Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Neural Net Training is bottlenecked by maxed out CPU #3750

Open
josephdepaoloboisvert opened this issue Mar 12, 2024 · 0 comments
Open

Neural Net Training is bottlenecked by maxed out CPU #3750

josephdepaoloboisvert opened this issue Mar 12, 2024 · 0 comments
Labels
Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment.

Comments

@josephdepaoloboisvert
Copy link

Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
    Linux Ubuntu 22.04
  • Flax, jax, jaxlib versions (obtain with pip show flax jax jaxlib:
    flax 0.8.1 pyhd8ed1ab_0 conda-forge
    jax 0.4.25 pyhd8ed1ab_0 conda-forge
    jaxlib 0.4.23 cuda118py312h5dcafd1_200 conda-forge
  • Python version:
    python 3.12.2 hab00c5b_0_cpython conda-forge
  • GPU/TPU model and memory:
    2/5 slice of NVIDIA A100 (10Gb VRAM)
  • CUDA version (if applicable):
    12.2

jax.print_environment_info()

jax: 0.4.25
jaxlib: 0.4.23.dev20240229
numpy: 1.26.4
python: 3.12.2 | packaged by conda-forge | (main, Feb 16 2024, 20:50:58) [GCC 12.3.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='genuinely-jolly-jackalope', release='5.15.0-97-generic', version='#107-Ubuntu SMP Wed Feb 7 13:26:48 UTC 2024', machine='x86_64')

$ nvidia-smi
Tue Mar 12 03:14:50 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 GRID A100X-10C On | 00000000:04:00.0 Off | 0 |
| N/A N/A P0 N/A / N/A | 418MiB / 10240MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| 0 N/A N/A 42322 C ...r/miniconda3/envs/jax_12/bin/python 417MiB |
+---------------------------------------------------------------------------------------+

jax.default_backend()
'gpu'

Problem you have encountered:

In training FLAX neural network models, the CPU usage is almost always at 100%.
JAX operations are occuring on the GPU.
I'm unsure if FLAX is making use of the GPU for storage of the NN in memory and updating the NN parameters.

What you expected to happen:

FLAX to utilize GPU

Logs, error messages, etc:

I've attached a trace of one epoch of my training loop that I generated using the following instructions https://jax.readthedocs.io/en/latest/profiling.html
I'm a PhD Chemistry candidate with little experience reading traces.

Steps to reproduce:

Unfortunately I don't have a minimally reproducible example at the moment. I'm hoping that anyone can provide insight if FLAX should be utilizing the GPU, and if it is can tell from the trace whether or not this is the case (or something else in my code might be bottlenecking training)
perfetto_trace.json.gz
.

@chiamp chiamp added the Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment. label Mar 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment.
Projects
None yet
Development

No branches or pull requests

2 participants