Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initialize independent and dependent caches separately in ARNN #1656

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

* Recurrent neural networks and layers have been added to `nkx.models` and `nkx.nn` [#1305](https://github.com/netket/netket/pull/1305).
* Added experimental support for running NetKet on multiple jax devices (as an alternative to MPI). It is enabled by setting the environment variable/configuration flag `NETKET_EXPERIMENTAL_SHARDING=1`. Parallelization is achieved by distributing the Markov chains / samples equally across all available devices utilizing [`jax.Array` sharding](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html). On GPU multi-node setups are supported via [jax.distribued](https://jax.readthedocs.io/en/latest/multi_process.html), whereas on CPU it is limited to a single process but several threads can be used by setting `XLA_FLAGS='--xla_force_host_platform_device_count=XX'` [#1511](https://github.com/netket/netket/pull/1511).
* Caches in ARNN that depend on model parameters can be initialized in {meth}`~netket.models.AbstractARNN._init_dependent_cache` [#1656](https://github.com/netket/netket/pull/1656).

### Breaking Changes

Expand Down
36 changes: 35 additions & 1 deletion netket/models/autoreg.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from netket.nn import MaskedConv1D, MaskedConv2D, MaskedDense1D
from netket.nn.masked_linear import default_kernel_init
from netket.nn import activation as nkactivation
from netket.utils.types import Array, DType, NNInitFunc
from netket.utils.types import Array, DType, NNInitFunc, PRNGKeyT, PyTree

Check warning on line 29 in netket/models/autoreg.py

View check run for this annotation

Codecov / codecov/patch

netket/models/autoreg.py#L29

Added line #L29 was not covered by tests
from netket.utils import deprecate_dtype


Expand Down Expand Up @@ -154,6 +154,40 @@
log_psi = log_psi.reshape((inputs.shape[0], -1)).sum(axis=1)
return log_psi

def _init_independent_cache(self, inputs: Array) -> None:

Check warning on line 157 in netket/models/autoreg.py

View check run for this annotation

Codecov / codecov/patch

netket/models/autoreg.py#L157

Added line #L157 was not covered by tests
self.conditional(inputs, 0)

def _init_dependent_cache(self, inputs: Array) -> None:

Check warning on line 160 in netket/models/autoreg.py

View check run for this annotation

Codecov / codecov/patch

netket/models/autoreg.py#L160

Added line #L160 was not covered by tests
pass

def init_cache(self, variables: PyTree, inputs: Array, key: PRNGKeyT) -> PyTree:

Check warning on line 163 in netket/models/autoreg.py

View check run for this annotation

Codecov / codecov/patch

netket/models/autoreg.py#L163

Added line #L163 was not covered by tests
"""
Initializes the cache before sampling.

Subclasses may override :meth:`~netket.models.AbstractARNN._init_independent_cache`
for caches that are independent of model parameters or any cache,
and :meth:`~netket.models.AbstractARNN._init_dependent_cache` for caches
that depend on model parameters or those independent caches.

When calling this method, `variables` should contain model parameters
but not any cache.

`_init_independent_cache` is called without providing the variables.
When `_init_dependent_cache` is called, the variables contain model
parameters and independent caches, but not dependent caches.
"""
variables_tmp = self.init(key, inputs, method=self._init_independent_cache)
cache = variables_tmp.get("cache")
if cache:
variables = {**variables, "cache": cache}

_, mutables = self.apply(
variables, inputs, method=self._init_dependent_cache, mutable=["cache"]
)
cache = mutables.get("cache")

return cache

def reorder(self, inputs: Array, axis: int = 0) -> Array:
"""
Transforms an array from unordered to ordered.
Expand Down
7 changes: 1 addition & 6 deletions netket/sampler/autoreg.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,6 @@ def is_exact(sampler):
"""
return True

def _init_cache(sampler, model, σ, key):
variables = model.init(key, σ, 0, method=model.conditional)
cache = variables.get("cache")
return cache

def _init_state(sampler, model, variables, key):
return ARDirectSamplerState(key=key)

Expand Down Expand Up @@ -155,7 +150,7 @@ def scan_fun(carry, index):

# Initialize `cache` before generating a batch of samples,
# even if `variables` is not changed and `reset` is not called
cache = sampler._init_cache(model, σ, key_init)
cache = model.init_cache(variables_no_cache, σ, key_init)
if cache:
variables = {**variables_no_cache, "cache": cache}
else:
Expand Down
6 changes: 4 additions & 2 deletions test/models/test_autoreg.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,9 +310,11 @@ def test_same(self, partial_model_pair, hilbert, param_dtype, machine_pow, skip)
model1 = partial_model_pair[0](hilbert, param_dtype, machine_pow)
model2 = partial_model_pair[1](hilbert, param_dtype, machine_pow)

key_spins, key_model = jax.random.split(jax.random.PRNGKey(0))
key_spins, key_model, key_cache = jax.random.split(jax.random.PRNGKey(0), 3)
spins = hilbert.random_state(key_spins, size=batch_size)
variables = model2.init(key_model, spins, 0, method=model2.conditional)
variables_no_cache = model1.init(key_model, spins)
cache = model2.init_cache(variables_no_cache, spins, key_cache)
variables = {**variables_no_cache, "cache": cache}

p1 = model1.apply(variables, spins, method=model1.conditionals)
p2 = model2.apply(variables, spins, method=model2.conditionals)
Expand Down