Checkpointing#
REAX provides flexible checkpointing to save and restore your training progress. This guide covers checkpoint strategies, resuming training, and best practices.
Automatic Checkpointing#
The Trainer automatically saves checkpoints when you enable checkpointing:
trainer = reax.Trainer(
max_epochs=10,
enable_checkpointing=True # Enabled by default
)
trainer.fit(model, train_loader, val_loader)
By default, REAX saves checkpoints in ./reax_logs/version_X/checkpoints/.
Custom Checkpoint Directory#
Specify where to save checkpoints:
trainer = reax.Trainer(
default_root_dir="./my_experiments",
enable_checkpointing=True
)
Checkpoints will be saved in ./my_experiments/version_X/checkpoints/.
ModelCheckpoint Listener#
For fine-grained control, use the ModelCheckpoint listener:
from reax.listeners import ModelCheckpoint
# Save the best model based on validation loss
checkpoint_listener = ModelCheckpoint(
monitor="val_loss",
mode="min",
save_top_k=3, # Keep the 3 best checkpoints
filename="best-{epoch:02d}-{val_loss:.2f}"
)
trainer = reax.Trainer(
listeners=[checkpoint_listener],
max_epochs=10
)
Monitor Options#
monitor: Metric to track (e.g.,
"val_loss","val_acc")mode:
"min"for loss,"max"for accuracysave_top_k: Number of best checkpoints to keep (
-1keeps all)filename: Checkpoint filename pattern
Save Every N Epochs#
Save checkpoints at regular intervals:
checkpoint_listener = ModelCheckpoint(
every_n_epochs=5, # Save every 5 epochs
save_top_k=-1 # Keep all checkpoints
)
Save Last Checkpoint#
Always save the most recent checkpoint:
checkpoint_listener = ModelCheckpoint(
save_last=True,
filename="last-{epoch:02d}"
)
Resuming Training#
Load a checkpoint to resume training:
# Load checkpoint
checkpoint = trainer.checkpointing.load("path/to/checkpoint.ckpt")
# Restore model parameters
model.set_parameters(checkpoint["parameters"])
# Resume training
trainer.fit(model, train_loader, val_loader)
The checkpoint contains:
parameters: Model parametersoptimizer_state: Optimiser stateepoch: Current epochglobal_step: Global training step
Manual Checkpointing#
Save checkpoints manually:
# During training
checkpoint_data = {
"parameters": model.parameters(),
"optimizer_state": optimizer_state,
"epoch": current_epoch,
"custom_data": my_data
}
trainer.checkpointing.save("my_checkpoint.ckpt", checkpoint_data)
Load manual checkpoints:
checkpoint = trainer.checkpointing.load("my_checkpoint.ckpt")
model.set_parameters(checkpoint["parameters"])
Checkpoint Formats#
REAX supports multiple checkpoint formats:
MessagePack (Default)#
Fast and compact binary format:
from reax.saving import MsgpackCheckpointing
trainer = reax.Trainer(
checkpointing=MsgpackCheckpointing()
)
Pickle#
Python’s native serialisation format:
from reax.saving import PickleCheckpointing
trainer = reax.Trainer(
checkpointing=PickleCheckpointing()
)
Best Practices#
Save Based on Validation Metrics#
Always monitor validation metrics to avoid overfitting:
checkpoint_listener = ModelCheckpoint(
monitor="val_loss",
mode="min",
save_top_k=1
)
Keep Multiple Checkpoints#
Save several checkpoints in case the best one is corrupted:
checkpoint_listener = ModelCheckpoint(
monitor="val_loss",
mode="min",
save_top_k=3 # Keep top 3
)
Use Descriptive Filenames#
Include metrics in checkpoint filenames for easy identification:
checkpoint_listener = ModelCheckpoint(
filename="epoch={epoch:02d}-val_loss={val_loss:.4f}-val_acc={val_acc:.4f}"
)
Checkpoint Large Models Efficiently#
For large models, consider:
Saving less frequently (
every_n_epochs=10)Keeping fewer checkpoints (
save_top_k=1)Using compression (MessagePack is more compact than Pickle)
Example: Complete Checkpointing Setup#
from reax import Trainer
from reax.listeners import ModelCheckpoint
# Save best model based on validation accuracy
best_checkpoint = ModelCheckpoint(
monitor="val_acc",
mode="max",
save_top_k=1,
filename="best-model-{epoch:02d}-{val_acc:.4f}"
)
# Save checkpoint every 10 epochs
periodic_checkpoint = ModelCheckpoint(
every_n_epochs=10,
save_top_k=-1,
filename="periodic-{epoch:02d}"
)
# Always save the last checkpoint
last_checkpoint = ModelCheckpoint(
save_last=True,
filename="last"
)
trainer = Trainer(
max_epochs=100,
listeners=[best_checkpoint, periodic_checkpoint, last_checkpoint],
default_root_dir="./experiments/my_model"
)
trainer.fit(model, train_loader, val_loader)
# After training, load the best model
best_path = best_checkpoint.best_model_path
checkpoint = trainer.checkpointing.load(best_path)
model.set_parameters(checkpoint["parameters"])
See Also#
The Trainer - Trainer configuration
Listeners - Custom listeners
Logging - Experiment logging