From a50d715b8f54d55216028f1d7b038db209915984 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 14 Oct 2024 09:37:51 -0700 Subject: [PATCH] Add digits VAE tutorial --- docs/conf.py | 4 + docs/digits_vae.ipynb | 986 ++++++++++++++++++++++++++++++++++++++++++ docs/digits_vae.md | 514 ++++++++++++++++++++++ docs/tutorials.md | 1 + pyproject.toml | 4 + 5 files changed, 1509 insertions(+) create mode 100644 docs/digits_vae.ipynb create mode 100644 docs/digits_vae.md diff --git a/docs/conf.py b/docs/conf.py index f5b6ca7..17ea579 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -34,9 +34,13 @@ 'build/html', 'build/jupyter_execute', # Exclude markdown sources for notebooks: + 'digits_vae.md', 'getting_started_with_jax_for_AI.md', ] +suppress_warnings = [ + 'misc.highlighting_failure', # Suppress warning in exception in digits_vae +] # -- Options for myst ---------------------------------------------- myst_heading_anchors = 3 # auto-generate 3 levels of heading anchors diff --git a/docs/digits_vae.ipynb b/docs/digits_vae.ipynb new file mode 100644 index 0000000..c964bbb --- /dev/null +++ b/docs/digits_vae.ipynb @@ -0,0 +1,986 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "47OmRSTR1dJU" + }, + "source": [ + "# Debugging in JAX: a Variational autoencoder (VAE) model\n", + "\n", + "In [Getting started with JAX](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html) we built a simple neural network for classification of handwritten digits, and covered some of the key features of JAX, including its NumPy-style interface in the `jax.numpy`, as well as its transformations for JIT compilation with `jax.jit`, automatic vectorization with `jax.vmap`, and automatic differentiation with `jax.grad`.\n", + "\n", + "This tutorial will explore a slightly more involved model: a simplified version of a [Variational Autoencoder (VAE)](https://en.wikipedia.org/wiki/Variational_autoencoder) trained on the same simple digits data. Along the way, we'll learn a bit more about how JAX's JIT compilation actually works, and what this means for debugging JAX programs." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "k19povzxp7hS" + }, + "source": [ + "## Loading the digits\n", + "\n", + "As before, we'll use the small, self-contained [scikit-learn digits dataset](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_digits.html) for ease of experimentation:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "aIwDAfS6PtFh", + "outputId": "4950f17a-7c47-4a83-cbcd-e206d07cca64" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "images_train.shape=(1347, 8, 8)\n", + "images_test.shape=(450, 8, 8)\n" + ] + } + ], + "source": [ + "from sklearn.datasets import load_digits\n", + "from sklearn.model_selection import train_test_split\n", + "import jax.numpy as jnp\n", + "\n", + "digits = load_digits()\n", + "\n", + "splits = train_test_split(digits.images, random_state=0)\n", + "\n", + "images_train, images_test = map(jnp.asarray, splits)\n", + "\n", + "print(f\"{images_train.shape=}\")\n", + "print(f\"{images_test.shape=}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2_Q16JRyrW7V" + }, + "source": [ + "The dataset comprises 1800 images, each represented by an 8x8 pixel grid.\n", + "To see a visualization of this data, refer to [loading the data](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html#loading-the-data) in the previous tutorial." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Z9TPYqipPyBp" + }, + "source": [ + "## Defining the Variational Autoencoder\n", + "\n", + "Previously we defined a simple feedforward network trained for classification with an architecture that looked roughly like this:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "id": "HNlg-ydpr5yH" + }, + "outputs": [], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "from flax import nnx\n", + "\n", + "class SimpleNN(nnx.Module):\n", + "\n", + " def __init__(self, n_features=64, n_hidden=100, n_targets=10, *, rngs: nnx.Rngs):\n", + " self.layer1 = nnx.Linear(n_features, n_hidden, rngs=rngs)\n", + " self.layer2 = nnx.Linear(n_hidden, n_hidden, rngs=rngs)\n", + " self.layer3 = nnx.Linear(n_hidden, n_targets, rngs=rngs)\n", + "\n", + " def __call__(self, x: jax.Array) -> jax.Array:\n", + " x = nnx.selu(self.layer1(x))\n", + " x = nnx.selu(self.layer2(x))\n", + " return self.layer3(x)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3DwMBvoksOmG" + }, + "source": [ + "In this network we had one output per class, and the loss function was designed such that once trained, the output corresponding to the correct class would return the strongest signal, thus predicting the correct label in upwards of 95% of cases.\n", + "\n", + "In this VAE example we use similar building blocks to instead output a small probabilisitic model representing the data. While classic `VAE` is generally based on convolutional layers, we use linear layers for simplicity. The sub-network that produces this probabilistic encoding is our `Encoder`:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "Hj7mtR5vmcGr" + }, + "outputs": [], + "source": [ + "class Encoder(nnx.Module):\n", + " def __init__(self, input_size: int, intermediate_size: int, output_size: int,\n", + " *, rngs: nnx.Rngs):\n", + " self.rngs = rngs\n", + " self.linear = nnx.Linear(input_size, intermediate_size, rngs=rngs)\n", + " self.linear_mean = nnx.Linear(intermediate_size, output_size, rngs=rngs)\n", + " self.linear_std = nnx.Linear(intermediate_size, output_size, rngs=rngs)\n", + "\n", + " def __call__(self, x: jax.Array) -> tuple[jax.Array, jax.Array, jax.Array]:\n", + " x = self.linear(x)\n", + " x = jax.nn.relu(x)\n", + "\n", + " mean = self.linear_mean(x)\n", + " std = jnp.exp(self.linear_std(x))\n", + "\n", + " key = self.rngs.noise()\n", + " z = mean + std * jax.random.normal(key, mean.shape)\n", + " return z, mean, std" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VwfCWbiRmkG9" + }, + "source": [ + "The idea here is that `mean` and `std` define a low-dimensional probability distribution over a latent space, and that `z` is a draw from this latent space that represents the training data.\n", + "\n", + "In order to ensure that this latent distribution faithfully represents the actual data, we define a `Decoder` that maps back to the input space:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "id": "FoAmZuVDnjgn" + }, + "outputs": [], + "source": [ + "class Decoder(nnx.Module):\n", + " def __init__(self, input_size: int, intermediate_size: int, output_size: int,\n", + " *, rngs: nnx.Rngs):\n", + " self.linear1 = nnx.Linear(input_size, intermediate_size, rngs=rngs)\n", + " self.linear2 = nnx.Linear(intermediate_size, output_size, rngs=rngs)\n", + "\n", + " def __call__(self, z: jax.Array) -> jax.Array:\n", + " z = self.linear1(z)\n", + " z = jax.nn.relu(z)\n", + " logits = self.linear2(z)\n", + " return logits" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0QaT-KY6npSc" + }, + "source": [ + "Now the full VAE model is a single network built from the encoder and decoder.\n", + "It returns both the reconstructed image and then internal latent space model:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "id": "Myo2MdxXnzlT" + }, + "outputs": [], + "source": [ + "class VAE(nnx.Module):\n", + " def __init__(\n", + " self,\n", + " image_shape: tuple[int, int],\n", + " hidden_size: int,\n", + " latent_size: int,\n", + " *,\n", + " rngs: nnx.Rngs\n", + " ):\n", + " self.image_shape = image_shape\n", + " self.latent_size = latent_size\n", + " input_size = image_shape[0] * image_shape[1]\n", + " self.encoder = Encoder(input_size, hidden_size, latent_size, rngs=rngs)\n", + " self.decoder = Decoder(latent_size, hidden_size, input_size, rngs=rngs)\n", + "\n", + " def __call__(self, x: jax.Array) -> tuple[jax.Array, jax.Array, jax.Array]:\n", + " x = jax.vmap(jax.numpy.ravel)(x) # flatten\n", + " z, mean, std = self.encoder(x)\n", + " logits = self.decoder(z)\n", + " logits = jnp.reshape(logits, (-1, *self.image_shape))\n", + " return logits, mean, std" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "xIm9Yi5YoIxN" + }, + "source": [ + "Next is the loss function – there are two components to the model that we want to ensure:\n", + "\n", + "1. the `logits` output faithfully reconstruct the input image.\n", + "2. the model represented by `mean` and `std` faithfully represents the \"true\" latent distribution.\n", + "\n", + "VAE uses a loss function based on the [Evidence lower bound](https://en.wikipedia.org/wiki/Evidence_lower_bound) to quantify theset two goals in a single loss value:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "id": "bMpxj8-Wsvui" + }, + "outputs": [], + "source": [ + "def vae_loss(model: VAE, x: jax.Array):\n", + " logits, mean, std = model(x)\n", + " kl_loss = jnp.mean(0.5 * jnp.mean(\n", + " -jnp.log(std ** 2) - 1.0 + std ** 2 + mean ** 2, axis=-1))\n", + " reconstruction_loss = jnp.mean(\n", + " optax.sigmoid_binary_cross_entropy(logits, x)\n", + " )\n", + " return reconstruction_loss + 0.1 * kl_loss" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RaT0ELpvqo2W" + }, + "source": [ + "Now all that's left is to define the model and optimizer, and run the training loop:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "id": "JPgoHL5rpKXd", + "outputId": "e7626646-7c14-42dc-c18f-774adff1306e" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 loss: 16745235.0\n", + "Epoch 500 loss: nan\n", + "Epoch 1000 loss: nan\n", + "Epoch 1500 loss: nan\n", + "Epoch 2000 loss: nan\n" + ] + } + ], + "source": [ + "import optax\n", + "\n", + "model = VAE(\n", + " image_shape=(8, 8),\n", + " hidden_size=32,\n", + " latent_size=8,\n", + " rngs=nnx.Rngs(0, noise=1),\n", + ")\n", + "\n", + "optimizer = nnx.Optimizer(model, optax.adam(1e-3))\n", + "\n", + "@nnx.jit\n", + "def train_step(model: VAE, optimizer: nnx.Optimizer, x: jax.Array):\n", + " loss, grads = nnx.value_and_grad(vae_loss)(model, x)\n", + " optimizer.update(grads)\n", + " return loss\n", + "\n", + "for epoch in range(2001):\n", + " loss = train_step(model, optimizer, images_train)\n", + " if epoch % 500 == 0:\n", + " print(f'Epoch {epoch} loss: {loss}')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "f8m_uoL9q47M" + }, + "source": [ + "And here we see that something has gone wrong: our loss value has become NaN after some number of iterations." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SBS1mmxwrS25" + }, + "source": [ + "## Debugging NaNs in JAX\n", + "Despite our best efforts, our model is producing NaNs. What can we do?\n", + "\n", + "JAX offers a number of debugging approaches for situations like this, outlined in the JAX docs at [Debugging runtime values](https://jax.readthedocs.io/en/latest/debugging/index.html). In this case we can use the [`debug_nans`](https://jax.readthedocs.io/en/latest/debugging/flags.html#jax-debug-nans-configuration-option-and-context-manager) configuration to see where the NaN value is arising:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "id": "JE7OYoZ4rRQ8", + "outputId": "2b77a9c1-aaa4-418a-da71-013d8f885340", + "tags": [ + "raises-exception" + ] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Invalid nan value encountered in the output of a C++-jit/pmap function. Calling the de-optimized version.\n", + "Invalid nan value encountered in the output of a C++-jit/pmap function. Calling the de-optimized version.\n" + ] + }, + { + "ename": "FloatingPointError", + "evalue": "invalid value (nan) encountered in jit(dot_general). Because jax_config.debug_nans.value and/or config.jax_debug_infs is set, the de-optimized function (i.e., the function as if the `jit` decorator were removed) was called in an attempt to get a more precise error message. However, the de-optimized function did not produce invalid values during its execution. This behavior can result from `jit` optimizations causing the invalid value to be produced. It may also arise from having nan/inf constants as outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. \n\nIt may be possible to avoid the invalid value by removing the `jit` decorator, at the cost of losing optimizations. \n\nIf you see this error, consider opening a bug report at https://github.com/google/jax.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mFloatingPointError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/api.py\u001b[0m in \u001b[0;36m_nan_check_posthook\u001b[0;34m(fun, args, kwargs, output)\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 114\u001b[0;31m \u001b[0mdispatch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcheck_special\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpjit\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpjit_p\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbuffers\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 115\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mFloatingPointError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py\u001b[0m in \u001b[0;36mcheck_special\u001b[0;34m(name, bufs)\u001b[0m\n\u001b[1;32m 320\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mbuf\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mbufs\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 321\u001b[0;31m \u001b[0m_check_special\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbuf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbuf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 322\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py\u001b[0m in \u001b[0;36m_check_special\u001b[0;34m(name, dtype, buf)\u001b[0m\n\u001b[1;32m 325\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdebug_nans\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalue\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0many\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misnan\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbuf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 326\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mFloatingPointError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"invalid value (nan) encountered in {name}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 327\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdebug_infs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalue\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0many\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misinf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbuf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mFloatingPointError\u001b[0m: invalid value (nan) encountered in pjit", + "\nDuring handling of the above exception, another exception occurred:\n", + "\u001b[0;31mFloatingPointError\u001b[0m Traceback (most recent call last)", + " \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py\u001b[0m in \u001b[0;36mcheck_special\u001b[0;34m(name, bufs)\u001b[0m\n\u001b[1;32m 320\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mbuf\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mbufs\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 321\u001b[0;31m \u001b[0m_check_special\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbuf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbuf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 322\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py\u001b[0m in \u001b[0;36m_check_special\u001b[0;34m(name, dtype, buf)\u001b[0m\n\u001b[1;32m 325\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdebug_nans\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalue\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0many\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misnan\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbuf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 326\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mFloatingPointError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"invalid value (nan) encountered in {name}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 327\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdebug_infs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalue\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0many\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misinf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbuf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mFloatingPointError\u001b[0m: invalid value (nan) encountered in pjit", + "\nDuring handling of the above exception, another exception occurred:\n", + "\u001b[0;31mFloatingPointError\u001b[0m Traceback (most recent call last)", + " \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/profiler.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 332\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mTraceAnnotation\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mdecorator_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 333\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 334\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/pxla.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args)\u001b[0m\n\u001b[1;32m 1290\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0marrays\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mout_arrays\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1291\u001b[0;31m \u001b[0mdispatch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcheck_special\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marrays\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1292\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mout_handler\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout_arrays\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py\u001b[0m in \u001b[0;36mcheck_special\u001b[0;34m(name, bufs)\u001b[0m\n\u001b[1;32m 320\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mbuf\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mbufs\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 321\u001b[0;31m \u001b[0m_check_special\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbuf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbuf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 322\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py\u001b[0m in \u001b[0;36m_check_special\u001b[0;34m(name, dtype, buf)\u001b[0m\n\u001b[1;32m 325\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdebug_nans\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalue\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0many\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misnan\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbuf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 326\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mFloatingPointError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"invalid value (nan) encountered in {name}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 327\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdebug_infs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalue\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0many\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misinf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbuf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mFloatingPointError\u001b[0m: invalid value (nan) encountered in jit(jit_fn)", + "\nDuring handling of the above exception, another exception occurred:\n", + "\u001b[0;31mFloatingPointError\u001b[0m Traceback (most recent call last)", + " \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/profiler.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 332\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mTraceAnnotation\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mdecorator_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 333\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 334\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/pxla.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args)\u001b[0m\n\u001b[1;32m 1290\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0marrays\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mout_arrays\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1291\u001b[0;31m \u001b[0mdispatch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcheck_special\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marrays\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1292\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mout_handler\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout_arrays\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py\u001b[0m in \u001b[0;36mcheck_special\u001b[0;34m(name, bufs)\u001b[0m\n\u001b[1;32m 320\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mbuf\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mbufs\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 321\u001b[0;31m \u001b[0m_check_special\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbuf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbuf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 322\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py\u001b[0m in \u001b[0;36m_check_special\u001b[0;34m(name, dtype, buf)\u001b[0m\n\u001b[1;32m 325\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdebug_nans\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalue\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0many\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misnan\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbuf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 326\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mFloatingPointError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"invalid value (nan) encountered in {name}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 327\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdebug_infs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalue\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0many\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misinf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbuf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mFloatingPointError\u001b[0m: invalid value (nan) encountered in jit(dot_general)", + "\nDuring handling of the above exception, another exception occurred:\n", + "\u001b[0;31mJaxStackTraceBeforeTransformation\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/usr/lib/python3.10/runpy.py\u001b[0m in \u001b[0;36m_run_module_as_main\u001b[0;34m()\u001b[0m\n\u001b[1;32m 195\u001b[0m \u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0margv\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmod_spec\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0morigin\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 196\u001b[0;31m return _run_code(code, main_globals, None,\n\u001b[0m\u001b[1;32m 197\u001b[0m \"__main__\", mod_spec)\n", + "\u001b[0;32m/usr/lib/python3.10/runpy.py\u001b[0m in \u001b[0;36m_run_code\u001b[0;34m()\u001b[0m\n\u001b[1;32m 85\u001b[0m __spec__ = mod_spec)\n\u001b[0;32m---> 86\u001b[0;31m \u001b[0mexec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrun_globals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 87\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mrun_globals\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/colab_kernel_launcher.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0m__name__\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'__main__'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 37\u001b[0;31m \u001b[0mColabKernelApp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlaunch_instance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/traitlets/config/application.py\u001b[0m in \u001b[0;36mlaunch_instance\u001b[0;34m()\u001b[0m\n\u001b[1;32m 991\u001b[0m \u001b[0mapp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minitialize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margv\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 992\u001b[0;31m \u001b[0mapp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 993\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/ipykernel/kernelapp.py\u001b[0m in \u001b[0;36mstart\u001b[0;34m()\u001b[0m\n\u001b[1;32m 618\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 619\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mio_loop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 620\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mKeyboardInterrupt\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/tornado/platform/asyncio.py\u001b[0m in \u001b[0;36mstart\u001b[0;34m()\u001b[0m\n\u001b[1;32m 194\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mstart\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 195\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masyncio_loop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_forever\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 196\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/lib/python3.10/asyncio/base_events.py\u001b[0m in \u001b[0;36mrun_forever\u001b[0;34m()\u001b[0m\n\u001b[1;32m 602\u001b[0m \u001b[0;32mwhile\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 603\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_run_once\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 604\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_stopping\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/lib/python3.10/asyncio/base_events.py\u001b[0m in \u001b[0;36m_run_once\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1908\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1909\u001b[0;31m \u001b[0mhandle\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_run\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1910\u001b[0m \u001b[0mhandle\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;31m# Needed to break cycles when an exception occurs.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/lib/python3.10/asyncio/events.py\u001b[0m in \u001b[0;36m_run\u001b[0;34m()\u001b[0m\n\u001b[1;32m 79\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 80\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_context\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_callback\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_args\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 81\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mSystemExit\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mKeyboardInterrupt\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/tornado/ioloop.py\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 684\u001b[0m future.add_done_callback(\n\u001b[0;32m--> 685\u001b[0;31m \u001b[0;32mlambda\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_run_callback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfunctools\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpartial\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcallback\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfuture\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 686\u001b[0m )\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/tornado/ioloop.py\u001b[0m in \u001b[0;36m_run_callback\u001b[0;34m()\u001b[0m\n\u001b[1;32m 737\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 738\u001b[0;31m \u001b[0mret\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcallback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 739\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mret\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/tornado/gen.py\u001b[0m in \u001b[0;36minner\u001b[0;34m()\u001b[0m\n\u001b[1;32m 824\u001b[0m \u001b[0mf\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;31m# noqa: F841\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 825\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mctx_run\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 826\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/tornado/gen.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m()\u001b[0m\n\u001b[1;32m 785\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 786\u001b[0;31m \u001b[0myielded\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgen\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 787\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py\u001b[0m in \u001b[0;36mprocess_one\u001b[0;34m()\u001b[0m\n\u001b[1;32m 360\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 361\u001b[0;31m \u001b[0;32myield\u001b[0m \u001b[0mgen\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmaybe_future\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdispatch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 362\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/tornado/gen.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m()\u001b[0m\n\u001b[1;32m 233\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 234\u001b[0;31m \u001b[0myielded\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mctx_run\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnext\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 235\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mStopIteration\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mReturn\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py\u001b[0m in \u001b[0;36mdispatch_shell\u001b[0;34m()\u001b[0m\n\u001b[1;32m 260\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 261\u001b[0;31m \u001b[0;32myield\u001b[0m \u001b[0mgen\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmaybe_future\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhandler\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstream\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0midents\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 262\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/tornado/gen.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m()\u001b[0m\n\u001b[1;32m 233\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 234\u001b[0;31m \u001b[0myielded\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mctx_run\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnext\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 235\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mStopIteration\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mReturn\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py\u001b[0m in \u001b[0;36mexecute_request\u001b[0;34m()\u001b[0m\n\u001b[1;32m 538\u001b[0m reply_content = yield gen.maybe_future(\n\u001b[0;32m--> 539\u001b[0;31m self.do_execute(\n\u001b[0m\u001b[1;32m 540\u001b[0m \u001b[0mcode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msilent\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstore_history\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/tornado/gen.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m()\u001b[0m\n\u001b[1;32m 233\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 234\u001b[0;31m \u001b[0myielded\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mctx_run\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnext\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 235\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mStopIteration\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mReturn\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/ipykernel/ipkernel.py\u001b[0m in \u001b[0;36mdo_execute\u001b[0;34m()\u001b[0m\n\u001b[1;32m 301\u001b[0m \u001b[0;31m# letting shell dispatch to loop runners\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 302\u001b[0;31m \u001b[0mres\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mshell\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_cell\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstore_history\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mstore_history\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msilent\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msilent\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 303\u001b[0m \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/ipykernel/zmqshell.py\u001b[0m in \u001b[0;36mrun_cell\u001b[0;34m()\u001b[0m\n\u001b[1;32m 538\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_last_traceback\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 539\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mZMQInteractiveShell\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_cell\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 540\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py\u001b[0m in \u001b[0;36mrun_cell\u001b[0;34m()\u001b[0m\n\u001b[1;32m 2974\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2975\u001b[0;31m result = self._run_cell(\n\u001b[0m\u001b[1;32m 2976\u001b[0m \u001b[0mraw_cell\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstore_history\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msilent\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mshell_futures\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcell_id\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py\u001b[0m in \u001b[0;36m_run_cell\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3029\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3030\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mrunner\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcoro\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3031\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mBaseException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/IPython/core/async_helpers.py\u001b[0m in \u001b[0;36m_pseudo_sync_runner\u001b[0;34m()\u001b[0m\n\u001b[1;32m 77\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 78\u001b[0;31m \u001b[0mcoro\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 79\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mStopIteration\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mexc\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py\u001b[0m in \u001b[0;36mrun_cell_async\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3256\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3257\u001b[0;31m has_raised = await self.run_ast_nodes(code_ast.body, cell_name,\n\u001b[0m\u001b[1;32m 3258\u001b[0m interactivity=interactivity, compiler=compiler, result=result)\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py\u001b[0m in \u001b[0;36mrun_ast_nodes\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3472\u001b[0m \u001b[0masy\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompare\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3473\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;32mawait\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_code\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0masync_\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0masy\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3474\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py\u001b[0m in \u001b[0;36mrun_code\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3552\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3553\u001b[0;31m \u001b[0mexec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcode_obj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muser_global_ns\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muser_ns\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3554\u001b[0m \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mepoch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m2001\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m \u001b[0mtrain_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mimages_train\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/graph.py\u001b[0m in \u001b[0;36mupdate_context_manager_wrapper\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1042\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1043\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1044\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/transforms/transforms.py\u001b[0m in \u001b[0;36mjit_wrapper\u001b[0;34m()\u001b[0m\n\u001b[1;32m 358\u001b[0m \u001b[0mgraphdef\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mctx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_graph_nodes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 359\u001b[0;31m out, output_state, output_graphdef = jitted_fn(\n\u001b[0m\u001b[1;32m 360\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/transforms/transforms.py\u001b[0m in \u001b[0;36mjit_fn\u001b[0;34m()\u001b[0m\n\u001b[1;32m 157\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 158\u001b[0;31m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 159\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m\u001b[0m in \u001b[0;36mtrain_step\u001b[0;34m()\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mtrain_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mVAE\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mnnx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mOptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mArray\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 14\u001b[0;31m \u001b[0mloss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrads\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnnx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalue_and_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvae_loss\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 15\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgrads\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/graph.py\u001b[0m in \u001b[0;36mupdate_context_manager_wrapper\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1042\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1043\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1044\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/transforms/transforms.py\u001b[0m in \u001b[0;36mgrad_wrapper\u001b[0;34m()\u001b[0m\n\u001b[1;32m 567\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 568\u001b[0;31m out = transform(\n\u001b[0m\u001b[1;32m 569\u001b[0m \u001b[0mgrad_fn\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/transforms/transforms.py\u001b[0m in \u001b[0;36mgrad_fn\u001b[0;34m()\u001b[0m\n\u001b[1;32m 511\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 512\u001b[0;31m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 513\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m\u001b[0m in \u001b[0;36mvae_loss\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mvae_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mVAE\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mArray\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mlogits\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmean\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstd\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m kl_loss = jnp.mean(0.5 * jnp.mean(\n", + "\u001b[0;32m\u001b[0m in \u001b[0;36m__call__\u001b[0;34m()\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mravel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# flatten\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 18\u001b[0;31m \u001b[0mz\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmean\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstd\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mencoder\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 19\u001b[0m \u001b[0mlogits\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdecoder\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mz\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m\u001b[0m in \u001b[0;36m__call__\u001b[0;34m()\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0mmean\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinear_mean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 14\u001b[0;31m \u001b[0mstd\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinear_std\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 15\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/nn/linear.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m()\u001b[0m\n\u001b[1;32m 380\u001b[0m )\n\u001b[0;32m--> 381\u001b[0;31m y = self.dot_general(\n\u001b[0m\u001b[1;32m 382\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mJaxStackTraceBeforeTransformation\u001b[0m: FloatingPointError: invalid value (nan) encountered in jit(dot_general). Because jax_config.debug_nans.value and/or config.jax_debug_infs is set, the de-optimized function (i.e., the function as if the `jit` decorator were removed) was called in an attempt to get a more precise error message. However, the de-optimized function did not produce invalid values during its execution. This behavior can result from `jit` optimizations causing the invalid value to be produced. It may also arise from having nan/inf constants as outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. \n\nIt may be possible to avoid the invalid value by removing the `jit` decorator, at the cost of losing optimizations. \n\nIf you see this error, consider opening a bug report at https://github.com/google/jax.\n\nThe preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.\n\n--------------------", + "\nThe above exception was the direct cause of the following exception:\n", + "\u001b[0;31mFloatingPointError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdebug_nans\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mepoch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m2001\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m \u001b[0mtrain_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mimages_train\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/graph.py\u001b[0m in \u001b[0;36mupdate_context_manager_wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 1041\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mupdate_context_manager_wrapper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1042\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1043\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1044\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1045\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mupdate_context_manager_wrapper\u001b[0m \u001b[0;31m# type: ignore\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/transforms/transforms.py\u001b[0m in \u001b[0;36mjit_wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 357\u001b[0m )\n\u001b[1;32m 358\u001b[0m \u001b[0mgraphdef\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mctx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_graph_nodes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 359\u001b[0;31m out, output_state, output_graphdef = jitted_fn(\n\u001b[0m\u001b[1;32m 360\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 361\u001b[0m \u001b[0m_nnx_jit_static\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mJitStaticInputs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgraphdef\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_constrain_state\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/api.py\u001b[0m in \u001b[0;36m_nan_check_posthook\u001b[0;34m(fun, args, kwargs, output)\u001b[0m\n\u001b[1;32m 118\u001b[0m print(\"Invalid nan value encountered in the output of a C++-jit/pmap \"\n\u001b[1;32m 119\u001b[0m \"function. Calling the de-optimized version.\")\n\u001b[0;32m--> 120\u001b[0;31m \u001b[0mfun\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_cache_miss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;31m# probably won't return\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 121\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 122\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_update_debug_special_global\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + " \u001b[0;31m[... skipping hidden 24 frame]\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py\u001b[0m in \u001b[0;36m_pjit_call_impl_python\u001b[0;34m(jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline, *args)\u001b[0m\n\u001b[1;32m 1697\u001b[0m \u001b[0;34m\"If you see this error, consider opening a bug report at \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1698\u001b[0m \"https://github.com/google/jax.\")\n\u001b[0;32m-> 1699\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mFloatingPointError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1700\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1701\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mFloatingPointError\u001b[0m: invalid value (nan) encountered in jit(dot_general). Because jax_config.debug_nans.value and/or config.jax_debug_infs is set, the de-optimized function (i.e., the function as if the `jit` decorator were removed) was called in an attempt to get a more precise error message. However, the de-optimized function did not produce invalid values during its execution. This behavior can result from `jit` optimizations causing the invalid value to be produced. It may also arise from having nan/inf constants as outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. \n\nIt may be possible to avoid the invalid value by removing the `jit` decorator, at the cost of losing optimizations. \n\nIf you see this error, consider opening a bug report at https://github.com/google/jax." + ] + } + ], + "source": [ + "model = VAE(\n", + " image_shape=(8, 8),\n", + " hidden_size=32,\n", + " latent_size=8,\n", + " rngs=nnx.Rngs(0, noise=1),\n", + ")\n", + "\n", + "optimizer = nnx.Optimizer(model, optax.adam(1e-3))\n", + "\n", + "with jax.debug_nans(True):\n", + " for epoch in range(2001):\n", + " train_step(model, optimizer, images_train)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "thw9URJmrROj" + }, + "source": [ + "The output here is complicated, because the function we're evaluating is complicated, but the key to deciphering this traceback is to look for the places where the traceback touches your implementation. In particular here, it indicates that NaN values arise during the gradient update:\n", + "```\n", + " in train_step()\n", + " 14 loss, grads = nnx.value_and_grad(vae_loss)(model, x)\n", + "---> 15 optimizer.update(grads)\n", + " 16 return loss\n", + "```\n", + "and further down from this, the details of the gradient update step where the NaN is arising:\n", + "```\n", + "/usr/local/lib/python3.10/dist-packages/optax/tree_utils/_tree_math.py in ()\n", + " 280 lambda g, t: (\n", + "--> 281 (1 - decay) * (g**order) + decay * t if g is not None else None\n", + " 282 ),\n", + "```\n", + "This tells us that the gradient is returning values that lead to `NaN` during the model update: typically this would come about when the gradient itself is for some reason diverging.\n", + "\n", + "A diverging gradient means that something with our loss function may be amiss. We previously saw `loss=NaN` at iteration 500; let's print the progress up to this point:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "id": "KJ1gAh8uurVX", + "outputId": "21564ba8-2ea4-42fb-bf9a-0291a556004c" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 loss: 16745235.0\n", + "Epoch 50 loss: 19.595727920532227\n", + "Epoch 100 loss: -13.440512657165527\n", + "Epoch 150 loss: -145.24871826171875\n", + "Epoch 200 loss: -683.0828247070312\n", + "Epoch 250 loss: -2291.444091796875\n", + "Epoch 300 loss: -6880.775390625\n" + ] + } + ], + "source": [ + "model = VAE(\n", + " image_shape=(8, 8),\n", + " hidden_size=32,\n", + " latent_size=8,\n", + " rngs=nnx.Rngs(0, noise=1),\n", + ")\n", + "\n", + "optimizer = nnx.Optimizer(model, optax.adam(1e-3))\n", + "\n", + "for epoch in range(301):\n", + " loss = train_step(model, optimizer, images_train)\n", + " if epoch % 50 == 0:\n", + " print(f'Epoch {epoch} loss: {loss}')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Jk_wdvqpurTG" + }, + "source": [ + "It looks like our loss value is decreasing toward negative infinity until the point where the values are no longer well-represented by floating point math.\n", + "\n", + "At this point, we may wish to expect the values within the loss function itself to see where the diverging loss might be coming from.\n", + "In typical Python programs you can do this by inserting either a `print` statement or a `breakpoint` in the loss function; it might look something like this:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "id": "9Klkz7qHwWia", + "outputId": "e727cf56-0cc1-4256-9567-5fe96396ce45" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "kl loss Tracedwith with\n", + " primal = Tracedwith\n", + " tangent = Tracedwith with\n", + " pval = (ShapedArray(float32[]), None)\n", + " recipe = JaxprEqnRecipe(eqn_id=, in_tracers=(Traced,), out_tracer_refs=[], out_avals=[ShapedArray(float32[])], primitive=pjit, params={'jaxpr': { lambda ; a:f32[1347]. let\n", + " b:f32[] = reduce_sum[axes=(0,)] a\n", + " c:f32[] = div b 1347.0\n", + " in (c,) }, 'in_shardings': (UnspecifiedValue,), 'out_shardings': (UnspecifiedValue,), 'in_layouts': (None,), 'out_layouts': (None,), 'resource_env': None, 'donated_invars': (False,), 'name': '_mean', 'keep_unused': False, 'inline': True}, effects=set(), source_info=, ctx=JaxprEqnContext(compute_type=None,threefry_partitionable=False),xla_metadata={})\n", + "reconstruction loss Tracedwith with\n", + " primal = Tracedwith\n", + " tangent = Tracedwith with\n", + " pval = (ShapedArray(float32[]), None)\n", + " recipe = JaxprEqnRecipe(eqn_id=, in_tracers=(Traced,), out_tracer_refs=[], out_avals=[ShapedArray(float32[])], primitive=pjit, params={'jaxpr': { lambda ; a:f32[1347,8,8]. let\n", + " b:f32[] = reduce_sum[axes=(0, 1, 2)] a\n", + " c:f32[] = div b 86208.0\n", + " in (c,) }, 'in_shardings': (UnspecifiedValue,), 'out_shardings': (UnspecifiedValue,), 'in_layouts': (None,), 'out_layouts': (None,), 'resource_env': None, 'donated_invars': (False,), 'name': '_mean', 'keep_unused': False, 'inline': True}, effects=set(), source_info=, ctx=JaxprEqnContext(compute_type=None,threefry_partitionable=False),xla_metadata={})\n" + ] + }, + { + "data": { + "text/plain": [ + "Array(16745235., dtype=float32)" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def vae_loss(model: VAE, x: jax.Array):\n", + " logits, mean, std = model(x)\n", + " kl_loss = jnp.mean(0.5 * jnp.mean(\n", + " -jnp.log(std ** 2) - 1.0 + std ** 2 + mean ** 2, axis=-1))\n", + " reconstruction_loss = jnp.mean(\n", + " optax.sigmoid_binary_cross_entropy(logits, x)\n", + " )\n", + " print(\"kl loss\", kl_loss)\n", + " print(\"reconstruction loss\", reconstruction_loss)\n", + " return reconstruction_loss + 0.1 * kl_loss\n", + "\n", + "model = VAE(\n", + " image_shape=(8, 8),\n", + " hidden_size=32,\n", + " latent_size=8,\n", + " rngs=nnx.Rngs(0, noise=1),\n", + ")\n", + "\n", + "optimizer = nnx.Optimizer(model, optax.adam(1e-3))\n", + "train_step(model, optimizer, images_train)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mawnZtF8wxu9" + }, + "source": [ + "But here rather than printing the value, we're getting some kind of `Traced` object. You'll encounter this frequently when inspecting the progress of JAX programs: tracers are the mechanism that JAX uses to implement transformations like `jit` and `grad`, and you can read more about them in [JAX Key Concepts: Tracing](https://jax.readthedocs.io/en/latest/key-concepts.html#tracing).\n", + "\n", + "For our purposes, the workaround is to use another tool from the [Debugging runtime values](https://jax.readthedocs.io/en/latest/debugging/index.html#interactive-inspection-with-jax-debug) link above: namely `jax.debug.print`, which lets us print runtime values even when they're traced:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "id": "wziDzgdTuloK", + "outputId": "5d640a4e-306b-4b93-9b99-725ee6e4baa1" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "kl_loss: 167451888.0\n", + "reconstruction_loss: 44.51668167114258\n", + "kl_loss: 21651530.0\n", + "reconstruction_loss: 6.270397186279297\n", + "kl_loss: 4448844.5\n", + "reconstruction_loss: -14.727174758911133\n", + "kl_loss: 1285240.625\n" + ] + } + ], + "source": [ + "def vae_loss(model: VAE, x: jax.Array):\n", + " logits, mean, std = model(x)\n", + "\n", + " kl_loss = jnp.mean(0.5 * jnp.mean(\n", + " -jnp.log(std ** 2) - 1.0 + std ** 2 + mean ** 2, axis=-1))\n", + " reconstruction_loss = jnp.mean(\n", + " optax.sigmoid_binary_cross_entropy(logits, x)\n", + " )\n", + " jax.debug.print(\"kl_loss: {}\", kl_loss)\n", + " jax.debug.print(\"reconstruction_loss: {}\", reconstruction_loss)\n", + " return reconstruction_loss + 0.1 * kl_loss\n", + "\n", + "model = VAE(\n", + " image_shape=(8, 8),\n", + " hidden_size=32,\n", + " latent_size=8,\n", + " rngs=nnx.Rngs(0, noise=1),\n", + ")\n", + "\n", + "optimizer = nnx.Optimizer(model, optax.adam(1e-3))\n", + "\n", + "for i in range(5):\n", + " train_step(model, optimizer, images_train)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "wObaYRxF1-qy" + }, + "source": [ + "Let's iterate a few hundred more times (we'll use the IPython `%%capture` magic to avoid printing all the output on the first several hundred iterations) and then do one more run to print these intermediate values:" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "id": "pBLiDfRX2OOu" + }, + "outputs": [], + "source": [ + "%%capture\n", + "for i in range(300):\n", + " train_step(model, optimizer, images_train)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "id": "O87gxdxGP3uZ", + "outputId": "a4dc57cd-56e0-4597-fbfd-9572e4c4ef7a" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "kl_loss: 2462.782470703125\n", + "reconstruction_loss: -8067.7255859375\n" + ] + } + ], + "source": [ + "loss = train_step(model, optimizer, images_train)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "FXHfa0942apE" + }, + "source": [ + "We see that the large negative value is coming from the `reconstruction_loss` term. Let's return to this and look at what it's actually doing:\n", + "```python\n", + "reconstruction_loss = jnp.mean(\n", + " optax.sigmoid_binary_cross_entropy(logits, x)\n", + ")\n", + "```\n", + "This is a binary cross entropy described at [`optax.sigmoid_binary_cross_entropy`](https://optax.readthedocs.io/en/latest/api/losses.html#optax.losses.sigmoid_binary_cross_entropy). From the documentation, the first input should be a logit, and the second input is assumed to be a binary label (i.e. a ``0`` or ``1``) – but in our case `x` is associated with `images_train`, which is not a binary label!" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "id": "seLRa2qE3wd3", + "outputId": "8b6ba25f-2edb-44c7-c5a9-bca2f334b5eb" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[ 0. 3. 13. 16. 9. 0. 0. 0.]\n", + " [ 0. 10. 15. 13. 15. 2. 0. 0.]\n", + " [ 0. 15. 4. 4. 16. 1. 0. 0.]\n", + " [ 0. 0. 0. 5. 16. 2. 0. 0.]\n", + " [ 0. 0. 1. 14. 13. 0. 0. 0.]\n", + " [ 0. 0. 10. 16. 5. 0. 0. 0.]\n", + " [ 0. 4. 16. 13. 8. 10. 9. 1.]\n", + " [ 0. 2. 16. 16. 14. 12. 9. 1.]]\n" + ] + } + ], + "source": [ + "print(images_train[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KjF0ys3c30w8" + }, + "source": [ + "This is likely the source of our issue: we forgot to normalize the input images to the range ``(0, 1)``!" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jkFIqZaTXRc5" + }, + "source": [ + "Let's fix this by binarizing our inputs, and then run the training loop again (redefining our loss function to remove the debug statements):" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "id": "9Og1-tIw4BNu", + "outputId": "eebd9fd0-2b3b-43ce-9af8-729544a00aca" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 loss: 0.7710005640983582\n", + "Epoch 500 loss: 0.3110124468803406\n", + "Epoch 1000 loss: 0.2782602906227112\n", + "Epoch 1500 loss: 0.26861754059791565\n", + "Epoch 2000 loss: 0.26275068521499634\n" + ] + } + ], + "source": [ + "images_normed = (digits.images / 16) > 0.5\n", + "splits = train_test_split(images_normed, random_state=0)\n", + "images_train, images_test = map(jnp.asarray, splits)\n", + "\n", + "def vae_loss(model: VAE, x: jax.Array):\n", + " logits, mean, std = model(x)\n", + "\n", + " kl_loss = jnp.mean(0.5 * jnp.mean(\n", + " -jnp.log(std ** 2) - 1.0 + std ** 2 + mean ** 2, axis=-1))\n", + " reconstruction_loss = jnp.mean(\n", + " optax.sigmoid_binary_cross_entropy(logits, x)\n", + " )\n", + " return reconstruction_loss + 0.1 * kl_loss\n", + "\n", + "model = VAE(\n", + " image_shape=(8, 8),\n", + " hidden_size=32,\n", + " latent_size=8,\n", + " rngs=nnx.Rngs(0, noise=1),\n", + ")\n", + "\n", + "optimizer = nnx.Optimizer(model, optax.adam(1e-3))\n", + "\n", + "for epoch in range(2001):\n", + " loss = train_step(model, optimizer, images_train)\n", + " if epoch % 500 == 0:\n", + " print(f'Epoch {epoch} loss: {loss}')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4HD91gWfyJuJ" + }, + "source": [ + "The loss values are now behaving, and it looks like we've successfully debugged the initial NaN problem, which was not in our model but rather in our input data." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vA6wSi1k5GuZ" + }, + "source": [ + "## Exploring the VAE model results\n", + "\n", + "Now that we have a trained model, let's explore what it can be used for.\n", + "First, let's pass our test data through the model to output the result of the associated latent space representation for each input. We pass the `logits` through a `sigmoid` function to recover predicted images in the input space:" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "id": "fBzJyliAPCGc" + }, + "outputs": [], + "source": [ + "logits, mean, std = model(images_test)\n", + "images_pred = jax.nn.sigmoid(logits)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qiJy39iDPTVS" + }, + "source": [ + "Let's visualize several of these inputs and outputs:" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "id": "dRFxkKInn_gx", + "outputId": "4452517c-ac68-479d-a452-c282e4092291" + }, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "fig, ax = plt.subplots(2, 10, figsize=(6, 1.5),\n", + " subplot_kw={'xticks':[], 'yticks':[]},\n", + " gridspec_kw=dict(hspace=0.1, wspace=0.1))\n", + "for i in range(10):\n", + " ax[0, i].imshow(images_test[i], cmap='binary', interpolation='gaussian')\n", + " ax[1, i].imshow(images_pred[i], cmap='binary', interpolation='gaussian')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "eaM-A6rNQWFz" + }, + "source": [ + "The top row here are the input images, and the bottom row are what the model \"thinks\" these images look like, given their latent space representation.\n", + "There's not perfect fidelity, but the essential features are recovered.\n", + "\n", + "We can go a step further and generate a set of new images from scratch by sampling randomly from the latent space. Let's generate 36 new digits this way:" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "id": "aNV9CNC1r2Gn", + "outputId": "5855d88a-81c8-4aee-d9ff-885ad78ad421" + }, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import numpy as np\n", + "\n", + "# generate new images by sampling the latent space\n", + "z = np.random.normal(scale=1.5, size=(36, model.latent_size))\n", + "logits = model.decoder(z).reshape(-1, 8, 8)\n", + "images_gen = nnx.sigmoid(logits)\n", + "\n", + "fig, ax = plt.subplots(6, 6, figsize=(4, 4),\n", + " subplot_kw={'xticks':[], 'yticks':[]},\n", + " gridspec_kw=dict(hspace=0.1, wspace=0.1))\n", + "for i in range(36):\n", + " ax.flat[i].imshow(images_gen[i], cmap='binary', interpolation='gaussian')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4oCtm6V3TsQJ" + }, + "source": [ + "Another possibility here is to use our latent model to interpolate between two entries in the training set through the latent model space.\n", + "Here's an interpolation between a digit 9 and a digit 3:" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "id": "8iJ9f60VUBwY", + "outputId": "ce34c9a4-3d47-43b4-ec1a-c5d47db0b394" + }, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "z, _, _ = model.encoder(images_train.reshape(-1, 64))\n", + "zrange = jnp.linspace(z[2], z[9], 10)\n", + "\n", + "logits = model.decoder(zrange).reshape(-1, 8, 8)\n", + "images_gen = nnx.sigmoid(logits)\n", + "\n", + "fig, ax = plt.subplots(1, 10, figsize=(8, 1),\n", + " subplot_kw={'xticks':[], 'yticks':[]},\n", + " gridspec_kw=dict(hspace=0.1, wspace=0.1))\n", + "for i in range(10):\n", + " ax.flat[i].imshow(images_gen[i], cmap='binary', interpolation='gaussian')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WHAaq0ryS7H5" + }, + "source": [ + "## Summary\n", + "\n", + "This tutorial offered a first example defining and training a generative model, in this case a simplified Variational Autoencoder (VAE).\n", + "Along the way we explored some approaches to debugging JAX programs, in particular the `jax.debug_nans` setting and the `jax.debug.print` function. Both of these are explained further in JAX's [Debugging runtime values](https://jax.readthedocs.io/en/latest/debugging/index.html) doc." + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "jupytext": { + "formats": "ipynb,md:myst" + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/docs/digits_vae.md b/docs/digits_vae.md new file mode 100644 index 0000000..83af38f --- /dev/null +++ b/docs/digits_vae.md @@ -0,0 +1,514 @@ +--- +jupytext: + formats: ipynb,md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.15.2 +kernelspec: + display_name: Python 3 + name: python3 +--- + ++++ {"id": "47OmRSTR1dJU"} + +# Debugging in JAX: a Variational autoencoder (VAE) model + +In [Getting started with JAX](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html) we built a simple neural network for classification of handwritten digits, and covered some of the key features of JAX, including its NumPy-style interface in the `jax.numpy`, as well as its transformations for JIT compilation with `jax.jit`, automatic vectorization with `jax.vmap`, and automatic differentiation with `jax.grad`. + +This tutorial will explore a slightly more involved model: a simplified version of a [Variational Autoencoder (VAE)](https://en.wikipedia.org/wiki/Variational_autoencoder) trained on the same simple digits data. Along the way, we'll learn a bit more about how JAX's JIT compilation actually works, and what this means for debugging JAX programs. + ++++ {"id": "k19povzxp7hS"} + +## Loading the digits + +As before, we'll use the small, self-contained [scikit-learn digits dataset](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_digits.html) for ease of experimentation: + +```{code-cell} +:id: aIwDAfS6PtFh +:outputId: 4950f17a-7c47-4a83-cbcd-e206d07cca64 + +from sklearn.datasets import load_digits +from sklearn.model_selection import train_test_split +import jax.numpy as jnp + +digits = load_digits() + +splits = train_test_split(digits.images, random_state=0) + +images_train, images_test = map(jnp.asarray, splits) + +print(f"{images_train.shape=}") +print(f"{images_test.shape=}") +``` + ++++ {"id": "2_Q16JRyrW7V"} + +The dataset comprises 1800 images, each represented by an 8x8 pixel grid. +To see a visualization of this data, refer to [loading the data](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html#loading-the-data) in the previous tutorial. + ++++ {"id": "Z9TPYqipPyBp"} + +## Defining the Variational Autoencoder + +Previously we defined a simple feedforward network trained for classification with an architecture that looked roughly like this: + +```{code-cell} +:id: HNlg-ydpr5yH + +import jax +import jax.numpy as jnp +from flax import nnx + +class SimpleNN(nnx.Module): + + def __init__(self, n_features=64, n_hidden=100, n_targets=10, *, rngs: nnx.Rngs): + self.layer1 = nnx.Linear(n_features, n_hidden, rngs=rngs) + self.layer2 = nnx.Linear(n_hidden, n_hidden, rngs=rngs) + self.layer3 = nnx.Linear(n_hidden, n_targets, rngs=rngs) + + def __call__(self, x: jax.Array) -> jax.Array: + x = nnx.selu(self.layer1(x)) + x = nnx.selu(self.layer2(x)) + return self.layer3(x) +``` + ++++ {"id": "3DwMBvoksOmG"} + +In this network we had one output per class, and the loss function was designed such that once trained, the output corresponding to the correct class would return the strongest signal, thus predicting the correct label in upwards of 95% of cases. + +In this VAE example we use similar building blocks to instead output a small probabilisitic model representing the data. While classic `VAE` is generally based on convolutional layers, we use linear layers for simplicity. The sub-network that produces this probabilistic encoding is our `Encoder`: + +```{code-cell} +:id: Hj7mtR5vmcGr + +class Encoder(nnx.Module): + def __init__(self, input_size: int, intermediate_size: int, output_size: int, + *, rngs: nnx.Rngs): + self.rngs = rngs + self.linear = nnx.Linear(input_size, intermediate_size, rngs=rngs) + self.linear_mean = nnx.Linear(intermediate_size, output_size, rngs=rngs) + self.linear_std = nnx.Linear(intermediate_size, output_size, rngs=rngs) + + def __call__(self, x: jax.Array) -> tuple[jax.Array, jax.Array, jax.Array]: + x = self.linear(x) + x = jax.nn.relu(x) + + mean = self.linear_mean(x) + std = jnp.exp(self.linear_std(x)) + + key = self.rngs.noise() + z = mean + std * jax.random.normal(key, mean.shape) + return z, mean, std +``` + ++++ {"id": "VwfCWbiRmkG9"} + +The idea here is that `mean` and `std` define a low-dimensional probability distribution over a latent space, and that `z` is a draw from this latent space that represents the training data. + +In order to ensure that this latent distribution faithfully represents the actual data, we define a `Decoder` that maps back to the input space: + +```{code-cell} +:id: FoAmZuVDnjgn + +class Decoder(nnx.Module): + def __init__(self, input_size: int, intermediate_size: int, output_size: int, + *, rngs: nnx.Rngs): + self.linear1 = nnx.Linear(input_size, intermediate_size, rngs=rngs) + self.linear2 = nnx.Linear(intermediate_size, output_size, rngs=rngs) + + def __call__(self, z: jax.Array) -> jax.Array: + z = self.linear1(z) + z = jax.nn.relu(z) + logits = self.linear2(z) + return logits +``` + ++++ {"id": "0QaT-KY6npSc"} + +Now the full VAE model is a single network built from the encoder and decoder. +It returns both the reconstructed image and then internal latent space model: + +```{code-cell} +:id: Myo2MdxXnzlT + +class VAE(nnx.Module): + def __init__( + self, + image_shape: tuple[int, int], + hidden_size: int, + latent_size: int, + *, + rngs: nnx.Rngs + ): + self.image_shape = image_shape + self.latent_size = latent_size + input_size = image_shape[0] * image_shape[1] + self.encoder = Encoder(input_size, hidden_size, latent_size, rngs=rngs) + self.decoder = Decoder(latent_size, hidden_size, input_size, rngs=rngs) + + def __call__(self, x: jax.Array) -> tuple[jax.Array, jax.Array, jax.Array]: + x = jax.vmap(jax.numpy.ravel)(x) # flatten + z, mean, std = self.encoder(x) + logits = self.decoder(z) + logits = jnp.reshape(logits, (-1, *self.image_shape)) + return logits, mean, std +``` + ++++ {"id": "xIm9Yi5YoIxN"} + +Next is the loss function – there are two components to the model that we want to ensure: + +1. the `logits` output faithfully reconstruct the input image. +2. the model represented by `mean` and `std` faithfully represents the "true" latent distribution. + +VAE uses a loss function based on the [Evidence lower bound](https://en.wikipedia.org/wiki/Evidence_lower_bound) to quantify theset two goals in a single loss value: + +```{code-cell} +:id: bMpxj8-Wsvui + +def vae_loss(model: VAE, x: jax.Array): + logits, mean, std = model(x) + kl_loss = jnp.mean(0.5 * jnp.mean( + -jnp.log(std ** 2) - 1.0 + std ** 2 + mean ** 2, axis=-1)) + reconstruction_loss = jnp.mean( + optax.sigmoid_binary_cross_entropy(logits, x) + ) + return reconstruction_loss + 0.1 * kl_loss +``` + ++++ {"id": "RaT0ELpvqo2W"} + +Now all that's left is to define the model and optimizer, and run the training loop: + +```{code-cell} +:id: JPgoHL5rpKXd +:outputId: e7626646-7c14-42dc-c18f-774adff1306e + +import optax + +model = VAE( + image_shape=(8, 8), + hidden_size=32, + latent_size=8, + rngs=nnx.Rngs(0, noise=1), +) + +optimizer = nnx.Optimizer(model, optax.adam(1e-3)) + +@nnx.jit +def train_step(model: VAE, optimizer: nnx.Optimizer, x: jax.Array): + loss, grads = nnx.value_and_grad(vae_loss)(model, x) + optimizer.update(grads) + return loss + +for epoch in range(2001): + loss = train_step(model, optimizer, images_train) + if epoch % 500 == 0: + print(f'Epoch {epoch} loss: {loss}') +``` + ++++ {"id": "f8m_uoL9q47M"} + +And here we see that something has gone wrong: our loss value has become NaN after some number of iterations. + ++++ {"id": "SBS1mmxwrS25"} + +## Debugging NaNs in JAX +Despite our best efforts, our model is producing NaNs. What can we do? + +JAX offers a number of debugging approaches for situations like this, outlined in the JAX docs at [Debugging runtime values](https://jax.readthedocs.io/en/latest/debugging/index.html). In this case we can use the [`debug_nans`](https://jax.readthedocs.io/en/latest/debugging/flags.html#jax-debug-nans-configuration-option-and-context-manager) configuration to see where the NaN value is arising: + +```{code-cell} +:id: JE7OYoZ4rRQ8 +:outputId: 2b77a9c1-aaa4-418a-da71-013d8f885340 +:tags: [raises-exception] + +model = VAE( + image_shape=(8, 8), + hidden_size=32, + latent_size=8, + rngs=nnx.Rngs(0, noise=1), +) + +optimizer = nnx.Optimizer(model, optax.adam(1e-3)) + +with jax.debug_nans(True): + for epoch in range(2001): + train_step(model, optimizer, images_train) +``` + ++++ {"id": "thw9URJmrROj"} + +The output here is complicated, because the function we're evaluating is complicated, but the key to deciphering this traceback is to look for the places where the traceback touches your implementation. In particular here, it indicates that NaN values arise during the gradient update: +``` + in train_step() + 14 loss, grads = nnx.value_and_grad(vae_loss)(model, x) +---> 15 optimizer.update(grads) + 16 return loss +``` +and further down from this, the details of the gradient update step where the NaN is arising: +``` +/usr/local/lib/python3.10/dist-packages/optax/tree_utils/_tree_math.py in () + 280 lambda g, t: ( +--> 281 (1 - decay) * (g**order) + decay * t if g is not None else None + 282 ), +``` +This tells us that the gradient is returning values that lead to `NaN` during the model update: typically this would come about when the gradient itself is for some reason diverging. + +A diverging gradient means that something with our loss function may be amiss. We previously saw `loss=NaN` at iteration 500; let's print the progress up to this point: + +```{code-cell} +:id: KJ1gAh8uurVX +:outputId: 21564ba8-2ea4-42fb-bf9a-0291a556004c + +model = VAE( + image_shape=(8, 8), + hidden_size=32, + latent_size=8, + rngs=nnx.Rngs(0, noise=1), +) + +optimizer = nnx.Optimizer(model, optax.adam(1e-3)) + +for epoch in range(301): + loss = train_step(model, optimizer, images_train) + if epoch % 50 == 0: + print(f'Epoch {epoch} loss: {loss}') +``` + ++++ {"id": "Jk_wdvqpurTG"} + +It looks like our loss value is decreasing toward negative infinity until the point where the values are no longer well-represented by floating point math. + +At this point, we may wish to expect the values within the loss function itself to see where the diverging loss might be coming from. +In typical Python programs you can do this by inserting either a `print` statement or a `breakpoint` in the loss function; it might look something like this: + +```{code-cell} +:id: 9Klkz7qHwWia +:outputId: e727cf56-0cc1-4256-9567-5fe96396ce45 + +def vae_loss(model: VAE, x: jax.Array): + logits, mean, std = model(x) + kl_loss = jnp.mean(0.5 * jnp.mean( + -jnp.log(std ** 2) - 1.0 + std ** 2 + mean ** 2, axis=-1)) + reconstruction_loss = jnp.mean( + optax.sigmoid_binary_cross_entropy(logits, x) + ) + print("kl loss", kl_loss) + print("reconstruction loss", reconstruction_loss) + return reconstruction_loss + 0.1 * kl_loss + +model = VAE( + image_shape=(8, 8), + hidden_size=32, + latent_size=8, + rngs=nnx.Rngs(0, noise=1), +) + +optimizer = nnx.Optimizer(model, optax.adam(1e-3)) +train_step(model, optimizer, images_train) +``` + ++++ {"id": "mawnZtF8wxu9"} + +But here rather than printing the value, we're getting some kind of `Traced` object. You'll encounter this frequently when inspecting the progress of JAX programs: tracers are the mechanism that JAX uses to implement transformations like `jit` and `grad`, and you can read more about them in [JAX Key Concepts: Tracing](https://jax.readthedocs.io/en/latest/key-concepts.html#tracing). + +For our purposes, the workaround is to use another tool from the [Debugging runtime values](https://jax.readthedocs.io/en/latest/debugging/index.html#interactive-inspection-with-jax-debug) link above: namely `jax.debug.print`, which lets us print runtime values even when they're traced: + +```{code-cell} +:id: wziDzgdTuloK +:outputId: 5d640a4e-306b-4b93-9b99-725ee6e4baa1 + +def vae_loss(model: VAE, x: jax.Array): + logits, mean, std = model(x) + + kl_loss = jnp.mean(0.5 * jnp.mean( + -jnp.log(std ** 2) - 1.0 + std ** 2 + mean ** 2, axis=-1)) + reconstruction_loss = jnp.mean( + optax.sigmoid_binary_cross_entropy(logits, x) + ) + jax.debug.print("kl_loss: {}", kl_loss) + jax.debug.print("reconstruction_loss: {}", reconstruction_loss) + return reconstruction_loss + 0.1 * kl_loss + +model = VAE( + image_shape=(8, 8), + hidden_size=32, + latent_size=8, + rngs=nnx.Rngs(0, noise=1), +) + +optimizer = nnx.Optimizer(model, optax.adam(1e-3)) + +for i in range(5): + train_step(model, optimizer, images_train) +``` + ++++ {"id": "wObaYRxF1-qy"} + +Let's iterate a few hundred more times (we'll use the IPython `%%capture` magic to avoid printing all the output on the first several hundred iterations) and then do one more run to print these intermediate values: + +```{code-cell} +:id: pBLiDfRX2OOu + +%%capture +for i in range(300): + train_step(model, optimizer, images_train) +``` + +```{code-cell} +:id: O87gxdxGP3uZ +:outputId: a4dc57cd-56e0-4597-fbfd-9572e4c4ef7a + +loss = train_step(model, optimizer, images_train) +``` + ++++ {"id": "FXHfa0942apE"} + +We see that the large negative value is coming from the `reconstruction_loss` term. Let's return to this and look at what it's actually doing: +```python +reconstruction_loss = jnp.mean( + optax.sigmoid_binary_cross_entropy(logits, x) +) +``` +This is a binary cross entropy described at [`optax.sigmoid_binary_cross_entropy`](https://optax.readthedocs.io/en/latest/api/losses.html#optax.losses.sigmoid_binary_cross_entropy). From the documentation, the first input should be a logit, and the second input is assumed to be a binary label (i.e. a ``0`` or ``1``) – but in our case `x` is associated with `images_train`, which is not a binary label! + +```{code-cell} +:id: seLRa2qE3wd3 +:outputId: 8b6ba25f-2edb-44c7-c5a9-bca2f334b5eb + +print(images_train[0]) +``` + ++++ {"id": "KjF0ys3c30w8"} + +This is likely the source of our issue: we forgot to normalize the input images to the range ``(0, 1)``! + ++++ {"id": "jkFIqZaTXRc5"} + +Let's fix this by binarizing our inputs, and then run the training loop again (redefining our loss function to remove the debug statements): + +```{code-cell} +:id: 9Og1-tIw4BNu +:outputId: eebd9fd0-2b3b-43ce-9af8-729544a00aca + +images_normed = (digits.images / 16) > 0.5 +splits = train_test_split(images_normed, random_state=0) +images_train, images_test = map(jnp.asarray, splits) + +def vae_loss(model: VAE, x: jax.Array): + logits, mean, std = model(x) + + kl_loss = jnp.mean(0.5 * jnp.mean( + -jnp.log(std ** 2) - 1.0 + std ** 2 + mean ** 2, axis=-1)) + reconstruction_loss = jnp.mean( + optax.sigmoid_binary_cross_entropy(logits, x) + ) + return reconstruction_loss + 0.1 * kl_loss + +model = VAE( + image_shape=(8, 8), + hidden_size=32, + latent_size=8, + rngs=nnx.Rngs(0, noise=1), +) + +optimizer = nnx.Optimizer(model, optax.adam(1e-3)) + +for epoch in range(2001): + loss = train_step(model, optimizer, images_train) + if epoch % 500 == 0: + print(f'Epoch {epoch} loss: {loss}') +``` + ++++ {"id": "4HD91gWfyJuJ"} + +The loss values are now behaving, and it looks like we've successfully debugged the initial NaN problem, which was not in our model but rather in our input data. + ++++ {"id": "vA6wSi1k5GuZ"} + +## Exploring the VAE model results + +Now that we have a trained model, let's explore what it can be used for. +First, let's pass our test data through the model to output the result of the associated latent space representation for each input. We pass the `logits` through a `sigmoid` function to recover predicted images in the input space: + +```{code-cell} +:id: fBzJyliAPCGc + +logits, mean, std = model(images_test) +images_pred = jax.nn.sigmoid(logits) +``` + ++++ {"id": "qiJy39iDPTVS"} + +Let's visualize several of these inputs and outputs: + +```{code-cell} +:id: dRFxkKInn_gx +:outputId: 4452517c-ac68-479d-a452-c282e4092291 + +import matplotlib.pyplot as plt + +fig, ax = plt.subplots(2, 10, figsize=(6, 1.5), + subplot_kw={'xticks':[], 'yticks':[]}, + gridspec_kw=dict(hspace=0.1, wspace=0.1)) +for i in range(10): + ax[0, i].imshow(images_test[i], cmap='binary', interpolation='gaussian') + ax[1, i].imshow(images_pred[i], cmap='binary', interpolation='gaussian') +``` + ++++ {"id": "eaM-A6rNQWFz"} + +The top row here are the input images, and the bottom row are what the model "thinks" these images look like, given their latent space representation. +There's not perfect fidelity, but the essential features are recovered. + +We can go a step further and generate a set of new images from scratch by sampling randomly from the latent space. Let's generate 36 new digits this way: + +```{code-cell} +:id: aNV9CNC1r2Gn +:outputId: 5855d88a-81c8-4aee-d9ff-885ad78ad421 + +import numpy as np + +# generate new images by sampling the latent space +z = np.random.normal(scale=1.5, size=(36, model.latent_size)) +logits = model.decoder(z).reshape(-1, 8, 8) +images_gen = nnx.sigmoid(logits) + +fig, ax = plt.subplots(6, 6, figsize=(4, 4), + subplot_kw={'xticks':[], 'yticks':[]}, + gridspec_kw=dict(hspace=0.1, wspace=0.1)) +for i in range(36): + ax.flat[i].imshow(images_gen[i], cmap='binary', interpolation='gaussian') +``` + ++++ {"id": "4oCtm6V3TsQJ"} + +Another possibility here is to use our latent model to interpolate between two entries in the training set through the latent model space. +Here's an interpolation between a digit 9 and a digit 3: + +```{code-cell} +:id: 8iJ9f60VUBwY +:outputId: ce34c9a4-3d47-43b4-ec1a-c5d47db0b394 + +z, _, _ = model.encoder(images_train.reshape(-1, 64)) +zrange = jnp.linspace(z[2], z[9], 10) + +logits = model.decoder(zrange).reshape(-1, 8, 8) +images_gen = nnx.sigmoid(logits) + +fig, ax = plt.subplots(1, 10, figsize=(8, 1), + subplot_kw={'xticks':[], 'yticks':[]}, + gridspec_kw=dict(hspace=0.1, wspace=0.1)) +for i in range(10): + ax.flat[i].imshow(images_gen[i], cmap='binary', interpolation='gaussian') +``` + ++++ {"id": "WHAaq0ryS7H5"} + +## Summary + +This tutorial offered a first example defining and training a generative model, in this case a simplified Variational Autoencoder (VAE). +Along the way we explored some approaches to debugging JAX programs, in particular the `jax.debug_nans` setting and the `jax.debug.print` function. Both of these are explained further in JAX's [Debugging runtime values](https://jax.readthedocs.io/en/latest/debugging/index.html) doc. diff --git a/docs/tutorials.md b/docs/tutorials.md index 19473d0..03f0add 100644 --- a/docs/tutorials.md +++ b/docs/tutorials.md @@ -8,6 +8,7 @@ The following tutorials are meant as an intro to the full stack: :maxdepth: 2 getting_started_with_jax_for_AI +digits_vae ``` Once you've gone through this content, you can refer to package-specific diff --git a/pyproject.toml b/pyproject.toml index d217601..16cec95 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,3 +63,7 @@ requires = [ [tool.setuptools] packages = ["jax_ai_stack"] include-package-data = false + +[tool.ruff.lint.per-file-ignores] +# F811: Redefinition of unused name. +"docs/digits_vae.ipynb" = ["F811"]