REAX in 15 Minutes#

REAX is a lightweight training framework that works with any JAX neural network library. This guide will show you the essential concepts in 15 minutes.

What Makes REAX Different?#

  • 🔧 Library Agnostic: Unlike other frameworks, REAX doesn’t force you to use a specific neural network library. Use Flax Linen, Flax NNX, Equinox, Haiku, or any JAX-based library you prefer.

  • âš¡ Minimal Boilerplate: REAX handles the training loop, distributed training, logging, and checkpointing so you can focus on your model.

  • 🎯 Flexible Abstraction: Use the high-level reax.Trainer for standard workflows, or drop down to the reax.Engine for custom training loops.

The 7 Key Steps#

1. Install REAX#

#!pip install reax

2. Define Your Model#

REAX works with any JAX neural network library. Here’s an example using Flax NNX:

from typing import Any

import jaxtyping as jt

import reax
from flax import nnx
import optax

class ImageClassifier(reax.Module):
    class Model(nnx.Module):
        def __init__(self, num_classes: int, rngs: nnx.Rngs):
            super().__init__()
            self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)
            self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)
            self.linear = nnx.Linear(64 * 7 * 7, num_classes, rngs=rngs)

        def __call__(self, x):
            x = x.reshape(*x.shape, 1)  # Need the channels dimension
            x = nnx.relu(self.conv1(x))
            x = nnx.max_pool(x, window_shape=(2, 2), strides=(2, 2))
            x = nnx.relu(self.conv2(x))
            x = nnx.max_pool(x, window_shape=(2, 2), strides=(2, 2))
            x = x.reshape((x.shape[0], -1))
            return self.linear(x)

    def __init__(self, num_classes: int, rngs: nnx.Rngs):
        super().__init__()
        self.model = ImageClassifier.Model(num_classes, rngs)

    @staticmethod
    def loss(model, x, y):
        logits = model(x)
        loss = optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()
        return loss, logits

    def training_step(self, batch, batch_idx):
        # x, y = batch
        (loss, _), grads = nnx.value_and_grad(self.loss, has_aux=True)(self.model, *batch)
        # logits = self.model(x)
        # loss = optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()
        self.log("train_loss", loss)
        return loss, nnx.to_pure_dict(grads)

    def validation_step(self, batch, batch_idx):
        x, y = batch
        loss, logits = self.loss(self.model, x, y)

        # logits = self.model(x)
        # loss = optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()
        acc = (logits.argmax(axis=1) == y).mean()
        self.log("val_loss", loss)
        self.log("val_acc", acc)

    def configure_model(self, stage: "reax.Stage", batch: Any, /) -> None:
        params = nnx.state(self.model, nnx.Param)
        params = nnx.to_pure_dict(params)
        self.set_parameters(params)

    def configure_optimizers(self):
        assert self.parameters() is not None  # nosec B101
        optimiser = optax.adam(learning_rate=0.01)
        state = optimiser.init(self.parameters())
        return optimiser, state

    def set_parameters(self, params: jt.PyTree):
        super().set_parameters(params)
        nnx.update(self.model, params)

The reax.Module organises your code into clear sections:

  • Model definition (__init__, __call__)

  • Training logic (training_step)

  • Validation logic (validation_step)

  • Optimiser configuration (configure_optimizers)

3. Prepare Your Data#

REAX works with any iterable (DataLoader, numpy arrays, lists, etc.):

from reax import demos

mnist = demos.mnist.MnistDataModule()

4. Train Your Model#

The reax.Trainer handles the training loop automatically:

# Initialise the model
model = ImageClassifier(num_classes=10, rngs=nnx.Rngs(42))

# Create a trainer
trainer = reax.Trainer()

# Train!
trainer.fit(model, datamodule=mnist, max_epochs=10)
downloaded https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw
downloaded https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw
downloaded https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw
downloaded https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw
An exception has occurred, use %tb to see the full traceback.

SystemExit: 1
/home/docs/checkouts/readthedocs.org/user_builds/reax/envs/latest/lib/python3.13/site-packages/IPython/core/interactiveshell.py:3756: UserWarning: To exit: use 'exit', 'quit', or Ctrl-D.
  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)

5. Scale to Multiple GPUs#

Want to train on 4 GPUs? Just change one line:

# Single GPU
trainer = reax.Trainer(max_epochs=10)

# 4 GPUs with Data Distributed Parallel
# trainer = reax.Trainer(max_epochs=10, accelerator="gpu", devices=4, strategy="ddp")

6. Add Logging and Checkpointing#

Track your experiments with built-in logger support:

from reax.loggers import TensorBoardLogger

logger = TensorBoardLogger("logs/", name="my_experiment")
trainer = reax.Trainer(
    max_epochs=10,
    logger=logger,
    enable_checkpointing=True  # Automatically saves best model
)

7. Use Your Trained Model#

After training, use your model for inference:

# Load the best checkpoint
best_model_path = trainer.checkpoint_listeners[0].best_model_path
checkpoint = trainer.checkpointing.load(best_model_path)
model.set_parameters(checkpoint["parameters"])

# Make predictions
predictions = trainer.predict(model, test_loader)

Works with Any JAX Library#

REAX works equally well with other libraries:

Flax Linen#

from flax import linen as nn
import jax
import optax

class LinenModel(reax.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Dense(10)

    def configure_model(self, stage: reax.Stage, batch, /):
        if self.parameters() is None:
            x, _ = batch
            params = self.model.init(self.rngs(), x)
            self.set_parameters(params)

    def __call__(self, x):
        return self.model.apply(self.parameters(), x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        loss, grads = jax.value_and_grad(self.loss_fn, argnums=0)(
            self.parameters(), x, y, self.model.apply
        )
        self.log("train_loss", loss)
        return loss, grads

    @staticmethod
    @jax.jit
    def loss_fn(params, x, y, apply_fn):
        logits = apply_fn(params, x)
        return optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()

    def configure_optimizers(self):
        opt = optax.adam(learning_rate=1e-3)
        state = opt.init(self.parameters())
        return opt, state

Equinox#

import equinox as eqx
import jax
import optax

class EquinoxModel(reax.Module):
    def __init__(self):
        super().__init__()

    def configure_model(self, stage: reax.Stage, batch, /):
        if self.parameters() is None:
            x, _ = batch
            model = eqx.nn.MLP(
                in_size=x.shape[-1],
                out_size=10,
                width_size=128,
                depth=2,
                key=self.rngs()
            )
            self.set_parameters(model)

    def __call__(self, x):
        return jax.vmap(self.parameters())(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        loss, grads = jax.value_and_grad(self.loss_fn, argnums=0)(
            self.parameters(), x, y
        )
        self.log("train_loss", loss)
        return loss, grads

    @staticmethod
    @jax.jit
    def loss_fn(model, x, y):
        logits = jax.vmap(model)(x)
        return optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()

    def configure_optimizers(self):
        opt = optax.adam(learning_rate=1e-3)
        state = opt.init(self.parameters())
        return opt, state

Haiku#

import haiku as hk
import jax
import optax

class HaikuModel(reax.Module):
    def __init__(self):
        super().__init__()
        def forward_fn(x):
            mlp = hk.nets.MLP(output_sizes=[128, 10])
            return mlp(x)
        self.forward_transformed = hk.without_apply_rng(hk.transform(forward_fn))

    def configure_model(self, stage: reax.Stage, batch, /):
        if self.parameters() is None:
            x, _ = batch
            params = self.forward_transformed.init(rng=self.rngs(), x=x)
            self.set_parameters(params)

    def __call__(self, x):
        return self.forward_transformed.apply(params=self.parameters(), x=x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        loss, grads = jax.value_and_grad(self.loss_fn, argnums=0)(
            self.parameters(), x, y, self.forward_transformed.apply
        )
        self.log("train_loss", loss)
        return loss, grads

    @staticmethod
    @jax.jit
    def loss_fn(params, x, y, apply_fn):
        logits = apply_fn(params=params, x=x)
        return optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()

    def configure_optimizers(self):
        opt = optax.adam(learning_rate=1e-3)
        state = opt.init(self.parameters())
        return opt, state

Next Steps#

Now that you understand the basics, explore:

Level Up Your Skills

See Examples

API Reference