The Trainer#
The Trainer automates the training loop. It handles the boring details of the
training process, such as iterating over epochs, validation checks, creating checkpoints, and
logging.
Basic Usage#
To use the Trainer, you simply initialise it and call fit().
model = MyModel(din=32, dout=10, rngs=nnx.Rngs(0))
trainer = reax.Trainer(max_epochs=10, accelerator='auto')
trainer.fit(model, train_dataloader, val_dataloader)
Under the Hood#
The Trainer uses an Engine to execute the training. The Engine abstracts away the
hardware and distributed strategy details.
Key Arguments#
max_epochs: The maximum number of epochs to train for.
accelerator: The hardware accelerator to use (e.g.,
'cpu','gpu','tpu', or'auto').devices: The number of devices or specific device indices to use.
logger: The logger to use (e.g.,
CsvLogger).listeners: A list of listeners to extend the Trainer’s behaviour.
Methods#
Fit#
fit() runs the full training routine, including validation loops.
trainer.fit(model, train_loader, val_loader)
Test#
test() runs the test loop on the given dataloader.
trainer.test(model, test_loader)
Predict#
predict() runs inference on the given dataloader.
predictions = trainer.predict(model, predict_loader)
Automatic Optimization#
By default, the Trainer handles backward passes and optimizer steps automatically. This simplifies
the training_step to just returning the loss.