REAX in 15 Minutes#
REAX is a lightweight training framework that works with any JAX neural network library. This guide will show you the essential concepts in 15 minutes.
What Makes REAX Different?#
🔧 Library Agnostic: Unlike other frameworks, REAX doesn’t force you to use a specific neural network library. Use Flax Linen, Flax NNX, Equinox, Haiku, or any JAX-based library you prefer.
âš¡ Minimal Boilerplate: REAX handles the training loop, distributed training, logging, and checkpointing so you can focus on your model.
🎯 Flexible Abstraction: Use the high-level
reax.Trainerfor standard workflows, or drop down to thereax.Enginefor custom training loops.
The 7 Key Steps#
1. Install REAX#
#!pip install reax
2. Define Your Model#
REAX works with any JAX neural network library. Here’s an example using Flax NNX:
from typing import Any
import jaxtyping as jt
import reax
from flax import nnx
import optax
class ImageClassifier(reax.Module):
class Model(nnx.Module):
def __init__(self, num_classes: int, rngs: nnx.Rngs):
super().__init__()
self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)
self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)
self.linear = nnx.Linear(64 * 7 * 7, num_classes, rngs=rngs)
def __call__(self, x):
x = x.reshape(*x.shape, 1) # Need the channels dimension
x = nnx.relu(self.conv1(x))
x = nnx.max_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nnx.relu(self.conv2(x))
x = nnx.max_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1))
return self.linear(x)
def __init__(self, num_classes: int, rngs: nnx.Rngs):
super().__init__()
self.model = ImageClassifier.Model(num_classes, rngs)
@staticmethod
def loss(model, x, y):
logits = model(x)
loss = optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()
return loss, logits
def training_step(self, batch, batch_idx):
# x, y = batch
(loss, _), grads = nnx.value_and_grad(self.loss, has_aux=True)(self.model, *batch)
# logits = self.model(x)
# loss = optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()
self.log("train_loss", loss)
return loss, nnx.to_pure_dict(grads)
def validation_step(self, batch, batch_idx):
x, y = batch
loss, logits = self.loss(self.model, x, y)
# logits = self.model(x)
# loss = optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()
acc = (logits.argmax(axis=1) == y).mean()
self.log("val_loss", loss)
self.log("val_acc", acc)
def configure_model(self, stage: "reax.Stage", batch: Any, /) -> None:
params = nnx.state(self.model, nnx.Param)
params = nnx.to_pure_dict(params)
self.set_parameters(params)
def configure_optimizers(self):
assert self.parameters() is not None # nosec B101
optimiser = optax.adam(learning_rate=0.01)
state = optimiser.init(self.parameters())
return optimiser, state
def set_parameters(self, params: jt.PyTree):
super().set_parameters(params)
nnx.update(self.model, params)
The reax.Module organises your code into clear sections:
Model definition (
__init__,__call__)Training logic (
training_step)Validation logic (
validation_step)Optimiser configuration (
configure_optimizers)
3. Prepare Your Data#
REAX works with any iterable (DataLoader, numpy arrays, lists, etc.):
from reax import demos
mnist = demos.mnist.MnistDataModule()
4. Train Your Model#
The reax.Trainer handles the training loop automatically:
# Initialise the model
model = ImageClassifier(num_classes=10, rngs=nnx.Rngs(42))
# Create a trainer
trainer = reax.Trainer()
# Train!
trainer.fit(model, datamodule=mnist, max_epochs=10)
downloaded https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw
downloaded https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw
downloaded https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw
downloaded https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw
An exception has occurred, use %tb to see the full traceback.
SystemExit: 1
/home/docs/checkouts/readthedocs.org/user_builds/reax/envs/latest/lib/python3.13/site-packages/IPython/core/interactiveshell.py:3756: UserWarning: To exit: use 'exit', 'quit', or Ctrl-D.
warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
5. Scale to Multiple GPUs#
Want to train on 4 GPUs? Just change one line:
# Single GPU
trainer = reax.Trainer(max_epochs=10)
# 4 GPUs with Data Distributed Parallel
# trainer = reax.Trainer(max_epochs=10, accelerator="gpu", devices=4, strategy="ddp")
6. Add Logging and Checkpointing#
Track your experiments with built-in logger support:
from reax.loggers import TensorBoardLogger
logger = TensorBoardLogger("logs/", name="my_experiment")
trainer = reax.Trainer(
max_epochs=10,
logger=logger,
enable_checkpointing=True # Automatically saves best model
)
7. Use Your Trained Model#
After training, use your model for inference:
# Load the best checkpoint
best_model_path = trainer.checkpoint_listeners[0].best_model_path
checkpoint = trainer.checkpointing.load(best_model_path)
model.set_parameters(checkpoint["parameters"])
# Make predictions
predictions = trainer.predict(model, test_loader)
Works with Any JAX Library#
REAX works equally well with other libraries:
Flax Linen#
from flax import linen as nn
import jax
import optax
class LinenModel(reax.Module):
def __init__(self):
super().__init__()
self.model = nn.Dense(10)
def configure_model(self, stage: reax.Stage, batch, /):
if self.parameters() is None:
x, _ = batch
params = self.model.init(self.rngs(), x)
self.set_parameters(params)
def __call__(self, x):
return self.model.apply(self.parameters(), x)
def training_step(self, batch, batch_idx):
x, y = batch
loss, grads = jax.value_and_grad(self.loss_fn, argnums=0)(
self.parameters(), x, y, self.model.apply
)
self.log("train_loss", loss)
return loss, grads
@staticmethod
@jax.jit
def loss_fn(params, x, y, apply_fn):
logits = apply_fn(params, x)
return optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()
def configure_optimizers(self):
opt = optax.adam(learning_rate=1e-3)
state = opt.init(self.parameters())
return opt, state
Equinox#
import equinox as eqx
import jax
import optax
class EquinoxModel(reax.Module):
def __init__(self):
super().__init__()
def configure_model(self, stage: reax.Stage, batch, /):
if self.parameters() is None:
x, _ = batch
model = eqx.nn.MLP(
in_size=x.shape[-1],
out_size=10,
width_size=128,
depth=2,
key=self.rngs()
)
self.set_parameters(model)
def __call__(self, x):
return jax.vmap(self.parameters())(x)
def training_step(self, batch, batch_idx):
x, y = batch
loss, grads = jax.value_and_grad(self.loss_fn, argnums=0)(
self.parameters(), x, y
)
self.log("train_loss", loss)
return loss, grads
@staticmethod
@jax.jit
def loss_fn(model, x, y):
logits = jax.vmap(model)(x)
return optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()
def configure_optimizers(self):
opt = optax.adam(learning_rate=1e-3)
state = opt.init(self.parameters())
return opt, state
Haiku#
import haiku as hk
import jax
import optax
class HaikuModel(reax.Module):
def __init__(self):
super().__init__()
def forward_fn(x):
mlp = hk.nets.MLP(output_sizes=[128, 10])
return mlp(x)
self.forward_transformed = hk.without_apply_rng(hk.transform(forward_fn))
def configure_model(self, stage: reax.Stage, batch, /):
if self.parameters() is None:
x, _ = batch
params = self.forward_transformed.init(rng=self.rngs(), x=x)
self.set_parameters(params)
def __call__(self, x):
return self.forward_transformed.apply(params=self.parameters(), x=x)
def training_step(self, batch, batch_idx):
x, y = batch
loss, grads = jax.value_and_grad(self.loss_fn, argnums=0)(
self.parameters(), x, y, self.forward_transformed.apply
)
self.log("train_loss", loss)
return loss, grads
@staticmethod
@jax.jit
def loss_fn(params, x, y, apply_fn):
logits = apply_fn(params=params, x=x)
return optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()
def configure_optimizers(self):
opt = optax.adam(learning_rate=1e-3)
state = opt.init(self.parameters())
return opt, state
Next Steps#
Now that you understand the basics, explore:
Level Up Your Skills
See Examples
API Reference