-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Numerical differences between shardings in random algorithm #21232
Comments
This is fixed by upgrading to partitionable threefry, e.g. by adding the following line to the top of the file (after imports): jax.config.update('jax_threefry_partitionable', True) See #18480 for more on the upgrade (which was delayed a bit, but is still planned). |
IIUC this is a bug (unintended behavior) even with jax_threefry_partitionable=False, and also we don't yet know what's causing this bug. Good to know that setting jax_threefry_partitionable=True fixes it though! |
Yes, I consider it a bug as well, but still undiagnosed. |
Description
We are seeing numerical differences between shardings in random number initialization on GPUs. For example, if I have a mesh of DP, FSDP, TP , based on what no of devices I allocate to each of these axes the numerical output of my initialization changes drastically. As a result of this when we are using TP we are seeing divergences in the network.
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: