Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Could not normally run trax using GPU in local computer #1778

Open
LiuZhenshun opened this issue Jun 22, 2023 · 0 comments
Open

Could not normally run trax using GPU in local computer #1778

LiuZhenshun opened this issue Jun 22, 2023 · 0 comments

Comments

@LiuZhenshun
Copy link

Description

Hi, I would like to install trax locally. Firstly, I found jax I installed is not suitable for GPU, so I follow the jax github to install Cuda version jax. Next, I validate jax could detect GPU in my local computer, but I could not run the sample code like transfomer and fast math.

Environment information

OS: Pop-os(based on ubuntu 22.04)

$ pip freeze | grep trax
# trax==1.4.1

$ pip freeze | grep tensor
# tensorboard==2.12.3
# tensorboard-data-server==0.7.1
# tensorflow==2.12.0
# tensorflow-datasets==4.9.2
# tensorflow-estimator==2.12.0
# tensorflow-hub==0.13.0
# tensorflow-io-gcs-filesystem==0.32.0
# tensorflow-metadata==1.13.1
# tensorflow-text==2.12.1

$ pip freeze | grep jax
# jax==0.4.12
# jaxlib==0.4.12+cuda11.cudnn86

$ python -V
# Python 3.11.3

For bugs: reproduction and error logs

# Steps to reproduce:
1) Install trax

- pip install trax
- pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

2) Use jax Detect GPU
- code:
    import jax
    print(jax.devices()) 
- output:
    [gpu(id=0)]
# Error logs:
1) Run the sample code of pre-trained transformer in your Realme tutorial
- code:
      import os
      import numpy as np
      
      import trax
      
      # Create a Transformer model.
      # Pre-trained model config in gs://trax-ml/models/translation/ende_wmt32k.gin
      model = trax.models.Transformer(
          input_vocab_size=33300,
          d_model=512, d_ff=2048,
          n_heads=8, n_encoder_layers=6, n_decoder_layers=6,
          max_len=64, mode='predict')
      
      # Initialize using pre-trained weights.
      model.init_from_file('gs://trax-ml/models/translation/ende_wmt32k.pkl.gz',
                           weights_only=True)
                          #  input_signature=input_signature)
      
      # Tokenize a sentence.
      sentence = 'It is nice to learn new things today!'
      tokenized = list(trax.data.tokenize(iter([sentence]),  # Operates on streams.
                                          vocab_dir='gs://trax-ml/vocabs/',
                                          vocab_file='ende_32k.subword'))[0]
      
      # Decode from the Transformer.
      tokenized = tokenized[None, :]  # Add batch dimension.
      tokenized_translation = trax.supervised.decoding.autoregressive_sample(
          model, tokenized, temperature=0.0)  # Higher temperature: more diverse results.
      
      # De-tokenize,
      tokenized_translation = tokenized_translation[0][:-1]  # Remove batch and EOS.
      translation = trax.data.detokenize(tokenized_translation,
                                         vocab_dir='gs://trax-ml/vocabs/',
                                         vocab_file='ende_32k.subword')
      print(translation)
