Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error occurs in nn.vmap while variable_axes is a nested dict #3751

Open
egg5154 opened this issue Mar 12, 2024 · 1 comment
Open

Error occurs in nn.vmap while variable_axes is a nested dict #3751

egg5154 opened this issue Mar 12, 2024 · 1 comment
Labels
Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment.

Comments

@egg5154
Copy link

egg5154 commented Mar 12, 2024

Hello, I am trying to use nn.vmap to vectorize a module in flax, and some parameters are shared over batches. Here is an example:

class Test(nn.Module):

    @nn.compact
    def __call__(self, x):

        x = nn.Dense(4)(x)
        return nn.Dense(2)(x)

BatchTest = nn.vmap(
    Test,
    in_axes=0,
    out_axes=0,
    variable_axes={'params': {'Dense_0': {'bias': None, 'kernel': None}, 'Dense_1': {'bias': 0, 'kernel': 0}}},
    split_rngs={'params': False},
)

_params = BatchTest().init({'params': jax.random.PRNGKey(42)}, jnp.zeros((3, 3)))

Here the first dense layer's parameters are expected to be shared over batches. However, I got the error report:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[13], line 9
      1 BatchTest = nn.vmap(
      2     Test,
      3     in_axes=0,
   (...)
      6     split_rngs={'params': False},
      7 )
----> 9 _params = BatchTest().init({'params': jax.random.PRNGKey(42)}, jnp.zeros((3, 3)))
    [... skipping hidden 11 frame]
File ~/miniconda3/envs/jax/lib/python3.9/site-packages/jax/_src/tree_util.py:243, in <listcomp>(.0)
    210 """Maps a multi-input function over pytree args to produce a new pytree.
    211
    212 Args:
   (...)
    240   [[5, 7, 9], [6, 1, 2]]
    241 """
    242 leaves, treedef = tree_flatten(tree, is_leaf)
--> 243 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
    244 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
ValueError: Expected dict, got ({},).

I wonder if there are ways to solve the problem?

@chiamp
Copy link
Collaborator

chiamp commented Mar 13, 2024

I believe variable_axes cannot be a nested dictionary. It can, at most, be a dictionary at the top-level, for example: variable_axes={'params': 0}

@chiamp chiamp added the Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment. label Mar 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment.
Projects
None yet
Development

No branches or pull requests

2 participants