Skip to content

Single-file SAC-N implementation on jax with flax and equinox. 10x faster than pytorch

License

Notifications You must be signed in to change notification settings

Howuhh/sac-n-jax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

13 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

SAC with Q-Ensemble for Offline RL

Single-file SAC-N [1] implementation on jax with both flax and equinox. 10x faster than SAC-N on pytorch from CORL [2].

And still easy to use and understand! To run:

python sac_n_jax_flax.py --env_name="halfcheetah-medium-v2" --num_critics=10 --batch_size=256
python sac_n_jax_eqx.py --env_name="halfcheetah-medium-v2" --num_critics=10 --batch_size=256

Optionally, you can pass --config_path to the yaml file, for more see pyrallis docs.

Speed comparison

Main insight here is to jit epoch loop also with jax.lax.fori_loop or jax.lax.scan, not just one update of the networks, as it is usually done (jaxrl2 for instance). With jitting the update only speedup will be approx 1.5x here.

Both runs were trained on same V100 GPU.

return_epochs return_time

References

  1. Uncertainty-Based Offline Reinforcement Learning with Diversified Q-Ensemble [code]
  2. Research-oriented Deep Offline Reinforcement Learning Library [code]

About

Single-file SAC-N implementation on jax with flax and equinox. 10x faster than pytorch

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages