Introduction#
1. Install REAX#
pip install reax
2. Define a REAX Module#
A REAX Module keeps track of your model parameter and give you a place to put the code for the various steps in your training loop (training_step, validation_step, etc).
import os
from functools import partial
from flax import linen
import jax
import jax.numpy as jnp
import optax
import reax
from reax import demos
class Autoencoder(linen.Module):
def setup(self):
self.encoder = linen.Sequential([linen.Dense(128), linen.relu, linen.Dense(3)])
self.decoder = linen.Sequential([linen.Dense(128), linen.relu, linen.Dense(28 * 28)])
def __call__(self, x):
z = self.encoder(x)
return self.decoder(z)
def encode(self, x):
return self.encoder(x)
class ReaxAutoEncoder(reax.Module):
def __init__(self):
super().__init__()
self.ae = Autoencoder()
self._encode = partial(self.ae.apply, method="encode")
def configure_model(self, stage: reax.Stage, batch, /):
"""Initialise model parameters using example batch."""
if self.parameters() is None:
# Prepare the example batch for initialization
inputs, _ = self.prepare_batch(batch)
# Flax Linen: Use init() with RNGs and example input to get parameters
params = self.ae.init(self.rngs(), inputs)
self.set_parameters(params)
def training_step(self, batch, batch_idx):
x, _ = self.prepare_batch(batch)
# Pass apply function to static method for JIT compilation
loss, grads = jax.value_and_grad(self.loss_fn, argnums=0)(
self.parameters(), x, self.ae.apply
)
self.log("train_loss", loss, on_step=True, prog_bar=True)
return loss, grads
@staticmethod
@partial(jax.jit, static_argnums=2)
def loss_fn(params, x_batch, apply_fn):
"""Static method for JIT compilation - receives params and apply function."""
predictions = jax.vmap(apply_fn, in_axes=(None, 0))(params, x_batch)
return optax.losses.squared_error(predictions, x_batch).mean()
def encode(self, x_batch):
x_batch, _ = self.prepare_batch((x_batch, None))
return jax.vmap(self._encode, in_axes=(None, 0))(self.parameters(), x_batch)
def configure_optimizers(self):
opt = optax.adam(learning_rate=1e-3)
state = opt.init(self.parameters())
return opt, state
@staticmethod
def prepare_batch(batch):
x, y = batch
return x.reshape(x.shape[0], -1), y
autoencoder = ReaxAutoEncoder()
3. Define a dataset#
REAX supports any iterable (numpy arrays, lists etc) for the train/val/test/predict datasets.
# Setup the data
dataset = demos.mnist.MnistDataset(download=True)
data_loader = reax.ReaxDataLoader(dataset)
downloaded https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to data/MnistDataset
downloaded https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to data/MnistDataset
downloaded https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to data/MnistDataset
downloaded https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to data/MnistDataset
4. Train the model#
The REAX Trainer takes the module and dataset and combines them in a training loop, automating away most of the boilerplate.
trainer = reax.Trainer()
trainer.fit(autoencoder, data_loader, limit_train_batches=100, max_epochs=1);
5. Use the model#
checkpoint = "./reax_logs/version_0/checkpoints/epoch=0-step=100.ckpt"
ckpt = trainer.checkpointing.load(checkpoint)
autoencoder.set_parameters(ckpt["parameters"])
# embed 4 fake images!
fake_image_batch = jax.random.uniform(trainer.rngs(), shape=(4, 28, 28))
fake_image_batch = trainer.engine.to_device(fake_image_batch)
embeddings = autoencoder.encode(fake_image_batch)
print("⚡" * 20, "\nPredictions (4 image embeddings):\n", embeddings, "\n", "⚡" * 20)
---------------------------------------------------------------------------
FileNotFoundError Traceback (most recent call last)
Cell In[4], line 2
1 checkpoint = "./reax_logs/version_0/checkpoints/epoch=0-step=100.ckpt"
----> 2 ckpt = trainer.checkpointing.load(checkpoint)
3 autoencoder.set_parameters(ckpt["parameters"])
4
5 # embed 4 fake images!
File ~/checkouts/readthedocs.org/user_builds/reax/envs/latest/lib/python3.13/site-packages/reax/training/_checkpointing.py:67, in MsgpackCheckpointing.load(self, filepath)
65 @override
66 def load(self, filepath: str) -> CheckpointDict:
---> 67 with open(filepath, "rb") as file:
68 return flax.serialization.msgpack_restore(file.read())
FileNotFoundError: [Errno 2] No such file or directory: './reax_logs/version_0/checkpoints/epoch=0-step=100.ckpt'