Source code for reax.modules

from collections.abc import Callable, Generator, Sequence
import contextlib
from typing import TYPE_CHECKING, Any, Generic, TypedDict, TypeVar

import beartype
from flax import nnx
import jax
import jaxtyping as jt
from lightning_utilities.core import rank_zero
import optax
from typing_extensions import deprecated

from . import _module_hooks
from .data import _datasources

if TYPE_CHECKING:
    import reax

__all__ = ("Module",)

MetricType = "reax.Metric" | jax.typing.ArrayLike
OutputT_co = TypeVar("OutputT_co", covariant=True)
BatchT = TypeVar("BatchT")
OptimizerData = tuple[optax.GradientTransformation, Any]


class LossAndGradDict(TypedDict, total=False):
    loss: jax.Array
    grad: jt.PyTree


LossAndGrad = tuple[jax.Array, jax.Array]
TrainOutput = LossAndGrad | LossAndGradDict


[docs] class Module( Generic[BatchT, OutputT_co], _module_hooks.ModuleHooks, _datasources.DataSource[BatchT] ): example_input_array: BatchT | None
[docs] def __init__(self): """Init function.""" super().__init__() self._trainer: "reax.Trainer | None" = None self._parameters = None self._automatic_optimization = True
@property def automatic_optimization(self) -> bool: """Automatic optimization.""" return self._automatic_optimization @automatic_optimization.setter def automatic_optimization(self, automatic_optimization: bool) -> None: """Automatic optimization.""" self._automatic_optimization = automatic_optimization @property def trainer(self) -> "reax.Trainer": """Trainer function.""" return self._trainer @trainer.setter def trainer(self, trainer): """Trainer function.""" if self._trainer is not None and trainer is not None: raise RuntimeError("Cannot set trainer, it is already set.") self._trainer = trainer @contextlib.contextmanager def attach(self, trainer: "reax.Trainer") -> Generator[None, Any, None]: self._trainer = trainer yield self._trainer = None @property def global_updates(self) -> int: """Get the global number of optimizer updates.""" return self._trainer.global_updates @property def current_epoch(self) -> int: """Get the current fitting epoch.""" return self._trainer.current_epoch def parameters(self) -> jt.PyTree | None: """Parameters function.""" return self._parameters def set_parameters(self, params: jt.PyTree): """Set parameters.""" self._parameters = params @property def rngs(self) -> nnx.Rngs: """Random number generators.""" return self._trainer.rngs def optimizers(self) -> "reax.Optimizer | list[reax.Optimizer]": """Optimizers function.""" optimizers = self.trainer.optimizers # Check for a single optimiser if ( isinstance(optimizers, list) and len(optimizers) == 1 and isinstance(optimizers[0], optax.GradientTransformation) ): return optimizers[0] # Multiple optimisers return optimizers def configure_model(self, stage: "reax.Stage", batch: Any, /) -> None: """Called at the beginning of each stage. A chance to configure the model. This method should be idempotent, i.e. calling it a second should do nothing. """ def training_step(self, batch: BatchT, batch_idx: int, /) -> TrainOutput | None: """Train step.""" def validation_step(self, batch: BatchT, batch_idx: int, /) -> jt.PyTree | None: """Validate step.""" def predict_step(self, batch: BatchT, batch_idx: int, /) -> jt.PyTree | None: """Make a model prediction and return the result.""" def test_step(self, batch: BatchT, batch_idx: int, /) -> jt.PyTree | None: """Test step.""" def configure_listeners(self) -> "Sequence[reax.TrainerListener] | reax.TrainerListener": """Configure model-specific listeners. When the model gets attached, e.g., when ``.fit()`` or ``.test()`` gets called, the list or a listener returned here will be merged with the list of listeners passed to the Trainer's ``listeners`` argument. If a listener returned here has the same type as one or several listeners already present in the Trainer's listeners list, it will take priority and replace them. In addition, REAX will make sure :class:`~reax.listeners.model_checkpoint.ModelCheckpoint` listeners run last. Return: A listener or a list of listeners which will extend the list of listeners in the Trainer. Example:: def configure_listeners(self): early_stop = EarlyStopping(monitor="val_acc", mode="max") checkpoint = ModelCheckpoint(monitor="val_loss") return [early_stop, checkpoint] """ return [] @jt.jaxtyped(typechecker=beartype.beartype) def configure_optimizers( self, ) -> OptimizerData | Sequence[OptimizerData] | None: """Create the optimizer(s) to use during training.""" return None def log( self, name: str, value: MetricType, *, prog_bar: bool = False, batch_size: int | None = None, logger: bool | None = None, on_step=True, on_epoch=True, ) -> None: """Log a key, value pair. Example:: self.log('train_loss', loss) """ trainer = self._trainer if trainer is None: # not an error to support testing the `*_step` methods without a `Trainer` reference rank_zero.rank_zero_warn( "`self.log()` was called before `self.trainer` was set. " "Probably, the model was not passed to `Trainer`" ) return if logger and trainer.logger is None: rank_zero.rank_zero_warn( f"You called `self.log({name!r}, ..., logger=True)` but have no logger " f"configured. You can enable one by using `Trainer(logger=ALogger(...))`" ) if logger is None: # we could set false here if there's no configured logger, however, we still need to # compute the "logged" metrics anyway because that's what the evaluation loops use as # return value logger = True trainer.log( name, value, prog_bar=prog_bar, batch_size=batch_size, logger=logger, on_step=on_step, on_epoch=on_epoch, ) def log_dict( self, dictionary: "Mapping[str, MetricType] | reax.metrics.MetricCollection", prog_bar: bool = False, logger: bool | None = None, on_step: bool | None = None, on_epoch: bool | None = None, batch_size: int | None = None, ) -> None: """Log a dictionary of values at once. Args: dictionary ("Mapping[str, MetricType] | reax.metrics.MetricCollection"): Key value pairs. Keys must be identical across all processes if using DDP or any other distributed strategy. The values can be a ``float``, ``Array``, ``Metric``, or ``MetricCollection``. prog_bar (bool, optional): If ``True`` logs to the progress base, defaults to False. logger (Optional[bool], optional): If ``True`` logs to the logger, defaults to None. on_step (Optional[bool], optional): If ``True`` logs at this step. ``None`` auto-logs for training_step but not validation/test_step. The default value is determined by the hook. See :ref:`extensions/logging:Automatic Logging` for details, defaults to None. on_epoch (Optional[bool], optional): If ``True`` logs epoch accumulated metrics. ``None`` auto-logs for val/test step but not ``training_step``. The default value is determined by the hook. See :ref:`extensions/logging:Automatic Logging` for details, defaults to None. batch_size (Optional[int], optional): Current batch size. This will be directly inferred from the loaded batch, but some data structures might need to explicitly provide it, defaults to None. Example:: values = {'loss': loss, 'acc': acc, ..., 'metric_n': metric_n} self.log_dict(values) """ for key, val in dictionary.items(): self.log( name=key, value=val, prog_bar=prog_bar, logger=logger, on_step=on_step, on_epoch=on_epoch, batch_size=batch_size, ) def state_dict(self) -> dict[str, Any]: """Save any additional module state""" return {} def load_state(self, state_dict: dict[str, Any]) -> None: """Load module state from the passed state dictionary""" @property @deprecated("REAX uses the term 'update' instead of 'step', please use `.global_updates`") def global_step(self) -> int: """Get the global number of optimizer updates.""" return self.global_updates
PyTree = Any InputT_co = TypeVar("InputT_co", covariant=True) ModelT = Callable[[PyTree, InputT_co], OutputT_co] LabelT_co = TypeVar("LabelT_co", covariant=True) LossFn = Callable[[OutputT_co, LabelT_co], jax.Array]