You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.
System information
OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
Linux Ubuntu 22.04
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| 0 N/A N/A 42322 C ...r/miniconda3/envs/jax_12/bin/python 417MiB |
+---------------------------------------------------------------------------------------+
jax.default_backend()
'gpu'
Problem you have encountered:
In training FLAX neural network models, the CPU usage is almost always at 100%.
JAX operations are occuring on the GPU.
I'm unsure if FLAX is making use of the GPU for storage of the NN in memory and updating the NN parameters.
What you expected to happen:
FLAX to utilize GPU
Logs, error messages, etc:
I've attached a trace of one epoch of my training loop that I generated using the following instructions https://jax.readthedocs.io/en/latest/profiling.html
I'm a PhD Chemistry candidate with little experience reading traces.
Steps to reproduce:
Unfortunately I don't have a minimally reproducible example at the moment. I'm hoping that anyone can provide insight if FLAX should be utilizing the GPU, and if it is can tell from the trace whether or not this is the case (or something else in my code might be bottlenecking training) perfetto_trace.json.gz
.
The text was updated successfully, but these errors were encountered:
Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.
System information
Linux Ubuntu 22.04
pip show flax jax jaxlib
:flax 0.8.1 pyhd8ed1ab_0 conda-forge
jax 0.4.25 pyhd8ed1ab_0 conda-forge
jaxlib 0.4.23 cuda118py312h5dcafd1_200 conda-forge
python 3.12.2 hab00c5b_0_cpython conda-forge
2/5 slice of NVIDIA A100 (10Gb VRAM)
12.2
jax.print_environment_info()
jax: 0.4.25
jaxlib: 0.4.23.dev20240229
numpy: 1.26.4
python: 3.12.2 | packaged by conda-forge | (main, Feb 16 2024, 20:50:58) [GCC 12.3.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='genuinely-jolly-jackalope', release='5.15.0-97-generic', version='#107-Ubuntu SMP Wed Feb 7 13:26:48 UTC 2024', machine='x86_64')
$ nvidia-smi
Tue Mar 12 03:14:50 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 GRID A100X-10C On | 00000000:04:00.0 Off | 0 |
| N/A N/A P0 N/A / N/A | 418MiB / 10240MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| 0 N/A N/A 42322 C ...r/miniconda3/envs/jax_12/bin/python 417MiB |
+---------------------------------------------------------------------------------------+
jax.default_backend()
'gpu'
Problem you have encountered:
In training FLAX neural network models, the CPU usage is almost always at 100%.
JAX operations are occuring on the GPU.
I'm unsure if FLAX is making use of the GPU for storage of the NN in memory and updating the NN parameters.
What you expected to happen:
FLAX to utilize GPU
Logs, error messages, etc:
I've attached a trace of one epoch of my training loop that I generated using the following instructions https://jax.readthedocs.io/en/latest/profiling.html
I'm a PhD Chemistry candidate with little experience reading traces.
Steps to reproduce:
Unfortunately I don't have a minimally reproducible example at the moment. I'm hoping that anyone can provide insight if FLAX should be utilizing the GPU, and if it is can tell from the trace whether or not this is the case (or something else in my code might be bottlenecking training)
perfetto_trace.json.gz
.
The text was updated successfully, but these errors were encountered: