The Trainer#

The Trainer automates the training loop. It handles the boring details of the training process, such as iterating over epochs, validation checks, creating checkpoints, and logging.

Basic Usage#

To use the Trainer, you simply initialise it and call fit().

model = MyModel(din=32, dout=10, rngs=nnx.Rngs(0))
trainer = reax.Trainer(max_epochs=10, accelerator='auto')
trainer.fit(model, train_dataloader, val_dataloader)

Under the Hood#

The Trainer uses an Engine to execute the training. The Engine abstracts away the hardware and distributed strategy details.

Key Arguments#

  • max_epochs: The maximum number of epochs to train for.

  • accelerator: The hardware accelerator to use (e.g., 'cpu', 'gpu', 'tpu', or 'auto').

  • devices: The number of devices or specific device indices to use.

  • logger: The logger to use (e.g., CsvLogger).

  • listeners: A list of listeners to extend the Trainer’s behaviour.

Methods#

Fit#

fit() runs the full training routine, including validation loops.

trainer.fit(model, train_loader, val_loader)

Test#

test() runs the test loop on the given dataloader.

trainer.test(model, test_loader)

Predict#

predict() runs inference on the given dataloader.

predictions = trainer.predict(model, predict_loader)

Automatic Optimization#

By default, the Trainer handles backward passes and optimizer steps automatically. This simplifies the training_step to just returning the loss.