Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OWL-ViT playground imports #1043

Open
kaninaba94 opened this issue Apr 16, 2024 · 9 comments
Open

OWL-ViT playground imports #1043

kaninaba94 opened this issue Apr 16, 2024 · 9 comments

Comments

@kaninaba94
Copy link

kaninaba94 commented Apr 16, 2024

This is using T4 GPU runtime. All the installs ran successfully, but the second cell failed:

AttributeError                            Traceback (most recent call last)
[<ipython-input-4-2d9aa1f916a4>](https://localhost:8080/#) in <cell line: 10>()
      8 import numpy as np
      9 from scenic.projects.owl_vit import configs
---> 10 from scenic.projects.owl_vit.models import TextZeroShotDetectionModule
     11 
     12 from scenic.projects.owl_vit.notebooks import inference

14 frames
[/content/scenic/projects/owl_vit/models.py](https://localhost:8080/#) in <module>
     23 import jax.numpy as jnp
     24 import ml_collections
---> 25 from scenic.projects.owl_vit import layers
     26 from scenic.projects.owl_vit import matching_base_models
     27 from scenic.projects.owl_vit import utils

[/content/scenic/projects/owl_vit/layers.py](https://localhost:8080/#) in <module>
     29 import numpy as np
     30 from scenic.model_lib.base_models import box_utils
---> 31 from scenic.projects.owl_vit import utils
     32 from scenic.projects.owl_vit.clip import layers as clip_layers
     33 from scenic.projects.owl_vit.clip import model as clip_model

[/content/scenic/projects/owl_vit/utils.py](https://localhost:8080/#) in <module>
     22 import jax.numpy as jnp
     23 import numpy as np
---> 24 from scenic.train_lib import train_utils
     25 import scipy
     26 

[/content/scenic/train_lib/train_utils.py](https://localhost:8080/#) in <module>
     34 import ml_collections
     35 import numpy as np
---> 36 import optax
     37 from scenic.common_lib import debug_utils
     38 from scenic.dataset_lib import dataset_utils

[/usr/local/lib/python3.10/dist-packages/optax/__init__.py](https://localhost:8080/#) in <module>
     15 """Optax: composable gradient processing and optimization, in JAX."""
     16 
---> 17 from optax import contrib
     18 from optax import losses
     19 from optax import monte_carlo

[/usr/local/lib/python3.10/dist-packages/optax/contrib/__init__.py](https://localhost:8080/#) in <module>
     19 from optax.contrib._complex_valued import split_real_and_imaginary
     20 from optax.contrib._complex_valued import SplitRealAndImaginaryState
---> 21 from optax.contrib._dadapt_adamw import dadapt_adamw
     22 from optax.contrib._dadapt_adamw import DAdaptAdamWState
     23 from optax.contrib._mechanic import MechanicState

[/usr/local/lib/python3.10/dist-packages/optax/contrib/_dadapt_adamw.py](https://localhost:8080/#) in <module>
     25 from optax import tree_utils
     26 from optax._src import base
---> 27 from optax._src import utils
     28 
     29 

[/usr/local/lib/python3.10/dist-packages/optax/_src/utils.py](https://localhost:8080/#) in <module>
     23 import jax
     24 import jax.numpy as jnp
---> 25 import jax.scipy.stats.norm as multivariate_normal
     26 
     27 from optax import tree_utils as otu

[/usr/local/lib/python3.10/dist-packages/jax/scipy/stats/__init__.py](https://localhost:8080/#) in <module>
     16 # See PEP 484 & https://github.com/google/jax/issues/7570
     17 
---> 18 from jax.scipy.stats import bernoulli as bernoulli
     19 from jax.scipy.stats import beta as beta
     20 from jax.scipy.stats import binom as binom

[/usr/local/lib/python3.10/dist-packages/jax/scipy/stats/bernoulli.py](https://localhost:8080/#) in <module>
     16 # See PEP 484 & https://github.com/google/jax/issues/7570
     17 
---> 18 from jax._src.scipy.stats.bernoulli import (
     19   logpmf as logpmf,
     20   pmf as pmf,

[/usr/local/lib/python3.10/dist-packages/jax/_src/scipy/stats/bernoulli.py](https://localhost:8080/#) in <module>
     14 
     15 
---> 16 import scipy.stats as osp_stats
     17 
     18 from jax import lax

[/usr/local/lib/python3.10/dist-packages/scipy/stats/__init__.py](https://localhost:8080/#) in <module>
    606 from ._warnings_errors import (ConstantInputWarning, NearConstantInputWarning,
    607                                DegenerateDataWarning, FitError)
--> 608 from ._stats_py import *
    609 from ._variation import variation
    610 from .distributions import *

[/usr/local/lib/python3.10/dist-packages/scipy/stats/_stats_py.py](https://localhost:8080/#) in <module>
     35 from numpy import array, asarray, ma
     36 from numpy.lib import NumpyVersion
---> 37 from numpy.testing import suppress_warnings
     38 
     39 from scipy.spatial.distance import cdist

[/usr/local/lib/python3.10/dist-packages/numpy/testing/__init__.py](https://localhost:8080/#) in <module>
      9 
     10 from . import _private
---> 11 from ._private.utils import *
     12 from ._private.utils import (_assert_valid_refcount, _gen_alignment_data)
     13 from ._private import extbuild

[/usr/local/lib/python3.10/dist-packages/numpy/testing/_private/utils.py](https://localhost:8080/#) in <module>
     55 IS_PYSTON = hasattr(sys, "pyston_version_info")
     56 HAS_REFCOUNT = getattr(sys, 'getrefcount', None) is not None and not IS_PYSTON
---> 57 HAS_LAPACK64 = numpy.linalg._umath_linalg._ilp64
     58 
     59 _OLD_PROMOTION = lambda: np._get_promotion_state() == 'legacy'

AttributeError: module 'numpy.linalg._umath_linalg' has no attribute '_ilp64'
@Chi-chicken
Copy link

What's your version of numpy?

@DishantMewada
Copy link

Getting the same error. I used to restart the kernel after installing dependencies on colab and it used to work. But now after running the first cell and restarting the kernel. Getting the below error -

AttributeError: module 'jax.random' has no attribute 'PRNGKeyArray'

Numpy version: Version: 1.26.4

Can someone please make the docker file or something? I have never been successful in installing it on my local machine.

@evantkchong
Copy link

Encountered this issue as well in Colab. Seems similar to numpy/numpy#25150

Running the first notebook cell:

!rm -rf *
!rm -rf .config
!rm -rf .git
!git clone https://github.com/google-research/scenic.git .
!python -m pip install -q .
!python -m pip install -r ./scenic/projects/owl_vit/requirements.txt

# Also install big_vision, which is needed for the mask head:
!mkdir /big_vision
!git clone https://github.com/google-research/big_vision.git /big_vision
!python -m pip install -r /big_vision/big_vision/requirements.txt
import sys
sys.path.append('/big_vision/')
!echo "Done."

results in the following python environment:

absl-py==1.4.0
aiohttp==3.9.5
aiosignal==1.3.1
alabaster==0.7.16
albumentations==1.3.1
altair==4.2.2
annotated-types==0.6.0
anyio==3.7.1
appdirs==1.4.4
aqtp==0.7.2
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
array_record==0.5.1
arviz==0.15.1
astropy==5.3.4
astunparse==1.6.3
async-timeout==4.0.3
atpublic==4.1.0
attrs==23.2.0
audioread==3.0.1
autograd==1.6.2
Babel==2.14.0
backcall==0.2.0
beautifulsoup4==4.12.3
bidict==0.23.1
bigframes==1.2.0
bleach==6.1.0
blinker==1.4
blis==0.7.11
blosc2==2.0.0
bokeh==3.3.4
bqplot==0.12.43
branca==0.7.1
build==1.2.1
CacheControl==0.14.0
cachetools==5.3.3
catalogue==2.0.10
certifi==2024.2.2
cffi==1.16.0
chardet==5.2.0
charset-normalizer==3.3.2
chex==0.1.86
click==8.1.7
click-plugins==1.1.1
cligj==0.7.2
clip @ git+https://github.com/openai/CLIP.git@a1d071733d7111c9c014f024669f959182114e33
cloudpathlib==0.16.0
cloudpickle==2.2.1
clu==0.0.12
cmake==3.27.9
cmdstanpy==1.2.2
colorcet==3.1.0
colorlover==0.3.0
colour==0.1.5
community==1.0.0b1
confection==0.1.4
cons==0.4.6
contextlib2==21.6.0
contourpy==1.2.1
cryptography==42.0.5
cufflinks==0.17.3
cupy-cuda12x==12.2.0
cvxopt==1.3.2
cvxpy==1.3.4
cycler==0.12.1
cymem==2.0.8
Cython==3.0.10
dask==2023.8.1
datascience==0.17.6
db-dtypes==1.2.0
dbus-python==1.2.18
debugpy==1.6.6
decorator==4.4.2
defusedxml==0.7.1
distrax==0.1.5
distributed==2023.8.1
distro==1.7.0
dlib==19.24.4
dm-tree==0.1.8
docstring_parser==0.16
docutils==0.18.1
dopamine-rl==4.0.6
duckdb==0.10.2
earthengine-api==0.1.399
easydict==1.13
ecos==2.0.13
editdistance==0.6.2
eerepr==0.0.4
einops==0.8.0
en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl#sha256=86cc141f63942d4b2c5fcee06630fd6f904788d2f0ab005cce45aadb8fb73889
entrypoints==0.4
et-xmlfile==1.1.0
etils==1.7.0
etuples==0.3.9
exceptiongroup==1.2.1
fastai==2.7.14
fastcore==1.5.29
fastdownload==0.0.7
fastjsonschema==2.19.1
fastprogress==1.0.3
fastrlock==0.8.2
filelock==3.13.4
fiona==1.9.6
firebase-admin==5.3.0
Flask==2.2.5
flatbuffers==24.3.25
flax==0.8.2
flaxformer @ git+https://github.com/google/flaxformer@399ea3a85e9807ada653fd0de1a9de627eb0acde
folium==0.14.0
fonttools==4.51.0
frozendict==2.4.2
frozenlist==1.4.1
fsspec==2023.6.0
ftfy==6.2.0
future==0.18.3
gast==0.5.4
gcsfs==2023.6.0
GDAL==3.6.4
gdown==5.1.0
geemap==0.32.0
gensim==4.3.2
geocoder==1.38.1
geographiclib==2.0
geopandas==0.13.2
geopy==2.3.0
gin-config==0.5.0
glob2==0.7
google==2.0.3
google-ai-generativelanguage==0.6.2
google-api-core==2.11.1
google-api-python-client==2.84.0
google-auth==2.27.0
google-auth-httplib2==0.1.1
google-auth-oauthlib==1.2.0
google-cloud-aiplatform==1.48.0
google-cloud-bigquery==3.12.0
google-cloud-bigquery-connection==1.12.1
google-cloud-bigquery-storage==2.24.0
google-cloud-core==2.3.3
google-cloud-datastore==2.15.2
google-cloud-firestore==2.11.1
google-cloud-functions==1.13.3
google-cloud-iam==2.15.0
google-cloud-language==2.13.3
google-cloud-resource-manager==1.12.3
google-cloud-storage==2.8.0
google-cloud-translate==3.11.3
google-colab @ file:///colabtools/dist/google-colab-1.0.0.tar.gz#sha256=3e056e666c1589f62dc076e4ca3199adf497d659a8bbfc3ed3234459fa688933
google-crc32c==1.5.0
google-generativeai==0.5.2
google-pasta==0.2.0
google-resumable-media==2.7.0
googleapis-common-protos==1.63.0
googledrivedownloader==0.4
graphviz==0.20.3
greenlet==3.0.3
grpc-google-iam-v1==0.13.0
grpcio==1.62.2
grpcio-status==1.48.2
gspread==3.4.2
gspread-dataframe==3.3.1
gym==0.25.2
gym-notices==0.0.8
h5netcdf==1.3.0
h5py==3.11.0
holidays==0.47
holoviews==1.17.1
html5lib==1.1
httpimport==1.3.1
httplib2==0.22.0
huggingface-hub==0.20.3
humanize==4.7.0
hyperopt==0.2.7
ibis-framework==8.0.0
idna==3.7
imageio==2.31.6
imageio-ffmpeg==0.4.9
imagesize==1.4.1
imbalanced-learn==0.10.1
imgaug==0.4.0
immutabledict==4.2.0
importlib_metadata==7.1.0
importlib_resources==6.4.0
imutils==0.5.4
inflect==7.0.0
iniconfig==2.0.0
intel-openmp==2023.2.4
ipyevents==2.0.2
ipyfilechooser==0.6.0
ipykernel==5.5.6
ipyleaflet==0.18.2
ipython==7.34.0
ipython-genutils==0.2.0
ipython-sql==0.5.0
ipytree==0.2.2
ipywidgets==7.7.1
itsdangerous==2.2.0
jax==0.4.26
jaxlib @ https://storage.googleapis.com/jax-releases/cuda12/jaxlib-0.4.26+cuda12.cudnn89-cp310-cp310-manylinux2014_x86_64.whl#sha256=813cf1fe3e7ca4dbf5327d6e7b4fc8521e92d8bba073ee645ae0d5d036a25750
jeepney==0.7.1
jieba==0.42.1
Jinja2==3.1.3
joblib==1.4.0
jsonpickle==3.0.4
jsonschema==4.19.2
jsonschema-specifications==2023.12.1
jupyter-client==6.1.12
jupyter-console==6.1.0
jupyter-server==1.24.0
jupyter_core==5.7.2
jupyterlab_pygments==0.3.0
jupyterlab_widgets==3.0.10
kaggle==1.5.16
kagglehub==0.2.3
keras==3.3.3
keyring==23.5.0
kiwisolver==1.4.5
langcodes==3.3.0
launchpadlib==1.10.16
lazr.restfulclient==0.14.4
lazr.uri==1.0.6
lazy_loader==0.4
libclang==18.1.1
librosa==0.10.1
lightgbm==4.1.0
linkify-it-py==2.0.3
llvmlite==0.41.1
locket==1.0.0
logical-unification==0.4.6
lvis==0.5.3
lxml==4.9.4
malloy==2023.1067
Markdown==3.6
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib==3.7.1
matplotlib-inline==0.1.7
matplotlib-venn==0.11.10
mdit-py-plugins==0.4.0
mdurl==0.1.2
miniKanren==1.0.3
missingno==0.5.2
mistune==0.8.4
mizani==0.9.3
mkl==2023.2.0
ml-collections==0.1.1
ml-dtypes==0.3.2
mlxtend==0.22.0
more-itertools==10.1.0
moviepy==1.0.3
mpmath==1.3.0
msgpack==1.0.8
multidict==6.0.5
multipledispatch==1.0.0
multitasking==0.0.11
murmurhash==1.0.10
music21==9.1.0
namex==0.0.8
natsort==8.4.0
nbclassic==1.0.0
nbclient==0.10.0
nbconvert==6.5.4
nbformat==5.10.4
nest-asyncio==1.6.0
networkx==3.3
nibabel==4.0.2
nltk==3.8.1
notebook==6.5.5
notebook_shim==0.2.4
numba==0.58.1
numexpr==2.10.0
numpy==1.26.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.19.3
nvidia-nvjitlink-cu12==12.4.127
nvidia-nvtx-cu12==12.1.105
oauth2client==4.1.3
oauthlib==3.2.2
opencv-contrib-python==4.8.0.76
opencv-python==4.8.0.76
opencv-python-headless==4.9.0.80
openpyxl==3.1.2
opt-einsum==3.3.0
optax @ git+https://github.com/google-deepmind/optax.git@7d43c5c0cc1ab343229c3c394b3179a1404e97e8
optree==0.11.0
orbax-checkpoint==0.4.4
osqp==0.6.2.post8
ott-jax==0.3.1
overrides==7.7.0
packaging==24.0
pandas==2.0.3
pandas-datareader==0.10.0
pandas-gbq==0.19.2
pandas-stubs==2.0.3.230814
pandocfilters==1.5.1
panel==1.3.8
panopticapi @ git+https://github.com/akolesnikoff/panopticapi.git@a698a12deb21e4cf0f99ef0581b2c30c466bf355
param==2.1.0
parso==0.8.4
parsy==2.1
partd==1.4.1
pathlib==1.0.1
patsy==0.5.6
peewee==3.17.3
pexpect==4.9.0
pickleshare==0.7.5
Pillow==9.4.0
pip-tools==6.13.0
platformdirs==4.2.0
plotly==5.15.0
plotnine==0.12.4
pluggy==1.5.0
polars==0.20.2
pooch==1.8.1
portpicker==1.5.2
prefetch-generator==1.0.3
preshed==3.0.9
prettytable==3.10.0
proglog==0.1.10
progressbar2==4.2.0
prometheus_client==0.20.0
promise==2.3
prompt-toolkit==3.0.43
prophet==1.1.5
proto-plus==1.23.0
protobuf==3.20.3
psutil==5.9.5
psycopg2==2.9.9
ptyprocess==0.7.0
py-cpuinfo==9.0.0
py4j==0.10.9.7
pyarrow==14.0.2
pyarrow-hotfix==0.6
pyasn1==0.6.0
pyasn1_modules==0.4.0
pycocotools @ git+https://github.com/cocodataset/cocoapi.git@8c9bcc3cf640524c4c20a9c40e89cb6a2f2fa0e9#subdirectory=PythonAPI
pycparser==2.22
pydantic==2.7.0
pydantic_core==2.18.1
pydata-google-auth==1.8.2
pydot==1.4.2
pydot-ng==2.0.0
pydotplus==2.0.2
PyDrive==1.3.1
PyDrive2==1.6.3
pyerfa==2.0.1.4
pygame==2.5.2
Pygments==2.16.1
PyGObject==3.42.1
PyJWT==2.3.0
pymc==5.10.4
pymystem3==0.2.0
PyOpenGL==3.1.7
pyOpenSSL==24.1.0
pyparsing==3.1.2
pyperclip==1.8.2
pyproj==3.6.1
pyproject_hooks==1.0.0
pyshp==2.3.1
PySocks==1.7.1
pytensor==2.18.6
pytest==7.4.4
python-apt @ file:///backend-container/containers/python_apt-0.0.0-cp310-cp310-linux_x86_64.whl#sha256=b209c7165d6061963abe611492f8c91c3bcef4b7a6600f966bab58900c63fefa
python-box==7.1.1
python-dateutil==2.8.2
python-louvain==0.16
python-slugify==8.0.4
python-utils==3.8.2
pytz==2023.4
pyviz_comms==3.0.2
PyWavelets==1.6.0
PyYAML==6.0.1
pyzmq==23.2.1
qdldl==0.1.7.post2
qudida==0.0.4
ratelim==0.1.6
referencing==0.34.0
regex==2023.12.25
requests==2.31.0
requests-oauthlib==1.3.1
requirements-parser==0.9.0
rich==13.7.1
rpds-py==0.18.0
rpy2==3.4.2
rsa==4.9
safetensors==0.4.3
scenic @ file:///content
scikit-image==0.19.3
scikit-learn==1.2.2
scipy==1.11.4
scooby==0.9.2
scs==3.2.4.post1
seaborn==0.13.1
SecretStorage==3.3.1
Send2Trash==1.8.3
sentencepiece==0.1.99
shapely==2.0.4
six==1.16.0
sklearn-pandas==2.2.0
smart-open==6.4.0
sniffio==1.3.1
snowballstemmer==2.2.0
sortedcontainers==2.4.0
soundfile==0.12.1
soupsieve==2.5
soxr==0.3.7
spacy==3.7.4
spacy-legacy==3.0.12
spacy-loggers==1.0.5
Sphinx==5.0.2
sphinxcontrib-applehelp==1.0.8
sphinxcontrib-devhelp==1.0.6
sphinxcontrib-htmlhelp==2.0.5
sphinxcontrib-jsmath==1.0.1
sphinxcontrib-qthelp==1.0.7
sphinxcontrib-serializinghtml==1.1.10
SQLAlchemy==2.0.29
sqlglot==20.11.0
sqlparse==0.5.0
srsly==2.4.8
stanio==0.5.0
statsmodels==0.14.2
sympy==1.12
tables==3.8.0
tabulate==0.9.0
tbb==2021.12.0
tblib==3.0.0
tenacity==8.2.3
tensorboard==2.16.2
tensorboard-data-server==0.7.2
tensorflow==2.16.1
tensorflow-cpu==2.16.1
tensorflow-datasets==4.9.4
tensorflow-estimator==2.15.0
tensorflow-gan==2.1.0
tensorflow-gcs-config==2.15.0
tensorflow-hub==0.16.1
tensorflow-io-gcs-filesystem==0.36.0
tensorflow-metadata==1.14.0
tensorflow-probability==0.23.0
tensorflow-text==2.16.1
tensorstore==0.1.45
termcolor==2.4.0
terminado==0.18.1
text-unidecode==1.3
textblob==0.17.1
tf-slim==1.1.0
tf_keras==2.16.0
tfds-nightly==4.9.4.dev202404300044
thinc==8.2.3
threadpoolctl==3.4.0
tifffile==2024.4.18
tinycss2==1.2.1
tokenizers==0.19.1
toml==0.10.2
tomli==2.0.1
toolz==0.12.1
torch @ https://download.pytorch.org/whl/cu121/torch-2.2.1%2Bcu121-cp310-cp310-linux_x86_64.whl#sha256=1adf430f01ff649c848ac021785e18007b0714fdde68e4e65bd0c640bf3fb8e1
torchaudio @ https://download.pytorch.org/whl/cu121/torchaudio-2.2.1%2Bcu121-cp310-cp310-linux_x86_64.whl#sha256=23f6236429e2bf676b820e8e7221a1d58aaf908bff2ba2665aa852df71a97961
torchdata==0.7.1
torchsummary==1.5.1
torchtext==0.17.1
torchvision @ https://download.pytorch.org/whl/cu121/torchvision-0.17.1%2Bcu121-cp310-cp310-linux_x86_64.whl#sha256=27af47915f6e762c1d44e58e8088d22ac97445668f9f793524032b2baf4f34bd
tornado==6.3.3
tqdm==4.66.2
traitlets==5.7.1
traittypes==0.2.1
transformers==4.40.0
triton==2.2.0
tweepy==4.14.0
typer==0.9.4
types-pytz==2024.1.0.20240417
types-setuptools==69.5.0.20240423
typing_extensions==4.11.0
tzdata==2024.1
tzlocal==5.2
uc-micro-py==1.0.3
uritemplate==4.1.1
urllib3==2.0.7
vega-datasets==0.9.0
wadllib==1.3.6
wasabi==1.1.2
wcwidth==0.2.13
weasel==0.3.4
webcolors==1.13
webencodings==0.5.1
websocket-client==1.7.0
Werkzeug==3.0.2
widgetsnbextension==3.6.6
wordcloud==1.9.3
wrapt==1.14.1
xarray==2023.7.0
xarray-einstats==0.7.0
xgboost==2.0.3
xlrd==2.0.1
xyzservices==2024.4.0
yarl==1.9.4
yellowbrick==1.5
yfinance==0.2.38
zict==3.0.0
zipp==3.18.1

Indicating that the version of numpy installed should be 1.26.4. However doing print(np.__version__) shows that the actual version installed is 1.25.2.

@Chi-chicken
Copy link

Getting the same error. I used to restart the kernel after installing dependencies on colab and it used to work. But now after running the first cell and restarting the kernel. Getting the below error -

AttributeError: module 'jax.random' has no attribute 'PRNGKeyArray'

Numpy version: Version: 1.26.4

Can someone please make the docker file or something? I have never been successful in installing it on my local machine.

You can try to change PRNGKeyArray to PRNGKey. It worked for me.

@DishantMewada
Copy link

DishantMewada commented Apr 30, 2024

@Chi-chicken Can you please clarify how to do that? Thank you.

I checked my jax.random file on local machine. It doesn't contain PRNGKeyArray. Only PRNGKey.

Attached is my jax.random python file in conda directory.
random.txt

And the full error while importing - from scenic.projects.owl_vit import models

WARNING:absl:Type handler registry overriding type "<class 'float'>" collision on scalar
WARNING:absl:Type handler registry overriding type "<class 'bytes'>" collision on scalar
WARNING:absl:Type handler registry overriding type "<class 'numpy.number'>" collision on scalar


AttributeError Traceback (most recent call last)
Cell In[1], line 1
----> 1 from scenic.projects.owl_vit import models

File ~/Downloads/owl_vit/scenic/scenic/projects/owl_vit/models.py:26
24 import ml_collections
25 from scenic.projects.owl_vit import layers
---> 26 from scenic.projects.owl_vit import matching_base_models
27 from scenic.projects.owl_vit import utils
28 from scenic.projects.owl_vit.clip import model as clip_model

File ~/Downloads/owl_vit/scenic/scenic/projects/owl_vit/matching_base_models.py:23
21 import jax.numpy as jnp
22 import ml_collections
---> 23 from scenic.model_lib import matchers
24 from scenic.model_lib.base_models import base_model
25 from scenic.model_lib.base_models import box_utils

File ~/Downloads/owl_vit/scenic/scenic/model_lib/matchers/init.py:25
23 from scenic.model_lib.matchers.hungarian_jax import hungarian_tpu_matcher
24 from scenic.model_lib.matchers.lazy import lazy_matcher
---> 25 from scenic.model_lib.matchers.sinkhorn import sinkhorn_matcher

File ~/Downloads/owl_vit/scenic/scenic/model_lib/matchers/sinkhorn.py:24
22 import jax.numpy as jnp
23 import numpy as np
---> 24 from ott.geometry import geometry
25 from ott.tools import transport
28 def idx2permutation(row_ind, col_ind):

File ~/anaconda3/envs/owl_vit/lib/python3.10/site-packages/ott/init.py:15
1 # Copyright 2022 Google LLC.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
(...)
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 """OTT library."""
---> 15 from . import geometry, initializers, math, problems, solvers, tools, utils
16 from ._version import version

File ~/anaconda3/envs/owl_vit/lib/python3.10/site-packages/ott/initializers/init.py:1
----> 1 from . import linear, nn, quadratic

File ~/anaconda3/envs/owl_vit/lib/python3.10/site-packages/ott/initializers/nn/init.py:1
----> 1 from . import initializers

File ~/anaconda3/envs/owl_vit/lib/python3.10/site-packages/ott/initializers/nn/initializers.py:21
16 # TODO(michalk8): add initializer for NeuralDual?
17 all = ["MetaInitializer", "MetaMLP"]
20 @jax.tree_util.register_pytree_node_class
---> 21 class MetaInitializer(initializers.DefaultInitializer):
22 """Meta OT Initializer with a fixed geometry :cite:amos:22.
23
24 This initializer consists of a predictive model that outputs the
(...)
56 )
57 """
59 def init(
60 self,
61 geom: geometry.Geometry,
(...)
65 state: Optional[train_state.TrainState] = None
66 ):

File ~/anaconda3/envs/owl_vit/lib/python3.10/site-packages/ott/initializers/nn/initializers.py:64, in MetaInitializer()
20 @jax.tree_util.register_pytree_node_class
21 class MetaInitializer(initializers.DefaultInitializer):
22 """Meta OT Initializer with a fixed geometry :cite:amos:22.
23
24 This initializer consists of a predictive model that outputs the
(...)
56 )
57 """
59 def init(
60 self,
61 geom: geometry.Geometry,
62 meta_model: Optional[nn.Module] = None,
63 opt: optax.GradientTransformation = optax.adam(learning_rate=1e-3),
---> 64 rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0),
65 state: Optional[train_state.TrainState] = None
66 ):
67 self.geom = geom
68 self.dtype = geom.x.dtype

File ~/anaconda3/envs/owl_vit/lib/python3.10/site-packages/jax/_src/deprecations.py:54, in deprecation_getattr..getattr(name)
52 warnings.warn(message, DeprecationWarning, stacklevel=2)
53 return fn
---> 54 raise AttributeError(f"module {module!r} has no attribute {name!r}")

AttributeError: module 'jax.random' has no attribute 'PRNGKeyArray'

@Chi-chicken
Copy link

Chi-chicken commented May 1, 2024

@DishantMewada You can try to change PRNGKeyArray to PRNGKey in the initializers.py file.

@DishantMewada
Copy link

@Chi-chicken Thank you for the reply. Can you please tell where is the file and what did you change in it?

@Chi-chicken
Copy link

Chi-chicken commented May 1, 2024

@DishantMewada The file path: ~/anaconda3/envs/owl_vit/lib/python3.10/site-packages/ott/initializers/nn/initializers.py
You can use vim ~/anaconda3/envs/owl_vit/lib/python3.10/site-packages/ott/initializers/nn/initializers.py to edit initializers.py: 64 (rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0) -> rng: jax.random.PRNGKey = jax.random.PRNGKey(0))

@DishantMewada
Copy link

You are such a saviour @Chi-chicken. Thank you so much. It is working on my local machine. I know the original question from @kaninaba94 was related to Google Colab. But still, I confirm that this method works for the local machine if you have the same dependencies as me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants