-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[export] grad of exported function fails with device assignment error #21314
Labels
bug
Something isn't working
Comments
gnecula
added a commit
to gnecula/jax
that referenced
this issue
May 20, 2024
Currently, the export code uses a manufactured device assignment for exporting the VJP function. We should use instead the same device assigment that was used when exporting the primal function. Bug: google#21314
gnecula
added a commit
to gnecula/jax
that referenced
this issue
May 20, 2024
Currently, the export code uses a manufactured device assignment for exporting the VJP function. We should use instead the same device assigment that was used when exporting the primal function. Bug: google#21314
gnecula
added a commit
to gnecula/jax
that referenced
this issue
May 20, 2024
Currently, the export code uses a manufactured device assignment for exporting the VJP function. We should use instead the same device assigment that was used when exporting the primal function. Bug: google#21314
gnecula
added a commit
to gnecula/jax
that referenced
this issue
May 20, 2024
Currently, the export code uses a manufactured device assignment for exporting the VJP function. We should use instead the same device assigment that was used when exporting the primal function. Bug: google#21314
gnecula
added a commit
to gnecula/jax
that referenced
this issue
May 20, 2024
Currently, the export code uses a manufactured device assignment for exporting the VJP function. We should use instead the same device assigment that was used when exporting the primal function. This PR fixes that for the case when the export is done through the direct use of `jax.experimental.export`, and leaves as future work the case when the use is from `jax2tf`. We add a disabled tests for the latter case. Bug: google#21314
In #21319 we fixed this for |
gnecula
added a commit
to gnecula/jax
that referenced
this issue
May 29, 2024
Native serialization needs to construct XLACompatibleShardings from the HloSharding stored in the `jax.export.Exported`. In google#21319 we fixed the device assignment for the `jax.export` APIs, but keeping and reusing the device assignment for the primal function. Here we fix the same bug for jax2tf.OI Fixes: google#21314
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Description
The
jax.experimental.export
exports the VJP using a synthetic mesh using the first N devices for the export platform. This seemed reasonable because all that is captured in the exported artifact is the number of devices, not their ids. However, during lowering there may be a conflict between the order of devices used by exporting code and shardings present in the primal functions.For example, the following code fails:
The failure is:
Instead, the exporting of the VJP should use the same device assignment as for exporting the primal function.
System info (python version, jaxlib version, accelerator, etc.)
Not needed
The text was updated successfully, but these errors were encountered: