Skip to content

Simple CIFAR10 ResNet example with JAX.

Notifications You must be signed in to change notification settings

hushon/JAX-ResNet-CIFAR10

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

9 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Simple CIFAR10 ResNet in JAX

This repo provides ResNet example for CIFAR-10 using Google's JAX. I aim to provide a simple baseline code for deep learning researchers who want to quickly get started with JAX. For those who are not famlilar with JAX, it is Autograd + XLA.

I built upon Deepmind's Haiku and Optax for high-level neural net API. I used PyTorch and Torchvision for data loading pipeline. My ResNet implementation is based on this repo.

Updates:

  • Support for mixed precision training using JMP.
  • Support for multi-GPU training: train_multigpu.py

Requirements

  • JAX
  • Haiku
  • Optax
  • dm-tree
  • PyTorch
  • Torchvision

Run

python train.py

Mixed precision training

python train_mp.py

Benchmarks

Model Size Test Acc
ResNet20 0.27 M 91.5 %
ResNet32 0.46 M 92.5 %
ResNet44 0.66 M 93.1 %
ResNet56 0.85 M 93.2 %