From a50d715b8f54d55216028f1d7b038db209915984 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 14 Oct 2024 09:37:51 -0700 Subject: [PATCH] Add digits VAE tutorial --- docs/conf.py | 4 + docs/digits_vae.ipynb | 986 ++++++++++++++++++++++++++++++++++++++++++ docs/digits_vae.md | 514 ++++++++++++++++++++++ docs/tutorials.md | 1 + pyproject.toml | 4 + 5 files changed, 1509 insertions(+) create mode 100644 docs/digits_vae.ipynb create mode 100644 docs/digits_vae.md diff --git a/docs/conf.py b/docs/conf.py index f5b6ca7..17ea579 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -34,9 +34,13 @@ 'build/html', 'build/jupyter_execute', # Exclude markdown sources for notebooks: + 'digits_vae.md', 'getting_started_with_jax_for_AI.md', ] +suppress_warnings = [ + 'misc.highlighting_failure', # Suppress warning in exception in digits_vae +] # -- Options for myst ---------------------------------------------- myst_heading_anchors = 3 # auto-generate 3 levels of heading anchors diff --git a/docs/digits_vae.ipynb b/docs/digits_vae.ipynb new file mode 100644 index 0000000..c964bbb --- /dev/null +++ b/docs/digits_vae.ipynb @@ -0,0 +1,986 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "47OmRSTR1dJU" + }, + "source": [ + "# Debugging in JAX: a Variational autoencoder (VAE) model\n", + "\n", + "In [Getting started with JAX](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html) we built a simple neural network for classification of handwritten digits, and covered some of the key features of JAX, including its NumPy-style interface in the `jax.numpy`, as well as its transformations for JIT compilation with `jax.jit`, automatic vectorization with `jax.vmap`, and automatic differentiation with `jax.grad`.\n", + "\n", + "This tutorial will explore a slightly more involved model: a simplified version of a [Variational Autoencoder (VAE)](https://en.wikipedia.org/wiki/Variational_autoencoder) trained on the same simple digits data. Along the way, we'll learn a bit more about how JAX's JIT compilation actually works, and what this means for debugging JAX programs." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "k19povzxp7hS" + }, + "source": [ + "## Loading the digits\n", + "\n", + "As before, we'll use the small, self-contained [scikit-learn digits dataset](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_digits.html) for ease of experimentation:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "aIwDAfS6PtFh", + "outputId": "4950f17a-7c47-4a83-cbcd-e206d07cca64" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "images_train.shape=(1347, 8, 8)\n", + "images_test.shape=(450, 8, 8)\n" + ] + } + ], + "source": [ + "from sklearn.datasets import load_digits\n", + "from sklearn.model_selection import train_test_split\n", + "import jax.numpy as jnp\n", + "\n", + "digits = load_digits()\n", + "\n", + "splits = train_test_split(digits.images, random_state=0)\n", + "\n", + "images_train, images_test = map(jnp.asarray, splits)\n", + "\n", + "print(f\"{images_train.shape=}\")\n", + "print(f\"{images_test.shape=}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2_Q16JRyrW7V" + }, + "source": [ + "The dataset comprises 1800 images, each represented by an 8x8 pixel grid.\n", + "To see a visualization of this data, refer to [loading the data](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html#loading-the-data) in the previous tutorial." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Z9TPYqipPyBp" + }, + "source": [ + "## Defining the Variational Autoencoder\n", + "\n", + "Previously we defined a simple feedforward network trained for classification with an architecture that looked roughly like this:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "id": "HNlg-ydpr5yH" + }, + "outputs": [], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "from flax import nnx\n", + "\n", + "class SimpleNN(nnx.Module):\n", + "\n", + " def __init__(self, n_features=64, n_hidden=100, n_targets=10, *, rngs: nnx.Rngs):\n", + " self.layer1 = nnx.Linear(n_features, n_hidden, rngs=rngs)\n", + " self.layer2 = nnx.Linear(n_hidden, n_hidden, rngs=rngs)\n", + " self.layer3 = nnx.Linear(n_hidden, n_targets, rngs=rngs)\n", + "\n", + " def __call__(self, x: jax.Array) -> jax.Array:\n", + " x = nnx.selu(self.layer1(x))\n", + " x = nnx.selu(self.layer2(x))\n", + " return self.layer3(x)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3DwMBvoksOmG" + }, + "source": [ + "In this network we had one output per class, and the loss function was designed such that once trained, the output corresponding to the correct class would return the strongest signal, thus predicting the correct label in upwards of 95% of cases.\n", + "\n", + "In this VAE example we use similar building blocks to instead output a small probabilisitic model representing the data. While classic `VAE` is generally based on convolutional layers, we use linear layers for simplicity. The sub-network that produces this probabilistic encoding is our `Encoder`:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "Hj7mtR5vmcGr" + }, + "outputs": [], + "source": [ + "class Encoder(nnx.Module):\n", + " def __init__(self, input_size: int, intermediate_size: int, output_size: int,\n", + " *, rngs: nnx.Rngs):\n", + " self.rngs = rngs\n", + " self.linear = nnx.Linear(input_size, intermediate_size, rngs=rngs)\n", + " self.linear_mean = nnx.Linear(intermediate_size, output_size, rngs=rngs)\n", + " self.linear_std = nnx.Linear(intermediate_size, output_size, rngs=rngs)\n", + "\n", + " def __call__(self, x: jax.Array) -> tuple[jax.Array, jax.Array, jax.Array]:\n", + " x = self.linear(x)\n", + " x = jax.nn.relu(x)\n", + "\n", + " mean = self.linear_mean(x)\n", + " std = jnp.exp(self.linear_std(x))\n", + "\n", + " key = self.rngs.noise()\n", + " z = mean + std * jax.random.normal(key, mean.shape)\n", + " return z, mean, std" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VwfCWbiRmkG9" + }, + "source": [ + "The idea here is that `mean` and `std` define a low-dimensional probability distribution over a latent space, and that `z` is a draw from this latent space that represents the training data.\n", + "\n", + "In order to ensure that this latent distribution faithfully represents the actual data, we define a `Decoder` that maps back to the input space:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "id": "FoAmZuVDnjgn" + }, + "outputs": [], + "source": [ + "class Decoder(nnx.Module):\n", + " def __init__(self, input_size: int, intermediate_size: int, output_size: int,\n", + " *, rngs: nnx.Rngs):\n", + " self.linear1 = nnx.Linear(input_size, intermediate_size, rngs=rngs)\n", + " self.linear2 = nnx.Linear(intermediate_size, output_size, rngs=rngs)\n", + "\n", + " def __call__(self, z: jax.Array) -> jax.Array:\n", + " z = self.linear1(z)\n", + " z = jax.nn.relu(z)\n", + " logits = self.linear2(z)\n", + " return logits" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0QaT-KY6npSc" + }, + "source": [ + "Now the full VAE model is a single network built from the encoder and decoder.\n", + "It returns both the reconstructed image and then internal latent space model:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "id": "Myo2MdxXnzlT" + }, + "outputs": [], + "source": [ + "class VAE(nnx.Module):\n", + " def __init__(\n", + " self,\n", + " image_shape: tuple[int, int],\n", + " hidden_size: int,\n", + " latent_size: int,\n", + " *,\n", + " rngs: nnx.Rngs\n", + " ):\n", + " self.image_shape = image_shape\n", + " self.latent_size = latent_size\n", + " input_size = image_shape[0] * image_shape[1]\n", + " self.encoder = Encoder(input_size, hidden_size, latent_size, rngs=rngs)\n", + " self.decoder = Decoder(latent_size, hidden_size, input_size, rngs=rngs)\n", + "\n", + " def __call__(self, x: jax.Array) -> tuple[jax.Array, jax.Array, jax.Array]:\n", + " x = jax.vmap(jax.numpy.ravel)(x) # flatten\n", + " z, mean, std = self.encoder(x)\n", + " logits = self.decoder(z)\n", + " logits = jnp.reshape(logits, (-1, *self.image_shape))\n", + " return logits, mean, std" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "xIm9Yi5YoIxN" + }, + "source": [ + "Next is the loss function – there are two components to the model that we want to ensure:\n", + "\n", + "1. the `logits` output faithfully reconstruct the input image.\n", + "2. the model represented by `mean` and `std` faithfully represents the \"true\" latent distribution.\n", + "\n", + "VAE uses a loss function based on the [Evidence lower bound](https://en.wikipedia.org/wiki/Evidence_lower_bound) to quantify theset two goals in a single loss value:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "id": "bMpxj8-Wsvui" + }, + "outputs": [], + "source": [ + "def vae_loss(model: VAE, x: jax.Array):\n", + " logits, mean, std = model(x)\n", + " kl_loss = jnp.mean(0.5 * jnp.mean(\n", + " -jnp.log(std ** 2) - 1.0 + std ** 2 + mean ** 2, axis=-1))\n", + " reconstruction_loss = jnp.mean(\n", + " optax.sigmoid_binary_cross_entropy(logits, x)\n", + " )\n", + " return reconstruction_loss + 0.1 * kl_loss" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RaT0ELpvqo2W" + }, + "source": [ + "Now all that's left is to define the model and optimizer, and run the training loop:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "id": "JPgoHL5rpKXd", + "outputId": "e7626646-7c14-42dc-c18f-774adff1306e" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 loss: 16745235.0\n", + "Epoch 500 loss: nan\n", + "Epoch 1000 loss: nan\n", + "Epoch 1500 loss: nan\n", + "Epoch 2000 loss: nan\n" + ] + } + ], + "source": [ + "import optax\n", + "\n", + "model = VAE(\n", + " image_shape=(8, 8),\n", + " hidden_size=32,\n", + " latent_size=8,\n", + " rngs=nnx.Rngs(0, noise=1),\n", + ")\n", + "\n", + "optimizer = nnx.Optimizer(model, optax.adam(1e-3))\n", + "\n", + "@nnx.jit\n", + "def train_step(model: VAE, optimizer: nnx.Optimizer, x: jax.Array):\n", + " loss, grads = nnx.value_and_grad(vae_loss)(model, x)\n", + " optimizer.update(grads)\n", + " return loss\n", + "\n", + "for epoch in range(2001):\n", + " loss = train_step(model, optimizer, images_train)\n", + " if epoch % 500 == 0:\n", + " print(f'Epoch {epoch} loss: {loss}')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "f8m_uoL9q47M" + }, + "source": [ + "And here we see that something has gone wrong: our loss value has become NaN after some number of iterations." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SBS1mmxwrS25" + }, + "source": [ + "## Debugging NaNs in JAX\n", + "Despite our best efforts, our model is producing NaNs. What can we do?\n", + "\n", + "JAX offers a number of debugging approaches for situations like this, outlined in the JAX docs at [Debugging runtime values](https://jax.readthedocs.io/en/latest/debugging/index.html). In this case we can use the [`debug_nans`](https://jax.readthedocs.io/en/latest/debugging/flags.html#jax-debug-nans-configuration-option-and-context-manager) configuration to see where the NaN value is arising:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "id": "JE7OYoZ4rRQ8", + "outputId": "2b77a9c1-aaa4-418a-da71-013d8f885340", + "tags": [ + "raises-exception" + ] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Invalid nan value encountered in the output of a C++-jit/pmap function. Calling the de-optimized version.\n", + "Invalid nan value encountered in the output of a C++-jit/pmap function. Calling the de-optimized version.\n" + ] + }, + { + "ename": "FloatingPointError", + "evalue": "invalid value (nan) encountered in jit(dot_general). Because jax_config.debug_nans.value and/or config.jax_debug_infs is set, the de-optimized function (i.e., the function as if the `jit` decorator were removed) was called in an attempt to get a more precise error message. However, the de-optimized function did not produce invalid values during its execution. This behavior can result from `jit` optimizations causing the invalid value to be produced. It may also arise from having nan/inf constants as outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. \n\nIt may be possible to avoid the invalid value by removing the `jit` decorator, at the cost of losing optimizations. \n\nIf you see this error, consider opening a bug report at https://github.com/google/jax.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mFloatingPointError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/api.py\u001b[0m in \u001b[0;36m_nan_check_posthook\u001b[0;34m(fun, args, kwargs, output)\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 114\u001b[0;31m \u001b[0mdispatch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcheck_special\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpjit\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpjit_p\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbuffers\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 115\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mFloatingPointError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py\u001b[0m in \u001b[0;36mcheck_special\u001b[0;34m(name, bufs)\u001b[0m\n\u001b[1;32m 320\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mbuf\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mbufs\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 321\u001b[0;31m \u001b[0m_check_special\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbuf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbuf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 322\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py\u001b[0m in \u001b[0;36m_check_special\u001b[0;34m(name, dtype, buf)\u001b[0m\n\u001b[1;32m 325\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdebug_nans\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalue\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0many\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misnan\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbuf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 326\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mFloatingPointError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"invalid value (nan) encountered in {name}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 327\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdebug_infs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalue\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0many\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misinf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbuf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mFloatingPointError\u001b[0m: invalid value (nan) encountered in pjit", + "\nDuring handling of the above exception, another exception occurred:\n", + "\u001b[0;31mFloatingPointError\u001b[0m Traceback (most recent call last)", + " \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py\u001b[0m in \u001b[0;36mcheck_special\u001b[0;34m(name, bufs)\u001b[0m\n\u001b[1;32m 320\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mbuf\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mbufs\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 321\u001b[0;31m \u001b[0m_check_special\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbuf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbuf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 322\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py\u001b[0m in \u001b[0;36m_check_special\u001b[0;34m(name, dtype, buf)\u001b[0m\n\u001b[1;32m 325\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdebug_nans\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalue\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0many\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misnan\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbuf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 326\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mFloatingPointError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"invalid value (nan) encountered in {name}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 327\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdebug_infs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalue\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0many\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misinf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbuf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mFloatingPointError\u001b[0m: invalid value (nan) encountered in pjit", + "\nDuring handling of the above exception, another exception occurred:\n", + "\u001b[0;31mFloatingPointError\u001b[0m Traceback (most recent call last)", + " \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/profiler.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 332\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mTraceAnnotation\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mdecorator_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 333\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 334\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/pxla.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args)\u001b[0m\n\u001b[1;32m 1290\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0marrays\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mout_arrays\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1291\u001b[0;31m \u001b[0mdispatch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcheck_special\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marrays\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1292\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mout_handler\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout_arrays\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py\u001b[0m in \u001b[0;36mcheck_special\u001b[0;34m(name, bufs)\u001b[0m\n\u001b[1;32m 320\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mbuf\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mbufs\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 321\u001b[0;31m \u001b[0m_check_special\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbuf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbuf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 322\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py\u001b[0m in \u001b[0;36m_check_special\u001b[0;34m(name, dtype, buf)\u001b[0m\n\u001b[1;32m 325\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdebug_nans\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalue\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0many\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misnan\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbuf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 326\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mFloatingPointError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"invalid value (nan) encountered in {name}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 327\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdebug_infs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalue\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0many\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misinf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbuf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mFloatingPointError\u001b[0m: invalid value (nan) encountered in jit(jit_fn)", + "\nDuring handling of the above exception, another exception occurred:\n", + "\u001b[0;31mFloatingPointError\u001b[0m Traceback (most recent call last)", + " \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/profiler.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 332\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mTraceAnnotation\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mdecorator_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 333\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 334\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/pxla.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args)\u001b[0m\n\u001b[1;32m 1290\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0marrays\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mout_arrays\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1291\u001b[0;31m \u001b[0mdispatch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcheck_special\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marrays\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1292\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mout_handler\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout_arrays\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py\u001b[0m in \u001b[0;36mcheck_special\u001b[0;34m(name, bufs)\u001b[0m\n\u001b[1;32m 320\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mbuf\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mbufs\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 321\u001b[0;31m \u001b[0m_check_special\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbuf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbuf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 322\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py\u001b[0m in \u001b[0;36m_check_special\u001b[0;34m(name, dtype, buf)\u001b[0m\n\u001b[1;32m 325\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdebug_nans\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalue\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0many\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misnan\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbuf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 326\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mFloatingPointError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"invalid value (nan) encountered in {name}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 327\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdebug_infs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalue\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0many\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misinf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbuf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mFloatingPointError\u001b[0m: invalid value (nan) encountered in jit(dot_general)", + "\nDuring handling of the above exception, another exception occurred:\n", + "\u001b[0;31mJaxStackTraceBeforeTransformation\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/usr/lib/python3.10/runpy.py\u001b[0m in \u001b[0;36m_run_module_as_main\u001b[0;34m()\u001b[0m\n\u001b[1;32m 195\u001b[0m \u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0margv\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmod_spec\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0morigin\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 196\u001b[0;31m return _run_code(code, main_globals, None,\n\u001b[0m\u001b[1;32m 197\u001b[0m \"__main__\", mod_spec)\n", + "\u001b[0;32m/usr/lib/python3.10/runpy.py\u001b[0m in \u001b[0;36m_run_code\u001b[0;34m()\u001b[0m\n\u001b[1;32m 85\u001b[0m __spec__ = mod_spec)\n\u001b[0;32m---> 86\u001b[0;31m \u001b[0mexec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrun_globals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 87\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mrun_globals\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/colab_kernel_launcher.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0m__name__\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'__main__'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 37\u001b[0;31m \u001b[0mColabKernelApp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlaunch_instance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/traitlets/config/application.py\u001b[0m in \u001b[0;36mlaunch_instance\u001b[0;34m()\u001b[0m\n\u001b[1;32m 991\u001b[0m \u001b[0mapp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minitialize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margv\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 992\u001b[0;31m \u001b[0mapp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 993\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/ipykernel/kernelapp.py\u001b[0m in \u001b[0;36mstart\u001b[0;34m()\u001b[0m\n\u001b[1;32m 618\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 619\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mio_loop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 620\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mKeyboardInterrupt\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/tornado/platform/asyncio.py\u001b[0m in \u001b[0;36mstart\u001b[0;34m()\u001b[0m\n\u001b[1;32m 194\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mstart\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 195\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masyncio_loop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_forever\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 196\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/lib/python3.10/asyncio/base_events.py\u001b[0m in \u001b[0;36mrun_forever\u001b[0;34m()\u001b[0m\n\u001b[1;32m 602\u001b[0m \u001b[0;32mwhile\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 603\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_run_once\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 604\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_stopping\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/lib/python3.10/asyncio/base_events.py\u001b[0m in \u001b[0;36m_run_once\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1908\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1909\u001b[0;31m \u001b[0mhandle\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_run\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1910\u001b[0m \u001b[0mhandle\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;31m# Needed to break cycles when an exception occurs.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/lib/python3.10/asyncio/events.py\u001b[0m in \u001b[0;36m_run\u001b[0;34m()\u001b[0m\n\u001b[1;32m 79\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 80\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_context\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_callback\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_args\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 81\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mSystemExit\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mKeyboardInterrupt\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/tornado/ioloop.py\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 684\u001b[0m future.add_done_callback(\n\u001b[0;32m--> 685\u001b[0;31m \u001b[0;32mlambda\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_run_callback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfunctools\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpartial\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcallback\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfuture\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 686\u001b[0m )\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/tornado/ioloop.py\u001b[0m in \u001b[0;36m_run_callback\u001b[0;34m()\u001b[0m\n\u001b[1;32m 737\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 738\u001b[0;31m \u001b[0mret\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcallback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 739\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mret\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/tornado/gen.py\u001b[0m in \u001b[0;36minner\u001b[0;34m()\u001b[0m\n\u001b[1;32m 824\u001b[0m \u001b[0mf\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;31m# noqa: F841\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 825\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mctx_run\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 826\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/tornado/gen.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m()\u001b[0m\n\u001b[1;32m 785\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 786\u001b[0;31m \u001b[0myielded\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgen\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 787\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py\u001b[0m in \u001b[0;36mprocess_one\u001b[0;34m()\u001b[0m\n\u001b[1;32m 360\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 361\u001b[0;31m \u001b[0;32myield\u001b[0m \u001b[0mgen\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmaybe_future\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdispatch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 362\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/tornado/gen.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m()\u001b[0m\n\u001b[1;32m 233\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 234\u001b[0;31m \u001b[0myielded\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mctx_run\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnext\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 235\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mStopIteration\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mReturn\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py\u001b[0m in \u001b[0;36mdispatch_shell\u001b[0;34m()\u001b[0m\n\u001b[1;32m 260\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 261\u001b[0;31m \u001b[0;32myield\u001b[0m \u001b[0mgen\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmaybe_future\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhandler\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstream\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0midents\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 262\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/tornado/gen.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m()\u001b[0m\n\u001b[1;32m 233\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 234\u001b[0;31m \u001b[0myielded\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mctx_run\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnext\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 235\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mStopIteration\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mReturn\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/ipykernel/kernelbase.py\u001b[0m in \u001b[0;36mexecute_request\u001b[0;34m()\u001b[0m\n\u001b[1;32m 538\u001b[0m reply_content = yield gen.maybe_future(\n\u001b[0;32m--> 539\u001b[0;31m self.do_execute(\n\u001b[0m\u001b[1;32m 540\u001b[0m \u001b[0mcode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msilent\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstore_history\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/tornado/gen.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m()\u001b[0m\n\u001b[1;32m 233\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 234\u001b[0;31m \u001b[0myielded\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mctx_run\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnext\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 235\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mStopIteration\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mReturn\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/ipykernel/ipkernel.py\u001b[0m in \u001b[0;36mdo_execute\u001b[0;34m()\u001b[0m\n\u001b[1;32m 301\u001b[0m \u001b[0;31m# letting shell dispatch to loop runners\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 302\u001b[0;31m \u001b[0mres\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mshell\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_cell\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstore_history\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mstore_history\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msilent\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msilent\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 303\u001b[0m \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/ipykernel/zmqshell.py\u001b[0m in \u001b[0;36mrun_cell\u001b[0;34m()\u001b[0m\n\u001b[1;32m 538\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_last_traceback\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 539\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mZMQInteractiveShell\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_cell\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 540\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py\u001b[0m in \u001b[0;36mrun_cell\u001b[0;34m()\u001b[0m\n\u001b[1;32m 2974\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2975\u001b[0;31m result = self._run_cell(\n\u001b[0m\u001b[1;32m 2976\u001b[0m \u001b[0mraw_cell\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstore_history\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msilent\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mshell_futures\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcell_id\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py\u001b[0m in \u001b[0;36m_run_cell\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3029\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3030\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mrunner\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcoro\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3031\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mBaseException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/IPython/core/async_helpers.py\u001b[0m in \u001b[0;36m_pseudo_sync_runner\u001b[0;34m()\u001b[0m\n\u001b[1;32m 77\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 78\u001b[0;31m \u001b[0mcoro\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 79\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mStopIteration\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mexc\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py\u001b[0m in \u001b[0;36mrun_cell_async\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3256\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3257\u001b[0;31m has_raised = await self.run_ast_nodes(code_ast.body, cell_name,\n\u001b[0m\u001b[1;32m 3258\u001b[0m interactivity=interactivity, compiler=compiler, result=result)\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py\u001b[0m in \u001b[0;36mrun_ast_nodes\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3472\u001b[0m \u001b[0masy\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompare\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3473\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;32mawait\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_code\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0masync_\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0masy\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3474\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py\u001b[0m in \u001b[0;36mrun_code\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3552\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3553\u001b[0;31m \u001b[0mexec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcode_obj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muser_global_ns\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muser_ns\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3554\u001b[0m \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mepoch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m2001\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m \u001b[0mtrain_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mimages_train\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/graph.py\u001b[0m in \u001b[0;36mupdate_context_manager_wrapper\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1042\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1043\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1044\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/transforms/transforms.py\u001b[0m in \u001b[0;36mjit_wrapper\u001b[0;34m()\u001b[0m\n\u001b[1;32m 358\u001b[0m \u001b[0mgraphdef\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mctx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_graph_nodes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 359\u001b[0;31m out, output_state, output_graphdef = jitted_fn(\n\u001b[0m\u001b[1;32m 360\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/transforms/transforms.py\u001b[0m in \u001b[0;36mjit_fn\u001b[0;34m()\u001b[0m\n\u001b[1;32m 157\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 158\u001b[0;31m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 159\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m\u001b[0m in \u001b[0;36mtrain_step\u001b[0;34m()\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mtrain_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mVAE\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mnnx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mOptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mArray\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 14\u001b[0;31m \u001b[0mloss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrads\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnnx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalue_and_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvae_loss\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 15\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgrads\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/graph.py\u001b[0m in \u001b[0;36mupdate_context_manager_wrapper\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1042\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1043\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1044\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/transforms/transforms.py\u001b[0m in \u001b[0;36mgrad_wrapper\u001b[0;34m()\u001b[0m\n\u001b[1;32m 567\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 568\u001b[0;31m out = transform(\n\u001b[0m\u001b[1;32m 569\u001b[0m \u001b[0mgrad_fn\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/transforms/transforms.py\u001b[0m in \u001b[0;36mgrad_fn\u001b[0;34m()\u001b[0m\n\u001b[1;32m 511\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 512\u001b[0;31m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 513\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m\u001b[0m in \u001b[0;36mvae_loss\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mvae_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mVAE\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mArray\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mlogits\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmean\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstd\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m kl_loss = jnp.mean(0.5 * jnp.mean(\n", + "\u001b[0;32m\u001b[0m in \u001b[0;36m__call__\u001b[0;34m()\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mravel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# flatten\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 18\u001b[0;31m \u001b[0mz\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmean\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstd\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mencoder\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 19\u001b[0m \u001b[0mlogits\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdecoder\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mz\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m\u001b[0m in \u001b[0;36m__call__\u001b[0;34m()\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0mmean\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinear_mean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 14\u001b[0;31m \u001b[0mstd\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinear_std\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 15\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/nn/linear.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m()\u001b[0m\n\u001b[1;32m 380\u001b[0m )\n\u001b[0;32m--> 381\u001b[0;31m y = self.dot_general(\n\u001b[0m\u001b[1;32m 382\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mJaxStackTraceBeforeTransformation\u001b[0m: FloatingPointError: invalid value (nan) encountered in jit(dot_general). Because jax_config.debug_nans.value and/or config.jax_debug_infs is set, the de-optimized function (i.e., the function as if the `jit` decorator were removed) was called in an attempt to get a more precise error message. However, the de-optimized function did not produce invalid values during its execution. This behavior can result from `jit` optimizations causing the invalid value to be produced. It may also arise from having nan/inf constants as outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. \n\nIt may be possible to avoid the invalid value by removing the `jit` decorator, at the cost of losing optimizations. \n\nIf you see this error, consider opening a bug report at https://github.com/google/jax.\n\nThe preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.\n\n--------------------", + "\nThe above exception was the direct cause of the following exception:\n", + "\u001b[0;31mFloatingPointError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdebug_nans\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mepoch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m2001\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m \u001b[0mtrain_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mimages_train\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/graph.py\u001b[0m in \u001b[0;36mupdate_context_manager_wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 1041\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mupdate_context_manager_wrapper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1042\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1043\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1044\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1045\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mupdate_context_manager_wrapper\u001b[0m \u001b[0;31m# type: ignore\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/transforms/transforms.py\u001b[0m in \u001b[0;36mjit_wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 357\u001b[0m )\n\u001b[1;32m 358\u001b[0m \u001b[0mgraphdef\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mctx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_graph_nodes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 359\u001b[0;31m out, output_state, output_graphdef = jitted_fn(\n\u001b[0m\u001b[1;32m 360\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 361\u001b[0m \u001b[0m_nnx_jit_static\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mJitStaticInputs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgraphdef\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_constrain_state\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/api.py\u001b[0m in \u001b[0;36m_nan_check_posthook\u001b[0;34m(fun, args, kwargs, output)\u001b[0m\n\u001b[1;32m 118\u001b[0m print(\"Invalid nan value encountered in the output of a C++-jit/pmap \"\n\u001b[1;32m 119\u001b[0m \"function. Calling the de-optimized version.\")\n\u001b[0;32m--> 120\u001b[0;31m \u001b[0mfun\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_cache_miss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;31m# probably won't return\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 121\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 122\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_update_debug_special_global\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + " \u001b[0;31m[... skipping hidden 24 frame]\u001b[0m\n", + "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/jax/_src/pjit.py\u001b[0m in \u001b[0;36m_pjit_call_impl_python\u001b[0;34m(jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline, *args)\u001b[0m\n\u001b[1;32m 1697\u001b[0m \u001b[0;34m\"If you see this error, consider opening a bug report at \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1698\u001b[0m \"https://github.com/google/jax.\")\n\u001b[0;32m-> 1699\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mFloatingPointError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1700\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1701\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mFloatingPointError\u001b[0m: invalid value (nan) encountered in jit(dot_general). Because jax_config.debug_nans.value and/or config.jax_debug_infs is set, the de-optimized function (i.e., the function as if the `jit` decorator were removed) was called in an attempt to get a more precise error message. However, the de-optimized function did not produce invalid values during its execution. This behavior can result from `jit` optimizations causing the invalid value to be produced. It may also arise from having nan/inf constants as outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. \n\nIt may be possible to avoid the invalid value by removing the `jit` decorator, at the cost of losing optimizations. \n\nIf you see this error, consider opening a bug report at https://github.com/google/jax." + ] + } + ], + "source": [ + "model = VAE(\n", + " image_shape=(8, 8),\n", + " hidden_size=32,\n", + " latent_size=8,\n", + " rngs=nnx.Rngs(0, noise=1),\n", + ")\n", + "\n", + "optimizer = nnx.Optimizer(model, optax.adam(1e-3))\n", + "\n", + "with jax.debug_nans(True):\n", + " for epoch in range(2001):\n", + " train_step(model, optimizer, images_train)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "thw9URJmrROj" + }, + "source": [ + "The output here is complicated, because the function we're evaluating is complicated, but the key to deciphering this traceback is to look for the places where the traceback touches your implementation. In particular here, it indicates that NaN values arise during the gradient update:\n", + "```\n", + " in train_step()\n", + " 14 loss, grads = nnx.value_and_grad(vae_loss)(model, x)\n", + "---> 15 optimizer.update(grads)\n", + " 16 return loss\n", + "```\n", + "and further down from this, the details of the gradient update step where the NaN is arising:\n", + "```\n", + "/usr/local/lib/python3.10/dist-packages/optax/tree_utils/_tree_math.py in ()\n", + " 280 lambda g, t: (\n", + "--> 281 (1 - decay) * (g**order) + decay * t if g is not None else None\n", + " 282 ),\n", + "```\n", + "This tells us that the gradient is returning values that lead to `NaN` during the model update: typically this would come about when the gradient itself is for some reason diverging.\n", + "\n", + "A diverging gradient means that something with our loss function may be amiss. We previously saw `loss=NaN` at iteration 500; let's print the progress up to this point:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "id": "KJ1gAh8uurVX", + "outputId": "21564ba8-2ea4-42fb-bf9a-0291a556004c" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 loss: 16745235.0\n", + "Epoch 50 loss: 19.595727920532227\n", + "Epoch 100 loss: -13.440512657165527\n", + "Epoch 150 loss: -145.24871826171875\n", + "Epoch 200 loss: -683.0828247070312\n", + "Epoch 250 loss: -2291.444091796875\n", + "Epoch 300 loss: -6880.775390625\n" + ] + } + ], + "source": [ + "model = VAE(\n", + " image_shape=(8, 8),\n", + " hidden_size=32,\n", + " latent_size=8,\n", + " rngs=nnx.Rngs(0, noise=1),\n", + ")\n", + "\n", + "optimizer = nnx.Optimizer(model, optax.adam(1e-3))\n", + "\n", + "for epoch in range(301):\n", + " loss = train_step(model, optimizer, images_train)\n", + " if epoch % 50 == 0:\n", + " print(f'Epoch {epoch} loss: {loss}')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Jk_wdvqpurTG" + }, + "source": [ + "It looks like our loss value is decreasing toward negative infinity until the point where the values are no longer well-represented by floating point math.\n", + "\n", + "At this point, we may wish to expect the values within the loss function itself to see where the diverging loss might be coming from.\n", + "In typical Python programs you can do this by inserting either a `print` statement or a `breakpoint` in the loss function; it might look something like this:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "id": "9Klkz7qHwWia", + "outputId": "e727cf56-0cc1-4256-9567-5fe96396ce45" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "kl loss Tracedwith with\n", + " primal = Tracedwith\n", + " tangent = Tracedwith with\n", + " pval = (ShapedArray(float32[]), None)\n", + " recipe = JaxprEqnRecipe(eqn_id=, in_tracers=(Traced,), out_tracer_refs=[], out_avals=[ShapedArray(float32[])], primitive=pjit, params={'jaxpr': { lambda ; a:f32[1347]. let\n", + " b:f32[] = reduce_sum[axes=(0,)] a\n", + " c:f32[] = div b 1347.0\n", + " in (c,) }, 'in_shardings': (UnspecifiedValue,), 'out_shardings': (UnspecifiedValue,), 'in_layouts': (None,), 'out_layouts': (None,), 'resource_env': None, 'donated_invars': (False,), 'name': '_mean', 'keep_unused': False, 'inline': True}, effects=set(), source_info=, ctx=JaxprEqnContext(compute_type=None,threefry_partitionable=False),xla_metadata={})\n", + "reconstruction loss Tracedwith with\n", + " primal = Tracedwith\n", + " tangent = Tracedwith with\n", + " pval = (ShapedArray(float32[]), None)\n", + " recipe = JaxprEqnRecipe(eqn_id=, in_tracers=(Traced,), out_tracer_refs=[], out_avals=[ShapedArray(float32[])], primitive=pjit, params={'jaxpr': { lambda ; a:f32[1347,8,8]. let\n", + " b:f32[] = reduce_sum[axes=(0, 1, 2)] a\n", + " c:f32[] = div b 86208.0\n", + " in (c,) }, 'in_shardings': (UnspecifiedValue,), 'out_shardings': (UnspecifiedValue,), 'in_layouts': (None,), 'out_layouts': (None,), 'resource_env': None, 'donated_invars': (False,), 'name': '_mean', 'keep_unused': False, 'inline': True}, effects=set(), source_info=, ctx=JaxprEqnContext(compute_type=None,threefry_partitionable=False),xla_metadata={})\n" + ] + }, + { + "data": { + "text/plain": [ + "Array(16745235., dtype=float32)" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def vae_loss(model: VAE, x: jax.Array):\n", + " logits, mean, std = model(x)\n", + " kl_loss = jnp.mean(0.5 * jnp.mean(\n", + " -jnp.log(std ** 2) - 1.0 + std ** 2 + mean ** 2, axis=-1))\n", + " reconstruction_loss = jnp.mean(\n", + " optax.sigmoid_binary_cross_entropy(logits, x)\n", + " )\n", + " print(\"kl loss\", kl_loss)\n", + " print(\"reconstruction loss\", reconstruction_loss)\n", + " return reconstruction_loss + 0.1 * kl_loss\n", + "\n", + "model = VAE(\n", + " image_shape=(8, 8),\n", + " hidden_size=32,\n", + " latent_size=8,\n", + " rngs=nnx.Rngs(0, noise=1),\n", + ")\n", + "\n", + "optimizer = nnx.Optimizer(model, optax.adam(1e-3))\n", + "train_step(model, optimizer, images_train)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mawnZtF8wxu9" + }, + "source": [ + "But here rather than printing the value, we're getting some kind of `Traced` object. You'll encounter this frequently when inspecting the progress of JAX programs: tracers are the mechanism that JAX uses to implement transformations like `jit` and `grad`, and you can read more about them in [JAX Key Concepts: Tracing](https://jax.readthedocs.io/en/latest/key-concepts.html#tracing).\n", + "\n", + "For our purposes, the workaround is to use another tool from the [Debugging runtime values](https://jax.readthedocs.io/en/latest/debugging/index.html#interactive-inspection-with-jax-debug) link above: namely `jax.debug.print`, which lets us print runtime values even when they're traced:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "id": "wziDzgdTuloK", + "outputId": "5d640a4e-306b-4b93-9b99-725ee6e4baa1" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "kl_loss: 167451888.0\n", + "reconstruction_loss: 44.51668167114258\n", + "kl_loss: 21651530.0\n", + "reconstruction_loss: 6.270397186279297\n", + "kl_loss: 4448844.5\n", + "reconstruction_loss: -14.727174758911133\n", + "kl_loss: 1285240.625\n" + ] + } + ], + "source": [ + "def vae_loss(model: VAE, x: jax.Array):\n", + " logits, mean, std = model(x)\n", + "\n", + " kl_loss = jnp.mean(0.5 * jnp.mean(\n", + " -jnp.log(std ** 2) - 1.0 + std ** 2 + mean ** 2, axis=-1))\n", + " reconstruction_loss = jnp.mean(\n", + " optax.sigmoid_binary_cross_entropy(logits, x)\n", + " )\n", + " jax.debug.print(\"kl_loss: {}\", kl_loss)\n", + " jax.debug.print(\"reconstruction_loss: {}\", reconstruction_loss)\n", + " return reconstruction_loss + 0.1 * kl_loss\n", + "\n", + "model = VAE(\n", + " image_shape=(8, 8),\n", + " hidden_size=32,\n", + " latent_size=8,\n", + " rngs=nnx.Rngs(0, noise=1),\n", + ")\n", + "\n", + "optimizer = nnx.Optimizer(model, optax.adam(1e-3))\n", + "\n", + "for i in range(5):\n", + " train_step(model, optimizer, images_train)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "wObaYRxF1-qy" + }, + "source": [ + "Let's iterate a few hundred more times (we'll use the IPython `%%capture` magic to avoid printing all the output on the first several hundred iterations) and then do one more run to print these intermediate values:" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "id": "pBLiDfRX2OOu" + }, + "outputs": [], + "source": [ + "%%capture\n", + "for i in range(300):\n", + " train_step(model, optimizer, images_train)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "id": "O87gxdxGP3uZ", + "outputId": "a4dc57cd-56e0-4597-fbfd-9572e4c4ef7a" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "kl_loss: 2462.782470703125\n", + "reconstruction_loss: -8067.7255859375\n" + ] + } + ], + "source": [ + "loss = train_step(model, optimizer, images_train)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "FXHfa0942apE" + }, + "source": [ + "We see that the large negative value is coming from the `reconstruction_loss` term. Let's return to this and look at what it's actually doing:\n", + "```python\n", + "reconstruction_loss = jnp.mean(\n", + " optax.sigmoid_binary_cross_entropy(logits, x)\n", + ")\n", + "```\n", + "This is a binary cross entropy described at [`optax.sigmoid_binary_cross_entropy`](https://optax.readthedocs.io/en/latest/api/losses.html#optax.losses.sigmoid_binary_cross_entropy). From the documentation, the first input should be a logit, and the second input is assumed to be a binary label (i.e. a ``0`` or ``1``) – but in our case `x` is associated with `images_train`, which is not a binary label!" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "id": "seLRa2qE3wd3", + "outputId": "8b6ba25f-2edb-44c7-c5a9-bca2f334b5eb" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[[ 0. 3. 13. 16. 9. 0. 0. 0.]\n", + " [ 0. 10. 15. 13. 15. 2. 0. 0.]\n", + " [ 0. 15. 4. 4. 16. 1. 0. 0.]\n", + " [ 0. 0. 0. 5. 16. 2. 0. 0.]\n", + " [ 0. 0. 1. 14. 13. 0. 0. 0.]\n", + " [ 0. 0. 10. 16. 5. 0. 0. 0.]\n", + " [ 0. 4. 16. 13. 8. 10. 9. 1.]\n", + " [ 0. 2. 16. 16. 14. 12. 9. 1.]]\n" + ] + } + ], + "source": [ + "print(images_train[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KjF0ys3c30w8" + }, + "source": [ + "This is likely the source of our issue: we forgot to normalize the input images to the range ``(0, 1)``!" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jkFIqZaTXRc5" + }, + "source": [ + "Let's fix this by binarizing our inputs, and then run the training loop again (redefining our loss function to remove the debug statements):" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "id": "9Og1-tIw4BNu", + "outputId": "eebd9fd0-2b3b-43ce-9af8-729544a00aca" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0 loss: 0.7710005640983582\n", + "Epoch 500 loss: 0.3110124468803406\n", + "Epoch 1000 loss: 0.2782602906227112\n", + "Epoch 1500 loss: 0.26861754059791565\n", + "Epoch 2000 loss: 0.26275068521499634\n" + ] + } + ], + "source": [ + "images_normed = (digits.images / 16) > 0.5\n", + "splits = train_test_split(images_normed, random_state=0)\n", + "images_train, images_test = map(jnp.asarray, splits)\n", + "\n", + "def vae_loss(model: VAE, x: jax.Array):\n", + " logits, mean, std = model(x)\n", + "\n", + " kl_loss = jnp.mean(0.5 * jnp.mean(\n", + " -jnp.log(std ** 2) - 1.0 + std ** 2 + mean ** 2, axis=-1))\n", + " reconstruction_loss = jnp.mean(\n", + " optax.sigmoid_binary_cross_entropy(logits, x)\n", + " )\n", + " return reconstruction_loss + 0.1 * kl_loss\n", + "\n", + "model = VAE(\n", + " image_shape=(8, 8),\n", + " hidden_size=32,\n", + " latent_size=8,\n", + " rngs=nnx.Rngs(0, noise=1),\n", + ")\n", + "\n", + "optimizer = nnx.Optimizer(model, optax.adam(1e-3))\n", + "\n", + "for epoch in range(2001):\n", + " loss = train_step(model, optimizer, images_train)\n", + " if epoch % 500 == 0:\n", + " print(f'Epoch {epoch} loss: {loss}')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4HD91gWfyJuJ" + }, + "source": [ + "The loss values are now behaving, and it looks like we've successfully debugged the initial NaN problem, which was not in our model but rather in our input data." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vA6wSi1k5GuZ" + }, + "source": [ + "## Exploring the VAE model results\n", + "\n", + "Now that we have a trained model, let's explore what it can be used for.\n", + "First, let's pass our test data through the model to output the result of the associated latent space representation for each input. We pass the `logits` through a `sigmoid` function to recover predicted images in the input space:" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "id": "fBzJyliAPCGc" + }, + "outputs": [], + "source": [ + "logits, mean, std = model(images_test)\n", + "images_pred = jax.nn.sigmoid(logits)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qiJy39iDPTVS" + }, + "source": [ + "Let's visualize several of these inputs and outputs:" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "id": "dRFxkKInn_gx", + "outputId": "4452517c-ac68-479d-a452-c282e4092291" + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAeQAAAB7CAYAAACl6fPbAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABkqElEQVR4nO19aXNbyZHtwb7vBDct3Y6x3bZnJub//4wXDtvt6Xa0W1KTIrHvO/A+6J3kucULCgQuQOrFzQgEKYoE8lZl5XoyK7LZbDYIKaSQQgoppJBelKIvzUBIIYUUUkghhRQa5JBCCimkkEJ6FRQa5JBCCimkkEJ6BRQa5JBCCimkkEJ6BRQa5JBCCimkkEJ6BRQa5JBCCimkkEJ6BRQa5JBCCimkkEJ6BRQa5JBCCimkkEJ6BRTf5ZfW6zVubm5QKBQQiUSOzdNBtNlsMBgMcH19jWg0GvJ+IvpWeXf5Br4deQ95fxn6VmUdCHl/KfKT922/+FX6+PHjBsA39fr48WPIe8j7s/kOeQ95fy7v3xrfIe8vz/s22ilCLhQKAICPHz+iWCzaz9frNRaLBUajEdrtNv7973/j73//O/7P//k/+Omnn3B7e4vhcIjZbIbVauX73olEAul0GuVyGW/evMEf//hH/OUvf8Hvf/97XF5eolwuI5fLIZlMIplMIhaLIRqNbvWI+v0+3r17Zzw/l/e//vWv+Omnn3B3d4fRaIT1eo1YLIZMJoN8Po/z83N8//33+POf/4w//elP+P7771Gr1ZDNZpFIJBCLxXZZ0mfx/uHDBxQKBeN5Pp9jNpthMBig3W7j119/xf/+7//ib3/7G3755Rfc399jMBhgNpth22TURCKBbDaLcrmM9+/f4y9/+Qv+53/+Bz/88AOurq5QKpWQTqcRi8V28j73WfdOp4NffvkF//jHP/C3v/0NHz58wHA4RD6fx7t37/DDDz/gD3/4A968eYNqtYpcLod0Oo1EIoF4PL4zb8/h+yneSZvNBuv1GsvlEovFArPZDKPRCL1eD7e3t/j555/x448/4pdffsFgMEAymcTbt2/xpz/9Cf/1X/+F77//Hufn58jn80gkEk/Kc9C8u8+xXC4xnU7R6/Vwc3ODH3/8EX/961/x97//HR8+fEC328VkMsF6vUY0Gg3srJ6K9+l0itVqZWc4CHnfVdaVVz956Xa7uLm5wb/+9S/87W9/w88//4zb21v0+30sl0sA8PDD99lsNojFYsjn87i8vMQf/vAH/Pd//zf+8pe/4LvvvkO1WjV95EZjQevH1Wple+zyCXzRM+VyGW/fvsWf//znk+uYfe0SKRKJIB6PI5FIIJVKIZvNolQq4erqCv/xH/+BP//5z/jhhx/w5s0bVCoVZDIZxOPxr677NtrJIHOxisWi74PHYjHM53OPwtTDuG2x9efcRArsZDLBbDbDYrGwA5VKpZBMJu2Bn9pE/t8+vOsGUAnl83mUy2XU63Wcn5/j7OwMtVoN5XIZpVLJDsBTz7srubwXCgXk83ksFgtEIhEsFgssl0vPOs3nc6xWK6xWKzPCkUjE1yCTx0gkgmg0ing8jlQqhUwmg1wuh0KhgGKx+KzDso33p9Z9sVh41p17yueg48HnS6fTiEQiSKVSSKVSgRll5fcp3gGvcp3NZliv1/ZM0+kU0+kUs9kMy+USm80Gm83GDjUPdD6fR6FQQKFQOMggP5d3l2jUkskk1us1crkcMpmMnTGubTQatefQNXDPKvfoOWf1WLyr7jmWvH9N1nWt/OSFsk3DsF6vEYlETAfF43E7Gzz3dMgB2O8kk0mk02nPs2wzyLvyvqt+jEQiSCaT9vPNZmPPtlwuEYvFkEgkjMdT65h97BLJT1ZKpRIqlQrOz89Rr9dxdnaGSqWCSqWCcrlsz/O1dd9GOxnkYxMXcDweo9vt4v7+HqlUCsvlEqPRCLVaDYvFAgDsUAVh+PyI7x+LxZBOp5FOp1EqlVCtVlGtVnF+fo7r62vUajXk83mkUqnAjMI2ovJhVNzpdNBqtdBqtdBsNnFzc4Nms2lRMb3rb42ouBj1NBoNZDIZrFYrjMdj1Go1VKtVMw6qZE9VQ9psNuYsUF6bzSba7TZarRbu7u5wd3eHXq+HyWTyzezFc51mPvvnz59f7Kx+C7SLvNzf36PX62E6nVpGLpVKmSOeyWTMIA8GAwyHQ8znc9M7p1xbVz+mUilzMJPJJFarFSaTCUajEYbDoWVWXAfptRONcTKZtExipVLB2dkZqtUqLi4uPFkhdawPeb6TGGT1st2fA18U8Xw+x3A4RKvVQiKRwHw+R6/Xw/n5uaVG6BHSA1GvPSg+ySvTbuVyGVdXV7i8vLTI+OLiAhcXF5Zy0RQFoyL3GQ/hk4aKB/r29ha//fYb7u7u0Gw2cX9/b0aAh3qz2Xh48nvOQ/k6lJQvKi5GAe12G7FYDKvVCsPhEL1eD6PRyON1UxZOyS/5nM1m6Pf7uL+/x2+//Ybb21s0Gg3bj2azieFw6Enfca1P6UA8h/wiSv4cgOfZB4MBGo0G4vE4FosFer0e6vX6yc7qrrStZHOqz/6avNzf35u8sESWTCZRLBZN1xQKBUSjUYzHY7TbbTQaDfT7fQDwGOVjr+82/Viv11Gr1ZDJZLBer9Hv99FsNtFsNjGZTDxBy2vQO6Sn7BKzAOl0GsViEfV6HW/evMHV1ZVFxhcXFzg/P0ehUAgsMDuqQdYN5L/9vnJRptMpOp0O1uu1GZ/BYIDlcmmeWDabRSqVMs8rKD71+1gsZgJXq9VwfX2N77//HpeXl+YhlctlFItF2wgApnhdg+yn5J5DmhJlVPLrr7/i06dPaDab6Ha76Ha7ntoTlT550a8uHy99OKi4NpsN5vM5lsulpfAWiwWGw6EZ40QigXw+j3w+73GGTvUM6/X6kVH6+PEjPnz4gLu7O7TbbfT7fQyHQ0ynU0s3ulHCayI9j4x+qJASiQQWi4WlWtfrtT17JBLBcrnEcDi0s7parY56Vr812iYv//73v9FoNNBqtdDpdCyjAgCpVAq5XA71et2ycbFYDIPBAPF4HPP53EpUp1jXbfoxHo+jUqng+voa19fXhnNhUMVSozrQr8EZ3cUuUf5zuRzK5TLOz8/x/v17vH//HhcXF6hWq5amdjOlry5C5oPycNMYqFIiuYszmUxMgAnIyGazODs7w2QyMeUA4Chetyolpo1qtRouLy/x5s0b1Go1q9FkMhlLVajCYh2XwsuXKwC7kh7q4XCIdruNu7s73N7eotVqYTAYYDqdYj6fW02HToI6B+SPvHE/XgOpUeYeawYlHo+jUChgNBpZWp5ycGoeGclPJhNLrd/c3OD+/h79ft/kVJUXjVs8Hn9VRtk1xlqT5ItOHtecRrnX62G1WlntfL1en/SsvnbaJi93d3f4/Pkzbm9v0e12MRqNMJ1OsVgs7OyyJlypVFCtVk2HDIdDq5XzvL9UyjqTyVgkf3V1hXK5bA4ZyxrUP+l02re+f0ra1S7xHKTTaRQKBcsC0AYwVU2nU+vTrzZCVkNED0m9JE01Mq1HY0GvGwAymQwGgwEmk4l5hW5aOChS5USeCehiMZ81g1QqZcaYPBOEtFgsTAER/EClzOd+zsa5wJDxeIx+v29R8XQ6xWazMZCFKxhUCiwNLJdLLJdL24vX4LWST7+fua+XJrcuOBgM0Ov1LK3OiDCVShngpFAoPAk6ekly5T2Xy6FYLJqCjUajHqPBzIWeV+DhrI7H46Of1ec820uTKy/MKPA1nU49mS2/stdrkRN+dfVjuVw2fAcAqyEzS1cqlTx689i4m23EwOgpu6SOablcxtnZGer1uqWqq9UqSqWSLwDy1Rhk3SzXi0okEtY2xDQjH4DCSk97MpkYGEaRiKoATnHA+RzcmFQqZRGDnzEmkIEGkukaeln5fN7ed5+6J42yoo+p9CKRiCEYM5mMIZC19qeo7NFoZOmx12KMKTN0YjRdRBRjsVi0thq/1oJTkDoGq9XKHDE6OpvNBslk0mPUeKDL5bIH/fqa1p0ZoXw+b+BF1jMZAfT7fYzHYzuLfF7gS5qVWS2WHF7aGL8G2iYvk8nEg8pnRog1eXW8qS8Hg4EnG/aSa0z9yAwQUcjRaNRAmLPZDOl0Guv1GsViEdVqNdD07iE8sz2JXQ+0S9Qr1PvFYhHn5+e4uLiwrho619SzQUb8R0tZM0KgYmWbUKFQ8EDdmaKcTCaWjm02mxiPxwBgRlgj6GMLoVvzVY9KYfMK2CAgrdFooNfrYT6fI5lMolQq4ezsDMBD1kCNz3PIL1IkLJ+po0qlYkZLQV00xorQXiwWryKFqmvNTAJ7vnkQzs/PrYZPROfXetKPSdwDLVEA8JQ5CAJkuosAkJeoffuRC9LJZrOoVqt48+YNAFhm6O7uDqlUCsBD+cQ9l8y66Hq8pqzGS5MrL/qVL9bktd4cjUatXj+ZTNBqtQy8CeDkZRvAe16p06iHEokEAFhAUqlUsNlsrCbOXl2Nkk9xBlxZd+0SW7Do/NDRyGazqFQquLi4MN2zLdMVBAVikDX3TuNFzzoejxsw6u3btzbggRvHehQNGqHzNHYa5Z2KXCWih8mtE7NWwtrQp0+fcHd3Zwb54uLC0slaa9iH3HofezAzmQwuLy/x9u1bq29kMhkTeB7yXq+Hz58/G+CCIJF9a9uHkh9YhOmjWq2Gq6srfPfdd3YY/A7PSxljfuX3jOyphK6vr/H+/XtzImq1mnnYL8m7S7ruuVwOtVoNwBfHolwumywBsKiOfbN6LvTrazLAr4kXwB/k6Way2PYXiUQwm82QyWQQjUbNUHOoBc+xC9w8Nukeq/wzy0Kdkk6nDVvAenOlUjlplugpQNrZ2Rnevn2Lt2/fWhsrAxkaZgLsmKljhK/lviAp0JS1Cwph4/7Z2RnevHmD7777DvV6HYVCAfH4l48merjX6yEej2MymaDb7Vqv3UsqLR4Qps+ZKuZGMPJkTajRaFhLw2QyQTabNe+wWq1iPp97lNZznk09UnqjzDQwrfL27Vu8e/fOMzkMgGUg2Eo0mUwstf5a6lOAF0xRKpVwcXGBN2/e4Pr6GmdnZyiVSsjn81aLeskIWUn3Rnl/+/at8c468mvjHfiy7iwrbTYbew6m8tbrtWWvOp2Ob/3P7TB4bYbwpUkzbUTwspODRpXOPqNf9nYnk0lzrKmDqD9eMguhaXgOcAJgGS4i9OmsZrNZj/y/hOyrjikWi7i4uMC7d+9weXnpGVTi1pIzmUzgAC4/Otggu4YinU4jm81a7xy97nq9bn1b2WwW8XjcIjfWTV4aiafKhHVtRsDD4dBSR7PZDLFYzA5Pv9+3IR3sCZ5Op8jlcshms4+Qwc89PJpuoeJkjXK1WlnDuqZ1c7mcRWEcI0eUJtOmjLBfA7kpJXrU1WoV9Xrd6k9+6aLXQN8q75rdotJUsMt6vUav1zPj/Jrq37vQS/Op+lHHL7J1T8FcdNipVxgtMwDQUgFLgqfOHPKr6sfJZGI6knJEYCPlhc7HS+l21S/kh5O3WF6qVCqPykk0ynSiggRw+VEgBlm9DoJCptOpRQRMObJ3l143PStu6Hg8NiTnSwEWNA03n89tRjHrOaPRyEZqak8mJ+60Wi10u13MZjMAsLGC+xpjwJsO4izVs7Mz85DZJ8caMhGNmoWIRqMWtbugipdufdJUvKIfifjlK5vNvnjt2KWneGdEQN5fUy+mksqAZkzYcqiRwWtqlQuajtW54bbRKGiOkXK/3zdnn8aOaHaVFUbGdKZPrSP99GOn08H9/b0NL2FEzxefnZEyDfNLtcL56RhOG9NZ1C6O6BSzBA4yyOr90eOo1WqYTCbWJ5dOpw0yTmMciURsMhcnunCMXKvVshrJS6ZiiJzmIA6CKwiYojGkULbbbdze3qLZbKLX6wGApe3Vs92HeKAJ3qrX65jP59aIXyqVcHl5iWq1akhkpl70AKuHt83Lew0gIx5i9UyVdzVofmv6Us/gAl107q+uOem18K51QE2rnioqeAna1k4XpJHw04/VahXX19dWV2UvayKRQLvdtjKY1oZd8psudUpy9eP9/T3i8ThGo5FnjKTqLQZqADwp4VMONlH94so4I/htUfypSnuBRMhqLDjasFKpWA2ZxXzWjmnEOLXm8+fPZpg5eWo6nXrAT25xPijyW3AK3Hw+R6fTAfClr04R4iRFRXa7XXQ6HYzHY0tpf+02kV2Igk3E32q1Qjqdxmg0AvDF8FerVc98bQoV0ZuvqVb8NdLD4HqmwGPQnUvunp6Ctjk2rgH+Wo31lLy7/Oy6rt8SbXse1xgzVUxjF6RRVv14fn4O4EvPNif+sV1InXcATzrx5PnY9Fz9qJe+KFiQWVM3fX3MKNm1Ga5OcfXiU5HwqWT/4AjZ7WGMRL70xNZqNc+1Z+zbYqqa0fGnT5/w6dMnNBoNA3ONx2Mbsq7N2qqUgzbK7mYAsN5SIpQZZWrbE9NK7C1kTyEzARrR7cuzC0SIxb5cvcZbXyj0vPGGxljbLAhO01aLb4F03Vi7IvkhPflVa0CnNG5aX2NJhu1lVD5+vJySd9cI+Rkn3ihGmTk0y/NS5OcQuWAkPqteJej+7b6f7epH4OG88pY4IqjZu0sn2nXo3UzGqUoHz9WPzKbQESkUCri8vEQ0GrVWOoK79gG4Ppd3F0zHm6D4HFxnvty1PaUjGgjKmgaDrTZM1XKhmRrQGgnHyDFV3Wg0MBqNMJ/PDenJa/Y46Sro2qEKmjsyMJlMWu2XFzs89bkaXSgKWifT7Mu7pr1Yh8lms3ZY6YnyQLheLIetsG2FQwW0XUWf47UReSQAxq/XFXi8n9rof4q6rTpA7PtmTY0XLjxlkE/Bu/JIUI46aW59cDwe2xhMnulvxZkDHjttaox1jzjlbrPZHHxeXVL9SKAT9Uw0GrUZ9f1+37AnBEa5/d6UezfCOwbtqx/5d3zOSqWCaDSKcrlsOv7YGCHlXevF7CSgPaGsc0iLGuKXaE0MJGUNPNw6QsOhM5P54pQpRedx+gzH7QHwgJc4oUkBPUEKoW5YNpu16Uqz2QyRSMSAZjotzO899M5Per8EsR3Kuxst0VkhL26aZbN5uJOUI+y63S7a7bbNzt3lcu7XQGqM+Vw6GUufgWtANKsqD/7/MaJN4OFGLjpAvV4PnU4H6XQai8Xi0fQ0l49T8K7GmEpIQYdqcInM5yAK4jp0Yt63RvrsnB3APaIDq/cqA4evu6sf1UhEIhG7KYu32jG7xv5dHSDC9aeePEXr0L76kQaZV0cWi0XLABwCcN2Hd/YSk3emzNnHTVmIRCKmU9h7f2oQZmCDQdy0BvD4ej01ZppG5QZFIhFrmeJQ9fPzc5yfn3vGDgZhlJVf91aP6XRqHm2n00G32zVhU7AZjSPT2KzxHoN3PdgK6nDToBqh9ft9dDodtNttu5CCNxLRKBP49VoVrM7wZj8sb3/Si90BeKKPQqGAUqlkF5czuwAEn4LSNKi2wWWzWcxms6+OC6QcKRJXeWf0dijvGv3qPHTeTMVomWeToJ3b21uTmW0GWVHap6DnyKumgPUuZ963zQhvMpn43rcdROoaeIiU+Z5sb+K1lcvlEplMxmYw6OQznZs+GAws7XqsCHlX/ch53DrjnAOdlsslIpGIGWx1onfBLATBu95nTGQ7Z2zzEgzeMMhsCQfiuPt/CtkOdHSmGg3AmybSwQF+E6/oyXBM2cXFhaGzLy8vbaBIkE3l6kEVCgWcnZ1hPp8jkUigWCzagY1EIp6GfL3NiRufyWRsQhPvzAyad36enxC79T8iv29ubuzu1UajYTfNcPweo4PXRhrNMZvC0YGNRgOdTseT/gLwaC8JoNHsAWUzyMPF6IrRS6vVsvIMZcgdJuCm97TextvFXAf3EN7VaZjP5zaakYDKwWBgUQ6Nst4upjLDy1OUf7/XscjVK3omtyl5fabBYIBms4lkMonFYmGOCe9ypuFU5zcIo6yvzWZjmBDKcDqdxvn5uZUI3F5fIpqbzaa1SalRDnrN/fTjYrHw6Me7uzuL9BlcAXgU3bupdxe/wDU6Fu/1et0cHso6Hf27uzt0u13rEuJAKpWBfacrPpcCn2XtLqoaChdcpIaNKZFarYa3b9/avZO1Ws3unnSNWhARsgIuOHM6k8mgUCggm82ahzccDj03ThF0xhQ9//7du3c27vGYvPuRGwHRIP/yyy+meBnxD4dDe4bXMvLQBRvpRRrL5RK9Xg+3t7f47bff7BDxIo/NZmMRZqVSsdvCmPZVYxiEUfZDJ5NPGubBYOA7DEQ/n0ZXUfTj8RjRaNSwE5rqPoR3zTYoqPK3336z7oD5fG7GgFP02ItPg6zlKCort5XFdTiCIlefUKeo8nflmP+mQWaHx2q1wmg0svotjQ37xllmCJJcByyfz5vsFotFT71es12DwQD39/dWDuQeHQvY9ZR+1AtzONSJ2BQaZAZZbinN3TtFt/Nzg+JdnVxd436/j16vZ5PneI93JpPBaDRCJBKx4UvuBRLHpqNdvwg8jnKoXPlS4JeOHTw/P7dIU2vIfpHGIaQKhUJH4dD6zrbZq9x41ko4sOMUvPuR1luZOm02m2bE6FmPx2NrTyPI4bWQPgNlhoARpuAZ5asR4b4x5QfA6kY6O5rPGiRKVeWc4BDWkQnq01tkXBS11ug0Q0MULo3CIbyrA6FGmUMdGCWrYqVh4C1Oev0ieXBvQqPzoyjWoMhFsbv6RNO8roPJ52bdU0sMPA9sH+SNVhp1B3lm1QljNwbbQ+nw87N1tnUikbALYugcHdNQbNOP3N/NZmPpf/Kz2WzsGfh86hDTFnC/6GAGPVNf0+0MrKinOY2OJZhut2vAPgCeWQ9P9YMfg45mkFWxEuhCT7vf71vEqTB/jh3UGjL79FxIfVBCqOkeKhn1+P0GOijx91mrIBBNedfevFOk87RWxv7oVqvlSfNuNhsDhr00SGcbApbXWXImd7/ft3t51fAxcwHAWimy2SwajQbK5bJNVqMXrGCpoJ6bvNOIRaNfJqNRZrVjQLsOyBcAD9BqOByi3+8jnU4DgPG+Xq8P4l1BaNrqwc/m1ZwKINLMFs8qn4nREu991ta7II2yGmO9nnA4HKLX62EwGDy6s9nPKDPdTt3EbF06nUav1zvJXc6a7XANh9v9wFo+jXGxWLSsCx2MY5KrH7V9KJVKmb4ulUqWYaB+ocNGnqPRqNXrh8OhZefW6/UjHRkE35Q94kd47gDY8Cq24A4GA9PhfqOOT6Ufj2KQVZjoibNO0+l0rJ7JlCPHw2magb3L7F/W6S9BCaFfnZK8quPAubK6MS4fCsrx4/1UxtgtD7C1Q9ueeJDj8firaWHRVLVGb/l83vp46Uwwm0KP3UVex+NxtFotMxAAMJvNbKyoAr2COGguJkIzP6vVytO2Rw9d2/g0CiFmYTqd2tAF7iEvqNiHd8qd60Dm83kUi0VL/VNmmIbUWibfh+tPRcybcDgilwo46LnF3GueU6KkWeNutVoW5T9llJlW5fPQ2GkN/dhOqu4H9YN+purQWCyG+XzuKX+4+uQYemWbftQbp5jerdVqSCQSHl1JWWP2kK1SKtd05DjrPch6uBplvud6vbZ1JD6Ask7ZdzMtuh7HdoACN8juJuqsU6bGCCZpNBpWYNc2A78I9RjGWEEu5PPz58/GJ+dTd7tdTCaTR202wONxj+pBuqMeT2GM/YwyozatTblI7ZcgV9i1LY4jS4EvkTHTSZvNxmpXHDuYTCbRbDZNmY7HY7RaLevzZPqYF5xQWUSj0cAdEnUsmNaLRqOWDuVMd5YwNP1M40DF1mw2rcZZLpdRr9cP4t2NxmhEGXWTJ9YtaaQV68HaNmcyE7RIEB1vt+JVdkFEPCrPOnKX8wv4/W+//WZocMV6+L0fnVKtQ7sDUI51Ltz31fKF/g7XfLlcmj55qnUuaB41IzEej02Pt9tt9Pt9y0yxrEKEuj6fylwymcR4PMb9/b21YtZqNQvIVFcGpevVKOu4TMqlZhSj0ain3/4l9GKgBtk1xtPp1NCcHz9+xMePH21eNedX9/t9Oxx6EPQgqaEJGiDiok7Jp87YbrfbZpCBx2MRXQCLmwI/ZWTsB5hQNLu++LevgcgTa2YKlGJkyKsXK5WKRQ3dbhepVMrjAM5mM3S7XQCw7zudjgHZ2BoV9G1X7toCMAPGsYnv3r3DxcWFB+Snz86aOa/IBL6ADM/OzkwxM+J+Du9at6RB5bljqaVYLCKRSFialAAjbStjrz3np79//x5v377F+fk5arWa3VutzxdUD7WOqb2/v8eHDx/w6dMnc6Dp7LNt6Kn6n3tmthnhIM+He+5ch5TkRr5+euUURP2oyHTe986zxJkRtVrt0U1JauxYYmDgE41+mdo1Ho8BwDPYKGhEs6uPXZ1M+XZtzktQYAbZzxiz17HVauH29hYfPnywCxg07QHAPHPtTdaxg7rJQRg3vwiZfFLoiL6kYgJg10YCjw+O/uzY6SR9Dq2HMX2rIwFV2byU5/cUaV2Th5dKUhGcsVjMmvuZriZKdTQaWWTEFK+mv8fjMVarladXnGtzLNLogCjVN2/e4O3bt3bVGx1RvRd8tVqh0+mg2Wxaa9psNkMqlbIeyX14ZxTCiw6Y1WHqmm1AnKA3GAwQi8UsK0RngOMP6/W63XN+eXlpIEbe8hZkhKxndTgcotVq4ebmBr/++ivu7u4M6MQhQwrKeup9gdPc5ewa/20OgKaylRfXqT926tRPP7I0cHNzg9FoZFExr3+tVqsWLTODo1nS5XJpnR7L5dJGidLJPjaQbpfn1f14CWxNIAaZG6fgEJ2Ew3uCP3/+jM+fP6PdbtuEl83mCzyeBkTH2Q2HQwP0bDabwO+SdWvdFBw/T5t1MwoZN4sp9W10bGPsrjuHZhCcQ+P80sCtr5GbbqcBnk6n1sAPwKKzSqViQECOHhwMBgYo0Roo/zaXyxnK/FRrwqiSN/vw7tVqtWplGkakTK+v12vrO+UF9fF4HGdnZwfx7qJ7+T0N52KxQLvdtkEmjFZYbyM+gnc9n52d4eLiApeXl7i4uLAZxYqsPcVZZVmJOkVlfpf3PTYp766j7PKoe0Rn5qWAl8o3dQsxQJPJxHAEbCcifkCvp2W3x2q1QrPZxHw+N0T2bDZDNpv14HReu546Nh1skN2DQmSvtqgwpaQ9sBp1plIpz993u100m02LmheLhaUrgeDGmbl1En6+zu+NRCKWimG0QAAR/z/o+vauvPutuzpBVOj7GuVTHgw3JccUFkFQemcpW4K45kTvM/2VTCYN3EOlwLnMPPjb6otBE3l070hmK50qaXYiKHBmMBggnU57lNYhvKvC16iLYyNZk6cTQNQ6xw/SGTo/P0e9XketVjOwHG9DC3oWNOB/VnWtiLAmIOe1KHW/c+p37ztJwaFcy5c0yFxzFyDKXnl2FFC2aZCpKznHXX/GgA2AB2h66hYjPuNrokAMstbv2u02Wq2Wve7u7nBzc+O5zUmFkcAKBQ58/vwZ8Xjck24kYADwpnWC9L4V4EHwDOt/HBSwXq89Tof2LZ+Svrbunz9/tlQe+ywp8AQc+b3nUzXmoJ/RTfPTEAOwUX21Ws1qkxy2QuXPdWerEBGUbPGiQwI8DMVwEbT7HkhdCxfJ6aYWaZz0Wjq2O1HmOd6U4MdOp4N+v4/pdAoAnrajQ3lXeaVToPdla880I2VOEOOwm6urK1xdXaFWq3nu4D6GMVZyz6ofYHGXsswpyknkV9O+bEEcDAaeiVyA91Y3Ith1+EYQe7/vMyiojn3fnPKnpRCWNSjfxGxotM/3obH2yxQc4xm+BTrIILspDQKjPn36hJubG9zf36PVauH+/h53d3emYFzAhVvHpTFmtMeRdoqQIx168F1gBf/NIRME49RqNesT7Pf7uL+/RyQSsR7BU4C3lM9d1t29RUsF3w+84qaN2XP6FIhtX/IzxnqjTLlcxsXFBa6vrw3Fy3GkrCHz7yhLqVTK2nDo1AEwBKX7rEHxToW0ayvKZrOxrEW73cbt7a1NUiNqmCWT1WplmaIgMQCa2dHxgOo80DnK5/Oo1+t4+/atjbSlg8SxsKcwxkrbMBF+TpH+jft7T71/EDzSACkYjQEK9aEGADr17+LiAufn58hmsx5n8pStiuoAuFgU17n1A5/57ce2vTuW0XR1GvfkNabHA4uQiVRuNpv4+PEjfv31V4PIK+CCgB16VlQKm80G0+kU7XYby+XSvMnBYGC9pezrdYEiQaavgYf2k3g8bkPJr6+vkc/nsVqt0Gq1LBXKCPlUyEfl96l1b7VaNkGH6+6mhPRQ8MBrClWVtZuVCGKqjmvMOEmHYJHLy0u8ffvWAEMENDHd66IkCTjiHbN8FtZhgzQYyru2u+n8YwUiqhPF+jjHmxJIyMiYezaZTEzWj9H654fgVWXKfeF9tkRUn52dWaaCdeNTGWM//vhiGn4bpsONLvW9/D4nCPIDRt3c3ODTp09ot9sYDoeeVkQC94hi5qwAYlc0HX/qKFm/9zNwCshlDdnNSrk18m1GO0i+1TFyRze/VDlgGx0cIfNhWQPr9XpoNpsG4KJRZeM9jQI3g9ElL3BnKwvrFaxDEF1LhQ3g0VjCIBUVeVNUabFYNCVP58DtIz0FPXfd3RqNy6dbIxqPx4Ymd3sDVfntq4D1ADIS4+xYtlIUi0VDQ5fLZVP83HuWO5h+5xQrArpYV+YYTU6VorzsqwTcqFgH2bDuB8CUEuC9mpF1xPV67an5M8VOVD+BhBxN6Tep7ljypnsTjUatdlyv13F1dYXz83PPMIegwZZP8eSuO687JaKeMrCt/1ijO8B7LWLQRCWvaVq2sxEcpSU8AIYX2Gw2BgSkHClI050kdiqD4hpjHa06mUysrMd6Mf/Pnd7F6WjHHJzkGmKtgeuVkcecyvZcCiRC1nom0XgcRt/r9UyQFJms975qszsXkPNmOQCCwBFOKFosFjamD/DOQg2CXI/Qr2XhlKkjP/52XXc/wIRbF1JFwfF37jQpGk+dW7zvuqsx5pWb5XLZfsY2HBpgpuepeLUtRKN6yhcnUREFmkwmzajr5Ld9DiF5188ol8tW7yXAiH2+dDTZeUA8AkFcrA9GIhGTb2YAeAMa67SH8v7UM7n/5r7rSFtGxYVCwTOf/Zgz2pUnrjvnlNdqNXPaGf1sWxeN2DgaEYAnSjsW39v+rQaNaPv1em03QbGtLxqNmpP9UuA1Nf48c5ze1e127SpLZjHZLcFWNJZf6FgXCgW7xEcvUQliH1Q/6ljYwWDgCy5mFjbIstA+FFjbk/YgMxLQezAVuMBN4bCHdDptbRfcYK2PckIMo6DxeIxarWYpTOBByINYSBpdRjO9Xg+5XM42jmAhRvz83FNv4i7rvs0Y+0XYjUbDLkfXNCRfHAKg4+6eu+4a5dAYM/piqQKAIep5dymjdjcq14OngzXogXMSUDqd9vRKqtLYlfx4Z0ljMpl47tCmQ0TlxfT0/f09gC+RkLY7MRLSS1f4OeVyGVdXV6hWq57+3iAdQr/902jUvUTiWNkpP3pq3TkpivX2bcAnnmk6n7xXm1mgYxhkxRIwome2oVqtmkPAAIPnlzqP7We8RIIzHbT0d+qgQIMTOvHNZhPZbNa6YZg55LlksMBzWKlUAMCudCyVSp7JdUHshQYrGmwQPMlJkcxIveR0LqVAJ3X51RP86gfsJSVIh4hZ3mRyf3+PXq9ng79brZbB/+mNDQYDXF5eAoAHjLLPgrrGihvJm0BisRiWy6V5gMPh0DP2k60hL0Xb1l3HHvr9jYJNGo0G4vE4ptMp7u7uLL3LNDWjEra8aNnhuevO93TvKmW6jsqG9fp2u73TIVVFvNl86Vuv1+ueqz0vLi5QLpctM0OH6hDe9YYmAsoSiYRNeJvP5+j3+2g0GohGoxiNRubwEDzody81o9NcLmc1dS3bPJf3fcit87sli2NGlkrbZCabzVrUqO1grqGiMdZb0O7v7+26Pdadg34WPT/sQ5/NZojFYtbKx7Gwq9XKUqqdTsfKRq1WC9Fo1EbCdjodTCYTe65j6x5XP9Lxnc/nllFbLBZoNBoW/bpZM55JBmLJZNL2sVareTJiQeGBXP12c3ODu7s7A01ySNVoNDL98/+FQX6q4O8HluKw8cvLS7x79w7VahXRaNR6LnWoAy+lZzTR7XZtCAAA63+jx37Igmq6gtElhW8wGNj708MejUYGxti13SJI+hrQ4il+XLBJs9m06N+9/IBDUbhn7CvUyzP2iZCpXAleKRQK6Pf75s0yteTewOM+u4JzKF9ag9ZLEPiiR87a3SG8x2IxT1mFilQzFRyDye4B8kQe2Vqn9TRtPWIb2L68H0Ia6T2Fnj3m5+u6F4tFc7gqlcqjyyQAfydbR+Ry9Cj/T6P8oEFzOq6U2ABmhlhSYX2Tt5jR8A2HQyvLLRYLi+zn87k5ZadKs/rpRxq9fr9veAL3/gHKL0sM5XLZsmy8mCTI++IB793f7Ir58OEDPnz4gEaj4WmNnEwmiEQiryJKDnR0Jklba/TnWpPK5XKo1Wq4urpCvV5HLBZDt9vFcrm01MJgMLCiu06LmU6niEQilvJg6jgIb5FGjQ4AD0iv17MaAwcT6ASvl9rEbeu+y99p/zdT9HxOTUWyhlitVrFer23vqtXqXuuu6cdcLmeGlEMz9PYd9ldTCfk9Iw8w+SyXy4jFYnaRw/n5uR16vXyeSuM55PJOhZvNZu2eVU4OY2SzWq3MC6cMl8tlrNdrA0bV63VUKhXP3dt86ZAU7TQIEjPxNQV4KsDiU5+vKWtm2jhGdBswhz/j2vOOaso6o+tjpd3VUcxmswDw6F7eyWRich6JRExeOPyEUaOOxgVgd5mfOghg9or6cTKZYDAYePSG6nrqaZ47zhVgnzWxQJwIF1SEzKwIe/15AQnnM7DtjAHGS+OCgCPeh7yNtPldhz/QM+p0OshkMpZC4mFjioRGsFgs2mEKesqL1ldZA/FrYQHwau4U3ofoeAAPveC8p1SR1VQmNCCHrrtGf259MBqNYjqdekoUvF1Gezbd9+Pfs65drVYtEqEXrmMd9619Ku/8N4FokciXISWNRsOiWCpSBZ4RfMY2OnYR1Go1S+f5odv5Oafq9fV7dr/vT/XZKjOaBt6WOSGpg01j3Gg0LBNE8N2x+ea/9damp+SFeBDFSygw9qVKZerMa43YvW6Ten4+nxtmgpkCbWFUpHWQcq1lOXZd0PFh5o17T0zGS+vwkxtk7TklSEQ9Rfeiau2/499ns1lLYz7lHe9DWo+l0GnKHfD2Mb6mO4X3Ia4vn1lbuBRdvdl8acPwQ27vs+58f+DhAvHNZmM1eUVpcq65GmRdc60vRiIRlEolD2pZ79d2keP7kJ9DwQwD68H8HOCh7ka+I5GIKSqmvAmYY03RbW/S/TiGMX5pRbQLabqc678tTU3SCDkWe7hXWDESmiI9ZpT8HHn5mt7he5w6Qt6mH11njQ7kcrm0rg2Vdb/74oOUaw2ctBOF41bH47FlORVd/dK0k0Emo6yDkajM+ZBs93Av+daX9q9pbyYPizt4XZGEehE9EdCcZcsIjq0n7iHdh3du5jaD6zbtsw5EeH00GrUhELsIGnkMgvfnCBgVlhpJNQI0lO7FH/usu99na0nC7RNkiwdflA0lVdTcB+2N5FcAnv2IRCKP1nxf3sk3a8buXmiPNF86hpA8MjPgN/XLrdsewjt/V7sZeCEJ+zMpE3p723A4RDabxWaz8Ri059Kh6+5XN90m79Q3lCk+n+4RdZG2x3CsKJ9zG+/HkpdtoD136MaufD+H9331I2WBP6cO1A4QV9Z3zVY9l3eOOKZsM4ijbKuTrA5G0LLux/tW2uxAHz9+3AD4pl4fP34MeQ95fzbfIe8h78/l/VvjO+T95XnfRpHNV032F4/j5uYGhULhxcEdX6PNZoPBYIDr62vzJEPej0/fKu8u38C3I+8h7y9D36qsAyHvL0V+8u5HOxnkkEIKKaSQQgrpuHTaGxFCCimkkEIKKSRfCg1ySCGFFFJIIb0CCg1ySCGFFFJIIb0CCg1ySCGFFFJIIb0CCg1ySCGFFFJIIb0CCg1ySCGFFFJIIb0CCg1ySCGFFFJIIb0CCg1ySCGFFFJIIb0CCg1ySCGFFFJIIb0CCg1ySCGFFFJIIb0CCg1ySCGFFFJIIb0CCg1ySCGFFFJIIb0CCg1ySCGFFFJIIb0CCg1ySCGFFFJIIb0Ciu/yS9/yvZMh76ehb5X3/5/u5QVC3k9B36qsAyHvL0W73oeMzQ708ePHDYBv6vXx48eQ95D3Z/Md8h7y/lzevzW+Q95fnvdttFOEXCgUAAAfP35EsVi0n6/XaywWC0wmE7TbbXz69An//Oc/8Y9//AP/+te/cHd3h36/j8lkguVyidVqhc1mg0gkgng8jkQigXQ6jVQqhUwmg3w+j3K5jFqthnK5bN/X63VcXl7i7OwM5XIZmUwGiUQC0WgUkUjE4x31+328e/fOeH6K9+VyifF4jE6nY7z//e9/xy+//ILPnz8b7/P5HIvFwsN7LBZDKpVCOp1GNptFuVzG2dkZrq6uUK/XUa1WUalUUKvVcHZ2hmq1inw+j3Q6jXg8brwrfY33zWaD9XqN1WqF+XyOyWRivP/000/45z//iV9++cXWfTwee3iPx+NIJpPI5XLI5/MoFAqo1Wq4urrC999/j9///vd49+4dzs7OkM/nkUqlPLw+5YXuuu4AsNlssNlssFqtsFwuMZvNMJ1OMRqN0O12cXd3h19++QU//fQTfv31V3S7XQBAuVzG1dUV3r9/j+vra1SrVZTLZRSLRRSLReM5kUggFoshGo36rvNTfH+N9238cy9ub2/x888/48cff8Qvv/yCbreLWCyGy8tL/PGPf8Sf/vQn/O53v8PZ2RkKhQLS6TQSicTO67wv736yMxgM0Ov10Ov10Ol0cHd3hw8fPuDXX3/Fx48f0W63MR6PsV6vEY1GkUwmkUwmkc1mkc/nUavVTL7L5TIqlQrOzs7sVSgU7KxyP/hs/X4fv/vd75617u4e8FkGgwHu7+/x73//G//7v/+Lf/3rX7i9vUWn08FwOMRsNsN6vUYsFkOxWMT19TX++Mc/4j//8z/xhz/8ARcXFygWi0+ezafW/bnywvXv9/u4ubkxefnnP/+J29tb9Ho94zkajSKVSqFYLOLq6grfffcdfvjhB3z//fe4uLhAqVRCPp9HNps12Y/H41vlaFfedX1HoxFarRZub2/x6dMn3N7eotlsotVqod1uo9VqodvtYjKZmJ5JpVLIZrPIZDKm39PpNHK5nOmfSqWCy8tLXF9f4/r6GpVKBblcDslkErFY7Nm6fbPZAICt83g8Rrvdxr///W/8+OOP+Otf/4qff/4Z9/f3njXebDaIRqOIxWJIJBKm15PJJOLxuOl7fa5arYY3b97g+++/x/v373F+fo5isYhsNotkMvnoTPf7fbx//94j7360k0HmolDxccNo1OLxOBaLhUcoVBmqouG/+fCJRMIOeiqVskXg/+kCcSP50E8pMP7Mj3cAHt6Xy6UJD4VBFXosFrNnJv+qZPj77t/QAGYyGeRyOVNQXzv0frwXCgVTqIvFAtFoFMvl0jaenw3gyXXXf1PoKWh0igqFAgqFwrMM8q7rrrLDAx+LxbDZbDCfzx8pb64n1zyTySCbzXqcCjXIVKq7GmSX76/xDnyRHcrPfD4HANsL3QcS5ZlrTFlQg7WPQd6VdxoEd83pUM9mM6RSKXuRV/K9Wq1MplW+KTvRaNTzjNyfbc+nfD5n3ZVoMGazGQBgPB6b8uc+qPzo2SW/1Dfkl86RnzF4at33kReeYdUfVP6JRALL5dJzbrl21EUA7DmSyaQFBu6Z3Zd3Xd9IJILJZPLIeVS9l0wm7Tn1nNIQq1HmK5PJ2HngWc7n81sN8td4V8dnsVggEolgOp1udVZ0LSnLXEvKEvdEZVyfR2WdBtkNCvxkfBvtZJCfIj3s9Db4lRHxer02ZihgNMA0tnyISCSC9XqN2WxmkdNsNsNiscBqtbL3pjcUJJFffgXgOQxuhKyCT0XHSI/PRf6Xy6Wt0XNJnR9+xng8tghnMBhgPB5jNpuZd0j+KXjk3X0//v6x11afxT04XLPBYIB+v4/hcGjRPflSw6yKS5WDOhv7GLbn8M/9Ho/HGA6H6Pf76HQ66PV6FpVxH55az2OutfLrGuXlconlconFYoH5fG7yyQiHGSDuEQA7A6vVyl7z+Rzz+Ryz2cy+XywWnvdzn/OY+6LPx/PGF+uOun90SOikqI4KglfyxUyKKy/dbhej0cgMtTpBq9XKeNhsNp6z0u/3kU6njT89F/s6dk89A5036p/pdGpyowEWI3pmUWh8GXTR2LrBQRA8qh3i+lHOt9kiDRDVseHLzTa4ARBf+p6H0F4G2T3g7ksXQNPUGuloREZvhFHvcrnEdDrFcDhEMpk0Y0OjrJ9/yAKosuD3emCBL4ucTCbte90EHgL+3Ww2w2Aw8EQPmUwGxWLRhHcf3tV4aaqx2+2i3W6j2Wx6DIEqQlUu5JnvxQPG72n8jkW63mqMh8Mh2u32o+ehoiJPlB3Nqugh3zWyOfQZtFTT7/fRbrctdXd3d4dWq4XBYIDpdGqR87a1OAWpoVIlRQPK13K5BADEYjFTpJvNxqJpkkbWk8kE4/HYMkFU1PysYzt4fD4/Y0wZIy+RSMRknoZlMplgMpkglUp5jASVbhB8kY/xeIxut2up3k6nY+leRvqMwAAY7zyzs9kM/X4fzWbTomjKmJ4NGotD9KO7ptxvOhTj8dhKkZFIxPR3LBYzY1woFCybScPmOg1Kh/DKryoDdAzVKLu6kfvMdc/lclYGy+Vyxjf/HoAnJa3ZDbULrh7a9dn2jpA1ElavhIxzE/Uwao6ehoqbpgdgsVjYe8bjcRQKBYzH40dRchDkOhbqTFAZuYvuLjQ3bDKZmBdO45tOp1Eul+3g7MM713Y2m2E4HKLVanlqOHd3d2g2mxgOh5hOp4/q3RQkfWat3fo5O8ci95BTSTWbTdzf36PVaqHRaKDVapmDwRqPpkWZVaEicFNEp4iOR6MR2u02bm9vbQ8ajQbu7u7Q7XZtz3U9g4q8nsOvX2TMqHYymXhkAHgwCvl83tadhgyAp25OZcuIulAomAw+dVaDjIpch0OjX8oZHQp1BGlcBoOBJxLyi3z2JT27g8EAjUYDNzc3uLu7Q7vdthr+eDy26JJ6h3vC55vNZoZJYO18OBxiuVx6nCiehaDW18Ud9Pt9c5iZemcpLplMolAooFQqoVQqmUHWtQDgMWBfS03vyqvKOOWb2Q8aZ81+0oFntlaxQLVaDfl8HolEwtae55my7hplP/3zXF10UMraL0WgXrimaOkZ0ZNmzr1UKiGXy5lHp0aCaSQaNE1hBm00/BwMpnu5YUyp8/f163q9xnQ6NeFl7aVYLGI0GnnS1vq3u2yWHorhcGgAnPv7e4soO50ORqORGQC3XuZ6kfP5HNFo1KOMyZ8fBZW603Q5Pe5ut4tGo4Hb21tzMtrttikbAJ59YO0pm816wBd+qbCgDd82g/zp0ydzJDqdDvr9vsmAW7M/hTFWfv0iB40UedYYRSYSCeTzeQBfzi2jIUbGatRUGaXT6UfZLLcMEmSK0u8ZNTLS7A/PBP9NgzwajSwTp9gVnplDM3Dq/NKZ/u233/Dp0ydzOslvJBJBOp1+5PzTUeIeLJdLS1tPJhNEIhGL7Fi3DyKDCMDjPE8mEwyHQyuTMQBhJpD11GKxiEqlYiCteDxuRm0ymXgwCX7Gax+eVQZc2Vb9putCY0xHplQqoVKpoF6v4/z8HKVSyTIRLKlNp1NPRkCzEi5uZZ/neLZBdmtSasj04PNnmhbgg2tKo1wuI5vNIhqNWg2U9RL1TNw6l/IStLFwIxpGZUzJAPA4Hvp3VAibzQaZTAbj8dgiBl2n56asKWhM8TK92263DVHNQ80UKdPp+h66hq6C3qY8gzQg6sDpIWcqr9lsot/vm3MBwFLUCrbQutQpomPgseJnxNBut9FoNHB/f49ut4vBYGApSDqhfkCPUxlnV771jLpnlSAn1lypNCl/lB8avHg8jul0ilQqZed0l2xLkPVN4EGu/FLVejaVd+oWvjKZTGBZONcBJuq31+uh2Wzi8+fPaLfbmE6nAPDI8VcgKXlfrVaYTCZ2bmiMC4WCGch9s3DbeFfdoyUKns9EImH1VgZZlUrFukxyuRxisZjxzMgfeKjHajaCtK9RVp3mYhq0JEcDSgefvGu3QLlcRiKRwHw+t3Iq8ThaZvUDke57vgMBdfGrewAZpcXjcU+xnCkNtjZlMpkvzPy/1NhoNALgPWTbAEdBGGO/9yPvWq9Mp9MWyTPKZLpanREAJgR+de99eNSUOg8G0400xEwNMc3iF5WqgFKA3d89Bm2L1BglDwYDS8GxPBGNRj3eaz6ft7YJRQT7oRmPRe5aahqPjtFqtbLyjJZmFLnpgtBOTQpmUYeHhpkGgqlcOoNcg6eyY6eQJZcPN23t1pM3m42Vk/j/rtL2i+oPcfj9Inc61JQZKvtMJmNgV/LK8+nqFL42mw2y2ayVqlyjE9Qaq+5RvMFyuTTHOJ/Pm05ndFypVCzYWiwWFiDQqD23RXGXdVaHzN1bBXTRGGezWZRKJXMk2MpXqVQsQmY2cbN5wFMwI6Cp6yAc7YMNMhfEJUbGrPXR+8jn86agqGQJmWdthJ4eF+DQQvlT/PqlS/hZmmbnc3BjGL2xrqZRpbZcBKlwVeg0rU4vlQpVHRi3bkgPl6kXvq8bGQdJ7oFRD5bOxWg0Mr7Yjkb5YX96vV5HpVIxIKDrnW777GPUK3nwNdJiWYLAFnre1WrVDjn7LF2jfApyjTAVElOILsp0PB4jlUohEokYopzpR7/3dWXeTUfq7x9Kflk61zlQR9n9OzXiasiCLodtc0Qp//w8GohUKgXgwWipMwE8AOo2m40ZGxpINcZBPIPf+uoaMbWey+VQKpVsLgCNcalUMkeDwQv11Xq9ftRedIi+JJ9uhKzGmI4No+NMJoNqterhm/Vjpv/j8bjxDsBsVTqd9oDW3BryvhSIQXYPmx76RCKBUqmEer1uAzPYb0YPgykN9gQTKcwHd/vBgo6G1BBr/zAAT480DTINIeA1xCqoCjraVt88hLSXUpGCjGg0fc46CiNRPgPTku5a6JoEqTzdw6IpQ0b7dHAo8MyiVCoVnJ+fo16vG1jEjTb189znCYpcBate+Hq9tnqaKqharYbz83OrS6lRPnaaHfB2ByiilNFsIpFALpfz1N1YY+12u4Y1YPaCzigzSOq4PiXr2wz0c8nPGLuRsba5+Bkp/Vv9d1Dk8uj34jnkOjIbRKOla0UH2k3NK1ZHo9qgnkGfRTMHDDyYAapWq+YwE9DFWQbMsGgf73r9pV/ZHR5ziFz4RcdaPwZgaxuLxcwuXV5emsPsOhMMEinjuVwOm80GyWTSAkq//vUXi5DdyFINGtMZnGry5s0b1Ot1m4rDh6BiY2pjOp3a4uZyuaOk+vxS64rkTSaTlgbWFw8R34NAGLdmnsvlHoGOntsk7rfObooRgKV26bGx/qQ9y6PRCL1ez7xTCu3XFOeh5Ka+/AyyIn01wnQntNVqNTNy2k6hTsU2ZXSMKNk1AMCXsks+n8f5+bnJu6bClPdtHnVQUT3w+HwybaglmWw2+yhyY78rlSfbu5jC47PSKVSQy7a6YBDkypNGtq6RUsP81Htte/8gePRL31NPcv91KAsjSg62AB72ik62Rv9+r0NJ38d9X8qmyg4N8tnZmRkzBlx0LJj1SiQS1jkRi8U8g572jTDdc7mthuw6EpVKBRcXF3j37h3Oz88tKqbuJu+6P9qJwIEgtGVusLjPGX62QWb0xe/15Rrj1WqFdDqNWq2Gi4sLvH37FhcXFygUCp4hIGx/IVAmn8/bQdPxa34R8r6Kyy+qdwdOaNpaDxDwMA2Lxo8/p7GkoAWRkuHvuw6PthqwBlIoFKwOpQM32NqxXq+t3YOpym1OQhDRDOCftnN7YHmAVquV1V7L5TIuLi5wcXHhOeisIW8zxpriC9qxcJ9HXwCs7YNjVC8vL1GtVi1iIO/bpnpRaQRtlKnQ9d+sE6uSV4PMFpzBYPAoNQfAZFGRySrnfso8CHJT1O6AGzVc6ii7n+93FoOSeeXV1Ze6bpR3fXH9aFzcGu5sNjsaZsJPzt3/U32pkwh12hZ1n/bxalaPBlmnqx2i110do/V2dcq0NMNyEh1/dZYVWMdoPpVK2TlXXNSu09F2ob0iZL9oStMRBG7F43Fks1nznvhifyMjY8CLpGVdy88gH5rW2PYsbnScTqcxm81MUICHoSF+SEYeMncEpZta3ddQaATPz6ABY69zrVZDqVQyg6wDD4AvoJDhcGgHQOs6QaUTt5FfelEBKnpwmJnQ2hRrx1rCcI0x/1aVbxDpaz+lpP+mrFLZ8rAzMuYcc617a8nDdXKDMsbuZ6jsafrarW8yil4sFh7l5DprfF6N+Mi/C/Si83fIs7mOkCs/OgPha8CsY2aFXFId4xrfbU4M/45r6oelOUYGwo92iby3raficUisj7Nu7je9Kwg+3bMfiTx0oGjHD0dvMtp1jSt51QFRLio+iCARODBl7aZRCcCpVqvmZedyOdTrdU+hnH1yTF2oMqbXRCIwRudMB1V3U6Fx6zg6lm6z2XhQevTA+DO+D1tGuMGMipT3fSNkCrV+Br1oGmROmNH5wzoAxK9F6KnP1K+Hkl9qbRuKnjKl82LdOen8e5dXXV/XAAQZdbpE5UjHjIddSxduO4tiEbbxGJRhVqPMddZ1JP/aGuQHjPFTzsq71vCm06llwtzXvtHyNmAga4Vuv+lTa+I6Fceo5/tFxapjiMinjuEecZiIApK2rf0xHOmn1k5TwwSKMhPnlwnUISXq2HFNggBDubqcDiIdHupOyiOdIfcCCdcQu3zr+7vOVBB6Zm+D7EaWqVTKboCh8l8sFkin02aQWSgnco29aRw8oFNfdAwbgTCMjA49NG5EwjQGUaeFQsGGOrhtBsCDQaZBYFRNNDnrnBrZHZLWUKeHxpiR8Xq9tglJHPcWi8WM56ea1t11PFaU4FePUuXuKnoVeI0ktCfWNWb6ctOzh0af7nv4/T/3SA+qyzvwkGVRR073RaOeY0TLTz3Ter1+dBNUv9+3HlcOEGE90N1PZrx07K3fWdvHIOtn0RAoPoI4BJ3y5zqA7vPrnh3DKOv+atBCHcPnAGDlOvbospTn1kI1/arrGbRRfkov0PHS4TiZTAbr9cMdBMSDLJdLz/wGzWJxXQ4Fu6oDrql0rf3qfQmqXwB4si0uaNflkzpcnQl1Mg9d/4MjZDXIxWIRi8UCsVgMlUrFNoNpR7Y48VBpDyeHgtAga1uGm04Iql1EF5u1kHw+j3w+j/l8bmlfnWbkAkkSiQQ2m41FcETuEU5fqVRQLBZ3uuXpKT7JI4c2xONxyzDo7SpsmwC83qifYtKDcUp6KgUMPFZkmup3+1xdo6YerBst8733Ncouf/rZT/2bn6ngL76P+4yMXA9N737tWXTdKdOz2cyuwOTEtGazaSMedfCE1ja1bsf34Dhc4EEJU4mpw/Qc4hq6kRlfelXqtmiez+8aY79e2EMNhO4vdaQO0FAjrE6+28Kl3RIaGLiGOCgnYluZh9/rHhCFv16vTXbq9boFWIvFYmtrUJAOkNoiHc+st07xex30QSeC2RxmLPx0CWXXr5MgqNLT3jVkdxF0qABvOdI2EKKlo9GopWLG47HNRSXil4tGYAAHQQQZIZN3PZh8Bo5+c4FGHMShvYOMrKlgCON3WwCee7ewS5pS1wEC2qut6WgqTDXGbtSwbU/d74MmN1rexofLg0bSLnpWPVgeer/3Dhro5RpdPZwaNdLz1ucgP3rA3VdQPPutJV+qXPVOande+nA4NPlnZ4Hfc+r7KAiTSnK1WnnqibuSfpYOx6Ex5nAM92IL93nddXFfQZI6aAqA4iwGravryFFNufM5/Pqk/SLkoPn3+x6ARcjkjwOduA90PCORLzfh0Sgza7LttS+fbiaCUTGvfKRBZhsWz5c6dyzx6XnUYSb8LDXWbuZAde4+zxNIhKzE2h+FSoEMNBabzcYGDXBqDQEkzOnTk2T90IWWB0XcHPWqcrmcpcCGwyE2m40pAHqpACxtx/YjRrA6p1vBPH6o2l2Jhxp4GLOnQq/vzVqUKlx9uYC0bUYwCPIzjNsiZFe41fjqkAdVWvw79xAxc+GSmxJ+DvkpDf2ZGidGNFq+AR4i/PV6/ajGpS+CT5TnIPbHz4iqcev3+56xrK1W69GVklr3VEPnGvjZbIZkMvlocMXX6rtP8e7WLsfjsQ2UYescz4Xf57ilEm1dc88F32PfKFmzHW4WjhGypn/VydGzDcBjBDR74maD/JyvQ7JB+m/9OXlmVK89vwzKqLfVeKmzuY3nfddbdbnOj9B7mInwpu5mmy2nQ7qod+ABi6Kf42eQgcNxKgcbZMCr5GiAtd1AjSi9KneM3GKxMC8GgKV4NPd/CEp5G/8apTClpJOKtA6mETMAU6o6KEHbACgIh4LRXE+bXxXUxPdUAIirHFlj8zPKQUePu9A2Y61OBA0bnTEaECp4AB5DrEh51wAoujPIZ1Alqmnbfr+PTCbjqbnqNCVNsfFFJUInjxTE/rhRpjspjVd6djode3E2tzqjauh2TZvuY4Ddv1UjqtcnulPeKN+7Pr9eZuM6DoeQe25ZQ2YWTmuVwBc51imF6uixpEAdsl6vLTu3rQx2iHFwjbCf8dQIXrNVHIwzHA6tTOk6y/re7vf78usXJbs1ZGZpqAtYB9fMIo04DbxfmW9bdH/ocxwE6tIN17rAtpSh9ghSafFiAUZCuVzOFK1bXztGaok8a4Scz+cNMMLInM/LTVNPVYVRFaw7KCGItKOus583/1RqjwpML+r4Gh3q8bnv5fe9H/+r1cr4HQ6Hprzc24kAfxAYnSqt9Wt/+aH7oc9AmabSHAwG6HQ6Jssc1qKlA+ABZUoj7PZ0unSIUda1pTFihmo4HKLf76PT6aDRaKDZbFq6mgaZKVWmq8mHC3Dx60t25wfsc4b9jKmmrGmQtaSk8q1/rzPUh8OhYS90bCj3KYhaPhW7dkjQKNBIZzIZuxVOb6ci78vl0ga1EHPDtX7KKB9CfsbYDSo0wo/FYubgcE9Go5HnJi3t5dX3dj9zH1JZ1M9zX8xWKY5JQXT5fN7ex4/XYxljIIAIWQVHN82vfqNRBDeNXjm90kKh8EjZBqVA/fjnYaFiZLqaB1b7iPXz/VJvGm0rujaoA6Pr7GeMKTx6SGjQmIlQxepOMXoqugnSeD0Vebi8DwYDJJNJu9iARlpvBuOa61AajoJcLpeWOtOsjaYHDyFNdxINS2dsvV6j1+tZ2s6vtUvRoHQGS6WSYQBchbWPUXZTyuxu0IiYX+/v7+1qz263axdm8MIPRvtuNKLRvQJqWLcLYngCjaTqD17qwRvC9OpHv79n1oX4lU6nY2eUjlw+nzdA6iHkRm2pVAq5XM4cerZ0KkpcHQquEx3rfr+PZrOJeDyOwWAAAFunAB4q1+4zqA72A9Zq9oIOM0sKzHDybAaRfXiKR82S0dHSerJmQngZjAYrAMyYu++/i1E+JHAMZHSmq7D9lB0NANO+WrPq9Xp2gMrlskcog4zQtvHNw8hNI8qaaFFGyQqZ3yZQfD+38B9UdP/U3/t5rYwC2L6idzO7/H7t/U9Beqg5ezsejxtQhJkLpie5h2qQs9ms/b5GRi6Yat9IU7/XqJMz2Emz2QytVsuAJPr35Fnb7Yg/IMDRHSZD2geH4EaYLBcRvMWa8f39vUXJrpFzZUbBhAqi4bOwB1uvqdu3bOPn0KtsMwOkuoO86nvQcaLu4eAcgk+LxaI5e0GlrRXUlc1mAcCm+dF4uSMelWdGx+1226I2/tztheXf6NdDeFdD5wIQ3WEf/EyW95id4/5TzoPAE2zjV53E1Wr1yChz4JNmWujc5XI5y6ZxOJVbmnnK+AYRKQd6uYSf8VRlqDVBKly+2DPrd/dxkJvmx7saZXpx6lFp/UFvd9q2CW5WQH8WlMFzP8/9t9YziRBnZMmUDQVXkb1BOA1P8bstReU6EzRwrONTWW0zyJrlUEAhszWKjFdDt69c+dU16QTRKC8WC1NGlDE+q0ZN6XTaogmCSih/imFwszH78u1mqRgpdrtddLtdu4Bekcv8fDXEWp/T86JtJvy/XQfSfI33baAuRpg67U3LBPoePBecWqdtm8RZ+LXX7UtqlBl1sauDsq4oaveZ5/M5hsOhzYFmpssdvOL+nX59Lv/uGXUNs5Z/gAcwr07QU+CgYkL8pqoRj7PvWvNveKbd8onKK9dRsy10dBKJBPL5/KOe722fty3Y2ldeAjHILhOuh6eRsXuhAA2Fpjr0QuljeFIuv9s8QfX86fGzbYPRgiuY+qyLxcIzdtAvmxDUc7iG3wWuaN2VByoSiZjC/FoEc+hB4fdqjLRVyXUoyDv512dSZKw6RmrIE4kERqMRotGo/Zx1aK15HkLumusAGT4TP9vtddX/50CceDxuBmYwGBiwhO9BI+cO3HgOvxppKoKd58+dfkVjDDykR7U1kQP2OcNY5xjTkfWbELdvhKwodtUjdNBY96WCjUQinj3he6nDl06nzagzexRUWtXNxNEJBmBAIvKkf0Nin2wikcBkMkEulzMHR7Es7jq577kPz/yeL03t0wDTkOrcCPKo+sSNSFUOtF57SM1ejbIGGS5wklkFAB7dEIvFHmUrXPvztTU9VLcHapCBx5EDvSJNMbHgz5w9Dz1/1++mDi6iLsqhD+9nxPQViXhnKtODUmCUpoyYDqOi0FRlkGnrXZ5LvVO9fJ2pUh4gjvl0L2wIkk/XELstPmqoyL8+h9bUGFloCx3gxRoQjT0ajSyq5p5tNhuLpClPz32Wp8h9PyocbekjOpaKzQX/MSKighiPx9ZKp7J/iML1A75Qiere61AEt85dq9VwdnZmk+koSzTU2mXgDlN4DvnpFPf6Tp5H5ZVyoC1a/Gy/99D687GCAJIaI5VbV0doJ4fOQvdbR7/M3KE88qsbqLALIBKJeFqL6JSVSiUUi0XrP6ZjzGskFQ3PdeCZPtRRdoMrv2iZ/9aJi8ADNkidVbdd1G99g9KVgRtkwFvvoULUejFTYi44hw+ukRG/d9Gax+BVHQF61Rz2Ua/XkUqlzKGgZw7AQDsUuMFgYLOw1Tvm12NEye4zuQZZowRFlHOKGg+PXqYRFLnGWI2AKmtVRK7g88Cyj3CXthai+KPRqLWxAbCyBNvz9n0mP9KUHsFlrKnqNDXFI6giphz1+31MJhNDa5fLZdTrdc9afI13P6VM5cchM6xjUk4ItNQhDgp2YiTMy0w4FpfT9DRS1pt83AH8+5JbjlGDzJS61oX5uzSyKl8asWnbU5DGeFvmitGXGmW3owSA8a3Oq5tN8ovigky386ueX8XVqFNPefeb4b5ef7lpjk4mywPqIB6Stnb5dbNxLgCRgQpJy1g8D7QHnK2hKfZj6PJADbJfWolpIa1TdTodQ29qb6MaRUY1uvGMIoJYBD0omupUFC8j4kKhgFgshlwu5wGSDIdDux2HXux4PEav13sEeHAP1DEMsh5C1yBTkHiYOJxAb4nS4St+Edg+6+4eDBpBra9qdOintKkMOFJVL+ogX9qfSnQ2nafVamV12ng8boaF83efS3rolcg/61WazuWl7aVSyRM1qDeust/tdk1ZZDIZnJ2dGUiFUeeuvKti1hpwNptFuVy2ZyHqlDOsZ7OZOUHlchmVSsUuMSmXy3Z7G6/9ZFSs94BTGbsR8qEpa9eYUodQjihn6mCznUjTlTTEbh9ykJHxNqefDgLLRxotq2z7nQvNHgEPCPRj4G402lRnHoDnLmTKtjs7gs/Os8myGQFezGi4l8ccqiP9dI/Wk7nXzG4qv+5wG/197WMOmgIzyCogfmhqDqvvdDqeCFlTBto7S8PINA0buoNEX7teq1uXYrqL6dx8Pm888eBPJhMAsDQMjYEO8CDgRQ96kJ6Va4h1HyhECnTis9BAMDpOpVKP0JL0BPmMh9R3FGTDCIxK209R6z7zQOmF5lSsNPCUt/V6bcjbdrvtuduXbSY69u+5z/HU/1EBuEMgKpUKzs7OUKvV7M5qyow6oHQkeF44MIcXtdAo7tpHvm0fWIrRGuZ8Pke/3zelOplMEIs9XPCiBphXS/IWN72RS50sFydwiDOqcq41cI1+Fc9B48v+WNYySWrY1Ri7wy4OJb8gZTKZmGy7aVs1vn5rpRgK1Yl+xjiIEp9bm6UBjcfjKJVKOD8/x+XlJSqVCnK53KPLJNhmRoePOnMymSASiXgu2zh2S5RbW2ZrG7OhCkhzbztTZ83NogSlzwMxyH6eq976QuWiL6Z9tbbpRtUEfGnvmG5YUEZZHQEdNqAAKAIYuIHkk4cCgCc9T28sk8l4rlELyplQ/rcZYvfl1mI160BB1GgNeLi71E1jP4d/90C7dWT9t0bn+lxMtzOyo1MEPNQDI5GI8U3nioaZNVJt+3rO4d/VkLgHn+vMnlNGl3QQaIg3m41lZYgA7nQ6mE6nlrk4OzvzOLHPUVyuYqJMU6Ezxe+2+dEg0yjrpSn8ysiIEY62xgTd9qfPA3iNBKM4Var86hpaLZGp0+oHsjqE3OhYdRsdfu0W4N/QgXHTz1rj9Dsj7usYpDJEuaaDlsvlHjnK/BtmEDnoh86mzkjft4S073OoTGrdXkt9apDdtQaCDa4ONshuulAFjovf6XTQbrc90bG2VPCGKE3ZccqLAk40h68G4hCP20Vt6gCNzWbjmaJE5cI0Bj0qVWrknV6kDrwP2vvjM7hRPh0ifVEZRCIRT6qPNThFvLMsoGkcTacp7br2fgAR9Vb1M/hcCvBTx42GWUF+rtJTGWJbnaJoj7kf6hApElyBUUyB8ZCrA8uSCNPGVGTqwD7XGGuUpCk88qEALLeEoDVkveubAC46SAo48ouIg0hBMuvEqWbZbNYiLa0RA/CsrbtH/F3XWAddP3b1I0sCo9HI9F46nbZ9XS6XnlQ/n8MvI+BGyEGRX6mKX931UeeOmUDWi7VMwOenjNOw8arbYwDpXIfFPZeuHtA1dfE320oBQeuPwAwyjRGjYU29sb/RNciqYFQpDQYDDAYDT48gr81SJPChPY1udMxad7fbxWg0snQKD4ZfLRh4nGpnKlan7xwjHeMaIk2H6YxeNbx0ZrTtBoBFcqynMEXqAq8YLVM5Ptc71ChNkZv6UpAfDStlgAecRLnRO3wpYwq+o4euyvcQ8ntm3Q/KE/l2h7GQdx21yQEdxFgMh0NDDrtO3S6KyzXEbupOJxppOwgjW0Y6OtKT9XCtGbsy4p7NIKIHv/o3SwEAPJPEqFgpq4re5/+R/AybGiDd20Pq3pyOxr3udrs2m5pOxWw2szVVo8yzrRGba1Bc3XJoNsIvMlc9o5/Nz3PLTvx/10lm+joSiWydrx9U1OmWCjT4oG5kFM8Sh/uM29b4GHSQQdbNovfHebi8JYb3HdPIMh2sUTKRdhRaIj15wFmLI3KTi+MqnOfwza/qTDCabzabGA6HtjlMzbA+pWlVbjYPPpvMWUM8ViuF30GhwHMNdSCI9mtS2Gi81cgxMlXl4CoIKl1d933Tp276mgAXjSg07asODgBP1NFqtdDpdOyryhdBPkF6t2rg1MhSjmnkeE+2257DlHqz2cTt7S0ajYbx3+12LUvDcsmhsqMZDhort4SgBpXPpbcUKZJawVuapg4yKnZ557kqFAqo1WpYrVbW/aDZHmJAgAfAk6alGR0zXbwt3bvvevPv9HxOJhN0u100Gg00Gg2Mx2NEo1HkcjlUKhVMJpNHN9yxP5eyowA1jeZcY+x+v88+uDpGu2AI+NPslfKi2U6+mK1ixoq1fT9n8xByI2PFabDzQnUDHbT1eu15Jo2mg47et1EgETIFbjweo9vt4v7+Hp8/fzZPXwfA0zPRUXfcHLdfmakQ93dXq9VePaQu35peZN2u1+vZ/a+MjAnj5+HgJruppEjkC2JS08QqrMdIj26rf1Pw9Go6Knj1Ev0uDaAAuxHHarXyRFIaLe9T+9EImSkvnR/MyJe/w3oV66qr1UObGR0pGrRer2dR0z5O21PkRprkT1OMjMpp0OiMcu1Xq5XdPdxoNMwgc1IW20MYpe5bg/X7fbem6763lhTUCDJtrX3Gfj3GxzDG/EqQJO8dZ1kpm81aVgHw1i/1zFKO1an+WtZh32jNzWDR6W+327i7u7MsXC6XM/1GozydTs0obzabR9m2bWUXXasg9sGvJKbZOLdExqBFdbnfFZkM4txg5dB1V0dK+dbJbuy+oEHm52iE7Jc18ZONoGScFFiETKMwGo1MOTLSpDF2e4vpxfKAuBt/7IK6e2DIP6N51qoYGdAg64ZpP6HydUxwgptKUrAIywRqjDVtTWdG682AtzeX6WJN1SyXS0tN8qURkaYBnyI3dco11vGLXL/NZmNpvUQigcFgYChpRuu8JEGvDCQGgM/GViS/1rN9nCSN7mmsmN5Xb597QueIjlI2m0U8HvfcPczblfr9vile4OHua81Q7GuclVT2VfmQNHrm5+s8YLetKeh6sUtuylqvMNTMCtfdxTpoGYTPfugAil3JNQyKqCdv5NfFgXBaG42JRnWM9P3WPsj1d42RyoyWydjnD8CT7WS2Tkt3pKDlxOVb6/cahGjLrWKSNBtF/txMmLu+QT5DYChrN21Ko8Z+Y025cIH8hEprW37oW/3MIFO/5F9BZWwRIsiGG8X6txuxU0Gr8mLNW5HMQZGbDtN6vXttnhut08sGYGUB9vlScaknTG/dVch8rtls9izeFVjEkoTeckReGeVEo1Gk02l0u12LjmmQW60W2u223U7E/nAAZjR0pJ8akecS5dRFT5dKJTMElA2tb/Oa0U6nY45Iv99/BHZUp0kND8s1uubPcfrUYfSrq7n3Havx87vCTlHYx6gZb3sGlZl8Pu9BRc/nD3fb8md8Vq0r04l3MwJBGzLlmwrdLRPwHLLFjA6Ftn1SpjgshudWAXRumYGfGxT/yrvqMsp3v99HPB63YUlqC/gs5FvR+35jew8lN2ChLlPAsbY+Ag/DdhR74M6/8OsaCJoCSVn7RZrqifDhmaLQwwHAEyWxVsyhFQSPuIr0GB4go00+AwCrK8diMfMACf7SlCgHKLC/l+MEWRNiRBmk0JFft/7NG3y63a7VSjTDoB4gB9TrdDH1bofDoWfqDtOVNBAaqe5KGh2nUilrm2Cal04QnYnpdGqHhJmUwWBgNftut2uAKP7NZrOxGlwmk7HBFoVC4dH0oOeQGgUO1qjX64YdyGQy6PV6tu6r1cqi+FarhXQ6jel0ilgshuFwiLu7OyuRqDGmoS8Wizg7O0O9Xrd+X72Y/jnywq9anye+Q2cDMGJzL4/Y1mN8zCiHpEZB5yZTlpfLpeE+AK+z6gK9+H50rNRIHgoW3cY3ZZ2OJw2w3jxEGRgOh57LOqLRKBaLhe0V67cAPPtyDOfIzUxQT7PNcLVaYTAYoNVqYT6f25xubX/V1LTyyzkI+Xze93rOoFPtCixjCYDyQJmmo69lGr9BN37rHITMBD4YhAhTTZMqeIgeuDHw/6JHnR7FAQSck1upVHw3LYgFcKNu4OGCiNVqhW63a6lTGiymtjkrmRtIcMbl5SVqtRqq1aoNgnCnYB1KapBZu7+7u8Nvv/1mRpl32jI7wf3h2rnoT/ZE8v103CNRtnSU6GzQ6I1Go53XmkaNEWC5XPYgTnnNHGvEWnuiN85ZuovFwmr/mqbmYSIIqVqt4uLiAtVqFfl83hTYrql28k3e6UicnZ2ZIioWi2g2m7i7uzOAFgF1nU7H6lOtVsuinna7jUajYS1ONPScE82vl5eXqNfrKBaLtubP4V1lhjVuTs/TCXqUaR1o405fcrNWx4wYuO78qu1aGh3PZjMbmalOh9tjzOhY0f2uk7Gttr4P366sV6tVLBZfpvsVCgVPeYmoeuChVEEDoMh9gkgpKzqi1K+Wf8i6qzHWWdWM7BeLhU2Vazab1onh4mb4HnTk1Zmlo7ltQNCu5NZ6Ffvi1pCHw6HpRLau0pmIx+OmOzgSNp/P+85lD9ohDSxlrZug9RK3bqwCpzUqbjZH8l1dXdk4NvY+Bm3USK5R1po4+R8MBgY4YiqS6XYisKmgqTwrlYrxrtO9gqz/0SDzcvmbmxvc3d2ZomWaS9s93IxANBrFYDAwxdbv9z391zzwjGTr9bp5vYzYtH70NdJIhxe2AzBwDsdCEvTEFDBnobdaLTsQNC6K1qch4QxyOkb8ns6dosWfw7saZHrUdCyINeCgDb0qcrlcot/vW518sVhYFoIZGTX079+/x8XFhRnmer1uI05p3J9D2wBGBJMxAqOTGYvFPA7Z19qbjk1utElarVZ2TScNsg790N5tNcg07m7L177lDJdXAI8yQUTyZrNZA5De398bApgRMDNC6iSoweHeuDOj3ZbQfXWNH85DkfaUXcp0t9sF8LhdivqduiOfz1v2kC93wlcQpT0/fJOCzFjWUoAwzyX3h6Uozun2u9wjaEc00AgZ8N6B6QKzOACev09B0/mojMCoRHUxVIkeuhB+hXn+m4qLB4PpW1f5aGqDAsvZrlSeeutNkHVkFTbWKKlcqWCZgnTHROpzqsCyfqUKShHQRH8SjciULweg7Lru7trRseHnrddrDAYDq2ezNsj6MpGcwEMPOJ0N8qWpahq0arWKUqnkSbU/R4Zc3rPZrMkvDVcsFsN0OkWr1TKjQaeOiHDyr2djvV57rjesVqs4Pz/H9fU1zs7OUC6XTZ72UVyqJBWAqYA4gm/oZFDpu7OJX8oYA4/vgd5svrTYaPZMHQ8/MKhfTfRrNcJ9jRr5TSaTnhIdncZoNGq6ZTqdWrmDmUTF1dARpL7RckKQdVgaf9coU0fzohZtJdIMKEse/DveBVCpVJDNZm32NW8I88sgBhGwaNZWjTLT1dTvlGuuueJa+KL98csOcc2CoKNcLqHeqAIq3JQp8Njr1Tqy3hrj3gpE2veguLUkJTfVATzMXeUBoefHWg5rbkzrMlXqx3uQwAUKm3rYjHR4aNRgue+h7SF8ZvKpQDXWjeg9clwia3bPBXVx3/neJLYDcd0A2GGi00ByEffRaNRmhrMW7d7X+7Ur7Hblm8+tSj0SidhViTrak4oKeJAjXX86pzQIWmvk3GgCu/xqbbsSP8+vVU6vQ6WD5EaQyqN7hk9VRwYejDKVp5tu1nV1DbKSnsdjAna0zOHKi17TymwXo2TKOrsE6Lzq9Z0EFwYNOtL30GyCO3+eTrw6/1qX5bMVCgUAsLS1XtWpDnKQwFfXKNNJY3Co5VNFf/N8Uzf5ZYaOYYyBI12/CPh7KTwYuuh+3io30g9EcqrDz8NMg0fSVDuVPp9DhfYUvG9DiPP1nNtr+IzqlS+XS0sLM9Xmvre+diVVrFwzRtp8UWEpOEcPkZ+CTSaT9qxuvdGN8g6JJFyjDDxM3eJFJFSa/D9Gwfr8bsZC35syxMEsbkry0AhCFZQCb7jXqnzciPhYymgXUtlxwVh+jg7/rY6bkp5lrn2QjoZfZE/+VCYp65pNoRPNGif1DR0mTWkf05nYllGg86MtRVqXpU5Qh1SdC7cmewyZUhvEl5ZVWRpQHal6SbMTxzbGwI4GmYz2+33PzzVlysEHRLm6ClsXRr1YN8XtorS5sW7digvjLhB51MO4jXf1SN3hIzzQbtO6fu/2IjOKI+8EY+yqRPfhnW0+/ExVrOoEbdtTNRI0wFQOAKy+AnypbzLdQ1Ab94cR8td418/XCJ/rxpc7AEENCGWH70M52FYq0foR92K1+nKVIcFoqqy/xrvyr8A4Hd7gDjxwUb5KVACqNNTB0pYYPQvsY/0a725pgi2JBBPpFDcqTipVPY+sd0ciERsC4edo7qKkXFnfdd1J2l/q3lHuN2nJrW366R4iz9nLH41GLSukBnXXc6rkOs9u6pSG2JVfAJaFSyQSvmVAnn1myDabjRk6d2925Z28slzkzjXwW2NtrSS4i0hydfx4VjQDtouz+TXedY213VB1u55L6hCVBdf+UMdx2p8L+tuV/OTdlzY70MePHzcAvqnXx48fQ95D3p/Nd8h7yPtzef/W+A55f3net1Fk81WT/cWDuLm5QaFQOHmK6rm02XzpUb2+vrYaR8j78elb5d3lG/h25D3k/WXoW5V1IOT9pchP3v1oJ4McUkghhRRSSCEdl4Kd5RhSSCGFFFJIIe1FoUEOKaSQQgoppFdAoUEOKaSQQgoppFdAoUEOKaSQQgoppFdAoUEOKaSQQgoppFdAoUEOKaSQQgoppFdAoUEOKaSQQgoppFdA/xd1OAWzGlQG3QAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "fig, ax = plt.subplots(2, 10, figsize=(6, 1.5),\n", + " subplot_kw={'xticks':[], 'yticks':[]},\n", + " gridspec_kw=dict(hspace=0.1, wspace=0.1))\n", + "for i in range(10):\n", + " ax[0, i].imshow(images_test[i], cmap='binary', interpolation='gaussian')\n", + " ax[1, i].imshow(images_pred[i], cmap='binary', interpolation='gaussian')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "eaM-A6rNQWFz" + }, + "source": [ + "The top row here are the input images, and the bottom row are what the model \"thinks\" these images look like, given their latent space representation.\n", + "There's not perfect fidelity, but the essential features are recovered.\n", + "\n", + "We can go a step further and generate a set of new images from scratch by sampling randomly from the latent space. Let's generate 36 new digits this way:" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "id": "aNV9CNC1r2Gn", + "outputId": "5855d88a-81c8-4aee-d9ff-885ad78ad421" + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUkAAAFICAYAAADd1gwNAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAADbl0lEQVR4nO2dWXcjyZGlDeACEDu45aKuKqmk6e6HeZv//zdm1K2qklpZlZncABArV8Q8UJ/zhtEDxBLBUlH0c3DIZJJAmC/XzK4tXkqSJLG38Tbextt4G9FR/rUf4G28jbfxNv6ZxxtIvo238TbexoLxBpJv4228jbexYLyB5Nt4G2/jbSwYbyD5Nt7G23gbC8YbSL6Nt/E23saC8QaSb+NtvI23sWC8geTbeBtv420sGNvL/NJ8PrfPnz9bs9m0UqlU9DPlOpIksdFoZB8/frRy+UEn/JblMXt9Mr02ecxen0yvTR6zuExZv/js+PTpU2Jmv+nXp0+fXpU8r1Gm1ybPa5TptcnjZYqNpSzJZrNpZmafPn2yVquV+XtUOCZJYvP53O7v7+36+tpms5kNBgM7OzuzL1++2NnZmQ0GAxuNRnZ9fW23t7d2d3dnSZLY1taW1et163a79v79e/vd735n79+/t/39favX61apVGxra8vK5bKVSqWFGixJEhsOh/aHP/whyPCcPDz37e2tTadTu7y8tNPTUzs5ObGvX79ar9ezfr9vvV7Per2enZ+f23A4tNlsZkmSWKVSsUajYZ1OxzqdjrXbbet2u3Z4eGjv3r2z9+/f2/HxsXU6HWs0Gra7u2s7OztBJuRaNIbDoX3zzTdLy6TzoWtzd3dnNzc3dn19bcPh0M7Pz+3z58/26dMn+/z5s11cXNh4PLabmxu7u7szM7Pt7W3b2dmxer1u7Xbbjo+P7d27d3Z8fGzdbteazabV63Xb29uzvb09q1QqtrOzY9vb25myrSvPfD5PrdfNzY1dXV3ZZDKxfr9vv/zyi/31r3+1v/zlL/a3v/3Nzs7ObDqdmplZpVKxer1urVYrrNG7d+/s48eP9t1339nHjx/t4ODAGo2GVSoV297efna/rSqTX4vr62ubTqfW6/XCs//5z3+2H3/80b58+WL9ft+urq7s9vbW5vO5JUkSnqlUKlm5XLadnR3b2dmxWq1mnU7H3r9/b99//7396U9/su+//97ev39v3W7XarXak723SL5V1yhJkrAus9nMhsOh9Xo9Oz09tfPz83COLi4u7OzszE5PT20wGNj19bWVSiVrtVp2dHRkHz58sI8fP9rh4aHt7+9bu922/f39cLbYY4v21yprFBtLgSQf3Gq1lgJJ3by7u7tWLpft+vo6HJitra2wqLrApVIpHMJKpWLVatVqtZo1Go3w0gnRZ1tWhufk4blvbm6sXC7b7e2t7e3thYPO3/KV5y6Xy5YkSXiuWEk8/7+1tWXb29u2vb1tu7u7T4Bk2cVeViY+OwaQpVLJ7u/vw+Hitbu7a7u7u1atVoNspVLJtra2bGdnxxqNhjWbzfBVX4BkrVazSqViu7u7qYNYhDy3t7d2fX1tOzs7ZmZ2fX1ttVrNqtWq7e7u2vb2dth3/C17T9cFmdl3zWbTqtXqyiD5nEzNZvMJwCdJYjc3N6ln1bPB3tB51K/6/8jDWrIenCMMjmVBctU1AgNubm5sa2vL7u/vbTqd2s7OTuqM8PKfof93f38fvufsVKtV29vbs0ajYdVqNbzvOm7/c3+zFEguM9SK5CsTdXd3l9L00+nUptOpTSaTYKUwARxaNCVf9X1jE7uswKvI4w/gdDq12Wxms9ksaHQW0P/d3d2dXV1dBdCpVCpWq9VsOBwG4AGs7u/vrVqtWpIktrOz8+QA5CGLbjjk0bUYjUY2mUzs6urK7u/vzezRalSFoLIoCHlLOG8ZlpHPzKL7xu8d9hdWdLVatevra7u5uQleDX+vf1fk3prNZjaZTGw0Gtnl5aWNx2ObTqd2c3MTANTMgqLi+xhY6jOrvPqz585RXnIiq5dzPB7baDSy0Whks9ks4MDt7a1tbW3Z7e1t2J/j8TgYR+y/RqOROn95r5GO3EDSLD0parFgtcxmM5tOpzYcDm04HAZXTjU5Vgvvo++rr/l8HjSO2dPDuMlkZS0uYDIej20ymdj19XVwQ8vlctiILDCy8P8sMgB5e3ubohq865TngnNYWAs2Hy8O53Q6tdvbWzN7AEm1NrBKKpWK7e3tWavVsnq9btVqNWXhL2uV5D2ygAD5PUAC+tvb27a3t2ez2SzQPwoseQ9VpKwF8z8cDu3i4sL6/b4Nh0ObTqdBkca8Lw04qAvOv3VO/L/1eYoYXhnMZjMbjUY2GAwCXYXMzD3zjlFyeXlpW1tbKRCsVCrBi2Futre3w2fmve9yAcmYJeXdOrQC2uPy8tImk0k4kBw+hGaydGH1qwIQYKnadJ2F93J4rkhBZTqdBqtL/w65+T89cGoFMC83NzdBnq2trfBSF3eT4Q8IFu5kMrHLy8vADSMXa1IqlYIViYtTqVQCDbK3t2f1et2azeYTfkvBvuiRpUBjoMA6oZRQ3FtbWzYej1MgiTVZBFAqeFxfX4e16PV6KQAZjUbBY2Ev4PbrPjF73HeAqV/3l7Qk/VroGRqNRoGL7Pf7AQdms1mYdxTAeDwO8qlsKOhOp5Nap7wNC8bGIBlzr72Lra6dmtnj8dju7++DpiiXy1apVJ4sXMyC1BE7kOp2rCpPlnuKNYm7rQEN1d4AvVqKLCILCQDzc5QE/FlRAIksHEwslslkEoD/5ubGzCxQBQAk4FitVgNvhyXpXe5VOeN15fOyZlmRfPWgAeBUq1W7urrKtCTzBBOeTT0sDAeAkjW5vr42M0tZvco3Ms/sM3XNAc4sK9LPX95rpWfIW8yDwcD6/X4ASLwyVWTX19c2Ho+Ddzafz21nZ8darVbqb2LWcZ5jI5D0rowHFc8/chDhJIgI44aaWdSN4DPUQtDh3VT9+SayeauYFwvjn0F5RjYdQIoVQECE34eEhmPJa8G9YlnEgXEg4eRQWLw0iKZRa/2eYENWUCHPsWhuYv/nP1/3EutyfX0deElvRRZpbfnzwks9DJQUXggBJqxJ9XjMLHhniwC+KEDxQz1KXG7FAixITzdhSCATONFoNAJ3rmvlZc1zz+ViSXruAdMazTEcDlNcHgLCC+GiqbuJi4froAedz9VDqBHmWARu1eF5H6LRal3BhXjrFquNhTUzm81mVq1WbTKZBMsL8Gm1WpnWy6ZA7zept+x56WY1s5S1olakAqNavupm/5rDHzT/M7PHeTGzwJkrNRILdPB3vF8eIwvQAW4CZH6f4XmglKB38Mj0vZ6bp2V+d13ZvAXPXIMRNzc3wYXmbKEI9H34ffYt2KGU3D+lJeknQAFyMpmE4Az5UbgRGvAABAgI4Mbx0sgqWhf3Qd05tSA3XWxvlbJZSTcgsAGoYw1rpE1dHH7XW9fI2Gg0oi7eJgvuFZfyn95iUeqADQs4bm1tpXhIAjRZABl7Dv2a5/Dv6YNd/NtTAAoOzDUgqa74Itc0j+EVsKYeQeFgKNTrdbu+vg7PgMIGUG5vb8OZUwtV5dXP83NRNCXivRoMHcAR0FePUvet0hOaEZLFt/5TBW684PAOl5eXdnFxERJGIWoho6+vr4NWVI1Jcm+z2Uy5cWaPxC0pAn7x1QLypvuyQwNAunnr9bp1Oh0bj8c2n89te3vbhsOhXV5epvhF5kQXr1QqpbQnvBcgpS6epjPksSYaQcXFBhT5ioZWzsdHshUgsYKxpr0F6a2vlxi61gCLeiQ+KsxLnzPG3cWCHXkcQH3WnZ2doCzJWyXRncRp5d74O+W2r66u7PLyMng3nEX9Pf0aUxxFD90PfC77CHn39vZsd3fXzCyVDUOCuZk9UWSLUr7ykmstkIzxXWpFDgaDUKVycXFhg8EgvMbjcdCKgJBGSjUhWXkYANJrSU2cZeNvOnjfnZ0d29vbs2azad1u125vbwOowysmSRK4PObG7JGLITVI+UxAK5ab57nOdYbnu8guACQJ0ChIY/GaWcrN9pajgqMewhg//VJg6S2lGG2jX73y1DWLWSd5c14KVru7u8F6LJUeIresibqVgLgHeazIer1uZhb2leeVUfo+g8JTEnkMvxf0/VVuzkez2bROp2OtViuApKYLjUaj4OXo+8YAsoiRiyWpxDEgeX5+bl+/fg3lRkrWKg+BK1uv10M1AEnKRO+wiMwsFSXWQ2GWzlVc15XwFkm1Wg1lkliR8EGawhGrIvALiUXJVwAsK9l3k7GIi9QXViQKgLnTqiAFxpjLqnwxaxDT7EUM715nAaVaT6ydfy4PhrF1yQtImOPd3d3wGdvb21ar1Z7lRnk+0rlGo5GVy+VUDvLt7W3KUlTZs+iHPEaMBsmiFsweykM7nY4dHBzYwcGB7e3tWZIkNh6PA5BigCmo+/NVhMXPWBkkY/yCHkTVALjYl5eXIWSv+UwApLpzmm9nZsFK1cOnWlG5SE9cbzJ0E+vGZfKp7R4MBk/y1dQijPFaMb4rLzCJgbQHZTYdPwO4zexJtFDfk/X2G1VlU4D11RBF8EUMT7/4w+m/z7Im1TuKzScybCKLWlQYCmYPXKPSLvpc+py6pijo6XQavC9NE4pxkbF5KGJdvALDOyH4h6fW6XTs8PAwBZK7u7t2d3dno9HIKpXKk33n56ZIRby2JelTY9SNJFlZE5QhYDVQQz0pE6Yus09rIMKnkWXcQhYhL17CW5MsrEZBNUqtgB4DPH94vbVTpFb3YBnL49RUiq2trbCOGmlUhYiF6SP+mhqlLhW8Ms+0iYyLDkMWsJhl74MsZaVgWSQ3qYo9ZuF6vlfPmpk9qYuP7SEPhi8VtDFLZ0nglWmlzN7ennW7Xet2u9bpdGxvby9Y0ADkMnX/RY6VQNJvlBjHluXOJclDbbbPsVOQhGNRYNTGBYArGgn+IgaS6w6v/ZSfA1Surq5SjRNirjb/1s2o6TTeevZ8WVEbQg8/BL+mYwFygCl5ndr8wrvhyMAcYe2zXhrEylsuD3Cep1KQi/1tzDpR5e95Ss97b6KECTKiRLLoIm848FwYEWQn4BHEqAT9XP0+7zXx7+9zgTudTjg/d3d3Vq1WrdPpWLfbtXa7bbu7u+GMPZc9oT/L+j6PsZYl6d3s57guALJWq1mz2bRWq5VqjoAVeXd3F/L1ZrNZqjmE70xTq9VSm39ra2vhgVh1eGuSz729vQ3g5jua8FU3qQIk70E6kfKvCrh58l6eMFcuUYM6KCb1BlQpeetROxgp6GNxl0qlJx2ONrXwY8NbfApwntNbxO+p4vfvwQuLP8a7rTNYFzMLeZBZVqSCPlQP1Su8KNBAUZFq45/xpSwyH5xqtVp2fHxsW1tb1m63QzZFo9EIfQC2trZSAMn86Ijx0EUaF2tzkjEXm7QSDp3mQu7u7lqj0bD9/X3b398Pbc9KpdKTPMskSYIGUgtGgzyeA1MNmwe/53kjOBJfPqgg6bks71rz/BTn43r4qpU8FnwRJ8fzMeesmbrS0+k0fO8jo/ycqHe9Xg9dgRqNRnDbAdBKpRKsyzxGzGJXxa2vZcDS/9xXV/E9ilgP6bqgr7ymz/FVC5ffU4DU/oza25QGJdR75wnoqw61Itkj+/v7Vi6XrVar2Ww2e7JPdnZ2wjrGXOyYAont87zBci13O5aD5y1INpaZpQCCSFar1QqTQqkS78PfKVBihdXr9VTwRl2/rIDIqiPmcutzxHIEYweQ91Fg8b0KoRt87fO6I+ZSxXIHmR/AhHXjeWezWfhb3YD8PXOBLHgJ2rCkXq+nXMBVKkKWGR7gYsDmrUq11hTw1FLTnFzmBx5cFY5319cBSob3SPh/b+VqLnKv1wvB0V6vFzprcf4WfV7s33kN3lfpqlqtFjwqckIxKBQMoX1iNIhXSv/0lqRaf8prwY2oy6kTpd2gt7e3QwoMbdM0Cl4qlVLgtLe3FwBSrZTYAcjD5TZ7CpRa7RDTdFkuk0/Q1q/QDeoOb7rYfhPFgNIDsleA+hy8F1+RhaCPRsVZa19fy5r60rNNhudYY9ZkjJ/UNfL/r0CrFqnnVvMAfG8xZr2PWv4YFdooAncbKzKWpRD73KKGAqUGV+lqr4aUUnfqpXpPIAbyRYIjY+1kcgVJn0qiViQP7zuOAw5oE22nhDWp4KJpALu7u6mWVovyyTYZWSa9fy37Pqr1NE2jiOi25yI9l+iDSJqHypqYPU20Nnusaoq5mqVSyarValB2CqJqeXsXc90R826ywC3mZXhrVP9e93VWnqjKsY41+dwceEoBz42OWvRFoGkyhor/219jeM6Vfa8KDHk0hZB+AqpkseIXgX5RgLl2dNu7Jb5yRB/cWy4cXN5To3RMEIBYLpdDxGs+n6c6tiw6BHmOmAvrgYFRLpejrk7sQGalmOTxvDyL51W1ZpzqJtJ0tAuQgqOut/8cDi6HoVwuW7VatWazGUrKrq+vA52Q10aOzaUHN+92e15S5eP7WEAL/lkDa34/rAuQOvzf++ASMmqKnQKkNq71iuzXGuqBeEPD7LHFGxQCNII2HcbbhBP2LnfRVvHGlqRPVGZjmj2a2moxmT2ml5RKpSe5lVTlaHds3LokeUgy9Romi4jPa/j3iy1SzBr0fKVPvPeHOU+6wLv5AAAA2Wq1bH9/3+7v70OnZw4YVIl3Y1FU/IxnBFhxiSqVSmhMTDCPu478nK2ywf2ceL5Ou8x478a73Z439gCpKW1Kh+g6M89mtrIsq4wYvTWbzWw8HofO3qPRKCil29vbkFKnz/hrDW91K794c3Nj4/HY+v2+nZ2dBX717OwsdcmeKuLnPifvsTEnGdPcZmkrRvk2LEcOlprXaEda1ptZsHC4F4bmqB5c8kr9yZI3i9Pyh94nBCsBrZtcm134vnibyBF7nhiY8JkabdQ6YWTQAAbZBwpGgL32zaxWq6GIYDKZhM7lsYYL68ibZWF5gFSXG1n8PvGuNnOj2RqaAhVL3PZBhiIOq64dzUoo1sDi0i7mAFHsOX8N0PSfnSSP15wMh0M7Ozuzz58/2/n5eehaDtc6nU5tPp+HhtwvbR1v1ODCb04f0d7d3Q2LhauCtsZq0aYLCpbKSVK2tbOz8yJ3kHg5s9xjs3QEfmdnx25uboLlpham2aMFrZZKFrfK528SEIhpXl23JElSrbg8me7pFH1mDmqpVEpxTF7p8b2P3scaTaw7PP0Ti3Jreo+uqwK1VyLsU5LkY+lA6tYWaUkqkGNJqlHhA54+WJc3572JHKp8SWei18PJyUlIZ9LrRLyr/ZJjo36Syu0wlP8C3BQksSSxPGJJ6LrYRMa07C3LkstreLfdu3K+/E6DILgQgKlyWWaWslR8Fx5/gNcdniP1silIUgWlz+BfBAN8H0oSxBU8vHwxkFTvYlNZF7ngWRTMMhHuGOgq2BbJgWfJqeDiFRZ7jrPisyfYg0XSAs89P1+VOlDagHvs+/1+6hZFVWYvxUPqWBsklfPSnEVN+iaiaWbh7mIzSzVTiFXosBk5RPoZmtCc1eklTwtFLWW1/LB0KdJvNBrBwlCQ5HcAe9wMv8E1TSaPkQWUyEQ6ValUSqXyMPc+NUOtR+UY1bK+ubkJUcz5/PHGu9FolEpv0kj7OjmhzCPfx4ICq2YMaOCJr1nv5+e36OEBxkfyvcLWGmlyWLVC7KVBxntjnpvXfcINBhqw0TXxgaDYPOU9VgZJz3ep1ajF6YAb5W4soJmFxdVDp73z9AD48jftGuQ7B2VxRusMXUx1b/SqA/g3IsQaVAIo0X4ApYIkVIO/oS/PxdZDrcm9HsQ1uEGlVMzdpt8hVULM+87Ojl1dXYXATalUsuvr69DGS61XrQFfpXOTgiP/Ri5Nb+LFc8VyWv33Oj/sJ//V99F8CcBRgMzixllDWq2RtbC/v28HBwfWbrefdAh6KaD0z+wDvZwt71FqQrx6H89ZxEXwwmtZknrYAC8OnFan1Gq1JzcKlkqlMAGY2kyScprafYcSPuq+W61W+JneufJcQfwyw7vZGkmkooHqoFKpZLVazQ4PD0NASbk7H8zimaiaILHW3/yWJy/J36ONWTdkJAPBB2y81kdZaNBJ579erz/Jb726urJ+vx8UAjy01uFrFHZV2fgc3YscQpqnqAWlWRaanKxZAChkCh+wxGJ7bFEQJ++RRR+oIVGr1UKzCCrbjo6O7OjoyNrtdmgk89Iut/KpMX6bfeUDsWaPnD/rAraoDLE5yZMjXgkkvcbFtMck1lw8LA51o5kcvqplpt2CeC/N5+Oe3U6nY+12O9Q/U/vsm0RsshE8SU4zUzqrw5WUSiWr1+u2u7trrVbrCR+nVqVuBOTc3t4O+W2xCHceC+25HF07s3RLMwVoH6hiA7OW19fXqTVotVqpBGDlMYnCqmKh3pvg3iryMOC4aQ4BJ0wuoe4LbcWnYMbPdF58E2ieNQskiwAdbznGuFWeG3Bvt9t2fHxsx8fHoYltt9u14+PjAJKar1r08DJo0NL3efCVW2aPxph6CNo6jc9AmSt144FyE3nXdrdVe6tFotcxqLYg2qntxvSmPlw8PkML45vNprXb7QCQWJKq5RUkN01a1olnQWkkTPkkMtOwAh4upgCQ3cxS7kapVLJGo/Ek7zNvboUN4/k/1tG7b7FD6QMZak0CetpDdDAYBPnv7++DMuUwcx3Gujws64pCZXiQVEogBmjeGkXBqyWpd4t7SicPascP76L6ueffPD/ATu7ru3fv7Pj42Pb394Nh0Ww2Q1xgHR54E1nUilSayRtIaiQwn2CKV0w+oEr6Gfs8K2d5nTXa2JLkUAOS2ndROYdyuRzy8iD1vaUZixgDQu12O1iQWDCAox6EGP+0yvCgQASOMrDRaBRSEjg4eoERv6+Hkggw7ig8bblcTnEwKIoioqYx11tB8rmIsD+4KDq9OAw3iIYQd3d3oVTu6urKtre3rd1up3Iy1wVILwtAyVx7rtq7xP49lPOK3dzp7xb3QZwiAFI5ce2LoPuEZ9ELxdTrogUZz5/F4xVhWXpZ2AcYDRgS2g/TK03lnDF+Ylap2WNjbq/E9MV7rjLW4iQ9ULJYJFITCCBqWiqVwoFhc2qESxddJ8XffwNAxjiiPII2HhB0k2qy++3tbeomRwAdCxiLyVs4uvHn84deepp2UwQ46vCuqmreRQDp54jMA11LTU3RLAZtwxY7EOvK7GUxS3db8mWEvnmI8pL8ju8qpbxprCwxD1fOD6+gAQH1unyWhc6Jcq/eYFDQUks8bw6Pz+KrykIUm7t4KKlUbypGNenZ0WR6KJbb29togI3vsxTcMmOj6DZgFgNJJgZLhe7ipAQhtEY9zdKNOkllgJPUPozq/uQZsYtZkroo3NymysBbNVgzmjrjr29lYX3JXNFAabZe70O/Yc0spDz5DkZm6TZslJhqIEvd+jxkAeg9QCh/5YceGP39RellRaYBqXJWq4u0mMFgkOr2o3OpQRG8N21crZQDn8VZ1L2btzxYfdfX16l8SN8HExpLZcJqVCXMOez3+6HpzXQ6feJVamaCz3go3JJULarpG2ww5VFYFLWs9DB5QFArEpDEgsTdxn3AxY1xTXkApQdJ3O3xeBw2JNFpqmzK5XKqDh1NidZkc+Nu04z2OevtpUfWs2TNLwc7Zi0DlDHuU/8+jxFzqf1zxn43BqoeGGMcV57DBzjgtWmH5pOtNeCXJElKkV9eXgb6J5YtgRfDz/Qc500bcIYI4PV6PTs5ObHz8/MAkufn56H2HC+DZ2d9uHMJwN/aemi0PRwOQ2wCDxMvE6qE/0M+cGbZfbcyJ8lXH0lSl42FQVPd3NwEZI+Zugq0Zo9NLUj/ASDhWBqNRuhk7F0KdR1WAZ2Yq6kbljzJ0WgUDn6lUrFGo2Gz2SyAv27UwWBgvV4vWKBod5Ku1YL+ZwFHVXJ6cLzb5pOC1YrxqU9FDx+IyeLd/N94MPTt67zcRQGkPqOCymg0SgHjxcVFaAJB4wc6+ZMpcXl5GSrdtHZdnztLWejP8hjKX2v54cnJSSg/HAwGNhgMUvLovuH8+Xm6vX28b5zgIcFcBU0uHVM6RfnJZcbanKQPAMQARkFSuYKs9+N7OCHt4k1+pKZjeMs0j8WNEedsWrghuKBqtWrT6dRqtVo4TKQLqVsBQa2utpk9OYCLLJ+ihl8zdW88T8zzsfF9V3otDPCReqVo8gx0ZMmkwweHvBWpgYGs3Mcin5dnVsWM5dXv9+38/DxYXhcXFyHLQhtaYHkOh8PACWvBAFaYz/zwlENesvDVW8aDwcDOz8/t7OzMLi4uUjXaUAjkVXtvQ88kjTGwHtXjVGqOvF2uGsH7XMWAWsvdVlJViV8f+CiVSiHi9FyhfWzTKlDyIlKnHEtRm1hdBm814QIAnHy+9vnDPSdHUDMBNMiQdQlYkRaL/6r5kJ4n5Vl4HuWHUAC4f5oYzx7w65/3mvm9pz/j3x4o2WM6/7HnRHY/B0UMBQKChaPRKPB3dMfRumYt8VVlPplMUmeIlDw9O3gzmlbDyIOyMktbk57bBxx1z2gqoIKlvg9yAnx7e3spkFVOk/QozaZYldZa25LM+plOTlZNdZapq/+v/ANJpLGC/SI3rqcX9HPULZpMJmERlWjXDaAaXfM/oQ58Dt5LD3X1tETUUwL6e/C02tMQmeGVsE6URNdAT55rl0WZ+KG0jl4yl1XaalZsrXbW/Go0+PLyMrWv9NB7qguF7vMItRGJ5lxqAYGn0fKWU1/qrXjelOGBkqwE7dBECTBnq1qtphrHxBrjrDLWbnBhlgYRtTiwlgAGBb+YGxNbDI2ea3QuZpXmOfzzajoIRLDZ40Xy0+k0lOBRagnHQvkiBxWeles1fW1tEdfKZskYO5jKv2KpaMMRPVxaz65Ws17qZmZh7aBJiEAu8ipWGTG6QF/8jsoOSGt1GGsbSxgv2tX2sqiLqjmFNH7Q2mb16Hz6kzdQsuZI3fKsYN26Qz1EzX0mtQ+FQEZMzNXWwJ+CJH9jZqkmLQqGeZyljboAaYAky7r0QLhIM/tNqZZo1sLnxRl5oFdrVruqEGVDC1ItBLUAaIxGo7CZzR5vD9Ta2v39fTs+Pg61tbVaLVVn/JKHk2ohqou007UGZNTSVPfJJ8UzJwTYsJxJ3/L82KbDV6d4oDR75LuTJEkpLMj/WNbES64BX2ONVbQ7Dq4pMgE+WFEazNDrijX9xwOkB1zeO+tsLxoxQ0Pv3u50OqmCCgVwQNNbfRp84nywtgqO/L83rDYxqja2JLMmkZ8/B5QxV1Z/pikYMYDMc2QtrkbZqRoiCANIKCekGp9UqEqlkqqtPTw8tG63awcHB3Z4eGidTifVwTsv8I/JqMrALO1qE4XUO0bgXeG7tAKEwI0WBXgFU61WU/X2PjNh3RFzrWPupFoV5XI5uP2+UAFLN5Za9hJg6bk3VUTalJq5NntMoNdmML7wgvmOAZIHSgYu/CZy6zkina/T6dh0OjUzSwVeNadWrUl9BtZBaRQKO1hjPISsfprrnKuNQJIHR6DYz7P+JutB/YSoVec3bRb4biqPfqaCZKvVSmk5AAIA0YRx3ziYe2T29/ftw4cPT2prW61W6NLyErykd7nVhSZYQJAArlE7IEGSq+Vo9lh9BODTSk6zEwCivCxJ5bR8CpNaJLq2GhTE6lKQLIrOWWbEgh1KgajVxB4FJMkG8XXnWsgRUyz+XHmAWnV4YwMevt1uByuS34GmItjEZ2bxybwvbrgqQXXreemZWodz3Rgk/QeqC77M73sL0ixt7qtVmQWOecqgJruv5eUuGEDCzELPRIBSo3RmD242bt7e3l5wsw8PD+3g4CDV9u2l+v359YlFVWkNB7/KTXa+a5Hv2gJAYrkQXdTWZf4qhzxA0syeuJCerFfPBGsjqzfpSwdvVBbvempLMeUidQ49xwp14C1jPaexoI13c9cJdDA8bUVesbZFvL29tdFoFJ7TA5kHa6UFFMzN0lep6NdNOebcQJKxDEDG/ibrZzHLsuiNq58BUPpaXrSyuqnauAG3k7lQ110j22xm30DhJYcP3uhFWFpppFxlrISMZr58NXtsVsK8MXe+jHGTZ1cZYi8dPi9QsygWdQt6qeGDUJp+pm32NKdRSylVFt2rWTmQy87busN7ZVi8V1dXwXJXOkCBTJ/DP08WkPP3ebazyx0kVxnPPXQMKNd9r3WeybveGkACyIhE6q2BsUi/ajksFm+55AUc6wzP6flaYE0YV0sSHgu3h3IxqAa1yLOCbkXIEjvkMQon6zC9FA/pn1uH5wt9ExSVRfcocunP1CPjs7IASJ8nLwpL59srTV2DZYbSKzGaT+XNI8C7FEjyIMPhcKk39WkMWCIEAWLla7oJfGMJ+DAaSigJnQWg/JtnjmmemDz6+QRhNN8xFu3FBQIcAAjd5Ao6uOW0kNM0iGW4sFVlWrRGtG7TXpDaccZf0hZr2cX7IYsm3bOOejHYZDIJzXaxwteRR2XQwJnypx7QAXWvDPQOH6xn2r0pPbDsWGWNfHaBplNpgrRG62MAqm3VaE1HFFiVuVqb19fXqZLhLKUxGo1WWiMATD0TXRNtzhyT7zlrNsar6vlir2lKl5mlztl4PH4iU9aHPTs+ffqUmNlv+vXp06dXJc9rlOm1yfMaZXpt8niZYqOUPAujD1rr8+fPoQP3b2kkSWKj0cg+fvyYcpF/q/KYvT6ZXps8Zq9Pptcmj1lcpthYCiTfxtt4G2/jX3W8fJHw23gbb+Nt/IbGUoGb37JZ/a/iJvyWZXpt8pi9Pplemzxmy7vbb4Gb3/Drtcn02uR5jTK9Nnm8TLGxlCXZbDbNzOzTp0/WarXCz5N/hN9J4xgMBvblyxf76aef7L/+67/sp59+si9fvli/30/dYaGpI2aWyu3a2dmxWq1mzWYz1DVT50xzCC5d73a74dbEWFqQ2UN6wjfffBNkiMmTSDqB3i9yeXlpJycn9ve//91++ukn+9vf/hY6KpOiQZ/IcrlstVotXOn57bff2u9+9zs7Pj62brebun6C6z3XLX9bRqbnBrKydv1+337++Wf78ccf7c9//rP9+OOPdnp6Gi5v29vbS13r2+l07Pj42N6/f28fP3609+/fW6fTCd2MVkmZWVce9p+mgGgaUL/ft7OzM/vy5UuqozdpTpqOoiWW2lmm0WjY0dGRffvtt/bHP/7RvvvuO3v//r11u93U3eH+epLRaGTffvvt0jIlkdQacgF9t3I63lM2qtcXsydvbm7C+2nfTG2CoU1qqfhSWbS4YTKZ2P/5P/9n6XOkDTr6/b798ssvARd+/PFH+/z5s/V6vVQrQarUzNKYQHOYbrdrHz9+tO+//97+/d//3b7//nv78OHDk74Hy1bWxPZdbCwFknwY9be6mNq26O7u7knnEc23IqFUq3I06VOTTH3bJ03E9h3LF4Gkl8HL02w2n1SckFN1dXWVyiGLNd0gD81X6Wgir7Za01rhTbvNZMm0KkhubW3Zzc1Nqq4amVkvnwis66BNIhqNxlp5havKw5r5HDnWg9vzmHeeV/PytM+g7km+ZpWm0mWn2WwuvPN9FZk8SMbk414aveZYc1Z9mShy+b6slIn6ly/NzLoeJSaPP0fsqyRJouWesURv2qHx3lnJ/34taOThz9OyZ+q538ut4ia2yPpzrQLQwnXNjvft0PwB0NsGNXFbP2/VZ1bNR/I7iai8tMtNrLeibmYsUV/aRyNQ7c+4zjPnMVTu2IvfWTTWaRSQ92De1WrhpYnkWIZa9eQVsa9jjjVXWHWOVpEj9n1MVr/XfFNdlYPhq8Z85Y2+j1btrHquFCD1tkfta+BvzPQjpli90tAzH2tmgnLIa+Relhh7eLN0s9OsfpCAKE1ZzR7vbp5Op6FZgt6jsinYeHdGK3xo7uDvBqYTsl8ULBjcosvLy6A96YRSq9UCUBa1qMvKrevkQUKBYFEdvY6XkkGfXZvTauduqmZoLWZmYQ18R27/3mrN+J8vUix5yKXyxV7sVV82qlQW8sV6EKghwp5Vy089vvn88Z6q555bG6Ros2DOAU2oMThizUdiVmps/lVJxPavzmUeezJXkPQLqhpZXU+/UGbxqxuSJLGbmxubTCbhd2g5BmBtYpX5g8ZFSrTKh9Pq9XopoNTNqIunLff9tZ5mFprPcptbFhgVOfza+P6L/h4QDwJaA+zX7iWHUj1cCsXte9wBo63rcDu9hajKW0EDmf1hXjQ368rhv2adI7X48HiUh4SL1L6eaoD4phdmFoBNLUt+f9mGK96i5xxxli4uLsL92lw9AVCyBmZx70SVg9+vvqSZecq7OUkuIOndBV1UnQh6+Gn35+eEwWrDZapUKtZqtZ5c+sNnrzI5/qDppUuDwSBc4wnhrzce6mfrAQJs2YRs2lKpFJrPttvtlDZ9qbHo0Pk7UFQ76+b1HsBLdyxCDrWAUErcwsf1pNq928yCh+J5K23PpfvWW5K6zl7BeUWyrBz+77MAUgFC6+G1xl77CPAcSmNpzTbyUruvMimPvsz+jBkbnB/O0/n5uQ0Gg8CjKhesPSFjSozPiCl0bz0X4Z3lykn6xdVBDz9ahWn/OIbfEP7wlkqlVE/HmNm+6jN7lw3Nx810/X4/WCcxYDZLL+Td3Z3NZrNglaDVt7e3rdlshqa1flFf0lX10WB12Tz3mjW3nip5qaFg4qO+XHx/eXkZ+n1qd3TtMRizWu7uHu5Sx/LUz8yyHPNUcrEzFFNqvpEDVqQPRHlXWwGSNn78jRo0Skss47Z6d5smHQqUeueTdo7S/pCx9/Rzw5mKWZBZFMqmY2OQ9FYkX3WRzR7dbSJT9Xo9bFgiiwiJ2W722K3FzFKdZTa9AY3nVMCgU8nl5aUNBoPApaj2M0tfUgYYMnj+q6urlOyNRiN1k2AWcV3U0DXx1ggumzYNjrXmYvzaicP+wGhHde475y6iWq0W7qFWzg2QZR8BkLPZLPxeLHiTpwwqS5abDQhoRB6Q5Bx42kkpBN8yTa1m3o+5AbB2d3dX4vtj66FBGzVq+CyoKGg1hioGjdTr5+jvKW7kRYH4kXt0O7bYDOUb0fDqcqtJra3mb25uoqkImx5WnWSNoPucM91IdBrX6HsM8HCxZ7OZVavV1AVOz1lpRQ2VFTn9lbBEIrUvZux9lj08RQGqP5ge9K+vrwP3tre3Z2aP+8/s8X4VBRH9ahbvLepd9Kz+k+vKnUWJqOWv/8ZTYT1UHvhHvSrD7KnHliSPbe5IB1OAWpZWyQJ7s4ezQ/rR/f29bW9vp6gBfQ8sUs6kpz5iI5ailNcohJP0B0ijZLG8J+WIcAE076xSqQRXm3yuZXsvLvPsvHQjovnYQIAjbgibUwllAEg3Hwse0/gvBZBePnp8YjGTnHxxcWGDwSC4Rciiz6vWpcpZlBbPkkfl0mdRF0wPYIweWaRoAY1YLt+irtebgqOXxVuOPuGdz/RBUb4nV1TPFp+nCtvMrFqtWrlcDkCmrvbCLjkZrj2g6M9RvV4P8njs0BgBV+kSWIoZST5/uQj6ZyOQ9IciZjZnWZgM3FYN5LA5CPLc3t6mAh+xG/c2mRzdmL4BKO4HboGnEvRv/FWrMa7k17AgzdKXS8G9np2d2enpaQhQnZ+f2/n5uQ2Hw3APScyN0fnhvbP2QZ7D7yf/ec+lKHkrza+Ld+H0vSh2YM/ldV2pPr8HSRSrXi6n99wAPJraw+VXWI/6bH5/K+CWyw8Xa5XLZavX6wHAPDhlDQVILv3Su9drtdoTGWLBL6gq7lfa3t4Ot5FqYYMqKzW0igDKtUEyZj1mHX7vPvgIIq43VTpmltogaM1qtRpuFoyVIZmtp8k98GmUT91rXQizR7BQd498vVKpZFdXV+F5vMX1kkM3IBd9DQYDOzk5sc+fPwegVJKdIJV/n6x19LIV7WrrVwUyDszt7e2ToCD7SKPEvvomRp1oSgwgoNcOxHJ+15HFz68GaLSbt16PoSBBHi7ZI0SN1SL1L9zZcrlsNzc3tr29HdZeQXJRvqT+jt4uOp8/3HfEffVq6cf2C4YGJaVcFZskSei07q342M/yzrgo5I6bGOGt/JE385lcgA9iXVMacHW5nlTLkPLQHP7vWQCsBr2wyOfUoRkJHqjWfo6KKHr4wwdIYkl+/vw51KOTC0r6zHw+fwIE+l5K7nur4CXkYniOkLXzaT0xK0otM6wn72oqX+lTaTaxJBcBpBoJCpJKA5k9BqOw1rhcjnOEZQZIam23VoBRdURlmALxMiCpLj93bHNmY0GgWPaAWpG1Wi2AI16ZL2/29+PEFFYeo5CyxKz/85aXRq71NkGsST2IuL2xu4RjRPsqw/Mp8DhoUqLxXImqEXm1ImkGQbSVzaiBKb7+GmDJfCon2ev17OzszPr9fqrZAPOtm9EDpb7vS8m16P09R6WHD8qmVCqFw4hVhkWlQMnaxzh0PZR5HcyYq6137wCWWmbJXjWzcHa4P5yf87cocw1sKe/MGfIVYWa2lHWmlnylUgnvyff6O76+3Sx96+je3p7N53MbjUbhrF9fXz9xqTVqXxRAmr3AbYkKDFgyGrliMdSlqVQqwZpEA6m2BKyea2qxzNCDpWQzrgfRUQVnTf3RlAc0NyWUuiEWzc1LDZ9iRQmmdsMxS19Rqk0isE68RflrgH7WUKAxswA05XI5ACUJ1N4608KFmKXqD+a6hzIWrIgFa/wNlQpq0FR8tr99E8vQzJ4EQ3zyOdTSzs5O+Jlafs/J5xUJlU08o1q9MctbXe3d3V2bz+c2HA5D4xpwQSk1da2L4iIZuYGkCuC1jk4CN7QxYQCSTioWDENBjI4lm3bQ0Wclgo6b0Ol0bGtrK/BaWJK0k/KWLt2CSqWSXV9fB/J8Z2cnkOI6Ry89fMBJ3TlemiNHwEy5N42aqpWWlZ+m/85T7kXBGC1A0CwDrBT21Hw+T1loNMLQpGzlVNXT8JbMJtaLzpsHSH02atHxTKBBqtVqeE7cXNZH09r0lkp9LygHs0er1N+nnjXvsXVRa1sNH/4/CyQ5R2QU+E5gOsexCLePR+R9znIBydjDxSbBzMI1qmYWFtpH6wAtry201ZNvxbbOxDDpWJDNZjO4x/V6PSwcn6kuvtmji0AS8nw+D+6CRhdjiuOlhnfx1XJRSxA5aT/XbrdDwj9/rwEts3Ti70vykTEZNdiBa4qS0+tVlS9WC5LvlQZS5eYP6aaunQd3DdTo1baj0Sjkr2o5n29lpoFFlY9mLSTa814aSGHtsyqunht+fhQY1dVW15jfVw5Y3fplQNl//jJ/t85YGyQRMIbqfuOoq62HFb6PDa3aB4vSWzEAJSC5yWbVRSU3zMyCRsW64nM1UKRBG6zIu7s7m0wmKbcHVy82N0VbllkA6S1ATcNqtVr27t07Ozo6smazGaxhIvdXV1cpNy7mchcNlFn7y5fsXV9fp7hFzUogMOitT74mSfLExVu0x9cZWRwkLraCG8E0jAXOAWeD91MvgUAi6TSXl5ehZFMTuSmMqFQqKSWx6np6Y6BcLqdA0rvGZo9BXp2LrKwJP99Fudd+bGRJxg577MEhz5V7MXtYHO3uzUTAhanb6qOKPpiwzlA3XnmTvb29VMWBdxNY0Nvb2/AMKAHfpVotST9PRY6Y1eg3IOugfG+73bajo6PQ8Zl0muFwmLJS1AKKudsvMZTTUnm1Sw6y67z4fEidD5+RoJ+1CChXWc8Y/RGzImk1RsMOLEj2FRkXgJwWLVBNxd9rd6vZbPYkz3JraytajroKQJpZyhJUAPSKSudBFZwvufSAvcwz5D1y5SSzuAKzdDa9pvTohJg97QbtyWi1KvMgbFk0uBwoAB/d089QDgUFoG7PIrfsJbRfVqTUN7FQDY2FwtUZXJWxs7MTEsun06ltb2+n/tZbkkUCZUzJZHGDfq2yAkw+QIEi0Pfx0diYRbPqiIGkWpGA5Hg8tvF4HAAfqsnsUXmzJmbp/quALKldvHgvb117zyCWyrdoKFBmWX46stKd8AK07ZtG4Bd9dhEj9+h2FiDopjCz1OH1mzTmYi/Kidr0WflMDoT2HfSLDSiaWao00SfKei7mpYM3PBfgqF27dRPGcgLhaHHnyKODYlH59XC9lEXpOTB9Zq3SMrMQJPQWiSowD1asXSxwkCdVwrypFelBUoGNoTw+SkCtZ7UiNVij6w63HPMCVcGvOnR+/D7Qn6l3o5kWKAXtLO/zcV96bAySWZxNLAJo9jTlQa01tSLVvVbQ9BHGTS0z1X64CFQqxKK1CghsTE2lUVKc9+e98yihfG7ErBMOHC6Xtq1SKz4WTCiVSlFLlMOd1bAjBpZ5yqwASY4gFvDd3V3IhIhd9xHjxpgvgITf89UcMY9p1RFzMxXc9OVBMkmSVEoWLjOVKvCOl5eXNh6Pw1prsxbllGOGiVrO68j4HFDq+dcqNfoI0KZQO3BpIO2lR67Rbb73bgqbmUPn/5bfzeIdFwFkns/Pey4CSL7HrdFGEX5hlWMt8vl1eF4OcOQqil6vl+oSDYiQ46m8GBkG2vIKS3RrayszGpoVFY25YOsM3TOkvjSbTdvf37fb24cLwJrNZiq3UK0RzXXkeQj24KICSsov57l2mrOqHZmUP+R7fR7OEMEaWrvd3NyEIA/cI9dZAJSq4ABGDUrC/2vkfFPL2f+dniXtBUrvABqtnJ+fh9sgCTQtqvopcuTqbsesSW9R8nu6Ub177QFyUeJonkCpBHMMJHmpi0DVijboHY/HqatmVc6i3e4YSNLlhxddouG6AEkPFKwTcio/tru7GypANCK6yO3WOV536KEFJPf29qzVaoXbEuv1+hP3Ui0onzKTJEmQezAY2O7uro1Go1Q3Kt2/mw7dR1j7GolWgMSapFkEIInRUa1WzewB5FHYgCQWJH+vPVF5D82LJRDkZc5zr3qQnM1mNhwO7eLiwr5+/Wrn5+fBkry4uEidJR3eMCty5J5Mbva4CdWMj4GkguMi9zpvF/s5GZ47yCwwbgILirXGwqol6XPaYjLkYWkpSOp9I2hodWX0crNyuRw4Lf4Gi4ON7PP1/AVUWSlBzC3y5SEnewiQhB/GilSA9HeqqGImM4F71nFhkySx2Wz2JKk5L9Dw6+R5SL1PG1nMHtsOEmAiiKOlpuxBLH+9vM7sMZtBrwb218r67JG8ZOar0hvaVf7s7MwuLi4C0HOWfHFJkTjgRyGWpH6vgOgTlzUooyZ+VncV3rfooWDJv2MLjKWmN8L5FAsFRu/iFSGPt1KU6+L5tLmuVldo6RogWSqVwkZWMj1JkmiqhgZylI/Ny9U2expwIypv9hBsoS2XVhT5lC6sKMCGyD0W2WQysfn8scGH33+bHFDP2/pcTQ3KxDwazankdzRxXHlM7UFplr5WV++v54I6vVolj+BolvwatFFOEmOD/Yby07XOK3C77Ci8dls3NNUbJG/XarXUheg+G5/h3V1976Kfnc/XZ/G5eBqV83c9+wYRzy2uAvO6wweZNOjiGxtwOCnDRB6aC5dKD6WWelsk/JAmX2t3J63k0DSjPIcHStpqkcpE9xltCxYLVMDFApDaF0Dvusk76OS/auCSPYMrrHPpuw8ph66g688LpYJqfdOToNlshj6t3OQZK9bIY3hjw1vSWNNatOApAe3bUHQg1KzgK2UZuJxYifS9a7Va1mw2Q020NqzwB10By/NDL6FNVEaNHvu7PLDO1ArhQGK5xBbWA0lerncsKp8ViNJLnLa2tgJIYKUgo1qRsZfmjeZtRepQbpt9RVBjd3c3Bd6x1l+4cPweXk1MienhzuvZlYrB7eWqYX1OUn7wTDR/2MxS1iRBGF9Rw3NT8lur1azT6Vin07GDg4OQF9vpdKzZbKZKa4viJdl3GjDU1LRSqRRKgev1urXbbWu1WqHptj5fkaOwO25Um5VKpeBK00Si0WhYu922brcbFkVz8QAiFke5GD6r6EPo5VOw0e4/gKRGtTV/z5eQ+U2nXJ1Z3N1fZ6i1pe6+t9b1kE2n0+CCkmB/d3cXND0XtFGJ4y1Uqo48j5z3GnkrTIFAgdPnr+rflEqPVSE+msta8zte8W8ClgqQWEmNRiMoF0CMM6I9JFECPC97ES/Nd/bRjj7MCdkA3W7XDg8P7eDgwLrdbjiPnU4n1f2/SJdbvRlePCvVbxQ4dLtd29/fD/1kyb4omp8s1JJUDhIeEm1AEwWEpi0SmxJTnPeFYFcLwKz4+k09HJ5D0v58ehMisipB7tMq9DBmHUB+b1WAUXBU1zKrpBM+MkmSwMVhlZg9XmCvEXtAEm5MwVKpk6I1vVcE/AwA8cEkHbqWXjFpOWNWetO6QKkuP66vWk508vaWFaDN3/OcdJ5CoalbrR2NNBug3W7bwcGBHR8f29HRkXW73XAu1VrznOwmI4YP6pUBkmaP5cHgBFhxeHgYQFwpAd6/CCzI7SIw3Th+Q7FAqhE6nY612+1g3qMVzNJXXRJh1UOH1lSrsmig1AX1tbaUUsGjUAuuz+yTdHWjqBXJUO24Ci/G73iA1Dw4bRyiMmqbrNlsFv7fKy1cW3WvNYp8e3sbGmPEUoKK4icVPPjeA1qMu/PrgaxaPaWgmdczY0lqhQ/VQjq3+iwxTo8oPO4rbjZrAGiyD+r1unU6HTs6OrJ3797Z8fFx8OhoLh3jJTcZfn61Xl6/5/ewqHnWbrebogMAcK+Ai9hjud+77a1Is3S6Blwk2qHRaAQr0iwNkJpnqAde35PNVoQW8eDoK08UJLXcSw+dmT1JglU3AwvZD80J9Zbzc0PdOQVIeC/4LN1k+jzMe8za5Wd6eJVX8u3GtMpF/74ore/B2Oel+oMK2Cgw+gYgWWlN6w4FcM+lVqvV1HNoalVsP9Kj1cxSjYTv7+9DLqXZY4ml9kzd39+3g4ODwElyFrVfQh5WZMzI8EE15Y11jjRFiUCvXpGr763rzXvkMTa+LZGvsZcCJYcVgdFY2ukbtw73zyeR8x5qnfrDm9fwmk8tR+8KKUigtf0hVO6F90L7+wilAhxRyVXcVnU/tds6G43oNa63uisoqFjggudSje/dWS9zzArLyzLhqz+A6sHEchv1cPnOM76CyCz75s9NR4xLjc2bn18FGYyGm5ubAHAEepCTs0OyOCAJ9dVsNq3ZbKYi+7E0vHWGro9vB6fGhbat0xLEmJcamwNkjdFxm46Nb0vMAkjdRMrBAJQ+lA8vRGqNLo4e+Gq1mnLBOQR5AqUHSM89aqSXZGxtFhGzNH0r/tlsFkDHzFKgqmktzBPzGHPNdeic8R4oJTbjdDp9cgWG35Sxz/CW5bLzFwNFb+GtM2IHUAMVsfxUgMZ33dGSS6UdzNKcZF6WpNIi7AFA0hsZXhkhK/ueyLxGvbFOzSwFkBoQarfbwcXWi+584vwmAMmz+0Cn1qZrp3TliDU1CGNKG2+TLWL2aO1jVecJlLncu+0X1fMnTLSvFcVs1lw+PxQkmRwCOL5sMU+gjKX5aG0tVQ1oQzYBg/pmwJ3qAbqzEPjgYGj7MtwMiGudh+dGTKk0Go0Uv0Z1Rq1WC+uh1996L4DhaQP/uR5EY0ERXzmxzvDAoSkkABzvz8HS6hpt/KFJ9uPxOFXhwrPzWXmPRTxqlmfmMz58dZqeLz03JIvXarWQTqPpPr7KJi9rXw0NjAzfCFjnXQOEGsVHGTAXeubZZ+TKMvICyrVAMuZ6KGdi9vRaWeUVNdJr9shDKi+hrqe6C5RnaTuvPF1u3ZRqRY7H41B2SFkfi6ta0Echkb1arYa63L29PSuVSqFpAc0y9P4SuFv+3icPZw3lI+mOozyP2QN3NRwOA+UxHo9TJL++vyq6GFDwWd619Yc65l1sumYxK0W73GtgBIs5SdLVKb4hLYc1VvGS51D5FST92eIr+1HnVj0pNRo4G5yzer3+JGGcn8FDxtK2eKZ1Bs/uQZK59t1+SCBXBcW4v79Pra2640pj+fpus/RNj+vKUogl6TkEhh4kLySToOkOyoNRAZEkSWjWG4v65TH0gGNJYgFqG3xtEIELpPwbzwTA1+t1Gw6HoSkBAA/hrhq02WyGw613+mQFenSoJem5uSR5qEkeDAbBsuAKDSXOzR6vOUAWz+v5qHIWSOpYhjJYdo38OumlXoAkMmkgQ9tzKUCi9DiIyKyKIM+h76cWpMqosvJ9jGNVHhuQTJIk5EQSqNEINhy173SUN22lfQTodaANYVQ5xYI52sJOMywwspSz1T1p9phIr9TBqvIVXpaoI8YhmVmo3yRZOavOFp7Ct77KW9ur5vbuNgeLEkQFSQ6jVwzb29upC5n29vbMzMKdxLw/Fszu7m5IEK7X68FSVXoia+iB8T8ze7Bax+NxsCK4BVIbIPh5WGTxeYD0PJaCGQAfcyVXXRv/3j77AEtLKSCoDZQe66HdjZTy0OyAvAIZWXPov9c5j62DzoN6CihHvBn4aM2BhIPMO4rtR0yJaes+bQhDjjFrCLjq32uOLl6k965UCVMQ4TuJ+Tl/brwISCpvgNkNQGBJEQjRtkgsuA/r56ntYs/KV3XlFCR952S1IL3lVC6Xg4WDe0dd9O7ubji0uO3z+TwohHq9nvoMTdFZNDQoQOTP7DHxGH6q1WpZp9NJtfPf2toKgM1c8HmeQ4wBZBZYvsTwoKlrAl+M18JasJZKd8DjbW9vW71eDzXD/hrjPGRblnbw+1KVgipm1hyg1/QZXspBancqlSkPPlLXwzeEUW7fW5HQVih6ZC+XH68Y4X0oQPEWOBwl/Cyy8j6rjELu3dYJZpLgjEajUYhQ4xIon4T20HI+7VgSI5nzPoxqSSoVoHeGKB/J7+lG1giddi8fDAa2tfVQF03ZHCCMZU0kj8axGkEvlUrh955bD1UmcI5YFlRcTKdTS5IkuGiXl5fRTRSzXnl/jSJrezutuPHpXN5dX3bELK0YWGuaCICv/2YtWUduI4Q/xkIhXUbrmYvad15Of/B9fqRPRcNa14ongFEzSnwFWN4AqcPTVto1S/teqpLSNDKeS1PwMDj0EjT1Iq6urlKt3/g+SZJQRQa9sczY+LZEfanJrtwUPAJR3SR5KH8D2Vl0JgcNAhjSqaTZbIaSKRLQ8yx700nLchO0W7RG5FhUdYuw/OBjAJD5fB6aSGgOJrmh1LGrS391dRW05jIgyfqwFj7a3e12A3jwM9wwjRoqLxnjwbICBz7zQIE0L09AAVozJzQv1Sf7s16sHcpJlTPeC3QHCddaPluENxPjatUiAyD1LhsUtQdJMwvFG4CEv1TPKy2eIa/hOUltiQYnCVD6tn2xofGLyWQScqb1ipLxeBwMAe1ypIC7qnJbCyS9NvcEvgcutbSS5KELNDlZSvKbPUaCWUSivJRNwaUV3c4py03wGlAvefd/Wy6Xw6ZGiSC/5rEBlEmShNSgSqXyJB+TzU+Ef9m1Yl3gchqNRvg8mtTCVVHvi4USC4x5C04B0gOltyBjB3Od4fcblrBWy6Bc4L6Yy1ifSbhj9VrIKex0OnZ4eJgCSS0xLcKazLIiNX1Jcww1CZu+AZQgKkCqpR8LAOX17LHAml53Akjq9RSavmVm0X3C3pxMJsFgmM1mtre3Z6PRKKUU8JhIrNf9+eKW5HNulK80IeVEARUh0A61Wi1wKnBn7XY7LLqmLRTtbqu77C9X8gnMnjtinkj3YVFVkSjHCAVRq9VCgjOWA8C6Ckjy+cy1b1CLtq3X6yFgBPeKtRVzxWKKUS26WNXGpq62yuPlYu+gsLDuSXc6Pz+3y8vL1MVSMe67Wq2mcghRIJ1Ox1qtVqGdcfxQoPE5oWpJqhWJBaxJ15odoYn1m0R8V5HBu9uaBoS3BNCrJxbzUKHuptOpmVkojtCyW1zsZrMZKCo6K7F+Wun03FgbJJVw9gdGN5BaZAQOAESE56FZVII6aEUaY2jpVJFpC3zVBdZL4/VaVp/249+Ljc37scFZfB2431tbW0+uH1BF81zgxg/mhs+rVqvh87WN283NjQ2HQzs7O0tVX6gS0PeLudz6ikWENwVIL1fM5dZgH4eIu4jG43Eqp46/4VkI1njvRfde3nuOETu03t1WekZzB80srBdzAHBoQxNvVBQJ9D6bRcHd3y6KUsOq59l0jXlPDBczS/GzqhxqtVoo1sDgUEpvlYyKjS3JZYaCTerDt7dTVgA/8y6Q3sVRJED6Z45FEpU/VM7OW5E6NMGe3D0fsUdZwGWy+NrsQF+rDg+UXqHd3t6mgCBGY+g88/weABUwszwM/16bDG/J6ueydkr2cyiZC+YSRabWNtUpvrHCS1hgfsT2I3uSIKFaiJ4XzlJWzGHRz+6DK/56DZ+4n6VQ1cMrlUopd35nZyfl2QGQmoTuU9CWGYXdu60Tj2CamkEEUpM9EV4TeeHS/KuIvC4/vAZnMXSjPleVoblsWN/MlwIN3xOF5tBqFY/m/K0zPFBCbdzf3werY53StCwrMWtP5LFu6snoM/jn1kOkvS/NLKRHeSvZu6yaS5gHn/rcWGRR6h7QV0z5PaeoXmJ4ykANDj3rPhlch8cSBV32AQFh3gMg9h2GVrUizZYESd50OBw++TmuNJErvQhd0VsnyMxSZjVRXSbN92rENAdAYs1rswbP7InwRfLc3d2FKz41wqxur3aKiSVI+/f1IMHh5N8cWGTXKCZ86NbW4z00y8oUG2qRkLeqMqoi8Bayt2i8C0UZIO6Pt/Zj67XKGmXJoXtGgxqaf6sdZ1SeUqlku7u7KQtHMwr4XW1N9pwXs+q+y/qqFpgmwWs3fFKXVHkqMOnzowhXDTytc470OpBYIwt9KV3lz5JysoojfD4Glp5FswcjgLMDPsH5Exj1MkVHssT49OlTYma/6denT59elTyvUabXJs9rlOm1yeNlio1S8iyMPlh9nz9/tmaz+aKmeh4jSRIbjUb28ePH4NL+luUxe30yvTZ5zF6fTK9NHrO4TLGxFEi+jbfxNt7Gv+pYipP8LWuMfxUN+FuW6bXJY/b6ZHpt8pgtb0m+cZK/4ddrk+m1yfMaZXpt8niZYmMpS7LZbJqZ2adPn0IjWDN7EunSl+/6c3l5aRcXF3ZycmLn5+fW6/Ws3++Hr2dnZyEDX/vGlUqlkCO5v79vv/vd7+yPf/yj/ed//qf96U9/so8fP1q3282shBgOh/b73/8+yLBIHh3IlkQiqOPxOMjyyy+/2JcvX+zi4iKUv2lUHDkoEaNdFT3+Dg4O7N27d/bhwwc7OjqyTqcTql+y0k7WkUmTeomaa7uwXq9nJycn9unTJ/uf//kf+/nnn204HNp8PrdarWaHh4f24cMHe/funR0dHdn+/n7qFju9C3nVG/aGw6F98803S8nDelC5NBgM7OvXr/bXv/7V/vu//9t++ukn++WXX2wwGITGCVqyt7u7G0oNuQCrVquF/FwyNVi3Wq1mBwcH9vHjR/vuu+/sm2++CSWKWr3hc0pXkSm295BT6557vZ798ssv9uOPP9r/+3//z3744Qf7/Pmz9fv98MxEuMmbJO+TBszU7f/ud7+z3//+9/bv//7v9vvf/97evXtnnU4n9A2InaN11mg6nVq/37eff/7ZfvjhB/vzn/9sP/zwg/3888/W7/dDkrfvoPVcZZbmUDebTTs6OrLf/e539t1339m//du/2bt378K+pIemNscpl8s2Go3s22+/TckUG0uBJA/IAffgqAsbW2C63VBzSVrGbDYLSbok6moyefKP5F5NGPb3ttBpmasIsg6n/szLExteFi2tSpIkdBrRu3o0gZcXqQ1sPF++p4nyKo+mamTl5i0rk8pB0i1pE+QO+q5K+j25qlRxaOkX68BG3OQa0mXkUTnYVzQy1j3k3SfNowQ4kKNer4d83evr65BqVSqVUk0SOHA80yKQXHWNYmuld9lwj43fE16+rLxl3UN6jpAfubJAcp01urm5CXMJSPk72TVfWJ8/tg/90Fp7bQlHj1TOkmKEguSyebuFXQTmk6x9/a4vXdP7LDiUmkvpP1tfWWPdxFmvBADJRS99lkVJ9fpePt9yGZk2GR7wtZuMNhD2DTv4yvexpNyinjlLjuf2X2wuvbLV2ztpVIuy07Z1esCzSi3zlEsNDL2LR/uY+p4ByOdBxr+3fgZ/q7mULz0UxPl3rGGKNtdVebGQtd/nIqW1zlgZJP3m9AmcsYPktbjWluqlYLxIRqZTB4JmHYCsA7HJ8PL50kR/9Wis8oYF16GVH5q8HStxzNq06wK/t4i12QAvOhtxAGPvoS+oFZ2zokZszX3zh1iysR4qFDDeCBYUF2exNrTWqlQq4epVTyesYo0sO5QS8RdncScMRRuxNYq5qQq8WvUWO7cvDZSsB99rtRMAyPfMsZ4VrjfxV+Iu6hmw6tjoIjAtNYwdEG9JqovpXTZetPAys7AJntOIeR5OfR89fPCRWpWhQKkLpwrBW2DQDjT00EYZy5Q5rguQelCgPWj6cHFxEbhhOuVg5Wp3aK80fEXHSxw0XXNfheHLRdVC54AAkNqCb39/3/b29kK9t3J7cJjwW4tapW0ClCoPSkzvg+HirNPT09Cog33DfKtRoRal2aPRote2Zin3PMZze1g9S54T4wn+3tM4yKeNXlBk2u4PWsKXY647VgJJPQQKILoZmQS1/lSDe4C8uroKzQRwI3ShIdDZwN56jVkXeVmR3vqKdYP2hfPID52gP4enYfH8Be0x133T4ddMry8YDAZ2fn5uZ2dnASRpA6ftt/S9FIwWWdF5rIN/v5i76KkQVVp6VxIAqZey0Z398PDQGo1GsGhYA7N0OzmsSd+dPE8Z2WvcLMj6oMjOz8/t4uIi9GBU+fjq3VcG66VAGWvQUvRQhQXQKb+4t7cX2tV1u13rdDqB/+UM+es2CMZBjeR5f/ha7rYHELWAdJF04RQkcXW0HbsHST7LLH51AP+/yJJcdWL0Pbz1BUfEK+Zy6+eyOHCr6prSLk5viMyqBef99Ouqw68VPJdakhcXFyEy79ch9l4xS96vQ1FAqbIsejGUh+QQAnrdbjdEq3G5VQ7t6K7cV968pN9vWJK9Xs9OT0/t7OwsZIMMBoNU5kSpVErdjMnPeF++emUS22tFDR9AgjtkznGtCQR2u107Ojqy4+NjOzg4CA2h7+/vUzw6HbW4FVKDnos6ha2yXhsFbrw2Z5F8Q0tPxhJVAyTVktTei2qtZvGS+jx5L7R3L/3BjG0wvxF0AKQKVto0Y5HrE4sqriqLWipsNL39UdOW1APQjjKLWm4VNXS9Y9zwIoXF86KY6TbOPdRYLO12O/SK1KEAq9kLRXQD8p6LXyO6q6sFBVeXFRfwdIgq4pcGSA2YgQF0MFLl1Ww2g4V/fHxsR0dH1mw2g+WpTSvo6KTt7XzWwab7NZeLwBYdbB9R5M5s3Ff6vdXr9XClqrZPUnc7y2wuAhxj7mOWW6KWs0btkUEHHKtq8ljAZhmwXFYOPTQK9p4+8Hl2yKNRw9j9NbEofl4jZj1qhyilKwBK5eh4fg4gwEg6D5yXpi95hcQ88H7qzvI7eQ7vwWhnqCRJAqjAxynF4I0L/arvvex+23R4L1LTdMweDAalQMgl3t/fT708SAK019fXliRJykvIusvH79Nl1y236xvU1VYLxMxSmphFVW2mjWx1sdT1MEunNRRtvWTJqrJpKhPEcywtSJNkl9VmRfJ53gtQK0StRdwYzUHzt+4Vxc/552cvaGs2nxajVwAAINzpA0Du7+8HF67dbof8OWTSPqWe58szUPPc8MAMKCRJEi74Ul4YMGUOFFjJs/RWZcwQyFsGr6goori/vw8yqKutNMjBwYF1u11rt9shDxpQ5bz5a4B1f2bdub3q2Oj6Bg+GfmJ0Q2FFKtGskTdARQ+pmQXgRIPq5y2zWVfdyDEr0VvESh1oRM73z/ScIrKq5Zn1ymvEAFKfS8EeTa3KbZ3kXD+Pm8jjaQK9v5x7UvTGPU0N2d7eTvGORLKPj4/t/fv3tr+/b41G40mOXda+2pTyeG74M4VRwRqQokRArVQqhbmhKocrRtTKxsgATF7CwFBZOPvNZtP29/ft+vo6ZHYkSZKyMJUO6XQ61ul0wl7T602YH+YDLwfFzb70Cfgv5m570DB7TDFQKyuWSc+/FUjNHkl+zcTHcgBEV9EKeW6EGDjGQFIj1Fn8IlalrziIPW/WIV1GNm8p+Ofx3Cm3JPK3BAN8ZVPsxsq8r/b1crD+lLhyLSlXk5LfSRDM7PGumm63a+/fvw9lnwcHB7a/vx++ApLLXAfyUuDiAbLZbNp8/lAeirHAvuOcEOjRC7b0rp8kSUL1ThZY5C2fyoKVeHBwEGThki61JH0Vl+Y/kuHCs7NnURr+EjrlkZdt0h0bK4Mk2kvdbH6uE6NAlhXp1vIxfR+IaLQgrkOMByviYHp59bk8uHir8u7uLqQqxMh03k8tOP/+eQ4Fy5gV6Z+dNdXNrfe98OLqTk2u9p/pFeC6Q0FSryUFJAeDQcqSNHu8c7rT6di7d+/s3/7t34L1qJykT+kpej9lDV1/QLBSqQSODjDUIBT0wN3dnU0mE+v3+yGQwwuljUcWC2YUOdjj0B4osXq9HkBbK5/0Lit+preimlkq+VyzamJVfHkEbzayJBWkPAh6AIuBDIDr77wg8jsej61arQZ3YdV7RlaZiCw+xgNMltvqZfQ/K5fLT1xZr0yKItNjEXi15hUo2XgaiVS3BXeGtfB85KL52eTZffoSLrde70szCzMLYBJLJwEcOYweIFWWlx6qgDXgxP7Z3t4OCosE+Lu7OxuNRlYqlYIiIZhFcAPF7YNPRSlns6dRbcBdr5tAVg246O2OZo95q+CFeq3eaCsiuJsLJ6kPoj9/bvIBSAh2JV41L61SqZiZRTmGRWOTiYlFVWOpJ1p/nZVesYjn9J/n38MndK8zPNXhrX0FbL+G/u8XzRPPqnJvevj0/TWyrdf7atBGn4OAATwX1qPykEp7xOZqUXAjL0tZ38/zeFShUZigwAlIUmc+mUyCdakel3J46qrrHijaovT7jjRBPlfPmdljmpzKwvvEvDP1QOEtAWd+R/f4KiOXK2X94YpZVP6gE7b33J5aK2qGq1mu5jeT5gFpk43rAZKoIRcv6Z3BfFWiXK/L9IsfG/o7Csi4C3mATmxdFCDV2vXWmyoDrTyCOPeHUj8zjwBObC24IA0LUptzqPvlOy35a4m9Ml90gGKAmTe4qOUFIM7n83AhmfLElUrFSqVSKDMFEOAftbab/4vdx120661KjvxPChbwWmazWco9VmtXgVy/9xaxV/yak6nW5ou422bPBxVi1pK6lRws5Q40VUgBkrZYbBqNqPJ+/rUpQGpElWYQo9Eo9IqEE/P9I/01lpoczyHzYKKfBbgq4czGUZdj2ZEFjjHeWJ+Fr1tbWylAmk6nwVpR7e6V4Lr8z6I1UZfbg6S/olSfS7lX5cKXfT6dbw/6uiabyKnPolU+uKTVatXMLGVhUmUDV0k6nS+dBYjgBgmOaFS/CKBUY4Pn0nOk2QjMQWxdFDjVEsYy9tSV4gj5l7jn3rhaZqzNSeqmeA4w/f/5AxvTAghYq9VCLz34GG2wy2LkUT2gVqiPqAKMvDTCyr8BSDSkAiTDBwe8lvUgSVmjdzGeG94VjIGjHkpcFW0ewHPu7OyEyg8a1OoaxwIe+pmbKi2zp81GAAHl33C3/Wdlyb2M9eifQ+c0b0tS55L9b2YhNQvjQt1lVWhKRaCwtQZfrVDOkS+xzHv4c0Q3o8FgEK6Y1dxiP2LAp4Er3x5NgRROmr2shscqOLFxxU0WEK77d8rJYE1iQRCx9JURqrE2tSY9B0ZuHukVNIHAkiSyyoLf3NykODqVLWtxPAB4rlPlWlUmr4xiCsrLr7XP8GCTycTq9bqNRqNUFxzvsmd9xiZBnGX44WU6KGW9r85V7PuYgvJWfV4Ao/uf94aPVBoB5eNpCKx9X9usFSkEfbRHZlFBHPWSMDYGg4H1ej27vLzM9AQ0FxSgVO8SoFeg1KowvE5f9sh5WgUoc3O3lx0x11hBRTcdxDv8HNohZkVuUovqN7+3JLUhBOBIPS1f2ZDeelxkOWvAwNeIZzW7WGUsQ4Go3ApCWv3kSwBns9mzuWisnwZz8gSSdX7P7zc/p4vm6aUsSdxCzf0zS3cm4t80S9GGD4COKmy1TBuNRkjjirnbeQ5vcOBuA5RcscEz40V5g0CtSM3b9VcyaDyjWq1akiShKmeTjke51G4vGt6F9WDgo8NqOQEmmm7iI9v6frGo8rKTEQNsJZrRzvCPAKRPP1GXIZZqEbPAYhzlptSBH7GDz1z58jZtuFEul6NdkK6vr8OGvL6+fuKCq8WzqTUZ4xYh5DUIkZXx4LMFtD7dW5L6WTGFkuea+OEt1tiZ0e/Vy9FgIgqNxjBYYT7XNWZJ8rl5AWZM+cInc574XukBz1MqSGqt/s3NTSqv0ixteetYV6bCQFIXWhdbJ8sfOm1DplFKNr5aK0rmxwB2FXDU71Xzaa2wdoi+vLwMnXOo9KC3nz7zfD4P4KEJrj7PUINYiw5oXuuiL+bMt88CSFSRMS+6hkS5fRDNLF0H7uVZxRpUZUkVSqPRCBFSDeYgjwdLlZG7V7wyUk5LU4Oynruo9fHPqsFA/YpVf3l5mbpID6+GqDgASeejTqeTalDriwGKkIm9FmuswvnBmswKvu3u7qa6PCk1hAvNHT2azbBpaWLuIOnBURddD5emcvDyrgKuq5K3CMgE8V5YGOtwDoue0ycva+AGza2ujVpQPj9NE7KxgDT3U3mVIgDTWyi6afXw6Ub0QSPPg0HImz0CKgAbs5b1MC7LXWNJAJDtdjt0f+E9eB6+qmVEdBVZeV+1zlg3jZwmSZI6WP65+JrX2iiYMK94MHouABf+n2DI6elpcGGhS+Ai9YZIbrikA5JyzPo8y67RKvKxn7IwQflw1tfz+Z6iwhhBKaAQuKCMvFjvhS4r20og+dybZvGMWRYkHIq+snLe1KJQrcR7Kj+2bBL2IrdGrRNAUoM14/E49LPjWc0stQgcblIvvIuoeXva3inWMDTPw+jXQ19qkWd9nj/MZmkgUtDxTSOYF8ZzMnmQpLQNRYgVRP0y9cBY8Gots/c4qGqFmD0WK+C2+WeOAaP/usnw1hYAqFVFmoqlFBBXNhM11jLevb29VGcdrgPmOgpf9lfkWNbL89SD5hn75yyXy0Ex0nyX7kFa/71uJH+t2m19eC+Q8lwxgGSjaocSb0H6RFjl8Bi4fmwG5ao8L7loeCtJwReg9J1nuAdGr13QueGAapQe8AMMtbLIW5Sx5q5+/tcdi0ASWXTeF/F0rIF+r+6u5rLFNrbnBGPDg2S9Xg8gzDwlSRLKWGezWbAAlZrRIBzPyb95PwXIrGi9f7Y810TXBeWsGRXwjtoijoIG3Z9QP5oriCWJFUmbOPIti+IkzbL7Ovj0HtayVCoFL9LPE8qNiwKxlCnZrFarQSH4Vnjr5oRu7G5ncVw+GOPd7CzrxROusYhj1iGn39y6qSBZVi9AqQnMXKquoAJhzPdqSSlAEpFTCzJ2d3dWXl8ea6WBKbWoPGXgQV4bmSoAqYIslx/u79ELmXR+eQ+zp5VYfqiC5MAzx7zv9fW1jUajEIzQtBfenz2mtIIGOOC81EJdlNpUJECy5+DBNdWMwKHSPdqpG2NDvRnlcbXRcJH39fi5UjDUdmasJ/uGvElA0gdxdO715z76jbutedXrJs2vfVsiXz3A+Bw2Jfq9labaQt0z5cP4nt/hc9nw2iFIid1Vgjcqi9+w3qrUSgblTxRYdFMoz4Wl4qse1JqMVdqsG/TIWrMYtaDWMJsJzUzHn1hZn66JWvdXV1fBitR9QYmY2aPFvYo1qZ4CLrS2cRuNRnZ/fx8sQ/0Mnm0+n6f6L97e3lqpVAppIz6LIsanrrsOfk2yFLNe3aDVXQQMNbNCk+qxjHUtfQVbXgGNrKFr6teOrlLNZjMkeWuXcUAS48nnv2qkO6vfg9+vvsvTqjJudMeNX2R1YXj5RF+1ONVqoSEvk+BTg3TyNWigJrt2PV+W94gBvY+Y+2fxEWAzSx0k/9JFBXS0/Zh2/NYgTlaK0KZrprSCymRmYTOxcbl6VduLNRqNcIGT2aMyU3fo6uoquE1XV1dBvlqt9sSdX7RWWZaIutC+1yXPRH0zpD+/f3d3F/g++Dv4LHVTfYS7CKveB8EAR82m0L6ZgCOpM4CjBt004BQr5YvtsTytY4aum/bG3N/fNzMLjXhVDmTRqL5y/krj0OiDO4qI2KunphbkuudordsS+RpzE3SRuZrURzuVoyIbHuHVOtSXHkI2uy7q1tZWmNR1U4FiUbcsgPSaUgHN16BrgituDy6BT4qNWZSxgMG6I2b1q8LSclDfTh/Cn1b6pVLpiRJh/qAjaP7Ke7ZarfBZzNcya+Tdbt0HzGWn0wmJ7vP5PFjrNIdAgeKeY51xeyUA7i1Jr/x0X61r1S8DkPCRXCWrvKRPwGY9FSDVilPeL5ZylnfamSoW1r/Vatnh4aGZmdVqtZQcGo/QWAUUAtF66AO9OKzT6djh4aHt7++HYBRVeXkYGmu7294i0bwtFliL2DXPTTtB+4CLf6/pdGpmltKSmiLAZkVbeZBcRR4PkM9xpQSS9PDGXGY2ie+4jPUDuZxVj1qUJem5Y9KUuJOaZ6ONfrfbDTwP1pnSELEWcngHWA3wfygQLb/LGl4ZIsv9/X0KfDudTuiheH9/H2SB/yQYOJlMbDAYhGgwVm+j0bAkSYIVnVUX/ByPuuxaKEjqftcgDCBJGZ/yjp6/0/kCHPWcPQeOzwWqVh2sF3maPGelUrFOp5MCec2TZh4mk8mTe+C1wQfcKheHHRwcFHI3ei5XyqolqfWZpCNwUPSWNKxHTePA9YJUVzdJeTT/MrONyo88x+rf3xPHsc3lwVEtSVUOlFQBQN6SXOQK5aXlY7SCWpJsvE6nk7prhLSRarVqW1tbwZJTtxuAJD2F/NHd3d2QT8k8kH6yrDJj7pU35MDwzATw7u7ugpcCz42lwuHr9/t2cXFh0+nUtra2wp5jjThsrK33UPKy6tULIwgD76gd2Cl/1coUnybnAd0HofwrRiPkBZC6Vtqso1arpZSqWpDMA7LTLwBjS91sBcl2u23dbjeVE5mXhbxR4IbvPSepFSqTySSE59Xt5isvkN/MgrUBn6Uup5L2fL+1tZVKiF7F1Y4FoTzA+kAMh045VSxG31nZp/3olZpaQxvjI5UPyzOy7deMg2ZmqUg8m1EvAoPnY33MLAQ+zCzsAYIO0+k0WHxJktje3t6TzIBl3W2GB0oNCOA680zqamP1qjvL/ThYtPV6PQVEammvGgxcZi2ey6hQ6gr3VAOiDAVDLduMVXcVDZA6sGj1ezIQNNNAOzvNZrMU0KHA1Dsgeq0X08WqiPKQJbeKG11wXWj4BCaKw8/QCJwuHn/jAVIDJvyfbppVucjYJKq16CmBarUarCfy69h4PrVBU3w0AgdYaud1zSnM24L0Si32UuUTc91UQcBHahaCt4a0LheXcmdnJxVoWBV4mAdAUmkOnWPoHdbHc6VYaeQb0vCVOvSsYGMeAPnc8EosllWhTSCYD/aJ7lNVvkodxKq58gZHs/R6AVq436y/AiTPZGap4A1GFh4JdBAgqQFQbYDDe/0q7rZ+qJ9kXWTlFzVjHi5Lqx3UGtPfi7mGmhbEYfAR1lUPnwdGPXx6red0OrUkSVJJyAClT7WIvfR+Eh+s8YR7ngC5CBx9IMfX1MYCYmYWXFhcRPL4fNkm76/Ry3XStFgrH/jx9AZro2CHhTubzULOIQBJupL2pIztj7yARN8rpowBM1035cmVZlIgUkqHACFZFP66iixFXCRQepkxdjhHyItsGvTh73d3d5/c2qnnq4gmwrndux3TTkyAT6fRkkQFG5CfTRGLMntgVZCMgeOiQ6gbQ9N1fG4Z5n273ba7u4e2bTy3anPV4mxUjWbrTYOay+Xz1fIIDKjsfg5i/2bDUqs+HA5Tc6I8Il6BBupIUdFO7ePx2K6vr8P8qPu6LkiaxROVvbVLkMZXd8FHajMILa3Uz4jt6zyHBliUt/YXYunnxgAcmSnb9DyyXsuaZ9R3FTkZ/oyjzJReYL00j5K1pXpIk+J9rrEqgDzGxvduwwl6olgfEstR0xzgvGhUwN9yaHzES8sVdZCs7A/dqlZkjFvUlJ12ux2sSKwODqLmd0IsA4J6Z7Vq9liya1GBGuZjESgRMMP11FQXNioWCRsdl5pIsZbPafs4/t5X9qy6VrGhe1EtJPacthIDFMk/HI/HdnV1lbKqzR6zFXzKTJ6pMvrM2m/g7u4uRcXEavnV8lIXVlNiyEbQjj9YWkXROssM72mCCZogr5F7s/QFgMiI4aFWclHJ8Rvdu52lxX3Kgdljl5YkScJBpMaU9B3y38rlckgP8ECJ285B1SDKutyRtyQhiNV9abVa4Wrb3d3dwLVqc9pSqZRKGGcRYxaktyKLKA0zizcU8MCi0embmxubTCbBA0ABoMGr1Wp4RsBGc/nUOkOhaZWMPleeQ+cNKwWPRSkADYSwx9ijuof9K8+D5+efZ+a5Y0pUg4He4vUgSTs0QBKLK5bW9BIAGQsWam6opjxpuo/ZYyARKzJ2ntSCLOIcbexux1xudVn5ngOjhC15T3t7ezadTsMBLJVKwe3TBhhERDVYoodvHXD0HJfyW5pe0mg0gtWxs7MTgED736lryibXze7L+l4i3Scms1ID+sKqx0VmXuCBSNRl+DQVrDO1AtQC8xZZXiN2AHHhtPkDKTRas21mT9Y9qzw0z8OnlqCZBQPh7u7uCa+toIkVyTlSINeqFq2QwiXF1fb7rcjh4wq6NhpEYw9Np9NUZgLpZr7sMFYiyzrmfYY2vgjMg6R2u0GQLKJeNT5uOO+rG5waW21HT512HpHHLBLd5zcqNUCkVjuSK9Hsm1coKPpGFkUsbJacmiKirgtuMcoLpQboo+kp/TSzkJTtu7SzydUa14Oel0ukB9A3rSApm69YkewjwIa10eeL5avGlFheLrdZ+ooCLeFEOY3H45D7ub29bbPZLGVtsb9iKWZageKDGkXuObUg1XLkpfmqcNpkQ5CvSlmpRuz9lbhFGxkb37vtXQefY6cdWdRNxc2Gh6TVFUCoSel6h0ypVErlXW0SAIjJEgNJNiuf7+tfif6aWQoMnwNGn+BbxIb1nCugB3e1v78fknQBFYCfpHAFfrOHkjKzB06SqhXtr8k8KAjDkxGNVO2/zvAHUHNzyX8E2LMUre5V37E7iwopan1032nN/P7+vo3H40ABMdco6fl8njJMAEnlwPUemLy51ayh9AHro9ef8ILT7vV64fybWVCmSZKEOVHPTAEyls6U59jI3fbWpOdGABW0ojbUxRrBilIXTfPZ2OREkzG9zdJXbW4ih36PHGw+NizPq8nhuunglnylDYdMATIWXdTFLSqKyvrUajXrdDo2nU5DZ556vR6qpLRxa5IkNhqNgjKjWYWZ2dXVVSidoyoiSZJwWDXC2u127fDw0A4ODqzZbIZKmHVcb0/+a0S+1+uFqwy47pdu3ewh75oC4sfHx6H+1/cgzDsYEFsfBe5Go2Hdbjf0x6T2udls2vn5uW1vb9twOAzZFlqupxVdi+6yKXqopzidTkOFEyWWeCC8KDzY3t4OewSQ9N5IVpCmCNlySSZXnosIFN1YKEfThFhy5XwbMlwK5St4KUhhjm+Sb5clh8rC5tQItrrPAAeWCb/jLceY21YU1+Xl4ataKfV63TqdTnDf4FxJ77m/vw/Ugt4CeX19bcPh0HZ3d83sodJGk8YBoO3t7fAZ79+/T10ZoE0I2OTryO5BEovk/PzcTk9Pg/vG/sHqwhPwIE7t7+HhobXb7WBNvgRn7CPVrNH+/n7gwOEZqSgxewi0zWaz8DeafhYref01uEhAcjKZWK/Xsy9fvtjXr1/t4uIirA/5tHhj8N7U0quVrOW7L2UV5waSPsoGYKCxfSG7liDBZTFhWo7FBsdtJwUFbkwtyXWBUidXrRrlO3Hz1f2CaFcLGJCMRUcXJfIWqQljliQysT509rm/vw+BjvF4/KRMjs1p9hgN14AaCqZWq1m327WjoyM7OjoK4EjjDC07W0VmnwivSpUrDM7Pz63X66XuQ+cAwnFpVxpAfH9/P2VJZjUbKWroGtHjgGfFMqQfJ/SBAr9vxaegEqMNinS1+aogORgM7OTkxD59+mSnp6fW6/VSeZHw/c1m03Z3d1O8uL/upGgaRMfGIOmjw3CPHBoWT6NaPlCDtQKnBIdEqoa6SUmShAnEkvQJtpvIYvYIlICB/j+WT5IkwY0gb5K/jVmMz4FjUUPfG8sXdxl5sOhubm5sOBza3t5eUAKApNljr0k4YQ4Bh4JmFXqw9fIp+L48urR4SyXGSaJgtYMMII+yoJEH6TLa7j+WT+jnNK/Be7LnKpVKUMxYiNvbD31WR6NRqoOR2SPNo41ofVepl3a1Y9Z+v9+3s7Mz+/Lli/X7/aDEwAv6Avj2aJ7jL9oL05GrJektFrPsxNxS6aH29+rqysrlcqis0dpf7byM2c1GybPhADJoviOywIkAxFiP19fXT4IxDBZukbVYNA/pZeO5AErtEG5moS+jEvxmjzcgYjVmAbvmWwKkWs7pE3/zsgK0X4DvQ8jewdJXhY6LGrO8fNT0pRUac69cZan00KNgOBw+uZJAewjoWfNUwUtYXTpUkaFsoWdQZnDfKLH5fG7VajV1x5ViiBocLyVHrlfKeqDkQPJ/DCbv+vo6ZZn5S6l8U1gzSzUrKCI5ORaQwjpSzs3zO2p9KvD9WsCYJZsCJe61plRgafFsWItae++fXddQf6ZBLN/EI48Dq9aKz8Pz14ewjjwj+1P540VpWS8FLh4okfPu7i6VnuQtXP2qOcu/FkDq0D2kjW/4Sp9JArPaJ4A50SyQl6AMdCwFkmys4XAY/bmG+9mo2gpJX56P9E0UYheFMYml0mMnbG2t5K1PeEssU33WVeRRTahtrNTK1TtGiOarBmQoQJlZsEZVKy674Dz3MjJ5+XxeIZFhbQ+mbcx842GVR2VEicQqKihDZf0AJDb8aDRaaY2wHLFKtI2YdvHxDZOxVLQfgB5aupqzRre3tymlscqBXHeNvJwaHfaJ8NqImt/TGujd3d0w5zc3NxtZ8KvIw7NrxRPxBgVHDCEtNWZdSBvScw2Nt7u7mws/HJMpOpIlxqdPnxIz+02/Pn369KrkeY0yvTZ5XqNMr00eL1NslJJnYfSB9/n8+bM1m81fxVzfZCT/yPP7+PFj0Dy/ZXnMXp9Mr00es9cn02uTxywuU2wsBZJv4228jbfxrzqW4iR/yxrjX0UD/pZlem3ymL0+mV6bPGbLW5JvnORv+PXaZHpt8rxGmV6bPF6m2FjKkqTE8NOnT9ZqtcLPExfF6vf79unTJ/vhhx/s//7f/2t/+ctf7JdffglJoxptBLk1fUH/zfck1FYqFWu1Wvb+/Xv79ttv7U9/+pN999139v79e+t2u6kkZa2SGI1G9u233wYZFsnDsyFPr9ezn3/+2f77v//b/vznPz+RR7uka3pCVjsqmA2Srdvttn38+NG+//57+8///E/74x//aB8+fLButxvK9mJRvOFwaN98881SMmWNRKLEXGnw9etX++tf/2p/+ctf7K9//at9/frVRqNRaJ1GYjO13+/fv7dvvvnGvv/+e/v222/t+Pg43Ka4SvLyJvIkLguBiPpgMLAvX77Y3/72N/vhhx/sp59+sq9fv9pgMAiNLsia0Aqc/f19Oz4+tqOjo5BgTrI5/yan0qcMqbyj0ShTpr///e8pmfwcsU9UNu3or3dy8/X8/Ny+fPliv/zyi52enoZqqZ2dnVCbfnBwkKqjf/funX348MGOj4+t0+mEHFHNedU1WvYc6bNrgwu94oPuP2dnZ/bLL7+EV7/fDwUketGZzgtztrW1FcpLP3z4YP/rf/0v+8///E/705/+FM4RlWSxvRjbd7GxFEjy5q1WKxMk6RnpC+o1Xwtw1KRjTYL1eVD8jtYca/G+v7/adwjJqpKIyaOLykVf19fXTxKMSQeJ5WspQLLZtre3M0FSu5erbLTbzwJJL8eiNcoaCpLkS45Go+ja6XppYi/dgeg6ww12q4LkuvIoiCALVVk+7zOr85LuSf/yuZ7sBfbaMq26Nl0j9iQgSeHF9fX1k2setHpK83p91ZeXSW/CzALJVdfInydyc0k709xUvUdKq+vIzVVw5H3N0i3iSLBHFp5pEUjGZIqNXJPJsx7Abxh+pgvGZLE5fWL67u6utdvt0Igh1i5p3eRSf9iweLXs0XfWNkvf2ofFq1c/aNIv78/fsUG1FjVWZvWSXA/zoC9tIOL7d/rXrzEURHzOrN7MGCtdVa+F91HLh6tNa7VaKl/XWzRFyMRXP/+ag6r9EKhQwTL2eyn5R0K65h9rDqxfvyTJ715xb+Xr+mg5L3NNjnPMktT3K5UebwJQhZh34nyuIKmL64dqWNVo2tQ21meQ36c7Nk0IuEogr1ZW6rJpom7s4ipNCjez8PloRO2e7BtCUEK2u7sbOrtoM4UiFnkZ2WOKQpOuOTT6f/6AvSRQ6nppUjxNUWjFheumHeSz3k+vfCAZG2u5Xq+nkrfN7InyX3bEft/PoacRfMWKJllT/8w6aX095wjrUzv9+1swY2C5yUAGf5eN9vukO1OlUrFms2k7OzspwPeVX6wR67C9vR0Mp1inozzOUOGWpNnTm+dwC/T+GJoL1Ov1wPV4l5vmqIeHh0sB5TKT5IFByyK5p0WBUuXx1jB9APWidFxPs8ebI1ncer1u3W43BZQv2YggZjViQauCQKOr8vAW9ksCpefq6GLPFQ167w73pqglyMAToKQSC3I0GoX/M3toz0c/VG3PZ7YZUHpA1O89SPoqFL0Xhqoj+n9yvnQf8bf8X61WS3GzXtnlYUXqGjGv8JDa4/P+/j6cY53XmKWv3c2hH8AQbW+X5xnaGCSXORRqQZbL5VArTADg8PDQjo6O7ODgIMVpqZDaQqrdboeuMjEObdnJ0c3oSw+196UHBgVJ36Kq2Ww+Ifl3dh7uEtdadCWdURAeJIsCykWWI3L60lKGguivDZQKknSx7/V6NhgMAkhyGGmiAABsbW0Flw256CcAgAGq9HPkagpfnulppFXl4GuW0vIuNldSoBBoCUd5Hx29vYKbzWZmZkFBaxu5vBvGIJPSF6PRKDRF7vf74Xno7+nboGkMACVGa0U6hSXJY0Ni370pr3O0EUhmLbAOrxH0kqNWqxV6Dv7ud7+zd+/eWafTCdakH9opnGCBtySXtSC9HBpB1JpwtSi1mwyfp81AaQ0G6O/v74fOOmaWAh0AFoKZ3/PduvMGypiFos/FYVRLEpAslUqpWmFAR0HypYDSg+R0Og0g2ev1rN/v22AwCB3TsbI0wKbWotljf0xVEmYPVwl0u93U1SNqbW16EJ8DSHW39f5wGgpjken1vQR42LdqSW9vb6fq9T2ltOnwcmBJosjoTn59fW27u7vBMPJd4RXsWBN/bzqBR9+n1Nfbb7I+a4PkcxrwyQf9Y1MCkHr3ycHBgR0dHYUu1oAkGoTP0kg3r6x2UKvIoYDhLyvSWxp5FiwRje4S2Y1dVaAgqVetotGJluqmMCsOIJXbQd5Y0wEAFOs5doAWKUrvkuYth1pZNNJQFxRrQ29FxBVTK9g/v1pW29vbqWtylXbJS7YYD7kILBd1OuI88LtbW1sB3JWXZK3VS3rO0FlHLjU+lJMcDAZ2f38fzg4pfniTyjGizPAY8B4rlUqInGsfUDWY8hgbW5J+MWOTDSdHRFfvFQFUCMiQ8xgDSbPHFlzwgD6tY1U+0h827VzjCW7lJDWNB2sQGgBXu9vthi7LZo8gqfl5pGBoH8ciXG1PK6icejE8UUfvgvl59c+p66+0RNHcamzt9C4lnh3PBADRPcUe8IoBy03BKO/ghj5D7MX/6VesYU1T4v91PW5vb0NmhVqLSvvE6JK8ZdLotoI0HaE4C3odrt6nrSCpuFAul0M/2ljg91d1t73wns/yxDZAhout94oQhOGmOs0R1M9Sa5LNEcutXIcT8sm63pXBisACZGPWajVrt9vh+fWuFAIyeieJuj3z+TzV+NW72kWApB58kpJJ7NULmrh4LavxqXJGSqzrfsjT3Vkkk7r+Sh1AjcCB437qz3kmfs686LWzankXwb/6v49Z5gwNekLzXF9fB9l8oAP+Ehl9T9BFNIkC7iZyxegdTUFCmWreJo2aFSTn83kASN6LXGYzC95plru9yVgZJGNugO8ZqW6MNjdtNpvhsiUsx8PDw3D5kg/lozG8++6TftfhIZHFu2yABmkKek8Km2p7eztcT7C/v2/v3r0L1rBWZ8C16MVN2kQY98hf+Zk3sCiYQH5fXl4G/g6e6OTkxM7Pz8PFXjynpmEpqGvvRQUTXp6Pzts11aH8KJ+FEopFTTVABheJu87Ps9zgPCwvb8nGLEiG8qh67QTJ8+xNPQ8EoUajUfg355Oz5RPG8wLKLApGccMbPppjrC+1JAFLDX7SLZ/MkkVVQ+uOXCxJD5JeYxOsIc/x3bt3KaDc39+3drsdBNQQPpPjN3nMtV7F1VY5PK8FQHLnN81CsSThVpvNph0eHtr79+9DGRvcJBYxF0+xUVWB6ObQHK+i3G3kRMazs7MAjFiSZ2dnIX9NLWdNktcO2YCNKhtNjNc0myKGBxuGVj/5ihsAH7AAUIbDYajYATj9WmQBwKbPHntvs6eFF1SeYUFxyR7nhHNxf39vV1dXVqlUgusNB+0LMPTz87KMY7Lp+3teV71DfekeUiWRJA9ZB1jHWhG1bqPkrLFR4MZbk2pFeHcbpKc+luhvu90OHERMuNjXdUExNpCBhGSSdHFF1d1mYbEkuRv54OAgWJMEYXxyq84bQze/bty83W1dK7Ukqfc9OTkJ0WAUA/cK6TPqRVPeFfLurubFqlWWt1yx780elTOAqFkQ/t70+/uHe94BFHgzs0c+XT8nizfMwz31Q/lHAoW4qTs7OyE6jcxYT3d3dzadTm1raytEl6+urqxUKkXvrI7JmJflvwzwLrI6dbCugCQeixam/KopQIsE4ZB4YlsXGXDRi5c0+dMsnXDthxd2U+F1g8fy0QBOrToxs9SGJW+Tl/Iinhvx4K9zA6gUYUEyfMoMLjdWJBFhkqaRVevRNZtA5VGuyVtvebk9i2RUN5gBIOie498AJSA/mUwsSZKQSsSNnAo8sf0f4/PyWEN9DxTVfD4PVxhgWanyVqvr7u7Odnd3bT6f23Q6tUajEfhJXcsi12dZcFSukuooLHrSmnwQN1aVh7ezar70cyOXihuvLWIbR7/6HDyAiQ3rE6r9SzeuX+B1JybLYuU59KuWVbLh9LIr3/RAZYnJtWpUftXh6RGlFrCYSZfRqg1kzaqL1bUkcqn19D61pMjgjXo1mscIUGLdQxcgk9IgamHG0kj8PPrAx7LyZRkAzKl+NXvc40oFmFnw0JSH5W+gSq6vr1O3VfI7yivrHtzU3V5WfuTQQCJZFnCNWfnPMSwo0hPL9UpZRXYeEm2AyU+7JJ8WQ8cdnRgPMot4C/7fWxLLPDd/67u9xLi3GNfhLZnYyy9qFl1QJHenG9NfwervOWZe1c3W9AvNtdQ8S03L8soiTxfOy6R5n7xIL/FyaOeccrkcUmX8+nh3TeXWV17y6V7QfRMzQLCefN4mv0PEGw6z2WwGJVYul0PmBQojz+CayuOHrhnKejweW7/ft1qtZmZm0+k02mVJaR+UmcYvFtF0m4yVQHIRkvtos04KJPJ4PLaLiwvb2tpK1duSWK61lxw2NrK6cNpmiVQaeL9VJ0dNdd1Qyq/NZjPb29uzq6ursAGRS8sYr66uouWUsefJAskihwK5ry6CKtEoKoqD+UEWBUdttKCaXN0fdYPz5rrUOubQkSiNdchn6mFjzjXNRz0gD5RZ2Rzb29vh/9Z1W/kcfY8Y36l7mzmO5SYDmtRoU9s8n89Dlx3iAz7Q4T970+GxwMxSinoymVi/3w/rNB6PQ4muKjNtqUY7NIwZnY+895nZBpakAmTMuuMB0Rjz+TykV9zf34cSJfonUlakPKXehcxXJopJ9eWLgN4yi6yTq4nu2seOmleqLnQjq6vAlZdsdjYaz5dlbb8EQHqrVis3tOTN7LHXpXYwUsWXJMmTIJfKru6PBniKcLkXWcbX19epS+49tWGWDtrFmph42RWQfSYHe25VcNH5UE/IW8sxbwSQ9HSDBna0STLpQuVyOeT4ag9JvzbrrJenClRGtZAJkOG9wAdfXFyEoKfSWARlGo2Gtdvt8HfMmz93m8jgx1og6TVaDCixqJgY+sYlSRKsyNFoFAASC0BdANUkmNeVSuVJNxeeRV3uZSaGv+H99/b2UmV4Zha0HUnWWscMYACSmgDPV79oL5EWw4i5/Z67863Q1KKuVqvh8CuPnCRJ6ChTqVRsPB4/oUl0PvmsvF3uGLcFSLLfNGnZ/70qCA+Suq+VA/dg5LnJdYYHlJibrT+PWZI8E7IB8jSFQWETjCKFTTvgK4h5sF5HnkXurxpPeDTD4TBVcggw8qrX69bpdOzu7i54Klp5h1WfpyVstqElGQNJ5RA0F83s0fJSXsu7QmaWcmPVvSbiyIHTCKoewmUnSJ8fy0cj66SGYO2Ox+OUli6VSinXczqdplwuBSWv9YqO+OozqFsae2mggxQL5WK9tcV7qjWp6U5E/n0z2CJcIXWBcbmJkBIA0AIH5oI9qCCjc6AUTyzCHXNzV10XHVn8Xex7T3dpRogHI84PbinnBCWmmSUxpZrHOnmsUPBUZTabzQLwAZBagaMAj7cQax6cJ0CarQGSXkt4oGRRVAOwebFUFkVzletDyyhPRgcdDnJWutAq8iiHRrRQE6/b7bYNh8OQQnF3dxdAhH8TtPA8l1opquGLcD91eKvRJ/zH6nc9FaABCQ+2/L42p1UrEktFP0d5Pj5nUxm9nChXbSaifUE1pYShe0AzFnCj123Ft2h4S9H/zM/5MoffAxzvCXePUodrjqXe+UyMdfdoFj5wzhQHOOd6tYN6dLoeXrlnlUPneb5ysSTVitR7M+jBR0KrunPUnkIqN5vNJ01qlfMBGNEmtEzL0rCryqLPr3mfrVYrNLu4u7sLoI9G4/cpa1MXFl4MWVl0X0nAM+Q1PKD5RsIeNNRV042HG+7fQ+mIra0tm0wmqQDV9vZDg2TlBpXf9Gu2rnz6vedZUa6+o5PSQBrw032JBWNmqXXWNBOfDrUukMSsxSyO0f8tQxWXvvR3/TPDQaoy9/yoWtybUAneCCHgUqvVUl4Kz6xnXq1KXw6rPHlMQeQ1NkoBUqAEYEjcpfFDkiQBzAA3zzMQ2dbGFoAO6Sn39/cheXZr66FZhm/AsOqG9dqOhWSSdZMmSWK7u7vh9kB1+XE97+/vn7igqgx04YsqQfTWle+Ow3xqmywtAtD2aWYW0kxwYbXPoh5EpU/K5XIIgumlWWqF5WGRxXhJr6SUCtESUU0LAijv7u6s0WiETuaACgdaAwo+vWnTddQ9F7P+1RPRPatzn+UlKNWjCo65IjdWwTUPWkityFhbQax6dZ3NLAAjTSs6nY612+3QaYvotwK9p01+dU5SXe6Ym41gWJHwkKQkaLUNWhttXSqV7Pb21obDYZi8y8tLu76+DhwTzWo14BDjPFYBS+TwMvL9zs5OOEA8l7bR4t/eKoU4J31Ic/V8yWJeQ0FD26LR4Ya8Ru2PiCUJuKirlSRJiuvTv/FFAfP5QzEAJLsCsqYIrbuRn3NH9VnM7EkEXikBpX9wSalG0txJbbwSSziPue/rDD3oGrH3HXMUwBQk1OL3bqifOz6Dz9Xf95bnOkCpZ9BTMHiY7BXWBLAmcNhut0Ov2U6nE1os0jtS83aLAkizHC1JJew1TE9jTK13xmrUFBFtZ0Wj1MFgEKJeBEXm87k1Go1QOufJ9lU1uv6uWnbKo/B8zWbTptNpqgfj5eVl6tY6FgigrFQqQR7aq9Xr9ZT1VgQ3ycbR1Bh6RWpzXX2ZWfh9JfmxMDm4WDX6bw4nFjc0BcCqnVvUmllVJv3qZeXga0qTWpKaUqJWPYB3e3sb1oc5KJVKQaljEXuQ9BUhzw3/OzGLWNdNU2XUCgbEzOyJ0soKzilg+jlTrpmKKap8VgEepRE8NgB+0DYa7FNDgqwXQLLb7QZwBHAxMDzNkLfbvREnySQkSZICE803REMDENzPrE01mSgOo5nZeDy2crkcejyOx+Og+bVLtHdB1MpdVRbkyQJJAIZ0n/l8Hu7ZwM1W/ogUGfgxEtJ9Q9s8h24Q5RaVGFe32W8sntXMUiCpyfUaCVZOM0meXg2gVlDeUe4YN2mW5rS8+6lzj5WDy+1z81hH/bkvO1X3fRN+1XPBvvmzUkuaP+xzBGP8JGvpgxtmj14Un837+72xzlBLUkGy1Wql7tdWzwsl2263Uw25u91u9LI89ZhiFEMeBsjGlqRqNz8ZCpK42pjMaAR1OTmUuNYsEgAEEa9pHfos6wCk/n1MJk8nYOLf3NwEIMdi8wBYKpXCYatUKqk7RTbZfMsM734CcLqhFNz4HmDReTBL57VpRoG6zqxd1rUAecvnv1+GwF9k7ek+VgDUVDMFx03cUX1ez0HileCxcJGZRqZ9uawGMRTcvAvvgVMDORg7fu3WGd7dxqsioEvwU4tE1JKkyz98JIaVYkWW9Zu3Mt64djvL5SaBVStjIGK1/FCjVMpbMVS7wlcuAplNJsPzSjFXHmtL24Rp5JgoOAsHN8aVpLG8PbN8cwc9UKjb5dN+dEMhiwKjHj7PX5pZCjC3t7efBBqYy02t/VVk1s/0aT3eTfXz46PJPlDiubos2Z57Vn1m5VH1Lhi9r4fgJRav8vjPBZA8CMei37jV5FB6sFkHLBUXyHppNBqhFyZuc71eD5eZqbtNz1lAEksTDEBpa540FF6eVFauDS50Mnh4uBQ4SeV2NNGcA+kPuEb7fM5e7BDmJY8OpRTUwlAOCc0Ph8SzESCgca+/57hoa3KRG6aHVS1KbeKqCf5+7plzDq+Xx4PLpikzWTJmzaFX2vCNKGgzC7w2ua4+QZnPyJJn3ci2V2DKQdLTgDZ2g8HAptNp4Aj1Ir0snlSfSblpL5uZpawwgMYXDqw7FCTpqI7x1Gg0UsFExQpKKbkGZW9vz8rlcrC0yX4hQk6DZIBXlfSmIxdL0uwpQQt/yMMyUb4Di1k8KdgnjCrhXDS4ePn8Zyk4amRbU2xwXXErCDbFXJk8rUh9Rq9s9GuMz/NzTH6rl1tBQxVGDCC9a5pHfuFzQz9X83bVi1EOnHWKpUchN0M5a/26qiw658pBTiaTUAbb6/Xs9PTUzs/PbTKZBPqKwOj19XWq0UO1Wk09o4KkpkQpr8zvU1lFACtmca66Bhpk8sYT9/P4K5sByd3d3aAIsJgVF4hTEGCr1+sBXPOgC3TkYkmaPW5MBhuUB2WyFrk7GmCIlcH9GiPGG2kVhwKkNlfAzSayr1ZKrFSvaBliX/3gUGgwIBaJVRc8ls7DXlBrS4EyL3DUz/TuMXyXb7pLKg8uGlFWqBIf+Uc+/96bWpE8sypb5SJHo5ENBgM7Pz+3k5MTGw6Hdn9/b7u7u6HtmeeXzR6zKpSy8kCsl5wpkJlZqO/2Z29dd1vXAuqtXC6HogwffUcG7fqjudNktUBJzGaz4M0ArD4w+qtzkkyG2WNzUI1o6UP6BFXlYxYlPHttVqQVwljkDmn/RA+SfK85aD7txltkRQGl58piVhC/p3IrGZ7FH/L3WAa6Hj5DIMvdLnL9eH/fOYpcSTN7opjJmmCNYpkOnqNeV5YYLwk4ebd7MBjY5eWl3d7ehqozM0sZJQAd6U38n1qrKqsPvik9lgdAxtZCo9jKg/IMOrc+e0Apgvv7+6BMJpNJeH+yADa1gv3I1ZI0S5v6Md6QEePzIKuHw2HQEqolvXujn5vn0I3rk7HJj9Ru3molal4gz6e5e4vSYfKUJxZs0JeP1vpUErNH60ytKX3G2Np6C/KlwZHnVSrAW7Jm6SYqWHAACAAJP6cUkUaTi5BFvRZtSUcVkHJvWMTIBUjiwen7KW3lA2vI49No8nJX/XkCyDCeNHPE70dwBNrK7HHtMD7MLLSBKyKrIjeQNHuab2gWt5I8QMLBoDH7/b71er1wax8aJCv1oqiNqgAOKBJxPDs7C/dUaxWORt6Ibsciyll5a3kMBUjlgvU6WN/QlFQtM4tqYP2ZWpj6mZq350ExzzWKzZffC/6rHlQNZOid6hwwQIhgo7bvi1XZrCqbV/Keu9WMD4JKPC8WH2uoa03Um8IMLDCfEqTAyWcAkpoqtAnIeJpK77Snogn+kRhGrNzTzFJWJO+nHpzZQ4AUhacxDH2edfdgriDJ0AX2VoefOADy7OzMzs7Owq19vV7P+v1+qoej5zWL0uTq/nBhFpdlcaMgpPrFxUWqhJKF4n1wKbIAUnlAP3/rDA+QysvB5WhZqPbvI5ARyyLAyjRLpwZpMMZbqCrTpnIxFh3aGED69SRggztLVZAqYlUq2oDF3+m86f7TddKeqT64qa4zYDmdTgOPyHtp/iR/l5Vg7QN4rO+iwN6qg8+BytBrmskAoVSRs+Lv3PHrqNSBgiRFG0Xk5hYCkv6A8TO+qrk8Ho+t1+vZycmJffnyJaQ8eJebA69XO+QdCNBnB8Sn06n1+307OTmxk5MTu7i4SF2/ijWJ262VQDGw8a5PXuSyDg+S1Wo1fK6/GAqg1O4+uEVmtnDDqTuvjU+VT8p7fbLkVZm9pauuNf+eTqc2HA5tNBoF2UlTofac6Kreox5L4F73mZV/00CFvtS6Aizu7u5CcQV7Si1Jn12S5T7r3tPSxNj5XXV4Y4j5Pjs7s4uLC5tMJmb2kD+8v79vZpY62+op6pnxQS6uyVUvLu+zlStIqobSvDv+TxdEecjBYGBnZ2f29etXOzs7s+FwGDgY3xmEaJcmoud9AP3iYul+/vzZTk9Prd/vB+6UlA0lxLFKYtbhcxp9U1lilqQC9M3NTQokSeQFQMg5IydNn1OtSYa3WH1vyU3BZBVZFaRRUmbpAA1zAHWiaSRJkoRGtNQY05DBW5Kb7DulIHT+7u/vUy3BlB5hz/s+q2aPd1GT6kROIeCitfMx15v38N7NuiNmEM1mMxsOhyFaPxqNLEkSazQaITKt82v2GJgCGGMt/2jGklVmm8fIDSTVSlIC2Gsm5fq0SQSW2WAwCBsX0MBd1HywmNuz6WHMes7JZBLSMRQktU8hB9GT4bEghn6efq7+LA+w9BFODiHpMFhKKCR9Xv6t1U7+/X2wAD6Mg513o1o+V90w764CLjwb4EfwbWtrKyg/oqN4KpVKJew37TOgDVnU0snT3U6SJNWBX/c74Ke0jFqValXR9QhXG5DU9DN/Nkul0hNw2dSa9JwkQDkej+3y8tIuLy/DzzUnEp6V9Cuzx2i2drDSjlSezsoLHBm53but5rAK4c1fD5LcHQM3xMJq1Ev7TpIQHGu+aZZfEEf5FBoNYD0SrCE53KcpAVDqRvmUhiKHBxEShQFJbYpMswEzywQ0zakze+w1iNVGdydtXqJVVUVEg33Agz1CBx/q/yllg48tlUqpfEHAwkeItTFsFkDmJUPM7UaZaeJ4ufzQu4D18NVf6qGZWQjGaSRYuzUpH6lAqd5fHsMbUFiDJIVz39Xl5WVIQyJyb/YIkqPRKHQEg0cGK/IGRh0bg6S3vCDEKcwnZ1BNfR8UGY1Gwa3W+zi0Uaf2l9MuQnnwQ4tkUuDXnowaFfWRX60V1gbDvutR0XwdB1BJce3S1O12bTabmdljZ26ej0E0lc2oFjHWI22tut1uaEqgl0wVZU2iiLQ569HRkW1tbdlsNgvWGcAIfcB6Uobnu+XrK3b3el4AieLxFrFfI8AdBaDFCDw/yoAerljNGuxRK5Kh+Z/eFd9UPr56+dgPnK+rqysbDoe2s7NjNzc3T+7eUbzA29T0QM+B532ucrMklcMbDAYhwDEej1NA6U1wQPXu7uHemPBgkn6hViQHodVqhRrcIsDGgz8aWBvW+nZtmpOoz84L6yrmsuUNlPp+/gDSEBceUp8VxeOpAE0w14PZbDat0+nY4eFhaEhweHho3W73CVCqrOvIG3OzVR4AXxsmsO+wwFgvvqIUiLLiqXhqJ1YXnbc1CTepMmFBqky+y5LOyc3NjQ0Gg7B27F+1ms0sZbliteVZyqeyqVxYx5wlgi6DwcBubm6s3++n5hk5Yul4pNlhlBRB75jlZElqoINo9ZcvX+zk5CQAJVqMv9G/N3uYUNxA3+qdjasuIkXvCjh5Dx+N1qRwr5X9wQV4eF64LXVDF1mTeQVxzCxlTXIAeW46p2Odq4umrdWwJrEoqb9tNBq2v79v7969s6OjIzs4OAg9ALPSZjYdSiPw/FwVUqlUrNVqhYR/6BF1NzV/1N/nrK38vBVZVKDQW/w0f+C8VCqVYPUrJ4cS0GR4VeIAqXpyZpaau1qtZmYWbTO2Lh+pVrKCv95pBcXBPsNC1PcwS5dwIofKpl6b9tjcVCHr2AgkfRRLLcmTkxP7+9//bufn56mkcBaKTafutFpcPv1CtbvW4RYZ5fZpO7GglOYMmlmwspRTgvgnsJHV+r8It1uBEjeaDeytdXXZlCdWrc7GxzJoNBrW7Xbt6OjIjo+P7eDgINVq31MieQQ7VB7yOxW4m81maBBxf38fvJnhcBgAX616zYVUThVPJVYFkrfLrTJppgNBDb0yREt3CdRQ6811J6Q2adSXz1Srzuwxc8RnpOQhH/tMqRmaU5tZsI61IINzpbQA76femvL9+tJ1ymPkzklCsPb7fTs/P7cvX77Y5eVl4BBYfL3qwcyCq6fA2G63Uw03fbVILHCTx/AaVGWMRdIg0dW90AvPYpdIZbnbKkse1iTviaUNUKo1AZABKL1eL9Xv0286fqY9AlutVmhtRQAnFrzJY3h5yAvUeSelZjgcPgnWAHyAkN7vHHOxX4JDXrRG5LFqFJvI/GQyse3thz6eKAECotPpNKQ9MVg3zaX0+YVm6yeQqzx8nu41rEmsXQAT40rjGP65Ne93b28vnDlN/cobIM0KiG4TuUK7UWaoIIklUqvVrFwuB6DUwAKWZNZ1D7ESuCKGmvuxHDMPZOoKoul8XSq/UxQfGRv6mfpvfsbhivGm+h4xN0rpBbwCXaui1sk/m37O7e1t8DTMLKTCaEK52WPXHFW8ulYvwR0r5aQpPkoLqFvtObvb29sQoNISQEBSeUgCJhrM0jpus/yi2rrX/DyrwUCKFhkkNKrwVBYACS2hcsWU+T+NJakDtzvW1WcymYTFgCQmZ03B00+kJtbGcu9eAmR8DqMCpC8pjOVFak2uRhOzLMiihlqCfDZrprcALlMto7LhrmlFVBFR7ZgsCipmlpJHPQ1V4kR9+ftF7tpL7LFFQKlgyTNpehqWsUaMteMP9d68H7+bVfWVdyqNzp9P+PfpSzwvmSOaegZm4L4rgPrP8T/fdCwFkkzccDh88nM09Gw2C1EnTGZcA7QZWpxFQ0NqmZGCKwCLtrm7u1uZ3+KZY5ryOXmQRVtoAehZLw3yeNlQFuPxOCy0BjXM4i3JNpEpNpQyIMlXOSEfmPLUgk/l8gqxXC6HbAUFnqy12kQeP++sXez6XKyTWOWG8nu466TPeMttmbGOTDFl7AOHaiUqN+mvOFYrEXBhXcrlcqonKvvTu/K6T8mlXVYeT8Hxfnpjp2+6679XT01jGb4Ch9RDPV96t1RWzCK2RlkL8+z49OlTYma/6denT59elTyvUabXJs9rlOm1yeNlio1S8iyMPmi0z58/W7PZfBG3MM+RJImNRiP7+PFjsNZ+y/KYvT6ZXps8Zq9Pptcmj1lcpthYCiTfxtt4G2/jX3UsxUn+ljXGv4oG/C3L9NrkMXt9Mr02ecyWtyTfOMnf8Ou1yfTa5HmNMr02ebxMsbGUJdlsNs3M7NOnT9ZqtTJ/L0nS3T6otez1evb582f74Ycf7M9//rP9+OOP9vnzZ+v3++E+YVKCqIZotVp2cHBgR0dHdnR0ZPv7+9btdu3g4MCOj49DdYcvT4xFsL755psgw6ryJMnTGm4i+STMn5+fh0bB2ox3MBiE6H7yj9Qn7YtJw9GjoyP78OGDHR8f2/7+fqrbkSbREyUejUb23XffLZSJ52Zus9bihx9+CGtBYi8MjM//1K9aDkYCNg0Z3r17Z7///e/t+++/t9///vd2dHRk7Xb7Sa39JmvEmhCtJnuAXLt+v29fv361v/3tb/bTTz/Z3/72Nzs7O0s1UyH5nKqhDx8+2B/+8Af793//d/vDH/5gHz58sE6ns1ajjk33nVm62SwR+NFoZL1eLzSvHQwGofySF9kYJI5rgYbuK91nWuygDVmI7I/HY/uP//iPteTxa6bNKr58+WI//fST/dd//Zf9+OOP9uXLF+v1ekGG5B/RbUozP378aH/4wx/sP/7jP+xPf/qTffjwwbrdbqoh8rIVeLE1io2lQJIPa7VamZPhAZJ0n7u7u4VNWH0umuaGmT3NxyN/UheZCVqUphE78M/J44GGHLv5/KHDd6wiw+eCkb6kSkATsTWBnvLAVqv1pJwxVrmSJRPt8LXOPLYWOtfkn+n7ZQGCNrnQXpLaho16cA5oq9Wyvb29helAq6yRKi5KJ5N/pIn5ZPBYIr3KiEza3k4Bw3ckXyX3c519Z/a0Gzdrc3Nz86Thhj47L/J3+T3tVak5yFQaqZza/cgbH6vK42XjHJHKw2drFVQWRjBU1tjdTX5/r7JGsZFbxY1qCbQeOUuj0SiVb+gvR/exIwUmLcfSXDC9rEifIQ9uROVRoOHztSRsOp0+yVXTnphm8YRnXvw7ljS8TqWKArvPq2MtyFfTXDo+i/tRSLKOPbOvoMDyiLUWKyKh3K+PWlqsC6/YntN59zX5un/JG6SvYZIkoYKnqCR5M3vyTL6prs8DxVvRBGw/X+S7xnoPZJ1DHXmdK74u84r9nebC0odSq6qSJLHd3V0ze+xsvumz51qW6NsZYf5fXFwE14CrAnzxOu/DVz3gk8kkdTGTJnfn2dopJpO6crxo/d/v98N9N3onD5tVm0roQmm1g1ZDqCUTc++WWWy/FigquuH0er3UWtCAFUtSO6poTz/ekw2sNcC+jBT3zXdlyXt9ODTsOdaEVn2sC2VuyKrlbD5R2yt3PIH5fG57e3tmZqEEcNk1WUUm740pOGpXb2TjChHftJqhZbSagB3rUJ51jooASE2Uj5X66s/K5XIAR71dtdfrpbpWNZvN1Hla5A2tMnJrcKEt2vV2wX6/bxcXF3Z2dmbn5+eh2YW/IF3fj01L9j+Hla5B7XY7tL3Kuweet4pZFC6NGo/HYbMiG7wQmh0lAPCUSo9Xe5qlLQXfvkrdotilSM8tuK4FTY0BRq7qPT09DWtBdyadXygMqhVQWlhayq9695pGF54Gybt2W+VUro4bLHu9np2fn1u/37fRaJTqLsNeUS9BAZK1perG1zfneQBVHj1LaiDwotMPRgc3deIV+J6RPKMqAaWDaJWWZWzEyvw2kc0Do3oCatnq95wP9p/eOUXFHvN0cHAQPDh9beph5gqS2irt9PQ03H7IISXA4Wsz/fuxca+urgI4IOje3l7oradXSOrfb7qgal1woyNWI5qcIA0/8y3lzR7bplF3ykKb2ZNNEGuK4Wtcl5FLiX5tW3dycmLn5+eptRiNRqEZbblcDr0YO51OCIgBkloyhgLQri7wj3Qoh4MsoguQV2R6ydTXr1/DPUQoBW387K0s1gVlCBDV6/VULTTWs+eg86J4kEutSL0yBE8AkOTed5VNn8U3J+E9lU/Wq0d8DbeOPJSb5/c5C/pSsNRzqA052NO41FBfBB2RjVcexlMu7rZGGf2taFgt6pJivWS53EmShGCD/mxra8sajUZwa9UazYOT9FyVXnvLRWBcIatAiaXChVm+dRPzw1z55gJmTwMHvuZ5Wa3urXroDsCDqDtrwTPrHdM0zm00GgEkcfs4kAqSGj2lUzmWZKzfZ15r5AEOANF1Go1GNp1On4CIrrNyzngJuNasEd14aKOGG54nB55FW0EhqPeCl6aUiQZq2EvwcihOrEs419g1rDo2Bcgsfl9rx7W5sxoPnBmlVba3t204HIY6dPoN0EyZLlS1Wi005f1VLUnPM/jFBViwtmKXweuiYNGwUXSUSiWrVqvh772rnae77XkqTHwsE4CRF4dwPp8HcNOWWwr6AI5+FvJpsEaDN7GI4qLnV8sekOf5lRrQpgG0Omu329bpdMJdQjs7OykgwoI3swCsuNm0taNJct63C3oZPaggqzZZUc5Vgy4KCEqtsHer1WrKNWV+9MqOvLlw3XuqoP1toig4PQsa5SaCrZF9lXN7ezsTmLKoq3XWzbvZSmto4Ek9FO9hsXb8jPPIz/TccfNnLF6xCVDm4m77CVESnKCHj/zG+EizNFDi1pFuo5onb4D0CxqLDhOMIlqPRWxmKStQ70chHYO7Snywxn++l2cd2RZpb3UdcevL5bLt7e1Zp9NJNc9tNpshVYPuPgo2gKQP2pDWUYSr/dzw7enoQQgwAEAKEGaW2rOTySSVlgJPq3fm6N7LyzLWNdNn0SAoFIkHf+3yr3mdZpY6a8r5xdzdLKBcV0b1ynxUGoUG2Gu2hf9cgJ65UUNiZ2fHGo1GKmPD0we/KiepQw+4aqZVx3w+D24eJrOCVxagbPrcfqNiXRBA8q3geE6aB3Og9LY9Wk0RXND54SDzbw9o/D+KY9mhG0tzMsknM7NwVQP/jxWpF3rV6/UALnqtAJYkxH/swrOirEjkU8ubBH3lR+fzue3s7ASvRdv1acSYfotY+1iTmsJEcQMHOe+MiphS84DC1cvQTHhWzDHUh4Ik1pa2TvOftQgskXFTKxJ5CIDCr0Ib4GWyVv5zffBS3e+dnZ2UNapxirzWaGOQ1An0h1I7EGuyKJtbw/xmaZNYuTr9HP/vvElz3Tia/gM4qgXJAVI5uX8HawruZDweBwtMrWPV9pqL6RPwIa+fG8yvt2zhG0ulUsqapOoEINjf3w/31Ozt7Vm5XLbr6+vwd6rA+DsFSn8/TN55khqtVIDkdkHyGrH81GpU3hE3Fj6Ztcd6073MTYX+1s9ND6FXzF45a3d/XOzr62tLkiR1/YLeqcT8a/I5YEt1G9kWfJZyhNx140FmXaD0wTWi8r1eL1U55C1kszQ4sqf5XrlkfXlDKo+RiyWpml1L7/TOaT38EMg6PHfgLQXf3Tvv9ATdqLFmuZpArvcg622Iukm59bFcfrjmk3uF+Qz4IkBSeTXuTla3EWBaFihRVNw42e12zcwCoU16kt43rXfVNJvNcMkWd1hrZFtBUi9n0wqhvK6SjcmnMqqlh4yaS+v5ZQKLKhOWiO4t9h6uXMzdZg/lsQcVuNR7IVA4nU7NzEIAqdFoPJn/er0e+NT7+/sATgDL1dVV+FwFZR9EyaLDVpEplmlxdnZmX79+DcG18/PzkEbn78FSPMnKF2ZtY1ymft1kbASSHBa+91akup7a+fi5Ls+eV9IrHbCO8rZQsixJNDEWBu4Oz9JsNu3o6ChweHodKS7t1dVVuJsDnpbPI40EkJzNZqkb4LAUANZl1kQJfMCv0+kEqxBXFIBEkWm+o1IF5XI5WAPKw+o6P1dxk/dATqyoZrMZqJlms5kCNLXOuOjMzMKaAhxQCX4O6aadd0CAv89S0HgwBAiJ7nKNbrfbtXa7nbpymfxW1mw8HofvcWcJimLlxRLMPd+/riXJ5yHLxcVFSEkDHAeDQch6wRr0BpKv/vLPkzcFoiM3TlK1r3e3vcvtU0K8Wc/LW6QA7jL3Vq8yYnykAiX5kmy0+XweQLtWq4VosKa+oM2xxMwecrrg7HgfQFL5T+Zqd3d35YR5b2VpVJbDQvoOlqDW8fLSg3Z/f5+ac+TSVCfNTYvld+YduPFAVqvVQpSzVquloqWa84pVNZvNrN/vhyYOGkFVYNU9kCfXpQAUCxh6PpygBLmrXM+6v7//xJMhh5Bg4dXVVUpOPlPzEf1rE4D0cnkrnjStWNYLn630mrrdWXNfZHAwN3ebr54Pi90+p3/jh258Di6bQO/gVs4rT7fbHxQlsxWwzCxVcdJqtYJW12oV3NOrq6tUQINkXrNHl0FBUkHMB6ueWwutvaYZCNUWd3d3KQuMZ9JGFRwohlqDzJHOl//cWG4nG39Ty4vPMktfMwo1wN7xFRy4sKVSyabTaQAT5X491YNM/md5HkZdV/+86gZrKS88st5N32q1Us0imIvr6+snzV+eC3zGZM0LKEn90TtvtLjEpwYuirLHLp0rImaRuyWpLnIsZ3DRJvNkfK1WCxUcWGu0RyMPbxFf8dzwGyUWZfTfK6nso6qkw+BqUxJ2f3//RFn4muiYq12tVqOW5DJAqeBRr9dtPp8HUh6QxIIElPVaVbVw9TnhaQn8oERi6Vg6vwqQeQAlcqpS5X13dnZSgMMzbm1t2e3tbVBiKDLeA0vUW9axjjh5PL/fb8yxrz7RqLSZpc4Heanwk8jEey2yHv0cKucfU3TrDqUS2OvauCaW1uefWaubdM1VwevZyVOZ5ZoC5F1urR6JBV78UAsIgDw+PrajoyM7ODgIvRfpT+j7/G0yYlakanfvfgA0ysmhydHmpVIppM9k0QOanzedTsPvYKVqbumylqR3t7Fm2Yi8f+yuaV5mj3dTw11pEIvnqVQqwdqJJSZrVF61fJ5AiRUIcPukZIIHZpay1P26KIWCdcYr5r2sexBjwUJARAMRsXxMVX6aWaAZBaVSKXgMaqVqpF/XwNMmeXVv0v2qmSz+jPmcRjMLe1aHrhdrpWlPsbQz9sgmI1eQNEu7XQqMvsQu6+/UOqOB68ePH1NASQPeer2eciXWnYysTavWCM/oKYXY/eBatgbvGDtUaHOCIvp7RKazLpB/bv4VJM0suPf+YKi7yYsDxGd6IB+Px0G2arX6pMQNq01zPD3A5+ly68+gODz4lEqlYNF7Tpu5MrPA9VF5pMGRrLSmTYEkK2dRqR3l6RUkFSTwXhSY1ANQZQbtpcre372epzWpsjK84vd8r1bdqQfH+dC0p6wChn9Kd1utSU+6Zk24bjYWDUvy8PDQPn78aO/fv7f9/X1rt9uhBM53JM9jsyo4akqBPh+bVRtSqMmv7ioaMcbR8ZnazINXpVJJRWmXtST9PJpZsJIAvUXrA+ioFamcnrckfT9ND0DkeDLytCZjFioyKfiwFmoxq7yqoAl0UWLJfsviwTcZMXonZo0jm+47Td1SxWxmKQvNJ4oDkvCWyO0bq+QVFNW/9Rxv1rlVy1mBUvOFWSufVZGV+bKJDLlbkjqyLMbY8JtVE4T39/ft8PAw1BPHrjTIw4r05Lm33hRUfGDKc46lUinlaio48lUB2VuA2vrKA+QioFQlpc9NGlHs9/xGVe4wFnEllYTIKRzT9fV1quU/8vD+MX5y06FAqTKxlmaPpa5eWcUUOlaaurGLLKx1xyLvRXlIZPRWL199JoGCo08UjyVcx6ixvNxt3fuaF+wbcOjZ9243RoTZg2VJ5J73Uqogr7XxIzeQjBH2uhgx3iE21AqCH/KNE/I2q/1m9ZyWmT2hDWLpLt4qi80J3+vnaABkUdrJMlYkc2hmCzefB299XgXTrPnRfD7ccL2ewVuqKAIFyryGvheyAozI4mWIccwKmDEZ8j58/rn8nkMeVaAKah7IlH/1Hf197b6nPvz+jSmVdYYqIjUmNJ/WW866T3gP3Xs6Zy8xcq/d9geJRfONLbIOvC4YE5rF+elibvLMZk/vFPE8oFoZpGD4FJLnaAVvNeh84Fr4SPoq4OjnkaGBjWXnI/bZXi6lCqbTqVUqlSeAFeO2igZKlcWDuq8o0QOX5frqa5Usg6wRm2PdEzrveC1EbNVqVBnZM6yHv15EG0iYPfJ/XjHkZXh4Cg3+XosWtDBBZafyyRsVSk+pUeGT4dc9M1kjV0vSm/oa4vedk7MOoOcsfDqR57vyAMjYYdLUFrQgVTMk8ypQeoCMfZZaj/7gQaSrdvSW6KrDczJZ7xGzeGOWpwZ3WEOCTqPRyEqlUspV9Jynn588N3JMBrhUrF29h0gVtx5A1t+389IKnjwPYczd9gCm+15BTPlinp2ab67siN0vxd8qEPvXJhak0insG01HazaboY6c4TlYpRt0rjgzmqXgK6J8uh5/uy5W5HbHjVm6k412W/HRT78x/VBOrWi3R7WUWhCkyzC5EMWlUinUC+P2P/dcWZYKC55l6ak23sTd4Rmyni2LGvAAiTUAvwkokrqkh1wpEw4565i3JemVnd+H7EENMun6Mvgb/X2Stn0rL29RripLljWZNe8eJBXUqayhqsU38AAktTmt7u0YzZAXX6w8KiWy7XY75Noy1HI3s9QzaoYEnunW1pZNp9MQ5AQoY20UN91rhViSSvADlLrJYlqC4d21LLI9b37I85Fq8XLgcSeVU1lEcnuLLMuSjFnPvoLAz82qI8t6U4CMWZFsMDY6kXw2LOs9nU5Tv8+hIHEbZcCGz2tkza/uQSLyWIaAhVIpzI9G8llnrQqJUSHrrkuWOxkDSbwVDdAAkChZba1GYwxvSWr+oZ6jLB5y0zOm1rDvSKVBKp1bqCdwQiP8ZmmaB7CM9azNy9rPpTM53/vonAIl7ra3JHVoKkdsIRlFkOd8jVlR6vKYWcpCWkbzxninWHoHgaBYdU5eMuv7xDaRt2r4G827JPdS0zPY6Ds7O6nNqhazzmseGl7dat138FOxK1gVIPEQsJBRiFgr6gnFEuZje2XTwXuo96T5qwoUSg3M5/NUuZ++9H6f+/v7sI89QMaokXXl0mfl/XyMgfQdSnYVxFFWMYWt/++NsFgS/qYj13u3FSi9AHpo/ObKGrHFyRsgeU+vVdWa068+nSQLwBQM+Kq0BJ+LpRC7LD7LUi1iDvQZYyCuSf40BAYwNKCgzxezpPPauAqQmiztS95wm+n2Q41wqVRK3Ra4vb0dQJ91Zg/7S9B89sMm3B1fs6LYHiTN4iWi0B760vu41Sti3/nPyAq0FTUUPKvVaqqZc1b1lt9HijdFcMZmOV7fwPeq4TR4o/fSeC7IbLmC9CIBUiNwsXI9rNzYInlZYpap/z/V4FQO0NJMu7rESuHylN0/a2xTspGpACqXy08apCpvq00ysgBz003sOTnf71NzN9X1JucOeTqdju3s7Dxxp5kbDmCs5yKlf97zWWbevZWl3kosSIlLbfZ4hQGUAOdK5wCgVPdTPYJFOb55BUXVKEDpsBbModmDUqL9GwCpgKgxA1WE/nfyyj7wI1dOkgf1WlytSSVWNYrnR9FWE+/tyeVKpZJKxbm6ukqR5RDLPgilC8YB88EaFo0NiGtNdRG3FFIS1263Ux2FisrVM3uanaBaGXK8Xq+bmYXGG7qZmUe9ZhZKIgso+dm6z8peo6Erl7Sx5/z6IBsydDqdUFaprvl0Ok11aPJJ2YAlPKHuI7Pl96zfe7xiirpcLqfAhjQfs0fLlywDQFLvgDd7BGRtDKEZGlpxk1eepDectAsQ81wqlYJijdEXqqxURpose+rjn86SNHuaEKuaTrmg50rsshaiSLPfW5Gqie7u7lJNA1jsJElSkU4P+h4c/c/VzSZvzHc50rZweVQWxUbMilRuD/k4WI1GI/S4VFJd+SLmUS9CQ2l4y5u/WXVDe4XMdbInJyd2dnZml5eXgafTuVZFw/MpNUREuFwuh+t2F7nzsdSvVQESkGTe7u7unoAX7jBziPKeTCaB/rm/v09xknoToa5hqVRK3RigYJk3F+7dYb3Y7PLyMtyVXSqVoo2a+Wr22Ll/NBqFXqdmFmQrCiDNCohux4I2mmem2p2xKIpd5IhZknqA6eysmhzQ8zxILOHYJ7nGgiGkmLTb7VB+SSs4baKad9E+QzeWBj6Ug0uSJHw+2j4WjUUurQX25XJ5utusgV4N8Msvv9jFxYVdXV1ZkiRP+mbq3dlY6EnycM3v3t6elUqloOBZW08bsZ99dRGHddnh9979/f2TtnXasYi14HxNJhO7u3to3qE/87mdeAMYAwqSWvmS5XKvut+88lVFRIoSz6kASRMRD5jIR1WXmaXwJSsmkBdgFlJxo2kYnsvxnKS61T4Y8pJASdCBMZ/PUw0blDQnuqZyAi4++qYb1hPoaHesSW3NpV1Nsor28xjeilSrSZ/X7LH6g3mLaW9vIeln5AGU/r0UHPr9vp2fn9v5+XlwRbUzO/XYcGAEo+BcseKQEUDCClLPiM7zPsi3zJ7ld9SrmM/n4RmygFKpHJ4JV1q52VjyO2vi755ScCyqUMPziZqk72vRte2Znj0UyHw+t/F4HNYJWmSRJbkpUOYe3fZEaqwc0T80pnNWdO0lgNLssXzPE+cx90MtL90Ayn2xabnpTnPVkFfLtfT6BH9NRVHgGANIuGQoBdZAk8KXCVaoQlGrKQ9y3bvc/uKs8XhsZg8gqZFSnlcT43lWv3+RXZPLx+Ox1Wq1QMHEIsTLKgDd31pdpvmo7Avl36B+tAEK+01LEGNcpG8mvMiCXLS2q6xTDBOUstG14TNj8+rBPOuz+D6vUUgXIATXjHmdLJ14JoDDp5tlE5N/leGfZz6fpzasr3rQhWBzcpkUqTFYHXAw3OcBD8Pv+m5C6l4VlY4R27hY/PrMuJxmj+3WnuOrfH4fbiL/p9xtHi63BjNQSrxYH39AlW8lWn93dxeUmUbIzR6avXJnNLdgmj2mcmEF6Z5dZqjn5Pec3mBJCZ+Wg5pZqjLMp9xhYbEW2qCWi8Ni7cWK2m98VTzwVjDRes0k0StufQqWf8+sz+P7deXJHSSf2/hsCJ8XpoCkGuOlXG/v9uth9wB2d/fYtBSXejweh7SEq6urkMgLkT4ajcIhJHVG0z5iHYWKKsP01pIC5Hg8tsFgEK4w5Vlj6xKz+HlunTNAhf+H+81D23tZ1BoGBLFatPqGtcFtvrt7uBu93+/bcDhMcWbwZnQ54sCiQNQqzSox9UPd7dh+Iy2s1WoFICcSr3QI66eJ74AJz21mqesesu6KKsogiZ0tfsYacY0yNJfevbS9/dDAmrXzN1d6wPQe66ZAWUhncl30WDBGc/7UxVayv4hqk1VkUPDmEOi9KBwG31WcYA9AMxqNohF+deVjrdaKtCAXudhYvb1eL1xjavbYk1Kjrf551WrBVVRXCkvIW5N5gGUM9NWtu729DaAI4A0Gg9DNW1NouMGPoIAqCYIINJUAvPwlXKuMmLHAvdp8Dj9D6QIQrJ0HSf4GxUaAEM6be3Gyuq3nNfw598YPIGn2AJjc7KhNhZnTm5ub1NqwtoDoc69f3ZL0gRcPAAjigzYKkN6iLIIjWUcO3B8O/u3tbSD37+7uwqXqZmaz2cwGg4FdXFwEoMFtZaHYJGpFessx76HaVANOSqaTAkMAZDgcpq6+RVnELF7vKqq8zCWudywBf9PhlQkKLEmSoJz8M6KQAW/WkiRsrDGzB2txOBymItjb2w/3ELXb7QDKqwK/7mu/3+r1erCSMCBI9yF6rw08NEDIHGApawceQFLvxcmb99b3iimB3d3dsLcAfAwOrUDTFLybm5sUHaIJ5epR5M1N5n6lrOZ8qaBYE/A/fO+tzCwLVD/nJUZs01JfqpYwBwu3bjqdWq/Xs9PTU+v1ejYcDgM/xOKj3f21D0VxkAzPQWrCv1ICAOVwOAyHkXVVDsvTI6SxmD1WUcQaQuQJkH7f+AopPpsgjKbC8MJaxF3XlC2AXdN/NDob6wy0CkCqwaDKhCCLLycEIPFk1BvQgI2+NyAJHwknWXR6WczYQK5KpRJAUrNgwBBAH5Asl8vBLddigUVliHnts9wsSdUUvqM4HN729nZo68RhemlXetGI8US6YanzpV4WkFStPhqNbDAYBHd7MpkEy3Nvby8EQCDRi7zAiKHgpIm9gCMR+H6/H559OBwGS9LMgqLw8+VzA7OUnqcR8hhZiowX/LD2DNBn9An6WCOaW+gpokUpP+vK5a1J3XcKvgBkLP80VrCgHLC6rz6y7T2YTdfHg7/KQxpWo9FIBVs0dZBzxQtOm4T52WwWgJW/zdsz0bExSHpAgUvpdDp2eHho9/f3YZEIZuhl5GbpQ6xa2WuJvK2QRTL5qGCj0QiX2+NeqxbHolROC43Hgdva2rJqtZq6jY9LzYq4ic/saR4kAIlrDW96eXlpvV7Pzs7O7OLiwgaDQeC+AEI2KyChVjAHTtNWfOWIz8XbxGKOrRGuJBeV8TlEh5UPxaPhfWLv7RUk78+r0WiEslFc4VVlilmTmjfJ82JxxRSqvpcHWrXe9GZB72oXxUfyHBo46na7wUrf3t4O6VrIqDigQVKUBPy+VtsUOTYCSRZYtTlNA46OjkJiLNqj3+8HgbG81MLRiVGgLFpTeJn84pKMTEqCtsUn3URLwgAegjR6kKmsoU6bWm1ffsizbDJUuWAhaRnf2dmZnZ2dpSzfi4uLkK5EWpOCACAIOHBotYrFW3S+9G3T3M8sxdztdm00GlmSJAHgdL+plQhAmlnKSvQe0d7enrVaLet2u3ZwcGBHR0dh3bjauNFoBNnWBRv+hrXXLABA8ubmJlodA8jp32gATa+eVYBUSzJvikcBnzVqNpu2v78fFK4+R6n02GkdAwqrUukQlDyBG7INsryZPEZulqTyHt1uN0SdqCSh7pfFhsNTzgjewd/sVgSX9ZxMHiTZeBrhxh3wrfK1E7TZIyfEQT48PAyHDYuSgxazEDYZaklqdUqv17OvX7/aly9fAlAiw3A4DPmcHECAB0DiSl/vvnGIYw0UsGp8CtGqa4M7CQepBxDOWDMl2F8a0NDEcr6qG44CqNfrofnI+/fv7d27d3ZwcGCdTic0IWm1Whtdb6wy8W8Fd543VkIIUGJ1MnzwQ61IrWpRyz7voTwxyoZ9RBOUarVq5XL5SVs6Te/Bo1QuXemTLFyIxTXWGblwkt6SJGEcgKEjzHw+D2kmEN8+RK9cxK9lSSITG03/zSG/u7uz0WgUUksuLy9tOBymysIUZLRG++DgwA4PD8Mhg0T3IJkHUHp3G5AcDAZ2enpqv/zyi52cnNhgMLDRaJTKQzOzcDDNLAASUVJ4Zw6dut36vU93UktyHTljVgoKCGXLZ2jjB1K1NMOCr8y5tyLh1ff39+34+Ng+fPhgh4eHKZebw75qMrmXScHfrx9leZrWowCnKUq6VvpSheYv1CsiaGOWtojxOrX8kMCYemXMBWBI3qfW0jNXKBNGXjyxjlwsSTNLuT76b7Q5oHJ+fh7qZflbzGp+T7PqX9KCVJmyeCpy6iaTSUhr8rfT6YVLsdy3VquVsiCL4Id0zhQkfa3zxcWFnZ2dWb/ft9FoFJ6dDQ2oqLegHB1K0NfbZh1k/bqJjAqUCmhaGkrOqlpMfKYvg2N4d1uVAnfA06kJcGTtNqVKYkCJJRwrOmAOkc2XXWoWRawFW4wfznvoGvlAkibxaxmuWfp6XN+LlqCv5kdmfXYeI9c8SRaUTjEMor4cJAUDz5t56/HXcLe9TDpIOvadgTSpVzWd2aPC0MCGckN+TvJ0tfnqgzeks2gAB4oAT4CUHi0rVK7Lt+HXdlsAjQfFWIrXOkN5SQ4egAZvpV1uSPXRz2ONYnyWurEa+MByRLnx3nlZZB4olfP3ecP8XOeVv/XWe6y89iUA0uyRZ+Vn/Pv29jbVhV+fR93r2C2IcJisnSq9rOdYV8ZC8iSTJAnutUbkfC4gQwUtKp9u1eGtCwhiBQFPKGunIzMLAKsbmg2rJXsxji5voOR7X2kDD0ReGpa8yufnQZPHVR51BTXJHPn1ax4HMxZsWVQDH3ODPUDq+/L8/j09N5g32Oi8e15NlUOMc9Of+9Qr3WdZ65A3WMaAEnxQr0P3iv6ez3pRt31RXmqeciwFkjzEcDhc+DsaIEADxC5H94m33ppUDUJLevo6klK0rMXFM3uweE4e/V2NLvqINnJpArLmgGrXE21EPJ1OU4EsZFrmsC0jk7ce/XWjVC34+4qVEIcv0ufXpGp9ZjZyVtli1mFfZ42Up9OcTw2YxdYllnLm/63vHVs7Pou5oa7a78dN9x2/rwE31i229zRv0Oxh7/mKKsoyCZTE9lwWYOZxjmJrxl7yV2Ooa61pUJwts8c+EKr0tTZ/Z+ex9VxMScZkynr4Z8enT58SM/tNvz59+vSq5HmNMr02eV6jTK9NHi9TbJSSZ2H0wQ3+/PmzNZvNQriLIkeSJDYajezjx49Bk/yW5TF7fTK9NnnMXp9Mr00es7hMsbEUSL6Nt/E23sa/6sg/g/RtvI238TZe0VgqcPNbNqv/VdyE37JMr00es9cn02uTx2x5d/stcPMbfr02mV6bPK9Rptcmj5cpNpayJJvNppmZffr0yVqtVur/EgntkyZBLbP2Jjw5ObHPnz/bzz//bKenpyGN4uDgwL799lv705/+ZN9//719+PAh1TRgmXtVFo3hcGjffPNNkEHl+fvf/75QHpKuufOFVCT+PRgMQnux8Xic6lDCy7eu0s4sVIlQybG/v58qddMaaW0+Oh6P7X//7/8dlcmvEbLQ0m04HNrp6an9/PPP9te//tX+/ve/28nJSai4IY3m9vbWyuWyNRoNOzg4sO+++87++Mc/2u9//3t79+6ddbvd1K2OmyTEL1qj2J5T2fTaBtJJqCb65Zdf7K9//av95S9/sf/5n/8J+y75R54ec0x3n263a8fHx/b+/Xs7OjqyTqcTyi/1crZlciNXkYk9p6k+FxcX9vnzZ/v73/9uv/zyi52enob6enp/Utrr28D5ck2aqlAGy1fKY7vdbmjCS3K8ylgul204HNp333239hr53Fyuyvjy5UtYo7/97W92fn5u0+nUyuWHju/adYmmIgcHB/bhwwc7OjoKlU/Uzq9yu2hsjWJjKZDkg3jYrEmgVyRtwzBh2QRZL19dw2dqH7x1mwd4GVaRhxw4ZLu6unoy8T5plzpVvtdmwz7hmQqOdrtt7Xbbms1mKPUDKDmksS4zz8nkAZ/cPn0/v5n4G+YZWXwiM4nkgH0etzsuu0Z+nfR6Va3W8F2r9W/5DP1MTZL3txVq708tAHhuPy4jkwfJUunhhkZVQL60kAuyzCzkgTLYe75ogffyMqGMdV9oFdEqe86vkTc4yAvW/2e/+eR3zbfVZ9G58Ffk6vMvm+D/3P/ndqWsAoteBTAYDKzX64VrPuklifZD83PlgSbp3t7eBpT3lQZFDA/cJBLrDYLIMRqNUpaXygRQVKvV1HP7jas10P4Ablqi6GXRjir+7hd9aYNXM0vdo8Ic1Gq1UE2kJZe+lrbodWKNeLbhcBisrYuLi9AXU+9EUbmwbNhr6gFoRxo6rL9Es5XYunnjgSqgvb29UP+sikAVmQcKX/niv+YlmwdIvUMJD7PX69nFxUWwin2zYzMLc85aaYI8yf2+gIE69rwqoDYGSb9h1ZQ+PT0Nd73Qp5A7X7hzZDqdht6G29vb4QoENjbWgWo0NE8Rw7sHXFPKotItR0GSphYsjDYV0DJGtJxqQq/VtXfjpn3+ngNIraAB5PUuZzNLKYrhcGi9Xi+AodlTkCyqgWtMNq0kGgwGdn5+bufn59br9UID4bOzs9D6DaAzswCKdMun47V/6RpWKpWFpXCbyqPvGfO4zB6tRLoOUbWl1SlmD0CppX/afCP2ubHPykMmQFLvfgILaLByfn4eOvjz7LqHAEbAVqvetGeoemu+58ImYyOQVA2kkzEajcIm5a4XGrtSDkeXHD1k9/f3QdtMp9PQ9krdjaIK8r0GV3nU0uXeGq49wIoEzAFCLETfodt3yvEukXZqWRdwYrIgj94nogCp96OoZQIQYanR5QmwYH3U1SnS0kI+/2z9fj/VH7PX64VDOBgMAs+qzTpQghwsLYHVRrqsC3+vQFmEMvCWnXelUaJ+jfUWRzNLKa9F58aDYx5gGVujwWBgJycn9vXr11RzZ7wzQFKNA2+NYpGyVlpaqWdKvZpN16gwSxLrkKAA9zhjddGxm9+n1hQt4Zv2cnMcrZSK2KBqeel9xgSisFKQhZpTbd0ECNJaS11p38dPiXHfiCDWjGDddUEeJc+1VtbX1Jul+Tpc2tFoFK5w4Hfok8ldRi8BIP4AYknSH/Pr169hz1GrrletlkqlJ2u8s7MTDqHWOeu1FHt7e6kbEYsc+v6sBR2YaMCrv6s15nSh8qARq1+Oudx5PLu6297a//LlSwhEQVddXV3ZfD5/wn+z1t7NVktSg6FekfE8m+zFXDlJLTIHWNSC1DZiLJZOonJh1WrV2u12KIL3XUDyPIDe1VDLS69bJaKN+wYRbZbukIMl2Wq1nvSM9M1Odf6UwI61slrVgs7iJH1LOt8ODUXEc5g9WI5wfzwX0UdVFnluzudk8xkIHEIsSW2ArOvkA1UKmB4oseq5KdNbkkWNmNLEauL/2StZvRe11Zsq3qJbpDFUQYMLZLtgbNABH0+MlnYqn661KgKv7D1G5LU+ufWTjG023FE9REyEToJqduWH9AJyPdxFBAdillesiwovns3MnrQN096DRKwBSbgkNgNWgDYZ9kGq2GtduWLzp+3//SHECimVSiHKP5vNrFKppOiG2NWxLxG40dQzvRpXA2rIs7u7m5JZFZGZpfYgAAnI1Gq1JweRZ+E98xoKjlA1tVot7BWeXzl6uvpsb28/4cdVMcca7RZp8XsvU8836UtmFhRALDNC/+0zYWJcat7KKzd3O8bnkb+m7qhyKj61QCfI82mxSSkCIHFJYy2y9AJ47q9h02kQhm7dsVxH5VbNLAATh5nN7ZWI/9my6xHbRGqtagCJ99WGrawJgKLWm7bqiq1T3iMmj2+tp+uj953HuqIDIqwfVpm2QyNQwm2MMXcur+HXhSbCrVbL5vN5iLLr76K8UGDqcppZiirQ/ecDbTHLch0Aja2PKmedN28lY0Co8cBLccJ7W1mt+fIauVmSZnGrRRvQchhxOxUo/KISEWYoYPqGsJtMSBb5jSmv/RMVJAE3kna5EZLcRhKUm81mAEltsqsbwWvGrA2h+WOryqhD+S0OIodNg0gkr2vCtndzYz0bY6Cc18jaY6ydWuTMFQEzlJPvUu5TR8iH5f23trbCRVYxlzsP+fxacz9Uq9UKfTqxvDSQw3PgcvpUJ0CIPepzCWOu96bWpVp0nlvVXGHuTkLWWq0WFJN6ldxzr3SB5klmFZ3kte9yi2577cFXs8dNWKvVQtI0lwApKXt//3CPrr/JzSwNkrjtuqibysHh9xak5s1pg1rkwsoAHAFGwJGkcBZSF5H5QW51f9RqVHCMke/PyaVrpe41lm+j0QgAgmvn10dpBn4Wa5LqPzNvaz/LSvEArYpA06u0KCFrkJiOAtja2rJ6vW7dbjfw40VakrjI9Xrd9vf3w5lQ+slzyhqQU+8NS1iT0P1efM7yWnYNn3N5wQE4XtaIq5bJKDB7uPJlPB6HtWIdNEijieTIFYvk/2rR7azJ8D/XFviU4FF2yFUPqjHMLHUBO2CibrACDe+x7kTELEk2mwIl0VE2YalUCs+ouY6AIprRV6P4DendZ/2/GFCua0kq+JJfR9kaoLe9vZ264EvTskjtUVBUOiWLM9bPz2vEFLK3Xpk/lBiWvrfkUY7KQ2s+3t3dnVUqFWu320/417z5caUAUFZ4LICk70SuLingoSCulpfm7/rcyWVomnXWx+OA3gXucaHValm1WrUkSWw2m1mpVAprorL4iih/bXHe1mTu7jZDrSHV6ER82+12aK8+m81sd3c3bErl7TQ6xp3KPojBQq/Dn/DVA7FGzjgYyr0BVLpoqq31K4vno6qxeUMO1YKxNKFVhud9OYRYwGYWXG3KJBUkR6NRAMvt7e0n7fV9RLEo7tgs7srFDrMCjpa3cuGUgqReMjWfz1OK0MysXq+HLAvlJPV58hre5YbiAFDwanQPxOZbU7k0sBjj/ZWDV0Wcp2zqwXBOYkqMSwST5KEaz1MjMT5d6ZQiihlyBUnPaSjhr1Ffaj31ci3Gzc1N0HBYOBDp3uLyr3WGd918vpl+r9UqHDJ1u7LcZX02b7lmBQH8+6wjq38eBXS4SKpOCAxw5W2j0QgVUPP53MbjcXDR/CHzIJU3cKw7fNDLB77gYc2eZmdoYE7vYMmylvMaft19UIK5V6/HW5Q8I3PAV6UlNAtFYwPIpAp5E8CJBVpQXGbpoFrsLKoiVuDWuSk6ap/bvdv6bz8pyiHohe9YkpjVRHZ5T6zM0WgU/u35Fl6butwxgNRyPZ+GpC6nr3/2v6v8kW4CtVyzwNb/+zkZY1a2biQFyHa7bWZme3t7QVlpswOi7tfX15nXsuoc6teihp8TL7M+j7rRrBPACGjwe7FgHfKzrkVFtXV4oPDPRq4q3aiwbtXL0fOhipE0LwV6DcJVq9UU5+7P1zrDe0LsQxSz2WOxAvON1wIPjpICL9RK9lx9nuDIWBsksaT4Pktjc7DUvSM1hnQFMwvggvszn8/t+vraRqNRirfU3Cqf0rEqUHpuC7JeM/t9BFdB2lcAVCqV8IzaIEE3hLcg9WuMrohp8mVkzLIgFTzMzHZ3d1PRQ6x9cgq3traepI7wnouojiKBJGaxxxQJewmLCWWs7rYCkU8f4tCqdRaznPMauhexEvUmSF7aaAWeVEEcHlODNgQ4dB2vrq5SNBHnE2WiVuwqssbWwwMka3B/fx+qbZSHJPF8OBzaZDIJdFCtVgs4oZ8V2wd5jY0syUVWpNYlEzX1qTK7u7sh8ZXNqSkxcJT0qKzX6zabzUJahhbwr8t/xfhI3wUG7aaWSZIkqU0MCNVqtVQSPbXBbDStr+Wzdf68m6PW0qoDCwD+NLZptdRODxU5kUmSpAjxWD5a3po7NrKsaj9nnvfV+mzmQikDZMSaYs3U2lzkYufJu2ogSZPataGK9milPFYj7upmK/espbEavOJ7rRJjzZVTXwYkdd51fXTfaOUQgEg5Ms+PwsJivrq6MrMHj4cAaizDoKi9mIu7HXPvFChZLN/xhgW4v78PwRtNc9Amvjs7O0HLkTZAs01NZ1jXmvRAuSi9hcNXLpft6uoqRSvUarXU3cikOvBsao0q/6OHN6Ydmetl14Sh76mkPEqGOWWoe8VmhSPGAlCQ/DXAMgaOnpsiwGRmAXCQz/OqrIdaZHd3D70DPG2ySdR30eC91LKlAxWWo7YepHEHpZfa2MLsMT+Sc3JzcxPWmvxXLRTY2dkJJYL0TIiVmi4zFq2PNgzBg/Eem89L5qyZ2RPPLpZZkffILXCjh5pJZzKSJHmSp4Xpj5WpNaZ6SJnIra2t0JB0b28v1SjDuxk8z3MjyyrQCdeNp+/L4YIr0RZO/quZhURltSTN0odWXVgPPpsAJe+jbhhrFCv5VGtZ50kBlk3vAwovCZT+8LHHqtVqOERmjyVvy4ysPVC0q+0/25dbqhVJR/x+v2/j8TjVE4H54Vxp53LWlXPn59DMAi+ZBwfrFbNPAMfAMHug27CK1aLHqmevxTpX8cLgYq3yUty5gqSa1rpxMd99SRRusi/CZ+I0FQfTfGdnJzRV8CkZq07MIrrAA0JW4q26dL5DCfW/0AO4UjxvqfRY0I8FrJ+7aXRRgcu7QPpZWrGiVpVuSAV1LHdtlR+bd3VF83BJPfDr4dNgFIdPyxJ9cwd9RmSHI9NuQfyO//2syP6mMsZcbpQtrjdBGzocsUb6DBrkwAtAbgUd9l/s2deRJUaLKB6o+4+lSo4u516DSsrbY30qT6tZF7Gzk8fIBSR1YmLJnkmSPCmHYlLm83mqflMPnWpTNuXe3l6qAQZazydqLwsuXpv6JFUlvOFoNHijml/vWKEze6lUCmk2/D6bWXkynjnvlAYPUMpT6lc2Ii4eDSI4lJDrUCfk8flSSy8j/84LKJFJuW+S4rvdrk0mEyuVStZoNIIliTLy86nUCT0PmZ/JZBL+1syeWJP6ddPsCj98VFs7E8F1611KVAhp1FcVMfOkslQqlXBeoV2gwZTCylKAq6wRn0Fmiy/r1TXhM80spbyV/qHjU7PZDEFRzY7xEW9dx3VG7u42mh2NgRtAqaFP+NS/USsToXC5WXTl+5TLUE3JAVj1ub2rRgs3vXMGghnNzWexmLqI8GJcUuTdaTauWpGqRGLuLM+86tp4gpuhBxJQ5BI3tVi0xyeuXIwT1uHBMk+LkgNGsKzT6djx8bGZmTUajRD0M3uaBaGfj8IjmkpUX8tfkSWWu6fuaB7WZCzCrZkd6kFp2g9AonsS2T1Imj0oaABRS2mpfCGxe5PkbF0nKog0MKTxCSrTsJbBCuYEq5PL7LjqhX2LotBYCAbApvst18CNWkIEaRQkvXvmSV1NR8DyApCYJPhI1aoaXFBu9LnJ8RYJ76/mPotCLTZWbblcTgVfNI1kMpkEK5ILnZRq4EVuIgupuYyLGhBsMmKBKu0Bqh3Y9UY+XG7WD+sN2WKW5CLuLg8KAcXWaDRS+2x/f/8JRRCzyFlj6oQvLi6Cm86cAPzeevQBg01oET+UltF8XQ+OOr+xOS+VSnZzc/PEYiOgs7W1FSxwburk1Wg0QiRcA13rrJHWoptZAMhms2mTySScrVarFZT0YDAInqbOM+eR/8P1JnaBZ5HH5YGMQixJ7eaMu61F9TFgVeJdrU1SAgBJvd+C77XFkj8QiyxKfj9JklQeIQOQ1G4+eisfHB4DQJ9MJoFvxX317nu1Wk2l2bCRNSUjD17SDz3k6s5Np1MbDofW7/ft7OzMzs/PbTgcBjebw6XuLRaAbyygn8M8xyiDTWTSvUavxe3t7WBFqgJTxennQhUhVApcH+vMPlJw9MEuteLyGppQnpUVwTxAGyAX/y6Xy0ERonTpao6C2d/ft8PDw9B8Btc7tq7LDubczMJZMLPUmkEbUMAwHA7DRXNY9Jq3iocHFUI+JXt0d3fXWq1WaJahc7XJ2uQKknzlwGtwhslWl5i/iUUoAVSzx8uo4P2UyJ5MJoFf8VzEMsStLqaWR7LJptNp6npXFoRIm2r1UumxrZgmxmIJ7+3tmdljDpv+nZLOPtVG5yoPV04tFL0JEpDs9Xp2fn5uo9EoBM2wBLACNPfOWxs+K8CDpK7Lum6cziPrr8nG6gZngbSuFXXSl5eX4ZBmAWvWa5MRmweV058rzooGK/Fu4Mt5Xg/svKfvp9DpdKxerwcO3pehrjN0rVWxsf/0DiiMCM4hBoc2T8ZoIsGcIBsB3clkEm0C/U9hSXqiXjdOjLcxe1p/qZOl6UD6HhocGQ6HoaWXd/chdJ+zJPnKs6g7qk10W61WiCTyuzs7O6mcLdXquN4qs9ZC+wPsI+o+DUifd5218QDpO3nrVRuDwSBUdJBRgMJSD0ErOJST9K4pz66Ka9PIYwxsCSapK+rBMcYxQrUMh8MnLdX0b4oeMdpKy3mVm0c2wFIrxWJBDD1nsUizD9rk0U0HC1wNI29AqfGka6rUlRomWiZM6ahWuik1l5cCy70LEC6C1kEDGjEuhUmrVqtP6kiVmzNLu7NofNwJNoZuBFzp54YuDiDHJqJ1WKfTCQDJzzWApDWlCuyUeJGSoXLH0os8r6rPmMfaaCUHgAhAYkFyI+R0Og0ZBXQu18Ol3I8+b6w6RQ+sehabygU48v7+s70i9POB0ooBvu4fb43GXpvI4PeF0hqtVsvu7++DElDljcuqebokl5vZk4wNBUQNmGDR5QWQKhtyeYMgtv4Kkgr60ATwj+CKr5DTJiR5BdVytSS9pcLCQWxrwwgABW2tqA8vpLxIqVQKADoej+38/DwQ05jjpKj4csVFgw0a40g1IodViNuJFUb9bKyNltlj2ZtaVJrupPlqnoPM4wDG1oY7tLk0q9/vpxKU+/2+DYfDkLqkHVs8r6odoZlLZPZWj9IvCl7ranqdF+YsZjlkzR/PqfSGKihPIQHIntbJCyjV0uIWyk6nEzh9Ah3sPeXluXgPXs8HMsghhbODt4MyKTJYiHxmllI8Zk+9TuWAeW1tPTQ9vry8TJVoQgWZPVInWalam4xc77jRSgECFuPxOByKRqPxpHMyVpZODiVUvvlukiShUYFyhvAQpVIpbAjchmUtSeUGkQvOBguQfLJ2ux3yBzVVRrP/Y7mUZpZydzT3soiNyWebWWptptOpXV5e2unpqX3+/NnOzs5SLvZoNApRfNKizNJXcOidKTFaBE+CuQMklT/2LvM6w4PtsgdCFeNz7690jFJEm6THxIam6+zt7aVq5wlI8QIstRqHYKCZhfQz9nC73Q59XPf391PXHestAHmBfmzEaCMPkl7xYD1jUff7/cAXD4fD1BkzS+ezxoJq6/CThdyWiHXHIsKP6BWfuN8sDBaGJnIrN4mAmNNYq2jUJEnCNbTdbjdldi87VNsBCEywEt3klF1eXqZ4VPiTra2tVAI8f8/B8pH8LA2e10aNWZKXl5d2cXFhX79+DXeja8qPVpywEfXZ9bmxDrRCQlNWeC8CDToXeQU9FPQWvZ8qdE0vUavX7xlv5Xl6JA9g4e+YE64x2NraCo0dtKZZq7qGw6H1ej0rlR4DhWYPAU8sUM4FNwPoNSqq6IoARy+nGiR+D+icst9IFxoMBkHJqpttZgvXdNOE/9zu3fZ8JEKQMrO9vZ2qaaZNExtPN3cWSJAKgdvNK0ke8uO63W6o6V617tRrGQVK/o3rwsZSq1MHYOStJDSj1rVr0niRGzTGSV5eXlqv1wsuN3yPB8WsIJNa9MoH6tprdQVdp5Vu2BQgGcybrqHno7zHo8na2one0wRm6Qizn4O83VL2nbreXGnry/IoVMDL0vQlQJLOW+pqkzCu2QlFWpFeRjOLgqX/Pc4dzzifz4MnRJDOu9zIrlH9TTj+wgI3ugnJRdN0Ew4Mhekgvd4r41siaSqQmaXcYE0ViLVRWmZkud1wcixmqVQKQaTZbGbVatWur69TixkDSI1aLsqHfAmg1ARyXtPpNFh+WM6+AYYPzLFu1D4zL6QXsf5YRGYWXPZNGyisIrPfm5owrk1sfRfyLEVXBEh6QFbrFcMjVgJ6c3OTuhCLJhXMuXo/miyuhQB5Uwfryq/A6PfG/f1DM95YipZfW335Cr9fzd3WgeulVQNouX6/b5VKJWgEeC24FHKgBoNBuGRegyK8H5uegBBRMJ/ku8mIuQQAgie6lfiPvU8MJLNcbT+Xm27cGG+sVpRafdpbslwupzaddsfWXDYftPEgifXYaDTMzEJ0XO8zyXPE+Ckt32M/qtLWgNVsNgsKGjn1gCkormudZA19D1+ZhhKiPFSDSJrio41l4Ojh9LAiNWizTlXNJkPXPOYmK8UFWN7d3UWvw+U9VPFpGaenEZ6jY2KjEJDUyYY/TJLEhsNhyF8cj8eB/GdhS6VSqKrp9/t2cXFho9EouOhMApyRpv6oxeCjZ8uO2O96FzMW2eRvfYa/vkcMID3QZm1S/1ybgIoHSwUO5jNJklAN5MGRDuy4d+pms9basZ1UFHLZzCwVXNM523SoMvAKQWkaFLbPDz07O7OLi4tQaYSbxvA0UNEBDrPHlBhAg0PuZdNsEVLqzCwAJIEbXlTWeJAs2tVmxLzO2NUnzIHSUljSzAWGC3uU4NZ0Ol3IHy+77wq7CEw3lx4+ItK9Xi8EZpQ4xl2j+gOLkoPK5BW5kDFuhK8Kih5sNHjhwdqDpAZufEKtPkceVqQ+r36vz++/hxxnA2u2QrlcDm5MzMVRCxXOiNZYOzs71mw2U0nPeYCkV4zerVbgpoltr9cLe0yb2ZJCQ3cps2yAfImR5QX4ZrU+QFYqlVIWJEEbLEkFyay8xaJkyfJq/J5QPlzpLv9esbZyNOjwL5+G9NzI3ZLUfDJNDYErJM9R+ThNYOV3dTOTh6g5iMrXZOUZ8nvLjJi1FntluXBaW6ug6efGu9rekvSHL7aYyy7wMn/rP4v/h0f09d2j0chKpVIqgqqVDrHKj+3t7dDcgE5Kvioij+HXyVuPuNdE9qlRx82m+zfpXGaWubcWud15yRHj2XQtNCUIA6JUKgWKgNQhuEgN2mh+5HOeTBHD8+O+P6xZOmdSI9aa9sPPPY0Clcf5wijRIM6LWpKec4uBgLZm1+CIj14p+Y/lwoEze0wR8InN3iJdV9N7iyTGmcTA0V/34NOPUBw+cOMPYBHDu/585aXcl/6NBmm0zrtUKoUUJ3VxtFhANzuHll6CeXS9zpIvy+LSS7SgcjSZngojWqyVSqUoNRLjjosASL/HVA5N/9GAJYpNgQF3G/ea5hWa+vNS6T8xg8O7yXrWNQCqKWWxJiN6FjXzRcsYvZf34u62WnY02GRREEwJcSYCzcfigvAKOmhIdVn1elqvHTdtkbSI2/JBjNilYdqtxUcsPcnuAfI5K3LRzxetjQ7lVhW0qbDRbjJaXjqdTkNAAG6ZCLEGfrT8ErDJOzjjh18zDRiSS0jyv1YYUZZJgJASOPYYHaB80nURLrc+uw+sqZLyVi/gMp/PQ0S7VCqFskZca3jgWAnmS9MHeqaUygEjzCz1bJQjoyQ0B1c5caWIvHenv/eilqRakQAkpXwkdauVcn9//8SsJkLq+QYEw9oslUqpGlTSGmj11O12Q7PQTWuDszasEsP+pRwRi+OJfx+NVLlVdn5fnyf2/XNro++l1rvW9PLSe1JYGzYecgOi8JEaTdQqIwUardCJNV7Oa/g103xQbhfk+4uLi/AzADJJkpTlSHXVwcGBdTqd0GOxCPfUPzvzTWoSYKgBJzrGa14wubwkYVNl02g0Uo2viyo/XFZW9VDIexwMBoHqUG8TfnwymViv1wuVYZw1M3sClovoMdaOvfrc2PhKWYTBEqxWqwG4IPcJ0PBggA2HCmvD5zJxYFlQrXghSkclweHhoR0fH1u32w15VOtuZG9BalsxrdmmNJEXeXa6eGpFeWtYieSs9J/nvl9mjfR7pUSwxnlhQWrpJ1Y9EeH5/OGqX0CSuVGLHxcOS58aZKKq2s087wPKgfA16qenp8Fy5EASFMQKwwPynbqPjo7s+PjYOp1Oam8VAfJqnRNMwtrVTvGAJopNzwcvvKz9/f0n1TUvHc1WGRUk6eVJUcPl5WXILDB73L94onQl5zpd+mTGqBsN/niwNLMUHbZo5GJJKreIFYkFyUHETYN0pnRPo8J6iNXiUfeaRT8+PrbDw8NQh0p3ZXri+ZriVUeMM1GApJEAL5Kx1UqOWYYq33P5kVn/Xsfd9gCpLbhqtVoqR3I+n4dAGXmPgCYRaiXUsX6gTnDfARkUGRZZXta+nx+1xrRG/ezszL58+WKnp6c2GAxSSg059ZmPjo7C3up2u+HZ8VRizaPzeH4OMK7lYDCwk5MTOzk5CZ3i9foSSvIIjGGkcAUDNBQvvUbkpaLZKl/M+ODKkPPzc/vy5Yv1ej0bj8fBlWZonIJGHvSU5P91eM9MP5cAMUbBcyO3wI23JFW7UdqH1hgMBin3WwGSNCByoXgP3hfN/vHjR3v//n04eGh++KM8GoZ6S1I7sOhtdXqJukZteX4Pkh4gvcvzXFR6HX4vBpTQFrTc4tk18OIjidfX16l51Si+KiW6JXU6nQAw1AwXAZI8Y+wAXlxcBLDhGlbkJTGbV71et4ODA/v48aMdHR2FZ1ZuMutOn02em688+3Q6tcFgYKenp/bzzz/b+fl5ioP0lUylUslarVbgUQ8PDwNFAB/5aySOZ8mLtU8WS7/ft9PTUzs5OQnWvXLbGqeAilDu0p8xr8A8UJq9sCVp9rSdPv8Hal9dXVm/30/lZSl/oPmPAAw5aspt4bodHBzY8fGxHR0dhbwv7XK87gFcFIHzLeD8XTvKy2FVqRugcxVLcM16ntjzrTM0mq2RW20mwiHa2tpK5bYCmFlReFV6rB2KDUum3W6H4FpRpXDeIiNtSXnIyWSSapTMnEIRwENC3wCOsasq8gR4z9VpfT1WsFagmVnoU8oVFjs7O6k6bYI1vpH1S7vZMTlZI9xouuIPBoMUlcDfAawoErOHfUdxAyNm4XvO0uyRd39uFBbdVg6ANu2Y+7FUEyVqcX/M0knYWKVEzvUA+o7SefBdCt6xmtBY6g/AUiqVMpPKma/Yy39+7PtVR+xz1KJUwIwBF5vKJ/HzPfPtLWjAl7X3ZWV5HNQsq1sDblp+SBs4OFRt6swzw6XimRCwKbrOeVHUl4g2gQ1VxKTY4b0RtNFmur9WJPs5WT2dRVBKDY4sblHLFlf53FUNjdzcbQU5D2qawxircVZ05/38+/scTN/8VQ9fHhtBF8aDpfJwml6g1QKLsvpj7sAyz5LH8FHuZaxa3VQeJDXAppo5ljNbVE6eXytvlaknAHWiJW+qpHXvatu+ZcpH8xoK8poMz0sTxwmGIrPnhX/tSHaWfP5MaaoZoBkDSW84oOB0LDI61vHElgJJ3nQ4HC78Ha8FiQj7RONYUqe62jqBZpZ6X79xptNpiG7Fyqt45phVpvLwM01OVXdNUzGUOKc+WRts8H7lcjlVCeCbKpTL5RCdWxS88Ys6Go2WkomhLhxy0NvT53bG1oZDq0MtYjghH8jxLe1ms1mKK6aBBLffLSuPnyNeajnqmsHh+WRkvJasqhb2l5mFTI1lPZRl9x0/x5UkMEEg0K+T71Dky0F9WV6SJLm42avIExu6P3wHKu0zq+3qkDMGbKqgWU9dOy530zJEPpu5GI/HT2TKevhnx6dPnxIz+02/Pn369KrkeY0yvTZ5XqNMr00eL1NslJJnYfTBEvn8+bM1m81f3VRfdSRJYqPRyD5+/Bisn9+yPGavT6bXJo/Z65PptcljFpcpNpYCybfxNt7G2/hXHcV0VHgbb+NtvI1XMt5A8m28jbfxNhaMN5B8G2/jbbyNBeMNJN/G23gbb2PBeAPJt/E23sbbWDDeQPJtvI238TYWjDeQfBtv4228jQXj/wNzKaVxQqISvQAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import numpy as np\n", + "\n", + "# generate new images by sampling the latent space\n", + "z = np.random.normal(scale=1.5, size=(36, model.latent_size))\n", + "logits = model.decoder(z).reshape(-1, 8, 8)\n", + "images_gen = nnx.sigmoid(logits)\n", + "\n", + "fig, ax = plt.subplots(6, 6, figsize=(4, 4),\n", + " subplot_kw={'xticks':[], 'yticks':[]},\n", + " gridspec_kw=dict(hspace=0.1, wspace=0.1))\n", + "for i in range(36):\n", + " ax.flat[i].imshow(images_gen[i], cmap='binary', interpolation='gaussian')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4oCtm6V3TsQJ" + }, + "source": [ + "Another possibility here is to use our latent model to interpolate between two entries in the training set through the latent model space.\n", + "Here's an interpolation between a digit 9 and a digit 3:" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "id": "8iJ9f60VUBwY", + "outputId": "ce34c9a4-3d47-43b4-ec1a-c5d47db0b394" + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAoAAAABMCAYAAAD9eAbFAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABJWUlEQVR4nO2d63IbSZKsAyAJ3qXenp7uNmm6d/f932pkpp0dXUgCBEGQIHB+6HxFL2dkASSqIG2z0gwGXlHleYnw8IjMGqxWq1X0rW9961vf+ta3vvXt1bTh976BvvWtb33rW9/61re+7bb1BLBvfetb3/rWt7717ZW1ngD2rW9961vf+ta3vr2y1hPAvvWtb33rW9/61rdX1noC2Le+9a1vfetb3/r2ylpPAPvWt771rW9961vfXlnrCWDf+ta3vvWtb33r2ytr+5v80XK5jI8fP8b5+XkMBoOu7+m7tNVqFZPJJN69exfD4SMv/qtjf624I3rsjv214o54vdhfK+6IHvtfGftrxR1Rxl7647Xtw4cPq4h4Fa8PHz68SuyvFXeP/cOrx/2asb9W3D3214H9teLOsGdtIwXw/Pw8IiI+fPgQb968afzb1WpVvZbLZTw8PMT9/X3c3d3Fzc1NjMfj+Pz5c3z8+DH++c9/xj//+c/417/+FRcXFzGdTuP29jYeHh5iuVzGYDCI4XAY+/v7cXR0FCcnJ/HTTz/F3/72t+r19u3bePPmTZydncX5+Xmcn59X3x8dHcVoNIr9/f3Y29uLwWBQZP3j8Tj++OOPCus22MG9WCzi/v4+7u/vK+xfvnyJ//mf/4kPHz7Ex48fn2C/v7+P5XJZ4R8MBjEajeL4+DhOT0/jzZs38dNPP1U4ef/pp5/i559/jp9//jnevn0b5+fnFf69vb0YDocp/m1xK+blcllhXiwWcXt7G7e3tzEej+Pr16/xv//7v/Hx48f4+PFjfP78Ob5+/Ro3Nzcxm83i7u4u7u/vq7GPiCdjTx+cnp7G+fl5/PTTT/HLL7/Eb7/9Fr/99lv8/e9/j7dv38bp6WkcHR09Gfc2ses897Gez+dxe3sbs9ksJpNJXF5exqdPn+Jf//pXfPr0KT5//hxXV1dxfX0ds9ksbm9vY7FYVH0Afu754OAgRqNRNQ/Ozs7ib3/7W/z222/x/v37+Mc//hHv3r2LX375Jc7Pz+P4+Lga9xL+DPs63Kv//9AgH+u7u7tqrGezWdzc3MT19XVcXV3Fly9f4vPnz/Hp06e4vLyMq6uruLm5qf52Pp/H3d1d3N3dxWKxiOVyWV0nImJvb6+aA2dnZ/Hzzz/H77//Hn/++Wf8+eef8f79+9q4j0ajODg4qMbe5/1kMnnRmJdsGnin02n1fn19HePxOC4uLuLLly/x5cuXarxZ57e3txVu1gv9ulqtqnve29uL0WgUZ2dn8R//8R/x22+/xT/+8Y/4448/4t27dxV27N3h4WHVBz7+k8kk/vzzz2ePOXP84eGhGmvGeDKZxHQ6fYL78+fP1XhPJpO4ubmpjTVrhfVO3zJe+/v7MRqNKpv/97//Pd6/fx/v37+v5jpjjm04OjqKg4ODavz5rMlkEv/5n//5ojHXuc66Bjcvvr+6uoqLi4v4+vVrfP36tbbG5/N5tcYXi0WF++Hh4cl8Z8xPT0/j7du38csvv8Tvv/8e79+/j19//bXCfn5+HicnJzX87u8mk0n813/9V2vYp9NpjMfjGv6rq6vaWv/y5UtcXl7GZDKp5rmONbi9qa3Tcf/999/j3bt38euvv1bzHezHx8dxcnJSzXuwTyaT+O///u8Xr3NsOvOd+Q3m8Xhc+bUvX77Ep0+f4t///ndcXFzEeDyucLNuFK+Ot9t4xvzXX3+txvz333+vxvzs7KzCfXx8XMOt8z0b86xtRAAxnm/evFlLBiKiNsg4ib29vVgul3F3d1fdNK/Dw8Nq4dJRe3t71bVZ0PwN/6dfHx4eVkQBksDE2IQAOtaXYPdFwwSaz+c17Lz0/ufzeezvfxuOxWJRSbeQIBwhBoK+4DMODw8rggQRViJQIoDb4M7GG8N+d3cXg8EgVqtVDbOO3eHhYc35cV0lgOrM9R0H4eN+fn4eZ2dnVf82EcBtsTshgPxpes3H3Oe9OnxI/2KxqH7GeA+Hw+prx69zHsO4jgBm2Nfhzkgv85v5xe/u7++LmHGAOD/vA34O/uFwWOHHDkCGT05OquDv9PS0GvcSAfT352JX4nt7e1sba7V3illtFWOrTgCsrHt+z31n2I+Ojmr41d7p3M/Gf9Mxzwj/3d1d7O/vV2PFWPI7tUn6tTt8PpvP0bmluMGuuB27EiAd/9LYb4rd5zprSuc4hFbxql1X7KwP7gnCq/NBcasAwufi+D0gPjk5qfBn/u652H1OQzJWq1Ut8JvP50/sOv3PS/Ezl8DNdSKittZ97B370dFRhbuE/bm4we7EF7+sYw5mnW/6Yr1yTfiP4qVlmHXO4+PA7kFPSeTaJMW9EQF8TnPyxwKZzWaV2jGfz6uIwB0dxI/J4Z3DNdTYahSNQdIFzOd1lfPPVE91jqgdRIJgj4gni5z7VPzqyLgek1Exa2Tp+Ltq63Az3uBWoueTnZ8zfkp+9FoaPTPfFHPXTfs2U4XAy0uVvWy+Y1hXq1Xc39/XjKM7b1VL/D66xuz4fR06ZtQOn8vMdR0/fq8kiFZat7sYa66jmDPFWzG7qoeipeSdtQ4GcOu8joiaM/Mx8K/1va2+cexZdsMxq21TzGqbIh5tuwYDOu8zfKVXVy3DD07w81Jsbtc3we1rvukeSuu/7XHXr7nfDLsH8dh08On81iAoI/5+H7rmssCxLewZ5szOqYLtWQvWuvpyn9vZvM3GW22B9r37+G1aawTQIyZ1CJCf6XRapQxIe0U8qlyj0ag20SKiRhAghyxCSIarYG130qb4lQgo8SNdcn19XaW/WASoXKPRqOoLnTDqNCFCrrS58VWn2iXejAAx3tPptMKuY35/f1/DNRqNqohwb2+vRuBVAVJ1iYWg5HdXJFAXsRoGTRd4SvDm5qZKB0RELbpzh7G3t5cSIyUB4Mych95jV5jdIOlcBy/znHmp461qkAZ3zCG9BnMlu5eu8Dp2f4G9lP6mpEHXuM51Te9i+1DPmEuqlDj+LojeOtwZ+WOtg1/TXhGP6fvDw8PaetbAR9exkhud89yPk28nAW31R9OYc6/4Ng1yFTt2/fDwMCIeiYFi5vM0mKCfSrg1pagEoU3s2gelcdexB7vOdyVBrG3FqiTOVX+3dW5zvA+UgLXRD03Ej/HGvqmwoeSPOQwuTfU7uYt4JP0639y/aOakrXFvhQBmHTafzysHSP6c19XVVUyn05jP57FarZ4QAZ0cLotGRO3zVeZHJtZB2QUh0EHTe9P6GK2RmU6ncXd3F6vVqiK+Dw8PlVHUSa2RJMqIphrpu1KtRRfYm8gfZHc8HtdqougLcA+Hw5pjODg4qO5dI0ScBotDVWUlv7sg/JlxzMacuiAd+5ubm7i/v4+IqAU7jC39VyJBEXX1z+tp3AF2QQI3GXPGWUngw8NDZRxxiMz1+XxekWFVtHGmaiAdG/eyKd7nZgCcZKoT8hpAXdtgv7u7q2yYjjcpLSV9qijgTJkvft9N5KdrEqCOUAM8rW9k/rLGI6KyU6qegJMXP3O7lRFRV0fasnVNY+7EV9c4xJ8Al7mudZyqHNKX2XtE1EiQ+hYlAE4i25oD2VrPsOsL/4M/xxdDeH28HJNnBL3/XXFUW9GGr1s35zWw13En2FP/PRqNqmyGj4nOX89YapZHswuZvytl+55r41pVAOkwDOP19XVVCDwej6vJghKoBPDo6KiYDoiIGgligwGdxd8fHBxUtQEuz760g9Zh9oUCdor/KYYFN86CukCMA7h8kWhURA3F/f195RAhgcfHx7XNFB4VtYlZv9a0L7gphsYhqhrImBMlDQaDyjiqMctILEQC0qAR2K5UXzcQip3gRrFDgrjXiKjS3RgLdfz0A0ZFDV1EVNdUR9Cm8y/hdaPoiicF0Ypbnf1gMIjDw8PKOShe3QTDz1AVmC/Z/TjervA76fDicNY6Gx40Hfzw8FARob29vVrdKzj1a80eYA9dFfH3rkhgNuYZ6dfNDoyf2nXGu+TIVVXa29urqcY+Dk76ulACfcxLGQ7wU/QP+Y34tsaxbRo06JrWdc98UhLouLP+035wm/+SgKe03jXgUexZdsfHvUTcPXPi2F0BUxI0n8/j6Ojoid3fpmW2HfsEbuY7IhZClgY8nsXzz+eewcE88CyXYmV+eS31tuR3awLoE0Ydou6Subi4qHaDaV3YcrmsFgu1UC71e1ssFlW0pQbn4OAgTk5OaqmILJpsuxbQHQNESHfDTSaT2q5HvWcmjC8Qf9HPGjmz4CBZOJQ2o2JvTSRoMpnE169fn+DWe4b4UsCqpM/rGVUdWK1W1XtWX+iLrm3i20SEsjEnMtYoLyKqVNDh4eETjKr+sPAhjp4u8jRI2+OdGa9MASPQY/cfRlEjW1QwlH4CmSzCvb29faJ4uyLSVYDjePXl2CG/V1dXcXl5Wdk4TQVqHdzR0dGTz1Ebxpxmk0VExN3dXXpPzyE8TRugNsFdIv0EPMx1H/PRaPREzdG5jq1iDd/c3MTe3l7c3NzU7j2irpx4eqxJ/XqOrV+H3Qkgwa6uc66vZT3Z5+jncUKG3qsq29pvrhY32fqX+rlNsEN8x+Nx5Xe4BwQJtx/q19XeZdhZJ/q3/D3kr63Av2ncuS52bjwex+XlZRXoMn5KAL20h3f9XGyn4sZmYPNYb04ACSJVAHhpaz0FnDnEf//73/Hly5cqUmLRRtSLJvWzNA2cLXp2mWrUdXh4GOfn57Xao64UoWzCOPm9uLiIT58+xZcvX6ot8WogtRaGAlmP+DxyBBdEcLlcxv7+fpyenlbEVwmHvrYlvtmC1vEgHQQJUtyMN3VPEF/9LE8H6IJfLpc1bMPhcCMFsG3C74TA5ztE6MuXL7UyB039ofiqUXAjp7stuVZEpCS5K6LveEvEVw3j169f4/r6uqYGsItN63jBonghAuyeo18yA72r1oRdj4dA+QV7RFQknzFXYqdKkBLf2WxW22nK+K7bFNG2nSvhVkVEjz8BNwE9OzIV92q1qq1xVTwp5+FvNPh38sC88J+1FRCsw+7HgkB+I+r1vXoKQUSkxB9SBXYlPxFPia8SwSwF6nPgOcRfsZfIf6YA4tciojbmXr+qeCD+t7e3tdp+xp+mapgGS+tKf9okv3oEDkfAXF5e1rgMPEYxM/Zu65nzrHN+rveQBciaFfFM30uxb0UA9cIuF9NhGEYIoKc+mSxOCNQQ6ALUa0AEyL+fnp5WaUafHF01NxRaLwAJJDWmRoKUkG7h1s9TJ8+EWS6XVdSsi2Q0GtV2GJcmRlt43SCrQSc6ZMxRRCK+GQfOLtLUN5/rahj9xTgrMd7b26upf23VgpQw+7unCbQ2RlURyMDR0VFERO14CN3VjoFVcsEOMuaUXttTYV3PcTeMmsJwEjidTuP+/r5SvVCCiI6Z607y+Z2WOjDflQxkTl+bK17b1P75NcGeqWG8SOfoXNfjGsDuBFizAZ4e8nvbBQku4cYp+hmI9/f3tU1bHNtCH0RE6thQgDTtxXWyDTDqczzYK82HbbA7ft0EobV/Ed/WOdipgdOABoxOgFB96BdIRUZ+XfFviwBlAb77Nl3z4Ff7jtJPAKCbF+kD5gBzPiJqGQAlTE6Ks6xHW75uHfnVzW4EAih/elQPxN9P8IDTkMHMxtz73lXjTPHeZtw72wTixhG5eD6fV85gOBxWhlGNI5/ZlAJUFTDim/qnGyG0o/i8tlqmSPgi0SgR7Hd3dxX5Y+JrtOiKELjVIUZEhT3i24BrrRX906VSkhlHVTF88wdqlh5MjZFggahhWCwWNdVAU4BEmqoGZuSvSxLoxtFJIJsAptNphQUSoDsj/Tw1NsIoQdJzHGl6bb2nLluJBGZpChR4XdN6dp+SIHBrEESgAyFqMmpO9LpSe/Wl2NUpokg/PDxUm7kgA6qC6pjf39/XlNHFYlHNDd0h7ccB+bufefdSvE24MyXK69G4J9Y4Z5Ux1z2o1XWO/XDlaN34a1+8tA8yu1HCr/YZmw8R0FIHAgBd5/f3387HxBeg/N3d3dXEANZ20/wuzYU2Mj1N9k7nP98TuFGHr8GPq/7L5bLCj30/OjqK2WxWWzdtl2o9F3/TvNfUKwSQNc446lmUKmDAWxaLRU31Z/1k99JVezEBLC0MT21otEAtmEZFGAp9coMTKiJBHIMTIjYGeIrUidBq1X460BdFltbCKS4Wi9oBkUQNKCOaBgCHkgEiioh6nYRHB12of02kl0Wt0YweEcB4YAhJk0CEldyxI1bTQRpFgplzxTQa7FodacLvNTpe3Bvx6Bg5zNbJAA4EjOokFL/fi96ftjbJgF/T17waRi0/0MNc9ZBiVTe1Tuzh4aFGCJsOsqZPSs5iGyfifelqhCswvua0tEUPMEYNZb3oDncIP//HeslIgKaWMpLYlgMtkUHtI53fekg5REDnuqpotMViUXuCi5JfJbh6IgR/l+HXfnop5qbv9fNx/mAAtx7OrEEuNlKDOsWvx2Hx+WoDFLuek9pGAFDC6ph5cW18lx9WrAcz89ka9LEOCPaygE/7AOx6lFBp7NtsOtba//QVew84qFrT/2q7dc2D2/82ojzXM7zbrvXWNoFoOsoVrIygRNQJkKpDEVFzrDBmDORsNqtdm+t6XUgX7HldZOzqgEbHmtLQqIFJowSQYyKUCJXq5jwVmBnptvFnEaFL9EoIVBL3MVcHqAsBopcZN7+HLvH6e5aKye6FhvFACcrUIFc5SaHqCe9aJ5q1bdKeTdj52kmQry8lJBH1eig9zR71m6BG57Ibdl6qiCj50Z+1TX4y7NrcMWCsV6tVZdiVBJ2entawazCnSrHjd8zqAPXv3UFwj8/F2tSUcOmh1kdHR7FcLmtPaIEIqOqrWQ2uR4mD4tHreP/qdfWJN03BwjbNAzj3W6x7iIA+mccVb4JDPpNx15diV9z+hBF9+oamGrdZB/73GflhjI+Pj2tzV59EBQHk/l3Uub29rfk1J0yaVtX+ps/1kYdKjtoigT7Pub7iRqgYjUa1R/LpvPTglqD+4eGh2tRBvyJo0WdqO9V++tN+tiH/LyKAJSUgIwNKUFAG1Ijpo7yUAOpkoSPpNC2m1uu2XROwrg82IUFNyoA+1khrwhS3pgh8kisB3iRS3warXq+Eu7QpIVtISgZUHleCS3pM1RC/p+w+227rFDDHq/ObtIiSoCYiRD9mER+fva1Rfwn+0tyKiBoZ4Fib1WpVYWV988IhYvDAXIrqm14lsuS421KD9PN0Xuuj/SK+laTwmLLT09PqawgiCo+m0Hys6Vucr+J1MuTkMeuHtpyiBnE4fdYwuPVFdoc1DSHwQEcVT5wk9nI4HNYIkD9iUB99t04Zfi5mJ/lKcrl/xImTk5Pa8+lRwRhzgvqIR0XISRyfx9irzdS1pM+CVfxtYXaVE5/FI/ggb7zznHIeTUiZl9p3yJ/6NVV+eXqIqsmZDfHnP2eq8bbYHTfkFsWWMeJxnP5IOid/KN/D4bAqBVDyq5lB9ZHYD/rVnwHsgc9z8LeyCaSJDKhiQAer+sWLwfSoWAkghEBlZf62pEpwn220EhFoIn80Nd7qMNRJKglwJUzrJD2y8HvqomWqp2PW30VEbUFqBKu4Ib6ugqEmaE2QL25XqLpsmQKWkXtXRlar1RPip4YrIircSoScAK0zbG2nPRS349emToqx5W9cDcJo4xB13LMo1g3aJoSwrVRYU18o+dMnEJ2enlbkBoeB0VY1iBooSEFG2PxnSgJdBVMH2lVazBVIxUzAEvGN+J6ensabN2+qZ5KjBIGbuU45kKuYikE3FqjdVCXIn//ra+WlAVNGBlSw4FmyHO8DAQT76elpJWp43Rv+TYmbpzgHg0FF7lhLzCV/BrDvvG2L/HqQA8ljHFHDUAD1udSKnXUOAYp4fKCDqmAa8Cjx1Wed87z3pucfvwS7rz+9h5OTkzg/P6/4yM3NTSVmKTlVn6YBHqVR8BZqPtWnM+89uDo7O6ueXcxzz3nut6uAz23PJoAlNcjJT0b83JhhvHCIGQFkUIkudaI7GchUr65SoE3qZ6YIMelV0lYypEaSCIjrQXw3jXB2qXxmqqumAsHtKQxNgw6HjwfeQoBd2WjLkW2LPRt7neOeromIWqCjZBCCCPH3dG8pnZcRH1pGmtrA7fj5bAyWGi2ur4bb1zj95vVQpaYERPvGXyViuA3u0r14akiDVY6kUgJIkEeA5OQvu54GjlxT1T+1J5kC2Mb4OwnSTQ6np6dVkBPxlACqEsQa1/VdGjMlg2DWGjslf54KfAl2+t/Hwe9FSRDjx4Y8UsA4ac1waMrfVX6va2PM9ZrMMU0xKwFSAtzG+tc+5B65h/Pz8woPgYymgLlP3djo5A9CqPNY15QGVxAsVcKUACqJaksBVJUdAshO3YODgzg7O6uIrwpY6tMI8HTDE+Ovcxb7yfpQmwLxBPObN29q6m9G+p+Df+sUcKYMleqUMH7OsLUTlAACxA0Gk4j3koPqUiEqEQFXSdRhqRH1KJ4JrPWRWjdIf5UmuCuOXbTSWKv654YDkqNKhX/N3y+Xy1TBKEX134MMlpQwHVtwMx6apuI9m+saIGVtHenL/j77uq3mpABjrQQQw6jOGoeI8qetRIJ8zSshbJP0rcObzW99BrkSQC0MV2VANzZxn25DlstlzWYoKXDi4ApCm+pfRspcoYiIqpYX5whBobxlMBhUWaEsgNfrOdFXm5Glgb0OLAuaXopb70kJ/8nJSWWbKWWAACrhR+XnKB9V+LPghesQGCrp1bSgEo6mOsDnYnYbrkRIU8CMox4B40oYadCHh4eq7lHLm/SlfpH1oRuosrrSJvLbBvF15ZMxPDg4qMpcuE9XIyOiUn3BDhEulW6AQYUC5pIqf2pPPOh5bmv1GBgnYjTvVDXo+hkQAG8KbJ1SoJ+3q+bEiLZOxfF7dcK6aTq7a6KbfW72c184GI+IxwOgSzUqpYCC3+kmgKx14fA37Ut1VhqxKwH0QmfuV3FusqnECZCuh22j301xl5wDRot70jqtUoTO2GY1vNn1IEJOjjLy1yUJzHA7AVRDzfhHRC1Q88wBX/s13WGW1KNtnSCOiq9pOr+dCEVEpXaTBnUH7dmBiMexz7JFpEo9UM4CqVLqe9vxbyK/R0dHVQ2YEgNIkCo0jHdp/oNb5ze74Z3weurb11YbCphn7RQ7ZAhfDXYIoI47Y8g+gKaAT1VW7KRj1s/XddVG/Z/OyZJt09Ms9Gsl6FrLr3W+lPl4UK6BBUGfb/pw3Ep6fU49F39rzwLW5kZYXwy4Hh3COXbIphH1TSC6i5gXqcJdEr1Sc6xO/JjYGMCI+uPjmBwuGyt2Pdgag7IJ9q77J8OuC4dISVP5pPf1iJTSWGfH+nTp4J/bdLzdYOPUMGjMf1VNWQM+1jreGRnSrzPD0hVWn+Oanjs8PKwZefpBU7yonRAdx+zYtY9d5cpIYJt94J+XzXFVPrUWjp9nG5h8DejxSX5+Kf2W1Yg1KUltKGBNuBVzxOOGHnWGGvCozcqOTdK5ro5RbUmmmmTkd9vm/ZeRX3b/ugKopxuQChwMBlU/ud3XjXMRjyRRCW6269f7Y1vymxE/xw4+bBdElfv2Ew4Yd9a8jn22zhnziMeNEKr0quKbkf82iL/bVOw6mz64N8YT28ecYCx0LjtuXed8BrjpRyV7TeUO2+J+MQEsGYiSrO0kkCJgDg4eDAbpuVh+rlzpUWcR9ch6F82xu0qhsq6mdyF5nJEY8bgRwHFzjiDnCuoh1xnepkmQRfabYHSsjtfrkjzNq8QXGZzdcLoTVB8Hpk828d3FSgI3xf6SVrqGG0av+8NQcr/MAcadtMlwOKzSI8xvnee8NNDh+tnibwu/f0429jreupFLd3Ir+YHoRtSP/vBnOuuTbnR+e3+XFJ+2iWDJGeocPzo6qmqhlAygzjjx97UNfl3jvpGqRH5KfbAtZk0Fqj331Bx1S6oEaUpSsWcHZ/PCrivxRQ0pOf1N58A2/cC61XWuJAiyFlEnAxr0OH49NFyfkw5RQg1ibqmSXiJ+bZFfx58FPEr69Gud85rJA7s/TUOfXQ+Z1vo/PfZEMwltEz8lvxluL+lx3FkJG35Kdz7zgAB9hCnrTDfCqIJOMAF2tQFtYH8RAWwyEJlj5HscHs5gNptVaRE/EZ9OZNHwZA0e9YZjpBahiQS2rYKVnIJHq0wcx8UJ4NPpNCKi2h2ESqjRAk5RseuuWxr3UcJdIk3bYPeUlDoHXpA3DKkeA5AdCeEk0ImgL1Q3fm0bwQx7RkZ0Uw8bAiIez/kiKsyIrz9aiXd/pjX3kCk/jr9NMugGkfSPKmDHx8dVkbSq3cx7xpf1z/Nfp9Np9QIzuPXYqMFgsJb8dNEy/K4E6RyHvNE3rElN8epY61NjfMwj6rusfcNH0xx4SX9sogJpSgxMmgrkXpX46trG+St2H3PGGseo6o+nvtueA+7bFD/pXFW1qNVmvquDjqj7MMiPjjkBH2RAVTA/JSJTfyDZXZDfiKgRYMae/lFRg77QmjT1Zdg4nqc7mUxqcx7sGlT5cS+bzv1tMasKrePOz/V7XRea5WLOY+f02ck8JQtfjuoX8VhCQQ1tF6qftmcTwNICcTLgipBOVHUARI+QQR1MjZz0IdxETSyapvq0rK4OHC9pGfl1EsTL5XpN/VAYCnbHrekxVYj0IdiK1VsXqd8mIqBkQGVrjB/3BPHXYnD6BbyZQqDEN7snXQxdkIGSMwS3pgKVAPI+GDwWQyshBq86wkwZAbf2P1+3afgdcxboueIJbi2M1giZftDdr6z/yWQS19fXFQl0pZ/78DXG9104P64ZUXcG6giXy2VNBeN34OSeCU418MWO4QjAD/nX54Kq6uapv4z8KP5t+wLcEVGR+tVqVStvGA6HaUpLnaEWw/ujIq+vr1Pc4NNjXrKaryb82+DO8Ec8rdlj96YG2Dhp1jdrXLFDghQ75Bc1yFOBJSLQpvrnn6P2RrMvYFdi5PeiJEifF351dVURIda6nqk3HA5rmz708GNXf7uw+fqZ4GazohJCms5D5SxkLvX58OPxuLL1rBtw7+19e0oUu+ezzR7ZvN8G+1Y1gGoYMyUsS1tgKCB1kEGiZZXOI6K2iDCcKEKlnbddtcwhZlGAq2AYbt01C/bsGAxNF6nxcMl809ZG37i64ERA61WUBGLUuF9SZf7Qb60L0poofZxcE/nrspWwk6ZxIqQ7xpQM4SggSloDW0qHunN1ItZV+suxKwGC5GkqUIMSamR0jjK+JRKkaoCvbw0uXf1p2wFm+DWVB3bGTx2gEkAPPDX9iQLm+FXtVnXNbWnXTtBJpDo+7QcluV6vyjulHa7+Qf4ggBro8Lm6tvTA33Xkf9vgXnFH5BsPWX9+Bqri18BWFTBwj8fjas3Tf6o6+eHC6za8tEH4XSDJsDM+XpNO04DH1T+I0GQyidlsVsvsDAaPh+ZznE4Jf9uBv/p2GthRo1kH2UY1cLvIQ5B7dXUVFxcXFQHEB3AdhARwqwJYUv/aaC9OAevXJSXMi1gPDg5qBf0YOsgdkZ0D1A4lkoQs4JjWtTYJok68LBWY7d5ikWMsUINI/ypeJpIqgRACTQkyQXdBgNQwKBkoqUG8tEhcxx1lDLzMAyeCSoLU8WyifrVFfDPsil/HWgkgwYHWMDJ2uglC095aC0ZkrOQP46sp/y4IkOJW7BGPahA7Fdn8oaqfbubRYEYjY697VOxK+F1x9UhY71Hft8Xudi4L0tQRKG51jox1qQ7Kaz9V6dYUsNeXdU3+M9yQQP0dBFDT3Mx1V0QgPEqCeU66K75aX9q047ft1kQG9HslgMwDHW8P8lQBVeVTa8EUd3bczbq0d1sKqOL3NTYcDmvzW/2VbnBzAUOVXwhgE+H3Xa8lBayLVsJO4O/ik/5McUN8UT6vrq4q8Wcw+JZOJrhhV70/5aVU69kG9q13Aasi4STQHSMkSCcYHUYnZoOri2pd6pOmk7Jt8udpoRL5VSWMh59rbRP3heHz+9dIah2OLkmQXoP3TUgg50Wp0fBoWReN1tdkEZbi7Jr4lbBHxBPsXg+n6dDBYFDb3aoOkrmgBCGLqBWnpn6dAHXRMifInNcgTwvj/f50I4tizUiDY1fSWar969IRaD8ofk0Fe32mn1jgm7/0BTHQDV5KpOlnrflqOw3kOJ38RNSDDEigBsA+jx231jVrwMPXWkenn5sVvndJfEvkPwuGNBOg8zjisRY9C+SVBEN+PYukpSXZxpeu5r77ZiVCLnwo6cuwe+DjxJ9x941Dnj1bt/Ghi7EvYVfcil1xE/Co8gvpR+lfrVa1GmFS/hl2V/oV87bYt94FzNdOAJXJw2jZ6apqAc07WQeYwfD6uU2jACUc2QC/FD8Gy9PASv5OTk4q4ru3t1erXVxH7CDEYD84OKgUQyKIEgHydMy22F0J07FyAsiY66YA0oLqKLQWDoLoC1zVI02LeRog6wN936aVHIASASe/WhSvj7PjfnkGZkTUDH5E1BwnpAGVZR358VRUG7izta41oHpMghJA5juNv8+IjP5c14Vj9vfS2G9rGEvrJBt7XReUdGhZB3NYSST3makIfj1emfKV4WzTIXpTh8XfaHC3WCxqgb3bcQ8CnPTqNXQeqY/ZBeHn+o7d1x1jRh+AE+x87yU9+lI77+tH18j3CHoyJdTJkGN3e6QZHVQxiLBmU7gmc11f69Z81wEQ11Ufqn/nNtvHW9VftYm6qcZx72LcWzkHUCcEEStEgOf3UfPFztBsR2dEpItc64Yi4kl0WXKC/K87xbbIX7ZgM+wYxcPDw9oxDyqXZ00jR+oQVCHTvtf/8X5oGzvvOkGV+PKEAJz3aDR6onBoekj7QZUxcONwfb6UnH8Jc1vY9frZuPuOPoj7fD6Pg4NvZyOORqNKFfQdnvSXlgZAItUQrms+F9owHE7+uWdNVWsg6EqlkyJXi9zo4RjcIXIviq8L5Vdx6/3jtDy4UsKSOUKv4fNABqfKtfTn7gRK49mWfVPcNFecS/fE+Hr6ivtz1VdT3qVAcx2mtkmAfu0kKMOu2HweO27IELZOdz3r9TZV/NokBU3jr/2BCKE+NuuPiHhC+CG/fKZmh0p9W7qftlsT/ojHOeqYwatzGvumKj/PUR4Oh5XtWzfnuyL7raSA3fgpETg/P6/Sn9TCsRFAjbYCVfCQIHZKoppoNFkigE7+nLW3gT0jAmA/OzurVE+waw0fKlHJeWntjGJXqV0NRKb4eR+06RgyAoTSC/k7PDys0j2e+lPSxzt9g4EAt17blUJaacz5XVukn+9VvVL1EwdAn6jBg/RpGhCjoAaCNNHNzU1FgjStrC2b44qb+20Dt2JHrQMnf6eqqL9QPvWoDP5H/0bPBFQisY78tNmyQE/xRzw9WUAd//39fUp85vN5ra5LA6SIx/W9zvg7we8Ce3Z9JYElh6XzUI/C8nSW2gG9ZtN9lVrbtn0dCSwRHs1SZNksrwuF8Ht9GX+/7j67aD4WTUQo4nG8PWihlcgvfaibJCNy4WaXbR0JLgkwEbkKqASYdY5yyu+VPKov67JtvQtYF4EqAkdHR9UDs1erb/nus7Oz9LR7J3++YNgkcn19XTkaV4ncKHt9QmlCteEYM/KLAkZUd3p6WitudyXPB5p71x2TmiJcrVZrN5A0EeBtjKU7QiW+jC24Z7NZWhNVmvC6UHTMIx530aKkZKRf+6cL0q99AHHh3g8PD6tr0R+MuS5+vi9FxZouGI1G1Xjd3NxUzmVTEtgGdp8vus48BcoaABcbRCBz4HMCRMmE1kcNh49HjDjBWIdbf98W8eddHVuWCtVaLldBsV2z2axW5E6f6HNDm5yLO179uisSpC1TwbKyFNarb+IorV3HUsK6i7aOBLDus7S+KqAavKxWT1PgXjOX+S397F019cGKOesLTQdngZoTX/BHRFoapD6iy0CnqWX4+b7UFzr++oKrYOsHg0HND3iZk9fKd4W7FQWQdydBZ2dnsVp9S5UcHx9XBa+e9iwRQCVBNzc3MR6Pq8gahxLxmB7y6NMHgQnatpFUMgB2FjW7fFTud0LKPWtjElAz4dg1isgib8ft12qD+Go9kJ8J5bh18es9ae2I103c3NzE8fFxFRmzgOhrnK/jdmLZBemPqNdDQfi4N1K8vrM5U0H5O931TNEw2PkfSimyYKeJCLXVdK3S/zofShsCqItUhdN3Oh4eHla75nQ9OyFyw5y1LohQiQTyvaqiSgSVBDCWs9msdto/WZGIqNlIneMZxq7bOhLo45I5REoe9DyzrIbX5zI/4+ffo5VIAE1VsIjHHfJew+ZlAZ4CJ6jNSoNcCdolCYp4Sny80QdaopJlaNQfad2nigNNitiucdNKa6DJvqgNdl8Mb9ENY54VygKALlprzwJWJczl3P39/epJAVorpP+nnanGFac3nU6rWiPUEdQldYgReY1JF6qQOwQUEU2HHRwc1NKipQntg8zf8hQQ+hTsuqN6EwLYJvn1BYHjQ6ljDuhTEsDjRt6/xgCAG6OIGnh0dJSSAceekaE2CIGTQPDrWVEQYo9m9WuPhv3sw+Pj4wq7PhqPccwcbdNc1/t/KW6uoxG/1ugpAXJ8ngLxQ9L1xABUM4wl2DLlk3vKsHZFAmmugnHP3KsXdEc81jMfHx8/IYBeA+Vk0x0iP9P3LlpGALg/XVvZekQZp/7Vd7TqWHvzMd2FIlJq69QfHSOtgXUVuEnRJSCmL7J0oOPeZT+UiFAWBLhf0r9V+xcRtaBRM0SqiH7PsY94mhLPmgcrJT+s4+uqnwsEJRGjrdbaJpCI+vEY/sgYTQ/6JCrV9qxWj09KgAChhnFODg4iU0UyJ+RkSe9jG+zgiIiqVg9iwBEZLmmXBpT7hwjRl7PZrNpRjcKkkZdjV7KRkaFtJlSJsPM9hLBE/LKXLojb29sK9+3tbXUgKLi5jtdiuQIIYWqzNSlBzHeN5te9lByh+GIU5/N5lQZnJ7X3f1Oftt2yoMexu8KrGKl3Y45kqgH/588GLpG/TX7WFnY+320e960kMFPnF4tFpXb6g951Y5w+FaPUdkkCI5pTgCVisFwuK4LfdKC1bgLR/2+a012R/aZW6gN+p+n+pt2cWQNPKXtTwr7LlvmNjPytI4K0jPw6dn3/nu05PtODtGwOK65M+dwF6W1NAaQ5CeRn+sgcWmnC0DCYbCKZz+dPDkkkcibFwv9p+kmP0WBBtmkw1BlE1OuCMG6ZnFsymhGP0QF9CHY9GBNS7OnvjPhCfsHeBv4SCXQlyKPX7GtfCPThfD6vlBI9Fd2JL33WpIKpo25b/eVnqohmSqcbNsXMeGut2HQ6ffIcVK61ibNsa6xL2CPqARzjAj7mvzr60mYBvqfmVzcOoHZzfW2ZoeySEGQqsPe12jJd03rOl5JAVwAjorZL0tsmpK+LPshIME3rwPABSoL8AQGqCEbUj71yHOvsxy5bifBmpMdJofaDKsP6uSUC1DUZ2LT5/Hef7u/eD7o2skBWhYuS8vmjtab1WLJb/L2SW/UJu8DcagqYhgNUg6gOkb9fRwDpiMFgUNURQQT42rfRRzwW06Ko6Cni2bb6NjvaHbOTgYh8wWhjQmiRrOLGceg5e5BAXUBag4YTXi6XnTkFJ0JqvMDlfaDvet/gBqsfjomDUdXIjQeYuyJDTgLVgalC5i8nhsxfdo1Cgvz5nxrklBTAEunWv20Df0akVRF0ss28y4ifqsSUfKD2sgO8dLRKieh3gXkdfn6uqjPzlHWanRlJiYgqHNo34G5SQHyNda2GKX4nw5kCpAccO/7lclmdEhARtYDW5/L3VoFoTX6jifx4P7Bhytd1tjZ+dBLkLZsHSoA1kKeBObNl/P5HbU335pmOiKdjXQp0usTcugIYUU8H6fdunJyIZRE1aWONmJ0Aklpy9YQdh6WT9Ns0liU1TCd/xHryx9/gMJfLZc1gqnpATRgEl37zjQXgdxXQ7+eluN0BKCnTfvbr+c+UQDHmrhS48ulpQ1XUNA3bdhrY8fM1zVPT+tI+yVKFDw8PT2rkNF1Gf7sB2UT97AK7qyD0h84zbavVqlLktUxANwA58VUFkM/wz/TvuyZC2diXyJCnBlnDHJiuqX11/t5KjuJ7NCeB/Ex/7ySQTXJ6WHzE42P0IHg+bv8XSEA210r4eWGTIx5LhyKe1v9lqeAfsR+yewI/9kvXNfYwC+6ygLl0je/Z1gWgTv6wi467RAC7bK0TQHf8JbLhxkGNJo2B19SBPycQkqMkaLl8fAD37e1tjfR5WsINWJv94GqgX6PJOWEIITL6nEQlwThRJUJaT5YdPNu2cyw5gOz3tIwIaqCQpYtIgxI1qyLG57jqqcpn1+PdhFkdu5MyNwI+5uD0TRC+VrKXG5RdECI35JkSoHWSWZ1g6dmnfCbvpZdfv2vM+r3eo4+Xqj+QIK3l5X+Zu/P5vHGuZuuIr7tWATdtOuZKfE9OTir1a39/v/a8dJoTgCZHu+vmY99EUDICfHx8XNk67JPurPdU6I9OgrN1GfF0/PWkDAiwl3Ap5oh4Yst+lJYRNr9PXfe8Iuqnl/ha3RXOVglglhZoUr+U/GUEMCJqyh3RAyTw+Pi4qhdjEkXUD1Cez+fV//pTC7pYTK4ElAxxCa82dZS8qwKIcsCkgvyqksKTJ1BRulbEMoxZHzhpVAfoY66LR4m+K4BeU5ftpCrdT9f9UCIEOhYQVSfs4KdPILU0jZSbVMAfgRBkJBBFVwlg03MwI9bXPP4ojiKzef7UHN38EfG4U5ggjv/l8340jNpKhFSdoB6Ur6UBw+GwOuIKx69YfzQSkN1Hpt4wfzP/xfqH9Oo4N63nH62tW4ea+sWHsd41M6ef5STwR2s61iW1VtU/tecR8eQpOd/DPneSAl7XXPXwn/v3XkCrD6DXx6q4AqgpYH9ovQ5S2x3fZBiyti51oM5Sd9SRMvCDRv0pEyikXRLfrK0zVpvg1oWj+EtPWYBIKPlz49nFQnOcTjhLf6eBQIbZDYTXj7jhyV5KFrtsJQUgi4gzJVCDmezA4Ox6mUKkv9+1Uc3GvUQCCeT0dARqIG9vb6t+oCa2pDT8aKSgSQkBtwbwTn61AJ7PKNmuHwV3iagqgcF/aRaH0gY9B5LPKyn6P0rL5qDaYPex6r95QhL2zAlQKYjVa3/PxvXV/vpRX/w84lHM0c2wWrufvXbRWiWAJQPlk7eU/uV/1SnqS5UhT23q56iaoERoF5GURwVZ1FpSBkuEuEQGs7o+nYyugnWFPRvvdddpIn/ZS0kQc8HnjRrKXe0iy4x2RoD8b/17n/NK+FxJz+5BCUSJgDHn2uqL0nr3e/J3vZdSkKPE169Zut73IETZ2Jaw87WSX3WGBGqZ8pnhbLqfXbSmvi8RF1XCHDtHPKGE62frNX8E8ufY9WsnA77ONZtFOYs/T1s/rzTP/X6+l8Jf8jt6xp2PvwoYEU+fefx/oanNVb6hhzo7blcAPdj1criu+6J1BdCNgB/oSHPS0lQIGlFOmTalhlxC7pr4+XVLx7+UCHDTgK+bCJkRUkJQIkLb9kV23SxtoRg8yvEx16Z/545g3XWbIsi2sZdSAJs6q4wMZtdQzIo9Imrvuwx0Sn29rl9KQWG2DhRH6WzLEhkrBVxt4Nd7K5Gg7HwzT31qhsOPB8muEZE/8lAxd9kyQuIBmL/cDnqJg9tDPteDGp/TXczv5zQfaz3XU2uSFb+q+7pZTT/zuTZkVy0be8fNYy15EpSSIZ/v/ExJoAf8P0LL5qCfOMLZpZQyOAkcDofV+bZ+NiabWD3702UftEYAs45RRqzRjefEdUI07YzJHIouLJVed7lYMmeoUZAvbjf6iplXyakozlL6b1eGI8OdRX4lJ5/VfzqZcFzqTEvKpmPvmgCVlMcsZVMKZCKiOIaZwXGnWlI7/frbkgOfU9kczMawaVz9/7L+pEZKA5osuMnIX9vN+6CJBJUec1WyVVm/6rz31NK6+d12H/i9ZnbJy0/0pYfdZvedfab/fJN77MphlvArdiVAkAF9FKiKIZum/Up9tetWsvmQIDZezmaz6mldt7e3FX4lgZA/vlcSWPKL2f3sgiD6uOt4U7IB5ul0WmHnIHvuk53PEVF7+pFvdNRAsEsS2IkC6KxYnwdLR2RpXB/8iKe7G/lcjzL4TCZmFl132ZHuoD0CVOwqB5cmfQm7SsxsAlkul7VowyOOLjBnBMhT7pn8n42H32NGDtR58tkR9QeJOwHqymiWyJ87+00ULif8+lmZipLVN2qAUSK+beP2+81KDZrUP/9f7bfsxTEpGfEtkWW97zbmfxP5WUf+dN26XfSd6xo08jM9S7QU5Pn3XRI/vWbJRkGAcJB8rapQFiirWJCNadPPuiQDGfHPcEOCIAW8wF+yV942xbJrEpQRX3DzLG9I0PX1dUyn05jNZpXd5n7xfRH1DY/OCTYhyLtoJdw8nnUymcRkMonr6+u4vr6OyWQSs9msOq8Xvx/xbWz1LEzOPdVzb5s2wbXVWiGAmRH0haAn3cOE/Xw33/k4GAxqRlANqX82k0evoekFz7G3NaEy7LoDWaVgWmmiKynkszPsalwpnB+NRjUiUFLa2lxIij2TwjXaj6jL/CX1NyJqY+5kGgeiBFAJdtdRcckIuuLh2H0sSqqvOsNSMMHOdy+vyMjQOhzb4laympHxdYqm/n+mGCnuiKgRqU0U4G0wl/63ifitU8HUfilBcjLo5CgianibMGP/dkl8m1JhkD+IkJJBMPtcZn2TLtt0XLsgQyXy6yR9sVhUeG9ubioVSFUhsOs4u+3ysiB9b7q3LluJ9LsKBhGCBEEGp9Np3N7e1nZ464Y2UqPOC/QEhCbf1SUJzrA76b2+vo6rq6u4urqqsEMCOd4o4nHDJkfi6JF2/pQzPQKsKxLc+iaQ0qSYzWaVQkUH6MHG2Q5AagTVmGJcNMLimBMKapUAeo7dGXXbZMixs+BL9+UTXYmRYi/hp4D48PDwyUGqTq7aiiRKxkBxY+C99sNxOvGF9INVnac6TBZU5jxoXaqemSHUeyzVfzQpoHxeiQDpZiZPJzalQbd1Eq586Oeq48+I2TqV0ueOv+hTSJBu5srUxuwet50H68ifpmedtDu507lM8KrrWYlRNr6K08fd71Xv+aV9sCnxdczg0RfqECRQMTsRVBLr+DYNctpom4y9r30IASQQAqjYNXujimfE+lKZdfe7C7unvs6J0Hg8jvF4XBFA8KsaRlMxoEkUUrK4SzWwRP4U983NTVxdXcXl5WVcXFzEZDKp8DLe2C/1/eyE50xMCKCSwK5VwE5qALOJAVFj8CF9vHzg9WBIjypVWieq0MMlI55Kyl3k1EuEAExqCFnsSgD1QOMS+fV6Eq8rUWUkS7dmmD0F2QZ2Xxhq2PWelPBmZBCjr07TVRLHDfasdWksMuw6VnpfSgDVoGU72PUzSjVUumuwRP6y+21jzpdUAE/rZUQwc+aumnuakM8nOFCnmZG/dX3wXLyOna8zAlRKgTYFM6wVVwWztLr2md9j22TIsevX69K+GviT/tR0KD/TOV5ay5usYce+K+KfEV9IHyQQEkR/+NzWeaVBoWdJ6ItdtpKPU+yMOZjH43FcXl7WlD+thyN4Jx2q/lDPueV79w/ZPXZV1sV7pvbqeE+n04oEXl1d1biJYqb20dO/kEGee+8H4f+wCmBpkagqghMnwhsMBk+i5YwAMkEWi0XNiGAwlWTwdyUC6CpTGySohN+NIsaBe+Ne+L0TQCe/GinrO0RoMBhsrDa1aRQdcyn9pZMfwpspgE58NcXvKSO9XkZ8SmPcRfrbnaE6eiflWopQqvt0dahJEerK+SvOddizVHVWw6dkzT/HA6YsPcjc0DHP5mMX/bHOznm9nwZqBL+KSVVOT5eqMpTN69I9+b21OddLuBU7eFTxI+3p9XBqu3WDRIZ30xKWXRCBdeRPlU6theOluJXg0zK/pYHjLglg07i7b9PxJv05Ho+fEH0XKnQ3MKTHn/vuKtiu6x0z+65zfTqdVuneq6urSgFkXdC0LE3PJYb8HR0d1R4TWDoSqk38nR4ETScy4Ex4fqYdqmAzAqgRVVZUy1b6iEejkSlrXcmpJWKkE0hreSBDTn4y9dOLaYksMCLsLFKikSlO2a7bLjD7glEjNxwO4+7u7sm96dlvKIAYUjWcjLeTIMXUttq5aV9kuCE+2ldN6R2wq2KgDqOk9pWco37fxZx3h+jpaq9ny1LBageyta0OIyJ26gg2xe1lGk4IMiXQ093qJLGTmZL/EjKwLSHOAj0db4L80m5ITwU70XcbERFPfIFvCmgLWwlvCX+JAKkaBClQ+8W7B7IR9aBdn5bBa10qsCvyq5+PD1MSVCK96q+Y3zp3VOnU1K8qgLvcDFHCnI25qtzYK/XR19fXVcnbcrl8ouT6E1H00a6of5kC2EXbmgCSsuNr3p2A8PIUzv39ffp3SoJQAJFX2V2DUV0ul7Xt1V5XUIoi2nQi7mhL6lumWGV9pdipMRiPx5WkjhGJqKemsjq7phRwm5h93PleVRt35kqCXAGczWbVeN/c3KRESK/dpHruwnD4NZwkaB/4/6kCiPrp4w3ujDw2kcptsOv6dmy8Z4GOKjtKfDISyPfgVsehtbPfe3wdd4kMuCKCk1Cyl20QURKoaTICw+wRebtsjLPXvTlmnOB4PK4FMV7n6Jsh6Fuvk1JF5HsoQaWslhN9L/7XWjDHroRI57Y+711VIYhQUy13V6qvZ3c0YPFdv7oBRIkQ4gf3iM8eDB43hI5GowozL8e+CwW4pHi66ucv3+yzWCwqfxZRT+8rZsifk8BsE0jbmFtTAF2BYREDUCM8rxXK0mQ6yReLRdze3sZkMomvX79W9QUoYTgpv3ZJ/csW0Es70Qmwkjg96V3rlZhMjt/JE6kxjItip56CxZSRyFLt4zbNCYETD+97VX3BnaUC9HMU92QyiYuLi8qZaBmBkh0nzyXi21bLiJeTFK738PDwRP0ppa6ZHxiZi4uL2nhHxJNxLuHumiB5CjJLk6jiU1IylTii/F5fX8d8Pq8RIV1TmSLUtdpbSgV6jbJvBMBBeG0jfaHv/BzMYFEnoQ4xU8S6aBlmL0lB5dKNAKoAaspX7Z+mfiEF7JKkSB5i4E6xNEbqE9poPkd9I6IeA3J5eRmXl5dV4Ko1njruHsxp+u/4+DhOT0/j9PQ0Tk5OUuxgc5xdEKHSmtbxJlhHrFAipOVPzNuIqKlh7IQ9OTmpcLMhYhMy1HZT8qfrGtI7Ho+rMedrMjcQfF3DBwcHNV/lBFDJb0n9bBt3qylgdYAMqhdA3t7eVmkujL1uD3c1kclHWmw8Hlc7bait4ZEyTkJwFu402upMNTJOgvQxR9oweBTNai2fN7DztxgXtpavw97lBhh1tNm1OZYGpwFGrQtShUfJBAuPMWeBQYQeHh5qtXSOs4nwt62GOfFTA+cEUOuAIAJqZDVAQAHF0CgZImpW7I63a0PpDsgdhe6Q01SYFoIrCeR7rQVEPQCvpkay4vBNyO5z+yMb80z90/pkJwQ4BpQw+kjVYSfQq9XqiRIGESgViXfdPID1jR6QXSUBqEC+2cOJH2PsysjJyUmcn5/XyIAe67VL8qvEVzd66HhD/iDAvhlA7Z2KHdSDQfwgQefn53F2dpZi30Vz0u9pbh1vJb6on/P5/MkY88SLiMdNEczvs7OzCj9fHx8f18hQl9jdHvuYo27rZheOftGd3nd3dxER1f3u7z9SLcZ9f3+/In++E5g13jXmVhVAJ0BZsWdEVGQHhUeNhKeHMBJaH6TOxOu/1GjqxpKmFHAbZChTPr2mBeK7Wq0q7DgHr3dypzifz6vJxd/jJDSFqiTIyYGrBW2qnj7uWqMHYXt4eEhxe9E7RidzrCwsVUCaNpRkZKgN/D7urvyy2YVxYvxwiqiZYFf8YGfcNQWsTkNPjc+IflPq/znYNdApqQtKYHTscIwYSU3taoCg776pIuKxHkxToUqASuu6LcOpc97JupIhlFvmeEaElABD6vU6GkhAeiEEp6enRSWsKyfhDlGVWicDeh4aihBKmdYDa+3bYPDtUNyIqB2PARk6Pz+P8/PzGgnaJenNlM8s3cl4X1xc1Oa7nwgAbrXH4D47O6tejPfZ2VmcnJxUSlhb2ZxN8TsBZH4zrx23lu0Q5IJRiW9E1MgveCH8qn56Gljvryu1v4SbsYUAQvZ1ww+22jOAEfWyjoODgye7gP0swC7HuxUCqI4B46Wqjk52nGHENyJ4fX1dk8uJqNUh0IGoJppKguxwnZemgNvCvlqtapsyuA6TaTabVVhI63rNkzpSvtcaBByI1j16/WQT+WuDBOr/er/7RoW7u29PavFUPgYy2+Gq5Bfsqoz4dT3lnam9baq+fO3k32tYB4NBhWM2m9WOSFBip7tklQzq7yKiIj2ZwtuEu62WKWIR+VFI2SGppVpOTydzz5pRwFFktXC7UD25t4h6ekhropQQkRbToyFQAJX8eRAxGo1qjlEVAj8jrBTcdIU9I0JNJEifhuDqF36BOQ1JQPlDBYIAORHYNQnUdKCeeQcRZG3z8idgRNQ3KEZE7R3cZ2dnlfIHKfoeCmCmhBGIM9aQH53n/BzSj80ejUbVmLsqqOqnKn+qhpWyd541bAu3ZmN0zLFn4GYOaP2fpvezciffCFI6BsbVbr/PpqB8k9Z6DWBE/Tl/6oRQ8vidEgJ1ijh9JUJKCLR4lpPi9dqahivtAG5bHdDrQwJpquow+dndO51O4/Lysqp5clUIzGp8fcAVj2LP1LC2cGcqoBJ/NZgYbMggTkPrGf14F1UQ3Wlo+gDMjnsdKWibCJbq8vg7CKBGzyig1IuoSqBzms/3GpJS7Z/2UzbXX4LdiV9mgCPqT3FxkqBqf4kAaj9qKYNGzG4YM5wZxpeOuWN3VchJoKbJ/GgIDfJo3D/OIFNHSkXxu6qJypQw3Qig6TGtiyINqsdggU3HEHsJIXTSq7Vgu0x703yDk+JWzPo4MNRP3cjEmgWrigeMtaq9Xgf3vbHjsz0d6vWuZG0Yc0qVGHf8t65zSKC+slRoE/a2FMEmlT/b5a0bvfR0Dsbeyx08W6g7gjfZBdwWztZrACPq2/gjHjtzsVjUQGltFB2KKqK7pDRadlURYsF1lQiViEBXEbNem3slJQr5Gw6HacF7hl1VIG1MHCYo11ZC1JQGbAu3kwonQUq+6Q8MCLhVCXPcGcEajUapsux4u8LM5yghdSXQr+/qgR5p5IeFKsHHMRI9048RTzeC6DW7dhClz3by75sF/IkA9Ik2dQh8z3tGeDdxCl00dRBODHyXJOkjrYEs2TENalUpyGqZnxPItkH6m4ivbgbxA5/9eagRUdnF0WhUc44Qoew4kNKTIbpqmvLnXcfaj0Hxlz4EAGysZT/2RhUh3QwC6fegx7F3mQrN5rnW6WaHe+sJHdyXZwbBoAq/j7kTIca965bNdXD7sUa6uxtbjq+D7NI8QNeyIc/qNPGWtlrr5wA6CYx4XDiasol4mkbRM7Ng0Wosazf+/yPG7He7qAnypqSAe9jb26t2AjmT97QuzoLvwe51I1pHkaUinYzsArveQ0ZMlAjpBg8dc8ftcrmSP8fdNNa7xJ7dD01xu/PwHZE0jEfmJNUJbIp7F86Sr0vG059ik20EoH5G57q+mur+FGcXQZ5j9rS9Okk/LgSCpBmM4fDbeaCMcaYUlEpZnPzuggRnmP1sOD/b0J8AodkRrQeMqAc2rva6Dd0V3ibcPs76zte6CTIiauPstkwDoBL2XQQ8WVmGq/u6a91fWu8JLs3quF9z4UDfPd2/q2AvK2sBs46x1/Tq+vZ+5P4zEtiUsesKc6cHQbsi4gOpSkG2mLQzl8tlTeFYLBY1ZYnr+fWbHEQXeP3aJew6YXRCYSzVOWqaQMlf08QqYe2iH5T8riMhakDVkPjxMFpDoUqqq6EZdsfpf9dmy8bc70dVMSJDf7nirXWOKOclvLsmAU1NjSaYM6LgczwiavNba2Yi6gp30zjvGr/j9ayF4qQvtDB+MBhUj/XzHfFa6rArdXsdVsXsWL2GV9eyBnXs7NY14SUervLuyiGuw++4FbsSBC1xUJvlZS0RdfKnPq5U0/u9cDeNu768f8j2lXyWYy5lNr43/tK4Z5gVq9symtvvddnKLtpGBJCbH4/HG3+wp4KQTv3ROP6YGDeYajg0mhgMnj5ODhKhtRmoKPP5vDF9AjYfqOdizyRzTfVqfQSKgDtExa/Fw4rbHYwfzHl9fV1FXfP5vHZshk6wNnC74qNqgD8LUyVzH3N3lBHxZKwzVUkPoD0+Pq7S7Iy5R5L05zbYneRoAMO46gGhehhuRnoVn342huH+/r72/Gz+V1Orh4eHFWHUHcKaJm/Cvg63Y9b+9wff+6OvVB1RB+ljvbe392ROeH/50yZY4xyLpKrJJrifgz3bCEDNl9cDNT3LWQkPnz8cDp+oR6zr0g7oDHcWcG4z5qpez+fzWt2XP+2D95JCosEi96jj7GppRn7WjbUT45eMuds0t+H6VCrNYqhN0zlO29vbq8qBHLeeLQh2nXOs+2zHv2IHt2J77nxXkuMbIEqPt3P8YOc+sWEu9PgTZFCIHx4enqT+nSSXguBt1jn3zX2xvhWz+nC346pyZgJXhvno6KjyW4zxJpmOLAAsYc/aRgRwMplERMQff/yxyZ//n26TySTevn1b+z7ir4/9teKO6LGD/bXi5vuI14f9teLm+4ge+1+5vVbcEU+xZ22w2oAmLpfL+PjxY5yfn3/3FFNXbbVaxWQyiXfv3tXqF//q2F8r7ogeu2N/rbgjXi/214o7osf+V8b+WnFHlLFnbSMC2Le+9a1vfetb3/rWt79O2+3TxPvWt771rW9961vf+vbdW08A+9a3vvWtb33rW99eWesJYN/61re+9a1vfevbK2s9Aexb3/rWt771rW99e2WtJ4B961vf+ta3vvWtb6+s9QSwb33rW9/61re+9e2VtZ4A9q1vfetb3/rWt769svb/ACZOcwEklA5wAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "z, _, _ = model.encoder(images_train.reshape(-1, 64))\n", + "zrange = jnp.linspace(z[2], z[9], 10)\n", + "\n", + "logits = model.decoder(zrange).reshape(-1, 8, 8)\n", + "images_gen = nnx.sigmoid(logits)\n", + "\n", + "fig, ax = plt.subplots(1, 10, figsize=(8, 1),\n", + " subplot_kw={'xticks':[], 'yticks':[]},\n", + " gridspec_kw=dict(hspace=0.1, wspace=0.1))\n", + "for i in range(10):\n", + " ax.flat[i].imshow(images_gen[i], cmap='binary', interpolation='gaussian')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WHAaq0ryS7H5" + }, + "source": [ + "## Summary\n", + "\n", + "This tutorial offered a first example defining and training a generative model, in this case a simplified Variational Autoencoder (VAE).\n", + "Along the way we explored some approaches to debugging JAX programs, in particular the `jax.debug_nans` setting and the `jax.debug.print` function. Both of these are explained further in JAX's [Debugging runtime values](https://jax.readthedocs.io/en/latest/debugging/index.html) doc." + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "jupytext": { + "formats": "ipynb,md:myst" + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/docs/digits_vae.md b/docs/digits_vae.md new file mode 100644 index 0000000..83af38f --- /dev/null +++ b/docs/digits_vae.md @@ -0,0 +1,514 @@ +--- +jupytext: + formats: ipynb,md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.15.2 +kernelspec: + display_name: Python 3 + name: python3 +--- + ++++ {"id": "47OmRSTR1dJU"} + +# Debugging in JAX: a Variational autoencoder (VAE) model + +In [Getting started with JAX](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html) we built a simple neural network for classification of handwritten digits, and covered some of the key features of JAX, including its NumPy-style interface in the `jax.numpy`, as well as its transformations for JIT compilation with `jax.jit`, automatic vectorization with `jax.vmap`, and automatic differentiation with `jax.grad`. + +This tutorial will explore a slightly more involved model: a simplified version of a [Variational Autoencoder (VAE)](https://en.wikipedia.org/wiki/Variational_autoencoder) trained on the same simple digits data. Along the way, we'll learn a bit more about how JAX's JIT compilation actually works, and what this means for debugging JAX programs. + ++++ {"id": "k19povzxp7hS"} + +## Loading the digits + +As before, we'll use the small, self-contained [scikit-learn digits dataset](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_digits.html) for ease of experimentation: + +```{code-cell} +:id: aIwDAfS6PtFh +:outputId: 4950f17a-7c47-4a83-cbcd-e206d07cca64 + +from sklearn.datasets import load_digits +from sklearn.model_selection import train_test_split +import jax.numpy as jnp + +digits = load_digits() + +splits = train_test_split(digits.images, random_state=0) + +images_train, images_test = map(jnp.asarray, splits) + +print(f"{images_train.shape=}") +print(f"{images_test.shape=}") +``` + ++++ {"id": "2_Q16JRyrW7V"} + +The dataset comprises 1800 images, each represented by an 8x8 pixel grid. +To see a visualization of this data, refer to [loading the data](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html#loading-the-data) in the previous tutorial. + ++++ {"id": "Z9TPYqipPyBp"} + +## Defining the Variational Autoencoder + +Previously we defined a simple feedforward network trained for classification with an architecture that looked roughly like this: + +```{code-cell} +:id: HNlg-ydpr5yH + +import jax +import jax.numpy as jnp +from flax import nnx + +class SimpleNN(nnx.Module): + + def __init__(self, n_features=64, n_hidden=100, n_targets=10, *, rngs: nnx.Rngs): + self.layer1 = nnx.Linear(n_features, n_hidden, rngs=rngs) + self.layer2 = nnx.Linear(n_hidden, n_hidden, rngs=rngs) + self.layer3 = nnx.Linear(n_hidden, n_targets, rngs=rngs) + + def __call__(self, x: jax.Array) -> jax.Array: + x = nnx.selu(self.layer1(x)) + x = nnx.selu(self.layer2(x)) + return self.layer3(x) +``` + ++++ {"id": "3DwMBvoksOmG"} + +In this network we had one output per class, and the loss function was designed such that once trained, the output corresponding to the correct class would return the strongest signal, thus predicting the correct label in upwards of 95% of cases. + +In this VAE example we use similar building blocks to instead output a small probabilisitic model representing the data. While classic `VAE` is generally based on convolutional layers, we use linear layers for simplicity. The sub-network that produces this probabilistic encoding is our `Encoder`: + +```{code-cell} +:id: Hj7mtR5vmcGr + +class Encoder(nnx.Module): + def __init__(self, input_size: int, intermediate_size: int, output_size: int, + *, rngs: nnx.Rngs): + self.rngs = rngs + self.linear = nnx.Linear(input_size, intermediate_size, rngs=rngs) + self.linear_mean = nnx.Linear(intermediate_size, output_size, rngs=rngs) + self.linear_std = nnx.Linear(intermediate_size, output_size, rngs=rngs) + + def __call__(self, x: jax.Array) -> tuple[jax.Array, jax.Array, jax.Array]: + x = self.linear(x) + x = jax.nn.relu(x) + + mean = self.linear_mean(x) + std = jnp.exp(self.linear_std(x)) + + key = self.rngs.noise() + z = mean + std * jax.random.normal(key, mean.shape) + return z, mean, std +``` + ++++ {"id": "VwfCWbiRmkG9"} + +The idea here is that `mean` and `std` define a low-dimensional probability distribution over a latent space, and that `z` is a draw from this latent space that represents the training data. + +In order to ensure that this latent distribution faithfully represents the actual data, we define a `Decoder` that maps back to the input space: + +```{code-cell} +:id: FoAmZuVDnjgn + +class Decoder(nnx.Module): + def __init__(self, input_size: int, intermediate_size: int, output_size: int, + *, rngs: nnx.Rngs): + self.linear1 = nnx.Linear(input_size, intermediate_size, rngs=rngs) + self.linear2 = nnx.Linear(intermediate_size, output_size, rngs=rngs) + + def __call__(self, z: jax.Array) -> jax.Array: + z = self.linear1(z) + z = jax.nn.relu(z) + logits = self.linear2(z) + return logits +``` + ++++ {"id": "0QaT-KY6npSc"} + +Now the full VAE model is a single network built from the encoder and decoder. +It returns both the reconstructed image and then internal latent space model: + +```{code-cell} +:id: Myo2MdxXnzlT + +class VAE(nnx.Module): + def __init__( + self, + image_shape: tuple[int, int], + hidden_size: int, + latent_size: int, + *, + rngs: nnx.Rngs + ): + self.image_shape = image_shape + self.latent_size = latent_size + input_size = image_shape[0] * image_shape[1] + self.encoder = Encoder(input_size, hidden_size, latent_size, rngs=rngs) + self.decoder = Decoder(latent_size, hidden_size, input_size, rngs=rngs) + + def __call__(self, x: jax.Array) -> tuple[jax.Array, jax.Array, jax.Array]: + x = jax.vmap(jax.numpy.ravel)(x) # flatten + z, mean, std = self.encoder(x) + logits = self.decoder(z) + logits = jnp.reshape(logits, (-1, *self.image_shape)) + return logits, mean, std +``` + ++++ {"id": "xIm9Yi5YoIxN"} + +Next is the loss function – there are two components to the model that we want to ensure: + +1. the `logits` output faithfully reconstruct the input image. +2. the model represented by `mean` and `std` faithfully represents the "true" latent distribution. + +VAE uses a loss function based on the [Evidence lower bound](https://en.wikipedia.org/wiki/Evidence_lower_bound) to quantify theset two goals in a single loss value: + +```{code-cell} +:id: bMpxj8-Wsvui + +def vae_loss(model: VAE, x: jax.Array): + logits, mean, std = model(x) + kl_loss = jnp.mean(0.5 * jnp.mean( + -jnp.log(std ** 2) - 1.0 + std ** 2 + mean ** 2, axis=-1)) + reconstruction_loss = jnp.mean( + optax.sigmoid_binary_cross_entropy(logits, x) + ) + return reconstruction_loss + 0.1 * kl_loss +``` + ++++ {"id": "RaT0ELpvqo2W"} + +Now all that's left is to define the model and optimizer, and run the training loop: + +```{code-cell} +:id: JPgoHL5rpKXd +:outputId: e7626646-7c14-42dc-c18f-774adff1306e + +import optax + +model = VAE( + image_shape=(8, 8), + hidden_size=32, + latent_size=8, + rngs=nnx.Rngs(0, noise=1), +) + +optimizer = nnx.Optimizer(model, optax.adam(1e-3)) + +@nnx.jit +def train_step(model: VAE, optimizer: nnx.Optimizer, x: jax.Array): + loss, grads = nnx.value_and_grad(vae_loss)(model, x) + optimizer.update(grads) + return loss + +for epoch in range(2001): + loss = train_step(model, optimizer, images_train) + if epoch % 500 == 0: + print(f'Epoch {epoch} loss: {loss}') +``` + ++++ {"id": "f8m_uoL9q47M"} + +And here we see that something has gone wrong: our loss value has become NaN after some number of iterations. + ++++ {"id": "SBS1mmxwrS25"} + +## Debugging NaNs in JAX +Despite our best efforts, our model is producing NaNs. What can we do? + +JAX offers a number of debugging approaches for situations like this, outlined in the JAX docs at [Debugging runtime values](https://jax.readthedocs.io/en/latest/debugging/index.html). In this case we can use the [`debug_nans`](https://jax.readthedocs.io/en/latest/debugging/flags.html#jax-debug-nans-configuration-option-and-context-manager) configuration to see where the NaN value is arising: + +```{code-cell} +:id: JE7OYoZ4rRQ8 +:outputId: 2b77a9c1-aaa4-418a-da71-013d8f885340 +:tags: [raises-exception] + +model = VAE( + image_shape=(8, 8), + hidden_size=32, + latent_size=8, + rngs=nnx.Rngs(0, noise=1), +) + +optimizer = nnx.Optimizer(model, optax.adam(1e-3)) + +with jax.debug_nans(True): + for epoch in range(2001): + train_step(model, optimizer, images_train) +``` + ++++ {"id": "thw9URJmrROj"} + +The output here is complicated, because the function we're evaluating is complicated, but the key to deciphering this traceback is to look for the places where the traceback touches your implementation. In particular here, it indicates that NaN values arise during the gradient update: +``` + in train_step() + 14 loss, grads = nnx.value_and_grad(vae_loss)(model, x) +---> 15 optimizer.update(grads) + 16 return loss +``` +and further down from this, the details of the gradient update step where the NaN is arising: +``` +/usr/local/lib/python3.10/dist-packages/optax/tree_utils/_tree_math.py in () + 280 lambda g, t: ( +--> 281 (1 - decay) * (g**order) + decay * t if g is not None else None + 282 ), +``` +This tells us that the gradient is returning values that lead to `NaN` during the model update: typically this would come about when the gradient itself is for some reason diverging. + +A diverging gradient means that something with our loss function may be amiss. We previously saw `loss=NaN` at iteration 500; let's print the progress up to this point: + +```{code-cell} +:id: KJ1gAh8uurVX +:outputId: 21564ba8-2ea4-42fb-bf9a-0291a556004c + +model = VAE( + image_shape=(8, 8), + hidden_size=32, + latent_size=8, + rngs=nnx.Rngs(0, noise=1), +) + +optimizer = nnx.Optimizer(model, optax.adam(1e-3)) + +for epoch in range(301): + loss = train_step(model, optimizer, images_train) + if epoch % 50 == 0: + print(f'Epoch {epoch} loss: {loss}') +``` + ++++ {"id": "Jk_wdvqpurTG"} + +It looks like our loss value is decreasing toward negative infinity until the point where the values are no longer well-represented by floating point math. + +At this point, we may wish to expect the values within the loss function itself to see where the diverging loss might be coming from. +In typical Python programs you can do this by inserting either a `print` statement or a `breakpoint` in the loss function; it might look something like this: + +```{code-cell} +:id: 9Klkz7qHwWia +:outputId: e727cf56-0cc1-4256-9567-5fe96396ce45 + +def vae_loss(model: VAE, x: jax.Array): + logits, mean, std = model(x) + kl_loss = jnp.mean(0.5 * jnp.mean( + -jnp.log(std ** 2) - 1.0 + std ** 2 + mean ** 2, axis=-1)) + reconstruction_loss = jnp.mean( + optax.sigmoid_binary_cross_entropy(logits, x) + ) + print("kl loss", kl_loss) + print("reconstruction loss", reconstruction_loss) + return reconstruction_loss + 0.1 * kl_loss + +model = VAE( + image_shape=(8, 8), + hidden_size=32, + latent_size=8, + rngs=nnx.Rngs(0, noise=1), +) + +optimizer = nnx.Optimizer(model, optax.adam(1e-3)) +train_step(model, optimizer, images_train) +``` + ++++ {"id": "mawnZtF8wxu9"} + +But here rather than printing the value, we're getting some kind of `Traced` object. You'll encounter this frequently when inspecting the progress of JAX programs: tracers are the mechanism that JAX uses to implement transformations like `jit` and `grad`, and you can read more about them in [JAX Key Concepts: Tracing](https://jax.readthedocs.io/en/latest/key-concepts.html#tracing). + +For our purposes, the workaround is to use another tool from the [Debugging runtime values](https://jax.readthedocs.io/en/latest/debugging/index.html#interactive-inspection-with-jax-debug) link above: namely `jax.debug.print`, which lets us print runtime values even when they're traced: + +```{code-cell} +:id: wziDzgdTuloK +:outputId: 5d640a4e-306b-4b93-9b99-725ee6e4baa1 + +def vae_loss(model: VAE, x: jax.Array): + logits, mean, std = model(x) + + kl_loss = jnp.mean(0.5 * jnp.mean( + -jnp.log(std ** 2) - 1.0 + std ** 2 + mean ** 2, axis=-1)) + reconstruction_loss = jnp.mean( + optax.sigmoid_binary_cross_entropy(logits, x) + ) + jax.debug.print("kl_loss: {}", kl_loss) + jax.debug.print("reconstruction_loss: {}", reconstruction_loss) + return reconstruction_loss + 0.1 * kl_loss + +model = VAE( + image_shape=(8, 8), + hidden_size=32, + latent_size=8, + rngs=nnx.Rngs(0, noise=1), +) + +optimizer = nnx.Optimizer(model, optax.adam(1e-3)) + +for i in range(5): + train_step(model, optimizer, images_train) +``` + ++++ {"id": "wObaYRxF1-qy"} + +Let's iterate a few hundred more times (we'll use the IPython `%%capture` magic to avoid printing all the output on the first several hundred iterations) and then do one more run to print these intermediate values: + +```{code-cell} +:id: pBLiDfRX2OOu + +%%capture +for i in range(300): + train_step(model, optimizer, images_train) +``` + +```{code-cell} +:id: O87gxdxGP3uZ +:outputId: a4dc57cd-56e0-4597-fbfd-9572e4c4ef7a + +loss = train_step(model, optimizer, images_train) +``` + ++++ {"id": "FXHfa0942apE"} + +We see that the large negative value is coming from the `reconstruction_loss` term. Let's return to this and look at what it's actually doing: +```python +reconstruction_loss = jnp.mean( + optax.sigmoid_binary_cross_entropy(logits, x) +) +``` +This is a binary cross entropy described at [`optax.sigmoid_binary_cross_entropy`](https://optax.readthedocs.io/en/latest/api/losses.html#optax.losses.sigmoid_binary_cross_entropy). From the documentation, the first input should be a logit, and the second input is assumed to be a binary label (i.e. a ``0`` or ``1``) – but in our case `x` is associated with `images_train`, which is not a binary label! + +```{code-cell} +:id: seLRa2qE3wd3 +:outputId: 8b6ba25f-2edb-44c7-c5a9-bca2f334b5eb + +print(images_train[0]) +``` + ++++ {"id": "KjF0ys3c30w8"} + +This is likely the source of our issue: we forgot to normalize the input images to the range ``(0, 1)``! + ++++ {"id": "jkFIqZaTXRc5"} + +Let's fix this by binarizing our inputs, and then run the training loop again (redefining our loss function to remove the debug statements): + +```{code-cell} +:id: 9Og1-tIw4BNu +:outputId: eebd9fd0-2b3b-43ce-9af8-729544a00aca + +images_normed = (digits.images / 16) > 0.5 +splits = train_test_split(images_normed, random_state=0) +images_train, images_test = map(jnp.asarray, splits) + +def vae_loss(model: VAE, x: jax.Array): + logits, mean, std = model(x) + + kl_loss = jnp.mean(0.5 * jnp.mean( + -jnp.log(std ** 2) - 1.0 + std ** 2 + mean ** 2, axis=-1)) + reconstruction_loss = jnp.mean( + optax.sigmoid_binary_cross_entropy(logits, x) + ) + return reconstruction_loss + 0.1 * kl_loss + +model = VAE( + image_shape=(8, 8), + hidden_size=32, + latent_size=8, + rngs=nnx.Rngs(0, noise=1), +) + +optimizer = nnx.Optimizer(model, optax.adam(1e-3)) + +for epoch in range(2001): + loss = train_step(model, optimizer, images_train) + if epoch % 500 == 0: + print(f'Epoch {epoch} loss: {loss}') +``` + ++++ {"id": "4HD91gWfyJuJ"} + +The loss values are now behaving, and it looks like we've successfully debugged the initial NaN problem, which was not in our model but rather in our input data. + ++++ {"id": "vA6wSi1k5GuZ"} + +## Exploring the VAE model results + +Now that we have a trained model, let's explore what it can be used for. +First, let's pass our test data through the model to output the result of the associated latent space representation for each input. We pass the `logits` through a `sigmoid` function to recover predicted images in the input space: + +```{code-cell} +:id: fBzJyliAPCGc + +logits, mean, std = model(images_test) +images_pred = jax.nn.sigmoid(logits) +``` + ++++ {"id": "qiJy39iDPTVS"} + +Let's visualize several of these inputs and outputs: + +```{code-cell} +:id: dRFxkKInn_gx +:outputId: 4452517c-ac68-479d-a452-c282e4092291 + +import matplotlib.pyplot as plt + +fig, ax = plt.subplots(2, 10, figsize=(6, 1.5), + subplot_kw={'xticks':[], 'yticks':[]}, + gridspec_kw=dict(hspace=0.1, wspace=0.1)) +for i in range(10): + ax[0, i].imshow(images_test[i], cmap='binary', interpolation='gaussian') + ax[1, i].imshow(images_pred[i], cmap='binary', interpolation='gaussian') +``` + ++++ {"id": "eaM-A6rNQWFz"} + +The top row here are the input images, and the bottom row are what the model "thinks" these images look like, given their latent space representation. +There's not perfect fidelity, but the essential features are recovered. + +We can go a step further and generate a set of new images from scratch by sampling randomly from the latent space. Let's generate 36 new digits this way: + +```{code-cell} +:id: aNV9CNC1r2Gn +:outputId: 5855d88a-81c8-4aee-d9ff-885ad78ad421 + +import numpy as np + +# generate new images by sampling the latent space +z = np.random.normal(scale=1.5, size=(36, model.latent_size)) +logits = model.decoder(z).reshape(-1, 8, 8) +images_gen = nnx.sigmoid(logits) + +fig, ax = plt.subplots(6, 6, figsize=(4, 4), + subplot_kw={'xticks':[], 'yticks':[]}, + gridspec_kw=dict(hspace=0.1, wspace=0.1)) +for i in range(36): + ax.flat[i].imshow(images_gen[i], cmap='binary', interpolation='gaussian') +``` + ++++ {"id": "4oCtm6V3TsQJ"} + +Another possibility here is to use our latent model to interpolate between two entries in the training set through the latent model space. +Here's an interpolation between a digit 9 and a digit 3: + +```{code-cell} +:id: 8iJ9f60VUBwY +:outputId: ce34c9a4-3d47-43b4-ec1a-c5d47db0b394 + +z, _, _ = model.encoder(images_train.reshape(-1, 64)) +zrange = jnp.linspace(z[2], z[9], 10) + +logits = model.decoder(zrange).reshape(-1, 8, 8) +images_gen = nnx.sigmoid(logits) + +fig, ax = plt.subplots(1, 10, figsize=(8, 1), + subplot_kw={'xticks':[], 'yticks':[]}, + gridspec_kw=dict(hspace=0.1, wspace=0.1)) +for i in range(10): + ax.flat[i].imshow(images_gen[i], cmap='binary', interpolation='gaussian') +``` + ++++ {"id": "WHAaq0ryS7H5"} + +## Summary + +This tutorial offered a first example defining and training a generative model, in this case a simplified Variational Autoencoder (VAE). +Along the way we explored some approaches to debugging JAX programs, in particular the `jax.debug_nans` setting and the `jax.debug.print` function. Both of these are explained further in JAX's [Debugging runtime values](https://jax.readthedocs.io/en/latest/debugging/index.html) doc. diff --git a/docs/tutorials.md b/docs/tutorials.md index 19473d0..03f0add 100644 --- a/docs/tutorials.md +++ b/docs/tutorials.md @@ -8,6 +8,7 @@ The following tutorials are meant as an intro to the full stack: :maxdepth: 2 getting_started_with_jax_for_AI +digits_vae ``` Once you've gone through this content, you can refer to package-specific diff --git a/pyproject.toml b/pyproject.toml index d217601..16cec95 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,3 +63,7 @@ requires = [ [tool.setuptools] packages = ["jax_ai_stack"] include-package-data = false + +[tool.ruff.lint.per-file-ignores] +# F811: Redefinition of unused name. +"docs/digits_vae.ipynb" = ["F811"]