Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove numba4jax, numba callback closures from operators #1747

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

PhilipVinc
Copy link
Member

@PhilipVinc PhilipVinc commented Mar 5, 2024

I wrote Numba4jax to allow HamiltonianSampler to work with operators back when jax could not call back into python.
This necessitated also to introduce those pesky get_conn_flattened_closure functions that returned numba closures such that numba4jax could embed them in a jax.jit context.

This was all shaky and ugly, plus I know that it often crashes on GPUs (@inailuig you confirm, right?)

This PR removes completely numba4jax dependency from netket, which was only used in HamiltonianSampler rules replacing it with a callback into python. I am not planning on supporting numba4jax in the future if possible, and this buys me more time.

This implementation is a bit slower than before on CPU (though for large models this should be negligible), and works fine on GPU (before it was crashing).
Testing should be done to ensure that this is ok.

Another plus of this PR is that it removes those closure functions from operators, and makes it easier to define a custom operator that works fine with Hamiltonian samplers (while before it was actually complicated, as those closure are not documented and hard to create)

@PhilipVinc
Copy link
Member Author

This is in draft because I want some feedback about the change from anyone who used hamiltonian sampler.

Copy link

codecov bot commented Mar 5, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 82.98%. Comparing base (b4465f2) to head (a1286d7).
Report is 25 commits behind head on master.

❗ Current head a1286d7 differs from pull request most recent head 37421a9. Consider uploading reports for the commit 37421a9 to get more accurate results

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1747      +/-   ##
==========================================
+ Coverage   82.60%   82.98%   +0.38%     
==========================================
  Files         309      303       -6     
  Lines       18768    18365     -403     
  Branches     2757     2718      -39     
==========================================
- Hits        15503    15241     -262     
+ Misses       2581     2453     -128     
+ Partials      684      671      -13     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@PhilipVinc PhilipVinc marked this pull request as ready for review March 5, 2024 15:03
@inailuig
Copy link
Collaborator

Is this ready?

@PhilipVinc
Copy link
Member Author

This would be ready, but It leads to a somewhat severe performance penalty in some cases so I'm not sure we really want to merge it.

An alternative would be to add a third implementation of MetropolisHamiltonian which does not use numba4jax, and default to that one if the get_conn_closure is not implemented/raises...

Though I would like to remove the dependency, as well.
Not sure.

@inailuig
Copy link
Collaborator

inailuig commented May 1, 2024

see #1777 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants