Error occurs in nn.vmap
while variable_axes
is a nested dict
#3751
Labels
Priority: P2 - no schedule
Best effort response and resolution. We have no plan to work on this at the moment.
Hello, I am trying to use
nn.vmap
to vectorize a module in flax, and some parameters are shared over batches. Here is an example:Here the first dense layer's parameters are expected to be shared over batches. However, I got the error report:
I wonder if there are ways to solve the problem?
The text was updated successfully, but these errors were encountered: