REAX Module#
The Module is the central building block of your model in REAX. It organises your
PyTorch/JAX code into 5 specific sections:
Computations (
__init__,__call__, etc.)Train Loop (
training_step)Validation Loop (
validation_step)Test Loop (
test_step)Optimizers (
configure_optimizers)
REAX Modules are fully compatible with Flax NNX.
Basic Example#
Here is a minimal example of a REAX Module:
import reax
from flax import nnx
import optax
class MyModel(reax.Module):
def __init__(self, din, dout, rngs: nnx.Rngs):
super().__init__()
self.linear = nnx.Linear(din, dout, rngs=rngs)
def __call__(self, x):
return self.linear(x)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
return optax.adam(1e-3)
The Life Cycle#
The methods in the module are called by the Trainer in a specific order.
Training Step#
The training_step() method is the heart of the training loop. It receives a
batch of data and an index. It should return the loss (scalar) or a dictionary containing the loss
under the key 'loss'.
Validation Step#
The validation_step() is called during validation. It is used to evaluate the
model’s performance on unseen data. You can log metrics here to track progress.
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()
self.log("val_loss", loss)
Test Step#
The test_step() is similar to validation but is typically used after training is
complete to evaluate the final model on a test set.
Optimization#
The configure_optimizers() method defines the optimizers and learning rate
schedulers. REAX uses Optax for optimization.
def configure_optimizers(self):
optimizer = optax.adamw(learning_rate=1e-4, weight_decay=1e-5)
return optimizer
Organizing Code#
By using Module, your code becomes more organised and readable. It separates the
model definition from the training logic, making it easier to share and reproduce experiments.