The Engine#

The Engine is the low-level execution component of REAX. While the Trainer provides a strict, opinionated structure, the Engine offers flexible primitives for building custom training loops while still benefiting from REAX’s distributed capabilities.

Trainer vs. Engine#

  • Trainer: High-level, “batteries included”. Manages loops, logging, checkpointing, and listeners automatically. Best for standard training workflows.

  • Engine: Low-level, explicit. You write the loops. It handles device placement, distributed communication, and optimizer wrapping. Best for research and non-standard loops.

Initialising the Engine#

engine = reax.Engine(accelerator="gpu", devices=4, strategy="ddp")

Using the Engine#

The Engine provides helper methods to setup your environment.

Setup#

setup() prepares your model and optimizers for distributed training.

model, optimizer = engine.setup(model, optimizer)

Setup DataLoaders#

setup_dataloaders() prepares your dataloaders, handling sharding and device placement.

train_loader = engine.setup_dataloaders(train_loader)

Distributed Primitives#

The Engine exposes methods for cross-device communication:

  • all_reduce(): Average or sum a tensor across all devices.

  • broadcast(): Send a tensor from one device to all others.

  • barrier(): Synchronize all processes.

Example Custom Loop#

engine = reax.Engine()
model, optimizer = engine.setup(model, optimizer)
dataloader = engine.setup_dataloaders(dataloader)

for batch in dataloader:
    # Your custom update logic
    loss = update_step(model, batch)

    # Logging
    if engine.is_global_zero:
         log(loss)