Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[export] grad of exported function fails with device assignment error #21314

Closed
gnecula opened this issue May 20, 2024 · 1 comment · Fixed by #21479
Closed

[export] grad of exported function fails with device assignment error #21314

gnecula opened this issue May 20, 2024 · 1 comment · Fixed by #21479
Assignees
Labels
bug Something isn't working

Comments

@gnecula
Copy link
Collaborator

gnecula commented May 20, 2024

Description

The jax.experimental.export exports the VJP using a synthetic mesh using the first N devices for the export platform. This seemed reasonable because all that is captured in the exported artifact is the number of devices, not their ids. However, during lowering there may be a conflict between the order of devices used by exporting code and shardings present in the primal functions.

For example, the following code fails:

    def f(x):
      return jnp.sum(x * 2.)

    mesh_rev = Mesh(list(reversed(jax.local_devices())), "i")
    shardings_rev = NamedSharding(mesh_rev, jax.sharding.PartitionSpec(("i",)))
    input_no_shards = jnp.ones(shape=(jax.local_device_count(),))
    input_rev = jax.device_put(input_no_shards, device=shardings_rev)
    exp_rev = export.export(pjit.pjit(f, in_shardings=shardings_rev))(input_no_shards)

    g = jax.grad(export.call(exp_rev))(input_rev)  # Failure here

The failure is:

E       ValueError: Received incompatible devices for pjitted computation. Got ARG_SHARDING with device ids [0, 1] on platform CPU and pjit inside pjit with device ids [1, 0] on platform CPU at /Users/necula/Source/jax/tests/export_test.py:1091:14 (JaxExportTest.test_grad_sharding_different_mesh)

Instead, the exporting of the VJP should use the same device assignment as for exporting the primal function.

System info (python version, jaxlib version, accelerator, etc.)

Not needed

@gnecula gnecula added the bug Something isn't working label May 20, 2024
@gnecula gnecula self-assigned this May 20, 2024
gnecula added a commit to gnecula/jax that referenced this issue May 20, 2024
Currently, the export code uses a manufactured device assignment
for exporting the VJP function. We should use instead the same
device assigment that was used when exporting the primal function.

Bug: google#21314
gnecula added a commit to gnecula/jax that referenced this issue May 20, 2024
Currently, the export code uses a manufactured device assignment
for exporting the VJP function. We should use instead the same
device assigment that was used when exporting the primal function.

Bug: google#21314
gnecula added a commit to gnecula/jax that referenced this issue May 20, 2024
Currently, the export code uses a manufactured device assignment
for exporting the VJP function. We should use instead the same
device assigment that was used when exporting the primal function.

Bug: google#21314
gnecula added a commit to gnecula/jax that referenced this issue May 20, 2024
Currently, the export code uses a manufactured device assignment
for exporting the VJP function. We should use instead the same
device assigment that was used when exporting the primal function.

Bug: google#21314
gnecula added a commit to gnecula/jax that referenced this issue May 20, 2024
Currently, the export code uses a manufactured device assignment
for exporting the VJP function. We should use instead the same
device assigment that was used when exporting the primal function.

This PR fixes that for the case when the export is done through
the direct use of `jax.experimental.export`, and leaves as future
work the case when the use is from `jax2tf`. We add a disabled
tests for the latter case.

Bug: google#21314
@gnecula
Copy link
Collaborator Author

gnecula commented May 23, 2024

In #21319 we fixed this for jax.experimental.export. When using from jax2tf the bug is not yet fixed.

gnecula added a commit to gnecula/jax that referenced this issue May 29, 2024
Native serialization needs to construct XLACompatibleShardings from
the HloSharding stored in the `jax.export.Exported`.
In google#21319 we fixed the device assignment for the `jax.export`
APIs, but keeping and reusing the device assignment for the
primal function. Here we fix the same bug for jax2tf.OI

Fixes: google#21314
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant