Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MinSR regularisation #1696

Open
attila-i-szabo opened this issue Jan 19, 2024 · 8 comments
Open

MinSR regularisation #1696

attila-i-szabo opened this issue Jan 19, 2024 · 8 comments
Labels
contributor welcome We welcome contributions or PRs on this issue enhancement New feature or request

Comments

@attila-i-szabo
Copy link
Collaborator

I've just seen this paper. The algorithm they describe is interesting too, and we might want to consider supporting it, but there's a more immediate point. In Sec. 3.2, they explain that the MinSR matrix is guaranteed to have (1,1,...,1) as a zero vector, so regularising it (a lot) is a good idea to protect from numerical instability (in exact arithmetic, adding any multiple of the all-ones matrix makes no difference).

Does NetKet already do something similar?

@attila-i-szabo attila-i-szabo added enhancement New feature or request contributor welcome We welcome contributions or PRs on this issue labels Jan 19, 2024
@PhilipVinc
Copy link
Member

PhilipVinc commented Jan 19, 2024

I was discussing it this morning as well. I think their algorithm is resumed well in this figure

Screenshot 2024-01-19 at 10 36 21

The major difference is line 8 of the algorithm, which requires replacing the local energies with something that depends on the previous gradient.
This can easily be implemented in another driver by combining local energies and jacobian in here.

I think the regularisation you talk about is at line 9 where they also add a full matrix of 1/Ns for increased numerical stability. Unsure if this helps in general...

Then, they add momentum and gradient clipping which is easy to do with optax.

@PhilipVinc
Copy link
Member

cc @riccardo-rende @llviteritti

@attila-i-szabo
Copy link
Collaborator Author

attila-i-szabo commented Jan 19, 2024

Also, a little detail, line 9 they also add a full matrix of 1/Ns for increased numerical stability. Unsure if this helps in general...

Yeah, I was mostly talking about this - it should improve MinSR too, and seems like a quick change there to someone who's worked on that code

EDIT. I've just noticed the second half of your comment. I think it could be an improvement - this is guaranteed to be a zero eigenvector, so if you eliminate it, MinSR might be stable with lower diag_shift which should mean faster convergence

EDIT 2. I also suppose we'd recycle the MinSR code if we ever implement the new algorithm, so it makes sense to add it anyway.

@PhilipVinc
Copy link
Member

I think it's as simple as simply replacing this line
with

        matrix = matrix + diag_shift * jnp.eye(
            matrix_side
        )  + jnp.full(matrix.shape, 1/N_mc)

@llviteritti
Copy link
Contributor

Probably the easiest thing is to implement it in another driver starting from VMC_SRt and taking into account line 8 and then as @PhilipVinc said add moment and gradient clipping with optax. It would be interesting to see whether in a model on lattice it actually gives an improvement with respect to vanilla MinSR ..

@attila-i-szabo
Copy link
Collaborator Author

Then, they add momentum and gradient clipping which is easy to do with optax.

Is this something (especially momentum) that we can encapsulate into a new driver? The updates would be wrong without adding the momentum, so it wouldn't make sense to leave it to the user to add it with an external momentum driver in optax (especially that conventions differ, so it would be an invitation for errors). Personally, I also don't see the point for using a separate library for this particular line of the algorithm, which is basically a single FMA.

Gradient clipping is a different story, I think it can be left up to users whether they like it and do it themselves in optax.

@hendrydouglas11
Copy link

hendrydouglas11 commented Feb 14, 2024

The addition of the projector P = (1/N) one_vec * one_vec^T should be scaled since OO^T -> OO^T + cP makes one_vec go from an eigen value 0 -> c , but for c=1 if the other eigenvalues of OO^T are all much less than 1, then it becomes the dominate eigen value and can leave the matrix still ill conditioned (going from one extreme to the other). I instead rescaled by c=tr( OO^T) /N =avg(eig(OO^T)) which ensures that it becomes neither the largest nor smallest eigen value. This worked well for me, also the Kaczmarz/momentum parts did make a noticeably improvement over minSR for 1D Heisenberg with RBMs.

@chrisrothUT
Copy link
Collaborator

chrisrothUT commented Feb 28, 2024

In Sec. 3.2, they explain that the MinSR matrix is guaranteed to have (1,1,...,1) as a zero vector, so regularising it (a lot) is a good idea to protect from numerical instability (in exact arithmetic, adding any multiple of the all-ones matrix makes no difference).

@attila-i-szabo this is why it's better to regularize MinSR by ignoring the modes where the eigenvalues are below some threshold rather than adding a diag_shift. Just throwing away this zero mode works completely fine. Here's some pseudocode

elocs = jnp.matmul(evecs.conj().T,elocs)
elocs = jnp.where(evals/jnp.amax(evals) > rcond,evals/e,0)
elocs = jnp.matmul(evecs,elocs)

I find that rcond can usually be set to 10^-12 and occasionally it has to be increased to 10^-6 or so

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor welcome We welcome contributions or PRs on this issue enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

5 participants