You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
For triton (if I have read this correctly) masked load/stores do not occur. So you can request to load/store to an index OOB for ref if that is masked. The current interpreter uses dynamic_slices/dynamic_slice_updates where masked updates are applied. In line with the 'always be in bounds' design in JAX if you index a slice that overruns the edge of the array it will be shifted to be valid (if possible). This leads to a disconnect in interpreter and Pallas outputs.
I know Triton is not Pallas, have you changed the desired behaviour for these cases in Pallas? - in which case this isn't a bug but needs documenting.
I've added a pull request fixing this with some tests #21298
Here is a colab minimal reproduction with shifts in load indices.
Description
For triton (if I have read this correctly) masked load/stores do not occur. So you can request to load/store to an index OOB for ref if that is masked. The current interpreter uses dynamic_slices/dynamic_slice_updates where masked updates are applied. In line with the 'always be in bounds' design in JAX if you index a slice that overruns the edge of the array it will be shifted to be valid (if possible). This leads to a disconnect in interpreter and Pallas outputs.
I know Triton is not Pallas, have you changed the desired behaviour for these cases in Pallas? - in which case this isn't a bug but needs documenting.
I've added a pull request fixing this with some tests #21298
Here is a colab minimal reproduction with shifts in load indices.
System info (python version, jaxlib version, accelerator, etc.)
(Problem persists in 0.4.28)
The text was updated successfully, but these errors were encountered: