{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# REAX in 15 Minutes\n", "\n", "**REAX** is a lightweight training framework that works with any JAX neural network library.\n", "This guide will show you the essential concepts in 15 minutes." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## What Makes REAX Different?\n", "\n", "* 🔧 **Library Agnostic**: Unlike other frameworks, REAX doesn't force you to use a specific neural network library. Use Flax Linen, Flax NNX, Equinox, Haiku, or any JAX-based library you prefer.\n", "* ⚡ **Minimal Boilerplate**: REAX handles the training loop, distributed training, logging, and checkpointing so you can focus on your model.\n", "* 🎯 **Flexible Abstraction**: Use the high-level `reax.Trainer` for standard workflows, or drop down to the `reax.Engine` for custom training loops." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The 7 Key Steps\n", "\n", "### 1. Install REAX" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#!pip install reax" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2. Define Your Model\n", "\n", "REAX works with **any** JAX neural network library. Here's an example using Flax NNX:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from typing import Any\n", "\n", "import jaxtyping as jt\n", "\n", "import reax\n", "from flax import nnx\n", "import optax\n", "\n", "class ImageClassifier(reax.Module):\n", " class Model(nnx.Module):\n", " def __init__(self, num_classes: int, rngs: nnx.Rngs):\n", " super().__init__()\n", " self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)\n", " self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)\n", " self.linear = nnx.Linear(64 * 7 * 7, num_classes, rngs=rngs)\n", "\n", " def __call__(self, x):\n", " x = x.reshape(*x.shape, 1) # Need the channels dimension\n", " x = nnx.relu(self.conv1(x))\n", " x = nnx.max_pool(x, window_shape=(2, 2), strides=(2, 2))\n", " x = nnx.relu(self.conv2(x))\n", " x = nnx.max_pool(x, window_shape=(2, 2), strides=(2, 2))\n", " x = x.reshape((x.shape[0], -1))\n", " return self.linear(x)\n", "\n", " def __init__(self, num_classes: int, rngs: nnx.Rngs):\n", " super().__init__()\n", " self.model = ImageClassifier.Model(num_classes, rngs)\n", "\n", " @staticmethod\n", " def loss(model, x, y):\n", " logits = model(x)\n", " loss = optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()\n", " return loss, logits\n", "\n", " def training_step(self, batch, batch_idx):\n", " # x, y = batch\n", " (loss, _), grads = nnx.value_and_grad(self.loss, has_aux=True)(self.model, *batch)\n", " # logits = self.model(x)\n", " # loss = optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()\n", " self.log(\"train_loss\", loss)\n", " return loss, nnx.to_pure_dict(grads)\n", "\n", " def validation_step(self, batch, batch_idx):\n", " x, y = batch\n", " loss, logits = self.loss(self.model, x, y)\n", "\n", " # logits = self.model(x)\n", " # loss = optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()\n", " acc = (logits.argmax(axis=1) == y).mean()\n", " self.log(\"val_loss\", loss)\n", " self.log(\"val_acc\", acc)\n", "\n", " def configure_model(self, stage: \"reax.Stage\", batch: Any, /) -> None:\n", " params = nnx.state(self.model, nnx.Param)\n", " params = nnx.to_pure_dict(params)\n", " self.set_parameters(params)\n", "\n", " def configure_optimizers(self):\n", " assert self.parameters() is not None # nosec B101\n", " optimiser = optax.adam(learning_rate=0.01)\n", " state = optimiser.init(self.parameters())\n", " return optimiser, state\n", "\n", " def set_parameters(self, params: jt.PyTree):\n", " super().set_parameters(params)\n", " nnx.update(self.model, params)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `reax.Module` organises your code into clear sections:\n", "\n", "* **Model definition** (`__init__`, `__call__`)\n", "* **Training logic** (`training_step`)\n", "* **Validation logic** (`validation_step`)\n", "* **Optimiser configuration** (`configure_optimizers`)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3. Prepare Your Data\n", "\n", "REAX works with any iterable (DataLoader, numpy arrays, lists, etc.):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from reax import demos\n", "\n", "mnist = demos.mnist.MnistDataModule()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 4. Train Your Model\n", "\n", "The `reax.Trainer` handles the training loop automatically:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Initialise the model\n", "model = ImageClassifier(num_classes=10, rngs=nnx.Rngs(42))\n", "\n", "# Create a trainer\n", "trainer = reax.Trainer()\n", "\n", "# Train!\n", "trainer.fit(model, datamodule=mnist, max_epochs=10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 5. Scale to Multiple GPUs\n", "\n", "Want to train on 4 GPUs? Just change one line:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Single GPU\n", "trainer = reax.Trainer(max_epochs=10)\n", "\n", "# 4 GPUs with Data Distributed Parallel\n", "# trainer = reax.Trainer(max_epochs=10, accelerator=\"gpu\", devices=4, strategy=\"ddp\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 6. Add Logging and Checkpointing\n", "\n", "Track your experiments with built-in logger support:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from reax.loggers import TensorBoardLogger\n", "\n", "logger = TensorBoardLogger(\"logs/\", name=\"my_experiment\")\n", "trainer = reax.Trainer(\n", " max_epochs=10,\n", " logger=logger,\n", " enable_checkpointing=True # Automatically saves best model\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 7. Use Your Trained Model\n", "\n", "After training, use your model for inference:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Load the best checkpoint\n", "best_model_path = trainer.checkpoint_listeners[0].best_model_path\n", "checkpoint = trainer.checkpointing.load(best_model_path)\n", "model.set_parameters(checkpoint[\"parameters\"])\n", "\n", "# Make predictions\n", "predictions = trainer.predict(model, test_loader)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Works with Any JAX Library\n", "\n", "REAX works equally well with other libraries:\n", "\n", "### Flax Linen" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from flax import linen as nn\n", "import jax\n", "import optax\n", "\n", "class LinenModel(reax.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.model = nn.Dense(10)\n", "\n", " def configure_model(self, stage: reax.Stage, batch, /):\n", " if self.parameters() is None:\n", " x, _ = batch\n", " params = self.model.init(self.rngs(), x)\n", " self.set_parameters(params)\n", "\n", " def __call__(self, x):\n", " return self.model.apply(self.parameters(), x)\n", "\n", " def training_step(self, batch, batch_idx):\n", " x, y = batch\n", " loss, grads = jax.value_and_grad(self.loss_fn, argnums=0)(\n", " self.parameters(), x, y, self.model.apply\n", " )\n", " self.log(\"train_loss\", loss)\n", " return loss, grads\n", "\n", " @staticmethod\n", " @jax.jit\n", " def loss_fn(params, x, y, apply_fn):\n", " logits = apply_fn(params, x)\n", " return optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()\n", "\n", " def configure_optimizers(self):\n", " opt = optax.adam(learning_rate=1e-3)\n", " state = opt.init(self.parameters())\n", " return opt, state" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Equinox" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import equinox as eqx\n", "import jax\n", "import optax\n", "\n", "class EquinoxModel(reax.Module):\n", " def __init__(self):\n", " super().__init__()\n", "\n", " def configure_model(self, stage: reax.Stage, batch, /):\n", " if self.parameters() is None:\n", " x, _ = batch\n", " model = eqx.nn.MLP(\n", " in_size=x.shape[-1],\n", " out_size=10,\n", " width_size=128,\n", " depth=2,\n", " key=self.rngs()\n", " )\n", " self.set_parameters(model)\n", "\n", " def __call__(self, x):\n", " return jax.vmap(self.parameters())(x)\n", "\n", " def training_step(self, batch, batch_idx):\n", " x, y = batch\n", " loss, grads = jax.value_and_grad(self.loss_fn, argnums=0)(\n", " self.parameters(), x, y\n", " )\n", " self.log(\"train_loss\", loss)\n", " return loss, grads\n", "\n", " @staticmethod\n", " @jax.jit\n", " def loss_fn(model, x, y):\n", " logits = jax.vmap(model)(x)\n", " return optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()\n", "\n", " def configure_optimizers(self):\n", " opt = optax.adam(learning_rate=1e-3)\n", " state = opt.init(self.parameters())\n", " return opt, state" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Haiku" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import haiku as hk\n", "import jax\n", "import optax\n", "\n", "class HaikuModel(reax.Module):\n", " def __init__(self):\n", " super().__init__()\n", " def forward_fn(x):\n", " mlp = hk.nets.MLP(output_sizes=[128, 10])\n", " return mlp(x)\n", " self.forward_transformed = hk.without_apply_rng(hk.transform(forward_fn))\n", "\n", " def configure_model(self, stage: reax.Stage, batch, /):\n", " if self.parameters() is None:\n", " x, _ = batch\n", " params = self.forward_transformed.init(rng=self.rngs(), x=x)\n", " self.set_parameters(params)\n", "\n", " def __call__(self, x):\n", " return self.forward_transformed.apply(params=self.parameters(), x=x)\n", "\n", " def training_step(self, batch, batch_idx):\n", " x, y = batch\n", " loss, grads = jax.value_and_grad(self.loss_fn, argnums=0)(\n", " self.parameters(), x, y, self.forward_transformed.apply\n", " )\n", " self.log(\"train_loss\", loss)\n", " return loss, grads\n", "\n", " @staticmethod\n", " @jax.jit\n", " def loss_fn(params, x, y, apply_fn):\n", " logits = apply_fn(params=params, x=x)\n", " return optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()\n", "\n", " def configure_optimizers(self):\n", " opt = optax.adam(learning_rate=1e-3)\n", " state = opt.init(self.parameters())\n", " return opt, state" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Next Steps\n", "\n", "Now that you understand the basics, explore:\n", "\n", "**Level Up Your Skills**\n", "* [REAX Modules](../user_guide/module.rst)\n", "* [Master the Trainer](../user_guide/trainer.rst)\n", "* [Scale to multiple nodes](../user_guide/distributed.rst)\n", "* [Advanced checkpointing strategies](../user_guide/checkpointing.rst)\n", "\n", "**See Examples**\n", "* [Real-world examples](../examples/index.rst)\n", "\n", "**API Reference**\n", "* [Detailed API documentation](../api/index.rst)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3" } }, "nbformat": 4, "nbformat_minor": 4 }