Skip to content

Commit

Permalink
Add mpi_sum_jax in nk.jax.expect (see Issue netket#1690)
Browse files Browse the repository at this point in the history
  • Loading branch information
alleSini99 committed Jan 11, 2024
1 parent 07fb86e commit 6ccf51b
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions netket/jax/_expect.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from netket.stats import statistics as mpi_statistics, mean as mpi_mean, Stats
from netket.utils.types import PyTree
from netket.utils.mpi import mpi_sum_jax

from ._vjp import vjp as nkvjp

Expand Down Expand Up @@ -95,6 +96,8 @@ def f(pars, σ, *cost_args):

_, pb = nkvjp(f, pars, σ, *cost_args)
grad_f = pb(dL̄)
grad_f = jax.tree_map(lambda x: mpi_sum_jax(x)[0], grad_f)

return grad_f


Expand Down

0 comments on commit 6ccf51b

Please sign in to comment.