Paper: https://arxiv.org/abs/1409.1556
Project Page: https://www.robots.ox.ac.uk/~vgg/research/very_deep/
Repository: https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py
Images must be in range [0, 1]. If the pretrained ImageNet weights are selected, the images are internally normalized with the ImageNet mean and standard deviation. If you don't want the images to be normalized, use normalize=False
(see here for details).
For more usage examples check out this Colab.
from PIL import Image
import jax
import jax.numpy as jnp
import flaxmodels as fm
key = jax.random.PRNGKey(0)
# Load image
img = Image.open('example.jpg')
# Image must be 224x224 if classification head is included
img = img.resize((224, 224))
# Image should be in range [0, 1]
x = jnp.array(img, dtype=jnp.float32) / 255.0
# Add batch dimension
x = jnp.expand_dims(x, axis=0)
vgg16 = fm.VGG16(output='logits', pretrained='imagenet')
params = vgg16.init(key, x)
out = vgg16.apply(params, x, train=False)
Usage is equivalent for VGG19.
The documentation can be found here.
If you want to train VGG in Jax/Flax, go here.