New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TDVP convergence issues in multi-process execution #1771
Comments
I guess this is because by default netket uses 16 chains per device (which, by the way, is extraordinarily inefficient for GPUs) and So for 1000 sample and 1 device your chains have length 1000/16 with 100 steps of thermalisation after every change of parameters. With 2000 samples and 2 devices you have chains of length 2000/32 with 200 steps of thermalisation after every change of parameters, so it should work even better than the case above. With 1000 samples and 2 devices you have chains of length 1000/32 and 100 steps of thermalisation after every change of parameters. |
@andrea-lizzit can you share the parameters for the simulations you plotted above? Th command line argument you used to launch them. |
Here it is |
Thanks. I'll try to find the time to look into it. However a quick thing I noticed is that you're starting from a random state, which in turn changes every time. Already setting vstate = nk.vqs.MCState(sampler, rbm, n_samples=n_samples, seed=1234) should ensure more reproducibility. |
Distributing a job on different processes with MPI results in much worse
R_hat
values for TDVP.R_hat
metric, often producing incorrect resultsThe plot below shows three simulation runs on the Ising model. The upper row shows an observable measured during time evolution, the bottom row shows the
R_hat
metric of"Generator"
. Running the script on one process withn_samples=1000
produces meaningful results (orange lines). When the same script is distributed between two processes, the markov chain does not converge and measurement of the observable produces different values (blue lines). Increasing the samples ton_samples=2000
on two processes restores the correct behaviour (green lines).Example script:
The issue seems to persist when running on multiple GPUs with native jax parallelism.
The text was updated successfully, but these errors were encountered: