reax.Engine#

class reax.Engine(accelerator: str = 'auto', strategy: str | reax.Strategy = 'auto', devices: list[int] | str | int = 'auto', precision=None, logger: reax.Logger | Iterable[reax.Logger] | bool | None = True, listeners: list[reax.TrainerListener] | reax.TrainerListener | None = None, deterministic: bool = False, rngs: Rngs = None, default_root_dir: reax.types.Path | None = None, profiler: reax.Profiler | str | None = None)[source]#

The central execution and orchestration component for distributed training.

The Engine serves as the high-level interface that binds a specific Strategy (e.g., local, DDP, sharded) to the execution flow, providing standardized methods necessary to build training loops or a full Trainer.

It encapsulates the environment setup defined by the chosen strategy, managing details like accelerator configuration, process launching (e.g., via MPI or Python’s multiprocessing), and backend initialization.

strategy#

The execution strategy defining the environment and distributed communication.

Type:

object

Initializes the Engine with execution parameters and components.

The Engine sets up the chosen execution strategy and initializes core components like loggers, event listeners, and the random number generator state.

Parameters:
  • accelerator (str) – The type of hardware accelerator to use, e.g., 'gpu', 'tpu', or 'cpu'. Defaults to 'auto'.

  • strategy (str | reax.Strategy) – The distributed training strategy to use, e.g., 'ddp', 'fsdp', or a reax.Strategy instance. Defaults to 'auto'.

  • devices (list[int] | str | int) – The specific devices to target. Can be a list of indices, 'auto', or an integer count.

  • precision – Currently not supported in JAX/nnx, and a warning is issued if set.

  • logger (reax.Logger | Iterable[reax.Logger] | bool | None) – One or more loggers, or True (for default CSVLogger) or False (to disable logging).

  • listeners (list[reax.TrainerListener] | reax.TrainerListener | None) – A single or list of reax.TrainerListener objects for hooking into events.

  • deterministic (bool) – Currently not supported, a warning is issued if set.

  • rngs (nnx.Rngs) – The initial random number generator state used for all training steps. Defaults to a new nnx.Rngs(0).

  • default_root_dir (Path | None) – The default directory path for saving logs and checkpoints if not specified elsewhere. Defaults to the current working directory.

__init__(accelerator: str = 'auto', strategy: str | reax.Strategy = 'auto', devices: list[int] | str | int = 'auto', precision=None, logger: reax.Logger | Iterable[reax.Logger] | bool | None = True, listeners: list[reax.TrainerListener] | reax.TrainerListener | None = None, deterministic: bool = False, rngs: Rngs = None, default_root_dir: reax.types.Path | None = None, profiler: reax.Profiler | str | None = None)[source]#

Initializes the Engine with execution parameters and components.

The Engine sets up the chosen execution strategy and initializes core components like loggers, event listeners, and the random number generator state.

Parameters:
  • accelerator (str) – The type of hardware accelerator to use, e.g., 'gpu', 'tpu', or 'cpu'. Defaults to 'auto'.

  • strategy (str | reax.Strategy) – The distributed training strategy to use, e.g., 'ddp', 'fsdp', or a reax.Strategy instance. Defaults to 'auto'.

  • devices (list[int] | str | int) – The specific devices to target. Can be a list of indices, 'auto', or an integer count.

  • precision – Currently not supported in JAX/nnx, and a warning is issued if set.

  • logger (reax.Logger | Iterable[reax.Logger] | bool | None) – One or more loggers, or True (for default CSVLogger) or False (to disable logging).

  • listeners (list[reax.TrainerListener] | reax.TrainerListener | None) – A single or list of reax.TrainerListener objects for hooking into events.

  • deterministic (bool) – Currently not supported, a warning is issued if set.

  • rngs (nnx.Rngs) – The initial random number generator state used for all training steps. Defaults to a new nnx.Rngs(0).

  • default_root_dir (Path | None) – The default directory path for saving logs and checkpoints if not specified elsewhere. Defaults to the current working directory.

Methods

__init__([accelerator, strategy, devices, ...])

Initializes the Engine with execution parameters and components.

all_reduce(obj[, reduce_op])

Reduces a tensor from several distributed processes to one aggregated tensor (AllReduce).

barrier([name])

Wait for all processes to enter this call.

broadcast(obj[, src])

Send a tensor from one process to all others.

call(name, *args, **kwargs)

Triggers a named event hook for all registered listeners.

compute(metric)

Computes the final value of a metric across all distributed processes.

default_device()

Context manager that explicitly sets the strategy's device as the default for the duration of the context, ensuring subsequent operations target this device.

finalize()

Clean up the trainer.

profile(profile_name, **kwargs)

setup(*args)

Prepares and transforms objects (Modules, Optimizers, DataLoaders) for distributed use according to the current strategy.

setup_dataloaders(*args)

Prepares and wraps dataloaders for the distributed environment.

to_device(data)

Moves the provided PyTree data structure to the Engine's assigned device.

Attributes

default_root_dir

Get the fallback directory used for loggers and other components when not explicitly specified, resolving path variables (e.g., ~) if necessary.

device

The JAX device instance currently assigned to this process.

is_global_zero

Whether this process is the global rank zero process.

local_process_index

Rank of the process on the current host (0 → local device count − 1).

logger

Get the first (and main) logger configured for the Engine.

loggers

Get all the loggers configured for the Engine.

node_rank

The rank of the current node (host) among all nodes.

process_count

Total number of processes across all hosts participating in the distributed task.

process_index

Global rank of the current process across all hosts (0 → total process count − 1).

profiler

Get the profiler that can be used to annotate and time parts sections of code

rngs

The current random number generator state used by the Engine.

strategy

The execution strategy currently employed by the Engine.