Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compatibility with multiprocessing / joblib - AttributeError: type object 'Float[Tensor, 'batch_size num_classes']' has no attribute 'index_variadic' #198

Open
jaanli opened this issue Apr 5, 2024 · 4 comments
Labels
bug Something isn't working

Comments

@jaanli
Copy link

jaanli commented Apr 5, 2024

I need to remove type hints from functions that are type checked and need to be called in joblib.Parallel or other multiprocessing pipelines; getting tracebacks like this:

joblib.externals.loky.process_executor._RemoteTraceback: 
"""
Traceback (most recent call last):
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/externals/loky/process_executor.py", line 426, in _process_worker
    call_item = call_queue.get(block=True, timeout=timeout)
  File "/home/ray/anaconda3/lib/python3.10/multiprocessing/queues.py", line 122, in get
    return _ForkingPickler.loads(res)
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/externals/cloudpickle/cloudpickle.py", line 831, in _make_skeleton_class
    return _lookup_class_or_track(class_tracker_id, skeleton_class)
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/externals/cloudpickle/cloudpickle.py", line 120, in _lookup_class_or_track
    _DYNAMIC_CLASS_TRACKER_BY_CLASS[class_def] = class_tracker_id
  File "/home/ray/anaconda3/lib/python3.10/weakref.py", line 429, in __setitem__
    self.data[ref(key, self._remove)] = value
  File "/home/ray/.venv/lib/python3.10/site-packages/jaxtyping/_array_types.py", line 310, in __hash__
    return hash(cls._get_props())
  File "/home/ray/.venv/lib/python3.10/site-packages/jaxtyping/_array_types.py", line 295, in _get_props
    cls.index_variadic,
AttributeError: type object 'Float[Tensor, 'batch_size num_classes']' has no attribute 'index_variadic'
"""

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/ray/evaluate_bm25_pytorch.py", line 314, in <module>
    main(cfg)
  File "/home/ray/evaluate_bm25_pytorch.py", line 295, in main
    trainer.evaluate(
  File "/home/ray/trainer.py", line 427, in evaluate
    predictions, objective = model_and_objective(batch)
  File "/home/ray/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ray/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ray/evaluate_bm25_pytorch.py", line 134, in forward
    predictions = Parallel(n_jobs=self.n_jobs)(
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 1952, in __call__
    return output if self.return_generator else list(output)
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 1595, in _get_outputs
    yield from self._retrieve()
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 1699, in _retrieve
    self._raise_error_fast()
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 1734, in _raise_error_fast
    error_job.get_result(self.timeout)
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 736, in get_result
    return self._return_or_raise()
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 754, in _return_or_raise
    raise self._result
joblib.externals.loky.process_executor.BrokenProcessPool: A task has failed to un-serialize. Please ensure that the arguments of the function are all picklable.
joblib.externals.loky.process_executor._RemoteTraceback: 
"""
Traceback (most recent call last):
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/externals/loky/process_executor.py", line 426, in _process_worker
    call_item = call_queue.get(block=True, timeout=timeout)
  File "/home/ray/anaconda3/lib/python3.10/multiprocessing/queues.py", line 122, in get
    return _ForkingPickler.loads(res)
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/externals/cloudpickle/cloudpickle.py", line 831, in _make_skeleton_class
    return _lookup_class_or_track(class_tracker_id, skeleton_class)
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/externals/cloudpickle/cloudpickle.py", line 120, in _lookup_class_or_track
    _DYNAMIC_CLASS_TRACKER_BY_CLASS[class_def] = class_tracker_id
  File "/home/ray/anaconda3/lib/python3.10/weakref.py", line 429, in __setitem__
    self.data[ref(key, self._remove)] = value
  File "/home/ray/.venv/lib/python3.10/site-packages/jaxtyping/_array_types.py", line 310, in __hash__
    return hash(cls._get_props())
  File "/home/ray/.venv/lib/python3.10/site-packages/jaxtyping/_array_types.py", line 295, in _get_props
    cls.index_variadic,
AttributeError: type object 'Float[Tensor, 'batch_size num_classes']' has no attribute 'index_variadic'
"""

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/ray/evaluate_bm25_pytorch.py", line 314, in <module>
    main(cfg)
  File "/home/ray/evaluate_bm25_pytorch.py", line 295, in main
    trainer.evaluate(
  File "/home/ray/trainer.py", line 427, in evaluate
    predictions, objective = model_and_objective(batch)
  File "/home/ray/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ray/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ray/evaluate_bm25_pytorch.py", line 134, in forward
    predictions = Parallel(n_jobs=self.n_jobs)(
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 1952, in __call__
    return output if self.return_generator else list(output)
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 1595, in _get_outputs
    yield from self._retrieve()
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 1699, in _retrieve
    self._raise_error_fast()
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 1734, in _raise_error_fast
    error_job.get_result(self.timeout)
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 736, in get_result
    return self._return_or_raise()
  File "/home/ray/.venv/lib/python3.10/site-packages/joblib/parallel.py", line 754, in _return_or_raise
    raise self._result
joblib.externals.loky.process_executor.BrokenProcessPool: A task has failed to un-serialize. Please ensure that the arguments of the function are all picklable.
@patrick-kidger
Copy link
Owner

Looks like they're not getting de/serialised correctly, so the index_variadic attribute doesn't make it across.

If you can open a MWE that'd be great. (Or a PR! The fix might just be to implement __setstate__ and __getstate__?)

@patrick-kidger patrick-kidger added the bug Something isn't working label Apr 5, 2024
@sachith-gunasekara
Copy link

I'm facing this same issue when trying to save an optax optimizer state using cloudpickle. Hope this issue gets fixed.

File "/root/optimizer.py", line 89, in run_train_on_modal optimizer_state = optax.tree_utils.tree_set(optimizer_state, inner_state=cloudpickle.load(f)) ^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/cloudpickle/cloudpickle.py", line 831, in _make_skeleton_class return _lookup_class_or_track(class_tracker_id, skeleton_class) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/cloudpickle/cloudpickle.py", line 120, in _lookup_class_or_track _DYNAMIC_CLASS_TRACKER_BY_CLASS[class_def] = class_tracker_id ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^ File "/usr/lib/python3.11/weakref.py", line 428, in __setitem__ self.data[ref(key, self._remove)] = value ~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/jaxtyping/_array_types.py", line 321, in __hash__ return hash(cls._get_props()) ^^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.11/dist-packages/jaxtyping/_array_types.py", line 306, in _get_props cls.index_variadic, ^^^^^^^^^^^^^^^^^^ AttributeError: type object 'Float[Array, '*shape']' has no attribute 'index_variadic'

@patrick-kidger
Copy link
Owner

Do you have a MWE?

@sachith-gunasekara
Copy link

sachith-gunasekara commented May 31, 2024

Yes, I'm training a model with JAX and Equinox, and I am trying to save the optimizer state.

`lr_scheduler = optax.warmup_cosine_decay_schedule(
init_value=0.0,
peak_value=learning_rate,
warmup_steps=warmup_iters if init_from == 'scratch' else 0,
decay_steps=lr_decay_iters - iter_num,
end_value=min_lr,
)
optimizer = optax.inject_hyperparams(optax.adamw)(learning_rate=lr_scheduler, b1=beta1, b2=beta2)

optimizer_state = optimizer.init(eqx.filter(model, eqx.is_array))

checkpoint_params = {
"optimizer_state": optimizer_state
}

with open(checkpoint_params_file, "wb") as f:
cloudpickle.dump(checkpoint_params, f)`

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants