-
Notifications
You must be signed in to change notification settings - Fork 57
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Checkpointing of a T5 model results in serialization error with new Jax 0.4.5 #262
Comments
Hi, I also have the same problem. I think it is because of the inconsistency between prompt-tuning and t5x. To fix it, I check the release date of dependencies, prompt-tuning repo, and t5x repo, and find a combination that can work.
There may be better solutions though. |
That's for the response. I've tried doing these downgrades. I was able to get everything built in a python env properly, but I'm running into mysterious cuda and NCCL errors that I wasn't getting before, which is a whole different issue. I'll see if I can make it work and report back. Ideally, this type of version matching wouldn't be necessary. |
Flax has removed optim in favor of optax in its newest versions above 0.5.3. This means that in order to run the code in this repository, one needs to downgrade below Flax 0.6. However, if you do that with Jax 0.4.5 or even with Jax 0.3.25 + jax.config.update('jax_array', True), the code cannot save a model checkpoint due to msgpack being unable to serialize the jax arrays.
Expected Behavior
Model should be able to be saved as a checkpoint.
Actual Behavior
Steps to Reproduce the Problem
With Jax 0.4.5 and Flax 0.5.3 one can minimally recreate this issue in a python repl as
The text was updated successfully, but these errors were encountered: