{ "cells": [ { "cell_type": "markdown", "id": "319b616a870a351e", "metadata": {}, "source": [ "# Introduction" ] }, { "cell_type": "markdown", "id": "9f4b434b619964dd", "metadata": {}, "source": [ "## 1. Install REAX\n", "\n", "```bash\n", "pip install reax\n", "```" ] }, { "cell_type": "markdown", "id": "17c5032e65192c8d", "metadata": {}, "source": [ "## 2. Define a REAX Module\n", "\n", "A REAX Module keeps track of your model parameter and give you a place to put the code for the various steps in your training loop (training_step, validation_step, etc)." ] }, { "cell_type": "code", "execution_count": null, "id": "initial_id", "metadata": {}, "outputs": [], "source": [ "import os\n", "from functools import partial\n", "from flax import linen\n", "import jax\n", "import jax.numpy as jnp\n", "import optax\n", "import reax\n", "from reax import demos\n", "\n", "\n", "class Autoencoder(linen.Module):\n", " def setup(self):\n", " self.encoder = linen.Sequential([linen.Dense(128), linen.relu, linen.Dense(3)])\n", " self.decoder = linen.Sequential([linen.Dense(128), linen.relu, linen.Dense(28 * 28)])\n", "\n", " def __call__(self, x):\n", " z = self.encoder(x)\n", " return self.decoder(z)\n", "\n", " def encode(self, x):\n", " return self.encoder(x)\n", "\n", "\n", "class ReaxAutoEncoder(reax.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.ae = Autoencoder()\n", " self._encode = partial(self.ae.apply, method=\"encode\")\n", "\n", " def configure_model(self, stage: reax.Stage, batch, /):\n", " \"\"\"Initialise model parameters using example batch.\"\"\"\n", " if self.parameters() is None:\n", " # Prepare the example batch for initialization\n", " inputs, _ = self.prepare_batch(batch)\n", " # Flax Linen: Use init() with RNGs and example input to get parameters\n", " params = self.ae.init(self.rngs(), inputs)\n", " self.set_parameters(params)\n", "\n", " def training_step(self, batch, batch_idx):\n", " x, _ = self.prepare_batch(batch)\n", " # Pass apply function to static method for JIT compilation\n", " loss, grads = jax.value_and_grad(self.loss_fn, argnums=0)(\n", " self.parameters(), x, self.ae.apply\n", " )\n", " self.log(\"train_loss\", loss, on_step=True, prog_bar=True)\n", " return loss, grads\n", "\n", " @staticmethod\n", " @partial(jax.jit, static_argnums=2)\n", " def loss_fn(params, x_batch, apply_fn):\n", " \"\"\"Static method for JIT compilation - receives params and apply function.\"\"\"\n", " predictions = jax.vmap(apply_fn, in_axes=(None, 0))(params, x_batch)\n", " return optax.losses.squared_error(predictions, x_batch).mean()\n", "\n", " def encode(self, x_batch):\n", " x_batch, _ = self.prepare_batch((x_batch, None))\n", " return jax.vmap(self._encode, in_axes=(None, 0))(self.parameters(), x_batch)\n", "\n", " def configure_optimizers(self):\n", " opt = optax.adam(learning_rate=1e-3)\n", " state = opt.init(self.parameters())\n", " return opt, state\n", "\n", " @staticmethod\n", " def prepare_batch(batch):\n", " x, y = batch\n", " return x.reshape(x.shape[0], -1), y\n", "\n", "\n", "autoencoder = ReaxAutoEncoder()" ] }, { "cell_type": "markdown", "id": "152211a9233c901f", "metadata": {}, "source": [ "## 3. Define a dataset\n", "\n", "REAX supports any iterable (numpy arrays, lists etc) for the train/val/test/predict datasets." ] }, { "cell_type": "code", "execution_count": null, "id": "51e497504a54c8c2", "metadata": {}, "outputs": [], "source": [ "# Setup the data\n", "dataset = demos.mnist.MnistDataset(download=True)\n", "data_loader = reax.ReaxDataLoader(dataset)" ] }, { "cell_type": "markdown", "id": "58b5f052601b16c5", "metadata": {}, "source": [ "## 4. Train the model\n", "\n", "The REAX Trainer takes the module and dataset and combines them in a training loop, automating away most of the boilerplate." ] }, { "cell_type": "code", "execution_count": null, "id": "acd9b8a01b0366bb", "metadata": {}, "outputs": [], "source": [ "trainer = reax.Trainer()\n", "trainer.fit(autoencoder, data_loader, limit_train_batches=100, max_epochs=1);" ] }, { "cell_type": "markdown", "id": "b65f27a03c0a66af", "metadata": {}, "source": [ "## 5. Use the model" ] }, { "cell_type": "code", "execution_count": null, "id": "3d7418fe3aa03aa3", "metadata": {}, "outputs": [], "source": [ "checkpoint = \"./reax_logs/version_0/checkpoints/epoch=0-step=100.ckpt\"\n", "ckpt = trainer.checkpointing.load(checkpoint)\n", "autoencoder.set_parameters(ckpt[\"parameters\"])\n", "\n", "# embed 4 fake images!\n", "fake_image_batch = jax.random.uniform(trainer.rngs(), shape=(4, 28, 28))\n", "fake_image_batch = trainer.engine.to_device(fake_image_batch)\n", "embeddings = autoencoder.encode(fake_image_batch)\n", "print(\"⚡\" * 20, \"\\nPredictions (4 image embeddings):\\n\", embeddings, \"\\n\", \"⚡\" * 20)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2" } }, "nbformat": 4, "nbformat_minor": 5 }