Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

the training_kwargs in VariationalStates are unused #1484

Open
vigsterkr opened this issue May 30, 2023 · 3 comments
Open

the training_kwargs in VariationalStates are unused #1484

vigsterkr opened this issue May 30, 2023 · 3 comments

Comments

@vigsterkr
Copy link

vigsterkr commented May 30, 2023

There's a dictionary parameter for MCState called training_kwargs that based on my understanding is for setting the variables that are going to be passed when training a model. this parameter is being saved as a frozen dictionary in MCState but i don't see it being used anywhere in the code-base.

this parameter is to be deprecated or just the feature was never finished? as i'd like to use a dropout layer in my model and would like to be able to pass an is_training boolean parameter when the model is being trained.

@PhilipVinc
Copy link
Member

just the feature was never finished?

That one.
I was planning to use it as a starting point to implement things like batch norm, but never got around to finish it.
Would love for someone to actually implement it at some point.

You need to store this dictionary in the constructor.
Then pass it to the jitted kernel in the body of vstate.expect_and_grad/forces in here as a static argument

@vigsterkr
Copy link
Author

vigsterkr commented May 30, 2023 via email

vigsterkr pushed a commit to vigsterkr/netket that referenced this issue May 30, 2023
add support for training_kwargs
vigsterkr pushed a commit to vigsterkr/netket that referenced this issue May 31, 2023
add support for training_kwargs
@vigsterkr
Copy link
Author

@PhilipVinc a bit unrelated but on the same topic: when using haiku as for model fw, if you use a layer that requires a PRNG (say dropout layer), it'll blow up on hk.next_rng_key() when calling apply() because the prng is not being passed. do you have any quick-fix ideas for that?

vigsterkr pushed a commit to vigsterkr/netket that referenced this issue Jun 14, 2023
add support for training_kwargs
vigsterkr pushed a commit to vigsterkr/netket that referenced this issue Jun 27, 2023
add support for training_kwargs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants