reax.Engine#
- class reax.Engine(accelerator: str = 'auto', strategy: str | reax.Strategy = 'auto', devices: list[int] | str | int = 'auto', precision=None, logger: reax.Logger | Iterable[reax.Logger] | bool | None = True, listeners: list[reax.TrainerListener] | reax.TrainerListener | None = None, deterministic: bool = False, rngs: Rngs = None, default_root_dir: reax.types.Path | None = None, profiler: reax.Profiler | str | None = None)[source]#
The central execution and orchestration component for distributed training.
The Engine serves as the high-level interface that binds a specific Strategy (e.g., local, DDP, sharded) to the execution flow, providing standardized methods necessary to build training loops or a full Trainer.
It encapsulates the environment setup defined by the chosen strategy, managing details like accelerator configuration, process launching (e.g., via MPI or Python’s multiprocessing), and backend initialization.
- strategy#
The execution strategy defining the environment and distributed communication.
- Type:
object
Initializes the Engine with execution parameters and components.
The Engine sets up the chosen execution strategy and initializes core components like loggers, event listeners, and the random number generator state.
- Parameters:
accelerator (str) – The type of hardware accelerator to use, e.g.,
'gpu','tpu', or'cpu'. Defaults to'auto'.strategy (str | reax.Strategy) – The distributed training strategy to use, e.g.,
'ddp','fsdp', or areax.Strategyinstance. Defaults to'auto'.devices (list[int] | str | int) – The specific devices to target. Can be a list of indices,
'auto', or an integer count.precision – Currently not supported in JAX/nnx, and a warning is issued if set.
logger (reax.Logger | Iterable[reax.Logger] | bool | None) – One or more loggers, or
True(for default CSVLogger) orFalse(to disable logging).listeners (list[reax.TrainerListener] | reax.TrainerListener | None) – A single or list of
reax.TrainerListenerobjects for hooking into events.deterministic (bool) – Currently not supported, a warning is issued if set.
rngs (nnx.Rngs) – The initial random number generator state used for all training steps. Defaults to a new
nnx.Rngs(0).default_root_dir (Path | None) – The default directory path for saving logs and checkpoints if not specified elsewhere. Defaults to the current working directory.
- __init__(accelerator: str = 'auto', strategy: str | reax.Strategy = 'auto', devices: list[int] | str | int = 'auto', precision=None, logger: reax.Logger | Iterable[reax.Logger] | bool | None = True, listeners: list[reax.TrainerListener] | reax.TrainerListener | None = None, deterministic: bool = False, rngs: Rngs = None, default_root_dir: reax.types.Path | None = None, profiler: reax.Profiler | str | None = None)[source]#
Initializes the Engine with execution parameters and components.
The Engine sets up the chosen execution strategy and initializes core components like loggers, event listeners, and the random number generator state.
- Parameters:
accelerator (str) – The type of hardware accelerator to use, e.g.,
'gpu','tpu', or'cpu'. Defaults to'auto'.strategy (str | reax.Strategy) – The distributed training strategy to use, e.g.,
'ddp','fsdp', or areax.Strategyinstance. Defaults to'auto'.devices (list[int] | str | int) – The specific devices to target. Can be a list of indices,
'auto', or an integer count.precision – Currently not supported in JAX/nnx, and a warning is issued if set.
logger (reax.Logger | Iterable[reax.Logger] | bool | None) – One or more loggers, or
True(for default CSVLogger) orFalse(to disable logging).listeners (list[reax.TrainerListener] | reax.TrainerListener | None) – A single or list of
reax.TrainerListenerobjects for hooking into events.deterministic (bool) – Currently not supported, a warning is issued if set.
rngs (nnx.Rngs) – The initial random number generator state used for all training steps. Defaults to a new
nnx.Rngs(0).default_root_dir (Path | None) – The default directory path for saving logs and checkpoints if not specified elsewhere. Defaults to the current working directory.
Methods
__init__([accelerator, strategy, devices, ...])Initializes the Engine with execution parameters and components.
all_reduce(obj[, reduce_op])Reduces a tensor from several distributed processes to one aggregated tensor (AllReduce).
barrier([name])Wait for all processes to enter this call.
broadcast(obj[, src])Send a tensor from one process to all others.
call(name, *args, **kwargs)Triggers a named event hook for all registered listeners.
compute(metric)Computes the final value of a metric across all distributed processes.
default_device()Context manager that explicitly sets the strategy's device as the default for the duration of the context, ensuring subsequent operations target this device.
finalize()Clean up the trainer.
profile(profile_name, **kwargs)setup(*args)Prepares and transforms objects (Modules, Optimizers, DataLoaders) for distributed use according to the current strategy.
setup_dataloaders(*args)Prepares and wraps dataloaders for the distributed environment.
to_device(data)Moves the provided PyTree data structure to the Engine's assigned device.
Attributes
default_root_dirGet the fallback directory used for loggers and other components when not explicitly specified, resolving path variables (e.g.,
~) if necessary.deviceThe JAX device instance currently assigned to this process.
is_global_zeroWhether this process is the global rank zero process.
local_process_indexRank of the process on the current host (0 → local device count − 1).
loggerGet the first (and main) logger configured for the Engine.
loggersGet all the loggers configured for the Engine.
node_rankThe rank of the current node (host) among all nodes.
process_countTotal number of processes across all hosts participating in the distributed task.
process_indexGlobal rank of the current process across all hosts (0 → total process count − 1).
profilerGet the profiler that can be used to annotate and time parts sections of code
rngsThe current random number generator state used by the Engine.
The execution strategy currently employed by the Engine.