- Error Output:
      2023-06-22 15:58:35.266959: W tensorflow/tsl/platform/cloud/google_auth_provider.cc:184] All attempts to get a Google authentication bearer token failed, returning an empty token. Retrieving token from files failed with "NOT_FOUND: Could not locate the credentials file.". Retrieving token from GCE failed with "FAILED_PRECONDITION: Error executing an HTTP request: libcurl code 6 meaning 'Couldn't resolve host name', error details: Could not resolve host: metadata.google.internal".
      2023-06-22 15:58:56.630331: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
      INTERNAL: Failed to get stream's capture status: out of memory
      2023-06-22 15:58:56.630403: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2461] Execution of replica 0 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.func.launch' failed: Failed to get stream's capture status: out of memory; current tracing scope: fusion; current profiling annotation: XlaModule:#hlo_module=jit_PRNGKey,program_id=0#.
      Traceback (most recent call last):
        File "/home/littleliu/Documents/project/trax_learning/tryTrax.py", line 22, in <module>
          model.init_from_file('gs://trax-ml/models/translation/ende_wmt32k.pkl.gz',
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/trax/layers/base.py", line 349, in init_from_file
          self.init(input_signature)
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/trax/layers/base.py", line 310, in init
          raise LayerError(name, 'init', self._caller,
      trax.layers.base.LayerError: Exception passing through layer Serial (in init):
        layer created in file [...]/trax/models/transformer.py, line 371
        layer input shapes: (ShapeDtype{shape:(1, 1), dtype:int64}, ShapeDtype{shape:(1, 1), dtype:int64}, ShapeDtype{shape:(1, 1), dtype:float32})
      
        File [...]/trax/layers/combinators.py, line 108, in init_weights_and_state
          outputs, _ = sublayer._forward_abstract(inputs)
      
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        File [...]/trax/layers/base.py, line 641, in _forward_abstract
      
        layer created in file [...]/trax/models/transformer.py, line 372
        layer input shapes: (ShapeDtype{shape:(1, 1), dtype:int64}, ShapeDtype{shape:(1, 1), dtype:int64})
      
      jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.func.launch' failed: Failed to get stream's capture status: out of memory; current tracing scope: fusion; current profiling annotation: XlaModule:#hlo_module=jit_PRNGKey,program_id=0#.

2) Run the sample code of Fast Math:
- code:
      import trax
      from trax.fastmath import numpy as fastnp
      trax.fastmath.use_backend('jax')  # Can be 'jax' or 'tensorflow-numpy'.
      
      matrix = fastnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
      print(f'matrix =\n{matrix}')
      vector = fastnp.ones(3)
      print(f'vector = {vector}')
      product = fastnp.dot(vector, matrix)
      print(f'product = {product}')
      tanh = fastnp.tanh(product)
      print(f'tanh(product) = {tanh}')
- Error Output:
      matrix =
      [[1 2 3]
       [4 5 6]
       [7 8 9]]
      2023-06-22 16:03:23.041313: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:439] Could not create cudnn handle: CUDNN_STATUS_NOT_INITIALIZED
      2023-06-22 16:03:23.041386: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:443] Memory usage: 36175872 bytes free, 4093902848 bytes total.
      2023-06-22 16:03:23.041476: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:453] Possibly insufficient driver version: 525.85.5
      Traceback (most recent call last):
        File "/home/littleliu/Documents/project/trax_learning/fastnumpy.py", line 7, in <module>
          vector = fastnp.ones(3)
                   ^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 2161, in ones
          return lax.full(shape, 1, _jnp_dtype(dtype))
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 1205, in full
          return broadcast(fill_value, shape)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 768, in broadcast
          return broadcast_in_dim(operand, tuple(sizes) + np.shape(operand), dims)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 796, in broadcast_in_dim
          return broadcast_in_dim_p.bind(
                 ^^^^^^^^^^^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/core.py", line 380, in bind
          return self.bind_with_trace(find_top_trace(args), args, params)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/core.py", line 383, in bind_with_trace
          out = trace.process_primitive(self, map(trace.full_raise, args), params)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/core.py", line 790, in process_primitive
          return primitive.impl(*tracers, **params)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/dispatch.py", line 132, in apply_primitive
          compiled_fun = xla_primitive_callable(
                         ^^^^^^^^^^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/util.py", line 284, in wrapper
          return cached(config._trace_context(), *args, **kwargs)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/util.py", line 277, in cached
          return f(*args, **kwargs)
                 ^^^^^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/dispatch.py", line 223, in xla_primitive_callable
          compiled = _xla_callable_uncached(
                     ^^^^^^^^^^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/dispatch.py", line 253, in _xla_callable_uncached
          return computation.compile().unsafe_call
                 ^^^^^^^^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2329, in compile
          executable = UnloadedMeshExecutable.from_hlo(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2651, in from_hlo
          xla_executable, compile_options = _cached_compilation(
                                            ^^^^^^^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2561, in _cached_compilation
          xla_executable = dispatch.compile_or_get_cached(
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/dispatch.py", line 497, in compile_or_get_cached
          return backend_compile(backend, computation, compile_options,
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/profiler.py", line 314, in wrapper
          return func(*args, **kwargs)
                 ^^^^^^^^^^^^^^^^^^^^^
        File "/home/littleliu/.conda/envs/trax/lib/python3.11/site-packages/jax/_src/dispatch.py", line 465, in backend_compile
          return backend.compile(built_c, compile_options=options)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant