Welcome to REAX#

Coveralls Latest Version https://img.shields.io/pypi/wheel/reax.svg https://img.shields.io/pypi/pyversions/reax.svg https://img.shields.io/pypi/l/reax.svg

REAX is a lightweight, flexible training framework for JAX that removes boilerplate whilst preserving maximum flexibility and enabling performance at scale.

Why REAX?#

🔧 Framework Agnostic

REAX works with any JAX neural network library: Flax Linen, Flax NNX, Equinox, Haiku, and more. Bring your own models.

Scale Effortlessly

Built-in support for multi-GPU and multi-node training with minimal code changes.

🎯 Stay in Control

Choose your level of abstraction: use the high-level Trainer or drop down to the Engine for custom training loops.

📊 Experiment Tracking

Integrated logging, checkpointing, and experiment management out of the box.

Quick Example#

import reax
from flax import nnx
import optax

# Works with any JAX library - Flax NNX example
class MyModel(reax.Module):
    def __init__(self, rngs):
        self.linear = nnx.Linear(784, 10, rngs=rngs)

    def __call__(self, x):
        return self.linear(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        return optax.adam(1e-3)

# Train on multiple GPUs with one line
trainer = reax.Trainer(accelerator="gpu", devices=4)
trainer.fit(model, train_loader, val_loader)

Installation 🛠️#

pip install reax

Or with conda:

conda install reax -c conda-forge

Get Started#

New to REAX? Start here:

Documentation#

Get Started

User Guide

Examples

API Reference

Community & Support#

  • GitHub: camml-lab/reax

  • Issues: Report bugs and request features on GitHub

Versioning#

This software follows Semantic Versioning