Welcome to REAX#
REAX is a lightweight, flexible training framework for JAX that removes boilerplate whilst preserving maximum flexibility and enabling performance at scale.
Why REAX?#
- 🔧 Framework Agnostic
REAX works with any JAX neural network library: Flax Linen, Flax NNX, Equinox, Haiku, and more. Bring your own models.
- ⚡ Scale Effortlessly
Built-in support for multi-GPU and multi-node training with minimal code changes.
- 🎯 Stay in Control
Choose your level of abstraction: use the high-level Trainer or drop down to the Engine for custom training loops.
- 📊 Experiment Tracking
Integrated logging, checkpointing, and experiment management out of the box.
Quick Example#
import reax
from flax import nnx
import optax
# Works with any JAX library - Flax NNX example
class MyModel(reax.Module):
def __init__(self, rngs):
self.linear = nnx.Linear(784, 10, rngs=rngs)
def __call__(self, x):
return self.linear(x)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
return optax.adam(1e-3)
# Train on multiple GPUs with one line
trainer = reax.Trainer(accelerator="gpu", devices=4)
trainer.fit(model, train_loader, val_loader)
Installation 🛠️#
pip install reax
Or with conda:
conda install reax -c conda-forge
Get Started#
New to REAX? Start here:
REAX in 15 Minutes - Learn REAX in 15 minutes
Installation - Detailed installation guide
Introduction - Interactive tutorial
Documentation#
Get Started
User Guide
Examples
API Reference
Community & Support#
GitHub: camml-lab/reax
Issues: Report bugs and request features on GitHub
Versioning#
This software follows Semantic Versioning