You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In one of my projects, I was implementing DINO-Projection Head by using Flax, and I faced a problem. The problem occurs, when I try to tabulate DINO head.
In the function init_model() model parameters are generated, and the summary of the model is printed by using nn.tabulate(). If the parameters compute_flops and compute_vjp_flops of nn.tabulate() are set to False, there is no problem; entire code works fine. However, when they are set to True, it poses an error. The error does not show up for MLP, but does for DINO-Head
having the same issue, tabulate with any flops=True not working (MPS/Metal)
337 e = jax.jit(fn).lower(*args, **kwargs)
338 cost = e.cost_analysis()
--> 339 flops = int(cost['flops']) if 'flops' in cost else 0
340 return flops
TypeError: argument of type 'NoneType' is not iterable
Hello Flax Community,
In one of my projects, I was implementing DINO-Projection Head by using Flax, and I faced a problem. The problem occurs, when I try to tabulate DINO head.
In the function
init_model()
model parameters are generated, and the summary of the model is printed by usingnn.tabulate()
. If the parameterscompute_flops
andcompute_vjp_flops
ofnn.tabulate()
are set toFalse
, there is no problem; entire code works fine. However, when they are set toTrue
, it poses an error. The error does not show up for MLP, but does for DINO-HeadI tried to execute the code in Google-Colab, and it was set to CPU option. While implementing DINO-Head, I utilized DINO repository: https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
How can I solve it ?
What is the exact reason for it ?
Thanks in advance.
The text was updated successfully, but these errors were encountered: