-
Notifications
You must be signed in to change notification settings - Fork 601
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Wrong parameter names when nesting Modules within flax transformations #3747
Comments
Hey @PhilipVinc, what is happening is that |
Ah, I see. Thanks for the answer Is there some alternative workaround ? What I want to achieve Is actually to be able to build However the lifted ideally, the usage I'd like is for this to work: sub_subnet = nn.Dense(features=1)
sub_net = Net(subnet= sub_subnet)
net = VNet(subneta= sub_net) A bit like I'm able to pass an already constructed module to a Module, I would also like to pass an already constructed module (with submodules) to |
@cgarciae do you have an idea on how to fix this, pointing us in the right direction? we could try contributing a PR. |
Hi, I have a complex case where I nest different submodules inside each other, which results in what I think is a wrong parameter name.
MWE:
I would expect the network parameter to be stored as a dictionary of the subnetwork's structure, as follows:
but instead the parameters of the subnetwork are split in two blocks.
We noticed that the bug disappears if the
some_args
dictionary is removed, and the keyword arguments are passed directly.cc @Adrien-Kahn
The text was updated successfully, but these errors were encountered: