Data Handling#

REAX provides flexible tools for managing data using DataLoader and DataModule.

DataLoaders#

REAX works seamlessly with JAX-compatible data loaders. You can use standard PyTorch DataLoaders or any iterable that yields batches of numpy/jax arrays.

DataModules#

A DataModule encapsulates all steps needed to process data: downloading, tokenising, and splitting. It ensures reproducibility and makes data handling reusable across projects.

A DataModule is defined by 5 steps:

  1. prepare_data: Download, tokenise, etc. (runs only on 1 CPU in distributed settings).

  2. setup: Split data, apply transforms (runs on every device).

  3. train_dataloader: Returns the training dataloader.

  4. val_dataloader: Returns the validation dataloader.

  5. test_dataloader: Returns the test dataloader.

Example#

class MNISTDataModule(reax.DataModule):
    def prepare_data(self):
        # Download MNIST
        pass

    def setup(self, stage=None):
        # Split dataset
        pass

    def train_dataloader(self):
        return DataLoader(...)

    def val_dataloader(self):
        return DataLoader(...)

Using a DataModule#

Pass the DataModule to the Trainer:

dm = MNISTDataModule()
trainer.fit(model, datamodule=dm)