Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MPI] logsumexp #1495

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Draft

[MPI] logsumexp #1495

wants to merge 7 commits into from

Conversation

inailuig
Copy link
Collaborator

@inailuig inailuig commented Jun 9, 2023

Adds an mpi-aware logsumexp.

Will be useful for e.g. computing the statistics of importance sampling.

@codecov
Copy link

codecov bot commented Jun 9, 2023

Codecov Report

Merging #1495 (d55d1a1) into master (26372e3) will decrease coverage by 0.92%.
The diff coverage is 72.00%.

@@            Coverage Diff             @@
##           master    #1495      +/-   ##
==========================================
- Coverage   83.78%   82.86%   -0.92%     
==========================================
  Files         240      241       +1     
  Lines       13687    13742      +55     
  Branches     2076     2095      +19     
==========================================
- Hits        11467    11387      -80     
- Misses       1706     1823     +117     
- Partials      514      532      +18     
Impacted Files Coverage Δ
netket/utils/mpi/primitives.py 35.00% <0.00%> (-24.22%) ⬇️
netket/utils/mpi/_logsumexp.py 75.71% <75.71%> (ø)
netket/utils/mpi/__init__.py 100.00% <100.00%> (ø)

... and 21 files with indirect coverage changes

Also fix the docstring.

Previously one got the following errors:

mpi4jax:
```
MPI_ABORT was invoked on rank 0 in communicator MPI COMMUNICATOR 4 
CREATE FROM 0 with errorcode 10.
```
mpi4py:
``
mpi4py.MPI.Exception: MPI_ERR_OP: invalid reduce operation
```
reproducer:

```python
import netket as nk
import jax.numpy as jnp
x = 1.j
y = jnp.ones(2, dtype=complex)
#nk.utils.mpi.mpi_max(x)
#nk.utils.mpi.mpi_max_jax(x)
nk.utils.mpi.mpi_max_jax(y)
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

1 participant