New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Initialize independent and dependent caches separately in ARNN #1656
base: master
Are you sure you want to change the base?
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #1656 +/- ##
=======================================
Coverage 82.24% 82.24%
=======================================
Files 291 291
Lines 17834 17842 +8
Branches 3481 3482 +1
=======================================
+ Hits 14667 14675 +8
Misses 2495 2495
Partials 672 672 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wdphy16 So this PR moves the responsibility of constructing the cache from the sampler itself to the model. I suppose this makes sense.
This also goes from having 1 'override point' to 2 different ones. Does this make sense? Can't we simply have one single init cache
that does depend on the parameters, and expose a single 'extension point' to the user instead of 2?
Also, is there a reason for not following the standard flax philosophy of model.apply(...)
? If I am not mistaken this would make it easier to define and interweave the cache with the code?
Regardless, I would ask that with this PR we have a documentation page (could be a ipynb file + a lot of discussion) explaining how to use this and giving at least one example of implementation.
Otherwise this is an 'hidden feature' that benefits no-one but you, and if someone discovers it will probably need to ask you anyway how to use it.
I would like to avoid having lots of those 'hidden feature
I wonder if the caching mechanism could be generalized to other samplers, for example there are networks that benefit from precomputed quantities in case you do only local updates... |
If there is no AR sampler I guess the straightforward choice is to save those precomputed quantities in A benefit of But we still need an interface of the model to provide the initial cache to the sampler |
It’s not going to be easy and I would leave this effort aside for a future attempt.We would have to rewrite all layers to define a cache and how to use it.And while this works well for the first layer of a nn I’m not sure this helps in the following layers?Il giorno 13 dic 2023, alle ore 10:13, Dian Wu ***@***.***> ha scritto:
If there is no AR sampler I guess the straightforward choice is to save those precomputed quantities in SamplerState, and actually I think it's also possible to move all the AR sampling caches to SamplerState
A benefit of SamplerState is that it's not reset unless MCState.variables changes, so if we do sampling multiple times without changing variables, we don't need to initialize the cache multiple times
But we still need an interface of the model to provide the initial cache to the sampler
—Reply to this email directly, view it on GitHub, or unsubscribe.You are receiving this because you commented.Message ID: ***@***.***>
|
Looking at it again after these months, now I think it's not very useful to define For a more general and trickless way to initialize the cache, may be we can discuss that when we have more use cases |
As the PR for RNN is merged, now I can start upstreaming the changes in my branch for MPS-RNN into the master branch, and I'd like to split them into a few small and orthogonal PRs.
The motivation for this PR is, while the caches in ARNN are usually independent of the model parameters (such as currently implemented fast AR sampling caches and RNN memories), some of them actually depend on the parameters (such as the
gamma
in MPS-RNN). Thegamma
is the partial contraction of the MPS, and it does not change during the AR sampling procedure, so we want to precompute it before the sampling, rather than recompute it in every AR sampling step.The independent caches need to be initialized before providing the variables, while the dependent caches need to be initialized after
Module.setup()
. Now the user can separately overrideAbstractARNN._init_independent_cache
and_init_dependent_cache
for them. An example usage is in https://github.com/cqsl/mps-rnn/blob/master/models/mps.py .Note that
model.init_cache
is called likemodel.init
, rather thanmodel.apply(..., method=model.init_cache)
.