Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AttributeError: module 'jax.ops' has no attribute 'index_add' #1773

Open
cmosguy opened this issue Feb 26, 2023 · 1 comment
Open

AttributeError: module 'jax.ops' has no attribute 'index_add' #1773

cmosguy opened this issue Feb 26, 2023 · 1 comment

Comments

@cmosguy
Copy link

cmosguy commented Feb 26, 2023

Description

I am trying to do something basic in my code:

import numpy as np              # regular ol' numpy
from trax import layers as tl   # core building block
from trax import shapes         # data signatures: dimensionality and type
from trax import fastmath       # uses jax, offers numpy on steroids

Upon import it errors out doing the basics here. What am I doing wrong? Should I be pinning a different version of the code?

Environment information

OS: Cento
lsb_release
LSB Version: :core-4.1-amd64:core-4.1-ia32:core-4.1-noarch:cxx-4.1-amd64:cxx-4.1-ia32:cxx-4.1-noarch:desktop-4.1-amd64:desktop-4.1-ia32:desktop-4.1-noarch:languages-4.1-amd64:languages-4.1-noarch:printing-4.1-amd64:printing-4.1-noarch

$ pip freeze | grep trax
trax==1.3.9

$ pip freeze | grep tensor
mesh-tensorflow==0.1.21
tensorboard==2.11.2
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorflow==2.11.0
tensorflow-datasets==4.8.2
tensorflow-estimator==2.11.0
tensorflow-hub==0.12.0
tensorflow-io-gcs-filesystem==0.30.0
tensorflow-metadata==1.12.0
tensorflow-text==2.11.0

$ pip freeze | grep jax
jax==0.4.4
jaxlib==0.4.4

$ python -V
Python 3.9.16


### For bugs: reproduction and error logs

# Error logs:

...

      1 # coding=utf-8
      2 # Copyright 2021 The Trax Authors.
      3 #
   (...)
     13 # See the License for the specific language governing permissions and
     14 # limitations under the License.
     16 """Trax top level import."""
---> 18 from trax import data
     19 from trax import fastmath
     20 from trax import layers

File ./ds_work/miniconda3/envs/coursera-nlp/lib/python3.9/site-packages/trax/data/__init__.py:36, in <module>
     16 """Functions and classes for obtaining and preprocesing data.
     17 
     18 The ``trax.data`` module presents a flattened (no subpackages) public API.
   (...)
...
    217     'vjp': jax.vjp,
    218     'vmap': jax.vmap,
    219 }

AttributeError: module 'jax.ops' has no attribute 'index_add'
@stephengineer
Copy link

stephengineer commented Apr 10, 2023

downgrade jax to 0.2.21
jax.ops.index_add is deprecated in 0.2.22
https://gitee.com/mirrors/JAX/blob/main/CHANGELOG.md#jax-0222-oct-12-2021

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants