Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The Error in FLOP Computation of Model Tabulate Function #3716

Open
GoktugGuvercin opened this issue Feb 24, 2024 · 1 comment
Open

The Error in FLOP Computation of Model Tabulate Function #3716

GoktugGuvercin opened this issue Feb 24, 2024 · 1 comment
Assignees
Labels
Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment.

Comments

@GoktugGuvercin
Copy link

Hello Flax Community,

In one of my projects, I was implementing DINO-Projection Head by using Flax, and I faced a problem. The problem occurs, when I try to tabulate DINO head.

In the function init_model() model parameters are generated, and the summary of the model is printed by using nn.tabulate(). If the parameters compute_flops and compute_vjp_flops of nn.tabulate() are set to False, there is no problem; entire code works fine. However, when they are set to True, it poses an error. The error does not show up for MLP, but does for DINO-Head

I tried to execute the code in Google-Colab, and it was set to CPU option. While implementing DINO-Head, I utilized DINO repository: https://github.com/facebookresearch/dino/blob/main/vision_transformer.py

How can I solve it ?
What is the exact reason for it ?

Thanks in advance.

import jax
import jax.numpy as jnp
import flax
import flax.linen as nn

from typing import List, Any
from dataclasses import field

def init_model(model: nn.Module, seed: int, input_shape: tuple, train_mode, tabulate: bool = False):
    rng = jax.random.key(seed)
    sample_input = jax.numpy.ones(input_shape)
    model_params = model.init(rng, sample_input, train=train_mode)

    if tabulate:
        tabulate_fn = nn.tabulate(model, rng, train=train_mode, compute_flops=True, compute_vjp_flops=True)
        print(tabulate_fn(sample_input))

    return model_params

def feature_normalizer(feat: jax.Array, p: int, axis: int, eps: int = 1e-12):
    norm = jnp.linalg.norm(feat, p, axis)
    eps = jnp.tile(eps, norm.shape[0])
    norm = jnp.min(jnp.stack((norm, eps), axis=1), axis=1, keepdims=True)
    return feat / norm


class MLP(nn.Module):

    batch_norm: bool = False
    features: list = field(default_factory=[2048, 2048, 256])
    activation: Any = nn.gelu

    @nn.compact
    def __call__(self, x, train):

        for feat in self.features[:-1]:
            x = nn.Dense(feat)(x)
            if self.batch_norm:
                x = nn.BatchNorm(use_running_average= not train, axis=-1)(x)
            x = self.activation(x)

        out = nn.Dense(self.features[-1])(x)
        return out

class DINOHead(nn.Module):

    batch_norm: bool = False
    features: list = field(default_factory=[2048, 2048, 256])
    activation: Any = nn.gelu
    output_dim: int = 4096

    @nn.compact
    def __call__(self, x, train):
        x = MLP(self.batch_norm, self.features, self.activation)(x, train)
        x = feature_normalizer(x, 2, -1)
        x = nn.WeightNorm(nn.Dense(self.output_dim, use_bias=False), use_scale=False)(x)


train = True
init_seed = 13
input_shape = (4, 1280)
batch_norm = False

mlp = MLP(batch_norm=True,
          features=[2048, 2048, 256],
          activation=nn.gelu)

mlp_params = init_model(mlp, init_seed, input_shape, train, True)

dino_head = DINOHead(batch_norm=True, 
                     features=[2048, 2048, 256], 
                     activation=nn.gelu, 
                     output_dim=4096)

dino_head_params = init_model(dino_head, init_seed, input_shape, train, True)
@youurayy
Copy link

youurayy commented Mar 5, 2024

having the same issue, tabulate with any flops=True not working (MPS/Metal)

    337 e = jax.jit(fn).lower(*args, **kwargs)
    338 cost = e.cost_analysis()
--> 339 flops = int(cost['flops']) if 'flops' in cost else 0
    340 return flops

TypeError: argument of type 'NoneType' is not iterable

@cgarciae cgarciae self-assigned this Mar 6, 2024
@chiamp chiamp added the Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment. label Mar 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment.
Projects
None yet
Development

No branches or pull requests

4 participants