Skip to content

Charl-AI/Solstice

Repository files navigation

Solstice

Solstice is a library for constructing modular and structured deep learning experiments in JAX. Built with Equinox, but designed for full interoparability with JAX neural network libraries e.g. Stax, Haiku, Flax, Optax etc...

Why use Solstice in a world with Flax/Haiku/Objax/...? Solstice is not a neural network framework. It is a system for organising JAX code, with a small library of sane defaults for common use cases (think PyTorch Lightning, but for JAX). The library itself is simple and flexible, leaving most important decisions to the user - we aim to provide high-quality examples to demonstrate the different ways you can use this flexibility.

Solstice is in the alpha stage of development, there may be API changes until we settle on a stable version 1.0.0

Installation

First, install JAX, then:

pip install solstice-jax

Solstice is fully documented, including a full API Reference, as well as tutorials and examples. Below, we provide a bare minimum example for how to get started.

Getting Started

The central abstraction in Solstice is the solstice.Experiment. An Experiment is a container for all functions and stateful objects that are relevant to a run. You can create an Experiment by subclassing solstice.Experiment and implementing the abstractmethods for initialisation, training, and evaluation. Experiments are best used with solstice.Metrics for tracking metrics and solstice.train() so you can stop writing boilerplate training loops.

from typing import Any, Tuple
import logging
import jax
import jax.numpy as jnp
import solstice
import tensorflow_datasets as tfds

logging.basicConfig(level=logging.INFO)


class RandomClassifier(solstice.Experiment):
    """A terrible, terrible classifier for binary class problems :("""

    rng_state: Any

    def __init__(self, rng: int):
        self.rng_state = jax.random.PRNGKey(rng)

    def __call__(self, x):
        del x
        return jax.random.bernoulli(self.rng_state, p=0.5).astype(jnp.float32)

    @jax.jit
    def train_step(
        self, batch: Tuple[jnp.ndarray, ...]
    ) -> Tuple["RandomClassifier", solstice.Metrics]:
        x, y = batch
        preds = jax.vmap(self)(x)
        # use solstice Metrics API for convenient metrics calculation
        metrics = solstice.ClassificationMetrics(preds, y, loss=jnp.nan, num_classes=2)
        new_rng_state = jax.random.split(self.rng_state)[0]

        return solstice.replace(self, rng_state=new_rng_state), metrics

    @jax.jit
    def eval_step(
        self, batch: Tuple[jnp.ndarray, ...]
    ) -> Tuple["RandomClassifier", solstice.Metrics]:
        x, y = batch
        preds = jax.vmap(self)(x)
        metrics = solstice.ClassificationMetrics(preds, y, loss=jnp.nan, num_classes=2)
        return self, metrics


train_ds = tfds.load(name="mnist", split="train", as_supervised=True)  # type: Any
train_ds = train_ds.batch(32).prefetch(1)
exp = RandomClassifier(42)
# use solstice.train() with callbacks to remove boilerplate code
trained_exp = solstice.train(
    exp, num_epochs=1, train_ds=train_ds, callbacks=[solstice.LoggingCallback()]
)

Notice that we were able to use pure JAX transformations such as jax.jit and jax.vmap within the class. This is because solstice.Experiment is just a subclass of Equinox.Module. We explain this further in the Solstice Primer, but in general, if you understand JAX/Equinox, you will understand Solstice.

Incrementally buying-in

Open In Colab

Solstice is a library, not a framework, and it is important to us that you have the freedom to use as little or as much of it as you like. If are interested in starting using Solstice, but don't know where to begin, here are three steps towards Solstice-ification.

Stage 1: organise your training code with solstice.Experiment

The Experiment object contains stateful objects such as model and optimizer parameters and also encapsulates the steps for training and evaluation. In Flax, this would replace the TrainState object and serve to better organise your code. At this stage, the main advantage is that your code is more readable and scalable because you can define different Experiments for different use cases.

Stage 2: implement solstice.Metrics for tracking metrics

A solstice.Metrics object knows how to calculate and accumulate intermediate results, before computing final metrics. The main advantage is the ability to scalably track lots of metrics with a common interface. By tracking intermediate results and computing at the end, it is easier to handle metrics which are not 'averageable' over batches (e.g. precision).

Stage 3: use the premade solstice.train() loop with solstice.Callbacks

Training loops are usually boilerplate code. We provide premade training and testing loops which integrate with a simple and flexible callback system. This allows you to separate the basic logic of training from customisable side effects such as logging and checkpointing. We provide some useful pre-made callbacks and give examples for how to write your own.

Our Logos

We have two Solstice logos: the Summer Solstice 🌞 and the Winter Solstice 🌛. Both were created with Dall-E mini (free license) with the following prompt:

a logo featuring stonehenge during a solstice

Solstice Logos