Skip to content

Latest commit

 

History

History

vgg

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 

VGG Training in Jax/Flax

This is the training code for the Jax/Flax implementation of VGG.

Table of Contents

Getting Started

You will need Python 3.7 or later.

  1. Clone the repository:
    > git clone https://github.com/matthias-wright/flaxmodels.git
  2. Go into the directory:
    > cd flaxmodels/training/vgg
  3. Install Jax with CUDA.
  4. Install requirements:
    > pip install -r requirements.txt

Training

Basic Training

CUDA_VISIBLE_DEVICES=0 python main.py

Multi GPU Training

The script will automatically use all the visible GPUs for distributed training.

CUDA_VISIBLE_DEVICES=0,1 python main.py

Mixed-Precision Training

CUDA_VISIBLE_DEVICES=0,1 python main.py --mixed_precision

Options

  • --work_dir - Path to directory for logging and checkpoints (str).
  • --data_dir - Path for storing the dataset (str).
  • --name - Name of the training run (str).
  • --group - Group name of the training run (str).
  • --arch - Architecture (str). Options: vgg16, vgg19.
  • --resume - Resume training from best checkpoint (bool).
  • --num_epochs - Number of epochs (int).
  • --learning_rate - Learning rate (float).
  • --warmup_epochs - Number of warmup epochs with lower learning rate (int).
  • --batch_size - Batch size (int).
  • --num_classes - Number of classes (int).
  • --img_size - Image size (int).
  • --img_channels - Number of image channels (int).
  • --mixed_precision - Use mixed precision training (bool).
  • --random_seed - Random seed (int).
  • --wandb - Use Weights&Biases for logging (bool).
  • --log_every - Log every log_every steps (int).

References

License

MIT License