-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Compatibility with multiprocessing / joblib - AttributeError: type object 'Float[Tensor, 'batch_size num_classes']' has no attribute 'index_variadic' #198
Comments
Looks like they're not getting de/serialised correctly, so the If you can open a MWE that'd be great. (Or a PR! The fix might just be to implement |
I'm facing this same issue when trying to save an optax optimizer state using cloudpickle. Hope this issue gets fixed.
|
Do you have a MWE? |
Yes, I'm training a model with JAX and Equinox, and I am trying to save the optimizer state. `lr_scheduler = optax.warmup_cosine_decay_schedule( optimizer_state = optimizer.init(eqx.filter(model, eqx.is_array)) checkpoint_params = { with open(checkpoint_params_file, "wb") as f: |
I need to remove type hints from functions that are type checked and need to be called in joblib.Parallel or other multiprocessing pipelines; getting tracebacks like this:
The text was updated successfully, but these errors were encountered: