Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Checkpointing of a T5 model results in serialization error with new Jax 0.4.5 #262

Open
emersodb opened this issue Mar 11, 2023 · 2 comments

Comments

@emersodb
Copy link

Flax has removed optim in favor of optax in its newest versions above 0.5.3. This means that in order to run the code in this repository, one needs to downgrade below Flax 0.6. However, if you do that with Jax 0.4.5 or even with Jax 0.3.25 + jax.config.update('jax_array', True), the code cannot save a model checkpoint due to msgpack being unable to serialize the jax arrays.

Expected Behavior

Model should be able to be saved as a checkpoint.

Actual Behavior

image

Steps to Reproduce the Problem

With Jax 0.4.5 and Flax 0.5.3 one can minimally recreate this issue in a python repl as

image

@wzq016
Copy link

wzq016 commented Mar 17, 2023

Hi, I also have the same problem. I think it is because of the inconsistency between prompt-tuning and t5x. To fix it, I check the release date of dependencies, prompt-tuning repo, and t5x repo, and find a combination that can work.

PT version: 287949b546999d34eafc3770fc1b2320074912b3
T5X version: e5f61889114b2cb5bbfa916eb1ec35e6767427a0
jax==0.3.15, jax[tpu]==0.3.15
libtpu-nightly==0.1.dev20220722
jaxlib==0.3.15
flax==0.5.3
clu==0.0.7
numpy=1.23.4
orbax==0.0.4
protobuf==3.19.6
seqio==0.0.8

There may be better solutions though.

@emersodb
Copy link
Author

That's for the response. I've tried doing these downgrades. I was able to get everything built in a python env properly, but I'm running into mysterious cuda and NCCL errors that I wasn't getting before, which is a whole different issue. I'll see if I can make it work and report back. Ideally, this type of version matching wouldn't be necessary.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants