Distributed Training#

REAX makes it easy to scale your training to multiple GPUs or TPUs.

Strategies#

REAX supports several distributed strategies:

  • ‘ddp’ (Data Distributed Parallel): Replicates the model on each device and synchronises gradients.

  • ‘fsdp’ (Fully Sharded Data Parallel): Shards the model parameters across devices to save memory.

  • ‘auto’: Automatically selects the best strategy based on the available hardware.

Configuration#

To enable distributed training, simply set the devices and strategy arguments in the Trainer:

# Train on 4 GPUs using DDP
trainer = reax.Trainer(accelerator="gpu", devices=4, strategy="ddp")

Launch Methods#

You can launch your script using standard tools like mpirun or SLURM. REAX will automatically detect the environment and initialise the distributed backend.