Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initialize independent and dependent caches separately in ARNN #1656

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

wdphy16
Copy link
Collaborator

@wdphy16 wdphy16 commented Nov 24, 2023

As the PR for RNN is merged, now I can start upstreaming the changes in my branch for MPS-RNN into the master branch, and I'd like to split them into a few small and orthogonal PRs.

The motivation for this PR is, while the caches in ARNN are usually independent of the model parameters (such as currently implemented fast AR sampling caches and RNN memories), some of them actually depend on the parameters (such as the gamma in MPS-RNN). The gamma is the partial contraction of the MPS, and it does not change during the AR sampling procedure, so we want to precompute it before the sampling, rather than recompute it in every AR sampling step.

The independent caches need to be initialized before providing the variables, while the dependent caches need to be initialized after Module.setup(). Now the user can separately override AbstractARNN._init_independent_cache and _init_dependent_cache for them. An example usage is in https://github.com/cqsl/mps-rnn/blob/master/models/mps.py .

Note that model.init_cache is called like model.init, rather than model.apply(..., method=model.init_cache).

Copy link

codecov bot commented Nov 24, 2023

Codecov Report

All modified and coverable lines are covered by tests ✅

Comparison is base (90c9f4f) 82.24% compared to head (1c5de8a) 82.24%.

Additional details and impacted files
@@           Coverage Diff           @@
##           master    #1656   +/-   ##
=======================================
  Coverage   82.24%   82.24%           
=======================================
  Files         291      291           
  Lines       17834    17842    +8     
  Branches     3481     3482    +1     
=======================================
+ Hits        14667    14675    +8     
  Misses       2495     2495           
  Partials      672      672           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@gcarleo gcarleo requested a review from Z-Denis December 7, 2023 11:18
Copy link
Member

@PhilipVinc PhilipVinc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wdphy16 So this PR moves the responsibility of constructing the cache from the sampler itself to the model. I suppose this makes sense.

This also goes from having 1 'override point' to 2 different ones. Does this make sense? Can't we simply have one single init cache that does depend on the parameters, and expose a single 'extension point' to the user instead of 2?

Also, is there a reason for not following the standard flax philosophy of model.apply(...)? If I am not mistaken this would make it easier to define and interweave the cache with the code?

Regardless, I would ask that with this PR we have a documentation page (could be a ipynb file + a lot of discussion) explaining how to use this and giving at least one example of implementation.
Otherwise this is an 'hidden feature' that benefits no-one but you, and if someone discovers it will probably need to ask you anyway how to use it.

I would like to avoid having lots of those 'hidden feature

@gcarleo
Copy link
Member

gcarleo commented Dec 13, 2023

I wonder if the caching mechanism could be generalized to other samplers, for example there are networks that benefit from precomputed quantities in case you do only local updates...

@wdphy16
Copy link
Collaborator Author

wdphy16 commented Dec 13, 2023

If there is no AR sampler I guess the straightforward choice is to save those precomputed quantities in SamplerState, and actually I think it's also possible to move all the AR sampling caches to SamplerState

A benefit of SamplerState is that it's not reset unless MCState.variables changes, so if we do sampling multiple times without changing variables, we don't need to initialize the cache multiple times

But we still need an interface of the model to provide the initial cache to the sampler

@PhilipVinc
Copy link
Member

PhilipVinc commented Dec 13, 2023 via email

@wdphy16 wdphy16 marked this pull request as draft December 13, 2023 10:00
@wdphy16
Copy link
Collaborator Author

wdphy16 commented Dec 13, 2023

Looking at it again after these months, now I think it's not very useful to define _init_dependent_cache, as the user can just initialize the dependent cache in conditional (if he knows the trick)

For a more general and trickless way to initialize the cache, may be we can discuss that when we have more use cases

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants