diff --git a/blackjax/__init__.py b/blackjax/__init__.py index 9c645730f..24a216935 100644 --- a/blackjax/__init__.py +++ b/blackjax/__init__.py @@ -5,6 +5,7 @@ hmc, mala, nuts, + orbital_hmc, rmh, tempered_smc, window_adaptation, @@ -16,6 +17,7 @@ "hmc", # mcmc "mala", "nuts", + "orbital_hmc", "rmh", "window_adaptation", # mcmc adaptation "adaptive_tempered_smc", # smc diff --git a/blackjax/kernels.py b/blackjax/kernels.py index 663c3dc08..ba9ec6b81 100644 --- a/blackjax/kernels.py +++ b/blackjax/kernels.py @@ -15,6 +15,7 @@ "hmc", "mala", "nuts", + "orbital_hmc", "rmh", "tempered_smc", "window_adaptation", @@ -545,3 +546,81 @@ def step_fn(rng_key: PRNGKey, state): ) return SamplingAlgorithm(init_fn, step_fn) + + +class orbital_hmc: + """Implements the (basic) user interface for the Periodic orbital MCMC kernel + + Each iteration of the periodic orbital MCMC outputs ``period`` weighted samples from + a single Hamiltonian orbit connecting the previous sample and momentum (latent) variable + with precision matrix ``inverse_mass_matrix``, evaluated using the ``bijection`` as an + integrator with discretization parameter ``step_size``. + + Examples + -------- + + A new Periodic orbital MCMC kernel can be initialized and used with the following code: + + .. code:: + + per_orbit = blackjax.orbital_hmc(logprob_fn, step_size, inverse_mass_matrix, period) + state = per_orbit.init(position) + new_state, info = per_orbit.step(rng_key, state) + + We can JIT-compile the step function for better performance + + .. code:: + + step = jax.jit(per_orbit.step) + new_state, info = step(rng_key, state) + + Parameters + ---------- + logprob_fn + The logarithm of the probability density function we wish to draw samples from. This + is minus the potential energy function. + step_size + The value to use for the step size in for the symplectic integrator to buid the orbit. + inverse_mass_matrix + The value to use for the inverse mass matrix when drawing a value for + the momentum and computing the kinetic energy. + period + The number of steps used to build the orbit. + bijection + (algorithm parameter) The symplectic integrator to use to build the orbit. + + Returns + ------- + A ``SamplingAlgorithm``. + + """ + + init = staticmethod(mcmc.periodic_orbital.init) + kernel = staticmethod(mcmc.periodic_orbital.kernel) + + def __new__( # type: ignore[misc] + cls, + logprob_fn: Callable, + step_size: float, + inverse_mass_matrix: Array, # assume momentum is always Gaussian + period: int, + *, + bijection: Callable = mcmc.integrators.velocity_verlet, + ) -> SamplingAlgorithm: + + step = cls.kernel(bijection) + + def init_fn(position: PyTree): + return cls.init(position, logprob_fn, period) + + def step_fn(rng_key: PRNGKey, state): + return step( + rng_key, + state, + logprob_fn, + step_size, + inverse_mass_matrix, + period, + ) + + return SamplingAlgorithm(init_fn, step_fn) diff --git a/blackjax/mcmc/__init__.py b/blackjax/mcmc/__init__.py index 48c722731..fa23f1bd4 100644 --- a/blackjax/mcmc/__init__.py +++ b/blackjax/mcmc/__init__.py @@ -1,3 +1,3 @@ -from . import hmc, mala, nuts, rmh +from . import hmc, mala, nuts, periodic_orbital, rmh -__all__ = ["hmc", "mala", "nuts", "rmh"] +__all__ = ["hmc", "mala", "nuts", "periodic_orbital", "rmh"] diff --git a/blackjax/mcmc/periodic_orbital.py b/blackjax/mcmc/periodic_orbital.py new file mode 100644 index 000000000..faaa46c0c --- /dev/null +++ b/blackjax/mcmc/periodic_orbital.py @@ -0,0 +1,294 @@ +"""Public API for Periodic Orbital Kernel""" +from typing import Callable, NamedTuple, Tuple + +import jax +import jax.numpy as jnp + +import blackjax.mcmc.integrators as integrators +import blackjax.mcmc.metrics as metrics +from blackjax.types import Array, PRNGKey, PyTree + + +class PeriodicOrbitalState(NamedTuple): + """State of the periodic orbital algorithm. + + The periodic orbital algorithm takes one orbit with weights, + samples from the points on that orbit according to their weights + and returns another weighted orbit of the same period. + + positions + a collection of points on the orbit, representing samples from + the target distribution. + weights + weights of each point on the orbit, reweights points to ensure + they are from the target distribution. + directions + an integer indicating the position on the orbit of each point. + potential_energies + vector with energies (negative log densities) for each point in + the orbit. + potential_energies_grad + matrix where each row is a vector with gradients of the energy + function for each point in the orbit. + """ + + positions: PyTree + weights: Array + directions: Array + potential_energies: Array + potential_energies_grad: PyTree + + +class PeriodicOrbitalInfo(NamedTuple): + """Additional information on the states in the orbit. + + This additional information can be used for debugging or computing + diagnostics. + + momentum + the momentum that was sampled and used to integrate the trajectory. + weights_mean + mean of the the unnormalized weights of the orbit, ideally close + to the (unknown) constant of proportionally missing from the target. + weights_variance + variance of the unnormalized weights of the orbit, ideally close to 0. + """ + + momentums: PyTree + weights_mean: float + weights_variance: float + + +def init(position: PyTree, logprob_fn: Callable, period: int) -> PeriodicOrbitalState: + """Create a periodic orbital state from a position. + + Parameters + ---------- + position + the current values of the random variables whose posterior we want to + sample from. Can be anything from a list, a (named) tuple or a dict of + arrays. The arrays can either be Numpy or JAX arrays. + logprob_fn + a function that returns the value of the log posterior when called + with a position. + period + the number of steps used to build the orbit + + Returns + ------- + A periodic orbital state that repeats the same position for `period` times, + sets equal weights to all positions, assigns to each position a direction from + 0 to period-1, calculates the potential energies for each position and its + gradient. + """ + + def potential_fn(x): + return -logprob_fn(x) + + positions = jax.tree_util.tree_map( + lambda position: jnp.array([position for _ in range(period)]), position + ) + + weights = jnp.array([1 / period for _ in range(period)]) + + directions = jnp.arange(period) + + potential_energies, potential_energies_grad = jax.vmap( + jax.value_and_grad(potential_fn) + )(positions) + + return PeriodicOrbitalState( + positions, weights, directions, potential_energies, potential_energies_grad + ) + + +def kernel( + bijection: Callable = integrators.velocity_verlet, +): + """Build a Periodic Orbital kernel [1]_. + + Parameters + ---------- + bijection + transformation used to build the orbit (given a step size). + + Returns + ------- + A kernel that takes a rng_key and a Pytree that contains the current state + of the chain and that returns a new state of the chain along with + information about the transition. + + References + ---------- + .. [1]: Kirill Neklyudov and Max Welling "Orbital MCMC." arXiv preprint + arXiv:2010.08047 (2021). + """ + + def one_step( + rng_key: PRNGKey, + state: PeriodicOrbitalState, + logprob_fn: Callable, + step_size: float, + inverse_mass_matrix: Array, + period: int, + ) -> Tuple[PeriodicOrbitalState, PeriodicOrbitalInfo]: + """Generate a new orbit with the Periodic Orbital kernel. + + Choose a step from the orbit with probability proportional to its weights. + Then shift the direction (or alternatively sample a new direction randomly), + in order to make the algorithm irreversible, and compute a new orbit from + the selected step and its direction. + + Parameters + ---------- + rng_key + pseudo random number generating key. + state + initial orbit. + logprob_fn + log probability function we wish to sample from. + step_size + space between steps of the orbit. + inverse_mass_matrix + or a 1D array containing elements of its diagonal. + period + total steps used to build the orbit. + + Returns + ------- + A kernel that chooses a step from the orbit and outputs a periodic orbital + state and information about the iteration. + + """ + + def potential_fn(x): + return -logprob_fn(x) + + momentum_generator, kinetic_energy_fn, _ = metrics.gaussian_euclidean( + inverse_mass_matrix + ) + bijection_fn = bijection(potential_fn, kinetic_energy_fn) + proposal_generator = periodic_orbital_proposal( + bijection_fn, kinetic_energy_fn, period, step_size + ) + + key_choice, key_momentum = jax.random.split(rng_key, 2) + + ( + positions, + weights, + directions, + potential_energies, + potential_energies_grad, + ) = state + + choice_indx = jax.random.choice(key_choice, len(weights), p=weights) + position = jax.tree_util.tree_map( + lambda positions: positions[choice_indx], positions + ) + direction = directions[choice_indx] + period = jnp.max(directions) + 1 + direction = jnp.mod(direction + jnp.array(period / 2, int), period) + potential_energy = potential_energies[choice_indx] + potential_energy_grad = jax.tree_util.tree_map( + lambda p_energy_grad: p_energy_grad[choice_indx], potential_energies_grad + ) + + momentum = momentum_generator(key_momentum, position) + + augmented_state = integrators.IntegratorState( + position, + momentum, + potential_energy, + potential_energy_grad, + ) + proposal, info = proposal_generator(direction, augmented_state) + + return proposal, info + + return one_step + + +def periodic_orbital_proposal( + bijection: Callable, + kinetic_energy_fn: Callable, + period: int, + step_size: float, +) -> Callable: + """Periodic Orbital algorithm. + + The algorithm builds and orbit and computes the weights for each of its steps + by applying a bijection `period` times, both forwards and backwards depending + on the direction of the initial state. + + Parameters + ---------- + bijection + continuous, differentialble and bijective transformation used to build + the orbit step by step. + kinetic_energy_fn + function that computes the kinetic energy. + period + total steps used to build the orbit. + step_size + size between each step of the orbit. + + Returns + ------- + A kernel that generates a new periodic orbital state and information + about the transition. + + """ + + def generate( + direction: int, init_state: integrators.IntegratorState + ) -> Tuple[PeriodicOrbitalState, PeriodicOrbitalInfo]: + """Generate orbit by applying bijection forwards and backwards on period. + + As described in algorithm 2 of [1]_, each iteration of the periodic orbital + MCMC takes a position and its direction, i.e. its step in the orbit, then + it runs the bijection backwards until it reaches the direction 0 and forwards + until it reaches the direction period-1. For each step it calculates its + weight using the target density, the auxilary variable's density and the + bijection. + + References + ---------- + .. [1]: Kirill Neklyudov and Max Welling "Orbital MCMC." arXiv preprint + arXiv:2010.08047 (2021). + """ + + index_steps = jnp.arange(period) - direction + + def orbit_fn(state, i): + state = jax.lax.cond( + i != 0, + lambda _: bijection(state, jnp.sign(i) * step_size), + lambda _: init_state, + operand=None, + ) + kinetic_energy = kinetic_energy_fn(state.momentum) + weight = -(state.potential_energy + kinetic_energy) + return state, (state, jnp.exp(weight)) + + _, (states, weights) = jax.lax.scan(orbit_fn, init_state, index_steps) + + directions = jnp.where( + index_steps < 0, -(index_steps + 1), index_steps + direction + ) + + state = PeriodicOrbitalState( + states.position, + weights / jnp.sum(weights), + directions, + states.potential_energy, + states.potential_energy_grad, + ) + info = PeriodicOrbitalInfo( + states.momentum, + jnp.mean(weights), + jnp.var(weights), + ) + return state, info + + return generate diff --git a/docs/examples.rst b/docs/examples.rst index e2180a214..d9ff94075 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -11,3 +11,4 @@ Examples examples/use_with_pymc3.ipynb examples/use_with_tfp.ipynb examples/HierarchicalBNN.ipynb + examples/PeriodicOrbitalMCMC.ipynb diff --git a/docs/sampling.rst b/docs/sampling.rst index 53fd63ef7..d18cd708f 100644 --- a/docs/sampling.rst +++ b/docs/sampling.rst @@ -9,6 +9,7 @@ Sampling hmc nuts mala + orbital_hmc rmh tempered_smc adaptive_tempered_smc @@ -61,6 +62,11 @@ NUTS .. autoclass:: blackjax.nuts +Periodic Orbital +~~~~~~~~~~~~~~~~ + +.. autoclass:: blackjax.orbital_hmc + RMH ~~~ diff --git a/examples/PeriodicOrbitalMCMC.ipynb b/examples/PeriodicOrbitalMCMC.ipynb new file mode 100644 index 000000000..02769aefb --- /dev/null +++ b/examples/PeriodicOrbitalMCMC.ipynb @@ -0,0 +1,997 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Periodic Orbital MCMC\n", + "\n", + "Illustrating the usage of Algorithm 2 of [Neklyudov & Welling, (2021)](https://arxiv.org/abs/2010.08047) on the Banana density \n", + "\n", + "$$p(x) = p(x_1, x_2) = N(x_1|0, 8)N(x_2|1/4x_1^2, 1).$$ \n", + "\n", + "Bijection functions $f(x, v)$ used for sampling are the velocity Verlet, McLachlan and Yoshida integrators for the Hamiltonian function\n", + "\n", + "$$ H(x, v) = \\frac{1}{2}\\left(\\frac{x_1^2}{8} + \\left(x_2 - \\frac{1}{4}x_1^2\\right)^2\\right) + \\frac{1}{2}v^Tv. $$\n", + "\n", + "Using any of these integrators amounts to doing vanilla HMC (traditionally done with the velocity Verlet integrator) but sampling various points from an orbit that discretizes the Hamiltonian dynamics to then weigh these samples in order to ensure we target the correct distribution (where in vanilla HMC we would choose a sample from the discretized orbit and perform a Metropolis-Hastings acceptance step on that sample to ensure the target distribution is left invariant).\n", + "\n", + "The benefits of sampling the whole orbit instead of a single point in it are: efficiency, since we build a trajectory around an orbit and use all if it instead of discarding most of it; and wider reach, since even unlikely points will be sampled and given small weights, making the sampler more likely to explore the tails of our target. This at the cost of higher memory consumption since we have `period` samples per iteration, instead of only one, and the lack of diagnostics, theoretical guarantees and heuristic methods developed for traditional HMC and its adaptive mechanisms (such as NUTS) during the past decades.\n", + "\n", + "It is also illustrated the usage of normalizing flows, specifically the Masked Autoregressive flow ([MAF](https://arxiv.org/abs/1705.07057)), as a preconditioning step for the algorithm; using as a bijection function the ellipsis\n", + "\n", + "$$ \n", + "x(t) = x(0) \\cos(t) + v(t) \\sin(t) \\\\\n", + "v(t) = v(0) \\cos(t) - x(t) \\sin(t),\n", + "$$\n", + "\n", + "i.e. the solution of Hamilton's equations for $p(x,v) = N(x|0,I)N(v|0,I)$,\n", + "\n", + "$$\n", + "\\frac{d x}{d t} = v \\\\\n", + "\\frac{d v}{d t} = -x.\n", + "$$\n", + "\n", + "As it is later demonstrated, these dynamics alone fail to capture all the volume of our banana density. They are, however, cheap and easy to use, since these dynamics are both gradient-free (don't require the computation of gradients of our target distribution) and tuning-free (have no tuning parameters); in contrast with the integrators mentioned above, which need to compute gradients at each iteration and require tuning of the discretization step size and number of steps (when used for periodic orbital MCMC, these values are represented by the `step_size` and `period`). Paired with a preconditioning step which transforms our target to approximate $N(x|0,I)$, our cheap and easy dynamics can efficienty sample from the whole volume of our banana density while delegating the expensive gradients and cumbersome tuning to an optimization problem performed pre-sampling." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "import jax.scipy.stats as stats\n", + "import matplotlib.pyplot as plt\n", + "\n", + "import blackjax.mcmc.integrators as integrators\n", + "from blackjax import orbital_hmc as orbital" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Python implementation: CPython\n", + "Python version : 3.8.10\n", + "IPython version : 7.28.0\n", + "\n", + "jax : 0.3.7\n", + "jaxlib : 0.3.7\n", + "blackjax: 0.4.0\n", + "\n", + "Compiler : GCC 9.4.0\n", + "OS : Linux\n", + "Release : 5.4.0-107-generic\n", + "Machine : x86_64\n", + "Processor : x86_64\n", + "CPU cores : 8\n", + "Architecture: 64bit\n", + "\n" + ] + } + ], + "source": [ + "%load_ext watermark\n", + "%watermark -d -m -v -p jax,jaxlib,blackjax" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" + ] + }, + { + "data": { + "text/plain": [ + "[CpuDevice(id=0)]" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "jax.devices()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Useful functions" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "def plot_contour(logprob, orbits=None, weights=None):\n", + " \"\"\"Contour plots for density w/ or w/o samples.\"\"\"\n", + " a, b, c, d = -7.5, 7.5, -5, 12.5\n", + " x1 = jnp.linspace(a, b, 1000)\n", + " x2 = jnp.linspace(c, d, 1000)\n", + " y = jax.vmap(\n", + " jax.vmap(lambda x1, x2: jnp.exp(logprob({\"x1\": x1, \"x2\": x2})), (0, None)),\n", + " (None, 0),\n", + " )(x1, x2)\n", + " fig, ax = plt.subplots(1, 2, figsize=(17, 6))\n", + " CS0 = ax[0].contour(x1, x2, y, levels=10, colors=\"k\")\n", + " plt.clabel(CS0, inline=1, fontsize=10)\n", + " CS1 = ax[1].contour(x1, x2, y, levels=10, colors=\"k\")\n", + " plt.clabel(CS1, inline=1, fontsize=10)\n", + " if orbits is not None:\n", + " ax[0].set_title(\"Unweighted samples\")\n", + " ax[0].scatter(orbits[\"x1\"], orbits[\"x2\"], marker=\".\")\n", + " ax[1].set_title(\"Weighted samples\")\n", + " ax[1].scatter(orbits[\"x1\"], orbits[\"x2\"], marker=\".\", alpha=weights)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "def inference_loop(rng_key, kernel, initial_state, num_samples):\n", + " \"\"\"Sequantially draws samples given the kernel of choice.\"\"\"\n", + "\n", + " def one_step(state, rng_key):\n", + " state, _ = kernel(rng_key, state)\n", + " return state, state\n", + "\n", + " keys = jax.random.split(rng_key, num_samples)\n", + " _, states = jax.lax.scan(one_step, initial_state, keys)\n", + "\n", + " return states" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Banana density" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAA9gAAAFlCAYAAAAQ4UDzAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOzddVhUafsH8PtM0CCNqIgidivq2u2KAYpid2B3u3bXGmsHdne7tq69YrdrIxZKCNLM9/eH75yf45kZalDA+3NdXu8658w5z7zCnPt+4n4EAMQYY4wxxhhjjLG0kf3sBjDGGGOMMcYYY1kBJ9iMMcYYY4wxxpgBcILNGGOMMcYYY4wZACfYjDHGGGOMMcaYAXCCzRhjjDHGGGOMGQAn2IwxxhhjjDHGmAEofsZN7e3tkSdPnp9xa8YYYyxVrl279hGAw89uR1bDMQFjjLHMSFdc8FMS7Dx58lBAQMDPuDVjjDGWKoIgvPzZbciKOCZgjDGWGemKC3iKOGOMMcYYY4wxZgCcYDPGGGOMMcYYYwbACTZjjDHGGGOMMWYAnGAzxhhjjDHGGGMGkOwEWxCE1YIgfBAE4e43r80WBOGhIAi3BUHYIwiCdbq0kjHGGGMZCscFjDHGmFRKRrDXElH97147TkTFAJQgosdENMpA7WKMMcZYxraWOC5gjDHGNCQ7wQbwDxGFfPfaMQAJ//vrZSLKZcC2McYYYyyD4riAMcYYkzLkGuwuRHRE10FBEPwEQQgQBCEgODjYgLdljDHGWAakMy7gmIAxxlhWZZAEWxCEP4gogYg26ToHwAoAHgA8HBwcDHFbxhhjjGVAScUFHBMwxhjLqhRpvYAgCJ2IqBER1QaANLeIMcYYY5kWxwWMMcZ+ZWlKsAVBqE9Ew4moOoAowzSJMcYYY5kRxwWMMcZ+dSnZpmsLEV0iooKCILwWBKErES0iIksiOi4Iwk1BEJalUzsZY4wxloFwXMAYY4xJJXsEG0BrLS/7G7AtjDHGGMskOC5gjDHGpAxZRZwxxhhjjDHGGPtlcYLNGGOMMcYYY4wZACfYjDHGGGOMMcaYAXCCzRhjjDHGGGOMGQAn2IwxxhhjjDHGmAFwgs0YY4wxxhhjjBkAJ9iMMcYYY4wxxpgBcILNGGOMMcYYY4wZACfYjDHGGGOMMcaYAXCCzRhjjDHGGGOMGQAn2IwxxhhjjDHGmAFwgs0YY4wxxhhjjBkAJ9iMMcYYY4wxxpgBcILNGGOMMcYYY4wZACfYjDHGGGOMMcaYAXCCzRhjjDHGGGOMGQAn2IwxxhhjjDHGmAFwgs0YY4wxxhhjjBkAJ9iMMcYYY4wxxpgBcILNGGOMMcYYY4wZACfYjDHGGGOMMcaYAXCCzRhjjDHGGGOMGQAn2IwxxhhjjDHGmAFwgs0YY4wxxhhjjBkAJ9iMMcYYY4wxxpgBcILNGGOMMcYYY4wZACfYjDHGGGOMMcaYAXCCzRhjjDHGGGOMGUCyE2xBEFYLgvBBEIS737xmKwjCcUEQ/vvf/9qkTzMZY4wxlpFwXMAYY4xJpWQEey0R1f/utZFEdBJAfiI6+b+/M8YYYyzrW0scFzDGGGMakp1gA/iHiEK+e9mbiNb977/XEVETwzSLMcYYYxkZxwWMMcaYVFrXYDsBePu//35HRE66ThQEwU8QhABBEAKCg4PTeFvGGGOMZUDJigs4JmCMMZZVGazIGQAQEfQcXwHAA4CHg4ODoW7LGGOMsQxIX1zAMQFjjLGsKq0J9ntBEJyJiP73vx/S3iTGGGOMZVIcFzDGGPulpTXB3k9EHf/33x2JaF8ar8cYY4yxzIvjAsYYY7+0lGzTtYWILhFRQUEQXguC0JWIZhBRXUEQ/iOiOv/7O2OMMcayOI4LGGOMMSlFck8E0FrHodoGagtjjDHGMgmOCxhjjDEpgxU5Y4wxxhhjjDHGfmWcYDPGGGOMMcYYYwbACTZjjDHGGGOMMWYAnGAzxhhjjDHGGGMGwAk2Y4wxxhhjjDFmAJxgM8YYY4wxxhhjBsAJNmOMMcYYY4wxZgCcYDPGGGOMMcYYYwbACTZjjDHGGGOMMWYAnGAzxhhjjDHGGGMGwAl2EmJjY+nQoUN069atn90UxhjLEh48eEAHDhygL1++/OymMJZit27doiNHjlBUVNTPbgpjjGV6nz9/pkOHDtHDhw9/dlMMhhPsJHh7e1OjRo2oVKlSVL9+fXr9+vXPbhJjjGVKHz9+JF9fXypSpAh5eXlR9erVSaVS/exmMZZs+/bto1KlSlGDBg0ob968tGPHjp/dJMYYy5QA0KpVq8jV1ZUaNWpExYoVo/Pnz//sZhkEJ9h6vHr1io4ePUrDhg2jOXPm0Pnz56lMmTJ0+fLln900xhjLVO7evUtly5al/fv308SJE2nq1Kl07do1unHjxs9uGmPJtmrVKsqTJw8dPnyY8uTJQy1atKBhw4ZxRxFjjKVAfHw8de/enbp3706lS5emY8eOUbZs2Wj16tU/u2kGwQm2HurR6tq1a9OQIUPo6tWrZGVlRbVr16aTJ0/+5NYxxljmcPXqVapWrRrFx8fThQsXaNy4cVS/fn0iIp4VxDKV169fU/HixcnT05POnz9PvXv3pjlz5lCXLl0oMTHxZzePMcYyvNjYWGrevDn5+/vTmDFj6MSJE1S3bl0qUKBAlokJOMHWQ6FQEBFRXFwcEREVLlyYLly4QG5ubtSoUSM6ffr0z2weY4xleNeuXaO6deuStbU1XbhwgTw8PIjo/79XlUrlz2weYymiUCg0fnYXLVpEEyZMoHXr1lGXLl14JJsxxvSIi4sjX19f2r9/Py1atIgmT55MMplMPJZVYgJOsPXIkSMHEREFBQWJrzk5OdGpU6coX7585OXlRVevXv1ZzWOMsQzt4cOHVL9+fbK2tqYzZ85Q3rx5xWPq71VnZ+ef1TzGUixHjhwaMYEgCDR+/HiaOHEirV+/ngYMGEAAfmILGWMsY1KpVNSpUyc6cOAALVmyhPr06aNxPCgoKMvEBJxg65EjRw4yNTWlR48eabzu4OBAx44dIwcHB2rQoAH9999/P6mFjDGWMb1584Z+//13kslkdOLECcqdO7fGcfX3qru7+89oHmOpkj9/fnry5IlkOvjYsWNpyJAhtGjRIpo2bdpPah1jjGVcgwcPpi1bttCMGTOoV69eGsfCwsLo/fv3lD9//p/UOsPiBFsPmUxGxYoV07pFV44cOejo0aNEROTp6UnBwcE/unmMMZZhLV++nEJCQujIkSNak+hbt25Rnjx5yNLS8ie0jrHUKV68OMXExNDjx481XhcEgWbNmkXt2rWjMWPG0MaNG39SCxljLON59uwZrVy5kgYOHEjDhw+XHL99+zYRff2OzQo4wU6Ch4cHBQQEaC1ekj9/fjpw4AAFBQVRkyZNKCYm5ie0kDHGMp7x48dTQEAAlSlTRuvxK1euULly5X5wqxhLG3UNgStXrkiOyWQy8vf3p5o1a1KXLl3on3/++dHNY4yxDMnNzY1u3LhBc+bMIUEQJMfV36nq79jMjhPsJFSqVIkiIiLozp07Wo//9ttvtH79erp48SJ169aN114xxhh9TTYKFiyo9djr16/p5cuXVLFixR/cKsbSpnDhwmLBPm2MjIxo165d5ObmRj4+PvT06dMf3ELGGMuYChQoQHK5XOuxCxcuUL58+cjR0fEHtyp9cIKdhOrVqxMR6a0Y7uvrS5MnT6ZNmzbRzJkzf1TTGGPsp/n48WOqOxTPnDlDRP///cpYZiGTyahq1ap6YwIbGxs6ePAgASAvLy/6/PnzD2whY4z9eGmJCRITE+ns2bNZKibgBDsJLi4ulD9/fjpx4oTe8/744w9q1aoVjR49mg4ePPiDWscYYz/ewoULafLkyfT582cCkOKH6okTJ8jW1pZKliyZTi1kLP3Url2bnj59Si9evNB5jru7O+3cuZMePXpEbdu25e27GGNZVlpjguvXr1NYWBjVrl07nVr443GCnQz16tWjM2fOUGxsrM5zBEEgf39/Kl26NLVp04YePHjwA1vIGGM/xo0bN2jp0qXUu3dvypYtGwmCoLGeKqkHKwA6duwY1alTR+dUMcYysnr16hERiYVOdalZsyYtWLCADh48SGPHjv0RTWOMsR8qrTEB0f9/l9apUyfd2vmjcYKdDJ6enhQVFUVnz57Ve56ZmRnt3buXTE1NqUmTJhQeHv6DWsgYYz/G7du3qVGjRlSwYEEKCAigAQMG0LBhw2jChAkUGxurtXjJt27evElv374lT0/PH9RixgyrUKFC5OrqSocPH07y3N69e1O3bt1o2rRptHPnzh/QOsYY+3HSGhMQER0+fJg8PDyyzPprIk6wk6VmzZpkamqarKnfLi4utHPnTnr27Bm1a9eOp4UxxrKUvHnz0ocPH4iIaNKkSeTq6kplypShDx8+0Ny5c4lIf4/1wYMHSRAETrBZpiUIAjVs2JBOnDhB0dHRSZ67aNEiqlixInXq1Inu3r37g1rJGGPpL60xQXBwMF2+fJkaNmz4Q9r7o3CCnQxmZmZUp04d2rdvX7KmOlStWpXmzZtHBw8epEmTJv2AFjLG2I9RsWJFCg0NpSpVqlCePHlo8ODB1KJFC/Ly8qKgoCAiIr091vv27aPffvuNnJycflSTGTM4Ly8vioqKopMnTyZ5rrGxMe3cuZMsLS2padOmFBYWlv4NZIyxHyCtMYG6IKS3t/ePavIPwQl2MjVp0oRevXpFN27cSNb5ffr0oY4dO9LEiRO56BljLFO7dOkSbdq0iTZv3kxKpZJWr15N1apVoxUrVtDx48dJLpfTixcv6PXr13qv8+rVK7p27Ro1adLkxzScsXRSs2ZNsrKyoj179iTr/Bw5ctDOnTvpxYsXPLuNMZapGSomICLas2cP5c6dm0qVKpX+Df+BOMFOJi8vL5LL5bRr165knS8IAi1dupTKlClD7dq1oydPnqRzCxljzPDu379PjRo1oqCgIJo/fz7169ePrly5Qj179qS5c+dS3759afDgwbRw4UKaN2+e3mvt3r2biIiaNm36I5rOWLoxMjKiRo0a0b59+yghISFZ76lcuTLNnz+fDh06RJMnT07nFjLGmOEZMiaIiIigY8eOUdOmTZO1VjszEVK7Z5nGRQRhEBF1IyIQ0R0i6gwgRtf5Hh4eCAgISPN9f7Q6depQYGAgPXz4MNk/CC9evKCyZctSjhw56PLly2Rubp7OrWSMMcOZOXMmxcfH05gxYygmJob++usvevLkCdWtW5d8fX3pzZs39OXLFzI1NaVcuXLpvVblypUpMjKSbt269YNab1iCIFwD4PGz25EZpCQuyKwxwZ49e8jHx4eOHTtGdevWTdZ7AFCnTp1ow4YNdPDgQWrQoEE6t5IxxgzHkDHB5s2bqW3btnTu3DmqUqXKD/oEhqUrLkjzCLYgCDmJqD8ReQAoRkRyImqV1utmRL6+vvT48eMUBYd58uShLVu20L1798jPzy/Vm7AzxtjP4O7uTn///Tc9evSITExMaPjw4VS9enVatGgRnTp1inLkyEH58+dP8kH6+vVrunjxIrVo0eIHtZz9LL9KXFC/fn2ysLCg7du3J/s9giDQsmXLqGTJktS2bVt69uxZOraQMcYMy1AxARHR9u3bKUeOHFSpUqUf0PIfy1BTxBVEZCoIgoKIzIjojYGum6E0a9aM5HI5bdu2LUXvq1evHk2ePJk2b95MCxcuTKfWMcaY4TVr1owqV65Mx48fp0ePHhERUdu2balTp060Y8eOZK8lVSchnGD/MrJ8XGBqakpeXl60e/duiouLS9H71MvNfHx8kqxEzhhjGYWhYoLw8HA6cuQI+fr6kkyW9VYsp/kTAQgiojlE9IqI3hJROIBjab1uRmRvb0916tShrVu3pngketSoUeTl5UVDhgyhCxcupFMLGWMs7f755x9avnw5TZ48md6/f08dOnSg27dv0/bt2+nIkSNE9LUyclBQULKXy2zevJnKlClD+fPnT8+mswzgV4oLWrVqRSEhIXT8+PEUvc/NzY02bdpEt2/fpl69evHsNsZYhpUeMYG6Y7JVqyw3uYmIDDNF3IaIvIkoLxHlICJzQRDaaTnPTxCEAEEQAoKDg9N625+mTZs29OLFC7p48WKK3ieTyWjdunWUJ08e8vX1pXfv3qVTCxljLPWCgoKoc+fOFBoaSmFhYVSyZEl6+vQpjRo1igRBoC1btlDFihVp7ty5NHbs2GQ9TB8/fkzXrl2jNm3a/IBPwH625MQFWSUm+P3338nW1pY2bdqU4vc2aNCAxo0bR+vWraPly5enQ+sYYyxt0iMmIPra6e7m5kYVKlRI50/wkwBI0x8i8iUi/2/+3oGIluh7T9myZZFZff78GaampujVq1eq3n/r1i2YmpqiWrVqiIuLM3DrGGMsbZYsWYJ27dqJfz958iSKFy+OiRMniq/duHEDT58+TfY1x40bB0EQ8Pr1a4O29UcjogCk8Zn5K/xJaVyQmWMCAOjRowfMzMwQERGR4vcmJibC09MTSqUSly9fTofWMcZY6qVHTBAUFASZTIYxY8YYtK0/g664wBCT3l8R0W+CIJgJX7stahPRAwNcN0OytLQkb29v2rZtW4rWXKmVKFGCVqxYQf/88w+NGjUqHVrIGGOpV7t2bVIqlfTixQsCQLVq1aJTp07Rvn37aOzYsUREVKpUKXJzc0vW9QDQxo0bqXbt2pQzZ870bDrLOH6puKBdu3YUFRUlbkOXEjKZjDZu3Eg5c+ak5s2bU2YezWeMZT2GjgmIiLZs2UIqlYratZNMeM4yDLEG+woR7SSi6/R1Kw4ZEa1I63Uzsg4dOlBISAgdOnQoVe9v164d9enTh/7880/auXOngVvHGGOp5+joSIIg0J9//kmRkZFE9LX+xObNmykiIiLF17tw4QI9e/aM2rdvb+imsgzqV4sLKleuTG5ubrR+/fpUvd/W1pZ27dpFwcHB1Lp1a0pMTDRwCxljLHUMHRMQEa1fv57Kly9PBQsWNGRTMxSDlG0DMB5AIQDFALQHEGuI62ZUdevWpezZs9O6detSfY25c+fSb7/9Rp07d6YHD7Jsxz5jLJPA/4osWVtb06JFi+jDhw/k6+tLV65coffv39ODBw/o9OnTFBOjdStjndatW0fm5ubk4+OTHs1mGdSvFBcIgkDt27enU6dOUWBgYKquUaZMGVq6dCmdPHmSxowZY+AWMsZYyqRXTHDz5k26ffs2dejQIT2anWFkvbroP4BCoaB27drRoUOHUj2dy8jIiHbs2EGmpqbk4+OT6l4gxhhLi8TERIqOjtYoTGJqakrbtm2jmjVr0vz582ngwIE0efJkWrhwIZmYmCT72lFRUbRt2zZq3rw5WVhYpEfzGcsQOnToQABSPYpNRNS5c2fq3r07zZgxg/bu3Wu4xjHGWDKlZ0xARLR27VoyMjKi1q1bG7rpGYqg7qH4kTw8PBAQEPDD72tI9+7do2LFitHcuXNp0KBBqb7O6dOnqU6dOuTj40Pbt29PdvU9xhgzhH79+tGzZ8/ELbSaNm1KlpaW4vHXr1+TmZkZhYSEkLu7e4quvXHjRmrfvj2dPn2aatSoYeCW/3iCIFwD4PGz25HVZIWYgIioRo0aFBQURI8fP071szwmJoaqVq1Kjx8/pqtXr1KBAgUM3ErGGNMtPWOCuLg4ypkzJ9WoUYN27Nhh6Kb/FLriAh7BTqWiRYtS+fLlafXq1Wnav7JmzZo0Y8YM2rlzJ/35558GbCFjjOk3cuRIevbsGc2ZM4dsbGzozp07NGHCBHr48KF4jrm5Odna2lK+fPlSfP3Vq1eTm5sbVatWzZDNZixD6tKlCz158oTOnTuX6muYmJjQrl27SKlUko+PD3358sWALWSMMd3SOyY4cOAAffz4kbp06WLIZmdInGCnQZcuXeju3bt09erVNF1n6NCh1KxZMxo5ciSdOXPGMI1jjLEkWFhYkJ+fHxUuXJj69OlDzZo1I2tra9q4cSPFxcXR33//TUeOHCEiSvGI3NOnT+n06dPUuXNnksn4UcOyvmbNmpGlpSX5+/un6Tq5c+emLVu20IMHD6hbt25p6sRnjLHkSs+YgIjI39+fcubMSfXq1TN00zMcjnrSoFWrVmRqakqrVq1K03UEQaDVq1dT/vz5qWXLlvT69WsDtZAxxnRzc3OjP/74gy5cuEDGxsb022+/kaenJ92+fZvOnz9P7u7u5Onpmaprr169mmQyGXXq1MmwjWYsgzI3N6fWrVvTjh07KDw8PE3Xqlu3Lk2ePJm2bt1Kf/31l4FayBhjuqVnTBAYGEhHjx6lzp07k1wuN3DLMx5OsNMgW7Zs1LJlS9qyZYtYuj61rKysaPfu3RQVFUW+vr4UG5tlC64yxjKINm3aUN++fWnlypXiloEeHh7k7e1Nhw4dInd3d7KxsUnxdRMSEmjNmjXk6elJuXLlMnSzGcuwunfvTtHR0bR58+Y0X2vkyJHk7e1NQ4YMoX/++ccArWOMMd3SKyYgIlqzZg2pVKpfYno4ESfYada9e3eKjIykrVu3pvlahQsXprVr19Lly5dpwIABBmgdY4zp165dO6pRowYdOnSIevbsSU+fPiV/f39ydnZO9TUPHTpEb9++pe7duxuwpYxlfGXLlqVSpUrRihUr0jy1WyaT0bp16yhfvnzUokULCgoKMlArGWNMu/SICRITE2nVqlVUt25dyps3rwFbm3FxFfE0AkDFixcnU1PTNK/FVhsxYgTNmjWL/P39f5meHsbYz5OQkECPHz+m+fPnU1xcHDk4ONDs2bNTfb0GDRrQrVu36OXLl6RQKAzY0p+Lq4inj6wUExARLV26lHr37k1Xrlyh8uXLp/l69+7dowoVKlDx4sXpzJkzZGxsbIBWMsaYdoaOCQ4dOkSNGjWiHTt2UPPmzQ3Y0p9PV1zACbYB/PXXXzRgwAC6du0alSlTJs3XS0hIoPr169P58+fp/Pnz5OHB8Rxj7MeIj48npVKZ6ve/ePGC3NzcaMyYMTRp0iQDtuzn4wQ7fWS1mODz58/k7OxMrVq1SnPBM7WdO3eSr68v9ezZk5YuXWqQazLGWFLSGhMQEXl5edG///5Lr169IiMjIwO1LGPgbbrSUYcOHcjU1JSWLVtmkOspFAraunUrZc+enXx8fOjDhw8GuS5jjCUlrQ/SlStXkiAI1K1bNwO1iLHMxcrKitq0aUNbtmyhsLAwg1yzefPmNHz4cFq2bJnBknbGGEtKWmOCwMBAOnToEHXp0iXLJdf6cIJtANbW1tSqVSvavHlzmiuHqtnb29Pu3bspODiYWrZsSQkJCQa5LmOMqVQqOnr0qMGvGxcXR/7+/tSgQQPKnTu3wa/PWGbRs2dPio6Opg0bNhjsmlOnTqU6depQ79696d9//zXYdRlj7MqVKxQSEmLw66rrUfj5+Rn82hkZJ9gG0qtXL/ry5YtBH6ZlypShFStW0JkzZ2jYsGEGuy5j7Nc2Y8YMql+/Ph0/ftyg192zZw+9f/+eevfubdDrMpbZlC1blsqVK0dLly412D7W6tltOXLkIB8fH3r//r1BrssY+7U9e/aMPD09DV73KS4ujlatWkUNGjSgPHnyGPTaGR0n2AZSrlw58vDwMOjDlIioffv2NGDAAJo/fz5t3LjRYNdljP2aDh8+TGPGjKE2bdpQnTp1DHrtxYsXU968een333836HUZy4x69+5NDx48oDNnzhjsmnZ2drRnzx4KCQkhX19fio+PN9i1GWO/ni9fvlDTpk2JiOjPP/806LX37t1L7969o169ehn0upkBJ9gG1KdPH7p//75BH6ZERLNnz6YaNWpQ9+7d6dq1awa9NmPs1/H48WNq06YNlSxZUlwrbSh37tyhc+fOUa9evUgm40cLYy1btiRbW1tavHixQa9bqlQp8vf3p3PnztHAgQMNem3G2K8DAHXp0oXu3r1LW7ZsoXz58hn0+upOd09PT4NeNzPgKMiA1A/TRYsWGfS6SqWStm/fTo6OjtS0aVOeFsYYS7HPnz9TkyZNSKFQ0J49e8jMzMyg11+8eDGZmJjw1oKM/Y+pqSl17dqV9u7dS69fvzbotVu3bk1Dhw6lJUuWcNEzxliqzJw5k7Zv307Tp083+MyzO3fu0D///PPLdrr/ep84HZmamlK3bt1o7969FBgYaNBrOzg40J49eyg4OJh8fX0pLi7OoNdnjGVdKpWK2rdvT48fP6YdO3YYfC1UWFgYbdiwgdq0aUN2dnYGvTZjmVmvXr1IpVIZbJeRb02fPp3q1q1LvXr1okuXLhn8+oyxrOvw4cM0evRoatWqVbrUeVq0aNEv3enOCbaBqYv7pMc+lWXKlKHVq1fTuXPnaMCAAQa/PmMsa5owYQLt37+f5s2bRzVr1jT49VevXk1RUVHUt29fg1+bscwsb9681LhxY1qxYgXFxMQY9Nrqome5c+cmHx8fg4+SM8aypkePHlHr1q2pZMmS5O/vb9DlYkREoaGhtGHDBmrbtu0v2+nOCbaBubq6kre3N61YsYKio6MNfv3WrVuLe2GmR484Yyxr2blzJ02ePJk6d+6cLglwYmIiLV68mKpUqUKlS5c2+PUZy+z69+9PwcHBtG3bNoNf29bWlvbt20eRkZHUtGnTdIk7GGNZR1hYGHl5eZGxsTHt3bvX4MvFiIj8/f0pOjqa+vfvb/BrZxacYKeD/v3706dPn2jz5s3pcv1p06aRp6cn9evXj86ePZsu92CMZX43b96kjh07UsWKFWnp0qUG76UmIjp06BA9e/bsl36QMqZPrVq1qEiRIrRgwQKD7jKiVrRoUdq4cSMFBARQ9+7d0+UejLHMLzExkdq0aUPPnj2jnTt3kqurq8HvkZCQQIsWLaLq1atTiRIlDH79zIIT7HSg/qFKr4epXC4Xq/01b96cnj9/bvB7MMYytw8fPpC3tzfZ2NjQ7t27ydjYOF3us2DBAsqVK5e4zQdjTJMgCDRgwAC6ceMGnT9/Pl3u4e3tTVOmTKFNmzbRrFmz0uUejLHMbcSIEXTkyBFatGgRVatWLV3usX//fnr58uUv3+nOCXY6UD9M79y5Q6dPn06Xe2TLlo32799PCQkJ5O3tTREREelyH8ZY5hMbG0vNmjWj4OBg2rdvH2XPnj1d7nP79m06deoU9e3blxQKRbrcg7GsoF27dmRra0vz589Pt3uMHj2aWrZsSaNGjaKDBw+m230YY5nPunXr6M8//6Q+ffpQjx490u0+8+fPpzx58pC3t3e63SMz4AQ7nbRp04YcHBzS9WFaoEAB2r59O92/f5/atWtHKpUq3e7FGMscAFDv3r3p/PnztGbNGipbtmy63Wv+/PlkampK3bt3T7d7MJYVmJmZkZ+fH+3duzfdZp0JgkCrV6+m0qVLU5s2beju3bvpch/GWOZy8eJF8vPzo1q1atG8efPS7T7Xrl2jc+fOUb9+/Ugul6fbfTIDTrDTiYmJCfXq1YsOHjxIjx8/Trf71K1bl+bNm0f79++n0aNHp9t9GGOZw9y5c2n16tU0duxYatmyZbrd5/3797Rp0ybq1KkT2draptt9GMsq+vTpQzKZjP766690u4eZmRnt27ePzM3NqXHjxhQcHJxu92KMZXwvX76kpk2bUu7cuWnHjh2kVCrT7V7z5s0jCwsL6tq1a7rdI7PgBDsd9e7dm5RKZbqOYhMR9e3bl3r06EEzZ86k9evXp+u9GGMZ18GDB2nYsGHUvHlzmjBhQrrea8mSJRQXF8dbBjKWTLly5aIWLVqQv78/hYeHp+t99u3bR+/evSMfHx+KjY1Nt3sxxjKuiIgI8vLyotjYWDpw4EC6doYHBQXRtm3bqGvXrpQtW7Z0u09mwQl2OnJycqK2bdvS2rVr6dOnT+l2H0EQaOHChVSrVi3q3r17uhVRYYxlXHfu3KHWrVtT6dKlad26dSSTpd/Xe3R0NC1ZsoQaN25MBQsWTLf7MJbVDB48mCIiImjVqlXpep/y5cvTmjVr6Pz589SjRw+uLM7YLyYxMZHatm1L9+7do+3bt1OhQoXS9X4LFy4klUrFne7/wwl2Ohs8eDBFR0enaM/qd+/e0YABA6hr1660fft2SkhISPI9SqWSduzYQXny5KGmTZvSs2fP0tJsxlgm8u7dO2rUqBFZWVnR/v37k7WvpUqlov3795Ofnx/17t07RetCN2zYQB8/fqTBgwenpdmM/XLKli1L1atXpwULFlB8fHyy3gOAVq1aRW3btqUpU6bQ69evk/W+Vq1a0fjx42ndunU0c+bMtDSbMZbJDB8+nA4cOEALFiygevXqJes9T58+pXHjxlG7du1o69atyb5XZGQkLV++nHx8fChv3rypbXLWAuCH/ylbtix+JfXr14eTkxNiYmKSdX7VqlVhZGQEe3t7EBEKFSqE06dPJ+u9jx8/ho2NDQoXLozQ0NDUN5oxlilERUWhQoUKMDMzw7Vr15L1nqtXr6J06dIgItjY2MDU1BSFCxeGSqVK8r2JiYkoWLAgypYtm6zzsxIiCsBPeGZm9T+/Wkxw4MABEBE2bdqUrPO3bt0KIoKzszMEQYCxsTHGjBmTrJhCpVKhdevWICLs3LkzrU1njGUCy5cvBxGhX79+yTo/IiICAwYMgFwuh1wuR/bs2UFEOHHiRLLeP3/+fBARLl++nJZmZ0q64gJ+mP4Ax48fBxFh1apVSZ4bEhICIsKUKVOQmJiI3bt3w83NDUSE4cOHIz4+PslrnD59GkqlErVr10ZcXJwhPgJjLANKTEyEr68vBEHAnj17knX+1KlTIZfLkSNHDmzYsAHx8fHiw/jx48dJXmPfvn0gImzZssUAnyBz4QSbYwJDSExMRKFChVC6dOlkdVK1adMGOXPmRGJiIp49e4a2bduCiFCyZEk8evQoyfdHR0ejYsWKMDU1xZUrVwzxERhjGdSxY8cgl8vh6emZrJzh+vXrcHd3hyAI6NmzJ4KCghATEwMzMzMMGDAgyffHx8fD1dUVVapUMUDrM590TbCJyJqIdhLRQyJ6QEQV9Z3/qz1MVSoVSpUqhUKFCiExMVHvuc+ePQMRYc2aNeJrX758QY8ePUBE8PT0TNYDec2aNSAidO3a9ZcbZWLsVzFy5EgQEWbPnp3kueHh4fDy8gIRoWXLlhozXA4ePAgiSlbwXblyZbi6uibrwZ3VcIKdPnHBrxYTAMDKlSuTPULk6ekJDw8PjdcOHDgAOzs7WFpa4saNG0le4/3798iTJw+cnJzw4sWL1DabMZaB3b17F1ZWVihevDjCw8OTPH/Dhg0wNjZGzpw5cebMGY1jrq6u6NChQ5LX2Lx5M4gIe/fuTXW7MzNdcYGh1mAvIKK/ARQiopL/e5iy/xEEgYYNG0YPHz6kgwcP6j3XysqKiIhCQ0PF18zMzGjZsmW0fv166tChAwmCkOQ9O3XqRH/88Qf5+/vz2ivGsqBVq1bRjBkzqEePHjRkyBC95wYGBlLlypXp0KFDtGDBAtqyZQtZW1uLx9XfN0lV/rx06RJduHCBBg0aRAqFIs2fgWVpHBfo0a5dO3JycqJZs2Ylea6VlZVGTEBE1KhRI7p+/Tp17NiRihYtmuQ1HB0d6fDhwxQTE0MNGzaksLCw1DadMZYBvXv3jho0aEDm5uZ06NAhMZ/QBgCNGTOG2rdvTxUrVqQbN25Q9erVNY6HhoYmGRMAoNmzZ1PBggWpcePGBvssWYK2rDslf4goGxE9JyIhue/5FXur1VMoKleurPc8lUoFKysr9O7dO9X3io6ORnR0tMbaq61bt6b6eoyxjOXo0aOQy+X4/fffkxxJvn37NnLkyAErKyscO3ZM6znjx4+HIAiIiorSey1vb2/Y2toiIiIi1W3PzIhHsNMlLvgVYwIAmDZtGogoyRHoUaNGQaFQIDY2NlX3UccEAHDy5EkoFArUrl071ddjjGUskZGR8PDwSFYtlri4OLRv3x5EhG7dumldSvrhwwcQEebOnav3WseOHQMRYeXKlWlqf2amKy4wxAh2XiIKJqI1giDcEARhlSAI5ga4bpaiUChoyJAhdOHCBb3baAmCQKVKlaJr166l+B5RUVG0e/duatCgAbVo0YK2bdtGa9asoapVq1KHDh3o3LlzafkIjLEM4NatW9S8eXMqWrQobd++Xe9I8vnz56lq1aokCAKdP3+e6tatq/W8a9euUaFChcjU1FTntR48eED79u2jPn36kIWFRZo/B8vSOC5Ihl69epGlpWWSs8xKlSpFCQkJdOfOnRRd//uYYOvWrVSrVi1atWoVnTx5kvz8/NQdIoyxTCohIYFat25N169fp23btlGZMmV0nvvlyxfy8vKiDRs20OTJk2nFihWkVCol56lzkFKlSum998yZM8nZ2Znat2+fps+QJWnLulPyh4g8iCiBiCr87+8LiGiylvP8iCiAiAJy5879g/oVMpbIyEjY2dmhUaNGes8bMWIElEolIiMjk33t4OBgzJ07F15eXti1axeuX7+OokWL4uHDh/j06RMKFiwIGxsbPHjwIK0fgzH2k7x69Qo5c+ZEzpw5ERgYqPfcw4cPw9TUFAULFsTLly91npeYmAgbGxt06dJF7/U6d+4MU1NTfPjwIVVtzwqIR7ANFhdwTPDV0KFDIZPJ8PTpU53nvHr1CkSEefPmJfu6+mICAJgwYQKICOPHj0/jJ2CM/SwqlQq9e/cGEWHx4sV6zw0NDUXlypUhk8mwYsUKveeOHj0acrlc72y1gIAAEBFmzpyZqrZnFbriAkM8SLMT0Ytv/l6ViA7pe8+vOh0MACZOnAgiwp07d3Sec/ToURARDh8+nKxrxsTEYP78+fDz88O///4rvl61alVcvXoVAPD06VM4OjoiT548ePv2bdo+BGPshwsNDUWxYsVgZWWFW7du6T13586dUCqVKF26dJIJ8dWrV0FE2Lhxo85zAgMDoVQq0bdv31S1PavgBDt94oJfOSYICgqCkZERevXqpfc8d3d3NGzYMFnXTE5MoFKp0LlzZxAR/P39U/8BGGM/zYwZM0BEGDZsmN7zgoODUbp0aSiVSuzYsSPJ65YvXx4VK1bUe07z5s2RLVu2ZBVTy8rSLcH+em06R0QF//ffE4hotr7zf+WH6adPn2Bubo62bdvqPCcqKgqmpqbJDmaPHj2K33//HRcvXgTwdX3F1q1b0a1bN42q5VevXoWZmRlKly6Nz58/p+2DMMZ+mJiYGNSoUQNKpTLJqsObNm2CXC5HpUqVNCqF6zJx4kQIgoD379/rPGfgwIGQy+W/fPVhTrDTJy74lWMCAOjWrRuMjY31dn737dsXZmZmSdZJAJIfE8TFxaFevXqQy+U4dOhQ2j8IY+yH2bBhA4gIrVq10rtD0du3b1G0aFGYmJgka+Du/fv3EAQBEydO1HnOw4cPIQgCRo0alaq2ZyXpnWCX+t9Ur9tEtJeIbPSd/6s/TAcPHgy5XK53SpiXlxdcXFyS3GIrISEBLVu2FLf1io2NxcmTJzFo0CAsWrQIKpVK4xoHDx6EXC5HvXr1uMAJY5lAYmIiWrZsmeQoM/B1ez5BEFCjRo1kFyIrXbq03p7q4OBgmJmZoX379ilqd1bECXb6xAW/ekzw+PFjyGQyjBgxQuc56plt+/fv13utlMYE4eHhKF26NMzMzDRGuxljGdfx48ehUChQo0YNxMTE6DwvKCgIBQsWhJmZGU6dOpWsa69atQpEhOvXr+s8p3PnzjAxMcG7d+9S3PasRldcYJBtugDcBOABoASAJgBCk37Xr2vIkCEkl8v1bs/h4+NDgYGBdPXqVb3XEgSBTExMKD4+noiItmzZQocPHyYrKyvq0qULCYKgsa1XhQoVaMWKFXTs2DHq2rUrqVQqw3woxpjBAaAhQ4bQtm3baObMmdS2bVud565atYo6d+5MderUoUOHDiWrENmzZ8/oxo0b5OPjo/OcBQsWUFRUFI0cOTJVn4H9mjguSL78+fOTr68vLVmyRLIdl1qNGjXI2tqadu7cqfdaKY0J4uLi6PDhw+Tk5EQNGjSg//77z3AfjDFmcNevX6emTZtS4cKFae/evWRsbKz1vNevX1P16tUpKCiIjh49SjVr1kzW9Xft2kV58uTRWeAsMDCQNmzYQN26dSMnJ6fUfoysT1vWnd5/fvXeagDw8/ODsbExgoKCtB4PCQmBUqnEoEGDkrzW3bt3Ubx4cVSvXh3t2rXDqlWrNEavEhISMGPGDAwaNAjVqlXD6tWrMWXKFBARhgwZYrDPxBgzLPX6qgEDBuidzbJs2TIQETw9PcXteJJDvU3Q8+fPtR4PDw+HtbU1mjZtmtKmZ0nEI9gcE6STmzdvgogwadIkned06tQJVlZWSf6OpyYmePToEezt7ZE3b168efPGYJ+LMWY4T548gaOjI3Lnzo3Xr1/rPO/ly5dwc3ODlZWVuFQkOT5+/AiFQqF3TXe/fv2gUCj0Fk/9leiKC/hh+pM8ffoUcrkcgwcP1nmOl5cXnJ2dkZCQkOT1Pn78iJcvX0rWZyUkJKBx48bo2bMn9u/fj4CAALi6uuKff/5B3759QUSYPXt2mj8PY8yw1qxZk6z1VUuWLAERoWHDhnqnin1PpVKhaNGiqFSpks5z1Al4QEBAitqeVXGCzTFBemrUqJHefebV08R37tyZ5LVSGhOcO3cO//77L8zNzVGyZEmEhYUZ5DMxxgzj7du3cHNzg52dnd4dgV68eIG8efMiW7ZsuHLlSorusXTpUr3Tw9+9ewcTExN07tw5RdfNyjjBzoDat28PMzMznVV+d+zYASLC0aNHk33NefPm4ezZswC+BtA+Pj7o0qULoqKixM3kBwwYgHv37iEhIQEtWrQAEYnrtRhjP9++ffsgl8tRt25dvbUS1Ml148aNU5RcA8D169dBRFiyZInW45GRkbC3t0f9+vVTdN2sjBNsjgnS0+XLl0FEmDVrltbjCQkJcHZ2hpeXV7KvmZKYAPiaxCuVSlSrVi1ZBdUYY+kvLCwMJUuWhJmZGS5fvqzzvBcvXiBPnjywtrZOVU2FihUromjRojpnzA0bNgwymQyPHz9O8bWzKk6wM6D79+/rrcIXExMDGxsbtGrVKtnXDAkJwY0bNwB8HQHr1q2bRsXw48ePo0CBAnj06BEA4MOHD6hbty7kcjn27duX+g/DGDOIM2fOwMTEBOXKldNbqGzx4sWpTq6Br9O8jIyM8OnTJ63H586dCyLC+fPnU3ztrIoTbI4J0ludOnXg5OSkM7kdNmwYFApFsosLpTQmCA8Px5YtWyAIAho3biwm4YyxnyMqKgpVq1aFUqnE33//rfO8b5Nr9XZ8KfHw4UO9HXwfP36Eubk5WrduneJrZ2WcYGdQvr6+sLS01Bnk9u3bF8bGxjqP6zN9+nQsXLhQnF569uxZ5MuXT9wDLzg4GO3bt8f06dNRvnx5GBsbJ7vKIGPM8K5duwZLS0sULlwYwcHBOs9TJ9deXl6p2g0gOjoatra2aNmypc7jzs7OqFmzZoqvnZVxgs0xQXo7c+YMiAgLFizQevz+/fupXtqV3Jhg7ty54ndM+/bt9S5RYYyln7i4ODRs2BCCIGDr1q06z/s2uU7tkq5hw4ZBLpfr3C5wzJgxICLcvXs3VdfPqjjBzqBu3boFIsK4ceO0Hr9x4waICPPnz0/RdePj49GhQwdMmDABAPD3338jV65cWLt2rXhObGwsFi5cCDs7O+zevRtFihSBhYUFb9XB2E/w4MED2NvbI3fu3AgMDNR53rcj16ndam/Tpk0gIhw/flzr8YULF4KIuMPtO5xgc0zwI1SrVg05cuTQWcysUqVKKFCgQJLbeH4rpTHByZMnMXnyZBAR+vXrl6J7McbSLiEhAa1btwYRYenSpTrPS+vINfD1d9/R0RFNmjTRejwkJARWVlZo1qxZqq6flXGCnYE1bdoU2bJlQ2hoqNbj5cuXR+HChVP8gHv37h2KFy+Otm3bonLlyti+fbvG8QcPHmDRokXo1q0b7t27h6CgIOTNmxe2tra4c+dOaj8OYyyFnj9/jpw5c8LJyUnv2qa0jlyrVa1aFfny5dM6MhUTE4OcOXOiSpUqHFR/hxNsjgl+hBMnToCIsGjRIq3H161bByLCyZMnU3TdlMYEKpUKQ4YMARFhzJgxqf48jLGUUalU6NGjB4gI06dP13meIZJrANi6dSuICIcPH9Z6fPz48SAi3Lx5M9X3yKo4wc7A1KPU6p7l761evRpEhNOnT6f42p8/f8a7d+8k5fwfPHiA6dOno2fPnrh27Zr4+tOnT5EjRw5kz56dixgw9gMEBQXBzc0NNjY2uHXrls7zDDFyDQB37tzRu85KfR9do9u/Mk6wOSb4EVQqFSpXroxcuXJpra8QFRUFW1tbNG/ePMXXTmlMoFKp0K1bNxARZs6cmfIPwxhLEZVKhaFDh4KIMHLkSJ3nGSq5BoDq1asjb968WjvdQ0NDkS1bNp2j2786TrAzOG9vb1hbW2sdxY6KioKNjU2qHqbfUgflT58+FR+k367VUKlU2LhxI8aNGwd7e3u4uLjgxYsXabonY0y3Dx8+oHDhwrCwsNC7ncaiRYsMMnINAL169YKxsbHWNd4xMTHIlSsXKlWqxKPXWnCCzTHBj3Ls2DG9Vf6HDh0KuVyudy/cpCQ3Jli8eLE4VXXhwoWpvh9jLGnq0eI+ffrofA4/f/7cYMm1utNdVwfaxIkT9W7d9avTFRfIiGUI48ePp7CwMFqwYIHkmKmpKXXr1o327NlDgYGBqb5H586dqUWLFnT8+HEKDAykDh06UNmyZcXj0dHRZGFhQX///Tc1b96cIiIiqHbt2vTmzZtU35Mxpl1ISAjVrVuXXrx4QQcPHqTy5ctrPW/RokXUt29f8vLyoh07dpCRkVGq7xkWFkbr16+n1q1bk729veS4v78/vX79miZMmECCIKT6PoyxtKlTpw5VqlSJpk2bRrGxsZLjvXr1IpVKRcuWLUv1PZIbE6xbt47c3d3J29ub+vXrR/7+/qm+J2NMt1mzZtHEiROpc+fO9Ndff2l9Dj9//pxq1KhBYWFhdOLECfLw8EjTPRctWkQmJibUtWtXybGwsDCaO3cueXt7U+nSpdN0n1+Otqw7vf9wb7V2TZo00bkW+/nz55DJZHqniyQlJCQE5cqVQ968efHkyRO951apUgUrV66EpaUlChYsqLOqIGMs5UJDQ+Hh4QEjIyMcO3ZM53kLFiwAEcHb2zvNI9cAMGfOHJ090dHR0ciZMycqV67Mo9c6EI9gc0zwAx0/flzvWmwvLy/Y29vrLIaWlJTGBGfOnEH9+vUhCALWr1+fqnsyxrSbN28eiAitW7dGQkKC1nOePXuG3Llzw8bGRmMpR2p9+vQJpqam6Nq1q9bj6tF09VZ/TEpXXMAP0wzk5s2bICKMHTtW63EfHx/Y2NggMjIy1feIiIiAm5sbli1bBgBa11s8f/4clStXxp07d3Du3DmYmZmhSJEieP/+farvyxj7Kjw8HBUqVIBSqcSBAwd0nqd+2DZt2tQgyXV8fDxcXV1RtWpVrcf/+usvEBFOnDiR5ntlVZxgc0zwI6lUKlSpUkVnRfGTJ0+CiLBq1apU3yOlMUFUVBRq1aoFmUyGLVu2pPq+jLH/p14G1qxZM8THx2s958mTJ3BxcYGNjY3BpmtPnz4dRKS1/ou6cnjTpk0Ncq+sihPsTKJZs2awtLTEx48fJcfOnTsHIsLixYvTdI+oqCg8e/YMwNciCYGBgbhx4wYOHTqEHTt2oHXr1hg5ciQ+fPgAADh16hRMTU1RvHhx8TXGWMp9/vwZlSpVgkKhwJ49e3Sepx5pbtasGeLi4gxy7+3bt4OIsHv3bsmxqKgoODs7o1q1ajx6rQcn2BwT/GinTp3SuVWnSqVCiRIlULRo0TT93qY0JoiMjES1atUgl8uxbdu2VN+XMQYsXbo0yRorjx8/Rs6cOWFnZ2ew0eTY2FjkyJEDtWvX1nr8jz/+0Jl8s//HCXYmcefOHQiCoHUquEqlQvny5eHu7q5z+khKLFq0CHK5HE2aNEGbNm3QuHFj9OjRA2vXrsV///2nce6JEydgYmKCEiVKaC2OxBjT7/Pnz6hcuTLkcjl27typ8zx1j3KLFi0MllyrVCpUqFAB+fLl0/rd8eeff4KIcObMGYPcL6viBJtjgp+hRo0acHJywpcvXyTH1Ft2HTlyJM33SUlMEBERgSpVqkAul0u2+2KMJc/y5ctBRGjUqJHWHQMA4OHDh3B2doa9vb1Bk93169fr3JorODgYFhYW8PX1Ndj9sipOsDOR1q1bw8zMDO/evZMcU49C7dq1K833iYmJQePGjTXWXmibHqZ2/PhxmJiY8Eg2YykUHh4uJtc7duzQed6kSZNARGjTpo3OaWKp8c8//+ic/RIREQEHBwfUqVPHYPfLqjjB5pjgZzh//rzOKr/qUahatWql+T4pjQk+f/4sJtk8ks1YyqhHrhs2bKgzub579y6cnJzg6OiIO3fuGOzeSc1+GTp0KGQyGe7fv2+we2ZVnGBnIo8ePYJMJsPAgQMlxxISEuDm5oYKFSoYZCpnQkIC6tSpg0GDBiXr/OPHj8PU1BTFihXT2gHAGNMUFhaGihUrQqFQ6EyuVSqVOB2rQ4cOBpmh8q1GjRrB3t5e6wjY1KlTQUS4dOmSQe+ZFXGCzTHBz/L777/Dzs4O4eHhkmOzZs0CEaV5ux4g5THBt0n25s2b03x/xn4F6jXX+kaub968CXt7ezg7Oxs80T1y5AiICGvWrJEce/PmDUxNTdG+fXuD3jOr4gQ7k+ncuTOMjY0RGBgoObZkyRIQEU6fPm2QeyUmJmL69Ol4+vRpss4/efIkTE1NUbhwYbx588YgbWAsK1JX6VUoFFrXPgNfk+shQ4aAiNCtWze9I0apod7jcuLEiZJjoaGhsLa2RqNGjQx6z6yKE2yOCX6Wq1evgogwYcIEybHw8HBky5YNzZs3N8i9UhoTREREoHr16pDJZFxdnLEkqAuYenl56Uyur169ChsbG+TKlQuPHz82eBuqV6+OnDlzal3z3adPHygUiiR3FmBfcYKdyTx//hxKpRJ+fn6SY1FRUXB0dES9evV+Qsu+Onv2LMzNzeHu7o5Xr179tHYwllEFBwejVKlSMDIywv79+7Wek5iYiD59+oCI0LdvX4Mn1wDQtm1bmJub49OnT5Jj6lFz3oIjeTjB5pjgZ2ratCmsrKy0FkEdNWoUBEHAw4cPf0LLvhY+q127NgRBwIoVK35KGxjL6NQ1Vpo1a6azoNmFCxdgZWWFPHnyiMUHDenixYsgIsydO1dyTF/uwbTjBDsT6tu3LxQKhaS4CADMmDEjRVPCLly4gIULF+LgwYNap4nqe5+uHraLFy/CysoKrq6u3NPF2DfevHmDIkWKwMTERGfxoYSEBHTu3BlEhKFDhya55CMmJgZHjx7FwoULcfLkyWQtEXn69ClkMhmGDBkiOfbu3TuYmZmhZcuWyftQjBNsjgl+qrt370IQBAwdOlRy7P379zAxMUHnzp2Tda0PHz5g5cqVWLNmTYqC+JCQEJ0dclFRUfD09AQRYcGCBcm+JmNZnUqlwtixY8V9rnXVWDl58iTMzc2RP3/+ZA1ePXz4EKtWrcKaNWsQGhqarLY0atQIdnZ2Wrf87dixo87Zs0w7TrAzobdv38LMzAytWrWSHAsPD4e1tXWy9qdbtWoViEj8ky1bNowaNQohISF63xcYGAhjY2PUq1dP597bAQEBsLOzg7OzM+7du5e8D8ZYFvbixQu4u7vD3Nxc5zKO2NhYtGjRAkSE8ePH602WIyIiMGXKFNjb22v8Hk+fPj3JtnTv3h3GxsYICgqSHOvXrx/kcjkePXqU7M/2q+MEm2OCn61Dhw4wMTHB69evJcf69esHhUKBFy9e6L1GcHAwcubMqfF9Uq9ePVy8eDHJ+7do0QKWlpb4559/tB6PiYlB06ZNQUSYMmUKb/vHfnkqlQqDBg0CEaFr1646a6wcPHgQxsbGKFasGN6+fav3msePH0eVKlU0focLFCiQ5ADajRs3QESYNGmS5Ni9e/d0dsgz3TjBzqRGjx6tcwrnuHHjQES4ffu23msUKVIEv/32G968eYNjx47B19cXgiDAzs4O/v7+eh+A/v7+EAQBVapUQVhYmNZz7ty5g+zZs8POzs4gRVYYy6wePnyIXLlywdraWmfRsKioKDRs2BBEhNmzZ+u9XkBAAHLkyCFWGj1w4ADevHmDBg0awNHRUe/v7qtXr6BUKtGrVy/JMfU0sO7du6fsA/7iOMHmmOBne/bsmc4pnIGBgTp/57+lruNy6tQpPHjwAFOmTIGjoyOICG3bttVbwDQwMBAFCxaEqakp/v77b63nxMfHo127diAiDB8+nJNs9stKSEhA165dQUTo37+/zmVgW7duhUKhQNmyZbUuAVFTqVRo1qwZiAguLi6YM2cO/vvvP+zYsQNEpHeXEgBo1qwZrKystI52N23aFJaWlrwVbwpxgp1JhYaGwsbGBvXr15cc+/TpEywtLdGiRQu91zA3N5dUBL1x4waqVq0KIkLt2rXx8uVLne/ftm0bFAoFypQpo3N7rv/++w958uSBpaWlwYqvMZaZBAQEwN7eHo6Ojrh586bWc8LDw1GjRg0IgoBly5Ylec2wsDA0aNAAFy5c0Hh92rRpICJERUXpfK+6UIm20az27dvDxMSEp4GlECfYHBNkBPpmn/j5+cHIyEjv7/bw4cNhbGyskfhGRERgzJgxMDIygq2trd6K4O/fv0fJkiWhVCqxc+dOreckJiaid+/eICJ0797d4DsjMJbRxcTEoHnz5iAijB07VmdH08qVKyEIAqpWrap1l4DvTZgwATNnztRYvhkeHp5kp/3du3dBRPjjjz8kxy5duqRzZJvpxwl2JqbegkNb4jp69GgIgoC7d+/qfL+Dg4PWkarExEQsW7YMFhYWyJYtG7Zs2aLzGocOHYKJiQkKFSqkc13I69evUaRIERgbG2Pv3r1JfzDGsohTp07B0tISrq6uOit+BgcHw8PDAwqFAps2bUrTPtejR4+GXC7XGbQGBgbCyMhI6+/97du3IQgChg0blur7/6o4weaYICN49+4dzM3NtVYNf/78ORQKBfr06aPz/RMmTAARaS2ydP/+ffz2228gIrRp00ZnwB8aGopKlSpBJpPB399f6zkqlUqchde8eXOd9VwYy2o+f/6MOnXq6CwmpjZ79mwQETw9PZOVXOsSFBQEIsLixYt1ntOiRQtYWFhIRshVKhWqV68OR0dHREREpLoNvypOsDOxqKgo5MqVS+ve18HBwbCwsNA7iv3bb7+hWrVqOo8/ffoUFStWFNeH6FrDcfbsWVhZWSF37tw6K5V+/PgRFSpUgEwmw6pVq5Lx6RjL3Hbu3AkjIyMUKVJE67pI4Ot07cKFC8PExAQHDhzA5cuX4enpiSlTpuisMK5Ps2bN4O7urvO4evT6+fPnkmMNGjRAtmzZtFYVZ/pxgs0xQUahXiJ25coVybFu3brpHcXesGEDiAh37tzRejw+Ph6TJk2CXC6Hm5sbAgICtJ4XGRmJevXqJTlyNnfuXBARatWqlaYkgrHM4MOHD/Dw8IBcLsfatWu1nqNSqTBy5EgQEVq0aIFz586lKSY4deoUiAhHjx7VelxdIHHUqFGSY4cPHwYRYeHChSm+L+MEO9Pz9/fXub5CPYqt62HZt29fmJubIy4uTuf14+PjxZ7m4sWL6yx8dP36dTg4OMDe3l7neuuIiAj8/vvvICJMnjyZ11+xLGvJkiUQBAEVK1bUmbA+ePAALi4usLKywtmzZ3Hw4EHkz58fy5Ytw7hx48QHXnK36FKpVMiZM6fO6t+vXr3SOXp95swZEBFmzJiRzE/IvsUJNscEGcXnz5/h4OCAGjVqSJ6x6lHs3r17a33vgwcPQERYuXKl3nucP38eLi4uMDIywuLFi7U+y78t2DhixAidz/v169dDoVCgdOnSSRZwYiyzevr0KfLnzw9TU1OdiXJ8fDy6desGIkKPHj2wb9++NMUEwP8vG9O1ftrX11fr6HVCQgKKFy8ONzc3nduGMf04wc7kEhISULRoUeTPn1+SKKvXYvv4+Gh9765du0BEOqt+fuvvv/+GnZ0dLC0tsXv3bq3nPH78GK6urrCwsMDx48e1nhMXF4f27duDiNCzZ880TYdlLKP5dupjo0aNdM76uHLlCuzs7ODo6Ijr168D+Pog3LBhA4CvlUDbtWuH169fi/UNkuqQUgfHutZw9+jRA0qlUrL2WqVSoXz58siVK5fetdtMN06wOSbISBYuXAgiwqFDhyTH1N8D2uqrqFQqODs7J2uLvo8fP6JBgwYgIrRr107rd11CQgJ69uwJIkLnzp11Pu+PHDkCc3Nz5MmT56ft181Yerl27RqcnJxgY2MjqZuiFhUVhSZNmohroVUqVZpjAgCoVasWihcvrvXYrVu3QEQYM2aM5NjatWtBRNi6dWtyPyb7DifYWcCBAwd0rrEYP348iAjXrl2THAsPD4dSqdS6d6Y2L1++RPny5UFEGDlypNZ1nkFBQShevDiUSqXOtduJiYkYMWIEiAiNGzfWudUXY5lJbGys2HnUvXt3vcGkmZkZ8ubNq7GX/YIFC2BlZYVt27YhV65caNSoEdq3b4/WrVsna/rkjBkzQERaA+dnz57pHLnatm0biAhr1qxJ/odlGjjB5pggI4mNjYW7uzuKFSsmeU6rZ7J069ZN63u7dOkCKyurZK2LTkxMxKRJkyAIAkqUKIEnT55IzlGpVOK0dX2djv/++y8cHBxga2uL8+fPJ+NTMpbxHT58GObm5nB1dcX9+/e1nhMaGoqqVatCEASNfeLTGhOEhIRAoVBgxIgRWo97e3sjW7Zskq15o6Ki4OLignLlyvFM0zTgBDsLUBcicHBwkPzShYWFwcbGBg0aNND6Xk9PT7i6uib7lygmJgZ+fn7i/pjapr+GhoaiWrVqICLMmzdP57UWL14MmUwGDw8PnhrGMrXQ0FDUqlUryeUPa9euhUKhQKlSpXD06FGcPHlS4/iyZcswYMAAzJw5E8DXKZ1+fn7YuHFjkm0oW7YsPDw8tB7r2LGj1j1yY2Nj4ebmhuLFi3M13zTgBJtjgoxm+/btICKthcb69+8PuVyutfDiwYMHQUQ4cOBAsu915MgR2NjYwNraGkeOHNF6jnrZzG+//aZzuyH1NFpjY2Ns27Yt2fdnLCNavnw55HI5SpcujTdv3mg9JzAwEMWKFYNSqcS0adMMGhOsXr1aZz2GK1euiPHK99Sd9bzzT9pwgp1F/PvvvyAijB49WnJs+vTpICKtvcLr168HEeHcuXMput/KlSthZGSEvHnzat1vOzo6Gj4+PiAiDB48WOeakf3798PMzAy5c+fWuVacsYzs2bNnKFy4MJRKJdavX6/1HJVKhalTp4rb382ePRvm5ubw9fWVTBkbPXo05s+fL/7dz88vyer76unhf/75p+TY/fv3IZPJMGTIEMmxefPmgYh0BsUseTjB5pggo1GpVKhQoQJy5MghmSX27t07mJmZoXXr1pL3xcbGwtbWFq1atUrR/Z4+fYoSJUpAEARMnz5dayfjrl27YGxsjAIFCuDZs2dar/Px40dUrlwZRKTzOoxlZImJiRg2bJhYBfzz589az7tz5w5y5coFS0tLsSaSoWIC4Ov08Hz58mn9HapTpw7s7e0lbQsODoaVlRUaNWqUnI/K9Ej3BJuI5ER0g4gOJnUuP0zTpnXr1jAxMZFslxUZGYns2bOjatWqkl+0iIgImJubo0uXLim+36VLl+Ds7AwzMzNs375dcjwhIQF9+/YVqyFGR0drvU5AQACcnZ1haWmJv//+O8XtYOxnuXDhAhwcHGBjY6Oztzc+Ph49evQAEaFt27aIjo7G1KlTsXTpUqxYsQJ9+vTRmNZ9+fJlVKxYEbt27YK/vz/KlCmjs1qv2ogRIyCXy7XOBPHx8YGlpaWkyElISAhsbW1Rt27dlH9wpoET7PSJCzgmSJvz58+DiDBx4kTJsVGjRoGIcPPmTcmxPn36wNjYWDJ1NCmRkZFo2bKl+MzXtvzr3LlzsLGxgZOTk86CqNHR0WjVqpW4dpuLLLHMIjIyEk2bNgURoXfv3jqXip06dQrZsmWDs7Mzrl+/bvCY4Pnz5xAEARMmTJAcO3nypM5twvr27Qu5XI579+6l8JOz7/2IBHswEW3mBDv9PX/+HMbGxujQoYPk2OLFi3UWPenatSvMzMxStU3GmzdvxK28Ro8eLZlmqlKpxP38qlSponNq2KtXr1CyZEnIZDL89ddf3GvNMrwNGzbAyMgI7u7uOgvzREREiIWAhg0bJv5cf/nyBYmJibh58yZGjRqFP/74Q+yAiomJwaZNm9C8eXM0btxY50iPWmxsLBwdHeHt7S05pp4Gpu0hO3jwYAiCoDXAZinDCXb6xAUcE6Sdj48PzM3NJVNUQ0JCYG1trXX52PXr10FEGutBk0ulUmHmzJmQyWQoUaKE1u+vBw8ewNXVFWZmZjorKqtUKowdOxZEhGrVqumsgsxYRhEYGIjSpUtDJpNh3rx5eivnK5VKFC5cWEykDRkTAMCYMWMgCIKkJou6qKmLi4tk0OvRo0dQKBTo2bNnaj4++066JthElIuIThJRLU6wfwx18bDve7fUay1LlCghSYKvXr0KIsJff/2VqnvGxMSIWws0bNgQYWFhknO2bdsGY2Nj5M+fX6Ow07ciIiLg7e0NIoKfnx/3WrMMKSEhQfw9q1Gjhs5Oo9evX6NUqVKQy+Xo1asXatWqhSlTpkimdp06dQp9+/YVp4CpdwNIbkXvLVu2gIhw+PBhjddVKhVq1qwJBwcHyTSwJ0+eQKlUpmrmCpPiBDt94gKOCdLuv//+g1KpRNeuXSXHZs6cCSLCmTNnJMfKly+PQoUKpbqz+8iRI7C2toadnZ1kXSkAvH37Fh4eHpDJZHr32d20aROMjY2RN29eXkbGMqxLly4he/bssLS01DqQBXx9Jk+cOFHc9rZatWrpEhPExsYie/bsaNiwoeTYzp07QURYvXq15JiXlxcsLS3x/v37ZN2H6ZfeCfZOIipLRDV0PUiJyI+IAogoIHfu3D/oY2ddYWFhcHBwQLVq1SQPRnUgvm7dOsn7ypcvj4IFC6Zof71vqVQqLF68GAqFAgUKFMCDBw8k55w/fx52dnaws7PTuTVYYmIiRo4cCSJC1apV+RedZShhYWFo2LChuE+lrj3kr1+/jpw5c8LCwgIzZsxA0aJFsWfPHmzZsgUODg44ceKEeG5UVBR2796NIUOGoH///vj999917p2tTaVKlZAvXz7J7+7hw4d1dpw1a9ZM66gWSx1OsA0XF3BMYHi6ZqtERUUhZ86cqFChgiReWLduHYgIx44dS/V9Hz9+jCJFikAul2P+/PmSe0RGRsLLywtEhP79++sstHj58mU4OzvDwsICe/bsSXV7GEsPa9asgZGREdzc3HD37l2t58TExKBdu3ZigeAiRYqkW0ywadMmrZ3ucXFxKFCgAIoUKSL5XTt16hSICNOmTUvBJ2f6pFuCTUSNiGjJ//5bZ4L97R/urTaMJUuWgIiwa9cujdcTExNRtmxZ5M6dWzI1RP0LefDgwTTd++zZs3BwcICVlZXWKqRPnjxBwYIFoVQqtSb6aps3b4aJiQlcXFx0rtNi7Ee6f/8+ChYsCIVCgSVLlug8b+/evTA3N4eLiwtu3bqFixcvaizb2Lx5M0qWLInnz59rvK9u3brImTNniqr3qqeAf1sABfg6yl68eHG4ublJZoKcPXtW57pMljqcYKdPXMAxgWGo6y3UqlVLkuT6+/uDiLBjxw6N12NiYuDk5ARPT8803fvz58/izLSOHTtKYo+EhAQMGjQIRIQGDRroXKr2+vVrlCtXDkSEcePGpXowgDFDiYuLQ//+/cXipbpms3348AFVqlQRq3ZfuHAh3WIClUoFDw8PFChQQPI7os4Nvl+WkZCQgJIlS8LV1TXZo+QsaemZYE8notdE9IKI3hFRFBFt1PcefpgaRnx8PIoWLQo3NzfJXpbq4gazZs3SeD02Nha5cuVCzZo103z/ly9fokyZMhAEAZMmTZL8koeEhIhbGo0cOVLng/LatWvInTs3jI2NeY9e9lPt3r0blpaWcHBwwNmzZ7Weo157KAgCypUrJ44Onzt3Dh06dMCHDx/Ec4cPH47KlSuLf583bx4KFCiQ4hFlX19fZMuWTTIFfM2aNSAibN26VeP1hIQElClTBrly5dK5Hy1LOU6w0ycu4JjAcBYuXAgikkxHTUhIQNGiReHu7i7pjJs0aRKISOeoXHIlJiZiwoQJICJ4eHhICrECwNKlSyGXy1G0aFGda0yjo6PRqVMncU/t0NDQNLWLsdR69+6duB3twIEDdRYzu3v3LvLmzQsTExPxeZyeMcGZM2dARFi6dKnG658/f4ajo6PWYscrV67UGi+wtEn3Imdf78Ej2D/asWPHQETi3nnfatCgAbJlyybpbVMXI9O2Z15KRUVFidNhmjZtKumVjouLEysre3l56dzGIDg4GLVr1xan5H7fYcBYeoqPjxfXW5cvXx6BgYFaz4uOjkaHDh3E6rlRUVEaD7EmTZpg7NixGu+pWbOmGOymZjTmv//+g0wmw4gRIzRe//LlC3LmzIny5cvrHK3atGlTiu/HdOMEO33iAo4JDCcuLg6FCxeGu7u75Dl66NAhrUXNPn78CDMzM3Ts2NEgbdi7dy8sLS3h6OiotaPyxIkTsLa2hr29vd6OzEWLFkGhUCBfvnxcpJH9cOfPn0eOHDlgamqqdz/qAwcOwNLSEtmzZ8eVK1fSPSYAvsb3Dg4OkpHoMWPGaI3vw8PD4ejoiEqVKnFxYQPjBDsLa9SoESwtLSVb99y9excymQz9+/fXeP3z58+wtrZG06ZNDXJ/lUqFuXPnQi6Xo3DhwpJKyyqVCgsXLhR7rZ88eaL1Ot8mOWXLlk1WBUXG0urt27eoUaNGkp07QUFBKF++PIgI3t7eWLp0Kd69e6fRo62ukr9p0yZx1KVPnz6S/S5Tonv37jA2Npb0cE+ePBlEJKlzEB4eDicnJ1SsWJEfpAbGCTYn2JnB33//rXUGm0qlQq1atWBnZycZFe7fvz8UCgVevHhhkDbcv38fBQoUgEKhwMKFCyXfRY8ePRKX4ixbtkzndS5cuABnZ2eYmJhoLdjEmKGpY1p1586tW7d0njdt2jQIggBXV1dMnz79h8QEN2/eFKehfyswMBCmpqZo2bKl5D3q/bp5Kabh/ZAEO7l/+GFqWI8fP9ZZKdjPzw8KhQKPHj3SeF29LYYhq3WeOnUK9vb2sLKykkxPA772WtvY2MDGxgZHjx7VeZ29e/ciW7ZssLa21nodxgzl5MmTcHJygqmpqd5aARcvXhT3gq9YsSJq166Nrl27ok2bNvjrr780ComcPn0azZo1Q8+ePdGzZ0+923slJTAwEEqlEr169dJ4/e3btzA3N4ePj4/kPcOHDwcR4d9//03VPZlunGBzTJBZNGzYEJaWlnj37p3G6zdu3IAgCBg6dKjG669evYJSqUTv3r0N1oawsDA0btwYRIQOHTpIRttCQ0Ph6ekpdm7q2lHk3bt34nKzTp06ad13mzFDCA0NFfe3btKkic7lCZGRkWjRogWICC4uLqhRo8YPiQmAr0vGLC0tJW3r2LEjjI2NJWu81TsMdOrUKdX3ZLpxgp3FqXunvg+q3717B0tLS3h5eWm8/vHjR1hYWKBFixYGbcfLly9RtmxZEBHGjBkjqWD49OlTFC9eHDKZDDNnztQ5wvb06VPxOgMHDuQp48ygEhISMH78eAiCgEKFCunsaFKpVFi+fDmUSiXc3Nxw/vx5+Pr6isf37duHfv36YdGiRRrve/ToEQ4fPowxY8akqCro9/r06aN1VKlLly5QKpWSrfDUnW38IE0fnGBzTJBZPHr0CEqlEp07d5Yc69y5M5RKpWQ2Wffu3WFkZKRziUxqqNdlC4KAUqVKSWamJSQkiDuKVKxYEUFBQVqvk5CQgLFjx0IQBBQpUoS38mIGd/nyZeTJkwcKhQJ//vmnzvj0yZMnYhw7fvz4HxoT3L17F4IgYPTo0Rqvq7fhHT58uOQ9jRo1goWFhWSWKzMMTrCzOPW00N9++02ypmP69OkgIo3tAQBg9OjREATB4A+q6OhodOnSRdym4Ps14BEREWLPX7NmzXSuy46JiUG/fv1ARChTpoxkFJ6x1Hj16pVYtKRDhw6IiIjQel50dDS6du0KItLYPqNEiRLieqyoqCjs3LkTvXr1wvnz5wF8XYJhiL3dAwMDYWRkhO7du2u8fv36dQiCgMGDB0veo2u5CDMMTrA5JshMhg0bBkEQJB3vb968gbm5uWSZ2PPnz6FQKNCnTx+Dt+XgwYOwtraGjY2NZFshANi2bRvMzc2RPXt2ndt7Al/rzjg6OsLExATLli3jZTAszRITEzFjxgwoFAq4urri0qVLOs89dOiQ+HP8999/A/hxMQEAtGjRAhYWFhpxtUqlQpUqVeDg4CCpg3TkyBGty0WY4XCC/QtQVxReu3atxuvR0dHImzcvihUrprE25OPHj7C0tESzZs2SvPaePXvg5uYGhUIBd3d39OrVC5cuXdL7cFuxYgWMjIyQO3duyQNepVJhzpw5kMvlKFiwIO7du6fzOnv37oWtrS3MzMywatUqfqCyVNu+fTusra1hYWGhd0r4s2fPUKZMGRAR/vjjD42ZGNu3b0f79u3x+PFjAF8T4aFDh2LHjh14+fIlxo0bp7Fe+ubNmxg4cCAKFSoEpVIJFxcXvfdW69mzJ5RKpcZ0L5VKhWrVqsHe3l4yPUxdwGj27NnJ/H+DpRQn2BwTZCbh4eHInj07ypcvL+l4nzJlCogIJ0+e1Hjdz88PRkZGePnypd5rf/jwAY0aNYKxsTFsbGzg6emJ1atX652+/eTJE5QsWRKCIGDcuHGSGW53795F/vz5IZfLMXfuXJ3P+rdv36JevXriNN7g4GC9bWVMl9evX4vLD5o3b46QkBCt5yUkJGDcuHEgIslMjJTGBGFhYViyZAlq164NKysrcd20rsEmtVu3bokxybe2bdsGIsKKFSs0Xo+NjUWBAgVQoEABgyX4TIoT7F9AYmIiKlSoACcnJ4SFhWkc27lzJ4gIixcv1nhd/YVx7do1ndd9//49TExMUKpUKYwYMQLe3t4wMzMDEemsAKp29epV5MiRAwqFAosXL5Y8ME+fPg1HR0eYmZnprdL47Zdg06ZNNbY9YCwp4eHh6NixI4gI5cqVk0yt/tb+/fthbW2NbNmyYd++fZLjHz9+xMSJEzFgwADxgbhw4UIMHDgQAMQpjqdPn0bNmjVBRDA2NoanpyeGDx+OSpUqQSaT6W3D06dPoVAoJGuvt2/frnVrjpiYGLi7u6NgwYL8IE1HnGBzTJDZrFu3DkQkKRAWFRUFV1dXFC9eXKPj/eXLlzAyMkLXrl31XrdDhw4wNjbGgAED4Ofnh3z58om7K+jz5csXcb/sevXqSZ7lYWFhaNKkiTjD7ftYRi0xMRFz5syBUqmEs7OzOJrIWHJt374dNjY2MDMzw8qVK3V26Lx//x5169YVawB8X0sguTHB+/fvMXToUFhYWICIUKRIEfTu3Ru9e/eGTCbDkCFD9LbXy8sL2bJl0+gE+PLlC3Lnzo2SJUtKOqxmzpwJItI6Y4QZDifYv4irV69CEAQMGjRI43WVSoWaNWvC1tZWY/1HWFgYbG1t8fvvv+u85r59+0BEuHjxovja58+fsXbtWr1bDMTGxmLhwoUoXrw4nJycQERo2LChpJcuKCgIVatWBRGhe/fuki8vtcTERMyePRtGRkZwcnLCgQMH9P5/wRjwNdF1dXWFTCbDmDFjEBcXp/W8uLg4sZZB6dKldVa7B76udR44cCAaNmyI58+f4/fff8f06dMBfO1lVo+uODs7Y/bs2Rq/c//99x+ICKtWrdJ5/Xbt2sHExERjPaL6QVqiRAnJg3Tq1KkgIr3FA1nacYLNMUFmk5iYiEqVKsHBwUEy62XHjh1aO9779+8PuVyutxBTrly50LZtW/HvKpUK//zzD27cuKHzPeqYoFSpUihVqhQEQYCTk5OkorJKpcLs2bMhl8uRL18+XL9+Xec1b9y4gSJFioCI0Lt3by6AxpIUEhKCtm3bih3u+pYfnj17Fjly5ICxsbHeJFxfTBAZGYmJEyfC3NwcMpkMbdq0kVTzrlevHkqXLq2zHRcvXgQRYcqUKRqvq/edP3PmjMbrr1+/hrm5uaT+EjM8TrB/IX5+fpDL5ZK11bdv34ZMJpNUCVXvi33q1Cmt11M/hFOyD2ViYiK2b9+OZs2a4cqVK+JDXhAEFChQQLLtQXx8vFjopHjx4rh//77Oa9+6dQslSpQQexN1VXlkv7bIyEj0798fRAR3d3eNDqLvvXjxAhUrVgQRoWfPnoiOjk7y+gkJCejbty+6deuGfv36ISQkROyJtrGxwZw5c7R2FgUFBWkdhVa7desWBEGQFCsZP3681gfpy5cvYWZmprWiODMsTrA5JsiMrl+/DplMhn79+mm8/m3H+7drOt+/fw8LCwu9y8ccHR0l9SH0+T4mAIA2bdrA3t4ecrkcs2bNknTYnz9/Hjlz5oSRkZHWrb7UoqOjMWjQIAiCgHz58uldw81+bYcOHRJnVU6YMEFnh3tCQgKmTJkCmUyG/Pnz6+04+vY938YEKpUK27dvR65cuUBE8PHxwYMHD7S+t0mTJihWrJjWY+qlYU5OTho1Y168eAETExOts0ZatWoFY2NjPH36NMl2s7ThBPsX8vHjR9ja2qJatWqSB1Lfvn0hk8k0viyio6Ph4uKCcuXKaR2Rvnz5MogIe/bsSXYbHj58iD59+mDNmjUAviY7rVu3xuTJk+Hs7AxjY2MsX75c0r4jR47A3t4+yfXWMTExGD16NORyOXLmzImDBw8mu20s6ztz5ow4ZbFv3756RzV27NgBa2trWFlZYevWrZLjW7du1bt8IS4uDmvXroW9vT1kMhn69u2rt0ro+fPnQUQ6Z2DUr18fNjY2GtPAnj17BhMTE637W/r4+MDU1NRg+9cy3TjB5pggs1J3/n2fKNy5cwdyuRw9e/bUeF09Mqar4FPZsmVRp06dZN9fV0ywceNGNGvWTJzh9v166uDgYDRs2BBEBG9vb0nR1G+dOXMGefPmhSAI6Nevn84CluzXExISgk6dOoGIULRoUQQEBOg8NygoSFze1aZNG0nhsMDAQPTr109nch4fH4/Hjx+L08pLlSqVZKdPyZIlUb9+fa3HDh48qHWmSbNmzWBqaopXr15pvH7y5EkQEcaPH6/3nswwOMH+xSxfvhxEhA0bNmi8HhISAnt7e1SpUkUjeV27di2ICFu2bJFcKyoqCkqlEsOGDUv2/VesWIHu3buLa7suX76MIUOG4Pjx4xrrWVq0aCHZpzMoKEhcb92iRQudRSeAr1PiixUrJn4R8trsX1tYWBh69OgBIoKbmxtOnz6t89zIyEh0795dnCb2/ZTwL1++wM/PD0SE2rVra+3sefToEWrUqCFuMZOcWR6zZs0CEWmt9H3ixAmthcq8vb1hbm4u2T5HXSF06tSpSd6XpR0n2BwTZFYhISFwcHBAxYoVJR3p/fv3hyAIGrVYIiIi4OTkJIkV1Hr06AErKyuN9dv66IsJVCoVFi5cCCMjI+TMmVOy1EWlUuHPP/+EUqlEzpw5dc62U7e7b9++ICLkyZOHl8384lQqFXbu3Ins2bNDLpdj9OjRerd93bdvH+zs7GBmZgZ/f3/Jz/6BAwdgZ2cHc3NzyTRv4OsyiMmTJ8PY2BhWVlZYuHChZEnX98LDwyGXyzFmzBjJsfj4eBQpUgT58+fXSOiPHTumdcp4bGwsChUqBDc3N53LLZlhcYL9i0lMTET58uXh5OQkmUK9atUqEJFGJeOEhASULFkSefLk0To9tkaNGihatGiy71+yZEns3bsXwNfevnnz5qFfv37i+uvExEQMHjwYgiDA1NRUMm09ISEB06ZNg0KhQK5cufQ+UGNjYzF+/HgolUrY2tpizZo1XGn8F6NSqbBjxw44OztDJpNh8ODB+PLli87z//33X+TPnx+CIGDEiBGSwmC3b98W1/WNGDFC0lMdFxeHadOmwdjYGNmyZcOyZcv01iP4Vo0aNbROBVP/Drq6umr8Dh4+fBhEJK7nUouOjoa7uzsKFCjA+8T/IJxgc0yQmal3Gvm+/kNoaCgcHR0l23wuW7YMRIRdu3ZJrqVeOvb9khVdkooJgK+7lajrtdSuXVuSvF+7dg0FCxaEIAgYNmyY3u+9c+fOoWDBgiAitG3bFu/fv09WO1nW8erVK3h5eYl1VfSt5Y+MjBQ71EuXLi2Zyh0TEyMuOStZsqTW+gSXL18WB3x8fX01Kofro+93SdvvYGxsLAoWLIh8+fJJ4nX1trw8q/PH4QT7F3Tt2jXIZDLJnpaJiYn47bff4OjoqJF8q0fPZsyYIbnWX3/9BSLSu52WWlxcHLp37y5OwZk+fTr69eunMZoYGhqKJk2aoG3btnB0dAQRoU6dOpKevqtXr4qJ0ODBg/Wujb137x4qV64MIkK1atVw9+7dJNvKMr+nT5+iQYMG4lQsbb3KanFxcZgwYQLkcjly5colGeFWj6QYGxvDyckJx44dk1wjICAAJUuWFKvcJvchCnzdf1Ymk2Hs2LGSY/7+/pJZJNHR0ciXL5/W6uDqKZzHjx9P9v1Z2nCCzTFBZqbeL9fOzk4y1Vpdbfzb5Fs9epYvXz5JMhsREQETExNJ57g2KYkJBgwYIM5gy507t8Y2hcDXREg9S6lEiRKSei7fio6OxtixY6FUKmFjY5OijlCWecXFxWH27NkwNzeHqakpZs2apXemxaVLl+Du7i523Hz/rL137574zO/fv7/kdyEyMhIDBw6EIAjIlStXigvw+vr6wsHBQdLG8PBwODo6SmaRqJPo76uDv3jxAqampmjSpEmK7s/ShhPsX1S/fv0gCIJkH2pdyXejRo1gaWkpmbb97t07yOXyZE8TX79+PRwdHVG/fn20bNlSkvR06tQJ/fr1Q3h4OMLCwsT1LlWqVJEUZ4uMjESvXr3EbQ30JVCJiYlYuXIlbG1toVAoMGTIEMn6GZY1REVFYfz48TAxMYGFhQXmzp2r9yF6//59eHh4iMsJvl968PbtW3h6eoKI0KBBA8mIR3R0NEaMGAG5XA5nZ2fs3r07xW1WTw//vnf88+fPyJ49O3777TeNB6muJPrx48cwNjZGq1atUtwGlnqcYHNMkNnduXMHCoUCXbp00XhdV/KtXoby/bIV4GshJRsbm2QVhUxJTAAAo0aNgpGRESwtLbFixQrJrLT9+/fD0dERSqUSM2bMSPK7X72Up1y5cmKRNZb1nDp1CkWLFgURoVGjRpIOmm/FxMRg1KhRkMlkyJ07t9YO90WLFsHExAQODg5aR4VPnDiBvHnzgojQq1evFMebnz59grGxsaQAIQAMHz4cRKTxu6JOops2bSo5v3HjxjAzM0tyD3tmWJxg/6LCwsLg7OyMMmXKSB5Affv2hSAIGr+8jx49glKp1LoHZpMmTWBvb5/s6aghISG4dOkSEhISNIowHD9+HOXLl9dI4jt37ow6derA0tISpqamKFq0qOTL7O+//0bOnDkhl8vxxx9/6G1HcHAwunXrBkEQkD17dqxZs4Z7rrMI9XRwV1dXEBFatmyJ169f6zw/Pj4eM2fOhLGxMezs7LB9+3bJOXv27IG9vT1MTEy0Vqs9f/48ChQoACJCly5dUlW5PjExEe7u7qhcubLk2IgRI0BEGoGfriRapVKhTp06sLKyStHoOUs7TrA5JsgK1IH794WX1AXPvn/+N2jQAFZWVpKO9+PHj2ut9aJLSmKCLl26oH379uKsNCcnJ0kRyg8fPogF0ipUqKB39xGVSoWNGzcie/bsICJ07txZax0Mljk9f/4cvr6+4tp79XIEXQICAsTp3F26dJHst/7mzRvUr18fRIT69etLflbCwsLQrVs3EBHy58+Ps2fPpqrdc+fO1bpLz3///QcjIyN07NhR43UvLy+tSfSePXt0doSx9MUJ9i9s27ZtICLMmzdP4/WwsDBkz54dZcuW1ZiaPWTIEEniDfz/w1RdBTS5goODMW7cOHF6+ciRI7F48WIxibl8+TLc3NwQHByM58+fo0KFCiAi2NjYYODAgRrXCg0NRceOHcXR7MuXL+u997///ovffvsNRISyZcum+kuQZQxXr14V90wvUaKE3iJmwNeAsVy5ciAiNG3aVBIghoWFiZVFS5cuLVkCERERIc4CyZMnj9Yp48mlrgS6efNmjdfVD9IOHTqIr6lUKtStW1drEr1x40YQERYtWpTqtrDU4QSbY4KsIDIyEq6urihSpIhkOuywYcNARDh//rz42sOHD6FUKiWj3omJiShYsCDKlSuXoronKYkJEhISMHr0aMhkMiiVSslOCiqVClu2bIGtrS2MjIwwdepUndWdga/TbocNGwalUgkLCwtMnTqVi0FlYuHh4Rg1ahSMjY1hamqKiRMn6v33jIqKwsiRIyGXy5EjRw6to9Lbtm2Dra0tTE1NsWjRIq2zJ3LkyAGZTIZhw4al+ucnISEBbm5uWjvdGzduDAsLC43n/969e0FEmDVrlsa5nz9/Rq5cuVCiRAm9P/ssfXCC/QtTqVTw9PSEhYWFpJz/1q1bQURYsGCB+Fp4eDicnJxQoUIFjVFflUqFYsWKoVixYikuIqZSqcStkiZPngx/f3/xWIUKFTB//nzx7/Hx8Zg1axaUSiWMjY2xadMmyfUOHTqEXLlyQRAEDBgwQO92HImJidi4caO4F6GXl5fenm6W8Tx79gxt2rQBEcHBwQHLli3TW5kzOjoaY8aMgUKhgL29PbZu3Sr5mT1+/DhcXFwgk8nwxx9/SALNo0ePwtXV1WBbvtSoUQM5c+bUeACqVCo0bNgQlpaWGg/SzZs3g4iwcOFCjWt8+vQJjo6OKF++fJKVSZnhcYLNMUFWoe7w+74KcUREBHLnzo2iRYtqfCcOHTpUMssGAJYsWQIiSnHndUpiApVKhevXr4tFyzp37qxRHA34uoytefPmYhEqfUvJgK8zhLy9vUFEyJUrF9asWcPfqZlIbGwsFi1aBAcHB7GQ3ffx7fdOnz6N/Pnz65yJ9vHjR7Rs2VJcSvB9IbMPHz6gdevWICIUK1YsyZ+xpKiLm+3YsUPjdXVh028T6YiICLi4uKBYsWKSJHrAgAEQBEHnlnosfXGC/Yt79uwZTE1N4eXlpZFoqFQq/P7777CwsNDYAki9bde3Dz3g/wuhpLSIw7fGjBkDT09PnD17Fv369UP16tXFY98+4MaPHw8jIyMQEVq3bi1ZExseHo7evXtDEAS4uLhg3759eu/75csXTJ06FZaWlpDJZOjcuTPvHZzBvX37Fn379oVSqYSpqSlGjx6d5BqnkydPitO527VrJ9lXNTw8XCySU7BgQcksiE+fPomj2gULFtQYyUmtS5cugYjw559/ary+f/9+ybSukJAQODo6oly5cpKAr2vXrpDL5cnaDowZHifYHBNkJc2bN4exsTEeP36s8br6e2natGnia+Hh4ciePTs8PDw0vpeioqLg4OCgcw/f5EhuTLB48WJky5YNgiDA1dVV6xZcu3fvFneTGDBgQJLPizNnzoiznIoWLYo9e/bwLiQZWEJCAjZu3Ih8+fKBiFC9enVJjaHvBQcHo3PnzuL2ndoKg+7evRtOTk5QKpWYPHmyxpJKlUqFTZs2wd7eHkqlEhMmTJB0yKeUSqVC2bJl4e7urvEz/u3uIN/eY9CgQSAiXLhwQeM6V69ehUwmQ69evdLUHpZ6nGAzscDSzp07NV5/+vQpTE1N4e3tLb6WmJiIypUrw97eHp8+fRJfj4uLQ548eVI8Jex748ePR8OGDbF06VKNkbu3b9/i/PnzGD16NIoUKYJjx45h4sSJUCqVcHBwwNq1a7Fnzx6Na124cEFcS+Pt7Z1k0hwcHIxBgwbB2NgYSqUSvXv3luwvzH6uDx8+YNiwYTA1NYVcLoefn5/eddbA15+ddu3aiQ9RbdO5jxw5Io5aDxkyRGNql0qlwrZt2+Dk5ASFQoFRo0Ylq3hPcnh6esLOzk5jFDwqKgp58uRBkSJFNHqku3XrBrlcjhs3bmhc48yZMyCiFO1HzwyLE2yOCbKSoKAgWFlZoVatWpLnebNmzWBiYoL//vtPfG3Tpk0gIixbtkzjXHVV47QUD0tuTHDq1ClcuHBBHM3u0KGDZJZbWFiY2PmeI0cObN++XW+8olKpsH37drFj1sPDA4cOHeJEOwNJTEzEjh07xAJmJUuWTPLfKDExEf7+/rCzs4NCocCIESMk23e+f/8eLVq0EJeJfd95/fLlS3GXkgoVKkiK8KaWegbJ91vmTZ48GUSkEb8EBARAJpOhZ8+eGufGxcWhVKlScHZ2lqwhZz8OJ9gM8fHxKFWqFLJnzy6ZGjNz5kxJ8n3z5k3I5XL06NFD49yVK1emeRQbgFik7N27d1i/fj26d++O0qVLo3Xr1pgzZw5u374N4OuX5I0bN5AjRw4QEYgIY8aM0bhWXFwcZs6cCTMzM5iammLy5MlJJkevXr1Cjx49oFAoYGRkhF69evGI9k/25s0bDBkyBGZmZpDJZGjXrp1GgKdNXFwc5s+fDysrKxgZGWHMmDGSNVEfPnwQk+/ChQtLplK9evUKjRs3Ftfqf5/cpoV69Pr7faz/+OMPEJHGOnJdSXR0dDQKFCiAvHnz6t3fm6UvTrA5Jshq1Pvsrl69WuN1bcm3SqVCjRo1YGNjozGj7PPnz7Czs8Pvv/+eprakJCaIiIgQ63EQkaQYFPB1LXfp0qXFbUCTWhoWHx8Pf39/5MmTR5wmvG/fPi6Q+hMlJCRgy5Yt4iBKoUKFsHXr1iT/TQICAsT6O1WqVBF/dtRUKhXWrVsnrt2fMmWKRkd3QkICFixYAAsLC5iZmWHevHkGW0KgHr3OkyePxj2fPn0KExMT+Pr6iq99m0Tritu17VPPfhxOsBmA/+8J6969u8br6l/i75PvQYMGQRAEXLx4UePcfPnyoUSJEgZ58AQEBMDd3R1FixaVjFKGhoZi7NixGDJkCEaPHo1WrVpBEARYWFhg4cKFksroL1++FNdh5c2bF7t3706yF/r58+fw8/ODUqmEQqFAhw4dDNZLyZLnv//+Q8+ePWFsbCwm1t9vZaXNsWPHUKRIERARfv/9dzx69EjjuEqlwtq1a2FnZwelUomxY8dqVJ//9iFqamqKOXPm6N3uJaVUKhVq1qwJBwcHjdHr+/fvQ6lUon379uJrUVFRyJ8/v9YkevTo0ZJebfbjcYLNMUFWk5iYiKpVq8LGxkZSKXnp0qWSpWLavruA/58hZ4hCosmNCcaMGYNJkybB1NQURIQmTZpI1uEmJCRg0aJFsLa2hkKhwMCBAyVbNH4vNjYWK1euFLdfKlasGDZs2MAFpH6g6OhoLF++HO7u7mLH+KZNm5JMct+9eyfuIOPo6Ii1a9dK4tT//vsPtWvXBhGhUqVKko6XGzduoHz58mIFcX1bfaXGzp07JQWDVSoV6tevDwsLC42f+RkzZmidefrff//BxMSE97zOADjBZiJ1sZJTp05pvK7eG7tbt27ia7qqE6qLMCV3e46bN29iy5YtOHjwoNbp2KGhoejduzeqVq2KI0eOAPj6BdK2bVv4+PiI59SvXx9Tp05FvXr1xB7mgIAAyfVOnDghTiWqUaMGrl27lmQbAwMDMWDAAJiZmYGI4OnpiWPHjvE0sXSiUqlw7tw5+Pj4QBAEGBkZoXv37njy5EmS771//z4aNmwoTgffu3ev5N/p271PK1WqhLt372ocv3btmrgvdv369fHs2TO993z79i2OHDmCzZs348qVK8n6uVDvIfttEUGVSoXq1avD2tpaYxRo1KhR0Lbn9c2bN6FQKNCpU6ck78fSFyfYHBNkRQ8fPoSRkZHGyBnw/8m3tbW1RvKt7vA7efKk+NqXL1+QI0cO/Pbbb8n6bvz06RP27NmD7du3IyAgQJK8piQmWLp0KWbOnAkTExNYWFhg7ty5ko7SDx8+oHv37hAEAXZ2dli4cGGSCXN8fDzWr18vxhIuLi6YPXt2qrZpZMnz4cMHTJo0CY6OjuKMsp07dyY5mBMdHY0ZM2bA0tISCoUCgwYNkkybjo6OxsSJE2FsbAwrKyssWbJE47oREREYOnQo5HI5HB0dsXnzZr0/yzExMbh8+TK2bt2K/fv3J2sP7Li4OBQsWBCFCxfW+BnVttvP48ePYWJiItnzWt1xb2VlleTSOZb+OMFmoi9fviBfvnzIly+fZKRMvT/miRMnxNfU++vNmDFDfC0xMRFlypRB7ty59W5RoFKpxL0Cv/1TqFAhjB07VjL9d/fu3Rg+fLj496VLlyJHjhxYvHgxJkyYgBYtWojX3bRpE5ycnCCTydC7d29Jr3R8fDwWL14Me3t7seBVcnoiP378iMmTJ4tf8EWKFMGSJUskVUtZ6kRFRWHNmjUoW7asuB3bqFGjkrWn89u3b9GjRw/I5XJYWVlh5syZkv3QIyIiMGLECCiVSlhbW2PZsmUaD9Hw8HD0798fMpkM9vb28Pf31/sQjY+PFxPxb/80bdpU70M/ISEBxYsXh5ubm0axEn9/fxARVqxYIb527do1yOVydO7cWXLvMmXKwMnJSaMWAvs5OMHmmCCrmjp1KohIUuPk4cOHMDY2FpNa4Ot3eL58+ZA/f36NpVirVq3SWhX5e2fPnoWlpaXG96mlpSVatmyJAwcOaCQeyY0JgK/FXD09PaHexvHcuXOSe9+4cQM1a9YEEaFAgQLYtWtXkh0CiYmJOHDgAKpXrw4igrm5OXr16iXptGWpd/XqVXTq1AnGxsbiAMfJkyeT/LdJSEjA+vXrkTt3bhARGjduLKn+DXzt7FaPhrdo0UIj3lCpVNi7dy9cXFxARGjVqlWSseKCBQvEmRPqP3Z2dkkuL1u8eDGISKMob2hoKLJnz44yZcqIP/uJiYmoVq0asmXLhqCgII1rLF++XBJDsJ+HE2ym4fTp0yAiDB48WON19TTVPHnyaExpbdq0qaTgyalTp0BEmDp1qs77/PvvvyAi9O/fH/fu3cOFCxcwd+5c1KpVCzKZDESEunXr4tChQzqTlcDAQFSsWBFmZmbiNBn1l1BoaCj69esnJksrVqyQTCEKCwvDiBEjYGxsDCMjIwwcOFBSkVybmJgYrFmzBmXKlAERwcLCAj169NA6Ys6SdvfuXQwcOBA2NjYaHRfqrVr0CQ0NxR9//AEzMzMoFAr07dsXHz580DhHpVJh8+bNyJkzJ4gInTp10vh3VnfKZM+eHUQkbnfVrFmzJNfejxo1CjNnzsQ///yD+/fvY9y4cSAiHDp0SOd7VqxYASLC9u3bxdfevXsHGxsbVK1aVfx5j4uLQ8mSJZE9e3ZJJ5G6eFBSASv7MTjB5pggq9L3PaSepvrtd9nx48dBRBg9erT4WkJCAooVK4a8efNKOj6/Va5cOeTLlw8XLlzA7du3sXXrVnTr1k3sDM+VKxemTp2Kjx8/an2/vphApVJh165dYrLUtm1bSYKiUqmwf/9+FC5cWCxe9f2MPl2uX7+Ojh07iolglSpVsH79et5LOxU+f/6MFStWiB3Y5ubm6NmzZ7K2UVWpVDhw4ABKlCgBIkKZMmW0/hs+ffpU3I6tQIECkmVWT58+FWfDFShQAIULF0alSpWSjAvOnj2L3r17Y9euXbhz5w5Onz4NJycnNGjQQOd7QkNDYW9vj+rVq2t0HHTv3h0ymUxjpqU6Ef9+J5/AwEBYWVmhZs2aPLsyg+AEm0n06NEDMplMUvDpn3/+Eff+VXv9+rXWX+omTZrA3Nxc8gBTU4/WPX36VHLs9evXmDx5spgQFS1aFOvXr9d4UKqT5UaNGqF58+aSQlFqN27cQJUqVcRKkNrWgb169QpdunSBTCaDubk5Ro0alaxRQZVKhUuXLqFDhw4wMTERK1jOmzcP7969S/L9v7JPnz5h6dKlqFChAogISqUSLVq0wKlTp5L1cAgPD8eUKVNgbW0NIkLLli21Fj0LCAgQ//3LlCkj2crizp074uiDh4cHRowYgYkTJwIAhg0bhsGDB+PMmTMAkKx2ff78GUSEmTNnaj0eFhYGBwcHVKlSReN6vr6+MDIy0lhfPmnSJK0jRw8ePJCMHLGfixNsjgmysuvXr0Mul0sKhsXHx6Ns2bJwcHDQ2PawY8eOUCgUGpWX1Ym3rmc1AJiamko694GvSf7u3btRp04dEBHMzMwwYMAAcVlZSmKCyMhI/PHHHzA2Noa5uTmmTp0qSYLj4+OxcuVKMQapVauW5NmhS3BwMGbNmiWOimbLlg09evTAxYsXOfHRIzExEadOnULHjh1hbm4uxn5//fVXsiphq1QqHDt2TCxgli9fPmzZskUyQPP582eMGjUKRkZGMDc3x/Tp0zU6faKiojBu3DgYGxvDwsICc+bMwaJFi9IUF7Rv3x65c+fWeXzw4MEQBEEjkVYPdg0dOlR87fnz57CwsEDdunUl2+o2aNAAZmZmyVpKx34MTrCZRHh4OFxcXFCoUCFJxe3+/ftLCpaoq41+Oy3lyZMnMDIyQtu2bbXeY+vWrSAiSQXHb8XFxWHDhg1ilUg3Nzf4+/uL66OuXr2KGjVqaLxH25edegQzV65cICL4+PhoTcYePnwoFkuztLTEqFGjJHsl6xIaGorFixeLPa5yuRy///471q5dy+uy/iciIgJbtmyBt7c3lEqlWCTmzz//lIw66xIWFoYpU6bA1tZWnPalbepVYGAgOnToII5Ifz+D4dOnT+IMBysrK8yfPx8JCQlo2bIlJk2aJF5jzpw5GDJkSLI/45s3b0BEWLhwodbjAwcOlDxI1UstpkyZIr5269YtKJVKtG7dWuP9CQkJ+O2332Bra8udOBkIJ9gcE2R16t0Nvp+dc/v2bSiVSrRs2VJ87ePHj3B0dETZsmU1pnV7e3vD3Nxc5/aXdnZ28PPz09uO27dvo0OHDlAoFFAqlejZsydevnwJIPkxAfA1RmnatCmICLlz58amTZskyVh0dDTmzp0rLgurW7cu/vnnH73tU1MnjO3atROnDLu5ueGPP/7A7du3OdnG13+bq1evYujQoWJ8ZmlpiW7duiW7Q0KlUuHvv/9G5cqVxfXwy5cvl6yjj4+Px/Lly+Hk5CQuDfx2nbJKpcKOHTvg6uoq/luri6OmNS7w8fFBoUKFtB67f/8+FAqFRo2jL1++wN3dHW5ubuJyTZVKhdq1a8PCwkIygr5u3ToQEebPn5/sNrH0xwk20+rvv/8GEWHEiBEar0dGRsLNzQ1ubm7iVPHExETUqFEDVlZWGg/OMWPGSJJxtbt374JIugWINomJidi7d6+4NtfNzQ1r1qxBfHy82PP4/ZdpREQE5s+fr9Ez+eXLF0yaNAnm5ubidGJtU8Lv3LmDFi1aQBAEmJmZYeDAgZIKpPrcu3cPo0aNErf0UCqVqF+/PpYvX56s9cRZSXBwMNauXQtvb29xlD9HjhwYPHgwrl27luwg4/379xg9ejSyZcsGIkLDhg3x77//Ss4LDQ3FyJEjYWJiAmNjYwwfPlyj91u9/t7KygpEBAcHB/j6+orTt7Zt24YOHTqIyfjp06fRq1evZAdVBw4cAJHmFltqt2/flmxvFxISguzZs6NkyZLiz7C6cr+jo6Okg2f27NkgImzcuDFZ7WE/BifYHBNkdTExMShatChy5swp6TRW79H77ZKV7du3S2q0PHv2DCYmJhrJ+LeqV6+O5P6bP3/+HD179oSRkRGUSiV69eqFwMBAnTEBAGzcuFGyo8SpU6dQqlQpcRbTtwXa1CIjIzF79mwx0a5WrRoOHz6c7OdXeHg41qxZg7p164pL4AoUKICRI0fi0qVLv9R2X/Hx8Th79iwGDRqkESM1atQIW7ZsSfZ2kwkJCdi5c6c4qOHi4oLFixdLliCop/2rdxWpXLmyZF/2GzduiIVPLS0tUbBgQbRr184gcYFKpULu3LnRvHlzrcdq1aoFa2trjUGGIUOGgL4rOKyu3P/9XvNBQUGwtrZG5cqVDbZdGDMMTrCZTt26ddM7Vbx3797ia0+ePIGZmRkaNGggPnS+fPkCV1dXFC1aVPKwS0xMhLOzM5o1a5bs9qjX1qjXPhcoUACbN2/W+nBav349iL5uybVjxw6NB+GbN2/g5+cHuVwOCwsLjBs3TusUpPv376NDhw6Qy+VQKBRo3769xpS35LT38uXLGDJkiLith7r65ZgxY/DPP/9kue09EhIScPnyZUyaNAkVK1aEIAji2rm+ffvi7NmzKQomHj58iB49esDY2BiCIKB58+a4fv265LwvX75g5syZsLGxgSAIaNu2rUYxEpVKhcOHD4sP2fz582PChAnicTs7O3z8+BE3b95Ejx49sHfvXgBflw8MGTJEUsFbl549e8LMzEwy8yMxMRFVqlQR76Om/vn6dkR7/PjxeqeGe3t78+hHBsMJNscEv4KrV69CLpdLdi6Ii4tD2bJlYW9vL3Zaq1Qq+Pj4wNjYWGPt7MSJE0GkfWvBSZMmQRCEFHVEv3z5Ej179oRSqYSxsTH69+8v2VYM+PqMcHR0hEKhwIABAzS+hxMTE7Fu3TpxfXa9evVw9epVrddYsGCBONpavHhxrF27Vu+68u+9e/cOS5YsQe3atSGXy8WO3g4dOmDTpk3JqgOT2bx+/Rpr165Fq1atxForRkZGaNCgAfz9/VNUqPPLly9YunQp8ufPDyKCu7s7Vq5cqVEwVO3s2bPiyHaBAgWwc+dOSSzYtWtXCIIAW1tbtGrVSmMmpiHigjt37mhNjIH/33VnyZIl4msXL16EIAgaHfFPnz6Fubm51qnhDRs2hKmpKR4/fpxkW9iPxQk200k9VbxgwYKSNUqDBg0CkebWQQsWLJCMSu/fv1/nuqu+ffvC2Ng4yb0nv6dSqbBnzx4UL15cnGasbV/rY8eOidPLK1asiPPnz2scf/jwIZo1awYigq2tLaZPn65RwE3txYsX6N+/v7guqEaNGti9e3eK9kVWqVS4ffs2pk6disqVK4u92Obm5qhfvz6mT5+Oc+fOSRKzjC42NhaXL1/GnDlz0LhxY3GEWRAEeHh4YPz48QgICEhRQpiYmIgjR46IVV+NjY3RvXt3ycgD8HUK319//SUWKKtfv75kyvj169fFEQoXFxfs2bNHo6d30qRJaNasmRgELlmyBO3btxc7Aho2bCgmu/o+R3R0NGxtbbWOzqxcuVJSmEQ92j1mzBjxtYCAAMjlcrRr107j/fHx8ahQoQJsbW21Bo/s5+IEm2OCX4V6K64DBw5ovH737l0YGRmhadOm4vfku3fvYGtri/Lly4vPy+joaOTPnx/u7u6SuOLBgwcgIsyePTvF7Xr+/Dm6dOkCuVwOMzMzjBgxQlIM7d27d/Dz84NMJoO1tTVmzZql8cyNjo7GnDlzYGdnB6Kv+2ffunVLcq/Y2FisXbtWjC+yZ8+OiRMnpvi7OSQkBJs2bULr1q3Fe6oT9379+mH79u0669hkVCqVCi9evMCmTZvQo0cPFCxYUPxcTk5O6NixI3bs2JHi3VdevnyJkSNHisvDPDw8sG3bNq2jtpcvXxa3bHV2dsbSpUs1BjM+f/4srrOWyWTo3LkzQkJCNDr/DRUXDB8+HHK5XLKkKyQkBE5OTihXrpz4GaKiolCwYEHkzp1b3NorISEB1apVg5WVlWQm5erVq3lqeAbGCTbTS12YZNCgQRqvR0VFoVChQsiVK5c4XUy9fYCVlZW4Jgr4uv7k+0rjwNdpOUSEP//8M1VtS0xMxJYtW1CgQAFxZPjIkSMaX3YJCQlYuXIlnJ2dQUTw9vaWbKEREBCABg0agIhgb2+P6dOna923MCQkBDNnzhS3fXBxccHkyZNTNe07JCQEu3btQu/evcWKpeqpUh4eHujZsydWrlyJf//9N9lTptJbdHQ0rl+/jjVr1qBfv36oWLGiWDFV3ZPcrVs3bNmyJdlrqr8VHByMOXPmiMVhnJycMGHCBK09+lFRUViwYAFy5MgBIkL16tUl07WePn2KNm3aiGvi69atC29vbyxduhTA1ymPmzdvRo0aNbB06VKULFkSp06dwtu3b9G4cWP069cPy5YtQ8WKFbVOR/+eetbE9yMz6grh1apVE382P378iOzZs6NEiRJiz3tUVBQKFy6MnDlzSjqdpk2bBiLCli1bkv9/KPthOMHmmOBXERMTg+LFiyN79uySBHbWrFkgIqxbt058bcuWLSAiTJs2TXztxIkTINKsNK5WqVIluLu7p3ra9OPHj9GmTRuxlsqECRMkz/M7d+6IHbguLi5YvXq1RqIWHh6OCRMmiEuJmjVrprXWh3r9r/pa6rXoZ86cSfEso4SEBPz777+YOnUq6tSpAzMzM/HZ6uLigmbNmmHatGk4cuQIgoKCMsQsJpVKhZcvX+LAgQOYNGkSvL29xVhLPd3a09MTs2fPxo0bN1L8b6rubPf29oZMJoNMJoOPjw/++ecfrZ//4sWLqF+/vhjLzZ49WyN+io2NxcKFC8Wp/jY2NujVqxeaNm2aLnFBdHQ0HBwc4O3tLTnWvXt3yOVyjRl52gau5syZAyLCmjVrNN7/8uVLWFlZoVq1ar/UEoPMhBNslqTevXtDEASxaqLav//+C7lcrlHITD2VpU6dOuIvfVBQEKysrFCrVi3Jl2LVqlWRO3fuNE2Vjo+Px5o1a8T1PJUrV5aso4qMjMSUKVNgZWUFmUyGDh064NmzZxrnXLp0SfxytrGxwbhx47QWOYuPj8euXbtQu3ZtMXnz8vLCvn37Uv05Pnz4gL1792LEiBGoWbOm+GBXjwbnzZsXnp6eGDBgABYsWIB9+/bhxo0b+PDhg8G+XFUqFT5+/Ihbt27h4MGDWLRoEQYPHoxGjRohf/784qi7euS9SpUqGDJkCHbs2JHqteUJCQk4evQoWrZsCSMjI/Hfb9OmTVqnfIWHh2PGjBlioZJq1apJ9sQMCgpCr169IJfLYWpqCg8PD+zevRvA1+UNLi4uYu/5tyMoCxcuRMOGDcVrLF26FL6+vsmqHqtSqVCyZEkULlxY8jPeokULSYXwli1bQqlUagRtAwcOBBHh6NGjGu+/efMmlEolfH19k2wH+zk4weaY4Fdy48YNSWEz4Ov3eZUqVTQ62VUqFXx9faFUKjVGg7VVGgf+vwCq+js7te7cuQMfHx9xhtq0adMkM9ROnTolruEtXLgwdu7cqfE8DQkJwdixY8XnccOGDXU+Dx49eoQBAwaIO1sUKFAAM2bMSPUIdFxcHK5cuYJ58+ahZcuWcHNzE5+/RARra2tUrFgRHTt2xOTJk7Fx40acO3cOz58/T9GU9aRERUXh6dOnOHv2LNatW4eJEyeiXbt2KFeunMae5YIgiGuXFy5ciOvXr6dolt+3nj9/jgkTJogFxxwcHDBy5Eit22OpVCocP35c3MPc3t4eM2bM0Pi3TkhIwLp168SletWrV0fLli3FZ216xQXqfam/j0e1VQg/c+YMBEFAr169xNfu3LkDIyMjybKwxMREseCZtp14WMaQbgk2EbkQ0Wkiuk9E94hoQFLv4YdpxhQZGQl3d3e4urpKeoInTJgAIsK2bdvE19RfKn/99ZfktW/XtwDAwYMHJVNnUys2NhZLliwRt9aoUaOGZFQzODgYQ4YMgYmJCRQKBfz8/CRf2v/++6+4P6KZmRn69u2r80vs8ePHGD58uJjwOTg4oF+/frh06VKaepgTExPx5MkT7Nq1CxMnTkTLli1RsmRJcZr6t38UCgWcnZ1RrFgxVK1aFQ0bNkSLFi3QsWNH+Pn5oXfv3ujbty/69OkDPz8/dOrUCS1btkSjRo1QrVo1FC9eHDly5BAre3/7x9TUFMWKFUOzZs0wduxYbNu2DQ8fPkxTMQ2VSoUbN25g6NCh4gi0ra0t+vfvjzt37mh9z+vXrzF8+HAx0Klbt66kw+ft27fo378/TExMIJfL0bNnTzx48AAdOnTAxYsXxYd9s2bNxIrd3wZTjx49QsOGDVO1b6l6KcT3vczaKoSrR3S+fU09otOnTx+N9387WpTcivbsx+MEO33iAo4JMq6pU6eCiLB582aN1589ewYLCwvUqFFD/H4NDg6Gk5MTSpQoISZ/uiqNx8fHI1++fChbtqxBRmm/naHm4OCAOXPmaIxqqlQq7Ny5E4UKFQIRoVSpUti3b5/GvUNDQzFp0iRxGnflypWxd+9erc/BL1++YO3ateL2kDKZDPXq1cP69etTPC36e6GhoThz5gwWLlyInj17onr16uIz9Ps/1tbWcHd3R/ny5VG3bl34+PigXbt26Nq1K3r16oU+ffqgT58+6NmzJ7p27Yq2bduiadOmqFOnjrgf+bcd/d8m0i4uLqhduzb69u2LJUuW4Pz582n+bJ8+fcLKlSvFbTMFQUCdOnWwdetWrZ3tcXFx2Lx5s1iTx9nZGX/++SciIyPFcxISEjRmOZYqVQpHjhxBaGhouscFcXFxyJs3L8qVK6fxs/Tlyxfky5dPo0J4eHg48uTJA3d3d7H9MTExKFmyJBwdHSWz+P766y8QEZYvX56iNrEfKz0TbGciKvO//7YkosdEVETfe/hhmnFdvHgRMplMUtxEvTbUxsZGY09KT09PmJiYiOtXvq00/u06EpVKBQ8PD7i6uhqs1zU6Ohrz588Xk97atWvj3LlzGucEBQWhT58+UCqVUCqV8PPzk4xo37t3D506dYJSqYRMJkPTpk11Tv2Ki4vD/v370bx5c3HadJ48eTBs2DBcuXLFoKPM79+/x5UrV7Bjxw4sWLAAo0aNQteuXdGkSRPUqFEDZcqUQYECBeDi4gInJyfY2dnB1tYWtra2cHJygouLCwoUKIDSpUujWrVq8Pb2RpcuXTBixAjMmzcP27Ztw6X/Y+8qw6rI3/Zzku6WFgQJA1ERUBERDAzsRNfGDsy1A7u7O9fADsRERewAuzBQRFHpOnO/H3jnt4wHXHV1V/d/7uuaD8yZM3WGeep+7ic2FklJSd/1vK9fv47Ro0ezniypVIpGjRphx44dJf72cXFxaNeuHaRSKcRiMVq1aoXLly8Ltnn58iXq168PsVjMBM6KJkR69uyJMWPGsL8vXboEKysrAIVJmaysLCxcuBAVKlT4JoOlUChQoUIFODg4CBgM7969U1IIf/78OfT19VGtWjVm2N+9ewdLS0s4OzsrtQMMHToURIQDBw589Xmp8M9BFWD/GL9A5RP8vMjPz0e1atWgr6+vNHaL7w2dMWMGW8drTgwdOpSt27FjhxJ9HADWrl37XarYRREbG8vmaJuZmWH27NmC9y1f5XRwcAARwcPDA5GRkQIbmJGRgfnz57PKqqOjIxYsWFBicHn//n2MGjWKsevU1dXRrFkzbNmypdg2tG9FZmYmbt++jSNHjmDVqlWYNGkS+vbtizZt2iAoKAheXl5wc3ODvb09LCwsYGxszHwCExMTlCpVCvb29nB3d0e1atVQr149tG3bFv3790dERATWrFmDqKgo3Lt377vqxLx79w5r165FcHAwS/A7OTlh0qRJxVar+e9Mnz6dCdI5Oztj5cqVAh8iPz8fmzZtYskHExMTbNq0SfBb/mi/gB9f+6nt5mngRRXCf/vtN4jFYkFVnLf9+/btE3z/zp07UFdXFwgKq/Bz4h+jiBPRXiIK/Nw2KmP6c4Ofg/mp0Xvw4AG0tLRQu3Zt9gJ79eoVjI2NUalSJZZ9fPToETQ1NREUFCR4MURFRYGIMGfOnO96vpmZmZg9ezbrt6lduzZOnjwpOPazZ8/Qq1cvyOVySCQSdOzYEQkJCYL9vHz5EiNHjmTiGuXKlcPSpUtLNKofPnzA2rVrUa9ePUilUhAVjqYKCwvD/v37f5qe6h+NnJwcREVFoX///szBEIvF8Pf3x7Jly5T693hkZWVh3bp1qFq1KuvjGjhwoFIC5NGjRwgLC2P3OCAgAA0aNMCQIUMEYzji4+NRrlw5PH/+nAW6/v7+iIqKwtu3bzFgwAA0bdq02B67L8HGjRtBRNi8ebNgPZ8Y4PerUChQu3ZtaGlpMT0CjuPQokULSKVSpcQBTxn7q7mwKvz7UAXYP8YvUPkEPzfu378PTU1NQUsY8KeCuEwmE/SY9uzZEyKRSDDGsEWLFpDL5QL2Un5+PsqWLQtnZ+fvPmkjJiaGtXeZmppi5syZAjox33LGB9ru7u7YtGmTUpV9+/btqFatGrNRffr0UdJ34cFxHM6ePYu+ffsyQU6ZTIa6deti0aJFSrbtvwqO43D37l3MmTMH/v7+TEXdxsYG4eHhnxVEvXTpErp06cJmivv7+2Pfvn2C5y4nJwcrVqxgOi4aGhoYMWIEWrZs+Y/6Benp6bCwsICvr6/gemJiYpRo4Dt37gSRUOz0+PHjSkriQGHg7+npCSMjo/+5ka+/Iv6RAJuI7IjoGRHpFvNZDyK6TESXbWxs/pmrVuGbUPSf+9O+olWrVoGIMH36dLaOp8cWnaW9ZMkSEBETlOARFBQEfX39H0KDzczMxJw5c1igbW5ujoULFwpefC9evMDAgQOZsEjjxo2VhDQyMzOxatUqpkitra2N7t27Iy4urkSjkJqaivXr16NZs2aM3q2mpoY6depg2rRpuHjx4jf3Kf1sUCgUuH79OubMmYP69euze6muro7g4GCsXLnysyNIbty4gf79+7MetrJly2LhwoVKiYzLly+jdevWEIvFkMvlqFq1Ktq1awegUPF91qxZGDJkiOA3CQsLw9ChQ3H//n0UFBSgffv27Bn+dKbr1yAjIwNWVlbw9PQUGHreaE6YMIGt48VKVq5cydbxlZ5PVfbfv38PGxsbODo6Fqtsr8LPBVWA/f38ApVP8GuBr9TNnTtXsP7t27coVaoUypYty5LKGRkZcHJygpWVFRNyfPPmDUxMTODh4SGgAu/du1ep1ex7IiYmhs0+VlNTQ69evQS2ID8/Hxs3bmSjHe3s7LBgwQKl93FcXBxCQ0MFGiLr1q0TUJWLQqFQICYmBuHh4SwQ5Cuxffv2RWRk5FdPVvmZ8ebNG2zfvh09evQQjCt1c3PDyJEjcfHixRL9pw8fPmDp0qWMBq6pqYkePXrg5s2bgu3ev3+PadOmMYE1T09PdOjQgQWo/7RfMGbMGBARzp8/z9alp6ejdOnSsLe3Z8/QixcvYGBggCpVqrAg/+3bt4zR9ukzxCv479q165vPTYV/Dj88wCYibSK6QkTN/mpbVbb658edO3egoaGBoKAgpYx18+bNIZVKBfMju3fvDpFIxEQeFAoFAgMDoampKZjbFx8fD4lE8kXVOoVCgXXr1iEkJAQ1atRAo0aNEB4erqRS/ik6deqEKlWqsL4iJycnpTEPKSkpGDduHOu1qlKlCrZs2SIw/BzHITY2Fr/99hvLprq6umLatGlKVLmi4Cu6gwYNgpubGzM0Ojo6qFu3LiZOnIioqKhfxrh+/PgRJ06cwJQpU9CwYUM235K/t3369MGBAwc+W7F//fo15s2bBw8PDxAVzsZs06aNEtOgoKAAkZGRqFmzJogIurq6GDp0KF6+fIkzZ84gODiY/UanT59GWFiYQCzs5cuXmDFjBurWrQsXFxeEhoYiKysL7969Q2RkJIYPH44mTZqw52nJkiVflPjgWR1FWxBevXoFIyMjVK5cmRnNK1euQCaTCUbY3Lt3D1paWqhVq5ZSL1+7du0gkUhw4cKFL/glVPi3oQqwf4xfoPIJfn5wHIdGjRpBTU1NKfCJjo5WYuFcunQJUqkULVq0YO9CPhk/atQowX7r1KkDfX39L5pKkZCQgLCwMNSqVQsBAQHo3LnzX05d6Nq1K+rVq8dmKvPjvYqO21IoFNi7dy98fHxAVCiAOnz4cKWRSSkpKZg5cybr99XW1kbnzp1x4sSJz7Za3bt3D3PnzkW9evWYPyESiVC+fHn06tULGzduxL17934JpeiCggIkJCRgzZo16N69u2BCio6ODho3bozFixfjyZMnJe4jPz8fhw8fRtu2baGurg4iQvny5bFw4UJ8+PBBsO2DBw/Qv39/aGtrg4hQp04dHDt2DBzHfbNfkJ6ejri4OMyePRsdO3ZE7dq14e/vj/79+3+RoNijR4+gpqaGtm3bCtb36NEDIpGI6QIVFBTA399f4AtzHIemTZtCJpPhypUrgu+fOXMGIpEIXbp0+ctzUOHnwA8NsIlIRkRHiWjwl2yvMqa/BpYuXVospfvdu3ewtraGo6MjqzpmZGSgbNmysLCwYNVpvg/Vy8tLQP8aNGgQRCLRXwYVI0eOBBHBwcEBfn5+KFeuXLHGvShOnDiBRo0a4fnz58jNzUXVqlVZj7a9vT0WLVokCAQzMzOxePFiZngtLCwwYcIEJVrOx48fsXz5cmZ8RSIR/Pz8sHTp0s9Wa4HCQGzr1q0ICwuDu7s7RCIRM0alS5dG8+bNMWHCBOzatQu3b98uVujjn0BeXh7u3buHvXv3IiIiAq1btxbMtuSz7127dsWGDRuUHI9PkZqaijVr1iAoKIhRxCpVqoQFCxYo0cZTUlIwY8YMRjG3sbHBrFmzBIY2MTERYWFhLKv7+vVrTJs2jfVNcRzHnJPz58/j1KlTmD17NqpXr86U0WUyGdzc3ODn58fEbrp16/bZ67h79y7kcrlgZjXHcahXrx7U1dWZanh6ejqcnJxgaWnJri8nJweVKlWCoaGhUlKGp5xPnDjxs8dX4eeBKsD+MX6Byif4NfDmzRuYmZnBzc1NSQxq2LBhICLs3LmTrZs+fTqIhKKnfB/q2bNn2brbt29DKpWiU6dOnz3+3bt3oaWlBW1tbfj6+sLb2xtmZmZKFNuiKOoTAEDdunVRtWpViMViyGQydO3aVTD5ASi0H82bN4dYLIZEIkGLFi2UksF8cNelSxcW+FlaWmLw4MGfZbsBhXbhzJkzmDhxIgIDAwUq3bq6uvDz80P//v2xcuVKnDt3Du/evfvsfflR4DgOb968wZkzZ7B06VL07t0bvr6+AhFWfX19NGjQAFOnTkVsbOxnqf4FBQU4c+YM+vTpozRC69MKt0KhwKFDhxAcHAyRSASZTIYOHTooUbm/xi84deoUlixZgo4dO7JWQL61z8fHBz4+PlBXV4ehoeFnVeE5jkODBg2gpaWFFy9esPW8CGpR/QFeJHDNmjVsHe9bz5o1S7BfntHm4OCgYrT9QvhhATYRiYhoAxHN+9LvqIzprwGO49C4cWPI5XKlERtnzpyBWCxG+/bt2Uvx2rVrkMvlCA4OZuu2b98OIhKITHz8+BGWlpYoX758iS9jhUIBbW1ttG7dWvDSzcvLKzHDm5OTg4iICIwfPx5AYcATERGBlStXYteuXfDy8gJRoZJ1QEAA2rZty7KcBQUFOHDgAOrWrQv6f2Gu5s2bIyoqSul4Dx48wIQJE1iAJhaL4efnh3nz5n1Rj9WHDx9w7NgxTJkyBc2bN4eDg4Mg6BaLxbC3t0dAQAC6dOmCcePGYfny5di7dy/Onz+Pe/fuITk5GdnZ2X8pfsFxHHJycvDmzRvcv38fFy5cwP79+7Fy5UpMmDAB3bp1Q2BgIBwdHVmPM7/Y2dmhSZMmmDBhAg4dOvRFRv7FixdYunQpgoKC2P7s7e3x+++/K/W88w5K+/btmWCcn58fduzYUWxVuaCgAEuXLkX37t1ZEqJ///6YOXMmnj17hnHjxuHq1atYsWIFatSowa6jQoUKGDNmDM6cOaMksta3b1+IxeIS++w5joOfnx/09fUF1Q5e3XPRokVsXadOnZT6DvmRXHv27BHs99GjR9DR0YGvr+/fUmpX4Z+FKsD+MX6Byif4dXD48GEQEfr27StYn5eXh6pVq0JfX59VLhUKBerUqQMNDQ32/v/48SPs7e1hZ2cnSKCOGDEC9Iko1KcYNmwY5HK5IKgBUKJ4Zkk+werVq/HgwQOEhYWx6qmLiwvq16+PI0eOACh89z958gRDhw5lrK2yZcti7ty5SgnizMxMbN26FY0aNWIiXtbW1ujfvz9OnDjxl/3lBQUFuHnzJlauXImwsDB4eXkJ5mPzfkuVKlXQsmVLhIeHY86cOdi6dStOnDiBGzdu4Pnz50hLS/uiCnhBQQE+fvyIxMREXL9+HdHR0di8eTNmzZqFQYMGoXnz5qhUqRJr4ypanfb19UXfvn2xbt063L59+y+Pl5OTgyNHjiAsLIz1paurq6NFixbYvXu30m/36tUrTJkyhdHMzczMMHbs2BJ7kf/KLxg7diy2bNmCbt26QU9PjwX1oaGh2LJli9J+b9++DSLCzJkzS7ymP/74Q6n4xOsRVahQgV3TuXPnIJFIBH7szZs3oa6ujrp16yqxQ1u3bg2JRCLoIVfh58ePDLCr//8/300iuv7/S4PPfUdlTH8dpKSkwMLCAmXLllXqE5kwYQKIhCOLFi5cCCLC7Nmz2bpOnTpBLBbj9OnTbB1PFYuIiCj2uFlZWSBSVh39HF6+fImuXbvi2LFjAICHDx9ixIgRrBeW4zjs37+fiZqIxWIYGBhg69at4DiOBXX3799HeHg4o4/b2tpi3LhxSrQhjuNw48YNjBkzRkAFd3V1xZAhQ3Ds2LEvVuLMyMjApUuXsHHjRowePRpt27ZFtWrVYG5uLgi+P10kEgm0tLRgaGgIU1NTmJmZwdTUFIaGhtDW1maV45IWU1NTVKlSBa1atcLIkSOxdu1axMbGfrHyaW5uLk6fPo3ff/+d0b+JCGXKlMGwYcOK7bt6/vw5pk6dyih2urq66N27t9LoruIq+c+fP0e7du0YxbBbt25YuXIljh8/jqZNmzJnqWzZspg0aZKgPaE4rFy5EkTEZrl+ihUrVihVYG7cuAE1NTVBImnDhg0gIowdO5Ztx2ezi3NEvby8oKenV6KCqgo/J1QB9o/xC1Q+wa8FPnH4qfLxo0ePoKuri2rVqrGgMikpCaampnBzc2PssfPnz0MikaBdu3bsHcqPNXJwcCix3ahz586wtLT84vP8K58AKGyH8/b2ZrZDXV0dEyZMQGZmJvMJsrKysHbtWpakl8vlaNWqFQ4fPqyUDH7//j3WrVvH6PREBD09PbRs2RJr1qz5bHtZUSgUCjx69Aj79+/HrFmz0LNnTwQGBqJMmTLsXEta1NXVoaenB2NjY5iZmcHMzAzGxsbQ09Nj51TSoqmpibJly6JevXro3bs35s6di0OHDuHp06dfrGb9+PFjLFu2DE2aNGHVbi0tLbRo0QJbt25VSmjn5uYiMjISjRs3Zj6Ln58ftm7dKvADSmL3FecXzJw5E5MmTYKNjQ07fmhoKA4dOvTZhAfHcZDL5Rg2bFixn797905p7JxCoUBQUBDU1dVZIundu3ewsbGBvb09SyRlZGTAxcUFZmZmeP36tWC/vE5LST6xCj8vfliA/S2Lypj+WuD7q7p27SpYz/eWFM1O870lUqmUUcDT0tLg6OgIKysrQea3ZcuWkMvlxSpychwHIyMjdO7c+YvP8/nz56hUqRIzzocOHUL37t0Fgdv06dMRFhaGc+fOYcCAAczYeHp6YujQoahZsybmz5+PN2/eICcnB1u3bkWdOnVYkFu9enUsXbq0WJG2Bw8eYM6cOQgICGBZbHV1ddSuXRsTJ07EqVOnvmn2cl5eHp49e4ZLly7h0KFD2LhxIxYuXIiIiAj8/vvvGDRoEPr06YOePXuiR48e6NmzJ/r06YNBgwZh5MiRmDx5MhYsWIANGzbgwIEDiIuLw9OnT7+Jip6Tk4Nz585h2rRpqFu3LjOeEokENWrUwLRp05CQkKBkiN+/f4+1a9ciICBAcC/Xrl0rSNzk5ORg8+bN8PX1RWhoaLHn8PDhQ7Rv3x5OTk5wdXVlBlRfXx+9e/f+rJjKpxg2bBikUmmx9yIxMRE6Ojrw9/dn++MNpLm5OWsNuH37NjQ1NVGzZk1mcBMTE2FgYAAPDw+lDD1fqSk6U16FXwOqAFvlE6hQ+J6uWLEijIyMlKrJPGttyJAhbN3Ro0eV+konTpyolKA/ceIEiAgDBw4s9rhjx479LOPoU3yNT/Do0SOsWbOGjXzS19dHixYt4OXlxXwCoDDB2q9fP0YxNjc3x+DBg3Hp0iUlu5Oeno7du3ejS5cuTJiL/r/VKiwsDFu3bv3igLsoOI7Du3fvEB8fjxMnTmDHjh1Yvnw5Zs6cibFjx2LIkCHo168fevXqhR49eqBHjx7o1asX+vfvj6FDh2LcuHGYNWsWVq5ciZ07d+LUqVO4ffs23r9//9UjoTiOw+PHj7FhwwZ07dpVIHBmY2ODsLAwHDhwQKnYwIvA9e7dmxUzzMzMMGzYMNy9e1ew7fXr1xEWFgYjI6MS+/QfPnyItm3bwsrKCmZmZszP8Pf3x4YNG0oUo/sUiYmJICIsXLiw2M/bt28PqVQqYHXOmDEDRIRly5axe9K4cWPIZDJcvHiRbcez3KKjowX7vHv3LjQ1NeHv769itP2CUAXYKvwt8P3Q27ZtE6zns9Ourq7sBZaamgo7OzvY2NgwWvHly5chk8nQqFEj9gJPTk6GiYkJKlWqVGxGsUWLFjA3N//iF86VK1dQp04ddg69evUSqJ0DQNWqVQXiF02aNEGTJk1YBVpTUxOurq7w8vIS9OAkJiYiIiKCiXlIJBLUrVsXK1euLPaFn56ejgMHDmDgwIGoUKECe9lLpVJUrlwZffr0wZo1a3D9+vV/ref6r5Cfn4/4+Hhs3LgRAwYMgLe3N1NQ5Sv1ffr0QWRkpJIoCVAYVG/YsIG1GRAV9pyPHTtWSaju9u3bGDx4MDO0jo6OxRq4vLw8REZGon79+uye1q5dG5s3b/7quZ0cx8HV1RV+fn5KnykUCgQEBEBLS0tA++/cubPAQGZmZsLNzQ0mJibM0czLy4O3tzd0dHSUKujHjh2DSCT6y75vFX5OqAJslU+gQiH4fmg/Pz8lG92rVy8QEfbv38/W8crI69evB1CYoK9VqxY0NTVx+/Zttl3v3r0hEolw6tQppWOePHkSRIQ//vjji87xW3yC5s2bo0ePHmjdujVrc7KwsIC7uzsePnzItsvJycGuXbsQEhLCEuoODg4YMWJEsWOoeMbb7Nmz0aBBA0HftbW1NVq2bInp06cjOjq6xNGWPwNev36NI0eOICIiAiEhIYz2zVOvmzRpggULFuDOnTtK90ChUODcuXMYNGgQm2+trq6O1q1b48CBAwI2QHp6OlatWsVYA+rq6vjtt9+K1X65e/cuhg4dCmNjYxAV9sKPHj36i8TKPsXixYtBRMUWfnbt2gUiYi0HQOHcdb6lkL9efpJIUcV9fuZ7UZYbAGRnZ6NChQrFJqtU+DWgCrBV+FsoGjQUNTLAn0FDaGgoe8FcvHgRMpkMwcHBrM9k3rx5SvTx3bt3g0g4G5AHPwLpUxra586xS5cucHR0RHBwMIYMGSKoHl69ehXVq1dniYCCggI4Ozvj3r174DgOp06dQocOHVhV29raGvPmzUNycjKeP3/OjOa1a9cwfPhwlC5dmlHNa9asiZkzZxZrVIBCutD+/fsxYsQI1KpVi4miEBUKb7m7u6N169YYN24cNm/ejAsXLiA5Ofmrs8lfC47j8PbtW1y6dAnbtm3DxIkT0b59e1SsWFFAJdPU1ET16tURHh6O3bt3lyjs9ujRI8yfPx916tRhzomVlRUGDhyICxcuCK4nNTUVy5YtYzNGP9f3fufOHQwdOpQJ1pUqVQq///670rP4NYiNjRVknYti/vz5ICImlAL8SQPnaWgcx6Fjx44QiUQCB23w4MHFVqhfvXrFklH/KzPS/2tQBdgqn0CFP7Fu3ToQEcaNGydYn52dDQ8PDxgYGLA2mPz8fPj5+UFTU5NVkF++fAkTExO4u7sLRnw5OjrC1tZWKXFbUFCAUqVKoW7dul90fn/HJwAKCwhTpkxhtl5DQwNhYWGIjY3Fs2fPcPnyZQCFtmzVqlUIDAxkFGdra2v06dMHhw8fLjb5m5+fj8uXL2PevHlo3bq1oPLL27jAwEAMGDAAixcvxtGjR/Hw4cN/JCGfnZ2Ne/fu4fDhw1iwYAH69u2L2rVrM2Gyoq1gHTp0wOLFi3Hjxo1i+7EzMjKwd+9edO/enQXjcrkcDRs2xMaNGwVsBIVCgVOnTuG3335j7DgXFxfMnTtXSQMmPT0da9asga+vL/MfQkJCcODAgW+uAnMch0qVKqFcuXJKvhffY+3p6ckKQjwN3M7Ojo37Onv2LCQSiWCSyK1bt6ChoVFshbp3794gIhw4cOCbzlmFfx+qAFuFv42nT5/CwMAAnp6eSrTX8ePHK/WqLlq0CESEyZMnAxDSx8+dO8e241VF+bEGPPLy8mBtbY3q1at/VaB58eJFJlSybds2NgYhMjISffr0YeOxtm3bBn9/fwCFL/Zjx46hffv2WLhwITQ0NFhVWywWw8bGBmXKlIGdnR3++OMPcBwHjuNw9epVjBkzBuXLl2dGx87ODj179sTOnTtLzEQrFArcuXMHW7ZswfDhwxEcHAx7e3ulfmt1dXU4ODigRo0aaNGiBcLCwjBq1CjMnDkTy5Ytw8aNG7Fz507s27cPhw4dwuHDh3Ho0CHs27cPO3fuxKZNm7B8+XLMmjULo0ePRu/evdGqVSv4+fmhTJkySkIqPK2rbt26CA8Px4YNG3Dr1q0SR1l9+PABe/fuRb9+/VhPNf1/D/SwYcMQGxsrMLpZWVnYtWsXmjVrxqrabm5umDVrllLQ/uHDB6xYsQLe3t6MNdCkSRPs37//u8wUb9q0KfT19ZXohjdv3lTqsY6Pj4empib8/PzYsZcvX67kXPIJoz59+gj2WVBQgICAAGhoaCj1mqvw60AVYKt8AhX+BMdxCA0NhVgsVhIne/jwIXR1dVGlShXmLyQlJcHMzAxOTk5M54Onj3fq1Im9b2NjYyGRSNC+fXulY06aNAlEpCS8+jn8XZ9g1apV0NbWRsOGDdmILT09PVhZWcHKyor5BEDhfOM1a9agSZMmbFsNDQ00aNAAc+fOxa1bt0r0Z1JSUhAVFYUZM2YgNDQUlSpVEih2ExVOMDE3N0elSpUQHByM3377DeHh4Zg8eTIWLlyItWvXYvv27YiMjMSBAwdw+PBhHD58GAcOHEBkZCS2bduGNWvWYMGCBZg0aRIGDx6Mjh07on79+qhYsSJMTEyUfAIdHR1UqVIFnTt3xpw5c3DixIkS50cXFBTgypUrmDFjBurUqcPsvI6ODlq2bInNmzcLEiccx+HmzZsYOXIka/fS0dFB165dce7cuRKV2/n74uTkhOnTpyv1NH8LoqOji026KxQK1K1bF+rq6oxtoVAoEBwcLKCBJycno1SpUnBwcGDXmJaWBmdnZ5iZmQmEUgFgx44dICIMHjz4b5+7Cv8eVAG2Ct8Fe/bsAZGycFNBQQGCgoKgpqbGsrocx6Fdu3YQiUSIiooCUEgbLl26NCwtLVlAlZaWBgcHB1hbWytlKXnRtKIVwq/BtWvXEBsbC6BQ4TokJIRlyn18fLBq1Sp2rpmZmfj9999Rs2ZNXL16FQCwdetWGBoasr4sNTU11KpVC1u2bFEKzBITE7F06VI0btyY0b/4OZd9+/bFtm3bkJiY+NlkQXZ2NhISErBv3z7Mnz8f4eHhaNOmDRsrZWxszEZOfe0iFothaGgIJycn1KhRA61atcKgQYMwd+5cREZG4ubNm39ZWU1KSsKuXbswaNAgeHp6snPR0NBA/fr1MX/+fCX6d1ZWFiIjI9G+fXt2X0xNTdG/f3+lvrXiZmO6uLhgxowZ38WA8rh8+XKxdK3MzEy4urrCzMxM8HzyBpJXHL148SLkcjnq1q3LMtL3799Xcih58IKA/POmwq8JVYCt8glUECI9PR1ly5aFubm5UgDBJxx79erF1p06dQoSiQTNmjVj7/6xY8cqJej5Hm2eUs4jNTUVurq6aNy48Ted79/1CU6cOAFzc3P4+vqyhLijoyMmT56s1DuclZWFQ4cOoW/fvmwUKBHBxMQEzZo1w9y5c3Hx4sXPVqU5jsOLFy9w6tQprFmzBuPHj0fXrl1Rv359eHh4wMrKigXy37JoamrC2toanp6eCA4ORvfu3TFx4kSsX78eMTExePXq1V/6LOfOncOMGTPQqFEjprjOJ88HDx6M6OhowTXyQfW4cePg6urKEuj16tXDpk2blPyQR48eYfz48UygVltbG126dEFMTMx3Y/lxHAdvb29YWloqMQ54yvfSpUvZOn78Ft/Klp+fj1q1akFdXZ0lfziOQ4sWLSCRSJRaHvgElJeX10/bJqjCl0EVYKvw3TBo0CBQMX1QKSkpsLa2hq2tLavcZmRkwN3dHUZGRmx0x7Vr16Curg5/f39WDbx48SKkUimaNGkieGHm5OTA3t4e5cqV+9tVy8zMTHTp0gVlypRB48aNMWLECABQ6v8OCQlh2e5hw4ahXr16OH78OPbs2YPevXszmhRf5Vy+fLnSzMS8vDycO3cOkyZNQkBAgKBSbG5ujkaNGmHcuHGIjIzEo0ePvmi0Bg+O45CWloYXL17g/v37uHHjBi5duoTY2FicP38esbGxuHTpEm7cuIF79+7h+fPn+Pjx41cfIzExEfv378ekSZPQtGlTWFlZsWtQU1NDzZo1MWbMGJw8eVIpoExOTsa6devQtGlTdu2Ghobo0qULoqKiBL8lx3G4cOECBgwYwCjg/GzMv5on+i3gOA41a9aEsbGxEgWxW7dugoQOx3Fo3rw5JBIJG7/15s2bzz7nnyqD8yKBRVsoVPg1oQqwVT6BCsq4efMmNDQ0ULt2bSUK7JAhQ0BE2LBhA1vHByzTpk0D8GeCXi6Xs2pgQUEBatasCS0tLaU51VOmTAHR50d6fQn+jk9w8uRJXL9+HXPmzEHVqlWZbXR1dcXIkSNx/vx5pXuRmJiINWvWoGPHjrCzsxPYUy8vL/Tu3RsrV67ExYsXv1iUi0dOTg6Sk5Px+PFjJCQk4MqVK4iLi8P58+dx/vx5xMXF4erVq0hISMCTJ0+QkpLy1YFdWloaYmNjsWzZMvTo0QOenp6C8Z6Ojo7o2rUrNm7cqOQT5efn49SpUxg8eDALlEUiEfz8/LB48WIlBltycjIWLVoEHx8ftq2/vz/Wr1//Q2ZE8+J8RRXmASAuLg5SqVRA+Y6KioJYLEbbtm3ZuvDwcKWE0OzZs0FEmDFjhmCfRVsoeL9YhV8XqgBbhe+G3Nxc1o/N9yrx4Ct7QUFBgsqenp4ePDw8WGaS790KDw9n350zZw6IhD3awJ/CEkUFI/4OHj16xEaGvX//Hhs3bmSZ9+zsbPj5+eHy5ct4/fo1dHR0MGzYMAwaNAh+fn549OgRCgoKcPr0aQwcOFDQN1WhQgUMHz4cx48fVwo48/LycOnSJSxcuBChoaFwcXER0ME1NTXh4eGBNm3aYPTo0Vi7di1OnjyJhw8ffpPy+JcgJycHjx8/xunTp7FhwwaMGzcO7du3R5UqVQQ94rzhbNOmDebMmYPz588Xe30xMTEYM2YMqlSpwq7N0tISvXr1QlRUlMBp4TgOV65cwYgRI9g9lMvlaNasWbGzMb8nNm3aVCwNjO+xHjlyJFs3bdo0EBFmzZrFrrNWrVpQU1NjNEOO49C2bVulXmygsMfwUxFAFX5dqAJslU+gQvHgRZw+1VMpWtm7du0agMJ3ZqtWrSAWi9k78+3bt7CxsYG1tTUTDn3x4gWMjY0FPdpAYWXYzs4Orq6u36X693d9AgB49uwZ5s+fD39/f9aHbWRkhHbt2mHdunVKASdQqHL+xx9/IDw8HDVr1hQIn/HtZvXq1UP//v0xf/587Nu3D9evX8fbt29/SLJWoVDgzZs3uHr1KiIjIzFnzhz06dMHgYGBTJSMX/T19VG7dm2MGDECkZGRxTLMnj59ipUrV6JFixZsBrVcLke9evWwbNkyJcbD27dvsWrVKgQFBbF76O7ujqlTp5Y4SvN7IC0tDVZWVqhYsaIgKfLu3TvY2trC1taWtRE8efIEhoaGcHd3ZzZ969atSq1hJ0+eVGJq8OjZsyeIvlxfSIWfGyX5BaLCz/5ZVKpUCVevXv3Hj6vC98Pz58/Jw8ODLCws6MKFC6SlpcU+W7lyJfXo0YNGjBhBU6dOJSKigwcPUqNGjaht27a0adMmEolE1LdvX1q8eDFt3ryZ2rVrRwCoefPmtH//fjp58iRVr16diAqTQMHBwXTmzBlKSEggW1vbLzrHgoICun37Nr1584akUimZmJiQvb09aWpqsm0yMzNp/vz5tGnTJipdujQZGxvTu3fvaP/+/XTixAkKDw+na9euERFR//79qVq1atSuXTv2fQCUkJBABw4coMOHD9P58+epoKCA1NXVydfXl/z8/KhGjRpUtWpVwXH5Y9+6dYvi4+MpISGB7ty5Q/fv36fExETiOE6wrb6+PpmampKxsTEZGhqSnp4e6ejokJaWFqmrq5NcLiepVEoikYgAkEKhoLy8PMrOzqasrCxKT0+njx8/0vv37+nt27f05s0bSk1NFRxDJBKRtbU1OTk5kYuLC7m6ulK5cuWoXLlypKurK9g2NzeXLl++TDExMXT69GmKiYmhzMxMEovFVLVqVapfvz41bNiQPDw8SCQSERGRQqGguLg4ioyMpN27d9Pjx49JLBZT7dq1qV27dtS0aVPS19f/7G+ak5NDT58+peTkZMrPzycjIyNyc3MjuVz+Rc/E27dvydXVlezt7en8+fMkkUiIiOjWrVvk5eVFVatWpejoaJJKpXT06FGqX78+tWrVirZu3UoikYgGDBhACxYsoA0bNlBoaCgREc2ZM4fCw8NpypQpNHLkSHas/Px88vf3p+vXr9OlS5fIxcXli85RhZ8TWVlZpKWldQVA5X/7XP5r8PDwwJUrV0gsFv/bp6LC30DXrl1pzZo1dPDgQWrQoAFbn5ycTJUrVyapVEqXL18mIyMjyszMJG9vb3rx4gVdunSJHBwc6OrVq+Tr60teXl507NgxkslkFBUVRfXq1aMOHTrQ+vXrmT05ePAgNWzYkCZMmEBjx4794nN89eoVPXjwgPLy8khXV5dsbGzIzMyM7ffv+gRERO/fv6ejR4/SoUOH6OjRo/TmzRsiInJ2diZ/f3+qWbMmVa9enaytrQXf4ziOHj9+TDdv3qT4+Hi6c+cO3bt3jx48eEAZGRmCbeVyOZmZmZGJiQkZGRmRvr4+6erqkra2NmloaJCamhrJZDL2P8VxHOXn51NOTg5lZ2dTRkYGpaWl0fv37+ndu3eUkpJCb968ofz8fMFxdHV1ycnJiZydncnV1ZXc3d2pfPnyZGtry+4ZUaEv9OTJEzp79iydOXOGTp48SY8fPyYiIktLS6pbty4FBwdTYGAg6ejosO8lJSXR3r17affu3XTy5ElSKBRkZWVF7dq1ow4dOlC5cuU++3sCoKSkJHrx4gWlp6eTmpoalS1blkxMTD7/IBRBv379aPHixXT+/HmqVq0au1+NGjWiY8eO0dmzZ6lq1aqUmZlJ1atXpydPntClS5eoTJkydP36dfLx8SFPT086fvw4yeVyev78OXl6epKRkRHFxcUJ/KeNGzdSx44dafjw4TRt2rQvPkcVfj4UFBRQeno6GRoaFu8XFBd1/+hFLBZ/9UgdFX4+8OIkHTp0UMrQ9ejRQ4lGPnnyZBARG5ORm5uLGjVqQF1dnfVtf/jwAY6OjjA3NxdkfJ8+fQptbW0EBAR8EdV5165djG786WJvb49mzZph2rRpiImJYdXS3bt349ixY0x8ZcWKFYwylpWVhWnTpgmqm8UhLS0N+/btw4ABA5TGc3l6eqJ3795Yu3Ytbty4UexoMv6+PHjwANHR0Vi7di0mT56Mvn37omXLlqhduzYqVqwIe3t7GBsbQ1NTU0kYjV9EIhE0NDRgZGQEOzs7VKhQAbVq1UKLFi3Qu3dvTJw4EatXr0ZUVBTu3btXYtU4Pz8fCQkJbFxXtWrVBOO6XFxc0Lt3b+zatYtleXl8/PiRzQHlxVMkEgl0dXVhYWGBgIAATJo06bP3dOfOnWjdujXKlClT7LUaGBhgzZo1n90H8GfVRCaT4ebNm2z9+/fv2TPHZ9Tv3bsHfX19lC9fnmWp16xZAyLhjNbo6GiIxeJis9QDBw4EEWHr1q1/eW4q/NxQKBR8BUdVwf4BCxXTa6vCr4esrCxUrFgRBgYGgvGGQCHVVi6Xo06dOqxF6NGjRzAwMICbmxvTNNm4caNSNZDXsFi0aJFgn23atIFMJvsiwbPXr18jKCioWFupp6eHGjVqYNCgQfjjjz+Y1sb38AkUCgWuXbuGmTNnon79+oIqtZWVFVq0aIHp06fj+PHjSho0PDiOQ3JyMuLi4rBjxw7MnTsXw4YNQ6dOndCgQQN4eXnB2dkZFhYW0NXVZWPDiltkMhn09PRQqlQpuLi4wNvbm4mljRgxAvPnz8euXbtw6dKlz1bK37x5w8Z1NWnSROBv8eO65s2bh/j4eME+FAoFLl++jIkTJ6JKlSrsOxYWFjA3N4e9vT3q16//Wb/g2bNn6NevH7y9vZXYdvzSrFmzYseHfgp+9Fv//v0F68eNGwciwpIlS9hv0KpVK4hEIhw6dAhAIY3dxsYGlpaWrIKflZWFypUrQ0dHR6m14caNG9DQ0ECtWrW+i1CrCv8ueIZjSX7Bv2ZMi4pZqPDrghci+XRmcU5ODry9vaGpqYkbN24AEL6g+JEERV9QvFG7desWtLS04O3tLQj6VqxY8UVU8SdPnkAqlaJKlSrYtGkTTp8+jePHj2PLli2YOHEiWrVqxXqAiP4ULpswYQLOnj3LAt/U1FQEBARg1KhRmDRpEtq1a4fjx4+za+Hxxx9/YPz48Th37pxS0Jyamor9+/dj5MiR8Pf3FxgDuVyOihUrokOHDoiIiMCuXbtw8+bNb6ISFxQUICcnB9nZ2cjJyfmml3dWVhbi4+OxZ88eTJ8+HZ06dULlypWZ2BhRoZhZ9erVMWTIEERGRirNAC8oKMDFixcxZcoU1KpVi/VniUQiJg53+PBhHDx4EEBh4sTHx4cJyBQHXl20WbNmGDduHDZu3Ijo6GicOXMG27dvR61atUBEn90H8KfjFhERITjf+vXrQyaT4ezZswAKA25nZ2cYGxuz/qizZ88W6xwaGhoKnEMe27ZtAxGhX79+X3bzVfipcfDgwc8aUtXy9xaRSAQ3N7ev+k1U+Dnx8OFD6Ovrw8PDQ6m9iU9SDho0iK2Ljo6GRCJBo0aNWPKc79vm23h4xWapVCqYNvL27VuYm5vDzc3tL1upQkJCoKGhgcmTJyMqKgpnzpzBvn37sGDBAvTq1Qve3t4CW1emTBn06NED27ZtY3buS32C9PR0dOvWTfBdHvx4rvnz56NNmzZK47msra3RoEEDDBkyBKtXr8aZM2eQlJT01ZRwjuOQm5uL7OxsZGdnIy8v76v3oVAo8Pz5c5w8eRLLly/HoEGDEBQUBAsLC8E58+O6lixZUuy4rqSkJGzcuBGhoaGCQFwul6Nz586Ij4/HmTNnvtgvePHiBbS0tFCjRg3069cPixcvxoEDB3DmzBkcOXIEo0ePhlQqRffu3T97fe/fv4eNjQ0cHR0Ffhcv5ltU2Z5Xry9aIKpevbqgQMRxHDp06AAiwt69e5WO5eDgAAsLCyVqvAq/HgoKCmBhYcFrMv08AbZIJEKZMmV+9PWr8A9AoVCgYcOGkEqlLEDhkZSUhFKlSsHW1pYZmczMTFSqVAk6OjpsXNH169ehqamJqlWrMiPJjy/o0qULe8FxHIfGjRtDLpd/Npjie2n5cQol4c2bN4iMjMSgQYPg4eHBqqPa2toIDg7G3LlzsX//fowZMwZ9+/ZlvVafon///oLv1q9fH9OmTcO5c+eUqsIFBQW4c+cONm/ejKFDh6JevXpKvU1EhSqjlStXRkhICKs2L126FH/88Qeio6Nx8eJF3LlzB4mJiUhOTsb79++Rnp6OjIwMpKen4/3793jz5g0SExNx9+5dXLp0CcePH8fOnTuxfPlyVhVv1qwZqlatyuZTFl0sLCxQp04dDB48GOvXry92XFdeXh4uXryIsWPHokKFCgI10woVKmDIkCGoV68eQkJCWFLtU0eoSpUqSEhIKPF3+quZlm/evAEV07tfFA8ePICOjg6qV68uuIbhw4eD6E910Pz8fAQFBUEmk7GevKdPn8LExASOjo6swpCWlgZ3d3cYGBgoqabfvHkTmpqa8PX1VamD/kfAC+2oAuwfs/DCkby6swq/Ng4cOAAiQseOHZWCun79+oGIsG7dOraOnxYyfPhwAH8mPqVSKQtg379/DycnJ5iYmAiEJA8fPgwiQu/evT97Tjo6On+5TW5uLuLi4jBz5kw0bNgQurq6AnsWHh6OlStXYuTIkZ/1CS5dusR6jokI5cuXR//+/bFz585ig6uUlBQcPXoU06dPR7t27VC+fHmoqakJ7LGGhgacnZ1Rp04ddOrUCSNGjMCcOXOwceNGHDx4EGfPnsXNmzfx6NEjJCUl4e3bt/j48SPS09ORnp6Ojx8/4u3bt3j58iUePnyIGzduICYmBvv378f69esxe/ZsDBs2DKGhoahduzbKlClT7DlUqlQJHTt2xOzZs3H8+PFix3U9f/4cW7duRWhoqMC3MDIyQps2bbBmzRq0adMGzZo1+2F+QcuWLVG6dOkSP+c4Di1btoRUKsWFCxfY+vj4eGhra6Ny5crsnHbu3AkiYmxNjuPQpUsXJYbazJkzQUSYOHGi4FhFR3oVHVGrwq8LfkLC/6vJ/zwBNk8VffHixY++Byr8AyhKsf1UyOPixYtQV1dHjRo1WLDx/PlzWFhYwM7OjilHRkZGQiQSoVWrViz7OWbMGBAR5syZw/aXkpICS0tLwZzBT8EH53xW8Uvx7t077Ny5E2FhYXB0dBQEmh06dMC6detKFNp4+/YtduzYgV69eqFs2bKC6riPjw8GDRqELVu24P79+8VS3NPS0nDlyhVs27YNERER6N69O+rWrQs3NzfB2Ivvuejp6cHFxQWBgYHo0qULJk6ciM2bNyMuLq7Ye8txHB4/fow//vgDQ4cORc2aNQUBtZaWFsqUKYNGjRoxutS7d+/Qtm1bjB07FhEREUrqn/fu3UNgYODfEjBJTEwUBMmfoqhiZ9Hj8ImYsLAwdn29e/cG0Z/jtD5+/Ihy5cpBT0+P0b0UCgUaN24MiUSCY8eOCY6VmpoKBwcHmJubM0aGCr82UlNTQURo0aKFKsD+QYu7uzuIlMc/qvDrYvz48cWy2/Lz8xEQEAC5XM6CDY7jEBYWBiLC2rVrARS2i7m6ukJfX5+Nv7pz5w709PRQvnx5gS3hFZy3b99e4vmYmpqic+fOX3UN+fn5uHDhAiIiIuDv789ao+RyOWrVqoVJkybh/PnzxTLG8vPzERsbi8mTJyMgIEBgK0uXLo127dph3rx5OHv2bLGstYKCAjx8+BCHDx/GwoULMXjwYDRv3hxVq1aFpaWlQL37ey1yuRw2Njbw9vZGq1atMHToUCxZsgRRUVF4+vRpsb7Lx48fcerUKcyaNQutWrUSFAwkEgnMzc3h6emJTp06se//E35Bw4YNP8uKWbRoEYj+VLIHCv1Le3t7mJub4/nz5wAKfVgNDQ14e3uz1tYZM2aAiDBmzBj23YMHD0IkEqFly5ZKSSV+DN3ixYu/+XpU+LnQunVrEBEyMzNL9Av+FZEzFxcX3L17l3r16kVLliz5x4+vwvdHfHw8VatWjdzd3en06dOkpqbGPtu6dSu1a9eOunTpQqtWrSKRSESXLl0iPz8/qlChAp04cYI0NDRo5syZNGzYMBo5ciRNmTKFOI6jVq1a0e7du2nfvn3UsGFDIiI6d+4c1apVi4KDg2n37t1KwjgvX74kKyurrxY/+RSJiYl07Ngxio6OphMnTlBKSgoRETk6OpK/vz/VqlWL/Pz8yNLSUum7b968oXPnztG5c+coNjaWrl69Sjk5OUREpKOjQ+XLl6dy5cqRm5sbubi4kLOzM1laWgpEQ4oiLy+P3r59S+/evaP379/Tx48fKSMjg7KysignJ4fy8vKooKCAbS+VSkkul5OamhppamqStrY26erqkoGBARkZGZGxsTGpq6sXeywA9Pr1a7p//z7duXOHbt68SadOnaJHjx5RXl4eERHJZDLy8PCgatWqkVQqpbdv39L69euJqFDIZcuWLeTp6UmTJ08mR0dHkslkdOHCBQoPDydzc3PKyckhdXV16tGjB9nY2NDo0aMLX0glXP/nsHTpUurduzfduHGDypcvr3Qt3bp1ozVr1tD+/fvZM3T27FkKCAggX19fOnr0KMlkMpo7dy4NHjyYhg4dSjNmzKCCggImcnLkyBGqU6cOERGNGDGCpk+fTgsWLKB+/fqxYykUCmrYsCEdP36cTp06RT4+Pl99LSr8fBg7dixNmjSJIiMjqWnTpiqRsx+AypUrIzc3lx48eECpqalKgpAq/HrgOI5CQkLo8OHDFB0dTX5+fuyz1NRU8vLyorS0NLp48SLZ2tpSfn4+NWjQgE6fPk1RUVFUq1YtevLkCXl5eZGuri7FxsaSiYkJHT16lIKDg6lBgwYUGRlJEomE8vLyqFatWnTr1i26ePFisYKSzZo1owsXLtCzZ89IKpV+0zVlZWVRTEwMRUdHU3R0NF2/fp2ICm16jRo1yN/fn/z8/MjDw0PpGHl5eXT16lU6e/YsxcbG0oULFygpKYmICgVGnZycqHz58uTu7k5ubm7k7OxMjo6On7XTvGhpamoqffz4kdLT0ykzM5Oys7MpNzeXCgoKmGCqWCwmmUxGampqpKGhQVpaWqSjo0P6+vpkaGhIxsbGpKenV6INzsrKogcPHtDdu3cpISGBbt68SRcuXKDk5GS2ja2tLXl5eZGPjw/l5ubS48ePadmyZUT0z/oFGRkZZGFhQe3ataPly5crfX7hwgWqWbMmBQYG0v79+0ksFlNubi7VqVOHLl++TKdOnSIvLy9KTEykatWqkbq6OsXFxZGpqSlFRkZS8+bNqUWLFrRt2zYSi8WUkJBA3t7e5OjoSDExMQLR3z179lDTpk3pt99+ozVr1nyTj6PCz4W3b9+SiYkJ+fv704kTJ0gkEv08Imeenp7w8PCAVCpV6l1U4dcFP06rc+fOShm80aNHg4gwc+ZMtm7nzp0s46dQKMBxHBNH42lDmZmZqFy5MrS0tNhYJACYN28eiAjjx48v9lwCAwNhYWHx3cT0FAoFbty4gTlz5ihRx0qXLo1OnTphxYoVSEhIKDbLm5eXh2vXrmHVqlXo06cPqlevDn19fSX6lbu7Oxo2bIg+ffpg2rRp2LhxI44dO4abN28iKSnpb4+vys3NxevXrxEfH4/jx49j8+bNmDFjBvr3748mTZqgfPnySqIhGhoa0NPTQ7Vq1eDg4IDp06cL7uuwYcOwYMECxlD47bff2O8yevRovHz5Ek+fPkX37t2xdOlSlpU+efIkGjVq9Leup6CgAC4uLqhYsWKx/WWLFy8GEWHUqFFs3f3792FkZAQnJydG+d61axdEIhGaNWvGnsVevXqBiLB8+XL2XX4UTc+ePZWON2zYMKXtVfi1kZubCwMDA9jY2EChUKgq2D/QJ1i9ejWI/hyJp8Kvjw8fPjA9i6K0buDPanS5cuWYiNj79+/h4uICfX19xhiKjY2Furo6vL29GWWXf68XFaZ68eIFTE1NUaZMmWJpy3xf7ZYtW77b9aWkpOCPP/5AWFgYnJ2dmc3U1tZGUFAQJk6ciOjo6BL93JcvX2Lv3r0YP348QkJCULp0aYGQp0gkgrW1Nfz8/NCpUyeMGTMGy5Ytw969exEXF4fHjx/j48ePf2tkl0KhwIcPH/Dw4UPExsYiMjISS5Yswe+//44OHTqgevXqKFWqlMAnEIvFsLS0hJmZGerVqwdXV1esXbtWcB7/pl8wf/58EFGxdOyXL1/CwsIC9vb2zP5zHId27dqBiLBt2zYAhc+iq6sr9PT0GFU9Li4OGhoa8PLyYs9icnIy7OzsBFVvHjzdvEqVKiph5/8Q+NZCXuyuJL/gXzOmvODQ1KlTf+iNUOGfBU/rnj9/vmC9QqFAy5YtIRKJEBkZydbzPStDhgwBUEirqlevHiQSCRNCS0pKgo2NDczNzZnoFMdx6NixI4gIO3bsUDoPXhmyKP3ne6KgoACXL1/GnDlzEBISAmNjYwH1OjAwEL///jsiIyPx7NmzYg0gx3F4+fIloqOjsXjxYgwaNAiNGzdG+fLlBf1bny4aGhowNTVF6dKl4e7ujsqVK8Pb2xs1atRAzZo1UaNGDXh7e6Ny5cpwd3eHg4MDzMzMoKmpWeI+tbW14ebmhoYNG6J///5YsGABjh49isTERDRs2JCpdG/evBkDBw7E+fPn2XXMmTNHICayfv16tGnTBg8ePEDVqlXRsmVLuLm5QUtLC1WrVkV8fDwUCgU6duyIoUOHYv/+/ejXrx9Onjz51b8D75QXVavncfz4cUgkEgQHB7N+reTkZDg4OMDY2BgPHz4EUChgpq6ujmrVqjGjOX36dBD92RMIFD5TMpkMAQEBSmJ2mzdvFtDNVfhvgH+++P5+VYD943yCvLw8GBoawtra+osmRajwa+Du3bvQ09NDxYoVlajQUVFRkEgkqF+/PqNZP3nyBKamprCzs2P9ynwCNCQkhL3LBw0apNRCdubMGchkMgQFBSnRthUKBdzd3eHo6Pi3E9UlISkpCdu2bUPv3r1Rrlw5FiyLxWKUL18e3bp1w/Lly3H58uUS9TkyMjJw5coVbN68GePHj0doaCh8fX1hZWUFsVhcrP2WSCQwMDCAtbU1ypYtCw8PD3h5ecHX1xc1atRAjRo14Ovri6pVq6JixYpwdnaGlZUV9PX1P7tPGxsb1KxZE506dcKkSZOwbds2XL9+HVlZWT+tX/DhwweYmprCz89P6bOsrCxUrVoVWlpagkkiI0eOBBFhypQpAApFev38/CCTyXDixAkAhaKmpqamsLe3Z62NWVlZ8Pb2hoaGBi5evCg41tu3b1m7mKod9r+DnJwc6OrqwsHBgfn1JfkF/wpFvHLlyrhw4QKZm5uTlZUVo9mo8OuD4zhq3rw57du3jw4fPkxBQUHss+zsbEbjOn36NFWpUoUAUN++fWnJkiWMcpuenk61atWiu3fv0okTJ8jLy4tu375Nvr6+ZGZmRufOnSMjIyPKycmh2rVr0/Xr1+nUqVNUtWpVwbk0adKEjh8/TvHx8WRnZ/dDrxsA3b9/n1G/4uLi6NatW6RQKIiIyMjIiCpUqCCggJUtW/azc58zMjLo5cuX9Pr1a3rz5g2jgn348IHS0tIoIyODsrOzKScnh/Lz85Uo4kXpYDxFXF9fnwwMDMjExIRMTEzI3NycLC0tleZcF72u33//nSwsLKh///6UmJhIq1evJjMzM+rTpw8REcXFxdGQIUMoOjqa1NTU6MmTJxQQEECPHz+mIUOGkLOzM3l5edGePXvI0NCQ+vbtS48ePSJnZ2eqUqUKubi4kIeHB7Vp0+arZlempKSQq6srOTk50dmzZwXUq7t375K3tzeVKlWKYmNjSVdXlzIyMqh27doUHx9PJ06coGrVqtHt27epevXqZGxsTOfPnydjY2PW0tC6dWvasmULicViunv3Lvn4+JC5uTmdP39e8LtdvHiRatasSV5eXhQdHU0ymeyLr0GFnxtBQUF08uRJ+vjxI2lqapZMBVPhb6Fy5cq4fPkyjR49miIiIujChQvk5eX1b5+WCt8Jhw8fpoYNG1LTpk3pjz/+ELR1LV++nMLCwqh37960aNEi1kJWq1YtcnFxoVOnTpG2tjYtWLCABgwYQL169aLFixcTx3HUunVr2rVrF23dupXatGlDRESrV6+mbt26se2K2oWoqCiqW7cujRs3jsaPH//Dr/vDhw904cIFio2Npbi4OLp48SK9f/+eiArbrFxdXQXtYq6urmRjY1PiPPiCggJ6/fo1vXr1ipKTkyklJYW1jaWlpVF6ejplZWVRdnY2axsrShGXSqWkpqZG6urqjCKup6fHWsdMTEzIzMyMLCwsyMzMjCQSSbHn8TP7Bbw/efHiRapc+c9XNcdx1KZNG9q5cyft3r2bQkJCiIho8eLF1LdvX+rRowctW7aMOI6jtm3b0o4dO2jz5s3Url07evv2Lfn6+lJKSgrFxsaSs7Mz2+6PP/6gnTt3UvPmzdmx8vPzqV69enT27Fk6ffo0m62twq+PvXv3UkhICM2bN48GDBhARPTzUcQBoFOnTiAiAfVXhV8f6enpTBiKFyfh8fr1a9jb28PU1JTNySwoKECTJk0gEolYNfr169coXbo0jIyMGFXszJkzUFNTg5eXFxPFSE5OZvv7VNEzMTEROjo68PPz+0vFyR+BrKwsnDlzBt27d4eJiQlMTU2VVDlNTU1RvXp1dOzYERMmTMC6detw4sQJ3L9/H5mZmf/YeT58+FCgpMmjoKAAc+fOZWIeWVlZWLx4MZsFCvzJOti3bx+AwkxvUFCQEi1u2bJlmDt3LtLS0pCdnY23b99+8zlzHIdmzZpBJpMxNXoer169gp2dHUxNTRnjITc3F0FBQRCLxew8nz17BisrK5ibm7NnMTo6GjKZDH5+fqzKUdwzy6OoYF9KSso3X48KPx+ePXsGIkKDBg3YOlJVsH+oT5CQkACiwvE4Kvy3MGvWLCVhKB5Dhw5VaiE7cOAAJBIJ6tatyxhDPDWTpxpnZ2ejRo0akMlkiIqKYt/l23X4kUpF0a5dO0il0q8WQf0e4DgOd+/eRXh4OEqVKgUbGxt+zA9b1NXV4e7ujpCQEISHh2PhwoXYt28frl27hjdv3vwj7A6FQoHXr1/jypUrTLC0KH5Wv+DYsWMgIgwYMEDpM370W9FnbPv27RCJRGjcuDHy8/PBcRz69OkjaFXJyMiAl5cX1NTUBJNy+GdsxowZguMUFexbv379N1+LCj8nGjRoACKhSHdJfsG3KT18J/z++++0fv16mjVrFm3ZsuXfPBUVviO0tbVp3759VLVqVWrYsCFduHCBjIyMiIjIzMyMDh48SL6+vlS/fn1Wjd6yZQsFBgZS+/btycjIiPz9/SkqKop8fX0pMDCQzp07RzVq1KBt27ZR8+bNqVmzZrR//34yNTWlQ4cOka+vL8sYmpqaEhGRjY0NLViwgDp37kyTJ0+mcePGffW1ZGVl0e3btyk5OZny8vJIJpORrq4uGRsbk7m5ORkYGJQoWqGhoUEfP36kx48f065du+jgwYOUlZVF/fv3p4SEBLp37x7dv3+frly5Qtu2bWMCYkWho6NDhoaGpK6uTra2tmRiYkK6urqko6NDWlpapKGhQWpqaiSVSkkqlZJIJCIApFAoKD8/n3JzcyknJ4cyMzMpPT2d0tLS6P3795Samkpv376lN2/e0MePH4moMKOek5MjyJ5LJBKysbGhmzdvUkFBAWloaJBEIhGIeEilUurRowdt2rSJrly5QidPnqSuXbuSjo4OKRQKJsTi5uZGHz9+ZFVeExMTksvlpKOj89W/y9KlS2n37t00Y8YMcnd3Z+vT0tKoQYMG9ObNGzp16hTZ2dmRQqGg0NBQioqKotWrV1OjRo0oJSWFgoKCKC0tjc6cOUP29vZ09epVCgkJIWdnZ9qzZw+pqalRZmYmNWrUiF6/fk2nTp0ie3t7dqzMzExq3LgxZWRkUFRUFBkbG3/1dajw82LWrFlERN/03lDh2+Dq6kq+vr60fv16Wrhw4Te9G1T4OTF48GC6ffs2TZo0icqWLUvt2rVjn02bNo2ePn1KQ4cOJWtra2rdujUFBwfT8uXLqVu3btSlSxdav349TZ06lZKTk2n8+PFkZGREffv2pX379pGfnx81bdqUoqOjqVq1ajR16lRKTEyk4cOHU6lSpahDhw7sWAsXLqQzZ85Q69at6cqVK6Snp/dV1wGAHj16RE+fPqWMjAwSiUSkpaVFBgYGZGpqSubm5iWymEQiET148ICuX79O27Zto4MHD1J2djaNGzeO4uPj6e7du3Tv3j26ePEiHT16lHJycqjQf/8TUqmUjIyMSFNTk6ytrcnU1JT09PRIV1eXtLW1SVNTk9TU1Egmk5FMJmP+CcdxVFBQQHl5eZSbm0tZWVmUkZFBaWlp9PHjR3r37h2lpqbSmzdvKCUlhbHvli9fTj169BCcw9/1C7KzsykpKYlSU1PJ3t6esrKy6Pjx46Srq0sODg5kaGj41WJgr1+/ptDQUHJxcaEpU6YIPps3bx7NmjWLevfuTeHh4URUyGbo0KED+fr60rZt20gqldL48eNp8eLFNGTIEAoPD6f8/Hxq1aoVXbp0iXbu3Em+vr5ERLRkyRKaMWMG9erVi4YMGSI41sKFC2nZsmU0bNgw6tix41ddgwo/Nx4/fkyHDh2ikJCQYsWNlVBc1P2jFz5bDQD+/v4gIiY2oMJ/B+fOnYNcLkfNmjWVep74arSPjw/re3337h3c3Nygo6PDssvXr1+Hnp4eypQpwzKpa9asARGhadOmrM/q3LlzbEZj0RFTHMchNDQUIpEIBw8e/OJzf/fuHTp06MDGcpS0aGhoCIRWiqKgoABjxozB77//zs6xffv2rKoKFFZswsLCsGnTJjx79gzt27fH2LFjsX79ekydOhVhYWFwcHCATCaDo6MjSpcu/dm+qZIWuVwOQ0ND2NnZwcPDAwEBAWjdujX69euHiIgIrFmzBlFRUcVW+pOSkhAQEMBm1Hbu3BlLlixBcnIyDh48iOfPnyMhIQHjxo1DrVq1UKdOHQQFBcHV1RVaWlqfPS+JRIJmzZopjXf7HM6dOweZTIYGDRoIsvnZ2dmoVasWpFIpE5/gOA7dunUTZK4/fPiASpUqQV1dHWfOnAFQOBbExMQEtra2LDOZl5eHBg0aCKreRX/bJk2aQCwWf9VzpcKvgaysLMhkMri7uwvWk6qC/cN9goMHD4KIMHbs2M/8Qir8isjNzUXNmjWhpqamJEDFV6Plcrmg73by5MkgIgwcOBAcxyE/Px9NmjQBEWHDhg0ACm2Ug4MDDAwMcOPGDQCFvZK1a9eGVCplei48YmJiIJFI0Lhx46+qCK9du1Ywhqq4RSQSoVSpUsUKrX2tT5CcnIzOnTtjypQp2LlzJ+bPn49BgwbByckJMpkMrq6uKFu2LExMTCCTyb7KJxCJRNDR0UGpUqXg6uqK6tWro3HjxujWrRtGjRqFRYsWYffu3UrCXTw+5xccOHAAN2/exOnTpzFw4EB4eXnBy8sLlStXFszFLmlxdHTErl27vvh3yc3NRY0aNaChoSHorQb+HMfZvHlz5t/ExMRAQ0MDFSpUYL/TnDlzQPSnSK9CoUCHDh1AJBQu3b17N8RiMRo1aqTU53/w4EGIxWI0adJEpSPxHwQvfHv16lXB+pL8gn/dmPKUjqJiQir8d8CLP3Xs2FFJ6GvHjh0Ceg5QqARqa2sLY2NjRg0/e/YsNDU1Ua5cOUYf4lUi27dvz16aBw8ehFQqRfXq1QViKpmZmfDw8ICOjo4Snbgk1KtXD3K5HP369cPu3btx4cIFXLt2DXFxcTh27Bi2bt2KOXPmYPDgwczIf4rU1FQMHDiQ0d7j4+MRHh6Ow4cPs22WLl2Kfv36sZf88OHDWcCen5+P1atXo379+hg5ciQLBkeOHInffvsNBQUFmDZtGjp16oSbN2/ixYsXbHn16hXevn2L9PT0Ymd0fg04jsOwYcNQp04deHt7w97ens3jLCruxi+6urqoWLEiQkJCMHDgQMyaNQubN29GVFSU4D7u3bsXQ4cOhZaWFipXrvxFSqi8AI6Dg4MgKZebm4uGDRtCJBJh06ZN7Lz79esHIsLo0aMBFNK9fH19IZVKWWD87Nkz2NjYwMTEBPfu3WPf7dy5M4gIy5YtUzqPwYMHg4iwYMGCv3VvVfg5wdNZP6X4qQLsH+8TKBQK2NjYQF9fX0lMUIVfH2/fvoWjo6NAaJLHu3fv4OrqCl1dXVy/fh1A4bt4wIABICJMnDgRQGEwXrt2bUgkEuzcuRNAoW2wsrKCiYkJbt++DaBwRrOnpyfU1dWZWBWPBQsWfJXvuX//fhARfH19sWLFCpw6dQpXr17FlStXcObMGezZswfLly/H2LFj0a1bt2Lt2Y/2CTiOw4wZM/Dbb7/h1q1bePnyJfMJkpKSkJycjA8fPiA7O/tvKY8Dhba0X79+qFmzJjw9PWFlZYXatWujTJkySoUJiUSC0qVLIyAgAF27dsWECROwcuVK7N27FzExMbhy5QquXLmCEydOYPHixfDw8IBIJBKIppUEjuPQpUsXEBG2bt0q+Gz37t2QSCSoXbs2U/GOi4uDjo4OnJ2dmVjZihUrWBDOU8V532Hy5Mlsf3xhqFq1akotfNevX4e2tjYqVapU7FxzFX5tZGRkQCQSoVKlSkqf/bQBNsdxcHBwgLW19fe7Eyr8VJg4caKgb6ooFi5cCCJC165d2Qv//v37MDU1haWlJcvsHjt2DGpqavD09GQV6ilTpoCI8Ntvv7Fs4fbt2yEWixEQEMAq40BhEGVhYQFra+sSM7I8cnJyQEQYOXLk37ruV69eoVevXiwbf/36dfTu3RunTp1i2wwaNAgTJkxgf/fr148dd+nSpZg8eTJWrlyJESNGIC4uDgDQo0cPDB06FEDhqLP27dszR+VrjGZubi5evXqFW7du4eTJk9i2bRvmz5+PESNGIDQ0lBlLdXV1pSDa1NQUPj4+6NChA8aPH49NmzbhwoULSElJ+WrDzWeO/6qK/fbtW5QtW1YwwgUodDqaN28OIsLSpUvZfQgPDwcRYdCgQeA4DllZWahduzbEYjFTHX/9+jWcnJygq6sr0ILge/3GjRundB78iJh+/fp91XWq8OugUqVKMDIyUlL6VQXYP94nAIC5c+eCiJSYIyr8N3Dv3j0YGhoKRiXyKKqNweuq8OrSRIR58+YBKNR68fHxgUwmw/79+wEUKpabmZnBwsKCJUtTUlKYWnVMTAw7TtFe2SVLlvzlOXfr1g2GhoZ/K+nzs/sECoUC79+/ZyO79u3bh5UrV2LixIkICwtD48aN4eHhUWxiXU1NDa6urmjcuDEGDRqExYsX4+jRo3j48OFXJ/kzMjIglUpZpf9zmDBhQrG9/QcOHIBMJkO1atWYZs/ly5ehr6+P0qVLMz9ww4YNEIlEqF+/PmNajho1CkSEwYMHs/t348YN6OnpwdnZWUlv5fnz57C0tISVldVXsfFU+HXATxT5NIkD/MQBNvCnWEDRl4wK/x1wHMcE7dauXav0OT/aa9iwYWzdjRs3YGBggNKlSzPKLv/C9PLyYnMzx48fz2g9fCV7/fr1EIlECAwMFATZ165dg46ODlxcXD4rSKVQKKClpYWePXv+retOTU1FmzZtWBb2+PHj6Nu3r4DC1KlTJxYUAkCbNm2wbt063L59GyNHjkRqaiquXbuGYcOGMWfjwoUL8PLygoGBAXx8fATiLkVx/vx5NGzYEIGBgahevToqVaoEJycnmJubf3Zkl0wmg42NDby9vdGqVSuEh4djwYIF2LdvH27dusWM1ffC6NGjIRKJSpwVCgBpaWlMaOT06dNsfV5eHlq2bAkiwty5cwEUPm+8oEnfvn3BcRyys7MRFBQEkUiEjRs3AiislpQvXx6ampoCx4sfHRcWFqbknOzbt4/Rw/4N4TwVfjzi4+NZ4u5TqALsf8YneP78OYgITZo0Uf6BVPhP4MyZM5DL5ahRo4bSjOCEhAQYGhrCwcGBjerKz89Hs2bNQERYuXIlgMJ2n8qVK0Mul7O2oISEBJiYmMDCwoKJrL569QrOzs7Q1tYWCFXl5+cz5hM//7gkDBkyBGpqan/L/v3bPgG//3r16qFWrVqoWrUq3NzcGGOk6AzuTxdDQ0OUL18eDRo0QI8ePRAREYHNmzfj3LlzSEpK+tsV8aJISkoCUfEidUXBJ7s7deokOP7Bgwchl8vh6enJmACXL1+GgYEBbG1t2Uz2LVu2QCwWo3bt2sxXnDp1KohIwEJ49OgRzM3NYWlpqTTP/ePHjyhfvjx0dHQY60KF/x58fHwgkUhY7FEUP3WAzau11q9f//vcCRV+OuTm5qJOnTqQSqU4evSo4DOO41hvQ9G51TyVx8nJiRnZPXv2QCqVolq1aqySzQfZoaGhLFO6du1aiEQi1K5dW0DXOXXqFNTV1eHh4fHZvv+2bdtCR0fnbytDe3p6Mmpaz549leZyDxgwAEuWLGGZUw8PD1y8eBGDBw+Gt7c368Nu2rQpm7MYGhqKjRs3IjU1FZ07d8aKFSuUjBvHcTh+/DgqVqyIatWqwd/fH8HBwWjVqhW6deuGwYMHY+LEiVi8eDG2bduG48eP49atW0hJSflHe4cyMjJgYWGBwMDAz27j5+cHiUSCPXv2sPW5ubnM4eIVPzmOw8CBAwXBdVZWFurWrQuRSMTmdr5//x6enp5QU1PDsWPH2D5XrVoFIkKrVq2UAugLFy5AQ0MDlStXVlHA/sPgK2XFOUuqAPuf8QmAP9VaExMTlT5T4b+Bbdu2gYjQsmVLJbsTGxsLLS0tlC9fHqmpqQAK2WX169eHSCRi7Rupqanw8PCAmpoao1rHx8fD1NQU5ubmSEhIAAC8fPkSTk5O0NLSEhRzMjMzUaNGDUilUkRGRpZ4rufOnRMkcr8V/4RPsHLlymLtOMdxqFOnDjw9PVGjRg0EBQWhadOm6NixI/r27YtRo0Zh1qxZWLduHQ4cOICLFy8iMTHxh80NLwm///47iEhpCk1R8Fo8jRo1ErAK9u7dC5lMBk9PT/bcxMXFQV9fH7a2towVyQfXfn5+zJ7zbLp27dox+//y5UvY29vD0NCQPUs8cnNzERAQUKxfq8J/B1evXmWJnOLwUwfYAJhohYpe8d/Fx48fUaFCBWhrayuNyFAoFGjbti2ICIsXL2brY2JioKWlBRcXF9Yvs3v3bkilUlStWpVlJydNmsQCI57WuWHDBojFYvj6+gqEzw4fPgy5XI5KlSqVOBIiISEBUqkUHTp0+FvXfPDgQfj4+CAwMBB169bF8+fPceHCBdy/fx9A4Viotm3b4uPHj4iPj0etWrXw7Nkz3LhxA926dYOBgQHU1NSgq6uLiIgIHDx4EF26dMGTJ0+Qk5MDX19f6OrqwtPTkzmijx8/hlwuh4eHBzw8PNC7d2+l82rQoIGSgNO/gUGDBoGIBFWFokhPT0etWrUgFouxZcsWtj4rK4s54DxlUKFQoGfPniAqHNPBcRwyMzMRGBgIkUiE1atXAyisenh5eUEmkwnEb7Zt2waRSIR69eopUYPv3bsHY2NjlC5dutixJSr8N/Du3TsQEXx8fIr9XBVg/3M+QUxMzHdp1VHh5wavd8C/s4vi2LFjkMvl8PLyYgynrKwsBAQEQCwWY/PmzQAK/289PDwgl8vZOz0hIQHm5uYwMTHBtWvXABRWRl1cXKChoSHoe/748SOzCbt37y72PDmOQ2BgIHR1df9W0ufv+ASzZ89Gs2bNoK6uDolEgnr16gl8AgDYtWsX7O3tYW9vj2rVqv1yfkF8fDzU1NTQrl27ErdZvXo1RCIRgoKCBOyHrVu3Mt+QD67PnDkDHR0dlC5dmlWfN27cqBRc8+2KfB82UNhe4OrqCm1tbZbM4KFQKNC+fXsQEdatW/dd74EKPxf42ITXdvgUP32AnZSUhFOnTn13+qkKPxdevnwJW1tbmJqa4sGDB4LP8vLy0KhRIyUq+alTp6CpqQlXV1cW3OzZswcymQweHh6syswb6gYNGjC6z44dOyCTyVCxYkVBYHTo0CGoqanB3d2dVcc/BV8ZX7FixTdfL8dxuHXrFg4cOMD6hnfs2IHjx4+zDHOfPn3g5OQENzc3tr6goAAODg5ITEzE3LlzYWFhgTt37uDFixcoV64cXr58icWLF6NatWro27cv1q5di9atWwMoFHspX758iee0Y8cOdOjQ4bPb/BPYvXs3iAh9+vQp9vPU1FR4e3tDIpEwRwooDJBr1qwJkUjE1D3z8vKYsRsxYgQ4jkNaWhr8/PwgEonY88QH11KpVFAN37t3L6RSKWrUqKEkXpKUlAQ7OzsYGxszJ0iF/yb4dhWebvopVAH2P+cTAMDp06dZL60K/00UZR19Ws0FgMjISEgkEvj5+bF3c0ZGBku8Fg2yPT09IZPJmAL1vXv3YG1tDX19faZa/ubNG1SsWBEymUxAC//w4QOqVaumZG+K4tGjR9DW1oaXl5cSrf1rrvdbfAIAzC8YP348IiIiUKFCBZw6dYr5BADQuHFjlCtXDh8/fsT27dt/Kb/gw4cPKFu2LExNTUtMZPOBcFBQkKAFcOnSpRCJRKhZsyaj8R4+fBgaGhooW7YsazVctWqVEruR32dISAhLrr9//x4eHh5QV1cXqNoDQo2XiIiI730bVPiJkJycDCKCv79/idv89AG2Cv87uHv3LoyMjGBvb4+kpCTBZ9nZ2QgMDFSqWJ48eRKamppwcXFh3+GDZDc3N2Zcli1bBpFIhBo1arDqNv+SdXBwEKiWRkdHQ0tLC6VLl1YK9oFCY1a3bl3B2KcfBV5chMfZs2dRt25dKBQK7NixA61atcKUKVMAFIq7ubu7Q09PD1WrVkVCQgLy8/NhbGwMoNCQ8lnoTysCaWlpqFGjBu7evYty5cr90Gv6HPixaiU5Ki9fvkS5cuUgk8mYSixQGOyWL18eUqmUiU1kZmaiYcOGICJ2j1JTU+Hl5QWJRMKeo/fv36Nq1apKwTXPaKhatapSf8379+9Rvnx5aGlp4dKlS9/9Pqjwc6FUqVKwsrIq8XNVgK3yCVT4/ijKYOOZRkWxZcsWpqvC24uiQTavq/H+/XuWlOWnezx9+hSOjo7Q1NRkVev379+jevXqEIlEWLhwITtOWloaatWqBZFIVOKECD4xXLTS+SPwqU+gUCgQExODunXrYseOHdi0aRMiIiIwbdo05hNUq1YNRkZGzDb+Sn5BdnY2o1t/GtAChec8btw4ps3A09Y5jmPFkODgYBZ0b926lRVXePbjvHnzQESoV68e245f16RJExZcF2U0FGU68Jg+fbqgDU2F/y4iIiJARIiOji5xG1WArcJPhbi4OKX+Kh6ZmZms55ZXewYKqxlaWlooU6YMU4A8ceIEC5J5wY9t27ZBJpOhXLlyLGsZGxsLQ0NDmJiYMOVNoLCv1sjICCYmJrhw4YLSeX78+JFlMT8nHPItKCgoKJGivmPHDnTv3p39vXHjRoFq9YcPH+Dk5CQ4Z0dHR6SmpuLJkyfQ0dFBpUqVUKtWLTbKAyjs7zpw4ACePn36l4b09OnTuH379nd3Ii5cuAA9PT04Ojoyw1cU8fHxsLGxgba2tuCeJyQkwMbGBlpaWqzf6e3bt/D29oZIJGLCMElJSShXrhzkcjnrqUtJSUGlSpUgk8mwd+9ets+oqCioqanBw8Oj2OewevXqkMlk3/23V+Hnw+XLl4tVoy0KVYCt8glU+DHIzc1FUFAQxGJxsb3Qa9euZQw1PrjKyMhA7dq1IRKJsGrVKgCFbUW1a9cGEbHg+fXr16hYsSKkUikLvLOyslhr4rBhw1iVOCsrCyEhISAiDB06tNheZr5Xt3379t/dPn5OB+VzfsGHDx/w8OFDODk5CQoX39Mv4NXFP7WVfxfZ2dkIDg4GERU79jQvLw9du3Zlgrb8Pc/Ly0O3bt1Yfyzfi80HzTVr1sSHDx8EQXjTpk3Z88MHys2aNWPBdVpaGnx8fJQS8Tx4nZbWrVurZl3/D8Dd3R2Ghoaf3eaXDbCzs7ORm5uLR48eCeggKvz6iIqKgkwmg7e3t5JoVHp6OqpXr64UZJ87dw66urqws7NjAXVcXBwMDQ1hbm7OhImOHTsGbW1tWFtbs9nXd+/ehb29PTQ0NARVUX69urq64Fg8UlJSUL58ecjlcsH3vhWXL19G8+bNoa2tjQYNGhS7zc6dOz8bYAOF//hFDamDgwNSU1ORl5fHes6vXr0Ka2trpKen4+rVq2jUqBGAQkP5V71W+vr6ICJoaGggODhYoN79reB/FwcHBzx79kzp86NHj0JXVxfm5uaCsVnHjh2Dnp6eYP3jx4/h7OwMNTU19rs8ePAA9vb20NLSYuJlSUlJcHNzg7q6uoCJEBUVBXV1dZQvX14p0ZGbm8vEdIp7JlT474F3qj+dzVsUqgBb5ROo8OOQnp4OLy8vyOXyYitG/Lzi4OBgVsnmRSyJiFWds7OzWfA8duxYcByHjx8/ssB7ypQp4DgOBQUFTGC1ZcuWjIJeUFCA3r17s0p1caKW/JjQkJCQv/0cpqenY9y4cXB0dAQRsX7qT/Fv+wX8+CoigqurK6ZNm/bNVHkeHz58gL+/P4iItXwVRWpqKurUqQMiwujRo1nF+P379wgKCgIRYdSoUeA4DgqFglG3mzZtiuzsbBQUFKBPnz5sMgQ/53rs2LEgIrRp04YF5mlpafD19RXMVy+KHTt2QCwWo27duko6LSr89xAbGwsiYnPoS8IvF2Cnp6fj0KFDmDhxIho1agRnZ2e0a9euxCZzFX5N7Ny5E2KxGHXq1FF6UaelpbEgu2iv1KVLl2BoaAgLCwvEx8cDKKxuWllZQVdXlyl0Xr16FRYWFtDV1WXBVnJyMqpVqwYiwuTJk9nL+s2bN/Dx8WEVrE8zk6mpqfDx8YFIJMLMmTO/mRZ08eJFyGQyGBsbIywsjPWKfYrY2FjUq1eP/T1lyhSlkRX16tVjFez8/HyYmJgUu69atWrhypUrWLp0KaysrGBvbw9ra2uoqamhdu3aJZ7ruXPnsGHDBvTv3x+lSpWCSCT6W0qZq1evhlQqFfSL8eA4DnPmzIFYLEb58uWZMAvHcVi8eDEkEgnc3d2ZSMmFCxdgamoKAwMDlom/dOkSTE1NYWRkxFgKT548gYODA7S0tHD8+HF2vCNHjrDg+lOl+Pz8fDb66+/036vw64AfC1VSwouHKsBW+QQq/Fi8e/cO7u7u0NLSYuOsimL58uVs6gzvM+Tk5LAE2aRJk8BxHPLz89G5c2cQEXr06IH8/Hzk5OSgXbt2bAxTXl4eOI7DrFmzIBKJULlyZcZ64zgOs2fPhkgkQsWKFZXGMwHAggULIBKJ4O3t/c3ilxzHoXr16iAi1K1bFzNnzixxesm/7RckJiZi3759mDJlCvz8/Fgg+614+vQp3N3dIZVKGc2/KG7fvo0yZcpAJpMJdHkePHiAsmXLQiqVspaCzMxMtGjRgum6FBQUIDs7m9nyIUOGsCB8wIABrBrOq4V//PiRjWIqLql+6NAhyGQy+Pj4qKaI/I+AF9ItrhhUFD80wCaiekR0j4geEtGIv9r+r4xpeno6hg8fjo4dOyIiIoIpDEdERKBt27ZfcXtU+BWwbt26YsctAH/2BonFYgF1KD4+HhYWFjAwMEBsbCyAwnFvrq6ukMlkTKQkMTER7u7ukEgkjEKcnZ3NBLFatmzJXpY5OTnMIDdo0EBpjFdWVhZ7gbdv3/6bXrKjR48GEeHNmzef3Y4XM3n69Clyc3NRoUIFJojCY8mSJQgLCwNQ2KPWpk0bAIUVdz5L++DBA1hbWwtU1Pl79TVqoRkZGd88Gzw3Nxd9+/YFESEwMFDpXNLT01n/XdOmTZnQYU5ODnr06AEiQsOGDVl/9NatW6Guro7SpUuze3LgwAFoaWnBzs6Ojfa4desWSpUqJXhGAGD//v2Qy+WoUKGCkiOjUCjYzPbZs2d/9bWq8Guif//+ICJBEqY4qALsH+MXqHwCFYri1atXcHR0hJ6enoDJxIOvZAcGBrKqc35+PkJDQ0FEGDhwIBQKBTiOYyOfgoODkZ6eDoVCwSqxtWvXZnZ+79690NbWhrm5uSCwP3jwIPT09GBkZFRsgnnnzp3Q0NCAlZVVsW1mX3KtRITx48f/5bY/m1/AV4a/hSYfHR0NY2Nj6OnpCUZl8tixYwe0tbVhamqKmJgYtv7IkSMwMDCAkZER69VOSkpClSpVIBKJMHv2bHAch3fv3qFGjRqgImM88/Ly2BhG/hkBCgsovD5LcZXrkydPsvGuRfviVfjv4vHjx4yh8lf4YQE2EUmI6BERlSYiORHdICLXz33nc8Y0Ly8PrVu3RocOHZSEp5YvX46+ffsqzadV4dfH4sWLQURo0aKF0su6aJ9VUQrR48eP4eDgAE1NTRw8eBBA4YuSz6zymeyPHz+yTFSfPn1Y1nr69OkQi8Vwd3dn6tB8xVQmk8HW1lYQmAGFAdikSZMgEong6uqKmzdvftV17tu3j1XP/wqHDx+Gk5MTHB0dMXXqVADA2LFjsX//fgCFgWurVq1QpkwZVKtWjWXYd+3aBTc3N3h4eMDT07NYgbanT5+iQoUKX3zefNWgOAGaz+Hhw4eoUqUKiAiDBw9W+m1v3boFFxcXiMViREREMIP3/PlzxjQYMWIECgoKoFAoWIKievXqLEmxePFiiMViVKpUiVHjYmJioK+vDwsLC9YiABQabalUisqVKyslUDiOY2O+JkyY8FXXqcKvi8zMTDZR4K+gCrB/jF+g8glU+BSJiYmwtbWFkZFRsXZ27dq1SqrRRauT7du3ZzTeJUuWMBvBs6fWrVsHmUwGR0dHxoK4desWSpcuDZlMhiVLljCm2v379+Hu7g6RSISxY8cq2bGrV6/Czs4OMpkMc+bM+areXIVCAVtbW7i4uHxR8Paz+AWvXr2ClZUVKlWq9GUX+v/Iz8/H2LFjmQ/16ZSAnJwclvD08vJiejsKhQIREREQiUQoV64cHj9+DKCQFWhpaQktLS3WM833osvlcib4lpGRwfzAiRMnCtiLvKp8cT3XZ8+ehZaWFlxdXf+yMKLCfwdhYWGgz4yQLYofGWB7E9HRIn+PJKKRn/vO54zpkydPULVqVfZ3Xl4e7t27h/Xr16Nu3bpKAY8K/x3Mnj0bRIS2bdsqOUxF5x4XrSy+fv0alSpVgkQiYcFfTk4OOnToACJCaGgo68MZMmQIiAh+fn6MznX06FEYGRlBR0cHO3bsYPuNi4uDnZ0dpFIppk6dqnQ+UVFRMDMzg5qaGmbPnv3FDh7HcaxaO3LkyB+qQvo9oFAoMH36dDZz8muuc+XKldDW1oa+vr4SFZ5PZKirq8PMzExQOTx69CiMjY2hra3NfpMPHz6wEW5dunRBTk4O8vPzmSFu2LAhq3zv3r0b6urqcHJyEvSyrV27tti56Pz58FX2kSNHqpRB/4cwd+5c0BfOMlUF2D/GL1D5BCoUh0ePHsHS0hImJiZISEhQ+pyfe1ylShWmo8FxHOuPrlOnDgu+eZaTlZUVrl69CqAweDI1NYWOjg4LrlJTU5mv0a5dO2ZXMjIyGLupRo0aSpTxd+/esb7vwMDAv6SVFsXx48chk8lQvnz5Yiea/Gy4ceMGa726fPnyF3/v3r17LHHesWNHpbG8d+7cgYeHB+j/56LzCZK3b98yEbS2bdsy9uC6deugpqYGW1tbpr9z+vRpGBkZwdDQkLWPvXnzBlWrVoVYLMayZcvY8V68eAEXFxeoq6sXqxZ+4cIF6OjoKAnHqfDfRlpaGiQSCTw8PL5o+x8ZYLcgolVF/g4lokWf+85f0cHKlSuHRYsWYeXKlVi7di1GjBiBXr164cCBA192d1T4ZTFt2jSWff40mMvNzWUU7aJiF2lpaUzsghc04TgOEydOBBHB29ubzbreuHEj1NXVYWlpyRyzxMREeHl5gYjQq1cvJljy/v171r9To0YNljHlkZycjMaNG7Nj8P3gf4W8vDx0794dRARPT0+BqvnPhFu3bjE2QFEBmL/Co0ePEBgYCPr/2YF8PzWPpKQk5sDUrVuXJTvy8vIwYsQIEBHc3d0Z1fvmzZtwdHSEVCrFwoULwXGcQOBk4MCB7FmZP38+RCIRqlWrJqB/80FUYGCgErWf4zgWqIeHh6uC6/8x2NraQk9P74uqTqoA+8f4BSqfQIWScO/ePZibm8PU1LTYfvt9+/axcZ18/zRQmFCVSqWoUKECq4Jev34d1tbW0NTUZFTgZ8+eoXLlyiAi/P7774wtNWnSJIjFYjg5ObGAHCj0IXR0dKCrq4v169cL7AXHcVi2bBm0tLSgo6ODpUuXfnE1++jRozAwMICWlhZmzZr1U4poZWZmYvz48ZDL5TA3N//i5FZeXh6mT58OdXV16Ovrs6oyD4VCgYULF0JDQwNGRkaCSR8xMTGwsrKCTCZj9j8nJ4eJ0/n7+7PK8sqVKyGTyeDs7MxYiffu3YODgwPU1dUF6vT379+HnZ0ddHR0ihVwjYuLg66uLhwcHATPlQr/fcyaNQtEJBgV/Dn86wE2EfUgostEdNnGxuazJ3vx4kUMHDgQ3bt3x6hRo7Bo0aJfIqunwvcBP3euXbt2ShXegoICNpYhLCyMBVZ5eXmsf7pDhw5sDMOOHTugqakJS0tLXLx4EQBw7do12NvbQyqVYu7cueA4Drm5uUx90s3NjWVDOY7DunXroKurCy0tLSxcuFBgMDmOw4YNG2BkZASpVIphw4YhLS3ti67zjz/+gLm5Oej/x0T8LHOW4+Pj0aFDB4jFYhgYGGDVqlVfFHRmZWVh4sSJUFdXh46ODhYvXqx0r9avXw8DAwNoaGhgwYIF7PP79++jatWqTHyGD+bXrl0LDQ0NmJubsz6shIQEJnyycuVKAIW0s379+rGeGf77HMexfrtmzZqx56LoOfHfGzRokCq4/h/DgwcPWJLmS6AKsL+fX6DyCVT4Uty5c4cF2cUlsk+cOAEdHR3Y2dkJKMdHjx6Fjo4OSpUqxYLkV69ewdvbmyXqFQoFsrOzmV/h7+/PEvInT55EqVKlIJfLMWvWLGavHj16xITJGjVqpBSAPXr0CAEBAYzm/KW2/dmzZyz5bG9vjxUrVvxtle7vgfT0dMybNw+lSpUC/b/y9pfSpU+dOoVy5cox2/ypwOnDhw9Rq1YtJlzHV4rz8vIwZswYiMViODg4sHv4+PFj1nY2ZMgQ5OfnIy8vj/WDBwUFMar96dOnYWhoCGNjY0Ey4OrVqzA1NYWxsXGxv01sbCx0dXVRunTpr2IiqPDfgI+PD+Ry+Rf7g78MRbwocnNzVQ7v/yimTp0K+v9Zg58Kn3Ech+HDhzNRLL7izHEcJk+eDCKCr68vm7F87do12NnZQU1Njc3KTE1NZdXnJk2asH7cI0eOwNzcHHK5HNOmTWMBfGJiIhsFa9r/HAAAYf5JREFU4u3tLejrBQopSHyAb25ujuXLl38R/TstLQ1jx46Frq4uiAg+Pj5Ys2YNo7X9U8jIyMDmzZvZGBNNTU2Eh4eXOKe7KBQKBTZt2gRbW1vWR89XDHg8fPiQVbV9fHyYE8RxHJYsWQJNTU3o6+sz9c709HQmRlKrVi3m8PDCJ2ZmZizg/vDhA+rXr8/6vIsmXbp06QIiQvfu3ZUYEQqFgmXBBw8erHrX/A+Cf8Zu3LjxRdurAuwf4xeofAIV/gp3796FhYUFTExMiv1/vXz5MkxMTGBsbCxghd24cYNVrfkKZk5ODrMNDRo0YHOd+YSumZkZoqKiABSKg/EK5bVq1WJtRwUFBZg9ezY0NDSgq6uLxYsXC2wMn3w3MzODSCRCp06dlNhcxYHjOBw+fJhV1Y2NjTF06NAvZsh9L3Ach8uXL6Nfv37Q09MDUeFc6aLzsz+He/fuoXnz5iAi2NjYKM02z8vLw7Rp09j9W7lyJfvfvn37Nrv+jh07sqLFzp07oaenBz09PezevRtAISOOT3aEh4czv2vt2rWsml109CI/KtTGxoax5Iri7Nmz0NHRKXGUqAr/bdy5cwdEhSPdvhQ/MsCWEtFjIrKnP8VM3D73nS8xpkeOHClWUOnIkSPYtm0bEhISlESKVPhvYebMmSyI/rTyCADz5s2DSCSCr6+vIBDcvn071NXVBX05KSkpLMDr3LkzsrKy2GgomUwGS0tL1geckpLCDIOXlxfr/eIrsEZGRpBIJBg8eLBSIBwXFwdfX18QEZycnLB58+Yv6lv+8OED5syZAycnJxAR5HI5goODsXTpUjx48OCHOJVPnz7FqlWr0KxZM2hqaoKIYGtri4iIiBLHhBQFx3HYs2cPKlSoACJCxYoV2Yg0HpmZmRg7dizU1NSgo6MjYAA8fvyYZfkDAwNZUB4XFwdHR0eIxWKMGzcOBQUFyMvLw6BBg0BEqFatGqsY3L9/Hy4uLpBKpQIBvPT0dBZ0jxkzRun+FRQUMAdr+PDhKqf9fxBv375lybgvhSrA/jF+gconUOFLcP/+fVhaWsLQ0LDY3t/79++jdOnS0NTUxL59+9j6pKQkxpDix3PyyV2ZTAY7Ozu2v1u3bsHV1ZVVSHNycsBxHFatWgUdHR1oa2tjyZIlzI49ePCA2bHKlSsrKYl/+PABw4YNg5qaGtTU1DBgwIAv6uflOA7R0dEICQmBRCIBEcHFxQUjRozAyZMnf0hlOzMzE1FRURg8eDAcHByYL9KmTZtiR6YVhydPnqBbt26QSCTQ0tLChAkTlNrLjh8/zu5xSEgIs+f5+fmYNm0a1NTUYGRkxDRY0tPTWVtdlSpV8OjRIwCFDAMzMzNoamoyOm9+fj5jIwYEBLDkCQBs2LABMpkM5cqVK5b2feLECWhpacHJyUlFC/8fBa+R9DXjH39YgF24b2pARPepUDV01F9t/yXGNDc3l1E3srKyMGfOHFSuXBkNGjRAeHg4goOD0blz5y++ASr8mliwYAHr1S2uB/iPP/6AmpoanJycBFnKS5cuwdLSEpqamti+fTuAwqCKpwuXK1eOjbe4fPkynJycIBKJEB4ejuzsbHAchy1btsDIyAgymQzjxo1jQX5KSgq6desGkUgEU1NTpWo1H3i6u7uDiFCmTJkvpnpxHIfY2FgMHDgQ9vb2ICJWFW/SpAnGjRuHP/74A1evXkVqaupfBoYcx+HDhw+4ceMGdu3ahUmTJqF58+awsrJi+7a0tESvXr1w8uTJL+oXy8vLw6ZNm1C+fHkQERwdHbF582bBdxUKBTZv3gxra2tGKeOpYfn5+Zg9ezY0NTWhra2NZcuWgeM45OXlYcKECZBIJLC2tmZ9UU+fPmXCKH369GG9aYcPH4a+vr5gXAcAvHz5Eh4eHpBIJIKgu+j5t27dWtCzr8L/HsaMGQMiKnEWfXFQBdg/xi9Q+QQqfCkePXoEOzs76OrqCsY38Xj9+jUqV64MsViMxYsXs/VZWVlsPGeLFi2YwFZsbCysra0hl8tZj29mZiZTEa5QoQJTMX/69Cnq1KnDdFl4J5zjOGzevBkWFhag/xdX/bT6mZiYiC5dukAikUBNTQ29evX64jaH5ORkLFy4EP7+/pBKpSzwrVatGvr06YPly5fj1KlTSExMVGL8FYfc3Fw8efIEx48fx+LFi9GzZ09UrlxZsO+6detixYoVX5y0io+PR6dOnSCVSiGXy9G3b1+l2eAPHjxA06ZNQUSws7MTJEEuX76MSpUqsaIK/93Y2FiUKVMGIpEIw4cPR25uLgoKCjB58mTWI88zCt+9e8e0Wfr27cvuBcdxmDRpEmsB+FTkFCgcx8b38vOMORX+t5CcnMyeka/BDw2wv3b5UjoYUEhdHTBgADp06ICjR48iJyeHiRTZ2toKVIJV+G9i1apVEIlEqF69erEvxpiYGBgaGsLIyEhgcF+9egUfHx9Brw5QGJgZGxtDU1MTq1evBsdxyMjIYAbVxcWFUcySk5NZRsvJyYnRxoDCIJ6vVru4uGD37t2CYE2hUGDnzp3MaJiYmGDUqFFK6qMlgeM43L17F0uXLkX79u3h7OwMkUjEAmPeEJYqVQply5ZFxYoV4enpCQ8PD7i4uMDS0hLq6uqC7YkIDg4OaN26NRYsWIBbt259cYD56tUrTJ48GZaWluyaN2zYoJRcOHr0KLtmDw8PgYDIuXPnWMW7YcOGzAm5desWPD09We8930O1c+dO6OvrQ0dHhyVKio7rKF++vEB87vr167CysoKWlhYb3VYUWVlZTI10+vTpX3TdKvz3kJubC0NDQ9jY2HxVgkUVYKt8AhX+fTx79gxOTk7Q0NDAkSNHlD7PyMhgUyeKimByHIcZM2ZALBbDzc2NtSoVVakOCQlhjLh9+/bB1NQUcrkcERERbMTn6tWrYWBgAJlMht9//509f2lpaRgxYgTU1NSgrq6OYcOGKbVZPXz4EN26dYNcLodIJELjxo1x5MiRLxZD+/DhA/bt24chQ4agevXq0NbWFth3kUgEQ0NDlC5dGuXKlYOnpycqVaqEcuXKwd7eHgYGBko+gb6+Pvz9/TFy5EgcPnxYSQi0JOTn5yMyMpKxAzU0NNC/f3+lFrGkpCT07t0bUqkUWlpamDx5skBMtl+/fhCLxTA3N8eOHTvAcRyys7MxfPhwiMVi2NjYsCT6y5cvWStb27ZtGX2c19aRy+VMmwUobAfgW4FCQ0OLFY/btm0bpFIpPD09v4i9p8J/E7zILj/y7kvxywbYa9euRdeuXZWyga9fv0ZoaKiSsrMK/01s374dUqkUFStWVMqKAoXUsDJlykAulwvG7eTm5rI+Wz8/P0bNevnyJfz9/Vk2u2gPtpWVFcRiMcLDw1nV/OjRo3B0dGTZVb5aznEcdu3aBWdnZxARKlWqhL179yopix4/fhyNGjWCSCSCSCRC3bp1sWXLli82ZDwyMzNx7do17NixA3PmzMHQoUPRpUsXtGzZEo0aNUKDBg3QsGFDtGjRAp07d8aQIUMwc+ZMbN++HZcvX1Yai/FXyM7Oxq5du9CkSRNGU6tTpw4OHDigJGB27Ngx1gtla2uLDRs2sG1evHiB0NBQEBGsrKywc+dOJi43fvx4yGQyGBsbCyhhXbt2ZZQw/n6/e/cODRs2ZMa16P3bs2cPtLS0YGlpiWvXrildy4cPH1CjRg2IRCIsXbr0q+6DCv8trFu3DkSEOXPmfNX3VAG2yidQ4edAcnIym1+8bds2pc8LCgrYTOwGDRoI2rmioqLYeE5eTVyhUGD27NmQyWQoVaoUoqOjARTqq/DTRDw8PJhYWnJyMgvcrK2tsWXLFmb3nz59itDQUIhEIujo6GD06NFKgXZSUhJGjRoFExMTVtEdN24cU7/+UigUCjx9+hRRUVFYvnw5xo0bh969e6Nt27Zo0qQJGjRogAYNGiAkJATt27dHnz59MGHCBKxatQrR0dF4/vz5V7O4bt26heHDh7OKvaWlJSZPnlzsNQ4ePBjq6uqQSqUICwtj1eGCggKsXLkSJiYmEIvF6N27N0usx8TEMJ+qa9eu7LeLjIyEkZERNDU1mfgqn/D4dDoMUPjb8T7JhAkTir3OpUuXQiQSoUaNGsUWcFT430B2djZTjf/a/4dfNsDu2LGjgOaTl5eHuXPnonTp0pg4ceJX3QQVfm0cOXIEmpqacHBwENDBebx7945lNsPDwwW9zxs2bICmpibMzMyY4SwoKMDUqVOZQeXnIH748AE9evRgRo+vhGZnZyMiIgJaWlqQy+UYPHgwMyj5+flYu3YtSpcuDaLCMVPr169XypY+ffoUY8aMgY2NDej/xcRatmyJzZs3/zT9gx8/fsSOHTvQoUMHJr5mbm6OoUOHChRagcJ7uGPHDqbqaWlpiUWLFjE6fVpaGsaMGQNNTU3I5XKMHDmSBfmnT5+Gi4sLo5DzqqQxMTEoXbo0RCIRRo4cyWheFy5cgK2tLWQyGRYsWMBeghzHsYp2lSpVlFRKgUJDX6FChRKdMRX+t+Dm5ga5XP7VfYyqAFvlE6jw8+D9+/csabpo0aJit1m6dCkkEglcXFwElOzExETWl92/f39ms65cuYKyZcuy9XySfefOnTAzM4NEIkF4eDizYzExMWx2c7Vq1QQsuvj4eDZaVEtLC4MGDVJisOXk5GDr1q2oU6cOY6hVqlQJkydPxs2bN3+KFiaFQoHLly9j7NixcHNzAxFBIpGgYcOGiIyMVBJ0vXfvHnr27Ak1NTVIJBJ07NhRUJQ4cuQIazHz9fVlSYt3794x38vW1hZHjx4FUOiT/fbbb+ze8OJkRWeTBwQEMGFboJDNZmtrC3V1daWxYPx5TJgwgbHp+Iq6Cv+bWLFiBYioxPfI5/DLBtgHDhyAm5sbNmzYgG7duqFMmTJo3bq1kpCECv8biI2NhaGhIUxNTYsVOcnLy0Pv3r1BVDiuoWjQGh8fDxcXF4hEIvz+++8scLt69SozGl27dmVZzNOnTzNDGxISwqiHL1++ROfOnSESiaCnp4fJkyczmlJ+fj42bNjA9mdhYYEJEyYoiZooFAqcOnUKYWFhMDMzAxFBLBbD29sbo0aNQlRU1D+mJJ6RkYGTJ09i/PjxqFmzJuvDMjQ0ROfOnXH06FElA5qSkoIZM2bAzs6O0c6XLVvGnBS+R5LPzrdq1YoJkyQlJbFqtq2tLUtgZGZmYtCgQRCJRLC3t2dqpQqFAtOnT4dUKoWtra1AITY9PZ1VF9q1a1eskbx79y7s7OygpaXFDLYK/7u4cOECiAgDBgz46u+qAmyVT6DCz4WsrCw0adIERISRI0cWG5AeP34choaG0NfXZ4l0oJDhxle5K1WqxBLImZmZbHxjmTJlWNCcmprKxLasrKywfft2cByHgoICrF69mo2xatiwoYBFxY++lEgkkEgkaNGiBU6ePKl0rs+fP8fs2bPh5eXF6NtWVlb47bffsG7dOjx8+PAfCbgVCgXu3LmDlStXon379gIfpWbNmli0aJEgmOW/c/jwYUa1V1NTQ48ePQTFkHPnzjHmoL29Pbt/CoUCq1atgrGxMROQ5RMYhw8fZqzCUaNGsaLFtWvXWNvc2LFjBQWV7du3Q1NTE6VKlWLjWYsiPz8fPXv2BBGhU6dOX9S3rsJ/FxzHwcnJCZqamt80f/6XDbCBQjrPhAkTMHz4cDx58gSJiYn4448/sGjRIowbN+6rRGpU+PVx584d2NraQktLC4cOHSp2m5UrV0Iul8POzo5lR4HCYJJXj/by8mIZ7aL9PqVKlcKePXsAFBrgKVOmQFNTE2pqahg9ejQLpm/dusVGfRkZGSEiIoIF5/yoDX60l0QiQUhICPbu3av0MlcoFIiNjcWYMWPg5eXFqNgikQguLi5o164dpk2bhj179iA+Pv6raeU8srKycOfOHezfvx8zZ85Ex44dUb58ecHxKlWqhOHDh+P06dPFziA/evQo2rRpA7lcDqLCsR27du1ixi0jIwNz585l1LGAgAAWEGdnZ2Pq1KnQ1tZm1Wy+OhAdHc2q/7169WLG9cWLF0yhtXnz5gJF0Hv37sHNzQ1isRgzZswo1vE4e/YsS8j8LHPGVfh3cebMGfTu3btYpsNfQRVgq3wCFX4+5Ofns8pnSX22jx8/Rvny5SESiTB58mRBi9OePXtgaGgILS0tRj0GCgNzOzs7iEQi9OnThyW9z58/j4oVKzIbeOXKFQCFgfmUKVOgr6/PEvNFCwHPnj3D0KFDWR+0s7Mzpk+fXuy76OXLl1i5ciVatGgBQ0NDFnAbGxsjKCgIQ4cOxdq1a3Hu3Dm8evXqi3u4i0KhUODFixc4c+YMVq1ahcGDB6N27drs/IkIZmZmaNu2LdavX1/s7OsnT55gwoQJLNluZmaGsWPHClr5YmJimPiYqakp5s+fz5LxMTExbCSXj48Pm/zy9u1bRsEvqovDU/nlcjksLCzY9BdAqCDu4+NTrFhZ0f78ESNG/BQMARX+XSgUCoSHh2PDhg3f9P1fOsAuii1btqBbt24YOnQoZs2ahSFDhsDf3x+bNm365n2q8OshKSmJKUWX1E974cIFWFlZsfnXRV+k27dvh76+PrS0tLB8+XL22cWLFxl1qWnTpqzP7/nz50zszMzMDEuXLmWB8sWLF9GgQQMQEXR1dTFs2DCByMf9+/cxdOhQmJqasmC8Z8+eiI6OLjZzmpaWhqNHj2L8+PFo2LChQPG7qCiJs7MzfHx8UL9+fbRo0QLt27dHaGgo2rdvj5YtW6JBgwaoXr06XFxcBAaaXywsLFCvXj2MHj0aBw4cEASvPAoKChATE4P+/fvD3NwcRAQDAwP069dPMAs8OTkZ48ePh5GREYgK54WeOnUKQOHLa8OGDYwW36hRI5bYKNrH5ujoKFAD37p1KwwMDKCpqYkVK1YIfr8dO3ZAR0cHxsbGOHbsWLG//7Zt24pVmFdBhW+FKsBW+QQq/Jz4VCm6OHuWkZGBdu3agYgQHBwsYLg9f/6cVVdDQkJYhTY9PR39+/eHSCSCpaUl0w8pKCjA0qVLYWxsDJFIhNDQUEb/fv/+PcaNG8fmRwcGBuLo0aPMhmVlZWHdunVMJFUkEiEgIADLly9XqgwDhTb05s2bWLp0KTp37oyKFSuyJDe/yOVy2NjYwNPTEwEBAQgJCUGbNm3QoUMHdOjQAW3atEGTJk1Qu3ZteHh4wMrKirHV+EVdXR2VK1dGjx49sGrVKty5c6fYAPTFixdYsGABO38+mb5t2zaW3CgoKMCePXvYNiYmJpgxYwYrENy9exfNmjVjrWUbN25kPdVr166FsbExpFIpRo0axYLxxMRE1gbYuHFjgSBZUlISatasCSLhtJGiSEpKKlZhXgUV/g7+EwH2xo0b0blzZ+zcuRN3795l1a/IyEg0bdr0m/apwq+LtLQ0FtgOGjSo2HnTb968YWM1QkNDBSJfz549Y9XRevXqsWA6Ly8PU6dOhYaGBrS0tDBt2jT2gr9w4QITzXB0dMSmTZvYca9cuYJWrVpBLBZDKpWiTZs2iImJYQYqLy8P+/fvR5s2bdjcaQMDA7Rt2xYbNmz47GzMDx8+IC4uDps3b8bUqVPRp08ftGjRArVr14anpydcXFxQunRp2Nvbo3Tp0ihbtiwqVaoEf39/NG/eHL169cLkyZOxceNGnD9//rP93ikpKdi2bRs6derEKN5qamoICQnBzp07BX2rly5dQufOnaGmpsaocWfPngVQ6PBERkaiXLlyjILHZ5vz8/OxaNEi6OvrMyVWnt5dVFTGy8tLIPqSk5PDqHteXl5KQkf8cXlHq3r16krCKyqo8K1QBdgqn0CFnxv8rGNnZ+dix2BxHIdFixZBJpPBxsZGIIqlUCgwa9YsyOVygegmUGj7+eR7vXr1GJ286JxruVyOfv36MVv+4cMHTJs2jSWn3dzcsHTpUsaCAwqZWGPGjEGZMmVYsO3j44PJkyfj0qVLxfo1QKENvXfvHg4cOICFCxdi2LBhCA0NRf369eHt7Q13d3c4OjrC3t4e9vb2KFOmDMqVKwcfHx8EBwejU6dOGDlyJJYsWYLDhw/j4cOHJR4rLy8P586dw9ixY9m0D15rJiIiQtBXnpqaijlz5rA52ra2tliwYAH733zy5AkbV6atrY2JEyeyoPv69evMv/L29mZJfF7ITFdXF1paWli5cqUg8I+KioKpqSk0NTVLTKxdv34d1tbW0NLSEowHU0GFv4tfPsB+8OABAgICcOzYMQEV5u7du2jUqBE2btz41ftU4ddHfn4+C7iCg4OL7VsuKCjA+PHjIRKJ4OzsLOiNUigUWLRoETQ1NaGrq4tly5ax5+vJkyest8vR0RGRkZEsw7pv3z5mbJ2dnbFhwwZWjX78+DEGDRrEstdubm6YPXu2gDKVmZmJ3bt3o1OnTqyyze+rW7duWL16NW7evPnDe4MKCgpw+/ZtrF+/Hr169WLBcNHgf9u2bQKH4N27d1i8eDEbxaWpqYmwsDA2V5wfT8ZT6MqUKYOtW7ey+3rs2DF2nICAAMEs0S1btsDY2JiNRSlKU7937x4Tkxk4cGCxGers7Gw267RDhw4sMaKCCt+DCqgKsFU+gQo/P06fPg1DQ0MYGhoKWFFFcfHiRdjZ2UEqlWLatGmCZyg+Pp4Fks2bN2dU4/z8fMydOxe6urqQyWQIDw9nytfPnj1Dt27dIJFIoK6uLhhXlZOTg3Xr1jH7pa2tjW7duuHcuXMCsc7r169j/PjxgiBWT08PwcHBiIiIQHR0dLGV+e+NN2/e4MiRI5gwYQKCgoKgpaUl0ImZMmUKs/dAoc0/ceIEOnTowEaD+vj4YPv27cyGP3jwAN26dYNUKoWamhoGDBjAqvXJycno2bMnxGIxjI2NsWrVKoEfxtPL/fz8mJYLUBj4jxw5EiKRCK6urkhISCj2evbu3csmjBRtGVThfxvfqz3glw+wU1JSYGNjw/6+dOkSIiIi0Lp1awwePFglUvA/jiVLlkAikcDV1bXYrDUAnDx5EqVKlYJcLsfs2bMFBvXRo0eMelSjRg0W9AGF6uW82nXNmjWZmI5CocCOHTtYsGhra4t58+axYDQjIwOrVq1igiUSiQT16tXDunXrBEZSoVDgypUrmDFjBoKDgwX9T3K5HBUqVECbNm0wZswYrFmzBtHR0UhISEBKSopSn/SnKCgowLt373Dnzh2cOHEC69evx/jx49GhQwd4enpCQ0ODHUtHRwdBQUGIiIhAbGysYN9paWnYunUrmjRpwqhpFSpUwMKFC1nfeVZWFlasWMHGa5QpUwbr1q1j+7lx4wZjHNjb2zOqHVCYlKhfvz6rTMfHx7NjcxyHVatWQVNTE4aGhti7d2+x1/rixQumZj558mRVb9X/ODIzM5GTk4OXL18KEkR/B6oAW+UTqPBr4OHDhyhbtiykUimWLFlSrD14//49U/kOCAjAixcv2Gf5+fmYOnUq1NTUoK+vj+XLlzOf4fXr1+jSpQtEIhGMjIwwd+5clsx9+PAhOnfuDKlUCplMht9++w03b94EUGjLYmNj0blzZ8Zic3BwwOjRo3Hjxg3BOb5+/RqbN29G9+7dmdgqv1hbW6NevXro378/5s+fjz179uDSpUt49uwZsrKyPmv7OI5DRkYGnj59iri4OOzevRtz5sxBnz59UKdOHaafwlfT3d3d0bt3b+zYsUPAfOM4DhcvXsSQIUNgbW3NWuR69eolKGLExcUxZp+amhr69OnDEg8ZGRmYPHkydHR0IJVK0b9/f+Yb5efnY9asWdDU1IS2tjYWL14s8NkePHjAVOC7du3KKuSfXis/YaRy5crfpLuhwn8HP8InAP4DATYAdO3aFW3btkXFihXRvn17TJkyBfv372c3SiWz/7+NkpRCiyIlJYVVpQMCAgQUY56GZGBgAJlMhuHDhzPqUn5+PpYsWcKqzSEhIQKjuW/fPtZrpKuri4EDBwpGWiUkJGDEiBGwtbUFEUEqlaJOnTqYN28e7t27JzCICoUCd+/exaZNmzBkyBDUq1cPdnZ2EIvFSn3UfAXZ2NgYpUqVgpWVFUqVKgUTExOWdS5usba2RmBgIAYOHIh169YhPj7+/9q787ia8jcO4J/TppWyUykmS/aSLUTW7PvYJ0v2PdtU1iyD7Jqyh7EP2RlKTIwtRrZJIbImRWnRep/fH3R+XRXh3o66z/v16vUa073nPKftPM/5fr/PN9v0sEePHpG3tzd17NhRnAJerlw5mjx5stxT4MePH5OLiwuVLFlSnAq+Z88e8XhhYWE0YMAAEgSBDA0NaenSpeI08+TkZFqwYAHp6OiQvr4+rVq1Si6OV69eUbdu3cR1dVkToKwuXLhAZcuWJX19fTp48OAXf1ZY4RYbG0vDhw8nGxsbmjBhAo0bN47WrVsn1xvhW3CBzTkBKzhiY2PFh7pOTk45zmiSyWS0ceNG0tXVJSMjo2zbOIaGhlKLFi0I+LANV9Z7340bN8QlaGZmZrR582bxgfLjx49p3LhxYiHdqlUr8vX1FT//7t078vHxoVatWon39p9++okmT55Mfn5+2bYQjImJoVOnTtHixYupf//+ZGVlles9XlNTkwwNDals2bJkYmJCxsbGVKZMGSpWrFi2ddeZH0WLFiUbGxv65ZdfaNmyZRQQEJBtRmBiYiIdP36cxo4dKxbVGhoa1LFjR9q1a5f4+5acnEy7du0iW1tb8djTp08Xp84nJyfTmjVrxO7kXbt2FbfeIvrQ+CxzhmCnTp0oIiJC7vu1YcMG0tPTI0NDQ9q3b1+O3/v4+Hjx4UluO4ww1aGsnICokBTYycnJFBQURAEBARQWFkbnz58nX19fWrp0KVlbW5Ojo2Ouv2xMNWTtFDp//vwcO2tm/QNdtGhR2rJli1yBGxUVJe6taGxsTLt27RI//+7dO5o3b564P3SvXr3ErpdEH9Zp9e3bV7yJtWzZknbu3Ck+XZXJZHT58mWaPn263FPpChUqkKOjI23evJn++++/HONOSUmhBw8e0NmzZ2nXrl20evVqmjt3Lk2ZMoVGjx5Nw4YNoyFDhtCwYcNo9OjR5OzsTHPmzKFVq1bRzp076cyZMxQaGprj3r8ymYzu379P27ZtIycnJ3H9VOZo86RJkygwMFAsfpOTk+nAgQPUoUMHEgSB1NTUqGvXrnJbj9y7d48cHR1JXV2ddHR0aPr06eIT8MyHEpnn6dmzZ7b11Pv376dSpUqRlpYWLVu2LNfvpaenJ2loaJCFhYXcyDdTXePHj6eRI0dSZGQkBQQE0NatW2natGnk5ub2XfvNc4HNOQErWNLT08nV1ZUAUIMGDXLs20H04UFw5ojozz//LNdASyaT0bZt26hUqVKkpqZGI0eOlOuoffr0abET9k8//USbN28WlzDFxMTQb7/9Jhak5cuXp5kzZ8o13oyMjKR169aRg4ODOENMR0eHWrduTfPnz6eAgAC5/jFZ43r16hUFBQXRoUOHaMOGDfTbb7/Rr7/+SuPHj6fhw4fTkCFDaMiQITRixAiaMGECubi40OLFi2njxo105MgRun79OkVHR+c46h0bG0unT5+m2bNnU/PmzcXYdHV1qUuXLuTj4yP39/Tu3bs0depUsXeLhYUFrVq1SizUk5KSaO3atWLjVjs7O/rnn3/E9z979kxc4mVqakq+vr5ycT19+lSc6dayZctcv5f37t2j6tWrk5qaGnl4ePBsNqa0nICokBTYWd2+fZvc3Nxo+PDhNHv2bDp16hRdv36dSpQo8U3bFbDCIzExUfwj/Wmn0KwePnwodp10cHCQe0pK9GHPxsx1xo0aNZK7EcTExJCbmxsZGBgQAOrQoQOdO3dO/EP+8uVLWrBgAVWsWFGcfu3o6Eh//fWX3NTFhw8fkpeXF/Xo0UMcAc584mtnZ0fjxo0jb29vCggIoIiIiFybkORV5rYcf//9N23YsIEmTZpELVu2FLcNAT50KO/cuTOtWrVKrotoeno6nT17lkaOHCl2Jc9MFrJ+7a5cuUK9evUiQRBIR0eHJk+eLLddxo0bN8TmctWqVaPTp0/LxRgZGSk2ObO2tpbrVp5VfHy82BG2U6dO4lo4xqZPny5XWCUnJ9ONGzfIycmJBg8enON0wrzgAptzAlYwHThwQNx54tSpUzm+Ji0tjRYsWECamppUqlQpcZ/mTG/fvqUJEyaQuro6FStWjDw8PMQH1jKZjA4fPizmDKamprRy5UpxNkVaWhodPHiQ2rdvL45Y29ra0tq1a+XujwkJCXT06FGaMGGCXE+UzHXG/fr1o4ULF5Kvry/dvn37m7ftzCouLo6Cg4Ppzz//JHd3d+rdu7fYdC1z7bW1tTVNmTKFTp06JfeQPiIigpYtWyZet4aGBnXv3p3++usv8fcuJiaGFi5cKM4AbNq0Kfn5+Ylf2/j4eJozZ464Haqbm5vcdWXOMihatCjp6OjQmjVrcv2dzrrDiL+//3d/bVjhoKycgKiQFdjnz58nExMTmjt3rlz3QiKi9u3b8363jGQyGa1duzbHTqFZZWRk0Jo1a0hPT4/09fVp9erVckVseno6bd68WVyX1L17d7lGGm/evKH58+eLT2ytra1p69at4g0os/nHkCFDxFFvIyMjGjRoEO3bt0+uKJTJZBQSEkJbtmyhMWPGUKNGjUhfX19uGpeGhgaZmppS/fr1qX379tS/f38aOXIkTZo0iaZPn06//vorzZgxgyZPnkyjRo2igQMHUseOHalhw4ZkZmZGmpqacsfT0dGh+vXr04gRI2jDhg1069YtuRtXfHw8HT58mJycnMSbo66uLvXr149OnDghfq1SU1Np79694jT5YsWKkYuLi9yWIw8fPqSBAweSIAhUvHhxWr16tdzDBplMRlu2bCEjIyPS0tKiBQsW5LqO8tatW1StWjVSU1PLtqcpY/7+/mRqakoLFizINsrRqlWrXB/afAkX2JwTsILr3r17VLNmTRIEgWbNmpVrD5Nbt26JI9KdO3fO9vD97t274kiqmZkZbd26VbwXymQyOnHiBDVr1ky8Fzo7O8uNWD99+pR+++03qlmzplg8N23alDw8POju3btyRX1MTAwdP36c5syZQ507dxa3u8z6Ubx4capRowbZ29tTr169aOjQoTR+/HiaOnUq/frrr/Trr7/S1KlTady4cTRkyBDq0aMHNW/enCwtLcVmrFk/KlWqRN27d6cFCxbQ6dOn5aaKy2QyunHjBi1YsEAc8QdA9erVo5UrV8rd8+/evUujR48Wp8m3a9dObiAiJSWFfv/9d3GqeO/evSk8PFzuax0aGipun9aiRYtct9z8dIeRT79nTLUpKycgKmQF9qxZs2jz5s3iv1NSUujWrVs0fvx46tOnT46dpJlqytopdMmSJbkWYo8ePSIHBwfxRnH16lW5zyckJJC7uzsZGBiQmpoaDRw4UG6NdVJSEq1fv56qV68u3vCcnZ3livH379/ToUOHaNCgQeKIsbq6Otna2tKsWbPIz88v2zQwmUxGERER5OfnR+vXrydXV1dydHSkdu3aUb169ahSpUpUunRp0tfXpyJFipCmpiZpaWmRvr4+lSpViipWrEhWVlbUpk0bGjRoEM2YMYO8vLzor7/+ovDw8Gxfj6SkJDp37hy5u7tTixYtxILcwMCA+vTpQ/v27ZN7svzgwQNydXUVtyGpVKmS3FN7og9PuEeMGEEaGhqkra1NM2bMyDbafPv2bXE2QdOmTeWazH369Vi3bh1pa2tTmTJlKCAgIMfXMfbvv/+Si4sLDR8+nObOnUtnz56lwMBAKlOmzDcfkwtszglYwZaYmEiDBw8m4END09ymGWdtsqWnp0ceHh7ZHvj6+/uLHb+rVatGu3fvlntAf/nyZerTpw9paGiQIAjUrl072r9/v9wOGHfu3KG5c+dSnTp15HqkDBkyhLZt20bh4eHZpji/e/eOrl69Srt376ZFixbR6NGjqVu3btSkSROytLQkY2NjMjQ0JG1tbdLU1CRNTU3S0dEhQ0NDMjY2purVq1OzZs2oR48eNG7cOFqyZAnt27eP/v3332wj4jKZjMLCwmjTpk00cOBA8V4PgOrXr0+LFi2SayyblJREO3bsoObNmxPwYYvPIUOG0M2bN8XXpKam0qZNm8jc3Fz8Pnw6CJKUlESzZ88mLS0tKlasmFyTuU+FhoaKI+i57TDCmDJyAqJCVmDPmTOHevfuTZGRkXTmzBnavn07TZ8+nWbMmCG3FRJjRJ/vFJqVTCajPXv2ULly5UgQBHJycpJ7Gkv0oUnatGnTSEdHh9TU1Khfv35is7PMY/j7+1OvXr3EddgNGjSgNWvWyP1spqWlUWBgILm6ulKDBg3EKWNqampUu3ZtGjp0KK1du5bOnj1LkZGRSllDFB0dTYGBgeTl5UUjRowga2trMWZBEKhu3bo0bdo08vf3l7thxcTE0Pr168WCWE1NjTp16kTHjh2TSy7CwsLIyclJvMGPGTMmWxfPt2/f0uTJk0ldXZ2KFy9OGzduzPUmGh0dTd27dycA1LZtW/5dZ9kkJiZSYGAg+fv7U0hICP3999904MAB+u2336hWrVrk5OREhw4d+ubjc4HNOQErHLZv3076+vpkZGRE+/fvz/V1jx49ok6dOhEAql69Ovn5+cl9XiaT0f79+6lGjRpiob1t2za5YvzZs2c0Z84cce1xyZIlady4cXTp0iW5e3tERAStW7eOevToIbdsq2zZstS5c2eaM2cOHThwgO7du6eULvnJycl0584d2rt3L7m5uZGDgwOVKFFCjKNUqVLUp08f2rJli9iwjOjDTL2zZ8+Sk5OTOCJeqVIlWrx4sdxa9aSkJPL29hYLaxsbGzp58qTc10Amk9HBgwfF5XX9+vWTm0KfVWZjWj09vc/uMMJUl7JzAqLc8wLhw+fyl42NDV27du2b35+amoqJEyfi+vXraNiwIQDA0tISHTt2hJmZmaLCZIUIEWHTpk2YNGkStLW1sX79evTq1SvH17579w7z5s3DmjVroKenh5kzZ2L8+PEoUqSI+JqoqCgsW7YM3t7eSEhIQPv27eHs7IxWrVpBEATxNX/88Qe2b9+OW7duQU1NDc2bN0ePHj3QpUsXVKhQQe6cFy9exMWLF3HlyhX8+++/iI6OFj9ftGhRVKpUCRUqVICJiQnKlCmDkiVLwtDQEAYGBtDW1oaWlhbU1NRAREhNTcX79++RkJCA2NhYxMTE4NWrV3j+/DmePHmC8PBwvH37Vjy+oaEhrK2t0aBBA9ja2qJJkyYoXry4+PnIyEgcPXoUvr6+8Pf3R3p6OqpWrYpffvkFv/zyC0xMTMSv88WLF7FixQocPHgQWlpacHJywowZM2BqaioeLz09HZs2bcKsWbMQExMDJycnLFq0CCVLlszxe3Lq1CkMGTIE0dHRWLRoEZydnaGmpvY1PwJMBXTv3h0GBgYICQmBpaUlatSogdq1a6Nly5YoUqQI3r9/Dx0dnW8+viAI14nIRoEhM3BOwKRx//59DBgwAEFBQRg8eDBWr16NokWL5vjao0ePYtKkSQgPD0eXLl3g4eGBKlWqiJ+XyWTYv38/Fi5ciFu3bsHU1BQTJ07EsGHDYGhoCADIyMjAqVOnsHXrVhw5cgQpKSkwMzNDz5490a1bNzRu3BgaGhri8W7fvo1//vkHly9fRlBQEEJDQ5GZs6urq8PMzAzm5uYwNTVFuXLlUKpUKRQvXhxFixaFrq4uihQpIh4vPT0dycnJSEpKQlxcHN68eYPXr1/jxYsXePbsGR4/fownT55AJpOJx7e0tET9+vXRuHFjNGnSBJaWlmJ+k5qaisDAQBw6dAi+vr54+fIl9PT00KNHDwwePBgtWrQQ79GvX7/GunXr4OnpiaioKDRs2BCzZs1Chw4dxOMBwM2bNzFlyhScOXMGNWrUwNq1a2Fvb5/j9yM6OhojR46Er68v7O3tsX37djEPYSyTsnMC4DN5QU5Vt7I/FNHQJCkpiZKTk+nFixfiPryMfUloaKi4V3K/fv0+2z0wJCRE3OLD3NycduzYkW10NSYmhubPny+uT65RowZ5eXllm5J4+/Ztmjlzplzn8Fq1atG0adPo1KlTOU7Lev78OZ06dYpWrVpFY8eOpQ4dOlDNmjXFBmNf+2FoaEjVq1endu3a0ahRo2j58uV0/PhxioiIyDZC/v79ewoICCBXV1dxChw+dhSfNm0aXbt2Te49iYmJtGXLFvG1RkZG5Orqmm30KPPpdNZ9xa9fv57r9+Ddu3c0cuRIcfQg6/6ajGUVGhpKlStXFv99+fJlmjdvHvXu3ZtcXV0Vcg7wCDbnBKxQSU1NJTc3N1JTU6MKFSrQmTNncn3t+/fvadGiRaSvr08aGho0duzYHO9xx44dE6dI6+np0ejRo+VmuhF96NC9detW6tChg9id28jIiPr06UObN2+mR48eZTt/YmIiBQUF0bZt28jV1ZX69OlDjRo1IhMTk2z9VfLyoaWlRWZmZtSkSRPq378/zZ49m3bs2EH//vtvtt1GZDIZhYaGkre3N3Xv3l1s8Kqjo0Pdu3en3bt3Z8tjrl69SoMHDxa3+Wzfvj0FBARkyzciIiLI0dFR7M+yZs2az47QHzx4kEqXLk1aWlq0dOlS7sHCcpQfOQFR7nlBgb2Zforb8LO8Sk1NpXnz5pGGhgaVKVOGfH19P/v606dPU926dQkA1a5dmw4dOpTt5y05OZl8fHzEdUB6eno0bNgwunDhQrbXhoSE0NKlS8ne3l68KWpoaFCDBg1o0qRJtGvXLgoNDf1sx/DU1FR6+fIl3bt3j4KCgigwMJDOnDlDfn5+5O/vT4GBgXT16lX677//6MWLF59dk5SRkUEPHz6kffv20bRp06hJkybiDV9dXZ2aNGlCCxYsoJs3b2abynXt2jUaO3YsGRoaEgCytLQkLy+vHB8YnD59mho2bEgAqGrVqnTw4MHP/t76+fmRmZkZCYJAU6ZMyXF7McYy3b59m+zt7SkwMFDu5+rWrVvUtm1bOnLkyHefgwtszglY4XTp0iWxc/aoUaM+u24/MjKSRo0aRerq6qSrq0suLi45Pqy/fv06OTo6igVm48aNadOmTdmOHRcXR/v27SNHR0e5Nc6mpqbUt29fWrVqFV24cCHHrboyyWQyevv2LYWHh1NwcDD9888/FBAQQH5+fuTn50dnz56lixcv0s2bN+nRo0cUFxf32d+R2NhYOnfuHHl4eFCvXr2yxTV8+HA6dOhQtu7LMTEx5OnpSVZWVmIuNGrUqBz7qrx8+ZImTJhAWlpaVKRIEZo6dSq9efMm15hev34t7hxSp06dbA8tGMsqP3ICokJcYF+/fp1+//13foLFvtqNGzfEwrlnz55ya4o+lZGRQbt27SILCwvCx27hvr6+2X7uMve5Hjp0KOnp6RHwYV/MWbNm0Z07d7Ld0BISEuivv/4iFxcXatasGWlra4s3MV1dXapXrx4NGDCA5s2bR9u3b6ezZ89SaGgoxcbG5jmBlMlk9O7dO7p//z4FBgbSjh07aP78+eTo6EgNGzYUn0RnPtFu1KgRTZ06lY4cOZLjSFBYWBgtWLBAbOhWpEgR6tevn1x30KznPnXqFDVt2lS8MW/cuDHX7q1EH9ZaDxkyhABQlSpV6MKFC3m6TsY2bdpEI0eOJF9fXwoLC6OkpCQiIlq9ejUNHz78u4/PBfaPnxMkJiaSj48PXbt2TWHHZKohMTGRnJ2dSU1NjUxMTL6YgIeFhVHfvn1JEAQyMDAgV1dXuf2zM71+/ZqWLVsmzmDT0dGhPn360KFDh3IcKb59+zatXbuWevfuTcbGxtk6fHfs2JGcnZ3J09OTjh49Sjdu3KCXL19+1brslJQUev78OV2/fp0OHz5Ma9asoYkTJ5KDg0O2TuXm5ubUv39/WrduHYWGhuaYx+zdu5e6desmPpyvW7cueXp65phDPH/+nCZNmkQ6Ojqkrq5Ow4YN+2zXb5lMRjt37qRSpUqRpqYmzZ07Vylr0Fnho+ycgKgQF9jr168nALRp0yaFHZOpjtTUVFq0aBEVKVKEihYtSp6enp8dOU5LS6MtW7bQTz/9JE5b9vHxyXGEOD4+nnx8fKhVq1ZiE7OqVavSjBkz6MKFCzmeJzU1lW7cuEGbN2+mSZMmUZs2bcjU1JQEQcg2vUtDQ4NKlChBZmZmVK1aNapVqxbVqVOHateuTZaWlmRubk4lS5bMdepY+fLlyd7ensaNG0cbNmygoKAgSk5OzhZTRkYGXb16lWbPnk21a9cW39+0aVNat25djk+c09LSaO/eveKIvomJCXl6euZ4/EwymYy2bdtGJUuWJHV1dXJxceFRa/bVfHx8aNiwYTRjxgxavHgxTZ06lRo0aKCQPVG5wP7xc4KkpCTS0dGhWrVqKeyYTLVcvnxZ3EKrR48e9PTp08++/tatW9S7d28SBIF0dXVpwoQJOU7xlslkdOnSJRo9erTYPExfX59+/vln2rFjB0VHR+d4/OfPn9ORI0fI3d2dfv75Z6pVqxbp6OjkeF83MDCg8uXLU+XKlalmzZpUp04dqlOnDtWoUYMsLCyoXLly2bb/zPzQ09OjunXrUv/+/WnRokV04sSJbI1eM0VGRtKWLVuoW7duYizlypWjyZMn57qU6969e+Tk5ERaWlqkrq5Ojo6Och3IcxIWFkZt2rQhfGwY+z3bKTHVpMycgCj3vKBANjnLKiUlBYaGhjAzM0NISIhcwwTG8ur+/fsYM2YM/P39YW1tDU9PTzRu3DjX16enp2Pfvn1YsmQJbt26hbJly2LMmDEYMWIEypQpk+31kZGR8PX1ha+vL/7++2+kp6fDyMgILVu2RKtWrdC8eXO5BiKfSk5OxpMnT/D06VO8ePECUVFRiImJQWxsLBISEpCcnIy0tDTIZDIIggBNTU1oa2tDX18fxYoVQ4kSJVC6dGmUL18eJiYmMDMzg66ubo7nIiI8ePAAf//9NwICAuDv74/Xr19DTU0Ntra26NGjB3r16iXXtCzT27dvsXnzZnh6eiIiIgJVqlTBtGnTMGjQILkmcZ+6efMmxo8fj/Pnz6NRo0ZYv349ateunevrGcvqyJEjCAwMhIaGBkaPHg0iwtWrV/HmzRtERUXB2toanTp1+u7zcJMz5VBkTgAA48ePh6enJ86dO4fmzZsr7LhMdaSmpmL58uVwd3eHuro6Zs2ahUmTJn32PhYSEoIlS5Zg586dkMlk6NGjByZMmICmTZtmu7enpaUhICAABw4cwOHDhxEVFQU1NTXUr18frVu3RosWLdC4cWPo6enleC4iQmRkJJ48eYLnz58jMjISr1+/xtu3bxEfH4+kpCSkpKQgIyMDAKChoYEiRYpAV1cXRYsWhZGREUqVKoWyZcvC2NgYZmZmKFWqVK45SFxcHC5evIizZ8/C398fN27cAACYmJigW7du6NmzJ5o1awZ1dXW598lkMvj7+2PNmjU4fvw4tLW1MWTIEEybNg0VK1bM9WuZmJiIxYsXY+nSpdDW1sbChQsxevTobMdnLCf5lRMAhbDJWVbOzs4EgE6dOqXQ4zLVIpPJaPfu3VS+fHkCQAMHDvzik2uZTEZ//fUXtWvXjgCQpqYm9evXL8dGHpnevn1Le/fupSFDhojbdgAf9s7u0KEDzZ49mw4ePJjjPtWKlrnP9tGjR8nd3Z26dOkiNmwDQGXKlKEBAwbQH3/8kePUt8xjXLp0iYYMGSI+ybazs6ODBw9+djYAEVFUVBSNGjWK1NTUqESJEp/dpouxnHh6epKtrS39+eefNHToUKpSpQoNGDBAKQ3xwCPYBSIniIiIIADUsWNHhR6XqZ7w8HDq0qULAaDKlSvT4cOHv7g86+nTpzR9+nSxN0mtWrXI09OT3r59m+PrMzIy6PLlyzRnzhxq1KgRqauriz1QrK2tadSoUbRx40YKCgrKtuZZGd69e0eXLl0ib29vcnJyolq1aomz6DQ1NcnOzo4WLlxI//77b65fi6ioKFq2bJm4rr106dI0Z86cXEfEM2VOB8/MjQYMGPDZ5XuMfSo/cwKiQjyCDQAvX75E+fLl0bZtW5w6dUphx2WqKSEhAYsWLcKKFSugpqaGKVOmYNq0ablu35EpNDQUXl5e2LZtG+Li4lCpUiX88ssvGDhwIH766acc30NEePjwIQIDA3Hx4kVcvnwZISEh4lYZOjo6qFKlCiwsLFCxYkWYmJigfPnyKF26NEqUKAFDQ0Po6+vLbdMlk8mQlpaG5ORkJCQkIC4uDjExMYiKisLLly/x9OlTPH78GA8fPkRoaCgSEhIAAIIgoEqVKmjYsCFsbW1hZ2eHatWq5fpE++nTp9i1axe2bduGkJAQ6OnpoX///hg7dizq1Knz2a/V+/fvsWbNGixatAiJiYkYM2YM5s6dK7c1GGN50bFjR4waNQqdO3cG8GG2x+rVq7FhwwZMnjwZ48aNU9i5eARbORSdEwBAt27dcPjwYTx8+BCVKlVS6LGZ6vnrr78wefJk3Lt3D/b29vDw8EC9evU++57ExETs2rUL3t7euHHjBrS1tdG9e3cMGjQIrVu3hqamZo7vi4+Pxz///IMLFy7g0qVLuHbtGt69eyd+3szMDFWqVEHFihVhZmYGY2NjlC1bFiVLloSRkREMDAzEbboyR3wzMjKQkpKCxMREvHv3Dm/fvkV0dDQiIyPx7NkzcfvOsLAwPHv2TDxX8eLFxa26mjZtisaNG+c6+y0lJQUnT57E9u3bcezYMaSlpcHW1hajR49G7969Pzv6DwDnz5/H1KlTcfXqVVhbW2PVqlVo1qzZZ9/D2KfyMycACvkINhFRr169CACFhIQo/NhMNT169Ij69OlDAKhkyZK0cuXKPK0JTkpKou3bt1PLli3Fp77169enpUuXfnG9EdGHhiGXLl2iDRs20KRJk6h9+/ZUpUoVsRPp935oaWmRhYUFtW3blsaPH09eXl50/vz5z3ZNzRQREUGrV68Wm5YBIFtbW9qwYQO9e/fui+9PTU2lDRs2iI1bOnXqlGN3Ucbyytvbm8aNG5dtZCQsLIwGDRqUbRud7wEewS4wOcHVq1cJAI0ZM0bhx2aqKTU1ldauXUslS5YkAPTzzz/TvXv38vTea9eu0ZgxY8jIyEjMKUaOHEmnT5/+7C4fRB9GuO/fv0/79++nefPmUb9+/ah+/fpiHIr4KF26NDVq1IgGDhxICxcupEOHDtGjR4++OFr//v17Onr0KA0ePJiKFSsmznxzdnamO3fu5OlrExwcTB07diQAZGxsTD4+PjyTjX2z/MwJiAr5CDYABAcHw8rKCk5OTti4caNCj81UW1BQEFxcXHDmzBkYGxvDxcUFw4YNg7a29hff+/TpU+zZswd79+7F9evXAQDVq1dHp06d0K5dOzRp0uSLT3UzERGio6Px8uVLcQ12XFwcEhMT8f79e6SlpX34pc6yBltPT09cg12qVCmUK1cOpUqVgpqaWp7OmZaWhitXruDUqVM4fvy4uO6qZs2a6NOnD/r165fr6Pynx9mxYwcWLFiA8PBwNGrUCIsXL+b1key7vXr1CrNmzUJKSgocHR3RpEkTyGQyxMfHw8bGBqGhodDR0VHIuXgEWzmUkRMAgJWVFYKDg/H+/fs8/b1mLC/i4uKwbNkyrFy5Eu/fv8egQYPg5uaGypUrf/G9maO8e/bswbFjx5CYmIhixYrBwcEB7du3R5s2bVC+fPk8x5KUlIQXL17g1atXiI6OlluDnZqaKq7BVldXF9dgGxgYyK3BLleu3Ff9fkREROD06dM4ceIE/Pz8xGvo2rUr+vXrh9atW0NDQ+OLx7l9+zbc3d2xf/9+GBoaYsaMGZgwYUKuI+SM5UV+5gSACoxgExE1aNCAAGTbg5cxRQgICKAmTZqI3TKXL1+ep1HbTI8ePaKVK1dSy5Ytxc7eOjo61KpVK5o3bx75+fnluKVFfoqPj6dz587RokWLyMHBQew2qqamRk2bNqUlS5ZQaGhono+XlJREXl5eZG5uLm5vdvToUd6jlilUYmIiLV26lKytralLly40YcIEatWqFbm4uCj0POAR7AKVE+zZs4cA0JIlS5RyfKbaXr16Rc7OzqStrU1qamrUv39/unnzZp7fn5SURIcPH6ahQ4fK7TNdrVo1GjlyJP3xxx90//59Se+XGRkZFBISQlu2bKGhQ4eKO6jg47abo0aNopMnT35xFD6rK1euULdu3cSu5zNnzvzs/teMfa38ygmIVGAEGwAOHz6Mbt26Yf78+Zg5c6bCj88YESEgIAALFy7E2bNnYWhoiFGjRmHs2LEwMTHJ83Hi4+Nx7tw5nDlzBufOncOtW7eQ+btYpUoV1K1bF7Vr14alpSWqVq2KSpUqKfSJW3JyMh4/fozQ0FCEhITg9u3bCA4Oxr1798T139WrV0eLFi3ETueGhoZ5Pn5UVBTWrVuH33//HVFRUWjYsCHc3NzQqVMn7vTPFIY+ztjI6uTJk0hNTUW1atVgamqq0NEQHsFWDmXlBOnp6TA2Noa6ujqePXuW55k7jH2NyMhILF++HN7e3khMTISDgwOcnZ3RunXrPN/viAg3b96En58fzp07hwsXLojrro2MjGBlZYU6deqgRo0aqFatGipXrvzZrt9fi+hDV/L79+/j3r17uHPnDm7evIng4GAxjuLFi6NZs2awt7dHmzZtPrvzyacyMjJw7NgxrFixAoGBgTA0NMSECRMwceJE7r3CFCa/cwIg97ygUBXYGRkZqFChAlJTUxEZGcnt/JlSXblyBR4eHjh48CDU1NTQo0cPjB07Fs2aNfvqm15cXByuXLmCq1ev4vr16wgODsbjx4/lXlOmTBmYmJigXLlyKF26NIoXLw5DQ0Po6enl2uQsMTFRbHL2+vVrvHz5Es+ePcPLly/ljm1qaoo6derA2toaDRs2RMOGDVGiRImvugYiQlBQELy8vLBnzx6kpKSgQ4cOmDp1Klq0aMGFNVOKpKSkz245p8ifOy6wlUNZOQEALFy4EDNnzsTu3bvRt29fpZyDMQB48+YNvL29sXbtWrx69QqWlpYYO3YsBg0a9MUmqZ/KyMjA3bt3cfnyZVy7dg03btzAnTt3kJycLL5GV1cXFSpUQPny5VGmTBm5Jmc6Ojo5NjlLSkqSa3L26tUrPH/+HE+fPpU7tp6eHmrVqgUrKyvY2NigUaNGqFat2lc/pIqJicHWrVvh5eWF8PBwVKhQARMnTsTw4cNhYGDwVcdiLC/yMycAVKTABgAPDw9Mnz4d27Ztwy+//KKUczCW1aNHj+Dp6YktW7YgNjYW1apVw7BhwzBw4ECULVv2m4+bkJCAe/fuISwsDOHh4YiIiBCL49evX+PNmzdyN8TcaGlpoXjx4uIa7Mw9L3/66SdUrlwZ1apVQ7Fixb45zpiYGOzatQtbtmxBcHAw9PT08Msvv2D8+PGwtLT85uMylptjx45h37590NXVRbFixWBubg4HBwdxX9Xjx4+jRYsWue4h+624wFYOZeYEsbGxKFmyJOrVq4crV64o5RyMZZWSkoI9e/Zg7dq1uH79OvT09PDzzz9j6NChaNKkyTcn+BkZGeLMswcPHuDx48d48uQJXrx4gaioKERHR+Pdu3f4Ul4vCAKKFSuGkiVLimuwK1SogIoVK8LCwgJVq1ZFhQoVvnnGR0ZGBs6ePQsfHx8cOHAAKSkpaNq0KcaPH48ePXrkaX02Y19DqpwAUFKBLQiCB4DOAFIBPAQwhIhiv/Q+Zd5M4+PjUbJkSVSvXl1syMRYfkhKSsK+ffuwceNGXLx4Eerq6mjTpg369++Prl27fvUT7LxITU39YpOzvDZR+xqJiYk4fvw4du/ejePHjyMtLQ1WVlYYPnw4BgwYoJRrZSyTqakpvL29ERcXh/T0dPz3339ISkpCv379UKlSJWzfvh3Tp09X+Hm5wP6yb8kLlJkTAMDw4cOxadMmXL58GQ0bNlTaeRjLiohw7do1rF+/Hnv37kVCQgJ++ukn9O/fH3379kX16tUVfk6ZTIbExMTPNjnT1dVV+HIJIkJwcDD27NmDXbt24dmzZzA0NMTAgQMxYsQI1KpVS6HnYywrqXICQElNzgC0BaDx8b+XAFiSl/cpq6FJpkmTJhEA2rlzp1LPw1huQkJCyMXFhczMzMStsTp16kSbNm2ily9fSh3eN4mOjqbt27dTz549SVdXV2z2NnnyZAoODpY6PKYiLl++TA4ODuK/U1JS6L///iNvb29q06YNPX36VGlNgcBNzpSSFyg7J3j48CEBoIYNG3KDRSaJ+Ph42rp1K7Vq1YrU1NQIAFlaWpKbmxtduXKlQG5LlZaWRufPn6epU6eKzc80NDSoY8eOtGfPnjxta8rY95IyJyDKPS9Q2BRxQRC6A+hFRAO+9FplP61+//49ihcvjkqVKuHOnTu89pNJRiaT4fLly/jzzz/h6+uLJ0+eAACsra3Rrl07tGzZEra2tj/kthQpKSm4cuUKAgICcOrUKVy9ehUymQzlypVDt27d0Lt3b9jZ2XGvA5avEhIS0KVLF1SpUgUzZswQp4ABwJw5cyAIAubOnauUc/MI9tfJa16g7JwAAMaNG4fff/8dQUFBsLHhbyGTTmRkJPbv348DBw4gMDAQMpkMpUuXRtu2bdG6dWvY29ujQoUKUoeZDREhPDwcAQEB8Pf3h5+fH96+fQtNTU3Y29ujV69e6NGjx1f3b2Hse0iZEwD5sAZbEISjAPYS0Y4vvTY/bqbOzs5YuXIlLly4gCZNmij1XIzlBdGHLqEnTpzAyZMncfnyZaSnp0NTUxPW1tawtbVFgwYNYG1tDQsLi3zteEtEePToEf79919cvXoVly5dQlBQEFJSUqCmpgYbGxu0b98eHTp0gI2NDXfjZZJ68+YNlixZglevXqFq1apo0aIFGjdujJ9//hm1a9dW2i4SXGB/nbzmBfmREzx58gRmZmbo3Lkzjhw5otRzMZZXMTExOHnyJE6ePAk/Pz+8fv0aAGBmZgZbW1s0atQI9erVQ506daCvr5+vscXFxSE4OBjXrl3DlStX8M8//+DFixcAgPLly6Nt27bo0KED2rZt+119XBj7XlLlBMB3FNiCIPgDyKlTkxsRHf74GjcANgB6UC4HFARhBIARAFChQoV6ERERX3cFX+nly5coX748HBwccPLkSaWei7FvER8fj/PnzyMwMBAXLlzA9evXxaZlurq6qFGjBiwtLVGlShVUqlQJZmZmMDExQdmyZaGlpfXV50tLS8OrV6/w7NkzPHnyBOHh4QgLCxO35IiPjwfwoSmalZUVmjRpAjs7O9jZ2cHIyEih187Y94qLi0NgYCBu3LiBv//+GzExMbCysoKXl5dCt7TLigvsDxSRF+R3TgAAnTt3xrFjx/DkyROYmpoq/XyMfQ2ZTIY7d+7g3LlzOH/+PC5duoTnz58D+NCYrFKlSqhevTqqVq0KCwsLmJubw9TUFMbGxihatOhXz9YkIsTGxopdxB89eoQHDx4gNDQUd+/eRdbfSTMzMzRu3Bh2dnZo0aIFqlWrxrND2Q9FipwAUOIItiAIgwGMBNCKiJLy8p78eFoNAD179oSvry/Cw8Plpgww9iNKS0vDnTt3cOPGDdy8eRN3795FSEiI+MQ4q6JFi6J48eIoVqwY9PX1oa2tDU1NTaipqYGIxG26EhIS8O7dO7x58waxsbHZjlOmTBlUq1YNNWvWRO3atWFlZYXatWsrpTEaY8qQnJyMIkWKICIiAsbGxtDU1FTaubjAzpuvzQvyKyc4d+4c7O3tMW7cOKxdu1bp52Psez1//lzcuvPOnTv477//8ODBA6SkpMi9TktLCyVKlICRkRH09fWhq6srt01Xenq62BQ1Pj4esbGxiImJQVpamtxxdHR0ULlyZVSvXh21atVC3bp1Ua9ePZQpUybfrpmx75GfOQGgvC7iDgBWAGhORK/z+r78upleunQJtra2GDFiBNavX6/08zGmDElJSXj06BGePHmCZ8+eITIyUtym6927d0hMTERycjLS0tIgk8mydREvWrQojIyMUKpUKZQtW1bcpsvc3Jz3oWQFlo+PD9q2bQtjY+N8OycX2F/2LXlBfuUEAFCnTh2EhoYiOjo636fcMqYIMpkMz58/x+PHj/Hs2TO5bbri4uIQHx+frYu4hoYGtLS0oKurCwMDAxgaGqJEiRIoU6YMypcvD1NTU5ibm6Ns2bK8BIwVSK9evcKRI0fg5OSUr7MrlFVgPwBQBEDMx/91mYhGfel9+XkztbGxwY0bNxAREQETE5N8OSdjjDHliY2NhbGxMRo3bgx/f/98Oy8X2F/2LXlBfuYEW7duxZAhQzB79mzMmzcvX87JGGNMuUaNGoVNmzYhODgYNWvWzLfz5pYXfNdjKiKyICJTIqr78eOLxXV+8/DwgEwmw6xZs6QOhTHGmAIsWrQISUlJGD16tNShsE/86HnBwIEDYWFhAXd3dyQkJEgdDmOMse/09OlTrF+/HtbW1qhRo4bU4QD4zgK7ILC3t4etrS22bt2a4xpUxhhjBUdaWhrWrl2LypUro2fPnlKHwwoYDQ0NLF26FACwevVqiaNhjDH2vZYtWwYAWLVq1Q/TfK/QF9gAxPbsK1eulDgSxhhj32Pjxo1ITk7G9OnTpQ6FFVBdu3ZFmTJlsHr1ashkMqnDYYwx9o1SU1Ph5eWFGjVqwNbWVupwRCpRYDs4OMDExARr1qxBenq61OEwxhj7BkSE5cuXw8DAAIMGDZI6HFZAqampwdnZGa9fv8bu3bulDocxxtg38vLyQnp6OmbMmCF1KHJUosAWBAFTp05FbGwstm/fLnU4jDHGvsGxY8cQHh6OsWPH8lZy7LuMHj0ampqaWL58udShMMYY+wZEhBUrVqB48eLo27ev1OHIUYkCGwCcnJygp6cHNze3bPsHMsYY+7HJZDJMnz4dgiBg8uTJUofDCjgDAwOMGDECN27cwL59+6QOhzHG2Ffy9PTE06dPMXHiRKXvd/21VKbA1tPTw+LFixEZGQkfHx+pw2GMMfYVjhw5gnv37mHWrFkoXbq01OGwQmDp0qUoUqQI3N3d8T1bljLGGMtf6enp+O2331C6dGm4ublJHU42KlNgA/8fxV65ciXfTBljrADx8PAAADg7O0scCSssdHV1MWrUKNy9excXL16UOhzGGGN55Ovri5cvX2LChAlQV1eXOpxsVKrA1tbWxujRoxEWFoazZ89KHQ5jjLE8CAkJwcWLFzF48GAUK1ZM6nBYIeLi4gIAWLhwocSRMMYYyysPDw8IgoAJEyZIHUqOVKrABv5/M50zZ47EkTDGGMuLzOlfP1qXUFbwlSlTBt26dcPJkycRFhYmdTiMMca+ICgoCNeuXcPw4cNhYGAgdTg5UrkCu3jx4hg+fDguXLiAEydOSB0OY4yxz7h69SoOHjyIrl27olq1alKHwwqhefPmAfjQWZwxxtiPzcnJCYIgYObMmVKHkiuVK7AB4LfffgMALF68WOJIGGOMfc6SJUsAAMuWLZM4ElZY1a5dG3369EFAQABCQkKkDocxxlguLly4gFu3bmHs2LEwNTWVOpxcqWSBXaJECQwcOBDnz5/HzZs3pQ6HMcZYDp49ewZfX1+0a9cOFhYWUofDCrHM5QeZzfQYY4z9eDIfuk+dOlXiSD5PJQtsAPj1118BANOmTZM4EsYYYzmZPn06AMDV1VXiSFhhZ2VlhcaNG8PHx4fXYjPG2A/o8uXLOHbsGLp06QIzMzOpw/kslS2wa9SoAUdHR/j5+eH69etSh8MYYyyL58+fY/fu3ejQoQPs7OykDoepgM2bNwMA3N3dJY6EMcbYpzIbVP/+++8SR/JlKltgA8DcuXMB/H+6AWOMsR9D5rZJs2fPljgSpiosLS3RvHlz7Ny5E2/evJE6HMYYYx89fvwYp0+fRq9evWBiYiJ1OF+k0gW2ubk5unXrhj///BPXrl2TOhzGGGMAHj58CG9vbzRp0gQNGzaUOhymQjJHSMaOHStxJIwxxjI5OTkBKDjbLKt0gQ0Ay5cvBwDMnz9f4kgYY4wB/9/pIfPvM2P5xd7eHm3atMGePXvw8uVLqcNhjDGVd+/ePZw5cwb9+/dHzZo1pQ4nT1S+wK5UqRI6duyII0eO4OnTp1KHwxhjKi0uLg5btmxBgwYNePSaSWLKlCkA+AEPY4z9CDIfuv/oncOzUvkCG/j/NyyzszhjjDFpzJs3D0QkFjmM5be2bduiatWqWLNmDV69eiV1OIwxprLu37+P7du3o2nTprCyspI6nDzjAhtAixYt0LJlS+zatQsPHjyQOhzGGFNJsbGxWLlyJWrXro1evXpJHQ5TUYIgYO3atUhLSxOb7THGGMt/mY1OV6xYIXEkX4cL7I+WLVsGAPDw8JA4EsYYU02Zf4eXLl0KNTW+PTHptGnTBrVr14aXlxeSk5OlDocxxlROVFQU9uzZg1atWqF+/fpSh/NVOIP5yMrKCo0aNcKGDRsQHh4udTiMMaZSoqKi4OHhAQsLC7Rp00bqcBiDm5sbMjIy4OrqKnUojDGmcpydnQEUzO06ucDOwtvbGwB3FGeMsfzm4eGB1NRUrF27lkev2Q/h559/ho2NDdasWYO4uDipw2GMMZXx4sUL7Ny5E+3bt4ednZ3U4Xw1zmKyqFu3LmxtbbF161ZERkZKHQ5jjKmE+Ph4eHl5oXLlymjXrp3U4TAmmjx5MjIyMgrc+j/GGCvIFixYAKBgdQ7PigvsT2SOXmdOS2CMMaZcM2fORFJSEhYuXAhBEKQOhzFRnz59UKVKFbi7u+Pt27dSh8MYY4VeeHg4vL290ahRI9jb20sdzjfhAvsTLVu2RIsWLbB7925ERUVJHQ5jjBVqycnJ8PLyQvXq1dG7d2+pw2FMjrq6uth8b9WqVdIGwxhjKiCz4fSKFSsK7EN3LrBz4OLiAoD3xWaMMWVzd3dHenq6+HeXsR9N586dYWJigiVLlvAoNmOMKdHjx4+xbt06WFlZoXHjxlKH8824wM5B27ZtYW9vDx8fH0RHR0sdDmOMFUoJCQlYsmQJatasiQEDBkgdDmO58vb2RkpKCpYsWSJ1KIwxVmhlLtXdsmWLxJF8Hy6wczFnzhwAwG+//SZxJIwxVjh5enpCJpPB1dW1wE4DY6qhY8eOsLCwgJeXF1JSUqQOhzHGCp13795hy5YtaNSoEerWrSt1ON9FIQW2IAhTBEEgQRBKKuJ4PwI7OztYW1tjxYoVePLkidThMMZYofL27VvMmTMHFStW5LXXhVBhywsEQYCbmxvi4+Mxa9YsqcNhjLFCZ+LEiQAK5r7Xn/ruAlsQBFMAbQEUqipUEARxWw43NzeJo2GMscJlwYIFSE1NxZIlS6ChoSF1OEyBCmte4OjoCEtLS6xevZqboDLGmAI9fPgQW7duhZ2dHdq3by91ON9NESPYKwFMB0AKONYPpXnz5mjRogV27NiB0NBQqcNhjLFCITo6GitWrEDNmjXRs2dPqcNhilco8wJBELBs2TKkpqZi3rx5UofDGGOFRmaj08WLF0sciWJ8V4EtCEJXAM+J6GYeXjtCEIRrgiBce/369fecNl9ljmIvX75c4kgYY6xwyPy7umrVKqipcSuQwiSveUFBzQk6dOiAOnXqwMvLC8nJyVKHwxhjBV5kZCT+/PNPtG7dukB3Ds/qi5mNIAj+giDcyeGjKwBXAHmaKE9EG4jIhohsSpUq9b1x5xsrKys0bNgQGzduxMOHD6UOhzHGCrSoqCgsX74clSpVQqtWraQOh30DReQFBTUnAP6/bIyXjzHG2PebOnUqgMKx9jrTFwtsImpNRDU//QAQDqAigJuCIDwGYALgX0EQyio35Py3adMmAOApYYwx9p0WL16M1NRUrFu3TupQ2DdS9bygd+/eqF+/PlatWsX7YjPG2Hd49uwZdu7ciQ4dOqBZs2ZSh6Mw3zw3j4huE1FpIjInInMAzwBYE1GkwqL7QdSsWRMODg74448/cOvWLanDYYyxAunJkydYuXIlbGxs0Lp1a6nDYQqmSnmBu7s7ZDIZJk2aJHUojDFWYI0cORJA4RvE5MVvebRs2TIAgLOzs8SRMMZYwfTrr78C+PD3lPe9ZgWZg4MDmjZtiu3btyMsLEzqcBhjrMC5evUqTpw4gW7dusHGxkbqcBRKYQX2xyfW0Yo63o+mRo0a6NevH86cOYOLFy9KHQ5jjBUoDx48wO7du9G6dWs0b95c6nBYPijsecHq1asBAHPmzJE4EsYYK3hcXV0BAB4eHhJHong8gv0VFi5cCIA7ijPG2NfKnAVUGG+kTDVZW1ujVatW2LNnD6KjC+1zBMYYU7j//vsPZ86cQd++fWFhYSF1OArHBfZXqFixIrp06QJfX18EBQVJHQ5jjBUIDx48wPr169G4cWPUrVtX6nAYU5jMrrcTJ06UOBLGGCs4Mv9mzpw5U+JIlIML7K/k7e0N4P/TGhhjjH3erFmzAAC///67xJEwplh2dnZo3749du3ahfDwcKnDYYyxH15wcDD8/f3h6OiIGjVqSB2OUnCB/ZXKly+PQYMGwd/fH2fOnJE6HMYY+6EFBwdjz549cHBwgJWVldThMKZw7u7uAIDx48dLHAljjP34RowYAeD/D98LIy6wv8GSJUsAfOgoTkQSR8MYYz+uKVOmAODeFazwsrGxQdeuXXHixAlcunRJ6nAYY+yHdezYMQQFBcHJyQk//fST1OEoDRfY36BcuXKYOnUqbt26haNHj0odDmOM/ZCuXr2KgIAAODo6onr16lKHw5jSrF27FsD/12QzxhjLbubMmdDU1MTSpUulDkWpuMD+RjNmzIC2tjYmTpyIjIwMqcNhjLEfzpgxYwAAc+fOlTYQxpTM1NQUjo6O8Pf3h7+/v9ThMMbYD+ePP/7AzZs3MWnSJBgZGUkdjlJxgf2NSpYsCRcXFzx+/Bi7d++WOhzGGPuhBAQE4Pr16xg3bhzMzc2lDocxpVu1ahUAYNq0abx8jDHGskhPT8fMmTNRpEgRzJ8/X+pwlI4L7O8wZcoUGBgYYPLkyZDJZFKHwxhjPwQiwpgxY6ChoQE3Nzepw2EsXxgaGmLChAkIDg7GgQMHpA6HMcZ+GBs3bsSTJ08wd+5cFClSROpwlI4L7O+gp6eHWbNmITo6WnxyzRhjqu6PP/5AaGgoJk6ciLJly0odDmP5xt3dHWpqapgxYwbS09OlDocxxiSXmJiI2bNnw8DAAM7OzlKHky+4wP5OU6ZMgYmJCebPn4+0tDSpw2GMMUkREebOnQsDAwMsWLBA6nAYy1fFihXDggULEB4ezsvHGGMMgLe3N6Kjo7Fy5UpoaWlJHU6+4AL7O6mpqWHu3LmIjY3FuHHjpA6HMcYkNXfuXDx69Aiurq7Q1taWOhzG8t348eNhZGSE0aNH4+3bt1KHwxhjknny5AlcXFxQoUIF/PLLL1KHk2+4wFaAoUOHonPnzggODsb79++lDocxxiRz7tw52NnZYdq0aVKHwpgk9PX14ePjg3LlyuHkyZNSh8MYY5I5evQozM3N8ccff0BTU1PqcPKNIEWnS0EQXgOIyPcTf5uSAKKlDiIfqMJ1qsI1AqpxnapwjYBqXGdBukYzIioldRCFDecEPyRVuE5VuEZANa5TFa4RUI3rLGjXmGNeIEmBXZAIgnCNiGykjkPZVOE6VeEaAdW4TlW4RkA1rlMVrpEVHqry86oK16kK1wioxnWqwjUCqnGdheUaeYo4Y4wxxhhjjDGmAFxgM8YYY4wxxhhjCsAF9pdtkDqAfKIK16kK1wioxnWqwjUCqnGdqnCNrPBQlZ9XVbhOVbhGQDWuUxWuEVCN6ywU18hrsBljjDHGGGOMMQXgEWzGGGOMMcYYY0wBuMDOI0EQxguCcE8QhLuCICyVOh5lEQRhiiAIJAhCSaljUQZBEDw+fh9vCYJwUBAEQ6ljUhRBEBwEQQgVBOGBIAi/Sh2PMgiCYCoIwllBEP77+Ls4UeqYlEUQBHVBEG4IgnBM6liURRAEQ0EQ9n/8nQwRBKGx1DExlheqkhMAhTsv4JygYOOcoHApTDkBF9h5IAiCPYCuAOoQUQ0AyyQOSSkEQTAF0BbAE6ljUSI/ADWJqDaAMAAuEsejEIIgqAP4HUB7ANUB9BMEobq0USlFOoApRFQdQCMAYwvpdQLARAAhUgehZKsB/EVE1QDUQeG/XlYIqEpOAKhEXsA5QcHGOUHhUmhyAi6w82Y0gMVElAIARBQlcTzKshLAdACFdmE+EZ0movSP/7wMwETKeBSoAYAHRBRORKkA9uBDAlioENFLIvr343/H48MfX2Npo1I8QRBMAHQEsEnqWJRFEIRiAOwAbAYAIkololhJg2Isb1QlJwAKeV7AOUHBxjlB4VHYcgIusPOmCoBmgiBcEQThb0EQ6ksdkKIJgtAVwHMiuil1LPloKICTUgehIMYAnmb59zMUwptMVoIgmAOwAnBF4lCUYRU+JLUyieNQpooAXgPw+TjtbZMgCHpSB8VYHhT6nABQybyAc4ICjHOCAq9Q5QQaUgfwoxAEwR9A2Rw+5YYPX6fi+DD9pD6AfYIgVKIC1oL9C9foig/TwAq8z10nER3++Bo3fJhatDM/Y2OKIQiCPoADACYR0Tup41EkQRA6AYgiouuCILSQOBxl0gBgDWA8EV0RBGE1gF8BzJI2LMZUIycAVCMv4Jyg8OOcoFAoVDkBF9gfEVHr3D4nCMJoAL4fb55XBUGQASiJD09aCozcrlEQhFr48OTopiAIwIcpUv8KgtCAiCLzMUSF+Nz3EgAEQRgMoBOAVgUxIcrFcwCmWf5t8vH/FTqCIGjiw410JxH5Sh2PEjQB0EUQhA4AtAEUFQRhBxENlDguRXsG4BkRZY427MeHmyljklOFnABQjbyAcwIAnBMUZJwTFEA8RTxvDgGwBwBBEKoA0AIQLWVAikREt4moNBGZE5E5PvyQWxe0m2heCILggA/TbLoQUZLU8ShQEIDKgiBUFARBC0BfAEckjknhhA+Z3mYAIUS0Qup4lIGIXIjI5OPvYl8AAYXwRoqPf1+eCoJQ9eP/agXgPwlDYiyvDqEQ5wSA6uQFnBMUbJwTFB6FLSfgEey82QJgiyAIdwCkAnAsRE85VY0ngCIA/D4+lb9MRKOkDen7EVG6IAjjAJwCoA5gCxHdlTgsZWgCYBCA24IgBH/8f65EdEK6kNh3GA9g58cEMBzAEInjYSwvOCcoPDgnKNg4JyhcCk1OIPA9gTHGGGOMMcYY+348RZwxxhhjjDHGGFMALrAZY4wxxhhjjDEF4AKbMcYYY4wxxhhTAC6wGWOMMcYYY4wxBeACmzHGGGOMMcYYUwAusBljjDHGGGOMMQXgApsxxhhjjDHGGFMALrAZY4wxxhhjjDEF+B8wIIMgz1pyngAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "def logprob_fn(x1, x2):\n", + " \"\"\"Banana density\"\"\"\n", + " return stats.norm.logpdf(x1, 0.0, jnp.sqrt(8.0)) + stats.norm.logpdf(\n", + " x2, 1 / 4 * x1**2, 1.0\n", + " )\n", + "\n", + "\n", + "logprob = lambda x: logprob_fn(**x)\n", + "plot_contour(logprob)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Initial state and sampler parameters\n", + "\n", + "Since the algorithm doesn't have an accept/reject step, we can't tune the parameters of the bijection according to its acceptance probability. By weighing the samples we are are doing, in a sense, importance sampling; hence, an alternative would be develop and adaptive procedure that aims at reducing the variance of the weights. \n", + "\n", + "The algorithm samples orbits of length `period`. Each iteration, starting from an initial point sampled from the previous orbit, shifts its initial point's position in the orbit, hence making the algorithm irreversible, and samples the whole orbit, forwards and backwards in order to cover the whole period, for steps of length `step_size`. The samples are then weighted and returned with its corresponding weights." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "inv_mass_matrix = jnp.ones(2)\n", + "period = 10\n", + "step_size = 1e-1" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "initial_position = {\"x1\": 0.0, \"x2\": 0.0}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Velocity Verlet\n", + "\n", + "The integrator usually found in implementations of HMC. It creates an orbit by discretizing the solution to Hamilton's equations of the Hamiltonian function \n", + "\n", + "$$ H(x, v) = \\frac{1}{2}\\left(\\frac{x_1^2}{8} + \\left(x_2 - \\frac{1}{4}x_1^2\\right)^2\\right) + \\frac{1}{2}v^Tv .$$\n", + "\n", + "The plots include the unweighted samples to get an idea of how the integrator is exploring the sample space before the weight's \"correction\"." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 557 ms, sys: 5.59 ms, total: 563 ms\n", + "Wall time: 567 ms\n" + ] + } + ], + "source": [ + "%%time\n", + "init_fn, vv_kernel = orbital(\n", + " logprob, step_size, inv_mass_matrix, period, bijection=integrators.velocity_verlet\n", + ")\n", + "initial_state = init_fn(initial_position)\n", + "vv_kernel = jax.jit(vv_kernel)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/cabezasg/.local/lib/python3.8/site-packages/jax/_src/tree_util.py:188: FutureWarning: jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() instead as a drop-in replacement.\n", + " warnings.warn('jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() '\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 2.7 s, sys: 27.3 ms, total: 2.73 s\n", + "Wall time: 2.76 s\n" + ] + } + ], + "source": [ + "%%time\n", + "rng_key = jax.random.PRNGKey(0)\n", + "states = inference_loop(rng_key, vv_kernel, initial_state, 10_000)\n", + "\n", + "samples = states.positions\n", + "weights = states.weights" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAA9gAAAF1CAYAAAATN0JoAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOzdeZxddX34/9fn7HefO/tM9hAIIWENSsQFqYpCkR0XcGktX2219WtrNzdUtH611db2p60LVm0FN1BRqpVFBVHDkgCSAAnJJJlMZp+5+71n//z+uJNhJhtZJgnL5/l48GDuOeee87kXvefzPp/P5/0WUkoURVEURVEURVEURTky2vFugKIoiqIoiqIoiqI8H6gAW1EURVEURVEURVHmgAqwFUVRFEVRFEVRFGUOqABbURRFURRFURRFUeaACrAVRVEURVEURVEUZQ6oAFtRFEVRFEVRFEVR5oAKsBXlKBFCXCuEuOMgj/0jIcR9R7EtR/X8c0EIsV0I8erj3Q5FURRFORzqvn9o1H1feb5SAbbyvCSEkEKIZXts+5gQ4lvHqg1SypuklBfMxbmEEL8SQlw3F+dSFEVRFKVJCPEBIcTP9tj21H62velA51L3fUVRQAXYiqIoiqIoygvXvcC5QggdQAjRA5jAmXtsWzZ1rKIoygGpAFt5QRJCvFIIMSCEeL8QYlQIMSSE+OOpfUuEEEUhhDb1+qtCiNEZ7/1vIcT7pv7OCSG+NvX+XUKIT864Ic+aniWEuEAIsUkIURJC/LsQ4p49n04LIT4rhCgIIbYJIS6c2vYPwMuBLwghqkKIL0xtP1kIcacQYnLqvG+YcZ42IcSPhRBlIcQDwAkH+C4cIcS3hBATU5/7QSFE19S+PxZCPCGEqAgh+oQQ79rHd/i3M77Dy4QQFwkhNk+164Mzjv+YEOIWIcR3p863Xghx+n7apAkh/l4IsXWqXd8TQrQ+U3sVRVEU5RA9SDOgPmPq9cuBXwKb9ti2VUo5qO776r6vKM9EBdjKC1k3kAPmAX8CfFEIkZdSbgPKwJlTx70CqAohVky9Pg+4Z+rvbwAhzSfbZwIXAHtN6RJCtAO3AB8A2mjeuM/d47Bzpra3A/8IfE0IIaSUHwJ+Dfy5lDItpfxzIUQKuBO4GegE3gT8uxDilKlzfRFwgR7gHVP/7M/bp76HBVNt+1OgMbVvFLgYyAJ/DPyLEOKsGe/tBhya3+H1wFeBtwCraXYOPiKEWDLj+EuB7wOtU23/kRDC3Eeb/gK4jOZ33QsUpj7TM7VXURRFUQ6alNIH7qd5r2fq378G7ttj2+7R62+g7vvqvq8oB6ACbOWFLABukFIGUsqfAlVg+dS+e4DzhBDdU69vmXq9hOZN59Gpp6cXAe+TUtaklKPAv9C86e3pImCjlPIHUsoQ+DdgeI9jdkgpvyqljIBv0rxJ7u8J7cXAdinl16WUoZTyYeBW4OqpJ+lXAtdPtWvD1PkO9D20AcuklJGUcp2UsgwgpfwfKeVW2XQPcAfNG+jM9/6DlDIAvkOzk/CvUsqKlHIj8Dgw82n1OinlLVPH/zPNm/SafbTpT4EPSSkHpJQe8DHgKiGEcaD2KoqiKMphuIeng+mX0wxuf73HtnvUfX/6veq+rygHYBzvBijKURLRnPI1k0nzR3q3iamb3m51ID319z3AJcAAzafWvwLeSvPp8K+llLEQYtHUOYeEELvPoQE799Ge3pnbpZRSCDGwxzHDM/bXp86ZZt8WAecIIYozthnAfwMdU3/PbMeO/ZyHqfcsAL4jhGgBvkXzJhdMTVf7KHDS1GdLAo/NeO/EVMcAnn6aPDJjf2OPzzDzO4invoPe/Xy+Hwoh4hnbIpodj/229wCfUVEURVH2517gPVNTkjuklE8JIUaAb05tWzV1jLrvq/u+ojwjNYKtPF/1A4v32LaEA99wZrqH5hPbV079fR/wUmZPD98JeEC7lLJl6p+slHLlPs43BMzf/UI076Lz93Hc/sg9Xu8E7plx3ZapaWR/BozRnL62YMbxC/d74uYI/sellKfQnL52MfA2IYRN8+n4Z4EuKWUL8FNA7O9cB2G6TaK5xn0+MLiP43YCF+7x+Rwp5a79tfcI2qQoiqK8sP2O5hTk/wP8BmBqhHRwatvg1PIxdd8/dOq+r7zgqABbeb76LvBhIcT8qcQZrwZeT3Oq9zOSUj5F8ynsW2je0Mo0n9BeyVSALaUcojl16nNCiOzUdU4QQpy3j1P+D3DqVDIQA3gPzXVMB2sEWDrj9e3ASUKItwohzKl/XiSEWDH1ZPkHwMeEEMmp9Vlv39+JhRDnCyFOnZpiVqY5yh8DFmAzdeOeeqp9pOVHVgshrpj6Dt5Hs6Oydh/HfQn4h6lZAgghOoQQlz5DexVFURTlkEkpG8BDwF/RnBq+231T2+6dOk7d9w+duu8rLzgqwFaer24Afkvz5ligmTzk2ql1SQfrHppToXbOeC2A9TOOeRvNG9LjU9e5heYaqlmklOPA1VPtmABOoXkz9w6yLf9Kcy1SQQjxb1LKCs2b3ptoPgkeBj5D88YI8Oc0p2gN00zI8vUDnLt7qt1l4Impz/nfU9d4L/C9qc92DfDjg2zv/twGvHHqfG8FrtjPFK9/nbrWHUKICs2b8TkHau8RtktRFEV5YbuHZvKw+2Zs+/XUtpnludR9/9Co+77ygiOk3HMGiqIoR9vUNKkBmkH/L493e44FIcTHaCYoecvxbouiKIqiHEvqvq8oLxxqBFtRjhEhxGuFEC1Ta5w+SHM0fF/TpBRFURRFeY5T931FeWFSAbaiHDsvAbYC4zTXg182te5LURRFUZTnH3XfV5QXIDVFXFEURVEURVEURVHmgBrBVhRFURRFURRFUZQ5oAJsRVEURVEURVEURZkDxvG4aHt7u1y8ePHxuLSiKIqiHJZ169aNSyk7jnc7nm9Un0BRFEU5WnavhhZi7s+9v37BcQmwFy9ezEMPPXQ8Lq0oiqI8j63bUWBt3wRrlraxelF+Ts8thNgxpydUANUnUBRFUY6OyZpPse4D0JmxSTvmnJ5/f/2C4xJgK4qiKMpcW7ejwLU3rsUPYyxD46br1sx5kK0oiqIoyrNfHEtKdZ+0bRDGMaVGMOcB9v6oNdiKoijK88Lavgn8MCaWEIQxa/smjneTFEVRFEU5DjRNYBkadT/EDWIcUz9m11Yj2IqiKMrzwpqlbViGRhDGmIbGmqVtx7tJiqIoiqIcJ11Zh3IjIIolGfvYhb0qwFYURVGeF1YvynPTdWv4wfoB5PFujKIoiqIox5UQgqoXEcYxNT+ityWBZRz9CdwHfQUhxH8KIUaFEBtmbPsnIcSTQojfCyF+KIRoOSqtVBRFUZSDdOv6Ab7zQD/X3riWdTsKx7s5z1uqX6AoiqIcLXEsm8u+4sN/ZO6HMUEUk7INYilxg3AOW7h/hxLCfwN43R7b7gRWSSlPAzYDH5ijdimKoigvMOt2FPjiL7ccUVCs1mEfU99A9QsURVGUORbFkqFSg12FOoOlBmEUH9Z5DF0gBDT8CCnBMo7NOuyDniIupbxXCLF4j213zHi5FrhqjtqlKIqivIDMVQZwtQ772FH9AkVRFGUuhVGMJgReGOGHkpRtUPVDvDDG0A9uXDiImqPWuhDEEtrTNlEUEUrQpophx7FE08T0sZauoWsCMUfFsudyDfY7gO/ub6cQ4p3AOwEWLlw4h5dVFEVRnotm1qze18jzwQbYu8+TT1oU6j7XX7ySX20aZaTssmm4AjC9/5ebRhktu+QSJhsGyyRMjfecfyLXnKPuS0fBfvsFqk+gKIqizFSs+xTqProQtCQtJBIvjECCrh1c4BtEMYPFBn4YM1iokUvaWBrsKNQRQpA0dVqSNroOhhDsLNbx/JjJmgcIFrclOWtRK6kjLOc1JwG2EOJDQAjctL9jpJRfAb4CcPbZZ6v8M4qiKC9ge45YX3/xysMaeb75/n6uv20DUSz3mdjs0YHH0DU40OyyD/7wMQAVZM+hZ+oXqD6BoiiKslsUS4p1n6SlE8aSehDSk0tQ80NaTW1Wia146n4/M+iOYkmtETBZD+gv1Ci7AdvGahhIRqseE1WfMxa1okuBZbos7kiyYVeJhGmwc7LGcNmlM2M3A3xdsGZpB9pBBvX7csQBthDij4CLgVdJKdVNUlEURXlGP1g/gBfESJoj1hsHS1xx1nwEcMVZ86dHr9ftKExnBb9yavvN9/fzxV9todzwqbrRM2YMP5ilWz/bMMQ15yycNap+OFPUFdUvUBRFUQ6NH0Y0goh4ahq3aWnEUpK2jeng2g0i+ifqTFY98mmLrqyDY+pUGgF3PT7MWNUjjmPa0jZDBZekI7h/2yTFuk/CMtm0q0gyYfKiRXkMTcMNI9pTDjU3oNQI6M06yFjiRRBLCTFM1jy8MKYlaZKyD35U+4gCbCHE64C/Bc6TUtaP5FyKoijKC8O6HQW+/9DOpwNjIfj+QzsJY4llaFxx1vzp49781eYoN8C37++nLWMxXvHnvE0XruqZs3XgL2SqX6AoiqIcCjeIGC55RBEMVussak3hhRFVr5nxuytrk7JNRssepYaPoQvGyi5bRyrkUiaPbC+yq1gn5ZhUGz5dWYeEpVFrBDRCia7rFBs+vfkEi1ptWjMJGkHIqb0teGFMPmWRtDQm6z7Lu9MsbU9h6BrFuk+5EWKbGiNljwWtOuZBrgM/6ABbCPFt4JVAuxBiAPgozeygNnDn1KLwtVLKPz2kb1VRFEV5QVnbN0E4o+xGHEtimB7N3r3+em3fBEH49PCzhMMKrndP8tI16Mo67Cq60/syts4HLjqFa85ZyBd/uWV6VN0PDm0d+AuR6hcoiqIoR8oPI/wwou5HaEJQdkMsQ5BLWPhhTN2PSNkmEomhaZRdn22jNYqNgIyjs6tYJ4hjxio1JBqmoXHmolaGSw36Sy6TVY9SLcYQkvmtGbK2Tk9XmqxjMlHz6Mk5ZBIGVTeiLW3TmrKo+yFhLAnimHo9xAtienMJzINMQn4oWcTfvI/NXzvY9yuKoigKNDN9CyFgxuxhXRNIKTENjXzS4ou/3EI+aWEa2vQI9oHoWjNhiRftPSPZ0AVXn72AK8+az6bhCtfftoE4llimxjfecQ4AX/zlFiqNYHpUPQbySWsuPu7zluoXKIqiKEdKExq7ig3Gqy5taZuOtEAXgqoXIiVkEwZVLyTrmIRRzObhIg/sGMf1ArqzDjECWxe4gWRB3iZrm7SmbbpyDk+Nlqm7Ad05h13FBq4fkk6YeKGk4oVUvYgglhTrAe1pB8fU2FVoIAFTFxTrPlEsSVgGNT/EsQ4uwp7LLOKKoiiKcpBmBtdw3cuWsHGozMqeLDfcvnF6mvaaJa3c+9T4QZ3R30dwDRBGknktCQBuuH0jsWyW57j+4pUA09PCNSEQUy3TBBTqcz8VXVEURVGUp9X9AFOHaiPED2MytsEpvTmKdR9DE5RqPgEgY4muCx4fKjNachmv+YwUXU7ozrCkI8PSjiwdGRuEZOdknZN7Mpy+IM/AZJ2WfJKGH7FtvM6KnhaSpk6xHpCwdBJWgoYXMT+fYKjk4pgahq5RbgS0Ji3StkEo5ayZd89EBdiKoijKMXXr+oFZicfCGP7zt9sJo5jfbp0gmrqJBWHM+v7CQZ3zQInMJOxVCkwgKdT9WduQctZIuqqhrSiKoihHV92PGS17xEJS9yOiMGbHZB1TE0zUfCZrASe0J0k7BrtKddwgJpuwgBgRayRNg7ofMVEt0z+pMy/vEEpBZ8YmYep0ZG3CGBa3pUhZBroO+ZRNGLvU/QiApK03y3hZOuNVDy2ISdo6CUOn2AjQNUEucYySnCmKoijKodpX4YvpaeAzpo1rmpi++R3p9TYNV1iztG2fpcBmbrv+4pUU6r7KIq4oiqIox0DC1EiYGm5kYArBQLFBRyzRdY0oiklZGqNVjwhJwjCY12IzUKgRS42lnUl0Q2OwWGPreA1NSlb5ObK2yX1PxfS0JDh3WSdbRyskTJNc0mDHRB1d0+jKJnCMZiK1zFTwnEuYWIZGLCFh6uiaIJsw0YQ4pLJdKsBWFEVRjqkrzprPtx/o50CzrQRwSk+Wx3aVjvh6Erj+tg18910v4abr1uxVhmtf2xRFURRFOfpySYueliRGxWOi0qDcCJms+fhhzLy8w+kLW3HDmO6MTckLueSMhaya30qx6pOydTYOFhkqNmhP2oxUXIaLPlaHyULHIIhiym6IoevEMqZQD7FNg6obkrZD8qnZuVaao9izw2PjIDOHz3To71AURVGUI7B6UZ5PXnYq4gAPg21T440vWohlaPsc8T5UUSyns4K/5/xlswLpfW1TFEVRFOXoc0ydlfNamJdPYJsGk27AlrEKgQypeCGWoTOvJUFr2kZDsG2yxmjVY2FHklzKoSeXQBMCQwhyjkV7xmZRe4pM0mzWs45iurM2tmEQhjFuEBHGHLAPcqTUCLaiKIpyzF1zzkK++2A/jw7sPULdkjT52ttfNB3wfvhHj82cOX5YdE2oNdWKoiiK8iwkkaQsg/asw0jFRZOwqD2NjDV6cza63szerWvQlbExBUgpKNY9FrdnuHCVxpPDVfJpkyWtadIJEzdojlbHkSRG0pax6MjaWLoglzDIOge/pvpQqQBbURRFOebW7SjQmXWAvQPsN529YDq4PthM3jOrfomp1/GM14vbU2warqhRakVRFEV5Fqm4AWMVr5lwNJZ0pi2CGIJIsqI7hRdJLGIMXcM2dcIoRtc1BksNZBxjWzpd+RQvOqGdloRNGMVEMmaiGpBLGDT8CAQkTINSzWOw7DJadmlNWaSPUpCtAmxFURTlmFq3ozBdGsvQYGVvjlzCpH+yzhkLWsgkTG6+v59C3SeftDD0fdfCNjTBDZeuolD3qTQCbrxv23R96+svXsmGwRJbRio8sL3AltEqH/zhY0Bz9FxRFEVRlOOv3AhwTI20bWDpgpcsbSOS4AUhxUbIcMnFDUJySYuUZeDrOhqChKFRrEcEQYypCzrTDo5l4AYRu4p1SnWfshvQnrLpytqkbJO+8Ro5xySKYfNIhbMWtR6Vz6QCbEVRFOWYml0uCy5Y2c17zl82HXjf9sjgdJVsx9Q476QO7nx8ZK/zXHxaD9ecs3D6fbGUCE3w8hM7WN6d4ZpzFvLWr90/6z0/2zCkAmxFURRFeZawdI3xmoehaaQdA3sqyVgYx0SxxAsj+sarLO1I0/AjkpZBS9JiqNRgy1iNrO2yuD1N1QtxLINiPcDUNRa1p9gxUSOWMUEkkVICEqFpCCmb67OPEpXkTFEURTmm8kmrWfJCNNdGDxYbrNtRmA68Z97y/CCmM2Nj7KM8xkStOX18ZsAexZK7Hh/h2hvXsm5HgQtX9cx6z56vFUVRFEU5PuKpADqKYhp+gKkL6n6IlBLb1Cm7ITsLdepeTM1tltSyDI1GGBFMlfAyNA0hwJua6eYYzVlvpXqAH8Y4ls5E1cMNYk7pyeEGMVLA8u7sUftcKsBWFEVRjpl1OwrccPtGolgiaJbQ+vYD/Vx749rpwHsmTROs7M3xByd37nUuAXzxl1vIJ61Z2cYlEIQxa/smWN6d4YJTujh9fo5PXX6qGr1WFEVRlGcJP4oJYujMJpDAwGSDoaLLeNXDNnTySZPFbSmWdqZo+DGGpuEYGhnbQBOSSsNnsuEzUm5Q9wIGCnVA0pG2scxmDWtdaAghCOMYL4zoaUlwYmeGjEpypiiKojwf3Lp+AC9ojlLHEuJITgfEhbrPDZeu4vrbNhDFEl0TXPeyJdxw+0b8MJ6VyAzg10+N8+unxqeP6xuvcfcTI82p50JQaQTTa70tQ2N5d+Z4fWxFURRFUWaIY0mx5jFUbJC2DMpexPx8gqSpU/MiOjLQmbUZKXu0ahrzcjrtaZudxQaGrtGScDihS9KatCi5ISNll7JboVD3Wb0wj2ka+FHMUyMV2lIWJQ0CCUlTp1D3Sdo6tqEflc+mAmxFURTlmFi3o8At6wamp4BrGoBASolpaKxZ2sbqRXmWd2dY2zfBmqVte63Xnmn3ecJY8qV7+7B0QTS1MYolN963jSiWs0a0VRZxRVEURTn+6n5IPYiZ35pguNjAEDBcapBP2rSnLQBStsmCVp1YSixdw49ipATb0EnaBuNDAcMVD9cLuW19PzsmaiDhxK4s1523lMXtWTaPVDAMjdGqj23qOFNBtdirVzF3VICtKIqiHHXrdhT4/F2bCabWSAkACbFsjlRff/HK6eB39aL8rEDYMjSCsFmWI9hjjfZMfvT0nuYIefPcMwN4RVEURVGOHyklkzWf4VIDN4iZ1+LgR5KerE0MaELQlranjzf1p1c0W7pG1jGouAExcNr8FqI45pdPDjNadQljSFk6Oyfr3LlhhMtXW2gC0raJlBJNCMJY0pZqLi07WlSArSiKohxVM8tySZrJPzRNTI8uSyn3W+969aI8N123hrV9E+wqNvjOA/0cTOJPQfMmfd3LlpBJmNOj44qiKIqiHD9uEFNqBGQdg/FqnZGKR8YxsE0dMZWHRYh9jy4LIejMOrQkTdKOyc6JGltHa3gxpByLiWqVejViUWuK9oxFI2j+PVSs4wYx3S0OnVn7qE0N300F2IqiKMpRtbZvYnrdNcCp83MsaU9x2yODQHMt9iM7i6zbUZgVBK/bUeDW9QOMVzwkzaBZ1wRxtP8IWwCXntHL7b8fIool3/jddm66bo0KrhVFURTlWaLmBhQaPn4oMXWBjCT/u2GIgUKNhfkkZy5uY2l7GtvUiaIYP4xx/ZDJetDMIB5K8gmDMI5JmAYrujOYmiSla3iRZGVvipRjUfdj2jM2XhgzL28SxjHlRkBHRgXYiqIoynNYPmnNmtb92K4Sjw6Upl9L4M7HR7jz8RHmtzicvbiVB7dPsqvoHvK1JM3yXbFUa68VRVEU5dlGIKn6EWMVj5xjEUaS7eM1CjWfiWrA5qFhfrNlnM5ckuVdKQSCihtQckNSpk5rxkETggWtCWIp6colcCwdGces7Glh03CZTSMVhsoBSMkZC1twLJ0wbtbDzjhHv4iWCrAVRVGUo6pQ96dLckFzxHp/BoouA1Mj24dry0gFY2oK+u6117vrbKup4oqiKIpy/ASxpCNjYxsaNTcinqoaEoQxxYZHqRFS8UJiCbqM8WOo+SETNQ9T1zhtXo6eliQDEzVStklnl00Ym7SnE5TdgHI9JGPrtCUdto3VGCvVmNeaoRFEZByNhKkzXHLRNGhNWhj63AfcKsBWFEVRjqo1S9swdTErCdnRNFT20DXBH6zo4k/POwFgVrkuNWVcURRFUY4P29AxdI04ljwxXMSPYiwhGK+4FCs+QRiScExiKakHEcV6SKHhY+kaETGP7SyjCYFjaui6YMtYdTrxmecHoEsagSSQITLU2DreYKIWsrgjTT5pMVh0kUhiKYljSXcuMeefUQXYiqIoypzY3yjx6kV5PnbJKj74w8eOWVuiWHLn4yOc0J4ikzCnS32pKeOKoiiKcnSFUcxoxcUPJfmkSS5pTe+zDI2ujM2vnhhmy3CVwVKDMI7oTDs4FgSxRhBKHAuytkFX1uHxwRJxHJPPJrCEoB5EWIbOZC2gOF7npK4MwxWXUt2nI2lhCkHJDZnXYhPGMU8MlXliuMzJXTlaUhataYs4loQHmlJ3BFSArSiKohyxmZnCZ44S7w66dxUbx6VdX7q3j8vO6J0u9aXKdSmKoijK0VVqBM0g2dSYqPskbQNT16i4AZNTeVKqXoSmSaI4olB1cf2YUsMnk7BIGDoL8xlsU0NoGp3ZBEEYsagtjSRmsurzyFiBzqwDSAYmKjQCiaYJ/FjSnnOYrxmM11x+9eQopoBc2qIlYWEagoSlowlBV9Y5Kp9fBdiKoijKEVvbN7HXKDE8PTX7aKxxOlg/fnSQT152KoW6r9ZgK4qiKMpRpgmBRE6X1RTsHtX2SJg6bhCypCPB1pEyUSywTBMhY2zdxDY0whhGKx5JyyDrGOTTFrYGxDElL8QLI6I4QhOCqhtQsXTcMGZePkGxFqAB89qSJC2d/mKNMIgJopiJmk9n1qEjbaNrzUD7aFABtqIoinLE1ixt22uUeGbQHUUxyzrTbBmtHvO2SQkbBkvMa5n7dVaKoiiKosyWTZh4YTOo7UzbGLpGGMUACAGa0Dh7YRtBKNk6Umao7CJlM6iueSEL25Kc0pVjqFKn3AjIJEwiCbUgYkl7iq0jFXRNp+EHyBjOXNjGULlO0tJZ3JZBxDETdZ9syuT0TAuDBRdDCBa3pbB0jcFSA0PTaEtbpGwDXQg0bd+1tw+HCrAVRVGUI7Z6UZ6brluz1xrsmUH3qt7sUQ2wE6aGNxXQ7yYAUxfcsm6AMIoxNMHVZy/girPmq5FsRVEURTkKdE3QnZs9/drQNTozNhNVH8sQJCyDhKXTmUtS9gKE0FhzQgexlCBguORCFDNUrjFY1lg9P8eugosXRIRRTHfOIQgjIl3iRhFp2yCXMGn4IUHYrH99QkcGP4pZ2poi1jQytkGxEdCStAijmO3jNXIJC12H7mwCy5ib2XYqwFYURVHmxOpF+b2Sm80Muj9/1+ajev1GEKNrMD/nIIF5LQmWdWUQwLcf6CeW4EeSm+7v59b1AyqbuKIoiqIcQxnHJOOYANTckIYfU3ZDkpY1vX1hW5r2lMXtk7sYrgY0fMlwsUKp7mEInY6MjRvEZByDrpxDoe4zUfWwdR2ZkOwYr5JNmNiWzsBknYRtUA8iunM67RmLtG1QbQSU3ICaF9KVdQiimIob0Ja25+RzHr9FcYqiKMrz3upFed5z/jJWL8qzsie71/5lnWn+9BVLec0pXXNyvShu1tLeVXR5YHuBW9YNsLI3t9cacC94ep24oiiKoijHlqZBZ9rB0jUytklXxubsxa2s6s2yoC2FoUHdD5FRTCQlk1WfWhBS8QOStoYERqsuPbkkPVkHISAMwYsgiGJqbkTJDRBAteEzVPAZq/jkkhamoaEJqPkB63dMMlhyEXOYUFwF2IqiKMoxkUmYe227/Mx5vGZlN3O38mk2P4zZMFjiqtXzZ22XQH5G2RBFURRFUY4dQ9dIOlpzKVncnBauEbOr2ODBbRMkbB3LNKj4IZZp0JN1yNomQQDtaZt5LQkylkk+ZWFoGmnHJGHrzGtxyCUtlnYkyadsyn5AJGPcMKJ/vMZopYGha7SmLbK2jakLEoZGzNxF2GqKuKIoinJMrFnahiaYtUb633+1hc/+fNNh3db2PJeg+UR8Ko/KtO89tJMbLlmFoYnpmpcCKNT9w7iqoiiKoihHytQ1NCEYKtaJYsnarTUGSzVaEgm8KCSM4eqz5rG2b4yRcgPDMCGMactY6JrBeC3gpM4knVmbk3uz2IZOxQuZl00y6QYsak+TS5jsnKzSH0rGah5p26RvvMqK7hyuL/HjiJ6WJBnHQIjjkORMCPGfwMXAqJRy1dS2VuC7wGJgO/AGKWVhzlqnKIqiPG+sXpSnJWEyWQ+mt9W86LDO9eLFeVqSFr/aPEYYxWhCcMOlq1jeneHL92zld33jVNzmucNIsmGwxHUvW8JXft03HZQ/NVI54s/0Qqb6BYqiKMqRiGUzGep41WWo7GKbgmIjZlE+QaXus7Pgsqw7y6r5rVTdAMc2mN+SwA0igiimJWmTT9rYhoZtaoSRBqbO4s4UKdukVPexLYPWjIUfSRa2pSjUPYjBNjVMIVi7dRRD1zn3hDba0jZSyiMOtg9livg3gNftse3vgbullCcCd0+9VhRFUZR9WtaZPuJzCGB9f5G7nhghimJevaKL777rJVxzzkIA7n1qbDq43m284vGN322fDq4l8KNHBnnb1+7nrV+7n5vv7z/idr0AfQPVL1AURVEOU2/OIYoFY1WfMIopNWJ2TdSoeiFL2pOs6MmysjdPVzbBit4c3dkEtqGjC52E2UyY5gYBk1Wf/vE6w2WXloRJym4uSSvUA5KmTme6uUa71PBJWjqmoRHFkqoXMFH30TTBA9sKbNxVYuOuIhMV74g+10GPYEsp7xVCLN5j86XAK6f+/ibwK+DvjqhFiqIoyvPWZWfO54HthzagKWhOBxeaQMbNJ8tRLKenlf/iyVHedd4JANO1t2eyDI32jL3XdoB7nxoH4NdT/94dpCvPTPULFEVRlCPRkrI5uTeDoUlGKz4ZR6fc0Cg1QiZrZV5qm7RlTMbKIaZh0ZV2yCZN0o6BrgliCcWaT6nh4QcxUsJQ0SVhGViGhmUIvKC5uvq0+Xmyjo6h62gCxqseW0YqDE42sDWdiapPV9ahLW2xeaTCWY6BbeqH9bmOdA12l5RyaOrvYWC/aWCFEO8E3gmwcKHqwCiKorwQFeo+Ag56zbWhwRtftJArzprPpuEKP9swxMqeLDfet216PXUcS9b2TbBpuMIdG4ebwTjNOpyvXN5Je8ZmVW8OTRPE0f6v/LMNQyrAPnIH1S9QfQJFURRFE4JTunPkExabRioICZIGsYypNAJ+8tggJ3TlOKE9QVvaQhcanVkHXRPUvICJqk8USxp+jBtEtCZtDEMwWXMZKnhomqQ1bZOyLbKOQaHuU2i4+EHISNmj6sckLYNGENORNZuBuanTCALCWHK4RbvmLMmZlFIKsf8E51LKrwBfATj77LPnMBG6oiiKcjSs21GYrmG9elF+r9eHY83SNmxTa44mS0hYOjV//+uw4xh6WxIA3HD7Rvww5sHtk1z3siXceN824lhimRpPjVT40SOD0+97zSldnL+8c/o9hiaI4wPfei5c1XNYn0nZtwP1C1SfQFEU5bmlWPcp1gMcU6ctZVFo+NS9iGzCJJ80p9ctu0FELCWOoaNpB17LbOgaXTkHy9TpbU0gJPzwoQF2lV0m6wGmJtGIeXy4SlvGxtB1urI2tqEzWvamRphj2jM2+dhCArqAB7dN4oUxSdvAjyVnL0rihTEVNyRtG4yWXYQGL1qSZ/t4nbRjsqwzBQhqbkhPziFpHd7oNRx5gD0ihOiRUg4JIXqA0SM8n6IoivIssG5HgWtvXIsfxliGxvUXr5wOVi1D46br1kwH2QcbeN98fz8/2zDE61Z2s228RmfWQQB3PD6y3/fENMtp7Z76HUsIwphMwuS773oJa/smyCctPvyjx2a9zw0iCnX/6fdEctaouQDe9YqlLGxL8bMNQ1y4qkeNXs8N1S9QFEV5nvHDmEK9uX654UeMVz0aQUTK0inUPJKWjmPqVNyA0XJz/XLC0ujJJWYlDAuieLof0fAj+sYqNPyYXMLAsQwsXXDqwhylzQFj5YjJRsBDYYF6GBIFEcu603RnbFrTDjPv6o5p0JW1kRKGyy66BrmEScOP8IKQWDZHywHCKAYpqLghbhDSlbE5dUGehW0pgigmiiWWrh1RorMjDbB/DLwd+PTUv287wvMpiqIozwJ7BrQ/2zA06/XavonpUe2ZgfjMwBueDr4rjYAv3du3x1VK6AeRavNXm0Zpz9gYukYUxZiGNh3Mr16U50M/fIw9B6cvXNXD8u5Ms75mGCMEzFyC/a5XLOXvL1oBqHXXc0z1CxRFUZ6HBM01z7tvt3Lq75mBaNULm9m5dY2aGxLFEkNv7vfDmF2FOlI2/+6frLGzWKPmRuQSBit7s8RoCCnozFrUPQsvithZqOKGIUEQExGTTzksDCVBFAMRlqHRlram2iEJoxhT19gyVkEXGj0tDlJKLEOjK2tTboTkk81Rd4QkloJ5+STQLB12mMuuZzmUMl3fppm4pF0IMQB8lOYN9HtCiD8BdgBvOPImKYqiKMfbmqVt08GpaWhcuKqHB7dPTr9es7QN2DsQ3x14w+xR8P3Zs2b1vtz5+AhCNBOdrZqX440vWjjrGt9/aOes43UBy7szrF6U56br1rC2b4LBYoNvP9A/9RQbMgnzML8ZZTfVL1AURXlhsAyN1pRFuRGQtnVakxaFuk/Nj8gnLZypqDRp6oxVmwnHLFND12aPXksg5RjUyy5jNZc4hOFSnR1jEi+UmLqgM+vQnUsyWgloT8d4YUhxwmMMl189Ocqi1gwrenOMlgJ0oZG2DYyp61TcAC+M6cwm8CPJCR0ZLF1Q90OyCYuU3cww7pg6xUaAqYtmWa45/r4OJYv4m/ez61Vz1BZFURTlWWJmcLp7tHh5d2avqeB7BuK7A2+YHXwfyc1L0nxSHkt4dKDEhsENQHPkeW3fxHSys91i2bz27jbsbtOt6wf22U7l8Kh+gaIoygtHNmGSnfFwuj3j0L7HMbmkhWloxBISpj5rdNucmrJWmxrl7kzbPFooYhkGlh6zq1gjZZvYuqC3NcUpvVmKVZtizSUCNBFTdWN+u2WMWhCiaxqL25PsKtZZ3pWlK+dMz6bThcFo2cP1I9pSFp1Zm2K9mRAt45jkpj5HEMXkkuYzrhU/VHOW5ExRFEV5ftk9BXt/r3dv2zMQ323P4Pt1K7tnJSI7XFEs+chtG1jenWHN0jY0wawp4hK4v2+Cf7lzM1Essc3m1PX9tVNRFEVRlLmRtPYdXlqGxvx8Ej+KsXSNefkktq6zs1BnrOLRmjRwTJ2CG+EFEa1pk3zCpNDIM1hoUPMjDE0jkiG7inUsQ9CTtSk1IhyjhmPpZBImu4oNtoyWQcbkkgZVP6J/sk7FjUjbOjU/ZH5LklzCRAiOaK31/qgAW1EURTkiMwPvPROe7Q5q81PTyS47o3dOguzdpbnySYt9zUD/9VPj0+vEvCDmB+sH6G1JqOBaURRFUY6RIIqpeSGmrpG0dAwNokgSRhFlN2R+awI/iohiST2ICGVExjaoeRG5lElb2mZpR4YVPS2M1D08L0bXDbpzDmEYMV73aUla5BIm5UaAlDBW8YgjKHshCcsgjiNGyx6OqTFWCXDDmIylU/IiNAGdGWd6ivtcUQG2oiiKMidmrrnWhOCGS1fRP1HjR4/sYrTiIaemih9KHez90TRBPmnxsw1D+9wv9/j7uw/2TwfiK7ozfPLyU1WgrSiKoihHSRxLhoouNT+gWA9oT1kMFhsMFRs0/IgTulPsnHTxo5gWx6DuBaDrmIYgkzDQEfhRjKPrmJZGp+5gZKC3JUlPJklH1sY2RHOEutCgNWUSRBGWrnFid5ZGGFH3Qua1JJAwPYpNDFulxAtCNuwsY5jwqhXdLGxNz9lUcRVgK4qiKHNibd8EXtBMYhJLyYd++NhegfThBtbzWhxetLiVDbtKbB2vEcWSj/1kI+84dzG/fmr86ePyCeblHNb1F4lmzBufOcr9xHCFN3z5t3zvXeeqIFtRFEVRjoIwloxWGmwdreKYOjvGqwwUGrSlLQYKdbwwxjQFptDYMVGjEcYsSNiEoSSOJS1pC1MXRFKysjfHWMVnSVuKC07pIZUwCaMIN4h5bFcRx9Dwg5iqFyI0KDYClndmmdeaxA0jxqoehqGzpKNZymvHRI1NwyXSpokbCzbuKtGStGlJWnPy2VWArSiKohy2mVPCK41gr5Hjg3XBKV20Z2zu3TTKQNGdte+acxZy5VnzufbGtbjB05GyH8b8rm+CP33FUjYOlVnZk+Ubv9vOULGBJkDXBHEs99mOKGZWqTG1NltRFEVRjowbRFTcoFnuSoOdEzUKtYAd45PsKtYAjfmtCUxTxzAkYSgQekTZjabu2TGdWYeVPS305BNM1jwe2jGJrmnkkxYL21K0ZW1qXshIxUMAk/WARW1Jal6IX485sTODY+okTJ3hikvWMbE0QcEMsHSN8arHso40W0bKoEPaNNCY6i9ISdVrlhdL2wbGwdQS3QcVYCuKoiiHZc8a2Mu7Mod1HiHglcs7ueachbz6c7+avQ+48qz50xnJ9/ToQIlNI5Xptd4zs5a/8cUL2DlZ574Z67F307VmErZnquOtKIqiKMr++X7EZMNHF4KqH2JognIjJGnqtKRs3DBi+0QZ14/xA0mh7rOkPcXS1hQ9LTZh0Lxv17yQ/kKNjKXjRhF+FFOq+9iaTtY2CGPozDkIIQgiiaFpOKZGR8bC9SKKtYClHWncMCJlNxOm6ULghRGhhAX5JJahYRkahq6xZmk72ydr2IbBsq4M2YRJqdGczq5rgrof0TN1vUOlAmxFURTlkHz6p0/wvxuH6czYs2pgd2UdoHTI55MSPvbjZumtHZP1Wfve9Yql0wGvZWjT15tpd/3tPbOWX3nWfIDp+t0IyCVMlnWk+bsLV7B6UZ4v/nLL9Dl3J0NTAbaiKIqiPLMoirhv6xhjFQ8/jlmST3BSTwtSRpiGoCVhsnEwAKHRlrbYVaghI0EsY4aKdRa1Jfj9eAXflTw1XkXEkLMtFnT6+GGE60cgYoqNiJN7csxvSQKQtHVKjYCaF7KkPU3S1Bkuu7SnbbwwJoxiNE3QnXOmandr5BLNclyWodPwQ168tIPzTu5u1sGeCqLdMMI2NNwgYlfBnU6Cdqhrs1WArSiKohy0933n4eks4Nsn6uga6IBpaLxyeSe/2DRKGB36Sms/knzi9scJZrxXE/Cald0Ae2Uk3zBY4pZ1A4RRjBDNhGf7Kxl2/cUr+dmGIS5c1cM15yycdd180mreWGVzKvn3H9rJFWfNV0G2oiiKohyAlJKRssuOsTqLO1NU3YBCI6TmRWgaWIZOLmnxkiWt9E9U2T5Rx42gI6OTdizGqj5bRxvYQmdbtUrCgCjW2DZRZuLh5tRu3dCwNIFh6LSnLKypKdu2oTM/nyCMJYYQxEDDj9g6VsXQNE7pzU4f15GZnSHcD5uj47vrck8H10FEpR4wWvEI45ieXIIgaq7rnln/+2CoAFtRFEU5KOt2FLhtjxJbHWmbt75kMWuWtrG2b4J4z+HlQ9AIolmvYwnv/94jvG5lN5mEyZqlbbzn/GWs21GgUPd5x7mLufG+bUSx5IbbN7K8O7NXre51OwrccPtG/DDmwe2T08fM3DczGVo0Vf5LBdiKoiiKsn/lRkDFDUGTPDlUIusYnLEwz/zWBLoQuGHznt6WcfiDk7upegF9oxX6J+oIIXBsnVxCJ5tKMlhy2Vl2qfkhEljYpmPqgoExlyWtaTpyBr/ZOsFgyaMjbdOVc7ANDU0T7Co1iCWMll06Mw6mLijWAzLO3kFxw48YqXhYuqDihszLJ7CNZgA+XvFIJ0x0TTBccWlJWXh79EsOlgqwFUVRlIOytm9ir7XMl50xj/ecv2z6tTWVyXMfpakPy/aJOl+6tw/B1Cj5SR384skRoqkLzKx1va/AeOa67N1TyXcfs+e67t3XWLO0bY5aryiKoijPT40gQhMaK3qzbB+rsqK7hVN6W6ZHhB1DJ2HpVL2QfNIiCGMWd6RxTJ32tM38fJJ6EFOsh2QcnfaMiVeICNGoeD4ZR8fzm+unS3Wf9rRNxtZ5dGeR08lR9yOEkIwUPRzLYPNwibaszYJ8ko60s882R1IiaI5sB1FIPKOzomkCKcE2DdpTNkEoSVgGafvQw2UVYD8Dz/O46667mD9/Pqeffvrxbo6iKMoxsa/M2vk9ylcInp7CDbOncT81UpmeSr7ne/zxnQTFQZyFp6NZ+74J7knSzBp+x+Mj+91faQR7bd9zXfbM4HnmPl0TXH32AjU9XHlGjz76KIODg5x33nkkk8nj3RxFUZSjKohiGn5ELCWOqeOYzRHfWML28SrjNZ+ebIJ0wsALY2xDI4oluibozjr4YUyLo5NyDLwgwg8lGceg6EV0pk3yVshD99zPzjhFqnMRQRDh6DqTNZ+ko+MGAUs6srSlbNwoRiCxTZ3xuk/G1qmHEQOlGsMVj4Fine3jVV5xUhcLo9T0NPDdnKkkZzU/xDE1bOPp/e1pm4mqh9BgcXsayzi8DOKgAuxndOmll/Lzn/8cgNe+9rXceOONzJ8//zi3SlEU5ej59E+f4Cu/7pt6kvt0Zu1C3d/r2M/ftZkLV/VQqPuzgvE7Ng7vdWxULzF5x79T3/QbAKzuZXS/7Z8RYu+bWEe6uT7rUGwcKu+1bX/rsp9pn6Lsy2233cZll10GQGdnJ1/4whe4+uqrj2+jFEVRjhIvjNgyUqFvvIamCZa2pVgyNQodxZKF7SmSjomlNQPx8YpHEMeEkcQ2NTrSNgOFOjsm68gYTuhMk0mYpB0TAXz3W//Ff/7LJymXSghN55Xv+1dOOf1sFren8ENJ1jFpBCG2rqFpGoWqz4ndaSRg6RqW1hwlrxVikqZGV3sKDUEUQxhJzNnLrzF0jd5cc+22qYtZGcItQ6OnJTEn35sKsA+gv7+fn//85/zN3/wNXV1dfPSjH+Wss87ixz/+MWvWrDnezVMURZlzN9/fz5fu7Zt+7QUxn79rM+979UnTI767p1VL4NdPjfPrp8aBZrmtly9r5zdbJ2atawbwx7YzesvHiWoFci+7FqHpFO/9L/yRPuzuZexpd3AtgJO7MyxoTfKrTaME0b7rWgNcuKpnn9v3XJd9sPsUZU833ngjixcv5t///d/52Mc+xhve8Ab++q//ms985jNo2uGPdiiKojwbuX7IWMVrrqkOIvrGqpiGxuK2FBnHoOaGDBcbFBs+CUOjNWMyWHDJOCZCQHvawQ8jsgmL4UqD4VKdnGOiyYh/vP5v+J9bbmb5mWu46A/fxq3//Pdsue925i8/nXzSZLzmM173AGiEEWt6c0xUXIRo1qie15Kg6oWMVVzySZO+MZ9gosYp83JkEwamvu/M35omsA4xK/ihUneDAxgYGADgVa96Fe9///t58MEHyWazvOpVr+Luu+8+zq1TFEWZez/bMDTrtQTue2qca29cC8A7zl3M/m5LUsK9T43vFVyny9sZufnvIY7ovvafaHnpm3GWnAVAVBk/YHsk8MRwhbseH+GMBS2cNj+31/Xn5RN86vJT98oQrihzbWBggFNPPZULL7yQ++67j3e/+9189rOf5R3veAdRdHjJcBRFUZ6tLF0HKYllTKHmE8WSOJaMV30Spk6x4eP6IZYGA8U620ZrPDlc4tGdBfrGqmwcmGC07PPUSBVb0+nMOGSMiL9619v5n1tu5g/f9h6u+9RXsRacRrZrAY3SOKGUxBLcMCafslnWkcE2dKquz1NjVbZN1PnVkyM8PlTGD2OSjsHpC/K8dmUX557UyctP6mR5dxZDP35hrhrBPgDDaH49vt8cSVmxYgW/+c1vePWrX83FF1/MT3/6U84///zj2URFUZQ5deGqnukR6d0kzQRhX75nK3c+MbLfEeR98Ya30P+dD6E7aTrf9A+YLVNrtqMQAKHpB3j302Lgge0FTF2gCdhdzUsDrnnxQhVcK8eEYRjTfQLTNPnCF75AZ2cnH/vYx5BS8vWvf12NZCuK8ryRsA1Wzm9hV6GGpemc0JUmYeq4QcgTw3Ue3VlgoNDA0gVCSMaCCAGUGz5jFZeebIJUIqRUj7EMwabBcT78F+/kkd/czcvf9je89A1vI2WbBFIShQFOMo1jaPRP1Ck0fFqSJhnHpDVtUfVCam5AGMN41SVpGaRsHVPTKXsBCV2nM2fTmrIPuW71XFMB9gH09vYCsGvXrultXV1d/OIXv+D888/nkksu4Re/+AUvetGLjlcTFUVR5tTuQPVnG4ZY2ZPlG7/bjh/EIAR3Pn5owXUwsZPR712PZqfouub/YWQ7p/eF1QkA9HTrft9vGxpeODsfeRhJXnNKF794cpQ4llimyvqtHDu9vb309T29hEIIwUc/+tHpf2ezWf7t3/5t1ro+RVGU57KOjENbyqbW0awRvatYRwLbxmoIASlTo1j36crYTNR9wjDG9WPcMGJnocbjwxVWzctxcleCj7//L3jkN3dzzV9+nO5zLma47LJxYJK6F9EojtEyfxmuH9EIJSnLYLzokXQMMo5BEEYMFV3QIIxipIwBwbwWh7xvYhga7enjH1yDCrAPqLe3l0QiwaZNm2Zt7+jo4I477uBlL3sZF110Eb/97W858cQTj1MrFUVR5tY15zw9IjxcdrntkcFDrm8dViYY+d71IDS63vTJWcE1QDjZfHBptOx73TSwV3ANU6W6lnfSkbGRwJUq67dyDJ144onccccdRFGErj89++IjH/kI5XKZz33uc3R3d/OhD33oOLZSURRlbmmaIGWbNCZrbBurE8mYyZpH1Q3RhaAz51D3Q0ZKdep+SMOPm1nDQ40oDhks1PnoVz/D/Xf9hDe++2857dVXUqwH1F2f8ZqPVy3hVQp0zluMYWjsmqySdgx0oREOSSw0FrSlkEB70qDhxViGjqlB2QvRhCBpGdMZzo83NY/pADRNY9WqVTz66KN77evt7Z3OLn7hhRcyNjZ2rJunKIpyVN18fz8/emTwkEatd6s+8r/EbpXOqz/GicuWsWeuEX90G3quC80++DJH81oczjupg4/9ZCPffqCfH6wfOIyWKcrhO/XUU3Fdl82bN8/aLoTgH//xH3nLW97Chz/8Yb71rW8dpxYqiqLMLSklXhhR9wPGyj5BHDJUcvGDiKRlcuK8LPNakmwbq7FtssGmkRqDpTqTFZeRUgMviBkf6mfdnbfy+mv/hA/93d9z9oIW0o7BSKmBlILyUD8ARttC0raOpgkqXohmaIzXfNb3TzJQqJGwdHIJm/ltaU6dl8MNYkxdkLJ0SnUfKQ+nxzL31Aj2Mzj77LP51re+tdfTamg+yf7JT37C+eefz2WXXcbdd9+N4xxcTVdFUZRnq3U7Cty6foD/3SPh2aHIvezNpE45D7NtPtsn6nvt9wY3YfecdEjnHCq57Cq606+DMGZt34QawVaOmbPPPhuA+++/nxUrVszap2kaX/va19i1axfveMc7WLhwIa94xSuORzMVRVHmhJSSkbJLqR6wq9hg81CJyYZPqRaQSxrYSB7cWieIInYV65TqATGSIAAvllRdn8Ei2GaSc//yS5y++lQe3jnBlokaW4YrjFR9NGLK/Y8D0Ll0BTIGTQhqXnPkuzfnIKWkWPU4qTNDEEOx7vPEcIWdkw0sU+PU3hwdWftZszxHjWA/g3PPPZdKpcJjjz22z/1r1qzhv/7rv/jtb3/Ldddd96x5cqIoinI41u0o8Oav/I6b7+9nshYc9nmE0DDb5u9zX1geJyqPYs87+ZDOuecsdV1X66+VY2vFihW0tLTwm9/8Zp/7Lcvi1ltvZenSpVxxxRVs3br1GLdQURRl7vhRTMUNGS43GCzV8MKQhKXTljIZKXk8OVTBCyLqbozreVQbMfUGRBrEYUg2lSRlCSwBen4e6wdK3P3kCA/1TbBtpEShHFKtx0xu20CyvRfPSLGr1KBY82hP2yRtne6WBCd1ZVje20I6YdOecsgmDIaKdea3OpgauGFEW8o+3l/XNBVgP4PzzjsPgF/+8pf7Pebqq6/mE5/4BDfddBOf+cxnjlXTFEVR5sS6HQW++MstrNtR4AfrB/CjZ35QGNVLh/1A0d3ZfGBpL1h1wOM0AfoBkpVctVqtv1aOLU3TePnLX37APkE+n+f2229HSskll1xCuVw+hi1UFEU5fEEUs6tQZ/t4ldFyg9GyS7Hus6vYoNKI0DSN0YJLI4qZ15ogaWhsHy+xcXs/bhijiWblkSgC0zCoNzz8WFLzoVLzcL2A7aNlto00GKpBHZj0I4p9j2HNW8lExSOUEX4s6cg4LG1LIRCc0JnjxYtb6W1JkE2adGeTxAgafkw+aZNPWVjGsyesVVPEn8GCBQs48cQTueuuu/jLv/zL/R73oQ99iI0bN/LBD36QVatWcfHFFx/DViqKojSt21Fgbd8Ea5a2sXpRfq/X+zr+2hvX4gUxmiYOKmgur/sJYWGQlpe/BazmGupDmZblbn8EzclgdS7Z534BnNyd4ZOXn8qm4QrX37aBKJaz1oJbuuDKs/Y9Qq4oR9OrXvUqfvKTn7B9+3YWL168z2OWLVvGLbfcwmte8xquvfZabrvtNlW+S1GUY84PY/woxtI1TF1QagQ0goicY5K09w4DJ6rNWteVRsAj/SUStqDuhQxM1GiEEi8Imd+aoOQFFCouTwxX2Hjn9wlLg/Se/xaCOIljgKUJGmGAG4IWQyMCKSCsB/ghuDNu6P7IVmKvhrHwdCYbEs/3WTYviSYgjOCCVZ2ctbCVhG2QtE1KjYAoY9GVcxireGQSJgtbk8+a6eGgAuyDcsEFF/D1r38dz/Ow7X1PPxBC8LWvfY3NmzdzzTXX7HN9lqIoytGye930LesGCKMYy9C4/uKV3HD7Rvyw+fqm69bsFWSv7ZvAC+LmE+eDyBTuj2yl+vBP6bj8Q2h2aq/9UsoD3uSklLjbH8ZZfMZ+a2BLYNtEDWhmNF/enWFt3wT5pMXGwZLKHq4cVxdccAEAP//5z3nXu9613+POP/98/vVf/5U///M/5yMf+Qj/8A//cKyaqCjKC5wfxpQbPuMVH9vUQEDWaQantqExXHFZYCQx9b0f/EWxpOqHaJpktOSzs1AjimIW5VM8PlTCi2Nytsm6bRMEo1spPPxT5l35IYSeQshmMF2OwPJASIkmBBLQZHOpl6mBjJ6+XmPbegCcRWcgAcOCcjVg41CR+VOZw4tuQMI2MHSNtvTTsdji9vTR/SIPkwqwD8KFF17IF7/4Re65557pG+u+JJNJfvSjH3H22Wdz2WWX8cADD5DL5Y5hSxVFeSGaOQq9O0QOwpifbRjCD2Niuf+EYGuWtqFrgvAgy3D5o9tJnPAizLb5eENPUdv4C4RuIkyb3Jo3IAzzgO8PRvuIqpMklq4+4HFu8HR7d/+jKM8GJ598MosWLeKnP/3pAQNsgHe/+9088sgjfOpTn+LMM8/kqquuOkatVBTlhSqKJcOlBlU3YKLms7AtRSQlXhChCTB1DS+MiWLJnlWt2tIWI6UIXWg4hsZvBotYhqDWiOnOhbSmbCqNgIfHagxO1Bnc9CT2khcR5OYztHN2n6B9qk/gTZ07iGYH1ru5Wx/C7j6RZKqFlAP1BjT8gEQ9IGXq3L9tAkvTyNgmaee5Ebqq+UoH4fzzzyeRSHD77bc/47ELFizglltuoa+vj7e85S3E8d51XBVFUebSzFFoaE6xNg2NC1f1YBkaumi+3ldCsNWL8lz3siUc7Mwqo6WLqF4CoPTbb2NkO7G6TiCqlyg/+EOAA04zr295ABDPGGADPDVSObhGKcoxJITgD//wD7nrrrtoNBrPeOwXvvAFXvKSl/BHf/RHbNiw4Ri1UlGUF6ogihmrNHh0oMBvnxrn4e2TyFjSkrIwdI2qH5J2DOx9rFkOwmbckrF1kNCddZBRTKneYO2WSfonKjy2s8jARB3TNBDpLoL99AkmH/whIU/3CXb3DGaGyFG9hDe4iY4VL6I1rSMjcAEiqHqwa9Ll4f5J7n5yhN9uGWfnRPWofW9zSQXYByGZTPLqV7+a22677aDWJ7785S/nX/7lX7j99tu54YYbjkELFUV5Ias0glnrkyXwRy9ZzDXnLOSm69bwVxcs3+f0cID3fedhvnRvHwebr8zuPZnYrTL8rb/FyHWRffHlJE9+Gcll5xBVJ4ADr8dubLkfu3c5euqZR6R/9Mgg7/vOw3ttn5mUTVGOh0suuYR6vc7dd9/9jMfats0tt9xCJpPh8ssvp1gsHv0GKorygqVJ2D7RoNQI0TTBSKVBLCWmrtGRtlmYT9KZaZYVnhnXTFY97t00wk8eGeC32ybYPFohYWqMlHwGKx6bRso8tqvIzkKdQtVjsBwgD9AnCKsTSPbuE4Qz/m5seQCQaIvOoViNaMkYOAIaNNdpF2s+IpJMVD36xyo8tL1A1X26wkkYxQyXGuycrFP3Qp4tVIB9kC677DL6+/t5+OG9O3v78p73vIe3v/3tfPzjHz+okW9FUZTDtXFo7yzFX/l1H5/+6RN7JTybGZi+7zsP86NHBp/x/N6uJ6hu/CW1x3+F0A3aLvq/2AtWUnnkf2lsexih6YSlEcLKxAHPE5ZH8Ye3kDhpzUF/th89MsjN9/cDzcD6gz98jDd/dS2fu2MT1964VgXZynFx/vnnk81m+eEPf3hQx/f29nLLLbewfft2NbtNUZSjStMEbSkLQxO0ZSxMQ2PrWJm+8WozG7gbUvMCdkzU2TFRxw0iwijm7ieHeHDHJE+NVugbLpOxDRpeQH+hxtBEg8lKyHg5YmTzRkYf+SWlI+wTANSf+h16tgM6l9IAitWQhoQYMIFaCMQB/RM1+iar1L2AyZqPlJKKG7BjokbVC9EEjJRd4oNc7na0PTcmsj8LXHLJJei6zq233spZZ531jMcLIfiP//gPHnvsMd7ylrfw0EMPsWzZsmPQUkVRXmguXNXDr58an7UtlvCle/sQgG02E5597CfNhGe6JnjpCW3cu8d79sUf72f0lhvIrrmS+qbf4O16ksTS1WTOvBAj087kXV8isfRs3G0P03HV9Qc8V33T7wBInviSQ/p8332wn+XdmX2uM9/XunJFOdosy+Liiy/mtttu48tf/jKG8czdqZe+9KV8/vOf58///M/5xCc+wUc/+tFj0FJFUV5oDEPjjIV5ql7IjvEq8/JJBBrVesD81iTFug9IXD+m7AUU6z5V12ft1gkq9ZBGENAIm/faQs2jWPcoTg0ay/F+Bm65gdY1V1I5wj5B7NVpbHuYzBkXTo9yl3ywYHrddsKCeghBHGBqEMTNPkyx3gy0J2s+EsmC1mYytD3DazeIkBIcUzumWcbnZARbCPGXQoiNQogNQohvCyGcuTjvs0l7ezuvfOUrueWWWw669msikeDWW29F13Uuv/xyarXaUW6loigvRNecs5BPXX4qyzrT7Hn7kIAfxPzb3Zvxp9ZWRbE8qOAamtO3Mi+6lNw5V9F9zWfQs+3Un1qLN7iZzFl/SNebPkXmzIvofMMNmC3dBzxXfdN9mB2LMVvnHdLn86cCaT/ce535nuvK1fTxZ4cXQr/gqquuYmJi4oA1sff07ne/m7e97W18/OMf56c//elRbJ2iKC9kLUmLPzytl9es7OHU+S3Ma0lQ8SMKNZ+66zMwUefRgQJbRqr8fmeRx4dKyCiiv1BloFinJ22TMDT6x6vUvKfjnvKWB+g451Laj7BPYACNrQ9AFJA8+aXT22OeDq6DqePCSJBL2LgRVNyQ9pRJ2Q1wTEFvS4IwktS9iI6Mja493Qsq1n2GSi7DpQYTVY9j6YgDbCHEPOC9wNlSylWADrzpSM/7bHT11VezefNmHn300b327a9Tt3jxYr797W+zceNG3vnOdx50cK4oinIorjlnIXf91Xn8w+WnYmhiVqAdA8Plw7u5GPke3L71BBMDCMMid85V2AtWUVl/O40dj2Jk2jBb52Fk2w94nrA8jrfrCVInv/yQ27Cr2OCRnUUMvZmwzdLF9PrymaPXu7Op754+fvP9/SrYPg5eKP2C173udaTTab73ve8Rx5KGHxFEzYdYfhgzWnEZKzemRlCa934hBF/60pc4/fTTufbaa+nr6zueH0FRlOcxIQQL21LomoZtGqQtnU0jZX6+YYS12yd4fLBEEEcEYcRQocFQyUdHkDQs6n7ApqEywxWPYMbSZiPfQ2nLesLSACnn0PsEAmg1pgLsJ+9DT7diz2uWNZ4ZlGqAI6A17ZBOGZyxoIW2hIUQ8Nst42wbq7BtvErDjzipK8PSjhQZZ3YVk6ob4pgaKceg6kX4YUTFDXCDfaQyn2NzNUXcABJCiABIAs+8qO856Morr+Q973kP3/3ud4nyi6bXNgJce+Na/DDG0ARXn72AK2bUaL3gggv4xCc+wYc//GHOOecc3vve9x7Pj6EoyvPcyt4slqHxyM4iQXRkD/VSy1+KP7SZxvZm/gmzbT7pledDFFJ/8j6chacixDM/q60/+WsAkie/7JDbUHZD7nx8BF3Aq1Z00ZGxZ/3G7rZ7lDuWzQDn+ts2EEu53xrgylH1vO8XJBIJLrnkEn7wgx/wgU/+E2gGQkBP1mG04hPGMaMlF82oM68lQXc2ga6J6dltq1ev5oorruB3v/sdiUTieH8cRVGeo+JY4kfNqdN71rUOopiqF4AEKSQbB0oMl11sXVDzQwSCYsNDhiGFWgMhBaGMeWK4DlLi+rOvtbtPMLHlYVpDEIfQJxA0g2YMMNwa9b6H6HrxH5LXNYoz0lLYu481YWlPmlN6skShQNMlSdPkqdEarWkTR9fpzMSEcYwbxCSs2TXHUrbRnA4vwDF1dhUbhJFEF9CbT+LsWaNsDh1xgC2l3CWE+CzQTzPp2x1SyjuOuGXPQu3t7bz4pefxH//5X9yivZwobnbcrjhr/vS6QD+S3Hx/P7euH5jVofvABz7AAw88wPvf/35Wr17NS1/60gNfTFEUZT/W7SjMSl6228339/PBHz52xOd3d24gmNhJVCuSOeN1pFb+AZV1PyZ2q9jdJ5I44WyEYU5lDT+4NU21J+7B6jrhoKeH24bAC2c/HIgk3PH4CADffXAnN1y6imvOWTi9P5+00IQAJJoQRLFs/i6rtdrH1AulXxDFkosuvYKbb76ZO+64k8svfT0NP2S86jNRbVD1Iop1n958AjeIcIOIlN3sdi1dupSbbrqJiy++mD/7sz/j61//+jFdH6goynNTHEuCOMbUNDRNIKVkpOLiBTEg6c4lpgPHMIx5bKCIrgm2jlbYOFRk50SdpK4TGFqzDrYmKdcDto/XqPg+XgCmAMcRSKmjEVHfuYH6PvoE44fYJ5A07+NVFypP/BYZhfSe+UqkARkJWVvgR5KEqaMbBi1Jg0tX9rK0N8d4xaPU8NGExo6JGlEs2ThUom+y2syM3pZiVW8LuaTV/OxRjBdGSKAlYaIJ2DZeI2HqxFLSkgyPaoA9F1PE88ClwBKgF0gJId6yj+PeKYR4SAjx0NjY2JFe9rhYt6PAtpYzKY0OUtmxcXqU5J5No3uVyNmdfGc3TdP45je/yeLFi7n66qsZHh4+5u1XFOW5Z8/lJ3tOg545/flnG4aO+HphZZyJn36e2K0SezUGv/4XhMVhsmuuBiGoPXEPQ//9fsoP/ojcuW86qKAgmNyFP7yF1CnnHVQbdMFewfVe7Ywl19+2Ydb3csPtG4niZnB98Wk907/LsWwG38qxcTD9gud6nyCMYoZKdc4695W05PPc+r3vUHV9hooN+saqPDlUoeqG1PyIOAKkmLU2EOCiiy7i+uuv55vf/CZf/vKXj88HURTlOSOKJUOlBkNFl6FSgzCK8aMYP4xJ2jqaJmaVqgrjmDCWICX9EzUcw6DuSwp1D9cPaUvblLyQmhcRiZgoAhlBI4BqQ1KsRtRK4wzPYZ/AByLA3XQvTmsP3ctOAw3CEGq+pCVpkc84vG5lN5eesYD2fBJD0+htSRCHMRUvIGEIxise1aA5Aj9Z9al5EZM1bzqLeKHu40cxjqlRccPmoKjeXD4XhDHmPmqAz6W5mCL+amCblHIMQAjxA+Bc4FszD5JSfgX4CsDZZ5/9nFiIPHOUCODzd23GPuEchGFTe/xXJOafQixhV9Gd9b79Jd9paWnh1ltvZc2aNbzxjW/krrvuwjRnrxdQFEXZ/duTT1rccHsz8/fuac4zp0HvmUV7X9nED1VjywPY81aQW3M1AIkTXkTh7q+SXP5SWl76ZgD8kT6EnXzGpGa71Tb+ChAkV7zioI4/2FntUSy5df0AqxflZyVBk1IyUfNpjmU3nyQX6v6BT6bMpWfsFzzX+gRSSor1gJoXYhmCehAzUmxgGhqvv/QKbvnut5GuS8OXCAGmLkCDk7vTWLpOZ9be52jJ9ddfzwMPPMB73/tezjzzTM4555zj8OkURXk2c4NmfodYSvxIkrJ16l6EF8Y4po4moOFHxLJZNWQ3xzKY15Jk00iZMIbWpE1vzsYLJZ1Zm6StMVn20DUghKrXTCwGEATNetWVo9An8CsTFPseZeEr34gGdGUTlNwQQcyq+VnKjZjhcoOkY1KoujiWzmTFo+xFpG2DnvYMwxWX5kpqSdmLsHSNYj0kabu0pWxiCZoQzVH+SOIYBq0piyiClqRJ0jq6hbTm4uz9wBohRJLmVLBXAQ/NwXmPq92jRH4YIwTE8VTqdytJ8sRzqD/5a3qu+AsGq7MXyusC3vTihftcHwhw2mmn8ZWvfIW3vvWtfOADH+Czn/3ssflAiqI8J8z87Zk5zXl3ML1maRuWoU0/gZ35IG/3dOnP37WJ0crhBZTOotPxhjYTlkbQs50kFp2O9aZ/YPR710MU0vKKt2J1LT3o80kpqT3+S5xFp2NkDpwI7VBJ4JZ1A1x51vy9vpcLV/Xw4PbJfX5PylH3vOsXNIKIybqPowt+u2UcL5D05pO4vscfvP5K/vsbX+O/vvd9Frz4tWRsA10XuEFI2k6zuD2Noe97tETTNL71rW+xevVqrrrqKtavX09HR8cx/nSKojxbuUHEUKmBEM2HyjIW00m6dK05M6Yrm6Dhh5iGNitwlFKyuD1Fd8ZEypjHB8skLZMFeQvH1pio+kSxhhSQSAgSlams3cDu6Gau+wTQXDKGjGk79TwafkDCMjixM02h6lP3BEnHYLzq4+hVkqZGoe5T8XYvsYkZqXqs6MkipKTgBqzsaSFtG+SSJq4fMyl88kmL0XJz6nx7yiLlGPTqSYKo+VBizxlFc20u1mDfL4S4BVhP87/Jw0w9lX4umzlKtGdRtSUvuZANT9zLlvW/JnnSubP2/Z+XL+XvL1pxwHO/5S1vYe3atXzuc59jzZo1XHXVVQc8fn/rLRVFef6Z9duDRJ9aY7U7SFy9KD89kr1maRubhit8/q7NCGDzSIUFrUlefUo3N9/ff1jX15M5QFB+4Ie0vOJtCDuJnszR/vq/ofLwoZcV8nY9TlgcJjf1pPtIdWdtFrYmeXB7obmeK2o+eHjP+ctmfS+rF+VZ3p1Rv53HwfOxXyBlc3ZaxYsYLbuYhsHmkRK9uSQvPmcN3fMX8d2bb+LNy15GLYhY2p7i9PmtRDGUagHZpDk9JTGOm6Pcu6dStra2cuutt3Luuefy5je/mZ///Ofo+v7XBvphjBuEWIZ+VNcQKopy/LlBhK4JHLM5ap3PmESxJGEZ0///twwNy7BwvZANu4pUGgGGoeH5AWnHoj3tcHJPjhO7Mvy+v8BAoUG56lNq+AyXXapeTGvKYcgIiMKnR7Fh7vsEANUNv8DqOYnJxHxkLaRFSubl01gtCdBiKg1J1tbxYhgrexQbPrmkTSQlOyYaZJMmDT9gYWeKl7ak6Mg47JiokbB0olgSRs0cWfNbk0gpp39rHfPY/WbOyfi4lPKjwEfn4lzPFrtHQ9wg3mtfpe0U9FSe6oZf7BVg/+dvtzNcdnlkZ5EzFrSQtA3GKx4S6JyR+faf//mfWbduHX/8x3/MypUrWbFixT4D6ZmjWSoTrqI8/+05Env9xSsp1P1ZvwurF+XZNFzh7279PVtGq7PeP1z2eHB7AUODKN7r+eB+7b4JaU6a1tf8KRM//Txjt32alpddi5HrJJjYidv/e2ToI4yDX89ce+xuhOns9Vt5OARw2RnzeM3Kbq69ce1eo9OrF+Vn/T7u+Vo5dp5v/YKEqZO0dEbLLmga+YTOUyMuGcfANtKcet5F3Hnzl3hsUx+p1i5iKTmxI0ux7vNQY4KUrdOetEk4erPkjZT05hP0tiSRwJlnnsm///u/8yd/8id8+MMf5kMfvYHRioehaXRkLIQQCAG6EPSNVQAwdZ3elsRemXMVRXn+SFg6hbpP3YswdUHGMdH2GH2NY8mOiRobB4uU6iFZx2Bt3wSOpROEEafMy7Kiu4XJmk8oYXl3lrs2jlD3Q9wgJgwDDBzyKcF4WRJKiI5Sn8Af6SMY207ra/4UaK73ToQxVT/i1N4sFS9irOKRSpiMl+roQtKStKnWqxi6gWE0M4BvGKzQlU2QsTw6Mg7tGZvxioemCbqy9vT1jlfyyKM7Af05bPco0d/d8ihbxmqzd2o6qZXnU37oNqJ6aerpTpMfxvzokWY1ku0T9b3O+92HdnLWghY2j1SpvvjdRI//X6644gr++ou38v/u2t7M6Kc/Xeprbd/EdJDvBjHv/94jvPMVJ8zKnKsoyvPHniPUe9Z5Xts3QaUR8KV7D1w/N4qb08fC+MAhtowjZBSimU/fkDTTpuPSv6O09hbKD90GQDi5i9bX/Okh3UjjwKX25K9JLn8pmnXkZYgk8KV7+1jYltrvd6QoR4OmCbpzCXKOQd2PGCjUWdCaIGka3PvUOEvWXAg3/QeP3fNjVlzwdlKOzmC5TqEakjDhsZ1Vhks+K+ZlSJkGuaTNrlKdjlSNUMLOiSr2ildy+Zvfyqc//WlS85fz8ldfRFfWZtOwRzbRzNdSrvtsHq4gNIkAJmtpTp3XgnOU1xMqinJ82IbO/HxyelR2ZnAdxRI3iCjXfXZM1Ci7AUNFl5IpMAzwg4iyGzBYcEmZFUZLHn4sEUJSDwKCMKbqetS8iLpfI44hjiOiMEQchT4BQPWxu0A3SK5oJj2NZDOpaanu8dRYGVMTdOeSdGYd2pImpmFg6jBQClmYs9A0mKz6dOYselscJmshfhSTcUxSljFrdtDxpH6RD2D1ojyfuep03vjl3xLuMZCdWvUqyg/8gNrGX5J90WUHfc4wkjywfSrrb6KV7EV/w5Pf/TDvffc7ab/07xGimaL+pvv7+f5DO7no1J5Z798+UZ8uw6OCbEV5ftrXyOvM2SwHQ8IzBtcAhbu/Slgcxuo6AaO1l+SJL0GzkwDk1lxFWB5HmDaxW8HM9x7S56hv/h3Sb5A+9VWH9L5n8rMNQ1xzzkIVWCvHXMI2OWNhnqSlY2gCKWGo7PKS01fw4+VnMvDgnZxwwVvw/YihQh03gEfGy+ycaBDKmGK1wfz2NCd15ZAYPDU6jo7G1rEKbWmLFZf+OY88/Aif+ru/4B++sYjFS5aiC422nEPa0tk6WqUlY7FpqEyx7hMjKDVCXnFSB5porsd8NnQuFUWZO6ausefMZiklI+UGDS9msFhn22gFw9SaAWgjQNcE5XqIbWroAjbsKuPYOqNlj4FJmKy6VBsBpZpPLmEQxKAbGmO//CqNiWHMo9AnkFFA7fFfkVx2DnoiAzSno8cxTNRcCnWPEzrSWKaGoWusnN9C1Y3QoTmbJ5tgvOqSsnUMXaPQ8MlaBsbUQ4c9R/aPp6Obo/w5bvdo0Q2Xnsq15yzE0p/+D2d1LMLqOYnq7+9EysNPgOosOo2W895OfdNvKD/ww1n7/EjOKvU103/e1zerdI+iKM9va/sm8ILm2uyDiJsPSuFX3yAsDpM//0/QnDTB2A6Kv7mZYGLn9DHCctATGYyWngOcad+qv78To6Ube8GqI2rnotbkrNcXrjr0tijKkYhjSbnhU3ED8kmTU+e30JFxsE1B2jZA1zjllZfQGN+FMbqF9kwCTdPIJATjFZ9KEBJGMdtKDbaPVXmgb5z12yYJwggvDEGAFBpWIsEb/+5zCM3gM3/9Th7uG2ag1GBgosHjg2UsQ8PUBHEsMYWg4UXsKtR4YqhEf6HOSNmdLlOjKMrzVxjFjJU9Bss1hst16mFIHEnmtTicvaiVFy1qY0lHimXtaSIJk/WAhhtiIPF8j0YQM1H3qUdQ9kO8KOKpn/4nfmGYk1//f45Kn6C+5QHiRpn0qa+ZtT2KoNSIkVJi6AZjJY+2tM2qeS10ZWzmtSc4oSuFZcCp83JcfPo85mWT2IZGPuWgPQsfKqoAez9m1pq94faNrOzNcfXZC3jx4qdHTNKnvYZgfAf+0OYjulb2xVeQPOlcivd8A7f/97P27e9/MlvGavusg6soyvNTPmkd9HrqgyUsh/QZr8NsX0DmrItJnnQump2iuvFXyCig0bcOt6+Z/PlQR8WCwhBe/+9JnfpqhDi0W03GeXpylSbgXeed0Cx7RLP80fLuzCGdT1GO1GjFZbziM1J2GSu71PwQSbMMztL2BC2OySWvvwQ7kWLH7/6HOI7ozSdwTIsFrSkMIYgRGLI5ypJNGuQTBnEYE8uYlKkThREJQ+e0k5fx1g98ltGdffzsS58gY+pEcUgQxXhhRM2LydiCUiNgtORS92IK9YC01Zy+7obRM34eRVGe2ypuyEjF5YldZTYOVgnCkJSj4wYhv+sb57b1/dy7aZTRsks+YZBPGgxXXLZN1hkouYyUXaquxA2gUoNGIMGwyZ72OuLsvDnvEwBUf38HeroNZ8mZs7Y3omZA2pIwAUlMTGvKQMYSTdOIwphCxSOIJIau4Ucx2ZTJ0o4MkZTTWdWfTVSAvR8zM/n6Ycz1t23g2w/0s66/OH1MasUrEIZN9fd3HNY1dn/5QgjaLnofRr6Xsdv+kbD8dB3bobK33/fHEvwg5vN3bebm+/vViLaiPI9tHCzN+TnNlm6K9/4X7sDjCMPEnncyiaVnE4xtwx14HCPfg7P07MM6d/WxO0FopFe9+pDfW3HD6b9jCf9y5yaiqVG5MGrWvlaUY0VKSc0LSTkGSctgou5TdUMSpsZ41eOJ4TqDBZd6bHD2H1zEk7+7kxYroupGDE7W0ESIbUIQheRTFlJC1Q0YcwNM0+DUea1cfFovV6xewFmLWknZOq+94DVc9Pb38tAv/oeff/8bVBoRBlCo+ZTqPoMln4RpkEmZ5FMmCVPDDyNqfkjNCxmvuAwVGzT8Z1/HU1GUI1d2QzKW0UxGRsxw2WfrWB1D1xgqNagFMTXP59GdBdYNFKl5ESnHIKY5DS5lmQQR6IBtQxiAnu1m7N7/otg3932CsDyGu+1h0qe+GqHNnu8eAGEE/ZMeg8U6S9vSOIbFIwMFSg2f3w8U2TRSoW+0xt2Pj/DkYIk4ihkuNRguuwxXXMLo4JbPHSsqwN6P3Zl8dcF0LdpYMt3JA9DsFMmTX07tiXuJ/cYhX0PSHI3RAM1O0nn5h5Chx9ht/w8ZBgd8ry6a//Fi4L6nxvngDx9TI9qK8jy1bkeB7z+085kPPESpU15J5qyLqT76c2pP3geA3XMiiWVraGx9EDPfi+6kD/m8Mo6oPXYXiaWrMbJHXvt6rOqzeyXO7trX6ndOOVaEEOQSJhU3oOFHtCRM4lgSxDEyitEFdOUcUgmLpS99PYHnsuk3P8fUJMVGQNmHBS1JTuzKcHJXhs6Mw8LWFCu7ciztTBNLyUjNx9A1FrQlSdoWCcvkz/7v+3nlBRfx7S/+P0rbH8WXMWNVjyeGytS8gIofsmW0jKFDR8qkUPdBwq5Cg/7JOrGUDJcbz7qOp6IoR8YPY6qez2TNY6joEgQSSzdImRpSQhDGjFc8ym5E1Quo1Dw27iqxebiMgaARxhTrHo4Bhgau10yMmjr5leTOupjJOe4TwFRyMxmTOu01e+3TAQSkHUhagpIXsn2iRM0P0UWzvYVawPbJKoWGy5axCv2FOsWqT2/OQReCqhfudd7jSQXY+7E7k+9fXbCcGy5dhamLfU7XTp/+WqTfoPbEvYd8DUlzNEZMLco32xfQdtH78Ac3MXn3lw/8ZiE4dX4OTTxdhiee+j/V/tZtK4ry3HTr+gGC6Oisq0ytPB9n4ak0tj7ExM+/QFAYovb7O9BTrYd9zsbWB4mqk6RPf+2ctXPmpw/V75xyjLWlbebnk8zLJ+jMOHhhxPbxGoahAZJNwyVKdZ+OJSfTvugk7vrRt1m3fZJCzWOk5DJY9snZNo5jsbwnS1vGoRGEFOsBEzWfBS0JGn5I/3iNtpTFiV0pYgnv+/g/0zN/EZ/+m3czOTJMNmFS93y8ELpbHOp+TM2NeGyogiYEXVkHyxD4YYypayDFnC8tURTl+CrWPWpe3Fw7HYZsHy8yMFlnvNqcXbOkPYVj68zLmgRxzGDRI4ojam7IrkKVYiPE0MGxIGE1B+saQF2Cs/J8zDnuE8g4ovroHTiLz8Rs6Z61zwEioBFDoQr1Rki5ETBeDTA0HTeKiGOJowtMTac36xBLSd2PiKSkFgTEUqI/ixKcgQqwZ1m3ozBrmvXqRXnec/4ylndnkDQ7eHv+97PnnYzZvpDqI/97WNeUzB4VTy1/KdlzrqT6yP8ecOp5FEtsQ2um7J/apglm1YRVFOW5b92OAresGzhqnWTNSpBaeT7Zc64AKSn99jvY808hd84Vh33OyiM/Q0+3kjjhRXPY0qdpmlC/c8pRV3UDdkzUGC65xBIcU8cxdbwoRgL5lE2h5pNNGFhGc51gTz7Jyy5+I0N9T9K/eQNBJLGMZpK0Ba0JMo5Jww8RQKEe0Za2SFgavh/j+QFSxkRSMlHx2FmoM1CXvPsTX6BWq/Hpv/lTtg2WyKdMFrenyDvN2tq9+SRuEONHMVUvRNc0MgmTmh+SS5rNQFtRlOeFhh9RqAcMF2vctXGIjbuKDBZdar5Podqg4Aas6G1h9eI2Vi1oJZ90SFkGtmUiJTQiiaVJLE0jkuAHMHMx6tHoEzS2rSeqjO310D2jN3NN6UBaA02DfMbktPktZJMm7WkbTUpSlsHS7gwZR8cLJCnLojuXYH4+SRBK8kmrmWzyWeTZ1ZrjaGYJHMvQuOm6NdMlYL58z9bp0aM9k3MKIUif/joKd38Fb3gLdveyI25Lyyvehj+8hYk7/gOzYwl2z4n7PG6y5nPFWfMRwMreHIW6r2rCKsrzzNq+iaM+xVNoOlb7Qtpe9xfIKEToh39rCEsjuH3ryZ37xr3WWR0pQbO29w2XrlK/c8pRFUYxo1WPhNFMGlRqaLSmmvVexysuGwdK+HHEluESuaTNkrYUG4OQ9pTDi1/1en76tc+ya+3tLL/qr8kYJgk9ouCGzG9J0jdeRQiNkufz6M4iScOgf6JOueHjBhEp26DSCPDDmJofEEUZLnjXR7j9Xz/And/8R65670fozjnNzqem0QgibF0nZxtoGnQmHVqSJrHkWTeqoyjKkQmiCEMTOKbBcNlFTE35LtdDHNOkMVbH1jVaHIOaK1jelWGs6jFe82hP27TFJiMVn1hIWhyLEj6iMXuW2Fz2CQCqj/wMLdVC8sRzprclaI6ge35z/bWhg2UJlnbm6MknKDZ8JmoexbqPaegEkSRjW6zqzZBPO6QtA9PQOSGXIJswj6h9R4MKsKfMTGo2c5r1D9YPcNfjIwd8b3rVH1C855tUH/kZ9uv+4ojbIjSd9kv+lqFv/iVjP/wUPW//F/RUC7ah4c2ogbt9sk7feA3L0LjirPmqw6koz0NrlrahaYL4KE0R39OR3kgrj/4chCB9+gUH/Z7ujM1o1Ttg+bELTuni9AUt6iGickwJAQKBlJKqG7Cr2GDrSJXxisuOiToV16N/skGpHpJxDNpTDt2ZXta8+mJ+d8dPaD3/j8FKkjQ1JmouCVNHSig1fMo1j12TdfIJndFygCagPWczUGhgaBqxlIxXXRKmRurENZz1h29l/f/8NytPO41Trnk7XhhzSm8LCdvA0QVSaBiiWRs27Rhq5FpRnoeiWPLkcInNu8rYuk4cRUgJURyRSej05BJkHZvOjEkUC/rGy4RIzl7cRktS57dbJ6fqRTsU6gG656MD+1vBfKR9grA8RmPrQ2TPuRKhPx0Ih0ClAZbx9HTq+TmH81d0c/aidiZrPqMVlzCX4PcDJWpeyJKOJG0Zh46sw7x8AoHAMp6dv3MqwJ6yO6mZH8YIIag0Aq69cS1eED/j1EzNSZNc8XJqj99D/vx3oNmpI26PnszRcfkHGbnpbxn78WfoeuMn2b1+f1lHitaUxYPbC0iefiCgOp2K8ty3bkeBH6wfYLTi0ZmxydgG4RwH11LGuNseJrF09dyeNwqo/v6OqeRmnQf9vrRjMFrdf8UEQxO867wT1G+ccswYukbeMRko1jF1nfa0xWjVa/YTohjTEuiGIJewsaKYihdy+sI8jq1RrPmccv6V3Hv795l49Bec+pqr2FXwSLgRkzWPIAiQloVEI4xDim6M0AQ1P8CqmdiaACEpegFBJGlJ6ASRxpqr/5SJ/s3c8oVP0r3kZF718nMZrXq8rDtL3Y8o1jwSloWIIJZq5bWiPB+MlV36xiuAYElbmsm6hx9CV0uCkzpTjNcD0o6BY2pUGwG7ooiaG6FraepeyK6ihxuG6KJZM/rsxW3sKjYYKtZpTVk4JoRxnaG+TWj5XvTE3JbBrD7yvyAlmTNeN2t7MPWPiKYeZOrQnnXYPl5jWWcWicTQNNw4ZkVvloYfkrZNENCaNLGNuZ0hN9dUgD1l03CF+fkkfaNVIim58b5tRLE86HWPmTMvovbYXVQ3/ILs6tfPSZvs7mW0vvbPmfiff6bwy/+k9VX/BwA3iFjXX5xumxAwWGywbkdBdUAV5Tls3Y4Cb/5qc6nK0VReewvFe/+Lzjd8gsQe9SiPRH3z74hrRTJn/uEhvW/LWK1Z53qqWoOgOV1NTQlXjpdS3efJkTJuEDG/JclEzScIY0xdkE+ZuKFNZzrCjyQnpS3KXkBHxmpWG0GSXXASbYtXsOXe2zjpDy4niiOiOKa/0KA7bVPzAmQcs7QjzWChTm9rkqoXEMSSE9qzSAH5coPJRoRjCFKOyXnLOljyiX/jH//sSv7z4+/ltO/8DCfTykSlQcWLmyM+VZdswkII6Mw4JJ9l6xIVRTl4fhizZbRK3Q8ZKfk8NVwmmzRx/Rjb0lnSnqbTD4hjyfbJBtsnanjFmBO7Ndb3BzRcD8e2iGLJ7wcncYMMjSBsViLwQyyhkU8Z9G0ZZvj7H8VesIrOKz48Z+2ffuh+wtkYua699huAlJCwoSPtINGouRHDxQbduQQndKQIY7CN5oyeMIpJ2gYJ69n/u/bsb+ExcPP9/Xzwh4/N2hZLidbs76FpIONmlr39sXtOwuo+kerDPyNz1sWHVYB9X9Kr/gB/eAuVh27D6l5GeuX5DBTdPdra/Azff2gn337nS1RHVFGeo9b2TRAc5eC6sfVBivf+N8lTzsNZfMacnruy/n8wcl04S8865PdGkeTN5yyczieRT1oqr4RyXLh+yGMDBQYKDYpuiB9GtKUdbENnvBrSkbFZNS+HG0SMlN3maHEMFT9ASJ2JikfDDzn5lZfzm298isGN68kvW00+aVHzInpbHPJJi50ll5M6U7QnTVqSFsLQOGtBHtMQDE166DrsnKgxVnXRNYEXxwgnzUf/9Wv85Vsv5dPvfxdf+NYPKTaiqbXbOoW6R85p1pL3wwYndqXnrD+iKMrx4UcxQRyRdZoJFQfrDdAFUSwwdZNNY2X6xmqkEzo1L2LraAVTaFi6RlANsXTBiV0ZEJIHdxSYqLp05xIIKRgaLbD5O58EIH/+n8xpu+ub1xLVCqTPvGivfTpgApkEZB2b1pSFAEqNkF2lOjEwL5/EeI4udVEBNvCzDUN7bdNEM4FZHEl0TeNjl65kw2CJLSMV+sZrlBsB/h7TNjNn/SETP/08Xv9jOItOO+jr77717W+0PH/+O/DHtjH5v/8fdtsCrO5lT49e83TiNT+S3Lp+QHVGFeU5as3SNsyppSpHQzC5i7GffBazcwltr/uLOe14+2Pb8QY20vLKP0aIQ78hxjR/A685Z+GctUlRDkcQS4JYknJM3DCm3AhJWAHzWpJ0ZmzqQUxb2sEPY9rTNo8PlnhypIoXh5w6rwXXl3RkHF78qot46PtfYPyhn/L6C19LsRFykqNTdEMQGn+wvBs3CFjWmSPl6AxMNh+ebx2p4QYRpi5Y0JpkXmuCbWM1RksuqYTFqhPO5P9+7J/47Af+gpv/7Qb+5uOfIZayWbLGi6haAaamUSbghI4Uuq4CbEV5LrIMjRM60qzfEcDUlOmaF5FP24AkdmCyFnDq/BxDxRr1RkCt4RFZOk4qiW7EVN0QXTfYVagxUfOQQmDoOiNlHz0KePBb/0R9tJ/Oqz6Kme+Z0/ZXHm4+dN9zOVrahJQBPa0JVvbm6c4lsA0dL2wmiuxIJ/DDGC+MnrMB9nOz1XPswlV7/w9qZW9ueop4FMVsGCwxryVBb0uC8aq/V3ANkDz55WhOhsr62/fal0sYe5X42m13CTCApLX3mgKhG3Re+vcks614//uP6F4Jjea6xBU9s9dKbNxVmi4zpijKc8vqRXm+/X/WcO05C3nx4jxz2S2OvTpjP/gkQtPpvOJDaKYzh2eHyvrbEYZF+rTXHPY5VBigPBukLJ2OTDOAbkkYnNiVZUE+RSOIqAcRCVOjUPMZKNTZOFjm0V1F3Cikf6LOA1vHqYch4w2foq9x5qsvY9MDv4LaJGcsbOGVK3q5+kWLOH1RCy1JA8PQ2TFZ5aEdk3TmLAr1AC+KWd6TxYslw1WPlG3QmrbpySfJJwxsHd71x2/n/77vr/jGjV/hx9+7CcvQyDgmPVmHbaNVHt05yfaxKkPFBsFRrkKgKMrR05lzeO2p3Vx+1gKWdKaZ35KgPW1TrIWMVz2qrgeaYElnGtvWSdgmlg6mEWNrGiYaSctkvOwTBRGOAGSMpsU8+YtvM7Hh17Se97Y5z8nij23H27mB9JkXTT90F4CjgW0KsgmLsxe3s2ZZB2uWtnP+im4uWNXLos40jqGTtJ/biRrVCDbNEZP+iRpf+XUfsWwGri9Z2samkQpBGKNrglvWDRBGMQfKG6KZNunTL6D8wA8Jy2NYuY7pEeaaFx0wQ+5udT/a5/ZkLs//+69v8ydXXkjyjs9xzUe/ymmL2vnIbbOntv9+oMS1N66dVWZMUZTnjtWL8tP/3735/n4+8qPHONIcZ1LGjN/+OYLJXXS98ZP7XAt1JGK3Sm3jL0muOA89kT3gsQK49IxeTuzKkE9afOwnG5trW6eqISjK8aZpGqvmtdCRsRkpupiGRkfaIZLNh+5p22DnZJ2UpdPwm4mFQtlcLzlc8TilN8tqJ88DWydY+sa3cf+P/5tH7rqFP/qnT1P1pspwuc1a2GEoySUsal6DockGA4UaOwsNHtwyQj2IyaUsdozXOKEjTXvGIowknZkEXbkE/z977x0eR3nu79/TZ/uuumTZcsW4gw226TX03o0T0hNyUn5pJwmdACEk+SYn9SQQQgq99xJ6xwZsirHBvclW1/Yy9f39sZKwLLnQcjDsfV2+bO28Mzsj78y+z/s8z+fzy19cxeLXX+fiH32PkWN34+AD9iVvuZQ8D8enT+CoCLLEiHigUipeocIuiiSVA9Ld6lU2p4v4qQKSBA2RAHlDoTps4vkCy/Yo2T6O66NKMoauIJU8erMlurMlirZDyfOREbgb3mDdv/9FeNKBhGef+qGfc3bR0EX3uAGaAk0xk3BAJWbouK7A0FTG14XxfEFIK1J0fWrDOvrHXMhse+y6SwMfMj85ZhJXnDQNVZbwheAfL63j4uOm8P0jJnL6XiNxvbKF147muf3iPtnXHqIuYgxkp31f8H6/22aPTnDjV+ay+9QZVB39HTYuW8xffnERSzen2XphektV8QoVKuzanD1nFLeduy9VwcEej/p7LPlMP38TxVULSRz21ffUvrKz5N58DOFYRGcdt8OxArj/zTbmjq3m7DmjuPmrc/nhkRO5+auVRcEKHy88H6qjBvGQTm/RJhoo90qrioyuyqRKNr6A3RqjSAImNESZ2VKF6wuyJY+R1SH2mjaRQ484ikfuuInO3vKifd5ySQQ1wqaGrikkQhq+5/PyuiSvrEtiO4LlnTne7siwZGOKBSs6WduZI1d0iAZVJElCV2SQFX579d8Z0TyS733tc2xsbSWoKdRFAoQMlVjQQJbLgXZFVLxChV0fXZVpjgeoi5qMTAQYUxMgGtQoWm7ZStB20VQJTVNIBHUsV+D6Pql8iYwF6ZJHwRF0bG5l0fU/w6wbQ8PR3/lQFt+2PILXt+gemnzwwKJ7QoORMZNR1SFmjKwiZpgkiw4F2yFdsrFcH1WRaa4OMaE+QjxkfOBz+r+kEmBvQbJg4wsx4IWdLNh885DxnDKzGV2VkaVydrulKjiwjwzsPToxUP6txuoITJhD7o1/s7knUxZJk0DXZL5+wFiUbdWJb4fXN6aAsgCSufuBROecSnrxQyz6921l5d2tkCSJRFB/P7+CChUqfAzJlJxBPw/XorIt8u88T/rFWwhNO5zIzB0HwO8V4XtkX3sQo3kyev24ndrH88XAIuCslgTfPGR8Jbiu8LFDksqCp2KY6DRqahRKPp4QjK8Jc/jkBvYeXcWE+gjTmuJMaYpxyO71KJLC4ad9nmRvN48+eA+KBGFdZXRtiMkjYuwzporaiEFA16gJqdi+T0e2QL7k4To+qYJNxnFZ1paiNVkgETDIWW65N1GWqKut4c//uIVCLs9/fWE+puxRE9FB+BRsm2TBpWB7Q54hFSpU2DXJlRxcT2AaCu90ZFnVnmdzpkS24FIfC1ATMgjqKr4A1xcEVQ3LLccrBRsKmRytt12OpGiMPP1ClPfYLratKKb/KakAuTceRbgWkT5XpZgMExsjjKkN0xQLknNcamIGjTETRVaIGtonbhGwEmBvQb8XtiKBpsrMHVs94Ek7fUQMifLEsC1TQlcklL7Aebf6wX3Q0VnH4xcz5Jc9gwSMqgpy8XFTGFUdYkx1cNg+6+1he4If3/EGiaCOrspUH3QOoXF78fjfruKs5tyQ8Z4vuOyBpZVe7AoVPgHcubh1SKXKzmJ3rKHnof/BaNqd6iO++ZGUiBZXv4qbaicy64Sd3keRJeaOrf7Qz6VChQ+TmrCBKssIUba8cn1BV7bEhu4cq7pySJKgPmIiSTKjqkM0xAKMqgrSUhPC1GUyloOpy+y174GMnTCR667+X3ryFnnHpTtns7w9wyvre9ncW6QxbuIKgQTkig6qKqMpEnlb4Nge7ZkSL6zq4oXlXby6rpfVnTkyJYf6iMkBc/bk7//8F68tXsSFP/g2maJDQNOQFZmAJlEfMUgVKgF2hQq7Oo7ns7anQG/Ooj5iEtRUogGVXMmjLVOkMWwyrj7CXi1xpjbHMBQZUwNDV6gNScjCo/3+X2Gn2mk+6TyUcB2w8/on/Raa20IHNN8j/9oDBEdNJdYwhipTojoqY3k+thAEDY29WmqYOSrBlOY4tVGDmoiOqX2yQtJKD/YWzGopl2IvWNMzMPkbzpPW83zOmj2KpnhgYNydi1uxHB8BGCOnodWOJrvoPsLTP8O6ngIX37uEDyIMvKorz0X3LOGrB4wlEtCY+tk7+PrpR/Lrn5xLYt6v0OINA2MFYDvlMvFKVqhChV2PReuTLFjTQyKoc8ei1h22pgyHl0/RedflyEaY2pMvQFK1He/0PsguuhclUkNwt312arwiS1xe8bWusAugKTKNscDAz63JArbj0ZoqUHQ8LNvHdgSNiQAhXUXeokJNVxXa00USQR0h4Kgzv8j/XvETnn72ec487jMULIcXV3fRm7XoLTg0Rg2Cusb+E2rpylr0pPMoisbarixpywZkXFewOVekUfZ5fUMvHekiY+oitFQFOf3Uk7niiiu48MILMetG87mvf5uVHRlWduTQZJlRNaH/g99ghQoVPgxcr9xakrMcTFXG9WQ2JgtkHRdPgOu66JpCPGKwW8ggXfRY253BFgLPh7qoTqZg0/vsP8ivWUTjUd8kPHIqigp5d8ftr/3saFxAA3fNKzjpTiadeC5VYQ1JBkkSuH75OjRFoTFu4gPxoM5uEYOwqWK5Pqa26/Zcb00lwN6KLQWG/vTUqmE9aRVZ4pSZzYMmiBcfN4VbX9nAsrYMng/RvU6g5+HfU9rwJoGWGR8ouO7HE3DNc2u44qRpHDRtFPfddx8zZu5F111X0DD/l8jGu6XrPlTKxCtU2AVZtD7J/GvLC3uyJOHtjDriVgjXoeueK/ELGern/wIl/NEEs3bnWkrr3yR+0BeQ5O1/MWqKxOl7jeTUPiGzPz21quJxXWGXQQiB6wk0VaZkC1yvLEDmCZ+6iIEsS1iu1zcWurIWRaf8s+f7HHXSadzwx19w9/XXsPecucRDJis680i+T1euhCcEkxtjFB2XRECnOR5EU2RCpsrSTSk8X9CbL7FoVReBgIrrSbRUBzl4cgOSEIypi3D++efzyuLX+dcfriLeNJbYxNmMrVXpzts0JoLbu7wKFSp8TBFC0J4p4Xo+maKDAGIBnXTRYXQiyEZypPIyu1UF0FSZWEinI59DUSSmjoixrjOP6cmseeFhOhfcTdWsYwnPOBoLCErgfojnmnWg6/m7CVU3MOegw0kWBe3pAhFTZ2TCRFYUWqoDREydEYlyOXvedtmcKiEExIMa1eFdu/e6n0qAvR2G86SVgIMn1nH1M6tZ05VjbG2YgyfWcdkDS7FdH1WWmDYixuvOwSSf+SfZV+8l0DLjQzsnX8AF9yxhYkOEWbvtxnev/DNXffccuh/4NbWnXDBICj9ZsD+0961QocJ/hi2rYUCgyNKAZeDOIISg59H/xWpdRs0JP8JoGP+RnWvm1XuRVIPwjCO3O2726AQ/PnoSs1oSgxYQdFWuOB5U2CUoa5to9BZsDE1ClVVipkrY0GhLF1m4tgfH8ZjanEBXZWIBjXhQI2951IQNik6ck+edw/VX/4FkRyvhUaPxXY+VHVnyjoeCxMZkWZm8KqSxR0sCy/ExVZlcyWFVZxYbQWfOopS2qArIJPMWjudTKDo0xAIEDJXrrruOfZev4JrLvse5/+8GqkbvhaHKlLbhUFKhQoWPN0XbI11wiAZUEkGdUt93Z0PUJFdy6Ml5yIqM58HmZBHbcXm7LUu25CKArOWy6e3XWHbn7wi1TCdxyFfpjw7SH3LnSKl9FbkNS5l04jcIGCajawxWmxq271FyJcbEdaaMjDNtRAy9L1udybgYmowiSaSLDlUhHakvuSDBoMqgXYlPVsH7h0y/J+0Rk+tRpLJYmaZIPPFOJ48u62BVV55Hl3Vw4T1LsN2yyrjnC6aMiGEYBpE9jqG46hWc3k07fK+qoMbkrTytt4UQ8J2bF7NofZLmaXOoOuyrFFctJPXMvwbGKEqlx7FChV2NReuTg0rCVVnishOncvacUag7+bTOvnI3+SWPEdv3LEKTDvzIztXLJ8kve5rQtMNQAtt/dh00sW4giF6wpmfgednveLBofZI/PbWqohtR4WNNLKgzMhFk7zHVjKsPEw3qhE2NN1tTWI6HA6zuyJItOX3iaGCoCjVhgxHxAKfM+xKSLHPH9X9jU7JIfdQkGtSoChlETZVkoWwJlnM8iiWPqKlRGzeZUB+jLmIQM3U8ASUH0kWfbNEhV3JYtjnLm61pCpaLYZpcd+PtBENBrr/8W6zftJmc7VAb+WRkhSpU+DTheD4d2RIF22NNV4687dEYM2iKmuiqRLbkEtBkXNejPZXnlfXdPP1OJ53pEjI+Bctj08b1LPjbxWjRGqpOPA+hfHS51cwr9yDpAdSJB/NOWy8bkgUKrocsKURNld0b4wRVpWx52KdqFtBkSo5HwfYwNRlJkkgXbDb25mlNFig5u+biYCWDvQNmtSS45py9BnoiN6eK3Lhww6Ax/d7ZEuXysVNnNnPqzGb+NT7IHxbeQebVe6k+4r+2+z69BYfe9yBCsilV4oy/vMislgSxvY7D7l5PZuEdaDUjCU89jEO2mNBWqFBh12DBmh7cPkUzCTh9r5FMbIjw1uY08ZBOd3b7VSmFVS+TfOrvBCfuR2z/sz/Sc80ufgg8l+gOxM0UWWJTqsii9UlmtSQGxCT7va8TQX1QRvvi46aQLNiV8vEKH0tURUZVZEYmQvhC4AmB74GpKeQsl5Ln0xQPDjiL1EXLgW08qHP47Mmccspp3HfbjRw8778o2hJN8QBhXcMXPrmSR8jU2NxTpKtgU/AEYV2lIW4ypjpEa8oqu5JI4HmADPmSh+265IoOT77TQSKs0uMFOffyP/PL75zNny/6FtfcdDdVn5CyywoVPk24fY4ho6oCJIsKYV1hfU+R3ryFIgnWJwskCxYr2rPYnkd3toSqyJiqgu1qjIrJ3POPi/Fdh9Hzf44TiPBRhatutpv8O88R2fNYLClEb94nFnRQZInGuEFN2ERTJHK2S1umREhXqYsYVIUMdFXGFxAxNXxf0FuwCeoKridoSxdpiAYGgu9dhUqAvZP092YvWp/ktlc34mxhk6MpEj89YerApBDKE+VzDp3BnbMOZ/PiJ4gf8NkBL7gPC0/Ay+uSgETV4efiJjfT88gfUOMN1M0Z9aG+V4UKFT565o6tRpbKi3aKDBFD5Yy/vMjOuHLZXevovv9X6PVjqT72ewPtIh8FvmORfe1BAuNno1U3b3esEIJbXt7AHYtaOWi3WuoixqAgesuMtu36XHzvW/hCVMrHK3yskWUJGQkV2LMlzqvre4mYKnuPrqYpHhhkydmWKpK3XWzX54BTzuH2227h/tuuZ/YJ59ASDdEQDZC3XToyJTzHoTZu0BwPsCFZwvFk6qIG9ZMakBGs6MrTnimRKTmoMvTkS8TzKpvTeVRVpitXQhYSk2fsyTcu+hW/v+g7/PqS/2buTdejKDv3THA8H4nyYkKFChX+79AUic3pIpt7i0QCKvFAuU3F8wQygp6shUzZH7s9VSRbcvGFIGJoxAIKd/76fLLt6xkz7xJIjPzIzlMHsoseACFI7HUCLlCwfVp7szTEQ5Qsh3AiiOX7ICCkKfTmbHxfoCsyAUMhqJdDUiEEmizjeIJNvQUEAuELEiGDRGjX0ZaqBNjvkVktCW752j6DerC/ftC4gUngYIEi8Kcei3j5YXKvPUxs3zO3edwtpe+9XJL0gtvw7RKBMTMJTtx3hwJCkqJSc+J5tN/wQ7ru+hm1p+y709fUn52vZIwqVPjo2d79trw9OyCI6Prwl2fX7NQxvVySzjt+iqwHqD31IuSd8LUUwqe46hWKq18GWSE6+5RBbgTbI7/0KfxihujeJ+1wbL9Gm+36PLasAyhPBm7+6lwANqWKqIqM5/kgSbh9O/SXj1eeSRU+7rTUhGmpCQ953fV8bM8nb7tETI31PTmCDbsxdtpePH/PDZzz5W/QUBXC8wVNWpDaiE5n2mZcfZCIqREPeNRHDVzP54Z//oPnnn0aN9pI1fTDkQghCYGiAJQVhavDJqlCCVNVGFUVYvd9j2D+ud/nxr/8hp9Nm8JFF5y3wwxQumDTk7dBgrqwQdj8aNwHKlSoUCZVsEkVHExNoTZiDFqcKzoeluUxuibI5mSRBW1ZArqMoZT/xEMquZKDQBDQZTyh4/ouVWGDVQ/8lZWLnmPf+d/HnjCTnuKOz8VJtpF/6wncVDuBcXsTmnzQDveRgKBcZO2bj1A9eR9itQ2UnHJbrdG3SOf4kHM8JgR0nl/ZxVPvdKArCpOaojTGAyRCBnVhg7zt4Xg+EUOlK2fRk7OpiRpYfSrqlQD7E05/2fjWLFqf5LePrxgQKPIF6LWjMcfMIrP4fqKzT0FWtSFiRbNHJzhoYh3ZosO1z6+l/d6rsNqWI+tB8kseQ61qpvrI/8IcNX2756UEItSdejHt1/+AH33tbDYUbmb+AZO3O0GtCA5VqPCfY8v7TVVkTptVbifpv+cefqvtPR/Tdyw677oCv5ih/uxfoEZqdriP1baS3n//EbtjNbIZRrgO1oYlNH75f3c4ARfCJ/PKPegN4zFGTnvP5wvl4PnOxa3ctbh1QBzysEn1PPlOx8AYRZErOhIVdlmKtkd7plhW/u3ryV7XlaUzW2L/kz/Pvy77NouffogvnPM5MiWHVMGhPVvCExALmDTEglSHTTIlh9tuvZX/d/EPqKlroKerA/m+f9C4/ynU7XMmRUkiHlTI2y7rN/RSsFx8D7qyNomgxuxTv8rKlSu45KILSDSO4rNnnbHdSWqy4BDSFXxR/nclwK5Q4cOh5HgIwaBSZ9v1SeYdTE2iK1ukJ1eiOREk1ucCJAGSLCH3CYCFDBkZmVTJYXJ9lARG2W1ESKySs2xO5nGExspn7mHJ/Tcwev+Tie95NO1ZCw2wtnFuvl0k9ez1ZBc/AIASjJFf9jRyMEZg9B7bvS4ZKC17EreY4+DTvkQubNCds8r+3LZHrSQRNnQ8r7yYsKYji6JKVIcM1nflaYwHkAV05kqokoymSiSLNroiM7ouSMHy6M6Wfb93JSr1Px8S/RPn51d2Dwmgo7NPxs+nyC19Ck2RBq1OAby2McXcsdX85JhJ/PWsSVitS6ne/2yav30DtSefD75Lx83nk3z67wh/+90TWtUIak8+HyfZxp8v/CZnX/38IOGgrcWEhhMcqlChwgdny3ut/9/9AWV/OfRNCzcw75qXBu7Ho6c2bveYW4e+Qvj0PPg/2G0rqDnuhztUDBfCJ/3SbbRf/wO8fJLq435A87dvJHHYV3F6NuImN+/wuoqrXsHtbSW698nvux9KU2UkGCQOWXS8gWy3BJw2q7my2FdhlyVVtFEkibztkSu4lBwf1xVossTusw+iYdRYHrjpWhrjASY2RHE8n4CqsVtDFMvzyZZswqZKUyzAa88/TkNjE4+/spSf3fQ4o2cdTOvTt7Dsuu+T6mhlc9IiU3JQkNBkGU0t9zLmbY9UweXU713OmMl78KNvfY1nXnhxwE4MypP+tnSRnpxVLtdUJUqOT8n10HdWWbFChQrbJV2w2ZQq0pYu/2lPl9icKg4IeBUdj3Xd5W1vtqbJFMuaTGFTY1xtuFwWLgSGIuMJaIqZjKgK0lIdYmRVkN3rIkwfEWVSU5xAx1Leuuv3VE/cm7HHnUvOcvBhm0Kpfsdq2v/xHbKL7ic840hGfOPvjDj3OiTNoLjq5R1eW0DyWPfsXdSOn860PWayx8gqJjfHqQmX9SWChsKEujBBXWVZW46s7RE1VAxdpeh6eEIgKRKmqiBJIPfNK4KGiqkqhHSNlprQLpW9hg8pgy1JUhy4FphKudL5S0KIlz6MY+8q9Aeqw7VKmi0z0OrG4r1xH6d95Ut05Z2BUkkAxxP84uG3OWhiHSO1HAAjm0fQ0BjjbWlfzDEzST55LZmFd2J3raPutEu3O7E1R02n+qhv0fPQb9n84B+5Y6+WbdrjbC04VMkYVajwwdk6U40QuL5AlSVUZbD1n+0J7lzcyqyWBGf3aSdc9/waVnXlhxxX6usl6X/OpJ79F4XlzxM/+EsEd9tnu+fkWwW6H/g1xVULCe5+ANVHfhPZLJe1KpHyfe+Xcju8tszCO1GidQR3338nfhNDOWJyPV8/aBxQtiTrf/YcPbWRV9b1Dvzc75ddYdfk0z4vMFWFZN4mX3SJhXRKtkNbzgZZIqhonPmlb/C7S/+b+x56lMMPP4xJjVHW9uZ4Z3MWIXzG1oYZ7fqEdJXunl4SNXXlBalgNXt9/nyUCfuy+s7fsPLa7zH6+7+jNjadZMGmZLsUbA/fSxELa6hKWRhp/gW/5W8/Pocvzz+TZ59/gUkTyvdgR6ZUtsdxHCRJoi5iki46yJJENFDJXleo8H7o/47vX6TqLdhIkiCkqWxIFqkJ6eiqTLJgkwiqrE8W8PGpj4ZJFx0yRWfg/htVHUIANSGDV9Z1ky96TGmqQZYloqZGT66EoSk0xIK8tXQpD/3+fAK1oxhz+k+QZJmALONh47gCxWeQyFlu6VP0Pvx71GCUEfOuRB31blWaHIjhW9ueEwQAVQVv1QtYyQ72+/J/EzB0mmIKQoZcyaMuGsDzPWzfZ86YBCu6ckRNhaWbM9RFgxwxuZ7xDVFURUZQfh6VHJ+asEHYUAkZCkJAQFN2KYEz+PBKxH8HPCKEOE2SJB0IfkjH3WUoixNJ+GJoiC1JEo37n86Gu37B326+k9ju+6DKsMUcm5fXJcuCZaUMAOvbuki2ZwGQNZPqI7+FMWIykrxzH7LwtMNxkm1kXrqVv/6hkZ7ct+jIlIZkq795yHhu/MrcSg92hQofgK37qu/awst6y2Da8wVnzh7J0k1p3mhND7wubXGcZMFmzthqVnflhyzY+Vu8kH3j32QW3EF4j6OIzj55u+fnZrrovP1SnJ6NJA77GpFZxw96jvQH1rIR2u5xrE1vY21aRuKwr+5QFyIeUEkV3YGfg7rC9V+eM+gZs/WzZ2JDpPIs+uTwqZ4XxIManiewXJ+oqbKio0hLVYCujE266HD8SWdw459+xXV//h37HnQw9VGT1mSeREjDUCV6izbjZIkV7WnyQiOVSvH6hl5yJY+AoROdsBcTv/Zbul+6i2xoBCFVplv2cHwPTZZQFImQrtIYDRIOqhSsAD/8f9dxyVdP4YQTjufvtz8IZhiEz9jaCJos4XjlBcHqiuJ4hQrvm1TBJpkv90VXh3RURSKVd0iXbDRZJl9yEL7ocxiQiAZ0JhkaJcsrixcqMmGj/P2aLTnkLJdcqZyFntoUJ11yydsuXk6gKnK5hLpgk+3p4vrLv4WkGUz7/OWEIhEMFUxVRZIlcoXSwFxDCEHquRvIvHQrNeOn03Laj5HMGJYDWRd8IfBLOWQjhAmUtrpGHRhVa6AqCs//4y5CdSOpm7Qvq9ozjK0PUx0K4IsCIV0laCjURwNIkowsKVSFdGaPq2HPkQlaasOY+ruhaHNi8NdEUN91O5k/8JlLkhQDDgS+ACCEsIHte8l8ApnVkuCyE6eWFXB9AdLgyTDj9kWN1pFecCfhCXM4a/Yonni7g/bM4I4IYUSQ9CBOcmgvZnjqoTt1LsIt//rjB3wWN9VOz9P/5O5o/YAnriwxKFvdr5BeoUKF987WlSEXHzeF21/dOGw1iyJLAzZ+8/66YCBbe8rM5sFZb1lCliU8f7ijQHHtYnr//SfMMTOp+sw3trvoZneto/O2i/HtEnWn/5TAmD2HjCmXhkso0drtXmt64Z3IZoTw9CO2Ow5A3aoebcuy1H62fvZUnkWfDCrzgvLCek3UIBxQKTkeLZ7PW60pHE8wqjpIdTzE/C9/nd9ddTnvLHmTloPmUhsNIMsKwvfJpoq0pYok8zaxuhF0PP1vUvkChq4Q0hSaokHSejM1J32boKGwtjuHLCvYDkgyGLqMLEkEgwph2QdFUD1+N77/i7/ws2+fwze+/Dm+9fO/EjAMJFmmpTpEPFjJWFeo8EFJFRyCuowvBJ25EgFVoTaiUxXSWd2do6UmRKbksjlVYnpzDFmW0GWJGSPjpIplsbNYQKPkeHRlLUy1fC9LgCzLVAd1MkWHurBBzNToSBfRfZtf/fDL2Lk0+3/nd6QDCYoZm6YqA8v1EL6LK/UFfZ5L+8O/I7/0KcLTjyB8xH+RFCphB2rCGmNNlQ3t7WywC4QT9ShACEAG2wdDgYgBY6ojmD1v8dCmVRzy5fNxhI/tQnemhKbIKLKEoUiMqw6hIEgWbEbVBFjXlWdsVZDGuEmm5BIJ7Fql3zvLh9FgMwboAv4uSdJrkiRdK0nS9tMgn1DOnjOKW7++Dz84ciJfO2DsoG2SrBCdfTLWpmU4m9/mlJnNfOew3YYcQ5Ik9Pqx2O2r3vP7+06JwvIX6bj9UrruvYrC289Sc8x3MZqn0P3gbyhtfAuAxnigImZWocIHYMv+6i11DCzH59ZXNgwoYW/NpMayVd+slgQ3f3UuPzxy4oCadr9AYn9P8qxR8SE91wBe91q67vk5Ws0oak/8yXYzyaXWpbTf+GNAouGzvxw2uAaw21ehVTcja9vOXDndGymuXEBk5rHIemCb4/rZ2rPb9+Evz6zmc39byE0LN+xw/wq7NJV5AeUsEZSzMI2xALURg9F1QWrCOh1ZixPO+jyhcIR/Xv17VEVmTHWIkK4gyzJzxySoDhlUhw0mT5uO77msW7mcoK4SMlWaqgJ4vo/leeQsn45MgYD2rrZBR9YmJJVYteBJLv//vsA1P/0299x1GxP3nMvXzr+KZa++yA2/vgjLdXEcn+Z4EEPdflVKhQoVhqfkeOQtF9fzMTWZvOWysTdPd8aip2CTs8r2WVFTJairNMVM6iIGvhBlFXAhMHWVhliAeFAnZ7m0p0vkrXJG29AVRiaCVAU1fOGTc1yefKe9vLAmPK744TdYu3wpB37tp9SMnojnAxJkixbZkktv0aXkgOSWaL/rcvJLn6LuoM9SddS3EYqKAWgagMfccbWEs60ANIydQDQE0bBENCQTNiBslhMA61N57vvXNcRr6jjo2FOIBEwcz2V1Z45l7TlkIZFzfN5qy7GkPcPa7hyKkBiRCJCzHDZ05+nOlninLU13dusc+a7PhxFgq8BM4M9CiD2BPPCTrQdJkvQ1SZJelSTp1a6urg/hbT9+bFkmGgloQybHoWmfQQ5EqVn9yEC/5ZUnT+OACTWctEcTM5pjZVn7pt2xO1bj2zv/gfMKaXKvP0LurSeIzjyO+P7zSb90K266g9pTLkSNNdB11xU4PRvZlCyyvK/8vEKFCu+N/kzzrx9dzvxrF5AI6gN+sQJ4a1MaWS6vNqtyuQer/0G7ZFOa+dcuYNH6JLNaEnzzkLIo2fxrF/DCqncFEmVZ4vWNqSFZcDfTRfedlyHrQepOuxTZ2HbVbXH1q3TeejFKKEHD536FXjt62HFC+Fib3kZv2n27151++U4k1SAy6/jtjtsWAnhsWQfPrezm/LuXVILsTzY7nBd8GuYE3TmLzakirckCruczsirEyFiA9oyF5XjUVlcx75wvcucdt7NmzRpMXWXPlir2m1BLTTiAJEkUXZ/xU2cBUNqwlKih4gmfTb1FqoJyOch2HUqOoDVdxMInaMr0dnVyz603cetNNzDugJPZ4/gv8cgNf6bQuZHPnvM5jv3Ct3nxkbu47a+/JVmwyFnuoHP3fFGuxqtQocJ2Kdoem1NFOjIl2tIlqkIGQUMhb3lIkkS+6FC0PYqOR0PMRJYlfMrf8z05m46MRargDByv5Hh0Zi0UCdJFl3XdOYqWS6roYLmCFZ1Zlrdl6ck6vLiqk6suPo83X3ySr//4MvY64FCSBYf+btVcqbzgpgCKnWPjbReTX/sao477FtG5Zw1UvxWAZBEKjs+S1l4y699CkhVqR08hbOposoyhlKtyFBl8BN2rl7L2zZeZeviZJAsuyWKJ7pyL63vETBlFlakNG8QDCg1hk6CpkbEcVFlibE0YX8C6rgLZksM7bRlyljPkd7sr82EE2K1AqxBiYd/Pd1D+Yh2EEOIaIcReQoi9amu3X4a4KzLcpFtTBofYsm4SnXU8r73wBLc++gJQznofPbWRB95sY8mmNLIEZst08F2svozzjhCuQ37Z0zg9rcT2PZPgxH3R68chmxF8u1i27zr9UpBVOm6/FC+X5OJ73xpQLt5aWXzLaxru9QoVPs1smbG2XZ+H32qjMfqufYQnGJiYyrLMpcdPYb8JNch9bSNbq/Vvebx+fF/geIMnt34pR+ftl2IV89SdfilqdNt2XPnlL9B51xVo1c00zP8FarRum2Pt9tX4pVz5ubMN3Ew3+aVPE57+GZRgbJvj3gvvx5Kswi7DDucFn/Q5gecLcpZL2FAxVRnbE0QDGqmSS01IJxE2yBZdjjv7y6iqys9/8cvyfp7Puu48izcksTyPCXVhxrWMomnUGF5/+QUWru1mVVsO2xeUbIHrlSdyPlAfNqkPmxQKJTa88gQb1q7C3PsU2mOTecuqQQlEeGdjJ7Yr+Pq3f8jEA47jyZv/wq03/osNvfmBjHuqYLO+J8+G3sKAyjGUJ/6pgj1su0eFCp9WSo6HKkuYqkyqYNGTK1EouWzsLavzZ22XvOWQCOpYjk9VyKApHkSSIGSoBDSZ4hb3mecLJMotZQII6Cq+L7Bcj4Ams6GnQHfeoilu8tgt13L/rf/kiLO+wt5Hz6PgChrjAWJBGd8HZJDw0dwsG2++gPymFUw888eMmH0U1Vutz3tA0YG3NudYv+QVasdMIhwNISNIBAwihoaBwFQUsgVY9eStKGaIxrnH4ngQUhQaIhojq0KETQPH9VBkCdsTdOYdJCFRHdIJGRpF16PouCgKhEwNkHC9T9aC3gfuwRZCtEuStFGSpIlCiOXAYcCyD35quxZb210t3Zx+V/J3C8IzjyO98E6+/P0LWfvL/wXgL8+uGTRmzNRZdN1lUFzzKoFxQ/22t6a0cQnFNYuI7zcPo3E3hOdSWPEiWtUI9IayUqgWb6DutEvouPkndN5xKfXzfj4wyR/OB3s4xfH+66wIEFX4NDN3bDWqUlbe9wU8t7J7yJj+YLn/WfDdw3cbpJC9pVp/v5J/vyha//6yVA7WobyI1nn3z3B6N1F3+qXodWO2eX75ZU/T/cBvMJomUnfaJQNK4duiuOZVQCIwevjycYDMK3eD8InOPmW7x9qa/kWF4diRJVmFXZfKvKD82dcVmYLt4vmQCJZts6qCOu9kirSlS7i+z/SmBo4/9Uz+9Y9/8N8/Ph8RjPFWa5J03sXxXJrjYTRVZvb+h/DQXTdTd0KeYFCjWHQoWS6qDJqi4EsSkixRHzVZ9dpL5Fe/SvUBZ+JX7UbGcskveZZATTNOvIVHl7bRkSqw72d/iJXu5o7f/5Qp41uY+LX5SJJEZ6ZExFRxfUgWbBpjAUpOOUunSBLJgk1zIogE5CwXWYKwoSHLwzW1VKjwyUUIgaHKrMvZdOaKFGwfx/N5Y0OabMmhM1tidHWQ+lqDdNGhO2thaipNcZOwoZItuUhI1Ebe1T8wNQUEdGaKJHMWCuXniaLIrExmiQRUVnXmeOXRe3jm7//D1P2PZK+Tv8Fbm5IYqoxlu9iujOf5RAIypUwPb//jAqxkOyNOuRB93F5Yfrlta2tsF0QxRWrjcsYe/jlSBYei5bE56WAagAyqkCh0tZJa9iJNB5xGr62S7sygqyr1UQPLgxkjI+wxKoHn+7jegLQaTYkgyYJD0XZxPYGhKfRkLRIhnYix6wqaDceHdTXfBm7sUwpdA3zxQzruLsPWdlcCcL2hn14lECG8x1FkX72PP9z3Ilq8YdB2X0B7XmCO3oPCyoUkDv/6dgWMhO+Re/MxQpMOxBgxCeE5lFqXYW1ejl4/FpAQQiBJEkbjBGpO+DFdd11B1z0/J3nw37lzcXFgYt+fWZvVkhiyYHDX4lbu7FNGVmSJy06cOmApVKHCp4VF65PctbgVzxvekq+f/sBSALe/upFTZjYPKGYngvrA4taslgSPLW0nqKuYqj+gui0AQ1Mo2B5C+HQ/+BusDUuoPu4HBEbvsc33zS15nJ6Hfocxaip1p168U73SxZULMJomooTiw273CmlybzxCaPJBqLFtZ8K3RgKmjYjxZmt60O9KluBrB4ytPD8++Xyq5wWSVA52c5aLIkuEdIVMyUXTJBpiQXoVC1WV6M7bzD3pC9x1641cfOUvOemr/83STWna0wVeXCux79hqTE2lZY/9sW+6jvSaRUi7zcV1YUpTjI0ZG+EJgir4wieqQuqNxxg35wgYOYlU1iHfuozcxuVU149kY08ezxckSy6mprDnFy4h//vv8vMfnsuEUU3M2Xs2nVmLguURCagY5cZMbNdDliCgK+T7JsfJgkVvziJZcEkENSbUR5BlCU2p+GdX+ORjuR7t6RKW65Uz1xmL7pxF0XawfZdxdWHSeZuqiEYooNKWLlEd0ik57oAVVVBTSJccio6HoSl4js+La7pZ053D6POk94Gi5VByPF5b30t33mbt6wtYeN3PiI6ZRt3R3+LF1Z3URgxCukJbtkjJ9pBVyCa7WPGPC3CyXbSccSnyyOnkXRBuOQDUKVe/eIABICC/8mUQgsTue1OwPCzbJ+eDb4GiQMYV9L58B5KqEd3zBBzXw/J9XB8kyWBENMD4mjAS5XnQiEQAU5NwHEFIL5fOF4VASFAbMaiLBAZK5z9JfCgBthDidWDHqdZPMLNaEoMsZ4BBVj1bEt37ZLKLHyCz8A6qj/zWsMcL7rYvxVULsdtWYDRN3PYbSxKSqiO88sQ8v+xZnK51yEaQ0LTPDAnOjaaJZY/sh3/PVRd8j9rjvofo6xRQZGng3IdbMOi/FtcXXHjPEjb05IkEtEpGu8Kngv6qjpIzzLLvFigyHLp7PY8v60BQLvfqt8QDBimFj6sN8/Y29BDKwbUg+eTfKLzzHPGDv0B4yiHbfN/sG/+m95E/YI7ek9pTLkDWzG2O7cdJtWN3rCZ+8Je2fdxX70M4FtG5p+/weACaUi71kiQYUxPi7bYM9palXwIiFX/dTzyVeQGoikw8+K5CbnVYJ9dj4/mCWEgjpKlsThWoax7D4UefwP23/JPZJ3we21UxVYWeks2a7gKpvEXtuBmYoQjtrz/LqBn7EdINkiWHVLZIQJOxXYGvSLRJgKKi4LHX+BoevuNW8pvWoRpBRu17DIaukMq7aIpEUJUI4vHzv9zAxV89jW+ccya/v+l+pk3ZnZLjUq3pVIfKwoempoLkkLdcVEVGlSUKlkfR8dFVaEuXsDyfqqBBxFSoiez4+VOhwq5Mb95GkSQUSaI1VaRgOxQdn1TBRlVkUgUHXVcYWx2jNqojIRMxVXJ2uRxckiQyVnmxyvU82tNpNqeLLNmQpCkRZl13lvqYiS98VnbkyFs2KzqybFi5lOV/vxiteiSxEy9kXUZCxSNvF7CcstK3rkCht5tNt5yHm08xZt5lqE1T8EQ5mBaUbR2CgKSCJCAakhGeROeql9Dj9VA1llTeoyj6AvG+v+VCF5m3nqJ2r6MJRBMgPDwhCBsGMVNnfGOIMXURJMD1feIBjRWdGTxPImc7SJLEut4i8YCKrpbV1j9pwTV8eBnsCgy1mLnxK3O5c3Ert7+6caCfUpZAjVQTnno4uSVPENt3HmqkesixAhPmgKySf/vZ7QbYkiQTnX0K3ff/ivyyp1GjtRgjpxGadMCAIrDwPTIv342XT2J3rCY89TDiB3yO1HPXowTjJA79MhJw+l4jB85/uAWD217ZOKCO7ItyabsEGJpcUSWv8Imnv6pjOGrDOoosMaoqyI+PngTAcyu7hpSED+rf9sQ2g+t+MgvvJPvqvURmnUB09qnbHJd9/eGybdfYWdSdfAGSunO2F4W3nwUgtPt+w273rQLZxQ8Q2G0f9JodZ5xlCfYcGefldUmEgHte38zs0QleXveujoO8xUJehQqfJiKmxu4NMUYmXDqzRXwhEQ1obE4WOOOr3+axB+/hrcfvoOmgeUiyhKEolNzy5NVCZtSeB7Ju0dOMq76Mdzry+ALCAZ100UaVoFqSWd1boumA03nj+itY/uIjiGAVsTEzGDHrYJA0sgWLfMmh7YW7EMUUpbbVWEefykV/+Bc//uLJnP/1efz5lgfZbcwoGqIBlL6Jr67KNMcDuL4YsOCJBjVWdGSxnLJ6siLDyESQVMHB0BRCuvqJnDhXqADlpJTt+CBBTcSgI+3TUq0TC+iETYWgrqKpMjURjaCuEgtq5C2PmKliauXElu366EpZdfyVdT20Jwus6ylSsh16iy6bUwV6cxaGBrmCT6pzIytvuATJDFN7+qXIRgiPctCcLpaz0QBetpPWm8/HL2YYeeZlBBonETSgYENui2mMDURlsD3IlXy8QobkysWM2O8kBAJdAd+FiFkuUTdV6HriLgAmfWYeDSPC9GQdSiUbxy8n4cbVRFBkGV2VUGQNXVWQJJnGuEFP1sLQZKY0RshZHumiS1N822KtuzKVAPsjpD/gPnVmM3ctbkUAU5tiPL28kzXmF3lqyWNoyx5EzDlnyL6KGSYwdhaFd54jcciXtmvFo9e2UD/vSoRTQg7EBlntCN+j6+6foYSrCYzbi9CUQ+i6+0qqj/s+kZnHkXnlbuRQgup9TuWUmc3Dnn8/l504lQvuWTKgTgh9q2Cuz28fX8F3D9+tEmRX+MTSX9WxtSAZlPsUfQGpYlkFc+sFqv77ov8YO8qCQ7ncO/XMPwhOOpDEYV/ZZqtI9rWH6H30fwmM25vak85HUncuOyyEIL/saYwRk1Bj9cMfe/ED+Fae2D5nbvdYulru99JUGWurRYgNvQV0RcL1BbJUbi+pPCcqfFpRFZlYUCcW1PF8QbZoI4DYjBkceNiR3H3DX7nhnK+xPu2yKVmk4HokTJ1EUGPfz5zAiucfpGvZQsZMnkt3zkaVJTIlG0dAZ85C+D71I0az+5d+AW4JRw4Rj4fxhURIl0mEVV7660UYsWoaJs9l/6NO4t7f/IjJu0/gd9fdwn/NP5kffmUejz/x1JBSb1WR2dLJqy5i0hgz6MlKmLqC7QtSRZuurIUkSwQ0h4ZooBJkV/hEUhXU6REWni+z58gEbysKndkSpi5RHTKoiZi4nk+25BE2dWojJrWRwceoDmm0pUts7C2QKTgYqoIqS2xMldBUiQ29BWzbI1m0kAopXv/reQjfZ+S8y1AjNWypu93nzIWV7qTt5vPwSjnGnH05Ut1EUCAW0NF1j3zGG6islQDbAV0Dz4eet55H+B7KuANJlSCoQkAHVZUYVx3B8DK88Oq/2X2/ozh0zlTakkU6MzYBVUKVFTRVpj4WpDERwHF9goZaFkoUULBdLNcjHtKRkIiYKlVBY1CVzyeJSqPMf4BZLQl+dvI0Tp3ZzGUPLOXxtztYY4UITjqQ9S/ciyimh90vNOUQvFwvpfVv7PA9lEAUNVpH7vVHBvyuhRB03/sL5ECMxKFfITBmFkbDeIIT5qIEoiQO+yrB3Q8g9fR17Fl8HWC7quETGyJsJYw+0GPx/MruAfuhChU+qZw6s5nDJtUze3RiwIZPpnwPbK0Q3m/DtXUwecCEWtQdTDgLKxfS8/DvMUfvSc2x30OShn9UDwTX42e/p+AawOlcg9O9gdA2ys59u0TmlXsw+54b20KS4NLjp/D9IyZy41fmcubegzPdHRkLJImzZo/i1q/vU+m9rlChDwlIFh0kBMvbM+x/+ldJJZM8d9/NzB5XzZHTGpjVUs2k5hjjasLMP/kYqmvrWfLM/Zw8o4XxNRHwBUFVozqgIUkChCBTdAiGEzQ0NNGz+BEyq5dgytAQD/HMXy4lFIuz/7zv0TR1LpGRuzFh74MoYBIbNZHLfv831q1azmfPPJWNnUk2p4oU7eFVw0uOhyTJ6JqCDzTGDGRJojZkoEplC6KC7Q67b4UKuzKu59OZLZEpumSLDrbrMyIRYK/RVYyIBdD6BA6Ljk9AL69KCSFwPB/P83E9H88XpIsuliPIlBwc2ydd8qgK68wenSBuGDieIGN59PbmeePvF+Hkkow/+xIaRrcMOScPcNKdbL75PHwrx8gzL0eum4gP5DxYn7TxhUd9COIamEBNGDQdSjZ4HqSXPIlWMwpRNxYPyLrlAHt8bYSQqfHGI7fguy57n/BFMnmXeECjLmpQHw9j6jKxoIbl+ZiqQiyooylyX+VOGFmSiARUwqaKokiMSASpi31yW0kqGez/IFvb8UTnnkFu6dOkX76HxEGfH9KrHRw/B9kMk1vyOIExQ5zPhiU07TC8dCcA+beeQDbDJA79ykBWu7judYprFxGZeSySrFB1+Ln4pRx3/O4inm8tERg/Z5Ca+Nbn33/uEvCZyfUUHY/nV5b9e0uOzy8efpvbzt13kCd4JVtVYVenv/96OE0FKN8PUt8/Hl3aTrbokLFcJMpCRMmCTbbocPWza7YrjgZQ2rCE7vt+gd4wjtqTz0dShg+as4sfpPexP5eD6xPPe0/BNZQz5Cgqwd0PGH77G4/gFzPE9t1+9hpRzuD395j33+/XPLua9T2Fch+65zMiHqg8CypU2AJJKrdVvLYhRTrvUD16ClP33o+//vn3zD5mHpKuM7IqQLZgURUxkGWFY04+nZuv+wuqk2G/cVWUXBe/NUOqaOO4ZbsgXVWImzoZ22XaIcdRLxeINsVZ+eKDGKEIe575HZqqEnRli/S8s4jVi1/kgBPm8fbmLIXIKC74xR/46Q++wTmfPZt/3HAL7a7HyEQQRZYGKmmEEORKDvGARlVIJ2+5tFSFQILX1vXyemuKXNFlzthq5oyrpuT4lFyPeEAfCDgqVNhVyZZcbNfHclxSBYfaiEFb2mJifRhHkXF8j4LtIQufbLFssbU5VSoLpAqIBTQKjsPGniKO76NIErGIjly0SZYEJVeipTrEwrVd5It5Nt1xGaWujYw/+yJiLZMo2EOr4Lz+4LqUo/HMKzAaJ2BtPUaUg+mgLmF5AoSMofgUAau3ldLmd4gf/MVBFXOWA0gSppfhzcfvYMzsw6kfNbrcP150MVWNkAYlT2LuuFpMXSn3VfPuMepjQaIBg03JQl9lzCfLkms4KgH2f5C5Y6uRJQm/r8ZaqxlJcOJ+ZBY/QGT2KSiBwbUjkqoRmnww2Tf+jVfMDtk+HIoZRumz5fHySbS6MUh6eYWotPEtev/9J+IHfR6tagReIU3yqWsxW6bjW3na776K+tN/Ci3Tuez+pUwdERsoG79zcSvdWQtVebcU9OsHjeOxpe2DbIpeXpfku7e8xiNL24dYf1WosKvSvzi29VeCBKh95c9lQTN4ozXNG63DV6XsCKt9FZ13XoYaq6futEu3qQL+bnA9h9qTfrLNIHxbCNcmv/RpghP2Gfa5IlybzMt3YYyajtk8ebvHMjR5SE/12XNGMbEhwrxrXsLxxCABxQoVKpSRJIm6iIHjlxWFZRmO/tw3+NV3PsvT99/CnOPPRldkNqUdLE/gCzjw2FO5/po/cuXv/kJo5vH4QiYeUmlL5cg7kDBlVEWmIRFktCoTDtTi+OD5PmE/z/TpM9h7XANVQZPVby3nr9f9nP3P/C+URBOrWjfz0o2/ZdykqXzxh5dx3a8u4jvf/Do/uOK3dGZKVIcNaiMGybxDd66EJsuUHI/qsE4koGFoCvmSw8ruLD05i5aaEKu7sjTGTTRVwVBl2jNFRiaCA8/SiuJ4hV0RSSqXZPe3PmVLDqmCw9qeHEFNoStrk7UcVnRkGV8bpjtvkSzYZAoOo6rDzBgZY0VHlmTexvMgHtLRFQnLFeiKjKHJjK4KsldLnKsvvpBi6zLqTvhvpBGzcFwf12JQebib7qSjP7g+62eoDeOxtzpnDwgpEIoYRAyZXKlIplhW/xZA7s1HQZKJTTl00H6WA5Lv89S9N+JaRfY79csUHY+GWICIKrExbTGuNkDUNKiNmMQDOuow97XtunRlS7gCQrrCiHjoQ/5f+XhRCbD/wxy6ex2P9akLA8T2PZPC8ufJvnof8QPmDxkfnn4E2cUPkF/6JNG9Ttzp9xG+h9OzETXWgCTJFNcsoueRPxA/4HOEdt8fANkIojfuRvqFm6k66tukn72ezrsup/6sn/EGu/FGa5qbFm4oP0j6TlhVyqWeU5pi3Lm4lVtf2TjkvZ9e0TXI4qu/ZLaS0a6wqzLQf91nmdGPLLFdJfD3gtOzkc7bLkY2w9SdcTlKMDbsuEGZ6/cRXAMUVryIX8oSnnHk8O/xxr/xcr3UHPfD7R5nRnOMqSMGn2d/9UoiqJdnIYi+vytUqLAlluvh+rBnc4INyQK+gNOOPZLHbtiHm679I/sdezq9OR/heQihsildJBBuoH78NJ649xZmjD6srH+Q8shaoKlgeYJ0yaU+ojNtVILlbVmE7RI2JFavXMmolhaqQgGWvPws1/78PI4859vMOfx4lndkKPga1aN354mbr+Z7V/6RL3zrR/zjj78EPcR5P72KFW0ZFq9zaKkKISQZx/OIGArRvqx0pmjTni1RFTRoVyzSBYuwqaPIMopUtu+yXI9M0SFVKocHNSGDaMVVoMIuRtTUsD0f4YMsOWSKLtObozieoGi7lPoy1smCRU9O5Zl3urBdD9vzWbYpyaquJJqkMqYmhKsIsgULVSnb6rquR3umSKlo8+RfLyO5/GVajv0vRux1II4oK34LHRy7LFLmpjtpv/k8RF/mWmsYj6Ac4G0ZhKuAh4biu7T2lvuibb9vjOfQ8+aThMbPJhFPkOzr7JCAsAHdPb0sf/IOxu19KNOnTiVoKjg26Jpgcthk37G1yFLZErhgO4QMBUNVypUulovj+mQtl+ZEEB9wPH9AQPGTSiXA/g/RX2Jquz6KIuH75dVovW4Mgd32IbPoPqJ7n4jcl33uR68fi964G7nXHyEy64TtemJviSQrJA7+Ih23XoSb3Iyb7iRx6FcGgmsAN9WOJEkEJ8xFqxpB3ZmX03nTj+m87RLqz/45eu1oBAwSNXM9wcqOLLe+shGvL2u3NQfvVsuDS9rwvfL2bNEZuPZKRrvCrsiWomWvb0wNWHBBuTz6g+KmO+i45UKQZerPvAI1WjPsuA+auR44zmsPo8YbMVumD9kmXIfMgjswmidjjJq2zWMoMizZVM7W3/7qRm7+2j7AuzZksiQNPCM8r7zQVrnvK1QoY7s+m1NFJCBgqsxqSaDKEut7C5zw+W9x2Tfnc8ct17PXkfOwhUfB9XB8weZMkeqZR9Fx26/oXrGYqnEz8CWZxoSG5ULYUIjpOkjQkbWx/bIFjqHr7HvmN7n1Z99kzerV9HZsYv/532P8fkdQcByqwwbr16ygYPmMmHEA3V6A6cecw5HdPdx3099QzRDjj/oCmqywvrtAdUTDcn2qQwF8v5zNU2RIF13GVAfZ1Jsnb3tMqAsT0lVWdGTYkCxQHzEYEQ9Q3WfjlSzYhAz1Ez/ZrvDJQpYlqgI6ChJVQY0uzcIX5aqUqrDGuh7whcD3Jdb35LE8DyEoK/4rsK6ryOS6CAvXJomaCpIksN2ycGBvziEa9Ln9tz/lzSceYO9Tv0Fk9snkbBvhC2JBjZItCGiCDZvbBoLrujOvQG2cMDA3cYY572zRAQG9W2yUgdyKl/AKKSIzjhoIrvu3xUM6m5++A6eUZ/Th81nTleXY6SPwJYlswaEpHqA7Wyyrmfs+ckomXXTYvSFGwXbpzFposkSm6BDQFTRFJqApfNJv+UqA/R9iy/5ryRecNXsUS/smp/F959G24iUyr95HfP+zh+wb2eNoeh7+HdbGJZijhk6It4USStAw/5cI10L4Hmrk3Um707ORwsoFuOlOwnseM2DBU3vmFXTc+CM6b72I+rOvQqsaMeS4W1rubI0swfOrugdsyXwBVz+7BqgojlfYtXlsaTuPLG1nVFUQWS4vkilKuTz6ntc3v+/jutkeOm65AOFa1M/7+bD3HHw4mWsAu2sdVuvSvj6roWVcuTcfxcv1UH3s97a7oOf773ZR2Z7gzsWtjIgHttCZKJeGCyEGWZVVqFCh7A8rgJCu0pu30FWdznSe51Z2o42cwZgpM3ns5r+y31FnUBM16Ew7NCVMVnWkqZ12IMr9f2HzwoeIj5uBIpcVfBVVYCoKI+sCJIsutpdnQkOUpZvT5Kw8thnn5AuuxpA81nfnmDi2Bc/zSbs+fm8rXUtfIt/bzuRDTyFU10J30WWfM75Fe1cvd133B+bkBAef9iWWd2ZJZGQa4iFUucRLqy3G1YUJGBoIQWfOYnRthPqIzqruAh3pLt5pTzG+PsLmTImOjMXExiiJgEq25OH5PkFdpS5iVhTHK3zsEUKQKdq8sKqbbMkFBCFdAWSaEia14TANMaPvdRVDlRBCYnVXFtsTBDSNkuOTtjyChozrCnRDoWiXWNVpA4KFt/2VN5+4kwNP/TKHf+5cOnI2svApOg6uK9FVKNG5qXVQcG00Thj2fPvvKAdI2+VS8S3xgcxrD6HG6omOmznQty0DBlDKpVn2xO3UT92P6IjxOL7gudXd7NVShUBgux6GLuO5HpYrUBC09hSZUBfBdn00WcLUFFzPJxbUUWWJkKHudMJwV6USYP+HmDu2GkWW8Pv6qF5e0zMQhOr1YwlMmEv21XuJ7nXCkCx2cNKBJJ/6G9nFD76nABvKZeAYZY854TlIioaTan83uJ5x5IBCsBACa9PbhKcfSfa1B+m45UIa5v8CNVa37eNLUBPW6cyWs3i+gO7c4IzelnYA/Yrjr6zrrWSyK+wSLFqf5KqH3+aVvoWldT2FgW2e7/PQkjag/Pl+r7IdXiFN560X4hXS1J95BXrdmGHHZRc/QO9jf/nAmWsoK4+jaISnHT5km3Ad0gtuxxgxCbNlxnaPM1w/en8pfb//98XHTSFZsCutIRUqbIUiSXSni7zYkSVnuYyrj9CWLFAd1FAk2P+Mr3L9Jd/gpX/fyWEnnY2hl9V4TVVFM00aZh7BppfuJuRm8Y0qDFnCNFTG1YZpSAQp2i7dOYf2dJFEQMNyZGKmyfpeQTysIYeiqHjImkapezNLXnoKK9XBtENPon7MJNrSBYIljdZFj9EwZiLyQcew8LY/ouk6ex89D1mW6chYOJ7AUCFddMgUbbKWT6Zo4/mCF1cWiRg6jYkARQe6sw6aItOcMJGArrxNyemv6oNowCOoV6alFT6+CCHoyJRY0ZFhbXeOcXVhVnfm8AM6k5vjFCyXdb0FmuJBDF2lIeZSGzbozlls6s0RMiQE5TLwdd05EgGFlCWQZAGej6aprHjkn7z5wPVMPOQUmg//Ap2pEoomIQR4QkbIgkLXJpZc+6MdBtcA/ZKCHjCcQajdtQ5r41vUHfwFgrqMb4MGuEBNVGHzsw/glvLsccKXaYiZJLM2UdOjtSdHW9qiMWYSMVXqYwFcX+AB1UEZWZIImypZyyVvuQQNlaqQ/okPrPupPMn+Qzy2tH0goBbAqq78oO3x/ebRtnLBsFlsWTMITz+CzCv34Ga6UKO17+scuh/6LXge5ugZuJluwlMPHWS/I1wLWQ9QXLuI4MT9KLz9LB23XED92VehRobPPvmCgeB6ZxGA5fjctbi1Mumu8LFm0fok8/5aLnkeDs+HfukzCVBkaZutE0P2LWbpuPVC3HQndadfitE0cdhxmUX3k3z86g8luPZLOfJvPUlo0kHD9njn3nwUL9tN9dH/33v6ElQViVNmNm/T/7tChQrv4vmCdT15lnXkWN+VRVEVEjmLvOUSC+qEDIXDDvsMb963N0/eeg2HHH8GzYkQyYJF1FRJBA3G7n8im168i00LH2LckV8gbBpEAwoja8J0Zi02JYsEjLLoWG3EREggPB9VlinZAtv3eeIvP8V3XSbOmkuut4NDjz2N+JhJqLJMpuiQzOTwZJ23nr6dKbP3o2XPA3n+ht+gGgGmHngcli/hCh9Nkig5oux9a3tsSObxPMHG3gJRU6E3X6IxEUIIH1VVGF1TTiJYnoQmy7RnSkgCArpSCbArfKxxXJ/1PXmKjofl+mxOFvGERMCQSRUcPOFjqAqmppAIaOhhg66sRapgUxMN4Atoz9mMrQnQkbHZnLHRJHCQUCRY+/TNvHn/dTTMOpKpp3ybkuehGzK247EpXUT4UOjaxAt/+iFuMU/dWT/bro0mlAPl7ZFd/ACSqtM0+wjCQYVISKIqoOIJCcMv8tIzd9Cy54EkRu5GwXKwPR9Dg4zloikS8VC5XaQmYhDUNWQZqsNGubJGkmmOB/BEWbzt0xJcQ8UH+z/CovVJrnluzXbH6PXjCE6YS+bVe/FLuSHbIzOPBSD72oPv+zyqPvMN3EwHmQV3EN37JIwRkwZtlzWT4IS5NJ7zG5zu9cQP+RJeIUXHLRfg5d6fv7VE+UOmytKgoEMAt76yoeKbXeFjzV2LW7cZXMO72l39auJf2X/MTul5+aUcnbddjNOzkdpTLsTcRq9z5tX7ysH1hLkfOLiGcgAtnBLRvY4fsk24NumXbsMYMRlz9B47PJYEHDG5nrPnjOKyE6ayYE0Pi9Ynt+n/XaFChTJF2yVdtNFliBg6kg9526MpFqI+arLnqASnzRrFFZddTndHG6uev4+oqeH75ZLykdUhpk6eyMgZ+7H+xfsQro2uSn3CSwJTk7A8j3hQozNjIUmC6pBJT95mTG0ITZMpWTBn/vfo6djEY7dcy+TPnEli7GSqAjq+ENSEDZpqY9RM3o8TLvgbyxYt5ISzv8io6fvw9HVXsvS5hwgHFCzbRZIgbKp0ZEoISaJgeWzKFMhbDoos4QjB+LoIx+3RzHHTm4gGdKpCBjFTQ1UkNvXmcFyfVR05ujNbGwtVqPDxwfHL/tWmqlAVMqkNmxw3rZEZI6twPBd8QVe2xNruHCWnfA+qikzIUFEVGVPTiAQ0PF8mW3SwPJ+8WxYL7Hn5Xt685y80zTyU8Sd9B1lWqAmb5Esub7fn6MmUWLZiNU/+/nu4pcJOBdc7witmyb/1FKHJByMbUVriYaY0xRlTE+Wg3evoXljOXh/3xe+QCCrIPtRHNHKlsrWYqcms7Mhje4LGWJDdGiKMrg6hyzLdOQvL9VAVGUNVPlXBNVQC7P8Idy5uHVDhHg5ZgnMPHMuFF1+CsPJkXrlnyBg1Vk9wwlxyrz+Cb5fe13koZpj6s64ESaK0/nUAhBgaPLjpDhBgNE2k7vRL8bJd5SA7n3pP7zd7dII7vrEvPzhyIpedOBVTG/xxc324+pnV7+taKlT4T7D1bWuoMift0cQRk+uRpbIAoOgb5wl4aQuv+G3hWwU6brsEu3MttSedT2DMnsOOy7xyL8knriGw2z7UnvjjDxxcC98js+gBjOYp6PXjhmwvK4f3ENv/7J36Ivz6gWO55py9OHVmM5c9sJRfP7qc+dcuqCyaVaiwHYQQ9ORtckWXguXhCJ9oQCMWUGmIG4yri1AVMjE1hYMPPZS95uzD9Vf/nkKhgON7xEMaYU2hKWFyyCnnYOXSaOtfxDRUUGQKloupqsRMjUzRpStTYsnGNJuSBQxdYX1nlsXre+hMF+h1NE4+7w/4At5Z/BK25bO+J4cnyu0vwhXUxwwSIo0QAj1YxSHfvJIRk2bx9N9+xuKnHkKWyr7bVSGTgKpQFdaZ1hgmgITeZ9UT0FRUSSIeUImYGvVRk8Z4gPqoSVBTqI8GiYd0crZDb/H9zW8qVPhPoMjlYHlTsoDAJ2gq1EVMqsMmQU0jHtIp2C45yyFdcHmrNc2LqztpT5dY35nG8lwOGFtNTVRDV2U0BKm8y7In7+TV2/5A3bQD2HP+T4iEDIRf1i/x8fGFx6b1raz+13n4dpGGDyG4Bsi98QjCtYjNOp6AodBTtFBliaIryCbTLHzgBqbscxh69chyO4ihEg6aRE2dsfUhJjXFmdocZ0ZzAk0p91o7niBdcig5Hh2ZEv6OJkWfUCoB9kfMovVJ7ljUOui1qKkSNpRBr2Utl+iI8ex9yNFYr99PkzFU/y+y94l9JZ5PvO/zkfUAjV/6E+bo8qTey3TjZrqxO9ZQXP0K+XeeJ/nMvzBGTkYJJTCbp1B36iVlleO+XtGd5eV1Sa5+ZjVzx1Zz9pxR3PiVucxoHlyW+sTbHZUJeYWPHYvWJ/nTU6uY2hRjSztHqy+bXXS8IYG054sd+l/7VoHO2y/B7lhF7Yk/Jjh+9rDjMi/fRfLJvxLcbV9qT/jgwTVAYcVLeJlOonsPtfvzHYvMgtsxRk7dYe81lLPXkYDGovVJfvv4imFt+SpUqDAUxxMULIegoaDIEs3xAPuPr2VCbQRJkkjmbVIFi7ZMkc6sxbf++3w629t48p6biQcMAqaCopbbrMZOnU3NqAm8/ODNVAV0poyIUh8NkC45pAo2XekisiLRW7Ap2uXsWnsmT9GVyNk+XdkSrTmfIy/6F7Gxe7Kmt8CqtRvpbt9M1/qVbHzzRVYseIL7/vZrWibNRItFaUhEOP0nv6V+wnSe/etleGsWoKkyrckCjfEAE+oi6JqKoqpoqoLvCzIlF9vzeWZFF+t78mxKFbBcj7CpMbY2jKkrbOjN47g+tusjxKdzQl7h/x7PF3hbfbkLIUgVbDoy5cUfTZEI6xp1YZOOTIkXVnexvidHd96i5HgIqRyI+8InZ7u0dhcQgEt5IUpTFUZEg9THTPKOR/K1h1h+zx9p2eMAjv/OZUiKiibLVEU1dHx6MyU2r1vHupt+UhZDPetKtGEWyd8rwnPILrofs2UGkYYxgIRll89ZCMETd/6DUj7HqEM/S2/RIaDLOLZHe6qIkMpVLqYmMyIRoDFmYmgKluvRmSlhOR66IuP6Av9Tej9Xml0+Yn7x8NtDSkwzpcEdEb4ol0t7PrjjjsN6+hHefuwmYgd+ftA4Y8Rk9MbdyLx6D+E9jkKSBwfpO4usGcjxhrJw0uPXEBg/G0kzEHYRJVxFYMxMjOZJAz2aZst0ak+9iK47Lyv3ZJ/1s2169G7No8s6eHZl14Cg2cXHT+GMq18aeID5gop9T4X/KP0+zf2q1lv/OxHUueyBpQO2cmNqwqzqfLdt4/0qhpeD60uxNi+n5sQfE9xtn2HHpRfcTuqZfxLc/QBqjvsBkvLBH9NCCDIv340abyQwfs6Q7bnXHir7Xh//3zuVvdYUiURQZ/61C7Ccche6LFFRC69QYQes687x6oYkm3uKNCYMxtZEWNOToyEWYGVHloZYgERQJ1UoqxOPnz6H6Xvvyy3X/pGjTj2bjFCpiwRoTZXI2y6HnPoFbv+fC+hZ8TJ+w2F05opUB3WKiQBrO3MgQ8ERrO3Ok7UcHMfFd30s26PgChIBHVVRkKJ1rHjqNpbc9SdqJ+9DfSJKoZBDCiYIj57B6Bl7EY/VgPDYZOuc/KP/4cFff5+rL/s+X7/kt5x++mlYjsfKjhwZy+Og3etY1ZWjUPKYOiJKPKSyKV1ClcuqyZmiQ21EKQcbcZOS4xHUZTy/XBn0KasmrfB/SDJvkSo4OK6HJEuoikxNSMcTZb9mWZJIFRwMVaYtXaQuYrKht0imYJMuuHieoDqk0xQPIMsSzfEgEhIFx0WWBe2ZIiFTxfV8Cq5HquQQC2jUhgxefO1RNj74v1RNnMOen7sASTMwDYeYXvaZXpPPs2n9alZffx7Cc6k/62fbFEPdGoNyFtXt+7N1mJt/+zm8XC+1R30HQ4VoQCZg6NiORz6b5OUHbqJ22oFY0WY60hbrewqMiJvUhAxGxgKYsk44BI4vUBSJiKmxrieH4/n0FCwc16e5OoSqfDpzuZUA+yPkqofe3q6l1Zb0x+BKdQvBSQeSfvU+wrNOQAm9G3hKkkR09il033sVxZULCU7c9wOdX3j6kRTXLkYJRKk++jtAuWR8OOuewOg9qD314vcVZFuOz519gmazWhIcP71xIEgRQCKof6DrqFBhZ7lp4QYuvvctfCFQZQkkCdfzBzQCXE8MWHAJyhnZMTWhQQH2ljTHy6WcyYJNT34418ky7wbX71Bzwo8ITdxv2HGpF24m/fyNBCcfRM2x33/fi2hbY7UuxW5bTtVnvjHkmL5dJL3wDsyWPbbZC741B0+sI1koKwD3I0sSs0dXVRbLKlTYBpbt8eamNBFdxnJd1nTaBDSFfMlDk6U+vRKoDurkbZd00UHXZE796ve55Guncc/Nf+f4+V+jN2/TUh1gXWee6QcezWPX/47n7/4n+x58OIok6MhYdOYcUpZHPKDgez6S8IjqKgVJsDldQAiQBWxKFpBkaIwYVO9xJLElC8EIMerk7xM3NQK6Tlu6gBLWSVsOYV3F1BUcKcQpP/otD/3PD7j6p9+lKWoSmLQ/mizoSNv4nkdN2KChQcdHZn1PEc/z2dSbR5JlPCEI9omapQouridIl1zCml8Jriv8x+hIF1nRkcPUZfIll9qIQchQWddTKFvGCSg5HtVhHV2VsS2fSEBjVkuCxRuSGKqMK6A7bxM2VZpjASxH8NqGHtpSJTx8woZGumiVW8kcn82pIqokseq5e1h19++IT5zN+HkXIGsGmiQzIhJkXXeGrlyJzo3rWHvDBSCgft6V6LWjd+q6FN71wfYpB9tbqhuUF93vQq8ZRdOkWYQMBceFvF0iFtJY+uD1uFaJ3Y76AkKUs9CKkIgHTXRDIWO7bM7kGVcXxXc9DE0hmbd4Z3MGU5PJFm2yBZt4sKwc/mnk07ms8B/ikaXtOzVOkcvqgf3E9ju7zy7njiFjg7vtgxpvIL3wzg9cRiWpGrUnX4Cb7qT3ib+WXxsmuO6nP8h2k2103Hw+Xn7o4sHs0QlO2qOJLb8fBXDHolYWrU+yaH2SB95sG7TPw2+1cdPCDZx/9xIuuHtJpWS8wkfCovVJzr97SV/JUtm7ub+02fYEjlcOqj1fgFR+OGqqzLkHjeOkPZqGPWZrqsSqrvwOgus8nbddjNW2vBxc777/kDFCCJLPXk/6+RsJTT30Qw2uATIL70QORAlNO2zItuyi+/ELaeIHfHanj1cbMcgWB1+z6wueXdnNd2957QOfb4UKn0R8UV7M84VMdcRgXG2M6pBJNKDiC4GQZGzHxzBkRteEiARUAqrMnH32Yc4Bh3DH3//MmLhCdURH+IKqiMFe4+o45qwvsfTVF3ni2RdZ2ZlnSVsSSQgaIzrNsSCKDJYnSJUccrZDSJeIGgquDyW37GnfmbdZl4XqEy4g39PJy7f+kbzj05rOky45JPMOq9syvNOWZnVbkpUdWWQzxOcv/V92nz6TS35wLq88fj/VIZMRcQNVVTh0t1qOnTGCmX2L6wfsVovjC8bUhIiaKumii+sLNFmmKW5SGzJBhlzJZVMyzxsbk7zVmiRbfG9OJRUq7AyO57OiLcuC1V28tLKLDT15urIlevI2nZkC67tzrO3OsbIzTXu6SE+uHEQrkoQqy7QkAgR1ldqIhu14hDSF1zamuOXVdSxc08Pa3iwvr+0lGJDwBdRFDWJhg7Cu8spDt3Dvn39G09R9OPDrlzOhoYrqaICgqVBwHDrSRdrXrGbtv84DSaZ+3s+3GVwrDM2W9lty9S+Bby0dWFq7GKdrHdVzTqHoSvQWPTozNqYmE3CyrH/hHqYddAzjJkwASUKXIWjIbE4VKVoe0aCGQMLzfVzPJ6SrtKWL6JrC5pTFU8u72NBT5KG32mlPFfg0UgmwP0KOmtKwwzGyBJefOJXLT5qGKktIgFY1gtDUQ8m+9hAJPzNovCQrRGefgt22HGvjkg98jpKsUHfmZSihOE5qxwsCgdF7UHvaxbipdjpuPh831zuw7cAJNdx27r5MqI8MWYH2vLIt128fX4G7VX/L8yu7Of/uJdy0cAM3LtzAmde8VAmyK3xg+vuo+z9LF9698/eLEIAEX9hnNAAPvPn+ysK9Uo6OWy/Eal9J7Yk/3mZwnXrqOjIv3Up4+hFUH/PdDzW4trvWUVz9CpFZxyNr5qBtfilHZuGdBMbtjTFi9+0eR5LKvde6KnPKzGaWtmWGHff0iq4P69QrVPhEoWsqkxujmLpMVchg1ugqmhJBxtaGqY8GGVVlMr4uTHM8SCyoM7YmTCSgEw1ofP37PyGd7OWPv/8d4+siNCaCjKkOMb42wt5HnYERDPPk7X+jK10kZujkiy7JgkN3vghI2K6HLgt8D0K6ju35+D7IHngupHICH3BlheozLsMz46Q6W8kVbFRFoli0WZcssTmVoy3nkC+V/6QcmZN/8gfGTp3Fbb/+CbfdfCO9eZtR8RAjayKETJ1oQMfQFBRZJhbQ8IWgYHv4viBXcmjtzfLm2m5SeRsJwZqeHG9uTLG6I8eK9hxLNmWwHO//+r+vwi5MwXJZ151nQ0+5/x/AsTxeWtON8H2Wt2VpSxfwfFjXlWNjssizK7tJlSxKriBbcunMFilYDu2ZIqs7M7zTlqXkOkjIBDSF9d0FXl/fy7quAr15i1XtGTb35lm4soeOVAnX9ejK5nn0tuu4+88/Y8JeB/HZ835DQyKKEB4BRaEmqNOZsehat5J1N50Pqk79vJ+j1Ywc9roMwFT7/kjQnyseTrWlv2QcILPgdpRwNeFJB1HoW2jLuJDM2rxyzz8RnstZX/4uIxMhxtaFaKmOMKkxwriaAA3xAI2RAHURg6ip0VIdImqWRdviporlOiAgYCqUXH9IW+ynhUqJ+EfIT44p22Dd8/omunIWCFAUGYQYKEW97MSpnD1nFAATGyLctbiVW1/ZQHy/eeSXPs3qx/5F1ZHfGjimBISnHkbq+ZtIv3Q75qjpH/g8JUkmNvf0nR4faJlB3Rk/pfP2S+m46cfUn3UlarQWU1P401OryBYdZEkaJGwgyxK3v7oR2xuadd/6FdcTAyXlFSrsDFv2Vc9qSbBofZL51y4Y6KO+8Stz2Zh8b6uovoC/PreG1d15tuPUtU28QpqOWy/C6dlA7cnnExym91kIn+TjV5Nd/CCRmceROPxr260ieT9kFtyBpJlEZh43ZFv65bvxrTzxAz63w+OcOKOJCfWRgd/x0VMbeW5l95BxB+9W+6Gcd4UKnzQUWWJcXYSRVSEcr9/uRyZVcsgUHWTJYFRVcEAHIRbUiQV1XM9n2h6zOPyo47jx2r8w65h51CSqCJoaNRGDMY01TD3sNBY98E9q9l9J46jRKKpOUJPoylo4vkSuVO4xlSWJUEBHV2TCAY9cEVz33XJSKM8JInNPZ7MLQUcQliU0WUL4HrKkkgjIqLKC3Ze9aq6L8V9XXsMfz/s69//xYgKyz8jT5rO6I03I1LFdl6Ljky2WJ9qZUpaYqRE0Ve5Y3MY7mzJ4QKA9y+mzmgkaGq3JEr7wqQkbFG2PklMuQ61QYUf4viBvu/i+IGxqyBJ0Zi0MVcYXgp6cTVM8gC9LGIZCWFFIl5xyq5jvsborT0hXCOkyyaxLLKTQmS4SDmhoijwwt1UVmbxlEw0IOrIlQGCoClFTYUNPia6chYREKe+hKxI52yO54A6W3HsNzXsexBHf+TmWpOH6NvWxIBFdIVlwaH17MWtvugjJjFB/1s/Q4ttO1rlASAHPA0WFoAzt1uB5tUQ58Fal8rymuOltShvfourQr+D2CagOBN49Hax78X7mHHEKWnU9USETCxt0pIsYAY26oI7jQ0BT2Ht0FZqqoCsysizRUh1ivShQFwkwutqhO2MT0GQaooMX9j8tVALsj5ifHDOJnxwzaZvCSlsGkf09yl1Zi0eXQWSPo8i+/jCR2aegJcolqnuPTrC2O09m75NIPfMPrLaVGI0Tdngepda3sTtWocYbMEdNG5LJ2t5+RsN4JHXwepg5cir1Z15Ox22X0H7jj6k/62c8uqwsajYcbl/57c5SacGqsLMMF0wvWNMzRNn68En171mgzBPw+ob3Xk3h5nrpvOVC3HQ7dadcRGDsrCFjhO/R88gfyC95nOjsU4gf/MXtCowJ16G0cQlO7ya0mlGYo6bvUJDMSbWTf/tZonudiBKIDL62fJLsq/cQ3P0A9PqxO7ymCfURvnnIu7YgZ88ZxVPLO3lsi3t+UkOE3541vO1YhQoVykF2QFcI8G6waOoK1SEDVZbKfZ9bETY10kWHs879IU/8+0Hu+Nsf+fx3L6KQLdKVLTGiKsCsY8/m9Uduov2FuzCqv4OQfSKaguULLNshXwJXAAgc1yKXSpNZuQBFUVCbpyINM4m3AMsC23OIB2ViuorhW1idnYydOp3asE5vvsQ7m7OoMnzxkj9y41Xf59bfXYLs22SOm091WKc77zCzJU6maBPQVUqOz4aePJObEmQLDkFTR5UgVXR4uz3PWAcSIZVc0SNvezRXKQT0SnBdYefozVvlrKlUdvyoj5pIEviirBCuqeV7LGyq7D26iqff6SBgqMQDJpszFp7vky751IR0skWbTEmQtzxGxE16shbVYR3bK2e10wUH14fOVJH1ySKKLGG5LlUhlaJnkMy6OK6HKsO6x/7BhidvonnWYRz4lYuwPdiUyiPL0J4uMSJh0vb2q7x23QUooSrqzvoZanT7C9ZWTyu9ncvwfJnQ+LmosTAasGVTxUBWWwVFgraXbkMJRKnd4yigz2aUcjl557M3IUkS04/7AhnLJV9yyZVcFBlaqgIEVIU9RkRpjAWJBAb3VocMjd0bo0SDGhPrI6RKDk1xk+inVGepEmD/h+gPnrf8eWsWrU9y1+JWnlzeCUBsnzPJLXmM1HM3UHvCjwBY250nb3tE9jyGzILbSS+4jbqTL9jue2ffeJTeR34/8LNkhIjseQzROaeimOFt7udmuum45XzMkVOpPfkCZH1wUG6MmET9vCvpvO1iOm76MXVnXoFeM2rYY20vuG6IGnRkrIExiixxyszm7V5ThU83Wy5YDRdMzx1bja7KOK4/oGzdHxw+srR9kDjXjujKvbf+PzfdScetF+DlktSdfumwVSbCc+h+4DcU3nmO2H7ziO23be9p3y6SffU+Mq/ei198tyw7ftDnd1h5kllwO8gKkb1PGrIt/dJtCNfZqd5rUxteHfzcg8bx3Mqugd/zFSfvnEhahQoV3kWSJHR16P1ftD0s16Ngu4R0hYbR4znyxNN44t4b2fOoeYSr6inaXdSGVeLxBGP2O47Vz97LbkedQ1aLkfbKdoK2DQXx7vewlUrT+vfv4OXetdQzR+9JfP+zMUZMGnIeeRfcjM/UZo21d/yGd159nj1+fjXBmhl05CR838MSoBohTvz+r9D+cCE3/+FnpDI5TvrCN+nJWrR2F8jZPvGQIKApRAM66aJNbUinLZ2lK2sxsipIY0QlHNRoqQnjC59MyWH3xhi6WgmwK2ybou3hifJnq+j4BDQFWYJC33d9fdSkO1tCkSVqwsbAfvuOqyUe0EnnLZZszlAsuIyrC5EtlKtKokGdsKGxobdAyfHQVRVXlI9XcgqoSjkr3Ft0KdnOu1ouAvBBlj1KjmDTo3+l55V7GbPPMRz8pfNwvHJQnbdcgppMe7LEqsXPsPKWKzESTYw/5wpyamJIsDxwveteJ/3CTVitywZe06tGMPbLv0OSTTTKiSoXUDXAAc8Hr2sN+dWvUH/gfDTTRJHBUMDyIN++geSSp2je7yQIJNjUUyRr2UQMjVjIJGxqlFzoydns3jC8yLEsSzRETUK6SpMEUfODW4zuqlR6sD8m3LRwA2de/RI3LtyA21dGrYYTRPc6kcLbz2J3rAHKk/2C7SEbQSKzTqC44iXsrnXbPXb2lXvQmyYy4pv/ou6MywmM3pPMgjvYfM3XyL356DbF0tRoDdVHfIPS+tfpvP1ifCs/ZIzRMJ76eVeCEHTc9BOstpXbPZfhYohRVUEMTUaWQJUlLj9xaqU8vMI26c9Y//rR5cy/dgGJYFndU9nCJmpWS4KLj5vCvuNruPi4KQOfp9+etedAX/VHgdPTSvuNP8IvZKg/84phg2vfsei6+0oK7zxH/OAvEd9//jaDa6t9FZv/+nVSz12P0TSR2lMvZsQ3/0Vg7F5kXrl3u0KHbqaL3JInCE//DGpkcHDspjvIvvYw4emfQasascPrOmpKw7D35KyWBDd+ZS7fP2LigBVfhQoVPhhCCFp786xoT9OaLNKbt6mJmIyqCvGtH56H8Hyeue1qxtWFWdOV5d432ljXXaBu31NAknj73zdSLIFlQ65Unmhv+aTILn8BL9dD/VlX0vSVPxM/4HPYnWtpv+G/6b7//w0rYOoArT15Dvrsd4lX1/Hb//4yr7/0LAogKTK5osfbbVnCoQBf++nvOPCok3j4n7/juv+5kqLjsSGZx/E8ujPloCJiakjAoVMaOWvvkRw1pZ4DdqtG0XRGJQJUhXUiAZ3JTTEin+JJ+q6I7wssx8P330dv1fsgU3TYnCrQmS3Rni4SMRSKTrn6IRbQkCSpnMEGbNenYA3uCW6MmWxMl4iaCj15m0XrkkRMlZbaMKqi4HkCSfKJGCrj6yJETQ1FkRkVD9JcFWJJaw+vr+/mzU0ZVnfm2JzM0ZYs4nlQKHl0PPIHel65l+q9j6fl+P9ieUeaFR1pOlIFVvfaLOkoseGNZ1lx0xXoNS2MmPdzNDPR5509uKJTCEHX3VfSeeuFuOkuEod8iaavXUPdiT/B7t1EasWrZQcUyoG5D+SdcjWKrsDm525BNoLU7H08jlcOuoUEqgKdz12PrBk07X8GXXmbZL5EV84p3789ed7ZnEFXJYTHdqvnVEUmEdKJB/VhK3I+LVQy2B8DFq1PcvG9bw0S/5IAQ5P51WUXcu7xD5F85p/Un/HTQftF9jqBzKv3kH7xVmpP/PE2j+9mOgnPOBI1XIUariIwZk/sjjX0Pn41PQ//nvyyZ6g+5v9DjdYN2Tc8/QgkLUD3A/+PjpvPp+6My4bYc+m1o6mf/ws6br2IjlvOp+7UiwYFFoO+3AVETJXsFqIHE+oj/PjoScOWzVf4+HPVQ2/zyNJ2jprSMKA78FGyZcbadn0efquNi4+bQrJgD+rB7veyfmVdLxMbIgOfq22Jc31QrPZVdN52cVnx8+yfo9cNLbv2rQKdd12OteEtqo78JpE9jt7uMbVEI3rdWKInnofZ/O7v1mieTHHNqwjXRtKMYffNLLwDEMTmnDZkW+q5G5Bkmdi+83bq2l7fmNrmtq2rcypUqPDBSBYcWlMlJCEIICEJiUzRJR7SGb/3NL761a9x9TV/YcHxn6NdJDB0ia68gxSspn7Pw+l47VFis8/Aj9YgSyCpDEqDuekOUDSMUdOQJInYvmeW5xML7iD98p0U1yyi6jPnEpp80MA+PpAqQrtjcvZPr+XWK/6Lm6/4DoefexmR3ffFdn1sx6MrnachEeKMH1yJGY7w6B3/QHHzHPLFnzB1RIJoQMN2fSIBHcvxWdWRB0kwti5CTdjER1AVMjErJeG7JK7ns7QtTUe6REBTmDWqiqD50YYaBcfF1BQ0RSZVcAgbKnVRA1WWMfv69nvzNjIShi7TW7AJmxpKX/AXMTSaYiYrOzKYqkJVNEC66NKQ0BldE2RDT4FDJtYTCagosooE5GybjVmbsC7TlXHIWS4FyyVllwMrIyDhezZr7/x/ZN95gZp9zyK4/3zWpoeef/aNf9P7yB8xmifTeNoluEaQ3r4psmCrfmpJQqsdTbxpN6KzThxo31SCcQBK6U62nhH0e2Hb3etJv/MiLYfMw9HCWH1WfQgobX6HzPKXaDnsc8ihCJmSTURTMDWZoge1QZWwaZAIalieQK2kZ3dI5Vf0MWDBmp5BgmCKLDFvzihu/Mpcvnr4NPY47guU1i6itOHNQfspgQiRWcdTeOd57K712zy+pBkIpzToNb1+LPVn/5yqI7+J1baCzdd9m/yyZ4bdPzTpAOpOuRCnZyPtN/4YNzNUJVhLNNEw/xeokVo6bruEwsoFA9vqo+/e7gIGBdf9isSzWhJ885DxlYn6LsZVD73NX55dw7qeAn95ds1HatHUrwren7GW+wQ7nl/ZzWUPLB20ODNc2Xg/R09t3OF7tVQFmT9nFJ+ZXM/s0YkdagKU1r9Jx83nIWkmDfN/MWxw7RXSdNxyAVbrMmqO/wHh6Ufs8DxkI1QuM28evHDhOyWQZCRl+ImLm+km+8a/CU87HDU2eOHM7lpHfunTRGYehxqt2eE5wM45IlSoUOHDoWi7JAIakiyRKjqMqDJpqQkyIh5AVxVO+/K3UDWd+/72G0qex4auPJliCUWVqd3/9HKWa8HtCAGWANuFiFK28wkqoKoGeA74734Xy3qA+IGfo+kLf0CraqL7/l/Rc/+v8K13xSE94J22HOsLMsf99x9oGDeFx/58Ie2vPILre2zOWbSmbFZ1ZFi6Oc1Bn/8hh575NV586Hbu+s2PkXwHWZGJBNTyBF3ycYRPzCyLR5U8n5bqEKoi4fvbrs6p8PElVbRZ2Z4lk7d5fUOSV9b1Dqh2f5gIIUgVbNrSRRQkLNcnWbBJFS268xbdOXsggIZydaTr+7i+QEIa+E4Xoiw4PG1EHFPTiAQ06kImAgj3+bSPTASpj4UYXxcjqCsUbIdkwUEWkCy4lByHvOWB1BfIAqV8nnf+9VOy77xA4tCvEDrgs8NmfdML76L3kT9gjp1J7WkXI4zg4Osc5trj+59NbM5pg7SRfLsIgDbMgrtP+d5ve+4WFD1AYs7xQFkQTVdB0SU2Pf4v9HCcmjmnEDI0YgEDS0BTzKQpZjCtOUFDPICEzLTmGLHg8Av7Fd6lEmB/DOjvF92yRPrKk6cxqyXBVQ+9TdeoQ1AiNSSf/ueQktDoXici6SbpF2/Z5vHVeANO76Yhr0uSTGSPo2n84h/Qq0eWv1Af/n158r4VgXF7U3fGZXi5Xtpv/BFOT+vQ94nUUD//KvS6sXTdfSXZNx4duL6tHysScMCEGi49fgoL1vRUbLl2Ubb2er/n9c0f6v9lf1B908INA2Xhlz2wlC/sM5pRVeUvIsHQILr/ntqybLz/eMmCzbkHjmVEfNtCf22ZElOaYjy3souX1yW3qyGQX/4CHbdfjBqppeGzvxy25NrNdNFx009wutdTe/IFqPFGOu+8jPSLt1JYtfA9/17cnlbUeP027bwyC28HIYjtc8aQbamn/4FkBInuoH9blsr6COceOPY/UplQoUKFMtGAhqbKRE2NsbVhGqIBNKU8XXtjYy/Pb3aY/Jl5bFj8DNkNb1MfN4gGVAKKRLCqkaZZR5B5499oVjdRFQwNxlQb1IUgoEO4trzIaKQ2UW2U7Xs0ymJII0aOZLfP/5Kq/eeTe/s52v7xHaz2VUA5eEjasHRjjrVZwecu+TNjps3mhX9dxZJ/3wIeCA+KJY+lrSnWdefZ8+Svsd/8/4+3X3qc7315HnYuQ1DTaE0WWbopje14IEnURU1GVwXwBWxIFmhNFbDfj31Dhf9bhETJ8VjdnacjU2JNV45UfmsX5veGt8Vii+369OQsNvYWaM8UsR2fjOVgKhKaLBMP6ESMshWc4737+akKGQQNFUmSaIiZyLKE6/lsThVZ35vH8XyOndbEhIYwmZJNbcSkM1NibXeeoKkihEdPzu4L5F1SeYfunMVbm5Kkiy5CCNJ96t1uIc2q6y+guOFNao/5HtXDaKAIIUg+8w9ST19HcPcDiO1zBt33/fL9zwn65vhyYmgCQQbk1Hp6lj7P6ANOIhqrJmpCNCQjSeCte4PU2jeZeOTnGFkXozpkEjVUZjbHOWBiLVOb4sxoibP/hBoO3r2eyU3D919XGEylRPxjQH8f49Yl0ovWJ7n62TXImkF8/7Ppefj3FJa/MMhLVwnGiMw6nsxLt2Pve+awRvRGwwRySx5HeO6wGS8t3kD9/F+Qev5GMi/dhtW2gtqTzhsSKJgjp9Jw9s/puO1i2m/8EXWn/3SIgrkSiFJ/1hV03fNzeh/5PVIhyT3ijCHN14oiYWoKl96/FNd7V/25ksHetThqSgN/eXbNoNeufmY115yz1/s63tZq+/3q4JIkDXzJ2o7Ptc+vHaj6kBkcRMPw99RwauO/e3wFzw5jNeV65dJzawdiaNnXHqL30T+X+6NPu2SIUjeA07ORjlvLGgZ1Z1yGbxfpfuDXRGefjJftxdr0DsHxcxDC3ymLLiEE1ublGM2Th93uZrq2yF7XD9pW2rCE4ppXiR/0hWHPtR8JuOKkaQMWghUqVPjPEenzlBUCDFUeyHw5vk+q4BDSVPY98XMse+oult1/DfMvuZqRVWFyJYdI3iFx7DlsXvQoXc/fzrTTv0NYU6mOBhCqiiJJ+KN3YwMgelahNY/GUMD1oSig6IKsKtTvPw+zZQYd9/2K9ht+SNWhXyW85zFIkkRBwOpuG9vLMObMS8j5V9L66N8opNOED/o8ri8hKx5ru/KYmsXoA05DC8R55u9X8PnTj+eqv9yIGqkq225JAl2VydsusiRR8iyipobwRVkELfLptPjZVakK6TTFTJa3Z6iLGciKRNHe+YUSIQRClMWyPF/QkSliuT5BXaUmpNOeKSJ8wZquHCXboy4ewHJcqsMmEpAs2KiqjCxJA4tSUK4Mrdvqs5SzXBzPJ6SrZEouIxIaUxtjqJKErsl4nkbY81GQcHwIyeUA3/d9cpZLW7qI40HQkKmLmNhuiWx3O623XYyX7aH2lAsIjp8zyAIPyg4ivf/+E7k3HyW8x1EExu5Nz0O/e99zAgBr8zsAhOvHITNYGM0Hup++GVkzOfjkc/DNMI7r40kCx3V46u/XEasbwYEnnIWQFHRVoTlu0pwIM7k5jq6ArqqMqg4NqgqosH0qGeyPCcOVSC9Y0zOQOQtNPQytZhSpZ/+F8AYLNET3PrmcxX7+pmGPbY6ajnBKAzfgcEiyQuLAc6g7/ad4uV7a/vldCiteHDJOrx9Hw/xfImkmHbecT3Hd60PGyHqAulMvJjTlEHqevZ6eR/8X4b9bIlQb1pGBx5d1DOql3TIDWWHX4CfHTKI2PNiC4f9n77zjrCjP9v+ddnrZs33pvatIEVABFXsP9l5i1CTGJMbU12hieoxGEzW22LD3XhAFAekg0nvb3k5v039/zO5hl10Qje/7M8lenw8iM3OmnTnz3Ndz3/d17WjpKoa3P/zguU8Ze/tsfvDcp13Ey15ZVVN4PjrOYCOAYdlYtpNlPWqoI2S2byXEvr+pfcvGX1lVw6Lt3T9zoiBwypiq/Q4mtm0Tm/8k0dn34x08gfILf9stYVXrNtPw9E+xTZ3Ki/+Ap+8Y9KadhI+6iODYU3D3HY2ZasFItWDlUoV9HwhGtAYz3Yqnf1cBNYDE4ufBdlwIupzzvMeQgqUEx5/R5XPFPqVQaSIKTqDSgx704P8P3LKER5E6lZUqokhZwENp0IXi8zNl5jU0b1mNULeWo4aUccohvRnVK0TvfgMYdOSptH76Pj49Sb9yLz6PQsitMLF/iH4DhuINl2DUrCHkU0BsUz0GDB0Ey7H0UvqMouqqe/D2H0v0g3/Q+vZdhQo3HWiIGZiCReWZNxMeewrRpS+RfO8evC4TCREbiGdyJHMavlHTmfadP9JUs4tvX3gaHyxaQWMij27Y5HWDipAHG5tt9Sl2tmSojWWRe4L5fzuIosAhfYo5YlApA0r9iNhopolhfj7JzqkGS7e38srKapbuaCae1cgbDgHOqAY53cSwbFTdJqsZSLJAXnOsstqFy8JeF+UBN2UBNxnVIJ3fl952ONe2NjPDtEBw4tC0ZhLyuohnddJ5ndKAgoVN0CNTGvRQEnCR1kwM00AzLbBMFElCkgRy9dvY89TNWLkk5Rf8Ft+QSV2Oaekqza/9gfSa2YSmXEDxid9Fb9n9L8UEAPndq1HKBmD4wl1Ux7WmnaQ2L6R4wpnU5EXqYxlEUaAs6MXcvJDW6m1Mv/gGjhxcwcTBpRw7opypwysJ+10U+RQ8ioLXJfeQ6y+IHoL9NcbkQSUFIQFBlCiafiVGrA5py4edtpO8QUITzya7ZVGhlKsjPAPGgiiT27bsc4/pHTSeqivvRinpS/Orvyf28eOdyDGAUtybykvvQA5X0PTir7rt3RYkmZLTfkho0rmkV79L86u/w9Kcgbk5rRWsDNph2RD5L/XK+3fHD08Y3unfO5vTB1Um/oPnPuW11XXEszqvra7jikeXdiLANk6PfkcI0KkpSZZEThlTxe1vrS8Q847HXrk7xi9eXcv/vLq2i9q4zT7EvX2fosDtZ40BOusHtMM2dVrfvovk4hcIHHYSZTNv6dZXPrdjJY3P/QLR5aXy0j/jqhjsXIPiITr7H2Q2LqD1nbux8mniHz9BdM5DWGr2c72ts1ud8rHuvLX1eAPpNR84oob79F5nNy1Eq99C0dRLEbvp05o2rAy30rWsvgc96MHXA6IoMLwywOF9I1x4RH/uuvVmBg4azHuP/ZVeYQ/1iRyKLNKnyMPok68CQWD7B7PwKi68iggCNGVMKsJuBh1+JLVrl5LNqegGhSybAcTb5vAVwOcNUXburYSPvoTM+nk0zLoZPVaPgPM+Nm0RRZHofcp3KD36IuJr5rD7hd+hCBphn4LfLWIigACVo6dw8o/vx9DyPP3Lb7JlzXKiGR3dcDzE0qqGKEFeM4hndT5X/KIHX0uUBBRGVARJZBy1+Ixq0noQZeI7WtJ8VhOjLpFl9roGlm5vIp3T0XSDVE5D1S1s22bVnih7ojnSeQO/W6HI50I1TOoSeYIeGa9LpjWjksjqNCbzpNpIttO3rbKpPs762jiGbhP2Kk6LQsCNIjkl05VhD1UhD31LfJQHvYzrX0xlyENTMkdzKkdV2EPA7SLocoi1YIskt61ix1M/Q5RcVF5yRxfdFAArn6bphVvJbV1KZMa1RKZd5iic/4sxgZlPk69ej3dQ95WD8YVPI7r9eCZ+gy3NOrtaNDRN49BKPx89+w8GjzqMiy66kEGVISb2KyaZN6mJ5xhUHkASBLxuidJAT3z+RdFDsL/GGN8/wgUT95Zo+odMZMhhR9D08dOdhEcAQhPPRvQESCx4qst+RLcPz4DDyGxaeFAzYXKonMqL/0TgsJNJLnmJphd/hdk2k1bYJlhC5cV/xN17BC1v3kFy+etd9iMIIpFjrqT4hG+T276Cxmd/jpl2yE93r4t1dd3IK/bga4+LJ/VjZOXe7K1pw8uruvbod8TK3bEu/dtp1fFsbVedD7nlLpZaQuE/zl/nju9DLKt1K2j2zNI9nP/AIp5Zuoenl+7hV2+u59bTR3PTicO59fTRtKRU9p2QbSfXe1oz/OLVtdTGO+sRWPk0TS/eRmb9XMJTL6X4pBu67YNOr/2QppdvR470pvikGzoJA4YmnEnk2KtQazcSHH865efe5pBel/egeq+ymxfiqhzarep/4pNnEUSJ8JTO/dW2qROf/wRK2QD8o4/tdr/vrK0v3J+edo0e9ODrCUmS8HsU+pcEqSwO8tNf/pqNG9Yza9aThH1uLNumIaHSv19vhk47i00L3mbb1m0ookCJR8EwTDyyQsmoI9HzGeo2rMC2nTLSdnsfcEpM9bY/iiDSZ9pF9DnvV5ipFhqe+AH5HSuRJBBtk5BbpjKsMPCES+h36neIblnOZ4/8nJaWFkzLxjRMSgIKIbdI1eARXHz747j9IZ667Tq2Lf0Av0uiOpalOppD1S1K/G6KA+7PbdHpwdcToigS9iqEvArFAXebwvb+hc5006Ihkac+nqMmnqM5lacuodGUyGPZsLU5TV0iz4b6JM3JHL0jHiYOKkYRRYJumfKQh77FfqrCHkJeBdOyMUwbr8tRFlcNh5jXx3Os2BVjc0Oa1rTK7lgWr0uiKuzBsGzq4lkyeae32rTA51LQLRtJFNjVkubjLU1sa8ywoylNOq/TnNZpSOvsWvw6C/7xM7wlvRlw2V9QSvt2uUYj2ULD0z9FrdtM0dTLUMr6F9b9qzFBbstisEx8w6Z0WafVbSa3dQmRI76B7Akg4/zWd8c0Hrz/XhLNDfz6t7/niAGlWLaAJEkU+VyUB1zEshpuRaI86EGWeujiF0XPHfuaY+a4PrhkEQFQZIlf3v57MokoiaUvddpOdPsJTTqH3I4V5DsYz7fDP3I6ZrIJtbbruu4gyAolJ99A8cnfI1+9loYnftDFb1v0BKg4/3Z8w44k9tHDRD96BNvuOiAGx51G2cz/QW/dQ/2sm9Cad9HdsNkzWf3viZW7Y2xp7DwB89HGRp5Zume/21/yyBLy+wmebJyKhgfm72DOpqZO6waVBwpZaLcics64Pt0Kmq3cHeOXr68rlD3CXkuviM/Fr95cz+wNjZh2V4/JdXUJHlrQua8cnOxww1M/Jl+9gZLTbqLoyAu7zCzbtk1i8Qu0vvNXPH3H4B91DM2v/o70p++Sr9lY2C449hQElwdBclRA23ulRZe323tSOIfWarSGbfhHTuu6rqWazPq5BA8/DTnYWR08teodjHgDkWOuQpZlJMGpDhhS5t97f0zn2nvU/HvQg68vdNMintPZ2pSkJa1y/rnnMGHiETz3wJ1ouQzJvIlp22QNi9GnXIYoK6x49RGSeY2sbqEbFhvqYjT4RiB6gjSumUfHacSObzQRqAwJ9ApJ9ClxUTF6PAOuvBs5VEbDi7+iafGL5A0bj0vCsCwEAYrHn0rfb/ycZO02Fv/t+2zfVUtTRkXXbHoV+RhYHqKsdz+u+tPj9Bk6mnv+5wYevu+vRDMqAY9CQyJHYzyLzyXjVXqsuv4dYds2LRkNv0simtaoi+epjqZZUxMnrxldtm9Jq2iGSa+wF9G2wRaI+GSyhkUyq6G3qcv7XSIZ1cC0BPKaSVnIQ59iHwGPgmpaBNwyPpeMIgn4XBIZzfktBNwyqmGRzOlopkVGMzAsi3hGpT6eI5bVaErmaU6qJHIGiayKZlqIAuRVk1hWZWc0g98t43crpFUN0waPZLP0ub+xeNYdVIwYz9jr7yRcVNzl+rTmXTQ8dbNjmTv2ZBJLXvjKYgKAzIa5yEVVuKqGdVmXXDALyReibMKZ2DgTZyqQTbay+q0nGDP5GMZPOgpRFDi0d5iQR6bIpxDxuzHb7lUPvhy+MoItCIIkCMKngiC89VXt878R7arJnUps27POto1cMZTBk08itfy1LnZZwXFnIPkjxD/uqjbuGzYFQfGQXjPnC51P8LCTqLzoj9imTsOsH5HZtLDTekF2UXrWTwmOO53U8tdoeeMObKNr76ZvyCQqLv4TWCYNT/2Y3I6VXY/l7tHc+3eEYzPXeVlDUuUXr67tlmS390IfFPZ5jq8+aiBPXzO5U5a1XdCs47JXVtV0W/79ybYWbn19Xafjd9zKBpbtjO57WPI1G2mY9SPMTIyKC24nMOa4rqdqmURn30d8/pP4Rx1D2bm3gmUQOfZqPAPGkt04DyO5d8LAN2QS2Y0LyG5eRHrNbLSGrUjBA9tmpdd+CIKIf9QxXdbFF8xCUNyEJnf2vTbzaRKLnsMz4HC8g8ZjWjbj+0cYWRmk2N+57KtnkqsHXyV64oJ/DbZtE81o7GpTZLYs598VQTf9i/3IkkhSNfnxrb8l1tzIow/cS1o1GFoRwLJMDFeAAdPOoXrlHPZs2Ug8nWd9TYx4VkeVFPyjppHdshgzn6adyto4gWFAAr+C0yOtmeRyBnkVCFdSeelf8I04msa5T7D1xT/jFRzCobapKHuHHkn/C3+LmU2w58mbSe7eCgJIEnhliUjAxdgh/bjprieYeNxp/PPuP/CXW25iT1Oc5pRzvfXxDImsRkbVD6p/twdfH9g2eGSR3sU+qkIuNMNAliRSOZ1tzWksy0ZryyrrpuOfLosCkYCbqcPLGdu3iLBXQpJE0nmdnGrRnMyTUk0GlYU4tE8RA0v9HNq3CLciURHy0L/YT68iL5IoIAgCFSEPVWEPfSI+PIqEKAjEcxoSAvGMTk00iyKLyKJIQyLviKmKAqIgkFZN0prButo4G+ri7G7JItg2LSmVDzfXs7YmwcptdTz9xx+x7r1n6D3lDIZf8mtkxYXbK+HrwKzyu9fQ8PRPwbaouPgPyIHirzQmMBKN5HevxT/62C4T/rndn5HZtZo+U8/HH/LRPtq7gaYFz2FoeU658ofEchqxbJ4tTSkEwcYry8SyGjYCJf6e0vAvi6+S0Xwf2AiEvsJ9/lehO5XjJTtaMSynX9mwbG59fR3q2POxl39EfMEsSk+7qfB50eUhfOSFRD/4B/kdK/AOnthhnRf/yGlkNn6MNeNbiPt47R0I7t4jqLziblpe/T0tr/8Rrel8io6+pFAaK4gSkeOvQwqVE5/3KI3pVspm3oLk7fwouCuHUHnZnTS9fDtNL/2ayIxvERx3euGl8OD8HWxvyXD99ME9GbR/I0weVIJbEdF0q0tlwrvr6hleGSyoeX+wvoHXVtciAJIAkiTiU0Tiue5nSa8+elBhP6eMqSqoWu/7fLQT7XbsS61l0VHJtWzAdkq+uiPglg3bmtKdlqXXz6X13XuQQ2WUn3MbSkmfrp/TcrS8/idyO1YQPGImkWOuQhAEghPORJBd6M27MBKNpFa/T9GRFyDILlzlgwiOP53Mxo+xTZ3Ss3+OUrR/v2nb1EmvnYN3yBFIgc7Xr9ZtJrtlEeGjLkbydbbQSCx6DiufJnLsVYVly3Z17ZFv96TvQQ++QvTEBf8CVMMinnUygWlVx+eSkETIqhYWNqZlk1V1RoydyJEzTmXZG09w2IyZVEdtVN3AJbsYOP1cdi18jfnP3ceoS28npdq0F5oFDj2R1Kq3yaz7CGnCmQg4QaEAiG0euW4JDAEyhoVHAVMF0+Wh9MyfkK4cTPTjJ3nnD9+lauYtyEWVewWW+oym8tI7aHzhNnY89TP6+X7NiWPOAQR8loRh2iR0uPznd1LWuz/vzLqfHdu3c/HP/0qjHSKp6TQkVIZWBOlbspc89eDrD1EUKAt5aEmp5GQJQRCd51iR0Q2L+kQOzbRI5zRq4jmiGY2ygJtBZUFGV4VYvD0KCIQ9LooCbmRBIORTCHsVKkIeRFEEOhO/fZ8NQXCcatqhSG3kWdNxyc7/F/kUvC4J3bKQBZF4Xielati2zYASPwu2NeGRJRqSeQRsDMNG1Uxqavbw0b0/J1G7ndHf+C4Dp810VLltAcky8bhNUKFpzUe0vvs35KIqKi74Ne5QOUpxn68sJgBIt9nhBg49vtNy27aJf/wEUrAM9yGnEstDQIQSL+ixerYue4ehU89k5KjRrNgeRZIEZAGylk2/iKMgPrpXCL9b6e6wPTgIfCUZbEEQ+gCnAY98Ffv7b0VHlWPNsLj9zfWsro47M3Jt2xiWjRSuIDThLDLrPuoiahY47ETkokpiHz/RRZwsMPYUbF0lva6zSNrBQA4UU3HRHwgceiLJxS/Q/MpvsdS9atGCIBCeNJPSM3+KWr+1TQilrrBekZwrkENlVF7yZ0LDJhGb8yDR9+/DNtsEKIAPNjRywUOLe3yx/43QnkG+cFK/LoPc6KpQQRn8/AcX88D8HTQkVUwb+kR8mJbVhVwLAgwp8/P7bzg2URdP6sesb076QpZRY/bxaWxPWAuASxH5zVljOHFUxQFfgLZlEpv3OK1v3Ym790gqL7uzW3JtpJzeqtzOVQQOPxW9cRvJxS+Q3boEUfEgCCKu8kF4BxyOrWZIrX63cKH+UcdQctpNlJ9z6+cOpNnNi7CycYKHndz5PG2b2MdPIPrChPbx29Rj9aRWvoX/kONxlQ864P6vPnJAz8RWD74y9MQFXx0Mq71/NEpNa47mdJ5YWqc5mWdDXYJFm5s47eqbMA2duc/cS0tKR5EUcnmVhKkQnnweia0r2LXpM3KmY8flBfwVg3FXDSP96TvYto0bCHnA74Fin4RtQV6DdM75kzcdETQJ8IgCxZPOpfzc29CSzex+8ibSuz8rnLMCKCV9qbrsTtwl/Zh3/y94+alHkUSR4qCLoWUhhleEGN2niMu/+2P+9LeHqNmylvtvuoh169fTnFSpT2bY1pwildc7eRr34OuPoEehb8RHkd/F2L5hNNMmmTepDHtJawY+RWJtbZKtDWkakyqr9kTJGzq7W7M0pXIIAqypjbO9KY0lgCKJFPncbeT6i0M3bURBoCGeoyWVJ6tb7GzOkFYNvLJE72IvA0p89Cv24/PI7GpOY+gmbklka2OShVuaSasq9ds+453fXUumuZbx3/wt/Y6eSVYD1bQo9smEfS4QbBoXPEvr23chF/dG8oXIrJtL/iuOCWxTJ71mNt7BE7posmS3LHJETY++GE1yOdo2IrgVhfoPH0dxe/jGN28kr5u4ZQj5ZJpSKk2JHCIC0axGzth/33wPPh9fVYn43cBPoNvWWgAEQbhWEIQVgiCsaG5u3t9m/9Vo7yVttw74rCbBBxsaC4rbHRNu4SnnIfrCxD56pFM5uCApFE29DL15F5kN8zrt3101FFfVMFKr3u62V/rzIMgKxSd/zxEt27mK+idvQm+t7rSNf+RUKi78HVY+TcOsm8lXrwOgV3hvH4no8nLmj+6kctoFpD97j8bnbsHMxAvrDdPmlc8RyerB1wvj+0foXeTt9CyeOKqCpGqg6t1YbQG7o1m6i5kEoCaeY3jl/n2aPw/7s5iqCLmZOrSMeZubnJLL/XzeUjM0v/JbkktfIjD2ZCrO/02XigwArXE7DU/+CCNeT9HUy1Cr1xEcdwZyUSWt7/29k42dq9dwPP0PxUy2EJ3zIE0v346ZS3Wr6N0dUqveQi6qwjNoXKfl+R0rUfesIXzkhV0qU+LzHkOQZIqmXfa5+19fnzyo8+hBDw4Sd3OAuKAnJvh8eBSJiM9FLKOh6zZBt4RmmuQ1RzQskdXY0Zwhlteo6N+fIdNnsmH+m+zatJaMqpHI5kmrEBp3BlKghNZ5j6PbNiZg4kxqB8edhh6tQdi9Gh1I5MHU9opO5k2HULtdYFnO5xz1cEcALTRoPH0uvwvJV0T9878kueJ1ZNsm2PZa8wQi9L7oD4SGTeTth/7A6w/+Do8sUlXsRRYlNtYlsbG58vJL+f0jL2KZGm/9/lusX/wROd1CNUziGa2Tp3EP/j0gyyIeRcLvVhjTK8SEAWGymkFdPM+ne+Jsqo+zqSHB5roY25syzN/cwobaOE3JPF5ZptgjU1XkYUCJH0UWSav7t9z6POimRV63SKsmPrdM2CNhSzaS5MQLH25sYNXOGIJt4xZE0ppOIq+zfGczG2rjINq8+9oLvHT7txEVN9N+eC+DDjuK4oALvxsswcISwdBUtr9wF9GFT+MZcDiCbRGZeDbeokoavuKYILv5E8xMjODhp3VabpsG8flPopT0wz/mOGxAESHgESlJb6VuzSccd8E1zJgwlMP7RxhQHiSe1cmaFuVBNy5Fcgj552si9+AA+JffWIIgnA402bbdtam2A2zbfsi27Qm2bU8oKyv7Vw/7H4n2TOBRQw7ccwGOqFnR0ZegVq9zFAQ7wDdyKq7KIcQXPNWlHzo04UyMaE23PdAHA0EQCI47rUCi65/8Edl97L88fUZRedmdiN4Qjc/dQnrdhwyvDHZSml6wPYp7ymWUnvFjtIat1D/xQ9T6rYX1Pb/rfz90FBvzKCLHDC/npZU1X/i73FcN/Mtgf5ZvDUmVDzY0MntDI5/VdK9ar7dUU//kj8jtXEXxid+h5KQbEKSu3TTZrUuc3ipBpPKSP+PpNwZX5RB8w6bgHzWd4hnXEpv7T4xEIwCi4sY37Ei05l1kNy8iOO70br2zu4Nat7mgMCoIe1/btmUS+/hx5KJKgmM7Z7bz1evIbllEaNI5yAFHeEUS9laT7ItTxlQd1Ln0oAefh4OJC3pigoNDxO+iT7GPooDLEW0UbBRJJJrTWL67lfpEju3NGbbWxek77WIUb5C1bzzAjuY0aRVUC2zFTdHUS9DqN5Pd/AngiB0ZgG/ENER/ES3LX0PBKRHXAcMAjyLglkBUnCogWQS30Ea4RfDg/L+vTT05NPQIYh8+TPS9u5FsDVfbesHlYeA5v2DwMefxyqx/8tvvX4WayxDxywytCHBIn2I2N6QYNPIwbr7vJcr7DmL233/KnKfuI5HWcCtCF8eHHvx7oCLkJeJ3URH2EvC6nP5ov8Lm+gQuUSCjGcRzBoIN1a1ZUppGIquysyWJKNmUB9xkNdNpnZRE8rp58BouHaCbFiGPzIAyH7oFWxuSLNzcwsPztrFkRwuqZpLIG7z+WS0LtjXxydYWttQm2NaYpjGWYe5jf+GDB24nMvAQjrn5AYJVgwl5RRIZlcZkDkOHloYG5v7tJpJtLiPhoy7GVTWUkmFTKPqKYwLbtkkufx25uHdh0r29mDu9ZjZGtJaiY64otHK6JIj4JOY/dQ/FFb2Y/o0r2NmYY3hFmPH9ijn5kF5cOLE/h/SNIAADSvwU9Vjn/kv4KqYEjwLOFARhF/AccJwgCF29onpwUBjfP8IPjh/Wxf+3OwQOOwmltB+xeY9iG3tn9gRBpOiYqzCTzSRXvtnpM77hRyEFS0kue/VfOk9P3zFUXfFXlEgVzS//hvgnz3bKiiuRKiov+wuevqNoffuvvPTAn9lY35XQ+EdNp+KSP4MADU//hPRaR4Rt3xLfryO6FaT7L8a+YmPr6hLoX2IgbFcDj/hcX/j+tn8nD368/QsfF5yyqvpZN2HlU1Rc+DuCh5/aZRvbtkksfYnmV36HUtqXysvvxFU+sPD8m1nnOfePmo534Dha3ryz8Nnk8tcxkk1UXvFXfEOOOOjzSi57FcHtJ3DICZ2WZ9bPRW/eRdG0ywvqo9BGvD98GClYSuiIbxSWh30ujh1ezomjKjhhVAVHDIhwWJ9woRy/Bz34itATF3yFMcVg9AABAABJREFUCHkUykNuLBt8sswhfcJUBj2UBDz0L/ETdCtYlkCouJjBJ15OZtcacluXktOdIE8C/GNmoJT2Iz7/iUJbFoAsKxSPO43MjpWkm3eTpy1TLYBl2YQDChGfglcB0wBsZ3+iCLIEouSQ8pKIj8Ov+CX9ZlxGbM2HbH385wiZZtpDGVOQGHbGtUy/8qcsWzCXmafMINvaQNAj05LOkcpr2ALY3mJO/8n9HHLMGSx66WEe/833iMcSGN1oZnxdEE1rNCfzxNJ5R6AtkesRZmuDJAoU+Vy4ZJHmpEpjIk8yb6FbFhURH8VeF/1KfQwo8RPxu0jmTeqTGg1JjT0xlaxu4nOJVIU9ZDSD2niOmliWvP755ctmm4BaKq+zqyXFom1N5FWdCr+LhGqQ1zRqWnMs3NJEQzJHPJMjntWojuZYsSNKfTJPc2MzC++7iWVvP0u/o2cy6orfEjNcNCYz1CbyxPIa2BCr3sGiv32fTN02qs78CUVHXkh78Y6aT5Djq40J1Op1aA1bCU04q9Oku0vPklj4DO4+o/EO3ru/gEegdfVHtOzeymlX34wlybSm8+yKpkiqBlVhD4f2LeK4ERWcOKYXwyqDPVUj/yKEg/FFPuidCcIxwM22bZ9+oO0mTJhgr1ix4is77n8inlm6x7EZsmwUSaBX2MvuaLbLdrmdn9L0wi8pOuZKwpM6qwc3vfgr8rUb6X3dw53KWxNLXyE+71EqL7sTd6/h/9J5WrpK9P17yayfi3fYFEpP/WGnMlXbNIjOeYD06vfwDplE6ek/6lZgzcwmaHnjz+R3f0Zg7Ml8++e/5c/nj/+Xzu1/A88s3cP9c7eSyBtkNQPbpiBI17F/deXuWEHY67+xr3Xl7hgXPbzkS800H9YnzJRBJTy+eFcnwb8D3ceVu2M8+PF2PtzY2Mma62BhWybx+bNILn0JV9Uwys7+BXKoayWJbWi0vn8vmXUf4RsxlZJTf4AguwpCfU2v/BZX2QCKpl5a+EzDs78gNOFMfEMnY9tWp8HwYKDH6qh7+HpCR8wkcsyVheWWnqfuoeuQgiVUXnZnJwXR9JrZtL77N0rPuLlbxXGXLPLst3q8rr8oBEFYadv2hP/f5/HvhIOJC3pigs9HQyJfGHOKfApFXpk3P6tlY12KVN5gfL8gC7dH2VwXZf4d12IZOgOvuQ9RUdB1p6w7uX05dS/9msiMawlNOBMJp+TblU+y7f6r8A0/morTfogiOBOdIR/4PB5k0SKf1WjKOKXhMntLzGn72wbCHigNeqj+bCFbXvgTSG6qzv4Zof5jsCzweqDU7yWc2Mh7f/8fJEniW7/+OyPHHkFet8ipJo0ph0AJgkDtotdZ+NRfqezdh9deeYUJ4w//wvfNsmxM20ZuU5f+qmDbNrZts2JXC59si+KTRcrCbo4cUk5eNwl7XUTaFJht2yae1ckbJkVeF17Xf5/9WE0sCzZkVIPmpMquaArVMGnNqFQFvWi2Q4ibkjk+251Ax0IzbI4eXMzph/el1O8mp1sEPDKqYeJVJMqCnm6PZZoWLRmVbY1pmpI5ZEFg2c4oCVXHsCwqg27qkyrNKRVdNykOyFSGfFRHs2294jqt6RypnRvY/tIfsdQMEy66mb6Tjqc1bdCaVvEqIrrltL8lNi1j96t3ICoees28BVfVMIx9YoKyqZcWPOa/aEygsNefvh1NL/4KtWErva9/tFBSLuPYcjUvep5el92J0mu44wggQ4nbZMlfrqS0qh/f/P0TJDWTMr+Lw/pF6F3iI+xx0b/E95X+Rv5bsL+4oGd64muEjhnRWFYr9LNals35E/ty9theuPYp7/QOPBzv4IkkFj2Pme6c6Ss65ipsLUfik2c7LQ+OPRnR7e/ipf1lICpuSk67ichx15DbupSGWT9Cb93bPy1IMsUnfpfI8deR276chqduRo/Vd9mP5AtTfv7thCadS3r1e8z6xeW8sXD1l85gftVZ5ZW7Y5xy93x+8epaauJ5UnkD0+q+nLldDf7O2Zu56OEl/OLVtf91We4lO1oPOIN/oFf4ZzUJHpy/g3xb7/aBysVX7o7xrSdXcN4Diwq+1l8UZjpG4/O3FPqtKy/+U7fk2ki10vDMz8is+wjvkMm4+x2CrWWhQ+VG8fHXkd22lMyGeVh5R43cVdoX0etUZHxRcg2QXPoyiBLBCWd2Xr78Ncx0K5Fjr+40KFpqltj8J3H3GoFv5PRu9/mvluD3oAc9+N9HVjXY3ZqhOpohmdfbfH4lspqJJEmcMqY3ZxzeixPGVDC8VzGnHFLFyYf249Rv3kyutQ5t3VuEPDIhn0hJkUjx8An4+x9KYtFzSPk0HsAnwpA+JZSPO9F5byWaEGxQBAjILkxdpyWh0ZJ1ejLbS8vbCbbV9kfEEUSzTJ3I8CPofeldiB4/Nc/9D8lP30QUbXQDork8Tf4hXPX7J3H5w9z1wyt498VZbKxPsqk2xsaGJFnNoCzo4phvXMZfn3gZXc0z9egjeeChR9p8kA/uRa+blpPxjOZoSqld7Eu/LHKaya6WDIu3t/LOZw2YloVu2WxvznY7qZzK68SyGoZhsaUhSXU0Q0777xKQsiyn6iHokelb4qVvxI9XkRlcFmJkryIuGN+X40dVMrQ8gEsSneSSCK1pjR1NabY3p8hqOvGshmZYnRTC22HbNg2JHCt3R3l/XT1NqRwb6hN8tKWJXbEU2Ca6aVEfz5HXNAQLBBHymhNrxPJ5GuJZWhJZWpe8zuYnf44guxl97V2UHDYDQ7PQDRPTdEh8Lm9T//EL7Hr+N4ieIL2mnIkvUo7Ltmin/h1jAimfdoRWv2BMsC+51pp2OI4l48/oRK6tZAuty1+j9+HHMHni4ZR6ocgNPo9I3fxnySdjTLn4h+RMi+FlfirDHprTebY3pEhkNdQvkRDpwf7xlRJs27bnfV72ugfd44/vbOS8BxZxx/ubueSRJUTaSmray2UnDyrh7gsP51dnjunSixQ57hpsQyc2/8lOy11l/R0bjk/fQY/WFpaLbh/B8WeQ27IYrXnXv3zugiAQmng2FRc4vpf1T95EduuSzuvHn0H5+bdjpqM0PPlDcjtXdd2PKBE55koqz72FaP0evnHC0dx+7xNc8siSgyKoHYntwX7mYLByd4wLHlzExoZUt+vbv5927KsG/+zSPV/p+fw7oGM/tksWKQ107uU5flQFlxygJLljGCSJApMHlXSZPFm5O8ZFDy3mgw2NXXy4Dxa53Z9R9/j30Oq2UHLaD51+a7mrLYVau5GGJ3+I1rwbV68R2HoOvWEb0Y8eIfXpOwXFfjlURvGMb5HdvIjYx0/Q+v595HauOui+qn1hJFtIr/2QwKEnFPqowZkUSC55Cd+wI/H0HdPpM4nFz2Nl4kSOv3a/s9E2kMp9ecGYHvTgYNETF3w5xDIqH25oZGNtgmRWQ9cNohmNeE7H73bIhdsl0S/iJ+iWiWd1KoI+jhpWxgVnn8GQ8VPZ/eGzBOwUg8r89C3yE/QI9Dv5GqxcisSSF5BkCPsh4HUxYPp5CAjElr6MR4HSsExM02nOmKia04Odbzs3DYdU2+wl2xqQt6A5YZLNWgilfam8/C68g8ZT896D1L31V/I5lazqZH53m34m33gvpcPG8dLfb+fFe25j2Y4mmuMZ4hmNpniWUr/COafO4MMFizlk3ES+fd23uPTyK9hR13JQZDmdN7AsG79bIqMaaP9C2bbZwbu5Ja2SyGlsb0oTzao0JPNEsyplATemZeFRZEJepcNnQRIE8oZFPKdjmDa1sQyqZpBWddJ5DetrXAL/VaA85EY3bXTLxi1LuBUJtyxRFfI4wns2DCwNMH1EJeMGFtE77GVwWQBZlmhJ51hTHSetmeR1k4hXIehxRP7qEznSeWcsy+sWLSmVZE4nmzfZXJ+gNa0S8boxDVhXm2BXc5aIz4VHUehf7uOQXiFCPg+JrEEuL2BrGba/+Ht2vvswxSOOYNL376f3wOEEXQJ52ybokYj4JUQzT8Mbf6Zp3pNIwRJcRRWkm+upn/MI+Q3vINudY4LE5kU0ffwEsS8RE+w7lZBY9AKCy0tonPNaVXB+g7GFs8C2OOKc6zAAjyJTFfbQW0ywbe5LjJx2OkNGH4Zh2mR0C8O2UQ2Lbc0pVuxqZdmOZupiXStle/Dl0JPB/hrgmaV7eGD+jgJJyOsWczc3MW1oGWN6h7n19NGM7x9h5e4Yt76+trCdKIAkglLcm9CEM8ms/QC1fkunfRdNvQRBdhGb+2in5cEJZyK4vCQWPf+VXYen/6FUXXk3SnEvml/5LfH5szpZhXkHjKXyiruRgqU0vfgrEktf6jJIFnllfEMnU3rZ3chFVTS98lvq33uQ++ds/NzMdEdiq+rWv6RE3pHMvbyqhv1N7B0xINKlfLmdXLZTGxvnO/3Tuxu59skVnHXvQp5ZuudLn9u/AwrWXUf0A9umJd1ZbK886OZ33ziE66cd2DpKAM6b0Begy+TJkh2taF8mZU1bSfjCp2l67hZEd4DKy+8kMGZG1+1sm9Tq92h45ucIsouKC36DHCyl4sLfUXLKjfhHTMOI1ZFut9kAPP0OpWj6FfiGTELyhqi87K5u7b0OBsmlLwJ2l/aP2PwnsU2Dog4l4wB6tJbk8tfxjzked9WwA+77gfk7OP+BRf9VEz896MG/AyzLYn1tkpxuUBPPsmp3jKakSl0sS0tKJaM6xDGdN1hTE6cxqaKZJmlVRxYEKos8zPz2z7AMjdo5T1IS9FAR8tKnOEjFwOGUjTue6Io3kLKNuGSJ2niejFJMyeEzSK55H6/ZimU51XN5AzI2qPueIw657ggXTkYw1jbsi24/ZTNvoeSoi4mvm0vNUz8m0dpAa8pkV1OepAG9zvkfSqacS+Pyd9g66+dkE60IokjI7ybkdZFUDaKWh+/84VFOv+K7vP7is5x47FTeW7CM3S0pGhI5cprZLUGVJYGcYbCr1emLzmn7nvHBQTVMqmNZamJZGpN54lmNjbVJLCyGVQaJ+BT6lfg547AqBpYFqQx7OllWBjwyoiiQzuuEvDJZ1fne3lhdyxura3h/XSOfVkf/o/u2PYpE/xI/VWEvWd2kMuzB45KQJJEyvxsbJxlTGfbxrWlDOHt8XwaXB4n43VSGfRT53PgUR5VcNUyi6Tx7ohmyqkljUiWvmwgCJNq8rNOaQWNSxbYEXAr0iviZNKCMAaVe8oaJW5FRRAnbFinyirRk8yT2bGDV379LYvMyDjn724y65DZEj5doNk80pyLaAi5ZhkQ96x++mfjGTyg96iL8vYcz7NI/UH7yjQRHTCPRUEf9p93HBHyJmKBjrYPWvJvs5k+c7LUnADgZbr1+K/G1H9L36LMxA5WE3C4O6RVgeFWIda/+A8Xl4qTLbyTgluld7GPCwGL6FfsIeVzIgkRNLMv8LS18tLGR2m7aUXvwxdFDsL8GeHdd15LpjkrHv3pjXbdEzxlPBEqDLseix19EdM6DncTGJH+E8JTzyW1b2skeQPKGCI4/g+ymhV9JFrsdcqicykv+jP+QE0gsfp6mF3+Fmdtr/6MUVVJ56R34hh9FfN7jtLz2Byx17485nnNKr5WiSiovuYPg+DNIrnidp/7ncn7/7Eec849FnHrP/G5JweRBJYhtg5oNvLii+kuRh04l3g8t5vnl1V228bkkrp82iBeuP7JLH2s7ubx4Uj86akQs2xUrfKe/eHXtfwXJ7lXkRd+HBLtkkZnjnMHlZ6eO3C/JFnBKyt5dV8/1s1Z0KRmfPKiE/QhiHxBGspnG5/6HxCfP4h9zLFVX/BVX2YAu29mGRvS9vxN9/148/Q+l8vK/4ukzCj1aQ3r9XAA8A8bi7jsGrWUP+ZoNgFOiLYfL8Q6eQNG0y/6l7HXqs/cJHHI8cnivx6XWuJ3M2jkEx5+OEunV6TOxjx5BkBUi06/otFwSHH/xfbFsl9Mr30Oye9CDrxkE8LgkXLKEKAokVJMin4JHEUnkdHTLojWjktIMDNtmd2uGnGEiiQKVIS+HjBrJ5NMvZf38t2jZsZGQ10Vl0IPfLTN+5rUIksz2tx/B1AWSWROPAqEjzse2bWo+fom8ZaHpe225DgYmkNqnMEYQRIqPvpghF92Knmyk8YkfkNq+grQBTSkb3Zboc8KV9Jv5U9SmXWx66Adkd6+jLOSlOZVnQ02CZdtbSKg6R8y8lit/8yDxWCtnnzidX9/xN55bupOXVu5mR7PT09sRPkUikTVI5XRK/W7iWeNLlYknc7rTy+qW2RPNEs2oaJZFLK1RGvBy7Mgqzhvfj5Kgt9vPK5JIn4iXEVUhirwuVuyKsq0xzYLtTSzZ2szamigrd8VoSua7/fx/EgQBMnmdaEbF75JxSSI+t0TQszfj73UrDC0PceKYSoaWBzFMG59bJqcbrNod5Z11Dbz+aQ0bahM0JnPkdB3LtvEoEkU+F15ZRpYEJvaPcGjfMIogUO53EfK78HtkmhI5Utk88UyexlSW5liOz95+kk8f+CHYNpNv+BuDZlyA3y1R7HMxoCzE+P4lBLwiG5bOY8nfvoeabGH4pbfT7/hL0KK1NK2ZiwqIA8Yi/i/FBACJRc8huDyEJp5dWGbbNi0fPYzkCyONPY+krhP2KQQ9HmKbV7BlxXyu+/6POW3yGPqXBBhRGaIy7GVEVRifW8KwDHa2ZIllVDKGTkOqR6Tvq0APwf4a4PPscXTTZsmO1m4HOcuyiWV0RLePyPQr0eo2k1k3t9M2oQlnIYcriH34cKeMcmji2QguD4mFz3zuOWa3LKb2wWvYfcdZ1D70LVpn349au6nbwUqQXZSe+n2KT7qBfPVa6h//fqfMuujyUnrmT4gcezXZrUuof/ImtJauZFOQFYqPv46ymbdgJJqof/xGUp/NZn1dkvMfXNwtKeh4OkbbfYMD92Z3XLdyd4y752wpeDfrpt2l3+vEURVsuP1kfnbqyG7vVXt2dea4PlSFuhfhALh/7tb9rvtPQcTn6lTufeKoii7iWkGv0unZFoCzx/Zi4oAIpgXRjE5zhwy4KArUxnMAB2Vp1xGZTQupf/QGtMbtlJz2Q0pPuwnR1TUo0uMNNDz1Y9JrZhOacgHl595WGBTDR15Ifucq9GgtouLGXTUcUfFgpqMYySaSy17FzO1tJ9CadhD98GFqH76e3XecTc39V5Fe9+Hnnmti8fNgQ3jK+YVltm0T/fBhRG+wTaV0L3Lbl5PbvpzwkRfhLSrh+mmDOKxPmBNGVfDC9Udy3dTuJzK0nn7sHvTgawVRFBlWEUASBcI+FyN7FVHklcgbJqpuIokikiAgCVDkcSHYNpphUxrwoMgi0ZzOiMogZ1zxXUKREj556i8I2PQq8XNI7yLKK3vR/5gLSW1ZTHzHKizbsVb0FFVQfNgJ1Cx9l0xzQ/cG5m0wswmaXvo1u//yDarvuZDGF28juuYDNK0rSdQAT7+JjLjmblzhMqpf/DWxBU+jGiYSEPHJHHr0DKbffB++QIDlD9zMstdnIYpQm8izrTlNNK1iIzB87JE88cZchh92BI/9+Rbuv+UG1m2vpj6eJ57tzO6TeUf23K2INKZymG2kQTMsGhJ5mlL5bomEk7HO0JjMkVF1dNMilddpTuXJqyYeWebQPmHG9i9meGWQw/uG8Xu6tha1Qzct4lkdw7IJe2RiOacfPJ5S2dqUpiWVpzqaZUND4j++VNyybRAETEvA65IZWOqnV5GvU8Yf2yaeVWlJ621WbkFOGV0JpoBm2CiiSGNSRRAgndNIZnTyuklONRBti2hWxy2JRLMapg2jeoWpiHhwiQKlXheCCC0pnY0NaTZv383rf/k+2975J1WHHs24791H5bDRyIKAgIjXrWCZNtuaMyx+4WG2PfNrpFAFfS+/G6vXWDQdKqddSHrnKoyDjAksNUPq03dofO5/2PPX89lz5zk0v/6nTsmm7qA17SS7aQHB8Wd2Ei7OblqIWrOB8NTLMN0+4kmNrGpQEpD46Mk7GTR4KD+5+SaOGlrOuRP6ctzICvqX+ulb7GdkVQgbkSKvRMgtkstb+BQZsUfs7F9GV3PXHvyfo90e5/nle1hfn8Qy7U4DmyIJhR7fF1fWFEQ0BJz+1HYC6B9zHKlP3yX28WP4hk1GdPud7WQXRcdeTctrfyC9+j2C4xxTeskbIjThbBKLnkVt2Ia7cki352dm4rS8eQdycR9CE7+BHq0hs+5D0p++Q8XFf+zSA9qO4NiTcVUMpunl39Dw1M0Uz7iOwOGnIgiOmmfoiJm4KofQ/PqfaXjyhxSfdAOB0cd22Y9v6GRcVw+h9e27iL73N3I7llNy0g0s2dHaiagt2dHaifALHXp3L3lkSbeK1O19vLppI4lOYGOYVoEUCjgzru1jnksSuG764O6/SOh0LFkUumRvO6Imnmfl7th/tJJzLKsh4FQUCMBhfYs63ftXVtXQlFIROzzHggBDK4Ksro53u0/DtHl26R6eW7bnoHuvLTVLdM6DZNZ9iKtqKKVn/LhL9rcd2W1LaX3rLmygbOYv8Q2d1Gm9p/+h6K3VpFa9TdHUS5FDpUihMtTajfhHHE1g7MnIgWLye9YQ/+Q51D1rQFLw9D8U39BJqDUbaH3nHty9R+73HPR4A+k1swkcdhJyuGLvuW3+BLV6HcUnfqdQHgZgGzrRDx/CVdKHGedczs/POLTLc7W5IVX4LjpCEOikIdCDHvTg/z/KQ15OPaQXadUpa7YsPy2pPJIkURX2IEsipUEPumkT8SkMKbPxuGXyuknQrRDwyIwd0puLvvMzHvztzexY/A5Hn3Ie/SJemjc3M/aUi6lf/i4733mIUdfeg2mBywUDZ1xIbM0cdn30HGWn3IiAUw7upnOZeGzuP8ntWk1w7CnYhkZ+92e0vnsPuZ2rKDvrp12uJ2NDKFxF5SV30PjGX4gtepZ83SaGn3MzmidMWUmYYZWHcP5jr/PEn3/O7MfvpGXHWmZ+/3aCboXmlEpJwE1lyENWkvnuHx/h8QfvZdlL/+C+752L9zd3MerK8zsd07AsfC6J6qhKRjOoCjmTqU2pPLZtY+nQYmlUhj2Ylo1uOuP2hroEec1EMy1cikjE42JNdRy3S0S0oDjgJqVahL0u+pb48Lr2T64ty6Yh4Rwvr1tkNYO+ET/zNjUiSSIeWSKdM+hXImFazvrAAcj6/y/opoVp2bgksVAp+GVg205GP+BR0A2rsC/LsqmL56iOZnArUlvvvEzAJWEDmmnj8ygoEuR1E92yyeQNJElElETq2kr4W9Ma9YksLkliS1OKiNeFppvIIjQlczQkVfa0OmJ0zWsXsOvNe7FNnf5nfo9ek05BFiWyqk4qp2FbAiImHltl6WO3U7thOeFDjid0wrexFTcaINng7XsoYkM1iW5igsCIoykaezJSoBgtEye57BVSq9/F1nIoJf3wt8W86dXvIgdLiRz3zf3eu/iCpxDc/k7Wm5aeJzbvUZTygQQOdWw8RclRa18/+3n27NzOa6+/Se+SvYTctm0sG7KaQTqvEfHLeJQQCDb9I14GlQX+pe+4Bw56CPbXBBdP6sfFk/oVsp8Rn4v1dQls4JxxfQrB8rPfmlxYH8tqRHwubn9rvfOikiQGnnEDmx+6kfjCZyie8a3C/n3DjsTd71DiC57CN3JaIRsXOuJsUqveIj7/SSrOv73bc1PrNmMbGiUnfht3bydra6lZslsW4+4zar/XZJs6at0m51i2RfSDf5DdvoKyM39csOry9DuUqivvofXNP9P61p2o1euIzLi2oIzYDjlYSvkFvyW57DXiC56k7tHvsrP/XXDs3kmByYNKkCWxg4qnQyU69mZ3VE9esqOV1dXxQh+vYeFIXXaABQi20+8+qCzA1UcNPCAh7iRwdhD9wa+sqvlaEuyvymasYwbbbvt3+/472nhJwt7JIsuGtz6ro1+xj12tXWd0C/s7SHKd37OGlrfvxky1EJ5yAeGjLkKQur76bNMgPv9JkstewVUxmNKzfoYS6VpdInlD+EdOI/XpO7S8eQfFJ1xPbvtyPP0OAcDKJWl8527yuz5FChRTdMzVBA49ofCb02N11D10Lfk96/ZLsBMLn0YQJcJTLigss/Q8sbn/RCkbQOCwkzptn1z+KkasnvLzb2dFTZrNDalOExkvr6rh+eXVXcg1wHVTB30tn8Ee9OC/HYIgdCqdDfs6i0V6FIl+JX5s20Zqn9C1baJZjZxmMrZvERN+/G2Wvfcib/3zLo6YfgqecAnThpWxrdFN3Te+w9JHbyOz9gMqp5yGKVhYwVIih59CdOXbFB0xE3dJHwy6lonnd6/BP/woio+/FnCCdrVmPaKrqw0nOGNpStWJr/kAI9mEp2IQuV2fsu6BG+h3zi+I+MehWSIuWeTHf36AFx97kOcfuIMdmzdw1g//RJ+ho8hpFo3JHB5XgCmDStk+8yoGHz6J2ffdyq+/dzlNGxZzxx134Pc7yYWQV2F3axZFlihWJJrSeaqKPBimjUcR0SyTjOpYbjanVOw2q6hERqc05KYpmSedNynyCIiCSNij0JLSMW0bv0sm7FNI5Q0CbrNbVWsA07YxTBtZglTOEakbNyBCfSLXZt+lEs0YeFwSpiHQmlG/dgQ7mdOpT+RQRAGvW6Yy5PnSdk627RDjplQetyTSr8R5XjKaQU08h1sW0UwLy7YJCrCyJk5rSiXiczGgLEDIo6AZJmUBNxVhF6IosqMxTSKjsb4+QWXQRUNcozWTRZIkikrd7GzNkM8ZpDWDjKaSSibY8dYDJNbNw1s1lKHn3oy/vC+yIBLxu4jndERRxBYEajetZvOLf0bLJOh/5vfwjjyRXIdfQ84CwbX/mMAGLLef+CfPklz6Mrah4Rs5ldCEs3FXDS3sx4jVkd+zZr/3Ta3dSG7bUoqmXobkCSDi/KaSS1/BTDZTetpNCKKEDLhFaG1uYP5D9zD2qBkUDTsCw7SQJRHddKo3cm2iiVnVpDTgwbYh7FU4pE8Y936e5R58MfQQ7K8ZxvePHDDY7W798MpggQzBFI5b9jbJlW8SOPQEXGUD2novBYqPv5b6x24kvuApSk78NuCIkIQmn0d83qPkd6/B0//QLse0TafsSuhAekW3j8AhXUWhCp+xLbJbl6LuWUvJKTfiqhpK49M/Jb9jBfVP/pCys36Gq3wgAHKwhPILf098wSySS15CrdtM2Zk/RSnt22mfgiASnjQT78DDaXnrTu68+Wpa183nr3/9K0VFTmb0mGFlzN7QCDjKna+sqmHmuD645DbiLQjM29zEPR9uxTCtLoNEx4qAwrXgkLltTWluf2s9wyuD+/2OJg8qQRaFTuS6u6xhO15cUc3MDhMo/1c4EIF+Zukebn19HZZtH5QH9YHQMYMttv0bnIkIvYOggGXDiIpAQal9Y0Nqv6rtBwtLyxOf/wSplW8iR6qovORPhQmifWEkmmh548+odZsIjD2F4hnfQpBd3W4LjrBg5NiriX34MIlFz6MU9yYw9hRaZ/+D9Op3nZaNY68mcPhpXSaLBLnt31b3gjta004y6+cRmjQTObg3s5xc+rIzkF7kDKSFc082kVj8PL5hR+IdOA6A295YBzjvhkseWYKqW91mrq+bOmi/rQ496EEPvv5wSmudccwlO39Xhr3Ytl0Y3377p7s444SpzH/+Pn546x/wlQjUxLL0OXw6O4YezrbZTxAeNRVN8eOTYdCMC4mvmUNswZP0PvsXjlDnPse1TaNTTCAIwn6r2cCJCZJtMUHxSd/F22s4zW/eQW7np+yY9VNovooRMy5EEiyW7Ixx2KmXQMVg3v37LTx1y5VMvOB7DJ72DQzbh8+dZ1V1K4ZpMmzEIRz96JssfuE+HnjgH8yePZvHHnuMqVOn4pYl+kd8NCVz+F0SsiiSypuUBlw0JBzrLp9bZte2FB6XxICSALppE/AqpHJOabhXkYhlVBJ5lZZsnmzeJJXXMUyTI4eV4ZElUnl9vwRbkUQ8LpFtjSlEUcDvlklkDY4aUsrczU3UJzRkUUASBMpDLqeCrsN3938B07JJtSlxBz1Kp3LtjKqzrSlFVjOJ+F0Igolu2oVn7WCht5Xi66ZFJOCil8vriNN1rDoUHJ9127Ip9rmQJYFERkcUBWoSOUREKiJuvLJCRVhGFgVqY1maUnmSmkE8q1Mfz5HIaVimjWoY5Ep0oimN6liS1rROYstyqt++Fz0dp88xFxOcfD7hkAfbthAFyOkmyZyGX7TYPOdFGhc8jStSyZCr7qTvkKHkVJOGjBPPtDck2HQfE4QmnUtm00JiHz2CmWrBN+xIiqZdhlLSt8v9ERQ3drr7mMC2bWIfP4HoLyrYdQZlSLU2kVz6Er4RUwsT/ApOJea6l/+BaRq4j7yYv32wgXXVFVx61BAMy3m+FFEgq5p4XCJCTiDoVRhU5qc4sP+2xh58MfT0YP8HYHz/CN89dkiBfB954Q2Ibj/RDx7Atm1K/A5RcJUNIHj4qaRXv4vWuKPw+dD405GCZcQ+fqyTQFo75FAZAEa88aDPyYjWoVavxTvkCNy9hmPrGlKonPDUS7C1HPVP3kRq9XuFkm5BlIhMv5Ly836NmY5S/+QPSH02u9seb1f5QKou/ysjTr6cJ56cxfCRo3j77beBrkTWbrs/t54+GkFwyPPyXbFClrl95t8pt4fKsOeAgi6qfuB+1fH9I5w3oW9hH6IARw8t5YgB3RNUw7L/z/tfD2Rn9szSPdzy2lqMtkzyl+3Pbe9rj/hcuJU2uy5lr53Z5EElKPLe148iCTSmvjqBl/yetdQ/9j1SK98kOO50qq78+37JdWbTQuof+x5ayx5Kz/wJJSd9txO5zmycXxA16whBlCg+4XoiJ34HV8UQ6h66lvTqdwkefiq9rn2Y0BEzu5BrACPh/I6kYPf947F5jyN6nImvdujxBpJLX+40kBa2//ARsB27vsL2ps0vXl3Ln97d6FjLdNheFJxe+JeuP7KHXPegB/+haB/vdNNi5CGHctZFV/DK04+x9rPPyGomVUUeSoMe+p/+bfRchu2zH0cSBLJ5KCoupt+0c0ltXoTYuqnbMVEOlf5LMYGh5bFsqDzhWgJDJ7Pj3X+y9J+3oKbi7GhKsr4+Sb5oKDfc/QKDDpvMkqfvYsE/fkFDQyPLtrWyYkcrfkVCMy3q0wbX/fg2Hn7uNTTDZPr06dx4440kkykymk4qr1Eby5PK69i2TcCjUBX2UR70YJpO2XNGNalP5JElkWEVQYr9LkzLJqMZNGVyjOldhEsUkETYHU2xoyXD3I2NJLM6rg5qpqZld+mjjngVSoNuBpT4qSryUhHy0K84QIlP4YiBxfQKe0nkNaJZnfz/By/i5lSeWFYjltVoTu1tBLBtm7pYrk08zqY1owJ2537pg0Aqr1Mdy1IdzaLqzvXlNBNREHC1VR16ZIl+ES8CAiGvwoCyAD6XjGU7pLchnidlaOQ0g9aMSjyr05TSkASRiN+NbdgEZJFsXiOZNxAEAcUlktFM4nmVfDrDzlfvZvuzv0b2BJh0498YfdoVlATdeBQR2xIo9iu4JIFMrJVVj95C4/xZhEZOpc/ldyOVDCKZNUnmHb/pxjkPYpudCXF7TFB80ncJjjudphdupeX1PyJ6Q1Rc/EfKvvGLbsk1OHHB/mKC3I4VqNXrKDryIkSXl6AIbpdAy9xHAIHyY68GnDYOrxvqN31K42fzCE86lwaxgvW1GV5Yvpv7PtpMVjMwLBtRELAsGxEYWOpn4oBi+pcEUKQeWvhVoedO/gehndRMP2QQRdOvQK1eR2bDvE4WSeGplyJ6AkTnPLCX3MouiqZeila/lezGBV32q5QNAFFGrd140OeSr16HbRqF/hK9ZTdyoBh3rxFUXfk3PH3HEH3/Xlre+DNmZi/B8w4aT9VVf8fdazjR9/7mrM+nu+xfkBVyh51PxaV/IW66OP300znl7POYvbKzTdmYXmEA1tclumSmBUAWBWaMKC8IatXGcvvNNrejnSR2J5zW/v+KJBQ8oH9w/DDU/Qyalr23bPqrwoEE3WD/JfOODdy6Tn3NoiB84f7cjgT+9rfWc+vpo7npxOE8fc1kAO6buw1w2h0uaWuNePbaKQwpCxxotwcFS83Q+v69ND77c8Cm4qLfU3zC9YiurrOylpan9b2/0/L6H5GLe1F15T34R07bu17P0/qe84xm1s7pdrJHj9bS9MIvaX3nryiRXlRdeTfFJ1x/QJXQ9t+Ru3Jol3W5XavJ71xJaPL5SB16rGMfPQKCSOTYzv1ZuR0ryW5ZRPjICzopjbdj2a4YsiTSMR6SRUdHoKcsvAc9+M9FXjcLpEYQBb578/8QCEe4+/afURvLEvK48LpkevUbRK/JZ9K88j3Uuq1kDWiMmYQnnokrEKH6g8dQupvorhyCWr+lk3DqAc9nPzEB3iIqz/45g067ntiWVbx++1VsWvEJWxuT1CfSbIiZHHL57Qw65Rqq1yzmldsuY8uni6hNqGxqTLK1Kc2OxhSvrqqm2T+Eu597jwuvuIa///3vjDnkEB589lXiORO3IuJu6x92+oktUnmDuniOaJu3uCIL2DYkcjqq6WRsS4MeNM2m2OdmcEWIkFchnjOQRKiOZUjmNQJupxg0ltHYE81QHcuS101My6YmlqUukUcE8pqJbUNxwIVLFpEkmbxuYWJTGfbSv9SHW5YK2i2GaX2u6JlhWuQ0s5NYW8fPWZZdKIPfn4J6RjURBXBLYicl9rRqkNFNbMshuUVehd4R3xcm2LGshleW8LskMppJn4iPyrCHPhEvsaxGTSxLbTxHxO9mwsBiRvcuwqNIFPvd9Ctxo+om5SEXHlGkJaXic0t4FRGPJLC7NcuO5hRbm5Jsbk6zuzVLXTRLUyqHmjdRVYM9Kz9m6V3XEFv3EZEp53PkD+8j0mcEiiThdcn4XBLFQTeqarNtxXy2PvQ9srWbKTnlRkpPvxnT7SNvQ0seGjcvY8/jN5Je8wFa084u12qbOoklL1L36A2odZuJHH8dVVf89YDVHZaaRW/e3a0Okm2ZxOc+hhzpVWgNy1mQ2v4pqc2LqJx6PhVVZQVyrWk6Ne/+A1dRJZFJ5xYmx+I5ky2NSbY2pvAqErphURH2UOx3E/AoBNzyF/5ee3Bg9JSI/4dgXyGvSy+/kgfXzCY295/4Bk8sCCJJngBF068k+t7fyKz/qOD96x99DMkVrxGb/yS+YVM6ZfBExY27z0hyO1YQaZsp+zykPn2boqMvQRAljGSLozhuGbirhiG6fZSf/2tiHz1KasXrZLctJXDI8ZSc+B2grWT8/N+QXPoy8YVPo9ZupPS0m7otX3dXDaXyirvpt+d9Zr/8CCjvEznum/jHzEAQBJ5fvofXPq1h+a6uZNPnkshqJrM3NHKw7xVBgA/WNwB0EU7ruEyWRC44om+hf/6Cif34rGZtl/21l00fbM/zvtvt27PfnFKZt7kJw9p/eXfE52pTiLRR5L1Z5VdW1WB0GMxFAW4/a8wXImIdVdhtHAIfy2p899gh3YrN/e4be7OxZx/eh2XdfE8HA9u2yW7+hNiHD2Fm4gQnnk3R1EsRle7LndT6LbS8+ReMWD2hSec6fvHS3r43rXkXLa//Gb11T9v6SzuV7NmmQXLZK8Q/eRZBdlF80ncJHHYSgvD5c5a5HStQSvsjBTrfV9syic39J1KonND40/duv30Fua1LKJp+BXJo7wy3bWhE5zyAXNyb0MRvsD9MH1ZGXjdZuLUFGyfDsq9AYA960IP/HFiWTVMyj2iDzy2TVg1GD+zFj2/5Nb/80Q28+sIzTDz+bKqjGZJZi/4nXErTmnlse+sf9L/0DiSXSF730uuYi9n11n1kdyxGHHxkJ/FVT/+xpFe/h1qzoUtVTXfYX0zgqhqGSxAomnI6vkgZm179O+sf/x88/Q6j/wW3o+pZJBF848+md/kYmt78C6sf/TmtU2fS55jLqSotwhYhrZr43TrrmwxOve4XHH/a2fzqJ9/nj9+/gsknnsU51/+c8lED2d6UYn1NnGhOwy0LbK1PI4gCu5uhIuQh7HHRt9RPVjXRDatNbEtCEGwqQi7yqpcNdQlkERI5gxW7ohT5XAypCBHPafgUhyDHsho+l4RmWPjdMinbpjTgRhIdEh/0KoyuCrBgawtBj0xFyINgC0ii8yea0YjnnPLxipAHt9y1BN0wLeoSOUzLKSmvCLhJ5DXiWQPTsvApIi1ZDZco4FJkKkJeiv2dJ/R10yJnGMQSOrIoMLxi7+RwTjNRRIEivwu/KTO4LICrm/PouC/bdhILtm07Pe22jaZbZEwDr0vCrUhtegGQN0xSeQO/W0Y1TBI5vW2dhSKJuBWR3pEAoigBNtsb0yCK+FwimbzBpsY02ZzjeV3TmiatGqQ0E8sGRZJJt9bz+qwH2LFqPr6qIfQ65zY8lYORXV5M23YmFEyTZM5AMvNsefMh9ix5G2/lYPqefTNquG/Bg9o0dGLzHiW18k2U8oFOG+M+XtZq3WZa3/0bestufMOPJnL8tc4k0ucgt3MV2BaeAWO7rEuvmY3euoeys39R0I7xoLP9nQfwFFfR+8hvUB70UhowqY9r1C95Da21hiEX3YbX5yavO+KERW6B8oBMVrNwyyKeoKfQ2pBWDUzb7iGEXzF67ud/ANpJTcesZDRnEDnxOzQ8eRPxBbMoPuHbhe0Dhx5Pes37xOY+hm/IJERPwCnRPvabND1/C8kVbxCefG6nY/iGHUlszoNoLXtwlfY74PnYpkOk28tdMhvmYqaj+IZNKYib2WoWI9GAf9Q0crs+I/3pOxjROsrP/zWCKLUJPJ2PZ8DhtLx5B43P/Q/BiWcRmXZ5l95YQVKoHXQ6g64dS/Ubf6P1nbtJr51D8Qnf5jP67/c8M9remdp9J4lPHFXBMcPLuW/uVmrj+U7bPTB/Byv37C0z1wyLu+dswatIBWKpGRbbGlOF7PDwyiCi0PU4oiQQ8bn2q3LeEfsS1JNHV/LGZ3X7VdJuz0533NfK3TFuf2s9puWUed16+ugCUX9xxV6/b0kU+M1ZYwoK9+14Zuke3l1Xzyljqvbp/YcHP97Oh5uasCy70HOtyCIRn4tfvLqW9bWJTsS7o9jc5EElnfq1vwj0eAOxDx5wiGv5IMpm/rKTeEhH2KZBYvELJBY9hxQopuKi3+Hpt3fixrZtUqveIjb3UUSPn/Lzf4N34OGd9qE2bKP13XvQm3biG3YkkROuP6hBFMBIR1Gr13ey3mpHZt2H6E07KT3jx4Vn3CHRD7Yp+J/dafvEkpccYbMLfosg718UpzzoZua4PizfFUU3rE6TKj3oQQ/+s2BaNvWJHK1plaxm0ivsQZZFgm6Fk2deyNOznuCtR+6k72FHIwDBgIxq+Rh6ytVsfOkuWtfMIXLYiZhA+LAT8S59k9o5j9Gv/0QEWSkQDu+g8Qiyi8ymhZ9LsA8UE1huHxoQi6XZs/QDvCOmobXsIr/7M7Y9eA0DLv4j3tIKTAuCvYbgu+pumj96hOr5L5PY9ineS38BZf0xTGhJ5ygLeOhb7CPQbyS3PfIGb866n7eeepA1i+dx9fd/xtkXXc6Oliw53SSrajQmVFyKSCytssKwKAl6GFoeZMLACANK/WiGxWF9I3hdMqIAAZeLBVub2N6YweUSqQh52dSYol+JH1EQUA2TeFbH75YxTJOaaIaMZmHZFmpJwBFXMy2KfQqNKY0+xT5U3aK+TWm8LOiUkMdzWqEEPpHTKQ92JbaaaWEYjg1ZTTTDxtoErWmNsqCL3dE0O5uz5A2DiqCbif1KMQyrC8HOaQYBl0zE6yKjGXhdcuE5SqsG8ZxjL9av2IvP1ZkyGIbF9pYUyZxBxCcjipJTgu+WyRsWmuGQVxHHMq2qyE2vsJftTSnSqoEsiMiKM6GQUQ1w2dSoBrbgVBj2Cnsp8rkwLItP98SxAAUBryyiKAK2oWPa0JrSaMmoGBbkVbAsgx0rn2fnh7MQbJtJ53+XyIQzied13LJESzpHeZtneTqvkdi9kW0v34kWa6ByykxKp19Gyu4w4d6yh5Y370Bv2klw/BlEjrm605hraXniC2aRWvEGUrCEsnNuxTfkiM/7qRaQ3bQQ0RfuIhpsqVniC57G3WcU3mFTCssbl7xGvqWGvuf/iqThQmvJIUuQaW6ieeFzRIZP5tApx1AVcRNNGWiWyYjyIEOqiigJuHArEookkszrpFUDn0vq1ObQg68GPQT73xz7qjE7tlICo6tCfFI1hOC400itfAv/mBm4q4Y52wgixSd8uwv59g4Yi3fwRBKLnydwyAwk/15i5h9xNLEPHyazdg6uz8liC5KMu89oml78Fa6KwYhuH6EjZnYiPdEPH0EOlVE09TLApumV35HfvZrGZ39O8YnfwVU2AHAy1FVX/s2ZOVz+Gvkdqyg57YddCJRlgxXqQ8UlfyS95gPi8x6n/vEbCY0/k/BRFxWI/X7PGbhu2iDW1yc5ZUxVgVgu29lK7eq6Lttvb84UbLgsGxZsbemyzbJdMZbviiGJAseNKO+WCBumzdzNTXuJud6VFLdjyY7WwnZ53eK1bs6r0zV1sClrJ7Ht5eGOcJvdSXSsPXstABdM7NuJXK/cHePBj7cXBOQWbG0pTBiIwt5+v44YVO6ort/6xjqMDqJvIo6X9aurarjrgy1YlpNJ/9UZo3ErInn94HrQLF0lufQlkktfBlEictw1BMef0UkArCP0lmpa3r4LrWErvlHTKT7h253KsM10jJZ37ya/YyXeQRMoOfUHSP6iwnrb0IgvfIbksleQ/EWUfeMX+IYdeVDn2o7M+rlgW/hHTe98LWqW+PxZuHoNx9ehTD2x5CWMeBuJ7pBh16O1JJa8iG/kNLzdzHq3QxIpiOg9fc3kr0QZvgc96MHXF3ndoDWtOsKepo1m2QRFARsng3rxTb/m11efyWsP/YXgjOvRLchqYA86Fm+f94h+/DjFwycje0IgSvQ/+ZtsmnUbiZVv4p80s3Ac0eXFO3Qy2Y0fU3zcNw8oCnmgmKDd/qv+/UcQQ2VO5ZHbR+zjJ0gue5Udj95A7xOuIXzYiZiWgM/jYeCZN1A6ciI7X/87c++8nsPOvJphZ1yGKMlopsn8zS24XSLDygMcf+l3mXz8mTx/z23c+9uf88Hrz3PM5T+hz4jRpLIGNYkcmmbSmMigKDLxvIFh2hw9pJR+Jc740J6NFQSB6tY0h/YtIuhV2NWcoSaWRRCgPp7FrUhUR/OIoo1HlqhO5siqJrujGSpDHjY1JBhSHiSjOqrlqu70IadUDRuR3hF/oRc8mtWIA36PTMDtvPvzmqOELQoCIY+CLArUxLNkVYM9LRlG9S4imdfY0pAintfIqQYWsLUpS5EvhSUJ9MlopPIGNjYVIQ8uWSKnO5ZkiiQWyoQNy0IABpcFyKoGQbdSqOLK6ybNKZXaaIZkXifgVpi7uZmBpX7ckkgirzO0Ikh1a4YNdUkECUIehWhOxTBha2OCrG4R9Mj0K/ZjyRbRtEqTaRPL6vQv8TqK10EPvSM+BARKAllEXOxpyVAd06gMusnqJvUJlayuEvJINKVMMnvW0PTBA6gtexgyfipTLvkRhr+UXFZFVnWKfC70lErIp9AUS7L9vSdoWvQyUrCEiot+h7ffoeTbwhXbtkl/+jaxuY8iuLyUnXsbvsETOz3buV2rib73d4xEI4HDTyUy/crPjTc7wsylyG5bSnDsyV1il8Ti57GycSLn3la490aiidZPniMwbAqugRNQAdUADGj64AEQIHzctbRmdA7pHebQXn4CXoVIwIUsiAwpD+BzSQiCQJ+ID8OycEni/6mo3n8Legj2vzleWVXTwZbKyf4Zls2jn+zk6CGlzMtfSnbzJ0Tfv4/Ky+8q/IDdlUMIHn4qqVVv4x9zfIGwRo67hrp/fpf4/FmUnHJjYb+SP4J3yBGk186haOplB8yYAQTGHId3yBEYrTW4qoZipqOFdbldq9Fb91B+zi8LLyI5XAH9D0Or30L9kzehFFVSNP0qfEMmIro8lJz4HXxDJtH67t9omPUjQpPPo+jIC7uchyCIBA87Cd/QycQ/foLk8tfIbJhH0fQr8I85br8lvJIIi3e0UhFySorvm7uNVE7nrTX13W4fy2jIkvC52db272POxsZuM9gAK3dHC/ux2NuT/cd3NvLCimokUaA04KY1o36h7K5tO97Ht7+1vpD1vvX00c7AZVhIUmfRsfbliixyzrg+ncrPb3tjXRdP7/Zrsdpl1vfBtqY098/d2olcA3jb+rC2NWcKyzTDYn1dgltPH81987ZRG8sd4LraysHnPoqZbMI3YiqR476JvB+BENsySS5/lfiCpxFdXkrP+hn+EUd32ia7ZTGt7/0dW88TOf46guNO7zTg5Gs20PruPRjRWvyHnEDxcd/s5EN9MLBti/Rn7+HuPaqL0Eli8QuYmRhlM28pHHd/JNq2baKz/4EgKZ2EzfaFAFw4sV+BTH+eQ0EPetCDf3+oukU8pyMJEE2pNCYEQj4Fn0umf8TL5HGHc9w5V/LB848wbOjRmOVjMAAEgaITv0P9YzcSnf84fU6/EUEEeo3HP2gC0cXP4RlzbKeJ98ChJ5LdOJ/M5k8ItPVW7w/7iwlUuo8JzEwc/8hpGPEGat7+Ow3zZlF1wrUoY6ahaeAbPIlR14+g9r37Wf3qQ9R99gnHXXMLdriKaNbE6xbRDYOqIg+9qvrzvbueYNF7r/PSP/7Mgz++mEOOPYtJ532bjOYikdFIayAYBiGvRSqvk87pLN3egiwKyLKA3yWTyuusqo6yrSFNxO/cU7ciMbjMR0taxaMoeFzOOOp1SxgpcLskyoJuRGwaUipuSaI06EaynfapWEYnq+sEFIXWZB5REmhIZAl6XLSkVXK6QXnAQzKr8emeGNubktQn8lSEnB5m3YSAR0aSIK3qiAKkNQ1JFNEsCxsbRbRRJIleYS+7WjNUBN0gCDQm84Q8CqmsRlo3GVoaQG4j2LIgkDdMWjIaQbdMwKNgWTY53WRPa5amdJZNdSla0ipjeoVI53WqWzOEfC7ymsn2xjTpvAECbK5NE/Ep9Cn2oukGu1qySIKjDWJZNlndJJUzCXlkGpI5Qj4FRRKw2yzOMpojJrapLkFjUqUs7CGt6kiiCIKFathEm+vZ/fajJDctxFVUwaFX/IoTTj6NlGZREXSzakeOrAGZ1hxBF5hNO1h2/69IN+wicMgJRGZcg+j2Y+HEYUY6Sus795DfuRLPwPGUnvqDTm1dlpoh9tE/Sa+ZjRzpRcXFfzxgn/X+kFn3IZh6wcO6HVKsjuSK1wmNmUFF36HE2/TUonMeBAGKZ3yLjhJr2S2LyW1bRtExV+MKl2PZJpsaUyRUk0P7RqgIeSnyuigL7m2bc9oReiy5/rfQQ7D/zbE/wqWZNttbMohuP5HjvkXLG38itfItQhPPKmxTNO0yh3zPvo/Ky+5EECXHWmD8GSSXv0Zg7CmdMsXBcaeT27qEzMaPCRxy/Oeem+QJIPUegZlNkP5sNr6RU3GV9iO/ezWBMTMQfUWA07eiVq+l8rI7sfU8za//Ga1+M61v3Un+kBkFP2/voPH0+uZ9RD98mOTi58ltWUzJqd/H3Wt412P7wpScciOBsScTnfMgre/cTWrVW0SOu6bbl6BhwWc1CSBRyNIeCDZ0Ipztpc37K3E+kE5JNKN3+vffPtzCe+vqmd8hK97cQajuYGFaNu+uq+/UOrCuLoFu7C1jf3lVTYeLKkzb8sH6Bh5ZuLOQlf6iZdvtqIl3VQbvWJrfEa9/Wsuzy/Yc8F6p9VuJffQIas16lLIBlF70+04l3vtCa95F67v3oNVvxTtsCiUnfqdTgGipGaJzHiazbg6uisGUnP6jTi0QlpYjPv9JUivfQgqXd1syfrDI7ViJEaun6OhLOy3XY3UkV7yGf8xxhWfZtm2iHzzQLYnObJhHfvdqij+nNF2RRWaO67Pf9T3oQQ/+8yCKAuV+Nw2pPE1plXjOoDwgkzds8loIryJw9uXfZcHsN9n5xv1UXXVPoTrGVTaA8MSzaVn2Ct5Rx1MycBRuF5Qfdw07H72B2MdPUnrq9wvH8vQ/FLm4D6mVb+AfdcznZsG+TEwgegLEFzxNculLVL9xJ4ktiyk546fYBii+MH3P+RmlYxay4837eeGXl9NvxiUUTTqbIp+LvKbTlNRIlJoIgk3ZocdwxV8msvyNR1n65jNsXvIBg2dcRMXks3EJAhkNsjmNkrIAtYk8umWTyeuUhz34XRLr6pMoooTHJZPIavSK+Ah7ZXQL6hJ5ykPQO+ShOpEnr1v0CnvIaSZ18Sy6KHBY7zCqaaNpJtG8TlbVkWTw2CLr62LUJrL0LvZSGfKQ0zJUhFy0pAyiWQ0sgR3NSeoSGk2JLNGsTnNGwy1JDKsM4ne78CoCguDiG2P70pBSWesWyeQt8paTlVd1A69bxrRtRAREQWBzfYK5W5qIZzWW+1ycN6EfVUU+NMNEFkUiXhELG8O02NKQpDWtUh3LktVMBpYGiKc1auMqIypCxPMGqm4ScEnURrMkNR3dNHHLAqICLRmNvqV+BLK4ZIFtjSlqYhnyukUyr1MWdGNbOo3xFB5FAUySOYvGRIbNDSkaUxo+2fEK31inYpgGzbEE6995ipalrzlCoEdfQq/JM3EF3OxpySJJIslcnpasgVsWSaRz7PngWRYsewUlEKHPubch7ZOVzmxcQHT2/diGRvEJ1xM4/LROz3Z221Ki79+HmYkTOmIm4aMv6dYx5PNgWyapVW/h7j0KV/mgTuvqP3oEQVIon34FCOAXIbp5ieOFfcxVCKG9oqaWmiU650GUsgGUTDiToAfKgx5kWSbkU8ioJrtbMvQbdvCZ9R786+gh2P/mOGdcH55buofuimprYzkEwDfiaDzr5hBf+BS+4UcWbLdEt5/IjG/R8safSa16m1Cbv174qItIb5hL9IMHqLzsjkLW19P/MJTS/iSXvVoQETsYSL4w4aMvxtYd+wdB8SDIe8uNonMeIjj+DCRfGAhTccmfSK143SE2n76Du2oo/lHHOOfsCVB62g/xjzia1vfvo2HWzQTHn0HRtMsQXd4ux3ZXDaPy0jvIbPiY+MdP0PjMz/AOmURk+pVdfLa/DEQBZEnk3PF9GNMrTCyr8fHmpi8t1gXQkFRpSKqfu12RVyae6943sR2jq0Is3RktZKznbmzsRJafWbqHV1bVcM64PuimXZg4eHDBju6S0v+rSO+HeIPTZx1fMIvsho8RfWFHVOzQE/dbDm4bGolFz5NY+hKi20/pmT/BN2Jqp2c2t2s1re/cg5luJTTlAoqOurBTGXZu5ypa37sXM9lMcPzpFE27vNtn7GCRXPYKUqAE3/Cj9p6nbRP78GEESaFo+pWF5dmN88nv+pTI8dd1ItFmLkXso3/iqhpGYOwpBzzeueP/7/3Ve9CDHvz/hd8tI4pOb6lfkTAtiy2NGUoCblKqgSBYbIsZDDn926yb9SuSS18hfOQFhc+HjrqIzKYFNL5/L76r70EVFNylfSiacCbxZa8QHHtyYSJQEERCE84kOvt+1Jr1B53B+2IxgZMM8I84iuY3/kxywwJM0U3l8dci+ny4XQLukVMZ1nsMTbMfZNfsx/F+9jG9Tr0RuXwowYBBPKdiawY+v5u8ZhGedhWnHH4yq1+5n41v/ZPtC16nbOqllBxyLIguMppOdSKHadrUxbPsbElT6nfRkFLxKRJBj4uanEYsY6BZGco0g8mDyojndFTLZmRliLBPKfS1lvgUWrMGpUE3qbzO5oYkDbEcLWmVzQ1pJAmwbWzLJpU3SWbSpFSdXc0iLkVi4sASJEGkNpFrs9TSCXolfLIXWxAQRQh6RCIBT5u3sUzfYomyoIvqqCOCJksCggj9i/00JPNkNYPeRR42N6Spj+fQDIvmZJ54Jo8iS0gitKZV3LKILAmsqY6zozmNYVhg2+xuSVPikxlSGaTYLxP2uskaWZpSKqpmkrdshpYFWLi9laBHQjNsJNkmk9MJeURaU3miGZVUziCnaWQ0nU31YBogyOBXZNbWJjmsd4jalEPqRaDVEijySBi6yrL3X6Fm7jOY2QT+UcdQPP0KxFAZGcDMQV0qR9At4FZkTNuicdNn1L13L1q0juKxJzDklG/SZAUKMZGZSxKd/Q+ymxbgqhpK6Wk/6iRkZmYTROc8RHbjxyil/Q+o93IwyG5ZjBFv6DT2Q5uw6bZlFB1zFXqgmLjuTPa3fPAgSml/IhPOwic6BE5SoHreU5jpVoae/zMG9Q5SHvSiSCKVITfFHjc+RaJPxEtJ4ItPAvTgy6OHYP8HQJIELLN7NuRtU8ouPuHb1P/zu0Q/eKBTGapvxFQ8a+cQXzAL37AjkUOliG4fkelX0frOX8msnUPg0BMBp/8oNGkmrW//ldz25V9IxEEQBIR2qyRDJ7P5E+RIL7KbP0GQFUITnMy6bZmIkkx40jnYep7E4hdpefMvZLctJzLjGuS27KN38ER6ffN+Yh8/QWrlm2S3LKb4hOvxDZ3UzbFFAqOPxTdsCqkVb7RZKHwX/5jjKDrq4m7tjQ4Wh/QOc+sZowGnjzmV06mNO+XNX0aw64sg0UauRRyP6ZNHV3bqyxaApGp0ykx3R9w13eL99Q2Fc7UL/+kekigwrl8RK3bF/levD5y+6MTi50mtfg9BlAhNOZ/wpHMP2OOU2/0Z0dn3O+Xco48lctw1hUANnNne2LxHSa9+D7m4D5WX3tGpCqKdxGbWzUEu7kPFJX/Cs4/4yBeFWrsJdc9aIsd+s6AECpDbvozc9uUUHXN1gUib+TTRDx/GVTWU4OGndtpPfN5jWLkkJRfcvt/JBXAmftot6nrQgx78Z8K2bTTTQhIE5DYyJwkCrjZxyWhaB0tAEATKgo6CtVuWiWsGgydOZ9eyo4gveo6iEUdjF/dGAnB5KT7heppe/g2NS14lPOV88jaUHHkhqfVziX3wABWX/aXw/vGPmUF84dMkFr/4hUpkv0hMIIgSrorBhMafSezjJ8is+5Dq6jX0OvUGGDAOExDdESrP+hmBEYuofu8Btj/+I4rGnY419VJa3T4sQM6qCEAgq6PrIfrO/BnFk9ax6+1HqX3rbuLLX6X03Ovx9puBRxLJmo5nsE+RqE2qaIZBYzKHLACCRMin0LvIiyg6JNS2LYJuGa8iURfLOpVjpsXGuiQZzaAkoCBh8+nuBE3JLEZbe5VoS8iSTTxvoLX1PstAq2qAJNIYzxPwSrQm8zQlVQzbwLK9lAY9VBV56RPxUR/PEfIoeBSJaFalKuzFsGx2RdOICJQHPfjdCrLkxH5Bj0xzUiWvqtTGMrhkEZcoEcvplIZstjam2NqQRJIlWhI5dkZz5HWDZE5nWEUYWRKpjuU4vF8xvYt8tGTylPldZFWd1nSe1pROyK2gCAKiKJFK5ciJNnUxmd2xDK1plYaEimnqNCVNVBtcElgmyDYIlsGWhhRbGlJOq4IFoghuyWbLJ/Oo/ugZ8tF6PH3HUHbubQWNIdrunQ20JAyiQJknQe1bD1O/8gPkokrKL/gt3gFjiVp7Q53slkW0zr4fK5cmPPVSwpPPKzzjtm2T3fgx0TkPYalZwkddTHjKeZ0m5L8obNsmufQl5EgVvg4CZrahEf3wQccdpC3pBRBf8BRGqpleZ96BS5LxKuBxy4jN22hZ9hbDpp3F9GlHMqw8RK+In37FfuJZDUGwsWyBkqCLvG7hdfWUhP9foYdg/xuhOyunl1fVdOlx7Yj+xT42NqRQiioJH30x8XmPkd2yCH9bFk0QBIpP/I5Dvuc8QPnMWwDwjzmW9GfvE5v3ON6hUwq+vv6R04kvfIbEomfxDp74pYQRiqZdRnzh0ySXvox38ERCk88rrLOySfR4PbkdK8htWUL5ubeh1m0iseh58rtXU3TMVYhuP/42RfKSE7+Nf9QxRN+/l+ZXfoN36GSKZ1zbLWkWFQ/hKecTOOwkEotfIPXpO2TWzyNw2EmEJ5/XyQLpQOjYS72mJsEf393I6up4lx7lr5p8FvtdJPM6Zlum2W47l0N6h3HLYhfRM0USEHB6wG2ctoHuYEEnr/TPg2XZ/+vk2swmSC59mdSqt7FNncChJxI+6sL99lmDQ8Zj8x4ls36uM4h2U86d27HSyUqnWwlN/IbjC99W2mXbNtlNC4l++CBWLuX0+R910QHFew4W8UXPInpDBMaeXFhm6SrROQ+hlPTrPJDOfdQh0ed3JtH5PWtJr5lN6IiZXcrJ2hHyyKTzBrYNt7+1nuGVwZ4sdg968B8I27ZpTObJaCYiUFXkxaM4glUZzSkFriryEs/kCfgCeFwSed1iQLGf0ZUhPtqYoeqEa9m+61OaZ9/HoAt/hyUKZC3wDpmEf9iRJBY9h2/E0SiRXhhuH5HjvknLm38hvWY2wbYKGlFxE5p4NvGPn0Ct29xty9bn4QvFBDNvAUkh+u497Hr2VorGHIs8YDyh0cdgWiANPpK+1xxG68dPEl/5JqnNC4nMuBbf8KMw2uKVWFtXVjoLSmQMlZfcSXDzJ7QsmMVH9/2crbNHc863fsi4KdNQdRd50yKaylEW9lLidyMCQb+LaFpDM0wGlXvZVJ/ELQss3dkKto1blkjlTPZE06imRdinUB/PklENFEkmntWxBQsDgWKviM+tEPEolBV5SeY0amI5PCLE8zpzNzdQ6pdpzqiEPQq6JXJInyJOGlOFW5FI5nSKvI6ndmtWRbBgyY4ou1tS2LbTLuRTJMIeBct2/Li9LomUqhP2u+hT5COa0agIyrhlmdrWDG+srqUhmiOhaqRVi95hN25ZwrRsdMMk4FcAAUWWsC2TmlgW07TY3pSmKZmnNavxWU2U0pCHPmEv6Zwj0NaQaCWZ1yjyKaimQS5votttfc8maIBgQHqf4jzbtshuXkRq0TOozXtQygdSfu5teAZN6BKH6m1/bNsiu3YOW+Y9jqlmCE86l9BRFxbsO3WcXv/onAedrHXFYErOv73T+Gokm4i+fz+5HStwVQ2n5JTvFUR4/xXkdqxAa9hG8ck3dhrnE8tecdxBzv9NgcCrDdtIrXyTwNhTCPQZScgDYb+LsoDM/H/eSyBSwslXfh/ZpdCc0vC5ZSYNKqE44MbvcpToPbJEYypPv4gPscfv+v8EPQT73wTdeQhvbkjx/PLq/RIdAejbRrABQhPPJrPhY2IfPIC3/2EFkSaHfF9EfN7jZDZ/gn/4UY7S+InXU//4D4jPf4KSk25w9inJhKecT/S9v3/hLHZHFB19CbahO7YfmRjpdR+h1qxHa9iGXNwHd+VgSs/6Ca6yAXgGHIZ3yCSaX/oV0XfuBkCbcgGRaZcB4Okzkqor7yG54jUSnzxL3SPfJjTlPMJHzOyWHEm+MMUzvkVo4tkkFj9P+rP3SK953yFxk849YEa7LOCiNOAu3FMbuvXY/t/Aw5dPABxhu+eW78G0nAmSdXUJzG56BI4ZXs7oXuED9jR/GfxvEmsjHSW17FVSq9/BNnT8o6YTPuoilEiv/Z+PaZD69G3iC57GNjXCUy4gNOX8Tj1RZjZB7KNHyKyfi1LSl7JL/oy794i9x002E/3gH+S2LcNVOYSS827HVdE9if2iUGs3kd+xkqLpV3QqMU8sfgEz0UjpRb8vZLU7kegOx7cNjdb370UOVxA++uJuj+NRRM44rBfPLN2DjSN21ON33YMe/GdCN20yqknAI5PXTVJ5nbxusqs1TTxjsLMlQ6+wm6GVITKa5SgGmyaSKDC8Msja2hj2oL64zriG9S/9jdZ1c/CPcYSWfAJUHX8d2x9ZTfT9eym/4HcIgoBv5HTcn71P/OMn8A2dUnBZCB5+GsllrxJf+DQV59/+pa7ni8QEtm1RcfldNL/0a+Lr5sK6ueR2fUrpaT90bMTcfiInfhvfmOOIzr6fltf/iKf/WIqPv65La5gOuAQB34ijGTRsCvKe+dTOfZa7f3INvYeO4aSLr2PgYUejKBI51SDsdRHPaCRyaUxEyoJugh6FaFqnMaWxoymJIklUht3UJ5w+eFUz0E0bnyzhdYmUhR07K9G2KfG50U2DykCAPqUBinxu4jkNjyyzoT5OKqdi2lAfz5PRdASfgN7m3uFRZPxuCb9LoiaWYVtdvCBclsxpxFI6Eb+CbtrkdAufSyarGcRzGvVJR8nctAWGVoSoiWWpLPKQzutsqIvTEM8SVXWi6TwSAjtbNfwuhaBbIaMahP0yimiztT7JLtkmmtFpjGVY25DCMGxCPgWPLOISRdbUxKiNZbFFR01cN2xsnOvP6E5M0Z4LjsigGdAugWpbJtlNC0ksfgG9ZTdycZ+2lq+j9ytaCw4pjX7wD7S6zbj7jKLfid/B6kCMbdsms/4jYh8+gqXnKJp6GaFJ5xTGYqc/+m3iC2aBbRE57lsEx59+wMqxg4Vt2yQWPo0UriAw5ri9z2K8geTiF/ANP7qQHBBMg9Z3/4bkL6Ji+hWEvOB1KwwqDxFd/DK1OzbxkzseoLSilJxmo8gCiaxBa0ZjcFkQ3bJoieVpSFp4JYm+kZ4+7P8r9BDsryG6y1S32yu1i1W9vKqGF5ZXd7FG6liWvO8slSBKlJxyIw1P3kRs3mOUnPy9wrrQhLPJbJjfiXy7ygcRHH8GqRVvEBgzA3fvkQAExswgueQl4vOfxDt4wgFfcgdCuwK4kWolsehZR9Dh/Ns7ZSmtfJrk8tewdRX/IcdjxBvIblxAcsXrSP4iAmNPQZRkh/hPOhf/yGnEPvoniQVPkVnzAZHjvol36JRuM+1yqIySk24gPPk8EotfJP3ZbNKfvY9/1HSH4HQzS9mc1r6U4NiBMLIyWCDsB8KDH2/nuumDGd0rXCDU+37/HTFvSzOlwX+Pnhs9Vkdy2auk184By3SI9ZTzu6ht74vczk+JffgweusePAPHOQFUce/Cetu2yaz7iNjcfzqlXUdeSHjKBYVnr9MgallEjr2a4ISzvpJBtP34sflPIvrCBMedvvd6W6pJLn0Z/+hjCyJtlq7S+t7fuyXR8UXPYURrKT//N4XZ946oDLm575LxfLBPqX8qp3fZtgc96MG/F1TDJJnTkUWRsFdBFIU2BWCBvG6im05WsTWjkc6bNCRyGKZJfTKH3y1jAdubkhT7FFyKTN/SAAPL/eR0i/Lxp7F7+Yc0fvhPBgyYgBKI4HWBHCmh9JgraZ59P5m1HziaF4JAyQnfoe6x7xGb+09KT/8RgGO7Nekc4vMeI1+97kupKcMXjwk8fcfgGXA4ySUvkFn3IZaaITLjOpSwozPj7jWcvpffRWL1u0Tnz6LusRsIjjud8FEXdbJobB/RdVGidOSxjDvsWISdi1jy6mM8+uvvEe41kJEnXMz46aegmTrxrEaJ30WZz01jUsXrUUjmNXY2O+O4LArsasmi6jqZnEE0o2JalqPZIgjUxfMoigCWQEbVcSkimTY3mPKQi94RD9lSP0lNJ6uayJJAXs+RzBqouk2x36lU2taYpHexD78ikdcsqkJetjUm2d6UxK/I1CfzNGZFBpb4KQ0opFQdEOhV5GFPaxbbtgm6ZXZHMwwo8VMSdLG6Js6e1gzZvElLUkXXbXxugbxuU+QTGNk7hN/tIqsZpPMaFUGoieVoSeVJqhqa5lh+SXkRWbSojWZoSOSI5534NJ3R8bjBNAwEUcQrQ0YDAyeL7W37PmxDI73uI5LLXsaI1aOU9KX0jJsdLZUDjM9mJkZ8/izSaz5A9IUpOfWH+Mcci9UhTtVjdUTfv4/87s9w9x5Jyck3dpp40Rp30Pr+vWj1W/AMHE/JSd9xnG6+ImS3LEJr2EbJqT/YS+jb3EEQJSIz9gqbJpa/ht60g77f+DluXwC3BKJgo7fW8vpjf2P05OPod/h0MqqJahgIosyAEg9DyvwU+V20pPJsa8yiSAIuRWRwXifs+9er8nrw+egh2F8zdJepHt8/4tgotfkuS6JT+mt1o0LlkkXUthe1adl8uLGzIra7cgihiWeTXPYK/pHT8fR3gntBkveS77mPFiy6io6+hOzmT2h9/z6qrrgboY3MFk29lJY3HfGwz7PnANCadqC3VCO4fbjKBnYqx3ZXDqHq8r8Smz+LljfuIDzlfLyDxqPH6kgsfAbb0Cj7xi+w8mma37iD8LTLUPesJTbnQTLrP6L4xO/irhwCgBwqp+zsn5PbtZrYhw/R/Orvcfc7hMix3yxssy/kcAUlJ99A+MgLSS57hfSa98ms+wjPoPGEJpyNZ8DY/1WPwC1N6YPabvaGRj7Y0EjAfXDkzzQtWlKfL5b2/wu2baPWbiC1/HWyWxaDJBEYM4PQpHNRIlUH/KzeUk1s3qPkti9HLqqkbOYteIdM6vQ96S3VtH5wP+qetbh7j6T4pBtwlfUvrFcbtv0/9s46Wq76/PqfI+N+3XLj7u4QILi7e2kpXkpbKtRboAIUK7QU1xYI7hKIu7tf93E99v5x5k4yuTchFO3vvXutrKw7c+bIyPk+sp+96Xj3ATJN27D3HU/BMd/H4i874DG1WJBMy070VAzZX4a1fNBnfi9Su1aSrllL4Kjv5rrXhmHQ/t4DiBYbgSOuzG0bXvg8arCBknN/n5dEZ1p2msn4iNkHVDC/4ahBjO8d4Levb8h7fENj5KDn14Me9ODbDV03qeACoGgmbzbgsiIAVlmgI56h0GXBJstkVAMBg5ZoGrts0BhRqe9IMKTCiyQIBLOdUYckoCoCNklAM3T6nnI96x+8jo6PH6bfWbeCCC5ZYsC048hs/YTwx//C228iujuApagX3slnEFn0b1wjZ+PoPRoAz7gTiS5/leDcxyi76C+feW/UklHStesxdM28nxb3ySUbnzcmCBz5HfRMgtC8Z0ntuQb/jAvxjD8ZQZRQRQnvuJOwD5lpCqcuf434ho/xTT8fz5jj8zQxAIJJcwSscNARDPv+ZNo2fErj/BdZ/MQfWP/GI8w85UIKxx7DmqAVn9PKiHI3o3v5sRS6aA0laYqlCSUypDM6LpsFj12iLWogixLBRIaUouCxmlZdFouV1niKYrcDRdVpi2ZIZoJUBVxUBOxM7l1ALKHSEEyQVjUqAnY0DYLxNIu2t6KoBn2KXQwr97KrNUpzNE1zOIVqaLQkVQrdVmwWGb/DgiRJJNIaRW4re9pSbGqKkspolPsFSpwS6xpCpOogFEtT5bfjd0q0x3RUFVKqjiSArovsaU9S6FYxDAFN14klFYLJJMmMQSxtiqkW2Kxoho4siqjoZDSwihDXs51qDVQJ1IxOKlvdkDEt28KRMMFVbxFd+SZ6IoS1bABFp/0U56CpB23mGGqGyPLXCC96AUPN4JlwCv4ZFyDaXHnbhJe8RHjRvxEkCwXHXIN7zHG5/eqZJOEFzxFZ9gqiw0vRyT/COfSwA36XDVUh07ITNdyMYLFh7zXyMz2wDU0l9OlTWAp74dondk5snk9q1woCR16VKygpHfWEFjyLY9BUAkOmgwS6IaJqOm8++nsESebMa2/DbpGxyhKCKKJpKsMrAtgsFjKqTmPI9GXvVeAkltZIqRrOrOd5D75a9CTY3xJ0dq0bsmqOnZ3qxTvbAXPWOpdOCwLDK0yRiX09sIFcct0JzegqtuWbcUHO97f8ivtywbytbADeSacTWfJSzntXtDkpmP09Wuf8gciyV/BNOQvAtNdY+jKhT5/EOWjaAS0KDMOg4537iK19L+9xuaAK15AZuEYciSVQgWh3U3jM90lsXUiqZi2OfuOxBCqwVQ0jvPB5oivfREuEEW1O/FPPwZhytik68dEjND15M+4xx+M/7OJcVdrRZwz2y+8jtvodQvOfoemJm3ANPwL/zIsOWImUvUUUzP4uvunnEV31FtEVb9Dy79uwFFbjGX8SrmGzPvPm+d/gYF3o/WEA0fSB1bb3hW5Akcf2lYutfV7oSprE5nlEV75Bpmk7ot2Nd8pZeMaffFDbKTCT3NCCZ4iteQ/BYsc/6zK840/N80PXM0nCC18wF0mLzVQcH33s3kU0nSA07ymiK99EsLspOO4G3KOOPvAiqms0PXULmaZteY87Bk2l+LSfHnDRN3SN4NzHkP1leMbuVfyOr3ufdO16Co69LkexTDdtN5PokbPzPa91jfa3/obo8BA48kq6gwCsbwjz7JKaLgn18HIvD3y8PY8N04Me9ODbj3AiQzyjYZUFVM3AbZMR0MloOmlVM+ev0yoFTitJRcfrkNATGSRJpNBpZUVNG1saowjo7GgOE3DbGVDqRVGTJNMKzfEUqYyBzWahrFdfwkdcSM0HT5DatYzA4MnIEqQyEmXHX8/Wf1xPy/t/p+j0nwHgm3ouiU3z6Hj3ASquuB9BtpoaJzMuouOde0lsWYBryIwDXluqdj0tL/4GI5PMPSZYHTj6TcA1/Agc/cZ/rpigUwzSOXgGHe//neBHjxBb9wEFx3wfe9VwNLLWncddj2fsiXR89AjBDx42rTsPuxTHoL0sNwnQdUhkMmQECf+oIygafRTUraRu3su8/dhdSM88SPn4oxl/3Dm0uAazcGc7yZRK2jDIqBqlPgftkRSSLOK1yaRVnY5Exty3BqohEE7plNsMKjwuHFaRcFrFkUzTEYeWaJpwwoPFIjKjfyHrmiwkVZVwLE1DOI2qAsRIqDoCxei6gaLptESSNIaS2KwyiqZis0hYLRI1HQlCiQxTBxZR6LYiCOBxWEhkVFbs6sBlk2iOKlT67TSrGkt2t6LrIlZJQJJBwEDXIK0qeO0u2qJpAm47hqGzrTVGgduCRZZwahoBl41Ctw3BAEOAYFKhviNJRjHfWwvm49GE2akWMWn66cZtRFe+SXzTJ6ApZoNj0hnYq0cdtFhj6JrpEjPvKbRIK44BkwjMuiJP/RtM/ZWODx5CDTbiHDKTwFFX5eINwzBIbl9Cx/sPo0VbcQ6Zif/wy7D4D9y1jix/jdAnT2Coe5sYosNL6bm/P+h4WWzNu6gddRSfcVuuE6+nYgQ//AfW0v54xp+UPSed9nfuA8lCweyrietgNyAp6nQsfYeWrasYcfbNrGwz2BhqYHCFl9FVASoCpoJ4fTCBwybhsEo4LDI72+JU+u2EkyqJtI7bLuV5Yvfgy4dgfN1ePMCECROM5cuXf+3H/bZixZ4g5/9zcdZKyfQm1DQdiyzyy5OG89s3NpBS9ibOncJWG7Kzt//NJ5iqWUvzcz/DM/E0Cvbx2dWVNI2PXY+ha1RccX+u69Yy5w+kdq6g/Ir7c/OwqT1raX7+Z/gPuwTf1HO6PU66cStNT96MZ/zJuMccj56Kk2ncQnLHUlI168HQsfcZi3fCKdj7je82WVEjbbS9egeZ1l0UnngzrsHTc8qieipGaP4zRFe+iWh34z/sEjNZ2odCpKfjhBf9h8jyVwEDz9gT8U05O5fcHAiGqhDf9AnRFa+Tad6BYHXgGnY47tHHHbAb/m2DTRLIZEXRvmlkWvcQW/ueSeVLxfYWLoYfiWg9+I1eT8UIL51DdPkrGJqKZ8zxJs1vH3VwU+nzU4IfP4oWa8c1YjaBWZflPudOJdDgR/9CiwcRnT5kXymyp4jAkd856Ox98JMnEO0ubBVDkBw+4ps+JbzwOUrO+hWO/Tw0OxFd/Q4d795P0am35oJNLR6k4ZHvYynqTekFtyMIIoam0vjkD9DjIcq/8/c86mJ48X8IffJE3j66g4Cp7q7pez/rSX0CrK0Pd2HD9OC/hyAIKwzDmPBNn8f/NfTEBPlIZVS2NkeRJRFZFLDJEmnN/B0Xuqw0hpKsqQ3SHE4xptqPzSLjsclouo4kiqyqbefF5fVEUxkMXafYa8dttdCv2ENrPI0swJ72OPWhJHaLiFUWsYvw6V+vJhHuYNyND5OUHKQyBqoK4aUv0vTx43n3oeTu1bS88Au8U88hcNglgJnsND52A4aSouI7D+UVPvdF45M/QE/GKDzxZkSbA6WthtTu1SS2LUZPRpA8RXjGnoB7zHFIDm+X1x8sJjAMg+TWRXR8+E+0aCuuYbPwz7oc2VOYe71hGCR3LCU093GU9lqs5YMJHH5pjtFnBzTMDpTVAhYJqovdzB5UzKLlq/jwpSdoXjMXXVUo7DeSw08+F++wGZT6PKiCgcduwSpKuBwyqYyG0yKwvSXGxvoQoWQGhyyjCSAaOqIsUeK2UlXoJpnRKfNYUVQDVdOIqTrt0TRWQSSqZNjdkqAjY56XQwafU2Ri3wIMQ8RlEwnGVbY0R1EU0/e70GOn2GWjJZEh4LDic8qMry7EaZOIJhVW14ZpiaQo9Vpoj6RIGxI7mkLYrBKqpiNjoBoC0bS5rvQKWBhRWcTOlhAWWUIzdEJxhXG9C9AMsMkiwyu9rK4JkVENVF0npWrEUxrJtEo8o5j/UqDqkEolzLV0zbtkmrYhWOy4hh+Bd/wpn2mjan6Gywh9+iRK626spf0JHHFl7jPshBJqIvjRIyS3LUYuqKRg9tV5TDAl1ETwg4dNNlygAkTJ9Gt3BQ4aF6Rq1xPfNA9H79HIBRXoyQitr/0JW+kASs7+dbev0VMx6v/xXSxF1ZSef3uucND+zn3E1r5P2SV35WLL6Mo36Xj/7xQefwPuUcdgx+zwE2mj5tFrcJUPYNR37sBrt6IaJptl+oBiBpd5KfDYCbisYEChy4ZhGETTKn6njKqZsUI8rdK32N3Tyf4ScKC4oKeD/S3Ayyvrcp1oVTOY1MePzSJx/IjyLKUovyutG7CmLvyFjmmvHoV7zHFEl7+Ga/CMnOCTaLFRePwNND/7U0KfPknB7O8BUDD7ezQ8cg3t79xP6Xmm4Im99ygcA6cQXvRvXCOOylvAOqG07gbAM+HUHP3WXjUU78TTUKNtxNZ9QGzV27S8+BssRdV4J5+Fa9jhuYUSQzetwxweHP0moAZNpezOBFq0uymY/T3cI4+m44OH6Hj3fmKr3yZw1FW5OTDR5iIw6zI8404kNP9ZoiteJ7bmXTzjT8Y76YycQvr+EGQL7pGzcY04ikzDFqKr3yK+/iNiq9/BUtIX94jZuIYdhuT69iYt6YMozH8ZkASTJXEgaMkoic3ziK37kEzjFhBlnIOm4hlzPLbqkZ9JI9TTCaIrXiey9GX0dNysLB92cRfRs3TTdoIf/oN03Uaspf0pOvVW7FVDc89nWnfT8f5DpGvXYy0baPq4W2z4p59P8ONHiax4DeeAydirR2IYRpfzChx+ad7f3kmnE174HJm2Pd0m2Ho6TmjeU9iqhuX5Xne8/zC6kqLwuOtyxaTw4v+gtOyi+PSf5yXXSnstofnP4hw07aDJNZhFNt0wx0cMw8Aiiwwo9bB8TxDdMAXPXl5Z15Ng96AH/wMIJRUiKQWXTaY+lqbc58Qmi/jsFmRRpC6YJJRS0QyDuZtbGVzhpdxnpy2WYnzvQsq8Doo8VjRdJ5zIYJNFxvctwGGRSesq8bSGqht47DIWyVQeL/c5OfPG3/GPH11Iy4ePUnbSTehqhgxgnXA61k3z6Xj/79irRyI5fTj6jME14iiTeTNkBtaSfgiiROCoq2h54RdElr+Cbx818H2htNbgHnt87h5tLe6Da+hhFBxzDckdS4mufIvQp08SXvQC7lHH4p10BrK36JBiAkEQcA6ehr3vOMKL/0Nk6cskti3GN/UcPBNORbTYzG0GTMbRbwKxdR8Qnv8szc//DHvvUfhmXIxYNRQds7OaVMCpgizotMbTFPUZxKjzfszuwy4jvOYDwmve5eW//QKLw82gqUfTb8rxDB4zlhKXnYDLQnmpC7/TTmXAhdcus7Uphs0iEUyk0HUN3YD2qIIoxHFZLTSGDTpiKdoSCg6LRFpRERBIaxqSBDbMREuSwNANdrYmkAWBtlgaQRRIpBR0ARwWibaoQiieJqHohKJpQmmV+rY4cVUBJOyigM0qsK4+hU0SEUTVdCjRdQzdoEMB2TCwSeB1yNglidqOGD6nzKo9UcJqlvIthRnby8/ISj8pVSOcVIjEFdrjKfwOO6UFDkAnk1FpDCaJ7lpG8/IPCW1egKGksBRVE5j9PdwjjsyjdHcHwzBI7V5NaP7TZBq2IPvLs1TumXkNGj2dMD//ZXMQRBn/4ZfinXBaruijK2kii18kvORFBEkmcMQVIFnRU9FDigvsvUZ00Rpw9BlLqnb9gX/XC59HT0YJHPmd3L5SNWuJrXkX76QzcJUNQAZi4WaCnzyOvc9YXCNN4UFRBIdh0PDB/aBrDDrrJpwWiaRuYBMFCtw2Ag4bqgG9Ag5cVplwyvw8daBPoYuMplPbkSCcMpXlCt22Hm/srxA9Cfa3APvnJyv2BLPq1B388qThXxnFNzDrCpI7VtD21j1UXH5vTnHb3msEnvEnEV3xOs7B07H3GmF2+Y64nI53HzCT06zdUOCIK2n41/cJzX2MopNv6XIMIUs/N5RUl+dkTxH+aefhm3wW8c3ziCx+kfY37yK84Fl8U8/BNfxIBEkm3bgNXUlRdtavcq/d/2ZnLe1H6QV3mh3MuY/R/OytOAdNwz/rslwyJnuLKTrhRnyTzyS04Fkii18kuvINPONOwjvxtLxuaN41CAK2yiHYKoegH/Vd4hs/IbbuA4If/ZPgx//C3mcMrqGH4xw4OafM/v8Lukuu9UyS5PalxDfPI7ljOegqlqLeBI64EteIIw/4PuftIx0nuuINIsteQU9FcQyYhH/GRV2oV2qkjdC8J4mv/wjR6afg2OvyGAxaMkp4wbNEV7yBYHUQOOoqPONOou31v2ApqgbAM/4UEps/Jbl9KfZDSPoB9Oz3WZC7X5xC859FT0QInP2b3P4SWxeR2DIf/8yLc+JtmZZdhBe+gHPo4flemLpG21v3IFrsFBzz/YOeS9bWNMd4CSYyTOlXyJamqHlsw+w+/Gd5LWeMq+pJsnvQg285JFHAbZdRdcioGl6HhFU2bZW8DgsYBpqmUe53IIspbKKZfLdG0gTjCgGHhYHFHgpdVpIZnan9ChldHSCSzBBNpmkIRSjzWumIi9htEr0LnGiqTrPYm8mnXsbiOf+iYsxhFFaPI9WRQRElCk+4icbHb6Lj/YcoPvUnAASOvNK0Pnz7Xsou/iuCKOHoM8YsvC98AdewI7q1vxQsNoxM15hAkGScg6bhHDSNTOtuIktfJrrqTaKr3sI96mh8U89G9pYcUkwgWu0EDrsY98jZBOc+as5fr36bwOGXZudqRQRRwjP6WNzDjyC66i3Ci1+k+ZkfYe8zFt+0c3MJVMKAmrYEkbhCa0whpUDa6sMz8UwKJ56O1ryB2PqP2Dz/LTZ8NIe5xRWUjT6C4lGHU1zdn3F9Chlc7qfQ42SYKNIaS7G1KUUobiCKYJFBNQycskIso2ITRTxOmWgyTVoX8dgElLRGKGkmSwCZtJlgtzcksm8e2GTzf90AQ9UADUmAuAppi46hw7amOAkVAk6Ip6CswI7bJqLoBh67laRLoymYQRDBKYNVxuzqSgIOh5UCu4WacJJI1kZLAWraUnisITqiKZKKQlrVsVttCBYJh0MimsigtW5nw4L32broPZKhVkSrA9fQw3CPOhprxZDPXHcNwyC1ayXhhS+Qrt+I5Ck21/uRs/Pm6A1dI7b2fULzn0aPh8zRwMMvzc01G4ZBYssCk+kWacHWZywFR38Pa0EVra/e+YXiAkNJHzAmUNpqTautUUfnutS6kjKFTf1lFM+4AJcMimrQ8fa9ABQedz2CIOCWwG+H0IZPCW9bzpBTrmHq6GHYrRIb6iPYRPA5LOiGTnWBE0EQSao6pV47/n0FzdIq7bE0sijQK+AkklJ7EuyvED0U8W8BVuwJcv4/FqFoRs5j2cAMnG8+ZjDbmqNdPI47MbrK9EBeuo9VVKHLQiihHLSz2InkzhW0/OdXeCefRWDWZbnH9UyKxsdMa67yy+9DtDowDJ3m539OpmkHFVc+mFs4Q58+RXjRC5RecEeXil6mdQ+Nj15L4fE34h519EHPxTB0ktuXEl74vGnN4S/DN+08UwhC1xFkC4am5t1M9Uwy68l5wj6VyRSRpXOILHkJQ1NMOvG087pQwjOtuwkvfIHE5vkIFivu0cfhnXgasrf4s984INNWQ3zDx8Q3fYoWbgZRxt57NM5BU3EMmPSZ88T/l6AlwiR3LCexbRGpXSsx1AySuwDn0MNwDZuFtbT/IS1QWjxEZMVrpv91Oo6j/0R808/HVj4obzs9FSO85EWiy1/DMAy8E07BN/WcXPXb0DViq98h+OkTGOkEotOHvddIDCVFydm/Jr5pHsmdyyg8/kYEUSJVs5b45vm4hh52SAq4ie1LaX3pt5Se/8ecCngnMq27aXzsBtyjj8nZ22mpGI2PfB/R5af8krsRJDlLDb8ZLdZBxZUP5BUewkteJjT3UQpP+uFnigieNqaCgaUepvQzGSSLd7YTcFq7jJZ03k+uPeJ/Y7zh24geivhXg56YIB9pVaMpnELTDVKKhssmY2Bgl2XKfHYiiQxzt7WQSusUuC1YZRFZELFZJfoXu7FIpop1SlEQBZFyv4NERiOjaWxujJBMa7REUyQzKrIoUBVws6kpjMsiYpV07rj6TKKRMDN/9CgxVUK2iLSFVeoXPE9o3tN5VPH45vm0vXoH/sMvy2m0KKEmGv91DY4Bk3PJ+L5oevZWDCVF+aX3fOZ7oYabCS9+kdi698EA9+hj8E05B8np6zYmAIht+Bhb+aA8R4nUnrV0fPQISstOrGUD8c+6LCfQ1gk9kyK66i2TMZUIYes1At/ks7Lja92vXyJQbAebLKApSdo3LqRh5cdEd60BQ8daUIl/6FQGjZ/FzJlTSaQUtjRFaY2lCaZMKrqASUt3uwQkUULUtewwODgtViwWg2TaoC6s0GkPLQFOCaKa+XobIIngspk2V6GsgYTL3A02ixlbplRz9tklQUoDmwTFHglNF7DKAhlVJxTTkESza2oRQZTAZ7fSp8CFJsLmhgh10XwtGJ9svhmqZq41LitkGjYT3LyYpnXzSXY0IUgy5UMn4hwyA7F6MuluHDH2h6FrJLYtJrL4RTJN25A8xfimnGWq2u8zgrCX9v8ESnsNtsphBI68Ms+XPdO8k46P/km6Zh2CxYHkKcBWPgg9Gf3CcYFhGNQ/dAW28kEUn/bTLs+1vJCNnb/7j9xaH/zoX0SWzaH0vD/i6D2KIgfUL3qLxncfpODYa3Ne8xagXI6y+K6r8Ff05eLf/4silwOfSyatGjSGUnjtEoVuGyOr/IypLsAii1hEkY54mnhGw22TiaQUokmFcErFa7NQ7LVS4e+x7fqiOFBc0JNgf0vQKXLWGRgrqjmD3Tk3ecdbm3hnQxNjevl5a31T7vnnrpoCwHnZBN0iCZw9oRfPL605ZP/j9rfvJbbuA8r28wZO1a6n+dmf4h57AoXZLpoSbKTxseuw9xpJ8Vm/QhAEdCVFwyPXIFrtlF92b3410dCpf/AybBVDKM4KpHwWOmdrwvOfIdO8A7mgEv/0C7pQgABi6z+i/c27kH2l+GddjnPw9NxCqMY6CM9/ltja9xAsNrwTTsM76bQuFCSlrZbwkv8Q3zAXBAHX0MPwTjoda8mh+SAbhkGmcSuJzfNJbF2IGjaV261lA3D0HY+971hsFUO6BAH/yzB0jUzTdpK7VpLauYJ0wxbAQPIU4Rw4BeeQGdiqhh2yhZvSXkdk+SvE1n0Imopz8DR8U8/BWto/bztdSRFd+QaRxS+ip+K4hh2O/7CLc8J1hmGQ2rmC4MePorTXIAcqzG1mXAhA7b0XUHHVQ2jRNqKr3sLRbwLOgVNQI61El7+Gvd/4PJGxA6H93QeIb/iIXjc8l+e1bhg6zc/eitJeR8VVD+XmB9vevIv4hrl5M1ah+c8QXvAcxaf/PK97rbTX0vDYDTj6jaf49J8fUmHij6ePZHCZJ+dAIAoCumHk7gECYLP0zGF/UfQk2F8NemKCrtB0A003i+6xtJlWeewWpKz9pqrpKJqRtXDSSGZUvA4rdovJ3klmNIKJDFZJwCpLtMfSWGSR7U0xRNFgT0eCfsUuOqIptrXEUHUDuyzSHlPZs30tT916KcNmHs/gs28hk1FJKBod0SQbH7kFJdxCnysfRHP5MQyDtlduJ7FjGRWX3ZubnQ0teI7w/GcoOed3XdwPzOeepfLaJw65EK1GWkw7zbXvgyDgGXOcmWi78+9nupKi/qHvoKeipt7K9PNy92HD0Ilv+JjQp0+jRVux9xmL/7BLsJUP7LKP2Jr3iCx9GS3ahqW4D96Jp+MaeliXuXIRKHUJCIKBZogIuk4wCZl4kMjWRSS2LCBVs85Mtt1+CgZNwN1vLNbeo0nJfjTMBEoEnHZwWkUyio5FEvA6bJS7LTTEVJKZJA1hcgl2d5CBYgfElb2JdCcEwGM1k99w2ky6ZcwutYEp7KYbkDbM56TsOdktUOmz4nVa0XSRIo+FplCabS0xEvvprWrRNpK716DsWkFi90rUZMykXw8Yi2/IVOz9piLYPMQPITbVlRTx9R8RWfYKarABOVCOd/JZuEcciSDlfwap2vWEPnmSdP1GM1487BKcg6blx4Lznia29n0EuxtHn7HY+4zGM/pY4MuJCzKtu2l89Lq8xLgT8Y2f0Pb6nyk45pqcIF+6fhNNT/8Y95hjc4V4JdRE46PXYascSsk5v0UQTMcg2TCIvP572rev5vI/PcuQIYOpDabo5XfSv9RFUyhNqc/G5qYokigwtX8h0/oXEUmpdMQzuGwykaTZdPPZZdpjGRxWkT5FbuSeGewvjJ4E+38I3flgf9bz+z4G5ALtQ0my9XSChn9di2CxUX7Z3/IUwTs+/CfR5a9Scu7vczeYyPLXCH74j7yudGL7Elpf+h3+wy/tMnfV8f5DRNe8S9V1T+XNmH4WDMMguW0xofnPoLTuxlLUG//MC7v4Wid3rSL40SMobXuwVQzBf8QV2KuG5Z5X2usIffokia0LEe0evJPPwDPupJyAWyfUcAuRZXOIrX0fQ0lhqx6Jd/zJpv3TIfoiG4aB0raH5PalJHcsJ92wGQwdwWLHVjUce/UIbFXDsJUNzEvMvu0wNIVM807SdRtI1a4nVbsBIx0HBLOQ0H8CjgGTD7lTDWawk9q1isiK10jtXAGSBfeII/FOOiOv8wCmvUZ0zbtEFv0bLR7E3nc8gcMvzaOMZ5p30PbW31Badpr0sdnfxT5gEmL2swsteA6lZZdJ0S7qRXTVW6TrN1F44g8QBJGWF3+De9QxOAdN7XYOe99zqXvgUux9xnTpzkTXvEvHO/flhElgb7fbN/Vc/IddDJgz401P3oxr2OE5L1nIKpY//WPUYAMVVz7YJXg8EAYUuyj3O5i/rQ0DMzASs/PYkmgW3Xro4V8cPQn2V4OemOCrRUbVaQgnwYBkRsVpk8hoBpIg0BpNEUspbGmOsKctQbHXhqYazH32ft559iGu/M2DVI+eypKdQXa2JUg37WHXEzfi7D+RotN+hiAIWfHGa5AD5ZRd9GdTQ0XN0PDodYBB+eX358UVSnstDY98H/+sK/BNPuNzXYsabia04Hni6z9EkC14xp2Md/IZeWJoWjxIaN4zxNa+h2h14J16Dt7xJ+fWXEPNEF35BuHFL6InIzgGTsE/40KsJX3zjmVoCvGNnxJZ+jJK2x4kVwD32BPwjD4OyR3Ije/ZADm7XMgW0HTQVOjUSNdSMdI7l6PsXEZ05yq0pOn2YCnug716JJ5ew/FWDsVdXIjHYSGl6hQ5rRg6dCSS6LqEgkZb4sDvi4iZRFuAruR7E3bAZ4doBhJZcpMVsAhmYr1/8t7ZWbdYIOCx4bGIlPishOIqjaEYTQ2txOo3mjFBzXrUjjoAJJcf/4DxlI+Yil4xCl10kuTQoEZaiK56i9jqd9FTUaxlA/FOOgPn4Gld4rB0wxZC854mtXsVkrsA37TzzM52tpmhpxNEls4hvOQ/oGm4Rh5F4MjvINqcueL/lxUXBOc+RmTpHKqufSJPl0dLxWh45Gpkb7FpXydKppjw4zdiqGkqrngA0ebE0DWTIdq8k4orHyDgLyaRtTVLbXmfPa/8jdFnXs+EE89nYImXcp+VQo/T1MHRDDY2RBBkgV4FTgRgYp9C4mmVWFqld6GTpKLjsEikVA1ZFCn12rHKPcn1l4GeBPv/M+zbEQ8mMrn/V9eGeH9jc5ftOxVBPRNOpeCoq3KP524ESspUFbe7zS7dcz/L3gjuR/aaKoutc/5IcufyPKVxMGk5jY/fQOCIK/FOOv1zX4th6CQ2zSO04FnUjnqsZQPwz7wYe99xuZudoWtZsZJn0GId5oI58+IuvsfheU+T3Lkc0eHFO+l0PGNP7GK9paVixNa8Y9qARFqRPMW4xxyLe9Qxn5v2raVipPesJblnDematSjtteYTooy1pC/WsgHmv5K+WIqq8/yPvykYagalvZZMyy4yzTvING4j3bwDNJNzJgfKsfcaib33aOx9xhzSTPW+0BJh4us/JLr6bdRgI6LLj2fMCXjGntCFxq8raWJr3iWy5EW0WAe2XiPwz7woj66lhJoIzXuKxMZPQBCx9x6FYLHh6Dsez9gTMFSFxNaFRNe8g2vITKKr36bgyKuwFPai/d37kH2lWIqqia/7kMDs73aho++PTtbE/p2ZnEJ4cZ+cQqiWjND4r2sRnT7KL70bQbKgK2manrgJPZ2g/MoH8lXDF/2b0KdPUnTyj3ANO/yQ39NsUwvdMAMtqyV/Hrsnsf5y0JNgfzXoiQm+eiiajppluXV2rVRNpyOWYn1DhHV1QTpiaQpcVhRNJxRP8dDN55OOhbjpgTlsbs+woT5JwoDwkpcIzX2MwhN/gHvEUcDeLt2+riLdKY13ounpH6ElQlRc9fAhM53yrifrEZzY+CmC1Y534ul4J56Wt55nWncTnPsYqZ0rkDzF+GdcgGvEkXvtkdIJIsteIbLsFYxMAuegafimnddF66Nz/jdXDO4U6xx7ArZeI5AEIdfxlQRQs8nqvvK0IiYtO6JoZJp3kNq9mtSetaQbNmEoptWTzVdMoPcQAr0GIBT2wVLcF8EZwCoL2GWRmoh+QD0eCZNu/lnwiuY6oRhZVeqDwIpJIdcMAyPeSrRhN7TvQGnZSceeLaixDiBrsVY1DFu1GRO4S/rgsopkgITy2RpCncX26Oq3SW5fCoBz4BQ8E07BVjW8S1Kbrt9EaMHzpHatMGO5yWfhGXdCLn4yNIXo6ncIL3wBPRFCtLtxDjkMLRHE0Wfslx4XGGqGur9fbnaez/hF3nOdCuHll96dY+R117hKLXuZ5o8epfSEmygYNRuvy7R4y4TbWPDXqyjoNZDr/vQYTpsVySoyutJPkctOWzxNidvOij3tpBQVj91GgdvKwBIPTqtETTCBx2ah1GenwGXNsWIOtRHSg89GT4LdA8BMvM9+aGG3ne329/5ObNVb2bnSkbnH041baXrqFlxDD8sJmeWoLBVDKDn3twiCiBptp+GR72MrH0DJuX/I+wE3PfMT1EgLld/9539NlTZ0zaR3LXgOLdyMrXIYvpkX5s1R6ZkU0eWvEl7yEoaSwjV8Fr7pF+QUzAHS9ZsJLXjOvDnb3XjGnYxn/EldEkVD10huW0J01Zuk9qwBQcTRfyLuUcfg6Df+v7oOLREmXb+JdP1m0o1byTRtx8h0lqUFZF8JcmEVlkAlsr/MtJHyFiN5ChEdnv8qENkfhmGgp6Jo0XbUaBtauBkl1ITaUY/SUY8aagLDDA0Eix1raT9s5YOxVgzGVjXsv5otN3SN1J41xNa+T2LbItBUbJXD8Iw7waT170f50tMJcxZu+Svo8ews3PTz8zwx1Wg74UUvEFv9DoJkwVJUjW/q2TgHTSNVu5621/9KxZVmdVhX0rkOSmTF66R2raTkrF+hRttJbl9Cas9aPBNOzVMeP9B71/j4DaCplF/5YN53vPXVO0lsW0TF5fflRMxaX72TxNZFlF9yVy5oyy2u5/wWR99xuddnWnbS+MTNOAdOofi0Wz/zPe3snuQl1wJMH1DETbMH9STVXwF6EuyvBj0xwTcHVdNZtquDna1R6jsShNMZShx2gpkMyYad/OX6c6gccxgDz/oJdSGFNOb9vPm5n5Jp2Z0rshuGQdurd5LYtthMJrLd4LY37ya+cW72sb2Ja3zTp7S99ieKT/8ZzkHT/uvzz7TuJjz/2b0MtUmn4xl/ch5DLbVnLcG5j5Fp2oalsBe+mRfhHDQ1t55qqRjRZa8QWf4aRiaBo/9EvFPO6XY9UDrqTb/m9R+ip+PIBZW4Rx6Ne/gR2DyF6OQn1p2wYHa6Y9m/hew/j6QSrt1BunEzWssWwrVbyQSbcq+T7G7sRVU4iypRPeWIvlIzTvAUIbkKDmiF1h0EzKRZo2u3WlfSaPEgWrQNNdyMGm5BCzagdtSRaa9Dz3mWC1gLKrGUD8RRPghn1TBsxX2wyhKaAW4LWG0QTpqz4J1Jf3fJvxpuJrbuQ2LrPkCLtCA6fbhHHY1nzAld7LEMwyC1Zw3hRf8mXbM22yQ5A8+4E3OftemJPZfw/GdRw81m8cMVwD3qaBx9x31lcUGnJWfJeX/Ii0c77XC9k84wVcqBVM06mp/7Ge6xx1N4zDWmKnzrbmqe+AH+gePpc/bPcdhlil12hld6ePZ319K4bT03PvAi3qIqCt1W+hV7cFhE3HYZQRSIJVViiopLtlARsFPkstGWyGCXRVKqTqXfkRsd6cGXj68swRYEoRfwJFCKGe/9wzCMvx3sNT2L6TeLcx5amCeK1gk9k6Lx8esxtKwH9j6V4ND8ZwkveJaiU36Ca+hMYO9NJTD7e3jHn5z3WMGx1+WUxgESO5bR+uJv8qiz/y0MTSG29n3CC19Ai7Vjqx6Jf8aFeV1NLRE2VcJXvYmha7hHHo1v6jl5N+1041bCi/5NcttiBIsN98ij8Uw8LS8Z74TSUU9s7XvE1n+IHg8hOn24hh6Ga+jhWCsG/9fVQMPQUUPNKC27yLTtQWmrQemoQw02dlVeFyUkpw/R7kG0u02ak8WOIFvNBFWUsudhYGgahqZiqGkMJYWeSaKnYujJCFoiAnr+8irINmR/GZaCSixF1ViKemMt6YscKD9kenzXazNQWnYR3ziX+Ma5aLEORLsH1/BZuEcfi7W4T5fXqNE2oiteJ7rqbYxMwlRznXpOXsFHiwUJLf4P8TXvYGgq7tHH4hl/CpEl/8Ez5nis5YMQRInWOX/EWtof37RzMQw9F0wpHfUEP3qEolNvzaMtHgo6RyEKT7gJ98jZex/fuojWOX/AP/NifNPOBfbp6OzzWI4pMu5ECo7eqw5uqIrpDZsIU37F/V2KPQLgsknE0nvDFBE4f3I1wyt83eo29ODLR0+CfWj4vHFBT0zwzaI1mmLBjlbW7goSTCtomk5G0XA7JF594iG2v/0oQ8+9FXHADCwGRDRIhZpofOx6rGUDstadIloiTMOj1yI5s2KOsgUtGTHp496inNI4mMlQwz+vRrS7KLvk7i/cUctjqDl9+CafiXvsPl1NwyCxdSGhT59C7ajDUtIP/4wLcQyYlDu2nooRWfE60RWvoycj2CqH4Z18Bo7+E7usg7qSIrF5AbG175Ku25hlT43GNfwInAOndGHGgZnc7jsXndUFQ8H8kbhkUFWw6jEi9btIt+4m3VZDpq2WTLAh1zHeF6LNhej0mnGBzYVodSBYbGZMIMn7sPx00DUMNYOupjHSSfRMHD0ZQ0uE9yn0d0JA8hRhL6hALOxlxgXFfbAW9+lybfs63bgwxdYU3Sw0dPqJd0YzWjJKYutC4hs+Jl27HhCw9x6Ne7RJw96/2G5oKoktC4gsfZlM8w4kdwHeiafjHnM8ojX72eoaic3zCc5/Fi1Yj6WknzlCVjGY4If/+ErjAkNTafjn9xCdXsouvmvvd0lJ0fjo9YBB+RX3I1rs5jjmY9cjSSLjr7uP1rQdu6iw49Gb0eJBZv3oX8geHw6rhF22EFr5Bm/+4w4u+9HvmXHCuTjtEn0L3VQUOGkIJ7BJIhvqo1gkcNpkqgtcDC73YrdIRFMKsbSK2ybjsR96EaYHnx9fZYJdDpQbhrFSEAQPsAI4zTCMjQd6Tc9i+s1hxZ4g5/5jEeoBJMbT9ZtoeuYnuIYfSdGJN+Ue3zsbWk/55ffnPClbXvw16Zp1lF/6NyxFvfZTGn8gp8htGAZNT96MlghTedXDn6vqeiAYaobo6rcJL/4PejyEvfdofDMuwF41PLeNGm0nsvjfRFe/C4B75Gy8U87KS6IzbTVElrxMfONcMHQcAyfjHX8Ktl4juiz6hqaS3LWC+PqPSGxfCpqC5CvFNXg6zsHTsZYP/PK6zImwWUmOtKLFOtDiQfREGC0ZQU/HMdIJdCWFoWQwdMVUKen8PUsSgmhBsFgRLXYEq8NMyu0eJKcPyRVAcheY3XFfCZLL/6Wdt9K6i8TmBcS3LDBnskQJR7/xuIYfiXPA5G4/+3TDFiIrXiOxeT4YBs7B0/FOPjMnCAbmZ9n+zr2kdq0EA5zDZuGfeWHus2x/934kpx//zIvMfTZuo3XOH6i65nEMTcHQdbNIsvY9PGNPzCsAHdq16ea4RCZFxXf+nmMwaMkojf+6Jk8hXI200fjotciFVZRd+CcEUTK3e/Q6BKuD8svuyRsHCH78KJGlL1N81q9w7uOr3fnt6+7XKgC3HGuqgn+WbkMPvhz0JNiHhs8bF/TEBN8cNN1gW3OEmvYon2xpJZpWkEXY3WZ2LBvaY2x+9EckWmuZdNPD2PxFRJIKGQUaV75P69t/wz/rcnyTzwT2ak7s27XLKY3vQx8HiK37gPa37vnCXex9ka7fbM7l7lmN6PLjm3Qm7rHH7020O7ucC55HDTVmk63zcAycnFsD9UyK2Nr3iCx7BS3SghwoxzP+FNwjjuo2cVY66omv/4jYxrlo4WYE2WoKZQ2ejqP/RESbEysmjTvJXn9ciXyatgXzXq9izj6L2SVZEs1ucFJJoYVbUSMtqNF2tHgHejxkxgTJKHo6gaEkMZS0ueZp6t6YQDRtyQTJYhbmrXZEmxvR7kZyes2YwBVA8hSZHXJvMYJsRWZvt/vzWMZ2FhNkIJ2Mkty+hMSWBSR3rQJdRS6oxDX8CNzDj+zSrQZzXY2tfZfoijfRoq3IBVV4J52Oe/iRuRjC0DXimz4lNPdxtFg7otNH4Mjv4Bp2eO6z/Krjgujqt+l494Eua3cnU630vD9i7206jbS9eQ/xDR9RecGdVA4YSlKB5o8fpWXRywy64DbKR05HkASKXXY8qVae/OmFjJwwjbsfeQaHzYLfaYoYxjMa0VSGcp+DNbUhZFnCYxFx2CRGVQV6hMu+ZnxtFHFBEF4F7jcM4/0DbdOzmH5z+NmcdTy7pOag2wQ/fYrIohe6LHpKsIHGx27AVjGIknN/b1asY0EaHr3WFHC4+C8IkiVLH78WW+WwnBIimGJkLf++jcCR38E78bQv7Zp0JUVs1duEl7yEnghh7z0K37TzsPXa612oRloJL/4PsbXvga7jGj4L7+SzsGY9D8FM4KIr3yC2+h30VBRLcR88Y0/ANWxWt4uqno6T2LqI+KZ5pPasBl1DchfgGDAZR/+J2HuP+lbMVH/VMFSFVO16kjuWkti+1LQsE0RsvUbgGjoT5+DpeQI0ndCVNInN84muepNM41YEqwP3qGPwjD85rwCihJqILHkp+9lp2HqPRpAsWIuqcQ6enrPhyLTuoe31P1Ny1q/NgoEk0/Tcz/BNORtrWX8zoIq24p92fpcZu0NBbMPHtL/xV4pOvgXXsFm5x1tf/zOJzfMpv+RurKX9MAydlhd+QbphK+WX34slUJGlT95BYttiyi7+a17hIEcZG30shcddl3tcFgV+e+oI3l7fmBMv2xdWSeC5707tSai/RvQk2P8dPisu6IkJvjm0RVNsaIygqjpvr2+gMZhEN3SCsTRWi4ggSLTV7WbJ366mavBozv/FAzRHUmxujBCKa+yZczuJ7Uspv+SvuRnT9nfvJ7b6XUrP/0POxrD1ldtJbF9C+aX35NhLhq7R8K9rAai44v4v1WkjVbeB8PxnSe1Zg+j0m3Tiscfn04k3fEx44QuooUYsRb3xTjnLVArfp8ue2LKQyPJXyDRsMdeoEUfiHnNCnr4LmImkahik6jeR2PQpiS0L0OLBrH3nKJwDJuHpNwH2Y8jtm7SK7KWX2wGLZObHSf3Q5qu/LTAMA7WjnuTO5SS2LzU71YaO5C3GNXgGzmGHH1AQNd24jeiqt0hs+hRDTWOrHoV34qkmiyCbNBuqQmz9h0SWvoQabESQrXjGn4IaakL2lXxtcYGeSdLwz+8h+8oovfDO3PWk6jbQ/MytORo4QHzLAtpeuT0nduoWIFW7hh3P/oLySccz5tybkSSJdEbDLhssufd6Eh3N3PPCewzqW40hwLAyHwUuKwlFpTmcwiIJ1AWTWCURURToW+Qm4PrfEc/9v4KvJcEWBKEP8CkwwjCMyH7PfRf4LkB1dfX4PXv2fGnH7cGh4+dz1vHMZyTYhqbQ9PSPUMMtlF9+H7KnMPdcdM17dLxzL/5Zl+GbbHpedtJj9/XSjq56i473HsyzJQBofuE2Mo1b87wAvyzoSorY6ndMZdBECNEVwDf1HDzjTtpnbreNyNI5xNa8g6GkcQyYjHfS6XlCGrqSIr7xU6Ir30Bp2YlgdeAaehjuUceYNKNuFgUtFSO5fQnJbUtI7lpp0rslC/aq4dj7jMHee5S5oPyXdOtvEwxDR2ndTWrPWpK7V5GuXY+hpBFkK/beo3EMnIJzwOQugmWdyLTsIrb2PeLrP8rOsFXhGXdil+5Aumk7kSUvkdiyAEQRa0k/5EA5xSeb383Elvlo8RD+WZfnPpP2dx9AtDpwjz4W2V9G+5t34591ObKnED0VQ/wcKvb7Qs+kaHjkaiSXn7JL7sot9LlFc8aF+KefD0Bk6csEP36UguOuz9mAxNa+T/vbf+uisq+nYjQ8dr3pD3rZvbngrzO5vmByNSv2BLnwkcUoqo4kicwaVEyRx8aZPargXzt6EuzPjwPFBT0xwbcDzZEUu9pibGuO8MnGZpojSQwDMppKRgO/00Khy86uBa/y8WN3cva1P6Vk2imEYyqbG8M0dUSyzBxnjpmjZ1I0PnEjhpI2R17sbpM+/q9rkD1FuWI8QGLbElpf/l3eqNmXiVTdBkKfPk26dp3pVDFqNoHDLs2tBZ1d0Mii/6C01yD5SvFOPA33yNl5c9zphi1EV75BfPO8nIaIe/SxOAdPz1GV94Vh6KTrN5HcupjE9sWowUYA5IIqHH3GYO89Glv1yIM6q+ybcH/boSXCpPasJbVnDandq3J2pZaiahwDJuMcNBVr2cBu4yc9HSe+8RNia94l07wDwWLDNewIPONPyhsl01MxoqvfIbriNbRYB9ayAVgKeyFYbBQee93XHheE5j1NeOHzlF30Z2yV5py2nknS+Nj1YGSp4VYHarSNxkevM1X2L/wzgiQjJSPUPnY9ss3Bqb9+lLhqIaOquC0im95+nI3vPMXNdzzE6JlHU13gpHehm96FLsSs6EoyoxFLK9gkEYssIYlCjyr4N4SvPMEWBMENfAL8wTCMlw+2bU+1+pvBij1BXlpZxwvLatE+w79Laa+l8fGbsFUNo+Sc3+ytHHZ6Xm5fQtlFf8n5R7a/cx+xNe9Rct7vcfQebXbx/v0r0vUbKb/s3pztUqZ1D42PXY971DF53bruYBg68fUfk9i2CD0ZRbS7sAQqcY89Pk+lfH+0vXl3dpa5HiOTMH0RZ1yUZ/OgJcJEV75BdOWb6MkI1vKBeCecZm6TXfgNwyDTsIXo6ndIbJ6HoaaxFFbjGnEkrmGzkL1F3Z93Z0d353JSu1ehtJkFDcHqwFY51PxXMQRr+cDPZVv2TUFPJ8g0bSfdsDkr0LYJPWVKtcgFlWaw0G8C9uqRB+zYa/Eg8U2fEl//EZnmHSDJOAdNwzPmuDymgaFrJLcvJbL8VdK16xGsTjxjjsMz4VTUUCORJS9RfPrPECQLqdr1xDd+gnPQ1JxYmBptJ75xLqk9a9AirVjLBlBw7HXoaoZM7XrSDVtQOupy3ydHvwm4Rx/7mYWPTlZH6YV35kYQcuwNX6lpTSPJpv3WU7fgGDCR4qyFjdJRT+PjN2ItH0jpub/PO1br638msWmeuUBnK+4C8L3D+nHrCXuFVXoo4N8O9CTYnw+HGhf0xARfP3TdIJZWaIulaQwn+XRLM3UdSRpCMcIJjUFFLoIpBatVpshto7rQyaO/uo7NKxdw+q8exfBVsqUpTkTp1Ja4LY+Fk27cRtPTt+AcOIWiU29FEIS9xfip5xLIWhYahkHLC7eRadp2SIX3TFsN0RWvo3TUIQgikrcYR5+xB3VdaH/7XtRoO2qoETXYALIN7/iT8U44NWeFaBg6ye3LiCx5kXT9JkS7G/fo4/CMOzE36gZm7BBb9yGxte+idtQjWB04B0/HPfxIbNUjDjhqpXTUk9yxnOSulWZRWk0DApbi3qZ1Z+VQbOWDTN2TL2Fc66uEoWsoHfVkGraYcUHdxpxDimB1YK8ehaPvOBz9JyD7Sg+4j9Tu1cTWf0Ry2yIMNYOluA/u0cfiHnEkos2V21YJNhBd8TqxdR9gZJLYe4/BO+Us7L1Hk67b8F/FBYGjv2+6pdRtJNO6Cy3ahmGAtbg3ngmndqvFsy+UUBMNj3wf56BpFJ/yo9zj7e/cT2zNu5RecDv2XiNMQcAXbiPTuCUXC5tx9B9Jbl/G5Ov/xpgx45GtIo0dcZq3rmH+/T+g39QTOPem3zJraCkBl50ij42+Rd/+ePH/R3ylCbYgCBbgDeBdwzDu+qztexbTrxcr9gR5eWUd/1lei6obCMABRrDz0NmF3p/SrSWjND52A4IsU37p30w1xkyKxiduwsgkKL/8PiSnb+8cakEVZRfemaN/mbMpr1F28V9ySUV3CH7yBJHF/0H2lyNlK41KR72pRtqNQBaYaqGR5a9QcPQ1SC5fTr1cj4eylenTcY+anUsEdSVFfN2HRJa/ihpsQHIX4B5zvFnp3EctW08nsgnih6TrNwECtl7DcQ09DOegaQfs1oKZiKVq1pKqXW8uRG01dJLCZH8Z1pJ+WEr6Yi3qjaWwF3KgrIvQx9cBQ1NRw80o7XUobXtMm66WXTlvSzCr7/aqYdh6jcBePTIv8NgfWipGcusi4ps+NVXYDR1raX9cI2fjGnZ4vm9pIkxs3QdEV72FFm5G8hbjHX8y7tHH5hZaNdJCeNF/cPQZi3PwNLR4kNi6D00l+DHHYd7LDARBNIsAmorStJ3EtkWk600/ckQZS0EFosOLFg+hdtSZxZ7jbzjgdSjtdTQ8dh2uITNzntWGYdDyn1+Trl1H+WV/w1LYy6xcP3ETRiZF+RX3ITm8GKpC09O37GWD7FOU6aSc79v9zr3PosAL3+uhf3/b0JNgHzo+T1zQExN8vYimFBpDSZojaaoK7IQTCmtqOwjGFGqDSXa2Relb5MHrkGiPpRlc6kMXDaLtHfzp6lMR7V6m3Xg/zTGdtqxqV3DuY0SWvET5aT/FO2Q6KQNiS16kfe7jeaKnnXOopRfcgb1qGABKW615jx02i6ITf3DA81ba62h84iYQBKzFfcHQUcJNOAdOofDY7gv2+8YEsreI5n//Ej0VI9O0HQQB1/Aj8U0+I+f8AKYWTWTpHBLbFgOYllzjTswvBhsG6boNxNZ9SGLLfIxMEsldiHPoTFxDZh6Q7QZmET7duIVUzTrSdRtJN2zGyKp0C1Yn1tJ+pnVncR8shdVYCquQHJ5D/4C/JBiGgZ6MoLTXorTVkGndYwqytuzMibCKNpdZHOg1HHuvkVjLBhyQ6m/oGun6TSQ2zyO+ecFeC62hh+EeOTuvw20YOqmdK4muepPkjuUgSriGzsQ78fQ8KvfniQtSNetMwbjGzSS3L0NPRQFMPZpsISDTvANBtlF+xf157M3935fWF39DqnY9FVc9hOwx1/VOEdR99Qc6rTcLj78R96ijgb2xda9jr2DSyRdT6nXQt8DJmt2NvHTbJUiyzNm/fZKSAj+zR5ThslkRgKHlXnzOHgr4tw1fpciZADwBdBiGcdOhvKZnMf360EkvTSt7/RNFgW5tuvaHYRi0vvx7krtWmDZD+1hspGrX0/zcz3ANPYzCk36IIAim3/VTN+PoM5biM3+JIAjEN82j7bU78U07LycyoacTNDzyfUSHh/JL7+n2ZmwYOrX3nIuj3wSKTvnx3puupppiHd1UeA1VIbJsDoam4p9xAXomSXTF64hOH5LdQ3jJS2QatyDYPVhL+yE5fbhHzsbRdxy6rpHatZJo1qoBUcI5YDLuMcdj7zM673hKsIH4xk+Ib/zETD4FEVvVMLNiOmDyZ1Y+9XScdOM2MlmbrkzLLtMaq/MTEkRkbzGyvwzJW4LsLUJyF5oCJE4vosNrqoXanCBZDqq+ahgGaCp6JmGqiKdMxVAtHjSF0yKtqJFW1HATargF9L2TXpKv1PTqLu2PrWwg1orBn7nIq9E2ktuXkti6iFTNWtA1ZF8pzmGH4xo2K2/mPRegrH6H+JYFoCnYeo3AM+4kU010v66yoWs5ClnB0VcjSBY6PngYyVOMa+hMYmvfwzFwCpnGbaZCad0GACwlfXEOmIy9zxhs5YPzRNY63n+I6Kq36HXj893O2RuGQfNzP0Vp2UXFdx7KdTsiK14n+MHDFBx9NZ5xJwFZO5r1H+XNHXYKnRSf8QucA6fk9qt0KvAW96X0gtu7XKsowA+PMQXMevDtQU+CfWj4vHFBT0zw9UHXDWo6EiAYNASTuO0yTotMeyzFjpY4mq5jlQTa4gpWi0AqreF32VAUlUKPnffefYcHf/ZdqqefSuHsqwimzJXL0FRanvkxSrCe3pfdi+YrzepR/JJ0/SbKLr0ba1E1ejpB42PXYwAVl9+bK6AGP3mcyOIX80Sh9kdw7mNElr9K5fceySU1YK793YlnHigmkFx+bL1GEFk6h/j6DzHUDHJhFbK3FM+EU3D2G49hGGiRFqIr3yS29j30VMwcZxpzPK4RR+QViHUlRXLbEuKbPiG5cyXoKpKnGOegKTgHTsVWNeyg8+WGrqG015Ju2GLGBM07Udp25/yxAUS7BzlQhuzda93ZKUwmOjxZFXEngtX+mR1wQ9cwlDR6ujMmiKAnQmixDlM8LdJqCqyGmtDT8dzrBKsDa3FfswBQNtDsuBdWHvR4hqqQqllLYttiktsWo8WDpghc/4m4hh2Oo9/EvM9OiwWJrXuf2Jp3UcPNiC4/ntHH4x57fLcWoZ8VF0TXvIelsJL0nrXEtyzASMcR7W4c/Sfi6DcBW/XIvP0qbbU0/Ov7+GddgW/yGd1eU6do377Np05Gm+QppPziuxBkC6m6TTQ/+xOcg6fn4thM626anrwZT+8RnPDDuynzu7FLGmlN5OW7fkLNyrlMuu5vDBgxliElbnoVuRlR5SfgtGC3yFT4Hd2eUw++ORwoLvgyFCWmAxcD6wRBWJ197GeGYbz1Jey7B18Qi3e2k1H3JtcCYJVFLpvah3/M23nQRFsQBAqPv4HGx66n9dU/UX7pPblZI3vWlzg8/xnsfcaY1cfSfgSOuJLgBw8TXfYK3kmn4xo6k+TO5YQX/ducRe41AtHmpODoq2md8wciS1/OUxXthKEqGJkk1pK+eQnkwRYpLRlBDTXhHHqY+Xc8hJ5JIDp9OAdPwzFoKskdywh++E/SWV/r5K6VFB59Dc6hM3H0HYez/0TTkmv1O8TWf0hi60IkbwnukUfhGnEUFn8ZlkAF/unn45t2HkrrbhJbFpDYupDgh/8k+OE/sRRW4+g/AXvfcdirhiHI+RVH0ebC0WcMjj5jco/pmZRZJe6oQ22vM32pw00oO5ejxUMcULtTEPNsunKKKbqGoavmAm0ceIpLdPqRfcVYSwfgHDzDtOkqqMJSVN1twrk/DE0h3bCF5K6VpHauMOnfgByowDvxNFNZfb+5KzXSRnzjx8TWfZCl2DlNcbOxx+cxEwxNyevkC6KEo/8kc6ZuwfMEDrsYQ0kj2l0owQYyrbuJLHnJDJYKqvDNvAjXkJm58YTuYCkxuyB6Otbt9cbWvEu6dj0Fx16XS64zLbsIfvyo6Yk+9kRzu/UfEV//Ib5p5+eS68T2JUSXv2oWDPZJrg1Npe21P4MgUnTyLV2S687f6JR+3VfPe9CD/wH0xAXfUgiC+U8SRFw2C4qi4XRL9CooYGiFHwCLJNAQStIWS7GtOUo6o5LRDFx2kcETZjLtlItY+NrTOPuMwlM9mbgOhiQz8Owfs+nhG2h5809Un3cncVGm8KSbaXzsBtpevYOyS+5CtDkpOvkWmp75Ce3vPUjRSbcgCAK+aeeR2LyA9nfuo/yK+7odN9ISYSSHLy+5Bg7oSnKwmMASqKDw2GvxTDiV9rfuIdO8A7W9jlTNOnxTzsY7+XQkTxGBI67AN+NCEpvnEV31NsGP/knwk8dwDpyKe+Rs7H3GIFrsuIYdjmvY4eipGIltS0hsXZCdF34dwebC0Wcsjn7jsfcZ22W8TBAlrFkLLLK6HYaho4ZbUNpqUDvqUIINqKFmMi07Se5YiqFm9r/cfd4Pq1l8F6W9UuS6bibWagY05cCvtdiQPMVmYbxiCJZARc6+U/IWH5KdmhJqIrV7FcmdK0jtXo2hpBAsdhz9xuMcNC2nrN4JQ1NI7lhObN0HJHcsA0PH1msE/sMuyR/b2y8m6HzvuosLAGLrPyS+/gO0SCuCxY5z0FRcQw/H3mfMAWNJubAKJBk9Ge72eS0ZpeP9h7CWDcCT1QwwDJ22N+/CUFIUnfyjrD1dlLbX/4TsLabwuOsQBAE9k6Lt1TsRbU56n/oDUqqOzyEzuNTHKy8+x57lHzL1nGuoGjwSu0XEYpEIOCxYJREQcFn/9zV8/n/CF06wDcOYz143mR58yzClXyFWWcwJJJ01vionjnT08DIW72wn4LTy69c3kFG7JmKS00fhiTfT8sJtBD/8Rx6V1jf1HFI16+h47+9YywdhLarGM+4kUjVrCX7yuDlXVDGYgtnfI12/kbbX/0L55fciObw4B03FOXgGoQXP4hgwuYsipyBbER1elGDDoV+sYZBp3kFg9ncBUDvq0ZNRbFnFZ0EQUNprsPcZg2v4zeaCufpt2l7/E9ZlL2OrHkWmcSvOQdPwTjkL/2GXkNi60PTcXvA84QXPYasahmvYLFMZ2+kzO7wlffHPvAgl2EBy+1KSO5YRWf4akaUvI8hWbJVDsPUaib3XCKzlg7r1WBStdmzlA3Mz7XmXpalmxzkeMm26UlH0VAwjk8zadKVBVzF0zUyuBbILq2y+jxY7gs1pdr0dHiRH1pLDHfjcVHRDVUg3bSddt4FUzVrSdRtNqpggYqsciv/wy3AMmGQKj+yzEJtBx2LiGz4mtWctYGCrGoZvytk4B8/Y62epKmZgsvItZH9pjpLdCdlbhH/GRYTnP0P9P74Lokxy1yq0aCuizYVr5NG4Rx51QDGV/aEGG0x/cVdXKrYaaSH48b+wVY/CnQ169EyKttf+hGR3U3j8jeZ3qq2WjvcewNZrBL7p5+Ve2/7m3VhL+xM44sq8/YbmP02mcQtFp/ykW3uSGQOLuGn2oB56eA/+Z9ETF3x7IQgCZT47HfEM5X47AYcFi2wG7jbL3gC+zOcgmlaY1LeQ5btDNLREUepNevmIU77HuuWL2T7nb4z4/r0UeIpIKOAtrWT0ebew4onf0jrvSZyHX4HsLqDopJtp+fev6Hj/YYpOuBFb5dBcgT7eZ6wpKGaxU3jc9TQ//zNCnz5FwVFXdTl32VNs2lWmE4dUAP6smAAguX0x1pK+FJ18C+madQTnPUV4wTNEl7+Crc8YtEhbLnl2j5y9V6Rzw1wSm+chuQImQ2vo4VjLBpiz2yOPwj3yKPRMktTuVSS2LyO1awWJLfPN6yiowl5txgS2quHd6rkIgmgW9P1lwKT9LsswO8+xjr3WnamYadOVSWKoWZsuXTPtO2E/my6b6Zdtc5k2XQ4vktNvxgQ21+fyJDcMAzXcbFLda9eT2rMmJ3AmeYtxDT8Cx4BJOHqPzms2dIrAxTd+SmLzPPRkBNHlN624Rh6NpbAqt22mZSfRVW+T2LKAiu/8vcucfmdcEJr3FHUPXoahq+hr3wcMbNWj8M+8GOegad2K0e0PLdoKmork6V5jJ/jBw+ipKIXn/jZXHI8snUNq9yoKjr0Wa1E1hmHQ/tY9aLEgZRf9KcfS6Hj/7yjtdVSd+zvKyotxWmVqgnFcyRZeefCPVA4dT/mMMzEAURAocNnoW+ShyG0l4LJht/Qk2P9L+PI8EXrwrcT43gGe+c6UbgWSxvcO5P7+eEsL729s7nYfjj6moERk8X/M5DRbDRZEiaKTbzGr069kq9NWO4XH30jj4zfS+uodlF92L5LDQ9EpP6HpqVtMz8szbkMQBAqOvppU7Tra37yLsov/mldRFAQBe/VIUjtXYOjaIalva4mQSZOy2NFSMRI7liIHyvO6ooktC/HPvAh71VDsVUNzC4EaaiS69GWQbWiJMPGNcyk+/ee5hVWNtBDfMJf4ho/peO9BOt5/CHufMTgHT8c5cApStiJumXga3omnmQtr7XpSu1eTqllLeP6zhDFAlLCW9MNWMQhr6QCsZf3NZPQgia4gySZl/CDzzl8FTNpaHZmWnWSylPZ083bQTFdMS2E17pFHYe9tqqTvK0oC2aR6+9Ks9+UK0FRkfxm+aefhGnFEnlCd0lZLdO27prJ4MoIcKMc5dGbXc9JUlNbdaKloVpXVwN57FIFZl+IcNK0LW+Cg12cYJLcvw1Y5tMv7bxg67W/9DQyDwuNvyAUcHR88hNJeR8m5v0Ny+dGVFK2v3o5gsee60Yam0vbqnzB0zaSF7dNdSe5eTWTxS7hHHYOrm+uTRKEnue5BD3rwlcImS5T7Dkw1zag6LZEUHVGFlFWjJZqgxGPF77TQFI4T0QxGXfwLFt99Nbvn3MWIK27HbjGwyQLO4dPoN+Nkds5/mdLKEdgHTMLRdxzeqWcTWfRv7NUjcI84KlugX0vH+3/HVj4YS1Ev7L1H4R57ItHlr+EcOAV79ci887L3Hkl44XMkd63ENWTGZ17n54kJOpPZxI5lSA6vacO5ZSEYGmqkhdiadyg+4zasJX0pmP09ArOuILljGbENHxFd8QbRZa8g+8txDpmOc/AMrKX9Ea0OnIOm4Rw0DcMwTPeN3atJ7VlDfONcYqvfBkDyFGOrGIy1fIAZF5T269bashOCICA5PN/IXLYWD5Jp3kmmeQfpxq1kGraYdmSAaHebY14TT8PRZwxyQVVesm4m1VtIbJlPYstCtGirSRcfMBnXiCNx9B2Xi/X0TJL4pnkm/btxC4JsxTnksG4790p7HbG175HcvRojGUFyF+Kbeg6uUUd/5sje/khuXwaQY6Lti8SWhcQ3zsU3/YLcyGS6fjOhT5/EOWga7tGmzkB02RyS25cQOPIqbOWDANPzPb7+QwqmnY+rzxgEXSCd0VCVJHf86hoM0ULfM39IW1yhOmDFYZHQdJ2UplHgsuUVv3rwv4GeBPv/A+ybSHeHFXuCfLK19aD78M+4kHTNOtrfuQ9r2UAsgXIAszp98i20vHAbHe89QOGJNyPZ3RSf+hOanv4x7W/eRfGZt2ErG0DgiMsJfvjPHH1ccvkpPPZaWuf8kfCC5/BnVUU74Rx6mJmc7VyOc8Dkz7xOa3EfJG8x9f+4yqQ5F/bCO/7U3POZ5h0Ikoyt0hRW6VTCLDnzNuRABanadcTXvk9883zQFJqe+mGW5n44IGLvOw7vlLNRWnYR3/QpiS3z6XjnPjrefQBb1TAc/SfhHDARuaDKXFj7T8TZfyJg0opMBe6NpBu2EFv/EcbKN80TE+UcBctSWIWloBLZX47sK0V0+j5XNfnzwqyER1FDzaihRpSOetSOejJtNaYqaJZKJlhsWEv74x13Mraqodgqh3Ur7KaEmswu/valpGrXmd7gniI8Y0/ENfSwPOEXLRUjsXkesXUfkGnYcvC59/ZaYmvfJ7bhI1Owzl2Ad+rZZqU7+138vMg0bEFpr6FgwrVdnouueIPUnjUUHHtdboGOrf+I+LoP8E49F0efMRiGQcd7D6K01VJyzm9ytMXgJ4+TbthM0Sk/yaOna7Egba//BUthr1xHZX8IBxoF6EEPetCDrwnxtEo8o2K1COxoi2ERRUIJjY0NYVKKRjqtoruKqTrx++yZczftC/6Nd9q5xGJm8bXkqO9Qv20jrW/eRfll9yL7SswYom4jHe89iLW0P9biPhSddAuNj99A66t3UHbJXxEtdgKzLie1eyVtb95NxRX35RVubVXDkdwFxNa+d0gJ9n8VE7TVEDjzNiwFlaixduLrPiS29j2UthoaH70W14ijTEtJTyGSr4SSM35hrmVbFpLYPI/IkpeJLH7RnMEeOAlH/0nYq0ciyNYc48076XQMXSPTsot03QbS9ZvJNG7NdbjBFN6yZIVPLQWVyIFyMy7wFn3lIqiGmjG1WUJNJjW9ox6lvYZMaw16IpTbTg5UmNomFUOwVQ3DUty7yzy2nkmR2rOG5A4zLtDiQZBkU5Ts8EtwDpicYyMYhk6qZt1e4TglZa6XR16Fa8SReQUFPZMksXk+sbXvk67fCKJkjm2NPjYvUf9c120YxNa+Z4rL7aMXA+b63f7u/VjLBuTGGrVklNbX7kTyFOUK8am6jQTnPo5j0FQ8E04BINO6m473/o6jehQF08/DaYOGiIIro9Pw9oN01G6n33m/Iil5cQAZ3cBlszC80hQ1075EO+UefH3oSbB7wOKd7Sjd0MP3hSDJFJ3yYxofv4G21+40vfyynTlHnzEm3WvBs9iqhpvWS+WDKDjqO3S8/xCRRf/BN+1cPONPIVW7nuAnj2OtGIK9aijOQdNwjZhNePF/sPcdi73XiNwxnQMmI3mKiSx5CUf/SZ+ZaAqSTNEJN5Ju3IqejOLoN574pk+RAxXYygaghluwlvTF0M0gILFlgbmIFVRiGLpJpTIMAkdcSWjuo4g2pzlX/dEjSB5zUTM0lcARV+A//FL8h1+K0rKTxNZFJLYtJjT3UUJzH0XyleLoOxZ7n7HYq0ea9CuHB+eASTgHmFQvw9DNRLZ5J5mWXaZqd+NWEpvns++8tSBbTYEzdwGSy28KnOVEzhwIFnt2BltGECRzuM4wMAwNQ1Mx1AyGksbIJNBTcbRUDD0Z2StyFmvPE1IBk9ZlKeyFo/doU928tJ/ZZe9mwdLTcVI160ntWU1y10rUjnrApMB5J56Gc+BUrBWDcouurqRJ7lxBfONcEjuWgqZiKaomcMQVuIYfmZe06+k48U3ziK/7gHTDZhBEHAMm4R51DI5+47+wp3h46UsmrXxovr1LpnU3wbmP5RZr87E9ORq4f8YFgDmfHV//Eb7p5+fsQBJbFxJd9gqecSfmdagNXaPtjb9gZJIUnff7A9qZGYb5e+zpYPegBz34piBLApGkgt0iUeF34bJa0DWdmo4odcEEmq5hlSWGTTsWsWEDm995irG9hqGXDMcA4opE2am3UvvEjbS+ejtlF/wJQbbkYojWV26n/JK7kT2FFJ30Q5M+/t7fKTzhJkSrnaITf0jTMz+m472/U3TyLbnzEkQJ99gTCM97mkzLzjzh1e7wRWMCpc1Uz/ZOPYfgB//AXj2S+PqPiK1+2xy7spqJYeDI7+AedTSe0cegJSMkty0hsX0JsbUfEF35JoJsMynhfcZi7zMaS1FvBFHCVjYAW9kAmGAm/VoinI0JdqK07kZpqyG29r2cYnf2qpBcfjMmcBcgOnxIDg+i3Z09JweCbDMZgZK8j8WqDppqUseVNLqSQk/HTWp5MmKKn8aCaLF29ET+/LFgdWAprMLRf4I5K17SD2tpv279ozsLB6k9a0jtWkWqbj1oKoLVgaPveJyDpmRnsF3Z8zLItO7OCsfONeelrQ5cQw/DNfJobJVDuldu3zwPQ0mZNqyzLsM94qhuR70+D1J71phiacde26Xz3vbW3eaM9Yk/RJBkk+X25l17aeB2N1o8RNurdyD7Sik64SZz7jqdoPWVOxBtDipO/hF2WcJpNeO0xPZF1C58nbJpp1M+YjIaAgUOK26rjMsmkszoaJp+SKLEPfj2oSfB7gFT+hUiiQLqZ/yKZV8JhSfcROvLvyf48b8oOPrq3HO+aeeSrt9ExwcPY80uGu6xJ5Kq30Ro3tNYywfh6DuWouNvpPGJm2h79Q7KL/sbkstPwezvkq7bQNvrf83OaJtVSkGS8U4+g+AHD5PavSqXxHwWOik5AJaCqhylyFo+kNj6D3NU9OiK13GNPDq7pYCtcoipdrllPmUX/RlraX9iGz8h+P7fMXQNLdIKkoXoyjcxdA1n/4lmJb60P/6ZF6FGWkyfy50riG/8hNjqd+j0uTRnrYZhqxyC5Ck256sKe2Ep7JXn32moGVPgLNSIGmo21Txj7WixDjKte9Czs1YHEy47IATRTMyd5ryVtbQ/Uv+JyN4SZF9JtkJedsDkD0CNdZCp30yqbgPpug1kmneCoSPINmzVI/CMPRFH/wl59G9dSZPctdIUg9u+BCOTRHT6TTXW4Ueath77+GB38cYs7IV/1uW4Rxz5hRfQTqSbtpPcugjftPPzZvl0JUXbq39CtLtyM9bmAnk7gtVhUr5FiXTjVjo+eAh733H4pplz10pHPW1v3oO1fCCBI76Td7zwon+bHfHjbjigxRyYpZXVtSFW7An2JNk96EEPvhG4bTK9Cpw0R1IUua2kMwrIEl6HBcMQiGVANAwM0cqYc26iZddm1j5zB5WX3YvdE0AWIFBSTvKEm2iZ80c6Pvonhcdcg+wuoPiUn9D8/M9pf+seik77KY7sPTS88DlslUPNAn3lkL0iqn3H4h5xVO7cPONOIrLkZULznqHkzNsO6Xq+aEwQX/8RZRfeibW0P8k9a2l7/c/IgXIydZsAg+AnT6J01OMaPB1LYRXuUUfjHnU0upImXbOO5E7T/zq501TKF50+7FXDsVUNx1Y11BRzlSxITh+OvmNx9B2bO1/DMNBi7ajBRtRQE2qkFS3atjcuaNmFnoxmfbU/PwSLDdHuRXL5kD2F2MoHInmKzJjAX4bFX4Ho8h/EcixDummHyc6r20C6bqMZowCWrC6Po98E7L2G7xUryybViS0LSGxegNJeA4KIvc9YU9hs0NS8OEQJNRFf/xHxDR+jhhpN7/EhM3GPmo2tctiXwvIzDIPwvKeR3IV53zeA6LJXSO1aScEx12ApMu3cIotfJLljGYHZ38NWPghD12h97U70VIyyi3+NaHOZs9hv/w012EDleX9AdAdQdUhmDDLhJja98Be81UMYeuJVCJJEic/GwFI/iqbhcVhZsTtIMJnGIolUBpw4ekTO/qfwpfhgf170WHJ8u7BiT5A7397Esj1BMMzqtXIQo+xO66GiU2/No2lpiTCNj98EgkD5ZfeYc0yZFE1P/RAtHqT80ruRfaVkmnfS9PQtWCuGUHru73IJS9PTP8LRfyLFp/98b8KlKjQ8cjWC1UH5ZX/7Ql1LXUnR8f7DpOs2mMltUS8Ch1+Goal5898tL/8ez5jjcfQbT3DuY2RaduOdfAZGJkly1yriW+ZjJMIgWUw18AGTcfSfmOeZaGgq6cZtpGrWkK5ZZ/pcZjvFkiuAtXygmZxnPbBlX8lnWmvk9m0YpsBZJmF2p9WMaV+ma+RUzkTJ7GrLVgTZdsj2HfseQ4u2kmnZbc5gN+8g07gNLdpmbiBZsFUMxt5rBPbeo7BVDM232oiHSO5cQWL7YlK7VmbVvj04Bk7BNfQw7L1H5T5LwzDING41K9ibP0WP7+ONOeKog/qJ/jfIWW+11VD5vX/mURDb376X2Nr3KDnntzj6jsMwDNpeuZ3EtsWUnvd77NWjDuF7fk+eeFly92paXrgN1/BZFJ54c961SKKA1k1hSxLgd6eN5ILJ1V2e68E3hx6brq8GPTHBtweabtAaTRNJpTF08Dut6LpOXTDBO+tb+GRzM+FUEsMAmyzis9vYvX0Tmx/5IfbKwVSc8ztskoRVhqACwY8fJbL0ZQpPvBn3iCMBiCx9meDHj+I//DJ8U87C0DVa/vNrUrXrKLvwT7mEpfn5n5Np2k75pXfn+VR3egsfzNLrUPBFYgLflDMRbG7SNWuJbfoUpXErYGqTOAZOxjlgkrl27ROzqJEWUrvXkKpdR6p2A1pWAwbJYnaFywbs9cAurD4kUa5OGKpidqSVlBkTqBkzJugsxgtitqttQbTYECx2My74HHRzPZ3I+mHvItO0g0yzaTPaae8pB8qx9xqJrXok9upR+TGRrpGu20hi+xKS25aghhoBAVuv4biGzDSFY/dhsGnxEIkt84lv/IR0/SZz2+qRuEcelRUs+3LtqjotZQuOux5PlrkGkG7YQtMzP8YxYBLFp/0MQRBI7lpFy39+hXPITFN7RRAIfvQvIsvmUHjiD3IJemTpHIIf/wv/rMvxTz4TAfBZwWMVWP3QjaSCzVx2x9OUVfTGahM4a2wv2uIKW1qiuK0yGxvCDCnzUui2IogCQ8p8VPodiGKPfuS3CV+ZD/Z/g57F9NuDZ5fUcNur63NBviTAVTP78cnWVjY1Rbt9jaEpND/7UzJteyi/9J68OdN041aanvkx9l4jKTn71wiihNJRT+OTN2Pxl1F64Z2IFjuxdR/S/tbdeCeeTuBIU2U5suwVgh89QuCIK/FOOj23z8SWhbS+8kcCR16Fd+KpXc7n80IJNaFF27D3GmEKce1YhqP3GCR3AEPN0PzvXxI48jvInkLq//FdPGNPwNBUMs07KDzhJmRvMen6TSS2LiK5bXFOKM1S0hdH3/Em1b1yWF7CaWiqOW/VsDnnf62019FJBxcsNiwFVcgFlVj85WY32VeC5Ck2aWDdKI9/URiqYlbBo22o4Razax5sRAnWo7TXYWSSuW3lQHnO99JWMRhr6YAu15du3EJq1yqSu1aQadwOGEjuQjPYGDjVnEPLBi1GVt01sWU+iU3zzPdQknH2n4Rr+Kwu3phfJmIbPqb9jb9ScOy1eMYcv/fx9R/R/uZdeKecTeDwSwEIL36R0CePEzjiCryTzsDQVJr/fRvp+s2UXfRnbGUDzCT89b+Q2PQpJef8Jo9poUbbaXz8RiSHNycC2Ilh5R7CKZX64N73eV9IAvz76mk9nexvEXoS7K8GPTHBtwcNwSStsSTJtIaiGUgSOC0ysizwyooa3llXT0vUQBbAKoPNLpLJ6DQu/4Dmt+7BN/VcyrN6KgaQ1jUaX/gFmYYtlF30F6yl/cx75mt/IrFlASVn/xpH33FoyQiNj98IBmbh0ulDjbbR+NgNSK5AbkYbTFZUw7+uQZRtlF/+ty88k/xFYwKLvww10pobF0vXrgdDR3R4sfcdi6PPOOx9xuQlnGBaVqYbNpNp2EK6aRuZ5h15667kKzVnxwsqzPlrf6cHdhGi3fOl67MYho6ejO6NCcKd2iwNKO11psJ2FqLNhbWsP9byQdjKB5vsvP0YZmq4heTuVaR2rTQFyNJxkGTs1aNNj/ABU3L2l2DaqiW2LiaxeR6pPWvA0LEU9cY1fBauYYcje7u6bnwZ0NMJ8/vk8FJ+6d25ooiWjJrfSTDZlXY3ariZxsdvQnIXUHbxXxGtduIbP6Ht9T/jGXciBUd/H4B0zVqanv8FzoFTKDrtp0iCgBWoCMhsfOlempe9xVHX3cG0I4+hxO3C45IYWeEnqeo0hRLUhZKkFI3RFQHqwklcNplij42x1QEK3V9+PNiD/x5fpQ/254au/xf01h586Xh2SQ0/f2Ud+9ZYNAP+OX9Xt121TgiShaJTf2Iqhc/5Y+4mAyYVq2D21XS8ez+heU8ROPwyLAWVFJ18C60v/pb2d+6j6KRbcI88ikzTViLL5mAt649r2Cw8E04lVbchO6M9CHvVcAAcg6Zi7zee0LyncA6a2q21UXfoFCzREmEEUUR0+JD9pftYXwCijBZppfn5n5n0aIcX0erAVjaA5J41yP5yArMuB6Djg4fJNGzG4i8zO7e9RmAc+R2Utj0kdywjuXMFkWVziCx5MWvPNRRbrxHYew3PLkL5Nlx6JoXSuptM2x6UrKhYpmGLOYe9HwVctLkQXX4khxfR7ka0uRCsDjPg6JzB3qdSbuiaOW+lZtDVtNnxzs5baYkIeiKMntq/gCIgeYuwBCpxj5xtVvSz3pz7W6IYqkKqbmPWrms96boNObsua/lAfDMuwJGl0OfRv+s2kdy2iMTWhaihJkDA1nsUvmnnmbSwbma68o+bQQ03o8VDGLqG5PBgKao+5ABLS4QJfvQI1vJBuEcdk3s807rbFKvrNQL/zIsASO5cQeiTJ3AOmYlnolnwCX78L9I16yg88WZzdg6TPpbY9An+wy7JS65NNfE7MZQUxeff3qUbsbExinSQSrRmwEsr63oS7G8JEonEN30K/2ehaRq6riOKh8aw6cFXg4ZggsW7WmkLKyi6htNmQRagNGDHZ7PSGElhlS3YpQyCAA6rgEMSSRs6gZGzSdRuILzoBWyVQ3BkBT4RJfqe8WO2/+sHtM35A70uvRvN4aXw+BtR2mtpe+1PlF1yN5ZAOcWn/5zmZ35M6yu3U3ru75E9RaaI6r9/Rce7D+QYQKLFRsHRV9P64m8IL34R//TzD/ka1VgHarABQ1MRrY7siFTpF44JZG8x3gmn4J1wCloqRmrniiw1fBWJjZ8A+9tzDcu6g8zIMQENQzeFxVqycUF7LWpHPbH6jXmJNwCSjOQMIDn31WVxmrosFqtpxyXKpi6LuXMMXcVQFQw1jZ5JZbVZYmjJqGn3lQhDdh69E4LViaWgEnuv4aYQa1FvrCV9kLwl+80pGyihJpMmXmvaeJprPEjuQlNNvf9E0zd8n3hCjbaT3L6ExJaFpGrWgqEjeQrxTDwd94gjDjpS1XlcLdaOFm1HzyQRJBlLYa8uVl4HQ+jTJ9Gi7RSfeus+zDqd9jf+ihbroOzCO5HsbvRMipaX/4Bh6BSf/jNEq51M807a374XW9UwAkeao2FqpJWWV+80vdY7Z7EBiwXqV3xA87K3GHHchYyaPpthlT6q/R4MdAxMVsjEPsX0iydJpHV2tkfM43lsWCWBlKId8nX14KuFqqpEo903IuEb6mBLkmTE43Hs9kOnv/Tgy8WKPUHOeXjRQRPpz0Jy10pa/v2rbqmv7e/cT2zNO3k08vDCFwjNewr/rMvwTT7L7IQ//wsyTdsovfBP2MoGoKfjND5xE0YmRdml9+Qqvmq4hYZHr8VWPoiSc3/3mVTnxJaFtL//IHo81OU52Vdq0rPLB5mK2GWDEGQLia0LEaxObOWDEG1OoqvfQQ03ETj8MnQlTXTF6+jpeK672R30dMKkf2VpYErLbsjZc/XFWj4YW5Yebtpzda1xGZpiKniGW9CibWjR9r0+2KkIWjKGkY6jd/pdKhnoVn1aMCniFpuZjNucOc9L0elHcvmRPYVInmJz3spb0m3XuFNtvZMmnm7cQqZ5xz52Xb2wVY/C0Xs0tt6jkPZJkrVUnHTNGhLbl5HYvhQjGTapaha7qUxeVI2teiT+7Bxzd4hvWUBi83wyzTty1lz7QrS7CRxxJe5RR3e/g87r6OyabF1E+WX35BZuPRWj8ckfYGRSpq2cO4DSUU/TkzcjeYspu+gviFY7sbXv0/723/BMODXn0ZrcvZqWf/8yV6Xe9zeQG6U4+Ud5c/b5nxBM7BNg+Z4gumH+ve/VXTi5mj+cPrLb1/bg64Ou6/Tp04fa2tqeDvZXAEEQjCeeeIJLLrnkmz6V/2+RVjXW1obY1hRhxZ526kNpRpZ76VPiocBlwRAEVu9u5+MtzbQG08Q1KHZD70I30aTKnrYUWibNzqd+hBJpofzSe7D5y9AAF5Bs2sKup3+CvWoEJef8xmS3hZpoeiLbDbzoL4g2Z45htG83MLTgOcLzn6Hg6KvxjDspd86tr/2JxJaFlF9612cKnmnxIG1v3EVq96ouzwk2lyneVTbALIxXDUN2F3wpMYFh6Cgtu0nuXm2OjNXtTZYlT1HWnmsg1rKBWEv6dWu9ZRgGeiJsdpQjrdmYoMOMCRJhM0FOx82EOZPK0sPVLvsBQJTNuWuLLZuUuxAdHlMsze1HchUge4qQvMXI/tIDdsq1RJhM0/acXVe6YXMu3uq067JXj8Tee4xZBM8Jlemkm3aQ2rmc5I6lZBq3ma/Jdr8F2YqlsApb5dADxgVqpJXI0pdN+9C2PV2LD4Bz0DQKT7ixi33o/kjVrKX5uZ/hGX8yBbO/l3s8NP8Zwgueo+CYa0zGQmf8sHk+JWf9Ckf/CWjxEI1P/gB03WRduALoSprmZ29F7aijzyV3oWVHG5yAK1nLyr/fRGn/YUy/9q/0LvYwojxA/zI3douUsywfVumjLZrGZZNpi6ZpiCSxCAI2i8SoKj9OW4981rcBd955J7feeivAt4ciLgiC8Y9//IOrrrrqaz92D8zk+p4PtjJvW9sX3lfnwheY/T2840/OPW6oCs3P/ZRM6y6TFlbSN+8GVXzWL3H2n5h3gyq79G5kdwGZ1t00PXUL1uI+lJ5/ey7pi65+h4537/9Mqrgabqb+H9/FWtoPz/hTkL3FGLqGHg+hBBvMGaKm7dkZIMx54soh2KtHYe89Glv5IARJRkvFaHvldqwVQxBkC0p7ralg3Xs0hmHkFoz45vkobTXY+4zFVj4wL2nWUjHSdRtNi66GLWSatu1dDCQZS2E11uJOK44q5IIKZF/Z55q9gmzHWtfZO4Mtfu55dV1JmyIqoUaUjrqcgqrSticnCiPINqxl/U1bjmwgsm+l2NA1Ms07SO1eTXL3KtJ1G7PzWYIpgjb6OLNLbeg4+k9EDbfQ9vqfKTj6aqyl/bs9r+AnTxDfOBdb2UAsxb2RAxVIroD5GcWCRFe/RbpmHeWX/e2A+4C91HD/YZfkbDYMXaPlpd+S2r2G0vP/iL1qmJlwP3ULejKS0w1I1W2k+fmfdRMc/gDJHcgFh52Ib/qUttf+1GXR3h9WWeSs8VU8v7Qml2CLooCuG1hkkeeumtLTwf4W4K233uLEE0+EAyykPfhiEEXRGDZsGOvXr/+mT+X/HBRNR9UMLJKALHUtTBuGQSKjkVI16oMJ3lhdR11HAkESkCSRYwaXUORxYpFEGsMJXllZR30wTjCRQRZFZNHA47DQlsiQyugEGxupfeImZH8ZpRf+CdFiQwI0OGiR0tF/AsVn/AJBEHNz251jPIah0/rS70juWknpeX/IuY1oyQiN/7oW0eGh7JK7DzpK1fLy70ntWoVv6jnmbLRsQU/HUcMtKO21KFn17s61Tg5UZGMCMy6QnL5Djgn0TJLgh//E3mdM7rW59ztnz7UxRw/vHDMD0xPbWtw7Z9spByr2rnmfgxJuGIaZZHfG+IJoarN8rn3oaNEO1FBDlia+NybQYh257eRABbaKwWZMUDm0i12XGuvIqoqvJLl71d7GhyTjGjYL78TT0VJRyCQPKS5Qo200/PPqrMhsPywFVUi+EkSrA0NJk67bSHjJi7hHzqbwuOsPeH16KkbDY9cjSDLll92Xi7sS2xbT+vLvcY04KteBDi18nvC8p/drEJkaAZ0NIsMwaH/zLuIbPqbXGbdhHTgZgWzhPBWj9qkfIGoZzvvdE6StPir9ToaW++hX6qbYbQNRoNzrwCIJ7GlLkNE0ZEmgKuBCEMBplbDKPUJn3wZomkavXr3QNI2WlpZvT4ItiqIxYMAAtm7d+rUf+/86VuwJsnhnO1P6FXYbmK/YE+TCRxaTUQ8s/S9gsooOpbmdt/Cdfzv2qmG559RYB01P3ASiTPmldyM5fehKiuZnfoISbDCVuov7kGnZSdPTP8JS1JvS829HtNiIb55P26t34Bp5dM5f0DAMWl/+PcldKyi/+K8HTKY6Z2krrvx7TvGxO2iJMOm6jaRq15OuXW8qYmMgWB1Z8a4xyIFy0o1bMVIxPBNP20sj2wcdHzxMdMUbudfaqoaZQh9Vw7GV7TerrGuowQazG9y8k0yrSQ/fd7YJTJVR2VuC5Ck0bbpcfiSnz6SC2d1mN9rqMLvTkkkRR5TI9UB1DUPXch1uPSuKZlpyZOlgWasuNdqGFmk1/Sn3geQuwFJYjaWkj9l97+y670tF11QyLTtJbl9KYvsS1I76XIBiKemLvc9YlNY9CLIFR78JeMYch66k8wKhxid/QOEJP8Ba1L2gl6FrBy0WaIkwdfdd2GV2f18owQYaH78Ra0lfs2iT3V9w7uNElry4t0rdKbZTs84UNes1AjXcQuOTP0C0OSm7+C4khwc9naDp6R+hxdopu+SuPNV0s0D0Q6wl/Sk9/w/d0tdHV/kYUenDY5P5YHML21tiueeuPqwfHoflgL/hHnz9mD59OgsXLoSeBPsrQWlpqdHS0sKiRYuYMmXKN306/2egaDoNIVOQTBKh3OfokmS3RVNEUma3U9N1PtzQQE1HCo9LJplWmTGgiF4FXmwWgV2tUT7Z2sba2g4iiQyyJNESTqEaZn1XNECQILljGTue/w2uEUdSeMIP8pk9HzxMdMXrFJ7wA9wjs2JQK14n+MHDeCefRWDWZfsUPlebgpO9R2cLnz9ET8VMi6/sqFhy5wpa/vMr3GNPpPCY7x/wvai5+2xcw4886DaGppBp3kmqdgPprBCZkTFHQywlfXH0HoNcWIUaasLIJA8YE6Qbt9H8wi/MmWPAUtwHe/VIs6tbOSxv7hg67bl2mIKiLbtQWvegdNSDpuS2EWSb2VX2FGXjggCS04/o9CHZ3Qj7UsRzY2MHoohnMJRUNiaIo6ei6IkIWiKUs+9UI22okdYu52AprMJS3BtrcV8spf2wlfbvMtqlRtpI120gsXMFqd2r0LOxhejwmkWHPmNNJoGmYu877iuJC1pfuYNM83Yqv/dI9683DNpevZPEtkWmsF7FYMC05Gx6+hYsBVWUXnCHGY9uWUDbK7fjGn4EhSfeDJiCqPF17+cx1MJLXiY091F8My6kZPr56JgRmYxO/cu/I7ZzFSf/5AH6Dh+LIJgCp4PKfcwaXELfIjeSALva4yTSKuGEQpHXhiSK9Cty9yiIf8swZ84czjjjDP7whz/w85///Nszg11UVMS2bduor6+nsrLys1/Qg0PCvsmzKAj89tQReSrEnZ3rzuRaFMBjlwknD0AlOgQIgkjRST+k8ckf0PbK7Xm0btldQPEZv6D52VtpnfNHSrP+v8Vn3EbTUzfT8tLvzES5pB9FJ91C65w/mtYdp/wI15AZKK3nEV74PNbi3ngnnoYgCBQefwONj91A6yt3UH7ZPd3SfwTZCoD+GbYVktOHc9BUnIOmAqagRapmLandq01f5x3LzO3cBdh7jyZduwEEEct+M+AFs7+Hb/r5pPasNV+/Zy2hnSuyB7FgKxuQncEehLVsgClkVtgL17BZuX3o6QRKsAE12GB2kbNUMDXYQLp2fc724suEYHMhuwuQPEVY+vVG9pVgCZQj+8uxFFZ1eW8Nw0ANN5u0sMatpBu3kmnclrMHESx2c1+FVRQedz2SK4CWjNLx/kNYAhWmxVgmmaf+qXTUI9rcB+3Yf1Ynfu/xu+9eGGqGtlfvRBAlU/Ezu7/Y+o+ILHkR95jjcxSwjg/+QWr3KgqOu8EUvEknaHnptxiaSsmZv0RyeExPzDf+itJeS8k5v81LrrVUjNaX/4BodVJ02q0HnA0/d6L5u/zZnHX51wpE0iq3njD0oNfcg68PwWCQhQsXctZZZ/Hiiy9+06fzfxIlJSW0tLTwzDPP9CTYXyIyqk5K0UhmNNKKhttmwW2XMQyTPZPKaNR1JHDZZCRRRDAEThtXzbNLdlHbkaDS7ySVhrSqEkpqNIfT+B1WCpwy8bRuzmlbIZqCtG4GlIIKjn4T8U2/gPCCZ7GWDcxjtwWO/A5KWw3t795ndomrhuIZdxJKWw2RJS+aNlcjZ1N8yo9peupHtGV1XiyFVZSc8Qsan7qFlpd+S9lFf0a0OnD0G4934ulEls3B3msErqEzu30vBNn2mVZWQtYdw1YxGCafYXabm7aT2rOG5O7VRFa+bo5GSTK2yqHEN87FXj0aW0W+WritfCC9bniWTOM2s3Nbs5bYmveIrngdANlfhrVicDYmMKnhjr7j8jU8dM0UGgs2mLFBuBkt3IIabUPZXYuWCOXUu780SDKSK4DsLsRa2h/nwCnI/jIzJiioRPIWdRnP09MJUjXryDRtI924jXT95r0NA0FEdPqwlA7AWtybwhNuRBBEM9bavRpLUe+vNC7ojAW7Q2zVmyS2zMd/+GW55FpLhGl96bdmnHr6zxEtNtKNW2l/4y5sFUMoPO56BEEgvOQl4uvexzftvFxyndyxjNDcx3AOnkHBtPPojKpFoHnec0S3LeOIS3/IzOnTUTFIZzR0RDRVpyOeYXiFRG0wwe62OIqiE8lkKPPZiSQyNIQS9C50dctA6cE3gxdeeAGAm266iZ///OfdbvONdLCHDh1qbN68me9///s8+OCDX/vx/6/igY+389f3tuQ6z7JoJtnBRIaA08qvX99ARt0rnmWRBKb2K+TTL4Eq3knrthT1puyCO/K6tp0Ki/t2o9ONW2l+9qdYSvpQet4fES02wkteIjT3sZyKs2HotL1yB4mtiyg+8zacAyYBmHTd536atfT6WZcbvhpto/7By/DNuPBziZ/sj5wC5u7VpGrWoifCQNaKonpUriIte4q6vFaLh7K+kJtMGljzjr00a6vDrP7maGAmPVzyFB7Ya1JTTHGyVNTsQqfjGJkUupIGzbTpMvZZbIWsTddeSw4Hos1hdr/tXiSn94CLj2EYaPEgake9KbzWuotUzXqTUp+du0aUsZb2MxcmQUJPRig6yazs1v/zexSd/CNsZQMILXwei78cJJlMwxa8E0/PKbMKspX2d+5D8hbjn3ZeHsXu8yC66i063nuQ8svvw1rSt8u1dFaai8/85T7foQ00P/9zbJXDKD3ntwiSnFOx9046g8ARV5hdlBd/S2r3KrOL0mcMsLfr3WUsonP7PWuybI4DJ8kXTq5mfX2YNXXhLs8JQKnXxmljKnsS7W8BfvnLX/K73/2OOXPmcPrpp/d0sL8CTJgwwUin02zbto2Ojg6cTudnv6gHn4m0qrG6JogAIIDHJmOzymCA0yZR25GgKZRgQ2OUUreVgMuKwyqzoT6ComnYJJGORAanRcLnttAeTVPud9ASSeOQYV1jlGBMIRhL0xbPYAEUQNIhpuo0vPx7kjtXUHru77FX79WS0JJRmp66GT2dpPySu5B9JRiamrXqWk/pub/FXj3KHMN56oeIVidlF/8FyekzO9Yv/saklJ/+cwRRync2ufiubplrLXP+QKZhC5Xff+y/tvvUlZQp4JVNuJWWnYC5pturhmOrHoW9eoQp7rnfMQxNIdO0g1TdRjINm0k3bN6HZi0gF1RiLe6zNy4oqMQSqDjoOm0y0iImK61Tl0VJm8w1TTGZbNkYXxAEc/5atmS1WeyIVkduBltyeBFsrgOuwbqSMpP99jqUthrSrbvJ1G9GT4Ry20jekixVfAiGpqKGmig89lrg640L9EySugcuwTX0cAqPu67L8+n6zTQ9eyuOPmMoPuuXCIJojjW+kKV8n387torBqOEWmp76IchWyi/+K5LLT2LrQlrn3I5z8HSKTv0xgiDu7XoHKii94M68wkBq6yKa5/yB0vGzOfO63zCuXwkWi0hrOIXXIdMr4MLlsDK5byHr6jtYUxtB03RCyQyVAQc+u4V4SkXRdYZV+BlZ6UeWexLtbxJtbW0UFxdzxBFH8NFHH337bLp0XWfdunV0dHTg8XQVdejB58eKPUHOfXgRajbDFjB9dnXDQMBUJf4q0Wmn5Ro5m8Ljb8y7KYY+fYrwohfwz7oC3+QzALK0mztyNyoQ6Hj3AWJr3qHg2OuytKEUzc/+FKW9ltIL7sgpN0eWv0rww3/im34B/hkXdDmX5hduQ2nbQ+X3HjloFfNQYRg6Susesxq9Z00edUz2l5l08OxMsqWwqkvSb2iquSg1bUdp2ZGjgelZChmYFXbZX2YKjnVadHkKkVwBJJcf0eFFsnu+kH2VoSk5JXEtEUKPB1GjWauuSGuue54nGpKlm1kKe6EnIrjHHItn3CmIFvN9Dc59DMlThGfMcQiShbY3TR9o/4wLCH36FO6xJ4CuEV70AtbS/jj6T0D2lpCqWUtk2SuUnPnL//56dI3GR68DyWL6pO+3EEdXvknH+3/HO/VcAlnrGKWjnqanf2TO7V30FySHJ/vdvR3noKkUnXYrINDx/t+JrXor910EiK37gPa37sE95jgKjrk273jBuY8RWfJS3vadKHZbaY1lcn8PKHaRVDTqQ6mDXt/Vh/Xj6OFlvLSyDgE4Y1xVD238a0Qmk6GsrAyPx8OuXbuQJKknwf4KMGHCBOOaa67hyiuv5C9/+Qs//OEPv+lT+j+DXa0xFF3HKot0xDKUee1IokBzOEVLNE19KM7a+jD9ily4rDKiIZDSdOpDSQxBpymUJJlSscgCVkFkRHUBhW4rhR4bzcEka+qCbGmKUtORQMtAGjPJloFMOk7jkz/M6lnck+cAorTX0vjULcjeYsou/BOizfn/2DvvMKnK8/1/zjnT607ZBkvvXelFUeyABRHEGo0xian2ggXsBWs0+ZoYTaw0FUEsiIggqCBF6b0v7O7s7PQ+p/z+OMPACggajZrf3tfFpTtzztl3zszO+5T7uW/UTKIwehOi4vLHMPpakN27kbqpt2Mqa0vZRQ8gGs3F7/WDNS7keJCaF69DNNuo/MUTh9CW98/VfpPo5LeFkoqS2b2m0KVegxyqBvSE29ysM+aqrnpcUBBJ+zrkeENRKCwX2E6+fidypI4DUpcCktOv23O591t0+ZDsXj0msLn1xNhk/c6WXZqmouXSRXcRXVA1XBwd0626ahvNXCOIuiiZqmAqb4cSC+DqPxp7j9OK6/gx44LY8rcJf/Qc5Zc+ekihW443UPvy9QiSkYorniqw0jSC7zxGav1C/Ofegr3L0MJn8RbkRIM+zuhvSXbfJuqm3I6xtDXlF+uNISUZoeaVG0HOUfGLJzG5/Gjo7+D+xNvir+KMm57BaXfQpZmbUzqX4bIY2RPO4HWY0DQNl9XI9voktdE0DrOEJIrkFQ2v3cgXOxpoW+pCFGFgWy8GSSKTU7AaRTwOMzZTk+jZfxO33XYbjzzyCO+99x7Dhw//6SXY1113HZdffjkPPfTQfhW2JnwPmLx0NxNmrUVVNURRn/H4Pt7hCqeZ2vg306sAIoteJfrZVDyn/gZX33OLj2uaSnCW7ntZev7tRVr2/pmVw3UNSy+4C1u7fvos9ys3giJTcfljGNzlemfyvSdJrp3fSKl8P/YrQ5acdCXugWO+hzvQGEWhkj1r9Rnu6vWo6RigU6/NlR0xVXbAXNEBU0V7JKf/kA1Qt5cIFa048l+jh2sHJd8HQzAUVMFNFv3/DSa9Wi5K+ryVdvAMdk636sqn0QrV7cNe02TVLUPc5XqS72mmV9B9VTTMfRZ7x8E4ep5Ocv0CsjVbsHc+AXNzfeOKLZtJvmFPUUwksfYj0ttXUHLiZQRnP4bBVUa+YTdyNIDR3xLf8Gsx+lvQ8N5TSLYSzC26k9n5JbaOg7C07Pmt3ofE6rk0vP/0YT8D6V2rCEy7C2vbPrp4jiihJCPUvnozajap0w49lWSq1xOYdifGsjYHsSneILLgxeJMIBQ+U9MmYGnRjbKx9zQSs0uuX0Bw9mM4jhterNgfjC4VTjbVxY9J1+BgVLjMhJI5coXqmKlJ+Oy/in/961/86le/4vHHH+eGG25o8sH+gdC3b1/t888/p6KiArvdzs6dO5ssu74nZGWFYGHvFgWBdF5BEgQUTWVjbZwtdTEaEjlcFgOqquJ1mJBlCKbyqKpCQzxLdTiF2SDRvtyG02qhdysv+bxMdSRDQyJDMqfw5bZ6dgSzZDU9wd4PpaGava/ciNFdTsfLJ5EzWLCaIJKD9I4vCbw+EUub4ym7YAKCKOk+wy/fiGA0U3nZY0gOT7EAau0wgNJR4xFEqejS4Dnlalz9RgGQ2bOWuql3YmnZg7KxdzfWC9FUav71JzQlR7Or/u8/KlQfCXIiVIgJ1pGtXke+fheg6aNl/paFuEAfFTOVtjrsCJGay5APVSOHqsmH9iFHapAjdcixgJ7kaoexuRXEwvy1tagQjliw7vz6DHbBvlPLZ3T7zlz6iNeUnD49JnCXY/BUYvQ0KwivNSc462FsP8G4QM0m2fvcbzH6qqi45OHGz+Wz1E0ZTz64m4rLHys6iYQXvkRsyetFAVRNzlM3/S6yezdSfuG9WFrtZ1PchGg0U1HoZqv5LIGpd5AL7KD8kocwV3akxAgGEWKRGLteuRE1n6HNL5/E4/fTvqKEUzuXYreaaOWxk5IVzKJAKJVH0yCXV0nkZNwWE2UlJtAEEpk8y3aH6NHMRToPHcrt+Ox6XJDMyZQ7rVSWWHBavv/PcxMORTabpaysjNLSUrZs2YIgCD8tH2yAiy66iOuuu46pU6c2JdjfIy4Z0JJOFU6WbG/AYzNx7zvryMsqkiiQV757sn0syTWA+4RLyNXvJDz/eYy+quJMkSCI+EZejxwLEJz9WPHLyNX/fORoHbEvZiC5SnH1OYfS826lbsp4grMepvyiBzE360TZ2Huoe/Vm6qZPpOKySUhWF74z/4gc2kfDu09icJUW52gALC17Ym0/gOhnU7F3ORGDu/w7vvLDQxAlzBXtMVe0x9VvlD6fHNpLtkD9yu7bTGzJG8WNS7S6MJW1xljaRveR3K8Q6vTpM+sF+vHBUHPpAxZdqWhjKlg2pSfNhQQaRUHTChRxQdA9MCVJ/6/BjGiyIJpsCBY7ksWhW3LY3Hp33Ok7bHUd9CKAqbSVvgkD5uZdyTdUk6vbVtxIzc27kNr8GZqcRzAYMVd1I7p4MkZPMyxV3TB4m+MePI7UliWIFgem0lbkwzUk1y3AVNkBJR3TBdT8rb7Ve6CkooQXvIi5eRdsnYY0ei7fsIfgWw9i9FbhP+dmBFFCzaUJvHkPSiJE+UUPYPRUkgvupv7Ne5GcfsoumKALmqxfSGTBi9g6n0jJSb8oXq/+rQd1r9ZR4xsl19l9mwi+9xfMLbofUTF8Q+2RvRJ7Vbkpc1n4eFMA+Ws0k5ZeG3WxA397eVllyfaGpgT7v4SpU6diMBi45pprfuyl/M/DaDTyu9/9jgceeIBly5YxYMCAH3tJ/xMwGySae/Tvd1XViGX0YN5pMWCSJDJ5hbZ+OzXRDO1K7TisJkLxHKeW2aiJZ1m8KUAklUNWIJ2F1n4jbUvtfLShDr/NRDYnsyuUxuW0YIxlyWbBBOQAmwjGyiosY25m6+R7CX74F9qOuQWDwUQul4c2x+M943eEPvgboXnP4T39GgzucsrGTKBuyngCb95D+cUPYes0GM+pvyb80XOE5v0D7+m/wzPsKpRYPeH5zyPZPdi7noSlRXe8Z/ye0Jyni8ftL2wLgojnlF8RmD6B6JLplJxw6fd+rw0OL4YuQ7F3Garf70yC7L5NZPduJFuzidTmz0isnqsfLBow+lsUaOGt9ZjA3xLJVVqMLb4OTVVQEuHDWHQl0LKpr1HEZdAaU8QFwaxTxKUCRdxsRTTp9p2iVR8dE+2eImvuSFR6TdMw/kTjgsgnL6OmonjGTPzamlUa3n2SXM0WSs+/vZhcx1e+Q2zJ6zh6nYVr4Fg0VSH4zuNk96zFf85NWFr1RElFCbw+EVSZsrEPI9lL9Ou99xTZfRvxjxqPubIjoIsJmo0QeOcR8rEgPa5+BIvPj9tmAkUlnlHIyFna+ew4zAZkWcNnMxHLykiSQDOfFZ/NRAuvjUROIZrKkZUV8hpUecxUuixk8ip5RcVmNGA26N7YTQn2fwdz5swhFotx7733HpU18qMl2AaDgbPPPpuXXnqJlStX0rt376Of1IRjQp9WnmIAvj/ZHtjWx4fravn7J9uLx4mFZqcgwLm9mrF2X6yRmvF3wX7Rs9pXb6Z+1iNUFsRJAESjmbIL7qL2lZsIvHGvTgErqcB72m9QEg2E5z2nb5SdT6Bs7N3UvnozgTfuoeLSRzD5W1J6wV3UTbuLwOv36IJpJiulo++k9pUbCbx5X/F6++E9/bfse+EPBN99kvKLHvjOc1fH9roFXV2zINACerU0V7uF5LqPSW1ZQq5+F5nqjaAcoAqLthKM3mZI7gpdYMxVdoAG5vQVr/lDQs1nyYdrUNOxRkUKADQVyeYpdufFQlIuH6R6bqpoj2C0kt65Elt7PSg2eJqhZlN4TvlV8bjsvo26onk2hcHpo+pPryJZXd9pzZqmEfrgb6jZFN4zG1O1lUSYutfvBslI2diJiGYbmpKn/q0HydVuo3T0HZibd0aO1ROYPgFBMlJ24b36fN/Orwi++yTmFt3xj7xBF2RJhgm8fjeIBkrHTGxEPZRjQerfegDJ4T0k8T4WSKLAhHO60aeVp+gAsKUuzld7IpzVrYLTu1Vw8XOfFzvYRoPIwLa+73TPmvDtsGfPHj788ENGjBjRNBP8X8Ill1zCAw88wLPPPtuUYP8AEEWBEltBBFTVaO6xMkj0EUrlOa6Fh3yhvlfptlLpsuC1m0mmZfx2C6FUjnK3Ba/VRHUojUEQiWZlJKMBv81IMpUFDST0mKKlA+xWM5F0jvLu/dFGXM22d5/D27wtPc65CpdNorohA8cNRw7XEPtiBgZ3Oe4BozFXdsR/3m3Uv3kf9TMfouyCCbj6nouSCBFb+gaSrYSSEy7Bf85N1E27i+C7TyJaXVjbHI+z1xnI4b3Elr6JwV2Ge8ABBpu1TW9sXU8i+vl0rO0HHDaJ/V7vt8WBtW0frG37APq+lQ/tJbHqA1IbFqKmE6R3rCS57uPiOYLBhKGkEslTiamkEqkwMmbYPzLm9GJwHar78n1C01TddzvegKEwptb4gJ9mXJDe+RXxle/i7HMu5soOjZ6LLHhRFzU7+aoigzK5YRGhD/+Btf0AvAV1+fBHz5HatBjPsKuwdz0ZNZch8Ma9yNGAXpgvxGORhS+R2riIkpN/ib1Q4JcAt12ibs4/iW5fxcBf3I6rXQ8SuTwemxHJKKEi0NJrB1Ekm1Nw2wyIiETSMrKi4bGZqPLasJmN2MxG/HYzbUudyKqKURRRNY1APINRErEYJBTAYW5Krv9beO655wAYM+bozNgfjSK+fPlyNm/eTKdOnbj44ouZPHnyf30d/z9hf/AeT+dZVxNjePfKRsk3wJsrq5m2bDfKYdhC3xY6zeuGgrXR442+MPPBPdS+djOiraTYjVbzGQLT7iJbu4XysQVKTriG2tduRhANVFw2CYOrjNTmz6mf+RCWVr10SlnBi7L21Vv0mdpLJyHZS4q/a//M7JFmtY8GNZ8hHywodip5EA26RZbVhWT3IFoc31jFSm39gvjyWbiHXEx623LUfAZX33PJ1e9GDlcjh/bpc9nB3bpn5dcgmKyIFqe+6bpKkWxuhKJF1347DqPub3lwAUFTdSqYst+SI6tXtw+y6lLSMZRU9AAVXTTQ8qYZh8yPpzZ9RmrbF7qCpigR/+p9lFSUksEXHThm82ckNyzC6Ksis3sNjp6n4+h+KqqqoBWEWJR0DDWbKt5HyebG6GtxxO75N2H/DF7Jyb/EPeCCA+9XNqVTwELVulBJZUe9Ij37MVIbF+Eb/mccPc9ASUWpfe1Wfdbv0ocxlbUlW7uVuinjMbjKqLj0EUSLAzWXoW7qePL1uym/+MFGBQg1l6FucmPLuaPBIAkIQF7RkESB+76m9H84rNgVbprB/hFw7bXX8vTTT7N06VL699fF8Zoo4j8M9scEACeccAKffvopsVisSZ/lB0Iym6c+kUMUoMxpwWLU946srCArGiZJoCGZI53TVchjqSzJvEqzEit7wklsBgmnzUQklcNplpi8dBc7Qwn2NaQpsUk4TEYcFhNt/Daqw2m2BZNkZZl97zzDtk/fZfgf78fdbSixdJ4d9WmSyv4RssXFGViA+Kq5hOY8XbBHuh4QaHjvLyTXziuKTKqZBLWTb0OO1FI+7n5dYEtTCb79qP6df/aNOLoNK752JR2n5t9/QpCMR3Qi+SZomqbTtqMBXatEEPRusMVRZIV9U6H16zGBJmdxD9GZf3JoL/lQNbmazeRqtxaFURtBlPSYwGjG4PLrVl1me2FsbL91ZyEmkAzoajzoMYGqwP64IJ9FzWdQsym0bBIlEz9g4ZmKFtl3h9MUgf8wLshnURKhomjr/tE1wWzDWFKJ5Cr91jPlSiJMzYt/RrQ4qLjiSUTjAaGx2LJZhOf/E8fxI/Gefg2CIJDesZLAG/dibtaRsgvvQzSaiSx+jeinUw6MLCpywRp2JaWjbsPWcTBwIP5wHD+iEUtCAqT177Bl9t/pd84VVJ3+SxRFJZTOY5YEWvosnNK5GX6nBb/TgstqpI3PTiKnICsKFqOE2WjAdAwiZpqmkS04Bh3L8U34z7F9+3batWvHqFGjeOutt4qP/+RmsPdvpqeccgoff/wxDQ0NeL3e//pa/n/AwfZdJoPIa1c3nuE8+HmDKHBcixJW7I4U1RuVg0TT2pU5jrnLnaneQN3U8Zibdab8wvsazTxl9qylbtpdmCvaUzbufl0sIh2nbvKtyLF6PTmqaK97ZE8ej2RzU3HpI0h2D4nVH9Lw/l+wdhxE6Xm3IYgSmeoN+hytr0oXnyhsmpqm0fDuEyTXLaBszASs7fod09qVdJzwR8+R3LjogHL2YSAYzDh6nXFYarCmKkQ/nYKmqXiG/oJM9QYSX75LydDLi5T1XHA38RWzC97Z3QkvfAlDid7NVhIh5EiA9K4vUWL1GFxlgIaSTugCa4ebmzoSJAOi0YpgtukU8f2UMHtJgQ7mQXL6sbTqeUinX06EaHjncUpOvBxz884E33sKc2VHbB0Hk6vdgrG0DWouRXLjIrK71xbnvpR4A3IsgJb/BiEvQcTWYSCe035btHc7GjLVG3QV+YMUQEG346p7fSLZ6vWUjb4La7u+eqd7zjMkVs8tCuyp2SR1U24v2mxZWnTXhc9euwXBYNaLOU5/48119B3FKvz+97Z+5kOkt35B2QV3HfVztd/3enRvvfr9TV71TfjxkU6ncbvddOrUiTVrDlipNSXYPwwOjgnee+89Ro4cyYQJE7jnnnt+5JX970HTNHaHUpiNIqqqAQLNSqyNjsnkFfZF0jjMBlI5GZfFSCork9M0PY4QRb1YKAjYTSKPzd1IKJFlRyCOoom0L7PTv60Hq9HAxxvrCaUyGAQJgyDz4RPXEdyxgRP+9CT+tl2pj2XYG1OQ5Rx10+4iW7NJn3stzN1GP5tGZNErOPueh+eUq0FT9e/eLUvwjbwBR/dTkBMh6l67FTUTp/zihzCVtUGT8wTemEhmzzpKR9+J7aDv6Ez1OuomF5xIRt9xSFH5SEismUdk0WsHbKgOCwHJ4aHZr/7vEKG1bx0TtD6eyMIXMXqaYfA21/fUaB3pbcuRo7UYPc31bnMhUf12tl0Cgmm/krgD0epAtDiRbG5EuweDw4Pk8GKq6HjYjvk3xQXZmi1ILj9aNklq82dk924EBDRVRU0EUQq+2EeCwVOJ56RfYus0+JheiabkqZt6J7narVT84vFGxe7E2vk0vPsEto6D8Z93ayFeXEdg2gQM3mZUXPwQosVRdBHZL9ILGg3vPkly3ceNigypzZ9RP/NhrO36Un7+Hbr2DbqoWWbbMurevA9flwGc+rsH8DrNxHNglgRqowkMooH2ZS46lDu5aGAr3FbTdxana8J/H7///e959tlnWblyJccff3zx8Z/cDPZ+3H777Xz88cdMmjSJhx9++OgnNOFbY8n2hqL39eFmOA9+XlE1TupUxq3Duxwyx200iHhtx05FsVR1wT/iOoKzH6Phg2fwjbi++GViadEd/9k3Epz1CMG3J1F6/u1IVidlF95L7au3EJg+QaeGl7WlbMxEAtN1enj5xQ/i6Hk6ai5N+KPnaHj3SXwjr8dS1YXSUbcRmHE/gTfuoWzsvboImCDgPfMP5IO7qX970jF3GoOzHyOzexXO44ZjadkDyeFDkIy6CncuXVTblONBTGVtD3sNtSAgsn8mSTTbkOwe8g3Vxc00u2ctgmTA1q4fosWhJ3bZJI7up6KpCsm1HyHHA9g7n4i1bZ9iEq4kwviG/4nYF2+Rb9iDq9/5iNYDm7kgSFCcwTb9R/R4ye7BVNGeyKJXUfMZlGQYLZsmvmYeSqS2SBMr/m6TTRdK8zbD0uZ4nWLm8CJa3boQS+E+KskI2ep1xL98D3nG/VT84omjbjZytI76tx7A4CrFd/aNB5JrJU/9zIfI7l6L7+wbisl1+KPnSKyei3vQOD25zmUIvH4PufqdlI2+E0uL7sixeuqm3QlA+bj79PdA02j44K+kty/He+YfGiXXoCukprcswXPab4+paDOora+R5VZTYv3Txv/93/+Rz+e5+eabf+yl/H+Hs846i5YtW/L0009z5513YjQ20R+/TwiCgEEUkBUNVdMwHcZbVyx8D2dyCqoGFpOEx25CUTVEQSCWyaOouvJxJq/QymOjJpKmmddGa6+NZl473ZqXsHpPBJ/NCKJAJJHF77Jx8u8f4oOHf8uy5+9gwJ+fIW/xUWKCECZKR99J3Wu3Enjz/iKzyDXoQpRUlPjyWYgWByVDLqb03FsIvHE3De89hWA0Y+80hPKL7qf2tVupm3YnFRc/jNHfgtLz76Ru6u0EZz5E2Zi7sbTSk3ZLVTd9pnveP4gsfLkoZPlNSG39gob3nsLcvCvuweMwepvr40eahpbP6O4cyYium5JoQDhMZ/y7xASSrQQlGdbFtwoxQT5Uja3T4ENjghHXEvtiBvngHlz9zy8k+IUmmiAiCGJhBtsIkvE/Su5Ekw2jryXhBf9Gy2dRUmHkWD2xZTORo4HGjDxBLLijlGMu74fkKkVyePXuu9WBYDADoGVT5Bv2kFg9l/qZD1Fx2aTivToS9L36/8hWr8N/zs2N4rvU5s9oeO8pLK164j/nJgRRIrtvE4HX70ZylVJ+4X2IFgfxr+YQnv88to6Di6Js4Y/+SXLdx5SceHkxuc7sWUv9249iquyA/9xbkEQJDd3rWmvYTuDtSVgr2tDlovGogkSvln5yqsL66hhpWcRllUgVRAYFhKbk+meEZDLJ3//+d3r37t0ouf4m/OgJ9qmnnkq7du2YPHlyU4L9A8FjM+kbpqYddoZzYFsfBknUxdAksdhZ2z8XOrp3VZGeOmNlNV/s/Obq48Gwdz2ZfLiG6OLXMLgrGtG07Z1PQEmGCc/7B6EP/ob3rD9hcPopH3efvlFOvZOKyyZhqepK6ei7CLxxD4HpEyi/6AFcfc9Fy2eIfPIyiAZ8I/6MtV0//GffRHD2o9TPuI/SgmiVaLQcmNV+/R4qLnv0G2eYNDlPZseKoh/3d4acR1PySLYCPV5TUAs+j/uRD+1FtLqKlW5NziIUqE2JVR+gpuPYOg5GjtQUFUfVdFz3sRYlDJ5K8vU7dbpYITE81i9tTcmjphMoaV1ATUlGCv9CxSBBjgVR4sFDqGqZfBajpxJT2z4YSiqLCqOGkgpEq+uY12DrMADJ4SU8/3mUROgbu9hKOkbd9Img5Ckb8xCSVaePaqpC8O1HSW9bhveM3+PoNgxN04h8/C/iK2bj7Hse7hMvQ81nqZ9xry5Kcu4tWNv1Q0mGqZt2J2omSfnFD2L0Ngf0+arkmnm4h1yM87jhjdYRX/ku8WUzcfY5p5EP9jdhXU3s6Ac14SeDyZMn4/P5uOiii45+cBO+V4iiyPXXX8/111/PnDlzOOecY/sba8Kxo8RmYncoiVESqXBaDnneKAkYJIFALIPTYsBskPTEXNIFleIZGVEAp8WI2SBxYqdyKj0WArE8ZoOI2SRiFCX8DjOyoiI0JMhkJXxOC+UlVThvfZpX7rqCJX8fT9erH0UwOnEIkLI6KbvwHl1/ZfpEyi97FGNJBZ5Tr0bNJogufg3RbMPV9zw9Jpg+geDbjyKcb8TWvj/l4+6ndspt1E27g/KLH8Loba4LpE4ZT+DNeym78B4sVd0AcPY+m3xwN7Glb2Bwl+E8fsQ33rP0liWIFiflFz/4rbU2ivihYwJBwFBSQT6wQx8rc/q+XUygqWjZlC6cloqipKN6TJAIHfhXsPT8emEdyYRoCmP0tcDarp8eD3gq9fjAXXbMRX5Lq57Yu5/Knr+MI7Vt2VET7OhnU0mu+RD34Isa2a+lti2jftYkTJUdKB19F4LBRLZ2K4HpE5BsbsrH3Y9kLyGxdj6hD/6GpW2foiBq+JNX9Nih3yhcgy4EIBfYQeDN+3QRvgsmIBot7C8haLEge6bdg8Fsp+Ol92CxWOhS7iSvqDhMEqqg0dytN3wEBJp5rFibrLV+Vpg2bRqapn2rovuP/g4LgsAFF1zApEmTWLhwISed9P34EzZBx4pdYe59Zx2qptt2TTi72+G7Z/tHBTSNTbXxRt3r/dTy0b2rGN27iunL95D/Fqba7sEXIUdqiX46GYO7rCgCBugzVKko0c+mIlqdeE7+JUZvc8rH3UfdlPHUTb2D8kse1unA599O/YwHqJs2gfJx9+kVXUUm+ulkEMB31p+wdzkRTcnR8O5T1M+4n9LRdyIazRhcpZSNuZvaybcSmH4X5Zc8jGRzH37BkoRgtKBmjqz8fEwwGFEzieImqKTjCIKAaD0wV6hmEsWkDnRlbGub9uSDe5Bj9bgGXIASDSCH9yEW1uvoeTqhec+x5y8XYfS1wD3kEoyeSoBGG2l27wain08v2HJkizNXutpo6oiWXYgGJIcHg8OHqbwdhvYDDgitFCw7RJP18Od+B6jpOCB84zXVbIrA6/cgR+soH3cfRl8LQPcXD85+jNTmz/Cc8mucx4/Qk+sF/ya27C2cvc/WaYVKnvoZ95PZtQbf2TfoxZ10nLppd6HEg5RdeG9R8Ca6dAaxpW/gOG447iGN5/ZTW5cSmvcPrO3769c9RgzvXvntb0wTfhSsW7eOlStXcuWVV2IymY5+QhO+d4wZM4brr7+eF154oSnB/p6haRqRVA672YCiQjwr4zM2Tn6ysoqsaLT02knkZOKZPEZJRBSgLpbFZNBHxxqSWcqdFuxmA1UeO21KRbxWEzWxLAZJpGOFC7fNhCiC16F/v+dVjcpWrRhw9f0s/uuNbJn6AG0uvRcEE5KGvldfeK/eyZ52l66r4vDgG34tWi5D+KN/IhgtOHudSdnYu6mbeif1Mx+k7Pw7ddruRQ/qscOU8ZRf9KA+NjbuAWqnjCfw+t2Ujb0HS1VXnd12+jUo8SChuc8iWhzF2e/DQbQ49P1TyX/3BPtHjgkAgu8+iZKMoCm6NouWz+j6LNmUrpFyBK8Z0eLUbbucfkwVHQrxQBmSqxxDSbk+e/49dWTVXApU5aj6LPGV7xJd/Br27qfiPkgVPr1tGfVvPYCprDXlY+9BNFn15HranQhmO+UXPYjB5Se5fmGhw92D0lG3IxiMRJe8TuzzaTh6noFn2K8QBIF8pJbAdD2pLi8IohbXmk1R+8bdaLk0Q6/7C7bKcjRNRJJESl1m8gp0KnWSyOsjmB3LHAztVNY0N/0zwwsvvIAkSYwY8c2FuIPxoyfYAH/84x+ZNGkSjzzySFOC/T3jYPq3gEY4dahoxpLtDcgFv+yconHHW/rMoSQKqJrWiFo+sK0PAX0eWxShjd/BjmACVdUVjtv57YdYEgmCgO+sP6LEG2iY8wySw1u07wJwn3ApSjpGbOmbiBYn7oFjMJW10SvP0+6kbuodVFz8MLZ2/SgddRv1Mx+ibvpdlF94b7EjHv10MqgKvhHX4eh+KmgaDe/9hfo376F09AREkwVTeVvKLphA4PWJBKZPoGzc/cUuaOP1iljbDyC5fiElJ15+5ET8KJAsDuTwPjRZdwVNbVyEwV3RiMJUnBUvWFnIob0Y+40ivvoDcns3kjDbkKMBlFRErxiXVBBf+Q7OPmdjbdeP8PznyUfrsHytSq1pGmo+hxxv0ClhBjOSw47BqM9cCQXLLoPVhWh16iqgBYEW0eo85pm0/xRqLkNizYdYWh93xM1UzWUIvHkvuVrdXsPSorv+GpW8LmSz+TM8w67C1e88nRY+/3niy2fpyfVpv0WTc9S/9QCZnV/pImfdhqFmEgSm30U+tJeyMROLXY34qrlEFvwLW+cTi2Io+5Hdt4ngrEmYytvhP+eWY67Ijzqu2VGFzJrw08GkSZMAuO66637chfx/jKqqKkaMGMGsWbPYvXs3LVs2/f18X9A0yCkqNpOEouoz1V/H/q89WVGJp3PsrE9gMxnw2U3kFRXTQR3XaDqPqoLLYiKcyqFZoUOZg7SsYJYkbEYJs0FiRzBBbSxNyxI7ecVETc8+RC++hdWvPkD1rCdoO/YW0ml93zH5W1I2ZqK+/+8viFsc+M+5mcBb9xOa81cEyYCj+6mUjbuPwNQ7CLz1AGWj78Tato+eZE+9g7op4ym76AFM/paUX/QAdVNv1/f+MROxtOyBIEr4z7uVwPSJBN95HEEyFhWmvw5rh0HEvphBYtVcXP3O+073/r8VE8ixAKqqHuIlr2kacrwBNZPQ2X1mO4LTp89hm2wIBcsuyeZGtLkKMUEJks3zg/iGHwnxle8AYGs/8IjHJFZ/SOjDZ7G2768LrRU+tKktS6mf+RCmsta6vo/FodPCp09AMNupuOQhDO4ykusXEnznccxVXfUY0WgmtmwmkYUvYet6UtGdRI43EJh6B5oiU37pIxjcZcU1aEqehrceINewhzOunYTqa4lRNNHMa8FgEPE7TOwKJjFIEq0cVspdZoZ0KMVpbSrc/pzw5Zdf8tlnn3HFFVfgch27wv1PooTSokULzjvvPN5//3327dv3Yy/nfwoD2/owGUQk4cgWP/uP2Z9KaIV/cmHe6uBzD07GFRW2BRJIosjFA1oy5dcD6X2E2VJBMlJ6/u0Y/S31WdnarQeeK1SSbV1OIrLwReIr3wUo+F/fjRJvoG7q7SjJCLYOAyk971ZytVupm3YXaiZByQmX4D7xMpLrPiY4+zE0JY+jx2n4Rl5PZvdaAtMnoBaUsi0te1B6/h3kgrsITLsT5es0pwLcg8ehyTnC85//rrcegJITLiWy6BXqpt2FHA1g73oy2X2byIf2AmBt35/MnnVoSp5c/S5dhMzuwdH9VIz+lsSWvklizTwyO78is/Mr0tuWgWjAUtUV0WBGDu0l8vG/qHnpOuRYAIB8pJbdj59PZMG/EAQBU3l7Ki5+kLIxEyk971Z8w/+MHKomvXExzt4jsXcZirVVL0ylrXWV8v9Scg0QWfQKSiKEe8jFh31+v3d1tno9/rNvxNZB33B1uvcDenJ96q9x9R+NpqmE5v5NT677nFtIrrN653rHl/iG/0mf388mqZs+kVxgJ6Xn34614EGe3PAJoTnPYGnTB//ZNzRKoPOhvQTeuAfJ4aFsjF6wORpcFgPXDG3LUxcd27xOE358hEIhXn75ZQYPHkyvXr1+7OX8f43x48cD8Pe///1HXsn/FkRRoMRqIp1TycoqbuuhiZPZIFHqNJPOK2wLJNlSF2d1dYR9kTQWk0Re0Qu6PruZjKxgMggkszKyqhfzYxmZCpeVEpsRo0Gk3KX79zZz26gssWA1G2nlddDthDNpM/zXhNYtpuHjFxEPEt01N+9M6eg7yYeqCUyfiJpNIRiMlI66HUurnrqa+PoFSBYHZePux+RvSWDGfaS2LcNU2oryix8EoG7KeHJ12zE4fZRf/BAGVymB1+8mvX2Ffj+MFsrGTMRU3p76WY+Q2vzZYe+buXlnXXRs8WvFvfa74D+JCWydhpCt2UJizTzSW5YSWfRqo5hAsjiwtutH7PPp7Hvu19S8cuMhcYGWiSOgYSxtTfm4+yg7/w78I2/Ae/o15APbSa79CEePU7G164e5siMGV9l/NbnO1e8i9sVb2LqedESL0sTquTS8/zSW1sfrQrcFRkFy/ULqZz6oN1PG3Y9kcRQEde/UXWYueRiDu5zEuo+LyXXZmLsRTRZiK2YX57D9I/X9X0lFCUy7CyUdo+zCezD5DxT6NE0l/N5fSO1aRdexN+Bs35ucqlFqM2AURTQEMlkNh1W33XLbDbQpdRLN5Mnmv40YXRN+bDz66KMA3Hrrrd/qvJ9Egg3w7LPPsmDBgm9VHWjC0dGnlYfXrh7IDWd0OkQ9/OvHtPId2kE8u2clg9v7i9TywyXjiqLS/CAVUoPUmCYkFn4UzTb9y8ziJPD63eTDB4opgiDiH3k91vb9CX34LIk18wBdjKRszETkWIC6KbejJMPYOg6mdJS+adZNvaNoDeEZdhWpjYuon/EAaj6Lo/sp+M+9hWzNZmonjy8qV1rb9qHs/DvJBXdTN3k8SuLQmXKTvyXuQeNIrvuY+Fdzvs0tbwRL2754z/wDzr7n4j311xhcfuRYPXI8iKapWFsfh2R1UPPy9QTffgT34It0/2t/SzK711D5y6cpOekKBKMFW8dBGMvakKvdAqKBxOq5aJqKo/spuPqcQ/jjfxfupYDRW0XllX+h8sq/4Cv4O+5HcuNinZL2IwtspDZ/Vug0j8RS1fWQ55VMQrdu27MO/9k3FOer1GySwOsTSW9fgffMP+Lqex6aItPwzhMkvpqDa+AYPKf+Gi2XJvD63XrnesS1OHqeoSfX0yaQq9uq224UBMpSW5YeqGafP75I4QNdLbVu+gQAysbec6gn6BEQz8j867Od3P7WGlbsOnbdgib8eHjqqacAuPPOO3/chTSBE044gYULF3LllVf+2Ev5n4PHbqK5x0oLjw2b+fBERqfFiMdmxmQQKLGbkAQIJrIYRbAYJdxWIyaDSInVRE5RSeZknCYDBklA0yCdU9jVkCJT0HZpVmKjfbkTv9PMcVUejmvjZWzvVlx3/Q30OvMidiycQXL5m43WYG19HKXn3kqudguBN+9FzWcQjWZKR9+FuUU3gu88oSfZVqeeZJe21guvmz7Tu9aXPIwgmaibMp5M9QYMDi/llzyMwducwJv3kdzwCaDHJuXj7sVU0Y76mQ+TXL/gkPuxXzAVNOpnPnJ4C61jwHeNCUxlbXD2PZd8YAfuQWNxn3CZHtMYzcWYACC5dj6C0UKzXz6Dq++on1VcoGaT1M98CNFsx3uEEazYitmF5Po4SkffWZxfj3/5HsHZj2Fu3kWfr7Y4SG9fQWD6RCSHT3/f3WXEV82l4Z0nsLTs3ji5nvcPrB0G4j9Xn8PWWW4TkKO1lF1wF87NCtIAAL8XSURBVObKjsU12EWNzKJ/E1+/gHZn/ZLKPmdQH8uRzilURzPsDqcoc5gxSOC1mXHbTFiNerFpZ0OKncEUwfg3uKs04SeDQCDAlClTGDZsGF26fLMewNfxk0mwKysrOemkk3A4HEc/uAnfCn1aefjDsPbfqFzcp5WH3wxtd8jj76yu4dOtQe59Zx0rdoWLyfjFA1o26ox7bCYufX4JU77YjQBIAsX/9m3lKSbZBqeP8gvvBU0lMO0u5ESo+LsEyUDpebdhaX08De8/TXL9QkDvOutJdh11U25HToSwdRhI2ehCkjxlPHK8AVf/0XjP/EPhS3UCaiaBvfMJlF1wF3JoL7Wv3kI+XAOAtV1f/ZrROmpfu7lRsr8f7sHjsLTpTejDZ0lvW/6d7r0gCJhKW2Nr1684N2zvfALWVr2KnWLv6b+j2dXPUnHppOLj2X2bMHgqkZx+DE4/lhbdSG3+HIPTj73LUALTJxD+5CXQNBzHj8DebRiZXasO/OJCJ+DrNnxqNkV8xWzcg8Z9p9fzfSFTvYHg7McxVXbCM+xXhzwvxxuom3wb2dqt+M+7FXvXk/XHEyFqX7uV7N4N+M+5CedxZ6HmM9S/9QDJ9QsoGfoLPCddWUik79I73+fchKPHaaiZBHXT7iok1+OL3fD09hXUz3oIU3k7ysZMbOSfuX+TVVNRysbe02g27mjQgJysMmXpbi59fklTkv0zwAsvvEBVVRXDhw8/+sFN+MExdOhQOnbsePQDm/CtYZREDIdRED8YDouBEpsJTdOwW4y08NnJyPrIWCCWIZ1TsJokWnhstPHbyasayYxMmctMIJ7BZBCwGSU0FcwmiRYeK6oG6byK12bC7TDTucLF2D/cTrcTzmLf/BcxbJ6H+aA12DoOwn/2jWT3rKN+xgNocg7RZKHsgonFJDux7mMkq1P3wa5oT/2sh0msnY/R25yKSx9BtLkITL+T9PYVuuXnxQ9ibtaR4NuPElsxG9Cp2eUX3qdfc/bjxccb3bOSCvwjrydXs0lny30rWywd3zUm0DSV7N6NuqCoryUGdym2zieQ27e5GBPUvHIj2er1uAaPQzTbsHUa/LOJC/RRrgeRIzWUnnfLIYVsTdOILH6tmAiXFURs9ccnE5r7f3pcN/YeRLON5PqFuiiZt7neuXb6iS2fRWjO01ja9Kb0gol6cr18VvGapefdiiAZUbMpneVWv4vS8+8oWsYZgBIjpFbOIPD5W1QNOoeKIRfQEM8iqyo2sxGfw0jncju1sTRLd4bYHUpiEgXSeZlQMo1RBJfVQCwjk1e+hdVqE34UPP+8zmK94447vvW5P5kEuwk/Pi4Z0JJrhrZFLCTHBlEXMjl4Bhv0ZLx7MzddKpyc2qWc164eSDiVK856y4qGohW62xos2xlGFPSZbgCjr4qyMRN1+s30CSiZA77agsFE6eg7MFd1JfjO4yQ3LgbA0rInZWPv0S2VJt+GHAvqgiZj70GOBqibfCv5SC3O44bjP/dmsvs2UTv5NuR4sDCT9QBqJkHtqzeR3bcJ0Kvj5Rc9oItUvHpz8fHiWkSJ0vNu06viMx8kvePL7/V+a6pSpKgLgtjIM1NJhDC4ShEEUd982w8oduDdgy6k4rJJSDYPnlN/g8nfEkGUEC324r2UYwFqXryWuqm3k9mztnjdyKJXcQ0c00i19EjI7FlLPrjnOwUR34Tsvk0E3rgbyeml7IK7DllLrn4Xta/ehByto2zMROydhuiPB3dT+8qNyJFa/fGuJ6GkYwSm3kl623K8Z/we96AL9Y7z5NvIBbZROmq8flwqSt3UO8jVbad01O0HkusdXxKYcT8mfyvKLry30Ry4mtdnv/MN1ZSefwfmyg7f6fVqNP77acJPEytWrGDfvn388pe//LGX0oQm/CRgMogMbOPn5E7lDOtUSqdyF5IoYDKIiIKArOoJQiIrszkQJ5rJE0zm9NEyUUBVQdE0FE2lIZGlNpYhnMzpFHSXha6VLhxmIw6zxJBfTqC0U1+2zXya1JbPG63D3vUkfCOuJbPzS+rfehBNzheTbEvL7jS88wTxVXMRLQ7Kxt2nP/buE8RWzMbgLqPi0kkYPM0JvHkvibXz9eMuvA9rhwGE5/1Dt5rS1CLLrvj4x/9C0xonQbaOg/GccrVuAfXuk9/7/qikomiaekhMIAgiSjKMwVWKvfMJOLoNw+AqQ0k0FGMC/9k3IdpcRZ2S7zsuyIdryO7d2Chm+z6gyTnqZz5EZtcqfMOvLSa0xecVmdCcZ4h+OgV7j9MoHTUewWA86PHJ2LufSun5d+iz1MtnEZz9KObmnam45CFEm5vI4smEP/on1o6DKCuI30aXvkH4o3/qrMiDkuvA6xOLLDdr2z4AmIASM2Q2fsTOD/5NVe9hVJ3xa2JZSOVAVlUcRhGH1YSKiN9lpVOFm1haJiOr5BWVSFpGEAUS2TyyqiI12XT95DFlyhS8Xi+nnnrqtz73J59gZzIZcrkc27dvJ51O/9jL+Z/HbSO68Po1g7npzE7ce153zMZD57cnL93N7W+tYVV1lLnr63jl853MXVeLJAqIwqEalPo8N408tM3NOlF6/h3kG6qpf/1u1NwBusz+mShz884E3550IMlu0Z2yC+9DSUaoLSTUllY9C8lzkrpXbyYX2I69y1DKxt6td6dfuZlc/U79i/ayRxGNFuqm3E5y06fFdVRcOqnw+Pji7yquxWyj7MJ7MXiaEXjznuJ5/wmytVupf+tB9vzlIhreeeLwBx3li1c02xEkQyPBjf03XnL4qPr9i1Re+Re8p1xNcPbjqLk0ubptyNFabO366UGB9vV3qjHq37yPfS/8jj1PXkjgjXsabcjfFemdX1E37U4kq4vyix5Espc0fn7HSmpfvRkUmfKLH8La5vjiecXHL3kYa5ve5CO11L56C9m6bfhH3Ybz+BHkw/uoe/XmQhJ+N7aOg/SEe8p48g17KBt9J7YOAwq/60vqZ9ynW7mMuw/poGCm6KtdrXfK96/ju0BALy4dTv+gCT8d3H///QBcccV/YM3XhB8cTTHBfxcGg0iF20q524bTYsBqkkhmFYwGEWtBfTyazmEQREodFnKywq5gnHAySyyTL4hPCdjNBoIJ3bnCbpaojWZYszfCrlCS3aE0oigw+Df3UtKiE3tnPUJ651eN1uHocRreM/9Ievty6mc+WOxkl14wEUub4wnNeZrYitmIJmshSR6o+1wveg3RVkLFJQ9jadGNhnefIPr5dL2YP2o8juNHEFv6JsFZkw5Q0EeNx3H8SGJfzCA48+FG8QmAq98oSob+guT6BdTPfAj1SI4cxwg1lyay+DX2Pvdrqp+5FCVWf/gDvyEuEM12jJ5KBPFrlP/vMS5IrJlH7as3Uf2Xi9j3/O+JLnnjO1Pl90PNJql7/W7dZvPMP+Lofkqj55VMgsDrd5NYPRf3oHH4hl97gML9xj0kVs/FNWgcvhHXgSgSnv9CMZEuv/BeBKOF8Ly/F5Lw0yg97zaQDEQWvUZkwYvYugzFf+4tByXXd5Pdtwn/ubcUC/FmwG2F+nWL2fTmXyjr3Jf2o28kLotoKkgSaIqG2WykwmXGYTaQL4gHWgwioYT++cnmZGqjKYLxLKqqkcrJNOGniyVLlrB27Vouu+yy73T+TzbBTiQSvP/++zz66KOMGTOGESNGcPXVV7Nhw4Yfe2n/89hPKb9kQMvDzm+/v7am0fEzv9rHquooeUWjbamDg7eAg/+/PtH4i9ja5ni921yzmfoZ9zf6ot6/SRaT7MKslKWqC+UXPYCWTVH32i3k6ndhbtaJ8ksfAVGi9rXbyOxajbX1cVRc8jBoCrWv3kJ651cYfVVUXP44xrLWBGc+RPQz3dfO6Kui4hdPYCpvR3DWw0QWvdqoai3Z3JRf8jDmig4EZz5MdOmMQyhWx4pszWZqX7mJzJ612LuejKPnGYc9zuDwIceCxZ+VWD3S1zyiJae/uBFrqoKaSyFZHAiSoahEaipvh8FTQT60l+y+TeTqtrP377+ibsrtejI69fYjrrV0zN34Rt6Ao9cZ5Oq2UTd5POkdK7/T6wZdmCTw+kQM7nJ9HspVWnxO0zRiy2YSeP1uvePwi8cxV7RH0zTiK98lMH0CBqefisv1x7P7NlH7yk2oqQjl4+7DXhB/qX31FtRcmvKLHsDa+jjkaB11r92KHA1QOmYi1nZ9gQItfMZ9GDzNKL/oASTrAe0HTVUIzn6czPYVeM/8A/bOJ3zn1wwHRAOb8NNFdXU1M2fOZMSIEbRrd+ioTBN+fDTFBD8+RFGgwmWhymuleYm1SDF3W03Iqkp9IkNWVtkRSpPMKtTHM9iMEpIo4LWb8DvM5BWVz7cF2dOQZEcwycZ9YRoSWbKKgsPp4uybnsBeWkVwxv0oezdgE8EigA1wHneWnmRvW0bgrQJd3GimbPRdBxLqz6aCZNTZSz1OI/rZFEIf/A3BaKZszD3Yup5E5JOXCc15BjQN7+m/0/VbNn1K3eTxyPEggijhPf0aPMN+RWrz59S+dgtytLGwmXvQhXhO+y3pLUsJTL2jyDD7ttA0jcDrE4l+OgVDSSUlJ19V9L3+On7suMB53JmUXnAXJUN/gWhzEVn4IvWzH/1OrxtAjgb0Inn1Onxn34jzuLMaPZ8P7qH25RvI7FmLb8R1lAy9XLfNCu+j5pWbyOxejW/4n/EMvRxNzhKc9UjBnnOknkgDwbcfJb7yXVz9R+MbcS0IAuGP/kn0M70b7j/7RgTJUOxcZ/dtxH/uLUXmnBlwWSG/ayV7Zj6Gq2UnKkffRhYRowQWE5gM4DQbaeWxEUoqyIqMJGiUWIz0b+PDaJAQBYFgIsfeUIZkXkZWNJK5JrGznzLuu+8+AG666abvdP73kmALgnCWIAibBEHYKgjCbf/p9RKJBPfffz9Tp05FkiRuvfVWNm7cSLdu3YovuAn/HRxufvubPH3tJgmzUSx+sI6WWNg7DcE3/Foyu76iftbDaMqBit5+upa5eReCsx8jsXY+AObKDpRf8jAAdZNvJbt3IyZ/SyouexSD00/d9Akk1y/AVN6Oissf01VDp08g/uV7SPYSKi5+CHvXk4kseoXgrEdQcxk9ib7owcKGPJX6N+5FSR+wG9uvVGrrNJjIgn/R8M7jh1S1jwXpLUtBlWl29bP4zvwDtk6DD3ucqbIDcngfcjSApuRJblyErf2ARsfYOgwoisGlNi7G0kpXPVZSUTRVQdM08uF9yOFajJ5KnMePoOr3L9L8mheouGwSRk8zyi968IhrtVR1wdH9FLyn/ZZmv34OwWg+osLqN0FT8oQ+/LsuTNKyJxWXPoLhoKBAzaUJzn6M8PznsXYYoL+PrjI0OU/og7/pVhxt++iPF+w16qaMRzRZqLjsUSwtupPatoy6KeMRjGYqLp2EuVkncvU7qX31ZtRMnPJx92Mt3J/U1i8IzNBns8oveqCRDZumqTS8/zSpTYvxDPvVIRv+0XAwS+Ng5BWNGSurv/W9a8J/B/tVQm+88cYfeSX/O/g+44KmmOCnA0EQMEpiIwvDEpuJ3q28dK1006HUQS6nkFO0wtiYhtNiIJzMEc/IdChzIongsBjZ2xBncyBJOJkjEMtSG80gWByc8ucnMLu81Lx+N5narcga7O8R70+yM9tXEHjzPtR8pqAuPh57t2FEF72qu38IAr7h1+IadCGJVXP0Ir4q4z/7RlyDxhUKvrpWi6v/6KJiee1L15PduwFBEHD1P5+yMROQo3XUvHTdIQVmV59z8I+6jVxgBzUvXX/ImNmxQE1GyFavxz3kEsovvBf3gNFHtAb9seMCg6sMW/sBOiX9kodx9h5JevPn34kmn975le5+Eg9SNvYeHN2GNXo+uXExNa/cgJpNUX7xAzh6nKaft30FtS/fgJqOUT7uPhw9zyiy1FKbPsMz7Fd4TrtGn6OedldhL78Kz7CrQFVoeO8p4ivextn3PHzD/6yrhWcS1E27k2zNZl3zpVBUNwBGCTJ71rHmlfuxlrWi9biJmE0WfHYzogjJnK41ZLWaiGXy7IskiaVVTCYRs1HAatJZHuFUjr3hJAgaO+uTrK+JcBQJhCb8iNixYwfvvfceo0aNokWLFt/pGv/x2ysIggT8DRgOdAUuFgThUEngY0Q+n+fqq69m79693HXXXdx+++0MGaJXkvx+Pz6fD0Vpqvr8mOhU4eSMruX0qnIztIO/0XOD2vp47eqBDOngLwqbHQ2OHqfiPf13pLd+QXD2o42+rEWzjbKx9xRmqp4sKnqbSltRfukkRIuTuml3kN62DIOrlPLLJuld79mPEflsKpKzlIrLHsXapjehuf9H6MNnQRDxnX0jJSdfSWrzZ9S+ciP50F4EgxHf8Gv1tez8ipoXryW7d+OBtRjN+M+7VbcEW7+Q2pdvIFe/81vdO1MzXbAncRRl8v0V9Lrpd7Hvhd9j7zIUo68FkUWvkdr6hX7fep6Omkmw97nfEFvxNp6TrgQgu2cdNf/+E7UvXUfw7Un4zvxDsXJdhKbpRubHiOT6BWj5TCMlzWNBPlxD7Wu36D6d/UZRNvbuRmvJ1e+k9uUbSG1cRMnQX1A6ajyiyYocC1I75TYSq3RV8NLRdyKYLEQ+eYXg7EcxVbTX2Qi+FsRXvkv9mzrVu+KyRzH6qshUr6PuNd1SofyShzE376y/jo2LqX/rAUylbXSKeqPkWiM09/9Irv0I9wmX4up//rd6rQDRzJEpX01d7J8mUqkU//jHP+jevTunnHLK0U9owlHxfcYFTTHBTx+JTJ54RkYSBQxGgbymEoinySkqDrMRv8NMc4+NSrcVq0nCZzezpS5GPKugAYIIWUUjntbntFs1r2DEzU9jtjvYM20CSv1ObNKBgNV53Fn4RlxHZudXBF6/W7fwEiV8I6/H2edc4stn6eNXqoxn6C/wnvF70ttXUDf5NpREGM/Qy/GNuJ7MnvXUvnID+eAebIXirmA0Uzt5PPEv30PTNKzt+lH5iyeQ7B4C0ycSWfRaoxjF3mkIFZdNKjDobiW2bOYhc9vfBNHuRnKVkdq4CPUos80/pbhASYRJbVmKqbxdI0vLo0FTFSKLXiMw7S4kWwmVv3iiaJUJuv93aN4/CM56GKOvJZVXPIWlqhuaphL9fDqB1+9Gcvqp+MUTWFr21FmBL11PPriH0tF34Op/PnKkVtfbqdmE/5ybcfUfjZrLUD/jfpJr5+M+4VI8p1ytz7WnotRNub2gzzK+2LkGkIH4rvWsf2UiNl8l3a68D5fTidkokMsrGEWQFcjnIZmTScsabqsRr81MTTiDrAhE0nn6tPIysK2fNqVOmrvtuKxGfHYzFuPhFfyb8ONj0qRJwHfvXgMI35XqWryAIAwC7tY07czCz+MBNE176Ejn9O3bV1u+/PCqzDt37mTcuHEsXboU0DfXHTt2sGTJEiZPnszdd9/NwIFHNp9vwg+LFbvCXPr8EnKyiskgMrp3FZOX7gZ0OvhNZ3biD8Pas2JXmIv/uYS8rCIIoB7Dxyz2xVuEP34BW5eTDvEhVvNZgjMfIr19OZ5hvyomP0oyTOD1u8kFduA76484ep6BJudpmPM0yXUfY+82DN9ZfwJRIrLwJWJfzMDcojul592KZPeQ3rGy6J3tG35tsXKZ3beJ+rcnocSDlJx4Ga7+oxutJ73jS4LvPo6aSeIZ+gucfc89pk1G0zSCsx8jtWEhroFjKTnxsm+1Of23oWkqsS9mEFnwEpbWx1E29u5jfp2J1XMJz38eQRDxDb+2Ubde0zQSX75H+OMXEMw2/OfcXOwwH+49UbNJgu88TnrrF9h7nI7vjN8X5q2eJ75iNtZ2/fCfewuiyUpq82cEZz+G5Cql/MJ7MbjLAX1+rOH9pzE360zZ2ImNggtN0wjP+wfxle/o78vQXzTq0BwrRAFEQRcH1L72+OvXDP5GJf8m/Dh46qmnuP7663nxxRePOn8tCMIKTdP6/peW9rPFt40LmmKCny/yisqeUAqLQSQjq8iqikkSUVUNBY22fmdR4LQumiGUymKSRDbVRgnEc2ypjbOqOkQ2r6KoKhLQwmvHZhKJBWuYfs9vkOU8bS97CMnTkoOn7pPrFxJ89wnd93jsPUhWlz5utOR1Ip+8jKXVcZSefzui2UZq2zKCsx7RBc4uuAtTeTsy1esLomlZ/GffiK3DQJRMgobZj5Hevhxb15PwnflHRJMVNZch9OGzJNd+hLmqG/6zb2ykgaKk4zS8/xfSW5ZgaX08vuF/bjQG9U1I71pFYPpEjL4qSs+/HaOn2ff4Dn3/yAV2UP/WgyjJcGF8rv0xnZcP7SX47hPk9m3C3v0UvKf/DtF0wOI137CH4OzHyNVtw9nnXDzDfokgGVHSMRrefZL0tmXYupyE76w/IZosJNZ8RMMHf0Wyeyi74E5MZW3J7FlL/VsPgqZROvoOLC2664K6b9xDrnYr3jN+h/M43SVCjgd1N5toHaXn31EUNNuP7L5N1E27E5PDyynXP4XZXYZRVMjkdRtOFY14Jo+m6ePxnStcHN/Ki8NswGg2cHL7UuIZBZtFQtAEEpkswYQucNau1EmbUgcmQ1Mb+6eGeDyOx+OhZ8+erFx59LHII8UF30eCPQY4S9O0qws/Xw4M0DTtj0c655s2U4CePXvy29/+FrPZjMFgYNOmTUSjUUaOHMnIkSP/o/U24btjxa4wT83bzKdbg6iaTou5qH9L3lxZTV5WMRpEXrt6IJtq40xbtpu1+6KoKkgiaOhJx9EQXfIGkYUvYu96Mr6R1zdK5jQlryenmz7FPWgc7hMvQxAE1GxKV6Dc+SXuwRfjPuES/VqfTSW6+DXMzTpTev4dSA4PiXUfE5rzDKLFSemo8Zibd0aOBaif+Qi5mk04jh+BZ9ivEI1m1EyChjl/JbVpMeaqbvhGXo+xpKK4HiUZoWHOM6S3LsXcrDPes/6EqbTVUV+jpsi61/eqDzBVtMd7+u8wN+v0Hd6RHxa5+p2EPvw72T1rsXU6Ad/I6xpZWB0J+UgtoQ/+Rmbnl5hb9sQ/8joMrgOBiJwIEXr/adLbl2Np0xv/yOuR7B40RSay+FViS97A6G9F6ajxGH1V5Op3Uv/WA8jRAJ5TrsbZ+2y0bJL6WY+Q2fklzr7n4Rl2FYIoEVv+NuGP/ompWUfKLphQ7FDHls0iPP+fWFofryuNmg68Dk3TCH/0HPEVs3H1O5+SYVd9p+R6v5jZ1Se0IZ6VmbZ8D4qiIYoC953XnUsGtPzW12zCD4/WrVsTiUQIhUKIR+neNCXYx4ZvGxc0xQQ/T2RlhWgqTzCRxe8wk8rJWI0SmQI13GMz4bIYqQ6nqI9lUdHwOc1oGmTyMlvq4sTTOfaG0ny1J0wqmyOUVmhfZqPCaSWtqKxes57Fz1wPmkarSx9C8LQgf9AaUluXUj/zYYyeSl2Q1Kmz6hJr5tEw5xmM/paUXTARg8tPLrCdwBv3oWZi+EbegL3TEORYPfVvPUiudguuQRdScsKlIAhEP59OdPFkDJ5KSs+9BVO5rs2QWPcxobn/Bwh4T/8t9m6nFPcLTdNIrJpDeP4LIAh4Tv4ljuPOKtpvfRPSO1YSfHsSmiJTcsIlOPucgyAdfuTox4KazxBbOoPokulIFqfurlFgh30TNEUmtnwm0cWTESQj3jN+j73rSQee11TiK98lsuBFBKNZL8gXxEgz1esIvv0oSjKC55Rf4ex9Nigyofn/JPHle5hb9tQbJjY38VUfEJr7LIaSCsouuAujtzn50F4Cr9+NkmjAf87N2DoOAvRkv276BNR0jGZjJmIoKK/vh55c34Vkc9Pikofwl/oxm0RyClR57VS5zewIptgbTSMrGlazRJdSB4M7l9G9mYecrJFXFGxmAy29NsrdVhRVI5nJI4oiDosBYxNH/CeJxx9/nJtuuonJkydz8cUXH/X4Hz3BFgThN8BvAFq2bNln165dR7zmsmXLmDx5MslkkrKyMiorKznzzDNp3/7YqmRN+P6xv3Odzato6B05UyGhBliyvYGBbX1sqo1z+1trGp0rCTCuf0uC8Szb6xNsrU9+4++Kfj6dyCcvY+t6Ev6RjTvZmqoQ+uBvJFbPxXHccLynX4MgSmiKTMMHfyW5Zl6ha/1nBIOR5MbFNLz3JKLZodt/VXYkV7ddT9jiQTwnX4Wz77mgykQWvkxs2VsY/S3xn3MTprK2aJpGcu18QvP+AZpKyUlX4Ow9srhhappGct3HhOc/j5pN4uo3CvegcY3sno6E5MbFhOf9AyUZxtZxMK6BY7+zFdT3iVz9LmJL3yC5fiGi2UbJyVfh6Hn6UZNONZ8l9sUMYkteB1HCc9KVOI4f/rV7NZ/wR/9Ek/OUnHxl8V7mQ3sJvvM4uZrNOHqegee03yAaLSTWzCM091lEsw3/qNuwVHUjF9xN/Yz7kaMBvRrd60w0VTnQze4wEP85NyEaLbpH5qJXiX0+DVvHwfjPuRnBcCBoOTi5dvY9r0Ab+3bJdfMSC26rkY01cVTAJAlM+Y2+ie//u2jqXP80sXXrVjp06MB1113Hk08+edTjmxLsY8OxxAVNMcHPG3lFpTqcQkSgIZnFajTgshood1kQBQENveBYH8uwviaGKEJtJEvv1iWIiPjsJkRBQxME8orK3DX7WF0dIZDI0qWihExeZmcwTjwts2vHZlb8/WbQNJqNewC+VsjO7FpNYMZ9iFYX5Rfei9HbHNCT1vqZDxVEUydiKm+HkghT/9YDZPdtLBTqL9UTtg//TmL1XD1hO+dmJIeHzO7VBGc/hpKO6Uy1fqP0/SpSS8O7T5CtXo+1fX+8Z/y+mNhDocg85xkyu1ZhquyE9/Rrjmlvl2P1hD74G+ntyzG4y3ENHIuj+ynHZKv5Q0LNpUmsnkts6ZsoiRC2LkPxnvbbI86KH4zM7jWE5v2DfP1OrB0G4j39d400WPLhGhrmPE129xosbfvgG34tBocXTZH1Jsnn0zG4y/Gfewvmyg7kI7UE336EXM0WXP1HU3LSFaBphOf/k/jKd/Ui+nm3Ilocejd7xgMgCJRdMKFYDMjVbaNu+kQETaXL5fdgKO1AQgYFUIHs3o3UTZ+AZHNRfvFDmFylGDkg2ltRYqCVz4HbaiSZybM3ksJhsVDuNtOnpZeTu5RTG8sSTGQosRrREOjTyovD0kQJ/zlgyJAhLF++nEwmc0zx4A+ZYH+vFPGDkcvlMBqN36mb1ITvF3/7eCuPz92EqulfMid08HPdaR0PSRwuf2Epi7YEGz1mkgTG9m1Bt2Zuwqkc8XSef3++k2z+yHNK0SWvE1n4ErbOJxZVHvdD0zSd7r30DawdB+nej0azTg37fDqRRa9gbt6V0vNvR7KXkKvbTuCtB1ASIbyn/w5nrzN0Gti7T5LeuhRrh4H4hl+LZHWS3r6ChveeQsnEKTnhMlz9z0cQJeRYgIY5fyWzY2WhW/1HTKWti2tSUlHCC/5Ncs08JLsH9wmX6knpUejUajZF7Iu3iC2fhZZLYW7eBUfPM7B1GnJMSfr3BTWXIb11CYnVc8nsWo1gNOM8bgSuQWMbKWwfDpqmkly/kMgnr6DEAtg6DcFzyq8xuA4KOMI1hOb+n97Vbt4F34jrMHqb6xX/r97XqeKiAe9Zf9Ip4bl0gY43H3PLHpSecwuSw6MXTN7/C0LBSsVS1Q01m6T+7Ulktq/A2W8UnpN/eVDR5W8k13yIo9eZeM/4feNijabqQdWX7+nnDfvVd6aFf52cccmAljx4fo9vfa0m/HdxxRVX8PLLL7Nq1Sp69ux51OObEuxjw/dJET8YTTHBTwfpnMz6fVFAQBKgpc+G13Eow6k2kmZzfRynycDmQIKOZQ40AexGibSs4jBJaMDm2jhWg0gwkWVzIA4aZBSFunCGrJxny+YtrHjuZjRFofm4+6GsTaPfk63dSuD1iaBpus1ngRWWC+wg8Ma9qJkY/rNvwtZxkD5GNvf/SK75EGvbvvjOuQnJ4igWdAWzFf/IG7G2OR4lFdWZaluWYG7ZA/+I6zC4y9FUhfjyt4kserVQUL5C71YX9phi8X3Bv1CTUezdT6HkxEsbsbkOB03TyOxYSWTRq+RqtyBaXTh6nIa9+6nHxJD7vqBpGrm6bSTXfkRi7Xy0bBJzi+6UnHhZ0Wf7m5AP7SWy8CVSmz9DcpXiPfU3xe4xFLray2YS/XSKfv9O+RWOnmfoKuHBPTqVvHaLTiU/7RpEs43kpk9peP9pAPwjrsXWcTByIkRw1sNkq9frDLSTr0QQpQKD4a96N3vMRIweXZw3vfMr6t96AKPVQbcr7kd1N8dqksgrOvU7tHM9ta9PRLKVUH7xgxhcpQjQSLjXZQKLJDGgvReTKLC+NoHDYqSF10b35m4Gty8jLytsqInhdZiwm400c1uoKLE2da1/4ti4cSNdunThyiuv5N///vcxnfNDJtgGYDNwKrAXWAZcomnauiOdcyyb6QcffMDevXu56qqrDnk8EonQo0cPKioq8Hq9/9H6m3Bs2O99vR+ndy3nmpPa0aeVhxW7wkfsYLfy2qiJZZAV9ZjmsA9GdOkMIgv+hbXjIErPuaVR5xEgtnwW4Y+ex9y8C6UX3FlMBJMbFulda1tJcS5HSUUJzn6MzM4vsfc4De/pv0MwmIgvn0V4wYtINje+s2/A2qoXSipK6IO/kdr8GabKTvhGXIvJ3/JAB3b+C6iZBM6+51Iy5JJGiXB23ybC818gu3c9Bm9zSoZcjK3ziceQaCdJrP6Q+FfvI4f2gmTA2vp4rO36YWl9HIaSyu89qJSjAdI7vyKzfTnpHSvQ8lkkVxnO487C0evMo1anNU0jvXUpkcWvkQ/swFjWFu8pV2NpdSBRUfMZYkveJLr0DQTJQMnQXxzoWh9U5ddn1q7F4PKT3beJ4DuPIUfqcA8eh3vwRXqFesG/iS+fhalZJ0pHjcfg9JMP7aV+xv3kw/v04klB9VvNpamf9TCZ7StwD74I9wmXNrp/mqrQMOevJNd8iGvAGEpOuuI731+BQwXMzuhaTq8WJU3d658wGhoa8Pv9DBkyhMWLFx/TOU0J9rHh28YFTTHBzw+JTJ411VEENGrjGdqVOmnptWE3G4hnZAQBnBYjiqKyem+EjTUxFAV6tXRhMxtRFY0NtVEEJJLZHD6HibxcEJtMZ5Flhd3hFBv3RXFajYSTGer27GbZszeh5HOUjbvvkNlfnQ48ESUZ1n2MCwrbciJE/Yz7ydVspuTEy3ENulB/DV+9T2jec0hOnz4yVtGeXP1OgrMmkW/YrXdIT7wcJAOJ1R8Snv9PgEbU73x4nz4StWsVpooOerf6oJEvNZsk+vl0YsvfBsB53HBcA8dgcHzz51TTNDK7VhFf+Q7prV+ApmL0tcDaYQDWNr0xN+v8vXe21XyGbPUG0jtWkN6yFDlSA5JBZ9j1OQdz8y5HvYYcrSP62TQSa+YhGEy4BlyAq//5jcbL0rtWEf7wH+Qbdhe62tdgcPrRVIXYF28RWfwaosmqU8kLBffw/Of1sbrKDvjPvRVjSQWZ3aupf3sSWi6N76w/Y+96ki6ituBFYsvewtKqF/5R45EsDgASa+fT8P7TGH1VtLrwbiSHH58TBERESSS5Yw2rXrwLo8tH64sfwFriR9UgL4Os6fu8gG7b5bKa6VDhxGIUiafyqIKAySBy5ZA2VLptpDIy24NJvHYjVpOEKIk4zUa8dhMlth+XkdCEI+OSSy5hypQprF+/ni5djv55hx8wwS5cfATwFCAB/9I07YFvOv5YNtNcLsfq1avp27cv6XSav//970yePJmysjK6dOnCxo0bKSsr41//+td/vP4mHB1/+3grj32wqZhICIDZKDLh7G7c+866oujZa1cP5JXPdzLzq33Fcw+XgBwrYitmE573DyxteutiJV+bAU5uXEzwnccxuMsoG3N3sUqZrdlC/Yz7UbMJfMOvw97lRP2Ld/FkYp9Pw1jamtLzbsXoa0G2divB2Y8ih/bh6jeKkqGXg2QkteETQvP+gZpN4R40FvfACxEMRpRUlMjCl0is/hDR5qbkxMsadauLiecnr5AP7sLgaYar/+hjonppmkZu3yaSGxeR3rIEOVoHgGT3YGrWCVNZG4z+Vhg9lUjuckSz/RsTQ03T0HIp5GgAOVxDrmE3+brtZGs2o8R1poHk8GHtMAB75xMwt+h+1HkxTZFJblxEbOmb5Ot3YvBUUnLCpdi6DD2IDq6S2vAJ4QUvocTrsXUZimfYrzA4fQcq/4tfBUHEM+wqHL3OAlUhuuR1op9OQXL48J9zI5YW3ZGjAerffoTcvk04e4/UadySkfT2FQTf1tVbS0fdhqWlntjL8Qbq37yXXGAH3jN+f4jVlqbIBN95nNTGRcWZ/f+keGEyiORltfgZl0QBNA1V0/9GDvaQb8JPBxMmTOC+++7jzTffZPTo0cd0TlOCfez4NnFBU0zw80MqJ7M3nCaZlYmkcrTx21E0MIg6o0cFHGYDZU4LO+sTbAkm8NlMNCSzVLqsNCSzbK1LUGKTiGcU2pY6yORVcoqKzSSxfl+EvdEMbouEiMCmmijhdJ6Nm7ax5ZU7UNJxysZOxFLVrdG6lGSYwBv3kqvbhve03+Lsrc/oq/ksoTnPkFy/AFunIfhGXIdospLdu5H6WY+gpMK6NWPvs9HkLOH5L5D46n2MZW3wn30jptLWyNEADe8/TWbXV7ouy5l/xOhvgaZppDYsJPzxv1ASIezdhlEy9BeNRM7kWIDI4ikk134EooSj5+m4+o06JkEzJRkhuXERqc2fk61eB6oCkgFTeTvMFe0xlrbB6G2OoaQcye5txPg7HDQlj5IIkY/UIjdUk6vfQa52G7nA9uK1LS17Yus4GFunIUhW51HXmKvfReyLGSTXLwBBwNnrLNyDxyHZD+x9+fA+wgv+TXrz50jucryn/aZYBMnWbiU05xlydduwdhyE74zfI9k9ZPduJPjuE8jhGlwDLqDkxEtBEIkteYPI4tf0GflR4zGVtkZJxwm+PUnXZul9diFWMKBpGtHPpxFd9Crmlj1pMfoO3E47MlBiM6IJEFq/lHWT78Pqa8ZJ1z6GwexDlQANLEYJWVWQVZVYWsZvM9GhwoWqCbitBpxmE0lZwW83MrRTBaqmoWmgqiouq4l4JosggMNsREOklc/WxML5CSIQCFBeXs6wYcOYP3/+MZ/3gybY3xbHSgcDSCaT3HHHHTQ0NHD55Zdz0kknIcsydrud1q1bs2DBAlq3bv3DLrgJh8xggz5bPbi9v5Ho2Q1ndGLJ9oZGNHFBzze+M+Kr5hKa8wzmqi6UjZl4iLVEpnod9W/eD4Kgq0YWNlwlEaZ+5oNk924ozuoIoqQnZu88jiZn8Z72W+w9TkfLZwl/XNhQfS3wjbgOc7NOKMkIoY/+SWrDQgze5nhPuwZrm+MBPYkPf/RPsnvXY/S1oGTo5Vg7DDpI8EQltflzYp9PJ1e3DdHmxtHrLJy9zmykPnokaJqGHNpLZvdqstXrydZu1bvbB5crJAOS1YVgsiEYTAiiCJqGJudQsynUTBxNzjW6rqGkElNFe8xVXbG07IHR3+qYvuyVRJjE6rnEv3wPJdGA0dcC18Cx2Lue1Ki4kNn5JZGFL5Gr24apvB2eU39dpJRlqjcQ+vD/yAd2YG3XT59dc5WSq99Jw3tPkavdqiu3nv47RIuD5KZPCb3/NJqm6hXqLifqyuZL3iDyySsYS1tROvrOovicLmJzL2omQel5t2Jt16/Ra1DzWYKzHia9bRklJ1+Je8CYo77ug2E3SySzjS2BTu9azscbA8iqhlSgix/8ce/f2sP0aw7vd96EHwe5XI7KykocDgc7d+485mCnKcH+YdAUE/z8oGkaDYksNZEMsqpS7rKQUVRUVcNpNqBqIKsaLbw2djck2dGQwGs10ZDK06HUzrqaOA2JDKIgIIrQ1u/AYpQwiCINySxLtzdQG02hKBpui4m0rLC2OsTuUJpgsJ66aXeixIKUnn/7IerPai5DcPYk0lu/aCSCqWma7oqx8KWCYvcdGL3NG6lU6yNjf0ayukhtXUrD+8+gZhOUDLkEV//RIEok13xI+ON/oeYyuAaMxj3wQkSTBTWbIrpkOrFlsxAEAWefc3ANuKDRmFU+XENsyesk1s0HRcHavj/O40dgaXP8MYmhqdkkmT1rye5ZR3bfRnKBHWi5g7XVBUSLA9FiRzBa9GRb03QP7FwaNZs8xA5MNNsxlrfD3KwTlhbdMVd1ayQEesTPgKqQ3voF8S/fI7PzSwSDGUevM3D1v6DRiJicCOld7VVzECQj7oFjcfYbVRSTjSx+jfjKd5Fsbjyn/RZbpyGg5PWGyBczkJw+/COvx9KyJ3K8gYZ3Hyeza7WuKH7mHxDNtgPaOomGwijgmfoa5TwNHzxDcu18XZ9n+J8RJCNWwOsAm9lIbP1ilr38IM5mbRn6p0nYXF4kUUTQBEQ0KkssVHpstC11YBJFsoqG3WygR3M3qqZrrdRFU9itRipdVpwWA238DlRNt+/aUpcglswSiGfwOcwM7VhKlc9x1PvbhP8uxo8fz8MPP8zs2bM5++yzj/m8n22C/eKLL7J48WImTpzYyOy7rq6Om2++mXvuuYc2bdp8wxWa8H1hxa4wM1ZW8/ryPSiqhtFwoIP9dRXxg2ni/Vt7WLYzXEw6ROC0rrp10oLN9eTko3tGJjcsIvjOYxj9rSi/8J5GVVEoUMPeuAc5FsB35p9w9DgV0Cu1oY8KapMtuuM/9xYMDi9yvIHgO4+T3b0aW6cheM/844EZ7DnPoCRCuPqeh/vESxGNFtI7VhL68FnkcA3WjoPwnHwVRk+l3q3e/DnhT15GDlVjKm+He8glWNv3b6Qsmtm9mvjyWaS3LgPA0uZ4HN1Pxdp+wDFtZPuh5jPIoX3kw/tQ4kGURAglHUfLpdGUHKgqCIKebJusSFYXoq0Eg6sUg6cSo7d5I1uMo0GTc6S3LSexbn6RpmZpdRzOfudhbdunkYBZZtcqop9OJlu9HslVRsmJl2HvdjKCICLHg0QWvkRy3cdITj+eU3+NreNgUGWin79O9PPpiGZbY0rYR/8ksXquTgk75xaMnkrdCuXdJw6x6wBIbVlCcPZjiGZ7QdCmbeN7l00SeONestXrdXG040cc9fWLAnSucGKURMb1a0mnCidjnv3skAR6+a5wUZ/gcN+o1wxty20jjo1u1IQfHi+99BJXXnklTzzxBNdff/0xn9eUYP8waIoJfr5QVI1gIktOVvHYjMiqRjiZRxM0yhxmHBYjOVll5a4Q1eEUrXwOujd3szeSIierxNJ5Klxmqrx2crJKNJUjks7z5e4wO+vj7IvlkFAosRhYvS9OMJ5mT0xBSUaomz6BfHA3/rNvwN5laKN16cKXLxBf8TbWtn11G8fCOFd6x5cEZz+qW0GOuA57pyG6kvWyWYQXvoRkc+EbeQPW1sfpI2NznyW1aTGm8nb4hv9ZF0tLRggv+BfJtfORnKV4Tr6ywOISkKMBIoteIbluAYLJgqvPuTj7ndco0ZYTIeIr3yWxag5qKorkLsfR/RTsXU8uirQdCzRNRYkFyYf2IkfrUBINKKkYaiaBJmdBkfUDJQOi0YJosSPaSpAcPgzuMozeKiSn71t1VHP1O0muW0By3XyURAjJ4cN5/Agcxw8/5DXGvphB4sv30FQFR88zKBlyCZLDg6YqJNbMI/LJy6jpOI7jhuMZerkuTla9job3n0EOVeuip6dcrVutbf6chjnPoMlZPKf+FkfP0wFIrvmQ0Id/b+QOA7o2Tv1bD5CtXo/7hEtxD76o0esst0Fy9VzWzXgGZ8tu9PrlvXhcJVhMIjazifalNtKKQGufDZdVosJlo0ulC7fdhCQIGCQRRdXYXBtj8ZYAOUWjxGqkR1UJDqsRNFAUlX3RDIFYmhW7wnSudGI1GTizayUl9iaq+E8FmUyG8vJySktL2bJly7f6e/jZJthXXHEFAwYM4Pe//z2ge2D+7W9/45lnnuHKK6/krrvu+iGX2oTD4OCZ66/PYO//edw/PkNWdbrYvef1KNLIRUHg3oJl0eSlu5m2bDdmg9goAT8S0ttXUD/zQSS7l7IL7y3SwfdDSccJznqIzK7VjcQuQJ+9Cc39G4LJiv/sm7C2Pq4w7zODyKLX9A11+LVY2/ZBzSYJf/xvEqvmILnL8Z1+DdZ2/dDknC7K8fl0NFXG2fts3IMu1L03VYXkuo+JfjYVOVKL0d8K14DR2LsMbWS1IUcDJFbPJbH2I5RYPYLRjLVtP2wdB2Jp0+eYqFg/NNRsivSOlaS3LiW1ZSlaLoVk92DvNgxHrzMbbf6aquhd+i/eJFezBcnhwz1oLI6eZyIYjAURtxnElr2Fpiq4+p2Pe9CFiCYrmT1rCX3wN/INexqpkmaq19Hw7pPIkTpcA8dQcsKlCJJB9yaf9QhKIlS06xAE4SDf01cwVban9Pw7G6mUgr7RB16fSD6457DB2JHQr7WHjuVONOCC3lUAXPnvpcQzB7rY7UvtVEfS5AuFIuUwH+TWPhsLbh72Ld+JJvxQ6N69O1u2bCEajWKxHHuBqynB/mHQFBP8byGv6N+F+wWd8orKnnAKm1Eik1dwWYzImkYyoyBJUOGyEk3n2FQXR1X0jmFOVlmyLUgglqYhmSeTV6gOJVE0jW01KVKAmkkQePM+vWh6+m91C6evIf7le4Q+/DtGb3NKL7irSMnW7TkfJlezGWefc/CcfBWCwVgYGXsMOVSNs885lJx0BaLRorOpPnwWNRXTi+8nXKLvY9XrCM97TmdsNeuEZ9hVRRZdrn4X0U8nk9r0KYLRgqPXmbj6nteIwabJeVKbPyOx+kMyu1YBGqbydtg6DsbaYcAxM8x+SGiaSq5uO+ktS0lt/pR8cDcIIta2fXD0PENvKBykM5MP7SW27C0Saz4CVcHe7WTcgy8uNiUyO1YSXvBv8vU7MTfvivf03+pFi3RcH79bNQfJVYbvrD9ibdMbNZskNO+fJNfOw1TeDv85N2P0VTXyJre06oX/nJuR7CVAgc325v2oqQi+4dfi7noScqPXpJFdMpW6T17D0b4fXS++DaPFhtdpZkhrD2kZwqk8JTaJZE7vWHercDOgvY9kTkHVBFp7rUTTeZbvCrGlJk6bMjuCKNCnpYcSmxlRFECDVXsjbNgXYWtdkp5V+vrO6F5BqfPY954m/LD45z//yW9+8xv++te/8oc//OFbnfuzTbDfffddbr31Vm699VY++eQTFi5cSO/evbn++usZMGDAD7zSJnwXHKw4LgowpL2f4d0rCadyxST866Jprbw2doVSR712du9GAm/cA6JE2di7DxE50RS50LHW7Rr8595STFpz9bsIznqYfEM1rkFjKRlyCYJkIFe3jeA7j5MP7i5US3+FaLaT2bOWhjl/RQ5V60Icp/4ag7scOd6gV6fXfIRgtuHuPxpnn3MQzTY90V6/kNjSN8gHdyM5vDiOG46j15mNRE00TSW7Zx3JDZ+Q2vI5ajICgoi5siPmVr10mlazTv8VJXE1lyFXu5nM7rU6HX3vBlAVRItTn83uMhRLq16NNlAlFSWxZp5OF4/WYSipxDVgNI7up+mJdT5L4qv3iS55HTUVxdb5REpOugJjSQVyIkRkwb/1brarDN8Zv8Parh9qPkPkk1eIL38bg7sM38jrsbTorlPCv5hB5JNXkBw+Ss+7tSgio+bSNLz3F1KbFuvU8rP+jGg0N3p9+YZq6l6fiJqK6nTCNr2/030SAUEE5WuEi/2Cf0u2NxBP53lu0fZDBP2aOtg/HSxdupSBAwdy7bXX8tRTT32rc5sS7B8GTTHB/zYOTrDjmTxmo0SZ04IkCkiCgCgKfLUnhKxomCSR7fVJOlU6qQ6laIjr1NpUTiaelqmNpdgXTbOxLoNGYexn9qOktyzBO3AsjqG/OCQhTe9aRXDmw6Cp+M+9pUgp15Q84Y//TXzF23ridu4tGL3N9b1o4UvEV8zG4GmGb8S1WKq6oWQSRBb8m8SqD3Qm1rBfYet8AmgqybUfEVn0KkoihLVdP0pOvLzIojrY+hLA1nEQzt4jMbfo0WitcixIauMikhsXk6vZBIDk9GNpdRyWlj0wV3XFUFLxgyfcmqbq42nV68nuXkN611cHYpSqrtg7n4Ct0wnFZHb/OZkdX+qibNuWgWTE0f1UXAMuKDZDMtUbiCx6hezu1Rjc5ZScdKV+/9BIrJ5HZOGLBwRkT7gU0WRtzCocOJaSIRchSEadEv72JOTQXl3IdMhFxRgluWERDe8/hWi20+7CO3G26ojXZmJzUB+X01SF0IfPkvhqDp6ep1J61p8wSAaqSq20cFvJa9DCY8VklAhEU8iaht9mocpro2O5C0kSqY+n0QCHyYgkwJa6GCaTEb/dxEkdSymx6wl2NJVjbyTF5to4tbEMaNC8xMqInpVI0jeL3zbhvwNN0+jcuTPV1dWEw2FMpm/HLPjZJtgAH374IZ9//jmpVIprrrkGURRZunQpgUCA+vp6evbsecwiNU344bF/Xjsnq0Xa7NfFng5n5yUdJnk5HPINe6ibPhE1HaP0vNuwtjs03o2v+oDQh88iOXyUnX87pvJ2gJ5Mhub9g+SaDzFVdsJ/zo0YPc3Q5NyBeR97Cd4zfo+tw0A0JU/si7eIfj4NTVVx9x+Na8AF+sxP/U4in7xCeutSRKsLV79ROHuPRDTbi1Xa2PJZZHas1Ku97fvj6HG6Tq1uZDumktu3mfT25aR3fEmudgtoKiBg9FVhLG+LqbQNRl8VhpJKDO7yb0Ur3w81n0WJBciHa8g3VJMP7iRXt12vRhd+n6m8LZbWx2Nt1xdz8y6HeJBndq0isWYeqc2fgSJjbtEdV59zsXYYgCBKqLkMiVUfEPtC98u0tOpFydBfYG7WSWcALJ+lMwCUvN7NHjxOp+Dv/IrQB39FjtTiOH4EnpN/iWiyIseDNLz7JJldq7B1HIx3+J+LiqD50F7q33qAfEM1JSddgav/6EMCj0z1eurfvE8vyIyZ+L37jAvollyjC93tS59fQuZr9nP9Wnt4vWkG+yeDRYsWMXXqVO644w6aNTu6wNDBaEqwfxg0xQT/+4imcjQkszQkcnjsJkRBoFmJFYtR32O2BuLsDafJygqZvELvVh6SWV1Yam11lHAiQ30ixYrtEaxmA3uCMWqT+kiOqCoE5z5LZNWcRnO2ByMfqdUdJ+p3UXLiZbgGjS2OOKW2LKHhvb+gKXk8p/5GFy0VBNK7VtHw/tMo0QDO3iMoGXoFotlGdu8GGuY+Sz6wHXOL7nhOuRpzRXvUfIb48reJLX0TNZvE2mEg7sEXFRsBcqye+IrZJFbPRc0kMHircPQ8DXvXYYeyruINpLcvJ7NjJZldq1EzcQBEqwtTebuC4GlLDJ7mGEsqEO3uY5rhPhiapqIkQsiRWvKhfeQbdpMPbCdXuw01m9R/n70ES8teWNv2xtq27yHuInK0jsTa+STWzEOJ1iHaS3D2Go6z94jiKF+meh3RT6eS2fkloq0E96ALcR43HMFg1BkAHz1PrnYL5uZd8J7xO935JR0jPP95kmvnN9LFKVL5P3kJyerCd/aNWFv10l+PqhBb8CKRZW9hbt6FdmNvx1dWRnO3mbymURNOIefSbJg6ifiWL6g4YQwdhl9FTtUFydqV2rEYDaBBqduG12Ymo8jEMzIGSaRzhZsKp5m8qpGRFULJHAZBpKXPQiiRx+Mw4rWb8TksGCURj81IXTyLWRLYXp+gzGVG0wS8DhPlrmMf1WvCDwtVVbnlllvo1asXl19++bc+/2edYB+MKVOmMH/+fDweD+Xl5dTW1rJixQp+9atfcemll37PK23Cd8WKXWGemreZxVuCaBwQQBvY1lfs9P39k+2NzpEEuKh/Sz5aX0dtPPuN15cTIerfuEdXij79msPO02b3baJ+5kMoqSje039X3DRBr3CGPvgrmqrgOeVqHL3ORBAEsjWbaXj/afL1O7F2HIT31N9gcJUix4KEF/yb1IaFiPYSSoZcons2SgayNZuJLp5MevtyBJMN5/HDcfY+pyjykQ/tJbHqAxJr56OmIohWF7ZOg7F1OgFLi+6HKH6q2RTZfRvJ7t1IrnYLucCOouL3fohmO6Ldg2R1IphtiEaLHlAUFOU0RUbLZ1BzKdR0HCUZKW7Qxfvt8GIsbYO5oj2m5p0xN+9STF73Q1MVsvs2ktq4mNTGxSjJMKLFUaSL7/cCV5IR4l++R3zlO6jpGOaWPSgZcgmWlj10n+x1C4gsegUlVo+1fX88p1yN0dOs0RybwVOJ76w/FdXAk+sXEvrwWT3oOeXXxfcIOOCHLRn1jkTr4w55/5MbPiH47pOHKMx/GxzO4/rg5yRRQFE1TAaRC3pXMeWL3cXjD1dYasLPG00J9g+Dppjg/w9k8go1kTR2s4F0TsFpMSCKArKiC0SG0znimTwGUcRjN5POKZS7zCzbUc+rn+9kU20cNA2rWULNKyTyCoIAiirgsxtZ+97L7P1YV4ouPf92bBYH+yMJI5DNZWj44BlS6xdibdcP38gbigw3ORYk+O4TZHev1oXOzvwjkr0ENZcm8snLxFe8g+Tw4jntN7p+iKaSWPUBkUWvoqbj2LudTMmJl2Nwl6FmEsSWzyK2/G20bBJL6+Nx9T8fS+vjEQQBNZ8ltXExiVUfkN27HhCwtOqJrfOJ2DoMbNQZBj0Rzgd362KnNZv1wnjD7gPz1aALnto9SDa37i5ishaETwsCpKqCJufQcinUTBIlFUVJhnXV8AIEgwmjv5WewFd2xFLVBYO36pDCtRwPktr8OakNiwrrB0urXjh6nYmt4yAEyagLoG1bRmzpDLJ71yPa3Lj6j8Z5/EhEk4V8QzWRT17WfbIdPkpOvhJ715MBSK79SBeRyyZ1G83BFyEYjMixAA3vPUVm12qs7QfoYnSFhF9NhAjNnkRy91pKeo+k09lX43c5qPLZ6VLpQhUElq7azNy/jie8ZzN9L7yWFieeS10khyQKVLnNJPMKBkmidakNA1DptZPPK9RGsxiNIqd1qaBHixIWbwmwL5zB7zTjtBiRRIEKt4UKl5lUTsNpMZDIyritRqLpPA6zgWAiiyQIuGxG/A5zkx/2/xD+JxLsV199lfnz5zNy5Ei6d+9OixYtsNlszJw5k5dffpkZM2b8AKttwnfF/k72fgG0gy29REHg7J6V7AgmWVcTQyuIpl05qPUhifeRoGZTBN+eRHr78kZKoQdDSUUJvv0omV1fYe82DO8Zvy+KfMmx+sKX9SosbfrgO+uPGFylaIpMbNlbRD+dCoKAe/BFuPqep89n7dtE+OMXyFavx+CpxD3kEn3OWpTI1m4ltvRNUps+BUHA1mmITgNr3lWfFVZk0jtWkly/gPTWpWj5LKLFgaVNH6xt++h+10fwxlSzSfIN1ciRGuRYvS5wVkia1WwKLZ9FU/K6XLsggGhANJoRzHYkiwPRXoLB4UNyl2EsqcDgrTrivLeSipLZtYr09hWkty9HTUVBMmJt2wd7t2HY2vUr2o1la7YQ//Jdnfqm5LG264dr4FgsVV11AbgtS3Sf7PqdmMrbUTLsKqyteukCJ1+9T+STV1DzWV2JddA4RKO5sahMZSf8Z99QnPvW5DzhBf8ivmI2pspOlI66rZEVCtDYkqOqK6Wj72wkvPJtcHrXcuatr2ukD9C+zMGANvr7tD+hlgQY178lM1ZWk5dVJFFgbN8WjO5d1ZRc/w+hKcH+YdAUE/z/AUXVqImmizPaBkkgl9cIxDOIApS5LJQ5zIRTOTKyit9hJpGReeXz7XyxPUhtJI3VKJLJyzisFlr5rSSyKnI+TzOPnWU7G9i9ZB773nsaqaSC5mMmYPA0QwYMgIy+PyS+fJfQR88jOTyUnntrURRL747OJPzJy4imA6KboBfsG+Y8Q75+J5Y2ffCe9hudTt7I51or+FyPxeDwomaTxL98n/jyWSjJMEZ/S5y9z8be9eTi+Fc+tJfkuo9JbvgEObwPEDA374y1bV8sbXpjKm97SFwDesIsR2rJh/chR2oPEj2NoWWSqAcLnwKIEoLBhGiy6irjVheSw6sLoLrLMXiaYXCXHf53KTLZmi1kdqwkvX0ZudqtABj9rbB3PQl715OLc+VKJkFyzTziK99FjtQgucpw9T8fR8/TEY0W5GgdkU+nklz7EYLRjKv/aFz9zkc0WcgFthP68O9kq9djbtYZ71l/xFTaGk3TdCGzj54HTcVz6q9x9DwDgyCgANkdXxJ453G0fBr/mX/E3W0YPptI5+YltC+1k5FVvLla7vvzL0nHY1w6fhLG1n1AFUDUUGWNshIrubxCOJXDapJI53TbrYakTJsyO8e38NK50kWV1876fVH2hNIoqkK5y0rrUgdlTjNZWWVfJI1BFJBVjUq3hVhGJpWTsRbGIkSxyZ7rfw0/+wR769atXHPNNdx2222ccsopiKJe/dm0aRM333wzF154IZdddtkPsdwm/Ac4WABtyfaG4mw2gEHUBc/W7osiAKN7V3HrG6vYWp885uvrSqHPE18xG2u7fvjPufmQuWVNVYh+No3op1MweJtTeu4txdkoTVNJfPke4QX/BkHCM+yXhU6piBytI/TRP0lvWYLBU4nn5KuwdhgIQHrbF7rPdf1ODN4q3IMuLAiaGchHanUa2Jp5aNkkRn9LHD1Ox97t5CJlSs1nyOxYSWrLUtLbV6CmIvo98VZhqeqKuXlnTJUdMXqrjupp+Z9AUxXk0D6ytZvJ7ttEtno9+fqdAMXk39ZhANa2fYv3VUnHSW34hMTqueTqtiEYzdi7nYKr77kYfS2K9mTRz6aRD2zH4GlW8Mk+EUEQSe/8ivD85/VApVUvvKddc5CXaMF7PJfSbVEGXFDc8POhvQTfnkSubpteUDn5ykNogJqco+H9p0muX4C92zDKhv8Z9WvHHCv2078nL93dyON6+m8HFcX8Di4gvXa1/tk4WPCvCT8daJr2H88uNiXYPwyaYoL/f6CoGrlCEbIuplt8BWMZRFGg1GkhW3gOoNRp5qtdIZ77eCu7w0kCsTwmA7isIq38Tlr5HdgMIqFElmhWYW84QTiZJbR1LdunP4AGND9/PGKBFXUwsjWbddHMeJCSEy/HNWB0kWKdq99Fw3tP6raRHQfjPf13ReXr+Ip3iCx+DU3O4epzjj7mZHEgx+qJfjqFxJp5CJJBFzQr2FVpcp7khk+Ir3hb3zNNVuydT8Te43TMzTsXxTrz9TtIbV5CetsXxSRWMNsLMUEXTJUdMVW0P4Rp9r2/R6koudqtZGs2613zvRvQ8pmiToy1fX9sHQdh9Okq/pqmkt29lsSaD0lt+hRNzmFu3gVnn3OxdRqMIErkw/uILXmDxNqPQBBxHjdcF4m1l6AkI0QWv0pi1VxEi4OSk67E0fO0YhzWMOevZHZ+iblFd3wjrivacmqKTHzxa4SXvIHJ14L2427D4GuJKIDdJNHa56BzlYsVC+fz7l9vx+Zwcd8z/0Yobc2eYBpZU4in8jhtRswGAw6zRG08Vxg70DAZBbbXJUjJKm1KnVw2sCV92/jZG85glDTCKb1L3cJ7wNc6nVNI52RsZkNx/EFRteJnugk/HXwfMQH8DyTYwWCQPn36sGvXLgCWL1/O3LlzWb16Nc2bN+fhhx/GaPxugXQT/jvQ1cU/Ry5k2AJ6wqJqOsX2tasHctnzS0jnj2EQ+2s4oBRaRekFdxaVQg9GZvdqgrMfQ0nH8Ay9Ame/84obaj5SS2jO02R2rcZc1Q3fmX/E6Nc3j/T2FXpC2LBHn7c6+ZfFWaDUps+IfjaVfP1OvVLb7zwcPU5HNNtQcxmSGz4hseoDXbBEELG0Ph57l6FYOwwobpL7FTozu1aR3bOWbPX64vwTkgGjr4X+z9MMg7scyVWK5PDqVDCL47AV5/3QVEX3w05FUBJh5Hg9cqQOOVJDvmEP+eAe3coDEExWzM06Y2nZA0vLnpgqOxSvrWZTpLctI7lxEenty0GRMZa1wdHzDBzdT0E021HzWZLrPia2bCZyqBqDpxnuQeN0qy5RIhfYQWThS6S3L9cFToZdha3jYARB0O//h8+S2b4CU2Un3QqltFXh/mgkVn9I+KN/IEgmfCOuw9bhUDEjOR6k/q0HyNVs0QOmQRf+x1+eo45rxntrasgVKIz3jerBJQNaFp//uoJ+E346SKVSSJJEQ0MDTqcTp/M/V+hvSrB/GDTFBP9/Ip7JUxfNUBfP4LQYsBgMKKqKz2FGVnTLyaXb6pm+Yg/RZIaGRB6XWcLvtFDusiBrGqVOMwaDyJa9cbxOI5sDceLpHLnwPr7810SyoRrKTvstluOGH7IfqJkEDXOeIbXpUyyteuEbeT0Gpz7aVXQZWTwZ0WCi5ORf4uh1BoIgoiTDhBe+THLNPESrE/egcTiPH4FgMJIP1xD9fBrJdR8DAvauJ+PqP6rYic3t20R81RxSGxeh5bMYSiqxdxmKrfMJGEtbF9eoJMNkdq0is3sNmT3rkEPVxXVLzlKM/pYYvc0KuixlSA4fkr0E0erSqeFH2Ps0TUPLZ1HTMZRkGCXRgBwNFGKCveQbdqMkQoWjBYz+lphbdNfjgla9isw3TdPI1W4htXExyQ2LUOL1CCYb9m4n4+x1VrGJkd23idiymTqzT5Rw9joD14CxGFx+1FyG+PJZRJe+gSbncB4/AvcJlyJZHHoxY/ksIotfA0GkxalXQo/hB2K28D4aZj9GtmYzjp5nUHHabyj12YknFTQRbCL4XBa0NW8z9+VnKG/ThQlPvYDLX86+cIqsorC1Lo6sQd8WXhxWAyV2E6msgibAlto49bEUNZEMZpOB1n4bXZu5GXV8C+pjWeriGbx2E50qXBia6N4/C/wQMQH8DyTYAFdffTWpVIoNGzbQrVs3unXrRo8ePTjppJNwOp2k02ms1ibhgJ8yJi/dzV0z16AUaLUaFCm2F/VvyWtLd3/nax9JKfRgKKkoDXOeIb1lib6hjriuSDHeT0MKf/wv1FwGV79RuAdfhGiy6JTmVR8QWTwZNRXB2mEgJSdeVtw009u+ILbkTbJ71yOYbDh6no7z+BFFanMuuJvkuvkk13+CEguAKGFp0QNr+35Y2/bF4Gl2kG+2WugqbyEf2EGufhf5UDVKrL4gRtYYgtGMYDDrnW5BBE3V563yWb3qfBhIzlJdQM3fElNZW0wV7TH6qhol63K0TqeJb1tGeudXoOSRHF5snU/E0f2UonCcHA0Q/+p9Eqs+QE3HMJW3wzXgAmydhuiV69Beop9OIbl+IaLZhmvQhbj6nINgMKHJeWJfzCD6+XQQRUpOvEy33yqsQ0lGaPjgr/+vvfsOs6uqGj/+3afcfqf3kt4baRBCC70jWFCw4esPFXyxgAUVUMCKoIAKqGB5VYoNG4LSpAmhJCSkkZ5MJplML7eftn9/nMkkkw7M5CYz+/M8PobMnXvXmZnM3evstdcis3YhwREzKDv/mr4F0K6yjStp++t38ewsZeddQ2TC/Lfwk7NvVQVBOtJ23y71g59Q56mPBN3d3XzpS1/i9ddf57jjjsPzPKZNm8Z5551HXV3d235elWAPDrUmGL5s16MtkaUjlSMeMtGEwPIkSAibOksbO3h0SSMNHVmiIY0ptcVs78kytbKQnOdgWZLSqMnrDe3YHjT3ZAkYGkhoam3njfu/Q/fa14jPOJPSM65EGrtVPUlJ8o3H6Xzq5wjd9MvCdxnjaHdspf3fPyHXsIxAzURKz/x033uf1byBzv/8kuzmJegFFRQdfwnRaachNB2nu8WfAb3sCaSdIzTyKOKzzyM8rrchaC5Nes2LpFb8h2zDMpAeRlE14XHHEB57NKG6KX1HscCvHLO2r8Vq3oDVuhGnvRG7Y+ve3+N3HBEzAqDpfl8W6SId23/8LueudxCBCGZJLWZZPYHyUQQqxxGoGtevItCzs35X8Q2vkV77Cm6iFTSd8OjZRKecTHj8sWhmsG/0WGLxI+S2rtrZm2buhRixEqRjk1j6L7pf+j1eyl9PFS/4GGap/7s527iCjsfv8XvhjD2aUedeSUV1Ddu6HXJSklz6bzqfvg+h6ZSe/Rkik05AAwoDUBTWCAVNpG3z5h++z+bFzzDrlPP46Je+xcTaMtrTNgFN0NKTY8XWbmpKg8SDQaIBnQmVBWhCQwqPRMZh5bZuVjV3YQiNkWVRJlQUcMyYUlwJEUMjZXvU7tKoTzl8DdaaAIZIgp3L5Vi2bBmJRIK6ujqam5tpbW1l3bp1PPTQQ0yfPp3zzjuPiy++eBCiVgbCos2dXHqvX1ar6wJNCFzXT17eM7uOB95Bgg39O4UWnvghCue/f4/OmnKXX9AIjZLTPkF0+uk77xynu+n8zy9JLX8KPVZK8SkfJzL5JL85SS7tNy955S9IK01k4vEUHvcBAhW73K197W/+3VrPJTRyhj8ncvyxaGbIv+vbtIb0mhdJr3257660XlBOaMQMQvVTCdZMxiit3TNu18bpacNNtOImO3HT3XjZJNLK4Dk5cB2k9PzP041+5630SCFatBgjXoZRUN7vjXvH18TpaiK3dRW5LSvINizD6WoCwCisJDz+WCIT5vd1FpeOTWb9qyTfeJzMhkUgBOFxx1Aw9119o0fs9ka6F/6R1Ir/IHST+JzzKZj3PvRwvO+mROdT9+F0NRGZcBzFp32i33nq1Or/0vH43Xi5FMUnfZT40Rft/Xv5+j/peOpejMIKyt99fd/O91shYK9z2I8ZVcxrmzv7bgJdc+ZE/veUcXt5pHI4+exnP4tlWdx0002sXLmShoYGVqxYQSAQ4JprrqGkZO+9Dg5EJdiDQ60Jhq+c47K1M0MkoJO2XIojftMoXdMAycqtPaxrTvLyxhZMXWN8ZZxgQMOVkMrYdKVsENCdtulI+w2rcrZL2naJBHRsJ8crf76PDU8+SEH9JCZd+jWcUAldOdg1zbQ7ttL2yA+wmtYQmXQiJWdc0ddAS0pJasXT/s33TILYUWdSdOJH+j6e2fg6Xc/9Bmv7WoyiagrnX0x06ikI3cTNJEgu/ReJxY/iJlr90Z3TzyA6/fS+xptuqpP02oWk1ywk27AUXAdhBAnWTiY4Yhqh2ikEqsf39Y/ZQUqJl+72+7Ik23fpy5Lyb7A7FrI3mRaagTBMhBnym6SG4+jRIvRYKUZhBVoovucOfy5FbptfJp7dsozctjf92MwgoZEziUyYT3jcvJ2jUNsaSC17kuTyp/DS3RjF1cRnn7+zqs/OkXzjcXpe/jNuoo1g/TSKTrqMUJ0/wtJJtNH1zK9JrXwGPV5OyemfIDx+PqVhQdgQtLS20vjIj8luWERo5AxKz716jx4shSZ4XY1s+MN3yLQ2cvyln+HyK66ipiRMNBSgMGRiu5LmRIaWLr96QhMaU2viBEyDyrjf7GxseYywqbOmOcEbjR2EDYPptYVUFkXoSFmETZ2M7VJbHCZoqAT7cDdYawIYIgn2rpYvX85DDz1ES0sL1dXVHH/88ZSVlXHmmWfS0tLSdx5LObzsOiN7R2Oo2qIwx47xR1TsWkL+dnl2lo5//YTUymf26BS6K7trO+2P3kFuy/LeJmf/i1FQ0ffxbOMqOp/8KVbzegI1Eyk+5fK+NwI3k6Dn1b+SWPR3pJUhPGYuBfPeS7B+GkII3GQnyTceJ/nG4zjdzYhAmMiE43bOlO49V213bffHcGxaQnbLcrxMD+DfTQ5UjiFQPqq3FKwOo7gKPVa635LwA+k3lqO90S8Tb92I1bwBL5sE/A7lwfqphEYeRXj07L4uotJzyTX2zu5+87942UTfYiE286y+r11u22p6XnmY9OoXEUaA2MyzKZz3PvSYv/Pr3/X/BdnNSzFK6ig5/VOER8/qi9FNddLxxM/8JmeVYyk97+q+buX9vs9WhvZ//6SvI2zZ+V9A6y271wWcNrmSrrTVlyDvy66jth5e3MjizZ10pi0umlnLGVOr9jhnrXawD3/XXnstc+fO7Uuscrkcq1at4q677sJxHO666y4ikbc+Y14l2INDrQmGrx0JdjSgk7JcymIBCsL+DWDXkyzZ0oFluXRmbAK6RnksREc6S8g0WLm1i6aeDGu3JyiMmMQCJvGwwertCQwN0pYHUjKiPMqmV5/hrz/+OoYZYO5lXyNXOQXblSTS4OEn267n0rPwT3T990G0UJSS068gMumEvsTTyybpeuF+Eov/iQiEKZz/AQrmnO9XY0lJZt0rdP/3Aazm9ejxcgqOvshv7hWM+F21171CYum/yG58HaRHsHYykcknEZ14Qt/7o2dlyTa8QXbT62QblvX1RPHHdtZjVowmUD7SH9tZXINRWPW2xnbuysulcbp3jOragt26GatlQ2/TNUBoBCpG+5sAo2f50096b9I7PS1+mfjKZ7Ga1/s72uOOIX7U2YRGz/JL6jMJkkseo+e1v+OluwjWTaHw+A/6ayEh8KwMPa88TM8rD4Pn+Y3Pjr0YMxAirIHjSjreeJy2p38BnkvJyR+jcs55SDRswN7lWnadMHL8//s6M+adxJSqOOfPqqO2KELOcdnYnqatJ83KpgQ528XQNI4fX4qu6VQWhEhZLuW9HcL7vkaeRNMEnifpSOXI2B4FIYPCyFubmazkx2CtCWCIJdgvvPACl156KZdffjkf+9jHGDly547Vueeey80338zcuWoNdDjaW2OoXROWB15u4Ot/W47ryb3uKB4sKSWJxY/Q+fQv9ugU2v9xHonF/6Tr2f8DISg68SPEZ5/Xb7RFavlTdD3/O9xkB+EJ8yk68SMEyvxzuG42SWLxIyQW/QMv3U2gcizxOe8iOvnE3jfd3uYfK54mvfpFpJVGC8UIjz2a8Lh5hEfN7EsKpZQ4HY3ktr5Jbvs6rOZ12G0NSCuzM2BNR4+WoMf8s1ZaKIYWCO9SIu7X3UvX9u9e21m8bBI3veO8VQd4O0d7CCOIWT6CQMVYAlXjCNZOwiwb0bdT7FkZspvfILPuZdLrXsFLdyHMIOFxxxKbeor/BqrpSNchveYlEov+4ZfJB6PEZ51LwdwL+0aO2F3b6X7hflIrnkELxSg8/lL/3FrvzQa/RP9JOv/zCzw7S9Fxl/pNzvbS5M1q3UTrX7+H07mNohM+1G+mKcBRdYX87aoTuO4vy/Z67GBceZSGzkxf9cT+Emd1zvrI89RTT/E///M/fOpTn+KjH/0o9fX1fR87/fTTueOOO5g2bdpbfl6VYA8OtSYY3jpTFt0Zm3BApzwW7NdpOZGxaOrO0ZzIENI1so5HZ8amOGywYms3DR1JtnRkKI2GGF0exdTgzaYEjufSk3UIBwym1hYigE3r13L/d79A65b1TDjjw1QvuJSM49GTgcwuCw6rdRPtj96JtX0t4XHHUHLGFf1uvlttDf5N4g2L/NLwEz7U12tESkl2wyK6F/6RXOMKRDDae2TsvL4da6enjdTK/5Ba8Qx222ZAEKybTHjcsYTHzsUsrd9ZUZdJYG17s3c813p/bGdPa7+vnxaKo8eK0SJF/ujOQAQt0Du6c8eNJc9DOhaenUXm0riZHn/3O9mBzPVvLGsUVWFWjCZYOY5A9QSCNRP7SsWllNgtG8msf5X0upexmtYAEKgaR3TKKUSnLOh7z7faGkgsfoTU8qf8MvnRsyk89uK+jQhcm8zSx2l/8UHcVBfRiSdQfvLHEEVV7DgMZ3dspfPfd5FpeIPQiOmUnv1ZKquqcSRkLf8GeZb+E0ZC1ROpu+hapk8Zw7S6YqbVFHLR7Lq+Lt9bOzO0JTJsbU9THDOJhExOHFdBKGDQnbGJBHXKokHV8XsIGaw1AQyxBPvrX/86o0aN4uMf/zgAlmWxevVq7r33XlpaWvj5z39OQcHbG8ujDL4DJSyLNnfy58WN/OG1LTjuO/v57Ncp9KSPUHDMe/YoMwb/vHH7v+8mu3ERgapxlJz5aYLVE/o+7llZel79Cz2vPIy0c0SnLKDwuEv6zljvaPCVeO3v2O0NaKE40emnEZtxZl8yLh2LzMbFpFf/l8z6V/0dY6ERrJlIaORRBOun+W9ku5SBSSlxE63YHf4oDqenBTfRjpvqxMv04GWTeFYW6eSQrtM7pguEbvol4mZo50iOXcrBjMJKjJJafyzHLl8Pz85hNa0hu2U52YY3yDWuAs9BBMKEx8wlMvF4v6N47x1zu7OJ5LInSL3xBG6qE6OoivjsC/ru2oN/h7v7xT+QXPYEQtOJz3kXhce+r+/GAvgLmo4nfkpuy3KCdVMoPeszfU3mduWX9/+LzqfuRQQjlF/wZUIj9+wQe8VJY0jkHB56pYG9/QgdM6qYa8+ZzMOLG5HAe9U4rSHn9ddf549//CNtbW3U1tayYMECdF3n4osvZvv27W/rOVWCPTjUmkDZH9v1yNoum9pTeK6kIGyyuT3FpvYU65p7SFouOoJR5RGOHVPOqm3dLNvaRSJt4UqoLDBp6bGwpSSIwwv/dysrnnuEsrEzmHLptbTLQuze+8475mb3a7KFoPD4D1Iw9139bvhmNi2h69lfY21fh1FSR9Hxl/o73r036PsafK15ETyP0OhZ/qzoccf0TcCwWjeTXv1f0mtfwm7ZCPg9UkKjjiI0YjrBuqkYhZX9yre9XBq7oxGnswmnuxmnd0SXl+7ZZXRn1r/R3jumS2g6Qjf6l4hHCv2GqQVlGIVVmMXVGMW1/XbEpZQ4ndvIbllBbssyspuW+POzgUD1eCLj5/sN2noby3p2jvSaF/0Z31uWg24SnbKAgrkXEqgY7T+n65Bc/jTJl36P1d1MyZgZlC34KKHqSWRcsADsnF9N8PIfEUawr8mcJjQiABo4HkQMSLZsZdsj3yfdtJ6CuRdSferHiIZNxlYWMGdkMXNHlTK9togtnRkCuqCpO8PSzd1s6U4T1KEgEuSc6TVMq/VL/i3HJaDrSPw9C3XGemgYjDUBDLEE+8Ybb2TlypX8+Mc/ZsWKFWzdupXly5cjhODqq6+msrJyAKNV8uW6vyzrNyLp7dpfp9BdSSlJv/k8nU/fh5vsJDbjDIpO+mjf3Vjwz2f3vPxnEov/iXRtIpNOpHD+xX0lzFJKspuXklzyGOm1C8FzCVRPIDr1FKKTTugb0yU9l9zWVWR6y8Ot7ev8BmZC8xuPVY3vLREfiVlajxYpGpBxArtyMz3YbQ3YbQ1+45Tta7FaN/U2QBGYFaMJj5pJaPRsQvVT+xYEbiZBevV/Sa18xn8DFRrhMXOIzTqX8OjZ/cZq9bz8Z5LLnwYgdtRZFM5/P0a8tN/3puu/D5JY9A+0YJSikz9GbMYZe70J4mZ6/AZ1a14iNGoWZedf0/f13JUGGIaG7Xj7/NkRwLffPb1vLruh5lYPCel0mkWLFmFZFrW1tbS0tNDW1saaNWt44IEHmDdvHueffz4XXnjh23p+lWAPDrUmUA5GZ8qiM20hpKSpJ4upwX/ebPFLmHWoK40yqjTCwnUdrG9JkHYcyiIB3tzWgyclXVkLoUFFNEzzkid44Te3ITWdqrM+Q/GU4wnq0Gb1f02nu5mOJ35KZv2rmKUjKD79kxSNmomDX1YupSS95kW6X7gfu62hd2znxUQnL+hLxp1EG8ml/yb5xhO4iTa0cAHRyScRnXIygZqJfe/tTk8LmQ2LyG58nWzDG33HtvRoMYHq8QQqxxIoH41ZNgKjqGrAR3hKx8bu3OavCVo3YW1fh7V9bd/RNS1SSGjEDMKjZxMaMwcj5p9bldIjt2U5qRXPkFr9X2QuhVFUReyos4nNOKPvrLpn50gtf4rul/+M291MtGY848/+GONmHU97OoeuaTR3ZmlftZCWp+/D6W4mPnkB5adejhsrJgBEAxA0BaYOrd0e3SueoOXxn6ObJtM/8GUqph1HPKhREAlwxvQaJlcWIgQ4LmgCtvdkaEvlaE9YZC0Hz/MoLwgzY0QxFfEQricRSNpSFkUhE13XqIgHiYXURIIj0WCvCWCIJdiWZfG5z32ORYsWMW+eP65n8uTJnHfeef1Kw5Qj245y8py970TpYO1sbHYvQg9Qctb/Ep10wl4f6+XSdP33ARKL/oEwQ73nrC5A7NJ91E110fPKwySWPIa0MoTGzKHg6Hf3nSna8ZjUiqdJLn/aP0clNIL104hMmE9k/Lx+JWdeLu03GNu6ity21VjN6/ve1MA/k20UVWEUlKPHy/zd6HABWiiKCIQRemBniTj+HWLp5JBWFi+X8svBUl04yXbcnlacru19b97gn7sOVI31y8FqJxOsndLv3Lqb7CS97mXSa14iu3kJeC5GSR2xaacSnXoqRkFZ39c5t3UViVf/SnrNS6AbvWM53tevGUlfV/bnf7ezacxJH+17I95dZsMi2h+7EzfdQ9FJH6XgmD0bnr1VJ44v47/r2vrOZwsgaKpz1keyd7/73cTjcVatWsXkyZOZOnUqM2bM4NRTTyUYDL7jrtIqwR4cak2gHAzPk3SlLVZv72FbV4ayghC4Hh6C6qIwhi5Y19pDY3uanC1p7ElRYBr8fWkjnUkPR0IkAFUxDcMM0Na4mVW/v4Xk1jWUzTqdme+9im0pg6znvx84u7x2et3LdD51L07XdqLj5lF4ysf7KtjATzLTb/6X7pd+74/tjJdTMPcCYjPO3HkMzHPJblxMctlTpNe97E/mKKggMvE4IuOP7WsiuuP57NZN5BpXkdv2JrmmtTgdW+lrxym03mq0CvR4OXqsGD1ciBaOowUi/nSRvZWIOzlkLoWbSeJluv2eLIk2nO6W/pNKhIZZWk+gejzBmkkE66b0K12Xrk12ywoya18iveYl3GQHwgwRmTCf6PTTCY2Y3vceHcx207zoURKL/4mX7iJQPZHC4y+hauJcwgFBTVGIVE7S3riWdf/4GT0blxIsG0HF6VcQHzUD04Cs7fdWiYRBR4DTw+o//ZiuVS9SMPoo5n/0a9SNqqemIExVYZjJNYUkLZuIaWLogpZEjoAuSOc8SqIGjoQlmzsxdagtjjKuqoDKeAgBmIZgfUuSkWVRdCHQNEF1oZpGcCQa7DUBDLEEGyCTyaBpGh0dHUQiEQoL974wV45s33t0FT99bkO/v9uxj6tr4i03RPM7hd6G1bSWyOQFfqfQvTRAA7Dbt9D59C/IbHgNvbCSohM/THTKgn6JnZtJkHj9nyQWPYKX7sIsG0F89vlEp5zcb7yF1bqJ9KrnSa3+b1/ncLN8lH8neNRMgrVT9ijLcpMd2G2bsdu34HRu8+dX97T6pWDZxFu6bvCTaD1W2lsOVolZXOOP6iofiR4v77dDLh2L3NY3yWxeQnbjYn+HHb+jeGTi8UQmn0SgcuzO5i92lvSq50m8/k+s7evQQjFis86lYM4F/XaZpZRk1i6k67nf7JwrfurlBKv23pnby6XpfOaXJJf8C7N0BGUXfLFvvuY7oWvwzQv9Hexdb+CoTuFHrjVr1nD++eezZo1/JvDll1/m3//+N8uXL2f8+PF8+9vffsevoRLswaHWBMrBSmZtXtvUQUtPjmQ2S21JjImVBQRMnUhAZ1tnltXN3WxoSbC5I0k8aPLkim20pf3PDwgQEiIhMHQNV7p0Pv8Q6596kGhxObMu/RJtsUlI/FQ2t8trS8fCef2vNL3wR39u88yzKTzukj3f4za8Rs/Lfya3ZTnCDBGddirxWef2a9bp5VKk1ywk/ebzZDYvAddBC8UIjZpFePQsQiOPwijsX3nh2Vnsti3Y7Q3YHVv9EvG+Y2Nd/fqrHBTdQI+WYMTL0AvLMYtqMEpqCJSN8JPpXaaN9JWKb15KZtPrZDctQVoZhBEkNGY20YknEB43r986Jte0hvTiR0m8+SzSsQmNmUPhMe8lOMKfNBIEQgJMq4WtzzzI9sVPoodjVJz4QarnnoOtGWRtCBlgGBAN6JiGQefKF1n18B1Y6STVJ3+YwnnvpiSscfTYSk4YX0FtcZjOtEVnyiZruYQCGmWxIDnbZWt3lqm1BRSGTFY3JagoCJJzXCZUFVJZEGJ7TxYpJR1pi+JwEJCUxYKqmdkR6FCsCWAIJti7k1IOeAmtkl+LNnfyhT8sYVN7ut/fnzS+jPqSCK2JHE+tat7rGdv9ka5D98I/0v3iQ2jhOKVnfprIhOP2+fjMxtfpfOZX2C0bMMtHUXTih/1Zlv0SUpvUqmdJLPoHVvN6/0118klEp5/u35Xe5bF2+xbS614hs2ERucaV/puiphOoHOvvHldPIFA1zi8B20fHcOk6eBl/HIdnZfzdas/1z2CDf9bKCKCZYUQo6jc+0fde4iSlh9Pd4peDNa3pbbK2Blyn94z4JMJj5hAedwxm+ahd5nVLrOb1pJY9QWrFM3i5FGZpPfE5FxCdeuoeNwyym5bQ9fzvsJpWY5TUUbzgMsLjj93nv9vMpiW0P/Yj3J5WCo6+iKKTPrLHeLG364PzRvCdd0/vO+//p0WNB9XwTDl8LV++nM9+9rPcdNNNnHDCzs6/y5Yt44tf/CJXXXUVF1xwwTt6DZVgDw61JlAOhu16dCZzPLe+Fc2TrNqepChicFR9MeWxIKahk7Fs1jQnaU9k8ZCkLYcnVm5nW0cGy/KbmRlAQQgKggaWFBTHAiQ3r2Lx775LT8sWppzybo5675V052B1i9VvJzsG9KQ66X7hAXqW/hthmBTMeRfxY96zx8363PZ1JBb9g9Sq58C1CdZMIjrjTKKTTug/XzqXJrNxMZn1r5LduLjvjLMeLydYN5lgzUT/2FjF6D1Gde0gpfR3prNJpJXGs3Lg2ki54wy25vdmMUNowYg/lisQ3ue/ES+XwmreQK5pLVbTanKNK/vFFR4z25/XPWommrnzvd7NJEivepbkG09gNa9HM0PUzT2dsQveTYOs7Pe1DGU7aXvpD3QsegyEoOzoCyg//v04RgwNiAbB9QABQQ08K0Xjv35G0+KnKa4bx+wPfRm7qL7337rOxMoYJ0+uZkxpjKZEhsJQgJzj0pVxKIsF6MrYZCyHgK4RDwUwNaguDhMNmJQX+F3DbdfDcSW6JrBcD10IwgF1BvtIdCjWBDCEE+zFixezcOFCrrjiCjWGYwg5UHm4ofvjEoD9jmDaH6t5A22P3oHdssGfw3zGFX1ninYnpUd61fN0vfA7nM4mApVjKTzuA70JorbL4/w514kl/yL95vNIO4tRVE10ygIik07s7dC9S7MSK0uucQXZLcvJNa7E2r4W6fiHwIQZ9MdylNT1NiSrxCgo83eho0WIQOSgFpBSSqSVwU13+eVgPa1+Y5TOpt5RXbt0KtcNApVjCdVN9Ud11U9DC0b7PZ/dsbVvLIfd3gC6SWTCccRnnr2zO+gur53d9DrdLz5ErnElerycwuMvITb99H3ePHAzPf4c8mVPYpTUUnrO5wjVTTngdR6sgKHx4Cf6J9GqU/jQ8Itf/IJXX32Vs846i2nTplFXV0c4HOZHP/oRy5cv5+c///k7en6VYA+OgVwTpNNp/vCHPzB9+nTmzJkzIM+p5J/rSbZ2pXFdyRuNnSSzHl2pLKMrYmRsl3EVcYqjJuu2J8jaHjnXpTLu70guaWjnjYZOujIWriMoiZu0p2xiIYPCkI7tef7xKsvizUd/yZtP/YlwYRm153wKOWJevzFQu7I7ttL1wv2kVz2PCISIz7nAn5yx21EnN91NavlTJJY+jtPRiDCChMcdQ3TKAr9vyW47xXbbZrINy8htWUFu6yrcZHvfx42iKszSeoySWsyiKvSCCn8HOlqMFo4f9Lls6do7p4sk2nF6Wvw1QcdW7PaGfp3K9cJKgrWTCNVPIzRiBkZxzR7rmMz6V0mtepbMhtfAdQhVjqF09pmMO/ZMpoyuorUzw4rtGdIeOIl2kq88TGLpv/Acm9KjTqf4uEsIFlWgmyAtkBoETLBt0AxJ17Jn2fqve/FyKaafexmnfeCT2J5kdWsK27YxAwaV8QhTqws5bnwZbYkMbSmbeMhkYkWMoKnx/JoWUrbEcyXjKqNMrCqkPWUxuixKdWFYdQ0fggZ7TQBDOMH++c9/zqc+9Snuu+8+/t//+38D8pxK/u06L1vAOz6DvS/Sdeh55WG6/vsgQjcpXvBRYjPP2ffOseeSWv403S/9AaerCbN0BAXz3uOXju+2Q+xZGb8Z2Ir/kG1YBtLDKKkjMv5YwuOOIVgzcY/Xka7T23RsHVbrZr8JWXsjbqJtz6+CpqMFowgzhDB2OYMtJXgunmMhrQxeLr3X0jE9VoJRUueXg5WP6m2gMqrfWXPwby5Y29eRWfeq3+m0dy5nsG4K0amnEJl0IvouHcF3fJ3Sq1+k5+U/9c4ELaPw2PcRm3HWHs+/83UkqRVP0/n0L/CySQrmvZei4y8dsF1r8H+WLu3dvVaGpl//+te88MILlJWVUVxcTFtbG8899xzf+c53OO20097Rc6sEe3AM5Jogk8lQWlrKuHHjeOONNwbkOZX8y9oum9tTRAI66ZxLSyJLY2ea0liARNZhTFkc0xBsbk2gC8HGjgzxoIHteRTHAmzvyNCVtljVnEAgae7OMqYsjNA12hNZtnVZ5GyboGkgW9bx+oO3kty+ifiE4yg47ZN9fUb2xm7dRNd/HyK9+r8IM0BsxpkUHH3RHiXeUkqsbav9sZ1vvoCX6emb0BEedwzhMXPQw3t2u3cS7f6aoHlD75pgC05nE9LJ7fFYEQj7ozvN3cd0uX5XcTvnV73tOv5zx+eaIYziGsyyegJlIwlUjCFQNa5fo9cd3FQnmfWvkV73MtmNryOdHHqshNikEymdcRp6+RhMHaqLTMpiYVI5m83r17Ll+YfpXv40SI/qWadRccLFBItr6Mn5x79DJkRCGrGgjtAErVs2s+nRe+hYs5iiEZO46DPfYO6sWcTDOhtbk2zvzrGuNU1dgYnlSuLhIOUFQabVFFFdaNKRshlR6t+EeWlDO4YuiAUMQHDyxAoCpkZdUUQl10PYYK4JYAgn2LlcjqKiIkaOHMmqVatUSdgQsWhzJ5f+/CVsV2LqgnOnV/PXJdv6Pr4jjxwodsdWOh6/h+zmJQQqx1JyxhUEayfv8/HSc0m/+TzdC//kNzSJFhObdS7xmWfvtbO1m+wkveZF0mteJLtlOXiuf95qxAz/DHb9tH4NRPZ4Pcfyd553nMFOd+H2lohLK4N0bKS3Y0yX8EdyGIHdRnIU+Ul173mrXcu6+r1W31iO5WQ3v0F28xK8dLdfLl47iciE44hMPL5f07K+68wmSS59nMTiR3B7WjBKaik45j3Epp66z8QawGrZQMcTPyPXuIJAzURKz7qqb6THQNGEv3utSsCHnr///e8899xzGIbBlVdeiZSSV155hY6ODlpaWpg9ezbnn3/+O34dlWAPjoEuEf/MZz7DT37yE5555hkWLFgwYM+r5E9HMsebzQkcxyMWNphYEWdrZ5rNHRmE8IgFTEIBnYa2JJs6MjR3ZSmJmhRFDKqLowghKAnrbGlPs2xbNy2JHNWFQTa3pYkGTZp6smRyOUzNAE2wpbWL5hf/wtZnHwShUXjcJRTMvXCP97EQ/uQKB8i0baHz5T+RWvkMSElkwnzicy4gWDd1j/d26TpkNy/11wVrX8ZLd/nd0KvGExo10x/RVTOp31Grfp8vpb/73NPaN7rTTXf3ju7sPTbm2rDLmC50E80M+kl4KOaP6YoWo8dLMQoq0CKF+y0ZzzWuItvwBtnNS7Ga1wOgx8uIjD+WyMTjiNZNJWzq4PWeX5dgSo9Y53Ia/vt3WlYuRBgBymedzjHnfQQvVsbWzhwpB8IGhE0oLwoR1k26kwm6Fv6ZRf/8HboR4PgPfJoTzn8/k+vKiAR0ujMOOcshaAieX9dOdUGQRQ0dVMZNNN2kMGRy1IgiKgpDFIYCdKUs0pZLRyqHYeiMKQszc0Qp0aCBqavq16HmUK0JYAgn2ABf+MIX+OEPf8i///1vzjzzzAF7XiV/Fm3u5NJ7F2I7/rnYBz9xLE+s2M7Pnttw0LvZU6rjrGw6+GZgUkrSq56j8z+/wE12EJ16CkUnXbbfO9dSSrIbF9Pz2t/IblwMmkFk4vHEjzqrr5HH7rxskszG18lsXOTPlEy0AaCF4jvPWlWOwawYvcec6oG2Y8621bIJq3k91vZ15Lat9t/sAS1aRHjkTEJj5hAePXuvXb533JVPLP0X6VXPI50cwfppFMy9kPC4Y/ZZDQB+6VzX878jufTfaKEYRQsu2+eYrnfiopk1jK+MqxLwIeiuu+7igQce4Oqrr+axxx7jhRde4Oijj+aLX/wiM2fOHNDXUgn24BjoNUFDQwMjR47kvPPO45FHHhmw51XyZ3N7CkMT2K6H5UgmVMXZ0pHGcT3aElksV7KqqZvFDR10p3OMLS0gEjYoiQTQdYGpaUysihE2DV7d1ElLT4aKwiDrmhMEdJ3mngwJy6Yn5WLqgrTtkLMdtjY0svXxe0mvexmjuIbiU/6f/74mBDoQ0vxzwkETXAkBDTra2uhc/A+SS/7l9yYpH0V85tl+89Pdqr2gt0qsaS2ZDa+R2bgYq2lt39jOQMVoAr19WQIVYzDL6vd5c3ygeLm0X0nXssFfEzStwW7dDEjQDIK1kwiPnk14zBzMijF96xwTiIfAykF3qrcsfsm/cDq3YUaLKJh5DpXHnIcZKyJgaERMDVt6tPd4hIP4o9aKomibXuLx395JT1szJ5/7bj5+zfWMGVlHNudRFA1gOy5tKYuWrjQ5V4KQWI7H0sYuAho4Dowqi1EcCzCjvgjPg5zrUVMYZltXmop4iCm1fgzK0HMo1wQwxBPspqYmampqOPPMM/n3v/89YM+r5M+uJeI7OjsD3Pbv1QedYOsaINmjCZre+4a4L56VofulP9Dz6l8RQqPg6IsomPfefk1J9sZubyTx+j9JLn+6bw5kdOqpRKeegllcvdfPkVLidDWR27Lcby627U3s9sa+URnCCGKU1GAWVWMUVaHHS/0z2JHecRzBGFrALxFHNxBC85uauC7StfB6S8S9TA9u2h/J4faO5HC6/PNWO0vFBEZJLcEaf1RXqG4aRmndPu9oOz2tpFY+S2r5U9jtW/zGblMWEJ99HoGK/Xf69uwciUX/oPulPyDtLPHZ51F4/AfRw3E0YHqdn8i/0djd7/sdCeikLXe/z60L0DSB40qEgE+eOIavnLvvagTlyHbeeedxxRVX9DUryWaz3Hnnnfz85z/n6quv5qqrrhqw11IJ9uAYjCZnF110EX/7299Yv349Y8a888kDSn61JbL0ZP2jTtGgQWVBiLZEluZEjq0dKbb3ZHllQyuO49KT86gvDRM1AxxVV0jCdqmMBQkEdDSguSfL+uYEEkFVYZgxFRHipsnK7V1s68zSnbPY2pGhM5Vlc0cGJwft6xfR9tS92B2NBEfMYNSZHydQMQ4zAHETbKnheZJo2KCzx6bDAtfKklr5DMklj/nNT40A4fHziU09hdComfs8L71jbGe2cSXWtlXkmtYhrZ3NXvWCCswdfVkKK/r6smiRQrRQzC8RN4L+bnvvzWpTeliOjbSzeFba3+nuPYPtJNr88Z3d27E7tvXd9Af/xv+uo7qCtZP2Xf3m2LgbX6N7+dMk170KnkOwdjKxWedSPPEEMEziYTAEFEeCdKYtNMBDUhwJkmtczvK/3Uv7ppWMmTSNr974Pc48/WRcT1JVGKYtmcVyJN1pi+ZEFkMTFIZNWhIZVjclcF2XRM7BNHQWjK/EkpLqwiDRoElpJIDtSQKGRlzNtB7SDuWaAIZ4gg1w8cUX86c//YlVq1YxadKkAX1uZfDt3mRqR5OzHTvYXz9/Kiu2dfPQqw37TY7DhkbG2c8Devm/1PfP6W6m85lfk37zebRwAYXzP0B81jkHPBPs2Tn/7PXyJ8luXgZIAtXjiUw8gciE+ZjFNfv/fCuL3boJq3UTdvsW7I5GnK7tON0t4O6r3cpboBsYBeUYhVX+mK7SeszykQTKRx/wJoLT0+KPFln9gt8BHQjWTiY67TSik0864OdL1yG57Em6//sgbrKd8NijKT7545hl9Zi9jet2fL+Xb+vmT4sasR2/0Z3Ab25n76dt/BlTKrliwVgA1bRsmPjpT3/KihUruOGGG6io2Dlbfu3atXzzm9/k1ltvpbKycj/PcPBUgj04BmNN8Oqrr3LMMcfw6U9/mrvuumtAn1sZXJ4n6UjlyFgu8bBJUSSA50lSlp9gR0ydrOPheZKM5bC8yR/L9cSKZgKapCWZY3JVASdMqGR8RYy2pIWmCRraM5g6hAI6zT05khmL0RUxqgrC5CybhOORzdl0Z2zWbE9ie37X8paeHN2pLNJzaFv8LzY//QBuuof4pBMZc8aHCVeOQMMjqGv05FySGb9E2gIC+KO+ctvXkXzjCdKrnsXLJtHCBf786wnHExoxjaBuYu3j6yGlh9O1HatlY19fFqfLH93pZXoG5GuuRYowinaM76zHLBtBoGI0ekHFfo9eSscis2kJ6dUvkl77EjKXQosWEZ1yMrHpZxAoH4mJ3wVcaIDwb4LXFoexbBvQcNo3seaxX7Nt2YsUlFbwgU99gcs+9hFGlxaCJgibOp0ZC+lJXE9SHDFpS9oURUy60zlWNycpiwdYta2HoKFTXxqhNBKgtjhCcTSgjo4OM4dyTQDDIMFesmQJs2bN4vLLL+fee+8d0OdWBteOZNpyPDQhuPnCaUysivPw4kYkMK2mkJsfWYHleAghcPfTNnwwGqLlmtbQ9ez/kd28FD1WSuH8i4nNOPOgmm85Pa2kVj1H+s3n+2ZJm6UjCI87mtDo2YRqp+z3bPKupJT+TnSywz9rlenBy6WRdhbpWEjXgd5UVOiGP5IjEEILRNDCBf55q1hx7zmrgyuNkq5Drmk12Q2LyWx4re/clVk2ksjkE4lOXrDP3fndnye14j90v/R7nK7tBGomUrzgY4RG7Gw2ZuiCD8ytZ+qu329gdHmM9S3JfX5fhYCxZVE+fsIYPjhvxEFdlzJ0NDc3c8MNN5DL5bjssss4/vjj8TyPRCLB3LlzWb16NeHw3kfbvFUqwR4cg7EmAJg1axZLliwhk8kQCg1uWa0ycLrTFq0Ji56sRcpyqC0KowsNTROUx4MkszbdGQchIGRqZB2Pddu7+dfy7WzryiKly5iyOOcdVYvluKxpSdCTtklbHtWFAdK2ZHR5hOYui3hYJ5VzaenOMroiQihg9pYOSza0pGhPZslZNmkbymMGXVmHjVtaWPnkg6z9z5/wbIvq2adRt+ASAmV1dKdshIvfDVtA3NTYlvRwAR1wHJvMhtdIrXqOzPpXkXYWEYwSGT2b0Ji5REfNRMRL+309jN7P3bOlmT8b20124Ka6/DXBjjPYdu8Z7N5qOISGMEyEEfTXBKEYeqQALVqMESt5S81Ene4Wv5R9w2ukNi3ZeQ3j5xGdvMDfnd/laFhxCCKGwJGCyngIgUvOlegdjax6/LdsfO0/hKJxzrzkcha8+yNUlhZRUximuijCpOoCWnqyWI6HFJKWnhx1RWH/qIAnSecctnamEULwRkMHRdEgReEAM0cWM7Y8ppLrYehQrglgGCTYAPPmzeOVV14hmUwSjUYP/AnKYWHXcnAAXRPoApzecp73zK7joVca8KTfqEoTAuftzuZ6B7Kb36Dr+d+R27oSPVZCwdHvJnbUWQfctd3B6W4mvWYhmfUvk93iz78WRpBg7SSC9dN6Z2CP32Ms1qHkWRn/zNXWVWS3rCC3daVfQt7b4Cw87hgi4+djltQe3PPZOVLLn6L75T/jdjf7481O+BDhsUfv9Y3vQ/NGUFMU7v/zIEDXNVzXQ9P2voOtCbj0mBG8Z3ad2rEehtLpNHfddRcPPfQQdXV1jBo1ihUrVnDMMcfwne98Z8BeRyXYg2Ow1gS///3vueSSS7jlllv48pe/PODPrwyOrrTFlo4UWcsFoZHM2YwqjRA0dDz8kV1BQ0MAadujJGKysS3Fv5ZtoyWRozQewHY9ZtQUsqEtxfauDK4HwoACUyfrSWbUFBAxDTZ2pACNwpCBi2R0WQxdExSHgzT1ZNjcliRnOXRlLQQCx3HZ2J4mZ7s0bNvOluf+wPaFj+C5DiXTTqLq+PdSUj8W1/MIBQMIKchYNq1J/4bxjlvbWUCzc2Q3LSG9diHJDa/1zZk2Sur8Xe26qQRrJmIUVRMSAk2AJWH/h6QGlpQeTsdWclvfJNu4ktyW5ThdTQCYBeUUjZ9L8aRjyVbP2GOSCkBUwMjyILGwie14lMVC5Lat5aW//op1rz2DGYqw4N0f5cqrPseYugq2dPpnpIO6RnfO4YRx5aQth8aODK2JDCFTpyTmJ9GhgI5AkrVdnl/bytrmJOXxIEFDo6YowpiyGOUFQUKmmmM93ByqNQEMkwT7b3/7GxdddBHf/OY3uf766wf8+ZXBsWhzJx/42Uv9kuYdO9G6gEuOGcGfFzf2KxfvTFv89sVNbE/s7Z7u4JFSkt28lO6X/kCu4Q20YJTYrHOIzzp/v83Qdufl0mS3LCO7aSnZLcuwWzaxY+/dKKn1x2OUj+qbd2kUVaGZwYG7DsfC6W7unXm5xS9Hb96I3bHz/LdZOoLgiOmERs4gPPKovTZn2Rc31UViyWMkFv8TL91FoHoihce9n/BYvzmM1vt9Xdec4JVNnX2ft6PE+/0/e6mvUkEDLpk3ou9norEjzXNr2/Z4TQEETdUlfDiRUu5xo+axxx7DsiwmTZpEfX09kcjB3QA7GCrBHhyDtSZwHIfa2lp0XaexsRFNU02NjgSuJ9nYmmRrZ4byeICujEN1YYiAoSGBaMCgM+0XVBeEDMrifnXCuuYEf1uyBSHAc6G6KMRrmzpxXA/LdcjZkrJYkI0dKQKGwdjSCKMrY4Q0nazjYjkeFfEQICgvDFIYNOnOWCzc2EZTdxZTF4wujVAUNlnVlKAjZbG1M01XRyvrnv492xb+E8/OUjx+DkedcykTj5rPtkSWlOXieJJUxiaVA1MDR4Ll+bvTEnCkJNOykdym18k0LCPbuLLv3LUWihGoHEOkfDThqpE4hXWYxTX77fr9Vu3oSu50bsNub8Ru24zVshGrecMuccQJ1k8lNGI64VGzCJfWUxLREEi6MmBokNrt7F1cQE1JmPICk9yGV1n4j/tpXLWYYCTOUWd9gOMu+CCzJozkolm1RIImK7Z1sXp7AteTlMQCzKorJmBobOnMsLUzzdjyGOGAgaELDE2QyLk4rsu65gSb21O4UhIydcaVxxlVFkMKqC+OoKsxXMPCoV4TwDBJsF3XZcSIEViWxfbt29F1ddfqSPHAyw18/W/L8TyJYWgg/bM2Zu9YJdjzTO0DLzfwtb8sy1vMuW2r6Xn5z6TXLgQhiEw4jvjs8/Y6kuNAvFyK3LbV5JrWYG1fh9WyEbe7ud9jtGiRP2IrWowWKfIbgvXNwTZ752BrID2k5/pl43YOL5fCyyT8JmepDn+kR7Kj33Pr8XICFaMIVI4jWDOBQM3Evc7j3B8pJVbTGhKvP0pq1XPg2oTHzKXgmPfs0VF9clWcb717Ou//2Yv9ztRrvU3J7nthAzuO0gcMjRsvmMqNf1+OtZ/z17CzId7/njLuLcWuHLnS6fQ+3zD39mb7TqgEe3AM1poA4Nvf/jbXX389Dz74IJdccsmgvIYy8KSUdKYsso5HQBdke48MlcWDBA2drO3v4wYNre/feGsiy7bODBnLJWXZJC2H5Q1d2J5H0rLRJSQtj+6sQ0HIIKhpLJhSSSSg057MMaEqTirrsqE1RSykgxBMro7zz6VNxEIGpibY3p3GNEw8CTFTY317CttxMTRB0/YWlj75Jza98FfsZBfRihEcdebFjJh7GkkZZGt7hoTt3zQ2NfA0iJnQkfHHfIF/ZlsCluditzWQ27Yaa/va3hnYm5HOzpPawgyix8sx4iVokWK/5DsY88dwmQHQzb7jYH7zUxvPziFzabxcEjfdg5vu8tcEibbdnjuEWT6SYOVYolXjCNdNwi2uwxQaDn7JesSEwpDAlYLSSICenEsia5PI+dcogZDbQ2rlM7S9+k8SrVspLK/mzIs/Rs0xZ6MHopimRiigc/HseupLo6zY2s2apgSW51JfEkHXNIKGxoiSCD1Zh+6MRUVBmOJIgOaeDDnbY1N7EqTAcj2klNSXhqgtjhELGKRslxHFEQw1imvYOJRrAhgmCTbArbfeype//GX+7//+j49+9KOD8hrK4Ni10RkcuEnVXf9Z95a6ig8Wu2s7icWPkHrjCbxcCqOkjtiMM4lNPQU99vZ3Uj0r4zcz6dyK3bUdt7sFJ9GOm+rA6513uesb4j7pBnoojrbbzEujuLq3oUndOypLdzM9pFY+S/KNJ7BbNvjdxKedSsHsCzDL6vf5eZOr4qzafuAxakfVFVJREOKJlc17/XhtcZjWRA7X9fpuyKgd7KHtkUce4Q9/+AORSITCwkJGjRrF2WefzejR/uz0f/7zn5x88skDflRIJdiDYzDXBF1dXZSVlTFnzhxefvnlQXkN5fDQlbboSFmYuiBre3RlLDa3JGhNWRRHAowrj/Dwa1vZ2p1BCkFVQYhTJlVQGg8RNXUCpk5TV5rmnizl8RCtPVnqiiNsbE+ytTsDHnSmsoSCBjEzSCykMbkmzuJNXcSCBqmcw8ptnTR2JFnz0uNsffGvpLetQw+EKJt+EiUzTydYM5mMLQgYEAmAYRg0dTsHVfYtPRe6WxA9jWRam7C6m7F6Wvv6sriZHmQuzYE70Qi0YAQtUuCvCaIlfuPTokqMomrM0jr0gnKE0NCAwoB/794Q/k3wZBakhGhYoyxm4KFRURBiW2eaTNZCei7bVi+jY+kTJNe8iHRsqifMZMrp7+WoE85gTHkhjZ1ZOlI5NEOjpiBIRWGYE8eXs3xbD57rsXxrF3VFEYqjJqARDmpIIQgKQX1pDNeTNPdkKIoEWLmth8p4gFjYxNR1JlbF6cnaOK6kKBKgJHrw58uVI1O+1gQwSAm2EOJW4AL8Zonrgf+RUnYd6PMG8800kUhQVlbGlClTeP311wflNZTDw66dxnVdw/O8vl1PTUBFPMj2nj1LyDXhvzm8nZ/8oKGR20eXcs/Okn7zBZJL/01u6yoQGqFRs4hOWUBk/LEHfVb7rZA77kjbOb8mbke/bU1HGAG03t3tgeZZWTLrXyW16lky618DzyFQOZbYUWf5sz4P4lpjQZ1k7sDLCoE/dmtvze1MXfDQJ+cDqmv4cFJfX88999xDd3c3juOwcuVK0uk0l156KWPGjOE3v/nNoJy5VQn2gb2ddcFgrgkAPvGJT3DfffexcOFC5s2bN2ivo+SXlJLujE3O8SgMmziuR0N7GtdzKYmGSFkuyxs72NyeJpGzGV0aY2RZDF0XlEYD5ByPgKHR3p2loTPF5o4ss+sLKYwEkFLSnszy8sYOetI2oZDJ/NGlnDKpgrUtSVZt62Zda4KG1iTNqRzJrItl22gdG9ny4j9pWvIMrpUhUFxN0dQFjJp7KlWjx7G9I01n1t/B9vBnSWv456yd/V4tRDRwPPp1H5fSQ1pZpJPzG5/u2uRMNxFm0P/fARqd7thJDwj//zXd7wRu6pB2ev9sCCoKwmgCUjmXXMtGulY8y5qXHifT1YoZjlE39wxGnXAB9WMmUlccIhIyaWhLUxw2KAkHaMs6lMaDjC2LMr48zpvbE7Qn0rzR2ENtcYiq4gi1RRECuo6LJGbqxMIm0aBOY2eGwnCAzpSF67kURAJUF4SoK/aTKE9KtXM9TORrTQCDl2CfCTwtpXSEELcASCmvPdDnDfab6dVXX80dd9zB/fffzwc/+MFBex0l/3bf9f7z4kYEMLWm8IAlxQIYWRrB1ARrW1N7fNzcy0ioUaURNrWn93js7uz2LSSXP01q5bO4PS2gG/6ZpfHziYw9+h3tbOeLm+khs/41MmsXktm4CGnn0GMlRCadSGz6aQecfb27k8aX7fUs9d5ogND8MV66Ljh1YgXl8aBqajYMvfzyy9x444089thjAFiWxfr163n22Wd5+OGH+eUvf0ltbe2gdI9VCfaBvZ11wWCvCTZs2MDYsWOZN28eL730kuosPMxIKWnoSGNosKElSWfGor4kSk/GJmTo6BqETJ2ujEN5LMC2rjSbWtNousByJEVhk0hIY/GmTpI5h2TOIWs5nDe9Bk3XydoWDR1ZqgrDNHWmeH1zB8mcQ9qyiEdCOK7EzqTZ/NpTtCz9D10b3gDpEa8aSdnU44mOO5ps4ViE0AiZYNl+E7QDCeKXah94RfL26fgJdkkE7N4poWVxk1goSHN3D3rLWlpXvkzD68+T62xCaDo1U+dx1MnnMeXYUwmEggghKAqZhAMGSxs66Mw4REMBP8mOBRhREmFiZQFVRWFe3tDKS2vbsT2P8niY2qIwM0cWY7sSXZd4Lhi6TnHEwNA1pPR7r5THgiD8c9nq3/fwks81ARyCEnEhxLuB90kpP3Sgxw72m2kmk6GkpIQxY8awfPly9Y9tmFm0uZM7nlzDC2vb9j3aiYEf57U3Unrktq4mvfoF0mtexO1pBSBQOdYf0zXyKIK1k9DMw2+EjHRsf0TX5qX+SI6mtSA99FgJ4fHHEp14AsH6qf3GcRysK04aQzxscuu/V+/xMU3ArpvVmvDPYe9obqd2qYe3ZDLJu971LiZMmMC1117bVwIG8I1vfAMhBDfeeOOgvLZKsN+ag10XDPaaAOCqq67irrvu4tVXX2XuXPUtHE5cT7Jmew9Bwx/E3JbMUhgJksrZBA2NtqSF6/nnd3UhWNOSpDBskrNtXE8jaAgSlsO2zgyN7SlGlUaxpUc8bBINGaSzLmtbkkyvLaQwYrJ+e4KM7bCssRPbgeKoQUEoRE/OIhLQ2drUxKZFz9Cx4gU61y9DSo9ArIjicbOJjpmJqJqGF69g77VyfmM0E3+H24V9Pm5XAdjnnO390XpfzwUihsRItZDevJTuDUvpWLsYO51A6AaREdMJTzie2hnHUVdVwfETK0hbHpGATnfWIWf5e/SNnVniYZOxFVFae2w+Mm8EmmEwuixKIuvQ3J3hxXWtIPwd6Ok1RcweVYLjSizPQwMCukY4aBAJGHieRFMNzIa1fK4JYN/rAmMAX+PjwO8H8PnetnA4zJVXXsntt9/Oiy++yPHHH5/vkJRDZNeZ2pLeNwddcPLECp5Z3YLrSXRNMKI0yrqW5KDHI4RGqG4yobrJFJ96OXbLRjIbXiOz4TV6XnmYnoV/BM0gUDnWH9dVPYFA1TiM4uqDnlU9EKSUON3NfoO1pjXktr1JrmktuDYIjUDVOAqP+wDhMXMJVI9/R7HpAuJhk2PHlO6ZTAPfumg6K7Z1981AV0m1sqtYLMaf/vQnbrnlFm666SYmTpzIySefzPz581m1ahUzZszId4jKTofNuuDLX/4yd911FzfffDN///vf8x2OcohIKdnenQEB27qzlMUCTK8rwtA1XM8f8VRfEmVze5Jk1sV2PeJhg6KwyVbLJhqA7d0ZNE2nqiBIIutQEAuQzDgENZ32Houc41JXEKQ7ZVFdGCIeMnhzayddWZeAAE0LUF4QoD4QpjWRIaPHKZl9LkVzzqNUz5HZ8Dqblr7I9lWv0bzkaQDMwgoitZMJ1UwkVD0Or2w0WsCf3Wv2dhBzpL+IP1CCHQHCYUhlDm5nfAcvlyLbvAF3+1oy29aQ27YSJ+E3SA0XlTH12FOIjZtLd+FEgvEoPSnICkhaDpva0owui5B1HDoTWTQJWddhbFmEzR0ZWnosxpTFqC6OkLX9Uu54yKAn45eAO55HPGgwo76E4ui+J6io5Fo5XNcEB9zBFkI8CVTt5UPXSSn/1vuY64C5wHvkPp5QCPFJ4JMAI0aMmLN58+Z3EvcBNTU1UVNTw9lnn91XNqAMfbs2PhPACePL+PzpE5gzsrhfOfmfFzfywMsNeY3Vy6XJNa7w5003rsRqXtfXtEyYQcyyEZil9ZjF/pguvaACo6AUPVq813mTByJdBzfVhZtow+lpxenejt2xFae9Eattsz/vGkA3CFSMJVg3mVD9NIL109Dfwoiu/dl9lNYDLzdw/V+W9S0QrjhpDF85d/KAvJYytHV3d/Pcc8/x+uuv8+yzz9Le3s6sWbO4++67CYfDg/KaagfbNxDrgkO9JgC44IILeOSRR2hoaKC+ft8NGJWhw3E91rekkMJDkxCPmFQX7tkjpC2RY3NHCtf1yLkexeEAy7d1IaVkW2cW15OUxoJUxoNI4dHcncN2XFa3JCiOhKgtCoGAo2qL+e+GFv6yaCu6AFPXqYiHOHt6JT05SVN3ihdWtyKEh2V5lBVGmFQVw5GwuqmLji1raVyxmM6NK8hsexM70d4bocAsqiJQVk+0vI5AUTWBogqcSDkyXooIRPZZrakBIfYsJZdS+lNGEm04PW3+6M7ObTgdW7HaGvzjbb30ggoK6idRPHY6F5x5OiPGjqM17dKTzrByWw/N3TaahIKIoCgWoqogRHU8jOVJVjf34Lgu6ZzH1LoCJlUVMKIkxrTaYoTmrwmqCsLomiBru3SlcmxsSxILGtSXxChSTcqUg5CPNQEMYom4EOJjwKeA06SUB3UU5FCUgwG8973v5eGHH2bDhg39SgaUoWv30V3fefd0PjhvxB6PW7S5k0t//pJ/rkeDgnCAjtTbKaAaONJ1/NmTzeuxWjZitzVgt2/ZY6QWgAhE0MJxvxOoGfYbmWkGQggkElwX6Vp4VhZppfEyCbzcnufMtWgRZkkdgfKRmOWjCVSOIVA+elAaox0zqpgFEyv22I3uG9EmJQHVBVx5i7LZLMFgkM2bN1NbW4tpDvzP7g4qwT44b3VdcKjWBM888wynnHIKV111FT/+8Y8H/fWU/HNcj9cbOslYLhnbYUp1AfWle94w9jxJT9bG9STRgM7W7gz/ebMZXdOQUhI1dSbXFBI2dJoTWRCwdnuCrZ1ptnSmaO6xqOmd1V0aDbCksYNcTpJ2PCZWxpg5spjiSJBE1uKRJdtAQCprMbI0QsA0sR2XrOPS3J0la9m0JDyCQajQ0jRvWEFX43qS2zeRamnA7mzCc+z+F6Ab6OEdY7pCfiOzXcd0eS7Sc5B2Fi+XQeZSuJkEeP1bqQkj6E8YKRtBoHwUgYrRBKrGEY8Wo5tQFtOojMcwDJ2oAeGwydiKCC+ubsWVkohhohkaEyqiSKHRksixobWHrrRNyNSpLYrwrln1TK8rZkRpBE0INEG/mwPNPVksxyWga6Rtl3o1Zkt5Cw7lmgAGqURcCHE28GVgwcEm14fSF7/4RR5++GG+973v8bOf/Szf4SiHQGfa6is7FsDybd17fdyckcU8+Mn5LNzQTnEkwPW7zdMWwNjyKOtbU4dsDJjQ/VLxQOXYfn/v2VmcrmbcnlacRBtuqhMv0+MnzVYaaWf9mdde2m+PjvDPRRsBjHgpIliPHoqhhQvRY8XosVKMwgqMgopB6Wy+NyFT49pzJu81ce5MW7ie9Gd/2h4LN7SrBFs5oF/96leceeaZ1NbWAjBq1Kj8BqQAh/e64OSTT2bGjBnce++9fPe73yUWG5jKHOXw5UpJYcTEQ+JJjdakRVncJRzo3ztE0wRFEX+ntCttsbk9RXfKIuO4FIaCzB9dzpiKKFu7M1TEQ2xqT9GRsiiOmTQndcJBnaqiMCu2duM4EtuRRMIG1YbGcePLKI0GaU1aBA2didVxPA+EgHljS3m9oQtTA1d6tCVsoiFJOGf5vWLCRUyZdwrpOScRNkxylsvIkhDpnlYWLVtDurOF7vY2Uslu3HQPMpfCszJIJ4dn53bpIq4jdAMtFO9974+ihQvQo0VE4iWEissRsQq0SDFWb1Iu8JucCaAk7s+7ri6O0pWwiQQ9SiIhkjmXRFoytqqYkOnfXAiYGts60ui633SsOBLB0BzKoiYVhUEKwyaFYQNzH0mzLqAn62AIMA0dTfVRUg5Cc3Mzf//737n88ssRQuR9TfBOz2D/BL+R4RO9d58WSimveMdRDZD58+czZ84c7rvvPm644Qbq6uryHZIyyI4dU4qhCSzXT9j+tKiR9+6j0/SckcXMGVnMJ3/z2h5nmD44bwTvmV3Hh+5bSM728jprWzNDBMpHQvnIPEbx9h0zqnifyTVAcSTQ9/X1ev9bUfanq6uLq666ivnz5/Pkk0/mOxylv8N6XXD11VfzP//zP9x6663cdNNN+Q5HGWQBXSMSMGjqzhIPByiIGKQsZ48EewfPk2xsT9LSlaG+JIrtSkaWRhhXGcX2JIYQdNsOuq4xrjLO1q4M9cVhUhmX5q40QhM40iNoCEaXRSmNBbFdKImFcCUUhE1KYyYtPQ5tqQypnOs3OdNNSmMRempcGtrTFEYhlbGwXEFxJIxpOwjpURINgm5gBcsZNyPOtu4cgbRDzPYbke0++FLDT5LtPS8V8JPnsrD/mFBAI2N7dKYhFABdg+qiEGWRAHUlUTqSOYIBnWQ2A7ogaXmMKo4QCmhMri4gFjLpSGUJB00sxyOg6RRHgkyv1ejK2BiGxriyKLNGFlMU2fe5aokgazm4UlIRN9Q5a+WgfOMb3+C+++5j/vz5TJs2Ld/hvLMEW0o5bqACGSy33norp556KjfccAO/+tWv8h2OMsjmjCzm4rn1PPByAxJw3f3viC7a3MmTq5r7/Z0Q/pgvgPfMrmPF1m7eaOzOa5J9pBFAZUGQi2bWHvBM9a5VB5rw/1tR9uc73/kO6XSaK6+8Mt+hKLs53NcFH/7wh/n2t7/NzTffzJe+9CW1iz3ECSGoLghhO57fUFNC2Nx3uXHackimHDrTNsmcTW1RmMKowbbuLLYrkdLD0DRGl0UwdY1o0CBoFHD0qFLaUxbrmpNsbksSNAPkbJfCcNDfrTUEc0eVsLk9ydYOD12XlMeCVBeG2NSapCfnIIXGvFGlVESDrG1J0pNxQEhsKYkFdMriYbozHkUxE9t1MfUA7SmbslgQNIemTpewAV7vmJSSWIBAwCBsChzXJWd5pDM2nTnZVxkeDEDQNCiKBqgrDtOdcchZLqVRg+qiOMeNKyFteXhSEtR1bE9SU9CD3tskbkJ1jJqiCDnbpT1tIRF0pCw86Xf6romYTKsvJmQIDKERC5v73LnewXJcqovCGJpGKueoTuHKAW3ZsoWf/exnHH300UydOjXf4QAD20X8sHTKKadw3HHH8etf/5rbb7+doqKifIekDLL3zK7jz4sbsR0P09D6ZmTvzcOLG/t1sQa/yvrGf6wAKXE8v7ulqQvc3l/yU6oLmD+mlCffbNlrJ3JdE0yr8R9z3wsbcXZ/gSOAwL/RsLfQdQH7Gi+uCThtciVXLBh70GXex44pJWBoB/X9UhTbtvnxj3/M+PHjee9735vvcJQjjGEYfP/73+c973kPd955J9ddd12+Q1IGma5r1JdEyDkeuiYImfseLZlzPCxPUlEQJNvu4kiJqelsbk8xtjyGEDoSG4E/I/uo+iLK4/6YTdtx+ceSRloTGUxTpzhqMmdkETPqSzA0gQR0odGZcUmmc2zoTrGxLU3I1BhbGSdte4yrjFJXEqEoGmBxQxeW42DoEAkYaEJHEx44Hq6UCE0Q0HVc4RELhdDxmFpbgK6BIyUzRxaTy3nEAjqxSABPStq7s7ywro2erI1EMKI4hCugpjDEto4M5bEghRETTQim1xRSXhCmKBJACEFpNIDtedQUh0nmHCKGRnVxhNriMC+tb6WhNUlJNEh9SYhE1mVkSYSaogiVhW+twVQsaNCesgCPeEjtYCsHdttttwFwxx13HDajmYd8gg1w/fXXc+6553L77berkrBhYM7IYu6//Ni+juH7S/T2lfrajtf3cdf1uOSYEdQUhfs93xlTq3jfPS/2e46yeIATxpaxsS3FooZOZo8o4pVNnft8/XhIJ5Hdvagr/zQBHzhmBLVFYRIZm/te2IjnSQKmxjGjSnhubVu/xwtgRl0hU2sL91mSvy9v5fulKPfeey/ZbJYvf/nL+Q5FOUJdeOGFVFZWcuedd/LVr34VTVMNlIY6Q9cOqlFWJKATNjVSusG4ijiRgEHQ0AkHdLozNqauURYLURg2cT2Jqe9czJuGTllBkPEVcZK2i/Qk8ZDJptYk23rSWI5HQcgknc6RyHlUxcNsyCVIpCVr7AQji6MUhkKETIes5RI0oDwWwjR0dE3415D12NbjUFsUpTubY0J1gEzGpimZo6YkjONJIqEAp44rQaBTWRdkdGkMTQi2dKZxXY9JNQXUFIZZ05IkrAtcBFnLxcKlu8eioSPF3FElRIJ67zVqeEA0aGDoGgUhk/aUhUBQEg2Qsx22d2UZWRZjW1eWZM5l1ogSisIBDF285R3owkiAYO9NEH92uaLsm2VZ3H333UydOpXjjjsu3+H0GRYJ9tlnn01dXR0/+tGPuOGGGzCMYXHZw9qO89UH8t7ZdfzhtS04u2zJaoBhaCCl/+ZiaLxnL0njnJHFzKgrZGnjzkZqbQmLvy7Z1u9xvdVae5W2vP1+/GCdNL6M7oxNZUGITy0Yy+rtCW7467J97jTvixD+9ZuG1i9RPmNqFX9e3IgAntqtpJ7ex69q6mHZ1m4eXtz4ljuBH+z3SxnepJT84Ac/IB6P85GPfCTf4ShHKE3TuOaaa7j22mt58MEH+dCHPpTvkJTDRDhgML4yjqal0DWJ7Xg4nkddYYR42N9NDZs6Qgj0vSSNpTH/bHHIEOhagFXbu0lmXTRN+OeSy3QiIQNLOuQcB1PTCIQECMFRI4vQTY10xuXN5h4kgva0zXFjChlZFqMrY7O8EWJBl6aeNLbjUlEYoaYwxCgvRlk0gOV6TK4p4NQpNXi9zUN3xOlJSXNXlvoSsFy/u/m80WV0ZXI8vaqVseWFbDPSpLI2BeEAXRmbSMggbOoURMy+GxQBQ6eqIETGdv0GpRKSORdXgqELKuJh4iGDtqTfbT0aMKgsCJHr3bgIGtoBdxn3V2WgKLu6++67cRyHa6+9Nt+h9DMsMk0hBF/84hf5/Oc/z29+8xs+/vGP5zsk5TAxZ2Qxv//kfB5e3IgEptUU0pm2+sqUD7SrOn9Mab8Ee2/2m+NKSdDQyDr926ztOJN8MC6aWcMdl8zq93dzRhazYls397/FWd9SwmlTKvnUXkq8H17ciOXs3g4OplTHmTmimIdeacCT/u7/wg3+7M4dX799/Vkl1cpb8cgjj7Bhwwa+8pWvEAzuu0mOohzIlVdeyfXXX88PfvADlWAr/ZQXhCiJBbFdDw0Qmt8N+2BKT8OGSW1JlNYei4KIQcA0MGxAQta2kJ5HOGACGq0Jj3jYoLowQtpySeZcqgo1WpNZyqIRHM+hPZFjcnWc+tIC1rck8DyPllSWps4sZbEAJVETXeicMbUcXdMwDY0RJf50EMv13681ofXdEBhRGqE4FqAzlfN31zvSdKQsJJJ0xiGoCxzDYG1LkkTGoSBs0p7KURDpP+qoNZkjkXEQAgKGoKYoxPaeLLGgzsSqOFnbIxzQ/XPUlkNb0qInYyGEoCBsgpRkbY9IQCfreEjpzxlXibXyVkgp+eEPf0hJSQmXXHJJvsPpZ1gk2ACXX3451113Hddddx0f+tCH1OJM6bO/3dMDJYDx8MHN19vxtrx7zqxpAlMXZPuPouS0yZW8urGdrsxuH+hVFQ8yvirOOdOq++Z8L9rc2S9x3XEWPWd7fdvoB5OzP7GymZMnVgA7E+GFG9qxepvE7L7EOG9GDceOKeXhXc69F0cCfOi+hViOh+EPucRx+/9ZzbxW3grP8/jyl7+MEIKrr7463+EoR7h4PM4nP/lJ7rrrLv7whz/w/ve/P98hKYcRXRPo2ltP9iSSingIz4Os41AZDWD03jCPBSPEQwE0TTCpKk5HyubVTW2ksza256EL/6x1VTxCMNBFNiMpjgXwJJi6YHxFjFc2mrQmslQVhEBIdHTmjS5hVFmcaNBAE345fFfKor03cS6OBCiNBYkGTcJBvwN6RTxIWyJLc08Wx5NUFoTpyebQRAAhdEKmhvQk3RmbZM5hW2eGuh3n2AWkcy7xkIHrSTa1JYgGTabVhYgHDcKmgS5c2lIWSI9o0CBtOUSDBgLY3pUhEjQImRprm5OUFQQIGTrNPVlGlkYH/pupDFk/+clP2LJlCzfddNOgz7t+q4ZNgh2NRvne977HZz7zGX71q19xxRWHzdQQ5Qi2+0ipY0YVM3tEMS9taCdgaFiOx7Kt3fvcjbZdyfwxxXucaX5iZfMeiewOuoC7PjynX2K6aHNnX0IbMDS+fv5UOtMWH5s/ipc2tBM0NF7dz1nwXUnghr8tRyBxPf+N/cZ3TSNgaHuMLAv0NiXb/Rz1rgm57fqtWyW7/dlRM6+Vg/f3v/+dN998k69//etUVFTkOxxlCPj+97/Pfffdx80338zFF1982DTHUY5cQUPHk5L60iiGDiNLYkzWBQgIGTqulP5ZaAktPRkKQiaJnE3OAsuTrNjWzYzaIta0dPPy2iwFsSALN3YQDwfJ2C7hgEFtcZhk1iYYMJg1spC60gi262E5HrGQv6zvztpEgjoCaEvlyDmu//4tJTnHI5mTbO5M43iSpu40hiZYMKGSwrBBY2cGkGzqyCCBtOXiuFlak1kCmk7Q1AgYGtu7c7QlLJI5m4JIgGzOpSBoEjI1wgGdoOl/LUKGTkcqR3fvjnco4MelC4FE+kfThL8LIKVU/w6Vg+I4Dt/97nepqKg4LJtVDpsEG/xd7K985SvcfvvtfOpTn1L/iJV3bPcRUwsmVvC/p+ycUrMj8d3fLG0JlERMOtL2Hn+/u6NHFTNnRDF3PLmm3+71rgmt5Xh8/W/L8aQ8YJm5wN9Fd3d74K7/bbmS37/asEdzMwG8b87Oc9q7VwLs6Ayu9+5au27/P6uO4cpbceuttwJwzTXX5DkSZaiIRCJcccUV3Hnnnbz44oscf/zx+Q5JOcLFQgaVhWGChoblekSDer/mahqCyniQ7pSNaeoYhkYu4ZGxXeIBnVTOQwhBYShAZXGEeNikuStDc08GENiOQzRkUhwNcOKESsaXx1mytYtkNo2G5KgRxRRFAkR6m7JpCGzXw3El3VmbRMb2y7p1QUTXabIzlISDeEgMTRAOmswdFWFrV5qymH/OuiQWpCOZo7XHojQWZHNHimhA0JxwsCyHaNigVNeIBXSqCkN9a+tdy71LY0HCAT/lCBoaLYksGdtjVFkMy/GwXI/yWFCty5WD9vDDD9PU1MS3vvUtdP3wO1owrBLsUCjElVdeyW233cZ//vMfTj311HyHpBzhDjRiasfO7sOLG3mw94zy7s6ZVs3a5sQBX+uimTUcM7qUr/1lGQDP9ya7H5w3ol8cQvgJ88GUgwdNf7d7xbZu1jYnWLS5Ewl9z7HD3s6Z65rgvbPr9vq8u+9ogzqDrbx9q1at4sUXX+RjH/sYhYWF+Q5HGUK++tWvcuedd/Ltb3+bRx99NN/hKEe4gpCJJyFnu1QVhPfauTwaNIkEDJoSGdoTOXRDIx7U2dadZURJhEhAozQeoiSSJuf6R7y2dmUpjRiMKo1jef5zjy6LkXU9UlmX8niQnqxFc0+GokiA0qh/nllKSdrWsWx/lnVAE6BpOJ5HcSxIQTpHVUEEAcQjBtWFYQKGRlE0gOdJtnVnsBwPiaA4GiCZtbE9ieUJooZGcThMyrLRpJ9EF+zj2JwQgmhwZ8pRUxRRu9XKO3LrrbcihOCzn/1svkPZq2GVYIP/ZnrbbbfxjW98QyXYyjt2MCOm5owsZvX2xB4Jb8jUOHtqFR+cN4JXNrbv0X18d9GgwWPLm/r93WPLm/jgvBH94iiOBLj5kRV9O9o76H5jdGRvp8+L59bvMVJrxznu4kiAG/+xYq9NzcDfvb75wmn7TZB339He158V5UB2lH8dbl1ClSNfZWUlF110EX/9619Zs2YNEyZMyHdIyhFM0/zRVfvjeZLmngy2IxlZHsN1PbK2pCBkYOoaKculOGRgOS5bOjOMrYgzoihMd8Yi6XhUxYJ40q80Cxo6AUOjPZnDk1DUm+BqmiAe8v8cDni09GSJBgxc3d81D+mCaMikIh4k7bgEdJ2aIj+53vVaagrDWK5HadSkLWnRmbKIhwM4jsvWXBYTSTxsMr4qTkn0re1Aq+RaebteffVVXnvtNT75yU8Sj8fzHc5eDbsEu6SkhE984hPce++9PProo5x77rn5Dkk5wh1oxNSizZ18/W/Lkbtl2Fnb469LtlFVEOKRZU17/+RdSPzd7ud3KdM+Z1r1XuOYWBXvS5RXbOtGQt9u84FuBuz6HA8vbuSPr23BdvvviH/qpDF95emKMpheeeUV/vKXv3DhhRcyadKkfIejDEE33XQTf/3rX7nyyit56qmn8h2OMsR1pHJ0ph0s1yNs6niaRs6xiQZ1OtM2lpfAtj0KYyHqhd/BPOt41JTEEPiztT0JuuafhZ45oojOVI6waVIS2zO5N3WN2mK/s7jXe9d9x1zq4kiAnONh6AJzL7vtmiYIaTohUycaNKkrjmC5HlnLYWRZDCEk5fH+ibmiDLbLL78cIQTXX399vkPZp2GXYAN897vf5d577+V73/ueSrCVQbdwQzve7tn1Lv61Ynu/Odw76MIfD+K6ElMX/XabH1ve1O8M9u7eSWf03Z/jPbPrWLihnUTGZkVTz35fV1EG2i233ALAbbfdludIlKFqxowZfOADH+D3v/89q1atYvLkyfkOSRnCco5HYUhHEyHakjnqSiLIzhQSf+yQdCVBU0MXgkjAwDR0AobOuIoYOccjbbnEg36nb/A7j0cCB7ec13ab3a1pgnDg4M6vGrqGoUPQ1Pt2xxXlUHvhhRd44403uOqqq6ivr893OPs0LBPs0tJSPvzhD/O73/2OpUuXctRRR+U7JGUI23E+2nI8NCE4f0Z1v3Lws6dWcd9/N/ZLsjXgmxdN79uJ3nXH+YPzRhzSBPdAO/SKMlgaGxt5+OGHOeussxg3btyBP0FR3qZrr72W3//+99x666388pe/zHc4yhBWFDFpSeQIGBpTawqJhwz/rLTlEAv6c7HbUzmqCoKkcjplBWEmVRVQ1Du1RJ1dVoazHTfdv/jFL+Y5kv0Tcj87a4Nl7ty58rXXXjvkr7urFStWMG3aNM444wwef/zxvMaiDH27z6h+4OWGfrvQizZ38rNn17OhLcXosihXLBirklpl2PvgBz/Igw8+yLPPPstJJ52U73AQQiySUs7NdxxDzeGwJgA47rjjeOmll1i9erU6i60MKsf1e6TsKK12PYm1S6n2rqXcKqFWFN/ChQuZP38+73rXu/jb3/6W73CAfa8Lhm2CDfCxj32M//u//+O1115jzpw5+Q5HURRF6bV161bq6uo499xz+ec//5nvcACVYA+Ww2VNsGrVKqZMmcKHPvQhfve73+U7HEVRFGUXZ511Fo8//jhbtmyhrm7vU2wOtX2tC4Z1V4Ibb7wR2FluoCiKohwevv3tbwPw9a9/Pc+RKMPF5MmTWbBgAffffz8dHR35DkdRFEXptWnTJh5//HHe9773HTbJ9f4M6wR71KhRXHTRRfzxj3/kcLh7riiKosD69eu55557OP7445k3b16+w1GGkW984xsA/O///m+eI1EURVF2uPzyy4Gdv6MPd8M6wQb4wQ9+AMA3v/nNPEeiKIqigD/pAXb+flaUQ+WUU07hjDPO4KGHHqKp6cDjExVFUZTB9eabb/LUU0/xwQ9+kGnTpuU7nIMy7BPsMWPGcN555/H3v/+dLVu25DscRVGUYa27u5tf/vKXHHPMMWr3WsmLL3zhC4C6waMoinI42HHT/XDvHL6rYZ9gw85v2Fe+8pU8R6IoijK83XTTTUgp+5IcRTnUzjzzTCZOnMiPfvQjmpub8x2OoijKsLV27Vp+85vfcMIJJzBr1qx8h3PQVIINnHzyyZx66qk88MADrFu3Lt/hKIqiDEtdXV3cfvvtzJgxg/e97335DkcZpoQQ/PjHP8a27b5me4qiKMqht6PR6Q9/+MM8R/LWqAS712233QbArbfemudIFEVRhqcdv4e///3vo2nq7UnJnzPOOIMZM2Zw9913k81m8x2OoijKsNPS0sJDDz3EaaedxtFHH53vcN4StYLpNWvWLI499lh+/vOfs2HDhnyHoyiKMqy0tLRw6623Mm7cOM4444x8h6MoXHfddbiuy9e+9rV8h6IoijLsXHPNNcCROa5TJdi7uOeeewDVUVxRFOVQu/XWW7Esix//+Mdq91o5LLz//e9n7ty5/OhHP6K7uzvf4SiKogwb27Zt4/777+ecc87hpJNOync4b5laxexi5syZHHfccfz6179m+/bt+Q5HURRlWEgkEtx9992MHz+es846K9/hKEqfq6++Gtd1j7jzf4qiKEeyb33rW8CR1Tl8VyrB3s2O3esdZQmKoijK4Lr++utJp9N8+9vfRgiR73AUpc8HPvABJkyYwM0330xnZ2e+w1EURRnyNmzYwD333MOxxx7LKaecku9w3haVYO/m1FNP5eSTT+bBBx+kpaUl3+EoiqIMadlslrvvvpspU6Zw8cUX5zscRelH1/W+5nt33HFHfoNRFEUZBnY0nP7hD394xN50Vwn2Xnz1q18F1FxsRVGUwXbzzTfjOE7f711FOdxccMEF1NXVccstt6hdbEVRlEG0adMmfvrTnzJr1izmz5+f73DeNpVg78WZZ57JKaecwq9+9Sva2tryHY6iKMqQlEwmueWWW5g2bRof+tCH8h2OouzTPffcQy6X45Zbbsl3KIqiKEPWjqO6v/zlL/McyTujEux9+MY3vgHAd7/73TxHoiiKMjT95Cc/wfM8vva1rx2xZWDK8HDeeecxbtw47r77bnK5XL7DURRFGXJ6enr45S9/ybHHHsvMmTPzHc47MiAJthDiC0IIKYQoG4jnOxycdNJJzJ49mx/+8Ic0NDTkOxxFUZQhpbOzk2984xuMHj1anb0egobaukAIwXXXXUcikeCGG27IdziKoihDzuc+9zngyJx7vbt3nGALIeqBM4EhlYUKIfrGclx33XV5jkZRFGVo+da3voVlWdxyyy0YhpHvcJQBNFTXBZdddhmTJ0/mzjvvVE1QFUVRBtD69ev59a9/zUknncQ555yT73DesYHYwb4d+DIgB+C5DisLFizg5JNP5ne/+x2rV6/OdziKoihDQltbGz/84Q+ZNm0a733ve/MdjjLwhuS6QAjBbbfdhmVZ3HTTTfkOR1EUZcjY0ej0e9/7Xp4jGRjvKMEWQlwIbJVSLj2Ix35SCPGaEOK11tbWd/Kyh9SOXewf/OAHeY5EURRlaNjxe/WOO+5A01QrkKHkYNcFR+qa4Nxzz+Woo47i7rvvJpvN5jscRVGUI9727dv54x//yOmnn35Edw7f1QFXNkKIJ4UQy/fyvwuBrwEHVSgvpfy5lHKulHJueXn5O437kJk1axbz5s3j3nvvZf369fkOR1EU5YjW0tLCD37wA8aMGcNpp52W73CUt2Eg1gVH6poAdh4bU8fHFEVR3rkvfvGLwNA4e73DARNsKeXpUsppu/8P2ACMBpYKITYBdcBiIUTV4IZ86N13330AqiRMURTlHfre976HZVn89Kc/zXcoyts03NcFF198MUcffTR33HGHmoutKIryDjQ2NnL//fdz7rnncuKJJ+Y7nAHztmvzpJTLpJQVUspRUspRQCMwW0q5fcCiO0xMmzaNs88+m9/+9re88cYb+Q5HURTliNTQ0MDtt9/O3LlzOf300/MdjjLAhtO64Oabb8bzPD7/+c/nOxRFUZQj1qc+9Slg6G1iqsNvB+m2224D4JprrslzJIqiKEemr3zlK4D/+1TNvVaOZGeffTYnnHACv/nNb1izZk2+w1EURTnivPLKKzz66KNcdNFFzJ07N9/hDKgBS7B771i3DdTzHW6mTp3KpZdeylNPPcWLL76Y73AURVGOKOvWrePBBx/k9NNPZ8GCBfkORzkEhvq64M477wTgG9/4Rp4jURRFOfJ87WtfA+DWW2/NcyQDT+1gvwXf/va3AdVRXFEU5a3aUQU0FN9IleFp9uzZnHbaaTz00EO0tQ3Z+wiKoigDbuXKlTz11FNccskljBs3Lt/hDDiVYL8Fo0eP5l3vehcPP/wwr776ar7DURRFOSKsW7eOn/3sZ8yfP5+ZM2fmOxxFGTA7ut5+7nOfy3MkiqIoR44dvzOvv/76PEcyOFSC/Rbdc889wM6yBkVRFGX/brjhBgDuuuuuPEeiKAPrpJNO4pxzzuGBBx5gw4YN+Q5HURTlsLdkyRKefPJJLrvsMqZOnZrvcAaFSrDfopqaGj7ykY/w5JNP8tRTT+U7HEVRlMPakiVLeOihhzj77LOZNWtWvsNRlAF38803A/CZz3wmz5EoiqIc/j75yU8CO2++D0UqwX4bbrnlFsDvKC6lzHM0iqIoh68vfOELgOpdoQxdc+fO5cILL+TRRx/lpZdeync4iqIoh61HHnmEV199lcsvv5yxY8fmO5xBoxLst6G6upovfvGLvPHGG/zjH//IdziKoiiHpVdeeYWnn36ayy67jClTpuQ7HEUZND/+8Y+BnWeyFUVRlD1df/31mKbJ97///XyHMqhUgv02XXvttYRCIT73uc/hum6+w1EURTnsfPrTnwbgxhtvzG8gijLI6uvrueyyy3jyySd58skn8x2OoijKYee3v/0tS5cu5fOf/zzFxcX5DmdQqQT7bSorK+OrX/0qmzZt4sEHH8x3OIqiKIeVp59+mkWLFnHVVVcxatSofIejKIPujjvuAOBLX/qSOj6mKIqyC8dxuP766wkGg3zzm9/MdziDTiXY78AXvvAF4vE4V199NZ7n5TscRVGUw4KUkk9/+tMYhsF1112X73AU5ZAoKiris5/9LEuWLOHPf/5zvsNRFEU5bNx77700NDRw4403EgwG8x3OoFMJ9jsQjUa54YYbaGtr67tzrSiKMtz99re/ZfXq1Xzuc5+jqqoq3+EoyiFz8803o2ka1157LY7j5DscRVGUvEulUnz9618nHo9zzTXX5DucQ0Il2O/QF77wBerq6vjmN7+Jbdv5DkdRFCWvpJTceOONxONxvvWtb+U7HEU5pAoLC/nWt77Fhg0b1PExRVEU4J577qGtrY3bb7+dQCCQ73AOCZVgv0OapnHjjTfS1dXFVVddle9wFEVR8urGG29k48aNfO1rXyMUCuU7HEU55D7zmc9QXFzMlVdeSWdnZ77DURRFyZuGhga++tWvMmLECD760Y/mO5xDRiXYA+DjH/84F1xwAUuWLCGTyeQ7HEVRlLx55plnOOmkk/jSl76U71AUJS9isRi/+tWvqK6u5rHHHst3OIqiKHnzj3/8g1GjRvHb3/4W0zTzHc4hI/LR6VII0QpsPuQv/PaUAW35DuIQGA7XORyuEYbHdQ6Ha4ThcZ1H0jWOlFKW5zuIoUatCQ5Lw+E6h8M1wvC4zuFwjTA8rvNIu8a9rgvykmAfSYQQr0kp5+Y7jsE2HK5zOFwjDI/rHA7XCMPjOofDNSpDx3D5eR0O1zkcrhGGx3UOh2uE4XGdQ+UaVYm4oiiKoiiKoiiKogwAlWAriqIoiqIoiqIoygBQCfaB/TzfARwiw+E6h8M1wvC4zuFwjTA8rnM4XKMydAyXn9fhcJ3D4RpheFzncLhGGB7XOSSuUZ3BVhRFURRFURRFUZQBoHawFUVRFEVRFEVRFGUAqAT7IAkhPiOEeFMIsUII8f18xzNYhBBfEEJIIURZvmMZDEKIW3u/j28IIf4ihCjKd0wDRQhxthBitRBinRDiK/mOZzAIIeqFEP8RQqzs/bf4uXzHNFiEELoQ4nUhxCP5jmWwCCGKhBB/6v03uUoIMT/fMSnKwRguawIY2usCtSY4sqk1wdAylNYEKsE+CEKIU4ALgaOklFOB2/Ic0qAQQtQDZwIN+Y5lED0BTJNSzgDWAF/NczwDQgihA3cB5wBTgEuFEFPyG9WgcIAvSCmnAMcC/ztErxPgc8CqfAcxyO4E/iWlnAQcxdC/XmUIGC5rAhgW6wK1JjiyqTXB0DJk1gQqwT44VwLfk1LmAKSULXmOZ7DcDnwZGLIH86WUj0spnd7/XAjU5TOeAXQMsE5KuUFKaQEP4S8AhxQpZZOUcnHvnxP4v3xr8xvVwBNC1AHnAfflO5bBIoQoBE4CfgEgpbSklF15DUpRDs5wWRPAEF8XqDXBkU2tCYaOobYmUAn2wZkAnCiEeFkI8awQ4uh8BzTQhBAXAlullEvzHcsh9HHgsXwHMUBqgS27/HcjQ/BNZldCiFHALODlPIcyGO7AX9R6eY5jMI0GWoFf9Za93SeEiOY7KEU5CEN+TQDDcl2g1gRHMLUmOOINqTWBke8ADhdCiCeBqr186Dr8r1MJfvnJ0cAfhBBj5BHWgv0A1/g1/DKwI97+rlNK+bfex1yHX1p0/6GMTRkYQogY8Gfg81LKnnzHM5CEEOcDLVLKRUKIk/MczmAygNnAZ6SULwsh7gS+AtyQ37AUZXisCWB4rAvUmmDoU2uCIWFIrQlUgt1LSnn6vj4mhLgSeLj3zfMVIYQHlOHfaTli7OsahRDT8e8cLRVCgF8itVgIcYyUcvshDHFA7O97CSCE+BhwPnDakbgg2oetQP0u/13X+3dDjhDCxH8jvV9K+XC+4xkExwPvEkKcC4SAAiHE76SUH85zXAOtEWiUUu7YbfgT/pupouTdcFgTwPBYF6g1AaDWBEcytSY4AqkS8YPzV+AUACHEBCAAtOUzoIEkpVwmpayQUo6SUo7C/yGffaS9iR4MIcTZ+GU275JSpvMdzwB6FRgvhBgthAgAlwB/z3NMA074K71fAKuklD/MdzyDQUr5VSllXe+/xUuAp4fgGym9v1+2CCEm9v7VacDKPIakKAfrrwzhNQEMn3WBWhMc2dSaYOgYamsCtYN9cH4J/FIIsRywgMuG0F3O4eYnQBB4oveu/EIp5RX5Demdk1I6QoirgH8DOvBLKeWKPIc1GI4HPgIsE0Is6f27r0kpH81fSMo78Bng/t4F4Abgf/Icj6IcDLUmGDrUmuDIptYEQ8uQWRMI9Z6gKIqiKIqiKIqiKO+cKhFXFEVRFEVRFEVRlAGgEmxFURRFURRFURRFGQAqwVYURVEURVEURVGUAaASbEVRFEVRFEVRFEUZACrBVhRFURRFURRFUZQBoBJsRVEURVEURVEURRkAKsFWFEVRFEVRFEVRlAGgEmxFURRFURRFURRFGQD/HyZ+/fCVyZT4AAAAAElFTkSuQmCC", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plot_contour(logprob, orbits=samples, weights=weights)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## McLachlan\n", + "\n", + "A different method of discretizing the solution to Hamilton's equations, see [Blanes, Casas & Sanz-Serna (2014)](https://arxiv.org/abs/1405.3153)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 29.8 ms, sys: 3.73 ms, total: 33.5 ms\n", + "Wall time: 32.5 ms\n" + ] + } + ], + "source": [ + "%%time\n", + "init_fn, ml_kernel = orbital(\n", + " logprob, step_size, inv_mass_matrix, period, bijection=integrators.mclachlan\n", + ")\n", + "initial_state = init_fn(initial_position)\n", + "ml_kernel = jax.jit(ml_kernel)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/cabezasg/.local/lib/python3.8/site-packages/jax/_src/tree_util.py:188: FutureWarning: jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() instead as a drop-in replacement.\n", + " warnings.warn('jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() '\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 2.5 s, sys: 997 µs, total: 2.5 s\n", + "Wall time: 2.6 s\n" + ] + } + ], + "source": [ + "%%time\n", + "rng_key = jax.random.PRNGKey(0)\n", + "states = inference_loop(rng_key, ml_kernel, initial_state, 10_000)\n", + "\n", + "samples = states.positions\n", + "weights = states.weights" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAA9gAAAF1CAYAAAATN0JoAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOzdd5xc1X3w/8+5ffpsL1qthCQQQsIGRBHGgAk2NgQ33EvsxA9JnCdO4jiPnzg4wX78ixMnsZ/EeezEcUkcJxgXDMYNGxdMF0VUCSHUpdX23ekzt5/fH7O77KqXVQHO+/XC0ty5c++ZAd97vvd8z/cIKSWKoiiKoiiKoiiKohwb7WQ3QFEURVEURVEURVFeDFSArSiKoiiKoiiKoijzQAXYiqIoiqIoiqIoijIPVICtKIqiKIqiKIqiKPNABdiKoiiKoiiKoiiKMg9UgK0oiqIoiqIoiqIo80AF2IpynAgh3iOEuPMw9/1tIcR9x7Etx/X480EIsUMI8eqT3Q5FURRFORrqvn9k1H1febFSAbbyoiSEkEKIZXtt+6QQ4r9PVBuklDdJKa+aj2MJIX4thLh+Po6lKIqiKEqTEOIvhBB37LVt8wG2vfNgx1L3fUVRQAXYiqIoiqIoykvXPcArhBA6gBCiBzCBc/fatmxqX0VRlINSAbbykiSEeJUQYkAI8WdCiFEhxJAQ4nem3jtNCFEUQmhTr78ihBid9dn/EkJ8eOrvOSHE16Y+v0cI8dezbshz0rOEEFcJITYJIUpCiH8RQty999NpIcRnhRAFIcR2IcTVU9s+DVwKfEEIURVCfGFq+5lCiJ8LISanjvv2WcdpE0L8QAhRFkI8DCw9yG/hCCH+WwgxMfW9HxFCdE299ztCiI1CiIoQYpsQ4vf38xv+71m/4ZuEENcIIZ6batcNs/b/pBDiFiHEt6eO95gQ4uUHaJMmhPiYEGLrVLu+I4RoPVR7FUVRFOUIPUIzoD5n6vWlwF3Apr22bZVSDqr7vrrvK8qhqABbeSnrBnLAAuB/AF8UQrRIKbcDZeDcqf0uA6pCiBVTry8H7p76+9eBkOaT7XOBq4B9UrqEEO3ALcBfAG00b9yv2Gu3i6a2twN/D3xNCCGklB8H7gU+JKVMSyk/JIRIAT8Hvgl0Au8E/kUIcdbUsb4IuEAP8IGpfw7k/VO/w8Kptn0QaEy9NwpcC2SB3wH+UQhx3qzPdgMOzd/wRuArwHuB1TQ7B38lhDht1v5vBL4LtE61/ftCCHM/bfoj4E00f+teoDD1nQ7VXkVRFEU5bFJKH3iI5r2eqT/vBe7ba9v06PXXUfd9dd9XlINQAbbyUhYAn5JSBlLKnwBVYPnUe3cDlwshuqde3zL1+jSaN50np56eXgN8WEpZk1KOAv9I86a3t2uADVLKW6WUIfDPwPBe++yUUn5FShkB/0nzJnmgJ7TXAjuklP8hpQyllI8D3wPeNvUk/S3AjVPtWj91vIP9Dm3AMillJKVcJ6UsA0gpfyyl3Cqb7gbupHkDnf3ZT0spA+BbNDsJn5dSVqSUG4BngNlPq9dJKW+Z2v//0rxJr9lPmz4IfFxKOSCl9IBPAm8VQhgHa6+iKIqiHIW7eT6YvpRmcHvvXtvuVvf9mc+q+76iHIRxshugKMdJRDPlazaT5kV62sTUTW9aHUhP/f1u4A3AAM2n1r8Gfovm0+F7pZSxEGLR1DGHhBDTx9CA3ftpT+/s7VJKKYQY2Guf4Vnv16eOmWb/FgEXCSGKs7YZwH8BHVN/n92OnQc4DlOfWQh8SwiRB/6b5k0umEpX+wRwxtR3SwJPz/rsxFTHAJ5/mjwy6/3GXt9h9m8QT/0GvQf4frcJIeJZ2yKaHY8Dtvcg31FRFEVRDuQe4A+nUpI7pJSbhRAjwH9ObVs1tY+676v7vqIckhrBVl6sdgGL99p2Gge/4cx2N80ntq+a+vt9wCXMTQ/fDXhAu5QyP/VPVkq5cj/HGwL6pl+I5l20bz/7HYjc6/Vu4O5Z581PpZH9ATBGM31t4az9+w944OYI/v+RUp5FM33tWuB9Qgib5tPxzwJdUso88BNAHOhYh2GmTaI5x70PGNzPfruBq/f6fo6Ucs+B2nsMbVIURVFe2h6kmYL8u8D9AFMjpINT2wanpo+p+/6RU/d95SVHBdjKi9W3gb8UQvRNFc54NfB6mqnehySl3EzzKex7ad7QyjSf0L6FqQBbSjlEM3Xqc0KI7NR5lgohLt/PIX8MnD1VDMQA/pDmPKbDNQIsmfX6R8AZQojfEkKYU/9cIIRYMfVk+Vbgk0KI5NT8rPcf6MBCiCuEEGdPpZiVaY7yx4AF2EzduKeeah/r8iOrhRDXTf0GH6bZUVm7n/2+BHx6KksAIUSHEOKNh2ivoiiKohwxKWUDeBT4CM3U8Gn3TW27Z2o/dd8/cuq+r7zkqABbebH6FPAAzZtjgWbxkPdMzUs6XHfTTIXaPeu1AB6btc/7aN6Qnpk6zy0051DNIaUcB9421Y4J4CyaN3PvMNvyeZpzkQpCiH+WUlZo3vTeSfNJ8DDwdzRvjAAfopmiNUyzIMt/HOTY3VPtLgMbp77nf02d44+B70x9t3cDPzjM9h7I7cA7po73W8B1B0jx+vzUue4UQlRo3owvOlh7j7FdiqIoykvb3TSLh903a9u9U9tmL8+l7vtHRt33lZccIeXeGSiKohxvU2lSAzSD/rtOdntOBCHEJ2kWKHnvyW6LoiiKopxI6r6vKC8dagRbUU4QIcRrhRD5qTlON9AcDd9fmpSiKIqiKC9w6r6vKC9NKsBWlBPnYmArME5zPvibpuZ9KYqiKIry4qPu+4ryEqRSxBVFURRFURRFURRlHqgRbEVRFEVRFEVRFEWZByrAVhRFURRFURRFUZR5YJyMk7a3t8vFixefjFMriqIoylFZt27duJSy42S348VG9QkURVGUF6ID9QtOSoC9ePFiHn300ZNxakVRFEU5KkKInSe7DS9Gqk+gKIqivBAdqF+gUsQVRVEURVEURVEUZR6oAFtRFEVRFEVRFEVR5sFJSRFXFEVRlPm2bmeBtdsmWLOkjdWLWk52cxRFURRFOQmklJQaAXU/Im0bZBPmCT2/CrAVRVGUF7x1Owu856tr8cMYy9C46fo1KshWFEVRlJegRhAxUfNJGBrjVQ/L0HBM/YSd/7BTxIUQ/y6EGBVCrJ+17R+EEM8KIZ4SQtwmhMgfl1YqiqIoykGs3TaBH8bEEoIwZu22iZPdpBc91S9QFEVRTkWxBAEYejPUlfLEnv9I5mB/HXjdXtt+DqySUr4MeA74i3lql6IoiqIctjVL2rAMDV2AaWisWdJ2spv0UvB1VL9AURRFOcUkTJ2EpVPzQ1K2gWOe2LJjh302KeU9wORe2+6UUoZTL9cCffPYNkVRFEU5LKsXtXDT9Wt454X9XHeeuhWdCKpfoCiKopyKdE3QnrYxdY0wivHC+ISefz7D+Q8AdxzoTSHE7wkhHhVCPDo2NjaPp1UURVGUpu89NsC3Ht7Fe766lnU7Cye7OS91B+wXqD6BoiiKcjyNVz2iWCKB4ZKLPIF54vMSYAshPg6EwE0H2kdK+WUp5flSyvM7Ojrm47SKoijKi8C6nQW+eNeWYw6I1TzsU8eh+gWqT6AoiqLsjxtEDBYbjFZcovjog+I4lmhCoAnBCZ6CfexVxIUQvw1cC1wpT+SjAUVRFOUFbz6rf0/Pww7CWM3DPolUv0BRFEU5GnEsGS67WJqg7kVoQHvGOapjtaVtRsouYSjpSFsIIea3sQdxTAG2EOJ1wP8GLpdS1uenSYqiKMqL3fSa1YPFxj6jzkcSYE8fpyVpUaj73HjtSn69aZSRssvPNwzPWRd73c4Ctz42gATecl6fWsbrOFD9AkVRFOVIxLGk2AgIopi0bSClRNc0dCmIjuARbRDFjJZdvDBCkyB0DUODYt1na8NnYT6BEAJdF4SRZLjkEkYRQghakibd+eRM1fFjddgBthDiZuBVQLsQYgD4BM3qoDbw86mnAmullB+cl5YpiqIoL0qzR60NTWDoGlF05KPO63YWeNdXmsfZnycHSgBYhsY1q7q5/cnBmaU6vvnQLj542RI+ds2KY/4+L1WqX6AoiqIcq2IjoNjwMTXBeBCRd0yKjQBD12hJWod1DCklQ8UGI6UGo1WPoaLLqgUZHto+iZAayzqTPLp9kqWdKSSwfbyOlJLtoxW8KKYnn+Sixa1cuKQN0zj29bIPO8CWUr5rP5u/dswtUBRFUV5SZs+VjmLJOy5cyIJ8YmakGTjoaPNnfrKRn24YBjhgcD2bH8Z8/4nBfbZ/6Z5t9LelePdF/TMj4bPboByc6hcoiqIox8r1Q3TA0jVqQUTaMUg7JoYuZtK63SBivOLiRzFZx6Q1ZaNpgrFyncd3F6l7EaausafUwDEE45U6/373KGU/YmFLkrrnY1smuaTFU7tLEEM9CBmt+HTnbKSUjFY9an5E3tBp+BFeGJGwdOyjCLiPeQ62oiiKohyudTsL7Ck25oxar+rNUaj7c/aZPTJ9y6O7+eQbVlGo+/x8wzBPTI1Mz4c71g+xvDszb/PAFUVRFEU5PJM1n7IXMl52aU3btKdtRioeQRSTS1i0pizcIGLXRI3RsotpCCpuyEjZBeCOp4epeh4agoRlYFsaxVrEnqLLcNklFhrDxRqhjHnFsjSNIMLUBTlHZ2S4TixjSvUATRNkHQPH1GeKrOmaYLLms7A1iXmEqeMqwFYURVFOiNmp4ZomWLUgx8VL2vjUjzbMCW7XbpsgmDUy7UeSv/z+0xxtMVFdcMB5XFev6tlv9XEVYCuKoijK8RNGMcWGj6VrIKDqBmQcgziGlGVQrPukbYMwlkgkjqVTcQM2DpYBjZrns7NQpSvtMFF10TTBeYtbGJissXuiRj5tMVx0qQqdaxe1sqQjQ8I0eM3KNFUvIp+0+I3lnYyWPTIJg9WnteOYOuWGj6Y119JuuFGzcKoKsBVFUZRT0exANo4kTw2U2DBYnlmncjq4XbOkDdPQ5qR/H01wLWjeID/1xlU8vH2C258YRE5t721J8PqzeyjUfVqSFpah4QfxVLGTw5vzpSiKoijK0RFCEIQx20arWIagNWlRagQkLYMgihECkBJDa45O192QR7dNMFiss7A9jRuEGAhGinUKXsjL+nK0pmwWZB0KVZfHdhWxdA1dgwhBa9JGCEHWsQgij1zKxjE0OnJJWpImFTdgrOKRtgyqbsh41cPUNFpTJkn7yEJmFWAriqIoJ0RL0poTKEsglhJdE0gpMY1mQZO12yb45OtX8vc/e5ZiPTjq80mahU82DJb46YZhhABdNAPuvdPCf/vixXz1vu3EUvKpH21geXdGjWIriqIoynEkpaTU8Kl5MZoQLOvKYE3Ngc7aJoNlFyklLUmL5wZLrNtdZLLi8vRgmf4Wh85Mgq5sigvzNu1Zh1ItIJFPsHJBKxuHaizpdIhi2D3Z4FWGwA0loxWXWEJ72qLuR/TmHRp+jBvEJC2dqh9i6xr9bUkcQ6fmR0gpj2iZLxVgK4qiKCfE7HnW0wxN8IFLTuPBbRNYhsYnf7iBMIoxdO2wCpgdimloSJgZORdICnV/n7TwDUNlYilVmriiKIqinABRLJms+kQRBGHERM1jmUxR8yR+FLFjokpH2iGfNJms+WweLZOzDGTKYqLmowtJOmlim4KRSoAfgR/VyaVM2jI2GVtHN3RaUxYCyXjVoyefxA+bo+OaEFgG2IaO6zf7G7EEJGQSBjU/wgtjHFM74jW0VYCtKIqinBD7S73uyDp87f7tBHtNkg7mIbjuSFtctbKblb05LENrzqOatRTY7G1Xr+rhkR2T++yjKIqiKMr80wT4cUQtCEk7Og0v4pmhKvmEAUJQc5sZbGEUkU1YpB2LehhQ9UJaEgZdmSS+F/LkrgLlRsjZCzJ05RLcvWmUha1JLj69nc2jVTIJi/aUgRvECAn5pMV41SOUks6MhRCCTMLEDWO8MKIja5OyDKpeQCwh45hH/N1UgK0oiqKcEPsbwd5TaOyzTQCGLvYJuo/UWNXn5od3YRkaN167kkLdn7MM13RBtelty7szaqkuRVEURTkBNCHozSWpehFBEDFRD6i7HnUvJIhjlnWkm0t1aRoJS+fipW10ZCyeGSiRTloEUcQzeyoYuoZmCLaPVfEjWNGbxfVjHEujJWWhIQkjwZ5iA00I+lqT9LcmAWZGpnVN0J1z5rQvmzj6eiwqwFYURVFOiDVL2prFxA4yOm1o8I4L+pHAzQ/t4thCbGZSvgt1nz+8Ytmc91YvapkTSO/9WlEURVGU40PTBIvaUoQyZttIhVLN47nBBpah0ZlPcGZPjmzCpDvrEEQxYxWXlGVxRm+e/rYkmwZLrJdlbFMnFYKhQ2umOS87BuII+ltTjJZ9Kn6AFBDJmLofARxx2vcRfbfjdmRFURRFmWX1ohZu/t01vLwvt9/3V3Rn+Pbvv4JPv/lsVvXmjjm4huZouEr5VhRFUZRTj23qnNGRpTXt0JtPkktZtKQTtKdtOlIm/S1JHFPH0ARRDJGUpC0DKaEl5fAbZ3bRljLpbbF4xdJ2LljcimkIdCGwTQ0pBZ2Z5nra2YRBLCGXMI9rcA1qBFtRFEU5wVYuyLF+anmu2RpBNDOCXKj7CDhkkC1oPgWPYokmQAiI4uff68ravOmcBWpkWlEURVFOIVJKJqoeFS8kjiWGLogAUxc4hk4+5aBpzUBYCEE2YZJLmGweqbJrokoQSxa2JulvX0zGNkFIEqZJIwjI2ib1IMKPIixDxwtiGn5IPmGSNI//+LIKsBVFUZQTYt3OwszSWJpgnwD6nIV5vnjXFtYsaWPNkjZs8/l0cimf39fQBVcs76QzYzNW8fj5MyNA83jvnEov3zJS4eEdBYbLHl+6Zxv9bSnefVH/Cfy2iqIoiqIciBfGlN2QtG3Q25KgZyot3NSgI+dQagSMll3StkHK0UnZBoPFBroGAkFXyqYRRixpT7OoPY2UkpoX8MxQgyCSpG2D7lSC9ozDUKHGxqpHI4goNEJW97dgGMcv0FYp4oqiKMoJMXtpLCnhXRf1c+HiFrqzNm86p5efbhjmc3du4h3/9iCbhivceO1KNCGa+886zm8s7+Qr7zuf687r49fPjc28JzSBBN5yXh+2qc859x3rh07U11QURVEU5RC0qTRtP4gAQWfW4ey+PGf25oljcP2Q0WqDp/cU2TVRxzEEacegJ++wc6LGfdvGGSy61PyIMIoJY8loxactbVN3IypugBDNDLeyF5E0NdrTDl4Q4YbR8f1ux/XoiqIoijJlzZI2DF1DALqu8Zbz+vjOB1/B2htezeldGbygGXyHseTG29ezfrBELPdNEv/lxhHW7SywdtsE4VQ+uABkLPnWw7t4z1fXsrInO+czV6/qOQHfUFEURVGUw+WHEXtKDQSQtp9PrJbARN1nYKJBxQuRQN2PsQ2NwckGSdsgYxvNe7+EIJLEUiIlpG0TXRdIKSg1AsYqHh1pBzeIGa95JG0dx9AP0KL5oVLEFUVRlBNnKmCO45hbHxsAmsXPWpLWnFHqKJYImmtVu8HcquORhH/6xXNcvapnZi3r6afUkmbV8LIXctVZXYyUXd5xQb9KD1cURVGUU0ip4ZO0DFqSFkMVl62jVdKOQUfGIWMbeH6EZepYuqBY9+nNJzA0wS5bo1IPKDUCio2A0zpSVLwAUxNkHJ2KF1IPQpZkU5imjhdGtKYslnWmEWh0Zu3jmh4OKsBWFEVRTpC12yYIp4LgMIZvPrSL7z02wE3Xr9mnqJmuCa47r4+VvTn+8eebGKvOXUP7vs3jPLJjkhuvXcmGwRKjFY+7nxsjimJ0TXDLugHCKMYyNJZ3Z070V1UURVEU5QCklPiRpOKFRLFGpRHSmbJpBM3U7qxj0tuapCeOqXghbSmbhKkzWHRJWxaWJVjZmsUydMr1gLGSy2ilQW9LglzCJueYbJ+o056x6Eg7DJUaICGUzcJnpgqwFUVRlBeDlqQ1Nae6GUZPjzav3TYxp6iZJgSfeuMqAD71ow37jGBPf9YNYm647enmXCcBhia4ckVz1PqpgdKc46sq4oqiKIpyaijUA2peiOuHVOoRhqbhhTFyavUsTRN0Z5uFzvJJm1zCpOoFBHFMLmEg0BkouJimRqnW4Cv3FBipeLQlLd5+wSKuXNmDVfGoewETwqPhR/S1JHGDCC+ISdnH9/upAFtRFEU5rtbtLHDrYwN899Hd+yzNpevNNapXL2rhpuvXzATbqxe18MW7tsxUET+YGJpPpiPJr54dbc7DollkRK2BrSiKoiinBikljSBipNwglzCxDZud4w3aMwYDxQaLWpNkHBMAx9RxZhUsNXUdGUsmGyELWh2CIKLqRdz3XIWdk3XSpk7FDfjJ+kHOWdRCxWtWKM85JqW6T7kRoOuCpH38w18VYCuKoijHzfTSXF4Q77OmtQDeurpvZnR59aKWOSPNa5a0YRnafj+7XwJiKYll89hn9+W48fUr1ei1oiiKopwCJms+xXpAw4uouiFJSyfl6LSnHRKmSXvGRp9a+3pvCUunO+cQxXV6c0mqXsDAZImEoWHpgrIb4Bg6GVtj21iNFT1Z4lgyXvXRhEZ7xiZlG5j68a/xraqIK4qiKMfN2m0TcwLk2bdNCTy+s8C6nYU5n1m3s8AX79rCpuEKl57ewdLONIdzP1zSnsKY2lECG4fK8/EVFEVRFEWZBzU/JIiay2o1vIj2tEXdC/jxUwPc99wwm4fL1P0QKSVRLJv7+SFj5QZbRysMFRsImktyVhsBCVvjjO4svfkkScvgtPYkFy1pp+5HdKRtIiRuENGaNqm4IcYBgvf5pkawFUVRlOOm0gjmjD6f2Z1h43Bl5vXG4Qpv+dcHWNyWpDNjs328tk9BMwBdY04RtP1ZtSDHhUvauPmhXUialcjV/GtFURRFOTX4QcyTA0USlk5P1mG45LKnUGfTYJWhSp1fbhxlYavD0o407ekENS9isuFiGRpCCHSh0ZFxaEmapB2LM5MWA5N18kkLLwh4dEeJ+7aMc3Zvnom6T9axyNgmhiao+xFSgjgBMbYKsBVFUZTjYt3OAl+9b/vMawE0gmi/++6YqLNjon7AY0WHnorN7U8M8uk3n41tNpfuUvOvFUVRFOXUEEbNImYLWhIgBZoQhLHEC2KCKKTmR9QbAWEcUvVj0laNmh8yVvFBClrTOqd1ZJvTwKQkl7RwLJ0wluSSJttGPbwooi3jMFkL2D1e44IlbYxUPPwopiVlTRVIjbB0De04jmarAFtRFEU5LtZum5hT1EwT0JmxDxpIHwsJfPuRXdx47UoKdX8muP7iXVtmCqcpiqIoinLiCSEwdY2OpM3msQq+pXNaWwpLg8GSS7nWzF4LI4lOzEgpwA0CGpHA90IqriBlGYyUGph6nkYYknYsOjM2Q0WXihehC4kAal7AeNVj50SNroxDytYRmsbuyTpRHGObOj25xAHnex8rFWAriqIox2zdzsKcCuDAzNJb03OwIwkP7ygc/EDH6MmBEk/veZqzF+SoNAK+/uAO/LC5HvZN169RQbaiKIqiHEfTlcJjCQlTnwli9amlt+7ZNMxjuybZM9FA1yBparQ4BnEY0AhjLFMHBK1pi7Eq4HnkkiZRLPFCSVvKoOoGjBZDTusSFGs+OyfqJCyN7lySQj2gNW2ia4JNQ2WeGijQmrDpzidwTJ1MwqTqhvhhTMLSD/pdjpYKsBVFUZRjMl0pfO9AdvWiFm68diU33Pb0CW1PLJuB9pMDpZl522o9bEVRFEU5/or1gMm6j6BZ+bsnlwCadVH8MGJP0cUxdISAwWKDjGMyUW3gx4LE1Kh2ZzZBPmExUKyxXUha0zZeJNCkZLgaUGyEWKaOMdlASokQGnUvIu9YdGUTBDLm/i0TtCQNKm7Emb0ZAinpzjmYhoYQHLfRa1ABtqIoinKM1m6bwA9jYjk3kF23s8Ad64dOatuEUOthK4qiKMqJ0ggiEqaGoWnU/JB4aqrYcKmB68cIYKToMlHzqfuSQHqEsSRtGiQtk9GKTxzDSMXD1gUvW9iCF8aU3IBiPWCi4tKTSxJHEWNVlyCI6G1JIKVJ3Q/pbUkQRzGTlSIlTxIEESNFD8fUWdWbJeWYJC0dyzh+i2mpAFtRFEU5JtPrVc8uLDZ7VPtkeuWydiRw9aoeNXqtKIqiKMdZytKZqPlATMo20DSBH8b4YUwmYfCKZR0YAnJJg5GiS6wJGp7PcMknZUNnyiQWYBmCRS1JwhjCKCSTENi6RhjFJG2N8YrHyxdm6M7aDJc9cgmTtJWk4AaYusZ5/a1M1j1qXkQmZZJLGGhCUPNCdAGOeXzSw0EF2IqiKMoxWr2ohZuuXzNnDvYX79oyM6qtCbANjUZw4oPtB7ZOEEvJIzsmAWaKn6lgW1EURVHmXy5pYZs6UoJjNkeJDU1gGhpVP0TXNJb35vBCSS2I0aRgWWeK8xZrZBM25bpPEEbsKTQYnKxzWmeWOI7ZPVFnQYuDpQuiSCKEThDFCAQ9OQdD15AS8gmLhW1JbF1jT6mBjCWdWQdL16n7MSnHYKTsUQ8iTE2QS1qY+vyOZqsAW1EURTlm03Oup+09qn1GV4YnB0rH7fyagCUdaUxN8NxoFSklmhDEUhLL5tqbN96+nlhKVfBMURRFUY6jvUeHtakCZ14YowmYrLlYlk5/PkUoQ1qTCVb0ZomimCf2FBguNyi6IcOlBn4YEMTg+RGeHxBr0J1LEkQBEzUPiWRhzmFXqbledkvKpFwPyTgGHSmbbMIinzSQUlBsBMRxzHjFJZIWCdOgEbj0tSQQ87hA9mGH60KIfxdCjAoh1s/a1iqE+LkQYvPUn6q3oiiKosyMan/kquXcdP0aLj7A/Of5qjESS9gyWmXjcAUpJa9e0cWn3rgKy9Ca5xDNAiuxBH9qnrhybFS/QFEURTlchq6Rsg0SloGGoFDzGau5RFLj9M4kbWmbhe0pXrGknUojRJMSXcZsHKqwY7xBEMXUg4gwhK1jVWqeT3vGoSVpkXAs2jM2DS+k7oZsGasyWXMZKNTZNFxiuOyRsAyEgK1jNYqNgHLdJ45jvCBCykO3/0gcyXj414HX7bXtY8AvpZSnA7+ceq0oiqIorF7Uwh9esYzVi1qoeOE+71+4uIV3XtjPotbkvJ43lvCrZ0dZ3p3hxmtXTo1kN6uJT7/fkrTm9ZwvUV9H9QsURVGUIyClJJM0OaMry1k9OZZ2pOhrS5O0mynce0ouGVun7IXUY0lL0qEtZVILYmJNo7clQWvaYmlnhq6UTUvKJpbgBTFeEJO0DXpyCapeTBCBYxm4QUSx7mMZGqe1p1jQkmDbeI1tY1WCOGIeB6+BI0gRl1LeI4RYvNfmNwKvmvr7fwK/Bv58PhqmKIqivHg8N1KZ89rUBY/uKBzVuti6JohjycEeOIex5M++8wSnd2WI93o0rdGci60cG9UvUBRFUY6UEIK0ZbJhqETDCxAaVBsBpqHTlU0wOOlyxZndtKULbBws05u3GSr75GWEbRrUvIiOlEVfa4qOnM2q3hyjZY8V3VnGqy5eGOOYBsWGR9UNEQi2jdSYrPq0pR3SjoYbRCxqS7GwJYk/ld2mz2OQfaxzsLuklNNrsAwDXQfaUQjxe8DvAfT39x/jaRVFUZQXkkJtbkAbREeXj6UL+N1XnsbXH9zxfBE1wDA0lraneHakMpPqtWOizo6JOoYGctYIttBQS3YdP4fVL1B9AkVRlJeuWMakTQ0/EDw3UqFeD+jMJ6l6IY0wZPukpLclydl9LfhRyK6CS3vaRkjww5CWlM3p3VmiWGLqOrmEST2IySYsUrZB1jGpeA6lesDO8SqZhE5LyqJc9+nOpal7IZM1ny2jVTqzzryviT1vRc6klFIIccAek5Tyy8CXAc4///x5znRXFEVRTmVLOtJsGasdcj9dg+6sw56iu9/3pYQNQ2V+++LFbBgqs7InSyZh0pK0+NSPNoAEAXNGtxe3p9k6Wp15HcXwXw/umFP1XJl/B+sXqD6BoijKS5eja8RSUHR9yo0QDQ2pucRRyGldOXQkSzszOJZJHMd0ZwN2TFYZr/ks60jRknIAwXjVQxcajSCgJWXTlkqStA0KNR8/jGlNWQyVdAQQxDGaLogkaELj8V2TSASdGZuso5F2LBKWMS/Ldx1rTfIRIUQPwNSfo8fcIkVRFOVF5/cvX3p4O0p41fLOfZ4m61rzhhUD920e50v3bOO+zeN8/cEdrFnSRqHevJnuL1I7rT21z/bvPzHIZ3+2iXd9ZS3rdh55mrpyQKpfoCiKohxUa8bhtPY0ltDpyVpkHI2qG1APJMOFBghBLmk3903Z1IKAuh/RkrTIJUyWtKcwNOjI2LhhxGQ9pNIIZ+q9CNHsL0SyORLenrGxDZ2zerP4QUy54TNc9ql6EbsLLvdtHuPpPSWe2l2g4e9bM+ZIHWuA/QPg/VN/fz9w+zEeT1EURXkRWr2ohQsXH3qkWAi47rw+VvVm52zvSNuc3ZdDE8+PTkuay2/90y+eoyVpYRkaugDb1HjTOb0sbkvywcuWcMXyzv2eS9KsKH7rYwPH9uWU2VS/QFEURTkoTRO8rD/HhUvaOL07z4ruLB0Zm2zSZKhY5/GdBR7fNQGaJJ+00ITGygV5VvRkqXoRlqmTT9qYmsZEzSNl6WSSBnU/ZLzi4voRjqERx9CXT7C0I0NPzsHUBI0g4pFtk4yVXBpeQKHqIdCagbwfUfOiY/5+h50iLoS4mWbhknYhxADwCeAzwHeEEP8D2Am8/ZhbpCiKopxU63YW5qRP7/36aP351St425ceID5IQnAYw6bhCu+4oJ8nB56e2T5S9pisBxiaIJwqSCJoPqG+d/M4D2yd4HdfedqcdHE/jPn3+7cftBgacMj3lf1T/QJFUZQXv6obUKgHWIZGe9omiGImaz6WLmhJ2XMyzoIoRhPisOY0t6VtVvXlyCZNGl5EEMPuQo2aG+FFEmdPBaRGd8qmK5tgx3gFCaRtk5GKi4wluYTJ6UaGqhcShJKRcoOHJ+tICUs7UrxsYbPPMjBZx9A1xio+E1WPJV1pAglJU9CWNsklbSqNAENo2Oaxjj8fWRXxdx3grSuPuRWKoijKKWHdzgLv+epa/DDGMjRuvHblTLBqGRo3Xb9mJsg+ksB73c4Ctz42wJUruhiYrDNQrFNx9/+U+DN3bGTNkjYuXNzCrsk6I2UPCURRzDsv7Kc3n6AlafHtR3bx5EAJaK5x/ZX7tvOd37+YtdsmZgqgBdH+q43rGsRxs5r5W87rO5af7CVL9QsURVFe3MIoZrTqkTB06n5IoSao+iGmJqh4EUII2tLNVO7xikvZDdE0QXfW2WcucxRLYikxdY1yI6DU8HAMjTO7sxiawNIEE3WP4aJHte7xbBSza7KKF4ZkEzZJq7mGtqVD2jIIopgglixoTeKFEUEYs2W0Qj5hITQYKXl43RGGroEAx9SpuCF+GGEY0JO3yDoWFy5uRTc0yo2QlqRJxjGP+XebtyJniqIoygvfnOA0jLlj/dCc12u3TcyMas8OxGcH3tOmA/CWpMVf3f40UXx4bSi7IXc+MwJMVQjXm8tymYbGdef1zZxn/WBpJsCG5s371scGuO68PixDIwhjhGiOis/2wcuW8JqV3arImaIoiqIchJz6H02AhiCWzSDZ0DSiWBBNLdsRRDEVNyJtG3hhRKkRzAmwvTBiuOQSx1BpeGwZrTJZ88k6Bl0tCVqSNkUvoK8lwa6xKtUgplKoEYaSlKVzeleahW0ZsgmDkhuSckKkhK5sM7i3DR1L10hYOrvHa9T8kHzKxg1i2mwT29CpeSGmLkiZOmUPOtIOFy5pIz8117s9PX+/mwqwFUVRlBlrlrTNBKemoXH1qh4e2TE583p6eau9A/HpwHvaNx/axY23r59Zg/pgaeEHEwMilrzzwv45wTXAW87r49uP7JoTuH/30d1cd14fN12/hrXbJhgsNrj54V0zKeXvvqifj12zAkAF1oqiKIpyEKau0ZayplLEBa2pZr2TQt3H0DTyCQsAXQg0DdwgIowlGef5NGspJTvHq4xXffJJs7kslx9RaoSs31Pk7L48i9ojilUPhMAydVpTBlUvptKo88SeAjsnaly+XJKyTDQdEqZO1jFIzRptLjcCWpM2xYSPZRqsXtRC1QvJJkx6cwmCOCYIYwxdsMQyqPsRhn7sFcP3RwXYiqIoyozVi1pmgtPp0d3l3Zl9Rnv3DsRnryu9bmeBG29fT3i0UfVeIgm/3DjCyt7cPkGxLgTRrCTwIJKs3TbBH16xbGak/XuPDcy08zqVDq4oiqIohy2XtMglrZnX+aRF1jERAoRozrXWNEF3zqHcCDB1jVzi+cDXDWLqQYyuaUxWA3Rdox74jFVdLENj41CJkVKDsVrAWQsytGUcMo7OrgmXUQQyhImaz2O7CvS0JIkjSdUN6cwkWN6dIWE1w9mqF5JPmizvzrJ+T5HdhTopS6c35+BHMVI208RtQ28G15qGYxz7fOv9UQG2oiiKMsfqRS1zAtm9X09v2zsQn7Z22wTRrOBaE0c/gj1tuOxxw21P8+tNo/z+5UtZvaiFtdsm9gniJVBpBHPmhx+onYqiKIqiHDltP0XMbEOnI7P/EWHH1Kh5IVU/5LxFebaO1Igj0HVJpRFh6FCbqDNe8poj07lmYbOxikcQR3hhTMOX3PXMKLmkwaq+FiarPjUvZPXiVhxTJ2UZ7Ck2GC27lBo+AkFZE+QSJn4ksYxm4D89mm1q2n6/x3xQAbaiKIpyVPYXeENzdNs0NPypyc8CWNSaZOdk/ZjPeeczI9yzeYybrl/DmiVt+w3ef7FxhK/et50olthms1CboiiKoignjh/GRLHE1AWxbM7h7s5auEHMOQtzGLrg2cESURwTxoKkYzJRczE1g1V9eSYrLmcvyLBjokE9CEmbglzSoD1tM1nzySVMbF0wWvbob0uia4KRcoORskcUw6KOJMMFl2eHyrSlbXJJi0LNJ2UZFBsBli5oTdnHJchWAbaiKIoyL2aPGr+8L8cjOwpAM8V7PoLraV7QXLt6ZW9unwJmANvGazNBtxvEc9LVV3Rn+Os3n61GshVFURTlOHGDiKFSAwBL1wijmJFynVLDJ2Ea5BIWhbrPsu4sewoN9hSqWEJS9WIyiRhDaKQdi7Rt0pOHshugmQb97SlaEjalRkBX1iGTMNlTrGPqgpofYhs6y3syPLpzkuFJFzTBgpzDQNFl12SdxW0pwljihiGbhip0Z23WLO0kYc3vXGwVYCuKoijHbHZRM10TBNH8rS6tAT0tCYZLLlHcXHbr24/uZmVPac5+aVunO5dg62h1zvbZaeQbhyu8/d8e4Du//woVZCuKoijKcVB1Q0r1oJka3gh4ZMcEYxWPnRNVMo5BXz6J60scU8fUBQnLYLRWQ5MQRh5P7inQmXZY1pVBm6jRm0twbn+O1Yvbsc1m+Fqquzw7UiVh6mwbr5GxNPwoQvPhjI40i1pTDJYb1IIIXYMl7SmCOObxnZPsnqwRS8H2sRqmqXP+orZ9lhU7FirAVhRFUY7J3kXN4nkMri89vZ2rV/VQqPvcvWmUh6dGxcNIYu9VnMQNIraOVve77vVsUcw+Vc8VRVEURTl6bhBNFRLTGCk32DhUJowiBibrFBseacsgbRlo6IzXAjrTFpN1n9aESWvKoi1pUG6EhBI6kzaeH1D1YvpaUrSkLPrbMziWQakREMUSTWg4hk5r2qZY90EYrOjOUPUj8gmLshvS35pmrOwigJRjMFxy6cjYPDNYpiNj49g6FTcgjCVhFONHMYamYR1j8TMVYCuKoihHZTolfE+xMW8Vw2czNLh6VQ+f+tEG/DCeqVY6zQtjPnjZEjYMlXFMnV9uHEHSnPMtDlJYTdea88Rnp7SrYFtRFEVRjlwURdSDmLGKBzSX0HL9iEWtSdZuG+epgQJRJKl6IbmUydL2LAJJytbJJQxakjalRkjVDdE1gaXpTDZ8cgkLAp+BQp2unEVr2p7pa6Qdg2LdJ5swmKx5jJZcEqaO0ByWdiTxo5hSo7nudUvKos9qbguimIxjMDBZp+KFCCnoyycxNMFg0SWIIwSCBS0JbOPoR7RVgK0oiqLsY3p5KwFz1p/+zE828tMNw5yzMM9PNwzjhzGGrqGL5lzrwyHgkKPMANe/cgmFuj+z3raGnHOep/eU2DRS4abr1wBw7+axmeW4brx2JRsGS3z30d3NiuYCcgmTZR1p/vzq5jrY7/nqWvwwRhOCT71xFe++qP/IfiRFURRFeQlwg4iK6xNGkHEM0o5JqeHz4NZxqm5Ee9rgjK4chq5R9UOySYNnhyoMFuvNJb6kpFhzsXWTshvwsv4sQQiuFxMkJGf2pHly1wRuEGIIjV2TNXpzMW4YkbMN0oZJX0uSihtQrMeIoLlG9xldWbaOVWlN2nRkbGpeSJiysA2dlpRFse7jmBptGRtdE6Rtg6of8oZzFhBJsAydbMLED2PCuDl6PVLy0ISgvzV51AXQVIB9CJ7n8Ytf/IK+vj5e/vKXn+zmKIqiHHfrdhZ4x5cfJJyKZL+7boCbf3cNP98wzJfu2QbAjonni5YFYcyZ3RmeHa4cVuDsj+8mKA7i9L8czXIOuN+D2ya48fUr56y3feO1K7lj/RD3bR4nls0qpdPrXu+9HNe6nYWZEe3ZDwkAPn7b07hBs0JaLCU33r6e5d0ZNZKtHNKTTz7J4OAgl19+Oclk8mQ3R1EU5bhyg4hnB0tsHa9iahqnd2Xob0uyYaBEtRHQkbHZNl6nJWmjaxqxhJQt6GtN0JVJUKgVm6Pbmk4+lUQgGS97VL2I9oxDrVFhy1PrKMosfqabOGqQcwwmdR0/jljQk2Co4lKpB6QTBl1ZGy+MyTgmjqmzsCXBjsk6I2WXXMJEn8p2a01ZtCTNmew3N4jQdUFX1sHUn08B98OY8ZrLWMWj5odYuqDh6RQbAa0pa7+/yaGoAPsQ3vjGN/Kzn/0MgNe+9rV89atfpa+v7yS3SlEU5fi59bGBmeAamjeff7t7Kw/vmNzv/pJm8bBDieolJu/8F+qb7gfA6l5G9/v+L0Lsf66TH8Zz1ttuSTarjq7syXLv5nGgmQbekmzeAGcvG7ZuZ2FmhNoyNK477/nr9rqdBb776O4554pjqeZlK4d0++2386Y3vQmAzs5OvvCFL/C2t73t5DZKURTlOKp7IZN1n4xjUm6EPDtcYrzmUaj6jNV8krZB0tbIpUwGCx5xHPPciEtXNkF/ewLb0NhTqFEOQvKOQdkLKbsRWUfnnh98mzv+43O4tQpC03nZ732WzMIVjOk6/R0AGtvGq/S3ptg2XmJ5d56EZZBNaJQbPhNVj2Ldw9QgjsHUxZxR59nB9WCxWdVcE4K+lgTGVJA9XHJBQDZhUHF9+tozaJqYWWr0aBzbDO4XuV27dvGzn/2Mj370o3z2s5/lvvvu47zzzmPt2rUnu2mKoijHzf5Goe98ZoRiPZizbVHr4Y/e+WM7GPrPD1Pf8hC5V76H/GXvwx/egj+y7YCf2Thc4TM/2cjabRNUGgF/9f2n+YefbeIr9z7/GQ0o1P19Prt228RMankwNco9+72954xbpsaaJW2H/X2Ul6avfvWrLF68mJ/85CcsXryYt7/97Xz0ox8ljo++I6YoinIqcywdXRf4QUyh5mJqOlEU05m1SRkae4p1VnRlsHSdKJJ4Uci6nZPcs2mY4aJHwtFZ2pnhnL48S7ozXLykg7QJ3/2nT3DbFz5Jz5IzedeN/4ruJBl8+CfYho4UzcDeNKDhSyp+xFN7Kjy4dYJdEzWeGy7z9O4So6UG20ZrhLEkm9L3u3QnNANsTQhStkEs5UwfQEpJJCWmrpF1TLqyCeK4Gaznk+ZR/2ZqBPsgBgYGALjyyit57WtfyzXXXMPrX/96rrzySn7wgx9w5ZVXnuQWKoqizL+3nNfHLY/uxj/EpOrkYa4baU1uZfc3P44wLLrf8w/YPafjDW+Be75BVBmH7mUH/OyXp4Lp2fHw7GZpUwXL9rZmSduc1PLZ+8x+T9cEbzt/4T4p5IqyPwMDA5x99tlcffXVvPrVr+bDH/4wn/3sZxkbG+NrX/sauj6/a6kqiqKcbEnL4OV9eYYKDbKOQXc+waahMjsn6tSCiHzCYEfBBU2j6Pk8N1QhZRnoGtQ8n4Tp0JKyKDQCFuWT7Jks8a3P/BlPP/grLnrzB7j+Tz7GhqESqY4+/PI4FT/ENnUiKYhjScrR6cokcUyBF4T4YczuyRo1P8KPYraOVHlk+yRpx+TiJW0saEnMSQEHcEydWPpUvQBT02feF0LQkbYYrXgIIVjWmWmOggtx1POvQQXYB2UYzZ/H95ujIytWrOD+++/n1a9+Nddeey0/+clPuOKKK05mExVFUebd6kUt3Px7F3PrYwN89yCB9rPDlUMWLPOGt7DrWx9Hd9J0vvPTmPnu5htRCIDQDh6QHKo4eRjDpuHKPsHx7NTyvauEH+w9RTkYwzBm+gSmafKFL3yBzs5OPvnJTyKl5D/+4z/QNJUcqCjKi0s+aZNP2vR7AVvHqhi6hmMKHMMiYRsgZXPkOYpx/QBDFwyM1xmrBfTkXE7vynJ2XwstCZ0P/+4f8vSDv+IPPvZpVl31VgQxO8arxGGAnciQdTR6sg7tmQRhFJM0DMqeR7khcJM63ogkjiWGrjFW8tA0QUvaIpcwqQQh4xWXnvzcDDvH1OlrSRLGMZauoc8KntOOSdJqxnzHElTPpgLsg+jt7QVgz549M9u6urr41a9+xRVXXMEb3vAGfvWrX3HBBRecrCYqiqIcF9PzmVf25vjL7z+930D3UAXNgondjH7nRjQ7Rde7/xYj2znzXlhtpmzr6dZjbusd64f2WwF89pzsI3lPUQ6kt7eXbduaWRVh1JyCcOONNyKE4BOf+ATZbJZ//ud/3mdJOUVRlBcD29CJpaTuhxSqPiU3JO0YLGhJMFRsEANJ22L9nkkmKh5ndKUp+xFPD5bRpOTWf/o4T9z/S67/87/mkje+hyj0eXDrJOV6gFeeoH3RcnoyCRA6QQymodHf5lBoRCxtT1Kq+5iaIJlsLs8ldEm1ERCagpRjYGmgH+Ahp2VoWAeYHT1fgfU0FWAfwLqdBR7YUsd2EmzatAl4fnma163s5s477+SVr3wl11xzDQ888ACnn376SW6xoijK/CvU/UOOIu9PWJlg5Ds3gtDoe8/fcNnqs3hg6/jM/Khwsvng0sj3HPYxcwmDUiPcZ/vVqw7/GIpypOJYMlb1qHkhfYuWcOedd1JzfZ4dLjNS8ujOO/zFDR+nXC7zuc99ju7ubj7+8Y+f7GYriqLMmzCK8cKYmhewdbRGteGxu1AnYZl0pm1yjomuSx7fMc5YxcfSmst57So0qPshvVmHb3z+06z98W185IZP8Ad/9EdMVhv8+wMjFOoBut/ArxTIdS0ETSObaFYIdzRBxY95bqRCue4TRBGL29OkQ4MoirENg9M7k1S8mNM70izrztCWtk/2z6UC7P2ZXX1WtC7kvofW8ZmfbJxZnqb55xJ+9rOf8YpXvIKrr76aBx98kI6OjpPbcEVRlHmybmdhprjY0ag+8VNit0rXu/6WRGsv924enzPi7Y9uR891odmHXyht7+A6YWr81bUr1frVynFVDyIqbkDaNlh8+nJc1+Vn969j7aRFFAu6sjb5pMXf//3fMzIywl/+5V+yaNEi3vve957spiuKohyzKJYMlVyCOGa42KBS93lmqEqh5tKVERS9iJQTs2WwzHMjNUZKDcIopjNrUWpE1P2AoV07ePzO73HWa95B5sI3c//WMXSpMVZqEMSSwuBWABYvW44A0pZGBEgd6l6ErWuMVT2CIKbml+nMOrhBhAQModGTd1i9uI1s4ugLk80nFWDvx+zqs1bXMtY/eTe1B+ZWuv33+7fzsWuu4Yc//CFXXHEFb3rTm/jlL3+J4xx4TVdFUZQXgjkPGY/yGLlXvovUWZdjtvXR2E9ZT29wE3bPGcfUzvdfvFgF18oJI4FVLz8PgB//4h4KvWsIIsmTuyOWtqfob03xta99jT179vCBD3yA/v5+LrvsspPbaEVRlGMURDFhFJO2DYSEiBhkhKnrFBoB+niZhhvy9O5JvKD5UD4Gym5I1tYxdIuytoAL/+RfWbZsCeMVj+GiR933GS3VMCzBxI71ALQsWoFuatR9ialH9LVlaYQRKVPDl2JmKa7OjM1AoUHS0hDoJA6z6OqJoipx7Md0hVldQLL/LNx6jdLg3ADbjySf+clG1qxZwze+8Q0eeOABrr/+eqQ8ilxKRVGUU8itjw3gBc2HjIcoJH5AQmiYbX37fS8sjxOVR7EXnHkMrYTMKfKkWnlxS1k6uYSJF8RccO7LyOfzbNvwGLGAWEiCMGbdrgJP7i4ghc5/3/wdTjttCddddx1bt2492c1XFEU5al4YMVbxGK/6FOsBqYRJV9oh49gEUtLwA8arAVXPB01jtORS90M0JJqug24QScg7BtnehYxWQ8bKLjvGywyV6hRdSbURM7ntGezWHgqByUihwe5CGV2Aaen05BO0pGwWZB3O7suRSzQLmrWkTDRN0JoyOK09hW2eOmHtqdOSU8h0hdmPXLWcr/75+wAIB9bvs9+X793Gup0Fllz4aq79nQ9z00038Xd/93cnurmKoijzZt3OAt99dPdMOrd+gLtEVC8d9QNFd/fTANgLVx10v4ONnhsHWJ5LUeabEIKOjMPi9hTtGYdLL72UbU8/QkfGRsaQS5poMUxWPYaKDerC5l++8S3iWPKGN7yBcrl8TOePY0kQxeoBvqIoJ9xYxWumbDs6FTegLWkh0RiteCxuTZGzNJ4aKHL/+m0MFxokzGZ6tx9Ba8Kk1TEwBCzIJmizddoTGlU/pFT3qNYDwhAKtYjKjqdx+laxcaTBrskGXhijaTp1LySftLn8zE7WLOvg/NNaecO5fby8v5VLlnbwiqUdnNaZYVlHBts4dUaxVYr4AcyuMHv66aejFZ7F5Q1z9oklfO+xgeZoT+eVZM56nBtuuIFVq1bRc/YlR7wEzPScx5akRaHuq+VjFEU5KtPXkukA9EiuRWu3TRDMGraO983uprzuh4SFQfKXvhes5hzqI6ma7O54As3JYHWedsB9LEPjk69fSaHuU2kE/HTDMDsm6s1zAe+4oF9dH5WT4sorr+SHP/wh155m8UDSIpYxubRFJGPcICKbtFi14kz+39f+i/e/7Q285z3v4aZv30KMIGkZWMaBxzaiWOKHMYYuqHkh41WPUiMglzBJ2QZdGWfeq90qivLiFcWSittM2844Jm4QMlENMA1BR9rGONBT9CluEDFSchmrerSkTB7f1aAjY7JyQY4ndhd4drjCtntuJSwM0fOq9yCNJEJEgGDHeJXWtIMUglgTuKEAIdCASGqU6zEh4I5sJfZqpBa9HMOApC0Iw5gndhcYKtdoSRic1ZuhO5sAIbAMjaoXEklJl2VQ9UIc69QKaU+t1pyirrrqKr727//BkqsiatHzT0eEaHb0/DBGImi7+o/IB+O8813vouO9n0Xk+7AMjZuuX3PAjuDsoPpTP9qAF8RIQBMc8rOKoiizrdtZ4HuPDXDLugHCKMbQmjezMIoP+3pSaQRzipHtPWbmj2yl+vhP6Hjzx9Hs1D6fl1IeNNiWUuLueBxn8TkHXQN7UUuCO9YPcfWqHv7wimW8ZmU37/nqWoIwxjQ0rjtv/+nninK8XXXVVQBseOQ+fuO1b2G85lNtRHTnEnTlEkzUPACuvPIKPv/5z/OhD32Ij93wl3z0Lz9BsR7QmbXRNYFt6PhhTBjHMyMvg8UGYRxPjViDIaDhhWRsk4Yf4YXxKTfXUFGUU08cS0qNgMFSHVPTMQxB3Y/wwgjH0PGDmELdpyNz8NpRw6UGd20cJZKSVyxpBwFhDGd0JVm7ZYxgeBuFdT9hwVs/TkCK2lQtUgvwg5iEERBJGK+46EJS8WIMHTIJjYYfo0mo7ngMgPwZ52DqUHclfhRgGREZR2ek7PH07iKyD9rSDpahkTB1HEOj5oekbAPnFEoPBxVgH5bTV1+K+8UvMr7lSRKnnTezXROwsjeHZWjNTl8iyef+82be9ZtXsOc7/x/d7/u/CCfFrY8NzATRGwZLSOAtU53D6UJCmhDEUs50ZmMJQRizdtuECrAVRTmk6cJk0w/pgKmR6OZ15XCuJ+t2FvjyvdsO+D6AP7qDxNILMNv68IY2U9vwK4RuIkyb3Jq3I4yDz4sORrcRVSdJLFl90P02j9XYPFbj3s3j7Jqo8bFrVnDT9WuOODNIUeZb/5JlLFjYz12/+Bnnv/at9GYd2nockrZBLmFgmxpSQsLU+YM/+APuf+hR/vXzn+Wcc8/l3MtehxdG6JqAWFINQrK2hWVqtCYtwigmZRtUGgFVPySbMIlopogbut78nKIoyiEU6j6TNZ/RoksmYbKgJUEjiJqDg1P/HEoQxTw7VKG/zWFP0eXxPZNcs7KHoVKDZwfLVMOA0Z3PkVx6AUG2j+pefYL2NW+n6oOUUAsCujMmUdSscdXwY9I26KbBwPZ15BaewcolC7E1ycaRKlEMXhhTqAcMFusMlhtUvYh8yuKVp3eQsg16cgliySl5XVQB9mGIu89CGDaNrY/MCbDjGDYMluZ0+gAu/4O/5Uef+QPGf/hZOt7yV9z00K59jnnLo7t52/kLZ6qVI2Uz7UtKYprBu2loao6hoiiHZe22iTnBNTQrbSIEURQf9HqybmeBWx8bYP2e0iHXvDbyXbi7ngKg9MDNOAvPRk+34g5soPzIbeQufvtBR7HrWx4GxCED7Nm+fO82XrOye87UHUU5WfxQcsWrX8v3vn0TN+ox+XSaitcMiNvTFvlUcw3WKIrZU2zwx3/5adY98RR//Ae/ywf+7r9YueIs2tI2bghpR6O/NUVHNkEsJUIT1P0QoQkWtCSo+RF9+SRJyyCftA6aXq4oijKt3PBZv6fA9rEaXhBx+fJOVi7Io2uCiZqPoQnySeuAny3WAwxNkHMMyq6PkNCetKj6EWMVl6f2FChXffRsJ/5zTwL79gkmHrmNxOVvpyOhUfNjyo0IXZPIEOpuM/4RlQkqA8/yqrf/Hp0Zh0o9QEZQ95uFwqLQY+dEjYVtabaMVmhJ2yxocVjR0/wu+qkXWwMqwD6kdTsLjLmQPO0cGpvXYl71+yAEQdQcFfr2o7vZPFLBC2Me2jbBvVvGkbKXlt+4nsIv/o3S/TeTf+V79jnu9OdnRr8NjRuvbc43VHOwFUU5UnundgN84JLTeM3K7oOO+n7zoV381e3riQ4VWU+xe8+k/PBtDP/3/8bqXkr2wjcj4wjNSdPY+jBw8PnYjS0PYfcuR08d/rUtlqhsHuWki2LJWMWlWA94xRVX8d//8RU2P76WCy69kpRtkLB0nt5Tau4bxTSCkImqT6ER8KY//Xu+8JF38N9//Sdc+bGvkMvkOP+0VrqyaXYX6rRnbMJYknNMdA0sQ8cxddpP8ndWFOWFyY8ko2UfQ9PAgMm6jxCgaYLefAJzP3OvpZSMV102j1QJI4kXRnTnHUZ31xmvueyaqHDn0wNYhsFQocZEPSJoP5PQPXCfoOJDxZ8q5tKIaS7gBUmtWUR1bP3DICXWaecxUqpTrPlEU3uZBiSs5n6TVR8vjHEsvTn/Opan5Mj1NBVgH8TstWBTZ6yhtvkhLm2tYnQu5efPjCCBMJI8vKOwz2cz512LP7yF0v03Y3WfTnLZhfvss2FPid++eDGZhKmCaUVRjsmGoX0rFf9i48hBry/rdha48TCCa2/PRoLiMEIIUme9irZr/oTyw7dRfuT7JJZeSOK0cwlLI4SViYMeJyyP4g9vIf+q3z6i76aLZsXw2cXb1PVSOdEqbkDdj8gnTS659HIymSw//fEPuPjyVyMR3PPsKMMVFz8Iac/YbBmps6A1Qd0LcPIdXPWhv+X2v/kDHv6PT/Gbf/L3FOo+jqnTlrGpeSE7J+poGqzoypBN7H9kSVEU5XBkHYPWpMlENUIIjclqwNMDJXpyCTRdsGCvIDueeoC4u1BnsFinXPcZq/qkTI0ggm2jNYo1j8l6QHnXRmoTw+hCYB9ln8CNQcRQ2fwgZraDEaOXarmOoxvoUzWuDA3yCZslbWnQDRa02LSnHUAgpWSwWCeW0J62ccxTqzaFCrAP4t/u3oobNJ+0OEsvBKHxvVtvpfWy9x3ys0IIWq/6nwRjOxj/0efoef8/Yrb0Nt+jWTjoyYESTw6U+Js3nw3AF+/aojqOiqIclatX9XDv5vE527aM1fjszzZhm88XOJtdWPGO9UOEhwiu/fFdjN7yKbJr3kJ90/14e54lsWQ1mXOvxsi0M/mLL5FYcj7u9sfpeOuNBz1WfdODACRPv/iIvtuVK7qA52tWqAKQyokmpWSy6jFcdunOOmimyYWXXclPf/wjrvzAx9EtjW3jdTpyNrvGamQdE8MQZCzBYCEkZUH36S9j9dv/mEe/9Y8889P/4lV/+ucs6UyTT1jc9ewoSdvAlDBQdEnaJmU3JGHq5JPmEVXpVxRF6co6nNPfwrpdBcIwpr8thRfH6Jognl6tQBOUGwFeGFHzm9XCa17AZM1n3Y5J8gmDsRhGK3XGKy5lN6QwuIvB73yK/EVvobTpfuyj7BMIQPPq1Lc/Tv7cq/EigRNJbKu59KFhhGiaRlvO4aLTu/CjiJaEhWlqnNGRZrzqEU6NYo+UXRa17Vt09WSalwBbCPGnwPU048angd+RUrrzcewTbbrzuXmkwp3PjMxs1xI5nP6zqW+6n/ylv3VYNzvNtOl48w0Mff3DjN36abp/63Ocu6SLiarHQPH5n+df7trMcNkjllJ1HBVFOSrvvqgfgM/euYnJmj+zfXaBM2gGqdMPDg9HY8vDZC54I7mL3kp29Rsor/sB9c1rcQKPzHm/SeL0NcjAJXvBmzGyB09orW+6D7NjMWbrgsM+v64JlrSn+KdfPDdTs0IVgDz1vVj6BQ0/wg0i9hTrPLm7SBDEDBZdutMG3S9/FZUf38ZjD9/HuWsupSNjk9AEMVBwA3pzCdKWTX+bZGFLkiiWXLzkD/lSYQcPfO8rvP9Nr6Hr/DczVKpT90MiGeOHkmzSZONwmYShYxs6hi5I24YKshVFOWyapnF6d5aefJLdkzUcS2fXRI1GEKLpAjeI8MOI0YrHZNVnx0SVFsdg63gDL/IxNZ0gkoxXfYYKPglLp1gLqWx5mNz5zT5BevUbKB1Bn2B6gDEJtKU1Rrc/gowCkssvIYohjEDqGgbQ35GmL5NgQVuKi5e1kTB1Jiou9anq524YNx9KCo1QykOuYHKiHXO1DCHEAuCPgfOllKsAHXjnsR73RPvmQ7t44xfu4x1ffpDP3bmJ7z8xOOd9CSTPfCXh5B6C0e2HfVwj10X7Gz5KML6L4p1f4O3nL2SwOLePMVB0CWNJLJtLfn3vsQG+eNcW1u3cN/VcURTlQN59UT9fed/5OKbG9NQkjWa10Ds3DPOlu7fiHUFwDWC09OBue4xgYgBhWOQueiv2wlVUHvsRjZ1PYmTaMFsXHDK4DsvjeHs2kjrz0kOeszVl8sHLlmBMPWn/0j3buHfzOLFUBSBfCF4M/YLxisfdm0b4xTNDDBYbPLm7gKkLunIOUjaXmjnvksuxEkkevesndGUtMo5OI4q56qwuzuvLs7g1QXvO5pyFLbSnHSZqAQlb5+/+8fOcufJsPvqHv8sTGzZS9kK6sg62obGo1cHRdbwgwg0jCnWPPYU6O8ZrFGY9OFMURTkcKVunK5dAE4Ll3Tna0hajJZcNgyUe31WkVPMRSOpuyHjdp+oHLMwl6M/bPDtYptYIyCU1bEPDsSHd0oO7/THkxAAcRp9A0FyyKykgIZp/zyYgBEaevBcj3YqzYAW6Bo6ugYSF+QSdKYdcymZ5d5ZcwiJlmyA0EqbBnpI7VV3cpeFHdGSsUyq4hvlLETeAhBAioPlgYvAQ+58y1u0s8G93b50zWn0gyTNeweSd/0rt2XuwupYc9jne9eZr2Zou8KN//0f+6fOfRy66cs770090pv8+vYatGs1WFOVIrV7UMrOyQUvS4q5No/z8mRGeHCgBpSM+Xmr5JfhDz9HY8TgAZlsf6ZVXQBRSf/Y+nP6zEeLQz2rrz94LNB9UHsr/uupMCnV/ztKF0HxYcMmydj786jPUdfHU94LrF8wuYrZuxyQxMdvHqtTckELNZaIWoAnBWT1Z0kmLsxd1cu4lV7LhoV+RNgTJliQDRRfL0Cg0AvxQUg9ibFNjT7HBonya9rRFS8bkm9/6NpdfcjHve/c7+eYPfk5PLkGxEbC8K0OpESA0i3IjxA0i2tMOSUunUPdJ2YaqJK4oymETQtCasmhNNes6jJRdGn6IYxiUGz6TlYiJeogfxkTElOsuj1VdxmoBsYjxYwijiCiIqbvgLL8EY+g5ijseJ8Gh+wQS8IFQNp+0moAnIapUKG95lNbV12AKjUYME40Y04ppS1skbIu0o9OXTyCQ1LwQP4oJo2aae0vSwjF1enIOCevUm/F8zFdpKeUe4LPALmAIKEkp7zzW454I00XMDhVcn9PXXOtaT+ZwFp9DbeO9SHl4FXcBnthd5B3X/zHpM9bw6Hf+mcbAM1OT9wUfvGwJtqmh0Szkc1pHmjBqpkF6QXM0W1EU5UisXtTCH16xjHdf1I8bREf8eXf3eipP3EHx/puJagVSK3+DYGwHtWfvpbH1UQCEYRJVJ2g+Fjy02sa7sbqWHjI9fHrJjT3FBpomZo6uAZapqeD6BeCF2i8oNQIafoQQksmGhw4kTJOy6xNLnXMW5jmzJ8OCXJLlXRnySYtr3/w2quUi69beS2umOQr91O4iW0YrFGse4zWP1rRFW8qht8Wm7AdsG60jM538v3/7Gls2PcNf/fmfYhk6K3pytGUcWtM2jmmQS5rkkybjNZdC3WOo1GC07BJGR5aFoiiKAhBGMQ0/YPt4g6cGioyW64xWfRqux7PDRR7ZNsFEJaBQ9ylW6ujARNVl94YneO7XP2Tw3ptxj7JPkLWeD7Y9DyY2PICMQtIrLiOmObrdmmiuGjJU8enK2bz6rB5SjsmOiTqDxQa7JmvsLtTxvAhDExiahmWcWsXNps1HingL8EbgNKAXSAkh3ruf/X5PCPGoEOLRsbGxYz3tvFi7bQI/PPSNSgI3/+4arjqri8xZlxOVRvD2bDzs8+yYqPOJHz5Dy9Ufxsh1Mn77Z1jdKfj271/Mx65ZwU3Xr+GdF/Wj6xpbR6sz69BKmqPZKlVcUZSjsW5nAe8IA+ywMs7ET/6J2K0SezUG/+OPCIvDZNe8DYSgtvFuhv7rzyg/8n1yr3jnYaVlBZN78Ie3kDrr8kPuK4Ebb1/PzQ/tIpxaztDUBe+8qF9l9LxAHE6/4FTsE0xL2ga9uQRlNyTtGJy9IM/i9iQJw0DGAi+OSFoGbWmbt7zhN8m3tPCT79/CZMWj6gYgBAOFOhNVj4e3TvCrjaMQR+wq1BgqNsgkDXaMVVm15nL+5KMf46e3fZuf3vJftEyNMOWTFgtbEtiGTnvaxtAE28bqdGUcolhSqKtUcUVRjkzV9dk5UWPneINYRtS8gCjSqHoBxXpAGMbsHK8yWKixeazKtok6uwse4yNjPPe9/0upXiU6yj6BAZT9Zlp4CLgSys/cjZXvJtu3HNMAx4BYEwgEKVMgNJ180sKfmtbmRzFJy2BJR4oFbUl68wl6884pu1TXfIypvxrYLqUcAxBC3Aq8Avjv2TtJKb8MfBng/PPPP/zh3+NozZI2LEPDD2IOFmavHyyzabhCI4hInnEx4mf/Qu2ZX+P0nXXY54piiZFI0/XmGxj8xv9i+7c/zcv+6NdAc7Rp7bYJwihGMjdlPIpUMR9FUY7ch7/1OLc/MbjP2tiH0tjyMPaCFeTWvA2AxNILKPzyKySXX0L+kncB4I9sQ9hJzHz3YR2ztuHXgCC54rJD7ithn8rm8dTr6UJt6np4yjtkv+BU7BPkEiZuEOKHksvP6CSOJW4Q4tgmS2PJ/ZtHsXQNXdN4Yvc4QupYpsa1b7yO7337Zhzh44YR+bRJX5hgtOLRlrJIWhqP7i6yuj+PjCVVN8CPYyZrPh/58xvY+NTjfORPP8wF56/moosuAsDQNTQh0DWNtpRDGIFpCFw/wrZUiriiKIfPD2Me3lag0vBphCGVuk/KNmmEAWMVl2eGSnhuSMUPKEVQ8iGMm/HI5HMPYx1jnyDcuz2VCao7nqJ9zdvoyNq0JE1MXVB1Yxa0JjitM4sfxGwZrWBqgqRtEkRxc+3rqaJmkzUfTRN0ZpxTctrMfATYu4A1Qogk0ACuBB6dh+Med3vPVSzU/Zk/f/zUIM8MVYBmcHzDbU83P2QlSZ5+EfVn72Xh1R+kHh3+v9RYwsUXnMfLzv48N/7pB/mLv/gLPvvZzwJzg/3pheCjqFkRryWp1sNUFOXwrNtZ4O/u2MjDO44u88VZ9HK8oecISyPo2U4Si16O9c5PM/qdGyEKyV/2W0dUg0JKSe2Zu3AWvRwjc/BCaM39577WRLOSuKpN8YLyguwX6JqgN5/c73tVN6Q145C2de7aNIYXRvS3pGhP21x57Vv4769/jZ/+6If0rH4NuhC0Z2yKdZ/2tIPnR9S8iIobsGWkxljZoyVl4eg6UsI3vvFfXHThBbz1rW/lscceo6OjAyEEHRmLoZJL1QsxNcn2sRqWqZO0jVOuYq6iKCdPHEu0A4zkDhXrrNs1yaY9RTIJi8GCh2EIFmVMHtle5emBSQaLPl4wVRSVZhr39MDjfPcJAOob7wYZY6+8grFKM1A+qztH2omIpUDIGCk1xqseUSQxdZclHRk6sxa2qRO4IZomiGWzbsaClv1ft0+m+ZiD/RBwC/AYzaU4NKaeSr8QzJ6rOPvPc/oP3Hnru+C1xI0K4xvXHvH5HtlR4CnnbN72vuv53Oc+xy233DLTjhuvXYmmCSTNTqmY+o/nkz/cwA23Pa1SxRVFYd3OwgFXGVi3s8DbvvTAUQfXAHoyBwjKD9+G9Bsz29pf/1HiqddHwtvzDGFxmNSqK47ocwJ40zm9/NlVy3nb+QtnalPMXnIMDv57KCfHC71fsD+6JmhP21QaEaW6j6Np7C7WGa00OOf8C+la0M/3vv1N9kzUeGLHOBNll7N6c8SxZKjkkU8Y7JxwQYBlCbaN19g0XOY7D+9iRxm+9PX/ZmxsjHe9611EUXNaR9IysHSNnGMSSkFLymJZR5qqGzJR8/DCI6+voCjKC1c4NYo7TU4FmDsnawwW63PeA9g1VuGmtTv51TMj3L9tgj2FBr4M6W9JUvMkuyYr1P0IL4CIZlA9O7iG+e8TAFTX/wqr5wyMtj4qAYxNemwaKpO3NPKOzmjJxTQEMoZiw6ceSIr1gKSp09eSxDF1hACB2CdLzw/jZkG0w5gCfDzNy5i6lPITUsozpZSrpJS/JaX05uO4J9NbzuvD0vf/NKjRtQo91UJ1/a+O6tiP7CjwcOdv0nraSt7/27/Dxo0bWbezwB3rh4imlusKY2b+7ocxNz+0i/d8da3qRCrKS8T+Asfpwoyfu3PTfq8HH/n2E8RHmWw7XbhRc9K0vuaDRPUSY7d/Bm9wE1GtQDCxG3fXU8jwyOZ/1p7+JcJ0SJ7xiiNrD/Cjp4ZYs6SN687raxaa3GuJrkP9HsrJ82LrFzimRk8+SS5pkLEtsimLaiNgtOwxVvFZdvHr2PzkQ2zYup1CPWKiHuBFku6cTVvKRAOiMCQkZvNojaFSncFijYmax32bR9getfNnn/wMv/zlL/n4xz/ORNVj03CZp3YXGSjWKdQ8dk9UGSjUGCw2mKx6DBYbBKrgmaK8JBTrPrsLdXYX6jPFS70wpuKGpCwDN4ypecGcz2yeqFJ3Q1rTNlIK9hTrmJpOw48o1j3qbkS5HuHRTOOWPJ/Ofbz6BP7INoKxHaRX/UbzPEAxhucmXO7cNMbjAwXWDxZ5aucEzwyV8cKYXMIk7ehUveb37szYzRotEtrT9vPHDmP2FOqMlF32FOonNcg+9eqanyJWL2rh5t+7mD+++TH27LVutdB0UiuvoPzo7UT10tTTnSOkmyRe91GK//knXPLqa8i/8x+IzcQBd5eAG8T8291b+fL7zj/y8ymKcspbt7MwM2XlUz/agB/OTYmeLsw4eyR3OlV63c4COyfrR3xOGUfIKEQzn79JaaZNxxv/nNLaWyg/ejsA4eQeWl/zQYRx+FNW4sCl9uy9JJdfgmYd+Pp2IFEs+d5jA/zNm8+emc4zHVx/8a4t7Ck25vwetz42MLOPSiFX5tP0UjdZp4WSG7B5pMoZHSkiJE/vKbH04tdx/y1fZsM9P6L/Ve8mjiUL8kmqwKL2JF4Iw6U6RDBUrlLxfLaPVtGI8IEohq7My7jsN9/G3/3d35FcsJwLrrgKKSO2jLiU3IDFrSk27Cki0BivuGi6oOYFnNmVwzgF5yAqinL0pJRUvRApJQnToFDzSVo6Ydycf9ybTzA9SySMY2QM2qxpI3EsMYVGyQ0YLbkkDJ2XLciiGwZ+HFGoNwsyzh4C9jm+fQKA6tO/AN0guWJu0dMYqPlQ9yNyCZNqGNKN5GV9OQSCmh+RsALcICJpGyyynw9hG35EI4iI42Zh1JRtUPNC6n5I3W9mIKVt44ROq1EB9kGsXtTCP7/rPN7+pQeI9hoVSq26kvLDt1LbcBfZC950VMc3su10vOHPGfn2X+L9+J9of+PHDvkv/85nRvjMTzbysWtWHNU5FUU5NU2PxvphjCYE0dSNYnYgPV2rIQhjdL25tu70qO0//eK5ozpv4ZdfISwOY3UtxWjtJXn6xWh2cz5Tbs1bCcvjCNMmdiuYLb1HdOz6cw8i/Qbps688qrZJ4NuP7AaaWUV/eMWyOb+ToWsYWvO30nWN7z66mzCWap62ctwYusbqRW1kHIs4llQaPg9tn+T0M5bSuuRlDD16J52XvB0/NBgo1klaBramYZqCnGNTDz1qXkS14VMPoO5DAFhAFDfovfR36HvuGT7zF3/Mx//tFlLt/TSCmCCSNNIRbhDTmbV4drCMY+pEsaThx5yzsAUhmu1TFOWFb7LmU2oEIMAxQgxdwwtjIilJTwWXtqHTlrIYq3i4YcR4FWIp8cOYYj3A1ODSM9oYKjbwY0nGsZioNBgs+WzYU2Sy1iCK5hZXPp59AhkF1J75NcllF6EnMvu8rwGWpuFHkLBsMo6JF4IXBAyVXJZ1pGn4Ed1Zh85sgqRt4AYRg8U6hq41l1kEalPj8SNlt1lkWtPozNrkT2BNK3UlPojp0aT/701n856L+uekjFsdi7B6zqD61M+PaE3svTmLXkb+8vdT33Q/5YdvO6zPfOmebXzzoV1q3qGivIjMHp2OpUTXxD4p0dO1GlYtyBHHMd98aBdv/dIDvPVfH+DezeNHfM7Cr79OWBym5Yr/geakCcZ2Urz/mwQTu2f2EZaDnshg5HuO+PjVp36Oke/GXrjqiD87LYol33xoF2/71wf45kO75vxOYRSzoifLOy/s562r+whnTauZPU9bUY5VNDVqNFnzSVoGL+vL05136MzaLGxJEgaS/jXX4E0OEQ8/R29Lkq60TX9bkkYUo9FMa/QC0DWouc3gejodc3pt2NAwufT3Pg2awT/f8CGe3T1GoeYSRSGbhitYQsPUNISAmJhyI2DXRJVNI2V2FeoUamoJL0V5MWj4EY6pkbYMvCCmM2PjWDpZx6Q11RxdllJSDyJKDZ/xikelEfLIjkme2DHBc0MlNgwWGSq55BI2GcciikKGi3VGSjUqboDnQyhnBdfHuU9Q3/IwcaNM+uzXzGzTeH60N21C3jFY3JpgWUeaM7rTlGouOyfqbB0pc+sTu/nXuzbzpV9v4TvrdjNeaQbQAoGhCUoNnzCWZGyDloTFcKlBxWuu613z9q5lfnypEewDmD1KYhkaN167Es6HzSOVmQJC6Ze9hsmffRF/6Dns3uVHfa7shdfhD26iePfXsXuW4fS/7JCf+cvbnm4WS1EjNYryojB7dNqcuuYU6v6cdOd1Owt88gfr8Wel1BzD8z2E5ZA+53WY7Qsx8t34I1tp7Hic6oZfk7/knbg7nyJ2K6TOetURp1YFhSG8XU+Ru/S9CHHsz3Jjmte9v37z2c0VF6aC7KcGSmwaqfDbFy+emX8eS9TqC8q8Gqu4NPzm/D83CNGEIAglQtNYmHfYOVnjzIt+g/Xf+2f2PHwHfSvO5bxFeWq+JGkZoAmCOMI2dDrTFuW6R1iPMWTzv9cAMAwgDMl0dfPW//UZ/vv//E/u/8+/5c1/+hkiAUlDkE6aCCBp6gyWGyA0EJI4jkmZOoWGT8Yx1Ei2orzAZRyDiVqz5FjWMbBNnU5Tn7NPEEmGiw22jFTYNloliEIsQ+fereNUagF+HHFGZ4pXLOsikjG7JhvsKfnsmKjhehHVcG5Bs4P1CVoueSeNY+gTAFSfuhM93YZz2rkImpk7tgGa1lwxpCtjEQC5lE1n2kETgt0TdZ4bqzFWrFPyAlrTNv0tKfaM19hTaLC8O4Oha1M1KeTUXO2QnGOStk2iWFIPQhzzxF4TVYB9ALNHSfww5sbb1xPvtSxGasVlFH75VapP3XlMAbYQgrZrPoz/jY8wdvvf0/P+f8LIHnw5mxhgP/MwFUV5YZq9bOCB5hB/77GBOcH1sTLz3RTv+QZaIovTdxb2gjNB0yk9cDPuwDMYLT1oiaO7tlWf/jkIjfSqV89be2OgUPe56fo1/NMvnuO+zeMzafQbhsozaW7a1H6KMl/cIMaxmp3bUi3ANDXStsFYxeWezRMMlGrUMVl04ZXsfOhOFuc+TqURMlLxmKwHOKagLe1weleKIIqoeSGm7lIPJMZUtgqaIBaCMA655NIrGPutD/HT//xnlpx1DmuufS+OJZis+5QEjFQ90pZJW9LCMjQMQ8ePYnQh5szDnM6wU0t6KcoLSy5p4VjNpfzs/dRYaKaBe2wdrbBrssrOiQqjVZ/enM14qRlsejE8O1Sh4jaXAE5aJpoGcRxhiOa9cnaAfbA+QeMY+wRheQx3++Pk1rwNoTWvpboG2aQGQuCHMRUvAi+g3EhS9QK0KqQSJlHUTI3POhZBGFF2I/Kp5u9jGToLWhJIJFEssYxmETfL1MklTMJY0pKyyCVO7EN39YjzANYsacPQBILm3ITpit7xrBK9mp0ieeal1Dbec9Sl6p8/VpLON38cGXqM3f63yDA46P6Gxj7po6CWrFGUF7LpZQP3F1x/86Fd/GrjyLyeL3XWq8icdy3VJ39G7dn7ALB7TiexbA2NrY9gtvSiO+kjPq6MI2pP/4LEktWHfFh4JAyNmYcPH371Gdjm85XFr17VM/PaMudeFxXlWLUkTUq1gN0TdTTRnJ5Q9QLiWFINAjpSJm0pk67zryb0PTbe91MaQcxkzccPQmQMhhAkbZMLTmvn5Qtb6W1N05NNYBo6YRwDgqRlkrIs2jMm/+uj/5uLXnUVP/rqP7Br46OMFBo8u6fI+oECbhBSdQO2T1RZ0ZujI2VhGTpdWWdmPVw3iNg1WWfnRH2f6sKKopz6bEOfWpJq7gMyL4h4YleBjYNldk7U2TxSZrDUoO5G7J6oEcSSugeNAKoeTNQbDBRqrN9TZKTYIGEaSLFvEHi8+gQwVdxMxqRe1kwPl0DGgtPasyzIORBL6n5EEAKxnBq9F2QcnXzapj3ncHpHioWtKRa32rzqzC4WtaXQNIGuCXpyCUxdww1iOjI2acegrzVJb95hcXvqhGf1qBHsgxECSbPQkCaaaVya1qz2OS398tdSW/8LahvvIfPy1+73MK1Jk8n6oW9uZvtC2q75MOPf/1smf/lvtL32Q/vfTxP8j1eeRiZh7pM+OjutXaWOK8qLw4e/9Tjff2LwuBw7tfIKhGHR2Poo7s4nyF74FmpP3UniCJfVmq2x9RGi6iTpq/7nvLVTE/CpN549c03b34j/8u6MqiKuzKs4lggBacdE01wcQ6fkBgRRhBAaMo4Jo4jN43U0YNHyVWxceDqP3XkrHedfQ82NqPohk9WA/tYEmpSkbINcxsIYFyxsTxKOlqiiE8RQ8yLyCYNCLWRw0udVv3sj27c8xzf+5iNc98n/JNXaSbnusbQrQ2faYaTqMV7x0IRgZW8WTXu+EzlR9TC05oj2WKU5d1yNZCvKC99IqcHuiRoJW2f7ZJXnRmvUg5hWWydlmSxrSzFQrNMIm0tZNbyYlG2iiZCJmj9TfTyhg4hg9hqKx6NPIOOI6pN34iw+FzPfPbNd0wQyjmkEkEtbIAWOobG4I0VHJkFPzmGgUGNxa4qX9+YIpWRpR5qlXRmkBGtW0GwZGgtbk8hZ2caO2XxAcTKoEexZZo/+rt02QTgVSccSoqkiAHsvOWkvOBOzvZ/qEz894HEPJ7iellp+CdmL3kL1iZ9SferO/e4TxJIv3bNtZo7h7DbvvYSPoigvbN98aNdxC64BNCtBauUVZC+6DqSk9MC3sPvOInfRdUd9zMoTd6CnW0ksveCY2zcdDkgJ6wdLc97be8T/YBkAinKkKm7Ajona1AhwSN2LqIch1UbAs4NlClWXrWNVsrZBb9bBNHQWtTq84pq3M7ZzE4PPPU0cRyQtHT9uFiwaqfhsGq3Qk0vgBhHbxqsMVXz8IEJDknI0ErbJWNXHsXWKocW7P/aP+G6DH3/+z3HrDVrTNova0qQck/7WFB1Zm8maT82PqHkBewp1RivuVJ9FEkuJJoQKrhXlRaDqBjw3UmGo1OChHZMUKh6dKR1HE8QIHMeiuzXN8p48Z3Sm6G9N4Jg6oYQokvgh2DrEUbPuQ09OZ/YimsejT9DY/hhRZYz0rIFIAaAJJqoeugZp08TWNRa2pTm3r5VzF7XS354i41gs605zWleac/vzLOlsjqAXGz51f9/CZafKdU6NYE/ZX1EzTQjiQ1QQEkKQfvnrKPzyy3jDW7C7lx1zW/KXvQ9/eAsTd/4rZsdp2D2n73e/f79vGzsm6sRSzrR5dpEklSKpKC98d6wfOu7nEJqO1d5P2+v+CBmFCP3obw1haQR322PkXvGOmXlWh7KoNclQ2cUPn3+CKYBXn9XF3ZtG8aNmJtEt6wZ4y3l9KoBWjrs4loxVPBJWcymsitucd132AiZqHptHKgyVXQYLdbwwojVpMVZzaUsnecVVr+fOr3+OgYd/wrIFZ5AyIa8bPDdaxTYMan5IFEUIKQiDCC2CRiyJ/BDT0JiouYxVfFqSJgMTNfLJNi75nRu4+0t/xUPf+Txv/aMbWdiSwDY06oGcKr4mmag2P9eetfDdmISlg6YRS0lXThX9U5QXg4YfYekaK3pzyIEiEwmTyYYgYUhySYO+fJIVvWl2TTbYMValO+ewfbxOzQ/QSTBWrVHzAS0ibekkHQu90pgzGXs++wQA1SfuQEvlSZ5+0cw2B2gEMY4pabV0UraJrgnedn4/KxfkqTR8tk/UsU2N1oSJY+osakszUm4wUvHIOCYjZRfT0LCNkzNKfTBqBHvK3qO/hbrPp964Cu0wHoSkV/0GwrCpPnHHvLRFaDrtb/jf6KkWxm77G6JaEYD2zNwb5Nax2pxlaaaL/3zkquUqPVxRXiSuXnXkS2Eci2O9kVae/BkIQfrlVx32Z95+wUI++fqV6LMuuKYuWNqeojX1/HUvilRmjnJiCAG6JggjSRDF6JpGV8amEUQMTDaAmOFSg8lqgz2FOjsma2QcC4ng9L52Vl12Ddse+gWFQonhYoNiLeC5sQqjlQZuELF5pEzF86j6Eb6EMARLgyAO2TxUIZ/QmKwGZB2Nui9ZcM7lLH/Nu3j2rtvYfv+PMQ2DM7qynL0gi0BiGYJCPSCMYiarPtpU+7tzDr35xCnZAVUU5cjZhsauQo2Hto5TdAPO6s1w7qIWzlqQZ/Vp7aQdi8laQMIyWdGbJ2WbdGZszl7QyhvPW8DSzhxL2lKsWpCjK+ewuDVJe0bjQFeIY+0ThOUxGlsfJX32axC6CTSDzwZQ8qDu+oyXXZKWztKODK1pi3zSpDufpDNj05lxCGOJpmv4YciOyTpD5QbpqaKT4TwWfp1PKsCeMr1Ejiaao9ItSYvl3RkOJ9FAc9IkV1xK7Zm7ib3avLRHT+boePMNxI0SYz/4O2QcUawFXHZ6+/Mpk7PbIMTMvEOVIqkoLx7LuzPzfqGWMqaxbd08HxVkFFB96s6p4madh/25J3YX2TBYmlNE8pyFeb50zzaGy8/PDtM1oTJzlBNCCEFXxqbqhZQbAY4hqPkRSzvSLGpLkU1YxEg6Mg7d+TSWprOsI03C0inWQlZcfh2h71JcfxcL8mkiCYbQ8EOJpWv4AZiaRtrWSdsaSRt0oUGk0YiiZuEzITFMHV2DSEoue+cf0LPiAr71+f/DY+septgIaIQxbhBTrIfUvJCUY1B3m2mT+b2WqpPHsqafoiinhBiwdZ2unEN31iGdcHjZgjzLujKctzDP0s4k7RmHnpwDyOaSVSkLQ9OouBGrF7eytDsLiOaTRE3gmCYMbyJqVOa9vdUnfgpSkjnndTPbZscxkw0ouj6L21OgwbaRCg9uGWPrWJndk3W2jleZqPr0ZGyeHCjhhxEjRZd7N49RagRY+qmREr43FWBP2TRcoa8lCbI5Z+lTP9rA9x4b4HBvR5lzr0EGLtX1v5q3Ntndy2h97Yfwdj1N4a5/J4wl920Z36dNmoDrX3nanKBaVRNXlBeHtdsmiA+92xEpr72F0e9+gsb2x+f1uPXnHiSuFcmc+5tH9LmfPzPCTQ/tmnNt2zaxn4eVp8jcKuXFr1j3eWDrGM/sKWIImKwFNPyQIJQsaHVY0pnlZb05FralubA/y8K2BBnbwtQ0TNPk3NXn0rnkLHY/8ENsXZB2TDoyCbwwxvMjEo5BwrbIJmw6sw5ndGXRdY2k3awY/PC2CUYLDUbLLjU/QBOQTdhc/sFPkmvr5NN/+rs8u203QkLCMsgmLIJI4vohuYSJENqcKW5VN2DnRJ2BQn3OVAxFUV5YDE0QSUmMRDM0DB3Spk5/W4KqH+OYOobQmawGlBsRbWmT3nyCMA4xNEEsmysaZB2DbMKgVA/Ra6Ps/M4nmLjj8/Pa1pmH7kvPx8h1zWyP5nwfsHQTQ9eRMewuujw7XOXJXUV2TdToyyVIWDqhlFS8EM+PMXWN1rRNxjGRhzUUeuKpOdg0iwjdcNvTc7YFYTxTmfNQ87AB7J4zsLpPp/r4HWTOu3beJtmnV/0G/vAWKo/ejtW9jPTKK2be0wQz675+/cEdvGZlN6sXtahq4oryItKSPPjcyY60xVj18Nd8bmx9hOI9/0XyrMtxFp9zjK2bq/LYjzFyXThLzjvmY8X7iQGmU8TV9Uw5nuI4Zv1AiUItYKTi4kcxC1uTZB2T+tSSW+cvbiVpapTdED+WOLpg53iD1qyB50uK2wNec917uemzN6CNP8f5K1fTlnKoez5S02n4ATU3IGFbJAQEMsaydCpuRKkWYtuQT2qEQUwsDAYLLrvGK0QRrPnAJ7nzcx/iL/7oA3z1m7djWM20y/58klhAxjEII8lkzacnl5iZT27qgkYQMlHx6GlJHPxHUBTllGQaGgtakuwu1LE0QbkRsrvk0t+aIJuw2TlRZnehQsYxWdqZZqhcZ6TiMVR2IZZM1kNCGVP3I8aqPqFX54F/+yuQ0HLF/5jXttafW0tUK5A+95oD7uNYzQKlC3IJxqoeuq4RxZK6F7BltIoXxXTnHHQhOL0zw9axMinLoK8lMbVKwrw2ed6oEWz2X0RICPj1c2NEUymLguZTozed08ulp7fzwcuW8O6L+rEMbebZSea83ySY2IW36+l5/WFbrvgAdv/ZTP70/+ENbwGaqZJXruhCwj5Vw2fPJ/eCmFsfG5jH1iiKcjys21ng47c9zQ23PT2TebJuZ4FP/WjDQZ/PHklwHUzuYeyHn8XsPI221/3RvFbb9Md24A1sIH3uNQhx7FfA/pYE5qzULw1U8UblhIhj8KMYQxNkbBM/jGkEEa1pi5akjRfG6JrAi2BZV5azenL0taZZtTBHPuHQlrE5uzfL617/ZtLZPM/88hbee/FpXLa8i6WdeRKmRsLQac8kOaMzjdQ1yl7IWV1pwjgkkhBGMF4OGK9FJCxBuRFQboAnYcTu5cy3foTnnnyUGz72v9AFvLwvz+KOFKau4YURDT+a0/H045g9xQajJY+JujdnOoaiKKeect3nmcEiG/YUqbrN1YiiWLJ9rErC1GlLWYxWmw8ASw2PHWM1tk2Ucf3m/7fv2zzOhsEipapLGEQYmuDZ0RpVN6DqhsQxhEHIkzf/A9XhHXS8/qOYLfNb86XyePOhe2LJ6n3e02jGVv15m/MXt7J6UZ5F7cnmqgtBzEjZI23rlOsBExUP09BZ3J7i3P42LlraRl9Lkp584oSvb3241Ag2zSJC924en7NtZW+Op/eUZlIWJc35S6d3ZfitJW2s3TZBf1uK8YrHtvEaW0arJM+8lMKvvkblsR+x6vyL2T5RnwnQj4XQDTre+DGG/vPDjN36aXre/3+5cOUSOjP2zPrckYTNI825E2uWtGHoGn4YI4HvPrqb61TlXUU5Za3bWeBdX1k7k7r57Ud2ceWZXYyUXdxgftI5Y6/O2K1/jdB0Oq/7OJrpzMtxp1Ue+xHCsEi/7DWHtX8uYVBq7LvExrSn9pT46zedTaHu05K0/n/2zjtMjurK22/lzmFyUM4ZIRFEBpMxIHJ0XmyztnftXe/aGDDGGHD257DetbHX60QGkcEkk1EWSRII5dFo8kzn7sr3+6NnBo1mFDDCBqnf5+GBmbp9u2roqr7nnnN+P1JFu+JvXeHvgqrKzGiMsHhjioAmM742Qk3YoGT7ZIo2kYBKUFcpWC6O59OVtSjZLl3ZEgFdpSqsUxvVaU1rHHvmhfzlzt/R19nB6NGj2NonMacpgeX6dOZLbO+zGFcVpuQ4bO7LoyoyId3HdEH4oBtguWWLUEUBTZLImQJp7FHUHnEeLz50Oz+YMJWbvv5l4iEDVYEX3+6hZHvMaIoR0cvnGtVVerMm8aCGoSo4vo+xlyr/FSpU+Ptiuz7rO7O0Z01yJY/V21KMqgqjaxLrOnKEdY2unElv1iJdcijYHrGAwuTaMGnTY2NnDuH79ORKFCyXgKaDcNmeLpIwVBwfNNlnzeO30vnacySO/xTGCEHwe7qG7i1Y21aTOP7TI266K4AiQ86CNW05xtZEGFcdZVTSpzNTorugkyq4qDLUx4PkTJdxNeH+FpgPaNp6ByoBNnDZ4WNo6S1wywubyr0JssQRE6pZ15nDdnx8yuXYmiqTDOlc/tslWI6/k8gYoBlEDjqF7LL72LB5C9MnTeDNjn0jGFAWPbuGzlu/Rvf932Np4EaWb0kNOYf7X22jLV3i66dPZ+6oOMu2lLNgni8qZZUVKnyAGag6GcDz4Ym1nftsfiF8eh7+MU7fduovvnFIL9S+wDfzFNY8Q2j6cSjB2F69ZnfBNZR9r1NFmy+e8N6tDytUeLc0JMKcMlOnJVXEFxAyFMKGSk3EIF20yZUcIgEVIcB2PXoKJttSJWQZevM6kiTxZluausPOwr/9t/zPr37FV75xHQFFQdFkdFmiPhYkXXRwPY+ortItSUQUhaLv4/hQHYKiA47joqqgSmA7gmL/oyJ09CeJdG5m0X/dSNO4KVzy0RNZ056hJ2eRDGos2dyLpsrURgOk8w7tWYtUyWVSbRjlQ7BArVDhQMUXgpJTrqJRFImWPpPOgoWuyPRkbWJBjb6ihfBcQFC0HTwh2NpXIFcSdORKKLJEwQZNVrAkh6zpIjyPvCkouB5da5ay/ek/Epl+LLHDzt/n15BbOfKmuwIEVcqZS6X8L1mF1lSR0dUR4pqG70NV2GB9T55CyWVsVYiwUd4Q/DAE11ApER/kqjOmc+M5s/sFAAS/X7yF686cyVdPncrN587mksPGcOzkWu5c3jIsuIZyFhkYFPdJrXyUtzrfCa73xR/aaJhE9elfxmpdQ99Tt4wowLZsS4qLb1k8GFwDyP3KuxXhswoVPpgsmFA9xKJqX5N58TZKG5aSPPGzBMbO2efz519/EuFYxOaf+a5ep0jlf3RVHqYEamiVcvAK/1hKjkdILwfVOdNDUxSSYZ2GeABJgr6Czba+PKbtki2WPaxVWUIgYdoOddEQx8yfyawFJ3DPbX+gszeDIkukizYF08ZxfTRZYkN7nrzpoasSLgJNKy9C00VwXOjNuwQVCV1V0HWJAOU1hSwrNJ/3NcJV9fz+hn9l5ZqNZAs+qqqQKrlYjk9IK1eKdBVMRiUCyKIsCaQqMiXbY3uqSEfGxPUqwmcVKnxQMFSZ5mSAguPTmzfJ5G0UoGi59BbMcszh+aiqiuX4SEImpsGWlENAkxECerIWvQWPnGXjuC6u6+H40FNy6G5tYcuiH6LVjSd5+r/u86DV6990D884fnDTXQU0ys+22jAEDYl4UCUSUChYDhFdpSdvkTVdmpJBJtZHOX5KHWcd3MzEuig1EWOfnuP7TSWDvQOpoo0vxBAv7C+eMImVW1Nc/9CavVLeVON1BCcfTv61x4kfeQmyVv5A7KuvrvCM47C7NpFdei963XiiIwgH7OwJN6Ox/OHeWfgMypmzStllhQr/WOaPTfKdhbO49v432NetkYW3XiTz8h2EZ59EdN67C4D3BuF75F55BGPUDPT6ie/qtZIscdEhozl/3iig/DyqlINX+KCgKjKO5yBLZSGdgT2wvOXieT5506XouFT1B91Z06E9XcJ1PEzHR5Z9enMmM0++iNWL/8qTj9zHQScsRJElArpMvuiyubdIMKBSyrm4XrkX2/MhqIPpljegbA+yRUFVxCeiypQ0DxyQFaiJJ5nxrz/g4Zs+y83/cQXf/tUdeLJOe9FhZnP5u79oOViOSzKkEQ1pGLqK7ws6cya6LGE5HqmiTSKkkzddVEUiGtD+gX/5ChUObCRJYnxtjHhQZ832NML32dJXRBZgqArRgIyqGOBLNMZDrGvPIEkQkB2mN8XI2w66AgHdYFtvDssFVQVJeFjZPC13fwdJ1ag779p93i4GkH/tCYRrEZ1/FjrlYNPQIRzSCCgSk+oi+EIiGdapCmk0JMPMH19NyfGoDuskwuXY6cP8HKoE2Dsw4IXtuP6gmM7KrSl++tTbw4Lrgd7nkYjNP4vOtxdTWPsc0YNO2efnmTj2E9jdW+h76tdo1aMJjJm92/EXHzpmiPCZ45aFz+5d1YrllMVablg4i8sOH7PPz7VChQp7x2WHj2F1W4bblrbsszntzk30Pvr/MJqmUX3KF9+X0qrSxhW46Q4Sx33qXb/W9wXNieBgIF0JqCt8kIgFNIQA0/WojQYRQFu6SHu6yPZ0CceDhniAgKbQUB3A9jw0WaYrZxIPawS0KEXX46CLzmTJbVO490//y/RjziYeVmnpKVCwHdIFG0kWSKpM2FCIGQo9BQ+Zcn+icEGTwfahryiIByFiQDRRtqiZUBMhoDdQ840f8JtvfYnffO8aLvmP7zKzOcq4mjCW41MdMfBF2Q+3NhqgMR4c1JWRZRnJFziuz1sdWfoKZdHE6Q1R6mIVpfEKFf6RGJpKQzyE7flkTY8JtRGKtktnxsTxfVxPoJYcYkEVJBCSzPacSVMyhKFKdOZthJDImya+UHAcj9YHf4ib7qD+kptQ43X7/JyF75Fb9TDG6FlE6sYTD4DrQ0BXCGgq46qCHDSmitqIgaGqjKsJ4wvIWQ6+kFA+oKJl75b94yr2EfPHJrn1igX8+ylTBzO8l/92CS/uJICmynDjObM5aFR8xHmM0bPRaseRW/kgYi8svt4tkqxQe/bXUBONdN//XZx0x/AxwLjqEDefO5vLDh8zuHmg9PeSCxgsdXd9wTcfWD2oXlwpJa9Q4R/D+fNGoav75rHsFdJ0LfoOshGh9txrkNT3Zyc4t/IBlGgNoSlH7NX4YyfXDP63L/ZsQ1ahwj8KWS5nWBrjQQKaQk/ewnZ9TMen6PjkSjbpooWhKoR0FdP2MW0PSZLIlVwMVWFUMszkuhgfvfgzbHprNYtffpGurEnWdOhIF3E8QUfaJKop6DJIskpIg2hApiasoWsQCkiEAuWMdX0swIymJONqIoyriTKzOUEkoDL3qJM58WNfYsXTD/HM3f9LJKBSsDw0rZyNrosGmFwXZXpjvLwWkCVqIkZ/pl0ibGhkig41Yb1sOdZbJGc6WI5HV86kO1cpI69Q4e+NKkv0FRxyJRfH95EkiYihoqkyhqahStCdL2G5AiFkJtVFmFwdYlJdmFjQwPU8LMehYEK64LHtyd9T2LSSqpOvJDB61vtyzub6pXjZLhoXnE08AJYDJReSQYW6iMb80TGOn1zPkZPqOW1OE5MbYsgS9OVtTNvF2EdroH80lQz2TswfmxzMovzymQ2DStw78pFp9YPZptdaM8PmkCSJ2CFn0/vYzzFbXic49qB9fp6yEabu/G/S8cd/p3vRjTRc/gNkI1Tu4xQCRZE5clINUxuig9d16xULBkvCAe5avg23Pw3v+YLbl7Zwz4ptIEm4XsVDu0KFfwT+SAbQ7xLhOnTffzN+MUv95d9Hibw/97DdtRlz6+skjvsU0m4UiU+ZUU9t1OC8eaNYsqmXF9b3lIVNKLfmVKjwYcD3BXL/96OKzJT6KLYHNWEd4fvkTQdDl+nOu7zVkWNCbRhXCCQEF1x8CX/+5fdZ/NCfOPHYY1GQSJkuNZpAkjxSeZPugoXvgUdZ0CwRVkEuW4H6wkcIiId0AoZKSFOoCqvUx4PYjkcwoHLSpVeSat3II//3U9TqMZx/7tlMrovRnTPpK9nUhYME9XIpOJTLLyOGitR/TUFdobdokyk6VIV0unMWecshESyPd32Lxnglq12hwt8LAVSFNXwRJKCpFF0XQ1GYVhfmuQ099ORMNEVFlVwKtktHziSiyyRCOtGgSsH0KTmQdyHzxtP0LrmPxLyPEpt7+og6TvuC/IoH0OP11Ew/nGhQJxIobzomIgHGVIWZP76O+mSIqnD5uVKybcIBjbp4kLzl4PoCHbBcDyHK/egfFmGzHdk/tgneJwayviP9b/3cH1fw3LquXfrThmccjxyKk1vxwPt2flpVMzULr8LpaaHn4R8jhM/sphiXHDYGhOCOZS1c/tslI2ai549NcsPCWaiyNHgNAnC8cqnYjqXklWx2hQp/H5Zs6uW9JomEEPQ+8d9YrWupPuPLGA3vnwp3dsUDSKpB5KBTdzvuoNEJbjp3NvPHJlkwoRpDK1fT6DsImVUqZyp80KmJGkA5I5wIq4QMlYa4QWu6yK3LWnjo9Xa29uYxPZ/5Y5JMboiRtzwMVabXhBMWXsKaxX9l27atNFcFmVkfoS1VoDVl0pErB7Y+LrYDJcsnpGuMrw0TM2Rihs5hE6twBP02YAFURWdLTwHLFUypjTB7VJyP/+dN1I6dwmP/dS2rV6+mYNn9ZZlhEiGVVNFBCIHt+qSLFumig+V6qIrMzKY4Y5Jh6qMGTckAjufTl7fQVRlNkXE9QcFy6Myag768FSpUeP9QZQlNlXFdgcAnHlAJ6QqqplCwXGRZpiGukyq5+J4ABJbrsz1VZF1nFiF8AjoUWt+k9/FfEBw7h9hHPvu+ZVetjg0UW9cQnncmlqugqRAwVJoSERoSQerjQaqjBqbjYbkeAAFNRQLyloMiyeiKTN502J4qsT1VpCtnvU9n+/5SyWDvhoGs76JVrdy5YhueJ1AUiaff6tqjv7Wk6kTnnkHm5Ttw+rajVTW/L+cYnnAwyRM/S+qpX5N+7o9sOPWfOHlmA67/jljbkk29rOvIcd0Dq/F8gaGVM9NTG6JcdOhoenIWf13XheuJsmr+Dlnwu1eUs9yVbHaFCu8/CyZUoykStve37y3nlt9H4Y0niR95CeHpx+7DsxuKV0hRWPsskTmnoASjuxwnS7A9XWLl1tRghdCO1TTzxyZZuTU1KMKoKjIXzB/F+fNGVZ43FT5QGKrCmOoQTYkAPXmLguUS0hWWbuohV3LQFInuvMvMxgi6qpAu2oR0lVhAoyNrcfDJF/Hwn3/DM/f9kcv++xdsD2v89a1OQOB4LqmCQCgC34P6hEze8okGJQ4ancATEq4o308lyyOgygRUcHyJLT15MqbDxNoIk5tr+dwN/8N//dvF/PabX+DYOU9x6IxxZWE2vxzsZ0oObZkS3VmLuqhBQFcYnQxRtD18UV6gL9nYS0fWJKLJBDWVuniAWEClM1u2Cuq0XFRFJqBVvLQrVHi/UBWZ5mSIrmyJJZuzyDKEdRVZlplQG2drXwFdVWlKBHBcgez6eJLA9T1aekpkiyY9bV1033cTaqyW6oXfQFJU3q/tsdzy+5H0IJGDTsH0wHcFoZCGrkpMrIlQEw2QNz18IWE6HnVRg0hAY1RVCMfz0RUZVZHJ5qz+tlaJ9nQJIQTJsI6hfnieN5UM9h6YPzbJTefO5s7PHcF/nDqViw8ZPWJwrUgMy2ZH550BikJ2xQOEtPfnT+0LiM47k8jc08guvYfOlU+wvjM3qHo64N39zQdW4/rlANp2fO5d1cqlv1nC7UtbeGZdF94OC3rPF5w4vZ4L5o8aFqhXqFDh/WP+2CTXnz1rl5Uxe6K4YRmpZ/6P0NSjiB992T49t53JrXoUPJfY/LN3O04IhlXTzB+b5IsnTBoMoHcUYbRdn9uX7rr6pkKFfzSqItMQDzKhtrxglCSlXNrtCRzPZ1RViEkNUSbWRplSH2FrXxFfCOJ1jRxywuk8+9BdrN7UxiutWapCOpbr05Xz8AHfG1iYKaiyR3U4gCRL2J5Ha7qA4/pkijat6RKdOYts0UbXwLY9Vm7tZVNXjrQIcdTnbyLb18PX//kTFAolVAkihkZtxKAjY9KXszDtclmp75eD6q5sCdcTeL7P9nSRqpCBqipkig5jqsIE9XJORldlJAksx3tfdGYqVKhQxnQ8ipZHquhQHTEYWxUmb7sIH8ZWB6gKa0yuj3LouGpmNEapiWoUTAfXA1VR0IRN6703IDyHuvOv2+1m+HvFzfVQeOsFIrNPRjbC2EDJFTQnQoyuCVMdMRgdD1KwHWQZhC/ImS4AmiIT0lUUudyuElDLNoIbu3O0p02yJZv2tIm/r21W3kcqGey9ZCDzsnJrituXtQxREJcl+OwxE1jTnqU6rPPAq23lTHA4SXjG8RTeeJrcMR8b9ILb10iSRNVJV+Km2uj5yy+4I9FAYNRMZAkOG1fFjx5/a8imgAA2dOYGldGdEbJlJcfj/HmjWLSqFcf1URR5SBaqQoUK+56VW1M8trr9b+qNsru30PPQD9HrJ1D90X9Dkt6//VPfsci98gjBSYehVY/a7diyWjGYjs9nfr+MM+c0cV5/dnrl1tSgNZeqyIPPJEE50F6yqbfyvKnwgUWSJDRF4pBxCfKmjev7TKqL0BgPUXJ8ogGZiKHQ2lcsi6QVLI4555Msfeoh/viH/2PmqZdSdF1yJQdVAVmA5ZdVwkMKGLqOjEdLn40iy1iWT9Gy6RYSigwBTUGRJFRVwrQFhiqzriNHXVSnfuIMjvz01Tx/y7f4+Kf/iV/86jf0yQ4dWcFz63qw3XJfZDKrMa0xQTyg0lewMR2PDV05ckUXVbaIBVTUflG0gKoQ0GTylkvOdPB8Qc5yaYgFUPcT5d8KFT4opIs2vXkLzy9719ueT6roENZVmpIGfXmHKXUR5o1LUrI8NnaD44EmK3TnLTIFi2V3/gC7p4W6C69Hqx69T89PoawXAWAAhVUPgxBEDylvuquA8F0yJYdISOftjjyGKpMq2DzwynZkGcbXRDhuan3Zriuo05U3MW0fXZMIajKW6xENqmRNj1hAwheCkRt3P3hUAux3yfyxSW48ZzbX3PcGgnLW+nPHTOD3i7dguz6yJA1ZHMcOPYfCG0+Rf+Ux4kdevFfv4eVTZJbchW+bBMfPIzT1yN0KCAFIikrNwm/Q8ef/oHvRTTR84idoiQae30kBHcqL1xVbdp8ZOn1W45AS+btXbOOOZS0sWtVaKRWvUOF9YMcy6T0hAzuO8vIpuu75NrIepPb8b+6Vr6UQPqUNyyltXAayQuyw89ASDXt1roU1z+CXssQOPWevxg+QKbncurSFu1e2cv1ZM7nh4TXYbllQ8fgptTyxtvOda5Skwf7sChU+yDQnwxw3tZ7OrEnEUEgV7UFRHtv1iAd1qkIOiaBKMXYwM+Ydzl/u+j1jjz+P1l6TuCGTs3wKoux/HTHKTh/JkIaQFITw2PDCQ2x8YzlGdRPRWSehxaqojoaoi6hkTQsfmaLjk8pbSBIEdZna2cdx0Fmf4aWHfsd3bhrPqZddiSQBwqc3b+N4HrarcfTkWralS+iKRMESjEoGCOoKXRkLVZGYXB8GyqrqjfEgectFCEE0oFG0XfKWOyicVqFChfeO7wt6chaaIhMNKDQlQiRDGjnLxfUFb7blMG2X5mSYvrxDMqQztSFO0e6luyjQFInX7vsfutcuYezpV6KMnzcYDO8OJ9WOufpprHQHwYmHEp5xHFCOdQZiG52ydaCqQtEGQ4KAWmLTa3+hasaR1NU3IElQGzZIRDQcX1Ad0pGAdNFlQ2eBjnSJcTXhfjHIaLk1FSjaHtGARt500BWZ0VVhMqXyhl5j3PhQbeRVAuy/gakNUbQd/LJzljtY3iiEGPpBrB1HYPx8sqseInbYeXtlldP9wPew2tch6yEKbzyJWjWK6lO/QGDMnN2+TglGqTv/Ojr+9FW67/k2DR/7IXIgMuLY3S3hj51cM+iJPX9skiWbeoeVilcC7AoV3hsD2duBIPKnT709aJ0nSzChNgJCsKmnwM5VUTvev75j0bXoRvxSlvrLvo8arWFPWO3r6Xv8v7A7NyIHIgjXwWp5g8Z/+u89qnUK4ZNdfj96wySM0bPf5VWXcVyfx1a3Dz43HdenJmoQ0GTsftugGxbOqjxnKnwosF0fx/dpjAfImy7pokNdrLzJlSla+AgkSWA6PgifY877DL++9vOsff4vhCcehabquMIloPmoioLvCzwhowBFy+Gtl55i6e+/ixGrxso9jfTU7VQdfh6lYy/BsoJEwzq+Dx3pEq4QmH0FkhGdaEBn5kc/Q2/bVu77zU+oGzWeUQcfB0LQni4SDGhoJZeHX2tnbHWY6ohGW6ZEQFWQkZlYGyYZ0RHinWeCJEkYqoIkSThe+f5V5Q9HRqlChQ8yBcshb3kEVImC7VGwPFKlInXRALWxAPWxBNmSzaqtKQxdxnElLMelK2chSxANlK0Cw5rC80/ewYZn7mb0UQtpOupsckUf1wMThlXISYBnl0g//ydyqx4GQAnFKax9FjkUJzhu7pDXeIDtg2pDWIVRNQG2v/AYbinPnFMvp64+hitcTFsgIVEfDdKYKLsddGVLpEwb2/UoOj6qLIMkQJQ31SXKzxUBxII6bsEiHtSoj+o0JUJ/l/8P+4pKgP03sGRTL27/B8Bxfbr6m/EHFos7EzvsXLruvJb8mmeIHnTKbuf2zDxW6xoSx3yc2BEXUlq/hNQzv6Pz9quJHX4+iWM/sdtstlbVTO25V9N553V03/9d6i78NpLy7v43v7ihh9uWtpAq2iyYUD2opm67ZQ++im9thQrvjh2D6WGiXrI0aI03YF2lqzKfOWo8j61uZ2N3YZfzCuHT+8j/w25/m9pzr96jYrgQPtkl95B+4c8o4QTVZ36V8PRjyb/+JH2P/xduqm2PgoylDctx+1qpOes//2brDE2VOX1WI8u39A1uVJ4/ryxstuPfqUKFDwOSVA6yOzMmBdtFoRx06pqM50FDLEDJ8ehpzzKtMUbjWWdw76/Gs/jBP/OJ751GS08BNSzwJIOC5aJr5X7ptzsLRAISufVLCSdrOfuGu3hr00a2PvVnel++k8KGZUgXXYU8dhzjkiHa00UCioSsKNRFDVRFQ8HnkI99jVKqk9/d9J9cdN2vqZ80k5LtokqQFQJEifpogHgwTFAtkS7arOvIUR3RSQSNfnswDQkI6Sq6KlMT0dncUyi3qIX2nDioUKFCGSFEWb3f80gEdQKaguV6dGRMdFWhr1BW+m+IGwQNGUNVqI0YuF5Z+d/1fXQZ0iWHdMmiK1dig6ER1GRkWebtVS/z+G+/R+30w5hw5j/jui4BDWy57Em9Y5hiAE7nRtof+B5uqoPo3NOIHXkJSjDGtp9fQmnDMoLj5g45/4FMuAu4EvT2FXj72XuomTib0dPmMKE2iCqrdGRL6LJMcyJIcyJINKiyrj1HUzLMmm19FGyXBRNrqI8EiAY1qkI6uiKRtzxqIwbxkEbYUHB98aG06tonAbYkSQngt8Asyv/vPiOEWLwv5v4gsmBCNapcVvoVwFNvdrLwoCZ6CzYv9vu77khg7EFodRPILltEZM5Ju+2N9M08AEq0GkmSCU05ksD4eaT++luyS+/F7t5C3QXX7/aDFhgzh+rTvkTvoz+l74n/puq0f3lXH0xfwLX9JfCaInH7547gU0eM45YXNuH5ghseXsPUhmhlAVyhwl6wYzA9oMa/o6iX4wnEDk+NoybXcPqsRm54eE0547Ub0s//keK6F0kc/xlCU47Y7VjfKtLz8I8pbVhKaNoxVJ/6xcEKFyVazqIPPH92R3bpvSixOkLTjt7tOFlixA3HU2bU8/njJjJ/bJKpDdFhAXXlubJ/cCCtCzRFxlAk8pZDqmjjegJHCGY3JyhIHgFdpT4aoDtr0pYqYrlw7sc/x2+++w1a16ykZtJBpIsyvu8R1g1sx8MXkAwqKIpCJpNGj1ZjAXqiieaF/0F85rFse/CnrPnNVxD/9ANsezKmB4YOnuvh+xK1cYO2dIFgwOC4z9/EEz+8kvt/9O98/KY/oGhJtqZMQjo0xYM8tbaNtW0pkCCkK/QVLMKGQn29QVeuxKotfWiKQl1MZ3J9jEzJwfMEybBKqugSDmhoH6LyzQoV/p54vigHxopMtlR+TuiKTHumxOhkCN8HJKlsiefKbO8rsKEzRyigcciYJJIELb1FWlNFMgUbhNTvVW/S0leiLgau56Jn2/jNt79MonEccz9+Le0FF00BzwXkoe1lClBY+wxtj/0cLRRj0iduRmmcjdl/XA7G8a09rAkEpNe9RKmvk6Mu/zeqIjpBTWd0dYhR1SEs28dHUBUxCGkKU5pidKdNZoyuojluMLYmwriayGDpdyyoEwu+M72qyHyIhMOHsK8y2D8D/iKEuECSJB34cOXx3yXzxyaZVBdhbXsOKAv43P9qG9Mbokhlh6shi0tJkogffh49D/2I0oblhCYfvsu5ZaP8p9txoStrAapP/RJG8wwkWdmrYDky+yScVDvZxXeiJhuJL7jwXV3jwA1oe4JfP7eRv77VNXg99g6K4pVsU4UKw9kxY71kU+9g6bft+IO/1/vbTAZEwAaY2RgjVbSx9hBc5157nOySe4jMPY3YYefudqyb7abr7utxereRPPFzROefNeQ5MvC8kY3wbuextr+JtX0tyRM/u0ddCCi7K+yooShLZU/sHYPpyrNjv+WAWReYtstb7Tk29ubpy1rMHp1EFoLXWvpQZBkfKFoeY6pClByPDV15DjvxbO741Y955ZE/cuq//4zGpIrjg2s7BIwQ6ZJTzgxLElowjN3RQl++LKQW1mWSs48i0TyO9hcX0TBqPD2mi+9AxhO4PuC7pIomkoC05RGJJjnmCz/k6R9dyR03/ytHfvFnSIEgmSJ05bKEdYlU0cJxfcbVRKiPGRRtj6xdFlUKGSphQ6UzY9MQc+jN2/3q5xYhTSFTVPqrcCSSIY2AXimQrFABymvm9kwJv98mV5NlLMcjV3LwhKAxFsRQZYKaTMFysR2fjqwJSJiOSU/OJGs6bOrOUR0xiBoa23qLdOZMZAmypkNQk0n1dfPADZ/FCIY4+gvfg2AQP29hOiCr5YBaptxDrQhBz4t/puflO4mOm82kC68iEk/guILeElhC4Jv5Pa4JFEmwffEi4g1jmH/MyciKPGgFmCo4BHSZuliQiKECElProoRUhbpYgOZk+Xno+uJDG0Tvjvf8BJQkKQ4cC3wKQAhhA/Z7nfeDzPcefXMwuN6RNzve+d3OmZvQtGNQnvsj2aX37j7ADkSR9BBOqn3Yscisj+zV+Qm3/OdPHPMx3HQH6ef+gBqv32tP3J0zT51Zc5g12W1Lt/KjJ9YhBOj9We7KQrlCheEZ608dMW4wP+0DyZA+REDw1qUtQ15/14ptzGqO71ZJvLR5FX2P/5LA+HlUnfzPu910s7u30HXXdfi2Sd2F3yY4/uBhY9xUGyChxGp3e22ZpfciB6JE5uy+1UWiXOZ+3ZkzeWZdF399qwshBLoqV4TLDgAOtHVBe6ZEuuRQGzZY356lLV3C9jxqwho526cvbzO9Icq2VImC7SIjMWVUNSec+wke/N1POGTbeiZNm0FvwUJWFQ4emyRfcsiYLus78xjJBoorn6U3X8T1NEZVqwRUmVhwDNJpV9KSAWvgZHwIAu05m7ANU+ojbN+SxyxZ1DeN5fz/+DF33vQFVvzheo74/PcQigLCw3Z8umwLTZXZ1ltk9ugYC8ZXc/C4KlxP0J4xyfkOQUPGdj0c4ZEpWdiOQFUl3tiepiNjMb0pSk0kwMFjkh8qQaIKFd4PipZLtlSuyw4bKnnLxQgq9BVthBAYiozpegR0hcZ4ENcXvLYtxaaeAgoSOctlW1+RoKFg2j5VEY3qcJC+okWuYFN0fKJBjYjqcs/Pv0Y+k+bj3/lfMqFatvUVyqXhgOuW+69dQHgufY/9jPyaZ6g++BRqTvoCpqJSzJVr6cIaOJkswi6ixut3eW0KEOhaQ6Z1Ayf80zfwhcAyPXKqS6pkgwTVepBRiQAl26UhHiQW1JlUr9CWKlF0PAxVRt9PnxP7YotxPNAN/J8kSQcBK4EvCyF23Tj4Iecvazre9WskWSF22Lmknvo1ZusaAqNmjjxOktDrJ2B3bHjX7+E7JuamVWRXPYysBwhPP46aM75CZ66Hnkd+ghKpIjB61q7PEThpRj2Zos2KrSl8US4RP2JCNa+1Zt55HwHb0+bgz7YnuHdVayXArnBAs3JrintXtbJme2YwY+24Pmvas4PChxKQKpbjjPljkyxa1Tpsnr6iM6L6/wB21ya67/8uWs0YahdetdtMstm6hq57bkDWAjR87AfoteNGnrNjA1r1KGTN2OVcTs82SuuXED/yEmQ9uMtxAIoscd2ZM7ns8DFcdviYYT3oFfZ7Drh1gSSBJ3zG10WYWh+hK2chKaC6ZdGhvOUS0TUaYgFa+4p05UyOPecynrrzFjY9fRuHzv8+sYBKImgQ1FVyRQfb8UkEFGLNExG+R75jK1rNJDpSNk01AUJK2Y9akUHy3+mtdIDegks+n6fttefY9tIDBIIhMgefyIJTzuSMK6/j4V9+i00P/ZzpF3wVy1ZozZh4PoyrChEN6WiqzOTGKPGgTqpkY7o+EjAuGqAtV+KuJdtImzYyMG9MAlVRKTgefQWLgKpguX4lwK5wQNPaV2BTbwHb8YkENJriASQJDE3GUGRsz8fyyh70UF7/u65HZ8YEIejMW3ieh+V6zE1UEYnLpEs2dbEAW3pyFBwHx5NojKo8/otvsH3DWv7pW/9FePRk0t05LEcgS2XrLp9y77Rum7Te/11Km1fSdMLHiC+4GNeXMHfY0bccKPXHIHrdhGHX1Rgpn2t9JMiqR+4lXl3LUaefh6rrtPYWsNMuyajBuKoIedPr121QiQTKWg2GqjC6KoTrC3RFRt5PRRL3xdNPBeYB/yOEOBgoAFftPEiSpM9JkrRCkqQV3d3d++Bt/3GcNnPvrGx2JjL7ZORgjOySe3Y7zmiaht25Ed82dztuR7xihvyrfyG/+mli884kcfTlZBbfiZvppPa8a1HjDXQvuhGnd9su5xDAU2s7WbYlNZjB9jzBr57ftMf378lZexxTocL+ysqtKS69ZTG3LW3htdbMoFiZpsrMbIwNLnzL1jvviAS+W7/rcqn3t5H1ECf/2/8bbCkZidLGFXTdeR1KOEnDx3+4y+BaCB9r+5voTdN2+96ZZfciqQbR+Wft8Tw9XwxuJEB5M+GLJ0yqBNcHDntcF+xPa4LGRIgpdVGiAY1pjTHG1UaY3hhDlWRMW1Ad0wnpKg3JAI3xAK4n0VNwMMJRzr744yx5+hGmhGwWzhvN2Nowb7ZleKU1w+vb+9jUU6R+Ulmt39q2Bl0FWYZcziJneuiqivCHPktcIJ3JkH3tSbYvf5xZJ17AsZdeyYYn/wypdhaccg4nXPrPvPHMQ7Q/exszR1cxripKdVjH8nw8HNozNn96eSsPvdbKay0penNlAbc1bRlSOZt0yUKTJRxP8GprhlzJJm86tGcsbM8joO2HNZ8VKuwG1/Px+xfPvi9o6SuSCGjURHRKtkOoP0s9oMIvyxKaLOH577SDdedNLN8nqKkIIWhMBvF8j1dbeunMW0yojeF5Hi19BWRFxdBknvy/H7H4mSe47qYfMOeoj9CZKZG3fIJaObiWKT8TPDNPy13XUdryCjWnfonQYZdQ9KURS4vM1jUgyeiNk4f8XgISAYNRyQg1Tjstq5fz0Us+QzwcImGoBDQFWYLegkVbqkjRstmWLpIpWbg7VMKqilweu58G17BvAuxWoFUIsbT/53sof7EOQQhxixDiECHEIbW1uy9D/KBz1RnTufLYCTTEDBS53GeoqzKHjUsy8FmRJVB3+uvKeoDo/LMobVxOjd2JsovPVWDsHPBdrG2r9+p8hOtQWPssTm8r8SMvJjT1SPT6iciBKL5dKtt3XXg9yCqdd1+Pl9+1B/bOC/49O/KWqYnuOvNVocL+xsqtKX75zAZuW9rCL5/ZwK+f24jtDb17jppcw61XLCAa1IY8F3YMPGc1xff6PX0zT9fd1+PbJWb903cJJ+t2Obaw7iW6Ft2IVj2Khsu/jxrb9Vi7YyO+mS8/d3aBm+2hsOZZInNORgnt+ZwFkCs5exxXYb9lj+uC/WlNENAUZo1OcOL0Bo4YX0N1xGDmqARnz2nio3MamT+2mhOm1zN3dBJPCMbVBDliQjX1EYMTL/g0sqLy2//+GaoiUxNS2dKVZ3tvgXTeomCbOHqcQHUTuU2v4vrljLWuqdiejeW4lHY6n4E1QXfrViacdDlzjjqFWNNkwtE4tpmnPWdy0Ec/zdyPLOSFu2/hqQfuRFclRldHiAQCKJJKLKhiqBJr23LkSi4Fx+tX+HXozTuUbJec6eIjOGhUnGRE58iJVRwxoZpx1WH232VzhQrDSRVsNnfnebM9y9a+PKmCRcl0eKWljzVtWSK6Sk00UFbDBqrCOk3xIDURg4D2TjGx5QrihoquSYR1BVWSGFMVZkpdjPpIgLqohuv5JEPlXuyVD/+Jlx66nS99+d858dyP0ZUpUXJ8JAHJcBBFLq/jnWKGzjuuwWp/m5qzv0Z47mns7hva3PIqRuOUYdVqhgSapmCoMi8v+j+MUIS5J11AtuTQW7RQZJnaeJDxyQhhXcUDWrryvLK1F8fdGyfu/Yf3XCIuhOiQJGmbJElThRDrgBOBte/91D7YXHXGdK46Y/qI9js7ett+/d7X2dD1jmBZdN6ZZJfey1uP/4mas/5jxLmNUTORVIPSphUEJx6yx3Mxt71BadNKEkdditE4BeG5FN9+Ga2qGb1hIgBaooG6C75F5+1X0XXP9dRf+t3dZr/2RE1UpzdnIwBVkTh/3qi/ea4KFT5MDPRYD5SB74qZjTGWbOolGdIHBc20nXqQ17RldjPDOwjXoeu+m3D6tlN34fWkA00s2zLyRllh7bP0PPwTjKap1F3wrUGl8F1R2rQCkAiOG96bPUB2+X0gfGKHnbdX5wtld4Wrzpi+1+Mr7D8ciOsCTZHRFJngji6WisyEuujgj7broysK3Xmb3oKN8Dyam5o4deEFPHDXrZz6sS+QFUG29OToyTvkbYjoEA3KhMfPo++1JzEtC1kY+DhlReARHkIDa4Lqoy5FqR9PtlDi7eV/JdY4llDDJPqKDposM/WCr7B9+3aW3/pDSmqUw479CLNHx+jL23RkLAKajKJItGZKxAIqRcsjFtBRFInTZjawpbdEUFc4YmItybDO1r4ing970GesUGG/wvXKImapokV7pkRt2ECSZTrzNgXbIxmUMTSVnpxJ3vJQZInqiEHecgnpCongOzZ3AU3G9gWJoEEioLM9XbbuqorqCASaqjCuJkJbpsSiu+7khdt+zgmnL+SUT/4LD762jfaMhez5ZC0HTfHRNcj1pei88xrcdAd15127x7jCK6Sx29cTP/qyYcdUCSzXp2f7FtYueZpjz/8MSiiCn7fQJImp9REcX1AbDVAd1ljVkmZrqojvCw4eU+DQCQdOMm5fyTz+C3Brv1LoJuDT+2jeDzw7q+Du+PPKrSm29A5tOVOCUSJzTyO34kHix3wMLTG83FzWDALj5lJcv5TkSZ/frYCR8D3yrz9JePqxGM3TEZ6D2boWq20dev0EQEIIgSRJGI2TqTn763QvurHskX3BdUjKu/evvPLYCZw8s4FLb1mM44l9UgZRocKHhQGLrT2Vd//mhU2DOgbXnz1r0Fd+XUeOnz71NqfPatyrEnEhfHoe+QlWyxtUn/nVYZ6UO5J/4yl6H/0ZxphZ1J1/3R57pQFK65dgNE1FCSdGPO4VM+Rf+wvhGcehxkfOhA88oXa8ng3dBW5b2sJlh4/Z4zlU2C85YNcFuyJdtOgrWHSmC9geTK6LkCrZnH3553h00R08evv/UnXcJ+jNOxTtci91wYaU7aNOOAyx4mEKW19FTDqcwi66snZcE6jN00mlHZavf4nU268RrRuNqsrEgiotvUV68xZjzv1PrD9fw9pbb2Rscz0zm49mRkOMTd052rMWddEAk+oiKIpMLKD290xCydE5Z1418aBBbczA8wVF2yekKziej+mWey8rVNgfEeKd3mldkenJWXTli/RlHAKKgqGphDSFRChSziL75SC8fE9IlByPpkQQ3xdszxRJ5W1qowbVYZ1kUKcnb5Eu2JRsj5Lr0JoyGZPQSRVL9OUdVi15gXt/di1jZx7C4Z+6mkWr2kmXSpi2T2/BoWD6aAqIfC/bbr8GN9dN3QXXj1ipFqJczqxIkBOQ27AMEIQmHTZsrOlDpmjR/dTtKKpG85HnsbErR8l08SVBpuQyvi7M/DFxeooWb7ZlMR2XaEBlbXsfc8cmDhg7v31ylUKIV/tLveYIIc4RQuy6BvkA4tfPbcT1hi+hY4eeC7JMdumue7FDU47Ey3Vjt7+9+zeRJCRVR3guAIW1z2NuXIFshAjPPhlJkoYE6EbTVKpP+xLmllfofeznCLHrreaq8NDg+6BRce795yO56ozpLNnUi+uXFQc9XwzadlWosL8yUBaeDOl7tMqT++2pBGURwNVtGb54wiTWdeS4+r43eGF9D1ff9wZFy93tPEIIUn/9X4pvvUDi+E8RmXnCLsfmXnuc3kd/SmDc3HLmei+Cayfdgd25keCUI3c974oHEY5FbDdWf5Nqw6gj9FI9tnq4G0KFA4PKumAoJdtl1ZY+bl3Swur2LBt78mzuy7G+I8vbVoTJh5/Io3f/id6ePgy13DMJ70ivG2NmIxthiute2v0b7bQm6F7zPK1rlmNJQQLTT2B7qkBbyqQjVaQn55EruYy7+FvokQRP/+w/6d2+GdMVzGpOMKkmgusLWtMmrusT0hQ8XyKgqjTGgtTHgoyuDpX7rQWI/i02SZKQ98JOtEKFDyupokNbusT2dIlMyUGRJYolH0+C7oKJ67l4widr2qiyRDSgkio59ORMOrMmcv+90le02NxTwPZ81nXm6MubuEJQsjwCmkQiqGJaHlFDYX1Pnkde7+C+p17kju/+O6HaMcz42DdpSZl0pov05VxSeYuSU1Ybl4p9rPv9VXj5XsZcdMMu28CEDLGYRHVMYXJCwVn/Emq8Hm0ngTOJsnJ4vq+LzUseZ+wRZ0AwzFttafryJWRfwvV8ipbH2rYsqzanyJk21VGdqkgA0/YPqNaRyvbi+8BtS1v43Yub2NA9smCqGq0mMusk8m88TfzIS1Gjw21rgpMPB1ml8ObzGE1Td/lekiQTO+w8eh76IYW1z6LGajFGzyY8/ZhBRWDhe2SX3YdXSGF3biQy60QSx3yc9At/QgklSH7kn0acu68wtENjVnN8MDufDOn9X6BiWNlrhQr7Gztbb41KBNnaVxw8vqNK+KHjkiRCOk+s7RxyHIYHnM+v3724U3bpveRWPEB0/tnEDjt/l+Nyrz5Wtu2aMJ+6c69BUvVdjt2R4pvPAxCedtSIx32rSG7VwwSnHIFes+tM9PpdPOtOn9W4V+dRocL+iO8L0kWbTT151rZlaenJYto2noCS5bGlu8SE2jCqKjPlxMt4e/GTbHhhEckFF6Dhs2PHoqJoBCcfQfHtlxGn2ru8x3e3Jsj4BmvbCii+R8tL92EXUjidG6k+6ERmfPIG1vzma/zX1z/L+df/huqaOiRkZEmQLdlkijYBTaEuWu4ZjQUl9H6hGd8XpEs2nufTnbeYUBOuiJxV2K/JlhzCuoKgrKtiey5VUYNGVSZbcmhOhunMmWiyzEGjE5iuT7MskTddHN8n2q+o7fuAgKCm0Ju3eWF9L0XHRVXAtASW5+H6kCk5bO8rkevazvJbvoESjDDlEzfS46q0tecIaRK+kHD9ss91oXs7a39/NV4xy/iLb8AYPR0JKHmwcyd0VIOmeAjH9TFzaXKbXyV5yMJhiQSD8lqm/cVFCKDx6PNxXIErIG/ZdORsmhNBArrC0oJNMqwT0srZ/eZYmIPGJA4oZ4ED50r/DqzcmuJzf1zB1fe9scvgeoDYggvA98q9jSOgBCIEJ8yn+NYLCH/3wgB67VjqL72ZmjP/napTv0T0oFMGs1fC9+i+7ybcTCeBsXNIfuQK0i/djj56ZrkffPl9ZJYu2u38A5625/X3Wa/cmuKGh9fg+QJZKtvxVNSBK+zPDJSF+wJsx6clVRxyfEeV8FUtaY6fWofeL2ay472zc8C58ybWjuTfeIr0c78nNP1Ykidescusee6VR+l7/JcEJx5K3bnX7nVwLYSgsPZZjObpu/S6zK16GN8qED/i4r2ac0eOnVwDwMf/dym37eT1XaHC/o4QgvZMkRVbelm9PU3OtLFdQUjX8QQkIhpjkiE6M0WeW9dJJtxM3cwFrH/mHhRhUROGuFJe1KpAUofkjGMRdpHSxuW7fe/drQnSlsdbd92E1b8mSHzkCjqfvx07l2b2p28kn+7lnhv/hbe2dLCxM0NHtkSqYFOyXTpzJvGQQVVYoyYSINbfO2p7PpbrUxcLUh3WUfZjZeAKFQCCukLB9ijaHqbjEtQ0XF8gSYL6WBDL9aiNGMSDGpmSS9RQkZCIBjRGJUIYmoIQgnhAJR5U6Ss49OYtgpoCQqI7bZExHRpiBomgwvZ0ATvdx/O/+Cqe5zHlYzdQXV0DjkAHEDKm7YMPdraLNf/3Ddxijumf+A6hsdMRMqjKcCFjBVBUic5MuX/87aVPg+8RmHHc4JiIDAEoe2oXUqRefZzqOScgwjXkLJeiZdNb8ig6Hl0Fi7asRapYVg2f0pigKR5iWlOM5mSYzT153tiWoitrIsS79VH5cFHJYO8DVm5NsWhVK3ev2DZMSXhXaIkGwjOOI//qY8QXXDiiMm945gmUNizF3PoawfHDhNmHoARjEIyRXf4AesNEAqNnIYSg54HvIwfjJD9yBZKsICkqockLUIIxkid+Fq+YIf3s71BCMSKzTxpx7qMn1/CVk6YMBtE79qAKMdSOp0KF/YkB0cJyxcYOgkI73OayVC6J9Haw50gVbW7/7IJhAoipos2Vx05gTXuW9nRplxtxxfVL6X3s5wTGHUzNR/8NSRp5LzT3yqP0PfHfBCcdRu3CbyCpe6+p4HRtwulpoeqUL4x43LdNssvvJzB+PkbDpL2ed4CXN/YM+nm/0P/vSj92hf0dx/PpK9g4rk/BclFVhYCm4PuCeMigqSqMbXpMbooRUCXuWtmC64HwXaKHXkjXmq/SuvgRqhaci4OgPi5TcARCCKomHURHpAp77V+JTz1qtyrA72ZNEJ68AIwYkebxHPLpb7Pst9fw7C+vYvonbyQaDjGqKoKhyOSKDpoi0RAPDtnwU2QJSQLT8fAEaEole11h/0MIQW/eIm956ApkShaWIzBUmTFVIaIBDcf3GVsVYkNXgbzpURvVCWgykYCGqsi4niCoyVhOudrD9XzGVIUJ6SqaIpEq2kQ8BZ8ATckARdOlPWOR7kvzzM+/ipXr4/gv/wStbgJdeYeUVc6U6nhYQKTYxbo/fgPfyjP64u8gaqegKmCbYMvlpYtBuQVFBXQFCpYgoLp4AjpX/RWtZsyQ8vB8fyepcCC17D6E51J11IWYpodpe0R0GVlVyJccYoZKTJepCukogGEoTKqt5tDx1WztKWK5HmFDZX1XjrChEjb23zB0/72yvxPfe/RNbukXM9odAyWkOxJfcBGFNc+SXX4/yeM+Oew1oUmHIwci5N94ao8B9gDh2SfiZboAKKx+GjkQIfmRKwbLxUtbXqW0eSXReR9FkhWqTroS38zT+9jPkQNRQpMPHzbnQHA9sJHwdmeu3GMlKuXhFfZfdiwLlyXo1zPBp6yc7/dXcNywcBYtvQVueWETQoCuySRDOveuakUC1nXkWLSqlTuXtwzOMSoR4JBxVSMG2GbLG/Q8+H30honUnnv1LoUIc6seoe/J//mbgmsoZ8hRVELTjhn5+Gt/wS9liR/57rPX8M7fa4DHVrdXAuwK+z1rWlP0FmxihoqLRFVYI5WziIc0Zo1K0pwIENI1evMWHdkSo+Ih0kWHnqygZuIM0pPnsfW5e6hfcBYBTcF0fTTK95OsKkRnHE9qxQPUFFJI4eQehRL3tCbwtrxKcfNKpKM+iocg3DyFxrP+jbb7f8Sa225m7ieupWjqtGdkekoOr7X00RQPEjRU0kWbjV05OrIW8aDK6GSQhliAoF4JsCvsf5Qcj6zpEtYV3u7KUbQ9qsM67RmTkK4Q0lWqwhrxkM6MRpnWVAmEjxCCdNGiI2NStD2EEDi+wHU9ZODtTpe6WIBoUGVLb56tPUWSYY26sMGqngLFUoGXb7maQudWxl30TTq0sdi9Dhrl4NoDSoCb6WLt7d9AmHnqLr4RGieT90DpLwtX/HIsYlOORxTA9spJAscFK9VKqe0tEsd/eljFnAIIM0v6lUdJzjyGRG0zIVWm5Pl4QiJpqMhAQyxAYzLI1Po4s0YlQALL8nB8H8f30DWZsKHSV7BxPY/9OQzdf6/s78BtS1v41fObhv1eolw+EtYV5o5JcuVxE/nT4i3c/2rbkHFazWhCU48it+phYoedhxKMDp1H1QjPOJ7ca4/jlXLDjo+EEoig9NvyeIUUWt14JD0AgLltNX2P/5LEcZ9Eq2rGK2ZIPfNbAmPn4FsFuh/4HvUXfnuYEMKTazoAuPiWxUNE2xS5Uh5eYf9lx7LwnTfQZjbGOGVmw+Dm0vUPrkb0i5rpisy197+x20231rRJ607PAwCrYwNd996AGq+n7oLrdylU9k5wfTi151z1rt0AhGtTWPMsoclHjPhcEa5NdtkijDFzCIyasVdzTm+I4ng+m3rKmwY7bkpApR+7wv5PumizritPWFfZ1JulPmxQlzCY3BjF93yUfiuveFClta+A6/mEAwoRTcUKaTSENewTL+GVX32NDc89QuKQs0mGwHSg6PSXis8+mdSyRfSufob44eehUd70U3hHEG1HdrcmsLatpvfxX9Jw4idRq5op9PWy9Yn/RaudQN3JV9L15P+w4b6f0vyZb5II6YyvCZMquWzuyTOpPsrGrjybe/PIkkzBkijYPqO0yrKywv6JhFSu3AQcT6CrMiFdJWx41EUDhAMauipjOi5vbE/Tm7N4uzNPc3WQgmmTKTkULI/acICJ9WHWdmRJ5R36ChbTGmI0xIMUTI+jJtWQNx16CxaxoMxff3ktvRvfoP7s/8QfO59C//fqjhUsbqaLjoHg+pKbhlSdDTSZ9rd7owKGXFYEh/IvfSD96hMgydTM/Miw6hgP6F32AMIxaT7+IiwHJFWmJqaTKXnUxoIcXh1E03RqowZTG2LURnTMftGz7qzFhOowXXmbvrxNXcwgEti7drYPK5Ue7PfArhRyBVC0PbrzNs+t6+LJNR3DgusB4kdejLBL5FY8OOLxyJxTwHMorPnruzo34Xs4vdvwSzkkSaa0aSU9D/2I+JGXEJ52NACyEUJvnEJ22X3EFlyIlmika9F3sHZSLv/Lmg5+NYIi+kAp7IC68sqtB7RIbIX9jAUTqtFVGVka/qC8+NAxfPGEScwfm+TeVa3YnhjMJGVNd48VLSPh9G6j667rkAMR6i76zohtI7BT5vpvCK4Bim+/jG/miBx06sjv8drjePk+EkdeslfzScDHjxhHa7qEEOWfPzKtniuPncAxk2u4+dzZlex1hf0ex/OJBzQczydbtLA9n860SXfWYm1HjkzJpqWvyNaeAms7smxLFYkFDc6e28zFh4zBFDJ2cgaB0bPoXXoPnmuXN+5E2X9WVsob84Hm6eRffxwhys8dXQZDK4sb7Yqd1wTmppX0PvQjmo67hLrZRyP5EIhEiY2aSteLdxGuH0X9sR+j85W/8uq9vyBfssnbHpoiIckSmZLDtr4ivTmLTNHG8Tw8XyBRLhXfnirSninheBVT7Ar7B4H+6rSs6TAmEcRQZHoKFqOSIZIRY1D0L1dy6SvYdOdstudK9GUtXlzfwwvreli9PcPT67p4cV03r21Okc5bhHQNyxFYjk9nzuS5NztZtGobT65p5QdXfZktr71E8xn/THL6sSOe15Dg+uIbCe2ipWtgWeIC+P3WXJSDZ99zyK3+K8FJh6FFhifNPDNPduVDxKYeSah2HImghO97mLZgcnWEs+c0c8y0Ro6aXMPo6jDNySDRgI7tCEYlgwR0me6CzeiqIIdPqGJaY2y/12qobDW+B06f1TjYW7grHE/wl/4M8EjodeMJTjmC7MoHiR26ELl/p3nweP0E9MYp5F/9C9H5Z+/RHmgASVZIHv9pOu/8Jm6qDTfTRfIjVwwG1wBuugNJkghNXoBW1Uzdxd+h89av0XXXt6i/7LvoteMA0GSJp9/sHPYesiyRKzlD1JVvvWJBJaNdYb9hTnOc5VtS5V1fGWY2xbn40DFDgsV98RXhZjrpvONakGXqL74RNVYz4rj3mrkenOeVx1ATjSPadgjXIbvkHoxRMzDGzN6r+RQZ7n+lFdMpL6Y9AU+s7SSgVZ4JFQ4c4kGdmqhBR7pIRNcIB1RSBRshgaHJGKqC7frkbY+aiEGuZNNaLJGoi+AXJGIhjUhIpuaoS2i941rstU9gHnImjoCQIZEIq/TmHZJzT6P9kf+H2fI62tiD0FSI6PSvlHeRyZYVmk/4NFvu+CZ+ug1yXUz46GeZePhJFE2HVMmlt307liNITFtAQ0Md6rT5xDWbt5++i7qaKoLn/TO+8NBI8shrrWSKNkFdoy9vE9ZV8D0KlkNf0UGTJRzXpydv0Rjfs2VghQofZDxf0J2z2J7K01twqIroNCUCNMRCBHZqiQjqCq4Q2L6H4sPmngJtGRMNQcb0kWV4q0sQ0mS2pYrEQyol26FkO6gyLN7Wh+X4vHbnj2hb+igTTvs0kUPOwHEop5r7kQF7p+DaaJw8TCV8JHZuTiu+vRivmCZ+0GmUdjomAYUVDyDsIjVHX4KsyIytj+E4HlURnYZkCGSFTMklYqjUxQLkbRfL8ciUbLryHlu6izQnAvTmHWLBPVud7g9UMtjvgcsOH8PN586mIWbscoymSJw2s2G38ySOvBRhFcjuIosdnXs6Tu82rG1vvKvzU8JJGi7/AcmP/BM1C782JLh2erdRXL8Eu3srkYPPQK8Zgxqtpu6Sm5BUja47v4nTtx0oW/CMlJHzfMEtL2zCdMpltI7rV/ywK+wXrNya4tJbFrOsP7iGsp3GKTMbhmViz5s36j3txLq5XjrvuAbhWtRf9B20quYRx+2LzDWA3b0Fq3UNkbmnjSieln/9Cbx8L/GjLtvtl+Ck2vDg5oLnw7ItwytYbKfyTKhw4KCrMvPGJDlqch2HjKuiLmrQEDOIBnRqQjqm4zMqGaIhZhDUFCIBjZlNMQTQlSniugJfQHT8QYRHz6B78b1EVI9kqGyXFdAUJlQHaDr4aORAlPwrjxJWQZFA1zRCRjnTveNdK1Fe6PmAF0oy9vIfMOaUf2LqJV+neuZRlEouhiLwereRenMJpa6t1B96BsnG8ciyzNwLvsiko8/ixXt+y+uP/4lMyeV7j6/j2bc6WdeZZ3u6hCIJEkGNlrTJi+u72dZbxHT8wex7hQofdnKmQ95y6M7bFEwH1/VJl1zkEb77w4bKzMY442rCTKyLMKoqwOTaMAFdw/UhrmkIHwpOORR2PEF1WMPHZ21bFsd12fjILbQtfZTRx15M9YLzUYBIEDTesctS8l107hRc6+y+kmWX1/fKo6jxevQJ84bpOnhmnsyKB4lOXUC8eQJhXSaTtwkYCk0xg1FVIQxNYkxVkLG1YaojBus7cqzvztOTt2jtLdIYCxA2NDKl3Ukz7l9UMtjvkcsOH8PUhiiX/mYJjuujKhLHT61DAmqiBufPG8X8sUnGVIf53UubaU+XKNhD95f0+gkEJy8gt+IBYoecPSyLHZp+LKln/pfcqkcIjBnZKH5XyEYIjBAAwnOQFA0n3UFx/RLcTBeRg04d7NUQQmBtf5PInFPJvfIInXdcS8Pl30eN1+1y/h0Dbx/IHUA3T4X9lyWbenF2aomQZWlEQb/5Y5N8Z+GsPfZdj4RXzNB157V4xQz1F9+IXjd+xHG5VQ/T9+Sv3nPmGspfpCjaiK4BwnXILLkbo3k6gbEH7XaeWc1xWtMlnEFHgeFjdvU3q1Bhf0WWZZJhg3G1kC451EaD1Md0BGWlbUMt2/OEdI2tfQVe2dLLmu0Z8raHL3wm1IZwXGg495945udfpe/Vp2lccBaq4lEfC+ILmWrHp/rgk+lecj9+oQcRrcH1fHQFDB0s652+y4GeUegvDTVCBNQQslH2vw2GNNIdrfS9tQQzXV4TRJsmYXsQCsD25U9RNWYK4w47maf+8FNkLUD84NNpSgRpy5SwXZ+wHubFjT3oiszkhihBTWHZll6iusL0pjhCiAMiY1Vh/8ZyPHIlh76ijarITKjVUEcIsCVJYnxNhJqIQV0siNev5ZIvWZR0CUn1KdgC1/Opjep4rs/WVAlZSEhCsPbRP7D56TuomvdRmk76GKoqI2yBIkFAEWgKSH2dbP7z1fhmntH9gmZQvu/fbebU7t6CtW01yeM/hSzJwwLs7MqH8K0CB53xSSaPS5IruQQ0hWTQIGP5GBkbQ5JpjAcQPpTscrtIfUjHE4Js0aUqopM1XRIhrWxFdgBQCbD3AfPHJodZ8ozEhq78LudIHHUp7euXkF3xIImjLxtyTNYMInNOIbv8ftxsN2qs9m86z55HfwqeR2DcQbjZHiKzPjJECEG4FrIepLR5JaGpR1F883k677iG+su+hxrd8yJZCPjV85sYUx2u9FtW+FCzYEI1iiIN0R244ujxu7y3Bz7v1z2wGncvo2yvlKPzzmtxM13UXXg9RtPUEcdlVz5E6qlf75Pg2jfzFFb/lfD040bs8c6//gRerofq07+8xwXx8q0pbr2i/NzLlZxhgo+qXFZYr5SHVzgQqY4YJPrt/Xa+lyRJwkfQnbVIlRyylkt9zMDzBONqQxiKjDvxFDY/NZeWZ++g8bBTMTQNx/HpyJfI5G2q5p9B9+L7aF/6GI0f+Ti9WQ8hgSYxrER04InhUM5+6RqkCrDxnh+iCEFo7Gy8XA9Vcz+CUj8JT0AkIOM6Dq6s07rsPqYfejQGDk/89nss+ISABWfi2D4HTYnSmXFJFWxkWVAfVUkENQKaTF08SGvapCqskwjvutKvQoUPOtGAhgASQZ1ESENVFMZVh0fMYEN5czke0pkoSWzsyhHUFGaNrqboOKxvzzM6qbOlN09LqoTwBZGihSxLtDx3N5uf/BNNh57KqDO/gO8LAqpGdaRsk2d7LrnuDrb8+Wo8s0DdJTehN0wabAvxGH7/74ncqoeRVJ3wnFMQQEwGyy8/L4SZJ7f8fpLTjsBonEjJcnF9qI3ohIMaEV0lbdpkO21iQZ0TZ8TRVRnPF2zrK5Sz81Ed2xWMTgZpTIQOmM22Son4PmL+2OSg6NFI7EoQbQC9fiLByQvIrngA3xweiEfnfRSA3CuP/M3nWHXyP+NmO8kuuYfYoedgNE8fclzWAoQmL6DxEz/B6dlK4oTP4BXTdN5xDV5+7wXM7lzeUhE9q/ChZv7YJB+ZOrRyI2e5ACOK+t22tIXHVrdz5pzGvSoX9808XXddh9O7jdrzriWwi17n7IoHy8H15AXvObiGcgAtHJPYIWcNOyZcm8ziuzCaZxAYN3ePc21PlXhyTQdfPGESV50xnSuPncDApQ8E15WNtgoHMmV/6JGfB6JfubcxYaDIkDU9xtaFOHpyLefMG8Ux0xr55Je+RjHVRebVpzC0soWPads4HsiRBqKTDiP16mN4jl3OOgkQ/jslpAM4vKM47AGlUlncaPyZ/0yhr53Ol+8lfug5hBunE1QgqElYHuhakBlHnMC//ORWNr++jIUfu4Jp849myZ9+QG71Mxw8vpraUADHcaiJqIQMja6CiyZLBHUVRZbQFejOW2zpKdCVMxGVmvEKHzKEEJQcl1LJwfEEhiSjKzIZ08HzfPKmQ7poDwr62a7Hhq4sb7Vn6cmXMHSFukSAvOXRk7awHR/bFXSkTXRZQpUlXEli0/OLeP3+XzNxwckc9clv0JAIkYwGqQrr+J5EyfHoa+9g05++gWcVB9XCR9Jc2Fu8Uo7C6mcIzzgeJRgDoOBDPAQz61Ss1x7AtwrMP+czVIcUqsMB6qM6JVegAJbjkC3ZSJJEV6ZEpmSjyBITa8McPCZJYyLA5PootTEDXVMGheAOBA6cK/0HM5JFjQQcNi7JKTPqURWJxNGXlXuxl98/bKwaryc0eQH5V/+Cb5t/0zkogQj1l9wMkoS59VUAhBiu8OlmOkGA0TSVuguvx8t1l4PsQnrEeRPBoYUQa9uz/PiJdVz+2yWVILvCh46VW1Ncfd8bPLWTsN/dK7Zx29IWLv/tkiGf7+89+iZX3/cGL6zv4f5X2/D2kMH2rSKdd30Lu2sztedcTXD8wSOOyy5/gNTTtxCccgS1C7/+noNr4XtkVz6MMWomev3EYcfLyuG9xI/efe/1jtz/6vbB/44GtcEy8QGHgQoVKoxMNFD2ja4OBzl2Sj1nzm7i0kPHcviEWsbURDl4TJLLzzmDsTMOZt0Tf0b1XZAkFElBUwAB8UPOxi9l6Xz9WUwY/GfHsnAoB9NQzmQrgFAgGgRPiTD+YzcjSRLOtteIR1RCuo+uqpRsD1kWVIcDbG5pwReCSKKKa//fb5k05zAe/9X19L7xLOPrw9QnguQsH88XRAyV2ohBYyzQ/7OGLyCoyeRNl6L9bvNrFSr84xBC0JWzWLapl6Vb+3ijtY8n13WiSGVf7I6sSWfOIlWwaO0r4nk+6zqyrGvPsXJLL3cv386qLb2kcg6m7VKwHJqSAQxNASRChkLedFj/13tZe99/0zT3WI761DV4QkKRJUL6QCWdT6q9jW23fQNhl6jfyYrrbyX/2l8QrkV0/jub7h7g22AXTXqWPkjTnGOINE7BlVQc30fXFCbUhDliYjUTaqM0J4KMSgYpuj7Zko3vCyxXENJVZEmir2BTsl38v8Ve5UNMpUT878RAJuex1e3MbIwRDWpDyslvW9rCN+8XhKYcSXbFA0QPOXtwN2mA6KELKb79MoXVTw9mtN8tsh6k8TO/xCuUA18v2wOSjF/K4uV78R2L4tuLMUbPQAkn0WvHUXf+t+i659t03nkt9ZfcNKy0NFNyOWduE70Fm6Cm8NSbnfiiLHD006fe5isnTamUiVb4ULBya4pLf1NWxd8Z1xM8trp90Bvbcnx+9dzGYYH47vCtIl13fwu7cwO1C68iNOmwEcdlly0i9czvCE05kpqzv4akvPdHdfHtxXjZLqpOvGL4eTkW2SV3Y4yetcfe6x0ZUxVi5dbUYJn4oCAckAzt3x6XFSq8FyRJYlJ9lHE1YWSp7K/r+u88d2RJQlJkvvQf1/Cfn7mAzS8+yKQTLmBbqogsg6ZBeNwcAnXjyK24n8ick5EkqdxnvRMe5cA6poOuSmRtgSZDJKBBRCPy5V/jpHsJGwrZTArX1cmm+tjU08Ua1cVav4ymqXPpMBVK3SVO+fL3eeAH/8Z/Xfdv+AKmH3kSybBOSFNpiAWxfWhKBqiJGDiuT3umhC8E2X6NlvqYQdh4bxuGFSr8PbAcn5a+IttTRbIll2TYQJJc8qZLKFBWyg6qKhnTJVN08HzB8s29ZIoeXfkijicoWi6u8EkVHYKGRtZ0qQppjEkG6cnbdK94lM0P/w+1M49g+sXfoOCWWzQ816cta+G6MtmuLWz849UIz6H+kpvR6ye852sTnkNu5UMExh40XP9FhU3P341TKjD+lI8jKxDXFTzPp9t0GJ0M9WtNROhbXc7sz2iMocoybZkSRcvDdT1KtosjwPX8A85NoBJg/x257PB37H0GFqVQLkdNFcslXvGjL6P49mKyyxaRPO5TQ15vNM8o+1avuL+sACz/bUIBsmYgJxrKwklP3UJw0mFImoGwSyiRKoLj52GMmj4YSAfGzqH2/G/Sfe8N5Z7snYJsATz8ejt3fv4IAJ5f343t+PjASxt6WL6lr2LVU+FDwaJVrSMG1wBCKvu7ShLlUkzg6Tc791oltxxcX4/Vto6ahV8nNOWIEcdlltxN+rk/EJp2DDVnfnWfBNdCCLLL7kNNNBKcdPiw4/lXHsXL91Fz1n/udfZaliAR0rn41y/j+eWf+/80yBKVDHaFCnuBqsg4nk9buoTnC3RVpql/Ier5cMQxxzJ+1iG89cStTD32HKojAZpjKq19Jbb2OsQOPZeuR/4f9uZVGBPmj/geBuX7MmOD4QiU/h7LQs4hGATTkwlVN9L68kO8ef9/EZ98GJ5i4Nkl8tEqwmNmEZg2h+2WTp0E8Wic//zhb/nZVZ/ll9/6Ny792veZtOAUPN+nO28SVGUMVSYWUAnpKsmwQUemhOf76IpEZ9ZiVPLAKhet8OHE8X1Mx6MqrPNGawaBIBFU6c6beAhiAY0Ws4DnSzTEA2zozlKyBTnTpjtjM6M5QkuqREiRyZbKfvGaouAKn7pEiLdfeoTND/2S6umHM+miq/BlDYFP0ZZozxbJFj3snq1s/NPV+J5L/SU37VIM9d1SePMFvHwf1af96+DvQoChgmxm2PL8IqpnHUOofgK9eZPWPp9IUGVsIoTj+/QVbWIhjSMmVOO4HrYv0ZEzKZgusaCG6XjIksSYqhCO56MqB9b9fmBd7QeEgSzZjx5fx6W/KZeZLphQjSRJ6LXjCE0/ltzKhwazzANIkkTssPNwU+2U1i99z+cRmXMqwYmHoARj1J71n9Sdfx1Vp3yByOwT0ZJNQ8YGx82l9vzrcFNt5XLxYmbIcc8XLNnUy/yxSW69YgFHTa5BlqjYd1X4ULG7WFkIWL4lhbdD/L23FU/vBNdvUXP21whPPWrEcemXbi8H1zOOo+as/9gnwTWA1boGu30dsUPPGbYx59slMkvvITB27i57wXdEptzaoioyT6ztxPXLfzdPgNQfZKuKXFEPr1BhLymYDp1Zk1TRpidnYbk+vhCYtsubbWkOOudzWLkUqx6/g0zRpD3t4AiJUABqZh2LGqkivezeYfMOLPAsyt7YLlAQkPMgW4KSC6kc5E3oybook08gOvEQfCNGVf+aIHbyFwhMO5FCsI5c0cMRULRsntqQYfzl36Z24kzu+MHXWf3iY9ieIBkwCAVUXE/QW7DJllwSQY2mRJDqiIGqlAWQ/P6dyUpPdoUPMpoiE9FVsiWX0dUBhOsiI7G5p0i2aCMEFC0fSQg2dGV4vTUNnoMkCcIBFd/z0WWFaKAsUpgqepQcH9eFtU8vYtVtP6R2+uGMv/BqUAwc30UGOtImmbxHrm0rG/54Fb7v03zpzfssuC5vui9CqxlDaMJ8wkBTuFwZUxXR6Fp8P65tMfqEj5G3bCzHJWbIKIDpC9IFB8d2Wd+Zw/Z8JEWmJqphWR7rO3OkijY506WvYLG5t0jOcjEOsA21A+tqPyAMZMkEYLs+v35uY/+R8hdN4ujL+u1y7hn22tCUI1ATDWSW3vuev5gkVaP23GtwM130Pf2b8u9G8MUd4J0gu53O268esgEggNuWbuVzf1wBwFdOmoKuysiUNwaSIX1EcagKFT4IDHw2ZzXF2debrL5VoOuu67Da15WD6x386AcQQpB6/k9kXryV8KyPUPPRf/+bK1RGIrv0XuRgjPDsE4cdy618CL+YIXHMx/ZqLh9Y1ZLGGSHTP9D76VcWzRUq7DUlx6douhRMh7c7smxLFWhPl+jNm3RnbUTdZKqmHMKWZ+9CcYq4nk/OdNAUGVXTiB2yEHPr61jt64fMu6u7UFAOuC3KPdtO/39n0Uiccw32DmsCIcmoMvhCQpd9enI2m3uKCCASiTHjEzdRN2kW9/3kala/8Bim67ChM8+yTb28uL6LJ9a080pLL5oi4QvY0J0nY9rk+jcVKuJnFT6ImI5HT97CdjwkGUKGwqhEmIIjKDg+mgxtGZO+ok1YV1EUeL0lS1uqwJquPPXRAHURA1dAzJBZtqWPkuuhSGA7Lhufv4+nfvc9xh18NGd8+fuMrksQ1jVkJDwhUGWPbNsmWm77BkgyDZd+F6N23L67vs2rcLq3EDvsPBRJwqFcMVOf0IhIBVqXPMj4w05m/JQpqLKChIwnyi2h+ZJDd9ak6HnkTIeQrmI5gm19JZBhdFUYVQJDkakK6UR0Be0Ay15DJcD+h7Dz18jTb3byq+c2DmbGtKpmwrM+Qu6VR3GzPUPGSrJC7LDzsNvXYW174z2fiyQr1F18A0o4gZPu2OP44Li51F5wHW66g87br8bN9w0e2542eWJtJ5fcshiA686ciSxL+EJw/YOrufQ3S95X8bNKAF9hRwY+D7ct3b2q/cqtqUHhsusfXE08uO96hz0zT+ed12J1rKd24dd3GVynn/kd2cV3EplzCtVnfGWfBtd29xZKG5cTnX8WshYYcsw382SX3ktw4qEYzdP2es6RrMgk3vHCdj3BolWt7+W0K1Q4YAhqCtVhnZLjURcPICGxsTtHb8GiK1dCkSSajrsct5hl03P3oakyti1wfR9VhuTc05GMMNml5U35gaeHAPbmabZjU8jAmkANJ6gxO6gNQUgBx/NxhEJjTKMhYSDJMpbj4usGZ3/tp4ybcTB3/fgqHl10N5IASfJxXEEipNKVtciZHhFDYVQiyJhkmPaMSVfWJKAp5E0X0xm5Nee9sKUnz6qtfXRli/t87gofPkzHoytr0pYq0NpXoDdv4bo+Jdsb0hpmuz5tmRLposXr2zOYtofj+mztydOSLrF6e5rtaZO6iEFtRGNCbZjtfRbVEZ2p9XEUJDQV2rJFtqdMXt+epVA0EUIilSvx+uN3sPjWHzPqoKNZ+NXvISsqmaKJJvtMrAmQK7p0blpH6+1XI6k69Zd+F61m9KATwN/KjgFfdsndKJFqamYcN3gsXYJs0eGNh/+I8FymnPYJ8pZLUIW6sEosqFEb1YgGZBRFpj1VorWnwNudGTZ2ZbFdj1zJJajL2J4ACWzPx/EE+ZLL3jWf7T9UerD/AZw/bxR3LmthwGLXF/DXt7qGjEkcdSmFNc+SefkOqk/70pBj4Vknkn7xNjKL7yYwZs57Ph9JkokvuHCvxwfHHkTdRd+m6+7r6bzt69RfcvMQb27HE4Ml4b4QZcEz7x1d04GS8X3Vk71ya4pfP7eRJ9aWxaYUGe76/JGVnu8DmNuWtnDdA6vxfDHYE6yrMtedOZNU0R4sX16yqZe2dGlQuMz2BH2FXfcOnzyjnhOm1vGdh9dScnavhusVM3Te+U2c3hZqz72a0Ai9z0L4pJ76NblVjxCddybJkz632yqSv4XsknuQtADReWcOO5ZZdh++VSBxzMff9bx7yjdV8lEVKuwd0aBGMqLTV7KpjQZQ+9ur3m7P8dq2NLbjoTZMJjH9CDpevo/uIxbSUF1FtmRh+oJELETN/DPofvkelHQrocQocpQD7ZFEz3Zmx3tVBpBkJpxwIaOrInRmSxQcj4guE1AFLX0WjQmdzd0FhA/TaiNUR0Nc9s3/4u7v/Rv3/fxaYrpg0tFn4vgQUGXq40E0RUKWZFzfpStn0tKTJ6CqmK5HPLBvBc9Mx+PljV38dU0XiaBGTUxn4cGjSVa8uA9YenImb3bk0BWJguUxuiqE6dr05E00RQEJ6qOB/spSD+ELbFegylAT0cmaNsG8xpT6ML4vIyswoznOjKYEsizRmiqyeGMR3/epjegIWQYkOlNF+komviKTyxXY+sLddP71DySmH8khn7qW7RmfvOUQ1lSiwQC6IpPaspJ1f7gWKRil/pKb0BINu702mXJl2a4ISdAQlwjpOr0Fl+y2NZjbVtN08hUYQQ3HBkkGx4dUZydbFj/MQScspKphLJbnUbI9JE0iYpS9wPMll3hIJl1wyFkOAkFH1uLoSbUEDYmiI5jVGKcrb6FJMrGQBkjlHrIDiEqA/Q9g/tgk3zlnNtc9sBq/PwDY2dpHjdcTnXsauVcfI3b4eUN6omXNIHboOaSf+z1W+3qMxsl7fE+z9U3szg2oiQYCY2YPy2Tt7nVGwyQkdegXYGD0LOov/g6dd32Ljlu/Xn4IJMtWZJoisWBCNes6ciPOqezD/syVW1NccstiHO+dv5/nw/cfe5O7rjxyn7xHhQ8XAzZbOzKgan/dA6vxhUCSJCTKmz+yVG5jkITYbVAYUGWWb+nDcjxsd/fBtZvvo+uOa3EzHdSd902CI4gPCd+j9y+/oPDGU8QOO4/E8Z/ercCYcB3MbW/g9G1HqxlDYMycPQqSOekOCm8+T+yQhSjB6JBjXiFFbsX9hKYds08USSUY1F3QFInz5416z3NWqHAgoMgSo6rCJMM6XVmT5Zt7WbE5xeq2FPGQQjrvIYApp32KZT+9ku0v3MXcz3yFrK3T3lciVXRoPuocepc9gLXqPiZd9FXaUjZFDwLASPlbr5ihuH4JkqxgjJ6FlmjAoLxQ15RyNUpXtoQsPFQViraPrri0dfdQbOtl4pjpgMDQyuWxUxtr+NEtf+JrX/g0v//htZyTK3DKRZ+mPhZgWn2UREgnU7DY1JWjO2sRCaqoqkTJcmmMBwnq771qx3Q8TMdlfUeOF9Z1ki45hAyV7rxN3nQrAfYBiuP5bO4u8HZ7Bl0pV4zURAyQBKblkQhLZC2bvpxFPKyTMx1c18cXEl6/5ZSmyDRVBSg5HrGARiigkCla/O75jYQCKuOrgoytDpO1HOaNS2I7Psscj76Si5AkHMtj21//TOdzd5CYeRzNZ/87eVdG9hxUGZAEqaLJ2pXLef331yNHqqi/5KYhyasRr623FbN1LZIsE5q8ADkQGXJcB3QNahMRptdGaUmX+Mtd96KFYkw7/mzSdlkE0fbL3+HdL9yGJElMOfXjuJ6L8MBxPSKaSjyogYD6WJDqiE48qBMoqTgCVEXh7a48W/sKVIcM2lMmh0+oQldVVEWhJqKjyJUAu8LfgcsOH8PUhijfvP8N1raPHIjGj7iY/BtPkn7hz9Se/bUhx6IHn0F2yd1kltxF3bnX7Pa9cq89Qd9ffj74s2SEiR58BrHDz0fZ6WbcETfbQ+cdVxMYPYvac69B1ocG5UbzdOovvZmuu66j87avU3fxjTSMncRvPnEIAN98YPWIIlDTG6LDf/k3sHJriq/f+/qQ4HqAlr5KSdiBwoAi/4Dt3bX3jdw6UbbB6f+s7NDvV/6VQJFAlsuekyMF2qbrY7o+z6/vGeHoO7iZLjrvvAYvn6LuwutHrDIRnkPPwz+h+NYLxI+6lPhRu/ae9u0SuRUPkl3xAH4pO/j7xHGf3GPlSXbJ3SArRA89Z9ixzOK7EK6z173Xe8N3zpk9WCFQqSCpUOHdETY0ssUsr7akSRVtCiUbfI14yMBQFYzqKYw+5CTaljzI9tMuww8n0RRIBBWCyUbc489mzTP3MeesT6NI5e/2AdeDHUl6GVb/37/i5t8RHw2MO5iqoy/DaJ5OQAZFknB9j5ILkig/MlUVNt37M7a8tpjPXP/fRMfPRpNlgko5Sx3QZC78jx9z+4++xv2/+i66sFn43W/TkAxh2i7Pr++mK2shEGiajCJLGKqM6Xj05i2qwvpeuxjsTN50aM+YbOrO8mZbBkVR8HHoyJSY1BAhEa7YBh4I+L4gVbSxPZ9kSCegKbiuywtvd5E2LTqyFuOqwwQNDYREznIwOzwSQZ2C7VAfDdCRsVBliRnNCRRJYUtvHkOWiRs6U+pB11WSIZW7lrYS0lSsPo+3tvcxuSlBquDwdkeOTMGm5HpUxQx6UkVan/gt25+/l9r5pzLxrC9Q8hSEJ7CFi5AltmRsim8vZ+t9N6Mmm6i/+EaUyK6/Q0tbXiXz0m1YrWsHf5epaqbxUz9D0QKogN5v55cIajSGDZJRnXTbBtpXL2bKaZ9CVgNEJBdD9ik6UOhsIfXGM4w/9jwi1fVkiw5Z0yIR1tB1jbCmoKoK1VGdcVVRAppMumCxsbfA2JogqZyDLCQakgG2p0oIITGtIVYWPz3AstdQ6cH+hzJ/bJLNPbsOBJVIkughCym++Tx256Yhx2QjRHT+2ZTeXozdvWW375Nbfj9601Sav/hH6i76DsFxB5Ndcg9tt3yO/OtP7FJcRI3VUH3KP2NueZWuu6/DtwrDxhgNk6i/9GYQgs7bruLIWJr5Y5N8/7E3h2XlB3i9NfM39WHv2GN929IWLvr1YjZ05Ucce87c5nc1d4UPJzv2Tw98pralRr6n9lSy7As4fmrdsCqmiKHstfCZ09tKx61fwy9mqb/4xhGDa9+x6L7vZopvvUDi+M+QOPryXX75WB0baPvN50m/8CeMpqnUnn8dzV/8I8EJh5Bd/sBuhYHcbDf5N54mMudk1OjQihE300nulceIzDkZrWrf3CshQ+Gyw8fwxRMmVYLrChX+BkzH482OLAXLIahJVEWDoMjUxQ2On17PIeOSnPXpr4Dw2fLUnwjKUtkyUAgk4THjpEuRkFj50J9QVagNgaIOzaQowPbXX8LN9zL6kptpuuJ/SBzzceyuzbT9+T/peehHSKUUvhBEDQVdldBVqI3pxAIGR176r4QSNfzvdZ8n9fZy5oyKUx0z2J4qsmRDLyWhcPRnb2Da0adz169/zPduuA7H9XhjW4q1bRnyllu28BMQ1hXChkpYV0kVbHrzNuYeWm92xPV8MiWbguXSmSnRnilStHwSIZ2IpjEuEeLoSTVccuhYovu4DL3C7vl7idZ5vhjSQ50pOWRKDr4vaM+U+jPQPr4Mzckw1RGdaEClOmrgA2FDoSdrUjTL3tYvb+whazrIioTtuvhAbcSgNmYgKzJTGxOcNL0R4ctl3RHZ4632FC+t7+KWv67nidVtPL66g7XbM1iuT6lksvqeH7Pl+XsZc9Q5jP3oF8k6CkJAwXJoSzu83WfTtfp5tiy6Ea12LPWXfXeXwbUQgu77bqbrzmtxM90kT/gMTZ+7hZqFV+H2bcfauIKgVL7nbR8sC9IFh9fbsrzVluee3/0SNRAmMe8MenI2JdNHVmQMHXpe/BOyZjDquEvIWS5I4PoQD+oITxDUVUZXhZlWH2Xu2CSzRic4fnoDC+eOYmZjknhQIxrUyJdcQprK+JoQsiwdkME1VDLY/1BWbk3tsY8zfth55F95lNRzf6D+om8PORY95GyyK+4n8/Kd1C78+i7ncLNdRA46FTVShRqpIjj+YOzOTfQ99Wt6H/s5hbXPUX3Gl1FjdcNeG5lzCpIWpOfhH9F5+9XUXXTDEA9sAL12HPWXf5/OO7/Jb77xabZ09LFaDC8PnVQXYWNXflA9/adPvc1XTpqyV4vxgUDKdn2U3WQZAc6Z28RVZ0zf45wV9g23LW3hsdXtnD6rcdDn/f1i52z1kk29g/3TA739J02v5/5X24a8boQkzjAkCeqiQ0sIZWDh3Gbuf2U7BXv396rVsYGuu64DSab+su+i1w0vu/atIl2LvoPVspqqU79IdO7pu51TSzai100gtvAbBEa985k2Rs2gtGkFwrWRtJHLHsuCR4L44RcMO5Z+4c9Iskz8yEt3+/4DHDO5htNnNQ5mp3/13Eae7Nc8GOCk6fV7NVeFChVGpjdvEQtqxIMam7oKTG+KMq4qRFUkSCygsqmnyJT6Gbx+xkW89PCdjDnuPPqkJLYLqi5oSNYx5eiPsu7FR6k58kKUcA1CgZAG2X6FJA+wMp2gaEhjZqNJEvEjLy6vJ5bcQ3bZvazdvJIxp18Js45DkiCgySRDATTVx5OrOeuq/+Hxn3yF313/RVLpm5l6xEnIkkxP3kaRBJKisvBLNzJldB2/+OlPaGnv5rTPXoMmSwjfR5UV5o+voj4SoDNnosgSfQWTkuMSNjWqwjqJ0O4zzr4vaEub2K5Hd84kb7sUTY+C41IVMmhIhJhQG2F8dRhN3XeikRV2j+v5vN2ZpSNj0RAzmNYYR36fSoN9X9CRKWH3KwQ3xoN4vo8qS6iyhO2WRXYT4QBTaiO8vj2D50EioJMruWRKFtGAiuULegsmYUPF88sVG0XTw/J84kEFy/XZ0lvCFwIhfHzhY2gyYV1mxZY+unMmnu9j+ZDJevRmbRoTBpZlsuaO79O75kVGHX8JocMup9t652+R6b8nc689Tt9f/gtj1AzqLvgWshHa5TVLkoRWO45E0xRi8xcOtm8qoQQATqaLoigHdx7vZFH7SjYdLevZ/upzTDrpcopSiJIDhgyS7ON0rif95mIOWngFY5vr0WQV23eJhXSQZGqiOuNqIyTDOrKsMLoqNKgM3pgIUrBc6mMGHdkS3TmTKfUxEuG9a0XdX6kE2P9A7t0LlV05ECG24ELSz/4fZsvrQzJiSjBKdP5ZZBffjX3kJei1Y0ecQ9IMhGMO+Z1eP4H6y75L/rXHST3zO9p+9y9Un/IFwv2KgjsSnn4Msh6g+/7vlvutL/7OsL4QLdlEw+Xfp+uu63jkh/9C7cKvE5q8YPD4OXObCBkqLX3FwYDohfU9LN3Uy4WHjOa8eaN2G2jvGEj5I5SEy8BJM+r5/HETK9mzvyO3LW0Z7Hd+YX0PLb2F921zY8dNFl2VufWKBSyYUI2uyjiuj6aWe/u/eMIk1nXkeLNfA0ACYkGVTGn3cj9nH9REV87asXocH7h1acsez83c+jpdi76DHIhSf/F3RswKe8UMXXdfj921iZqzvkpo2jF7nFc2wtRdeP2w3/uOCZK8S59sN9tD7rXHicw+CTU+dOPM7t5CYc2zxA47FzVWs8dzUBVp2EbYlcdN5IX13YPKv8dOruGnlxy8x7kqVKiwa3wEUUPn+GmNTGrIUxUKoGsS7SmToKZQHwugyXDU+Vew7In7WP3g72g+72skAiq27dGZt5h08mWse+FhOl+6m/FnfgFdFqTNci/mgHyjpBrgOeC7oJQX6LIeJHHsxwnPOJ70X37K5kU/JLNhOVMWfoGCHaLk2PTkPWTJIRQIMvdzP2b1H65h0Y+vYvbFX2XaMWchhEem5BHUFQxN5aJ/uZZoNMGtt/yMtq5uLvn372NLKodNSBAzVDpzJnnLZVtfgQ1dOZoSYUYlBJbjETXUweyX3e8LrsnljJqhKXiirKJueT5Zy8XqV2CP2ArV4QAN8QD1scABmz37R9Gbt1jbnsX3fN7sKCtLzx5dtU/7b8s2twLRL0wa1lVMpyzGFQvqlByTouMRD+mDQeDpsxqoiQaIGhq6WlbpLzmCoAYTayJkTQdPwKS6KDVhnaLtEdJVsqZFZ8bEdn3qYwZ9BYulm3rL88oy1VGdXMliW7osKKhQ3szv6kmx6e6byW1+laaTr0CZdw6lEa4ls3QR6Wd/R2DCfGrO/tpug+sBEkdfNsIfpTy73r/hPrDasQDFA0OCZQ/+L4oepP6IhWQEGArISlkXaf3j/0cgmmTOqZcyqTaO6fu0pYqMSWrUxgKMro5QHTZQFYWZTdEhtluSJBEJaIytCRMOaEyuj1MfO7CDa6gE2P9Q9vZxE513JrmVD5F69g80fPxHQ74wYocsJLfyITIv37HLLLaaaMDp2z78/SWZ6NzTCYw7mN6HfkTPQz/E3PoayZM+N0wELTjxUOouuoGue26g49avUX/Rd9Cqh2ap1WgN9Zd9j667v033fTdTdeqXiB50CgCPru7A9fxh5be2J7h1aQt3LN/GR6bVceVxE4FyQL2+M8er29KcNrOBk2c2oKsyluMPyURKUjlzdmUlsP6HcOfyocHnr57fxMkzG/apQvyiVa0IyvfLoNq3804FxK1XLBgcM8CN587m8t8uwXF9ZFnaY3B97OQaHnqtjRH2bvZIYd1L9Dz0Q7REU9neJjo8aHWz3XTddR1uppPac69BCcXpuvcGAs0z0OrGjagwvjvc3lbURP0u7byyS+8GIYgfcdGwY+lnf49khIjtoX97Um2YCbWRETet5o9NcusVC4ZUE1SoUOG9URcJ0JmxCEqCiXoMCdBVhZCmUbBd6uIa21Ml6uqbOOPSK7j//35BzZHnojZPwxaC2ngQLdFI3byT6HrlceqPvhA7UkOpv/hmQG1Y7Rckdfq2o+/kravVjKbmsh+QW3wXfS/dzivb3qLx7KsQ8iR8t9yHbdkWWUui+YLrKN51E6/f8UOsfJbpp16CokokwjphTSZV9Pj8V7+BEYnzu5/cQE9vH5/55i/IFiJ05CxqIwaO55M3XcYkw2zpK9BbtGiMBmlJFRlbVS7pzZkuRctlc08e1/VJhnUOGVdFUFPYnioS0RUimoLnCWY0JqiNGqgHoO/uBwHHFeRNl3TeoqfksL4rz4S6KLF9ZH/ZmzfZniqhyDL1MQNFLpdaF22foK6gyhKjkkEyJYec6eL6PjVhA03TmNYYJ1202dZboCdvkwiqFByXeEAjHlSxPUFAVYiHdeSiQ7pgsbo9S0TXiAUU1rRlMC0Py/FoT5sUbZes6WK671TJuYBSzPDW3ddjdW6k4Yx/Izb7xGHBtRCC9PN/ILvkHkLTjiE6/0x6HvzB37wmoH+Nr/Tf2ztS9MHp2krbq88x6cRLCUZiuJaHEIKSDeamVaQ2vc7hl/07kxqSNCQDaIrMlPowQpQdWA4bX01NNICmKEQCI4eOYUMjbFRaMQaoPIH+gZw3bxS6Wu7jUBVplzt8smaQOPoy7PZ1FNe9NOSYEooTnX8Wxbf+P3vnHWdHWbf977TT+/Zseq+QkAAJEDqE3nsTUEHF3uVRVBQbooj6SBPpvfcSSghJCCSBkN6TzWb76W36vH/M7iGb3YRQ1MfXvT4fcTMz555y5szc169c11u77cX21o9Bb9uIY/VPMpRYPXUX/pbIrHMofPAybfd8t19C7hsymfoLfo1j6rTd93201g19tpH8EerO+yW+4VNJvXgTmYUP4jhOhRhZu/ESsGyHV1a3c/bNCzn31kVc/9I6nny/ha3JEje/uZkfP7GCa06axAUHDsUjCYiALApcd9oUbrtkxsAE/9+E2n6ilLfM2/SJx9u5z37ptjTn3/Y29y1u4v7FTTz0bhOyJFYmiQs2dnHh7W+zri3PY8uaefCdpkofdg8B/Pax4xiW2HNEeHRNkAWbkp+IXOffe56uJ3+Dt240dRf+tl9ybSS303bv9zHzSWrPuRYEga5nbyAwZiaOZaLtWAu4ll17A8dx0FrW4akb3e96M9e5U/a6d9m22rSC8uYlRGee3UdVfGfIIvz2rH25dQ+/renD4gP91gMYwGcIryIxZXCUYYkgo2tDRHwKBc0g7JOpj3jJlUwEEUbXBjn/C18hFE3Q/uqdSI7DoKiPkCJRNBwmzrnQ1UWZ/whFA7pNcipWPt5699mhtazr9zgEUSJy8PnUXfAbLNNi673fpWnBc6R1h2IZkiVwbCg5PhrOvIbE5Nmse/YW1jx1K9gOLZkyW7oKbE8V6cipHHvuZVz0/V+zbdUybrv6MrY2t9CeKdOeVTEMm2RJoz2v0llQsUyL1kyRtkyZlTtSvLm2nWRBY1OywPLtGdoLGm9t6uKV1e34FYkxdWGqQj5qo37G1oepj/oGyPW/ETURL0GPSGtOI+GTKRk2BW1vzOL6wrYdOvIqW7uKpIs6juOwsb2AaloUNVfUriroxbQdBAFyqkFHXsXottuURYGCapJX3VrsRNDDoJgfSRII+xWqgh4EBwIehZjfS7y7NaGkuZ7tXQWDsmoR9cl4JJGqkJdIwINh25QNg4Jmomk2iiRSH5EJi+Bk2mi693voXU0MOuN/8E45Cm2X83Jsi9SLfyb39qOEph5HcOLhJJ//0yeeEwDkW9zP+OpG9bu+7Y0HEBUfYw4/i+qgl0FxP4mgh5qgwOaX7iBQ1UB06hw2Z1QCXpFBUR+1ET8HjKrhoDG1eBSZeNC7W3I9gL4YeAr9GzF9WJwHvjiT784Zx0NXzOIXp05GFgVEwSWPO9Pt4OSjUKqHknnz7j5EObL/6QgeH9m37u93P76h++AYKlr3D7A/CKJE/NBLqD3751iFFK13fZPS+oV9tvPUjaL+wt8hKD7aH7ya8tb3+2wjevzUnnkNwUlHkJ1/L6mX/xfH3jvhEtsBsx+ms6YtzzVPreCM/QbzwBWz+M6ccTx05ax/es/vAPaMnoqDnbG5q68Y3u7wzQffY+q1L/PNB9/rJVh2/m1vc+0zqyrCJeAGZybUh5kyOFqxgzJMmxdWtvbpw4YPCeDImt0r5QNs6iz2K8i3pwoTx3FIv3k3qZf/F/+oGdSe98t+CavWso62+36AYxnUX/BrfEMmY3RsIXrw+YSnHo93yCSsfBdmvgu7nK+MvSeYqWasQhLfsL4CagDZRQ+B47oQ9DnmN/6BFK4mPP3kPp9LBJTKOTsOles4gAEM4F8HryxRE/ZRF/EzviHClEExaqNeBseD1Md9jKkNEw96CIciHH/RV9ixZhnB5ApEQWBTV55UTqOgJIjtcwzJ915CzXWg4fZj9kBODEYKJVD7eX/vDN/gidRf9if8w6aSfOVvJJ/7A6qhotpgWq5vrs+jMOHs7zHs4JNZN/cBlt//WwxVw3EcWlMlmpNFUkWdoQfO4eTv3kDr9i38/MozWfLBCj7YnubxZdt5eUULjy5tYmVTmlUtWeav72LZtiTPfdDKxq4cy5rS5PIqpm2zoS1HvmzSni+zujVLfdTPiJogQ6sC1A6UhP/b4VUkZgyv4oBRCUbXhVAkAcOysHcjerszdNNiVXOal1e2saYlS041KKgmfkWktVvEDsG12TQdh1zZoD2n0pbTUETXp7momdjdxFQU+uqv+BSJxliAsFcho5rURnwMjvuQZQHVsPApEvURL4bpEPCITB4cA8G1oTpyfD37DokiCyIOuN7wHpFEyINHdCi2bGTHvd/FLueoPfeXKN1Z6J1zurah0fnkryl88DLRWeeSOPYqjK5tn2pOAKBuex+lZjjWLhpJAHrHFgrr3iI6/RR2lD2sai1SLFlEAl5SK98k27KZwUdfSsDvnvfm9hJFzSToVfApIpbtChIO4ONhgGD/m7FzFuiCA4fy0JWzOP+AoRw5vhZFEpAEN5skihKxwy7FTLdQWP5SrzEkf5jI/qdRWr8QrW1jn334hk8FUaa88Z2PPB7/yOk0XHojStUQOp/4Fel5d/Yhx0qikfqLrkeO1tHxyM8orp7XZxxBkqk68VtEDjyLwvsv0PnEddi62me7jwPThpvnbRrInP0fwvRhcb50aG8xry2dhb1SiP/mg+/x5PstZEoGT77fwpfuWfJhCbhps7w522t7B1eBfk1rDlkSkQRQZJHjJzfgkT/8964e61ceNgpFciddouD+rwe7Ez+TRIErDx1JTahvWZtjGSSf+wO5RQ8T2ncONWf8uF9f+fLmpbQ/eDWix0/9Rb/D0x1ZFhQfqZf/RnHNfJLP34itFsjMu4vU3FuxtdJHThBLGxYD9OutbWTaKHzwiitquEvvdWntW+it64nNvgixH2G0TNmoPHP6u44DGMAA/rWQJXfyLgoiOc1ARMAjyQyrChD2SUw/7kwitYOZd99f6MgWKGg2Hq+CTxGoOfhcECC74ME+4wqCgG/EdMqbl+KYxh6PQfJHqDnrGqKHXEhx1RvsuOe76OlWyrj9nZYFqi0y/tSvMuSIC2hZ+grL7voZm1qSbOzK8/r6Vt7blqIrW6Jx4oGc+IM/UyqVuPaKc3hj/nyWbk+RVy1SeY1kySBdMsmUNFrSZTozZbryGnG/jN8rMSjiI6+ZCI6rw9GcLNGaLSOLIt4BIbP/MxgSDzKqKkyyaOCTRHJlk2x5z/cZwJbOPAs2JlnTkuXp5Tt4rymFalgUNJPOvEquZKBIErIgEJAlwj6ZoEeiKqjQUdDIqwYRv4JPkYkHPJQMG59H6qMgPyjuZ+KgKDOGJWiIBliyJc0H2zPguCQ/HvRQHfYxbWiCfYfEOGBkNUdNrKOsm+xIl/F5ZGoCXkzbAkFwyWfrCrY98EMEyUP9hdf3EiU1cftxbbVAx8PXUN6wmPhRVxA79GIEQfjUcwJLLaBuX4V/5Ix+12feug/BGyR4wOlkTSiaUNA1IpLNuhf+Qe3IiexzyByqAh6CHpnmTJHWnIZuugGHhqiPqtCAh/zHxQDB/j+Ix5Y1M3dNO5YDR02o49z9h+Lg9kF7h0wms+B+bK23FVFk/9MQfSGy8+/tM57oDeAbvi/FtW/tVSRMjtRSf8FvCe17HLm3H6XjkZ9hlXt7dcvhKuov+A3exvF0PXM9uXef6jOOIIjED7+UxDFfprxpCe0P/Air8PGsuXbFa2s7Pra91wD+ufjhCRN6eZtbzkcL+C3dlubFVW29lnUWdGynd+ZYAPYdHGXfwdEKGbZsh7OmD+bbx47jvi/M5IIDh1bKwe/7wsx++4UfvMKtdpBEoZc3u4NLuD2SwIUHDuVXp0/he3PG8cVDRnDrm5vpLOi9xrLVAh2P/JTiqteJzr6IxJyv9tsHXVjxKh2PXYscbyQx56uYuc7KusiMU4gfcRnajjWEp59E7Vk/dUmvx09p4+I9XjeA0rq38NSP6Vf1P7vgAQRRIjqrd3+1Yxlk3rwLpWY4wUlH9Dtuj03Z7q7jAAYwgH899G4rqlzZIFPS8UiQCHlZ05pjS0Zl8DGXUmjbyvZ3XkY3TZIFlbzqIEdqCE89nsKKuf22fAXGHYSjlyhvfe8jj0EQRGIHn0/t2T/DynfRdtc3KW9eCtCtiQFdeYPAARdQfexXyG1Ywvq7/4dSPsOmDpX3mpJsaCuxuiWNkxjNzK/dhOgPc+dPrmDL26/SmVfRDIuCqtOaLlE2TDTDRpJF2rIammkxrCrI0JoQ04bGkCWJrrxOVchDWbfIlY1/mS3UAD4asixSG/GQCCrUR/3kywZFffdl4qZlkyxoNKVU2nMlUmWNpq4S25JFyrpJa0bFMG1Kho1fERldF2ZiY4yqkBfVsPEpMsMSARpjAaq7iWA86GF4VYCGqL9P+6VXlhgU8xPxK2RVHQQH2/W7Q5EldMuhPuqlIeanMR5gUNRPqqizYFMHOdUgoIjsyKq0ZXUUETqXvcRrf/kR4drBDP/c71Gqh/TanwWouS467v8BWss6YrMvRtlJlPjTzgnK6xeBbREYO6vPOq1lHeUNbxM54HQkn1vNJwBlAxY8ex/ldAeXf+Nqpg2vIhbyMr4+xMQhMUbUhNBtm7JmEfTKA5UhnwADBPv/GHZWy7Zsh9fWdhD2uj0PgiAQP+Jy7FKW7OJHe31O9AaJHHgm5c1LUHcynu9BcMJhWLkOtB191/UHQVaoOu6rJI77Gur2FbTd9c0+Pd6iL0TdOdcSGHsQ6dduI/Xa7f32jIT3O5GaM/4HI9lE6z3f/kjf7j3BcRweW9bcyw/74r8v5v69UHoewD8HS7elWd/eOwDz2pr23X4nPaXgPQrUu2JU7Ycl3Q5QF/Fx7v5D8SofZqnP3G9wryqGj6pqmD4sTmPMj9FP+8HBo6t54IpZXHf6FC44cCgzR1Zx2/zN7Hp0RqaNtnu/h7p9NVUnfpvYQef1eek4jkN20cMkn/8jviGTCU48nM4nrqPw3guozWsq24WnHo/g8SF0K/j29EqLHn+/x185huR29LaNBCcc2ndd13aKq14nPO3EPr3g+WXPY2baiB9+Wa+AwK5lXw4MVIcMYAD/hyAJgivypZuIokBOM9ANG8tx0HWomnwIgcZx7HjtXgRDxy+LNMS9+BQYfMg5CLJCpp/2Mf/wqYi+MMXVb+z1sfhHTqf+czciRWroeORnZN9+hJzlkLOgSwcVCE47gZrTfoTWvpnVt32PfGcbpgmG4yp9C46Dv2oQk774e6KDx7Dy/l/S9tYj2N2WSFGfTEARiYa8VIe8DE8EiPp9lA2H1rTbXzu5McrQeIDqiBdwrZrWt+XpyJVJFzVaMyWKqoFqWOjG3pUnD+Czg207JEsGsiTRkVdpy6m0ZEps6XRF6nZFV0EnrxrUhBSKuo2qW0R8Eo4j4JUl/B43W62bNpmSiWm5qvJ1ER8Rv0LYJxMLeFxNo53eyR9FCm3bYXuyRFdBI1s06CzoJIIehlUFqQ77Kgr2mztz3R7uFsmCztKmNDgOiaDMiqdu4eVbr2PU1Jmc/IO/EUsk+uxH79xK273fxch2EJt2HNm3H/7M5gQAxdWvI8ca8DSM7bMu8+Y9iP4IkemnVJaZgKhl2Tj3AUbtN5vj5xzNyVMb+c4x4zl53yGMSIRwHIdC2cKnDNDET4rP7MoJgiAJgvCeIAjPflZj/jdi5sgqxJ0eCpbt9Mr0eRvGEphwGPl3n+yVFQMI73cyUjBOZt5dfaK5gbGzEBQfhQ/mfqzjCe87h/rzf4NjGbTd8x2Ka9/qtV6QPVSf+gNX6fzdJ+l6+nocU+8zTmD0gdRd8FuwLdru/V4l+v1xIOD2pj+yZDu/f2kd59yyiKufWMH8DV1c/cQKjr/xTU79y1v9Erv+iPjOgloD+OR4e3OSXecvbTmNq59Y0e930RNE2h1Shd6SIK+sbufaZ1dxzUmT+Pax47jmpEm8vTnZ53v7qO8z3o+vqiT2tqD6zfNr+Pxd7/YRPFOb19B2z3ewimnqzr2W0OQj+4zl2Bapl/9K5s27CU48nJqzrgHbJH7E5fiGT6W05g3MXEdl+8DoAymtmU9p3UIKH7yM3rYBqR+RtJ1RWPEqCCLBiYf3WZeZfw+C4iUys7fvtaUWyC58EN/waZWy8pqQh9E1QSYNivTadlcf8AEM4NNgYF7w6SFLInURHwGPRCLgQRQECppJUJEIekSCisK4k67AyKcovfc0siTikWUiPpFYdRXVM06htGYeesfmXuMKkkJw4qGU1i/CUgt7fTxKrJ76i35PYPwhZObdRdfTv+vT/hUYdxAN5/0Sq5Sl+d7vkmnagGHZKJJEV7FMumQQiMaZ/Plf0zjtCLa+dAfbn7sJwbJRbQvNEqgNuPZAM4YnqIt60SwLG5uCZrAjUyLgFREFkXzZYHu6RE4zWLo1xTtbkqxtzfPa2g6WbUszf2Mna9tyqLrVa15UUE3WtGbZ0J5DN21My6Ytq7IjXUI1LCzboaxbe3xXDaB/WI6rxj00ESTilykbNpbl0JQqsT3tVl/2BD1s20EzLbyyRH0swOyxNUweHKUqqCCLsGRrivXtWfKaScQnYzsOedWkNVt2Kzs9UncJuU5LplzRU7Fth0xJJ1XUMXajrisKMLw6QHXER2M8wMjaEPGAa/mlmzYlzeD1te0890ErK5ozNIR9pIoqmZKBbarM+9uPWfni/exzzFmc9K3riYTDRPwKO8801G0f0HbfD8CxGXTBr5GCCRKf4ZzAzLajbltBcNIRfQIK5W3LUbe9z7DDzyEeDOABAgJU+6DtzYcxNZVzv/I9NyiVLrG2PUdnXqU25CXgERlVG6Q++tEEfwD947OUg/sGsAaIfNSGA9g9pg+Lc+2pk7nmqZVYtoMDbE32LgePH3YJpfULycy/h+oTv11ZLnp8RA86j9Qrf0PdvAT/qP13WucnOOFQimvmYR/1xb3y2uuBt3E89Z+7ka4nfkXXU79B7ziH2CEXVjJhgigRP/pKpEgtmTfuoL2QpOaMHyP5e98K3vrR1F98Ax2PXUvHoz8nftQXCe93Uu+II/33xIL7MBxdG2J1q5st3VWYqsf3eHnzCt5Y11GxF/rN82u4+U13ctHj1ZzTTB5d2oxpub1c15w0iXRJH7Ac+gSYObIKryKiG3afrO8LK1sZVx+u2Dm9sqqNJ9/fgQBIgvtd70rOU6XevVoOrnhZz/ezqxf29GHxfj2yd/0en3yvb9n69KGxSjn7K6vaKvfJziisep3kC39CjtRQe+ZP+9jTAdh6ma6nfkt58xLCB5zhZooFgfCMUxBkD0bnVsxsO/n3XyJ20LkIsgdP7UjC00+iuGYejmVQfdqPUGL1u73OjmVQWDEX/+gDkEK9z01rWUdp/UKiB1+AtIvISXbhg9hqgfgRl1WWdRZ0t/y980NBOo8scsZ+fc9tAAP4FBiYF3wGqIv4sG2HgmZiO1DWLRRFpLEqiGnDxEGHUlh2OGteeZDEtGPRQ1WUTYchiRDRky/i5WXPk5l3N7Vn/6zXuKF9jiW/7DmKK18jMuOU/nfeD0SPj+pTvk+ufhSZeXdjJLdTc8aPez2/lMGTqL/oejoe/inb7v8h4hnfZ8yMQ0mVDETBRkBAcySGnPk95MQgtr16H4vSbcy47OcMbqgj7Pdy+MQaasN+cmWdTFFHNy1qwj66ciq2I5AsqggIZEsGIY9MqmxQH/YR9sm0ZEropokkSGSKOVIFjWHVIWpCXgrdZNznkZAEkYJqEg96oNuSqCVTRhaFijp1Tci1/fJIIuJn6Of8/ysUSSQeUMipAmXdA2gUdROvLJHXTFoyJTTTxrQsmlMlMiWDuoiXuliAmrCHNa15SiZIgoPPI1IfD9CR1akKWlRHPPgUiaJuYlouORcR8HslipqJYdlIokSqqJEtG0iiSFEzGRz3IwgChmVT0kwcB9a15VjdnMfrEWms8hENKKxtzWPYDjVhH6Zpsqo5Q1EzaUkX6cireEQYJBe455ffILtjE/ud/XVmnnw+zTmDgmYg4eqYeBxIrXyN5As3IccaqDv350iRWgKJwZ/ZnACgsPxlAEL7HN1rueM4ZObdhRyuITTtBHQTqvzg93kwMztoWfwsY2efwogx42nJqNREfJTzKs0Zleqgh+GJEBMbIgOK/J8CnwnBFgRhMHAicB3w7Y/YfAAfgR5l7Fvf3MS2ZKkP4ZSjdURmnEpu8aOEp59SsdwACO17LLl3nyA97y58I/brVQ4amno8hQ9eprDyVSL9qAjvCXIoQd35vyb1yt/ILXoYo2ML1Sd/F9EbBNxSnOiBZyBHauh67g+03fNdas/+GUp8UO9xIjXUX/g7up69gfTcWzA6t5E45spKWcypUwfx1Pst/ZJsy6FCrj8KL69u5431nfzs5EncOr83abp1/mYc50Mirxk2P35yBY4DiiRw+LhaqsNeztxv8ADZ3gv0WGI9tqyZh97d3ivwMakhUiG+giD0WpcIKH3I9M6oj3hJlQwsy66Ibu3cQtGjGD59WHy3y5duS3PLvE2815Tu008tCPDO1jTvbE3z8LtNfe45x7bIvHkPucWP4h06hZrTftQnaARg5rvoePRajM6thKadgNG+kdyih1FqhhEYMxMAT+1I7HKB0vqF5N9/gciMU11xlImH4x8zq1/RsV1RWrcQu5QhvO9xvY/TcUjPuwsxECWy/2m91hnpVvJLnyU45Wg8tb3F6HbF5QcNH7jfB/CZYWBe8NnBzai676yybpHSNZqSKrVhH2GvTNAjY37+O/z5qtNZ98KdDD/5axg2eGQBXfQTn3U2yTfuRG1agW/olMq4nrpReBrGkn/vecLTT/5YfZbuO/8sPDUj6Hr6d7Td/W2qT/0B/mH7VrZRqoZUguqbH74ORUsRn3EyCBaqAdVBLyXLZsIJlxOpG8rKh37P/D9+hUO+/Cuqw1PwIBJURFa3lGhKFtmWKlEV9uIVRRpifrYni4ysDaIZCm15lfqwl3RJJ1nQUCQR1bApGTrZkonpWDTG/GxLFZEQKGpmd7WgRbqko1sOflmkNupzlahth6qQl3RRY0e6THXIS8Aj0RjzIw2Qjo9EddhHwOMSWYDtqQJeRWJ4tZ+ibhHxyszd2EW6pGGYNus7Chw3SWHNjiw1QQ9pETZ2FmmIBpAlAweHoYkgBd2iqJkosthdEg4pR3e/T1FA7g6AaKaNT5GQRPe77vkNbe0ssLY9R7as055VmTQoSntWw6dIlFSLom7hVSQyRY0dmSLvNWewLIu2rMZw0yLTtJYXb/oBhlrm8K/8murJM+nI6yiSRMwnY8qQLDlsfe1BUm/dh1w9DNkXorjydcI1wxA/wzmBYxkUPngZ/6gZfTRZSusXoreuZ9hJ38C0PCgyiJJIJCDz/oN34PH6uPiq76BIEkMTfjKqRmtGQxIFGmMBV6VdNQkM+Fp/YnxWGewbge8DuzVWFQThCuAKgKFDB6yV9oSl29Jc++wqNMPebTY3OutsCiteIf3a7dSd/+vKi1GQFGKzL6brmesprn6D0OSjKp/xNoxxX6bLniO834kIwsd7SQiyQuK4r+GpG0Xq1Vtpvfvb1J7xY5SqDwUdghNmI4Wr6Hz8l7Td811qTr8a35DJvcYRPX5qTr+azJt3k3v7UTf6fdqPkIKx3ZLr/iACgij0a7EErhL1Q+82sav2ya6bO1DZRrccXl7dDsCDi5s4emJdJRM+gN2jh+TuXIJ37MQ6cpr54X28yxexJ3INcNrURo6ZVF/Jfvd8Bx5ZxDDtXkrXM0dWVZZLokBLpsz9i5v4yVMrduu9vvPh7FoFaGtFup75PeVN7xKaehyJo7+EIPV9XOrtm+h49FpsvURs9sUUV79ObPbFOKZO8sU/Iyg+/MOnusc9aBy2VkBrXkNq7i0YqR1Un/y9PfpR74z8smeRYw34Ru7Xa7m6eSla0wfEj76yT2VK5o1/IEgysUMvrizbXZXIqtbcXh3HAAawl7iRPcwLBuYEe49c2cCrCPgUD+1ZlbZMCRHY2pHHEQWCioKvqpFBs05hx1tP4J1yItVDR9KVU0nmdQL7nUxmyTOk37iT+ot/34tIh/c7keRzf0Td+j7+EdM+L8Q48gABAABJREFU9rH5R06n/pI/0Pn4dXQ89BPiR36e8PRTKvuQQnHqzv81Xc9cz7on/8rgzh2MPvEKNGx8MliOREHVCI0/hImXV7P+get444arqP3Or9k4OE5G1UkXdddJxe+lMep3s9aqgd8r45cVBsVkRNEtp2+MgdatftyeVdnUlUeRRSxLoCWrUh3y4vWKDKsOsrGrgOAIjK4Osi1VpKugMW1YFYokUDBM2ltUyrpF0C/RlHLJYqakM7o2TMA74Af8UfApEoosUB/1EfTJVPk95FSdlqxGmwjrdmTIaQaaaaMaFgFZoDOvEvB6EIHakI/RdSFkSSIeUEAQaIz5MW0HRRKRRAFJdAXLTNvBK4uVjGssoNCe07BsB910y/878hpdBZVC2SBfNujIafiVAsmCzo5MEa8ikC9ZIDgMSgToyusUVZNUsQw2bFrwLPPvuQFvtJqZX/8tw8eOQ7NAsw0KhkXII2EbFtue+gOp5a/jGz4NO99FdP/TkEydHS/+merPcE5QWrcAq5gmPO1EwLUDswHTMsm8eTdK1VAi+xxJ2XIrBn2KhLdzFdven8+cS77B+BGDGV8fQZIk7E6HqpCJJLktKFUhLz55IJD0afCpr54gCCcBHY7j7LGp1nGcWx3HmeE4zoyamppPu9v/r9GTjdsT0RS9QWKHXIi2faWrILgTAhNm46kfTWb+vX36oSMzTsFMNX+iHmhwo9bh/U6k7rzrsNUCrXd/h9Iu9l++wROpv/gGRH+E9gd/TGHlq/2MIxI/7FKqT/4eetsGWu/6Flrrhj2esySCLAmV8uKjJ9ZVCJ0ouBnRXVEb8eFVRITubaYO7usRuLuYvY2bCT/n5oUDImp7gR6S6z7IRQ4fV8ujS5v3OmCyMwQg7Ff6iJf1ZMt3VbruWX7uAUNBEHjgnSZ+/MTuyfWeYHRtp/Xu71DesozEsV+has5X+yXXpQ1vu71Vgkj9hb/DN3QynvrRBMbOIjjxMBJHXUH69b9jZt2Ajah4CYw9CL1zK6V1Cwnvd9Jev0i1lnUVhdGdA2OObZGedydyrJ7w1N6ZbXX7SkrrFxI58EzkkCu8Iotw3elT2H9434DR8ZMb9voaDWAAe8LezAsG5gR7D68sUtTczF1j3EcsqBDyiaimjW2YlE2D5q4igw8/D9EXouO129Eth2RJJ2+6z57Y7AvRW9dRWreg19jB8YciBmPkljz5iY9PSTRSf/Hv8Y8+gPSrt5F8/sZecw/R46Pm9KupnXkqzQueZM29PyfhcTAs1xZJNx3KlkNo6DgmfPEPBGqG8MCvvskdN/2OFc1JOnIqBc0AARJBL/VRP35ZZEJ9hJqIl2hAYUg8iCJJRAIKVWEvQZ/MvkNjjKgOsW9jjKqggiKJDE74sR2IB70cOqqGQ8dWs7rV7T31KBItmSJ+RWJ4IkjQJ9MQ8+OVZZI5jWRRZ21bjkWbuyioH2079d8OURSoi/jweyQGRf14PCKSJJIIeliyNYUku3ot21JlyobN5q4SiiSRLmqolsmw6iDDEkGGJQIEPDKS4CZIesh1D3yKRMgro+xUWRD0KgxJBCr6BYoksiNVYluyxLZ0kS2dBUqayZItKcAl4du7VDanSqQLGutbs2RKOoroIDgOG575X9688zfER07msO/+DX/NYBxEFAmqgj68soRezjH3z98iudx1GYkefAFKwxgiY2dRNfkwaj/DOYHjOOTefQo50Uh81H4MCUF1ACIeKH/wMmZqB0OO/hy27VaxSiL4FHjtrptI1A3i8LMupaugUxvxMzjh5/AJdVwwcwSzRtUwrj7MPoNjRPx9dWsGsPf4LMITBwOnCIKwFXgQOFIQhL5eUQPYa8wcWYUsukRyTwVboX3noFQPJf3GHb28LAVBJHb4ZVi5TnJLn+n1mcC4g5HC1eTeeeJTHaNvyGQaPvdHlHgDnY/9gsyCB3opiCvxBuov/j2+IRNJPvdH10+7H4Xx4MTDqLvwdyBA233fp7CirwibJNJNkD+8GpLkErgeQueRRb47Zzwe6cNtZEngS4eN4r4vzOS7c8bxyJcOIuzvS8IlUUAU6I6G9j1Xy4Frnlo5IIb2EdiZ/F5z0iReWNmK8TEFYnqCJ15FJB7w9CtatjvF8B6lcFdhlD794HuD0vqFtN7zbWw1T9151xGedkKfbRzHIbv4UTofvw6legj1l9yAp3ZE5f62Sq5/d3DiYfhH7EfXMzdUPpt79ynMXAf1n/sjgdEH7PVx5d55AsEbJDTlmF7Li6tex+jcSuzQSyptFtBNvF+9DSlcTeSA0yvL4wEP6ZLOD4+fwGNfPohjJtax7+Aov+pWUB/AAD4jDMwLPkMokkhRM2jNqpQ1i0zZYHNHCVkUyJZtdN0iGlDwhaLUHnYB5W0fkFmzGNN2LYIAgpOPQqkeSubNu3CsneYLskJ42omom5eid277xMcoegPUnH410YMvoLjyVdru+0EvIVZBlAge9kXqj/0KHWvf5fXfX0X7jibAwadImAZoOlQ1DGLaF3/PoOnHMO/hW/nTD77E+qYWOnJlGuNe6iNexjVEGFsfRbccOvIqmmHTUdAIKkJ3b69DdchLxK8wvj5MPOhhRE2IMXVhgl6FhoiPhqiPhniAqoCXkM+1k0oEFSxb6BbSMkgEFRJhD2GvTMAroUgQ8kjops2WruLuL8YAKvDKEjVhHwGvRKaokyrpqLqFYZmE/V4c2yboFfHJAkGvjKwIlAyLzoLBgvUdFFQNjyySCCiops32VIntqRKaaX3kvhVJxO+RQIDNHXk6C2VEwUYRoSrsZcqQGA0xPx5JIFky0BwLwbEJ+71ohpuUEdQcK+/4AS2LnqJ+1qmMvfBnJFWZZNEkX9LZ2lUiWdJIN2/kmV98gXzzBupO+T6xg86jZxZSLmXJ2eCbeBihz2hOoG1fid62gdoDTsXvFakK+vB7FaKKSXrB/QSHTqJ68iwsQBbAL0PzO6/Q1bSBQy74Bm1Fi+ZUgXVteVrSZVTDwu+RGFcfYd+hCWoivgG9gU+JT02wHcf5keM4gx3HGQ6cB7zmOM5Fn/rI/gvRo4L8yqo2zG6Bsz1l/wRRIn7kFzEzbeSW9vah9g/bF//IGWQXPYxV/rD0U5AUwtNPQWv6AK1l3ac6XjlSS92FvyM46XCyb91H55O/7uXPLflC1J59LaGprp925+PX9fHvBlf8rOFzN+IbPJHk8zeSfOkvvQIGPb1nhuVgWu51sSxX9GrnbOYFBw7lgW6/4wsPHMpDV8zqQ8J2zdKJwEn7NHDw6Gp+cepkHr7yII6dWMeuzxXLdnh7c/JTXa//BkwfFmfmyCqufXYVCzZ2fezstQNMboxy6azhXPvsKm54eR0X3v72Xgc3eoJTHxeObZF+4046n/gVStUQGj73pz6tDYBb+v38H8m8cSeB8YdQd/5vkILuPeYbPAlbK5HfKagVP/wykGRKG94GIDzjZBq/eEslo7w3MNItlNYvJDz1+F4l4LahknnzHjwNYwmMn93rM8WVr6K3byJ++KWIiq+yvLOgV64pwG2XzOCprx4yQK4H8JliYF7w6aFqJhvbc2xqz7F8e5pC2SRX0kkWNYolg3BAJuARWduW4a1NXezIFAgrIuNmn4SvZghtr/29d+BdlIgffhlmupX8ey/02ld4vxMRFC+5dx7/VMcsCCKxQy6g5owfY6Saab3rW6jbV1bW24B32gkMOftatGwny/72TTZ/8B7pvIlXhnBAwjEtdEckevTXaTz+S7SueZfHfnY5W9evZlVzjmXbUizfnqIpWWRLV4HWjEpONcBxkCWJ4VUBhlUF8CkSgiBQE/Ezpi7MsOog8aCHgmrQnC3TmlXJlnQURWLq4BjZssmOtErUL2HZNo4jIIsStUEv4xsiHDS6hoAi05Yt05wqsb4tR6bU1zFlAH3hOA4dOQ1ZEpBFkaKmUxP2opo2iYDCpPoYNWE/jbEAEgKdBY2unMqKlizz1naSKuqYNhQ1i5BPRhAgX967CgJJAJ8ssT1dQhJEiqqFYUFdJEBZMwl5ZRAFaoKul7ZfETFtdz9t6z/g6WsvI711LTM+92PGnXIVmi1RttwWs7ZcmaJq8MGCN3j9D1/FMnQazv813u73cX9zgqrPYE4AkFv8GGIgSnDiUa6mkAg1kQDtCx7FLGYYftzlaIaNT4GgBxRHZ93zf6d29BRGzDgcx3YI+TykSlp3Gbw+4CX/GWOgwP7/AJZuS3POzQs5828Luf6lddz85uY+fcK7g3/ENPyj9ie78CGsQm8SEjv8Mhy9THbBA72Wh6ceh+gN9vHS/iQQFS9VJ36b+JFfoLxhMW33fAcj+aFasyDJJI69ivjRV1Le9K7rBZhu7TOOFIhSe861RA48i8L7L9J23/cwMm19tqtsLwqVvtxds5mNMT9ndAuU9ahL95CKcfVhfnX6FEbXhhBxX/hPvt/Cgo1dXPvsKgBuvWQGj3zpICY2fFiq47D3D/T/duwsONaj/P5xsLw5y63zN6MZvUXL9gbr2vIkgh+vrMkqpGl/6MfkFj9KaOpx1F/wW+RIX2sMM5+k7f4fUlz5Gv7RM/EOnYKjl2CnyozE0VdS2riY4uo3sLutbzzVQxD9bmvCx9U9APdFiigR3kXlN/fuk1iFJPEjLu/VU2lrJdJv3o130HgCEw7rM57tuMJ+jy/rq6o+gAEM4N+PjmyZx9/bzksr23hxZSsLNnbSlC6xvjPP9oxKumzgkSW2dRXRLRufJNJRMLFskDxexp54JUa6lc7FT+PlQ7Ed38gZ+IbtU3EW6IHkjxDa51iKq9/AzHb0e0wfB4ExM2m4+A+IviDtD/4PuaXP9Jq8C8OnMuziPyD7I2x78H9IffACsYDM0HgIQRTQHbARCE89idEX/wZDU3n+V1fy/qtPklEtVu/IsaI5Ta5sIIhQ1E2KuknAKyMIQuV5WNYtmlJF2vJuLzVAumTglyWCHol0tw7I8NoQ+4+s4oTJ9YyuCZMtm1SFPPgViXK3WFZNxMeBI6sI+jyMqg3TEPOzI1P+1NfqvwGOA6Zl45UlasNeRtdGaIgGCckSjdVhhtX4OWpCHSfs08Dkxjh+WcICHAS2Z9yM9dZkEdUwKWgGhuX0KgffHWzbZk1rjpU7shQ1k6JhuuTZqyAJDh5ZpDriIajISKKIqhnkNYu2nMq7z97DQ9degaB4mPm1P+EdezCy0F0RIrpzwqxq0TL/YTY+cC2CL0x4+ilI0Vpw7Aq52nlOYKkFBD79nEDv2Ow6lkw/GUvxIgLZsklbcxMbXnuEMQcezZzDZzM0EaQq7CES8pJa8AhaPs2Mc75BRnUId3uLt2VVVrdl6cyr5FVjgGR/hvhMCbbjOG84jnPSZznm/+9Yui3Nubcu4p2tn7z8OH7kF3BMg/Sbd/da7qkZ5tpwvPc8RmpHZbnoDRCefjLl9YvQO7d+4v32QBAEIvufRt25ru9l693frkTnKuunn0ztOddiFVK03f0tyluW9R1HlIgffqkb/U630nbnN3qNszPOnjGkT3Z6VzK9dFu6X3XpCw4cyunTGnvV3+9K5KYPi1MV6q3iuGggg71X6OnFFgW3rP/o8bVUh3uTXgG48MChVIf6J8M9ASYRKmJmO/tc9+d5ff/iJq5+YgVtOa3fMftDedtyWu78GnrLeqpO/Jbbby33bSPQdqyh7e5voXduwzNoPI5RxmjbSOq128m/9zyO7U7e5EgNiaO+SGndQtLz7iL50l8pb1m2131Vu8LMdVFY8SqhfY7pFeG2Cmlybz9KYOxBfTLt2UUPYRczxI++YreqwA7wwDtNA9oCA/inY2Be8PGgGhZbUkVU00Y1bdpzKpIkkVcNJCDhkxFFAcvQacmp5FUT3bLBEVAtG1kUSYzfn/Do/Wmf/yB2MY0XCEkgCQLxI76AXc6TXfRwr/1GDjgDEMgufuwzOQ+leggNl/wB/8jppOfeQvL5P2IbHz6brUQjNRf9nsjIqex47q+sfOxGtrUnyZVNkiqoQNkGuX4C+3/9rwSGjGf+P37Fbb/4LovW7WDZtiRLtnSwPVkkEfQyvCqEt1vosockpEs6XkkkqEikizp2tziWalqoho1H/vD5GPHKZFSTdNkg4JUp6e519ckimmmhmRYxv4dBMdc2rCVTxjugJv6RsG0H1bQIemQ6CxrNqRIF3cBybLyKzOB4gKDXy4wRVQyrDjFzdDXjGsKEPTKNMT8Orm3au5u7KOomubJBzC8T8SsUVINkQSNXNmjNlmnJlHuVjudVi66CRm3Ei2E5bEsWWd1SoKyZdBV1Aj6Z5rRGa05lU2eelTuyNO1o5YU//YC37v8TVRMOZPa3/0ZsiOvUYzoOkgCSA4KpsuOJ39H62t1I4SrkWB1mppXka7ejv/c8QfqfE7S+9FfUjzkn2PUuyy58GMHjJ7LfSVhAQYcdGZOVz92B49hMPvmLmJZDfdTPuPoIk4Ml1rz6COMOOZFxU/bFr4gEfQqSIOBYNls7i7y7Jcm8tR1s6MhVPMoH8OkwIIP4b0IP+duRKWNan+5mVhKNRGacQu6dxwlPOx5vw9jKutjsCymumUf69TuoPfMnleXhGaeQW/IU2YUPUXPqDz7V/nvgG7YPDZfeSOcTv6Lz8V8SnXUu0UMuqFiF+YdPpf5zN9L5+C/peORnxA67hMgBZ/YhAYExM2m49E90PfUbOh//JeEZpxI/7NIK8REFCHtlDrjuFfKqyZxJ9dx43rReZFrvJss7q0vvrDq9azZawO3r7iFyb29OMqkhwvwNXZVtVrXmWLotPaAozof3b8/13PXvS2cN5/a3tmDaTr/e0sdMrOO606cwaVCUq59Y0e8+HNxKhWtOmgTA+be9jWHaFaE703Yqntfg2trtLRzbIrvwQbILHkRONFJz7i/x1Azvu53jUFj+EqlXbkaOVFN37g/JL3mamtN+CEBpw2LUbe9TeP8Fwvu5HMI3dB+kUBVmuhVtxxrqL/7DJybYucWPAA7RA8/qtTz95t04lkns8Et7LTdSO8i9+xTByUf3eg70B9uB/+m+9gMl4gMYwL8PjuNQNiwcx30XyYKIX5HocDQEYHDMR0m3CXhFcpqBaVpsSqrUhDyUyho5zWRcQ5RyWccGdMOi+ogvULjjKpLz72LCWd9EliSCuk25cSS5KUeRW/I0oanHo8Tdtik5UkNoylEUPniJ6Myz+63i+bgQvUFqzvgx2QUPkl3wAHrHVmpOv7ri7Sv6QsROvwbPgntoX/go+datVJ3yI8RwVWUMFbA8YYac/QvSix6gaf5DdG5dQ/rin9A4Yiy1ERVJgO3JPIs2JUEUmdwYZcawBIok0JnXyJQMFElkUMxPTdhLriziABGfjGHZWLZDV0Fje7oEDgyr8qObNrGAggPsSJcRBFfDoi7iQxZFvIqAVx7oUe1R6S4bJrmyiSS4SRXVtAgqEmXTQjccHMHBshxqIl42d5WwbQGfRyARVKiL+vF0K1Yngl4+P3sUr63uoDVXdj3KAwp5SSDoVZAlAUkQKKgmTakikiDQWVRpjAXwSBKdeY3B8UC3zVaZroKrW2BaNiOqAiTzBu0FFa8k4RVERNGh2q+wqd1k/crlbHnsdxj5LobM+TwjDj8bj1/BcQTqwj5KuoWDQ6lrByvv+QVqZxOJg8/H6Gqi4bQfIQhQWr+Y4rb3aVm25zmB+DHmBDvryeid2yitW0Bk1tmIPrc6UAe01g2klr/KoNlnYodrEQSBEdU+6qIBbv7xj/F6vZz8hW8T9XsJVykcMCpBoWzRli0jlw22JYt05FQ6CyqKKDOiJvhZ3SL/tRgg2P8G9GRaddNGltxM36cNGEUPOo/CqtdIzb2F+ouur5SdSME40VnnkJl3F+Wt71fsASR/hPD0k8ktegT9oHP7JRefBHKklvoLf0fy5b+RXfQQWut6qk/5XsU/WInVU3/R9SRfuInMG3e6mcMTvtnHXkiJ1VN/4fWk37iD/JKn0LavpPqU76MkGhEEepG2J99vIVXUOW5yQ+U62o77MuwR3tqZAH7x7iW80m3FBW500AFsx+HmeZuYt74T07LxyCIHDI9Xqguc7j7s/2aCvXRbmseXNfPIku2YtuNaYjgOpu10B0ocbNsNguzulpZEgSsPGwW4xK4pWeyXhINLon/1/Gr8iissA24vfg8M0+axZc08urS5sv6jYOY66Xr2BrTtKwlOPpLEMV9G9Pj7bOeYOqlXbqbwwcv4RuxXsc9Ivfy/FFa9TmjSEfiGT8WxTdRty1GbV+MbPBFbKyFHa1ESjfhHzdirY+r/OLvIL3+J0JSjkaMfelzq7ZsorphLeP9T+/jMp1+7HUFWiB/2uV7LJcF9Se9a/eXgCviNqw//V9/XAxjAvxOZkkGyqCEI4JclGuN+MiUdn+T2EIuigGHZmLaDZliUdYs1HQW8iofRDXEyBY3RNX7WtpgUDQvDtBg0uJHyzFNoW/AEXfudQGjYWAKySLLsEDv0Ekpr3yL9xh3Unv4/leOIzjqHwoq55BY/QuKYL38m59bTl+1pGEPymd/Tdtc3qT7pu5VnoyBK1M6+FH/tKHY89yda7vom1af+oFKZ4wCdJTAFiehBF+EdPIkdT9/AO3+6iilnfY3x9ReweFOSzZ15WjJlGqIBcmWduqCHwVUhdqTLRH0Kfq9EqqTTEPUT724jUg2L1mwZzXBFy0bXhsiUdLYky+w3JEZbVqWs29TFvPhliUzJwKfIDElIgMDHsA3//waW7WBYNh5JxLIdWnNl8mWdHVmV0TUhCqpFQTOQRIHNBYNMWWdYdQjHdhBEgZggoGoGmmUR8IiIokTIqxDcyW85FvAyY0SCsm6xstmdf/kUiaLqEkHVsAl6ZGzHYXAiQFdOpyHiR1Tc+YFt22zsLGLbNrGAh7ZkkYJhgGkhi6CIrmjg9rRFQbfYVDZY+eJ9bHzpTpRwFcMu+C1DJk2mLuZnSDSIjUPMJ7NiR5Ztyxex4dHf4jgw4tyf4h85nc23f5X86tepnXYE0siplP9JcwKA7MIHETw+IvufVlnmOA7p125DDESRpp9LTjeIBjxolsCm9xbw7ptz+eJ3fsLsKWNRLYvBMT914QBD4yK6abO5s0RrVqUu4sW0HTIlDdsODIicfUoMEOx/AXbO9vX4BfdkWk3TRhAFcBwEYP/hcWIBD6+sbv9YAlGiN0D8sEtJPn8jxZWvE5ryof91ZMapFN5/kfSrt+G77KZKRjmy/2nklz5D9q37qTn96j2OX1q/yLUXyHUiR2vxDZ9GaNKReAaN65OBFmQP1Sd8A++gcaTm3kzrnd+g5rQfVTJqosdP9SnfJ98whvQbd6Lf/W1qTr8aT/XQXcZRSBx9Jb5h+5J8/k+03vl14kddSWifY/rs880NXQxJBCoevyJueRi45d4792OrRm8i1vMv03J6EW/DdB/Osii4pWU7ZcD/W9Bz78YDHla1ZHnwnSZ2LrjoRWp3Ym+W03//iQj84tTJvchc2K/0CjIlAgp5zawQ6YJmUdB6K4YKgCC4NiDPLN+x1+S6uPYtUi/+GcexqTrxW7184neGkWmj68lfo7dvIjLrXGI7VWJEDzqP8sbFeBvGoiQa8TaMQ29Zh1VIYeY6KCx/hdC04ysl3XrHZgorXqW8eSlmpg0pGCd26EW73XcPsoseAsed9PbAcRxSr96G6A93q5R+iPKmdylvepfY4ZcjheKV67T/8Dg/OH4Cr6xq6zeQYQ4EjgYwgH8pTMumoJluNZZPoaSb+D0SAgK5skFAERleHcQhSMyvEA96aU67vageSWRMbYhcUaM5o6IoCvsNibM9UySrWwQkEVmRqQl5sI68gOTy12l/9Vbil/8OwwKPCGYoQXTm2WTm30N523L8w/YFQI7WEZpyDPnlL7n2fpHa3Z6DVcqSfP5GylvfR1S8eAaNIzjuEALjZyN6fH22D4zaH+VzN9L55K/oePTnRA86j+jB5yGIEkVAHDebusRQOp+4jvYHriZ+xOWEZ5yKIAiVd4kKKEOnMeaKP9P81B9Y/uANFDYs4cSv/ARNCqKbFh2FMjYO7QWdmqhJzK9g0x0E3sWXpaiZSIJAIui+59tyKqphEfEqaJZDuqTjl0W2dRUZFPUT8irots22VBm/IjKhIfIZ3RH/GTAsm9ZsGcuGsm7iOA6rWrKuSFZJJ10wiAUlyrpNIuhFkgUE0X3HmJZNSJZY2ZylI6/REPW5oqaDI9QEvRiWjdidGDIsm0zJwLRsaiN+aiJehiWCvNeUoqCb4Dg0pXVqQx5U3STkkyvb+z0SG9rztGVLxAMeNrTlWLSpi7Juo5sWPo9AQS9RVN3foV1IsfyB39Kxbhk1U2bTeOJVKIEIE2rCjGoIUxP2YZoW729Ps/LZv7PupXtQakdSe/rV2LF6DAvqDz2P7PrFiKPGUvJ+9JzA1ooUV8+jtG4BWusGsC38ow+g6riv9Uk27Qy9YwultfOJzDq3krQCKK19C615NYk5X0X0BjB1k7yqEfTI3PPHX9A4bCSnX/R5JjTGsByoj/rwKRKiACXdZMnWtNuu50BXTiPi9wyQ688AAwT7n4yds9U95aw9ZcuqYbvkbqf09Zi6MGfsN5g3N3SiGXv2wt4VwclHkn/vBdLz/kFg7ExEr1viIcgeYkdcTteTv6bw/ouE93NN6SV/hMiM08gufACtbSPe+tH9jmsVM3Q9cz1yYjCR/U/HSDVTXPkqhfeep+6C3/SrtgyumJqnbhQdj/2Ctnu/S+KoKwlNO6EiQhI54Aw89aPpfOp3tN39LRJzvkpo0hF9xgmMmYnn8tEkn/sDqRdvorz5XarmfBUp0NvT+tU17YjdZFgQBeKB3v29PYGNXSGLHyqV90DAJYkvdxNuubtU+f8nErJr4Ke/f/fcu5+kwqK/j5x/4FDG1Yf56+sbK6S9I+9mbno+UNAtfn7KZH730loypf6F5Zzu/xiWg2F9tF2HrZVIzb2F4spX8TSMofrk7/XJ/vagtHExyWf/gAPUnPETAmMO7LXeN2wfjOR28sueIzb7IuRINVKkBm3HGoLjDyE09TjkUAK16QMyCx5Ea/oAJAXfsH0IjDkQrXk1yef/hLdxwm6Pwci0UfjgZUL7zkGO1n14bOsWoG1fSeLYr1TKwwAc0yD16q3ub3TGyb3GOmxcLdOHxVnXlmeny1yBIPBfFzgawAD+XTAtm61dhUoFkGlD0CvTnlPZ2lUkV9bxKxITB8UwbYu8alId9lEb8rKlM09LVmNQ1MvhE2pZ11HEI0i050tohk1d0EumbDIiEWRITQDVdJh62pW8e+9vyK14ndi0Y4kFwa9ZSPufRn75S27g/dI/7RRAPIfCyrlkFzxI1fFf3+15pF//O+Wt7xOeejyOqaNuW07yhT9R3rJst21nSryB+ouup+vp37vzjpa1VJ/83cq73FMzjIbP/ZGu5/5I+rXb0ZpXU3XCN9C75zIKoAhQW1vLiK9cT9Nbj7H6mdu447vnc8wVPyE2ejqOLbDf8BgRn8jypjRlw3bLvYNeYrv4+nplkUzJxjBt4gEPsiwwsiaELECyaBD2ysiySGexREe2zFo1y7rWPPsMiVIfibBrd19P7+p/KjlRu6sjfIqEJAq0ZsvYjkNDxI8ii6SKGmXdQhFFNncWCHslVjVn0C3X80YSIVtWqA97aE4WCQVkrG6v89qwl7JhE/RJ+FUJUYSR1RGwBbZnyoiCgGFZlHUbRYKOfBlJEAl5JRRRxOeRCXhkfLKEZtiAiWba5FQTy3FojPtpy5VJl0xKmkFHXiVZ0FnbkiWnGlQFfXQWVPKqhWWC4EBh41tsf/YvYBmMOPUbDJ95AibgkaBg2mzoKKKqJoVsmmf+8COaV75DdPLRRI79MqLiavQYgL9xH7Jt29k4f89zAquYIffO4+TffwFHL6NUDSXYPectvP8Ccria+JGf3+33k5l/L4I3SOSA0/Hglobbhkr6jTtQakcQ2ucYfACCQGdBZ+kb97J96ybO/uFNvLKmizWdZQ4cVsWgmA/DssmWNFqzZRqiXgTHRpRFxtSECfukf+p99t+CAYL9T8bjy5orRLlHufe606dw3KR6nny/pde2DvDQu02csd9gLp01vFe2aVgiwLZUX4urnSEIIoljvkTb3d8m89b9JI76YmVdYOxBeIfuQ2b+vQQmHFrpCY0ccBr5Zc+SefNu6s65tt9xtZZ1OKZO1bFfxts4AXAJS2n9IryDJ+72eBzLQGtZ6+7LsUm98jdKm5ZQc8r3KlE639B93H7rp39HsrtsN37UFZWHVw/kcDW15/6S3DtPkpl/Ny13XEXVcV/v5Ru4s7iVZTv87JlVvUpfewIbek9gA9cr+9pTJrOqJctDS7ZX+uF3JSGm7VQy4j3YlZD+J2HXwM81J03iZ0+v7FbnFPjZKZN5YWVrn2z/3qK/tgdRgI68Vuml3h1n102bv76+gX0ao7y5Uw/8rthbzq82fUDXczdi5btcXYCDz0eQ+j76HMsk8+bd5N55HE/dKKpP/WGlP3FnSP4IwQmHkn/vebqeuZ7EMV+ivOldfEOnAGCXc7Q/fyPq1veQQglih19OaJ9jKr85I91Cy61XoDat3C3Bzr51H4IoEZ11bmWZbaikX/87Ss1wQvvO6bV97t0nMNOt1J5zbW8/bGBDe56rn1jBQ+9u7/eazeiu8BjAAAbwz4VtO+zIlGnOlLFtN7ibLxvsMyROUTXY1FXAKwmsacvTmdMI+mVqwj4Gxfxkyq5I18iaIDvSJWrCEWYM87Cts8SqVp10ycS2TBxs/D6Z+rCXifVRxp90Jtvefo7tr97Fvoccw9YiSBJEw14ajryc5id/0yvwLkdqCU89nvyy54gccAZK1eB+z0Xd9gHBcQeTOPoKwK2u0ZpXIXp2n4FzLIPCB69g5jpQakeibn2Plju+Ss1pV+Mb7M4tRG+QmtP/h9w7T5CZdyf6nVuoOe2HeOpGIQCG69NJSRSoO+RMqsfsx6oHf80Tv/06h5xyPt/40c+YUB9ldYsr1rS5q0DY52GcRyJfNqiNfJhdD/kUBAFaMmXqoj78HgnddMladdhha5crfNWRKbG2NcvmdJmAIrKuLcfp0xppiH3YWlRQDTrzGoIgUBdxx/pnwrIdHKe7TeszgGZatGTKyKJAuqRT1g0yJQMBgVRBoy7iI69ZvL8tRUumSEva1Q8qGiZVQR+GZZEu6nTmdCQBBsX8xINequt9FHSD+oifDR15ZFFy/azTZQREfJKIxyMR8Eq05zQCHgnNgJJm4vNILG/OoxkODbE0Q+IB/B6JdF7FJ4nUhhVCfoU1zVkEx2FbqkTMq9BWLPPe1jSZsoaqWeQ1k5akRrm7ss5RC7S9cjPF1W/gaxjD8FO/S6xxMGXDIOCTERHQDZOgV2LRokXMv/1naMUsI0/9Oua43hWUDlDyRfDvYU4geoNkFjxAbvFjOKZOYMJsIjNOw9swpjKOmW5Bbfpg99/PjjWUNy4mMftiIsEQES9kVWhZ/DhWrpPqE7+NKEpuEN1xKKU6WfjY7QyZOpvtwdEYXTlymkWwW8lddxzaM2UcIB70IEoidWEvI2sC2M5/ZoDo/xoGJBD/iVi6Lc0jSz6c1DrAI0u2V4hZfzBteGxZM6tac72WN6VLe/VleRvGENp3Dvmlz/RSCBcEgcTRV2BrRTLz760sF71BIjPPRt2yDHVb/z9ux3KziMJOpFf0BghNOWq3FgOOY1PasBitaQVVx3+dwV+9G2/jBNTNS2i9+1voHVsq28rhKurO/xWRmWdRWP6Sa/XVtb3PmIIgEj3wDBou+SNSIEbnY9eSfP5GvFb/gQfd7G1F1NOPffCY6kqhmN1NnK87fQrThsT6HacHO2f5+lMs/2eiP9XsT4Nd1dUfercJvdtjXLccfvLkCt7aA7ndEwTg6Al1eKTeD2nbgVdWt6PvgVz3YEdG3SO53hvYukpq7i20P3A1giRRf+FviR16cb/k2sx20H7/D8m98zihqcdTf9H1/ZLrHiiJRuJHXI4crSO78CGURCOhqceTfPlvtN75DfS2DcSPuJxBV9xG9MAzeomcCXL378g2+x1b79hCcdUbhKefjLyT2E9u8WNYuU4SR19ZyTYBmLkOsoseIjD2IPwj9usz3pPvt/DA4iasfsoQJAF+ePyE3Z7nAAYwgM8OeregVizgoSuvUjYtFEkkVzZcHRDLIRrw4JdESpbJiKpQd2BSdW2ETIeOrIZtQ9mwsB2BjV1FaoMKAVkASWbKoCj7Do4iCiLThkepCfs4+cqrKeUyNL9+L2NqIiTCXoYmAoyddQTBYW7g3Sp/OOeIzjoXQfGS2cWZZGc4ltlrTiAIAr4hk/HUjex/+53nBHOuYtBlNxGceBjYFu33/4Ds4sdwui0PBUEgeuAZ1F3wGxxTp/We71BY+gyC47hWSYKAJEFjzMeUffbhKzc8yJxzL2PBMw/yzXPn8MDTL7G+PUe6pFNULVI5jWVb0zSlihTU3lVRQa9CyKdgOQ470ipt2TKG5eD3SNg4rG/L8tradpZuT7EjWSKdN2hJF9mRKVPWTFTDfY53FXT8HgmPJNBV2Hsni48D07Ipawapgsq2ZIGmdInsZ+TFbVquwp5PkUBwdQEUWSTsVyhoJmXDwiuLZEs6rTmNjrxOW16jq6ihmTqOBKpu4VNEMiUd24aITyHgFQn7FATBIR70kC2pKJLI4EQQHAe/TyZT1tmeLBFQJGRJQDVtwj4PflkiWdTQTIumrhJbOgtYpoMtSIT9CpYjsK4tR1dJZ2NnkbZcmZWtGV5b087G9hxdeQ1bANOEguNabOU2vcvWv3+F4tr5VB98AUMvvB6luhEch6IGqZxJc8Zgc2eO1x68hZdu+BqO7GHcF39PYtqxBHdTndDfnCBy4FkU175Fy+1fIvvWffhH7Megz/+VmpO/14tcgzu/dqz+5wSO45CedxdSMMbIQ08hEXArz+xch2svOn42gaFT8AJBL4T9HhY//Bccy2K/s76K7QiEfR4CXpF0WUc1bEzTwnYcQj4JWRKpD3upDnsJeD3Eg32dVAbw8TGQwf4n4u3NScxdJrVWd7/jkERgt3ZCAnD85IZeCtY4MLI2xMaOQr+f2Rmxwy6htG4BqVdupu78X1eibZ6a4YSnnUD+vecJ7zun8iKMTD+J/NJnSM/7B/UX39CHNMuRGgDMTDue2v5fnrvCTLWgbV+Bf/QBeAeNw9ZVpEgt0ZHTKbz3PK13f5vE0VcS2neOWzIuSsQPuxTfkCl0PXsDrXd/c7f91p7aETRc8kcyCx8g9/ajbNryHonjvkpg1P59juORJdsrntjgkuydr63twPLtGb754Hu8+xFWaRff/jaXzBrOD0+Y0K/912eRBbx/cRMvrGxlUkOEnGby3rY07XmVbMmdgPW0GXzafe2qru6Ve3/nn0TY3iMJmLaDKAgEPBITGiK0ZMp0Fj6bCcDHgdq0guQLN2FmWgnvdxKxwy7tty8Qdu7Ldqg+5fsEJxzae/2aN3Fsq0/7giBKJI75ErZlUlo9j5Zbr8BWC4SnnUD0kAt3qxxuZt22Ayncv0pv+o07EX1u4KsHRqaN3OLHCIyfXYmKV7Z/9XZwXLu+3aHn6xRwyxenD40xui7MmTv9NgYwgAH8c6FIIpIo4JVEYgEPQxMBvLIbLKsJ+WiI+SioBjURLyGvguk4+BQJWRCpDnmoi3jYmlQZWRsiFvCQ8CuMrguxqT1PdcRPTdTBRsQjSzRlyqimTUdeJTp4DJOPOpMFzzxIYtoxhKqHUzahLhpgn7O+yqI/fInsm/eQmHMVQUAPxojsfzrZBfej7ViLt3F8n3ORI9WYmfY+y3eH/uYEjgPxo66gvH4hmTf+gbZ9JVUnfLNSMu4bPJGGy24i9fyNJOfegrZtOSNP+zqylKBUNtmiZ9AtkWFVfqacfhXTDj6GW677Hr/8yvkccMJ5TD3ti+higKBHIGx7sWyL9pxGsNsvuwdxv8K2ziKC4JaNb+jIoVsWb29MsqmjQGdJo6Q5FHUQBQ2PLvLWuk7SJYPpw6o4cEQVDg66aaGaNh5JwnGcPnOXT4JtySKpggbY5FSbDa0ZcprF8JogBwxPkCoZRPzKp96XVxZd8S/dQnDAJ4ts6CwS8sqMrQ3hUyRa0yVSBdcaK13W8MsyiiShGgJBSXAzp7qFXxEZURukIeojXTKwbZtVO7I4iOgWRP0SYUWiOVtGkgQEBwKKRCLoYVNnnqJq0BAPIOLO0UQR2rMqiZCC6Qiohkmu7JAq6ng9Ig1RP0XNJChLbMyUKKoalg2qDoJjggCyWqDj1dvJr5yLUj2UUedeQ3TYaCIeCdVyK+w8tkVZAzWbZMuzN6A2fUB88mE0nnQVfk+AkgolxxUgzb3zGPEjPt8rYN8zJ3BsCzPTRsfD16BufQ+ldiTVJ393t+2U4M4LdjcnKG9e4raGHfNldPyUim4Ze+vc2wGBuiMvxyOCTwZFkSls/YAdy15j7HGfI+eposEDiiygCDCqOkRXScN2wBEg7vdSF/EzKOJHlkWEbhX4AXx6DBDsfyIqJcndRGxnT9+ZI6s45+aFFTEoSRIwLVdIq4cQNiWL3Dp/M44DXkXk8oNH8NPuMt49QfJHiB32OVIv/YXi6jd6EYPo7IsornmT1Nybqbvgty65lT3EZl9E8vk/Uloz340q7wSlZjiIMtqONQTGztqrc1e3r8SxzEp/idG1DTmUwDtoPOF9j6Pr2RtIvfQX1G3LSRx9BVLQneT7R06n4bI/k3zuBlIv3oS69T0Sc65C2qnfFFwBtPihlxAYM4vk8zfS+ejPCUw8jMRRV/TqzTYtp1It0CPW9cLK1l5jvbx67yYJJcPm5jc305ZTCXplZFGoqGfv2u/9cbF0W5pb5m2qHMv83WRvdeOzIfM7q6vHAx5+9vTKyjpJFJAEuvuq9g5TB0c5Z/+hXPPUSkzb6dP+0B8m1IdZ05av/Dvml8mU+4/g7i1srUj6jX9QeP9F5Fg9def/Ct/QffrfVldJv3YbheUvdfdlf79X1to2VNKv3k5h+Yv4hu1LcOLhfV48RmoHyZf+gta0Au+g8STmfOUjg1DajjUAeOvH9FlX3vo+6palrlDZTvd8+rXbQRCJH9G7P6u8eSml9QuJHXpJL6Xx/iAJcN4BQ3sFnAYwgAH86yCJAo0xP6ppEw96UA0bnyIS8StIosCciQ205suEvDKKINJV0PF5JKrCXkRRZN8hCeqiKrpp41ckIn6FyYNiWKbjkkZgW6ZEXrPIFjSKZYuSZoMAk076AusWvcKCu3/PIV/7I7GAn4hfYez4iXQcehqb3nyCuulzaBg9hlTBwTzgdPLvP0/6jTsqc4Wd4akfTXG1G3zcuaJmd9jdnEAKxqg+7Ufklz1L+vW/0/qPr1N1wjcq1TjhQJTRn/8ZW+Y9TtMrd7Hm5q8x4vTvEB6xD4Io4JEkbCTyZZ2G4ZM491f38/q9N7H4uQdY+fYbHHHZD5lw1DEogoBpuW4hmmlT1AxMGwzTct913ae3vj1He76MIkrIoqt4LeASsIAH4n4JSVYoGSZbuwr4ZAm/R6Q65KOt4OqJRP0eOvMqtZG+7hR7A820XHuxouu3HQsqLNmSRgCa0mUQoDOnYpg2EwZFcZzAp1Y0l7ttzAzLpqCaiIJAddhHtqwzOBHEq0ikCxrRkESt6kHVDfyKRDTgJeGX2NRVwqeIxGWJaMDD6Oowmu1Wq+VUg66SztjaMHnVoCrkoSbsxecRSRUNfB4JnyziOK46+JBEkFRRZ1RtmNGpMu9vSyKLIrIkUijrGLZFUQe/LOKXRLqKOgXNYEe6wNqWDJnSTi1qlkNx3UK2Pn8zVilLZNY5DDrofEJhBY8kIIoCPklEsG3KQGrdYjpe+BOOqVF1/NcJTzmGvCNgaK7IXmnjOySfvxHH1AlOOrJPJtqxDLKLHyO78CEESSF+9JWEp52wx9+IrZUwOrcRGHVAr+V+XEvRltf/gRwfRGjfORg2aDaUtr1Hbt1CqmZfTChegwCE/RIBBRY8dCPhmkEce/4XCfoChP0SUwbHmTI4hmk5iKKAalgEPBKN8SCKJAyQ6n8CBgj2PxG7kph0Se/Vr/vwlw7q10u4Z/0PT5jAMZPqey0fVx/mlnmbmLumfY/CU6F9j6XwwcukX/87gVH7VwSRJF+I2GGXknrxJoqrXqsoGQcnHU5uyZOk37ybwNhZCPKHhFFUvHgHT6C8eQnxIy7fq3PPv/ccsUMuRBAlzFwX2o61OLaJt2EsojdA7Tk/J/2aa8FV2riY0JSjqTr2K4BbMl57zi/ILX6MzFv3oe1YQ/WJ38Y3rC9R8jaMoeHSG8kuepjsokdQNy8jfuTnCU4+ylUfFQUWb07y+5fWfSzBuD3hyfdb6KkSsh33hf2zpz+0Otqb3uyl29I8tqwZAZg0KMq1z67aq35nGz41me9Bj7r6X1/fWKm0EIBz9x/CmfsN5qr7lu62ymJXhP0KK1uy/ZYhCzv9/85naFi9zzfo/eQE23EcSusWkH71VqxihvD+pxGbfRGi0n/WWmtdT9czv8dMtxI58Cxisy/s1busd26l66nfYSSbutdf1LvvyjLJvfM4mQUPIMgeEnOu6q7G+OhGjvLmJSjVwypK35UxbYv0639HitQSmX7Sh9tvWkJ5w9vEDvtcL29ax9RJzb0ZOdFIZP/T++xnQn2YTV1FTNNGFAWuPXXygN/1AAbwb4YsiYQkkZC37/Qr5FcY4//wOVQV9vZ67siSSGPM7wqkie6kuD7qIz6+lmRRpyNbpjWvsqmjiGVaWI6A6diIiIjeIBNOuZL37/8tq956gRlHn8Ho6gCZssHsc7/MjmWvk3z1VkaMvYFQSMAR/eQOuZDUS3+lvH4RgXEH9TpW37CpFN5/Ea15dZ+qmv6wpzmBIAhEpp+MHK4m+fJf6Xj4GnxD96X23GspiRIp1aHukDMRGqbQ/PTvWX/3/xA54HQSsy8mGnTwFyVEQUBIFuksOgw+/stEJ85m9SM38Ozvv8Hat+Zw9lVXM7o+hCzC8u1ptiaL1AQ8ZEsmg6p82Fi8uroDRxAJKgKWABGvgtcj4zgiNSEvhu1QFfRiIqDb0JHX8cl59h9ZRcgrs3RbCt20CHkVqsNetzQaV3tEMyyqQh78iozPI6GbFprhEPC6Jbp51aA9pyJLAo7toEgSjgPpsk5txIdu2QQUgYJq4JVd/2bDtPEpEmXDItjP/fRxIYkCkihR1kzKpklAkQn7PXjk7soLj0y6aOCRZOojPiYNimM6FvPWdZJVTaI+mfENIcbWR7Adh7ZsmbxqYFpQ1A26Chp+j0TIK6NbHwaZvIqMZloYllttZZgWtg0+RSDglRlTH3V7nTULUQZJtVEth4DHQ2dBQzVs1jWl2JYto+pun74MOMVONrxyM9l1iwk0jCJ69s/w1I1CVsCyQPSI5FUTxwHHVNn2wu10LnnR1WA5+bsoVUMq88ayabi2sUufQakdQc0pP+ijT6C1rCP5wk0YXdsIjDuE+NFXVJTD94TylmXg2Pi6bXR7YAG5D17GSDbRcNrVeCQZywavYLD5pVvwJhoYdeTZ1MS86KZDWYdNrz9Evr2Jg6/8NXldIOKHQtmkJuylJuKjqFkUVNdurjrsrfiPD+CzxwDB/idj+h4EhHZd1992/W1z6yUzuH9xE1c/sWK3+xUEkcSxX3EFz+bf08vTMrTP0RQ+eIn06/8gMPpARF/ILdE+4vN0PPRjckueJjrzrF7jBcYeRHruLehdTX3stHaFY7kvzZ5yl+Lq17EKKQJjZ1XEzRythJltIzjxUMpbl1N473nMVAu15/wcQZS6BZ7OwTd8Gl3PXE/7g/9DeP9TiR96SS/yDyBICrFDLiQwfjapF/9C8vkbKayYS+KYLzN12j6fupe3P+zKI3XL4cdPrKA9r5EqflgSXRPyMG1onIBH4u3NSYYmApw2bTDXPLWCHkFzqVv5fG8gQB+xtU+KnkDAhvZ8ZWxlp4dtPODpRbAlYfel41VBz25FtE6dOohk0VXG7QkMCcCuIXe/55M9joxMG+lXbnaJa+1Ias74SZ+ocg8cy3SDMQsfRAolqDv/ul4ZbsdxujMpdyD6gtSe8wv8I6b1GkNr20jyhT9hdGwhMPYg4sd8aa9eogBmIYW2fVUv660eFFe+itGxheqTv1e5x10SfUu3gv9pvbbPvv2oK2x27i8R5L49Uxs7C/z8lMl9AnsDGMAA/jPQX1ZJEASUnfQtBEHA55EZpEhkSxqqaiKKNv6AF8FxqAoFkEWBkqEgH3kKWxY+R9NLf2fwvofwjm1TFfbQVhYZftwXWPvI9Wxc8AKxaceiGhDa51jyS54h/cY/8I/av9dzxj9yOoLsobj2rY8k2HszJ7DVAoWVrxIcfyh611bUbcvZccsXqL/gN2SjdeiajVM3mobP3Uj69dvJvfM45S3LME/6DomxY4n5RHIlla6SgV8WiQydxAHfvpXk/IdY8eK93PSVt7G//xMiZ11IumyQKWpsTRbQTYflzUm2dOZJlkzqYwE8osCkxgiqZTI4ESStGnglifpYgJG1Pra0legqaiSCXoZVBxFwFbdVzcLCIV1yieSqHVk2tGfZ3FXGcWx0w2F4tZ+GmJ+ugk5eNamLBpg9uoqWrEqmpOPgvodG14aJ+mXiAQ+pfJmGiAdZFGmIeCnpNhG/RE3Eiyx+NEHqcU/xyCKGZeM4VIiVbbsWZrbtUDJMDMumK6+xI1UCUWTK4Ihb9u6AIgqE/Ar1ERnT9nPY+Dq2p4os2JhifF2A5kyZZMFiUMxPe0Fla2eBLV0lgj6ZiNfD4LifqpCPbMkgrxoULIeo30vAK6KaEoog0JQy6chr1IV9yJJI1C/jlf2s2J5FEgUQBLxeCUsXUHWbtGpSVHU2Z0ok8w5a9/2WXfI06QX3Aw6jjv88DQefjmWCR4GMDmJ3e59uQ75pLVufuAEj3Ub0gDOIHnpx74B7VxNdz1yP0bGF8PSTiR9+ea/fgq2rZObfQ37J00jhKmrOvKaXAO9HobT2LcRAtJdosAAIWonkW/cRGDKRkTNmUTbcite2+U+iJZvZ7wvXsd/IKiRBJlc2KHS1sPCVe2nY5xAm7D+brGaSUU1iPte6zAamD00Q9Svd7RADauH/TAwQ7P9QrGrJfuQ23vrRhPc7kfzSZwlOPqriQ+2qjX+5D/n2D5+Kf9T+ZBc9RGjKUZWybYDg+ENIv3obxRVz8XxEFluQZLyDJ9HxiBstFL0BIgec0Yv0pF69HTlSQ2z2xYBDx+PXoW57n/YHfkTi2K/gqRnunkPDGBouvcmNHL77JOrmZVSd+K1+CZSneih1F/6GwgevkHnjTlrv/DrzV59KcOZ5e/QW/LiQRKHfTO3O5c496CzovUrQ23Ia725N9yKiVnc2wnFcezGfLPbxff5w3/1bKu2cNQf6/XtdW56H3m2q9Fsv2Zpm15y5Zdk8sLiJh9/dTjzQm7Q1RH00Z9S+xyTAMx+09romogBTGqPMGlnFHQu3ops2YrdvtdMtpraznoCA+/L+OLANjdziR8ktfgxEifiRXyA8/eTdlmIZXdvpeu4P6G0b3HaCY77cqwzbKqTpeuFG1M1L8Y+c4fYCBmOV9Y6pk3nrfnLvPI4UjFFz+tUExh7Uz552j+Kq18Gx+7Rh2FqJzJv34Bk0jsBOPeDZtx/FzHST6J1e+EZqB9m3HyEw4VD8u0S9K9tYrkfpdad/dHZpAAMYwH8uDMvttV6+PYtXkQnbsCNTYmRVgANH12BaNrph45gdjD3z6yz545dZ8cTNjDr9G6zvAF0DZeyhhIc9z/a5d+IdMRPTF3ED70d+no5Hfkpu6TNEDzyjsk/R48c/ZialNfNIHPn5PoHvnfHx5gQXIXoDpOfdRe6dJ2i546vEj/wCzj7HulVpHh9Vc76Kf9T+JF/4M813fQvvnEvwHXEOqiPgVyRKmkXYK+L3BRhx1pXUTDuCTU/+md/+5Ls89sA9XPydnzF03D4YOR3TtmjJlkmpJqbtUFJ1bI+E4IDlOCRCPgZFA67/sm0jI3LYuDpaciUUUWLa0DhVYS+abhH0KbTny6i6jUdSWbwpSWexjGMLOJaFIYBhWyzfnsEyLUbVRWhJFfhgu0RBs5BE9x0ZCyjkVIOQV2H/YQmeX7mDN9d1kC8Z7DM4xrGT69G6SXPAK+FXer/zXB9pNxAvCpDtrgzzSgKa7SA4AlGfRMmw2NJZJKe6ZN8riYQDMi0ZjcmDImRVk/XtOerDfnTLRhJEHBtSJY2wTyHkVfB5ZBqr/BRLOl5ZYNaoRHf/s43Z3dtcE/IQ83uoCvnwyhLZcpFkUUM3LKpDXnyyh/aMikdysGzwKxJ5VWdjR55BET9NmTJ+RaAhFuCdrUk6siqCKCCLUNIMbEcg002u1W0fkJp7M0ZXE/Fx+9N43JcIVg3CAgzHIurxIEo2jVE/G9vSbH/lPjoXPYYUruo34F547znSr9+B4PFTc9ZP+2j9lLe+T+rFP2Nm2wlNO4H4YZd+rPmmVc5T2riY8NTjes1dBKB90UOYxQx1Z/6Uoi4Q94tomSQt8x9k2H6HcvKJJ1MdUmhKF0mXNN556E8IgsD+536dSMjHoCqFeFBBFkQifg+66aCZDtHAAPX7V2DgKv8HokedfG8Qm32RK3j20l+pv+QPlR+wt360K3i27DmCk4+uvOjiR36Blr9fRebNe3p5YErBOP7RB1BYMZfY7Iv7zZjtjNDkI/GPPgAz2YynYQxWIVVZV976Pkayidozf1J5EMnROhi2L3rrelrv/jZKrJ7YYZcRGL2/+0I99isERh9I8oWbaLvnO0Rmnk3soPP6HIcgiIT3nUNgzEwy8+4i9fYTZFe8TuywzxGcfORelfDuCZIo8ItTJ/Pke8288xGiaLvDrtRcFODI8bVUh700p0p7zLifu//QSiayh1THAx6ufXYVumm7lh2OU/FY7fkb+mbd+0NPhtq0nT7iZDv6Idc94zpO78EdByJ+hWVN6Ur03Hbc//R3GA4fBij8ski5H7/yD8fuLgd//Q6sXAeB8bOJH/l55N0IhDi2Re7dJ8jMvw/R46f61B8SHH9Ir21K6xeRfPHPOIbq9kztd1Kv7JHavJrkC3/CTO0gOOUYEkd+vpcP9d7AcWwKy1/E2zgRpWpIr3XZRQ9jFdPUnPHjyn53R6IdxyH18t/c/q49CJvB3luZDWAAA/jPRVdBo6iZGJYrsgUOo6qDNMR9rGnJoZkWrTmViNdD49DRNM06jfYFj+GfeBRSt/CSZAqEjvoy+X98nbZ5dxKf477//SOn4x85g+zCBwlNPqJX4D20z7GU1rxJcd2CPiKQu+LjzgmsYobghEMxM22kXvwzmTfvIX70FYS6A5CB0Qfi/fx4Ui/9lU0v3EHr8gXUH/9Nho4ZgSxLVId9BLwKubKOJz6YWV/5A+NWzuPth27ip58/nQPmnM7+Z34JbzBOvmxgWjaWI1DWLUI+mdZsmY58GUF0Bcv8ssjYujABrwcLmNAQwbQEaqN+6qM+SqrBO1u7yKs6yZxOzC8R9MuUDA/JfJl0SUUURTJF3S2FtmwkWSLoV9w+WFnEst3ecJ8iUadICDhs7Mjx8op2tqYLCLZNZ0GlLu5jbF2MmrCXurAPQRAwLdcTWgDy3UrpjuMq0A9JBBCALV1Fhla54mFLm1J05nW6ChqmbRP3KbSrJkMJkCvr5DTTDZo7kFNNsmXXW7oq6KGomwxNBDBsCwmLar9CMqcxojrIkHgABMirOmXDoqhZNCXLWDGH7ekiQUVk0eYkhmUTCyholk1zukRBt5BFkc6c6paJdwcJYgEvwxJ+OsM+3t7cyeaOPPmyRSLswbYdbAd0yySXbSf9+j8orXsLKVrHkDN/zNBpMzFsSIRkkjkNy4FsWSfkEcjs2MAHt/2SUvtWwlOOIXbUFxC7/dbBrTZLPv8n1C1L8Y2YTvUJ3+zV1mVrRdKv/Z3CBy8jxwdRd8Fv9ihitjsUV74KlkFon2N6LbfSLWSXPEVo8lGEGsbgkyGnOWx44i84wJA5V9CSLRH1h6kL++lctZAdHyxg5jlfpXHwYPyyxJjaID7FDda42gAKAc9ASfi/CgME+z8Q/amT7w6iN0j8yC/S9fRvyS99lsj+p1bWxQ692CXfL//VVQ8XJddaYPrJ5N59ktDU43tFmMP7nUR5w9sU18wjNOXoj9y35AshNY7HKmUpLH+ZwITZeKqHom57n9DkoxADMcDtW9G2r6D+4htwDJXOp36H3rqO5LM3oE45quLn7R85nUGf/yupV28jt+ghyusXUXXCN/AOGtd334EoVcd/ndDU40jNvYXk8zeSX/Ys8SO/8Ikegj0YW+uSqp3LwD8NerK6c9e0I4nCRwrYTRrkCrjt7GMtCm5G3e1dconprn/v9fFAn6x2DxzcqOqu4+2OMPcn1LY3t+2eyLXWuoH0a7ejNa9CqRlO9R5EzMDtpU6+8Cf01g34x86i6tiv9Jog2lqR1NzbKK6ci6duFFUnfadXC4Stl8m8eTf5pc8iRWv7LRnfW5Q3L8VMtxI75KJey410C7klTxKcfGTlXnYch9QrN/dLoour30Dd9j6JXUrTPZLAqJoQ69ryOIAiCZy5X/8etgMYwAD+/4FtO2iGzcjaCF5ZYkemzMiaIF5J4I3WTiYNitCR1zGxUU2bhsPOJ7ViPu0v/S+DL/sTtqRgAb6a4UT2P43MO4/jnXQ0vu6S1fiRX6Dljq+Snnc31Sd8o7Jf37B9kBODyS99ul8RyF3xSeYEoi9EZv595BY/SvLZGyivX0TNqT9wxwtEqT7tR6hr3iQ592a23Pk1SodewISjz8OjBPApIkVdJOx1e4mrDj6Ok08+idcfuoXH772d9+e9xIxTPsfw2acT9XsplHUCioToOLRmVbpKJpJoMjjmJ+KV2dxRYsYIH+nuDK6mmbRlyxQ1E92w3ZLviB+PLNNV1KgO+zFMCxG/W/5tQUrVGRzzE/B60E2b/RuiTGqM0JHXiPhlVMOmOuj2b29Pl2hOlSlpGtmCjk8RMSyL1ozKvkMkQKiolbfnVEzLwcIhU9RpiLoCa5IoUDIsBCDkl1FNm2xRZ31bHkmEsm5R1E1s26Gs2xR0g+qgl0LZpDHuw3Lc3mdFEt0S8YCHeMCDZTtsT5fpzBmMqY9QE/ZT1C2WbM9gmjamZZIpWlRHPGRUDSdrk1cNuooaAa+Eptms2pHBNNz+7KBHQrdt2rNldMsi4pUZkogweYjMB80ZdMOhbFgUNIv2omsPJokQRGXD3IdoXfSka+V6yIVEDjgDWfFSVN35TEtSo2i6Pc1OQSO18AFyix/HE4oz+KyfIu2SlS6umU/q5f/FMXUSx3yJ0LQTe93bpY2LSb30V6xihsgBZxA95ELEnSzr9haObZFf9izexol9xFE7X7sdQVKIHvY59zdhQNfqReTWL2bonMshXE1TVxHBthEtncf+ch21w8Yw5bjz8fu9FFULr6xw1MQ6QADHIeiTUQbKwv9lGCDY/4GYObIKURCwnb2jToHxh+BbOZfMW/cSGHdQxXZL9AaJH/VFup7+HfllzxGZcQoA0YPPp7D6dVKv3Ez9xddXsr6+YfuiVA8j984TFRGxvYEUiBI95AIcw+3nFRQfgvyhrURq7q2Ep5/crf4dpe7C35Jf8pRLbN57Hm/DGIITD3eP2Rei+sRvERx/CMmX/krbPd8lPP1kYodejOjpq9jpbRhL/UXXU1w9j8y8u2i//4f4Rx9I/LBLUaqH9Nn+o7CmLb/H3vePg30HR6mL+HhldXvFA3VPEIAXVrYyrj7cyyYMHKTuEnOpO2tt2e7flm1j7Z6v9oIkwvShcZY2ZbB3k2n+d8HItJGZfw+l1fMQA1FXVGyfY3dbDu6YOtmFD5Fd/CiiN0j1Kd8nMH52r3u2vPV9ks//CauQJDLrXGIHn9erDLu8ZRnJF/+CleskPP0kYode0u89trfIvfM4UqiKwLiDPzxOxyH96m2ujsBhl1aWl9a8ibr1PeJHX9mLRFvlPOnX/o6nYSyhqcf3Gl+3HDZ0Fvjl6VMG+q4HMID/IlSFvGTLOqZtM3lwjPqon21dRWzHpibkxSNL+GQJr+jgkSQaq6MUTv4yG+/7Ofl3Hic661wkXAISPfh8imvnk3rpLzRc+icESUGpGkxkxink3nmc8NTjKoFAQRCJzDiF1Mv/i9a8aq+D1x9vTuAmA4LjD6bz6d9RWjufLtlL4ugrEL0BBEHAP/EwBg/bh/Tcm2l7/W4yq+bTeurXiQ4e485fJCiZBlG/l1QowtgTr+SowbNZ++zNLHz4b7z38qOMmnMptVOPQpEltIKOLYmYlk3Ur6AZNppo4fNItOZKeGWZZEEFR0DtKiJJbsZ1VXMWRREZGg8wti5CXczHjKEJ3ljXRkfBg2lDV8nAsUVwBPYbHmOfxhgdOZV0t92WTxFpyZYJlHU2dRYoqAYNcT8b24sYlk3A68O23ZYqzTTZliwhSwKaYRP2yViOQ9Ajo5pum9mY2lClMm1wLEB7rszmjjwhr4IjgGo6hLweBsWDRLwKyZLGoIiP+miAZFEj4JUpaRZRv4Kqm6SLKq1ZjbBfZnAigCKLOAJYjo5l2WxPacR8Mmvb8uiWzdaUiIDDfsMSGLZDWXcYWuUjV8yRLKiYpkOyoKNg0Vm2SBcMRAEsG5rSGvUxBcMS2ZEt0ZwqUtJt6BZNy73/Ep0LHsAuZQlOPLxbCNSd39pA2oAg0KMkU2r6gOSLf8FMtxCccgw1R34eZ+c2sXKO1Mt/o7R2vusscuJ3egmZWaUsqbm3UlozD6V62B71XvYGpfWLMDNtvd790C1suvEd4odfhi+UQJKgXCzT9soteGuG4d/nFEzDJhqVURSJtx+9hUKqg3N/8HskSaaommi6zcbOHHUtPmaOqB4QM/s3YIBg/wdi+rA41546mWueWlnJXO4JgiCQOObLtP79KlKv3NyrDDUwfja+FXPJzL+HwNiDkCPViN4A8cMuI/n8HymumEton2Mr40QOPIPkc3+kvOndjyXiIAgCQo8PsWlQXLcAOT6I0roFCLJCZIabWXdsC1GSiR54Jo6hkl30CF3P/J7SxneJH/UF5B47r1H7M+jz/0t63l3klz5Daf0iEsd8icCYA/vZt0ho0hEExs4iv+Rpsm8/QssdVxGcfCSxgy/4SHujvYVHEhiWCLChs7hX26/ckWVVa26viawDvLWhi3e3prjmpEm9fKyvOWlShVSBW+WQLxvc/ObmPuPEAgqZktFnuWXDO1vTeCSB6cPjfXrFe47hXwmrkCa76CHy77+IIEpEZp1D9MCz9tjjVN62nNTL/+uWc086gviRX+hl3WZrJdJv3OFaeSUGU3/R9b2qIHpIbHHlXOTEYOou/G0lk/NJoe1Yi9a0oo9vZnnTO5Q3vUvs8MsrRNpSC6RevQ1PwxjC007oNU7mjX9gl3NUnXttv8EF03JY2ZLlVwN91wMYwH8NfIrExEExhpR1dMshXzZIhLzYQFXYR8QvMTQRxO8VSAS9vLM1SX7SLHL7zKZr4YPUTzkELdSIidtbnTjmS3Q+9gty7zxREWSMHnQexVU9gfffV54/wclHkXnrPrKLHvlY1WEfZ04giBKeulFEpp9Cet5dFFe+itr0AYnjvkqg286LYJzBp/2I5LqFdL1yM2tv/w7x6Scxes7FlB0/XkyiPoVsWSdTLGMH49Se+gOi00+g+eU7WfnQ9QTfeIRJJ15G476zqQl4yZR1vKJIUbOoDshotk22aOBRTDrzJRRJpjrgYUdBwy8K+L0yIa9EXjVRDZttXWUUoYxqWOimU7FGMh0DzZRpzagsb0riUWRqIz46smVqIn5qI17ea3JtuXyKRHXAx9j6MEXNoDHuB0Fg1Y4s4xrCBDyS248tQEm3cQSHxrgfv8dV5saBoCIhigKmadOcLhP2K2S7rbhmDE8wri5EXrfBgUBWpD7ipz2r0pQuMaw6iFcRKvsI+TwkLNdKbGNXgSq/Ql41yRQNFNmhK6uyI2lR1EzCPhnTsVFEgc1dOQRBxCMKvL2xxOaOIk3pAkGPTF7VyZRAEEFzbavxypAtajyxtBmf7CqqN6VUVMMit+ZNUm/dj5lpxTtkMvGzflrRGNoVRYBSluQb/6CwYi5yrJ7ac3+Jf/jUXvOZ0vqFJF/+X+xygejsi4jOPLtyjzuOQ2nNPFJzb8XWSkQPvoDorLN7BeQ/LhzHIbf4UeR4Qy/7W8fUSb16C3KikdiMUwh5QZZg88v3YuQ6GXrh9QQDMg0JH9GwgtO5hWUvPsRhp5zPwQfNZNWOLM2pMrURLwGPTL6kU1B1EqH+HVUG8M+DsGvv5L8CM2bMcJYsWfIv3+//b1i6Lc21z6xieXNvwbOQV2JoIoBu2mzciexlFz9G5o1/UH3ajwjulEUzMm20/v0qfCOmUXvGjwG3X7T9vh9ipJoZ9MVbkPxhd7llsuO2K5ECEeov/sMn9s7LvHUfettG/KP2xz9m5ofkopDGyLRS3ryE8vq3iR/1RbSWtWQXPoToCxI7/DJEb5DgTg8ktXkNqZf+gtG1Df+YmSSOumKPpNkqZckuepj8e8+D4xDadw7RmWf3skD6vwxJgG8fO46ZI6sqVl9hr8zcNe0gCFx+8AguOHAop/7lrV73Riyg8P054xlXH+bC29/GMG038w19StP7KwX/V8IqZcktfoz8sudwLIPQPscSPfi83fZZg3vvpN+4g+Kq15Fj9SSOvapPOXd581I3K11IEplxKtHZF1VKuxzHobT2LVKv3oJdzhM54AxiB5+/R/GevUX7Iz9Fb91A45f+XsmC24ZGy9+/gih7abjspgrxTr5wE4UVc2n43I146j4sG1ObVtD+wI+IHHDGHu3yLjxw6ICw2T8JgiAsdRxnxr/7OP5/w8Cc4LNDpqTzwgctILjClIok0Rj3Y9kOtm2zvj3PutYs7QWVztY27vn+OfjqR5M457pe7/POJ35FefMSGi7/C0p8EOC2p3Q983sSc64ivFMFTfbtR8jMu4v6i2/ot2Vrr457L+cEiWO+BJLSrYnRTHDSEfhHTKdq0uGYgInb+pOZ51bAyaE49cdeQeO+h4AgEvDIFHWdTBECXvDI0BANseGduWx+6U7UrmbqRk7k8HOvRB42naqgwqCYn6Jh0ZrTmDYkhuOAmyC2kSWBgmq53tq6QSzgoaQbRHw+vLJAMqchCw7LWnIUVZ1owEPE5yHolZBFEUFyqAv6aKwK4lhQF/NRF/Oysb1EzC+zoTXPwk0dCI5ARjMRHIfJg6McOLKaRMjD8KoQJcPCK4uEfTJF1UK3LFoyJZrSZbySyKjqEKPrwgiCwJsbOqgKKGimW1o+Y3gCryKhmRam5WA7Di2ZMtu6CuS6hd8yJY39hiUoqiaZkkHJtNAME0WQKBkGQZ9CvmzQki6zNZmnqFnoukXBsKgKeJk+LIqFSEPEx8aOAu9uTaIZBsm8W16nyKCaIAHlne4Jufv7BBAcm9K6haQX3I/R1YRSO4L4oZfgGzljt/NQx7EprphL+o07sbUikf1PJ3rweb3sO61ihtTcW9ysdd0oqk74Rq9ybTPXQeql/6W8eQmehnFUHf+1igjvp0Fp07t0PvpzEsd9nfC+x1aWZxY+SHb+vdSe8wuCI6YhAWrbRprv/jahfefQOOcq6iIS4wbFEbB47pefp5RN8edHX0XwBjFMk3TZQECkJuKlMRZgv2EJwr5PHgwYwJ6xu3nBAMH+D8bSbWnOu3VRL3IkCSBJbmZTEHr3vDq2Retd38Iuphn0hb/1EmnKLn6UzBt39iLfesdmWu/8JqF9j6Vqzlcr2+aXv0TqxT9/bCuCXeGYBoKsYBXTlLe8h9a8Cr1tI3JiMN76UfhG7IenZrhrb9Gxlc5Hf1YRRonMOpf4oRd/OJZlklvyJNkFD4ADkVln8//YO+vwOM6z6/8GllksWbJlZmYH7DgOk5OGGQvh9k25ffuWuUmTNMzMTLZjJzEzM0sW0zINfn+MtPbGdmKnSQqfznX1ajy7szOzs5rnhnOfE5hw3mcmR1qsleiSF0msnwOCYCVxE8//0jraXza6lhBJtMTOPss/uzLopCGaybv/vzt3eM4H+UCBtMcW7s4rxPwroSU6iC9/nfja9zA1Fc+QqQSOuSQX4B0Kpq4RX/MukQXPYuoKgQnfwD/5wryZKD0VJTzvEZKbPsJWWEXhabfh6DFo/3FjrXTMuZ/0zuXYy/pReOqtecntP4Ns/VaanrmD4NSrCEy6ILc9PP9pYktepPSAOfLDJdGmptDw+C2ga5Rf9w8km/OQBRBZEnjxm5O7qeFfEboT7K8G3THBlwPDMFldG6YhkmJPSwJJFjh1WDmKZpBSdBJpHR2NDbVh1tdFUVWVJR+8zK43/0HhabflhJZsgB5vp+aR72Av70fJRVbybZomzS/8BLVlDxXXP5BzWTCyKeofvB57WT9KL/zVFz7/o4kJDCVD6yu/JFu3CbA66UVnfDfv87IN2+iYfR9K8y68vUcx4Ozv4Cypoi1mrZmyaCXYQ8o9FPq9RONp4pvn8clLD9PRXE9J7yFMOu96xhw3HUyQRQmf24amG3gdMroJsmSSzGjUdaSoDaeo8jvRTIirOgG3HU3TSWVUEhoUuq1EtLrIQ0cig6oZ+DwOirx2JEGkyOegT7EPVdVJqRp729Mk01k21Edw22WiKYVIVmNkZYBCj4upg4op9DpJKjpBl51wKovTJlHTlmD5nnbcsoRqmDhkkT4lPkoCTjKqTntCochjZ3zvQgLu/BgpllLY05pgd1sczYTWRBZdMyzht0gawxAIeWwU+134nDLLd7fikGXSWYVtrQnsokgkpeGQDERJpNDtpMRnJ+SxE06q7GpPsbMlTjKtoQNJDexYybQgQVYHhwBp06J5m4ZOautCokteQm2rQS6oJHjspbgHHfuZorXZpp10zLkfpWEbjsohee40YBXUk5vmEZ77CIaaJjjlEvwTv5ErdFvz0e8SWfA0mAbB467EN/bMw46lHQ1M06Tpqe+ip+P0uOHB3DGtZteNuPpOoHjmj5AAm66x+6nvYaQiVFx3Hy6nl4ALBpT5aVr0KnOf/js//dtDXHTB+aRVA9MwOxXfDRBk+pZ4KfE5vnAzrBufj8PFBd0U8f9gLN3dflDn8UArpU/XTgRRovC0W2l66nuEP36cwlNvyb3mHzeT5Ob5hOc8gKvXSESnF3tJH3xjzyK+8i28w07E0WMwAN5hJxJb+gqR+U/h6jvuCytzdymAa/F2ooufR5BslFz4q7wupZFJEFvxBqaaxTN8BlqkidSWBcRWvonkCeIddRqiJCNIMoGJ5+MZfDzheY8SXfAMyfVzCE2/Dlf/yYd8uMj+YgpPuZnApAuILnmZxLrZJNbNwjNkKv4J530pVcovE123UzPg+eW1iIJwWLG7T9tpTagOEU4prKqxlM+7kuv/e2sjyufMfn8dUMMNxJa/TmLDh2DoVmI9+cKD1LY/jfSeNYTnPozaXouz9xgKZnwLW0GP3OumaZLcOI/wR49a1K4pFxOYfFHut5e3iBoGoROuxTfunC9lEe06fnj+U4juAL4xZ+6/3rZ9xJa9imfoCbnk2lCztH9wD3KglMCxl+Z9TmTxC2gd9ZRc+GvETyXXkihwYqcK/TfGVHYn193oxv+nMEwTTJ14RsNpl7BJAqmMlajZZZFERmVfR5K1tRFaEgo2wcA/4hScqz4m/NGjuPqMQ/KGkIAelaWkpl1N8+z7SG6YY2leCAKFJ91Iw+O3EP7oUYrO/B8Ay3Zr4jeIfPw4mX0bv7CQ6NHGBM6qYTirRxNb+hLJjXMxsklCJ34LW8Caw3VUDKTflX+jY+37tC54mtX33EjBmDMJHnMJktOLaoDUKTzmcWlMHlBM5cTrOP60mcx56zWWvfEYb/31e6x6bQCnXvJNjjnlDFRTRLAJhJNZNFNAMA003cTlkClwytTFFMoCDtpiWZJphYqQB1U3KHDKFHocmIaBLAq0J7IYGBQH3YTcduySRGXIi00W2dIYxcSkJZZG1w0UHepbE2QUnaBTpqY1TVMsS0nARp+iAD6XjM8pE0llqetI0ZZQUDWdxpSOompUhJxEU1l2NifoU+qh0GOnMugioxnIWRWHbNHMDdNkS1OMRFqnI6kTz2QwETBNWLyrnWxWozxgx2UXKfLIbGqM0Z5QwFTY1myx5UxTIJ7RCLplRMEEI0NTLE3QbSeaUtF0Hacs0tHZmu7qUisAulXc8ThBiyu0bZxHbPmraOFGbIVVFJ11h6Wl8hnrs54ME5n/NIn1cxDdAQpP/y6eYSfkxalquIGOWf8gU7MOR4/BFJ56a54mj9K8m/ZZ96I0bsfZeyyFp9xoOd18SUhtX4zStJPC02/fn9B3uoMgSoROtIRNdSC64g3Ult0UzfwxotOLAbjs0NFYyycvPsCoY2cwcOIMNtRFEYBin4M+JV5K/S5cnaMB3fjXoLuD/R+MVTVhLnhg8REpMx+I8EePEVv+GqUX/w5nr/0KzNmmnTQ99T28w2fkLLqMbMqisTq9lF91V+5hkNz8CW1v/5nCM//nc+05wOqGq237EBxu7MW9D6JjG5kE4flPo7buJTD5Qlx9xqKGG4gufA5TUyg+9ycYmQStb/0ZR9VQsrUbyOxdg728PwUn34SjrF/e56X3riU89yHUtlocPYcTOuG6g97zaWixNmLLXyOxfpa1ePcZi3/cTJzVo/4tq39HQ+Pu1D7LWXepunkQw+HrhmmaZOs3E1/xJqntS0CS8A47Ef/E87GFyj9zX7VtH+GPHyO9awVysIzQ9Otx9ZuYd5/Utn20z7mPbO0GHD0GU3DKzdiLe+VezzbtpGPWP1CaduDsPZaCk7+DLVh22GPqiTBKy26MTAI5WIa9fMDn/i7Su1fR8vIvCJ34zZyIoGmaND//Y6sLdMODuS5Q+JMniS19OTcf1gWlZTeNT34Xz5ATKDrj9oOO8e3j+/Cj0wd/5nl048tBdwf7q0F3TPDlYXdLnLlbm/E5bGQ1lWhaQ8BEQqJ3iYfFu9rY2hxDzRqkdY32eIZoXR07H78Zd79JFM/8EW4BKotshONZNj76Y5TWvVbHutOmKDz/KWJLXqLk4t/i6jUSAEPN0PDQN5H8xZRd/pfPfTbq6TjZfRsxDd16nhZX5+lTHE1M4O4/EUNJEVnwHIIkETz2Mnxjz8olYhKgpqKE5z9FYt1sRJePwDGX4Bt1GnZJpsgn0iPo5sTBZfQr8VHbbvk0t0RTtK7/iPmvPM7endsIFJVx3DmXMeyEmTi8PuyiQGtCoy2eQTN0VB1iySyIApph4HfK2GSZREYhq5m4bRJVhU6yGiQyCpG0iqAL9C3xMKFvIRUhD1vqo6yu6cDrllAUk46UQms8C6YOgkA0rZNRLXr7WSMrGFlVRMBlo0+pl45Elpq2FC3RJG1JS/Quqxl47TKyKJJUDYp9DkRJYECxl8HlgVyBXTUMGiMp9rYmcdtkNAw6Elm2NcdJZg0aoykcooDbKSOLAgNLfNSG00TT1veUVHT8TjsgIApQFnTQEVew2yQkScQwTHTdwOuSiSYVMppBOKWSzZqkLN0y63eRipJa8x6R1e9ipCLYy/rhn3QB7gGTP7OZY2oKsZVvEV3yIqam4BtzJsFjL82z3jI1heiyV4kuecly6Jh2Nd5Rp+Y+11DSRBc9T2zFG4guPwUn3oB78PGHp6BrKkrLbrRoM4LNgbNq+Od6YJu6RsNjNyMIAuXX3pv7jSa3LKDtrT8Smn5Dzu1H7ain8fFbcPYZS8m5P0UG/HbwuexseeKHRPdt538efJOh/fpS7HMQTmYpDTo5pk8xNlu3WvjXhW6K+H8pvvnUSmZvbj6qfQw1Q+NjVve6/Np78uZRwh8/TmzZq3lBfmr7Elpf/y3BqVcTmHQ+YM22ND35XfR0jIrrHzisRYFpmnR8cA+J9bPztssFlXgGHYtn2PQ8+m9q+2KyDdsITbsGgPia94gufoHA5IvQU1HU9n0Un/PD/aIT8x7BSMXwjjqN4PFXIB1AezcNncTaD4gsfBYjHcMz9ASCx13+uZVIPR0jvuY94qvewUhFsBX2xDf2TDxDpn3uw/PrggBfKEH+V89Wg9WpTW1dQHz1OyhNOxGdXryjTsM39qw8xexDQU+EiSx6lsS62Qg2J4EpF+Ife06eH7qhpIkuftFaJG0OgtOuxjvylP2LaDZFZMHTxFe/i+D0Epp6Nd4RJx1+ETV0mp6+A6VpR95214DJFM/88WEXfdPQaXziNkw1Q8X19+cEURLrZ9P+/t0UnHIzvlGnAvuLW55h0yk6/fb8Yz/1PbREOxXX3Z/TQjgQ3TPXXx+6E+yvBt0xwZeHjKqzfl+YjoTCB5sb8NllAh4HqazGkAofG+pibG+K0pFU8NklsrpKKmuw5r3nCM9/iopzf0r12Ml47Q5Sqsq+PbXUPH4L7r7jKT73J4D1DG987GYQBCquvTc3ihVfN5uOD+6m6Jwf4Rl07OHPcd9GWl75Jaayf+JWsLtw9RlnzVT3GZtLPI40JgCLYtsx534yu1dhK66m4OTv4KwcmndspXk3HfMeIVu7HrmgB6Hjr8I9YDJeQWD8gBC9gm5KQy4KvU5SWY3+pV4iKZWNSz7iofvuYfeGFdicLvpOPpXBJ57PwIFDiGWy7GtPIQomsbSOahgEXXYyuorXZqMs6KauI4WqatjsIvXhDHZZRNcMJBGK/W6CHie9ChxE0xqJrI5hWrR+mywST6tkNZ1kJkskQy4b7Vvi4pzRvRhW5cdpkwknFWIpFcUwaU9kqQzYSagmmAY7WpKImDhsMqJo4JRlJvcrwSZCR0olpWjs60hhYtAaz9IYSVPqd7GpLozLbqMxlsImCBT6LA2RXgVO9oXTbGmIk8hoSBJ47RI+lx1REqkucBJOajhsEq2xDAlFxTShR9CFx2nDJoqs2NuGapqE05YFZ3z1uyS3fAK6ajU4JpyHs+eIzyzWmIZuucQseBo91oqr3wRC067NU/8Gq9jd8eEDaOFG3IOOI3TiDbl4wzRN0juX0THnQfR4K+5BxxGcejW24OFjxdjKt4h88iSmls1tE11+Si/6zWeOl8VXv0vHnPspPu/nOVFeI5Og4ZHvIHkLKLvyb53e6wbNz/8EpWUPFdfdh+wrxAbYRdC3zWHba3/n6u//hiuvuQ5JEOlIKdgkAYdNYmLvQnyu7pnrrwvdCfZ/KVbVhLnooSVoR0nzzdSup/n5n+AbP5OCA3x2DTVL4+O3YBo6FdfemxNkann9t2R2r8oTPMnUrKf5hZ8QPP7KnNLop5Ft3E7TU9/DN/YsvKNOw8gkURq3kd61nEztRjANnNWj8Y87G2efsYdMVrRYG21v/gGldQ+FZ3wPz8BjcsqiRiZBZOGzxFe/i+j0Ejz+SitZOoBCZGSTRJe8TGzlm4CJb/QZBCZdkOscHg6mppLc8gnxVW+jNO9CsLvwDJmKd+Spn9sN/3fFvzLBVlprSKyfbVH5Mon9hYuh0xHtn61waWQSRJe/TnzlG5i6hm/UaQSOuSRPHdwquswn/NFj6Il2PMNmEJp2de4+dxVlwvMeRU+GEd0B5EApsq+I0PTrP3P2PvzJk4hOD46KQUiuAMkt84kufp6S83+B61Meml2Ir/2Ajln35gWbejJMwyPfwVbUi9JLf48giJi6RuNT38VIRii//v68IlGXeFDXZ8wcVcHb6xo48M/90ok9u1XDvyZ0J9hfDbpjgi8PpmmyoyXOx9ua2d4Ux2eXcEgyimHNHWdVHc3Q6Uip9C/1Ut+RoqYtQXM4zroHvouWjDDtB09QXhZiZ3Oc+oRJdOkrRD55Iu9Zlt67lpYXf4Z/8oWEjr/SOrah0/j4rZ1FxQfyCp8HovGp72KkExSe8T1Ehwu1rZbM3rWkdizFSMeQfEX4Rp+Od9SpSC7/Qft/Vkxgmibp7UvomPswerzVsm+adg2yrzDvO0rvWk7k4ydQ2/dhLx9IaOpV9Og/ggK/jaDLQbHPTsBpQ0DE55Jw2m3UhZM07d7Glrkvs37BB+iqQs/Boxlw/NnY+k3B63bRkcqgmVDsdZHOqpT6nciyyKbaCAlVw9BB0az5YgVwAgU+CbddoneRl8ZYGp/DgdMu4JQFJFFiZ3McRVFRsjqtGRAFUE3oXygzsncJI6usmey2eIZdrTHCKR0Mg/ICFy5ZQkKkPppANww6Uioeu4zDJuFx2hlW5ieaUcioOuGURiSVJasZ2CQRTNjdGielaaCLFPrsVIXcGBi47DLr9sVoiqWwiwJZRUOUJQaV+sjoJsPLAyQVlWRWY3dbEgETE4s9V+BxIgC7GlppW/cRe5e+j9K0A8HmxDP0BPxjz/5cG1XrHq4gMv8p1Na92Ev7EjrhujxWJlhFl/C8R0jvWIpc0IOCGd/OEz9VI02EP3zQYsOFKkCULL92T+gz44LMvo0ktyzA1WskckEFRjpG61t/wlHaj5IL/u+Q+xiZBPUPfRNbUU9KL/l9rnDQ/sE9JNbPoezKv+Viy65EvPC0W3NOPi5AUNrYft+NFFcP4oW33mVsz0KaohnW10Uo8top9DjoWeTpFjX7GtE9g/1fjC9CXnb2HIF31KnEV76FZ+CxOcEn0eag8LRbaX7ux0TmP0XBjG8BUDDjWzQ8ciPtH9xL6cWW4Imz1whc/ScRXfISnmEn5i1gXVBb9wLgG3dOjn7rrByMf/xMtHgbiQ0fkljzPi2v/BJbUU/8E8/HM2RqbqHENCzrMJcPV59xaOEG65o7E2jR6aVgxrfwDj+Jjg8foGPWvSTWvk/oxBtyc2Ciw0No2tX4xpxBZOFzxFe9TWLdLHxjz8I/4bxDdgXBmgfzDp+BZ9iJKA3biK99j+TGeSTWfoCtpDfeYTPwDDkeyfOfM/P6tdtspeOkti4gsWEuSuM2EGXcAybjG3Uajp7DP5dGaGRTxFe9TWz5axjZpFVZPv6Kg0TPsk07Cc99iGzdZuylfSk650c4K/fTppXWvXTMeYDsvo3Yy/pbPu42B8FjLrFGJla9hbvfRJw9h2Oa5kHnFZp6Vd6//RPOJbr4eZS2mkMm2EY2SWTB0zgqh+T5XnfMeRBDzVB46s25YlJ06cuoLXsoPvenecm12r6PyMLncA+YgmfQsRT57Nx18Wgm9C7k529swDDBJgl8Y0zlQcfvRje68f8nUoqOqpn0CLpJZXTCiSyiHfoU+RAMg20tCZxOiUKfk4qgG1kUUXQwEBl3xY9Z+Lfv0DT3YSbd9ks2NcYA63mX2raQjjn34+w5HMkdwFU9Cs+wEy0tiUHHYi/pgyBKhE68gZYXf0Zs5Rt5oo4HQm2txTv6tNwz2l5cjWfw8RScfCPpXcuJr36PyPyniC55Ee+IU/BPOA/ZX3REMYEgCLgHTsHZewzRpS8TW/4aqR1LCUy+EN+4cxBtluCTu99EXH3GkdjwIdGFz9H8wk+I9hpB4bFX0HPgYPa2QjhjJcIuEUp8AmUBLwl3BeVn3krB1KvRtn3Eunlv8OGDv8Tm8tJz7HRCI0/EXtqXBlWjZ8jLsHIve9tS+JwCsmQnmlWQdTAEkDqp0e1xnYxDB1K4ZRHT1PE5LPGzPa0ZHLKExy4RFjM40VGyVvDudjlpi6bZZhMJONM0RDO0xNIkMjopRaM5lqY05GZvS5SWaIasAS7JxG53UOZzUOhVeL2+A8UQsckiLoeMqhuUeuxoJuxuiRFTVEwdehU4KAu68DltSKKJYUChW6IhopHRrX0DDhGHJJDRNNbUtCHLEvGsTjSZRhDBZZOwiyI71i5l34pZNK1bgKFmsBX1JDTjW3iHTc+jdB8KpmmS2buWyMJnUBq2IQfLKTrr+7gHH5fXoDGyKev+r3gdQZQJTr0K/7iZuaKPoWaJLX2F6LJXECTZEhWV7BiZ+BHFBc6qYQdpDbiqR5PZt/Gw5x5Z/AJGOk5o+vW5z8rUriexbhb+Ceflkmst2kz4kydwVo/GM/yk3P66aRKddT8iBlf96LcMrwjhddro65AJuC27NJdNwmPvTu3+HdB9F/7DsXR3+1F3r7sQmnYt6V2raHvvLiquuTtH83JWDcM39kziq97GPfAYnFXDrC7fCdfQMesfVnLaSW0NnXAdDY9+x7L/OuuOg44hdNLPTTVz0Guyr4jglIsJTDyf5NYFxJa+Qvu7fyO66DkCky/EM3Q6giSTbdyBoWYoO/8XuX0//bCzl/ah9NI/Wh3Mjx+n+bkf4R4wheC0q3PJmOwvpuj02whM/AaRRc8RW/oK8dXv4BtzJv7xM/O6oXnXIAg4egzC0WMQxonfJLn5ExIbPiQ872HCHz2Ks3oUnsFTcfefmKfM/v8rDCVNeudyklsXkN61EgwNW1EvQidch2fY9MN+z3mfkU0SX/UOsRVvYGTiuPpNIHjs5QdRr7RYG5EFT5HcOA/RHaTglJvzGAx6Ok500XPEV72DYHcROvEGfGPOpO3tv2ArshTVfWPPJrV1Pumdy3EeQdIP1pgFgCAfejQisvA5jFSM0AW/zH1eavsSUtsWEjzuipx4m9Kyh+jiF3EPnprvhWnotL13F6LNScHJ3wGgT6EVeHQpwb+/sZHThpV3i5p1oxvdyENa0cioOn6XTM9CNyN6BNnbkaQunCSeVim1O+kRcuKxSWxPqRiGTsgjo1f1ZdLMq1ny2qNUjZuOPTAAO6CIEoWn307jE7fTMeeBHCU7NP06y/rw/bspu+KvCKKEq3qUVXhf/CKeIScc0v5SsDkwlYNjAkGScQ+YgnvAFJTWvcSWv0Z8zbvE17yHd8RJBCZfgOwvOaKYQLQ7CR1/Bd7hMwh//BiR+U8RX/s+oalXdc7VigiihG/kKXiHnmBRz5e+Qv2z36ejejT+KRflEqikAXuiJvFsnKwKqg4idpx9TmHs8DOhZQubP3mbPctmsWvhW9iDZQSGTSU24njCqX4MKgtis9tRBRWbCqYNFMWypLJjiXypGsTTGVKiSC+nh7Sqs70pgV2WQDARRAlZlCjxyWhOnaSq47IJ7OlIsactiWoY+B0S4bROgceBIQg0xTO0J7OdgmqQVa1rycSz1LVnMaGT2g0Oh43qoAePUyKS1miMpFBNi0ou2sAUBdpTGg67xI7GJBlVwy4LhJxOkqpKOqsSSUNzIotTtiFLAuFUClkyiSVNlOadRDcvILplPmqsDdHuwjP4eHwjTkKuGPS5665pmmT2rCa6+EWy9ZuRfMXWej98Rt7svmnoJNbPIbLwGYxkxBoNnHpVTijPNE1S2xZZTLdYC47q0RSc9C3sBZW0vvnHfyouMNXsYWMCtW0f8VVv4x1xUi6RNtSMJWwaLMsJm5qmSfv7dwNQeOotyIKADAQdkNy2gJbNy5hx9R2MHjyIZNayhVMNg7RiIAoQ8ti7hc3+TdBNEf8Px3PLavlZZzerC6MqA6z9lDf24dAlwuSfeD6haVfnthtKhsbHLWuu8mvuQbS7rJmQF36K0rTLmgnpXDgj858muuRFSi/9w0EVPaW1hsbHbsqzADkcTNMgvXM50cUvWNYcwTICUy7GM/QEMAwE2Yapa/lCKEqaxPrZ+EadfkBlMkNs+evElr2KqasWnXjKxQdRwpXWvUQXv0hq60IEmx3vyFPxj5+J7C8+ou9Oaasluekjklvmo0ebQZRx9hqJe8BkXP0mfO488X8T9FSU9K6VpHYsIbNnNaamIHkLcA8+Hs+QadhL+x7RAqUnI8RWvWX5X2eTuPqOJ3DMJTjKB+S9z8gkiC57hfjKtzBNE/+4swlMvjBX/e6avw/PfxIzm0J0B3BWDcdUM5Rc8H8ktywgvXsFhafdhiBKZGrXk9y6EM/g449IATe1czmtr/4qz2KrC0rrXhofvzXP3k7PJGh85DuIniDlV96JIMmd1PDvoSc6qLjuH3mFh+iy14h8/FieiGCXBRfAZY8sRdEM7LLIs9dP6k6yvyZ0U8S/GnTHBP88TNMkldVoT2bZ0hjDBPxOGb/TxpCKIOv3dbC2NoaAQdBrR1FNyoNOtjRF2NoQxy5LNERTeAWNZ352FWo6ybjbHiRi2Ikp1jG6PHoPpIonty6k7c0/5Gm05OyG+k3MJeMHoum5H2GqGcqvuutzr0uLNhNd+gqJDXPABO/IkwlMuhDJHThkTACQ2PQRjvIBeY4SmZr1dMx7BLVlN/ay/gSnXZ0TaOuCoWSIr3nPYkylIjiqhhGYeH7n+Fr++iVidbfBUr5WASmbIrNzMR0bPyFTsw5MA3tBD0qHH0vZ8CkEe/TH5bSzL5whm4Uuc0w74JAg4JUQBBs+2bQ6ynaZWFYlrRg47RKZtIohmqi6gd9pwzQEVMArm+yLKBT77MRSCm6HjIEAhoFqQlrRkQRIq5DV9p+3jsWAlAGbBFVBO36Xjcqgm2hWQwbaUlkQBOySgGCK+Fw2IqmsZeEVV0hmFQwDVMWy2TJMy/qs2CVSu3097ZuWENm+NBcjeXqPxjVkGq7+E/P0fw4H09BJ7VhKbOkrKE07kHzFBCadb6naHzCCsJ/2/yRqey2OHkMITb8uz5fdmr9/mGztBgSbC8lXgKN8AEY6/k/HBaZpUv/AtTjKB1A888cHvdbyYmfs/M2Hcmt9eN6jxFa8nic4HF/zHh2z78t5zTsBUYQCMc7yv92Ar7yaX93/AuP7lVDkc1ERdNGWyGKa1uiAZpj0KvxsFkA3vlx0U8T/y7CqJsxrq+t4eeW+vORaAE4aWkaJ33lE4meuPmPxjjiZ2PLXcPeftJ8qbndSePrtND/3Y8IfP0Hhyd9BEEQKT72VxsdvpmPWvRSf/wsEQcA/+QISmz6iY/Z9lF99d95iZyuqQvIWkN614nMTbEEQcfefhKvfRNK7VhBd+Czt791FdOnLBI+51KIAfWohTW1fQnjuw8RXvkVw2jW4Bx6DaHMSPOYSvCNPIbrwOeJr3iOxcS7+cTPxT5iZS8LsxdUUn/ND1GMuJbrsZeKr3ia++h08g4/HP+Fc7CWf7YNsL+qJfepVBI+/EqVxO6mtC0ltX0zHrHthFtjL+uHqPRZn79E4KgYddO7/yTANHaVpJ+k9q8nsXkW2YRtgIvmK8I44GfegY3FUDjliCze1vY7YyjdIbJgLuoZ74BQCky/EXto3732GmiG++h1iS1/ByCTxDJlK8PgrcsJ1pmmS2b2K8EePobbXIocq8Iw7h+CxlwGw7+5L0dMxbIU9yNSuI71rBe7+k5CD5YiyA1PXjuh807tWINgcOCoG5W03TYOO2ffl9AC6EJ77EHoqav3NdP4OokteRG3ZbVHDD0iu1fZ9RBY8jav/JDxDpuW2a7rJ0t3tACiagWGCqhm5bUt3tzOpT2F3st2Nbvx/BtM0aYlnqe1IkspqgMCelgTlISd2ScQEivwuxvWWiGZUsqpOolMpq9zvIZlWyWgmLVGD7e0qlWfdxuZH/ocd795P2am35Y4TmHg+6R1L6Zh9H86qYUieIO6Bx+AeMIXIwmdx95uIragKW7AM/6QLiC58lvSIk/NmXgGcvUYSXfgcWqLjcwvRcqCUwlNuIjD5ADvN9XPwjTrVSrS9+c87Q80QnvcoRiZu6a0cczGSy4+z1wjKr76L5KaPiMx/hpYXfoqzejTB46/EUd4fsOKewMTz8I05ncS62cSWv0bLK/+Hrbga//hz8Qw+fn8h/4Bjqp3/Lznc+IbNwDZ0BnoyTGr7EtLbFrFv/svs++RFZHeAQP+x2HqOwdFrFBxQ9M/qoKsmIa+AoutkFVB1hY6sSshuo9AOTZpAVtVJZ01EQ0GWZTQDWrMaiSzIpoLNIRNw2vA6HSTTCmkDnFKWREYn4BTQNJNIBrqkuczO8zd12Neu4PMoxNIKpT43KdMgo8HAYheSTaIjnqUtlaU1lsUhmoSTWm6eXAOItKHUrCO5ZxUbd69GzyRAknH1GoV7ykW4+k8+7Ejep2GoGZIb5xFb8QZauAE5VE7BqbfgHTY9Jxjahcy+jUQ+eYps/Wbkgh4Uzfwx7gFTckURLdFBdMEzJNbPQXB6cQ86Hmf1SHwjTwG+nLhAbatBj7XiPIQeUWrLfDI16yk4+cbcWp+t30JsxRt4R52aS67VSBPhjx7DWT0a70iLJSqLYBdNtr12F7qmMvSC77OuLk5cERjTO0RZ4NBFClW3FORlUcDZrSj+L0F3B/s/EM8tq+Xnb25EP4yE9O/OHc7AMh+XPGx1uT4PlhXXTQg2B+VX/z1PEbxj7sPEV76ZpyoeW/kW4bkP5XWlUzuX0frqrwlOveqguauOOQ8QXzeLypufzpsx/TyYpkl6x1IiC59Fbd2LragXweMuO8jXOr1nDeF5j6C21eCoGETwhGtxVg7Jva621xGZ/xSp7YsRnT78E8/DN+bMnIBbF7RoC7EVr5NYPwdTzeDoORz/2LMs+6cj9EU2TRO1rYb0zuWkd60k27AVTAPB5sRRORRnz2E4KofgKOufo+T/J8DUVZTm3WTrNpHZt5HMvk2Y2SQgWIWEvuNw9Zt4xJ1qsJLRzJ41xFa9RWb3KpBseIdNxz/hvLzOA1j2GvF1s4gteQk9GcbZeyyhqVflUcaV5l20vfd31JbdFn1sxjdx9puA2HnvIoueR23ZY1G0i6qIr3mPbP0WCs/4LoIg0vLKL63iwIDJh5zDPvBc6v5xFc7qUQd1Z+LrZtHxwT15wiRd3e7A5IsIHn8FcIBq+JCpOS9Z6FQNf+YHaOEGKq6776DgceaoCtwOmVdW1aHrBjZZ5H/PHMqv3tnU3dH+GtDdwf5q0B0T/HNoiWfYWBchmdUBg4yqo+oCxT4HpQEnfYq9yKJAYyRNUtGQRQFFM2mKZoiksvQIOWlNKCza1sj6+gSYGuvefJTmhS9R+o3/xdlvQu5YSmsNjU/ehqvveIpn/gRBEDrFG29EDpVTdvmfLQ0VTaHhsZsBk/Jr7s2LK9T2fTQ88h2C064lMPG8o7pWLdpMZNELJDfORZBt+MachX/ieXliaHoyTGTBsyTWz0a0u/BPvhD/2LNya66pKcRXv0N06SsY6Riu/pMIHnsZ9pLeeccydZXk5vnElr+G2laD5AnhHX06vpGnHvRsPhxsWAwmZe9K4jtWkNyzBj1tzbbbiqtx9RxOWf8ReCoG4isoIKWB0yYRcMt0JBTARBQl0oqOIQioGRO1M/Qr9ksk0jqmYHWQVSDoFqgKudFNE00zkSQDBBG/QybokjEMWLIrTFjNP88uAVQnIEswuMKLYQiYaBT6HMRSBnbJKvTu7YijaCbR5lZS9ZtJ7NtIunYjWkcdALInSKDfWNx9xkOvMUflvqLFWqymyNpZGJk49rL++Cech3vglIPisGzDNiILniGzdw2St4DAlIutznZnEdvIpogtf53ospdB1/EMP5HQ9OsRHe5c8f/LigvCHz9ObPnrVN70ZJ4uj55J0PDIt5G77OtEyRITfuI2TC1LxbX/QHS4MQ3dYog276biun8g+4sRgYATYuvmsPuNv9P79G8zcPo3CLodjKgMMqpnAQMq/ITcDuIZFUzLB9sui9SF0+imCaZJj5C7O8n+CtGtIv5fglU1Yc6/f/FhxaoE4I5TBnLTCf1yXe6N9VHW10Xz9pFFODD37lIE9Y07h4ITb8htzz0I1IylKu707rcPaN5NxXX3IvstlcXW139HevfKPKVxsGg5jU/cSuiE6/BPOPeor9k0DVJbFhBZ9BxaRz32sn4Ej7sCZ+8xuYedaeidYiXPoic6rAXzuCsO8j2OLniG9O6ViC4//gnn4ht9xkEPfz2TILHuA+Kr30WPtSL5ivGOOgXviJOPmvatZxJka9aTrllHtnY9avs+6wVRxl7SG3tZP+t/Jb2xFfU8IsrUVw1TU1Db96G07EFp3oXSuINs8y7QrRVZDpXjrBqOs9dInNWjjmim+kDoqSjJjXOJr30fLdyI6AniG3U6vtGnH0TjN9QsiXWziC17BT3RgaNqGMHjLs+ja6mRJiILnia1+RMQRJy9RiDYHLh6j8U3+nRMTSW1fTHxdR/gGXQc8bXvUzD9BmyFVbTPugc5UIqtqCfJDXMJzfjmQXT0TyOxcR7t7/6Nkgt/ndeZySmEF1fnFEL1dIzGR29CdAcov+pOBMmGoWZpevJ2jGyK8uv+ka8avuQlIvOfouis7+MZMvWQxxcFkEWBaQNLKPI5EIDnl9dimCAJ8L2Trb//bnz56E6wvxp0xwRfHBlVZ+H2JtbXx5AEKzlz2mWKvA6cNokyv4u+JV6ctk7hUCCt6myoi9ISSyOJAj6njNdhY8mOFpbsamNrU5T2uErdU9/FSEWt59QBCWx02atEPn6cwjO+i3fYiQAkN39C29t/znMVOZTSeBeanvk+eipCxQ0PHjHT6UCoHfVEFj1HavN8BLsT//hz8Y+fmbeeK617CX/8OJndq5B8xQSPvRTPsOm5RM3IpoiteIPYijcwlRTuAVMITLn4IK2PrvnfXDG4S6xz9Ok4qoblz36T390GK2kVBZBlULI68eZdZPauJV2znkzDFkzV6ic7gsV4ewzAU96Pwqp+6AU98YVKEQSTeEbHaxcJpw1SupUMu0TIGuARQRAhpnUmyFgK5Q5LCByvC8b2KiSmGCTSKpFEiqaoQfqAc/y0w4jPDtUFLqIplYyukeloIVy3F711F8mGXSQbt6MnOqx97S4clUPw9hxJycBRuEuryagiirG/U/5Z6Cq2x9e+T3rncgDc/SfhG3c2jsqhByW12fotRBa9QGbPKiuWm3g+vjGn5+InU1eJr/2A6OIXMVIRxM6utZ4K46oe/aXHBaamUHf/NTh6DKbkvJ/lvdalEF5+1Z05Rt4hG1fLXyP80WMUnn473uEzcAIeJ4jpDtbc/W185X3pedlvkSURWYJxPYvoXeJnWA8/vYt96KaBQ5IQRAG/Uyaa0fA6ZNKKTtBlI+j5z2no/KehO8H+L8GFDyxm+d7wYV+XRLhofE++MaYy18VaVRPmgvsXH/TQ/zTaZ99PYs17nXOl+21/so3baXr6DjyDj88JmamRJhofuxlHxSBKLvoVgiCixdtpeOQ7OMr7UXLRb/Meik3P/hAt1kKPbz78hanSpqFb9K5Fz6NHm3H0GELguMvy5qgMJUN85ZtEl72KqWbwDJ1G4JhLcwrmANn6rUQWPW89nJ1efGPOwjf2zIMSRdPQSe9YRnzNu9Y8lSDi6jse74iTLZ/OL3AdeipKtn4L2fqtZBu3ozTtxFRSna8KyIES5MJKbKEeyMEyy0bKX4zkK0R0+b5QIPJpmKaJkYmjx9vR4m3o0WbUSBNaRz1qRz1apAlM69ci2JzYS/vgKB+IvWIgjsohX2i23DR0MjXrSKyfQ2rHEtA1HD2G4BtzOu6BxxxE+TKyKWsWbuUbGMnOWbhjLsnzxNTi7USXvEhi7QcIkg1bUU8Cky/APWAKmX0baXv7r1RcZ1WHDTWb66DEVr1NZs9qSs7/BVq8nfTOZWRq1uMbd06e8vjhvrvGJ24FXaP8uvvyfuOtb/6R1I4lVFxzT07ErPXNP5LavoTyK/+WC9pyi+uFv8LVe0xuf6VlN41Pfg93/0kUz/zR536nkihgmiayKIAg5Dra3R3srw7dCfZXg+6Y4ItjV3OcF1fto64jSTytMLzcT69iDyGvC0zoEXJS7HdR4LYjS9b6oWgGH25uoj6cIuSxoelwTP8itjfHmL+liWU7m6np0Ik276bxqe/hHjA5j61jGjrNz/8YpWVvrshumiZtb/6R1I6lVjLR2Q1ue/dOkps/7ty2P3FNbplP21t/ovjcn+AeMOULX7/Supfowuf2M9QmnItv7Fl5DLVMzXrCHz+O0rQDW2EVgeMuxz1gcm491TMJ4iveILbyLUwlhavvePyTLjzkeqB21Ft+zRvnYmSTyAU98A4/Cc/QEw7ppAJWF1sACjyQTINhgGwHt11ENxSMlhoiezYTqd1Con472XBTbl/J6cVX2hNnYQX+0kpSciE2fwmmrwjRU4Ag2/g8ArMD6FXoRDEMFE1D1Q1iKSsJ78oADDWLngyjx9vQos1o0RYIN6CE68i21WHkPMsF5IIeOMr7Yy8fgLNyCLbi6lzRIiiC1wOt8c9PrrVoM4kNc0ls+BA91oLoDuAdcRK+UacfZI9lmiaZmnVEl7xEtnZ9Z5PkPHxjzsjda8sT+2Nr/CDajKNqmMU6GHESrt5jvrK4oMuSs+Ti3+bFo112uP4J51kq5UCmdgPNz/8E7+jTKDz5RqsgEqth80O34e4zjqrzf0p50EGB20bfIjdP/N+NtO/dzPhbH8RZUIIpiqQzGhVBJ9XFfoZU+KgIeIhkFAaW+OlR6CaZ1ZFFax4eEypCru4O9leIryzBFgShCngKKMX6W33INM2/f9Y+3YvpF8ek331IU+zwjy1JsG7CgVTRVTVhLnhgMYdhlOdgKBkan7gFU+/0wD6gEhxZ+BzRRc9RdPYP8Qw+Dtj/UAnN+Bb+sWflbSs45eac0jhAatcKWl/5ZR519ovC1FUS6+cQXfwieqIdR8/hBI+9LK+rqaeilkr4mncxDR3v8JMITL4w76GdbdxOdMlLpHcsRbA58A4/Cd/4mXnJeBfUjnoS62eT2DgXIxlBdAfwDD4ez+Cp2CsGHjEt+qBrMQ20SDNqyx6UthrUtlrUjjq0cOPByuuihOQOIDp9iE6vRXOyORFku5WgilLneZiYuo6pa5haFlPNYChpjEwCIx1DT8XAyF+SBdmBHCzDVtADW1FPbEW9sJf0Rg6VHzE9/uBrM1Fb9pDc/DHJzR+jJzoQnT48Q6fhHXkK9uLqg/bR4m3WLPya9zGVFM7q0QQmX5hX8NETYSJLXya57gNMXcM78hR8Y88mtuxlfKNOw14+AEGUaH39d9hL+xKYchGmaeSCKbWjnvC8Ryg650d5tMUjQdcoRFeVObd9+xJaX/8tweOuIDDlIuCAjs4B23JMkTFnUHDSd/Z/V5pqecOmopRfe+9RsQIkAS6a0JMeQVf3DPZXjO4E+8hwtHFBd0zwxbGhPsL7GxroiGVpSSr0LfEyuMzP6J4hmqIZvC4Zn8OGwyZS6nMiCBBJqWxqiJBSdcJJBYckMr53IbXtKdbta2POxiY2NyRJm4dn1aiRJhofvwV7Wb9O604RPRWl4bGbkNydYo6yDT0ds+jj/qKc0jhYyVDDw99GdHoou/LOL7yGdiGPoeYOEJj4DbyjD+hqmiap7YuJzH8araMOW0kfgsdehqvfhNyxjUyC2Kq3ia96GyMdw9FjCP6J5+HqO/6gddBQM6S2LiKxfhbZus2d7KmReIaegLv/pLz46cDeocMGEpBRwe0Av0umyGOjNZ4lljHwOkU6IjHE9j2Y4QY66ndZSX1LHdlY+0HXLTo8iG6/FRc4PIh2F4LNYcUEkpy7NskwcEg6mayCpmTRs2kMJYmRTqCnogcU+rsgIPmKcBRWYCuoQirqia24Gntx9WdSvmX2d9aVQ7yup+Okti8muekjsvs2AgLOXiPxjrRo2J8utpu6RmrbImLLX0Np3oXkLcA//ly8o05DtHfeW0MntXUh4YXPoYfrsZX0sUbIKgYSnvvQVxoXmLpGw8PfQnT7Kbvib/t/S2qGxsduAUzKr70X0ea0xjEfvwVBFCm/+h5Eu5Mqt87aB76Llggz7tZ/oDn8BL0OCt0OEmve480H/8CJ1/2Y8vGn0Z5SMU0QTCjwOXDLEg6HzIAyD167nfKgiwFlPmRRoiLotGawJQGH3J1cf5X4KhPscqDcNM3VgiD4gFXATNM0Nx9un+7F9IvjD+9t4YH5uw/5miwK6IaJiUVTOqZ/EbfPGMDS3e38Zda2PPqPCIiigPaprDtbv4WmZ3+IZ+h0is64Pbd9/2xoPeXX3JvzpGx55f/I1m6g/Kq/Yyuq+pTS+D9yitymadL01PfQU1F63PBgnvrjF4WpKcTXvk906csYyQjOXiMJHHspzsqhufdo8XZiS18ivnYWAN7hM/BPOj8viVbaaokte43k5o/BNHD1n4h/7NkHUb/Aepim96wiuXEeqZ3LQVeRAqV4Bh6De+Ax2Mv7f3ld5lTUqiTHWtETHejJMEYqip6OYWSTmNkUhprBVBVMQ7XK4l1/z5KEINoQbHZEmxPB7rKScqcPyR1A8oSQvAVWdzxQguQJfmnnrbbuIbV1Eclti6yZLFHC1WcsnqHTcfebeMh7n23YRmzVW6S2LgTTxD3wGPwTv5GzswDrXrZ/cDeZPavBBPeQaQSPuyx3L9tn3YvkDhI87nLrMxt30Pr6b6m88QlMXcU0DKtIsn42vtFn5BWAjuzaDGtcQslQcf39OQaDno7T+OiNeQrhWqyNxsduQi6spOyyPyGIkvW+x25GsLsov/quvHGA8EePEVv+GsXn/wL3IXy1DwcBcNi6u9ZfF7oT7CPD0cYF3THBF0c4ofDq6lr2tqbQMOhb5CHottO72EtS0elV6MYmikTSKk6bhCwKeB0y9ZE0sYyKbhj0CLqwSzLRdJZ4VuWVFTXM29xCVoGUodP87A9Q2+sov/YfebZbifVzaH//7wSnXUNg4jeA/ZoTB3btckrjB9DHARIbPqT9vbv+6S72gcjWb7XmcmvWInqCBCZ8A+/o0/Yn2l1dzkUvoEUaO5Oti3H1n5hbAw0lQ2L9bGIr3kCPtSCHyvGNPRvvsBMPmVyqHfUkNs6zCsnRZgTZjqvPONwDj8HVd3xuHzeWcrckgGpCiceik0uSiKEbKKqVeGsmOGWLIlzosSMIMk2xFI3RDHq0FS3WghZvR092YCQjVkyQjmNkU5hqGlPNWmueru2PCUTR0iORbFZh3u5EdHgRnV4kt9+KCTwhJF+RxaTzF39pOjF6Ok565zJS2xaR3rMGDA25oAeeoSfgHTr9oG511z6J9bOIr3oXPd6KXFCJf8K5eIdOz8UQpqGT3DKfyMdPoCfaEd0BQtOvxzNkau5eftVxQXzt+3TM+sdBa3cXU+1AhfC2d+8iuWkepZf+EWflYCQg9sljtC19jeFX/C++AZOIpw1KAzbEWDPz/nQDfUdO4NY/PEw4pZBRdERZJBJP057QqCry4HHIuO0iBV4nbpvE6J4hehZ5upPqrxFfG0VcEIQ3gXtN05xzuPd0L6ZfDKtqwizd3c6cTU0H2XBJosANx/bmiSV7UVQDA2vmx36gCJJqIIoC1x/bG5/LRshtP6RYWnj+08SWvHjQoqeGG2h8/FYcFQMoueg3VsU6EabhsZssAYcr/oIg2Trp4zfh6DGEkgt/lUtS03vW0PLSzwlNvx7/+Jlf2vdiqBkSa94nuuxVjFQEZ68RBKZcjKNqv3ehFmsluvRlEutng2HgGToN/8TzsXd6HoKVwMVXv0Ni7QcYmTi24mp8o0/HM2TaIRdVI5sktX0JyS0LyNSsBUNH8hbg6jcRV9/xOHuN+LeYqf6qYWoqmX0bSe9aTmrncsuOQxBxVA3DM/g43AOPyZvf64KhZkltXUh8zbsojdsR7C68I07GN/asvAKIGmkituzVznun4+g1EkGyYS/qiXvgMTkbDqW1hra3/0zJ+f9nFQwkmabnf0Jg0gXYy/paAVW8leCUSw6asTsSJDZ9RPs7f6XorDvy1L1b3/4zqa0LKb/yTuylfTBNg5YXf0a2YTvl19yNLVTRSZ/8A6kdSym74q95hYMcZWzkKRSeevNRnZMkCvz6nGE5b+xufLXoTrC/GD4vLuiOCY4epmmSyHYykQyTvW0xaiMZBEGkuLO7VRJwEs9oGIZJVjMo9NrJagYOWcRpk0gqGn6nDZ/TSlha42lW7GznzTU1rKyNkNZANCDbVs+uJ27FXTmYwVf8imhWtOjFpknbG78ntXM55Vf+NTdj2j7rXhJrZ1F6yW9zNoatb/ye1M5llF91V469ZBo6DY/eBEDFtfd+qU4bmbpNRBc+R6ZmHaI7aNGJR5+WTyfe9BHRxS+iRRqxFfXCP+l8Syn8gC57attiYivfQGnYZq1Rw6bjHXV6nr7LgfckW7+F1Jb5pLYtQk+GO+07R+DuNwFnn3HYgmVIWIm2HatIahfB7RRIZU3i+v7PswE+B6Sz5M1LfxoyfC5N/MuCHevc9c94j2maaB31pHevJLVzudWpNg0kfzGegcfiHjL1sIKo2cYdxNe8R2rLfEwti6PnCPzjz7FYBJ1Js6mpJDbOJbb8VbRwI4Jsxzf2bLRIE3Kg5GuLCwwlTcPD30IOlFF62R9z15Op20Tzsz/K0cABktsW0fbG73NipyKQrFlHyws/IzT6FEpOvhlZhpQGDl2l7vkfoEVb+PZdr6A5g/QtdtOvNIhdsppjm+uiiBLIooRhmgws89O7yIMoilQXddt0fZ34WhJsQRCqgfnAMNM0Y5967ZvANwF69uw5tqam5ks77v8PWFUTzvO+vXpyNW+src/RxbvEjeJplRdW7iOSUvO2T+pTmGfjs6omzM9e38CWpvhBxzJ1laZnvo8WbaH8mnvy5ori62bT8cHdBKddTWCi5XnZRY890Es75+V38o34Rp+e27/5xZ+jNG7P8wL8smCoGRJrP7CUQVMRRE+IwOQL8Y0584C53TZiy18nse4DTDWLq99E/BPOzRPSMNQMyc3zia9+B7VlN4LdhWfw8XhHnGzRjA6xKOiZBOmdy0jvWEZ6z2qL3i3ZcFYOxVk9CmevEdaC8gXp1v9OME0DtXUvmZr1pPeuIbtvI6aaRZDtOHuNxNV/Eu5+Ew8SLOuC0rKHxPrZJDfO65xhq8Q35oyDugPZpp3Elr1KatsiEEXsJX2QQ+UUn2X9NlPbFqInIwSnXZO7J+2z/oFod+EdeQpysIz2d+8kOO0aZF8hRiaBeBQq9gfCUDI0PPJtJE+Qsiv/llvoc4vmsZcRPOYSYL9YScGpt+RsQHKdnk+p7BuZhEUZk2TKr777IGV7r0Mikc0PZbx2iYRibesWNft60Z1gHz0OFxd0xwT/HNriGSJpFQHwOGTKAi6aohkUXccuiaQUnQq/k23NccJJFVmC6mIvqmbgc8oU+fZTpwVBoLYtwbLdHXy4uYFoWqW2I004rmK3Q4HHRmzN+6x+8S4GnXMjjmGnYwK6AfFkjPrHbkawu3PMHEPJ0PjkbZhq1hp5cXot+vijNyL7inLFeIDUjmW0vvbrvFGzLxOZuk1E5j9Ddt8Gy6lixAxCx1+VWwu6uqCxJS+jttciBUrxj5+Jd/iMvOdxtmEb8dXvkNy6IKch4h15imUNaj+4kG6aBtn6LaS3LyW1cylauBEAuaASV/UonL1G4ug5PCdy6cCapTgUrfpw6BIm6/Lh/ldCT0XJ1KwnU7OOzN41aFHLJtZW1BNXv4m4B0zGXtb/kPGTkU2S3PwJiXWzUJp3IdgceIacgG/smXmjZEYmQXztB8RXvYWe6MBe1g9bYRWCzUHhKTd/7XFBZMEzRBe/QNnlf8bRw5rTNpQ0jY/fAmYnNdzuQou30fjYzZbK/mV/RpBk7OkYux+7BdHuos81d6HITuxYc+sd858ituQlpn37t5w9cyayJDK8wo+GwJ7WBAVeJ6qqsqs1QVXITe8SH2UBFzZJwiYJ9AgduWp7N/55fOUJtiAIXuAT4Lemab72We/trlYfPf7x0U7+OntbnlLwpD6FXPbIUlTNEje6enJ1Hn38QPoowGur6zCBYRUBfv7mBvTPUD1T2/fR+MTtOCqHUHLhL/dXDnMV62WUXf6XnH9k+wf3kFg3m5KLf4Or10iri/fSL8jWb6b86rtztktKaw2Nj9+Cd8TJn9utM02D5MaPSO1YgpGOIzo92EI98I4+LU+l/NNoe/fOzlnmekwlhVzQg+Cxl+fZPOipKPHV7xBf/S5GOoa9vD/+cTOt93Qu/KZpojRsI772A1JbF2BqWWyFPfEMm45nyLQ8qlzeeXd1dHevJLN3DWpbrXU/7C4cPQZb/6sYhL28/1HZlv2rYGRTKE07yTZs7RRo24KRSQAgF/SwgoU+43D2HH7Yjr2eDJPcMp/kxnkozbtAknEPmIJv1Kl5TAPT0EnvXE5s5Ztk921EsLvxjToV37hz0CKNxJa9SvG5P0GQbGT2bSS5+RPcAybnxMK0eDvJzR+TqVmHHmvFXtaPglNuxtAUlH0byTZsQ+2oy/2eXH3G4R15yucWPrpYHaWX/TE3gpBjbwRKLWsaSbbst56+A1e//RY2akc9jU/chr28P6UX/SbvWK1v/5nUlgXWAt1ZcT8SyJKAaZjdomZfM7oT7KPDkcYF3THB0aOmPYlNEpFEgURGpU+xF0U3aIpm0A2TkNtONK2wvSlOwGMjEleoLvES8tgp9DhIZjUaIik03cTtkFlV00ZzTGHB9haaIkm8TpmWhILXZiPksZFVVFY+9r+0bVvJ4G/dhRTqRVK1OqfK3rU0vvhz/J0sHAOrE9n0zB24+0+i6JwfIQjC/mL85IsIdVoWmqZJy4s/R2nacUSFd6Wtlviqt1E76hAEEclfjKt69GFdFwDa378bLd6OFmlECzeA7MA/9iz8487J2W2ZpkF65wpiy14hW78F0enFO/JUfGPOyI26gRU7JDbMJbF+FlpHPYLdhXvgMXiHTsfRc9hhR63UjnrSu1aS3rPaKkprWUDAVtzLsu7sMRhH+QBL9+QoxrVErCT765QrNg0dtaMepWGbFRfUbc45pAh2F86eI3D1HoOr7zjkQOlhPyOzdy2JjfNI71iCqSnYiqvxjjwF77DpiI79XVg13EB81dskNnyIqaRx9hqFf9L5OHuNJFu36QvFBaGTvmO5pdRtRmndgx5vwzTBXtwL37hzDqnFcyDUSBMNj3wH94ApFJ/9/dz29g/uJbFuFqWX/h5n1TBLEPDFn6M0bsvFwqZp0v7G70juXEHFFX/BX94PSQYMiNVsZN+zP6Z4zEnMuP5nzBzTgz0daUQBXA4bwyv8tMWyNMcyNMbSlHqdBNwOpvQvIuCyETxAzLAbXw++0gRbEAQb8A4wyzTNv33e+7sX06NHVwe7K5k+UMCsqzN914fbWbCjLbdPdaGbv144CoCLH1qCqlv3WhJAP4Lb3tWF/jSlW0/HaXz8VgRZpvyqv1tqjEqGxidvx1RSlF9zD5I7sH8OtaCSssv+mKN/WbMpb1F2xV8+M6kIf/IksaUvIwfLkTorjWpHvaVGegiBLLDUQmMr36DgpBuRPIGcermRjHRWps/FO2JGLhE01AzJDXOJrXwTLdyA5C3AO+o0q9J5gFq2kU11JohzydZvAQQcVUPxDD4e94Aph+3WgpWIZWrXk9m30VqI2mrpWg7lYBn2kj7YSnpjL+qFrbAKOVR2kNDH1wFT19CizajtdahtNZZNV8uenLclWNV3Z+UQHFXDcPYcnhd4fBp6JkF6+xKSW+ZbKuymgb20L57hM/AMmZrvW5qKktjwIfE176FHm5H8xfjHnoV35Cm5hVaLtRBd8jKu6tG4B05BT4ZJbJhrKcGPOrXTgsZEEESrCKBrqE07Se1YQrbe8iNHlLEVVCC6/OjJCFpHnVXsOe3Ww16H2l5Hw+M34xl0XM6z2jRNWl7+P7L7NlB+9d+xFVZZlesnb8dUMpRfew+Sy4+pqTQ9c8d+NsiB84udlPMDu99Hin4lXs4d3aNb1OxrRneCfeQ4mrigOyY4ekSSCm2JLIZpEnLbKPZ3Up9N0yrEiwJra8NsbozidcpkFYMTBpVQ5HPSFs+wdl+Y9oRCoddOkdfJlsYoO5sTNEdTbGmMUex1oOo6vUr8VkIQz6Akorzxi8sRnD56X3MnWcGBjtVBbfn4cWLLXqVs5o/xDjwGFYgue4XIx0/kiZ7un0P9A87KIQCobfusZ+yQaRSd8d3DXrPaXkfjk7eDIGAv7g2mgRptwt1/EoWnHLpgf2BMIPuLaH7pfzEyCZSmnSAIeIZOJzDxvJzzA1haNLHlr5PasRTAsuQac0Z+Mdg0ydZtIrFhLqltCzGVNJK3EPfg4/AMOu6wbDewivDZxm1kajeQrdtMtmErZqdKt2B3Yy/tY1l3FldjK+yJrbASyeXL7f9pS62vCqZpYqRjqO37UNtqUVprLEHWlt05EVbR4cHZYzD2qqE4q4ZjL+t3WKq/aegWhX7rApJbF+230Bp8PN7hM/I63KZpkNm9mviad0nvWgmihGfwcfjHn5tH5T6auCBTuwGlbR9K41bSO1dgZCwGp+QtyBUClOZdCLKD8mvvPawqvGmatL7ySzL7NlJxwwPIPmtd7xJBPVB/oEsksPC02/COOAnYH1uXnXgtBePOI+gTKPW5qG9tY809NyFKMlO++xAj+pQwsDyAJIqUB1y0JxT6lnpJZHT2tCbIaBoVATeabnLsgCKqi/79Gzb/jThcXPBPD7wI1l/Do8CWI0muu/HFMLZXiP89cyjvb2zktGHluaB6bK9Q7r+HlvvzEuxRVUHG9grxzadW5pJrOLLkGsA76jTSu1cR/uQJi+LcabEhuXwUnfU/ND//Ezpm30fhmf+DaHdSfPYPaHz6e5ZoyTf+F9lfRMEpN9P21h+JLn4hJzIRPPYyUlsX0v7BPZRfddchH8amaRBf/Q7uQcdRdPYP9j90dQ3EQ1fnTE0l27AVe2k/ZH8RhpK21DzdASSnj+iyVwl/+ACRhc9iL+2D5A7gHT4D35gz8Iw6lcye1cRXvU104bNEF7+Au99EvKNOw1k9EtHR2UkddSpquIHk5k9Ibv6Ejtn30THnARyVQ6yKab+JB1U+JW8Iz5CpuQq7kU2SbdyB0mnTpbTsIbV9CbklUxCR/cXIwTIkfwmyvwjJW2gJkLj9iC6/pRbqcFuCJZ+hvmqaJugahpKyVMQzlmKongxbwmmxVrRYK1q0ybLlMPbTkaVAKfaS3niGTMVR1h97xcC8Rf5Q0OJtpHcuJ7V9CZna9WDoyIFSa7ZtyLS8mfdcgLL2A5LbFoGu4qgaRmjaNZaa6Ke6ypK3EHtJb9J7VuHqNx7JE0JPtIMgosVaSayfjav/JJTGHZZCad0mAGwlvS0l8upROMoH5omsdcx5gPia9whNv/6Qc/amadI+615E2UFo2rW57fHV75DZs4qCk76dC8w65jyA1tFA6SW/zRUPwp88gdK8i+LzfpaXXKuRJjpm32dZzR0g+nOk6FPk6aaFd+PfFt1xwVcPn8tGRzKLoptoBhiGiSgKCIKA1LkkeOwSZT4HtR1pvE6ZaFrDaVPZ1ZoknFTRDZNwSqHY46BXgZva9gRVhR4qCz2U+pykszopTSOaUulXXURr0se4K37Movt+QMu8xyk48ds5enLJcVeg7ttAy/t3I5X1s577E84js2ct4bkP46gcgr2oJwUzvkl23wba3vkrFdfcjejwYCuqwj/hXGJLX8E77MScKNSnkdgwB9PQ6PGtR3JJDVhr/6FwqJjAWTUMyRPEUTWM2PLXSW6cS3LDHOTCSmR/Kb5xZ+PuM5aimYPQYy3EV79LYv1sUtsWWeNMo07DM+wEJJcfZ9UwnFXDME76Fukdy0hu+YT4qneIr3gDyVeMe8Ak3P0n46gckhfnCLItty90doTb95Ft2GbFBM27SayfnfPHBhCdPuRQGbJ/v3WnzRNC8IQQXb5OFXE3gt35uR1w09Ax1SxGtismiGGkIuiJDks8LdZqCaxGmjCyyf3nbXdhL+6dS4Yd5QOQC3t85vFMTSVTu57UjqWkdyxFT4YtEbi+4/EMmYqrz/i8NVlPhElsmENi3Sy0aLMlVDflYryjTzukRejnxQXxdbOxFfYgW7Oe5LZFmNkkotOLq+94XH3G4eg5PO9z1bZ9NDz6HZKbPyEw8bxDXlNq2yLSu1cSmn597neoJ8K0v/d3bCW9CR5nsTMydVuILHgG96Dj8HS6jiitewnPewRP7zGUHTPT+l4libRu0jbrQbR4BxNvuYvq8gJ6F7hxO+yE3DZ000QQIKsZBJ02RvcMsbY+TFsyS+8iNwXur78p043PxpehKHEMcAWwQRCEtZ3bfmKa5ntfwmd3oxOrasKWUJlmsGJvBwPLfAd1rnyu/D+wd9Y3MqF3IXO3tuRtlwQ4a2QFb6xtyG1z2yVSSv6spyAIFJ52K42P30Lrm3+i/Kq7crNGzk5f4ujCZ3FWj7IeuKV9CJ1wHeEPHyS+4g38E87FM/g40rtXEl3ykjWLXDUM0eGm4KRv0/r6b4ktf+2QCYapqZhKGntJ77wE8rNEUPR0DC3ShHvw8da/kxEMJYXoDuAeOAXXgMmkd60gPPdhsp2+1uk9qyk86Ubcg4/D1XsM7r7jLUXQtR+Q2DiX1PbFSP4SvMNPxDPsRGzBMmyhCoLHXEJgysWorXtJbVtEavtiwnMfJjz3YWyFPXH1HYez9xiclUMOUuIUHR5c1aNwVY/KbTOUjFUl7qhDa6+zfKmjTai7V6InIxy2Xi2IeTZdudK2oWMamrVAm4efBRDdQeRAMfbSfrgHHmvZdBVUYivq+ZlWHF0wdZVswzbSe1aT2b3Kon8DcqgC//iZlrL6p+autFgbyc0fkdjwYSfFzm2Jm40+LY+ZYOpqXidfECVcfSdYM3WLXiB0/BWYahbR6UENN6C07iW27FVMTUEuqCRw3OV4Bh2XG084FGwlVhfEyCYOeb2JdbPI7ttIwSk356iESssewh89Znmijz7Det/GeSQ3ziUw5ZKcqE9q5zLiK9/EN+ZM3P0nHXBdGm1v/RkEkaKz7vhceroowIzBpXy4pRnDtOjh35ra9zP36UY3/sXojgu+YqQUDQSBQq+DeEYjqxm47PnPEr/bToVpIogipX4ndlmkI6VgGCYum0RSUZElmSK/k152Cc2EjKITy6j0L/GhmwZpxUDVdRJplfKQG/upJ7F73bk0LnodR6/RVA6bSNoEwZTpd/4P2PTgrYTf/hMll/wRJJnCM79H4+O30vbmHyi78m+IDjdFZ91B07M/pH32fRSdeQeCIBCYcjGprYuswvu19xxy3EhPRZFcgbzkGjisK8lnxQS2UAWFp9yEb9w5tL93F0rzLrT2OjK1GwhMugD/xHORfEWETriWwLGXkdq6gPia9wnPe5jwJ4/j7j8Z7/AZOKtHIdqcuSK6kUmQ2rGM1PZFnfPCbyM4PLiqR+PqMxZn9eiDxssEUcLeaYFFp26HaRpo0RbUtlq0jjrUcANapBmlZTfpXcsxtcNPawuy3Sq+i9L+hoRhWIm1poB++IltweZA8hUjB0pxVwzCFqrI2XdK/uIjslNTI01k9q4hvXsVmb1rMdUMgs2Jq89Y3AOm5Cmrg7XWp3etJLHhQ9K7VoBp4KgaRvD4K/PH9j4VE3R9d4eKCwASG+eS3PgheqwVwebEPWAynsFTcVaPOmwsKRdWgiRjpKOHfF1Px+mY8wD2sn74OjUDTNOg7d2/YaoZis76fqc9XZy2t/+E7C+m8NSbEQQBQ8nQ9uYfER1uKs78LjZRRAIq/W6aVs9m94q5TDz/25T1HYbDJiFINib2DmHosL4xSnnAScBlI+C0U13koX+pl6yqE3Tb8Ti7E+x/N/zTCbZpmguxwvpufIVYursdRTMwTFA1g6W72w9KsCf1KUQ+wHrLME3e39jIgWMAAvDrmcMJpxREgZw3dlrR8/7dBckdoPCM79Hy4s8Jz30oj0obmHwhmdoNdMy+H3v5AOxFPfGNOZNM7XrCnzxhzRVVDKRgxrfI1m+m7e2/UH7N3UguP+4Bk3EPPJbIoudw9Zt4kCKnINsRXX7UcANHDNNEad5FaMY3AdA66jHScRydis+CIKC21+KsHoVn6PesBXPt+7S9/SfsK17D0XMESuN23AOm4J90PsHjryS1fbHlub3oBaKLnsdROQTPkGmWMrY7gL2kN/aS3gSPuxw13EB653LSu1YQW/kWseWvIch2HD0G4agajrNqGPbyAYf0WBTtThzl/XMz7XmXpWtWxzkZsWy6MnGMTAJTSXfadGXB0DAN3UquBToXVtn6Hm1OBIfb6nq7fEiuTksOb+ioqeimppJt2km2bhOZ2vVk6zZbVDFBxNFjMMGpV+PqN8ESHjlgIbaCjqUkN31EpmY9YOKoHEJg0gW4Bx67389SU63AZPV7yMHSHCW7C7K/iOCxlxNd+Cz1D30TRJn0njXo8VZEhwfP8JPwDj/xsGIqn4YWbrD8xT0H06y1WAvhjx7F0XME3s6gx1AytL31JySnl8LTbrN+U2376Jj9DxxVwwgcc3Fu3/Z378Re2pfQCdflfW5k4TMojdsoOvuHB9mTzBxVQZnfyYMLdufcVWTRSqi/NbVvnlBhN7rx74ruuOCrhyQKGKaJqhsIwKEed8VeBy6bhCAIyIJAVtMJumwU+RwE3DaKM076Fnko8jtRNIPKkJvGSAq/UybosbGhLkLIbSOj6iSzOpppkspC/9OuI7prHc3v/Z3CXndTUFxKPKMjhMqoOuMWal77Ix3znyJ0wrXI3gKKzvweLS/9go45D1J0+m04egzOFeiT1aMtQTGbk8JTb6H5hZ8Qmf80BSfecND1yL5iy64ymzqiAvDnxQQA6Z1LsZf0puisO8jWbiC84Gmii54lvvINHNWj0GNtueTZO3zGfpHOTR+T2roAyRPCPWQqnsFTsZf1s2a3h5+Id/iJGEqazN41pHauILNnFaltC63rKKjE2dOKCRyVQw+p5yIIolXQD5YBEz51WabVeU507LfuzCQsmy4ljal12nQZumXfCSCKCKKVtAk2h+WX7fBYNl0uP5I7aMUEDs9ReZKbpokWbbao7vs2kqlZlxM4k/zFeIaegKvfBFy9RuY1G7pE4JKb55PaugAjHUP0BC0rruEnYSuszL1XadlNfM37pLYtouL6+w+a0++KCyILnqbuvqsxDQ1j/RzAxNFzBMHjrsA9YMohxeg+DT3eCrqG5Du0xk74wwcxMnEKL/pVrjgeW/46mb1rKDjlJuxFPS3m23t3oSfClF3+p9yYW8ec+1Hb6yi56NeI7hBZFQQRGvftZtHTf6F80BiKp3wDwwC7TSDgkinxuXDaJFTTcgLIqgYOr4BdFnNChd3498SX54nQja8Uk/oUYpfF3Az2pD4Hz4aM7RXiV+cM43/f3IhhmthlkdOGlbNibweKZiAKAr/qtPRZVRPGLotkVWO/QMZhmqSuaktQIrb0ZSs57awGC6JE0Vl3WNXpNzqr03YnhafdRuMTt9H65h8ov/pui1J+9g9pevoOiz5+3s8RBIGCk75NZt8G2t/9G2VX/DWfQiUIOHsOJ7N7FaahH5H6tp6KWDQpmxM9kyC1azlyqDyvK5ratpjgcZfjrByMs3JwbiHQIo3El78GsgM9FSW5+WOKz/1pbmHVYi0kN31MctNHOVq4s3oU7oHH4O4/CamzIm4bPxP/+JnWwrpvI5m9a8nUrie68DmimCBK2Ev64KgYgL20H/ayvlYy+hmJriDJFmX8M+advwpYtLU6lJbdKJ2U9mzzTtAtQxBbYU+8w0/E2ctSST9QlAQ6k+qdyzu9L1eBriEHywhMuRjPsBPyhOrUtn3E18+ylMXTMeRQOe7Bxx18TrqG2roXPRPvVGU1cfYaQWjaVbgHTDkq307TNEnvXIGjx+CDvn/TNGh/7+9gmhSedmsu4Oj48IHcAil5ghhqhtY3f49gc+a60aau0fbmnzAN3RpvOKC7kt67ltjSV/GOOBnPIa6vPanQv9SX97d4wbiqvJGQbnSjG91w2SSKvA5Sik6J34HTdvAaKYoCfpcNt10ioxlIgoDLLiFLItG0SpHPQYHHKvjaZZGKoAtNNzFME0yBgMuBoum0xLLsaoujKDqCKVDgcTP8sp+x8p4bqX/zToZ+905UPYMsmNhHTSVRu4H25a/hrxqG1G8C7t5j8E++gNiSl3D2HIZ32ImdBfr1dMy5H0f5QGxFVTh7jcA7+gziK9/C3X8Szp7D867H2Ws40cXPk96zGs+gYz/3OzqamKArmU3tWoHk8ls2nNsWg6mjxVpIrPuA4vN+jr2kNwUzvkVo2rWkd60gsWlejhYuB8txDzoG98BjsZf2RbS7cA+YgnvAFEzTtNw39q4lU7OO5OaPSax9HwDJV4yjYiD28n5WXFDa55DWll0QBAHJ5fvcka2vAnoyjNK8G6V5F9nG7SgN2yw7MkB0enFUDcM3fiau6lHIBZV5ybqVVG8jtW0hqW2L0eOtFl2830Q8w6bj6j0mF+sZSprklgWWsnjjNgTZjnvQ8Yfs3KvtdSTWzya9dy1mOobkLSQw+UI8I076XLGyTyO9cwVAjol2IFLbFpPc/DGBYy7NjUxm67cSmf8U7gFT8I60dAbiK14nvXMZoek34CgfAFie78mNcwlOuQRf9Sg0QNVANhS2PvlLkOz0Ofd7pDSBUp+ITZTxuWw4bCIOm4TbIWMTdQSHRND95fiTd+OrRXeC/R+Csb1CPHv9pM/tYF06sScDy3x57/v0vw/8vFdX1/Hyyn1ouokkCRiGeVAXG6y56WztBto/uAd7WX9soXIAqzp91h20vPhzOmb/g8Izvofk9FJ8zg9peuYHtL/7N4q/8XMcZf0InXAN4bkP5+jjkidI4Sk30fr674guep5gp6poF9yDj8/Nurj7Tfzc78heXI3kL6b+oRssmnNhFf6x5+ReV5p3IUgyjh6WsEqXEmbJN36OHKogs28DyfVzSG5dCLpK09P/00lznwqIOHuPwT/pAtSWPSS3zCe1bSEdH9xDx6x/4KgcgqvvBNz9xiMXVFoLa9/xuPuOByxakaXAvZlswzYSG+dhrn7XOjFRzlGwbIWV2Ap6IAfLkQOliO7AUVWTjxZWJTyOFmlGizSidtSjddSjtNVaqqCdVDLB5sBe2hf/mLNwVA7G0WPIIYXd1EiT1cXfuZzMvg2WN7ivCN/oM/AMPj5P+EXPJEhtXUBiw4coDdtAlPLm3g+c61Lb95FYP4fEpnmWYJ23AP/kC6xKd+dv8WihNGxDba+lYNxNB70WX/UOmZp1FJxyc26BTmycR3LDh/gnX4SrehSmadIx+z7Utn2UXPjLHG0x/MkTZBu2UnT2D/Po6XoiTNvbf8FWWJXrqHwapw2zruXAP8GhFV+unV03utGN/3wIgkDQbSd4BI1cWRLxHqAsbO13cJDutEn0LfHSnsiS1Q36Frup7UgRTanoisG6+iipjIYoGNiKyulz1s3sePUvLHr1EXyTL0UEVMOkYOr1xPdto+Hdv1F+9d3IgRIKjr2MbN1mOmbfh720L/biaorOvIPGJ26l9c0/UHblXxFtTkLTriGzdzVt795JxbX35BVuHZVDkbwFJNbPPqIE+wvFBG21hL7xc2wFPdAS7SQ3zCWxfjZqWy2Nj92EZ9iJlqWkrxApUELJeT+z1rJti0ltXUBs2WvElr5izWD3n4Cr7wScPYcjyPYc480/4VxMQ0dp2UO2bhPZ+q0ojdtzHW6whLdsncKntoIeyKFyKy7wF33lIqimpljaLJEmi5reUY/aXovSWouRiuTeJ4cqLG2TikE4KodgK+510Dy2oWTI1KwjvcuKC/RkGCTZEiWbeiXufhNzbATTNMjUbtgvHKdmrPVy+g14hk3PKygYSprU1oUk1s8hW78ZRMka2xp5Sl6iflTXbZok1s+2xOUO0IuBzhnrWfdiL+uXG2vU03Fa3/ojkq8oV4jP1G0m/PETuAZMxjfubMCau+6Yfb/l6X3MxXme5a0fPkq2ZQ/l5/+ClBxE1DWyukzALTOudwEehw1FM9B1g7SqY5elg0ZBuvHvie4E+z8IBwqaHc37Pms/AXIUchEY0yvE8r3hg98nyRSd/QMan7iVtrf+aHn5dXbmXNWjLLrXoudwVA61rJfKB1Bw4vV0zHmA2JKXCUy5CN/Ys8ns20j4kyewVwzCWTkY94ApeIbNILr0ZZy9R+dEPwDLS9lXTGzZq7j6TvjcRFOQZIpOv41s43aMdBxXn7Ekt8xHDlXgKOuHFm3BXtIb07Aeb6lti6xFrKAHpmlYVCrTJHTCdUQ+fgzR4bbmquc9guSzFjVT1widcC3BqVcRnHoVastuUtuXkNqxlMjHjxH5+DGkQCmu3qNxVo/G2XO4Rb9y+XD3m4C7n0X1Mk3DSmSbd6O07LFUuxu3k9q6kAPTK0G2WwJn3gIkT9ASOMuJnLkQbM7OGWwZQZAsnqBpYpo6pq5hagqmmsVUUhiZJHomgZGO7Rc5S7TnCamAReuyFVbh6jXSUjcv7WN12Q+xYBnZJJnajWRq1pLesxqtox6wKHD+8TNx95+MvWJAbtE11Czp3atIbv6Y1K7loGvYinoSOuFaPEOn5yXtRjZJcssCkhs+JNuwFQQRV78JeEecjKvP2H/aUzy6/FWLVj44395Fad1L+OPHc4u1ta0mRwMPHnspYM1nJzfOI3DMJTk7kNT2xcRXvGEJ5x3QoTYNnbZ3/oKppCm6+DeHnC/sX+zh0ok9+cdHO3PjGqIA4dTROKN2oxvd6MbRwzRNYmmVlKIRy2jIkohgQkXARVnISTiTRRQE3A6RaFojljRwDZxKaMQq9n38AgOrhiFWjSBrginbKT7nRzQ8cRutb/6eskv/BLItF0O0vvF7yq+8E9lXSNGZ/2PRx2ffT+HptyPanRSd8T80PfsDOmbfT9FZd+TOURAlvKNPJ7rgGZSW3bku4uHwz8YEapulnu2ffCHhDx/C2XM4yY3zSKx93xq7sluJYWj69XhHnIRv5Mno6RjpHctI7VxGYv2HxFe/iyA7LEp49Wic1SOxFfVCECUcZf1wlPWDcVbSr6einTHBbtTWvahttZ1iZ5kDrwrJE7RiAm8BoiuA5PIhOr2d5+RCkB0WI1CSD7BYNUDXLOq4msVQMxjZpEUtT8cs8dNEGD3RjpHKnz8W7C5shZW4+o6zZsVL+mAv7XNI/+iuwkGmZh2ZPWvI1G0EXUOwu3D1Hot7wKTOGWxP7nentO7tFI792JqXtrvwDD4ez/CTcPQYdGjl9q0LMNWMZcM67Wq8w0485KjX0SBTsw6leRcFp9x0UOe97b07rRnrM/4HQZItltu7f9tPA3d60ZMR2t78A3KglKLTb7fmrrMpWt/4A4LDRfFZ37do+l3H27qQ+Jp3CY6fScmQ8QgiFLlsFHmdVAadmKaAaZokFQ2bJFJV4LFGNQyT7h72vz+6E+z/T9Fl+9VFEQfQDZP+pT5W1YYP6ZEtB0ooPP12Wl/7DeGPHqXgpG/nXgtMuYhs/RY6PnwQe+ei4R19Bpl6S0XRXj4AV+/RFJ12G41P3k7bm3+g/Oq/I3mClqpo3Sba3v5r54y2VaUUJBn/xPMIf/ggmb1rcknM56GLkgNgK6jMUYrs5f1JbJybo6LHV72NZ/hJne8UcPQYZKldbltI2eV/xl7al8TmTwjPuR/T0NFjrSDZiK9+F9PQcfcdb1XiS/sSPO5ytFiL5XO5exXJzZ+QWPsBXT6X1qzVEBw9BiH5iq35qsIqbIVVef6dpqZYAmeRRrRIs6XmmWhHT3SgtNZgdM5afZZw2WEhiFZi7rbmreylfZH6jkf2lyAHSjor5GWH9bIG0BIdKPVbydRtIlu3CaV5N5gGguzA0XMYvtFn4Oo7Lo/+bahZ0ntWW2JwO5dhKmlEd9BSYx063bL1OMAH+yBvzMIqgtOuwTts+j+9gHYh27ST9PYlBKZckjfLZ6gZ2t78E6LTk5uxthbI3yPYXRblW5TINm6n48MHcPYeQ2CKNXetdtTT9u5d2Mv7Ezrh+rzjRZe8ZHXET731sBZzO1qTnHbXfNQuNeBOn+tDjYN0oxvd6MaXiURWoyWexTRMmqIZ+pV5UTQDr1NmZGUByYzKXn+GWDJNfUcGTQebLFB+8o2kGnaw89U/0+vquxG8IWTAFiqn6IzbaX39d4TnPUzJyTeCt4Dis39I8ws/pf29uyia+WNcnc/Q6OLncfQYbBXoewzaL6LaezTeYSfmztM35kxiy14jsuBZSr7x8yO6tn82JkhunEfZZX/EXtqXdM162t7+M3KoHKVuC2AS/uQp1I56PAOPwVZYiXfESXhHnIShZsnWbiC92/K/Tu+2rOhEdwBn5VAclUNxVA62xFwlG5I7gKv3aFy9R+fO1zRN9EQ7WrgRLdKEFmtFj7ftjwta9mCk452+2kcPweZAdPqRPAFkXyGO8v5IviIrJgiWYQtWIHqCn2E5ppBt2mWx8+o2ka3bbMUogK1Tl8fVZxzOqqH7xco6k+rUtkWkti5Cba8FQcRZPdoSNhswOS8OUSNNJDfOI7npI7RIo+U9Pug4vCNm4Ogx5Eth+ZmmSXTBM0jewrzfG0B8xRtk9qym4OQbsRVZriGxpa+Q3rWC0Ixv4SgfgGnotL71R4xMgrIr/g/R4bFmsd//O1q4gdKLf7vfcx0Qwo20vH83zvKBlE27CpcIRQE3fUr89ChwURbysnpPG7taHHgcMllNpzTgwiFL2Lp9rv8j0J1g/3+KLtG0ruTaEkkRGFoR4KVvTeGmZ1bRFD/4ge3uPwnfuHOIr3wTR9WwHE0rN4/9hLWgll99F5LLT+Ept6C27KXt7T9TftWdyIFSimf+hKZn7qD1rT9RetGvLVXRs79P0zPfp/39v1N87k9zD0zfyFOJr3iD8EeP4ew18qi7lgf6JYpOL6LTR+Pjt2IrrMJRNRTfyJMxdc3qANuchI6/kpbXftOp3A1qy27s5QPxTzwPU0mT3rOG5LaFtL/9Z9olm6UG3m8irs5E1Tf6dHyjT8fUNbKNO8jUriNbu4HEhjnEV78DgOQJYS/vbyXnnR7YcqAEoVMR3F7UM8/O6tMwTdMSOFNSVndaUyz7MkMnp3ImStY1yXYE2XHE9h0HHkOPt6K07LVmsJt3oTTuQI932sBJNhwVAy37q14jcFQMzrfaSEZI715FaudSMntWd6p9+3APPBbP4ONx9hqRu5emaZJt2GZVsLfOx0ha3pie4TPwDjvxM/1EvwhM0yQ87xFElx//hJl5r4U/fAi1vZaSC3+F5AnmxEqsBfI3yN4C9FSU1td/j+QpyM1dG0qG1td/hyDJFJ/z44PmrqMLn8Mz9IScD+bhsKUpnvvvCdUhfnja4O656250oxtfORTNwC6LyKKAEINwQkE1TDx2mT5FHoq8VYgGzNrcgCRYK01GA4fDRd9v/Igtj/4PLe/8maqLfo3dJpHQwD1gCoEJ5xFd/hr2ikF4h03H2XM4oWlXE/7oMWLLXiUw6XwCx1xMtmErHR8+gL20D47yAdaMds06i1pbPiBnhyg5vQQmnU9k/lNkatYf1tLrcPhnY4LMnlXYS/oQmPQNBIeXbO16ElvmE13wNNEFT1suIv0n4u43wWos9B2Hq69lkavFWsjsXUdm3wYy+zaR2r7YOhHJZnWFy/rt98Au7Ilod1oCdb4iawTpUzPpB8LUVKsjrWasmEBTrJigqxgviJ1dbRuizYFgc1pxwVHQzY1sqtMPew9K0y6UZstmtMveUw6V4x4wBUfP4Th7jsjzkTYNnUztBlI7l5HesQwt0ggIOKqGUjDmO5Zw7AEMNj0ZIbVtIcnNn5Ct32K9t+dwAsdc3ClY5jri8z4SpLYuJNuwlYJTb8nTcsk2bCP8yZO4BkzGO+o0ANJ71lj2W4On4htzJgCRj58gW7uBwjO+m2NWxFe8QWrbIoLTrsnpCYhAoV1l01t/QBBFhl72Q7yFHvxuG9P6l5BVNTpSCs2RDPFslo60hssu4bKJlPtd9Ai6uhPs/xAIBypMf10YN26cuXLlyq/9uN3Yj64OtqLt74SaJjhsIs9eP4nXVtfx7LLaQ+5r6irNz/0Ypa2G8qvuypszzTZup+nZH+CsGk7JBf+HIEqoHfU0PvU9bMEySi/7I6LNSWLDXNrfuxP/+HMJTbdUlmMr3iA87xFCJ1yHf8K5uc9MbVtM6xu/IzT9BvzjzznofI4WaqQJPd5m+VdmEqR2rcDVaxSSN4SpKTS/9L+d/oaF1D/0zVzCrDTvovD025H9xWTrt5DavoT0jqU5oTRbSW9cvcdaVPceQ/KSLFPXrHmrhq05/2u1vY4uOrhgc2ArqEQu6IEtWG51kwMlSL5iiwZ2COXxfxamplpV8HgbWrTF6pqHG1HD9ajtdZhKOvdeOVSe8710VAzEXtrvoOvLNm4js2cN6T2rUBp3AiaSt9AKNvpPtubQOjsFZqe6a2rbQlJbFljfoSTj7jsBz9BpB3ljfplIbPqI9nf+SsEpN+HrXDDBmrFuf/dv+CddQGjqVQBEl75C5JMnCJ1wLf4J52HqGs0v/Zxs/VbKLv8zjrJ+mKZJ29t/IbVlPiUX/jKPaaHF22l84jYklz8nAnikEIBXvjOlO8H+N4IgCKtM0xz3rz6P/zZ0xwT/emQ1ncZIhpSqYZoGpg42m4huQtBpoyzg5M21DbyztpZVe8KkVbDZwGUDmwh1yz9k39t3UTD5IqpnXEFWB7sE4YxO04s/I9uwjbLL/4K9tI/1zHzrT6S2LaLkgv/D1XsMejpG4xO3gYlVoHcH0OJtND5+K5InlJvRBosV1fDojYiyg/Jr/v5PzyT/szGBLViGFmvNjYtl920E00B0+XH2Ho2regzO6lF5CSdYlpXZhq0oDdvINu1Aad6Vt+5KgVJrdrygwpq/DnZ5YBchOn1fuj6LaRoY6fj+mCDapc3SgNpeZylsd0J0eLCX9cVePgBH+UCLnfcphpkWbSG9dw2ZPastAbJsEiQZZ8+Rlkd4v0m5ri5Ytmqp7UtJbV1ApmYdmAa2ol54hk7DM2Qqsj/fdePLgpFNWb8nl5/yq+7MFf/1dNz6TYLFrnR60aLNND5xO5K3gLIr/opod5Lc/Altb/8Z35gzKDjpOwBkatfT/MLPcPefRNHMHyMIAjagwAN1795Hw7L3GHblr6gYMYWQ10aR28WEPkXUdyTZ2Z5A16A06GBQmZemqILHKTOiIsiU/sXY5e4E+98Jh4sL/iUdbMP4AvTWbnypGNsrxNWTq3lowe48UbOsavDa6rrDCYoDIEg2is75oaUU/vrvcg8ZsKhYBTO+Tcese4kseJrQ1KuxFfSg6Kw7aH3lV7R/cA9FZ96Bd/iJKE3bia14HXtZXzxDpuEbdw6Zuk2dM9oDcFYOBcA1YDLOPmOJLHga94DJB1kbHQ5dgiV6KoogioiuAHKw9ADrC0CU0WOtNL/wE4se7fIj2l04yvqRrlmHHCwnNO0aADo+fBClYSu2YBnOqmE4q4ZhTr8eta2G9K4VpHevIrbidWLLXum05xqMo2oYzqqhnYtQvg2XoWRQW/eitNWgdoqKKQ3brDnsT1HARYcH0RNEcvmtqrvDg2B3WQFH1wz2Ad1909CteStNwdCyVse7c95KT8UwUlGMTJx8CEj+ImyhHniHz7Ao7J3enJ+2RDE1lUzd5k67ro1k6zbl7Lrs5f0JHHsprk4KfR79u24L6R1LSG1fjBZpAgQcvUYQmHKxRQs7xExX/nEVtGgzejKCaehILh+2op5HHGDpqSjheY9gLx+Ad8TJue1K615LrK5qGMHjLgcgvXsVkU+exD3oOHzjrYJP+KNHO6vU37Nm5+isUm/5hODxV+Yl15aa+B8x1QzFl/z+qJJrsEovr66u606w/02QSqX+1afwXwtd1zEMA1HsDhz/VXDIEl6HxPamKElFQ9EMehZ6MEyTSELB75JpjKYIJ1VsNgFRNKnw24lmdTRDJzRqBol9m+hY8iJFvQfh7D0eXbPYbT3O/gF7n/wuLa//lvKr7rTYbafdhtq+j7a3/kTZlXdiC5VTfO5PaX72B7S+8XtKL/oNsq/IElF96Rd0zLJEVAVBQLQ5KDjp27S+8kuiS18heMwlR3ydWqIDLdyAqWuIdlfniFTpPx0TyP5i/OPOxj/ubPRMgszuVZ3U8DWkNn8CfNqea0inO8ixOSagaRqWsFhLZ1zQvg+to55E/ea8xBsASUZyh5DcB+qyuC1dFpvdsuMS5f3+baaJaWiYmoqpZTGUTKc2SwI9HbfsvlJRMLS8wwh2N7aCHjirhlpCrEW9sJdUI/lLPjWnbKJGmiya+D7LxtNa40HyFlpq6n3HW77hB8QTWryd9M5lpLYtJlO7HkwDyVeIb/y5eIedcNiRqgOPqyfa0ePtGEoaQZKxFVYdZOX1WYjMfwo93k7xOT86gFln0P7OX9ETHZRd9kckpxdDydDy2m8xTYPic3+CaHeiNO+m/f27cVQOITTdGg3TYq20vflHbKGK3Cy2C/C5IbzhYxqWvceI0y7nxFNPw2UTKfI5UDTYUNdBeyJLZYGXMr+LlKKQzOh4XXaqC1wYpommG90J9r8JNE0jHv90HL0f/5IEe926dWQyGZzObg+3fxVW1YR5ZOGegxTDTeDFFbX86pzDU5EAZH/x/oVv9v6FD8A36lSUpp3Elr6CvbQfnkHH4u47nuBxlk9hrKQ3gYnnE5p+PUrLXtrfvxu5oBJHWT+KTr/dmtF+4w+UXXUXsq8QQRAoPPkmGh67ifb3/07JRb/+XKpzatti2ufch9FJ68o790CpRc8uH4CjcjD+8ecSmHwhqe2LEezu3LyWFm7E1cdKmAw1i+QtRGmt4UAzKkEQLNGP4moCky7AyKYs+lcnDSzfnqs39vKBODrp4bbCKssju8eg/Hugq5aCZ7QFPd6GHm/f74OdiaHF2zHbajG6/C5VhUN7rAkWRdzmsJJxhxvR6cVe3AvRHUTyBJF9hUi+Ymveyl9yyK6xaegobbU5mni2cRtK864D7Lqq8Aw7EVevkTh6jUA6IEnWM0mytetI7VxBaudyzHTUoqrZnIieEPainjh6Dv9M6nRy2yJSWxeiNO/KWXMdCNHpJXTCdZ9LvzZNk445D2BkkhRefOt+O5BMgtbXf4vocFN89g9zrIu2t/6ErbhXbhY7sX4O8VVv4xt3Dt5h0wGL/h3++PFO7/QL8o4X/vhxsvWbKTrr+7m5rc+DQxbJHsAq6TYS/veAYRgMGjTo89/YjS+EtWvX8swzz3DllVf+q0/l/1sYhklLPEtKMXHIMolMli31MaoKPBR6bcSzOl6bhCkIFPgcNHRkiGY0Cn1OdNOkNZZl0MybWde0i12v/5XeV92FLVSGQwDBF6J85o+pe+6HtL35J0ou/CWi3UnxeT+j6cnbaX3t15Rd/hccZf0oOPUW2t/5K+F5D1Nw0nesGe1jLyW68FkcFQNzlFx33/G4Bx9PdPGLuPtP/FzBMz0Zpu2dv5HZu+ag1wSHx1rHy/rh6DEYz/AZXzgmAIvG3mXxaZoGaste0nvXkqn9tD1XUac9V3/sZf2tkbFQBbZQBe6BU3KfZ5omRipqdZRjrZ0xQYcVE6Si6Ok4WqzVSpiVTCc9XOOQEGVr7trm6EzKPci+QsSSPkjeIJKnANlXhOQvRg6WHrZTrqeiKE07c3Zd2YatuXgrZ9c19iycvUZZRfCcUJlBpnEHmd0rSe9ajtK4w9rHE7JcU2Q7tsJKRIfrsMm1Fmsltvw1yz60rebg4gPWeELh6bcdZB/6aWRq1xNf/Q6+sWflxWLRRc+T3r2SgpNvxFExMDdPrbbsoeT8X2Ar6IGejNDy2q8RnV6KZ/4YQbIhqllaX/8dpqYw4JKfojrcyCIUemWMSA273riHkv6jGHXuDfQs8jCpTxFp1aC2PY7fbUMzwSaBJMKg8gDVRW6a4goSAj6X7ZB2fN341+Cvf/0rP/rRjw77+r+EIi4IgvnQQw9xww03fO3H/m/Hqprw51p5Afzjo538Zda2w3aqL53Yk5dX7kPVP/v3EVn0PNGFzxKa8S38Y8/KbTc1lebnf4zSuseihZX03k8L27qQ4vP/F3ff8ejJCI1PfRcMg7Kr7kT2FqC07qXp6TuwF1dTesnvc0lffO0HdMy693Op4lq0mfqHvom9tA++sWcj+4sxDR0jGUENN1gzRE07O2eAsOaJewzC2XMEzl4jcZQPQJBk9EyCtjd+j71iEIJsQ23fZylY9xqJaZq5BSO5dSFqWy3O6tE4yvvn+XnrmQTZus2WRVfDNpSmHfsXA0nGVtgTe3GXFUclckEFcqDs6Ludhm6poOdmsMWjnlc31KwlohJpRO2oyymoqm01OVEYQXZgL+tr2XL0GIyjckhepdg0dJTmXWT2riW9dw3Zus2d81mCJYI28lSrS20auPqOR4u20Pb2nyk46dvYS/se8rzCnzxJcvPHOMr6YyvuhRyqQPKErHuUCBNf+x7Z2g2UX/33w34G7KeGB4+/MmezYRo6La/+iszedZRe8juclUMwMgkan74DIx3L6QZk6jbT/MJPcFYOo+TCX1pJeKSJpie/i+QNUXb5X/Kq8skt82l760/4xp5FwYxvHfE9+PbxfXhs8d6c3/3zN0zq7mD/G+C9997jjDPOAOimiH8FEEXRHDJkCBs3bvxXn8p/HTTdIKsZSKLwmcG5YZjsbkuwfl+ErG7glARSik5lyI2OwOByHx3xLE8v20ssmWFvewq3LNKWzOK2y8iShNsusGvnLtbcdxtyoJTxN9+J0+miPZFF16FtzRwa3v07wXHnEDjRiv/Se9fS8tL/4uo7juLzfoYgiNZ89vLXcmM8pmnQ+uqvSe9ZTenFv825jejpGI2P3oTo8lF25Z2fOUrV8tpvyOxZQ2DyhZauh2zDyCbRoi2o7ftQO9W7u9Y6OVTRGRNYcYHkDhxxTGAoacJzH8ZZPSq3bxf223NtztHDu8bMwPLEthf3ytl2yqGK/WveUVDCTdO0kuyuGF8QLW2Wo/oMAz3egRZp6KSJ748J9ERH7n1yqAJHxUArJugx+CC7Li3R0akqvpr03jX7Gx+SjGfINPzjz0XPxEFJH1FcoMXbaHj4250is32wFVQiBUoQ7S5MNUu2bjPRZa/gHT6DwlNvOez1GZkEDY/fgiDJlF99Ty7uSu1YSutrv8Ez7EQKOzvQkcUvEF3wDMFpVxOYeL41KvnCT1GadlJ62Z9y42Lt7/6N5KaP6HfJzwkNmIhigE0GPZ1g60PfxdAUZvzwQZyhEoaUeelZ4MPntKEaJmUBFxlVpdjrxO+SqSzw4LHLRFIKim5Q4LZj706w/y2g6zpVVVXouk5LS8sh44J/SYItiqLZr18/tm/f/rUf+78ZB85V22VrlvpwwfmqmjAXPbgY7TBs/csm9mRfR4r5O9o+85h5C98lv8dZOST3mpbooOnJ20GULVqYO4ChZmh+9oeo4QZLqbu4GqVlN03PfB9bUS9KL/k9os1BcutC2t78A57hJ+X8BU3TpPW135Des4ryK/562GSqa5a24rr7P7NzqKeiZOs2k9m3key+jZYiNiaC3WVRwHuNQg6Vk23cjplJ4Bs/cz+N7AB0fPgg8VXv5PZ1VA7BWTXcUggt+9SssqGjhRusbnDzbpRWix5+4GwTWCqjsr8EyVdo2XR5gkjugEUFc3qtbrTdZXWnJYsijijRabwGho5p6LkOt9EpimZZcnTSwTqturR4G3qs1fKnPACStwBbYU9sJdVW972z655HRdc1lJbdpHcuJ7VzGVpHfS5AsZX0xlk9GrW1BkG24eozDt+oUzHUbF4g1PjUdyk8/buHFXYzDf0ziwV6Kkrd/2PvvOOcqPP//5yZ9LrZZLO79I5U6V0RO0VFRECsp97pNXsvoCCKWE/vvp6901QQsaAigqCINKnSO+xuNpvek5n5/TEhsAKCnpx4v309Hpy3yczkk8luPu/yer9ez156yOz+wcgG91Hx2o0YvE21ok3+esH5rxFZ8i7FZ/9Fm6tTZHzvPEBq1xpKRz2EqWF7cmEfFW/cjGi0UHb5k0hmO0o6QeVbtyPHaii74slaqulagehWDN7mlF4y4aj09SKzDo/dxNV9mzK6Z6NjLpLV4b+Hvn378s0330Bdgn1cUFpaqvp8PhYvXkyvXr1+6+X8z0BRVPYEk+RkBUVVKS8yYTUe/vsoJytsrAyzdHuQbE7BbBQpthmpX2RGFAQaFVsQgPkbq1mxo4aaZIaGRWY+WVuFUVJRESlz6ImlVCrWfc3i5++hrMuZdLviHjKygkkvEU9n+X7as9Qsm03poJsxddDUmiPLZxOc+zyOnsNxnXbVQYXP7/GOGIe58cn5wuetKKmYZvGVHxVLbluO752x2DoPxn32n494L3Y9dTHWdqf/5DGqnCVTtY3U7nWk80JkakYbDdF7m2Ju3AmduwG5UCVqJnnEmCBdsZmqafdpM8eAvqQJpkYdtJGx+m1rzR3DfnuurZqgqG872eqdZAN7Qc4WjhF0Rq2rbPfk4wIXkqUI0eJEMtkQDqaIF8bGjkQRz6BmU/mYII6SiqIkIsiJUMG+Mxfxk4tUH7IGvbsB+pLGGEqaoi9thrG0+SGjXbmIn/SedSS2LSe1YyVKPrYQzQ6t6NCks8YkkHOYmnY5LnFB9fsTyVRtof51Lx3+fFXFP+tREpsXU3bpJIz1WgOaJWflW7ehL25A6eiJWjy68Wv87z+Ctd0A3INvAaDmk2eIr/kcz3m3Y23bHxMQ+G4G1V++Qnn/Syk97RKQISmD2wzr33qQ8NaVdLxmEq5m7anvNFNkN9LIZaF5qR23xUggniaUytDcY6Oxx0ajYiuiWMdjOxExc+ZMhg0bxoQJE7j33ntPnBlsj8fD5s2b2bt3L/Xr1z/6CXU4JuxXBlfUA7PURwrQN1ZGj5hcSyL4omm+3R44/AEHQRBEPENupeKNm/G//0iB1g2gsxVTMuw+qibfRfXMhynN+/+WDLufyjdvwffeeC1R9jbDM+Q2qmc+rFl3nH871pP6ka0eRfibqRhKGuPoPlSjig+8gYpXb6D6/YmUX/X0Yek/+xUglaPYVkgWJ5ZWvbG06g1oghapXatJ7fhe83XeulQ7zlaMqfHJpHevA0FE/6MZ8OIzr8PZ9xJSO1dr5+9cTWjb8vyL6DGWtcjPYLfCUNZCEzJzN8Ta9rTCNZR0gmxwH7ngPq2LnKeC5YL7SO9eW7C9+DUhGK3obMVIdg/6Zo3ROb3oXeXoisrzFK3a91ZVVXLhKo0WVrGJdMUmMhWbC/Yggt6kXcvdAPe5f0eyupCTUQKf/xu9q55mMZZJ1lL/zAb2IhptP9mxP1on/sDrH757oeYy+Gc9WlC633+92Np5RJa8i63TQC25VlUCc18gtWMlxefeoAnepBP43huHKufwXjQGyWzXPDE/fIJszW68I8bVSq7lVIzqGRMQDRY8Q+86YnLdttzO+gptdieUzBFKxthVE+dfX26hVzM3fx3Q4iffcx3+ewgGg3zzzTcMHz6cd99997dezv8kvF4vPp+Pt99+uy7B/hWRVRSS2RyJtEw8k0MQoLFb0nx0JRFBQBtJUVV2+uNsqophNUmEEgfmPGuiGdw2I/5YhkRGpmvTYjw2A4u2+KmJZahfZMZj0+OPZ7CZdKTkNI069yM88A+s/+RVxO7d6HX2SHYGoiTSOVoOvpa0fxe+T5+l3FUPfYM22LsMIevfRWTJu5rNVYczKTn/DirfvB1/XudF726Ad9h9VLx5G773xlF22WOIBjPmZl1xdL+QyNKZmBq2x9rmlMPeC0FnPKqVlZB3xzDWaw09h2nd5sotpHauIrnjeyIrZmujUZIOY/02xNfPx9ToZIz1WtXap4zlLWl4w2QyFZu1zu2u1cRWfUZ0+WwAdEVlGOq1zscEGjXc3LRLbQ0PRdaExoL7tNggXIUc9pGL+snu2I2cCBXUu381SDokqwudzY2htDmWlr3QFZVpMUFxfSSH55DxPCWdILVrDZnKzaQrNpPeu+FAw0AQES1O9KUtMJQ0xj3oRgRB1GKtHd+j9zQ+rnHBwWrgP0Zs5UckNi6iqP9VheRaToSpfm+cFqdeeC+i3ki6YhM1Hz6Jsd5JuM/9O4IgEF7yHvE1n+PsMwpr2/5YBFD2LKP6y1ext+lHUd9RJFLgsoAqwrbP3ya0aSmtL/w7rmbtEASFcDKNXq9jVyCJy2KgSyMXqWyWZE7H3lAKQRAxGySSGQWTXsRlNmA01HWvTxRMmzYNgJtuuol77733sMf8Jgm22+2murqaCRMm8H//93+/xRL+J9GrmRudJBbst95Ztpt29ZwEE5mCl+6322qIJrNMXbb7yBdSYe76qloWXj/FcxBNNkouvJfKN7UkuWz0xELX1ljeCvfAG/HPfoyaT/8P98Ab0Dk8+cT7bnwzxlM66mEsrXpTdNpVhOa/SqioDFf/K3H2G03Wv4vgvJfRuephadEDyeLEc8GdVE25G/9HT1Fy4T2HfOHvn6NJbl1aEKI6FkhmO9bWfbG27gscpIC543uS21cQX/cloClqmxp1LFSkdXYPktmB9aQDYiVyPJT3hfyB9L4NxL7/hOiyWdr9NJi16m+BBqbRww2lzY+4XlXOauJkqajWhU7HUTMplGwaZM2mSz1osxXyNl0HLDnMiEZz3pbEgWRxHHHzUVUVOR4kU7VNE16r3k5q11qNUp+fu0bUYShthu3ks0GQUJIRPEO0yu7eF68jF61BsrqIrvwIS4seIOnI7NuImkmBwYyayyDoDES+m4GxYTt0Dm8tit3PQXKrpj5srN/msO+l5rPnyFRtpeSiMQUV0tSeddTMeQZjo44FCnd02SxiKz/C0WOYZtWiyFTPepSsX7Pt2m8TE1rwBsktS3CdeR3mJp0OvJYi4//gMXKRakoveQSdrfiIaz4ccej5hdtA1RoOXruRoZ3qc9egQ99THf67eOqppwC49NJL6xLs4wSj0Uj79u158cUXefTRR7FYLEc/qQ5HhV4USWdloqksRr1IOJlhe7WKKIkYdSKSIJDI5NhWHWNnTZxkVqHYoscfTWExWrEbdaRzCiaDhCQK1ETTOM16yl0WLuhUn0AsxardYaqiaURRomNDJzuro0RTOZr/4W/k/Nv48KXHadOuA3pbMzxWgdYlFsTL7mPFczfie38CDS5/Epxeis+8jlywgpo5/0Tn9GJq1JGS4WOofPNWfO8+SNnlj6N3N6Tk/Dvwvfsg/tmPaZaeokRR/ytI7/2BmjnPYChpcljmmrFBG80y8iidz4MhiFIh4Xb2HoGSTWkCXvmEO7zwLcK8pbHeGrTD2KgjpkbtNXFPUSporDj7jNS645VbSe1ZT2bfBtK71hTEz0BAV1xfW/v+uKC4fn4euxwzXQ9Zm6qqeUZaRGOlpeOaLks2rTHX5KzGZMtvNoIgaPPXOn1em8WEaDBrc9hmO5LZgWC0HnEPVrKaKGu2Zg9Z/y7S1TvI7N2AkggVjpEc3vx7Hooq58iFKnGf81dAiwsyVdswlrU47nGBkkmS2r0Wa5v+h30+vXcDgS9ewtysG46ew7T7mctSPXMCcjyo7d8OD7mwj+r3xiNai7TxBZ2BxKZvCM1/DUvrfjj7jQYg5NtJ1fRHsZY3o/GQm5ARUIB4ClLbF1O5cApt+w+mfq/zyQGiZoxLu3p2rCYdgqjt+9XxDKmsjCyrVEkCBr1IVlbYVBnFpJdoX99Bh/oHPMl/bRX5Ohwb/H4/06ZNY8CAAT+5V/1mNl2KorBmzRoCgQB2u/2/vob/Vdwzcw1TluxCBUQBREFAUVV0ogCCQPYg7+vDQRS04P/g5NqoF7mqdxNmr6lgb/BQMYn92G+nZe1wZkEYaj9CX71JePE0ik67Gmf+C02j3UzE0rovngvuAAQCn/6L2Ko5FJ/ztzxtKEXV5LvJ1uymdPTEQgIaWTaL4Bcv4uw7mqL8l9zBqJp2P1n/Tupf99JPVjGPFaqqkK3eqVWjd66qRR3TFZVpdPD8TLLe3eCQpF+Vc9qmVLmFrG9rgQam5ClkoFXYdUVlmuDYfosuuxvJ6kKyFiGaHUgm+39kX6XK2YKSuJwIocSD5KJ5q65IdaF7Xks0JE8307sboiQi2Dqdg73L+Yh67b4G57+KZPdg73QugqTH/9HT6JxeivqNJvTVm9g6DwJFJrx4GobS5pibd0Pn8JLatZrI0vfxXjTml78fRabilb+BpKf8qn8csuFEV3xE4PPncPQeievUywGtOl751u3a3N5ljyOZ7fnf3UewtOqNZ+hdgEDg8+eIrfy48LsIEFszl5qPn8bW6VyKz/5rrdcLzn+VyJL3ah1/JNiMErH00bsP15/ajLPalfHeij0IwLAuDepo4/9FZDIZysrKsNvtbN++HUmS6ijixwHdunVT//KXv3DNNdfw+OOPc+utt/7WS/qfQXU0jS+SwmwQCSey2M167CY9kWSWrKxgNUgs3RHArJeoDKdIywpem4EGxRYkQcSgE7Gb9ViNOvYGk1j0EjazjjKHCZ0kEk1m2RmIsXJXgERaxmHWE07mkCSBquoAz940ilQszJ+fnIxiKSGjqlSHE+Rq9vDeuD+is5fgvXQSotGCkorlR28ChYQ6vXcDVVPvweBthnfUBES9sfC9frDGRS7qp+K1mxCNFsqvePIQ2vL+udr9lN5fA3IiTGrXmnyXeg25wB5AK6Ib652EsUFbLS4ob3WIE4e25pqCUFjGt41s9Q5yoSoOjsAku0ez53Lut+hyI1mLtZjA4tQSY4P5FydbqqqgZpIFdxFNUDVYGB3TrLoqa81cI4iIVhcoMobS5sgRH44ew7B2OLOwjt8yLogs+4DgFy9QeuljmBrULlLnojVUvnEzgqSn7Mqn86w0Ff+Hj5NYvwDP+XdgbXNq/nfxDnKxGm2c0dOI9L6NVE25B31JE0oveRhRb9R0hN68FXIZWv7hKbJWDxLaJ2iM7mTDy7fhLGvC6AdeYHs4SzyVo7TIjNuqp0NDNx6rEYNOoEmxhfmbq4kmc3hsek4qc5BTBQKJNGv2hGlXz4GqQt/mbrKqQDorYzfpKLLojzj2UYfjg7vuuotHH32Ujz/+mIEDBx7Rpus3S7BvuukmLr/8ch555JGfVGGrw8/D/jnsbE5BEARkRd0vewUc2ok260XSWS3plkSBa/s15bXFmsiSJApc3K1hraB+8pJdvPL1drb4Dk9ZDi18i/A3U3Gd8Scc3c4vPK6qCv5Zmu9lyYX3FGjZ4SUzCM1/BUePYbgGXK3NXr07jtSOlZRcdD+W5t21We43bwU5R9nlj6NzlmqdyY+fIr52Hp4L7ip0jvcjtWs1VVPuoaj/VTh7Df8V7mxtFIRKdq/VZrj3rEdJRgCNem0sb4WhvCXGspYYylog2T2HbICavUSgYMWR/RE9XD0o+T4Ygi6vCm4waf9fZ9Aq8qKklUHVg2ewM5pVVzaJmq9uH/aaBrNmGeIs1ZJ8Vz2tgu5uQM1nz2Ft1Qdbx7OIr59PumIz1pP6FTrGkaXvk63ZXRATia39guS25RSdchn+2Y+jc3jJ1uwiF/ah9zTCPfBG9J6G1Hz8NJKlCGPD9qR2rMTSqjemRh1/1ucQW/0ZNZ88c9jfgeTOVfim3Y+5WVet+ixKyPEQlW/djpKOa7RDVzmpPevxTbsPvbcppaO0TTO85F1C818rzARC/ndq2hhMDdvhvfjBWmJ28fXz8c9+HFungYWK/c9Bmd1IZfTQz6bMYSQQz5DJiw0a6oTP/qt45ZVXuOaaa3jiiSe45ZZb6nywjxO6deumLl68mLKyMqxWKzt27Kiz7PqVICsqNfG0ps0iCURTMpIo5G3RBHKywqrdYUxGEZ0oUs9ppl6RiXAyhygIlNgMBJNZMjkFu0mH22o45LPJygq7axKkszmqoiniKZmcCkaDwMpVaxn3p+G4SutzwxNvklD0mI06/NEEX345j5Uv3oOpaWe8F41BECXNZ/iNWxH0RsovexzJ5ioUQM0te2pqzaJE4IsXiS6bhev0a3F0HwpAavdaqqbeh6lRB7wXP1BbL0RVqHjl76hyhnpX/99/VKg+EnKxQD4mWEd6zzqy1TvRaEkiek+jfFygjYoZShofdoRIyaTIBvaQC+whG9hHLlRBLlRFLuLTklz1MLN9gpifvzYXFMIR89adP57Bztt3qtmUZt+ZSR7xmpLdrcUEzlJ0rnKto+5ugM5VH/+siVhOwLhAScfZ+8J16N0NKBs9sfZz2TRVU+4m699F2eWPF1TKgwteJ/LtOwUBVDWXpWr6/aT3bqB0xDhMjTtqoqZv3oaoN1J2+RNI1iKUbBrf1HvJ+LZTOvoRjOWtsAiaN3wkHGHv27ehZlMMufdlmjZqQCAuIxnAYzXR2GOmtddBOJHFYzdSEU6RySnoJREFlT7NPfhjGTZWhNhaE6dDPRfRVIZ29YpwWPRUhJJIooDHZqK+y1ynLv5fQjqdxuv1UlJSwubNmxEE4cTywQYYNWoUN910E1OnTq1LsH9FdG3s4u1re/HtthpcFgPjPlxXSJYP18G+f0g7WpfZa4kqndWu7IgiS6N7NiKYyPDEZxsPsfgSAGe/0WSqdxCc9xJ6d4PCTJEgiLgH30wu4sM/+/HCl5Gjx4XkwlVEvpuB5CjB0fU8SvZTwGdNpHTUwxjrtcZ78YNUvXU7VdPHUnbZJM1H85y/kQvso+ajp9A5SgpzNACmRh0xt+hJ+JupWNucgs5Z+qveZ0GUMJa1wFjWAkf3odp8cmAv6X0b8v82Efn23cLGJZodGLxN0Jc01Xwk9yuE2t3azPpBVOP9UDLJAxZdiXBtKlg6oSXN+QQaWUZV8x1RQdA8MCVJ+6/OiGgwIRosCCYrksmGaHYiWZxad9zuPmx1HbQigKGksbYJA8b6bcnW7CFTtbWwkRrrtyGx6RvUXBZBp8fYoB3hRZPRu+phatAOXXF9nH1Gktj8bcEmLBusIL5uPobylsjJiCag5mn8sz4DOREmOP81jPXbYMnT+vcjW7Mb/8yH0Rc3wHPe7QiihJJJ4nvvQeRYgNJRE9C7ysn4d1H93jgkuwfvRWM0QZP1CzQK2EmnUNT/isL1qmc+rHm1Dr27VnKd3rcR/8f/wNiw/c9SDD8YpU4TgcSBRHo/GhVbqIocSLyzOYVvt9XUJdj/JUydOhWdTsf111//Wy/lfx56vZ4///nPTJgwgaVLl9KzZ8/fekn/E5BEAa/9wCyrzZQjk1UwGyUkQSCWznFKKz2RZA6jXtREzUSRIouKAIiigMWoQ1G1ax0OQv5/siroJImmXjOyouK1G1GUdtw4/hkm3XYN7z19H6PuepLyIguBeJq2nfsQH/p3Ns38B6n5L1DvrOtJOkvxDh9D1ZS78b33IKWXPIKldR9cZ/yR4BcvEJj7PMVn/RnXgKuRI9UE572EZHVpQlMN21N89l8IzHmmcNwBOq2I6/Rr8E0fQ/jb6RT1u/RXv9c6WzG6NqdibXMqoKlUp/dtJL13A+mKjSQ2fUNs9WfawaIOvadhnhbeRIsJPI2QHCWF2OLHUBUZORY8xKJLScdQ04kfUcRzoNamiAuCUaOIS3mKuNGMaNDsO0WzNjomWl0F1tyRqPSqqqI/QeOC0FdvoCTCuIaP/dGaFWo+eopMxWZKLrynkFxHV3xI5Nt3sJ18Lo5eF2ujXh8+QXr3Wjzn3YapcUfkRBjfO2NByeG9eCKStUi73sdPk963Ac/QuwtWbiYDGCSF6g8fJR2upu/fnwZLMRXRNG6bkeYeKw6zka6NXZQ4zWz1RQGBtKySyuXw2q2YjTrKnGZEUUAnuTDq9aiKQscGTuwmA6mcjKyoWI06BCD342C8DscNc+bMIRKJMG7cuKOyRn6zBFun0zFkyBBef/11VqxYQZcuXY5+Uh2OCV0buwoB+MHJMxyYwV5XEWFg+3JG92xUOOdw5x8OB896g7a5SqKAy6KnOpbBM+RWKt+6nepZj1KeFycBEPVGvBfdT+Wbt+F7d5xGASsqo/jMPyHHagjOfUHbKE/qh/fiB6h863Zt9urSRzF4GlFy0f1UTbsf3zsPaoJpBrPmo/nmrfjeG1+43n4Un3Ud+17+K/6PnqJ01ISfbVv1cyAIgqaumRdoAa1amqncTHzdlyQ2f0umeiepPRtAzhTOEy1F6IvrITnLNIExh/cADczuLlzzeELJpskGK1CSkVpFCgBUBcniKnTnxXxSnjtI9dxQ1gJBbya5YwWWFlpQrHPVQ0kncJ1+TeG49L4NmqJ5OoHO7qbB399CMjt+0ZpVVSXw6b9Q0gmKz6lN1ZZjQareeQAkPd6LxyIaLahyluqZD5Op3ErJsHsx1j+JXKQa3/QxCJIe74hxSBYnyR3f4//oKYwN2+MZfIsmyBIP4nvnARB1lAwfW4t6mIv4qZ45AclWfEji/WNowSrIh2kWjOzeqPC3urkqyve7Q5zbroyz2pVxyQuLC4m3XicW/pbrcHyxe/duPv/8cwYNGlQ3E/xfwujRo5kwYQLPPfdcXYJ9nGAx6LAcNDVVlP/B+6Ov4oOTaUEQkH4ilhQFAb0okEhpitM6UcBlMSAjs8ufwNmqO2defhOfv/4kjvJnOeeKvyMJ2nda077nkfDvZc/Cdymr34i2p48g4mqLdMFd7H1vPNXvP4L3ojE4up2PHAsQWfIukqWIon6j8Zx3G1XT7sf/0VOIZgfmpp2xn3w2ueBeIkveQ+f04ux5gMFmbtoFS9v+hBdPx9yi58/SaPklEE02zM26Ym6mzVCrqko2sJfYqk9J/LAAJRmrpe8CmkirrqgcyVWOoagcKT8ypts/MmYvRufwHNd1q6qi+W5Ha9Dlx9RqH3BixgXJHd8TXfER9q7nYyxvWeu50PzXNFGz064uMCjjPywk8PnzmFv0pDivLh/84gUSGxfhGnA11ranoWRS+N4dRy7s0wrz+XgstOB1EhsWUnTaHwq6PQZAUFR2f/o84W2rOOWae2nQsiNGvUhaUdkbSNLUY8Vh0ZFWFKqjacw6kVhapn6RiT0BBYdZT5t6TnSSSL0iC+VOM23KnciKil4SyMgKVeEUFqMOs16HThIx6erYPv8tvPDCCwAMH350ZuxvRhFftmwZmzZtonXr1lxyySVMnjz5v76OOvwyLN8Z5JIXFpOVVSRRSw7a1XNy38w17M8dNJrXLXlroydqfWFm/bupfPt2REtRoRutZFP4pt1PunIzpRfnKTnBCirfvh1B1FF22SR0Di+JTYupfv8RTI1P1ihleS/Kyrfu0GZqL52EZC0qvNb+mdkjzWofDUo2RdafV+yUsyDqNIssswPJ6kI02X6yipXY8h3RZbNw9r2E5NZlKNkUjm7nk6neRS64h1xgnzaX7d+leVb+CILBjGiya5uuowTJ4kQoWHTtt+PQa/6WBxcQVEWjgsn7LTnSWnX7IKsuORlBToQPUNFFHY1um3HI/Hhi4zcktn6nKWiKEtHvP0FOhCnqM+rAMZu+If7DQvTuBqR2rcHW8Sxs7c9AUWTUvBCLnIygpBOF+yhZnOjdDY/YPf8p7J/BKzrtDzh7XnTg80onNApYYA+ll2gsCVWR8c9+nMSGhbgH3oCt49nIiTCVb9+pzfpdOhGDtxnpyi1UTbkbncNL2aWPIppsKJkUVVPvJlu9i9JLHq5VgFAyKaom17acOxJEKOgiKAdpHDRxW/jTqc0Lha7DYfnOYN0M9m+AG2+8kWeeeYYlS5bQo0cPgDqK+HHC/pgAoF+/fnz99ddEIpE6fZbfCfzRFJH83HVOVqlXZMIgiazbF2LeD1Us2x4gq6psnP44q76cxajbJlK/y+n4YxkC8QyxTJYNkydQseor2o2+B1ubvvhjKsFVnxGY80zeHulmQKDm438QXzsX15nX4eh6njYrO/kucqFKSkc+hLH+SdpI2gePad/5Q27F1m5AYa1yMkrFq39HkPRHdCL5KaiqqtG2wz5Nq0QQtG6wyVZghf1UofXHMYGaS+PsqzH/coG9ZAN7yFRsIlO5pWB5WQuipMUEeiM6h0ez6jJa82Nj+6078zGBpKMwIKgqmhDq/rggm0bJplDSCdR0HDkVPWDhmQgX2HdH0hT5j+KCbBo5FiiItu4fXROMFvRF5UiOkp89Uy7HglS8dgOiyUbZlU8h6g+wNiJLZxGc9yK2zoMpPut6BEEguX0FvnfHYazXCu+I8Yh6I6FFbxP+esqBkUU5l7eGXUHJ0LuwtOoDHIg/7J0H4TqIJaEDMqs+YPecF2h15iV0HPpn7GYJVRZJZHOogsLZbctxWAw0K7FiNejR6wSyOa0oJUrQuNiKTjp6wqwoKllFQS+KdVZe/yVs27aN5s2bM3ToUGbOnFl4/ISjiAO0atWKAQMGMGXKFP75z39SXHxk1d06nDj4dlsNufxst6pCvSIzwUSmFvVc5yyl5ML7qJp6t2bRNWJ8YeZJ72moqYhPu5/q98bjHZm377poDFWT79SUxS95BGNZC0pHjKNy8t1UTb2PsksfxdKqN+5z/07NJ/+gevYkSi64S1MWvWgMvmn34XtnrCY+kd80re3PILVzFeGvp2Asb4m5efdjeo9yMkrwixeIb1h4QDn7MBB0Rmwnn31YarCqyGQqNmGo1xpTw/YgSMRWfoQgSlhbaVY0Gf8uWD4bR48LMTVsT3DB6+iKtG62HAuQC/lI7lypKXgrMrlQBXIypgmsHW5u6kiQdIh6M4LRolHETTYMjuaav/Z+Spjdo32gP/quNtQ/iejKj8hUbMZY/yTS+zZiLG+FHA+RqdyMvqQpuuIG6NwNSO1aiyDqiK+bT+Tb98hFfKjZ1JHXJYhYWvbCdeZ1BXu3oyG15wcCX7yoKYAe5Hut5jL4ZownU70D77D7teQ63+nWKs1XY+t4Nko6jm/6GOSID++IcRi8zcgG9uJ7ZyyiyY53xIOIJhuqnMM/a+KBrvdBybVGI3ucjG873ovu/8nkGigUntQ8zVJVVfQ6kSdGdDpqwnw0Rkkdfn0kk0mee+452rdvX0iu6/DfwT333MPgwYN5/PHHefDBB3/r5dThGJDOKRj0IjpRQFZkDDqt2JuVVQQBPA4zdqNIkz/eTaByN+/+YwwX3ltCg1YdCCeyWI16OlxyF5lokPXTJtFgxDj0jTpiP/lslHiQ0MI3Ec0OXKdfi3vg31HSMYJzn0c0WrG1Px3viHFUvX0nvncfoPSSRzB4m+IZfAu+ZJiaj59GNNmw5Pd+yWzHc/7tVE2+G/+HT1Iy7N5DispHQmzNXEIL3z5gQ3VYCEg2F/Wu+b9DhNaOFBOomQTmRh2gUQcy/l1Es2lsnQZiatKZ0ILX0LvqoSuujxytIReuIrl1GblwJaLOiBwPFxLVn2fbJSAY9iuJ2xDNNnTOUqTyVohWFzqbC8lWjKGs1WHP/qm4IF2xGcnh0RoRNhfJbcsBgciyDwnNfw0574t9JOhc5bj6/wFL6z7H9E5UOUv1rIko6QTekeNrJdextfMIznsRS6s+FJ/5JwRBILVnHdUzJqD3NCyMhkWWvk/46ylYO5xJ0Wl/0Cjgn/yD5LZlFJ/zt0Jyndj0DYG5z2Nu0QPXmdfVKgREty7F9+lLWFv2wt73EgRBpdRhIpPJsXN3DEHSs3p3kGKbma6NXdiMehIZGU+RgXROwWnWH1NyDdrohvE4sjLrcCgef/xxAMaMOTYBvt+0gw0wd+5czjrrLO68804mTpx4lDPrcCLgYCE1vU7k7Wu1ZHHk84sPmQXZL/5kbX867kE31/oyim9YhH/Wo5hb9NDstkSJXNRP5Vt3oGZTlF36KHp3Q02Eavr96IrKKb3kYSSzo6ASaW17Gu7BNyOIEsmtS/HNeCg/sz2u4KGoZFNUvX1sncb9qJo+ltSuVdg7DcTUqAOSzY0g6TUV7kyyoLaZi/oxeJtha3/6IdeQUzEtsa/fButJ/chU7yS+Zi6mJp0KlLHoyo/J1uymqN+liCYbwfmvoebSFJ95HaoiE1/7BfGNX2s+mc26FpJwORbEPfDvRL6bSbZmN47uFyKaD2zmgiBBYQbb8B/R41VVJbTgNTKVW1GyKeR4EGNpC7LRauRQZYEmVnhtg+UgNfRSjWJmK0Y0OzUhlvx9lOMh0nvWEV35MXp3Q8quePKoVev9AjiiwUzZFU8imbUO134KeHLrMtxDbsHWbgCqqhL84gWiy2fj7D2SolMv1+he08eQrtiId9h9mJt3JxeppvLtO1BzGcounYS+uL4movfJP4ivmUvxOX/F3mlgrXUE5r1EdOn7hS7K0SCgda/1OpExQ9oVrPPqEucTE0888QS33XYbr7/+OldccUXh8boO9vHBwTGBoig0bdqUSCSCz+dDr69TyD3RkUjnqIxohdQii4Fiq0Y73x2I89VGHz9URJFEBZNeRyBQw6t3XUEyFmXQXf8mZnRj0ouk0gpCJsaXT91AOuIvMIu07/EXiS7/AGe/Synqe4lWTH33AVK71uK54E6srfuSC1dR+fadqHKWsksmovc01BhNU+/RrBaHP4Cp8QGxrMjy2QTnPl9LyPIn3+OW76h+bxzG+m2xtj8dfXF9bfxIVVGzKc2dIx7SdFNiNRTnPZMPxnGPCQbdSOS7GWT9u3H0uDCf4OdjMkFEEMT8DLYeJP1/ZPOkZFKEFrxOxrcNNZtGTgTRFTfIK4/7ajPyBPFAPOAsRXKUINmKte672YagMwKgphNka3YTW/0ZmaptlF026bD2mwdD26ufIb7m80MU4hObvqH6/YmYGrXHO/wBBJ1BUwKfdh+SzU3ZaG2eOvr9HAKf/hNLqz54LrgTBLEQOxSdcjnOPiOBvIjetPsxlDajdNQERL0JEa2AnvFto/LtO9G76tFk9KPYnGZOKrXT0G1hZyBBZTiNQRSo77LSs7mbPi08WI16HCYdHvuR/b7rcGIgHo9jt9vp3Lkzy5cvr/XcCdnBBjjjjDNo3rw5kydPrkuwfyc4WEjt4CRh3AXta9HEAaxtTyMbrCC86G10zrJaNG3rSf2Q40GCc58n8Om/KD737+jsHkpHjqfy7Tu1rvVlkzA1aEvJsPvxvfsgvuljKB01AUe381GzKUJfvQGiDvegGzA3745nyG34Zz9G9YzxlOQrk6LedGBW+50HKbvssZ+cYVJzWVLbl+PodTGu/lf+8huVy6LKWSRLnh6vyih5n8f9yAb2IpodhUq3mksj5KuvsVWfoiSjWFr1IReqKCiOKsmo5mMtSuhc5WSrd2h0MbvnZ3lGqnIWJRlDTmoCanI8lP8XKAQJuYgfOeo/hKqWyqbRu8oxNOuKrqi8oDCqKypDNDuOeQ2Wlj2RbMUE572EHAv8ZBdbTkaomj4W5Cze4Y8cSK7z/tPJrUspPvsvheQ69OUrRJfPxt7tApynXIaSTVM9Y5wmSnL+HZibd0eOB6madh9KKk7pJQ+jL64PaPNV8TVzcfa95JDkOrriI6JL38fe9bxjSq4Brju1GXazvi6p/p1g8uTJuN1uRo0adfSD6/CrQhRFbr75Zm6++WbmzJnDeecd299YHX47WIw6GhVbUFTN6WA/GrgsnNmulKYlNrZURSixmxAaurA/+gIT/zySOU/fRrs/Po6ityJJAoqso97FD7LzrdvxTR9L6WWPoS8qw3XGtSjpGOFFbyMaLTi6XaDFBNPH4P/gMYQL9Vha9KB05ENUTrmLqmn3UnrJI+iL62sCqVPuxvfeOLwjHsTUoB0A9i5DyPp3EVnyLjqnF3vnQT/5HpObv0U02Sm95OGfpID/JI53TCAI6IrKyPq2a2NldvfPiwlUBTWd0ITTEmHkZFiLCWKBA//ylp4/LqwjGRANQfTuhpibd9fiAVe5Fh84vcdc5Dc17oi1/Rns/sdIEluXHjXBDn8zlfiaz3H2GVU7ud66lOpZkzCUt6Rk2P1acl25Bd/0MUgWJ6UjH0KyFhFbO4/Ap//C1KxrQRA1+NWbWuzQfSiO3iMAyPi243tvPDpnab7rrX0mEqCL+tn77oOIRivNRo3BaDdh0QkIIlgMBiKpCHaDgCjqSOdkWpXYaeq2IUkiJn3d/PTvAdOmTUNVVW6//fZjPuc3T7AFQeCiiy5i0qRJLFiwgP79fx1/wjocXxyOtjq6pybU9OgnP/DdjgMUIGefUeRClYS/nozO6S2IgAHaDFUiTPibqYhmO67T/oC+uD6lI8dTNeVuqqbeS+noiZibdKLkwnuonjGBqmljKB05XrNTkHOEv54MArjP/TvWNqegyhlqPnqa6hkPUTLsPkS9EZ2jBO/wB6icfCe+6fdTOnoiksV5+DcnSQh6E0oq+p/dJJ0eJRUrbIJyMoogCIjmA3OFSipWSOpAU8Y2N21B1r+bXKQaR8+LkMM+csF9iPn12jqeRWDuC+z+xyj07oY4+45G7yoHqLWRpvf+QHjx9LwtR7owc6WpjSaOaNmFqEOyudDZ3BhKm6Nr0fOA0Eq+Ai0azP/ZvTkISlJT0fypayrpBL53HiQXrqJ05Hj07oaA5i/un/04iU3f4Dr9j9g7D9KS6/mvElk6E3uXIbhOvxbkLNUzHiK1cw3uIbdoxZ1klKpp9yNH/XhHjCsI3oSXzCCy5F1snQbi7Ft7bj+xZckBetjp1x7T+/PaDdw16KeDhDqcOFi3bh0rVqzgqquuwmAwHP2EOvzqGD58ODfffDMvv/xyXYL9O8HhqK2CIFDqsOC1m9EJAhWRFLF0joaNWjD8zid4+4E/88Pb42lyyTgUyUA2C7Kj5ADle9r9mq6KzYV74I2omRTBL15E0Juwn3wO3osfoGrqfVS//zDeC+/D3LwbpaMe1mKHKXdTOuph9O4GlI6cQOWUu/G98wDeix/E1KAtgiBQfNb1yFE/gc+eQzTZCurfh4Nosmn7p5z95Qn2bxwTAPg/ego5HkKVNW0WNZvS9FnSCU0j5RAz1/3v367Zdtk9GMpa5uMBL5KjFF1RqTZ7/h90xA+GkkmAIh9VnyW64iPCi97G2v4MnAepwie3LqV65gQM3iaUXvwgosGsJdfT7kMwWikd9TA6h4f4+gXUfPw0psYdKBl6D4JOT/jbd4gsnoat49m4BlyDIAhkQ5X4pmtJdWleEHU/9EqC3TPHIWSTnH/Xc8St5SiINCux5kcojTR2m8llwG030qaek94tPZgMv3n6VYefgZdffhlJkhg06KcLcQfjhPiE//a3vzFp0iQeffTRugT7d4TlO4OHdLG7NnbRv7WXpTuCha9pQRBwn/s35GgNNXOeRbIVF+y7AJz9LkVORogseQ/RZMfZazgGb1Ot8jztPqqm3kvZJROxNO9OydC7qH7/Eaqm30/piHGFjnj468mgyLgH3YSt/RmgqtR8/A+q33uQkmFjEA0mDKXN8F40Bt87Y/FNH4N35EOFLujBEAQRc4uexNcvoOiUy4+ciB8FkslGLrgPNacpqyY2LETnLKtFUd8/K77fyiIX2Iu++1Ciqz8ls3cDMaOFXNiHnAhpFeOiMqIrPsTedQjm5t0JznuJbLgK04+q1KqqomQz5KI1GiVMZ0SyWdHptZkrIW/ZpTM7EM12TQU0L9Aimu3HPJP2n0LJpIit+RxTk05H3EyVTArfe+PIVGr2GqaG7bX3KGc1IZtN3+AacDWO7hdodMJ5LxFdNktLrs+8DjWXoXrmBFI7vtdEztoNQEnF8E2/n2xgL97hYwtdjfHNL4EAAMfmSURBVOiqzwjNfwXLSacUxFD2I71vI/5ZkzCUNsdz3h3HXJG/6czWRz+oDicMJk2aBMBNN9302y7k/2M0aNCAQYMGMWvWLHbt2kWjRkcWAazDiYFMTsEXTSErKh6bAavxALVfUcGol9CJAgIqWVWlW6++7P3j/cz79xj2zn6SsvPuIIW27xg8jfAOH6vt//sL4iYbJefdTtXMhwjM+SeCpMPW/gy8I8fjm3ovvpkTtLGfZl21JHvqvVRNuRvvqAkYPI0oHTWBqqn3aHv/8LGYGnVAECU8F9yJb/pY/B8+gSDpCwrTP4a5ZW8i380gtuozHN0v+EX36L8VE+QivrzPee19XFVVctEalFRMY/cZrQh2tzaHbbAg5C27JIsT0eLIxwRFSBbXcfENPxKiKz4EwNKi1xGPia3+nMDnz2Fu0UMTWsvv1YnNS6h+/xEM3iaavo/JRnrfRs0xxGilbPQj6Jxe4usX4P/wCYwN2moxYn4OO7TgdSxt+xfcSXLRGnxT70WVc5Re+ig6p7ewBidZfLMfJeHbyfXjn+fMs/qzYneIeg4DgiiyoTJKZSRN23ouymwGGnosdKhXXJdc/86wcuVKvvnmG6688kocjmNXuD8huAkNGzbkggsu4JNPPmHfvn2/9XLqcAzYP4f9xGcbufSlb1m+80DHulcz9yF+mYKkp+TCe9B7GlH9/iOkK7cceC5fSba06U9owWtEV3wEkJ+lfgA5WkPV1HuQ4yEsLXtRcsGdZCq3UDXtfpRUjKJ+o3GechnxdV/in/04qpzF1uFM3INvJrVrLb7pY1DyStmmRh0oufBeMv6d+Kbdh/xjmlMezj4jUXMZgvNe+o/uU1G/SwktfJOqafeTC/uwtj2N9L6NZAN7ATC36EFq9zpUOUumeqcmQmZ1YWt/BnpPIyJL3iO2Zi6pHd+T2vE9ya1LQdRhatAWUWckF9hL6MtXqHj9JnIRHwDZUCW7nriQ0PxXEAQBQ2kLyi55GO/wsZRccCfugTeQC+whuWER9i6DsbY5FXPjkzGUNNFUyv9LyTVAaOGbyLEAzr6XHPb5/d7V6T3r8Qy5FUtLbcPV6N4TtOT6jD/i6DEMVVUIfPYvLbnuen4+uU5rnevtK3EP/Du2jmehpONUTR9LxreDkgvvwZz3II//8BWBOc9iatoVz5BbaiXQ2cBefO8+iGRz4R0+pjDffyR4bAbKHEauP7XZTyqE1+HEQiAQ4I033qBPnz6cfPLJv/Vy/r/G3XffDcC///3v33gldTgW1MTTeSshkYpwmlTmgAezJAqU2I3kFBVFVbEadNR3WRkxcgR9R99AzdpFVM17CeEgTSBj/ZMoGXYf2cAefNPHat1VnZ7yofdgatxRUxNfPx/JZMM78iEMnkb4ZownsXUphpLGlF7yMABVU+4mU7UNnd1N6SWPoHOU4HvngbzwFoh6E97hYzGUtqB61qMkNn1z2PdnrH+SJjq26O3CXvtL8J/EBJbWfUlXbCa2Zi7JzUsILXyrVkwgmWyYm3cnsng6+174IxVv3npIXKCmogio6EuaUDpyPN4L78Uz+BaKz7qerG8b8bVfYOtwBpbm3TGWt0Ln8P5Xk+tM9U4i383E0rb/ES1KY6s/o+aTZzA16UzJBXcVGAXx9Quofv9hrZky8iEkky0/N32f5jIzeiI6ZymxdV8Wkmvv8AcQDSZtJn/eS9oc9mBt/5cTYXzT7kdORvCOeBCD58BeblIV9n3yDJUblnP6NffibNmJymiScqeRmliWTVVRikwGTHodzT12GnjslNithFJZ0rkDfxt1OPHx2GOPAXDnnXf+rPNOiAQb4LnnnmP+/Pk/qzpQh98O326rIZNTUFTIZBWenrupkGR3bezi2n5Na4lRC0D7JqWUXvwAosmO750HyAYPFFMEQcQz+GbMLXoQ+Pw5YmvmAmBq0A7v8LHkIj6qptyDHA9iadWHkqHaplk19d6CNYRrwNUkNiykesYElGwaW/vT8Zx/B+mKTVROvrugXGlu1hXvhfeR8e+iavLdyLFDFS0NnkY4e48kvu5Lot/P+cX3ydSsmyaS1e18is/4IzqHh1ykmlzUj6oqmJt0QjLbqHjjZvwfPIqzzyjN/9rTiNSuNZT/4RmK+l+JoDdhadUbvbcpmcrNIOqIrf4MVVWwtT8dR9fzCH75av5eCuiLG1B+1T8ov+ofuPP+jvsR37BIo6T9SnSuX4rEpm/ynebBmBq0PeR5ORXTrNt2r8Mz5JbCfJWSjuN7ZyzJbcspPudvOLpdgCrnqPnwSWLfz8HRaziuM/6Imknie+cBrXM96MaCgnjVtDFkqrZotht5ZdnE5iUHqtkX3l2g8AHkYgGqpmuqkd6LHzzUE/QwCCez+KJpXvlmB/fMXFOrAFWHExdPP/00APfdd99vu5A60K9fPxYsWMBVV131Wy+lDseA/btJKpejIpRkTzCBP3ZgDMlikCiy6PHYTHRt7KJNmYPLejVj1nMTGXDRFfiXfED8u/dqXdPSpBMl599JpnIzvvfGoWRTCHojJcPux9iwHf4Pn9SSbLNdS7JLmmiF143faF3r0RMRJANVU+4mtecHdLZiSkdPRFdcH99744n/8BUAotFC6chxGMqaU/3+ROLr5x/6/gSB4nP+CqhUv//o4S20jgG/NCYweJti73Y+Wd92nL0vxtnvMi2m0RsLMQFAfO08BL2Jen94Fke3ob+ruEBJx6l+/xFEo5XiI4xgRZbPzifXnSgZdl9hfj268mP8sx/HWL+NNl9tspHcthzf9LFINrf2uTu9RFd9Rs2HTxaEzwrJ9dznMbfshed8bQ5bY7mNIReuxHuR5kqyH26TSua7N/Gv/pK2Q67G0u40Vu8J8vXGar78oYq94SQZWcFm1BFNZtgdiGHWSYgCbPbF2Omv/bdRhxMXPp+PKVOmMGDAANq0+XmjfidMgl1eXk7//v2x2WxHP7gOvzl6NXNj0IkFBcWvt/gLnezlO4O8tnhH4VgBMOpFxg/twKOXn0bpiHGgKvim3U8uFjhwnKSj5IK7MDXprKlCrl8AaF1nLcmuomrKPeRiASwte+Edlk+Sp9xNLlqDo8cwis/5a/5LdQxKKob1pH54L7qfXGAvlW/dQTZYAYC5eTftmuEqKt++vVayvx/OPiMxNe1C4PPnSG5d9ovukyAIGEqaYGnevTA3bD2pH+bGJxc6xcVn/Zl61z5H2aWTCo+n921E5ypHsnvQ2T2YGrYjsWkxOrsHa5tT8U0fQ/Cr10FVsXUehLXdAFI7Vx144Xx19MdVUiWdKKhq/5ZI7fkB/+wnMJS3xjXgmkOez0VrqJp8F+nKLZpKbNvTtMdjASrfvpP03h/wnHcb9k7nomRTVM+cQHz9fIpOvQJX/6vyifT9Wuf7vNuwdTgTJRWjatr9+eT67kI3PLltOdWzHsFQ2hzv8LG1LD72b7JKIoz34gdrzcb9FLKyqhWfcgpTluw6hOVRhxMTL7/8Mg0aNGDgwIFHP7gOxx2nnnoqrVod3iaoDicWiq1GJFEgGM/itRuwm/VEUzmyskJOVgjEszQrseG0GIhmctRzmbGa9LgdZq69dQzlnU+nev5rJFZ/xn7ukApYWvXGM+RW0rvXUTNjAmoug2gw4b1obCHJjq37Esls13ywy1pQPWsisbXz0BfXp+zSRxEtDnzT7yO5bTmSxUnZJQ9jrNcK/wePEVk+G9Co2aUjxmvXnP1E4fGDoS8qwzP4ZjIVGzW23M+yxdLwS2MCVVVI792gCYq6G6FzlmA5qR+ZfZsKMUHFm7eS3rMeR5+RiEYLltZ9fjdxgTbK9TC5UAUlF9xxSCFbVVVCi94uJML77bW0xycT+Oz/tLju4gcRjRbi6xdoomTF9bXOtd1DZNksAnOewdS0CyUXjdWS62WzCtcsueBOBEmvqc9PH0umeiclF96LqZGmPi8AeqDq6/fYPv8dOp19MaWnjCIcz1ERTrPJF2ObP876vWF2+BNUxdMU2420KLUTSmWpjKSwGiUcZh3RVI6c/DOsVuvwm+CllzQW67333vuzzz1hEuw6/L6wX0m8b0sPoqDNWGVzCt9uqyl0t1W0X7B+LT28fW2vwpy23t0A7/CxGv1m+hjkVKxwXUFn0DyHG7TF/+ETxDcsAsDUqCPeix8kF6mmavJd5CJ+TdDk4gfJhX1UTb6TbKgSe6eBeM6/nfS+jVROvotc1J+fyZqAkopR+dZtpPdtBMDcpJP2eDpB5Vu3Fx4vrEWUKLngLq0q/v7DJLev/FXvoarIBYq6IIi1PDPlWACdowRBELXNt0XPQgfe2XsEZZdNQrK4cJ3xJwyeRgiihGiyFu5lLuKj4rUbqZp6D6ndawvXDS18C0ev4bVUS4+E1O61ZP27f1EQ8VNI79uI790HkOzFeC+6/5C1ZKp3UvnWbeTCVXiHj8Xauq/2uH8XlW/eSi5UqT3etj9yMoJv6n0kty6j+Oy/4Ow9Qus4T76LjG8rJUPv1o5LhKmaei+Zqm2UDL3nQHK9fSW+GQ9h8DTGO2JcrTlwJavNfmdr9lBy4b0Yy1v+overcuBvow4nLpYvX86+ffv4wx/+8FsvpQ51+N3BoBNp4LLQ3KOpI6eyMqIgIAoCgiAgqyqZnILbaqCpx4bTrEdWVALxNCaDnpG3PIS7VVeq5/yT6KbFwIGuuK1tf7yDbiSxYyWVMx9GzWULSbapUXtqPnyS6KrPEE02vCPHa4999CSR5bPROb2UXToJnas+vvfGEVs7TztuxHjMLXsSnPs8wfmvoqoKotGCd/gDBx7/8hVUtXYSZGnVB9fp15LY9A01Hz31q++PciKMqiqHxASCICLHg+gcJVhP6oet3QB0Di9yrKYQE3iG3IZocRR0Sn7tuCAbrCC9d0OtmO3XgJrLUP3+I6R2rsI98MZCQlt4Xs4RmPNswae6ZOjdCDr9QY9Pxtr+DEouvFebpV42C//sxzDWP4my0Y8gWpyEFk0m+MWLmFv1xpsXvw0veZfgF5pH9sHJte+dsQWW2377NACzCPE1n7F37mvU63waJaf/kVA8g6xmyeQgndU+N6NBRCdCmc1AM4+dVFbBohfw2k3YjXpSWQWdKB4ySlmHEw9TpkyhuLiYM84442efe8In2KlUikwmw7Zt20gmk7/1cupwELo2dnHTma0KXxKiKNCrmZtezdx5MRPQSQID25czY8Ue/vTGMp75YhOgzVeXXHgv2Zo9VL/zAEomVbju/pkoY/2T8H8w6UCS3bA93hHjkeMhKvMJtalxx3zyHKfqrdvJ+LZhbXMq3osf0LrTb95OpnqH9kV72WOIehNVU+4hvvHrwjrKLp2Uf/zuwmsV1mK04B0xDp2rHr73Hiyc958gXbmF6pkPs/sfo6j58MnDH3QUmpZotCJIulqCG/tV5SSbmwZ/eY3yq/5B8enX4p/9BEomSaZqK7lwJZbm3bWg4CgzQNXvjWffy39m91Mj8L37YK0N+ZciueN7zYPS7KB01MNI1qLaz29fQeVbt4Oco/SSRzA37Vw4r/D46ImYm3YhG6qk8q07SFdtxTP0LuydB5EN7qPqrdvzSfgDWFr11hLuKXeTrdmNd9h9WFr2zL/WSqpnjNesXEaORzoomFHlrKYVsEfrlO9fx89BjyYuDJKAlPe/7tXsyBZkdfjt8dBDDwFw5ZX/gTVfHY476mKCExtOix5Uld2BOAIqkiigqipZWaYqkiaSypJMZVm1O8TirX72BROEEzJ6g5GuV43DUq8l1R88SnLH96hoSbYOaNL9TLzn/I3ktmVUv/9woZNdctFYTE07E5jzDJHlsxEN5nyS3Ivg3OcJLXwb0VJE2eiJmBq2o+ajJwkvnq4V84feja3zICJL3sM/axJKNoWoN+YfH0zkuxn4359YKz4BcHQfStGpVxBfP5/q9x9BOZIjxzFCySQJLXqbvS/8kT3PXoocqT78gT8RF4hGK3pXOYL4IwGtXzEuiK2ZS+Vbt7HnH6PY99JfCH/77i+myu+Hko5T9c4Dms3mOX/D1v70Ws/LqRi+dx4gtvoznL1H4h544wEK97sPElv9GY7eI3EPuglEkeC8lwuJdOmIcQh6E8G5/84n4WdScsFdIOkILXyb0PzXsLQ5Fc/5dxyUXD9Aet9GPOffgaVlLxwGaOUSKTJA+IdFVHzyTxzNulA6+EYCqSwKkMyqlDn0lDkMmAwCgipi0Yk4LHqiqSxuu5FiqxmnWU+Zw4TDrKe8yPSrqa7X4fjg22+/Ze3atVx22WW/6PwTNsGOxWJ88sknPPbYYwwfPpxBgwZx7bXX8sMPP/zWS6vDQdhYGSUra1/IWVllY2Xe2ir/xaECYz5Yy9tLdvHZ+ioqIwc2InPTzlq3uWIT1TMeqvVFvX+TLCTZ+VkpU4M2lI6agJpOUPX2HWSqd2Ks15rSSx8FUaLy7btI7VyNuUknykZPBFWm8q07SO74Hr27AWWXP4He2wT/+48Q/kbztdO7G1B2xZMYSpvjnzWR0MK3alWtJYuT0tETMZa1xP/+RMJLZvxigYp0xSYq37yN1O61WNuehq3j2Yc9Tmdzk4v4Cz/LkWqkH3lES3ZPYSNWFRklk0Ay2RAkXUGJ1FDaHJ2rjGxgL+l9G8lUbWPvv6+haso9WjI69Z4jrrVk+AO4B9+C7eSzyVRtpWry3SS3r/hF7xs0YRLfO2PROUu1eShHSeE5VVWJLH0f3zsPaB2HK57AWNYCVVWJrvgI3/Qx6Oweyi7XHk/v20jlm7ehJEKUjhyPNS/+UvnWHSiZJKWjJmBu0olcuIqqt+8kF/ZRMnws5ubdgDwtfMZ4dK56lI6agGQ+oP2gKjL+2U+Q2rac4nP+ivWkfj/7vYoC9G/tZcqfenPL2a1rMTjqcOJhz549vP/++wwaNIjmzZv/1supw2FQFxP8PhBL59gTTKITRbZWx/DHUqRziib25LVhMejYXB3DH02xzRdlR02ccCKFKAjUK3HS6Q8PYfbUp3rGQ2T2/oABLY4QdQJFnc6l7Jy/kdy6FN/MPF1cb8Q77P4DCfU3U0HSa+ylDmcS/mYKgU//haA34h3+IJa2/Ql99QaBOc+CqlJ81p81/ZaNX1M1+W5yUT+CKFF81vW4BlxDYtNiKt++g1y4trCZs/cIXGdeR3LzEnxT7y0wzH4uVFXF985Ywl9PQVdUTtFpVxd8r3+M3zousHc6h5KL7qfo1CsQLQ5CC16jevZjv+h9A+TCPq1Ivmcd7iG3Yu90bq3ns/7dVL5xC6nda3EPuomiUy/XbLOC+6h48zZSu1bjHngDrlMvR82l8c96NG/POVhLpAH/B48RXfERjh7DcA+6EQSB4BcvEv5G64Z7htyKIOkKnev0vg14zr8Da+u+mNGSpJqEQnzbMio+eBxTvda4L7iHWFZPTlax6KHYamJAKy+9Wnlp7bFgN4l47CYURcBlNdCk2IpBJ5BTVJwWA8VWA/rD2NrV4cTC+PHjAbjtttt+0fm/yicsCMK5giBsFARhiyAId/2n14vFYjz00ENMnToVSZK488472bBhA+3atSu84TqcGPhkbcUhP3+7rYacrFHEc7JKTj5yMmpt3Rf3wBtJ7fye6lkTUeVc4bn9dC1j/Tb4Zz9ObO08AIzlLSkdPRGAqsl3kt67AYOnEWWXPYbO7qFq+hji6+djKG1O2eWPa6qh08cQXfkxkrWIsksewdr2NEIL38Q/61GUTEpLokc9nN+Qp1L97jjk5AEf7P1KpZbWfQjNf4WaD584pKp9LEhuXgJKjnrXPof7nL9iad3nsMcZyluSC+4jF/ahylniGxZiadGz1jGWlj0LYnCJDYswNdZUj+VEGFXRVCqzwX3kgpXoXeXYOw+iwV9eo/71L1N22ST0rnqUjnr4iGs1NWiDrf3pFJ95HfX++AKC3nhEhdWfgipnCXz+b02YpFFHyi59FN1BQYGSSeKf/TjBeS9hbtlT+xwdXtRclsCn/9KsOJp11R7P22tUTbkb0WCi7LLHMDVsT2LrUqqm3I2gN1J26SSM9VqTqd5B5Vu3o6SilI58CHP+/iS2fIdvhjabVTpqQi0bNlVVqPnkGRIbF+EacM0hG/7RIAggCRpdcr993V8HtKhLrk9w7FcJvfXWW3/jlfzv4NeMC+pigt8PsjkVFQGbSY8gCqSzMnpJJJOTCcTTZHMKleEUoUSWmniafYEkCoAAaRlEi51O1zyMzuai6p0HyPi2IAGiKGLQQdO+51Jyzt9IbVuO773xmvCZLp9QtxtAeOFbmvuHIOAeeCOO3iOIrZqjFfGVHJ4ht+LoPTJf8NW0Whw9hhUUyytfv5n03h8QBAFHjwvxDh9DLlxFxes3HVJgdnQ9D8/Qu8j4tlPx+s2HjJkdC5R4iPSe9Tj7jqZ0xDicPYcd0Rr0t44LdA4vlhY9NUr66InYuwwmuWnxL6LJJ3d8r7mfRP14L34QW7sBtZ6Pb1hExZu3oKQTlF4yAVuHM7Xzti2n8o1bUJIRSkeOx9bx7AJLLbHxG1wDrsF15vXaHPW0+/N7+dW4BlwNikzNx08TXf4B9m4X4B54g6YWnopRNe0+0hWbNM2Xk/phBGwmsBklpKof2DbtYcyljWl+yViMBhMZICdDIgteu5594QwWnR69wUCp04TLZiScylBqN5LIyqhAkeW/p8Zeh/8M27dv5+OPP2bo0KE0bNjwF13jPzZjEwRBAv4FnAXsAZYKgvCBqqrrf8n1stks1157LXq9nrFjx9KiRYvCcx6PB7fbjSzLSNKxedDW4dfHwf7XA9uXs3DzgYrqwPbltC6zo5NEsrljE3CwdTgDNZsi8Plz+Gc/ptF18hZJotGC9+IHqZ4xTpt3ymWwdzpXs+G4dBK+afdTNe1eSi64C3Pz7pReNonqGQ/hn/042VAlzt4jKbvsMfwfTCLw2f+R9e/EdfofcQ+5Fb23CaEFb5Ct2a1ZiBXXxz3wRoxlLQl88SIVr91Iyfl3YKx/krYWvRHPBXcSXjyd8MK3yVRtw3PBHbU8LI8GQz1NsCf2/RycfY4sKLK/gl41/X5QFWwdz0bvbkho4dsYyltiadEDW8ez8H/4JHtf+BOi2U7J+ZqFQHr3OkKL3tKUsAUB9zl/LVSuC1BVEI+9vhZfPx81m6qlpHksyAYr8M+eRKZiM/buQ3Gd9oda9leZ6h34Zz1KNrCXolOv0ObABJFcxE/1rEfI7NuIo9dwik65HASB0FdvEl48La/2fS+SxUl0xUcE5j6PwduUkuFj0dmKSe1ZR/W74xD0RkpHTyx8RvENi/DPfgyDtxneEeNq+aCrqkrgs/8jvvYLnP0uxdHjwp/1XrVrwBltS7muf/O6pPp3gkQiwfPPP0/79u05/fTTj35CHY6KXzMuqIsJTnwk0jlq4hn0kkCRSU+RWUdNLI1Vr8NtNZLI5JAVqAgnEAUBRVXwx7KoCiCCSSdiN+tp4jajlwR0pU70f36Mpc/dyp5pY2h5xcPoi1ogqBBKgKXTubglHTUf/wPfOw9ogldGC+7BNyOa7ESXzUJJhHEPuhHXqVegs3sIfP5vqibfRclFY3Cdejl6Vz1q5jxL5Zu34B02Bku+uFs94yEqJ99N8Zl/wtZpIObm3Sm/4kmqZz6Mb/pYnH1G4ew7qrCPWVv3RV9Uhm/mw1S+fSeu067C3u38Y7a7FK1OJIeXxIaFOLqdX2vu+sc4keICORYksXkJhtLmtfb0o0FVZMJfTyX8zVT07oaUDLu3lniomssSnP8K0eWzMZS3pmTo3egcHlRVIfLtu4S+ehN9SWNKht2Hvqgsz4CcgJKOUzLsXiwte5ENVuB7VxsT9Jx3O9a2/VEyKfyzJpLctgxnv0tx9hmlaQMkwlRNu5+sf1ct8dM0UJ0CZcsa9r4zBntJPdpf+zBpwYoigjEH5U49WUWgicfJ7kAcg17AYdQRS+fIKQo6SSSRlWnrNGOQRMS6mevfDSZNmgT88u41gPCferEJgtAbeEBV1XPyP98NoKrqI0c6p1u3buqyZYdXZd6xYwcjR45kyZIlgLa5bt++nW+//ZbJkyfzwAMP0KvXkc3n63B8sd//OpNTMOhE3r62Fxsro3yytoKB7csZ3bMRy3cGGfnC4p/sXB8Oke9mEvzyZSxt+h/iQ6xk0/jff4TktmW4BlxTSH7keBDfOw+Q8W3Hfe7fsHU8GzWXpWbOM8TXfYm13QDc5/4dRInQgteJfDcDY8P2lFxwJ5LVRXL7ioJ3tnvgjQU6cHrfRqo/mIQc9VN0ymU4egyrtZ7k9pX4P3oCJRXHdeoV2oZ6DJuMqqr4Zz9O4ocFOHpdTNEpl/2szem/DVVViHw3g9D81zE16YT34geO+X3GVn9GcN5LCIKIe+CNtbr1qqoSW/kxwS9fRjBa8Jx3e6HDfLjPREnH8X/4BMkt32HtcBbus/+Sn7d6iejy2Zibd8dz/h2IBjOJTd/gn/04kqOE0hHj0DlLAW1+rOaTZzDWOwnvxWNrBReqqhKc+zzRFR9qn8upV/zi+SidKDDtut51CfbvBE8//TQ333wzr7322lHnrwVBWK6qarf/0tJ+t/i5cUFdTPD7haKo7KyJY9CJZGQFq1GHx2ogkVEw6UR0OpEd/jiyorA3lCCRlgkls7isBsx6EVWBdfuCbPMnSGZyWAwSdpOeSDLD3h3bef+RvyDnsnT502MojnpEU7B/8j6+fgH+j57UfI8vfhDJ7NDGjb59h9BXb2Bq3ImSC+9BNFpIbF2Kf9ajmsDZRfdjKG1Oas96qmc+jJpL4xlyK5aWvZBTMWpmP05y2zIsbfvjPudviAYzSkZrAsTXfoGxQTs8Q26tpYEiJ6PUfPIPkpu/xdSkM+6BN9Qag/opJHeuwjd9LHp3A63Y76p3HD6pXw8Z33aqZz6MHA/mx+daHP0kIBvYi/+jJ8ns24i1/ekUn/VnRIP5wPM1u/HPfpxM1VbsXc/HNeAPCJIeORmh5qOnSG5diqVNf9zn/h3RYCK25gtqPv0nktWF96L7MHibkdq9luqZD4OqUjLsXkwN22uCuu8+SKZyC8Vn/xl7J80lIhf1a2424SpKLry3lqAZaHFg1bT70NuKGTn2Ocq95SzbGSAj5zDpDdR3mRGBEoeJ6niGBk4zNrMeFZAEkXouE+VOM63LHNQrslCH3wei0Sgul4uOHTuyYsXRxyKPFBf8Ggn2cOBcVVWvzf98OdBTVdW/Hemcn9pMATp27Mh1112H0WhEp9OxceNGwuEwgwcPZvDgwf/Reuvwy7F8Z5Cn527i6y1+FFWjwt5ydmv+OqD2l+u9M9fw9pJdv+g1wt++S2jBa1jbnoZ78M21kjlVzmrJ6cavcfYeifOUyxAEASWd0BQod6zE2ecSnP1Ga9f6ZirhRW9jrHeS1u20uYit+5LAnGcRTXZKht6Nsf5J5CI+qt9/lEzFRmydB+EacA2i3oiSilEz558kNi7C2KAd7sE3oy8qK6xHjoeomfMsyS1LMNY7ieJz/46hpPFR36Mq5zSv71WfYihrQfFZf8ZYr/Uvul/HE5nqHQQ+/zfp3WuxtO6He/BNtSysjoRsqJLAp/8itWMlxkYd8Qy+CZ3jQCCSiwUIfPIMyW3LMDXtgmfwzUhWF6qcI7ToLSLfvove05iSoXejdzcgU72D6pkTyIV9uE6/FnuXIajpONWzHiW1YyX2bhfgGnA1gigRWfYBwS9exFCvFd6LxhSodpGlswjOexFTk86a0qjhwPtQVZXgFy8QXT4bR/cLKRpw9X8kPiIAt51z6N9FHU5MNGnShFAoRCAQQDxK96YuwT42/Ny4oC4m+H0iKyuEE1n2BuOUOExkZRWjTqLMeZDVoaKyviLMzkCcvYEkZQ4D6axKkc1AmdOMLKvEkllW7Q1SFU5i0evIqVC/yMjOQIJ927fx5thrURSFk/7wCDlrQ+IHha2JLUuofn8iele5Jkhq9wD5guqcZ9F7GuG9aCw6h4eMbxu+d8ejpCK4B9+CtXVfcpFqqmc+TKZyM47eIyjqdykIgsZUWzQZnauckvPvwFCqaTPE1n1J4LP/AwSKz7oOa7vTC/uFqqrEVs0hOO9lEARcp/0BW6dzj6mbndy+Av8Hk1DlHEX9RmPvep7WbT6BoGRTRJbMIPztdCSTXXPXyDP8fgqqnCOy7H3CiyYjSHqKz/4L1rb9DzyvKkRXfERo/msIeqNWkM+Lkab2rMP/wWPI8RCu06/B3mUIyDkC814ktvJjjI06ag0Ti5Poqk8JfPYcuqIyvBfdj764PtnAXnzvPIAcq8Fz3u1YWvUGtGS/avoYlGSEsuFjMeSV1/dDS67vz2vwPILb7aHMbiSVk7Ga9bT22ClxmWnqtrLdH8EXzeK2GannMHJyo2JW7AqypyZJoxIz3Rq76dSo+Ff8JOpwPPHEE09w2223MXnyZC655JKjHv+bJ9iCIPwJ+BNAo0aNuu7cufOI11y6dCmTJ08mHo/j9XopLy/nnHPOqUUNq8N/F/s71+ls3n4rP2f6Y/Gm5TuDXPLCYjI/6l4L+X+iqAk9/BTCi6cT+uoNLG374xlcu5OtKjKBT/9FbPVn2DoNpPis6xFECVXOUfPpP4mvmZvvWt+AoNMT37CImo+fQjTaNPuv8lZkqrZpCVvUj+u0q7F3Ox+UHKEFbxBZOhO9pxGe827D4G2GqqrE184jMPd5UBWK+l+JvcvgwoapqirxdV8SnPcSSjqOo/tQnL1H1rJ7OhLiGxYRnPs8cjyIpVUfHL0u/sVWUL8mMtU7iSx5l/j6BYhGC0WnXY2t41lHTTqVbJrIdzOIfPsOiBKu/ldh6zzwR/dqHsEvXkTNZSk67arCvcwG9uL/8AkyFZuwdTwb15l/QtSbiK2ZS+Cz5xCNFjxD78LUoB0Z/y6qZzxELuzTqtEnn4OqyAe62S174TnvNkS9SfPIXPgWkcXTsLTqg+e82xF0B4KWg5Nre7cLcJ1+7a+i7PnwhR0Y3bPRf3ydOhxfbNmyhZYtW3LTTTfx1FNPHfX4ugT72HAscUFdTPD7hqKo7A0lURSFSCqLrIDbZqTUYcKgO5BQxtM59oa07vSGygiNi62YDTrKnCasRj3BWJr1+yIks1mi6RzpdI5oKofVrGNfMEU0nWbpivUsee4WVEWh2WUTyBY15uCp39TO1fhmjEc0OygdMa5AOU5uX0H1+4/kRVPHYihtjhwLUj1zAul9G/KF+ku1hO3zfxNb/ZmWsJ13O5LNRWrXavyzH0dORjSmWveh2n4VqqTmoydJ71mPuUUPis/+SyGxh3yRec6zpHauwlDemuKzrj+mvT0XqSbw6b9IbluGzlmKo9fF2Nqffky2mscTSiZJbPVnRJa8hxwLYGlzKsVnXnfEWfGDkdq1hsDc58lW78DcshfFZ/25lgZLNlhBzZxnSO9ag6lZV9wDb0RnK0aVc1qTZPF0dM5SPOffgbG8JdlQJf4PHiVTsRlHj2EU9b8SVJXgvBeJrvhIK6JfcCeiyaZ1s2dMAEHAe9GYQjEgU7UV3/SxqKpC4xEPIpW1JHvQmtN7N+CbPgbR4qD0kkcKTASrBKhgNECZ00T/Vl46N/Ewd30FJp2Ey2bErBNp6Lby9SYfwWSWEruRk+o5uahzQ4z6E5etWIcD6Nu3L8uWLSOVSh1TPHg8E+xflSJ+MDKZDHq9vk7K/gTAv77cwhOfbURRtUS5X0sPN53Z6hAa7MHH7YdBEnjg/PYEExmiySz//mrbUV8v/O07hBa8juWkUwoqj/uhqqpG917yLuZWvTXvR71Ro4Ytnk5o4ZsY67el5MJ7kKxFZKq24Zs5ATkWoPisP2M/+WyNBvbRUyS3LMHcshfugTcime0kty2n5uOnkVNRivpdhqPHhQiiRC7io2bOP0ltX5HvVv+t1uy1nAgTnP8q8TVzkawunP0u1ZLSo9CplXSCyHcziSybhZpJYKzfBlvHs7G07ntMSfqvBSWTIrnlW2KrPyO1czWC3oi90yAcvS+upbB9OKiqQnz9AkJfvYkc8WFp3RfX6X9E5zgo4AhWEPjs/7Sudv02uAfdhL64vlbx//4TjSou6ig+9+8aJTyTzNPx5mFs1IGS8+5Asrm0gskn/0DIW6mYGrRDScep/mASqW3La815a0WXfxFf8zm2k8+h+Oy/1C7WqIoWVK38WDtvwDW/2nfN6J6NePjCDr/Ktepw/HDllVfyxhtvsGrVKjp27HjU4+sS7GPDr0kRPxh1McGJg5yssHZvmJysIIhQYjXSyHPo/HA8naMinMKoE9hWHcdjN2DR66jnMmPUSdTEUizcVI0vnKY6lsBi0mHVGwgl04SSGWqiaWriWXZv3czKF29HlWWajHqIbEnTWq+TrtyC752xoKqazWeeFZbxbcf37jiUVATPkNuwtOqtjZF99n/E13yOuVk33OfdhmSyFQq6gtGMZ/CtmJt2Rk6ENaba5m8xNuqAZ9BN6JylqIpMdNkHhBa+lS8oX6l1q/N7TKH4Pv8VlHgYa/vTKTrl0lpsrsNBVVVS21cQWvgWmcrNiGYHtg5nYm1/xjEx5H4tqKpKpmor8bVfEFs7DzUdx9iwPUWnXFbw2f4pZAN7CS14ncSmb5AcJRSf8adC9xjyXe2l7xP+eop2/06/BlvHszWVcP9ujUpeuVmjkp95PaLRQnzj19R88gwAnkE3YmnVh1wsgH/WRNJ71msMtNOuQhClPIPhn1o3e/hY9K5yQBNX88+cgN5io/HI8RhLGiADBhXCOUjsWU/1O2PRW4tofeXDxI0l5DjgyW4RQa8HkyQysFN9Tmvp4cPVVagoeG1GylxmTJLEoq3V6AUBl91Mk2Iz53aoh9NSpx5+omPDhg20adOGq666ildfffWYzjmeCbYO2AScAewFlgKjVVVdd6RzjmUz/fTTT9m7dy9XX331IY+HQiE6dOhAWVkZxcV1tIv/BiYv2cU9M9cUfj6rbSnX54WcDhY9AwqdbgTo3tjFnQPbFBLxn0MfDy+ZQWj+K5hb9abkvDtqdR4BIstmEfziJYz121By0X2FRDD+w0Kta20pKszlyIkw/tmPk9qxEmuHMyk+688IOgPRZbMIzn8NyeLEPeQWzI1PRk6ECXz6LxKbvsFQ3hr3oBsxeBod6MDOexklFcPe7XyK+o6ulQin920kOO9l0nvXoyuuT1HfS7CcdMoxJNpxYqs/J/r9J+QCe0HSYW7SGXPz7piadEJXVP6rB5W5sI/kju9JbVtGcvty1GwayeHF3ulcbCefc9TqtKqqJLcsIbTobbK+7ei9zSg+/VpMjQ8kKko2ReTb9wgveRdB0lF06hUHutYHVfm1mbUb0Tk8pPdtxP/h4+RCVTj7jMTZZ5RWoZ7/KtFlszDUywuf2D1kA3upnvEQ2eA+rXiSV/1WMkmqZ00ktW25JkrT79Ja909VZGrm/JP4ms9x9BxOUf8rf9X7e3bbUk5uWFRQEq/DiYeamho8Hg99+/Zl0aJFx3ROXYJ9bPi5cUFdTPD7Qzonawl2TqEykqLcZaJxsQ231UAkpYk8FZkN6CWB6kiStfsiRFM5GhabaVfPiVGvFc33BuNsqoyyYV+YTdUxvHYDwUgWs0nCqhdYWxknkEgTjKbIhSv4/vnbkbMZykeMR/rR7K9GBx6LHA9qPsZ5he1cLKDZflVsouiUy3H0HgFA7PtPCMx9Acnu1kbGylrkhTcnka3ZpXVIT7kcJB2x1Z8TnPciQC3qdza4TxuJ2rkKQ1lLrVt90MiXko4TXjydyLIPALB3Goij13B0tp/+PVVVldTOVURXfEhyy3egKujdDTG37Im5aReM9U761TvbSjZFes8PJLcvJ7l5CblQBUg6jWHX9TyM9dsc9Rq5cBXhb6YRWzMXQWfA0fMiHD0urDVelty5iuDnz5Ot2ZXval+Pzu5BVWQi380ktOhtRINZo5LnC+7BeS9pY3XlLfGcfyf6ojJSu1ZT/cEk1EwS97k3YG3bH1WRCc1/jcjSmZgan4xn6N1IeeG42Np5BD55BnNJA066cjyYXNhNAllZEyMLbVnNrnfGoXe46fGniVjd5fgTKTIZUFRNA85kAINOwms30brcQYNiK2adwLo9EawWHZ0buEhlZH6oDLMvnOakcgdtyhy0KLVj1EmUOIzYTScW9b8OBzB69GimTJnC+vXradPm6L/vcBwT7PzFBwFPAxLwiqqqE37q+GPZTDOZDKtXr6Zbt24kk0n+/e9/M3nyZLxeL23atGHDhg14vV5eeeWV/3j9dTg6/vXlFh7/dCP7f1sEwKgXGTOkHeM+XFdL9OzzdZW1utQPX9iB1mV2vt1Ww/e7Q3y+vuqYXzeyfDbBuc9jatpFEyv50QxwfMMi/B8+gc7pxTv8gUKVMl2xmeoZD6GkY7gH3oS1zSnaF++iyUQWT0Nf0oSSC+5E725IunIL/tmPkQvsw9F9KEWnXg6SnsQPXxGY+zxKOoGz98U4e41A0OmRE2FCC14ntvpzRIuTolMuq9WtLiSeX71J1r8Tnasejh7DjonqpaoqmX0biW9YSHLzt+TC2r2SrC4M9Vpj8DZF72mM3lWO5CxFNFp/MjFUVRU1kyAX9pELVpCp2UW2ahvpik3IUU39XbK5MbfsqVlTNGx/1HkxVc4R37CQyJL3yFbvQOcqp6jfpVjanHoQHVwh8cNXBOe/jhytxtLmVFwDrkFndx+o/C96CwQR14CrsZ18Ligy4W/fIfz1FCSbG895t2Jq2J5c2Ef1B4+S2bcRe5fBGo1b0pPcthz/B5NAlCgZehemRlpin4vWUP3eODK+7RSf/ZdDrLZUOYf/wydIbFhYmNn/NZNrQdD8DxVV+xup88A+MTFmzBjGjx/Pe++9x7Bhw47pnLoE+9jxc+KCupjg94dMTmF3IE5GVqgIpWjstiAIAnpJICurSKKAgEDDYjM18TSrd4fx2Az4Yxk6NnRiM+oJJjLEUzkqQgmW7wxQE88hCDK+cJoShxFfNIXLrKcmnGRrIAFAsGI3a1+9FzkZxXvxWEwN2tValxwP4nt3HJmqrRSfeR32LtqMvpJNE5jzLPH187G07ot70E2IBjPpvRuonvUociKoWTN2GYKaSxOc9zKx7z9B722KZ8itGEqakAv7qPnkGVI7v9d0Wc75G3pPQ1RVJfHDAoJfvoIcC2BtN4CiU6+oJXKWi/gILZpCfO0XIErYOp6Fo/vQYxI0k+Mh4hsWkti0mPSedaDIIOkwlDbHWNYCfUlT9MX10RWVIlmLazH+DgdVziLHAmRDleRq9pCp3k6mcisZ37bCtU2NOmJp1QdL67613DaO+PtQvZPIdzOIr58PgoD95HNx9hmJZD2w92WD+wjOf5XkpsVIzlKKz/xToQiSrtxCYM6zZKq2Ym7VG/fZf0Gyukjv3YD/oyfJBStw9LyIolMuBUHUFMUXva3NyA+9G0NJE+RkFP8HkzRtli5D8rGCDlVVCS+eRnjhW9ibdOSMvzyM3mIllpEpNhswGnRUrf+WT/5xJxZ3PVpe/hDFZeUoKpTY9PjDSURBonMTF7sCcRRVoHWZnRZeO3pJQgUSGRmPzUAslcWkl6iJpYkkc3RpXITFoMNhMWA2iOhFHY3cdYJnJyJ8Ph+lpaUMGDCAefPmHfN5xzXB/rk4VjoYQDwe595776WmpobLL7+c/v37k8vlsFqtNGnShPnz59OkSZPju+A6HDKDDZrIWZ8WnkNEz77dVlPLuuvkBk42VkXJ5DTbgswx2nftR3TVZwTmPIuxQRu8w8ceYi2R2rOO6vceAkHQVCPzG64cC1L9/sOk9/5QmNURRElLzD58AjWXpvjM67B2OAs1myb4ZX5DdTfEPegmjPVaI8dDBL54kcQPC9AV16f4zOsxN+0MaEl88IsXSe9dj97dkKJTL8fcsvdBgicKiU2LiSyeTqZqK6LFie3kc7GffE4t9dEjQVVVcoG9pHatJr1nPenKLVp3m4P+ZiUdktmBYLAg6AwIogiqiprLoKQTKKkoai5T67q6onIMZS0wNmiLqVEH9J7Gx5RgyrEgsdWfEV35MXKsBr27IY5eF2Nt279WcSG1YyWhBa+TqdqKobQ5rjP+WKCUpfb8QODz/yPr2465eXdtds1RQqZ6BzUfP02mcoum3HrWnxFNNuIbvybwyTOoqqJVqNucckS7DiAvYjMOJRWj5II7MTfvXus9KNm0ZtWxdSlFp12Fs+fwo77vn4P9d/Hgb9UeTVxMv/7wfud1+G2QyWQoLy/HZrOxY8eOYy6w1CXYxwd1McHvE9FUFl8kTTCZocxhJJ1TEVQVg15CJwokMjJN3FYCiTSrdoVw243URNN0aOgkkVYQRe0am6oi/LAvgqKobK+OYzVI6CSBYCyjeWSrCnq9xPaqKKFElip/NVXT7kOO+Cm58J5D1J+VTAr/7Ekkt3xXSwRTVVXNFWPB63nFbs0i6mCVam1k7AYks4PEliXUfPIsSjpGUd/ROHoMA1EivuZzgl++gpJJ4eg5DGevEYgGE0o6Qfjb6USWzkIQBOxdz8PR86JaY1bZYAWRb98htm4eyDLmFj2wdx6EqWnnYxJDU9JxUrvXkt69jvS+DWR821EzyYOOEBBNNkSTFUFv0pJtVdU8sDNJlHQcJRWrdU3RaEVf2hxjvdaYGrbH2KBdLSHQI0FVZJJbviO68mNSO1Yi6IzYTj4bR4+Lao2I5WIBrau9ag6CpMfZ62Ls3YcWxGRDi94muuIjJIsT15nXYWndF+Ss1hD5bgaS3Y1n8M2YGnUkF62h5qMnSO1crSmKn/NXRKPlgLZOrCY/CniOtsZclppPnyW+dh7WdgNoOOgGSotMOK0SRWYL9T0WVi34hE/+NYaies059c+PkjTYEAStQNSg2IzHZqKpx0o9lw2jTiWSypKVtSTbZdGzdHuAZFamQZEVUYRQIoNeknCYJdxWEz9UREhkslSGkjT0mDn9pHp4HUe/v3X47+Luu+9m4sSJzJ49myFDhhzzeb/bBPu1115j0aJFjB07tpbZd1VVFbfffjsPPvggTZs2/Ykr1OHXwvKdQWas2MM7y3YjKyp63YEOdjanoD/ItutgOvnZbUuZ+0MVR9E2+0nEf1iI/8PH0XsaUzriwVpVUchTw959kFzEh/ucv2PrcAagVWoDX+TVJhu2x3P+HehsxeSiNfg/fIL0rtVYWvel+Jy/HZjBnvMsciyAo9sFOE+5FFFvIrl9BYHPnyMXrMDcqjeu065G7yrXutWbFhP86g1ygT0YSpvj7Dsac4setZRFU7tWE102i+SWpQCYmnbG1v4MzC16HtNGth9KNkUusI9scB9y1I8cCyAno6iZJKqcAUUBQdCSbYMZyexAtBShc5Sgc5WjL65fyxbjaFBzGZJblxFbN69AUzM17oS9+wWYm3WtJWCW2rmK8NeTSe9Zj+TwUnTKZVjbnab5Wkf9hBa8Tnzdl0h2D64z/oilVR9QcoQXv0N48XREo6U2JeyLF4mt/kyjhJ13B3pXuWaF8tGTh9h1ACQ2f4t/9uOIRmte0KZZ7XuXjuN7dxzpPes1cbTOg476/iURujZyEYhnaFZio5nHyuJtNazdF0ZRNNE+RVELCXULr40tvtgh17n+1GbcNejY6EZ1OP54/fXXueqqq3jyySe5+eabj/m8ugT7+KAuJvj9QlVVwsks8XQOm1GHQSdRGU6iquCyGnBZDSiKwuo9Ibb64nhsRro0dlIdzWIxSnyztZolW/0EYml0kojDZMBokNhZHcdkgEA8hwC4LTqW7QwQS0NS0bq6VdPHkPXvwjPkFqxtTq29LkUmOO9loss/wNysm2bjmB/nSm5fiX/2Y5oV5KCbsLbuqylZL51FcMHrSBYH7sG3YG7SSRsZ++w5EhsXYShtjnvgDZpYWjxEcP4rxNfOQ7KX4DrtqjyLSyAX9hFa+CbxdfMRDCYcXc/H3v2CWol2LhYguuIjYqvmoCTCSM5SbO1Px9r2tFq+0Ee//wpyxE82sJdcuAo5VoOciKCkYqi5NMg57UBJh6g3IZqsiJYiJJsbndOLvrgBkt39s1hcmeodxNfNJ75uHnIsgGRzY+88CFvngYe8x8h3M4it/BhVkbF1PJuivqORbC5URSa2Zi6hr95ASUaxdRqI69TLNXGyPeuo+eRZcoE9mujp6ddqVmubFlMz51nUXBrXGddh63gWAPE1nxP4/N+13GFA08apnjmB9J71Be9rnSBg1YNOgp5NXYRXfs6UZx7E26Ijfa5/GLPZSjSZQxFFyh1GnBYDp7bwYjLpcJr1NPfasBo1irdBEtFJIumczN6g9jtf7jThi6ZIZxVtnlyW8UVTbKmMsb06RvtGTkx6Ped1KEdfJ3p2wiCVSlFaWkpJSQmbN2/+WX8Pv9sE+8orr6Rnz5785S9/ATQPzH/96188++yzXHXVVdx///3Hc6l1OAwOnrn+8Qz2/p9HPv8NOQV0Ioy7oAPjPlxXq/v9S5Dctpzq9x9GshbjHTGuQAffDzkZxT/rEVI7V9cSu4D87M1n/0IwmPEMuQ1zk075eZ8ZhBa+rW2oA2/E3KwrSjpO8MtXia2ag+QsxX3W9Zibd0fNZTRRjsXTUZUc9i5DcPYeoXlvKjLxdV8S/mYquVAlek9jHD2HYW1zai2rjVzYR2z1Z8TWfoEcqUbQGzE3646lVS9MTbseExXreENJJ0huX0FyyxISm5egZhJIVhfWdgOwnXxOrc1fVWStS//de2QqNiPZ3Dh7X4yt4zkIOn1exG0GkaUzURUZR/cLcfYegWgwk9q9lsCn/yJbs7uWKmlqzzpqPnqKXKgKR6/hFPW7FEHSad7ksx5FjgUKdh2CIBzke/omhvIWlFx4Xy2VUtA2et87Y8n6dx82GDsSujdxcVpr7yGz1Pt/510WQ60C0yktSw47AtHEbWH+7QN+4SdSh18b7du3Z/PmzYTDYUymYy9w1SXYxwd1McH/FmRFRVHVgqCTqqrsqElg1InkFAW9JGLUSVSHk7y8aAsVoSThZJZUTsFhMtCzuZtkIkMwnSORyVIdzRJOpAknMmRkiOQln5VUDN9747Wi6VnXaRZOP0J05ccEPv83+uL6lFx0f4GSrdlzTiRTsQl71/NwnXY1gk6fHxl7nFxgD/au51HU/0pEvUljU33+HEoiohXf+43W9rE96wjOfUFjbNVrjWvA1QUWXaZ6J+GvJ5PY+DWC3oTt5HNwdLugFoNNzWVJbPqG2OrPSe1cBagYSptjadUHc8uex8wwO55QVYVM1TaSm5eQ2PQ1Wf8uEETMzbpi63i21lA4SGcmG9hLZOlMYmu+AEXG2u40nH0uKTQlUttXEJz/KtnqHRjrt6X4rOu0okUyqo3frZqD5PDiPvdvmJt2QUnHCcx9kfjauRhKm+M573b07ga1vMlNjU/Gc97tSNYiIM9me+8hlEQI98AbcbTtX1CeFwFBVUktncq+L9+mqHUPev1hDJLRgipAicWAgorXbqRDvSICyRzJbI5W5U46N3RpVDUV6hVZMBsOTZJzskIik0MSRVLZHGv2hPluew2ReIa2DZyAwLAuDTEd5tw6/DZ48cUX+dOf/sQ///lP/vrXv/6sc3+3CfZHH33EnXfeyZ133slXX33FggUL6NKlCzfffDM9e/Y8ziutwy/BwUriogB9W3gY2L6cZ77YRGUk/R9dO713A753HwRRwnvxAxh/JHKiyrl8x1qza/Ccf0chac1U78Q/ayLZmj04el9MUd/RCJKOTNVW/B8+Qda/K18tvQbRaCW1ey01c/5JLrBHE+I444/onKXkojVadXrNFwhGC84ew7B3PQ/RaNES7fULiCx5l6x/F5KtGFungdhOPqeWqImqKqR3ryP+w1ckNi9GiYdAEDGWt8LY+GSNplWv9X9FSVzJpMhUbiK1a61GR9/7AygyosmuzWa3ORVT45NrbaByIkxszVyNLh6uQldUjqPnMGztz9QS62ya2PefEP72HZREGMtJp1DU/0r0RWXkYgFC81/VutkOL+6z/4y5eXeUbIrQV28SXfYBOqcX9+CbMTVsr1HCv5tB6Ks3kWxuSi64syAio2SS1Hz8DxIbF2nU8nNvQNQba72/bM0eqt4Zi5IIa3TCpl2O+d4IaDPVh7Ok24+DC0wbK6Pc9/6aQ9gadR3sEwdLliyhV69e3HjjjTz99NM/69y6BPv4oC4m+N+GqqrsrEmglwSSWRlFUalXZCaVyfLiV9vZVhOnIqgl4M29NnQ6gcbFVorMBiQRtvvjbPFF2ReOUxNKEUpCIv8dq2TT+Gc/RnLztzh6XUzRqVfUSkhFIL5zFf73J4Kq4Dn/jgKlXJWzBL98lejyD7TE7fw70BfX1/aiBa8TXT4bnase7kE3YmrQDjkVIzT/VWKrPtWYWAOuwXJSP1AV4mu/ILTwLeRYAHPz7hSdcnmBRXWw9SWApVVv7F0GY2zYodZacxE/iQ0LiW9YRKZiIwCS3YOpcSdMjTpgbNAWXVHZcU+4VVXRxtP2rCe9aw3Jnd8fiFEatMV6Uj8srfsVktn956S2r9RE2bYuBUmPrf0ZOHpeVGiGpPb8QGjhm6R3rUbnLKWo/1Xa/UMltnouoQWvHRCQ7XcposFcm1XY62KK+o5CkPQaJfyDSeQCezUh076jCjFK/IeF1HzyNKLRSumw+zCVt6pl66YqMqHPnyPy/RzKup5F4/NvIKcKlDrMGPI2WyU2E6qijTuEElmaei3YDHpaldpp5LGRyOTQiwLe/Dn7kZUVZEXFkC8u+WMp9gaTbKuOUhPLkM0pdG5STJ8WJdThxICqqpx00kns2bOHYDCIwfDzxAN/twk2wOeff87ixYtJJBJcf/31iKLIkiVL8Pl8VFdX07Fjx2MWqanD8cf+ee1MTinYehn1Ih3qO1m6I/gfXz9bs5uq6WNRkhFKLrgLc/ND493oqk8JfP4cks2N98J7MJQ2B7RkMjD3eeJrPsdQ3hrPebeid9VDzWUOzPtYiyg++y9YWvZClbNEvptJePE0VEXB2WMYjp4XaTM/1TsIffUmyS1LEM0OHN2HYu8yGNFoLVRpI8tmkdq+Qqv2tuiBrcNZGrW6lu2YQmbfJpLblpHcvpJM5WZQFUBA726AvrQZhpKm6N0N0BWVo3OW/ixa+X4o2TRyxEc2WEG2Zg9Z/w4yVdu0anT+9QylzTA16Yy5eTeM9dsc4kGe2rmK2Jq5JDZ9A3IOY8P2OLqej7llTwRRQsmkiK36lMh3ml+mqfHJFJ16BcZ6rTUGwLJZGgNAzmrd7D4jNQr+ju8JfPpPcqFKbJ0H4TrtD4gGM7mon5qPniK1cxWWVn0oHnhDQRE0G9hL9cwJZGv2UNT/Shw9hh0SeKT2rKf6vfFaQWb42F/sMy4CfQ9jTXc4Bf1UtrbGQPcmLt6pm8E+YbBw4UKmTp3KvffeS716RxcYOhh1CfbxQV1M8L+PVFamKpzCF03jNOsQBIFii55FW/1srIyxeV8Il81Am3oOVFGgZ5NigokcqAp7Agk2V0bY4Y/zQ0UIh8lAPJWmJgYKkFNkfJ89R3TVHKztBuAeeEMt5hho3tTVMx4iW72TolMuw9H74sKIU2Lzt9R8/A9UOYvrjD9poqWCQHLnKmo+eQY57MPeZRBFp16JaLSQ3vsDNZ89R9a3DWPD9rhOvxZjWQuUbIrosg+ILHkPJR3H3LIXzj6jCo2AXKSa6PLZxFZ/hpKKoStugK3jmVjbDjiUdRWtIbltGantK0jtXI2SigIgmh0YSpvnBU8boXPVR19Uhmh1HtMM98FQVQU5FiAXqiQb2Ee2ZhdZ3zYylVtR0nHt9axFmBqdjLlZF8zNuh3iLpILVxFbO4/YmrnI4SpEaxH2kwdi7zKoMMqX2rOO8NdTSe1YiWgpwtl7BPZOAxF0eo0B8MVLZCo3Y6zfhuKz/6w5vyQjBOe9RHztvFq6OAUq/1evI5kduIfcirnxydr7OUhB3Fi/DWVD78FW5CKe03RRdICQSbFv9iQSW77D23c4XYZdRyYHmVyO5l4rWQXMOj0dmxQh51Q8VhM/VEZwmCRaljlp6bViMeqpCqdwmnVYDHrKi8yYDRKprMy+UBIVMOtFzDqJQCKDJAnsqklQ7jSSU6B+kRmH+bf1Nq/DASiKwh133MHJJ5/M5Zdf/rPP/10n2AdjypQpzJs3D5fLRWlpKZWVlSxfvpxrrrmGSy+99FdeaR1+KZbvDPL03E0s2uxHRRNAG9WjEVOW7OKnJM5aeG1UhZNE0/JPHJW33Xj3QU0p+qzrDztPm963ker3H0FOhCk+68+FTRO0Cmfg03+iKjKu06/FdvI5CIJAumITNZ88Q7Z6B+ZWvSk+40/oHCXkIn6C818l8cMCRGsRRX1Ha56Nko50xSbCiyaT3LYMwWDB3nkg9i7nFUQ+soG9xFZ9SmztPJRECNHswNK6D5bW/TA1bH+I4qeSTpDet4H03g1kKjeT8W0vKH7vh2i0IlpdSGY7gtGCqDdpAYUgaIImcg41m0LJJFCSUeR4qLBB74dkK0Zf0hRjWQsM9U/CWL9NIXndD1WRSe/bQGLDIhIbFiHHg4gmW4Euvt8LXI6HiK78mOiKD1GSEYyNOlDUdzSmRh00n+x18wktfBM5Uo25RQ9cp1+L3lWv1hybzlWO+9y/F9TA4+sXEPj8OS3oOf2Phc8IOOCHLem1jkSTTod8/vEfvsL/0VOHKMz/HEiiNtaucqBQtL+TfXAhyaATuahLA6Z8t6vQvf7x8XX4/aMuwT4+qIsJ/v9AVlbYHUhgNerI5BQEVaOS7wsnkQRIywr+WAazJOFxmIilcjTxmFmzN8j073azqSKMrAoYdSpZRcUgiCQzCkkZrHqVzfOmEVr4FsZGHSm58J5D9jMlk6Lm02dJrF+AuXl33INvKTDcchE//o+eJL1rtSZ0ds7fkKxFKJkkoa/eILr8QyRbMa4z/6Tph6gKsVWfElr4FkoyirXdaRSdcjk6pxclFSOybBaRZR+gpuOYmnTG0eNCTE06IwgCSjZNYsMiYqs+Jb13PSBgatwRy0mnYGnZq1ZnGLREOOvfpYmdVmzSCuM1uw7MV4MmeGp1IVmcmruIwZwXPs0LkCoyai6DmkmgpOLIiTByPKiphuch6AzoPY21BL68FaYGbdAVNzikcJ2L+klsWkzih4X59YOp8cnYTj4HS6veCJJeE0DbupTIkhmk965HtDhx9BiGvfNgRIOJbM0eQl+9oflk29wUnXYV1ranARBf+4UmIpeOazaafUYh6PTkIj5qPn6a1M7VmFv01MTo8gl/LhYg8MEkkrvX4uwymFZDrkWv16OXBCrCKgqgxAPse3c8qaqtdB5xA016DyGUOuDjnlZUvHYDVoOeYrOe+m4r0VSWcELGadPRo0kxnRu52RNMsr06hsduxGHSYTfrKbGbqImliaVzmPQSsXQOm0EinpGxGCR8kRQ2sx6HUUex1Ygo/ra0/zr8evifSLDfeust5s2bx+DBg2nfvj0NGzbEYrHw/vvv88YbbzBjxozjsNo6/FLsT0D2z6de1bsJLy7ajqyo+0dYDoEkgHyMv5JKOoH/g0kkty2rpRR6MOREGP8Hj5Ha+T3WdgMoPvsvBZGvXKQ6/2W9ClPTrrjP/Rs6RwmqnCOydCbhr6eCIODsMwpHtwu0+ax9Gwl++TLpPevRucpx9h2tzVmLEunKLUSWvEdi49cgCFha99VoYPXbarPCco7k9hXE188nuWUJajaNaLJhatoVc7Oumt/1EbwxlXScbM0ecqEKcpFqTeAsnzQr6QRqNo0qZzWjRkEAUYeoNyIYrUgmG6K1CJ3NjeT0oi8qQ1fc4Ijz3nIiTGrnKpLblpPctgwlEQZJj7lZV6ztBmBp3r1gN5au2Ex05Uca9U3OYm7eHUevizE1aKsJwG3+VvPJrt6BobQ5RQOuxtz4ZE3g5PtPCH31Jko2rSmx9h6JqDfWFpUpb41nyC2FuW81lyU4/xWiy2djKG9NydC7almhALUsOYwN2lIy7L5awis/B2e1LeWLgwT6RAFuPbs1fx3QotYohCTAyB6NmLFiD9mcgiQKXNytIcO6NKhLrv+HUJdgHx/UxQT/f0BVVaoiKeLpHDlFJZzIkMjIZGWFZFZBFEAvQTiRJStDlyZFRJM5Pvh+D2t3h9gdjGPUad/DsixQ4jSRyORQFQGrQWJjVZKatfOo+eQZdEVleIePweCqVyvWUFWV2MqPCHzxEpLNRcn5dxZEsbTu6PsEv3oD0XBAdBO0gn3NnGfJVu/A1LQrxWf+SaOT1/K5VvM+1xejsxWjpONEV35CdNks5HgQvacR9i5DsLY9rTD+lQ3sJb7uS+I/fEUuuA8QMNY/CXOzbpiadsFQ2uyQuAa0hDkXqiQb3EcuVHmQ6GkENRVHOVj4FECUEHQGRINZUxk3O5BsxZoAqrMUnaseOqf38K8l50hXbCa1fQXJbUvJVG4BQO9pjLVtf6xtTyvMlcupGPE1c4mu+IhcqALJ4cXR40JsHc9C1JvIhasIfT2V+NovEPRGHD2G4eh+IaLBRMa3jf/X3n2Hx1VdCx/+7dOmj3qz5N4ruIBtiukdAikkJKTdfCSBXFIgPUACpNwQSIAkQBIIqQTSSIFA6B0Mxr3hJjfZ6mU0fU77/jjSWHIBA5LHlvb7PHnuxR7NrCOBzl5nr71WxxO/INuwDt+IKZSefSVGxRhc1/UamT11D7gOJad92tvg6En801uXe5NhzDTlZ11JyYxTCOpQHvHjujZNMROrvZ76+7+LnYlz2hU3UjvjeGLpLEG/TlfK9Br0qQpjyiMYqqA1nsHQBE3dOWqLAxw/oZyjR5dRGfWxoz1Fd9okkbWI+HUmVoWJ+HUSGZOm7kx+RF1NkZ+2ZJas6VAcNCgNyV3roeiIT7A3b97M5Zdfzje+8Q1OPfVUFMUrhdmwYQNf/epX+eAHP8hHP/rRwQhXehf6NoP69r/WYL2bVuL74XUKvYf40ocIjD+G8gu+us+5Zdexib38Z2Iv3Y9WWkvFe76WPxvlug6J5Y/Q+exvQKiUnPI/PTulClasmY6n7ia9aTFaSQ0lJ3+KwMQFAKS3vObNuW7dhlZaR9HCD/Y0NNMwu5q8MrDVT+Jmk+jlowjPPIPQ9JPzJVOOmSGzdRmpTa+Srl+Kk+oCQCutw183DV/tFIyaSeildW850/Ldfv+sjt1kmzaS3b2BbMM6zNZtAPnkPzhxPoFx8/LfVzsdJ7X+eRKrHifXvAWh+whNP5XovPegl43MjyeLvfxnzJZ6tJIRPXOyT0QIhfS2FXQ+fY+3UBl9FKWnX95nlmjP7PFcyhuLMv/9+Ru+2bGLtn//iFzzFu+Bysmf3KcM0LVytD/6U5LrnvVKBc/+AkLr/5qDJYCPzB/Fn17dkV+gaYrgz59d2G8Hu28HfaBfwz/p8OG67rs+uygT7MEh1wTDh+u6ZC2H7rTJjnZvtnVDVwpDUXCESypj47guOcvh6FHF7O5I87tXNrO9I0Vbt42hQlFIYURRCFUBXEFzPIXteo2r4lmbxk1raHnQG7te8d5vEhg1a58H+tnGjV7TzHgbxSd+jOj89+VLrHOt22l/5FZvbOSk4yg944p85+v40ofpevE+XCtHdO4F3jEnfxiru5XYS/eTWP0kQtW8hmY946pcyyS5/nniS//t3TONAKEpJxKaeQa+2in5Zp1m61ZSGxeT3vJaPokVvlDPmmAqRs0kjOoJ++zMDzQ7FSPXtJls40Zv13zXelwzk+8TE5hwLMFJC9HLRvb8TB2yO9aQWP0EqQ0v4Vo5fLVTicx9D8HJxyEUFbNzN92L/0ZizVMgFCJHn+M1iQ0VYye76HrxjyRWPo7iD1N80icJzzo9vw5r/+/PyWxbjm/kDMrO/VJ+LKdrW3S9eB/di/+GXjaS2ou+ga9iFGG/wKcp+DSNgKGwffkLvPHnH2IEI1x24x20GDW0d2cxFMiYFhnbojISIOzTKI8G0RRBznIRAnZ2pOlKZRlZFuQTC8exYEIFDZ1JFAHxtEVVUYCa4j2TWZJZE9NyCfq0/Nnsgbj3SANvoH4uR3yC3dbWxty5c9m+fTsAr7/+Oo8//jirVq2itraWH/7wh+j6O1tIS4Pvjmc2c8tjG95VF/E3s6dTaB0V77823ym0r8yOVbQ9dAt2upuSRZ8gcsyF+Ruq2dVEx39/Smb7Knx10yk760r0cu/mka5f6iWE7Tu981Yn/0/+LFBqw8vEXn4As3Wb96T2mAsJzzwDxRfEyWVIrn+exMrHvIYlQsE/ZjahqYsITJyfv0n2dujMbF9Jducasg3r8uefUDX0spHe/0pGoBVVoUYrUMOlXimYP7zfJ869XMf25mGnurATnVjxVqyuZqyuRsz2nZhtO71RHoAwAvhGTME/aib+UbMwaibm39vJpkhvWULyjRdI178OtoVeOZbwrDMJzzgVxRfCMbMk1z5D95J/YnU0oJWMoGjhh7xRXYpKrmUrXc/9jnT9616Dk1M+RXDScd4oi64mOp64i0z9Uoyayd4olIrRPd8fl8SqJ+h86pcI1aDs3C8RnLhvMyMr3kbrP75PrnGTt2Ba+MF3/cvzoqNH8MjqRnK2iyrguxfN5CPzR+X/fu8O+tLhI5VKoaoq7e3tRCIRIpF336FfJtiDQ64Jhp/udI6GjjRdqRxNsQzVUR85x2Vbe5KRJQH8moIrFLqTGe57bQexRJa2lElAhdJogIhfYWd7FkNxSJs5gj4/02ujrN8dpyORobtlN5vvv4FsZyM1Z3wW/ahz9rkfOJkE7f/9GakNL+EffRRl512FFvGOduWnjLz4JxTNoPjk/yF81JkIoWAnO+l87vckVz+JEohQtPBDRGafi9B0zM5GYq/8meTaZwBBaNrJRI+9KL8Tm9u9gfjK/5J64wVcM4tWXENo6iKCU05ArxiTj9FOdpLZvpLMjtVkdq7F6mjIx61GKtDLR6GXjujpy1KJGi5DDRWjBKJeafgB7n2u6+KaWZx0N3ayEzvRjhVr6VkT7MJs34Gd6Oh5tUAvH4Vv5AxvXTD6qHzlm+u65Jo2kXrjRZLrX8COtyKMIKHpJxM56uz8JkZ29wa6l/zTq+xTVCJHnUl0/sVo0XKcXIb46/8i9urfcK0ckdnnUnTCpaj+sPcw4/V/0fXifSAUSk/6JKHZ5+xZs3Xupu2hW8g1bvSa057+GcKGH9OFoAE+TVASMGh/+a8s/dfd1I6fxnduv5cZk8eyZGs7W1oSbG7poqEzi09XKAn6mF0b5aSpNSQyNmt3x9jW0U1r3CJiaIwtDzB1RDHnzKqhNZ6lM2lSGtKZVBVFVd/e2XepMAZjTQBDIMEGuOyyy0ilUqxfv57p06czffp0Zs6cyUknnUQkEiGdThMIHPyMX+nQ6d3t6zuqS1NgTPmeucEHKhs/WOkDdArty07FaP/vz0hvWuzdUM/9Ur7EuLcMqfOZe3FyGaLHXETRcZegGH6vpHnlY3S9+CecVBeBiQsoPvGj+ZtmestrdC/+O9ld6xBGkPCsM4jMPjdf2pxr20Fy7dMk1z2P3d0Ciop/5EwCE44hMG4eWsmIPnOznZ5d5U2YLVvJtW7H7GjA7m7taUbWn9B9CM3n7XQLBVzHO29lZr2nzvuhRiq8BmrlozAqx2FUT0Avq+uXrFuxZq9MfMsS0ttWgG2ihksJTjmR8IxT843jrFgL8RWPklj5GE66G6NqPNH57yc4+XjvyXXHLmIv3U9y3XMoviDRhR8kOvcChGbgWibdrz1I7JW/gKJQfOJHvfFbPXHYyS7aH/s56U2L8Y2aRfn5V+cXQH1lGtbR9s//wzEzlJ93NcFJCwGvlPDdFE1UR310pMz8LvX9n5bnqY8EsViMr371qyxfvpzjjjsOx3GYMWMG5513HnV1de/4fWWCPTjkmmD4cV2X1niWWNpEVyGVc9BUQUBTSeRsdEUQ9KlsbunmPyt2sbE5iaEJZo6I0BzP4VcVVuzswsJFxcXQVPy6Sjxr49NUOpIZsoluNv/tZro3v07RrDMpP+MKLE1HA6w+cSRWPU7nU79CqLpXFt5njKPZsYv2x35OdsdqjBGTKTvzc/l7X665ns5n7iWzfQVqtJLi4y8hNOM0hKJixVq8GdCrn8A1s/hHH0VkznkEJvQ0BM2mSG18meTaZ8jsWA2ug1ZcQ2DCsQTGH4O/blr+KBZ4lWO5pk3kmuvJtW7Fam/A7Ni1/3t87xExzQBF9fqyuDauZXqvd/btcSOMIHppLXr5SIyKMRhVEzCqJ/SrCHTMjNdVvP51Uptew463gqISGDuH0LSTCUxcgKL78qPH4sseJrtr/Z7eNPMuRAuX4lom8ZX/JfbKn3GS3nqq5KRPopd5v5szDWvpePwurxfO+GOoO+sKSqqqiWcdkpZLYuVjdD59D0JRKT378/kyfh2ojghKQgGwUqz/y4/ZvOQZTjzrQn5w688oL47iuC4bmrrZ1BxnbUMXLd0ZDE3BBU6ZWMnsceVoqkB1Hda2xFm6pQPXdRlZEWJ6TQmzRxXjIggY3jnrupIAPk2O2zrcDdaaAIZIgp3NZlm9ejXxeJy6ujqam5tpbW1l8+bNPPDAA8ycOZPzzjuPiy++eBCilt6tP726gz8v2cGa3TEcB3RN4foLpufnCAsh3nUJed9OoUUnXkrRwg/u01nTdff8gkYolJ72aUIzT9/z5DgVo/OZe0mueQo1XEbJKZ8iOHWR15wkm/Kal7z2D9xciuDk4yk67kMYlX2e1r7+L+9prWPjHz3LmxM5cQGK7vee+jZuJLXxZVKbXs0/lVajFfhHzcI/cjq+EVPRymr3jds2sbrbsOOt2IlO7FQMJ5PAzaVxrCzYFq7reF+nav3OW6nBIpRQCVqkHC1a0e/G3fs9sboaye5aT3bnWjI7VmN1NQKgFVURmLiA4KSF+c7irmWS3rKExKrHSdcvBSEITDiW6Lz35EePmO0NxBb/leTaZxCqTmTu+UTnfwA1EMk/lOh86h6srkaCk46j5LRP9ztPndzwEh2P34mTTVKy6ONEjrlo/z/L5f+h46m70YoqqXjvtRgVo1EAQ1f49vnT+cMr21jf1L/B294UAcUBnY6U2e/Pjx1TwuvbO/PnrK/uOX8tHd6+8IUvkMvluOGGG1i3bh07duxg7dq1GIbB1VdfTWnp/nsdvBWZYA8OuSYYfrKWza6uNK4Djd1pRpcFcRwwVIVowKs8UARsaOpmS0ucldu7iGctJlRG8OuCHR0pFm9rpzuRIWBolEf93lnXgE5XOktTzMRQwc7l2PrUH2l68S+E6yZRc+E30YsqSO6VY5odu2h7+MfkGjcSnHIipWdcnm+g5bouybVPew/f03HCR51J8Ykfy/99eutyup7/PbmmTWjFNRQtvJjQ9FMQqo6djpNY+V/iyx7Bjrd6oztnnkFo5un5xpt2spPUpsWkNi4ms2Ml2BZC8+GrnYpv1Az8tdMwaibm+8f0cl0XJxXz+rIk2vv0ZUl6D9itHG5PMi0UDaHpCN3vNUkNRFBDxajhMrSiShR/ZN8d/myS7G6vTDyzczXZ3W94sek+/KOPJjhpIYEJ8/eMQm3bQXL1kyTWPIWTiqGV1BCZc/6eqj4zS2LV43S/+nfseBu+kTMoXvQJ/HXeCEsr3kbXs78lue5Z1EgFpad/mtDEhfiFoLpEp2lXIzv/8zOS9Uvxj55F2blX7dODRQVEZwO7//EDMm0NfPTz3+BbX/sKQb9BUUDPn5Vevr2TVTtaeWV7HF24jKsIsmBCFWPLgyQyFqUhHw6Qyprs7kyhayqTqiKEfBptSRNDE9iOS11JMD/vXTp8DdaaAIZIgt3XmjVreOCBB2hpaaGmpobjjz+e8vJyzjzzTFpaWvLnsaTDw/52sHsTlgXjygb0nLZjZuj4789Jrnt2n06hfZldTbQ/chvZnWt6mpz9L1q0Mv/3mYb1dD75C3LNWzBGTKbklMvyNwI7Had7yT+JL/03bi5NYNw8ovPfj2/kDIQQ2IlOEqseJ7HqcaxYM8IIEJx03J6Z0j3nqs2uJm8Mx7YVZHauwUl3A97TZKNqHEbFmJ5SsDq0kmrUcNmbloS/lX5jOdobvDLx1q3kmutxMl4lgeIL4Rs5Hf/oowiMnZPvIuo6NtmGntndb7yEk4nnFwvho8/Kf++yuzfQ/dqDpDa8jNAMwkefTdH8D6CGvZ1f76n/r8lsX4lWWkfp6Z8lMHZ2PkY72UnHE7/0mpxVjafsvKvy3cr7/Zxzadof+3m+I2z5+V9G8Ye5fNE4IgGdkqDBmt0x/ra0gZx14N71Am8h5+KdsR5fEaYzleOio2s5Y3r1Pues5Q724e/rX/868+bNyydW2WyW9evXc8cdd2BZFnfccQfB4NufMS8T7MEh1wTDT3c6R2fSxNAUNrfEGVESQBUKQZ9KZcQbQ+m6Lm80d9Mez9GVzNGVzlIe8dOdNGmMJ1i5rYvtXSniaYtRxX4iQYOwX6ehPU3ONGmJZ4llwK9BatPLbP/XbQhNZ/IHv0qu8miye8XkOjbdi/9G10v3o/hDlJ5+OcEpJ+QTTyeToOvF+4gv+w/CCFC08ENE557vVWO5LunNrxF76U/kmregRiqIHnOR19zLF/S6am9+jfjK/5LZuhxcB1/tVIJTFxGafEL+/ujkMmR2rCKzbTmZHavzPVG8sZ0j0SvHYlSM9sZ2loxAK6p+R2M7+3KyKaxY76iunZit28m11Pc0XQOEglE51tsEGDvbm37S85De6m7xysTXPUeueYu3oz3hWCJHnY1/7GyvpD4dJ7HiUbpf/zdOqgtf3TSKjv+ItxYSAieXpvu1B+l+7UFwHEqOfR/FCy7GF/ATUCFjusTXPM6uJ36Na9tUnvxJInPOw9zPWLK+E0ZO/vT1XPLeC7h04RiCPo2MaZPKWXQks+zuTLNqV4x0zsIQguMnVhD065SEDJp7GpWVhgwyOYe60j271K7rEkubZEyHooBOwJC710eCwVoTwBBLsF988UU+/OEPc9lll/HJT36S0aNH5//u3HPP5cYbb2TePLkGOpz07bgMBx5h9KdXdwxIku26LvFlD9P59K/36RTa/3UO8WX/oeu534EQFJ/4MSJzzus32iK55im6XvgjdqKDwKSFFJ/4MYxy7xyunUkQX/Yw8aUP4aRiGFXjicx9D6GpJ/bcdHuaf6x9mtSGl3FzKRR/mMD4YwhMmE9gzNEo+bPYLlZHA9ldb5Bt2kyueTNm2w7cXHpPwIqKGipFDXtnrRR/GMUI9CkR9zJF1za9p9dmBieTwE71nrfqAGfPaA+h+dArRmFUjseonoCvdgp6+aj8TrGTS5PZvor05ldJbX4NJ9WF0H0EJiwgPP0U7waqqLi2RWrjK8SXPuSVyftCRGafS3TehfmRI2ZXE7EX7yO59lkUf5ii4z/snVvredjgleg/Seczv8YxMxQf92Gvydl+mrzlWrfR+s8fYnXupviES/vNNP37Fd7M6b0f6PQ1oSLEjs40tu1VTtiOmx8nt/cutTxnfeR56qmn+J//+R8++9nP8vGPf5yRI0fm/+7000/ntttuY8aMGW/7fWWCPTjkmmD4yZg2u2PevS2ZMdFVlaKgTkXYh6buaQ6VzFo0xtJsboljmi5CgZ0dKXyaYNm2TnZ1polnc5SHfUysihAwFFoSOZIpk2UNnZg50DUwLagWHSz5zXeINW6j4vgPEVj44f0+sM61bqP9kdvJNW0iMOFYSs+4vN/D91zbDu8hcf1SrzT8hEvzvUZc1yVTv5TY4r+SbViL8IV6joydl9+xtrrbSK57huTaZzHbtgMCX91UAhMWEBg/D71s5J6KunSc3O43esZzbfHGdna39otX8UdQwyUowWJvdKcRRDF6Rnf2PlhyHFwrh2NmcLMp7HS3t/ud6MDt7fnSQyuuRq8ci69qAkbNJHwjJudLxV3XxWzZSnrLElKbXyXXuBEAo3oCoWmnEJp2Uv6en2vbQXzZwyTXPOWVyY+dQ9GCi/MbEYZtklzzOM0v3I+d7CI8+QQqTv4koZJqVB3SOch27KL98TtIbF9FeMxMRp/7BdSyGuI5+o18FZZJe58JIyMv+jpTJo7igjmjOXtaNVXFARo6Uri47GxP0ZHKsqM1SdCvURzUOWdGDaYDiazXZM91XUI+jWTOprY4gF+XifSRbLDWBDDEEuxvf/vbjBkzhk996lMA5HI5NmzYwN13301LSwu/+tWviEbf2VgeaXD07bisKIJpNVE+dMyo/TaMKgka/HnJDlY2xN715/brFLroY0SPfd8+ZcbgnTduf+xOMluXYlRPoPTMz+GrmZT/eyeXoXvJP+h+7UFcM0to2kkUHXdJ/ox1b4Ov+Ov/xmzfgeKPEJp5GuFZZ+aTcdfKkd66jNSGl0hvWeLtGAsF34jJ+EcfhW/kDO9G1qcMzHVd7HgrZoc3isPqbsGOt2MnO3HS3TiZBE4ug2tlcW2rZ0wXCFX3SsR1/56RHH3KwbSiKrTSWm8sR5/vh2NmyTVuJLNzDZkdq8g2rAfHQhgBAuPmEZx8vNdRvOeJudnZSGL1EyRXPYGd7EQrriYy54L8U3vwnnDHXv4LidVPIBSVyNz3ULTgA/kHC+AtaDqe+AXZnWvw1U2j7KzP55vM9eWV9/+XzqfuRviCVFzwNfyjZ/V7zbFjSphQFeH+Pt2/93bsmBIuml3Ho2samV4T5d6XtmLaLroquP8zC2UiPQQsX76cv/71r7S1tVFbW8tJJ52EqqpcfPHFNDU1vaP3lAn24JBrguEpa9ne7OCMhaYqlAQNSnpGGdmOS2MsTa5n7KEQsLExRsaGIp9KayLLsu1d7OpKks1564rxFSFGl4cI6IJHVzZT39pJexLMnhtB1A+lhsumf/+cHa8+hr9uOmUXfGWfMmOgf5MtBEXHf4TovPf0e+Cb3raCrud+S65pM1ppHcXHf9jb8e5J2vMNvja+DI6Df+xsb1b0hGPzEzByrdtJbXiJ1KZXMFu2Al6PFP+Yo/CPmomvbjpaUVW/8m0nm8LsaMDqbMSKNWP1jOhyUt19RndmvAftPWO6hKIiVK1/iXiwyGuYGi1HK6pGL6lBK6nttyPuui5W524yO9eS3bmazLYV3vxswKiZSHDiQq9BW09jWcfMktr4sjfje+caUHVC004iOu9CjMqx3nvaFok1T5Nc/GeyXc2UjJ1J5SmfQK2egmV7/XmwsrS8/DfaF/8VofmoPPl/KJ59JopQ0BVIW5AFDMDu2EXzwz8i2ehNGKk8+ZMYus7sUWFOm1pHXXmAMWVhmrszRPwa9S0JdrQnWdcYJ6ApVJX4+cDckYwqCxNLm2gCOlJZ0jmHyqiPqmhAdgEfAgZjTQBDLMG+/vrrWbduHT/72c9Yu3Ytu3btYs2aNQghuOqqq6iqqhrAaKWBsnR7Jw8ua+Cvr+/EclyMPqON/r6sgb8tbcCyHQzNOzt73T9XH/RM7DfzZp1C+3Jdl9QbL9D59D3YiU7Cs86geNHH809jwTuf3f3q34kv+w+ubRKcciJFCy/OlzC7rktm+0oSKx4ltWkxODZGzSRC008hNOWE/Jgu17HJ7lpPuqc8PNe02WtgJhSv8Vj1xJ4S8dHoZSNRgsUD/gveTndjtu3AbNvhNU5p2kSudVtPAxSBXjmWwJij8Y+dg3/k9PyCwE7HSW14ieS6Z70bqFAIjJtLePa5BMbO6TdWq/vVv5NY8zQA4aPOomjhB9EiZf1+Nl0v3U986UMovhDFJ3+S8Kwz9vsQxE53ew3qNr6Cf8xsys+/Ov/93JuqeLvSByLwegBYtoOmCFzAsl3ZyOwIl0qlWLp0KblcjtraWlpaWmhra2Pjxo386U9/Yv78+Zx//vlceOGF7+j9ZYI9OOSaYPja3p7EpykIIG06jCoN0pU26c7kyOQcSkIGyYxFRcRHVypHY3eGiKFS35akPZFmeUMM03SI+nVqigNkLJtExqS+OUkqk2VbZ4qkN2GKgAGjygK4tmDFcw/T8Mhd0Nsoa/Lx/eLqbbpqxZrpeOIXpLcsQS8bRcnpnyEw5uj861zXJbXxZWIv3ofZtqNnbOfFhKaelE/GrXgbiZWPkVj1BHa8DSUQJTR1EaFpJ2OMmJy/t1vdLaTrl5LZupzMjlX5Y1tqqASjZiJG1XiMirHo5aPQiqsHfISna5mYnbu9NUHrNnJNm8k1bcofXVOCRfhHzSIwdg7+cXPRwqU93wOH7M41JNc+S3LDS7jZJFpxNeGjziY864z8WXXHzJJc8xSxV/+OHWsmWjeREy65krIJs2lL5EhbFp0Jk+4Ni9n1xD3kupopnn4SladdhhIooTSiYFoOlgu2BV2mS2r1E3Q8+StUXWfOpV+jfOrx6BoowIJx5VSXRggbKnWlATKWS3cqR3syh2O7bG6JUxLSKC8KMKkyQmXUj64qtCdyCFwifh2frlJbLBPsI9VgrwlgiCXYuVyOL37xiyxdupT5871xPVOnTuW8887rVxomHX76loqrAiZVRXijKd5vp7G3THcgZ2fvaWx2N0I1KD3rf/OdJ/fmZFN0vfQn4ksfQuj+nnNWF/Sbp2wnu+h+7UHiKx7FzaXxj5tL9Jj35s8U9b4mufZpEmue9s5RCQXfyBkEJy0kOHF+v5IzJ5vyGoztWk929wZyzVvyNzXwzmRrxdVo0QrUSLm3Gx2IovhDCCOAUI09JeJ4T4hdK4uby+Bkk145WLILK9GO3d2K1dWUv3mDd+7aqB7vlYPVTsVXO63fuXU70Ulq86ukNr5CZvsKcGy00jrCM04lNP1UtGh5/vuc3bWe+JJ/ktr4Cqhaz1iOD/TbJch3ZX/hj3uaxiz6eP5GvLd0/VLaH70dO9VN8aKPEz1234Znb1fvAqr3tnmgEnHpyPHe976XSCTC+vXrmTp1KtOnT2fWrFmceuqp+Hy+d91VWibYg0OuCYavlniGRMY7tqQrCvVtcXZ3pakrCuAqgjFlIUzbpbYkgE9T6E6brGroYkVDJ+UhH7bt4tMVIn6dWDqHabmsaujCsh06ExnaUyYt3WkSpoMG2DYEdMhY0N2yi8aHbyHXuInQjNOpOv0z2L4gKl6zrFyfOFObX6XzqbuxupoITJhP6SmfQuupYAMvyUy98RKxV/7sje2MVBCddwHhWWfuOQbm2GS2LiOx+ilSm1/1JnNEKwlOPo7gxAX5JqK972e2biPbsJ7s7jfINm7C6thFftaKUHqq0SpRIxWo4RLUQBFKIIJiBL3pIvsrEbeyuNkkdjqBk455PVnibVixlv6TSoSCXjYSo2YivhFT8NVN61e67tommZ1rSW96hdTGV7ATHQjdT3DSQkIzT8c/aiZCKERVSMRjdCx/hPiy/+CkujBqJlN0/CVUT5jHhGo/kYBBU3eO3Vs3sevxu+ncvBx/xShGnXs50XFHoQuXnA1Bn4oCZE2bbCrGln/dQfcbLxMeM4ujPvxVZk4eT0U0SF1pkICmsrMjSXHQIBrUMTQVgSBn24woDtKdMXm1vp2oX6OuJMCsuhJCfp2wT2NbW4KgoVEZ9ZPMWIwqC+aPLUhHlsFeE8AQS7AB0uk0iqLQ0dFBMBikqGj/C3Pp8NK3VBwB9l69p/Y+m/3DR9bzi+fr+71GEV4l9Dv5N9frFOrdUINTT/I6he6nARqA2b6Tzqd/Tbr+ddSiKopP/CihaSf1S+zsdJz48v8QX/owTqoLvXwUkTnnE5p2cr/xFrnWbaTWv0Byw0v5zuF6xRjvSfCYo/HVTtunLMtOdGC2bcds34nVudubX93d6pWCZd68K/b+KL6QNyszWo5WVIVeMsIb1VUxGjVS0e8JrWvlyO56g/T2FWS2LvN22PE6igcnH09w6iKMqvF7mr+YGVLrXyC+/D/kmjaj+MOEZ59LdO4F/XaZXdclvWkxXc//fs9c8VMvw1e9/4TWyabofPZeEiv+i142ivILvpKfr/luqAqoioJte+WHCIFty0ZmR7KNGzdy/vnns3Gjdybw1Vdf5bHHHmPNmjVMnDiR73//++/6M2SCPTjkmmD4chyXZM7CdV12diR5aXM7uurtVE6vjTK2IkLErxHxew+40zmLZ99oYWtbglTWJmionDS5kpBPozOVI5bOsqkpQUcyR3Msha4r7GpPsnpXnHjPkAg/0NtEXHEtul65n6YX/4oRLaP23C/hH3UUVs9r+jYbd60cySX/pGPxX725zUefTdFxl+x7j6t/ne5X/0525xqE7ic041Qis8/t16zTySZJbVxM6o0XSG9fAbaF4g/jHzObwNjZ+EcfhVbUv/LCMTOYbTsx23dgduzySsTzx8a6+vVXOSiqhhoqRYuUoxZVoBePQCsdgVE+ykum+0wbyZeKb19JettyMttW4ObSCM2Hf9wcQpNPIDBhfr91jGjciL3+MRqWP41rmfjHzaXo2PfjG+VNGlGAYg3CdNH0zH1seuW/qP4wo0+7lGknX4iJinAVHOEdJyjxKeQchda1L7Lqr7eRS8UpO/FjlB37XkpDCvPGVXHmjGrKwgYrdnbSmbSJZ01GlviZUl1ELGPS2p1hxsgiNEVhe1uSyogfRVGYVRslbTuksjaZnI2iCDRVEPZpVBfJUX9HokOxJoAhmGDvzXVdWcJxhOg9a/2HV7bR1N2/j+eZ06ooj/h4/xxvLt2X/7KCbe2pfq85uq6IyqifJ9c3v6MZx65tEVv8V2IvP4ASiFB25ucITjrugK9Pb11O57O/wWypR68YQ/GJH/VmWfZLSE2S658jvvQhcs1bvJvq1EWEZp7uPZXu81qzfSepza+Rrl9KtmGdd1NUVIyq8d7ucc0kjOoJXgnYATqGu7aFk/bGcTi5tLdb7djekwfwzlppBooeQPhDXuMTVd//e7kOVqzFKwdr3NjTZG2jV4MlFHwjphAYN5fAhGPRK8b0mdftkmveQnL1EyTXPouTTaKXjSQy9wJC00/d54FBZtsKul74I7nGDWildZSc9AkCExcc8L/b9LYVtD/6U+zuVqLHXETxoo/tM17snRDAh+eP4v1z6vLNywDZyOwIt2bNGr7whS9www03cMIJezr/rl69mq985StceeWVXHDBBe/qM2SCPTjkmkACWL69kzcaYwigLWVy6uQKRpeHaEvk0BRBScggkTZ5bmML3WmTzc0JIj6VY8eXU10UQFUUupJZGroSZE2XjGlh2oKWWIon1zUQz0Da9HamNSBseA9bo0GdxM43WPvAj0i17aJuwflMueAy4pZGW2rPzOxedrKT+It/IrbyMYSmU3HMe/DNex/KXg/rs02biS99iOT658E28Y2YQmjWmYSmnNB/vnQ2RXrrMtJblpDZuix/xlmNVOCrm4pvxGTv2Fjl2H1GdfVyXdfbmc4kcHMpnFwWbBPX7T2DrXi9WXQ/ii/ojeUyDlz67GST5JrryTZuIte4gWzDun5xBcbN8eZ1jzkaRd9zr7fTcVLrnyOx6gmvm7rhp27eGUTnnks8OLLfxoid6CT92l/oWPYoQghGH38RVYsuxlSCRIM+gppLNOQnljSJp010K86Gh37BrqVPUVw3gSkf+DJOyUgcF0J+nSk1UU6ZXM3Y8hCvbe+gOhqgM5lFURRqS4N0xNPkLDd/3l9TFWqKAwQNldriIH5dwbRdNEVgOg6ui3d8Qf4eOSIdijUBDOEEe9myZSxevJjLL79cjuE4wuy9O71oYjmvbesgZ3nnYhEC09p/F2hNFdi2+452sXvlmutpe+Q2zJZ6bw7zGZfnzxTtzXUdUutfoOvFP2J1NmJUjafouA/1JIhKn9d5c67jK/5L6o0XcM0MWnENoWknEZxyYk+H7j7NSnIZsg1ryexcQ7ZhHbmmTbiWV5gmdJ83lqO0rqchWRVatNzbhQ4VI4zgQf3id10XN5fGTnV55WDdrV5jlM7GnlFdfTqVqxpG1Xj8ddO9UV0jZ6D4Qv3ez+zYlR/LYbbvAFUnOOk4Ikefne8O2vezM9uWE3v5AbIN61AjFRQdfwnhmacf8OGBne725pCvfhKttJayc76Iv27aW17nwTLkOesh69e//jVLlizhrLPOYsaMGdTV1REIBPjpT3/KmjVr+NWvfvWu3l8m2INjINcEqVSKv/zlL8ycOZO5c+cOyHtKh0Y8bfL8xhbiGZOaYj/zxpTTEs/g11Sylk1HMkdJyGBDYzcNHSkaOhOMrYgS9qnUlgapLQ6xsSWG60DGdAjogtZ4lvqWJC9saCSRs2nptvFpUBbWiGUswj6NqF8l5DMwXJNX//5LNj37dwJFZdSdczn2qPnsNTIbFQiqYHfvovGZ++ha+wLC8BOZe4E3OWOvo052KkZyzVPEVz6O1dGA0HwEJhxLaNpJXt+SvXaKzbbtZHasJrtzLdld67ET7fm/14qr0ctGopXWohdXo0YrvR3oUAlKIHLQ57Jd29wzXSTejtXd4q0JOnZhtu/o16lcLarCVzsF/8gZ+EfNQisZsc86Jr1lCcn1z5Gufx1sC6NyHKVHn8m0E8+mvLSYTc0JklmvMZkdbye55EG6lv8X1zapnHM65SdcQnF5NdGAQXs8Q3VJCMu2SGQtNKGxdclj7HrsbuxMkmMu/BQzz/k4ccuhvTtDMmdRHPQxuizExOoIc8eWsK0lSUvcJOJXmT+2FEWBp9a1IhSXrOkyY2SU8eVR2pI5plRHKQ8bMpEeggZ7TQBDOMH+1a9+xWc/+1nuuece/t//+38D8p7SofPDR9bz37VNnD29mkhAz5/P7nsudjC5tkX3aw/S9dL9CFWn5KSPEz76nAPvHDs2yTVPE3vlL1hdjehlo4jOf59XOr7XDrGTS3vNwNY+Q2bHanAdtNI6ghMXEJhwLL4Rk/f5HNe2epqObSbXut1rQtbegB1vY5/vhqKi+EII3Y/Q+pzBdl1wbBwrh5tL42RT+y0dU8OlaKV1XjlYxZieBipj+p01B+/hQq5pM+nNS7xOpz1zOX110whNP4XglBNR+3QE7/0+pTa8TPerf+uZCVpO0YIPEJ511j7vv+dzXJJrn6bz6V/jZBJE57+f4uM/PCC71r16d69/8N6ZA/ae0uHlt7/9LS+++CLl5eWUlJTQ1tbG888/zw9+8ANOO+20d/XeMsEeHAO5Jkin05SVlTFhwgRWrVo1IO8pHTqWZZMxHXyGiioE2zuS+DWVVM6iMZahriRIOmvywpYWWmNZDF0hZOgcM7aUsE9nV1eKrGWxamcX6axNSdhHsV9nR1eSho40HYkMhqqiaQqNXVmm1oTImA4dyRyuELTHM+QaN7LmLz8m0bSNyJTjiJ7ymXyfEQMI+aAooAIKrmOzY2s9LS88QGrDSwjdIDzrTKLHXLRPibfruuR2b/DGdr7xIk66Oz+hIzDhWALj5qIG9u12b8XbvTVBc33PmmAnVmcjrrX3JG8QRsAb3anvPabL9rqKm1mv6q3v+M/er9X9aCUj0MtHYpSPxqgch1E9oV+j1152spP0ltdJbX6VzNbluFYWNVxKcMqJhGeclj/KFRQQ9EEiA1ZHA62vPkh87dPgOtTMPo2qRR/CKBlBNucQ8kPEH8B2XYr8OroCW7fVs+GfP6Nz0zJKRk/lkqu/y9QpU+nOWuxsSxE3bdrjWabWhGlJ5AgbGmURH3NGlhD2G+QclxHFfjoSGZZu7yRk6OiaIKCpnDipEp8hm5gNdYO5JoAhnGBns1mKi4sZPXo069evl/+RHMH6ns9WFYED2D1txPf+t1T0dKgaqH97zY5ddDx+F5ntKzCqxlN6xuX4aqce8PWuY5N64wVii//mNTQJlRCefS6Ro8/eb2drO9FJauPLpDa+TGbnGnBs77zVqFneGeyRM/o1ENnn86yct/PcewY71YXdUyLu5tK4lonr9I7pEt5IDs3YayRHsZdU95y36lvW1e+z8mM51pDZvorM9hU4qZhXLl47heCk4whOPn6/o03sTILEyseJL3sYu7sFrbSW6LHvIzz91AMm1gC5lno6nvgl2Ya1GCMmU3bWlfmRHgNFEeQ718vd66Hl3//+N88//zyapnHFFVfgui6vvfYaHR0dtLS0MGfOHM4///x3/TkywR4cA10i/vnPf56f//znPPvss5x00kkD9r7SoZfMml6HadOirTuL6YImIGVaaAg2NicI+QVTa4qJBgw2N3fz6tYOOrozJC2L2uIgI4oDVEb8VBUFwHVZvSvGltY4jV1pbNulKGjgODbhgI8d7XFSOZPWrhQ7n3+QtpceAKFQctwlFM27kCJDJxgRGKgoikMs5xIQ3nzuHfU7aXv1byTXPQuuS3DSQiJzL8BXN32fe7trW2S2r/TWBZtexUl1gVAwqifiH3O0N6JrxJR+R636fb3rervP3a350Z12KtYzurPn2JhtQp8xXag6iu7zknB/2BvTFSpBjZShRStRgkVvWjKebVhPZscqMttXkmveAoAaKSc4cQHBycd517n3poHrkNm2guTSh0huWYLQDKrnnsF5l1yGEq1gyY4uLBv8BlRF/Jw2tZqupMX21g42PH4fr/779yiawbEfuIITzv8gE6uLyORsshY4wnv/zc1xqqI+Ftd3MKEyRNKE6ojB5OoixlYEURWFRMYinrFoT2YJ6Coz64qZXB0h5NO9PizSkHKo1gQwhBNsgC9/+cv85Cc/4bHHHuPMM88csPeVDr2+s7Cvf2htfm72gUYuqYIDjvI6dkwJxUGD5za2krWc/b+oD9d1Sa1/ns5nfo2d6CA0/RSKF30i/+T6QF+T2bqM7tf/RWbrMlA0gpOPJ3LUWflGHntzMgnSW5eT3rrUmykZbwNA8Uf2nLWqGodeOXafOdUDrXfOdq5lG7nmLeSaNpPdvcG72QNKqJjA6KPxj5tLYOyc/Xb57n0qH1/5X1LrX8C1svhGziA670ICE449YDUAeKVzXS/8kcTKx1D8YYpP+sQBx3S9GxcdPYKJVRF5xnoIuuOOO/jTn/7EVVddxaOPPsqLL77IMcccw1e+8hWOPvroAf0smWAPjoFeE+zYsYPRo0dz3nnn8fDDDw/Y+0qFs6szheO4OEA6Z+O4Dm80xgEXXVWIpy2S2SzLd3bREMsyoSRAW8ZkRDTIuPIQI8uDLBxXTixjsXZXF9vbkrgCGrvSTKoKkzUdWuMZ6ltSdOdM4okMCdMh0baL+v/8ktiGVwmU1TLx3P9HZMpxZEwbTQG/IUhlHRShYjk23WlIdbcRX/YQiRX/9XqTVIwhcvTZXvPTvaq9oKdKrHET6frXSW9dRq5xU35sp1E5FqOnL4tROQ69fOQBH44PFCeb8irpWuq9NUHjRszW7YALioavdgqBsXMIjJuLXjluv+ucfFn8iv9ide5GCRYTmX0OIxecR3VNFaPLQvhVhY0tCRLpHD5dJeTXOXlSBbuXP8Pvbv8/OloaWXDGe7j4c19j6tgxtCez+HSFjOmyuyuFZTuEfN6OdDpnsX5XDBQF13WZWBGhOGwws7aYlGmjKlAR8dMSSzOyLMzEyrA8VjpEHco1AQzxBLuxsZERI0Zw5pln8thjjw3Y+0qHRm9S3Tf56TvOSwDHjCmhJZ7dp+HZmzlzWhXPb2olY751ct2Xk0sTe+UvdC/5pzdm4piLiM5/f7+mJPtjtjcQX/4fEmuezs+BDE0/ldD0U9BLavb7Na7rYnU1kt25xmsutvsNzPaG/KgMofnQSkegF9egFVejRsq8M9jBnnEcvjCK4ZWIo2oIoXhNTWwb187h9JSIO+lu7JQ3ksPuGclhdXnnrfaUigm00lp8I7xRXf66GWhldQd8om11t5Jc9xzJNU9htu/0GrtNO4nInPMwKt+807djZokvfYjYK3/BNTNE5pxH0fEfOWBH94Ml8HaqFUVg2S5CwGdOHMc3zj1wNYJ0ZDvvvPO4/PLL881KMpkMt99+O7/61a+46qqruPLKKwfss2SCPTgGo8nZRRddxL/+9S+2bNnCuHHvfvKAdGhkTJuW7iwuLlURP37De0DblcrRlsiStWx0RTCuIsLOjhSrd3XyWn0H4LK9I0lpQKctaaEJl4ChMbO2iLryIJWRILqqYKgKtmOzsiGGcCGesZheV8SIogAt3Wk27I6xuS3Jjo4ksVSWrlQOv6pj71rJsr/+lHjzDsJjZjH2nE8RqJmAT4GIXydrQXfGJGdCV8+JLCeXIbnuWRIrHvWan2oGgYkLCU8/Bf+Yow94Xrp3bGemYR253evJNm7Gze1Z+6jRSvTevixFlfm+LEqwCMUf9krENZ9XNdb7sNp1vEo3M4OTS3k73T1nsK14mze+M9aE2bE7/9AfvAf/fUd1+WqnHLj6zTJJ179OYu3TpDcvAcfCVzvVq+6bfAKGplNVpFNd5KfY78PGoTGWQcWlOGKQa1jPy/f/jIaNq5k0fRZf/vb3OOnERSRyFrjQ3J1BUwTJrE1jLM2o0hBlYZ1dHRnq2xP4VUFzPEtJ2ODoumJcIRhRHCDi0ykLGeRsF7+uEjAO/NBfOvIdyjUBDPEEG+Diiy/mb3/7G+vXr2fKlCkD+t7S4Fm6vZMP3+2VheuawvUXTKczlfN2sP+9hlzP9rShKXzquDH7jOx6M71zjt/u3/WyYs10PvtbUm+8gBKIUrTwQ0Rmn/OWZ4IdM+udvV7zJJntqwEXo2YiwcknEJy0EL1kxJt/fS6D2bqNXOs2zPadmB0NWF1NWLEWsM23iPogqBpatAKtqNob01U2Er1iNEbF2Ld8iGB1t3ijRTa86HVAB3y1UwnNOI3Q1EVv+fWubZFY/SSxl+7HTrQTGH8MJSd/Cr185D6v3d/PSFX2He3WS1cFF88bme9AL7uCDw+/+MUvWLt2Lddddx2VlXtmy2/atInvfve73HzzzVRVVb3JOxw8mWAPjsFYEyxZsoRjjz2Wz33uc9xxxx0D+t7S4HBdl+3tKZJZk46UiSagutiP6wpKQwam5dDUnSGgq/h0FV2F+pYkz21sxnVdVjfEKAnpxDI2QVWwYEI5k6oi3vEmXWFTSzc+VaO2JEhrLI3jutQU+4kGDDpSObJZi63taZJZk23tSUzLJuTXUISgpTtLNptl7VMPsvqR32KluimadiJjTr+UihHj6EjnSGVg75PNPsABEk2bSax6gtT653AyCZRA1Jt/Pel4/KNmUGToKAKyFmTof+9zXQerq4lcy9Z8Xxaryxvd6aS7B+R7rwSL0Yp7x3eORC8fhVE5FjVa+aZHL10rR3rbClIbXia16RXcbBIlVExo2smEZ56BUeHNoTfwurH7VaiIqpSH/PgMjaCmkm7cypJ//IJNS18gWlbFR674Mief936SOa+CsbLYT2nARypjkTAtqiI+UpbN5KooG5riWI4Drsvm5jiVUT9V0SBlYZ26kiCRwMD1cJGODIdyTQDDIMFesWIFs2fP5rLLLuPuu+8e0PeWBs81/1jNfa/uyP+zKrwbi6EpLJpYwRPrmnF7/vy4CeW8sKntgO81WLKNG+l67ndktq9EDZdRtPBiwrPOPKjmW1Z3K8n1z5N644X8LGm9bBSBCcfgHzsHf+20Nz2b3Jfrut5OdKLDO2uV7sbJpnDNDK6Vw7UtvO+e8MZ0qTrC8KMYQZRA1DtvFS7pOWd1cKVRrm2RbdxApn4Z6frX8+eu9PLRBKeeSGjqSQfcnd/7fZJrnyH2yp+xupowRkym5KRP4h/11s3GSkMGHcncm77m2DElfP2cqTKZHoaam5u57rrryGazfOITn+D444/HcRzi8Tjz5s1jw4YNBAIDM8dUJtiDYzDWBACzZ89mxYoVpNNp/P7BLauV3j3XddnSkqA1kcWnKXSlTCIBjZElQVI5G0NTcF1vfZDMWgQNlc5kjsX17XRncrR1p2mLm7iKi+IK5o8voyTgI5bK0ZLI0p02KQqqlIX8zB5TggqsaIjhONASS6NoguKAD9txMBRB1nZojmWojPrY1ZUlbCi0JbK8sbORNY/ez7YX/o5j5iieeQrF8z9ItLyWWM8DYAUwFCgJKsTSDinb6z5u9ezyJtc/T3rLElwzg+ILEZ0wh6IJ8wiMPhozUIaCl5hbeF8H7NPNHLzZ2HaiAzvZ5a0Jes9gmz1nsHuq4RAKQtMRms9bE/jDqMEoSqgELVz6tpqJWrEWb6xY/evePGwzg/CFCE6cT2jqSd7ufJ+jYRpgCO96dBUCPoj6DErMZtb+53esf+VJAuEoJ73/U8w654PUlZfQljCZOiJKSFPY0Z7muEnlJDMWXekcM+tKyFk2qqLQ2JUibTmkMjnW7opRXRyiKKhz4oRyKqJyfvVwdCjXBDAMEmyA+fPn89prr5FIJAiFQm/9BVLBfesfq/lTnwS7lyrgkmNH8fdlDfnd7W+fP73frvahltm+iq4X/kh21zrUcCnRY95L+Kiz3nLXtpcVaya1cTHpLa+S2enNvxaaD1/tFHwjZ/TMwJ64z1isQ8nJpb0zV7vWk9m5luyudV4JeU+Ds8CEYwlOXIheWntw72dmSa55itirf8eONXvjzU64lMD4Yw66IaEqQFUVbNt7mm3u5+dfHfVxx6VzZYI9TKVSKe644w4eeOAB6urqGDNmDGvXruXYY4/lBz/4wYB9jkywB8dgrQn+/Oc/c8kll3DTTTfxta99bcDfXxp4ibTJyoYYmgoCgaoKaooCZHI2pWGdjqSJ60LAUKiM+OlI5ehIZGjpzrKzPcmG5jilQYNdsTSVUR+KEHQmTXyqS2Msi+WAEC4nTqpkZEmY5Ts6qC0OkDQtdranmVgVpjWZpa4oQMin09CZoiJisHJnFwEFUq7Lxt1x0jmLpqZmNj51P22vP4JrWxRNO5GSBR8gWDnWa6qpg1/X0FSXti4bVQccr29M0sFLgnesoHvjYlL1r2MmvDnTWmkd/lEz8NVNxzdiMr7iGtSe++UA1LC9La7rYHXsIrvrDTIN68juXIPV1Qh487CLJswjOGEByuhZ+0xSAW8Hv7bEwLYdujIWAUPFatrM9mfvp23ty+j+EOd+6JNc+YUv4vhC7O5MUxTS2dSUQNcURhYH0DWFiN+gPZFlXHmY4rBO1K/j0xRwvdL8lza1sK0tRU2x14V8anWUSTVFFAVkE7Ph6FCtCWCYJNj/+te/uOiii/jud7/LtddeO+DvLw28pds7+fCvXsG0XVRVoAC246L3dHuG/qW+S7d38ovntvDU+mYO0PfsHVF6plu91Vu6rktm+0pir/yF7I5VKL4Q4dnnEJl9/ps2Q9ubk02R2bmazLaVZHauxmzZRu+na6W13niMijH5eZdacTWK7nvH17fPdVg5rFhzz8zLnV45evNWzI4957/1slH4Rs3EP3oWgdFH7bc5y4HYyS7iKx4lvuw/OKkujJrJFB33QQLjj0UIwejSINs79n+e/ui6IlY2xPI/CwW4ZP6ofMn48u2drG+K7/N1fl12CB9uXNfd50HNo48+Si6XY8qUKYwcOZJg8OAegB0MmWAPjsFaE1iWRW1tLaqq0tDQIJsaHSEypk1nKocCWA5YjkNpyOg562xjOy5+TUXpSZwcxzt/3RbP8t+1u7Bt6Eqb1BX7cV1IZiwSGZPtnSnKwwaJjE3IpzO6LMSsumI601l2d6XIWQIhXCqjPiZXFVEZ8dOWyLCxKcYbTXFCPh2/CpGAj12xJI2dWdbvbifW1Unzi/+gc/kjXk+R8XMZtei91E6dh6oomK6CwCFnOSTSFgFdIWs7JHPg0yCdA8VxibVsJbdjOantq0ntXJc/d634wxhV4whVjUWUjkYtq0MvGfGmXb/frt6u5Fbnbsz2Bsy27eRatpJrru8TRwTfyOn4R83EP2Y2wbKRhHSB40DK8Xba9641M4CRJRoRn8LWZS+x48UH6dq6Gj0QZtZZH2Lhez7KqUeNY2x5iOZYlhU7OzAdl+KAju3AlKook2sjtMVztMSyTK4Jo6kqfk3Bdl0ypkMqa7KhqZvtHWkMRaDrCseMLqU45CPi16iIyOqV4eJQrwlgmCTYtm0zatQocrkcTU1NqKpsZHAk6NvkDN767GzfBmgD7WDOZvfK7t5A96t/J7VpMQhBcNJxROact9+RHG/FySbJ7t5AtnEjuabN5Fq2Ysea+71GCRV7I7ZCJSjBYtRApM8cbL1nDrbiNTNxbK9s3MziZJM46bjX5CzZ4Y30SHT0e281UoFROQajagK+EZMwRkze7zzON+O6LrnGjcSXP0Jy/fNgmwTGzSN67Pv26aj+Zt9nTfUalPW+zqcfXPWCKuDqMyfzv6dMeFtxS0e2VCp1wBvm/m6274ZMsAfHYK0JAL7//e9z7bXXcv/993PJJZcMymdIhdWbYLcnciTSOZI5i6xpIwS0xnNkchZNiSzpjEks7XWUHlcZxHZgzqgy0qZNLJ1jUlWE1riJpkJ52IflutREA7y4pYVk2qYtkaWxM0V1cRBdEezqSrCxMU5XJkdQV3DSCXa88hCNix/CSnbhrxjJrDM+wNhjz0IxggR9gt1dGXKmiwlkTROBS2Niz7UU65CzIGvbpNt2kN29AbNpE9nmesy27bjWnhRW6D7USAVapBQlWOKVfPvC3hgu3QBVzx8H85qfmjhmFjebwskmsFPd2Kkub00Qb9vrvf0YFaPxV42nqG4CRs0URHEdjlCw8B58h3RvvFbI0PArKrF0jpaU26+U3ZfuRmx9jqZXHybWsotwWRULzruUE8//AK4RxFAVgobG+MooigIbm2K0deeIBDSmjSjC0DX8msLo8hBtiQyZrEN51E9JyGB3R5qMbbO1LY6hql53ecdhVHmIyVVRHLw1RG3JwCZU0uHtUK4JYJgk2AA333wzX/va1/jd737Hxz/+8UH5DKmweudl5yxnwJPst5Ng9zK7mogve5jkqidwskm00jrCs84kPP0U1PA73011cmmvmUnnLsyuJuxYC1a8HTvZgdMz77LvDfGAVA3VH0HZa+alVlLT09Ck7l2VpdvpbpLrniOx6gnMlnqvm/iMU4nOuWC/zcvejgkVIeaPK6MlnuWJdc37fY3aU36gyxnXw8bDDz/MX/7yF4LBIEVFRYwZM4azzz6bsWO92en/+c9/OPnkkwf8qJBMsAfHYK4Jurq6KC8vZ+7cubz66quD8hlS4SWzJltbk8QzFtGARipnEU/b1LfFifhVNEVlZEmAV+rbaYmnUYQgZGjMGV3CyNIQW1uTqKqCaTvoimBkaRgHh/KwwfJtHazd3U1LPINp2YQDOpoqaIllMVTB+t1d+Ayd4qBBLJGhvTvB7hXP0rj4IdKNm1F9AUbNPZm6Y84hNHIqlgtp0yJkeGfJY0kX0/aamwXwHhajgNXzZxreWWzXsXFjLejxBuItjWRjzVjdrfm+LHa6Gzeb4q1XMQLFF0QJRr01QajUa3xaXIVWXINeVocarUAVCipQGQJVV1BdQSJrk8mC6UJJSCUa0vBrKj5VoTubo707Rypt07F9Nak1T5DY8DKOZVI18SgmnvQ+TjnnPCbXltARN2lPZGhL5hhXESLi0wkYGtvbkqRNix2dKaZWRxhRFMIwvO7vQngl96NLg2Qsb6xaNKCzZleMMaUh/LpK0FCoLQmRznkPWGqKA/h1udk21BVqTQCDlGALIW4GLsCrCtkC/I/rul1v9XWDeTONx+OUl5czbdo0li9fPiifIRVe33nZj65p5KXNbYOyo/12OGaG1Bsvklj5GNld60Eo+MfMJjTtJIITFxz0We23w+19Im1mwbHpbXKGoiI0A6Vnd3ugObkM6S1LSK5/jvSW18GxMKrGEz7qLG/W5wBeqyJAiP3PQjf6dJ6XHcOHj5EjR3LXXXcRi8WwLIt169aRSqX48Ic/zLhx4/j9738/KGduZYL91t7JumAw1wQAn/70p7nnnntYvHgx8+fPH7TPkQova9mkczYdySx+TWV7exKfpiCEQtiv4dMUnn6jBcexexqqaowpDyJccHCJBnTqioOYtkMia2E5DjvbEjiuwqpdHXTGTWy8EZCKUNBVl50daWwUNNchbVns7Mhguw62Y1MU38n2xY+wbclTmJkU/tIaSmecRO3cUxg/aQrb2tJ0Jax89/ESP1imd2Qt8fYmjALeTrWby+BaWa/xqet4GwdC8Rqf6j7vfwfR6NQPaIrXmMzQBUnTpdSnYDsgVIeqaBBFVXEcF9O26d65hV3Ln6Rh6TNkYq3ogTDjF5zFxJPew+QpM0F1GVkUZHNLN36fzsyaCA0dGUZVhFAEBHSNNbtiNMZSdCRyjCoNURXxM2NUcc/Px5t7HQ0YGKqgMZamONDTBFVAxK95XxP1YzveTqU8fz08FGpNAIOXYJ8JPO26riWEuAnAdd2vv9XXDfbN9KqrruK2227jvvvu4yMf+cigfY50eOjd0TYtB1XxRjVNH1FEZyrHyp1dPH6AnU8AIeDCo0bQnswNaIdys30niTVPk1z3HHZ3C6gagTGzCUxcSHD8Me9qZ7tQ7HQ36S2vk960mPTWpbhmFjVcSnDKiYRnnvaWs6/3VhP10didPajXKoBQBI7jndU/dXIlFREf75tTJ5PqYebVV1/l+uuv59FHHwUgl8uxZcsWnnvuOR588EHuvfdeamtrB7wMDGSCfTDeybpgsNcE9fX1jB8/nvnz5/PKK68Myr8b0uHDdlx2daWwbDc/TztoaCiKIJ2zeWVLC51Jk8auNKPKAlRGg1iOS1BXaUtlmVoVIeTTWbs7RmMsw47OBLVFAd5ojOO63jqjsTNNyKfQlbbJmTY1RX6mjoiypS3F6oZ2EmkL1xFMqI4SDRq0d3Xw8pOP0bjsCeJbV4PrUFQzhvIZx1M+bT7x8BiiAQNch3jKwRaQ3l/b8Lepd+/W4e1X54FX1acBVRGVWMKmrFgjbOiURwwMHBo2rmHj68+yc8WL5DobEYrKiOnzufhDH+Z977+IgM9gZ2cGx3FY35RgY2MXbYkcmqIyqtzPiGiImmI/ZRE/4ytCLN3WwZPrW9AUQVHAx/jKEPPHl5HO2vh0hZzlEvarRHw6ugq2Kwj5NMpDBi6gqbLPwnBTyDUBHIIScSHEe4EPuK576Vu9drBvpul0mtLSUsaNG8eaNWvkzXQY6HuOu2/CtXR7Jx/61Sv5M737oyqC48eX8fwgjABzXYfsrg2kNrxIauPL2N2tABhV470xXaOPwlc7BUU//JpwuJbpjejavpL01mXkGjeB66CGSwlMXEBo8gn4Rk5HKCq9D4kPtopAACdM3P/YNUX0fx9FeDvV3z5f7lRLkEgkeM973sOkSZP4+te/ni8BA/jOd76DEILrr79+UD5bJthvz8GuCwZ7TQBw5ZVXcscdd7BkyRLmzZM/wqHOsh1ytoOmKBjanqTLdlxeq2+jPZkhlrapiPiojgTozuRoT+awHJfysEFN1M8T61uI+hXqW5OkTRu/ptAUz1AW9CFw2dqeosivoqkqyazD7NFFAKzY1kE8Z5HM5CgK+jhlag3NXSme2dCK67rEOppJbHiF2PpXaN64Atd1MMLFFE+Yi2/0LEKjZuIGK8ngJcjvJs/uHb6l4JWbv12948IAAooL3U2I5nWkt62kcf0Sssk4QtEIjJqJf/LxjDr6OOZMGsP/O3EsI0oiNHelWbK9A9OycG2XZza3EOop2U5kHT44byR1pWEqwj40TaE7neXFza3ebGtb4bjJ5cyqLcZ0HCzHRUFgqAKfrhEw1EE5UysdWQq5JoADrwu0AfyMTwF/HsD3e8cCgQBXXHEFt956Ky+//DLHH398oUOSBtnc0SUHTLqm10RZ2RDL/3PvrO3eJM523EFJrgGEUPDXTcVfN5WSUy/DbNlKuv510vWv0/3ag3Qv/isoGkbVeG9cV80kjOoJaCU1Bz2reiC4rosVa/YarDVuJLv7DbKNm8A2QSgY1RMoOu5DBMbNw6iZuE9smlfHjWk5B/WUXBFwzowaXtzU1u/1CvC9i2aydrfXRXxGTyWCTKqlXuFwmL/97W/cdNNN3HDDDUyePJmTTz6ZhQsXsn79embNmlXoEKU9Dpt1wde+9jXuuOMObrzxRv79738XOhxpkGmqst/dTNN2iAR1NFVhfIWGC4wqCbK+ySJj2lRF/WRyNrGMRdSvsbsrSUssSdBnUFsewXRcSoM6rfEcwnVpT+bQFZVRZQHv64sCBH06Vk8MhqKSM22qivzUlgbZ3Z5CDZUyYdF7qXvvRylXTdYteYHVrz5Hw9olZFY8BYBeVElgxFTCoyajlE8gWjMWUw3s06X7reTwknSXt5+sO9kkmeZ6ck2byO7eSG73Oqy41yA1VFLB2NmLqJy2gNboJIQRwrTBNcCvCZq7c4ytECgqlAV1bFfjjeYYIV3hjaY4hqowujRIY3eakN+gNGJQ7FPY2WHiOoKo30d52MfMEUUUBQ88o1sm19LhuiZ4yx1sIcSTQPV+/uoa13X/1fOaa4B5wPvcA7yhEOIzwGcARo0aNXf79u3vJu631NjYyIgRIzj77LPzZQPS8LJ3MzSBt1t944UzeGZDywGbZh0qTjZFtmGtN2+6YR255s35pmVC96GXj0IvG4le4o3pUqOVaNEy1FDJfudNvhXXtrCTXdjxNqzuVqxYE2bHLqz2BnJt27151wCqhlE5Hl/dVPwjZ+AbOQP1IEZ0KYCuKWSttz44pquCG94zg+/8e02/udaXLxrHN86d+ravTRp+YrEYzz//PMuXL+e5556jvb2d2bNnc+eddxIIBAblM+UOtmcg1gWHek0AcMEFF/Dwww+zY8cORo58dw0YpSOPaTvs7EghBHQkcpSFdGpLQ+iqQlcyyxtNcRwXDFXBrym8trWdjS1xHMehJKhTXRykyKdjuzbLG7po6kzTHM8Q0FUWTKwgpGtUR/2s3R1jU1McXVMoCegE/QrjKorY0tzN2t0xVKFQEjKYWBECTUU4UBxUWdPQyRvr1rFp5Wu01q8m1fAGVry9J3qBXlyNXj4So7QOraQGJVqJGq1Ai5QhjODbTjZd18XNJnHjbeS627zRnZ27sTp2kWvb4R1v66EXVRIZNYXisTOZPf94ymvHYRgaODard3bQ3G3j02BsZYhoOMCsmjBVkSCKImjsSpHKuWxpjVMaVNnRlcG2XUaV+CkK+jnvqBEUB31Yjk3Qp9GVyLKrK8uYsiC1JUGKQwdOsCWpVyHWBDCIJeJCiE8CnwVOc113/4Nt93IoysEA3v/+9/Pggw9SX1/fr2RAGh7ueGYztzy2obftFydMLOdLp0/Kz9POz99WwHbe2fmkgeTaljd7snkLuZatmG07MNt37jNSC0AYQZRAxOsEqge8RmaKhhACFxdsG9fO4eQyuLkUTjqOk03u8z5KqBi9tA6jYjR6xViMqnEYFWMHpTFaX6qA4yaU92tO1zuSS3YCl96OTCaDz+dj+/bt1NbWouuD9++uTLAPzttdFxyqNcGzzz7LKaecwpVXXsnPfvazQf886fCSylqsa4zh01QMVaGyZ9wTeMlmImuRsxxCPpX6lgT/XtGAT1cJ+1R8hsqEyijFfp2OlEljV4rnNzTTnjJp705juTC6NIhP1zhlSgWbm+O0J3OYlk3OhtKwD00RaKpCRyJHznEIGzpjKgLs6kpTE/WxuzPDto4EiXSWXZ1ZFBVCVoKOreuIN24l0bSNRPN2sp2NuJbZ/+JUDTXQO6bL7zUyU3UUoeDS04HcsXDNDE42jZtNYqfj4Fj93kZoPoySGozyUQSqxlBcO5Zg1QSKKqsI+zUmV4XQNZ32hMmU6jCqELjCYdPuOMmsSTjgQ9cUZo4sYXdnCl1VaO1OE/SpIBRiKRNDV3Bsh9rSIOMrIiwYX04qZ5M2LYoDBrtjaTQhGFEcIGXajC4N5WeeS9JbOZRrAhikEnEhxNnA14CTDja5PpS+8pWv8OCDD/LDH/6QX/7yl4UORzrESoJGPml28UqSexO3uaNLuP8zC/OdyK/5x+p+X6sImFlbxKqG2CFLvIXqlYobVeP7/bljZrC6mrG7W7HibdjJTpx0t5c051K4Zsabee2kvNajCISigmZ4T7Z9I1H9YZRAEWq4BDVchlZUiRatHJTO5m9FEd5O9zkzaliyrYOs6ZWVu0DOdFhc3y4TbOkt/eY3v+HMM8+ktrYWgDFjxhQ2IAk4vNcFJ598MrNmzeLuu+/m//7v/wiH37oyRxo6MqaNZTt0p02EgFFle0b2CCGI+L2FeCyVoy2RwaerJLM2QggWjK9kUlWEhq401bpCU1eKeCaHoQgURcG2HBSh0BrL8uLmdoQDFUUG9S1JysMGFWGDXZ1pwKU7azK1OkzU52dcZRhNUbFshwmVEa9Bm6vQZbiYjoURKqPu6BNRjl5EUdhHZzJHddjAzXSyas1G0rFWOttbScdjWKlu3GwSJ5fGtbIoVhbXdbBdUBQVFA3FH+m594dQAlEi0WJKKirxFVWgFFWjBUJ0pBUCPjB0hbDPR2lAY8G4ChA2I0pCvNHYzfjKEJVRP4oiSGRsFk4IkM5ZRAIauq7RnTFpT2aZUh2lK52jLWFyzNhyUpkco8vDZG2LkSUhysN+utMmjgu1RUHiGRMFF8t1aUtkCfl0ZBW4dDCam5v597//zWWXXYYQouBrgnd7BvvngA94oqc0ZbHrupe/66gGyMKFC5k7dy733HMP1113HXV1dYUOSTqE1uyOvek/957b/szvX98nif7wsaOYPqKItbvXYDtuQXe3Fd2PUTEaKkYXMApP2KeSyO45xaWrol+J9/4YqoLleGX6uio4emQxuZ4y8m+fP517X9rK5pYE4DVTKXmT81aSBN5s4yuvvJKFCxfy5JNPFjocqb/Del1w1VVX8T//8z/cfPPN3HDDDYUORzqELMdBEQKfpmDZLumsRcDoPyPZdV12dqVo6ExTFfUDgpqiAJOqItiui64IOhImAUNjUmWUja1JykI63RmbtlQOXRMYCrRlTEYbQWpL/AhFoTTipyttM6kyRDxroQmBogkautJ0p0wmVIXx6yrbOpKYjkNFWKM1aWO5gnFlIUzTRQgH16fgCpW0XszIaXPozpgE0ia2DameTW0BhP0wtiJKeVChJZGlJWYSz1pYFjgCXAcCGhgGjCwNE9AFhqrQ0JUhbmYxBKgoHDUiwqTqKBXRAH5DQ8VFFQpBXSORs5g3qhQbl1jKZFdXCoRCyKdSU+THtW1UTaMkoFMV8TOyLIDr+JlZV0za9Lqx50yHzmTWm2udyFAR9mE5EEuZuCGI+g15zlo6KN/5zne45557WLhwITNmzCh0OO8uwXZdd8JABTJYbr75Zk499VSuu+46fvOb3xQ6HOkQ2vtX8v5+RS/d3smT6/ufxRYCIj6NGx9e2zNLsWdjWGJaTZTlO7vySXXv/60tCTC+PLTfZnEnT67If48t22XJtk4AVjasRlMFdp8EXRHQmXq7bVyk4eYHP/gBqVSKK664otChSHs53NcFH/3oR/n+97/PjTfeyFe/+lW5iz2MhH06puPg01WKgxpZZ99+IRnTIZ40SWQdulI5qiMGUb9KUyyF6XhHsFRFpbY4QMhfRTDYSVt3hrqSALG01yk7Yzu4QqE1nmVUaZiykMHIsiDFfoOc5RDABVswfUSUaMigNZ6hM2nh0xXmjymjKuJne3sCpT2JJgRp20JVVDRFoBsG5VEfjR02oysiJFIZ2lIWRYbG7u40jivQVK9c/ejRZYT8OmUhg9W7uhhT4mf1rgTrGjtIZW38usByFSxHoGsafl0jEnAYp2uEfILKaICzZo4gadqU+HUUTSGRzqEoAkMT6KpKwFApj/h5qbMFQ/HOrndnLI6qLWZGbTFrd3VTETJQVYGuqIyvDmG74NMUVEXQlsiSNh0qon62tSfRFYGuqRQFvaTclosv6SDs3LmTX/7ylxxzzDFMnz690OEAA9tF/LB0yimncNxxx/Hb3/6WW2+9leLi4kKHJB0i75tTx1+XNmBaDrqm8L45+1YwPLisYZ/RUq4L97y4Nb9zLfDODL/FRm2eqgjGlgXZ0prcp0O2C6gKzBlVQkfKpL41cdCjrQ4HRUGDBz6zkE///nU6knsS4cauND+9ZDZnz6jh0TWNlIUM2pM5zplRw+TqCM9ubMW0HO+b2ed6+45PE3jjuBaMKzt0FyQdcUzT5Gc/+xkTJ07k/e9/f6HDkY4wmqbxox/9iPe9733cfvvtXHPNNYUOSTpEAobKuPIIyayFogjCvn2XwDnbJmvbVEcMMpaNDYR8Ojs6MowsC+DXNBzHRVNVtrUnmVQRYeG4ckpDBpURP/UtcZbv7CQS0DEUQV1piJl1JTiOw+ItraxrSuPTBadNqcTQVVI5m6pIgFGlCn5dpaU7g19XSecsmmIZXCCTs0mmMuRcBVwX03bQFOHdMw2d4p5Z0NUKFAf86JpLScBgfFUI0xJMrAgzsixExFARqiCWyeI4kDYtaqIBbzRZyKC+NYHPUAlogoDPx/TaInIWlAZ9BA2VsF+jIuxjZ0cav6ZSHPKhqoKKsIHrCurbkwR0janVETK2S01A57iJ5d731XJwHAe/oZHKWqxv6iaeMikOGQigO51DFV5ybaiCRMYhazlURA6/EabS4eeWW24B4LbbbjtsKh6GfIINcO2113Luuedy6623ypKwYWTu6BLu//SC/c7H7nWg3NZxXVRF4Loues8M5j8v2ZEf9yWA06dVEUvlWLKtExev/PnieSOZMaKIe1/a2u+9p1ZH+N57Z+bPfN/48FpyluP9IhjgJ7THjilhxc4uLNvlrft5H9heuTAAlREfG5ri/ZJr8C7hwWUNjCgO5BvJ9Vq6vRPH8c5ZK27/99RUgeu4qKrCB+bW8f45dfL8tfSm7r77bjKZDF/72tcKHYp0hLrwwgupqqri9ttv55vf/CaKcuhGIkqFI4TwxnBZNooQ+HV1n9cEdI2gTyeetRlfHkJV8JJxv0YiY2FqLqUhg7KQD9d1MVQFVXhjKouCBrPHlFFVFCCetdAVhcqIj7Z4llW7utjWmqCmyEfI0EERVEcDZCybVNYilrEwbQu/phINaKRMh5BPI+JT2dQcx7IBHCzXJWvZjK2OEjB0AoaCKhTWN3YT0DUQLqVBH0ePKqUyEmRceZjikEHGtNnYGMO0FcaUhwhoGs3dWQKGgiIgnrPIWA7daW/82ISqIsKGga4p+HQV23UxVBVVEUyoihL2q/g1jbKgj3jWptivMa2qiJ3dafw+Db+uEkt7NetRv94zi9z778zFxbJdfIZC1rIZVxkmY9qMrwgTNDSylsPoMoFPV/Fp+/6MJKmvXC7HnXfeyfTp0znuuOMKHU7esEiwzz77bOrq6vjpT3/Kddddh6YNi8uWePP52ADvn1PHX17f2W8nVcHbSf32+dP7zWCeXB3h0nsW53fELz9pfL4jeW8SD/Dhuxfnzxj32thzxvh/T5nAHc9szo8OU3D3m8i+Xb1N2T50zCg+Mn8US7d3cuNDa/vN/347ekeaOa6b32E3eqoAbnty4z6v1zWFv76+E8tx9/ne/X1ZA73fDgc4ZkwJfl3N726/2QMQSerLdV1+/OMfE4lE+NjHPlbocKQjlKIoXH311Xz961/n/vvv59JLLy10SNIhoiiCoHHgNaChKUypiqAq4LgCy7JxEVRH/UT9OkKBgK4ihGBEUZCWeBYXqIr48u9RUxyg1PLOe9uOS1c6h8DFdGy2taYoDevUFQegBIKGRjxj4VMVDE0hmbXJ2Q67OtI4QEcqR0VxkICqkEhbJE2LY8eWE/JrZHI24YBG1Kfh11SE4tCVtJhQFWbOmFJqigL57tt+XaUi6idkJIj4DTKmw7iKEMdNrCCZNnlo1W4mVEfY0ZYiY9kUhw0UTaE8rBPy6QR0lYqIn+6syaTqCLbt4tMVSsM+EhmLWNZG1QR1RX7GlgUB13sQ70LWcigP+0hmTVRFIZNzSOUsysIGHQkTn6ZQV7JnzJg8tCG9HXfeeSeWZfH1r3+90KH0MywyTSEEX/nKV/jSl77E73//ez71qU8VOiTpMDF3dAl//sxCHlzWgAvMGFHUL6ne+7X3XbbvjnjfJP6OZzZ7pdB7cRyXvy9rYO7oEhaMK8PQlHyirglBImfv8zUHq3d3vG+8c0eXMKO26G0n2L279n2T5JKg0e97Mr0mygt9zlofO6aECVURHnhtB47rdQL/9r/W4PS8z6KJFf0+I2c5fOOcqf2+f5J0MB5++GHq6+v5xje+gc/ne+svkKQDuOKKK7j22mv58Y9/LBNsqZ9o0GD2yFJytoMqvDWkqoh9Sk+DPo3RPU3S+v6d6LM7njFthPCOQ6WyDromKA7otKZyBNuT1JQECBkqzRmLnO3g1xW6UiYVUR+mZdOecjlhXAldKZvWhElpWGVENEA8ZzK6PEh52E8sbTK9tth7GC5gRJGfooBBtmct4tcVhPDKyCdVR6gu8tOdtSgJ6OxoT9GezKALQTxrEfIphAOq13wskSUV9YPwEuCQXyPk3zdtsB2XuhI/qayFrmnUFgfZHcsQNFRcIJ2zaIw5JDIWjusSMlQyps3W1hQBA1oTOdKmQ1XUv9+qAkk6ENd1+clPfkJpaSmXXHJJocPpZ1gk2ACXXXYZ11xzDddccw2XXnqpXJxJeW+1y/12XrtgXBm6puyzg+0Cf1vakC+B/uTCMfx3bRNnT6/mL6/vhL0S7KChkrVs7APUeE+tjlAe8XHOjBo+Mn8UQL+d9LmjS/Jn0HOWk2/w9lY75bbjogj45MIx+d3lydWRftccz+6ZmymAkyZXsmBcGQ8u8867C7Fn5ztrOvs0LVvVEOPSexbLedfS2+I4Dl/72tcQQnDVVVcVOhzpCBeJRPjMZz7DHXfcwV/+8hc++MEPFjok6TCiKAK/8tbJ3lud9/TrXkO0lliGOWNKcSyb7pxNddSPUATdaZOykIGDQyJlUVweoq4kSEDrJJXNURb0UVsSYv74MKqq0NiV8na8dRXXcTF0hTHhEJURP8VBI/8woCOZoyOZBaAooFMR8RM0NMojfsI+nTGqoKk7Q0ssQ8a2qCkJUmxZGJpAuArVxX4SWQvTcQCXWDqHonhN4EI9JeAAlu2wK5YiazqE/QbFQQNVVSgJGvmjZMVBnW3tSRzH6+QeS+WoLQ6Ss22aYxkiPg0hBO2JLLUlh350qHTk+vnPf87OnTu54YYbBn3e9dsl3AJ06Js3b577+uuvH/LP/fnPf87nP/957rrrLi6//LCZGiINMQcqz1YFXH3mZEqCBt/qM3d70cTy/Xbf3rt0vPc2rquC+z+zcJ9zzpfe45Wm9919jqdN1jZ2UxYy+OeK3Qd9DaoiELjYzp7P29AU596XtrKlJZGPy9AU7v/0gn6l8iVBg+v/vYZcT9m9qnhlcvv7XvzvKYd1w2HpMPLPf/6T9773vXz7298uWC8NIcRS13XnFeTDh7BCrQlSqRSlpaVMmDCB1atXHzbNcaShpymWIZWzSJsWiYxFTXEA03YoC/lI5yzeaE4Q0BWylsOcuiIeX9vIy1s6iQRUQobO6dMqMR2HLc1JJlR587Jd16EyGkAIQUBXCBoaRT1jLre3J9FVr1N3VzpHUFfJWA6W7WDaLo7rsL0tSWcqR3Ms0/OwvJpoQGVXLI2uKmxvSxE0NKqiPnRVwbRdgj4Vn6JQXeTHdBwaOtIksiaqKlCFYGx5iBEl3nzxZNbMn3V/fXsHtu3gugIXh5riILoq2NaaYmSZl1QbmkJNUaBQPyLpCGNZFqNGjcK2bXbv3o2qFqb64UDrgmGzgw3eLvY3vvENbr31Vj772c/Km6k0KOaOLuHbF0z3El7TwcE7I633dMje+wyzC5QGdTp6h1j2+fO+PrtoHJGATjxtctuTG/vtXi+ub8+f685Ze0q0HXf/zcp6Hejv+ibEOdvlCw8sZ1dnep+v/cDcun3Ooc8dXcLa3TH+9OoOXLwSHk0ROI67z/dCkg7WzTffDMDVV19d4EikoSIYDHL55Zdz++238/LLL3P88ccXOiRpiKqM+EjmVAQ+VEUQS1tEfDpRv05HModPEUT8OtlEFhSFsGFQHjEoCmjs7MjQ2Jkh4FdpimdI5CxKggbHjS+nMuJn+c5OkjkLx4EZtVEqo17ZeVfaAlxcx8V1vW7ksVQOoQhUAWG/yo52i2hAx3FdEA5Bn595o4M0d2coDfrIWQ6lIZ1NLXEUoeC4DtviOeo7EqSyNqbtEDE0yiM+ikM+SsNe1++2eIbujIWqeHPERxT56c5YgKAkoIEQ5CyXaSMipHpmYpeHZWWpdPAefPBBGhsb+d73vlew5PrNDKsE2+/3c8UVV3DLLbfwzDPPcOqppxY6JGmI6ntee+8zzOfMqOl3hvmcGTVsao6/6fupiuCM6dVsaIpz82MbAPLv8ZH5o/qd6xY9jVV6U+Q3S659usInF47J73I/vKoRx3Xz79Fr7+S6N6b3z6nbZ/f8vssW8L45dfx92Z4RaQc6zy1JB2P9+vW8/PLLfPKTn6SoqKjQ4UhDyDe/+U1uv/12vv/97/PII48UOhxpiFJ6EuhegT6N1iojfnZ1pdnVkaIsYhAyVCqK/ZR0GmSyFprq0pLIUKH4GVMSxHRdqqJ+SkIGacsmlbMpD/vpzuToSOaojAYoC/vyn5ExbbozJrbroihed+5UzqIk6KcimqM6GsB1oThoMKI4gK4qFAcNbMelMZYmZzkEDR1NEbTGs/h1BdNxCeoKwqeSNR1CPp0RRQH8ukrOcujOWIR8mvfZ6RzVRUFCPhNFePPIe5uvAZQeuh+DNITcfPPNCCH4whe+UOhQ9mtYJdjg3UxvueUWvvOd78gEWxpUfc9r9+7wAkyu9jqU2o43E3tydYQF48retITbdlweXNbAjo5Uvz9/dE0jH5k/ap+EvncM2IF2sHtHiu09FutjC8fsKfN+aO0+Z8l7KQJuvHAGc0eX9OuKbloOi+vb+d9TJuy3IZwkvRO9s4oPty6h0pGvqqqKiy66iH/+859s3LiRSZMmFTokaYhzHJeutInlOBQFdIKGSk3UR2scfKogljYJGV6jsYZYmvHlIUaXBklkHVKmQ22JH8vx5mEHdY2ArtIWz+C4UFrhlYj3NjUD7xy43XMcNKir6KrCiKifgE+lLGSQMm00xWtkpqt7RtapiqC22JuTXRE2aI5n0RQwNJXWRJbOZI6AojCyNMTEqjDhngcIivDWHTnLwXIcNEVHVQTFPeXrkvRuLVmyhNdff53PfOYzRCKRQoezX8MuwS4tLeXTn/40d999N4888gjnnntuoUOShri9d3gXTazINy+zHfjlc1t4ekPLW76PC/vd/e7VN6HvbVDWu2O89/89UNK793s8uKyBv76+02t+pghGlwYZVxHmsz0jyoB9uqL3ln6/neZxknQgr732Gv/4xz+48MILmTJlSqHDkYagG264gX/+859cccUVPPXUU4UORxriOpJZutMWqipI52zKQgaNXRmCPo1Y2iZrJ3Bst6dpGeiad366riwABAgZGpbjoqsqmqYwq66IWMbCpyn7TWJVRVAZ8cN+8pCI3yBnOSgK+505LYRAVwV6wCDk03HLQmRthzHlIbKmi64JigJ6vyOXmqpQUxygM5Uj5PMRDRxezaekI99ll12GEIJrr7220KEc0LBLsAH+7//+j7vvvpsf/vCHMsGWBl3f89Gm5dDcnen3983dmX5zuHupAoQisG0XXRX9dpsfXdPY7wz23gYiue19j/fNqXvTnegDjS+TpIFw0003AXDLLbcUOBJpqJo1axYf+tCH+POf/8z69euZOnVqoUOShjDTcdE1BV0VJHsmiGiaguO6WK6bP76lKuDTVXRNJWhoTKgMkzEdklmbUp+KX/d2m/2Ghv9NZnu/GVURBIyDO7/qlXULgr273G/Sj8yvq7JhmTQoXnzxRVatWsWVV17JyJEjCx3OAQ3LBLusrIyPfvSj/PGPf2TlypUcddRRhQ5JGsL23uH90DGjWN+4BrMncf7QMaNY27imX5KtAN+9aGZ+J7pv4vqR+aMOmFgPhoNJ1uVutTQYGhoaePDBBznrrLOYMEF2nJcGz9e//nX+/Oc/c/PNN3PvvfcWOhxpCCsJGjR2pTEtKAkZBA2NEcUB4lmLqF+npjhARzLLmLIg3VmLioifCZURIn6DiB8qDs+KWEk6JHofun/lK18pcCRvbliN6epr7dq1zJgxgzPOOIPHH3+8oLFIQ9/eXbb398+/fG4L9W1JxpaHuLxPCbYkDVcf+chHuP/++3nuuedYtGhRocORY7oGyeGwJgA47rjjeOWVV9iwYYM8iy0NKttxcVw3f+bZcVxytoOmCDRVoXdtLqfdSNIeixcvZuHChbznPe/hX//6V6HDAQ68Lhi2CTbAJz/5SX73u9/x+uuvM3fu3EKHI0mSJPXYtWsXdXV1nHvuufznP/8pdDiATLAHy+GyJli/fj3Tpk3j0ksv5Y9//GOhw5EkSZL6OOuss3j88cfZuXMndXV1hQ4HOPC6QNnfi4eL66+/HthTbiBJkiQdHr7//e8D8O1vf7vAkUjDxdSpUznppJO477776OjoKHQ4kiRJUo9t27bx+OOP84EPfOCwSa7fzLBOsMeMGcNFF13EX//6Vw6Hp+eSJEkSbNmyhbvuuovjjz+e+fPnFzocaRj5zne+A8D//u//FjgSSZIkqddll10G7Pkdfbgb1gk2wI9//GMAvvvd7xY4EkmSJAm8SQ+w5/ezJB0qp5xyCmeccQYPPPAAjY2NhQ5HkiRp2HvjjTd46qmn+MhHPsKMGTMKHc5BGfYJ9rhx4zjvvPP497//zc6dOwsdjiRJ0rAWi8W49957OfbYY+XutVQQX/7ylwH5gEeSJOlw0PvQ/XDvHN7XsE+wYc8P7Bvf+EaBI5EkSRrebrjhBlzXzSc5knSonXnmmUyePJmf/vSnNDc3FzocSZKkYWvTpk38/ve/54QTTmD27NmFDuegyQQbOPnkkzn11FP505/+xObNmwsdjiRJ0rDU1dXFrbfeyqxZs/jABz5Q6HCkYUoIwc9+9jNM08w325MkSZIOvd5Gpz/5yU8KHMnbIxPsHrfccgsAN998c4EjkSRJGp56fw//6Ec/QlHk7UkqnDPOOINZs2Zx5513kslkCh2OJEnSsNPS0sIDDzzAaaedxjHHHFPocN4WuYLpMXv2bBYsWMCvfvUr6uvrCx2OJEnSsNLS0sLNN9/MhAkTOOOMMwodjiRxzTXXYNs23/rWtwodiiRJ0rBz9dVXA0fmuE6ZYPdx1113AbKjuCRJ0qF28803k8vl+NnPfiZ3r6XDwgc/+EHmzZvHT3/6U2KxWKHDkSRJGjZ2797NfffdxznnnMOiRYsKHc7bJlcxfRx99NEcd9xx/Pa3v6WpqanQ4UiSJA0L8XicO++8k4kTJ3LWWWcVOhxJyrvqqquwbfuIO/8nSZJ0JPve974HHFmdw/uSCfZeeneve8sSJEmSpMF17bXXkkql+P73v48QotDhSFLehz70ISZNmsSNN95IZ2dnocORJEka8urr67nrrrtYsGABp5xySqHDeUdkgr2XU089lZNPPpn777+flpaWQocjSZI0pGUyGe68806mTZvGxRdfXOhwJKkfVVXzzfduu+22wgYjSZI0DPQ2nP7JT35yxD50lwn2fnzzm98E5FxsSZKkwXbjjTdiWVb+964kHW4uuOAC6urquOmmm+QutiRJ0iDatm0bv/jFL5g9ezYLFy4sdDjvmEyw9+PMM8/klFNO4Te/+Q1tbW2FDkeSJGlISiQS3HTTTcyYMYNLL7200OFI0gHdddddZLNZbrrppkKHIkmSNGT1HtW99957CxzJuyMT7AP4zne+A8D//d//FTgSSZKkoennP/85juPwrW9964gtA5OGh/POO48JEyZw5513ks1mCx2OJEnSkNPd3c29997LggULOProowsdzrsyIAm2EOLLQghXCFE+EO93OFi0aBFz5szhJz/5CTt27Ch0OJIkSUNKZ2cn3/nOdxg7dqw8ez0EDbV1gRCCa665hng8znXXXVfocCRJkoacL37xi8CROfd6b+86wRZCjATOBIZUFiqEyI/luOaaawocjSRJ0tDyve99j1wux0033YSmaYUORxpAQ3Vd8IlPfIKpU6dy++23yyaokiRJA2jLli389re/ZdGiRZxzzjmFDuddG4gd7FuBrwHuALzXYeWkk07i5JNP5o9//CMbNmwodDiSJElDQltbGz/5yU+YMWMG73//+wsdjjTwhuS6QAjBLbfcQi6X44Ybbih0OJIkSUNGb6PTH/7whwWOZGC8qwRbCHEhsMt13ZUH8drPCCFeF0K83tra+m4+9pDq3cX+8Y9/XOBIJEmShobe36u33XYbiiJbgQwlB7suOFLXBOeeey5HHXUUd955J5lMptDhSJIkHfGampr461//yumnn35Edw7v6y1XNkKIJ4UQa/bzvwuBbwEHVSjvuu6vXNed57ruvIqKincb9yEze/Zs5s+fz913382WLVsKHY4kSdIRraWlhR//+MeMGzeO0047rdDhSO/AQKwLjtQ1Aew5NiaPj0mSJL17X/nKV4Chcfa611sm2K7rnu667oy9/wfUA2OBlUKIbUAdsEwIUT24IR9699xzD4AsCZMkSXqXfvjDH5LL5fjFL35R6FCkd2i4rwsuvvhijjnmGG677TY5F1uSJOldaGho4L777uPcc8/lxBNPLHQ4A+Yd1+a5rrvadd1K13XHuK47BmgA5riu2zRg0R0mZsyYwdlnn80f/vAHVq1aVehwJEmSjkg7duzg1ltvZd68eZx++umFDkcaYMNpXXDjjTfiOA5f+tKXCh2KJEnSEeuzn/0sMPQ2MeXht4N0yy23AHD11VcXOBJJkqQj0ze+8Q3A+30q515LR7Kzzz6bE044gd///vds3Lix0OFIkiQdcV577TUeeeQRLrroIubNm1focAbUgCXYPU+s2wbq/Q4306dP58Mf/jBPPfUUL7/8cqHDkSRJOqJs3ryZ+++/n9NPP52TTjqp0OFIh8BQXxfcfvvtAHznO98pcCSSJElHnm9961sA3HzzzQWOZODJHey34fvf/z4gO4pLkiS9Xb1VQEPxRioNT3PmzOG0007jgQceoK1tyD5HkCRJGnDr1q3jqaee4pJLLmHChAmFDmfAyQT7bRg7dizvec97ePDBB1myZEmhw5EkSToibN68mV/+8pcsXLiQo48+utDhSNKA6e16+8UvfrHAkUiSJB05en9nXnvttQWOZHDIBPttuuuuu4A9ZQ2SJEnSm7vuuusAuOOOOwociSQNrEWLFnHOOefwpz/9ifr6+kKHI0mSdNhbsWIFTz75JJ/4xCeYPn16ocMZFDLBfptGjBjBxz72MZ588kmeeuqpQocjSZJ0WFuxYgUPPPAAZ599NrNnzy50OJI04G688UYAPv/5zxc4EkmSpMPfZz7zGWDPw/ehSCbY78BNN90EeB3FXdctcDSSJEmHry9/+cuA7F0hDV3z5s3jwgsv5JFHHuGVV14pdDiSJEmHrYcffpglS5Zw2WWXMX78+EKHM2hkgv0O1NTU8JWvfIVVq1bx0EMPFTocSZKkw9Jrr73G008/zSc+8QmmTZtW6HAkadD87Gc/A/acyZYkSZL2de2116LrOj/60Y8KHcqgkgn2O/T1r38dv9/PF7/4RWzbLnQ4kiRJh53Pfe5zAFx//fWFDUSSBtnIkSP5xCc+wZNPPsmTTz5Z6HAkSZIOO3/4wx9YuXIlX/rSlygpKSl0OINKJtjvUHl5Od/85jfZtm0b999/f6HDkSRJOqw8/fTTLF26lCuvvJIxY8YUOhxJGnS33XYbAF/96lfl8TFJkqQ+LMvi2muvxefz8d3vfrfQ4Qw6mWC/C1/+8peJRCJcddVVOI5T6HAkSZIOC67r8rnPfQ5N07jmmmsKHY4kHRLFxcV84QtfYMWKFfz9738vdDiSJEmHjbvvvpsdO3Zw/fXX4/P5Ch3OoJMJ9rsQCoW47rrraGtryz+5liRJGu7+8Ic/sGHDBr74xS9SXV1d6HAk6ZC58cYbURSFr3/961iWVehwJEmSCi6ZTPLtb3+bSCTC1VdfXehwDgmZYL9LX/7yl6mrq+O73/0upmkWOhxJkqSCcl2X66+/nkgkwve+971ChyNJh1RRURHf+973qK+vl8fHJEmSgLvuuou2tjZuvfVWDMModDiHhEyw3yVFUbj++uvp6uriyiuvLHQ4kiRJBXX99dezdetWvvWtb+H3+wsdjiQdcp///OcpKSnhiiuuoLOzs9DhSJIkFcyOHTv45je/yahRo/j4xz9e6HAOGZlgD4BPfepTXHDBBaxYsYJ0Ol3ocCRJkgrm2WefZdGiRXz1q18tdCiSVBDhcJjf/OY31NTU8OijjxY6HEmSpIJ56KGHGDNmDH/4wx/Qdb3Q4RwyohCdLoUQrcD2Q/7B70w50FboIA6B4XCdw+EaYXhc53C4Rhge13kkXeNo13UrCh3EUCPXBIel4XCdw+EaYXhc53C4Rhge13mkXeN+1wUFSbCPJEKI113XnVfoOAbbcLjO4XCNMDyuczhcIwyP6xwO1ygNHcPl39fhcJ3D4RpheFzncLhGGB7XOVSuUZaIS5IkSZIkSZIkSdIAkAm2JEmSJEmSJEmSJA0AmWC/tV8VOoBDZDhc53C4Rhge1zkcrhGGx3UOh2uUho7h8u/rcLjO4XCNMDyuczhcIwyP6xwS1yjPYEuSJEmSJEmSJEnSAJA72JIkSZIkSZIkSZI0AGSCfZCEEJ8XQrwhhFgrhPhRoeMZLEKILwshXCFEeaFjGQxCiJt7fo6rhBD/EEIUFzqmgSKEOFsIsUEIsVkI8Y1CxzMYhBAjhRDPCCHW9fy3+MVCxzRYhBCqEGK5EOLhQscyWIQQxUKIv/X8N7leCLGw0DFJ0sEYLmsCGNrrArkmOLLJNcHQMpTWBDLBPghCiFOAC4GjXNedDtxS4JAGhRBiJHAmsKPQsQyiJ4AZruvOAjYC3yxwPANCCKECdwDnANOADwshphU2qkFhAV92XXcasAD43yF6nQBfBNYXOohBdjvwX9d1pwBHMfSvVxoChsuaAIbFukCuCY5sck0wtAyZNYFMsA/OFcAPXdfNAriu21LgeAbLrcDXgCF7MN913cdd17V6/nExUFfIeAbQscBm13XrXdfNAQ/gLQCHFNd1G13XXdbz/8fxfvnWFjaqgSeEqAPOA+4pdCyDRQhRBCwCfg3gum7Odd2uggYlSQdnuKwJYIivC+Sa4Mgm1wRDx1BbE8gE++BMAk4UQrwqhHhOCHFMoQMaaEKIC4FdruuuLHQsh9CngEcLHcQAqQV29vnnBobgTaYvIcQYYDbwaoFDGQy34S1qnQLHMZjGAq3Ab3rK3u4RQoQKHZQkHYQhvyaAYbkukGuCI5hcExzxhtSaQCt0AIcLIcSTQPV+/uoavO9TKV75yTHAX4QQ49wjrAX7W1zjt/DKwI54b3adruv+q+c11+CVFt13KGOTBoYQIgz8HfiS67rdhY5nIAkhzgdaXNddKoQ4ucDhDCYNmAN83nXdV4UQtwPfAK4rbFiSNDzWBDA81gVyTTD0yTXBkDCk1gQywe7huu7pB/o7IcQVwIM9N8/XhBAOUI73pOWIcaBrFELMxHtytFIIAV6J1DIhxLGu6zYdwhAHxJv9LAGEEJ8EzgdOOxIXRAewCxjZ55/rev5syBFC6Hg30vtc132w0PEMguOB9wghzgX8QFQI8UfXdT9a4LgGWgPQ4Lpu727D3/BuppJUcMNhTQDDY10g1wSAXBMcyeSa4AgkS8QPzj+BUwCEEJMAA2grZEADyXXd1a7rVrquO8Z13TF4/5LPOdJuogdDCHE2XpnNe1zXTRU6ngG0BJgohBgrhDCAS4B/FzimASe8ld6vgfWu6/6k0PEMBtd1v+m6bl3Pf4uXAE8PwRspPb9fdgohJvf80WnAugKGJEkH658M4TUBDJ91gVwTHNnkmmDoGGprArmDfXDuBe4VQqwBcsAnhtBTzuHm54APeKLnqfxi13UvL2xI757rupYQ4krgMUAF7nVdd22BwxoMxwMfA1YLIVb0/Nm3XNd9pHAhSe/C54H7ehaA9cD/FDgeSToYck0wdMg1wZFNrgmGliGzJhDyniBJkiRJkiRJkiRJ754sEZckSZIkSZIkSZKkASATbEmSJEmSJEmSJEkaADLBliRJkiRJkiRJkqQBIBNsSZIkSZIkSZIkSRoAMsGWJEmSJEmSJEmSpAEgE2xJkiRJkiRJkiRJGgAywZYkSZIkSZIkSZKkASATbEmSJEmSJEmSJEkaAP8fjVkRRTRtujIAAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plot_contour(logprob, orbits=samples, weights=weights)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Yoshida\n", + "\n", + "A different method of discretizing the solution to Hamilton's equations, see [Blanes, Casas & Sanz-Serna (2014)](https://arxiv.org/abs/1405.3153)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 29.2 ms, sys: 547 µs, total: 29.7 ms\n", + "Wall time: 30.5 ms\n" + ] + } + ], + "source": [ + "%%time\n", + "init_fn, yo_kernel = orbital(\n", + " logprob, step_size, inv_mass_matrix, period, bijection=integrators.yoshida\n", + ")\n", + "initial_state = init_fn(initial_position)\n", + "yo_kernel = jax.jit(yo_kernel)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/cabezasg/.local/lib/python3.8/site-packages/jax/_src/tree_util.py:188: FutureWarning: jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() instead as a drop-in replacement.\n", + " warnings.warn('jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() '\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 2.65 s, sys: 4.8 ms, total: 2.65 s\n", + "Wall time: 2.76 s\n" + ] + } + ], + "source": [ + "%%time\n", + "rng_key = jax.random.PRNGKey(0)\n", + "states = inference_loop(rng_key, yo_kernel, initial_state, 10_000)\n", + "\n", + "samples = states.positions\n", + "weights = states.weights" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAA9gAAAF1CAYAAAATN0JoAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOzdd5hd1Xno/+/a9fQzZ/qMRhVRJQyWKAIbMAHjQHDDvcWJ4yRO7CROcn3jFsKPG984N3ZunGsnromT2OCCwcQOxBhs00WRKJIQ6prRaPrMmdN3X78/9mg0o15GBVif59GjmX12Wecgzl7vXmu9r5BSoiiKoiiKoiiKoijK8dFOdQMURVEURVEURVEU5eVABdiKoiiKoiiKoiiKMgdUgK0oiqIoiqIoiqIoc0AF2IqiKIqiKIqiKIoyB1SArSiKoiiKoiiKoihzQAXYiqIoiqIoiqIoijIHVICtKCeIEOJ9Qoj7jnDf3xJCPHIC23JCzz8XhBA7hRDXnup2KIqiKMqxUPf9o6Pu+8rLlQqwlZclIYQUQizdZ9stQojvnKw2SCm/K6W8bi7OJYT4lRDiw3NxLkVRFEVRYkKITwkh7t1n25aDbHv3oc6l7vuKooAKsBVFURRFUZRXroeAy4UQOoAQogswgVfvs23p1L6KoiiHpAJs5RVJCPE6IUS/EOLPhRAjQohBIcRvT722WAgxKYTQpn7/hhBiZMax/yGE+PjUz3khxLemjt8thPjrGTfkWdOzhBDXCSE2CSFKQoh/EkI8uO/TaSHEF4QQRSHEDiHE9VPbPgdcAXxZCFEVQnx5avs5QoifCyEmps77zhnnaRFC/KcQoiyEeBI44xCfRUII8R0hxPjU+35KCNEx9dpvCyE2CiEqQojtQojfP8Bn+D9nfIZvEULcIITYPNWuT8/Y/xYhxB1CiO9PnW+tEOKCg7RJE0J8UgixbapdPxBCNB+uvYqiKIpylJ4iDqgvnPr9CuCXwKZ9tm2TUg6o+7667yvK4agAW3kl6wTywDzgd4CvCCEKUsodQBl49dR+VwJVIcS5U79fBTw49fO3gYD4yfargeuA/aZ0CSFagTuATwEtxDfuy/fZ7dKp7a3A/wG+JYQQUsrPAA8DH5NSZqSUHxNCpIGfA7cB7cC7gX8SQpw3da6vAA7QBXxo6s/BfHDqc5g/1baPAI2p10aAG4Ec8NvA/xVCrJhxbCeQIP4Mbwa+AbwfWEncOfhLIcTiGfu/Gfgh0DzV9h8LIcwDtOmPgLcQf9bdQHHqPR2uvYqiKIpyxKSUHvAE8b2eqb8fBh7ZZ9ue0etvo+776r6vKIegAmzllcwHbpVS+lLKe4AqcPbUaw8CVwkhOqd+v2Pq98XEN53npp6e3gB8XEpZk1KOAP+X+Ka3rxuADVLKO6WUAfCPwNA++/RKKb8hpQyBfyO+SR7sCe2NwE4p5b9KKQMp5TPAj4B3TD1Jfxtw81S71k+d71CfQwuwVEoZSinXSCnLAFLK/5JSbpOxB4H7iG+gM4/9nJTSB75H3En4kpSyIqXcALwAzHxavUZKecfU/n9PfJNedYA2fQT4jJSyX0rpArcAbxdCGIdqr6IoiqIcgwfZG0xfQRzcPrzPtgfVfX/6WHXfV5RDME51AxTlBAmJp3zNZBJ/Se8xPnXT26MOZKZ+fhB4E9BP/NT6V8AHiJ8OPyyljIQQC6fOOSiE2HMODdh1gPZ0z9wupZRCiP599hma8Xp96pwZDmwhcKkQYnLGNgP4D6Bt6ueZ7eg9yHmYOmY+8D0hRBPwHeKbnD81Xe2vgLOm3lsKWDfj2PGpjgHsfZo8POP1xj7vYeZnEE19Bt0HeX93CSGiGdtC4o7HQdt7iPeoKIqiKAfzEPDRqSnJbVLKLUKIYeDfprYtn9pH3ffVfV9RDkuNYCsvV33Aon22LebQN5yZHiR+Yvu6qZ8fAV7D7OnhuwAXaJVSNk39yUkplx3gfINAz55fRHwX7TnAfgcj9/l9F/DgjOs2TU0j+wNglHj62vwZ+y846InjEfz/T0p5HvH0tRuB3xRC2MRPx78AdEgpm4B7AHGwcx2B6TaJeI17DzBwgP12Adfv8/4SUsrdB2vvcbRJURRFeWV7nHgK8u8CjwJMjZAOTG0bmFo+pu77R0/d95VXHBVgKy9X3wc+K4TomUqccS3wRuKp3oclpdxC/BT2/cQ3tDLxE9q3MRVgSykHiadOfVEIkZu6zhlCiKsOcMr/As6fSgZiAB8lXsd0pIaBJTN+/ylwlhDiA0IIc+rPxUKIc6eeLN8J3CKESE2tz/rgwU4shLhaCHH+1BSzMvEofwRYgM3UjXvqqfbxlh9ZKYS4aeoz+DhxR2X1Afb7KvC5qVkCCCHahBBvPkx7FUVRFOWoSSkbwNPAnxFPDd/jkaltD03tp+77R0/d95VXHBVgKy9XtwKPEd8ci8TJQ943tS7pSD1IPBVq14zfBbB2xj6/SXxDemHqOncQr6GaRUo5Brxjqh3jwHnEN3P3CNvyJeK1SEUhxD9KKSvEN713Ez8JHgL+lvjGCPAx4ilaQ8QJWf71EOfunGp3Gdg49T7/Y+oafwz8YOq9vRf4zyNs78HcDbxr6nwfAG46yBSvL01d6z4hRIX4Znzpodp7nO1SFEVRXtkeJE4e9siMbQ9PbZtZnkvd94+Ouu8rrzhCyn1noCiKcqJNTZPqJw76f3mq23MyCCFuIU5Q8v5T3RZFURRFOZnUfV9RXjnUCLainCRCiDcIIZqm1jh9mng0/EDTpBRFURRFeYlT931FeWVSAbainDyXAduAMeL14G+ZWvelKIqiKMrLj7rvK8orkJoiriiKoiiKoiiKoihzQI1gK4qiKIqiKIqiKMocUAG2oiiKoiiKoiiKoswB41RctLW1VS5atOhUXFpRFEVRjsmaNWvGpJRtp7odLzeqT6AoiqK8FB2sX3BKAuxFixbx9NNPn4pLK4qiKMoxEUL0nuo2vBypPoGiKIryUnSwfoGaIq4oiqIoiqIoiqIoc0AF2IqiKIqiKIqiKIoyB07JFHFFURRFmWtreous3j7OqiUtrFxYONXNURRFURTlFJBSMln3qbkBmYRBU8o6qddXAbaiKIrykremt8j7vrkaL4iwDI3vfniVCrIVRVEU5RXI8SMm6h4pU2e86pIwdRKmftKuf8RTxIUQ/yKEGBFCrJ+x7e+EEC8KIZ4XQtwlhGg6Ia1UFEVRlENYvX0cL4iIJPhBxOrt46e6SS97ql+gKIqinI4iKRGArgmEEEh5cq9/NGuwvw38+j7bfg4sl1K+CtgMfGqO2qUoiqIoR2zVkhYsQ0MXYBoaq5a0nOomvRJ8G9UvUBRFUU4zSVMnbRvTU8QT5slNO3bEU8SllA8JIRbts+2+Gb+uBt4+R+1SFEVRlCO2cmGB7354FXeu7eckP6h+xVL9AkVRFOV0pGmC5rSFH4Q4XkTDC0nZJ29l9FyG8x8C7j3Yi0KI3xNCPC2EeHp0dHQOL6soiqIosR+t7ed7T/bxvm+uZk1v8VQ355XuoP0C1SdQFEVRTqTxqkcoQddhuOISRSfv8fucBNhCiM8AAfDdg+0jpfy6lPIiKeVFbW1tc3FZRVEU5WVgTW+Rr/xy63EHxGod9unjcP0C1SdQFEVRDqThheyaqDNUcgjC6JjPI5EIAQIxh607Msc9Vi6E+C3gRuAaKU/2EnJFURTlpWwus3/vWYftB5Fah30KqX6BoiiKciyiSDJccbB0geuHFOsebdnEUZ1DSokQgpa0zXDZwQsi2rMWmnbyAu3jCrCFEL8O/E/gKillfW6apCiKorzc7alZPTDZ2G/U+WgC7D3nKaQsinWPm29cxq82jTBcdvj5hqHputgAd67tZ7Ti0pq1eduKHlXG6wRQ/QJFURTlaISRpFhz8UJJPmkipUQTGkJIjmZWtx9GDJcdHC9AArqmYehQbviMVx3mN6cQCExdww9DhqaCbyEkhZRFd1Mae45KeR1xgC2EuB14HdAqhOgH/oo4O6gN/FwIAbBaSvmROWmZoiiK8rI0c9Ta0ASGrhGGRz/qfNsTffzl3esJD3IHfq6/BICpCyIpmTnT7LYn+vjIlUv45A3nHtd7eSVT/QJFURTleE3WPSpugKVrjFZdmlMWxbqPoQsKKeuIzhFFksFSg8Fig9GKw+6iw/k9OR7fNo4QkjPbs6ztneSMjjRSwvaxKhoa20YrBEFEZ95mxYImLl/agTUHQfbRZBF/zwE2f+u4W6AoiqK8osxcKx1GknddMp95TUlWLWmZHlVe01uczgi+72jz5+/ZyI+f3c1w2T2ijOF+eOC9vvrQdha0pHnvpQumR8JntkE5NNUvUBRFUY6XG4QgwdAEbhjFJbYKBoa+N1VYwwsZKTfwQkkuadKasdE1wXCpzpreSSqOR8I0GCo3sAyNiVqDbz08QtUL6cwncfyItG2Qs002DlTQpIYT+IyVXTpzNprQmXQian6AZerUXH/6mMQxBNwnL1+5oiiK8oq3prfI7snGrFHr5d15inVv1j7v+UY8wg1wx9O7uOVNyynWPX6+YYhnp0am58K96wc5uzM7Z+vAFUVRFEU5MmMVh6oTMlxp0Jyy6cjZDFdc/Cgiaxu0ZmzcIKJvosZo1cXUBFXXZ7jcQEr4+YZhaq6HjCQZ2ySV0BmtOQyUGoxUPSIpGa04IDUuX5rFCyMsQ5BNmmwdbqBrkrLrYxgauUSOpGnQ8EKGSi6GLig1fHoKKSzj6PKCqwBbURRFOSlmTg3XNMHyeXkuW9LCrT/dMCu4Xb19HD/YO5/bCyWf+fE6jjVdlqFBJDngWq7rl3cdMPu4CrAVRVEU5cQJwoiyG2DoAk1A1fPJeQaRhIxlUHECcklzehlYQtepuB4bBmpoCCquR/9EjeaMxaTrITSNFYub2D3hMDDp0pI2GZh0qNY93npBN0s7s6Qsg3O7clSdgNZcgmuWdTFecUhaBhcuaMbSNaquH2cfF+B4IUEYqQBbURRFOT3NDGSjUPJ8f4kNA2XCSCLZG9yuWtKCaWjTI9jAMQXXAtA1wa1vXk7feI2vP7ydSMbbuwtJ3nh+F8W6RyFlYRkanh8hxJGv+VIURVEU5dhoQhCGEZuHqhgGzC+kKDs+CdOYSj4GMpLoAlKWQcP1eWLrGP2TDeYXkkRComkwWnaYbPic35WjKWXT05Sm6ro8uc3B0jR0XeCGUEhZSAkJ06DmhWQsHUvXWdKWozltUqx7DAfxtPCaGzBedbFMjbJjkbKPLmRWAbaiKIpyUhRS1qxRZAlEUqJrAiklpqFRSFms3j7OLW9cxtcf2sbO8UMnotZE/Cc4QKlMSVyuY8NAiR+t7QfiNV63vnn5ftPCf+uyRXzzkR1EUnLrTzdwdmdWjWIriqIoygkigTCKqHse5XKAJgTnduaxDI2GH5K1dQbLDhIoJExeLDs80z9JseqycbDEgoJNV1OGrqYEq/IpCmmbyZpHyjI5q6OJtX1llnTaIKGvWMfUBG4oGak0AEFrxqbmhcwrJKl7IX4YkbJ0qk5A0tJZ0JoiZcRTxqNIHlWZLxVgK4qiKCfFzHXWexia4EOvWczj28exDI1bfrKBIIww9Nkj2AcTyUOPbpuGhoTpkXOBpFj39psWvmGwTCSlmiauKIqiKCdBJCUTVR8/AD+EUj0gCANcX+BLydaRMh3ZJM1pi2LDZ9tIlXzCJJKSUs1DF4JMwgAp6J+o43ghXhSRT9oU0hYtaYsIQVvGRtMlYxWX7uYUXhARSonQBLahTwf0USTj7UA2YVJ1A+p+SMrSj7qGtgqwFUVRlJPiQFOv23IJvvXojv0yfftHEFzvcbD4ui1jcd2yTpZ1x0/E/WB2KbCZ265f3sVTOyf220dRFEVRlLkngBBJ1QvIJgxqvs+LQzUKSQMpBK4vmah5eGFEU8qkOWvj+AE1x6cpZdKWT+K4Ic8NTTLheFzQlaOrkOTBLUMsKGRYtaSZTcNVUpagI5fECSKiSNKUMpmo+QShpD1rI4QglzAJwggniGjP2aRtg7RtEElJyjr6cFkF2IqiKMpJcaAR7N3Fxn7bBGDo4qDltY7UaNXj9if7sAyNm29cRrHuzSrDtSeh2p5tZ3dmVakuRVEURTkJdE3QlUtSbfh4Qchw1afmeJQbLpEUnNWRwdQ1NF2QsU0uWtRCezbBc7uKZJM6ri/Z0F9G6GAJwbbxGr7QWNaVi5OpGhqtGRuQ1LyI/sk6mi7obkoyvzk1qy2aJmjNJmZtSx/luuuZVICtKIqinBSrlrTEycQOMTptaPCuixcggduf6DuiOteHsmfKd7Hu8dGrl856beXCwqxAet/fFUVRFEU5MYQQLGhJIZFsGixRrwe8OFgmaRh05BMsm5cnmzDpzCcIwoiByQaREJzRkWNRS4qNu8tElDB1naQtsXSd9rRFa8ZCCo0glMwrJBkquXhBhOtFhGGE60dIKRHi6KZ9Hw0VYCuKoignxcqFBW7/3VXc+pMNPHeAWtbndmb567eez8qFBW6bg+Aa4tFwNeVbURRFUU4/lqGzsDnNWNWl5gXUgoCEZdCSiddQdzUlEULgijhni6lppBMGfhiRT1lce14Hq7dPEIQhZ7bnWNyeIQhCAikwTYgQdORtHD/E0nWkFDSlrBMaXIMKsBVFUZSTbNm8POunynPN1PDD6RHkYt1DcPD11XsI4qldYSTRpupWhtHe1zpyNm+5cJ4amVYURVGU04iUkrGqS8UNANB1nQjiINoyKKTt6UBYSkjZBpmEweahCr21OpGEznySd140n3zSIgSSlk7DC8glDOpehBsGJM04E3jdD0lbOvZR1rQ+FirAVhRFUU6KNb3F6dJYmmC/APrC+U185ZdbWbWkhVVLWrDNvdPJZ8biuga/dk4H7Vmb0YrLz18YBuLzvXtqevnW4QpP7iwyVHb56kPbWdCS5r2XLjhZb1VRFEVRlEPwwoiKE5CxDbRcks5cggt7mkAIWtMmxbrLwGRExtbJJ01yCYPeiTqGHtfQbs0Y+EFEdyHFotYMAJWGx8aSM1Xmy6Qjm6Q1m2BossbIkIvjh0w2ily0MF6ydqKc+BBeURRFUWBWaSwp4T2XLuCSRQU6czZvubCb/94wxBfv28S7vvY4m4Yq3HzjMjQh2GegGxnBR646g5tW9PCrzaPTQbrQBBJ424oebFOfdcy96wdPyntUFEVRFOXwNCFAQsMLiIDWbIKlHTmWtmcJInD8kPGaw/qBEn0TDQxN0JQ06cwn2VWs8fDmMXonGtScgCCM8MOI0apHW9bGDyJqXoAQ8Qy3ihuSMDWaUzaBL3GD8MS+txN6dkVRFEWZsmpJC4auIQBd13jbih5+8JHLWf3pazmzI4vrx8F3EEluvns96wdKRAcoch0RB+urt48TTM0HF4CMJN97so/3fXM1y7pys465fnnXiX+DiqIoiqIcESkhkBFDFReAtDX7wfhYxWPXeJ2KEyKlxAkiTF1nYNIhZesUUiaaBmgCL4yQUw/vk5YOCKJIUnF8RisObdkEfigZqzqkkzpJQ9+/QXNITRFXFEVRTp6pgDmKIu5c2w/Eyc8KKWvWdPEwkgjiWtWuH816TROwe7LB8hn1rfc8pZbEWcPLbsB153UwXHZ418UL1PRwRVEURTmNlBoeCUNncYvFQNFhm6yRMnU6cgkytkEYRiRtA1MTFOs+3YUk+aRG/4TGZN2nVPcxah6LWtK0pE1swyCfNJlseHhhSEcuhW3quEFEc9pmaVuGSEo6c0mME7wOWwXYiqIoyknxo7X9+GEcBAcR3PZEHz9a2893P7xqv6Rmuia4aUUPy7rz/NMvt9A/6ew9kYTvzahvvWGgxEjF5cHNo4RhhK4J7ljTTxBGWIbG2Z3ZU/BuFUVRFEU5ECklfiCpOAF+GFH1fdpzFn4Yb8snTbqbU3REkpob0JKxSZrx6HXSNkkYBgt70tiGTsXxGS65DJSKdGcTdDSnyHkGfcU6zSmb1ozFYKmBIJ4h1whCsirAVhRFUV7q1vQWuWNN/6yR6D2jzau3j89KaqYJwa1vXg7ArT/dgOPPrpsdTR3s+BGfvmtdvNZpqoTHNefGo9bP95dmnV9lEVcURVGU00Ox7lP3ArwgolSPp4jXvRBNFwgRVwfpyicpNTyaUxa5pEnDDwnCiOakgSYEfeMNLBOKVYevP1RktOLSkrK5aWUP17+qB1t3KTV8NA3qbsj85hSOH04nTz2RVICtKIqinFBreov8w/2b8Q9wU9P1uEb1yoUFvvvhVdPB9sqFBb7yy61HdCPcE3AHoeQXL44QyXiUXEPVwFYURVGU04WUkqobMFJukLFNLFNj55hPe8ZitOLQ05wilzCBeIlYWzYxfaw+VX5kpOqzoCWFF0aUax6PD4/TX2yQMXXKjs9/vzDEikXNOJ6cWqttUW7UmKx7mIZG2j7x4a8KsBVFUZQTZmZprn3TlQng7St7pkeXVy4szBppXrWk5YBrsA9KQCQlkYzPfX5PnpvfuEyNXiuKoijKaWCi5lFq+DT8kMm6Ry5pkbYNmjMJLNOgOW2jaeKAxyZMna58EilB16HSCOgbrZIwdBK6RsnxSegazQmDneN1zu/O4YUwVHIQQFvWJpswMfQTn+NbZRFXFEVRTpjV28ens4ML4IKePHvubRJ4prfImt7irGPW9Bb5yi+3smmowhVntnFGe4YjuR8uaU1P3zglsHGwPKfvRVEURVGUY1f3Qhp+QBDGWcFbUibVhs9da3q5f8MAG/qLTNYcwjDCD8J4fXbDZ7jUYPNQie2jFdwgiMt7uT6ZpMF587IsaI1Hvpe2Z1ixuJWGJymkLYSAMIpoydiUGkFcGuwkUCPYiqIoyglTafjTo88SWNya5rn+0vTrG4cqvO2fH2NRS4r2rM2OsRqjVW+/8+gas5KgHcjyeXkuWdLC7U/0IYkzkav114qiKIpyenD9kPX9JZKWRmc+wXDZZahUp2+8Rn/J4dHNY3Q1JVnclqY9n6DuhpQaHoYmEEJg6jqtmQQtGZNUwuCsVJ7+iTpXnh1nHX9s2wSPbhnj/Hl5RqtNZBImSdPA1AV1Pzyy2XBzQAXYiqIoygmxprfINx/ZMf27AJ7dNXnAfXeO19k5Xj/oucIjyEly97MDfO6t52Obcekutf5aURRFUU4PQRghBcxvTiGRCAShjPAlOIGk6obUHI8gCnGCiP5ijaoXMl7xiCJJU8pgaUeeSEqEBk0pk7SpE0aSppTFtuEqYSQpZA2KNY9d4zUuW9rKSMWl7kU0p624nrYfYepavKb7BFEBtqIoinJCrN4+ThjtfV6sCWjP2ocMpI+HBL7/VB8337iMYt2bDq6/8sut04nTFEVRFEU5+TQhsHSN1ozFpqEqCVOwpCWFhmSo7FCte4RAICW6kAyUPFzfpxFAEARUnYB8wmSorGFrgqrj0pxOUEibDJccGl6IRkQkBXUvoFjz2TZapSubIJMxEbqgv9ggkhJdE8xrSp6w9dgqwFYURVGO25re4qwM4MB06a09ScpCCU/uLB76RMfpuf4S63av4/x5eSoNn28/vhMviOthf/fDq1SQrSiKoignkJRx7eowkqRtYzqI1TRBZz7BgxsHeaavyO5inZCInG3SZOk0bIHjS1KWTiQjOnI2YxWB6zXIJU2iSOKGkkJap+z5NCohhqZRrgf0FmuYGswrpBmvujRnLDQh2TZc4fldJVrSFt1NSWxTJ5c0qbkBbhCpAFtRFEU5Pc3MFL5vIHvlmW3c98LwSW1PJONA+7n+0vS6bVUPW1EURVFOvFLDZ7zmoQFVN6C7KYkQgjCSOF5If8klacaJVUaLLnU7olhr4EtB0jLoac7QkbXIpixyiTp9uiSXtPDDCCSMVBwm6i4Jw0QIB6FFmLpO3fXJ2CZtuQRhFPHotgmaUwZVx+OszibQoCVlY+jx1HDzBGYTVwG2oiiKclxWbx/HC+JM4XsCWYD3fXM1rn8Ei6dPICFUPWxFURRFOVkafkjC0DB1jaoXj2TrGgyVGrhBhK1rDJRqTNYDvCAiiFxCoZE2BRnbYKzcwI8kdsXDsnSWdTXhRxHFuke57jNS8ZiXT6ABpYaDF0Z0ZpNkLR03DOnMJdAETFRL1N0QJ5CMVl3StsF5nVkyCYuEqWMZKsBWFEVRTlN76lXPTCy2J+g+WRk7D+a1S1uRwPXLu9TotaIoiqKcYLmEwXDZxQslaSueIu6HEV4gydhGvHzMgK3DFfonTSKg4fkMF11kQtCeSxBKiaZp9ORSeDLCd+JR7ISh4YYhSVtntOzwqvkF5hcSDFVc8kmLtKUxUQ8AwYoFBcbrLg0nIJ80yaUMkDBZ98kl47raJ4oKsBVFUZTjsnJhge9+eNV+a7D3BN26rmHrgoobnvS2PbZtnEhKnto5ATCd/EwF24qiKIoy99K2SU9BJ5ISe2qU2NAEtimougEIjcWtOcqNkIlGCALOas9iLRKkLJOxmouGRt9Eld3FEc7uzIMM2TJSp6eQIGUaBGFEQtdx/JAggo5sAl0ThBHkEyZL2tIYusZIqUEgoS2TwDY1Gr4kk9QZrTg0vABD12hKmdjG3AbbKsBWFEVRjtvKhYVZQeu+Qfc/3L+Zh7eMnbDrawKWtGUwNcHmkSpSSjQhiKQkkuD5ETffvZ5ISpXwTFEURVFOoH2nXwsh6MwlcYIQTQiKNQ9di0t2+UFIIW1zXlceBDzTO0bfeJ1y3Weo7OB4PpGME5w5QYAMJOfOayKQHpN1j74JWNicZrjsYuiCpoTBeM0lY5ukEya5pEVb1kZKyUTdIwgjJmouUsZJ2IbLIfMLKYSYu7JdRzz5XAjxL0KIESHE+hnbmoUQPxdCbJn6W/VWFEVRFCAOsj969VJWLiywrCt3wH00AYXU8T/rjSRsHamycaiClJJrz+3g1jcvxzI0NAEICKOpYHvGOnHl2Kl+gaIoinKkNE2QsgwSpo6hC0puwETVA+C87hzt+QRd+QSvXdpOyY1Hti0Dto3W2DnhEgQhXhCBIdgxVsFxA9pzNs0Zm2RCpzlrUnV8ql7I9tEa4zWH4ZLD1pEKuycb8XU1je2jVUqOT7Hm4fohnh8SzfF6tqNZ3f1t4Nf32fZJ4AEp5ZnAA1O/K4qiKMosFTfYb9sliwq8+5IFdOaSR32+Qz1njiT84sURzu7McvONy6ZGspleDx5JKKSso76msp9vo/oFiqIoylHKJQ3O6chzVkeWJa1pOrIJLEPDCSJ2TtRpSulU3ZCSG9KcTtCeMWn4EEWC7nyS5ozN4tYMHdkELWkbGUEQQBhBJmHRXUhRcyLcUJKydYIwoljz0TXBkrYs85tS9E3U2TVex49k/CB+Dh3xsIGU8iEhxKJ9Nr8ZeN3Uz/8G/Ar4i7lomKIoivLysXm4Mut3Uxc8vbN4THWxLUPjdWcduvxXEEn+/AfPcmZHlkjOfjStEa/FVo6P6hcoiqIoxyJjGTzfPzk1/RvKDQ/dMOnIWQxMNLhiaQft6TLr+ot05pIMVhyaw4ikrdPwAtoyCRa0pOlsSrC8O8dQxSNhaExUPapeQNY2KOmCqhfnftk2VGEk5VBIWWRTJnUvYkFLmvmFJF4kCSM5Xb5rLhzvvLwOKeXg1M9DQMfBdhRC/B7wewALFiw4zssqiqIoLyXF2uyA1g+PbT6WBrzurDZed3Y7v9o8Gk8XAwxdcGZbhheHK+yJp3eO19k5XsfQQM4YwRYaqmTXiXNE/QLVJ1AURXnl8kNJxtYJI8mLg5PUvIB5TSkcP8SXITvG6jRnLN5+0SJCGdFbrNOeTSDDCD+UFDI2S9szeH4EQiNnG9T9kGzCpCOfIJcwKTs+Fcend7xGNmnQlk0wWfeY15wgDCUjJY8Xdk/SVUihz/EQ9pwlOZNSSiHEQXtMUsqvA18HuOiii0515RZFURTlJFrSlmHraO2w++kadOYS7J50Dvh6BPz8hWF+tXmUq85qQwCtWZvl3Xlu/ekGkPH08Zk3mUWtGbaNVKd/DyP4j8d37pf1XJlbh+oXqD6BoijKK1cmoSMRTNYdyk6AwMEQOs1BwNLOPLoQLG3PYhkakYS2bIJtI1UmGg7nduTJJE2CSDJe99B1Qd0NKGRs2vI2Kdtgsu7hhxHNaZvhsoMfSpwgRNc0wkggJawfKBFGMFbzyNoG6YRJxjZIWscfHh9vhe1hIUQXwNTfI8fdIkVRFOVl5/evOuPIdpTwurPb93uarGt7111L4kRl978wzENbRnnbih6Kde+gdbcXt6b32/7jZwf4ws828Z5vrGZN79FPU1cOSvULFEVRlEPKJizO6sxg6QY9+SSFdIKy49EIJf0TdYJQkkmYIATNaYuaFxAiaU0lyCYMzmzPkjB12nM2bhBRcgKqbsBkI54tp4k430oQRnQ2JZlXSJK2dJbNyxFJSdnxGCg2KDcCdk3UeXjLCJsHKzy3a5LGAXLGHK3jDbD/E/jg1M8fBO4+zvMpiqIoL0MrFxa4ZNHhR4qFgJtW9LC8e3bW8baMzbXndWAZ2uxA24+49ScbeG7XJIYm0AXYpsZbLuxmUUuKj1y5hKvPbj/gtfYE6neu7T++N6fMpPoFiqIoyiFpmuC87jyXndHKmV1ZzmrP0p63aUpaDJcbPLNznKe2j6HJiKaUha5pnNOV4+yuHCUnQAjI2Dq6EEzWPZKWTi5pUHdDRkoNyg0f2xAgBPPySRY0p2jLJjA0QcMNWdM7zmjVpREEFOs+pqZRyFh4QUTVO/4A+4jHwIUQtxMnLmkVQvQDfwV8HviBEOJ3gF7gncfdIkVRFOWUWtNbnDV9et/fj9VfXH8u7/jqY4cshxFEsGmowrsuXsBz/eumtw+XXSa3jHLLG5exYaDED5/eRRBKIuC5/hJQQtfgXZcsmJ4u7gUR//LojgOOas+k5icfG9UvUBRFefkr1T0mGz62odGasfFDyUTNxTQ0mlMWhh6P10op8cIIXYjpbYfSkk5w3rw82ZSB60VEw5KdY1WqboAXRqSGKghNoyVj05lLsH2shpCSlG0wVvUII0lT0iTVrlN1QvxAMlZ1WLurgYwkPc0pVi5sRgjYXWxg6Rq7JuoUax7ndOQJIkjqBq0Zi1zSpNzw0XRBwtSP+zM7mizi7znIS9ccdysURVGU08Ka3iLv++ZqvCDCMjRuvnHZdLBqGRrf/fCq6SD7aALvNb1F7lzbzzXndtA/Uad/sk7FCQ+47+fv3ciqJS1csqhA30Sd4bKLBPwgolj3+Nxbz+emFT38w/2beXjL2PRxYRRPI98zXTyScSKVAwXQugZRFGczf9uKnmP8tF7ZVL9AURTl5c0PI8brHilTx/FDSnWfihdgafG6Zw1ozSYAGKu6VKZGl7vyyf0C1TCSBFGEpWtM1j0m6x6mJljankETGoYuGK+6DE46lBwXz4/YMVbF8QIKWZu0qZGyLAxDI20b+GGEH0m6m1J4QUQQRuwYrZJPmGgaFKtxnWvLiNthGRqaELhBQBhBV8Ymn7K4aFErhiEoOz75hEk2YR735zZnSc4URVGUl77V28f3BqdBxL3rB2f9vnr7+PSo9sxAfGbgvceeALyQsvjLu9cRRkfWhrITTJfg0ogzhEeRxDS06ezfKxcW+Pi1Z/HYtnHCGUPioxWXm1b0YBkafhAhRDwqPtNHrlzC65d1qiRniqIoinIUJPFIta5paJFgz+01CCMqbkDGNvCCiFLDnxVgu0HI4KQTr3+uu2wdrVKse6QNna7mFK0Zm9GKS3c+Qd94neGypN9tEIURKVPnnHk55jenSNsmk3WPlKWjCUF71gbi4NkyNNKWzvbRKjUvJJcyafjRVOIynZobYhka2YRF1Q1ob0ry6vnNNKUtAJpS1px9TirAVhRFUaatWtIyHZyahsb1y7t4aufE9O97Atx9A/E9gfcetz3Rx813r5+uQX2oaeGHEgEikrz7kgXctKJn1jVWLizwu69dzNce2j49Sv2rTSP8/lVn8N0Pr2L19nEGJhvc/mQf0VR28fdeuoBP3nDu9PGKoiiKohyYqWu0ZWwmah4J06CQtjB1QbHuY+iCpmQclGpCYGoajh8SRJJMYm+IKaWkb6zGcMWlkLbYOlKl5vpUaiHrxidZ5vg4rWkmaw66pmHrglzKxPEixp0aG4Yn2VWuc9VZ7WST8fVTlk4uaZKx9442VxyfppRFa94m0QhYsaBA3QvwQpPOXIIgklMP3gWLWg1qboA+h7WvZ1IBtqIoijJt5cLCdHC6Z3T37M7sfqO9+wbiM+tKr+ktcvPd6wmONareRyjhgY3DLOvOzwqK1/QW+fbjO2dNAfdDOd1WgGXd+VntvElNB1cURVGUI5bdZ9p0PhWvWRZib3CqaYLOfIKqE6AJyCX37u/4EY0gxNQExZqHaWg4VclI3SGVMNk0XGWs7jFScVnWnaOQtckkTHYUG4xUNEQoKdZ8nukt0pFPIkNJsebRkU1wbleO1FTbqm5AJmlyhpHjmb4iO8ZrpC2DrnwSx4+IpMTWNVJTo9mmoZGcg/XWB6ICbEVRFGWWlQsL+40U7zvae6BAfI/V22dP295TLuN4DJVdPn3XuukR6pULC9Oj6DNJYMtwhf/3iy2z1pEX656aDq4oiqIoc2BmcL2HqWsU0vtPsxYCLF2jRkjNDblwYRN96TqRlAgNKo0A29LonagxUnLI2SaprEFH3ma81KARhXi+jxdIHto0SmvG4uzuPBWnQsMPWbGwmZRtkLZ0+icajJTr1BwfQ4NKIySfqBAisAyNjG3QMTWabWjigO9jLhxvmS5FURTlFWrlwgIfvXrpfkHrqiUtmMbe24sAFjan5uSa970wzPu+GdeuXrWkBe0A98bV28dx/b3T1zcMlObk2oqiKIqiHBkviKh7AYYmkICQ0Jm3qDsh53XnOLcrjwGEYYjjRqSteH31eM1jcVuG9mySZfOyNKcSZJMGGUujOW3QnLGYrHpEkSBp6YxWXQB0TWO47DBSdXGCiPktaSxDsGm4SrHu0vBCyg2fhhcyWnEZrbqzBgPmkhrBVhRFUebEzKziF/TkeWpnEYinePdO1OfsOq4f165e1p3fL4EZwGjV2zttXMAPn96FF8Zbzu3M8tdvPV+NZCuKoijKCeL4IQOlBkIKDF3g+SG7J2tMNDyylkFT0mS85rOwPYtlmgyWG+hIam5ERgsRWlznOmObdOah1NAxDZ3FrRnySZuqG9DZlCBjmwwV6xiaIAhDdF1wVkeeZ/qKDJUbaMCilhS7Jhv0jddZ3JIhiCSuH7B5qEo+bXD5khZyqcScvn8VYCuKoijH7bYn+vjLu9cTRhJdE3P2VLgza7OgJcVEzWPHeJ0wistuff/pXSzrmj0ynbF1OvNJto1Up7dFEYQzVmlvHKrwzq89xg9+/3IVZCuKoijKCVB3AyZrLpahE4URT+wcp1jz2TlWJW1qzGtO4/ghtikwdUhqGmP1uOxXWAt5btck3bkEZ7Xn2D5RY15TkhWLC1w4v0DCNBASal7AiwNlLFOjd7xGJmEQRvG1F7emWdSaYnDSoeIG6AjObMsQAhv6i3FNbQTRBFiaxmVL20lac7ceWwXYiqIoynFZ01vkL3+8jqlB4jmdcvWByxexakkLq7eP8+CmEZ6cGhUPQoltzF7l5Pgh20aqs5KeHaglYcR+Wc8VRVEURTl2jh8SSUnC0BkqNXhxuEoYhOwYq1FpuGSSBhlLR2gaEzWf9qxF2Q1pSVi0ZCWtKZNyw8ePoD1j44YBxXpIVz5Fc9pkXlOapGlSaviEUYShCUxTpyltUWl4uGHI+fPyVJ2AbEKn1AhY0JpmrOKgCY1kwmCk7NKas1k3WKElZZGxdBp+hBeEmLrACyN0TWAbxxdsqwBbURRFOSZ7poTvnmxMB9dzydCgkLKm623vm4zEDSI+cuUSNgyWSZg6D2wcjtd5ESdVOVicr2vxOvGZU9pVsK0oiqIoRycMQ8IQ3ChitBKvhU6YGnUvYFEhyRM7JnhhYBKAyQGfpoTFGW1pBJJsUidnmzSlLUoNn4rjgaaRMjUmqg6FjI2pR/RP1ujMN9OSsYimbuyZhEmp7pNPGEzUXEYrDgubM3hhyMLWFEEkKTUCNCHIJS068wZRJAlDSVNCZ6CzwUTFJ0AwvzmFbeoMlhz8MF531t2UnFXH+2ipAFtRFEXZz5reIj9a24+AWfWnP3/PRv57wxAXzm/ivzcM4QURhq6hC444yBYceGR5Xx9+7RKKdW+63raGnHWddbtLbBqu8N0PrwLg4S2j0+W4br5xGRsGSvzw6V3xiLqAfNJkaVuGv7g+roO9J3DXhODWNy/nvZcuOLoPSVEURVFeARw/ZLLuEUpJbqpsV6nh8ciWMepOQD5lsWxeDlPXqHoBTSmTF4cqDBTr5JMmCBgtOSSyMOkEXDg/T90FTYakExHnzMvybP8ENcdHizR2lxqEUuCEAU0Jk7Sp051PUvdCio0IxwdTFyztyNI3XqclbdOasal7IX4osQyNloxFse5jGxotaQtD18gkDCpOwPXndcFU7e5UwsQPozizuC4YLjtEUcTitiz6gTKpHgEVYB+G67rcf//99PT0cMEFF5zq5iiKopxwa3qLvOvrjxNMRbI/XNPP7b+7ip9vGOKrD20HYOf43qRlfhBxNJUuvLFd+JMDJBZcgGYdPLHI49vHufmNy2bVsb75xmXcu36QR7aMEck4S+nq7eN89Oql+5UNW9NbnB7RnvmQAOAzd63D8eMn1ZGU3Hz3es7uzKqRbOWwnnvuOQYGBrjqqqtIpeYmO76iKMrpyvFDNg5Msn20hqFrLO3IsqCQ5MWBMnXXpz1ns3O8RmvWmg5I05ZGTyFFd95mou4yXnXQdEFTOoUgYrTkUvNDWrMJarUSOzespSTyOJkOgtCnKTSY0DWCUDKv02ak4lJp+OTTNp2AF0rSto5t6HQ1Jdg6UmVwskFT2sSYakM+ZZFPWbPeRxhJWjLWrCngbhBnFR+rONTcANvQCUJJsebSmj225GcqwD6MN7/5zfzsZz8D4A1veAPf/OY36enpOcWtUhRFOXHuXNs/HVxDHMR+7cFtPLlz4oD7S0AewZB0WC8xcd8/Ud/0KABW51I6f/PvEeLAFSN3FRtsGqpMB86FlEWx7rGsK8fDW8aAeBp4YeoGOrNe95re4vQItWVo3LRi7/f2mt4iP3x616xrRZFU67KVw7r77rt5y1veAkB7eztf/vKXecc73nFqG6UoinICNbyAibpPJmFSc3xeHJpkvBoHvMMVh4Spk7IN2rIWu4oOURjx4lCd7lyKntYMlqmzu+hQrHk0pQxqTkDdj8jaOg/95Pvc+69fxK1VEJrOBb/3BTILzmNEFywMQNc1tk/UWNAs2Dxa4TxdwzQ0MgmdUt1jpOxSdlxStj5d21o7wKiz44fsLjbQNCjWoKc5hanHfY/Riosg7kvU3ZB5hQSa0A5YpeRIqTrYh9DX18fPfvYzPvGJT/CFL3yBRx55hBUrVrB69epT3TRFUZQT5kCx8n0vDDNZ92dtO5ra1t7oTgb/7ePUtz5B/rXvo+nK38Qb2oo3vP2gx0zUPD591zq+9uA2Kg2fv/zxOv7uZ5v4xsN7j9GAYt3b79jV28enp5b7U6PcM18L9lmgbZkaq5a0HPH7UV6ZvvnNb7Jo0SLuueceFi1axDvf+U4+8YlPEEXH0RNTFEU5jSVNA8vQCMKQsapHQjMIw4iWrEVzymKoXOes9gy60PCDEDcMeL5vkoc2DdE/0cA0DRa1pHnV/DxL23NcsriFlCX53j/8FT/+8i10LzmH9/zlP6MnUux+8h5MTUMXgprvYxhQdyRVN2TzYIXHto3RP15n42CJdf2TjJYbbB+pU3F8klMjzwfi+CGaBinLIJLM2m9P9ZOUpdOWtwklREgKafOYPzMVYB9Cf38/ANdccw1//ud/zlNPPUUul+Oaa67hgQceOMWtUxRFOTHetqIHSz/8nO/UEZa0yFd2MnzbJyEK6Xzf39H0mveQWLwCgLAydtjj73thmK8+tH1vlvIZ909tKmHZvlYtacEy4rXhpjE7eJ75mqUL3nfpAr774VVq9Fo5rP7+fs4//3yuv/56HnnkEf7wD/+QL3zhC3zoQx8iDMNT3TxFUZQ5l7B0LpjXxNK2LOd15VjamaHmhGwcqDDZCMjYFr0TDcZrHhUv5Nm+MinLwDA0Kg0XGUia0iamrrOgJYnvOXzvb/6c5x64i1Vv/R0+8+Xv0HHeStJtPXilMepuQBBKhKaDlORsk9aMTco0aHghgYzYPV5nsu6zbbTKiwMl/uu5AX60to/e8RqOH+z3HpKWDhKqToCpa1gzqpC0ZWzcIMKPJEtaMyxoTjO/kDquTOJqivghGEb88XhePDpy7rnn8uijj3Lttddy4403cs8993D11VefyiYqiqLMuZULC9z+e5dx59p+fvj0LryDPBF+cahy2IRl7tBW1n3vM+iJDO3v/hxmU2f8QhjfAIV2fKUwggg2DVX2C45XLizstyb7SF5TlEMxDGO6T2CaJl/+8pdpb2/nlltuQUrJv/7rv6JpauxCUZSXl1zKIpeymO8GbBuroumQtjVSpk7K0pGAqYGIQlw/QBOSXWN1Rms+Y3mPpW1ZXjW/QN7W+JO/+CjrHv8Ff/Cpz7H89W/H90N2jNWJAh87mSWT1unOJWjPJXG9kHRSo+z41LyAXNJi02CFSMalOquNABC0pBM0pQycIB5l7ynMDnFtQ6enOUUwlQBtZvKylG2wcGrAYN9qJcdKBdiH0N3dDcDu3bunt3V0dPCLX/yCq6++mje96U384he/4OKLLz5VTVQURTkh9qxnXtad57M/XnfAkleHW3btj+9i5Ac3o9lpOt77Nxi59unXgmo8ZVvPNB93W+9dP3jADOAz12QfzWuKcjDd3d1s3x4vUQjCiFBKbr75ZoQQ/NVf/RW5XI5//Md/nLNOmqIoyunEMjSiSOJ4IcWaR7HmkU0azGtKM1LxCCJJytJZ1z/JWNXlrI40FTdi/WAJjYi7v/RZnn3sAX7nf/41l73xvUSBz2PbSpQaLm55nNaFZ9OTSxKh0/BDUqbOwpYMxbrHouYUJTfA0gSphEWp4aHpUHF9jFCQSuiYmoalH/ghp6lrHKzy1lx/Z6sA+yDW9BZ5bGsdO5Fk06ZNwN7yNL++rJP77ruP1772tdxwww089thjnHnmmae4xYqiKHNnT43ogcnGQetJH0pQGWf4BzeD0Oh53//mypXnxeuip0bDg4n4waXR1HXE50xbOjVv/2m41y8/8nMoytGKIslIxaHuhfQsXMx9991HzfF4YbDMYNlhXj7Jpz79GcrlMl/84hfp7OzkM5/5zKlutqIoypwJwgg3iKi5PtvGqlTqPn1jdVK2TlsmScY2QEQ8uWOc0bKLKSS5hEHfhEPd9enKJfnOP36OJ/7rLv70U3/F7//RR6nUPb7+8A5KDQ/Tc/AqRZo65iMRNCW0eJq5plFq+GwarFCsewRBxPyWFHk/JCQuE3peV5ayG3JuR5Yl7RlaMvap/rhUgH0gM7PPiub5PPLEGj5/z8bp8jTx30v42c9+xuWXX87111/P448/Tltb26ltuKIoyhyY+R14jCUgqT7730ROlY73/A1WoYuHt4zNGvH2Rnag5zvQ7CNPlLZvcJ00Nf7yxmWqfrVyQtW8gJobkLYNFp55Do7j8F8PP8WaYhIvknTlExTSFv/n//wfhoeH+exnP8vChQt5//vff6qbriiKctzCSDJQahBEkqHJBuWqx4bBEsWGR7tIUHECMrbBCwMlto/UGCw1iMKIjpxF2YmouAFD/Tt55r47Offad5G86E08tmUcDcFY1cELJBO7tgKwcOk5CCTJhEEkBZohaHghCVNntOLih5KqF9CZS+F6EaGI0IVGTyHJBQuaySWPPTHZXFIB9gHMzD5rdSxl/XMPUntsdqbbf3l0B5+84QZ+8pOfcPXVV/OWt7yFBx54gETi2OqlKYqinC5mfgceSfmtA8m/9j2kz7sKs6XngGu43YFN2F1nHVc7P3jZIhVcKyecECIuRQecf+FKAP77gYeY7L4cP5Ks2zXJGa0Z5hfSfOtb32L37t186EMfYsGCBVx55ZWntO2KoijHyw8jwlCSsQ2EBKGD0CIsTaNc99g5VqbiBDy/q0i94YKMkBImHZ+sZaBpJiUxn0v+5J85Y+liijWf1eUxaq7H0EQVy9Yo9m0AoGnhWRiGQb0RYlkwL5eh4fskLR0zEtS9ENsw6MjZ9BfrZE0DdI2kfXrlvji9WnOamJlhNrXgPJx6jdLA7ADbCyWfv2cjq1at4pYv/hOPPfYYN73nN5HH2htVFEU5TRRSFpoQaIJjHsEWQsNs6Tnga0F5jLA8gj3vnONoJWRPkyfVystb2tLJJ028IOKSC8+nqamJbRueIUISIQnCiLW7Jli3exIpdP7j9u+zaPESbrrpJrZt23aqm68oinLM3CBkpOwwVnUp1j3SyTijdzZhEUhJw/MZr3hUXQ9T15moe9T9CE1EmLoJmkEUSQq2TrZzPmPVkNGyy/axMsOVBmUPao2I8e0vYDd3UQxthst1BkpVDAGGqTG/JU1b1qY7n+RVPXma0yb5lEVT2kSYgkLaYmFLBts8fcLa06clp5E9GWb/7Lqz+eZf/CYAQf/6/fb7+sPbue2JPv51sJOmK97PvT/+IX/8yVtOcmsVRVHmzpreIrf+dAPh1MLrN17QzYFi7LBeOuYHis6udQDY85cfcr9DxfbGQcpzKcpcE0LQlk2wsCVNSzbBFVdcwc71T9OeTSCjiHzCRJMwUXUYmGxQJ8E///v3iCLJm970Jsrl8jFfW0pJFEm8IFIP8BVFOelGKy6C+EFjueFTSJrICIbLLguaU+Rsnef7J3ns+a30TzRI6FBIGrihpCmp05oy0HRBd96mJWXQktJoeD7Vhk+57hMGUKyHVHauI9GznE2DDfrGGziBRApwfJ+UZfDaM9tYtbSVixa38sYLerhgfoHXnNHOa89oZ3FrhiWtmeMqqzXX1BTxg5iZYfbMM89EK76Iw5tm7RPJOHutF0TkLnsX/lgfX/m7/8UbrriYrvNfc9QlYPYkFSqkLIp1T5WPURTlmOz5LtkTgB7Nd9Hq7eO4fhRPiZXwn88N7JctvLzmJwTFAZqueD9Y8Rrqo8nA6ex8Fi2RxWpffNB9LEPjljcuo1j3KKQs7l0/yCNT67gF8K6LF6jvR+WUuOaaa/jJT37CjUstVg/bhFFEPm0SSoEXRmQTJsvPPYf/963/4IPveBPve9/7+M737yCSgrRtzKq/uq8gjPBDiaELam7AaMWl7PjkkyZp26Ajm0A71mkliqK84oSRpNzwAcgmDBw/LmNl6IKOXALzIBm393C8kN2TNSYqHs0Zi+f6HVqzNq+eX+CpHeNsGq6w85G78IuDdL3ufUg9BVqIhmD7aI22dAIJ+ELg+hAJgaZpRFKjVA/xAXd4G5FbI7PwAgwDkpaGF4Ss7y8xMumQtgyWz8vTmU+BiLOB170AL4zI2AZVNyBhnT7BNagA+4hcd911fOtf/pUl14XUwr3/AYWIs9c+tXMCP4joftPH0fUS737Pe2h7/xcQTT1YhsZ3P7zqoB3BmUH1rT/dMN2x1QSHPVZRFGWmNb1FfrS2nzvW9BOEEYYmQAiCMDri75NKw58VUO+bQdwb3kb1mXtoe+tn0Oz0fsdLKQ8ZbEspcXY+Q2LRhYesgb2wkOTe9YNcv7yL9166gLM7s9PftaahcdOKA08/V5QT7brrrgNg/ZOPcMV1b2Oi5lHzArqbkrRlE0zUXEBw7TW/xpe+9CU+9rGP8clPf5a/+OwtlOoebbkEhi6wDR03CAkjOT3yMlhyCEKJH4VEkcTUBA0vIJcwcbwQL4xIHGfteEVRXv6iSDLZ8BmYrGNOfd/U/QDPj0haOl4QMVHz6MgdOnfUYKnBQ5tGCaTkkoUtaCJeJrukJcmDmwL8oW1MrLmHeW//DA5pnBAIwQIcPyJl+QQRTFRdEqZGpRZgmpBJCBoe6BJqO9cCkD/rQgwNHCfCDyMsIyJt6YxXA9b2jnPBfGjNxg8FEoZO2jamE1AmD1Z/6xRRAfYROHPlFThf+QpjW58juXjF9HZNwNmdWb774VXTI0TtH72IZRe8mt0/+F90/ubfIxJp7lzbPx1EbxgoIYG3TXUO92bqFURSTndsIwl+ELF6+7gKsBVFOaw9mb/3PKQD8MM4NZPkyL5P1vQW+frD2w/6OoA3spPkGRdjtvTgDm6htuEXCN1EmDb5Ve9EGIdeF+2PbCesTpBcsvKQ+20ZrbFltMbDW8boG6/xyRvOnfVdq74XlVOlZ9EZzJu/gF/e/zMuesPb6WlKUUhbJEydpmTc0ZNIkqbORz7yBzz6xNN89Utf4MILL+RVr30DThihIRBSUvMDMgkT29BpTlkEUUTaNqg0ItwgxLJNQhlPEddtHU3V11YU5QhMNnzGqy7DpQZJQ2dRWwbH3xtvQDxQeCh+GLF1pMqC1jQDEw2e2zXBW1bMZ6DU4IXBCg0/ZKR3C8klF+Pneqju0ydoWfVOam6cIHJg0qc9axJJjSiCmidJW2BYBrt2rKFp/lm8aukCTE2ycbBCJMHzQ4p1n4FilYFSHceHprTFa5a2kUkYdOQSRJE8LWf1qAD7CESd5yEMm8a2p2YF2FEEd67t53NvPX+6s7emF676g7/hp5//A8Z+8gXa3vaXfPeJvv3OecfTu3jHRfOnM/Uip/6BSElEHLybhqbWGCqKckRmTu3ew9TjEewwjA75fbKmt8ida/tZv7t02JrXRlMHTt/zAJQeu53E/PPRM804/RsoP3UX+cveechR7PrWJwFx2AB7pq8/vJ3XL+uctXRHUU4VP5Rcfe0b+NH3v8vNWkg+naLq+dTcgLaMRVM6rsEahhF9EzU+9um/Zs2zz/Mnf/h7fOhv/p1zlp9HIZkgjCSZhM6C5jRtuSSS+P+buhcgNEFXU5KGH7GgOUXKNMglrUNOL1cURdmj0vBYt7tI31iNqhtypZRcML+AqWuM1zxMQ6OQsg54bKnuMV510TVBNqFTrHoITdKeTjDZ8Bku1Xmub4Ji3UPLtRNsfi4+bp8+wcRTd5G88p20JQV1X1J1QjRNovsS1w0JJYjqOJX+F7nqnb9HW9qm7PhIBDVHIjTwA5e+iQYLWlJsHa1QaFh0NyU4r7sJXROnZXANKsA+rDW9RUYdSC2+kMaW1ZjX/T4IgR/Go0Lff3oXW4YruEFEPmny8NYxpOym8Gsfpnj/1yg9ejtNr33ffufdc7xlaNNTHm++ce96Q7UGW1GUo7Hv1G6AD71mMa9f1nnIUd/bnujjL+9eP53U7HDs7nMoP3kXQ9/5n1idZ5C75K3IKERLZGhsexI49HrsxtYnsLvPRk8f+XdbJFGzeZRTLowkw+UG5UbAZVdfx3f+9RtsefYJLrnyGrKWgWUInttVQmgQhCHVRsBk3aPo+rzp43/LP/3Zu/nO//44137yG+RzeVYubqYjZ9NXrNOasfGDiLxtYhgingJ5mk15VBTlpcMLI0arLprQyNqCYt1HTs2U7cglDviwTkrJeM1l02CFSEpcL6Q1azNWdRiruGwbKnPP87uwDZORUpXxeoTfeg6+c/A+QdWHqj/Vv2hEQARAUgNDwPiGJ0FKrIUrGCw3KNc8oigebEzqkLHBRFKs+rhBPL297gWEkUQ/TYNrUAH2Ie2ZcukFEemzVlHb8gRXNFcx2s/g5y8MI4EglDy5s7jfsdkVN+INbaX06O1YnWeSWnrJfvts2F3ity5bRDZpqmBaUZTjsmFw/0zF928cPuT3y5reIjcfQXDt7t6IPzmEEIL0ea+j5YY/ofzkXZSf+jHJMy4hufjVBKVhgsr4Ic8TlEfwhrbS9LrfOqr3pos4Y/jM5G3q+1I52coNH8ePaEqZXHHl68hmc9z70/9k1VXXEkQRv3hxhImah+uFtOUSbBmp0VNIUHV87EI7r//Y3/Cff/MHPPEvt/Ibf/J3TDZ8bEunLZug5vnsKtYBOLcrRzaXPMXvVlGUl7JC2qKQshnzGwRSUKn7PL97ku5CCl0IupuSs7JuR5FktOLQX6wzWGpQqrmMTq2bllHEjtEq41WXYt2n1LeW+vgQuhDYx9gncCPwgNLmxzFzbQzb3dRKDZK6Nl0eVAgopGwWt2fQdJ2ufILWbAKJQErJ7mKdSEJb1j7tHkiqAPsQvvbgNhw/ftKSOOMSEBo/uvNOmq/8zcMeK4Sg+bo/xB/dydhPv0jXB/8vZqE7fo14PcJz/SWe6y/xv996PgBf+eVW1XFUFOWYXL+8i4e3jM3atnW0xhd+tgnb3JvgbGZixXvXDxIcJrj2xvoYueNWcqveRn3To7i7XyS5ZCXZV1+PkW1l4v6vklxyEc6OZ2h7+82HPFd90+MApM687Kje2zXndgB7c1aoBJDKySalZLzSYLDk0N2UQugGF195DT+756dc++HPIDXBrgmHtpzN9pE6uYSJoQuytk5/0SFjhnSd9SpWvvOPefp7/5eN9/07V//pJzmzLUs2YfKrTcNk7PiY3ZMNkpZBueGTtAwKKfOosvQriqK0ZhKsWtzMUzsmCMKI7qY0YRhhTH2XeEGEpWuUGz4NP6TuhYyWHSqOT6nu8uTOSfKJeL30SLnOaNlhshFQHOxj8Ae30nTp2yhtehT7OPoEulunvuMZmi68HtcTJI0Q29bJpkx0zcewTAr5FJef2YETReQSBpaucVZbhom6RyTjUeyhksPCltRp9T05JwG2EOJPgQ8Tx43rgN+WUjpzce6TbU/nc8twhfteGJ7eriXzJBacT33TozRd8YEj+o+omTZtb/00g9/+OKN3fo7OD3yRVy/pYLzq0j+59+P5p19uYajsEkmpOo6KohyT9166AIAv3LeJiZo3vX1mgjOIg9Q9Dw6PRGPrk2QvfjP5S99ObuWbKK/5T+pbVpPwXbIrfoPkmauQvkPu4rdi5FoPea76pkcw2xZhNs874uvrmmBJa5p/uH/zdM4KlQDy9Pdy6Rc0vJCGF7Brosa6gRKBLxmpunRkLLoufB2V/7qLZ598hBWrrqQ9Z2NpAomg5Ib0FJKkLItFzZKFLUm8CFYt/kO+VtzJo3d8gw+86fW0XfRWBop1Gn5IhMQLJdmEycbBMmlTp+6GWLogZRmn7VpDRVFOTwtaMrRmEvRN1LFNjf6JOjUvxNAEjh/gBVE8Kl3z2DZaIZ8y2THSIAg9EobAjySTNY/dZY+EqUENqlufJH9R3CfIrHwTpaPoE2jEE8STQGtWY2THU8jQJ33Oa6ZmBcdlvGxdp6M9ybymFN35JBctbiZtG0xUHapuSLHh4/ohtqWTPE2rKhx3tgwhxDzgj4GLpJTLAR149/Ge92S77Yk+3vzlR3jX1x/ni/dt4sfPDsx6XQKpc15LMLEbf2THEZ/XyHfQ+qZP4I/1MXnfl3nnRfMZmJzdx+ifdAgiGWfMCyJ+tLafr/xyK2t69596riiKcjDvvXQB3/jNi0iYe6dYacTTrO7bMMRXH9yGexTBNYBR6MLZvhZ/vB9hWOQvfTv2/OVU1v6URu9zGNkWzOZ5hw2ug/IY7u6NpM+54rDXbE6bfOTKJRiaIIokX31oOw9vGSOSKgHkS8HLoV8wUnL45YuD3P/CIMNlh3UDJSxNoz2fQEqoOCEXXX4VVjLFU7+8h46sRc428MKQ689vZ8XCPPMLybhe7KIC+ZTNaMUlaev87d9/iXOWnc///Njvsnb9RqpeQEc2gWXoLCwksQ3wg5CGHzJec9g1Uad3vEZxxoMzRVGUI5GyDTrzCUxN4+zOHB1Zi7GKy4uDFZ7dVaRYc4lkRN0Lmaj61AKfjqYE85osNg2Umaz75JM6pqFjW5ApdNHYsRY53g9H2CcwgZSAhAAbyCUhiGD4uYcxMs3Y885F1yBpx2Hp/KYE7bkE2YTF2d058imLlGUQIkjaBoOTDmXHZ7Ts4vgh7Vn7tBq9hrmbIm4ASSGED6SAgcPsf9pY01vkaw9umzVafTCpsy5n4r5/pvbiQ1gdS474Gu95641syxT56b/8X/7hS19CLrxm1ut7pozv+XlPDVs1mq0oytFaubAwXc6qkLL45aYRfv7CMM/1l4DSUZ8vffZr8AY309j5DABmSw+ZZVdDGFB/8RESC85HiMM/q62/+DAQP6g8nHeunM+GwfKs0oUQPyx4zdJWPn7tWep78fT3kusXBFNJgSZrPmt6x5FE9I7VqTsBxZpLseaDBud35cmlLJYtaGfFa69l/RO/IGUKepqTDJYFhhBM1gP8QFJ3IxINwUDZYVFLhkLKopA1uf173+fK11zGb7333dz2nz+nqylJse5zbneOiZqHpmmUGwFuENfITlk6k3WftG2oTOKKohyVppRF01TG8JGyQz0IsDSNyZrHeDlivB7g+D5RIqJUc1lX8RipukREhJEgDEKCKMTxwD77NZiDm5nc+QxJjqxP4AOhjJ+0WgK8CNxKhfLWp2leeQNJQ6Mawlg1wjQiCmmLVMIkm9DozqdASsqOh+OHIOO8LPlkXBqxqyl52q2/hjkYwZZS7ga+APQBg0BJSnnf8Z73ZNiTxOxwwfWFPXksQ0NP5UksupDaxoeR8sgy7gL8avMol731d8ictYqnf/CPNPpfQACGJvjIlUuwTQ2N+B/M4rYMQRhPg3T9eDRbURTlaKxcWOCjVy/lvZcuiG9IR8nZtZ7Ks/cy+ejthLUi6WW/hj+6k9qLD9PY9jQAwjAJq+PEjwUPr7bxQayOMw47PVwT8C+P7uCRqRHr6e2AZWoquH4JeKn2CypOgOvHZWSKDQ9NCJKWSdn3EWi8ekGB8zpzzMsnObMjS3PG5k1vfTvV0iRPr36Y5mwCQ9NYv7vMlqEKxbrLRN2lJZegNZOgM29TrLtsH6sTZtr5x699i62bXuCzf/GnGLrOed15mlI2zWmbpGWQT5nkEgYjlQYTNZfBUp3hcoMgPLpZKIqiKBDXtXb8gJ2jNZ7fVWS86jBa82i4LltHqjy+eZSJqsdorcFEuYoExmoNel94hk2/+An9D92Oe4x9gpwJIeBI8FwovfgYMgzInnslfhhnFW9NxVVDBisu7Vmba87rIpMw6B2vM1x2GSjW6Z+s4wQhhi4wDQ1TPz0fOM7FFPEC8GZgMdANpIUQ7z/Afr8nhHhaCPH06Ojo8V52TqzePo4XHP5GJYHbf3cV153XQfa8qwhLw7i7Nx7xdSbrPl9/ZCeF6z+OkW9n7O7Ps7Jd8P3fv4xP3nAu3/3wKt596QJ0XWPbSHW6UymJR7PVVHFFUY7Fmt4i7lEG2EFljPF7/oHIqRK5NQb+9Y8IJofIrXoHCEFt44MM/sefU37qx+Qvf/cRTcvyJ3bjDW0lfd5Vh91Xyr1lDPcwdcG7L12gZvS8RBxJv+B07BNA/O8vbRv0NCWpOgEZy+D87jyLWlLYukYQSjwJKcugOW3xljf+Bk2FAvfe9UOKZY9ywwdgd6nOWMVl9bZx7n9hiDAK6R2rM15zySdMdo5VOX/VVfzx//gkP7vr+/zsjv+YHmFqSlnML6RIGDptORtT19g5VqcjlyCKYLKupoorinJ0Kg2PvvE6O8bqyAicQBICdTdgsu7jBwG7i3X6J6psH6uxs+gwWHQZHRpl84/+nlK9SniMfQIdqPjx+usQaADlFx7EauokN+9sDAMSJgRCIKUgZWoITSefNHH9KM4lE0YkLYMFzWl6mlN05VN05ZOnbamuuZgifi2wQ0o5CiCEuBO4HPjOzJ2klF8Hvg5w0UUXHfnw7wm0akkLlqHh+RGHCrPXD5TZNFSh4YekzroM8bN/ovbCr0j0nHdU19MTGTre+mkG/v1/sOP7n+NVf/QrIB5tWr19nCCM/xHNnDIehiqZj6IoR+/j33uGu58d2K829uE0tj6JPe9c8qveAUDyjIspPvANUme/hqbXvAcAb3g7wk5hNnUe0TlrG34FCFLnXnlE+wsRBzp7RFNPHfckalPfh6e9w/YLTsc+QS5p4gYRbhBy1dkdRFE0NUVbY0mr5LHtY6RsAyEkz/ePI9AwdI0b33wTP/r+7SSEhx+G5FMmPU1JhssurWmLlK3zVO8kFy1swg8kpYaHG8TJg/78k5/mxXXP8Gd/+nEuvmgll156KRAn9xNCoAuNloxNGMXbHD8kYZ2eIzaKopyevCDiqZ1FJmouju9RbHhkbYNaw2es3GDdYBXX9ai4PjKCshMHwgKY2Pwk1nH2CfZ9zO9VxqnufJ7WVe+gtcmmLWWjm1BthHQ3JTmjPYcbhGwfrWIZGkkrzm8RRGAZAi+QTNRctIagNXP6leiCuQmw+4BVQogU8UOJa4Cn5+C8J9y+axWLdW/67/96foAXBisAhJHk03etiw+yUqTOvJT6iw/TfM3vInTziK8ngcsuXsGrzv8SN//pR/jUpz7FF77wBWB2sC8EaJogDCVCCApTT7UVRVEO5/P3bOS2J/soO8ExHZ9YeAHu4GaC0jB6rp3kwguw3v05Rn5wM4QBTVd+4KhyUEgpqb3wSxILL8DIHjoRGsTfk3IqmZmmCWQUl+FQuSleUl6S/QJdE3TmEwd8reoEFNI2OVvngY0j+FHE/EKKtkyCa258G9/59re456d307XyOkxD0Ja1KdY8ChmLuhPgeAFlx2frSFxLNpeySOjxv/V/+/f/YNUlF/P2t7+dtWvX0tbWhhCC9qzNQKlB1QkwdcnOsRq2qZO0daSUp11SH0VRTo0gjKYfyu1rd7HOUzvH2DxYJp+0GCo10HVBW0uSJ8YqrNtVZGDSwwnikWZBvGZ6z8DjXPcJAOobHwQZYS27mpGyh4bg7K48WTMgkgIhJZqUTNZ8fBmhC8GStgxdORtL1wjCMO4fSBituMxvTh3vRzjn5mIN9hPAHcBa4lIcGlNPpV8KZq5VnPn3hQsO3nnrufgNRI0KjW1PHfX1ntpZ5DHOZcX17+aLX/wid9xxx3Q7br5xWfwPhrhTKjRBJCW3/GQDn75rnZoqrigKa3qLB60y8Pl7NvLVh7Yfc3ANoKfygKD85F1IrzG9rfWNnyCa+v1ouLtfIJgcIr386qM6LpLwa2e382fXnc07Lpo/nZtiZskxOPTnoZwaL/V+wYHomqAtbVFqBFQcn6SpsatYZ6hU54KVl9A+bwE/+v7t7J6osWbHOGNlh+Xz8mhCMFJxySV0+sYbCASmLugbr7J5pMoPntzFzjL88798h9HRUd7znvcQhvF4T8LUsHSNppRJGGkU0hZntGWoOSHjNRc3OPr8CoqivHT5YTQrB4OUkpGKQ1+xzu7JBv4++Rn6Rivctnonv9o0xmPbxtk5UiaQET2FJDU3ZKBYoxGEeEE8yiyZHVzD3PcJAKrrf4HVdRZmSw81H0YnXbYMlUlbOq1pk6Gyi65rBDKiWPXwgoiSE2DqgnmFFLahI9h/thvEo/U1NziiJcAn0pzMM5JS/pWU8hwp5XIp5QeklO5cnPdUetuKHiz9wE+HGx3L0dMFqut/cUzn3jhUYWzZu7C7z+YDH/wtNm7cyJreIveuHyScKtcVREz/7AURtz/Rx/u+uVp1IhXlFWTf4HFPYsYv3rfpgN8Htz3Zd8zX2pO4UUtkaH79RwjrJUbv/jzuwCbCWhF/fBdO3/PI4OjWf9bWPYAwE6TOuvyo2/SrTSOsWtLCTSt64kST+5ToOtznoZw6L7d+QcLU6GpKkU3qZBMmGduk6gSM1l1Gyi5nXvbrbHnuCTZs3UHJCRlr+LihpC1j05K20DWdMAwIidg6VmVwssFgscZ4zeGRzcNsD1r505s/zwMPPMBnPvMZxqsum4bKPNs3SV+xznjVYdd4jV0TNQZLDSaqLgMH6FArivLyNFn32DVRZ1exTs2Ncz24QTSdKyII48Bypt6JBnU3pC1lYSAYqgfouo7jR4zXfCp1n3ItxGXv+ug9ZzhRfQJveDv+6E4yy38Npq47GcLmMYcHXhxlzc4Jnu+f4Pld42wYKOOGEdmkSdbUqXvx9117zo7LG0fxz3v4YcRAqcFI2WH3ZP2UBtlzVabrZWflwgK3/95l/PHta9m9T91qoemkl11N+em7Ceulqac7R0foJq1v/hSD//YnvObaG2h6998RmcmD7i8Bx4/42oPb+PpvXnTU11MU5fS3prfI6u3j0wHk+765Gi/YOy16T2LGmSO5e6ZKr+ktHtPItYxCZBigmXtvUppp0/bmv6C0+g7KT98NQDCxm+bXfwRhHPmSlch3qL34MKmzX4NmHfz77WD8UPKjtf3877eeP/3+93w2X/nlVnZPNmZ9Hneu7Z/eR00hV+aSEIJC2uKCRDMVJ2DbSI2zOtP4geSFwQpnXv7rPHrH11n/0E9YePX7CIOInqY0QRSxuDVFw4eRagO8iKFKg0rdY9tIBU1GeDK+x7dnL+CK33gHf/u3f0ty3tlcfNV1QMTW4TKlWsAZ7Wk27J5EExqjZQehCWquzzkdeQxVuktRXlaiSFJxfCIJGXuqTJ+lE0pJseaTts04wZeIA8s9eRr2Hh9hGlAPAnZP1DEMwau6MyRsEz+QTDYcpGDWcHXAie0TAFTX3Q+6Qerc2UlPJVD1oeoH5FMmNS9EAK/qzqFpOmUvQDOg4aVIWjoLW9LTx9Zcn7oXIqVERnGiypobUPcC6l78uWRs46Quq1EB9iGsXFjgH9+zgnd+9THCfaYgpJdfQ/nJO6lt+CW5i99yTOc3cq20vekvGP7+Z3H/6x9offMnD/sf/74Xhvn8PRv55A3nHtM1FUU5Pe0Zjd0TUN+0ome/YHpPrgY/iNB1jd2TjelR23+4f/MxXbf4wDcIJoewOs7AaO4mdeZlaHa8nim/6u0E5TGEaRM5FcxC91Gdu775caTXIHP+NcfUNgl8/6ldQDyr6KNXL531ORm6hqEJwkii6xo/fHoXQSTVOm3lhDF0jZULW8glLMJIUnN9ntg+ztIzz6R5yQUMPX0/Ha95J56lM1CsYpoGKUND1wVNtkVNeFScgIoTr3msefGUTAvwwzrzrvxteja9wN9+6o/59NfuINkyHzeICCJJzQ1xA0l7zuDFgTJJS0ciaXgRF84vIETcPkVRXvomGz6TdQ9dCGpegKHFA20RkrQdh2+mrtGRSTBUdmj4AWOVODlyEEkmGz4SwRVL2xkpuTi+SyZtM1Z26C86vLB7komqQxDGa6/3LDg5kX0CGfrUXvgVqaWXoiez+72uAZau4fmSpG2Ss01qviSIXPon6ixty9LwxulsStCVTZFOGDh+yHDZxdAFdTdEE1D1JALJSNkhlBJdCNqyCQrpk5fTSn0TH8Ke0aT/9Zbzed+lC2ZNGbfaFmJ1nUX1+Z8fVU3sfSUWvoqmqz5IfdOjlJ+864iO+epD27ntiT617lBRXkb2HZ0WsN+06D25GpbPyxNFEbc90cfbv/oYb//nx3h4y9hRX7P4q28TTA5RuPp30BIZ/NFeJh+9DX981/Q+wkqgJ7MYTV1Hff7q8z/HaOrEnr/8qI/dI4wktz3Rxzv++TFue6Jv1ucUhBHnduV49yULePvKnnjK2NSympnrtBXleAVhxETNY6LmkbIMzu9porMpQVvGpqcljR9FLFx1Pe7EAHJoC93NadqzSRY2J6l6IRqCloxNPYiT9lUdqE4F1wAe4HkQahav/b2/Bs3g/336Y2zpH6NYdYjCkBcHy5iawNI0hIBQRpRqPn1jVTYNl+kr1inWVAkvRXk5cPyQhKGRtHT8MKI1myCTMGhKWrSk49FlKSVlNw7ERyoO1YbPEzsmWL1tlI39Jdb1T9A3UcWyBMmkiet67BqvM1yqU3F9XD8OrKeD6xPcJ6hvfZKoUSZz/uunt2nEo70CyFmQs3QWNKdY0pLhzM4MFdeld6xK73iNu5/v5+u/2sK3frWNO9bsYqziEEmJJA6iy46PH0ZkbZ3mjM1Q2aHmhEw2fOreseemORZqBPsg9h1NuvnGZXARbBmu8OTOOKjNvOr1TPzsK3iDm7G7zz7ma+UuuQlvYBOTD34bu2spiQWvOuwxn71rHQjUSI2ivEzMHJ02p0awb1rRM2vK85reIrf853q8GVNqjuP5HsJKkLnw1zFb52M0deINb6Ox8xmqG35F02vejdP7PJFTIX3e6456apVfHMTte578Fe9HiON/lhsRf+/99VvPjysuTAXZz/eX2DRc4bcuW8RUNS8iiaq+oMypkYqD48cPvlw/RAiBH8bp7nuyNjtGdM65/BrW3fmP9D95Lz3nrmDlwhYqbjyVE03DC8O4tnXWolJ3iBoSU0IYxVMzdR1kFJBt6+Ttf/55vnPrH/Lov/0Nb/nTzxMCKVsnmzSRQNIQDFRc2nMa1CUykqRNncmGRy5pnra1YRVFOTL5pMFwxYUgIpcwSJj6fuWogkgyWGywZaTMluEyQRCRNHQe3jbGZNXBCyVL29O89swOEJK+CYfhisuOsSqOG1IJZic0O5F9AoDq8/ehZ1pILH41ArAB2wRdgK5Ba8YiFILmjE1nIYUABiddNg9WGJx0KHserSkL09ToL9YYmHQ4uzOLbegMlRzCUJJNGlTdkLzQyNgGUSRx/ZCEeXLHlFWAfRAzR0m8IOLmu9cT7VMWI33ulRQf+CbV5+87rgBbCEHLDR/H+/c/Y/Tu/0PXB/8BI3focjYRwAHWYSqK8tI0s2zgzDXEM9dY3/qTDbOC6+NlNnUy+dC/oyVzJHrOw553Dmg6pcdux+l/AaPQhZY8tu+26rqfg9DILL92ztobAcW6x3c/vIp/uH8zj2wZi7OeBhEbBssI4mnl2tR+ijJXvEDGU7IllBo+thF33kbLDg9vHWd3qUZNmiy8+Bp6n7iPxfnPUq67DJZdJus+CVvQkk6wtD2NH4Q0nABDd2iEEgMxtZZSIkVcguayK17HyAf+iJ/925dYct4FXPrG92PrgomGh96AoZpH1jLjzqauoRsaXhihCcHMbu+eGXaqpJeivLSkbZP5hk4kJbaxf51n1w8p1T22j1bZNVZjsFhjcNJlfnOSsXIDL5R4IWweqlL14jJeCdNE12S8PlvEU8NnBtgnsk8QlEdxdjxDftU7EJqOJA6qc0kdTQM3kNS8kMjzmXTSVOoeOpJUQickAi2iKWkShBGlesCi5gS2pWHoGt35JJqI4zVL12lEIQlTJ580CUJJIW2RT57ch+5qivhBrFrSgqHFNyrB3ozeUbS3c6vZaVLnXEFt40PHnKp+77lStL/1M8jAZfTuv0EG/iH3NzT2y6gLqmSNoryU7SkbuO8DszW9Rd7z9cd5rr80p9dLn/c6situpPrcz6i9+AgAdteZJJeuorHtKcxCN3oic9TnlVFIbd39JJesPOzDwqNhaEw/fPj4tWdhm3un0F+/vGv6d8uc/b2oKMerkDKZrMbTK3UBfhRRdf04MY8X0Jq2aE+bdF58A4HnsvGRe3HDiMm6hxcEEMVTGDMJg4sXt/Kq+c3Ma8nQnU1iGDpBFBBJjZSpk0lZtGUt/scnPsGlr7uOn37zC/RueJqRyQYv7Cqyvr+I74dUHJ8dYzXOnZenLW1hGTqd+QTa1Oi144f0jtfpHd+bdVhRlJcOU9cOGlw/u6vIi8MVdo1X2TpSYddEA8eP6J2o4QeShguNABwPJhsOw6UaLw5MMjzZwNIFfgT7PnY7UX0CmEpuJiPSr4qnhwsgY8OS9hzzm5IQRdTckCACISMm6h6appEzLQpJi7aMzVntGRa0pljcmuSqc9pY2JxG1wSaFq+xNnQNx49ozdikbIN5hRTzmlMsaEmf9PwUagT7UIQgntkPmoinHWpaPJ1rj8wFb6C2/n5qGx8ie8EbDnia5pTJRP3wNzezdT4tN3ycsR//DRMPfI2WN3zswPtpgt957WKySXPWSNe+09rV1HFFeXn4/L0b53Tkeqb0sqsRhkVj29M4vc+Su+Rt1J6/j+QxlNXao7HtKcLqBJnr/nDO2qkJuPXN588a2d93xP/szqzKIq7MqSiSCBFnpdU1QdLUKDV8wihCaAIRRXhhwNaROgJYdNYyNi44kzU/v5OWi3+DuhtQdgImqh4LCmk04qnehayFPaFRaEniDPnU0AlCqLohTZbOZD2kv1jhdb97Mzu2buY7f/Pn3HTLt8m0dDJRdTirM0tbNsFYxWO84qIJwbLuHJq2txM5XnWxdIEQgtGKF09VVxTlJW+k3KB3rEbK1ugt1tkyWo2/OxI6actkaVuK3aV4WUsowXEl6YSJlCHjU3kahA4JCVoIzozuxYnoE8gopPrcfSQWvRqzqTPeBghNEIQhNVfSlLaRUiNhCZa0Z2jLpehqStA/0WBJe45XJwycIOLMjjSLWrOEMn4AsYepa/QUUrOuu++U+pNJjWDPMHP0d/X28eli7pGEcKqMxr4lJ+1552C2LqD67H8f9LxHElzvkT77NeQufRvVZ/+b6vP3HXAfP5J89aHt02sMZ7Z536zDiqK8tN32RB9P7TxxM1I0K0l62dXkLr0JpKT02Pewe84jf+lNx3zOyrP3omeaSZ5x8XG3b88Tdilh/cDsEfx9R/wPNgNAUY5FxfHpHa/RO16n7oXU/RA3jKi4ARsHSoxXXLYMV8jbBl25JJaus7AtzeU3vJORnZvo3/QcYRCRtjXcMMIyBMNlj61jNTpzSRw3ZPtIjaGKR+iH6ALSCYNUwmK06pCyNCYDk3f9xd/jOQ3+60ufpFar0ZFNsLglQ8Y26WlO0ZqzKdY9al5I1fHZNVFnuOwgiddohpFUa7IV5WWi6vpsGqoyWnFY01tkuOzQkdJJGRpCaCRsi47mDOd0NbG0M8PCliRJSyOKBFJGuAFYGgQhGDrMy+ukZkSDJ6JP0NixlrAySmbGQKQOGJqgWPfQDUHGtjBNwfzmNBf0NPOqeU30NKXJ2DqLWlP0tKRYsbDA/EIGKaFY9ageQ2nSk0WNYE85UFIzTQiiw2QQEkKQueDXKT7wddyhrdidS4+7LU1X/ibe0FbG7/tnzLbF2F1nHnC/f3lkOzvH60RSTrd5ZpIkNUVSUV767l0/eMKvITQdq3UBLb/+R8gwQOjHfmsISsM429eSv/xdCO3Inh535mwm6j5esPcJpgCuPa+DBzeN4IXxTKI71vTzthU9KoBWTrgokoxVXJKWThhJqm6AbWpUHZ+JqsvWsRojFYe+8TqNIKQ5ZROEEc0pm8te/yZ+9q9fZODJe0nNO5u0odOcMdk8UsM2NOp+iNOejIPfMMSQUPMh8nxMQzBSbTBR9ckkDHaN1WhKt3L5Bz/FQ1+/mad/8I/c9Mc309OSwtI16l5EzQ2QUjJadRgre7TlLQIvIm0ZRBpIJM0nsTyNoignjuNHGJrg3K4m1g8UKaRNSlWBrUc0JXV6mlKc25Vm94TDttEqnR0Jdkw0qPsBtp5ksFSj3JDoWkjaMrAtC436rGvMZZ8AoPrsvWjpJlJnXjq9zQRqXoRlRLSkNNIJC00I3r5iAct6mqi4PjvHqpiGRj5pYhs6C5szDFUajJQc8imT8ZqLbWqndKT6YNQI9pR9R3+LdY9b37ycI3nom1n+awjDpvrsvXPSFqHptL7pf6KnC4ze9b8Ja5MAtGZn3yC3jdZmlaXZk/znz647W00PV5SXieuXH30pjONxvDfSynM/AyHIXHDdER/zgcsWccsbl80aZTN1wUeuOoML5zdNbwtDNTNHOTmEAE0TBKEkCCWGLmjL2DT8kN3jdYgidk82mKi7DBTr7JqoUUibRFKwtLuF86+8ge1P3E+xWGJwskGx6rF1rMxYzcXxQ7aM1qi6LhU3xJXxaJKpgxcEbB6qkEsZlBo++ZRG3ZfMu/Aqzr7uPWz85V3seOS/MIXOWe1Zls/LoRH//1KqeYRRxETVQ0jQNUFnPkFXPnnAdZyKorz02LrGQKnO49tGqTR8lnflefWiAsvmNbFicSu5hMVoxSdhWyzraSKdMOnI21wwv8AbX9XNOfMKnNWRYdm8HN1NNme0pmnP6gcdcT3ePkFQHqWx7Wky578eocfLVAzAASZdaLgeoxWHlGmwtCNDU8YknzLpzidpyyZozdpECHSh4fg+u8ZrDJQcTE1DEs/QOR2pAHvKnhI5mohHpQspi7M7s/slADgQLZEhde4V1F54kMitzUl79FSetrd+mqhRYvQ//xYZhUzWfK48s3XvlMmZbRBiet2hmiKpKC8fZ3dm5/yLWsqIxvY1c3xWkKFP9fn7ppKbtR/xcb/aNML3n+qbdaN8x0Xz+fmGoemyiHuomTnKySCEoCNnU3MDyo6PrQtqXsgZ7VkWtqbJpW1CCe1Zm85CBl3TWNKeIWnplN2A866+icBzmFz3S3qasoRSYAkd14+wNEHkg2UYZBM6+YQgZQsMoSGlhhuG6JpGEIKm6WhAJODKd36ErnMv5vtf+v9Yu/ZJig2fuh/hhpJKI6TmhaQTBjUnQBeQT85ecx1FcjqruKIoL02hlNiGRmdTko5sklTC5PzuPEvas1w4v4mlbWk68ik6mxKEUUTVC2hK2WgIJhohF/c0c0ZrDl3TEUIQComdMJBDmwgblTlvb/XZ/wYpyV7469PbZsYxY/W46sfCtiSaEGwfqfLYllG2jpbZNVFnx2id4ckGXXmb9QNlnCBirNLg0a1jlOo+ln56Ln9RAfaUTUOVeHG8jDOG3/rTDfxobT9HeivKvvoGpO9QXf+LOWuT3bmU5jd8DLdvHcVf/gtBJHlk69h+bdIEfPi1i2cF1SqbuKK8PKzePk50+N2OSnn1HYz88K9o7HhmTs9b3/w4UW2S7Kt/46iOe2pncb8M6TU34OsPb5+1ba4/B0U5mMm6x+rt42wYmMRAMl73cf0AL4joKsRroC/syTO/KcNlC3MsakuRMi1MQ8PUTVZc+Gral5zHrsd/gqlJsgmD1pyNF4b4QUTS1klaJpmkTWs2wdkdWXRdI2kIBBpP7hhltOIwVnGp+z4SSSZpc9Xv30K+pZ2//tPfZdOOPnSIM48nTbxA0nB9CimLSAjCGcH0ZN2jd7xGf7ExaymGoigvLaYuCCNJGEZoelxRKG3pLGxN4voS04gf1k1UPaoNSXPKoqspQRAFGBrUwhBNh2zCIJu0KNYDRHmU3h/8FeP3fmlO2zr90P2MizDyHdPbZ2amMjUwDAOBRErB4KTDpuEq63aV6R2vsbA5RS5l4keSuhfieCG2rtOStcklzNO2X6DWYBMnEfr0XetmbfODiLGpzJxx/WsoJA+eDdzuOgur80yqz9xLdsWNc1ZzMrP81/CGtlJ5+m6szqVkll09/ZommK77+u3Hd/L6ZZ2sXFhQ2cQV5WVkTzLDg2nLWIxWj7zmc2PbU0w+9B+kzruKxKILj7N1s1XW/hdGvoPEkhXHfa5nd02y32CbjB84qO8z5USKooj1/SUmKi5jVYfnw4h5hRT5lEXd85FS8OqFBdKmRtkJCCLJ1bqgb8KhJaPTcCOeavhcd9P7+c4XPo05voWLlq2kJZ2g6vlI4qmOJScgZdlYhkUYRti2Tt0NKE76JC2wUgEyEkihMeg49I1WiSK45Ldv4f6//xif/Ojv8LXbfkzCthDA4pYUfiTJTdV+nah5dOWTBGG8hMzQBU4QMF516WpKnuqPWVGUY6BrGj0taRANTCGZdAJ2lRzmNyVpStv0jVfZNVkla+mc0Z5msFJnqOQwMOkQhpJyI6AehHhByFjFw3VqPPa1z4KEwtW/M6dtrW9eTVgrknn1DQfdJ2HBxQuaWdicZazuoWsCP5RU6i47xuo4QURHNsHybsHS9iw7xir4ocX8pgS6LtDnKN6aa2oEmwMnERICfrV5dO+URQl1P+QjVy7hijNb+ciVS3jvpQuwDG16qkN2xW/gj/fh9q2b0w+2cPWHsBecz8R//z/coa1AvLbqmnM7kLBf1vCZ68ldP+LOtf1z2BpFUU6ENb1FPnPXOj5917rpmSdreovcfPe6Qx53NMG1P7Gb0Z98AbN9MS2//kdz9iAQwBvdidu/gcyrb0CI4/8GbEqa2KY2a5mOrepbKydBFEEo4wfUadvCDyWeH9KSiuuxemGErWu4ESztyHF2Z455zWnO686SSyZozSZZ3p3j9Te+lUy+iRceuIP3X76Yq87tYHFrlpQV17btzCc5sz0N6JSdgHPaMrhBODWTDsbKAWM1n4QhqDoBVRc8CSOJbs55x5+x5fmnuflTn8AQcOH8Jha2prEMHdcPaXjhdA4ZTQi8IGJXsc5g0WGi7qip4opymivVPZ7fNcmzuyYoTZXWCiPJzrEqliYoJE2GKg4116fkeOyYqLFzrELdD9FlPOP1+d1FRisN/DDE1AWbRqqUGh51xycKoeF5PHvb31Ed2kn7Gz+BWZjbnC+VZ+KH7sklK/d7zSDOJD6/yebVCwu8elGBhc0pkoaOH0aMN3zyCZ2GFzBZ97AMnYUtKZbPa+LSJc0sbMvQlU+e9PrWR0qNYBMnEXp4y9isbcu686zbXZqeji2Jg9hs0uTj157F6u3jLGhJM1Zx2T5WY+tIldQ5V1D8xbeorP0pyy+6jB3j9TlZfC90g7Y3f5LBf/s4o3d+jq4P/j2XLFtCe9aers8dStgyHK+dWLWkBUPX8IIICfzw6V3cpDLvKsppa01vkfd8Y/X01M3vP9XHNed0sGGwzFzN5ozcOqN3/jVC02m/6TNoZmJuTjylsvanCMMi86rXz8n5nt9d4q/fcj7FukchZVGse6q+tXJSGIbGOZ1ZVm8vYpuCJa05WrM2jSCi7MTZvS1Do+aH/P/snXeYXNV5/z+3Ty/bi7TqXQIhARK9V9MxiGLHJdhxSRzHjm0MGGOqa+L4ZycGO+50ENU008EgoUJRAXVptdL2nT5z+/n9MbuLll0VQASQ5vM8eqTZe+bcc1dz75y3fV/H8+nMmliOR2e2SFDTqArr1CcCbE0VOfoTF/LYnf9LX2cno0c1EzIUDmxOYLs+7TmTtlSJ8TVh1jg2m3rLKuNC9ym5IPkgy2B5Atcvt9TRJYmcKRCjjqD2sPN48aHb+cn4qVz7na+RCBkoMrzwVje27zO1IUZE1wjqCkFNQUKiJqyhKwqOJ9DVj2bkp0KF/R3b9VnflaM9a1Iouby5NUNjMkBQV1nTmSNqqHTlLXrzNn0Fh6LtEtIlpjfE6C26bO7OIfmCVN6iZHuoqoOKT3u6RFRXcAUoks+bj99G1xvPkTj2s+gjGMHv6xq6N2NtXUni2M+N6HSXKaeH521Y11GgpbZAS3WEUdWCzrRJb96kr+AgydAYC5IzXaprwkQC2vCTfQSpGNjAJfNaaO0tcMsLG/FFuS/bYeOrWdOZw3Z8fMrp2JoqkwzpXPrbRViO/w6RMUAziBx4MtlX7mP9ps1MmzieNzv2jmBAWfTsSjpv/Tbd9/+QxYHrWbI5NWQN97+2ne3pEt85bRqzR8UHxYE8X1TSKitU+AgzkHUygOfDE6s799r8Qvj0PPwznL5t1C+4fkgt1N7AN/MUVj1DaNoxKMHYXplTiLLwyVePe/+tDytUeLfUx0OcPMNgS28egURQk6kJaFSHddJFm6zpEg1o+L7AcX26cyZtfRaKYtNXtEDAW+0ZGuafibj9N/zq1//Dv33nagxFQVUVhFRW+M6ZNiXHw9BUVBmCskTWAUdAdRBMFxzXQVXKm1HLERT7v/hDR36GSOcm7vnldTSOncSCT5zAyu0Zegs28YDKK5t60TWFmohBX95kW6pEr6YwtT6CWumLXaHCRxaBwHR9ZFHuZrA1Y9JZsgjICj25EolggK5iCUV4SMKj5DggFNZ35chago5MEUmWUGwZRZLRhKDXdPB9n5ItkXcculYtZttTfyI87Whih56/168ht2xkp7sGDOovSuALgZAE2/tKjK0OE9I1PFeQDOts6i6QtxxGVQWJGB+vTggfzbj6h8Dlp0/j+nNmocrlmus/vLyZq8+YwTdPmcKN587iokNbOHpSLXcuaR1mXEM5igwMivuklj3CW51vG9d74xdtNEyk+rR/xWpbRd+Tt4wowPbK5hQLbnl5iPKuLJcVxivCZxUqfDSZP756SIuqvU3mxdsorV9M8oQvEBhzwF6fP//G3xCORWzuGe/qfapcFmjRVXmYEmglHbzCh005KqRSHdbJWR6qLJMI6dRFA4Cgr2jR2pfHtF0yRRdNk1BlAImS5VEXC3LknOnMmH8c9976R7pSGRRFoq9gkyvZ2K5A8gXr2gsUHRddlfBliWCgHP3ImFByIJX3CagSqixjaBIByvorsqzQct63CVfV8/trv8bSVevIFT00VSFddLFdQUhXSJcc+ooOY6qDaEo5Ki7LEkXLZWtfuT7T9T6qUkEVKux/GKrC2OowluvTmzdJ52w0AUXHIVW0EQgUJBRJxRXgezJhTaYt4xA1FJAUenMWfQWPvGVjuS6O62G7Ll0Fi66trWxe+BO0unFUnfa1vVouBuD1O93D048ddLprgEH5e78xJhPSFZJhhWhQo2jbxEMKvXmbnOnSlAwxvjbCUZNrOf3AJibWx6iOGHt1jR80lQj2DqSKNr4QQ3phf/W4iSzbkuKah1btkfKmGq8jOGke+dcfJ374Rcha+QOxt766wtOPwe7aSHbxveh144iOIBzgekNN7+mN5Q/3O4XPoBw5q6RdVqjw4TJ3TJLrzp7JVfevYG+3dCy89SKZl+4gPOtEonPenQG8JwjfI/fqXzFGTUevn7DH7ysbCDKfnDuK8+eMAsrPo0o6eIWPCoos4foC2/WRJAadYDnLRQjImy6m7VId1mlKBMhbLttSJUqOS8lzQYLerMmskxew8uWnefzh+5h93DlI/aq/mZzFpr4SkYBCX87BcSUyRRcBBPVy9Dogg+VBriiojkjoqkzJ8sADVYOaZJJpX/sxD97wBW769y9y3S134qCRKrhMqk8gfEHBLqeQxoMaYUNDUxU8X9CZMzFUGcvxSBVt4kGNrOmiKxKx4K7FFStUqPDB0pwMcfw0lbfaM4DPpp4ikgSqIhMyZGRZwxUSo2vCrN6WxvdBl30m1cewHB9NhoCusbUvR8l20TUZWZIwszm23nMdkqpRd95Ve71cDCD/+hMI1yI690wClGutgwZEgyoBVWZcbRRVkYjoBjVxjfp4mLlja7Fcj0RQIxH++D9/Kgb2Dgz0wnZcH02VB6O+P39y7TDjeqD2eSRic8+kc+3LFFY/R/TAk/f6OhNH/wN292b6nrwZrXo0gZZZuxy/4JCWIcJnjlsWPrt3eRuW46PIEteePZNL5rXs9bVWqFBhz7hkXgsrt2e4bXHrXpvT7txI7yP/idE0leqTv7rXvdQApQ1LcdMdJI757Lt6nwBcz6c5ERw0pCsGdYWPEgN9pC3XpyZo4AtBV8akPV1kW6qE6wkaEgECukZL1MB0y8ZxqmCTDBlEdIW87TP7gk/w8l8ms/DP/8uMI88kFtbZ3FOgaDvkSg4gkGSZaFAjZimkCh5QjvS4bjnibPvQW/SJBxRCAWgMl1vUTKyJYOgNVH/3J9zy/a9y8w1XcNG3buKgUUlGVQWwXEF9NICKhGX71MUMGuJvK4jLkoQkCRzPZ+X2DPmSi5BgWkOUulhFabxChQ8TVZGpjgSY0pAgXXKZUBPFcn22ZQo4btkIyRZsIrqGpAhkRaEzY9KcCKHIPtuyJp4HpueSt2Q8x6X9oZ/gpjuov+gG1HjdXl+z8D1yyx/GGD2TaN04ogFwBeiaiqHpjKkOMXdsFVVhnVBAZ2x1EMeDTMlGSFCtfLwi1TujkiK+A3PHJLn1svl84+QpgxHeS3+7iBffIYCmynD9ObM4cFR8xHmM0bPQaseSW/bgB6LUKckKtWd9GzXRSPf9N+GkO4aPAcZWh7jx3FlcMq9l0Hmg9NeSCxhMdXd9wfceWDmoXlxJJa9Q4cPh/Dmj0NW981j2Cmm6Fl6HbESoPfdKJPWDEQbJLXsAJVpDaPJhezT+6Ek1g//2xe7bkFWo8GEhSRKJkE59LEBAU+jJW9hOuQ+r5XkUHJdsyUFTJYK6guX42J6PLEG25KIpCqOrQkysi3PaRZ9n41sreenlv9OVLZG1LDrSRWzPpytnkQxqqPjIsoqhQSKoUB3RCOoQDUiEjPLeoz4RYEZzknHVEcbVRJjSFCdsqMw64kRO+NQ/s/Tph3j27t8RMCQKlo+hyUQMjWTYYEJ9hMkNMTRFRpElaiMGpuMjyxJhXaVgetREDQxZYktvgZzpULJd2jMlOjImTiWNvEKF/1NUuVxSki05uJ7ABwKajCarGJqK7wu6ciVM38f1JMZUh5hYH6SlJkAsaOB4Eo7nYtmQKvpsefIP5Dcuo+qkLxEYPfMDWbO1bjFetovmw84iHgLLAdOBqrBMXVRjTkuSIyfVMX9SHcdPq2dsbRQkyJRcLMdH/4iqgr9bKhHsdzB3THIwivKrZ9YPKnHvyPFT6wejTa+3ZYbNIUkSsYPPovfRX2C2vkFwzIF7fZ2yEabu/O/R8adv0L3wehou/TGyESqnsAmBosgcPrGGKQ3Rweu69bL5gynhAHct2YrbH4b3fMHti1u5Z+lWkCRcr9JDu0KFDwPff/+bWOE6dN9/I34xS/2lP0KJfDD3sN21CXPLGySO+SySvHMBkpOn11MbNThvzigWbezlhXU9CMoe3lRxz9uMVajwYSJEOeLr+gJVUphQHcQVEjURA+H7FEyHkKbQmzNZ25ljfF35+7dNFLnwoou49Vc/4uUH/8xJxx8DQNYU1BoCSRJ05kqkCha+LxCAJwT1IQ0JCSGD7wmQIB7SCRoqhixRFdFpiAWxHI9oSOfki79Mqm0DD//+P5GSozj/vDOZXB+lM2eSLTkIEcDQbBL9Tq1IQBtU5HU9n6Au052zSBds6hMBunMWWdOhJmzgC5/evE1DfO+nk1aoUGFkBFAV1vD9IEFdwXRdFElmen2Y59Z20503kWWVoukiydCdLxFQVZJhg3BQxbJdCibkPciueIreRfcRn/MJorNP+8DWnFv2AHq8nqqp84gGdCKGQFUEiXCQUckwc8YlqY0HqepPA89bLkFdpTYaoGC6uL7AACzXQwgwVPkDyb77oNk33AQfEANR35H+W7/4p6U8t6ZrxGMA4enHIofi5JY+8IGtT6tqpubsy3F6Wul5+GcI4TOrKcZFh7aAENzxSiuX/nbRiJHouWOSXHv2TFRZGrwGATheWRF1x1TySjS7QoX/GxZt7OX9BomEEPQ+8d9YbaupPv1fMRo+OBXu7NIHkFSDyIGn7HLcgaMT3HDuLOaOSTJ/fDWGVs6m0XcQMqtkzlT4qFMd0UGSiAY0EiGFcFCjLmbQ2lfg1sWtPLKqg219BRwhmNuSZEJthLTpEdIUukqCY8+6iFWLnmbLpk2MrYkwvSHM1t48rX0lenImqZKDkHxcH0q2T0hXGVMbJq4rxII688ZV4flQHTJojAVQZY1NPQUsTzC+OsyM5hif+fYN1I6ZzOP//T1WrVxNwSob/WOrw8SCKumiA5Q3r6m8RV/ewnI9VEXmgOYEE2rDNCYC1EcNLMcllbdQFQlVlvB8n0zRpj1TImc6H/L/RoUK+z6aIqOpMo7n47o+QUUmYCiomkLJ8/GQaI7rZE0X4fsoSHiez/a+Imu2Z0ESBA0otr1J7+P/j+CYA4gf/wU+qEZXVsd6iltXEZ5zBpanoCkQMBTqo1Ea40EaEiGSYQPT8bDccimMocookkTBcpHlss5E3nTYliqxPV2iK2d9QKv9YKlEsHfBQNR34fI27ly6Fc8TKIrEU2917ba/taTqRGefTualO3D6tqFVNX8gawyPP4jkCV8g9eTNpJ/7E5tP/wInzWjA9d8Wa7t3eRsLl7dx99JyxHogMj2lIcqFh4xmfWduUHVcUK4vlwFFkYe9pxLNrlDhg2P++Go0RcL23ntpSW7JfRRW/I344RcRnnb0XlzdULxCisLqZ4kccDJKMLrTcbIE29Illm1JDWYI7ZhNM3dMkmVbUoMijKrytvBZ5XlT4aOEoZZTvpsSQbpzJiXHJ6QrLNrQTd52USSJ9qzN9OYouiqTLtpEdIWwodCVtZl7ygL+eutveO6Bv/CpU35BVVDlb6s7UCRBseSQyguE5OELaEgo5BxBQ1DhwNFxPE/CpvwdXbBdArqOIXm4AjZ1F8iWXCbURpjYWMcXrv0ffvX1Bfz26q9w3Oy/MWfqOIq2i+OBoSn0FWy2pwr05G3qYwGClsqoRJCs6eD4AsvxWbShh/acRUSXCeoqDfEg8aBGT94iqCl05Sw0RSagfbxa51So8HFCkSWaEyE6cyYvb8iiKhDQFTRJYWx1hC29BWRFoTkZwvY8bM8vO+mFRVvaJJcv0dXeRfd9N6DGaqk++7tIisoH5R7LLbkfSQ8SOfBkLBdsV5AwFHRNZlxthNpogLzp4QsJ0/GojRpEAxrNySCOV04PVxWZnrw1aHh3ZEr4wqcqbGCoH5/nTSWCvRvmjklyw7mzuPOLh/Hvp0xhwcGjRzSuFYlh0ezonNNBUcgufYCQ9sH8qn0B0TlnEJl9KtnF92CueoZcyUGWJOR+1dN7lrVx6+JWbG+o0X3xbxZx++JWlrWmh815wrR6Pjl31BBDfdHG3g/kGipUqPA2x05576IjxfWvkHrm94SmHEH8yEv24qqGk1v+CHgusbln7XKcEAzLppk7JslXj5s4aEDvKMJouz63L9559k2FCh82iizREA8ytjpEVVhHSOC5AtcTuMJndDLExLoYE2pjTKqLsDVVRAhI1DZw8HGn8fQDd/LG+u28ti1LTTiAafv0lsr1lUKUFXclSUYRLtXRAMgypufSniriOh6pgs2WngLb0xaZvElAhaLlsGRTD2s60/QR4rAvXk+ur4dvfekzFIslVFkiYmhUh3Q60yVSRQfT8Sg5Lr4vsFy/v8bcx/E9tmdKVId1NEUhW3IYUx0qZ/RJEqoigxBYroe/t9seVKhQAShnoxVtl7zl0pd3qIkYjKmOULLLQoSjqkJURXSmNEQ5eFwVM5vi1EYClCwHu7+cRfEstt97LcJzqDv/6l06w98vbq6HwlsvEJl1ErIRxgJMx6M5HmZ0dZBEUGdUIkjRdpFl8DyfnOkihEBTZEK6Wu7c4PkEVRnT9dnUk6cjbZK3HDoy5sfqeVOJYO8hA5GXZVtS3P5K6xAFcVmCLxw1nlXtWarDOg+8th0BKOEk4enHUljxFLmjPjXYC25vI0kSVSd+CTe1nTX3/pSfa0kCo2YgS9CUCLK5tzjsPes7c4PK6CM5DEqOx/lzRrFweRuO66Mo8pAoVIUKFfYuA1Fcy3lvOeJ292Z6HvoJev14qj/xb0jSB+c/9R2L3Kt/JTjxULTqUbscKygbDabj8/k/vMIZBzRxXn90etmW1GBrLlWRB59JgrKhvWhjb+V5U+EjiyRJaIrEIWOqKJRcPF8wuSFGfSxI0fGIBFSiAYOtqSIBTaI773LUuZ9h8ZMP8ec//p4Zp15MyXXIWw6KDAhwfYgGIChLBAwNSfi09ZnISrm3tul5dOctFFUmoJrokoSslkWEgprMuq48jXGDhonTOfxzV/D8Ld/n0s/+I7/8n1tAAUWSeH5Nb79gmaAzazCtMU7MKEe2S7bHhu482ZKLIllEgxqaUjasg5pCUFfImy4F20XkIGe6NMQCZaO7QoUKe41MyaG3YON75Z7XjhD05C0iAZ3mZJC+vMXY6hBzxlRhOT7ru/PUuT6qBH15i968xZt3/wS7p5W6C65Bqx69V9enAm7/vwNA3/KHQQiiB5ed7hogST7Zkk00pLG+M0dIV+jLW9z/ag5FkRmTDHPM1HoSIZ2asE53waLkeBiKREhTsNzyczRb8ogFZDwhGLlw96NHxcB+l8wdk+T6c2Zx5X0rEJSj1l88ajx/eHkztusjS9IQUbTYIedQWPEk+VcfJX74gj06h5dPkVl0F75tEhw3h9CUw3cpIAQgKSo1Z3+Xjr/8O90Lb6DhH/4DLdEwonHtCVi6edeRodNmNg5Jkb976VbueKWVhcvbKqniFSrsZQbaAQ4o++8OGdjRDPfyKbru+QGyHqT2/O/tUV9LIXxK65dQ2vAKyAqxQ89DSzTs0XoLq57BL2WJHXLOHo0fIFNyuXVxK3cva+OaM2dw7cOrsN2yoOKxk2t5YnXn29coSYP12RUqfJRpSoQ4enIdPXmrvIEs2sj9ojx2ziVm6CTDBvGgRikxm2lz5vHoXb+n5ZjzaOs1iRsKWeFR9CCgQSio4eGTCOoIJDzhs+GFv7JxxRL0qiZis05AjVVTHQ4Rj2mkTBPfl8kWXTIlB1n4BAMatQccw4Fnfp6XH/od1900npMv+hKSLBC+R6ZoU3JdhPCpmlDFlr4imiKRF4KmWJCQodCZKaeDT2ksR71kWaIhFiBvuoicIBrQKFguRdut9M2uUGEv4vllY1oBoiGNUVVBaqIaOdPFclze6shTcgSjkwG6sxa18QBTG2Ist126cyaSBKse+DXdqxcx7hNfgnFz2BPXvZNqx1z5FFa6g+CEQwhPLwsySv1/fMCg3NFAVaFkQ0CCgFZiw+uPUTX9cBrqG5BkqAkbJKI6jhAkQwZIkC45bOwu0JWxGFsbZH13jkkNUXRFplcSlGyPiKGSt1x0VdCcDJEtuRSssiNP+xg58ioG9ntgSkO0LDrQ3y87Z7mD6Y1CCCQY3CTrtWMJjJtLdvlDxA49b49a5XQ/8EOs9jXIeojCir+hVo2i+pSvEGg5YJfvU4JR6s6/mo4/f5Pue35Aw6d+ghyIjDh2Vzfa0ZNqBntizx2TZNHG3mGp4hUDu0KF98eO0dsBQ3NAWVtVZWaPitNXsNnYU+CdSSY73r++Y9G18Hr8Upb6S36EGq1hd1jt6+h7/JfYnRuQAxGE62C1rqDxH/97t2qdQvhkl9yP3jARY/Ssd33dUH6OPLqyffC56bg+NVGDgCZj97cNuvbsmZXnTIWPBbbn4/iCun7jM120qY8FEQgyRQcXDwVB3naRgGPO+yy/vurLrH7hcaKTjsRSwRYOQc1HkiUkXyAkGRkoOi5rFv2Nxb+/CSNWjZV9Cump26madx6Foy7CtgPEwkF836Mja/Wne5dIuh5hXWPGJz5Pz7bN3P+bn1E/aizNBx2DEIK2VJGQoZG1HB5ZsZ2W6hjVYY3t6SIBTUUSMhNqwiQiBu4OmhCSJKFr5Yi27fp4vkBTPj51kRUqfFQpWA55yyOoyeStslGZLjhURVzqYkHqY3FypsOrW1Lomoxp2+RMl56CiabKRAwZ03KIGArPPriQdU/dxegjzqF+/pnkigLXBwuGOfElwLNLpJ//M7nlDwOghOIUVj+LHIoTHDu7nInWP94FLB9UG0IKjKoO0v7So7ilPAecdin1TQk816Voe+BDQzxIcyKA5Xl0ZkukLBvTcyk6AlVW8F2BQKD07z0G2gFGDQ3XF8SDEvUxg6ZE8P/gf2HvUTGw3wOLNvbieuXNsOOWe1jqqjy4WXwnsUPPpevOq8iveobogSfvcm7PzGO1rSJx1KeJHXYBpXWLSD3zOzpvv4LYvPNJHP0Pu4xma1XN1J57BZ13Xk33/TdRd8EPkJR399/84voeblvcSqpoM3989aCauu36SJJU6VtbocK7ZMCYHknUS5YkfFF2YMkSHDGxhq+fOBmAnz+5lg3dhZ3OK4RP71//E7t9LbXnXrFbxXAhfLKL7iH9wl9Qwgmqz/gm4WlHk3/jb/Q9/kvc1PbdCjKW1i/B7Wuj5sxvvefWGZoqc9rMRpZs7ht0VJ4/pyxstuPvqUKFjwOyJGG7Hl3ZEnnLQ1NAlWV0VcH3Bc3xEJYj6Cs6TG6I0XDGJ7jn1+N4+cE/87kfncqW3iJIPhIGedtDlSXCusya7jxhQyG3djHhZC1nXnc3b65fz9an/kLvS3dSWP8K8oWXo44ZS3MsxPZ0iYAhIyFTFzVQVBUNmPup72Cmu/jdDd9iwfdvpn7iTEqOjSRBJi/wXKiLekSrQ0Q0hb6ixVsdOarDOslQAHxBLFjeR0QMDUNVqAprbO4toskSlusR1CtGdoUKe4IQgnTRwXQ9kiGdgKZguz6dWQtdkckUbVwhaIwHiQRUVKl8P7ueT1/exO5XE2+zXFJFi86cwvrOHIZSFhNbveQlnvjtD6mbdigTzvgnPN/D0ED3wXHA22EtAcDq3ED7Az/ETXUQnX0qscMvQgnG2PqLiyitf4Xg2NlD1j/wfhfwZOjN5Fnz9D3UTJhFy+QDGFcbQlcVOnIWKoKmeIDmRIh4SOWtjhxjqsOs3Joib/scPqmKpqqygGJV2CCgKeQsl5qIQSyoEQmouH5Z/Ozj1qprrxjYkiQlgN8CMyk7OT4vhHh5b8z9UWT++GpUuaz0K4An3+zk7AOb6C3YvNjf33VHAmMORKsbT/aVhUQOOHGXtZG+mQdAiVYjSTKhyYcTGDeH1NO/Jbv4XuzuzdR98ppdftACLQdQfeo/0/vIz+l74r+pOvVf3tUH0xdwVX8KvKZI3P7Fw/jsYWO55YWNeL7g2odXMaUhWtkAV6iwB+xoTA+o8e8o6oV4+4khSwwa15f+dhHmbuqx08//ieKaF0kc+3lCkw/b5VjfKtLz8M8orV9MaOpRVJ/y1cEMFyVaTsUeeP7siuzie1FidYSmHrnLcbLEiA7Hk6fX80/HTGDumCRTGqLDDOrKc2XfYH/aF2iKjKEqFB2PVNHGFwLXF8xoilOwJQxdpT4WoDdv05YqYLtw7qe+yG9++F22vbmMmvEHohTB93zCuoLlCVzPoyqko8kS2UwaLVpV7g+bbKDp7H8nPv1otj70c978zdeRPv9jzFGTsHwI4OO6Al+SqQpodKZLhAMBjvvSDTz24y9x30++wadu+CPIcbb2mQS1chrm429u4422AIoiE9EVUkWbqK5SHzPozJVYtkVClyXq4waT6mNkiw6eL4gHVPqKNmFD/Vilb1ao8H+J5wtc30eTZXKmQ6pooysy7ZkSo5Mh/P7s13IwS6IjXWT1tixhQ+HQceXv59a+ItvSRXKWg+OXU7B78jbbUiXqoxK2bxLIbOd/r/tXEo1jmfOZq9iW89A0cDyQ5LKAYll9oWwA5lY9Q/tjv0ANxpj0DzciN83C7P/eloNxfGs3ewIBmbdeopTq5MhPfYOaiEFA02ipCtOUDGPbHkKC6qhOQFOZWB+lL28xrbmK5qogLckgLdWRQQ2HSEAjEng701eRJZTdlMh+VNlbEez/Ah4TQnxSkiQdCO2leT+SzB2TZGJdhNXtOaC8P77/te1Ma4giSeXXO24uJUkiPu88eh76KaX1SwhNmrfTuWWj/KvbcaMrawGqT/lnjObpSLKyR8ZyZNaJOKl2si/fiZpsJD7/gnd1jQPbetsT3PzcBp5+q2vwegbEh9Z05Hh0ZTunzWwcTCmvUKHC0Ij1oo29g7XVtuMP/lzvLzNBYrD3tevDmo4cqaK9W7Gz3OuPk110D5HZpxI79NxdjnWz3XTdfQ1O71aSJ3yR6NwzhzxHBp43shHe5TzWtjextq0mecIXdqsLAeXuCjt2HJOlck/sHY3pikG9z7Lf7AtM2+XN9jSbeoqkCzazRiUQwButKRRFxhNg2h4tVUEKlsP67jzzTjqLO27+GUsf+iOnfvO/aEqEMR0Xz/cJGhq5ok11WMcFVCOMU2ylO2/jCYgYEskDDyfePJaOvy+kYfR4uoouOJBxwRMC4TrkTBlFgVTRJRZNctRXf8pTP/kn7rrxaxz2L/+FbATJW/B6W5aQLpHJWbjCZ0x1hIaoQcHxyLkuQVUhGlAI6SqdWYvGuENPwcF2ylG3kK6Q1m1M20WSZWoiOgG9kiBZoQKU98ztmRK+LzA0GU2WsRyPbMnB8X0aYkEMVSaoKxQsD9fz6chYCCBb8uhIF8iZDhu7c1SFDSLVGpoMndkivvDJmi6GapNNd/PAD76IEQxx5Jd/iGKE8TIl3BLIStngcyiLj8lC0PvCX+h5+U4iY2Yx/pOXE47FcVzAhpIQ+GZ+t3sCTRZse+le4g0tHHzUSaCUxRZzpkNf0Saky9RFg4QNHYFgcn2U7YZGTcShORmiYHu4vuBj1H1rj3nfT0BJkuLA0cBnAYQQNmC/33k/yvzwkTcHjesdebPj7Z+9M3ITmnoUynN/Irv43l0b2IEokh7CSbUPOxaZefwerU+45V9/4qhP4aY7SD/3R9R4/R73xH1n5Kkzaw5TGv/93zfRky+f54V1PQAVI7tCBYZHrD972NjBrBYfSIb0IQKCty5uHfL+nz7+FjOb47sUOyttWk7f478iMG4OVSd9eZdON7t7M113XY1vm9Rd8AOC4w4aNsZNbQcklFjtLq8ts/he5ECUyAG7LnUZ8MRffcYMnlnTxdNvdSGEQFflinDZfsD+ti/Yni6RM11qwwbru3Jsz5Rwen1qozqpokemYDGtKUprqkjedJElhSmjYhx33qd58H//k0M2r2H8tBlYbll/4eBxVeRKNqmSw9rtOfRkPcXlz9KbL+IJjeakSthQqQ6MRTr1S2zOiMG+trIop31uy1iEDcHk+gjbentwTIv6xhbO//efcMcN/8zS31/DYf/0I1BlPN/DcXy6bQddkWjrLTFztMa8CdUcPKYay/NoT9u4nkNIVynZHqbtki5aOD6ossSrW9J050vMHJWgOmQwZ2wVivzxSumsUGFvU7AcMkUH4QvChkrBcgmEFFJFB4HAkGVKTrnEoj4WwPUFq7elWd+VQ5NlspbLpt4CEUPF9QSJgEZ1NEC66JAvOVheOYskobvc/4tvk8+k+PR1vyMXqmVzX56gAZZdjmCXvHLk2vJc+h79L/KrnqH6oJOpOfEr2LKKWewXMZNBzmcRdhE1Xr/Ta1MAvXsVmbb1HPf5K/AR2JYgo7r0FRxkWVAdCdGUCFK0HepjAWLBciR7e6ZE0fExVBl9H8182RsuxnFAN/B7SZIOBJYB/yqE2Hnh4Mecx1Z1vOv3SLJC7NBzST15M2bbKgKjZow8TpLQ68djd6x/1+fwHRNz43Kyyx9G1gOEpx1DzelfpzPXQ89f/wMlUkVg9MydrxE4cXo9maLN0i0pfFFOET9sfDWvt2XePo9g0Lge4NGV7RUDu8J+zbItKe5d3saqbZnBiLXj+qxqzw4KH0pAqli+d+aOSbJweduwefqKDs/3O61Gwu7aSPf9N6HVtFB79uW7jCSbbavouudaZC1Aw6d+jF47duQ5O9ajVY9C1oydzuX0bKW0bhHxwy9C1nctNqLIElefMYNL5rVwybyWYTXoFfZ59q99gQSyXM5ImVATYXJdhJ6iDb7AUAXRkEbO9IgEdBrjKq29RbZnShx1ziU8ecdvWP/0HRx66E+JBXSS4XJauO0KLNsjHtKJj5qI8D1KnVtQaibSnXOQFIlAv7iYrpQ30FDeIJtAb8GjWEjR8cbztP39QYxQkMxBJ3LoCadzxpev5qFffp+ND/6c6Rd+i6ACW1JFfA/GVgcJhzR0VWZCbZRIQMMs+BRsG02RqYsYtPaVuOOVzRQcF8mHg8dWoSgylifozZkYqoLjeShyJYpdYf+lra/Ahp4Ctu0TCag0J4MgUW53pylYjoPt+5RsF8I6kiThuh7bMyaqLNGVK+F4grwJLWNriegyadOmPh5gc2+OvOXiCKiLajz5/75L27rVfOGaXxIYNZlMbxa3v4eW45Xrpn3AcE22LryJ0qZlNB/7KSLzF+AKabDdFoDjQ6nfBtHrxg+7rlERELJMUyzEkkcWEq+u5ahPnIemabT2FTFth6pYgDFVEXKmS9hQ0VWFsFFO/dZVmVGJIK4v0BUZeR91xO0Nt4EKzAH+RwhxEFAALn/nIEmSvihJ0lJJkpZ2d3fvhdN+eJw6Y89a2byTyKyTkIMxsovuQVd2/oEymqZid27At809ntsrZsi/9hj5lU8Rm3MGiSMvJfPynbiZTmrPuwo13kD3wutxerfudA4BPLm6k1c2pwYj2J4n+PXzG3d7/hmNH0yP7woVPg4s25Li4lte5rbFrbzelhlUA9dUmRmNscFotIAhIoF70pJrR8qp3j9A1kOc9G//OVhSMhKlDUvpuvNqlHCShk//ZKfGtRA+1rY30Zum7vLcmVfuRVINonPP3O06PV8MOhKg7Ez46nETK8b1/sNu9wX70p6gKRFiUm2UZEhjamOccbURptbHMFQF0xVUhXWiQZX6uF6OUnkSadMlFIpx1oJPs+ipvzI1UuLcuaMYXR3mzW1Z3tia4Y2tGTb3FqifUFbrt7auQlfKz41iyaVoeQQ0Fdcbuh4fyGYzZF/7G22vPM6ME87n6AVfYt3f/oyS7eDQk8/l+Eu+zIpnH6b92duYOTrJmKowybCG5QmE57I9bfHnRVt48PWtvNaaJlV0yRZdVrdnyZkWedPFUFR8YMW2NAXTI1u06cxZ+MJHryiLV9jPcDwft7/ey/cFbX0lkkGN+riB5boEdYXGeBBdVVBkUBUFQ5MRO+iwdOdNLNcjpKv4vkRTIlDOENncQ1feYkJtDNd1ae0roqgKhiLz9O9/ykvPPMHVN/6EmYedQGe2RNb0y463chUaDmUR5S13XE1p86tUn/LPBOZdhPkO43oAs20VSDJ646Rhx6IBg6ZEmIS1nS0rXuETF3+eeDhIIqQS1mRUVSJdcGhLlShZDlv7imRKFo739oNKVWQCmrLPGtewdwzsNqBNCLG4//U9lL9YhyCEuEUIcbAQ4uDa2l2nIX7Uufz0aXzp6PE0xAwUuVxnqKsyh45NMvBZkaVyn7gdkfUA0blnUtqwhClGGnUnH6zAmAPAd7G2rtyj9QjXobD6WZzeNuKHLyA05XD0+gnIgSi+XSq377rgGpBVOu++Bi+/8x7Y79zw70nfPIBocPftxypU2FdYtiXFr55Zz22LW/nVM+u5+bkN2N7Qu+eISTXcetl8okFtyHNhR8NzZlN8j8/pm3m67r4G3y4x8x9vIpys2+nYwpq/07XwerTqUTRc+iPU2M7H2h0b8M18+bmzE9xsD4VVzxI54CSU0O7XLIBcydntuAr7LLvdF+xLe4KApjBzVJxjp9Zz2MRqqiMG05vinDazkdNmNDJ3bBVHTapj9ugqPCGYUB/m8PE11EUNTrzg88iKyv/84r8AiaqAyqaeAttSRTLFEoWShWvECFQ3ktn4Gq4HmgKKJOO4DkXHxXrHegb2BD3btzDxxE8x64iTiDZMIhSJUyzk6UwXmfmJzzH7+LN4/q6beeTe29FkmfG1UWKhALKskogYBDWZN7fnKJZsCpaLrssUbJe+vIPleeQtB1/A3DE11EZ1jp5cw+ETammp2nXdZoUK+xqpgs3mngKrt2fZ2JWjJ29Rsl2WbO7l1dYUqixTGw2U1bARJMMGdbEAybBBcAe9AtsVRHUVVZWI6AqqJDM6GWJyQ5y6qEFtRMMXguqwRtRQWf7wn3nxodv56tf+jRPOvYTObAnT9pAQVEWDKJS/j71ihs47rsRqX0vNWd8mMvtUvJ1eDZibX8NonDwsWy0gQcBQCKgyLy38PUYowtyTLyBbcunN2SBJ1ERDtFSFiAYUhCSxpS/P8s2pQefD/sL7zt8RQnRIkrRVkqQpQog1wAnA6ve/tI82l58+jctPnzZi+52B1wDfufcN1ne9LVgWnXMG2cX38sydv6HmzH8fcW5j1Awk1aC0cSnBCQfvdi3m1hWUNi4jccTFGI2TEZ5Lce1LaFXN6A0TANASDdR98vt03n45XfdcQ/3FN+0y+rU7aqI6vXkbIajUVVbYrxiosR5IA98ZMxpjg32uBwTNtHfcK6u2Z3Yxw9sI16Hrvhtw+rZRd8E1pANNvLJ5ZEdZYfWz9Dz8HxhNU6j75PcHlcJ3RmnjUkAiOHZ4bfYA2SX3gfCJHXreHq0Xyt0VLj992h6Pr7DvsD/uC1RFHlTCHUSRGVcXHXxpOh6aItORNenKlpAln+amRk45+wIeuOtWTv/0V+nzgmzqydCXc8k5ENEhEpAIj5tL3+t/w3QsJMnA98s9tWUYLEEZPE//nqD6iIuR68aSKZhsWPI08cYxBBsnkjJdDFlh8nlfZ9u27Sy//ac4eoxDjjqeGc0xUkWL9pSJroChyWzLmsSDOpbtkQhoSDJ8YmYD29MlNF1n3oQaYgGF1r4Stutj7Wcb6Qr7N74v6MiUSBVt2tMl4iGtv02VieP5RHSViKHRnTXJWy6yXG53a7keqiwT3yFAFdQUHCBuaCRG6WxLFdFVhWREw/VBlhVaaiKMT5VYePedvHD7LzjutLM5+TNf44HXttGRKSE8Qd50UFUJTZPIpProvPNK3HQHdeddNWhX7Gz/4hXS2O3riB95ybBjMlC0fYrdm1j98pMcff7nUQJh/IKNjMTUxji251MXNaiJGizf1EdrbxEhYG1HnvkTA3v/P+Ajyt4qkPkX4NZ+pdCNwOf20rwfed6pgrvj62VbUmzuHVpypgSjRGafSm7pg8SP+hRaYni6uawZBMbOprhuMckT/2mXAkbC98i/8TfC047GaJ6G8BzMttVY29eg148HJIQQSJKE0TiJmrO+Q/fC68s9sj95NZLy7iPPXzp6PCfNaODiW17G8cSQNkMVKuzrDLTY2t2n/jcvbBzUMbjmrJmDfeXXdOT4+ZNrOW1m4x6liAvh0/PX/8BqXUH1Gd8c1pNyR/IrnqT3kf/CaJlJ3flX77ZWGqC0bhFG0xSUcGLE414xQ/71xwhPPwY1PnIkfOAJteP1rO8ucNvi1oo2w/7Lfrsv2BmpgkVPzqQjlccTElProqSLNp+45Is8svB2Hr79f0kc82n6Ci4lp9xntmRD2hao4w9FLH2YwpbXEBPnUXxn2LqfHfcEavM0MgWHV9a9SHb960RqR6OrEomAxpa+An15k5bzv4X1pytZdev1jB3VyMzmw5jdHGNNZ56OnE1NWGF8XQRdUYgHVVRFRpFlCo7PabOrqY0ESIQMHM+n4PiENAXH8zH701wrVNgXEUJgOj4CQUCV6StYbEuXSBVsFBlCukRYV4kFNVRZBgTb0kXCuopGOROkMR4sp5Kni3TnLKrDOrURg2RQoztjks6bFEouadNla6pES9Lg+TVFMkWXZYtf5O6fX8WYGQcz7zNXcv+rHaRKRWxH0JO3yZs+mgoi38u226/EzXVT98lrRsxUC1AWLFMlyAnIrX8FEIQmHjpsrCMgXbDofvIOFFVj1FHnsb4rR9F28fHJlGxaqiPMHh0nXXR4syOH5bjEAiprO9PMGZtE3xclw0dgr0i3CSFe60/1OkAIcY4QYuc5yPsRNz+3AdcbvoWOHXIuyDLZxffs9L2hyYfj5bqx29fu+iSShKTqCK9cRVFY/TzmhqXIRojwrJOQJGmIgW40TaH61H/G3PwqvY/+AiH23NN84Kg49375cC4/fRqLNvbi+uU+4J4vWLSxd4/nqVDh48hAWngypO+2VZ7c355KUG51t3J7hq8eN5E1HTmuuG8FL6zr4Yr7VlC0Rqp+ehshBKmn/5fiWy+QOPazRGYct9Oxudcfp/eRnxMYO7scud4D49pJd2B3biA4+fCdz7v0QYRjEdtFq7+JteERS14eXTm8G0KF/YPKvmAoRctl2eY+7lyylTfbc2zozrOxL8farhyb3AiT553IX+/6E73dvQRVBlXBB+xoo2UWshGmuObvuz7RO/YEnW88T/ubS7GkIIFpx9GWLtKWKtCRKtKd88nnXVoWfB89kuBvP/8mPW2bSZs+05vijK0O4brQkbZwvXJXBE+Ulc6b4wEaoiHqYkF0VQbBDrWkEvIetBOtUOHjSqrosD1ToiNj0lewUSQJy/GRKNdRO56L7Xnl0gpFIhHUyFkuvQWb7ZnSoFM6XbTZ0ldAlmBLb5HunInrC2y3rCxeHTVwHZ+orrKmK8ejKztZ+OSL3HHjvxGqbWHmp77HllSJ9nSRdM6lL29ScnySEQ0p38uaP1yOl++lZcG1hHdSBiZLEI/KVFdpTEqqOOv+jhqvR3uHwJlM2RDP9HSxefHjjDnsdCQ9yurtGXrSJVy3v9+367GiNcvyLSkKlkM8rBMPG5iu2K+eCxX34l5m2ZYUC5e3sbYzx5KdpHCq0WoiM08kv+Ip4odfjBodnl4dnDQPZJXCm89jNE3Z6fkkSSZ26Hn0PPQTCqufRY3VYoyeRXjaUYOKwML3yL5yH14hhd25gcjME0gc9WnSL/wZJZQgefw/7tG1zWyOD0bnkyG9/0YRw9JeK1TY13hn661RiSBb+oqDx3dUCT9kbJJESOeJ1Z1DjsNwg/P5dbsWd8ouvpfc0geIzj2L2KHn73Rc7rVHy227xs+l7twrkVR9p2N3pPjm8wCEpx4x4nHfKpJb/jDByYeh1+w8Er2ue2Rx6NNmNu7ROipU2BfxfEG6aLGlp8Cq9ixbenLYrofj+5iuYGN3kUn1ETRZYvJJl7Dm5SfY9OJ9JA+7AF3yKO3gn1cVjeCkwyiufQlxir3Te3xXe4KsMFizvQDCo+3F+7ALKZzODVQfeAIzPnsDK275Br/87hc473s3U1vbAEJCUWXSfTlSRQtdVWmMG0R0FUWS0bVyjMb3BemSjecLegoW46rDBLT9I0pVYf8kZzqEdQUJyJRcXCGIBVSaEiHSRZvRyRA9BRshBLNbkri+oIFy1Nt2feLB8v3rC1EutVRkbNfnxfU9mI6HLINp+Zi+i+/79JYstvWZ5Lq3seSW76IEI0z5h+vpclS2deYIquAj4/llw67UtY3Vf7wCr5hlwsXXojVNAwVMh2G11xEdGuIGrgdWPkNu02skDz57WCBBobyX6Xx5IQKoP+J8LM9DCEHRcejqtGmIGQQ0hZ68SV00RFhXyBYtxlZFmNuSGF5Gsw+z/1zp/wG3LW7lwptf5tbFrTs1rgeIzf8k+F65tnEElECE4Pi5FN96AeHvSooA9Nox1F98IzVnfIOqU/6Z6IEnD0avhO/Rfd8NuJlOAmMOIHn8ZaT/fjv66BnlevAl95FZvHCX8w/0tD1vziigbGxc+/AqPL/sjbr6jBkVdeAK+zQDaeG+ANvxaU0VhxzfUSV8eWuaY6fUoavysHvnnQZnX2HnQmD5FU+Sfu4PhKYdTfKEy3YaNc+9+gh9j/+K4IRDqDv3qj02roUQFFY/i9E8bae9LnPLH8a3CsQPW7BHc+7I0ZNqAPj0/y7mtnf0+q5QYV/H9wXt6SLLW1Os3J4hazq4HgR1Fd+H2qjGuOog29IFnl3bRSbURN2M+ax9+h5kz6Q6BEkFDEADkjpUTT8aYRcpbViyy3Pvak+QsjzW3nkDVv+eIHH8ZXQ+fzt2Ic0Bn7+BfLqXe276V1Zv2s66nizbM3myRRfL8ekrmsSCOlURg/q4QcQox2gs1++vuwxQFdIr/a8r7NP4viAa0CjYHgXbKyuEawqKLOEJn4aEQdHxqArrJMIGmZJLNFBOFY8YGs2JIIF+9fB4UKcmbNBXsOnNm0QDKhISnVmLrOnSGAmQDKq0p0tYmW6e/8U38TyPSZ+6llhVNZ4lUAUIIVOyfPDBznSx8g/fxS3mmPYP1xEaUzauNeltZ/8ACiAk6MzYdGRKrHn5SfA9AtOPGRwT638OhQ2wCynSrz1OzQHHQaiagulSsG16Sy5F26On4NCVs8jZLrbnMrE+SkM8zKSGCPWJMJt78ryxNUVHxhyinr4vUolg7wUGotZ3vNLKCBnhI6IlGghPP4b8a48Sn3/BiMq84RnHUVq/GHPL6wTHDRNmH4ISjEEwRnbJA+gNEwiMnokQgp4HfoQcjJM8/jIkWUFSVEKT5qMEYyRP+AJeMUP62d+hhGJEZp044txHTqrh6ydOHjSid6xBFWJoO54KFfYlBkQLyxkbDLav27HYWJbK/eu9/oN+f4uq278wf5gAYqpo86Wjx7OqPUt7usT6nUR+i+sW0/voLwiMPYiaT/wbkjSyLzT36iP0PfHfBCceSu3Z30VS91xTwenaiNPTStXJXxnxuG+bZJfcT2DcXIyGiXs87wAvbegZ7Of9Qv/flXrsCvs6luvRV7BxXB/TdtHlcjsa3xdYIZ2W6jCO6zCxPo4sSdy7bAueB5Lwic27gK5V32Tby4+QmHc2loC6uEzJEXhCkJx4IO2RKqzVT5OYcgS7+uZ9N3uCyKT5+EqIaPME5nzuByz9zZU8d/MVTPuH68lHQjQnwyiyRDpnoSkS9bG3hYp8X6DIEoKyiJsnQKu06KqwDyKEoDtvUTBdVEUiV7TJmA4hXaEpHuo3ul0m1kZZ25WjYHvUhHUMrfwMGJUM4bg+AU3Gcny68xau79MUDzK5PsqyLX30FWxChkK1b9BSEyZbtNmesUn1pXn+F9/GyvVxzNf+A71hPJ05h7QLkgs6PjYQKnax9k/fxbfyjF5wHV7tZFTAdcCXy9FrA7ApG9eGBOWGH+V7t/PVp9FqWoakh2f7Y3yyBalX7kN4LjVHXIDt+HS7FiFdIaLLZE2HaEAhGlBIBnVUCfSAxqS6GIeMq6Gtt0jRdYkaGhu780QMhUhg3+1AVDGw3yc/fORNbukXM3q3xOdfSGHVs2SX3E/ymM8MOx6aOA85ECG/4sndGtgDhGedgJfpAqCw8inkQITk8ZcNpouXNr9GadMyonM+gSQrVJ34JXwzT++jv0AORAlNmjdszgHjesf0d0mSkEUlPbzCvsuOaeGyBG6/XIEPqIqE35/Bce3ZM2ntLXDLCxvLqV6aTDKkc+/yNiRgTUeOhcvbuHNJ6+AcoxIBDh5bNaKBbbauoOfBH6E3TKD23Ct2KkSYW/5X+v72P+/JuIZyhBxFJTT1qJGPv/4YfilL/PB3H72Gt39fAzy6sr1iYFfY53l9a4psySWsy8iSTCKkkSpaGGGdGc1xRlWFCagK3TmLjkyRhniIVN6hO+9TP2EG6clz2PTc3Rx22Bm4moLlCmQEwgNVU4hMP5b00geoKaQgvPvMsd3tCfzNr1HctAzp8E/geT6h5gk0nvFvbH/wp6y+7UYO/IerKNkG7bkSmZLDq1v6aIwHCRoqvXmLjd15OnMmyZDG6ESQhliAoF4xsCvse5iOT950iRgqG7qyZE2P+liAzkyJkK4RNBTGxyJEgxrTGmK0pUr4ngdC0Je36MqZFKzya1eU66wlIXjLcqiOBFFkSBccWlNF4oZGVVClrTdH0Sqy+DffpdC5hZYLvkenPgarx0GhbCS7gAm4mS5ab/8uwsxTt+B6aJxE0YOSV44JKH45bdmh/FoGXAGSBzkP7FQb5va3SBz7uWEZcyrgl7KkX32E5IyjiNU1E1AkHMrf9UlDQ5YkGmIhGmNBJjfEmd4cR5ElipaD6XjYnoehyIR0hb6ijefv290GKgb2++C2xa38+vmNIx5LBDXCAZXpjTGOm1LH9x5YwTs7V2g1owlNOYLc8oeJHXoeSjA65LikaoSnH0vu9cfxSrlhx0dCCURQ+tvyeIUUWt04JL3sbTa3rqTv8V+ROOYzaFXNeMUMqWd+S2DMAfhWge4Hfkj9BT8YpjL4t1UdACy45eUhom2KXEkPr7DvsmNa+DsdaDMaY5w8o2HQuXTNgysR/aJmuiJz1f0rdul0a0ubtL22fdjPrY71dN17LWq8nrpPXrNTobK3jet51J5z+bvuBiBcm8KqZwlNOmzE54pwbbKvLMRoOYDAqOl7NOe0hiiO57Oxp+w02NEpAZV67Ar7Pn15m009BSK6wuYek6ZEkKqIwdjqCI7nY6gKqgRRQ6G118EXgrChEDF0LM+lMQrFEy6m53++xZpn/0ri4LNIKALLKUeZVBmqZp1E+pWF9K18hvi884Zs4kaSTNzVnsDeupKex39F4/GfQatuxsz2seHR36LXj6fupC/R9bf/Yf19/8Woz19FdUhnYn2MjOXR2ldkTE2Ijd05NnfnUVWFgu2Rt32atcq2ssK+iSSVv+Nd38cV5T2wKkNAV2lKBIgENDRFxnI83mhL0Zu3WNuRoykRomjbZAoOBdejJmQwsSHKqvYMqbxNX8FmckOUsVURMiWLI8ZXkzVduvMm8YDMU7+8kt6NK6k/81swdi6F/u/VHe93N9NFx4BxfdENQ7LOBrYiA8WmKhCUwfTfPuYDmdeeAEmmdsbxw7JjXKBnyQMIx6Tl2AuxHUCSqY8ZpIsOTckAY5IhDEOnKqQzpTFGbUTH8nwKlktv3mZcbZiuXPl666MBIoE9K2f7uFKpwX4f7EohN11y2JYq8dyaLu5/tW2YcT1A/PAFCLtEbumDIx6PHHAyeA6FVU+/q7UJ38Pp3YpfyiFJMqWNy+h56KfED7+I8NQjAZCNEHrjZLKv3Eds/gVoiUa6Fl6H9Q7l8sdWdfDrERTRB1JhB9SVl23Zr0ViK+xjzB9fja7KyNLwB+WCQ1r46nETmTsmyb3L27A9MfhFlTXd95TR4vRupeuuq5EDEeouvG7EshF4R+T6PRjXAMW1L+GbOSIHnjLyOV5/HC/fR+Lwi/ZoPgn49GFjaUuXEKL8+vip9Xzp6PEcNamGG8+dVYleV9jncX2fRFDDdgU5y6Fgu3RlTTpzJhv7ChQtj7aUyZZUkbUdObb2lUgEDc6a08wFh7RQ9BTs+DQCo2fSu/gehGsjAb4PigSSDHrNaALN08i/8fhgDaMuQ0iDXW1X37knMPv3BE3HXETtAUciC1CDUWKjJ9P5wl1E6kdRd/Sn6Hz1KV6995cULJdUwUJBoCiQytts7SvRW7TpK1g4rofvCyTKiumtvUW2pYo4lZ7YFfYRAppCdUjDdHyaEyESQY10yWVUMkQybKD1C3jlTJdUwaY7a7Mta9JXsHhhXQ8vrO9hxbYMT6/t5sU1XbyxOUWmYBEKqJi2R9Epl5c8+VYn9yxp5W8r2rj+219j82sv0nzal4lPO3rEdQ0xrhdcT3A3JV0e5T2NTDld3AeE55Bb+TShiYeiRcpBsx1j2J6ZJ7vsIeJTDydcN46aqA4IipbPhLoIp84axeFTGpg3oZrR1RGaE0FiQR3HE+Wac1WiK2sxKhHk0LFVTG6M7fNaDRVX4/vgtJmNg7WFO8PxBK19xZ0e1+vGEZx8GNllDxI75Gzkfk/z4PH68eiNk8m/9hjRuWfttj3QAJKskDz2c3Te+T3c1HbcTBfJ4y8bNK4B3HQHkiQRmjQfraqZugXX0Xnrt+m66/vUX3ITeu1YADRZ4qk3O4edQ5YlciVniLryrZfNr0S0K+wzHNAcZ8nmFIJy9GhGU5wFh7QMMRb3xleEm+mk846rQJapX3A9aqxmxHHvN3I9OM+rj6ImGkfsiSlch+yiezBGTcdombVH8yky3P9qG6ZT3kx7Ap5Y3YmuSNz+xcMqz4QK+wWJkE5dNMA2r0gioBPTNfpK5Ui1LivIcrlGu2BJ1PRHfjqyJomghucL4iGVaFim9oiL2HrHVZirn6A09wxcAaGARCyoks47JA86lfaH/xOz9Q20MQeiKBANyPi+Dx4j1mbLskLzcZ9j8x3fw09th3wXE06/jInzTqRgOaSKHoXuNkwbktMOo7GhDmXqXBKazdqn7qK+JkHg/K+iSALfdXlgdQe27RLQVbIlh0RIw/c88qZNX9FFVyVcT9Cbt2mIB0ZYUYUKHx88X9CdM2nrLdCZt6gKGTQmA4yKB1Hf0dc5oMm4QuD4PooPm/sKtGdMNFmQNn0UGd7qEoRUmS19JZJhl5Ll4LiCoCrz6pZeiq7HY7/9D9oXP8rEUz9PcO7peD7s2NlTAax3GNdG4yR259ISQK5/0IDMamnty3jFNLEDT6W4wzgo73EKSx9A2EWqj7gIIUNzdZh6T1AVLT/zZFkmZ7pEDZXamEHWcijZLum8SZcQbOopMioeJFV0iAb33brrHalEsN8Hl8xr4cZzZ9EQM3Y6RlMkzpndvMt5EodfjLAKZHcSxY7OPg2ndyvW1hXvan1KOEnDpT8mefw/UnP2t4cY107vVorrFmF3byFy0OnoNS2o0WrqLroBSdXouvN7OH3bgHILnpEicp4vuOWFjZhOOY3Wcf1KP+wK+wTLtqS4+JaXeaXfuIZyFOnkGQ3DIrHnzRn1vjyxbq6XzjuuRLgW9Rdeh1Y18vNib0SuAezuzVhtq4jMPnVE8bT8G0/g5XuJH3HJLh16E2vDg84Fz4dXRuicYHuCe5e3vad1VqjwcUNXZQ4cneCoibUc1JKkPhGgOREgEdSpiei4vmBsTZimeICAppII68xqiiGERFemiOMIPAHR8QcSGT2d7pfvJar5VEcUVEkQMVTG1ho0H3QkciBK4dVHCMll558mK0SM/kj3DmsauMMF4IaSjLn0x4w++R+ZfvHlVM86mrzpoakSTu9WelcvwuzeQu3Bp5FoGguyzNwLvsqkI8/khXv+l5WP/4negstNj61h6cY+3urMsz1toqsyIV2lNWPy/LoetvYVKNn+oPBjhQofdwqWQ85ySZccHMcDSdCXt/FHcLFHAhrTG+OMrQkzsaEczZ1UGyagaXgexDUVBBQcr1wT7QiqQyq+57CiLY3peqx/+De0L/4ro46+kMS885ApZ6moQIj+CHS+i853GNc6u85k2RnZVx9Bjdejj5/DO+9az8yTWfogsSnzSYweT8RQSeVtAgo0xoI0V4XQFIkxyRAtNWFqIgYbOnOs686TLnls6yvRHAsSDqj7lShyJYL9PrlkXgtTGqJc/JtFOK6PqkgcO6UOCaiJGpw/ZxRzxyRpqQ7zu79voj1domAPbbul148nOGk+uaUPEDv4rGFR7NC0o0k987/klv+VQMvIjeJ3hmyEwAgB5RQQSdFw0h0U1y3CzXQROfCUwVoNIQTWtjeJHHAKuVf/SucdV9Fw6Y9Q43U7nX/H708fyJV23naoQoWPC4s29uK8oyRClqURBf3mjkly3dkzd1t3PRJeMUPXnVfhFTPUL7gevW7ciONyyx+m72+/ft+Raygrj6NoI3YNEK5DZtHdGM3TCIw5cJfzzGyO05Yu4Qx2FBh53L6dBFahwlBkWSYeNhiDRKbkUBcLUBsx8AFZktCUcnuesKHS2lvk1a19rGxLU7RcQDChLoLj+jSeexlP/eIb9L72JE2HnYmsKNTHgrieRDEuqD7oJLoX3Y9k9uDKNdiei6JAQAfLGlpbOYAH+EaIoBZCaBBWfYLBAH0draTWLMJKl/cE0VGTsF2JeEDQuuQpki2TGXfoiTz+h5/jKwbxOZ+gORFgW9rEdQVhXealdT3oisKUphgRXeaVTT1EDY1ZzTGEMPY4+65ChY8ikiRhux7pkkNf3kZRFMZWh1B34lwfVxOhNhqgLhbAdX1AkDctSgEJSRYUbBfX9amK6jiuR2vaQul3h7351z+y4ck7qJr7CZpP+jS6qoDlo6kQUQWKDF5vJ5tvuwLfzNOy4HpE4ySgXCv9biOndvdmrK0rSRz7WWRJHmZgZ5c9hG8VOOATn2HKqCRZyyGkaSQjAbJFG01VCcgyjYkAQkiUHA8kqArryEBvwSYZMciZNvGgTlDbP0QQKwb2XmDumOSwljwjsb4rv9M5EkdcTPu6RWSXPkjiyEuGHJM1g8gBJ5Ndcj9uths1Vvue1tnzyM/B8wiMPRA320Nk5vFDhRBcC1kPUtq0jNCUIyi++Tydd1xJ/SU/RI3uXilcCPj18xtpqQ5X6i0rfKyZP74aRZGG6A5cduS4nd7bA5/3qx9YibuHVrZXytF551W4mS7qLrgGo2nKiOOyyx4i9eTNe8W49s08hZVPE552zIg13vk3nsDL9VB92r/udkO8ZEuKWy8rP/dyJWeI4KNEv+DbDj3AK1TYn0iGdRIhbcT7SJIkXF/QmTXJFGxypkt91MD1oKUmhK5IeONPZsOTB9H67B00HnoqhqZi2i6dOZtM0aZq7ul0v3wf2xY/SuMxn6Y3L5Clcq/bd6aIqvSnk1KObmkKpIqw4d6fIAuf0JgDcLM9VB14PHLDRIQvCOkyrmPjyBptrzzOtEOORMflb//7Iw53QDnsDHzP48AxcbanTDIlBwlBfUEjEdIIagoNiQBbUiZV0QARQ60Y2RU+tkQMFVWWqI3qZcNRkhhbG0HeiYEtyxKxoMa46jDru3LossKM5ipKnsuGziItYZUNvYWy0rgP4aKNqsGWZ+5mw9/+xKhDTmH0WV/G8UFXNGqTEinTwfFc0h0dbLn1CjyrQN1FN6A1TBwsC/EZfv/vjtzyh5FUncgBJyMo9722vbKx7pl5ckvuJzn1MIINE8jZLp6QqYrohAyVsKaSLlpkSzbhgMZJ02IYmgICWvuKWJ5PTUTHdl2akyEa48H95jlQSRHfS8wdkxwUPRqJXQmiAej1EwhOmk926QP45nBDPDrnEwDkXv3re15j1Ulfxs12kl10D7FDzsFonjbkuKwFCE2aT+M//AdOzxYSx30er5im844r8fJ7LmB255LWiuhZhY81c8ckOX7K0MyNXH/x00iifrctbuXRle2ccUDjHqWL+2aerruuxundSu15VxHYSa1zdumDZeN60vz3bVxD2YAWjkns4DOHHROuTebluzCapxMYO3u3c21Llfjbqg6+etxELj99GjeeO4sDR8VRFQlJAlWWuObMSpeBCvsvu9pIDnQdqE8EUBSJnOUyri7MUZNqOfeg0Rw1tYF/+Mq/U0x1kXv9SUKqRmuqiO26OC4osQYiEw8l9eqj+G55e+2LsnK/ztDNnUvZuKb/b9MsO8HGnfllSn2ddL10L4lDziHUPI2QDCFdxvYFoVCIWfOP51/+41Y2rXiFsz91GVMPPpKX/vxjMiueYfaYaqqCGo7rUh1WCRkaqYKDIkuEDLU/Yg/t6RJbeot0Zs1BYbYKFT4uCFGOOGfzNtmSA0KgKBLZkoPr+eRNh3TRHhT0s12PdZ1Z3tyeJV20CBoqjdUBsqZPd9rCtF0cD7ozJQxFQlfBEz4bnlnIivtvZsK8k5j/me/SkIhQEw1SFVGxPQnLdunZ3sGmW7+LZxcH1cLfT9K1V8pRWPkM4enHogRjQLnvdSwE0+s07NcfwLcKHHreZVRHVOpjQeqjOpbrowAeHvmigyJDV9akN28hyxLjaiPMHp2gOR5kUn2U2mgAVZYHheD2B/afK/2QGalFjQQcOjbJydPrURWJxJGXlGuxl9w/bKwaryc0aT751x7Dt833tAYlEKH+ohtBkjC3vAaAEMN9XW6mEwQYTVOou+AavFx32cgupPfoPKvbs/zsiTVc+ttFFSO7wseOZVtSXHHfCp58h7Df3Uu3ctviVi797aIhn+8fPvImV9y3ghfW9XD/a9t3W3foW0U67/o+dtcmas+5guC4g0Ycl13yAKmnbiE4+TBqz/7O+zauhe+RXfYwxqgZ6PUThh0vK4f3Ej9y17XXO3L/a9sG/33JvBZmNMfxPIEvypuS/aneqkKFd0M0oNKcDFIVMjh2Ui2nz2piwSEtHDKuhtE1EWa3JLnkvNMZM30Obz7+Z2RRjhDLsoxe3tmSOPgs/FKWjhXPYtFvPPfPv+M3+0BCptH/x5cgFgRHjjDu0zeUI+ptr5OMqIR0H11WKNoeQkBVSGdTayueL4jGq7jiZ7cw4YBDefzX19DzxrOMqwtTHw+Rtco9f4KGSn1EpyEWwEcQD+oIIKQr5E23nD5aocLHBCEEXTmLZZt6Wbylj9UdGZ56qxNFSJRcj45sic6cRV/BYmtfEdf1WdORZW1HlmWberhzSRtLNvXQl3FwPZeC5dCYCKDpMqBgqCq5ksO65+7jzfv/m6YDj+bwz12JLyQEPiFVxvUEkuTR19HO1tu+i7BL1L+jFdd7Jf/6YwjXIjp3qNPdc8Esluhe/CBNBxxFuH4ili/heB66JjOxJsL8SdWMqQ7TkAzSlAxhuT5528H3BSXbI6ArSBL0FWxKlrvfOdcqKeL/RwykkD66sp0ZjTGiQW1IOvlti1u56j5BaPLhZJc+QPTgswa9SQNEDzmb4tqXKKx8ajCi/W6R9SCNn/8VXqFs+HrZHpBk/FIWL9+L71gU176MMXo6SjiJXjuWuvO/T9c9P6Dzzquov+iGEVNLDx2bxNAUgprCk2924guwHZ+fP7mWr584uRLFqvCxYNmWFBf/pqyK/05cT/DoyvbB3tiW4/Pr5zYMM8R3hW8V6br7+9id66k9+3JCEw8dcVz2lYWknvkdocmHU3PWt5GU9/+oLq59GS/bRdUJlw1fl2ORXXQ3xuiZu6293pGWqhDLtqRYtLGXZEjnnmVtg/Vbyk5q1itUqFCObk+sizK2OowiS/3RZ3/IcVVV+NdvXcE3PvdJNv/9ISYddwGbe/PIMmgaBFsOwKgbS27p/UQOOAlJkvAZriI+0JYnooGqSxQtgaHJBHUZIx4j9M+/xsn1EVRk7EIvlgiS6+1lU28XqxQXa91imqfMZrspk+80OelffshDP/0mv/z+v+EB0484iWRII6yr1EWDlBzB6OogNRED2/XpyJo4nk+6WI6j18cMIoH9Q0m4wscbx/PZ0lugI2eSKTpURw0kZPKWTcRScRyZgKqQNV2yJRfP81m6sY9U0aG7UML1BSXbxfY9UgWHkKGRMV2qwhrjqoN0Zi16lj7K5of/h9rphzHjou9SdASRgILjwPZMASEUMp2bWf/HKxCeQ/1FN6LXj3/f1yY8h9yyhwiMOXCY/oskw6YX7sYpFRh/8qeRZEgaGr4n6DNtRiXDJEMGY2oipIsdeK5gemM5PbwjW6Joe9iOh+n4OL5PxhXMjAff95o/TlQM7P9DLpn3dnufgU0plNNRU0UbJIgfeQnFtS+TfWUhyWM+O+T9RvP0ct/qpfeXFYDl9yYUIGsGcqKhLJz05C0EJx6KpBkIu4QSqSI4bg7GqGmDhnRgzAHUnv89uu+9tlyTPYKR/Vpbhtu/MB+A59d1Yzs+PvD39T0s2dxXad9V4WPBwuVtIxrXAEIC0/GQ+guMBfDUm507Ffd6J2Xj+hqs7WuoOfs7hCYfNuK4zKK7ST/3R0JTj6LmjG/uFeNaCEH2lftQE40EJ84bdjz/6iN4+T5qzvzWHkevZanclmjBzS/h+eXXO7b1uODg0ZV7vkKF3aAqMo7n054p4Xllw7chFhzsfz3/yKMYO/NgVj/+F8YfdTY1UYOmiEpb2mRzj0P8kHPp+ut/Ym1aTmD83BHPMdDrNuOA4QpkGUqOj+f6BF2wfIVQdQPbFj/Em/f/ivikQ/EUA88ukYtWER41C23qAbRbBg0KqMkk3/zpb/l/37mM//7+v3Hxt3/EhHkn43oenTmTkKEQNtSyoJGukAzptGdKKLJc7oebtzA0Zb9KF63w8cR2fWzPJxHQyFkOPpAIqvTmHYRUIhbUyJVsLBeaEkE29uQoOj55y6YnYzNjVITNfSXCmkLO9rBcG1WS6fZ8kpEAbz3/IBsf/CXVU+cxacHlOLKK5/sUbUFHpkim5OP0tLL+z1fgey71F92wUzHUd0vhzRfw8n1Un/o1oGwQBiTQVZCtDJufW0j1zKMIN4wnVbTZlnGJBDTGVoXxfZ/unE1YVzlsfA2m7eIi0Z21yJQcYkENy/VQJJlRVSFcV+x39/v+dbUfEQaiZD99fA0X/6acZjp/fDWSJKHXjiU07Whyyx4ajDIPIEkSsUPPw021U1q3+H2vI3LAKQQnHIwSjFF75reoO/9qqk7+CpFZJ6Alm4aMDY6dTe35V+OmtpfTxYuZIcddr9yia+6YJLdeNp8jJtUgS1Tad1X4WLErW1kIWLI5hbeD/b2nquFvG9dvUXPWtwlPOWLEcem/3142rqcfQ82Z/75XjGsAq20VdvsaYoecM8wx59slMovvITBm9k5rwXdEppyxoioyT6zuxC1nhjKgBycBWkXcrEKFPSZXcujKWqSKNj05C9vz8YSgaLms3pZh9jlfwMqleOOJO8iWbLbnHSwB4SDUzjwaNVJF5pV7h8074CqzKPe7dYGCgLwH2RKUHOgpQMaGnqyHNOF4ohMOxjNiVJ35LWrPv5rYSV8hOP0ETKOWbMnDFhK5Qomn1qUZe8kPqJswgzt+/B1W/f1RHFeiNqIT1lVs16U7Z5ItucSDGo3xIImQhizLuK6P3++Z9CutvCp8hFEVmbCm0ldwGF0VREEgSxKb+wqkCjZCQNH2USTY0JVlRVsG4dnIikTYUHAcH0NViekqtuOSynuUHB+BxKqnF7L0tp9SN20e4xdcgZA1PK+s9dKeKZIpeOS3b2Hdny7H932aL75xrxnXZaf7QrSaFsLj5xKRoDEKug41cYOel+/HtS1GH/8p8pZNyXFIBDQUISg4Hqm8jef5bOopUHBcZEWhOqJjuj4bewqkCzZFy6WvYLOlt0DOsjHU/cvk3L+u9iPCQJRMUPaO3fzchv4j5S+axJGX9LfLuWfYe0OTD0NNNJBZfO/7rmeQVI3ac6/EzXTR99Rvyj8boS/uAG8b2e103n7FEAeAL+C2xVv44p+WAvD1EyejqzIyZcdAMqSPKA5VocJHgYHP5symOHvbyepbBbruuhqrfU3ZuN6hH/0AQghSz/+ZzIu3Ep55PDWf+MZ7zlAZiezie5GDMcKzThh2LLfsIfxihsRRn9qjuXxgeWsaZ4RIv+j/4+9ntVYVKrwfLNenaLkUbZe1nVna+gq0p4ukChadWQupbirVkw9m8zN3gV1C+FAoOuiKjKprxA8+G3PLG1jt64bMqzBymzxB2eAeMLz9/n/nZY3EOVfi9O8JZABJRpXBEaDLHr15hy1pE0mSiEYjTPn0tdRNnMl9/3EFK//+aP815Fi2OcWLG7p4bOU2XmvtQ1clfF+wvitH0fbImS4dmRJbegt05SriZxU+WpiOR3fOxHI8ZEkQ1GWaYyFypoPpeMgSdOZMUnmbkK6iaxJvbM3Q1ldgdWeRhkiAumgAD4jpMou29FEo2iiyoOR4rH12IU//7oeMPehIzvj6j2ipjRMOBlAVBVkGXfLJbN/Iltu+C5JMw8U3Eawdu/eub9NynO7NxA49D02SsAXYLtTFNIJujtaXH2T8vJMYP2ESqqwgIeF6HumSi2W5dOVtLMclbzlEdB3HE7T1lUAIRidDaKqMrigk+zsKaMrw9l/7OhUD+0PgnR+yp97s5NfPbRiMjGlVzYRnHk/u1Udwsz1DxkqyQuzQ87Db12BtXfG+1yLJCnULrkUJJ3DSHbsdHxw7m9pPXo2b7qDz9itw832Dx7alTZ5Y3clFt7wMwNVnzECWJXwhuObBlVz8m0UfqPhZxYCvsCMDn4fbFu9a1X7ZltSgcNk1D64kHtT32ho8M0/nnVdhdayj9uzv7NS4Tj/zO7Iv30nkgJOpPv3re9W4trs3U9qwhOjcM5G1wJBjvpknu/heghMOwWieusdzjtSKTOLtXtiuJ1i4vO39LLtChf2GoK6QDGqUHI/6WLmNzaaePJ15k45cCRA0HnspTjHL5ufuRaa8GfZ9UGRIzD4N2QiTXVx2yg88PVz2rA5Q4m1DfGBPoIUTVNsd1EUg2D+hJ1SakwZ1MQMo61BIRpBzvv1fjJ1+EHf97HIeu+8eNKksoOo4gmRQoytnkbc8gobK2OowTYkg21MlenI2Aa0sfmY677a50K4RQrCpK8/STb10Zop7de4KH09Mx6Mza7I9VaCtr0BPzsR1fUq2N6Q0zHZ9tmdKZEsOr7elKDo+tifY3JunNW3y+pYMW1MlkiGduniAibVhtvaWSIZ1pjUmUCSQJUFHrsi2lMVr29LkCyaoMum8xZtP3MGiW3/G6AOP5Oxv/hhPVkgVLXTZY0zSIFNwaN+4hrbbr0BSdeovvgmtZvRgJ4D3yo4GX3bR3SiRamqnH4NP+f5PlyBT8ljxyJ8QnsvUUz5DwfEJKlAf0YkENWpiOiGjLF62NV1ka0+BNR0Z1nZlcDyfnOUR0uT+IKLA7c/GyZacEZ19+zKVGuwPgfPnjOLOV1oHUyp9AU+/1TVkTOKIiymsepbMS3dQfeo/DzkWnnkC6RdvI/Py3QRaDnjf65Ekmfj8C/Z4fHDMgdRd+AO67r6Gztu+Q/1FNw7pze14YjAl3BdlRWHbG4htvZ0yvrfqM5dtSXHzcxt4YnVZbEqR4a5/OrxS/7kfc9viVq5+YCWeLxCU64N1VebqM2aQKtqD4luLNvayPV0aFC6zPUFfYefK1ydNr+e4KXVc9/Dq3arhesUMnXd+D6e3ldpzryA0Qu2zED6pJ28mt/yvROecQfLEL+4yi+S9kF10D5IWIDrnjGHHMq/ch28VSBz16Xc97+680fubt7pChfdKNKBRHQ2QsVxqozoy4DgeK9vSrGjLYDkuasMkktMOo/2l+2iadxZNddVkSg6ODcl4iNwhp9P14j0oqTZC1aPI++UN9Z6YrTveqwNR60nHX0BjLEBH3sWWXaKajK4INnUVqIvrbO4sgAST66NUh4N8+vu/4s6bvs7C/7qSqOYz9vAz8BEE1AQNiRCqLCFLMj2mS1/RZmtfgaCqUHJd4gGNvdkat2i5PL+ui+ff6qIqolMT0TnnoNEkwsbeO0mFjxW9eZPV7Tl0VaJguoyuCmO6Pr15G10rR1frowF8IcrttnyB4/poskxd1KBkuaSLNpPqwggkFCExvSnO7NFJZEliS2+R7ZkSCJ/6qI6kKrg+dGaKZEolHAkK2SKbn7+bzqf/SGLq4Rz02e/RlnbIWx5RXSUSMNAVhfSWZaz941VIwSj1F92AlmjY5bXt7j4PSVAXg3jIoDvnkGl9E3PrSppOugwtqGHboMlg+9DX0c7mlx5m9vFnk2wag+W4FBwZSZGJBDXwRXm9QZlsySFnOcgStOcsYmNVQoZCznSZ3hwjVfBQZIloQAVJ2m/6Xw9QMbA/BOaOSXLdObO4+oGV+P0GwDtb+6jxeqKzTyX32qPE5p03pCZa1gxih5xD+rk/YLWvw2ictNtzmm1vYneuR000EGiZNSyStav3GQ0TkdShip+B0TOpX3AdnXd9n45bv1N+CCTLrcg0pawevKYjN+KciiLvFXXhZVtSLFzexh1LWofUxXo+/OjRN7nrS4e/73NU+Pgx0GZrRwZU7a9+YCW+EEiShETZ+SNL5TIGSYhdGoUBVWbJ5j4sx8N2d21cu/k+uu64CjfTQd153yM4gviQ8D16H/t/FFY8SezQ80gc+7ld9811HcytK3D6tqHVtBBoOWC3X1hOuoPCm88TO/hslGB0yDGvkCK39H5CU4/aK4qkEgzqLmiKxPmVGuwKFfYIRZZoSpZrlDszJV5a383S1hRvbUtTFVJIF3xM22PyaZ9n8X9+kfaX7mHOF/6dTN4sqxAXbJoOO4eeRQ9gv3Yfkxd8i7aUScGBoFSuu34nXjFDcd0iJFnBGD0TLdGAQVlxXJfAcaGz4CDjoUjlOlNV9Wjv66XY2ceUlml4ko8qyaBITG6q4Sc3/4Vvf+Wz/OEnV3HOlwuccuHnaIwHmNYYJRHS6c2ZbOjI0FssiyApiozj+EQTKgHt/WftmI5HwXRZ257mxbVd5EyHcECjJ+9QdDwS7/sMFT6OOJ7Ppq4C6zoyKJIgpKtUhw1kBUqmRyKskzMderMmiYhBznRwXR9ZlvAEmI6PrEBjPEjB9IgaKqGgSq5g8uvnNhAPqLQkNMZUh8mZLge2JHE9H9/1yRQdhCSD59P6zF/ofPYOEjOPoemsb1CwJSTHQVUlXASposmq1Ut544/XIEeqqL/ohiHBqxGvrbcNs201kiwTmjQfORAZctyg3HGgLh5lRlOczX1FHrvzHrRQjJnHnU2fDV6/cQ3Q/cJtSJLElFM/je25+L7AdT1UTSLR7whriCtUh3ViYZ1QScXxQZNlNvQUaUsVqYrotKVNjphQg64pqP312YpcMbAr/B9wybwWpjRE+d79K1jdPrIhGj9sAfkVfyP9wl+oPevbQ45FDzqd7KK7ySy6i7pzr9zluXKvP0HfY78YfC0ZYaIHnU5s3vko77gZd8TN9tB5xxUERs+k9twrkfWhRrnRPI36i2+k666r6bztO9QtuJ6GMRP5zT8cDMD3Hlg5ogjUtIbo8B++SwbSeneWVtbaV0kJ218YUOQfaHt31X0jl04Idkht3qHer/wjUU7rkiVcb2RD23R9TNfn+XU9Ixx9GzfTReedV+LlU9RdcM2IWSbCc+h5+D8ovvUC8SMuJn7EzntP+3aJ3NIHyS59AL+UHfx54pjP7DbzJLvobpAVooecM+xY5uW7EK6zx7XXe8J158wazBCoZJBUqPDuCBkqXXmT1R05CqZLwfJAkggHdCIBBbl2IqMPOZHtLz/A1lMuRgolUSSJ2qBKINmId+zZrHxmITNO/wySH+vXQGFYOknCybDqD1/Dzb0tPhoYexDJIy/BaJ6GpoCsgC98Sg5Ifr/jTJXYcM8v2Pz6S3zumv8mOnYmuioTUmXqo0E0Gc79959w108u5/7/uQndtzn3R9dRGwtSMB2eW9tFpmDjCtA1mbikIcsSBcujJ2dSHTHec5QrZzq0p0ts6M6xtj1LUFdIFWy6siWmNMRIVNqC7Rf4viBVtLFcn2SorGLvui7Pr+0iXbLozNmMqQ4R0sv92Qu2S6kjQyKoU7Ad6ksBOjIWigIzmpKEdZVNfXl0SaYqZDBrtIwmq0SCcOfi7UQMhe19Lm9t8xnfGKc7U8TxHIqmS9H1qI4ZdKeKbHnkZra9eB+1c09h3FlfwXIVPE/g4yAJhda0TXbdErbddyNqson6BdejRHb+HVra/BqZv9+G1bZ68GeZqmYaP/tfqFoAhXKpSEiHaEilOW6QCKkktrbRsfJlJp/2WdAMwpKHofrkLUGho5XUimcYd/R5RJONZEyHvGmRDKkYukZIkwmoKsmIxpiaKIYqky7abO4t0pwMkrMchC+ojxl05218YHJ9DHk/M6wHqNRgf4jMHZNkU8/ODUElkiR68NkU33weu3PjkGOyESI69yxKa1/G7t68y/PkltyP3jSF5q/+iboLryM49iCyi+5h+y1fJP/GEzsVF1FjNVSf/GXMza/RdffV+FZh2BijYSL1F98IQtB52+UcHksD8LU7Xh0WlR/gjbbMe6rD3rHGeuHytl3WbJ0zu/ldzV3h48mO9dMDn6mtqZHvqd2lLPsCjp1SNyxVMWIoeyx85vS20XHrt/GLWeoXXD+ice07Ft333UjxrRdIHPt5EkdeutNNpdWxnu2/+SfSL/wZo2kKtedfTfNX/0Rw/MFklzywS2EgN9tNfsVTRA44CTU6NGPEzXSSe/VRIgechFa1d+6VkKFwybwWvnrcxIpxXaHCe6BgOaxpz1I0XQKaTF0igK5KjK2OcNzUeuaPq+Xsz38dhM+Wp/6EockIGUzXwxU+00+9FAmJV/96K4YOdWFQVdjRtFSA9pV/x831MvqiG2m67H9IHPVp7K5NtP/lW/Q89FMwUwgBCUNClyV0DeoTBjFDY/7F/0IoUcPvrv4nUmuXMqspQVU0QGtvnpfX92ILlaO+eC3TjjyNu27+GTf+4HuYtsvSLX2s7cySMT1SRRtFQDyoEQ1oRAMqmZJDX8HG3E3pzY6Ue2vb5Eo27alyim7J8omFdEKKytiqMEdPqeWiQ1sIVQzs/1P+r0TrPF8MqaHOmg6ZkoMQgo5sCdfzsVwfIUNzVZiqiEbEUKmO6/hAWJfpyVr9Di2Xl9b3kjUdFGRMx8XxfGrDAWqjBgLB2Ooox0ytQ5EVZASu8Fi9LcVzazr5zTPreeqtDp5e1cWqtgym61EwTVbe8zO2vHgfLUecw9gzvkrBVvAcKNkOHRmPtX02Xauep23h9Wi1Y6i/5KadGtdCCLrvu5GuO6/CzXSTPO7zNH3xFmrOvhy3bxvmhqUEpPI97wElG9IFl1e3ZlnTnufu3/0KLRAmMed0urM2puUhSTK6Cr0v/hlZMxh1zEVkzLL8oSN8YkEdIQQBTaUuGWJqU4KDWqqYOSrB8dMaOOegZg5sSZAM6CRCGvmiS1CTGVsd2m+Na6hEsD9Ulm1J7baOM37oeeRffYTUc3+k/sIfDDkWPfgsskvvJ/PSndSe/Z2dzuFmu4gceApqpAo1UkVw3EHYnRvpe/Jmeh/9BYXVz1F9+r+ixuqGvTdywMlIWpCeh39K5+1XUHfhtcN6YOu1Y6m/9Ed03vk9brn8cyxcfCX6CIbFxLoIG7ryg+rpP39yLV8/cfIebcYHDCnb9VFkCc/b+cP7nNlNXH76tN3OWWHvcNviVh5d2c5pMxsH+7x/ULwzWr1oY+9g/fRAbf+J0+q5/7Xt73puSYK66NAaPQk4fEINf1/fQ8He9b1qdayn666rQZKpv+Qm9Lrhade+VaRr4XVYrSupOuWrRGeftss5tWQjet14Ymd/l8Cotz/TxqjplDYuRbg2kjZyXWFZ8EgQn/fJYcfSL/wFSZaJH37xLs8/wFGTajhtZuNgdPrXz23gb/2aBwOcOK1+j+aqUKHCyPTkbGojBtGgSmtvgRlNMVqqIiRD5frM1r4CE+dM57XTF/D3h+9gzDHnkxJJXCAoBIFYDZOO/ARrX3yE2sMvwArXICkQ1MBxyufwADPTCYqG1DILTZKIH76gvJ9YdA/ZV+5l1aZltJz+JaSZx4AAQ5eJBwOEdIFPFWdc/t888Z//xh9+8FXS2ZuYfOiJaIpEV9ZCViWEpHDWP1/P5JZ6fvGfP2Pz9i5OvewKNFlGkgQhTeXg8dXEAzrdhbJ0U2fWomi7hA2NqrBOIrRrsUnfF2xPl3A8n85MqWwgOT4F26UmYtAYDzKlPsro6jDqftZ/98PE9XzWdmbpyFg0xAymNsY/MCPL9wUdmRK2V+7K0xQP4vkCVZZQZQnblRBAIhxgan2U17am8T1IBnXyJY9MySQW0LB8n76iSdhQEQgQULQ9HM+jKmzgeA6be0uD9dmu7yEERAyZVzb10ldwcHyB7ftkS4KenEtjQsexLFbd8SN6V73IqGMvInTopXSZb/8ucqX+v19/nL7Hfokxajp1n/w+shHa6TVLkoRWO5ZE02Ric88eLN9UQony7z/TRVG8bWDLlIML6ZJNx9YNbHvtOSaddCklKUTJBl2GMB5u5zpSb77M7LMvY9yoBnRFwfJcYiED5HJK+NjaCMmwjhDQlAwO9rWuiwUwHY/qkEFv0aQ75zChNkIivGelqPsqFQP7Q+TePVDZlQMRYvMvIP3s7zFb3xgSEVOCUaJzzyT78t3Yh1+EXjtmxDkkzUA45pCf6fXjqb/kJvKvP07qmd+x/Xf/QvXJXyE8/Zhh7w9POwpZD9B9/03leusF1w2rC9GSTTRc+iO67rqa9ru+T+3Z3yE0af7g8XNmNxEyVFr7ioMG0Qvreli8sZcLDh7NeXNG7dLQ3tGQ8kcwrhUJTphWzz8dM6ESPfs/5LbFrYP1zi+s66G1t/CBOTd2dLLoqsytl81n/vhqdFUui5Go5dr+rx43EYDHVnW8K2Xasw5sois3VKdTwKB43q4wt7xB18LrkANR6hdcN2JU2Ctm6Lr7GuyujdSc+U1CU4/a7byyEabugmuG/dx3TJDknfbJdrM95F5/nMisE1HjQx1ndvdmCqueJXbouaixmt2uQVWkYY6wLx0zgRfWdQ/+fo+eVMPPLzpot3NVqFBh58gyRAI6x09rYFuqRE1YRVEUOnIWuuNRFzOQJYkjzr+MV55YyMoH/pdR519OUFNw7HKa9aSTLmXtCw/T+fe7GX/WV5AlQaZQ3nD329hIqgGeA74LSnmDLutBEkd/mvD0Y0k/9nM23fsTsuuWMOnsr1CyQ5imRU/OQ5NlgoEQc774M17//ZUs/Ol3mLXgm0w58kwEHlnTI6QrhAMaF/7r94hEY9x68y9o7+zmwm/+CEsSHDEhgaGp9BTKRvXGrhxbUyUa4wEaYgaW6xHWVVSlLIxkuWWDRpXKZT4BXcX1Ba4nKNk+WcvF8X0a4gHyhkJdNEBTMkhtdP/e4H8Y9ORMVrdn8T2ftzqy2K7HrNFVe7X+1tpBA8X2BGFdxXQ8ipZLLKhRtD2Kjkc8pA8agSdPr6cqbBDWVYK6wrquLKYjCOkwsTZCpuTgCRhfE6U2ZlC0XIK6Sqpo0pE2sT1BQ9QgVbRZ/lofsiyhyArVsQAF26E7By5iUCm7szvNlntuJLfpNZpOugxlzjmURriWzOKFpJ/9HYHxc6k569u7NK4HSBx5yfAf2uXZ9X6H+8C9bgGyC4YOi+/9DYoepH7+2WQEGP2lIKois/7x3xOIJjng1EuYUBvD9Hw60kVGJwwaEgZNiVB/3brMrOb44O8VykZ/UFcZVxclktMZV0Pl3qNiYH+o7OnjJjrnDHLLHiL17B9p+PRPh6STxg4+m9yyh8i8dMdOo9hqogGnb9vw80sy0dmnERh7EL0P/ZSeh36CueV1kid+cZgIWnDCIdRdeC1d91xLx63fpv7C69Cqh4oYqdEa6i/5IV13/4Du+26k6pR/JnrgyQA8srID1/OHpd/anuDWxa3csWQrx0+t40vHTADKBvW6zhyLNvbSUhXinINGoasyluMPSfVVZIkFh4zm/N0Y6BU+GO5c0jrk9a+f38hJMxr2qkL8wuVtCMr3y6Dat/N2BsStl80fHDPAzy86aNAgf+dnZiSOnlTDQ69vZxeJETulsObv9Dz0E7REE3ULrkWNDjda3Ww3XXddjZvppPbcK1FCcbruvZZA83S0urEjKozvCre3DTVRv9N2XtnFd4MQxA+7cNix9LN/QDJCxHZTv92cDDKjMTai02rumCS3XjZ/SDZBhQoV3h+10QDdeQsEaIqMLEnoqkTQUClaDlUhg9Z0ifqGBk6/+DLu//3/o+bw86FpMr4PTckwciJA/ZwT6Xz1cRqOugApXIMlyirDCuWoltovSOr0bUN/R29drWY0NZf8mNzLd9H799vJbX2LhrMuxx87Ec8GXYeCbZGzJJovvJrSnTfwxh0/oZTPMP3ki1FlqArrBFWJVM7hS9+4Ei2U4A//eS09fX1cdvWvSOVcesMWNVED0/XwhM+YqjBrO7LkSg7JkEFrb4HRVWGqwzoF2yNvOmzsyeG5UB3ROXhMkqAu05YqEguoIMrtAWc2J6iLGiiVqPWHgutB3nRJ5y16Sg4buvKMr4sS20vtL3vzJttSJZR+dW9FhrzpUHJ9ApqMIkmMSgbJlBxypovr+9SEDTRNY0pjnHTRZmtvga6cTTygkTc9qkIa0YCC54Miy8RDGpIkkS7YrGrPEtY1EiGV1e1ZSpaL5Qp6iiVKlkO24GDaApnyvQUgFzOsvfsa7M4NNJz+b8RnncA7C9eEEKSf/yPZRfcQmnoU0bln0PPgj9/znoBUeY+v9N/bO1IS4HZuYfvrzzHxhIsJReO4lodAUDQFpY3LSW18g8Mu+QYTGuLUxgyChsKUxjD4oKkKc8ckqYsF0RSFSGBk0zGgKYyu2r2DYH+h8gT6EDlvTtlolChHiXbm4ZM1g8SRl2C3r6G45u9DjimhONG5Z1J868Wd1mIbDZOwO9YjPHfE41qigfpLf0TssAvJv/EEHX/+9xEN8sDomTRcchPCtem49dtY7euGjVGCMeovup7A2Nn0PfYL0i/dgRBi0DDydhJQ9HzB/2fvvOOsKO+2/516etu+9N5BFBBQsfeusWKJSYym92rypJiYZkw0iYkaY+y994ICKigKCNI7u8v2Pb1Of/+Y3YXDLoglyfO82evzCVln5txTzpyZ+/qV63plfTvn37KUC297i+tf2sSTq1poy2i8szPJj59cw09On8z82cNQJQERkEWBX5w1hV+dM3Vggv8fQk24b5Ty1sXbPvJ4e/bZr2hIcvHf3+a+ZY3cv6yRh95tRJbEXkuKJVu7uOT2t9nUluWxlbt48J3Gst7+HhJYG96/NcuY6gBLtsU/ErnOvvc8XU/+Bk/tGGov+W2/5NqIN9F27/cws3FqLrgWBIGuZ2/AP3YOjmWiNW8EXMuuA4HjOGgtm1Brx/S73sx07pG9Li/bLjWuobh9OZE55/dRFd8TAjD/0GHcdvnMff62ZgyPDfRbD2AAnyC8isSUQVGGVwYYVxck4pPJaiYRr0p1yCUNkgBja4LMv/JLBKMVtL36T7ySwNCIF0mEtG4x8eTLwHFoWfwIRcPNpAjsJgCeOvfZobVs6vc4BFEifPjF1M7/DZZp0Xjvd2hY+hwpyyFbgmTBHS9veak/9yfEJs9jy7O3sfHpWxEch6ZEgZ3xHLsSBdoyRU684LNc9r3f0LBuJbf88NPsbGqiJVVgVzJPSbNoTRZpiOfpzJcoGDrN6QKdaY0NLSne2NROPFtie0eOtc0ZOrIab2zp4pWN7fgVmfF1IaqCXmojPibUh6iLeAfI9X8Q1WEPAVWkNaNR6ZfJGTY5rf+55wfBth06siV2duVJ5nUcx2Fre87tbdYM2jIlKgOebmcQyGomnTkNo9tuUxYFcppJpuTmcysCKoNjPmRJIOxTqAp6cLDxeWSiXg9Rv0plwENRdwioIsmiQVG3iPkVJFGgPuwh5FOwHZuSbpAvmWi2hSKLVIdlIhI4qTaa7v0uRlcjg879EZ6px1Ha67wc2yLx4p/JvP0oweknE5h0NPHnb/rIcwKAbPdnfLWj+13ftugBRMXL6KPPI+ZTGRzxUuFTqQlK7HjpDvyV9YSnn8TOpEbQK1Eb8VIX8XPoyCoOH1uNV1WIBTz7JNcD6IuBp9B/EDOGx3jg83P4zknjeeiqufzirCndXpEuedyTbgemHIdSNYzU63f3IcrhWecgqF7Sb97f7368w6bhGCW0lo37PBZBlIgdeTk15/8cK5eg9a5vUNi8tM92au1o6i75HYLipf3BayjuXNVnG1H1UfOpnxCYfAzpN+4l8fJfcewDEy6xu6PQ/S1ftKmD686ZygNXzeXbJ43noavn/st7fgewf/RUHOyJ7V19xfD2hW88+B7Tr32Zbzz4Xplg2cV/f5trn1nXK1wCbnBmYl2IqUMivXZQhmnzwtrWPn3YPZgxPMa0IdH9HsO2zny/gnz7qzBxHIfk63eTePmv+EbPpOaiX/ZLWLWWTbTd930cy6Bu/q/xDp2C0bGDyOEXE5p+Cp6hk7GyXZjZLuxitnfs/cFM7MLKxfEO76tzAJB+6yFwXBeCPse86J9IoSpCM87o87nBMZ/7/AE8yidjpTeAAQzgw0GVRSqDHqpDPsbXhZk2KEp1SGVozEdd1MfYmiARv0okHOHkS79E84aVKO3vY+GwvStPNq+TkaJEp51AYtVLFDMdaJT75MoVQ5CCFZT6eX/vCe+QSdR/5iZ8w6eTeOVvxJ/7A5pRomiDZoDpgMerMPmC7zD88DPYtOBBVtz/O0xNx3agOZ6nJVkkXdIZeugJnPndP9DWtIOfXn0+776/hvUtGZ5c2ciiTR08uaqJjS1JNjRneHNLF8u2d/HMqhYaEnlWNqVIFzUMy2JDa5qibtKeKbGhLU1dxMeIqgBDK/xUh7z/dV67/9vgUSRmjqjk0NEVjK4JokgChmUdkOiZblqs25Xk5bVtbGhJky7q5EquYFZrukhrugCCg4CA6ThkigbtmRLtWQ1VFAl53PLwnn2JAgh77dYjSwyO+gl7ZFJFg7qQl0FhL6oqoRkOXlWiPuLBsgW8ssjkwVEcwUGWJI4YV8shwyNIAtiO4KpqKzIxn4IqOWSbt9J873ewixlqLvwlSncWek95PdvQ6Hzy1+Tef5nI3AupOPHLGF0NH2tOAFBqWIVSPQJzL40kAL1jB7lNbxKZcSZtJZUNHXkSeYNwUCWxdhHplu0MO/EKgn4V3XTY1pUjkzfwSSIeVca0IegZINYfFgME+z+MPbNA82cP46Gr53LxocM4dkINiiQgCSCLLgGOHnUFZrKF3OqXysaQfCHCs86msHkpWtvWPvvwjpgOokxx6zsfeDy+UTOov+JGlMqhdD7xK5KL7+xDjpWKwdRdej1ypJaOR35Gfv3iPuMIkkzlad8kPPs8cqteoPOJ67D1veN4Hw6vbGjn/mWNA5mz/0WYMTzGF44sF/Pa0Zk7IIX4bzz4Hk+uaiFVMHhyVQtfuGf57hJw02b1rnTZ9g6uAv2G1gyyJCIJoMgip0ypR5V3//fexPDqo0ajSO6kSxT6Euf+Xl2SKHD1kaMY3k+5k2MZxJ/7A5m3HiZ40ElUn/vjfn3li9tX0P7gNYiqj7pLf4faHVkWFC+Jl/9GfsMbxJ+/EbuUI7X4LhILbsPWCh84QSxsWQbQr7e2kWoj9/4rrqjhXr3XhY1vorduJjrvUsR+hNHa0iWuPWsK3z5pPPddOWfg9zWAAfyHIUki0YCrdpzRLMDBo8iMrA4S8MgccuJ5hGuG8Ob9f6ErW6CgWyAKBD0SNUdcCAKklzzYZ1xBEPCOnEFx+woc0+izvuwYfGGqz/sJkSMuIb9uEc33fAcj2YoG6Lit3AVbYvxZX2HI0fNpW/EKK+76Gdua42yN53h1QzPLd8bpyJSonziLU7/7Z4qFAr+46gIWLFzEquYUiXyJZK5EIqeTKBokCyVaM0Xi2SLxbImYVyKgygwO+1yVccFGlWFXskhrqogiinjkj++jPYBPBkNjAUZXhojnDLySSKZokirs/z4D2NGZZcnWOBta0jy9upn3GhNopkVOM+nMlsgUDBRJQhYFArJEyCsTUCUqAiqd2RI5zSTkkfAoEjG/SsGw8aoSob0U5OtjPiYOijBjeIxBsQArGhOsbky6mWnDJuJTqAp5OHhYJQcPizJnZDUnTKolr5nsihfxqTL1ES8mDpIjEPbKOM1raHjgBwiSSt0l15eJkpq4VSR2KUfHwz+huGUZseOuInrkZQiC8LHnBFYpR6lpHb5RM/tdn3rzPgRPgMCh55AyIKtDyTQIYrHx+X9SO2oSU+edQmXAR9Aj05YskS4aaJaDIrnicRWBT6bE/78JAwT7fyEeW7mLBRvasRxXuOvYCW6Zp2/0LDxDp5Bacj+2Vt7REZ51NqI3SPqNe/uMJ3r8eEccRH7jmwcUCZPDNdTN/y3Bg04m8/ajdDzyM6xiuVe3HKqkbv5v8AyeQNcz15N596k+4wiCSOzoK6g44YsUty2n/YEfYuU+nDXXnnAc+MlTaz+0vdcA/rX4wakTy7zNLeeDBfxWNCR5cV1b2bLOnJv12PNVIgAHDYlw0JAIPXaulu1w3owhfOtElwjOnz2M+66c0/vf/fULP3iVW+0giUIfQi0AqiRwyexh/OqcqXz3pPF8/oiR3Pb6dhr28lO3Szk6Hvkp+XULicy7lIqTvtJvH3Ruzat0PHYtcmwwFSd9BTPT2bsuPPNMYsd8Bq15A6EZp1Nz3k9d0qv6KGxdtt/rBlDY9CZq3dh+Vf/TSx5AECUic8v7qx3LIPX6XSjVIwhMPqbfcS3bYV1LeiB4NYAB/C+Cbtmkiwapgk4qb6BIAhGfwvtNSRoSJYYcfwW5tp00LHsJwzJJ5A1yJQslUk1o+ink1izot+XLP/4wHL1Aced7H3gMgiASPfxias7/GVa2i7a7vkFx+wrAzYzrJnRkDAKz51N14pfIbFnO5rt/RCGTZFtHiVWNcba05djYnESoHsNhX/0Toi/EPT+9moZ3XqMja6LpFnndoj1dxDBMTMtGkiRaMyVM22FwhY/BFUGmD4sgIpPImVQHVIqGRbqoY+/DFnQA/37IskhNWKUioFAX8ZEtGhT0fZeJm5ZNPKfRmCjRnimQKGo0dhVoTBQoaCatqRK6YVMwbHyKyJjaEBMHR6kMeigaFh7Z7f2tj/ioDLrB41hAZUSlu2zv9kuPLDE45ifqV0kVdQQELNsGAbyKiNnt5Vwf9TIo6qcu4iWR03lzayeZkoVXkWiIF+lIGwiiQ9vKF1l48w8J1Qxh5BW/R6kaWrY/Cyhlumi/7/toLZuIzrsMZQ9R4o87JyhufgtsC/+4uX3WaS2bKG55m/Ch5yB5g4A7jyrp8Oaz91NMdvC5b17D7OEVRAIqUweFOWholBFVQRzHIadZ+FRpoDLkI2CAYP8vw55q2Zbt8NrGjt51giAQO+az2IU06WWPln1O9AQIz/4Uxe3LKe1hPN+DwMSjsDIdaM191/UHQVaoPPkrVJz8VUpNa2i76xt9erxFb5DaC67FP+4wkq/9ncRrt/fbMxI65DSqz/0RRryR1nu+9YG+3fuD7Tg8tnJXb5/u/csauewfy7h/WeMHf3gA/xKsaEiyub08APNad7XBvra/5Pa396nwPbom2Pu3A9SGvVw4axgeZXeW+lOHDCkjgh9U1TBjeIzBUR/GXu0HogDzZw/jgavmct05UxlfF6IlVeS217ez99EZqTba7v0upab1VJ72LaKHXdTnpeM4Dum3Hib+/B/xDp1CYNLRdD5xHbn3XqC0a0PvdqHppyCoXoRuBd+eXmlR9fV7/L3HEG9Cb9tKYOKRfdd1NZFft5DQwaf16QXPrnweM9VG7OjPlAUEAmp5cGBvFfUBDGAA/1nIokC2aKAZJpIkkCkYGKYN2JiWQ820IwgMGU/za/chOwZRn0J9zIdfFRh6xAUIskKqn/Yx34jpiN4Q+fWLDvhYfKNmUPfpG5HC1XQ88jPSbz9C1nLImpA0oAQEDz6V6rN/iNa+nfW3f5dCVxuWBbZjUTRtcBzkWD2TP38D4SHjWHPfL+lY8pAryuiTCftUfB6ZsN9DZdDDmJoQXlXBsKA9XcSyHKYODjEkFqAq5MGybRrjeTZ3ZIjnNBI5neZknlxRp6iblHRrgHz/m2HbDl05A1mS6MiWaE2XaE4X2N6RwzT7vve7cjrZkkF1UCGv25Q0k7BXAgQ8soRPEQn7ZHTTJlVwgy+241Ab9hLxqYS8MlG/itKtOt+DDyKFDrArUaAjVyRdMGnPaET9KsMqA1R1txvops22zgzrWzKUTJuufIl3d8QRgYqAzNqnbuXl237FqIPmcMYP/kYkWtFnP3rnTtru/Q5mpoPYwSeTfvvhT2xOAJBfvxA5Wo9aP67PutTr9yD6woRnnFl23rKRZvurDzD6kHmcePxxnDhlEN84diynThvMiOoghm2TKZmEBnquPzI+sSsnCIIELAeaHcc5/ZMa978Nc0ZVIgoCdnem2bId2jO7S6s99ePwTzyK7LtPEpp+SpldVuiQM8guf5rU4ruonf+bsoeLf9xcEi97yb2/AO+QyQd8PKGDTkKtGk7nk7+i7Z5vU3nqNwlMOKJ3vSCrVJ31fZKv/p3su09iZeNUnfZNBLm8nMQ/Zja1839L52PX0nbvd6k+6wf9lrjuDwLuZOOR5U2YloMoCr29s29s6eKet3aiyiIXzhrWpze7P6/mPT2VgQFF5I+It7fH2Xv+0pbReu279v4ueoJI+0IiV07yXlnfzutbOvnJ6ZNJFnRifrW3z3rP72pvj+y9EevHV/W4ibVcd87U3s9fcOtb/fZjl3ZtoPOJX4JtUXvhtWV2eT1wbIvEK38jt+pFApOOpuKUr5J99ylix3wWBJHChkXI4crezLN/zGySr/4dOVSNreXQ27YQnH7yPq8LuJlxBJHApKP7rEu9cQ+C4iE8p9z32irlSC99EO+Ig3t/c9VBlYhPQZVF1rfuDo7s7QM+gAF8HAzMCz4+ZEmkPuojkdXxqQKGBdmSgU+RCagiAVlhwllfYMXN3yS3/Gn8R87HKwv4VQG1uoqqWWfSufQR9DmfQq3Z3c4jSAqBSUeSXf0yVinXm936ICjROuou/T3xF24itfgu9PbtVJ7ydUTVbZNxcLPjdYFf0v7YL2i89zsMOf+neEeNRZUkuvJFTEsgEIkw7cpfs/2xP9Lw0j8pdu6i6sLvkNV0ZEekJqbgU1WmD68k5JUxLBsL15JLt0uMrQmAIJAs6GSKJmGvzI6OHLIooMgSqYJOzKeiOzZDY35GVAbwKLszcYmcTkMih1+VGF0VwsKhszvAWBX0IIkuuVIkEVUeyEV9GFiOg1+VGVYRoDNbpCWlYRkOjQlXn2VUTRDbdudwtu2gmW4Wui7qZ964ahJ5jfZkAUmAFTsTqKpAddBLXchDumSRLZmkigaDo358qkRLWidXspAlgUFRN2Nt2w6ZkoFlO4R9SpmtVA9EAUZW+7Ec1zt7eKWfqE+mqJu9x/bG5k7Wt2XAthldFaAhniOjmfgwWPyPn9O48nWmHn8eR13+TWxLpCJok9EMemYxpYb36XjiOkTFw6D5v6a4fSUVx3wW5xOaE5jpdkoNa4gcfnGfgEKpYTWlhlWMPvVKAkE/JQNUBXwStL7yMKZW4uKvfI9EQUczbDq7heFifg+6aTMk6qMu/MEEfwD945MMTXwd2ACEP8Ex/+swY3iMa8+awk+eWotlO719p3sidtTlFDYvJfXGPVSd9q3e5aLqJXLYRSRe+Rul7cvxjZ61xzofgYlHkt+wGPu4zx+Q114PPIMnUPfpG+l64ld0PfUb9I4LiB5xSW8mTBAlYsdfjRSuIbXoDtpzcarP/TGSr/xW8NSNoe6yG+h47Fo6Hv05seM+T+iQ08sjjvTfEwvuw3BMTbCXEOxNhDa0uctX71rDok0dvfZCv3l+A7e8vh1wifjCTR0cM76Ga591RbRkUQBBwLR2+ysPkOwDx5xRlXgUEd2w+2R9X1jbyvi6UC/xfWVdG0+uakbA9S53oA85T+zVq+XgipclCzpzRlX28cKeMTzWr0f23t/hk++Vl62LuPfbNU+s4VOHDOG3L2zol1zn1i0k/sJNyOFqaj710z72dAC2XqTrqd9S3L6c0KHnupliQSA080wEWcXo3ImZbie76iWih12IIKuoNaMIzTid/IbFOJZB1dk/RInW7fM6O5ZBbs0CfGMORQqWn5vWsonC5qVEDp+PtJfISXrpg9ilHLFjPtO7rDOn05nTy7ZTJIFzD+l7bgMYwMfAwLzgE0Bd2ItlOxR1C0ewKRo2HkVkSFUQB5g6/Ehyy49l3UsPUHHwSWSsGCUThlYECJ5+KS+vfJ7U4rupOf9nZeMGp51IduVz5Ne+Rnjmmf3uuz+IqpeqM79Hpm40qcV3Y8SbqD73x2XPL3XIZOouvZ6Oh39K430/QDr3e4ybeSTxvI4k2GDLlGyRUed9D6ViEDtevY+3km3M/szPqayrJeBTOGFiLZVBH53ZIm3pApphMyjipTVdwhEEOrMapmmTLBp4ZZF4TmdopR+fItGUMLFsB1GE1U0p0kWDQVE/FQGVrmyJVbtS1IQ8ZIomumHjUSR8ioQkCrSmitD9fhKA6qAXSRTwyCLiJ+jn/P8rFEkk6pcRS1DUVRxK5HQTryyR1QxaUgU008a0LJriBdJFg9qwh9qon+qQyobWLAXTQRIcvKpIfcxPe1qnKqhSFVLxKhI5zf1+NdNCQsDnkchrJoZlI4kSibzmqu6LIgXdYkjM15uRzmsmOLCpNc3q5jSKCINrQoT9Cpvbc2imRU3IC47NmuYUhYJBU6pAV7qIIIsMUvLcde3XSDdvY8b5X2XOmZfSkSmRKhoIuNpJqgOJta8Rf+FPyNF6ai/8OVK4Bn/FkE9sTgCQW/0yAMFpx5ctdxyH5OK7kEPVBA46Fd2EmB98XhU72ULz288y7sgzGT5qHG0ZjdogFAsaTYkC1QEPgyt8HDwsijygyP+R8YlcOUEQhgCnAbd/EuP9t2P+7GFce9YUhlf6+yWccqSW8MyzyK99rY+oWfCgE5GjdSQX39VHnCw4/RQcQyO39tUPfUxysILai39NcNqJZN56mM7Hf4mt7VaLFgSByOxzqTrz+2itW1yrr2RL33HC1dRd8jt8Yw4lueBWEi/djGPtJlRnTR/Evt5flkNZtm1/eHl9Oxf//W3uX9bIbW9sL1v3yvp2fvLU2t2CWpZT5q987TPruOaJNQO93geIHjusi7p7nPfE5PpwrzL4Bbe+xS2vb6cto2E5EPEpfcj1nqgLe/qIl+3ZQrGnYvi+lq9oSHLV3cuZ9ctXeGdn+ffpCO59cv+yRi68dWmf79uxLZKL7iT+7A14Bk+k7rIb+iXXZraLtvu+T3HHSoIHn4rRvpXMWw9T2PI2ouJFEETUmlH4RhyMo+XJrnrB/aAgEJh0NJWnfYuaT/3kA1+khU1LsQspQgeVR7R7XqSiP0J41tll64xkK9kVzxKYenxZ9qo/HDO+ZiCwNIBPDAPzgk8OPWTCchzyJYuWVJHGVInqoIeJQ6JMrI1y4qe/gWUarH/uDvKagWnbeGUBXfBRMed8t32scU3ZuGrtaNT6cWTfe/6A9Fn2hPvOP4+a836Klemk7e5vUWxYXbaNUjm0+7k5jO0PX8f2N59AkmUQBUqGQ23Yh6QojD3ls0y66AdkGjfy+h+/xK7tmynpblZTFgVa0iWakyW2dWTZlSwiiiJ1YR8dWY2qkFuNk8wbDIp56Uhr7OjKo0oCRc2kK12iMZGnLVNCNywa4nk008YwLQqaSaaoE8/rdGY1MiW3TzhdMEiXDBRRoDNTYnVTip3xHM2pArZ94NZJ/82oCnmpj3ipDHoYVxsmU9TJlAwqAjJ53cKvSLy7I8n6tgw74zne2BonVzTY0JymOuB+p1s786SKFp0ZAweHobEAXkUmr5mosltZ4FUk93ehuVlnuXsOoplu0MSriGimRVG3yBYNNrdmeHVjG0+s3MWrG9oZFPNiOwIeB4qaTV4zkCWBZF5jQ2uKVQ0J1rdn2N6RpzVbZM3KFfz9u5eS72zm6C/9mlFHn0d7WkOUJCp8CmGvhE92iC95gPhzf0CuGIzsD1NYuxD7E54TOJZB7v2X8Y2e2UeTpbB5KXrrZoYcNx/TUVElkGWZqEei4aV/oHq9fPpL38anKkysC2Hb0JIsIQsCg6N+QCD7Ee3VBuDik8pg3wh8D9insaogCFcBVwEMGzZgrbQ/rGhIcu2z69AMe5/Z3Mjc88mteYXka7dTe/Gve7PAgqQQnXcZXc9cT379IoJTjuv9jKd+rPsyXfkcoUNOQxA+XHxFkBUqTv4qau1oEq/eRuvd36Lm3B+jVO4WdAhMnIcUqqTz8V/Sds93qD7nGrxDp5SNI6o+qs+5htTrd5N5+1E3+n32D5ECUZ5a1bLPc+73mHr+cfoGInTT5qF3G+lv3mB2M7u9Axg2sHpXmtW70jy4rJHjJ9X2ZsIHsG/MGB7j7e3xsknaiZNqyWjm7vt4ry9i70z13jh7+mBOmFzXp+xblUUM0y5TDJ8zqrJ3uSQKtKSK3L+skf95as0+vdf3PJy9K9ZtLU/XM7+nuO1dgtNPpuL4LyBIfR+Xevs2Oh69FlsvEJ13Gfn1C4nOuwzH1Im/+GcExYtvxHT3uAeNx9ZyaLs2kFhwK0aimaozvrtfP+o9kV35LHK0Hu+oQ8qWl7avQGt8n9jxV/epTEkt+ieCJBM98rLeZfuqEqkaKA8fwCeLG9nPvGBgTnDgyJZMfN0Z1rZ0ic5MEY8osL0ziyAIBDwKoapBDJ57JrvefALv1NOoGTGK9nSBZMnAd8gZSMufIbnoTuou+31Z1VjokNOIP/dHSjtX4Rt58Ic+Nt+oGdRd/gc6H7+Ojof+h9ixnyM048zefUjBGLUX/5quZ65n81N/ZUhXC2NOuwrdcfAqMrppki9pRKbMY1Koms0PXsfCP36Jum//mnX1p1EZ8pAt6iiySNinMCTiBwFSbaGzjQABAABJREFUeZ2IR8KnygyKSq5ThABDYyIFwy0Zb4jn2REvEHCgqFu0ZTQqAgpBr8zQWIDGRB5FFvHLAls68pSsDNOHRlFkiWzJoDVZQjNNIn6FnZ06guCQyOmMqwviU5UPuDID8MgSiixQH/G6fdI+hVzJoCWt0SbChuYUWc1AM21KhoVfFujKanhVBVGAmqCXMbVBZEki5ldAEBgS82NYbum+JApIosSgqA/TdvDIYm/GNepXaM9oFG03gdKe1ejMlGjPFskWTTKaRke2xJZmiBdNdnRmCHgVknkdURSoD3tJFw1ymkmqZGBaFptef4al9/4BT6SK2V/5DSPHT0B3BEq6Sb5kEFIVNMOk8Zk/kFi9EO+Ig7GzXURmnY1k6jS/+GeqPsE5QWHTEqx8ktDBp7ktlN3Ldcsk/frdKJXDiEw5loLhZlNVWcIT30jDe29w4qe/zpjhg5k8xM1S245A3jQRgZxuURFQ8SkDyvwfBx87gy0IwulAh+M4K/a3neM4tzmOM9NxnJnV1dX72/S/Hj3ZuP0RTdETIHrEJWhNa10FwT3gnzgPtW4MqTfuxTHLy0DDM8/ETOzqVQD9sBAEgdAhp1F70XXYpRytd3+bwl72X94hk6i77AZEX5j2B3/cb8ZcEERiR11B1RnfRW/bQutd30Rr3bLfc5ZEkCXXH7znfw7uTTx/9jDGVAf6fKYm7MWjiO72+7Bo2lfBl42b4ewvuzmAvughuVK3EufR42t4dMWuDxUw6YEAhHxKH/Gynmz53orhPcsvPHQYCAIPvNPIj5/YN7neH4yuJlrv/jbFHSupOPFLVJ70lX7JdWHL27Td930QROou+R3eYVNQ68bgHzeXwKSjqDjuKpIL/4GZbgdAVDz4xx2G3rmTwqalhA45/YBfpFrLpl6F0T0DY45tkVx8J3K0jtBevVqlprUUNi8lPPtTyEFXeEUU4LpzpjJrRKzsvpclgU8NlIcP4BPCgcwLBuYEBw6fIpLTTHKaybCYl7BPwSOJ6IaDbVoUSiY7EkUGH30hojdIx2u3U9IdEkWDbMl99kTnXYLeuonCpiVlYwcmHIkYiJJZ/uRHPj6lYjB1l/3erUx79e/En7+xbO4hql6qz7mGmtlnsWvJk2y89+dUe6BkWOimiWY6lAyHwNAJTL7qj/irh3Hvr77BnTf/nrXNCdpSRbJ5ExAJBxRqwn6CHpnJg6NUhzzE/Ar1YR+KJBH0yUR8KgGvwkFDYgyv8DNxUJjKoEpAlRheGUA3HWojPk6aVM9BQyKsb8+hWRZeRaQ9VyLkkRlRGcDvkRgWC2BbEM+VSOZNNrVleGdHkqJu7fuCDAAAURSoDXvxqRL1ER9eVUKSRCoCKst3JpBlV6+lIVGkaNjs6CqgSCKZooZmmQyvCjC8IsDwCj9+VUYW3dbAHnLdA68iEfTIZX3WAY/CsAo/dWEvfo+MKgm0p0vs6MzTnMyzsytP0TB5e0cSAdddpDldpClRIp3X2RHPkcnrqIqMYFtsf+5vLLnrt8RGTeGY7/wNb81QTAREHCpCKn5FRiukeO3P36JrtesyEjl8Pkr9WMLj5lI59ShqPsE5geM4ZN59CrliMLExhzAoBNUBCHmg9P7LGIlmhp7waSxDckXNJPDKDq/ddSMVtYM47rzPkCwaVIe8DI76OWJsNRfOGs4R46qZWB/ioCERov3o1gzgwPFJlIgfDpwpCMJO4EHgWEEQ+npFDeCAMWdUJbK4m0juC8GDTkKpGkZy0R1lXpaCIBI9+jNYmU4yK54p+4x//OFIoSoy7zzxsY7RO3QK9Z/+I0qsns7HfkFqyQNlCuJKrJ66y36Pd+gk4s/90fXT7kdhPDDpKGov+R0I0Hbf98itWdBnG0l0r4O4R9RdkgSUPcqHzz1kCL897yBUafc2siTwhaNGc9+Vc/jOSeN59AuHMW1IpJ/xu3tmZBFZ6nvFTfuDbacGUE5+f3L6ZF5Y29qtdnvg6OnN9igiMb/aqxa/9376UwzvUQp3FUbp0w9+IChsXkrrPd/CLmWpveg6Qgef2mcbx3FIL3uUzsevQ6kaSt3lN6DWjOy9v62Cq5kQmHQUvpGH0PXMDb2fzbz7FGamg7pP/xH/mEMP+Lgy7zyB4AkQnHpC2fL8uoUYnTuJHnl5r/oodBPvV/+OFKoifOg5vcunDY6QLOj84JSJPPrFw5g/exiXzB7GQ1fNHajSGMAniYF5wScIjyxS0E3a0q7XdSKr0ZgsoCoS8aJJwTCI+WRUf4Tao+ZTbHif1KZl2Db0zAwCU45DqRpG6vW7ytqyBFkhdPBplLavQO9s+MjHKHr8VJ9zDZHD55Nf+ypt932/zJ5QECVCR3+euhO/RPvGd3nld1+krWknAuBXJSwbHNth0OAhzLj69wyaeQKLH7yVP3/vi2xtaqUrX2RIzENt2MukIWFG14bI6yYdWQ3NtIkXdEKqjGbYyJJAzK8SDahMqAtT4fcwojLIiOoAfo/MoKiXQVEv4YBKVdBH2CNTF/bgUSQkQUAzLdJFk0ERH9GAQtSnEPIpqDKEfDJ5zWBXsrDvizGAXnhkieqQF79HIpXXSRR0SrqFYZmEfB5s2ybsEfHKIn6PjCgL5HSLzqzBks0dFDQNVRapDCgUDZumRIGmRAHN/OAAhyyJ+FQJcNjemaU1nccju3Q66pWZPCTGyKoAiujQkTEwdBvHtgl4ZTQdBElCLCZZc/v3aVn6FPWHncW4+T+jS5NJFk3yJZOGeIGudImOpi0888vPk9m1hdozv0f0sIvomYUUC2kyFngnHUXoE5oTaE1r0du2UH3oWfg8IpV+H0GvQlA2SSy5n8CwyVRPnoslunMqnwTN77xMZ+MW5s3/Ou1Zk8Z4ng2taXYlC+Q0E68iMbIqxOTBUSq7VdQH8NHxsQm24zg/dBxniOM4I4CLgNccx7n0Yx/ZfyFWNCS5eeFWXlnXhtktcLa/7J8gSsSO/Txmqo3MinIfat/wg/CNmkn6rYexipndn5EUQjPORGt8H61l08c6XjlcQ+0lvyMw+WjSb95H55O/LvPnlrxBas6/luB010+78/Hr+vh3gyt+Vv/pG/EOmUT8+RuJv/SXsoCBZXcLXVkOptV9XfbyQp4xPMaM4TEe6PY73hdpuHBWeSmiJMKVR4zksDFV/OyMyTx01VxOnFTb5xi3th9Y7/d/O2YMjzFnVCXXPruOJVu7PnT22gGmDI5wxdwRXPvsOm54eROX3P72AVcQ9ASnPix6+q07n/gVSuVQ6j99U5/WBsAt/X7+j6QW3Yl/whHUXvwbpIB7j3mHTMbWCmT3CGrFjv4MSDKFLW8DEJp5BoM/f2tvRvlAYCRbKGxeSmj6KWUl4LZRIvX6Paj14/BPmFf2mfzaV9HbtxE7+gpExdu7fNWudO81BfjVOVO57pypA+R6AJ8oBuYFHx8lzWRbR5btHRlWNaXJlExyBZPObImiaRH0KoiCzcbmNO/siLMrnifqkxg37wy8VUNpe/UfWMYeRFqUiB39GcxkK9n3XijbV+iQ0xAUD5l3Hv9YxywIItEj5lN97o8xErtoveublJrW9q43Ac/BpzL0/GvRMl2897dvsP3994hnTTwyRP1eV6jKFoie8DUGn/wFWje8w2M//yzbN65nbXOGFTsTrNjZxfaOHDu78rSniqSLBo7toCoiI6oCDIn5USQRQRCojfgYVxdieGWAqF8lXdDZlSyyK1kkXdAJ+RSmDY3SmtHIFd0yX9elxA2+10V8TB0aZc7IKmRZpDlZpDlVYlNbmmxp/21OA3DhOA4dGQ1ZEpBFkbymUx30oJk2lX6FifVRqkJe93sTBLpyGl3ZEmta0izc0Ekir2NYkNcsgl4ZQYBs8cCuvSiAVxZoThYQRYGcZmHZAnURP5pu4fPISLJMTVRlaIWPoE8E2yHolWja8B5P/fwzJBs2MuOyaxh75pcxHIm8CY6NG/DSDN5/ezGL//BVLEOn/uJf4+l+H+89J3CAik9gTgCQWfYYoj9CcPJx2DaI2FSG/XQufRQzn2LkyZ+jZDp4ZAh7QEJn0/P/oGb0VEbNPBoHCPtVknmdfMkkWRjwkv+kMSAP978AKxqSXHDLUj71t6Vc/9Imbnl9+37Fn/aEb+TB+EbPIr30IaxcOQmJHv0ZHL1IeskDZctD009G9AT6eGl/FIiKh8rTvkXs2CspbllG2z3fxojvzvYKkkzFiV8mdvzVFLe9S9u938FItvYZR/JHqLngWsKzzyO36kXa7vsuRqptn/uVRKGPF3IPBkd9nHvIkDJ16R5SMb4uxBeOHNVbGSAAdyzdyZKtXVz77DoAbrt8Jr86Zyp7Bu/e2Zkc8No+QOwpONaj/P5hsHpXmtve2I5m9BUt+yBsastSEfhwZU1WLkn7Qz8ms+xRgtNPpm7+b5HDVX22M7Nx2u7/Afm1r+EbMwfPsKk4esF903aj4virKWxdRn79IuxSDgC1aiiiz62c+LC6B+C+SBElQnup/GbefRIrFyd2zGfLIs22ViD5+t14Bk3AP/GoPuPZDpQMm8cHqjIGMID/lehIF3nivV0sWNfKS2tbWbK1g5ZEnm2JHM3JEomcgVcRaEwUsAUbjyjRkdfdibasMua0qzCSrXS98zQeoKe2xTtqJt7h03qdBXog+cIEp51Ifv0izHTHxz5+/9g51F/2B0RvgPYHf0RmxTNl+hzCiOkMv+wPSL4wOx/8Ecn3XyDmUxhaoaIINiULbFsgNP10xl7+GwytxAu/uZoVrzxBzrBY15JhW0eWjGYgCJDXTAqG2Z2t3I2ibtGYyNOaKVHQXcGmVNHAr0r4VYlUN0kbWR3iqPE1nDilnvqIl7xuUeFXkQQBw7RRZYnaqI9Dh1cS9SuMrw0Q9aok8uUteAPoH44DpmXjkSVqQh7G1ISpjwYIyBKDK4MMq/Jx/MRaTp1az+RBMbyyhAk4CDQm3Yz1zniekmGS0wwMy+nXdmtv2LbNhtYM61ty5Eom6byBYZgEPBIeWUARoSKg4FdFJFsgqxnkihYt6RLvPncvj/7yagRFZe7XbsI3cR4ydm9FiAOkixbNbzzMtgeuRfCGCM04EzFSA45Nz52495xA4OPPCfSO7RS3Lyc84wxs2YPgQEqzaG5sYsurjzB29vGcdvw8RlYFqA17iIS8JJc8gpZNMnv+N+kqWnhlkbBHoT2tsbE91R2oGiDZnyQ+UYLtOM6iAa/LD4cVDUkuvO2tPgrHHwaxY6/EMQ2Sr99dtlytHu7acLz3PEaiuXe56PETmnEGxc1voXfu/Mj77YEgCIRnnU3thb/EKqRpvftbvdG53vUzzqDmgmuxcgna7v4mxR0r+44jSsSOvsKNfidbabvz62Xj7InzZw7tQ6z3JtM9vsh7q0uHfEovebZsd/neRG7+7GFMG1xeTv7C2r6BgQGUY0VDkuZUEVkSEQW3rP/4CTVUhcpJrwBcMnsYVcH+yXDPM15kt4J4T4XHioZk2d89uH9ZI9c8sYa2jNbvmP2h2LCalju/it6ymcrTvun2W8t9xWu05g203f1N9M4G1EETcIwiRttWEq/d7irwdiv2y+FqKo77PIVNS0kuvov4SzdT3LHygPuq9oaZ6SK35lWC004oi3BbuSSZtx/FP+6wPpn29FsPYedTxI6/ar8lXvcta+Squ5cP6AsM4F+KgXnBh4NmWuxM5CmZFkXDpj2jIUsimZKO6LiEQFVFNN2iI10iVzIpWTaCKKHbDh6PROWEmYTGzqL19QdxCklUAYIySIJA7JgrsYtZ0m89XLbf8KHnAgLpZY99IuehVA2l/vI/4Bs1g+SCW4k//0dsY/ez2aoYTM2lvyc8ajrNz93M6kduZHNzkmTeokuDIlB0QK6dyMFfvgn/kAksufNX3Hrtd1i2sZll2ztYtrmDhkSeqpCHEVVBFFFAM6xeMp8s6HgkkYAikcy75MEjSxR1i6LhqpSDGwiWRYG2dJFEwSDQbflkOQ6qLFIyLDTTIhZQCflUOnM6bZlS7+cH4GLPIErP37btUDItAqpMV06jKV4gqxsu4VZkhlQECHq9zBxZybCqIIeOrmB8fYiwKjM46kPAFS19d3sXed0kUzSI+mTC3YJp8ZzmCtKlizQlCpSM3aXjBc2kM6tTE1KxHIfmTIl1bVlyJZO2jFu90JF2v8sdXRk2tKRpaG7jxZu+zxv33kjlxNkc8a2/ER06BhzB9csWQHRAMEo0P/k7Wl+7GylUiRytxUy1knjtdrT3nico9D8naH3pZkofck6wN1FLL30YQfUROuR0TCBnQGvKZP1z/8BxHKae/Xk0E2ojKmPrI0wNFtnw2iNMmHc646dMI+p1Wx4k0UHApqGrxIqdcd7c0sHmjswAyf6E8En6YA/gQ6CH/DWnipjWx7uZlYrBhGeeSeadxwkdfAqe+nG966LzLiG/YTHJhXdQ86n/6V0emnkmmeVPkV76ENVnff9j7b8H3uHTqL/iRjqf+BWdj/+SyNwLiRwxv9cv2zdiOnWfvpHOx39JxyM/I3rU5YQP/VQfEuAfO4f6K26i66nf0Pn4LwnNPIvYUVf0Eh9RgM6sxtHXL6Qrp3H8xFpuvOjgMjKtd5PlPdWl91SdlkQBu7sMTBQETMtBknb3/c4ZVcncUZWs3sODfHL9gJVrD3ru357r+fb2ODG/2ust7lqLC5i20+tBvidOmFTLdedMZfKgCNc8sabPenAjxJIo8JPTJwNw8d/fxjDtXqE703bKPK8fevfAKwwc2yK99EHSSx5ErhhM9YW/RK0e0Xc7xyG3+iUSr9yCHK6i9sIfkF3+NNVn/wCAwpZllBpWkVv1AqFDXA7hHTYNKViJmWxFa95A3WV/+MgEO7PsEcAhMvu8suXJ1+/GsUyiR19RttxINJN59ykCU44vew7sCy+vb2fR5k4e+PyA9/sABvCfguM4FA0Lx3GDjx5ZIuCR6cyWEIGhMT/ZgoFHcUgXDQzDYFtnnqqQQlHTyegmk+sjZEoaAiIFW6Dm6CvZfseX6Vp8FxPP/waSLOEv2RTrR5GZehyZ5U8TnH4KSqwecIlAcOpx5N5/icic8/ut4vmwED0Bqs/9MeklD5Je8gB6x06qz7mm135I9AaJnvMT1Dfvof2tR8m27qDizB8ihSp7xygCXjHG0At+QdeSB2ha8hBdOzaQuvx/qB85jiEVJSRgS2uGJds68cgy04ZFmTWiEllyValTBQNFEhkU9VEd8pApupQl7FPQu/2Y29NFknkTBIfBEQ+GJVIV9JApma5vMlAZUBkc9VE0FDyixECLKr1e1CXdIlMyXb0cUaBoWHhlAdOyKZnuHFc3bWoiHnZ2FbBsAZ8qUhFQqA578XQrVleFfHzmiFEs2tBJa6ZITrOI+RVykquWL0sCkiCQK5k0JvKAQ6poUtctptaWLjG80k+maNCcLBDPa+Q1nZJhM6Y2QDKn054p4VclZBkcbCI+me2Gw6b3V9P4xO/QM10MPelzjDjqPGSfAohUh7xopontCIi5Ftbecy2lzkYqD78YrauRQWf/EAEobllGrmEVu1bse05Q+yHnBHvqyeidDRQ2LSE893xEr1sdaAL51i3EV7/KoHnnQdBtcxxVEaA24ufmH12DR/Xwqau+g8+vEvIqHDqqgqJu05YskS6ZtGZKrl1dRkMRRUbXfLQ5ywB2Y4Bg/wfQk2nVTbs30/dxA0aRwy4it+41Egtupe7S63vLTqRAjMjcC0gtvovizlW99gCSL0xoxhlk3noE/bAL+yUXHwVyuIa6S35H/OW/kX7rIbTWzVSd+V0kn0tOlWgddZdeT/yFP5FadKebOTz1G33shZRoHXWXXE9y0R1klz+F1rSWqjO/h1IxGNtxiUEPnlzVQiKvc/KU+t7raDsQ86u9wlt7ksFbFm/D6A5qWDYg9Pxt8z9PrcXpjlqfe8iQXqXyHlXr/3asaEjy+MpdPLK8CdN2XEsMx+m1Pdvz+gv76MCWRIGrjxoNuJUCjfF8vyQcXBL9q+fX41Mk9G7BNGOPgFRP1cGmtixrmtP9jtFnzEwnXc/egNa0lsCUY6k44YuIqq/Pdo6pk3jlFnLvv4x35CG99hmJl/9Kbt1CgpOPwTtiOo5tUmpYTWnXerxDJmFrBeRIDUrFYHyjZx7QMfV/nF1kV79EcOrxyJHdHpd6+zbyaxYQmnUWSmxQ2WeSr92OICvEjvp02XJJcO/j/p4zPcGoAYI9gAH8Z5AqGCQKOuAQUGXqwl4SWQ2/KhDzu2JDZthCt93fq27ZbGjL4fOojK5XyBUMhlX72LBLR7dtbMdh0PAh5OecSduSJ+iaeSrBIePwKAKlEkSPvJzCxjdJLrqDmnN+1HsckbkXkFuzgMyyR6g44YufyLn19GWr9WOJP/N72u76BlWnf6f32SiIEjVHXoG3djTNz91E613foOqs7/dW5jhA0gJHkIgccSm+YZNpefoGlt30Zaad91Um1M1nyZZOtrZnaS9oDAn5yWgGg6I+6sM+WpJFIl4Fv0ciUdCpj/iIdbcRFXWL1kwRTTNpy2iMqQnSntFoz+pMqPeyvTOHblkMqQggd5M6vyIR8spYtkM/mqj/38OyHQzLRhYFHMehJV0kXzLZlSwyssqP5Qh05XS8isDOnE4irzGiKojgONhAFIWibqLbJj5FcMXvvCp+dTcdqQh4OWR4FM2MsHaXW2HlUSTyJZ2d8QKaYRP0KFi2zdBKP53ZklvqjUyPctG2ziyG5VAZ9NCcyJAtlHrb1lRRoGCYNMcLFE2bre0l1rx4L1tf/CdKqJJh83/L0ImTqYv4GVrhd4/bq7ChOc2O95ay5dHf4Tgw/MKfEhg1g+23f4Xs+oXUHHwM4qjpFP5FcwKA9NIHEVQv4Vln9y5zHIfka39H9EeQZlxARtMJ+aIULJEdq5bw7usLuOo7P+bIg8aTN0wGRT1Uh3z4VQnDdtgRz7MrWaQ25MF0IFM03fv7I+jZDGA3Bgj2vwF7Zvt6/IJ7Mq2maSOIAjgOAjChLoQqi6xpTn8o0i16/MSOuoL48zeSX7uQ4NTd/tfhmWeRW/UiyVf/jvczf+rNKIdnnU12xTOk37yf6nOu2e/4hc1vufYCmU7kSA3eEQcTnHws6qDxfTLQgqxSderX8QwaT2LBLbTe+XWqz/5hb0ZNVH1Unfk9svVjSS66E/3ub1F9zjWoVcP2Gkeh4vir8Q4/iPjzN9F659eIHXc1wWkn9Nnn61u6GFrhL7PuShbc/qgeAbSewEbJKNeX7uFrrqXT7kirACiS0N3vI/QS9P8m7HnvbmrL8j9PrmHPggt9PyrhokgfmywR+MVZU8rIXMinlAWZKvwKWc3sJdI5zSKnlSuG9tiuiaLA/csaaEmVDkhQLb/xTRIv/hnHsak87ZtlPvF7wki10fXkr9HbtxGeeyHRPSoxIoddRHHrMjz141AqBuOpH4/esgkrl8DMdJBb/QrBg0/pLenWO7aTW/Mqxe0rMFNtSIEY0SMv3ee+e5B+6yFw3ElvDxzHIfHq3xF9oW6V0t0obnuX4rZ3iR79WaRgrPc6zRoR4/unTGRTW5YfPbGmz3USBP4r7+0BDOA/BdOyyWkmkigQ9MjkNJe4AaQLBj5VZHhVAIQgUb9KxKfQlCiwK1lAkQTG1gTJlEx2xQv4PAKHjojR2FWgaIMqisiKQm1QgRMuI756Ie2v3ErkM9cjiDKK4GAEK4jMOZ/UG/dQbFiNb/hBAMiRWoJTTyC7+iXX3i9cs89zsApp4s/fSHHnKkTFgzpoPIHxR+CfMA9R9fbZ3j96Fsqnb6TzyV/R8ejPiRx2EZHDL0IQJfKANH4edRXD6HziOtofuIbYMZ8lNPMsBEFAws3SGYAy7GDGfv7P7Hr6D6x68AayW1dwylU/RpMDWJZNS6bIINEV1aoMqER9CjYOhu3g28uXJacZKKKAN6AiCbi+ybpJbcjjvnd0E68ksq0jy+Con8qASkF3s31hr8LQmL/Pef7/DMt2aEkVMS2bomGB47CxNU26ZJDMGbSn81SH/eQ1k4qAB1USEEUB3bJwEAh5JNa3ZmlJFKiv8CGLMpMGh6kOetBNG48sIgjuvCJdsjAtm5qwj+qwh+EVAd5rTFDQTWzHoTGpUx1QKWomIY9CXrMR0VAlka0dOVpSJcJeiU2tWZZuaado2GimjU8WyeQM8hYYpoVUTPPefb+hY/NKaqbNY8ipX0b2h5lQH2ZMTYjasA/HsljelOD9Z//BhhfvRqkZRc0510C0Ds2CuiMvIr15GeLocRQ9HzwnsLU8+fWLKWxagta6BWwL35hDqTz5q32STXtC79hBYeMbhOde2Ju0AihsfBNt13oqTvoKksePpplkSzohVeLOG37BkBGjOOeSKxlXH8RxoCbswae4YnFj9SArdySoCakIQFe2RMQnD5DrTwADBPtfjD2z1T3lrD1lyyXDdks/9mDShwyPce4hQ7jk9rfRjP17Ye+NwJRjyb73AsnF/8Q/bg6ix/WFFmSV6DGfpevJX5Nb9SKhQ04D3Cx2eObZpJc+gNa2FU/dmH7HtfIpup65HrliCOFZ52AkdpFf+yq5956ndv5v+lVbBldMTa0dTcdjv6Dt3u9QcdzVBA8+FUEQ3L7sQ89FrRtD51O/o+3ub1Jx0lcITj6mzzj+sXNQPzuG+HN/IPHinyhuf5fKk76C5C/vkX51QzuiKGDbDoLo2nTsiZ7Axt6Qxd1K5XuPp/ewyf8Pa8H2DvzsXfa9d6baNO0Dtr4ScIVN9sbFs4cxvi7EzQu3EvOrrGtJ05HV3MvbvX1Ot/j5mVP43UsbSRX6Vwp1uv8xLIfmVOkDj8fWCiQW3Ep+7auo9WOpOuO7fbK/PShsXUb82T/gANXn/g/+sbPL1nuHT8OIN5Fd+RzReZcih6uQwtVozRsITDiC4PSTkYMVlBrfJ7XkQbTG90FS8A6fhn/sbLRd64k/fxOewRP3eQxGqo3c+y8TPOgk5MhuVfvCpiVoTWupOPFLveVhAI5pkHj1Nvc3OvOMsrGOGl/DjOExNrVlEYS+38tZBw0ayF4PYAD/Jhimxc54HhsHRZQwLYeQV6YjU2JnV56spqNKIlMGV6CbFpmiTkVApTrkYWdXjrZMicFRL0eOqWRr1IsiirQk8hQMm0q/SrpgMKrKz7AKP5oDB5/7Bd65+9fkNiwiOv1EogHwGTbSrLPJrn7JDbxfcdMeAcQLyK1dQHrJg1Se8rV9nkdy4T8o7lxFaPopOKZOqWE18Rduorhj5T7bzpRYPXWXXk/X07935x0tG6k64zu973K1ejj1n/4jXc/9keRrt6PtWk/lqV/H7J7LKIAiQE1tLcO+9Ft2vfEYG579B3d9/2KOv/p/iI2ZhSyKzBpdhU8RWbEzSUE3sCyIBT1EfeVzAq8ikSm6/cARnwe/KjKuNoggCLSkS8R8KqZtk8hrtCTyvL6xjY68xmEjK/GHvVjObgE5oLd3Vfw/Sk5KhkVeM1FlEZ8i0ZIqYViW62Hd3ceeKRn4FYkdHVnCPoX1LRkyuo4sSFiO62ce8yu0Z4qEfAqW7aCbNnUhL4btEFBFogEPju0wujaMhEhDvIAsiRR6BcwEOjJFRFEgoMookoRXlfGrMh5ZQjNswEKzHDIlt1e+PuqjM6vRmdMoaSaJvEZrpsSW1jSpokmFXyWeK5EpWtg22DZkN75Jy4t/Actg1FlfZ9jcU7EdAY8Eed1ia2eBYtGgkEvx7A0/pGntO0SmHE/4xC8iKh4ALMA3eBrptia2vrH/OYGVT5F553Gyq17A0YsolcMIdM95c6teQA5VETv2c/v8flJv3IvgCRA+9Bw8gIbrJpJcdAdKzUiC007Ag1s1Es9qrFr8ME07t3HhD2/ipY2drO8sMntEBbVhL7plkyno7EqVqIt4MLFRRYlR1QF8ngFq+Elg4Cr+i/H4yl29RFnrVu697pypnDy5jidXtZRt6wAPvdvIuYcM4Yq5I8rKZodX+GlI7N93URBEKk74Am13f4vUm/dTcdzne9f5xx2GZ9g0Um/ci3/ikb39H+FDzya78llSr99N7QXX9juu1rIJx9SpPPGLeAZPBFzCUtj8Fp4hk/Z5PI5loLVsdPfl2CRe+RuFbcupPvO7vVE677Bpbr/1078j3l22Gzvuqt6HVw/kUBU1F/6SzDtPknrjblru+DKVJ3+tzDdwT3Ery3b42TPrGF8X6iUPPYEN3dhNFGVJ4Nozp7CuJc1Dy5t6++Ftp3y8/spo9yao/5ewd+DnJ6dP5mdPr8WwHCRJQMQlrz1cbH+Z6v4gCrC3tIAoQEdW6+2l3lfwSDdtbl64hWmDI7y+pWuf+zjQ4FOp8X26nrsRK9vl6gIcfjGC1PfR51gmqdfvJvPO46i1o6k66we9/Yl7QvKFCUw8kux7z9P1zPVUnPAFitvexTtsKgB2MUP78zdS2vkeUrCC6NGfJTjthN7fnJFsoeW2qyg1rt0nwU6/eR+CKBGZe2HvMtsokVz4D5TqEQQPOqls+8y7T2AmW6m54NpyP2xgS3uWa55Yw0PvNvVbFTO2dqDXagAD+HfAst2AYEuqhN1dgpkrGUweHCVXMtgRz6GIAs0pja68TtirUB30UB/xkSpoOI7D8MoAzak8E2sjzBhWwdb2NPGCQapkYloGNjZBVaQu6mOCYTPutE/RsPRZdr1yJ9MPO4EdeRBEiIY8lI75LC1P/aYs8C6HawhNP4XsyucIH3ouSuWQfs+l1PA+gfGHU3H8VYBbXaPtWoeo7jsD51gGufdfwcx0oNSMorTzPVru+ArVZ1+Dd4g7txA9AarP+RGZd54gtfhO9Dt3UH32D1BrRyMAhgO6pYMtM+jI8xg06VBW3fsrnvjN1zj6rPl8/Yc/Z2RVgDVNaQQRtrfliAVVvB6JXMmgJrw7ux7yKkiiwM6uPMMqfSiSiOXAkJifyqDK1vYsa1oyNMfzbGjJ0J7ViHhlmuJ5Lp87mqEVgd3Xw7BoS5dwHIfqkIeg91/bUmbZDo7T3ab1CcCwbFrTRSRBIFWwKZkmyZyBLIrEczqDY17yms2Glgw7ujK0JApuJtt0qPDLlByLTLFEMq+jRX2MqAoS83uYWBehYJhUB720pIsIuOXZqZKFIkuEfQq2DUGvRFdewytJWJKrmyPJAuuaUxR0k8ExP0NifnyqRDJbQhEEqkMyIa/M+l0ZBMehIVEg6lNoyxR5ryFJuqihGxbpksmuRIlidyJFKuVoe+UW8usX4a0fy/Czv0Ns0BBKmkHAJ+MIIrpmElIl3np7OUv+8XO0fJpRZ30Nc3x5BaUDFLxhfPuZE4ieAKklD5BZ9hiOqeOfOI/wzLPx1I/tHcdMtlBqfH+f34/evIHi1mVUHnkZ0WCQoArpErQsexwr00nVad9CFCUEwLJNColO3nz0doZNn0dTcBx6Z458ySYgi9SEfBi2TXu6iCgIhD0KQ6IiVSEPIyv7tssN4KNhwKbrX4gVDUkeWd7USwQc4JHlTb3ErD+YNjy2chfrWjNlyxuThQP6sjz1YwkedBLZFc+UKYQLgkDF8Vdha3lSb9zbu1z0BAjPOZ/SjpWUGvr/cTuWm0UU9iC9osdPcOpx+7QYcBybwpZlaI1rqDzlawz5yt14Bk+ktH05rXd/E71jR++2cqiS2ot/RXjOeeRWv+RafXU19RlTEEQis8+l/vI/IvmjdD52LfHnb0Q28v0eg26WWxH19GMfPraqt1DMth2SBZ3rzpnK2Or920ntWUbbn2L5vxL9qWZ/HOytrv7Qu43o3YTatJzevz8KBOC4ibWoezWo2Q68sr4dfT/kugfNqdJ+yfWBwNZLJBbcSvsD1yBIEnWX/JbokZf1S67NdAft9/+AzDuPE5x+CnWXXt8vue6BUjGY2DGfRY7Ukl76EErFYILTTyH+8t9ovfPr6G1biB3zWQZd9Xcis88tEzQR5O7fkW32O7besYP8ukWEZpyBvIfYT2bZY1iZTiqOv7o32wRgZjpIv/UQ/nGH4Rt5SJ/xnlzVwgPLGrH6Ydfqf2nrwwAG8J+AYbnPvmhAoSurYdo2kiT0+ikbFkT8HlccyjAZWRVEFEW6chqKJJHXLToyJWxbQDMtDNthe6JEbdiLTxaQRZlpgyNMHhwBR+CQEVFqQn7O/OKPyGVSNL52N6OrQlT4VYZW+pl42NEEhruBd6u4e84RmXshguIhtZczyZ5wLLNsTiAIAt6hU1BrR/W//Z5zgpO+zKDP/InApKPAtmi///uklz2G0215KAgCkdnnUjv/NzimTus93ya/4hkEx0ERQBZEkAQGRf2MGT+Zr/3pQU668DMsfvoBvnz+CTzw5Ets78qSzJbIlQzaM0VW7kzQlCiQ28u72q+6itSaadOULNCWLnaXK7tK4ltaUizZ2sXaliRtyQKdeZ3mVJF0USNXMihoJo7jEM9pqJKAT5XoyullitqfFEzLpqSbJHIlGuI5GpMF0oVPxirMfT8IeBUJWRRJ501kSSDokykaJrmSiUcRSRd04lmNRMGkK68Tz5bIlSwUUaCkuWKyRdPEsi3CXgWPIuJXZQTHIex1gxyqIjI45se2bXyqRE432dGZxyNJyLJA0bTxqjI+SSSe1ygZDg1deRq78mA72IJIxC+jGbC+JUO8oLO1M097tsi6thSLN3awpT1DZ1bDcMC0IG+7YmGFbe+y8x9fIr/xDaoOn8/wS69HiQ1GFATyBnRlTJoTOju7srz64G288oev4UgqEz7/e6oOOZGg3H91Qn9zgvDs88hvfJOW279A+s378I08hEGfu5nqM75bRq7BnV87Vv9zAsdxSCy+CykQZdSRZxL1gSgLkO1w7UUnzCMwbCoeIKBCNOjj7Yf/gmNZHHz+V1zLO5+K3yOSLppumb3t4CDg90h4VIm6sJfqkBe/R/3QVqcD6B8DGex/Id7eHu8Vf+qBZTu8vT3O0Ar/Pu2EBOCUKfW8sSfBcGBUTZCtHbl+P7MnokddTmHTEhKv3ELtxb/ujbap1SMIHXwq2feeJ3TQSb0vwvCM08mueIbk4n9Sd9kNfUizHK4GwEy1o9b0//LcG2aiBa1pDb4xh+IZNB5bLyGFa4iMmkHuvedpvftbVBx/NcGDTnJLxkWJ2FFX4B06la5nb6D17m/ss99arRlJ/eV/JLX0ATJvP0pxx3tUnPwV/KNn9TmOR5Y39Xpig0uy97y2tgOrm1J848H32NCW3e85ffaudzh0RCVfOGp0v/Zfn0QW+/5ljbywtpXJ9WEymsl7DUnasyXSBQMHylSzPw72Vlf3yOXfudRPBvqDoEquargoCPhViYn1YVpSRTpz/36v0FLjGuIv/Akz1UrokNOJHnVFv32BsGdftkPVmd8jMPHI8vUbXsexrT7tC4IoUXHCF7Atk8L6xbTcdhV2KUfo4FOJHHHJPlVCzbQr0CeF+lfpTS66E9HrBr56YKTayCx7DP+Eeb1R8d7tX70dHNeub1/o+Sp7FF5nDIsypjbEp/b4bQxgAAP410KRRGRRwCOJVAQ9DKvwIXZnIKuDXgZFPOSKOnURL2GvgmnZqJKAJApUBVWGxPzs7MozuiZI2KdQFfQwoS7ExrYMlQEPNREPIOJVVZqSBXIlneZ0keCgMUw59lyWPvsQlYecSLhqJEXToTYSYOqnvszbf/wi6dfvoeKkLxMSoRiIEp51Dukl96M1b8QzeEKfc5HDVZip9j7L94X+5gSOA7HjrqK4eSmpRf9Ea1pL5anf6C0Z9w6ZRP1n/kTi+RvpWnArpYbVjDrna4hSJZphsq0theWIJAo+pp7zZQ45/Dhuue77XPfli5l16kUcdPbnsUQfAY9C2OdgOxbtGY2ARy6bV0R9Ck3xPKosIouimzXVLJY3xWlKFEgUSugGFEyQCyUMU+bpVU2815Ri5shKpg+NYTuu77JpOji2O7f4JETQGuJ5EjkNsMmUbLa0pshoFiOqAxw6ooJkwSDsU/ZryXgg8MgiPkUk1x0w8MgiO+J5vIrO2JogYb9KW6pAZ7ZEZ04jmdfwqCKqIlKyHbymhSAKlHQTRXKrAIZV+IjndWzLYnOygC2Abtl4FZGYV6UlXaQzqyEBUb9CLCDTEC+QK5rUR/2AhWWDV3FLxqtDKqYDBd1CEaEjbRLwi9SHfRQME78ssbOrQKbkCpoVS2DbrrK5XMrR8ertZNcuQKkaxugLf0Jo6BjCqohuC4gSqDKUdNDScXY8ewOlxveJTTmKIad/GY/qJ1+EvOMKkGbeeYzYMZ8rC9j3zAkc28JMtdHx8E8o7XwPpWYUVWd8Z5/tlODOC/Y1JyhuX+62hp3wRXTbRyEPiuzQ/MrtgEDtMZ9FlcAjgcejkNm6muaVrzHu5E+TVasYpAqoiogqCYyqCZAs6uA47u/P56E27GNQxIvU/Sz6uPfSAFwMEOx/IXpLkruJ2J6evnNGVXLBLUuxupdLkmsVpXSrV88YHqMxnue2N7bjOOBRRD57+Eh+2l3Guz9IvjDRoz5N4qW/kF+/qIwYROZdSn7D6yQW3ELt/N+65FZWic67lPjzf6Sw4Q03qrwHlOoRIMpozRvwj5t7QOdealqLY5m9/SVGVwNysALPoAmEDjqZrmdvIPHSXyg1rKbi+KuQAu4k3zdqBvWf+TPx524g8eKfKO18j4qTvozkLc8uC7JC7MjL8Y+dS/z5G+l89Of4Jx1FxXFXlfVmm5bTWy3QYyW1t5/1nork+0O6YPLK+nZe29DOcRNrkUWXUApC337vD4sVDUluXbyt91je2Ef2Vjc+GTK/p7p6zK/ys6fX9q5TJIGfnzmFe9/eyfrW/QcdejB9SIQLZg3jJ0+txbSdPu0P/WFiXagsqBH1yaSK/UdwDxS2lie56J/kVr2IHK2j9uJf4R02rf9t9RLJ1/5ObvVL3X3Z3yvLWttGieSrt5Nb/SLe4QcRmHR0nxePkWgm/tJf0BrX4Bk0gYqTvvSBQSiteQMAnrqxfdYVd66itGOFK1S2xz2ffO12EERix5T3ZxW3r6CweSnRIy8vUxrvD5IAFx06rCzgNIABDODfB0kUqI940Uw3S1Q0bLyK1FuqfNLkejqyLgGURejM6QQUicqgB1EUmTI4QlXIg27Y+L0SYa/MhLoIhmkT9qoIjsOu7qxmPF/CK7skXRRFpp7xOTa/vYAld/+ew7/6RyoCPsJ+lUmTptB15Nlsff0JBs06ifpR44jnbcxDzyG76nmSi+7onSvsCbVuDPn1bvBxz4qafWFfcwIpEKXq7B+SXfksyYX/oPWfX6Py1K/3VuME/RHGfO5nbFv0OLsW3MWGv32VYWd/m/DIaQgqKKKI5YjkSjqDhk/h/F/ey6L7/8I7zz3AurcXcfxnf8CkY09AkgR0w8GybUqGTUF3+32LutkdTHbFuNa3pmlPlwh6ZDySQMmwcGxQJVeKpTqgIskKnTkTKBD1qUiCwNAKP+migWnbhL0KLakCg6P+j9SPrZkWXVmNZF6nJVUkGlBYscOtXmtMFhEE6MyUMEybCfURHMf/sWViBEGgNuzFsBwyRbc0vDrsJVXQGFoRwKe6Gf2gV6Yu7HMFwgSRWMhDlV9hU2sGVZWpCXuJ+r2MqQ25mjhAVrPpyGqMqwtR0EwCikRNxItXlejKaaiKiFeWkEWJgKowLBYgUdAZWRVmbEpj5c44kiAgCCJFzcBybLJFd77skwS68jq5kkFzIs/65hSJwm5pI8FyyGxZys7nbsEqpAnPvYBBh11MIKTglQUU2S2rxnaFhuMbl9Hxwk04pkblKV8jNPUEMo6AV4MSUNj6DvHnb8QxdQKTj+2TiXYsg/Syx0gvfQhBUogdfzWhg0/d72/E1goYnQ34Rx9attwHYFu0LPwncmwQkYNOQndAs6Gw/T0ym5ZSOe8yQpXVCAKEfRJ+Bd585EZC1YM4Zf7nCfr9BDwKU4ZEmDY0hm7YiCLotoNXEhlaEUAShQFS/S/AAMH+F2JvEpMs6GX9ug9/4bA+XsJ7rv/BqRM5YXJd2fLxdSFuXbyNBRva96syHjzoRHLvv0xy4T/wj57VK4gkeYNEj7qCxIt/Ir/utV4l48Dko8ksf5Lk63fjHzcXQd5NGEXFg2fIRIrblxM75rMHdO7Z954jesQlCKKEmelCa96IY5t46schevzUXPBzkq+5FlyFrcsITj2eyhO/BLgl4zUX/ILMssdIvXkfWvMGqk77Ft7hfYmSp34s9VfcSPqth0m/9Qil7SuJHfs5AlOOQxBcBctl2+P8/qVNH7nkeW9Y3aXOQrfyte04/Ozptb393gfSm72iIcljK3chAJMHRbj22XV91M37gw0fm8z3oEdd/eaFW3srLQTg/JlDmT972IfylQ75FJIFvd8yZGGP/9/zDI29JMYDno9OsB3HobBpCclXb8PKpwjNOpvovEsRlf6z1lrrZrqe+T1mspXw7POIzrukrHdZ79xJ11O/w4g3dq+/tLzvyjLJvPM4qSUPIMgqFSd9ubsa44MbOYrbl6NUDe9V+u4d07ZILvwHUriG8IzTd2+/bTnFLW8TPerTZd60jqmTWHALcsVgwrPO6bOfQ0fEWLUrjWnaiKLAtWdNYf7sYX22G8AABvDvgyyJyJJIoB8hoYBXYeQevbsVwfLnlyyJDI76ei10BEGgPuoj5ldIFXQ6shpt2SI7ugqYtoMp2GimA4KF4Akx6ayree++37LhzReYeeK5jIoFiBdKHHnBF9m1ciEdr9zG0C/eQNCvYAHZIy4h8dLNFDe/hX/8YWXH4h0+ndyqF9F2re9TVdMf9jcnEASB8IwzkENVxF++mY6Hf4J32EHUXHgtRVEiXnSonXMOcv1Ump75Pdvu+RHhQ8+hct5lhAPgU0T3elCirWAz7JQvEJ00j/WP3MBT13+ddW+cxAVf+xFjaoJIgsP7u5LsjOep9KukCgbDqnxols7ijV1YAkQ9CrppocoiHllGlSUCioQFRAMeDFvAsqEjp7GtI8PkIVGCHoWNrRk6cxohj0s0Iz4VnyLRkS2R0wyCqkxFwINXlbBsh1zJxKuIeFWZbMmgPVNCltyJhSxJOA4kizo1YS+aZeNXBPIlw1Xnltw2OJ8qUTSsfu+nDwtBEFBlAUmAomniV2RCPhVVFt11kkiqqCNKIhVBLwcNjoLgsGBjO+miiVd3GFMbYNKgMKIo0pEtEc+5lQolw6IzXcIjS1QEvRiWTSQgo5kWsiiiWXb3fAqKpoVpOfhUCZ8iM64+goMrPKZIApJtkbOhWvXQldcp6habm5M0JAsUDbdPXwacXCfbF9xCatMy/PWjiZ7/M5Ta0SiyK2wriCIZzcCxAaNEw4u307n8RVeD5YzvoFQO7Z03Fk3DtY1d8QxKzUiqz/x+H30CrWUT8Rf+hNHVgH/8EcSOv6pXOXx/KO5YCY6Nt9tGtwc2kH7/ZYx4I/VnX4MkyW5GXzDY9uKteCrqGX3sedTFfGiGRd5w2PLaw2TbGzn86t+Q1CEcEMlpBtUhD5VBDyXdIl008UhQFfJ8Yj38A+iLAYL9L0YPiTmQdf1t1982t10+k/uXNXLNE2v2uV9BEKk48Uuu4Nkb95R5WganHU/u/ZdILvwn/jGzEb1Bt0T7mM/R8dCPySx/msic88rG8487jOSCW9G7GvvYae0Nx3Jfmj3lLvn1C7FyCfzj5vaKmzlaATPdRmDSkRR3rib33vOYiRZqLvg5gih1CzxdgHfEwXQ9cz3tD/6I0KyziB15eRn5BxAkhegRl+CfMI/Ei38h/vyN5NYsoOKELzL94Gkfu5e333OkXI1Ztxx+88IGWlNFdu2hal0dVDl4WAy/KvH29jjDKvycffAQfvLUGnp0w6Ru5fMDgcBu+7GPi55AQLZouJUMjuO+3LvX7S1sNrzCT0uqQH9xgMn1YVY3pfoNYpw1fRDxvI5PkXoDQwL0UWb3qR/tcWSk2ki+cotLXGtGUX3u//SJKvfAsUw3GLP0QaRgBbUXX1eW4XYcpzuTcgeiN0DNBb/AN/LgsjG0tq3EX7gJo2MH/nGHETvhCwf0EgUwcwm0pnVl1ls9yK99FaNjB1VnfLf3HndJ9K3dCv5nl22ffvtRV9jswl8iyH0Fdd5rSvHzM6f0CewNYAAD+L8LQRB6n9M98KoytYpEMqdTLJkIokNYVRAEh+GVQTyy684gHX0GO5Y+R8NLdzDkoCNYbtlEfAqtJZGRJ1/JhkeuZ/uSF4gcfCKaBcFpJ5Jd/gzJRf/EN3pW2XPGN2oGgqyS3/jmBxLsA5kT2KUcubWvEphwJHrXTkoNq2m+9Urq5v+GTKQWL2DXjaH+0zeSXHg7mXcep7hjJXWnf5uKCeOIeEUShQLpoolPEQmPmMycb91G+xsPse7Fe7np6rcxv/djwuddSrJoksxr7OzKoRsOqxrjbOvMktcsaiJ+OkSNqfURLNu1NEqXdFRZYlDUx5haPxtb8iQKOpVBlTE1YbyiQFdWI5U3sG2HXNFAEgW2tmdpSuRY35rFtG0KmsmQqJ9J9WGSRZ3WjEbQI3Hs+FpXqK6gd88tHMbUhIj4ZWI+lXRJZ0y1GxzoymoUdJuwT6Im7HHnDx/Q7+1WUbpl36btYDsOquSSZsdx0C0b24aCbqKZFvFsieZkCQSYOjRKTjPd3mxJIOZT8coSRtjLjOExdNti4eYuJtX7aUgU6UxrDI75SBUMmuJF1remifgUfIrE6NogYZ+rdJ/VTPSi7Xpge0SKho1HEtAMi0SuRGXQi0dye61VxcvapoybeRUEvKqMaTgUDItsySKvaWxLFIhnbbTu+y29/GmSS+4HHEad/DkGHX4OpgUeVSBVcnv5LdNGsyHXuJGdT9yAkWwjcui5RI68rDzg3tVI1zPXY3TsIDTjDGJHf7bst2DrJVJv3EN2+dNIoUqqP/WTMgHeD0Jh45uI/kiZaLAACFqB+Jv34R86iVEz51Iw3Eq0ttefRIvvYsaV13HomFoQBDJ5g1xXC28tuJf6aUcwadY8MrpNPK8R88lsbstiOXDI0BhBr4IoCKjyALn+V2KAYP8fxbqW9Adu46kbQ+iQ08iueJbAlON6fahdtfEv9iHfvhHT8Y2eRfqthwhOPa63bBsgMOEIkq/+nfyaBagfkMUWJBnPkMl0PPIz1NrRiB4/4UPPLSM9iVdvRw5XE513GeDQ8fh1lBpW0f7AD6k48Uuo1SPcc6gfS/0Vf3Ijh+8+SWn7SipP+2a/BEqtGkbtJb8h9/4rpBbdSeudX+ON9WcRmHPRfr0FPyz29GzeE+/u7CtA1pnTy0rQ2zIa7+5MlhFRy3aQRfdFJ4humVZzstjvviWxf8/ivS22+vt7U1uWh95t7O23Xr4z2cd2y7Ac7l/WyMPvNhHzl5M2RRZRZQlDL/eklgS4Y+nOPoS8Luzh7OmDe9eJ3b7VTreA2p56AgKgfMhSOtvQyCx7lMyyx0CUiB17JaEZZ+yzFMvoaqLruT+gt21x2wlO+GJZGbaVS9L1wo2Utq/AN2qm2wsYiPaud0yd1Jv3k3nncaRAlOpzrsE/7rB+9rRv5NctBMfu04ZhawVSr9+DOmg8/j16wNNvP4qZ6ibRe7zwjUQz6bcfwT/xSHx7Rb17t7Ec1rWkue6cD84uDWAAA/i/C920ac8Ueb85iSorRH0iTckcoysDzBwVo2Da6CUT07AYf+7XeeePX2DNk7egf+obbGh10A1Qxh1JaPjzNC64kzGj5mB4wm7g/djP0fHIT8mseIbI7HN79ymqPnxj51DYsJiKYz/XJ/C9Jz7cnOBSRI+f5OK7yLzzBC13fIXYsVfiTDvRrUpTvVSe9BV8o2cRf+HPNN31TTwnXY56zPmUbBFFkchrFjGfjCcQYOi5V1Iz/Si2PfkXfvc/3+XxB+7lsm//jGHjp9GVMbrVv/PkDAvDcsiXdEIeGVEAzbK6y6RNdMtElUQUSebI8dW0ZzS8ssSskZUEfTKW5RBQJTJZnZLhgGDzXHOGzmzJrSYw3VJ03bRpSueJZ3Um1EXQdIt3GuJISEii+46M+hUyJYOgR+HI8TXEc0WeWtnE5rYsEb/KvLE1iN1BloBHJrBXcNqyHRJ5DdN28EhCb2WYLApYjgMO+D0ypmmyvatAqqiRLbokOuCRaU2VmDwoTNG02NSaYUjMT0G38Msume/MlQh6FcIBlWTeZESlj0TeRFVE5o6qAAQKuknRcDPOMb9CyKsQ9ql4ZIlMKU9HVsM0baqCKl5ZpT1dRBZAs2xUWSJb1NnYnmFQxEtTsoRPEaiL+FneEKc9W0Jw3Bh9yTAREEh1k+tSw/skFtyC0dVIbPwsBp/yBQIV9diigGlaRBSZChwGVfjY1pag+eX76XjrMaRQZb8B99x7z5FceAeC6qP6vJ/20fop7lxF4sU/Y6bbCR58KrGjrvhQ802rmKWwdRmh6SeXzV0EoPWthzDzKWo/9VMKukDML6Il47S88SDDDzmSs047ncqQSmOiQLKgs+yhm0AQmH3h1wmHPNSrChUBBUmQCPtUDBN0CyKeD27pGMDHxwDB/j+IHnXyA0F03qWu4NlLN1N3+R96f8CeujGu4NnK5whMOb73RRc79kpa/vFlUq/fU+aBKQVi+MYcSm7NAqLzLus3Y7YnglOOxTfmUMz4LtT6sVi5RO+64s5VGPFGaj71P70PIjlSC8MPQm/dTOvd30KJ1hE96jP4x8xyX6gnfgn/mNnEX/gTbfd8m/Cc84kedlGf4xAEkdBBJ+EfO4fU4rtIvP0E6TULiR71aQJTjj2gEt79QRIFfnHWFJ58bxfv9EOoDwR7c3NRgGMn1FAV8hD2yGX2bHvjwlnDejORPaQ65le59tl16Kbtlvs4Tq93dc/fguCWRB0oTNvpI062oyvHuJpQHzE423EjwXujNuxlZePuTLjtuP/0F2t3oHdcnyxS3I8tWG85+MI7sDId+CfMI3bs55D3IRDi2BaZd58g9cZ9iKqPqrN+QGDCEWXbFDa/RfzFP+MYJbdn6pDTy0rCS7vWE3/hJsxEM4GpJ1Bx7OfKfKgPBI5jk1v9Ip7Bk1Aqh5atS7/1MFY+SfW5P+7d775ItOM4JF7+m9vftR9hMzhwK7MBDGAA/3fRldMomRam7WBg4eAwuibIoLCPDa1ZNNOiPasRVGTqh42mZu7ZtC95DO/E45C6hZdkUyBy3BfJ/vNrtC28k9jJ7vvfN2oGvlEzSS99kOCUY8oC78FpJ1LY8Dr5TUv6iEDujQ87J7DyKQITj8RMtZF48c+kXr+H2PFXEewOQPrHzMbzuQkkXrqZrS/cQcv7S6g7+RsMHTsCnyJTFfTg9ch0ZQqosWEc/uUbmbR+EUsf+BM//dw5zDrxbKadcRWRiirSJQvTsDEdgaJlERVkGlIF2lNFZFlCFBy8XoVRNQH8iowoikwaFMK0BMJ+hfqoD92wWNFkky7qJPM6ohMk6JUpGQrxXIn2TBFBFMlpOrohoFkWPkVCEASqIx4iXhnLsclrBl5FolaRUGVXNOyJFbt4cV0bmm6imQ6aaXP6QUOoCXupDXkRBAHLdkgXDWzbxrBsNNNGFkW2JQvUh72oskRDIk9VUMWvymxsSdGULJAuukGG2rCX7Z15hlf6yOkm8bybtQd33ExJZ02yRFVIRRIFhlYEu6vwbCr9HroyOkMiHkbWhHFsh4Jhohkmec2iOVmiPurQ0JXDp0os3daFZtrEfAol06Q5WSCjW8hAPKdjOW65eLKoE/Z4GFHtpTrkY9mOLrZ1ZMgWLWIBFce2sBwBW7DJpNtJLvwnhU1vIkVqGXrujxl6yBwMCyqDMl0ZHcuBdMEg5BFINW3h/dt/QaG9gdDUE4gedyWiZ7ftmplLEH/+Jko7VuAdOYOqU79R1tZla3mSr/2D3PsvI8cGUTv/N/sVMdsX8mtfBcsgOO2EsuVWsoXM8qcITTmOUP1YfDJkNJvNT/wFgGEnX01TskDIr1AbCdC1bgnN7y/hsIu+yuChQwjICuNqA3hkhYxmdPdoy/jVgaz1vwsDBPv/IPpTJ98XRE+A2LGfp+vp35Jd8SzhWWf1roseeZlLvl++2VUPFyXXWmDGGWTefZLg9FPKIsyhQ06nuOVt8hsWE5x6/AfuW/IGkQZPwCqkya1+Gf/EeahVwyg1rCI45ThEfxRw+1a0pjXUXXYDjlGi86nfobduIv7sDZSmHtfr5+0bNYNBn7uZxKt/J/PWQxQ3v0XlqV/HM2h83337I1Se8jWC008mseBW4s/fSHbls8SOvfIjPQR7MK7GJVXah/SF3hd6sroLNrQjicIHCthNHuQKuO3pYy12v1wdXEVzoM/fnwTT2hdB39fQq3f1rbI4kNt2f+Raa91C8rXb0XatQ6keQdV+RMzA7aWOv3ATeusWfOPmUnnil8omiLaWJ7Hg7+TXLkCtHU3l6d8ua4Gw9SKp1+8mu+JZpEhNvyXjB4ri9hWYyVaiR1xattxItpBZ/iSBKcf23suO45B45ZZ+SXR+/SJKDauo2Ks0XZUERlcH2dSWxcEVq/vUIf172A5gAAP4/weO41DSbYZXhfDIEs2pEqNr/KiCyNqWDiYPjtCZ0bEFB912qD/qYhJr36D9pb8y7DM3YUoKJq6gaXjW2aTeeRzPlOPxdpesxo69kpY7vkJy8d1Unfr13v16h09DrhhCdsXT/YpA7o2PMicQvUFSb9xHZtmjxJ+9geLmt6g+6/vueP4IVWf/EG3D63QtuIUdd36VwpHzmXTiRYiiD78iEfJ4sdAI+GRic07m9NNPY8EDt/Lkff9g9esvM/OsTzNs3jlE/F5yJQO/IKDbNs3JIgXNAMFibE0AvyqxpaPIrOFeEnkNjyShmSZtmRKabqFbDpm8SWXAS0iVSec16mI+CroEqK4/tGmRyBlUBjwMivkxDIvxgyIcOrKS9rRGxC9TMmwqAyphr0I8p7E9l6U9XaI9VUCUJBzbpjFR6LYBc9vLRBHieY28ZiIJAp05jdqQF0US8MgSBcPGtB0Cqoxtu5V0WztyFE2LomaTLZrYlCjpJumCSnVQpaCb1Ee8gECyoKNKEh5FxKtI+GUFy7ZpSRVJF3RGVAepCPpIFzXe35WmZJqYlkWqYFMb8pApaZC2yWsWybyBLAvotluJaZg2BcsiIIpotkNHpoRmmcQ8CoMqgkwbIrGuOUNRt9w+Y92mI6cRL2g4AgTsEjtfe4iWt550rVyPuMT1b1c8aJrbj93cpZO3wAIsQ3MdaJY9jhqMMfi8nyLvlZXOb3iDxMt/xTF1Kk74AsGDTyu7twtbl5F46WasfIrwoecSOeISxD0s6w74d2tbZFc+i2fwpD7iqJ2v3Y4gKYSP+jQCUDKha/3bZDYvY9hJn4VgJbtSRUTRQTB1Hv7LddQMH8eUEy8i6FXJ6gaqpHDcpFpAAMch4JVR5IHs9b8LAwT7/yDmjKpEFD6476YH/glH4F27gNSb9+Iff1iv7ZboCRA77vN0Pf07siufIzzzTAAih19Mbv1CEq/cQt1l1/dmfb3DD0KpGk7mnSd6RcQOBJI/QuSI+TiGa0smKF4EebetRGLBbYRmnNGt/h2h9pLfkl3+lEts3nseT/1YApOOdo/ZG6TqtG8SmHAE8Zdupu2e7xCacQbRIy9DVH199u2pH0fdpdeTX7+Y1OK7aL//B/jGzCZ21BUoVUP7bP9B2NCW3W/vO7ilPQfyzdSFPQyr8PeWjNsfQK5FwX0h3bxwK82pYq86PbiCN47juDYLjuOK4HT/rR+g35YowPjaEJs7ct0eiX2x8QOszP5VMFJtpN64h8L6xYj+iCsqNu3EfZaDO6ZOeulDpJc9iugJUHXm9/BPmFd2zxZ3riL+/E1YuTjhuRcSPfyisjLs4o6VxF/8C1amk9CM04keeXm/99iBIvPO40jBSvzjD999nI5D8tW/uzoCR13Ru7yw4XVKO98jdvzVZSTaKmZJvvYP1PpxBKefUja+bjls6czxy3OmDvRdD2AA/0WoCnnIFg0cx2HakBiDojrbujJYpk1d2NMtpCXhk0QUQWRQVYTC6V9k830/J/PO40TmXogEmLjv//zGN0i89Bfqr7gJQVJQKocQnnkmmXceJzr9ZOTuQKAgiIRnnkni5b+i7Vp3wMHrDzcncJMBgQmH0/n07yhsfIMu2UPF8Vchevyu//akoxgyfBrJBbfQtvBuUuvfoO2srxEbMhYTEUGAbEmjKhCgIxBh4hlfQBt1NO8/cTNLH/obK198lDEnfZqag49HFARkxe3v1SyLmF+iK69TI6h4ZZFdqQJeVaIjWyTgldnRlaVZElFFkbXN7rt8cMzPmPoQg2J+pg+TWbqlg86sgW5bJAoGNgKmDVMGR5gzupqunEay4PZjeyWJ1nSJfMmgOaVRNE1qIl5kSSSnWXhkAcEB3bQwbYvGRKE7g23jUyQkUSDgcbPhhgHDYj5kye29HhyViOd01uxKIYsiQUVAsA38qsqQWICoX6Ejq1EZ8jC8IkBntkTIoyAJ7phRn0q6qNOWLqKqEnURLz5FpqAbmLYrTNaUKBBURbZ1FikYJruSBSRsJg+pwLYdCrrJqEiAhmKOeK6IZUKiqCPZFomiRVfOQJJgmwnViRL1ERXbkWhO52noypHXTQQHCgWD1KqX6FryAHYhTWDS0d1CoO781gK6dAgAOt0e2I3vE3/xL5jJFgJTT6D62M/h7NkmVsyQePlvFDa+4TqLnPbtMiEzq5AmseA2ChsWo1QN36/ey4GgsPktzFRb2bsfuoVNt75D7OjP4A1WIEpQyBdpfflWPNXDCRx0pusc4JPwqApLH/4zuUQHF33/BmRZIlsyKegOO7qyrG/xMWtk5UC/9X8AAwT7/yBmDI9x7VlT+MlTa3szl/uDIAhUnPBFWv/xZRKv3FJWhuqfMA/vmgWk3rgH/7jDkMNViB4/saM+Q/z5P5Jfs4DgtBN7xwnPPpf4c3+kuO3dDyXiIAgCQo8PsWmQ37QEOTaIwqYlCLJCeKabWXdsC1GSicz+FI5RIv3WI3Q983sKW98ldtyVyD12XqNnMehzfyW5+C6yK56hsPktKk74Av6xs/vZt0hw8jH4x80lu/xp0m8/QssdXyYw5Viih8//QHujA4VHFon5FNqy/fub7422jLZPL/T+IAgCjyxv6i3/lkU3c63IIj85fXIvqYLdfdevrGvrt+RccAOars6Y25KF7bgBBFUSmDEi1qdXHP79JcdWLkn6rYfIrnoRQZQIz72AyOzz9tvjVGxYTeLlv7rl3JOPIXbslWXWbbZWILnoDtfKq2IIdZdeX1YF0UNi82sXIFcMofaS3/Zmcj4qtOaNaI1r+vhmFre9Q3Hbu0SP/mwvkbZKORKv/h21fiyhg08tGye16J/YxQyVF17bb3DBtBzWtqT51UDf9QAG8F8DjywxYVCEYbqJZlik8jpVQR8CUBPxEfJIDI76kSQHr0dgTVOa4tTDyEybR8fSB6mbdgSlwGBM3N7qihO+QOdjvyDzzhO9goyRwy4iv24hna/cQt1lv+99/gSmHEfqzftIv/XIh6oO+zBzAkGUUGtHE55xJsnFd5Ff+yqlxvepOPkr+LvtvAjEqD/rh/gmLqXrlVtY//dvE5txOqNOvBRN8INjE/GaaIZOqqBh+CqoOusHxGauZedL/2Ttw78nsPhRJp/+GbyHHEVQVfBKEoIAuu72Bds25DSTvGagGRZ+j0pIlYgXDUJeGUWSiHhlMgWDYshmV6KEg+1mX02LnGbgVyUk0SJTEmjPamzpyKLIAjGfh2RBx6/KDK3ws64lRSJn4FVETMutnmtJFvGorgr3mpYMFQGVmN+Dg0Ou6KCJNo4FdWEvEZ9C0bDAAa8i9dqFNSUKBL0KuZJJwbCYOjTGhPoQRcPBtkGWBOojfjpzGg3xPCMqQ6iygG5YCAIEPBIVAZW8btGUKhLzuKQ+kdNRRZuOTIkGw0DTbTweEct21e8b4zmaEq4f8/IdRbZ35GhM5vCoMoWSTrLg6rkULcAEjwiZosYT7+3CJ0vYtsOuhEbRsMhseJ3Em/djplrxDJ1C7Lyf9moM7Y08YBfSrn3nmgXI0TpqLvwlvhHTy+Yzhc1Lib/8V+xijsi8S4nMOb/3Hncch8KGxSQW3IatFYgcPp/I3PPLAvIfFo7jkFn2KHKsvsz+1jF1Eq/eilwxmOjMMwl63Wux9ZV7MTKdDL/kenx+mSExLxG/itW2jRUvPsRRZ13MYYfPZkNjhl3pInURL36PTLZokCsaVIQ+fIZ9AB8PAwT7/yjmzx7G+LoQ1z6zrk8pbtAjURX0EM9pZDVXkEqJ1hE5Yj6pRf+ksHkpge4smiAIVJz4JZd8L7iFmnN/DEBgyjHkVr9EctGd+MbORfKF3OUTjyL15v2klz7gqop+BO+86JGXkXrzPjLLHsM3ehbhOef3rrMLGYxUK8Xtyylufpua836K1rKR9NKHKDWsInr0ZxA9AQLd6qOVJ36RwKSjSbz0Fzof/wW+sXOoOO6qfkmzqHiJzL2A4EEnkX7rYbLvPU9+3SKCB51EZM75ZRZIHwWaaR8wuf4oiHhlUkUD23HLv0+YVEtVyNNrg7V4Uwf3vLWTs6cP5genTmRFQ5IX17WVjVEX9nDzJTOA3b7gP3tmXZlAmW45/ZLrfyesQprMssfIrnwOxzIITjuRyOEX7bPPGlwynlx0B/l1C92XaD/l3MXtK9ysdC5OeNY5ROZd2lva5TgOhY1vknj1Vuxi1u3zP/zi/Yr3HChSSx9A9IUJTj+5d5ltaCQW3IZSOay3egQgtfAOl0RfUE6iS41ryL3/MuFDz92v1/aAm+UABvDfCb8qo5s2OxN5HBzqIj5UWaE+7MVyHCzLZnR1GFUQiAZURn7+h9z93fPY8ezNVFxw3e7A+5jZ+McdRnrpg/gnHIESG+QG3o/9HF3P/J7c+y8T6q6gERWPW1a++C60lk39tmx9ED7UnODcH4OkEH/hJjof/gmBycfgGzmDislHYwDecYcxaPhBpBbfTXL5M6ze+CZ1J1yFZ9rhNCQ0urIGRcMgXXJQJQgMn8Lsb/6VXcsXsuPlO3nnHz+ladEkjr7oi3iGHoxHEhhZH0K3HOI518s55vfQlMzhkUCS3d5WV4fEQZJAcFy1b49ikMhqyKJDvmSSKer4PRIeScXvlSlqFmt3ZaiPufZJluNQGfBREfRgO1AX8dCcKrKpLY0oyMRCPrIlA68KkiBSMG1M23aFUmWRqoCHgm4Sz2ms3ZWkOVVCFGBMVZCx9WG8ikSyYBDwSIyuDVDUbeaMrkKVJUzLRrdsqkIqbemS64UtiTTEc3TkNWYNjyKLAnnNQZIkZNHGNGy2pUsE/QoBr0RrSidd0MgbFkXNpJCyqPB7mTEsgiOJDAp72dyWY/nOBCVTJ56xEQQdRdrdgtYjoVqwQS9BvGTgYCA6NvlNS0kuuR+jqxGlZiQ15/0U76iZ+5yHOo5Nfs0CkovuxNbyhGefR+Twi8rsO618isSCW92sde1oKi+4tuz9amY6SLz0V4rbl6PWj6fylK/2ivB+HBS3L0dv20rFyV8re8+n33ncdQe54BcIkkK+BPG2rSSWP0Nw+in4hkwkrMpEw35k2+LJm39OtKKKK79xDYIoM214lKE5H5YgUB32EvDIKMpA9vo/AcE5wDLjTxIzZ850li9f/m/f7/9vWNGQ5KLb3irr25UEkCTR7b8Vyq2kHNui9a5vYueTDLryb2UiTellj5JadCdVZ/+wl3zrHdtpvfMbBA86kcqTvtK7bXb1SyRe/POHtiLYG45pIMgKVj5Jccd7aLvWobdtRa4YgqduNN6Rh6BWj8BxbPSOnXQ++rNeYZTw3AuJHXnZ7rEsk8zyJ0kveQAcCM/9f+yddZgd9dn+PyPHfX03u8nG3T1IQgjuFC1uFQoU+lKj7Vt5a1SAUloo7u5OAgHi7rbxbNbtuI3+/pizJzkkgYQCld/e18V1kTlnzsycOTvfR+7nvs8nMOncT02OtFg70cXPkVg3BwTBSuImn/eFdbS/aPidMsmspUQKVqVZFIQD1LsBgi6ZRFZDNwq7zr89Z2SBF/LKPWF++OJatrcnv+SzPzxoiS7iy14hvuZtTE3FM2w6gaMuxhaqOuQ+pq4RX/0WkflPYeoKgUlfwz/1goKZKD0VJTz3QZIbP8RWXEPxKd/F0WvIvuPG2umacy/p7cuwVwyg+OSbsJcfOok9EmQbt9Dy5K0Ep19BYL/AMTzvCWKLn6N8vznyTP16Wp/5Mf5J5xZ4zpuaQtMjN4KuUXnN3w7p7y1LAs99Y2oPNfxLgiAIK03TnPCvPo//NvTEBF8MDMNkdX2EpkiSPZ1JbDaBWUMqyCgGKVUnkdEQsGypVu+NgGEw/63n2PX63yg+5bt5oSU7oMc72f3gt7FXDqDswt/kLZ1an70NtW0XVdfel3dZMLIpGv9xLfaKAZRf8KvPff5HEhMYSob2F39JtmEjYHXSS067peDzsk11dM3+O0rrDrx9xzDorG8jFdUQT1njVbIINhGGVnooD3qJplLEN3zEx88/QEdLAyV9hzH57KuZNP14bJKETZYRJfDbbIQzCl67jNchEU0p7O5M0hBOE3DKeB0y4ZSGz2234hddJ5zSqA44aYpmqAm5SGQVTEMg5HPgccoEHHYwYXhNAN0wSCsa21pTxLNZdnelEHRLsHRXe4KBlQHKvA5OGVFJZchFUrF8sBXVIK1qhJMqi3e0I4ogCQKmKTCoPECJX0YWBRq7MvjddibWhij1F449xVIKezqT7GyPoRiWgJ6qGvjdMs3hNLohEPLYKPG7KHbbWLi9HZddIpHKUteWwC4KRNM6dskAUaLUa6fI46DU6yCaUtnekWRHR4JEWsMwIaFZvzcZEETIGuAQIGNaybZp6KS2LCC6+HnUjj3IRdUEj/467iFHf6pobbZlO11z7kVpqsNRPazAnQasgnpy41zCHzyIoaYJTrsY/+Sv5Rlm1nz0W0TmPwGmQfCYy/GNP/2QY2lH9Ds3TVoevwU9HafXdf/IH1ONtND80PW4+k+i9OwfIQAuQ2PHY99DT0WouubvuJxeQh6BwRVemha8zPuP/4Wf/Pl+Ljr/vLx+TVbVEUUQEelb5qXM5/hczbAeHB4OFRf0dLD/g7FkZ+cBoliVAec+H+ZP1E4EUaL4lJtoefx7hD96hOKTb8y/5p9wNslN8wjPuQ9Xn9GITi/2sn74xp9BfMXreEccj6PXUAC8I44ntuRFIvMex9V/wudW5u5WANfinUQXPYMg2Si74FcFXUojkyC2/FVMNYtn5Cy0SAupzfOJrXgNyRPEO+YURElGkGQCk8/DM/RYwnMfIjr/SZLr5hCaeQ2ugVMP+nCR/aUUn3QDgSnnE138Aom1s0msfQ/PsOlWl/AQVUqHJJA9zLnmLxKxjFbwb1030Q7RZ+625ehGhd/BqOoggyt8Berjv3h9w2HPaH+ZUMNNxJa9QmL9+2DoVmI99YID1LY/ifSu1YQ/eAC1sx5n33EUzfomtqJe+ddN0yS5YS7hDx+yqF3TLiIw9cL8b69gETUMQsddjW/CWV/IItp9/PC8xxHdAXzjTt93vR17iS19Cc/w4/LJtaFm6Xz3r8iBcgJHf73gcyKLnkXraqTsgv9DtDkL5vxlSWDmYEuF/mvjqnuS6x704P9TGKaJYWrEsxoOWcQmCmQUk+0dCRySSDKrsacrwZr6KG1xBUkwCIw8Ceeqjwh/+BCufhOQvCFsAlT1KiM540paZ/+d5Po5luaFIFB8wvU0PXIj4Q8fouT0/wGwbLcmf43IR4+Q2bvhcwuJHmlM4KwZgbN2LLElz5Pc8AFGNkno+G9iC1hzuI6qwQy4/A4617xDx/wnWHX39RSNO53gURcjOb2oBohAVFHxqRpTBpRTM+lqjjrlHGa/+gLLX3uEt+68lbWvDeLki7/BzNPPJJIGzdTJqBqGYZLMqhgYeO02Sn0GjZEMdrtMJKuSVlTKAg5UTcdrF/C77aQ1E6dTpC1qkDZUSn0Ogg4bFSEXGVVHliS2NMcwMYmkFBTdwI7I1o4oimoiiCbJpEK9orC+0YFqFOGwiVT6XUSNLFtaUmi6TlbRSKoGApa4WVLJUl+foCLgpMRjp9RnRzMhlrZo62lVxzRhc0uMZEanK6UTS2XRDAMBkYXbOslmVcoDdmyym8FuL7vaY4RTWToTUNcawzR1BFMipmgE7DKypNOeyNISzeBz2UhkdXRdwymJdKrW/ZQBFWtOGsP6t8cJalwhsmEusWUvoYWbsRXXUHLGrZaWyqesz3oyTGTeEyTWzUF0Byg+9RY8I44riFPVcBNd7/2NzJ61OHoNpfjkmwo0eZTWnXS+dw9K81acfcdTfNL1ltPNF4TU1kUoLdspPvXmfQl9zh0EUSJ0vCVsagKty15FadtJydk/ztuKOm0C7Y0NfPTsfYw95gSGTjueDc1RDEOgzGenX5mXcr8L136jAT346tHTwf4Pxso9Yc6/b9FhKTPvj/CHDxNb9jLlF/0WZ599CszZlu20PP49vCNn5S26jGyKpoeuR3R6qbzirvzDILnpYzre+CPFp//PZ9pzgNUNVzv2Ijjc2Ev7HkDHNjIJwvOeQG3fTWDqBbj6jUcNNxFd8DSmplB6zm0YmQTtr/8RR81wsvXryexejb1yIEUnfgdHxYCCz0vvXkP4g/tRO+px9B5J6LhrDnjPJ6HFOogte5nEuvesxbvfePwTzsZZO+bfrvonYHliH66YeU7vLG/dpeqWddeR/na+SJimSbZxE/Hlr5HauhgkCe+I4/FPPg9bqPJT91U79hL+6GHSO5YjBysIzbwW14DJBfdJ7dhL55y/k61fj6PXUIpOugF7aZ/869mW7XS99zeUlm04+46n6MRvYwtWHPKYeiKM0rYTI5NADlZgrxz0mb+L9M6VtL3wc0LHfyNPAzdNk9Znfmx1ga77R74LFP74MWJLXsjPh3VDadtJ82O34Bl2HCWn3Vzw+QLwm0+wEnrw5aGng/3loCcm+OKwozXOR3Wt+Fx2MqpCLK0jmCaSLNGn2MPCbW1sb0taisxqlmhSoWtvA9sfvgH3gCmUnv0jvCL0KrYTiWdY99CPUdp3Wx3rnE1ReN7jxBY/T9lFv8HVZzQAhpqh6f5vIPlLqbj0T5/5bNTTcbJ7N2AauvU8La0t0Kc4kpjAPXAyhpIiMv9pBEkiePQl+MafkU/EJEBNRQnPe5zE2tmILh+Boy7GN+YU7JJMiU+g3Ovm1JFVDOoVYHd7gvZYmqZImsj6D/n4pYfZta2OQGkFR5/xdcbOOhfT5cEtS4TTGslUlo6khm4YdCXS2G2yNfssibiccj5RdtpkaoIOTARSmk40qYFpUua3M6Z3yFIdzxos3tGBUxZJZXRUQ2dHewrdUBElgXDCQNXBYYPjh5QwbUAlPpdE31IfKUVje2uc9niWcCKLlrOxEiTw2mXiKR2/S8ZllykPOJjSr4SMZiBg6Xc0RZLsbE8SdNtJZFRiaZWNzXEUTacllsYmCHjcNgRgeKXf8l9OZumIZogrKl67HVkUQBAo99uJZVVERGRJxDBMVEPH77ITTyukFUvwLaOYpPV91HA9FSW1+m0iq97CSEWwVwzAP+V83IOmfmozx9QUYiteJ7r4OUxNwTfudIJHf73AesvUFKJLXyK6+HnLoWPGlXjHnJz/XENJE134DLHlryK6/BQdfx3uoccemoKuqShtO9GirQg2B86akZ/pgW3qGk0P34AgCFRefU/+N5rcPJ+O128nNPO6vNuP2tVI8yM34uw3nrJzfoIEBBwQ8NjZ9PCPiNTX8YP732D4wL4Uex10pRQq/U4m9y3BYe9RC/+qcKi4oCfB/g/HNx5fwexNrUe0j6FmaH7Y6l5XXv3XArpp+KNHiC19qSDIT21dTPsrvyE4/UoCU84DrNmWlsduQU/HqLr2vkNaFJimSde7fyWxbnbBdrmoGs+Qo/GMmFlA/01tXUS2qY7QjKsAiK9+m+iiZwlMvRA9FUXt3EvpWT/cJzox90GMVAzvmFMIHntZvsIHVncyseZdIguewkjH8Aw/juAxl352JTITI7LqbeIr38RIRbAV98Y3/nQ8w2Z85sPzq8KJw8rZ2BSlsZut8B8EQ82S2jKf+Ko3UVq2Izq9eMecgm/8GQWK2QeDnggTWfgUibWzEWxOAtMuwD/+rAI/dENJE130nLVI2hwEZ1yJd/RJ+xbRbIrI/CeIr3oLweklNP1KvKNOOPQiaui0PHErSsu2gu2uQVMpPfvHh1z0TUOn+dHvYqoZqq69Ny+Iklg3m8537qbopBvw5Wayu4tbnhEzKTn15sJjP/49tEQnVdfci8vnR9EKn9mXTO7Nb3pEzb4S9CTYXw56YoIvDhlVZ31DhI54ltmbmvE7ZYJuO+mszoAKLxsbYmxriRJNq3gdElldJZrSWPfOM3TNe5yqc35C3wlT8bucpFIKu3bXU//Ijbj7T6T0nNsA6xne/PANIAhUXX1PfhQrvnY2Xe/eTclZP8Iz5OhDn+PeDbS9+EtMJZ3fJthduPpNsGaq+43PJx6HGxOARbHtmnMvmZ0rsZXWUnTit3FWDy84ttK6k665D5KtX4dc1IvQsVfgHjQVryAwZXCImpCH8oCToMeJounUFLkxdYPlCz7g/r/dzc71K7A5XfSfejLDjj+PAQOHopo621viiBjEFBNNMwh67WQUjYDdhtdlJ5zKYGIS8sjUNaUwBQADQRQocjkJeW0Ue1zIkkl7VMHlEGmPZQm6bXQkspYvdDxFZwpkwVLFrgk5OHlkFWP6FFHkcdASTecstUTakwr9it1EMyoysLUtSVbXcdstb2sBgSn9Sgh5HLTHs6RVjeZwOpf4K7QlMhR57GxqjOKyybTF0kgilPjdZFWDASVuGiIptrTEiaY0ZAk8dhG/047dJlMVtBPLGNgkgfZYlqSioukmvUIO/C4HdlFkxZ5OsrpJOGNZcMZXvUVy88egq1aDY9K5OHuP+tRijWnolkvM/CfQY+24BkwiNOPqAvVvsIrdXe/fhxZuxj3kGELHX5ePN0zTJL19KV1z/oEeb8c95BiC06/EFjx0rBhb8TqRjx/D1Pbp7oguP+UX/vpTx8viq96ia869lJ77s7wor5FJ0PTgt5G8RVRcfgeCKGGaBq3P3IbStouqa/6O7CvGAdgk0DbPoe6Vv3DlD37NVVddgyhKdKayyKKA0yYxoU8RAfc/rx3Tg8NDT4L9X4qVe8JceP9itCOk+Wbq19H6zG34Jp5N0X4+u4aapfmRGzENnaqr78nbErW98hsyO1dSefU9+YQ4s2cdrc/eRvDYy/NKo59EtnkrLY9/D9/4M/COOQUjk0RpriO9YxmZ+g1gGjhrx+KfcCbOfuMPmqxosQ46Xvs9Svsuik/7Hp7BR+WVRY1MgsiCp4ivegvR6SV47OVWsrQfhcjIJokufoHYitcAE9/Y0whMOT/fOTwUTE0luflj4ivfQGndgWB34Rk2He/okz+zG/5lothtI5JW+eQtP1x7sH8VlPY9JNbNtqh8mcS+wsXwmYj2g88Ud8PIJIgue4X4ilcxdQ3fmFMIHHVxgTq4VXSZR/jDh9ETnXhGzCI048r8fe4uyoTnPoSeDCO6A8iBcmRfCaGZ137q7H3448cQnR4cVUOQXAGSm+cRXfQMZef9HNcnPDS7EV/zLl3v3VMQbOrJME0PfhtbSR/Kv/47BEHE1DWaH78FIxmh8tp7C4pE0SUvEPn4sfxnfOvYfjwwf2fBvf/65N49quFfEXoS7C8HPTHBFwfTNNnaFmNeXRvbWuOEXHYkQSSjqQiiQFbVMXSdSMagb4mX1miCbe1p2rsirLn3FrRkhON+/CiVZUVsb4nTkDCILnmRyMePFjzL0rvX0PbcT/FPvYDQsZdbxzZ0mh+5KVdUvK+g8Lk/mh+/BSOdoPi07yE6XKgd9WR2ryG1bQlGOobkK8E39lS8Y05GcvkP2P/TYgLTNElvXUzXBw+gx9st+6YZVyH7igu+o/SOZUQ+ehS1cy/2ysGEpl9B1cBRFPttFHsdBN12fA4ZEZGQx4bTJlEfTtG2Ywub5r7IuvnvoqsKvYeOZfCMs7H1m4LH7aIrkcEQodTjRtFU/C47dkmiNZamK50ikTLJaiAYkAIcQIlfxGWzEXLbyWoaumFSGXTjkESCXjtr9nSRyihEEwYx3UqwsybUBgTG9S1ncr9igi47bfEsda1RUlkdwzDoV+rBFAScsszecBzRhPakiiyYeJ0OnDaJgeU+NNMgllSIKSbRZJqkauKxSWiazu7OJClFQxBEijx2ehe5UQ0dt93G+qYYbZEUkiigqBqmCEPLAmQNk5G9/CSzOilVY097Ch0D0zRQdZOQy4UowramdrrWzmX3kndRWrYh2Jx4hh+Hf/yZn2mjat3D5UTmPY7avht7eX9Cx11TwMoEq+gSnvsg6W1LkIt6UTTrWwXip2qkhfD7/7DYcKEqECXLr90T+tS4ILN3A8nN83H1GY1cVIWRjtH++h9wlA+g7PxfHHQfI5Og8f5vYCvpTfnFv8sXDjrf/SuJdXOouPyOfGzZnYgXn3JT3snHJYCU7WDLvddT2mcIz73+JuN6l9CZzLB2bzSnKm+npsiN39WTYH9V6JnB/i/G5yEvO3uPwjvmZOIrXscz+Oi84JNoc1B8yk20Pv1jIvMep2jWNwEomvVNmh68ns5376H8IkvwxNlnFK6BU4gufh7PiOMLFrBuqO27AfBNOCtPv3VWD8U/8Wy0eAeJ9e+TWP0ObS/+EltJb/yTz8MzbHp+ocQ0LOswlw9Xvwlo4SbrmnMJtOj0UjTrm3hHnkDX+/fR9d49JNa8Q+j46/JzYKLDQ2jGlfjGnUZkwdPEV75BYu17+MafgX/SuXmF9AO+V9mGd+QsPCOOR2mqI77mbZIb5pJY8y62sr54R8zCM+xYJM9XO/PamVIPuv3fMbnW03FSW+aTWP8BSnMdiDLuQVPxjTkFR++Rn0kjNLIp4ivfILbsZYxs0qosH3vZAaJn2ZbthD+4n2zDJuzl/Sk560c4q4fmX1fad9M15z6yezdgrxho+bjbHASPutgamVj5Ou4Bk3H2HolpmgecV2j6FQX/9k86h+iiZ1A69hw0wTaySSLzn8BRPazA97przj8w1AzFJ9+QLyZFl7yA2raL0nN+UpBcq517iSx4GvegaXiGHI3bLvGjU4fSu9jDz15dj2GCTRL42rjqA47fgx704P9PpBQdXYfqkJdk1iCazOK2S/Qr9yEYVifTa5cpDciU+1xIEiQzJoYRYOKltzH/zm/R+v6DHHPzb9jYbDmU+CedQ6puAV1z7sXZeySSO4CrdgyeEcdbWhJDjsZe1g9BlAgdfx1tz/2U2IpXC0Qd94faXo937Cn5Z7S9tBbP0GMpOvF60juWEV/1NpF5jxNd/BzeUSfhn3Qusr/ksGICQRBwD56Gs+84okteILbsZVLblhCYegG+CWch2izBJ/eAybj6TSCx/n2iC56m9dnbiPYZRfHRl9FnyFB2aNCVI4g5RSj3i5T7PSQ8VVSeehNF069Eq5vL2rmvMefen2NzeekzYSbBUTNxVgyiSY1REfAwsXeAZFYno2ukVRuGTc3pp+SEvYC2mIHbniWrargcMm5ZwikJ2ESReFrDNCHkdWKKWbS4Tlaz9g16XWRVnc3NcVw2y/qzOZzCwJqt7khkKfHY2duVpD2eJqmBSzKx2ez08muUeGXeaYmQUE1cdgmXzYaBSZnHgaKb7OlMEMuoGAbUFjmpCrlxO23IpowBFLlE2iIGWRUcdht+u4BdFlFUlZX1YWRRIJHRiaasL9LrkrEJAtvXLaF+2bu0rluAoWawlfQmNOubeEfMLKB0HwymaZLZvYbIgidRmuqQg5WUnPF93EOPKWjQGNmUdf+Xv4IgygSnX4F/wtn5oo+hZokteZHo0hcRJNkSFZXsGJn4YcUFzpoRB2gNuGrHktm74ZDnHln0LEY6TmjmtfnPytSvI7H2PfyTzs0n11q0lfDHj+KsHYtn5An7rskwCb/3d0TT4Iof/oZhlUG8Thtepw2f00E0peCyy3gdn98+rAdfHHoS7P9wLNnZecTd626EZlxNesdKOt6+i6qr7s7TvJw1I/CNP534yjdwDz4KZ80Iq8t33FV0vfc3KznNUVtDx11D00PfJvLRI5SccesBxxBy9HNTPZDKLPtKCE67iMDk80humU9syYt0vnUH0YVPE5h6AZ7hMxEkmWzzNgw1Q8V5P8/v+8mHnb28H+Vfv93qYH70CK1P/wj3oGkEZ1yZT8Zkfyklp36XwOSvEVn4NLElLxJf9Sa+cafjn3h2QTe04BoEAUevITh6DcE4/hskN31MYv37hOc+QPjDh3DWjsEzdDrugZMLlNn/f4WhpElvX0Zyy3zSO1aAoWEr6UPouGvwjJh5yO+54DOySeIr3yS2/FWMTBzXgEkEj770AOqVFusgMv9xkhvmIrqDFJ10QwGDQU/HiS58mvjKNxHsLkLHX4dv3Ol0vPEnbCXW3LJv/JmktswjvX0ZzsNI+sEaswAQ5IOPRkQWPI2RihE6/5f5z0ttXUyqbgHBYy7Li7cpbbuILnoO99DphV6Yhk7H23ch2pwUnfhtAPoUWeMJ3fPW72xo5pQRlT2iZj3oQQ8KkFE0sqpGwGmjttjNqOogO9oTNIWTxNIK/pCbCr8Ln0NmU7OCZhgUe2SE3v2Zds5VLHzpQWrGz0T2D8QJZESJ4lNvpvnRm+mac1+ekh2aeY1lffjO3VRc9mcEUcJVO8YqvC96Ds+w4w5qfynYHJjKgTGBIMm4B03DPWgaSvtuYsteJr76LeKr38Y76gQCU89H9pcdVkwg2p2Ejr0M78hZhD96mMi8x4mveYfQ9Ctyc7UigijhG30S3uHHWdTzJS/S+NT36awdS2DahfkEKmnAzohBJB1HVcmJo9lx9juZiaNPx2zZwsaP32Dnkvcw5r+OPVRBYOh0omOORVUHM6yqGK9NRkbAYbMRSSuogIY1H24AqgJxdBIZnZpiN1lVJ2boBFx2ZFlCFkVEEUJeG4JpkFB1fE4bm1tibG2JkdF1PHaRaNqw7DsFaIym6Uhk6EhmMQ1Q1Jz9laHQ0KUgAKIILju47RLVQS9+l0wkrdASSZM1dZw2GyImpmjSnlSpsktsakmQVlVcsojPZc1rK4rC3iS0xbM4bDIOWaQzpWITTeIpk2zrdqKb5xPdNA811oFgd+Edeiy+UScgVw35zHXXNE0yu1YRXfQc2cZNSL5Sa70fOatgdt80dBLr5hBZ8CRGMmKNBk6/Ii+UZ5omqbqFFtMt1oajdixFJ3wTe1E17a/d/k/FBaaaPWRMoHbsJb7yDbyjTsgn0oaasYRNcza63efX+c7dABSffCOSIGATIOiE1Jb5tG5axqyrvs+44UNIK3putt2wHAIEAZ+zR9js3wU9FPH/cDy9tJ6f5rpZ3RhTHWDNJ7yxD4VuESb/5PMIzbgyv91QMjQ/YllzVV71V0S7y5oJefYnKC07rJmQ3MIZmfcE0cXPUf713x9Q0VPa99D88HcKLEAOBdM0SG9fRnTRs5Y1R7CCwLSL8Aw/DgwDQbZh6lqhEIqSznlynrpfZTJDbNkrxJa+hKmrFp142kUHUMKV9t1EFz1HassCBJsd7+iT8U88G9lfeljfndJRT3LjhyQ3z0OPtoIo4+wzGvegqbgGTPrMeeL/JuipKOkdK0htW0xm1ypMTUHyFuEeeiyeYTOwl/c/rAVKT0aIrXzd8r/OJnH1n0jgqItxVA4qeJ+RSRBd+iLxFa9jmib+CWcSmHpBvvrdPX8fnvcYZjaF6A7grBmJqWYoO/8XJDfPJ71zOcWnfBdBlMjUryO5ZQGeoccelgJuavsy2l/6VYHFVjeU9t00P3JTgb2dnknQ/OC3ET1BKi+/E0GSc9Tw76Enuqi65m8FhYfo0peJfPRwgYhgtwUXwCUPLkHRDOyyyFPXTulJsr8i9FDEvxz0xAT/PEzTJJnV6Exm2dIcQ0TA57Ljc0gM6xVgfUOENXsj2CSBoMtGWjGoCDjZ3BymriWBTZToSGXxSxoPfP8SsukkE2++j6jmIJojTUUWPUt0/pMFVPHklgV0vPb7Ao2WvN3QgMn5ZHx/tDz9I0w1Q+UVd33mdWnRVqJLXiSxfg6Y4B19IoEpFyC5AweNCQASGz/EUTmowFEis2cdXXMfRG3bib1iIMEZV+YF2rphKBniq9+2GFOpCI6aEQQmn5cbXytcv0Qs1piJZTOlAGI2RXb7Iro2fExmz1owDRxFvSgZeRRDJkwnG6wh5PWwpytNWskpZ0N+ttYlgyCJBN0SDknG77YTS6mkNB1JFMhkNQTBQDVMnKKMXZZQMAnaRXaF0xR57SRTCg67DUkAwzDQBZFkRkUUQNEho1iCYiKWeje58xeBmiIZj8tBbdBJRDGwC9CeUBBEsMsCkiDhsomEkyq6YNIRV8hkVTQDNM36LgwTHHYodUnU162lY+NioluX5GMkT9+xuIbNwDVw8iHtJveHaeikti0htuRFlJZtSL5SAlPOs1Tt9xtB2Ef7fwy1sx5Hr2GEZl5T4Mtuzd8/QLZ+PYLNheQrwlE5CCMd/6fjAtM0abzvahyVgyg9+8cHvNb2XC52/sb9+bU+PPchYstfKRAcjq9+m67Zf6fopO/gG3MKLkCSoEhIsOSOa/FV9OUX9z7N5IHllPhcVAVddOZm9CVRIKsZ1Ba7/+2Eef+b0UMR/y/Dyj1hXl7VwAsr9hYk1wJwwvAKyvzOwxI/c/Ubj3fUicSWvYx74JR9VHG7k+JTb6b16R8T/uhRik/8NoIgUnzyTTQ/cgNd791D6Xk/RxAE/FPPJ7HxQ7pm/53KK+8uWOxsJTVI3iLSO5Z/ZoItCCLugVNwDZhMesdyogueovPtu4gueYHgUV+3KECfWEhTWxcT/uAB4iteJzjjKtyDj0K0OQkedTHe0ScRXfA08dVvk9jwAf4JZ+OfdHY+CbOX1lJ61g9Rj/o60aUvEF/5BvFVb+IZeiz+SedgL/t0H2R7SW/s068geOzlKM1bSW1ZQGrrIrreuwfeA3vFAFx9x+PsOxZH1ZADzv0/Gaaho7RsJ71rFZmdK8k21QEmkq8E76gTcQ85Gkf1sMO2cFM7G4iteJXE+g9A13APnkZg6gXYy/sXvM9QM8RXvUlsyYsYmSSeYdMJHntZXrjONE0yO1cS/vBh1M565FAVnglnETz6EgD23v119HQMW3EvMvVrSe9YjnvgFORgJaLswNS1A87tYEjvWI5gc+CoGlKw3TQNumb/Pa8H0I3wB/ejp6LW30zudxBd/Bxq206LGr5fcq127iUy/wlcA6fgGTYjv13TTZbs7ARA0SwfV1Uz8tuW7OxkSr/inmS7Bz34/wymadIWz7I3nCSd0TGBbe0JKgNObCE3hgnFXgfj+hQRSysomoFmqAiYlAc8JLI6mm6SVLKsb0rT5+xbWPePW9j+5j+oyDmKAAQmn0d62xK6Zv8dZ80IJE8Q9+CjcA+aRmTBU7gHTMZWUoMtWIF/yvlEFzxFetSJBTOvAM4+o4kueBot0fWZhWg5UE7xSd8hMHU/O811c/CNOdlKtL2FzztDzRCe+xBGJm7prRx1EZLLj7PPKCqvvIvkxg+JzHuStmd/grN2LMFjL8dRORCw4p7A5HPxjTuVxNrZxJa9TNuLv8BWWot/4jl4hh67r5C/3zG7E2XZ4UYaPgvb8FnoyTCprYtJ1y2kcd6LNH78PJI7QNGg8Ug143D0GQO5or+AlfzaZChyimAKJFQdSVHpSmdxyyJBp0SXIZBRIZE0EV0qGjqGIdCc1UlkQDIUHE4bbodEsctJKquQ0cErCYTTKg47uESTSAaynzh/AWjo0vA6NZJphQq/k5hhktUNBha5cTpsdCYUYmmN9mQWwTSJJDVUzUrYNYB4B5nda8nsXsmGnavQMwmQZFx9xuCediGugVMPOZL3SRhqhuSGucSWv4oWbkIOVVJ08o14R8zMC4Z2I7N3A5GPHyfbuAm5qBclZ/8Y96Bp+SRTS3QRnf8kiXVzEJxe3EOOxVk7Gt/ok4AvJi5QO/agx9pxHkSPKLV5Hpk96yg68fr8Wp9t3Exs+at4x5ycT67VSAvhDx/GWTsW72iLJSqJ4BBNNr90J7qmMvzCW1nXlCCpiozvW0S534EoCmiagWDuGxlVdctL3SZJuHoUxf8l6Olg/wfi6aX1/Oy1DeiH8Fj67TkjGVzh4+IHrC7XZ8Gy4voOgs1B5ZV/KVAE7/rgAeIrXitQFY+teJ3wB/cXdKVT25fS/tL/EZx+xQFzV11z7iO+9j2qb3iiYMb0s2CaJultS4gseAq1fTe2kj4Ej7nkAF/r9K7VhOc+iNqxB0fVEILHXY2zelj+dbWzgci8x0ltXYTo9OGffC6+cafnBdy6oUXbiC1/hcS6OZhqBkfvkfjHn2HZPx2mL7Jpmqgde0hvX0Z6xwqyTVvANBBsThzVw3H2HoGjehiOioF5Sv5/AkxdRWndSbZhI5m9G8js3YiZTQKCVUjoPwHXgMmH3akGKxnN7FpNbOXrZHauBMmGd8RM/JPOLeg8gGWvEV/7HrHFz6Mnwzj7jic0/YoCyrjSuoOOt/+C2rbToo/N+gbOAZMQc/cusvAZ1LZdFkW7pIb46rfJNm6m+LRbEASRthd/aRUHBk096Bz2/ufS8LcrcNaOOaA7E1/7Hl3v/rVAmKS72x2YeiHBYy8D9lMNHzY97yULOdXwJ3+AFm6i6pq/HxA8nj2mCrdD5sWVDei6gU0W+d/Th/OrNzf2dLS/AvR0sL8c9MQE/xxaYxk2NUVJZjUETBRNR9GgxO+g3Oeib6kHSRRoiaZJZjVsooCim7TEMkTTWXoFXIRTGh9vaWJtYxzB1Fj1yoO0LnieXl/7X+QBk/LHUtr30PzYd3H1n0jp2bchCEJOvPF65FAlFZf+0dJQ0RSaHr4BMKm86p6CuELt3EvTg98mOONqApPPPaJr1aKtRBY+S3LDBwiyDd+4M/BPPrdADE1PhonMf4rEutmIdhf+qRfgH39Gfs01NYX4qjeJLnkRIx3DNXAKwaMvwV7Wt+BYpq6S3DSP2LKXUTv2IHlCeMeeim/0yQc8m7thaXTvS8DtQCaTwNi5gviO5SR3rUZPxwCwldbi7D2S8v4jcVUPpqikhKRi4rbb8LpF4kmdjKLhdthIKiqqCWoW1FwiFfRAVgHdAMMAxYSgW6Dc60SUBHTTxC6CZkLAIeNzSEiyyOJtXXQqhefdLZLqBEQBBvVyIZkypqFR7HeQyhpIsoimqNSHUySzBrH2dlKNm0ju3UC6fgNaV4N1XZ4gvgHj8fSbCH3GHZH7ihZrs5oia97DyMSxVwzEP+lc3IOnHRCHZZvqiMx/kszu1UjeIgLTLrI627kitpFNEVv2CtGlL4Cu4xl5PKGZ1yI63Pni/xcVF4Q/eoTYsleo/s5jBbo8eiZB04PfQu62rxMlS0z40e9ialmqrv4bosONaegWQ7R1J1XX/A3ZX2rZcjkhvu4Dtr9yJ/1O+xaDjvsaRR4nI6qDjK4JMajSR4nXSTyjYppQ4nXgsEns7UphmCamCRUBB56euewvDT0q4v8lWLknzHn3LjqkoJUA3HrSYL5z3IB8l3tDY5R1DdGCfeRPeCh3K4L6JpxF0fHX5bfnHwRqxlIVd3r32Qe07qTqmnuQ/ZbKYvsrvyW9c0WB0jhYtJzmR28idNw1+Cedc8TXbJoGqc3ziSx8Gq2rEXvFAILHXIaz77j8w8409JxYyVPoiS5rwTzmsgN8j6PznyS9cwWiy49/0jn4xp52wMNfzyRIrH2X+Kq30GPtSL5SvGNOwjvqxCOmfeuZBNk960jvWUu2fh1q517rBVHGXtYXe8UA67+yvthKeh8WZerLhqkpqJ17Udp2obTuQGneRrZ1B+gWoUwOVeKsGYmzz2ictWMOa6Z6f+ipKMkNHxBf8w5auBnRE8Q35lR8Y089gMZvqFkSa98jtvRF9EQXjpoRBI+5tICupUZaiMx/gtSmj0EQcfYZhWBz4Oo7Ht/YUzE1ldTWRcTXvotnyDHE17xD0czrsBXX0PneX5ED5dhKepNc/wGhWd84gI7+SSQ2zKXzrTsou+D/CjozeYXw0tq8QqiejtH80HcQ3QEqr7gTQbJhqFlaHrsZI5ui8pq/FaqGL36eyLzHKTnj+3iGTT/o8UUBZFFgxuAya9YOeGZZvUURE+B7J1p//z344tGTYH856IkJPj8yqs7Cra2sbYhik0RcNhGnTSbkteO2y5T5nfQv9eK05YRDgYxqsG5vhLZ4GpuMJZJks7F8VwcLt7VT156gtSvN3sdvwUhFrefUfglsdOlLRD56hOLTbsE74ngAkps+puONPxa4ihxMabwbLU9+Hz0Voeq6fxw202l/qF2NRBY+TWrTPAS7E//Ec/BPPLtgPVfadxP+6BEyO1ci+UoJHv11PCNm5hM1I5sitvxVYstfxVRSuAdNIzDtogO0Prrnf/PF4G6xzrGn4qgZUZB0SezzdO6GDUv522YDJaOTbN1BavcaUnvWkW3ajKla/WRnsAxf9SCcFf2oqB2IHuqL4A7gddmIpTTsIiRViGsWpdshQMYEF1b3O6pZCb2M1VF2ida5BH0wslcxCdUgllDJqFkaOlWS+wWFIoVdeY8AfYplEopJRtPIhNsJN+xGb99BomkHqeat6IkuwLJYc1QPw9N7NOWDx+AuryWjiihGYaf8UOgutsfXvEN6+zIA3AOn4JtwJo7q4QcktdnGzUQWPktm10orlpt8Hr5xp+bjJ1NXia95l+ii5zBSEcRc11pPhXHVjv3C4wJTU2i49yocvYZSdu5PC17rVgivvOLOPCPvoI2rZS8T/vBhik+9Ge/IWbgAjxuEVJg1f/kmvqr+9LnkN8iiiCTBhNoSakt9jKwO0bfUg6YbOGUJQRTwOWTLhs9pI6PqeB0yxd6Dz4b34J9HT4L9X4IL7lvEst3hQ74uiXDhxN58bVx1vou1ck+Y8+9dxGf1sjtn30ti9du5udJ9tj/Z5q20PHErnqHH5oXM1EgLzQ/fgKNqCGUX/gpBENHinTQ9+G0clQMou/A3BQ/Flqd+iBZro9c3HvjcVGnT0C1618Jn0KOtOHoNI3DMJQVzVIaSIb7iNaJLX8JUM3iGzyBw1NfzCuYA2cYtRBY+Yz2cnV58487AN/70AxJF09BJb1tKfPVb1jyVIOLqPxHvqBMtn87PcR16Kkq2cTPZxi1km7eitGzHVFK5VwXkQBlycTW2UC/kYIVlI+UvRfIVI7p8nysQ+SRM08TIxNHjnWjxDvRoK2qkBa2rEbWrES3SAqb1axFsTuzl/XBUDsZeNRhH9bDPNVtuGjqZPWtJrJtDatti0DUcvYbhG3cq7sFHHUD5MrIpaxZuxasYydws3FEXF3hiavFOooufI7HmXQTJhq2kN4Gp5+MeNI3M3g10vPFnqq6xqsOGms13UGIr3yCzaxVl5/0cLd5JevtSMnvW4ZtwVoHy+KG+u+ZHbwJdo/Kavxf8xttfu53UtsVUXfXXvIhZ+2u3k9q6mMrL78gHbfnF9YJf4eo7Lr+/0raT5se+h3vgFErP/tFnfqeSKGCaJrIogCDkO9o9HewvDz0J9peDnpjg82NrU4SX1zTR1JUillUYXuGjf4Ufv9OBCVSHXIQ8Dkq8dmTJWj+yms7cTc00RTIEPTZMHaYNLKWuJcbsDY2s3N1GQ6dBrHUnjY9/D/egqQVsHdPQaX3mxyhtu/NFdtM06XjtdlLblljJRK4b3PHWnSQ3fZTbti9xTW6eR8frf6D0nNtwD5r2ua9fad9NdMHT+xhqk87BN/6MAoZaZs86wh89gtKyDVtxDYFjLsU9aGp+PdUzCeLLXyW24nVMJYWr/0T8Uy446HqgdjVafs0bPsDIJpGLeuEdeQKe4ccd1EkFrKTXAEq9kEyDpoPNDg4JDEPD7NxDYu9mOnduItm0jUxXc35fyenFW1aDq7QX3pIqVFcppqcM0VeC6ClCkG18FoHZJUBNsRNVN1A0Dc0wiCat5Lc7AzDVLFoyjB7vQIu2okXbMMNNqOEGsh0NGHnPcgG5qBeOyoHYKwfhrB6GrbQ2X7TwAQEvtCX2UecPBS3aSmL9ByTWv48ea0N0B/COOgHfmFMPsMcyTZPMnrVEFz9Ptn5drklyLr5xp+XvteWJ/ZE1fhBtxVEzwmIdjDoBV99xX1pc0G3JWXbRbwri0W47XP+kcy2VciBTv57WZ27DO/YUik+8HjcgxXaz8f6bcfebQO15P6E86KTYI9O/zMsj//stOnZtYuJN9+EtrcQAUhmNyoCLfuV+hlX6qAy6iWVV+pf4qC5yk1J0bJKAopkIAlQFXThtPTTxLwtfWoItCEIN8DhQjvW3er9pmn/5tH16FtPPjym/fZ+W2KFrgpKQE93YL9BeuSfM+fct4hCM8jwMJUPzozdi6jkP7P0qwZEFTxNd+DQlZ/4Qz9BjgH0PldCsb+Iff0bBtqKTbsgrjQOkdiyn/cVfFlBnPy9MXSWxbg7RRc+hJzpx9B5J8OhLCrqaeipqqYSvfgvT0PGOPIHA1AsKHtrZ5q1EFz9PetsSBJsD78gT8E08uyAZ74ba1Uhi3WwSGz7ASEYQ3QE8Q4/FM3Q69qrBn1tQwjQNtEgratsulI49qB31qF0NaOHmA5XXRQnJHUB0+hCdXovmZHMiyHYrQRWl3HmYmLqOqWuYWhZTzWAoaYxMAiMdQ0/FwChckgXZgRyswFbUC1tJb2wlfbCX9UUOVR42Pf7AazNR23aR3PQRyU0foSe6EJ0+PMNn4B19EvbS2gP20eId1iz86ncwlRTO2rEEpl5QUPDRE2EiS14gufZdTF3DO/okfOPPJLb0BXxjTsFeOQhBlGh/5bfYy/sTmHYhpmnkgym1q5Hw3AcpOetHBbTFw0H3KER3lTm/feti2l/5DcFjLiMw7UJgv47OftvyTJFxp1F0wrf3fVeaannDpqJUXn3PEbECJAEunNSbXkFXzwz2l4yeBPvwcKRxQU9M8PmxvjHCnHUttCczhFMqfUvcDKnwM6ZPEQ3hNCG3jNtuw2ETKfM5EQXoSipsbo6SUQ3CySxOm8S4PiHqO9Ks3tvJ/LpW1tTHSZuHZtWokRaaH7kRe8WAnHWniJ6K0vTwd5DcOTFH2Yaejln0cX9JXmkcrGSo6YFvITo9VFx+5z8tylTAUHMHCEz+Gt6x+3U1TZPU1kVE5j2B1tWArawfwaMvwTVgUv7YRiZBbOUbxFe+gZGO4eg1DP/kc3H1n3jAOmioGVJbFpJY9x7Zhk059tRoPMOPwz1wSkH85MRKsEWsLrZoQloDpw2CbpleASctsTTRrI4sCGRTcbS2XUjRJjobd5Lu2EuqvZFsrPOA6xYdHkS334oLHB5EuwvB5rBiAknOX5tkGNglnWxWQVOy6Nk0hpLESCfQU9H9Cv3dEJB8JTiKqrAV1yCV9MZWWou9tPZTKd8i4BYsGvvBIlU9HSe1dRHJjR+S3bsBEHD2GY13tEXD/mSx3dQ1UnULiS17GaV1B5K3CP/Ec/COOQXRnru3hk5qywLCC55GDzdiK+tnjZBVDSb8wf1falxg6hpND3wT0e2n4rI79v2W1AzND98ImFRefQ+izWmNYz5yI4IoUnnlXxHtTnp7dFbfewtaIsykm+9FdXopcrkIeuwkV7/Fq/f9nhOu/TGVE0+lI61i6iaCKFDideCUrfnqQeU+3A6ZXkE3A8q9yJJEZcCJqptIomWd1oMvD19mgl0JVJqmuUoQBB+wEjjbNM1Nh9qnZzH9/Pj925u5b97Og74miwK6YWJiPeSOGljCzbMGsWRnJ396r66AIi4CQu79+yPbuJmWp36IZ/hMSk67Ob9932xoI5VX3ZP3pGx78Rdk69dTecVfsJXUfEJp/G95RW7TNGl5/HvoqSi9rvtHgfrj54WpKcTXvEN0yQsYyQjOPqMJHP11nNXD8+/R4p3EljxPfM17AHhHzsI/5byCJFrpqCe29GWSmz4C08A1cDL+8WceQP0C62Ga3rWS5Ia5pLYvA11FCpTjGXwU7sFHYa8c+MV1mVNRq5Ica0dPdKEnwxipKHo6hpFNYmZTGGoGU1UwDdUawur+e5YkBNGGYLMj2pwIdpeVlDt9SO4AkieE5C2yuuOBMiRP8As7b7V9F6ktC0nWLbRmskQJV7/xeIbPxD1g8kHvfbapjtjK10ltWQCmiXvwUfgnfy1vZwHWvex8924yu1aBCe5hMwgec0n+Xna+dw+SO0jwmEutz2zeRvsrv6H6+kcxdRXTMKwiybrZ+MaeVlAAOrxrM6xxCSVD1bX35hkMejpO80PXFyiEa7EOmh/+DnJxNRWX/AFBlKz3PXwDgt1F5ZV3FYwDhD98mNiylyk97+e4D+KrfSgIgMPW07X+qtCTYB8ejjQu6IkJPj/CCYVXVu1hd2ca0zToU+Yh6HTQv8xLPKvTp8iNLAqEUwoeh4wsinjsEg2RNNG0immY1BS5sMsSkZRCVzLLCyt2M29zBxkV0oZOy1M/QOtsoPLqvxXYbiXWzaHznb8QnHEVgclfA/ZpTuzftcsrje9HHwdIrH+fzrfv+qe72Psj27jFmsvdswbREyQw6Wt4x56yL9Hu7nIufBYt0pxLti7CNXByfg00lAyJdbOJLX8VPdaGHKrEN/5MvCOOP2hyqXY1ktgw1yokR1sRZDuufhNwDz4KV/+J+X08WMrdkmDNRZe6ySl020iqKqYK8Vzb1+8Am03A5xCxyXb2dqXpSmXQo+1osTa0eCd6sgsjGbFignQcI5vCVNOYatZa83RtX0wgipYeiWSzCvN2J6LDi+j0Irn9VkzgCSH5Siwmnb/0C9OJ0dNx0tuXkqpbSHrXajA05KJeeIYfh3f4zAO61d37JNa9R3zlW+jxduSiavyTzsE7fGY+hjANneTmeUQ+ehQ90YnoDhCaeS2eYdPz9/LLjgvia96h672/HbB2dzPV9lcI73jrLpIb51L+9dtxVg9FAmIfP0zHkpcZddn/4h80hWjaoCJgQ4i3Mvf26+g/ehLfvf0BwimVjKIh2USiySwdcZVeJS6CDjtOm0iRx4HbaWNUlZ/eJV4cck/H+qvCV0YRFwThNeAe0zTnHOo9PYvp58PKPWGW7OxkzsaWA2y4JFHguqP78uji3SiqYVVLBauTnRdBUg1EUeDao/vic9kIue0HWHwBhOc9QWzxcwcsemq4ieZHbsJRNYiyC39tVawTYZoe/o4l4HDZnxAkW44+/h0cvYZRdsGv8klqetdq2p7/GaGZ1+KfePYX9r0YaobE6neILn0JIxXB2WcUgWkX4ajZ512oxdqJLnmBxLrZYBh4hs/AP/k87DnPQ7ASuPiqN0mseRcjE8dWWotv7Kl4hs046KJqZJOkti4muXk+mT1rwNCRvEW4BkzG1X8izj6j/i1mqr9smJpKZu8G0juWkdq+zLLjEEQcNSPwDD0G9+CjCub3umGoWVJbFhBf/RZK81bLF3PUifjGn1FQAFEjLcSWvpS7dzqOPqMRJBv2kt64Bx+Vt+FQ2vfQ8cYfKTvvF1bBQJJpeeY2AlPOx17R3wqo4u0Ep118wIzd4SCx8UM63/wzJWfcWqDu3f7GH0ltWUDl5XdiL++HaRq0PfdTsk1bqbzqbmyhqhx98vekti2h4rI/FxQO8pSx0SdRfPINR3ROkijwf2eNyHtj9+DLRU+C/fnwWXFBT0xw5DAMk6RiMZFEBPZ0xmnoSmIKEmV+O05ZptTnIJbRMEwTTTcJuW1kNQOnTcQhS6Rz85lep5WwtMfTLN/eyevr6lm6K0xGzSlctzey89Gb8FQPZehlvyKaFdEAzTTpePV3pLYvo/LyP+dnTDvfu4fEmvcov/g3eRvD9ld/R2r7UiqvuCvPXjINnaaHvgNA1dX3fKFOG5mGjUQXPE1mz1pEd9CiE489pZBOvPFDooueQ4s0Yyvpg3/KeZZS+H5d9lTdImIrXkVpqrPWqBEz8Y45tUDfpRumaZJt3Exq8zxSdQvRk+Gcfeco3AMm4ew3AVuwIj+nbcNqdrhkcNoFYmmT1H7xmB3w2iGhfDrd2gkc6Cr+5aC7PK5+yntM00TraiS9cwWp7cusTrVpIPlL8Qw+Gvew6YcURM02byO++m1Sm+dhalkcvUfhn3iWxSLIJc2mppLY8AGxZS+hhZsRZDu+8WeiRVqQA2VfWVxgKGmaHvgmcqCC8ktuz19PpmEjrU/9KE8DB0jWLaTj1d/lxU5FILlnLW3P/pTQ2JMoO/EGZNliNjh0lb3P/AAt2sb1f3kB1Rmif4mHgeUBRNHALktsbIohIWCTBUwTBpb7qS3xIAgCtSWeI7qOHvxz+EoSbEEQaoF5wAjTNGOfeO0bwDcAevfuPX7Pnj1f2HH/f8DKPeEC79srp9by6prGPF28W9wonlZ5dsVeIim1YPuUfsUFNj4r94T56Svr2dwSP+BYpq7S8uT30aJtVF7114K5ovja2XS9ezfBGVcSmGx5XnbTY/f30s57+Z14Pb6xp+b3b33uZyjNWwu8AL8oGGqGxJp3LWXQVATREyIw9QJ8407fb263g9iyV0isfRdTzeIaMBn/pHMKhDQMNUNy0zziq95EbduJYHfhGXos3lEnWjSjgywKeiZBevtS0tuWkt61yqJ3Szac1cNx1o7B2WeUtaB8Trr1vxNM00Bt301mzzrSu1eT3bsBU80iyHacfUbjGjgF94DJBwiWdUNp20Vi3WySG+bmZtiq8Y077YDuQLZlO7GlL5GqWwiiiL2sH3KoktIzrN9mqm4BejJCcMZV+XvS+d7fEO0uvKNPQg5W0PnWnQRnXIXsK8bIJBCPQMV+fxhKhqYHv4XkCVJx+R35hT6/aB59CcGjLgb2iZUUnXxj3gYk3+n5hMq+kUlYlDFJpvLKuw9Qtvc6JBLZT0rm7EOPqNlXi54E+8hxqLigJyb459ARzxDLWAm22y5TEXDSGsugaDp2SSSl6lT4XWxtiRFOqthsArXFHrK6QchlJ+TJKWrnlJH3tMdZtjvMB5uaiaYUGsJpOmKWtVPIIxNZ9S5rnr+LIWddj3vUqWg5e6ZEMkbDwzcg2N15Zo6hZGh+7LuYatYaeXF6Lfr4Q9cj+0ryxXiA1LaltL/8fwWjZl8kMg0bicx7kuze9ZZTxahZhI69Ir8WdHdBY4tfQO2sRwqU4594Nt6Rswqex9mmOuKr3iS5ZX5eQ8Q7+iTLGtR+YCHdNA2yjZtJb11CavsStLA1Vy0XVeOqHYOzz2gcvUfmRS4dWDTyT0tcP4luYTIb1r34bM+YLw96Kkpmzzoye9aS2b0aLWrZxNpKeuMaMBn3oKnYKwYeNH4yskmSmz4msfY9lNYdCDYHnmHH4Rt/esEomZFJEF/zLvGVr6MnurBXDMBWXINgc1B80g1feVwQmf8k0UXPUnHpH3H0sua0DSVN8yM3gpmjhttdaPEOmh++wVLZv+SPCJKMPRNj50M3ItpdDLjqLrKyEzsWrb5z3uPEFj/P9G/9hrPOPhu7LDGs0gcI7OxIUeyxoxs6W1vjVIfc1JZ6qQy4sEkSNlmgKnj4qu09+OfxpSfYgiB4gY+B35im+fKnvbenWn3k+NuH2/nz7LoCpeAp/Yq55MElqJolbnTl1NoC+vj+9FGAl1c1YAIjqgL87LX16J/yNFY799L86M04qodRdsEv91UO8xXrpVRc+qe8f2Tnu38lsXY2ZRf9Glef0VYX7/mfk23cROWVd+dtl5T2PTQ/ciPeUSd+ZrfONA2SGz4ktW0xRjqO6PRgC/XCO/aUApXyT6LjrTtzs8yNmEoKuagXwaMvLbB50FNR4qveJL7qLYx0DHvlQPwTzrbek1v4TdNEaaojvuZdUlvmY2pZbMW98YyYiWfYjAKqXMF5d3d0d64gs3s1ake9dT/sLhy9hlr/VQ3BXjnwiGzL/lUwsimUlu1km7bkBNo2Y2QSAMhFvaxgod8EnL1HHrJjryfDJDfPI7lhLkrrDpBk3IOm4RtzcgHTwDR00tuXEVvxGtm9GxDsbnxjTsY34Sy0SDOxpS9Res5tCJKNzN4NJDd9jHvQ1LxYmBbvJLnpIzJ71qLH2rFXDKDopBswNAVl7wayTXWoXQ3535Or3wS8o0/6zMJHN6uj/JLb8yMIefZGoNyyppFky37riVtxDdhnYaN2NdL86HexVw6k/MJfFxyr/Y0/kto831qgcxX3w4EsCZiG2SNq9hWjJ8E+MhxuXNATExw59nQmccgiApBSDfqWeMhqOq2xDLpuUuSxE04p7GhPEHTJhJMqfUu9BN12Qm47sYxCSzSDYYDLLrFyTyet0QwLtrXTGk3hccm0xbJ47XaKvDLZrMrSh/6XjroVjPz2XRj+PmQ1KyGwtCV+hj/HwjGwOpEtT96Ke+AUSs76EYIg7CvGT72QUM6y0DRN2p77GUrLtsMqvCsd9cRXvoHa1YAgiEj+Uly1Yw/pugDQ+c7daPFOtEgzWrgJZAf+8Wfgn3BW3m7LNA3S25cTW/oi2cbNiE4v3tEn4xt3Wn7UDazYIbH+AxLr3kPrakSwu3APPgrv8Jk4eo845KiV2tVIescK0rtWWUVpLQsI2Er7WNadvYbiqBxk6Z4c5riWwD7F8K9Srtg0dNSuRpSmOisuaNiUd0gR7C6cvUfh6jsOV/8JyIHyQ35GZvcaEhvmkt62GFNTsJXW4h19Et4RMxEd+7qwariJ+Mo3SKx/H1NJ4+wzBv+U83D2GU22YePnigtCJ3zbcktp2ITSvgs93oFpgr20D74JZx1Ui2d/qJEWmh78Nu5B0yg98/v57Z3v3kNi7XuUf/13OGtGWIKAz/0MpbkuHwubpknnq78luX05vS77E76KAUg2EA2I1m+g/skfUzruBGZd91POHdeLPV0ZREHAYZcZWe2nK6HSHE3REklT4ncSctmZ0q+EoMdOwGXLixn24KvBl5pgC4JgA94E3jNN847Pen/PYnrk6O5gdyfT+wuYdXem73p/K/O3deT3qS128+cLxgBw0f2LUXXrXksC6Idx27u70J+kdOvpOM2P3IQgy1Re8RdLjVHJ0PzYzZhKisqr/orkDuybQy2qpuKS2/P0L2s25XUqLvvTpyYV4Y8fI7bkBeRgJVKu0qh2NVpqpAcRyAJLLTS24lWKTrgeyRPIq5cbyUiuMn0O3lGz8omgoWZIrv+A2IrX0MJNSN4ivGNOsSqd+6llG9lULkH8gGzjZkDAUTMcz9BjcQ+adshuLViJWKZ+HZm9G6yFqKOe7uVQDlZgL+uHrawv9pI+2IprkEMVBwh9fBUwdQ0t2ora2YDascey6Wrblfe2BKv67qwehqNmBM7eIwsCj09CzyRIb11McvM8S4XdNLCX98czchaeYdMLfUtTURLr3ye++m30aCuSvxT/+DPwjj4pv9BqsTaii1/AVTsW9+Bp6MkwifUfWErwY07OWdCYCIJoFQF0DbVlO6lti8k2Wn7kiDK2oipElx89GUHrarCKPafcdMjrUDsbaHrkBjxDjsl7VpumSdsLvyC7dz2VV/4FW3GNVbl+7GZMJUPl1X9FcvkxNZWWJ2/dxwbZf34xRznfv/t9uBhQ5uWcsb16RM2+YvQk2IePI4kLemKCI0ckpdCZyGKYUOSx5214zJz3rSgKrK7vYmtrHI9dRtF0pg8qp9jnoCOeZfXeLqJplaDTRlnAycbGKDvbErRG02xuiVLudaLqGn3K/Bi6SWcySzYR5eWfXYLg8tP3yjvICo481bn9o0eILn2JqrN/jHvwUahAdOmLRD56tED0dN8c6u9xVg8DQO3Yaz1jh82g5LRbDnnNamcDzY/dDIKAvbQvmAZqtAX3wCkUn3Twgv3+MYHsL6H1+f/FyCRQWraDIOAZPpPA5HPzzg9gadHElr1CatsSAMuSa9xphcVg0yTbsJHE+g9I1S3AVNJI3mLcQ4/BM+SYQ7LdwCrCZ5vryNSvJ9uwiWzTFsycSrdgd2Mv72dZd5bWYivuja24Gsnly+/f7Vn9ZcM0TYx0DLVzL2pHPUr7HkuQtW1nXoRVdHis4kDNcJw1I7FXDDgk1d80dItCv2U+yS0L91loDT0W78hZBR1u0zTI7FxFfPVbpHesAFHCM/QY/BPPKaByH0lckKlfj9KxF6V5C+ntyzEyFoNT8hblCwFK6w4E2UHl1fccUhXeNE3aX/wlmb0bqLruPmSfta53i6Durz/QLRJYfMp38Y46AdgXW1ccfzXFE88l6BWo8LlobO9k5V+vR5RkjvmfBxhZW8rgyiB2WaTM76IzodC31E0ia1DfmSCp6FQH3ai6wbQBJfQt/fdv2Pw34lBxwT898CJYfw0PAZsPJ7nuwefD+D4h/vf04byzoZlTRlTmg+rxfUL5/x9e6S9IsMfUBBnfJ8Q3Hl+RT67h8JJrAO+YU0jvXEn440ctinPOYkNy+Sg5439ofeY2umb/neLT/wfR7qT0zB/Q/MT3LNGSr/0vsr+EopNuoOP124kuejYvMhE8+hJSWxbQ+e5fqbziroM+jE3TIL7qTdxDjqHkzB/se+jqGogHr86Zmkq2aQv28gHI/hIMJW2peboDSE4f0aUvEX7/PiILnsJe3g/JHcA7cha+cafhGXMymV2riK98g+iCp4guehb3gMl4x5yCs3Y0oiPXSR1zMmq4ieSmj0lu+piu2X+na859OKqHWRXTAZMPqHxK3hCeYdPzFXYjmyTbvA0lZ9OltO0itXUx+SVTEJH9pcjBCiR/GbK/BMlbbAmQuP2ILr+lFupwW4Iln6K+apom6BqGkrJUxDOWYqieDFvCabF2tFg7WrQFLdoGxj46shQox17WF8+w6TgqBmKvGlywyB8MWryD9PZlpLYuJlO/DgwdOVBuzbYNm1Ew854PUNa8S7JuIegqjpoRhGZcZamJfqKrLHmLsZf1Jb1rJa4BE5E8IfREJwgiWqydxLrZuAZOQWneZimUNmwEwFbW11Iirx2Do3Jwgcha15z7iK9+m9DMaw86Z2+aJp3v3YMoOwjNuDq/Pb7qTTK7VlJ0wrfygVnXnPvQupoov/g3+eJB+ONHUVp3UHruTwuSazXSQtfsv1tWc/uJ/hwu+pV4emjhPfi3RU9c8OXD65DpSirouoGqGxiGiSgKCIJA95LgtdvoFXSxoz1JwGEjmlFx2AS2t8WJJjV00yCSUSnzOelT7GFvV5LeRR5qStyUe52kFJ20qtGV0phc4act6mXS5bcx/2/fp23uIxQd/618B7X0mMtQ6tfT8s7dVFYMsJ77k84ls2sN4Q8ewFE9DHtJb4pmfYPs3vV0vPlnqq66G9HhwVZSg3/SOcSWvIh3xPF5UahPIrF+Dqah0eubD+aTGrDW/oPhYDGBs2YEkieIo2YEsWWvkNzwAcn1c5CLq5H95fgmnIm733hKzh6CHmsjvuotEutmk6pbaI0zjTkFz4jjkFx+nDUjcNaMwDjhm6S3LSW5+WPiK98kvvxVJF8p7kFTcA+ciqN6WEGcI8i2/L6Q6wh37iXbVGfFBK07SaybnffHBhCdPuRQBbJ/n3WnzRNC8IQQXb6cirgbwe78zA64aeiYahYj2x0TxDBSEfRElyWeFmu3BFYjLRjZ5L7ztruwl/bNJ8OOykHIxb0+9XimppKpX0dq2xLS25agJ8OWCFz/iXiGTcfVb2LBmqwnwiTWzyGx9j20aKslVDftIrxjTzmoRehnxQXxtbOxFfciu2cdybqFmNkkotOLq/9EXP0m4Og9suBz1Y69ND30bZKbPiYw+dyDXlOqbiHpnSsIzbw2/zvUE2E63/4LtrK+BI+x2BmZhs1E5j+Je8gxeHKuI0r7bsJzH8TTdxwVR50NAoiSRFo3aZt9H1q8i6k33kXvihB9il14XTYCLjuGaSAKJroGIY+NMm8Rq/dEaI9nqS1zU+z56psyPfh0fBGKEkcBlwHrBUFYk9t2m2mab38Bn92DHFbuCVtCZZrB8t1dDK7wHdC58rkK/8DeXNfMpL7FfLClrWC7JMAZo6t4dU1TfpvbLpFSCmc9BUGg+JSbaH7kRtpf+wOVV9yVnzVy5nyJowuewlk7xnrglvcjdNw1hN//B/Hlr+KfdA6eoceQ3rmC6OLnrVnkmhGIDjdFJ3yL9ld+Q2zZywdNMExNxVTS2Mv6FiSQnyaCoqdjaJEW3EOPtf6djGAoKUR3APfgabgGTSW9YznhDx4gm/O1Tu9aRfEJ1+MeegyuvuNw959oKYKueZfEhg9IbV2E5C/DO/J4PCOOxxaswBaqInjUxQSmXYTavptU3UJSWxcR/uABwh88gK24N67+E3D2HYezetgBSpyiw4Ordgyu2jH5bYaSsarEXQ1onQ2WL3W0BXXnCvRkhEPWqwWxwKYrX9o2dExDsxZo89CzAKI7iBwoxV4+APfgoy2brqJqbCW9P9WKoxumrpJtqiO9axWZnSst+jcgh6rwTzzbUlb/xNyVFusguelDEuvfz1Hs3Ja42dhTCpgJpq4WdPIFUcLVf5I1U7fwWULHXoapZhGdHtRwE0r7bmJLX8LUFOSiagLHXIpnyDH58YSDwVZmdUGMbOKg15tY+x7ZvRsoOumGPJVQadtF+MOHLU/0sadZ79swl+SGDwhMuzgv6pPavpT4itfwjTsd98Ap+12XRsfrfwRBpOSMWz+Tni4KMGtoOe9vbsUwLXr4N6f3/9R9etCDfzF64oIvGSlFQ8DqXiezGopu4PzEsyTgsWFgoutQHnBikwTCKRVBAI9Toj2u4pcESgIOestudAMyqkYsozGozItqmKQVHVXXSWU0KvwunKecyI41i2ha+AqOPmOpGjYZBRCQGXD+D9j4j5sIv/EHSi+2WGvFp3+P5kduouO131Nx+R2IDjclZ9xKy1M/pHP23yk5/VYEQSAw7SJSWxZahfer/3rQcSM9FUVyBQqSa+CQriSfFhPYQlUUn/QdfBPOovPtu1Bad6B1NpCpX09gyvn4J5+D5CshdNzVBI6+hNSW+cRXv0N47gOEP34E98CpeEfOwlk7BtHmzBfRjUyC1LalpLYuzM0Lv4Hg8OCqHYur33ictWMPGC8TRAl7zgKLnG6HaRpo0TbUjnq0rgbUcBNapBWlbSfpHcswtUPLngmy3Sq+i9K+hoRhWIm1poB+6ElvweZA8pUiB8pxVw3BFqrK23dK/tLDslNTIy1kdq8mvXMlmd1rMNUMgs2Jq9943IOmFSirg7XWp3esILH+fdI7loNp4KgZQfDYywvH9j4RE3R/dweLCwASGz4gueF99Fg7gs2Je9BUPEOn46wdc8hYUi6uBknGSEcP+rqejtM15z7sFQPw5TQDTNOg4607MNUMJWd8P2dPF6fjjT8g+0spPvkGBEHAUDJ0vHY7osNN9em34BBFRKA65KZl5Rx2LvuAKed/i4qBI7HLEoJgY3zvIgzTZH1TlDK/C49bJuC007fUQ99SDxlNJ+R24rL/5+v7/Lfhn06wTdNcgBXW9+BLxJKdnSiagWGCqhks2dl5QII9pV8xsiig5WTBDdPknQ3N7D8GIAD/d/ZIwikFUSCvIJ5W9IJ/d0NyByg+7Xu0Pfczwh/cX0ClDUy9gEz9erpm34u9chD2kt74xp1Opn4d4Y8fteaKqgZTNOubZBs30fHGn6i86m4klx/3oKm4Bx9NZOHTuAZMPkCRU5DtiC4/ariJw4ZporTuIDTrGwBoXY0Y6TiOnOKzIAionfU4a8fgGf49a8Fc8w4db/wB+/KXcfQehdK8FfegafinnEfw2MtJbV1keW4vfJbowmdwVA/DM2yGpYztDmAv64u9rC/BYy5FDTeR3r6M9I7lxFa8TmzZywiyHUevIThqRuKsGYG9ctBBPRZFuxNH5cD8THvBZema1XFORiybrkwcI5PAVNI5m64sGBqmoVvJtUBuYZWt79HmRHC4ra63y4fkyllyeENHTEU3NZVsy3ayDRvJ1K8j27DJoooJIo5eQwlOvxLXgEmW8Mh+C7EVdCwhufFDMnvWASaO6mEEppyPe/DR+/wsNdUKTFa9jRwsz1OyuyH7SwgefSnRBU/ReP83QJRJ71qNHm9HdHjwjDwB78jjDymm8klo4SbLX9xzIM1ai7UR/vAhHL1H4c0FPYaSoeP1PyA5vRSf8l3rN9Wxl67Zf8NRM4LAURfl9+18607s5f0JHXdNwedGFjyJ0lxHyZk/PMCe5OwxVVT4nfxj/s68u4osWgn1N6f3LxAq7EEP/l3RExd8+ZAlEcM0UXXDeuwf5Nsu8ThwyRKiAJIoouomAadMsdeB32mjxOugf6mXIq8DRTPoVeSiKZzG57Thd9tY3xCh2OtAyRjEsjo2SUDVBYaeeg2RHWtpe/svlNbeTXF5BbGkhhGsoPfpN7L7pdsJz3uc0HFXI3uLKDn9e7Q9/3O65vyDklO/i6PX0HyBPlk71hIUszkpPvlGWp+9jci8Jyg6/roDr9lXatlVZlOHVQD+rJgAIL19CfayvpSccSvZ+vWE5z9BdOFTxFe8iqN2DHqsI588e0fO2ifSufEjUlvmI3lCuIdNxzN0OvaKAdbs9sjj8Y48HkNJk9m9mtT25WR2rSRVt8C6jqJqnL2tmMBRPfygei6CIFoF/WAFMOkTl2VanedE1z7rzkzCsulS0phazqbL0C37TgBRRBAlBMmGYHNYftkOj2XT5fIjuYNWTODwHJEnuWmaaNFWi+q+dwOZPWvzAmeSvxTP8ONwDZiEq8/ogmZDtwhcctM8UlvmY6RjiJ6gZcU18gRsxdX59yptO4mvfodU3UKqrr33gDn97rggMv8JGv5+JaahYaybA5g4eo8ieMxluAdNO6gY3Sehx9tB15B8B94TgPD7/8DIxCm+8Ff54nhs2Stkdq+m6KTvYC/pbTHf3r4LPRGm4tI/5Mfcuubci9rZQNmF/4fgCZFWLZu2xt07WfjYH6kcMo6yaeehayayTcLnlin1WcmzbppkVIOsauDwCtgkkSLvf79LzX8yvjhPhB58qZjSrxi7LOZnsKf0O3A2ZHyfEL86awT/+9oGDNPELoucMqKS5bu7UDQDURD4Vc7SZ+WeMHZZJKtai7MJh2ySumotQYnYkhes5DRXDRZEiZIzbrWq06/mqtN2J8WnfJfmR79L+2u/p/LKuy1K+Zk/pOWJWy36+Lk/QxAEik74Fpm96+l86w4qLvtzIYVKEHD2Hklm50pMQz8s9W09FbFoUjYneiZBascy5FBlQVc0VbeI4DGX4qweirN6aH4h0CLNxJe9DLIDPRUluekjSs/5SX5h1WJtJDd+RHLjh3lauLN2DO7BR+EeOAUpVxG3TTwb/8SzrYV17wYyu9eQqV9HdMHTRDFBlLCX9cNRNQh7+QDsFf2tZPRTEl1Bki3K+KfMO38ZsGhrDShtO1FylPZs63bQLeVaW3FvvCOPx9nHUknfX5QEckn19mU578uVoGvIwQoC0y7CM+K4AqE6tWMv8XXvWcri6RhyqBL30GMOPCddQ23fjZ6J51RZTZx9RhGacQXuQdOOyLfTNE3S25fj6DX0gO/fNA063/4LmCbFp9yUDzi63r8vv0BKniCGmqH9td8h2Jz5brSpa3S89gdMQ7fGG/brrqR3ryG25CW8o07Ec5Dr60wqDCz3Ffwtnj+hpmAkpAc96EEP3HaZEq+DtKZT7nYc1PdWFAV8Lhtuh0xG1ZFEAadNQpJE4hmNUp8jryZul0WqAi503UA3rLws6HKgqDrN8TS7OxIYuklGMfG63Qy/5DZW/fUG9r56J4NuvgNFMZAEcIyeQXzPejqXvYy/ZgS2AZPw9h1HZur5xBY/j7P3CLwjjs8V6NfRNedeHJWDsZXU4OwzCu/Y04iveB33wCk4e48suB5nn5FEFz1DetcqPEOO/szv6Ehigu5kNrVjOZLLb9lw1i0CU0eLtZFY+y6l5/4Me1lfimZ9k9CMq0nvWE5i49w8LVwOVuIechTuwUdjL++PaHfhHjQN96BpmKZpuW/sXkNmz1qSmz4iseYdACRfKY6qwdgrB1hxQXm/g1pbdkMQBCSX7zNHtr4M6MkwSutOlNYdZJu3ojTVWXZkgOj04qgZgW/i2bhqxyAXVRck61ZSXUeqbgGpukXo8XaLLj5gMp4RM3H1HZeP9QwlTXLzfEtZvLkOQbbjHnLsQTv3amcDiXWzSe9eg5mOIXmLCUy9AM+oEz5TrOyTSG9fDpBnou2PVN0ikps+InDU1/Mjk9nGLUTmPY570DS8oy2dgfjyV0hvX0po5nU4KgcBlud7csMHBKddjK92DKo1vYdkKCx97Jcg2el3zv+QVgWKfSIOUSTgsuOwizhsEk67jCQaeO0QdH8x/uQ9+HLRk2D/h2B8nxBPXTvlMztYX5/cm8EVvoL3ffLf+3/eS6saeGHFXjTdRJIEDMM8oIsN1tx0tn49ne/+FXvFQGyhSgCrOn3GrbQ99zO6Zv+N4tO+h+T0UnrWD2l58gd0vnUHpV/7GY6KAYSOu4rwBw/k6eOSJ0jxSd+h/ZXfEl34DMGcqmg33EOPzc+6uAdM/szvyF5ai+QvpfH+6yyac3EN/vFn5V9XWncgSDKOXpawSrcSZtnXfoYcqiKzdz3JdXNIblkAukrLE/+To7lPB0Scfcfhn3I+atsukpvnkapbQNe7f6Xrvb/hqB6Gq/8k3AMmIhdVWwtr/4m4+08ELFqRpcC9iWxTHYkNczFXvWWdmCjnKVi24mpsRb2Qg5XIgXJEd+CIqslHCqsSHkeLtKJFmlG7GtG6GlE66i1V0ByVTLA5sJf3xz/uDBzVQ3H0GnZQYTc10mJ18bcvI7N3veUN7ivBN/Y0PEOPLRB+0TMJUlvmk1j/PkpTHYhSwdz7/nNdaudeEuvmkNg41xKs8xbhn3q+VenO/RaPFEpTHWpnPUUTvnPAa/GVb5LZs5aik27IL9CJDXNJrn8f/9QLcdWOwTRNumb/HbVjL2UX/DJPWwx//CjZpi2UnPnDAnq6ngjT8cafsBXX5Dsqn8QpI6xr2f9PcHjVF2tn14Me9OC/AwG3ncN5OkiigMexL9wLuu0HDdKdNol+pT7CySxpzaB/mYc9nUmiSQ1MWNcYJZ7WkUwdR1Ev+p3xHba99GeWvPwwgWkXI5qgmFA641rie+toeusOqq68GylQRvHRl5Bt2ETX7L9jL++PvbSWktNvpfnRm2h/7fdUXP5nRJuT0IyryOxeRcdbd1J19V8LCreO6uFI3iIS62YfVoL9uWKCjnpCX/sZtqJeaIlOkus/ILFuNmpHPc0PfwfPiOMtS0lfMVKgjLJzf2qtZXWLSG2ZT2zpy8SWvGjNYA+chKv/JJy9RyLI9jzjzT/pHExDR2nbRbZhI9nGLSjNW/MdbrCEt2w54VNbUS/kUKUVF/hLvnQRVFNTLG2WSItFTe9qRO2sR2mvx0hF8u+TQ1WWtknVEBzVw7CV9jlgHttQMmT2rCW9w4oL9GQYJNkSJZt+Oe4Bk/NsBNM0yNSv3yccp2as9XLmdXhGzCwoKBhKmtSWBSTWzSHbuAlEyRrbGn1SQaJ+RNdtmiTWzbbE5fbTi4HcjPV792CvGJAfa9TTcdpfvx3JV5IvxGcaNhH+6FFcg6bim3AmYM1dd82+1/L0PuoitP0+t+v9h8i27aLqvJ+TkQPYdBXNkAi6bUzoW4TXYUPRDHTdIK1o2GWphw7+H4KeBPs/CPsLmh3J+z5tPwHyFHIRGNcnxLLd4QPfJ8mUnPkDmh+9iY7Xb7e8/HKdOVftGIvutfBpHNXDLeulykEUHX8tXXPuI7b4BQLTLsQ3/kwyezcQ/vhR7FVDcFYPxT1oGp4Rs4gueQFn37F50Q/A8lL2lRJb+hKu/pM+M9EUJJmSU79LtnkrRjqOq994kpvnIYeqcFQMQIu2YS/ri2lYj7dU3UJrESvqhWnmSvamSei4a4h89DCiw23NVc99EMlnLWqmrhE67mqC068gOP0K1LadpLYuJrVtCZGPHiby0cNIgXJcfcfirB2Ls/dIi37l8uEeMAn3AIvqZZqGlci27kRp22WpdjdvJbVlAfunV4JstwTOvEVInqAlcJYXOXMh2Jy5GWwZQZAsnqBpYpo6pq5hagqmmsVUUhiZJHomgZGO7RM5S3QWCKmAReuyFdfg6jPaUjcv72d12Q+yYBnZJJn6DWT2rCG9axVaVyNgUeD8E8/GPXAq9qpB+UXXULOkd64kuekjUjuWga5hK+lN6Lir8QyfWZC0G9kkyc3zSa5/n2zTFhBEXAMm4R11Iq5+4/9pT/HospcsWvnQQnsXpX034Y8eyS/W1rY9eRp48OivA9Z8dnLDXAJHXZy3A0ltXUR8+auWcN5+HWrT0Ol480+YSpqSi3590PnCgaUevj65N3/7cHt+XEMUIJw69KxdD3rQgx58ETBNk1haJa3qRDMqdlFEFaBX0EVVyElW09EMsNtMIimNWNLAPXgGRaNWsfejZ3D3Ho6tZhSaCqpkp+ysH9H46Hdpe+13VHz9DxiyLR9DtL/6OyovvxPZV0zJ6f9j0cdn30vxqTcj2p2UnPY/tDz1A7pm30vJGbfmz1EQJbxjTyU6/0mUtp35LuKh8M/GBGqHpZ7tn3oB4ffvx9l7JMkNc0mseccau7JbiWFo5rV4R52Ab/SJ6OkY6W1LSW1fSmLd+8RXvYUgOyxKeO1YnLWjsZX0QRAlHBUDcFQMgAlW0q+normYYCdq+27Ujvqc2Flm/6tC8gStmMBbhOgKILl8iE5v7pxcCLLDYgRK8n4WqwbomkUdV7MYagYjm7So5emYJX6aCKMnOjFShfPHgt2FrbgaV/8J1qx4WT/s5f0O6h/dXTjI7FlLZtdqMg0bQNcQ7C5cfcfjHjQlN4Ptyf/ulPbdOeHYj6x5absLz9Bj8Yw8AUevIQdXbt8yH1PNWDasM67EO+L4g456HQkye9aitO6g6KTvHNB573j7TmvG+rT/QZBki+X21h37aOBOL3oyQsdrv0cOlFNy6s3W3HU2Rfurv0dwuCg94/sWTR8r9k5vWUB89VsEJ55N2bCJCBIUux2UBtxUhxyYpoBpmiQVDZskUlPkIZnV0QyTnh72vz96Euz/T9Ft+9VNEQfQDZOB5T5W1ocP6pEtB8ooPvVm2l/+NeEPH6LohG/lXwtMu5Bs42a63v8H9tyi4R17GplGS0XRXjkIV9+xlJzyXZofu5mO135P5ZV/QfIELVXRho10vPHn3Iy2VaUUJBn/5HMJv/8PMrtX55OYz0I3JQfAVlSdpxTZKweS2PBBnooeX/kGnpEn5N4p4Og1xFK7rFtAxaV/xF7en8SmjwnPuRfT0NFj7SDZiK96C9PQcfefaFXiy/sTPOZStFib5XO5cyXJTR+TWPMu3T6X1qzVMBy9hiD5Sq35quIabMU1Bf6dpqZYAmeRZrRIq6XmmehET3ShtO/ByM1afZpw2SEhiFZi7rbmrezl/ZH6T0T2lyEHynIV8opDelkDaIkulMYtZBo2km3YiNK6E0wDQXbg6D0C39jTcPWfUED/NtQs6V2rLDG47UsxlTSiO2ipsQ6fadl67OeDfYA3ZnENwRlX4R0x859eQLuRbdlOeutiq+uy3yyfoWboeO0PiE5PfsbaWiB/h2B3WZRvUSLbvJWu9+/D2XccgWnW3LXa1UjHW3dhrxxI6LhrC44XXfy81RE/+aZDWsxta09yyl3zULvVgHM+1wcbB+lBD3rQgy8SiaxGR06VvC2eoX+pF0Uz8DhsjKouQtNMeoVcdETTNKQzlkaEJFB50vWkmrax48U/0ufKu8Ebwg4IoUpKTruZ9ld+S3juA5SfeD14iyg984e0PvsTOt++i5Kzf4wr9wyNLnoGR6+hVoG+15B9Iqp9x+IdcXz+PH3jTie29GUi85+i7Gs/O6xr+2djguSGuVRccjv28v6k96yj440/IocqURo2Aybhjx9H7WrEM/gobMXVeEedgHfUCRhqlmz9etI7Lf/r9E7Lik50B3BWD8dRPRxH9VBLzFWyIbkDuPqOxdV3bP58TdNET3SihZvRIi1osXb0eMe+uKBtF0Y6nvPVPnIINgei04/kCSD7inFUDkTylVgxQbACW7AK0RP8FMsxhWzLDoud17CRbMMmK0YBbDldHle/CThrhu8TK8sl1am6haS2LETtrAdBxFk71hI2GzS1IA5RIy0kN8wlufFDtEiz5T0+5Bi8o2bh6DXsC2H5maZJdP6TSN7igt8bQHz5q2R2raLoxOuxlViuIbElL5LesZzQrG/iqByEaei0v347RiZBxWW/QHR4rFnsd/6CFm6i/KLf5IVSAcxwM23v3I2zcjCVx12BS4bigIe+ZT6qi9xUhbysq+9kZ1scj0Mmo+mUB1w4ZAlbj8/1fwR6Euz/T9EtmtadXAtYcz3DqwI8/81pfO+5NezpSh2wn3vgFHwTziK+4jUcNSPyNK38PPaj1oJaeeVdSC4/xSfdiNq2m443/kjlFXciB8opPfs2Wp68lfbX/0D5hf9nqYqe+X1anvw+ne/8hdJzfpJ/YPpGn0x8+auEP3wYZ5/RR9y13N8vUXR6EZ0+mh+5CVtxDY6a4fhGn4ipa1YH2OYkdOzltL3865xyN6htO7FXDsY/+VxMJU1612qSdQvofOOPdEo2Sw18wGRcuUTVN/ZUfGNPxdQ1ss3byNSvJVu/nsT6OcRXvQmA5AlhrxxoJec5D2w5UIaQUwS3l/QusLP6JEzTtATOlJTVndYUy77M0MmrnImSdU2yHUF2HLZ9x/7H0OPtKG27rRns1h0ozdvQ4zkbOMmGo2qwZX/VZxSOqqGFVhvJCOmdK0ltX0Jm16qc2rcP9+Cj8Qw9FmefUfl7aZom2aY6q4K9ZR5G0vLG9IychXfE8Z/qJ/p5YJom4bkPIrr8+CedXfBa+P37UTvrKbvgV0ieYF6sxFogf43sLUJPRWl/5XdInqL83LWhZGh/5bcIkkzpWT8+YO46uuBpPMOPy/tgHgqbW+L5/59UG+KHpwztmbvuQQ968KVD1QxsooBLlmlLCESSKopu4LbL9CvxUO53IAo6r6zai00ERQXVAKfTxeALf8y6f3yPljf/SO1F/4fNJpFUwT1oGoFJ5xJd9jL2qiF4R8zE2XskoRlXEv7wYWJLXyIw5TwCR11EtmkLXe/fh728H47KQdaM9p61FrW2clDeDlFyeglMOY/IvMfJ7Fl3SEuvQ+GfjQkyu1ZiL+tHYMrXEBxesvXrSGyeR3T+E0TnP2G5iAycjHvAJKux0H8Crv6WRa4WayOzey2ZvevJ7N1Iausi60Qkm9UVrhiwzwO7uDei3YkgCMi+EmsE6RMz6fvD1FSrI61mrJhAU6yYoLsYL4i5rrYN0eZAsDmtuOAI6OZGNpXzw96F0rIDpdWyGe2295RDlbgHTcPReyTO3qMKfKRNQydTv57U9qWkty1FizQDAo6a4RSN+7YlHLsfg01PRkjVLSC56WOyjZut9/YeSeCoi3KCZa7DPu/DQWrLArJNWyg6+cYCLZdsUx3hjx/DNWgq3jGnAJDetdqy3xo6Hd+40wGIfPQo2fr1FJ92S55ZEV/+Kqm6hQRnXJXXE5CAIofGxtd/jyCKjLjkh3hDHnxuGzOHlKFqOuGUQnM4TVJRcTs03LKEwy5QFXTSK+jqSbD/QyDsrzD9VWHChAnmihUrvvLj9mAfujvYiravE2qa4LCJPHWtZSn0tXsXHXRfU1dpffrHKB17qLziroI502zzVlqe+gHOmpGUnf8LBFFC7Wqk+fHvYQtWUH7J7Yg2J4n1H9D59p34J55DaKalshxb/irhuQ8SOu4a/JPOyX9mqm4R7a/+ltDM6/BPPOuA8zlSqJEW9HiH5V+ZSZDasRxXnzFI3hCmptD6/P/m/A2Labz/G/mEWWndQfGpNyP7S8k2bia1dTHpbUvyQmm2sr64+o63qO69hhUkWaauWfNWTVvy/tdqZwPddHDB5sBWVI1c1AtbsNLqJgfKkHylFg3sIMrj/yxMTbWq4PEOtGib1TUPN6OGG1E7GzCVdP69cqgy73vpqBqMvXzAAdeXba4js2s16V0rUZq3AyaSt9gKNgZOtebQcp0CM6fumqpbQGrzfOs7lGTc/SfhGT7jAG/MLxKJjR/S+eafKTrpO/hyCyZYM9adb92Bf8r5hKZfAUB0yYtEPn6U0HFX4590Lqau0fr8z8g2bqHi0j/iqBiAaZp0vPEnUpvnUXbBLwuYFlq8k+ZHv4vk8udFAA8XAvDit6f1JNj/RhAEYaVpmhP+1efx34aemOBfD0UzaI6myaiaNcJkCDgkEc2EoMtGRcDJO+ubeHH5btbWR0hkwWUHnwNsksSOxe9R//pdFB99Ib2nX4ZmgixBJKPT/NxPyTbVUXHpn7CX97Oema//gVTdQsrO/wWuvuPQ0zGaH/0umFgFencALd5B8yM3IXlC+RltsFhRTQ9djyg7qLzqL//0TPI/GxPYghVosfb8uFh27wYwDUSXH2ffsbhqx+GsHVOQcIJlWZlt2oLSVEe2ZRtK646CdVcKlFuz40VV1vx1sNsDuwTR6fvC9VlM08BIx/fFBNFubZYm1M4GS2E7B9HhwV7RH3vlIByVgy123icYZlq0jfTu1WR2rbIEyLJJkGScvUdbHuEDphR0dfV0jNTWJaS2zCezZy2YBraSPniGz8AzbDqyv9B144uCkU1ZvyeXn8or7swX//V03PpNgsWudHrRoq00P3ozkreIisv+jGh3ktz0MR1v/BHfuNMoOuHbAGTq19H67E9xD5xCydk/RhAEbECxBxrevpfGJW8x6vJfUTl6GkUeGyUeF5P7l9LQlWRbRxxdhyq/k4EVHtpiGh6HzKjqAJP7l2CXexLsfyccKi74l3SwDeNz0Ft78IVifJ8QV06t5f75OwtEzbKqwcurGqgKHro6KEg2Ss76oaUU/spv8w8ZsKhYRbO+Rdd79xCZ/wSh6VdiK+pFyRm30v7ir+h896+UnH4r3pHHo7RsJbb8FewV/fEMm4FvwllkGjbmZrQH4aweDoBr0FSc/cYTmf8E7kFTD7A2OhS6BUv0VBRBFBFdAeRg+X7WF4Aoo8faaX32Nose7fIj2l04KgaQ3rMWOVhJaMZVAHS9/w+Upi3YghU4a0bgrBmBOfNa1I49pHcsJ71zJbHlrxBb+mLOnmsojpoROGuG5xahQhsuQ8mgtu9G6diDmhMVU5rqrDnsT1DARYcH0RNEcvmtqrvDg2B3WQFH9wz2ft1909CteStNwdCyVsc7N2+lp2IYqShGJk4hBCR/CbZQL7wjZ1kU9pw35yctUUxNJdOwKWfXtYFsw8a8XZe9ciCBo7+OK0ehL6B/N2wmvW0xqa2L0CItgICjzygC0y6yaGEHmekqPK6CFm1FT0YwDR3J5cNW0vuwAyw9FSU890HslYPwjjoxv11p322J1dWMIHjMpQCkd64k8vFjuIccg2+iVfAJf/hQrkr9PWt2jlyVevPHBI+9vCC5ttTEb8dUM5Re/LsjSq7BKr28tKqhJ8H+N0EqdSCjpwdfDHRdxzAMRLEncPxXwS6LeO0y21pjpLIqig59ij3ohkkkpeBzSuztShJNa8iygBuTwaUuIimNiKJSPHYW8fqNdC54juLeQ3D3m4gogF2UqD3nB2x7+BbaXvkNlVfcabHbTvkuaudeOl7/AxWX34ktVEnpOT+h9akf0P7q7yi/8NfIvhJLRPX5n9P1niWiKggCos1B0Qnfov3FXxJd8iLBoy4+7OvUEl1o4SZMXUO0u3IjUuX/dEwg+0vxTzgT/4Qz0TMJMjtX5qjhq0lt+hj4pD3XsJw7yNF5JqBpGpawWFsuLujci9bVSKJxU0HiDYAkI7lDSO79dVncli6LzW7ZcYnyPv8208Q0NExNxdSyGEomp82SQE/HLbuvVBQMreAwgt2NragXzprhlhBrSR/sZbVI/rJPzCmbqJEWiya+17LxtNZ4kLzFlpp6/4mWb/h+8YQW7yS9fSmpukVk6teBaSD5ivFNPAfviOMOOVK1/3H1RCd6vBNDSSNIMrbimgOsvD4NkXmPo8c7KT3rR/sx6ww63/wzeqKLiktuR3J6MZQMbS//BtM0KD3nNkS7E6V1J53v3I2jehihmdZomBZrp+O127GHqig99WYQBFyAzw3RDR/RuOQtRp9yKSeeeio2m0iJx4FmwIaGLtriWWpCHsr9LhKKSlIx8DplakvcmBhoutGTYP+bQNM04vFPxtH78C/pYEuSZCaTSZzOHg+3fxVW7glz4T8W5z2z94cswq/OGsltr6z/1M9I71pF2/M/xzN8Rn7h60bnu/eQWPsuJWf9KL94RBc9R2T+EwRnXElg8nlWJ/zZn6K0bKP8kj/gqBiAkU3S/NjNmEqGiivuyld8tWgbTQ9/B0flIMtD8DOozqm6RXTO+TtGjtZVcH2BcoueXTnIUsSuGIQg20htXYRgd+OoHITocBNf8y5atIXQ9Csx1CzxlW9gZJP57ubBYGRTFv0rRwNT23ZD3p6rL/bKwThy9HDLnuvAGpepq5aCZ7QNPd6BHu/c54OdiaGnE5jZJEa336WqcHCPNcGiiNscVjLucOc9L0V3EMkTRPYVI/lKrXkrf9lBu8bdauvdNPFscx1K64797LpqcPQehavPaBx9RiHtlyTrmSTZ+rWkti8ntX0ZZjpqUdVsTkuZvKQ3jt4jCebmmA+GZN1CUlsWoLTuyFtz7Q/R6SV03DWfSb/Od022LqbyyrvyC7eRSdD8+C2YSsaylfOGULsaaXn8e0j+Uiou/ROi3Uli3Rw63/kLvgln5T1a07vX0Pb8/xZUqbvR9cEDxFe8RskZ3y+Ys/80OGSR7H6skksm9+Y35xyaFtiDrwaGYVBbW8vevXt7OthfAgRBMB977DEuv/zyf/Wp/H8LwzCpa4mzvT2OTRLoimfRgaqgixKPnfKgiznrm3l1dSNJVaGlK4PfLVDu96AZBu3xLKqaYeM//odspI2h37wLvBUoqsXGiTXV0fT0D3FWj6Dsgl9a7LZICy2P5bqBl/4J0eHOM4z27wZGFj5DdMFTFJ3wrTwlF6D99T+QqltE5RV3fKbgmZ4M0/HmHWR2rz7gNcHhscS7KgZYhfHqYcjeoi8kJjBNA7VtN+nda6yRsYZ9ybLkK8nZcw3EXjEQe1m/g1pvmaaJkYpaHeVYey4m6LJiglTUSpCzSSthVjI5erh2wOcAIMrW3LXNkUvKPYgunyWW5g0ieYqQfSVI/lLkYPkhO+V6KorSsj1v15Vt2pKPt7rtupy9R+LsM8YqgueFygyyLTvI7FxBescylOZt1j657rcg27EVV+PoNfSQcYEWaye27GXLPrRjz4HFB6zxhOJTv3uAfegnkalfR+szt+EbfwZFs76Z3x5Z8BTRhc9QdOL1FmOhO37YsoCy836Oq/8E9GSE5sdvAcOwWBeeEJKapeHpH6F1NTD0ujtQvTXYJSjyypjhehbedSMlfYdy6vf/wvi+5UztV0JK1djblSGhZNnWksTnkCj3u+gVctG31ENbTEESweu0MaYmhCh+ec4yPTh83H777fzoRz8COGhc8C9JsAVBMO+//36uu+66r/zY/+1YuSf8mVZeAH/7cDt/eq/uUNbXfH1yb15YsRdV//TfR/fCF5r1Tfzjz8hvNzWV1md+jNK+y6KFlfUteECVnve/uPtPLHhAVVxxJ7K3CKV9Ny1P3Iq9tJbyi3+XT/ria96l6717PpMqrkVbabz/G9jL++EbfyayvxTT0DGSEdRwkzVD1LI9NwOENU/cawjO3qNw9hmNo3IQgiSjZxJ0vPo77FVDEGQbaudeS8G6z2hM08wvGMktC1A76nHWjsVRObAgadYzCbINmyyLrqY6lJZt+xYDScZW3Bt7abcVRzVyURVyoOLIu52Gbqmg52ewxSOeVzfUrCWiEmlG7WrIK6iqHXvyojCC7MBe0d+y5cgFIvtXik1DR2ndQWb3GtK7V5Nt2JSbzxIsEbTRJ1tdatPA1X8iWrSNjjf+SNEJ38Je3v+g5xX++DGSmz7CUTEQW2kf5FAVkidk3aNEmPiat8nWr6fyyr8c8jNgHzU8eOzleZsN09Bpe+lXZHavpfzi3+KsHmYl3E/cipGO5XUDMg2baH32toMEh7cgeUP54LAbyc3z6Hj9Dwcs2p+Fbx3bj4cX7c773T9z3ZSeDva/Ad5++21OO+00OMRC2oN/DqIomsOGDWPDhg3/6lP5r4OmG2Q1I++BfSgYhsnOjgQbGiKopoENEUXXqQy40U0YUukjHFd4fOlOYimV3R1Jgk6J1ngGr8NOSlUJue20NNSz6M7rkYPljLj2jwS8fsLpDBkVulbPoeWdv+CfcBahTxQpXf0nUHruTxEE0ZrPXvZyfozHNA3aX/o/0rtWUX7Rb/JuI3o6RvND30F0+ai4/M5PHaVqe/nXZHatJjD1AkvXQ7ZhZJNo0TbUzr2oOfXu7rVODlXlYgIrLpDcgcOOCQwlTfiDB3DWjsnv24199lyb8vTw7jEzsDyx7aV98radcqhq35p3BJRw0zStJLs7xhdES5vliD7DQI93oUWacjTxfTGBnujKv08OVeGoGmzFBL2GHmDXpSW6cqriq0jvXr2v8SHJeIbNwD/xHPRMHJT0YcUFWryDpge+lROZ7YetqBopUIZod2GqWbINm4gufRHvyFkUn3zjIa/PyCRoeuRGBEmm8sq/5uOu1LYltL/8azwjjqc4pwYeWfQs0flPfqJB9BOUlu35BpFpmnS+dQfJjR8y8KKfERoymawBdgn0bJLN/7gZU1U44Uf34wiWMrzCR+9iH363TEYzqQy4yKoaJR4HAbdMVciDxy4TTSmohknQZcP+KX/DPfjqoOs6NTU16LpOW1vbv0+CLYqiOWDAALZu3fqVH/u/GfvPVdtla5b6UMG51cFehHYItv4lk3uztyvFvG0dn3rMgoXv4t/hrB6Wf01LdNHy2M0gyhYtzB3AUDO0PvVD1HCTpdRdWovStpOWJ7+PraQP5Rf/DtHmILllAR2v/R7PyBPy/oKmadL+8q9J71pJ5WV/PmQy1T1LW3XNvXnFx4NBT0XJNmwis3cD2b0bLEVsTAS7y6KA9xmDHKok27wVM5PAN/HsfTSy/dD1/j+Ir3wzv6+jehjOmpGWQmjFJ2aVDR0t3GR1g1t3orRb9PD9Z5vAUhmV/WVIvmLLpssTRHIHLCqY02t1o+0uqzstWRRxRImc8RoYOqah5zvcRk4UzbLkyNHBclZdWrwDPdZu+VPuB8lbhK24N7ayWqv73t1135+KrmsobTtJb19GavtStK7GfIBiK+uLs3YsavseBNmGq98EfGNOxlCzBYFQ8+O3UHzqLYcUdjMN/VOLBXoqSsNfLzlgdn9/qOEmmh/9LvayvlbRJvd54Y8eJbb0xX1VakOn7YVfkKlfT/lFv8ZZMwIt2kbz47cgOtxUXHYHksuHkU3R8uT30ROdVFx+R4FqulUg+n/snXecFPX9/58zs73e3u7t3cHRexEQpAmK2CkqIoJijZpoEhN7wQKKoog1msRYYqw0FUREURFBUESaVOn14O529/a295n5/THLwgkIGo34/d3r8fCRsDcz+5nZu/28y+v9et2GwduK0ksnHJW+XmTW4bGbuKZfC0b3bnrMRbIG/O/Qr18/vvrqK2hIsH8RlJaWqj6fjyVLltCnT59fezn/Z6AoKntDSbKKgqqoNCqyHNFDNycrbKmOsHJPiGQmh0Wvw20z0MhlQRSgSbEFVVX5YrOfFTtqCaWzNHGa+WhdNZKoklOgrddKOKWwa9UXfPXCPZR0O4NeV44hnMphkERkWWbdzH9Su2w2ZUNuwZhXa46smE3dvBdw9B6B67SrDyp8fot35HjMzbrmC5+3oaRimsVXflQsuX0FvrfHYTtxCO6z/3jEZ7H76Yuxdjr9B49R5SyZmu2k9qwnnRciUzPaaIje2wJzs27o3BXkQtWomeQRY4J01RZqpt2nzRwD+pLmmJqeoHV1G3esN3cM++25tmmCor4dZP27yAb3gpwtHCPojFpX2e7JxwUuJEsRosWJZLIhHEwRL4yNHYkinkHNpvIxQRwlFUVJRJAToYJ9Zy4SIBfxH7IGvbsCfUkzDCUt0Je2xFja6pDRrlwkQLpyPYntK0jtXIWSjy1Es0MrOjQ/UWMSyDlMLbr/InGB/72JZGq20vj6lw9/vqoSmPUYiS1LKLtsEsZG7QDNkrP6zdvRF1dQOnqiFo9u+pLAe49i7TQQ95BbAaj96Fniaz8tMNQsQO2yGfjmv0L5gMsoG3gpag5SMhSbYf2bDxLZtoou107C1eIEGrtMuG0mKorNtC6z47YYqY2nCSeytPBYaeax0aTYitTQrT4uMXPmTIYPH86ECRO49957j58ZbI/Hw5YtW9i7dy+NGzc++gkNOCbsVwZX1AOz1EcK0DdVR4+YXEsi+KJpvt4RPPwBB0EQRDxDb6Pq9VsIvPdoPVq3zlZMyfD7qJl8N/6Zj1Ca9/8tGX4/1W/ciu/dh7RE2dsSz9Db8c98RLPuOP8OrO37k/VfQvirqRhKmuHoOQxBEHAP+itV//kr/vcmUn71M4el/+xXgFSOYlshWZxY2vbF0rYvoAlapHavIbXzW83Xedsy7ThbMaZmXUnvWQ+CiP57M+DFZ16Ps9+lpHat0c7ftYbQ9hX5N9FjLGudn8Fui6GstSZk5m6CteNphWso6QTZun3k6vZpXeQ8FSxXt4/0nnUF24ufE4LRis5WjGT3oG/ZDJ3Ti95Vjq6oHL274pBnq6oquXCNRgur2ky6ajOZqi0FexBBb9Ku5a7Afe5fkKwu5GSU4Kf/Qu9qpFmMZZL11D+zwb2IRtsPduyP1ok/8P6H716ouQyBWY8VlO73Xy+2bj6Rpe9g6zaoQAELznuR1M5VFJ/7V03wJp3A9+54VDmH96KxSGa75on5wZNka/fgHTm+XnItp2L4Z0xANFjwDLv7iMl1x3I7G6q02Z1QMkcoGWN3bZx/fL6VPi3d/Hlg6x+85wb871BXV8dXX33FiBEjeOedd37t5fyfhNfrxefz8dZbbzUk2D8jsopCMiuTyORIpHMIAjQttmo+upKIIKCNpKgqu2rjbPZFMetF0jkw6AUUVGpjGdx2I/5omkRGpnvzYtxWA0u3B/BHMzQqMuOy6KmJpdFJOvRShpY9BhAa/Ds2fPgf4l260nvQaPzRJNWRFG2H/J41/t3UzH2O8qJG6Cs6YO8+lGxgN5Gl72g2VyecScn5d1L9xh0E8jovencF3uH3UfXG7fjeHU/Z5Y8jGsyYW/bA0fNCIstmYmrSGWuHUw77LASd8ahWVkLeHcPYqB30Hq51m6u3ktq1muTOb4msnK2NRkk6jI07EN+wAFPTrhgbta23TxnL29Dkr5PJVG3ROre71xBb/QnRFbMB0BWVYWjULh8TaNRwc4vu9TU8FFkTGqvbp8UG4RrksI9cNEB25x7kRKig3v2zQdIhWV3obG4Mpa2wtOmDrqhMiwmKGyM5PIeM5ynpBKnda8lUbyFdtYX03o0HGgaCiGhxoi9tjaGkGe7BNyEIohZr7fwWvafZLxoXHKwG/n3EVs0hsWkxRQOuLiTXciKM/93xWpx64b2IeiPpqs3UfvAUxkbtcZ/7FwRBILz0XeJrP8V58iVYOw7ADAj7VuCb/x8cHfpT1O8SEklN0AwZts97i/DmZbS78C+4WnZCEGTCiTQGo47KuiQus4FuFcWkcjKJjMLeUAoQMOslUjkFo07CZdZjPEJxrAH/e0ybNg2Am2++mXvvvfewx/wqCbbb7cbv9zNhwgT++c9//hpL+D+JPi3d6CSxYL/19vI9dGrkpC6RKXjpfr29lmgyy9Tle458IRXmbaipZ+H1QzwH0WSj5MJ7qX5DS5LLRk8sdG2N5W1xD7qJwOzHqf34n7gH/RWdw5NPvMfgm/EQpZc8gqVtX4pOu5rQgv8QKirDNeAqnP1Hkw3spm7+v9G5GmFp3QvJ4sRzwV3UTBlDYM7TlFx4zyFf+MbG7QFIbltWEKI6FkhmO9Z2/bC26wccpIC581uSO1YSX/85oClqm5p2KVSkdXYPktmBtf0BsRI5Hsr7Qn5Het9GYt9+RHT5LO15Gsxa9bdAA9Po4YbSVkdcrypnNXGyVFTrQqfjqJkUSjYNsmbTpR602Qp5m64DlhxmRKM5b0viQLI4jrj5qKqKHK8jU7NdE17z7yC1e51Gqc/PXSPqMJS2xNb1bBAklGQEz1Ctsrv3pevJRWuRrC6iq+Zgad0LJB2ZfZtQMykwmFFzGQSdgcg3MzA26YTO4a1HsfsxSG7T1IeNjTsc9l5qP3meTM02Si4aW1AhTVWup3busxibdilQuKPLZxFbNQdHr+GaVYsi45/1GNmAZtu13yYmtPB1kluX4jrzeszNux14L0Um8P7j5CJ+Si99FJ2t+IhrPhxx6IVF20HVGg5eu5Fh3Rpz9+BD76kB/1s8/fTTAFx22WUNCfYvBKPRSOfOnXnppZd47LHHsFgsRz+pAUeFXhRJZ2ViySwGnUg0lWNHIIYkihh1IqIgkMopbPNF2VMbJ55V8NqMBMIZTHodXqOenKxg1AlIokAgmsZp1lPusjCka2NqoynWVEaoCSeRRIkuTZ0EIkmCsQy2UX9ADezgi7ee5cRuJ2J0tcFlUmnuNiNfdh9r/3UT/vcm0OSKp5CdXorPvJ5cXRW1c/+OzunF1LQLJSPGUv3GbfjeeZCyK55A725Cyfl34nvnQQKzH9csPUWJogFXkt77HbVzn8VQ0vywzDVjRQfNMvIonc+DIYhSIeF29h2Jkk1pAl75hDu86E3CvKmx3io6YWzaBVPTzpq4pyhhbNxe8/Q+eZTWHa/eRqpyA5l9G0nvXlsQPwMBXXFjbe3744LixuhdjdC7yjHT45C1qaqaZ6RFNFbafl2WbFpjrslZjcmW32wEQdDmr3X6vDaLCdFgLsxgS2YHgtF6xD1YyWqirNnaSrKB3aT9O8ns3YiSCBWOkRze/D0PQ5Vz5ELVuM/5M6DFBZma7RjLWv/icYGSSZLasw5rh8Nrn6T3biT42cuYW56Eo/dw7XnmsvhnTkCO12n7t8NDLuzD/+5DiNYibXxBZyCx+StCC17F0q4/zv6jAQj7d1EzbSK28pY0Pe9mVEEgB0RSkN6xhOovptBpwBAa9T2PnCqiNaUFTiizYzYaECUBBM1/PpmRyckqeknAYpDIKSqbaiIY9RKdGjnp3MhZeB4/t4p8A44NgUCAadOmMXDgwB/cq341my5FUVi7di3BYBC7/VBRhwb8NNwzcy1Tlu5GBUQBREFAUVV0ogCCQPYg7+vDQRS04P/g5NqoF7m6b3Nmr61ib92hYhL7sd9Oy3rCmbgH3VTvjz/0xRuEl0yj6LRrcOa/0DTazUQs7frhueBOQCD48T+IrZ5L8Tk35mlDKWomjyFbu4fS0RMLCWhk+SzqPnsJZ7/RFOW/5A5GzbT7yQZ20fj6l3+winmsUFWFrH+XVo3etboedUxXVKbRwfMzyXp3xSFJvyrntE2peitZ37YCDUzJU8hAq7Driso0wbH9Fl12tyacYS1CNDuQTPb/yr5KlbMFJXE5EUKJ15GL5q26Iv5C97yeaEiebqZ3N0FJRLB1Owd79/MR9dpzrVvwHyS7B3u3cxEkPYE5z6BzeinqP5rQF29gO3EwKDLhJdMwlLbC3OokdA4vqd1riCx7D+9FY3/6/SgyVa/cCJKe8qv/dsiGE105h+Cnz+PoOwrXqVcAWnW8+s07tLm9y59AMtvzv7uPYmnbF8+wuwGB4KfPE1v1YeF3ESC2dh61Hz6Drdu5FJ/953rvV7fgP0SWvlvv+CPBZpSIpY/efbjh1Jac1amMd1dWIgDDu1c00Mb/h8hkMpSVlWG329mxYweSJDVQxH8BnHTSSeqf/vQnrr32Wp544gluu+22X3tJ/2cQiKbxRVOYDSLhRA6HWYfNqCeSypKTVSwGieXbazEbJKoiKRQZimx6mrgsiKKAQSfiMOuxGHRUBhNYjTpsJh1lDhM6SSSWzLIzGGNjVZhURsWkg1Ayi4wAqQT3XjuMbCLCHf94m305KzqdRCCSQg7u4d3xv0eyl1B62SREowUlFcuP3gQLCXV670Zqpt6DwdsS7yUTEPXGwvf6wRoXuWiAqldvRjRaKL/yqUNoy/vnan+M6OTRICfCpHavzXep15ILVgJaEd3YqD3Gio5aXJAXSfs+ctHaglBYxredrH8nuVANB0dgkt2j2XM591t0uZGsxVpMYHFqibHB/JOTLVVVUDPJgruIJqhaVxgd06y6quvNXCOImiiZImMobYUc8eHoNRzrCWcW1vFrxgWR5e9T99mLlF72OKaK+kXqXLSW6tdvQZD0lF31TJ6VphL44AkSGxbiOf9OrB1Ozf8u3kkuVquNM3qakt63iZop96AvaU7ppY8g6o2ajtAbt0EuQ6urnyZn86DPx9GG6C42/vt2nOXNGf3gS+yuyxBJ5SgvMuO2GjihqRu3RetMNy0ys2Czn0gqR4lVT4dGDlRBpDaaYt3eCB0a25EV6N/SQ1aFdE7GZtRRbDUeceyjAb8M7r77bh577DE+/PBDBg0adESbrl8twb755pu54oorePTRR/ersDXgZ8D+OexsTkEQBGRF3S97BRzaiTbrRdJZLemWRIHr+rfg1SWayJIkClx8UpN6Qf3kpbuZtmw3oUSWXcFDbWtCi94k/NVUXGf8AcdJ5xdeV1WFwCzN97LkwnsKtOzw0hmEFryCo9dwXAOv0Wav3hlPaucqSi66H0urntos9xu3gZyj7Ion0DlLtc7kh08TXze/nlL5fuxXhiwacDXOPiN+pqd7AAWhkj3rtBnuyg0oyQigUa+N5W0xlLfBWNYGQ1lrJLvnkA1Qs5cIFqw4st+jh6sHJd8HQ9DlVcENJu3/6wxaRV6UtPanevAMdkaz6somUfPV7cNe02DWLEOcpVqS72qkVdDdFdR+8jzWtidj63IW8Q0LSFdtwdq+f6FjHFn2HtnaPQUxkdi6z0huX0HRKZcTmP0EOoeXbO1ucmEfek9T3INuQu9pQu2HzyBZijA26Uxq5yosbftiatrlR30OsTWfUPvRs4f9HUjuWo1v2v2YW/bQqs+ihBwPUf3mHSjpuEY7dJWTqtyAb9p96L0tKL1E2zTDS98htODVwkwg5H+npo3F1KQT3osfrCdmF9+wgMDsJ7B1G1So2P8YlNmNVEcP/WzKHEaC8QyZvNigoUH47H+KV155hWuvvZYnn3ySW2+9tcEH+xfCSSedpC5ZsoSysjKsVis7d+5ssOz6mSArKrXxtCacKAlE0zI6QaN/C2gK+av3RLAYtY52oyIzpU4TkWQOURAosRmoS2ZJ5xQcJh1uq+Gwn00okWG7L0oim6M6lEInidhMOlavWc991w3H6W3EXx9/nZRkxKyX8EeSfLFgPiteugdTixPxXjQWQZQ0n+HXb0PQGym//Akkm6tQADW36U3JsDEIolRwaXCdfh2OnsMASO1ZR83U+zA1PQHvxQ/U1wtRFape+QuqnKHRNf/8rwrVR0IuFszHBOtJV64n69+FRksS0Xua5uMCbVTMUNLssCNESiZFNlhJLlhJNriPXKiKXKiGXMSnJbnqYWb7BDE/f20uKIQj5q07vz+DnbfvVLMpzb4zkzziNSW7W4sJnKXoXOVaR91dgc7VmMCsiViOw7hAScfZ++L16N0VlI2eWP9n2TQ1U8aQDeym7IonCk4idQtfI/L12wUBVDWXpWb6/aT3bqR05HhMzbpooqZv3I6oN1J2xZNI1iKUbBrf1HvJ+HZQOvpRjOVtsQJmI0TCESrfvB01m2Lovf+mVfMm1MZy6PTgsZto4jLTqdxJMJnBazVRHU2RSMsYdCI5WaFfmxJqExk27QuzPRCnU+Mi4uksHcudOM0GqsIJRFHAYzPRqMj8gwKGDfj5kE6n8Xq9lJSUsGXLFgRBOL58sAEuueQSbr75ZqZOndqQYP+M6NHMxVvX9eHr7bW4LAbGf7C+kCwfroN9/9BOtCuz1xNVOqtT2RFFlkb3bsro3k0Pq0IuAM7+o8n4d1I3/2X07orCTJEgiLiH3EIu4iMw+4nCl5Gj14XkwjVEvpmB5CjB0eM8SvZTwGdNpPSSRzA2aof34gepefMOaqaPo+zySZqP5jk3kgvuo3bO0+gcJYU5GgBT0y6YW/cm/NVUrB1OQecs/VmfsyBKGMtaYyxrjaPnMG0+ObiX9L6N+f82E/n6ncLGJZodGLzN0Ze00Hwk9yuE2t3azPpBVOP9UDLJAxZdiXB9Klg6oSXN+QQaWUZV8x1RQdA8MCVJ+1+dEdFgQjRYEExWJJNNs+SwOLXuuN192Oo6aEUAQ0kzbRMGjI07kq2tJFOzrbCRGht3ILH5K9RcFkGnx1jRifDiyehdjTBVdEJX3BjnyaNIbPka0WTDUNKMbF0V8fULMJS3QU5GNAE1T7Mf9RnIiTB1C17F2LgDljytfz+ytXsIzHwEfXEFnvPuQBAllEwS37sPIseClF4yAb2rnExgN/53xyPZPXgvGqsJmmxYqFHA2p9C0YArC9fzz3xE82odNqZecp3et4nAh3/D2KTzj1IMPxilThPBxIFEej+aFluoiRxIvLM5ha+31zYk2P8jTJ06FZ1Oxw033PBrL+X/PPR6PX/84x+ZMGECy5Yto3fv3r/2kv5PQBIFvPYDs6y2TI5MTsFi0CEKEEvn6NdGTzSVw6ATaVxkRhRFXBYtARdFAYtRh6Lyg4JLWVkr6luNBiqKJUx6kUZFZjLZDvzxgWd45q7f897fxzHijicosuoJxjN07NGPqP8vbJ75N5Kfv0jjs24g5SxFHjGWmilj8L37IKWXPoql3cm4zvg9dZ+9SHDeCxSf9UdcA69Bjvipm/8yktWFteMATE06U3z2nwjOfbZw3AE6rYjr9GvxTR9L+OvpFPW/7Gd/1jpbMboOp2LtcCqgqVSn920ivXcj6apNJDZ/RWzNJ9rBog69p0meFt5ciwk8TZEcJYXY4vtQFRk5VncYi64YajrxPYp4DtT6FHFBMGoUcSlPETeaEQ2afado1kbHRKurwJo7EpVeVVX0x2lcEPridZREGNeIcd9bs0LtnKfJVG2h5MJ7Csl1dOUHRL5+G1vXc3H0uVgb9frgSdJ71uE573ZMzbogJ8L43h4HSg7vxRORrEXa9T58hvS+jXiGjcFY3la7bwMY9Sq+2ZNIh/2c8tdnwFrMvkiKEpuRZsVWiixGujcvxuswsas2TkZWSGQUUjkZr92E2ShR5jSjk0QMFSIWowEUlc6NHDgtRlJ5GrnTpAdBK6I14H+DuXPnEolEGD9+/FFZI79agq3T6Rg6dCivvfYaK1eupHv37kc/qQHHhB7NXIUA/ODkGQ7MYK+vijCoczmjezctnHO4848El8WAKMD+fEASBYotevyxDJ6ht1H95h34Zz1GeV6cBEDUG/FedD/Vb9yO753xGgWsqIziM/+AHKulbt6L2kbZvj/eix+g+s07tNmryx7D4GlKyUX3UzPtfnxvP6gJphnMlAy/T5vRevehwvX2o/is69n37z8TmPM0pZdM+NG2VT8GgiBo6pp5gRbQqqWZ6i3E139OYsvXZPy7SFVuBDlTOE+0FKEvboTkLNMExhzeAzQwu7twzV8SSjZNtq4KJRmpV6QAQFWQLK5Cd17MJ+W5g1TPDWWtEfRmkjtXYmmtBcU6VyOUdALX6dcWjkvv26gpmqcT6OxuKv7yJpLZ8ZPWrKoqwY//gZJOUHxOfaq2HKuj5u0HQNLjvXgcotGCKmfxz3yETPU2Sobfi7Fxe3IRP77pYxEkPd6R45EsTpI7vyUw52mMTTrjGXKrJsgSr8P39gMg6igZMa4e9TAXCeCfOQHJVnxI4v19aMEqyIdpFozq2bTwt7qlJsq3e0Kc26mMszqVcemLSwqJt14nFv6WG/DLYs+ePXz66acMHjy4YSb4f4TRo0czYcIEnn/++YYE+xeCxaDDctDUVFH+H97vfRUfnEwLgoD0A7GkoqioKmRkmXhawWHWUeowk1Nl9tQl8HTozVlX3swnrz6Fvaw551z5F3SCiqJCi37nkQjspXLRO5RXNKXjGSMJuzujv+Bu9rz7EP73HsV70VgcJ52PHAsSWfoOkqWIov6j8Zx3OzXT7icw52lEswNzixOxdz2bXN1eIkvfRef04ux9gMFmbtEdS8cBhJdMx9y694/SaPkpEE02zC17YG6pzVCrqko2uJfY6o9JfLcQJRmrp+8CmkirrqgcyVWOoagcKT8ypts/MmYvRufw/KLrVlVF892O1qLLj6nVP+D4jAuSO78lunIO9h7nYyxvU+9noQWvaqJmp11TYFDGv1tE8NMXMLfuTXFeXb7usxdJbFqMa+A1WDuehpJJ4XtnPLmwTyvM5+Ox0MLXSGxcRNFpvyvo9hgACZXdc/5FePu3nHrtfTRp3w2DCBkFdtcmaVJspciiJ6vI+KNpjDpNJ6GJy0JlKEGRVU+7Mgc6SaTMaabUYaJtmQNFVdFLIumcdp7FqMeok9Dn9RQa8L/Biy++CMCIEUdnxv5qFPHly5ezefNm2rVrx6WXXsrkyZP/5+towE/Dfhp6OqsgCnBGh1JOa+flvplr2Z87aDSvW/PWRk/W+8LMBvZQ/dYdiJaiQjdayabwTbufdPUWSi/OU3Lqqqh+6w4EUUfZ5ZPQObwkNi/B/96jmJp11ShleS/K6jfv1GZqL5uEZC0qvNf+mdkjzWofDUo2RTaQV+yUsyDqNIssswPJ6kI02X6wipXY+g3R5bNw9ruU5LblKNkUjpPOJ+PfTa6uklxwnzaXHditeVZ+D4LBjGiya5uuowTJ4kQoWHTtt+PQa/6WBxcQVEWjgsn7LTnSWnX7IKsuORlBToQPUNFFHU1vn3HI/Hhi01cktn2jKWiKEtFvP0JOhCk6+ZIDx2z+ivh3i9C7K0jtXouty1nYOp+BosioeSEWORlBSScKz1GyONG7mxyxe/5D2D+DV3Ta73D2vujA55VOaBSwYCWll2osCVWRCcx+gsTGRbgH/RVbl7ORE2Gq37pLm/W7bCIGb0vS1VupmTIGncNL2WWPIZpsKJkUNVPHkPXvpvTSR+oVIJRMiprJ9S3njgQRCroIykEaB83dFv5waqtCoetwWLGrrmEG+1fATTfdxLPPPsvSpUvp1asXQANF/BfC/pgAoH///nz55ZdEIpEGfZbfCALRFJF0DhHN7qvCZUEviazfF2L+dz5W7Aoiyyob3n6c1fNncckdj9G05xn4QilqExli6RwbJ0+gavVCul5+L+Z2/aiNKdSu/oTg3Gfz9ki3AAK1H/6N+Lp5uM68HkeP87RZ2cl3kwtVUzrqYYyN22sjae8/rn3nD70NW6eBhbXKyShV//kLgqQ/ohPJD0FVVY22HfZpWiWCoHWDTbYCK+yHCq3fjwnUXBpnP435lwvuJRusJFO1mUz11oLlZT2IkhYT6I3oHB7NqstozY+N7bfuzMcEko7CgKCqaEKo++OCbBolm0JJJ1DTceRU9ICFZyJcYN8dSVPkv4oLsmnkWLAg2rp/dE0wWtAXlSM5Sn70TLkcq6Pq1b8immyUXfU0ov4AayOybBZ181/CduIQis+6AUEQSO5Yie+d8RgbtcU78iFEvZHQ4rcIfznlwMiinMtbw66kZNjdWNqeDByk7XLiYIoOYknogfTq99kz90XanXUpXS/8M1aTCIpAKpNDRuWsDmU4rQZallixGg3oJcjKoBcFEDWVf7109IRZUVSyioJeFBEbrLz+J9i+fTutWrVi2LBhzJw5s/D6cUcRB2jbti0DBw5kypQp/P3vf6e4+Miquw04frDfDmx/ktC1SRF1iUw9urjOWUrJhfdRM3WMZtE18qHCzJPe00RTEZ92P/53H8I7Km/fddFYaibfpSmLX/ooxrLWlI4cT/XkMdRMvY+yyx7D0rYv7nP/Qu1Hf8M/exIlF9ytKYteNBbftPvwvT1OE5/Ib5rWzmeQ2rWa8JdTMJa3wdyq5zHdo5yMUvfZi8Q3LjqgnH0YCDojtq5nH5YarCoymarNGBq1w9SkMwgSsVVzEEQJa1vNiiYT2A0rZuPodSGmJp2pW/gauiKtmy3HguRCPpK7VmkK3opMLlSFnIxpAmuHm5s6EiQdot6MYLRoFHGTDYOjleavvZ8SZvdo89vf+642NG5PdNUcMlVbMDZuT3rfJozlbZHjITLVW9CXtEBXXIHOXUFq9zoEUUd8/QIiX79LLuJDzaaOvC5BxNKmD64zry/Yux0NqcrvCH72kqYAepDvtZrL4JvxEBn/TrzD79eS63ynW6s0X4Oty9ko6Ti+6WORIz68I8dj8LYkG9yL7+1xiCY73pEPIppsqHKOwKyJB7reByXXGo3sCTK+HXgvuv8Hk2ugUHhS8zRLVVXR60SeHNntqAnzsTBKGvDzIplM8vzzz9O5c+dCct2A/w3uuecehgwZwhNPPMGDDz74ay+nAceAjKxikkQkUSChgkEnoQJZWUUSVUrsJmxGkaa/v4dg1R7eeeZ+RtxXQqM2nQkns9iNOrpcehfZWB1rpjxGxcjx6Jt2wd71bJR4HaFFbyCaHbhOvw73oL+gpGPUzXsB0WjF1vl0vCPHU/PWXfjeeYDSSx/F4G2BZ8it+JJhaj98BtFkw5Lf+yWzHc/5d1AzeQyBD56iZPi9hxSVj4TY2nmEFr11wIbqsBCQbC4aXfvPQ4TWjhQTqJkE5qYnQNMTyAR2E82msXUbhKn5iYQWvore1QhdcWPkaC25cA3JbcvJhasRdUbkeLiQqP442y4BwbBfSdyGaLahc5YilbdFtLrQ2VxItmIMZW0Pe/YPxQXpqi1IDo/WiLC5SG5fAQhEln9AaMGryHlf7CNB5yrHNeB3WNqdfEx3ospZ/LMmoqQTeEc9VC+5jq2bT938l7C0PZniM/+AIAikKtfjnzEBvadJYTQssuw9wl9OwXrCmRSd9juNAv7R30huX07xOTcWkuvE5q8IznsBS+teFJ15fb1CQGTbMnwfv4y1TR/sJ18Kqky53UxGllm1O4yAjrWVIYptJro3c2E3GYilc5Q5DaRzCk6z/piSa9BGN4y/ICuzAYfiiSeeAGDs2GMT4PtVO9gA8+bN46yzzuKuu+5i4sSJRzmzAccDDhZS0+tE3rpOSxZHvbCE3PdmQfaLP1k7n4578C31voziGxcTmPUY5ta9NLstUSIXDVD95p2o2RRllz2G3t1EE6Gafj+6onJKL30EyewoqERaO56Ge8gtCKJEctsyfDMezs9sjy94KCrZFDVvHVuncT9qpo8jtXs19m6DMDU9AcnmRpD0mgp3JllQ28xFAxi8LbF1Pv2Qa8ipmJbYN+6AtX1/Mv5dxNfOw9S8W4EyFl31IdnaPRT1vwzRZKNuwauouTTFZ16PqsjE131GfNOXmk9myx6FJFyO1eEe9Bci38wkW7sHR88LEc0HNnNBkKAwg234r+jxqqoSWvgqmeptKNkUcrwOY2lrslE/cqi6QBMrvLfBcpAaeqlGMbMVI5qdmhBL/jnK8RDpyvVEV32I3t2EsiufOmrVer8AjmgwU3blU0hmrcO1nwKe3LYc99BbsXUaiKqq1H32ItEVs3H2HUXRqVdodK/pY0lXbcI7/D7MrXqSi/ipfutO1FyGsssmoS9urInoffQ34mvnUXzOn7F3G1RvHcH5LxNd9l6hi3I0CGjda71OZOzQTgXrvIbE+fjEk08+ye23385rr73GlVdeWXi9oYP9y+DgmEBRFFq0aEEkEsHn86HX//xiVA34eZFI56iJpkAVKLLocVk12vmeYJwvNvvZuC+CXlQw6PUEgwH+c/eVJGJRhoz5FzGjG6NeJJNV0WXifPLEjaQjAcovm4je2zL/Pf4S0RXv4+x/GUX9LtWKqe88QGr3OjwX3IW1XT9y4Rqq37oLVc5SdulE9J4mGqNp6j2a1eKIBzA1OyCWFVkxm7p5L9QTsvzBe9z6Df53x2Ns3BFr59PRFzfWxo9UFTWb0tw54iFNNyVWS3HeM/lg/OIxweCbiHwzg2xgD45eF+YT/HxMJogIgpifwdaDpP+vbJ6UTIrQwtfI+LajZtPIiTp0xRV55XFffUaeIB6IB5ylSI4SJFux1n032xB0RgDUdIJs7R5iaz4hU7OdsssnHdZ+82Boe/WzxNd+eohCfGLzV/jfm4ipaWe8Ix5A0Bk0JfBp9yHZ3JSN1uapo9/OJfjx37G0PRnPBXeBIBZih6JTrsB58iggL6I37X4MpS0pvWQCot6EiFZAz/i2U/3WXRhcjWhx+WPYHBbaeW00Kbawqy5JTTiNXoQKt4WeLTz0a1OCxaDDYdLhsR/Z77sBxwfi8Th2u50TTzyRFStW1PvZcdnBBjjjjDNo1aoVkydPbkiwfyM4WEjt4CRh/AWd69HEAawdTyNbV0V48VvonGX1aNrW9v2R43XUzXuB4Mf/oPjcv6Czeygd9RDVb92lda0vn4SpoiMlw+/H986D+KaPpfSSCThOOh81myL0xesg6nAP/ivmVj3xDL2dwOzH8c94iJJ8ZVLUmw7Mar/9IGWXP/6DM0xqLktqxwocfS7GNeCqn/6gcllUOYtkydPjVRkl7/O4H9ngXkSzo1DpVnNphHz1Nbb6Y5RkFEvbk8mFqgqKo0oyqvlYixI6VzlZ/06NLmb3/CjPSFXOoiRjyElNQE2Oh/L/BQtBQi4SQI4GDqGqpbJp9K5yDC17oCsqLyiM6orKEM2OY16DpU1vJFsxdfNfRo4Ff7CLLScj1EwfB3IW74hHDyTXef/p5LZlFJ/9p0JyHfr8FaIrZmM/6QKcp1yOkk3jnzFeEyU5/07MrXoix+uomXYfSipO6aWPoC9uDGjzVfG183D2u/SQ5Dq6cg7RZe9h73HeMSXXANef2hK7Wd+QVP9GMHnyZNxuN5dccsnRD27AzwpRFLnlllu45ZZbmDt3Luedd2x/Yw349WAx6miis6BCvQ5chcvCmR29tPBY2eaL4LGZEJoU4Zj0Eo/8cSRzn76dE65/kpzeilECWRZpfPGD7MwLmpZe/jj6ojJcZ1yHko4RXvwWotGC46QLtJhg+lgC7z+OcKEeS+telI56mOopd1Mz7V5KL30UfXFjTSB1yhh8747HO/JBTBWdALB3H0o2sJvI0nfQOb3YTxz8g/eY3PI1oslO6aWP/CAF/AfxS8cEgoCuqIysb4c2VmZ3/7iYQFVQ0wlNOC0RRk6GtZggFjzwX97S8/uFdSQDoqEOvbsJ5lY9tXjAVa7FB07vMRf5Tc26YO18Bnv+NorEtmVHTbDDX00lvvZTnCdfUj+53rYM/6xJGMrbUDL8fi25rt6Kb/pYJIuT0lEPI1mLiK2bT/Djf2Bq2aMgiFr3xRta7NBzGI6+IwHI+Hbge/chdM7SfNdb+0x0gBgLsO+dB5GMVlpdMhaT3YxJLyKIAmaznlh1FJsRRFFPOqvSzmunebEVSRIx6Rvmp38LmDZtGqqqcscddxzzOb96gi0IAhdddBGTJk1i4cKFDBjw8/gTNuCXxeFoq6N7a0JNj330Hd/sPEABcp58CblQNeEvJ6NzegsiYIA2Q5UIE/5qKqLZjuu036EvbkzpqIeomTKGmqn3Ujp6Iubm3Si58B78MyZQM20spaMe0uwU5BzhLyeDAO5z/4K1wymocobaOc/gn/EwJcPvQ9Qb0TlK8I54gOrJd+Gbfj+loyciWZyHvzlJQtCbUFLR/+4h6fQoqVhhE5STUQRBQDQfmCtUUrFCUgeaMra5RWuygT3kIn4cvS9CDvvI1e1DzK/X1uUsgvNeZM/fLkHvboKz32j0rnKAehtpeu93hJdMz9typAszV5raaOKIll2IOiSbC53NjaG0FbrWvQ8IreQr0KLB/N89m4OgJKOA8IPXVNIJfG8/SC5cQ+moh9C7mwCav3hg9hMkNn+F6/TfYz9xsJZcL/gPkWUzsXcfiuv060DO4p/xMKlda3EPvVUr7iSj1Ey7HzkawDtyfEHwJrx0BpGl72DrNghnv/pz+4mtSwnOewFz617adY8BFUUm7h78w0FCA44frF+/npUrV3L11VdjMBiOfkIDfnaMGDGCW265hX//+98NCfZvBLrDUFsFQaDUYcFrN6MXBPZFUyTSOSqat2LknU/xxgM3sP6N8TS/dDwZnYFUBhRHyQHK97T7NV0Vmwv3oJtQMynqPnsJQW/C3vUcvBc/QM3U+/C/9wjeC+/D3OokSi95RIsdpoyh9JJH0LsrKB01geopY/C9/QDeix/EVNERQRAoPusG5GiA4CfPI5psBfXvw0E02bT9U87+9AT7V44JAAJznkaOh1BlTZtFzaY0fZZ0QtNIOcTMdf/92zXbLrsHQ1mbfDzgRXKUoisq1WbP/4uO+MFQMglQ5KPqs0RXziG8+C2snc/AeZAqfHLbMvwzJ2DwNqf04gcRDWYtuZ52H4LRSuklj6BzeIhvWEjth89ganYCJcPuQdDpCX/9NpEl07B1ORvXwGsRBIFsqBrfdC2pLs0Lou6HXkmwZ8Z4yCYZds/zJKyNyanQttyOLKuUOsw0KzaTyal47GY6lDvo09qDyfCrp18N+BH497//jSRJDB78w4W4g3FcfMI33ngjkyZN4rHHHmtIsH9DWLGr7pAudo9mLga087JsZ13ha1oQBNzn3ogcraV27nNItuKCfReAs/9lyMkIkaXvIprsOPuMwOBtoVWep91HzdR7Kbt0IpZWPSkZdjf+9x6lZvr9lI4cX+iIh7+cDIqMe/DN2DqfAapK7Yd/w//ug5QMH4toMGEobYn3orH43h6Hb/pYvKMeLnRBD4YgiJhb9ya+YSFFp1xx5ET8KJBMNnJ1+1BzWQASGxehc5bVo6jvnxXfb2WRC+5F33MY0TUfk9m7kZjRQi7sQ06EtIpxURnRlR9g7zEUc6ue1M1/mWy4BtP3qtSqqqJkM+SitRolTGdEslnR6bWZKyFv2aUzOxDNdk0FNC/QIprtxzyT9t9CyaSIrf0UU/NuR9xMlUwK37vjyVRr9hqmJp21e5SzmpDN5q9wDbwGR88LNDrh/JeJLp+lJddnXo+ay+CfOYHUzm81kbNOA1FSMXzT7ycb3It3xLhCVyO6+hNCC17B0v6UghjKfqT3bSIwaxKG0lZ4zrvzmCvyfxrY5ugHNeC4waRJkwC4+eabf92F/H+MiooKBg8ezKxZs9i9ezdNmx5ZBLABxwcyOQV/NKMxXkkAAMd5SURBVIWsgMdmwGI8EF4qKhgNEnpByAuhqfTs24/K68fy2T/vZ+/spym/4A5SiKiAwdMU74hx2v6/vyBusuE97w6qZz5McO7fESQdts5n4B31EL6p9+KbOUEb+2nZQ0uyp95LzZQxeC+ZgMHTlNJLJlAz9R5t7x8xDlPTExBECc8Fd+GbPo7AB08iSPqCwvT3YW7Tl8g3M4it/gRHzwt+0jP6X8UEuYgPRVEO8StXVZVctBYlFdPYfUYrgt2tzWEbLAh5yy7J4kS0OPIxQRGSxfWL+IYfCdGVHwBgad3niMfE1nxK8NPnMbfupQmt5ffqxJal+N97FIO3uabvY7KR3rdJcwwxWikb/Sg6p5f4hoUEPngSY0VHLUbMz2GHFr6GpeOAgjtJLlqLb+q9qHKO0sseQ+f0FtbgErLUzH6MuG8XNzz8AmedOYCVe8KUO3ToJInvqmIEomk6NHLhtRlo5rHRqZGzIbn+jWHVqlV89dVXXHXVVTgcx65wf1xwE5o0acIFF1zARx99xL59+37t5TTgGLB/DvvJTzZx2ctfs2LXgY51n5buQ/wyBUlPyYX3oPc0xf/eo6Srtx74Wb6SbOkwgNDCV4munAOQn6V+ADlaS83Ue5DjISxt+lBywV1kqrdSM+1+lFSMov6jcZ5yOfH1nxOY/QSqnMV2wpm4h9xCavc6fNPHouSVsk1NT6DkwnvJBHbhm3Yf8vdpTnk4Tx6FmstQN//l/+o5FfW/jNCiN6iZdj+5sA9rx9NI79tENrgXAHPrXqT2rEeVs2T8uzQRMqsLW+cz0HuaEln6LrG180jt/JbUzm9JblsGog5TRUdEnZFccC+hz1+h6rWbyUV8AGRD1ex+8kJCC15BEAQMpa0pu/QRvCPGUXLBXbgH/ZVcsJLkxsXYuw/B2uFUzM26YihprqmU/4+Sa4DQojeQY0Gc/S497M/3e1enKzfgGXobljbahqvRvSdoyfUZv8fRaziqqhD85B9act3j/HxyndY61ztW4R70F2xdzkJJx6mZPo6MbyclF96DOe9BHv/uC4Jzn8PUogeeobfWS6Czwb343nkQyebCO2JsYb7/SHCadZQ5jNxwassfVAhvwPGFYDDI66+/zsknn0zXrl1/7eX8f40xY8YA8K9//etXXkkDjgW18TSyoqKToCqcIp094MEsiQIldiMyKgpgNeood1kYdfFITrnsZgLrFlE172WEgzSBjI3b4x1+H9lgJb7p47Tuqk5P+bB7MDXroqmJb1igJd6jHsbgaYpvxkMkti3DUNKM0ksfAaBmyhgyNdvR2d2UXvooOkcJvrcfyAtvgag34R0xDkNpa/yzHiOx+avD3p+xcXtNdGzxW4W99qfgv4kJLO36ka7aQmztPJJblhJa9Ga9mEAy2TC36klkyXT2vfh7qt647ZC4QE1FEVDRlzSndNRDeC+8F8+QWyk+6wayvu3E132G7YQzsLTqibG8LTqH93+aXGf8u4h8MxNLxwFHtCiNrfmE2o+exdT8REouuLvAKIhvWIj/vUe0Zsqoh5FMtvzc9H2ay8zoieicpcTWf15Irr0jHkA0mLSZ/Pkva3PYQ7T9X06E8U27HzkZwTvyQQyeA3u5RVCo/PBZqjau4PTr7sXRshtV4RRlTgPBhMzG6ihFFj1mg47mJVYqPDbcdhOhVJZ07sDfRgOOfzz++OMA3HXXXT/qvOMiwQZ4/vnnWbBgwY+qDjTg18N+JXFFhUxW4Zl5mwtJdo9mLq7r36KeGLUAdG5eSunFDyCa7PjefoBs3YFiiiCIeIbcgrl1L4KfPk9s7TwATBWd8I4YRy7io2bKPcjxOixtT6ZkmLZp1ky9t2AN4Rp4DYmNi/DPmICSTWPrfDqe8+8kXbWZ6sljCsqV5pY98F54H5nAbmomj0GOHapoafA0xdl3FPH1nxP9du5Pfk6mlidpIlknnU/xGb9H5/CQi/jJRQOoqoK5eTcks42q128h8P5jOE++RPO/9jQltXst5b97lqIBVyHoTVja9kXvbUGmeguIOmJrPkFVFWydT8fR4zzqPv9P/lkK6IsrKL/6b5Rf/TfceX/H/YhvXKxR0n4mOtdPRWLzV/lO8xBMFR0P+bmcimnWbXvW4xl6a2G+SknH8b09juT2FRSfcyOOky5AlXPUfvAUsW/n4ugzAtcZv0fNJPG9/YDWuR58U0FBvGbaWDI1WzXbjbyybGLL0gPV7AvHFCh8ALlYkJrpmmqk9+IHD/UEPdy9ZWR80TSvfLWTe2aurVeAasDxi2eeeQaA++6779ddSAPo378/Cxcu5Oqrr/61l9KAY4QgQDonsy+UpLIuQSB2YAzJpJNwmPQUWw2c2NRF+zIHo/u0YOY/H+GMEVcSWPo+8WXv1ruepXk3Ss6/i0z1FnzvjkfOphD1RkqG34+xSScCHzylJdlmu5ZklzTXCq+bvtK61qMnIkgGaqaMIVX5HTpbMaWjJ6Irbozv3YeIf/cFAKLRQumo8RjKWuF/byLxDQsOc28Cxef8GVDxv/fY4S20jgE/NSYweFtgP+l8sr4dOPtejLP/5VpMozcWYgKA+Lr5CHoTjX73HI6Thv2m4gIlHcf/3qOIRivFRxjBiqyYnU+uu1Ey/L7C/Hp01YcEZj+BsXEHbb7aZCO5fQW+6eOQbG7tc3d6ia7+hNoPnioInxWS63kvYG7TB8/52hy2xnIbSy5cjfcizZVkP0pMKqmv3yCw5nM6Db0WS8fTWLsvxJfb/Cz4rprKujjpnIJJryOSyrKnNo5FLyGJAlv8MXbX1v/baMDxC5/Px5QpUxg4cCAdOvy4Ub/jJsEuLy9nwIAB2Gy2ox/cgF8dfVq6MejEgoLil1sDhU72il11vLpkZ+FYATDqRR4adgKPXXEapSPHg6rgm3Y/uVjwwHGSjpIL7sbU/ERNFXLDQkDrOmtJdg01U+4hFwtiadMH7/B8kjxlDLloLY5ewyk+58/5L9WxKKkY1vb98V50P7ngXqrfvJNsXRUA5lYnadcM11D91h31kv39cJ48ClOL7gQ/fZ7ktuU/6TkJgoChpDmWVj0Lc8PW9v0xN+ta6BQXn/VHGl33PGWXTSq8nt63CZ2rHMnuQWf3YGrSicTmJejsHqwdTsU3fSx1X7wGqortxMFYOw0ktWv1gTfOV0e/XyVV0omCqvaviVTldwRmP4mhvB2ugdce8vNctJaayXeTrt6qqcR2PE17PRak+q27SO/9Ds95t2Pvdi5KNoV/5gTiGxZQdOqVuAZcnU+k79c63+fdju2EM1FSMWqm3Z9PrscUuuHJ7Svwz3oUQ2krvCPG1bP42L/JKokw3osfrDcb90PIyqpWfMopTFm6+xCWRwOOT/z73/+moqKCQYMGHf3gBvziOPXUU2nb9vA2QQ04vuC2GhEFgbp4Dq/diM2kJ5rKkZUVsrJCXSJLyxIbdpOBeDpHoyIzVpMet8PMNbeMpVH30/F9/iqJNZ8UZhcVwNK2L56ht5Hes57aGRNQchlEgwnvReMKSXZs/edIZrvmg13WGv+sicTWzUdf3Jiyyx5DtDjwTb+P5PYVSBYnZZc+grFRWwLvP05kxWxAo2aXjnxIu+bsJwuvHwx9URmeIbeQqdqkseV+lC2Whp8aE6iqQnrvRk1Q1N0UnbMES/v+ZPZtLsQEVW/cRrpyA46TRyEaLVjanfybiQu0Ua5HyIWqKLngzkMK2aqqElr8ViER3m+vpb0+meAn/9TiuosfRDRaiG9YqImSFTfWOtd2D5HlswjOfRZTi+6UXDROS66Xzypcs+SCuxAkvaY+P30cGf8uSi68F1NTTX1eQvO63vflu+xY8Dbdzr6Y0lNGEY7nqAql2FITZUcgyaaqKDuDCXzRNCVWI61LHdQls1RHktgMEg6Tjmg6R07+EVarDfhV8PLLGov13nvv/dHnHjcJdgN+W9ivJN6vjQdR0GassjmFr7fX1vPJFoH+bTy8dV2fwpy23l2Bd8Q4jX4zfSxyKla4rqAzaJ7DFR0JfPAk8Y2LATA17YL34gfJRfzUTL6bXCSgCZpc/CC5sI+ayXeRDVVj7zYIz/l3kN63ierJd5OLBvIzWRNQUjGq37yd9L5NAJibd9NeTyeofvOOwuuFtYgSJRfcrVXF33uE5I5VP+szVBW5QFEXBLGeZ6YcC6JzlCAIorb5tu5d6MA7+46k7PJJSBYXrjP+gMHTFEGUEE3WwrPMRXxUvXoTNVPvIbVnXeG6oUVv4ugzop5q6ZGQ2rOObGDPTwoifgjpfZvwvfMAkr0Y70X3H7KWjH8X1W/eTi5cg3fEOKzt+mmvB3ZT/cZt5ELV2usdByAnI/im3kdy23KKz/4Tzr4jtY7z5LvJ+LZRMmyMdlwiTM3Ue8nUbKdk2D0Hkusdq/DNeBiDpxnekePrzYErWW32O1tbScmF92Is/2mz1CoH/jYacPxixYoV7Nu3j9/97ne/9lIa0IDfHAw6kcYuCy1LrEiSQCorIwoCoiAgALKqkskpePKzqE6zHllRCcbTGI16Rt36CJ62PfDP/TuRzUsALaEBcHYcQNngm0jsXEX1zEdQc9lCkm1q2pnaD54iuvoTRJMN76iHtNfmPEVkxWx0Ti9ll01C52qM793xxNbN144b+RDmNr2pm/cCdQv+g6oqiEYL3hEPHHj981dQ1fpJkKXtybhOv47E5q+onfP0z74/yokwqqocEhMIgogcr0PnKMHavj+2TgPRObzIsdpCTOAZejuixVHQKfm544JsXRXpvRvrxWw/B9RcBv97j5LatRr3oJsKCW3h53KO4NznCj7VJcPGIOj0B70+GWvnMyi58F5tlnr5LAKzH8fYuD1lox9FtDgJLZ5M3WcvYW7bF29e/Da89B3qPtM8sg9Orn1vjyuw3Pbbp+kAswiJtZ+wb96rNDrxNLxnXksonkFVsmRljc2JKGDQixgEAa9dT3OPlWQ6i0UvUGo34TAZSGUV9KJ4yChlA44/TJkyheLiYs4444wffe5xn2CnUikymQzbt28nmUz+2stpwEHo0czFzWe2LXxJiKJAn5Zu+rR0oxO1TVUnCQzqXM6MlZX84fXlPPvZZkCbry658F6ytZX4334AJZMqXHf/TJSxcXsC7086kGQ36Yx35EPI8RDV+YTa1KxLPnmOU/PmHWR827F2OBXvxQ9o3ek37iDj36l90V7+OKLeRM2Ue4hv+rKwjrLLJuVfH1N4r8JajBa8I8ejczXC9+6DhfP+G6Srt+Kf+Qh7/nYJtR88dfiDjkLTEo1WBElXT3Bjv6qcZHNT8adXKb/6bxSffh2B2U+iZJJkaraRC1djadVTCwqOMgPkf/ch9v37j+x5eiS+dx6styH/VCR3fqt5UJodlF7yCJK1qP7Pd6yk+s07QM5ReumjmFucWDiv8ProiZhbdCcbqqb6zTtJ12zDM+xu7CcOJlu3j5o378gn4Q9gadtXS7injCFbuwfv8PuwtOmdf69V+Gc8pFm5jHoI6aBgRpWzmlZApdYp37+OH4NezV0YJAEp73/dp+WRLcga8Ovj4YcfBuCqq/4La74G/OJoiAmObzjNelBV9gbjiIKKJAqoQE5W8EXTRNNZkqksaypDLNkWYF9dgnAii16vp8fvxmNp1Ab/+4+R3PktMlqSLQFNep5JyTk3kty+HP97j6DmO9klF43D1OJEgnOfJbJiNqLBnE+S+1A37wVCi95CtBRRNnoipiadqJ3zFOEl07Vi/rAx2E4cTGTpuwRmTULZT0EfNgbbiUOIfDODwHsT68UnAI6ewyg69UriGxbgf+9RlCM5chwjlEyS0OK32Pvi76l87jLkiP/wB/5AXCAarehd5Qji9wS0fsa4ILZ2HtVv3k7l3y5h38t/Ivz1Oz+ZKr8fSjpOzdsPaDab59yIrfPp9X4up2L43n6A2JpPcPYdhXvQTQco3O88SGzNJzj6jsI9+GYQRerm/7uQSJeOHI+gN1E371/5JPxMSi64GyQdoUVvEVrwKpYOp+I5/86DkusHSO/bhOf8O7G06YNdB+2KJRwGqPtuMfs++jv2lt0pHXIzdUkZGYjnVMqcehq5LJj1IgICFr1AsdVAOJmlxGmm2GrGbtJT6jDiMOspc5p+NtX1Bvwy+Prrr1m3bh2XX375Tzr/uE2wY7EYH330EY8//jgjRoxg8ODBXHfddXz33Xe/9tIacBA2VUfJytoXclZW2VSdt7bKf3GowNj31/HW0t18sqGG6siBjcjc4kSt21y1Gf+Mh+t9Ue/fJAtJdn5WylTRgdJLJqCmE9S8dScZ/y6MjdpRetljIEpUv3U3qV1rMDfvRtnoiaDKVL95J8md36J3V1B2xZPovc0JvPco4a80Xzu9u4KyK5/CUNqKwKyJhBa9Wa9qLVmclI6eiLGsDYH3JhJeOuMnC1SkqzZT/cbtpPasw9rxNGxdzj7scTqbm1wkUPi3HPEjfc8jWrJ7ChuxqsgomQSSyYYg6QpKpIbSVuhcZWSDe0nv20SmZjt7/3UtNVPu0ZLRqfccca0lIx7APeRWbF3PJlOzjZrJY0juWPmT7hs0YRLf2+PQOUu1eShHSeFnqqoSWfYevrcf0DoOVz6Jsaw1qqoSXTkH3/Sx6Oweyq7QXk/v20T1G7ejJEKUjnoIa178pfrNO1EySUovmYC5eTdy4Rpq3rqLXNhHyYhxmFudBORp4TMeQudqROklE5DMB7QfVEUmMPtJUttXUHzOn7G27/+j71UUYEA7L1P+0Jdbz25Xj8HRgOMPlZWVvPfeewwePJhWrVr92stpwGHQEBP8NhBL59gbSiHpRLb54gRjGbKygtEg0arEhtWgZ1sgRm0sw05/nJ21CaKJDAJQ6nHQ7XcPY/ZU4J/xMOm936EHZECnFyjudi7l59xIctsyfDMnaEm23oh3+P0HEuqvpoKk19hLJ5xJ+KspBD/+B4LeiHfEg1g6DiD0xesE5z4HqkrxWX/U9Fs2fUnN5DHkogEEUaL4rBtwDbyWxOYlVL91J7lwfWEzZ9+RuM68nuSWpfim3ltgmP1YqKqK7+1xhL+cgq6onKLTrin4Xn8fv3ZcYO92DiUX3U/RqVciWhyEFr6Kf/bjP+m+AXJhn1Ykr1yPe+ht2LudW+/n2cAeql+/ldSedbgH30zRqVdotll1+6h643ZSu9fgHvRXXKdegZpLE5j1WN6ec4iWSAOB9x8nunIOjl7DcQ++CQSBus9eIvyV1g33DL0NQdIVOtfpfRvxnH8n1nb9sKDt5f6kTHTHcqpmP4GpUTtKLriHaFZHJqdi0UOx1cTpbUvp3dJNxzILdrNEsd2EoooUmQ00L7Zi0AnkFBWH2YDLaqjnF9+A4xMPPfQQALfffvtPOv9n+YQFQThXEIRNgiBsFQTh7v/2erFYjIcffpipU6ciSRJ33XUXGzdupFOnToUbbsDxgY/WVR3y76+315KTlXzVWiUnHzkZtbbrh3vQTaR2fYt/1kRUOVf42X66lrFxBwKznyC2bj4AxvI2lI6eCEDN5LtI792IwdOUsssfR2f3UDN9LPENCzCUtqLsiic01dDpY4mu+hDJWkTZpY9i7XgaoUVvEJj1GEompSXRlzyS35Cn4n9nPHLygA/2fqVSS7uTCS14hdoPnjykqn0sSG5ZCkqORtc9j/ucP2Npd/JhjzOUtyFXt49c2IcqZ4lvXISlde96x1ja9C6IwSU2LsbUTFM9lhNhVEVTqczW7SNXV43eVY79xMFU/OlVGt/wb8oun4Te1YjSSx454lpNFR2wdT6d4jOvp9HvX0TQG4+osPpDUOUswU//pQmTNO1C2WWPoTsoKFAySQKzn6Bu/suY2/TWPkeHFzWXJfjxPzQrjpY9tNfz9ho1U8YgGkyUXf44piadSWxbRs2UMQh6I2WXTcLYqB0Z/06q37wDJRWldNTDmPPPJ7H1G3wztNms0ksm1LNhU1WF2o+eJbFpMa6B1x6y4R8NAiAJGl1yv33dnwe2bkiuj3PsVwm97bbbfuWV/N/BzxkXNMQEvx1kcgqqCjajHkGErKKgE0WyOZlgIk06m2NfKEkwniYQT1FVl0RGQJQEcoqIZLVz4nWPoLO5tG6ibysGQBQkDHpo3u9cSs65kdT2FfjefQglm0LQ5RPqTgMJL3pTc/8QBNyDbsLRdySx1XO1Ir6SwzP0Nhx9R+ULvppWi6PXcEryiuXVr91Ceu93CIKAo9eFeEeMJReuoeq1mw8pMDt6nIdn2N1kfDuoeu2WQ8bMjgVKPES6cgPOfqMpHTkeZ+/hR7QG/bXjAp3Di6V1b42SPnoi9u5DSG5e8pNo8smd32ruJ9EA3osfxNZpYL2fxzcupuqNW1HSCUovnYDthDO187avoPr1W1GSEUpHPYSty9kFllpi01e4Bl6L68wbtDnqaffn9/JrcA28BhSZ2g+fIbrifewnXYB70F81tfBUjJpp95Gu2qxpvrTvjx6wmMBulhCqvmP71EeweJvRevQ49AYTWSCnQCINXoeefbEMRr0OUTLgsZtw24xEkxnKHEYSWRkVKLL879TYG/DfYceOHXz44YcMGzaMJk2a/KRr/NdmbIIgSMA/gLOASmCZIAjvq6q64adcL5vNct1116HX6xk3bhytW7cu/Mzj8eB2u5FlGUk6Ng/aBvz8ONj/elDnchZtOVBRHdS5nHZldnSSSDZ3bAIOthPOQM2mCH76PIHZj2t0nbxFkmi04L34QfwzxmvzTrkM9m7najYcl03CN+1+aqbdS8kFd2Nu1ZPSyyfhn/EwgdlPkA1V4+w7irLLHyfw/iSCn/yTbGAXrtN/j3vobei9zQktfJ1s7R7NQqy4Me5BN2Esa0Pws5eoevUmSs6/E2Pj9tpa9EY8F9xFeMl0woveIlOzHc8Fd9bzsDwaDI00wZ7Yt3NxnnxkQZH9FfSa6feDqmDrcjZ6dxNCi97CUN4GS+te2LqcReCDp9j74h8QzXZKztcsBNJ71hNa/KamhC0IuM/5c6FyXYCqgnjs9bX4hgWo2VQ9Jc1jQbauisDsSWSqtmDvOQzXab+rZ3+V8e8kMOsxssG9FJ16pTYHJojkIgH8sx4ls28Tjj4jKDrlChAEQl+8QXjJtLza971IFifRlXMIznsBg7cFJSPGobMVk6pcj/+d8Qh6I6WjJxY+o/jGxQRmP47B2xLvyPH1fNBVVSX4yT+Jr/sMZ//LcPS68Efd636c0aGU6we0akiqfyNIJBK88MILdO7cmdNPP/3oJzTgqPg544KGmOD4RzIjE4ynkUSRIpMOu1lHbTSNVa/DadKRyOSQFagOJxFErQgZjGdQFRVRUjHrRRxmPS2KwSCBrtSJ7o9PsOz5W6mcNpa2Vz6Crqg1ggKhJFi6nYtb0lH74d/wvf2AJnhltOAecguiyU50+SyURBj34JtwnXolOruH4Kf/omby3ZRcNBbXqVegdzWidu5zVL9xK97hY7Hki7v+GQ9TPXkMxWf+AVu3QZhb9aT8yqfwz3wE3/RxOE++BGe/Swr7mLVdP/RFZfhmPkL1W3fhOu1q7Cedf8x2l6LVieTwkti4CMdJ59ebu/4+jqe4QI7VkdiyFENpq3p7+tGgKjLhL6cS/moqencTSobfW088VM1lqVvwCtEVszGUt6Nk2Bh0Dg+qqhD5+h1CX7yBvqQZJcPvQ19UlmdATkBJxykZfi+WNn3I1lXhe0cbE/ScdwfWjgNQMikCsyaS3L4cZ//LcJ58CYIgaPos0+4nG9hdT/w0CwRSoGxdy963x2IvaUSX308kJZgxCmDOgtepRxYEWnic7AkmaFUiYjfoiWZzZGQVURSI52Q6OMwYdSJiw8z1bwaTJk0Cfnr3GkD4b73YBEHoCzygquo5+X+PAVBV9dEjnXPSSSepy5cfXpV5586djBo1iqVLlwLa5rpjxw6+/vprJk+ezAMPPECfPkc2n2/AL4v9/teZnIJBJ/LWdX3YVB3lo3VVDOpczujeTVmxq45RLy75wc714RD5ZiZ1n/8bS4cBh/gQK9k0gfceJbl9Oa6B1xaSHzleh+/tB8j4duA+90ZsXc5GzWWpnfss8fWfY+00EPe5fwFRIrTwNSLfzMDYpDMlF9yFZHWR3LGy4J3tHnRTgQ6c3rcJ//uTkKMBik65HEev4fXWk9yxisCcJ1FScVynXqltqMewyaiqSmD2EyS+W4ijz8UUnXL5j9qc/tdQVYXINzMILXgNU/NueC9+4JjvM7bmE+rmv4wgiLgH3VSvW6+qKrFVH1L3+b8RjBY8591R6DAf7jNR0nECHzxJcus3WE84C/fZf8rPW71MdMVszK164jn/TkSDmcTmrwjMfgLJUULpyPHonKWANj9W+9GzGBu1x3vxuHrBhaqq1M17gejKD7TP5dQrf/J8lE4UmHZ934YE+zeCZ555hltuuYVXX331qPPXgiCsUFX1pP/R0n6z+LFxQUNM8NuFoqjsrktgEAWysorVKFFsNZDMKhglEZ1OZGcgjqqq7A0nSaZl6hJpiq1GTHoRVVFZt7eOnbUJElkZq16HzSgRSmao3rmDGRP/RC6bpfsfngBnOZEkJPLvHd+wkMCcpzTf44sfRDI7tHGjr98m9MXrmJp1o+TCexCNFhLblhGY9ZgmcHbR/RhKW5Gq3IB/5iOouTSeobdhadMHORWjdvYTJLcvx9JxAO5zbkQ0mFEyWhMgvu4zjBWd8Ay9rZ4GipyMUvvR30hu+RpT8xNxD/prvTGoH0Jy12p808ehd1doxX5Xo1/gk/r5kPHtwD/zEeR4XX58rvXRTwKywb0E5jxFZt8mrJ1Pp/isPyIazAd+XruHwOwnyNRsw97jfFwDf4cg6ZGTEWrnPE1y2zIsHQbgPvcviAYTsbWfUfvx35GsLrwX3YfB25LUnnX4Zz4CqkrJ8HsxNemsCeq+8yCZ6q0Un/1H7N00l4hcNKC52YRrKLnw3oKg2X6k922iZtp96G3FXDLueSrKG7F0Ry2yImMyGCh3alT+MqeZQCxNmcOM02pEVmSMkg6vw0RZkZHWXgdNir9XzGjAcYtoNIrL5aJLly6sXHn0scgjxQU/R4I9AjhXVdXr8v++AuitquqNRzrnhzZTgC5dunD99ddjNBrR6XRs2rSJcDjMkCFDGDJkyH+13gb8dKzYVccz8zbz5dYAiqpVoW89ux1/Hlj/y/XemWt5a+nun/Qe4a/fIbTwVawdT8M95JZ6yZwqZ7XkdNOXOPuOwnnK5QiCgJJOaAqUO1fhPPlSnP1Ha9f6airhxW9hbNRe63baXMTWf05w7nOIJjslw8ZgbNyeXMSH/73HyFRtwnbiYFwDr0XUG1FSMWrn/p3EpsUYKzrhHnIL+qKywnrkeIjauc+R3LoUY6P2FJ/7FwwlzY56j6qc07y+V3+Moaw1xWf9EWOjdj/pef2SyPh3Evz0X6T3rMPSrj/uITfXs7A6ErKhaoIf/4PUzlUYm3bBM+RmdI4DgUguFiT40bMkty/H1KI7niG3IFldqHKO0OI3iXz9DnpPM0qGjUHvriDj34l/5gRyYR+u06/D3n0oajqOf9ZjpHauwn7SBbgGXoMgSkSWv0/dZy9haNQW70VjC1S7yLJZ1M1/CVPzEzWlUcOB+1BVlbrPXiS6YjaOnhdSNPCa/0p8RABuP+fQv4sGHJ9o3rw5oVCIYDCIeJTuTUOCfWz4sXFBQ0zw20RWVggnsuwLJSixG8nKKiaDhNd+kNWhorKxOsyeWs0bu9RhIJ1TKbYa8DpMZBWVWDLLmr0h/OEkRoOErEBjp5GdwQR7t2/nzQeuQ1FUOvzuETLWJiQOClsTW5fif28iele5Jkhq9wD5gurc59B7muK9aBw6h4eMbzu+dx5CSUVwD7kVa7t+5CJ+/DMfIVO9BUffkRT1vwwEQWOqLZ6MzlVOyfl3YijVtBli6z8n+Mk/AYHis67H2un0wn6hqiqx1XOpm/9vEARcp/0OW7dzj6mbndyxksD7k1DlHEX9R2PvcZ7WbT6OoGRTRJbOIPz1dCSTXXPXyDP8fgiqnCOy/D3CiycjSHqKz/4T1o4DDvxcVYiunENowasIeqNWkM+LkaYq1xN4/3HkeAjX6ddi7z4U5BzB+S8RW/UhxqZdtIaJxUl09ccEP3keXVEZ3ovuR1/cmGxwL763H0CO1eI57w4sbfsCWrJfM30sSjJC2YhxGPLK6/uhJdf35zV4HsXj8VDmMJLKKdiNetqU2imxm2nlsbAzEKUqlqXYYqCiyMwJTYtYXRlie02cZiUWejQppkeLBpHT3wqefPJJbr/9diZPnsyll1561ON/9QRbEIQ/AH8AaNq0aY9du3Yd8ZrLli1j8uTJxONxvF4v5eXlnHPOOfWoYQ3432J/5zqdzdtv5edMvy/etGJXHZe+uITM97rXQv4/UdSEHn4I4SXTCX3xOpaOA/AMqd/JVhWZ4Mf/ILbmE2zdBlF81g0IooQq56j9+O/E187Ld63/iqDTE9+4mNoPn0Y02jT7r/K2ZGq2awlbNIDrtGuwn3Q+KDlCC18nsmwmek9TPOfdjsHbElVVia+bT3DeC6AqFA24Cnv3IYUNU1VV4us/p27+yyjpOI6ew3D2HVXP7ulIiG9cTN28F5DjdVjanoyjz8U/2Qrq50TGv4vI0neIb1iIaLRQdNo12LqcddSkU8mmiXwzg8jXb4Mo4RpwNbYTB33vWc2n7rOXUHNZik67uvAss8G9BD54kkzVZmxdzsZ15h8Q9SZia+cR/OR5RKMFz7C7MVV0IhPYjX/Gw+TCPq0a3fUcVEU+0M1u0wfPebcj6k2aR+aiN4ksmYal7cl4zrsDQXcgaDk4ubafdAGu06/7WZQ9H7nwBEb3bvpfX6cBvyy2bt1KmzZtuPnmm3n66aePenxDgn1sOJa4oCEm+G1DUVT2hpIoikIklUVVodhmxGs3YdAdSCjj6RxV4STJTI6N1RGauW2Y9SJehwmrUU9dLM36fRGSmSyJjEwqnSOSyWEzSdSEM0QSSZauXM/Xz9+Gqii0umICWWczcgetJbVrDb4ZDyGaHZSOHF+gHCd3rMT/3qN50dRxGEpbIcfq8M+cQHrfxnyh/jItYfv0X8TWfKIlbOfdgWRzkdq9hsDsJ5CTEY2p1nOYtl+Fqqmd8xTpyg2YW/ei+Ow/FRJ7yBeZ5z5HatdqDOXtKD7rhmPa23MRP8GP/0Fy+3J0zlIcfS7G1vn0Y7LV/CWhZJLE1nxCZOm7yLEglg6nUnzm9UecFT8Yqd1rCc57gax/J+Y2fSg+64/1NFiydVXUzn2W9O61mFr2wD3oJnS2YlQ5pzVJlkxH5yzFc/6dGMvbkA1VE3j/MTJVW3D0Gk7RgKtAVamb/xLRlXO0IvoFdyGabFo3e8YEEAS8F40tFAMyNdvwTR8HqkLzUQ8ilrbhYC309N6N1Ewfi2RxUHrpowUmgk0HqGDQQ7nLTP9WJfRs6eGz9dXodSLFNiMWnUTTEitfbvYRjGfx2E20KbVy8UnNMOqPX7ZiAw6gX79+LF++nFQqdUzx4C+ZYP+sFPGDkclk0Ov1DVL2xwH+8flWnvxkE4qqJcr923i4+cy2h9BgDz5uPwySwAPnd6YukSGazPKvL7Yf9f3CX79NaOFrWNqfUlB53A9VVTW699J3MLftq3k/6o0aNWzJdEKL3sDYuCMlF96DZC0iU7Md38wJyLEgxWf9EXvXszUa2JynSW5dirlNH9yDbkIy20luX0Hth88gp6IU9b8cR68LEUSJXMRH7dy/k9qxMt+tvrHe7LWcCFO34D/E185Dsrpw9r9MS0qPQqdW0gki38wksnwWaiaBsXEHbF3OxtKu3zEl6T8XlEyK5Navia35hNSuNQh6I/Zug3H0vbiewvbhoKoK8Q0LCX3xBnLEh6VdP1yn/x6d46CAo66K4Cf/1LrajTvgHnwz+uLGWsX/2480qrioo/jcv2iU8EwyT8ebj7HpCZScdyeSzaUVTD76G0LeSsVU0QklHcf//iRS21fUm/PWii7/IL72U2xdz6H47D/VL9aoihZUrfpQO2/gtT/bd83o3k155MITfpZrNeCXw1VXXcXrr7/O6tWr6dKly1GPb0iwjw0/J0X8YDTEBMcPcrLCur1hZEUFQcVrNdLEc+j8cCKToyqUwqgT2B6IU2IzYtJLNHZZMOhE/JEkX24N4I+mqY0mMRl12I066uIZ6pIZwokMgWiWndu3sPKF21FlmeaXPkzW06Le+6Srt+J7exyoqmbzmWeFZXw78L0zHiUVwTP0dixt+2pjZJ/8k/jaTzG3PAn3ebcjmWyFgq5gNOMZchvmFiciJ8IaU23L1xibnoBn8M3onKWoikx0+fuEFr2ZLyhfpXWr83tMofi+4BWUeBhr59MpOuWyemyuw0FVVVI7VhJa9CaZ6i2IZge2E87E2vmMY2LI/VxQVZVMzTbi6z4jtm4+ajqOsUlnik65vOCz/UPIBvcSWvgaic1fITlKKD7jD4XuMeS72sveI/zlFO35nX4tti5nayrhgT0albx6i0YlP/MGRKOF+KYvqf3oWQA8g2/C0vZkcrEggVkTSVdu0Bhop12NIEp5BsPftW72iHHoXeWAJq4WmDkBvcVGs0sewuipAAEkFaJZiFVuwP/2OPTWItpf9QhxYwk5tLgXwCyA3ggWSeScro05rV0JH6yuAVXGYzMWrLoWb61FEsBjM9HIZeK8rhXYzfoG9fDjHBs3bqRDhw5cffXV/Oc//zmmc37JBFsHbAbOAPYCy4DRqqquP9I5x7KZfvzxx+zdu5drrrnmkNdDoRAnnHACZWVlFBcX/1frb8CxYfLS3dwzc23h32d1LOWGvJDTwaJnQKHTjQA9m7m4a1CHQiL+Y+jj4aUzCC14BXPbvpScd2e9ziNAZPks6j57GWPjDpRcdF8hEYx/t0jrWluKCnM5ciJMYPYTpHauwnrCmRSf9UcEnYHo8lnULXgVyeLEPfRWzM26IifCBD/+B4nNX2Eob4d78E0YPE0PdGDn/xslFcN+0vkU9RtdLxFO79tE3fx/k967AV1xY4r6XYql/SnHkGjHia35lOi3H5EL7gVJh7n5iZhb9cTUvBu6ovKfPajMhX0kd35LavtykjtWoGbTSA4v9m7nYut6zlGr06qqkty6lNDit8j6dqD3tqT49OswNTuQqCjZFJGv3yW89B0ESUfRqVce6FofVOXXZtZuQufwkN63icAHT5AL1eA8eRTOky/RKtQL/kN0+SwMjfLCJ3YP2eBe/DMeJlu3Tyue5FW/lUwS/6yJpLav0ERp+l9W7/mpikzt3L8TX/spjt4jKBpw1c/6fM/uWErXJkUFJfEGHH+ora3F4/HQr18/Fi9efEznNCTYx4YfGxc0xAS/PaRzMuv2hsnlFKojKcqLTDRz23BbDURTOXKKgtNsQC8J+CNJ1ldFiKVyNCm20KHcgVGvFc0rg3E21kTYuC/CNl+cMqee2mgWq1HCrBNY54sTjKYIxTMooSpWvHA7cjZD+ciHkL43+6vRgcchx+s0H+O8wnYuFsQ/42EyVZspOuUKHH1HAhD79iOC815Esru1kbGy1nnhzUlka3drHdJTrgBJR2zNp9TNfwmgHvU7W7dPG4natRpDWRutW33QyJeSjhNeMp3I8vcBsHcbhKPPCHS2H/49VVWV1K7VRFd+QHLrN6Aq6N1NMLfpjblFd4yN2v/snW0lmyJd+R3JHStIbllKLlQFkk5j2PU4D2PjDke9Ri5cQ/iracTWzkPQGXD0vghHrwvrjZcld62m7tMXyNbuzne1b0Bn96AqMpFvZhJa/BaiwaxRyfMF97r5L2tjdeVt8Jx/F/qiMlK71+B/fxJqJon73L9i7TgAVZEJLXiVyLKZmJp1xTNsDFJeOC62bj7Bj57F7K2g/ZUPgdmF0yiSkQWSGZm6HWvYNW08eoebPjc8isXTiGAsRSqFxpgQwKwHsyRR4jDTttxBeZEFm1FiXWUIi0lPt4oiZFVl494we8JJ2pc7aVdmp02JA4NBxGszYjMdX9T/BhzA6NGjmTJlChs2bKBDh6P/vsMvmGDnLz4YeAaQgFdUVZ3wQ8cfy2aayWRYs2YNJ510Eslkkn/9619MnjwZr9dLhw4d2LhxI16vl1deeeW/Xn8Djo5/fL6VJz7exP7fFgEw6kXGDu3E+A/W1xM9+3R9db0u9SMXnkC7Mjtfb6/l2z0hPt1Qc8zvG1kxm7p5L2Bq0V0TK/neDHB842ICHzyJzunFO+KBQpUyXbUF/4yHUdIx3INuxtrhFO2Ld/FkIkumoS9pTskFd6F3NyFdvZXA7MfJBffh6DmMolOvAElP4rsvCM57ASWdwNn3Ypx9RiLo9MiJMKGFrxFb8ymixUnRKZfX61YXEs8v3iAb2IXO1QhHr+HHRPVSVZXMvk3ENy4iueVrcmHtWUlWF4ZG7TB4W6D3NEPvKkdyliIarT+YGKqqippJkAv7yNVVkandTbZmO+mqzchRTf1dsrkxt+mNtX1/jE06H3VeTJVzxDcuIrL0XbL+nehc5RT1vwxLh1MPooMrJL77groFryFH/Vg6nIpr4LXo7O4Dlf/Fb4Ig4hp4Dbau54IiE/76bcJfTkGyufGcdxumJp3JhX3433+MzL5N2LsP0Wjckp7k9hUE3p8EokTJsLsxNdUS+1y0Fv+748n4dlB89p8OsdpS5RyBD54ksXFRYWb/50yuRUH7+1BU7W+kwQP7+MTYsWN56KGHePfddxk+fPgxndOQYB87fkxc0BAT/PaQySnsCcZJ5xSqwymauS0IgoBBEsgoKjpBQBAEKlxmauNp1laG8VgNBOIZulQUYTXqCCezRFNZ9tUlWLWrlkBcRiRHdTiD12HEF01TbJIIRFJsC6cQFJVg9W7WvXIvcjKK9+JxmCo61VuXHK/D9854MjXbKD7zeuzdtRl9JZsmOPc54hsWYGnXD/fgmxENZtJ7N+Kf9Rhyok6zZuw+FDWXpm7+v4l9+xF6bws8Q2/DUNKcXNhH7UfPktr1rabLcs6N6D1NUFWVxHcLqfv8FeRYEGungRSdemU9kbNcxEdo8RTi6z4DUcLW5SwcPYcdk6CZHA8R37iIxOYlpCvXgyKDpMNQ2gpjWWv0JS3QFzdGV1SKZC2ux/g7HFQ5ixwLkg1Vk6utJOPfQaZ6Gxnf9sK1TU27YGl7MpZ2/eq5bRzx98G/i8g3M4hvWACCgL3ruThPHoVkPbD3Zev2UbfgPyQ3L0FyllJ85h8KRZB09VaCc58jU7MNc9u+uM/+E5LVRXrvRgJzniJXV4Wj90UUnXIZCKKmKL74LW1GftgYDCXNkZNRAu9P0rRZug/Nxwo6VFUlvGQa4UVvYm/ehTP+9Ah6s5VoVsZjNaLX6wisX8Kcv92FxdOYdpc/RHFZIxRUPBY9wXgK0NG1iZOqUAJZFWhTaqO1146kk5AEgXg6h8dqIJ7OotfrCCcyBONZTmruxKTTUWw1YNLrkCSRJsX/O3ZiA44dPp+P0tJSBg4cyPz584/5vF80wf6xOFY6GEA8Hufee++ltraWK664ggEDBpDL5bBarTRv3pwFCxbQvHnzX3bBDThkBhs0kbOTW3sOET37enttPeuurhVONtVEyeQUdJJI5hjtu/YjuvoTgnOfw1jRAe+IcYdYS6Qq1+N/92EQBE01Mr/hyrE6/O89Qnrvd4VZHUGUtMTsgydRc2mKz7we6wlnoWbT1H2e31DdTXAPvhljo3bI8RDBz14i8d1CdMWNKT7zBswtTgS0JL7us5dI792A3t2EolOvwNym70GCJwqJzUuILJlOpmYbosWJreu52LueU0999EhQVZVccC+p3WtIV24gXb1V625z0N+spEMyOxAMFgSdAUEUQVVRcxmUdAIlFUXNZepdV1dUjqGsNcaKjpianoDe0+yYEkw5VkdszSdEV32IHKtF726Co8/FWDsOqFdcSO1cRWjha2RqtmEobYXrjN8XKGWpyu8IfvpPsr4dmFv11GbXHCVk/Dup/fAZMtVbNeXWs/6IaLIR3/QlwY+eRVUVrULd4ZQj2nUAeRGb8SipGCUX3IW5Vc9696Bk05pVx7ZlFJ12Nc7eI4563z8Gzd0WdgcT9UYkejV3Mf2Gw/udN+DXQSaToby8HJvNxs6dO4+5wNKQYP8yaIgJfpuIprJUh1OEU1nKHEbSORURTZ9FJwokMjLN3FaCiTRrKsO4rQbqEhk6NXISz8joBIFIKsvGqjAbqsKoKuzwRbGb9IiSSCSWRlZBRcGkl9jui1IXz1Lt91Mz7T7kSICSC+85RP1ZyaQIzJ5Ecus39UQwVVXVXDEWvpZX7NYsog5WqdZGxv6KZHaQ2LqU2o+eQ0nHKOo3Gkev4SBKxNd+St3nr6BkUjh6D8fZZySiwYSSThD+ejqRZbMQBAF7j/Nw9L6o3phVtq6KyNdvE1s/H2QZc+te2E8cjKnFicckhqak46T2rCO9Zz3pfRvJ+HagZpIHHSEgmmyIJiuC3qQl26qqeWBnkijpOEoqVu+aotGKvrQVxkbtMDXpjLGiUz0h0CNBVWSSW78huupDUjtXIeiM2LqejaPXRfVGxHKxoNbVXj0XQdLj7HMx9p7DCmKyocVvEV05B8nixHXm9Vja9QM5qzVEvpmBZHfjGXILpqZdyEVrqZ3zJKldazRF8XP+jGi0HNDWidXmRwHP0daYy1L78XPE183H2mkgTQb9lfJiM3aziMNippnbxsqFHzL3n2MpatSKU/40kazBgYqAKKg0clkpdZppUWymkcuGURIIpdLkZIG25XZcFgOrdocIJzM0c9kQRJVoKodOFHGYJIosRrb4IsRTMlXhOE2KLZzWvhyv4+jPtwH/W4wZM4aJEycye/Zshg4deszn/WYT7FdffZXFixczbty4embfNTU13HHHHTz44IO0aNHiB67QgJ8LK3bVMWNlJW8v34OsqOh1BzrY2ZyC/iDbroPp5Gd3LGXedzUcRdvsBxH/bhGBD55A72lG6cgH61VFIU8Ne+dBchEf7nP+gu2EMwCtUhv8LK822aQznvPvRGcrJhetJfDBk6R3r8HSrh/F59x4YAZ77nPIsSCOky7AecpliHoTyR0rCX76PLm6Ksxt++I67Rr0rnKtW715CXVfvE4uWImhtBXOfqMxt+5VT1k0tXsN0eWzSG5dBoCpxYnYOp+BuXXvY9rI9kPJpsgF95Gt24ccDSDHgsjJKGomiSpnQFFAELRk22BGMjsQLUXoHCXoXOXoixvXs8U4GtRchuS25cTWzy/Q1EzNumHveQHmlj3qCZildq0m/OVk0pUbkBxeik65HGun0zRf62iA0MLXiK//HMnuwXXG77G0PRmUHOElbxNeMh3RaKlPCfvsJWJrPtEoYefdid5VrlmhzHnqELsOgMSWrwnMfgLRaM0L2rSs/+zScXzvjCdduUETRztx8FHvXxKhR1MXwXiGliU2WnqsLNley7p9YRRFE+1TFBUVrWvdpcLJ6srwIde54dSW3D342OhGDfjl8dprr3H11Vfz1FNPccsttxzzeQ0J9i+DhpjgtwtVVQklsiQyOaxGHUadRHUkiQoUmw0UWQ0oisLayhDb/HG8DiPdKorwxTNY9RJfbvGzdEctdbEUkiThshowSiI7/HFMRoG6aBZVBK9Nx7IdIWJphbisdXVrpo8lG9iNZ+itWDucWn9dikzd/H8TXfE+5pYnaTaO+XGu5I5VBGY/rllBDr4Za7t+mpL1slnULXwNyeLAPeRWzM27aSNjnzxPYtNiDKWtcA/6qyaWFg9Rt+AV4uvmI9lLcJ12dZ7FJZAL+wgteoP4+gUIBhOOHudj73lBvUQ7FwsSXTmH2Oq5KIkwkrMUW+fTsXY8rZ4v9NGfv4IcCZAN7iUXrkGO1SInIiipGGouDXJeDk7SIepNiCYroqUIyeZG5/SiL65Asrt/FIsr499JfP0C4uvnI8eCSDY39hMHYztx0CH3GPlmBrFVH6IqMrYuZ1PUbzSSzYWqyMTWziP0xesoySi2boNwnXqFJk5WuZ7aj54jF6zURE9Pv06zWtu8hNq5z6Hm0rjOuB5bl7MAiK/9lOCn/6rnDgOaNo5/5gTSlRsK3tc6QcCiA50EPZsXEf32U6Y+Nx5v6y70u+ERTGYrkWQOSZQocRpwWYyc0saLySjhNBloXWrHbJAQAINOQhIFsrJCVSiJrKqU2k3UxtIkszKSJJLOyhoDwx9jqy9G5woHJr2eoSeUo28QPTtukEqlKC0tpaSkhC1btvyov4ffbIJ91VVX0bt3b/70pz8BmgfmP/7xD5577jmuvvpq7r///l9yqQ04DA6euf7+DPb+f4964StyCuhEGH/BCYz/YH297vdPQXL7CvzvPYJkLcY7cnyBDr4fcjJKYNajpHatqSd2AfnZm0/+gWAw4xl6O+bm3fLzPjMILXpL21AH3YS5ZQ+UdJy6z/9DbPVcJGcp7rNuwNyqJ2ouo4lyLJmOquSwdx+Ks+9IzXtTkYmv/5zwV1PJharRe5rh6D0ca4dT61lt5MI+Yms+IbbuM+SIH0FvxNyyJ5a2fTC16HFMVKxfGko6QXLHSpJbl5LYshQ1k0CyurB2Goit6zn1Nn9VkbUu/TfvkqnagmRz4+x7MbYu5yDo9HkRtxlEls1EVWQcPS/E2XckosFMas86gh//g2ztnnqqpKnK9dTOeZpcqAZHnxEU9b8MQdJp3uSzHkOOBQt2HYIgHOR7+gaG8taUXHhfPZVS0DZ639vjyAb2HDYYOxJ6NndxWjvvIbPU+3/nXRZDvQLTKW1KDjsC0dxtYcEdA3/iJ9KAnxudO3dmy5YthMNhTKZjL3A1JNi/DBpigv9bkBUVRVULgk6qqrKrNoFRJ5JTNCabUSfhj6Z48Yst+ENJwqkcmaxCkcVE71ZFxJM5apNZEuks/liGaDJNMJ4hK0Mko/G4lFQM37sPaUXTs67XLJy+h+iqDwl++i/0xY0puej+AiVbs+ecSKZqM/Ye5+E67RoEnT4/MvYEuWAl9h7nUTTgKkS9SWNTffo8SiKiFd/7j9b2scr11M17UWNsNWqHa+A1BRZdxr+L8JeTSWz6EkFvwtb1HBwnXVCPwabmsiQ2f0Vszaekdq0GVAylrbC0PRlzm97HzDD7JaGqCpma7SS3LCWx+Uuygd0giJhb9sDW5WytoXCQzkw2uJfIspnE1n4Gioy102k4T7600JRI7VhJ3YL/kPXvxNi4I8VnXa8VLZJRbfxu9Vwkhxf3uTdibtEdJR0nOO8l4uvmYShthee8O9C7K+p5k5uadcVz3h1I1iIgz2Z792GURAj3oJtwdByAnF+fqN0UqW+mUrXgLYra9aLP78aiM1lRUSmxGJAF8NqMnNC4iLpEllROpnWZgx5NXah5ybNypxmz4dAkOScrJDI5JFEklc2xoSrK0m0BQvEMnRo7EEWB87o0xtIwi33c4KWXXuIPf/gDf//73/nzn//8o879zSbYc+bM4a677uKuu+7iiy++YOHChXTv3p1bbrmF3r17/8IrbcBPwcFK4qIA/Vp7GNS5nCc+2UQwnjn6BX4A6b0b8b3zIIgS3osfwPg9kRNVzuU71ppdg+f8OwtJa8a/i8CsiWRrK3H0vZiifqMRJB2Zmm0EPniSbGB3vlp6LaLRSmrPOmrn/p1csFIT4jjj9+icpeSitVp1eu1nCEYLzl7Dsfc4D9Fo0RLtDQuJLH2HbGA3kq0YW7dB2LqeU0/URFUV0nvWE//uCxJblqDEQyCIGMvbYmzWVaNpNWr3P1ESVzIpMtWbSe1ep9HR934Hioxosmuz2R1OxdSsa70NVE6Eia2dp9HFwzXoispx9B6OrfOZWmKdTRP79iPCX7+NkghjaX8KRQOuQl9URi4WJLTgP1o32+HFffYfMbfqiZJNEfriDaLL30fn9OIecgumJp01Svg3Mwh98QaSzU3JBXcVRGSUTJLaD/9GYtNijVp+7l8R9cZ695etraTm7XEoibBGJ2zR/ZifjQAIR7Ck24+DC0ybqqPc997aQ9gaDR3s4wdLly6lT58+3HTTTTzzzDM/6tyGBPuXQUNM8H8b+xNsgySQzsrkVJVyp5lUJse/Fm5lVyhJdSCB1SzS0mNDECWau624rXpkRWF3MME2X5S9oQSBcIpwQiWu5pPsbJrA7MdJbvkaR5+LKTr1ynoJqQTEdq0m8N5EUBU8599ZoJSrcpa6z/9DdMX7WuJ2/p3oixtre9HC14iumI3O1Qj34JswVXRCTsUILfgPsdUfa0ysgddiad8fVIX4us8ILXoTORbE3KonRadcUWBRHWx9CWBp2xd79yEYm5xQb625SIDExkXENy4mU7VJW7/dg6lZN0xNT8BY0RFdUdkvnnCrqqKNp1VuIL17Lcld3x6IUSo6Ym3fH0u7/oVkdv85qR2rNFG2bctA0mPrfAaO3hcVmiGpyu8ILXqD9O416JylFA24Wnt+qMTWzCO08NUDArL9L0M0mOuzCvtcTFG/SxAkvUYJf38SueBeTci03yWFGCX+3SJqP3oG0WilbPh9GMrbcvBgoqrIhD59nsi3cynrcRbNLrgJWdESZr1eR0WRmbIiMzlZxqDTEYpnaVlqxaLX0a7MTlO3jUQmh14SKfmeLV1WVpAVFUO+uFQbS7EvlGS7P0ZtLEM6p9CtqYt+bUpowPEBVVVp3749lZWV1NXVYTD8OPHA32yCDfDpp5+yZMkSEokEN9xwA6IosnTpUnw+H36/ny5duhyzSE0Dfnnsn9fO5JSCrZdRL+K2GtgbSv3X18/W7qFm+jiUZISSC+7G3OrQeDe6+mOCnz6PZHPjvfAeDKWtAC2ZDM57gfjaTzGUt8Nz3m3oXY1Qc5kD8z7WIorP/hOWNn1Q5SyRb2YSXjINVVFw9hqOo/dF2syPfyehL94guXUpotmBo+cw7N2HIBqthSptZPksUjtWatXe1r2wnXCWRq2uZzumkNm3meT25SR3rCJTvQVUBRDQuyvQl7bEUNICvbsCXVE5Omfpj6KV74eSTSNHfGTrqsjWVpIN7CRTs12rRuffz1DaElPzEzG3Oglj4w6HeJCndq0mtnYeic1fgZzD2KQzjh7nY27TG0GUUDIpYqs/JvKN5pdpataVolOvxNioncYAWD5LYwDIWa2bffIojYK/81uCH/+dXKga24mDcZ32O0SDmVw0QO2cp0ntWo2l7ckUD/prQRE0G9yLf+YEsrWVFA24Ckev4YcEHqnKDfjffUgryIwY95N9xkWg32Gs6Q6noJ/K1tcY6NncxdsNM9jHDRYtWsTUqVO59957adTo6AJDB6Mhwf5l0BAT/N9HKivjj6bwR9M4TDoEUcRl0rFoW4At1TG2VIdx2/V0KHegING7RTF1sQwIsCcYZ5s/ytaaKBurQjhMRuLpNP6INp4jKzK+T54nsnou1k4DcQ/6az3mGGje1P4ZD5P176LolMtx9L24MOKU2PI1tR/+DVXO4jrjD5poqSCQ3LWa2o+eRQ77sHcfTNGpVyEaLaT3fkftJ8+T9W3H2KQzrtOvw1jWGiWbIrr8fSJL30VJxzG36YPz5EsKjYBcxE90xWxiaz5BScXQFVdg63Im1o4DD2VdRWtJbl9OasdKUrvWoKSiAIhmB4bSVnnB06boXI3RF5UhWp3HNMN9MFRVQY4FyYWqyQb3ka3dTda3nUz1NpR0XHs/axGmpl0xt+yOueVJh7iL5MI1xNbNJ7Z2HnK4BtFahL3rIOzdBxdG+VKV6wl/OZXUzlWIliKcfUdi7zYIQafXGACfvUymegvGxh0oPvuPmvNLMkLd/JeJr5tfTxenQOX/4jUkswP30NswN+uq3c9BCuLGxh0oH3YP1iIXyRzIgA4QMin2zZ5EYus3ePuNoPuFfyArC6RzOVqX2sjKKma9gW7NnWSy4LYa+K4qSpFFonWpk7alNqwGHTWRFHaTDotBTyOXGZNeIpWVqQonUVUw6UWsBonaeBZJhN21ScqcBnIKNHKacVp+XW/zBhyAoijceeeddO3alSuuuOJHn/+bTrAPxpQpU5g/fz4ul4vS0lKqq6tZsWIF1157LZdddtnPvNIG/FSs2FXHM/M2s3hLABVNAO2MDqV8chQF8dZeGzXhJNG0/IPH5WJB/O88qClFn3XDYedp0/s24X/vUeREmOKz/ljYNEGrcAY//juqIuM6/TpsXc9BEATSVZup/ehZsv6dmNv2pfiMP6BzlJCLBKhb8B8S3y1EtBZR1G+05tko6UhXbSa8eDLJ7csRDBbsJw7C3v28gshHNriX2OqPia2bj5IIIZodWNqdjKVdf0xNOh+i+KmkE6T3bSS9dyOZ6i1kfDsKit/7IRqtiFYXktmOYLQg6k1aQCEImqCJnEPNplAyCZRkFDkeKmzQ+yHZitGXtMBY1hpD4/YYG3coJK/7oSoy6X0bSWxcTGLjYuR4HaLJVqCL7/cCl+Mhoqs+JLryA5RkBGPTEyjqNxpT0xM0n+z1CwgtegM54sfcuheu069D72pUb45N5yrHfe5fCmrg8Q0LCX76vBb0nP77wmcEHPDDlvRaR6J5t0M+//h3XxCY8/QhCvM/BpKojbXvn7E+WBX84EKSQSdyUfcKpnyzu9C9/v7xDfjtoyHB/mXQEBP8/4GsrFAZTGI1SmRyCpIoaB3q2jgmnYiMQiCWQS9IlDrNhFNZmhWb2FAV4q2v97BpXwgFMBsE0lkFkyiSzCrEs2DVq2yZP43QojcxNu1CyYX3HLKfKZkUtR8/R2LDQsyteuIecmuB4ZaLBAjMeYr07jWa0Nk5NyJZi1AySUJfvE50xQdItmJcZ/5B0w9RFWKrPya06E2UZBRrp9MoOuUKdE4vSipGZPksIsvfR03HMTU/EUevCzE1PxFBEFCyaRIbFxNb/THpvRsAAVOzLljan4KlTZ96nWHQEuFsYLcmdlq1WSuM1+4+MF8NmuCp1YVkcWruIgZzXvg0L0CqyKi5DGomgZKKIyfCyPE6TTU8D0FnQO9ppiXw5W0xVXRAV1xxSOE6Fw2Q2LyExHeL8usHU7Ou2Lqeg6VtXwRJrwmgbVtGZOkM0ns3IFqcOHoNx37iEESDiWxtJaEvXtd8sm1uik67GmvH0wCIr/tME5FLxzUbzZMvQdDpyUV81H74DKldazC37q2J0eUT/lwsSPD9SST3rKOo+xDanXcdBr0RUVSoCmt7uBwPsu+dh0jVbKP7yL/SpPdQwmkFUYISp4lMRqbUYcJi0OOy6GhcbCWSzBJOZimyGOjVopgTm7rZG06y3RejxG7EbtRhN+spyc9ex9M5THqJeDqH1SgRT8tYDBK+aAqrUY/DpKPYakQUf13afwN+PvyfSLDffPNN5s+fz5AhQ+jcuTNNmjTBYrHw3nvv8frrrzNjxoxfYLUN+KnYn4Dsn0+9um9zXly0vdDVBg6ZyZYEkI/xV1JJJwi8P4nk9uX1lEIPhpwIE3j/cVK7vsXaaSDFZ/+pIPKVi/jzX9arMbXogfvcG9E5SlDlHJFlMwl/ORUEAefJl+A46QJtPmvfJuo+/zfpyg3oXOU4+43W5qxFiXT1ViJL3yWx6UsQBCzt+mk0sMYdtVlhOUdyx0riGxaQ3LoUNZtGNNkwteiBuWUPze/6CN6YSjpOtraSXKiKXMSvCZzlk2YlnUDNplHlLKiqlmSLOkS9EcFoRTLZEK1F6GxuJKcXfVEZuuKKI857y4kwqV2rSW5fQXL7cpREGCQ95pY9sHYaiKVVz4LdWLpqC9FVczTqm5zF3Konjj4XY6roqAnAbfla88n278RQ2oqigddgbtZVEzj59iNCX7yBkk1rSqx9RyHqjfVFZcrb4Rl6a2HuW81lqVvwCtEVszGUt6Nk2N31rFCAepYcxoqOlAy/r57wyo/BWR1L+ewggT5RgNvObsefB7auNwohCTCqV1NmrKwkmw8cLz6pCcO7VzQk1/+H0JBg/zJoiAn+/4CqqviiWhKiqhBOpomlZHKKQjqngKoiSQKRZJZkTuHEJkWkswozVuxmfWWYvXVxDHoQVQFZFfA4jCTTMgrgsOjYUJmgdt18aj96Fl1RGd4RYzG4GtWLM1RVJbZqDsHPXkayuSg5/66CKJbWHX2Pui9eRzQcEN0ErWBfO/c5sv6dmFr0oPjMP2h08no+12re5/pidLZilHSc6KqPiC6fhRyvQ+9pir37UKwdTyuMf2WDe4mv/5z4d1+Qq9sHCBgbt8fc8iRMLbpjKG15SFwDWsKcC1WTrdtHLlR9kOhpBDUVRzlY+BRAlBB0BkSDWVMZNzuQbMWaAKqzFJ2rETqn9/DvJedIV20htWMlye3LyFRvBUDvaYa14wCsHU8rzJXLqRjxtfOIrpxDLlSF5PDi6HUhti5nIepN5MI1hL6cSnzdZwh6I45ew3H0vBDRYCLj207w03+RrtyAsVF7is+9EUNJc1RV1YTMPnsZVAXXGb/XGhz5xD+5Y5XmDJNN4j7nRopPGPj/2rvv+DrL8vHjn/tZZ5/spJnde0AHtGWUvUFwoCiurz9U8IsDHKiAAo6vCAqogAqiXxXBxVcRQfaGQvemKx1Jm2bn5OzzrN8fT5I2HVBo0tMm9/v14qU0yXOuk5Q89/Xc131dBDUojfhRcNkZy2K21VP/0PdwMnHO/O9bqJxyIp2ZHFGfRmfSJGhoGLpgVGkUnyroSGTQVZWm7jQ1hSHmjS1m5qhSyqM+GjpSxHoa+4V9OuMqwkT8OomMSXN3Bk1RQHgl522JDFnToSBoUBySu9ZD0VGfYG/atIkrr7ySb37zm5x++ukoilcKs379er7+9a/z4Q9/mI9//OODEa50CPZsBvWdf67GOpRW4vvhdQq9n/iSfxEYexylF319n3PLrmMTe+3PxF59CK24mrL3faPvbJTrOiSWPU7nC78FoVJ02n/17JQqWLFmOp69j/TGhWhFlRSd+hkC4+cBkN78pjfnunUrWnENBfM/3NPQTMPs2uWVga16BjebRC+tIzz9LEJTT+0rmXLMDJktS0ltfIN0/RKcVBcAWnEN/pop+KonYVROQC+ueceZlof6/bM6dpLdtYHszvVkG9ditm4F6Ev+g+PnEhgzp+/7aqfjpNa9RGLlU+SaNyN0H6GppxOd8z70ktq+8WSx1/6M2VKPVlTVMyf7ZIRQSG9dTudz93sLlZHHUHzmlXvMEu2ZPZ5LeWNR5n6w74Zvduyg7dEfk2ve7D1QOfXT+5QBulaO9id+RnLtC16p4LlfQmjvrZGIAD42t44/vbG9b4GmKYI/f35+vx3sPTvoA/0a/klHDtd1D/nsokywB4dcEwwfruuStRySWZMtbUkUIWjoTOHXVHBd4mkTRxGYOZtj6orY2ZnigVc30tCRprXbxlChKKxQXRBEaAoKguauFJbjYtkOadNh58bVtDzijV0ve/+3CNbNYO/hoNmmDV7TzHgbhSd/gujcD/SVWOdat9H++B3e2MgJJ1B81lV9na/jSx6j65UHca0c0dkXecec/GGs7lZirz5EYtUzCFXzGpr1jKtyLZPkupeIL3nUu2caAUKTTiY0/Sx81ZP6mnWarVtIbVhIevObfUms8IV61gSTMSonYIwYt8/O/ECzUzFyuzaRbdrg7ZrvWIdrZvr6xATGHU9wwnz0ktqen6lDdvtqEqueJrX+VVwrh696MpHZ7yM48QSEomJ27qR74d9IrH4WhELk2PO8JrGhQuxkF12v/JHEiqdQ/GEKT/k04Rln9q3D2v/zCzJbl+GrnUbJ+V/pG8vp2hZdrzxI98K/oZfUUn3JN/GV1RHxK/h1gaHqBH2CrcteZt3DP8IIRvjsLXfT5quitTuDJiCbs8jYFmXRAGFDpaQghK4o2I4LuDS2p+lK56gpDvDJ+WOYN66MHV0pFFy6szYVET+Vhbsns6RyFjnL8RJ2bXeTv3w3qpP2NVA/l6M+wW5ra2P27Nls27YNgMWLF/PUU0+xcuVKqqur+dGPfoSuy458R6q7n9/E7U+uP6Qu4m9nd6fQGso+eENfp9A9ZbavpO1ft2Onuyla8Ckix13cd0M1u3bR8Z+fkdm2El/NVErOuRq91Lt5pOuXeAlhe4N33urU/+o7C5Ra/xqx1x7GbN3qPak97mLC089C8QVxchmS614iseJJr2GJUPCPmklo8gIC4+f23SR7O3Rmtq0g27CabOPavvNPqBp6Sa33T1EVWkEFarQMNVzslYL5w/t94tzLdWxvHnaqCzvRiRVvxepqxupqwmxvwGxr8EZ5AMII4KuahL9uOv66GRiV4/uu7WRTpDcvIvnWy6TrF4NtoZePJjzjbMLTTkfxhXDMLMk1z9O96B9YHY1oRVUUzP+IN6pLUcm1bKHrxf8lXb/Ya3By2mcITjgBIYT3/X/6XjL1SzAqJ3qjUMpG9nx/XBIrn6bz2V8hVIOS879CcPy+zYyseBut//cDck0bvQXT/A8f8i/PS46t4vFVTeRsF1XA9y6Zzsfm1vV9fO8O+tKRI5VKoaoq7e3tRCIRIpFD79AvE+zBIdcEw088Y9LYkaY7nWVnLEN52I/p2GzvSFFVGMRQFYQQdCYy/HnRdmKpLG3JHAFVUFoQpDCgsb0jjSocLNNEqAbH1EZ5qylOWzxDd8tONj10M9nOJirP+jz6Meftcz9wMgna//NzUutfxT/yGEouuAYt4h3t6psy8sqfUDSDwlP/i/AxZyOEgp3spPPF35Nc9QxKIELB/I8QmXk+QtMxO5uIvf5nkmueBwShKacSPf6Svp3Y3M71xFf8h9RbL+OaWbTCSkKTFxCcdBJ62ai+GO1kJ5ltK8hsX0WmYQ1WR2Nf3GqkDL20Dr24qqcvSzlquAQ1VIgSiHql4Qe497mui2tmcdLd2MlO7EQ7VqylZ02wA7N9O3aio+ezBXppHb7aad66YOQxfZVvruuS27WR1FuvkFz3Mna8FWEECU09lcgx5/ZtYmR3rqd70T+8yj5FJXLM2UTnXooWLcXJZYgv/iexN/6Ga+WIzDyfgpMuR/WHvYcZi/9J1ysPglAoPuXThGaet3vN1rmTtn/dTq5pg9ec9szPEfb5sR0IGmAYCkV+nfbX/8ri/7uP6rFT+O6dDzBt4mgWb2unviXBxuYYOzozGJpCQcjHMVURTp1URSpns2ZnjB2dcZq6TSKGxqjSAJOrCjlvRiVt8RxdKZOioMb4iiiq+u7Ovkv5MRhrAhgCCTbAFVdcQSqVYt26dUydOpWpU6cyffp0TjnlFCKRCOl0mkDg4Gf8SodP727fnqO6NAVGlYbZ1JIAvB3DQxrjdYBOoXuyUzHa//Nz0hsXejfU87/SV2LcW4bU+fwDOLkM0eMuoeCEy1AMv1fSvOJJul75E06qi8D4eRSe/PG+m2Z685t0L/w72R1rEUaQ8IyziMw8v6+0Ode2neSa50iufQm7uwUUFX/tdALjjiMwZg5aUdUec7Odnl3ljZgtW8i1bsPsaMTubu1pRtaf0H0IzeftdAsFXMc7b2VmvafO+6FGyrwGaqV1GOVjMEaMQy+p6ZesW7Fmr0x88yLSW5eDbaKGiwlOOpnwtNP7GsdZsRbiy58gseJJnHQ3RsVYonM/SHDiid6T644dxF59iOTaF1F8QaLzP0x09kUIzcC1TLrffITY638BRaHw5I9747d64rCTXbQ/+QvSGxfiq5tB6YXX9i2A9pRpXEvbP/4Hx8xQesG1BCfMB7yS7kMpmhgR9dGRMvt2qR/6rDxPfTSIxWJ8/etfZ9myZZxwwgk4jsO0adO44IILqKmpec/XlQn24JBrguHHdV3a4lk6UyaqcMjaYCgCTRWkTAddUQj5FDa1xnl8+U42NCcwNMH06gjNCZOwobJsexdZ10EFdCHw+xTiKQefodKZzJCJd7Ppb7fRvWkxBTPOpvSsq7A0HQ2w9ogjsfIpOp/9NULVvbLwPcY4mh07aH/yF2S3r8KomkjJ2V/ou/flmuvpfP4BMtuWo0bLKTzxMkLTzkAoKlasxZsBveppXDOLf+QxRGZdQGBcT0PQbIrUhtdIrnmezPZV4DpohZUExh1PYOxx+Gum9B3FAq9yLLdrI7nmenKtW7DaGzE7duz/Ht97REwzQPEqA1zXxrVM7/OdfXvcCCOIXlyNXlqLUTYKo2Icxohx/SoCHTPjdRWvX0xq45vY8VZQVAKjZxGaciqB8fNQdF/f6LH40sfI7li3uzfNnIvRwsW4lkl8xX+Ivf5nnKS3nio65dPoJd7v5kzjGjqeutfrhTP2OOrOv4qC0hHEsw5JyyWx4kk6n7sfoagUn/vFvjJ+HaiKCorDQZxcinV/+wkb33iOk865mB/e8XPKCqK4wFu7YmzYFWdtUxetsQyGruC4ggXjSzl+dDmaJhCOw1st3bxZ34HruNSWhZhSWcjskUW4riBgeOesq4oC+DQ5z/pIN1hrAhgiCXY2m2XVqlXE43Fqampobm6mtbWVTZs28fDDDzN9+nQuuOACLr300kGIWjpUf3pjO39etJ3VO2M4Duiawk0XTe2bIyyEOOQS8j07hRacfDkF8z+8T2dN1939CxqhUHzGZwlNP3P3k+NUjM7nHyC5+lnUcAlFp32G4OQFXnOSbMprXvLm/+HmUgQnnkjBCR/BKN/jae3if3pPax0b/8gZ3pzI8fNQdL/31LdpA6kNr5Ha+EbfU2k1Woa/bgb+2qn4qiajlVTvG7dtYnW3YcdbsROd2KkYTiaBm0vjWFmwLVzX8b5O1fqdt1KDBSihIrRIKVq0rN+Nu/d7YnU1kd2xjmzDGjLbV2F1NQGgFVQQGD+P4IT5fZ3FXcskvXkRiZVPka5fAkIQGHc80Tnv6xs9YrY3Elv4V5JrnkeoOpHZFxKd+yHUQKTvoUTns/djdTURnHACRWd8tt956uT6V+l46h6cbJKiBZ8kctwl+/9ZLvs3Hc/eh1ZQTtn7b8AoG4kCGLrCdy6cyh9e38q6Xf0bvO1NVehpTNJ/4XH8qCIWb+vsO2d9bc/5a+nI9qUvfYlcLsfNN9/M2rVr2b59O2vWrMEwDK699lqKi/ff6+CdyAR7cMg1wfBj2g47u9KYtkNzd4aRJUEcx2sMGfZ5lQeKgPW7utncEmf5tk6SOZsJFREMTbC9I82b29rpSmQxFIXSAh+m7VDgN+hOZ2nqMvGrYJk56p/9I7te+QvhmglUv/9bKOEyUns9qzY7dtD22E/INW0gOOlkis+6sq+Bluu6JNc85z18T8cJH3M2hSd/ou/j6S3L6Hrp9+R2bUQrrKRg/qWEpp6GUHXsdJzEiv8QX/o4drzVG905/SxC08/sa7xpJztJbVxIasNCMttXgG0hNB++6sn46qbhr56CUTm+r39ML9d1cVIxry9Lon2PvixJ7wG7lcPtSaaFoiE0HaH7vSapgQhqqBA1XIJWUI7ij+y7w59Nkt3plYlnGlaR3fmWF5vuwz/yWIIT5hMYN3f3KNS27SRXPUNi9bM4qRhaUSWRWRfuruozsyRWPkX3G3/Hjrfhq51G4YJP4a/xRlha8Ta6XvgdybUvoEbKKD7zs4THzyegCCoKdZp2NNHw75+TrF+Cf+QMSs6/Zp8eLCqgdDWy45Efkmlr5GNXX8cN132dgM+gIKhTENBJpE2Wbe9k7Y4OXt/Sheu6jCsPMW9sOaNLQySzFkUhA9NxSecsmjpT6JrKhIoIYb9OezKHoSqYjkNNUbBv3rt05BqsNQEMkQR7T6tXr+bhhx+mpaWFyspKTjzxREpLSzn77LNpaWnpO48lHRn2t4Pdm7DMG1MyoOe0HTNDx39+QXLtC/t0Ct2T2bWL9sfvJNuwuqfJ2X+jRcv7Pp5pXEfnM78k17wZo2oiRadd0XcjsNNxuhf9g/iSR3FzaQJj5hCd+0F8tdMQQmAnOkmsfIrEyqewYs0II0Bwwgm7Z0r3nKs2u3Z5Yzi2LifTsBon3Q14T5ONijEYZaN6SsFq0IpGoIZL3rYk/J30G8vR3uiVibduIddcj5PxKgkUXwhf7VT8I48hMHpWXxdR17HJNvbM7n7rVZxMvG+xED72nL7vXXbnerrffITU+tcQmkH42HMpmPsh1LC38+s99f8NmW0r0IprKD7z8wRGz+yL0U520vH0r7wmZxVjKbngmr5u5f1+zrk07U/+oq8jbOmFX0Xxh7lywRgiAZ2ioMHqnTH+tqSRnLXvzn8vgbeQc/HOWI8tC9OZynHJsdWcNXXEPues5Q72ke+6665jzpw5fYlVNptl3bp13H333ViWxd13300w+O5nzMsEe3DINcHwk8xaNHdnCOoqG1vjVBUE0BSFoE+lLOKNoXRdl7eau2mP54ilcsTSWYrCfhJpk85khtc3t9HYmSKWsqgpDlAU9Lo6b2tPkTVNWhNZYmnwq5Da/Brb/nEnQtOZ/OHryI2YQXqv24Lr2HQv/Btdrz6E4g9RfOaVBCed1Jd4OpkEXa88SHzpvxFGgIL5HyE6+0KvGst1SW96k9irfyLXvBk1Ukb0uEu85l6+oNdVe9ObxFf8h8yWZeA6+KonE5y8gNDEk/ruj04uQ2b7SjJbl5HZvqqvJ4o3trMWvXw0RtlIb2xnURVawYj3NLZzT042hRXrHdXVgNm6jVxLfU/TNUAoGOWjvU2A0TO96Sc9D+mt7havTHzti+SaN3s72uOOJ3LMufhHz/RK6tNxEsufoHvxozipLnw1Uyg48WPeWkgInFya7jcfofvNR8BxKDr+AxTOuxRfwE9QhYzl0r36KXY89Rtc26b81E9TMOsCsvsZS7bnhJFTP38zH37f+Xxs3mgiAZ2MaZPKWXQkszR1ZVjZ2EXOslGE4MSxpYT8OkUhg5buDIoiKA75SOdsqnvGcPX+nYylTbKWQ0FA7/tz6cg2WGsCGGIJ9iuvvMJHP/pRrrjiCj796U8zcuTIvo+df/753HLLLcyZI9dAR5I9Oy7DgUcY/emN7QOSZLuuS3zpY3Q+95t9OoX2/zyH+NJ/0/Xi/4IQFJ78CSKzLug32iK5+lm6Xv4jdqKDwIT5FJ78CYxS7xyunUkQX/oY8SX/wknFMCrGEpn9PkKTT+656fY0/1jzHKn1r+HmUij+MIGxxxEYN5fAqGNR+s5iu1gdjWR3vEV21yZyzZsw27bj5tK7A1ZU1FAxatg7a6X4wyhGYI8ScS9TdG3Te3ptZnAyCexU73mrDnB2j/YQmg+9rA6jfCzGiHH4qiehl9b17RQ7uTSZbStJb3qD1KY3cVJdCN1HYNw8wlNP826gioprW6Q2vE58yb+8MnlfiMjM84nOubhv5IjZtYvYKw+SXPMCij9MwYkf9c6t9Txs8Er0n6Hz+d/gmBkKT/io1+RsP03ecq1baf3Hj7A6d1J40uX9Zpr+8P3TmTgiss8DnT2NKwuxvTONbXuVE7bj9o2T23uXWp6zPvo8++yz/Nd//Ref//zn+eQnP0ltbW3fx84880zuvPNOpk2b9q6vKxPswSHXBMNP1rLZ0ZVGRdCdzmHoKgUBndKwD22PHcFk1mRnV5rNrQkyORtNFTR2ptCFYNm2Tho608SyOSpCBpMqo+i6QmsySzJpsqyxi2wODA1MC0aIDhb/9ia6mrZQftJH8M/76H4fWOdat9L++F3kdm0kMO54is+6st/D91zbdu8hcf0SrzT8pMv7eo24rkumfgmxhX8l27gG4Qv1HBm7oG/H2upuI7n2eZJrXsBs2wYIfDWTCYybR2DsHPSS2t0Vdek4uZ1v9Yzn2uyN7exu7Rev4o+ghotQgoXe6E4jiGL0jO7sfbDkOLhWDsfM4GZT2Olub/c70YHb2/Olh1Y4Ar18NL6KcRiVE/BVTewrFXddF7NlC+nNi0hteoNc0wYAjBHjCE05jdCUU/ru+bm27cSXPkZy9bNemfzoWRTMu7RvI8KwTZKrn6L55Yewk12EJ55E2WmfJlQ4AlWDtAnZjh20P3U3iW0rCY+azqjzv4RSUkki58227qVaJq17TBipveQ6Jo+v4+JZdcwbW8rY8gg7OtO4uDR2pOhMZdnamiLs04gENM6fXonlChJZC1wXy3YJ+TUSWYuaoqBMpI9yg7UmgCGWYH/nO99h1KhRfOYznwEgl8uxfv167rvvPlpaWvj1r39NNPrexvJIg2PPjsuKIphSGeUjx9Xtt2FUUdDgz4u2s6Ixdsiv269T6IJPED3+A/uUGYN33rj9yXvIbFmCMWIcxWd/AV/lhL6PO7kM3Yv+j+43H8E1s4SmnELBCZf1nbHubfAVX/woZvt2FH+E0PQzCM84uy8Zd60c6S1LSa1/lfTmRd6OsVDwVU3EP/IYfLXTvBvZHmVgrutix1sxO7xRHFZ3C3a8HTvZiZPuxskkcHIZXCuLa1s9Y7pAqLpXIq77d4/k2KMcTCuoQCuu9sZy7PH9cMwsuaYNZBpWk9m+kmzjOnAshBEgMGYOwYkneh3Fe56Ym51NJFY9TXLl09jJTrTCEURmXdT31B68J9yx1/5CYtXTCEUlMvt9FMz7UN+DBfAWNB1P/5Jsw2p8NVMoOeeLfU3m9uSV9/+HzmfvQ/iClF30DfwjZ/T7HKVn5voza5sPeKb/+FFFXDKzhidWNzG1MsoDr27BtF10VfDQ5+bLRHoIWLZsGX/9619pa2ujurqaU045BVVVufTSS9m1a9d7uqZMsAeHXBMMTxnTpiuZpStjYagKhXuMMnJdl9aeucKqEF65eFOMnAshQ6MtnmHZti4aOhOYpvebfsKICLWlIQKq4InVu6hv7qQjBVnHq1Aq8kORz2Xzo3ezdeF/8NdMpeSir+1TZgz0b7KFoODEjxGd875+D3zTW5fT9eLvyO3ahFZcQ+GJH/V2vHuS9r4GXxteA8fBP3qmNyt63PF9EzByrdtIrX+V1MbXMVu2AF6PFP+oY/DXTcdXMxWtoKJf+baTTWF2NGJ1NmHFmrF6RnQ5qe49RndmvAftPWO6hKIiVK1/iXiwwGuYGi1FKxiBXlSJVlTdb0fcdV2szp1kGtaQbVhFZutyb342YFSOJzh+vtegraexrGNmSW14zZvx3bAaVJ3QlFOIzrkYo3y0d03bIrH6OVIL/0ymq5mi0dMpP+1TaCMmYdqgq4CVpfnVv9G+8K8IzUfFqf9F4cyzURQFTelJvgEDsDt20PzvH5Pc6U0YKT/10/gMnenVIc6ZXkdpVGdsaYT2ZI6oX6O+NUFDe5J1uxIYmkJFoY9LZ9VRWxIinjFRgc50jkTGYkRBgPKoX3YBHwIGY00AQyzBvummm1i7di0///nPWbNmDTt27GD16tUIIbjmmmuoqKgYwGilgbJkWyePLG3kr4sbsBwXY4/RRn9f2sjfljRi2Q6G5p2dvfEfqw56JvbbebtOoXtyXZfUWy/T+dz92IlOwjPOonDBJ/uexoJ3Prv7jb8TX/pvXNskOOlkCuZf2lfC7LoumW0rSCx/gtTGheDYGJUTCE09jdCkk/rGdLmOTXbHOtI95eG5XZu8BmZC8RqPjRjfUyI+Er2kFiVYOOC/4O10N2bbdsy27V7jlF0bybVu7WmAItDLRxMYdSz+0bPw107tWxDY6Tip9a+SXPuCdwMVCoExswnPPJ/A6Fn9xmp1v/F3EqufAyB8zDkUzP8wWqSk38+m69WHiC/5F4ovROGpnyY846z9PgSx091eg7oNr+MfNZPSC6/t+37u7Z2amwm8HgCW7aApAhewbFc2MjvKpVIplixZQi6Xo7q6mpaWFtra2tiwYQN/+tOfmDt3LhdeeCEXX3zxe7q+TLAHh1wTDF87u9Le/GtFkMrZ1BUHSZk2iYxFd8akKGj0nInViWcsdsbSRA2dbe0pWuMZljZ0YNsuIUOhuiiI6bh0Z0zqmxPE01m2daRIehOmCPigrtiPKjSWPPMo25+4d3ejrIkn9ourt+mqFWum4+lfkt68CL2kjqIzP0dg1LF9H3ddl9SG14i98iBm2/aesZ2XEpp8Sl8ybsXbSKx4ksTKp7HjbSiBKKHJCwhNORWjamLfvd3qbiFdv4TMlmVktq/sO7alhoowKsdjVIzFKBuNXlqHVjhiwEd4upaJ2bnTWxO0biW3axO5XRv7jq4pwQL8dTMIjJ6Ff8xstLB3btV1HbINq0mueYHk+ldxs0m0whGEjzmX8Iyz+s6qO2aW5Opnib3xd+xYM9Ga8Sy47GoKx82kPZEjbdl0JnIkNiyk4an7yXU1UzT1FMrPuAIRKKI4omBZDpYAOwedpkt61dO0P/NrVF1nzse/QfGkE9EUl4ChMGtkKTUlERzHYUxpCAtBdypHZzKHZbtsau2mJGRQEvEzoSJCRdSPpiq0J3MoQNin4dNVqgoDMsE+Sg32mgCGWIKdy+X48pe/zJIlS5g71xvXM3nyZC644IJ+pWHSkWfPUnFVwISKCG/tivfbaewt0x3I2dm7G5vdh1ANis/5777Ok3tzsim6Xv0T8SX/Quj+nnNWF/Wbp2wnu+h+8xHiy5/AzaXxj5lN9Lj3950p6v2c5JrnSKx+zjtHJRR8tdMITphPcPzcfiVnTjblNRjbsY7szvXkmjf33dTAO5OtFY5Ai5ahRkq93ehAFMUfQhgBhGrsLhHHe0LsWlncXAYnm/TKwZJdWIl27O5WrK5dfTdv8M5dGyPGeuVg1ZPxVU/pd27dTnSS2vQGqQ2vk9m2HBwbrbiG8LTTCU09HS1a2vd9zu5YR3zRP0hteB1UrWcsx4f67RL0dWV/+Y+7m8Ys+GTfjXhv6foltD9xF3aqm8IFnyR6/L4Nz96t3gVS723zQCXi0tHj/e9/P5FIhHXr1jF58mSmTp3KjBkzOP300/H5fIfcVVom2INDrgmGr/ZEllja7NmldmmNZ9nSnqKywI8LjCwKknNcyiM+wn6d7lSO1TtiLG3ooCzow/KeS1MU1OlOm5i2w6rGGLbj0JHI0p7M0dKdpjvnoAnvOXbYgJQFXbt20PTY7eSaNhKadiYVZ34O2xf0GmUB5h5xpja9Qeez92F17SIwbi4lp30GtaeCDbwkM/XWq8Re/7M3tjNSRnTORYRnnL37GJhjk9mylMSqZ0ltesObzBEtJzjxBILj5/U1Ee29ntm6lWzjOrI73yLbtBGrYwd9s1aE0lONVo4aKUMNF6EGClACERQj6E0X2V+JuJXFzSax0wmcdMzryRJvw4q19J9UIhT0klqMyvH4qibhq5nSr3TdtU0yDWtIb3yd1IbXsRMdCN1PcMJ8QtPPxF83HSEUoiok4jE6lj1OfOm/cVJdGJUTKTjxMkaMm8O4Kj8FAT87ujLs2rKRHU/dR8emZfjL66g770oKxx6LhkPOhoChIgSYlkMm1cXmf9xN91uvER41g2M++nVmTBxLeTRITXGQoK6ytT1JJKBREDTwaRqqIshaNjVFQWIZkzfr24n4NWqKAkyvLiLi1wn5NLZ3JAnqGqURH4mcxcjiEKoiE+yj0WCvCWCIJdgA6XQaRVHo6OggGAxSULD/hbl0ZNmzVBwB9l5NRvY+m/2jx9fxy5fq9/mc9/q31usU6t1Qg5NP8TqF7qcBGoDZ3kDnc78hXb8YtaCCwpM/TmjKKf0SOzsdJ77s38SXPIaT6kIvrSMy60JCU07tN94i17qV1LqXSa5/ta9zuF42ynsSPOpYfNVT9inLshMdmG3bMNsbsDp3evOru1u9UrDM23fF3h/FF/JmZUZL0Qoq0IuqvFFdZSNRI2X9ntC6Vo7sjrdIb1tOZstSb4cdr6N4cOKJBCcvwKgYu7v5i5khte5l4sv+TW7XJhR/mPDM84nOvqjfLrPruqQ3LqTrpd/vnit++hX4Ruw/oXWyKTpfeIDE8v+gl9RRetHX+uZrHgpVAVVRsG3Hu3EKgW3LRmZHsw0bNnDhhReyYYN3JvCNN97gySefZPXq1YwfP54f/OAHh/waMsEeHHJNMHw5PTvOrgudqQwvb2jDchxcFyaUhxlVFiHi1ykM6gghSOcsXtrQQn1rglTWJupTmTeunJBfpStlEkvn2LQrTmciS3N3GsNQ2d6WYNWOOMmejNnXU+HkAoptEXvzIZpe/itGtISq879CoO4YLLyP75lku1aO1KJ/0L7wr97c5mPPpeCEy/a9x9UvpvuNv5NtWI3Q/YSmnU5k5vn9mnU62SSpDQtJvfUy6W3LwbZQ/GH8o2YSGD0T/8hj0Ar6V144ZgazrQGzfTtmxw6vRLzv2FhXv/4qB0XVUEPFaJFS1IIy9MIqtOIqjNI6L5neY9pIX6n4thWkty4js3U5bi6N0Hz4x8wiNPEkAuPm9lvHiKYN2G89SePS53AtE/+Y2RQc/0F8dd6kEQUoMiDixNj5/INsfP0JVH+YkWdcztTTL8ZyVRRXwRGQsRwKDAXTFTSvfpVVf72DXCpOycmfoOT491McUjhu7AjOmlpBWdjHsoYuOpM5ElmTqqIgUysL6ErnaI1nmV7jndWvb0lSEfajqAozqgpI2w7JrEXWslEQ6JpCyKdRET20JnJSfhyONQEMwQR7b67ryhKOo0TvWes/vL6VXd3Zfh+7fG4dU6sK6EzlKAoa/PqlzWxtT/X7nGNrCiiP+nlmXfN7mnHs2haxhX8l9trDKIEIJWd/geCEEw74+ekty+h84beYLfXoZaMoPPnj3izLfgmpSXLdi8SX/Itc82bvpjp5AaHpZ3pPpff4XLO9gdSmN0nXLyHbuNa7KSoqRsVYb/e4cgLGiHFeCdgBOoa7toWT9sZxOLm0t1vt2N4ZbPDOWmkGih5A+ENe4xNV3/+1XAcr1uKVgzVt6GmytgFsq+eM+CQCY2YTGHc8etmoPeZ1u+SaN5Nc9TTJNS/gZJPoJbVEZl9EaOrp+zwwyGxdTtfLfyTXtB6tuIaiUz5FYPy8A/53m966nPYnfobd3Ur0uEsoXPCJfcaLvRcC+OjcOj44q6aveRkgG5kd5VavXs2XvvQlbr75Zk46aXfn31WrVvG1r32Nq6++mosuuuiQXkMm2INDrgkkgHW7uli5PYaqKLR2pzludDFTqwvpSOZQFNE3YunljS3Esxbrd8UJ6irzx5VQUxQCBB2JLDu7EpiO1yDNtAUtXQmeWdtEIueSykIOb2ZyxO/dDwpCOrFtb7Huzz8m1baD2nkXMvHC/0fc1mlJwd4zKOxkJ/FX/kRsxZMITafsuPfhn/MBxF4P67O7NhFf8i+S614C28RXNYnQjLMJTTqp/3zpbIr0lqWkNy8is2Vp3xlnNVKGr2YyvqqJ3rGx8tH7jOrq5bqutzOdSeDmUji5LNgmrtt7BlvxerPofhRf0BvLZRy49NnJJsk115Nt2kiuaT3ZxrX94gqMmeXN6x51LIq++15vp+Ok1r1IYuXTXjd1w0/NnLMomH0+3cHafpsjdqKT9Jt/oWPpEwghGHniJYxYcCk5JUhByE9QdYiG/HSlTRIZEy2XZMOj99K45BkKa8Yx6UNfxSnyrhn2G0yojHLKxHLGloZ4c2snldEAXakcCKgtCdHRnSFnuWiad95fUwWVhQGCukpVYRC/rmA5LqoQmD0PeXyaIn+PHKUOx5oAhnCCvXTpUhYuXMiVV14px3AcZfbenb7k2Co+MX9Uv+7PB9qt1lSBbbvveScbvHFRbY/fidlS781hPuvKvjNFe3Ndh9S6l+l65Y9YnU0YFWMpOOEjPQmissfneXOu48v/Q+qtl3HNDFphJaEppxCcdHJPh+49mpXkMmQb15BpWE22cS25XRtxrRwAQvd5YzmKa3oaklWgRUu9XehQIcIIHtQvftd1cXNp7FSXVw7W3eo1Ruls6hnVtUenclXDqBiLv2aqN6qrdhqKL9TvembHjr6xHGb7dlB1ghNOIHLsuX3dQfd87czWZcRee5hs41rUSBkFJ15GePqZB3x4YKe7vTnkq55BK66m5Lwv46+Z8o7v82AZ8pz1kPWb3/yGRYsWcc455zBt2jRqamoIBAL87Gc/Y/Xq1fz6178+pOvLBHtwDOSaIJVK8Ze//IXp06cze/bsAbmmdHjE0yavbGylM52jJGhwwrhSOlImmiIwbYf2ZI6SkMH6Xd00d6Vp6EhQVxohYKjUFoeoLgqyobkb14V0zsavCdoSWepbUry0fgepnENrt41Ph5KwQVfGJGyoRPwaQZ+G4Vi8+civ2PjC3wkUlFBz/pXYtXP7dasGLzkPqGDFdtD0woN0rXkZYfiJzL7Im5yx11EnOxUjufpZ4iuewupoRGg+AuOOJzTlFK9vyV47xWbbNjLbV5FtWEN2xzrsRHvfx7XCEegltWjF1eiFI1Cj5d4OdKgIJRA56HPZrm3uni4Sb8fqbvHWBB07MNu39+tUrhZU4KuehL92Gv66GWhFVfusY9KbF5Fc9yLp+sVgWxjlYyg+9mymnnIuJQVFbGyOk8x5jcnseDvJRY/Qtew/uLZJ+ayzKD3pIxSVjiAaNGhLZKgoCGE7FomMha5q1L/5FDv+82vsTJLjLvkM08/7JAnToSOeIZG1KAj6GFkaYmJFmJljitnWkqK5O0fErzJvdAmqBs+vbQUBmZzD9NpCRpeFaE/mmFgRpSRsyER6CBrsNQEM4QT717/+NZ///Oe5//77+X//7/8NyDWlw+dHj6/jP2t2ce7UEXzz/Mn7jPMabK5t0f3mI3S9+hBC1Sk65ZOEjz3vwDvHjk1y9XPEXv8LVlcTekkd0bkf8ErH99ohdnJprxnYmufJbF8FroNWXENw/DwC447HVzVxn9dxbaun6dgmcq3bvCZk7Y3Y8Tb2edSgqCi+EEL3I7Q9zmC7Ljg2jpXDzaVxsqn9lo6p4WK04hqvHKxsVE8DlVH9zpqD93Aht2sT6U2LvE6nPXM5fTVTCE09jeCkk1H36Aje+31KrX+N7jf+1jMTtJSCeR8iPOOcfa6/+3Vckmueo/O53+BkEkTnfpDCEz86ILvWvXp3r3/4/ukDdk3pyPK73/2OV155hdLSUoqKimhra+Oll17ihz/8IWecccYhXVsm2INjINcE6XSakpISxo0bx8qVKwfkmtLhk8vZJE2LgKHh0xS2tafwaQpZy2Z7e4qRJSEytsXrm1ppaE/hNzSChsq8MaWEfDo7YymsnM2yxg6SGYfCsEFxwGB7V4LG9gztiTSGqqJrKru6MkyuCpHKObQnsiAEHfEsuV0bWPWXn5DYtZXopBOInPa5vj4jBhDyQ4FfxXEVFCwatmyh+aWHSa5/FaEbhGecTfS4S/Yp8XZdl9zO9d7YzrdewUl3903oCIw7nsCY2aiBfbvdW/F2b03QXN+zJmjA6mzCtbL7fK4wAt7oTn3vMV2211XczHpVb3uO/+z9Wt2PVlSFXlqLUToSo3wMxohx/Rq99rKTnaQ3Lya16Q0yW5bhWlnUcDHBSScTnnZG31GuoICQD+IZsNobaXvzEbrXPAeuQ+XMM6hc8BG0oioypkPIB1F/AMd1ifp1dAXqt9az4R8/p3PjUopGTuZjX/seEydOpjtj0dCWIm7atMWzTKmM0JrIEvZplEV8zKotIeTXyDkuVYV+OuJZljV0EvTpaIpCwBCcOK4Cv67IJmZD3GCuCWAIJ9jZbJbCwkJGjhzJunXr5H8kR7neM9o508GhfwOqPYmere2B+ttrduyg46l7yWxbjlExluKzrsRXPfmAn+86Nqm3Xia28G9eQ5NQEeGZ5xM59tz9dra2E52kNrxGasNrZBpWg2N7563qZnhnsGun9Wsgss/rWTlv57n3DHaqC7unRNzNpXEtE9fpHdMlvJEcmrHXSI5CL6nuOW+1Z1lXv9fqG8uxmsy2lWS2LcdJxbxy8epJBCecQHDiifsdbWJnEiRWPEV86WPY3S1oxdVEj/8A4amnHzCxBsi11NPx9K/INq7BqJpIyTlX9430GCiKoK9zvdy9HloeffRRXnrpJTRN46qrrsJ1Xd588006OjpoaWlh1qxZXHjhhYf8OjLBHhwDXSL+xS9+kV/84he88MILnHLKKQN2XenwS2ZNWuM5EhmTjmQOF/CpCrGecUr17QlUIZhRW0hhwGBTazeLN3fQHM+QsmxGFAQYWRyiJKQzoiCIImBFQxebW+Ps6EzjuC4FQQPLsogEfDR2JEhkTdpiKRpeeoS2Vx8GoVB0wmUUzLmYAkMnFFHwo+IKm27LJSBcFAFb6xtoe+NvJNe+AK5LcMJ8IrMvwlczdZ97u2tbZLat8NYFG9/ASXWBUDBGjMc/6lhvRFfVpH5Hrfp9vet6u8/drX2jO+1UrGd0Z8+xMduEPcZ0oeoous9Lwv1hb0xXqAg1UoIWLUcJFrxtyXi2cR2Z7SvJbFtBrnkzAGqklOD4eQQnnuC9z703DVyH7NblJJb8i+TmRQjNYMSsszj/o1egRstYvL0L04aAAeXRAGdMrKAzZbG9tYMNTz/Iwn/+HkUzOP5DV3HShR9m/IgCsqZDxnRxFXAdh43NcUYUBnhjcxvjK8J0my5VYR+TKqOMLg2hKAqJrEk8bdORzOD3qUytLGTiiAghny4bmA1Bh2tNAEM4wQb46le/yk9/+lOefPJJzj777AG7rpQfS7Z18vfecV62i6II7ANsaauCA47yOn5UEYVBgxc3tJK19j5FtS/XdUmte4nO53+DneggNPU0Chd8qu/J9YG+JrNlKd2L/0lmy1JQNIITTyRyzDl9jTz25mQSpLcsI71liTdTMt4GgOKP7D5rVTEGvXz0PnOqB1rvnO1cy1ZyzZvJ7dpEdud672YPKKFCAiOPxT9mNoHRs/bb5bv3qXx8xX9IrXsZ18riq51GdM7FBMYdf8BqAPBK57pe/iOJFU+i+MMUnvKpA47pOhSXHFvF+IqIPGM9BN1999386U9/4pprruGJJ57glVde4bjjjuNrX/saxx577IC+lkywB8dArwm2b9/OyJEjueCCC3jssccG7LpSfriuS1Msg+N4D96zpo1pO7y1M+6dtVdcciakslmWbO9kRyzH6OIA7ekcldEAY8vD1BYHmT+mlO6sxeodXWxvS+Hgsqszw4SKEFnLoSWeYXNrinjWJJ7MkjJtulsbqP/3fcTWv0GgpJrx519BdPJ8MqaNKiBoCJJZB6Fo2LZFLA2p7jbiS/9FYvl/vN4kZaOIHHuu1/x0r2ov7/055Jo2kq5fTHrLUnJNG/vGdhrlozF6+rIY5WPQS2sP+HB8oDjZlFdJ11LvrQmaNmC2bgNcUDR81ZMIjJ5FYMxs9PIx+13n9JXFL/8PVudOlGAhkZnnUTv/AqoqR1BbEiSkarzV3E0yncXQNcI+nQUTyti57Dn+92c/oqOliblnvY8Pf+EbTBpTR0fSxKepZE2HHV0pLNsh5NfRNcjmHNbs6EIRKi4O48sjFIV8TK8pJGU6qEBZxE9TPEVtUYjx5WF5rHSIOpxrAhjiCXZTUxNVVVWcffbZPPnkkwN2XWnw9TY82zvx2bNUXADHjSqiJZ7dp+HZ2zl7SgUvbWwlY75zcr0nJ5cm9vpf6F70D2/MxHGXEJ37wX5NSfbHbG8kvuzfJFY/1zcHMjT1dEJTT0Mvqtzv17iui9XVRLZhtddcbOdbmO2NfaMyhOZDK65CL6xEKxyBGinxzmAHe8Zx+MIohlcijqohhOI1NbFtXDuH01Mi7qS7sVPeSA67ZySH1eWdt9pdKibQiqvxVXmjuvw109BKag74RNvqbiW59kWSq5/FbG/wGrtNOYXIrAswyt++07djZokv+Rex1/+Ca2aIzLqAghM/dsCO7gdL4O1UK4rAsl2EgM+dPIZvnn/gagTp6HbBBRdw5ZVX9jUryWQy3HXXXfz617/mmmuu4eqrrx6w15IJ9uAYjCZnl1xyCf/85z/ZvHkzY8Yc+uQB6fDImDat8Sy4LmURP37De0DblcrRnsiSsWx8qsLIkhA7OtOs2tHFG1s6UBWHLa0pigM67UkLVQG/rjC9KsrI8jCl4QC6qmBoCrZts6whhgrEszbTa6NURAK0xtOs2xFja3uSLe1JkqksbSkTv6ZhN65g6V9/Rrx5O+FRMxh97mcIVI3DrwrCfo2M6Z0hz1nQ1XMiy8llSK59gcTyJ7zmp5pBYPx8wlNPwz/q2AOel+4d25lpXEtu5zqyTZtwc7vXPmq0HL23L0tBeV9fFiVYgOIPeyXims+rGut9WO06XqWbmcHJpbyd7p4z2Fa8zRvfGduF2bGz76E/eA/+9xzV5auedODqN8skXb+YxJrnSG9aBI6Fr3qyN1Fk4klomk5lgU5FYYCigIGNQ3Msi+I6FIQNso3reP2hn9OwYRUTps7g2hu/z6kLFpDM2biuS0t3BlVRSGUtmuJp6opClIZ0dsSy1LcmMVSXloTXIHdmTSFC8cq/gz6VkqBBznHx6yp+/cAP/aWj3+FcE8AQT7ABLr30Uv72t7+xbt06Jk2aNKDXlgZHXzm45aAIwS0XT2PiiAgL69spChrc9Ohqcj3b04am8JkTRu0zsuvtHKhBmsBrkmYeaOu7hxVrpvOF35F662WUQJSC+R8hMvO8dzwT7JhZ7+z16mfIbFsFuBiV4wlOPInghPnoRVVv//W5DGbrVnKtWzHbGzA7GrG6dmHFWsA23/ZrD4qqoUXL0ApGeGO6SmrRy0ZilI1+x4cIVneLN1pk/SteB3TAVz2Z0LQzCE1e8I5f79oWiVXPEHv1IexEO4Gxx1F06mfQS2v3+dz9/fxUZd/Rbr10VXDpnFo+OKsGkF3Bh4tf/vKXrFmzhhtvvJHy8t2z5Tdu3Mj3vvc9brvtNioqKt7mCgdPJtiDYzDWBIsWLeL444/nC1/4AnffffeAXlsaHK7rsq09RTpn05nKoiAYUejHRVAU1DEth+Z4loCm4DNUgrrC6p3dvLaxBYRgxfZOSsI6XRmbgID548sYXxHxjjfpChtbugloGtVFQZpjaRwXKqM+okGDjlSOXNaiviNNOmOxtT1J1rIJ+1VwBS3xLGY2y+rnHmHV47/DSnVTMOVkRp15ORU1Y2hL5khlvCZee96ifIANJHdtIrHyaVLrXsTJJFACUW/+9YQT8ddNo8DQUQRkrX2v4boOVtcuci1b+vqyWF3e6E4n3T0g33slWIhW2Du+sxa9tA6jfDRqtPxtj166Vo701uWk1r9GauPruNkkSqiQ0JRTCU8/C6PMm0Nv4N3PNRVGRBVKQgEMQyNkaKR2bmHRI/eyccnLREsq+OhVX+XUCz5IMuetDSuK/BQHfCSzFqmcRXnER8aymVARZWNLnJzlIHDZ0JKgIuqnPBygPGpQVRgg7B+4Hi7S0eFwrglgGCTYy5cvZ+bMmVxxxRXcd999A3ptaXDs3dBMVQSqAMtxMTSFBePLeHptMy5eKfgJ40p5eWPb215zMGSbNtD14v+S2bYCNVxCwfxLCc84+6Cab1ndrSTXvUTqrZf7ZknrJXUExh2Hf/Qs/NVT3vZs8p5c1/V2ohMd3lmrdDdONoVrZnCtHK7dO7lTeGO6VB1h+FGMIEog6p23Chf1nLM6uNIo17bINq0nU7+UdP3ivnNXeulIgpNPJjT5lAPuzu99neSa54m9/mesrl0YVRMpOuXT+OveudlYccigI5l72885flQR1503WSbTw1BzczM33ngj2WyWT33qU5x44ok4jkM8HmfOnDmsX7+eQGD/o23eLZlgD47BWBMAzJw5k+XLl5NOp/H75SzbI53rutS3JmhNZPFrCp0pi6hfo6YoQDJn49MUXNd74J7MeslvWyLLm/UdxNJZ2rvTtCRMEKAKhbljiyj0+ehK5WiOZ0lmTMIBlZKwn1mjihAurNwRw3GgJZZG1QQFQQPbdjEUhaxt09KdoTRssKMrS9in0pnIsa5hJ6ufeIgtL/8dx8xROP00CuZ+mILyamI9u9cKYChQHBR0pV3StvdnVs8ub3LdS6Q3L8I1Myi+ENFxsygYN4dA3bGYwRIEXpJtAyreA+f9Tbl2zIy3Jkh2eWuC3jPYZs8Z7J5qOISC0HSE5vPWBP4wajCKEipCCxe/q2aiVqzFGytWv9ibh21mEL4QwfFzCU0+xdud3+NomA9QVXBt0DXwG4KwT6fUbGb14//LutefIRCOsuCDn+GYcz9MbXkR7XGTSVVRAppCQ3uKEyaUkcxadKVyzKgpwrQdVEWwM5YmYzok0znW7OymsihA1K+zYHw5JRHfe/hbKB3tDueaAIZBgg0wd+5c3nzzTRKJBKFQ6J2/QMqrJds6+civXsfa43x1766lKuCy4+v4+9JGTMtB1xS+c+HUfrvah1tm20q6Xv4j2R1rUcPFRI97P+FjznnHXdteVqyZ1IaFpDe/QabBm38tNB++6kn4aqf1zMAev89YrMPJyaW9M1c71pFpWEN2x1qvhLynwVlg3PEEx89HL64+uOuZWZKrnyX2xt+xY83eeLOTLicw9riDbkioClBVBdt2UJT9Vx6MiPq4+/LZMsEeplKpFHfffTcPP/wwNTU1jBo1ijVr1nD88cfzwx/+cMBeRybYg2Ow1gR//vOfueyyy7j11lv5xje+MeDXlwZeMmOxorELXREgvGqzEdEAacumOKjTkTQRAvy6SlnYR1siR1siQyyVpbEjyVu7khT4NXbFM5SFDRRFpSORxVBddsWyZG0XRYGTx5dTUxRkRUMX1YUBkqZFY3uacRURWuMZaooChHw6jZ0pyiIGy7Z34VddMg5sbEqQzpo0Nzez/tmHaFv8OK5tUTDlZIrnf4hg2eieGMHQNTTFpS1mo+pevuu4kHTwkuBty+neuJBU/WLMhDdnWiuuwV83DV/NVK8vS2ElWs/9cgBq2N4V13WwOnaQ3fEWmca1ZBtWY3U1Ad487IJxcwiOm4cycsY+k1TA27muLdKxHOhKmQR8KnbzJrY9/xCta15D94c478Of5gtf+hL4wjR1pSkIamxsTqJrClUFAQK6QiRg0BHPMrosTGHIIOL3uswLF7qzFq9uaGFbR4qqwgCOC5MqwowdUUBBQDYxG44O15oAhkmC/c9//pNLLrmE733ve9xwww0Dfn1p4P3pje1855+rcRwXTVPAdbEdF72n2zP0L/Vdsq2TX764mWfXNQ/oKC+lZ7rVO13SdV0y21YQe/0vZLevRPGFCM88j8jMC9+2GdrenGyKTMMqMltXkGlYhdmyld5X14qrvfEYZaP65l1qhSNQ9IF7GutaOaxYc8/MywavHL15C2bH7vPfekkdvrrp+EfOIDDymP02ZzkQO9lFfPkTxJf+GyfVhVE5kYITPkxg7PEIIRhZHGRbx/7P0x9bU8CKxljfz0IBLptb1/fwpbEjxUv7qWTw67JD+HDjuu4+D2qeeOIJcrkckyZNora2lmDw4B6AHQyZYA+OwVoTWJZFdXU1qqrS2NgomxodJTKmTSxtIgDbdbFsl6KgTtivk7VsHAd8moLSkzi5rsv2jhS7YmmeXtuEZbsk0iaVhX6EUEhkbRKZLFvbkpSGdRIZh6BfZ1RJmGlVUboyOXZ2pchZLooQlEV9TKwooDzipz2ZYd2OGG81dxPxGwRUiAR87IylaOxMs25nB91dnbS8+n90LH3c6ykybjYjT3o/VVPmoKkqOVeA42DaLvGUScivkLUcUjlvlztlguK4dLVswdy+jNS2VaQa1vadu1b8YYyKMYQqRiOKR6KW1KAXVb1t1+93q7crudW5E7O9EbNtG7mWLeSa6/eII4Kvdir+uun4R80kWFJL2BDYFqRcb6d971ozH1BbbBDSBduWv8LWlx8htmUVeiDMsedcxvyLL+esY8dSUxikOZ5lVWMXWcemwKfhOIJJlVEmVEXoiHtVCBNHRNAUBb+uYrsuGdMmnbV4a2eM7Z1pfJpAV1XmjC4mGjSI+jRKI7J6Zbg43GsCGCYJtm3b1NXVkcvl2LVrF6oqGxkcDfZsdAbvfHZ2MGdlH+jc9v5kd66n+42/k9q4EIQgOOEEIrMu2O9IjnfiZJNkd64n27SB3K5N5Fq2YMea+32OEir0RmyFilCChaiByB5zsPWeOdiK18zEsb2ycTOLk03ipONek7NkhzfSI9HR79pqpAyjfBRGxTh8VRMwqibudx7n23Fdl1zTBuLLHie57iWwTQJj5hA9/gP7dFR/u++zpnoNyno/z6cfXPWCKuDasyfy36eNe1dxS0e3VCp1wBvm/m62h0Im2INjsNYEAD/4wQ+44YYbeOihh7jssssG5TWk/Oo9u92R9MrAO9M5r75agbZ4jkzOYmd3hmQ6RyxjowrB+PIQJi4za4pJmzaxdI4JFRFa4yaaCqVhH5brUhkN8PKmZpIZb1Z2c1eKyqIAilBo7EywsSlOVzpH0FBws0m2vfooTQv/hZXswl9Wy7FnXcro489G+IMEdYUdXRlMy8EEMqaFgsPO+O73UmhAzoSsbZNu205253rMXRvJNtdjtm3DtXansEL3oUbK0CLFKMEir+TbF/bGcOkGqHrfcTCv+amJY2ZxsymcbAI71Y2d6vLWBPG2va7txygbib9iLEU141ArJ6EW1mALBcv71hLWwW9AyNDwKRqxTJbWpNuvlN2X7oYtL9L8xmPEWnYQLqlg3gWXs+DCDyN8QRRVEPapjCmNoihQ3xJnV1eGgpDOpMooPk0jqKvUlgRpT2RI5xzKCwJEAzpNXWlylsPWtm40RSNlWbi2S11piAnlEW/UqxBUFQ5cObB05DucawIYJgk2wG233cY3vvEN/vd//5dPfvKTg/IaUn7t2RxtoJPsd5Ng9zK7dhFf+hjJlU/jZJNoxTWEZ5xNeOppqOH3vpvq5NJeM5POHZhdu7BjLVjxduxkB07PvMs9b4gHpGqo/gjKXjMvtaLKnoYmNYdUlm6nu0mufZHEyqcxW+q9buLTTic666L9Ni97N8aVhZg7poSWeJan1zbv93N6Kgn7qh7kDvbQ99hjj/GXv/yFYDBIQUEBo0aN4txzz2X0aG92+r///W9OPfXUAT8qJBPswTGYa4Kuri5KS0uZPXs2b7zxxqC8hpR/yazJtrYkHSmT4pBBxrToSltsa00Q8asoikpdoZ9X6ttpjafRhCDo05g1spjqwiBb25KoqoJpOxiKoKY4jINDSchgxfYuVu/spqU7jW1bhP0GiuI1PjMUWLuzC7+hUxg0iCUytHcnaVr+PDsX/ot00yY0X4CRs0+j+vjzCNVMwnQhnbOJ+BTiGZOulItlQwYI4CWvQgXT9hqeaXjnr13HhlgLeryR7pYmsrFmrO7Wvr4sdrobN5vinVcxAsUXRAlGvTVBqNhrfFpYgVZYiV5SgxYtQxEKOlAeBqEJVNerBshkwXShKKwSDer4DQVDqMQzWdrjOdJpm45tq0iufprE+tdwLJOK8ccw/tQPctq55zOxuojOhElHIkN7ymRMSZCQTyfgU2loT5HI2jR2Jpk8IsKIghCGruDTVQQQNDRGFgdIWw6t8SwFAZ01TTFqC4P4dY2goVBTFCRp2ihCMCLql13Dh4F8rQlgkBJsIcRtwEV4VSGbgf9yXbfrnb5uMG+m8Xic0tJSpkyZwrJlywblNaT86931LgoarN4Z429LvLPa+Tmd7XHMDKm3XiGx4kmyO9aBUPCPmkloyikEx8876LPa74bb+0TazIJj09vkDEVFaAZKz+72QHNyGdKbF5Fc9yLpzYvBsTAqxhI+5hxv1ucAvldFeE+h9zcLXVcFN79vGp2pnOwYPozU1tZy7733EovFsCyLtWvXkkql+OhHP8qYMWP4/e9/PyhnbmWC/c7ey7pgMNcEAJ/97Ge5//77WbhwIXPnzh2015Hyy3VdcrZD1nRo62mU1tCZREWgaSohn3du96UNLViOA7joisbIshDCdXGBsF+jpjCIaTskshaO67KtNY7jKqza0UEsZXrnuPHGQWoq7OxMY7kKqnBI5Sx2dGSxXRvHsYkkGtn2+r/ZuuhZzEwKf3ElxdNPoXbWaYydOImtbWk64ha9wzKL/GDb3u08sZ+JGQr9O4zv+z1wcHMZXCvrNT51HW/jQChe41Pd5/1zEI1O/XiVYT4ddE2QMl2KfAq2A0J1GBENIlQVx3ExbZtYw2Z2LnuGxiXPk4m1ogfCjJ13DhNOvZgpk6fhCpfqwiCbW7vxGTrTK6M0dKQZWR5GFwK/prJqZxdNXSk64ia1JUEqIn6m1RWiuOAKCOoqkaCBoQh2dWcoCvnoiGdxhUvYr1NbFKQi6sd2vJ1Kef56eMjXmgAGL8E+G3jOdV1LCHErgOu6173T1w32zfSaa67hzjvv5MEHH+RjH/vYoL2OdGTYs2RcAU4cX8p50yrpTOVY0dDFUwfY+QQQAi4+por2ZG5AO5Sb7Q0kVj9Hcu2L2N0toGoERs0kMH4+wbHHHdLOdr7Y6W7SmxeT3riQ9JYluGYWNVxMcNLJhKef8Y6zr/dWGfXR1J09qM9VAKEIHMdFVQWnTyynLOLjA7NqZFI9zLzxxhvcdNNNPPHEEwDkcjk2b97Miy++yCOPPMIDDzxAdXX1gJeBgUywD8Z7WRcM9pqgvr6esWPHMnfuXF5//fVB+bshHTlsx2VHVwrb9nq6VET9hP0aQgjSOZs3t7TRkTJp7kxRXeSnPBrEclyChkpbMsukigghn8765m4aOzI0diapLvKzZkc3wrURisLOzjRhn0Jn2sY0bSoL/EyuilLfnmJ1QzvdaQtcwdgRUaIBg45YJ68+/R92LX2a7i2rwHUoqBxF+bQTKZ4yj0R4JAVBHdeFeNLGAtJvl0kfpN40+r1eSuDtoFdEVWJJm5KoRtSnUxQ28OGwY+NqNix6ke3LXybX2YRQVKqmzuXSj3yUD3zgYoIBP9s70ggc1u1MsK6pi/aEhaoKaov8VBcFqSoMUhL2MbYszNJtHTy1thlNFRT6DcaWh5k3toSU6c0/zzkuYUMj4tcwNAXLgaChUhoyesaAyT4Lw00+1wRwGErEhRDvBz7kuu7l7/S5g30zTafTFBcXM2bMGFavXi1vpkNcb8l4b7fxPcuEl2zr5CO/fr3vTO/+qIrgxLEl+22cdahc1yG7Yz2p9a+Q2vAadncrAEbFWG9M18hj8FVPQtGPvCYcrmV6I7q2rSC9ZSm5po3gOqjhYgLj5xGaeBK+2qkIRaX3IfHBluwL4KTx+x+7poj+11GEN5blOxdOlTvVEolEgve9731MmDCB6667rq8EDOC73/0uQghuuummQXltmWC/Owe7LhjsNQHA1Vdfzd13382iRYuYM0f+CIc603bImDa6qvQrEbYdl0X1bbQns3RnLEojPkYUBOhOmbQlMrguFIV8VBf5eWZNMxGfyua2BMm0TSCg0hJLUxL04eKd+/YSPZVExmbmqAJcx2XZ1g6Spk0im6Mo4OOUyZW0xlI8u74Vx3Hp7mghuf51YuteZdeGFbiugxEupGj8bIy6GYTrpuMEy8mwuzz8vdLx7rcCr9z83drz2FxAuDjxXajNa0lvXUnTukVkk90IRSNQNx3/xBOpO/YEZk0Yxf87eTSVhWFauzO8ubUDy7JxHIcXNrYS1lUcXBJZhw/PqaWmOExZ2IemKXSns7yyqdWrMLAVTphYyozqQkzH8XalEfg0r2Tcr6uDcqZWOrrkc00AB14XaAP4Gp8B/jyA13vPAoEAV111FXfccQevvfYaJ554Yr5DkgbR7JFFPHjFvAM2R5taGWVFY6zv31Xh3TB6kzjbcQcluQYQQsFfMxl/zWSKTr8Cs2UL6frFpOsX0/3mI3Qv/CsoGkbFWG9cV+UEjBHj0IoqD3pW9UBwXRcr1uw1WGvaQHbnW2SbNoJtglAwRoyj4ISPEBgzB6Ny/D6xaV4d90GX6SsCzptWySsb2/p9vgJ8/5LprNnpdRGfVlUgk2qpn3A4zN/+9jduvfVWbr75ZiZOnMipp57K/PnzWbduHTNmzMh3iNJuR8y64Bvf+AZ33303t9xyC48++mi+w5EGma4q6PvZzTRth2hQR9UUxpVruEBtUYi3rG5ylkNpxE/OtoglLaJ+nV2xFM2xDCG/SnVhBMdyKAzqtMdzCBc6kjl0RaG2NIhpupREfYT8PmxMVFVBU1Sypk1JxEd1UZAd7Qm0UBETTnk/Ve//OCWqybrFL7Fy4Ys0rnmTzLJnvfgLyglUTSZcNxGldByREaOxtMA+XbrfiYnX3Vvp+V/7XXytk02Sa64nt2sj2Z0byO1cixX3GqSGisoYPfNkyqfMpzU6HsUIkbMBA/yadz59THkUoUBJUMd2NdY3xwjpgrXN3fhUlbqiALviacIBg5KIQdin0NBh4jqCqN9HadjH9KoCCoIHntEtk2vpSF0TvOMOthDiGWDEfj50veu6/+z5nOuBOcAH3ANcUAjxOeBzAHV1dbO3bdt2KHG/o6amJqqqqjj33HP7ygak4WXvZmgCb7f6loun8fz6lgM2zTpcnGyKbOMab95041pyzZv6mpYJ3YdeWodeUote5I3pUqPlaNES1FDRfudNvhPXtrCTXdjxNqzuVqzYLsyOHVjtjeTatnnzrgFUDaN8LL6ayfhrp+GrnYZ6ECO6FLxGY1nrnYvRes9Of/fR1f3mWl+5YAzfPH/yu35v0vATi8V46aWXWLZsGS+++CLt7e3MnDmTe+65h0BgcLrGyh1sz0CsCw73mgDgoosu4rHHHmP79u3U1h5aA0bp6GPaDo2dKQTQmTQpCetUFgbRVYXudI63dsWxbBe/pmBoCou2trNxVxzLcSgK6YwoClKoG9jCYvn2GDs7UjQnMwQ1lXkTyghpKhXRAGt2xti4K46uKRQGNMI+lVFlUTY3x1nTFEMTCiVhg3GlIRzVm+VcFNRZ3dDBW2+tZeOKN2ndvIpU41tY8XYveCHQC0agl9biK65BLapEiZajRsvQIiUII/iuk03XdXGySdx4G2Z3mze6s3MnVscOcm3bveNtPfSCciK1kygcM4PZ80+kpGo0hq4hXJuV2ztojtsYKowpD1EYCjCtKkx5NISqwM7OFOmcS31rgpKwxrbOFJblMrLYT2EwwLkzKikK+bAch4Ch0pnIsrMry6iSINVFQQpDB06wJalXPtYEMIgl4kKITwOfB85wXXf/g233cjjKwQA++MEP8sgjj1BfX9+vZEAaHu5+fhO3P7m+t+0XJ40v5StnTuibp/3RX7+OabuoCtjOu+8ePtBc2/JmTzZvJteyBbNtO2Z7wz4jtQCEEUQJRLxOoHrAa2SmeGfMXFywbVw7h5PL4OZSOOk4Tja5z3WUUCF6cQ1G2Uj0stEYFWMwykYPSmO0PakCThhXyqub2voqCXpHcslO4NK7kclk8Pl8bNu2jerqanR98P7uygT74LzbdcHhWhO88MILnHbaaVx99dX8/Oc/H/TXk44sqazFW7u6MVQFQ1UoL/BT2LM76rouiayFaTmEfBr1LXEeXbkTQxVeGbiuMK48SqFfpyNl0tSV4pX1LbSncrR1pzFdGF0cwtA1Tp1UyqZdcTpSOTKWjeVAccjX0zBMpS2Rw3QcooZBXYmfxu401WGvJ0l9W4JkOkNDZxZNAb+doHPrOuI760ns2kqieRvZziZcy+z/5lQNNdA7psvvNTJTdRSh4OJ1IHcdC9fM4GTTuNkkdjoOTv8CdKH5MIoqMUrrCFaMoqB6NIGKcRSWVxD2aUysDOFTdVqTJpMqwqhCgHDY2NRNKmMTDBr4VJVjRhbR0JHEUBVa4hmChreHHk/n0HQFx3KoKQ4yujzC/DGlpHI2adOiIGCwM+Z1ea8qDJAybUYWh/pmnkvSOzmcawIYpBJxIcS5wDeAUw42uT6cvva1r/HII4/wox/9iF/96lf5Dkc6zIqCRl/S7OKVJPcmbrNHFvHQ5+b3dSK//v9W9ftaAVREfTR3Zw9b4i1Ur1TcqBjb788dM4PV1Yzd3YoVb8NOduKku72kOZfCNTPezGsnBa73OEEoKmiG92TbV4vqD6MEClDDRajhErSCcrRo+aB0Nn8nivB2us+bVsmirR1kTa+s3AVypsPC+naZYEvv6Le//S1nn3021dXVAIwaNSq/AUnAkb0uOPXUU5kxYwb33Xcf//M//0M4/M6VOdLQkTFtTMuhK5VDEVBXuntkjxCCiN9biMdSOVoTGQwVkjkHRbjMH1vCuPIIjV1pRugKu2IpYpkcmrJ7yoUrBM1daV7Z1I7iQlnUYHNLkpKIQUlIp6krQzpn0Z01mTIiQmHQR11xGFVVMW2bUSUhcpbDDlcQMlxMx8Lwl1B3zMk4x5xESSRAVzJLaVBHM2MsW7meVKyNro4WUt0xrFQ3bjaJk0vjWlkUK4vrOtguKIoKiobij/Tc+0MogSiRaCFF5RX4oqWoBSNQAyG6Mgo+AwxdIeLzURTQmDemFIRDVVGI9bvijC0zGFHgRyiCeNpm/rgK0qZFxKehGxrdaYv2RJZJVVE6UibtiRzHjSogmTMZXRoiZVqMLA5TGvYRy5q4NlQVBIlnTFTAcr1O8CGfjqwClw5Gc3Mzjz76KFdccQVCiLyvCQ71DPYvAB/wdE9pykLXda885KgGyPz585k9ezb3338/N954IzU1NfkOSTqMVu+Mve2/zx5ZxOyRRXzu94v3SaI/NreOqVUFfOefq70b5yDH+nYU3Y9RNhLKRuYxCk/Yp5LI7j7FpauiX4n3/hiqguV4Zfq6Kji2tpBcTxn5dy6cygOvbmFTSwLwOp0Wvc15K0kCb7bx1Vdfzfz583nmmWfyHY7U3xG9Lrjmmmv4r//6L2677TZuvvnmfIcjHUaW4yAE+DUF04FUxtpnRrLrujR0pdjZlWFEgTeVurIgwLjyCLbroiuCjoSJX9eYUB5hQ2uSkrBOd8ahI5VD1xUMBdpTJiN9QaqL/AhFoTQaIJ61GF8WIZ6z0BAoiqAplqY7lWVcRQSfptHQmcJyXMrCGi0JG9sVVJeEsC0XRXFxDRWhaMTVAmqmzqQ7bRHImLgOJLLeg2oFCPthTHmUkoBKczxDa7dJPGthWuAIwIGABj4D6kqj+FXwaQoNnWkSpjffW0PhmOoo4yvClBWE8OsqKi6KEIR0je6sxZy6YlxculImO7rSuEIhaKhURf04to2qqZSGNUZE/dSUBnBtP9NrC8mYDooQmLZDVyJH2rRpS2YpD/uwXJdYysINQsRvyHPW0kH57ne/y/3338/8+fOZNm1avsM5tATbdd1xAxXIYLnttts4/fTTufHGG/ntb3+b73Ckw2jvX8n7+xW9ZFsnz6zrfxZbCIj4NG55bE3PLMWejWGJKZVRljV09SXVvf9bXRRgbGlov83iTp1Y1vc9tmyXRVs7AVjRuApNFdh7JOiKgM7Uu23jIg03P/zhD0mlUlx11VX5DkXay5G+Lvj4xz/OD37wA2655Ra+/vWvy13sYSTk0zAdB5+uUqBrZJ19+4VkTIdEyqQ7YxNL5aiIGER83o616Qgc10FTVWoKA4T9FQQDnbTFM1QXBujOWjimS8a2QSi0dOeoKw5TEjKoLQlS6DfIWQ4BXHAEU6uiRIIGrd0ZOlImPkPhuFElVET9bG1LorYnUIVCzrFRFBVVFWg+g+KIQabTorYkSjqdpTVlUqhr7IincVxvNvfIkiDH1hXj13UqogYrG2PUFvpY25RkbVMHGdPBUMFGwbZcdJ+G32cQCdiM1jVCfkFFJMA50ytJmQ6FPh1UQTJrogiBT/fi8esq5QUBGjY0YyiCgK4QS1scU13IlOpC1uzopjxkoGoKhqIyZkQY2/EevKuKoCOZI5WzKY/62dqexKcpaKpCQVCnIuLDkYsv6SA0NDTwq1/9iuOOO46pU6fmOxxgYLuIH5FOO+00TjjhBH73u99xxx13UFhYmO+QpMPkA7Nq+OuSxr7xXR+YtW8FwyNLG/cZLeW6cP8rW/p2rgXemeF32KjtoyqC0SVBNrcm9+mQ7QKqArPqiuhImdS3Jg56tNWRoCBo8PDn5vPZ3y+mI7k7EW7qSvOzy2Zy7rRKnljdREnIoD2Z47xplUwcEeGFDa2YltN/5gf0G58m8MZxzRtTcvjekHTUMU2Tn//854wfP54PfvCD+Q5HOspomsaPf/xjPvCBD3DXXXdx/fXX5zsk6TAJGhpjSiMksxaKIgj79l0C52ybjGUzImKQsWxsXEI+ne0dGWpLAvg1nZhrEvCpbG1PMr48wgnjSikKGpSFfWxuibNyRxeRgIEuoKY4xIzaImzbYeHmVtbuShPQBadNLMfQVVI5mxEFAUaWejvErd0Z/IZGKmfRFPMaj2ZzLsl0kqwtsF0Xy3HQEKgKGIZGkQshv84IFYoCPjRNUOjXGVsRxrQE48rCVBeHiPo0dK2NrnQWx3XJWA5VET+W61IaMqhvTeAzVPy6IGD4mFJThGlBccAg4NMI+VSI+GjsSGOogsKQ91qlIR0Xweb2JH5NY0plhIztUhnQOWF8qfd97alaMzSl7yx8d8akOGSgCEEsnUNTRM8/CvFMjqzlUBY58kaYSkee22+/HYA777zziKl4GPIJNsANN9zA+eefzx133CFLwoaR2SOLeOizBx7fBQdubOa4LqoicF0XvWcG858Xbe8b9yWAM6dUEEvlWLS1Exev/PnSObVMqyrggVe39Lv25BERvv/+6X1nvm95bA05y/F+EQzwE9rjRxWxvKELy3Z5537eB7ZXLgxAecTH+l3xfsk1eG/hkaWNVBUG+hrJ9VqyrRPH8c5ZK27/a2qqwHVcVFXhQ7Nr+OCsGnn+Wnpb9913H5lMhm984xv5DkU6Sl188cVUVFRw11138a1vfQtFOXwjEaX8EUJQEfWTsWwUIfYpDwcI6BpBn048azO21OuCrSiCkE8lkbGwdCgOGZSEfDiOi09TUIRACC/hnD3aR2VhkHjWQlcUyiIGrd1ZVu7oZGtbgsoCPyFDB0UwIhogY9mkshaxjIVpW/g0lYKARjrnNVsr9GtsaOnGtFVcx8LCJZWzmTwiQsin49cVVKHw1q44AV3DFd4xq5m1RZRHgowpDVMYMsiYNm81xcjaMLo0REDT2BXPEDJUFCHozlmkTYdYJodwXMaVFxA2NDRdwTBUbMfBUA1URTC2PELYr+LXNIqDPhJZm0K/xpSKAnZ0p/H7NAK6SixtIoCIX8fQdv835uKSs1x8mko6ZzO6LETGtBlbFibk08hZLnUlQXy6ik/b92ckSXvK5XLcc889TJ06lRNOOCHf4fQZFgn2ueeeS01NDT/72c+48cYb0bRh8bYldp+zPpAPzqrhL4sb+u2kKnhPWb9z4dR+M5gnjohw+f0L+3bErzxlbF9H8t4kHuCj9y3se1rba0PPGeP/Pm0cdz+/qW90mIK730T23VIETK8u4CPH1fGxuXUs2dbJLf9a02/+97vRO9LMcd2+HXajpwrgzmc27PP5uqbw18UNWI67z/fu70sb6f12OMBxo4rw62rf7vbbPQCRpD25rstPfvITIpEIn/jEJ/IdjnSUUhSFa6+9luuuu46HHnqIyy+/PN8hSYeJogiCxoHXgIamMLEigqqA4wosy8ZFUFkQIOrXEQoEdBUhBNWFQVriXiPUioiv7xqVhQGKLe+Mse24dKXTCMCyHba2pigOG9QWeQ1Gg4ZGPGPhU73RYMmsTc6yaexI4whoSWQpiwQJFCnEsxbdGZPjx5QQ8enkLIdIwCth9+sqQnHpTOQYWxFh5ugSqgoCfd23/bpKecRHyKcSMnxkHIuxZWFOGF9GKmPxrxU7GFcRoaE9RcayKQ4bqKpCSVAn5NMJ6CplET/dWZMJIyLYtotPVygJ+0hkLGJZG1UTVBX4GV0SxMWlI+E9iM9YDqVhH6mchaoIMjmHlGlRFvHRHs/h0xRqi3c3nAvt/lZK0ju65557sCyL6667Lt+h9DMsMk0hBF/72tf4yle+wu9//3s+85nP5Dsk6Qgxe2QRf/7cfB5Z2ogLTKsq6JdU7/25D16x7474nkn83c9v8kqh9+I4Ln9f2sjskUXMG1OCoSl9ibomBImcvc/XHKze3fE94509sohp1QXvOsHu3bXfM0kuChr9vidTK6O8vMdZ6+NHFTGuIsLDb27Hcb1O4N/552qcnussGF/W7zVylsM3z5vc7/snSQfjscceo76+nm9+85v4fHIVJr13V111FTfccAM/+clPZIIt9VMQNJhZW0zOdlCFt4ZUFbFP6WnQpzHSUPf5c7HH7njGtBHCOw6VzDromqAwqNGWzBLsTFJZECRkqDRnLHK2g19XaE+ZlEf9ZC2L9pTLSeNLaEuYtCdMJvgV6orDJHMWFYV+ykJ+YhmLBVWF3jE2AVUFfgoCBtmetYhfVxBCEPbrTKiIUhHNkcjZFAZ0GjpStCUyaIogkbUI+hUiQiNrOnQmsqSjfhAKtcVBQn6NkH/ftMF2XGqK/CQzNoauUl0YZGcsQ9Cn4gLpnEVzt0N3xhsJFtQVcpbNlpYkfl3QlsiRMV0qCnxyx1p6V1zX5ac//SnFxcVcdtll+Q6nn2GRYANcccUVXH/99Vx//fVcfvnlcnEm9XmnXe5387nzxpSga8o+O9gu8LcljX0l0J+eP4r/rNnFuVNH8JfFDbBXgh00VLKWjX2AGu/JIyKURnycN62Sj82tA+i3kz57ZFHfGfSc5fQ1eHunnXLbcVEEfHr+qL7d5YkjIv3eczy7e26mAE6ZWM68MSU8stQ77943sgQv2d77NVc2xrj8/oVy3rX0rjiOwze+8Q2EEFxzzTX5Dkc6ykUiET73uc9x991385e//IUPf/jD+Q5JOoIoisCvvHOy907nPf26SnVhgNZYhlmjinFsh3jWoiIaQOCdPS4JGTg4JFIWhaUhaguDrNQ6SeRsSkI+KgsCHDe6DEUImuMZEhkLn67gut4ordHhEOVRPwUBve9hQGcyR3syC0BBQKcs4idoaJRGfIT9OoYKzd1ZmrsyZGyLyqIg2ZyJpocQjsKIYj+pnE3O8ZqydaVyCOE1gQv5tL4HCJbtsCOWImu5RAI6BUFv57soaPQdJSsM6mzvSGLZYDsOXUB1QZCcbdPcnSHi0xAC2hM5qgoDh/Rzk4aXX/ziFzQ0NHDzzTcP+rzrd0u4eejQN2fOHHfx4sWH/XV/8Ytf8MUvfpF7772XK688YqaGSEPMgcqzVQHXnj2RoqDBt/eYu71gfOl+u2/vXTreexvXVcFDn5u/zznny+/3StP33H2Op03WNHVTEjL4x/KdB/0eVEUgcLGd3a+3flecB17dwuaWRF9chqbw0Gfn9SuVj6dNfvlSfd+1zp5SwdNrm/u9l97vxX+fdkQ3HJaOIP/4xz94//vfz3e+85289dIQQixxXXdOXl58CMvXmiCVSlFcXMy4ceNYtWrVEdMcRxp6dsXSpHI2KdMilbGoKgySsx2KQwbpnMW6XfGeB+sOM6sLeGptE6/XdxL1q4QNnTOmlGM6Lltak4wtD3sTPFyHioIgQngjtoKGRkHPmMtt7Un0nk7dXekcQV0lYzlYtoNpuziuw7a2JJ2pHM2xDIqAUyaMIBpQ2RnLoKmCbe0pArpGRdTnVd3ZLiFDRdcUKiMBco7Nzq403WkTVfES+9GlIaqKvHLv3o7jfl1l8bZOcF1sxwVcKguCaCpsaUtRV+yVyxuayogC2dRMOjiWZVFXV4dt2+zcuRNVzU/1w4HWBcNmBxu8XexvfvOb3HHHHXz+85+XN1NpUMweWcR3LprqJbymg4N3Rlrv6ZC99xlmFygO6nSkzH3+fE+fXzCGSEAnnja585kN/XavF9a3953rzlm7S7Qdd//Nynod6GP2Hq3Nc7bLlx5exo7O9D5f+6HZNfucQ19Y344i8M6YCyiL+PDpyn6/F5J0sG677TYArr322jxHIg0VwWCQK6+8krvuuovXXnuNE088Md8hSUNUecRPMmehCD+qgFjGIuTTiPp1OpI5/KpCxK+TTWQRqkLU56M0bFAY0NjelWFXV4aAT6M5liaRsYgGdU4YW0p5xM+yhk6SOQvHgWnVUcqjAUKGSlfaAlxcx8V1IZPzRo8JRaAJCPlUtrdbRAM6rusiFIeg38/swiDN3VmKgz6ylkNxSGdzWxLhAq5Da3uOrW1JUlmbnG0TNjTKoj4KQz6Kw16C3BbP0J210HqauVUV+IhnLBCCQr8GQsG0HaZWRkmZNooiKAkbef0ZSUeXRx55hKamJr7//e/nLbl+O8Mqwfb7/Vx11VXcfvvtPP/885x++un5DkkaovY8r733GebzplX2O8N83rRKNjbH3/Z6qiI4a+oI1u+Kc9uT6wH6rvGxuXX9znXvWaINb59c+3SFT88f1bfL/djKJhzX7btGr72T696YPjirZr+753ueMf/ArBo+MKtmv98LSToY69at47XXXuPTn/40BQUF+Q5HGkK+9a1vcdddd/GDH/yAxx9/PN/hSEOUoggi/t0lrP49Gq2VR/zs6EqzoyNFScQgZKiUFfoo7jRIZy0MVdAcz1KmwMjiEDnXpSLqpyhkkLZsUjmb0rCf7kyOjmSO8miAkrCPQM9rZEyb7oyJ7bpe6buuksxZFIX8lEVzjIgGcF0oDBpUFgTQVYWCoIHtuDTF0t7sbl1FVxVvjJimYNoQ0AVBn07GdAgaOlUFAfy6iml7563DPo2MaRPPmIwoCBLymSgCwj69r/kagFwJSO/FbbfdhhCCL33pS/kOZb+GVYIN3s309ttv57vf/a5MsKVBted57d4dXoCJI7wOpbbjzcSeOCLCvDElb1vCbTsujyxtZHtHqt+fP7G6iY/Nrdsnoe8dA3agHezekWJ7j8X6xPxRfde46V9r9jlL3ksRcMvF05g9sqhfV3TTcuhM5Q7YDE6S3oveWcVHWpdQ6ehXUVHBJZdcwj/+8Q82bNjAhAkT8h2SNMQ5jkssbWI5DgUBg5BPpTLqozUOPlUQS5sEdZVczmFHV4bRJUFGlgSJZ20Spk11oR/LccnZDiHdG4nVFs/guFBc5u0CCyEI9cz59usqds9x0GBPolwZ9RPwqZSGDNKWgyoENUVBdHX3OC1VEVQXBrAcl7KwQXM8iyrwEu1ElljKxqco1BWHGF8RJtzzAMEbW+ZV01mO21c+XhiUO9TSwFi0aBGLFy/mc5/7HJFIJN/h7NewS7CLi4v57Gc/y3333cfjjz/O+eefn++QpCFu7x3eBePL+pqX2Q786sXNPLe+5R2v48J+d7977ZnQ9zYo690x3vt/D7SDvPc1HlnayF8XN3jNzxTByOIgY8rCfL5nRBmwT1f03mvLhFoaCG+++Sb/93//x8UXX8ykSZPyHY40BN1888384x//4KqrruLZZ5/NdzjSENeZyhHLmGhCkDbTlIQMmroyBH0asbRNxkrgOi4FQYM6BQxdJW3a1JYEcN0AIUPDclx0VUHTFGbUFBDLWPg0Zb9JrKoIyiN+2E8eEvEbZC0bVRH77eAthEBXBXrAIOTzSsmztsPIkiA5i56u6P1fU1W8sWadqRxFPoOo/8hqPiUd/a644gqEENxwww35DuWAhl2CDfA///M/3HffffzoRz+SCbY06PY8H21aDs3dmX4fb+7O9JvD3UsVIBSBbbvoqui32/zE6qZ+Z7D3NhAJbu81esu73y4p39+OtSQNhFtvvRWA22+/Pc+RSEPVjBkz+MhHPsKf//xn1q1bx+TJk/MdkjSE5WwHn6qgKYKUaeO6oGkKjutiuS4BAbqhoqrg01R0TSVoaIwtC5M1HRJZm2KfSqCnk7ff0PqVnL8b6jvMBd+TV9YtCKoKvMPX+HWVygLZEVwaeK+88gorV67k6quvpra2Nt/hHNCwTLBLSkr4+Mc/zh//+EdWrFjBMccck++QpCFs7x3ejxxXx7qm1Zg9ifNHjqtjTdPqfkm2Anzvkul9O9F7Jq4fm1t3wMR6MBxMsi53rKXB0NjYyCOPPMI555zDuHGy47w0eK677jr+/Oc/c9ttt/HAAw/kOxxpCCsKGuzqzpC1HYoCBkFDo6owQDxrEfXrVBYG6EhmGVUcpDtrURbxM648QsRvEPFD6ZFZEStJh0XvQ/evfe1reY7k7Q2rMV17WrNmDdOmTeOss87iqaeeymss0tC394zq/f37r17cTH1bktGlIa7cowRbkoarj33sYzz00EO8+OKLLFiwIN/hyDFdg+RIWBMAnHDCCbz++uusX79ensWWBpXtuDiu23fm2ek5U60pAk1V6F2by2k3krTbwoULmT9/Pu973/v45z//me9wgAOvC4Ztgg3w6U9/mv/93/9l8eLFzJ49O9/hSJIkST127NhBTU0N559/Pv/+97/zHQ4gE+zBcqSsCdatW8eUKVO4/PLL+eMf/5jvcCRJkqQ9nHPOOTz11FM0NDRQU1OT73CAA68LlP198nBx0003AbvLDSRJkqQjww9+8AMAvvOd7+Q5Emm4mDx5MqeccgoPPvggHR0d+Q5HkiRJ6rF161aeeuopPvShDx0xyfXbGdYJ9qhRo7jkkkv461//ypHw9FySJEmCzZs3c++993LiiScyd+7cfIcjDSPf/e53Afjv//7vPEciSZIk9briiiuA3b+jj3TDOsEG+MlPfgLA9773vTxHIkmSJIE36QF2/36WpMPltNNO46yzzuLhhx+mqakp3+FIkiQNe2+99RbPPvssH/vYx5g2bVq+wzkowz7BHjNmDBdccAGPPvooDQ0N+Q5HkiRpWIvFYjzwwAMcf/zxcvdayouvfvWrgHzAI0mSdCTofeh+pHcO39OwT7Bh9w/sm9/8Zp4jkSRJGt5uvvlmXNftS3Ik6XA7++yzmThxIj/72c9obm7OdziSJEnD1saNG/n973/PSSedxMyZM/MdzkGTCTZw6qmncvrpp/OnP/2JTZs25TscSZKkYamrq4s77riDGTNm8KEPfSjf4UjDlBCCn//855im2ddsT5IkSTr8ehud/vSnP81zJO+OTLB73H777QDcdttteY5EkiRpeOr9PfzjH/8YRZG3Jyl/zjrrLGbMmME999xDJpPJdziSJEnDTktLCw8//DBnnHEGxx13XL7DeVfkCqbHzJkzmTdvHr/+9a+pr6/PdziSJEnDSktLC7fddhvjxo3jrLPOync4ksT111+Pbdt8+9vfzncokiRJw861114LHJ3jOmWCvYd7770XkB3FJUmSDrfbbruNXC7Hz3/+c7l7LR0RPvzhDzNnzhx+9rOfEYvF8h2OJEnSsLFz504efPBBzjvvPBYsWJDvcN41uYrZw7HHHssJJ5zA7373O3bt2pXvcCRJkoaFeDzOPffcw/jx4znnnHPyHY4k9bnmmmuwbfuoO/8nSZJ0NPv+978PHF2dw/ckE+y99O5e95YlSJIkSYPrhhtuIJVK8YMf/AAhRL7DkaQ+H/nIR5gwYQK33HILnZ2d+Q5HkiRpyKuvr+fee+9l3rx5nHbaafkO5z2RCfZeTj/9dE499VQeeughWlpa8h2OJEnSkJbJZLjnnnuYMmUKl156ab7DkaR+VFXta75355135jcYSZKkYaC34fRPf/rTo/ahu0yw9+Nb3/oWIOdiS5IkDbZbbrkFy7L6fu9K0pHmoosuoqamhltvvVXuYkuSJA2irVu38stf/pKZM2cyf/78fIfznskEez/OPvtsTjvtNH7729/S1taW73AkSZKGpEQiwa233sq0adO4/PLL8x2OJB3QvffeSzab5dZbb813KJIkSUNW71HdBx54IM+RHBqZYB/Ad7/7XQD+53/+J8+RSJIkDU2/+MUvcByHb3/720dtGZg0PFxwwQWMGzeOe+65h2w2m+9wJEmShpzu7m4eeOAB5s2bx7HHHpvvcA7JgCTYQoivCiFcIUTpQFzvSLBgwQJmzZrFT3/6U7Zv357vcCRJkoaUzs5Ovvvd7zJ69Gh59noIGmrrAiEE119/PfF4nBtvvDHf4UiSJA05X/7yl4Gjc+713g45wRZC1AJnA0MqCxVC9I3luP766/McjSRJ0tDy/e9/n1wux6233oqmafkORxpAQ3Vd8KlPfYrJkydz1113ySaokiRJA2jz5s387ne/Y8GCBZx33nn5DueQDcQO9h3ANwB3AK51RDnllFM49dRT+eMf/8j69evzHY4kSdKQ0NbWxk9/+lOmTZvGBz/4wXyHIw28IbkuEEJw++23k8vluPnmm/MdjiRJ0pDR2+j0Rz/6UZ4jGRiHlGALIS4Gdriuu+IgPvdzQojFQojFra2th/Kyh1XvLvZPfvKTPEciSZI0NPT+Xr3zzjtRFNkKZCg52HXB0bomOP/88znmmGO45557yGQy+Q5HkiTpqLdr1y7++te/cuaZZx7VncP39I4rGyHEM0KI1fv552Lg28BBFcq7rvtr13XnuK47p6ys7FDjPmxmzpzJ3Llzue+++9i8eXO+w5EkSTqqtbS08JOf/IQxY8Zwxhln5Dsc6T0YiHXB0bomgN3HxuTxMUmSpEP3ta99DRgaZ697vWOC7bruma7rTtv7H6AeGA2sEEJsBWqApUKIEYMb8uF3//33A8iSMEmSpEP0ox/9iFwuxy9/+ct8hyK9R8N9XXDppZdy3HHHceedd8q52JIkSYegsbGRBx98kPPPP5+TTz453+EMmPdcm+e67irXdctd1x3luu4ooBGY5brurgGL7ggxbdo0zj33XP7whz+wcuXKfIcjSZJ0VNq+fTt33HEHc+bM4cwzz8x3ONIAG07rgltuuQXHcfjKV76S71AkSZKOWp///OeBobeJKQ+/HaTbb78dgGuvvTbPkUiSJB2dvvnNbwLe71M591o6mp177rmcdNJJ/P73v2fDhg35DkeSJOmo8+abb/L4449zySWXMGfOnHyHM6AGLMHueWLdNlDXO9JMnTqVj370ozz77LO89tpr+Q5HkiTpqLJp0yYeeughzjzzTE455ZR8hyMdBkN9XXDXXXcB8N3vfjfPkUiSJB19vv3tbwNw22235TmSgSd3sN+FH/zgB4DsKC5JkvRu9VYBDcUbqTQ8zZo1izPOOIOHH36YtrYh+xxBkiRpwK1du5Znn32Wyy67jHHjxuU7nAEnE+x3YfTo0bzvfe/jkUceYdGiRfkOR5Ik6aiwadMmfvWrXzF//nyOPfbYfIcjSQOmt+vtl7/85TxHIkmSdPTo/Z15ww035DmSwSET7Hfp3nvvBXaXNUiSJElv78YbbwTg7rvvznMkkjSwFixYwHnnncef/vQn6uvr8x2OJEnSEW/58uU888wzfOpTn2Lq1Kn5DmdQyAT7XaqqquITn/gEzzzzDM8++2y+w5EkSTqiLV++nIcffphzzz2XmTNn5jscSRpwt9xyCwBf/OIX8xyJJEnSke9zn/scsPvh+1AkE+z34NZbbwW8juKu6+Y5GkmSpCPXV7/6VUD2rpCGrjlz5nDxxRfz+OOP8/rrr+c7HEmSpCPWY489xqJFi7jiiisYO3ZsvsMZNDLBfg8qKyv52te+xsqVK/nXv/6V73AkSZKOSG+++SbPPfccn/rUp5gyZUq+w5GkQfPzn/8c2H0mW5IkSdrXDTfcgK7r/PjHP853KINKJtjv0XXXXYff7+fLX/4ytm3nOxxJkqQjzhe+8AUAbrrppvwGIkmDrLa2lk996lM888wzPPPMM/kOR5Ik6Yjzhz/8gRUrVvCVr3yFoqKifIczqGSC/R6VlpbyrW99i61bt/LQQw/lOxxJkqQjynPPPceSJUu4+uqrGTVqVL7DkaRBd+eddwLw9a9/XR4fkyRJ2oNlWdxwww34fD6+973v5TucQScT7EPw1a9+lUgkwjXXXIPjOPkOR5Ik6Yjgui5f+MIX0DSN66+/Pt/hSNJhUVhYyJe+9CWWL1/O3//+93yHI0mSdMS477772L59OzfddBM+ny/f4Qw6mWAfglAoxI033khbW1vfk2tJkqTh7g9/+APr16/ny1/+MiNGjMh3OJJ02Nxyyy0oisJ1112HZVn5DkeSJCnvkskk3/nOd4hEIlx77bX5DuewkAn2IfrqV79KTU0N3/ve9zBNM9/hSJIk5ZXrutx0001EIhG+//3v5zscSTqsCgoK+P73v099fb08PiZJkgTce++9tLW1cccdd2AYRr7DOSxkgn2IFEXhpptuoquri6uvvjrf4UiSJOXVTTfdxJYtW/j2t7+N3+/PdziSdNh98YtfpKioiKuuuorOzs58hyNJkpQ327dv51vf+hZ1dXV88pOfzHc4h41MsAfAZz7zGS666CKWL19OOp3OdziSJEl588ILL7BgwQK+/vWv5zsUScqLcDjMb3/7WyorK3niiSfyHY4kSVLe/Otf/2LUqFH84Q9/QNf1fIdz2Ih8dLoUQrQC2w77C783pUBbvoM4DIbD+xwO7xGGx/scDu8Rhsf7PJre40jXdcvyHcRQI9cER6Th8D6Hw3uE4fE+h8N7hOHxPo+297jfdUFeEuyjiRBiseu6c/Idx2AbDu9zOLxHGB7vczi8Rxge73M4vEdp6Bguf1+Hw/scDu8Rhsf7HA7vEYbH+xwq71GWiEuSJEmSJEmSJEnSAJAJtiRJkiRJkiRJkiQNAJlgv7Nf5zuAw2Q4vM/h8B5heLzP4fAeYXi8z+HwHqWhY7j8fR0O73M4vEcYHu9zOLxHGB7vc0i8R3kGW5IkSZIkSZIkSZIGgNzBliRJkiRJkiRJkqQBIBPsgySE+KIQ4i0hxBohxI/zHc9gEUJ8VQjhCiFK8x3LYBBC3Nbzc1wphPg/IURhvmMaKEKIc4UQ64UQm4QQ38x3PINBCFErhHheCLG257/FL+c7psEihFCFEMuEEI/lO5bBIoQoFEL8ree/yXVCiPn5jkmSDsZwWRPA0F4XyDXB0U2uCYaWobQmkAn2QRBCnAZcDBzjuu5U4PY8hzQohBC1wNnA9nzHMoieBqa5rjsD2AB8K8/xDAghhArcDZwHTAE+KoSYkt+oBoUFfNV13SnAPOC/h+j7BPgysC7fQQyyu4D/uK47CTiGof9+pSFguKwJYFisC+Sa4Ogm1wRDy5BZE8gE++BcBfzIdd0sgOu6LXmOZ7DcAXwDGLIH813Xfcp1XavnXxcCNfmMZwAdD2xyXbfedd0c8DDeAnBIcV23yXXdpT3/P473y7c6v1ENPCFEDXABcH++YxksQogCYAHwGwDXdXOu63blNShJOjjDZU0AQ3xdINcERze5Jhg6htqaQCbYB2cCcLIQ4g0hxItCiOPyHdBAE0JcDOxwXXdFvmM5jD4DPJHvIAZINdCwx783MgRvMnsSQowCZgJv5DmUwXAn3qLWyXMcg2k00Ar8tqfs7X4hRCjfQUnSQRjyawIYlusCuSY4isk1wVFvSK0JtHwHcKQQQjwDjNjPh67H+z4V45WfHAf8RQgxxj3KWrC/w3v8Nl4Z2FHv7d6n67r/7Pmc6/FKix48nLFJA0MIEQb+DnzFdd3ufMczkIQQFwItrusuEUKcmudwBpMGzAK+6LruG0KIu4BvAjfmNyxJGh5rAhge6wK5Jhj65JpgSBhSawKZYPdwXffMA31MCHEV8EjPzfNNIYQDlOI9aTlqHOg9CiGm4z05WiGEAK9EaqkQ4njXdXcdxhAHxNv9LAGEEJ8GLgTOOBoXRAewA6jd499rev5syBFC6Hg30gdd130k3/EMghOB9wkhzgf8QFQI8UfXdT+e57gGWiPQ6Lpu727D3/BuppKUd8NhTQDDY10g1wSAXBMczeSa4CgkS8QPzj+A0wCEEBMAA2jLZ0ADyXXdVa7rlruuO8p13VF4f8lnHW030YMhhDgXr8zmfa7rpvIdzwBaBIwXQowWQhjAZcCjeY5pwAlvpfcbYJ3ruj/NdzyDwXXdb7muW9Pz3+JlwHND8EZKz++XBiHExJ4/OgNYm8eQJOlg/YMhvCaA4bMukGuCo5tcEwwdQ21NIHewD84DwANCiNVADvjUEHrKOdz8AvABT/c8lV/ouu6V+Q3p0LmuawkhrgaeBFTgAdd11+Q5rMFwIvAJYJUQYnnPn33bdd3H8xeSdAi+CDzYswCsB/4rz/FI0sGQa4KhQ64Jjm5yTTC0DJk1gZD3BEmSJEmSJEmSJEk6dLJEXJIkSZIkSZIkSZIGgEywJUmSJEmSJEmSJGkAyARbkiRJkiRJkiRJkgaATLAlSZIkSZIkSZIkaQDIBFuSJEmSJEmSJEmSBoBMsCVJkiRJkiRJkiRpAMgEW5IkSZIkSZIkSZIGgEywJUmSJEmSJEmSJGkA/H9BN53FU5eYNwAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plot_contour(logprob, orbits=samples, weights=weights)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Ellipsis\n", + "\n", + "We now create and use a bijection given by an ellipsis using the `IntegratorState` class. The bijection must have as inputs the potential and kinetic energy functions, which are the negative log densities of our target posterior and the auxiliary distribution used for the momentum variable. In the case of our banana density, we are targeting the \"posterior\" $N(x_1|0, 8)N(x_2|1/4x_1^2, 1)$ and using a standard normal distribution for our momentum variable, hence our potential and kinetic energies are $1/2\\left(x_1^2/8 + \\left(x_2 - 1/4x_1^2\\right)^2\\right)$ and $1/2v^Tv$, respectively. However, the orbit we build now is independent of these two energies and moves around an ellipsis given by \n", + "\n", + "$$ \n", + "x(t) = x(0) \\cos(t) + v(t) \\sin(t) \\\\\n", + "v(t) = v(0) \\cos(t) - x(t) \\sin(t),\n", + "$$\n", + "\n", + "which returns to its initial position every $t=2\\pi$ radians. The `step_size` for this orbit is set to cover the entire ellipsis. This ellipsis actually targets a potential and kinetic energy given by the product measure of two standard normal distributions, hence its inefficiency at exploring the real target measure.\n", + "\n", + "The bijection must output a function which takes as input an `IntegratorState`, composed of a position, momentum, potential energy (negative log density of our target evaluated at position) and the gradient of the potential energy, and a step size; and outputs a proposed `IntegratorState`. Even if the dynamics of our bijection are independent of the real potential energy, we need to return the potential energy at the proposed position for the computation of the sampler's weights. But, as our dynamics are gradient-free, we can return the same gradient as the previous state to avoid unnecessary computations." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "def elliptical_bijection(potential_fn, kinetic_energy_fn):\n", + " def one_step(\n", + " state: integrators.IntegratorState, step_size: float\n", + " ) -> integrators.IntegratorState:\n", + " _position, _momentum, _, grad = state\n", + "\n", + " position = jax.tree_util.tree_multimap(\n", + " lambda position, momentum: position * jnp.cos(step_size)\n", + " + momentum * jnp.sin(step_size),\n", + " _position,\n", + " _momentum,\n", + " )\n", + "\n", + " momentum = jax.tree_util.tree_multimap(\n", + " lambda position, momentum: momentum * jnp.cos(step_size)\n", + " - position * jnp.sin(step_size),\n", + " _position,\n", + " _momentum,\n", + " )\n", + "\n", + " return integrators.IntegratorState(\n", + " position,\n", + " momentum,\n", + " potential_fn(position),\n", + " grad,\n", + " )\n", + "\n", + " return one_step\n", + "\n", + "\n", + "step_size = 2 * jnp.pi / period" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 25 ms, sys: 94 µs, total: 25.1 ms\n", + "Wall time: 29.6 ms\n" + ] + } + ], + "source": [ + "%%time\n", + "init_fn, ellip_kernel = orbital(\n", + " logprob, step_size, inv_mass_matrix, period, bijection=elliptical_bijection\n", + ")\n", + "initial_state = init_fn(initial_position)\n", + "ellip_kernel = jax.jit(ellip_kernel)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/cabezasg/.local/lib/python3.8/site-packages/jax/_src/tree_util.py:188: FutureWarning: jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() instead as a drop-in replacement.\n", + " warnings.warn('jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() '\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 2.57 s, sys: 7.87 ms, total: 2.58 s\n", + "Wall time: 2.66 s\n" + ] + } + ], + "source": [ + "%%time\n", + "rng_key = jax.random.PRNGKey(0)\n", + "states = inference_loop(rng_key, ellip_kernel, initial_state, 10_000)\n", + "\n", + "samples = states.positions\n", + "weights = states.weights" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAA+IAAAF1CAYAAABs5lCZAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOzdd3hURRfA4d9sSU9IQu+9CQoK0pSm9KI0kSYgAuKHKIoFEBBBkS4CgooKVlBARCD03qt0pEovIY30ZMt8f2wS0gmwCe28zxNl986dO7vZ7J1z78wZpbVGCCGEEEIIIYQQOcNwrxsghBBCCCGEEEI8SiQQF0IIIYQQQgghcpAE4kIIIYQQQgghRA6SQFwIIYQQQgghhMhBEogLIYQQQgghhBA5SAJxIYQQQgghhBAiB0kgLsQ9ppTqqpRalcWyPZVSW7KxLdlavzMopc4qpRrd63YIIYQQd0LO+7dHzvviYSWBuHikKaW0UqpMqudGKqV+yak2aK1/1Vo3cUZdSqkNSqnezqhLCCGEEA5KqSFKqeWpnjuZwXOdMqtLzvtCCJBAXAghhBBCiFvZBNRRShkBlFIFATPwZKrnyiSUFUKITEkgLkQmlFINlFIXlVKDlFKBSqkrSqlXE7aVVEqFKaUMCY9nKaUCk+37s1JqYMK/cymlvk/Y/5JS6tNkJ+4Uw8KUUk2UUseVUjeUUjOUUhtTX+1WSk1USoUqpf5TSjVPeO4zoC4wXSkVqZSanvB8BaXUaqVUSEK9HZPVk1sp9bdSKlwptQsoncl74aaU+kUpFZzwuncrpfInbHtVKXVMKRWhlDqjlHo9nffwg2TvYRulVAul1ImEdg1NVn6kUmqBUur3hPr2KaWqZNAmg1JqsFLqdEK7/lBK+d+qvUIIIcRt2o0j8K6a8LgusB44nuq501rry3Lel/O+ELcigbgQt1YAyAUUBl4DvlJK+Wmt/wPCgScTytUDIpVSFRMe1wc2Jvx7DmDFcaX8SaAJkGYomVIqD7AAGALkxnGCr5OqWM2E5/MA44HvlVJKa/0RsBl4U2vtpbV+UynlCawGfgPyAZ2AGUqpxxLq+gqIBQoCvRJ+MtIj4X0omtC2fkBMwrZAoBXgA7wKfKGUeirZvgUANxzv4QhgFtANqIajEzFcKVUyWfkXgfmAf0Lb/1JKmdNp0wCgDY73uhAQmvCabtVeIYQQIsu01vHAThznehL+vxnYkuq5xLvhc5Dzvpz3hciEBOJC3JoFGKW1tmitA4BIoHzCto1AfaVUgYTHCxIel8RxcjqQcDW2BTBQax2ltQ4EvsBxckytBXBEa/2n1toKTAWupipzTms9S2ttA37EcTLN6IpvK+Cs1nq21tqqtf4HWAi8lHBlvj0wIqFdhxPqy+x9yA2U0VrbtNZ7tdbhAFrrZVrr09phI7AKx4k2+b6faa0twDwcnYkvtdYRWusjwFEg+dXvvVrrBQnlJ+M4mddKp039gI+01he11nHASKCDUsqUWXuFEEKIO7CRm0F3XRxB8OZUz22U837SvnLeFyITpnvdACHuMRuOoWbJmXF8mScKTjg5JooGvBL+vRF4AbiI4yr4BuAVHFebN2ut7Uqp4gl1XlFKJdZhAC6k055CyZ/XWmul1MVUZa4m2x6dUKcX6SsO1FRKhSV7zgT8DORN+HfydpzLoB4S9ikKzFNK+QK/4DgZWhKGyX0MlEt4bR7AoWT7Bid0IODm1elrybbHpHoNyd8De8J7UCiD17dIKWVP9pwNRwclw/Zm8hqFEEKIjGwC+icMhc6rtT6plLoG/JjwXOWEMnLel/O+ELckd8TFo+48UCLVcyXJ/MSU3EYcV4AbJPx7C/AMKYelXwDigDxaa9+EHx+tdaV06rsCFEl8oBxn2yLplMuITvX4ArAx2XF9E4avvQFcxzFsrmiy8sUyrNgxIuATrfVjOIbNtQK6K6VccVxtnwjk11r7AgGAyqiuLEhqk3LMwS8CXE6n3AWgearX56a1vpRRe++iTUIIIR5t23EMfe4DbAVIuON6OeG5ywnT1uS8f/vkvC8eORKIi0fd78AwpVSRhAQgjYDWOIaY35LW+iSOq7rdcJz4wnFc8W1PQiCutb6CY8jWJKWUT8JxSiul6qdT5TLg8YSkJiagP455Vll1DSiV7PFSoJxS6hWllDnh52mlVMWEK9V/AiOVUh4J88d6ZFSxUqqhUurxhKFt4ThGDdgBF8CVhBN8wlXyu12WpZpSql3CezAQR4dmRzrlvgY+Sxh1gFIqr1LqxVu0VwghhLhtWusYYA/wLo4h6Ym2JDy3KaGcnPdvn5z3xSNHAnHxqBsFbMNxEg3FkQSla8K8qazaiGMI1oVkjxWwL1mZ7jhOXEcTjrMAxxyvFLTWQcBLCe0IBh7DcdKPy2JbvsQxVypUKTVVax2B4+TYCceV5avAOBwnUIA3cQwNu4ojsczsTOoukNDucOBYwuv8OeEYbwF/JLy2LsDfWWxvRhYDLyfU9wrQLoOhZV8mHGuVUioCx0m7Zmbtvct2CSGEeLRtxJEEbUuy5zYnPJd82TI5798eOe+LR47SOvWIFiHE/SJheNZFHBcH1t/r9uQEpdRIHIlWut3rtgghhBA5Sc77Qjw65I64EPcZpVRTpZRvwhysoTjurqc3PEsIIYQQDzg57wvxaMpyIK6U+kEpFaiUOpzsuQlKqX+VUgeVUosSMhSmt+9ZpdQhpdR+pdQeJ7RbiIdZbeA0EIRjvnqbhHlpQghx35B+gRBOI+d9IR5BWR6arpSqh2P95J+01pUTnmsCrNNaW5VS4wC01h+ms+9ZoHrCPBghhBBCPOCkXyCEEELcuSzfEddabwJCUj23Ktn6yju4veUWhBBCCPGAkn6BEEIIceecOUe8F7A8g20aR3bDvUqpvk48phBCCCHuT9IvEEIIITJgckYlSqmPACvwawZFntVaX1JK5QNWK6X+TbiSnl5dfYG+AJ6entUqVKjgjCYKIYQQOWLv3r1BWuu897od95Kz+gXSJxBCCPEgy6xPcNeBuFKqJ9AKeF5nMOFca30p4f+BSqlFQA1SrrWYvOy3wLcA1atX13v2SA4XIYQQDw6l1Ll73YZ7yZn9AukTCCGEeJBl1ie4q6HpSqlmwAfAC1rr6AzKeCqlvBP/DTQBDqdXVgghhBAPLukXCCGEEFlzO8uXzQW2A+WVUheVUq8B0wFvHMPK9iulvk4oW0gpFZCwa35gi1LqALALWKa1XuHUVyGEEEKIHCX9AiGEEOLOZXlouta6czpPf59B2ctAi4R/nwGq3FHrhBBCCHFfkn6BEEIIceecmTVdCCGEEEIIIYQQtyCBuBBCCCGEEEIIkYMkEBdCCCGEEEIIIXKQBOJCCCGEEEIIIUQOkkBcCCGEEEIIIYTIQRKICyGEEEIIIYQQOUgCcSGEEEIIIYQQIgdJIC6EEEIIIYQQQuQgCcSFEEIIIYQQQogcJIG4EEIIIYQQQgiRgyQQF0IIIYQQQgghcpAE4kIIIYQQQgghRA6SQFwIIYQQQgghhMhBEogLIYQQQgghhBA5SAJxIYQQQgghhBAiB0kgLoQQQgghhBBC5CAJxIUQQgghhBBCiBwkgbgQQgghhBBCCJGDJBAXQgghhBBCCCFykATiQgghhBBCCCFEDpJAXAghhBBCCCGEyEESiAshhBBCCCGEEDlIAnEhhBBCCCGEECIHSSAuhBBCCCGEEELkIAnEhRBCCCGEEEKIHCSBuBBCCCGEEEIIkYMkEBdCCCGEEEIIIXKQBOJCCCGEEEIIIUQOkkBcCCGEEEIIIYTIQRKICyGEEEIIIYQQOUgCcSGEEEIIIYQQIgdJIC6EEEIIIYQQQuQgCcSFEEIIIYQQQogcJIG4EEIIIYQQQgiRg24rEFdK/aCUClRKHU72nL9SarVS6mTC//0y2LdHQpmTSqked9twIYQQQtw70icQQggh7tzt3hGfAzRL9dxgYK3WuiywNuFxCkopf+BjoCZQA/g4o5OzEEIIIR4Ic5A+gRBCCHFHbisQ11pvAkJSPf0i8GPCv38E2qSza1NgtdY6RGsdCqwm7clbCCGEEA8I6RMIIYQQd84Zc8Tza62vJPz7KpA/nTKFgQvJHl9MeC4NpVRfpdQepdSe69evO6F5QgghhMgh0icQQgghssCpydq01hrQd1nHt1rr6lrr6nnz5nVSy4QQQgiRk6RPIIQQQmTMGYH4NaVUQYCE/wemU+YSUDTZ4yIJzwkhhBDi4SF9AiGEECILnBGI/w0kZjztASxOp8xKoIlSyi8hIUuThOeEEEII8fCQPoEQQgiRBbe7fNlcYDtQXil1USn1GjAWaKyUOgk0SniMUqq6Uuo7AK11CDAa2J3wMyrhOSGEEEI8gKRPIIQQQtw55ZjCdX+qXr263rNnz71uhhBCCJFlSqm9Wuvq97odDxvpEwghhHjQZNYncGqyNiGEEEIIIYQQQmROAnEhhBBCCCGEECIHSSAuhBBCCCGEEELkIAnEhRBCCCGEEEKIHCSBuBBCCCGEEEIIkYMkEBdCCCGEEEIIIXKQBOJCCCGEEEIIIUQOkkBcCCGEEEIIIYTIQRKICyGEEEIIIYQQOUgCcSew2+18+umntGzZknHjxhEREXGvmySEEA8srTV//fUXL730Ej169ODy5cv3uklC3Jbt27fTpk0b+vTpw759++51c4QQ4oF29uxZ3n33XVq2bMnChQvvdXOcRgJxJ1i4cCHDhw/nxIkTDB48mIoVK7J+/fp73SwhhHjgXL16lebNm9O2bVu2b9/OH3/8wcCBA+91s4TIMqvVSrt27ZI+v9WrV+e9994jPj7+XjdNCCEeKFprpk2bRsWKFfnqq684dOgQnTp14tKlS/e6aU4hgbgTbNq0CR8fH44fP8727dvx8vKiUaNGTJ8+/V43TQghHhj79++nevXqbNq0ienTp3Pu3Dk6d+7Mpk2b7nXThMiyU6dOcfXqVSZMmMD58+fp168fkyZNolGjRgQHB9/r5gkhxAMhLi6O7t2789Zbb/H8889z+vRpFi1ahNVqZceOHfe6eU4hgbgTxMXF4e7ujsFgoFatWuzevZvWrVszYMAAhgwZgtb6XjdRCCHuaxs2bKBevXoYDAa2b99O//79MRqNeHp6EhcXd6+bJ0SWJX5ePT09yZUrFzNmzODXX39l165dPPvss1y8ePEet1AIIe5vkZGRtGzZkl9++YVRo0axZMkSihQpgqenJ8BD0y+QQNwJ8ubNS1BQEFarFQBvb28WLlxI3759GTt2LAMHDpRgXAghMrBq1SqaN29O0aJF2b59O1WqVEnadu3aNfLly3cPWyfE7cmbNy/g+Owm6tKlCytXruTSpUvUq1ePc+fO3avmCSHEfS08PJxmzZqxYcMG5syZw/Dhw1FKATe/VxO/Zx90Eog7Qbly5bDZbJw+fTrpOaPRyNdff80777zD1KlTeffddyUYF0KIVNasWcMLL7xA+fLl2bhxI4ULF06x/dixY5QtW/YetU6I21ewYEG8vLz4999/Uzxfv3591q5dS2hoKA0bNuTChQv3qIVCCHF/ioyMpEWLFuzcuZN58+bRo0ePFNsTv1fLlSt3L5rndBKIO0Hi3Zu9e/emeF4pxaRJkxg4cCBTpkzho48+uhfNE0KI+5LFYuGNN96gfPnyrF27ljx58qTYHh0dzbFjx1LcIRfifqeU4oknnkjTJwB4+umnWb16NcHBwTz//PNcvXr1HrRQCCHuTzNnzmTHjh3MnTuXDh06pNm+d+9efH19KVas2D1onfNJIO4ElStXxsvLiy1btqTZppRi8uTJvP7663z++eeMHz/+HrRQCCHuP2azmZUrV7J69Wpy586dZvvOnTux2WzUqVPnHrROiDtXp04d9uzZQ2xsbJpt1atXZ/ny5Vy6dImmTZsSFhaW8w0UQoj70KBBg9i6dWu6QTjAli1bqF27dtJQ9QedBOJOYDKZePbZZzNcskwpxVdffUWnTp348MMP+eGHH3K4hUIIcW8EBQVlOi2nVKlSGc4BX79+PQaDgWeffTa7midEtqhfvz7x8fFs27Yt3e116tThr7/+4tixY7Ru3ZqYmJgcbqEQQuS8W/UJDAYDNWvWTHfb1atXOXbsGPXr18+u5uU4CcSdpHHjxvz777+cP38+3e1Go5Eff/yRJk2a0LdvX5YsWZLDLRRCiJw1bdo0Ro8eTXh4OFrr286TsXLlSmrWrEmuXLmyqYVCZI/69esnjfjISOPGjfnll1/YunUrnTt3Tkr4KoQQD6O77ROsWrUKcHx3PiwkEHeSFi1aALB06dIMy7i4uLBw4UKefPJJXn755YdmDTwhhEjtn3/+YebMmfzvf/8jV65cKKVSDCW71Qn42rVr7N69m+bNm2d3U4VwOm9vb+rWrZtpnwCgY8eOTJ06lcWLF/Pmm29KUlchxEPpbvsE4IixChQoQNWqVbOxpTlLAnEnKV++PGXKlOHvv//OtJyXlxfLli2jUKFCtG7dmlOnTuVQC4UQIuccPHiQVq1aUb58efbs2cPbb7/N+++/z8iRI4mLi7vl/K6lS5eitaZ169Y51GIhnKt169YcPXo0xYoq6XnzzTcZPHgw33zzDZ9//nkOtU4IIXLO3fYJ4uLiWLFiBa1atcJgeHjC14fnldxjSinatm3LunXrbpl4JV++fCxfvhytNc2aNSMoKChnGimEEDmkZMmSBAYGAjBq1CiKFy/OU089RWBgIJMnTwYyvwK+cOFCSpQoIRnTxQOrTZs2gOOzfCtjxoyha9eufPTRR/zyyy/Z3DIhhMhZd9snWLt2LREREbRt2zZH2ptTJBB3og4dOmCxWFi8ePEty5YtW5YlS5Zw6dIlXnjhBUnUIoR4qNSuXZvQ0FCeffZZSpQowbvvvkvHjh154YUXuHTpEkCGV8BDQ0NZs2YNHTp0eGgyo4pHT4kSJahevToLFiy4ZVmlFD/88AMNGzakV69ebNiwIfsbKIQQOeRu+gQAf/zxB7ly5eL555/PqSbnCAnEnejpp5+mRIkSzJs3L0vla9euzS+//MKOHTvo0aMHdrs9m1sohBDZZ/v27fz666/89ttvmM1mfvjhB+rVq8e3337L6tWrMRqNnD17losXL2Zaz6JFi7BYLHTs2DGHWi5E9ujYsSO7d+++5fB0uJlHpkyZMrRt25Z///03B1oohBDZw1l9gri4OP766y/atGmDq6trDrU+Z0gg7kRKKTp16sTq1auThl/cSvv27ZkwYQLz589nyJAh2dxCIYTIHkePHqVVq1ZcunSJKVOmMGDAAHbu3Em/fv2YPHkyb775Ju+++y7Tpk3jiy++yLSu3377jdKlS1O9evUcar0Q2ePll18GYO7cuVkq7+fnR0BAAC4uLrRo0SLLfQkhhLifOLNPEBAQwI0bN+jcuXMOtT7nqPs5Q2f16tX1nj177nUzbsvhw4d5/PHHmTp1KgMGDMjSPlpr/ve///H1118za9Ysevfunc2tFEII5xo3bhwWi4Vhw4YRGxvL1KlTOXXqFI0bN+all17i8uXLREVF4e7uTpEiRTKs59KlSxQtWpThw4fzySef5OArcB6l1F6ttVxFcLIHsU8AjqXMrl27xrFjx7I81WLXrl00aNCAKlWqsG7dOtzd3bO5lUII4TzO6hOA46bl1q1buXjxIiaTKYdegfNk1ieQO+JOVrlyZapWrcpPP/2U5X2UUkybNo1mzZrRr18/1qxZk40tFEII5ytTpgwrVqzg+PHjuLm58cEHH1C/fn2mT5/OunXrKFSoEGXLlr3lCffXX39Fa80rr7ySQy0XInu98sorHD9+nN27d2d5nxo1avDLL7+wc+dOevbsKVPXhBAPFGf1CUJCQli6dCmdO3d+IIPwW5FAPBv06NGDPXv2cOTIkSzvYzKZ+P3333nsscfo0KEDR48ezcYWCiGEc7Vv355nnnmG1atXc/z4cQC6du1Kz549mT9/fpYCCa01c+bMoU6dOpQpUya7myxEjnjppZdwc3Njzpw5t7Vfu3btGDduHH/88QfDhw/PnsYJIUQ2cEafABzTeuLj4+nRo0d2NveekUA8G3Tp0gWTycTs2bNvaz8fHx+WLl2Km5sbLVu2lLlhQoj72qZNm/jmm28YPXo0165do3v37hw8eJA//viD5cuXA+Dq6sqlS5eyNCR3586dHDt2jJ49e2Zzy4XIObly5aJdu3bMnTv3tldIee+99+jTpw9jxoy57UBeCCFykrP7BAA//PADVapUoWrVqtnY8ntHAvFskC9fPlq3bs1PP/1EfHz8be1brFgxlixZwrVr12jTpg2xsbHZ1EohhLhzly5d4tVXXyU0NJSwsDCqVKnC6dOnGTJkCEop5s6dS+3atZk8eTLDhw/P0kn3+++/x8PDIynBlRAPi169ehEWFsaiRYtuaz+lFF999RWNGjWib9++bNy4MZtaKIQQdy47+gT79+9n37599OrVKwdewb0hydqySUBAAC1btmT+/Pl06NDhtvdfuHAhHTp0oFOnTvz6668YDHLNRAhx/5g5cybbtm3j559/BmDdunUMHDiQDh06MGLECMBxEvXx8aFUqVK3rC8yMpKCBQvSoUOH2x5NdL+RZG3Z40HuE9jtdsqWLUuxYsVYv379be8fFhZG7dq1uXbtGjt27KBcuXLZ0EohhLgzzu4TALz55pt89913XL58GX9//2xre3aTZG33QNOmTSlatCizZs26o/3bt2/P2LFjmTdvHiNHjnRu44QQ4i49//zzmM1mzp49i9aa5557jnXr1rF48eKk+axVq1bN8gl37ty5REZG0qdPn+xsthD3hMFg4LXXXmPDhg2cOHHitvf39fVl2bJlGI1GWrZsSXBwcDa0Uggh7oyz+wTR0dH88ssvtG/f/oEOwm9FAvFsYjQaee2111i1ahWnT5++ozo++OADevXqxejRo5OuMAkhxP0gX758KKWYNGkSkZGRAOTJk4fffvuNiIiI267vm2++oXLlytSuXdvZTRXivvDqq69iMpn49ttv72j/UqVKsXjxYi5cuEC7du2Ii4tzcguFEOLOOLtP8Pvvv3Pjxg369u3r7KbeV+46EFdKlVdK7U/2E66UGpiqTAOl1I1kZUbc7XEfBL1798ZoNN7xSVcpxcyZM2nYsCGvvfYamzdvdnILhRDi9iROZ/L19WX69OkEBgby0ksvsXPnzqS1ktevX39b+S327NnD3r17ef3117OcwEXcv6RfkL6CBQvy4osvMnv27DvO/1KnTh1mz57Npk2b6Nu3L/fz9EIhxMMvO/oEAF9//TUVK1akXr162dHs+4ZT54grpYzAJaCm1vpcsucbAO9prVvdTn0P8nywRO3atWPz5s1cuHABNze3O6ojNDSU2rVrc/36dXbs2EHZsmWd3EohhMiczWYjPj4ed3f3NNvGjRvH/v37AThx4gRffPHFbZ08e/XqxR9//MGlS5fIlSuXs5p8z8gc8Zuc2S94GPoEa9eupVGjRvz444907979jusZNWoUH3/8MaNHj2bYsGFObKEQQtxadvYJ9u7dS/Xq1fnyyy956623nNXkeyazPoGzA/EmwMda62dSPd+ARzQQX7NmDY0bN+ann37ilVdeueN6Tp8+Tc2aNfH392fHjh0P9XwJIcT9Z8CAAZw5c4annnqKsmXL0rZtW7y9vZO2X7x4EQ8PD0JCQm5rDfDg4GCKFClC9+7d+eabb7Kj6TlOAvGbnNkveBj6BFprKlasSK5cudi5c+dd1dO9e3d++eUX5s6dS6dOnZzYSiGEyFx29QkAXnvtNebNm8fly5cf+ovzzp4j3gmYm8G22kqpA0qp5UqpSk4+7n3r+eefp0KFCkybNu2u6ildujR//fUX586dk7lhQogcNXjwYM6cOcPEiRPx8/Pj0KFDjBw5kn///TepjKenJ/7+/pQuXfq26v7hhx+IjY2lf//+zm62uD9IvyAZpRT9+/dn165d7N69+67q+e6776hbty49e/Zk27ZtTmylEEJkLDv7BMHBwfz222+88sorD0UQfitOC8SVUi7AC8D8dDbvA4prrasA04C/Mqmnr1Jqj1Jqz/Xr153VvHsm8aS7e/fuu7r6DfDss88ye/ZsNm7cKHPDhBA5xsvLi759+1KxYkX69+9P+/bt8fX15ZdffiE+Pp4VK1awfPlygNua422z2ZgxYwb169fniSeeyK7mi3vEGf2Ch61PANCjRw+8vLzu+gK9q6srixYtomjRorz44ot3nBhWCCFuR3b1CQC+++47YmNjefPNN7Oj6fcdZ94Rbw7s01pfS71Bax2utY5M+HcAYFZK5UmvEq31t1rr6lrr6nnz5nVi8+6dHj164O3tzdSpU++6ri5dujBq1Ch++uknRo8e7YTWCSFE5kqVKsVHH33E1q1bcXV1pVatWjRv3pyDBw+yZcsWypQpQ/PmzW+73iVLlnD27FkGDBiQDa0W94G77hc8jH0CHx8fevTowe+//861a2nemtuSO3duAgICsNvttGzZktDQUCe1Uggh0pddfQKr1cqMGTNo0KABlStXzoaW33+cGYh3JoPhZ0qpAirhkohSqkbCcR+ZRTC9vb3p1asX8+fP5/Lly3dd37Bhw+jevTsff/wxv/76qxNaKIQQGevSpQtvvvkms2bNYsGCBQBUr16dF198kWXLllGmTBn8/Pxuu94vv/wy6W6eeChJvyADAwYMID4+nq+//vqu6ypbtiyLFi3iv//+o127dsTHxzuhhUIIkb7s6hP89ddfnD9/nrffftvZTb5vOSUQV0p5Ao2BP5M9108p1S/hYQfgsFLqADAV6KQfsXHVAwYMSLrSc7eUUsyaNYsGDRrQq1cvNm3a5IQWCiFExrp160aDBg1YtmwZ/fr14/Tp03z//fcULFjwjuo7cOAAGzZsYMCAAZhMJie3Vtxr0i/IXPny5WnRogUzZ850Ss6XevXq8cMPP7Bhwwb69OkjU9eEENnK2X0CgClTplCyZElat27txJbe35yaNd3ZHoYMqcm1adOGLVu2cP78eTw8PO66vtDQUOrUqcO1a9fYvn075cuXd0IrhRAifVarlRMnTjBlyhTi4+PJmzcvEyZMuKO6evbsyfz587l48eIdXTm/n0nW9OzxsPUJVq9eTZMmTZg9ezY9e/Z0Sp2jR49mxIgRjBw5ko8//tgpdQohRHqc2SfYvXs3NWrU4IsvvmDgwIHObeg9lmPLlznbw3bS3bhxIw0aNODrr7/m9ddfd0qdZ86coVatWnh7e7N9+3by5cvnlHqFECLx/JBeshWLxYLZbL6jeq9cuULx4sXp27cv06dPv6s23o8kEM8eD1ufQGvNE088gcFgYP/+/bed1CijOnv16sWcOXOYM2cOPXr0cEJLhRDCQWvt9D4BQOfOnQkICODChQv4+PjcTRPvOzm5fJnIRL169ahWrRpffPEFdrvdKXWWKlWKJUuWcOXKFV544QWio6OdUq8QQnz66ae88cYb2Gy2NNvu5oQ7ffp0rFbrQ3fVW4jboZTi3Xff5eDBg6xZs8ZpdX7zzTc8//zz9OnTh7Vr1zqlXiGEOHHiBDVr1uT48eNptt1Nn+D8+fPMnz+fPn36PHRB+K1IIJ6DlFIMGjSI48ePs2zZMqfVW7NmTX777Td27dpF165d0+00CyHE7fj5558ZMWIE0dHRGAzOO1VERUUxc+ZM2rRpQ5kyZZxWrxAPoi5dulCgQAEmTZrktDpdXFxYuHAh5cqVo127dhw+fNhpdQshHk3Xr1+nRYsW/Pfff07P6zJlyhQA3nrrLafW+yCQQDyHdejQgWLFit3xHIqMtGnThilTpvDXX3/xzjvvSKIWIcQdW7t2Lb169eK5557ju+++c8qQ2UQ//PADoaGhvPfee06rU4gHlaurKwMGDGDlypUcPHjQafXmypWLgIAAPD09ad68OZcuXXJa3UKIR0t0dDStW7fm0qVLLFmyhNKlSzut7rCwMGbNmsXLL79MsWLFnFbvg0IC8RxmNpt555132Lx5Mzt27HBq3W+99RbvvPMO06ZNY/LkyU6tWwjxaDh06BDt2rWjQoUKLFy4EBcXF6fVbbVamTx5MnXq1KFOnTpOq1eIB1m/fv3w9PR0+gX6YsWKERAQQFhYGC1btiQ8PNyp9QshHn42m40uXbqwa9cu5s6dS61atZxa/9dff01kZOQje3FeAvF7oHfv3vj5+TF+/Hin1z1x4kQ6dOjAe++9x++//+70+oUQD68LFy7QvHlzvLy8CAgIwNfX16n1z58/n7Nnz/LBBx84tV4hHmT+/v706dOHefPmcf78eafWXbVqVRYsWMCRI0do3769rDEuhMgyrTVvvfUWixcvZurUqbRp08ap9cfGxvLll1/SuHFjnnzySafW/aCQQPwe8PLyon///vz111/pJjy4GwaDgZ9//pm6devSvXt3Nm7c6NT6hRAPp7CwMFq0aEFERATLly+naNGiTq1fa824ceOoUKHCI7VGqBBZ8c477wBky2i2pk2bMmvWLNasWUPv3r1l6poQIkvGjRvHjBkzeP/993nzzTedXv/PP//M1atX+fDDD51e94NCAvF7ZMCAAbi6ujJu3Din1+3m5sZff/1F6dKlefHFFzl06JDTjyGEeHjExsbSpk0bjh8/zp9//skTTzzh9GOsWLGCAwcO8MEHHzg1+ZsQD4NixYrRpUsXZs2aRVBQkNPr79mzJ6NHj+bnn39m6NChTq9fCPFw+fnnnxkyZAidO3dm7NixTq/fZrMxfvx4qlWrxnPPPef0+h8U0hu6R/Lly8drr73GL7/8woULF5xev7+/PytWrEhK1OLs4W5CiIeD3W5PGj3z448/8vzzz2fLcT7//HOKFClC165ds6V+IR50H374IdHR0UybNi1b6v/oo494/fXXGTt2LNOnT8+WYwghHnwrV65MStg6e/bsbLl4vmDBAk6dOsXgwYOdmhD2QSOB+D30/vvvY7fbnbpsSXLFihVj+fLlRERE0KxZM0JCQrLlOEKIB5PWmoEDBzJ//nwmTpxI586ds+U4W7duZfPmzbz33ntOTf4mxMPkscceo02bNkybNo2IiAin16+UYvr06bz44ou89dZbzJ8/3+nHEEI82Pbs2UP79u2pVKkSixYtwtXV1enH0FozduxYypcvT9u2bZ1e/4NEAvF7qHjx4nTt2pVvv/2W69evZ8sxnnjiCRYvXszp06dp3bo10dHR2XIcIcSDZ+zYsUybNo13332XQYMGZdtxxowZQ548eejdu3e2HUOIh8GQIUMIDQ3l66+/zpb6TSYTc+fOpU6dOnTr1o3169dny3GEEA+ekydP0qJFC/LkycPy5cvx8fHJluMsX76c/fv38+GHH2I0GrPlGA8KCcTvsSFDhhAbG5u0mH12aNCgAb/++ivbt2+nU6dOWK3WbDuWEOLB8MMPPzB06FC6dOni9GWTkvvnn38ICAjgnXfewdPTM9uOI8TDoEaNGjRq1IhJkyYRExOTLcdwd3fn77//pkyZMrRp04b9+/dny3GEEA+OK1eu0LRpU7TWrFq1ioIFC2bLcbTWfPbZZxQrVoxu3bplyzEeJBKI32MVKlSgffv2TJ8+nbCwsGw7TocOHfjqq69YsmQJffr0kaypQjzCFi9eTJ8+fWjSpEm2zf9K9Nlnn5ErVy769++fbccQ4mEybNgwrl27xvfff59tx/D392flypXkypWLZs2acfr06Ww7lhDi/hYWFkazZs0IDAwkICCAcuXKZduxNmzYwLZt2/jwww8xm83ZdpwHhQTi94Fhw4YRHh7O1KlTs/U4b7zxBh9//DFz5sx5pJcKEOJRtmnTJl5++WWqV6/OwoULs3XO9pEjR1i4cCEDBgwgV65c2XYcIR4m9erV49lnn2XcuHHExcVl23GKFCnCypUrsVqtNGnShKtXr2bbsYQQ96eYmBheeOEFjh07xp9//snTTz+drccbPXo0BQsWpFevXtl6nAeFBOL3gSpVqvDCCy8wZcoUwsPDs7zfuXPnmDlzJgsXLsxyYpePP/6Y//3vf0yYMIHx48ffaZOFEA+gf/75h9atW1OyZEmWLVuGl5fXLfex2+1s2bKFr776io0bN97W8T799FO8vLwYOHDgHbZYiEePUorhw4dz8eJF5syZk+X9LBYLf/zxB99++y3Hjx/P0j4VK1YkICCAa9eu0bRp02wdmSeEuL9YLBZefvlltmzZws8//0yTJk2ytN+lS5eYM2cOc+bMua3Eklu2bGH9+vW8//77uLm53WmzHy5a6/v2p1q1avpRsXv3bg3ozz77LEvlL168qP38/DSgAe3j46M/+eQTHR0dfct9rVar7tSpkwb0rFmz7rbpQogHwIkTJ3S+fPl00aJF9fnz57O0z+LFi/Vjjz2W9D0D6J9++ilL+x47dkwrpfSHH354N81+IAF79H1wDn3Yfh6lPoHdbtc1a9bUxYsX13FxcVnap1evXin+Vps1a6YPHjyYpX1XrVqlzWazfuaZZ3RUVNTdNF0I8QCw2Wz6lVde0YD+6quvsrTPxYsXdffu3bXRaEz6nqlVq5a22+1Z2r9Ro0Y6X758j9x3TGZ9gnt+Ys3s51E66WqtdYsWLbS/v78ODw+/Zdnp06drQO/cuVNv3LhRt23bVgO6dOnSesuWLbfcPy4uTjdr1kwbDAY9f/58ZzRfCHGfunDhgi5WrJjOkyeP/vfff29ZPjAwULdr104DukKFCvrHH3/UFy9e1I8//riuX79+lo7ZpUsX7eHhoQMDA++y9Q8eCcSlT+AMAQEBGtDffvvtLcvGxsZqs9ms+/btq0+fPq0/++wz7efnp00mkx4+fLiOj4+/ZR3z58/XSindrFmzLAf/QogHj91u12+99ZYG9OjRo7NU/ttvv9Xe3t7a1dVVv/POO/rw4cN68uTJGtCHDx++ZR1bt27VgJ4wYYIzXsIDRQLxB8SuXbs0oMeMGXPLsiNGjNBAiqtQa9eu1SVLltQGg0HPmzfvlnVERkbqZ555RpvNZr18+fK7arsQ4v4UGBioK1SooH18fPTevXtvWX7dunW6QIEC2sXFRY8dOzZFB759+/a6YsWKt6zj33//1QaDQb///vt31fYHlQTi0idwBrvdrp9++mldokSJWwbS165d04CeNm1a0nNBQUG6e/fuGtANGzbM0l2rWbNmaUB37NhRW63Wu34NQoj7T2IM8c4779zyeyE0NDTpwvxzzz2nT58+nbRt3bp1GtDr1q275TGbNGmi8+bNqyMjI++6/Q+azPoEMkf8PvL000/TokULJk6ceMs5F76+vgCEhIQkPffcc89x4MAB3nrrLRo1apTp/rGxsRiNRpYuXUqlSpVo164dmzdvvuvXIIS4f4SFhdG0aVPOnTvHkiVLeOqppzIsq7Vm/PjxNGrUCF9fX3bv3p0mq2lISEjSd09mRo0ahZubG++9954zXoYQjySlFCNHjuTs2bO3nCuemAwxeZ8gd+7c/Pjjj8yfP5/XXnsNpVSmdcTGxtKtWzcmTJjAH3/8weuvv46jDymEeFhMnjyZUaNG0atXLyZOnJjp98KhQ4eoXr06f//9NxMnTmT16tWUKlUqaXvi982t+gXbtm1j1apVvP/++7KMaWoZRej3w8+jdvVb65t3xT/99NNMyy1dulQDetOmTbdVf1RUlF64cKFu2LChbt26tZ47d66+du2aLl++vPbx8dG7d+++m+YLIe4TkZGRuk6dOlka8RIdHa07d+6cdCcsIiIiTRm73a7z58+ve/bsmWldR48e1Uop/cEHH9xV+x9kyB1x6RM4id1u1zVq1NDFihW75XDxEiVK6Jdffvm2j5Fev2D48OEa0G+//XaW538KIe5v33zzjQb0Sy+9dMsRL3/++af29PTUBQsW1Fu3bk23zKhRozRwyym1jRo1emTvhmudeZ/gnp9YM/t5FE+6WmvdqlUr7evrq8PCwjIsExgYqAH9+eefZ7ne69ev68mTJ+sXXnhBL1y4UO/bt09XqlRJ//vvv/rChQu6RIkS2t/fXx86dMgZL0MIcY/ExMTo559/XhsMBr1gwYJMy165ckXXqFFDK6X0mDFjMux0nzp1SgN6xowZmdbXsWNH7eXlpa9fv37H7X/QSSAufQJnWrFihQb0zJkzMy3XsWNHXaRIkdsKnDPqFxw7dkwPHDhQA3rYsGF3+xKEEPfYL7/8opVSukWLFple1LPb7Xrs2LEa0DVq1NCXLl3KsGzTpk11pUqVMj3uxo0bNaAnTZp0x21/0Ekg/oDZt2+fBvTHH3+cabnKlSvr5557Lkt1xsbG6ilTpui+ffvqXbt2JT1ft27dpLvgp0+f1oUKFdL58+fPUkInIcT9Jy4uTrds2VIrpfSPP/6YadnDhw/r4sWLaw8PD71o0aJMy86cOVMD+tixYxmWOXjwoAb00KFD76TpDw0JxKVP4Ex2u13XqVNHFy5cWMfExGRYLit/o8ndql9gt9t17969b2tFFyHE/WfBggXaaDTqhg0bZrq6Unx8fNLffKdOnTItGxMToz08PPSbb76ZYRm73a7r16+vCxQo8MhlSk9OAvEHULt27bS3t7cOCgrKsMwHH3ygTSaTDg0NvWV9K1eu1E2bNtXbtm3TWjv+2ObNm6d79+6tbTZbUrljx47pvHnz6sKFC6dIyCCEuP9ZLJakpCpff/11pmXXr1+vc+XKpQsUKKD37Nlzy7pbtGihS5UqlendtjZt2mgfHx8dHBx8221/mEggLn0CZ1u7dq0G9JQpUzIsc+7cOQ3o8ePHZ6nOrPQLrFar7tq1qwb05MmT7/6FCCFy1JIlS7TZbNa1a9dOd9pZovDwcN20aVMN6I8++uiWI2uWLVumgUynvq1evVoDeurUqXfc/oeBBOIPoMOHD99yDd7t27dr4JZ3vaxWq3755Zf17NmztdaOO2Zr167V77zzjp4+fbq22+0p/uD279+v/f39dfHixfXZs2ed8nqEENnLarXqTp063bKzrrXW8+bN0y4uLrpixYpZ+hsPDQ3VZrNZDxo0KMMyu3fv1oD+5JNPbrvtDxsJxKVPkB0aNmyo8+XLl+k8y6eeekrXqFHjlnXdTr/AYrHo9u3b39Z6w0KIe2/lypXaxcVFV69ePdPprleuXNFPPvmkNhqNetasWVmqu2fPntrHx0fHxsamu91ut+uaNWvqokWLZljmUZFZn0Cypt+nKlWqROfOnZk6dSpXr15Nt0zNmjUpXrw4v/32W6Z1KaVwc3PDYrEAMHfuXAICAvDx8aFXr14opVJkTSxcuDCrV68mLCyM5557josXL2Za/95zoXy1/hR7z4Xe5qu8O1k57r1qm3hwPAyfEZvNxquvvsq8efMYO3Ysb7/9doZlv/zySzp16kTNmjXZunUrxYsXv2X9CxcuxGKx8PLLL2dY5qOPPiJ37twMHDjwTl6CEOIWPv30UwIDA5k6dWqGZV5++WV27drFqVOnMq3rdvoFJpOJqVOn8sILL9C/f39mzZqVad03oi3sOBPMvnOhxFpst/kqhRDOsG7dOl588UUqVqzIypUrk1ZWSO3kyZPUqVOH48eP8/fff9O7d+9b1h0bG8uiRYto27Ytrq6u6ZZZsmQJO3fuZPjw4RmWEaAcgfr9qXr16nrPnj33uhn3zKlTp6hQoQJvvPEG06ZNS7fMsGHD+Pzzz7lw4QKFChXKsK4jR47QuXNn/P39KVq0KA0aNOCTI7kwuLgDcPqzZkycOJFr166xd+9eevbsSaVKlWjcuDH58uVjw4YNFC5cOE29e8+F0vW7HcRb7biYDPzauxbVivs55w3IRFaOm1GZvedC2XEmmFqlcudIW8X96159fp3JbrfTu3dvZs+ezejRoxk2bFi65bTWDBkyhHHjxtGuXTt+/fVX3NzcsnSM+vXrc/XqVf799990lzrZuHEjDRo0YMKECbJkGaCU2qu1rn6v2/GwedT7BACtWrVi69at/Pfff+kuGXTx4kWKFSvG8OHD+eSTTzKtK3m/YE+QAdeij+NZsS4GF3fOjm2JzWZL0S/o2rUrf/31F8uXL+f777+nV69eaeq02zVrjl0DBfFWO/m8XalRMrezXr4QIgs2bNhAixYtKF26NOvXrydPnjzpltuzZw8tWrRAa01AQABPP/10lur/448/ePnll1m9enW6yyXb7XaqVq1KTEwMR48eTbEM6qMosz6B3BG/j5UpU4bXXnuNb775hrNnz6ZbpkePHtjtdn788cdM66pUqRLr16/np59+4ttvv2X0yQJJQbi22/CuUIezZ8/SsGFDJk+ezCeffEJ8fDwrVqzg6tWrNGzYkMuXL6epd8eZYOKtduwaLFY7O84E3/XrzoqsHDe9MomB16RVx+n63Y4H+i6ouHv36vPrLHa7nT59+jB79mxGjhyZYRButVrp1asX48aN44033uCPP/7IchB+8uRJNm3aRM+ePdMNwrXWDB06lEKFCtG/f/+7ej1CiMx9+umnhIWFMWHChHS3FylShMaNGzN79mxstszvRifvF/g3fRPvKk2S+gXFP/ibtm3bpugXjBkzhkGDBtG0adOki3+paRwBuKeLCQ8XIzHxckdciJy0ceNGWrZsScmSJVm7dm2GQfjq1atp0KABnp6ebN26NctBOMD3339P0aJFadiwYbrb582bx6FDhxg1atQjH4TfigTi97kRI0ZgNBr5+OOP091etmxZ6tevz6xZs7Db7ZnWlTt3booVK8Y333xD7IXDgKMTHbR4HAb3XEyePJlmzZpRrVo12rRpg7+/P7Vr12blypVcuXIl3WC8VqncuJgMGBWYTQZqlcr6le+7GRKcleOmV+ZBD7yEc93N5/des9vt9O3blx9++IHhw4czYsSIdMtFR0fTtm1b5syZw8iRI/nqq68wGo1ZPs6sWbMwGo307Nkz3e3Lli1j27ZtjBgxAnd39zt5KUKILKpatSqdOnViypQpGU5b6927NxcuXGDlypW3rC+xXxC5f0WafkHevHnT9AsKFizIokWLaNy4Ma+99hpz5sxJUZ/RoKhUyIeQqHhiLTYqFU5/OKwQwvk2btxIixYtKF68OOvWrSNfvnzplps3bx4tW7akdOnSbNu2jXLlymX5GP/99x+rV6+mV69e6fYlLBYLI0aMoEqVKplOZxMJMpo8fj/8POqJWRK99957WimV4frec+fO1YAOCAjIUn0hISG6YM+puviHS3XuFgO11xNNdNGBfyRtX716tS5Xrpw+fvy41lrrGzdu6C1btmgPTy+dt0gJHbDjcIr69pwN0dPXndR7zoZkeMzUZfacDdHlhwXokoOX6vLDAlI8n7qujJ4b8udBPfTPg3d03FKpjiseXVn5/N5vrFarfvXVVzWghw8fnmF205CQEP3MM89opdQt1/9OT0xMjM6dO7du165dhu14/PHHdZkyZXR8fPxt1/+wQpK1SZ8gG508eVKbTCbdv3//dLfHxcXp/Pnz65YtW2a5ziJvz0vRL+jdu7cODw9P2p66X3Dq7AVd5sk6GqX0B599kbYNFpu2WG1pntfakcTpYki0PnQxTF8Ny3g5NiFE1q1bt057eHjoihUr6itXrmRYbvr06VoppevWrZulVZdS++CDD7TRaNQXLlxId/uMGTM0oJctW3bbdT+sMusT3PMTa2Y/ctJ1CAoK0rly5dKtW7dOd3viSbdFixa3VW/xD5dq3/o9tF+j15OWKtm4caMuXbq0nj9/vtZa6+vXr+tXXnlFvzPsU12sx0StXNy12a+gXrrtYJaPk17QPX3dSV1y8FJd/MOlutTgpUmBUOpyGT1X9qMAXeLDpbrsR7cfTD+IgZfImnvxu83pY1qtVt29e3cN6I8//jjDcpcuXdKVK1fWLi4u+o8//siwXGZmz56tAb127dp0t//0008a0PPmzbuj+h9WEohLnyC79evXT5tMJn3q1Kl0tw8fPlwrpTLcnpnPP/9cT5s2LdN+wdONXtSlWvTVeSrU0ICe8OVX2maza5st8yWPtNY6KCJWrzt2Te86E6w3/HtN34iRi3ji4WGx2nRsvDXDC1HZYc2aNdrd3V0/9thj+urVq+mWsdvtesSIERrQL774YqZrhGckOjpa+/v7Z3hxPjIyUhcoUEDXrVv3lsufPUoy6xPI0PQHQO7cufnwww9ZsmQJW7ZsSbPdxcWFfv36ERAQwIkTJ7JU59iAYxT3daWkKYy3ny2IwWBg5cqVdO3aleHDh9OhQwcAfHx8qFGjBt9OnUR8XBz5O47GGnWDbm2b8+Wirbzy/U5+23k+02MlHw4enzAcvFap3JgMCoVjKFtGw8bTe+7PfReJt9qT5qL9uS/zrO6pVSvuR/+GZR64pFwic8nn/7/8zfZbfi6T73enUyQSjzlh5XHaz9xGm+lp/z6dyWq18sorr/DTTz8xatQoRo4cmW65kydP8swzz3D27FkCAgJ46aWXbvtYWmu+/PJLKlWqlO48sLi4OIYPH85TTz11R/ULIe7ciBEjMJvNGeaF6NevH0ajkenTp9+yrv3nQqn88UrKDA3gk8UHOXbsGMHBwZn2C/xKPMaFDfMo99xL5K5Qk/ff7k/nd0bSYOIGun23g8Dw2AyPZ7HZMRjA3cUICmy2+zdpsBC3w27XRMXE8fI326kzZg1zd57L9mOuWLGCVq1aUaZMGdavX0/+/PnTlLHZbPTv359Ro0bx6quvsmDBgjuaSvbrr78SEhLCW2+9le72L774gqtXrzJ27Nh0c8qItCQQf0C8/fbbFCxYkA8//BDHxZWU+vXrh4uLC1OmTAGgxOBlST+pjQ04xtebznAuLI7rFTvw9Zzf6NatG6NHj2by5Mn06NEjKThZsvkflFI83aApJi8/XAtXIH+nzwi/cYNBPdqybucBhi46lGHQs/dcKPsvhGFPaLJdw4bjgRy/GoFNOxK7JJ6D05uvm95zqV+9M0/hD8NSVo+q5BdtrHbNiMWHb/l7zCx5X1Y+CzvOBBNruZmbYf/FG7SZviVbPkPx8fF06tSJuXPnMnbsWIYPH55uuX379vHMM88QGRnJ+vXref755+/oeBs2bGD//v28/fbb6Z5QZ8yYwblz5xg3bhwGg5xKhMhJBQsW5N1332XevHns27cvzfZChQrRsWNHvv/+e27cuJGiT9Dtm5QXDLt+v5PIOCtWu2b29gu8/v5wFi5cmKJf0LJdJ85cj+RGdDynTp2mQgFv8lR+hhA88Ww5GPcyNflj6igOr/iV/RfCmL4+4+XT/D1d8XQ1ERIVj5+HCz7uN5M52e063T6OEA+KRlM2s//iDQKjLAxZdJg9Z68TEWMh1mJ1+rH+/vtvXnzxRSpUqJDhnPC4uDg6d+7MzJkz+fDDD/n+++8xmUy3fSy73c4XX3xB1apVqVevXprt169fZ/z48bRp04Y6derc0et5FMnyZQ+QWbNm0bdvX/7880/atm2bZvtrr73G3Llzyd37O4weKROknB3bMunfDSas52xwdNLjIp6wsG91rFYrz0zbn/S8NeQCcad2UsXXygm/GhjylkraFn/tDNd+H4YyGMnX6TMKlyjD7mGNUxwzMciJs9jTBMtKQfKPXo0SftQvnw8/DxdCo+NTLC2WermxvedCefmbbVjtYDLA76/XuaO72+nV+6AvZfUoc3wutmNNuOpjAAY1LU//hmUy3Oer9aeYtOo4dg1GBZ1qFKOQrzt+Hi6MWnrklp+FvedCaT9zW5rnDQqnfoZiY2Pp0KEDy5Yt44svvshwre5169YlJVpctWrVbSVgSa1ly5bs3r2bc+fOpblyHhYWRunSpalWrRqrVq2642M8rGT5suwhfYKUbty4QZkyZahSpQqrV69Oc8Fs3759VKtWjedeeZvThVKen5P3CUoPWUbym9LfdHuKOsW9+GHtISasPonJOw/eLgZ+6VOL48ePs2fjSv49eYbnXuzMlAOOrOjaZiVoyUSij28hV91uvP7We4zrUCXDttvtmnibHVeTIand18NjORsSjYvRQNn8Xni43H6wIMS9VnLwshR93m41izKkRUWsNo2nqwmT0TkXrn///Xe6devGk08+ycqVK/HzS9vXiIiIoG3btqxdu5aJEycyaNCgOz5eQEAALVu25KeffuKVV15Js/3tt99m+vTpHD58mIoVK97xcR5GsnzZQ+LVV1+lYsWKDB48GIvFkmb7oEGDiImJIWLf0kzraVapQIrHraqVIn/+/DwzbT/a5qjXEnaVqBM7iAu7hqVMg6QgXGtN5JH1xF3+l/xdxgJw7bfBBJ39N81xEu9QpnepJ/X1n11nQ5m06jijlh5JERh/lXBVPfVQcoPBgEr4/51IfSf0t53nmbLmBHEWyaj+oKpW3I9RL1bGZFAYABfzrbOgJx9xoQyKebvOM3HlcUYsPpxpdv3kn82qRdJmBXbmZygyMpKWLVsSEBBAp3dGUbdtj3TLLViwgObNm1OsWDG2bt16V0H44cOHCQgIYMCAAekOXxs7diyhoaGMHz/+jo8hhLg7uXLlYvjw4axduzbdDOlPPfUUzz33HBsX/Yy2puwzJB8t92LVQkn/9nY10ahiPjw9vfhiRygm7zxom4WIeDsL1+1k46oAzp49R9P2XXi2lmO5I601Uf9uxrVYZTwrNeTG5l+4sPK7TO9sGwwKN7MxKQi32uycDYnC29WE1pqLodEZ7ivE/Sb5nN/i/inPmW2fLJzwOVdZHsFps2vs9oxLT/t6Fp06d8GtcAWefH0iXj5p+yGBgYE0bNiQDRs28OOPP95VEA4wfvx4ihQpkm4m9NOnTzNz5kx69+4tQfhtclogrpQ6q5Q6pJTar5RKc8laOUxVSp1SSh1USj3lrGM/KkwmE+PGjePEiRPMmjUrzfbHHnuMF154gYi9S7HHx6TYVmLwsqT53INbVKRfvVKUyO1Bv3qlGNyiYtIw2qCAKVz/ayyxZ//BGh5EricaYfUvkVSPtsZhcHEn8tAaov/dQv6u41BmVy7/NpStW7emOGbyIMdkcNwlzIxdQ7zl1ut97zgTjCUhwLfeYbCTYt66xc6IxYfZeioIjeOP4kFbyko4dKlZjN9fr82gpuWzdDe6WnE/fu1di5drFENrbk6XsGsMSqW7rFnqz+bw1pWoVzYPLkZFQR9XXIzp75eRzIbAh4aG0qRJEzZu3Ej+1u+y0/WpNH8PADNnzqRjx45Ur16dzZs3U7hw4ay9YRkYO3Ysnp6e6a4Lfv78eaZMmUK3bt2oWrXqXR1HPLykT5Az+vXrR+nSpfnggw/SXTf8ww8/xBYZQuSR9Wm2lRi8jHrj1vPmc2WZ/3otJnZ4gv0jGmE0GrkcdrMPkdgv+P6PZQReuUTLti+Tq0g5zodE83h+j6R+QdThdZh8C+BVtRlzv5vOm2++SZzFSpw147XEE4MXAIXCardjs2uMMt1FPEDsCYGz3a5Z8249GlXIQ5m8HnzTtSpl83kTb7VjNipMt+oIAxExFgLDY7keGYfFlnZZ4mnTpvHWG33xLlWVKr3HcuBaPFtPBaUo899///HMM89w9OhRFi9eTPfu3e/q9W3fvp2NGzfy7rvv4uLikmb7kCFDMJvNGeatERlz9jddQ6111QxuvzcHyib89AVmOvnYj4RWrVpRv359Ro4cSXh4eJrtQ4YMwR4bQeT+5Wm2bT4ZlDSfe3CLimx4vyGDWziuXCUGs/6N38Aafo3wHQvweboNroXKcynsZtIVg9kNj7K1KNh9MrHn9mOPvkGBruNRHr40btyY6T8tSAoqqhX3Y0SrStQpk4dRLz5O37ql0rQpNTuOL6HkgXKcxc6UNSeSgg8/D5ekq4r2hMe3O7c7+UUCg0Fh1xq7dlwseKZsHhmW/gC73WR81Yr7UdjXPcXVZ6NBMerFyrzbJG1An14CQX9PFyw2zdXwOFCKl2sU49fetQDH8Pffdp5P9/OZ2QWnq1ev0qBBA/bu3curw7/E/bGGae60a60ZOXIk//vf/2jZsiWrV69Od3ja7Thz5gzz5s3j9ddfx9/fP832xORQn3766V0dRzwSpE+QzVxcXBgzZgyHDh3ixx9/TLO9cePGPPnkU9j++RNtTxsQnw+NpufUjTxdMjcdqhdNWhc4Kv7mfNbEfoE+8Df/e3MAT9esyRNFfMjv48qM7k+z4p16lH8ysV9wAM9Kz+NTox0zZszg5c7dOBsYTkhUfJpja62xWO1YbRqtoWx+LzTg426miN/tJ5IS4l5IupCUMLrDaDRSIo8XMfE2Vh27hqvZiI+bCQ9XE1FxVq7ciCEsKg6r1ZbmrrfNromMs+LuYsSgICrOytaTVykzZBl1P1/N6NGjeeutt6jybGPKdBsFJjcAXE031/M+cOAAderUITg4mDVr1tCyZUvu1pgxY/D396dPnz5ptu3YsYP58+fz/vvvU7Bgwbs+1qMmJy85vgj8lJDJfQfgq5SS39htUkoxceJErl+/ztixY9Nsr1WrFk/XqUf47kXYLXHp1jF00aGkwDUxQPDzcMHNbMDo5kX+TmNAKep5XsFmB61TXpFTgPXGNdCgXNwx+eQlb+ex5CtakgGvduaTL79LGu49aukRtp4KYuTfh5m15b809aR3bfC7Lf8REWNJSvCmgS0ng5ICldDo+KS76wYFhy/fyDCYyUjindB3m5Rn1IuVk4JyF5OBgY3KSRD+iEl+cQeg97Ml6VKzWLoBffKLOEajgaUHLvPX/stoSMrkX9jX0Yns+t0OJq48ztBFh9J8PveeC2XKmhNpgvoSg5dRuN/3FK34FKdOnWLp0qX069klTdJCm83GG2+8wSeffELPnj1ZtGgRHh4ed/1ejB07FpPJlO4wtr179/Lzzz8zcOBAihUrdtfHEo806RM4yUsvvUStWrUYNmwYUVFRKbYppej0+lvcuHaR6H83p7v/eQtEx1u5EhZDYEQs54OjMRjgqw5lATC6efHa6Fm4u5rYtnkjGo3ZYMDTxUwudxdKF/Bn4f/qJfULDK7u+DZ4lXeGjGDxwt8Z8Fo3rgXfSBN02O2JfQvH8z5uZh4v7EvZ/N4pAgsh7mdKKZRSSaM7vt98hu+2nOXSjTgW7rvCJ0uOYDAYiLPaCIyIw2rTXAmPJdpiSzEEPSQyjoOXwjgXHEV0vBWrTbPnbBBdv9+LxW7nwJ9fMWLECLp3786GFX/zbPlCmE1GXqpelFqlHBfNN2zYQL169TCZTGzevNkpSdP279/P0qVLGThwIF5eXim2aa0ZNGgQBQoU4L333rvrYz2KnJkJQwOrlFIa+EZr/W2q7YWBC8keX0x47ooT2/BIqF69Ol27duWLL76gX79+aTrEtdv3Yfe2V4g8uAqfaq3TrSN5ErXExFIjWlVi6KJDGFzcKdjrKzq3LcXe5ZeJDgkCZcAeE44tMhi7JY7oE9txLfoYRk9HkGL09KXmm19ydfzbXF88Hh0bwfIS/jeDDJtOEegooJCvG4ERcY4vomQb7Vpz5Eo4ipsZ0TU3AxU/DxcMyrHVxeSYK546mMlKIF2tuF9SufIFvFMkbhP3RuoEejklNDo+6fNmALyTZfFNLXGkx++7z3Pk8g2OXY1IU8bPw4Uu324nLlkGpNR3s1P/DZpNBiasPE789bME/jECbY3Hp8MnNG7sSLL0a+9aSe/NY/nc6NChA3/99ReDBw9mzJgxTlkq5Pz588yZM4c+ffpQqFChFNsST7h58uRhyJAhd30s8dCTPkEOUUoxadIknnnmGSZMmJBmeGiVOk0oVbYCIbt+R1esh1Jp78H8eyUCjebgxRvk9nTBxWSg/4KTSduPXI/i0KFD/Hf+Im4mE+fOnyW3lxuHDp3i6pUrxMVEU/D4n0QUfQw3bz8mvfQEVYs3wM/Xj48Hv0uPl9uwavkyfH19qTVmNRExVv7sX5OSuXMl5IzRJL80v+zAJQb+vp9c7mb2DG+SPW+cEE5iMKiEjP+K7ammS+4/H4LWmrfn/sOKI9cwKJjTvToFfNwSEhdr4q2a41cj8HIzEaM1wVFxlPD3ot/P/6BtVoKXf0nUkfV4V3uB2bNnYzAYmNr5yRTHWbBgAV27dsUzb2F06xG0/Pks297yS3Muv12jR4/Gx8eHAQMGpNm2YMECtm3bxqxZs9IE6SJrnHlH/Fmt9VM4hpv1V0qlzW2fBUqpvkqpPUqpPdevX3di8x4uY8aMAWDo0KFptr3SrgXuRSsRvnMB2pp2OBiQIolaYoAwdNGhpO0GsytvLb1EyetbufxtH0w75+BxfBmRB5YTe24/7iWfwuuJJknZ2V1MBl6sUY5iXT/FvXR1glbO4OTy2Umn1cRjGXAEHRq4FBaLxaYpldcrxZ1xk9FA88oFcTUbkj6giYFKYjbrxDm8I1pVot1TRdLcLbxdsrb4vZfZMO27rTezaQt7z4VyKSwGc2I+A5OBS2Ex6Q4jTxxFMmrpEQ5evIE17fQtAD75+3CKIBxufoZrlcrNn/su3gzCgWfKOKZDxF48wrVfPwQgf5exuBaumLTkUPuZ2+jfsAwlvTWNGzfmr78W49fodebqZyk5JCDdpQpv1+effw445pWm9vfff7Nx40Y++eQTcuVKmxhGiFSkT5CD6tSpw0svvcSECRO4dOlSim2l83vxUp+3Cbt2gUFlQkh9r/mrV8uj0bibTcRZ7Hi6mFIsKQZwLkxjcnGlbJnSLJ//I01qP8nIoR8wbcpkvp/1LWvXrqNP57bMHT+YP99pRuWifpjNiqLPvkDTAWPYu2cXT9d+liJv/szV8HiiLHaaTtlOfFwMJqPCnCx7+r9Xwug/dz8WOwRFWZzy3SZEVlmsNuIsNizp5DZIntMg+XOJozvsWtO5RsqbYx2fLs6KQ5dYceRaQhno/uMe3BJGfSgFV8KjuXIjlqDIeIwGhZerGW93MwU94fqfnxJ1ZD256nbD7/k+lBq6PM3yxP6N+/HSSx1ReUvj2X4MJp+8ANSZ+s9dvRcHDx7kzz//5O2338bX1zfFtri4OAYPHszjjz/Oq6++elfHeZRly/JlSqmRQKTWemKy574BNmit5yY8Pg400FpnePVblirJ3EcffcSYMWPYuXMnNWrUSLHt67mLeaNLG94bOZYJH3+Y4g924Rt1HMt0WezYuXlHPPl6yEmsFgIXf47Rw5cirQdiA+Is1hRX1KsUycWI1pUAGLv8GLv/CyIoYCpRh9fiVbU5/o37oQyOLxyzAYrn9uTU9ai0x0pmTNvH2XA8kDPXI/H3dCSGiLPaye/jxppj11IsN6WBoIg48nq70u6pIrcVTCfegU29bNq9ujP7KEu9lNi7TTJfeiwrbrUkXfLtJoOiQfl8bDgeiNXuGG3Rs3YJjlwJp1JBH+ZsP0u81Y5BKWz2lCM8kks+kiO5rjWL0e6pIgB0/nY78QmBuovJwNw+tbi4fxNtO3TE6JOX/B1HYcqVP00d1huB+GycwLETJ8nT6j08KzybpkzyZYlux7lz5yhbtiyvvfYaM2emnK4bHx9PpUqVMJvNHDx48I7WIH2UyPJlKUmfIGf8999/VKhQgc6dOzNnzpwU28Ki46hd/UkUcOjQIYxGI5tPn6Nu6eJYbXaOX40gxmLlv+tReLia+GbdQQ5eTRmI/DOsEYsPXOTjRYe4vvhzCuTLy+F1izAoA7HxFjzcXNB2jVaayDgbE1ceZ9+5EBRQIPIkcz9/G7urD/lf+gRzbsd3YcdqhRn/UtWkY2itaThhHWdDYlMcO6vfa44RdhqjUhiykBgrI46AyxEkOWO0kXgw2O0ai82edIfbbDQkfY7OXI9g4/HrmIyKxo8VoEAuxxQ0rTV/7rvIBwsOYtPQsFweej5Tkldn78aOY+jxY4V9OHgpZV6nE582w8VkJNZi42JINJHxVs5cjyK/tyvVSvgTeSOUVq1as3PXLvybvIF31eZp2utmsGPZNZcza+fiXrYWeVq/j8HsmqLMnfYJANq3b8+aNWs4e/ZsmvwzEyZM4IMPPmDVqlVJI/dE+jLrEzglEFdKeQIGrXVEwr9XA6O01iuSlWkJvAm0AGoCU7XWNdKtMIGcdDMXERFB2bJlKVWqFFu3bk1xstBa07BhQ44fP87p06fTzB1NLwBNbz1kAG23EfjHx5jzFsf/+ZuJGhRgNCoqFfShZB5Plhy4nLQWqdaasI0/Er5zQcKXw3sYzI6kEmajwmLL/HNnNDiCnYy22e0ak9ExJyfxrmRiQJPVwDl5EGbXjtfjanYM0U9cQ9pkNNChWhHa32aAL25f4u/DYrVjvoM1uNP7TO84E5xpcP/V+lNMXHn85p3psnnYeioo6fOQejrF3XxbupkNScPLE9ukgMaP5Sdi/3L++HIkXoXL4dNmeNJIk+Tir50hcMFItCWOvO2G4Vbs8XSPc6cn3T59+vDTTz9x6tQpihYtmmLbF198wbvvvktAQADNm6ftDIiUHvVAXPoE2S/eaifGYsPNbEgxn/rDDz9k/Pjx7Nmzh2rVqjkuGmqNyWhg/vz5dOzYkV9++YWuXbumqM9qsxNrtWMEx/+NiidGrkpRZuuHDXhm3AbgZr+gQc2qjJ04iSJ+Hni5mDEaFUaD4vClcL7bfIrrEfFci4ijdF5P3G+cY/rg3qA1+dqPwLVwBeb1rsF/wTF8/PdhbHbNgIalsGvN1PUpc8pk5XvNZtdEx1lBOb5bPVxMdxSM2+0aq10nrbPqYpb56o8Ku10TZ7GCcvQz3RM+QzabnZ+2n8XH1Uyc3Y6LUdGh+s073+WHBRBnvdlD8HGB8GQDUl2A1ONTj3zSBE9XM3EWG+dDonAzm4iKt+DrZuba5Qu81PYFTp85S+7W7+FRLu1cb22zEBzwJVFHN+D1ZAv8G72edNMrubNjW/L3/kusPHKNRhXz0TbhhsCt7Nu3j2rVqvHxxx+nme4SGBhI2bJlqVu3LkuXZr5kssiZdcTzA1uUUgeAXcAyrfUKpVQ/pVS/hDIBwBngFDAL+J+Tjv1ISG94rbe3N5999hnbt2/n999/T1Fu3/kwRo8ezdWrV/nqq6/S1FetuB+1SuVOCljAcccuPcpgJN/LozB6+mIJuwo4gumnS/hhtWkOXLzBX/tvBuHguILs16Anfo1eJ+bkTgLnDcMWfQPglkE4kGEQDiQltnCcLG8+b7HaWbjvYpazpyfPfg0356EvP3zl5tJmVjtzd5536lBpkb7kCfTSu3Od0e9177lQPlp0iM7fbmfSKkditIkrHcPb/TxcMBkNCReN0k5bSJ2Bv1JBH1xMhnS/GG8nCFeAr3vKu8bxVjujlhxh4/FAR3IXHHdb5n8zkXlfjMC9VDV8O3yabhAe898+rv72ISgD+buOyzAIT/TbzvNJyxVmxcmTJ5k9ezb9+vVLE4QHBQXxySef0KxZMwnCRVZJnyAbxVltHL18g+NXIzh86Qaxlpt3rocOHUq+fPkYOHAgUXEWTgVGciowktCoeNq3b0+VKlUYOXIkFouF69evJw1xDQ0Jxs1kwGw04O1mwtNsTJNEqOkXG5P+ndgvOBqqaP3ZIkYsOkhsQhZorTX5vEy4GA24m8DdpCiR24MPujVn6Iw/MLh6cm3eR/he+4cqxfwYtugQ8TaNTcOUdWfoVKM4/u43v4Unt6+cNPzXbrdnuD65XWuUArPRgF07Ht8Ju3YE4QaDgoQkXOL+EmexERlrIc6S8dJ4d8Kxig5JS+QmfoYMBoXRALF2G3FWa5oLPKn7rOGpom4LMKBhKccNH5Ni47t1+GnbOcYFHOWf86HsOxvMf9fDMKDYuG03jRrW59q1QOq8+UW6Qbg9NpJrf3xM1NEN+Nbrjn/jN9INwvN7wpS1hxn4+34CDl1h0PwDLNx7MUvvxfDhw/Hz8+Odd95Jd1t0dDQTJ05MZ09xO7JlaLqzyNVvh8yG19psNp5++mmCgoKYu2o7r/16MEW5j17vzI6duxj201oaPlEixV1vN7MhaUguShFx7ig3di5Ax0fjkq80Ho/Vx7VA+kODVcL+MRY71oggwncswK9hL5Qp7fqCUce3ErRkIiafvOR7aSRmv7tLHJHIkHCLMvWA+sSh9snfp/SGmv+28zwjFh9OGmZsAFyS3RFPnMMLzhsqLW5fZp//xG3Jf1eJjAperlGMBXsuYLFpzEbF3L61UwT4yYfDGxQMalIePw8XRiw+7Lgjcofyermwe1jjNJ+x5LTNQvDyqUQdWY/XE03I26w/vp5uhERbUpSLPLia4JXTMecuSr6XRmLyzpPhcW1RoUQd2UDshUPY42PxLF+HqaMG07VW8Uzb26VLFxYvXszp06cpUKBAim3/+9//+Pbbbzl48CCPPfbYbb0Pj6pH/Y54dpE+gUNoVDwnAyOSRv+UzutFbq+bw1G//fZbXn/9dabP+pHmL7bDZFRExlopl8+bgIClvPjii8yYOZNxZ1NedDv5aXMMBoXVasdis/PYR38TtuUX4q+ewuiVG/fS1fEoVyfN0NdE5fO581nbxymT14uJYz+jRv3ncS9YjoK+bpTw98LVbMRkMhIYGEjLlq345599jJswiRmBZVOcx3cNboCri4lcHjePExNvxW7XuJodyzoZkq0zPmvzaU5cjaRrrWKUyetYAm3QH/tZdyyQfD6urHy7Lt4e6bc5PVprxw2DxMBeMrjfV6w2O9HxNozKse68p6sJk9E59xUdidMcQ9O11ihu/v7PBUWy7UwQZoOBZ8vlpYDPzSX2vt5wirErjgNQNrcbnm4m9l+KTNr+dDE/fu9XG6tdY7XZ+XTZEXaecSRxOxMck1QuT8hhjv46Gn9/f77+ZQFPVK5EvQkbkrYXyuXK+QsXCJw/EkvIZXI3fwuvys8B0P7xPOBipE6p/BTwNtP1h71ou42Y//YRc2I71tDLmHwL0Kbve/z8ZtNM34etW7fy7LPP8vnnnzN48OAU2w4cOMBTTz3Fm2++yZdffnlH7/OjJrM+gUz0ewCkt25xYjBhNBqZMmUK9evXZ8zY8cQXbOK4k2ux8+HCg4RXbMuNlSv5bOx4ptftlqLexDnhFpvGFhvBtT9GYHBxw+iTl/B9SwjfvQjXopXxrfcKbkUqpdhXAzEJ+8eeO0DEvqXEXztD3nYfpbmj51n+GYye/lz/czRXf36PvO2G41ak4m29B48V9ObolZTZqe3aMUydVAFT4p3sxOzUf+67yPw9F5Lm/Sau75yY9M1oUPR+tiTe7uYUgfqwRYduBuLp3E0VOSOzz3/itsTfU/LM5+aEjPrWhCDYYtMs3HcxRSCeuBRZ4nD4xOHsd3onJVFEnDUpqZvNrhMyo97cbouN5Pqiz4g7fwjfuq/gU7sjJrORWT2e5u25+7gYFovWmhubf+HG9t9xK/EkedsMweCacoqJi1HxUvWi/Lh2Pze2zSPy4CqwWTH5F0GZzISs/ppvShWha62PMmzrgQMHmDt3LkOGDEkThB88eJBvvvmG/v37SxAuxH3CzWxEoQiPsSQ9Tu61115jxowZfP7JMJ55vikYzQRFxmO12an+7HPUqFmLzz79FMPLU1ME1devXydfPkeSJ5QmdMMPRO5fgWvhCsRdOET0sY2EenxHvtrtcK3SClIF5McDY+jwzS5sMREE//I9evJkvv1+DnXatEEn3GHWWuOfOw8BK1byas8evD/oHeq3685/pdujDEZ61CqCwWjE1WQiLDqeazdiibc6huB7u5mJjLOQx8st6ZiDFx5g3m7HHb4/911k/fsNWH7oCquOBgJw+UYcLb7czOYhjYCbybYSl5xKj1IKsxG0Vsj08PtPUvJfg0LpO/8F2RLmgif/HCilMCjHNqVUigC/iJ8HHZ4qluIzobUmzmKnR52S9KhVHKvdzrmQKGw2zZiAf9lzLox65XIzu1ctrDY7KmG1n3PB0eRyN3Pm+s1+bcQ/AZxb/TXlKlZm7De/ULBoQaLibewf3oiNJwIpl9+TlZt28fHED7FFRZOv4ye4F69ys335fAmOiifOYqfLd7uIOrqBG1vnYQ27gsHVE3PeEkT9u5l9syMgk0Bca53UH0idKV1rzcCBA/Hz80szXF3cGQnEHwDpBQvJ1atXj44dO7L491kU7vMUBo882IFTgZHgWgSP8s9yY/dfeD3ZEqNXyjm3RuUIZuOvHkfHR5OnwwjcilbGHhtJ5KE1hO/6k2u/fohH+Wfwe653UibG5LwqP48yuRK0dBJXf3mPfB1GYvYvnKKMW5GKFOg2gcAFI7k2byh5WgzE87H6WXr9qYOY5DIawq6BiBhLmrulyZePSgzgtNZ4u5tT3O0+fPlGiiv09cvllTni90hmn//k24wJ8/krF8qVYsrF77vPY9eOz8Qfey7QPmF+VOIIieTLgiX+jk1Gx2iRW8lo3ni85eYUB8dn7OY2S+gVAhd8gjXsKrlbDcKrUkPAceIfteQIFQvl4lJwBNcDphB9dCMFn26GuX4/lDGdr2ulmLvrPDe2/ErkobV4Pd4In6dfxJy7KFrbufhVd9TFA0nF0xsZMmTIEPz8/Hj//fdTVJ14wvX19ZUTrhD3EXcXI48V8iEy1oqHqxFP15TfDUajkS+//JIGDRrw66zpvNDzLYr4ehBvsxEcbeXD4Z/QvlVTfPctIVfNDkn7FSyQz5GgyqSwW43EnNmLR/lnyPvih2htJ/b8IcJ3LuTq2tkUObGWuGrdcC9XO01Aa3T3Jk+XCQT9OZqe3Tpz8fPxvD3gTSDhu1CDt7cPv/+xgPffe5evZ86gVetQZs/5CS8vL8wmI1FxFq6FxxIRZ8GuISrehr+nGxZbshVf7HbWH3NkolaATcPmk0HsPBOSoj2BkTfHCScG4YmjQTMLxiUIvz+ZDAqzQWGx2TEbHDkJbteVsGguhsbgajJQroBPiotZZpMRUzqfD6PRgLLfHKoOYLHZsdrtaG0HDHi6mfFwMXPqeiStqhamS60SeLuZiIqz4mY2YrNr4uOtPFchL7/tPOfoI9hthG6YTcTuv3Av/TS//vUn2uTO6z/vJSrOhrebmd961+DAtvUM690VXz8/8rw0EptvypFuFQr4cOjSDUJjLFiunyN42Re45C9NnhcH41G2Fq2qFiV486/8/s1kIiMj8fLy4tjlcBbtv4S72UinGkUpmMudgIAANm/ezIwZM/D09ExxjIULF7JhwwZmzJiRJnmbuDMSiD8AEufOZpbFe8KECSxZsoTSZ/4if9vBbD4ZlLTNt94rRJ/YRti2eeRu8kaK/RKzjp/nFL8CBhfHHTeDmxc+T7fBq2ozwnctInzHAmL+24dvve54P9UyzTqknhWexeiVO9ld749wK1o5RRmzf2EKvDKJ639+RtCSCVhCL5OrTqdbZiTVGv69lnat5lvt8+3mM6lvlqMMikthMVQulAsXkyPYUkrh55FySH3qFuXzTn9Ym2RXzz7J39uMPv/J/zYSh2mWL+CdYvqBLVk8bbVpvtl4mk0nr6cY6p54EWbvuVD+3HcRmy1lEG5IyJyb+sJPhpnTFTSvXJCdZ4KTsqMDxF44zPVFY0Br8nf6NMXfiE3DgYs3sEWfJ2jRZ8RePEqup1pgiwomZudCzPlK4FGmZlL5uskSy/k+2w2fmh0w+xVM2l6vXD5W+/lSyNuY9No6fbs9RY6G2PMHubZ8OePHj09zUl24cCHr169nxowZ+Pv7Z/BKhRD3gqerKU0Anuir9SfZf96DNu068NWUSbzYsQvKqyA2u8bDxUjdunVp2qw5qzfMx+uJphjdvXE3OfKuJCamslhteBhsaBfH8FulDLgXr4J78SoMrWJh8uhhXPxrDO5lauLf5I00U2ZMnn7k6zyGysd/ZOiH77Fo415eG/QxnWuWwMvdBa3tuLia2V+kLX6NbCxdOot69euzYtlSihUr6hjxBhhRKKWJs9uJtVrxcDGlCLxqlc7N4gNXk0ZC1SmdmyqFc7Hu+M2l7lpXcUyH01on3elUqYa3iweHUgp3VxNuCRdVbldMvIXzwVHk8nAhOs7G5bAYSuVNuQZ28nrtdrvjbnbCHfKU2zQx8VbMJgMWqxWT0uT3ceNiSBRGpfH3MBMeZyXeYiM63obVZsfVqKhZIg87TgeTy2xnxfRhxJzahXe11vg915v3Fx2nyeOFCI+2YDQaCY2KY9DIcSz9dixlypalQIGC+FtPsvXU1aQ+wYzOTxAUGUdkrIW83i6806kxkywTcC1cAaUUdcv48VXXp5hyfRO/41gJxWqzM3zxYU5ciyA81sqXa0+i7TZifx9EmTJl6N27d8r3LSaG9957jyeeeII+ffognEMC8QdEteJ+GQZ6baZv4fDlcIo16MzqZT/wdbde7DAakzrcZv/C5K/ZisBdy/Cp/kLS3erGj+Vn3q7z2DTEXXEkvLCGXcUlf6mkug1mN3yf6YxnpYaErJpB6JpviD6+ldwtBmL2TTmM1a1IRQq8MonABZ9wbd4wcjcbgNfjz6coY3T3If/LnxK8cho3tvyKJfgiuZu/leGcsyR3MFI4vZvlVptm7s7zuJodS1N9t+U/bHbNqKVHUgRw7Z4qwvy9F5PuwrZLJ8tkVpbGutdB+t224V69hvTe24zm51cr7sfqI1eT5nqbDIpRL1amfAFvhi8+nOajcy089mYyPoudKWtOMLBROX7efpa/9l9Ov0Ea/D3NhETFk4VcgyTeSqlY0IcDFx1JCh1zvb/C5JuffB0+TjdXgiX4AoELRmGLDKJWm1e5fGATjTq+xsLNh4i79C8eZWqitR2lDBy6GOZYXgfw9MudYvnBs2NbYrFYcO93iS1XNGMDjnEmKCpFEK61ndD1szF65+XNN99M0Y7o6GgGDRrEE088Qd++fbPwgoUQ94Nnxq7lUphj6S9r/hbAMqZ8NoIp38zBbNL4uDnmWH8y6lNW1axOG3bx+chxGJUi3mLDZrdzLjiKWKudgoWKYFc3sKQ6Rlzexxj89ULe+Xg8N7b8yuXv++PfqC+elZ5LClKUgvoVizB44ByORLzB7oC5nD59msjxMxnQrAomk4H3/viHK+Fx+FRrjdm3AP/+PZ5atWqyePFiqj5VjVh3EzdiLbgYjRT1c8fX0w2zUSU7hmLSS1Uo6ufOqcAoetUtSck8joAq4K1n+Wr9KeqWyUOnmo47h3a7xmg0SvK1h8SdLitnUAqDMhBvsWHX2pEnKRWtNUHhcUTEW/B2M+Lt6gJaOy4gmRPnjEex/fQ1PFzMVCnqi11rwmLiMBkMlMvnidGgOHk9kqK+buw8G8z1hDXCi+Zy4/ClG0QHXWXr9A+IOX8Kv0av41OtNQD/hcTiYTagDAasFgtBa77l771LqVmzJkFBQXR6+WV+33iAwrEhaNeahMba+d/cgzffF2BYywpc+XlQmgsKXy/ZhsHsxruLTvDW8+X453xoij5N1OF1BJ8/xfQ//sBsNqd4T8aPH8+5c+fYsGGDLGHqRPJO3mN3G+g0nrSBkwlrcsdXbInbjmVM/2wov/y1jkmrT7L7bCgaMFd/CfatJnTjHPK1dcwXXXP0WlKQ4pK3BBhMxF3+F4/yaTM0mn0LkO+lT4g6tIaQtbO48sOb+D3fF68nGqf4Qzf7FaTAKxMJ+msMwQFfYAm5SIHneuDuYiYsYT6bMpnJ3eIdzP5FCNv0M9awy+RtOwyTd9o52JkNS0/NaHCUvVWOrcTs6EeuhGPXOulx8rnH1Yr7MbdP5qMQMpu7fKsgPSckTxTmar6z5cDu1WvI7L1N7bed5/l605mkx1a7ZsTiw7z8dNE0d7ANCl5+uhjHrzmS8dmBLSeD2HY6CFsmI9HtwPXI1IuPZMxm1wz765BjSLzdRtiGOYTvXoRb8arkaTMYo5tXmn1izu4n6K/PwWgif6cxnLp0mLfeep/hA19n26vjubB7BdaIIJTRjNEjF6HRlqS/PTeTAatNU7mQD3+96VhbfOD0P7FZLcTmKsHXm86QxyvlqI/oY5uJv3qS3C3fwd3dPcW2cePGcf78eX7++WeMRklUJEROcGQGvzn09U4CjcQgHMDkk48nWnZn4YJv+N8bb1C/fgMiYizYgbIVK9G5a3e+/forevftS9FiJfByMxFt0UTF2/FxN1Gh8hOsXvonLxT3ITTGRrxd819QNJPWnAQgV812eJSrRXDAFIKXfUH0yR20fH0I3RpUxdVs5LkK+dh/PoxSrV7Ht0BRDsz/gs/f7EjL1cspW6YU25INIXcv/TQFuk3AZe0E6tWrx+zZs+nQ4SVye7piMBjQdo3RmHY+r8lk4r1mjxFnsREVb8Vqcyw7Wr6AD0NaVCQ63kZIZBz+yZLZJR+aLh49ZqOBMvm9uBAaja+bmYK+7mnKBEXGse9CKGjHuvQ1S/rh7mLCYnP0I4KjYvl82WFCo+KwAZ1sxalQwBuTwUhsvIU4g4ECudxwMxnwcDWz7vh1XE2KTwMcCd1iLx4l7O/P0VYL+Tp8jHupaknHdjFCt1rFOXH+Gr+NHUTYyT28+fZAcvv7U6ZUSV7u3IWxW8cSfHYNvsVv9gmSpl0AU9eeYPSyfzECiwbU4YnCfrw1dz//Hd2PuUAZNpwMxmg8hdEAtoTE83ZLLGGbf8GlYHk6dOiQ4v04e/YsY8eOpWPHjtSvn7VppSJrZFzOPZQY6ExadfyOlsfaey40KQgHMJhd8a7/GocPH2bBz9/jajbeTDbm6YdPzfbEnNhO7IXDQKo1kk0uuBZ5jJgzGWekVUrh9URjCr32Fa6FyhGyYirXF32WtCxZIqObF/leGoVXlWaE75jPhfmfEnIjPE1duWp3JG+7j7AEXeDqT+8Qd/l4mmMmniuzcsq02R3llYLi/h5kNG3IkQXTQPPKBXExGTAq0p17n57ky2glzk9Ob//0AsmctPdcaFLmb43jzu/ttuFevobM3tvUft+ddokuq12z878QkidSVcCnbR6nfAFv6pW9metAQ6ZB+J2ya0dStsAFnxC+exHeT7UiX8dP0g3CI/4JIPCPERi981Cw+2RcC1fEbnBl9NBB1Ov3GScXTsAeG0nYxh8JWfMt9rjoFB3SsBgrVrtm/8UbtJm+BYDFSwMAhVvxJ5Lak/SarfGEbvwRc75SeCbMUU905swZxo0bR+fOnalXr94tX2fi8kclBi+7g3dJCJGaSljD+Hall9fi9f4DKVmyJAMGDCA6Js6xprhBse7fQLye6QLKQPf/vUubGdv4YtVxDNqOUWli4m3UfrY+MVGRPOsVRKsnCuCq0tb/WLFC5O/8Ob4NemH9by+bxvfBePUw9crlBRQVCvpQPLcnXlWbUunVMcSGXafus8+wbv0mPm5RIUVd/oVLsWr9RipVeZLOnTszZOhQomMt2Kx2VEJe1ogYC5GxlhTvT0RsPMsPX+Hv/ZdYcuASMfFWwqLjCY2Kx2RQXLoRQ6zFlpQsLjFxnHg0KaXw9XDh8cK+lMnnjTmdjOs3YiwYAH9PF2x2O1FxVmw2Rxb1yDgr/14JJzw2nmJ5vfA0mzgVGInWmkth0QRFxmGx27HbbZjNRsJjLPh7mPksIQiPOLCSa3OHYjO503zoLN7r0jjFsb3McOrESQLG9CbizH76DxtLn0HDcXf3pH///rz12XQu/f1Fpn2CG7GOv1Ub8MK0bZQYvIwVe44Tf/UUbgkJ3iLjbJRMNiQ/fNcibJHB+D3XK81FwHfffReDwZCl5crKDb3ZJ2g/fVNWfiWPNAnE76G7DXTSK+9erjaFK9Vk8thP2bj/ZIptPk+3weiVm9B13ycklkjJo2wtLEHnsQRdyPS4Jp+85Hv5U/wa9iLm9B6uzB5AzLkDKcooown/pv3xe64PMad2cfXXD7DeCEz3mAVemYgymrn624dEHlydlZeeocTEWOdCotEJS1L5e94cXqOAZ8vm4dfetehSs1im61anvkiy91wonWftYOLK43SetQOAEa0qUadMHka0cmSVz0qQnhNSZ/42GNRttyE7XkNm64En3w5k+LtJLb+PW7rPnwqMTBFgKwXng6Po+t0OVicbDZJd4q+f4+pP7xB77iD+Td/Ev3G/NOt8apuV4FUzCVk1A9fCFfCt3wNTrvwAeFd/Ae96Pdm/Zxfe1VqRr8PH+NbthsHFnehTOzM87v6LNyj+4VICD27AtUjFpFUMwmPiGdP2ceqWzUP4nsXYwgPxe+41zo1rnWL/gQMHYjKZmDBhwi1fY+rgW4JxIe7end6tjbfZebVWsaTHhXzd6F63HJMmTeLo0aN8/fUMus/ezROjVvPhwoP8E2zAu3obDm5azsnD+5i5+Sxfrj1Nmfw+5PZy4aUXWuDm5sauDSto9UQhrkWkHqQOuLqz6p2GBK76lm3btuPv70/rli0o9Vwnyg5ZwhOjVrPueBDvNS7Pt4Nf5b1pvxNhN9O0WVP2rf+bMW3K4+1ioGphb6Z1epypmy7R4O0veLJJeyZOGE+Tlq048N9FLDZNWHQ80RYrEXFWbsTcHKF0JSyOkOh48nm7ERwVx+XQmKRVVYxKJV3JV0phNBowGg3EWe3ExNvu6IJHaonBvXgwJGbMz2zESW5PF2w2zYXQaLTWuCYsYRZns2NUBkrk8cTLbORCcBQWm53qxXyJtdiJjI0nMDSWoMg4jAYTvu4mcnuZaVAuN3abxXG+XzENt2KPU6D7ZDzyFWP92ZSj7S4dP0CdZ57h6pUrNOrxFkXrtGLXmRCadeqJX/0e/Pr32tvqEyQKPboF0HiUdawctPO/EJa+WZdF/6vDi6VNhO9ciEe5Olz95YMU+61cuZJFixbx0UcfUbRoUQb+vo9aY9bwy/b/0hxj1eFLxCfrd+29eHv5nR5FMjT9HrpVNvSs7p/8Krir2UhMnVexH3uTkA2zydPy3aRtBrMbvvV7ELxssmPt4sop5297VHiW0HXfEXlkHX71e2R6bKUM+NRoh1vxKlz/ewKB84bhU/slfJ/tmhRsKKUSMjgX4frf47ny0zvkbTMkTRI3l7wlKNDjC4IWjyd4+ZfEXT2F//O9UUZzeofOssSgPDTqZufBbDIwsFE5qhX3S5oW4OfhknRRIzHg+3PfxaRs64kXSS6HxSS91/FWe4qkXzv/CwGtUyyRdqsEe9kp+WfDoBxzpm+3DVlJEpgVyd/nUUuPZDqnPivzwlNP53i9fmnWHw/EknC1OqPukF3DN5vOkKxfBjguzlQs6M2xKxFOC86jjm8lOGAKyuxK/s5jcCuSdukvW/QNri8eR9z5g7gVr0Lc5eNEHV6Hwc07aXk/76rNCd30U9LfQmKQbnBJO5QuOUvgGSxB5/FvfDM5o9UO41ccY2W/qix+509at27N33OHpthv6dKlLFmyhPHjx1O4cEIuiUkbOB0URek8nqwe1OCO3xMhROaUUhgMN5Om3S5PFyMvPFWEp0r4YwBqlsqNUtCqVWuaNW/O0OEjyN9rJibv3Ng0hEXH41GjPWEHVhGy9jsKvDKRgCNXGdb6MdzNBqxerhiKV+fXufPY6N+c3J5pz8kng2IAOzagcpUnWLZ6Ix8PG8ovc77D5fQB8rzwAWb/wuR2s+PuYmTOMRt5u00iaPE4hr3bn2otu9L/neF8vf4Ur/7kuKDvaoBnXx5EiTIVWfzNWFo3qs+3P/7Gk1Wr4GoycPTyDYIj4/h06THC4xzjasvncaP5E4WIttjBAL7uJiJiDYTHxJPX2zVFVuyIWAvXbsSiAS9XEwVyuaGUIt5qI95qx6gUbi7GLE0NsNvtWcrCLh4MiRdVcrmbKZXfi+sRcbgYHEsFergYsVgs2Ix2fFzN9KxbitNBkTxWIBdmk5l1/14lMtbKueBoSuRxp0x+H3K5mLDarARdv47H6s84f2APPjXa4Vu/B8pg5J+zoRTJ7ZZ07Mh/lhGy5lsM7rmwWKM5uv8fchXYwE5bCUcDH2uKCrr9PgE45n+bcxfDnLdE0nPjVhxhWKvHmbb9d0zYOLDsxxT7xMXFMWDAAMqWLcugQYNo8eVmjl5xjHAdtvgoAN1ql0wqP2VNyhuA4tYkEL+HnBLoaEfwYTQ45r5qYO5OOz412hC+YwFeVZqmWAPcs1IDIvYtIWzjj3iUq5Pij9fk5Y97qWpEHV7rCKjTWy4pFZf8pSnYYwoha74hfPsfxJ0/RJ4X3sfkky+pjHupahR8ZSKBf37KtXkf4d+4H95Vm6eox+juQ76OnxC28UfCd/2J5dpp8rQZnCYT651IHlzZ7Y5AOjHoSwy2DYoUa4zP33MhzRriX288naLe5Em/LAkBevLAvX/DMvcsSZuzgujMkgRmRfLg2qAUdq0znPed0QiR5K9h77lQOidk/jYbFXP71qZacT8+eaFy0lD8zOik/6R87uiViDSZ8u+EttsI2/QT4TsX4lKwPHnbDkn3Mxwf+B+Bf36KLTIE/xbvYI8MxqP8M6AMRB/bgMknd9LfkEeZmoSunYXJOy/2uEjir57Eq2qzTNsRcWAlGM14pFoiMCzGyrBhw4iJiUkzxCwmJoa33nqLihUr8vbbbwMpc1CcvB5F40kbJBgXIhs57tre2beRUooqRXJROq8nJoNyrDeuFCaTkSlfTKFipcqErv+BvC84liqMt2k8PDzxrfsKwcu/JOroBmp17ozdrrHaNDPXHcPr8UZEH99C9IkdqIp10z1uTJwVbbMTY7Hh7unB55OnsDK8AMEBX3Llx4H4N/kf3X6G4Y0dCV6N7t7k6/gJoeu+Y++yX/n32FH8Wn2A0d0bgDg7xMbFU6VRe3IXLcsfE96jU+smTJk+k1I1G3PwQhhXbkQlBeEAx4NieSIsFh93M5dDYiid15sSuW8uvWS325OypEfEWnE1GzEbDY4hx3aNQUGsxY7JoLDYNQarHVdz1vNjSAB+f0k+SuFWd7+Tl3dMW3B8TowGRbHcHsTG27gRG09eFF5urli15uCFULafCUVrjcUazoHzoYTHxrPlVCiWhD7GrjPB/NSrJtu2baNXz+6Eh4eTp/X7KZbujbFD+bwenLgSQcjqmUQeXIVbqeq4FihDhZJFuBYay/IVy/Gp2f6u+gTx188Rd/lffBukHHZeMrcne/fuZc6cOQwa5MiWntzEiRM5efIkK1aswNXVNSkITzRjw5kUgfiPXSvx9KRdmbZFpCRD0++xasX97jhg23EmOGkOsNZQyNed9k8VwdVswK92J4w+eQlZNRNtsybto5QBv+f6YosM4caOBeRNlcDJq2ozbJEhRJ/YnuV2GFzcyNPibfK0fp/462e58sOANPubcxel4CuTcCtRlZCVXxG8YjramnKYmzIY8WvYizwvfOioZ85AYs8d5E6l97VrtcOoJUdYmOyON5Ai+Et8XxPraFAuLwv3XWRtwnqliWqnGrptNqoMh3Hfakh2dribz5azJA+u7XaNQWX8HqUeCu/n4ZJmesDCfReJtyXMe7dpvt54mq/Wn3Ks+55w0jUAhX3dMlwDNqNQ/W7vhtuiQrn2+zDCdy7Eq2pzCnQZm24QHnVsM1d/eQ9tjadAl7F4P/483tVfwKtqM1wLlUO5eBCxfyXa6hiu5pKvFN7VWhF1bCPRJ3eQp80QXAukn0EewB4XRdSR9XhWrJtmProx+D++//57BgwYQLly5VJsGzt2LP/99x9fffUVLi6O74VTyXJQpPf47NiWmT4WQuQ8bzcz7i4mlFJExVuZu/M8523eVGnZnehjG5OmkhXK5crf/evQ4eUuuBUsQ+zWnxnzQvmkKT3L/w3GreSTGHPlJ+Ifx7ST9k+mXC2lWiE3fNzNRMdbOBUYwbEr4Zy+Fo5H2VoUfHUqLvlKErx0EkHLpvDJsrPk8UgYMWcwkr9pP0q3H0T0uUNc/XEg8YE3h7oW8nPnbEgMPsUf44Ov5uNftAyv9+pOu+6v8+Wafzl6Je2Q1/n7LvP91nN8/PfhpCDMbtdJP4nJ8EwKImMtRMRYHOfthNEHOtlZQN/ijOCoy570f5Bg/H6TGIDfatqAY1k7myOotmkiYuKJiI4lPDqOE1fDiYi1kt/LDaU1JqPCzWTgTFA0vu5GCvq6se9cCOFxCXfMkx1q/6VIhn06jjatWxJtN1Gxz+QUQXiibYdOEjJ3CJEHV+FTuyP5OoygcK0OxJaqR1z+u+8TAI6/X6MpxUpGCmhdtQgDBw4kT548DBs2LMU+Z8+e5bPPPqN9+/Y0bdoUAA9zys94rVJ+WCwWft5+ln/Oh5A3b14md7g5AlD6BLem7ud5LdWrV9d79mScPOxRl3i3MXFo+6+9a3H8agS/7z5PPh83KsT+y3uvv4Jfw9fwqdEWAJPBcQVsxw+fEH1iG69P+4uAszeHtmu7jcuzXsfo4UuBV26dlCE1S+gVgv4eR/zVU441ERu8ijLdDPa13UbY5p8J37EA10IVMrzrbQm6QOCiz7CGXsa3bjd8anVIs3b5nVKAyahSLOWkICmzOJD0vhoMjqQ56S1Z1bWmYy6eBtonLG+W3h3o+yGDenpyYmmy1J/REa0qERodn+Exk7dpx5ngpCXJjArebVKeAxfCWHX05gWRxNGbBoNC2x132x3Zhh1L1eWU2AuHCfp7PPbYKPyb/C/Nsn2Q8m65yb8IRncf3EtVw5y3eNKcLYDYcweJPrENk19BfKq/iLZZUUYTdkvcrZf5A8J3/0Xouu8o0P0LXAuWvXl8rTEv/4Twq+c4ceIEvr6+SdtOnjxJ5cqV6dChA7/++mvS80+NWkVI9M0LZv4eZioXzsWusyHUKOHPT6/dXNdc3KSU2qu1rn6v2/GwkT7BrSW/E6g1PD95I5fDYgFNx6r5mNq/DbF2ReOhs/m6Z2183Mx4uZrYsWM7zz77LEOHDuXjkZ8AEBQUTJ0v93Bj55+EbfiBgj2/5OPuzSni50a14n6ERlmxY8PVaMSm4Wp4LP4eLoRFx1O+gA81xq5H223c2DqXG9t+x+RfmA3LFvHjcTvB4TH8FxaH3a6JOHeEc/M/wx4XRe7mb5Pn8XpULerLkUthlMnnRbsni/DRX4cJXf8DEXv/xqVQeSp1HkKIKeMRc0dGNMLVxYTdbk9a+cFu18Ta7NhsdkKj43ExGfBxc8HLzXHRIs5iI85qTxhNYMh0nfHUfWcJwu8vt3NHPPGCSmh0POuOXSPeaiOXuyuVC/tgtWvirJqSeTyJjrPiYlSYjYqd/4Ww9dR1LFY7yqi5GhZNVJSVzecjHXXGRRO8/Euij2+lSJW6PNNzCCbPXGw5nfJmTOzFY4Qs/hxLTARmv0J4Vqzn9D6BLSaCSzN74lG+LnlaDkx6/rfeNTi5fRU9u3dj1qxZadYNf+GFF1i3bh3Hjh2jaNGiAOz97zpdv99NrFVT3N+dJf97hqqfrUlKBtusUn6+fuXuT31Wm52IWAsmowFvt7ubpno/yKxPIEPTH2Cphx8fvxrB0EWHErbeoFTdp/Ao8zRhW3/Do0JdTD55sNrhv+Ao/Br2JOrkdn6cPIq87W5eBVMGI97VXyR0zTfEXjya7tzWzJj9ClKg2wRCN8whYs9i4i4eTZojlli/X/2euOQvQ3DAFK78OJC8Lw5OM2/cnKcoBbtPJnjlV4Rt+onYC4fJ02pQUuKpu6FxBGnJ5xP7eZrx93Dh+NWIpCRuf+67mLTOemoG5Ri+njgnvP1TRVIM404dUGZ1Ka6scMba4F9vPM26fwPRWt/WxYFbHTv19tsdIp96KHzqHAqXwmJSlE/88rcn+yWlXrIsO2ltJ3zHAsI2/4LJtwAFXvoEl3wl05SzRd8g6O/xxJ47gFvpGtjCruBToy3aGk/wimkosxvuJaoC4FKoPPa4SOIuHiNkzTdYQi6Rp/X7ScM2M22PzUr47sW4FqmUIggHsJzYxPlDe5g1a1aKIFxrTf/+/XFzc2PSpEkp9pnV42naz9yW9LiYvwebTgYBsOlkEN2/3+mUYPy3nedZfvgKzSsXpEvNYrfeQQiRruRBx9HLN7gcFovJADatWHE8lJkzptPtpbaYjgaw4lBxSubxok6Z3NSpU4cuXbowadIkevbsScmSpcifLy+rBtak0dhIbmybS6Fzq4iMa4wlJpJa41KOVlPARy3K8mTR3CiDAQUU8TJwMRJ863bDtdjjBC+ZSP16z5DruT54VWmGi1K80bAE0U8VIrBOZeaNe5/rf4/H48Z/bKvRDWUwsu9iJKXzhTsSwDbqi2uRxwhe/iUHvhpAzR5DuOz7BG5GiLWlfB/s2p70PlittqRlnWw2OwpwNxvxcDVhMDjOGSajwtVsxNVsxGazo7WjrDGdjNqJZF44xFltWBKG8aeXffxeSf53kBiUp/49pf797TkbQkiUBT8PE+dCoiie251cHq6YTRBnsWIyOpK+xlmsVCvqQ1hkFBfDYqiY3wcP9wIER8dTJM915gRsJWjxOKw3ruHboCeNu/XhUlgc7smmUiTNB187C4N7LkzeefCt283pfQKAyH8C0JY4fJ5+McXznWdsxPr7B1Sp+iTtOnVLen7KmuPM/nU+B5csYcKECUlBOECZ/L50rlGMzaeuU7NkbmZsOpViRZYVR1KOHE0UcOgKwZFxtH+qCB6umYeedrvmyOUbhMdYAU35At7k87n1HPgH1f3zVyMylTi0+bed51MMcU4+/Hj54Ssp9ll84DK+z78Odjsha79Jet5qB4NXHnLVfpnokzuIObMXcAznLeDjitfjjTG4+xC+Y/4dtVUZzfg/34e87YdjvRHIlR8HEnlkfYoynhWepcArkzC4enJt7lDCdy1Kc4XZ4OpBntbv4d/kf8SeP8SV2QOIPX/nQ9VT1G0As1El3VENibJw6noUQxcdos9PjjsuhXzd0x2cZlTQqGJ+rAl3YOOtdqasOZH0O0mdcd3Pw8Vp2cedseRd52+3s/roNWzJ2p86A396Q+lvdey7bVtqiUF88szplQvlInH65L3u+tiiQgn842PCNv2ER4VnKdhjSrpBeNzl41yZ8zaxF4+Su/nb+NZ+CZeCZfEoVxvPx+rj/3xfQtd/j/WG4wRmMLviUa4O8dfPEn18G95PtcryCTfq6EZsEdfxqZVyDVB7XDSBa77nsSeeJKZkvRS/m6df/YTVq1dTvEkvChRIOex09ZGrKR4fupRyqcJdZ0NIz287z/PK9zv5bWfapeXSKzt00SE2nwxi6KJDWdpHCHFrRfw8MBsV8TaNza4pkduD+s834ZlGLVj+8wwiAy9zOiiSiBgrWmtGjf4Ms9nMO++8i1KO/CjlCuThzKSXePW13vyzcTmFjTcYuPhMmmNp4NOAkyg0+bxMuJkUfw6oT+L9OvfiVXh33LeYC1ciZOVXhPw9jri4KHxcTRTM5Ypf7rx0HPE1TzR+ifMb53Nt3kfYIh3fU8sPXE46jmfCd63ZJw/bvxmK295faP5YXirkvXln8LVahVEowmLiOR0UyfXIWIxGAzabjTirleh4GxZbYqCecvhy6uAso1GjyTNvpxeIPwqZ1OOsNgLD4wiLsXA9IhZrdqwDegvxVjvhMRZi4q3YbHaOXr7BmqPXOHLpRkKug/+zd5ZhktRn1/+VtPu4rbu7sqywsMvi7hY8ECwQYngMEjyQ4C6L6wqwxroL6+7jM+1aXVXvh5rpnZ6ZFSTJkzdzvsw1XdXV1dXd9b/PLedoKKqOktayPo/mn09SUYkrGklVpzacwiQKmCSJZFrHbZUJJ9OUB+JMXbGXZ2dv5ZO1B6kOp5AR2VwZxoRAR5+V9bM+yIyeFV7yFzwjzmPOdj9ba2KsPRDCKYCWSlA37XHqv3kOW6fB5J16O5ayXv+SmEBTEoRWfY6185AW8UlgyXuUl5dTcspN3PXhepbvruPKl5bx+PT1bPzoaUx5HcgdcUbWc5buruO1xXvZUR3j3WX72VUby9remrzFbz5ax21T1/DQl5uY8tQC1KN8T1JplUgiTa7TgsMiUxtJHtN7/W9FGxH/L0AjwXn0q6387pP1hyU6U/oWZ/3fLseOyVuEZ/SFxLctIbbDEFBo/J24h52N7CuhfvYL6GmFg4EE9dEUotmKa+gZxHeuIFm54weft73riOwZselPoqUSme3m/A4UX/E4tm4j8M99mdpP/4KWzJ5BFQQB16BTKL7iMQSzjaqp9xCY/ya6pjZ/ue+FfKeFB87oS7/SlhX2bzZVceELSzIEWhRAFgVuHNuZX03uwfs3juaGcV0y2zQdFm6vzXwmzSvgG8uDnDO4jIuGt/9Rbemr9vp5cta2H2x51/h8pVmJXxSyrc0OR6iPZrfX2vYfS86HdPBlugreWbaPh77cmOlQ+E+GOPFdqyh/9RaSBzaSM/kX5J3+K0SLPWsfXdcJrfqCyrd/DaJE0WV/w9n/pIx1oBozSK2j9zhsnQZT+8WhanRoxWekQ9UUXfkE9q7Dj+mcdE0luOQ9TAWdsXXO7oAKLp6KGqmnvv+lPPbNds7952I6/WYap/51Bms+fApzUTeCnSZkPMgbMbMZEZeaKTkP75jT4jy+L7FunkBs/n8b2tCGY0dTcuG2mXjt6uGM6JzDqf2LeemKYVhNErf8/o+IksRrj9+P12rCbJLQdWjXrh2//d3vmTbtS6ZNm545TlrTuO7GW5FNJj57459HfP2uBS50RBbsqGVXdZhV907k85+P5PYxBby/30bBBQ/iHXcVka2LqXjtVhYtXYmS1omkVPb5k5zws1/R88K7SVVup+K1W0nsW4/PIfPYuYe680w5pRRe/hhlo85g66ypvHnPVfR1K7x/7TDevnowFwxrT1RRWby9ht21UZbvqaciFEfRwCZLOCwSXqsps7bruiFO15RINhI1VdVIq1qrpPpwJLxxZvxwZLypMNh/M5QGf3ebSULTOapY6k+FDMFOawRjKTaUB/l2WzVbq0LsrI4gSwI7qsPUhBMZIT5BaNkt1/jZGcr3Ih1y7HQvcOKymRnVJY8Sn40ynxVFNeKaJTuqmbe1ln0hhU/WHCSW1sl3WzGbZLRUmLtuvJppL/4VW8dBFP/s6Syh5Eb4aw9Q+eYviW6ch2fMpeSfey9CQ3v5TxETnNu/kPMHFnLR4GJsMkTWzkSLBfGMujBrP6X+IKHln+LoewKb1GLmbKnmgueX8u2OWgIL30YN15Az+Rc8PXdP1vOWN4n7NKAumqJviaFFIwAvXDakxTnN3NCQUBCgPBhnd12sxT5NYZYl7GaJ+kiSSCKNz3709vv/ZrQR8f8CNBKc5sJiH60+kFW1vGREe/58dj8GlHk4qXchZw8qwywJuIefjSmvPfXf/BMhnaBx7EmQTeSceAPp+oOEVnwCkCFp7iGnI1ocBBe986POXXbnU3jxX/CMupDo+tlUvH47qepD2XTR4iD/rN/hHX81se1Lje1VLbPt5oLOFF/5JI4+JxBc8h5Vb/8aJVDZYr9jRWUoyQOfb2BDebDV7WlV55VFu7nvtD7cOakH790wit+c0ivTfdBYrT2uqzGjpgMpxfhMDgbiyJJRAZckkQ9W7mfq8n18tPrADz7fRkK7cHutMQf9PavrTZ/fdCmSRIFrx3TKkGY4POE+mq94a9uPRt4bz63p97jp/02J/H2fbSCh/Psz7k2hpxXq57xE9Qf3I9ncFF3xOK6BJ7cIxrRkjNrP/4p/1vPYOg0yquWFXQCwlvVBS8YIr/ois79v/M9AkoltN/zpXUNPp/S655GdLYnu4RDd9K2hqTD6oqzzSdXuI7TyM5z9J2Ep7XnovQDz330aLRYiZ/LNCKLEhvJsRdST+2RXyFVNZ2CZB7Mk0DHXzm0nZgu+wfcn1s0TiM3/b0Mb2nBsUFXN0DRpQiiHdczh7WtH8tRFg/A6zOQ6zJw6sg83/+oevlv6LZXr5uK1mw1bR13n1ltvo0ePntxxx+2s31vD3C1VrN4bwJufz7kXX8mHU9/hDye4W319kwA7qwJMfnoRN039jqfm7mTtwRDbqiM8ubAaMARjPSPPo+jSRxDRef2eK/nsrRfJc1gZ1sGH22LisksvYeStTyNY7FRN/T3WLdOpDsYY0eHQ6wqymZ/96kH+8PTLpPwV/POOC5j52fu4rGaqQgnmbatm9d56Eqk0kiASjCaIJdPEUjoJRSWQSLG/Psre+iixpGrojDQsjo06I3qDJamqqkTiCrFU+ntVfVNpjaU765i9uYr66CGv6ObV9/9WWEwSAhBPqUii8G9pTdd1o7tDEAQUVWPp7jo+WHmArzZU8dayfYQSCpqmE1c0w262IZmi602Jt/F8TdNQVWNswSwLFLmttM+xMbprDp3znYBB+PfVxdh4MMCOqjBaWsFnkbCaZfbUJ1h9MMyjr33EcaNG8vU3X3PeDb9i/oxPyM9pWeSJbvqWijfuQI0GKLjgITyjL0IQxJ80JijJtbFgaxVTV1cQjSUILvsQa4f+mTHTzk7jGtZ/8xyCbDZepwlSVTsJr/wc54CTsZb1ok9xduX9jIGliKJBukXg1L7FvH/DaDY9OJnV957EqK4tdRs65NrRdaMb1yyJlHisR3wPoijQp9RLt0In/co8FB1l//92tM2I/xcg4wmtaGgYJEwSBT5cdYC0mi0A1qPIxdaqMN8dCDJ3SzXXjumEy2aiutvjPHTDeRRs+5yq3hdkjm3rPAR799EEF7+Ho/e4Q56EFgfu4ecQWPAmyfKtWEp6/ODzF0QJ79jLsXToT92Xj1Hxxp34JlyNa/BpmYyyZ8Q5WEp7UPvZI1S8eSc5E6/DOXBKFqEQzTbyTr0dW6dB1H39DypevYWcE2/E0feEHzSjlTqKmNeO6gi/+2Q9wzv6Ml7jTeech3TwMaVvMQsaZmY14P2V+1FVIwM7sVcheS4LU5fv+9Hz4U2TMSJwXNe8jB96I95Zti8j1HfjuC6tWoM1fccCcHr/Yl5bsidLSO5w/vZHm/c+3PbG764gCPjs2Sr9zYXs7jutT5bX+LmDyzJE/j9bA4dUzR5qv3gUpWYPzkGn4ptwdatCKamqndR89gjpQCW27qOxdhwIWhp0DQRDNCjnxBuo/ughTLll2DoPRbQ6Mee1Q7QZi/f3FSbU1TTBRe8Y1fDuow49ruvUf/0PRLMN77grs56TOLCZyNqZuIaemVFcdVqM8zvrmYVsKA/Rt8TNsI4+Vuw51MnQu9TDlqow++pjXPrS0hZdHk1/E43/HwmNM+FtM+JtaMMPR9OW6uYzsa8t2MFz83fzq8k9GdM9n0K3lb/edzfzp3/MX+77LReedRpmm4sVe+o54I9zx31/4cZLz+a6O+9h5Hk3kkxr3HB8J2685XY+nvoGH7/4D3538/08/dVObDoERCjzWhE0lbNfOCSmt2SXnyW7VpBvE2jnldgfONTJZintRcGVT1M/42lmvvoY8+bNJe/UO5AdPv54eg8+uecSPp00jF//8nZmvfk0O75bwT0PP8XmyjCxpE5pjgVJkrj92ss57YQxXH/t1dx/1y0smPMNp95wDwnZRnkghoZA/1IPbrsJWRIRG0TZgrE0XoeZ+kiCbTGFboVObCaZxlb1houJiKGknVI1zOhEUmkcFhlJOryYW+NnsHhnLZsrI1hlkQP1MS4d2QG5GVltbXb53w1F1aiNJNF1yHWascjHZttmkkQK3VbSmo5JElt0TP2r0Pgqsiiwvz6O0yLhsZmoDSXx2ExEEgrtfBa8VhlJFFA1HUkUUFSVdftDBOMKpT4b7X12NE1DloyYushrI89lMUimKJBUBNbur+HWd79D1XRO65ePgMaO6iAlXhsd3fDe84+x44u3kHPKKD7vflZ4uxBXBW46vhN//MZwAdCUJP45LxJZOxPZW4yz/yTMBR3/JTHBe8v3Ut0gpRNe8yVaNIDnrN9lto/uVUrVxx+xb88afCfegOQ4tHbrmkrdV88g2dx4x1+FCPzulN4oisLoh+dSG1XoW2ZH0IzP4N7TejKlXzGqDmlVxSxLpNIaJlnLSsq8dvUw7vt0I/5Yijsn9TjqjDgYcWOh5//fufCmaFNN/y9BY8uzz27GH0tRHojzbgPBa1SUHtk5lydnbcsKgmVR4L0bDK/lG264gRdfeonCyx/PsjpQQzWUv/RzLB36U3DufZnHtWSMgy9chzm/AwUX/uknWSzUWJC6aU8Q37USW9cR5E65NUuATY0Fqf3yMRK7V2PvMYbck3+B2MyCCSAdrKZ22uMk92/A3n00OZNv/kFCbk0F245l30Zl9UZP6ydnbctUmZsfSxYFHjqzLw99uTFL2b65ovrhlNab+2c3V8hvTsIPCfUZ6vjv3TA6s887y/a16rMtCoayrs6h79HNE7q2+vo/VCDunWX7uPfT9Wg6Wd7fza+fJMDornks2mFU/QXgpN6FzN9uKJNKkmhkuP/N0DWV8MrP8M9/E9HiIPeU27B3GdZyv0bxlTkvIVrdmHJKEUQR2VOIlk5iKemJa9ApCKKx8Cb2fUd41ZeIDd/bxN61FJx7P6bcsu99juE106n/+h/kn3d/1rlFNsyhbtrj5Ez+Ba4mPqO6qlDx6m1oqTjdbnqeOIcSJF6bTCB+yPJQxEgyAVhNRnKk+b3n5gnZ1in/6+Jrbarp/xq0xQRHRtNKuFHVFRjz8CwOBA7NWF40tJh7T++PwyKzevVqhg0bxuVXXMUNv3+Y15fsxSpLpHWdVa8+wIpvv+Ku5z6nXs7l1L6FtM918OKjD/L8P57h4j+9ha2wPSv3hYkkVKwmkXDq8PfnKX0KmdGKkJOu60TWzcQ/+0UEs528U+/A1nlI9va1M6if/SI+Xw4jrr4PX5eBSILAyC45mAT481c7jHG1tZ9SPu8tzA4PxafdSmGPoUzsWchloztR4rWT1gV0jAppXTTJvK01vLhgN7IocNuJ3TilXzE2k4xJFjMV01SDkrqqa1hkiaSi4rKZkcTDz4c34sNV+0mmNWwmkYpAgp8d1xGbWc68r2Pxt/53oCIYR1E1JFFA04yxxqOhaWv9kdTlvw9UTSelqAgiWGQp69oEYil210QxySKd8xyYJBFBgC/XlfPakr04TCJDO/q4fFQHjGQKuMwysixlznNvbZSdNRGsJpHaSIqRXXLJcZhRlLShGSTJmThlV22Us55ZSKTZd/rZ8zpjMdnYumUDD9x1G7UHduEcdEpDYr5l5Vap20/NZ4+g1OxBzilFcuZi8hYdMSawuTwk1R8XE2iJCAefvxZzSQ8KzzdcEEo9FkaUWnnll+eSk5fPosVLGPXX+ZnnhFZ+hn/2i5ScdTemHmMBI0Z0WyQCzdUQG7DuvpOwW2QiCQWTJKKoOi6r3CLh9L+ONtX0/w/QXE161V4/H6w6kCEojZ7LzVt305qeqcI+8sgjvPneR9TP/DtFVzye+fFL7nzcx11CYN4rxLYtwd5QURMtdjyjLsA/+0USu1dnLY4/FJLdQ/559xNe+Tn+b1+l4tVbyD31lxl1SMnuoeD8Bwgt+5jA/DdIVm4n//S7sJT2yjqO7Cmg8KI/EVrxKYEFb5J4eRO5k2/C3n309zqfRloqCXD6gBI+X1fO4cacdLKFzS58YUmWRZYgGOSx8flqg/Lj2G75VIUSXDisfZaq+kerD7ToagD4aPUBPli5H0U1srh/OLNvRsn9cGS4eftvWiNzno2vo+nG8XRdz5yjrpN5rHn1u+m5NlatZVHg/KHtOKdBJb4RRyLqG8qDmbnulKrzyIzNdC108WHD97fxCgqCkJXM0IF5W6t54Iy++GMpvt1azfImldl/BxR/BXXTnyR5YCO2biPJnfwLJIe3xX5qPETdjKeJb1+KrfNQfCdeT+DbN8g/6zcAxLYvI7F3LZG1M3ANPg0Aa/v+SM5c0v4Kkgc3U3T548cswNIUWipOcNG7WMp6Z82Gq/Ew/rkvYy7pgXPAJLrmOzI+4MGlH6LU7SP/3PvQTVZoct9oSsLhEAkHkAWBcwaX8dHqAy06JprikhHt/ycJeBva8J+E2KQi2UhimpJwgKkrK7jn9P4ADBw4iFtuu52nnnicLqMmExHa4faaiMYVrrnrPlYvmss7T9zPpQ88T59SL+FEkuLjL8D0yius/Pg5bCffjT8FNlkgrR05SaqoKR45pwe/+XhrdmeWIOAaOAVLaS9qP/8b1R/cj2vomfjGXYUgmzI6MZaSnvi/fISvH72FCRdez2U3/JIeZR4ufnGFcRxRgsHnctXkk3jz4d+w8637qB14MoWX34ZvUk9ESUTSQNXAa7dgMUm8tngPpoYq+bNzd3Ba/5LMLLEsCYiiiNUiIolpkmnBuOdJIkcr/Oq6zs6aCOg6u6sjeGwyg9rnZEh4awS8aVGstW214SThZBqnRW7wiT+2qnXz4wTjCnFFxWWRcTbYQhkK8cac/JHeW0JRSWs6NpOEwKFK/o+t6gcaiksiAoUeC7XhFHO3VoEgcFKvAtrlONhcEcJhkYkrafb74/QocuGPppi/vYZCt5lwIk3nPCc5DkvDiIFuFBkarmt1JMFn35VT4Y8xtIOXVFrjQF0EXbVhs5hQdZ26QBRZlnCYJe5+f00LEg5Gp+TTT/yFXXPeRbK56Hj+A+idW/IrXdeJfPcN/tnPI8gW8s74NbGti44pJihUq9m1ecMPiglsIsQ1CC77EC0RwTf2UCdcdSjJmhVvEqqv4dePvkAkeej9pUPVBOa/ia/HMOTux2ce13QOS8IB9tTF6JhnxypLpDQj6dRGwr8f2oj4fzP0Q2xqQ3nwsNVCn93Ms3N38OL8nTjGX0ftZw8TWvEpnhHnZvZxDz2D6MY51M96HmuHARnhKdfAUwzSPO9VrB0HZsj7kZCO1IOuITlzWm2nEQQB97AzsbbvR83nf6X6vXtwDz8H79jLESRTZobM2q4vNV/8jcq3f41nzCV4Rp6f9fqCKOEZcS62zkOom/YENZ/8GXvvceSceAOSzZglM0kCaVU/atVb0w3RiaNB043r+fy3O1v4VDfObjdCB6au2EdjkWJjxQZjdKAyzL2fbcgSDlHSGs99u5O5W6qzqtaqpnPPp+vZVxdlY0WIKX2LW61IN28HlkVaTc5I6Azt4GPFXj+6bgRujeMLh6vKHwzEM+3hKVXnnWXGvHujJ7jPbs5qJ2/aMbB0Vx07qsJZ57p8j79VQp3W9Iw1Vua6qDpPz95GvzJv1ozdvxq6rhFePY3At6+BKJN7yh2HHYFI7P2O2i8fQ40F8Z1wLa6hZxrza/UHiGyci7PPBKwdB6JraRJ712VsAbVkDNlTgCmnFFuXJgQ6EUFLRJDd+cf0ewst/wQ16if/7N9lVxC+fQ0tHib3wj8gCCI5DjPURFHq9hNc8h72nsdj7zqc4R1zsq5784p4U0RSxoL8fSzp2tCGNvx7cCxkSAQcZollO2u48MXlaIxG9rzNQ7+9k5KfPU3PUh+D2/s4cUgX7nvwIe65+5eU1a/Ga+vGU7N2s7UqTc7Ic9k653WKuk7B0r4fibSO0yIaGeBWoKUSFAlRquvC3DyuA//4di/N9zTnd6ToiscJzHuF8MrPSOz7jrzTf4U5z0jomQs7k3/5k9R/8xyz332OTSsX8dZbb7d4rbVhNx1/9iS1C96ibuknfLN/LYu6v8KoMWNRdJAEAVHQefCLjVSEUsgi5DpMRju2qqEAZvOh0DitaugIyKKAhKH70rh2Hu5610VTbKkMU+y1gWC0xnfJd6DrOqIosrUyzPLddfgcZib0KMBullBVncbDSc2kp2vDSXbVRAgkFFKKRtcCJ8VeK06L6XuRnoSiEU6kscgi/piCSRaxyBJmSWBnbRxBN9xzWiPWkYRCeTCR6Qws9dhoTJt/34p4YzVdFEXiKZU1+/xIokBdNIlZ9rJgRw2KpuO2mZixoZJrx3RG0zVEwRCX1RoET/2xFClVo1uBm6pgnNpIClEUG8h3thr+k99sozqUoCYUJ5JUuHREO2qjSdKaTlmOnUgqza6qGE/P3Y6S1hBb+WxTNXt4+JZfUblrM/aex5Mz6efoNjf5Zij1Cayv0lEx1vD6mc8Q27oQa4f+5J56J7Irl+CS944pJqiXSvHmDzp0vVQFNVKPaHG02iHaiF75No7rlscLM1cTXvk5jt7jMRd2PvQZVuzg28/eou/E89iYLiS8ZC8yoOg69V/9A4Drfv1nZu5VCSaMGEASoGuBg61V0dZeEqtJpDqUpFuhi/+NRvKfHm1E/L8US3fVkdYMgqlqhie2cYNqSTnv+3xDhjTaexyHrdtIggvfxt59NCafMb8pSDK5k2+m8q27CSx4i5wTrzcel014x11F7eePEFk/G9eASYc9Jy0RofqjP5A8sBEA0erE2mmIcdPpPKTFjd1c2Jniq57EP+clQss/JrZ1MQXn3Y8pz/AstJT2pORnT1P31T8ILniLxO7V5J12Z2aOPXOc/I4UXf4YwSXvE1zyHok968g56UbsPY5DOUZx9cbW8mMR/dxYHmRXbes3pebPb6rrklZ1nv92J7O3VGeRcAGjKj2n2eNNj/ncfEPAbsH2WvbVRVsQ58bq4yuLdoOuc/WYzvhjKZLNOiQEAVbvD2RyOKpmiNI1bRdvHIFoJNeyJCKLAkpDQqOxM+C+zzag6Xrme6fpkFQ0Pm4Qpbv0paUkFe1HTXbrGMJ6lZuq/m3KkkrdAepmPk3ywCasnYaQe/ItyO6WAiR6WiGw4E1Cyz9Bziml6Nx7s0Y+PKMvIr5jGZbi7phySrEU9yBVvhU1Uk86VE1k3Tc4B01Bduag1O0nsn4WsW1LSPsNqx7R4SX/jF9jbd/vsOeaDtcRWv4R9h7HZXWNJPZvILLuK9zDz8Fc0BkBWLs/gK5r1M38O6LJSs6J15NrN/HGNSO4feoaZm2uIs9poXuhi5V76gnGFQCaSyks3VWXES1sQxva8H8Xuq7z+Ll9+OVHGzOPbf3jFFRV58IXDRcV0WwlZ/IvqH7/XgKLp7J17BU8cm5/itxWfn3HLXz2wbs8dM9vmDRpMtuqYgRjSeSBZyCtmE793JcpuuJxzILImC65zNhU0+IcQis/IzT/df6kpBBEEWdpN+SuY3D0m5hJmAOYANliQTzp51g7DaFu+pNUvHYbzr4TDTFJQWjQibkDW6fBVH71LKedMBrPhOsx956QiS/Gd8+nPq6y1Hw1pf3H4P/6Wc4+/RQuvOwK7r7nIXw5ubyycBcfrTG6yNIa1IYVXrqiH3LDbHTjrLPeMLdldBoI6EgZ0tm0gp3WjHWw8XlaQ0xmkSVsJgmnVUbVQNN1YqkkszdXkes0c8AfZ9VeP8d3yzM66lqZ7weIKSqSJGISRZAhrcJ3+4O4bCbyXRZKvbbDWqjFG5KnVpOEjkH2jfPUM3GAoup0zrEjSyJRRc2MHTVFJKlikgSsJolIIk1a1zGL4veuhDfVMtA0LZPkz7GZSaRUQok0ug6mBkLdGB91y3exuy6GWRbpkGskNcq8NjrnO9hdE0WWBMZ0z8scuxGCILBqTx1fbaxEFkU659lIpHW2V4aYvqkOiyzQo9DFOYNL+cvMLdRFkkZSpOk5qwrBJR8QXPI+FruTvDN/g6PnmMz2mhTUVBnvK7H3O2qnPY4a9eMddyXu4edkEurfJybQUgliW+YT3TSfxIGNoCqAgGvwKfhOvLFlPI0RzxV7bPjnv4Gu63jHXnHoPWgq9TP/jsnhpc+Z19GryEUgkeYv5/ThV4/8g/iulQy76DYevmoi9ybS3DJ1DcFIgsvHdMRuknll4R521UYZ3M7NV5uNxP3dk7qTVDRynNnaP234fmgj4v+laC6odc7gMvqUeFqdA85unxbIOelGyl+6ibqZf6fwokOz35bSXjgHTSG8+kscfcZjKTYUke09x2BZ+RmBBW/g6DmmhU1TI4LLPyZ5cLMh8mC2k6zYRnzHctL+gxQdpq1dECVMeR2QPYWkw7WUv/xz3KMuxHv8ZcbCa3GQd/pdRLsMpf7rf1D+yi9aFWgTJBnvmEuwdx9F3YynqP3sYWxdR5Bz0s9bJVGtoXk19nD4emMltZHWq7OSmE2+m6MqlGhBtvuXeehb6uHd5cfmn/z8gl0IkFV9BuhR5OKAP0YqrfHQlxu577Q+SKKQ+T4IQJ8SD98dyFaKV1Q908be2ILelFyrqsZFw9ujAx+uOoDa4L+qNiSCdF3PqKfowAcr92fI+o8h4W6rTChxqDLb2mVtvA7Jn2B2XFcVYyRi8VRE2UzuKbfj6Dux1UAjVbOH2i8fQ6nejXPAyfhOuBbRnD0fZu3QH6VuP+HV0/AefxmyOw/JnU/y4GYcPcfgHHgyaX85ddOfIrF7FYgS1g4DcQ6YhGh1EVr+MXUz/07p9S8c9pwD899A11S846469D7SCvVfPYvkLsBz3CXGYxjBYlnlQvYd2ETuKbcjO3xM7lfMO8v2EUupxFIqe+pi7GmwFjE3zPNf89ryTIXcajp2pf42tKEN/1kIgsCUAWWc1K8UJa3hsMgIAryycHfWfrZOg3D0nUho2Uc4eo6hZ9FkLCaDPLzwwgsMHTqUO+++m0Tvi6mKG77G3nFXUPflY0Q3zGHQyWdyXJcc9tbH2VQZAQwid/doLz//28tYOw7G3uM40oFK4rtXEZ77Mpay3llEXOHQhIyt0yDcQ88iuOwDIutmEt00j/wLHsTWYAfl6D0OS2lPgtMfp/LLxyk5uJrCyTdyx+mD2LQ/SOd8K8XeIs4ZOpzih67gj3/6I889+zRzv/ma3//hr8wNZrtBiAKM7pLX0DJu2Fk1qmwrqkoipSIKOhZZalDa1pEaWtcjKZWkoiIAXocZsyyR67RQ5LFSHUpS4LLgssgk0hpigwq3rhtCZ5JgWMMZa4yeRVKbIs9poTacJKGoWGSRQCKFxSQaAmWRJHlOC1ZTy+6pcCJNTFERdEikVbw2MxbZqEI7zDIW2Ugq2M0SdVGVpKpiM0uttqc7LBKheIq0qmOWjaTAD21HTygqwbiCVZZw2UzkOMzURBJYzRKd850Uuq188V050WSaKf2KEQTId1vJd1sRBIG0qhGOKYiiwJUjO/Lp2gO8snA3y3fXc+2YDpw/pD0OqymT2Pjz9M0ICEQSChvK0/xmSg+2VoSNWWYB1h8MMa5HHpFEClkyiloxxfgskgc3Uzfz7yi1+xgy/mTOvvG3vLjmUJdfrhnqUqCnUwTmv0loxafIvmKKLnsUS3G3rPd9LDGBIEr4v32dyJrpaMkosq8E16BTMOd3IHlwC+HV07B1HYmt06CsY6cw4sv161YR2TgX98jzkD0Fh74LKz8nVbWTduf+FrPNRU0khaaDTY8Tnf8qQ4cNY+YLf6E2kkQU4JmLBrK1Okw0pSKLAjdP6Er7PDvtfA4kUSCZVqmLpJBEgTzn/9/2Yv9qtBHx/1K0plDdqJr+0eoDvLdif6vVVQDZlYdvwtXUf/UM0XVf4R50cqaSmzvuShI7llE342mKr3wSQZINtesTr6fyjV8SXPJeC7uDRijVuzHld8Qz4jwAXINOybTUtJ6t1YhtX0Zy33ryzrgb2V1AxZt3ElryHqmKbeSechuyy1gcnX0mYC3rTe20J6ib/gSx7Utandc1F3Si6PLHCK34jODCtyl/+ed4j78c1+BTj6nN91hQ04yEN841S6LAxJ4FfL2ppSBNIzw2k6Ei3kAcZRHuO90ILj5afSBDgk/rX8wX31U02GxkV9obxdUaq89Nhc+aWoX5YykeOrNvpnJtlkVGdc5lXTMiLgpkVOEbn29k5I02xsZEz5AOPs4dXJZdMW9Q8m/KuBs7NMyy2EKzYHjHQ23xjSj1WikPJLJI+1kDS4il1CNey8aX/SlIeOLAJuq/egaldh/2HmOM8QZny4qvrqkZXQLR4iT/3Huxdx3R6jElmxtHr7GE10yn9ou/kXPSjcR3rsDavh/Jg1sILHiDxN7vEO1ePMdfhmvA5CwFUy0WJDD/DTQl2ao6e7J8K9ENs3GPODfT2QIQXPo+St1+Cs57ICs5oAarWfvRPxhx/Hj6XXAJc7ZUH9HjW1F1Hp6xGa/dzPgeBXQrdGW+J0BbRbwNbfgvgM0so2o6mqSBIBBOpFmxpx6TmCUNge+Ea4nvWoW08Hkk4UbAIITtuvbiF7fdwVOPP0rhRX2wdjDmyx29xxFe9SWBb1/not9ej89jy5BwMDppVq7bgK5puEddiLXM6Njxjr0cpf4gsq+k1fNtjAtSVTsouPCPpCq24Z/9IjXv30/uqXfg6HEcALKnkBsfeZXEmi947omHie3fxN52j3DcqHFUh1J0znVxzj+Nqj/2E3nvyzO491e3cdv1V9Bv1ATSAy9HdhtE5bhueRlRr8ZqtCHUliaUUBAaVE3jqoamqpglCZMso+saSVXDIoukNZ14Kp1RDx/SIQdFNXRV4kmloYNRx2k1MaJzDst31+KymRlY5gUaW95bn7V2WGT6t/OSVFR03WjJrgol8EeTpNJGYiCcVkhrGk6rCZMkklDSBGMpLCYJi0kimVYNQutq2XrusRsJBE1vmP9uZUbdYZYp89lRG/YRDzNMfrR5cUXVWbm3PtOtN7Cdl34lbirDCSyiiEkSiKXSDO1gXBe7WWbJTqMDdEzXXARBJJpKo2gqoZhCIqny3LydJFIqSU3nwS+38s3GKu45tTc9SrwkkylCiTROi4RZEkmmVWrDcUyyRFxJo2k6bquJEq+dS0Z04M0l+0hrOhY1zMHZbxBZMwPJlUf+efdT22UYL64JU+SUSKRhSu983ltdSbJyB3XTHkep3Ydz4BR8E65pkZgHkI8QE6iJCOHVXxJe9QW6ksTefRSuYWdiKe2duZ72HscR+e5rUtW7WxBxgEBK4+W/PYDLl4dn5CF3pHSwisDCt7B1GYbQZTSn9C1h/cEAwzvn8PYjvyIYDPLCCy+iaMbYQSCepD5lWL/VR4zxDbfdTDqt8tuP1rGvPsaN47syvkcBCUUlklCwmWXMctts+A9BGxH/L0ZzAbfGx5buqkM7Qo+1WRLIGXwy2o5FRBa8xtiJJ9G+XTu++K4C1eIgd9LPqfroj4SWf4xnlPFjthR3NzLmKww/YlNOaYvjCrIZPZ0tDCNIphat5I1I15eT3L8eW9fhWEp6oKUSWEp7InYeTHTjXMpfuskQYOs1DkEQkD2FDQJtRnW+/OWbyJl0U1aLEDTOjp+Dvcdo6r/+B/7ZLxDdMJucSTf9KBu21iAIh0b1s9hlEwws87C2gfzO317L2G55tMuxowPnNhE9e/vakXy8+kDm8ctHdWTprjrCcYXn5+9qUV1urD67LDIvLdydqVA39RhvTM40Jmxa8zLXdDIVdLMsZtrJ9Yb2tPtO65Nl2db0O/f07G1Uhpp95oJAnxIPfUo83NOglg6NqvNSi8t0MJBocU7zttUwvnv+91K1/yFQY0H8814juv4bJHf+EYm1Un/QEG47uBlb91HkTro5KxGUDtUgNSSOGmHKKcU34Wr8s18kuPg9JLuXVMV2At++jmj34jvhOpwDT26VaGtKAgQRQWp5m9Z1jfpZzyM5fHhGXZh5PFWzl+CSD7D3Hpc1d67rOjUznwFd5/d/fpI7ph9s0XLeHAJkbMv21MU4a2AJMzdWZmkBAD+Zun4b2tCGfw0kUUASDaLpssrkuiyc0KuQtfsCpDWNqdeNpNBt45uxJs477zzu++NfuPOuu1mxx8/szdUIg88lr+R96r76O8U/M0ZbBEEk58TrqXzzTp5+9BGGnn9zi9c9GDR8lJrHBa3FD41ojAu83YdT2rkHlXkdSOxehVJ3gNpP/0Ks93hyT7oRr9PJWUPaYx5+K3ndh/Dsg3dy780/Y8o5F/HAQ3/m0jfXZB33T0tifDxzNu+/9iKP/OVPqKuXUjrxMk676BpyPHZ63TcDiyzxyLn9mNyniGQqjYphy2SXzURSKrquY5YEArE0BS4JvaEzrCqUQANK3RYSKZWUpmEWRawZQTVjLl3VdGySyJAOPvqVeFDRqY8mUUM6Pof5iPPeJknMWELZLTKxlEpa1Sjx2gkmlMx4Yl0kiccmE0tpyJJAfSyJ22rGZZEza1O2n7bx/5HE3xqr+LIkYhWFw5Jww6ddRRTBYzOTUFT21EYRRIGOuXbsZpm4opJWdfJdVkJxhbpoio3lIcoDcRCgvc9OJJmmY54dAYE/fbmB5XsDABzfNZe/nDvQqIjH0+yti2I1iSSVNKoODV34VIfivLRwJ789tRdr9vrZWxMl1bDeeWX4cOV+rKLO+J6lmMwy47vnIuhw6Yh2XDSsjLse/gcfPfcIWiyEa8jpeI+/LKsTtDJivNCX68qpXziV4JL3DKHh8x7IXnfTCloykkmw67SMCWRvMYJkovz569ASEey9jsd73MWYctu1uL6aYvyORFPrreDRDXOJHNhKrwvuJtZwvrquUzfzGRBEcibdhCAIlHhM9CguY+3CWUydOpUHH3yQnr16448pJJU0aRUkWaTUbsMsS5gkEZ/DzAOfb+SrjVWoOqzet4IZt49FFkREUSCYSFPisbb4DjcWBEXh2DQs/hfRRsT/P0Dz4Hdk51wspuxqpIBBzt69zgieL31pKY4Tbyb4ys3MeO4hCs97gEalEGvXkeT1G0vtoneNOfIG6wTfuKuIbVtC/aznKTj/wRY/KlNee2LblqAlo4gWx1HPO7F/A7qaxtFnAgBK7V5kZy7WzkNwDzubumlPUPvFo9i3LsY38bqMeJVnxDmGQNv0J6n97GFiW8aQc9KNLarjJm8RBec/SGzLQvxzXqTyzTtx9p+Ed9yVP8jqrDU0JZWSJGY8KBsfPql3IdubiZXN317Ln8/uhz+WYmtlOFNh3lgezKilv79iPw+d2ZeRnXO59KWlhyWjiqpnkXSBlh7jjX8b1dhbvAcgpRgV9KtGdczMo4NR1dhQnl1Bb6r43ppAoKrpPPD5BnoVu7OujyAYLWnHgkBM4dO15fyrbtu6phJZ95VRcU7FcQ8/B89xFyOaW8qNGPZlnxNY8JahpXDanTh6jz+kGNuw3T//DfJOuR1H73FZzxdECe/YywksfIfI6mkIkoznuEtwDz+71ddrRKpyJ6bcslY7OSLfzTK6Rk69IxMg6JpK3YynES12ciZen7V/dP0s4rtXc/4t97G0RmwxvtIIUYChHXx0LXQxZ3NVVpJl3raarI6Lj1Yf4OOGLo7WPOCbW+y1oQ1t+PegtRbnVCpF9/u+yfz/y5O6ckq/Yib1LsqQsFNPP5MTTj6dxx7+E8PHTmJulZl2uXbiKZnTbnqA1+65muDCd/BNuBoAS0kPHP1OZOkXbzF88jlA9r0qYDFawFOVOzLOKEdDY1wg95pAdcKIC0w57XANPp1k+RaCi98juW89ySm3cuvrKYpz3WyokrFd8Ci+RW8z89P3WbHoW8SxN2Y5vfhTYDGZmXLxdayVe7Lk7cfZPf1FPt8yj+SwK3F0Gkg0pfK7Tzcytms+cUWjIhhnT10Eh8VEuxwHLqtsVMK1GBo6mqaxtSJIdSRFjwIHgaSERxAxySKptI6s6cii0FApNNqkZVnMdJz5IylkUUTVdEIJhRxHdkL2cNVlSRQo8dqQBAEdqI8l0DSQRBFN11FVY67dsFkTcVnljGJ7I8JxhYSiIYrgtZszftvJtIosipnqpq4bPuroOqIooAkiYisrc1rViKVUzCZjZDKRUtlRG0ESQE3rbKsKM6DMi90sYTXJ1EWS6MDu2ijT11egaxplXhsuq4lch4mqYBynxcTKvQHynBYEdBZur0VNp9FUHU3TyXGaCMUULhpWxhtLjCJDjsUoLnhsJvZURXh29jZSuvHNVIG0DrkOC3WROIqicObgUlKqSiSV5qR73uHgV8+T3L8BS3F3fOc9kKX70hSpqp1snf4USvUuHL3HG57cTRTOU1W7qJ32OKLFTuElj2SPUooSvhNvILZtCYF5rxJZU4G14yB843+WJa7W4jUrdwBGrN0cWiKCf96rmEt6EO04JvMJRTfMIbFnDTkn3Yjszgdg8a46Nu+t5NN7b6F///7cescviabSaKqGqmpYJYFgXCGliOS7rRS4jJGAbZVh1EOTiOysjtKzyI3VJBFLpUlrOk0t6JW0iqIa32FRIDPu0oZstBHx/3I0tZZqGvy6LDIJ5VALtd0sceYgIwt981urSCgasrcI79gr8c9+gciG2Tj7nZjZ3zr2OsTtq6mb8TSFlz6MIIhITh/e4y/FP/tFYtsWZ1rEMs/pMIDgoneJ71qNo9fxHA3hNdPwjrkUQZRIh2pJHtyCrqWxFHc3bl6XPoJ/3quEV35ObMdy7D3HkH/6XQCY8ztQdPmjhJZ9RGDROyT2fYdv4nVZBAmMQMTR63hsnYcQXPQuoVWfE926CO9xlxjt6q1UG38oxnXP59zBZXzcxNrpxnFd+GZjZRa5BTK+2o0icc1pUbpBLf3EXoVH9M5u/jxRFLCZJJ7/did5LgvnDjaSKEcTThNFAZ/dzOPfbGuxremSu2qvn4tfXHrUc0qpeosWeE0/VGE9VvwrquGJfeupn/0CSvVuLO37k3PSjRll3uZI1eylbsbTpCq2Yus6nJxJNyO7Ds1IG1Xyp0ge3ISt6wis7ftnn7+uEfluFoH5r6PFQjj6nYh37OXIzpwjnqOWSpDcvx7ngMkttqnxEIFvX8NS1htHnxMyj4dXfUGqYit5p9+VlWhKh2upn/MSlnZ9GTjpAl5ZsjfreGLD6ENTqzyAhy1y1vd2fPd8Zm6szHy3Bcgi5jM2VGT932ib2IY2tOHfh6YiYk2J3PF/m5+139OzdrD5D1MIxRIM/9M8wkmdAWVubrnnz6xaupB77vwF4+96lufn7SStAxTgHDCZ0IpPDd2YBg0Z37iriG9bwtSnHsJy5gNZ6683rwhncWcSu1biGXneMZ1/87ggdXALkpbGUtoLW+ch2LoMp/bzv1L9wf3UufLY2X8S3jGXIMgm3OOu4oQTJzH9+T8S/+B+HH1PwHfCtUg2N0V2CCUVPl69n7DsY9h1f6b8u0WUf/U8le/dg6P7KHwTriElFrGzLkw0kWZrZRhB0IkqOsVeByZZZGd1mJ1VIXoUq1QEory8aD+iINKxwMENx3fGY8322G5qVSY1VAuFBmE3TTNGxlprptMaZsc1TWtVldxulkg0CKuF42nWHwyiqDpd8h2GQ4YGqbSK3SxjNUmoqpEEF0WxgTSnsZplUg2kWRQFdlSH0XWjfb7EazME3hp0Ygy9GWOcrjlUVUXXdNJqGkkyZRINuqajIlARihNJpLGbJTrnORncwUuwQbX9jcV72FcXJaVqbKuK0qvETZHbjUUW2VkTQZZEDtTHMEsYKvQI2M0iaV1GkoxEx8XDO3HpiE7M3FTJ7C3VtM91MrZ7IYp26Pvf+K2UgLpInFAUlu2pxmE34TWl+MMDD1G+bDqi1UnO5F/gHDCpVdcfPZ0isHgqoaUfGna8Z/+eor6jCDWE24eE3d5Dsrnxjr2iRTJFqd1P/ewXSOxZgym3HQXnP3hM9sDxnSsQTBZyS3rSXC44sOAttHiInPMfyJy3GvHjn/MilrLeOAedktl31R4/6999kpC/jmuffR1NMCGhYTEJLNhWz7bqGOcMLcNrN2GRDiVkrjiuI3/4YhOCYOgWjOiUQyiRzoxkmJpVwxvjikZNoTa0jjYi/l+IphXwpnO9TWeGm4uJRVMqU5fvazET6hpyGrGti6if/SLWjgORXYawmeT04Zt4HXXTniC86gvcQ8809h98GpH1s/DPegFbx0FZ7TqW0l5Izhyim+YdlYjrqkG4pYbXi26aixqpx9591KEKXypOOlCJo+8JJPdvJLZpHuXVuyi86E9IDp9RHR91AbZuI6mb8ZQhHLNxHrmTb2rRDi9a7PhOuAZn/0nUz34B/5wXCa+dgW/Cz7B1Gf6TtMwUuCxsrQzTo9BFgdvKhB4FLN1Vx0l9ivh2Ww2bKw9VxhtJOByebGo6zN5chSyJhkBaw8J9uPtZQ302a676w5X7GdjO22JWuzmuHdMJfyzV6s2yT4lB6lbt9fPQFxuPSML/L0PxlxOY9xqxbYuR3Pnknfkb7D2Oa12/IK0YKvxLP0C02Mk7/a7MiAQ0VsE/a6iSm8g99Zc4+kzIOlayfCv13zxHqnI7ltLe+M5/8LDZ9eaIbV+Cnk5hb5bsAgjMew0tEcm0mYGREAjMfxNbl2HYex2qyOu6Tv3MZ0BNkzvlVl5bsrdV272GnfHHUpn7y0l9jGrWzI2VnNyniN+c0ivr3gNk+YlP6VvMij31R/QXhyN7zrehDW34YdBbYXNNiXgkoWRtS+tGl9LIh+dmZsXXHQgxN9fOHfc/zAN33ED5y89lNF8AfBOuJr5zJXXTn6T4yqcQZBOSw0v7E69kz7R/kLtpHs6GDrceeWZ8NpnSAWPZOvN10qHqzEz2Yd9DK3GBGK3H3WsUToedSBpMvmJMee0QbS5SFdsJLn2fdKCCvNPuBGCHVEb+FU8ZZGnZh8R3raLdyddz0x3XoWk6G8uDHPTHMJtEUkX98V3+JMmFHxNY+j7RnSsYcdqlLN1USkI0sWZPgHZ5TnwOC3mOKGZZ4E/Tt5BMG8S0xCXjtpmxmAR2VYXRdB2TZHSqyaKAP5bCLAq4Gvy6m64PkiSS77IQiBuiY437HO6zbb5OyZKIUxKJJNMk0hq5dguhRIqKYByrSaJ/mTdD+FVVzQiXpVWNtGaoxStpteH+r7O3PsrO6gh2s0QxEE+ZMwJwJtnwWtc55FWv67oxn67poKtYzSbMkkS5P4rdbMJjM1HssbL+YJDqYIJ+ZW7C8TT+aAqLJGISBWQgpaSxW2Tcogi6TonXTpHHSkUwwYcrDzCgnYd1+wN0zbPxu9P6ApBIa4QTCZS0htdmwWmR8ccVuuS7aJ/nwGs1UR9NIpklBpe5WVseIw3YBOhf5qQ6HKc8orK1JsXKJ54huOhdtFQc16ApeMZcdljv7sT+DdTN/Dvp+oM4+k5sSPK4MiQ8WbmDuhlPoVTvbrDSvTHrWFoyRmDRO4RXfYFosuKbeD2uQaccU0FIVxViWxdh7zKcqJzdmp6s3EF4zXRcg6ZkYgxd16n7+lk0JUnuybdmyPngEic1W5axZeE0vKMv5J8bIezYwTVjOvHC/N28vNhI1L+6ZDef33w8JR4bqbSRFLpkeAf6lnrZXxdjYs8CnDYTDotMWtMxS2KLkQVZEkmljZEO05EM6v/H0UbE/8vQvAJ+32l9Gr7sWmZm+JzBZZR6rRxoNnvbGoETBJHcU26j4pVbqJv5dwrOO5TRdvQ5gdjmBQS+fQNbl2GYfCUIokTupJupfOsuAgvfJmfidYeOJUo4+kwgtPwT0qHaI6qVC5KMpawP1R88gLmwC6LFjnv4OVkqk/WzX0J25+M9/nIEk4W6mc8Q3TCL8pduwjvuSpwDJiMIAua89hRd+lfD93m+MTvuOe4S3EPPbHGDM+W1o+CCh4jvXIF/7ivUfPQHLO374Rt/dQuFy++LGesreDuT6AgyZ0uVMVclCgxs581UvkXhkFVa049EAGTJsAlrhK4blfaEojKlbzEby4NNXiMbPYtcWWQfjMp0a57dzdFoh9ZUSK4RjeTsaJXw/6tQY0GCi98jvGa60RZ+/GW4h53d6lw2GPYjdV8/ayy2vcfjm3hdVoU5VbOnoUq+zVDmn3RTVpVcjQUJfPs6ke++RnLmtGhlPxZE1s5A9hZjKeudfW4HNhL57mvDliy/I2BU3etmPA2STM7kX2S9TnT9LOK7VuKbeD0mXwmpIw2GC0LGe75ph81vTjlki9ZcI+Dta0fy0eoDCBiq/UfyF1+118/HDeMR6YZKUFv7ehva8ONxNBIO8Nj5/bnxnbWZ/689rh0JRaV5jnb9gSDPX3ER0z79mBUL3sbWZXimY0i0OMg9+RdUf/gggcVT8Y29HKsE1sGTMa+ehX/Oy9i6DEOyOjngT1EbSTF68hlsnfk6kbVf4R17+RHfx+HiAoq7EWkw0Kif/RKyp5C80+5CqdtPzWePEN04FwDP8ZcRoBBBNuEbezmOnmOom/k0ez/+G3/eu5g/P/Io++tiaBrEk1rD3LAJ9+gLsfebSGzhmyz77HU2ffs5g864hpyBJ2OSZZwWGYsksbE8hKJqFHtsVATjWEwmApEUiipS4LZS5rVhM5swpVVmba5mb10MWRKY3KeIEq8t85mU++NsqQrhssgMaOdrVeSqkTg3rai3NnIgiwKyIBBT0pQH4pR4rcSSaQLRJF67mXhaR9c1zJKhcq6qGiZZwmc3EY6nsJtlZBG+OxCgwp8kqakEk3a6FRmK9qIoIEsimqYjNZkPT6sagVgKSQSzJJFSVMJJhQK3jWRaMwTSrDID23kxSQICAvG0QlJNo+kCKiBrIif0KuRgMImu6RS6zHTOs2GWRCoDRvt/p1wHuqZxar9iyrx2kmmNmnCCtAblgSTJFNRIKd5atpfdtXHSqkafEhfdi9zsiEZ5ZXll5lrFddhcGaEmoRPbspDA/NdJByqNtvATrsmsqQBWGRpNW9R4CP/cV4mu/wbZU0jBBX/IEkvT0ykCi94ltOwjo0p+zj3Yu408tF3XiW6aR2DuK6jRAM7+Jx3TiKTPAv6G6bDY1sVocaOrril0TaX+q2eQ7B76nX4tQU0gktKJbZ5PfPtSvON/lhkv7Zln4VcntOOk+8/DlNeBvDEXIaDz/qoDnNK/mE/XGu39ApBMw9yt1VwzplPDd87QFBjc3sfg9ofWbFkSs9rRm0ISBSwNGw+nK9CGNiL+H0Xz6tKxVImaVsAblbHPG1LGu8v2ZTzFl+6qY+FvJjLm4dkcDCTIc5oJJ9Moaa1VgSazr4S8CVdR883zRL77JuMVLggCOZN/QfkrN1M3/SkKL/kLgiBiKe2Jc+AUwqu+wNFnQlaVzzlwCqHlnxBe8yW+JpZKrcHZ9wRsXYeTrjuAubgbaqQ+sy2+Zy1K3T4Kzr03q+pu7zkWNVhF/VfPEPj2dbwTfoar/yQEUcI99Azs3UdR/81zBOa9SnTjXHIm/Rxrg+VJIwRBwN51OLZOgwmvnUFw0btUvnEH9h5j8B5/Weam9X1RH8uuODTamDUnw40JERGyFNFFUeCa4zrx8qLdWWR8zpZqNE1n2e56rh7dMdNK3Bxbm82ifx+UBwxRnQdO78Pjs7ZSGz7UURGOK/z6o++OSMJNzRII/xegpeKEVnxKaPnH6EoSZ78T8Rx/2WHbwtWIH/+8V4hunIvsLaLggoewdRqc2a6nUw1V8g8bquS/wt5rbFaVPLJ2BoH5b6IpCWPufPRFh7X7OxySlTtIHtiEb8LVWa1xuqpQP/NZJHc+nuMuzjweXvUFyQMbDUuyJgmBdKia+tkvYmnXF9eQ0476ujaTyNyt1ZkRhmNtL2+cE/9o9QHevnYkN09oWfVvTCA2HY9oa19vQxuyoevG7KsoCt8rcdecdDclb404uX8pS9p7eXbuTi4f2YFuhW40TaPQZaaqyf3+Vyf3pHOek1N/fg+rli+hbvoTFF32aEarwtZlmCHcuvQDg2wUdyOhSuRO/gUVr99OYO4r5E65lTQCKV2gU/tOlPQZTuW6mbhHXXDYBGgjvk9cYCnpgbV9P5TafUQ3zye2fZmRvB93FfauwwwXlcseJbFmGjsXvMl5J4/DPeoCPMPPRZCzK9CyK4/8KXeQO/JMkgtfZ/4bj2L98m26TL6SThddSPdiJzXhFJIoUhGMowO9StzkhVN47GZO71/S0MqtUx1OsL06AujUhBUW7ailxGsjkVbpWeRhwbZqFu2opSacYkKvPG6a0B2tYU7cIkvYzFKLzzChqNSEjeJKrsOC3WKE71aTRKd8J9XhBJqus6MqTFUwgdUk0sMko+qG0Bq6jtUsIssSWsPsus9hxiRLROMKobhCoduMomnYZSnzXgTBIOKN4/+ptEZ9NMmBuigJVcUsS+Q7rcbMuiQRiKYoD8TxWGU6FbiwmSXKfDZiKZUypwOHLJNUNSTBmGnvXujmjpO6UR9RaO+zkNZ0Hp62kZ11UYKxNKqq4rGbyXPbmL+tCrtJojwUZ1dNlFRapV2ujbgCu2piFHmshOMKW6si9Cn1ZEYBmv5O9m9dR+Db10lVbseU16HVtnAJUNJGkju6fhb+ea+hJaO4R5xraMmYDimiJ/atp+6rZ4zEfb8TjSq51ZnZnqreRf03z5E8sAlzcXfyz703M9ZxNDSScF3XCa36HNlbjLWZWnp41ZekKnfQ7uy7qUmZSGk6asRP/TfPYS7uQemYczljYBkXj2jP5vIAv7n7ThLhAO2uvAfJZEFRdbx2Ga/DUI6vi4Uza3TfUndmlEFoGDf4vmgj4EdHGxH/D6FpZVs25ARJq0cXOWruH95I4pvOJTc+tvA3EzPPe2fZPmZsqCDXYebTteWZx4d39LH2QBD7oFOxbF2Mf86L2DoOyLR2y+48ciZeT930Jwiv/Bz3sLMA8I27knijzdkVj2cqzyZvEfbuowmvmYFn5PlHFW2TrE6k0p6osSCRdV9j73U85rz2JPauxdl3IqLdCxhtvsn96ym6/DFEm4vg0g8JLp5K/cy/E904j8ILHkKQZGR3PgXn3kts+1LqZz1P1du/NubExv2shR2VIMm4h5yOs+9EQss/JrTiU2P2ve9EPKMvxOQtauWMfzo0t/3SNJ1ZW6qzCK0GGdadSmutqqdn9j0GHtyaz7kAvLt8H1NX7EfX9BZ+3c8v2NVihi3PacYfUzKt7P+XSLimJImsmU5w2YdosSC2biPxjr3isHPgupomvGYagQVvo6sp3KMuxNMsYEzs+466rxqq5H0mGIttk2x24sBm6r/5J0r1Lqwd+pNz4s8x5bVUPT0WhJZ+gGC2t5gPDy75wAhCz7s/I/Km1B8k8O3rmeA48550jbrpT4GukXvK7a3OujVHJKnyTZOxBkkUjuoZ3jwxeDhi3bhfU1HBI7Wvt6EN/2toJOHGXLCOKB67yrCmQzqtYmnFeqopir0O/nj2IR0LUZT4+o5x3PLuarbXRLhlfFcm9ipEFAWunzyEF6bcRPnHDxNc+gHe0Rdlnpcz8ToSe9ZSN+0Jiq96EkE2Yy7sjHvYWYSWf4yjz3hsnfoTTuq8vPwgav+z0Db+luj6WbgGn3rU9yNZnVjLepKIHkNccGAjRZc/hhr1U/vl4yjVu/BPe5R4r3HkTroJQZSwDTmDwu6jDaXqBW8R3TCHnInXZylcAyQBT1EXJtz1dxYvmMu+r15m47sP8/TiD/Hedz8XnX8Ofzq7L8sa7mcOs0ynTg6W7a7nzaW7sJskehU5ORhMsLE8gigYlmVV4STFHgtum5n99XE2HgxQHoxjkkW+2VjNpN5F1ERSVAQTmCSR0V1yyW3wZk4oKoFYivpoCq/NhCiK1EWT2C2yQdDiCrWRFDkOCzlOMwf9adxOM5XBBGU+Gy6b2Zjz1o35cBHQhMbKujH/HUwqFDgtVIaSuKwy3QucKGkNAQgnU4QTKmZZINdhoTIUJ6Vo1MaS2GUTDptMNKVglh04LRKbDvpxWk0gitRHkvgcZko9NgQEhMY2eVEgnkpTGY2hqlCWY6fIZSWSSPH5unJW7PWT47QAOsd3zWNYlzw+W2N0XsUSCtWhOCU5DlbtqWPpToEcp4X6SIJEyrAqK/baWLffz6byQ8WJxIFNBBa8RXLfd0jufHJPuQNHn/EIooQV6F5kozaWIqlArkNm/YaN+L95jmT5FiylvcmZfFNWxbxFlfzCP2YJEmqJCIGFbxNePc2YOz/5Fpz9Tzqmtbg5kvvXkyrfSs5JP896fjpYRWDBm9g6D6X/uJPZUBEzWtK/egY9nSTv1NuJKPDOigO8s+IAuVUrWT3jEy6+4Q58Y0axam8AQRC5bWJX/JEUd5zYjUe/3kZtJMFZg8oMf/CG4E/VdayHK3234UehjYj/h5AVwKo6oB9TFWpIBx/3ndaHGRsqmNK3OMv66kgtoQ98sdG4sTZbo5NpjbSqgSCSd8rtlL/yC2qnP0nhRX/K/OAdfU8gtnUhgflvYOs8BFNuO+PGctKN1Hzy5yybMwDPqPOJbV1IaMVneMdcckzXQ7J78Iy5BL3BnkEwWRFkUyaoqJ/1Aq4hp2fIj2fk+Tj7TKBu5t9J7F5N+Wu3knfyLVhKjTZae7eRWDsMJLjkPULLPyG2bQme0RfjHno6gpSdCRctdrzHX4Zr8GkEl7xPeO0Mohvn4Ow7Ed/oCxA9/1pC3ggd2FEdOeo+PxSyJDConbeFYFqjVdnh7Ndae7g2ksJlkQgnj00F/d8BTUkQWfsVoWUfokb9WDsMxDv28iNa1sX3rMU/+wWU2n1YOw4i56Qbs6x11GgA/7xXiW6Y3dCSll0lVyN+/N++RnTDbCRX3hHnzo8FqZo9xLYuxj0qO4ll2JK9j73XOGxdhgENKunTnkCQzS1a0iNrppPYu46cyTf/4ITS+B4FR61WHy4xeKT9JEnkvCFlWdZ9bWhDG34YIsk0K/fUo6gapV4bvYrdh73/6LpOZSiBpkOJx1BCjisq1xzfBZssIghGy7EsCuS7LNx6zaU8vG0JgUXvYu8yDHNhFwBEq5PcKbdS/cH9BBa+jW/8zwDwjLmE2LbF1M38O+afPZNJZlra9cVS2ovg0g8MMiK3br8E0L/EycbyCIr+/eICye5hzC//gbJuGgs/fJHI2pnG80ddYNiouvLIP+u3xHevpn7W81R/+AC2rsPxnXAtpgZPcxnoX+yid7GLioGj6dx/BAdWz2Xf7Le4+ZrL+fvjf+Xee+/llBNPZv62WrZXh0koFraU++mc52D1nnrmbamiW6ELdA1BkCj12Fi1L4BJEshxWCkPxLCaJBKK4fltlXVqwim2VYXJcVpQNY09dTFynRbSqsbu2ijoNHxuOj67BalBvK02kmRbVYh4QqUyGEcE3DYzuXYLkiDgMEuGf7luiLs1CsA1rVLqgoCOwKD2PrZXR3BbJQrcdiJJBZMkUh1KYjYJ7KhKUG6Jo2o6RR4rVpPMAX8Mi0miLKchMZzWiCRUUmmdynQCm1km320l32mhzGfsI4kyVotATNEIxlWcZonN5WFMos57Kw+w/mCA+mjSeL9AZThJhT9GIKYQTWnURRNY0Ch0WVFUgyCCRqnPRoHLTFUoSe9iN/v8cXRdQ6zaRMW8d0nsWYPo8Bpz2QOnZDoiJnb1sq4iwtYa4/oK8SDbPnmTyLqvEO0eg7D3PSFrPCC6cQ7+OS+jJSItquRGFX02/m9fQ4uFcA46Be/xh587Pxp0XSew6F0khy+rLd0g3M8iCHD+Tb8l2jDeEN04h/iOZfgmXJ1lgaZGA6x9/wnsxV259LrbyPPZuGykQLHXiqZDdShBlwIXd53cE5MgYDVLmGWRQNz4boJA+xw7pb7v1+HXhqOjjYj/h5AVmDZUxI3ZnSNXiVbt9WcsglbsqadHkSszt9laUPvOsn08PXtbpq24OamyyIZNRULRkD2F+E64jvqZT2dVvwVBIPfkWyl/+SZqpz1O0aV/Q5Bk7N1HY+8+2liku4/K/OjNhV2wdR9FaMWnuAafesxWYYIgIJgbWn7SCtGti5B9JcS2LjIUURsE43RNRRAlZHc+9u6jSVXtRE/GqHzrVzj6nYRv/FVIdg+i2Ypv3JU4+52If85LBOa9QmTdTHwTrsbWdUSLYEVyeMk58XrcI84htPQDwuu+IrJ+Fo4+4/GMOP8HVzj/3chzmbNayxuRVvXvrVp+JPxfIeFaMkp4zXRCKz5FiwWxtO9P3pm/xtqu72Gfo9QfxD/vVeLblyJ7Csk/+/fYuo3MbjPP2JslcI88H8/oCw8ttmra0CRY+DZ6OoV75Hl4Rl14RDuyY0FgwVsIZhvuYWdnHjNsyZ4ybMlOPGRLFlr2EcnyLeSd/qsWSu7+ua9i7TQE54CTM49LAkzsVcg3m6paaBO0pt4/b2s1q/b6j0iWh3TwHTEJ+H33a0Mb/ldhECSjM6rx/2PBvrooOpDjMHPAH6dDrgOHpfXQbtnuehZurwYdBnfM4YSehSRSKg6ziCQIDR7Qhr3hW0v28sKCvThPuJHInvWEZj7BcXf8g631Rixh62zcX0LLPsbWdTjWsj6IJis5J99C9dTfE170Np7xV2fei2fMpVS/dw/htTMya7lHhmA6+xxzbBJNV5bGuKDABIGjxAV1Ieg24ky85WGCi94luOhdopsX4DvpBuwdjZZeW6fBlFz9DKGVnxFc/B7lL92Ee+gZeEZdQNrqZGNlmEEdfHgsMnv8CYoGTuCKiy8gtnUxzz31Vy696AI6detJ90mXYe8+GkQRt00mGleMAEvXiCbTuG0mJEEgqen0LnESjKlsrgyRazfTo8hpCMZJ0LPEiywZomqxpEI0qdK3xJjPTioqFYEYdouMx2pC1XQ0NJwmmXgqTTSRpsKfoCIQJxBL0LPITZHFgtMi0afEg8tuRVUbxXwVDoST1EUS5DittM9xYpJFZNGwlvruYJBwSkXTZLZVh8h3WFHRUTWNeEowrKlEgWgyzfaqCOv21ZNQ0tSE45zvKyMYS1DujxFLJikPqtSGFU7qU4zXbqYmmqTUZ0OWRJS0ysH6OJXBBP5oikAU0rrOqj1+oyikqYTjKhazhs9qpjwYZ0ulzK7qKClVR0DD47Xy+bqD7K+PY5IF/NEEZllG1FXSaagMxqnesoLVn7xK5dbVSHYv+eOvpucJZ1KXlIg2+c6tLw9SH9OxCGlqln1OYPF76EoC19Az8B53MWLTNvOaPUab+f4NWEp6kjP5ZswFnTLbkxXbqZ/1HKnyrVhKen4vgdbDIbF3Hcl96/FNvB7RZGFAgcy66jTRDbNJ7F7Nw488yqCJQ7l+6npjHO2b57GU9cHV8LsAMlVyLRnDM+UO+rX3kee2IYnG5w/gtJpIpAwv8ISSxmU1UeKxsbkyhM0sY5ZF9tXHKPLYDM7Shp8MbUT8P4TmgSkcfkb8cCrpR6uev7NsH7/7ZP0Rz2PtgSBXj+7IG0v2EFM0nP1PIr5jKf5vX8facRDm/A6AoaKeM/lmaj97mOCS9zOV7pxJPyfx0npqpz9J0aV/zcyRecdcRsX2ZQSXvJ8l6HasMHyX3ya07CNsXYbhHnl+ZpsWC6EEKojvWkl821Lyz/g15qKuBBdPJbTyM6O9vNdYck66EUGUMOWUUnDe/cR3rqR+zkvUfPzHIwq0ya48ck76Oe6R5xNa/gmRtTOJbpiDrdtIQ1CutNdPorL+r0JrJPz/R6TDtYRXfUF4zQz0VAxrp8F4Rl1wRAJ+SLhtGoJsxjv2CtzDzsqq0CQPbjEW08odWNr3I/ekm7KSMPE9a/HPegGlbh/WTkPIOfH6rCr6D0XiwCbi25fiaZY9D634jFTFNvJO/1UmqZWq2kVg4TvYe4zB3mtsZl9dU6n98jEE2UzulFuzvqeaDvkuCxaTIcgnCgLXjumEy2bCZzfjj6VYuz/ArAai3qg30dr9pbny+bEQ62Pdrw1t+F+FIAhI0vdbWywmiZSiERPVBmGtwz9/7T4/XrtRCVy5p5Zx3fLx2M2EE2k0XafUa0MS4OM1B3nkqy2oOkg2FzlTbqP6g/tZ8uEL5JxwbeZ4vglXI1esJ/rVU5gvexLRYsfWYQDOAScTWP4p1m6jsZT2BGDi+LF8snQAwcXv4ex3IqLF0YKEA3y3P9jyQaBaOXpcUBWoYM9cIy4oOO9+9HQS/+znqXnvXmzdRmHrMgzXgEkIkgnPiPNw9DmBwPw3jHV+/Sw8x12MPvBkPl97gK4FTk7rV4osGmNsXUdNYoqnPzuXz+K7aa/z1bP3kFNUxsApFzP01HPYHdLwOS2IooCKQJnPTv8yD3XRNOg6K/b5SSlpir1uehS68NhMdM53YTeLqBp0yLUTT6mUemXaNVQdq8NJFFWnPBBH03TsJpE1e/10LXTRIcdGNJVm9b56EoqKRRSwWkz0K/VhMgloOgRjKRTVaDGvCcYpDydQVI1YyujA7FrgzojCaTq4LBLV4STVoSSO9hI2k4zNLDFrYzn76uN0L3TTJd+JwyajaBr5bhu7q8N8uPoA47rmAQIbK4IEo2kQRdYfrMdqzsdukhAFo81+2nflVAbjqCpUhxP4Y0niiobVJBCIqdgtZjw2jRKXiN1qJa3qFHtseB0WSr0WTCaJ7/bVE00qeGwQjuskgURKYUUwQWTrQoLLPyFSvoPc/EKGXPALeo89C5/HRWU4gjmisaMumfne6JpGeMsi9n/7miHc1mkIOSdcm7Xma8kogYUNaucWBzmTf0HPUZMoj4vGWhkNEJj/BpHvvkF0NFbRJ/ygNvSm0HWNwLxXkdwFuAZOASCeSpMO1+Gf/SKWsj78o64794fjDeNoTwI6uafekYnFAaLrvyG+fSm+CdfgLOiA2SwjCgKariOKhhCfz2oiZZIo9IgZgTVd13FZTdSGU6RVQ2C1jYP/9Ggj4v9BNA9MDxfwNldJP5ZWUIBn525v8ZgsGgF54zyxktZ4aeFu0k2y8Lkn32K0qH/5GMVXPJZp5Xb0HEN8+3iCi6di6zwES0kPJIcP34nXU/flY4RWfIZnxDmA4fPt7Hci4dXTcA065QcRFe+YS9HTCoJsQo36iWyYQ/LARlKVO5BzyrAUdSHvzLsx53dE1zW8465ESyvEty8hsnYGsR3LyD/j7gwxs3UZSknHgUTWzSSw8B1DoK3XWEOgraE1LetaufLImXgdnlEXGIRv9TSqti/FXNzDEIbrcdxP6kPehmNDsnIH4ZWfEd28AHQNe4/jcI8494iZZy2VILzqc4JLP0RXEoZq6ZjLsnQD0pF6At++RnTDHCRnTgsxNiVQiX/OS0YV3VtE/jn3Yuv601jf6bqGf87LSM4c3EPPyjyu1O03ZsC6jcwQbj2dovbLR5HsbnIm35T1+sHFUw3Sfsavs6rkYFS7+5R4OGdw2WGTfu8s28fcBnHAw91fmt+T2pTP29CG/xw65NhRNY1wIk2fEk8miG4N+S4za/bVI0kihS5DGMvnMGMxiaiajt0sUx9Nsa8+1nBfMeICW+chOAedQnjFp9i6DMPWYQAAVrudkgvvZtkzd8Ccl8idcivQYHO2axW105+k+KqnGNwxj7QO3vE/o/L12wkueT/Tzp71XtwyKV2D1OGFQb9vXFBY9jR1H/+R+I5lxLcvIb19If2ueIDKmIjJmYN8yu24h5yOf+7L+Gc9T3jl5/hOu5pN6mh21iYwmWSuGFnGJ2srCcRSKJ3HMPqucdRsXMT2We8w59W/8e27/yR/2KmMPeMiLhg7gHynMZvtc8psrQzx0sLd1EUU8lwWVu2tRxaMuWhJAEkQcdqMVvUcu4zPYSaUSOMVRfwxhXynBUdaZvPBILvjKbx2C+v21+O25gM6HpuMy2zCH0tQG45zoN6C1yaT1lTmbatGEETGdMvDYhJJKjoOs4QoCkQSaeoiSdKaRiSZRhQF4kmFikACiwk2lofpXuSiOhCjPJDEJIlsrw6T57ZQ6LOiqjr76iIEEirtJIGNFRHqIzE2HIzgsEj47GYCsTQ2ScdpFklrRvdGfTSF2yqxvSrM1soomg5eh0w4qeKySLgsErkOE2ZZwGmWaZfrIKFotPNZqYumcFshllRIKSqxBAQUUBMRIuu+Jrz6C9RQDZbcUq6++4/kDTiBJfvCbK1XKNXidMix47OpdM0xsac+zobVa/ju29dIVWzDlN8xS7jNLoPPorJn2Wyq5r6BGgviHDAZ77grkGxuDsYbO+O+JLDoXaOKPuxMo4p+FG2kY0V0w1xSVTvJPe3OTCv9Vr9O/VfPoDdYkgqCyIpd1RTsnsm+vd+Rc/ItWeNoQz0xPp79Ipb2/fAOO5MXrxrG+v0BBrTzkes0Z5J/oqhjMmXfOwRBoEOuA1EQSKs6ZTm2/9NFqP9WtLGI/+NoTSX9WFo8V+31c7CZfZnLKiMAoUR2GrqRhDdCcvjIPflWaj7+A4EFb+Mbf1VmW85JN5LYv9Eg6Vc9hWi24eg9ntjWRQZh6DI0I4rlPf5yolsW4J/zEgXn3f+D3n/jzScdriO4+F0EyUTBBQ9l/M7BEMUIrfgUXUkiWex4x15JcNE7aGqKqnd+g637aGNexluEIMm4Bp+Go88JBJd9RHjlp8S2LsLZfxKe0RdmHTdzPewevMdfhnvEeUQ3zCK08nNqv/gb0tyXcQ44GeeAyS1ITxt+WuhphejWhURWTyNZvgXBbMM1aAquoWcecQZaTytEvvuK4OL3UKN+bF2H4x13VZZwm6YkCa/4lODSD9C1NO4R5xlt6A1t5loyRnDJ+4RWfoogyq1W0X8sohvnkqrYSu4pdyA2jGfomkrttCcQzTZyJ92cWQD9815Dqd1HwfkPItncmWMkD24huPg9HH0m4Oh1fIvXEDCs6A6HxrEXI0sucN9pfTL3l8YKuM9uZsaGihZdOXBsrg9taEMbflrIkqE8fSyY0LMAi2SoYHfIdaCk05z42DzKAwkGtnPz0c3H47QYVdChHbws3GEolnusEtqEq0nsXWcItF39DJLVSfcCJxvULg3jXB9i6zIMe/dRiBa7MUP+/r0EF7xFTe4NlHitlHToSrjviYRWfoaz/6SsBP2QMgdVoTjR+NEtMr9vXGAu7UWP4cezbeabhHauZts/f87gs69nj2swiiBgLuxCwYV/IrFrJaH5r7PurT/hKu3K6PNvIq/PCAKxFJGEgsUk4Y+lqFV1cnofR1HhAKRdmwit+JTKhR/w/sIPWD9iArlDT8XerhenDepIRTiJmtbw2s0oqo7NLKPpsLs6xO7qIAlVZ3D7HPJcFiwmEwgCSSXN9uokNRGj7dxpkcl3m/HHkuytDeGPpbGaZLrmOfDYTGgaiJKZHgVOfA4ZFZ0vV+9nbyBJocvCrE1VnD2glByHxN66OC6rjEUW2XwwiIbOQX+ctKbjj6ZAV2mf60FTdXZXBlm6p45ATKHYa8Uqi9hkiVyHlTMGljJncyUeh4kyr52qcApFE+mY6+BgMIGm6Vgk+HB1OTrQu8RF93wnJkFnZ3WMpKJjM0sE4wouk0gUkYuGlRJLaXjtJtrl2lmzN8iOqhCr9tbTKc/BuJ75JFMaNcEYu6qiVB3cRXj1dKKb5qIrSSzt+pJz4o14ug7jnIv68+qi3VglkVyHhQP+OPXhGBIagQPb2TzjDeJ71iC58sidcpsxB96kiuzfvYGds18kVbUTS2kv8s9/IMujO7FrJfVzXiZdfwBrp8HkTLwuayb7x0JLxgjMfx1zcXccvcdlHo+un0V85wp8E6/L/H5WLP2ONZ+8SPuBY+h1whnE0yr7gml65ZrY8s5fcVpNzJ/xIQHRjc8u449qRFNp8kVjdOFILg0mSaRzvrPVbY3XotFWTxR/XAfA/yp+NBEXBKEd8AZQiJE+fUHX9aea7TMe+AzY3fDQx7quP/RjX/t/Aa2JIR1Li2djcNwUqqYRa5ZpFltR0QawdxuBc8BkowWs82Ds7ftjkUXiVid5p91B1bu/xz/nJXJPvsWook++mfKXb6Zu2uOG1YkkIzl9eEZfRGDeq8R2rsDeIDLVGnRNJbj0A6Ib56Eno0jOHMyFXbB2Goyt02AsRV0pvuIJ/PPfpPbzv+EZdQG2zkNQ/OUEF76Dnk6Rf/bv0BIRaj7/G+4R5+LoM94g28s/pnznctxDzsA96gIkqxPRYsc39nJcg08luPg9Ig3z4K6BU3CPPK9ViyvRbMU1+DScg04hvmsV4VVfEFz0jtEh0G0Erv6TsXYalHUzb8OPg1K7n8h3XxPZMBstHkL2leCbeF2mtfFw0FWFyPrZBJe8Z2TI2/Ul76zfYm3iy63rGtFN3xKY/wZqqAZbt5H4JlyDyVdsbNdUIt99Q2DBW2ixAPZeY/FNuOawSRclUEnku69JHtxM2l8BgKWsj9G+eYREjZaM4p/3KubiHjj6Tsg8Hlz6Qaa63Vi5j+9eTXjV57gGn5ZluaIlY9R++RiSK4+ck25s9XVEUWDt/gBPzd6eEWQ6f2g7zmkQTWua9BPQ2VAe5Nm5O/DZzRldiqY5O1EwlM9b8x5vI+P/22iLC/5vwm01M7hjLqqqI8si5z+/lP0NCftV+0PcOXUNj100iBN7FrB4Ry2dcq1UBhJ4bCaSipW80+6i8q27qP/qWa763WNceVwnLnhhGd4xl5LYvcYQaSvpgezMwdZpEM5BpxBa8Sm5J55E564jMcsSteOuILZtEfWzX+CkXzwMgkCnAicH6qPUhTXiOliBFBA/uJnAgjdRavcjmCyY8tpjKetLXtfhKLll3ysu2PL53yg76Wf069mVma88xlf/vB9LaS98E67OjJvZugyjS/+hqDsWs37aa3z15C/J69yXDrf9isHt+zFrSw2xpIKmKoTjEEqBpaw3+WW9SQeriK6exrZ136AunY0pp4z1gydz2RWXURNRSKqGjVi+y4zZJLAnkKJbkZNEOAkCmAVIJtOYRBOKaliVtc9x4LPKHAzE0dMalYEYyXQaj82MVdARgISiUeSx0LvYgc1iwmqS2OePciCYRBAEqoNxdtdG8Npk7CaRPJcFqyyyqyZMJKWQTmkEUyojO+VgkQT21isE4mki8QSSLlPssRFNKGyrjDK4vZc8lxW7WaJdjoOzBpaweEcte+qjOK0m0mqacCJFPKngzbMTS6WxmCSsssjCbdWUeWyYZYFwMk1NKIZZMvQI1pWHKXJZWLE3wAk9C8l1mRF1gVgqTa8SD9XhOMm0jlWWiUaC7FgyjW1fvE/swBYE2Yy91zjcQ07HXNgZAItZoCqUwCxCIJ4kpalUx1RSVTsJLHyH+I5liDY3vglX4xx0apZDilJ/EP+3rxHftgTJlU/e6Xdh7zUuQ1RTNXvwz3mZxJ41SN5i8s+9D3vX4a3+3nRdJ7JuJom935Gq3o2WjCC783GPOA9Hj+OO+FsNLp6KGqkn/+zfZ1rc08Fq6me/0GBJejoA9nSKvdMeR7Y6Of2G3xEQLXTMcRBOpKlaOJWNa1fy8quvUVBcyu6dtYTjaexmEYdJJBpPoegCggBOi5yZ/f6h1onNbRTbcGwQ9MMoJR/zAQShGCjWdX21IAguYBVwlq7rm5rsMx64S9f1o5vZNsHQoUP1lStX/qjz+/8Bzecxj/U5F7+4NCPS1lyIqRGH86QGw4e54rXb0NMpvlm4nFs/3ka4oZrun/caoWUfkn/27wzBEiC2dTE1n/4Zz+iL8B5/GWAQoopXb0VPpyi+5tks/8WmCK34DP+cF7F2HITsKSAdrCZZsQ09GUWQzdi6jsDZfxLWjgOIb19KsnxrprUtvGY6wcVT8Yy6EDUWRKnbT/6Zv84cOx2uJTD/LaIbZiNanXhGXYBr8KlZFU0lUElw8VSiG+YgSDLOAZNxDz8X2d2yQt4Uir+cyNqZBlGMBZGcuTj6TsDZZ+J/jbjb/zWoiQixLQuIrJ9FqnwriBL2riNwDpyCteOAI85d6ekUkfWzCC79EDVUjbm4B97jL8XacVDWAhHfvYbAt6+RqtqJubALvhOuwdresPXRdZ3E7tX4576CUrsXwWRF9hZh8pXgO+FaZE9Bq6+drNhG5Zt3YS7qYmTGdZ3YtiUNXrZ/O+w513/zHOHV0yi68olMxj1ZuYPKN+/E3uM48s+427gusSAVr/wC0eqi6MonsoKH2mlPEt04h8KL/3zEGfnmEACLySDPAJe+tDTjrqBjzA6KgoCq6S0E3sZ0y+P2E7uzdFcdj329FU03BOF+OalHq17i/0sQBGGVrutDj77n/5/4V8UFbTHBj4em6cRSaXRdZ/gfvyHeRBmtU66dub+awKbyIC98u5Mij5VN5UF6FLvokufio1UHmPv+C9R9+wavvPIqE08/h7GPLgCMMZqK127HUtabggseRBBEI4Z49VYkNCb+/jVG9yhlb12MRZ+9yaZPn+W4ax/kQO6QVs/Tmgix7Z/XIFqd2DoOQkvFSVXvJl1/AABTQWec/Sbi6HMCks1FbNviY44LcswqdWtnc3DOm6hRP/buo/GOuzJTYbQBFpNCzsFlrJ32Ov7qCtr16E+y3zlYOw9BEATKXDKV4TTNx9xFJUlk60JCa2aQLN+CKMnk9xpJ6YjJdBp4HCf2MfzGv15/kJii4bKauHpMR0q9TiRBZ2dtGEmQEESJaFIhnFDZXx9FUTXqInHcVjOVoQTBeBKHWUaWRLwOM+1z7NQE4shmiSEdfeysirOtKsxBfxRZEuhb4qE8EKd3iZsir419dTFynWbUhkp4p3wnkbhKjsOEIArsrA3jD6eIpzXSaRWv3WLYqrnslHgsWEwSiXSaDfsDRFMqoUSardUhFEXDIouUeO3IMuyuiSNJIoqiMLlvMav21rN8dwBVTVMXVeica2d/MEGJz45FFOlV7OLsoWVEYio7ayMEYgq7a8LUbV/LxoXTWfj1FyTjMcy57XAMmIyj78QsTRUJGNnFR6HLTJ7DxJaKCCtXLmfP3PeI71iGYHHgHnYW7qFnIloOKYCrUT+BRVOJrJuJIJvxjDgP17CzMutsOlJPcOHbRNZ9DYKA5M7DXNCFnInXHTYmAKh44w7USABzcVckm5tk+VaUmr1Z631zpGr2UPHqrTj7nZgZ99A1laqpvydVtZOSq5/JWAzH5r5AzfLPuflPz3H6aaewvz7GwWACoXo7D/38Iiafdhb/fOkVdlRHjd9FQmFklzw8VjORlILNbIgAioKhrA+08LAHQzMmrRqxgUkSs8h349/WntcGA0eKCX40EW/lxT4DntF1/Zsmj42njYj/27Fqr5/nv93J7M1VGbLd9NMu81qpjaZIKtphbbGSFdupfOsu7N1Gknfmbw79+FSFyrd+RTpQRfHP/p4hrLXTHie6cR5Flz6SsRJL7FtP1bu/xT3ivKw296aomnoPWiJM8VWHiia6ppI8uJno5gXENs9HS4SRvcW4Bp+Ks/9JWRXRdKiW2s8eJlWzm9xTf4mjx3EZdfVGpKp24Z/3qpHJdOXjHXMxjr4Ts/ZR/BUEl7xHdONcQMDZ9wTcI8496oy7rirEdiw32oZ2rQJdw1zYBXuvcTh6jUF2H/5G3QZjhju+cznRzfOJ71wJWhpTXnscfU/E2XcCkuPICSgtGSO8dibhlZ+iRuoxl/TAO/riTNDUiGT5VgLz3yCxdx2SuwDv2Mtx9B6XIfepqp34571GYs8aZG8RlvYDkF25eMdcgn/uK+i6hr3rCKzt+7XI/uq6jp6KZX0vg0veJzD/DdrdNjVLfbXp+VS+eReuwadmKtmakqDitdvRU3GjBdTmQtd1aj56iPieNRRf8USWUmt08wJqP38E96gL8Y29HLtZIpY6dkX7puR51V4/H60+wHsr9mc84gUMT/HmZPzPZ/fjkhHtMzPjjV07bRXxNiLeHD9VXNAWE/x46LpOPKUiinDPx+v5cE15ZttfzurDxSM7EoqleOzrLdRGklhMMrdO7A7oLNtVB7rOX2+9hC0bv+OBlz/j2TWHRuDCa2dQ/9WzeMdfndGLSR7cTOXbv6bfuFO4+jePUOyzs25vHa/cfSm1tTWUXPtcFiFqhHnvErZP/RNFl/0tE0sApEM1xLYtJrpxHqnK7Q0V0bENFdEuWcc4XFxgEiXcFgiF49Qu/4TQik/QlSTO/ifhGX1xJp4pdZm4YEgx62Z9zJsvPks6VIO5sAvuEedh7zGajrl29vqThy10pGr2ktwwi+DGuajRAJLVycCxk+l//MnEcrpR4rGyvSaC1y5R4nYQiKUQJBmHRaTY7cBhNZFWVWoiKbrk29lfF2VbdRCHxURtME4wqdMp305C0fCHk7gdMl6bCVkQ6JLvpCaqUBdNUB5IYJclqqIJehe66FGSQyqZJJ7WMYkiXfJtlOQ6jRlyi8iaA0FqgjFiCZVwQsUki/Qp9aCo4I8n6ZbnoH2+E0GHbRVBoorKjpoYOjo2k4iiga7p9Cx0EUqoBFJpjuvoIZrS+HRtOVuqwmi6TjSuoGqQSIPFDHZZZny3HMwWM4IOWu1OVs2dwfI50wjUVGC2Oegw5ARc/SciFHanNtaynbPQIdK9yEOPfDt7Nyzn/VefI7RrDaLViWvombiHnJ61FmuJCKHlnxBa+Rm6qhhz4MddnIk5tFSc0PKPCS3/BF1NYynrjaW4O77xVx01JgCj400w2zOPa8ko+5+8EO/xl+MZfWErv0+NqrfuRvGXU3Ldc5nxs+CyjwjMe5XcU27H2WBjFt+5guoPH+T0i6/mkb8+wvury0kqOtFwiHd+dylmCT76egG/+2I7e+uTnD2oiKEdchjdOYdCr514WkNAQAfMooCtwW1BbxBxO3ROOglFbRB4M7SmTE20KBp5ZBsJPzz+bURcEISOwHygr67roSaPjwc+Ag4A5RiL78bDHON64HqA9u3bD9m7d+9Pdn7/i3h27o5MtUoAXFaJSFKlS56Db+4cnyHrX2+qavHcxsUluPRDAt++Rs7Jt+AaMDmzXak/SMVrt2Eu6mr4josSWjJG+atGu3rxVU9nFtja6U8S3TCH4iufaLFYAlRN/R2akqT48sdafR96WiG2bRHh1dNJHtyEYLbh7D8J15AzkN15CKJE9YcPIshmzEVd8TRRU22O+J61BOa/TqpiO3JOKd7jLsHec0wWIU8Hqwgu+5jo+m/Q0wq2biOOWTFdjfiJbp5PdJMRKACYi7tj7zEae7dRP4nC9v8PUBMR4jtXEN+2hPiuVejpJJIzB3vP43H0mYC5sMtRr7WhnP4l4bUz0JNRLO37G8rpHQZkPTdVbaiMx7cvRbS58Yy6ENegUzKzhkqgkuCCt4humodgtuMZeT7uYWdRO+1xTHnt8R53MelQLbEt81EjfnwnXHNM7zG0/GP8c1+h7LapSM2IuK4qVLx+B1o8lBWM1n39TyJrplFw4R+xdRxoHGfVF/hnPY9v4vW4h55x6P0Hqyl/9RZMuWUUXfLIUcUDWwsYzZLAu9ePypDnZ+fu4NGvtmb2k0WBh87sy4wNFSzcXouO0UlzZ5PK9w/p2vn/GW1E/BB+bFzQFhP8tNB1nUhCIa3pKKrKhyv2MWdbHdcc15Ep/Y21KZo0WowP+OPkOa10zHOgpFViSZWlu2tZsm4rj9xwJq7CdtjO+VNG0FXXdWo//TOxHSsMAt3gStJ17xfMnvo8Nz7wFGWDJiCKEDmwhT/fdBHOAZPJnXxzi/Msq1/DohfvpejKJw9fOazeTXjNdKIb56IrCSzt++Eefg7WToMR4IhxgQnIc5vwWCU27q4kuOQ9wmtmgCDgGnQKnhHnITl9XD6shO3lQZbsCxDdOJfgso9I1x9E9hZRPOps7H0mkpBa7/RrvN/qmkpi92oim+aR3LEMNZXA5s6h58gJJIoH0bX/MNK6QCqt0znfharpqJrO2cPak1RUZm2uJJJQiacULJKE1SKBrrG7LoooCNhNJgLRGCbZhM0M4ajCkI4+9tXH8ccV6mIqqOCxCfTr4MNrNgh7LK1hlqDMayffY0PRRJKKQk04QTCucLA+gtliwmOWOVgfZ38ojlOWsMgCeW4HPQpdCIJGTUzBJolEUmniijH2kFI00rqOyyQxqW8RiqqRUjQ2VYb4dE05ddEkUcWoYDemjV2oqDXbqN60hMjWJST9FQiiRKcBo+g4YhIdB4/DbLGy9kAQp1lie228xTW36yo9EmtZNeM99m/fhOjw4h56Fq5Bp2QlfLRU3BDhXf4xaiKCvefxhnBvQ3ymqwqRdV/hn/8WejLSMLp2NYH5b/6omEBPp9j32Dl4jr8M7+iLWmxvXOtzT/0lzr4n4JKgrnwnFW/cib3rcPLO+i2CIJCO1FPx6i1IDh+PvPY5M7bUsr06Rp8iGxvffYQ9K2bR6cpHkIt6kWqy6J8/qJg7J3enwG1HRyCpqIgCmOVDxLt5ZbuRiEuiIerYnIi34eg4Ukzwk4m1CYLgxFhUb2+62DZgNdBB1/WIIAinAJ8CLX2jAF3XXwBeACP7/VOd3/8qRnbOpVH8VAciSZU/nmVUssBQah/QzntYb2EA94hzSOxZg3/WC1hLe2dark05peScdCN10580LM2OuxjRYifvtDupeuc31H/zT/JOuxMA34RriO9aSe30pyi+4vEWhMGU14HId19n1FCbQ5BNOHqPx9F7vKGaveJTwqu/JLzqCxy9xmLtMgxNSVDURBSuaXYysX8DCCLWst7YOg7E2mEA8e1LCCx4m9ov/oZp8Xt4Rl+YIeSyp5DcST/He9xFhFZ9SWTN9AbF9G64hpyBo+eYTPDRHJLTh3vYmbiHnYniryC2ZQGxbYsJzHuNwLzXkHNKsXUeiq3zUKzt+vykol//l6HrOkrtXuK7VhHfuYLkgU2ga0jOHBz9TsTR8zgsZX2OacY+Wb6V0KrPiW1ZCLqOvfsoQzm9uHvWfqnqXQQXTSW2bTGCxYFnzKVZLWnpSD2hJe8RXjMT0BAdPqzt+pA8sBFh1PnYu48mvmsFuqYiu/MwF3UlumUhif0bjqkFPFm5E8nha3WePbj0Q5SaPeSfc2/mfGI7lhFZMw3XsLMyJDxVtQv/3JcN652GuTBoEHP74m+g6+Sd/qujknCTJDC+RwFztlRnqt0A5w9tl0WeR3bONSzOFEPE5aEz+3LJiPb0KHKxYk99q44NbdZkbWgNP0Vc0BYT/LQQBIFEOs2qvXXsq01Q6LHz6Pll+BwWFFXDJBl2Rm6rmd7F5qx7harrBGMKBcWlXH7Xn3jxwVtxL3gr0w4uCAI5J99K8tVbqP3ir1zyx9e59Phe7Pd3Zdd3S3njUYNYC64ChnUqwTXkdMIrP8PR6/jMaFAjauV8AMNG8jBE3FzQidzJN+MbdyXhdV8RXvUFNR8+iCmvvTFallN62LggpSpsmzeNARNOJceXg3TiDbiHnUVg0buEV31BZO1MnIOm8LVwLlW6D0Ey4ew/CUfficS3LyO47CP2z/gn4pw38Aw8GfugU1u0KDdeOUGUDLX5LsMoMqcYbjnItzO/YP2300knP2Kr2Yaz0yByewylqstg7L5CfA4LWytDiEBVMEF5MEZC0SlyyFhThqq4x2IilVYJxxPIgkxdSAHBqFhurY4iipBIacg6xIBAQmfbQT95LjsFTgkdGbtFIJRQ6aqDySSzuypseJan0sRTadp5bUQSaeKKArpGTVQhrmikNJ142vBLV9IaTotIMq2Ra7cSSacxyQIWQaQyHGfZLj/5bjOdc2wcqE+QVFRcFpmEkkZPhIntXkNs1yr271qJFguCKGPrMIC8EefRfdjxeLy5pDTwp3RcpClwmBEFFbfZmNMHUOMhIuu+5sDqaWwO1+Arbk/x5F9g6ntCVoylJWOE10wntPxjtHgId7dh5I29DDXPKBDpmkp083wCc19BjfoRTDasnQaDmsbkK/kJYoIdxnc3t32LbelgFYFvX8faaTCOPhOwStDerbHlub8ZLikn/6LBbk6j7svH0FMJii69m/fXVlARUgBYOutL6pZ/jWfMpWiF2SQcYPqGCu45rW+m4m23HJ0GCoKALBpe8qJgCES24afDT1IRFwTBBHwJfKXr+uPHsP8eYKiu67VH2q+tDe3HozUvcVGAi4a359wGkaamdkSNa64kgNrkq5EO1xnZN2eOQaQbbmy6rlP35WNEN8/Pmk8NLHib4OJ3yT3tTpx9DAGq2LbF1HzyZzxjLsV73MVZ59TYYpN/zr3Yu404pveWDtUQWvEpkXVfoSsJw0d69EVYiru3ICVVU39PYu86rJ2H4B1zWSZTr+sasS0LCS6ailK3DzmnDM/I84125SbH0FIJohtmE1r1Oen6g4h2L84Bk3ENmHzE+aDs860mtn0Z8R3LjcSAqiDIFixlvbF26I+1XV/MRV0PS/D/26DrOulABcn9G0jsW09i7zrUiKHCa8rvaCjsdh2BuaT7MfltakqC2OYFhNdMN1oSM10Rp7dQTk+WbyW45H1jJsxsxz30DFzDzspUpdVYkNCyjwivnoauKlhKe2LrNgrP8LMB2P/0JZRc9xxquJbwmunYOg/F3m0k6VAN4ZWfY+08JEOUD3++SQ48czn27qPJO/X2rG2p6l1UvH4H9h5jyD/jV4BR3a949VYkVx7Flz+GIJvQkjEq3rjDaFP/2d8zPuIA/vlvElryHnmn/ypLVfVwEAXoV+rBLIus3hdA13TMJsMScUN5kLmbq6gOJ3HZTIzvnk9dNMWUvsWZpB20Vb6PFW0V8X9NXNAWE/x4aJrOun31bKgIkFJ1dE1H1aFnkQenVaZfqRdBgHjDeIum61SHk4jo5LvMrN4bZEdthAp/jM//8RCrvvqQ8bc8RlGfoSzZHQYgcWAjVe/8lrPOOZe33nobSRT4fP5qLjp1POaCThRe/Bejg05JUPHKLaBrFF/9TMapAoz14+Dz12LylVB44R+O6b3papro5vmEln+MUrMHyZWPa+gZuAZOQZDkrDU9tm0JNZ/8CdmZQ86oC7H2n3SoQ6r+oKEZs+nbJpoxZ7cYMUsc2ER45WfEti0BwNZ1OK6BUwzR1sOsaR18MiUeOz67mVQyybJF86jesITgjpUkg8ZX31XUgaKeQ+k/dAQDho3i610xQjGFZBqsMrTLseC2mokkFeKKSiylYpMlamMKogpaGjCDRQRBhLgKKQWsErhtIiVeKw6zTDiRxmuVSWvgtEqIskw4pVJkl6kMpUgqKXxOK5IImw6GqQoqxNPGcUp9MmlBxmUWcVlk6iIJHDYTPYo82GSRmrCCySSypTKCxSRR6LFQ4LJQ5w/xzZxvqd6+hsjudcTKdwA6otWJtdMQ7F2HY+syFMniQMeomMsCuC0gmUxGdkPQ6ZlnZVtlmN07txFeO4PY5vno6RSW9v3JHXom/UcMZ3P9oaqumogYFfCVn6Mlwlg7DcE75hJsJT3QaIgFty4muPAdlLp9SK487N1H45t4HYIg/CQxAUD97BcJr55Gu1veymqR13WN6vfuIVmxnZKrn83ElXUzniby3TcUXPiHzPEbR95yTr6FokGTsVpl6mLpVrtUm+PEHjm8cOWIDBFvygEzVq1plfpIgkteXEpFKMmUvsU8esHANnX0H4F/aWu6YHxyrwP1uq7ffph9ioAqXdd1QRCGAx9iZMKP+OJti+6Px5nPLGTdgWCr2+Qm1a5Ve/18vPoA763cj6rqCM1E3IpcFvauW0TlBw/gHHQquZN+ntmmJWNUvH47upLgpN+9RrvSYr7aUE7Vu78lVb2b4queyvh013z+N2JbF1J0+WNZWW5dTXPgn1dhKer2va3O1Hg4Ux3X4iEsZb3xjLwgaz5YUxKEV39JaOlHaIkwtq7D8Rx3SRM7ioab8OKpxgL+/9g77/g26vuNv29oD1u2JW9n771IAiRsCHvvDaV0UUZpmWGEvaG0pZS2lE0gCTPsDSF775048ZLloT3v7vfHV1bs2AlhlfZXP6+XXxDpdLo7Sff9rOd53F7c407GOfzInJ1U23aJrUsJL3lbcJnJeqwOPxJb33H7nETrqQSJHStJbF1CYvty0oFqACTVgrm0H5bygVhK+2Mu6YfiKvqv4N7oyRiphk0k6zaSql1PsmYtWrQFANmej7VqGLZeo7D2HP2NInhtMAyDtH+LUE7PKuqbCitxjT4Wx5BDO4yatYmshRbMJLF9heCEjTkB19gTOibgC18jtPAN0NLY+o4XAmz5xbnAqXXOS6T9W8mfdD6mokrCS98hWbOWwmOvRpJk/DNuxzn8SOz9J+5VJTSy8mOa3nkE31l35Xx3QdAs6p69Gi3WStmlf0GxuYUQy/SbSdVtoPTCMeef6QABAABJREFUxzAVVnQscp11F9aqYbl9xLctwz99Ko5hh1N0zJV7vH6KBEidnREkYFxPDyeNquC2N1eR0jrfitsLuXUn3d8O/+uJ+I8VF3THBN8fum6woT7E8h2thBNpkuk0kmLiwH5FRBIZBpW68TjMOVui9Q0RVte08tSXW1FlmVuOG4hmwO9nrCAajrLpn1diSkX4zWMzeG5VJPc+ibkv0/DF8/zpz3/hgosvpTGUYr+f3UbT2w+Rt//Z5E86l+NHeHn17U9pePEGBk4+ngt/dzt//bomt4/WL58n+PV0yi5/aq82lbujzV4qOO9VkjvXINvcIiEffVwHilCieiWtXz5HcucaFJeXvImn4xx2xK6EvKWW4NxXspox4Bh8CO79Tsbs7dHh/TIhP+Gl7xJZ8QF6LIiaV5zrnrdf6xQgzwI9C+3EMjCwxE5dS4JoVuDMv30D1asWENmyjNatK9HTSQDMhRVYSwdiLuuPs7w/pT17YbPaiKZ1VD2D2WwiHk+TSAOS4FzHAZcKhQ6QJZWklkGVVWwmhWK3FbtZwmxW0dJptgei1AYzmBQocqqoskRaIzsloeFxWtnWFCWWSJExwDDA0EE1gYxMKm2Q0gw8DpWKfDNl+U6KC+xEo3G+WLSaRN0m6jaton7TKiJ1WzAMHUlWsZf3J7/vKNx9RoK3PwldEaP8COpAut01dqqgSxI2kwzJMIFln1G/6D1S/i1IJiuOIQfjGn0cZm/P3T6bAOHFbwr6WiqOrc840bApG5D9rujEN8yj+ZO/o4X8qPkl5B90Ebb+E5BlUbj5oWICPZ2k5omLsVYOxXvyjR2eaxtJLzjqN7hGTgEguvYLAm/ej3vCaXgOukh8Z3esouGlG7EPPJCi439PuUvl4IHFvDBvm9BtCvqpuvhxjC5iLJ/TxBfXHcaqmiCRZIZBJS4KHLtEX1VFcMBboknOfepr1vt3jf7fetxALjqgd+7f/w0x6X8SfuxE/EDgS2Al0Bbu3QhUARiG8VdJkn4D/BLIIO4P1xiG8fU37bt70f1+WLy9hbP+Npd0F0F2G1RZYvrlgh/ank++J7R8+k9CC2ZRdOL1OAYemHs81bCFuud+h7VqOGVn3opmyGRCfuqe/i1qXrGwNFNNaPEwdf/8NbLF2Un5ufWrFwjOeYnSS/7caaHbF+ipBJEVHxBa8BpauBGTrzd540/twP/WkzFCi98kvPB19ERE3JQnnpETgzEMg/jmhYTmzSBZswbZ6sQ56hhco47tZEGVCfqztmcfokWakW1uHIMPwjH0sH3iN7eHFm0lsWMVyZ2rSdasI+XfArroSMg2N2ZfL0zenpgKKzEVVmAqqEC25/0kN0M9lSDTUku6eSfpwA7Sge2kGrfmLLsAIXRWNlAIm1QOxVRY+a2ONRNpJrb2CyKrPibt3wqKCfuA/XGNPFqMr7fnL2XSoguy8DVRRHEW4Bp7Eq6RU3KJuhZpIbTwNcJL38FIJ5CtLuwDDkCLB7H1HIVr1DFZHYKvCS9/D8fASYSXvUvBocIbtOn9x1HzijEVVRFd+TGew3/eaQy+PQxDF50eDEov+XOH4805DrSzPWmd8xLBr17oIMQSXv4+ze893mmCRIu0UPuvK1CsLkoueKRDoWh3/GJyb/76xZY9Pn/E4GI+2o2W0h7dKujfDd2J+I8TF3THBD8MIok06+qDbG2MYjOrKJJEgdOCphuMrMrHbt4l2rSmNsQ1ryzLcp0NHFaV/l4HH67z47aaaNi+ia3/uJLy/sMZfNFdNIaThONw47F9eWrqL1kwfy5vffgZ48aM5tS/fMW8f91FdNUn9Dz/Hl6YejGt8RT/fHAaM597Ct9pt2Hrs+tnkwkHqPnrz3CNPIqCI365h7PZOxI7VxOa+yrxLYuQzDZcI4/GNe6knEWpYRgkti2l9asXSNWuR3EV4R5/qijCtylmh/yEFrwm6HPpJLbeY3GNO6mTFsmuNeR9ktUrBB2u50icQw/F1ncCstmKCTAroMhQkmfGZTUh6RoWswWnVWF1bYRUJoWWSVMQr2XzqsW0bF1FeMc6tJhgd0iKiqukB67S3hh55ZgLKnH4KnB7fdgtJra2Qp4VJB3SOgwodSAD4UQSp9WMxSThsVsIxzOE4gkaIylCYYhnKYwWhHiaywpuqxWXVSaa0kkaGVJpmQKXQiShUdeSIpyCjK4hhwNIwRrk0A4skXqCtVupr95EJinE/GSzDVfFAIr6DkMpGYC391B01UwophPX2vHETSIhzLfKtIQ1WrNPGFqG1LalxNd8Smj9XAwtjcnXG9eIo3AMOaST4F+qYQuhha8RXfuFoK8NnETehNNyQqeGrhFbP4fg3FdIN24DxYSt12iQJGy9Rv/gMQEIJf/mD/4ipkfbUTHSTTup+9eVWCqH4jv9NiRJIt1SR92/rsRUVEnJOffRu9BKoLmVdX/9NahmSi98LHfOJW4T8tx/Mfftl/jV3X/l9l+dxx8/28yCLc3YzCp5NhXDgPMm9CSaymAATrNKLKVx+KBirGYFXTdQFTF+HoylOOzBTwgmd0UGJw4v4dGzR3ero39H/FtV039IdC+63w83vraSl+ZX54Jsr9PMqCoPH69tyI2dS8C1R+1STD737/P2qqJuaGnqX7gerWUnxRc8mvNchl03mfZqkLGN82icdSeu0cflVKHjWxbjf/VWXGOOp+Dwy3Ov1+Ihav56KbZeY/CedP13Pm9DSxNd/VlWVGUnan4p7vGn4Bx6WG6kXk/GCC9+i9DC19ETYSHyNeF0rD1H7uKV71xLaOEs4hvmCQutgQfgHn085rIBHRffNiGWlR8R2zQftAxqQQWOQZNwDJz8nWzMjEyKVMMWUg2bxH/9W0kHqjEyydw2ktmGmleM6vaiuL0ozgIUhwfFnodscyFbnMgWB5LZKgIKWe3y5mnoGkY6iZFOoqdi6IkIeiKCFguiRVvRos1o4SYyIT+ZYAN6tHXXiyUZNb8Es7cnpuLeWIr7YC7t32GEel+hxcPENswltvYLEtUrhPJ8ST+cww7DPvjgTmJnmXATkeXvEV72Lnq0FVNRD9z7nZylFewSYWsLoNAyYoRbknAMOQRbr9Ekdqwi8NZDlF36Z2SLHT2dzAVfocVvkdi6BN9pt5IJNxHfNJ/E9hW4xp6ItWJQp+Nvj+i6rwi8cW8Haga0OQjciHP4ETlbksT2FTRMvxnH4IMoPPYawV/0b6H+uWuxlGetgLKFJEPX8L9yC8matZRc8HCn6v/uKHFbqA8l9/j8iIo81taFujviPzD+1xPxHwvdMcH3R1cKx8FYmuZYgny7mXybebfnkpz+13loOphVYWs4vNzFu6sasJsU0prB/voK/jzt9/Q/6iJKDjmLAquJ3xzejyffX8GsW8/HZHXw4psfMW5ABS9/vYF7Lj8FKR2lxyWPItvyOH+Ul2svPQM9HhJOEe3Wj6Z3/0hk9SeU//wpVLf3O593yr+F4LwZQldElnEOPRz3fifvEufKJuTBOS+LIrwjn7yxJ+EcdXRO40OLhwgvmU14yWz0WCsmb09co4/DMfjgTgXRdEsd0ZUfEVn9CVqoEclkxd5vAvZBk7D1HJ3rupuB3oUKTruNhlCKaDJDS0IX919F/OXbVALhNOlgA47INhq3riNav5VYwzaSoaZdbyrJmN1FyHnFqK4iFFchdkcBJSVFOPLyKSzyUOHzUh+XyXc5cNst7AylqG2O0ZKEZEYk4rJhoGdSkEnS25nCYTaor28kHGwmEwmSjrUSbgoQa2kkFfKTCfqz8/ACZruboso+FPToR1FlX1J5VfTsO5DGcAqn3czmxghGBtKaqMCBmBYAKM830cfn5OC++byxuJqvFy4juu4rYuu+RI+HkG1u8odMpmT04UQ9HQvEhq4R2ziP8OK3SO5YhWSy4hx+BO5xJ+XsvoxMmsjqTwnNn0GmpRa1oAI1z4drzAnY+4z90WICI5Oi5m8/R3UVUXzeA107EF3yJ1RXIUYmTf0LvyfTUidcifJ8Yu1/9TaSO1bR+9JHyHh65vbdFm/89sqruOe++1EVmbrWOGtqg7y/pp7mSIpoWvjYRxIZjhkm1NP9kSRHDCrGahJXX1UkDENYnD360Tqe/nqH+FoB7151IAOKhYJ7dxL+7dGdiP8PYncfcRCqiC9dNoEPV9d36JK12RC1vW7aW6v3OM4OQlCi7l9XZjvdD3TgiwfeepDYui8FnyU7jtv88VOEF73RoYve/NGThBe/1akCLkbRXhaj69nRoe+KtpGj4PxXSdVtFOqZY07AOeqYXFKnp+JElr1LaGHW9qq4D+79TunQRU+31AnxlpUfYqTimIv74Bx1DI5Bkztw2qDNA/sroms/J1m9CjAwFVZh7z8RW/+J37pTvvv5ZIJ+Mk07SbfUkmmtIxNsIBNqRAsF0BPhb96JrCLluEE66LqYMdsLJNWC4ipAdXtR80pQ80tQ80sxFZajeso7TDV8W7QtZrENc0XyrWuo+aU4Bk3GMeRg4cfd4RoYJKpXEFn6LrGNc0HXsfUZi2vMCR2KKMn6TYQWzCK29kuQFRxDDiFv4uko9jyaP3oS18ijMZf2R5IVGl+7G3NxH/L2P1OMzGVH1NPNNbR88neKTrz+W52joaWp/cevkBQTpRc/nvseaYkIdf+8AklVhaOA2ZbrbssWJ6UXPoJstqEno1mqR5LSi/6I4sjP7bv1qxcJznmRgim/xTXiyG88Fnk3isnuuPvkYQwocTFzyc4OHPGzxlbispm6ueDfEd2J+I+D7pjg+2H3eK/tfqlluStCCMpA2U2Mac6mAPe/tw7dMDhmSAl9i508PWcrdcEEB/Ur4qRR5dxw5eW8/+Ysrnjgnxx1+KFsC0R45OOtxKtXUvPiTXiHTeIvf3uaykInG9av5sKTjsJSPgjfGdOQZIWUfyt1z16NrddovKdMzR1bJuin5qmf4xg0maJjr/ne1yDdUkdowUwiKz8GLYN9wP4dhD4NwyC5YyXBua+S2LYUyWzHNepoXGOOR3WJcV8jkyK65nNCi98k7d+KbHHgGXE4RWOmoLsrSXW45jrJHauJrvmM2Pqv0RNhJLMdW99x2PtNxNZrNLLFjl2ClAEmSXSm22BCiK/FdaGubFOh0qPisFppiqVpaGwi2biTZHMNmeZaiDaSaKkn0txIJtKcm6zbEyRZQZJlDGTAwNC1b3wNkozqyEd2FaK6i1Hzi1HzS3EWlOPwVeAoKMAkg89lpzRfZUsggVWRaImlKbSaaIgniSUNWlMdd+tSoI/XQmbnShpWzWH13E+Jh5qRVAu2vvvhGHwwtt6jURUTFgVi2cPMhANEVnxIZNl7aJEmFLcP1+jjyBtxJHl5ToJJYVMWXv6+aL5EW8TE5MQzsPYYTssnf/9RYwKA4PwZtH72r05UtdyE3Ek3Yh+wPwDNHz5BeMlsvKfcjL3fBGDX1FzBUb/BO3oKiWzYluucF1Zyzz9e5aSxPagscFIfjNMQTOC0qqzY2crLC3bgy7PRGE7gc1k4akgpg8vc9PPtanC09wbXdIPP1jXw1eYmLpjQg96+XV7t3fj26E7E/wfR1Zh525gpkHtudxsiyCbxf5vbZZesDaktC6h7dRrOkUd3sB7RU3FhyZSIUHrRo6iuolwXPd1UTemFj2IqKMfIpKh79hq0aIuo+GXHxPRkjNqnLkdx+yg5/4F9EvH6JhiGQbJ6JcH5M0hsXSKqpCOOwj32hN2qpJ8QWvAameadgiM+5ngxntaWtCdjRFd/SnjpO6QD25HMNhyDD8I5/EjMJf06JdiZSDOx9XOIrZ/TQSHc1mcctt5jsPYY0aWa9nc+z0wKLdaKFguhx0Ois52MYaQTouOtpTH0jEi+Qai4yIoQsVEtyGYrktmGbHEgW10odjeKI7+DB+b3PkZdI1m7gcTWxcS3LCKVVRBVPaXY+x+AfeCBXRYrMpFmoqs+IbLiAzIttchWJ45hh+MadUxOf8DQNeKbFxFa9DrJ6pUgyVjKBlBwzFWY21nGNb3/JxR7PvmTzgMgWbeRxtfuouJX/8peI11w0ld8gGvUsTm+1r6ize+zfZFJ2PrcQ2zT/KytT/9dvPDaDZRc8BBmb0/hF/763cQ3zqf4nHuwVgzJ7TfHCx96CIXHXN3hGu3Jx3Z3KBIc0LeI6uYYU4aUcP0xe6/id+O7oTsR/3HQHRN8P+i63vG+sY+JOEA8lSEQSbK8ugWnzYSm6RQ5LeRbVWavbmBrbSP/uv48MrEwz771IQ0ZJw98sBF0ncavZ9Dy+b849rI/8Ktf/hqPTeL4ax6k6d0/kjfxTPInnw9AaNEbtHz8VCebxpbPniY0f+YPUqBvgxZpIbT4DcJLs9aXlUNx73cKtj5jkSQZGYjXbyI0fyax9XNAknEMmoRr7InttGUMkjtXE1kym+iGuaBncPccgnnwkdgGHJDrkstk+Rlahvi2ZUTXzyG+aT56PASKirVyWNZJZTRqQUWHz0ikx6JzTnY/eSokJfGPhCa41CbAaQGHCcIpiKXAJulo6TBGpBUlHSIWDKMl46jESCeS6OmkUP+WNeIJg7QByAq6oiKpZvJsFlSLFavdTUa2ETO7ke35mO1uNFmhq3S9zClTUeBgR3OEUpeVApuKhkG+3UwglqIxmKI5nqIltosDngk1Et+6FH37IsKbl6Kl4ihmK2WDx5Kq2h9zn3Gdmh6GliG+ZTGRFR8Q37wQDB1rz1G4Rh+Lrc+4XY2U1nrCiwRHHC2NmleM57DLsPUdn7vOP3ZMoEVaqHnq51grh3bQQIpvXYr/lVtwjjiSwilXABBd8zmBtx7ANe4kCg79WXa7JfhfuRXHkINzU3MSoKWT1D//e7SQn/KL/shlR4/FZjHxm0P7oesG/lCcSDKD1SRz65triSQzSJLErw7uwyEDfZi61c//behOxP8HsbsSuowYOzl9bCVDyvKY9vbqnA3RLccNoSWW6tD9Wry9hUc/2pDzDYbOwX4bX7zwuN9ROPxQktnEPRWopv7Za7LqqHcjKSbBF//XVSgODyXnP4RstortnrkaS/kAfGfckbtxRlZ9TNPsRzqIVvxQSPm3EFrwGtE1n4NhoOYXkzfxLJzDBS/XMHTimxYSWvhabrTJMfQwXKOPxVxUld3GIFmzlsjy94itm4ORSWIqrMIx7FAcgw7uUohMiwWJb15EfNN84tuWYqTiIlEs7Y+1xwgsVcOwlA3cK9/3vxGGoZNu3EaiehWJ6hUkqldiJKO5JNnWZxy2vuMxFVV1Sr71dJL4xnlEV39KfOsSMHQsFYNxjpiCfcABuYq0logQXfkR4SVvk2mtR7Z7kFQVzyGXEls/B9XtxT7ggFwAl2rcTuCtB/CddpsoNCgq9S/dSN6E0zGX9CE452Uy4Uby9z8bc3HvTue0N2RCjdT+/ZdYq4Z1WHDDS96m+cO/kn/wJeSNPwXYVQkvPPZqnEMPA3ZVzT2HXIo7q+Au9hug7pkrUex5lJz/cIfviddppjGyW2vhG9DmG9o9dv7joDsR/3HQHRN8f3Q1mm4YBrrepoi8Z/5nWtPZ1hRFzY7aKIpEIJLklUU76FPkZNmKFfz9D+cxetRoXn/7XX4zfQVLqluR0Um+ez81q+Zx91MvM2bMeH724nIC7zxKdOVH+E67FVufcaIQOXMa8W1LKTnvwVzCqydj1P79FyJ+uODhfbK43FfoyRiR5e8TWvQmWrgR2VmIyVNC/uQLsVYMFufdWk940RtEVn6EkYpjqRiMa/Tx2PtPzCmxa9EWkqs+pmX5+2Ra6oSrx4ADKBxxKJQNxZBkzJDrlhu6RnLnGuKbFhDbvJBM804AFJcXa48RWHsMEwl6no807ZL5LFQEr1rq4nEj+5xbEXz0ZEpsI2Xfvy2W89ognha1+dZ2O2l7XkFwtiVJJPcZ9g4z4LYCkowsQYHdTCqjkUimsZokCtw2GkMJGhqaaaleTbx6BfFty3PnbnIXkt9vHJXD98facxjRhKnDceXEW1d9QnTN5+ixVmRHPs6hh+MccVSOKmkYOontKwgvfov4pgUgSchWJ3kHnEVy59p/a0wA0PjWA8TWz6Hskj/nqBC7tF7clFz4MLLJ2i527p2NnVUyQT91z3SMndsgFNU/wHvqLex/8OFM7ldEJKVzw9EDSWR0VEnYlhoG1IeSzN/aRLnHzvheBciyTCyVoSGYwKIqlOZbu0fOf0R0J+L/o2izGvLYzayqDTJj8U4ymo65XfLtsZuZ9vZqUhk9F5gDudfd8uYqMnvojBtahoaXbyLVsImS8x/m8P3HMG9LEynNILr2SwJv3teRG56t6tkHT6bouGuRJInIig9EVfyAs8k/8FyxX8Og4aUbSPu3UvqzJ3Ld8j2hreOd8m/B0DUURz6mwqqcRVlXaHzjPtKN20m31okqaWEV+fufiX3AAbmFNVm/ifDit4TYh5bGUjUM18ijs4uv4HfpySjRtV8QXfkxydp1gISlcgiOQZOx99+/w1hx++uWrFlLfNsyEtuWkarfKMbDZQVzcW8hclY2AHNJP1RP6Q8yFfDvghYLkqrfRLJuA8madSRr14nEG1DzikWA0XMU1p4jUWydR50MLU182zJia78gtnEeRiqO4irCMeQQnMMO78DnS9VvJLy0zbYkiaV8MK4xx5NpqSMTbqTwqF+TCfqJrf8KLdpK/sEXt6uA/xnZbMM54ijU/BKaZj9C/sEXozgLSDVsRgv6yYQaMfQ0qsuLrc+4TmIwnY7dMGicdQeJbcspvfTPOZXfZP0m6p+/FluPkXhPuwVJknNWfs6RUyg86jcAxLcvxz99qrA7O/G6DhyyhhdvIBXYTukFD3ca198b2ropXf2Cu4XYfjx0J+I/Drpjgp8ObSJNoXiaYDyFKksUOiwEIgn+/sVmrGaVQCQJm77igRt/y29+exWPPfIwi7c28dGaOlZtq+Hduy5DMdLMfPsjbvm4hi2NUdHRCzZQctFjmPJL0OIh6p7+raD2XPRobmqsLabYvUi5J4gu6xLBKzbbswKiAzpYRnU4Py1Dw/SbSTduFzQvScEx5CDyJ52f46briQiRFR8SXjqbTGs9isODc8RROEccmbM2y3XJV35EbP2cXWvYwEli4qu0f5cJT7q1nsS2pSS2LiVRvTJHNVNcRdjLB+EoG4BR2h+brzfGPhbsTYg1QEYk4G0dbHP2/2U6KpO3h4xI6gE8FmhOdt62fXFAAtp61hlEEu9zQTKRJtVSTcu2DYR3rie0Yy3pZqGML5ksWCqG4Ow5EnOv0ahFPZAlCZdZCMY1ZqsW6aadRNd9SWztF6Sbdgif8b7jcA49HFvvMbuKIfEw0VUfE172btZmNg/niClg6OiJyHeKCVRXoaAEButJN9eiJyLIZiuWymFdigorCCu4YEJc7TZNpDanAGjTeplKsmY9vS5+CK2gZ9aq9JqO06SZFPUv/IF0cy2lFz6Si38AIis/oumdR3FPOB3PQRcytMzFkFI3Z4yrpNBpIZ2lpZZ77FhUYZ8i8j0pGw8YzN3STErTSWd0Bpe66VH0w01odqMjuhPxbnQYVW8fgO/++Fn7VTFzyU5SGR1Vkclo+l45pplIM3X/+i2yxcER1/+D8uJCPlzTAEDzJ38nvPD1jh2/r6fT+uVzeA67DPfYE4VF0zuPEF31Kb4zbheqlQguTu0/f4Ot9xi8J9+0x0qdoWs0zrpTjCa1g5pfQvnlf+/yNYntKwgtep2CI36FbLFT/9INaOEAeiyI7PDgHDEF16ijcwUALRYUi++yd9GCDeLmPuRQHMOPyHXJQVidRNd8TnTN56LCK8lYKodi7z8Re7/xnTxI26AnYyR3riGxczXJmrWk6jdiZG1LJLMtq5jeC7O3B6bCCtSCChSH5yetXurJGOmWWtJNO0gHqkk3biPl34oWbrMAljAVVWIpHyx80iuH7tFvXU/FSWxdSmzjXGKbFmAko8gWB/YBB2AffBDWqmG5YoSWiBBb8xnh5R+Q9m9BMllwDDpITCwU9wGEvUdo/ky8J9+IpJhI7FhFdM3ngqef/X5lwk1E13xGYvvybBc9D8XhIbFjNUa8sz6C4vBQfM69HRbC3RFZ/SlNbz/UIVDUEhHq/3Ulhq5RevEfUWxu0oEd1D13DaZCoYYqqe0mRux5ourdLulv01MoOuE6HIMmfevPSpFEIt7+d9wtxPbjojsR/3HQHRP8NGiLEw3DyHkIt42z67rOlkCE+VsC5NtUxlQVMfX6a3j6H09x71+exjVgf5ZUt+K2Kqxfu4qP7/8ljrK+OE+ZJqblWuupf+YqZLeXkvMeQDZZSexcg/+lG3D3H0/eCTfkRuZzhc6L/7jXe3Fs/dc0vnl/BwExAN/pt2PrPabL17TFBZ4jfokWbCDw1gNo4Wahot13PK4xx2GtGi6ORdeIb11CZMls4lsWgyRh7TUK1/CjOtiY6ukE8Y3zia75jPjWpaBnUNw+ERP0n4ilfFCX3f3cJNmO1SR3riFZsw4t3CielGRUTxlmXy8hklpUiamwEjW/pEv7VCvi/p9i3+hLu0NBJO7xPTyvIhJvVdeIhxrJNO0g3bSDVON20m0is9nPQba5hTVr+WCslUMEpa+dt3sbzIZBxL9ViLdu+Jp0YDsdmhwDJ+UK+Yahk6heSWTFB8TWfw1aGnPZAFyjjsUx8EAk1fztY4KQH9XlxVRURXz7CjKN27q8evkHXUjehNM7Pe4yQTgtGjW1//gNstlK6UV/zAn0tX7xHMG506k67kqsQ44gYeg0vn6PoKSddSfWquEiNn73MaIrP8J7ylTs/cbn9p9q2EL989fiqhxE4enTcNkseBxmehc5+MNR/ZGQsFlUEqk0LquZ4jwbkoSwJpbFbymeyjB3SzNFLgvRZAa3zcSIivx9+k5049ujOxHvRm5UvW0cvS0A3/3xU0ZX8PKCanRj33mnbb6Gtn7jKTn5RoTeJ1kO7FRStesoPvd+LCV9MQydxtfuJr5pQe6Go6cS1D/3O8EXv+jRXMIaWjCLlk//SeExV+McdliX7x3fvBD/jNvJn3wBzpFHIykqWrgJPRntkktmZNKEFr6GoWXIP/Ac9FSc8OK3RCLmLCC8+G0SWxeDJGMqqkJxFuIaewL23mPQdY3ktmVElr8v1NF1DXNpf5xDD8M+aBKKTShKGoZBunEr0XVziG/4WlRwAZOvF7Y++2HrPRpL2cA9jtcZukY6sJ1k3SZSDZtI+7eSatyOkYrltskppucVo+b5UJyFWdX0fKGabnUhW51IZus+d9QNw8DIpNCTUfR4GD0eyqqmt6BFmoQoXNBPOljfUTldVjAVlGPy9cLs64OlpC/mkr577SCnW2qJbxE88cT2FaClka0ubH3HYx94ALaeI3NBhaFliG9bSnTVJ8Q2zhMLbXEf4dM65OBOPPtMyE9w7qvYeo7CPmB/tGgLkZUfC1/xkVMwDINMOEB84zwiKz4U9mgYKA4P1p4jsVQMwVzcJxfYpBo24Z8xTXSqj72qy/PJhALU/fPXmAorKT73PiRZEd/1mXcQ37qUknPuwVI+SAixPfs79ESY0gsfRXV70dNJGl74A+mWus5V72xy7xp7IgWHXbZPn+PuUCQ4dFAxn2TdElQZzhxXxSmjK7qT8B8J3Yn4j4PumOCnQVfc8jafcU3TCSczpLUMhiFjUWVmLdrKr889mVSgmp4XP4zsqUQCfC4zE6X1PHTjFThHHUvhkcKSzKheTPXLt1E07GA8x16D12XFvel93vvngx3oPJlwE3X/+FWH+2xX2PnXS1FsLoqOuxbFXYSejJFprsFc3KfLdWlPcQGyjB4LEVnxIXoijFpQjqmgAslix5l13ki31hNd8SGRlR+hRZp22ZgOORRzSd9dPPxEhPjGecTWzyG+bSloGbHm9RkrNGN6jtqr20gm0kyqbiOp+k2k/FtINW5DCzbs2kCSUdzerJOKD9VdhOIswOzwoNjzwebCYnORsdhza6sMuGQI7kWv1dA19GTWTSUe2uWmEmlCCwdEtzjkJ9Pa0KHwoTg8mLw9xZRfcV9Mpf1Q84r37LGdipPYvoL4lkXENy8ShQdJxlI+CPuA/bEPOCAnlgeiWRNd/SmR1Z+iBRuQLA6cQw7GOWJKzqIsd+32ISYwtBTJ7SsJLXmb5M61GKkoyCqWikFYq4ZjKe2PqagS2eZGj0do+fhvxDbNp/K3L+5R6ycw+1Giqz8RujDZeDS2eSGNM27HMexwio65CoDWr18m+OXzeA79Ge5xJwEQXjKb5g+fIG//s3LcdRBd//pnrsKuwnuffcU9nwrbWFmCcFLjzhOG4HGY0XQdzTDoUejEYVaRZdB0I/vbFfH98poggUgSCRhVlY/X9c2TFl1RW7rxzehOxLsB7BpVn9Bb+GHv6f/bEnNFkdF1nTbh9cGlLtbXh+lqUj208HVaPvk7+ZMvIG/iGbnHtViQumeuAgNKL3wYxeHZNYITD4lkJM9HurmGumeuwlRQQcm59yGpZpHIv3QjKf8WSi9+PDfq2x5tN6vyXz/7jSPsIBby4FcvYB80GVvPkaRb6oiseB81vxTXiKMASNSup/mDv5AOVIOWBkUVCpz7nYpsdyPJClq0VSwCqz4WHpSygq3XaOyDJmPvO77DYp9u2kls03zimxaQrFkLho5ktmPtMVyMalcNy3Kk95wwG4aBFg6IDnRzDZmWWjKt9dlFsLFDkr47JJMFSbWIxVdRd72PoWPoOoaexkinMNKJPSuoK6qwQ3H7MOVnldM9ZcLXvKCsy0p8h+seaSZZvVLwxLctI5MNIFRPKbY++2HvOx5L5ZBdNl2GTnLnGqJrv+xgW+IYfBDOYYfnut9GJg2K0uHaGbpGZPn7pBo2U3DEL5AUE80fPYlsc4OsEln2LlrID4CpqAf2/vtj6zd+r4r29c//ARSVkrPv7uKz0fFPn0qydh2lF+3q1LSpnHoOvxz3mON3TW9sXULxWXdhrRwqqt5vP0R0zed4T52a8xUHshZmv8dc2o/iM+/ssnPwTWjf+YZdv/N9ScDb3y+6E/Zvh+5E/MdBd0zw06At6QYRgOu6sEJSFBlNNwhGExiSsDZTJdjvrk+IB4WuhWx1UXrBw1hsdjQdbjhmADffcD3+r2dSMOUKXCOOwuc0UfvFy9R8+DSDTvgFw485lw11ITa8dBfRjfMoPvNOrD2E73KbmFXegeeSf8DZXR5r9f0n4J54Bp6sCNw34ZviAj2dJLL8fYLzXhFFaFlBNtvxHHYZjiGH5NZNYWP6cdbGNJ21MZ2MY9CkDpQiPRkjvnWJ0IzZsliItiFhLumTjQmGYykf9I10KD0ZI928k3TTTjIttaRb69Bas04q0Za9rueyahGON4oqlNOzCZpQTc9gZNIYmSRGZs/6I7I9TyT9eT7U/NJsTFCOqbAy15jYE4xMmmTdehLVK0lsX06yZh3oGSSTFWvPkdj77oetz34d6H2ZcIDYujlE135Oqm4DIAmf9mGHYes3Macb095yDPYcEyiuIlS3j+D8GWRa6jFSUSSLA3vf/bD3m4i158g9fgZt38OyS5/o0qI2un4OgdfvwT3xzNz3MN1SR90zV+Uch2SThdimBTTOvEPYlx73OyRJEs2tl28SDgKnTs3FN4au4Z8xjcT25Vz58POce/xhvLCgmqXVQTK6Ts9CB784qCeZjEFSMxhc6qbYLZJrk9rmlrNLB0LTDULxNGZVxmH55vgindHQdAMkCbMiI8vdyfi+Ym8xwbeP7LrxX4sxPTwduuDteeFtPNHF21s4dXQFBjC0LI9b3xQWXACb/BEum9SbJ7/Y0qlT7hp7Ism6jbR+8Rzm4j658S/Fnof35JtoeOE6Gl+/h+Kz7kK22PGdOpW6Z6/BP+sOSs59AFNBOUXH/Y7GWXfS9MFfKDz6SiRZoei431H79BUE3npAjPHulowoWe5WpnnnPiXiGAaphs14Dv959nU16PEwlsEH5zZJ7liJpbQ/3pNuILlzDa1fvUB44euEF72JubQfRiqBc9jhOIYeinu/k0k1bCG65jOia78gvnkhzaoZa6/R2AccgK3POEyFFeQVVpA3/lT0RCTLDV9KYtsy4hvnASBbXdmRrUFYygZ26ihLkiTsw9ze3ChVe+ipOFqkGS3Wih4LosUjGMkIejKOno5jZFIiadUzwrYMxM1dalNNNwu/8ZxquhPZJlTTFYcH2ebe5wqooWVIB6qzPPG1JGvWkGkRVVvJbMdaNQzXuJOw9R6TUzyHrB/79hXENswhtmEuWqSzbYmkmAQHr34T0ZUfEV3zOUUnXoet58hd10pWsPXZj8TO1bR++QLmkr7ENs4XY/OGjpJfQt4BZ++zx7th6KRbarH17nzdAULzZ5HYvpyCKVfkkvDYpgUEv3oRx5BDcI0+DhDWfPHNCyk44hdYK4eK1y54jeiaz8ibdF6HJFyLh2mcdRey1YH3xOu+cxJ+YL8irjq8fy6R3teEuqt7RHcy3o1u/G9CknYJuLUl4W3dcEWRsZlNBBMpZMBsUckYoLqL8J54HQ0v30xg9sMMPu9WIkkNWZL57JW/ccAh22n+4AlMhZX4KwajjDqFqubtrH3rb7RYSrD0HkvBMVeReu53NL5xb65o7xh8EPHNCwnOeUkkrVlBtfbHqriLcgJg+4RviAtkkwVDS2HvJ0bKYxu+JrLyY5pmP0zw65exVgwl1VSNY9BkCo78JQVTfpOzMQ3OeYngnBcxeXti778/9gH7YyrqgWPggTgGHoiha6TqNxHfuoTEtmWEFr5BaP5M4Qle3HtXTFA2oFNHWbbYsZT2z1mvdTglXRPTbNG2mEA4qRjJGHoqLtxUMqmcXZmRjeqEnZmKpJpEAd9sQ7bYkS1OZJsLJUvjUhye3Kj1vkB09DeQrF1HcudaknUbRKMDCXNxb9zjTsTaazTWisEdCvuZYIMYUV//NcmaNQCYi/uQf/AlOAZPznXJtXiI8MoPiaz8GMVViO+Um3P76BATzHkZx8ADSVSvJBMOYCQioFqw95+AY9DkDh7ve0O6eScgobgKO59ryE/zu3/EXNovVyzSU3EaZ92JJMl4T74R2WQhFagm8NYDmIt7UzDlCiRJIhP00/j6Paj5pRQdf21OwV8HWr96gcTWxRRN+TVVA0fgdVk4b0IP+vqaiKc0Dh3kIxRLY1IU0nqG6pYYRQ4zZpOc+95I0q5utiJLeBzmTsffFQzDQMsW34S9mY78A4om/i+jOxH/H8S8LU05NfV0RmfelqYuE3QJOgi1pTUDl83EXScPY+rrK9ENochoVWVCSY3CKVeQDmwn8Ob9lFz4SC7Jspf1pfDo3xJ46wGaP3gil7B4T/gD/hm30zT7YYpOuh57vwnk7X82wa9fwuzrnbUX81F41G8IvHkfrV88i+eQSzqci7VyKCgmouvmYK0a/o3nrsVaxdi2ySr4xpsXoHpKMXt75raJrf+a/EnnYcovwZRfQnzjPJT+E5Fklciqj9GjLbR88Qyhha/jOfxy7P0n4Cm+hPyDLyJZs5bYuq/E+NnGeSDJWKuGipH0PuNQPWW5BRjEIpOoXkVixyoh4JbjukuoBWWYfb2zHPGemIqqUN3eLsfxZLMNuaB8r7y5HwNaLCi69I3bSTVuJdWwhXTjtlwVXXDCBuEccTTWqqGi69zu+LV4mMTWJcQ3LyS+ZRF6IoKkmrH2HoNjwIHY+u6Xsy3JBBsEB3/1Z6SbqkExYe83oVPl3dA10i01GJkMoYWvCU9UxYRj+JG4Rk751l7uyR2r0GOtXRZAEjvX0vrFs9gHHIhzuPD1TjftIPDWg2JxPeo3SJJEdM1nhOa9KoR9Rh0LCFpF6+f/wj7gAPImntnh+ANv3k8m0kTJOfehOLpOgF0WhXCys4GMkuWAyZLEkFI387Y0AfuehMOe7xHd6EY3ugEdk3O71YTFpCCkoKC/18r6xgTWquF4D72Uxo+fYvsnLzDx1Ms4c2wlTpuJTfPep2rgCBpfu5vSCx9GdfvwHXsldTu20fDG/ZSe/xCmokq8p9xM3TNX4591JyXn3o9stlJw5K9I1q4j8OYDlF78WKc1wNZrtFDVTsa+sasM3y4usPUaja3XaLRIK5KiCA2ZlR8AwvM8vOxdfCfdgGvkFFwjp5AJN2VtTL/KJeVqXjG2vuOx9RmHtXIolrIBYnT5gLPRUwlRwN6xisTO1URWfCDG5BEFe3Nxb0ED8/bC5O2BqaC8k7UXiORTdRV1GOf+d8DIpEm31JAOVJNq3CaodQ2b0SLNYgNZwVzcB9eoY8S5Vw7tINxq6JqIo7IuM+nGbQCYvD3JO/BcHAMnYSqsyL5Xiuj6OcJZZfMi0DOYvD07FOZ3QUdxFhJe8jahudMBst30I7D1Hf+tHGsMwyC2fg6W8oGdvl+GliHwxv0Yhk7R8b9HUlQMQ6dp9iOkm3bgO/32nChh48w7kEwWvKfcjGyyoKfi+GfdgaFl8J1yc27kXZEgvG4Oobmv4Bx+JP0mn8Sxw0tx2UzYLCpleVby7WZUWWbhtiZ0Q8JiUpAlCVmVURQl181WFBlVljsk5PuCNss0XRclG7W7G/6DoTsR/x/EhN6FmFU5xwtvP57ePvg2EOMsqexsukmRcmOqA0pczFqyE384yUdrxZixbLbiPeVm6p+5msaZd1Jy/oPIFjGO5hh8EKlANaG50zEV9cA97kRsvcfgOeQSWj75O61fvoBn8vnkHXg2qcattHzyd0xFVdh6jsQxaBKJ6hWEFszCUjGkg2iFbLHjGHgg0VUfkz/p3G8chzJ7e6K4vdT87TJMBRWYCitxjzkx93yqYTOSomIpF1V2Q9dIBarxHXQhpoJy8iadS3L7CiKrPiG27ksCr9+N4vLiGHigWFCcBRQcfjmewy4jVbuB2KZ5xDcuoOWTv9Pyyd+Fcniv0Vh7jsRaNQw1rxjnsOIcB16Lh0XVuF5wwZK164mt+3LXCSgmTB4xAtbGD1ddXhRXIYpTcMEkdd8qnN8EQ9d28cTDTWTCATKhRjLBejIt9WRaanPKruKzcGDy9cI56hgsJf0wl/ZDzS/taJOjpUnsWEti+3ISW5eKqrihI9vc2Pruh73vBKy9RucWxUwoQGT5B8TWfZlVpQdL+WAKjvq1EGzJKuAa2Y5GdM1nxNZ+IbrpZpugCvSbgL3fhO9seRNa+Ibg8vWb0OFxLRYk8Ma9olh0tKhma4kI/pl3IKkmvKfchGyykKxZR+Cdx7BUDhVjcZJEKlBN45sPYPL27OQJ3vLZ0yS2LaVgym/36JmrSBDZLQmXJLh8Um+qCh3c8sYqNN3gr19s+U7CbHu6R3SjG93434YQa9NzdmdtaPMe1zSdA/oXM7pHGlmW0Mb9io2FYWa98jzHXnIsBgZPfbGZ7c0xjrzyQd644xIaZ95Bybn3E8CG79SbxbTczGmUXPBQtmj/e/wzptH0zqMUnXgdssVO0YnXU//8tQTeegjf6bd2oCc5hx9JZPn7RFZ8kOPc7g3fJS5IN+/Ed+pUTAXlpFpqia3+jOiaT8k07RBCs333wzHwQEy+PlgqBuMeewJapEXQ1DbOI7L8PcKL30QyWbBWDsPaaxTWHiMwFfXA1msUtl6jxHtpmaxmzAbBD2/YTHjJ7Gw3OXvtXUWonlJM+aUoeT5Ut1c85ixAdniQLY4fhNNrGAZGKo4WaxUTeOGmbEzQIKhyrXVkgv5dI/GSjKmwAmuPEZhL+mIu6Y+5uHeHsXHIKsZvX04i6yQjVOtlLBWD8RxyCbZ+E3KNHSOTJrZpAbF1X+acVWRHPq7Rx+IcelgHezEtHiK2fg7R1Z+R3LlafNalA7D3n4Bz+JF75eTvDYltS0kHqik8+rednmv5/F8ka9dRdMIfcsccnPMSsQ1f4znkUmy9RmFoaRpfv4dMuJGSs+9BdXsxDJ3A7IdJN27Hd9qtuWKDAkQbttI0+2HMZQMoPOKXKBiggIGBIknkOyyYVBHf9PO5WFsXRlFkehY6UGUZ3YBQUgPDIK1lyLOqqIqM2bQHnaKsQ8LuMKtyjmeudCfiPxi6E/H/QYzp4eGFn03oxP/cPfg+dXQFp46uYNaSnRjAqbuJO81cspNkWu8wpm7KL6HoxOvwv3ILgdkPC5XK7AKZP+lc0k3VtHz6D0wF5dj6jMU19kTSbQl6QTnOoYdSdOw11D//ewKv30PJ+Q9hKqyg4LDLSNVvJDD7YWHh1K7z6x5/GtHVnxGaN6NTx3x3SIpK0TFXkqzbgB4PY+s9hujaL1A9ZVhK+pIJ+jH7euVUPmPr56A4CzAVlAve8o7VRFd/iq3nSOJbFpE38XSSO1YTWvQGLHwNZDHm7R57Iu79z8RTPhDPQReJhWbLYuJbFxNd8xmRZe8CEiZvD1EVrhiMpXyQGD3vPaaDsquejJJqrCbdVE2maSfplloyzbUkti3Nqat3OEezTYyWWxzIFjuSySpGzxWTGHGWhJVFG0ccLS1E2jJJjFQcPREVwiyJCJ3k+iQZxVWEyVOCfeABuaDFVNQDxVXY2Qs8GSNRt0Eov+5cTbJmHUYmKcbuSvqSN/EMbL3HCEuXbKKcbq4hvHEusQ1zSdWuF98rXy/yJ1+AY/BBqHnFuf2nGrcTW/sF0XVfkmmpFbYmvceIUfa+45BN38+XPVm7nvim+eQdeG6HfYmu9X3oibDgelkcohL++j1kgn6Kz74b1e0jE/Tjn3UnqqsQ70k3ICkmtFiQxpnTkExmfKdO7VCJj6z6hPDC13GNPg7XiCP3eFy6IT4ZGRhWkUex24rXZeGIISXM29KEbhi5T85gz13tPfHA93SP6EY3utENWZaR9yBpIkkSxw8vY/qiatK6wTFDyxh79TSWrVzDb39xKW9u1bB6e3JgvyJ69OnHhEtv5+snriP6/iM4jrtedMZPvon6l26k8bW7KT7jDmx9xpF/8EW0fvY0wa8qyZ90LpaSvhQc9nOaP/gLwTkv5exPASxlA7D2GE5w3gycw4/8xq74940LtKCfTEst7gln0PLx37D1m0hy+7LcVJxstoGskH/QhThHHIVr5BT0dCIrTraYxLalxLcI7QPZnoe1YgiWiiFYKgaJqbjiPjldFBDrT6aljlRgu4gJmneSaakjtnlBRyHV3Aem5ARcZYsd2WwTMUEbR1ySQcqaXeo6hpbB0NIY6aSgtiXj6MkIWjzcoQCQ273ViZpfgrmkH47BB2MqrMBUVIWpoLLTmLdh6KQat4uu/87VJHasQgsJRXjFWYCt737Yeo3G2mt0rlOuJ2NE131FbONc4psWYqRiWWeVA3EMmoy1x/Bc/KAnY8Q2ziW69gsS25aBrmEqrCRv0nni2LrQGvo2MAyd1i+fEw2YwYd0eC669svc+u0YNDn72BcE57yEY+jhuMadhGEYNH/wBMnqlRQe9zss5YNwm2H7h88R3zAXz6GXdYj/UrEgjbPuRLY48J50I6gmbCaZYDhFL49DfL/YJaJW6LRwYH+r6H7LImFua6YpskRGM0jrOpIsd0q4DcPINuPEa81qx0RdkiRUpTsB/6HRLdb2/xjfRWzpm17T9nxta5yXsurqXSG06E1aPv5bVjDlgtzjeipBw4vXkW6ppeS8BzB7ewqf5FduJblzjVBSrxxKJthA3bPXIFvslJz/EIrNTSbop+6Zq1Ds+bluexsC7zxKdPVnlF3yp1wlcV+RatiCkUlhKR9IJhyg+cO/UnT8tcgmK/XP/x7HsCNwjTgyqyqeJDj3FZI711Bw2GWYi/uQrNtA4K0HcY09kcTWxcS3LBGiIxYHtt5jhSpqr9G56quhZUjWbSBRvYJk9UrhtZ1NqBVnAebS/qKjnB1BU5ydE1wQN009HhLKpeEmtEhzTtVUT0SE+nkytosLpqVBy2Rv2AZIMpIsI8kmJJMZSbWIBTrLEVdsrpytl+IsEGJtrsI9dpb1ZCxrWbJFdPTrNgrFeEMnV3SoGoa1ahiWquG7utmZNImdq7NqqQvJZD1GzcV9spy6A3aNohkG6cB2Meq3bo4YUc+O/9sHHSS827vwJ/8uMHSN+uevRQsFKLvsyY6WYm3WfFlFf8MwaH7/z0SWv5d7TE8Kj9xMuInS8x7EVFSJkUnTMP0mknUbKTnn3g4d72TNOupfuh5L+WCKz5i2R164qkjICAVUkypzy3FDmPb26hylpP2/29wPTIrESz+f2OE33c0D//HQLdb246A7JvjvQJs9kq4bJDIGs1fWEmlu4Jqzj0E2Wxlw+ePsN7AKq0kmGk9S+/UbvPeP+3Hvd0qumN4mhuUYcgiFx14DQNM7jxFd9RFFx1+LY/DBWfvTR4mu+hjvKTdjz04tyUC8bgP1z16De9zJeA699Fufw/eNCxK16wm8cS+2XmOIb12cSzbNZQOFEFnvMZh8vXNreyboF13hHStJVK/KCYpKqlkk4lk3ErOvN6bCij0KpOrpJFo4ILrVkSb0aKvgh8fDIi5IxUTBPZ0URXgtqxuTXaclWc42FExIqhnZlOWIW53iz56HYs8XE3jZ0fc9FToMXSPTWk+qYQuphs2k6jeSrNuYE5fNFR2qholpgMLK3PXIOatsWkBix0qhMm9zY+s7HseA/bG2c1bRExHRJV8/h/jWxaBlUNzerFDe5A7X+fsisuJDmt59rIMtL0DKv5X656/F7OtN8dl3IykmkrXraXjpBszFfSk+6y4k1URw3gxaP/8XeRPPJD8r4hZd+TGBdx7BM2oK3iN/jSxJxA0xQdgwfSrJ2vUd4oUrDuqFx2nh+JFleOwW1CxvW9eNnBBbh++EbtAaT6PrOsF4CkWSURWJIpcVa7uuuKaLRFyRJTRdx6Iq3YJsPxC6xdr+B/Fdg+w2Qbeu8OL86tzIq0mRUBU5N8K+O1xjjifduI3Q3FcwF1XhaBM9MVvxnjKV+ueuwT/jdkrOfwjVWYD35Bupf+5aGmfdScl5D2IqrMB78s00vHwjjbPupPjMu1DzfKLbPn0qgbcfwnvKTbluu+egC4lvmEvT+38SN8F9tOwCOowyicXGRd3Tv8VUWImlcohYbLWMEDUzWfFMvgD/rDvRspXn2Po5qPmlmIsqcQw4AMlkIb55kUgstywitvZzhCJqXzGS3mMElvJBQmRm/7MwtAwp/1YhYlK7jlTdxpyImzgml6guZxXK1YJy1LwS1LxiIZxiz+tQLf8xYehadgytIduZryHdvJNUoLqDlYpsc2Mu6Yd9wP5YygZmuVSC72QYOulANdGVH5PYvoxE9Uqh2K6oWCuH4Rp9XAffdcPQSdasI7ZpHrENc7OJuvAULRj9C+z9D0Bx/vAJZHjJbFJ1Gyk87ncdAo3Iqo9zVe82SkFowSwiy9/DPeE0kZhrGRpfu4d0c43ghBVV5nxBkzvXUHTCHzok4ZlQI/7X7kR1FeE96fouk/ACh4kpQ0s5dbQoSrQVzHanlLTEUrzwswk8+flmPlrbIIplXQQh34UH3q2k3o1udOObIEkSdosQ1jRnNPKsKs1mD0Mvmsayv/6O7dPvpOdVj3DBxF58vTmA9aCTGbB5I+s/m4XqKcM1cgqOwQeRbqkl+NULqPklDDr6IpzH/prNwXoC7zyK4vZirRhCwZG/Ih2oJvD2Q7nivg5YSvvjHH4koUVvYB98EJaSvt/qHL5vXBDf8DWmwiocgyeTd+C56NFmouvnkNi6hNYvnqX1i2eRHfnYeo7C2mMk1h7DcQ4/AufwIwBBy0rWrM3FBJEVH2BkueI5y9DCKkyFFUKt3FOKmleC7MgXz/2b9GIMw0CLh0Vc0FJHuqVGxAVNO0gHdogJOABZxeztgWPwQbmYQPWU7bJ3iwWJrZ+THVFfSqa1HgC1oDwbE0zo4LveNuYf2zCXxPblwqPdVZTzDzeXDfzBLba0aAstn/4TS/lgoZbf9ngsiD/btS7KTr6lW+vxz7wDxVkgYlXVRHTtl0IXZtBk8iaJCY5E9QoC7z2OvecIehz7C0o8dmpb48SSBs0fPklyxyoKj/sd1rIBGMCoChdpTSee1tjSGGV8b6EPIEkSyh661bIskW8zEU9liMYlbGaRYMeSmQ6JuICB3sYs6M7B/y3oTsT/n+KHFltavL2FW95YRSbbAs9oBmeNF96gry7aQUYzaG+UIUkSBUf+knRLLU3vPIajsAyKhbKn6i7Ce+otNLx4PY0zp1F89r0oVie+02+j/rnf4X/1VkrOfxBrxSCKjr2awJv3E3jnEYqOvxZbjxF4DruMlo+epPXzZ/AcfDEgPCs9h/6MpncfI7zozX3ihXUF2WSl6JgrSbfWo4UDWCuHimrr5oXYeoxEcXrEGHcigmzPQ4u2EF76Dq5RxxDbOJ/UVy9SeMxVOAZNwjFokhjDqtsoFFG3LiW0YBahea+Cogq104ohWCsGYy4biKW0H4w5Hsh2l/1bSPm3km7cTjpQTWzD11mbk3bHa3OjuApRnYXIDo/wELe5UWxO0dk22/Ywmo7wsdB1MYKmpdHTCYxUQnTSExH0eBgtHkSPtpKJNKGFm9EiTR0sUSTVLMb3SvtjGn4EZm8vzMW9UVxFuUXQyKRJ+beQ3LmGxM7VJHeuyZ2H6inDMfQwMYrWY3hOdEZPxYltnEd80wJim+ajx4LZzvdw3GNPxN5v4o+SfLch3bST1s+fEaJx7RT1EzvX0PTe41h7DMdz6M8AMY7W+tnT2AdOIn/yBSLhfu9PJLYvo/CYq3LCMcGvXiS65jMxYp8dW2s7V//MOzDSKbxn3b1HnYOxPQq4++RhuX+3/z13xef+ZJ0/N7GS0TrfA74tD7y7g96NbnRjX6DpRlYwEmIpjfJ8G4UOM3WTJ6FErmPhv+4g9PGTDDzlL/TxOXn6y00cccnvqd5eTfMHf0HN82HrNZq8/c8i09pAcM5LeEcP5C/X/II/Vz3EM1MvETo05z2QLdrfRP2zV+OfeQel5z+Us7zKP+QS4lsW0TT7YUovfPQ766d877igQcQFnskXwOQL0CItxLcuFhZmW5cQXf0pAGp+CZYKQVWzlg/CPvAAHIMmAW289Bohfta4VYihNWwmtuHrTmuy0IwpRHEWiJjAnoeSHU2XLHZkk1XYmuZigt1G09vbl6US6KmY8BKPh0R3PRZEizRnJ/GaMFLxDtdLcXkxFVbgHDkFs7dnVlyuR66DbRgGWshPbO0XJLK0tTZRNslsE84qY0/E1nssJk9p7jVp/1bimxcSXf81af/m3DVzjz0Be//9MZf1/1ZNmG+DtnVdTyconHJF7n3UTJr61+5Gj7ZQfPY9qM4CtHgY/6u3gZ7Bd9qtKPY8EjtWEZj9MJbywRQdcxWSJJNu2kHja3dj8pTiO+kGSj0OPA4Tk/t7eeCRx7PF/dPpNf5IVFmmR6GNHoVO+pS4cZpNJDJigU9rbTpOez53SQKTqmC1qKQ0HUWXcVo7JuFt4+i6rqNkLe268ePjPzoRDwQCP/Uh/Nfi+4gtdeU3XtsaR29PY5CEvdk546twWVReXrSDYCyd646X51v59SHDWDD0z/zlqjPZ8co0Si94GDVPdDktJX0pOuEPNM66k8Cb9+E95WZM+SX4Tr2FhpdvxD9jGsVn341j0GQyQT+tn/+LVlcRnkMuwTX6OMErnz8T1VOe49I6hh1ObOM8Wj7/F9aq4R0q2ntCdN1XhJe+QzqwHSOTQlItKI581PwSTEVVaJEW1MJKMkE/DS/fiJpfgmxzI5ttWEr6Et++HDW/NFcQaP7oSVK163I8JEmSOyqiJmOCF1W9ksSOVbsSc8gltOaSvtmx9N45m6s2aPFwOw/xBjIhP1p2LD3l34IWC0GWx/b9IWXtSsQYmqnHMBSXVwjE5RVj8pShuIs6enhn0qSbdpDYvlyIzdVtJOXfDJo4JjWvWCjFVg3D2mN4ruut6xqZQDXxrYIrl6xZI14jyUgmCyZvL2x9x3WgObRHm/VZyr+ZdKCadHMNWrgZPRkBSUZ1e7H3m4h7/KnfqI5qZNIE3noASTVTOOW37Ubl6micdSeq20fRiTcgKWp2cX1ILK7HXo0kybR++TzRVR+Rd8DZOIcdDohxtuDXL+EYdgTuCad3OO7AWw+SbtyG77RbMRdVdXlMsgQNoQQvzq9mQImrQ1e6Kz73nz/dJPw+c6+XOt0Dvi0PvFtJ/ZtRX1/P6aef/s0bduM7YfPmzei6nhUM68Z/IsKJNHWtCQwMzLLME59voiWWwm5SuebIAZw+7ioeUFt59e+Pc45SwOSTz+PIwaU89dV2ik68jvoX/kDj6/dQcu59mH29KZzya/KNEO8+cQeDelWxPlXJgPOnse6pa2iaeRvecx7IFveniuL+rDvxnXUXssmCYnVSePSV+F+9lZZP/0HBEb/8xuPPBP0Ev36ZRPVKtFirsI+yuVDcPkwF5Zi9PZFkBcVThhZq/E5xgeL04Bx2OM5hh4sJscbtJLavILFjJfHNC4iu+ggAyeLICZ8Knnhv7IMm4Rh8UO54jUyaTLCedGs9mdZ6tFCj8BGPNJGq34gWbe2UKH8fSCYLsj0fxZGPuagKpddoYa2aXyy8xPNLO6yxhmGghZuIb12aHU/fRLJufY7LLpltWMoG4hg4CWuP4ZhL+mWVxgX1Lrr2C2JbFpPctnSX8nq2AWDtNZqCwy/fY8KYbq4huXM1Kf82waMP+tFjQQw9g2xxYCkbgHvC6fs0LRFZ9i7xTfPxHPqznO2pYRjUv/dHkjtXU3T877GUDcDIpGicdQeZYD3FZ96JqbCSdGCHiB3yivGeejOSakaLtNDw6m2gqPQ8+zZ0ixMdmWNHlLNpyRxaP/0H+YMOYMwpP8dps2AAxS4rk/sVksgYhJMpehXZqWmOYVKFZoPTrGI17z2ty7OZiCYzWFQFh7XztoJX3m1L9kPiqquu2uvz/9EccVVVjXg8jsm0716F3diF78oRb+t6qYoMhkFGN4RVgSTlRtHblJgvmtiTv36xpct9nTSyDLtF5Zl35lD33LWorsKcsFUbwkvfofmDv+AceTQFR/4KSZKIbVpA46w7sfYcie/UW0BWaP7wr0SWzsZz6GW4x52IoWv4Z9xOYtsyfKfdmhO30GJB6p6+AslkofTCRzu81+6IrP6UprcfQi2owFo5BMlkxUgn0SJNpFvqyLTWCesrsjZcZQOQrA7MxX1xDj4ExZFHeNl7ZIL1eA66CD2dJLz4LfRkFM9BF3Z6v/DyDzAVlmMp6Z8TMNFTCVL1G0jWrBOqqHUbRdc5C8VVtGssvd0ImuLsmqvdpmqa44inYhjpZM5H3NAzohMOIElIkiyq4qpZVMjNNmGFluWD7ek99ESYTPYapZvaxtCqSbfU5K6ZZLZhLu4jigtlA7CUDUTNem4ahkGmuYbEjlUkqleQ2L4CPdYqztntwzHgAGR3Eaa8Yuz9JpAJ+gm89QAFR/yiyzF8w9DZ8ehZGKkYirNQjO+7vdnP3yAd2EFi+zJs/SZ08BftCk0fPEFk6Wy8p0zNKfRrsSD1z/8ePR6m5PwHhVJu4zYaXrgO2eGh5Lz7UWzu3PfZMewICo8WSXx8y2L8M27H2mMEvtNu7TB23sY19xx+Oe7sNER7SEBVgZ3tzbHcY6osxNr21pXO/Y7TOrIsMe3EoZwzvuskf1/Rts+24l53R7wzLr30Uv75z38CdHPEfwRIkmS88MILnHPOOT/1oXRjD9jSGEGRJWRZYv6WJt5bWUcvr5PqpijHDi9leHk+MxZXc/e1v6Bl9VeMveR2hhxwBOtqW6mPaGRCAeqfvxYMnZLzH6RfjwrOG1fMHy46jdaGHfQ4/16spf2J7lzHzhduQC2ooPjse5AtdqLr5xB4/V7sA/an6MTrckXitvts0QnX5TrMXUFPxqj9+y/RkxFsvceiuIrA0NFiIeEW0rQTPRkVG8sKZl9vLOUDxdh1aT/svcciW+z7HBcka9eTCfqxVg7NTXiJtXGniAlq15Os20A6sH3XumqyYiqqxFTYQ4iiFZRlXVRK9lhkNjLpbEwQEZ3tdELEBTndGD0XF3TkiFuQTRYksx3Z6shavFn2+B6ZYAPp1rocZS0d2EE6sD0r/ArClrUcS2k/LGUDMJcNxOzrtUtoLRERHfLqlSSqV5Bq2Jw7Z1vvsdkGSQ+cQw/5xpgAoOWzfxGaP0MU8wsqUPNLxLSErKLHQ8S3LMZIJym9+I97HeVP1m+i/vnfY60a3kGhv/XL5wl+/TJ5k84jf/+zMHSNxjfuJb5hLkUn/EE0ksIB6p/7PYaeFrTL/BL0VJyGl24g3bSD4rPvwVPRH1kGt1XFFa3h00euwFZUzuCfPUREVxhc6ua4YeX0LHLQs8jBe6tqqWlJcGC/Qso9dgocZqwmFc2Agna+4NFkhnhaw2VVsagKum6gG7tUz/ekjt6NHw4bN26kf//+sJeY4D86EZckyfjzn//Mr371q5/6UP5n8OdPN/HQB+tzIk8gVJcVCc7cr4odzTHmbAqgG+KxygI725pie9slAPFty/C/equ4ke2WjLTdLNtuZiCS1ub3/pgVarkaDIPAG/cR2/A1hcf9DueQQ9CTMepfvI5Maz3FZ9+Tq2omdq6m4aUbsfUeg/eUm/c4qtQwfSpaOEDpJX/qOuHMpEkFtguBkdr1JGvWkWneKZ5UTKLTXdqf+PZl2HqMQrLYSDftwDn8SGw9RnS4yenpBDsePUuIuKlmIchWPghL+UAspQNyo3QgeEiphrax9G2kAtVkmmt2ca0Ash6hitubHUHzoDjyUez5WXVUh7AsMdvEQqpaxDWXlewIWvaT1TUh1pJJZcfQklmRtyh6IowWC2WtSlrQIk1CGC7UuFuFXcoukEI93ezrhdnXS/C/2hbYdJJUw2YhZrd5EemW2tw+FGcBlsphaOFGJJMNe/+JWVXZZIdFv+7Zqyk85uo9do2T9ZsEb34PYm0tXzxLaO4rVF41fY8FmsjKj2l655EOokF6OoH/5ZtJNmym+Ky7sFYMJhNsoP753wNQct6DqHm+XBBo6zNWfO9khWTdRhpeugHVU0bJOfd24JqHl8ym+cMncI05noLDL+/yeIqcZgKRVJfPKRJcc+QAfn2I+N7vXnjrarLlu3K72/blsZtpiaW6OeJdoLGxEZ/Px4QJE5g3b153Iv4jQJIkY8CAAaxdu7Y7gPwPxY7mGPG0hixJ1LbGeGH+DkwyRFIZzhpbSb5d5cJ/LCSeTNDw0o2kA9sZcsm9hPP75faRatxG/QvXoToLKD73fsYPKGf5+m1se+b36OkEpefej1pQTnLLIupnTMNaNRTfabcjqSZCC2bR8uk/cY05Ac9hlyFJkhC9evEGUo3bKDn/wQ7+4O0R2zifxll34Dvjjpx9WHuI7m4jqfrNJOs2CP527Ybc2qx6yrBWDsVU0ofY6s+wVA1HUk17jAua3vsTkeXv5V5rqRic40+bCitz66eRSZFq3J6NCbaSDmwnHdiBFm3pcHyyzZ2zLlOcBUJo1ZEvuvUdnFQsgrKmqEiyCrKMJAnBLwyhmo6WRs+khNhrOzeV3Gh6tGXXeHq4ES3SQnuXFdnqFBz2oirMvp457/O2NdAwDDKtdcIvfONcMi11aOEmsQ9FxVw6QFiTWRw4Bk3GNerobx0TZEKNGJkUqqe0yzgwE26i5i8Xkj/5AvImntHlPrRYkLpnRAxaetGjOdHd8LL3aH7/T7miO5ATbPUcdhnusSeixcM0vHgdmVAjJefci7m4D4aWwT/zDhLbluI9dSr2PuOwSqLA7tRbWfqXq5AlGPrLP2LL95LRDXp7ndxz8lBiKYP5WxtZXxfGbVXZ0BDGoij09No5c3xPihzWXJc7GEuzcFszumFgt6js17MAs7rrGmiajpF1X1GVzgJv3fhhcOqppzJr1izYSyL+Hz2aDvDMM890J+L/Jize3kJNaxxVkdE0XfiCGkZOnblNJGr+1mbSGfH8yMr8fUrEbT1HUnjUr2l69480f/AXCqZckfvh5x90AVqkieCXz6M4CnCNOBLXiCPRoy20fvkcstWF57DLKDr+WhpevZWmdx5Ftjqx9xmH77TbqH/+Wvyv3kbJefdj8pRhrRiC59CfCR75F8912Z0GQJZBVvaoAi6pJiwlfbGU9MU18mhA3JSTO9eQ2LGK5M7VhBa+DoZOumErsiMPe9/xGJkUWiKSUwUHwTGr+M2zJHesIrFjNcmaNYQWzMpVuRW3D0tJ39z4mcnbC2uv0bt41oaOFg6Qbm4bS68nE2wUQUHderRIS8dE/YeEJIsk31kgAo0eI4SHeX4Jan4pJk9pB+6dUE/fRmL7u6QatpCs39Spoq848lHLB1Fw2M9RC8rRExGaP/wrJk8ZejyEnorn+OIgRsxki3OvY+XfNF626xi7XnCSNWtpel/wv/Oz3xlDyxB44z6StevxnnQD1orBaNEWGqZPxUgnKT73PtQ8H/Ftywi89QCWsgGiEyMrpFtq8c+4Ddnmxnf6bR2S8NjmhTR/9CS2PuNyXPOusKckXIIOlJM98bfbEvLvw+3u5obvG5555hkAfve733WPp/9I8Hq9rF+/nvXr1zNw4MCf+nC60QVK8qw0RZIYBkzoXYTHYWbR1mbK8qwMrcjj4zUN6IDJZMV36lTqnruWNc/fTun5D+WsKc3envhOuYmGV26hceYdrDlvGrqtgJIzplH7wh+onz6VkvPux9J7LIXHXEXT7IcJvPUARSdeh2vcyWTCTYQXvYHi9JA34XQkxUTRSddT/8zVQhD2/Ie69pDOUh4kU9dcckmSUN0+VLcPe/+JgFgjUg2bxbq+c5XQclnxASDsNRWnB8eQQ1GdBZ06kAVH/hLn8COy8cQa4psWEF2ZHUs3WXeppRf3xuzrhXPIIUjqLltLPRkVMUFLraCqBf05X+9kzdpOmjI/JCSLA8XhQXUVYuo1WlyXtpigoAzZ5t4Vv2Tt1uJbFpPybyZVL0bU9URY7ExWxNRh+SDyJ52HOTvinYsJEuHvFBOobu/ez6Etqd8D1aXN71uLtlBy7n2570xsw1yaP/gL1t5jKDzq10iSRMvnz+Q43e6xJ6KnEjTOuJ10Sy2+028XSbih0/TeH0lsXUzBlCuw9xkHQG+vjbU7mtj04lQyiSgn3/w3VF8VWwNRihxmLKrEyp0hSvKtRFMZoqkMqgyra1oYUlHAmroIn69v5PyJPYkmM+iGQUM4gSxDgd1CIJwkmsxgzsZAhmFgGCDJEmTV1vck9NaN745oNMqsWbOYPHkyX3zxxR63+49OxAsLC1mwYAGLFi1i7Nju5sKPiQ4j6bLEWftVccpu6sxtQX1utNkwcFh2fYUkYGCJi7X14S7fwzn8SCG8Mnc6ittL/gFni9dJMoVHX4kWC9H8/p9QbC7s/SfinngGWjxEeNEbyDYX+Qecje+UqTS8fCOB1+/Bd/rtWKuGUXzGNOpfuI6G6VMpOfc+VFeR4JE3biM071XhT57l6raHtXIorZ8/Q7q5Zp8VRhV7Hvb+E3OLsOB8ryGxcxXJ6lVEVn5MZPn7tNl1ia63+FPzS4QdV//9xWvTCVL1m0jVbRBc6vpNQnglC9niyI2lqwXlmDxlqJ5SLGUDOixI4qMQI+ld2ZSI0fSk6HzrWjthF0kUIRQ16ykqxtBkix3J4kSxOUUl3ebqVE02tAyZcIBMa70QWmmuIRXYQaphI0Z81+cvW12YS/piG38q6BpatJWiY68GoOapy9HTCSRJIrx0Nva++4Gikqpdj5FKgNmW5e2bCS2YhaVyCKrb953GqQzDIL5hLubiPl1araRb67N+316KTrweSVbEovnuY8Q3L6TgyF9hH7A/WiJCwyu3oEWa8J1xJ2ZvT5K162mcdScmTzne025FNlnJRJrxT58KhkHxGdNQnQW590rWbSTwxr2Yi3tTdMIf9lgI2huK3RYOHVTM+vow87Y0UdMa3yN/+/tyu7u54d8MTdN46KGHyMvL44QTTvipD+f/LXw+H42Njdx6661Mnz79pz6cbnQBkyJTkrdrfRpWns+gEjeheIpQQmNEpQezKhNN6SgOD8XZYnrDK7dSct4DuYkma9Vwio67lsAb97Ftxn14T74JpaAc3+m30/DSDfinT6X4nHtxDj0UPRGm5eOnaHr3cQqP+S2eQy9Fi7bS+vkzyFYXrpFThCPFKTdT/+L1NL52F8Vn3tlJvM1SNgBkhdj6r7FWDNmn85UUdZcOzPhTspzvbSIxr15JYucagl8+R/DL50SyWTF4V1xQ0rfda08VXeKWWpK167NWXxuILHtvV5FdklE9ZZiLqlALKzB5ykVMUDkU+6BJXa7TIiYI7WZpmsyNpht6BiGTbeTeQ1LUdpQ1q6CrZW1NZZsbxe7udO3aON2Z1noS25aJmKBpJ6nataLT3RZ3yAqmoioRQ0kyejJC0fG/R5IVap66PCsmZ6F14Ws/akwAEFsv4i1rxdBOzwlxtj+T3LGKouOvxVIqxIbj25fT+OZ9mEv64c1qxQTnzyA071WcI6YIwdZMmsbX7yZZtwHvidfnJiFaP32a6KpPyJt0Hq4RRwFQ7Mh2sGffR7JxOxVn3sovTj2Uinwb76/1E4mn8LqsxDIZGsMJehQ42RaI4Q+n8LptVBU48EdSRBMZAuEkNUExaWhkrciaIklMiozNvCvOkCQJJCOXkHdblP04ePDBBwFBWfuvTcTLyspoamrixhtv5IMPPvipD+f/NdoH25puUJZvywXb7YPueVuayOgGBlllVMBq2iUKd+fJw/hwdT1PfrmFrlgPeZPOIxNuFJYkbm8uQZYUFe9J19Pw8s00vnk/xWfcjrVqOJ5DL0VPRAh+9QKy2Y573IliIX7xevwzp1F85p1YygaIx16+Mbc4K/Y8Co74JZnWBpreexzFWYCt1+gOx+IYehitX71IcO50irI+pd8WssUufML7iEKRnk6QqtuQrY6vIbrmMyLL3hXbZnnmOZ/wkr5YK4d2EGTTk9FdI+mN20k37SC2aUGOP517X6sLxe1FdRXmRtBkex5K2wia1YnqKkIyt6mjmnOj6bu67LuNoaWTGOkEeiqOkYyRClSLpD4WzI6nNwtv0nBAiKa0V2o1WZDtHiRZxdJ7DJmgH/fYE3AMPyonrNTy2dOYS/piaGkkxYSlbBDxTQtEJzuTxlI5FHSNxNYlxDbOxdZnLKrbR6J6BVq0hcIpV4j3+g4LbnzzQlINmyk46jedntNiQfyv3gq6JhRObW7hC/7hk0RXf0r+pPNxjToGPRnD/8qtpJt24Dv1FqwVg0j5t+J/9VYURz6+M+9AsbnQEhH8r9yCFgtSfPbdHXzt0631+GfcjmzPx3fqrZ0KKvuKQCTJS/OrMRBCbqoio8pSbnqlvTDb9xFu/CFe/7+AZ599lvr6eu666y7M5u+mzNyNb4bVauWII47glVde4YEHHqCq6vvpHnTj3wNVkcmzmTGrGTwOMy/8bDxPfr6ZJduaaJUq6XHmzWx9YSotr91Bvwvuwe22U92cJG/wgTiNMNve/DOt7/2R/GOuwlLSF9+pt+B/9Vb8r95K8Vl3iQ5kIkpwzovIZguew39B0bFX4U9GaH7/z8hmW9Y2awBFx15D4M37CLz9MEUn/qFD8qrY3NgHHkhk+fvkjT/tOzlzSJKM2dcbs683jDl+lx5K1jEkWbNmlz2prAoqV2n/nFCr4H2Xw9BDgWw3ublWuKgEhINKKrCd2Kb5uSkzcfAmVHeR8PZ2FoqYwJGHYsvLjqY7UN1ewffO0dVMoCgdBVcNA9pU07V0LibQk3HhS964FT0WQms3np6jrKUT7a8Esj0PDB1LxWC0SDPOEVNwjT4OOTtx0PLZ05i8PbKxhPJvjQkMLU1o/quYfL0xt7MSbUPwqxdzoqttzinJmnU0zrwDk6cM3+m3IputhJfMpvUzYUlWcOQvwdAJvPUAia1LKDz6t9gHiMZLaP4MQgtfwzX6OPImnpl7n54+J4ufvYeWjYsZdPq1TDzyKMyyQkmencMGFrNqZysFTgt2k0JTJIXdLHH88BLyrSY+3xhg6c4wDovCUcNKsqKICmZVIRRPM7qHh2RKp8Bp7mRTJjjiWcH87rH0HxyJRIKHHnqIiooKzjvvPC68cA+TufyHJ+Jms5nTTjuNGTNmsHbtWgYNGvRTH9L/W+xrsL37dqeOrmBoWR7vrqrj6KGluXHYI4aUcN+7a1mwrSOHSZIkCqdcgRZpoendP6LY83NJrGy24Tv9VhpeuB7/zDsoPusuLKX9KTz6txipOC2fPIVksuAaOQXfmXfS8OL1NLxyS3a7fvhOuxX/K7fQMH0qxWfdhWJz4T35BupfuI7G1+4W27W74arOAtxjTyQ0fwbO4Ud2Uij/LpBNVqxVw7FWDQeyliOBauERXrOeVN0G4psX0VZ9VpyFOYV0s7cnZm9PLGUDOx2LnogI3+42tfSgXyx+4SaSdRuFvVeXju5doT37fx9fYbaL0XRXQU7tXM3bNYqmuAppnHkH9v774xx+BNE1n5Gs20i6bj2WcvG7VRwe0o3bchYm1h7DiG9ZTLqllvi2pcKapamaTNBPyr8VS8UQDEMnsvIjTAUVxDYtILFtKfb+E3PXd1+gpxK0fPw31IKKTpMReiqOf8btZIJ+is+6E1Nhhahcf/Y0kaWzcY8/FffEM9BTCfwzbidVvxHvSTdg6zWadNNOGqZPRVIt+M68E9VZgJ6K0/jqbaSbd+I77bZcFR2yCf8rU0HXKD79tm8d5PUosCNJ4HVZWLy9Jffp6Ybge521XxVl+bZO/O1vq46+O77v6/+/wzAMpk2bhsVi4ZprvltBrxv7jltuuYUPP/yQW2+9laeffvqnPpxu7AbDMEhmBP/UapJzQb6iyDgVkYCZVYXhFR56Fjn5ckMjprL98XATS569nciHjxE77GrBW0XigOPOJRkOUffpc2BxUHLEz7H2HoZx0g34Z92J/9Xb8J0xjbwDzsZIJwgtmIWkmMk/5BK8J92A/9XbCLz9EJJqxt5/Io5Bk8iEGmn97J+0fJSH5/BfdEhE8g88l9j6OTR/8hTeE/7wva+HJEkiuS6syHVBtWiriAlqhSBbdPUnRJbOFturFky+ntlkvhcmbw/MRT1wFB1Ee2UTQ8vkvLsFVU2MpWvhJpI1a9GiLRiZrqlNXR+onJ123Me4QFYFB93pwVRYiTU3nr6LthZ4497OMYF/808eEwAE588k01KH7/TbOyWi4aXvZF1ODifvACEMmazflC26e/CdeacQZ13+Ac0fPoGt7/hcM6fpnUeJbfgaz6GX4RwuaAThZe/S+vkz2AcdRNHhP0dv937vPfM4obmzGXzMxVx88cX0KXKwpTGKSTEYVVWA26ZS25LEYZZQZLCqMmZVxeu0cPGBeRzeGsdjNVHksuAPJagLJomnNRwWE16nZY9JtiRJ3T7hPyKeeOIJwuEw99577ze6fPxHi7WNHTvW+NOf/sTEiRO5+uqrefjhh3/qQ/p/jT2prH+TCNSe+KPthd/aICFu83oyJlQjm3dSfNbdHRLkTDhAwwvXoSdjFJ9zD2ZvT8HVmXUX8S2LKTzmKpzDDiMT9FP/4vUYqTjFZ92Fubi3UKiedYfgmJ15J4rVSSbSTMPzv++wvzboqQR1T/8GwzAou+gx5Ha87h8LejKWs/FINWwm1bCFdNMOdi2AQvzM7OuFWlAmRtLzS1Hzi5EdBchK5/qZoWuiex0PoyXCGNkRND0VzwqxpcQImtZuNF2Schx5MZqeHUOz2MWf1Yls7XoMrdP7GwatXzyD4ijAPfYEMkE/kRUfoDjycY0+DhAKsS2f/oPiM+8SAjat9fhfvonyX/yDlk/+IdRUywYQ2zgP2erEPeZ40i111D51OebSfpgKKzEX98ExaHLX/L49QAjivE/x2Xdjrdrlw21kUkJ5v3ol3pNvxN5vQu48QvNm4Bp9HJ7DL8fIJPHPmJYbUXMMmky6pY6GF6/D0HVKzrkHU2Flx/2deH2uEg5kVVJvJB2oxnfmnVgrvl1RsY0TntHauRloBjqiI97N3f7psGbNGoYMGcJvf/tbHnvsMQAkSeoWa/sRMHbsWGPRokWMGDGCLVu2EAgEsFi6VnHuxk+D1liK1lgaAIdFwevqzOF9ddEOIskMeTYT25ui6JpOJKmxcPbzzH7qPvJGH0PV8VcQSoiOb6lLpeb9p9jx5UyKJ59FySHnEU1AcN1XNL55P5bKofhOuwVJtdDy0ZOEl7yNe/xp5B90IUYqTsMrU0nVb8Z78g1CxyVbbA0tmIV74pl4Jp/f8RzmvETwqxcoOv73HazCfiy0eYSn6jeJP/8Wkg1bILVLf0e252Eq6oEpR1MrQ80vRnH7ULqgW3Wiq2VFWDvR1bSMiAnacgFZFgJuWcpaTjW9nZuKbM8TYrB7yeT+k2OCZN1G6p+/Fnv//fGeeF2H54SbzsPY+o7De/JNSLJCyr+FhpduQjJbKTlH6MIIYddHsfYaTekpN2MyK9TP/hOhFR+QP+l88vYXXe/oms8IvPWQEHE9+SZURUVRIKVBaPFbtHz0JM7hR1JxwpUMLnFS7LYST+s4rSpXHNKPApeZaFJDMcBsUkikdTRDx+e0sqUpQiKt47aqlORZcVlNhJMauq7jsppErNCNnwSjR49m9erVRKNRVFXda0zwH90RBxg/fjwDBgzg8ccf5/bbb8fl6loRuRvfH23d7PbYmwgUiGR7T/zRCb0LURWZVGbXCHNbqilb7PhOv436538vhNbOvT/nzai6irId7+tyo+amgnKKTrqBxpl30PTuYyDLOIccQvHZd9Pw4g00vHwTxWfdJZTST76Jxll34Z8+Fd+Zd6A6C/CddRcNL/yBhuk3U3K2SJwAZLOVouOupf7F6wjMfgTvKTftUWX9h4JssQsf7XZJYWz91wTnz8DacySJbcvRYq2kAtXENs7rMAIOZMXTPFgqBu8aSXfkI8kKWiyEpXIIpqIeWcV0q1hQ242j72nxNHQNI50UyXsqhp6Ikg5vI9GmkpodRbOUD8I1csruL0axe3LiMOKYPGTCjblNzCV9kUw24tuWYO8rbMFUTxl6Mobn0Etz2yVr12GkU+jJGKqrkIornkexub/TtY6s+lgIqIw/bbckXIiwJLYvp/DYqzsl4c6RU/Ac/nOMTJLGmSIJLzz2apGEt9bT8NKNGFomO3peKQpFb9y7a3/tknAjk6bxtbtJNWzGe8pN3zoJB/A4zLREUxiI39k540X3u1vF/KfH9ddfD8Cvf/3rn/hI/ndw6aWXcuWVV/LEE098o0drN/69iKU0rCYFRZaIpTKd1py0pmMzy6yri1Erg9dpxR9O0NvnxHPmJazauI3tn02n2pZH/oHnAlAXzlA55ec45DTrPn+ZtGSl8IDTcQ88EF1L0/T2wzTOugvvKTeL4qmuEZo/AyTIn3whxaffTsMrU2l87Z5cMp5/8MXoiQihudORTZYOqtl5E88gsXUpTe89jsnbE7O3x496zSRZwVxUJdS/syPpsY3zCc2fgbXveBJbFqHFWjHSSaJrv8Bos1DL7UAS1LfS/ijOgpzPtyQraIkIlvLBWTvPHkgmG5LJnItz9hoTGIYo4ucoa1Ey4QC6fyt6PIQWbUWLtaLHgkJbpf1+/kNjAi0WpPH1e1CcBRQc2VEIOrruK5pmP4KlaihFJwjR1VTDFhpevgnJZKH47Hs6JuE9RuA9+UbMZoX6954gtOIDfJPOxJZNwmMbvibw9sNYq4ZSfNL1GIqKzQx2k8rWRZ/Q8tHfsPWbQPlxv8GkyNSH00RSWk4MeeaSnUzu78VmVslkdHp6HXgcZsBge3OUZ+Zsw5Cgr9fJKaPLybNbyLN1J98/Nb7++muWLl3K5Zdfjqp+c5r9H98RX7RoEc899xwXXHABU6dOZdq0aT/1Yf1PoX1XW5bggL5FXHV4/1zQ/03ewje+tpIX51fvcf/pljrqX/g9kqxScu79qHm+3HOpQDUNL92ApJhEMp5fIqyk2rqTx12DY/DBIjF68QaMdBzfmXdiKelLbPNCGl+7C3NRD3xnTEOx55Fu2kH9izcgyTLFZ3Xk7rZVJt3jT8Nz8EX7fH0MLU10zRckti0l3VqHkYyJhU01I1vsgrPt8KA6C3AMOTinCtthH7pGcM5LGIaOZ/IFJHauJbJ0NvmTz0dxFZEJNZLYtpzomk9FFVpRSdSsBcPIWYt8IyQZFBOyahY+8G0Lpq6JBFxLd+SbdYXsKJpz6KHk79ZBAFFMiG1eQOGUK5BkhfCyd9FiwZwlHYiFKbrmCxS3l8T25VjKB2HylJIONYpRungYPRnBSCUxDB3ZZBH2ZmUDcQw7DNVV9M3nmkVi5xoaXr4RS/lgis+8o50VTJrGN+4hvmkBBUf9BtfIKR06JM6RUyg48lcY6ST+GbeT3LmGwmOvxjnkkGwSfgNGKpGbwjB0jcCb9xNbP4eCI3+Fa9QxuWMQzz1AbP1X2UmOzqKB+4K2SZI2TO5XxPjehbkE/MX51Tl6yPf1C+/GvqPNI/S4447jrbfeyj3e3RH/cdAWE6TTaXw+H4qi0NDQgKJ8e8HDbvw4CGY74gZdd8TjKY3aYIzmSJqmaIJ0WuPT9QHy7SYkCVojCV7/8+1smfM2nsMvxz3meAAKbTJl+RYW/OtO/Ms+pfjIy7CNOhEDiKz8iKZ3HsPac6SwjlRNNH/wBJFl7+Le7xTyD74YIxkVnfGGrXhP+AP2Aftj6BpN7zwqtEAOupC8CbvcDjLhAPXPXI2kmoXKejuL0W9Csn4T0TWfkfZvEwm0lhGTZ2YritUlEmWnB0vZAOz9JnR6/d5iAjWvGC0eIrF1KZEVHyBbXUgmM4nqVeK1mZRIfncv4HeCBGpbTCAJ/3DDEHGFnslpx3wTZJsbxeGh5PwHO2me7HNMsPZL1PwSEtuWYakciqmgnEwogBYJCKG5eFZ8NnsdZZsLU2El9gEHYO0xYp85zkYmTcMrU0nWrqfk3Ps6UMdi67+m8U3xmO+MachmG8m6jfhfmYpksomiu6c0a637uEjCT52KpJpo+fCvhJe+Q69DzoRx56FLEvHNC/HPugtLSR8qzrqDfLcTXYJ+xW62L/uaBU/diL18ICVn3o5qtmIg0d9nIxhJYbepqIpCvs3ElUcMxGM30xpL0bPQhstqxqTIPDd3K5v8UawmiYZQimuP6kdVoWhU6rrO0uoWtgWilOTb2K9nISa1O0H/d+HAAw9kzpw51NbWUlpaCuw9JvivSMR1XadHjx74/X6CwSBW657tCrrxw6J9R7zNW9xi6phw72mkve25s/82l5RmdEom2pDyb6H+xRtQ7HnCIsLh6fBcw0s3IpntlJxzL2qeT/B1Z95OcsdqkdwMPVQkSC/fhJ6IUHz6bVjKBxHfvAj/a3dh8pRRfOadKE5PNrm/ESQoPvOuXKXbMAyaP/gLkWXvdlj89wY9naThpetJ1W1EcRZgKqzMjrZLGJmkGA2Ph9CiLeiJCCXnPZDjRrWHlogQnPMSlvJBOAYeSKpxO9GVH2HtORJb7zGA4Cylm3aQf+C5yFYnLZ/9CyOTpODwy9GSUSLL3iO2cR6q24upsBLFWUhs7RfoiTDW3mOFlUksiLm0XwcbL6nNwq1NIVW1IFtsyOZ24+k29z6NomUizTS9/RB5k87D5Cmj6d3HckWIVP0mYesWaSbdUgd6psNrJdUi/E7t+eI9TRaQFYx0gkywgXRgB5LZmtMN+CakAtU0vPAHZJubkvMfzFXP9XSSxtfvJrFlcS5pNgyd5g+fJLJ0Nq7Rx4qOSiqO/9XbSNaup+i43+EYfJAYR3/pRox0guKz7hR2JLpG4O2Hia39HM+hP8M97qTcMRiGQfP7Yizec8gluPc75RuPuz0G7cWBAHaNpF80sSd//WJL7vG7Tx7WnYz/m3D66aczY8YMFixYwLhx43KPdyfiPw7aYgKA+++/n+uuu46///3vXHrppd/wym78u7AnjngbMppObTCOpsH7q2vZ5I+iGSJ5OHd8Fa2xNAu2BfjjDb+kcdUcCo+9BufQQ3GawOuy4rYqfPW3qTSv/oqCI36Ja/SxQLtkvMcIvKfejKSas2Pqs7M0o5+L+/ort5Ks25Arrhq6RmD2w8TWfE7egeeSt/9ZuWNO1q6n4aUbMRVVUXzWXV06buyO8NJ3aP7gL6CYMPt6oTgLkBSTSG6TcfREONdJdgycRNHx13bax/eNCfRMisjSd4ltmIPi9mEq6oHqKiS6+lP0RBhbn3EkdqxBjwcxl7TFBCI6k6S2mCCrnJ71HBfK6W2UNVdWPT1PiMDuAW0xQf6k8zH5ehF4+yGhNeP2kazbIJxUIs2km3d27vIjCbE5e9YH3WLPXUc9FiTl34aRiuEafRwFR/ziGz+X9kXx3SkH0bVfEHjrwZzwr2yxk9i5Bv+rtyHbXBSfdRem/BLCS2bT/OETWHuNFmPrqikbN76He/yplB52MQkN4luX4J95B2ZvD4rPvBPV5qS3x8LQSg/pmrU8eePPyC/twV1PvcKCmhiyJJPWDOKpNNGURiKtYzNL9C5yc9yIcjx2E3aLQv9id27cfPbyGr7cGMCkgozM747sj8tmRpIk6kNxvtjQSKnLQm0oyX49PPTydU8T/zuwYMECxo8fz7nnnsvzzz+fe/y/ejQdQJZlbr/9di699FL++c9/dvuK/8jYPbF+4WcTePSjDXy1MYABpNI6j360IdcZ39NI+7wtTXjsZsFFxuiUhNtMMom0jtnXOyu0NlWIr519T86D2+zrLcbUX76JhpduoPice1DdPnyn3UrjzGk0zX4EQ8vgGnEkJefcI7abPhXvqVOx9RmL7/TbaJx5B/Uv/EEkT0VVlJx9Dw3Txf58p9+OpbQfkiRRcMQv0KIttHz0JLLZ9o3dy/jGeaTqNoqR5SGH7jVJ1dOJnCBJJ2TVSRV7dtTK0NCzFh1tSDfXZJVPxXUxMkkkkyhIxdZ8DloG57DDybTWYes1GkvZAFJ1G5AtvfFMPp/o+jnEN8wlb+KZmDyl39ECTEdLRHepqEdb0aMtZMJNOeXUlH8rDS9c17kar1qEv2hhhRB1ySvO/nlRXF5kq3Ovx5NuqaPhhT8QnPcqvpNv2utxpltq8U+/GUkxiWmItiQ8GcM/606S1St3dcK1DE3vPkZ09ae5zokeD+F/5RZSjdsoOvE6HAMOEMJsL9+EoaU7dMKbZj9CbO3n5B90UackvPXTfxJZ/j7uiWd86yT8yMHFHDzAx42vrdzjNm10kPdW13d4/N1Vdd2J+L8BDQ0NzJgxg4MPPrhDEt6Nfw+uvvpqbr75Zu6///7uRPxHhKYbJFIakgRWk/KNVkeSJHVSaAZIZjSiSQ2bSaEsz8aWQISNDSL56llgozGSoo/PSb7dQjKjUXHK9SSiU2l651Fksw3LgIkEQglaEwquo39HPJmi+cMnQJJwjTomu15LNL3zKP4Zt+M79RY8h/8CFBPhha+jpxMUTrkC3xnT8M+6g6a3H8ZIxXGNOoaiY6+hSVYJfvUCRipO/sEXI0mSUFk/8ToaZ92Jf+Y0fKffJuhee4CgNz2LtcdIvCffgGxx7HlbXcPI7KHj/D1jgujKjzDSCRxDDxMxQY/hWMoGkNy5BtnXi/wDz20XE5zxnWMC8b4pQVuLBdGjrTkl9UykGS3SJCYRX7qhc3ddklHcRZjyS3EMOAA1rxglLysC6/aiOAr2muQbmTTNbXoAE05Hde3ZzUP4d/+J2Pqv8BxySYckPLz8A5rf/xOWisH4Tr0F2WInvnUJja/dheIqEom020tw3gxaP/8Xtr7j8Z54PcgyTe/+kejKj3BPOI38yRdmk/Clws60sALfGXcgW50YBsQzGlvXreT9B6/Amu9l0EV38cW2GBP7FtIcSVHTEqNnYR5Oi8xHqxtwmVUhjmySGViWhy1L92jDIQOLCSXTtERSDCnPI6nptKXaMhLooBniOyl3c8X/bbjvvvsAuPXWW/f5Nf8ViTjAxRdfzM9//nOefPLJ7kT8R8SeOOFXHd6fhduaSaV1dGDOpgALtzV3KRC1eHsLZz8lxtXlrKVSV3BaVJKZlKicVwzGe/JN+GdOw//KrRSfeUeu+mwp6UvxmXfQMH0qDS/ekOPpeE+9lcbX7qL5vT9iZJK4xxxP8Tn34Z8+Ff+rt+E98Trs/SZQfOad+GcIPrrvjGmYvT0pPuc+GqbfTMPLN+I95WZsPUYgyQreE/6Af+YdNL0jRJf2lowb2YXFXNz3GxewvS3eqCb0RCSXqGvxMJIkIdt2VTD1RKSD17kWC2Lr1Zd0YAeZUCPu8aeiBf1kWmqFZQjgHH4EzR/9jR2PnYWpsJK8A87B5MmNyeT2FVn1CenGbcKqJJMSPPHsyLue5YrriQh6ItL1uJusCPVUVyGWyiGABIaBrecoLOUDUPKKc4WV7wqTp1QEHFpmr9ulm2va8bfvwZRfkrte/hm3karfTOFx1+Accgh6OkHgjfuIb15I/qTzhW99uAn/K1PJBBvwnXIztj7jSDVuo+HlmwEoPvtuIR7YrouSP/kC8iac1uE4gnNezFmV5E/qPMa/N6iyxOUH9WHeliZkCfbw80GWhIDblCElHTriVpPC4u0t3ZzxHxn//Oc/Abjyyit/4iP534TJZOIXv/gFjz/+OMuWLWPkyJE/9SH9v0QokcZosyw1DFzWPRSU9wDDEH7GS6pbSaRFQj+iIo/mqLBX+my9n1W1KYqcVj5YXc+a2hDztwQIpyV6n30La5++gcY37qP0ommEvcORNA0UE94Tb6Dx9btF9xmyyfhhSIpC4O2HhUbM6bfhOeRSZJON4NcvYaTiFB13Lb7TbiPwxr00f/AXtFiQvP3PovCYK5FMFkILZqEnIhQc9WskWcHedz+KjruGwFsP4Z8xTSRr5r0k41oataBsr0k4kB1V3wOl4ieOCdLNNURWfChiAi2d8x3X03HR1U9mY4JkBCOd7PIUZKsTxVmIWtQTU1Y12lo1HEtZf9Q8H4qzMEcX+y6QVBNmX8/sxdgzrc4wdJrf+xPRlR+St/9ZHYriwfkzaf3s6WyH+0Zkk5Xoms8JzH4EU1ElxWdMQ7bn0/L5M4TmvYp90OScOnrbJFze/meTd+A5SJIkEvhZd6K2TWFmPy8D2Ll1CwtevAHFbOeAKx5m1KB+bA1EqSq0sckfoaY1TiSZodSl0rfYTZ5dZW1dmCHleTgtnVO1aDKNz2km327C0HU2+SPYzCacFhWvy8LAUjdbGyMUu6zYLTLJtIaliwJZN344tLS0MGvWLA499FD69eu3z6/7r0nEJUnisssu469//SuffvophxxyyE99SP8v0d5PvL34WvvO+JxNgS7F2UAk4dPeWp0TaOsqCbeZZOJpncZIR2sNW6/ReE+8nsbX7hbV59Nuyy14ltL+FJ95B/7pU6l/8XrB18kvwXfKVBrfvI+Wj57ESMVxTzid4nPuwf/qbTS+djeFU67AOfwIis+5V1ibvXAd3lOnYq0cSsm59+F/RfiRFh17DY5Bk5FUM95ThNhb0zuPYqSTudG33WHtORJkhfCStyk86rsLNSlWJ5mW2lx1PLbuS9S8kg7q7m2LupFJI6kmMs01mMadRHjF+6Rq1hGx2IWlWaxV8MPySwgveRvXmOOw9RlHyyd/Jx1swLpb1dswDGIb5hLfskjw2lVzh1E0yeLC5C5GsbmQbS7Bebe5URz52T8hFvdjC9zFty8n3bRjj58FQKpxG/7pUzF0TXSts7SDdGs9/ldvRQs14j3lJux9x6PFQzTOvINkzbrciHq6aQcN029BT0bwnX471qphJGvX43/1ViTVkrU3q8TQMgTefojYui9FEt5O5AcgOH8GwTnC+sRz+M+/dZfh0IE+1teHqWmNoyrCKnD3X5EiwVn7VXHK6ArG9PBQVehg+sJqVtcG+XhtA19ubOxWUf8RkclkuOuuu/B6vRx//DfTWLrx4+BXv/oVjz/+ONdffz3vvffeT304/y+h6wYmRXgOp7VvR2XUNDGiHoyniSYz+NwWAuEk/lASCYMBJS4cZpVQPInTaqKmJcbSHa20xNLoOkRkM74zbqfxpRtZ+extlJ5+G9YewqZKUk34TroR/xv30PzBXzB0DfeY43EOPhhJMdP45v3Uv3QjxWdMI3/SuchWBy2f/B1/IoL35JvwnnwTTe8+RvCrF9CirRQc/nMKjvglitVFcO50tFiQohN+j2yy4hh8MIZh0DT7Efyv3ILvtFu6dFiRJAlb77GiQDvpvO8sKPZTxwTpoJ/QgteEqFvbeLpqETGB2YqSX4LZ5hLJdltMkBWNlR0eEReYflw3A0PXCC99V1Dx3N6ut9EyQgNgzWfkTTyTvKz4n2HotHzyD8KL3sA+cBJFx12DpJgILXqDlo+fwlIxBN+pU5HMNprfe5zIig9wjpgifMJ1jcY37iW+aQH5B12UK8LnqJCFFSIJt+dhAtK0NQhuAklmwAX30Cq7aIgIbZ8Fm5tZtK0ZkyLTFEsjGTbKC+xkdAOnRaZnoZ1oMkNNS4xwIoPPbaXQYWLx9hbqWuNs8EcYVeVmaFkBzdEkTosQ5x1emU+vIgcGomgfS2mY1c5UkW78cLjrrrsAvnWz+L9qXmHq1KkA3HvvvT/xkfz/RZtPuJLttrX3E2/rjHf1/OLtLdz02krO/ttclu8MdrlvCcEZ61+8Z66Kvd8Eio77Hcmda2icdQd6u2qrpbQ/vrPuEnYkL15PurkGSTXhPfF6HIMPpvWLZ2n97Glkq+D0WHuMEAvt3FcwFfWg5LwHkR0eGqbfLMRBXEUUnyPEOQJv3k9w/kwxxmOy4jt1Kra++9H84RO0fvUCXWkpqK4iXKOOIbLsPeLbln23C55F/oHn0vrlczRMn0om6Mcx+GCStetJN9cAYOu7H4kdqzG0NKnG7UgWe1Y47TBMRVWE5s8ksvIjEtuWkdi2jPjmhSCrWCsGI6sWMs01tH76T+qeuYpMyA+IBLX6oZPRQn7MRVU4Bh9MxW+eo/zyv1N2yZ8oOe8BZFkm07SDwqN+jWfyBbjHnYRz6KHYeo3G7OuN4vD86Em4nojQ/N7jqHnFOIYd0eU2iR2rxEi8JFF8zr2Yfb0AwfGrf+5a9FgQ35l3YO87nnRrPfXP//7/2DvvMLnq8ot/bptedrZvsum9N0JIIaElAZLQkSagiIINBUEUFSkKiL2gCPwUKQGkBwgEEiCd9N5D6ibbp7c7t/3+uLOTXVIpAYQ9z+ODmVvmzt2Z+33Lec9BrdtO6Xk/wT/sbNR9m6l78lYsQ6Py8vtwdR5EZtdq6p/+GaLLR8UVv8lblGk0zviNnYSf8vWDkvD4speJvvsonn7jbYGaD3lvBAHe2ljPbS+u4+mle8CyqAgcHNBYFnQochcS7ctHdWbSgEpMizZFsnYcH/z9738nlUrxi1/8ol0o7DNE3759mTZtGrNmzWLXrl2f9eV8IeF1SOQMC820DtmZ+yAsy8I0LRoTWfaG0zSnVDwOCVGAxrjtcRz0yHQv8+J3ygzpFOSEbsWYlkU0kyOnmwTcMi5FpMLn4LQBnfnZ3x7DW9qB+ufvIluzofBegqzQ9aKf4u51EpHZ/7TXcKC47xjKL/wFengfddNvRY83EBh5HiVTbiS7Zx31T/0UM5OgZMqNBE68gOSq12h86V4sXaVo/JWEzriOzPal1D/1M4xUFADfgFMpPefHqLVbqZv+E/TEoZ+vwbGXYWpZwm/87ZBxw7His4wJYnMfxVHeFe+A0+h0w3Sqv/1vOn7zQaqu/iOi7MSI1lF69g8pPu1agqO/gn/omXh6j7bFV4sqj3sSDhBb+BRa026Kxl95yOTSzGVpfOFXpDa+S9H4qwr7mZpK08v3k1j+Mv4R0+wZfVEi/PYjROY8jLv3aMq/chdIMo0v3Uty7ZsERl9C8eTvYmkq9c/eYQu9Tvw2ZWMuwiFCeutiGl6wRx8rLr2nYKVm5e9p/dM/wzINul11Lz1796TE60SyLM4ZUoHXIZPTTTTDRBIg5HMgS7AnnCWu6mQ1g0RWoy6exaVI7I9miKQ1YhkdhyTid8jUxTRUw8TraPv7bGGkGqaFKNKehB9HpFIp/vCHP9C9e3cuuODDjSL+TyXiHTp0YOXKldx///2f9aV8YdHS+b5pUp9DdtRab7996gDe29HM9CV7uOKR95i+ZA+5I1TMKwJOvja6K86jqDdWDD2NkrN/SHb3Whpf+BWWfqBz7qzsScVl92AZGnXTbyXXuAtBkimZehP+4VOIL32B5pl/RpAdlF90O57+E4jOe4zwm39H8pfYgmlVvWma8Rti7z1rJ1mX/ApP35OJvvtvwrP+Zitzyg7KzrsN78AziC18iuaZfy5Q0VujaPzVKKWdaHr5PnJNh1eHPxpc3U+gePJ38Z9wDsWnfxM5YKul64kmLMvE3XUokttH7WM30jTjNwTHXIrkL0Ep7Ux2zzqqvv4XiiZcjaC48PQejVLejVzdNhBlkmvfxLJMfANPIzBiGpF3/g3YD2WluJqqr/2Zqq/9mZJJ325zTanNC2wq3Gf48G6xGdPjTZRMvfmQC3xq0zzqn/kFojdE5Vd/a1vAYF9//VM/RVScVH71t7iqB6Du20Td4z/CTMeouPRXePuMJb3tPTvhdnrt4yt6kNq8gIZn70AuqqTiivsPKPa/8CsyWxcTOv1bBEdd2OY64iteIfL2w3h6j6F06o8+Eu3Osg4IGpr5DtShgl9BFNoUyeDIRbR2fHIwDIO7774bn8/XPib1OcA///lPFixYQElJ+/f9eMDlkCn2OCj2OHAcZe2OpHLsCafZG06TVHU8TplU1sCy7NihV7mPoZ2KCHmdeBwyXUu9eJwyXkUi6FSoLvbQu8JHkdtJlxIPVSE3HUs8VFdVcekvH8RbXE7df++Auk0UOcElgUOxi/Eta3h0wZOYpoW7+wjKL7kLIxWl7vFbyDXtwTfwdMov/AVauIa6J25GD+8ndOo1duK9bQn1T/0UIxkhMGIaZeffhta4i9rHf0SucRcA3r7jKL/ol+ixeuoev7nwems4yrpSNP5q0lsXEZv/xEHbj/m+t8cEh0Vq41xii57GO/AMPL3HHLRdT4apf+onZHaupHjy9woFcyMZof6p20hvWWj/3U//Fpah0fTSfSSWvYR/xDTKzv0Jlpal4emfk9m2hNAZ1xEafyVmOkb9Uz9FrdlAydQf4R8+BUWEzKa5NL50L46KHlRc+usCHR044LKiqfT7+n1079GLqpCHQR2LuHFiX8b0LEMzzPyYpkWp10nfqiD7Iiodg06ymsHcbY32/D4gYmGaJm5FIOASSeR0qovddC320CHopNjraHMf3IqEUxaRJRGf88ONlLTjw+H+++/HsizuvffeD13w+J9QTT8astksoihSU1NDVVUVbrf7qMe04+Oh9Sy5KNhVt5ZvkgAokkDnEi/bG5If+twt6urJtW/R/PpfcHUbRvkFP28rVNK8l/qnf46lq7bgWoc+WJZFbOFTxBZOx91jpO0DqTiIzn2M+JLn8q/9GEGUaJr5Z9Kb5uIdeDolk78HkkR0/hPEF/8XZ6eBlJ1nq7jb55xuK5i2er01tGgd9U/cAkD5pb8uJILHA5ZlYqnpAi0uW7OR2KKnKb/4DtsqZMsCHOXdCY7+CrHF/yW1cS56ohGluJqSs3+IUtyRmr9dSacbpqPH6ml47i46fOOBg4RaTDVNw3N3UnLm92l8+T46XPO34/aZDvtZ9RyNL91L5v1ltsLtwNPbbrdMYgueIrboKZzVAyi74GdI7oD9+sKniS2cjrNjP8rO/xmSt4jUxndpmvln26f+ol+ilFQTXz6DyJyHcVT1ovzC25G8RXkru4fsYy+6Hcnlw1RTtm1ezUaKz/we/iGT21xLi/2du9dJlJ37kyOKzBwNH3QXaPmrtH5tYv8KHr7KFuBsLa4IHNbBoB2fDF544QUuvPBC7rvvPm699dZD7tOumn58cKwxAbTHBZ82VN2gLpbF45CJpFQM06TY5yKb06kIuHDlu3XZnN3lU0QRSRJoSuXANIlnNERRpCGhEs1obKmNk1A1yv1u9jSnSOQM6vfX8MZvv4saDzPs2vvQynqQzIKBTVVufuOvpNbNxn/CuYRO+waCIJJr2EnDs7/E0nP2WFr1ANTarTQ8dyeYBmUX/BxXp4Gkty2h6ZX7EV0Byi/8BY6K7qi1W2l84VeYuQyl024ueF3n6t+n4bk77den3oyn16g298KyrAKluejkKwnmfaWPB75MMQHYxYCmGffjrO5PxVfuahMXgm0d1/jCrzCzSUrP+TGenid+4PUEpVN/hKf3GPREE40v/Ipc3ft515Nz0cL7aHjuDoxEMyVTf2SLtUZqafjv7RjJMGXn/QR3D1uc09gwi5rX/kag60CqL/wFaclTWKe1aB0NT92GmUvT6fJfU9alF36XzPDOJXQMuZk6uAO7m9PM2dKAUxJI5HSGdwpRVeTi5VX76FbqI5LO0b9jgHMGV7M3nCaW1SgPOOkU8hBPq2zOix1W+J10LTsgeGtZVmEsVG4XavtUUFlZiSzL7N2795CJ+JFigv/pv1AymeT111/nt7/9LRdddBFnn3021157LZs2bfqsL+1zhRW7IzzwznZW7I58Yse2niU3LQtJFJDydkqXjerMU98azTVjux3TewhAx6ID4ictDzLf4ImUnHUD2Z2raHj+bkwtW9hHKelExRW/QXT5qH/6Z2R2rUYQBIrGXU7x5O+S2bGC+qd/ipmOEzrlaxRP+g6ZHSuoe/LH9uzXtJsJjr2c1Po51D99G2YqRmj8VZRM/RHq/i3U/ueHqHXb8+e8gtJpN5Or3Vp4vTWUokrKL7VnQ+qn/4RszSf3/dNj9fbCGau375UgtplNM5Jh5EAZgiDi7TsOd89RGCn7bxUc/RUqv3o/kidE6PRv4SjtbHtwurwYWbtAoscbqH30B9Q/fRvZvesL543Of4LASRcdtMh9EFrTXiLv/IvEqpmo+7dgHc2L/BhhZJPU//d2Mu8vo3jSdw5Kwk01ReOL9xBb9BTegWfkhVECmGrafn3hdLwDT6fi0nsQ3X4icx8t2JNUXvk75FAV4bf+QWTOQ7h7jaLisnsQPQEi7/47n1CPovySu5FcPoxUlPqnbkPdv5nSaTcfnIQvf7lVEn7rIZPwo7FAWtC/ys+4XqW0fozbvu/2XDjY/z21TzlwwB7wd7O2cNlDiwH47qk9AQ77m/84z4N2wF/+8hfAFhBtx+cP7XHB8YdlWYemXVu2LoxDEgm6FUzLIuR1FpJwy7KQRRG/24FDkRAFAVUz2BvOsLMpxY6GOI3xDJZpEnQpIIgkczrJnEGfch+9unfm9Bv/hMNfzKpHfoJas5GWFUcQJUrOuoHAiGkklr9M88y/YJkGjvJuVFxxP7InSP3TPye1ZSHOqt5UXvl7RE8R9U//nOT6OXh6jaLyivsBi7onb2m13x9QijvS+PyviOa9vR0VPai8Kv/6C78iumA6VishU0EQKJ78XbwDTiU6/3HCcx7+xNZGU8uS2bWa9PvL8u/1+YoJAKKLnib23rOkty/ByBzegvPDIrFqpp2Ed+hL+YW3H3QtyfVvU//kjwGByivuLyThyQ3vtH299xh7bO2xm9DC+yi78BcERp5Lds86mzGXTVFx6a/x9hlr7/fEzZiq/Zq7x0i7QfPes9S8+ldK+5xIzyt/RajIX1i37ST8p5i5NN2/aludGqZFJmexZl+M+rjKGxvr2dyQoHOxG4ciUeZ3UuJxEHQ7OL1PBZF0Dp9TZlyvMtxOmd5VAU7oWkynkC1iHPA4GdqpiIEdg3Qpbes6k8hoRNM5kqpOTj+2791hf9PtOCpmzJhBfX0911xzzUei///PiLV9EMlkkl/96lfU1tbSp08fbr31VsaOHcs999zD3XffzfTp0z/rS/xc4HAq6B/32BYarKabKLLI7VMHEEnn2nTiXlhZc0zvYwGNiUMrb/oGTwRBoHnmn/OKpb9AdNidDSVPG7YF1+6gdKotuOYfehaSt5imGfdT9/iPKL/4DvzDzkYOVtD48m+oe+wmys7/GUXjLkcp7Uzza3+k9j8/pOz82/ANOBWluJrGF++h7olbbJ/SIZPw9j8FOdTxwOtnfAvfkDMLPzpHaWf7Wp69nfqnf0rJ5O/jG3T6IT/TsUCL7Kd55p9R8/NwJWf/EN+gioN3PJpau9OLIMnIwfIDL+aftZKvhOrvPIro9NoV/ud/RYdr/44e2Y8eq8PTY6Tt+X2Eh7MWrSW+4pWCNYnkKyZ02jfx9jv5w33g1uds2kvDC3ejxxoO8vsEyNXvoPHle9Gj9YRO/yb+EecgCAK5xl02jT1SW3jdVFM0v/I7MjuW4xsymeKJ12NqKg3P3kF21yrbsmzC1dDiMbp5Pr5hZ1N8xnUIomQvqP/9BUYiTPkFtiVea7Sornp6j6H0nFsOa1Gn6odQmz8EqkMeTulTzrJdYVTNFmkTOOAX/siCnZiWxV2vbqBPpZ/nV9YUxkFyhsXz+d/c4X63H+d50A7bI3Tu3Ll85Stfoby8/OgHtONTRXtccPxhGAeeZa3nTp2yRMjrIKnqBNwKIa/joKC0EOxbkMnpWIBuGuyPpRGAeVvq0U2RVE7D7ZQp9zuJZ3S6l3kIeR1sq42wN+um/JJ7qH/mZ2x//Ha6XnI7eoe8gJsgUnT6txBcfmILp2Nm45Se8+NCrND4/N00vXQfxqlfxz/yfCqv/B1NL91D82t/RGvaS9H4K6m86g80vvhrml66l9zoSygadzkVl/+G8Ky/EVvwJLnarZRMuamgMRN+8wFiC6ej1m6hdMpNBcacIEqUTLkR0eUjsfxl9Mh+SqfdfFQ19cPBMg2iC6aTWP4SlqbiqOyFp8chbBM/45gAILXhHfT8LDuCiHfAqRRP/PYR1eaPBMvQibz9CImVr9rMxnNvbeNEY+k5wnMeJrn6dZu5eO5PkLxFWLpG5J1HSKx8rc3rybVv0vzm35F9JVR+1XbSSax+g/Bb/0AJdaDsol+iFFWS3rqYpld+h+QtovziO1FKqtsIvXUccQa9Lv4RaU0gmrYdXQ44t2gM/eb9VHTrT00kiSgIyLJEyKMgiALZnI4iifh9DgIuhcqgi96VQVRNo6RrET0q3FgIJLI6aVXD41QQBAFBEApsBZfj4AK/bphkNANFEslpJrIo4JCPPCbXkoC3/Ld9lvzD4fbbbwc+uoPKJ5KIC4JwJvBnQAIesSzrvg9sdwKPASOAZuASy7J2fdT30zSNa6+9FkVR+OUvf0nPnj0L20pLSykpKcEwjHYRHQ6vgv5xj22ZFT8cDXbF7gjPLt97zNd5pNly36AzQJTyiqW/pPziXxYWM9lXTMXl99H4/F00zfgtRipK4IRz8OS7nA3P303d4zdTev5tuLuPoPLK39H4/F3UPfUTSiZ/D9+gM1CKO9D4wq+pm34rxad/C9/Qs6j62p9omvFbwm/8BbVmPcUTv42zqpf9+qu/JzzrAbK711Iy+buFarQSqqLyyt/T+NJ9NM/8I+q+jYRO/+aRrcsOg6ZXf48e3k/RKdfg7jYUpbjTIfeTfSXo8abCv414I9IH/DQlfylGvBHZV4xlGpi5dMFOrKV766jogRyqRAvvI1e7lVz9DvY9+A0s08RIR6l/+jYqLr3noPf39DyRzj96HiPehLpvE/HlL9P0ym9xdujTdqE/RiQ3vEN41gMIiouKy36Nq3pAYZtlWSRWvkrknX8huf1UXH4vruoBWJZFcu2bhN/6J6LTQ8Vl9+DqNNBO2F+6Bz3edEAZvVWSX3LWDfgGT8JIRWl84Veo+zdTdMrXCZx4gZ3Y1++wKY2GRvklv8JV3a/NtcQWPU1swZN51dUffSw6egve3FjPu1sbuWZMVx6evwPDshVPWwpdpmW1+T1+cLkUOPLv9uM8D9oBd955J3BAHbUdh0d7XPDFRUsy8EEE3AoB98HFSNO0Cv7joiiQyuZI5AxEQWRPUwoJAVOwSOsmHqdIudNFIptjXySLLAlUBVwk0iob9kdJ5ED2l1B16b00Pfdzdj19B70v+znpquGFaysadzmSJ0j4rQdpeOYXlF14O5InSNmlv6b5tT8QeedfaJH9FJ9xPeUX30V49j+JL3kOrWk3pdNupvKy+wi/9Q/ii58hV7uV0mk3UzLlJpwd+hCe8wi1j/6AsnNvxdmhDyVn34izQ1/Ccx6i9t83UDrtR7g6HygMFJ9xHUpxR8JzHqb20R9Seu6tOCt7HnSPjobUpnnEFz+Dp98EfANPw3GYc3zWMQFAx2/+E1NNk2vcSXrLIhLLZyD5QoQmfO1Df2493kDTjN+i7ttkjxycek0b/RWtaS+Nr9yP1rCTwKgLKRp/VaGI3vTyb8jVbSMw8ny74G6ZNL/xV5JrZuHqMpTSc3+M6PAQfusfJFa+ZtuYnfNjBKeX+NIXiLzzbxxVvSm/8BeFxL5p5p9Ib5pL6YnT6H7ud3E7FETJQNMhVb+Luqd/DpZF9WW/pkP3fgR8DmIZBwYmXodARcBFzjARRQtZFJCAYV2KcUoCKVVDBOrDabY3pwm4JDoUeUnlDJyyhJ6nmx9Nq0ESbKtBE+uYqOktiX17Av7hMWvWLNasWcN11133kXVKPnbkKAiCBDwATARqgGWCIMywLGtjq92+AUQsy+opCMKlwG+Ajzw0s2/fPnbu3MmSJUsAewHeuXMn7733Hi+88AJ33HFH+2Kbxwc71x9GwOlox7bYmrVGy7zqmr3RD211ciT4BpyKICk0vfJb6p/+GeUX31moPEsuH+VfuZumV35LZM5DGIkmik75Gs4Ofai66vc0PHcnDf+9neKJ1+MfehaVV/2Bppfvo3nmn8jVbSd02jeovPpPNL3yO8Jv/t2eA578Xcq/cqedbC18GnX/VsrOuQVHRQ/KL76D+HvPEZ3/BOq+zZROubFgqSK5A1RccjfR+Y8Tf+851JqN9gJedeyeggBmNoUcqiIw8twjin45qnrlq9UNSL4Qqc3zKTvnx2328fQaRXLdbJwd+pDevABXlyGA7TsqunwgiOjRWvRIHUqoCmdlT/zDzgZAjzfS8Owdh11wwQ425GA5crAcPd5ArnYrZi79oT6vkU0SeetBUhvfxVndn9JzfozsLy1s15Nhwq//hcyO5bi7n0DJlBuRPEG74/3m30lvnIury2BKp96C6C0isWYWkdn/RHT5qLz8Xpwd+5Heuoim1/5oW5Hlk/xc4y4anrsLMx2j9Lyf4u0zFsD2A33pXkSnz7ZCazX3b1kW0Xf/TXzpC7bGwFk3fChhNo9DIpMzDrIka0FON3liyW5afj6GBev3x7hwePUhf4/PrqgpvHbB8GqAw/5uP87z4MuOLVu2MHPmTM4999w2SV47DkZ7XPDFRUtX+1iD9oyqk9UNspqBKGD7kAsCIiKmZaFIEl43mIZFl5CHTfUJFEnCJRp4nQ5yqs5zy2soDzoIZ4wCFd3pCzH6u39g9f/9lM1P3mXPA7cS7vIPn4LoCdL06u+pe+IWyi++A6WoktJzby3oxuiRWkrP+yklk7+Lo7wb4dn/pDbPmCs56wYcVb0Jz36Q2kd/SNm5P8Y/fCqOyl62NdqTP6bo5CsJjLoA/7CzcXboQ+OM+6l/6mcERl1A0bivIshK/lqmopR3z7P0bqbo5Cvsgu+HWDfMPHU8MGIazo59D7vf5yEmABCdHlzVA5CDlSTXzMLMpo75s4L9PUttfJfwWw+CZVJ6zo/x9hvfartJctVMIu/8G0FxUnbRLwsMgdTGuTTPegBBECg7/zY8vcegRWptUd369wmcdBFFJ1+JmYlT/8zPUfeut5P1U74GllmY7/f0GUvJlJsQFSdmNknDi/eg7llL0YSrCY66iEhGx+sUiMQ1iOxk7/TbECSFTpf/mqKqToS8Dko8Ms7KAF1L3QyuLsIly2ytjZHM6VSWOhjUMUCPch+WadGUsgUPdqR1Ak4JTbeoCafpWuKlKZVFFkWcsoRpWbgdh07fJFHA7ZTRdAu3IhzTWFzrLnt7Mv7hcPfddwMHivQfBR9brE0QhNHAHZZlTc7/+6cAlmXd22qfWfl9FguCIAN1QJl1lDc/kjDL4MGDue6663A6nciyzJYtW4jFYkyZMoUpUw7vNfxlRGsxpw/T/VqxO8ILK2uwgAvzfsVH2/+KR94rUGqPB9LvL6PppXuRg5WUX3JXm2TNMg0icx4isfI1PH3G2bQwxWnPE8+4n+yOFfiHTyF02jdBEIi8+28Sy17C2bEfpefeiuQrJrboGWILn0IOdaDs3B/jKO9Odvdaml79HUY6TtH4KwmMPA9BlFD3b7E715H9+EdMo2j81W2oV5ldq2l+7Y8YqQiBURcQHHPpMXfHkxveofnV3+PudRKlU248Ip0ts2MF4TkPgWXiGzyJ4EkXE53/JI6qXnh6nmirgr76B3L17yO6/ZSdcytysJz0lkVEFzxh06kFgaJxXz2Ieq3HGmh44W46fP2vR7zeQrd6zsO4e4yk7IKfH9MD3bIsMlsXE579IEYqSnDsZQRHf6UQoFiWRXrTPMJvPWhby5zydfzDpyIIAtm962l67Y8Y8UaC4y4neNLFWJpK85sP5BPzoTYV0O0nOu9x4kuew1HVm7LzbkMOlNq0s1d/j+j0UHbBLwrFkuTaN2me9QBKSSfKL77joO9Y+M1/kFzzBr5hUyieeN1Hsm+TJQHTtLvbx4IrRnXm1+cPOuRvueW1kMdRGA+Bw4u2fdTnwZcdF198Mc899xybN2+mT58+R9z3yy7WdrzigqOJtbXHBccXet4X3LIsZElAFI/87NMNk5Sqk84Z7I+mKPE4cTpkgm6ZSDqHpoNm6KRVA0WRMHSd5pTGjsYkS3Y0sDeiklZ1mpI5ijwCadUimwPNAr8HTuxSTEDUeOzO71K/YyMlZ/2gzUiYBzBr17PtmbtBVii/8HacVb0BSK6bTfMbf0MOllF+we0opZ3I7l1P48v3YeWylJx1A95+41HrttP00r3oiSZ7/T/xAiw1TfMbfyW9ZSHOzoMpnfJD5EA5Zi5D5O1HSK6ZhVLahZIpN7bpfhuZOOFZD5DeshBHVW9Kzvx+wWrzaDCzSfb/+/tYmkrZeT/F1XnQYff9PMQEYIvqNrzwa4xEE1VX/xGl5NDMvg9CjzcRnv0gmW3v4ezQl5KpP0IJVbW6hnqaX/8L2d1rcHUbQcnZP0D2FWOqKcJvPUhqwzs4O/Sl9JxbkIMVpDbNp/mNvyIIAiVTb8LTcxTZmk00vXwvZjZFyVnfx9v/FJsd99I9qDUbbd/xk69AEET7Mz93B1p4v82kG3gaAH4JnDLEdq1j5zN34/QFOPG639CxSxfKAl5iyRwOh0w4lcUpyjhdIj2K3cSzBm6HxMAOQTqX+hlYHcSlyGi6gaoZrNwbRREFopkcIY9CideJqluIokVV0IMiifhch1dBt4tl9pRCCxulPcH+5LF06VJGjRrF1VdfzaOPPnrEfY8UE3wSifhFwJmWZV2b//eVwCjLsr7Xap/1+X1q8v9+P79P06HO2YIjLbrLli1j+vTppFIpysvLqaqqYvLkye2dik8IH2WW9IF3tvO7WVuOWxLeguyetTQ8f7ftF37J3SjFHQvbLMsivvRFou/+G0eH3pRfkKcUmQbRdx8lvuzFNrNCqU3zaH79LwiKk9KpN+PuNuxA4p2JEzrl6/hHTMt7Wf+N9NZFOKv722qjoQ6YuSzRef8hseIVpGAFJZO+g7v7iML1GNkkkbcfIbVuNlKwguLTv4m756hjeijGV7xCZM7D9tz1qd/A03fscffs/ijINe4mMudhsrtX4+55IqXTfnxMs2BaZD+ROQ+TeX8ZSnk3Ss76QZugRY83EH7zH2TeX4ajqg+lU36IUtIJU1OJzX+C+LKXkIsqKZ36I5wd+9qFkVd+ix5rKCT0RjpG04z7UfeuxzfkTIrPuA4kKa+u/xSOql6Unf9zZH8JlmXaCft7z+LqOoyy836K6PQUrsfSNZpe+wPpzfMJnHSxTYH7iIubJMDp/SpYujNMNHPAFk8WhQL9rAUOWeSpbx7599c++318sXfvXjp37sykSZOYNWvWUfdvT8SPT1xwtES8PS44frAsC8OwEES7eyYKQoFufjgYpkVTIkttLEsyqxFwKwTdCpVFbgTTIqubdpJggSIL7ItlcckSa2rCzNvcQFMiy8a6JOlsDo9DJJExKfVLuF1OSnwOgi4HXodMOBLlmd/cRGzH6rz69XkAVDghq4ER3cuW6XdgpKK2YnYfu3OerdlI44v3YOmq/Xqvk9ATTTS9/BubCj18KqFTv4GlqzS//hfSWxfh6jrMZmR5Q6TWvUV4zsOAQPEZ38I78HQEQSD9/jLCb/zVHpUbeR7BsZcX1sRCcXnOQ5iZhF3EH3f5Mc2Oa5H9NDx3J3p4P/7hZxMce/lBLi6fB5halsSyl4ktfgZBcVF2/m24Og086nGWoRFf/gqxRU+BaRIcd0UbVqBlGiRWvkp03uMgCIROvaag1ZPZuYrm1/+CkWwmOOYSgmMutWfHZz9Eat1bODr0oeycHyMFykkse4nI3EeRA+WUnf9THOXdUWu30fjiPZiZGCVn/aCgTaPWbqXp+bswdY2uF90G1UNwiJDNSyWkti2hacZvcBdXMvXmP1NUXkVlwIkkiVgm7ItmqItmQIS0msNCoE9VALckUhZ0M3VwRzqGPBR5bOE5y4JYOkdNOI0sgc+l0JzKYZkmOQPK/Q46FnuPae7bMOwk3DQtJKk9Gf+kMXHiRGbPns2+ffvo0KHDEff9n0rEBUH4FvAtgM6dO4/YvXv3Ed8/l8uhKEr7F+wTxgPvbOf3b27BtOyk4aZJfQpqzIfDit0RLvrHorbWS3kvppbXJvWvoHupl4fm7zjmbuChoNZtp+G/t4Mg2PZlH5iXSm1ZSPOrf0D0BCm/6HYcZV2B/AzyG39FdAcoO++nODv0sWeMXroXrXkvwdFfITjucsxskuaZfyLz/jJ74T37B0i+ElIb3iY8+yEwdIrGX4l/xDQEUSK7dz3Nb/wNPVyDp+/JhE77RptOanbPWsJvPojWvAdXlyEUnfL1Y5oTU/dvoXnW39AadqKUdyM46kI8fcYeVhTs04Rau5X40hdJb16A4PQQGn8VvmFnHbVYYGTixBb/l8SKVxFkhaKxl9mCa/nZNHsxnkFsoS2sVDTuq/hPOMe+zzUbaH79L+jhffiGnkno1G8gSAqxxf8ltuhpJH8JpdNuxlU9gMyOFTS99kcsLVNQXjeySZpf/T2Z95fhHXgGJZO/gyA7MHNZml/7A+mti+yEfeL1bWa+W9TYs7tXU3TKNQRHXfCx7p0o2PPcH5zeEAHzA/v96rxBXD7qyJZ4H+X32o5jx+23387dd9/N7NmzOf30o4swtifin1xc8GFjAmiPC44X9BahNotjCuwN0yKcyrIvnEYSwbAESnwOREsgZ5oEXTJBj7PQsUtmNfaGU6zc1cy2+jjRjMH2+gRZTSPoduBySBR5nHQIOnE7HWiawe5ICkUU2by/mVWP/ZrG9QsoHX0RwZOvxuMQSGpQ5oJwc5S9z9+NWruFovFXETjpYgRBQI830vjiPfYscV6cDcuyGXPLX8ZR2dMelSqqIrnmDSJzHkFQnLYqep+xaNE6ml/7I2rNBlzdR1Ay+bt2dzybJPLOv0iufRMpUE7xGd9qU4Q3MnGic/9Dcs2biJ4ARWMvxzdk8lG1Rsxchui8x0isfA1BceIfehb+EVORA5+9eKSRSZBcM4v48pcwU1E8vccQOuM6ZP+RR6AsyyK9dRHRuY+iR2px9xhJ6IzrUIoqC/uo+7cQfvPv5Orft+/zpO8iB9veZ7m4mtIpN+Ls0Ad13yaaXvsDeqSOwOiLKRp7OWYubcd125fi7j2a0rN+gOjy5cXb/oHkLaLs/J8VYrPUpvk0z/wjsjdE1UW/pKpTZ3QT3ArUJyG29i0a3vgrjsoe9L7sDk4b3oPKoJukqlMXU0lmcuR0g7poBqciohrgdkj0qQzgd8uc2q8DJ3UvJaMZFHsc2MGyhSWIyKJd6IqkVBriWRLZHA5ZolOxl5DXiXSUIphlWZimVaCdt3fFP1lEo1FCoRBTp07llVdeOer+xzsR/0yo6WAPye/bt49rrrnmoNej0SiDBg2isrKS4uLij/rxvrRo6bC1zJIeS4dtxe4IFz+4qE2CHXDJxLN64d/9q/wM7Rzi6aV7PlYiDnl1ymd+gZlNUHbeT3F3G95mu1q7jcYX7s57fdrVbsh7gL54D0aimeLTv4lv2NlYmkp49oOk1s22qerTbkYKlJNc/TqRd/4PQZQJTbzepi8lmgm/+UC+U9uLksnfw1HRA0vXiC15jtji/yKIEsHRX8F/wrmIihOwlT8Tq18ntvApzEwcT59xBMdehqOsyxE/p2UapDa+S2zxs+jhGkRPEG//U/D2G4+jqven+nDV402ktywkteFtcvXvIzjc+IdNITDqAiR34IjHmtkk8eUziC97CUvL4h14OkXjr0T22b9Py7LI7lhO+O3/Qw/X4O55IsVnXIccrLAX27mPklz9BlKgnJIzv4+72zByjbvzs/7b8sqs1yNIDntGf+kLKKVdKD33VhylnVtRDJspPv1afMOm5AOxBhqe/xVa4y5Cp16D/4Rz29xTIxmh4bk7yDXsPIj6+EniUH7hx5pUT1+yh9tfXo9pWe0d8U8YqqpSXl6Oz+djz549xzRn3J6IfzbUdGiPC44nWlNej2Xd0XQDwzRpSubIaiYuSUCwDOKaRcCtkM3pdAh6cCgSkiRimhb18QxN8SzramNEExk6FrnY1pimIakiWSKVIScNcRW3DBnNIBLP4XRJROMqO8IJdr36IHsWv0po8Bl0Put7JJARsQWRNE2l8fW/EN80F2//Uyg+8/uIihNLz9H85j9IrXsLV9dh9vrvCZLe9h7NM/+EZRoUT7we74DT0MM1Nq27bhve/qcQOuM6RJeXxIpXic77DyDYRfrhUwtF+vCbf0dr2oOr23BCp13bRnNErdtO5O1HUPeuRy6qJDjmMrwDTjnq/LjWtJfooqdIb14AgKvbMHwDTsXd48Q2TK7jDUvXyOxeTWrDu2S2LcbSc7i6DCE49rKjdsHtNX8F0YVPkqvdhlLSmdCp17Shw+uJZqLzHie1fnbBlcXTdxwA6S0LbV2gFubBuCsQBIHowunEl7yA5C+ldOpNuDoNzDMdf4+RiRE65Rr8I6Zh6Tkis/9Jcu2buLoMofScHyN5gliWSWzBU8QWPYWzY3/KLvgZJYEgbgn8TohlLHbNe5aGeY/Z7Lnzb8PncHPWoBI6lPgxDIPlO5upjWdwSBIuh4TP7aAhlqMyoBDyuikPKJzQpZQSv5NSj4Og1wGIWJaFQwZRlPA4JJKqQU43iGdzhNwO3E4FlyIdNRGHtg4HUruf+CeKm2++md///vfMmDGDadOmHXX/452Iy8BW4HRgH7AMuNyyrA2t9vkuMMiyrOvzoiwXWJb1laOd+2iLbi6XY+3atZxwwglkMhkefPBBpk+fTnl5Of369WPz5s2Ul5fzr3/962N9xi8ajjQj2nob2HOmiYzGhto4Zw2satOZ++B5fvbiOp5csqewPd8MPwiiYCunGob1sWnseqLZnt1p2kPJmTcclCTpiSYaX/g1ubptBMddQXDMJQiCiJFJ2J3RHcvx9D2ZkjO/j+j0FEQ+wKJ44rfxDjgVPbKf5pl/Qt23yU4OJ30HyVeSp5c9jJmJ4x8+haJxVyC6fGjROiLv/B+ZrYuR/GUUnfzVNgurqaaIL3mB+IoZWLkM7t6jCZ544REFWMAWJ8nuXEVizRtkti8DU0fyFePqNgJXl8G4qgcgBco+0cTcVNOo+zeT3bOO7M6V5OrfB2xFVd/giXgHnHbURV9PhkmseIXEytewcmncvU6i6OQr2xQg1LrtRN99lOzu1cjFHQmddi2eHiOxLJPUhneIvPNv+z6PmEbRyV9FkBTiS54nuugpRIeH4knfwdt3HLnG3TS9+nu0hh34hp1td8xlB8nVrxOe87CtnHvurTg72gro2b3raXzpXixdo+ycHx80D6c111D/7C8x01HKzv0J7kPZxXxC+ODvRRQ4pqS6tTaDJArcde7Ao3bQP3h8+8z44fHHP/6Rm266ib/+9a9873vfO/oBtCfixysuOJZEvD0uOP4wTbMg7nSkOXHDtNB0w3Z7MEwcikRDPEM0rSHLIoJp0THkQQCcTgVJFGhIZKiNZGhKqViGRUXIhZ4z2BPNsKcxSZnfRTKns3ZvlHAmR3MsTcDrpLrIS1M8g8elMPeZh9j8xqN4uw2n+NyfIDo9OACnABnTonnxf4nOfxxHZS/Kzr8NOVDWyn3jQSR3gNJzbsHVaaCt2v3K71FrNuDpezLFk76D6HDbLKzFzyC5AxRP+jae3mPs2eVZfye7cwWOih4UT7weZ8d+dhF+5WtEF07HymXsue1xl7cpQmd2LCc2/wly9e8jBysIjLoQ78DTjqoro8caSKyZRWr9HIxEE0gyruqBuLoNw9VpII6K7p8og84yDbTmvah715PZvYbsrtVYuQyiy4en3wT8Q8886ty7ZRqkty4mvuQ5cnXbkQLlFI291Kb2F+KkNPGlLxJf9gKWaRA44VyCoy9BdHrQwvsIz37owH2e/D2cVb1Q922m+fW/oDXvwTd4EqHTrkWQlHxh/kXk4o6UTrsZZ2VPtPA+Gl++z1ZbH/0VisZdgSBKmGqaptf+QGbbe3gHnUHJpO/ikBXc2El4ImNQ89aDRFa/jrf/KZSc/QPckoJXganDyuleVkRDPMP87Y0kswaWZeCUJMb2KieuGfQq9+GWFSTJwOdwktNNgm6RnpUhRAFME0I+Jw5JQBREFFnMF7QsRFFEEjkqLb0dxxeJRILi4mJ69uzJxo0bjynmPq6JeP4Nzgb+hG1T8i/Lsn4tCMJdwHLLsmYIguACHgeGAWHgUsuydhztvMey6AKkUil+9rOf0dzczJVXXsmECRPQdR2v10vXrl1599136dq168f4hF8cHGmW9FDbttQluO3FdYXj7znfpskeat/nV9YwvVUifiQIApR6HTQmcx/7M7WmDQfHXk5w7GVtfhimphKe9TdSG97Ji5/dhOj0YFlmQf1cDlZQes6PcVb1Qo/V0/Tq71FrNuLpPcZeeN1+EstnEJ3/BIgioQlX4xt6VoEmllz1uk0vO/lKfIMn2pXwPWuJvPMvcnXbUUo7Exx7OZ4+YwrUbSMTJ7F8BomVr2Jmkziq+uAfMRVvn7EIsuOIn9nIJslsX0Jm+1Kyu1ZjqrYiqegpwlHeDaW0M0qoA3KwHMlXgugJIDp9CIqj8P6WZYGhY+bSmOk4RiqCHm9Ej9aiNe8l17g77wVqgSjh7NAHd4+ReHqNRimpPuL1WZZFbv9mEqtmkto0H0wDT58xBEdfgqOie2G/XONuYgunk96yENEdIDjmEvzDzkaQFNT9W4jMeRh1/2YcVX0omfwdHBU97MV21t/QGnfZgdEZ1+X/Pi8Tmfc4otNDyZk34Ok1yh4xyIvquLqNoHTqTfmKd15c7u1HkIsqKb/g5wcJyWT3rqfxhV+DKFF+0QGRn+MNAZjYv4IhnYqOKTn+OLT09tnyI0PXdaqqqjBNk/r6emT52IxGvuyJOByfuOBYYwJojwuOJ3RdRxTFPOVVPGIgapot3sR2hzyl2r7ICdWg2CPjcSiI+blz0xKJpnOE0yoikFB1EEx2NqaJpjI0JHSqQy4sYMHWRpKqQTKbo8glUhHwUBlyI4oWyazBkjeeZ86/7kcp60L5Rb9E9pciQUF1Pb1tCU2v/g5BdlB27k8K4me5+h00vnwverS+IAIK2IXfBU8ieYKUnHkD7h4nkKvfQdPMP6E17MDdezTFZ1xnF+m3LCQy52GMZDPeAadSNOFqZH8pRjpGbNHTJFbNRBBl/MOnEBh1YWHO27IsMtuXEFv8X3K1WxFdfnxDJuEbelYbmvahYFkm6r5NZLa+R2bHCrTmfCwmyTjKuqKUdkEp7ohcVInkL0PyFiG6/YgOd5vuu2UaWFoWM5vESEUxEs3osTq08H60pt3kGnZiaVn71IFy3N2G4+k1ClfXoUdN+I10jOTat0ismokRb0AOdSAw6kJ8A08rHGvmsiRWzSS+5LkCc7BowtUooSpMNU1s8TPEl71sj7Wd/FX8w6diaWqBri/5Syk583u4u4+w/z6v/QGtcVd+lO1aRIeL5Pq3Cb/5dwTZQemUGwsFdq25hsYXf40W3kfotG/gH3EODkGgyAE+JxhGltWP/4bE9mWUj7mI6lOvIqGLyECnIpGhXUP0qypCN03e3dZIIqNjWSaDqoo4f2RXfC6ZbQ0JVM3ELQu4FBlBFMlqGr0qgzgUEUs3KfG7ABGnLKDIUjul/HOGllG1J554giuuuOKYjjnuifjxwrEuuo8++igLFizgl7/8JZ06HQim6+vrueWWW7jzzjvp1u3YlCm/6DhS0H6obe/taGb+tgMjeyf3KuXxb4w65L4hj4Ofv7gOE/s1URQ+UQuzI8EyNJrf+Bup9XPwDjiVkjNvKFiHQF7Ve/kMIu/8H3KoirLzflqYG8/WbKBpxu8wUhGKxn+VwIkXQIvo24In7I7rxG/j6TsOPVqX9xBfjaOqF8UTv2NXYuu2E5n9EOq+jShlXQmd8nVc3YYDFunNC4kueBI9XINS2oXg6Ivx9D35QOU3lyG5bjaJFa+gR/Yjuvx4B5yCd8CpOCp7HfUhbJkGWuMu1H2bUGu3ozXuRGuuwdLVQx8gyvkBZYNDchYEETlYgVLWBUdFD5wd+uLs2BfR4T7q30FPhklvnEty3Wy0pt0IDje+QWfgHzENJXRAzEKt20588X9Jb12E4HATOOFcAieej+j0okXriM5/nPTGuYjeIkLjr8Y76HTMbNIueqyeheQrtrsQvU5Ca66h+fU/24yFXidRMvm7SN6Qrar+6u8xkuGC2q0giJi5DM2z/kZ641zcPUdROvWmg4Rykhveofn1P9vq/Hnbm08LDkngqW+NPma1848yRtKC9tnyI+ORRx7hm9/8Jvfffz+33HLLMR/XnogfH3yYRLw9Ljg+aLEva/nfsRanwO6kZzW7u4dl4VIkNNNCFGyaeTbP6omkNEwMUlmTlKqxJ5rG1A2aklkqgl5y6SyLd0doSqjE0hkqgl46lfoo9TiQBBGfx8Gm/WGefflNal++D8nhoeyi2ymt6EGSAzocWtNeGl78NXpkP0UTvkbgxPMRBAFTTRfcN5ydBlI69SbkQDlq3XaaX/3DgY7rqdcgKC7iy14ktvApECWKxl5u054Njdji/xJf9iKCKBEYeT6BEy+wO7qRWmILp5PaOBdBVvANPYvAyPMKujKWZaHWbCCxfAbpbe/Z96rrULwDT8PTa/QxiaEayQjZmg0FD3CteQ9GMnzonQURRNFWCjONQ+4iuvwopZ1xlHfDUdUbZ3V/5GDF0eMTQyezaxWpdXPsz2LqODsNJHDCOfbMfCumYGLVTOLLXsJMx3B1HUbR+CtxVvXGMjSSa94kunA6ZjqGd+AZhCZcjegtIr1pHpF3/g8jGbGZieOvQpDzujGL/2szFs66AU+Pkbaq+pv/yNukDqB02i3IAfuep7YspHnmn+zCzDm3FixpvUDfDgpGIsqcB24jWbeTzmdfT3Dw2XSvcOKRRXKGRc4At9tB92Invcu91ERz1MezhDwyZ/StontlgETOAEtAFiGcytKU0PA4bEuyzsUeRFEg4FLwuJ2YpoksiUe0FTNN+zcI7dTzTwvZbJZgMEhFRQW7d+8+5iLJFz4Rv/rqqxk1ahTf+c53ANs/9IEHHuCvf/0rX/va1/jFL35xvC/1fwZHCtoPte1YOuKiIHDtuG78a+FOcvnEW5EE7jxnIO9uaeDNjfWfymezLIvY4meIzX+iMNfzQUXR7N71NL38G0w1TfFkW8ALbJGR8Bt/tVXROw+i9OwbkYPl9gzy638iV7sNd6+TKJ54fZ6WPpfw249gpmL4hkyi6OQrET1B0lsW2oIj0TqcnQdRdPJVuKr72VSszfOJLfovWvMepGAFgRHT8A2eWEgCLcsku3styTWz7AXL0JBDVXj6jMPTezSOyp7HrJpuWSZGMoIRb7D/m4lhqimsnIplanb+LUqIihPB4UZy+xG9IWR/KXKw/ENR2fR4E5nt75HashB1z3rAwlHVG9/gSXj7jS9Q1y3LJLNjBYllL5LdvRbB6cU/fCqBkeciuQPoyTDx954lsep1BFHEP/I8gqMuQlCcJNfMIjrvcUw1lVeZvQJBdhBf+gLRhU8hKk5Cp38L74BTwdSJLphO/L3nbFX1aTfj7GDbTeWa9tD00n1ozXspOvmrBEZf3OaeWpZJbP6TxBY/g7PzIMrOuw3J7T/me9EaPodEMnfooOZIaLEqg7bWZHe9uuGwneuPY1H4UZP4Lzosy6JDhw6Ew2GSySSKcuy/ifZE/PjgwyTi7XHB8UNrL/EP260zTZN4OmcTrUQ7KcmZkNU0NENAkUHTLAQR4hmN/eEMeyMJdNMilsoR8CmYBqRSKjXJDOm4SknQTacKPyARcCoU+5y8vnInS3fGqNu7k33P3YWZjdP9/JvRuo1uez2qLeCV3rrIZsyd/UNEl8/2sV7/NuHZD4IgUjLp23j6TQBDOzCD7C2ieOL1tk91tI7IWw+S2bHcnnU+41u4uw61C8tz/0N683yb9TX6K/iGnoWoONGa9xJb/F9SG+eCIOLtP4HACee2YY3p8SaSa98kuW42RrwBQXHi7nEint5jcHcf8aHmwU01jR5vwIg3YaRjmNkkZi6NZWg2J1oQECQF0eFCdPkQPUEkXwlysOJDrYOWniO7ey3prYtIb3sPMxNHdAfw9j8F35DJbcbS9Fg98RWvkFwzCyuXwdVtOMExl+Kq7m/r42yaR2zBdPRoLc5OAwmddi3Oyp7kGnYSnvMQ6p51eXr6d3FW9bYZc2/8Ba1pD94BpxI6/ZtI7gDZPetsu9NEUxubVMvQiLz7qC3MV9WHjuf9BDFQhgV08UPQBY17d7D28bvJpJL0veRWgj1PwOsW6FVehKZb5EyNhphG1zIviqTQo9xDp5CXjJaj2OuiU7EXr8dBIq3jdcoIgoggmIRcDkxBQDBty1+nLCJLIrIsFX5fLayTwr21DiTftl5Duxjbp4l7772X2267jUceeYRvfOMbx3zcFz4Rf+2117j11lu59dZbmTdvHnPnzmX48OHceOONjBo16lO40v8tHOuMeMu26Uv28Pr62oNmxO+buamgfi6LAobZdub75F6lNCdVNtYmPo2PVUBq0zyaZ/4J0Rui/IKfHzSvZCQjNL5yP+qedXgHnkHxxOsRHS574V032/bfBIpP/ybeQRPBMokve4nYgukgihSdfCX+4VOwtCzRBdNJrHwVQXYSHPMVAiPOAUEgsfoNYoufwUxFcXUdRnDspbiqB9jJ6PalxJe+gFqzEcHhxtv/FPxDz2qz+BrZJOktC0lvmk92z1qwTERvEe5uw3F1GYKr86DPTCXVVNOo+zaR3b2GzM6VaI27AJCLq/H2PRlv//FtaN5GOkZy3WySq99Aj9Yi+UrwnzAN/9CzEJ1e9EQz8aUvkFz9Opah4xs8keDYy5D9pWR2rCDyzr/Qmnbj7DSQ4onX4yjrSrZmI+FZD6A17cbTZyzFZ1yP5AuRa9hB02t/RGvYiXfQRIpP/2Z+DMEitX4O4bf+gaC4KZ12M+6uQ9t+rlyG5tf+aCunD55E8aRvfybq9C3FrtYCbKIg2HOWx6Fz3T4jfmg8/vjjXHXVVdx+++3ceeedH+rY9kT8+ODDJOLtccGnj9ZJQkvy0KLebL8moOkmqayOIouomobHoYAA8XTeoskSCLodeJ0yCVVD03T2NqdoTKg0p7IgCIiIbKkNE06q5JDoVeLG55Lwux2kdBGHaDF/Sz2pjEoiq1Fb28iWZ+4mtX+7zY7KK6a3vu7E8peJvPtvJH8pZef8uFC81SK1NL/6e9T9m/H0GUfxpG8jeYJ2d/z1P6M17LSL9Gd8C8lfRmb7EiJzHkaP1ePueSKhU76OUtIJtXYr0bmPkd29GslXTODEC/ENnYyouNCidSSWvURy3VtYmoqzuj++oWe1GVOzLBO1ZiOpjXNJb12MmY7aI2PV/XF3HYar82C7WH8U1fXjAcsy0Zr2kN29luyuVWT3rMXSVASHG3ePE/H2G4+7+/DCemqZBtmdK0msfoPM+8sA8PQ9meCoC2zh25YEfNEzNpOwvBtF46/C3f0EzFSU6IInSa59E9HppWj8lfiGTMbKZYjMe5zkqplI/hKKJ3/X7oJrKtH5j5NY9nIbu1MALVpH04z7ydVuxT9ims1wkBRE7F6FG9C2zmPHK3/GHSji/Jt/j7eiCznTQtV1SrwuBMGiNp6mOZ6jW7mfbiV+qkIuepQHUHMGIlAWcONzySAIeF0KumkgmBB0OxEEAUUWoNWYR0tS3fJbajNqaZqFbWZ+brw9Ef90oOs6oVAIh8NBU1PTh7rfX/hEHOCtt95i8eLFpNNprr/+ekRRZMmSJTQ0NNDY2MjgwYO54IKPZzn0ZcORAvTpS/bw85fWtVE+FwU+thL6JwW1diuNL/wKU01TMuVGvH3GttlumYbtJb3oGeTijpSd8+NCIqxF62xxtr3r83Yk30MOlKFF6wi/+XeyO1eilHeneOK3cVX3Q2veS+Sdf5F5fxlSoIyicVfgHXAqlq6RWPUa8aUvYKZjOKsHEBh1Ie4eJyAIImrddhIrXiG9eT6WnsNR0QPvwNPx9jsZyXvgfhuZOJn3l5PZsdyeB8/EAZD8ZTg79MZR2aswFy75Sz9hsbYUWtMeco27yNW9j1q7Ba1xN1gmiDLO6n64u4/A02MUSumB5NvSNTI7lpPc8HZBWM5Z3R//sCl5+zUZrbmG+LIXSa6fA6aJd8CpBMdcihKqQq3dRnTuf2zxtqJKik75Op7eYzDTUSLv/sdWUPWXUTzpejw9R+VpgM8SW/xfRLePksnfx9NrVP4ztKIZdh5M6bSbCyI5LdCidTS+8Cu0pj22d/zI8z6R+3g4wcIWKJJAlxIv2xuShdeGVAe5ZGRnbn95fcFTXACk/Bxle+f608GECRNYuHAhsVgMr/foHr+t0Z6IHx98mJgA2uOCTxN2YmBiGEbh2akoCoZhtunaGYZJLJVDEG1VZ5csktJ04hkdSYKQ24XfrSCKIpFUDguTSCpDNJVlV2OaSFIl4HGwtylFOKdi5CwQDHwOFz6PQonXicuhsKkmzK5YGjWrYeo5kqksc5/8C+lNc22B1rN+cBDFW923mcYZ92Mkm9uMM1mmkZ8Rn47o8tqjan3GgmkQX/5ywWozOPoSAiPPByzbgnPxM1iaaouzjb0U2V9Kds86oguno+5Zh+gO4B8xDf/wKUjuAEY2SWrtmyRWvY4erUV0evH0n4BvwKk4OvQt3FfLNFD3byazfak9D54vhguyE0dlT5xVvVEqutuz4cUdj6o586H+zqaBHqkl17SbXP0OcnXbyO3fUtCpkYuqcHUbjqfHCbi6DG0zIqg17SW54W1SG97BSDQheorwDZmEf+jZyIFSTC1Lav3bxJe+gB6ts0f5xl6Gp88YrFwmL972Epah4R82heDYyxBdXlLr3yHybl7QdfhUik7+KqLTQ7ZmI82v/zlvd3oWoVOvKYzYpTbNo/mNv4EgUHLWDYeMEePznyD63rMUdenPoK/eRlGwCEEUkSSJflU+upaH2F4fwyXJ6EAyk2NU9xJcDhlFFjENC7/bQecyPw5JxOdUMC2LRFrDEiwcioRTEnDJEg5FRson061p5gcSbntb60Tc3k57Ev4p4bnnnuPiiy/+UMKtLfhSJOKt8dRTT/H2228TCoWoqKigrq6OFStW8I1vfOOYB+u/7DiaqNsl/1xcSBJa0K/Sz6a6T7f7fSToiWYaX7qH3P4tbVQxWyOzew3Nr/4eIxOn6OSrCJx4nr3wWqatcjr3URDy4mzDzgYEW4jl7UcwEk14+k8gNOFryIEyMrvXEH33UXJ125CLqykaexmevuPy802ziC99CSPRiFxcTWDE1ILiuJFJkNr4Lql1s21VckHE1XkQnj5jcfcc1caD07JMtIZdZPeuz8+Eb8WIHaD+C4oLuajSFmnzlyJ5Q0juAKLLi6C4EWSHfQ8EsEwTDA1TUzHVFGY2URBnMeKNaNE6u+qeh+j04qjshbNjX5zVA3BW92uj6GpqKtndq0lvWURm23uYagrRU4R3wCn4Bk/CUdr5gPL7ylftSrik4Bt0BoFRF6IUVZJr2EF04VNkti7O0/hs8TaA+IpXiC16GkvP2VYlYy5FdLjyVLS/2t3x/hNswZy8lVq2ZqM9Jx5vbENHa/Md2LmKphn3g2VSes6PcXcf8Ql9A48MSRT45rhuLN7RzJqaWNttQluPcTmvhh5J59o4Ghyui93e5f74iMftglcgcGRbvkOhPRE/PvioMQG0xwXHG6ZpFhJxSbKptS2JeAtaEoZMNkc2l8Mhy1gWRFUNWRDI6gbFHieKJKEZJtFMFsG0iKZzxDM5NM1gXyyD3+NifyRFczJLUyJNIqMxvEspliTgVRRcssSuSAoFC9UwaIhkWFkTpzFhUbfkOaJzH0Mp70rZ+T9DKapEAbT8NRrZpD2mtmUhri6DKTn7psIcca5xV94uc3thVE32l6LHGgi//TCZrYuRi6oInfYN3D1HYaZjxBY/Uxi38g09i+Coi5B8IbI1G4gvfpbMjuUIshPvwFPxD5+Ko6zrgTG1tW8VLMGkYAXePmPtMbUOfdqMVBmpqB0T1GxA3b+VXMMOMFo+kYAUKLPjgkApkq8YyR3MC7V5EBQnSLKdyFm23o6l5zBzGZu6no6hJ8MYiSb0WD16rP7AHLkgopR2xtmhD86O/XF1HogcrChcl2VZaE27SW97j/TmBXbBQBBtq7VBE/H0GoUgKejxRhKrZpJcMwszE8dR1ZvgSRfj7jUKK5e1GxpLXsDMJmzxtvFXohR3RN23mfCch8nVbsHRoQ8lk2xBV1NNE533HxIrZyIFyig564YCA85U07ZV7fq3cXboS+k5t7S55pbvQNMrvyW7YwWBwZPoNvHbuFwKggJeB3QOOXA7nIzpVcq+5hQ5006CBVlgXPdiTEFhXySJlrMoDbnoVhzA7ZQp9jkwDINkziCZNfE4RNwOEbei4HUpiPlueNvf1cFe4IfqlLfj+ENVVZLJJKFQ6IhuEYfClyoRf+KJJ3j77beZMmUKAwcOpFOnTng8Hl566SUee+wxXnjhheN0tV8sHE3U7bezthx0TGXASTiVQzdtCq0IyJJQmBv/LGDpGuG3/mH7RLbyB20NIx2j+Y2/ktn2nt0xPfuHyEGb9q21iLPtWpVX7v4ujorumLkssfeeJb70BQShZZ75QgSHm/TWRcQWPInWtAe5uJrg6Ivx9psAQHrLAuLLXiZXty1PS5+Ab/Dk/Py3QK5xN6lN80hvno8e2Q+Ao7In7m4jcHUdirND3zYVZrAXDa1hJ1rTHrTwPvRoLXq8ESPRjJn9cIURQXEieYuRAmUoRZXIoQ4opZ1wlHVFCpQfROfTm2vI7F5NdudKsrvXYukqgtOLp9covH3H4+o2DEGUMJIRkhvmkFwzCz1Si+gpwj/sLPzDpiB5i1D3byH23rNktr2H4PTa4m0jz7Pv5+YFROf9Bz1ah7v7CYROuxalpNr2Fm9NRZv0HTw9T7SvzdCILnzKnhMPlNlz4nnLsgPXf0A1XynpRNkFP2sjKPdZQxIOVLtbW5IdTem8XQn946FlNu/joD0RPz74qIl4e1xwfNF6ZjyXyxUotrIs0zrGbPldaZpe6JQbhk5cNRAQ0E2DgEvBsAScDplkViWRzpLJGSRVjcZYBoci43fLpDSL3fVRwokU4ZRJj0o/0VgWQZZRnCKZjInHIRBNZclhsfz9MDEVXDLsXLWcPa/8FgSB0qk/omffkWgmNBgHPk9y7VtE5jyEIEr22tJvvJ0EmUZ+VO1JW5wtP6omiBKZnSuJzHkYrXkvzs6DCJ1yDc6qXmjROmILnyK14R0EScY3ZDKBEy9ADpSRa9xFYvkMUhvfxdJzNi19yJl4+oxBVFyYapr01sWkNs8ju2sNmDqipwh39+G4ug7D1WXIQQwvy9DRwjVojbvRmmvQovsxovXoiWaMVARM/dj/uIKI5C3Kz4qXI4eqUIo72c4spZ0RFWeb3c1skuyedWR2rSazY3mhUeDs2A9P35Px9j0ZyRfCMg0yO1aQXDOrQE939zyRwMjzcFYPsMXbVrxCYsUrmJk47u4nEDz5q7b1WLSO6LzHSG+ah+QNUTTha3gHngoIpLcuIjL7IYxkGP+IqRSNv6rQBc/uXW/PiccbCY7+CsExlx5E5c/V76DxpXvQ402Un/EtvMPOxgfkAAXQgQoPBH0CJ3QqoWdVECyLHeEkfSuLEQSDgEOiKWPicSgE3BKKKFAZ9OGQYV80SbHHSyKjEvQp+F1uPA4JpyLbI4iiUChk2R3wgxPxdny6ON4xwRcqEd++fTvXX389P/nJTzjttNMKFYstW7Zwyy238JWvfIWvfvWrx+tyv1Bo7U0sCPCtk7vzk7P7sWJ3hBdW1vDUkj2YHzhGABRZ5KIR1QzsEGT9/hhLd4bb0G4/KyRWv0F49oNInhBl5/2kMP/VgsLC+/bDAIROuxbf4EmFB2Bq47tE3n4EM5No4xmux+qJzH2M9Ka5BSEW/7CzQZJJb15IbPEzaI27kPxlBE44B9+QSYhOL+r+LSRWzSS9eQGWrqKUdsE78FS8/cYjB8oLleTM9qVk3l+Gun8LWCaC7LAVSzv2xVnVB0dlDyT/4b3DLUPDzCQx1SRmLotl5OxqtgWIIoKkIChORKcH0RVAUJyHPZeppsjV70Ct3UZu/2ay+zZipqIAyEWVuLufgLvnibg6D0KQFMxclsz2JaQ2vktmxwqwTJzVA/APOwtP77EgSWTeX0582Ys2Tc/ls2l6J5yL6PSS3b2G6Nz/kKvbhlLahdCp1+DuPsL2Fm9DRZtii+XlRWtaW5a0nhNvDSObpPm1P5DZvhRPv/GUnHnDMSnRfpqYdBgLs9ZFMgG4vJW42we3tyuhHx3pdBpJkmhubsbv9+P3fzRxvtZoT8SPDz5KIt4eF3w6OFYBN8uyMIwD3XNRFEEQyGkGkijgciokszqSaOuHZTUNsIhnNbbvC+NWBCxBIalmqY/pmBbolopgSagGyJjEcxq6qiE7HOxsTOF1QiqnUR/NopsgWLBxRx17X/w1WsMuQmMuoevYy4iLEq1TVC2yn6ZXf09u/xZbi2TSdwqFfC1SS/itfxw0qmYZOsk1bxBdMN223+p7MkXjrkApqUYL7yO2+FlSG98BwNtvPIGR5+Oo6I6RiZNc+xbJNW+gR2oRHG48fcbhHXAKrk4DbY/rbJLMjuWkty8ju3NlodAuF3fE2bEfzg59cVT2xFHa5aCCfZu/k5rCyCaxcmksTcUy9LzyFyDKBRFX0eVDdHoPYpEVzmUaaOF95Oq2o+7fgrpvE1rDTsBCUJy4ugzB3WMk7h4nIvtLbFvT+vdJbXyX9Ma5GKkIorcI36CJ+IeehRwst8Xbls8gufZNrFwGd4+RBMdcirNDH4xkxGYYrH4jr0R/HoGTLkJ0uPPe4v8s/D1KJn+3EOeZWpbovMdJLJ+BXFRJyZSbcFX3O+jztPjIiy4fZef91L6nQKkMjbo9M54GAsCIbg4qSgIM6VhMecjD7vo4bqdMLKVTGXRhYiLLIj6HQnnQRcjrwgT2NSWoDHnIZHWCXgc+t8tWRzftWNuyLCRJys+Ki+3K6J8BPu2Y4AuViDc1NTFixAh2794NwPLly3nzzTdZu3YtHTt25L777vtQyrdfdtw3cxMPzjtg63r9+O48ungXOd087Cy4KMCP8lZmrWdcPw9Qa7fR+PJ9GIlmew74hHMOChi0aJ1thbVnHa6uwyg583sF2pLRYp91CM/w1jPNkq/YVkYdPAkkhcyO5cSXPI+6d71t5zXwdPzDpqCUdsLMJkltmkdy/Rxy+22WgaNDHzy9x+LpNQqluCOQrzLvXU92zzrUmo029SxPDxNdPpSSzigl1cihKlvh1F+G5AsheYqOmFy3hmXoGJk4ZiqCnmhCjzWgR2rRIvvQmvZixBsK+8rBCpzV/XFWD8DVZQhKqKpwj7I7lpPeupjMjuVYmorkK8E74FR8g84odLKT6+eQWPkaemQ/kr80X6Q4E8HhRm2Zodu7vs3MvSBKqPu3EJ7zELn9W3BU9aF40rdxVtpJpqVrxBY9TWzJc7ZlyZnfw9PzYFEmdf8WGl/+DUYybHuFDp/6qVaZKwNO6uKHtpaTRLsCrsgiT33z0J3sFbsjXPaw3fEG2+7sjnPa0tbbldCPDbFYjFtuuYVVq1YxZswYTNNk4MCBTJkyherq6o983vZE/PjgoyTi7XHB5wstMaem6bZFU14huiUplyQJw7BIqxqKLCKLIrppoWo5mhIZTEtgV12UpKbjcTowDfB5JHxOhb3hNNmcSV00RjpnIAsSWd3AwKA5kUURBARZZnc4yaZai5SmEn7rH6TWzcbVeTAlef0QAXBhe46LpkH90heILnjStjI947oD3XHLIr15gT2q1uIZPv5qe95ZTRNf8jzx5S9j6Tm8A04hOPoSlOKO6PEG4ktfspNNLYuz00D8w6fi6XUSiBLq3vUk180mvXURVi6D5A3h7nUSnt6jcXUahCArWKZBrmEn2d1rbEr6vs0F/RhECSXU0Y4J8t7hcqAMyVdi+4e7fIdNrj/4t7LUlD2ylgyjxxvRo3Vokf3o4Rq05r1Yeg4AweHGWdXbjgk6D8LZsS+ClL/O2q02PX3rIvRILYgy7h4n4Bt4uu3hLYpkd68lseo1MtuW2KJmfU8mMOpCHOXdMFIR4ktfJLHyNSxDy8/cX4bsL8FUU8QWPUN8+QwE2UHRuCvwj5ha+HzZPetofuMv6JFafMOmEDrl6wcV3c1c1v4erJ+Dr8tgupzzY3RPETkgBIRckAGCXtBUkEQY1LmIyiIXVSEvxQ6ZvfEcIbeM3+uk1KugSAKSKBHyuXHLErFsDh0TXTMp8iiIokSxz1Ww/zNNs83vo4VV0o5PF59FTPCFSsQBrr32WtLpNJs2bWLAgAEMGDCAQYMGMWHCBPx+P5lMBrf76H7I7YAr/29JGw/xriUe9oTTRxVku358dx6ev4PPkJF+WBjZJM0z/0Rm23u4e4yk5OwfHkRVtyyT5KrXicx9FCzLpp61erDn6t8nPPufqDUHPMNb5oqzu9cSXfAEas1GWxl15Pn4hkxGdHpQa7eRWDGD1Ob5YOg4qwfgGzIJT++xiA4XWqSW9Ob5pLcstGfFATnUAXd3m5bu6jSwYHVmaipaw05y9e+Ta2yhpe9vM9NdgCQjOn12Qi4p+RlxAUwzPw+mYuYyWLnMQYcKitOmp5d0xlHWBUd5dxyVPZG8Rfa9Mg1yddvJ7lpNZtcq1JqNYJl24NB7NN6+J+PsNAAQUGs2kFz7JunNC7F0FUeHPgRGnGOL3oiSXbBY/Czqvvy9O+li/EPORJAVm3kw7zHSG+fmqWhX4x14WmFOLluzgebX/4oerslblnzrILsVyzJtb/h5jx2kjPtpYnyvUuZvazqkiNv147vjdyuH9ApvPfN924vreGrJHpvYgE1fNy2rQEWHI8+Qt8PGDTfcQC6X484772Tjxo3s2bOHDRs24HA4uOmmmyguLj76SQ6B9kT8+OCjUtPb44LPFwzDKHTPdV1HFMV8988seJPbHXMTSbLXXd0wqG2Ok9EM9oSzmIZBVtcoC7joWBxANyxqwmkiqTTJbI6cnqM2msbQIakbxJIZJAR8XgdYMH9jmpbScnLdbMJv/gPB4aJ0yk24u4+g2gGRHIVnbKxxNw2v/5lc7VbcPUZSPPHbhRE2M5fJj6q9iCCIBE484BlupKLEljxHctXrWIZmq4OfdJGdYGaTJNfMIrFqJkas3u4ODzwD3+CJKMUdMbUsmfeXk948v1DYFhxuu9PcQksv7lgoCujROnJ12+2YoHE3WrgGPVp3CG9wAdHpQXDkdWMk2fYStyws08DSVaxc1hZgsw7mPkqBMpSS6gO+4hU9UEo6FWIkPd5Eds8asjtXkdm1CjMdA1HC1Xkwnr7j8PQeg+T2oyfDpNa/TXJtfmTNHbDF24ZNRQ6U5gsWL5Jc8yaWoeHtP4Hg2MtQQh3yzINZRBc8iZlJ4B10OqHxVyP57PXOzCaJvPsoyTVv2F3wM28oeIO3Rq5hB+GX70cN76PrKZfSd+KlaEgYlj3GUBn0oGsZfA6LaE7C6xbpFPRT5HER8ki43S56l/uojaTJmFDmc1Lic1PkceKQbeE1WVKwDI1MTkM3LTTNoMTvxu22CwKiKBQ0Flo64R92Brkdnww+i5jgC5eIq6rKunXrSCQSVFdXU19fT2NjI9u3b+fpp59m0KBBTJkyhYsvvvg4XfUXB9OX7GnjIX4sHXGAYo9COK0dfofPGJZlkVjxCpF3/4XkDlAy5aaDrKwA9HgD4Vl/J7Njue1TOek7hcTNsizbM/zdf6PH6nF1GULRhKtxVvXGsiyye9YSW/QM6p61iE4vvmFn4R8+FdlfipGKklw/uzAvLSguPL1H4+1/Cq4uQxAkGT3WQHr7EjI7lqPuWY+lqyCIOMq74ezYF0dVHxyVPW1F1FaV7TY+oakwRjqOmU3YHuKaiqVrWJYBlmUnsZLShoYmeYJIniIkfwlyoBzRW3RAqdWyMJJhO/mv3Yq6bzNq7ZZCAu+o6IGr+wg8PUYWxGQKM+8b30WP1Rfm4m27th5Yeo7UxrnEl72I1rQHKVBG4MQL8A+ZjCA7MNIxYov/S2LVa61m8S8qUM2NdIzo3P+QXPsmUqCcksnfPaTYmp5opnnmn8juWoWn9xiKz7oByeX7pL9aHxoeRSSt2YGOANw8+WAa+aFmvuFA11s4jtZmX3TceuutnHDCCYX1QFVVNm3axAMPPICu6zzwwAN4PMfu09uC9kT8+OCjJuLtccHnC62TjhZxN1VV891wA1mW0XUdRbHVvnXDLCQr8VSG3Y0pdMsklzPoWOKiIugnnMlhmibJtIaFQVI12d8YwZTc1DZH2ZcxyCQzRJJZvE6THVFojtrXIwDhpj3sf/k3aE27KTrhXKomfA1LVuyuOLaYm24aJFa8QnT+4wAEx15O4IRzC3PGbTzDPUGCYy4tFJONZIT4shdJrH4975c9gsDI83B1HQqWSWbnSpKrXyfz/nKwTBxVffAOOAVPn7HIvuK8GOoaMu8vJbNjZYGhJnlDODv2w9GhL86qXjgquhcK9mAXy41EU0E3xkhHMTN2TGCqGXtczdCxrLyPuCgjyA5EhwvB6UVy+RG9RUi+YmR/KXKgvA3t3dRUtMZd5Oq2FejperQOANETxN11mE1P7z4C0eXDVFOkt71HasO7ZHevyY+s2XPx3r7jEGSH7Siz7CW7aQF4+59KcPTFKMUdsSyT9OaFROc/jh7Z38ZbHPKx2aa5hN9+BDMdJ3DCuQTHXXFQF9yyTLIrXqHx3X+jeAL0Oe9HdOo7hJDPZqZ5nSIZQ6Rfh2Ka0ib9q9xkUiqSKDCseyUejxOnJLE3kkLEJJlSKQ+5CLgUXA4XHqeEYdje4E6HA1GAWCqDQ1YwTBMsg4DXnaeht/YDb0/AP0t8FjHBFy4Rb43169fz9NNP09DQQFVVFWPHjqW0tJRJkybR0NDQ/oU/BnzQQ7ylM7dmb5Q3N9Yf/QSfY+Tqd9A44370cA3+kecRGn/VQTYfB6hnD2MkI/iGTKZo/JWFLrqlayRWzyS26BlbUKTnKIrGXY6jogdg06DjS18gvXUxAJ4+Y/EPn4KzeoC9vWYDqfVvk9qyEEtNIbr8uHueiKfXSbi6DkV0uLF0DXX/pjwtfQNq7bZC8ivIDpSSTiglnZBDHWxqeqAsr5he1EbV/FhgmQZGOoaRDGPEGtBjdWiR2oLwS4H6JogoZV1spdROA3F1GYzkCWJZJrnabXYRYetitOa9tkpqlyF4B5yKp/cYRIcLPdZAYs0bJFe/gZmJo5R1JXDi+Xj7TUCQZIxM3KairXjFpvUNPJ2icVcU1Gst0yC59k2icx/DVFO2ivrYyw85553esojmWX/D0lRCp38T35DJn1vK1/XjbS2G1jjczHfLbzHkcXDXqxvaqegfAXPmzOHrX/861113HVdddRWdOh2w4DvjjDP405/+xMCBAz/0edsT8eODjxsTQHtc8HmBruuFefLWcWgqlcLhcJDL5XC5XFgW5DS9oCYtIBDJZNFyNq3drUjIokAyl0OWFWLJFOF4hqZ0DsOwMAyL1XsaiGR0KoJO9jRlaI6rNCcgrVJItFUgqqlE3/0XiZWvoZR1pXTazTjKuh587bEGwrMfJLN9KUpJZ4onXt+m26ru30Jk7qOoe9YhBcoIjr4E36DTESQFI5MguWom8ZWvYKaiKCWd8A+fUnBR0ZNhUhveJbXh7bwtmYCzuj+e3qNx9zzR7gZbFnpkvx0T5B1U9FbuKVKgHEdpZ+TijiihDraLSqAMyVds09KFY/+OH6CnR9DjLcrpdWhhe2RNj+wvdM1FbxHOjv1wVQ/A1XkwSnlXBEHESEZIv7+UzLb3yOxaBYZuq8D3PwXfwNPsBFvXSG9dRGLla6j7NtpjfIMnERh5bkE3J7PtPaILp6M17EQp7ULRhKtx9xhZWM9zjbtstuKedTiqelE86buFBL01jEQzsZl/IrFrFR0GnMhZ37iFrOQmqYJbAa8iU1nmo6Y5SdfiACUBFyO7lLIjHqfM7cQpWlimjCyLaGYO3bSVVSuCHnxuN7Io4HU6EbDwuBzIsoRumEQTtsWbYFk4HQ5kCRRFRpbl9kT8c4LPIib4wibiCxYs4LLLLuPaa6/la1/7Gl26dClsO/vss7nrrrs44YT2OOnDoHXgv35/jOdW1KDppm2F9fn9Gh0RppYl8s6/SK6aiVLamZIpNx3ywW2qaaILniSx4hVEh5vguMvxD5tSqISbapr48pdtj0s1hbvniQRP+grOjn0Bu1KeWPkqqbVvYaop5OJqfIMn4Rtwqq0gqmtkdq4kvWUBme1LbUqYJOOqHoir21BcnYfgqOiOIEq2QErzXrsz3XCAgmbEm/iga7WgOBFdB2xKBFmBlg66Zdkdcl3FymXyFmYHC+uJLh9KsU1DU8q64qjojqOiR0GJVI81kN2ztkBPN9MxEEScnQbi7TMGT5+xSF77M6a3LyG57i2yO1aCIODuMRL/iGk2E0AQ7K7B8pdIrJqJlcvi6TuOonGXo5QceBhm96wjPOdhtIYdODsNpHji9YcMlsxskvDsf5La8A6Oyp6UTv1Rm/N8HiGJAv+9bvQhVdCPlGi325V9dKxatYpnn32WpqYmOnbsyIQJE5AkiYsvvpi6urqPdM72RPz44OMm4u1xwecHH1RTz+XsWeNcLofT6UTTtILyumlaZLIquqGjyBKZnIVqWrglCY9TRM1p5CwTTKiPZ0irBulcjmgyi2oK+JwCNVGVplgaCYtttTF2NtlJOCaUF4NhQiIGO3VIv7+M5pl/xlSTFJ38VQIjzz/kTHV62xLCcx7CiNXj6T2GolOvQSmqLHy+7K7VROc/Qa52i52Qj7oI76AzEBUnlq6R2jSPxMpXyNVtt5lxfU/GN3gizo797HvStMceV9u6uOAVLhd3tGnpXYfi7DSwwOwyUlGblt6wg1zjLrTmvejh/TabrjUE0bYuc3oRFBei7IACNd20R9Z0FVNT86JuCTA+oLIuyshFlbajSml+ZK2qZ0E41sxlUfdvJrt7NdmdqwqjdlKwAk+vk/D2HVfwRc817ia1bjbJ9XMwM3Hkoir8w6fgGzwR0enFMg3SWxYRe++/aA07kUNVBMdejrff+MLfxMjEiS14ksSq1xGdHorGX2UX3D/wN7MsC3PTPGrf+geWoTHqvG8w/KwLcUoCsmWydncWSwTBAd1CEpWlQToV+QkoIjnTor45TnWZn0hWo3epB8XjRbEskpqJS4B0VqdHxyA+lwNZkHA6HUiiYLMNsO38TNMkm83ich1oGrRWSf+8Ngm+TPi0Y4IvbCJ+++2307VrV6655hrAfrhv2bKFhx9+mIaGBh566KGP5A/7ZUVremyLWrMiCVx8Qics7M754SAKHHWu/LNGZscKml//M0YqattajL7kkKqjuaY9ROY8THbXKuTijoQmXI271+jCw9PMJomveIXE8hmY2QTOTgMJjLoQd/cRCIKIqWVJb5pPcs0s1P2bC56a3v6n4Ok5CtHpwTJ0sjUbyL6/nMzOFWhN9r0VHG6cHfri7NAHR1UvHBU9kHwlB6jjes6uVrdQ0FIRW3wtY6ujmpqKZWgH5sVE0aahKU5ExWUrpLr9NjXdZ1uYycGKNrPWppom17iTXO121NotqDWbMBKN9unyNDRX9xG4u5+A5PZjmQZqzUZSm+aS3rwAM5tE8pXgGzwR35BJyIG8TVxzDfFlL5FcPwdMA0/fcQRHX4Kj7ECgrDXXEJn7KJlt7yH5ywid+nU8fU8+5MKVeX8ZzW/8DSMVITj6EoJjLjnIpuTTgMAHSyNHx5DqILdPG3DYGXE4tvnv9uT88Ein06xYsYJcLkfHjh1paGigqamJrVu3Mn36dEaNGsXUqVM599xzP9L52xPx44OPm4i3xwWff+i6TiaTQRRFFEVBFEVyuRzpdBpN17FEEYfDgSjKJNMpHLIDOS9yaVgWsXSW5rRGNqORUDNE1BylLhdZQSGZTpPLmazd20Bjk0FTBrwKdCkXaUybJFJQm7CtqjLpGPWzHiC9dRGODn0oPfuHhyzkmppKfNmLxN97Fss0CAyfRmD0VwrrpmVZZHesILboadT9m23rzhFT8Q87G8ltf9fU2q0kV79BatM8LC2LXFSFt/8EvP0moJTa76lF68i8v6ztuBqC7ePdsR/ODr3tcbWSTgiSkn9vEyMZQY81YCQaMZIRjEwMMxPHVNNYuYxdjDdbVNMFBFGytWQcbkSHOx8TBBG9IZuaHixH8pcWklzLNNAjteTqt6PWbkPdt5lc/XY7zhAlnB364u4+AnePkShlXW1dgHgj6c0LSG18107SRQlPz1H4hky2LU/zsVJq/dvEl72IHqk9YAXb/5QD763niK94lfjiZzBzGXxDz6Lo5CsK97X1+mukIoTf/AfprYvwd+zDyK/eRFFpR4qLHPgVkd0NWSIJCAZAV6FfJ4XuFeV0LnGSVMFQDXZFovi8DrRUmo4dgpR7A6iZDCXBIKIFUTVLtxIPPq8PSQCnouQLSQaKIufFCWUMw8AwDJT89nZ89vgsY4IvbCJ+xx13sHHjRv7617+yYcMG9u3bx/r16xEEgRtvvJGKiopP+Go/G3zcYPtIx7fe9t6O5gI9tgUC0KXEQ+diD/Naibr9r8LIJonkO6hKaWdKzvrBIcW8LMsi8/4you/+2/YL7dCXoglX4ep8gJpm5jIkV79BfPkMjEQjcnFH/MOn4ht4emG+WWveS3L9HFIb52LEG0FScHcbhqfXaNw9RhYE0fRkGHXPOrI1G1BrNtqJeQsVzB3IU9OrbWp6sKKNOuqHTT4ty8TMJm1qerwRPdZgK6RG9qM17WlLf/OX2oWBTgNwdRqIUtYFQRCxdI3snrWkt71HZtt7GKkIguLE3eskfANOw9V1qN3Zt0yyO1YSX/kK2R0rQFLwDTydwInnF9TiwRZ+iS16muTaNxEUJ8FRF+Efed5B/qVgz4xH3n7kwN/w7BtxVvX6UPfg8wC5lXf4B5PwY/EIb/cSPzLOP/98/H4/mzZtol+/fgwYMIDBgwdz2mmn4XQ6P7Z4V3sifnzwcRPxL0tc8L+Olu647aNsr3XZbBZRFElkMiiyjGHa23XTJKfl8DtdaLpOIpWmIaOTyOlYOY1IOkfOMCnxuzF0gbRu8X5zlFwmTjpn4nK4KfH72VnfRE2DSVMCGrETOLVl3vitf2JqWYrGXkbgxAsOua7qiSai854gtX4OotND4KSL8A+fVhiXsiwLde86YkueJ7tjBYLstPVSRkzDUd4NsOOG9JZFpDa8TXb3WsBCKemMp88Y3D1H4ajsUVhj1dotBRcVtXYrlmrTnhFllOKO9rhacQeUokqkYAWyrwTJV2wLtH3IrqupZW3l9PysuR6tQw/vs0fWmvYWuu6C7MRR2RNndT9c1QNxdhpgj9dZFnp4H+ntS0hvXXTAIaayJ94Bp+Ltf0ph3E+PNZBYNZPkmlmY2QSOql4ETrwQT+/RbZL/1Pq3iS6cjhFvxNV9BKFTrmlTuAfb99ttWTRueIfGOQ9jahk6jvsq/U4+n4pSCc0ChwQeCepisD0GXqDcDf2rYWDPjgiGTsZw06Xcxeb9YSRNI+QTECUnnSvKkK0sCdXEKckEfU5ciojf5UIWRVwOJd/tNgsCbC3OAO009M8XPsuY4AubiOdyOX7wgx+wYsUKRo2yLYz69evHlClT2tDR/pfxcYPtIx3/wW23Tx3AXa9uOKpQ2xcB6feXEZ71AEaiGf+IqW08qlvDnlN+i9jC6RjJMK4uQwmOu6KNP6Vl6KS3LCS+/GVytVsRFBfe/hPwDZ6Eo6p3fi7IRN23mfTmBaS3LbaTcgQclT1xdRuOu8sQ2wokP79u5jI2/ax+B1rjLrSmvWjhmgPz260gOL2ITq+tkCrb1HRBFIE8XcrQsfQcZi6NqaYxM4mDVFIF2YkcqrIT/rIueZXUnsh+OzG0Pc/32PT0nSvJ7lmHpWURFBfu7iPw9BmLu8eJhYBET4ZtKtqaWeixeiRvCN/Qs+wOQb740LJf/L3nSKx+HSwL/7CzCI6+pM0+hftsWaQ2vGN7vaspgiddfFhWw/8KJAHuPm9Q4XfnkEUuGF7dRi39R4cQd4N2L/EjYevWrUydOpWtW7cCsGTJEmbNmsX69evp1asXv/71rz/2e7Qn4scHHzcR/zLEBV8EqKqKKIoFRXVZltE0jXQ6jWmaiJKEhYhhWaRzWQTTBEFE1wySagZN1UhpFuGUitvhJpPTyBkaBhJ1iRT1zRFEh4zX6SLgcuNQJOqbU6ytSRBthqgFPhkiui3SpiQj7J79IPEtC1HKulIy+XsUdeyLDCQ+cO25hp1E5/6HzI7lSN4iAqMuxjf0zDaF41zjbttBZcM7WHoOZ8d++IZMxtNnXJt1Mr1lIektCwtuJKK3CHe34bi6DsPVeRCyP6+ZYpno4f0HXFQKiun1B6ueSwqS24/g8NiCbLIDRDmfnFtYhoFl5LByqh0XZBNY2gftNgUkf+kB5fSyrjgqe6CUdC4UKYxMAnXvOjK71pDduaIg4Oao6IGn9xg8fccVCu6WoZPZsZzkmlm2WJ0g4Ol1Ev4TzsFZPeAA6880SG+eT3ThU+jhfTgqe9lz4nmxXQc2m6EFSmQ/+9/8B+ldq3B17EeXM2+ge7dOVATB5XJj6BoNSR1TAzUJNQZUCVBdBn06eehTXYpTFImkVWRLIG0YBH0KLlFGESxCRX78TgceBVLpDG6X27Zec7txKrZIm2nayv+6riNJEoqitM+Df87wWccEX9hEHCjQm8LhMB6Ph2AwePSD/ofwcYPtIx1/qG0tnfGQx8Ff5mw9rBfyFwGmmiY67zESK19D8oUInXbtYWnQpqaSXDWT2JLnMNMxXF2GEBx9Cc7Og9rsr9ZuJbFqJulN87F0FaW0M94Bp+HtPwE5UAbkk9qGHaS3LyW7cyXq/i32QirJtkdnx34484rpUqCszflNNXWAmp4MY6SiNgUtm7TtyQrUdHteCUFEkCQ7QXe4EZ0eRHfAVk73FiMHSpEC5XYVvfX7ZJPk6neg1uWV02s2FooAclEVrm62Sqqr8+BC8GFqWTLblpDa8A6ZnSttpdROA/EPPQtPnzEFKh3YavXxJS/YHquGjnfgaRSNvazg5/5B5Jr2EH7rH7ZAS4c+lJz5/UPOjP8vomPIzf5IBgubgTKya4iluyKF7fecP4jLR3UG2mo4bNgf49nlezHynuTtHfEDWL9+PTfccAN33nkn48aNK3y3161bx80338z3vvc9pk2b9rHeoz0RPz74JMTavuhxwRcBlmUVZsRtVWn7N9qSoJumiabr5DTdHpOTZZoTcUxLIqflaIrHCfn9RGMqSd0klc6QUFVUw0Q3VfY3pigJeHDLMi6Xg7imo6omK3fGSaehMWl7R2uCTVM2LbAE2L15MfveehA9GaZi2GQqTr6KmPvQowxGzUYa5z9hO6d4iwiMvAD/0DPbFPWNTILUutkk1ryBHt6H4HDj6TMW34DTcHYacGAGOh0js2M5mR0ryO5a3Wa9dVb3w9mhL47KXjjKurQRnLUMHT3RZAuvJpowUhHMdMyOCdQUppbF0jVooaYjgGRT00XFheD0HNJNRQ6Wt32fFt2auu0F5XStcTdgISguXJ0H4e5+Au6eIwvjaJZlkat/n9SGd0htmouZiiL5ivENmohv6OTCfi2fI7XxXWLvPYse3mcLtZ18RZuxQIAgEMOmrMeWPE988X9Bkuk64WqkYWcTFEQ6FEPPchfZnIZLFPF5JfaFs9RFIJuGoB/8HhjZLUBZkR9JEPE7JZozGkGHgMfno6LIjyLoKCKYuomkKDgkgUxWozhojyQosoQkSYUueMt1tiTf7bPgnx981jHBFzoRb41kMokoih9Jdv7zimMRcfqoxx9t26UPLUb7PBqFf8JQ928h/ObfydW/j7PzYIrP+NZhkzwzlyW5eibxpS9ipCI4qnoTOPGCNpQqsBPm1Kb5pNbNtufEAWfH/nj6jsXTe3SbBchUU2T3rkfds55szUZ7nsq0hVNElw+ltAtKaSeU4mrkoirkogpkfymC0/uxHvSWnkNPNKPHG9AjtTY1vXkvuaY9GK3o6XKoKq+cPgBn58EFoRr7fmTyAnQLyWxfiqVlkfyleAecim/QGW3o52D7ecaXvkhq0zwAvANOIzj6KyihqkPfbzVFbOHTxFfMQHS4KZpwdV4R/YtbZRbzwohW/v9fdmJnOhS5C8rpqmYWtsmSyEUjqrlweHV7Ev4B/N///R/Lli1j8uTJDBw4kOrqatxuN3/5y19Yv349Dz300Mc6f3sifnzwScYEqqqiqio+n6+9M/U5RIuneGtldcuyCrR1p9OJqqoFT3LTNNkfjhLNqHgkEUVxoJsmuiWxpzGCaenURiMYOQHD1HE6nXgVGUlS8LrdNKY1lmyuxRIhmQZFAETImWCqUBQQsWTY35Rkx5wn2T7/FWSXj9KTr0IZPPEgYTAnthJ7du96YgufIrt7jW1lOnyKbWXqO+BHbFkWas0Gkutmk96yECuXQfIV4+kzFk/vMTir+x+gZVumXQhvGVfbtxkzHbVPJEo2Lb20ix0TFHdADlbmx9VChxScO1ZYlomZjuWp6fV2TBCuQWvag9a8F0vPjxM4PDg79CmMrDk79Gk1s243GlJbFpHesgA9vA9EGXfPkfgGnoG7xwkH2bEm18wqjPgp5d0Ijr7ELt5/YJ33YI8SZN9fSsOch9Gjdfj6nkzv067F5S8hAYREcHmga6ULv2xgAd2K/UTUFFv3qUiAbEH3aicn9uyEYJlkdRXBFLEsEb/bhSjbVHanQ8TncqNpKlouh9/nw+lwUhQMtEm2W1gdQJuiUjs+X/gsY4IvTSLet29fGhsbaWhoQJI++sPo84LW3a9IOveJzIhDWxGow82Pt+6WfxlgmQbJNbOIznsMU03bgiDjLi/MNB20v54juW428aUvokdrkQLlBRVQ6QPVcy1SWxAya1FFVcq75303h+Os6tNmHs3Sc+QadtoUtPodaE270Zr3HqR2LshO2/vT7Ud0+hCcbkTZCdIBajqWiWXqWFoOU8tgZW2F1JaKeRtICkqog01DK++Go6KH3ZVvdQ8sy0KP1pLduZL0+8vsGTdDQ3QH8PQeg7ffeJydB7ZZQC3TILN9KfEVr6DuWYuguPANmUxg5HkFlsAh/x5r3yI6/wnMdAzf4IkUTbj6sH+P/yV0KfawJ5w+osCbLNpBqSSJYFnopoUoCBim1ea4dkr6kfHoo4+yYMECSktLCYVCNDU1MW/ePO655x5OP/30j3Xu9kT8+OCTjAn+9re/8f3vf58ZM2Z87G5HO44/WvuOt8SthmELj7YkO6l0mnQ2h0NxYJgGWCa5nMae5jhhNUtjNIWAiaK4cBs6KgJBfwCnotCoatTsD1MbT+MUQFAkYmmd+giE42Ap4HeC1wteF+gNNcz819+I7lmPo7wbodO+2cbC7INQ928hvuR528pUlPD2Oxn/8KkH6dCYWpbM9qWkNs0ju3Mllp5DdAdswbPuJ+DqOvTgdTdWX1BM1xp22orpsYa2tHRBzDPeAraTitODoNjUdEGUbNX0AjVdw9KyeWp6EjMdw0jHDgi95iH5SvJuKnnl9MqeKCXVbdZ4U02T3buOzPvLyby/3BZ3FURcnQfi6TvedlVpJQgLoIX3kVj5Gsl1b2HlMjg7DSQ46kJc3U84bCLraNzF/rf/j9SuVbhKOtHhjOvo3Gcoug6SCboFxcUgmFDug3KfiG6aBP0OHIoDQdVwOpw4nA76VwcJ+jwEvV7C0ThZwyClGrhlgRK/B0kwcSsOW4gvm8XEJKeqVFVU4PV6C3ZkQKGQ1J6Af/7xWcUEX5pE/P777+fWW2/l4Ycf5tprr/1EzvlZ4XgIMU1fsofbX16PYVo4lSOfs+X9W7pvLegYcrMvkvlY1/F5RmuLDEFxERx9Mf4R5xxSNAxakswlxJfPQN27HkF24Ok7Dt+QyTg79j/owayF95HetpjM9qWo+zaDZSI43LiqB+DsPBBnxwE4K3sc5HUONnVNj9a1VUdNRzEycaxsyqam66pNQ7NMLCx7sRRlBNlhq6M6PYguP5K3lUJqoBy5qKKNSmrh8+Xn0tR9G8nu3Uh2z1qMeANgU+bcPUbi6X2SPeP1gWP1eCPJtW+RXPsmRqIJyV+Gf8QUfEPOLNixHHQ/LYvsjuVE3n0UrWk3zo79CZ3+zf9JMbZDQQBuntyHREbjwXk7DrmPQxa5Y9oANuyPsX5fjHX7YpiW3QEXBVs52MT+94d9NnwZVNZnzJjBvHnzkGWZb3/721iWxdKlSwmHwzQ0NDB8+HCmTp36sd+nPRE/PvgkY4J4PE4wGKRv375s2rTpEzlnO44vWrrihmEUuuQt/7+l66hpGs2RGLqu43U7SaRSRDMZYppBTUMYRBFTNfE4nSQRkSyREr8fBA0JBzXhOIJikdNMlm1vIhKHaBqyOgSc9rNV8dhq61vqLCKb57Pv3X+jxxvx9DyR4ISv4SjtfNjPoEX2k1jxCsl1s7FyGRyVPfENORNvv/EHadGYuQyZHSvIbF9CZseKAiVdKe+Gq9MgnJ0G4OzYr013vXCvdC0fE9SjJ1o5qeQL7VYuk6em58A0sCwTASEfE8i2tZnDg+jyIrqDSF7bTUXOu6nIRZUFC9PWMLJJcvs2k923EXXPOtTarWAaNkW961A8PU/E3XPUQYVzy9BIb1tCcvUbZHevBlHG03csgRPOO+IaryeaiS14kuS62UgON6VjL8c7fAolkozPAVkDhnSCjAZuGUQFyIEugCJCRZFChd+L6HDTweck4HNS5nUiywIhv5+cpoMgYpgWhqHhUmQ03UDAQhREDEPHyhcoPB5PIREXRbHdkux/AJ+HmOBLk4jncjn8fj8lJSXs27fvf/qH8VFmw4+mjn7JPxej51vcrUWgDndcy+uJjMbiHc1sqI2jfwmo6mDPJEff/TeZ95ch+YoJjr0M36CJR1QozzXsJLFqJqmN72LlMsjFHfENPB1v/1OQg+UH7W9kk2R3rSa7ew3ZPevQwzX2BlHGUd7V7kiXd8tT0zsjugPH9Tttaln08D5yTXvQGnaSq99Brn57oXMuugO4Og3E1WUIrq5DD6Kdg03dT29bTGr922R3rQbA1XUo/mFn4e456oi0uWzNRqLzHkPdux65qIqiCVfj6TP2f/p3/EFIIgzsEGR09xJmb25ge4N9bwWgIuBkcHUR103oAdCmENaSdN8+dQCRdO4jsWS+DCrrDzzwANOnT+fGG2/k9ddfZ8GCBYwcOZKbb76ZoUOHfqLv1Z6IHx98kjEBwPe+9z0eeOABXn31VaZMmfKJnbcdxxetE/EWCrvNErLncrPZLKlUyhbxTKfRLAvNtFiwfguSy4Wh59A1kJ0+Qi6FEn8AySGiyDL7muLEMhkiKZUtu5M0xG3Btvo0OEVwCPYz1+2BaBz2Y+vEpFbMILL4WSwti3fAaRSNO7yuCdid4tSGt0mseh2taTeC4rSZYwNPx9V50MGFb9MgV7ed7K7VZHavIbd/c4EOLvnLcFT1tLvSZV1QSrsgF1V+LCr60WBZJka8Ea1pL7mmXeTqd5Kr327TzQEEEUdlL1xdBtnich37HySealkWubpt9pz4xrmYmTiSvwzfkEn4hkw+ZIGhBUYmQXzJ8yRWvIJlGgSHnU3x2EsR3QG8+X1KJAgEoFMxBHxOFFFANwREEUzDxC1bGIiUB3w4JYuulaVUBd0IArgUBcswEABRtlmELkVGsCz0vPWYx+3GMHQ0TSuIr7ndbhwOxxcqNvmi4vMSE3xpEnGAX/7yl9x111089thjXHnllZ/YeT9tfNjZ8KMF2Q+8s53fzdpS6G6LAvzqvEH0qfQf8biWZHx/NMP0vJLzlwnZveuJ/n97bx4nR1Xu/79Prb1Oz5p1MlkhZGGRIGGT5cqmgggoCHj9er2yqKCIyE8QuMhFUAEvLqCi4P1evoAom+IOXBARCZAAJiGEhJBlkkwy+9ZLdVWd3x/V3elZM0lmpicz5/16Baa3qlPd1dXnc57n+Tx//b9ktq7BSEwmccz5RBf9y6CC3HdSJN9+ka6Vz5CpXw2APX0BkYOOI3LgMQOmY3vdrWS2riGz7R2chnU4O97tkT4u7Chm+ZSg73e8KmhfFkmghcuCFW0rXHBND1LQilLTXSdYHc8kg77jyY5g5byoVYnX1bxrMLqBVT0Ta8o8rKnzsacf1Ccdbdfxpkm9t5zk2y+SevcVZDaDXjaJ2OJ/IXrwyT1qyvsjs/Vt2v7+EOn3VqBFyyk/5pNBHbi+/7qh7y2WLvjEETOQwC9f2RxEwoFjD6jmypMP3CfhPBFc1j/ykY9w2WWXFdKQ0+k03//+97n33nv5yle+wuWXXz5s+1JCfGQY7jlBY2MjkyZNYsmSJQzndhUjTz5VvbjFWV6UQ2DK19nZicibu/k+K9/bTGVlFVsattOWzTIlMQknm2XetCmkM2ksXaepK8221m66Mlm27mhlYwtkfdjZBokYGBo4GRAmWGHY3hgIdQPYmWyn6+Vf07ri90gpiR1yComjP9HD96U3UkqcbWvpWvk03Wv+hnSS6NGKYE4w/1js6Qv6FdTSy+I0vBvMCxrW4TSsx23dTqFztqZjJCYV1YhXokcrgjlBOB7MCfKp6bpO8Gsig+ium8UvTk1PdQTty7qa8TqacNt34rY3FBYCILcYMGUu9pQDsKYfhD31wH4j5kGd+Ht0r/07ybdfCMasm0Ef8YNPDvqID7KA4KU66XztN3S89lukkyK28AQqP/ApjPIp6AS14jEgDUwJQ2UCDAEL6sKUR+Iksw4dXWmitk1Z1EB6PjMmVWCaBvNqytB1E+m6JJ0M7dkslvSJRyJUlyVwshlCloVtB2nppmni+z6u6+L7PqZpFurAlRAf+4yVOcGEEuJdXV1UVFQwa9Ys3nnnnf36i7InaaS7m2Q/tGwz1z2xssdrQqbGuYfX8nBuwl/8uuWbWnlsRT2PLq8n6/poAiZIMLwP+XTpthcfxGlYj56YTGLpucQOPrnfFPJism0NJNe8QPeaFwr14daUebn68COwpswb8AdJSonX2ZQzSqkPen23NeB1BO6o0tnHEgHN2JWGVj4Fo3xKrl/5DMzK6YMuNrgdO3PpdK+Q3vRmocYtMv9YogtPCIxnBjFUk1KS2byS9n/8ivSmN9DCZZQtPYf4+84otHaZyFhGUBu+O0f0PblG7Kvx4/7AT37yE1avXs0NN9zApEm7Jsbr1q3jP//zP7n99tuHrY+0EuIjw3DPCQA+97nPcd999/Hss8/yL//yL8O6bcXIk4+G5+ey+Tpy3/cLLut5U7etDY00diVpSXYRtiwqyysC8RaJ4nrQ1dVNm5OlpdMhlc7SncngmxaWNNjc3ETGhdYUeH6Q4hwyBN0pGaSvZ8DWIemBnm5i5z9+xdbX/oJEElv8QcqWnttvplgxfjZD6t1XSa55gdSG14LfzkiC8Jz3E557BOFZh6ENULoFwUJ/YU7QUp9LTW/A62jC626DfQyXaOEy9Hg1RmISZvlUjMrpuRZmM/vUefc+rsyWVcGxvftqYPgqNEJ1BxNZcDzR+ccOelwQBCM6XvsNnSt+j3RSRA48hqrjLkSrmUUCCAExE5qy4AHVwOQqmFGtUxGPUB3RaexM4usa8yZVM608TDrtgaFREw9haxqViRi2adLS3kE265DMZLAsg1gkQtS0MQwN2zQwTbPQVg+CBaB8Kno+M0Mx9hkrc4IJJcQBrrnmGm6//XYeffRRzj333GHd9lhld5Ps3hFxCIT3+UfW8fiKehzXRxOCm89aXIiS964Pn+hIKUlteI32l36Js20tWqSc+JIziB/2oSGZiGVbtpJ85yVS65YFLcuQaKE4obqDCc08BLt2EWZ13ZBTzXwnjZdsw0914me6kU4yiHy72aB/uJSBaVufGvEYWiSBFooN2X3c7WwmU/8W6S0rSW96s5CapicmE5l3JJEDju7RimUgpO+RXPsSHa8+gbP9nVzLl7OJv+/D/a6sT1Q04IKlgVv6QCJ7b1LNx3uN+I4dO7jhhhvIZDL8n//zfzj22GPxfZ/Ozk6OOOII1q5dSzg8POeZEuIjw0jMCbZv3860adM44YQTeP7554d124rRIZ+iDvRJWU+n04X7Hcehqb2LsrI4bd1JstksIdMkKyVtXWmcdIakq+M4KbISWjJQE9ERdojW9nZ2tjp0OVmiIZu2rgxpL4WGTsoVNHZncTKwow18IAl0dzTStuxROt/8C3gukQOPpuzIs7GnL9jtMflOitS7r5Fc9zLpDa/hZ7qL0r0PCRzJpy/oU1c+4Hvke0G2WyrfviyZa2nqIL3A2A4hEJqB0HM14nY4mBOEgramvVPLB9yX65DZ/g6ZLatJb/4n6fq3wMsiDJvQzEMIH3AUkQOOGnRuNEnATgnZ5i10vPobulY9G7yHBx3H9KPPw5w0m8rclCIRh7oKQTYj6cxALAqWZmDqLtGQRWXMJJl1iJtRojGTsBmiOmKgCUksHCJkSGriUYSmEQtFSWa6aAftwS0AAG9WSURBVGlpob27G891iUSjTKmqIhaJEAqF+qSeF/qcq5rw/YqxMieYcEK8ra2NiooKDj30UF5//fVx9YXZXR34YI9d9POXcbJ+H6OntQ2dBRM3XROcdNAknl2zY8I4pu8phWjussdIv7ccYVhEF55IfMkZWJPmDGkbXrKd9MbXSb0X1Ih7nY1Ari3I1AOCfqGT52JNmoVRMW1E68CKkVLidbeSbdwY1Ig3rCez/Z2CQZuwwti1CwnPeh+h2YdjVs0Y0vfLS7bT9c+n6Xz993gdjRgVUyl7/9lDyiqYiBia4JFLj+7RTvCxFfUI4Jxcq7KBsmDGu9jeHclkkrvvvptf/vKX1NbWMmvWLFavXs2RRx7JrbfeOmz7UUJ8ZBiJOQHApz71KR588EFWrFjB+973vmHfvmJ0yAshCKKU2WyWbDaL53mFXs5Zz6MzmUQiiBgmyayLh2RnUzsZx0EaYXThkhUCA41uT8N0PVzp0plxaehw6Er7JNOdeG5wgU1nfRq7HbY2Q3sKXILYsyDoax1Ec39L1+t/wM90Y089kKolZ2LOPw7TMMnu7rh8j8y2t0m/9zrpjW8EBmjSB0TgWD7lAOwp8zAnzcaqmYlmR3ezxeHDz6aDCPzODTg73iWzfR3Ojg2FNqtmzSxCMw8lPPtw7BmLBzS3LUZIn4ptK3jn70/R/d5y0E3KF3+QmUvPxqiYjklglleWgGgYyiyL2soILpLOziRdLnjpLFVlBgdOr0ETko3bWhCmRmd3hskJiwOmVhO3TWoScdLZLGURi4yTwTAMopaJl3VpbG4iFg4TjUZxXZdIJILv+1RVVWEYhmpHNg4YC3OCCSfEAa644gp+9KMf8b//+7+cdNJJw779UrCvZkvF7dBWbWsvTOofW1HPQ8s2F56nawJdgOsFrZLG7tlTepzGTXQu/y3dq59HuhnsaQcRO+x0IvOPG3KKdb41SSbXL9TZ/g5O46bCjxy6gVk+LdcvdDJGYhJ6vBo9Whm4n4fLglqw3UWjpURmM/jpDrxkB353K25nM15HI277DrKt23Fb6oNV+Rx6YjL2lAOwpx+EXbsQa/LcIS8KSCnJbFlF15t/pnvt38HLYtcdTNkRZxGe+/5RW1wYywj6fr8EcMrCyVx6wtxCi8EL7v0HTq42xDI0Hr74KIA+WTD5+8azIdtAFE/Q8/zxj3/EcRwOOuggZsyYQSQytMjSUFBCfGQYqTnB5s2bmTlzJqeffjp//OMfh337itGnOF3ddd2CINd1Hcdx0HQdKaG5I0hDbmhqwg7H6Uil8NHRfI3yRBzfdWjp6sbBRge2tHSSctJ4rkYyZ+bV3NmNrkF9O9TvhC6CmnEBdOf+DxByUmTXP0v9y0/R2bgVPVxG/OCTiR1yKpNqasn4kCJIrR4M30mR2fo2ma1vFXxj8o7qQFBWVjk9SB9PTEYvqw58Y6LlQXp5KD6k6Lb0XPx0V847pg2vqxW3szGoEW/djtu6LWiTlvulElYYa/LcoI/49AXYtQv7tGwdDLerhe6Vz9D95p/Jtu/AjldQdeiHqTzsQ9jRcsoMQIPqSgib4EsoD5uEDFg4vYrqynIa2rvYvH0nU8rj6CJIS39v504MKbB0gSfggJoKbEtnakUMNI2urk5a0ykqImF0YFpFObZts2PHDkzDwHWD+VZlZSWZTIbKykoikUgPt37F/sVYmhNMSCHe3NxMdXU1S5cu5eWXXx727ZeC4TJbKhb0uiYKgruYeZNivLuzS4nwIeKlOule9Sydb/wJt6UeYYWJzD+O2OKTsGcsHnIKeB7pZsk2b8bZuTHoI96yNfhRbN+BdDP9vkaYdmDWppuQ/+GQstBHXGbTPXuOFl6oocerMSumBCvR1XVY1TMxJ80etCZsILKt2+l+63m6V/0vbtt2hB0ltuhEYod9GKtm5h5vbyIiACHgkg/MIR42e5SVCODCXNp6b/f0iWDINhjJZHLAH9b+fpT3BSXER4aRmhMAnHPOOTzxxBOsWbOGgw46aET2oRg9itPT873H84JcCIFt22QyDjubm/B9ydamJsxIlK50CuFrxCJxGtu7sWyDjs4Uwo7gOC5OOs3GljY0T0cLR4gIcEUWTQuxvaWVd5slrW2QAaoMqKqElAeeA1ID3YCauMk7b7zGWy/8jpa1ryB9j/iMBZQv+iD2gcdCON4nSm4TbHOgY/U6m3B2vhfMCZq34LZsI9vWgJ9s6/9FmoFm2kHWmWYEPyrIoJVZzqwNr/9YvRaKBd4xFdMwK2uDecGk2RgVU/vMZwyChYWB5ouFvumr/pfUeytA+kTrDmb+Bz5EYsHR6JhMLtcxhEc6C04KysoEYSuEpWeJmBFiIZvqmElV1MTTTHTdI+n4+G6WluY2TEMDYaPpHtMSMaZPKsd1s9SUxdCFpLW1ifqWVmJ2iJrqSqpjUSKRCLquk3UcdF2nra0Ny7IwTZNYLFZoUaaE+P7LWJkTTEghDvBv//Zv/Pd//zcvvfQSRx999IjsYzTZG7Ol3mmqyze1ctcz7/DiuiYlskcAKSWZ+tV0rXyG5Nq/I50UeqyqyB31oD0W5b2376c6gn6hXS1BLViuRjzoI+4U+ogDu+rBDAthhYIacTuGHsn1DI1XoUcrBzVmGwpu+w6Sa/9O99sv4mx/BwC77hBiB3+QyPxj0UxlwLa3fOywafzun9twcx9pIWPFl32i3r2vEfl2Z+M5Tf13v/sdv/rVr4hEIiQSCWbNmsXpp5/O7NmzAfj973/PiSeeSDQ6vKmcSoiPDCM5J3jzzTc57LDD+NjHPsYTTzwxIvtQjD55QZ6PiOfbTOWdrpvb23GyLvU7d2CGQrhZD6lbCClxfEkknKA9lSTZnaHbcRCuR3vGJ+X7xE2TssoEIU3w7vYW1m3roLMbTB00C0K2RmXMprErRXMnpLOQcSEWBsuEVBoatrfStuZZmt58llTTFoRmEJn9PmIHfQBz3pHooRg6EBeQlYEYd/fg+P1sBq+zKZgTJNvxUx27asTdTDAv8HZJZaHpCMMMFu7tCJodRQ+XoUUS2LFKzHg1vh3ZozmiTs8ov5/NkN74Osm3XyS5fllhLhRdfBKTDj6FWOV0KsuCIRkSQmGoimiUhTR0I0R3dxpf6CQSESw3y7yplbSlXKrDkIiGMXXozEgQHi3tGWoTNmk84oZOVWWcqG1TFbUBH+lLtjQ04CNoaW5mcnUlNeXlJBIJbNtG0zRSqRSZTLAMEolEsG1bRcP3U8binGDCCvH33nuPOXPmsHDhQlauXDkuvlB745KcT1O98YxF3Py71aSz/URFFcOO76RJrV9G95oXSL23HDwXLVpOZO6RhOceQWjmYUM2YRlrSN/D2f4OqQ3LSa5/hezODQBYk+cSWfABoguOH7Sdi2LoiNx/8pdxLXe7OOp91JyqwnUBKJSg3Py71eM+TX3GjBn8+Mc/pr29Hdd1eeutt0gmk1xwwQXMmTOH//mf/+Gaa64Z9v0qIT4yjOScAODUU0/l6aef5sUXX+TYY48dsf0oRpe8kVteTEEQDcvT1NKC0HV2tnfSlewmHotjmyYYITq703Qn00wpD/P2tibaUh5h3SAcs8l6YORqzxE6a7c1s6O9m7BlkMn6dKd9pK7juR5N7dCaDtqfLagLEbEswmTZ2O5gINmR8tCaN1O/4jneffV5sp1NoOlEZiwmMfdIpsxfwvSZ02n3BJt3ZaGjEZjDTbYg5QSt1DyC//ee3eu5+wab5ZXnnN+dovsSBO3AfHZF5geqa48Q/C45uecIgnZiqc4mnPdW0LH+FTo2vo7MZtBCcSIHHk1k4QmEZyxG03RqbJgUh+qwQbfjgoCwpVMRFsTDNtPLy+h2fbIZhyllEZrSPmFLI5uVlFtQVR6mI+MghEGZJWjtdIjZEjSbuVPj1MRjWAboQqc7mSRsWzQ0NhGJhGloaGB23Qxs2y7UgOe9BYQQZLPZYU1XVow+Y3FOMGGFOMDVV1/NnXfeyQMPPMCnPvWpEdvPWOS6J1YWar8FMLnMpqFjoMQnxUjiZ5JBW491L5PasBzpJEHTsafNJ1R3KHbdYuyp88ds6y7pe2SbNpPesorM5pWkN/8z6HEuNOzpBxGedxSRA4/GrJha6qHul1i6KNSA7w5BEBWXUhai3v0J7omQpr5s2TJuuummQs2v4zi8++67/PWvf+Xxxx/n/vvvZ/r06SNitqOE+Mgw0nOCTZs2MWvWrHFVtqYIyKeaplKpoJd1zsjNsix838fJuiSTKdodh2jIRvo+yYwDmoFpmLiuT1tXkm7PpbmtiynlMVwXwmGTzozP1qZ2HGy2t7fT0pYkGgmRdByi4QgtbZ1saJBk3EAkT6uCQ+sqmBwPsbEtSXfapz2dJWLpdKayCM+hfsNatv/zH2xfvYzuxi0AxKomU3Pg4VjTFsKURdixGnRTIDwQGqRzofI0/UfNYzpkvUBQ784gLi/w84QAM3efJHCF748yAqGe7G4jU/8WbF1J58Y3STcG802rrJqqA44kMvdoonUHI3SDDgKxHtKgogzmTkqQSqYoj4UAl5TnUxGxmDc5QdS22N7eTZlpEo3atLQl0SwN4UN53KYiFIhnUwNHCoT0yaYcystCJKJRhBBETEFlLFLoAe44aVwpcVIpIpFI0AM+FiMajaLrOslkEtM0lRDfzxmrc4IJLcTzfcXnzp3LmjVrJoz74fJNrXzy3n+QHWRyn38nxu7ZMT6Rnktm6xpS7y0nvelNnIZ3c21FNKxJs7GmHphzTZ+DWTVjSA6kwzo+6eO2NeDsfA9nx7s429eR2bY2WDwA9LJJOYfU9xGa9b69qiNX9OTQ2gRv1rcP6bmGBjefdXAh5fzlDc0Duqfn09R1XePjS2o5N+e4Pl7o6uriox/9KAceeCD/3//3/xVSzwD+4z/+AyEEN91004jsWwnxkWGk5wQAF198MT//+c/5+9//zjHHHDOi+1KMPp7nkc1me7Q8y2azuJ6HlJB2XQzDIJvNkkplsENhJJJUKkNnxseXkHEzVMeiNLUnsU0dx4etbZ20dPs0tLST9R3KYhZdaYEhdFq7O2lolVSENbqkz/zp5Rw8vQopddbt6KQjk0XgEzIEq7d2EDLA8QWu71IVDmFk2lj9yotsXrWcrWtXkOnuBMAqqyJRO5/Y1HmISXPJVM5Gj1YQF4Jues7fDGBqGNodcLxAiPv0P8eLEYhpk2DhIANYuX8mQQ/1jmwg1jOAl+5C2/keyR0byDasI7l9LdnW7QBopk1sxkLqFh1O3cIlZMtnkEwJHB+qysDSoKkVfB3KYzC3Msb06hie51NmeaxvcoiGdeKWzfSqCOUhi1TWJR420XyXlAtTKsK0tGeIxywEPgYCXdPwgUTIxEPiZ100Xac6kUBID8sQaJDr+y0Kn3leExmGQSgUKtSBZ7NZTHNo7dsUY5OxOieY0EIc4Gtf+xp33HEHTz31FGecccaI7mus0F/f8GKmV4TZ3pbqt0WZlkt7VYwOfqY76NG9dQ3OtrfJbF9fEL0IDSMxGbOqFqN8Kkb5ZIyymsAdNVaZ6/u5Z+2/pO/hpzrxkm14nc24nU247TtwW7cHpnAtW3cZwmk6ZvVM7GkHYk9fQGjGYozE5GF+BxSVEZP2tIs3hC+eoQseueToAevC8+7p+fT01dva+fVrW3B9iaEJPnHEjEILtPFAS0sL3/nOd9ixYwfz58/nxBNP5Oijj+a8887jkEMO4frrrx+R/SohPjKMxpxgx44dTJkyhRNPPJHnnntuRPelGH2KzdvywstxHKLRKL7v09HRQUcyia0bWKEwjuviuh5l8QhdGZ901kGgkUlnaE97eFKS9SVltmDNtg4ELl2pDL7UCFsmHWmXiCHY1NRFt2YwKQSHz5qCLwRloSAiLn1Yv62Jxu4UqQwksy4h26Y8bBC1Q0xN2Di+ZGd7kqzrsXPzRta8sYzmjWto2biGrqZthePTQnHMqlqsiuloFVMwyiZhlFVjRispLy8nHI7S7QRCvTdVBvgalFuwvSsQ6aaQGF4KL9mO192C39VEsrWR9qYGvNZtOK31ZDtbCtsw45VEpx1IWd1BVM5cRFntPIRlETZ1PN9nenmYnZ1p0imPeFQjZIChG+iagRCSyoiBLwwqbYNJiRC+71MRt0A3saRHIh4mYes0tmWIhQQhy6Y8apPxPIQQtHcniVh2ob2bYRqknQxhDRzPJx4OEQ3ZRC0DTRPousDzPDzPw8kZs4XDYaSUhR7S46F0VREwFucEE16Id3R0kEgkOOSQQ3jzzTdHdF9jheWbWrngZ0F9+J6gCZg/Oc6ahs4RGplid0jp47ZuD9xRGwN31GxLPW5bQ+B83gth2mh2FGGFA1M23SwYwkkpwXeRroPvpPGdJDLTz8+zpmOU1QSu6VW1mNUzsSbNwqqZpfp8jxI1MYvGLmf3TySIoN945qIeYry4RrzYG+Kcw2v55SubC4trArDN8VUz3t7ezgsvvMDrr7/OX//6V5qbm3nf+97HPffcU5hoDTdKiI8MozEnALjwwgt5+OGHeeONNzj00ENHfH+K0pBOp9F1nY6ODnRdR0pJOp3Gsiwc10f6Ph4QDdmE7BBZzwV0HNcl47p0Zxw6upJIzSBimbzX2Ao+OL6P8H0MXWdnt4fvQXNHinjUpLIswuLactqSLral09aVYktLN8s3ttHRnSQWNmnqSlEZCzG5vIyKkElZ2MQ2dbpTKTTdorXLYe3OruB3oT3Jq+t30LR1A07jRrJNm8m21OO1bsPtaul70JqGEYqBGQ7mB4aFpgXp3JYOnufiZbN42Qx+Jkk23Y30+ia66+EyItXTKJs8g/iUmfiJWqyquYQrK5hdGcL1NaZEDRrSDhkXbA1SWY9oyKIjlcLUDDTNI27bTCsPEdIELVmXqC8xLA3dF5RFLSK2iab7WMKkPG4T0jTSXpa4bVIZCzE1EcGXPpphIjRBY2s3Ld0pDOESDUWYnIjRlUqRiIXRgGQ6Q0UsQjgUzF3ymRFCCHRdJ51O43len4i4Yvww1uYEE16IA3zhC1/gxz/+8bjqKw5wyp3P825TN9MSIS5YOrOHiduVv3ydP69uIJUzZxNAzNbpyngqHX0/JO+Y7nY0Bq7p3a34yXa8dCcyk9zlmu5lwS9yTddzrunmLtd0LVKGHilHj1dixGvQ41Wqr/d+hqEJbj5rMRcuretx/zdy3hCSIFX9k0fW8diKejJZv/C91wR8dZzUjLuuy7/9279x33334fs+tm2zadMmpk+fPqJphkqIjwyjNSdYu3YtBx10EKeddhp/+tOfRnx/itKQF+LpdLqHu3omk6W7uxszHKa8rBzwiYSC6KxEkHE9kukkbSmHTMYjZAjQDZraO0k5DjHLIhaLkMx6bGnO0OVkSWWyVMdtIpbJQVPKKQsbtKUytHRmeKuhg82NXTR2OXSm0oQMg5q4hWlZVJeHaGrPUB42aU15VMd0OjOwqbmTjlSWtmQKW0pWNQep5vmrWqUF8ZBDtqMFI9VMsr2dbHcbjY1NdHZ04Tkpstk0ws0ifA+BxDIgFgnanAojhGaFiZeVoYeieFYZWqwCIuV06DWYdoRZVRpViSjNHSkEEk2A6/tMSyRIRHWitkZTV4bt7Sm6kh4zayJYwmdLu0fMEtiWYFIiTsI2iFgGEQuSqTRSGDhZl2jYQtfARFIetUhlPUxNkM5CbWWYyngMHZfKeARdgOf5pJ0sO9u60DWBYZhUJyIYugZeFt/3KYuEiEYigCz0mvd9H13XMQwDKSWGsatTjBLh44/bbruNU045hcWLF4+JOcG+9SUaJ1x//fX8+Mc/5uKLL2bt2rXo+v4rOvLRr4eXbaK+LYiQ1reluf3Pa7F0wcOXHM3Tqxt48o1tPV4ngc6M188WFfsDQoig7VgkAVP2fwGl2DdcX3L9kysBCmJ8+aZWfv3all2CWxOcc3gt5xxey0//+i5/eWsHEJSeVETGR6bDTTfdxP/7f/+PM844g/PPPx+AWbNmlXZQijHP/Pnz+ehHP8pvf/tbnn32WT74wQ+WekiKESAvuGzbLqSsa5pGOp0mGo1ihcOYhobra/h+kLrckcpiWiYagqhp4jouzd1Z4mFJPBKmLBpGFxANmRi6RnVcQrdPMu2SdEEIn2TWpSpm0Z2WNCcdGtpTZDxJVnqELZtpFRHClk5Th4Olm0SMLCHTREu5aJrOjAqbqWUmnid5ZUMjzZ1JasskOzqCoErchETMpDJs0mFYJN0pVM+ymFZuk0xl+efmNEkCA7Y0wfwvpEPchklxm2lVMaKmRlNXlllVIZo7M3RnXRwJ7+3oJNUdiIeNzT6CJL4GhhBkXYkuJYkQ4Pm0dHjUlkfJelBheyyYXEVzMkNV1Gdnd1ClXhkxmF1TxuS4ja5JNjZ3kUx5CE0SM0DXDUKmwDY1GjvTxCMmlqWT8XwaO7spMzVClkPaDerAI5ZJWTxCdypN2NTQfAddM9B0Dd3USGfSRCNhhBCYponrugV39PzcP+8dUCzIFeODjRs3ct111/Gb3/ymYMhZ6jmBOsuAadOmcf3113PLLbfwi1/8gs997nOlHtJeUdySrL9yUseTPLainpfWNw26HVHUDkmhUOyf+BJu/M0q5k+Js2RmBS9vaMYtujD4vuTp1Q08s2YHLcldHroa0JocWhr8WKa1tZVvfetbLFiwgPPOO6/Uw1HsZ/zkJz/ht7/9LV/96ld5/fXXJ4yZ60SiWGi5rltIQ47FYoRCIZKZDL70idgmmYzD1rZObMtC+j54QV9vqWuYukZH2sfSIRGxMXQd35OAztTyKLGwRdzopMvViNgmzR0Zmrsc6luSQbmZD2ED6iriJCIGHhq6gJoyQdb3STo+O5JdWCIQxxHLIB4O0ZXKoBkmumljZdLEbUBCPKpRGTeIGTop16A64mPrNtIXZNCZNdmmuTtDMhmUUmcATQb9zyOWgZfN4kgD4Xus3NpGR8qnImphGxqmEJRHJFFbY2eXT0QH3TZJOz6a5RPSLJKuoDoeQjppGruSdKayREzBtvYOJiWiHDgpxubmLjQEtmUwrcxG1zQQkoVTE4QMkx3tnViGHrSBy3gkUxkSIZO6mgTJpIPwfaK2iRXSeK+pi5nVceLxGI6ToaM1hZtJ4+kegiiGptOV7CIRjyNwyWaz2LaN5wXBp3yLssC4TS/cpxh/XHfddQD86Ec/KvFIdrFPZ5oQ4nYhxNtCiH8KIZ4QQpQP8LyNQoiVQog3hBAjn1e2F1x33XVEIhFuueWWwmrY/sbLG5oHFOF5BHD6oikDPm5oSoQrFGMBYxjmAZ4veXlDMwBHzalCKxITnoSfvLCB9Y3dtHQHwlsAlqkV6sn3Z2655RYAvv3tbysRNYqMl3nB1KlTufzyy3nzzTd59tlnSz0cxQhjGAaapmEYBpFIBCEE0VCIsG3h+5L2VBpdE/gSOjq7iIUtLCHR0DAsk4gpMDWJ63kYukY0bBEN2WgCLN2kMh4lagS9zLd3dLOjrYMdrUkaOrsJGQaTKsqYmggRtQ3wobk7y6SEzbyaONMrYxwyvZxD6iqoidtUxQLDsrKQwcFT4sysDBO2TebUhJk/NUrMNGjqcNjQ1EVje4qtzRnak0mitk7UMiiPmkwrCzN/WphDZ8WorTQIWVBmCapiBpZt4GsC3TbwpUATki0tXWxoTRIPBdfShg4fQ0KX47G1JY0hBBXRMLGQRaeTpa0rhSslXVkPx/XI+oL2tMfUigiWZWBZFtWJMGg6DZ0ZXN/H8yGZkVimoCIeY1JFGWW2TtQUTCqzqYzaJLsylMXCTKuMUR6ziVsWUUug4eH6EksI7JhNNB6nOyNxfA/Hd0lm0vieSywWK4htz/Pwfb/w/7z4ViJ8fLJ161YefvhhTjrpJI44YuxUju3r2fY0sFhKeQjwDnDtIM89SUp52FitmwuHw1x11VVs2rSJxx57rNTD2S3LN7Vy93PrWb6ptXDfUXOqsIxgJTVkatSWhxBQ+GcZGoumJYiHTY4/oJreU1Nd7CofVigUpeWASXFqYnueIl78vZZAZyrL3c+tB+DmsxZjaEHbFq0fbTq5zB4XRm2dnZ384Ac/4KCDDuLMM88s9XAmGuNmXpBvZTNSTrqKsYWu6+i6jmVZmKZZEGxO1sFE4gsQuExOxIhHbKZWJaiMWHiuhxTgaxbRUIh4xM6ZgUkMXSNuBrXJiViISWUhklmftCvQdJ+0Ezixb21u490dHWxs7CbjekwqD7GlIwOu5IiZlVTEQnSmXDrTHpmsR0OnQ9L1qU6EKY+EWTA1wUFTE0ytiFIZCzNnShxXQrcLUgPTDBYa5k6KYugaNXGLRdMT1ERsjp03iY8vqWV6dQJfSpo6sqQcFyfr0ZnK0N6dpSMFTsalMyOojlrMqDapskFoGrYmqIibeFISC1nUlYeZOzlGXSJKTTRM2DSIR0PUVsSYUxOlOmozuypGTTwEwiXrebSlszR1pWnvTuNkfapjFqb0iYRD6LrB5tY02zrStKQ8HDcLuk46maWpI4WlmaSyLjEdyqIhTCkoi4apqozjZT10TaM6UY5pGOiahmmaeJ5HJpMplCSYpqkWbMc5N954I8CItSjbW/YpNV1K+Zeimy8DH9+34ZSWr3/963zrW9/iG9/4Bp/4xCdKPZwBKU5Bt4yeDsfHH1DDjo4057+/rkdtaL5d0U1PrcZxfXRNcOnxc1i9vYO/r2/Cl0EqqwqGKxRjg73uTiDo8UX+2d82IKFwrXjk0qN5eUMznaksP3lhQ4+XHlJbvt+LcIDvfve7uK7LN77xDTW5GmXG07ygqqqKz372s9x///0sW7aMpUuXlnpIihFGCNHDRdtxHKyccKvORczjsRi6ruF5HtFwiKlCkPUFHd0OTtYlnTGIR3Qipo6uW/jSpz3lMqMyTiqTYXtbBsMQSCQV6LjSZ0dLJ5pl4gto7MwQNXWiuoluQU0iTFk0xAE1MVZsakUIjVhYI50J+mFPr4piahFaUj7TdZ8yy6Ap5WCZBrOrdIQmiFkmtq4RNXQ0BDXxMJWxCLqRpbYyStgUpD0P0IiFs2giKFEqMw060s6u3xRdw7YspoR1NmY66cz4mLqgpcsjHjZIZ110odGZlqD5lEdDpD0fU9epjFuYQqMyarOtpZvWtIepmYRMjVTGpSxsEjI0NM1AR6OmPMT2jhSGkUZoGoYPQpeEdAvQCIV0LDSiYZu27i6ynkvI1CiPRki5WaTrYeoauqZTFitD1wW2bRW8APLmXOo3Yvyzc+dO7r//fpYuXcrxxx9f6uH0YDjzLz4L/HGAxyTwFyHEciHEJcO4z2ElGo1y5ZVXsm7dOv7nf/6n1MMp0Dv6XZyCnnV9Xt7QXGhJ9pe3dvBmfTs3/XZV4flLZlbwxZPmsXpbe6FlmedL7v3bBkKmHjhKokS4QjEukEF2Sx4vt8jmZINrxZKZFTy5op57/7aBmNXTmPLE+ZNGebDDz7Zt27jtttuora3loosuKvVwJjr7/bzghhtuAOCzn/3sflu2ptgzNE3DdV2SySRdXV2BoZemkSiLEw6FCvXEQggitokQOsm0Qyrj4AroTGXwpURqOhpgGiYVMZt4xCZih3jfzCrqqhPMrE6wcHoZZVGTsG3S1J1F1wRpx2FHZwaBJBG2SLsetq0TDZnMqI6S8SRtXVm6spLWpIvng+tpRMMmk8siHDW3ikNqKzhmXjXTKyKELYNplWFCtsn2jgzxkI5haXjA/OnllEcMykMGh84ow9AgZpuUWTrTK2McMK2C6qhOJAS6EMR1l6qIxuzqMLWVEQ6dkeDAqeWYhqAqGiYRssn4PtGwzszKKGVhg8W15Rw1u5qZiTDRsEUkZDK1KsqCaTEWTa+gMhxiSiLCpPIIoJGIGEhd4PgQsU1cH6IhC9Mwiesm8aiBLyWO45F1JalMBunpCDRcdKZWJZhaFqU6HqGmogLfdXGcNJZlFj5fwzAKHgEj6ZatGBtcfvnlwK6o+FhitxFxIcQzQH9Fxd+QUv4m95xvAC7w4ACbOU5KuVUIMQl4WgjxtpTyhQH2dwlwCUBdXV1/TxlRbr75Zu6//36uvPJKLrjggpJ/QfuLfh81pwpD18i6Ppom2NqW4vEV9WSL+oJnPVmYdOfpLbR9CU+/tWNYV2MUCkVpESJYfHtlY2uP+32CNPVT7nyedY1Bv/guZ1enhPFi0nb99dfjeR4PPPCAinSMEKM5Lyj1nGDWrFlceeWV3HXXXTz44IP867/+66iPQTH6OI6DZVmF6HgoFAKCFG9d13KLMiKITls6IcMmHNIwNJ22pEsynSFs2yA0hAAz1wK0PGaDsCnPutSWh0llsrSlssSnJLD0LirLbGrCJqZp4Uifne1pKqNhNClZvb0DXYNDp8fZ0NJFJiuxdIOulIvj+9RVRNCFxtQyiwOnlNOWqmBdYxLLELi+5OX1jSQdj6pYGEGwiDA9bpN2XBzp4/lBtDrjSbysR2XcZrMURKwQx9dG6Ug5RE2dmdUxysMW08pjbGpJkfFcwpZgWrmFbVnUtyVp7UoTNTUs3WRyzKauKkpr2iGVziJ9qImGSLk+adejqixEddgkIyUm0J3x8XwfS9eJmgazquLUVklcV2IZPhlXQxcu5WVhXFci8SmLBZ9R1vNwXTeYu6czWKaFRBKPRQu/B0KIoG+6tav0S/1WjF9WrVrFr3/9a04++WQ+/OEPl3o4fdjnPuJCiM8AlwIflFImh/D8m4AuKeUdu3vuaPUM7c1dd93FV77yFX7wgx9wxRVXjPr+i7n7ufXc+Ze1+Lko11WnzueoOVVccO8/cLzgs9NE0DdYEghwAEMXnHfEDM49vLYgxpdvauUTP36Joa7pl4dN2lLZ3T9RoVCUFEODonU4LEMrZL/0fJ7A82WfRTlBIOA/eug0Dpgc56g5VftlinpjYyOTJk3i6KOP5qWXXirZOCZ6H/GRmheUak7Q1dVFZWUls2fP5u2331aT9gmA4zhkMhlc1yUUChXqx2GXaPM8n5STJZv1cbIZOlIOnqYT1nXskEnUsojaBp7nk/UkWelj56OwukAAHaks21uTICT1LSnKoyabm7qxDIFt6NRWhqkqC/PGpjZsU0cXAseT2LpkQ0sXXUkXy9CZWxMnYhuYOsyqjBG2DBCwoyvDzrYk/7tmB51Jh3jYRGgQCVnMrIwifUlTt8OUcpt1DR0Yuk5V3Kap06G2PITvw5tbWqmOh3hnRxvt3Vlm1MSYkrA5bk41jR0OXZkUtRVlbGtLs70jScQ2QHrousbksgjd6SzxcIiYbTC1zCIesfF8CJsaPkHrNAiyDBq70viALwWVEQtDE3RlXFzXQ2pQFbFp6kyDEOiGRtbJUhUxae12SLseFRGLkGViW0FrMifrEg5Z2LZditNIMQY455xzeOKJJ1i3bh3z5pWmve9gc4J9dU0/HbgG+OhAP7ZCiKgQIp7/GzgVWLUv+x1pvvCFL5BIJLj11lsL7Q1KRbEBm2kEbsaPragvCG4IItueL/nEETO4aGkdpy6cjAb88pXNXPTzl3ukqN9y9sE90lY1Bj4JlAhXKPYPtF7CoD8RDkHLsnioZzp6zNKRBNeRJ9/Yxu1/XtvjujEU+jOPLAV5E5ZvfvObJR3HRGY8zgtisRhf/vKXeeedd3jqqadKPRzFKGBZFpZlUV5ejmVZPaKpEKQ3e57E83186ZPKeEQsg4ipE7ZNQoaJpWuF6KsmwBR6YdFTy6W3R8yg1Vfa9TA0QUuXy7SERUXUxDKDFOqQoaMJgSEEhiFwsi7bWpOs2dpJ1vWYmgiTcX0aOzI0tKdpS2dAQMjS0QS8vqWV+tYUnVkPXRdMSoQxNEHG89jYmsTSoTxsE7YshAZSCqaXh5hRGaUyZnHKgsnMqYlRFjKZURlBk4J3t3fj+oJDZpRz7IFTMUyN2uoIB00r47gDJ7N07hRilk1rtwMI5tREmVEVxrZNLMtA0wW6JgibgTO7lMF7GjINqmJhopYJUuL4kPY8gur1ILIfDxnoQsPP+sRtA6GZVMQjVEXDWKaFQBQM2OKxCIZhsK9BR8X+ycaNG3niiSf40Ic+VDIRvjv2NSv5R0CcIK3sDSHETwCEENOEEH/IPWcy8KIQ4k3gFeD3Uso/7eN+RxTLsrjmmmtoaGgoea+5JTMrePBzR3HVqfN58HNHAfDo8voeES0tJ9LPPbyWb519MIfOKMf1ZY8a8jwXLq3jV5cdw9dOm8+tZx/MJ5fWYQxHnySFQlEyHG9okwxNF8yujvW4rzreN1KQrycfCvnymTv/sucCfjjZuHEj99xzD4cddhinnHJKScagAMbpvCDvnP6Vr3wF13VLPBrFaJCPogohCv+klAVRp+sCXTPwfC8QvraNDrieD1IitMD4zfclAnC8wCFcAL7v47gSF41Y2MAyTGbUxCgLaaR8iIdDTCqLYpsGrgdLZlUhhMDzIBExSGY9qmIhEBoZN4svYEq5TVsywx/+uY1X32uktduhudtBk3DQtASxsEXSk9RWxJheHiZimegCOjIe7+zoYkrC5pi5NRwyI8G8yXHa0x5djofQNQ6eUc6Rs2uoiUcQwmdBbZzJZTadmSwdKZe12zppaEuRcXwyjhfU0GuCKeVRplXG6Mp4+FIj6fi0dTuEDA3T0INFCU0gRNA61zY1Mq5PyBDYpgFSErYMIiETy9CQQCIWoboszKTyKJWJWOG1wf99DEPHskw0bdfiiRLiE5MvfvGLwNhenN9X1/R+lxeklNuAD+f+3gAcui/7KQVXXXUVd911F1deeSWXXHIJ4XC4ZGNZMrOikCZ693Prg4s8QTrpKQsnc+iM8h6ppMU15Lo+cE/g+VPitCadHrXlCoVi/CJ9yT/r24Hg+mEaGnWVETY29wxcCk0MeN3Id2HIX3P6M48sRVr7l770JSAoLVKUjvE6L0gkEtx4443cfPPN3H333Xz5y18u9ZAUo0BegAcR8CBDsrjPdNjSkZ6JnhOTQmgkomEEkMkGaeMAnpTEwxZO1qetK00Wn7Tj4XkSiUD6krZuByk0DD3w9IiEDFwZXKfjIZOlc6vpdlze3NSCT9C3vLE9zfypUSqjBltautjckmJaeYS3GrqYVBZGIOl0PLrSWRIhg6NmVzJzUhn/3NLK5uYkjutz0NQEhq4xpyrK5PIwQgiautJkXEncNuh2XDpTWY6cW00sYpBxYsybnMAwDVKuT0c6S9p30TMCX/osmKZTGbXJAqmMi+N52IbA1AWJsIUrJbGQie/LXKlUUFppmTqmFUSvPV/mxLyG5kkyWY+QrhM2dTRNYFt6wTwxZBtkXR/bMhD4mKZZWDDpncmgmDg8++yz/OEPf+DMM8/k/e9/f6mHMyD7JMTHM6FQiB/+8Id88pOf5Dvf+c6Y6TuXT1XPuj6moXHpCXP7n/TmVv983+fxFfVAIOgfWraZG3+zCl9KLEPjM0fPUm7pCsUEobhFoSS4Pry4vqnP8/7loEn9XlcGMo8sviYNJOBHkrVr1/LUU0/xoQ99iBNOOGHU96+YGFx//fXcdddd3HrrrVx++eWFmmHFxKD48w4irBIpwTKNoPuM0LANicgJQFMTaJrAz4lN35c4rocrPTwfWjozREMGUdvEEB62oeMLA01KDDPYV8wyKAtbaJqgPZ3B9SRTK6Jsau5icjzCnOoo8yeXE49YOI7H9kiG6rjNjo4MHRmXsGlw1Jwq1jd2ELV0ZlbHkb7ksBnlxEMmO9pT+AKSjothaCQigYGZ47psb0vRkXII2wbRkEHY0jlqdg0hU8eTgTi2NQ1daFi6Tsr1MJAgg/7pMyuitHRniJg6vi/Jej6dGUkiZOUWN6A7ncXxfUxdQ7hBurqUkqwvsU2dMBJHCkK2iW3pSEluQUTgebuCSIau5RZIgkWS3lkMSohPPG6++WYAfvzjH5d4JIOjhPggfOITn+BrX/sad9xxB1//+tcLrpmlJJ+qXhyR6s3LG5pxc4ZMrg8PLtvMr5fXc+KBNTy7Zgf5LFbH9Vm9vaN322GFQjEOOaAmyqbWVI/68f6SYQwNLjthbr/b6C/6/cWT5g16TcpH0CsiFq1JZ0SM4P7jP/4DgNtvv31Yt6tQFGOaJtdeey3XXnstP//5z7n00ktLPSRFCdF1HSklhqHjuoG4Bp+s6xGyDBAaaccNepN7EpkT5h0ph660R2c6Gzh6hwwitknUNhGahgAQELYMDE3sMofzJbomKAsbTC4LEzJFIRIfswyOmldD0oNNzV3EQxaWbtCdcalJRBC6TlXMwpcCywwWTqcmQoRtjdbuLNGciZvr+Ri6RnU8jGnotHWlCNsW8ZCBmVuI0DVBMumwfmcnOzvTWLogbBnEbZNJCRvbNglbBpYEUxNkfYknfSxDx8lKwlYglrOeJOP7GJqG40oM4ZN2fbrSWVwJlWELXddAgC8l6YyLoQeRcEPXe5QKCJF3ss+ZGGtaIXNBifCJxyuvvMILL7zARRddxPTp00s9nEFRQnwQNE3j29/+NhdddBFXXHEFP/vZz4Zt28s3tfLYinoEcE6Rs/lgzy+e6A72/KPmVGFookfdqOP6/OWtHX2eWxW10AQMscRUoVDsp2xpS3HCgTU8XXQd6L0IN68mync+fuiA15eBot8DXZOKI+i+DPZnm0EkfbjE+N///nceeeQRzjzzTBYtWjQs21QoBuLLX/4y3/ve97jiiiu46KKLiMViu3+RYr8mqPMOVi3zvad7Ph6kUJu6QMrAoM00A7d0XYjgNbqPrglMX2AgiIcDR3UdsAyDymiIrOcjRWBcZukCU9eQUhZaqJWFTdqSWTJZj1lVETRdI+l61JRFiYRMWrrSHDi5jIqoBVIyvTJCZyqLEGAbGiFTx81KqmM2Egm+pGurg+v61NbEMIRGNifEARJhi7JQ0MLX9XxS2VxqPvBeU1eutlvgepJZ1TFCpk7E0ikLm+i6RsZx0XUNoUnaUx6mFlz/9aLUfkPTcmajQRS9NemgCQ3h+3RnspRHLAzTwPclyYyDaeggBFnXxzK14L3PCW0pCfanhPeE56KLLgLg1ltvLfFIdo9y6doNF154IUcccQQ///nP2bGjr5DdG5ZvauWCe//BQ8s28+CyzVzws8ENjvbUDGnJzApOnD9pt+PIuyQrEb5/UROz2F9+ZoRgvxnreCfr+kzqZcx21mHTsIwgAmMZWh8R3tsNfcnMCm48YxHHzKvmxjMW7VZMF0fQIRD9vQ0k95VrrrkGgLvvvnvYtqlQDEQ4HOa//uu/yGazfOc73yn1cBSjgO/7hQhrXpAXiz9N08hlZENuwdHzPHzfC2rEpQSCaLhhakTCJvGQTTxkMq2qjCnlUSxTJxoysXSBrYOGxHGyhbp0z/PQkVRFTCaVhSiL2hi6RljX8HNR4bTj4kmfWMjAcSXdmSzRkMGcmhgLppZRZptMqwhj6hq2obOjM83G1iRtyQz/2NBIxvUK4t/1/F1+RCJwK9eEKLTAFFKyrS1JY1eGbsejMmozrTxMXWWUeChIbxcEkWxNCMpy6e0x2yiYqJm6wNR1fAlR28AydTRNI+N6gYmbLtC1wKBN5BzWZa53u57LLrAsAzOXxq/roocxW7GxnmLi8Jvf/Ib169dz9dVXU1dXV+rh7BYlxIfALbfcAsDll18+LNt7eUNzj/Zju5uY9pcOOhjLN7Xy/Nqdhdu9hVDEUnVt+zPN3Q5lkb6r8r0/ZyHg+AOqMTSBViI1/P6ZFRxSmxj27QpgSpnqCzpUBEGk4E8rt/e4/8k3tnHigTVcuLSOhy8+qo8I770AuHxTKzf/bjV/X9/Ezb9bvdtFwXwEPX/+aTCsdeRPPvkkL730EhdccAEzZswYlm0qFLvjwgsvZPr06dxyyy00NDSUejiKESaf4pwX4XnyNcjBcwS6BoahYVk6uq5hWWZw29SwLZ2k45POSmzDIKRLKqIh4mETz5O5dPUshhaIz+KoblfKobk7Q3vKxfMlpq4TNjRMTWNSeQQpJU7WDVLbhUCTMKMywqR4iEnxEEIIIpZBZcwOeo0bGoausaUlSVNnBts0yLqSeCgwnXOyblDXLiVZN8gG0AhatWW9ILI/oypKxNQos00yrs+W5iRCiEI0HSBk6uhasBofCZmYulYQ4fn3LxYyKI+YRGwDIQTVUYuIZRDNpbpbpo5lBP/CIQtDFxi6KES+8//yt/P14MViXDFxcByHL3zhCwBcd911JR7N0FCp6UPgtNNO4yMf+QiPPvoo69at44ADDtin7R01pwpT35U6PtDEtLi2ck/MkPI14hBMwEOmRiq76wck6ZS2N7pi3/AltCd7ts85ZeFkauI2j7y6mbx/yVmHTuOAyXFOXzyV1dvaeXDZ5h6v0QgilLv7mdIHKF3QBHxwwWSeX7sT15OYhsahtQle3bhLnC3f3DbgD6GhBcfiFz08VL+Cg6bEWdPQOYRnKiC4BmSyPs39FIU//dYObFPjL6sbaOpymF4e4sWvf7DHAqCT9bnrmXeYURnZI4f0Yk+L4a4Rl1Jy9dVXY1kW99xzzz5vT6EYKkIIHnzwQU488US+8Y1vcN9995V6SIoRJB8JL647hnxUfJc7d15k5p+Tvz9oO+Yjpcy17LKwDQ3LDGrMm7vTpLMenoTykEk8ZBS27/kSx4OQGbiXG64gamtEQhaaHkSOHddDsw0MzWd6IoKXa/ml58YT7BsQMldTLkhnPTQNsj40tKeZXhEmYkBb0iHluNhGkGLu+X4QCdc18oceRKNhekWccHeG1pRD1DZoTTqEzcDULT/+sLV7mVG86GBZBpMMHc/3CwsSelEgQcv1HNeLaufz73XvbeVvK7O2icPdd9/Ntm3buPPOO6moGP3uLXuDGMurRUcccYR87bXXSj0MANasWcPChQs555xzeOyxx/Z5e7urEe/tTnzjGYtYva0dCZx7eC1AH3Okb/9hDX9a3cBhM8r50+qGoH2ZJgo9xRXjk1Cu5vaxFfU81FtsiyDl+PgDavp4BMybFGPDzi5217zuyFkVvLKx/8jnRUvrWDQtwR9XbedDi6cCcN0TKwuP54V1/icwfxrqAs48dBpPvbmrNOKw2gSnLJpCRcTixt+sxPWD5y2Z2Xf/ymBwZKktD/H9Cw4PrkFZH5/gXDJ0DXKtZUxjeGu995QnnniCc845h2uvvXbM1YEJIZZLKY8o9TjGG2NpTgBw9NFH8/LLL7Nz505qampKPRzFGKJY/HmeR1faoSPjEjGDyHA0l6KdzbpsbO4CEfTUNjWNuqpIoS+2L6EjnUVKaOpMYxkaFRGLsrBJRypL1vcxBZRF7KCWXQYmcnkRLqXEdf3gd1iArgVR6aTjsrklSUtnhraUw5zqKDOqYnRlXEKGRlfKJRExCZlari3bLkGbn092pLNsaOzCcT1CpkFZKEh9D5v6HgvfoBZ+l5jW9jCNrz8tU7wQohj/uK7LtGnT8DyPhoYGTNMs9ZAKDDYnUBHxIbJgwQJOO+00Hn/8cZYvX86SJUv2aXu7M1zrnY6+als7j6+ox3F9Hl1eD1Li+rLQQujp1Q385IUNAGxsTvKxw4Jo6Jtb2nqYM/UmYmoks6qP+EgzUFR5Xzl14eRCC7t8m7pi8ufPhqbuPo+t39lV+FsAc2uibGxJ4hYN1NAFm1uSfV6b55evbg5q44BXN7ZwzuG1PUSyJBBwWq7GrHhcv31zW48Fojfq2znv/XXMnxLnvPfXFRapHl9R30eIKxG+78RtnYwn8Ty/z7m5tS1diGbf9cw7/H19E74MIiufPLKOaeXhEXE/HyrZbLbQx/naa68tyRgUihtuuIGPfOQjfPnLX+ahhx4q9XAUo0zexK13pBx6Rma701myHkRNHcf1SIRNhADX9fERaFInlXXRDY1EyCDrBaLUMDRMXSNi6ezsTBMLGcRCJllX4noSXdcIWQaZrEvW9TB0rdC3vBjXkwgNfA80M1gatw2duG2gAZMTIWorIrkVc4kPmGYQsdf1XWnemqaRdjyyubT1sKFTE7Wpb0uRynpUR0yEDMyBNUFgrLaH7Eor3zPxLIQomNqp3uETk//6r/+isbGRb3/722NKhO8OVSO+B3zve98D4POf//yI7ytfW6mLIHVdQA9hnvVkj/TQP63uWaf2xpY2jppTxfNrdw4qWk5dNGVEj0MRsKciXNuNydmsqgi3nn0w9376iIIYOufw2n5rwYWAjc27hHifWnICJ9PvfPxQHrnkaC5cWsepCydz5KwKpISGjsyA4/D84Nj83I+vgMJKfH7bx86r5uazFmMZuy43mqDfLI1HXt3MBff+g4eXbebXr20BYGfnwPtX7B0acNmJ83j44qM4/8g6or18I+Ihg+WbWlkys4IrTz6wx7XonMNr+eJJ80omwgF+8IMfsGXLFm655Rbi8XjJxqGY2Hz4wx/mAx/4AA8//DCvv/56qYejGGXyIjydTuM4TsFYrT+C+nGjIBTzNdgANYkQ1fEQZbZB2DRIZT2yvk93JqgJD5kG1dEQpq7jejKYH+R+ZqWU6JqGYRgD9rXXdYEmRMHgLBiPYFp5mNk1MebUxLBNHUvXiFqBQ3nUMgIjTyHQciniGdej2/EwNIGRM1Eri1gcPD3BobXlhHJR/nwkf28ybvclS1fTlGP6RKW7u5trrrmGKVOmcOWVV5Z6OHuEEuJ7wMKFC/n3f/93Xn31VZ5//vkR3Vc+GnXVqfN58HNHcc7htT0mw6YuCn8fNaeK03sJ6sNmlHPzU6t7tDArRgi47Pg5HDBZTWJLSdzu+8P5gQOqueVjB2Pq/f+gGJrgzvMO48KlPd0gl8ys4OQFk/s8XyLwvF2eAcVbNXXBBUvrCinGS2ZWcOvZB3PpCXNZsbkNr0gtz6qK8P5ZA4uv/HY/d9zswCCOQOBfefKBBTOwUxdOZl5NlCUzK+jv8BzXx/EkEnA8yU/++i5/faexzz4Ug1MTs3rcLl6gEYBlBteNLz+8goeWbaa7l29EZ9otGLT1vhaVUoBDkH522223UVlZWXBMVyhKRb6t6c0331zikShKged5QSq4rg8oxKMhE8+HtONRFg6uzfkIuqEF9dfxsElNPETIDhJV86Znfk6Yhi2deMggZOgkIhamoWMbgRi2Da3HAngx+TRvTYiCq3jxY2ZRu6/APM2kKmYTDfVMmE05Hp1pj0zWpTXp4PpB6aNtaHhSkvX8YBHBcdnZkaaxK0PWG3q2ZWGceSGv6ikVe8Btt90GwI9+9CNse/8y8lWp6XvIzTffzC9+8Qs++9nPsn79+j7pSMNJ7/T1vOlR3qytv77i+RrxP6zcPrAIB05eMJl42KQiYqHn2lEohgcNOO6Aal5Y19Tz/n6iwAdNLWPFptZCxFwT8KHFU5k/JZ5b8u75AiECoQtw93Pre6QHL9/USnXcxtRFD1f+YMU8l+4len7WnzhiBuceXltw4s9v6+UNzYUJAOwS/4+vqO9hxlaMBB5cthldE1x83OzC+VXs8v+/b+/A9YHGbnQN9Nx4BMECxIvre75n7zV172qhkv+POlV3S2OXU/hb1wRnHjKV5m6HRVPL+OOqBra3p7jg3n/0uUYIEXTakUA66/P4ivrC9aXUAjzPddddR3NzM/fcc89+lX6mGJ/Mnz+fc889l8cee4znnnuOk046qdRDUowSmqbheV4hijtwRFqnuiwMBL/HeeM2w9ADt3U98N3Ii2k91ydbz7ULg/6NzyxDp+eSa0+ynk9nKoumCeK2ga73na+6nk9HOjB/LQsZhQUAKSVZLxiTrgkynh+0VjNMkhmPiKVjGnrOgE7LjRsa2rO57DdBWzJLzPYRQsPI9UUfjHxaesaVSILjt03V5UcxOJs2beJb3/oWCxcu5Oyzzy71cPYYZda2F9x0001885vf5Je//CXnn39+SceSd1YvFmR3P7eeO/68to9e0XIixtCDvCbX8zF0LVenpBguBHDErApe29Qa9BgVMDURImwZvLuzq8fnYuoCNxcBzmPpgk8cMYOHX9lcEO7TK8I0tKXwZfCa/OeX9wgACuZ+hiY4cf4knl+7s2CqlTf7e3jZ5h6f9ccOm8afVjfkaroEN5+1mAuX1vUwC+x9/wU/e5ms6xfSz/pDE3DJB+bw8xffw5cSQ9eoqwizvrFnivwpCyfz7JodeLJ/jf3+WRWFyPxA6eyKoaEVieyhYumChy85usdiT+/rzWjS0dFBIpHgoIMO4q233hqzaYjKrG1kGKtzgqamJmpqajj88MNZvnx5qYejGGXykfCBhHgxvi9xvKCMy9CD1o7pbL5fdxDdhuC3ThP7Vue8rS2F50t8KYnZBlWxvpHChvYU6WwwftvUmZoII6WktdshmfXQBFTFbKQv6XJckIKwqfeJmEMg3hs7M0ELNNcn63mYmo7je0RMI4jo78ZF3fMlTs5o2PN97NxiRfE+imvWFYpPf/rTPPDAA/z1r3/l+OOPL/Vw+kWZtQ0z11xzDTfffDOXX345Z511FqFQqCTjKBZLeYGST/ftb7J9yPTAlfr+FzfQ3J0Fghpzxe6JWoEL6fTyMG83dO5WzBRHjX0ZmF/1R28RDpDN3Wdou1rcNbSng5qr3OOQS992g7ZSdUVtpTxfcuiMci49YW4P0XT3c+v7LLi8vKG5MAnwpeTG36xi/pQ4S2ZWcOMZiwpu6Pk0+CUzK3j44iAz480tbX2c2IuP+acvbCgcm+P6PUQ4BH1XJbvq5/t7T4vT43uL8PKwSVsq2+/+FX0ZyiJG7+yYrCd5LBcV793JoRRp6l/5ylcA+OY3vzlmRbhi4lFdXc1nP/tZ7r//fn7961/ziU98otRDUowiuq4XjNugp0DMC8d8bbjr+QWBncm6mLoWRMd1LSeag+vwAJVpQyIvVIN0cR1PBuK2PzK5xfv8fALA9SWprEfEMnBcn2TGpSJqFyLqA0W2hRCURyw601k0IGabpLMelh6Yvjmez+5my8GCvCTfsr33Zb7YjV61JVO88cYbPPDAAyxdunTMivDdoZaT9oJIJMIPf/hDmpqaCgZupaDYWT2PJ3cJm8qIyZSyXSugb9S387MX3i2IcKDQ0kIxON2OR0faZU0/IlwQfJE0sas39+4QBNHG/lp0mLrg3MNr+cQRM3a1/cqll+sieNw0tEKE+MV1Tfz6tS0YutbDN2DJzIoeplpHzanC6PWNb+xlhOb7kpc3NLN8Uys3PbWav61r4obfrOrRFi2/3UtPmBtkVwzAYO+Drgn+86zFTIoPXsszWMnEzKrIoK/Nc+SsCnWhGwKWofV5vyXw6PL6QiS8dw/x0WTr1q3cf//9HHvssUroKMYc3/ve94hEIlx11VUFQaaYGAzUJqu4t/Wuvyl0oMgbqMlcS0gh6NdwdU/In3tSSsrDJqmsh+P6JML9l/HYhsZ7LUnea+ouzAX1XL12xvVwfb9gtGrq2m7Tyy1DoypmU1MWwjQ0pAwWBKRPD8PWgdC0wNFd18A2erZBK46GKxGuAPja174GwP/9v/+3xCPZe9T8dC+57LLLmDJlCrfddhvd3X1bQ40Uyze1cvdz61m+qbXgrD4QHWmXdK9V0LaU2+d55RFVZ7kvnHXYNL562nw+eWRdv25i/bmUH3dANTd9dHGQK1zEobWJQirwOYfXYpuBuLYMjZvPWsxVp87n4UuO5uGLj+LYedWFdG7Pl3x8Se2Ahlp5IfW54+YUfmz7S/XOm3jlW+WR2/YNv1nF8k09a8OXzKwouKxftLSOy46fg64JBMFigaUHhm26JjA0Csdx4dI6fnVp8LpzDq/tIeZ7/8YP9DOrCzg655UwGIEjvL5HpRe6YJ+iEfsTgsC08WunzeemMxf1+xzPC0R3RcRCE8Fnml/sGU2++tWvAkGLEjUBU4w1EokE11xzDfX19dx3332lHo5iFMm3zsq3zyqmd/mnoWsYubrrvCFr2NKxDA3bGF7X70TEYnp5mNqKCBG7/wRYX8KsyggHTIoHEXlfommC6phNxDKojFhEdpNO3h+aJojZBpMTISaXhSiLmITMoW1H1wRmPynpgYFbsHCh0tIVb7zxBs888wznnXce8+fPL/Vw9hqVmr6X6LrOHXfcwac+9SmuuuoqfvrTn474PvtLDb3xjEXc+JuV9Jd1tHhaGbOqozz5xrbCfYbWs58zgKUuaPtEMuc4vXhaokc6uRBw1qHTehjn5VuFXXnygTlDtF3bMTTBjWcuKojovFv1QDW5V558IK9ubCHr+piGxrmH1/abKlx83gQr88H9fuDdhkYwOfj4ktrCNh7r1ZPcy0XK+9v+9KKe0qcsmtKvoWDx38XbWDKzgps/uphHXt3M5LIQNXG7R218f/FwXcB/fuxgWpPObuvGJdC0h+3PRqLf+1hlenmIr394ARB4S/SHaWhURCxu/t1qfBlM0m48Y1HhNRURi9akM6J14y+++CKPPPIIp59+Ou9///tHZB8Kxb5y9dVX84Mf/IDPf/7zfOpTnyIcDpd6SIpRIC8Me4vuYlGdF45CiEEzyfaVfPQ9v29jNxFs29DoSPv4smc9tqlrJML7NjfMH+tetBPvg+/3TElXTGyklAWPrrxj+v6KEuL7wEUXXcR3v/td7r333kI7neHgoWWb+9TmAn1SQx9fUc9jK+p7iHA9V3u0eFoZT15+HADPr91JW8qlPGz0GxFvUH2aB2QoJt3Pvr2TZ9bswDI0Tpw/iaff2oEkELjN3U5h4UMDjj2gmitPPrAgWGxTI5P1C27ovYXMQG7V+Qj3jWcs6iGC+jPT6lHC0OsHTObq0W46c1GPc+3cw2v51aubC+eW1U8EdKCa4d5Cu7+/i7dx8+9W47g+a3d0cuMZi7CM4D3p/b6bORO7/GLB8k2tWIaGkx3cbHBNQ+cgj04M8iUTvd/T7R3pQouyoHSh50LdB3Lna/E5JJCs3tZe+Nx8uWuBaaTqxm+44QYA7rnnnmHftkIxXESjUe666y4+/elPc+edd3L99deXekiKUWCg1HTYN7O1PcHzJVnXQ9O0IaWA56mIWIUa8dgAUfOxgKaJQkszlRGl+O1vf8s777zDl770JebMmVPq4ewTY/dbt59w++23c9ppp3H55Zfz0EMP9Xl8IFE9EA8t28x1T6wE4G+59lf51+VT0fMR0HyLoWIu/sCcQoQL4GM/erEgvvsT4YrBGcq6q+8HxmlZ16cmbmObuz6jDy2e2iNqXSzC84ZoN/5mFZ4v+e9/bOSURVMGFNR5BhLAA91ffN7o/bjkSylpTTo99rFkZgWPXHoMj62oRwDn9Iq2L9/Uyl3PvFMQzPma4T0VYb0Xl1Zta+ecw2tp6szw/NqduDnzmvw4i39+izMGKiIWz6/dWXBgV/RcRBpoocLz4eanVmMbGhnX54xDpvK7f27H9yWW2fN87X3tKfan2JdzYHc89dRTPP/881xwwQXMnj17WLetUAw3F110Eddeey3f/OY3ufjii5k8eXKph6QYYXrWgI++SJRS0pEzLpV4RKQ+5DRwTRMkIoM1QRsbBH3Gd/2tmLhks1muuOIKhBDceOONpR7OPqOE+D5y6qmn8pGPfISHH36Yq6++msMPP7zw2GCieiD+uGp7n9vFjtUPfu4ovvqrN9jWluKtre19IrZ/Wt3QQ4iv2taxD0c3vhEiiAa7+6jc8mYrpqFxzuG1nJPry50X0fOnxAcU1a1JB1/KHkIGGNSduj/TrCUzKwa8P3/ePLainvU7OnmlyNFdMHC972DR+Pz48pH/va0ZztcdkzOje3R5faEt200fXcwfV23nxXVNSMD1g+/UYyvqe0Tf84sQrUmn0LZtIDf3vOFd7/7Z45GqmEVXxu03u6CYN+vbe/x92fFzeLepm50dadY2dPY4h4pLDfI+Ar7ct3NgMLLZLF/4whfQdX1Uyn8Uin1F0zR+8YtfcOqpp3LllVfy8MMPl3pIihFmoGj4aBF0VJFYho7r+Xjj1CtQCXAFwJ133smWLVu44447qKoaXa+akUAJ8WHg9ttv5/e//z3//u//zuuvv164fzBRPRAfWjy1INrzt4v5/jPvsLE5CQQu6DUxi8auXdHM0xdN6fH8xdPKeKNoon1YbYI369v3qJfweOWgyXHW7tj7tOV8Om7v9HDom5I9UJSwd5bDUXOqBhTUg71msPvzPL6ivk8GxcyqCHeed9geRTGLx6cJOHZez3T7oZJPS8/XHZ84fxLPrNlROO7WpFOog8+Lyd4LFo+vqGfdjk6Wb25DSllYuDhx/iS+++e3aUv2bG/2LwdNYlLcZtl7Lazf2bVH493faO5yuPT4OfxjQzMrt7YPuQf7M2t2FFrNvVkfLCReuLSuz3lcnI0wUjXi3/ve96ivr+e2224jHo8P67YVipHilFNO4YQTTuCXv/wlV111lfI1UIwogfFbUKYlkURt5fujGJ90dHRw7bXXMnXqVK644opSD2dYUEJ8GFiwYAGXXHIJ9957L08//TSnnHIKsHtR3R95oT5QOvsrG1t63O7MuFx2/ByefGMrdZURTuklxJ+8/Dg+9qMXWbWto0fdeHHq8wX3/mNCRAh783ZD5z61btME3HhGUFudfz+h/1rogRjIkG0wQT3QawYzd8uL595ccvzcPRZPvQX/3ojw4jHl646r43af4y6O5j+6vB7P8wvmYZ+89x+5nuq7cHIi/YsnzQMoZKRA4J/w13caybqDR4jHCxL42d824MmB3ef7pdeXYqAFxMEWmIaDbDbLbbfdRnV1daFFiUKxv/Czn/2MAw88kG9+85v87ne/K/VwFOOcmG3gmUGtur6vPdAUijHKrbfeCsBPf/pTLGvsl1QMBTGW3QePOOII+dprr5V6GENi+/bt1NbWMnPmTNatW4euBzaRe1ojvjs+fd8yXigS98cfUM2Gxi7q29IAhPbQMCmfZtw7UjrREQSttPpzoy/m1IWTufSEuYOmku+O/urBB6sR3xvyn3PW9dE0wcKpZZz//rq9PieHY3zFYzJz7xv0767ee5+Pr6jnwaLe5nl0AZ88sq5Q0/7Qss3c/+IGEILKiMmrG1snhAgfCEFQEygE+J7sUzteHja45vQFPRYwbj374EHLK0aKq6++mjvvvJOf/OQnXHrppaOyz+FCCLFcSnlEqccx3tif5gQA5513Hr/+9a95+umnOfnkk0s9HIVCodhv2bhxI7Nnz2bRokWsXLlyvypVGGxOoIT4MHLLLbdwww03cNddd/HlL3952LY76+u/L/y98dsf4dP3LeOVjS0cOauS7W0p1jX27GP+tdPmFyKCu+Pu59Zzx5/XTmhxUowmYP7kOLecfTCPr6jnoWWbCTpXwiG1CSaXhXrUH+sCzj+yjl/mWm7pAq46dX4hxXwg4ZIXlfm2UHsr4veE4Rb3w8HuTOkGan92yf+81qcOXAN0XeD5u1LUYVe9fe/U7KE44o83NAFzq6NsaklSGbHY2ZnpI8ZrYhaH1VWwsyPN+e+vY/6U+D4tNO0N7777LvPmzePggw/mn//854juayRQQnxk2N/mBB0dHSQSCWbNmsW6deswDJWEqFAoFHvD6aefzp///Gdeeukljj766FIPZ48YbE6gfhWGka997Wv88Ic/5Oqrr+aSSy4Zlh6ixSI8f3vjtz9SuD3n2t/3fskeGSYdNadqwomRwfAlrNvZxdOrG4CgZZbnB0ZsN54Z9E5+9u2deDlVJ3Npv8Up1RURa1DhUmx2polg+yPpOp1npFOJ94ahGMIZmgAhCiZuN56xiP99e5cIF8DJCycjoNA6rriOvLcIzz+/Jm4XFlCKEQTZ2UOtqd6f8CWFhbuB2hY2djk889YObFNj/pQ4j62o32d3/D3l6quvBuDHP/7xiO5HoRhJysrK+I//+A+++c1vcs899/ClL32p1ENSKBSK/Y7nn3+eP//5z5x11ln7nQjfHcrRYRixbZu7774b13VHzURgbnW0x+3a8tAe1yhPNKpjVqFmtr/EFteX/OSFDTy0bDOelNTEbT5z9KyCaPzPsxZjaAINsMzAKf3Bzx3FVafO58HPHUVr0uljtlZMcW20n3ML18XIuE7vCcs3tXL3c+tZvql1908eYXoY1nmSbNH7+UhRf3MIRPVlJ8zl+XcaC4tKuiY4ak5VoZ49XzKnERjsXXbCXM49vBatn1o6yfgU4XuCBJysX6jNz78dQuzZQl8xQz2//vrXv/Lkk09y9tlnc+yxx+7VvhSKscK1117L5MmTueaaa2hrayv1cBQKhWK/wvd9Lr74YoQQ3H333aUezrCjIuLDzMc//nE+8IEPcN999/HNb36T6dOnj+j+nv7qiZxy5/O829TN3OooT3/1xEGf3zvNHcAYQi30/s7HDpvGqm0dbNjZRVORy3xehvWnuyRBn+WGjgw/eWEDAF//8AIuXFrXb81s8aLGYGZrvc3O+nNd31P2Ne38oWWbufE3q/CLnMdHe5Gm+Bh69D7XBD7gecGixeSyELCrE8CkuM3LG5pxcz1bBPCJI2YUxj+Qu/fdz63Hn+iKexA0TSAIFj/y7O11YqAe9/2Rj4b/7Gc/27udKRRjCNu2+eEPf8h5553Hrbfeyne/+91SD0mhUCj2Gx599FHWr1/PjTfeOOKaqhSoGvER4Nlnn+Xkk0/mzDPP5Le//e0+b68/8byv28lz69kHc+8L7xZaohWjQZ/60cGI2TqfP3Eef127s0ev6rHA9PIQOzoyuEXCK99+7MhZlT0M8AZCAI9+/phBBWpx7fdg4no467X3ROQM9Przf/qPwnujCfjqqUP3GRgO+jsGoPBe3vTU6sLCxU1nLupx++GLd9WCFxu/9S4H6P1+P7RsMzc8ubLQhi0fCc//PYYvjSNKYFQouPmsxcyfEufcH7/U4/EpcZuXv7FnxlN3P7eeO/+ytoePQn/n1+OPP865557LZz/7We677759OYySomrER4b9dU4gpWTevHls2LCBTZs2UVe378atCoVCMd5Jp9PU1dXR2NhId3c3kUik1EPaK1SN+CjzwQ9+kDPOOIOnnnqK1157jSOO2Lf52L6I791x3RMrOf6A6h5C/MhZFfx/H1rAkpkVhdZnAsnujNUPr6vgiyfNY2tbqo8QL7UxVsrxCnXdeU5ZOJmIpfPkG9t63D+rKtLvwoSEQetj90QQD2e99u76jg/l9cXvjSbEqKfI93cMXzxpXiFy7XpBjbLnBf3FH764b5u2G89YVOhQMFBNfrHIv+m3q8h3P/PlrvNzuIPkEUsn6XjDu9ERQtcE579/BufmXOchWJgofk8au/qvLR+M3fW4hyD97Etf+hKhUIjbb799r49BoRhrCCF44IEHOPbYY7n22mt58MEHSz0khUKhGPPceeedNDY2ctddd+23Inx3qBrxEeJHP/oRAJ/61KdKPJLd057Kctnxc5hVFeGy4+fwq8t2RX2fvPw41t/6Yf79uDm73U5tZYTlm1qDFkm9HptbE+3vJf2yNw0JLENDH+RsPu+IGT16a+YXBnqLcF0EvbVDptbnGCx9cIHan5gcDfIiZ2/rzI+aU4WdO14jFwkd7bT0wY6hv8eWzKwoCHUIxPZNT63mxXVN3PTU6h51yP19Li9vaO7Rg3wo2nso52V/z9kfRHg+Cn54XTlNnRkeX1FfeA8PmZ7o8dyDe90eCvl+8Hkfhf7Or5tuuomtW7dy6623UllZuVfHoVCMVY455hg+8pGP8NBDD/HSSy/t/gUKhUIxgWloaOD6669n7ty5XH755aUezoihUtNHkK9+9at873vf47777uOzn/1sqYcD9J+efurCydz76b5R++LnPvb5Yzjvp//A82WfCFmeI2dV8M+t7Tiuv1epvfkWYYunJ4jbBj99YcOQBJKuCZbUlQ/YI/qy4+fw9Q8v2FUH7UssU2P+5Dhv1rf3eG7+vShOMV+9rR0JPaKE/dFfX+zRErT7muo+FlqbDbWVWX/j+8YTK3v0Fb9oaR3fOvvgwmsv+vnLONmgh3o+5fqCe/+B4/U9Y/Rcn223n8dCpkZ6d6khYxRNBILYcX3+Wd/e47tyaG2CVVvbKT5kK5f2X5wZs3haGU9eftywj23r1q3MmDGDGTNmsHHjxv2qP2h/qNT0kWF/nxNs3ryZmTNnctBBB7F69Wo0TcVCFAqFoj/OPfdcHn/8cf70pz9x2mmnlXo4+4TqI14ikskktbW1+L7Pzp07sSyr1EMCAmHyyXv/QdaTmLrgl5cc3aNu9o+rtvO3fmqm8wJ8oDTzQ2sTrNzaPuBzhpKenq/bPv6Amj59ovOPHzQlzpqGzn7HVowu4D8/dnAPYzXY1Y96bUMn1z2xssc2fn1Z3xrwPRGIY0HQTkSue2IlDxUJ8QuX1nFrTohD/2Z0AD/567s8s2ZHYdFIABcsrWPxtESPc2N/x9AEnztuNvGwSUXEGtKxCeC4A6q58uQDBy3HGI7z/dOf/jQPPPAAL7zwAh/4wAf2ejtjBSXER4b9fU4AQZvTO+64g1/84hd85jOfKfVwFAqFYszx5ptvcthhh3H66afzxz/+sdTD2WdUjXiJiEQifO973+Pf/u3fuOqqqwrp6qVmycwKfnnJ0f2aVw02QbcMDcf1Ebne172ZXBZizfaOQt/tzxw9i2fe3sn6nV3A0NJ/872Kd3SkB3z83aZudEGP6F0fEa4J/jMX+exdH5w3iVoys4LNzd2FyLuRS10vjoav2tbOo8vrCz2si6PcA9WEKwE++iyelkDPLcaYuuDcw2uBXZ/lm1va+vRr/+JJ8zhsRjnPFC346Frw2sdW1O/R/nUBNXGbho49r58eDXwp+fmL7xUWIoaCBF5c18SrG1v6ze7YV5PAPK+88goPPPAAp5122rgQ4QrFYNx0003cd999fP7zn+f8888nHA6XekgKhUIxprjooosAuOeee0o8kpFH5UWNMJ/5zGc4+OCDufvuu1m7dm2ph1Ogd40twB9XbR/0NTeesQhNiAFbPv3lrR04niRi6WSyPs+u2cHZ75te6OE8ULKpIGihpud6c5uGxqbm7gHH4bg+/WQN98D3Ja1JZ7d12/GwST4L1vMlj6+o56Kfv8wdf17LdU+s5OFlmwd8falqwhU9Wb6plZt/txpJkNVw4vxJhfvzn+Vf3trRp8c4DFwfP9TE6CNnVaDlFoXGqgiHwIDP82XhXO0Py9A4ZeFkTl04mUNrE4UMloHO7eE4/6WU/Pu//zvAuOwPqlD0JhqNcuedd5JOp/n6179e6uEoFArFmOJ//ud/WL16NVdccQWzZ88u9XBGHCXER4EHHngACFLSxjIfWjx10Meve2Ilbi6qOBidGQ8JrGvs5uFlm3YZbemij6GaoQtOXjgZTdMCoUDgjN2WcvflUBACKiIWR82pwtC1wIxK77+nt5Hrl6xrAgmFGnfYFcUXDM1ErJjlm1q5+7n1PYzDFMNPsSD0JDz91g4u+vnLPLaivsdnCX17jOdNxL562nweufRoLlwatBU65/BaLCM4b7ReqlwjiIBbumD55rZhdVkX7N4Urjy854lMN5+1GNvcda4+9vljejz+tdPm8/DFR/GzTx/BvZ8+ghvPXNTj+f0ZAO6rSSDAgw8+yKpVq7jyyiuZO3fuHr9eodgf+cxnPsNhhx3GD37wA3bs6FuCpVAoFBMRx3G48sorSSQS3HbbbaUezqiwT6npQoibgIuBxtxd10kp/9DP804Hvg/owM+llN/el/3ubxx66KGcf/75PPLIIzz99NOccsoppR5Sv+RFSL4F1IVL6/o1d9sTtrWn+fVlx+zqB/3bVXgEhm8fXDCZy06Yy8sbmnm6KD14OLywfAk3/241N56xaJdr3EB+CCIX+xOCxdMSQQp+1scnEGGGJvjEETM4p5dZW17E9VcjO1xpu4rdkxeEmWzOJJAgQiugz2dpGRrn5NLW8/RXTrBkZkWhRdqbW9p6+BUcMasC29QJm3qP83Y4GEjTv39WYLK2ensH7XuxSHXdEyu5aGkdOzszTIrbwOBtEQc7t/fkOYORyWS49NJLSSQS3HzzzXt2QIoxi5oX7B4hBPfccw/HHHMMn/rUp3j66adLPSSFQqEoOddccw2tra387Gc/Ixoderel/ZnhqBH/LynlHQM9KITQgbuBU4B64FUhxG+llG8Nw773G374wx/yyCOPcPHFF7N+/XoMY2yW51+4tK4gyIfCQA7qeeZWRwtC57onVhZaRgngsBnlwyJODU1wxiFT+7Qiy7o+f1y1vRDF93zZp8f2yxua+/SozouLiohFa9IZVGQMVBO+r729FUMnLwgfW1HPo8vr8bzAtf6cw2s55/DaIX+W/W03b8D3/DuNZF0fXRe8Ud+O6/kYuoapC7Ke7BN1N3N+CoMhCHrZ92dKWIwu4OsfWsDLG5pZubW9z74uXFrH48vrSe1mfw+/srlgpPjYivoBF4eKDdjyfgoDsS+eCNdffz3JZJJ7772XeDy+V9tQjFnUvGA3HH300Xz84x/n0Ucf5c9//vN+7wqsUCgU+8L69ev5/ve/z4IFC/jc5z5X6uGMGqOhBo8E1kspNwAIIX4JnAVMmB9cgJqaGu68806++tWvct111/Hd73631EMaFjbcFkTVBoqcr2sMar2Xb2rl0eX1fep085P+sKHtVkj0xtJ7RqqPnF3FI69u5q3tHfi+RNcEYVPH0ETBQK6/1HTL0Aotx/JCbV9Fc3/bVYwc+c/s3JzwLhbc+/JZ5s/Pm85cRGvSYVtbqiBoPc/nk0fWIYFfv7aFrBdkepy8YDKXnjCX837y0m69DHYXURcCLv7AHADe3NJWSF2X7Or9vWhagm+dfXAPZ/j+Fsfy9xXXfZ/74139jDd++yOjlsnxz3/+kzvuuIOlS5dOqB9cRQE1LwB+9KMf8ac//YkLLriAbdu2EQqFSj0khUKhKAkXXnghEJSsTST2qX1ZLgXtM0AH8BrwVSlla6/nfBw4XUr5udztfwWWSin77c4uhLgEuASgrq5uyaZNm/Z6fGMNKSV/+9vfqK2tZc6cOaUezpAZSGT3l9ra33Mf+/wx3PXMO7y4rqkgIC5cWsc5h9cWJv3FwsHWBf/x0cV844mVfVJ1dQHnH1nH9PLwgNHN5ZtaC9HRfOTy40tqB+wDPlItx1Qrs/2b/kTp2obOHr3o89kTd/5lbeEczrfNu+mp1T2i4vMmxXivqQuv13pT77Z+Fy2tI24b/PzF94IFJD3wLsgWqfrjD6jmpXebCwtMJxxYU9jOpLjNOYfX9hDZIVMrfM80wBqkH3o+y0UXcNWp83cbFd8bWlpaWL9+PTNmzGDq1MG9KfZHJnL7suGeF4znOQHA6tWrcV2XBQsWjJkWpwqFQjHarFy5EsdxWLJkSamHMuzsU/syIcQzwJR+HvoG8GPgPwnmf/8J3Al8du+HClLKe4F7IegZui/bGmsIITj++ONLPYw9ZqBa0qEKzbyYkewSAfmU4d4iXBfwpZMP5MKldfxp1XZe6NXPXM8Zq+0uXfzxFfVkc/vMu0TvaXr5vqJamY0N9nZBpHd5weMr6nlsRT2+lGia4MYzFhW2J4QoeBB4Eh55dXMfd/J3d3ZhGhqzqyOFln7BawEZXER1AdPKg3ZGvpQFAd77QvjCuiY0QcFcsDiybumiTx38jWcEEf3iFP1ioV5MXoSPZCZHZWUlRx555IhsWzHyjOa8YDzPCQAWLVpU6iEoFApFyTn44INLPYSSsFshLqU8eSgbEkL8DPhdPw9tBWYU3a7N3afYjxkohXXjtz/SIyr+tdPmF6KFmoBj51Vz5ckHFgRMvjd5PlJXPPlfOqeKv+Wi6ACH1iZY09DJw69s3m2N669f29LD+fzR5fUDRsQV45fi83R3mRG96V1ekBe9vgyyW1ZvaweCBZcPHjSpR6335LIQptHZIyKef/27RSJc1wQXHzeb//7Hxj5lDPl965rA6SfHXRMCv5+Mpqwn+4js655YyWOfP2bI578n4VfKYFAxAGpeoFAoFArFvrOvrulTpZT55tNnA6v6edqrwAFCiNkEP7SfBC7cl/0qSs9gZmTFEfTlm1p7iJliEV7sulwcqQO4+7n1VEQs7KIU2h0daVxvcAO05ZtaueuZd3qk8UJQz6sM0yYO+Sj4trZU4Tx1XJ+Hl23m8UEWcYrp7QoOwYJOPrvj169tKfgTXHrC3IKhm2loXHrCXE6cP4kbnlzZp068+OaUMpvOjMtnjp7F6u0dfGjx1MK4ivf95YdXUN+W7rEd15f9miXm+5r3pvf5f+vZB3PdEysHPP5zf/zSoM7qCkV/qHmBQqFQKBRDY1/N2r4rhDiMYG65EbgUQAgxjaAdyYellK4Q4nLgzwRtSu6XUq7ex/0qSsxQzch21+Kod/p270j7jPJwwfCtoSODJgZOmy1+bb4WXRIIE13X2NqWYvmm1t0KMFXbvX/TIwquCQxdK5QpFBuVDeWz7X1+nnhgTSHyXezCX9zuLH/evLyhecB2ZHm2tqV5cNlmIDhfX93Ywvwp8cI28/t+8esfZPbXf99ne8UivCZm0dLtDLjP/Pfl239Yw59WN3D6oincevbB/HHVdv7WqwREodgH1LxAoVAoFIohsE9CXEr5rwPcvw34cNHtPwB9+ogq9l/2pIfwntRK9460v9ec7PG4oQm+fPKB/e6z+LX5NPgPLZ7K6m3t/Pq1Lfzyld1HQ3uLuP76hyvGNsXngedLzj9yBoIggj2Qe/5QWL6plefX7izc1nWth/N/73ZfxYtVUsLuegIMtkiwfFMrRWXo/TKtPExzt1Oo8S6OiufT0r/9hzX85IUNAPzkhQ1cdvwcHvj3pcDApowKxZ6g5gUKhUKhUAyNsdnMWrFfMBJmZL0j7cURcYCFU8sGdHHu/dp8Gvzdz63H9eWQenoXizjHkzy0bPB6dMXYo/d5kK8JP6ef1mZ7wssbmnuUPHx8SWCI1p9XQl6cF5ukXf/EykHFeG+PhN77HkyE65rg/PfXsXbH6sJx91fj/afVDX1uf/3DCwD6+DuotHSFQqFQKBSKkUMJccWYonek/eUNzdz+57WFx09Z1J9Rb/+vzYuQPenpnX9uJrt3qcyK0jPQebCvC0cVEatH2vfiaYl+vRKgf3G+ubmbe/+2ASkDwX1obYLlm1rxJZh6kH0xkJHcUXOqMHPGhv1x8XGzuXBpHfOnxAddbDh90ZRCRDx/uxglvhUKhUKhUChGByXEFWOO3oIpZA5NRPf32vx9e5JG/+Dnjir0Ife8oe1XMbYYiWyN1qRTMEfTBAVzwd6LPAOJ8//+x0akDKLXN50ZRMqXbwraK/u+ZHp5eNAWew9ffBSPr6gvLAY8vGxzoSVgPGwO6bjz0e98jXj+tkKhUCgUCoVidFFCXDGm2RMRvbvtDPW1+eeeu4+pzIqxyd6a8fUnugc6PwcS55Kg9dlAIn4wis/h5ZtaeXxF/ZBfm6c49VyJcIVCoVAoFIrSIeRghYcl5ogjjpCvvfZaqYeh2M9RLuiKPL1d+XvXdO/uHHlo2Wb+uGo7H1o8lQuX1g26n+Lt5febF857st+BnrOn53V/ZmwqFX1kEEIsl1IeUepxjDfUnEChUCgU+xuDzQlURFwxrhlIeCkmJntS092b5Ztaufl3q3Fcv0ebsf7onYGxt3Xrg52/I5F+r1AoFAqFQqEYHbRSD0ChGEkGEl6KiUk+Hby4F/1Qz5F9PZeWzKzgiyfN2yPxrM5fhUKhUCgUivGJiogrxjV7WoerGN8Mtaa7P0pxLg3nPlV7MoVCoVAoFIqxg6oRV4x7VI24YncM9Rwpxbmkzt/9D1UjPjKoOYFCoVAo9jdUjbhiQqNqaRW7Y6jnSCnOJXX+KhQKhUKhUIw/VI24QqFQKBQKhUKhUCgUo4gS4gqFQqFQKBQKhUKhUIwiSogrFAqFQqFQKBQKhUIxiighrlAoFAqFQqFQKBQKxSiihLhCoVAoFAqFQqFQKBSjiBLiCoVCoVAoFAqFQqFQjCJKiCsUCoVCoVAoFAqFQjGKKCGuUCgUCoVCoVAoFArFKKKEuEKhUCgUCoVCoVAoFKOIEuIKhUKhUCgUCoVCoVCMIkqIKxQKhUKhUCgUCoVCMYooIa5QKBQKhUKhUCgUCsUoIqSUpR7DgAghGoFNpR7HEKkGmko9iBFmIhwjTIzjnAjHCBPjOCfCMcL+dZwzpZQ1pR7EeEPNCcYkE+E4J8IxwsQ4zolwjDAxjnN/OsYB5wRjWojvTwghXpNSHlHqcYwkE+EYYWIc50Q4RpgYxzkRjhEmznEqxgcT5XydCMc5EY4RJsZxToRjhIlxnOPlGFVqukKhUCgUCoVCoVAoFKOIEuIKhUKhUCgUCoVCoVCMIkqIDx/3lnoAo8BEOEaYGMc5EY4RJsZxToRjhIlznIrxwUQ5XyfCcU6EY4SJcZwT4RhhYhznuDhGVSOuUCgUCoVCoVAoFArFKKIi4gqFQqFQKBQKhUKhUIwiSogPI0KIK4QQbwshVgshvlvq8YwkQoivCiGkEKK61GMZboQQt+c+x38KIZ4QQpSXekzDiRDidCHEWiHEeiHE10s9nuFGCDFDCPGcEOKt3Hfxy6Ue00ghhNCFEK8LIX5X6rGMFEKIciHEo7nv5BohxNGlHpNCMVQmyrxgPM8JYHzPC8b7nADUvGC8MZ7mBUqIDxNCiJOAs4BDpZSLgDtKPKQRQwgxAzgV2FzqsYwQTwOLpZSHAO8A15Z4PMOGEEIH7gY+BCwELhBCLCztqIYdF/iqlHIhcBTwxXF4jHm+DKwp9SBGmO8Df5JSHgQcyvg/XsU4YaLMCybAnADG6bxggswJQM0LxhvjZl6ghPjw8Xng21LKDICUcmeJxzOS/BdwDTAuDQaklH+RUrq5my8DtaUczzBzJLBeSrlBSukAvySYKI4bpJTbpZQrcn93Elygp5d2VMOPEKIW+Ajw81KPZaQQQiSA44H7AKSUjpSyraSDUiiGzkSZF4zrOQGM63nBuJ8TgJoXjCfG27xACfHh40DgA0KIZUKIvwoh3l/qAY0EQoizgK1SyjdLPZZR4rPAH0s9iGFkOrCl6HY94/DHKI8QYhbwPmBZiYcyEtxFMPn1SzyOkWQ20Aj8Ipdq93MhRLTUg1Iohsi4nxdMwDkBjK95wYSaE4CaF4wDxtW8wCj1APYnhBDPAFP6eegbBO9lJUHKy/uBXwkh5sj90JZ+N8d5HUEK2n7NYMcopfxN7jnfIEhnenA0x6YYHoQQMeAx4EopZUepxzOcCCHOAHZKKZcLIU4s8XBGEgM4HLhCSrlMCPF94OvADaUdlkIRMBHmBRNhTgBqXjARUPOCccG4mhcoIb4HSClPHugxIcTngcdzP7CvCCF8oJpg1Wa/YqDjFEIcTLAS9aYQAoLUrBVCiCOllA2jOMR9ZrDPEkAI8RngDOCD+9ukaTdsBWYU3a7N3TeuEEKYBD+2D0opHy/1eEaAY4GPCiE+DISAMiHE/5NSfqrE4xpu6oF6KWU+cvEowQ+uQjEmmAjzgokwJ4AJOy+YEHMCUPOCccS4mheo1PTh40ngJAAhxIGABTSVckDDjZRypZRykpRylpRyFsGX4fD98Qd3MIQQpxOk9nxUSpks9XiGmVeBA4QQs4UQFvBJ4LclHtOwIoIZ4X3AGinl90o9npFASnmtlLI29z38JPC/4/DHlty1ZYsQYn7urg8Cb5VwSArFnvAk43heMFHmBDCu5wXjfk4Aal4wnhhv8wIVER8+7gfuF0KsAhzg/4yjFdOJxo8AG3g6t8r/spTystIOaXiQUrpCiMuBPwM6cL+UcnWJhzXcHAv8K7BSCPFG7r7rpJR/KN2QFPvAFcCDuUniBuDfSjwehWKoqHnB+GFczgsmyJwA1LxgvDFu5gVC/SYoFAqFQqFQKBQKhUIxeqjUdIVCoVAoFAqFQqFQKEYRJcQVCoVCoVAoFAqFQqEYRZQQVygUCoVCoVAoFAqFYhRRQlyhUCgUCoVCoVAoFIpRRAlxhUKhUCgUCoVCoVAoRhElxBUKhUKhUCgUCoVCoRhFlBBXKBQKhUKhUCgUCoViFFFCXKFQKBQKhUKhUCgUilHk/wd0dCJBS4SL7AAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plot_contour(logprob, orbits=samples, weights=weights)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Ellipsis + IAF\n", + "\n", + "The ellipsis used to build the orbit on the previous algorithm solves Hamilton's equations for $p(x,v) = N(x|0,I)N(v|0,I)$. We can use normalizing flows to approximate the pullback density of our target to a standard normal, thus allowing the algorithm to sample from a density similar to what it is targeting.\n", + "\n", + "To do this we parametrize a diffeomorphism as an [MAF](https://arxiv.org/abs/1705.07057) and optimize its parameters by minimizing the the Kullback-Liebler divergence between the pullback density and a standard normal (equivalently maximizing the Evidence Lower BOund (ELBO) or the Variational Lower Bound).\n", + "\n", + "Once we have a diffeomorphism that \"transports\" our target to something close enough to a standard normal, we can use our orbital MCMC sampler travelling around the ellipsis to sample from our target pullback density (the target density \"transported\" to a standard normal). This will be equivalent to sampling using periodic orbital MCMC where the bijection used to move around the orbit is the composition of: first the inverse diffeomorphism which transports samples from our target to the standard normal, then the ellipsis solving Hamilton's equations for $p(x,v) = N(x|0,I)N(v|0,I)$, and finally the diffeomorphism which transports standard normal samples back to samples from our target. Formally, if there is a smooth, invertible transformation $T$ such that for $x$, a random variable distributed as our target density $\\pi(x)$, we have that\n", + "\n", + "$$\n", + "z \\sim \\phi(z), \\quad x \\approx T(z),\n", + "$$\n", + "\n", + "where $\\phi(z)$ indicates the standard normal density. This implies that\n", + "\n", + "$$\n", + "\\phi(z) \\approx \\pi(T(z)) |\\det \\nabla T(z)|,\n", + "$$\n", + "\n", + "where the right hand side of the equation is what we call the pullback density of our target. Thus, letting the bijection $f(x,v) = (x(t), v(t))$ for\n", + "\n", + "$$ \n", + "x(t) = x(0) \\cos(t) + v(t) \\sin(t) \\\\\n", + "v(t) = v(0) \\cos(t) - x(t) \\sin(t),\n", + "$$\n", + "\n", + "we have that using the periodic orbital MCMC on the pullback with bijection $f(x,v)$ is equivalent to using the periodic orbital MCMC on our target density with bijection $T \\circ f \\circ T^{-1}$." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First we define our parametrized MAF bijection using autoregressive neural networks." + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "import optax\n", + "from numpyro.nn import AutoregressiveNN" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "iaf_hidden_dims = [2, 2]\n", + "iaf_nonlinearity = jax.example_libraries.stax.Elu\n", + "init_fun, apply_fun = AutoregressiveNN(\n", + " 2, iaf_hidden_dims, nonlinearity=iaf_nonlinearity\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Then we initialize the parameters of our MAF transformation and define our reference density as a standard normal." + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], + "source": [ + "_, unraveler = jax.flatten_util.ravel_pytree(initial_position)\n", + "_, initial_parameters = init_fun(jax.random.PRNGKey(1), (2,))" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [], + "source": [ + "log_reference = lambda z: jnp.sum(stats.norm.logpdf(z, loc=0.0, scale=1.0))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Some utility functions\n", + "\n", + "Define the log pullback density, our loss function (negative ELBO) and the optimization loop used to train our transformation." + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [], + "source": [ + "def logpullback(params, z):\n", + " mean, log_sd = apply_fun(params, z)\n", + " x = jnp.exp(log_sd) * z + mean\n", + " return logprob(unraveler(x)) + jnp.sum(log_sd)" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [], + "source": [ + "def nelbo_loss(param, Z, log_pullback, lognorm):\n", + " return -jnp.sum(jax.vmap(log_pullback, (None, 0))(param, Z) - lognorm)" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [], + "source": [ + "def param_optim(\n", + " rng, init_param, log_pullback, learning_rate, n_iter, n_atoms, n_epochs\n", + "):\n", + " epoch_size, remainder = jnp.divmod(n_iter, n_epochs)\n", + " n_iter = epoch_size + jnp.bool_(remainder)\n", + " rngs = jax.random.split(rng, n_epochs)\n", + "\n", + " optimizer = optax.adam(learning_rate=learning_rate)\n", + " init_state = optimizer.init(init_param)\n", + "\n", + " def _epoch(carry, rng):\n", + " state, params = carry\n", + " Z = jax.random.normal(rng, (n_atoms, 2))\n", + " lognorm = jax.vmap(log_reference)(Z)\n", + "\n", + " def _iter(carry, _):\n", + " state, params = carry\n", + " grads = jax.grad(nelbo_loss)(params, Z, log_pullback, lognorm)\n", + " updates, state = optimizer.update(grads, state)\n", + " params = optax.apply_updates(params, updates)\n", + " nelbo = nelbo_loss(params, Z, log_pullback, lognorm)\n", + " return (state, params), nelbo\n", + "\n", + " (_, params), nelbo = jax.lax.scan(_iter, (state, params), jnp.arange(n_iter))\n", + " return (state, params), nelbo\n", + "\n", + " (_, params), nelbo = jax.lax.scan(_epoch, (init_state, init_param), rngs)\n", + " return params, nelbo.flatten()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We train the parameters of our transformation by minimizing the negative ELBO. A plot of the loss shows convergence." + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAA20AAAEICAYAAADMVBwKAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAA+NklEQVR4nO3deXxcdb3/8ddnZrInTZq1bZo23VuK0JYCZZGt7KLl+lOEe5WCXHEXvSrivlz14s8FUfwhXFEWZRNRFmUp+76kLbTQhe5N2mZrlmbf5vv745ykkzRpkzbJTJL38/GYR86cZeYzkzOT8873e77HnHOIiIiIiIhIbApEuwARERERERHpm0KbiIiIiIhIDFNoExERERERiWEKbSIiIiIiIjFMoU1ERERERCSGKbSJiIiIiIjEMIU2EZFhZGaPmdnyaNcx1MzsPDP7xzA8zw/M7M/+9BQzqzez4FA/b7SZWYKZbTCznGjX0hsze9fMzoji84+ZfUFExgaFNhEZ1cxsu5mVm1lKxLz/NLPnhuG5uwJFJ+fcBc65O4bguZyZNfgHqp23a/uqI2K77WbW5K9fbWb/NLOCHutcYWZrzazRzErN7GYzyzhEST8Bru9R38yI+18zsz1mNt/MzjCzksN+8T7n3E7nXKpzruNIHyvWOedagD8C10W7lt445+Y7556Dg+9/g8Xfj8+OeP4xsy+IyNig0CYiY0EQuCbaRQyDY/0D1c7b/+3ndh90zqUCE4Ey4LedC8zsq8DPgK8D6cASYCqwwszie3swMzseSHfOvdbH8u8AXwZOd869288ao8bMQtGuoQ93A8vNLCFaBZhnSI8lYvj9FxEZNgptIjIW/Bz4Wl+tQ2Y218xWmFmVmW00s0silmWZ2SNmts/M3jSzH5vZSxHLbzSzYn/5SjN7vz//fOBbwMf8Vqy3/fnP+S19CWZWY2ZHRzxWjt/qlevfv8jM3vLXe8XMjhmKN6eTc64ZeAA4yn/+ccAPgS865x53zrU557YDlwCFwMf7eKgLgOd7W2BmPwb+EzjNOffeQGs0s2lm9ryZ1ZnZCiA7Ylmh36IXMrOPmVlRj22/YmYP+9MJZvYLM9tpZmVm9nszS/KXnWFmJWb2DTMrBf5kZklmdoffGrnezK6NbB00s0lm9jczqzCzbWb2pYhlPzCz+83sTr/ud81sccTyAjN70N92r5ndFLHsk/7zVZvZE2Y2tXOZc64EqMYL0r29Vwlm9msz2+3fft0Z8PzHvChi3ZD//Iv8+0v8fa7GzN62iK6O/j78EzN7GWgEpvfy3NvN7OyDfA7Szew281pbd/mfq6C/7Aoze9nMbjCzvcAPzGyGmT3jvz+VZvaXzs+zmd0FTAEe8Z/j2sh9IeL387B5n/HNZvapAfx+vuHXWGfe98PS3t5vEZGhpNAmImNBEfAc8LWeC8zrNrkCr9UiF7gU+H9mdpS/yu+ABmACsNy/RXoTWABk+o/xVzNLdM49DvwUuM9v9To2ciO/e9uDwGURsy8BnnfOlZvZQrzub58GsoBbgIdtCFtVzCwZ+BjQ2UJ2MpDo1xlZez3wL+CcPh7qfcDGXuZf7z/+ac65rYdZ5t3ASryw9t8c+Pvo9Agwx8xmRcz7d3/7zlpm4/3uZgL5wPci1p2A9zudClwNfB8vqE7He91dgdW8lqZHgLf9x1kKfNnMzot4vA8B9wIZwMPATf62QeBRYIf/+Pn+epjZMrzA82EgB3gRuKfH61wPHEvvvo0X6Bb465wAfMdfdg/d973zgErn3Cozywf+CfzYfw++BvzNup8/9wn/fUnza+/VQT4HtwPteO/9QuBcvDDf6URgK5CH19XWgP8BJgHzgALgB/5zfALYid9i3EcL871Aib/9R4CfmtlZEcv7+v3MAb4AHO+cS/Pfp+19vV4RkaGi0CYiY8X3gC/agQM3XARsd879yTnX7pxbDfwN+Kh/QP1/gO875xqdc+uAbuejOef+7Jzb62/7SyABmNPPmu7GC4mdIkPF1cAtzrnXnXMd/nlwLfTRquJb5beMdN7OO8i6kf5hZjVALV4g+bk/PxvvQL69l232ENHK1UMGUNfL/HOBx51zO/tZVzdmNgU4Hviuc67FOfcCXlg6gHOuEXgIP5j44W0uXvA1vPf3K865KudcHV6wiPxdhPF+7y3OuSa8QP1T51y138L1m4h1jwdynHM/cs61+oH0f3s83kvOuX/551jdxf6gdQJekPi6c67BOdfsnOtsyf0M8D/OufX+7+CnwILI1ja89zmjj7fsP4AfOefKnXMVeK2mn/CX3Q18yA/q4O17nYHw48C//HrDzrkVeP/4uDDisW93zr3r7/dtfTx/r8wsz3+sL/uvuRy4ge7v127n3G/9x29yzm12zq3wfx8VwK+A0/v5fAXAKcA3/Pf3LeAPwOURq/X1++nA+0wfZWZxzrntzrktA3m9IiKDQaFNRMYE59w7eC0aPQdumAqcGBl28A52J+C1boSA4oj1I6c7B9RYb2a1/rbp9B1menoWSDazE82sEK9F5O8RdX21R10FeAf4fVnknMuIuD3Rzzouds5l4LWqfQF43swmAJVAtvV+TtFEf3lvqvFaYHq6FPiImf2wn3X1NAmods41RMzrs5UHL5h0tib9O/APP8zlAMnAyoj39nF/fqcKv7to5HP3tR9MBSb1+F19C6+VqFNpxHQjkOi/rwXAjj6C8VTgxojHrMJrccqPWCcNqOn95TOJ7u/PDn8ezrnNeK10H/SD24fY/w+DqXj/tIh8Pafi/c57e/0DNRWIA/ZEPP4teC3dvT6+meWZ2b1+N8V9wJ/p/+dsEtAZzjvtoPv72Ovvx3+fvozXqlfu13Cwz6CIyJBQaBORseT7wKfofrBWjNclMTLspDrnPgtU4HXhmhyxftfIiuadv3YtXivMeD/41OIdWAO4gxXj/1f/frxgcRnwaMSBZTHwkx51JTvnenaPGzR+i96DeK0LpwKv4rXufThyPTNLxTtv7ek+HmoNXtfDnt4DzgY+Z2aHM+rhHmC8RYwEincuU19WADlmtgDv/e0MJZVAEzA/4r1N9wdj6dTzd7eHPvYDvN/Vth6/qzTn3IUcWjEwpY9gXAx8usfjJjnnXolYZx5et8ze7MYLSJ2m+PM6dXaRXAas8wNK5/Pe1eN5U5xz10dse9B9u4ee6xbj7VfZEY8/zjk3/yDb/NSf9z7n3Di81kA7yPqRdgOZZhb5j4QpwK5+Fe/c3c65U/HeS4c3MI+IyLBSaBORMcM/KL0P+FLE7EeB2Wb2CTOL82/Hm9k8P1Q9iDcQQrKZzaV7l6o0vFBXAYTM7HvAuIjlZUChHXx0vbvxzvP6D/aHCvC6133Gb4UzM0sxsw/0OPAciICZJUbcDjg3zn+eZcB4YL1zrhavS91vzex8/70pxAuaJXjdyHrzL/rouua80SLPBr5uZl/u8fyJPW7WY9sdeN30fmhm8WZ2KvDBvl6w323vr3jdPTPxQhzOuTDe+3uD7R/0Jf8Q3UnvB75pZuP9c76+ELHsDaDOH7AiycyCZna0eaNoHsobeIHwev93nGhmp/jLfu8/53y/xnQz+2jnhn4dmew/B7Gne4DvmDfATTZeF+HIoffvxeuy+lm673t/xmuBO89/LYnmDc4SGVoHotvnwDm3B3gS+KWZjTOzgHkDjRysu2MaUA/U+q/76708xwEDovjPVwy8AvyP/1qOAa6i+3vRKzObY2Zn+Z+XZrywHz7UdiIig02hTUTGmh8BXS01fsvWuXhd93bjdZP6Gd55LOAdnKf78+/COxBu8Zc9gdet7j287lbNdO/W9Vf/514zW9VbMc651/EGOpkEPBYxvwivVfAmvO6Gm4ErDvHa3rbu12n7dcSyy/AOODtvkeflPGJm9cA+vEEflvvhCucN6vAt4Bf+8tf917jUeYOp9PaaVuEdXJ/Yx/K38QZ0+L6Zfcafnd+jviZgRi+b/zveIBVVeC2nd/b9dgBeGDkb+GuPLojfwHtPX/O72z3Fwc9F/BFeUN3mr/sA/n7gh/uL8Lq3bsNryfsD3n5zUP62H8QbkGOn/xwf85f9HW9fvNev8R28Fs5O/w7c0dfvAW8gkSK8ls+1wCp/Xudz78FrTT0Z758ZnfOL8VrfvoX3D4livJB0uMcMvX0OLgfigXV4+/cDdO9+2dMPgUV4Ldn/pMfgOHiDlHzH7255wIBDePt/Id5n/O945ys+1Y/aE/AGranE+w7IBb7Zj+1ERAaVOTeQHg4iImObmf0MmOCc62vUQgHM7Fzgc865i6Ndy1Aws88Clzrn+jUYxhA8fwJet8jT/IE8RERkFFNoExE5CL9LZDxeS8XxeF3//tM5949o1iXDy8wm4nW/exWYhdfac5Nz7tfRrEtERMaG3k58FhGR/dLwukROwjtv5pd4Q8nL2BKPN8LhNLzRGu8F/l80CxIRkbFDLW0iIiIiIiIxTAORiIiIiIiIxLCY6B6ZnZ3tCgsLo12GiIiIiIhIVKxcubLSOZfT27KYCG2FhYUUFRVFuwwREREREZGoMLMdfS1T90gREREREZEYptAmIiIiIiISwxTaREREREREYphCm4iIiIiISAxTaBMREREREYlhCm0iIiIiIiIxTKFNREREREQkhim09eHxd0r5w4tbo12GiIiIiIiMcQptfXhuYzk3P7cl2mWIiIiIiMgYp9DWhylZyextaKW+pT3apYiIiIiIyBim0NaHqZkpAOzY2xDlSkREREREZCxTaOvD1KxkAHbubYxyJSIiIiIiMpYptPVhih/adlQptImIiIiISPQotPVhXGIcmSnx6h4pIiIiIiJRpdB2EFMyk9mh7pEiIiIiIhJFCm0HMTVLoU1ERERERKJLoe0gpmYms6e2idb2cLRLERERERGRMeqQoc3M5pjZWxG3fWb2ZTPLNLMVZrbJ/zneX9/M7DdmttnM1pjZoqF/GUNjSlYKYQcl1WptExERERGR6DhkaHPObXTOLXDOLQCOAxqBvwPXAU8752YBT/v3AS4AZvm3q4Gbh6DuYVGoESRFRERERCTKBto9cimwxTm3A1gG3OHPvwO42J9eBtzpPK8BGWY2cTCKHW5Ts7wLbG+r0AiSIiIiIiISHQMNbZcC9/jTec65Pf50KZDnT+cDxRHblPjzujGzq82syMyKKioqBljG8MhOjScjOY5N5XXRLkVERERERMaofoc2M4sHPgT8tecy55wD3ECe2Dl3q3NusXNucU5OzkA2HTZmxuzcNN4rq492KSIiIiIiMkYNpKXtAmCVc67Mv1/W2e3R/1nuz98FFERsN9mfNyLNykvlvbI6vFwqIiIiIiIyvAYS2i5jf9dIgIeB5f70cuChiPmX+6NILgFqI7pRjjiz89Koa26nbF9LtEsREREREZExqF+hzcxSgHOAByNmXw+cY2abgLP9+wD/ArYCm4H/BT43aNVGway8VADeK9N5bSIiIiIiMvxC/VnJOdcAZPWYtxdvNMme6zrg84NSXQyYnZcGeKHttNmxee6diIiIiIiMXgMdPXLMyU5NIDMlnk0ajERERERERKJAoa0fZuWm8p6G/RcRERERkShQaOuHORPSeK+0jnBYI0iKiIiIiMjwUmjrh6Pz02lo7WBrZUO0SxERERERkTFGoa0fjp2cAcCakpqo1iEiIiIiImOPQls/zMxNJSkuyJqS2miXIiIiIiIiY4xCWz8EA8bR+ePU0iYiIiIiIsNOoa2fjpmcwbu799HWEY52KSIiIiIiMoYotPXTMZPTaWkP63ptIiIiIiIyrBTa+qlzMJK31UVSRERERESGkUJbP03NSiY7NZ43t1VFuxQRERERERlDFNr6ycw4cVoWr23di3O6yLaIiIiIiAwPhbYBWDI9k921zRRXNUW7FBERERERGSMU2gZgyfQsAF7btjfKlYiIiIiIyFih0DYAM3NTyUqJ57WtCm0iIiIiIjI8FNoGwMw4cXomr23ReW0iIiIiIjI8FNoG6NSZOeyubWZTua7XJiIiIiIiQ0+hbYDOmpsLwFPry6JciYiIiIiIjAUKbQM0IT2Ro/PH8fT68miXIiIiIiIiY4BC22FYOjePVTur2VvfEu1SRERERERklFNoOwxnz8vDOXhmg1rbRERERERkaCm0HYaj88eRn5HEI2v2RLsUEREREREZ5RTaDoOZ8eFF+by0qYKyfc3RLkdEREREREaxfoU2M8swswfMbIOZrTezk8ws08xWmNkm/+d4f10zs9+Y2WYzW2Nmi4b2JUTHvy3MJ+zgobd2RbsUEREREREZxfrb0nYj8Lhzbi5wLLAeuA542jk3C3javw9wATDLv10N3DyoFceI6TmpLJySwd9W7tKFtkVEREREZMgcMrSZWTpwGnAbgHOu1TlXAywD7vBXuwO42J9eBtzpPK8BGWY2cZDrjgmXLC5gY1kdr22tinYpIiIiIiIySvWnpW0aUAH8ycxWm9kfzCwFyHPOdY7EUQrk+dP5QHHE9iX+vG7M7GozKzKzooqKisN/BVH0bwvzyU6N55YXtkS7FBERERERGaX6E9pCwCLgZufcQqCB/V0hAXBe/8AB9RF0zt3qnFvsnFuck5MzkE1jRmJckCtOLuS5jRWs37Mv2uWIiIiIiMgo1J/QVgKUOOde9+8/gBfiyjq7Pfo/Oy9atgsoiNh+sj9vVPr4kqmkxAe5YcV70S5FRERERERGoUOGNudcKVBsZnP8WUuBdcDDwHJ/3nLgIX/6YeByfxTJJUBtRDfKUScjOZ7PnTmTJ9eV8fLmymiXIyIiIiIio0x/R4/8IvAXM1sDLAB+ClwPnGNmm4Cz/fsA/wK2ApuB/wU+N5gFx6KrTp3GlMxkfvjIu7S2h6NdjoiIiIiIjCIWC8PVL1682BUVFUW7jCPy9PoyrrqjiM+cPoPrLpgb7XJERERERGQEMbOVzrnFvS3rb0ubHMLSeXlcdsIUbnlhC6+om6SIiIiIiAwShbZB9N2L5jE9O4XP372KHXsbol2OiIiIiIiMAgptgyg5PsRty4/HAVfe/iaV9S3RLklEREREREY4hbZBVpidwq2fWMzumiYuu/U1yuuao12SiIiIiIiMYAptQ+CEaZn86YoT2FXTxMU3vczbxTXRLklEREREREYohbYhctKMLO7/9EmYGR/9/avc+NQmWto7ol2WiIiIiIiMMAptQ+jo/HQe/eKpnHf0BG546j3O/tXz3PfmTprbFN5ERERERKR/dJ22YfL8exX88smNrCmpJS0xxNK5uZwwLYtFUzMozEohMS4Y7RJFRERERCRKDnadttBwFzNWnT47h9NmZfPq1r08UFTCC5sq+Mdbu7uW56YlMDUrmYLMZPIzkrzb+CQm+dMKdSIiIiIiY5NC2zAyM06ekc3JM7JxzrGtsoG1u2rZubeRHVWN7NzbyGtb9lK6r5lwjwbQ7NR48jP2h7jIQDd5fBLpSXGYWXRemIiIiIiIDBmFtigxM6bnpDI9J/WAZW0dYUprm9ld08SumiZ2VTexu7aJkuomNpbV8ezGcprbwt22SYkPMnl8MtNzUpiRk8qM3BSmZ6cyPSeFtMS44XpZIiIiIiIyyBTaYlBcMEBBptdVsjfOOaoaWtlV08TuGi/M7apporiqkY2ldTy5royOiKa6vHEJzMhJZc6ENI6dnMExk9MpzEohEFDLnIiIiIhIrNNAJKNQa3uYnVWNbKmo927lDWypqGdD6b6uFrq0xBDHF2Zy8owsTp6RzdwJaQpxIiIiIiJRooFIxpj4UICZuanMzO3e9bK9I8ym8nrWltSyuriG17bu5ZkN5QBkpcSzdF4uFxw9kVNmZhMf0tUgRERERERigVraxrjdNU28smUvL7xXwTMbyqlvaSctIcQ5R+VxyfEFnDgtUwOciIiIiIgMsYO1tCm0SZeW9g5e3lzJY2tLefzdUuqa25mek8K/nzCFS0+YQmqCGmZFRERERIaCQpsMWFNrB/9cu4d73tjJyh3VpCfFsfzkQq48uZDxKfHRLk9EREREZFRRaJMj8nZxDb97djNPrisjJT7I586cyVWnTtMFv0VEREREBolCmwyKjaV1/PLJjTy5roz8jCS+eeFcPvC+iTrnTURERETkCB0stGmIQOm3ORPSuPXyxdz9qRMZlxTHF+5ezVV3FFFa2xzt0kRERERERi2FNhmwk2dk8+gXT+V7Fx3FK1sqOeeG53lgZQmx0GorIiIiIjLaKLTJYQkGjE+eOo3HrzmNeRPG8bW/vs01975FQ0t7tEsTERERERlV+hXazGy7ma01s7fMrMifl2lmK8xsk/9zvD/fzOw3ZrbZzNaY2aKhfAESXYXZKdx79RK+ft4cHl2zm2W/e5lNZXXRLktEREREZNQYSEvbmc65BREnx10HPO2cmwU87d8HuACY5d+uBm4erGIlNgUCxufPnMmfrzqRmsZWlv3uZZ58tzTaZYmIiIiIjApH0j1yGXCHP30HcHHE/Dud5zUgw8wmHsHzyAhx8sxs/vml9zMrL43P/Hkld726PdoliYiIiIiMeP0NbQ540sxWmtnV/rw859wef7oUyPOn84HiiG1L/HndmNnVZlZkZkUVFRWHUbrEorxxidzzqRM5a24u333oXa5/bAPhsAYoERERERE5XP0Nbac65xbhdX38vJmdFrnQecMGDujI3Dl3q3NusXNucU5OzkA2lRiXHB/i9x8/jo8vmcLvn9/Ct/+xVsFNREREROQwhfqzknNul/+z3Mz+DpwAlJnZROfcHr/7Y7m/+i6gIGLzyf48GUNCwQD/vexo0pPi+N2zWzAzfrzsaAIBXYhbRERERGQgDtnSZmYpZpbWOQ2cC7wDPAws91dbDjzkTz8MXO6PIrkEqI3oRiljiJnxtXPn8NkzZnD36zv53sPv6FpuIiIiIiID1J+Wtjzg72bWuf7dzrnHzexN4H4zuwrYAVzir/8v4EJgM9AIXDnoVcuIYWZce94cws5xy/NbSYkP8c0L50W7LBERERGREeOQoc05txU4tpf5e4Glvcx3wOcHpToZFcyM686fS2NLB7e8sJWJ6Ylcccq0aJclIiIiIjIi9OucNpEjZWb84EPzKd3XzA8fXceE9CTOP3pCtMsSEREREYl5R3KdNpEBCQaM31y6kAUFGXzlvrdYv2dftEsSEREREYl5Cm0yrJLig9zy8eMYlxTi6ruKqGlsjXZJIiIiIiIxTaFNhl3uuER+//HjKKtt4Yv3rKZD13ATEREREemTQptExcIp4/nRsvm8uKmSm5/bHO1yRERERERilkKbRM3Hji/gQ8dO4oanNrFyR1W0yxERERERiUkKbRI1ZsaP/+1oJmUk8qV73qK2qS3aJYmIiIiIxByFNomqcYlx/PayRZTua+a/H10X7XJERERERGKOQptE3YKCDD57+gweWFnCsxvKo12OiIiIiEhMUWiTmPDFpTOZnZfKdQ+uUTdJEREREZEICm0SExJCQX7x0WOprG/l+sfWR7scEREREZGYodAmMeOYyRlceXIh975ZzFvFNdEuR0REREQkJii0SUy55uxZ5KQm8L2H3tFFt0VEREREUGiTGJOWGMe3PzCPNSW13PPGzmiXIyIiIiISdQptEnM+dOwkTpqexc+f2Mje+pZolyMiIiIiElUKbRJzzIwfLZtPQ0s7v3jyvWiXIyIiIiISVQptEpNm5aXxiZOmct+bO9lUVhftckREREREokahTWLWF8+aRUp8iOsf2xDtUkREREREokahTWJWZko8nztzJk9vKOfVLXujXY6IiIiISFQotElMu/KUQialJ/I/j60nrEsAiIiIiMgYpNAmMS0xLshXz53DmpJaHlmzO9rliIiIiIgMO4U2iXkXL8xn3sRx/OLJjbS2h6NdjoiIiIjIsFJok5gXDBjXnj+H4qom7n59R7TLEREREREZVv0ObWYWNLPVZvaof3+amb1uZpvN7D4zi/fnJ/j3N/vLC4eodhlDzpidw5Lpmfz2mc3Ut7RHuxwRERERkWEzkJa2a4D1Efd/BtzgnJsJVANX+fOvAqr9+Tf464kcETPjG+fPZW9DK394cWu0yxERERERGTb9Cm1mNhn4APAH/74BZwEP+KvcAVzsTy/z7+MvX+qvL3JEFk4Zz/nzJ/C/L2ylsr4l2uWIiIiIiAyL/ra0/Rq4FugcBSILqHHOdfZTKwHy/el8oBjAX17rr9+NmV1tZkVmVlRRUXF41cuY87Xz5tDU1sFNz2yOdikiIiIiIsPikKHNzC4Cyp1zKwfziZ1ztzrnFjvnFufk5AzmQ8soNjM3lY8dX8BfXt9BcVVjtMsRERERERly/WlpOwX4kJltB+7F6xZ5I5BhZiF/ncnALn96F1AA4C9PB/YOYs0yxl2zdDYBM3614r1olyIiIiIiMuQOGdqcc990zk12zhUClwLPOOf+A3gW+Ii/2nLgIX/6Yf8+/vJnnHNuUKuWMW1CeiJXnjKNf7y1i3W790W7HBERERGRIXUk12n7BvBfZrYZ75y12/z5twFZ/vz/Aq47shJFDvTZ02eQlhDi/z6xIdqliIiIiIgMqdChV9nPOfcc8Jw/vRU4oZd1moGPDkJtIn1KT47jc2fO5PrHNvDa1r0smX7AWDciIiIiIqPCkbS0iUTVFScXMmFcItc/tgH1wBURERGR0UqhTUasxLggXzlnFm8V1/DEu2XRLkdEREREZEgotMmI9n8WTWZGTgo/f2ID7R3hQ28gIiIiIjLCKLTJiBYKBvj6eXPZUtHA31aVRLscEREREZFBp9AmI9558/NYOCWDG1ZsoqGlPdrliIiIiIgMKoU2GfHMjO984ChK9zVz49Obol2OiIiIiMigUmiTUeG4qeO57IQCbntpG+v36ILbIiIiIjJ6KLTJqPGN8+eSkRTHt/6+lnBYlwAQERERkdFBoU1GjYzkeL79gXms3lnD7a9sj3Y5IiIiIiKDQqFNRpV/W5jP0rm5XP/4BjaUqpukiIiIiIx8Cm0yqpgZP/vIMYxLjOOae96iua0j2iWJiIiIiBwRhTYZdbJTE/jFR49hY1kdP3xkHc7p/DYRERERGbkU2mRUOmNOLp85fQb3vLFT57eJiIiIyIgWinYBIkPl2vPmsLWinv9+dB2FWSmcOTc32iWJiIiIiAyYWtpk1AoEjF9fuoB5E8fxub+s4uXNldEuSURERERkwBTaZFRLjg9x+5UnMCUzmStvf5NnN5ZHuyQRERERkQFRaJNRLyctgXuuXsKs3FSuvrOIu1/fqcFJRERERGTEUGiTMSEzJZ67P7WEk2Zk862/r+Ubf1ujywGIiIiIyIig0CZjRnpSHH+64ni+dNZM7i8q4cIbX+TN7VXRLktERERE5KAU2mRMCQaM/zp3Dn++6kRaO8JccsurfOcfa6lqaI12aSIiIiIivVJokzHp1FnZPPHl07ji5ELueaOYM37+LLe9tI22jnC0SxMRERER6UahTcaslIQQ3//gfB675v0cW5DBfz+6jvN+/QLPbCiLdmkiIiIiIl0OGdrMLNHM3jCzt83sXTP7oT9/mpm9bmabzew+M4v35yf49zf7ywuH+DWIHJHZeWnc+ckTuG35YpyDT95exCdvf5PtlQ3RLk1EREREpF8tbS3AWc65Y4EFwPlmtgT4GXCDc24mUA1c5a9/FVDtz7/BX08kppkZS+fl8cSXT+PbF87jjW1VnHvDC/z8iQ00trZHuzwRERERGcMOGdqcp96/G+ffHHAW8IA//w7gYn96mX8ff/lSM7PBKlhkKMWHAnzqtOk889XTueiYifzu2S0s/eXz/HPNHl3bTURERESiol/ntJlZ0MzeAsqBFcAWoMY519kEUQLk+9P5QDGAv7wWyOrlMa82syIzK6qoqDiiFyEy2HLHJfKrjy3gr585iYzkeD5/9yqW/0ldJkVERERk+PUrtDnnOpxzC4DJwAnA3CN9Yufcrc65xc65xTk5OUf6cCJD4vjCTB75wil8/4NHsWpHNef++gVuWPGeLswtIiIiIsNmQKNHOudqgGeBk4AMMwv5iyYDu/zpXUABgL88Hdg7GMWKREMoGODKU6bxzFdP57z5E7jx6U2c9+sXeP49tRCLiIiIyNDrz+iROWaW4U8nAecA6/HC20f81ZYDD/nTD/v38Zc/43QykIwCueMS+e1lC/nzVScSNGP5H9/gM3etZHN5/aE3FhERERE5THaoPGVmx+ANLBLEC3n3O+d+ZGbTgXuBTGA18HHnXIuZJQJ3AQuBKuBS59zWgz3H4sWLXVFR0RG/GJHh0tLewa3Pb+Xm57fQ3NbBxQvz+fLS2UzJSo52aSIiIiIyApnZSufc4l6XxUIjmEKbjFR761u45YWt3PHKdjrCjg8vyufq02YwMzc12qWJiIiIyAii0CYyxMr3NfP/ntvCvW/upLktzNnz8vj06dNZPHU8uuKFiIiIiByKQpvIMNlb38Kdr+7gzle3U93YxsIpGXz6tBmcc1QewYDCm4iIiIj0TqFNZJg1tXbw15XF/OHFbeysaqQwK5nlJxfykeMmk5YYF+3yRERERCTGKLSJRElH2PH4O6X88eVtrNxRTUp8kI8uLuDyk6YyPUfnvYmIiIiIR6FNJAa8XVzDHa9s55E1u2nrcJwxJ4flJxVy2uwcdZ0UERERGeMU2kRiSHldM/e8XsyfX99BRV0LE8Yl8uFF+XzkuMlqfRMREREZoxTaRGJQa3uYp9eX8deVJTy3sZywg8VTx/Nvi/K54OiJZKbER7tEERERERkmCm0iMa58XzMPrt7FAytL2FxeTzBgnDIzm4uOmch58yeQnqTBS0RERERGM4U2kRHCOcf6PXU8umY3j6zZTXFVE/HBAKfNzuaiYyZx1rxcxmn0SREREZFRR6FNZARyzrGmpJZH3t7NP9fuYU9tM3FB46QZ2ZxzVB7nzMtjQnrikD23LgouIiIiMnwU2kRGuHDYsWpnNU+uK+PJd0vZvrcRgGMLMjj3qDzOPSqPmbmpRxy0dtc08YHfvEgwEODnHzmG98/KJhQMDMZLEBEREZGDUGgTGUWcc2wur/cC3Loy3i6uAaAwK5kz5uRyxpwclkzPIjEuOODHfuLdUj5918qu++MSQ7x/dg5nz8vl5BnZ5KYlqAVOREREZAgcLLSFhrsYETkyZsasvDRm5aXx+TNnUlrbzFPry3hqfRn3vLGT21/ZTkIowEkzsjhjdg5nzMmlMDulX49dXOW14L183Vm8XVzDsxvKee69Cv65Zg8A6UlxnDEnh/PmT+C4qeMV4kRERESGgVraREaR5rYOXtu6l+c2VvD8exVsq2wAYFp2CqfPzjlkK9wPHn6XB1aWsPYH53aFsXDYsbq4hnd317KmpJan15dR3dgGQEZyHIumjGfpvFxOnJbF9OwUArpQuIiIiMiAqaVNZIxIjAv6XSRzAdhe2cBzG73Wss5WuMS4AEumZ/H+WTm8f1Y2syLOhSupbmTy+KRurWeBgHHc1PEcN3U8AO0dYVYX17Bu9z42lO7jpc2VPLOhHPBa4hZNyeC4qeM5vjCTYwsyDqubpoiIiIjsp5Y2kTGiua2DV7fu5fmNFbywqYKtFV4rXN64hK4Ad+NTm5iek8oflvf6T55eOefYUtHAqh3VrNxRzaqd1WwqrwcgPhhgQUEGx08bzwnTsjhu6nhSE/S/IhEREZGeNBCJiBygpLqRlzZV8uKmSl7aXEltk9fl8cpTCvn+B+cf0WPXNLZStL2aN7ZX8fq2Kt7ZVUtH2BEwmD8pnROmZXLCtEyOL8wkMyV+MF6OiIiIyIim0CYiB9URdryzq5Y3t1dx7lETmJKVPKiP39DSzuqdNbyxbS9vbK9i9c4aWtrDAMzOS+0KcCdOyxqya8+JiIiIxDKFNhGJKS3tHawtqeWN7VW8sa2Kou3V1Le0AzAlM7mrJe7EaZlMyUzWCJUiIiIy6im0iUhMa+8Is6G0jte3VXmtcduqukaozE1L4PjCTBYXeoObzJs4jqBGqBQREZFRRqFNREaUcNixpaKe17dVUbS9ije3V7OrpgmA1IQQC6dkcHyh16VyQUEGSfEaoVJERERGNoU2ERnxdtU0+QHO6065sawO5yAUMI7O9wY3WTx1PIs1uImIiIiMQAptIjLq1Da2sXKn1wr35rYq1pTU0trhDW4yIyeFE6ZlsmjKeBZOGa+LfouIiEjMO6LQZmYFwJ1AHuCAW51zN5pZJnAfUAhsBy5xzlWbN2LAjcCFQCNwhXNu1cGeQ6FNRI5Uc1sHa3fV+gObVFG0o5q6Zm9wk/SkOBZOyWBhwXgWTc3g2IIMxiXGRbliERERkf0OFtr6c5XbduCrzrlVZpYGrDSzFcAVwNPOuevN7DrgOuAbwAXALP92InCz/1NEZMgkxgW7znOD/efFrdpZzeqdNazaWc3z71XgHJjBrNzUrhC3cMp4ZuakqjVOREREYtKAu0ea2UPATf7tDOfcHjObCDznnJtjZrf40/f462/sXK+vx1RLm4gMh33NbbxdXMOqHTWsLvbCXOdFxdMSQywo8ALcoikZLCjIICNZ58aJiIjI8DjSlrbIByoEFgKvA3kRQawUr/skQD5QHLFZiT+vW2gzs6uBqwGmTJkykDJERA7LuMQ43j8rh/fPygG81rhtextYtaOa1cU1rNpRzU3PbCLs/y9ralYy78tP55jJ6RwzOYOj89NJTRjQ16aIiIjIEev30YeZpQJ/A77snNsXebFb55wzswE12TnnbgVuBa+lbSDbiogMhkDAmJGTyoycVD66uACA+pZ21hTX8FZJDWtLalm9s4ZH13j/czKD6dkpHDs5g/dN9sLcURPTdckBERERGVL9Cm1mFocX2P7inHvQn11mZhMjukeW+/N3AQURm0/254mIxLzUhBAnz8zm5JnZXfMq61tYu6uWNcW1rN1Vw4ubK3lwtfe1FgwYs3JTu1rjjpmczuy8NBLjFORERERkcBwytPmjQd4GrHfO/Spi0cPAcuB6/+dDEfO/YGb34g1AUnuw89lERGJddmoCZ87J5cw5uV3zSmubWVNSw9pdtbxdUsuKdWXcX1QCeEFuZk4q8yeN46jO28RxOkdOREREDkt/hvw/FXgRWAuE/dnfwjuv7X5gCrADb8j/Kj/k3QScjzfk/5XOuYOOMqKBSERkpHPOUVLdxNpdtazbvY93d9eybs8+yva1dK2Tn5HUFeA6A11+RhKR3c1FRERkbNLFtUVEoqSyvoV1u/exbs8+3t29j3W7a9la2UDnV296UhxHTfQC3PxJ45g7YRwzclNICKl7pYiIyFgyaKNHiojIwGSnJnDa7BxOm53TNa+xtZ0NpXXdwtyfX9tBS7vXmSEYMKZlpzBnQhpz89KYM8G7FYxP1rXkRERExiCFNhGRYZYcH2LRlPEsmjK+a157R5htlQ1sKK3jvbI6NpTWsbakln+u2ROxXZBZeV6Qmz0hjbl+mMtOTYjGyxAREZFhotAmIhIDQsEAs/LSmJWX1m1+Q0s7m8rr2Vi6jw2ldWwsreOp9WXcV7T/cpgZyXGkxIeIDwVYWJDBkhlZZCbHc/y0TJLjg8QFA8P9ckacW1/YQkZSPJccX3DolUVERIaZQpuISAxLSQixoCCDBQUZ3eZX1rewsdRrkdtWWU9jawdNrR2sWF/WdTkCgLigcdSkdCaPT2JyRhL545PIj/iZlhg3zK8oNv30XxsAFNpERCQmKbSJiIxA2akJZM9M4JSI68kBNLd1UFLdREVdC6uLq6luaGX9Hu/8uRXrymhtD3dbf1xiiPzxyeRnJDG5R6DLH59EVkq8RrcUERGJMoU2EZFRJDEuyMzcVGbmpnLSjKxuy8JhR2VDC7uqm9hV09TtZ0l1I69v3UtdS3uPxwswKSOpK9RNGJfExIxEJqYnMjE9iYnpiaQkjOw/JS3tHdEuQURE5KBG9l9aERHpt0DAyE1LJDctkYURg6BEqm1qiwhzjd5PP9it2LOPyvrWA7ZJSwwxKT2JCemJTMpI9IJdemJXuJuQnkRqDAe7qoYDX5OIiEgsid2/oiIiMuzSk+K8a8dNGtfr8pb2Dsr3tbC7ponSfc3srmmmtLaJPbXN7Klt5t3d+6isbzlgu9SEEHnjEpiQnkjeuEQmjEs8YDo7NYFgFC5pUFmn0CZj00ubKrn8j69T9J1zyEyJj3Y5InIQCm0iItJvCaEgBZnJFGQm97lOZLDbU9tM6b5mSmubKdvnTb+2ZS/ldS20h1237YIBIyc1gbz0RCaMS/DDXBIT0hO6hbvk+MH901XZsD9klu9rZnN5PSf3OFdQZDS67aWthB28ub2K8+ZPiHY5InIQCm0iIjKo+hPsOs+vK6tt8ULdvmbK/IBXtq+ZrRUNvLJlL3XN7Qdsm5YY6gpwOWkJZKcmkJUST1ZqAlmp8aQlhGhtDxN2sHBKxiHPudsb0eXz8j++wYbSOlZ9Vy0PMvpl+dd4VBdhkdin0CYiIsMu8vy695He53oNLe3dAl336Ra2VjRQWd9CS49RMTslhAIcNWkcSXFB0pPiKMxOAWBKZjIF45NJTgiyqayua/0Npd70q1v28oFjJg7iKxaJPRlJ3iU/9tQ0RbkSETkUhTYREYlZKQkhZuSkMiMntc91nHM0tHawt76FyvpW6lvaSQgFaG0P88yGcjaX19PS3sG7u/fx1PoynOOArpk9vbS5QqFNRr3Of3Zs39uIc06X9xCJYQptIiIyopkZqQkhUhNCTM1K6bbstNk5B6wfDjt21TSxu6aJxrYOKuta2Fxezy0vbO1aZ/XOGgA6wg7nHKFgYEhfg0g0NPiX+Hj47d2MT47jQwsmsWpHDf/5/mkKcCIxRqFNRETGlEDADjjnbvXO6q7QNjUrma0VDbR3hPnC3aspqWnkwc+eQnxIwU1Gl8jrMt7x6g7ueHUHAMkJQf7jxKnRKktEeqG/QCIiMualJe7/H+ayBfm0doT5x1u7efzdUt7ZtY+7XtsRxepEhkZ9czsnFGby4rVnkhAKdLVY3/zcFlrbw3zzwbVc+8DbNLYeOCCQiAwvhTYRERnzUhPiuqbPnpcLwI//uY7UhBBzJ6Tx16JiAG59YQvX3LuaLRX1UalTZDDVt7STkuCN9vr6t5ay8rtn86Nl8ympbuLi373MPW/s5P6iEj5910oq6lpwztHWEebd3bWU7WuOdvkiY4q6R4qIyJjX2dI2MT2RmbneoCc1jW1cdMxEjps6nh8+so5fPbmR3zyzGYB1u/fx8BdOpbi6kRXryjhrbi7zJvZ+QXKRWNXQ0t41ompGsneJi/PmTyAjeR3r9uzjE0um8r7J6Vz7wBqO/8lTBAziggFa2sPEBwNcfdp0ZuWl8ub2KhJCQebkpTFnQhoT0hMJmFHf0k7euIRBv7aiyFikT5GIiIx5KQkhbrx0ASdNzyI5PsTlJ03lzld3cN78CSyZnsXPHt/Ab57ZzLEFGXzl7FlcefubXHDjC5RUN9Eedtyw4j2+eu4cAgYvba5kalYyy08qJCctgY2ldWQkxzMzN5VgQIM7SOyoa2kntcd1DFMSQjzyhVP5++pdfPLUaaQmhHhffjovb66kprGNprYO5k8ax7MbK7jpWe+fGGkJIdrCYZrbDrz0RsCgMCuFQMBo6wgTF9zfDTMlIUgwYATMSIkPkRQfJD4UIBQw4oIB4kMB4oJGKBAgLhQgLmCEgt7yYMAI9XF//7S3bX/vB/T5lBhmzh182OPhsHjxYldUVBTtMkRERABvhMm1u2o5ZnI6ZsYrmyu567Ud/GjZ0eSkJfC/L2zld89t5qTpWVx7/ly+99A7vLipEoDZeals39tIa49rxyXHB8lMiWdvfStZqfEUjE8mMS5Ae9gRDBg5qQkkxQe7DiDjggFCwQDxwf3TQYNg50Gq+Qeqwf0HrMFAoOvgM/JA1vsZ8A9SexzoBrsvCwaMuEBAB7BjwLzvPs4nTprKty6cd1jbby6vp7apjQUFGQAUVzWyobSOvQ0tdIQdKfEhdlY18l5ZHea30rW2h6lvaae+pZ3Glg7CztEedjS2ttPU2kFbh9cF81CX5RgKAaPXz4n3+dv/2QgF9y+LXB6K+JzF+eExLnjgZ68zfB64vHNb77PcFVIjHjsu8rn9beKD+58zruunN61RQEcWM1vpnFvc6zKFNhERkSPjnBfy0pPimJqVQmV9C4+t3UNTWwez8tKobmjl7eIaapvayExJYG9DCyXVTbS2hwkFvRaIyrpWWto7aA872jsc7eEwbR3R+xttxv6D0K5wGHmw2DP87T+IDUUcdAYD+w9guy0L9t5a0m3dgBEM9vU4kXX1ODAOdD/IDgb3P17kss7nHIsHtu0dYWZ++zH+65zZfGnprGiXc4Bw2AtzbR1h2jrCtHaEae9wdPjz2/1gd1j3+1jmfe4i7oe952zrcHSEw7R1Lutcz/+Mdj5mm/+57Xqsrsfxg6i/bls4zHAdfnd+RiKDXCjQeytmZ/iM99ftnO65feQ/lEI95ndbJ2D+Y/vzOqdDB4bNnsFzLH4m4eChTd0jRUREjpCZcczkjK772akJfOKkwm7rfHjR5AE/rnOu6wCwtSPcdSDbEXZ0OO+gsCPi4HP/T/8A0XUeUO6f39bLge+BB7H7Dz47/ANO76f/2P46kcsiD3Jb2sK0hTu66ujX4wzjgWxPPbvKxfUIgl3hr49l3UNkZACNWDdgBIM9H+fAEBwM7D9QjmxZiWx1iQt0P8Dta3noIIG0oaUD8LpDxqJAwIgP2Ki91EZHxOelvcPb/zsiwl1k+Iv83LRHfKbaO/YHyTY/NLb1mO7cvjVien9rZpjW9v3Tbe2O5vYOb512r6a2XrbrnB5KkWGzt+DYM2zGhfb3Sui8xYd63A/uX3dieiLLFuQP6WsYbLH5SRURERHMzP+vNSQRjHY5Qy4cGf4iwmNfwbNbYOxxQNvtoLhbAO09vPYWJrueL+zo6CXItnWEaWzt5XHCYX/9ni03ritkD5fI7ntdrR/BAJ1ZLjVh9O9XsSgYMIKBIDGamQ+p5z+Uuoe7AwNee4/g2C1Ehh1t7eGulsvWiOmux+gKkb2H1Mamjm7zW9vD3Z6/tSPcrcv6oikZoy+0mdkfgYuAcufc0f68TOA+oBDYDlzinKs27985NwIXAo3AFc65VUNTuoiIiIwmXa0ro/yKROHw/lbQ3lpROs/pimx1aYtokYk8OO6+vHO7iAPbnl33OvY/1wmFmZw2Oyfab4eMQCPxH0rO7f/HTTgGTg8bqP7k+9uBm4A7I+ZdBzztnLvezK7z738DuACY5d9OBG72f4qIiIgIXjgN4B3wJsaNjANekZHOzO+KPEI/cof8V5Zz7gWgqsfsZcAd/vQdwMUR8+90nteADDObOEi1ioiIiIiIjDmH2/8gzzm3x58uBfL86XygOGK9En/eAczsajMrMrOiioqKwyxDRERERERkdDviTuPOu2bAgDuGOududc4tds4tzslRf2oREREREZHeHG5oK+vs9uj/LPfn7wIKItab7M8TERERERGRw3C4oe1hYLk/vRx4KGL+5eZZAtRGdKMUERERERGRAerPkP/3AGcA2WZWAnwfuB6438yuAnYAl/ir/wtvuP/NeEP+XzkENYuIiIiIiIwZhwxtzrnL+li0tJd1HfD5Iy1KREREREREPKP76pUiIiIiIiIjnLkYuCK4mVXgdbOMNdlAZbSLkFFL+5cMNe1jMpS0f8lQ0v4lQy0W97Gpzrleh9WPidAWq8ysyDm3ONp1yOik/UuGmvYxGUrav2Qoaf+SoTbS9jF1jxQREREREYlhCm0iIiIiIiIxTKHt4G6NdgEyqmn/kqGmfUyGkvYvGUrav2Sojah9TOe0iYiIiIiIxDC1tImIiIiIiMQwhTYREREREZEYptDWBzM738w2mtlmM7su2vXIyGNmBWb2rJmtM7N3zewaf36mma0ws03+z/H+fDOz3/j73BozWxTdVyAjgZkFzWy1mT3q359mZq/7+9F9Zhbvz0/w72/2lxdGtXCJeWaWYWYPmNkGM1tvZifp+0sGk5l9xf/7+I6Z3WNmifoOk8NlZn80s3Izeydi3oC/s8xsub/+JjNbHo3X0huFtl6YWRD4HXABcBRwmZkdFd2qZARqB77qnDsKWAJ83t+PrgOeds7NAp7274O3v83yb1cDNw9/yTICXQOsj7j/M+AG59xMoBq4yp9/FVDtz7/BX0/kYG4EHnfOzQWOxdvP9P0lg8LM8oEvAYudc0cDQeBS9B0mh+924Pwe8wb0nWVmmcD3gROBE4Dvdwa9aFNo690JwGbn3FbnXCtwL7AsyjXJCOOc2+OcW+VP1+Ed8OTj7Ut3+KvdAVzsTy8D7nSe14AMM5s4vFXLSGJmk4EPAH/w7xtwFvCAv0rP/atzv3sAWOqvL3IAM0sHTgNuA3DOtTrnatD3lwyuEJBkZiEgGdiDvsPkMDnnXgCqeswe6HfWecAK51yVc64aWMGBQTAqFNp6lw8UR9wv8eeJHBa/G8dC4HUgzzm3x19UCuT509rvZKB+DVwLhP37WUCNc67dvx+5D3XtX/7yWn99kd5MAyqAP/ndb/9gZino+0sGiXNuF/ALYCdeWKsFVqLvMBlcA/3OitnvMoU2kSFmZqnA34AvO+f2RS5z3jU3dN0NGTAzuwgod86tjHYtMiqFgEXAzc65hUAD+7sVAfr+kiPjdzlbhvcPgklACjHSoiGj00j/zlJo690uoCDi/mR/nsiAmFkcXmD7i3PuQX92WWe3If9nuT9f+50MxCnAh8xsO14X7rPwzkHK8LsaQfd9qGv/8penA3uHs2AZUUqAEufc6/79B/BCnL6/ZLCcDWxzzlU459qAB/G+1/QdJoNpoN9ZMftdptDWuzeBWf4IRvF4J8Y+HOWaZITx+9rfBqx3zv0qYtHDQOdoRMuBhyLmX+6PaLQEqI1o0hfpxjn3TefcZOdcId531DPOuf8AngU+4q/Wc//q3O8+4q8/Yv/jKEPLOVcKFJvZHH/WUmAd+v6SwbMTWGJmyf7fy859TN9hMpgG+p31BHCumY33W4PP9edFnWl/752ZXYh3vkgQ+KNz7ifRrUhGGjM7FXgRWMv+c46+hXde2/3AFGAHcIlzrsr/o3UTXveQRuBK51zRsBcuI46ZnQF8zTl3kZlNx2t5ywRWAx93zrWYWSJwF965lVXApc65rVEqWUYAM1uAN8hNPLAVuBLvn736/pJBYWY/BD6GN9ryauA/8c4f0neYDJiZ3QOcAWQDZXijQP6DAX5nmdkn8Y7XAH7inPvTML6MPim0iYiIiIiIxDB1jxQREREREYlhCm0iIiIiIiIxTKFNREREREQkhim0iYiIiIiIxDCFNhERERERkRim0CYiIiIiIhLDFNpERERERERi2P8HhB7NOwT9qMEAAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 6.32 s, sys: 164 ms, total: 6.48 s\n", + "Wall time: 6.57 s\n" + ] + } + ], + "source": [ + "%%time\n", + "parameters, nelbo = param_optim(\n", + " jax.random.PRNGKey(0),\n", + " initial_parameters,\n", + " logpullback,\n", + " learning_rate=0.01,\n", + " n_iter=1000,\n", + " n_atoms=1000,\n", + " n_epochs=4,\n", + ")\n", + "plt.figure(figsize=(15, 4))\n", + "plt.title(\"Negative ELBO (KL divergence) over iterations\")\n", + "plt.plot(nelbo)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We define our log pullback given the learned parameters of the transformation and use the periodic orbital MCMC with an ellipsis to sample from this log pullback density." + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [], + "source": [ + "logpullback_fn = lambda x1, x2: logpullback(parameters, jnp.array([x1, x2]))\n", + "logpull = lambda z: logpullback_fn(**z)" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 1.54 s, sys: 1.96 ms, total: 1.54 s\n", + "Wall time: 1.61 s\n" + ] + } + ], + "source": [ + "%%time\n", + "init_fn, ellip_kernel = orbital(\n", + " logpull, step_size, inv_mass_matrix, period, bijection=elliptical_bijection\n", + ")\n", + "initial_state = init_fn(initial_position)\n", + "ellip_kernel = jax.jit(ellip_kernel)" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/cabezasg/.local/lib/python3.8/site-packages/jax/_src/tree_util.py:188: FutureWarning: jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() instead as a drop-in replacement.\n", + " warnings.warn('jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() '\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 2.87 s, sys: 21.7 ms, total: 2.89 s\n", + "Wall time: 3.07 s\n" + ] + } + ], + "source": [ + "%%time\n", + "rng_key = jax.random.PRNGKey(0)\n", + "states = inference_loop(rng_key, ellip_kernel, initial_state, 10_000)\n", + "\n", + "pullback_samples = states.positions\n", + "weights = states.weights" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We need to push the samples through the learned MAF transformation to have samples from the target density (banana) and not the pullback." + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [], + "source": [ + "def push_samples(z1, z2):\n", + " z = jnp.array([z1, z2])\n", + " mean, log_sd = apply_fun(parameters, z)\n", + " x = jnp.exp(log_sd) * z + mean\n", + " return x[0], x[1]" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [], + "source": [ + "samplesx1, samplesx2 = jax.vmap(jax.vmap(push_samples))(\n", + " pullback_samples[\"x1\"], pullback_samples[\"x2\"]\n", + ")\n", + "samples = {\"x1\": samplesx1, \"x2\": samplesx2}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The pushed samples are much better at targeting the banana density than the algorithm without a preconditioning step. The transformation helps the sampler stay close to the same density level when moving around the ellipsis, thus reducing the variance of the step's weights along it. This preconditioning serves, in a way, as an adaptive step that tunes the parameters of the sampler through a transformation. Notice that if we move around the whole ellipsis there are no tuning parameters, only the number of samples we choose to extract at each iteration, in contrast with choosing step sizes and number of steps in the case of the other numerical integrators. Of course, we still need to choose a gradient descent algorithm, learning rates, number of iterations, and epochs for the optimization!" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAA+IAAAF1CAYAAABs5lCZAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOzdeXxc1Xn4/8+5984qjXZZsi3vG8Zms1lM2BMgkJINsqdN02b/pfk13/bXb9O0pS1t+qVLvk2XtNmbJmVJAyRkAQIkgCFgAzY2eF9kyZatfRvNfpfz++POjEfSyJZs2fLyvF8vv2zN3Ln3jGDunOec5zxHaa0RQgghhBBCCCHE6WHMdAOEEEIIIYQQQojziQTiQgghhBBCCCHEaSSBuBBCCCGEEEIIcRpJIC6EEEIIIYQQQpxGEogLIYQQQgghhBCnkQTiQgghhBBCCCHEaSSBuBAzTCn1YaXUk5M89qNKqRdOYVtO6fmng1KqTSl180y3QwghhDgR8r0/NfK9L85VEoiL85pSSiullo557C+VUv99utqgtb5Pa33rdJxLKfWsUurj03EuIYQQQviUUn+ilHp8zGN7J3jsA8c6l3zvCyFAAnEhhBBCCCGOZz3wJqWUCaCUmg0EgMvGPLY0f6wQQhyTBOJCHINS6kalVIdS6g+VUj1KqU6l1O/kn1uklBpSShn5n7+plOopee33lVKfz/+7Win17fzrDyul/qbki3tUWphS6lal1G6l1LBS6t+VUs+NHe1WSv2jUmpQKXVAKXV7/rEvAdcB/6aUSiil/i3/+AVKqaeUUgP5876v5Dz1SqmfKKXiSqmXgSXH+F2ElVL/rZTqz7/vV5RSTfnnfkcptVMpNaKUalVKfarM7/B/l/wO36WUeptSak++XV8sOf4vlVIPKaV+kD/fZqXUJRO0yVBKfUEptT/frv9RStUdr71CCCHEFL2CH3hfmv/5OuAZYPeYx/ZrrY/I97587wtxPBKIC3F8zUA1MBf4GPBVpVSt1voAEAcuyx93PZBQSq3M/3wD8Fz+398FHPyR8suAW4FxqWRKqQbgIeBPgHr8L/g3jTnsqvzjDcDfA99WSimt9Z8CzwO/p7Wu1Fr/nlKqAngKuB+YBXwA+Hel1IX5c30VyACzgd/N/5nIb+d/D/Pybfs0kM4/1wPcAVQBvwP8k1JqTclrm4Ew/u/wbuCbwG8Ca/E7EX+ulFpUcvw7gR8Cdfm2/1gpFSjTps8B78L/Xc8BBvPv6XjtFUIIISZNa50DNuJ/15P/+3nghTGPFWbDv4t878v3vhDHIIG4EMdnA/dorW2t9WNAAliRf+454AalVHP+54fyPy/C/3Lamh+NfRvwea11UmvdA/wT/pfjWG8DtmutH9FaO8C/AF1jjmnXWn9Ta+0C/4X/ZTrRiO8dQJvW+j+11o7W+jXgYeC9+ZH5u4C78+3alj/fsX4P9cBSrbWrtd6ktY4DaK1/rrXer33PAU/if9GWvvZLWmsbeBC/M/HPWusRrfV2YAdQOvq9SWv9UP74/4v/Zb6uTJs+Dfyp1rpDa50F/hJ4j1LKOlZ7hRBCiBPwHEeD7uvwg+Dnxzz2nHzvF18r3/tCHIM10w0QYoa5+KlmpQL4N/OC/vyXY0EKqMz/+zngHUAH/ij4s8Bv4Y82P6+19pRSC/Ln7FRKFc5hAIfKtGdO6eNaa62U6hhzTFfJ86n8OSspbwFwlVJqqOQxC/g+0Jj/d2k72ic4D/nXzAMeVErVAP+N/2Vo59Pk/gJYnn9vUeCNktf25zsQcHR0urvk+fSY91D6O/Dyv4M5E7y/HymlvJLHXPwOyoTtPcZ7FEIIISayHvhsPhW6UWu9VynVDfxX/rHV+WPke1++94U4LpkRF+e7g8DCMY8t4thfTKWewx8BvjH/7xeAaxidln4IyAINWuua/J8qrfWqMufrBFoKPyj/27alzHET0WN+PgQ8V3Ldmnz62meAXvy0uXklx8+f8MR+RsBfaa0vxE+buwP4iFIqhD/a/o9Ak9a6BngMUBOdaxKKbVL+GvwW4EiZ4w4Bt495f2Gt9eGJ2nsSbRJCCHF+ewk/9fkTwK8B8jOuR/KPHckvW5Pv/amT731x3pFAXJzvfgD8mVKqJV8A5Gbg7fgp5seltd6LP6r7m/hffHH8Ed+7yAfiWutO/JStLyulqvLXWaKUuqHMKX8OXJQvamIBn8VfZzVZ3cDikp9/BixXSv2WUiqQ/3OFUmplfqT6EeAvlVLR/Pqx357oxEqpm5RSF+VT2+L4WQMeEARC5L/g86PkJ7sty1ql1J3538Hn8Ts0G8oc9zXgS/msA5RSjUqpdx6nvUIIIcSUaa3TwKvAH+CnpBe8kH9sff44+d6fOvneF+cdCcTF+e4e4EX8L9FB/CIoH86vm5qs5/BTsA6V/KyAzSXHfAT/i2tH/joP4a/xGkVr3Qe8N9+OfuBC/C/97CTb8s/4a6UGlVL/orUewf9y/AD+yHIX8Hf4X6AAv4efGtaFX1jmP49x7uZ8u+PAzvz7/H7+Gv8v8D/59/Yh4CeTbO9EHgXenz/fbwF3TpBa9s/5az2plBrB/9K+6ljtPcl2CSGEOL89h18E7YWSx57PP1a6bZl870+NfO+L847SemxGixDiTJFPz+rAHxx4Zqbbczoopf4Sv9DKb850W4QQQojTSb73hTh/yIy4EGcYpdRblVI1+TVYX8SfXS+XniWEEEKIs5x87wtxfpJAXIgzz9XAfqAPf736u/Lr0oQQQghx7pHvfSHOQ5KaLoQQQgghhBBCnEYyIy6EEEIIIYQQQpxGEogLIYQQQgghhBCnkTXTDTiWhoYGvXDhwpluhhBCCDFpmzZt6tNaN850O8410icQQghxtjlWn+CMDsQXLlzIq6++OtPNEEIIISZNKdU+0204F0mfQAghxNnmWH0CSU0XQgghhBBCCCFOIwnEhRBCCCGEEEKI00gCcSGEEEIIIYQQ4jSSQFwIIYQQQgghhDiNJBAXQgghhBBCCCFOIwnEhRBCCCGEEEKI00gCcSGEEEIIIYQQ4jSSQFwIIYQQQgghhDiNJBAXQgghhBBCCCFOIwnEhRBCCCGEEEKI00gCcSGEEEIIIYQQ4jSSQPwEbWof5KvP7GNT++BMN0UIIYQQQgghxFnEmukGnI02tQ/y4W9tIOd4BC2D+z6+jrULame6WUIIIYQQQgghzgIyI34CNrT2k3M8PA2247GhtX+mmySEEEIIIYQQ4iwhgfgJWLe4nqBlYCoIWAbrFtfPdJOEEEIIIYQQQpwlJDX9BKxdUMt9H1/HhtZ+1i2ul7R0IYQQQkyrnUeGAVg5p3qGWyKEEOJUkED8BK1dUDvtAfjnH3yNZ/f0cuPyRr7ygcum9dxCCCGEODvc/s/r2dk5AsAFzZU88fkbZrhFQgghppsE4meIzz/4Gj/ecgSg+LcE40IIIcT5pxCEA+zqSsxgS4QQQpwqskb8FDiRrc2e3dN7zJ+FEEIIIYQQQpwbJBCfZoWtzb785G4+/K0Nkw7Gb1zeeMyfhRBCCHF++MS1C8v+WwghxLlDUtOnWbmtzSazlryQhi5rxIUQQojz25/esYo/vWPVTDdDCCHEKSSB+DQrbG1mO96UtzaT4FsIIYQQQgghzn0SiE8z2dpMCCGEECdKa03O9QgYBoahZro5QgghThEJxE+BqW5ttql9UAJ3IYQQ4jznepqOwRTpnEvQMphfF8UypZyPEEKciyQQn2GF4m45xyNoGdz38XUSjAshhBDnobTtkso5xMIBRrIOiYxDynYxDcWsWAilZIZcCCHOFRKIz7Byxd0Kj8sMuRBCCHH+sAwFKLK2i/Y0Gw/0s7c3iQKuWlzH5QvqZrqJQgghpokE4jNs3eJ6TEPhuRrDUNRGgzJDLoQQQpyHwgGTkGXwRscQTbEQ+3qTzK0J47ia7YfjEogLIcQ5RALxGba7awTb1QDYrubZ3T0ntP2ZEEIIIc5uiYzDT7ceIed6bD8SJxYyOTKUwdOaVbOrZrp5QgghppEE4jPs8W2do37ujmdOePszIYQQQpy9htI5srbHvPoIPfEsC+qjLGiowDQMljfFZrp5QgghppEE4jPs9tWzeX5vX/Hn918xnxXNMVkjLoQQQpxnGipC1MUCtA+kMLXikvm1LKivmOlmCSGEOAUkEJ9hH7pqPuDPjN++enbxZwnAhRBCiPNLOGjy4SsXcGQoQ1UkQGMsNNNNEkIIcYpIIH4G+NBV84sBuBBCCCHOX5GgxZJZlQBorRnJOgDEQpZsXyaEEOcQY7IHKqW+o5TqUUptK3nsH5RSu5RSryulfqSUqpngtW1KqTeUUluUUq9OQ7vPaZvaB/nqM/vY1D44000RQgghypJ+wal3aCDF9sPDbD88zKGB1Ew3RwghxDSadCAOfBe4bcxjTwGrtdYXA3uAPznG62/SWl+qtb58ak08v2xqH+TD39rAl5/czYe/tYFN7YMSmAshzkty7zvjfRfpF5xSXfEMNdEgNdEgXfEMjuux5eAgLx/oJ51zZ7p5QgghTsKkU9O11uuVUgvHPPZkyY8bgPdMU7vOWxta+0dtX/bw5g4e2dwh+4oLIc4rhUFJufeduaRfcOrVVQTpHckC0BgL8cS2Ll7c349SsOXQEJ+4brGkqwshxFlqKjPix/O7wOMTPKeBJ5VSm5RSn5zGa55VJjO7s25xPUHLwFQQsAwUjNtXXAghznVjByXl3ndWkn7BSVrUUMnSWf6fRQ2V7O1J0FQdYn5dhEODabKON9NNFEIIcYKmpVibUupPAQe4b4JDrtVaH1ZKzQKeUkrt0lqvn+BcnwQ+CTB//rlTwGyysztrF9Ry38fXFbcvA3h4c4fsKy6EOK8UBiXl3nd2mq5+wbnaJ5gs01A0xsLFny+bX8PTO7sBuHB2FeGAOVNNE0IIcZJOOhBXSn0UuAN4i9ZalztGa304/3ePUupHwJVA2UBca/0N4BsAl19+ednznYk2tQ8ec+/vcrM7E6VZrl1QO+q50sBcUjOFEOeDsYOScu87e0xnv+Bs7ROcKjcsb2ReXRTH1SxulP3FhRDibHZSgbhS6jbgfwM3aK3LlvNUSlUAhtZ6JP/vW4F7Tua6Z5rJzHafzOzO2MBcCCHOB3LvO/tIv+DUUkqxpLFyppshhBBiGkw6EFdKPQDcCDQopTqAv8CvhhrCTysD2KC1/rRSag7wLa3124Am4Ef55y3gfq31E9P6LmbYZGa7T3Z253gz7kIIIcTpJP2CmeG4HiMZG6UUsXAA05BibUIIcTaaStX0D5Z5+NsTHHsEeFv+363AJSfUurPEZGe7T3R2R6oHCyHOJPc+tpMntndx26pmvvC2lTPdHDFDpF8wM4ZSNgBaa+LpHLUVoRlukRDifPXMrm4+89+bybkeb13VxH/8puxGORXTUqztfHeq1zJOZX25EEKcSvc+tpOvrW8FKP4twbgQp4+rNUHTQGvwyi/BF0KI0+L3H9xCJr97w+PbutlycJBL50uMMlnTuX3ZeW3tglo+e9PSUxIgj93STKoHCyFmyhPbu475sxDi1KoOW9iuh+N5xMKBmW6OEOI85nijBwNHss4MteTsJIH4WaAw4/4Ht66QtHQhxIy6bVXzMX+ejE3tg3z1mX1sah+crmYJcd4IBy0aYyEaYyGClnTjhBAz549vu4BClYoVTZVct6xxRttztpHU9LOEVA8WQpwJCmnoJ7pGfDI1L8YWp5RilUKMli90Nyn/8PhOvvPrA1yztIFvffRKAP7Xg6+xpzvBP773YlbOqT5VzRRCnON++00Led/aFpK2Q0NleKabc9aRQPwctPruJ0jkXCqDJtvuuQ2Ad/3bC2w7Emf1nCp+/HvXznALhRBnsy+8beUJrws/Xs2LsYH63Xes4p6fbZdilUKcgC/88DUe3HQEgKd39bLub5/C9TS9Cb/g2+3/8gIbv/hmmqoiM9lMIcRZLBKyiIQkpDwRktN0DihN8ywE4QCJnMvyP32MW778LFs6hnE8zZaOYd71by/McIuFEOer49W8GBuoP76tc1zgLoSYmON6JLMOGdvlB/kgvKArnisG4QVfeWrv6WyeEEKIPBm+OI4zISXyWG0YO3uUsb1Rz+dczd7e5KjHth2Jn/I2CyHOTSebNn68XSbGbgd5++rZvNI2cNztIYU4n+h8tfSxKequp+mOZ3E8DzTEQgbxrFfuFEW3rW46Ze0UQpy7RjI2rx0cwjQUaxfUEg6YpHMuWcclGrSkhsUkSCB+DGfC/t3Ha8PY2SPLUOMqGI61ek7VqW62EOIcNNW08WMF7Z+9aWnZa5QL1Fc0x2Z8QFSIM4XrHg2sDWN0MO54Hq7nEQ1aZGyXT924lH/4xZ7i879/82JcF76xvpWcC3deNocbVkggLoSYukc2d9Adz+J5mp54hreububwYBrTUAwkc8yvi2KZEowfiwTix3Am7N99vDaMnT26c00L9288WHxeAaGAwbyaCAf6UxOuET8TZv6FEGe20vtRzvb4xvr9E96fTmat99jilFKsUojRCsG352lM82ggHjAMLNMglfO3EFrUUME1S+oxlaJnJEPItKivMnFcMBX80a1Lxp3bdj02tQ3Sk8iwek41ixsrT8+bEkKcNbx89k1zdRjb8egYSJFI2yggGrQYydg4nsYyZ7qlZzYJxI9hbJA7EymRtdEghlKgddk2FGaPvv7cfrrjGapCFuGA32bTULz38nncuaaF93/9RRwPth0ZHneNM2HmXwhx5ivcE3O2hwe096fQgFFmvfdk1nrLfUaIqcnaLq7nYSoDy1QYxujUdMNQzIqF2NUVJ562uWJhHUop2vqSrFtcz4L6Ctb89VP+wRqu/rvnabv3N0adY2dnnNcPD1ETDfLMrh7qK0NUR2S/ciHEUYahuHxhLS8fGMD1NNctayBgGeRcj5GMTThgEpDZ8OOSQPwYjreW8VTb1D7IPT/bjqc1hqG4+45VZduwu2uEJ3d0A7C1Y5hPX7+YWCRQbPPSL/4cJ5/J5niw9Is/Z9/fHv3iPRNm/oUQZ77CPfErT+/h1/v68LRf8fOapQ18/ublx8zWkbXeQpycnOOSdTwMpbBdl5gVKM6Mu57G8TxClsmuzjiPbesiYCq2dgzx21cv4vbVswGIx49fIyZjewRMg4qgyXDaJucce425EOL8dNOKWVw4uxrHc6mvCGEYijk1EYKmQShgYo4ZKBTjSSB+HDOZElkaICs0g6lc8bnSVPLHt3WOet32zjjf/9hVxZ/HfoeO/flMmPkXQpwd1i6o5fM3Lx8VVI8NwgvHyVpvIaaPpzUKMA2F69diAyCRddjQ2kcq67KwoYLDgykqwyYNlWE6BlIMpW2i+a2FqqqOXyPmgtkx2vqTdA1nWd4Uo74ieOrelBDirKWUork6jOvpfLaOJhIwCQUkH32yJBA/g00UII9NJf/o1Qt5fm9f8XWFke8CyxgdfI8tYjjTM/9CiLPLZO8ZstZbiOkTME1s18F2XYKmUZxtautLknM0jbEQbf1JFtVXsLNzhEMDKeorAuMC6RuW1vHcvgEAbl3ZMO46VeEAd61pIed4hAPGuMrsQghRyjQUkaCJ1oxbLiOOTRW2wDgTXX755frVV1896fPMRCGy6bpmufN89Zl9fPnJ3XjaL7byB7euoDYa5PFtndRXBOlP5rh99Ww+dNX84nkK6emWwai0dCGEOJOdjYUklVKbtNaXz3Q7zjXT1Sc4m2mtx3V293aPsKMzTkXQZCht85YLmrBdj56RDNs6hhlM2dywvJEL51Yf89x/+IPXeHx7F5cvqOV7H1t3qt+KEEKcF47VJzjnZ8RnohDZdF6z3AxSsWCS46GUojYaLAbdX/zRGwDFGfLC4+WC71u+/Cz7+5IsaajgqT+88YTaJ4Q499372E6e2N7Fbaua+cLbVp6260ohSSFGU0oxdoJ6UUMFmZzLzu4R6isCtA8kWdxQweHBNNs749RGg/zPpkP8r9oI1dHyaeaff3AzP97iL3Nbv7efd/3b8/z496471W9HCHGWcV2PJ7Z3MZDKcfPKWcyujs50k85q53w5u3KFyM72a65dUMvdd6zCUArX09zzs+1sah8ct1Z87M8Fm9oHWfvXT7K3N4mnYW9vklu+/Oy0tlEIcW6497GdfG19K239Kb62vpV7H9t52q59Ou7fm9oH+eoz+9jUPjjt5xbidLBMg2XNMS5ojrGkMUbINBlM2SSzDiHLJBoy8TzIlim6Zrsev9rZzc9eH91f2Npx/KJuQojzzzee388/P72H+186yB/98HXS+a0SxYk55wPxwuyxWWZ7nbP5moOpHJ7WaPz9fO/56XYytjvqmLFrxeHoDFN/0h71+P6+5LS3UQhx9ntie9cxfz6VTvW9tHA//PKTu/nwtzZIMC7OWgHTIGAq4mmbkUyOoVSOuTURtNYcGcpwxaJaGmOhca/b2z3CtiNx5teNntWaVxs+XU0XQpxFth4aprYiyLz6CAPJHN3DmZlu0lntnE9Nn4lCZKfjmmP3893a4e8PbipYPbea918xf9Qa8YLCDNNYSxoqpr2NQoiz322rmvna+tZRP58up/peKls3inOFaSiWNMYYSGYZTitGMi6VYYt3XjaXuTVh6ivLB9Zag1LwuTcv428f20lvIsf82jDr//gtp/kdCCHOBm9aUsd9Lx9iOGUzry7KnJrITDfprHbOB+IwM5V6T/U1S/fzLa2Y7mo/EC8XhMPoSuxevk7f0kZZIy6EKK+wJnyqa8QnKrI21eJrp/JeKls3inOB52kMQxG0DGqjQRzXozee5dW2OB7wrkvmTvjaJbMq2deboGMwzd+952JuXD5Lqh4LISb0kTctZvGsGMMpm2uXNhCUrcpOynlRNf1cMrYTu6l9kA9+c8OoWe6gqXjgk1cXO6+FomxzqsN88KoF1EaDDKZyZ1UVYiHE6THVQLnc8WOLrN19xyoGUzlqo0Hu+dn2M6r42qmoyi5V008N6ROMprVmJGPjeJqgaVIR8rcPimdsvrF+H0NJm1g4wJzaCL/9pkUA/M5/buSZ3f7g/aK6MA986hoMBXXRINbYvU2FEOe9QwMpXmrtp6EixLXLGgiW3Cc8T2N7HpZxdDtFMd55XTX9XDJRBeEHPrGOe366vZie7nq6mGJ5y5efZW+vv/67YyjDP/xiN+HAmdEBFkKcWY5XpbzcQODY4wG+8vQesrZXrGFx96Pb8LTGUApP65NOBZ/O4Fn2NhdnK9vVuB6ELJOc4+J4BgHToDoSIBoM0FgZJmCaDKVzxdcUgnCAAwMZdhwcpLIySMZ2mV8/eonafRva+Ndf7ePGFQ3ce9elp+ttCSFmyFAyx6fv20TIMvj6h9bgKsX9G9uxTIOdnXGUgpsumAX4QXhfMovr+d/t9RVBLFMG86ZKAvGzyCObO4qd29JO7NoFtdz99lV8+FsbxqVYlivCNlEHWLYzE+L885Fvb+TltgGuXFjHVYvrJ1wzXS7oHrvG+uHNHTyyuYOc49+nDPz9jgvBN9pPoVXoE04Fly3NhPBnw13Pw3E9/IkoRWE+SinFW1Y28eT2LjQOt6xsmvA8SdelOWSRyI6ufPy9Xx/g7p/uAODBVw7zRscwP//9G07NmxFCnBHW/s1TuPlE6Yv/+ile/JO3kHM9GmMhtNb0JbLFY23Pw/U0kYBJxnbJOZ4E4idAAvEzyLFmeTa1D/LDVw9RWEhgmqM7sRMVNVrSUFGcES9QhhrXAS6dOS9sZ1YuGD8VaZxCiJnxkW9vZH2+xsT6vX3E0/aEa6bLFTYbu8ZaQfEYQ8E1Sxu4ffVs7vnZ9uIxhTT1E72HSIE1IfzMt4BpEAmaOK4mGjRHdYJXzq5iYX6GOxI8uoZzYV2EtoG0/7ilSOU8BpI5Lmiu4i3/+AzDmRwPf+oavvLLvaOut70zMWE7co6LUhAOSJdSiLNVMmUXg3CAnKupiwZZPaeGbUeGCFrmqD6BqRQGiqzjovELRmqtUUpS1KdC7ppngE3tgzyyuYMfvnrIX+tVkuJZ6OxuaO3HyVdXU8B71raM63yWS7F86g9v5II/e5xMyRryhfUVrF1Qy6V/9QuG0g41EYt4ZvRoeLmZdJmJEuLc8a5/e4Et+eUsBbu6RyasUl4adJuG4shQmt1dI1y/rJHueIb3XzGfFc0xHt7cUQy6P3/zctYuqGVFc2zSA3ilg33AuNdJgTUhjgoHTAhQdiaqNAAvePZ/v5ndR4Zp7fe/4/viGVpqI1z+paeLx9zw5ed48/I6frVn4Oh1rPGda601w6kcHvjbqWqIBKVbKcTZ5KdbD/PfL7WxsW1o3HOGoXjXZXO4fnkDkaBJtOTzbZkGdZVB7Hx8YSqF42osEwnGp0DumFM03TPCheC2kHIO/izPI5s7eDif4lkodlTa+bxrTcukr3H321fxxR+9Ufy5rS/B6rufIJHz9x0fSjuM/cjURgNsah/kCw9tLaarv2tNi8xECXGWu/exnXzn1wfIueMLdV65sG7CNdOFrJuHN3fw0KYOHnj5YHHnBYCdndt44JNXc9/H1/HI5g70mNdOtvBbYbDPMg3QetTgZOE8p3tLSiHONKahcPMfwKkWSWquifD64WEO9CWpilpsOTQ07phL5lXT2pehbSBFQMGDH7+KZNbhkr/8BU7+w/2v77uEq5c1EgwYOK6H7XrIRkZCnD1+775N/PyNLsqV7Z5XHQL8oLq+MlT29QHTwDIUrqtRhoL8gJzE4ZMngfgUnIoZ4UKaZeFDoICAZfhFjkqC3sFU7oQ7nyuaYyydVcm+Hj+1TGuKQXiBBpY1VrAvn54+mLJ5z3+8WGzX3t4k31y/X2aihDiL3fvYzlF7go/1+zcvLx5XbruytQtq/ewc1xsVhINfOKpwfyoMIj6yuWNK98mxaefAuJoYhXZIAC7OZ0opLHNqvV2tNVs7htjTlWBXV5yW2ijzaiN0xrPjjr1sfh2/f8uF7O9NsP1wnM6RHF/40a+LQTjA5/5nK6//xa1kbReFIhySLqUQZ5OndvaUDcIBmqojeJ4mnrH52dYOsg7cdlEzc2uio45TSoHSaK1Bg5Jl4lMid80pOBVrE0ele5oGNyxvZFYsRCxkYSgFJUWNTqTzef/Gg9z96LbiyLmBP3puwrgZsXvfcwkbWvv58pO7x3WywZ85//T1i4lFAjITJcRZ6HsvtR3z+Q2t/Ty1vasYrBf+nl9fwePbOrl99eziPatwLywwDTgylC4WazuR+2RtNFi87xVmxF3vxAu7CSGOOtif4lc7e4iFLBpjQdr6UhiGImAYNFUG6E7YxWOf29vHm5bMYjCZozJkUhkOMJDKjTtnLGwR8UwMpWT7IiHOIkPJLK7rlX3usnkxbl7VhO16/NeLB9jVNYKhFG39Cf7i7atxXQ/H1QQDBgHLzK8P94NwSUufGgnEp+BYaxNPNGW9NM1yJG3zrRcO4Hq6OEJlGYq771g14TmPV+Dt7ke3jVpbflFLNTu7RkbtO17w8OYO7lrTUnyPnj+4Ncr2zjjf/9hVk35/QoiZV249+FiWoaiNBvn+mGD9ey+1kbL9+8Xze/v49PWLuXNNCwpYNaea7UeG6RnJ8tyeXh54+SCWobBMA9edWubMpvZB7vnZdlxPYxqKv3z7qimtLRdCkM9W0ViGgVESGGut6RxOU59f05nJaW5Z1URzVYiu4cyoIBzgF2908dkblzG/LsorbQMkcy73vP1CPn3/luIxAVOhlCIwxZl5IcTMae9L8uYvP0uZ1Wmsmh3jumUNtA8kGUrZfPelNlp7klSFLEwFRwbSJFJZMAwChkEq61KplL8bipIg/ERIID4FE61NPNmU9cKx7//6S8WguUBrzWCZUejjXXdT+yBfeXpPcSYc/DUbTVVhXp+gQ67KvMdPf/9VehNHr3/76tmTfl9CiJl3rCDcUHDx3Gq2dAzjeJp7fradRfUVdJWkqhaC8IJvPN+af63inndW86V3X8RXn9nHL3d2+7PgruaWC2dxybyaYhD+1Wf2HTeYLl2mU7jvSQq6EJPnuB5p28VQCtt1iAYtlPLXkqdt/+dZlSF6EhnqzCBKw0jGxSnTI2+oDJHzPJqrIty4YhY5x6MqEuDut6X4+yf3EbAM7nnH6hl4l0KIyXJdl7sf3cHBgRSfe/NSggGDD359Q9kgHGB31wgdQymytuaiuTluXTWbpuoQu7pGcLXmspYaEjmXqrA/0Kc9d1ScYcqg3JRJID5F5TqG5fbS3dDaT200OOlteja09o/6nxn8NPJjzShNlCo/tgCcn+gJnoZf7eo++sAYd+YLwJW+x1f+7Bbu33iwmJb6oavmT+K3JISYKYX13fPromjg9cMTz4R7mlFBetb2WLOgll3dI+j8PWLs7SK/DAxPa+5+dBsrmmOsW1yPZRrFQPrZ3T186oYlAJMepJRq6EKcHE/rYoq441IsmmQ7HgrFklmVxMImS9wKDg5kiIYseuIpGirDXDankteOHN2i7HeuXURNJAhARciiIl+r6XevX8bvXr9sJt6eEGKK3vnVF9l2JA7A8/v6jnu8q6EmEiCuXLZ3xrliUT1XL2ng+uWNWMqgMRZEozBMA9v1CJgGRn4mXGst25edAAnEp8HYbX0e2tRRLGakgFDg+LPk6xbXEwoY5GwPw1B8/NpFx12LPVHHtXRmyQDm10dp70+hgTIZ6QB86Kr5o65z7b2/5PBQhrk1YV74wlskABfiLFBajK2tPzXl12v8AblVc6r580e34eXXZ//umxayvTPOqtlVfOuFA8XMHc/zC7R99qalvGdtCw9sPIjG31t4Q2s/wKTXi0s1dCFOjmUY5FwH29VYhiqmppsG2I4fpM+qCpOxPTqGsmitsT1NTSTAuy6fj/tKO/t7kyyZFePWC5sYyTh0DKb4+nP7+OGmI4Dfp2i99zdm8F0KISZrV9fIlI43FIxkPX/w3DRIZR1mV0eIBgz6EjmG0jYLG0JUhgN4nsYwFK7r+YXakNT0EyGB+DQo7UAeGUqP2tanXMXfgrHru6faCS33mk3tgxweSo9ao3nbquZRnedyXtzXx/0bDzKYyvHdXx8opqN3DGW49t5f8sIX3nLCvx8hxOnxxPauso+HLQPH08e8B4Dfyd7dNcK2I8NHtx/RmltKqqfPr68oFoA08uvKAe5a08IjJXuIFwYGpzLLLanoQpw4w1BUBC087XeoAdI5F9t1CZgGpqGwDBOtHVzXZVdXivl1UWZVh0lmbHb3JPE82NE5woe/uZFg0CQWtnhqR0/xGh7+lkf/9uG1x22P1hpPT317NSHE9JhbG6G9zKD8W1Y0EAsHeOL1TjIl3YJPXLuQ9sE0nUMZPnjVfEyleHZ3LwvrIxgKFtRXksq51HoehuGXRzdNKZN+MiQQnyaFDuSm9sHi1j2enji9vNz67pO57thzWobiA1fOJxayigXgzHwxBe3pcetD2vpTfPFHb2AoxlVM7xjKAPCRb2/k5bYBrlxYx/fKFGyb7j3WhRBTc9uq5rLbk7la84NPXc3vP7C5+HkGaKkJM7smwittg/4DimKQXbgNOJ7mK0/v4fbVs9l+ZBgN3HHxbH6y9Qhufl35iubYhIOJMsstxOmjlKKwTDORsemMZzCAoGUytyaC43rs7R5hbm2UpqowIcvE0zCYtjEMg4qgIpF1ae1PsrwpRqDMDFcmv/3p+j09/Ouv9jGvNsI/vveSYscc/GyZoVQO14NgwKAqbMlsmRCn2S9+/3o+9d+b6BhIMZTOMZiyWVgfZU5tBR+9ZiF/d+cqbv6nF+geyfLmCxr537evRKN4YlsnqZzL4cE0s2vDWJYilfEIBQwytkPO0RjKJWAZ8rk+SVMKxJVS3wHuAHq01qvzj9UBPwAWAm3A+7TWg2Ve+9vAn+V//But9X+deLPPXKWd0WOtER+7vvuRzR3FAL5c4bXJdGQ3tPYX14U7rt+RLp0J11rzgSvmM6cmwsbWftbvHb9eZKIJs498e2Px+PV7+/jItzeOCsZPxR7rQoipKcxaP7G9i0zOpWvEL7pWSCF/4Qtv4dp7f0nHUIaKoMn1K2axek41bxwe9teRKjVu1tzTfrX058vcL8BfV368lHO5F5ybpE9w5vHXafprw+MZh8FEjqBlYBkus2IhElkHT/t9BKUU4YBJLBzgHRfP5ZFNHcTTDhqYVxsmnrEp18X+yvtWc6A3wUe+8woAr7QNsq8nwU8+d13xmKzj4mhNKGCSdVxs1yRoSYddiNMpHDT5r9+9kr3dcf7qpzto709xzdJGlIJkxibUWMlda+fx5I4u9vQk+dvHd/Hxaxdz44pZtPUlaYwFGUjkGEzmiKcdEpkoDbEIlumvCXc9jSUF2k7KVGfEvwv8G/C9kse+APxSa32vUuoL+Z//uPRF+S/mvwAux8/W3qSU+km5L+dzwWQ6nmPXd2vKr6U8XoBbGqTXRoPFWSwP6BvJjioAZyjFnWtaWLugls/etJS3/fN6dnQef/1ILGTyctvAqMfG/nwq9lgXQkxdIRj/+vNHZ8Y9Dd9/qY2RtM0dF8/h68+3ksy53L/xIJapuOOi2Ww5NMT8umjZAbpj0fj7f29qH+T933gJx/W/mO95x+pJF6ucDpKRMyO+i/QJzhhaa+x8IRjb8UjnHDQerucXYTIAwwBDaQ70J1hYF6WpOoxpKC6aV8ODn7yaH205QnXIxDIUPckskWCAtv4k3SP+crUFtWGUFeCXuw4BR4s57u6Kj2qLX1FZ43gelKTKCyFOrx++cpAv/nhbcYLu+xvaqQwaPLL5ELWRIB5Avgjr+j29rJ1fwxUL6/C0pjJoEq4KY5iKSMikpTpKJGgWg3DZuvDkTSkQ11qvV0otHPPwO4Eb8//+L+BZxnzpAm8FntJaDwAopZ4CbgMemFpzzx1j0ziBsusrjxXglgbphvK3GykwFDTGQvk0Eg+Fn05a+trL5teyryeB7epRVZFXNsfYWVLg4U/ediFPbOsc1UG/cmHdqPcjFY+FODPcv/HguPR0DXTFs2XT1h1X8+MtfiGmsQXeFP76ztJU9bEUMJjK8fXn9he3QXJczZ/96A1QjFp6c6oCZcnImRnSJzgDKYXjesQzObT2iAYsKkIBQpaBZRls3jvAt3/dhudprl7SwCXz/e/yzuE0GdvjzjVziQRNBpM5Dg+keH5fH3deNofhjIunNZfOr6UiFOCWlbP4m5/vLN4XVjbHikvgAEKWSVUkQNbxqIiYWLKOVIgZ8W/P7PMzZUoeS+Q8GisDDCRtUrZbfG6BEaZjIE3n8GEswyBju8yuCrO6pRpDKSJBk2DA9GfCDX/AzStZLy6mbjrWiDdprTvz/+4CmsocMxc4VPJzR/6xcZRSnwQ+CTB//tlRqftEZ2LGzpyXW0tZ2BbIdjxMc3SAWxqke5ri3r8Kv/N755oWYiGLr61vRQM/3nKEKxfVs6I5dnQtuWlw84WzeHZ3D3a+E72nZ4RPX7+4WCV5MJXj929eDvgz4Rc0xbhqcT2b2geL7ZSKx0KcGR7f1nn8gybJzM9s/+CVg2wt2ebsyoV+to6rIWAq1i2u58kxheIKo+yFLR0fmWDpzXSQjJwzynnfJ5hRWpPJ2VimQX1lmKFUjsqwRU00iFKKJ3f0UBcNEgtbvNo+SH8iS871eHjTYQzlz5i/Z+08ElmbP3poqz/Qbyh+/y3LuXlVMwp4eFMHc2si3P+xK/jnZ/bRVBHmD9+6nM6hNI1VIUKWCUAkaJHfAU0IMUNi4QABM+tnp1CypbHnYWtvVIA+krGprbDYeXiYOfWVeBriGZto0B9MMwyFUgpDTTw4L6ZmWou1aa21Uuqk/ttorb8BfAPg8ssvP+P/O0/nTMyEKe2FDX316F9HYRY6Y4/ek2xBfZQvv+9S1i6o5StP7xn13OPbOhlM5Yqd1pzjMZzKEQ2ZDKccAFwPNh8c5I9vX1m2oNxd//Fice/hv333RcWtzQptL2xbJB1hIU6/21fPnnA992QY+W9pDz+VFWBn59G006Bl8K7LWtjSMYzneBTKq7//ivls7XijeJxl+LesgGWgmPw2ZidCMnLOTOdjn2AmKaUIWAbhgEXSdnA9TW00RHU0UCyo1FIXZcP+flI5l8qgRUXI4nD3CKDZ251gR1ecTQcGePHAAMmsRzTgrwXdfHCIu9a2cPdPtuO4Hrar+cR1i7nvY+u49Z+e4YZ/fI6QCd/56FquWdZcbJPjeqTzxd2iIUsqqAtxmv39ey/mjx96ne54hp6RXDGAHky7hC1FDoqFmi3LYH9PkqTtcbg/QTgYYF5thFg4iFKQzU/gFe4nMht+8qYjEO9WSs3WWncqpWYDPWWOOczRVDWAFvx0tbPeqZ6J2dDaj5NPCy3szTt2Fvprz+3nqR3dxdd88volxWPGdspvXz2bg/3JUTH9y4WKySVeaRvknp9uH/fevv386NTWL/7ojWLFZEkPFWLmFQbG/s9jOxjJulN+vacpFmhyPc3j2zqLBdwU8J61LWw74hd3K2zP+MjmDr707osAf7Dv9tWzWdEcG7X05uEyS2+mi2TknFHO6z7BTFNKEQ0HCFgmGn9ArLSq8UfftJDKoMVwJsc7LplLIuPwzK5e1u/uoTeZpS4a4Lk9vah8wJyyNSaQyto8se0IWdulpTZKTzxDa1+SF/b20Nrn78SQdeFz973GC39yC9GQ371M5dx8J1+TsV0qQrJZjxCn04pZMX76ueu48M8fH/W41rCwodLfHzyVw1CwbnEDK+fESKRdcp7HsqYYtRG//pTWfq2pAqVU8Y84cdNxR/wJ8NvAvfm/Hy1zzC+Av1VKFXpHtwJ/Mg3XnnHTPRMzNs293Pnv33iw2Nn90FXz+eZHLh/3WEHh34XnDvYny64THUtDMRXVUP6sVm00yEDKHnfsw5s7WLugVtJDhTgNjrWNYOl9YKIgfFYsSG/JqPhYAVOh8IPwgGVw++rZvNI2ULwHrZ5TzV/+dHvx9Rr44auHuHNNCx+6an7xnrOp/egA3+kIlKU6+xnjvO4TnCkC1uhZKs/TKAUB0+B9V8xDa03QVPzFT7bTNpAmZTtkci5LFtSys3MEtL8G1PHABZ7fP8COzjgXtdRyeCiNaSguaani+T2jx1lGMh4DySzB/LZG/jSCdNaFOFWe2dXDvU/spLEyxHc+spZgMACA63r82aPbeG53L/PqoqTGZM9qoL4yyH/+zhX8amc3DZUh9nQneHF/PyNZl8taqmmOhZhTGy0eH7ZMlPLXhQMyGz4Nprp92QP4o9gNSqkO/Kqn9wL/o5T6GNAOvC9/7OXAp7XWH9daDyil/hp4JX+qewpFWs5209HBLATftdEg9/xs+7gZ5dLzP7W9qxhIF2a6C53f0g5w6UzUM7t7iKdtDvYn+cbzxw/Cx7pmaQOfv3k5j2zuKPv8vu4RvvrMPmqjQUkPFeIUOtY2gvdvPMgXf+Snhh8rNf3Oy1oYyTpsbh8cVZQR/LXfS5tirJ5TParieens9obWfhx39Bf62GydibJjJFA+t0if4OyQsR0ytkc8bTOQzLGzc5j+hE19ZZD9fQnmVEVoiAbI2MP0JXJcOq+GzQcHcUZ/zOlPOVyxqIbrls2iMRZCofl/bljML3YcDcZb6kJ8/6V2PnXDUgwDwpZB1vUwUISD5ml+50Kc244MJ/md7/q30d0kuOEfn+OJz9/A+r299CUyrN/TS200wMGBJCHTz1opddNyv95TOGixdFYluzrjJDIuGdumrTdJ+CKTcMD/3JYOpkkAPn2mWjX9gxM89ZYyx74KfLzk5+8A35lS684SJ9PBHF35/Gh14tIZ5cKfTe2D4wLpx7d1jksBLRZiMxSu1hT6zKXFlgoaKoOELIOhlE0yV34G7fM3L2ftgloeniAQf7ltkFfbBwlaBnffseq0blkkxPmk3DaChVnwzuHMqOcChsIesyd40FR884UDeKX55yVebT/6WS63tGR31wiHh9JYpoHj+tkvBowbeDuZ7BjZhuzsIX2CM5/naTK2PyA2lM6Rzjn0JnLEQhaOq1neGGN/fwrP87hxWSM7u0dwPE1NNEjW9YinnVHn+/en9/JbV86nqiJM51CahqoI3/ntS/nKU/uojQSprQzSNZIhkbWpjligoFoqtglxSjy13R8EKxRg64xnedu/rCdje9iOizIUlSGTrONx0/JGntrVi1vSLfiHX+zhgtnVxMIWm9sGqa8wiQQMaqMRgqZJ30iW5qoI4BdylMyW6SeLdWZYaYcV/K0/tPb/PjKUHlWZ/JHNHYzpV7NqdtWoCugrm2NHO8Du+KqGhvLXeSgF1y5tmNSewX/4P1v48vsu5a41LTz06iFy7vik1kLht8FUjs/etPTEfhlCiGO6cmHdqM+s6+riLPhYY4NwYPRnt0xueuElheAZ/KUnD23qKK4JNxRYhuKDV85n1ZiZ84ITXbIjdSaEmB5aazxPo7UuLjWJBCyGUzaOo7Etf9vTt65uJhI0CZmK/3juAGHLpLbCpGs4TTKnCVkG2ZKp8aQLN/7jc7zj0rl8/uZlZByXoZTLyuYqupNZ+pM5ljRU4rmaZNYlGgrM3C9BiHPcLatm8Rc/2THq6/zw0OhBecd1WTorxq6eJGO772kHdnXFuWlFI239Sd516VJ+8kYnhoZ59VFmxUL+EpN8YSmt8/cUWRs+bSQQn2FjO6x337GKbUeG+cErh7hv40H+59VDPPjJqwF/HWaBUvCp6xYTiwRGVUAvXdc9dkYcYEVTjDULarlzTcu4iupza8I0VIbGzZy39ad4z3+8yC0XNnHjilkcGkiNS2kFvxNfG5WRbyFOle997Co+8u2N/HpfH64uH2xPRXNVqLjtIYBlKnR+bXhtNMgHv/HSuIE3T/udemDC7JcTXbIjdSaEmB7+mnC/sxwJgONpWmrDNFQGqYsG2dEVBzxqwhZVFSFmV0eIRSwOD6VpG0gRDJj8+VuWk8jYPLrlMNuOHP3OH0g7fPeldn69r4frlzfx6NbDKKA6bPGhdfO5dmkjNZVBHM/PzBNCnBpzqiv4/seu4E8efp2OoSwG+a1DSyRzmnjapq0/XayOXjCrMkh/MsfBgRRXLGxg5Zxq5tRG6U9kaaqKUFcRHBWEF4Lv0n+LkyOB+Awr12H95PdeLXZ0bVfz9ef2c8m8mlGViz945Xy+8LaVbMqnkWbto3sBGhxd1w3wd4/v5JW2QTSws2uEXV0j9I5kWTW7atRa0u54hn/54Br++qfbi9uTFWjgyZLK7OUY+B1zIcSp872PXcWN//AMbf2pkz7XQMrmXZfOYcuhIW5b1cwtq5qL96KHN3eUzX5RgDIUP3z1EI6np3XmWrYhE2L6FDrQpqEIBgp7e/sd8bTjsr97hB9vPUzOhUjA4IZljew4MoxlKC6cHaNrOM3/umUF1y6t5zf+5dfjOvj7etNEQwOgoa4iSH8yS8722N2VIBoMUFsRlM66EKdAxnbJOR5VkQDXLZvFA5+8mhv+4dlxWbMF7f1pP0jPr0oLB0x/a0JgTnWYT1y3lOuXN6K1wjIMmqojmEqBUhiK4qy467oYhiHblk0jCcRP0omuZxz7utLXdsdHp5V0xzPURoP5bQP8ju9da1qKz9+1poXekSzP7u4pVjourOsGWNYUG7VFWSGoNg24YmEtr+Sfcz34sx+9Qdp2qYlYDI1ZG3Y8lnSchZh2n3/wNX6xvZuKkMkf3LKCD101n9tWNU9q94NS1y9rQANZ2y0OzNmOx0+2HgHguy+1ccuq5uLSkomKMxa2MbFL6lk8vLlj1P3sRFPMZRsyIaaHaRq4+XQ40zzaYdZa0x3PsK1jmL09I/QnsnTHsxiG4vBQmlVzqmmuiuB5mhf29vDN5/ZhmCbLmio5PJgmUVJLxt9dJQ7AcMomaCm2HYn7nfuaMHNrI6fzLQtxXrju3l9yKJ9+ftPyBv7zd69iXl0Ff3vnRfzZj94YV2QxZILjQiSgyLia6miAP7p1BQsbKjg0kObGFY1UhgI4rkfPSBrb9WiMhXE8jeN5hKyjRRYLwbcE4dNHAvGTcKKdzeO97v1XzGdrx9F1n1cvrueen23H9fy143ffsQqAP/3RG6Nmpf7yHavHpYpuah/kB68cLNsO1/NHxoKmIpdfT14u5Xyy3rO2RTrOQkyjzz/4Gj/e4gfKadstrgf/wttW8sud3eztTU7qPFcsrC1WVy+9/3jaD6rBX9pSmgp+55oWfripg9zYb3Uo1rEo/P3Qpg4c9+j9rDTFPGd7fOXpPaMGB49FqqsLMT0KAXhpGmnWccm5Hka+YEwq69CXzDK7OkLfSI6VzdX0jmSxHZutB+N4gOe47O5O0FgZBAWJMlsj2hpmVwRp7U2SdTyuXFR/zNnwjO2fI5Tf5kwIMTHP04xkHH69r7cYhAM8s6ePnpEMs2Jh6iIWjne0cFsAuH55Hbt7UvQnM3jA3JoIN10wi854lreums3lC+ow88tHOlNZTNMgZbv0juSorQiM2ze8kGUjn9npI4H4STjR9YzHe93Yvb8HUzly+UJJWmu2HRnmnp9tH5WObucLpRW2FwKKe3u74/vRRQcHUty4YtZx086Px1SMmqUfSyohCzF5hc/L0zvHfy5/8MpBPnTVfHqTEy8DqYsGGEzboCEUMHj3ZS3FLQYHUznuvmMVj2/rHLU0RTO6xsPaBbU88Il1PLK5wx/wczUe+foTpsGNyxtpiIVQwAMvHxx1Pytk8Gjtv+bX+/p4pW1Aiq8JcZq5rpcv0KqLgXkim8Pz/LTWBfUVDGdsokET7Wk87fK+K+ZREYT7Xz0y6ly9iRzRgCKoIFcmBfaSeTUMZxySWYeman823PM0KdslYKrizNpQKsdQykYD1ZEAdRVSW0aIicQzNn/6yBu09acIlJmI/tXObt55aUuxr1/4aHpA0tbcuWYuFzTFeGF/PxnbYyCVI2QZ/Pi1w/QlsixqqODyBbWEgn7F9GTaoSuepCJUdXSkPk9mwqefBOInodx6xskEnKWvM02DwyXV0UtfX7oveOl1FBQDc/BHvwrFlT78rQ1kbQ/TUNzzztWsW1yPoRhXKbGgO57ll7t6xhVwmAoFfOK6xePaf7w9hQvPSYAuxFGlnxdd5jP5Rscw7/y3F2isCDKUssueI2W7fOldFzGYylEbDXLPz7aX7M6An0Hz9lVsbO0vrgPXGv7yp9tZ0RwrfhYLs9N3rmkpBtjbjwzzg1cP8dSObgL585TenwrXcz1NfjWNFF8TYgbokhuI7bhoIGSZaBeaq8PMrY2Syjpc0lLNS62DDGVyHOhL86PXOljWWEFDVNGXGn0Tsr3y/QlLQcbRmErxzjXzaK4K43mavT0J4mkbFKycXUVlyCKesQkH/b5MPG2PCsRL2yyzbkLAC3v72N4xSENlgKGMh2VQTD+vDpn8/RO7+LsndvP2i+cQNiGTT1iZWxMmnrIJWybzGyr5VEs1w2mHg/1+1srhwQwdQ2mqowGe2tnNnWvmMpSyGcrYzIpFSOUckjmTKsPwM2jEKSGB+EkYu54RmFSqeuF1hZmmB18+yCObO7j7jlXFDrNlGrxnbQt3rWkZdZ3aaJBtR4ax8uu/zJLjHt7cQcb2P52Op/mzH7/B37zrIlbPrR5VCV0BTVUhuuNZNH4F5BP5iF3aUs3rh4fR2l9fOr++otj+0vc/UQaAbFUkxHh/8IMtxc9xOR6M29lgrIx9dCvBrz6zb1QQDv5A3vYjw7z38nnct/HgqMcf2dwx4edwRXOM7UeGcfI98Zzj8ezuHu5a04LGz4opfN4La8nNfAqsFF8T4vRzXZeso3FdDY5LyDJpqavg8FCavkSWyrDJ4oZa4jmX7niWSMDC9jxeOTjEkqZqcp3DxDNHbx61kQA9idHZOAHgysV1/MlvXEAm59JcFSFtOygUQ6kcdRVBklmHvpEMlaFKKoJ+MK5QVISOrj8tbLlWYJrS+RfiL378Br1Jm9aBDEETrllaT3NVmBf29pB1NPGMg2UYfG9D+6jXDaZz9I5keHpXN72JHL9zzUIubqmhMmSxoXWAeMYGrakIWWRtF0MpGitDOK7GMhWZnEM26+BYJsGghIunivxmT1LpesbSDu/xZn8KAarj6eLxj2/rHLUV2QMb/QC9NED94Dc35GfSFR+4cj535gP1Te2DPLRpdHElT8Pdj27j49cuGtVx18C8uig9I9nijFu5yXCtPRKvPYbWmsqLb8UIhEY9v7VjuPi6nD26/Xa+Q18YPChXCVm2KhJitI98eyPtA5Orhq61Jr3nJXK9bcTW3oEZqRr1/N5uv95DIQOndCkLQO9IlhtXzBp33h+8cpA788tMCp/f0gG2i+dWjzr+lzu70fiz7KvnVHN4KI1pKDy3sMuD5v0l9yohxOmhdT4d3XFQBpiGiedp5tVGqAiZBE1F0DLY3jlMMuswkrbpHMqwqLGSqojFNUsacFw/C0cpfxeXwYO7GNr2a6IrryfYuACAdUvr6UvkqLAsakIBDEORtT0qQhYB02AkY5NzNKAYTuWIhS0iQROtIRIwx7W7dC2qEOezjoEUvcmjmW85F65fNouRrM3qubW83NqP7UIun5duD3WRfOOXhOdfBAsuBmB/d5y043H10jrqK0PUVwa5anEds6tD7OqKc3ggyaKGCgAqQhbRsEkq6xAOKML5XPjCxJ+YfhKIT6Opbr0z9vjbV8/mlbaBYoe5UJG4EKA+svlo4STH1fSOZHlkcwcP56sbO2UWgzueprUvyaevXzyqyvIrJVXUFWAYjFtLntq5noGnvgZActsvmfW+v8aMxMq+F8NQxfYXUu5LC8ndfceq4hp2oLheVbYqEuKol9sGxj1WKLximYqKoMlw2kFrj4Ff/DuJrU8A4MR7aHjb50e9bkNrP/c+tpMntndx26pmUjl3VC2Ip3d083rH0LjrOR587bn9PL+3l5zjYSiFp3VxgHBT+9F7h6H8thWeu/vRbXjaz7AptFtrmFMTkSBciBlgGAaWqYgnc2Qdj3DAxDQUDZVh+hI5NrcN0B3P0lwVYtmsSlJZj4vn1bB+bw8b9g8wuyrEUF2ETM6jN55k/3//GV46zvDLj9D4jj+iefWb6E3kuGZJPbFIgIzt4Hr+8rhwwGRhfQVbDg0SC5kkszYp2yFs+Y8X0l1dT5N1XJQGy1AYhqSlCwEQtMZ/Di5pqWYobXP7qtlsPzzEHz38Bq6GzKFt9Dz0V+hcmuGND9Hye9/HDFcynNXUpnPsOhLnlzt6SOVs7rhkLresbMJ2PQ4PpcnYHlsPDrC8uZqwaVBfE8XzPCzTRCmF52nM8WNmYhpIID6Nprr1TrnjVzTHiinrha3ICgFq70h21Ouf3tldTDe1TIVRMgtV6skd3Vy/rGHCdsypCXO4pArjwvoolqHY2r0HFYzScMcf0Pvo39H70F/R9MG/5dKFjVy9uJ7vvtRW7Kjf887Vo4rMhQMmv8y3r1BI7rM3LR2Xjl4aoEtHXZzvrlxYx/qSAmoFCqAkbXPouf8isfUJqta9B7u/g9yR3eNPpikOvrX1p7hy4ejPlwd0xbPjXwcc6E0U0+MLldEVfuVl1yvMdMNbVjbx/N5ebMcrPqcBA4pV1WWQTYiZ4XegPVR+cCwaNMnaHvG0jWV79I9kCZomlgWvHRyipSHC0sYYvSMZOoezrFtcT2c8w6dvWMrh4TQ/Xf8KX03HqX3zx0nuep7eR/+eutr/w+++/X3cuaZlVIXlcMAgk98x4chwmpTt8qHL53HpgloSGccfpDeOVnIvDOq5GgIy8yYEALOqItywrI7n9vqD9O+8pJmmWIi7H32DgwMprlxQx/KmSvbu3E7PQ3+FFWug6so76X/8n8l17yey4BIAggGL1r4E3UNZTMvg28+3sqq5ipGsw/y6KMlUloODaVbMribjuJiGImoZKDRaQ6BM5oqYHhKIT7Opbr1TenyhcNmda1pGFUfa0NrP7q4Rnt3dU3zd2OJqrqu5uKV6wrWj5Tr3BSOZ0fuFW4bi4ECKTCaDEQwTXbaOhjv+kL5H72Xg6a/zxu2fo6kqTE04wEAqx4WzqxhM5bh/40H+4ifbsF2Nofw2FgrJFQrZfeXpPcUZ/9IAXQgB3/vYVXzk2xvZ0NpPJGgSTzvFdHLHg5GsS3LneuIbH6bysrdRc/1v0//z/4vnlKmgPmYgfUdnHLNM5ks5rX1Ht0XTwMevXUQsEiimqReyWD59wxI+fcOSUSnsheemY5BNijkKcXKUUhiG4ddqUAplKAKmn+ViGZqaygBHhjysgCJsGbT3J4vrRQ0DXMfFMKCuIkQmkwbAqp7FrPf+FV3f///Y++DfsPmGy1i3qI4vP7WLja2DVIQsPrRuAbWRII9v6wSgtiLIL3Z0U1MZZF5tFIWmL5HFwO8jGLIeXIiy/utjVwP+oPiRoTRffHgrO7r87+hf7e3n2vlhnn3ob7BCERb95pcYGez1X1jSL/iHO1fx/ZcPk7BtQgRIpm02HujH9TTbDw/jas2Sukoc18N2NQQ0hmlgmcao7Q/F9JNA/AxROlNsGYr3Xj6PVXOqi2szC+mhBSuaYuzrTfgfGPzO8tWL69l2JF6csZqssYH4vvzexCoYwcv6/6644Fpy3e8lvuGHhOdfzJP6huLxWzqG2ZJfQ1ZoYmF/4psvbGJWLMTurpHieynMmMlMmRDjjd3vu7Rwmz3YSf8T/0po7krq3vJJf8Yrm8IMRYqp4LGQyZ+87UL+5ek9o85bbu/fchSjB/kUEIsEigNmK5pj44Ljwt/lnjtRUsxRiJOntcYwFNGQRdZxCVqmX+xVK+oqwqSdFI1VYUIhC9vReCgqwwHqKgK09SWpigR4ckc3WsNbL1nEt8G/54QrmXXnn9L5X/+L+//uj5k163s8tb0bV0NfIsf3fn0A24ORjI3WkMg63HnpXFzHoyoc4NBgmlx+a7XKkEV9ZRADf826EOeqPd0jdA6luXReDdXRqW3bp5SiNhqkY9jPYA0YkHM1L/3X/8Ee7mXJR/8OJ1xHjXmYTmDlvAYqZsf4m3esor4qTO9wkra+NFqnWDOvlq6BJEnHxTINLmiKsaSpgu54hnm1USpCAQplnCUIP7UkED9DlBYuy7ma+zcexDSOpnqO3cdoZ9cIVy6s5eX8Wm8FbO+M84lrF/HNFw4UK6Hr/HMBUxULw5WzdFYl+3oSox6zYvVoO4ubSWCGK6m57jfJHnyD/if/nVDLhVhVjaOOH1tbRePvb1hog6f9xwwF1yxt4PM3L5eOtTin3b/xII9v6+T21bOLSzfKKTfzu3ZBLfNqIuzND4xpz6X/Z1/GNE0a3/FHKNO/fbsjfRgV9Vim4sFPXs33X2rjiz96Y0rtLNwrDOXXe/BK7hWmqUZtsThR1s90z15LMUchTo5SCtM080XbTCL5ImiFwXrXCxOrCDCvNsobh4fpGckytyZKRcjkzjXzONQ3wlM7e6itCLKrM85IyA+SVbIfgED9POpu+QyHHvsnvvXv/0zmwndQGDZsH/SDhbAFrgsaTU8iwyVWNYmszeGhJPNro3hA1vGISlVmcY57cV8f//jkbrTWzIqF+cf3XUIsHCh7bFtfkgO9SebURlja6NdTcF3NJ7/3Cvt7/YKutgeZ7b+i5+VfMu/W3yXbsAKAni4/C+Xtb7qIP3zP9XzvxQP89yPb6BtJs6QxSirnEYsG6MvYVIYCeJ5mMGETnmtSUxemKhzA9TSWCUrJwNipJne+M8TYysZ+ASQ9ap1lS210VLDsF14xyNkeHvDrfX280jbAX79z9ajCaBta+zkylOaBl/1tihRQHwvSN+KnrWhg9Zwq9vckRlVVtmqaAXAGOzFnL0MZJvV3/CGd//k5+h//F2a9757jjpSV3W/UUBKEi3Pe/RsPFgPi5/NLQ8oF46OyYUq2I/zrn24vBuEA2SN7yHXv4zN3f5mfp/xq51pr7KEuKuZcgO1q/vB/ttDWP7mq6wVmPvAOmIrL5tWwqX0QT/v3iQuaY+zvTRS3WJxoVvpUzF5PtfilEKK80u9ppRSmgT/YhiYaMHFdzYqmKjwvzmDKZjidI2d7fG9DOwf7kyxoqMBxXB7bNohZWYc71EnQAEMpgpe8GfvAK+z8+XeY13QxRv3CUdfO5RPuGitCtPYmGVqUY1kgRmXAIp52UXjMrqk4jb8NIWbGs7t7CFoGs2JhOgZS7O9JcOn88d+TvSNZntzRRSRgsqVjiBuWN7CwoYI/+sEmXthfUiwVzazD61l6zbV0XvLO4iCYPdQFhsEglezuHOaRzYeojph0DWs6hjI0RAMEDMADS4GjIBI20UpRFw3i3y4MCrcNmRE/tSQQP0MUCrc9vLmDhzZ14LpH11luP+JvE1YVskYF4u+/Yj4rmmN85ek9/HpfH56GrO3vD/yld1806vxfe24/SimU9mfYC0F4wY+3HBnXpkDdPADs/oOEZi/zH6udTe1Nv8vAk/9O4vUniV3y1mO+r/wWwsUAXwErZ1cd6yVCnBMKayNLfy4XiG9o7S+mn+ccj/s3HuShTR3YzujF3OGWlcz5xNd5OtdM4RPlJgbQ2SSBen+7sY7BqQXhc2vCdA5n0Pg7MbzSNlj8rGpgd34LtOPNSp+K2eupFr8UQhyf52kcT5OzHdB+pz+esqmpCGKYipqAxVAixwOvHMRU0FQVZueRYZJZm4wDVl0LyZ6DVHlQyLmrv/UzJNvfoOdnX6Hpt76MMo4WdqoMQs5TRIImHYMpuuNZ0rbHgoYKAqbBcMpGA8mMQ0VYuqTi3LVydjUbWgfocFKEAiZzaiJlj9vTGec/nt1HOudRHbF4rb2fi1pq2HBgcNRxHor/9U/fp8Z0+Y+Xe9lbmCnvO0hlYwv1VRXs7IqTzjkEIiFqKgIEDMV1yxt592UttPUlaR9MsiAWZFVLNUHTQKMx8rPg/tpwmRE/1eSuN42mmpo59vjCn7vyhdoKM0Cle/h++vrFbO+Ms2p2FduODLP9yDC3r57NxgMDxfXXP3z10Kj9xT/4jZfIFfb0LeSgToJVNwfMAHZP26jHKy+9jeSu5xn81beJLL4cKzbxTNXlC2pZ2hTjoU0dOK7fUd/aMcz7v/7SqErrQpxrbl89uzgTXvi5nMJ+36VyjoebGsaIVI0ajbaqm8g6Rz/Ads8BAIKNCwG/oNtUdI9ki3t+l7steJpi4cVjzUqfqtnrqRa/FEIcm6c1aP+7OBI0iSc1w6ksadtlKJGlL5ElnXOpCJp0DWdQhqKlroJdh4cACDYuIPH6k2jPLQbcXria+ls+Tc+jf0f8lR9TfdVdxetdtqCW/qRDOuf5W5ShqI4EqIoEGEpmiYYtRjIOWzuGmFcbZemsyuK2ZkKcS9556WyiQYNDAyluWDGLWVXhssd984VWElkXU0HPSI7ZVS5bDw1TFTKIDw+O6hds7UyjPc3qudXkXE1/Ikv/0EFmL76QR7ccIWM7VIVNDg8lWNYU446LWwiaJm39SSxDcdGcKlytGUk7xEIhcq4mpHR+W2MlhdpOAwnEp8lUUzOPdXxp5/Orz+wbNdMUz2818K0XWoud7oCpWFAXLRZZcz1dnJF6eHNHMQiH8eu4j0UZJsFZC8l27xv9uDKov+1zdH7n9xh8+us0vvuLxecW1EU5PJzGyV9zy6Eh/vj2ldy1poV7frq9WNXd8TR3P7qNFc0x6WiLc1Lpdn7HWiO+5dDQuMfim35KMNlN8KoPQjAKlE8Py3XvByDYtPiE2ui5motaqnm9Y3jC8Tmt/fT1u+9YNeFnVWavhTjzae1XUfU0aE/johnJuZiWwvE8+hMZjgxn8gVcNXOqw6Rsj9VzqukaSpEYzhFsWoq2f4o9cJhgw9F7WmzltSR3PsfwC/cTXXENgZpm6qIBrl/exOaDg/QmcsyujjC7KkIy61ARsnA9zZHhNOv39hEw/YrtoFneLFlz4txjGAZvnWBAvpRS/rZImfyX8pGhBKtb6qg9uJ7tm7dTc91vFvsFm9r6GM54RAIm8YyDmUuQ6D1Cbs1t9CYy2I5mIGlTUxFgdnWEuTURAqZib3eCaMiiqSZCznFpiIUIB0xsxyMcMP1dFsRpITkH06RcauZESrfxOt7xhZkmU/md4Yc2dXD/xoOjZr5sV7M/H4SXViO/97GdxXXhU1H68Qs2LyPXtQ+tR0+1BWrnUH3NB0nteZHU3g3FxzuG0ixrrCz+7Hqahzd38PDmDnZ0xkedw9P6mL8nIc52H7pqPt//2FXHzPy4bVXzqJ9z3ftJvPYY17/zwxihCn9JSUkQrktG07Jde7Hq5mKETmyNpWEo3n/FfEKB8l8FhQQarTWDqTJbpJVYu6CWz960VIJwIc5QrusCfj2ISNAkGrSojgaIhQO4WuMpA+35VWqyOZf2wSSWqRjK5Fgyy/9eDzb7uyfkOveOOndAKWre8mmUYTDwi69iKc2lLTUcGkjz7staePOKRm5f3YynPVK2S288SyrnEFAGtqNprgpREQrQn7JP829FiDNHznH5vZuWYpeMjPckXdJd+9j5q0eoXfsbo/oF3XGbtO0ymMqRdTwG2ncDkK1ehGNrXC9f1C1j8/yeXnYeGSKVdWipjzC3OkzPUAoTTWXQxPE04cDRLctkNvz0kEB8mpQGzMdKzSzMhP96X9+ktvEqzDT9wa0ruHHFLOx8+vlYxWrkyxq47+Pr2N01wtfWt46aAVcKrl/WcNz3Unr+0OwV6Fwau298QF91xbsJNCxg4Kmv4+X8/UVdT7Oz62iqrVLw0KYOHth4sLjVWqnaKW7fIMS55P6NB9neGad08DnX00blsitZunwF2c69DDz9dQaf+Q5DL9yHduziF6PWmuyRXYRmLz/mNeqi5auymobi49cuYjCV46NXL6SuYvRn8cLZMUKB/CCgaRQrpwshzj6ed3QwPZuz8+vUFM2xMKaCgeEszZVBlDJI2x4Z22UgaTOQzLGjM47rQWXQJFjfggpGsTt3ESnJqUx7YFU1UH39R8i0vUZy53reODLMK20DpLIOb101m8ZYkIX1ldRVBAlYimTOpakmjKE033vxAD/Zcoiw7CcuzlOHBlL8xaPb+e6LB7AKhdLyzwWHOrjjjt9g3qIleD37iv2C3ufvw87Z2J7fd08f2QUoKuYtH1UsORA0SWdtXj80RNA0MLUmnXOoCAexXU3ANKmOBIiEAvktDw0MQ0LE00F+y9OkNGA+Vlp66cx5aeB8rFmktQtqWbe4nmf39E6YPqqAoGUUq5GPLRRVOCYcMMc9fiyhuRcAkD28a/z5TIu6t/4e7kgvwy/cX/b1C+srcNzygweehr/4yTbp3IvzUqGq+vN7+0ZtK2jVNGFk46yeU83wiw9gVc0i2LQENzVM/JUfAX4Q7gx34yWHCM1deczrDEwww6TQfOfFNr785G6+tr6VgeTRGe+Aqfjrd13EfR9fxweunA9a8+DLB/nwtzbI51WIs1BhBk3lA3DT9Ksi257HrKoKVs2roT4W4crFtSyojWJ7/rFHhjMMJmxsxyZgQsAyic5dQfbILrLO+OtUXfY2gs3L6H36m+SSI+zuGuFfn9nH3/xsB/GUzUAqx3A6x57uBPGMje1p9nQnmVsdpSYa4vE3Okdl/QhxvnhyRze269FcHaYx5g+Ma6AyqJi3YCF9vb2sW9xAasMPiv0Cb0y/IHt4F4GGeRjBClbOruSdlzRTE1YEgKqwxXDK4dBAgta+JIPJDBk7h1YKT2s8zysO2Mln8PSRQHwaTSY1s3TmvDRw3tQ+yFef2Ve2k1tIZR9bRRn8/4BBy+CDV80fFdCXKwxlKEVrb2Lc48di1c7BiNaQ7dhe9vlwy0oqL3kr8VcfJZcvHFWqfSCFofxBAMuA5qrQqOdtV9LTxblros/1pvZBvrF+f9nXhOZcwKygw59//E6s6iaqrnw30QuuJbr0KtyE/1lRSpE95H8mQy0XnlDbHM9fFlM6CGAA1y1r4MFPXl2sVaHxP6flltEc674lhDhzKKWKM1zBgIXremhPYyoDz9NEQyZ1FRYXzq7i1lWzuXZFI6vn+tuMGQb0JVxCpsGcmgjVC1eT7Wkj4JbZpcEwqXvrZ3FTcdqe/A6OhnTWobUvyUjWJZF2+MHLh3jijSP8Yns3rb1JlIKaygCVQYOM457m34wQp4fnaRIZh6FUDtsd3Z9/+UA/Ww4O0tqb4MX9fYQsg8vnVTG/JsjsmgpmLbuIwcFBHrz7dzCqGqm58t1UXnAtkZJ+Adoje3gnoZYL8TwImgbzG2Jcv7yJpbMqqY6GAM1TO7s4PJBmOOXQPZwmFlbk8u0pzbgTp4cUazvNyhU1OlbhttLnxn4sDAXXLG3g9tWzx63fLC0UVV8R5Gevd+J6mvaBqW1vpJQi3HIhmUPlA3GAmht+m9Selxh48j9o+vC9o7Y7KE1HNwyDd106l6+tby05P2w9NMSm9kFZWyrOKRN9rguPZ+2SVNHDO7GHulBKUXHhjSTWfQL3tZ8w8uuHiSy5ksiiy3CGu3FGjgbBmUPbMMKVBBpOfOcBy1S4nh9kG0AwcHRwEPxZ+x+8cqh47zENVVxGcyr2DhdCnDpKKUzTxHU9jPynOhw0qVGaQwM21ZEgpqlIDWZY2VRJbSRIYzTJju4RkpkMGc8jl85hzF4JWtOSaae1YuW4vkmoeSmxy95G/LXHiF50C/HQCkIBzWAyw/N7+2ntS7C0McaFc6vYcWSYa5bU8dgbnVSELD5y2RxG0jbhgEEwIF1Uce5I51yyjotpKIbTNnXRIIahaOtL8pOtR5hdHfaXgbge9RUh2na+Rry3k1DA4AfZG+HKjzEy+ADxV36MWnAFkUWXYSS6cfP9glzPAXQuRXjeahxgV3ecd62Zw2Xz53CgZ4TN7cOYhmIgmaEyqIhFLCqDBpFAAJXvjjiOm8+WkSUip4vc5WbA2C15jrUHb+lzCkAdrXxuGYrbV88etb1ZaWf4Q1fN50NXzedPf/QGrudvT+S6mrqK4Kg01OMJzb+I1J4XcYa7saqbxj1vRqqovfGj9D/+LyS3/YrKi24ue56c4/FSaz+m4Xf+wX8vT+7o5pe7evjr/HZmU90GTogz0USf68Ljhc5rru8gPQ/dQ9W6u0jt/jXZw7uILF5LYNWt1IVqGXj6a0QWX07mwGs0vufu4vmzh94gNG/1Se3zubSxkjsumUNtNMhgKjfqM7epfZC7H91W/Kwq4L2Xzyt7b5quvcOFEKdeoQiTUgrP8whZFhUhE41F70gWT2tynqaxMsj8uggeirk1UTa3D3CwP0PNggs4bAbYtnkDsetX4pSZPKu7/rdI7X6BwSf/nchv/SPNVSHu29jOUMomGgqwu3uEZM5lUUOUcNDimmWzqAqbLGqsYjCVI2ia1FT4y+lcTxd3b5CtzcTZSuNPPo39X3gka+N5mvrKIPPromw/Mkz/nl3s+PafUvem9zCy8wUiu7YSWryW2GW3Y8Uaiv2CbNtrNL3nbkKmIn7wDQBC81YDkLLh1/v6ef3wMJmsTTzjYNsu4aBJImsTNA0WNMQIGgYBy/AH6AKmrA0/zSQQPwMcaw/e0udUfh2H5mineDCVO2Zn+P6NB3nw5YPFTr8HDE4hCAcIz78IgEz7ViovvrXsMRUX3Uzi9acYfPY/iSxbhxmuLHtcYfuysVxP82c/foOD/Um++1LbhNkBEqCLs8VEn+vaaBBD+SNqWkF638vErngn1Ve9h6q17yC+6Sek9m4gbGeJrfkNIsvWoe0MVVe8G6vKL7boxHtwhrqIrX37SbVxZ9cINyy3+exNS4uPFT5nh4fSxSAc/E7wnWtajvv+hBBnLs/z0NqjULtNo/LbmsFwNksqlyNgKVzbz2irUAZN1WGcobS/rM6AtA4SnbuCZNtWKq6b4EKhCupv+hg9P/syuR1PYc2+i0MDSQIBi1TWoak6zKXzqqkMB+hL5KirCDCQyDGSzhGyDEzTI2u7BAwFSqEUOK6Hqf1gXClVTJ+V2TtxNogETeyMR871q5QbhsLzNLGgheNqthwaYumsCg4NpBnY9zIN695N4zXvo3Lt20lt+impPRsI5cb3C8LVDYQtyB56nUDtHKyY308IAKmM4392PI0JpD3NsllVzKoOMbsmQk1FCKU1ufyOCqYBWFOrJSVOjgTiZ4Bj7cFb+lxtNMg9P9te7PgWOsXlOsOb2gd5ZHMHD758kLHFyo+38kONOSbQsAAjWkOm/fUJA3GlDOpu+Qyd//V5htZ/n/pbPzPF34JfvO0bz7cW16xmbY+vPL2Hz9/sV4WWNFhxNploGcpf/GQbjqcxFLzzkjnct2s2I688ir3iWgL1LVRf9R4S259hZPPPMCIxIgsuGXfuTNtWAMJlnpuqbzzfyi2rmsctk7FMg4CpcFyNMhQ3XTDruO9vKmRgTYjTT2uNafodbc/zcLUiZ9uEAga1KoDhGbhOmq6hJJWRAFnbwVTQM5zBdTUXNleyoytBZMEl9L1wP5Y9Qi4QK3ut8IU3UvHGk3T/6rscuOAaCMWoCZikbY8blzXSVB3B86A9l6I/kWVeXZSQopi9Z1p+e23Hw/U0acelImgV702l61klGBdnMq01tutRHQ6g1NHBo2TG5rsvtTGUzpJ1PC5tqaYnnqWrfg5HXv8h9RddR1VdC7Fr30vn5l+O6xcowNWQyDhkDm6j9uKbiFhQEw5SF7U4PJyh1jFxXU11hUVtRYhwEBY3VtFcHSESsHBcB9txCQat4gCXfJ5OHwnEzxBj09Unem5Fc2xcx/6uNS1o4K41LaxdUMv9Gw9y96N+Z7+Uwk+L8Y4TiY99WilFeMElpNu3HPMDGmxaTGzNbzCy6WdUXnwLoealZY8z8teYqJJ6aTte2NvHK20D3LmmRdJgxVmjNMgszDZvah/kjx9+vVg3wdOwp3uEihXXkOvcQ7rtNQAC9S1UrroJXIfUrhcIz79oXPp5un0LRrSGQMOCk26r1oxLm/c0uK7nV0wHfvjqIX65s5vn9/Zy38fXAYx7f1Mh68uFmBmGYeB5Ho7jopTCcT0crbA9l5ynqYqYuAQZSOVIuy47Dg/joYmnHZI5h654mmjIYtk1N/Dz5+8j0r8bPfvyUfseFyilWHHn53jtnz9DfP33mHPH72MoxUeuns/8uii7ukbIeZq3r25iYWOMxsognqfxAMfxCOULxbm2C0oRNP37oKc1Wvuz5EKc6VxPs+3wMAPJHJUhi4vnVRM0DYZSNnu64xzsS7KwoZKhVI7n9vZy0wWNeN4dHNqzja6drxJZ6BKob6Fi1U3oMf0CDaAgMtiKl0szf/UVxCqD5FyXpA3RkEXItMjhUBsOs6A+QmXIYl5tBVXhAKAJWiYBBYVeuQTip5cE4meZ0qB8U/sgH/zmhuJs+F1rWorrOscG4ZAPfk+wEGJk4aWkdj6H3ddOsHHhhMfVXPthkrueZ+DJ/6D5t/4BpQyiQZNU7mgl1MqwRSrn4HrHn53X+IF3YXs2SYMVZ7pyQSb4GR2ZfIG2zKFt2P2H2PjSMNGL30rFqjczsukneJkEoeZlRJZcjrIC+Wqoo78QtfbItG0lvPCSafmy1MC3XzjASNrmllXNxc+ZaRrFz6fjHa2a/vDmDh7Z3HFSQbSsLxdiZvgzXmBZfuX0gKnwPAfPNLE9h0jAxNUBFtRX0juSZrvjcag/Qd9IlobKAJapAAOrZilmKMrB1zdQ23z5hNeLB+cw95o76XjhIWZf+Rs0X3gpOdfj+X39xMIWRwZTrN+jiEUCNMWCYBhYhiJgGCilUcrAzKfwerpQWNIPwkvXugtxphpO2/QnszRUhuhP5OhPZKmrCDGStWmsDOFoj82H+oi3bqMm18vBp0fYUXX5pPsFroZs+1ZQCq/pQg4P5QgGIGC4uBpqQwEW18cIWIpY2GLNwjqqoxamqbAdjfI8QsEAWntSqG0GyIr8s1ihM6zxC6E9srmDDa39o9Z1TvbzdLzDwgsvAyDduvmYxxnhSmpv/B1ynbtJvP40wKggHCCecXA8uGJhLZbhX3ui/xEVFAOCu+9Yddx92oWYaeWCzMJjAM5IH/2PfQUvk8BOJzjyn5/DGeqiat17QSmSO5+j8/t/SPyVH1P9pg+M+1K0e9rwUkNEFq6ZtjYPJHN8bX0rT23vGrd3+EObOrAMhakgYBkoGPf+prqNWek2jjKwJsTpVbinuPktiyzTJBqwCBomyZxLKBCgMmBwaDBFMmsTsEwsy0ChiIUCLG6IktMGs5avYaT1NdxjjPDbgFr7XsyKWt546J9IZLM4nqZnOE3Wdjk0mGZPb4IfvtrBjs44h4fStPYmyLpesZ0BU6GBgGEQCZoETX8wwRu/o6sQZxzLVNiOR+dQmpGMTdAyUQp641n+71M7ee3gEG/sbuOl736Jnr4+Xtt3mB3//v9MqV/QsX0jtfNWUFFdi5Ev6mwoheM4JB2XaNhCmYraigiWMhlKZbEdP/uEfMZdMBjEMCQQP91kRvwstal9kG2HRxc+08BI2s6PFPtV1T9+7SK+8+sD2K4+5uxzQ2WQ3sTERdysqgYCDfPJHNhM9VV3HrNtFaveTGLrLxh67rtEl1+NGSm/fmwgZWMYBq7noYzyOfNLZlVysD/JAxsPYubfT2EfYwnGxZlo3eJ6LENhu3rUdl+mofBcTXrfy4TmrqR63XsBiCy5gsFffpPoimuoueaDAOS6W1GhKIGa5nHnTx/wB8PCiy6b9rY/sb2LL7xtJY9s7ijeMwop6nNqIsX38vDmjmJ2Sm00OOU085NdXy6EOHGFIlGF2WRDaVzPI2SC6yoyWYeM6zC3KsLurhEcR7OovoK04xCyLFK2SzRo0LjyCjrfeAFn8AiBurkTXy8Upfamj9H303/g1V88zGULPslFLdUksi6WgpaaCEZ+K9PVLdVUBAIcGkyxfFYlWmtG0jYZz9+HGaWIBs182/3HTFMCB3Fm0fnsDdNQRAMmWdvj0FCKypCfiRKyTFr7EuzqTKI1pPa+TGD2BfQtuh17AdQ3Xly2X2CGophj+gVeJkHm8G7mvOWDKK2IBBQZW2MrzbyaKMsaK6iJBqgIKCoCBuCRzWlG7BzVwaAfuOfPJUH46SeB+FmoNPW1IGgqqkLWqD26P37tIm5Z1cx3XmyjUA3R0+XT048VhBeEF61hZPPP8HIZjGB4wuOUUn7htu/+PkPrv0f9Wz+LqSAaNBnJHp0dPzKULr4Hb4KF6wd6E8Vic46n+dr6VgyFrCsVZ5TSNeFAPhXF7zQ+tb2L7Z1xZlWFOTyYJrzgErKde3CGuzGrZhFZcAnBD3yJnv+5G1yHmut/i2DT4gmvlW59lcCsRViVdSfUVoW/JVDadsc9Fw2a3L/xID98tWTvcNMvDFn6WSsNok80zfxYdTGEEKeOv5+4X33ctl08z599Hk7ZeJ7HUMZhX9cQoUCA1S1VtAaSNMXCHBpO0VQV4Y3Dw9RXhkmtupLXgXTrpmMG4gDRldcT2voEfc/+F0v+8BNcfeFcntvdzdNJmye3d1EZNtlxeIgX9vRwQXM1b1rWCPkU9JTtEg6YmAF/ECASMMk6DqZSBANS4VmcWRzXoy+RxdUQsfylFZZlcOWievriGb75/H4SOYe6cADP02Rtj8C8S/AO7yE72I03pl+gXIeWWz5KunkxSvtZJqXSbVtAe6y4/DouW15Px2CajuE0lWGDWNBiYUMFNZVBWmoioDTpnMuc2jCmMrBMA8MwijsRiNNPUtPPQqUdX0PBdcsaeOCTV7O9Mz7quJfyKbGOm9+zWENTLHTC140svhxch8zBrcc9NjhrEbE1v0FiyxOEh9r82frs6I5/acq6pvz/jOXi89IOvxAzrTAw9uUnd/Phb23gkc0dxc+c7Xh8bX0rz+/t4/BgGgAzWg0o4i//CJ07+ljD2/8IL//zRLxskuzhnUQWrz3h9hqGIuuMD8IBdnaOcPej24oF5RTwnrUt4wLmtQtq+exNS1m7oFbSzIU4SxUCcvCLuFlK05NIM5jKYJgGh4fTOI5mbm2YRM7F9mDb4bi/q4IByWA9wbq5pFtfndS16m/+NNgZ/vMrf8vX17fyrRfa8bSLUpq2vjQ7uxK8djjOgf4kw+lcfubeoyIcJOdCOudiGdAzkqV/JEtfMle8VwlxpkjnXDwNkYBJynYxDYNwwKB3JMvGAwNsOTRMOuexoyuO6/nHqmg1QVPRX6ZfoO00GdtFqfFBOPiD80aoAuqX0NaXwnZdZlWGWNwQIxIIUBUN0lARojJk0RgNEw4GMJRB2DTQWmMYCnPs5ubitDnpQFwptUIptaXkT1wp9fkxx9yolBouOebuk73u+ay04xu0DD5/83LWLqjl9tWzRx23ozNObTRYPNY0DebVRU/4uuGWVahAmPT+43/pAtRc95sYFdUc+Mm/4E5iMde4au1lHisoTfkFprxGVYjpMnZGWEPxM1dafKGw560RrqTulk/jpobpffReskd24yYHsfsPkTn4OtqZODsl3bYFPNcfFDtBbr7wWjkavyKxmV8THgr4RSCPpZBmLvUbRIH0C84OhZlwyzJxXY9oJEjYssg4/jrxxbUVzKsNM68+yoL6CMtnVRKxIJG16R7K4CmD2LLLyRx8Ay+XOea1qkKKay6/iIvf+n6e//n/sP7XL+Joj1TOYzhtYxgQDhpkbY+s61IXsfA8D8MwiIVMaqIBaiuChC0TD00oaGEqVTazR4iZZBoKL79dmZ+B5g9QXzS3ioZYkIqgQcSCgZRDIufSUhsiGqnkgnd+Divr9wvsfL9ADXWQOfQGYWUTLhOxae2Rad1EbMkaupIOBwcS7OtNcKAvQe9QhsZKi5aaMFWRAK6nCFgGNRGLdM4B7WEaCnXcssniVDrp1HSt9W7gUgCllAkcBn5U5tDntdZ3nOz1zkdj99udaH3lh66az49f6+DlNj8YdVzNs7t7uHhuNft6EgynbTa1D2IZELTMcUXUjkdZAcILLyXd+uqktjcwQhXU3vi79P/8/5J4/Slil7z1mMeX3goMBRfNreb1juGyAfp7L583qnq8bIUkTtaJ7mu9bnE9hRIHSvnbCN61poUNrf38bOsRdhwZQrsORuBoNooRCNH4zj9meMNDxF99FABn4DB1t3waZQUnvFZ6/ysYoQpCc1ee+BudgML/3AUsg7vvWMVgKjfp34WkmYtS0i8485V+h5v59NSk49JUHcEBOvrSGIZmeNgmozXDaYeBpEMoGEApm+5Ehngyi7nwCtj4KLmDWwkvvarstWbHAozkXLriObjsLsz1j7P9oX9i4Uf/Ly5mvj34wYGChTVhVrf495NU1mEgmaMqGqQ6YmHbHgaQsV0Mw1+WJ8Sp4Hkax9MYCixz8vOWkaCJkYIdR4apiQRojIWIBi2idRYrm2P8dMthdhwZojHi98P7U35fPKgN5t35xxxc/xBD+X4BQ0dY9s7/F0cFyTjjA+bI0AHc5CCVS68A7ZLIgmUZ1EQCzKmLcuOKWVSFLDRgmoqsowmYHkHlDxS4rn9twzBk27IZMt1rxN8C7Ndat0/zec9bEwWZ5Tq+m9oH2dJxtICbBp7c0T3unFqPr2Q+WZElV5DeuwG7t43grEXHPb5i1U1+4bZnC4XbqiZ1Ha3h6sX17OyMkytJPSusD7+zZJZOtkISJ+tkBnN2d41QKNfgePD9l9qoCFn0jGTZ1TXC4C+/iTPURbBpCVbdHKLLrsYI+Zkp1evegxPvQwVCeJkRArVzJryO1h7p1lcJL1qDMqZ/XaQG3rKyiU/dsGRKn58THcA4kXOeimuJU076BWcoXVIwJpO10a6Lh8ecWJiAoegYSJNSYKKYWx3FI0WVF2AkbdM7lKYyaPqZcsEIyX2vTBiId474CbWJbAYIUvvmj9P3k7+nd9MTzF33G1yxoI6N7QM0xEKsaKykuSZKxnbJ2DavtA1yeCiDMhRvWdnEnOoodRVBbNcjYJmEZY24OAW01iSyNvmNBagIWQStyQXjnqd5vWMY2/XoT2TJOi4XNFeRsh1+sa0L2/E4/PjX2N17BKPxaL/AtqKknNH9Ap0ZIVE7B8MZX3A5akBi/6ugFIsvu5ZIZZiBVJagYaC1oiZsURUNYqOJJ20qIy6ea1EdDhENhQBNKpMlVhGZ1t+dmJrpDsQ/ADwwwXNXK6W2AkeA/09rvb3cQUqpTwKfBJg/f/40N+/sMzbILGxRVq4jWlgPfjwnk4QSWXIF4M/MTSYQV0pRd+tn6PzP/5eh575H/W2/N6nraODrz7fyqesWE886KGDVnOqys3SFVH3ZY1ycqJMZzHl8W+eon3+85Ujx34PPfhdnqIvamz5GuvVV7N52hnoOELvkrQTq5wGggmHMcCVGuPKY18l17sVLDhFZeuUU393kpW13ykH4dGejTHROyXw5a51Uv0D6BKdG6f7bjuPiev6ymrBlYWDQqBRd8RyGoxlI58gEPZbOitHRnyYaMNCGxnY0AStAZOFlpPe/jNIaPYkZtegF1xHe+gRD679HdMU17I0EaK4KM6syQl8yxzN7+ggGLJY1VdIZz9JUHWYk5bDz8DBza6KEgxYTl4sV4uR5GlzPn/hxXA/H9SYdiLtaE0/nMAzoHs5weChN2DLZ3jnEUCrHtke/Rrr3CFU3jO4XuGX6BTpc6S8ZK3sd6Nn+EpE5K7ADUaqU5qI5lfQmHSzDoCJiEU87zKsJksq5KGUSi1hEgyaGoTDwCzZqrfNrxaVs2EyYtt+6UioIvAP4YZmnNwMLtNaXAP8K/Hii82itv6G1vlxrfXljY+N0Ne+sVboe3DQNfvjqoWJRqLHroWujwQnXfhZEgyc3emxV1hGcvYzU3o2Tfk2wcSGxtW8nsfUXZI/snvTrtIZvvXCAu9a08KV3X8SHrppfLBBVStaoipM1lYJjY+sRjK3NUEoFw1ReehuBhnnE1txBdPmbMEIVJLY/i3Zt0q2byOQLHR0vJSy172VQxkmtDz+eY72XcsoNYJysic55Kq4lTq3p6BdIn+DU0/jrWk3TwMAkZJlURwKETXDwMLXG8TxyjkdTTYSL5lWzoC6K5/qdyOjSq3ATA2S690/qekop6m7+DJ6dZfDZ/2QgYaO0pn0wwaGBFOmcQ2tfkkP9aSzToGs4Tdb1qK8MjprFF+JUMRSYBtiu/50TmCAIzzou7f1JWnsTpPOZppahsAKw/XCcvkSWmrBBdzwNKAKmgRWK0HD57VPqF5T7vz4Z7yfbtY/QkisZSjj0px3605qQaVIVCRFPZBhOp/GUQdg0qQkHqYgECFoWeBqlPcLhIIFAQFLSZ9B0zojfDmzWWo/LhdZax0v+/ZhS6t+VUg1a675pvP45qXQ9+OGhNA++fHDCWbttR0bvK24aCndMZJ7KuZgKxhYajQQM0vbxZ9MBIkuuZPiF+3ETg5iVkwt6a679MKldzzPw5L/T/JH/O+nUWk/rsrOT5dbNSwAuTtRk97UuNysLExcWDNQ0M7T+exiRKsItFxKaewEYJsMvPkCmYwdW7WyMyIpJtTG9byOhlgsxI7ETfJcTq6sIUhcNcLA/yVef2Vf2d7CpfZCHN3egoLid2anIRpnonJL5claSfsFZwDINdMDC00dnyT1PEQtZ1EfCJNI2pnKJZ3PgadJZTVUoAAY4QHjJ5aAM0ns3EmpeOqlrBhrmUXXlu4hveIjBS25FG6uojfrbOfWOZOmJZzAVvOWCWbjaozoaYumsqrIBg9Ya1/Nn9SzTkKBCnDSlFLFwAMfTKCZeI941nCFrexgGdAymWNxYyXDKpikaQc9VpBwHrSGZ9ZhTHaahwqJu9nxan/4etbfECEyxX2Dg9zdcIL3/ZcAfCDNM8LTCc11qYhEiAYO47eF6MJTK0VAVoioUJGQoApZJwDSxrKOfd/nMzJzpDMQ/yATpZ0qpZqBba62VUlfi/78k0xmTVAgyN7UP8sjmjgk7omM/RmOD8ILGWIh5dVHa+pJkHI9U1pl0EA4QXbaO4RfuI7VvI7FLb5vUa4xQtLguLLHlcWJrjl2fp7CTQrDM+5QUVXEqTGYwZ+ys7MObO/ifVw5NuNyj4sIb8bIpElt/gZsYoOKCawnNXkZk6TrS+1+h7s0fn1Tb7KEu7N42am/62BTf1WjXLWtgJG2PqiVhKhhI5hhI5tjX2wr4VV5LU8If2dzBD145WFwL/8NNHTzwiXWTHsCYionOeSquJU456RecBZTyO+eWaaC1P6g4ks7iAvWVQfYPaOyMR1UYsp6iN5EhmbHxXI/6CkXKqiY09wLS+zZQc92HJ33d6qs/QHLHc/Q9+R9EPvoVhhTggWE4DKay3LCsgQP9KS5oqmJJQ4ygVT7IdvNFtRSQczxCgUIBOCk+JU6cUorAcYoBup7GMPyJr6zjkbNdDEMRiwZRhiKRtYmFLbqHc7x+cJhkzmX+mjeTGhmmZ+sviEyxX1DaU8/u3YBVM5vq2fMxDX9r0oqwRTQIdZUhljRV01ARxDIMZlVGsR2/BkQ25xCwDJSyMAwZuJpp0xKIK6UqgFuAT5U89mkArfXXgPcAn1FKOUAa+ICW/KIpm6gjeu9jO3liexeXzqtBKT+lu2BuTZh4xmYkc7Q4W1c8S1c8iyI/az7F/xKBxoWY1U2k926YdCAOhXVhTzK4/vtEl19zzNn0tQtquXHFrAnXwk+UuioddHEqlc7KohQ/2XIY5zjrQSpW3YSygqT3v0qmfQtVV95F8vUniSx/06Svm967AYDI8qtPuO0BU6FgVBAO/sBcVzw76rHC52p3l7+v+Nj3WJqRcyqyUSY6p2S+nD2kX3B2KQStWns4jkZhEDAN+p0c82orSeYcDMOkKgxaaZKOS9A0cLUiYLrElq+j71ffQQ93o6qbJnVNIxim7i2fpPdHX2Lo1Z9QeeWdzKsOUFMRwlCwqyvBgf4kHQMpDg4muf2iuQQsheO4ZByPSMDANE1/favyAyfX07iuR9r2cD1N0FREQtNdDkkI3+zqCB2DKYaSOQ4OJNl5ZJjL5tcyqypExnaIhQN0DKY4OJCivS/FhgN9pLIu0VU3EHatCfsFFhANwkiufLadl02RbN/KnHXvYNXsSrKOR3UkSFU0yMo51TTFLEzTwlAahUs6Z4N2qYlGinuGSxB+ZpiWNeJa66TWul5rPVzy2NfyX7Zorf9Na71Ka32J1nqd1vrF6bju+WjtgtpR66TvfWwnX1vfSlt/ih9vOcIlc6tHHf/2i+dw6bzacbPl4H+4jxdElKOUIrpsHen2rXjZ1JReV3frZ9BOlsFnvn3MY19pG2QkbbN2QS33PraTG//hGe59bCcwfj1vbTTIh7+1YcK180JMl8Jg2NoFtbieJpE9OsA10deZEYxQseomqq66E7Rm+MUHCbVcSPVVd076uqm9Gwg0LiRQ03zCbV9QF+X5veOzfsuleBc+V3/+4zfK3iMkNVwcj/QLzj6e5/l7dys/xXtRQ4w5VRGaYkEWN1RQWxkk43rEMzbNVWEumFtFbYXFrFiEFVfcABwdNJysyLJ1RJZcwcAL96PifUSCFsOpDOmsx7aOQRzPozEWomckS28iQyrr8GJrHy/t72PLoaH89kuKrO0X1AqYBrbn4WqPgKXIut6E2YFCnKxI0GTprEoODSRo7U3Q2pPgh690oLRLKufQHU+zqyvOcDpHIGAwknXJ2GDrCFXH6BcoBYZS1EfLDyKlWzeB69B80TVYpkUy5zKYzOE6NpbhEQ0EmFMVoLYqTCwSJmXnCFgWnuehtcY0TQnCzxAyTHiWKqyR/vGWw6MeH0rb/O27L+LxbZ2sml3Fd19qI2v71VAL61gnWs86FdHlVzPy6qOkWzdRsfK6Sb8uUDeX6qvew/CLD1Jx8S1EFlwy4bHfeL6VrnimWIX6a+v9tNlbVjVz/bJGuuMZ3n/FfAZTOdm+TJw2axfU0jOSPf6BJbxsksFffpPam36HQP18lDn5W6+bHCLbsYPqq98/1aaOsqixkv29yVGPGUB0zGzRdcsa+PzNy3l4c8eobBlD+dubzYqFimvEJyLbjAlxdjm6VtQA/HTbcNBi4awYlYkgvUMZQpZmT2ccQ2sSaYeGKgvXi+B4mqA1j2jTQpJ7XqJh3TtJOZO/bu3Nn6Lz25+l/1ffYOXSe2iqilIdMkjkHOLJLAOpLA2VUcKWRc9wipyjmRUL0zOSpS+Rw9Uaz9MELYPGmIlCoT2wtSfF3cQpp5RiMONiGQYjts1QymH7kQQ9iSytvSPs7hrJ7yjgF00LGv5+3/3rv48RqaL+ts+hXWdUv8DWkHE0brb8Bym19yWMaDW5+qV0DWeYUxMhnnHwtCKe8WjvT1FVYVEfNsnZEDJMLMPwA3xDYU2yArw49SQQPwuVrpEeO55126pmPnTVfD501Xy++sw+co4fhBvANcsauH31bAZTObYeGuKpHd0nHJCH5q7EiFaT2vvSlAJxgKp17yW541kGnvwP5vzOv6KsQNnjtIZn9/SOeux/Xj3EN184UBzh3tm5jb98x2pMQ+G5GsNQMlMnTrnbVjUXB4YKjvVZGnr2u2Tat4IyphSEA6T2bQTtET2JtPTmWIibVszimV09o2a4lYK+kSyWofyObMDg8zcvZ+2CWh7Z3DHqHDevbOIbHxlfsX1s0D22hsNHr17I9s44t6+ezYeuku2nhDiTmaaB63oEDHBcjas19dEgiazDYCLDSM4hZJqknRQH+hzmVIXZ1ZMgkXGovfBqDj/7A1LDQ1BRM+lrBmqaqX7T+xla/z32bnqBxZe+iX19LpmcS0tVgHnVYa5Y1kDIBGWYZHIuw5kcWmv6kxlsB2orLLIu5BwXQymGUjls16UmGpwwW0mIqXIcj4zrEjbNUcHs1YvquG/jQTI5h7k1ETa29bOvewTHdTFNg+bKML0jWaqjQfpG0qQ69xJ/+RFil90OULZfULKidBTt2KT3v0L1hdcRDfv7gWsA5W9BqhyHcIVFz2CGkGURMBUxFcS2bXDAUEFc11/WIWaeDImchUrXSANcsbCWhfVRPn39Yr7wtpXF7ZVqo0EsQ+UrPqpiEL5ucT03rpiFaSgMIGgqrlxYS120fEBcjjJMokuvIr3/FbSTm1L7jUCIuls+gzPQwfDGhyY8rjCAUGowZY9KM7NdzbO7e7DzU3e2q9ndNTKl9ggxkfs3HuS3vr2R+zceHPX4F962knddOmdS58h07CDx+pNUXfEugo0Lp9yG1J4XsaqbCMxaNOXXFnSNZPnzH7+B6+ni577g2T29eFqjFFy/7P9n7z2j5Liuc+2nYueePMg5JwYwgDkTpEgFUokKVs6WbOtKli3RV5RMBUu2aUm+kpXvtWWLSowSMxjADJBEIHLOmDzTuXLV+X7UTGN6BmEGAPURRD1rcRFdXeF0T3Wds8/Z+30P2UO9ffHEUByJUDTxU5fPGHbegaB7cFnI4OeT7Qb85JldPLu9h1vvXT/se4yIiHh9MNhXXFXDIENVZfAFtuvhuT6lskvZcOkpGniyTH0c9udMUprClKY4M867CkRAqT89XQViI7x+9vyb0ZomsfvBH7Ozo0jZ8BEC9hdcdnZX6Mmb/HHtATYezON5Lj1lm4aUCkLC8jz2dhs4ThiElywPVQ0Fsxw/wAtGLkYbEXEkXC9gU2eBzW1FNrYXsN1DkfLssXV85OKpnDu1kZzpEQQCJxAESMiSxLbOEl0lC9f1Kdk+fY/9CCVZR/1lHxx1O8y9axGOSWrmRdTFFZqSOgXDpiWl0ZqNsbdoIDyBpqlIkoQuS6iKjCZLJJNxdH3kY/2I155oRfwUZKiNz5ffNK+a/jl4NUqVpeoqnQC+/scNVWVRJAhCvSnOmlTPmv35ajA7UpJzLqa87jHMPWtIzlwyqmMT088hOfdSCi/+ntS8y9AaJxx2vz7DrXk9tIUCWLOvtib84Q3t0cpbxAnz+d+uqZZFDNRWD76vPnDhVFbs6h0mdDYY4Xv0PfpDlGwLdRe/d9RtCKwy1p5XyZ7zlhOu5xr4eQeBqP6OfAG+d2iQumxTJ89s767Wwf/mE0dXKT+ccOLg55OgVjwy+m1GRLx+qXoWCxEKUiLQFImC5eEFAb4kmNKcRpFAlF1SquBg3iBv2ng+2OpY1PqxGNteIHPW9XiE9mYjurai0bj0L+n8zVfofuG3NFz+YTQpLIlZsauHnb0Grh+Q1VVSSZVsXGf9gQJnTmxgWlMCXxHIkqBoOhhOgKpIiP5nnRzVwkacBHZ1l9neXmJcXYIgEJQsj5im4HgBhuORq7gkYhqSBFOb05Rtn/19FUzHxweKloeiSJTXPITTsYPmt/4dciw16nYYW59HjqWYsmgx6ZhCKqlRJ+ko+GiSRCKmY3uCmCrRmFSIazpCBGia1p/x4ker4a8jokD8FORoNj41A+P+kbcgTDEb+PfgfwgBL+05PnGz+JQzkGMpjK0vjDoQB2i4+hOYu1bR99h/0HrLN4870Ogu167Iv2nhuOM6T8TpzeAU660dpWoQPsBAEDlg5/WHV/bjHGPyqvjyvbg9+2h5x1eR9cSo22TseAkCj+SckSusHwmJcFA7NDiW+50WBOF/o1FEP5y39+DnU8l0a1L4o99mRMTrGyFEv2ibTxCAYXkIoDETo+x5NPuChJqir9xHd9kjED6BAF2X8YOA1OyLKLxyP75VRomnR3Xt+ORFpBZeQ/Gle0nPvwJ14lTSMZmkptJXccjEFHb1GiSKMjPHpjGdgFf3dhPXxzB3TIaeskVnyUaIgLq4TlMmQX1Sr3pAB0GY+ROJVEUM4Ps+rg+aQk1wKoTA9QUCga7IVGyPnGmjKDJtBYPGlI6mgOF4HMgZdJdsypbHrOYkB3MmHQWTbFwhHVdIqTIr9pTpKnh45R5yz/yK+LTFJOeOrqwTwsl9c/tKUjPOwxEybXmH1gAmZuOUbUFO8ZnTGmN8Q4xsQqNi+yR0yMQT6FpYI64oShSIv46IAvFTkAG7susXjB02SB48MFZCPw98P1yZOtnCoZKikZi1BHP7CoTvIimjS3dR0400XP4h+pb9mMqm5aQXXHlC7TlzYh23nDeZOWMz/MO96xHAO44hKhVxejH4t/PlG+ZVtw+ta54zJjPs2ISmcOfKfdz+wMaqAOLRcPMdFJ7/DYnZFx7XRBWEM99Kugl9/JzjOn4wZ0ysY+mCscOC409eOp1dPRWe2NKFEGJUiugDQfc9qw/UfB8DAfyqvTl++fxuXD9cWZszdvj3GhER8fpBCBHaGuFhmOEKnmn6yLJCczyG7/jkTEjFFFRFQtM1gr4KcRn2Gg6JORdTeOkezO0rSS+6etTXb7jyI5g7X6K47EdcfdvP0RWZMdkEa/flKBsuui7h+bCzq8zkhjQN6QRJRSYd09jfWyGb1AiEiusLGtMxtP4g3HQ8PD+0OUvoKrIcBeOnO77v01aw8HyBqkiMr4tXA1TT8cN6a8DprwVP6SqzWjUO5Cpk4zpdJQfD9kjHVRqTOpsOFtnSVqCrbOK6AX2GQ5/hUjJdioZHNg7b7v0ZBD5N137muCaErH3rCKwSmTkXoyoSrRmFTEzDDGBGa5J0IkZjWmdaa5bGTAIR+MiAIot+obbIsuz1RhSIn2IM2JXBIRXxwQHF0NVygO8+vPm4V72PRXLOxVQ2PIm151USM4YLOR2L9FnXU97wJLknf0Fi+jkoiexxt2XpgrHMGZvh3T99sVpHftcr+/nNJy+MgvGIo/52hqZYt2bjwCG/bVmCxzd38sTmTkZSwSGEoO/RH4Gs0HjNp459wGEIbANz92oyZ72pX8n4xLjlvMnVtPDJTamqs0LR9sI68f768dvevICtHSW+//g2mlI6vRXnmEJrd68+gOMF3LP6QDWtHcLvdeC3GAQicjSIiDgFCILQ8ktWZSQpDF4bEzqurtBdNGhI6LRmU/SWLQICmjMxTNenPuXjjptNd7YFY+tzxxWIK8k6Gq78GL0PfY+OlQ9yxjXvoD4hs2hClq6SS2tWpzkdJ1exGZONMbY+weSWJDFdIZvU2NaWxxMS01sSoeicIlOo2OQsh7q4jq6GqcRxPVoRPN2x3ADX80jHdUqmg+UGpPoDcaf/3pFlCdvzaUhoZOMa3SWbCXUJVEUmHQ8zNeyKQ2/JZkNbjq6CRZ/hkO9388kmNCqOj6qAsesVytteoOnyD6E2HF92mLHlOSQ9Qf30xcQ1hWwyTjqmMj6jkYjpzGxNU5+MYfkBPcUKrZk4Sv9klCRJBEEQrYa/zogC8VOMRzZ2DHs9OBAHhqWTjsRqKRNTKNlHkGg8Compi5H0JJWtzx1XIC7JCk3Xf472//o8uaf+L803fP6Yx5w1sQ4BvHqgULN95a5eXt2frxFzc/xo8B8RcrTfztAU6yvntFZFAAenbo+UyqblWHvW0HDNp1AzzcfVXmPHSvBdknMvOa7jB/Ppy6YzZ2yGHz21gwumN/G+JWHmyPt/saJmdV+IUPzwsU2dNccfrkZ+gMPViQ9sb0jqw1LXIyIiXp8IIaqrZYqskNDB832SqkLJtLF8n0xSo9d0QRLYnsf0xiRNaZ1c2UYVEpps0THnYkqrHiCwysijTE8HSC28CmPjEzz9639n+jmXY7gNtKZ0EBab24vUJUzGp3VSsQTTmuOMq0shyxLZuEpdOkFdUsXzBZ4vMByPguni+YKuok1zJkYmHg19I0CVZZBkDMdDkuXw9cB7gOX6yIpEXFEIBGQSGk2pGCXbrQoHx1WJpK7y1OYOdEnGdn1yZQdfhJNZEhJxVUaTHF64/4ckx0xl3EU3UzoO/UAR+BjbV5CYcT6mGqOr4BFTLCZOytCUSSBJUJ9USekqE+sTuJ4HkoQm9TshRBptr0uip9EpxlDbpOsXjB31MYdjTDZOaYjH8EiQVI3krCWY215EXPfZUaenA+it08iefzPFFXeRWnDlUb3FAdYOCcAHeHZ7DzNaaoUvJIgG/xHAkX87d67cx8Mb2vnwhVPJJDQumN7EPasPHNJVGGUQ7ptFck/+An3cHDJn33Dc7TW2PIeSbiI2Ye5xnwNgZv9vYiBTRFflqgjbgL0hhL8VTZXpLFqHPc/vXt532EB86CRGQ1KvSfO/7c0Lqm4N0YRYRMTrl4FAXJIkNE3B9wWIgGxcx/R9PEuQ0jUmZhM4QUBjIo7pO7R1VcLA3PdRZJXWBZdSevk+jONMT5ckiYaln6X9/36Oe//jW1z08X+kWHFCOzVf0JWv0JE36TYsxmRTrHL7UJXQuiwQAQlFxpdBVeTQzkyRyOo6BdNGk0FXoxXBCIjpCuPrkpiOQ0LXiekKQRDQVjAwnVC1f1xdgoSuUDQdyoaL6BcQlAlYvz9HyXTQNJm8YaMpAUXToeL4+B5YgK64jKuPs/3+n+EUexl/85cpBccXell7XyUwi6TmXkJA2A5dEezPO+iaSVLXaC/azGtS8FwP2/NDtXRVYWARPEpLf/0RBeKnGAMreIercz3aMR1Fi+XbuilbLt5hZuI+esl0AG67f0ONz/BISM69lMrGpzD3rCU547xRHTtA3UXvxdjyPH2P/pBxH/khsjZS05NDCGB6S5odgyYUPnXZ9GjwHwEM/+1cu2Asn/jVKyzrX/19dnsP184fQ0NS5w+v7K9xHBgNuSd/SWCVabrlc0jy8Q34AruCuXsVmbNvPK609MGr+Du6K+zoPjQB4XgBP316J1fMaQ3VhEVYH/eucyfx9sUT2dpR4tUD64edc1N7kVV7c8N+T0PLYYaukOcMh89eOXPUnyEiIuLPiyRJiH4lR1mWkeX+4Fz2UVHRVZ+4pmD5Ac0JnaIk6OyoIAOyopCN62hygD93IfvqWjG3PENdfyAearCPHK1xAnUXv5euZ37FxhVPMW/JFTg+HMiZ+B4k4hI+go6CRWfRJptQ0FWVpB7W905vzaApEqqsUrZdLMelLqGRjkfLghGHiOsK8UFCqmXLx7R8ZEWmx7CJaQquJ1AUCUWVKBkembjCjq4KRctl7d4cjWmNiu2yo8vEDiTq4tBbAV2AHwRsWbeaHU/dQ2bxjYixxz+xPpCWnph+DgCODxXbpy4tYfkBbqUCTQnQFEzPJ6ErVEyHpK6Q6f+MUSD++iMKxE9BvnzDvBEF4AOs2pvjkY0dOIeJwBtTOn+7dE51peul3b3D1KKPRWLa2aF6+uZnjjsQl7UYjdd/jq7f/gOF539DwxUfHv05JEjqCrIUCtOpMlw7goyBiNOHgd/OgDib5db+JpZt6uSpLV015Q2jwdy9hsqGJ8he+G70E/D9NravAN8jdRyqqhAG4WdMrBtWvjHA45s7Wb61i0AIZFni629dWH0GDATaD29op6dks7mjBBy9xntoOUyUjh4Rceox2Et8AF1XsV0P23WRJYEsyTQmY7g+6JqH6aRwA0HedImnNToNl6LpkphzCcVX7scxS6iJDAojtzIbIHv+2zE3P8PuB/4DZcJCJC2FHYAPKK6gZLqs2tXNtLEZsok0BcOmLp5gTEMcL4C+ikNMU2hJ60C/L3pExBHw/ACBwBMShuGgKzK6IuEFYSCuqzKphIoiw8GcQcF06CiZtJcshC/QFZjamGT1XhtfQCYuUao47P/j/0HNNB2XZ/gAwncxtr1ActYFSKoOgAMoqmBSnU5v0SGlKRjOwGKbhON4xDQVuV+0WdOikO/1SPRUOg1YsasXyw1XqAIRBqwAmiLx8w+eW5Nu+v33nM1ls0ZX0yopGonZF2FsX4HwnGMfcAQSU84ktehaii/dg9O5c9THBwLuW9tWVYcPgvCzr9qb40dP7WDV3tdGsC7i1GPFrl5s9/BFWn6/aNloRXUDx6L30R+iNk6g/qL3nFD7KpufQcm2HrdaugDyhnvE9wfsDQMRrnj966NbmP6VB7n2juVAWAv+3x9bwjdvXkRck1EkRhxUD6yQf2HpnBrhtoiIiFOTTCpOc12SZCyGroWBuWHZlC2bpoTEjNYU4+tiKIqMKgISukrdvMsg8DG2Ph9aqB7HdSVFpfFNf41XztH++H+CdGj1yPbCVfvWbBh0t+VNbM/HcgWBD0XbRVNlLM9HSINW9/tX/T0/IDjZVjIRpyxCCHKGw+6eMkXTwHd9KpZLV8lBkSV0RUZFoiGhh5ZgqkJfxcLzIKFIBIGgo2ixZl8OLwANcFxBYeVdOD17aVr6l8ix5HG3z9yzlsAqVy3PEhIkZEioCo7vMakxxrmzGojHNdK6QkKXsT0PXRY4rofvj14DKuLPQzQ9chqwsl88aYAzJtRx7YKx1ZrND/5yJS/t6WPumAzXLhjL31wzm22dJTqKxxZ5GyA171Iq65dh7nzlhDyPG676GNauV+h9+N8Z+4E7kJTjv0UFsL2zxPeWbSMQYW3shy+cysb24jFVoCPe2GzvLB0xTVIAF81oYl+fQX1CO6ImwVDyz/43fqGTMe/7TnXG+njwjQLWnrVkz7vphNLI9vYZR3xPkkCTJVxf4Avo6w/at3dXuPaO5Sz74hXA8LTzkQbVx/Ifj4iIeP0zkKY+ELxqcuhD3FU0QICuKBRtl7qETtkJMCyPnK6S0jXSE2agNozH2PIMmbOuD8/D6Et99HGzyZz7Vgov30d83mXEJy8CIKmBrkgcLFlkXQ8vFSOTUFFUcANwPXB9H88NaLc83CBAlSWaUjrSgCiXBDFFCb2V5eHZABGnD0EQsK+3jK4pKIGGHbhMbc6GAXrFZl+fj+N7JDUNx/OZ2ppClgWKyOMICQkbx/OxPVAAF5B699Lz/O9Iz7uM1MzzOZFQ2Nj8DHI8TWLa2SiAI6AxAYokkZB0ZBHgejC5Lk5jOkVMEeC5uJ6LqsjV33LE648oED8NeH5HT83rdQcK1ZTtD/5yJc/0KyKvPVBg7YECcU3m+gVjR5WiHp9yJnKynsrmp48YiDtdu+i+/7tIkkLdRbeQmn/5sH2UeJrGaz9D933fpvjyvdRd8K4Rt2EoAmo+g+0GVbGuo6lAR7zxWTFkcmooA7+JkWIf3ELplT+SPvsG4pMWHn6fjh3kl/9f3J791F/+4SOKGBnbXoDAJzXvslG1YTR86tJQE+JwIo47e2pFG08kqP78b9ewfFs3V8xu4fvvOfu4zhEREfHnRwhRXUEOggBVVVBkj7Lp4rgCBegtV1BQSKoyfSWLtqJNR94gV3GIKxJ18y+j94Xf45X7UNONw4JwIQJ6H/we5u41JGdfQP3lH0Y5jMp6/aV/gbl9Bb2P/DvjPvJDVC2Gqkqosozj+ChJnbGZGKbp0lM0UWWF1nSMsulQdl16iw6aopCKayiqjCJDfTyGHwjKnouuysiSTFyLPJZPRwZnSEi+jypJeJLSn0Uh0VexURSJUsVlQ77AxIY46/YXKRg2PhKtqTiG4yFLENfBccANfHof/HfkWJLWpZ8irUPZoSYYD1yLwnN3Utm0HL11Oi0333rYSfzAtTG2ryAz91JSqkZLAiRdQpckWjMppozPUDE8GmIaTfUJUpqEAJIxnUCEGX6RZdnrlyg1/TRgaOcXAP/y6Fbe8eMXDhuQWG7ArDEZJtTHR3wNSVZIzb0Yc+fLBPbhV+IKL/yewCgiqRo9f/oX+pb9GBEMnyNMzrmI5OyLyD93J27P/hG34VgM/R4e3tB+0s4dceqwam+O7iGWfoPT0Ec7DBOeQ+/DP0DJNNNw+YcPu09l03I6/udvcXv2Iyey5J78xRFnqCubnkZtnIjWOn2ULTkyUxqTfPvmRVw6q5lv37yIL98wjxePMBkxozl12O2j5fO/XcN9a9vIGy73rW3j879dc1LOGxER8edhcFAqSRKZRIxMQmP+pEYqrk++5BEg0Vu0QZKQZQVdlpGAbFylcdHlIAKMLc8e9vz2/o1UNj6F1jyZ8rpldPzX53H7Dg7bT9biNF7/V3i5dgrP/g8ykNAhG1PJV2xc1+dAzmBPb4U93Sar9vTSUTTpLtkc7LNp6zPIGzbteQPbcrEdQdF0cDy/uhLuBwHRouHphRCCIAhtxhRFQQQ+L+3Ls6O7zJTGFJoiE9NkxtQnOFgw2dBeYEd7keWbu8hVDM6YUEcioVOf0ZnUkKQhqUEQjjVLr/yRyoGtNF7zSYjXURoShHvlPjr+5+8ovnQPWtNkzF2vYGx78bDtNHe+jHBMsvMuozkDsi6TVFXGNiRoqtPoyRXY29NHV8Ukb1j4IszwiPeXkqiqWvUSj3j9Ef1lTgMumXnkmm/HP3zPs72zxGevnDWq6yTnXY7wHIzth3+YuLk2YhPmMvaD/0b2vJsprX6Q3od/gBCHEZFb+hlkLU7Pw98/bLB+PMwbm6l53ZTSo9rx04g7V+7jA79cyU+e3snQ2/6Smc28f8lk3rdkMp+6bHQBcP753+D27qfp+s8dtgasvG4ZPX+6g9j4uYz72I9In7GUwCoRWOVh+3rFHuz9G0nNv/ykrsz0VmzmjM3w+WtmkzMcVu3NMSY7fKJtVkuqmpZ+oizf1n3U1xEREa9fBp4/g73FZVlCUxVUWaG5PsmSWa2MbUzjyRLj6+I0pWIgCyRZIZAg3TQJrXUalU1PH/Yabi7MWGu+4fOMfd93CGyDzju/jNt7YNi+iSlnkj7zeoqv3I/UtZVMTMXxfUquT5/h4AaQiesoqoLjeWzrzLPpQB7b9lCFQHg+puWSNxzq4wq6ppDQFFxfUDYdzH5lddfzozTe0wDfD8iVbfoqDmXDoWJ5OAFcML2FKY0pLM+jJROnOR2nKamjCMJ7xBe05y329Zq0FU18L2Bze4my4zN3XAPzx2eImW3kn/0fEjPPJznvcnzCBbABvFIvnXf+PV6ujdZ3fo3Wd30dOPR7GEpl03LUdCNjZi+iIZVkTDpOYzrO3NYsM+rTGHbAmGQcF4mc6SCCAFmWURSZmK6RTMSjFfHXMVEgfhrwq48t4bJZzaNa6XtiSxf//sQ2xmZGbiMWmzAPpW7METtdKczzQZIVGq76GHWXvB+/UgD/kIyL8ByE56CkGmi45pM4bVspvXL/KFp+eCQJLp/dUrPyed/aNu54bCvv/8WKKBh/g3Pnyn3ceu96nt3eU7UrG8y6gwW+dfMivn1zWH84UqE2u2MHxZV3k1p0TdVSZDCVLc/R+/C/E596Fq3v+keURBb6J54OF2hXNj8NiMOWbZwIFdvnvT9fwXt/9iL/+uhWbvnpi0xvTqEqYRtUReLuz1zEsi9ecdLEDa+Y3XLU1xEREa9vJElClg+la/u+jypLZOMazUkNQYDsB0xuShJXJGQ1QJdk4opPwfAoe5CafwVO+7bDBhmHrBkFsQnzGPu+7yLFkvhmqWa/gXFBw5UfRU03sf9P36crb1KwXWRJRpFlVFmmI2+ydm8va/b30d5r0lO26DMdjACQJMY3JgiAvTkDy+1fDRcCXVWQJQnT9XD84IhCnhFvHCzXR0iE9mSBQAoCZEnGdn0CARLhSnlnvsKafXlMxyNvhs4BuipRtn2e395L4HtkEzLj6pI0p1VsL2D3fT8ARaVx6WeH9fO+WaLr91/Fr+QZc8s3SMw4D6paDMNDMt8qY+56hcaFl9KQSaBrMqoioSBRsjwUGZqySRqa0ki2T52mhgNewt+upmnIchTqvZ6JasRPE371sSXMvPWhEXuElyyPkjU6nVNJkkjNv5ziirvwKzmUVG1dqZysw68cGtzXXfQeEEE4e+5aWLtWU1z9ALIeJzXvclLzr8DY8lw4szjjfLSmiaNqz2CEgJ8/t5uhH3/A6/hItkwRbwyGliHUJ7UaVfHxdXFW7c3xg8e3jbg+XHguvQ9+DyVVT+NVHx/2vnVgIz0P3EFswlxa3v4PyFo4qeVXcqCoSLHhKeCVTcvRx81Gaxg/mo937LZCjX2hFwh+8dxubn/bQnKGUxViG7B1c7wAXZVPSPV8oCY8qhGPiDj1cRwHRVGq6suaqmC7AiHD+IYEjidoSCdQBGzvVIlpNrt6HVLzLie//D+pbHqa+ovfW3NOOVUPgF/OodaNQWuexPiP/QeSHK7eHW5c0Hr952j7w9c48OT/MGnph0mo4PvQVbRIaBJj40m6ChUqgU9SCJKaSl1KJqaqxFQ1FG5VZOriGkXLRZIlPM/HcgMaUhqqIh+3fWXEqYOiyAjXxw8CJFkildSZN76Og3kTVYaYptJXdtjSWaKvaNOWM8kVTfpMDycQxHUVx/XZ0llGlWXs5oBtXUXWPvpbKvs20nTD51EztS4jwnPpvuebuLk2xrzrdmITQhvigXGxnKwb1k5ny3Pge0w59xrq4xoT6mJIBARCwvR8ipZJSzqO5fuk6nSmtGSIaSpBECBJYRZLxOubaJrkNGLh+Oxrfo3U/CtABFQ2PzPsPa1+HG6urUaJVZIVfKNAee0jlDc8QXbxm6m/5P0UXvwdXt9BGq/7LJKq0/vQiaeoH86qRB6FLVPEqcubFo6ref2ecycxb2yGhCYjS7Clo8S7f/rCqETa8i/8BrdnL03X/xXyEIEhN99B9z3fQq1rpeUdX0XWDqWBu7k21Lqxw2bKne69uF27w9/QnwEvEDy8ob1GDX3Frl4cL6iZoDoRvv+es1l729IoCI+IeIOgqgoICV3TmNaaYXxDEssRBBJ0Vmw6izZ526W35OABaraZ2OSFVDYtH5byrdWHorFu7lBd+EAQfqRxAXVjSJ+xlM4X7qFv1xY8z8ULPLwgQBKCsmkjKRKSkPAFxOSAlnSKac0pAkKBt/qkjh8AIrR/SsV10gkVSZLxfIF2mOAlCAI8LyAIotXyNwJxTSEVU1EIMzwUJbwv5o3NMrYuRTKmU7Jt9nSXWb2/lxd39NBecilaHkEQUDIcuss2gRcQICjbDgd372T/Y/9JYvq5pBbWirEKIeh95N+xD2yk+Yb/RXzKGdX3BnQR1IbacQpAcdNyEs0TSE+cjSxDJqYiAxldxXZcgkAwtTnD5IY088Y2oCoytuOE2SuqEokPngJEgfhpxH2fu4SzJtZV/+gSUJ84uUkRevNk9DEzqGxcPuw9rWUKwjHxCodSg4XnUtm0HLf3AHUX3UJyzkXoY2YgxzMEjomabqThmk9ht22h+PK9o26PpoR+0BJh+q0qh/9WJLjprPFcPLOZ2968YNiq30A98Z0r9436mhGvP963ZHJVrOzTl03nP1/cw9bOEm5/sXggCAdmI8Ru20pxxV1hSvqM82reC2yD7rtuBxHQ+o7bwnT0Qbhdu9Fbpg47Z2XTUyDJJ6SWLgGXzWrmzIl1NKaObaH2/I6emtKMC6Y3VbNGfEE0QRUREQGAoihYlkVPvkDZquA6DoEIU3vr4zq5ikNbT5HekoUGZOMSdUpo5ZSafyVe30Hc9m0151QbxoOi4Xbtqdl+rHFBw1UfR0k3se9P36et12JfrkxfyaCjYLDxYJ4DHSW68hbNSY1EXKclEyOu67RkdVqysbAu3HFRVYUAgRCQUFViqoIkDm+xVrY8CpZN3nDwvCgYfyOQ0FUySR1dUwZNsAgkBHu6y+zqKmE5LqWKw4G+EiXDIG879FZsyraFhk/esNnXU+FgX4nNf7gjTEm//nPDAuDiyrupbHyKukveP6z0zO3eDYDePKVmu1fswtq/gcZFV+D6HigS6ZhOUtPoNUxMF3b2VjA8n+ZsEkWVkWVQFQVZlmr0HSJev0SB+GnGfZ+7hBktYUqsAPKmh3qS74LU/CtwOrbj9tYqnutjZgLgtG+vbrP2r8fctYr0oquJjZuN8D0qm59Ba5yAPnZG9XyJ2ReSf/Z/cLr3jKotri8I+jtWCbj9bYt475LJXD1vDA9t6OD5HT3c/sDGmnrYwfXEt967PgrG3yC8b8lk/vtjS8gktOqq7+GyJI5F4Nr0PPg9lHQjjVd/ouY9IQJ6Hvw33L4DNL/ty2iNE2re980iXqETfezMYcdVNj5NfNrZKP3pmseDILReu+W8yfzt0jnH/ixDVr4/9MuVNe8PfR0REXF6oigKjidIJ9IkY0mSOmRiSmj1JATFssm2rjKG69BXtkjGYyT0sP4xNeciUDTKG5+qOackK+it07A7ttdsP9a4QI4labrhb/D6DtDxzK/oKsL2DpsXD1gcrEBXBdqKBj1lF02TSGgSIgjoKlhsbCtwMFcioakosoQuy4DADQLKtoushGU83qCZWc8LsDwPkPECge26RJx6BEGAf4wZd1kOPbeLloOCRGfRJRnXSMZlSk6/eCESTakkqqzQZ/kEQvDyH39Dcd9mWq79NGqmuSa4Mne+Qv7p/yI599KwJHMIdvt2lEzLsL7f6F/QajrzSjxf0J4z2NJexHVtAhda6nUyCZVixYLAp1gx6MkX8Xyn+jki4cHXP1Egfhoy1Cc4EKArJ2/WLDn/cpDkYZ2u3joVSY1hH9wMgAh8yuuWkZp3GbEJ8xC+i3VgI3bbVvQx04FDM3pNSz+LpCfpffB7CH90tesDuL5gQ1uBe1YfYNmmzmowZrsB96w+pNL6u5drA++hryNe3wxkM3z+t2sOm9VwwfQmdFVGkUDXZKaP0q4r/8yv8PoO0HTD55GH1HkXXvgd5vYVNFz1MRJTzxp2rH1wCwCx8bVBsr1vA36pm/SCq0bVliPxu5f3kTMcPn3ZdGa2psnEahVTZSn8zSuDSjNW7c1RdmrLP4a+joiIOH3RVA3LtihVSvh+gB9AoWKxsa2PnopBb9GgOZOgJaMzqzVOWg/VouV4mtTMJVQ2PzOs/46Nn4vTvh3hh8HtSMcFialnkT77Rkqv3E9x3zrsQfGGAVQqPl7gYDuCshVQMlwKpoumKPSWXPb1lHH9cAwgSWHKuusHWHaA5QaYtjeojE7g++AHod/0QPp8xOufIAgoGKG//f6+Ch1Fi5LpDNtPkqTq39sLBC3ZOJIkYTgujudSH9cYm9EZk4zRkonRZ9hUXA/hQ/ngbg4s/28Ssy8k2V9aNhDuu7k2ev70L2itU2m64W8Ou0Jtt20ZNiYQQmBsfIrUpPlkmseS1lXKZkCuVKItb6LLEJMhFUtgBwIv8EnGY8QUBdcLkPu9wwd/rojXJ1Egfhoy1Cd4RnOKbd+6gYmj8A0/Gmq6kfjUs6hsfKrGmkxSNPTxc7APbOzfICGperVjrmx6BmvnK8ixJKlF14Y15P0PLSVVT+OVH8Pp3Enh+d8cV7sE8FJ/DawYsv0Pr+xn1d4cq/bmqAwJPloPY/MU8fpkcDbDfWvbqlkNn/zVK9Wsh3OmNPDrj1/AF5bO4dcfv4CPXjJyuzJr7zpKr9xPZvGNwwJtc+fLFJ67k9SCK8mc89bDHm/v3wCKij5uds328oYnkfQEiVlLRveBj8D6AwX+5dGt/OSZXXh+gDHkntZkid988sLqdwDw/l+sGHaetB4NOCMiIkLSqRgxTSYVj5FNJrBcl/29BmXLIS5LSASUDBtVlqiLKVjIZPorZJILryQwi5i7V9WcMzZpAcKzcTp2hBtGMS5ouOIjKHVj6Hnw+wS2UT2nAGIa+J5EX9lmzcEcJctCDsD3PYIA/CDAcjzKtovthtdSJAk38MMsQUlUy3RkWaYuqZHUFbJxnZgWPRdPFQzbw3ACypZHb8XGDTwO5CqUTWeYKJ/o9xaqS+qYTsDGjgKG49FddMgm4kxtSeBJEl0lE8v1USWJku2y475/Q46laFr6WYJBgXbgWHTf+22QZFpu/ocarZgBvEInfrGb2KQFNdudjh3YvftpPuMqfA8C4aMrIMsaOdPFDXwqpkNjAupiKhXHQvFdVEXCtM2qqGKUnv76JwrET0OWffEKZrWkqqtiu3sNLvnOExzIWyftGumFV+EXu8PAYxDxyYtwOnfhmyUkSSZ7/tsprX6Ajju/jLV3LWrTJLLnv72qMC0Cn8KKu+h74ueU1y9DnzCPwoo/VFfVR8uO7gqBCFcEBycBOL7guw9v5v2/WMGu7kP+zqoMn758xnFdK+LPy50r9/HPj2457HuPbeqs1kKv2pvjntUHOJg3gTBlPTmCgDOwyvQ8+D3UxgnUX/GRmvfcfAc9f/pXtNZpocDgETo+a++rxMbPrd7fEHbWxrbnSc655LAd9fEwOPluT68xzDfd9gXnTGngs1fO5JwpDVWRtsGkdYUNt19/UtoTERHxxiAWi5FIxBFCEFMkYppCQlWJxzRmj81yxsQGxtTHKVkCzw0o9i8+Jqadg5yso7L+iZrzxSctBMJJTmBU44L8s/+NHEvhF7vpevyn1XOmgMktaXpNjwP5CrmyS1veoz6lgqzQmIoR0xRMzyfwfQzHAQHZhI6qyJiuR8HyqgE6QExV0BWZuK6gRErUpwSeF1CwbCzXQ5LB8wUVa2AiRmC7YbBqux65ikPRdKjYHgldYXxdDMN0SOsq9ek445pSBKj4vkCRFVwfeg1B33N34nbtoun6v6pJLRdC0Pvo/8Ht3kvzW79UFSYcirX3VSAcGw+msvFJUDSCmZdguCDLgpQMFdfGFQJdhWRMJxOPAzKSJCNJgnRcpyGTwXVdPM+LrMtOASL7stOUZV+8gpt++BxrDxQATmoQDpCYdQGSnqC8/knikw+pQ8annEXhuV9j7X2V1NxL0FumMOa930a4FnKiriZAEYFP973fQkk3kZhxLqkFV9J9z7eQE1l6Hvg3xn3k35H1xHG1rzGl84Vr5/DV+9dXRbpe2nOoTlwCLpnVzOevmR3Zmp0CDKyEHw3HC7h79QHuWnWgGnT+ZuU+PnXZ9BHNSPYt+wl+uZexf/EvNQFz4NrhrDfQcvOtRwymfaOA07mLukveV7Pd2P4iwjFJLzw5aekj5ax/fJS1X7uOmbc+yEAMPpCqfiK2ZREREW9cBlJdYzEdTdNozRpIBY/x9WkyCY1cWdBVNNlbsrAdGAhlJUUlNe9ySmsfwjdLKIkMAEqyDq11OuaeNdRddAvAqMcFnb+9lfKGJ4jPPJ+mORfjAEXDQ1U0uooeDSkPw1OpS8YZU6fhBQFly6FouDiajOP7xGSVlvokZtHD9j1URcF0PDRFDldKhUBT5Wh18RQhCAKKlktcVTEdlyAIaEyoOD4kdQ3HD1BkGdfzMV0fNwh1AYIgvMfb8wbdZYeuoklCUUhk45RsD9v1EQIsG4oHNlJYeTepRdeSnHVBzfVLqx/A2PQ09Zd+gMS0xUdsp7lnLXKqHm2QUJvwXSqbniY5cwlyPE3BA4qQUcGpwLgGjfpUkokNMZzAww9iNCVSNKR18D1s2yaRSFRXxSNe30RTJacxG9qKr9m5ZS1Oau6lGFufI3DM6vbY+NlIsRTmrkPpaUoii5ptpbz2Eaz+FXQhBD33fxc5UUfDVR8nMe0cYmNnkpx9IQ1XfgQv30HuiZ8ftQ2aInGkLrOn7PDS7l6O5EQiCC2vomDk1GCoT/jhJA9kKbwf3EErvwL4yTO7jlkLXdn0NJVNy6m7+L3Darn6lv0knBF/y98ecdYbwNqzFhDDOuXK+idQ68YMS017rcmbXk0QDuH3EQXhERERR0KWZWRZxnE9ugpl0skYjZk481pTjEslyJUM2osmFcshCMJBZqp/pJledDX4HsYQe9PEtMXYBzfXpJePZlyQXnAVWvNk+h75IeVSLy5QsC084dKeL7GtvYTjewgCIMDxXQzboqdUoVixSMc0CqaL7fgE/dKujhtgeT6C0HkFROQvfgohBARCIEtyqO6vyIyvTzO2PkE2qSFLMjFNQQCKJFOyXTYfLLCru8wrOzt5ZF0HAaDJMoEMDWkd27YoGh4F06diG3Q98G+oda3DRFvtg1vIPflLEjPOI3vhu47cxsDH2rOWxNSzayZ4zJ2vEJhFUotCCzQNkINwUqu5QWbOhDq0uE4iGWdiXYrWujgxTaMhnSKbyRCPx5HlcNIostt7/RMF4qcxr7WveGrR1QjXwtj6QnWbJCskpp6NtXtVTf34wP6yngSgsuEJ5Hiaxqs/gazFkBQVc89azN2riI2fS/aCd1Je91jNuYfi+uKwNiQDPLCu/ajv/+7lfTVq6hGvX4b6hA+kYg/OILxoRhNvXzwRdZTChF6hi97H/oPY+LnUXfjumvdKrz5GZf0y6i68heQQG7OhmLteQU5kaxTTvWJXmB2y8Cok6bV5HB8pjTKtKwx14QkEURAeERFxRCRJwnV9fN9HkmUCTyAJQcX1SCRiTGyMkYprxFWFTAKSHLKG1FqnE2uZirnx8ZpzJqafA/1ByVBGNC7Ys4aGKz+O8B16/vSvCBHQnYd9XWUc12d/X5mdnUVW782RN2yKFRfHE7QXLF7a3cX+riK25+B7br+IpYQih8Gb7/kYtkfecCmZLo57fGKxEX9eFEUmocm0Fwy8IEDIEiXHIxCwr7fCgVwFxwuQAdvzKJYdyo6P53msOVik5Hi4nkDTVNJxnXX7c3SXPRwBZQ/6lv0Yr9hN85u/iBxLVq/rGwW67/8OaqaJpjd/8aj9utO+ncAskph+bs328oYnUFIN1Ul7CXACkHyoi8VJ6CpTMxqzm7NMbskyriFBOq4CYT24oigIIdB1PUpNPwWI/kKnMQO+4ooEUr/X9skkNmE+asM4yhuGdLozz8Mv9+F07qrZrsTT/aqo4FdyaK3TkPQwzdfav4G+R39E/aUfQGucQGbxm5HjGXoeuAOv2HNc7fMOM7s9+Dt49UCB9/58RRSMnwIM+ITXJ7Wa7YO9tJ/Z3sNf/HwF3tCC6aMgAp+eB+4AEdD0lr+tUct1OnfSt+zHxKecNSzd/HDnMXe+QmL6OTXnKG94EhCkFl494jYdiSPNLyhSmB0yFE+IYcecbCvDiIiINxayLKNpCrqmY1lhyq+s6WTTSVqyMSY21ZFRFHxZojGpcPaUFJOyMK1eYkJKIrPoGsy27Tjde6vnjE2YhxxLYex8adj1RjouSExfTN0lf4G9fz29d38zTE+3oM+yyFk+O7sNtnWW2NNbobdis62jjOc6eK5gT28Z2wso2x4JVUFGCq0dfUHFdvCFQFUlBKG+xvHYXkb8+UnFdVIxlea0TlxT8IWgq2QhANN2eWJzG09u7mTVzm4O5iposmBvzqStr0xDTEaXBQXTpTtXZl9vGSEJAh8qm5aHnuAXvYfYhHnV64nAp+eP/4JvFGi5+VaUePqo7TN2vgSSTHz6OdVtfiWPufNlUguuJCUrJIAxOjSmYMZYjZnjMqR0hfGNWeqSKgEyrgsJTQp/l7qKruvVVfGI1z/RX+k0577PXcLOf7qRuz59Ee9dMhnlJN4RkiSRWng19r71uPmO6vbE9HNBkjG3D1dphvBh5vbuJ+gXdDN3raLnT/9K3UXvITX3EgCUZJbM4hsRnk3X3f+ICE5OLczQ7nWwxzLAqr05fvTUjig4fx0w9G/xviWT+bvr5tbsU7Fr7wtziGL+sSi8+HvsAxtpvPYzNWnnvlWm+95voyTraH7rl45pZ2Mf3ExglUjMPKSKLoSgsv4JYpMXHTWlfaT4ApLa8B+wHwjede4k3r9kMmdOrKtONrlewBeWzqkG36oMO7594wm3IyIi4o2NqqroukpTOk46qVMfU6hLxNBUlSktGZYunsyls1qYP7GBTFwjUKAuHkNXoG7+FSArVDYcEm2TFJX49HMxd758xL58JOOC1sU3Em+dSmXnS+TWPIQfQNkATQSYtsuG/Tm2tRfY3VVka2eBnOGRjKnEVJWYKlNxPTxf4PkeqiywvYA+y6NoOfh+gCRJhK7jEa9HBgTYXO/QPdSSTeB44AWQjakUyhZdhQordveyZW8vz27tYnNnmc1tRbZ3lHEcFzmA7T1luoomCgEH+irs6XTIFwKcfAe9j/6I2IT5VU2DAQrP3Ym1dy2N134GfcyxRX7N7SuJTZxfE7BXNj4Fgc+4s64hm4SYAo31MKM1RWs6heW4xJUA23XpKrpIkkQ6oSJJoW94EASRjsEpRiTWFgGE6ahbO0r8UVMo2ydP4CG98GoKz/6ayvrHqb/0L4BQnCU2YR7G9hXVbYORZIWGKz5C5+++ipdrwyt00XDVx6udLYCX70BJ1RObuAD7wEaKK+4a9lA8GQx4LEMY+L3/FytwvABdlbntzQvIGQ4NSZ2c4XDB9KYorffPxNC/xYAF14B39ou7etnUXsR0j/9etg5spPD8b0jNv6JGSE2IgN4H/w2v1MPY930HJVl3zHMZ214ERa2pD7cPbMTLt1N38XuOu41Dsb2AuCbjuEHo39svvvaOxRM5Z0pD9XtzvaB6b0fBd0RExGiRZZmYruFbJpomocgyuqoAMtOa0siBz6v7KyhJh6lNdZQ9H7ME9U31JGacR3nDk9Rf9kEkJRyGJmctwdj8NPbBzVUl9cGMZFxAqZOJ51zJzmV3kl/xBxrnX4GIJ6koLmXLpS6u0l4wGZONk1BV2nImaUUmPknD91xKgUR30SQQMr5wSWgxsskYigAhJFRZQlPlI5b7+IHA9QIEAkT47FVP5upGxGEJAkEgBK4fIEsSri+AUHwNYExGD+ulBaRiGju7y+RKJpqsYgcuveUAVZZoyxfpKXsEQmC7AXnTpVDy6OmvRpA8j74H/wVJkml+yxdrJuDNnS9TePF3pBZdQ+bMpcdss5trw+3ZS8NVH69uE0JQXv84+rjZxMdOpjmp0zwmTiIm0ZjQQ8FAQFLj7M8ZjAeashqqEn6+aAX81CQKxCOAkalOHw9qtoX4tLMpr3+CuovfW31wJWdfSO7JX+Dm2tAaxg87Tkk1MPb9/4zwbETgo2aaq++5vfsxtq+odsTFl+8j/9yviU1eRHzi/JPWdkWCr79lQTW4HrB4CkSowH3b/RsIROg1KgExLVKb/nOwam+O7z++rfq3cL2Ae1Yf4O7VB7DdAFmCaS3pUaWgD8U3S/T88V9R68bQuPQva94rrrgLc8dLNFz9yZq0tCMhhMDYvoLE1LNrasnK6x5D0hMkZ1983O0c1m4Bi8ZkuHbB2MNOEA14qK/Y1VvdftMPn2NDW5GF47Pc97lLjnGFiIiICHAcByEEcT2G53sk4jq2b+F4Dgf6iiSTCcY3+gihYzoeqZiO7/rsz5mkz7gWc/sKzF2vVNWmE9PPBUXF2PbiYQNxOPa4oLB9BaLQRePVn6D3sf/g4LIfM/7NX6SjAikNsgmZjr4K7UWbuAqqJGjNpEAItnc6VOwCpgiY1VpHyfRoyUgkfRUbibpkHFWR0Y4QWAshsFwPAig5LnFVwROCpCRFdmevAaJ/7BUEAbYX4AdhyUAqriFJULFdZEnCdHxs36M+rqPKEsgSdSmd8Q0J2nImioCy5eL4gvq4TpvvUjQ9JOGTL3vkB0kC9D37PxQPbKXlbV9GrRtT3e4VOul54I7QvvTaz4yo/ca2F4FwLDyA074Nt2cvU9/8WRoTEigyrVkV15UoGwZxTSFdFydXKFOfiVGX0FAlGVkIZEVCkqRoNfwUJJo+iQCGq07XJzUm1p8cT+P0omvxS901QizJ2RcBHFVsTY4lUVINqJlmhO8CoV/zQBCePvM6YuNm0XTdZ1GyLXTd/Q0KL91zUto8QM4ITVBX7c1xMG+iKjKKFCpw+0HYEUCYqjY0jT3i5DGQhn7nyn28/xcreG57T9UPXlNlBGC5Ydq5L2BHV/m40weFEPQ+/AP8So7mt/5dTfBs7n2V/LP/Q3LupWTOecuIzud07sQvdNZ0uIFtYGx5ntS8y5H1k/M7G2BLZ4nPXjmT9y2ZXPUJH8xg//ABC0MvEKw9UOCmHz53UtsSERHxxsS2bRzHIQgChBBICBRJICRwPZm0JjOzJU1c05BVBdN1KDo+iZjK+JnnoKQaKK97rHo+OZYkMfVsjG0vIMSRn97HGhckzryO9FnXU3fRe6hsfIr2R35I95qHkAFd9imaHlpgYZkWru2jqBKuHbAvX6Ti+li2z86OIiIQNKaSOI6LhNIvIBMGe0IILMfHcvxqvbgQ4X+SLIUOHf0uHUf7LBHHRxAISpZH2XLJG27/QogCksD1fPrjbUqWS3fZIm+45AyLtqKJJgt0CcZk4kyoTzG9OUVS0yhZLru7y6RVn8a4jCdJKBo09MvOmLtWUVx5F+kzryc5KAtDeA7d9/0TQghabrq1xmrvaBhbX0AfM6MmoC+vX4akxojPvoyOvCBXsdiXN8ikVCY21yEpGrIQKLJEWldIpzRSMY24rqFpYUOjQPzU46StiEuStAcoAT7gCSHOHfK+BPwAuAEwgA8LIVafrOtHnBhvWjiOZ7cfEj37u+vm8r4lk2u8xo+X5KwLkBNZyuuWheqogFrXij5uFsbW56i74J3HPEfPQ98H3yc+9Uy8Yg/phVcR61eflmNJmm/8Ap2/+QqFF36HcCzqjyGeNSIkaMub3LlyH7c/sLG62nrOlAbqkjpPb+sOfSdFOKM1OI094uQxOA19YAJEEH7nF89s5k0Lx3HfmgMn7XqlV/6IuX0FDVd9gti4WdXtXqmHnj/+M1rjBJre9Ncj7vCMLc+CrJAY5DNa2fwMwrNJn3HtSWv3AHPHZPjRUztGVCox1MLwtbQ0jDi9iMYEb1x83yeRSGCaNo7rkkomQIAbhIFnNi5xoLtALKYzPhujPRtHV1WKVoDjQ67skVp4NcWX7sEr96GmGwFIzrkEc+fLOO3bhtlEHo6jjQvqLroFc++rlDc8jtY2kXYrT/ObP4QIbHoqEpKkIEseuzskxjQkqPgS9Zoa1hdnZerSYbpvMh4jrkHRsNFVhbiu4Pug6zK+H2C5HqmYhuP7eP2p0aosIyEhy9Fq+GuB64fp/5oq43gBjh+gATFNJaEpOF5AyfLY01dBCgTphEJnwUVRFJAFuqpjeDC+SaW9r4Iih0J8faZLUteYPibJBEmwfl+Ozt4g7Psf/De0lqk0DLEq63v8pzgdO2h5+/9Gaxh32PYOxSt04bRvpf7yD1W3Ba5FZdMzJOdchKulgHDBRwhBd8FkzIQGHN1n9tg66uIKhhvQkooTVySECFBVLQrCT1FO9or4lUKIs4Z2uP28CZjV/98ngR+f5GtHnAADqtOXzmrm2zcv4n1LJgPw7vMmH3b/0fzcJVUjNf8KjO0r8I1DQX1yzqU4HTtwc+1HOTqk8drP4BU7Ka64i+x5Nw1LCY5PWkDDVR9H2BXKG57APrh5FC08PH4Av3lpH7fdv6FmtfWlPTmWbeokEIJzpzQwsyXF1fPHRGnprxGDSwKC/tlgRQJdk3nTwnF8/Y8beGnPyRHPs9u2klv+/0jMXELm3LdWtwvfpee+7yA8J5z11hMjOp8QAmPLc8SnnImSOGQXWF73KFrzFPRxs09KuweY1ZLi1YMF/uXRrbzzJy/UiApee8dypn/lQa69Y3l121ALw9fa0jDitCMaE7wBURSFIAiIx3US8Ri6riNJEJNlFKDH8DhYqLB5fw+2a9Ha715RF9OQApeKSzgJKQIq6w+5qiRnLQFZpTLEZ/xIHG1cIMkKLW/9O+RYGpUAY+8alL7tJBWJPjP0Ck+pCqqmkY3FmZiNocqC1voks5rTZHQNCQnP8/EliQBwAx9NVfCCANP2qgFhb9mkYrt4QWjlFtcVFBkSmhIFR68BiiwhBHh+qHWS0lXimkJCUxACLNdDQiYbUzE9n709Bl1Fk658mVd29PDynk72dOTY11viYNHiQKGM5QY0ZxI0Z+MQSDi2IKFpeIRK6MK1aXnb39eseJfXLaP86qNkL3hntcRiJFS2hJlnyTmHVtaNLc8jHIOGM68jDmQ00GRBX9EmUAJkSdDcEEfTVDRNZ1ZrlrSuo+k6XhDeY9G9dmry50xNfxvwKxGyAqiXJGlk00cRfxbet2Qy//2xJdUgHIanrA8w2mSr9JnXQeBR2fBkdVtqXvgQMrY8e8zjlXiaMe/5NkgS1t61YRuG+JBnznkL8WmL8QuduLm2Ubbw8AQiFGA5HJ4veGlPjh3dFZZt6mRrR+mkXDOilgumN6GrYUmArsrc/raFfGHpHH798QvY0FbAOYFa8MH4ZpHu+7+Dkm6k6cb/VdOp5Z78BXbbFpre9DdozZNGfE6nfRteoZPU3EsPbevahdO+nfSZS09qxylLcDBvMpAJKQR87tergDAI395dIRCwvbtSDcYHLAxVWeKsiXVRjXjEn5NoTHAKEwbfErqu4/sBQSDQVZmkqtLTm+dgr8XmjjJPbumiPqWT0VUURUaPaWRjkGicQGzSQsrrllX7cjmeJjF9McaW54b174fjWOMCNdNE81v+FrN7H0ZvG6hxbCQ8Bzr7XLZ3OuzqybO9s0TF8qjPpEjFdAJJouJ4uL6LLMv0FS1Mx8X2BIHvEwhByXIpWg6aLGMHAVq/UFZX2SZXccj1+45HnHxURSYVU9FVmXRMRdcUVEUOPe79sGYcBDISshA4tkvR9dmTs+kr2wQBrOsos7u9gKJINMTU0GLUdtnfXWTdvm42HOijLW9z8Kn/Dp1TrvssWtOhvt/u2EHvY/9BfMqZ1F/6gVG139jyLPrYmTUr6OV1j6I3TqB++gLqEqGmgSaB8EAPfHZ2l8lqKtOb0kwbU8eYbApZAl1R0BQ5KoE4hTmZYm0CeEySJAH8VAjxsyHvTwD2D3p9oH9bTaQnSdInCWfHmTz58KuxEa8933loM49s7GByY/LYO48AvWUK+vg5lF59lMx5NyFJEmq2ldiEeVQ2PU3dhe8+5jlkPcG4j/4IvxKu8vnFHpBkArOIX+4lcG1kPYEUS4WrmtMWo6ROfIV6pI+3hze010xiRJwcBguMDRYgA7hr1clJSRcioOeBO/ArOca+/59r7ETKG56ktPpBMufdVKvQOwIqm54GRa2pDy+/+igoGqkFVx3lyNEjAYZbO3htL9oA7Oyp1Gwf/DoKviNeI6IxwRscVVUJggDPC1BVBVmWCYTAQcLxXMY1pDACD+HZTGiOU59NsKe7SE/RpGB7pM+8jt4H7sDat57ElDMBSM27DHPHS9gHNh1RtG0wIxkXaM1TsHv2snbFi9SfczM+4SpUXIGCKcipZcY1xNEVGVXxKdoSmiLh+jq269BZsFA0mYQuoyKhaaD317235w38IABfENMUJCCpq/hBgOX6RDlGrw1av3Ce5XqULA9NkVFlMN1Qs6Bo2liux66eMobj4wfQklLYbQSYdonevEFB0ZkmKTiuR1yR6Kk49OQ8jCCslTG2ryS38i7SZ11PesGV1Wv7ZpHu+/6p3770745pXzoYt+8gTsd2Gq786KFtPfuxD2yi9cqPUBeXmFynkTdcFAGODslkEkVSsP0AT4Aih789SRIEgY+qymHafcQpyckMxC8RQhyUJKkVWCZJ0hYhxMjyiwbR31n/DODcc8+Npnj+f+A7D23mJ8/sAmBPr8Fls5oJE7n604SPcwUyc+b19D78A+wDG6sdbGr+5fQt+wlO9x70lqnHPIesxZDrx1Ja/QB9j/+MxMzzkbQYwjFR0o0kZpxPauFV9Nz/Xbr/+C+MueUbo3pIjgaJ2iD9TQujxZzXioGU/8GWZe9YPLFqT3KiFF74HdauVTQu/Utig9LFnc5d9D36Q2KTF9FwxUdGdU4R+BhbniUx/Vzk/sA+cC3KG5eTnHMRSiJzUto+wNF+ljOaU2zvrtS8joh4jYnGBKcBA0rNQRA+i3Vd4+r5U7irtIOibeO6Pjt6TSqWwBEBbXmTwPeoj4E6+yJy8Z9SXvtINRBPzLwASYtR2bR8RIE4HHtckDnvJsrrl9H21H+hjl8IA9ofPugOFByPfb0VipZHY0ZjalMdhu3TVTKwLIc+06YlkwFkcmUTIUsgSQSBIB3XaK1PUbYdsnEdTfEwHA8QZOPasLYGgcAL+kusovrx4yIIAlxfIAJB2fHQFAnb87GFoGw79FRC3/dMTMG0XCwPXM+jI1BRJYu1B8q4DgSyha5JKJKEJMuUbI9CAC7g5trpefDf0MfMoPHqT1avLYIwVd0v9zL2fd8dkX3pYCqbngYkknMvq24rv/oIyArZBVdjuNBnCUoOtCQlZFfQW3ZpSck01yUpOz5Nno8reyRiGrqmIvXXkkep6acmJy0QF0Ic7P9/lyRJ9wLnA4M73YPA4JzOif3bIl5nPLKxo+b1vj6D6xeMrQbnx0ty7qX0PfFzyq8+Wu1gB7ZVNj6FPopAJ33GdZi7V6MksjS96a+BcFVTksL0sMaln6X3oe+Rf/q/amYeTyYzWlJ89JLpPLyhnTctHHfY1fBVe3M1VlERo+fOlft4eEM7tutj99fq227A6r05jlA1MCrMXasoPHcnqQVXkj7rTdXtvlmk695vIccztIxy1hvA2rsOv5IjNWgm3djyPMKukDnz+hNv+Ai59o7lLPviFVx7x3J29lSY0Zxi2Rev+LNdP+L0JBoTvPEZCL7DpIdQE0MGZo3Lcsn8MezrqeD7UHKgIRuwt9tkQjbG3t4SaR38WIzUgisprX0Y3yigJOuQ9TiJWRdgbHmOxqs/haQOD2aPxNHGBclZS2j/z7/hwH3fZtyHvo+SrKMCpCVwXCibDkhQtn3SikI8oWF6GmbZJud7xGMexb4ylUyCic1pFBV0RcL1PHw/QJcUYppCMq5hu36Ytq/WVn8GgaCv4uAFARISjWn9iHZoEYcnCAJ6yg6262PYPsiChKqgKBKu57Nufw5ZgEeA6QgOFEx0RaE5rdNRMtG0fj2DOMQ0iVLZRInFiamCvBEG4YFj0X3vt5AkiZabb0VS9er188/+D9aeNTRe97kRCQoORghBZdNyYpMXoWZD6z3hOZQ3PElm1oW0NteTSWrosk9MAUXTGVevMD6rM2tsI81JhXRcJkAmZ7ogAnzPR9dVVDVyoz5VOSlPAEmSUpIkZQb+DSwFNgzZ7Y/AB6WQC4CCEOLYKl0Rf3auXzB22OuhwfnxIOtxUguupLLlOXwzrKdWknUkpi2msulpROCP+FySqtFy8z/gFbroe+Ln4Tbp0O2cXnQ16bNvoPjSPVVhjJPN9Jb0YevqBxhQ+77jsa28/xcraoSzIo7OgF3Zdx7azK33rufZ7T28tCdXzUAQwOaTUJPv5tpDJfTWqTRe99nqjPLgWe+Wm289rhKHysYnkWIpkjPOq24rr30YtXEisRGu9IwGRQoHhvEhg7+BmvBlX7yCXf90YxSER7zmRGOCNz4DK3ADq+EDq3KhmKbEnDGNnD21CVtImI5HX9mjaDgUKja+B30l0GKQPvN68D3Kg0Tb0vOvJLDKmLteGVWbjjYuUBJZWm66Fb+Sp+eP/1wdbyiAIsD0A4xAkFQg53joCCqmR6fhUTI89naX8TyXuAaO46NKCklNxQ2gLW9g+l41KymmKcOCcAAvEHhBQLzfast2Rz7mOd2xXA/D9jAdv5oVJ0uhZZ7jB6iyRL7i4LoBLoLugkHespg3NgMCkkkVxxWUbRchoGSA8AWaqqAIQcGw8QKIC0HvI/8Ht3svzW/5Uo29WGXLcxRX/IH0mdeROWv0k+lO+za8XBup+VccOufW5wmsEk1nXY8iQVJXQJaJq6CLgOaYxoXTGzlvegOZVIKEpofq/XLoICNJEr7vR6vhpzAnaypuDPCcJEmvAi8BDwohHpEk6dOSJH26f5+HgF3ADuDnwF+epGtHnGS+fMM8Pn3ZdKY2Jfn0ZdP58g3zhgXnx0vmrOvBd6lseKK6LbXgSvxSD9a+9Yc9RgiB3bYVt692sUSSFVpvuR0lVY+b70B4LsI7JI7SePUniI2fS+9D38fp3nNS2j+YpH70FdLBat+Rx/jIGTyB8bNnTywL42gEjkn3vd8CSaLl5n9A1g75eeef/q9w1vvavzzsrLdvFLD2byBwrSOe29j2Iqm5l1Rn053uPdhtW8iced0Jd5rpWO299+2bF/GFpXP4zScv5K+unjVs/6E14hERrzHRmOANzkDgPTglVlFClXBJlsnEVerjCmOyOi1ZjZiikI0rdFdcGpMKmgZBANmWKcQmzqf86iNVobX4tLORU/WUNz55xOt7pR6sA5tr+nwYPi4IbKP6XmzcLJqW/iXW3lfJL/9PAHosqARgWi7lYpmC46NLATu6S+ztLeIFPnFVRgifREwjCBTsQJBSQVYkQKI+FQMhY1jeUb8zRQ4nLmzPxw9C+62IY2M5HmXLx/J8Kq5HIAIsz8cJBMmYRjquIgQUDBfH8+ku2VRMn5aERjymEdMlmhI6LUmFg3mXppRCXRLG1KfQVJVUXEbrr7Huefk+jM1PU3/ZB6p2uxD2370PfR99/Bwar/n0sDaKwA/HqfkjL1qVNzyJpOqk5l58aNvaR1Drx9Ew/QwkFTQ5QJdVGpMa45sztGQSlD2ZfMVFBBKpmILr2vjCxfXdalZKxKnLScllEELsAs48zPafDPq3AD57Mq4X8drz5Rvm8eUb5tW8hjBtvT6hHbe3uN46LRRtW/sImXPfhiRJJGYuQYqlqGx4gsTUs4Yd0/fYf1Be+zAA8WmLabz2M1W1SUmSqbvgXQjfpfPOr6A2Tap6PEuKRvPNt9LxX5+n+55vMvaD/1ZjITUaUrpCxamdvV67P3/UYwbUvl0viDzGR8HgCYzXCiECeh/8Hm7PPlrf9XW0+kMTTZVNyym+dA/ps28kc+bSmuMCxyT/zK8orXkIAh+teTLjPvSDYemTxtYXEK5FauEhQbbSmodDkbZFV59w+6+ZN4YPXDj1iGUPQ/ULoprwiD8n0Zjg9ECWDwWSAwGBoigUSiX25izcACbUJxlfn2Lz/jyGpZJO6YzTkph2DtsXWB7Un/UmOh+4A3/vOtSpZyHJCqn5V1Ba9UA1ZX0wdvt2On79d+C7yKl6Gi7/CKmFV1UnBAbGBaVXH6Pw7H8z9oN3oGZbgdA2zencQfHle9HGTCe94EoKZVBTIKsSaVWmYgoKToBCQCIlYzsqY+piICsIOWBGS5pELBSqM2yPsuWgKypx/eiBtSJLNKZ0XC9AUSRiaiSwNRK8AGQZVFnCDyRa0nEc36cxqaMpCqbj4omA+nQMK/CI2T6uLhNIMgd7DDJxjYrtk0hogIvrByQ0MB0f23fpzgeYDhR2r6F3+f8jOfsishe8q3p93yzRfc83kfUELTfdOry/376Svid+hl/oBEmm9R23kZhR69goPBdj8zMkZl2AHAv7Y6dnH/aBjdRf8RFKgYxkgRmH+pRKxQpIyAqTWutozcRQJR9J0knqOlYQKvgjwPNdEok0Eacu0XRcxIj58g3zWP6lK7nvc5dUPcdb0vqxDxxC5qwb8PoOYPevgMtajNTcSzG2vVAzew3hw6u8bhnJ+ZdTf8WHsdu20v6ff42x9YWa/SRFIz7lTCrrl1Fa/UB1u5pupOWmW/FKPfTc/x2Ef/QZ6yMxri4+bNvkxiQ/emrHEVPOB9S+B6y2ohrxwzOQhj7wPV4wvQl1kIiNIkucObGOT182nS9dN4el88cc6VQjpvDC7zC2vUDDFR8mMW1xdbvdvp3eh/+d2KSFw7QFAtei/b/+F6VVD5A+41rqr/gwbs8+rAMbh52/vOFx1PpxxCbMD491TCobnyQ195LjngwaYFw2Rm/FYdkRykXOmdLAXZ+5iKaUhkToLR6lo0dERLxWyLKMqoZ1qrbtsq+3QuB5JHWJmKbgux5CgnRMJ6HIZBI6i6fXo8gSCQXGLbwYJZElt+ah6jnTC68OLU83PT3seuX1jyMpKs1v+zJa/Th6H/oePX/612EZSvGJ8wlcm+57vlXzXsNVnyA2aSF9j/wf7PZtOIAdgITEwaLBzt4CaV2nbHvYATRnZBrTOhPrkiCH/s6u7+P5Al2RAUFCk7Acn7IVBnpHQlNkkjE1CsKPghAC1w+qYqxxTcYPBGXHRZZAUWQycZ1MQieuK2iKQkLVaEhpCB8UZBRNoaO3Qs6yMN2Atp4Sa/bl6C7a9FiCXgP6ihZBIFMoQ0dbG933fwetaVKNfakIfHru/y5eqae/TK2+pq2FlXfRfc83kPUEzW/5Ekq6idKrjwz7TMaOlQRWKbyv+ymvfRgUldZF15CQIClD2XaIyYJMQiPQFQzXQ5ZkQEfGx+u/7xRJQpEUIFJMP9WJAvGI4+J9SyZTMl26y86oj03OvQQ5ng5XFftJL7oa4drD6rl9qwSBR3ziAuqWvJPxH/0hWtNkuu/7NoWVd9d4J9Zd+n4Ssy4g98TPMfesrW6PTZhL03V/hbV3XbVubLTsOkxq7zPbe/jXR49e/33OlAY+e+XM0yIIHxpQj/SYoXX050xp4F3nTqIaigvB0gVj+fIN82hI6of9W4yGytbnKTz3a1ILriRz3s3V7X45F856J+vIXvCu8B574XcYO1YCIGtx0mcsZcx7v0XTdZ8jOSu0JBuwzRnAzXdg71tPatHV1c68smk5wjHJnH3DCbUdQkuyZ7f38JNndvEvj27lvT97cdh3fs6UBlZ9dSm7vxPVhEdERPx58H0fw3ZJxFScwKVSMYgpIGsa4xtSzJ9Uz8SmNNNb0kxsrCMRjzFnbIrGuiTpRddgbF+BVwpLuPTWaehjZtSUsVWvU+lDzbaQmnsJY97/Xeov+yDG5mfo/M1X8I1D2Xpa00Ra3volnM5d9D70g+p4QVJUWm76CnKqge57vkml1EO7Cd25gGLFoVR06C1VKJs+rheQszxKpoMjAnRJZmePQa5iU7BcLD8gG4+RiuuUHA8/EFRsj+C1TOk6RfD8ALffY340GI5Pxfao2B6266MqMvVxjYyuh0r04pA4YNFw6C3bbO3Ic6C3jK5K9BYtCmWTzopNNqZjOj4F2yWuyKR0jYQChglFG9r7PHorZdruuh1khZZ3fBVZT1TbknvyF1h719K09LMgyXTdfXvNuCA58wKyF7yTcR/6Hqn5l6M1TRw2JgCorH8cJd1EvD/rM3AsKuufoGHOxSTTdSgqJJMS2bjKuIYk509p5IKpjUzJJmhMaqRjATFNRQCyLJHUY2iqQkyPRNpOdaJAPOK4WLU3d9zp6bIWI7XoGoztL+KV+wDQx89FbZxIZf2yIfuGK9GBE66Uq9kWxr7vn0jOvZT88v9Hfvn/O9S5SjLNN34BrWkiPfd/p6amPL3oarLnv53ymgdrVsxHypH6EUFU/w3HL0x3pDr6BePrUGQJWaKa1n/nyn3ceu96dnSVj7uddscOeh/4N/Txc2i6/q8OzXp7Dl33fpPALpM97+3kHv8pyVkXIHwP++CWcB8RULfk7cQnnwGEq9xATW05hB0uSNWZbyEEpTUPobVOQx8/97jbfiQcX3D36pPjpx4RERFxPAyuFW9K6uTzFmXLQwGaElCxXUzbY/GUBmaNySIJaE1p9JZt9uc9UmddDyKg/Oqj1XOmFl2D07kTp7NWK0TW4tXnryTJ1F34blre/g+43fvo+PXf4RW7q/smZpxH/RUfxtjyLIUXflvdriTraH3HV0OtkHu+SeBaGD5UXMjbUDJdLF9QcgMKJZN8yaA3b1O0fHIVm40HCrT3VvB9geUFWK5HQIDcn811uofhrh9gOD6WG/43eNHkaAyshuuqjCJLocCdH2C4AbYXeoYD+EFAyXQpWx7dZZMdHSU2tpfY0VFm9cEcz+3I0VNyyBk2KgGNqTiN6Rhl2yFnQAB4HhQ9n+77v4uXb6flpq/UlKmVXn2U0qo/kTn3bcipenoeuGPYuEBtHE/D5R9GUsJ09cA2kLVEzWfySj2Yu1eH5RP97iuVTcsJHIP6s2/EEyD5gKIhJJm2fJE9vRViskxLfZLWuiTZdJpsIo6myGRSKTRNJqarxOPDszUjTi2iQDxi1Hzwlyu55acvjmjfoQrOA2TOehMEfrXTlSSJ9BlLsQ9uxu3ZX91PjiWR42m8QQIYkqrT/NYvkT77Roov3UPu8Z9URV7kWJKWd9zWP3P5DXzrUNBWf/mHSMw8n77Hf4a5a9WoP/fRGKj/Pp5V4TcCxytMN1BHrwwKuFftzXH7AxvxA4EsSdz25gUAfOfhzSfURq/YQ/fdtyMns7S+/X9XRdSEEPQ+/O84bVtpvvGLCNek7uL3kjnrTcQmLcAv9eCVegj6lf4HBhQD96Ra11q9hgh8yuufID7tbNRsCwD2wS24XbvJnH3jcYm0jeSISC81IiLi9UA2FcMOAqaNr2fuxBbcQCYd07lkzjgWT2ti8aR6mjJxuopldvSUac97OIDWMJ74tMWhaFt/CVlq/hWgaJTXPVZzDbVuDH6pt0aoLTnrAlpv+QZ+OUfHr/++RjQre/7bSS24ksJzv67JutNbptL8li/hdOyk94F/oywCzAoICZAC4rJEvmSwudNg7f4+1rd1U7EdioZD0XHZmyuzv6eMYblYtsee7jLr9vVRMmwUORSzG20g+kbB8wMUWUJTZAIhGOnHl6TwGNcL/dZVWcL1AiQpFGZDCtXne8sOO7tKvLSri4fW7mbdwRx7Owrsz1VIKjJpXSYbU1AVlYkNKeriOj4KqbhEvQbxOFiBoOvxn4XCrEs/S3zyomo7rH3r6XvsP4hPW0zDlR/F7do9snFBoaNmTABQ2fAkiID0GddW9y2teQitZSr+hHmogBVAoeSAgIZkirSmsqu7RHM6hq6qxBW533teIaZppFIpEolEpJb+BiAKxCNGxQd/uZJntvfgjTDVyPIOXyulNU4IO921D1c73fTCq0BWKK17tHbflqk4nbtrtkmSTOO1nyZ7/tsprX6Qvsd+XH0QavVjabn5Vrx8Bz33f7d6fklWaH7Ll9BaptB9/3dwumrPebz4ApZt7Dit7coOF1CPhKF19ADff3xb1S9cCMGGtgLv/umLFI+hSHs0Atug6+5/JHBMWt/5tRo7suKLv6eyaTn1l36A5JyLkLQ4fY/9mMrmZ+l96PsEVpn80/9F3+M/I7CNasfndO0GSUZtnFg9l7V7DX6pm/QZh0TeSmseQNKTpOZfPqK2Du1WP3XZdFRleGcrSeG+uirz9sUTh70fERER8ediwMpMURRasikCodJbNlFkQYBEUleJKQpbOyq8uq+HbR0GIDAGDREyi2/EL/dhbF8BgJLIkJx9EZWNTxG4dnU/rXkKiACnZ29NG+IT5zPmvd9GOAadd36lGoxLkkTT9X9FbMI8eh/8Hnb79uoxyZnn03DVxzG2vUB++X+SB8oG7M+5eMKlt2yhSy7ZRBzHl9CQMGwP1/JQkegrG1iuS8XxMS0B+HSULAw7XL0tWS4ly6VsH+q/hAhXed/I6euaIhMEAscLUCSpmikwEpK6QjKmkI6pCASm62O7Pq4foMphvXhH3sByXNbvzbE/Z2J7PrawQfjE9XBcJqsqshSw6WCeFXt62NVVoNcQHDCg14LuV+6nvOZBsue/vUaY1c210X3vt1Hrx9Hy1r9DkpURjQu8ch+BUQjvz36ECCive4zY5EVoDeOBgcn5XWQWh5PzOQEGgAR9JQfDDUjENZKaRkxV0FSZTDJGJqGTiCmo/QtcURD+xiAKxCNGxUt7+mpeq7J0XIJtcJhON1VPcuYSKhuerJnpjo2bjdO1E+HV1qNLkkT9FR8hu+SdlNc+3L8yHnZs8UkLabrus1h71tA3aLusJ2h9x9eQ9SRdd/0jXqnnuNo+lF+v3HvEVeHTYZX8RITpBuroAd7/ixU8v6MHQRhkSpJET8nGP4EBiwh8uv/4XdzuvbS87cvoLVOr71U2P0v+2f8mPuVMshe+G4DsuW+l4cqPYB/cTOacN9P6zq9Rf+lfIOuJal0YgNO2Gb11GrIWq24rrXsUOVlHctYSIKwfN7Y8T3rR1TV1Z0dtLzC1KVm1D8wktMMO2N525nj+9ro5/OYTR/6+T4d7LyIi4vXBQDCu6zqTm5KMa0gytTlDfTJBEASsP1Dg5V19rN3VxfYug74yDDYfS0w/FyXbWlM+lj5zKYFdwdj6fHVbbPxsAOyDw7OkYmNnMuY930K4Jp2/+QpeoStsm6qHNpXJOrrvvr0mfT1z7lurGXb5NQ8RKOAKsPp9vgUyBcuhUHHZ21fG9X0cP6xfLjqCrZ1FchWTgmnRlrfY3VWmZLk4vo+uyuiqjNO/KCGEoGR55AyHnOFUBcneaKiKTCqmkoqpxI9h9TqUgVVxAVhOgKaE95UEaIqELEOf6dBXsei1PTQBhmVi2AENMZlCxcYNXLpLBgd7SmzrLrCto8j+HpNcCXxCZ5Pck78kOfsi6q/4cPXavlWm667bw1K0C9+NHA8VyUcyLnD609UH7k8Aa+86vHxH7eT86oHJ+StqPnfgQ30MEgok4yr1CRXb9bAcFwlQVbnGqSDijUH0F40YFedPbax5fdGMJn7ygXOPsPfRSUw/F6VuzJBO9zoCs4ix/VDqe2zSAvA97LYtw84hSRL1l3+I7Hk3U1r9ILknf1ENutNnXNsfpD9C6eV7q8eo2WZa3/U1ArtC1x++TmCfuMdyyfYpme5h06xPl1XykQrTHSk4HDyREQbhEAjB8q1dx90mIQR9j/4Ia9cqGpf+ZY0vqH1wMz0P/CtIEnIsVa35grB0QtLj1bovtS5Uah8IpoXnYh/cSmzi/OoxXrkPc/tK0guvrh5XWvsIBB6ZxW8eVbv39Brs7TX4xXO7q/fVUO5b28YPHt/GDx7fxnnfXMbULz/Ied88pLFwOt17ERERry8SsRhJLYas6sRiGpIQuMJHyIK86aMQTjqmgMb+x5skK2TOvgF7/wac7j0AxCefgdowrqZ2XMm2omRaqs4rQ9HHzKD1lm8ibIPO396KVwwn3JVUPa3vvI3Atem66x+rLi2SJNF4zSdJzDiPvmU/oWPLSjoMOJgDfBB+QF/RxvVsDuYr7M+bFEyPQsVBBC6u7dFTcugzbCq2x7hMHMPxSKgqthtguwEJLQxG/UBgez6x/te2+8YMxCEUFRvwTj8afiAwHA/T8WpS+INAYLoeJdsjbzoYjoMfhGnucVWhrWgTV3zsQFA2XBzTYn1bkaIFri/RVbDZ0WNg2h6OD10mmIB1IOz79fGzaXrzF5Gk8AYUvkv3vd/Cy7WB72LueAnrwKHJnmONC6z965HUGPrYmdVjyq8+ihxPk5oTeof75RzG1udJL7oGbejkvASXzxvHgrENLJhUz/SxdQSShBDgeD5BIPD9AP8NOnlzuhIF4hGj4lcfW8Jls5qJazKXzWrmVx9bwk+f3lmzz8zWkXka1nS6/Wni8Wlnh8H52kP2D/FJi0CSMXevOfx5JIn6Kz9K5py3UHrlfvLP/k/1vfrLP0hyziXknvp/VAbNqOut02m56Su4vfvpvvdbNSvwx8vG9uKwVeHjrZ1+o3K04LAhqVdF8QShQF4gwPWPfzW88PydlNc9RvbCW8icdX11u5trp/Ou25H1JPWXf5j4tMUYm5fjFQ8F/cmZSzA2P4ux9QXK6x7D6diOkmkGwDqwEeHZVQVUgMq6ZWEd2JnXhZ/B9yivfZj41LPRGieMuu2CsBbuF8/t5voFYw+7j+MLntneU3Uv6C471WA8uvciIiL+/8D3fSRJENdlknq4IhyLa9QndCqmTd5yWDgxTXMyFM0SCgzkFaXPXIqk6tUJekmSSJ95HfaBjTg9+6rb4lPPxNq3DhH4h21DbOxMWt99O75RoPN3/7uqZK23TKXlpi/j9uyje5ClqSQrNL/179HHzqDnj/+MOLgZ14feMuzqhe4S7M257O0t4fmCkuOxqavAgT6XzooNiszEhjSNyQSSIqFIMqm4SkNKpyGlk4yF6tayJCEh4foBvhAop/koXIgw2JYIA/KBzAHPD/CDAD+AXNlmd2eZHZ0Vuks2jueTiiuMz8aZ3JxBlyWUwKOz4rKzy6SrUKKzp8S+XoeuIhzIQalfoM3t2U/33bejZJpofcdt1Yy2Aa0Ye996ErMvouGqjxOfetaoxgXm7jXEJi2oBup+JY+x7UVSC6+uatKU1j4MgUd28Y0MuJHLwNgYTGpUGd+cJJ6M0ZSKk45p6IqMIhOKAQbBIYHZ00xz4I3Maf4IiDgefvWxJWz5xpv41ceWcOfKfTy2qbP6nqZIfPTiaWiHqWk9HOkzliKpMUqr/gSEtd+ZM6/D3rcOtzcUbZNjSWIT52PueuWI55EkiYarP0n6jKUUX/wdhRd/Xz1f043/i9j4OfQ+cEdNKlti2mKa3vQ3WHvX0fPAHUfs0EfKgnHZYavCRxIjO13ThQcHh7Yb8N2HN/OBX67kOw9t5uEN7Yc/6DjLoEprHqLw/G9ILbqG+kv/orrdK/XS9YevIQGt7/022fNvJjZ+NpKepLT20WoJhN46ncw5b6ay+WmM7StovukrxPpnus1dr4CiVhXUReBTevVR4lPOqAbdxrYX8Mt9ZM55y/F9gH6CQLB2f37E+w8E5cdbtx8RERFxIvi+P8jNREKVJQhgXF2Cq+aP44o5Y1kwoYG6pMLYTJhyLAEaoCSyJOddTmXjU1Wx1fSia0FRQ9/lfhLTziGwythtW4/Yjtj4ObS+6+v4pW46f/dVfLPYf+xiGq/7LNbu1fQ99h+DStfitL7jayiZRg7cdTu9Pfvo8sP0eScAowKFiiBXtpBFQCAkipUKXQWLvR15hAhw8XEFyIqgZDr96dSHhtqyLFGX1NBlmWxMRVeVqlXX6RxceX5AwXToKBp05I1Qtd4NcD2fXNlGViCmShQMG8cLEEGA5QYECGzXpc3w6egGx4WcEdBTBj8AVQaH8D+32EPn728DRaH13d9ASdZVr59/5ldUNj5F3SXvp/nGz5M+6/pRjQvcfAde3wES0w5l3ZXXLwsz4s4MFwGE74aT89POIdU4gboE1GswPgVTWhJkUynKls+M1jTN6TgNaR1NURFCJqZIuKeh6N/pQBSIR5wQ//f5WsGz1myc9y2ZzPxx2REdryQypBZcQWXT09VOMn3GtSCrNaviiRnn43btrtZ7HQ5Jkmi87rOk5l9B/plfUXzlj0Bol9byjq+iZJrouvsbuL2HrJ7SC6+i4cqPYmx9jr7Hf3pCD7nfrzrANXcs586V+6rbzpnSwG1vXsBFM5ur6t+nc7rwBdObUPsHJQJ4aU+u6on97PbD1+s3JLTDbj8alc3P0vfYj0nMOK/Gpqyy5VnafvEZvHwHmfPfTqx1GpIko7dOJzH1bIRdCWesASSJ1PwraLrxC7S+47aqrYkQAnPHSuKTz6impJm7VuEXu0ifdcgnvLTqT6j140jMOL7SjQFkmSOuiB+OAc2GE6nbj4iIiDgefN/H8zxM08RxnGrdeExXaEjHmT+hkasXjOPiWc3MaGkgm0wgSZDWoUkNB6XZc96CcG3Kr4Zq6UqyjuTsiylveJLAsQBITF8MsoLZrzFzJOITF9Dy9q/i9h2k6/dfq6ajZ868juyFt1Be91itrVmqntZ3fwNJUen63VfxCl1YQJEwINdk6DNK7OmukK847Out0Fk0MEXAwb4yDckY01pSdBVt1uzLsX5/jr6S1Z9SLQgCQdF0KVkelhfQXbZpL5i0Fyx6y/ZhP8MbGUmSiCkyRdvFsD3yFZf9fQZ5y8H3fXrKJr4Ay3EpmR6+L/C8MCD1AoHsA7JM4IdK6Mk4tGYUGrPQkIZMAlRAmEW6f38bgV1mzLv+scamrHfZjymu+APamBloLVORtfioxwXm9rBOPDHzfKB/cn7tI6FIW/MkANytz+NXcrSc8xbiOhTNUFQuHoOxdXEWTsiSjoU+4glNIa5pJHSVmCoDEkEQ4DhujVVgxKlPFIhHnBDtebPmdb4Szhrect7kEZ8jc85bEN6gTjfVQHLORVTWP07ghp1ucnaoqG1se+Go55JkhaYb/xeJ2ReSe+JnlAZ15K3vuh0kic4/fK3qXw6htUl2yTspr3mI/DO/GnG7h9JXcdjRXeHWe9fzyV+9wqq9uaoV1/M7erj9gY3cvfrAaZ0ufM6UBt55zsRRLXL3GaMrGzB3raLngTuITZxH89v+vurbae5ZS+/D/45wjP4yhvsw96ytHqePn0N8yhn4xR76Hv8pXXffjm+WasTYANyevXi5dpKzL6xuK615ECXdSHJWeJ/a7dtDUZfFb67Wnx0vXgCTm1JMrD+8X+hls5qrwXdLWufl/31t9b2R1u1HREREnAwGVsM1TauZ2NYUiZii4AWCpnQSSVKYNT6DLwXIEiRjkEiEq+L6mOnEJi2ktPqBaqZaZvENCLtCZfPTAMixFPHJZ2Bse/GYE+iJqWfRctNXcLp20XXX16vjivpL/4LUwqsoPPfr6lgBQueV1nffTuBaNWntFaBgQLEMLj7CD1BUJRRkkxXSiRgJTaJkeHQWbNIxBUmW2NVdoaNg0FdxKNsulusT1xXKtkeu4pCOqaRiChXnxLLyTkUGgsqkpoYibZpMNqHTU7bZ1lWip+hiuQ6GHeB4HrIc0FWy6C27JOMahu+T1HUaM3EkGSwHLMdHlqFQhnwFHNtg/x++hpNvp/Udt6GPmVG9fv7F31Fe/SBa6wyy591E36M/PK5xgbHtBbSWqWgN4wCwdq/GL3SS6Z+c14HCqj8Sa5pA/ZzFZLRw8qk1AQlZxhUCWVJIxePYno+MhO/7BEFoA6cqMrquoShKtCr+BiMKxCNOiKFizgOv54zNMG9shnRMOaaqut4ylfiUM8JOt79eK3P2DQR2hcqmsNPVGsajtU7DGOQBeiQkWaHlLX9HfNpi+h75P4POMY7Wd36dwCjQ9Yev1Yi01V/+IdJnXk9xxR8orPjDSD/+EXlsUyfv/8UK7ll9ANsNA2/HDap2U6/HdOHXKmV+8HlX7c3RXbJHZWUyGqx96+m+91voLVP667/C4FUIQWnVnxCOSeO1n6bx6k/QePUnyT31S7xCWFohazGSsy/C6d6DsfUFMovfjJLIDLuGseU5kORq0O3m2rF2rSZ9xnVISlgHWFr1RyQ9UfUNPVEe3tDO/PF1w7bf/ZmL+NXHlvDy/76WPd+5sSYIj4iIiPhzo+s6kiQhyzKKckgtOxAQ1yRSuoIqSVi+zxkT63j32eOZPTZGTFPJO2EdL0D2nLfiF7uqK42xCfPRWqZSWv1gNRBJzrkYL9+O27XrmO1Kzjyf5jf/LfbBLXTf8y2E5/Tbmv11OFZ49Ic1rhh66zRa3/l1/HJvf1p7CQGUfWgvQcUO6Cy5BMLHNG36TBtdkVElmZLp4HgenQUDw/Lw8YmrCo4XYLs+nu8T9H+GbEKl3J+anomPPvvrtcT1A4qmS8X2apw7TtR+zfV8TMfHdsPPXajYlE0HTZbw3ADP8xBeQDqmU5fS6Cu7FCyHkuPRVrDZ3J5jR2eRQskIV5DrkxRMi4IBbTZsycO2XsjqoAuL9rtvx+nYSctNX671Cj+wmcLzv0WOZxj7/u+QXnDlcY0LvFIv9oFNJGdfVN1WWv1guKjUP2FvHdyC0baNiRe+lbnj0kxrSTOpJcH45jSTWrIsaM0wvSXBpKYknh/gCYGiKmFdux8QBAGu6wEiUk5/gxH9NSNOiOsWjBn2etXeHO/52Yts7ihRtv1qzerRyJzzNvxST3XFOzZxwbBONzX3Uuy2LUdNTx9AUjVabr6V2KQF9DxwB8a2UIU9Nm5WKNLWs4+uu79R9SaVJInGpZ8hOf9y8k//F8VX7h/V93A4bDfgiU2dDHRVAZCJqUdMF/7/s3b8aCJqJ9Kuwed9789XcMtPX2DZpk78QCABZ02s42RlWNkHN9N19+2odWNoffftVdsRgPyz/425YyVayxSScy8FIDX/chLTFtPzpzuq+xVfvh+v2MXYD32PZH+K2WCEEFQ2P0N88sKqF3l5zUMgSaT7xeC8ch+Vzc+SXnQNcix5Uj7bgnFZWjKxqrCPLMG3b14UrXRHRES87tB1HVmWicfDidAgEFiOS850yZs2u7uL5MoO+bKLpuvMHFPP5OYUaS2s6QVIzFqCkm2t9sWSJJFZfCNu166q1kty9oUgK1Q2PzOidqXmXUrTm/4aa88auu//LsL3kBSVlpu+Eoq03f9drP0bqvvHJ84blNZ+G75dwSL0fD7YG1C2QUEgJA3L8PFEQMVxCGSZ5rROV8nBsFzq4zo5w6U9V2LN3j6e297Ds1s70Qjr5ifUJ5nYkKQhWRuIm45P3nCqNmonihChOrnheCPb1/ZRZPCCcAJhYHuu4tBRtOgqWaOyXwvvAy/0FpclHE9QMF1291bY3VumaLo0Z3TG1CVIxRW6ChYFw0ZVJVpSOnFFoaNQYdPeHLt7yjy1uY097X0czBkEflg+MIAPHDAcNv/+W5j7N9L85i+SnLmk+r7TvYfuu/8RJZklPuXMag348YwLjK3PAYLUvHBs4ebaMXetIn3m9dXJ+d5X7keKpcguvAqQmTM2y5IpjYypj3PmhHoWTGqiPqnj+wGuL/D7hdlEEFStXCUp/L9yuiv8vcGI/poRJ8T333M2N501nvqkxk1njef77zmbFbt6R610nZhxLmrDuCGd7pv7O91NACTnXQZAZdPyEZ1T1uJhGtLYWXT/8buYu1aF15p+Ds03fgF7/0Z6/vjdWtXUG79AcvZF5J74OaXVD47qMwxFAB2l2pqvnz27i7tXH+CC6U3DgvDXonZ8pEH00TzQB9p1y09frKl/P9b571y5jy/+fm01I8D1ArxBfbYA1h4ocDKyrOy2rXT+/mthfd8t36wRYSmsvIfii78nfeZ1qPXjqsKAAA1XfAQUteplnzn3LUz4xE9R043DrgHgdGzHy7WRnHc5AIFjUV73GMk5F6NmwuyG0uoHIfDJnPPWE/9gwLyxGf7zxT385qV9KLLM+5ZM5g+fvoj3LRl5+UdERETEnwtVVVFVtfpaAJ4PkhBYlovjB8wZX4eqymhKgKKENl8OMnEV0groskLz+W/BPrARu2MHAKn5VyLFUlVFdSVZR2Lq2VQ2PYMQIwsI04uuofHaz2DuWFkVaZX1BK3v/DpKtpWuu27H6TzkBHMorX03nb+/rVpjXiEUmMvbULItiqbDwR6D1XvzrNrVxkt7enBFgK6rdJZMntrcznM7+1i1u5ukrpAzHA4WLHJli7Idll8Nrvs1HZ9tnSX29hrs6CxjeycejHcWLbZ1ltjeWaazaB1zf8Hwztn1BYbrIwJBR9Gko2gOsxxzvAB3SIDuB4L2gkFb3qS9YOK6PhJQNBx8XxAIwY7uMru7SrTlyxhuQENSJfA8pECis2RjOD74AksIOosl2gsm2/ocdnaVcLxD2RQQWox23/dPWHvW0PSmvyY1//JDnyHfQefvvoqk6ox533cRgXdC44LKpuVordPRmsJa8NKaB0GWD03OF7sxtj5P5oylWHKCkmWTiCksmtTENfPGM3tcPamYDEKhq2ijyjJ+AHa/mvyABZyiKGFwHqWmv6GIAvGIE+b77zmbtbct5fvvORsIbahGiyQrZM55K07b1upsd2r+FcixFKVXwgekVj+W2IT5VDYuH/GDSI4laX33P6I1Tab73m9h7V3Xf+7Labz205g7XqL3oe9XO/HQwuRLJGYuoW/Zj084GB9KIOA3K/cNC7aPZjV158p9fOCXK4cFwcdiNMH9UIXthqTOj57awT2Datq9QHDb/Ruq5zna+e9cuY9b713Pnl4DQb8i7mG8sE8GYRB+G0oiw5j3fBs100R53WOU1j5M4aX7yC//vyTnXEzj0r+k8ZpPYexYSWXTcoJ+RV69eRJyIgzcj1XPXdn4FCha1RO0smk5gV0hs/hGAALXorz2YRKzllRrxUZKY0rncGYDuipX/wa+HzChPhGthEdERJwyyBKosqBoeuzqLdNVsbBth4aUzphsign1Kc6e2EBDQkaSw34yrcOYs5ei6AkqL98XnkePk150DcbW5/FKYR+ZWnAlfqkbe9BK9rHILL6R+is+irHlWXof/gFCBCjJOsbccjtyLEXn72+rEXVNzjyflrf9PU7HDroGBeMKEHggAjiYK/PCjjZe3NrO5n1F1u0vcLDXZFtnnue39rB6bx/d+Qrbuk1M1yOmKhi2S0/ZxvYCOgpmdXXZcjxyFQsJqEtoBEJgD5rF7ixarNjdzY7OEkEw8hXprpJNXVwjHVPpOYYwnCRJJPWwpl+RZFRFxnJ8/CAMsrvLNggJw/ax+r3QhRCU7TCVvWx51VV0ANv1sV1BOq4RSFBxPXRNpiGlYTgua/fm2dmR54mNB3l+ezf7eyp0VSz2FVz6DJdC2aS3VKG7XGHdgV5e3mnSmROUDQvbD7C9QwYrYRD+bcydL9N43eeqJWLldY9RWHEXnXd+GeE5tL77drT6sSc0LnB7D+C0bye94AoAAsekvG4ZydkXHZqc7w/yM+e8BeHDxPoEmqpScl0KhkvF9GkvO7Rm/z/2zjrOrvrO38+x6zpucXd3x4MFlwAthbqwNdoubWkLddl6u6UUKBI0JCFAgBB3d/fJZHzmuh39/XEmN0wEadnfdul9Xq/stlfOOffc6f1+P/Z+K9QUuRElEUWRcTsUvB4nsiximlbHSF9BpO2jRiEQL/Chs7s+9g/9VNitvF7im+yquOhw4Rt2OemDa9HjLQB4B1+E1laL2njofR9Xcvkov/Vh5FAFzXN/QLZuD9CxGE/9GKm9y2l/84yFiSAplF73rXwwflp9/cPConOwPWdDLW/taUQShXNmx08HtKsOtfLAvF0fKBj/ID7S71TYfvDqQTz06h5+9dYBXtx8slOW3jSt/HHe7fhnW5GFvQ5uGlXD1D4l7/v63w+5U/toev679nc8+ydI/hJa5v+E1N4VpPatJLrsUeSiLhRf9RUEUUIOlFJ08adIH1hLZMXfaXvzj2SObT3vLPjZWIZGau8KPL3HIbp8HXPnr+Ao74WzxlbET+1ZhpmJExg96wN/ltHdwjx83ZBOP8qKJDChZzGiICDyr6crUKBAgQLvhWWBU5HBsuhe5KVHcYDWpEGR301NkZduJU7a01lCbhdepWOe3CEiuLyUj7qMxP5V6AnbVcM/6howTRLbXgfsFnbB4Sa5a8kHuqbguBsITr6D1O6ltL/xByzLRA6UUX7bDwGBpue+jRZtzL/e03cCpdd+k1zjIZqe/y5kk2iAJUAmC5G0RUPMoCWjoiMRcikkszlqW1O0pNKkVIOWlIrPKZFVTSqDLoq8Ck5FxtGRpDYsi7r2JCsOtrD1ZJTGeIZ4RkMSBVyyPW8fS6usPtRCJKGxpTZCXST9vj+z3yWTyNpBst8lv+frHbJE0O3A45DI6QaGZaEZFiG3gkMWKfY6cDskzI5ChmnZlmEOWUQU7OT9aSTRtqnL6jp0jKaZFraFmwhZVaUtnWN3fZw9jTH2noqwu7YdwdJJZlVORFT2nEqyan8Tte32eEDMhNoIuCRwu0RCQJGeo3neD+0g/LLP4x9+BZZl0TL/JyR3LSG2/kWMVATf0EvzFex/Zl+Q3L3E1owZOB2A1O6lWLkUgdF2R5yZS5PY8aYdmAfLyAEp3SLgdKBqBvGsTtzQEAyRjG5hWWAaJlhGfh8tiiJOp4wggCQJBcX0jxiFQLzAh8qcDbXM2VB7noam90Z0uPENn0n64Nr8AugfeTVAvjLt7T8ZQXaQ3PX2Bzq2ne3+EZK/lOYXv5+vugcn3EJg/M0kd7xBZMlfzw3GO9TXYxte+gc+0RmKPAqXDSzHIXUOtn/6+j4emLeLHXUxNMPi4gHlnWbHzw5oL+i1fR4+qI/0aYXtSFo9U4E1LS7uX4bcsYgKAuw4GWXLici7Hn/m4M7V4PaUyrMbai9oUfaPkK3dZVfCvUHKZ/8UOVCGlbMbBv0jriJ3cg9ySVdcXQeTesffi6vrUELTPo6n9zgkd4CKu/4LpbjmPc+XPrwRMxPHN+Ri+/wndqC11uIfdU1Hu5hJfNMCHBW9cXYZ/IE/z5L9zSw/0Nwp4T2iS4gn1h3HtOxs+INXDypUwwsUKPB/hlxOpSWaojGaBnR0QUAHaoo8tjWlJdCtOEB72qQllialQtgn43fLiKLIoEtuBcsitcVuR1dCFbh7jyW5fRGWriIqLrz9p5A+sCZfqX6/BCfelrcwa19s25cqRdWU3/owlp6j6blv54sAAJ5+E+029aYjNDz3bTLpGJE01GXAMKDYLyNaEEmnaIqmiOcMRMvEJYuomkF7KodbkekS9jCgMkDPsiCiaFeVPU4Zy7JYcaiVRFbDLclYFnQr8dKn3J8P1nO6iW6Y+F0KDkkgmXv/LetlfhelPic1RfZc+vvF6vgniQKiAG5FojLgIqebaKZBTjXIqjoCFookouoGpkX+mi3LQpZEyoJOch3e4G3JHKfaUmw7GaElpZJRTeojGVI5ncbWNPtPxTAtaI7nON4SJ6PliOVyNLXZ13K67zIg2WMPTVGTpJrl8EsPkT26laIrvoR/hK1YbuVSWIaOqabB0AlOmg2m0cmX/h/ZF1imQWr3Utw9RiL7iuw9wJaFOCr64KjqD0By19t2YD72etxAsRckyw60i7wuwl4FCXA7BEQMSnwuygNOnLKCaph5Vfl3/ivw0aIQiBf40NhyIsID83a962veSyzbP/JqEAQSHbPicrAMT5/xJHe8galmEV0+PH0nktq7Im9BcjZqay3pQ+s7WZQBSL4w5bf9CMkboumF75GrPwBAaOrH8I+eRWLLK0SXP945GL/2m3gGTCW6/AkiK5/6h2dz2tMaHofEs5+ekBdqA3tm/J1kNKNToDXoLD/2QZWB9zXzveVEhPVH23jw6kEf2Ef67AD7M9N68cnJPQDb8/KtvU3c/ld7dup8wnNbTkSIpFUGVHTOJp9ezD8MMkc20fzi95D9JZTP/hlyoBQA0eUj13iQlgU/xVnZl/Lbfoyr+3DU1lqydbbWgJlLIwfLcPcaTWjqXUhuP6aaIXN0C9naXRf8jlM7FyP5inH1GAlAYtN8RE8Ib8e8eObIJvT2OgJjrv+HFkvDtFiyrwnpHf8j2XIikk+KWJZFJP3ewocFChQo8K9CpqM92aUoOGWJgEPE7xSpCLlwKhIlPieaCX6HPQObzIGq6vaIlBciUohAvwnEty/KB9qB0bMwM3GSe5YD4Bt6KZaWJbV/1XmvwcylSR/aQK7+QKdZckEQCE25k8DYG0hue43IkkewLMtWTL/lYcxMgqZnH8hX4wE8fcZTdsN30NpOcmrOfxJPtCICrRqoOR2nCJalIWJS7XPQnNZI5wwciojXISGJFnsaY6w81EIknaM65KE67KbU5ySS0vAoCrppURdN43cpBDuqz6fRTJOMZrCtLgoIhD0OattTtL9Hq3lTPMOx1hTtaRVZEjqtM++FJArIgoBmmB0zyyJZ3cQwTQwDUppBIqeTUU28TgmfSyHgVlAkEd0waU+rRFIqyYyGhYVugWpCLKsSz+kE3A4QLTySiMepIGOiKAJpTcWwBEJ+J1gCmmoiKvZIgK0fDg4HtKUhk01y8oXvkqjdRfFVX8E/7PL89VtA9vg2tJZaSm/4NoEx1+HsMvhd9wWiy4fadJT0wXUYmfh57kqHRVmyDW9H63v26Bb09jr8o6/tEFozSGxegLN6IMGqfoS8UO5VkCSRU7EcTbEsmmbQFs9iihZ+lxME28bMHuWz8sr0Be/wjy7v3ZtSoMD75L08sWVRyFt2XPA1gRK8A6aS3LmY4OQ7kFw+/GOuI31wLandS/CPvArfsMtI7V1O+sAafIMv7vT+9OENtMz9IWDZ7UL9p1B8+Rfy6tVWLk1g3C3E1z9P0/PfpfzWh3FW9SN80SfB1IlvfBmA0PRP2NlHSabk6q/RrriIr3seM5ug6JLP5L2pPwjzt9eTVg0+M60Xo7qF+eOyw+fYv72zkrzlRIQn1h0H7ATGtcOqeGLdcVTd9l0dVBXk1jFdzxHtOj27reomDln8QEE4nGlTX3+0LV/lfnT1sU5B9OlW9C/M6M2BxgS/efsgMwdX0q/Cnz/3hxZ1n0Vq73JaX/s1jtLulN3yUCdhtvThDRjxVkSXl/Cln0X2hqCyH2r9AYxkO3q8meSOxfhGzET2FWGqWeIb5hLfPB9LzQDgH3MdRRd9stM59XgzmWNbCYy/GUGU0NpOkjm6meCk2QiyrXQb3/AyUqAUT8f8+D+CZcHAygA762J24sICURTsSkOhLb1AgQL/h7AsCxEBzbSwTA2HouBzKWjJLG1JlWKfA6dDwdB1jramyKgaNQGoKQ7Qu9xPUcDJvK319LnoFrb8aQ3JnYsJjJmFs+sQlLIeJDbPxzf0UhxV/VFKupLc8WanAAzA0lUanvwqers98y0X1VB02edwdxsGgN5Wh6N6IP4xJomOWfTwxZ/GWdmH8lseoumF79L07AOU3/5jZL89WuXuNZqym79P89yHaXjmm1Td8hBiUTV1cXBK4EhamFIW1Yog4KBrkRvTyGEgUR/NgiDikCV2nowwqMrC71YIehykdI1uxW7aUjkyqkGfcnsEyrLstnXTNDnWkmJk1xCqDi5FpDGRRRZgX32MriU++pb6zlHVNk2L1qRKwCVjmBYtiRxB97tr+ZimSVYzEAQRpyzidsqYpoUgQDyjkdV0JAHqI2mqwh68DgnNNBEEGc0wSGRU3A4ZURDt4F0SOBXNIloioiCQyml4FJFYSieSzmJZEPBKqCYksxoWBpGYhiQbGJZJPKsTdMl0KYFjzeADNANUHZLRdhpeeJBcWx0ls76Z13ABMLJJmp//Lpah4+oyGDlYjqg4cb7LviBzbBvRFU/kRfskXxGV9/wByd25MJLY8SaiJ5RXUo9vWoDkK8LbfzIA6QNr0WNNFF10L73LHBQHXDhFAVEESbRoiKToVx3A4ZDte43d2q+aJrJuIEsikiQWKuEfcQqBeIEPjfE9ixG4cPw1sVcxG4+354U9LkRg7A2k9iwjue11ghNuwVk9yyfl0AABAABJREFUAEdlH+KbF+AbMRNnlyHIRdUkty06JxBPbluEHCqn5Oqvkz64FrXpCILDhWVoJLa/QXLnYuRgGSAgOj12MH7LQzir+xO+5LNYlmUH45ZFaMY99g+gKFF0xZcQ3X7iG+ZipmOUXP01BPmDi9K9tbeJpfubmdirmINNCSThjPd6/wo/e+pjzNlQSyStsuJAc/5eCUBb6kzLuGnBjroYO+rsDoR3BuPnm91+r0B8zoZant9US1nAxWc7EgWn32MnDDp/qxa2KN/pGXaAVYdauXRgef7c70QSoSbk4UT7B2sdPJv4pgVElv4VZ9chlN3w3U72YOmD62hZ8DOUsu64e4witXspSrgKOVCCFCgld2of3v6T8Q2/AtlXRPbkbtpe+zV6rAlPP/vx5M63SG5bRHj6JzolW5I73gLLwj/sMvs6Ni8AScm3vuXqD5Cr20P4ok/m7Uo+CKJwJuie0LOYA00JNN1EkUUevHoQkbTaSWn/dMfD2er7BQoUKPCvgiAIeN0ODFNHN8DvlmlMaDglEV0E3bADu9q2DC5ZRBYFqov89Cr10bXYh0OECr+TVHlfnDUDiW+ej3/kVQiSTGD0dbS9/muyx7bi7jkK37DLiSz5K2rTERzlvfLXkD2xA729jqIrvoQgKcQ3vozocJ+zJ9Da6vAOvjgvrBW++NM4q/tTdvNDNL/4IE3P/qctBtrRfeXqOpQut/2Yuhe/x6lnvkHZTd+Hyj5kDWzvLA2k5hwhTw6vAggSJT6F5niWtKpxrCnO8dYkhgVeh0xbUiWSzqIbFkO7hCnxOjjYlGBfQxxZFHA5JBoiaU62Z5El6FHiI+jxouV0jrWlOdiU4GQkQzKrMaZ754StKAq4FYmUamCY1nsK6jbHM2w7ESGR1elT4ac84KY84OwQCzvz3banVJI5naxmkNEMSvwKWVVn/eFWNMNut+9b5keSJRIZlYyqIwgCblmia7GXkNtBfSSDiQtFEGiIZ/A4TTI5BxnDoMgrUx9VMQwDv0Mmms6SzEHYD6oBRgYybac4PudBjHSMspu+h7vHiPw1GpkEzS98F7XlOMVXfRU92kBi62uEptx53n2BqDhpfe2/SO1eihyqoOjyLyB5QrTM+xHpg+vz6z+Anmglc3gjgbHXI0gKavMxsie2E5r6MQRJye8l5XAV/t7jsCzoHnZTH82S0UyOtqdxCiZlKQdBt5OmaJpSj4vuJT78DglRFHEoUiEA/zeg0Jpe4ENjVLcwL31uIt2KPDgkAcdZEtDrj7bRv/y9xS8cZT1w9RhJfMsrWLqKIAgExlyPHqknc3ijbW027Apy9ftRm491eq+RjiEXVduB9Yx7KLv1hwB2W1rtLoov/wJlN3wHd++x+EdejeQJ0vTCd8nW7bW9xC/9HP6RVxPfNK+jTa0jEBYEwtM/QXjGPaQPrKHphQcxMol/6D7ppsXKQ600xnMYFnQt8iBJAvsaEzzTEdj+4s0DbDx+pv1cEOxqueM8yuNnz42fb3b7vWzGTs+pL97bxK1/WdtJGf1UNNNpAQY7MbDsQDMPv7q30+P76mM4ZPEcsb7KgItJ/4RQm2WZtC99lMjSv+LpO5Hym3/QKQhP7V9Ny4Kf4ijvRcVtP8Y35BIQBFoX/gI91kTmyCYkbwiwRxSia56l6dkHQBApn/1TWw+g+3CcFb2x9BzWO8YeLEMnufMtXD1HIgfLMdIxUruX4hs0I3/M+Ia5iE4vvqGX8UFxKSKfntITSRQwTItHVx/j7gnd8y3/s8d15QszencKwv8nrO4KFChQ4MNGkkRCfi+lYR+KomCYBhaQ1mxP8eZEhlhWo1dlgC5lAXTToCLkIZHVeGNPI8msjiUYFI27ESPe0uHZDN6BU5F8RcQ7qtjewRcjyA4S2xZ1Or+RjgHg6jYM3+CLqPzE73FU9D53T9BzFM7qAfjHXEdiy0LaF/8ZyzJx1Qyg/JaHMVIxmuZ8q5OAm1DZh4o7fo4gO2l69j/JHNnU6dwpFQRBQhNMvIqIIoqUhTz4nQ5UCzDhtW2neHtvPdtq28jpBjkTDjfFkWURTTdpimWoj2bQdIvtJ2OU+u2ugrRqUBl0IYsCR1uSlPmddCv2cCqSOe/30KXIQ5FXoTxgzyBHUjla4tlzbMYA9jcmkUQRj1OmPpq1hdo6sutqh3p7KqcTzWqU+93kDBOnJJLM6uyqi3IqmsG0RFqSKjndxOeQMEyTYr+TEp8T07Lb0AUsnLJINJkimspg6jqGYaEbFk2RJJuPREgk02ia7WPeElGJpiCegdYENJzYx94n7sdUM5Tf/uPOQXg6RtNzD6C2nKDs+m/jGzQd74CpF9wXGMl2Gh6/j9Se5QQn3ErVvX/GP3wm7p6jADDT0U736HRy3jfMtiiLb5qHoDjxDZ8JQK52F2rjIQJjr0cRJQwsAh4nxV4XPreDXMakX6kbQ7NwORS6hHwE3TIhrxNJllHkQhD+70IhEC/woTKqW5gV35jBwR9dybAuoU7PqYbF9rrY+zpOYNyNmKkoyd1LAfD0m4QUKMu3jnuHXNKx6Ha2FxOdnrz9BNgBtN5eT+7kLty9x+Ks6oepZjFSUVuUZfZPED0hml94kGztTjvgvuQzBMZcT2LLQtoW/R7LPCOGEhh7AyXX3E+ufj+NT9/faVH+RznRnsZ4D9/1QVVB+lX4uXFkDWO6d66Anj1H/k4F9NOz6KcDt9sfWce35+3qFLydHcjrpp00OR3wPbexFuvsEjeweG9Tfv7vNMmcTsitUOzrnHGvi2Y/sP3aaUwtS+v8n5LYNB//yKspmfXNTt0IyV1LaH3l5zir+lF+68OILh+C4upQKS0ntvZ5lKJqguNvttvUXvoBsdXP4B04jcq7f4vrHcJqRjYFgoiguPKPpQ+tx0i24x9hW5Qltr6KpasExlwPgNZ+ivTBdfhGXNkpOfBeVPid3H+5/R35OyxqLOxEzaOrj12w2v1B1PALFChQ4H+b0wGFKIqU+j1YgohpQbHXiSyKaLqFV5HoEfZyzcgain0Kaw63UtuWpiGWJpODUN8xKMVdiG2Ya8/LSgr+kVeTPb4NtfkoksuHZ8A0UnuXYeZS+XOf/k0+vS94tz2B5AkSnnEv/rE3kNz2uq2mbho4q/tTftuPMHMpOxhvO5k/vlJcQ8Vdv0QpqqZ57sN5NXcADcjoBvGURs60OBlJoakGLckMJ9vibKttY+vJdnbURdhytIWTbWliGRVVt4hlVFI5A4csoeoGtW1JBBGa4vbz3Yo9HG5OklZ1ygJOREGgti2NYFm0nWdeXJFEygNuSv1O2pIqR1vSHG1Nsnx/M0eaE2TfsZYrol1FT+Y0cpqO0jFTrhkmTfEs8ayOblh0CXsxBJNkTudgS4I9p6IcbU2SVDU00yCZ0TjckmBnXRRRFJEFwRZrMy121Ed4a289J6NpcobJ/uYUzckcBxsi7KhvJ5mz0A2NxoRKbVuKU9EUKQ3actCSg9b9q2l67tuITi8Vd/4CZ2Xf/PXriVaa5nwLvb2eshu/i7vXGCzLQms9QWj63efsCxI73qTxmfuxLJOKO35GaOpd+ZEzo+Pv5p1ru2XoJHe8gavHSJRwJXqildTelfiGXpZXWo9tmIvoCeEbfDFeF2i6hSyKOGUBjyJQGVYY3r2MQdUh3LJAadBDl+IATkVGkYRzih8FProUAvEC/2N8a+aA8/oivx9cXYfiqOhNfOPLWKaBIEoExswiV7eX3Kl9SG6/vejuWdYp8JbDVWhtdZ0Et7Ind2MZOt5BMwDQWk8g+4oQHG5kfwkVd/wMOVBG84vfJ31ovS3gMuMeghNvI7VrMa2v/AJL1/LH8w6cRvmtP8RMx2h88qtkP4B/6YV4r3HqoFvhjkfX8+zGWnadinHd8Kq88N0T646fUxU9rYA+qlu4U+CmGhZzzvIxP1vhXBLsqvo732dZ7y20BxDN6DTGc7QmPxxRMT3RRtOc/yR9cB3hiz5J+Kz5/PjmV2h7/de4ug6l7OaHEJ0e9FgzTXO+ResrPyM84x6KLv8CRZd8Bq31JI1PfoXs8R0UXfZ5iq/66jmBs952EjlU3ukcia2vIgXLcfcchallSWx9DXfvsSgltvVJfOM8kGQCo679QJ/tj3eOyn9H4zssyk7zTpu4s/mgavgFChQo8K+AIAh4XTJlfhdlPjeiKJFSdcIemT5lAXpVBKgJ+dhVFyNnWEiyQFa1OiRfRIrG3oDWfIzssa0A+EZcieBwE99gJ+j9I6/C0nIkd55xyZDDVQCdgud32xOc7n4LTryN5M63aFn4SyxDt2fGb/8JlqHTOOdb5BoPnzmHr4jy2T/F3XMU7W/9ifa3H8EyDVQgo0Myp3KyJUprMkd9LMP++hipnElTXCWnGuiaSVIHRHtTLkoiDZEsR1viNMUzpFWdtrTGhO4hAk4JQbBwKSK6aVHkddKr1MeAygCVQRd9yn3UtqdJ5XTbzzunnyNAGs9peJ0SqZxBIqeRVg1q286MjQ2qDuFxSIQ8DgZUBfA5FARByFfF3Q4JlyTSmsyQyOpUBz2kcwY+t0zvMj9eh0yJ10nY40CWJXQDatvSyJKEiUljNMWp9gxNsSytCRWHKBNJZ2lJqsRyOaIpE93SiGQgkYaUBtEktBv2rHxs7fO0nu6A60iCnEaLNND0zDfRE62U3fx93D1G2kKnSx+lZd6PyRxYS9Gln6Xo8i8QnnEvbW/9mfY3fo+ryxAq7/4tzuoBne7V6b+b039HYI/BGcl2/CM7kvObXwHLxN9hW6o2HyV7bAuB0dciyg40AxyKQFXYTUXQRbcSPwPLA6g6BDwOBleHGFQTwudWbE2YQiX834pCIF7gf4xR3cK88NmJXDaw/AO/VxAEAuNuQo/Ukz64DgDf0MsQXT5iG+YCEBh1dceiuzj/PkdFH6xcCr2tLv9YYttruHuNQRAl9HgruVP7sUw9n0GVfUUUzbwP0ROi5eUf0bLwl3k1VbsVfTXNcx/CVM+0fLm6DKbirl8iuv00PfcdEtvf+Ifu0ftl5aFWctqZKuju+nh+Dvt0VfRC7eenA7fTP+1n+5jPHteVz07tyenfflEUmLu1jrDHkX+fCZ3mvs9OHPQu9VLkUT7Uz5yrP0Djk19BaztJ6Q3fITDmuvwCZVkW0VVPE1nyCO6+Eyi76XuIDhdaWx2Nz3wDMxOn9LoHEGQHgiiRObKZhqe+lm9h84+48pzFzrIscvX7cVScyayrzUfJndyNf8RVCKJEatfbtk/42BsA0JPtJHe/jW/IxUi+9z+rPbVPSf47A/t/Kw/NGows2n7hDuXCAfbZHQ+FGfECBQr8X0ESBRyKRMhrtyhnVJ2KgAvTghKPA0WWyOgGommh6Rom0L1UoTrsZOSMK5H9xcTWv2gfy+XDN/QyUvtWoseacVb0xlnVn8S2V/NjZUpxFwTFhVq/P38N77UnEAQB76AZePpPIbN/FfWPfwlLV3GU9aDijp8hyA6anv1PsrU788cUHW5Kb/hO3oGl+cXvY2QStGfgaBQaY5DJ6kRSOURRIuhUkAVQNYimcwRcMj3CPiRRIJbWONKaRDVNBATKAm5GdQ1zoDlNU1JFzWnM23yS/Q0xUjl77jqn65yKZDjUkkLTTRrjWQ43JzjclKDurHb1Eq+DtGoQSamEPQ6csohunmlRD7gVBlYFGNezhMqgh7pomkPNCWIZFUkQ2NcQZ/3xdo61pTAMk0RWI+hRSGYMsprJmO7FDO8WpjzgQAQkySKt2WOGTlGkOZqlOZohmsySSGU43BhFNiGVNUhnwadATrX/iQKkshC3QNeyNC/8BdFVT+EdON12wXmHWKvadJTGZzpa1W/7Ea6uQ7BMg/Y3fk9i8wL8o67FM2AqYKukN7/4IMltrxEYewNlN3//HDE2ALXDXcdR0fvM38+WhcihCjs5n0uR2P4Gnn6TUEIVAMTWz0VwuPGNuJISB5T57HGCtqRKyOeiOuCkW1kApyyiKBKaJZLNGaRVPZ/sKPDvQyEQL/A/yqhuYR752Gh+fP0Quhe//7ZdAE/fCcjhKuLrX7TVVx1u/COuInNoA1rbSRzlvXDWDCK+9dV8+7ir6xAAsrU7ADoy2X2ROtROU3uXocca8fSd0KllLb7hJTx9xiEXV5Peu5zml38E2K3oxVd+meyJHTQ9+wBGKpq/PqWomoq7foWr21Da3/wDbW/+oVPl/MMm//MsCBxuTnZ6bsfJKLc+so5fvmm3n78zGD8duN0+rusFK6l+t5IP1DXD4tkNtTz06h6uGFRx3mq9LAp8dmpPpvQp4cfXD+GeyT3xOD64kvyFSOx4i8Y530SQFCru+iWePuPyz1mmQfubfyC29jl8Qy+jdNa3EGSFXONhGud8E8vQKL/9Jzir+9uCKZsX0Dz3IeRQBZUf+zWumoHnPafeXoeRbMfV9UyrenzzK/bc17DLsEyD+MZ5OKr64awZZF/npvlgmgTG3viBPt+qw6386q0D3PqXdfmW/dnjuvL8ZybwtcvfO8B+Z8dDgQIFCvxfQ5ZE/G6FYp+LgNeJokiU+p2kVY3qsJdh3cKUB/1M6OanpjhETdDF4K5haqZcT+7kbnKn7MA6MOY6EATim+YB4B99LXqkgcyRzQAIooSzZiDZE3bQ/H73BJHljyMHSvENvwK97ST1f/scZi5tr/t3/ALZX0rTCw+S2r86/5kEUaLo4k9RdMV9ZGt30fjkV8g0H8tbdyaykFMNREulMZlFEECS7XGwsFuiLaPRJezB65LBFHDKMmnVIJHVcSsiQbdCj1IfexqTrD7SxuK9jWw43ELYrdCcyBFwK9RF0sSyKo3RDCdaUzTEshxrSZLIaKQ6quNhr5MBlQEm9C6m2OtAMyxqwu5O34/jtPVYKoeqm8hAW0JFkiCr6VT4nficCg5JxiVLDK4KMaxLiEFVQfqUeYmnNZyyRH0kyZ5TMRRRIprIYAl2xV+3LOojGXIGBDwK3Ys9OEQL07Q1cRQRSv1gCZAFpGgjjU9/g/S+VYSmfZziswRzMyd20DjnWwiiQsXsn+Gs7Iula7S+8guSO98iOOFWwhd/CkEQ0NpP0fjU18jW7aX4qq8QnnHPBZ1wsie2o5R2zwfpucbD5E7txT/yagRRIrFtEZaaJjDO3gNokQbS+1cRHD6TkMuHwymgYeGRZYo9Dop9TgJuJynVAEHC73Lac/eC7dGe1c7tYCjw0aagml7g/wuzx3WlX4WfW/+yDv19ZvwEUSIw7iba3/hdXhnVP/pa4pvmE9swl5Irv4x/9LW0zv8JmUMb8PSbiBKuRA5Vkj6yyf6hlGScNYNofvH7OMp7ITo9BMbegLOyT/487UseRQ6UEppyF6GpH6fx2W+RObSO2IaX8A2biW/IJYjuAK0Lfkbj01+n7OYf5FuhJJePspu+R3TVU8TXv4TadITSWf/Zocz+zyF0KGmfzdkZUxNbjf00qmHxlxVHGNYllJ8zPv3vxpE1ndS2T6tvn65+5zQzv2nIaman455GFOCTk3vgdyvMHFzJsgPNLD7P6/4RTC1H5O2/2OJo3UdQcu39nbLUppqh9ZWfkzmyicCEWwlNuRNBEMjW7qR5rj0fXn7rD1GKqu2A/e2/kNz2Ou4+4ym5+uuIDtcFz31aaMfdwxZnMVIRUntX4Bt6KZLLR2rvcvRYE6Udi7mRTZLYvghP/8ko4coLHvd8WHa3JaZl8Z35Z5Tv36lWX6BAgQIfZWRRsC2uRIFhHRZOBhZORUYSDboXewh7ZTQdunUL4ZAl1o6ZycllzxNf/wKlNz6IHCjFO3C6HWxNvA1P34lI/hISm+fnbaXcPUYSWfooWrQRJVTxvvcEwYm3ITo9iE4f8Q0v0fTctymZ9U2UUAXld/yMlrkP07rgZxiJNgJjZuXf7x92GY6SLrTM/wmNT32doss/j2vwxXhkcMoCGgKGrpNVQZHA5VQQRYkuRR7Kgi6iSRWHZJFRLdxuGbci0hCzhdVWHTzF3voE3Uu8uGTYUttGz/IAqZxOt2IfiiQScMq0pVTakzncDpnWhEGRz4ksChR5HZQFXLgdEm6Hh4qAHYCfnkmOplRaUjk8ikSR14lhGpxoy3GiPUVWM+gWdlPblkYU7BEqryIR8jrQdNvfXBEF2lMah1tS6LpBQySL1y1zvDXJ7jodyzRIqQbRtMapaJKsZlAd9pLWTSwsHLK9p2lOgGBCyoLEkU20LfwlgK2M3mt0p7+j5J5ltL3+W5SiKspufgg5UIKZS9My/ydkj28jPOOefBdb9sROWub/2BZqve3HF0zMg101z57cS2D0mbGzxOYFdrV76KVYukp883xc3Ufg7KiYxzfOBVGkesIsHA5wS+BzOyj1uyjyuyjzuwh7FDKqji0NZKHItrOMaVqIgljwDP83oxCIF/j/xun22wcX7LaDSQECbplYWr/ge3yDZxBb/Qyx9S/i7jkKyRPEN/RSEtvfIDT5Djx9xiMHy4lvmo+n30QA3H3Gkdj6KmYuZStZD74Id++xtl9oZR+MZHv++Jnj29Haaim78YwVlqOkG1YuTXT5E2RP7gVDJTB6FuW3/5jmuQ/R+NTXKL3hO3mRL0GUCE+7G2dlX1pf+w0NT9xH8ZVfxtNn/Pu+N4okIAsCGf1Me9j7TYqe73WL9zaxeG8TTqWzj/g7A72z/cYfvHoQ87fVdVJrT6udxdgEbD/zx9Yez6unflhobSdpWfAztJbjdpA9eXanLLWeaKVl7sOozccouuzzeeuw1P7VtL76S5RQFWW3nF6EU7TM/ynZ49sIjLuR0LSPIwjv3gCUPrgOpaxHPomS2Po6GBqBUddiWSaxdS+iFHfF3bG5S2xZiKVmCI6/+Z/63KYFDy7YTb8KfyEIL1CgwL8NgiAQcCvohkTYDRYWZQEXrQkVNaujeBVEQUBxg8clYukWltNNaNS1tK1+BrXlOI7S7gTG3Uhq9xISWxYSmnIn/lFXE13+BGrTURzlPXH3GU9k6aNkDq5FGXvDB94TGOkYrp5jyNXupPHp+1HClQTH30zZrQ/T9uqviCz9K3q0wa64dqxZzuoBVN79W1oW/Iy2135NtnY3PS77DJruIm1aiICOvaYqAoiSwIguQZpiOQzLRBFFdMGiyCWR1gz21EdpimXJ5AxaEzlaUzmwoCbk4kR7EocoktF1Am6Z9rRKWzLH8bY0TkUko+p0KfYSdCtEM1AWOJOQfqcoWE43aIhlcDslWhI5BMDjEKmNZPDJEmVeF/saUvQr97OnPorXoVDqd6KbJk0JjZZ4Dt0wMAAskxPttgCbJ2twsCVChc9NKmcgCQZ1kQypbI6sCUfaklT4XThkgZYkZE1QAcvQiK58ivjGl3GU96Lkuv/Mt3+DPU4WW/scsdXP2Jam138b0eXDSEVpfun7qE1HKb7yy7aLCnanXftbf0QJV1N604OdjnU+Mkc2ganj7tjL6fFWUvtW4h9xFaLTS2Lb65ipKMFrbkYC1EQbqV1vUzr8EpRAMS6HYPuwCwLditx4nSJup4TDqSApIh5ZBgQUCURRwjTN8zrjFPhoU/jGC/x/5XT77exxXVFEgfi7BOEAgqQQGHcDuZO7ydbtAejIbNoejYIo4R89i9ypvflWNW+/yWDopA+uzx9HcvlwVvfHzCZJ7ngLtdVuB86e2I5v8MWInhBwxg+6/I6f4x91DdkjGzE1lejqZ0gf3kDFXb9C9IRoeu47nWbTATx9J1J592+Qg+W0vPxD2hf/GVM7V730fGiG1SkIfz8UeR0E3efPpZ2uaqvahVW11x9tI9sxd57VTB5bfbRTEH6h487fXv+hBuGWZZHY8SYNf/8yRrKdspu+T3jqXZ2C8FzjYRqf/CpaxFZBPR2ExzcvoHXBz3BW9KH8jp8hB0rQoo00PnU/2dqdFF1xn+0J/h5BuB5vIXdqH55+kwBbqT2xrUOUrbiGzJFNaK0nCIy/CUEQMdUMic2v4O41BkdZjw/8mUWBThZvptVZnO3d7OYKFChQ4KOCIAgosoQk2UrqHodMyOugIuxAtwSSqkml382BxjSrDrWg6QKVE65BcLiJrXsBAEdJV9x9J5DYshAzl8Y/7ApbxK2jXV0JVeAo70Vq/6r8eT/QnuDkLkqu+jJlt/4QS1dRW04QWfY4sXXPUzLrm7bLytZXaX7poU6K7ZI3TPltPyIw4VZSu95m/+NfpqXhCBqQw7Ybz+mg6QZht8Lm41FWHmph/bEI2+pi7GqIMn97Hcv2N3GsNU0yY9CW0XDKYBr2GqIaJu1JlfKAi35lfk60JjnSnCKaUollNMqDbmRJpC2RozGWRTpLdVXVTXTDJKPqtCdynIgk2X4iwrGWBAeb4qw42Eoyo3MimiahargcAiGvQsDlwO+SSaoGW461su5AC/sbYxxqjHOkMc6x5gT7G5Ikshq76qPsqUuw/mgz20+0UR9LowsWCRXSGYgnDGpbU5xoV0l3BOFapJ7GZ75BfOPL+IbPpOLOX3QOwnWVtld/ZTugDJpB+c0PIbp8dtv5019Ha7W1ZXxDLsGyTCLLHqP9jd/h6jqUirt+8Z5BOEBq/yokXxHO6v4AJLa8ApaFf/S1WKZBbMNcHJX9cHYdgluE2KZ5WKZJ9eSbKPIp9C0LUBZ007Pci8uhcCqSoy2loao6QacTl0PG7ZQRRAmXQ7L/s0ChGv5vRiEQL/D/nVHdwlSF3Oim9Z5q4QC+YZcjeoLE1tqLrhwswztwBskdb2GkoviGXoro9BLbaIu4Oar6IQfLSe1Zds6xJE+Q4OTZyAG76ikoLgRZyf/wtb/9CP5R1yB7Q4Qv/jTBaXejntoLooTWXofkCVJx1y9xdRlM26Lf0r7kr53szZRwFRV3/tIWbNn6Go1//3InddUPC1GA9pRKLPPuiQxRFPKz4O8M7uZsqOXZDSc6vfZwS+p8h/gfxUhFaZn3I9rf+D3Oqn5UfuL357SdpfatpOmZb4IoUXHHz20rEtOgfclfiSz5K+6+4ym79YdIbj/Zuj00PvlVjFQ75bc8jH/Y+/P1Tu1dAYB3wDT7v+9akhdls7PuLyAFy/EOtJ9PbFuEmU0QnHDrB/7MogA9S7xM6VOSF8iTJbHT93THo+v55ZudZ8gLFChQ4KPK6d9CQYBYRmPHqTjNiSxht4SJgFtx0J4zwbKwHD78I64kvX81WvspAILjb7GFs7a9jvhOEbd4MwCeAdNQGw7lX3+a97snkDxBXDUDqLjjF4hOL1qsiVzdXixdJXzRvRRd/kWyJ7bT8OTXOp1DECXCU++i7NaHsdQ0DU9+jdj6F/P7Bg17nVYkibr2FPGsis8pk8hqRNI5asJeMqpJOqvRllHJZQ08TgcOGbxOGUGQSGR1XIrMm7sb2H4yzu5TcZqTGgGPTLnPicch43fJlPodlPqc+Wtrjmc52pJkZ12UlQdbWLK/ma0nIuw5FeNULMORlgRv72sirepIpomq6gysCHC0JcnBxhhL9jby1q5TnGxPcaQtTnNM5Whrkn2NMXafihFNZPAqMpmcigNIa7YHeFNUJZpU0XXQLDAskAVoy4BmWSS2v0HD4/eht9dTMutbFF/+hU7z4HqijcY5/0lq73JCU+6i+KqvIsgK2bq9ND59P2YuTfntP8bTeyymmqFl3o/tgH7ElZTd/H1Ep/c9/x6NTJzMkc14Bky1k++5dCdRttTe5RixJoITbkEQBFLJGMnti/AOnIYSqiCW0thRH6MpnkWWJHTDRBIFPIpISjVQZAHDsjBME8MwSGZyaJqBaZpYllWYE/83ohCIF/hf4Wwl73dDVFwExlxH9tgWcg0HAQhOuNmez9k0H9HhxjfyKjIH16O11XUonl5E9sQO9HjLOccTBOHMrLCukTqwhuzJ3bS//RcEWSHQYUGBZRIafxMl134DteEwmUMbMJLtSC4fpTd/D/+oa0hsXkDzCw9ipM/4owuyQtHFn6Lslocxc2kan/wq0ZVPfahCbkVex3u+RhTgoVmDGdUtzE9f38fN/722I7hbywPzdlEXzX5o1/NBsSyL1L6V1P/t82SObiY84x7Kbv0hsv+MgJxlGkSWPUbrKz/HUdGLyo/9GkdZj/zCeloFtXTWtxAVJ8ldS2xfUZcvL6L3fq8luettnNUDUMKVtijbpnl2prtmENnj21EbDhAcfxOCKGFqOeKbXsbVbXg+U/5eCNjfR6nPgWnZSY+Vh1rzYwXmOxbd9Ufb8rP6umnx4ILdhcp4gQIFPtKcroynczqaZtG73E//Sj+mIGGZkNY0YmmNSNrEAIrHXIcgKSQ6quLOyj64uo8gvmkeppa1Z7Yti/imBYBtO4ogktz19nnP/X72BJZp4CjtSsVdv0TyBMid3E1y52Isy8I37DLKb30YMxOn4cmvku7QHDmNu/twKu/5A54+44iu+DuNz3wDrdW2xkqpBvsa4uxrTNCeUDnRmqI1nsUhitRFUjTHM2iGxaCqAEO6BehX7ifgcdCt2Mvg6gDDa/y0pXIcbkliYSKLFpZpMKFHCUlVxbTgYHOKtoSKz6mQVXV21UVYfqCZ1kSWxniGSDqHLAp4FIW+5X6ymsGJ1gxBl8yx1jSGICKKErXtKfbXRamPZohnVA40tFMXtVvmPQ6LnKZTH8myvzHO4ZYULYk0imChOAVUFVQLIjlIp+25aIdsq6O3JSAbb6b5hQdpf/MPdmL+nj/i7T+5033M1u2l8e9fRms9Qen1DxCceKsdCO9d8Y71/5c4q/qhx5tpfOYbZA5vJHzxpym+7PMXFGU7m9TeFWDq+AZdBEBi+xlRNss07FG10u64e4/FAbRvmo+lqZSPvwXdArdTJOxSUCSJk60p6hNZcrpKWtUo9truAC5ZwjRMUppOTreIZc/sEQuB+L8PhUC8wP8Kp5W8u71PJXX/iKts67K1zwG2YrlnwBQS217DyMQJjLoGQVaIb7T9RL1DLgas8y667yQ09S6clX2Jb5iLUtyFkmu/kX/OTMfJ1u1FbTmOFChBkJ00PvU1Mse3I0qKPYs28mqydXtpeOLL5DpsLk7j7jGCynv/iHfQdGLrnqfhifvy7fX/DAGXfF6f7pBbpvodyqemBcsPNPPT1/fx3yuP2n7g2Aqt/5vo8RZaXv4hra/8HDlYTuXHf0tg7A2d2seNVISm57+bb0srv+1HSN7QmYX1yCbCl3yGoks+DUBk2WO2n3jNQCru+lUnX9H3IndqH3p7Hd4hlwKQ3r8aPdpIcLytghpb+xySrxjfYHvOLLnjDXsubNJt7/scvcp8CEDLBfzVDeNMa/r4nsWd2gfPblsvUKBAgY8ikijgdSoEPAohlwNRlOlW7KXE78QhClimiSyBJIA7EKZ89EwSe5fhSDSiAMFJt2GmYyS3v4EcKMM7cBrJHW9gZOLI/mLcPUeR2r2kUxfb2byfPUFi66sIshNH9QAiSx6h/a0/gWniKO9FaOrHUEIVtLz0A6Krnu50LskdoGTWtyi55n709nrqn/gS0TXPEklotCTSJFIZDMtCECx6lvoIuB1EkjnSOYOmVJaDjUmCDgdhr8KILkEEEZJZndpojpPtKQRBoCmeQxFFKkNuyoMuJEFmVLcQQ2sCHG9Ls/1klLVH2jjUGGfJvkb+suIwr+84xbGWJGnV9hRvS6n4HA6KvDJDa0KEvRJOEaLZLHsa42yojXKqPUUqkyOWUYnE0ngdMpG0jiAJtCeztCdzxDMqTYkMDglckohTAZdkB9+iBNEMJDKQMA1qNy+k/m9fIHdqH0WXfpayWx9GDpTk753tgPIKTc/+J4List1U+k7Eskyiq56mdeEvcFb1s33Fw1XkTu2j4cmvokebKLvpe50E194Ly7JI7nwLR3kvHOU9sXSVxOYFuLoNw1nRm/SBNejtdQQn2EmATCZBYuurePpPxirpQjQHjXGTlqSGzylRFPDQtdjNiK4l9KsMUdIhkCcItviuhK2YbpgmZocYb6E9/d+Hglhbgf81RnUL8+mpvXhg3q78YwLnelQDiE4P/tGziK1+BrXpCI7yXgQn3Ep630rimxYQnnpXh4jbmwQnzUYJVeDqPoLkjjft1qF3yYKGJt+BpWsIsoKRipDcvZRc3R7UxsPIRTU4K3pRduN3ERQXzS/9gObnv4ur12iUohoSW1/DP/pa0gfX0fjMNwnP+AT+Udfmf0Qll4+Sq76Kt/9U2t76I03PfBPvkEsJT7+7k//lByGePX87eiyjEzvrufOpnv9vYRka8c0Lia2ZA5ZFaPo9BMbMOue7ydbupHXhLzGzyU5CK9m6PbTM+wmWrtrKqT1HYWaTtC78JZmjm/GNuIqiiz+FIH2wn7Xk9kUIDjfeAVPsNvT1LyIX1eDuM57cyV3k6vYQvuQzCLJid2FsmIuzy+C8WN97IQpwtDnJu+U/3mkn905RQ9OycMgX9hQvUKBAgY8SiiwyqnsRR1oT9KsIIIsC2+uilIXclIVcOBSJ9lSW9oSOf8z1NG16jda1L1B8+X24agbh7DqU+Ia5+IbPJDDuJlJ7lpHYvJDQlDvwDbuclpd/SObwBjx9J17wGt7PnqD0um+hlHQlsvzvJDbOJXNsC+6eY0hue53g1I+hlPUgtvY5cqf2UnL1/Ug+W4hTEAS8A6fZtqdvP0Js9TOk965AmPlZsr2G405rOGQZwzCJ5TSakzkERJyIaLrBhuPtpHIGaU3HpQiE3E62HYsgKQIVfjdDa0KM6BIiZ1hYlkBa02ltzJDMadRFM4CFqptYgmW3xMsyoigS9siIQP8yD91L/eQ0na0nI+w9FUUWBLI67G+KouomLlGkWTVIqyYZ3aQlpSFIGXpV+KkMBdhTGyOn28WA1riK7leQJAtFtAXqTBVyoj0Lnqo/QPviP6M2HsbVYyTFl38BOVje6fsws0na3vg96QNrcPceR8lVX0F0+TBzaVpf/zWZg+vwDrnEbmGXFJI7F9P21h+R/aWU3fYTlJIuH+hvUG04iNYhCguQ3L0EI9lO8VVftYVb1z6HUtwFT39bUyaxeQGWmqF44q04BHvNDzjtVvysZlBT5KZHUYDqkBe388x+x7IsHLJEPKthWdj2spKIIAiFQPzfiH86EBcEoQvwJFCOHUM9YlnWb896zXRgAXCs46GXLct66J89d4H/+8we1xWA5zfV4pRFttZGLlixDYy+lsSm+UTXPEvZDd/BUdoNT79JJLa8QmDMdQTG3kBi+xvEN86l6JLP4B8xk5Z5PyZzeCOevhPOOV7qwBoyhzcgugO4e47G1W0oeqKN2NpnESTFVuD2n8nImtkknt5jSRsa2SObsHJpOwM+6XaCE2+j9bX/IrLkr2RP7KR45n2dAm13r9FU3ftnYmvmEN+8gPTBtYQm3Y5/5FUIkvKh3Esr/3/+tbAsi8zRzUSW/g29vQ53rzEUXfrZcxZbyzSIrXmO2LrnkcOVVNz8AxxlPezs9LbXaV/yCHKwnLLb7YVVaztJ88s/RI82UnTZ53H3Go2l5/KBuNpaa2euI/Uopd0IjJ51zr020jFS+1fjH3YZosNN+vAGtJbjFF/1FQRBJLp6DpKvCN9Qe9Y8seNNe0G++mvv+pmH1wRJqgZHW5K8m1vfsJog5QEXJX5np8dP2/2902quQIH/KxT2BQX+GYIeByO7FmOYttd0yOugPZXD51TIqBYlHgenojqWp5jQ8MuJbFtE0aRbwFdBaNJtND37AMkdbxIYfS3uPuNJbF1IYOz1uHuNsa3Ntr5+3kBcjzXbre1qFmdVXzz9Jr/nnkCUFVw9R5E9to3U/pUoxV0IjJiJ6PLhrB5A5O1HqH/8S5Rc9RXcPUfl3yt5w5TO+ib6kEtoXvzfHH7mOwT7T6DHzHuIeSppS+RwyraYG4IJgkTMtLAQME3b0UQSJdK6iZhTKXO6UGSRiqCLyrCHxqgdfAddEqZpcKg5Q0o1aIznaE9kcCsSiYyGxymT1SyKvS48TpmTkTRuZ5b6SAbTFFEkEY9LJKtpHGqI0ZTM4ZAkFNHC51HIJCziWbuSrhoGVWEPJgaIYKmQ0sEp6ySzFrJoV4F1E5yZNtrfeorI7reRfEWUXHN/xzx25wA0d2ofLQt/iRFvITT9EwTGXo8giGjtp2iZ9yO0tjrCF30S/+hZ0GFZmtiyEFe34YQuuhfJX3Tm+1IzpA+sJdeh++MfeTWOkq7n/B0ktr5qJ+cHTscydOLrX8JR2RdXt2GkD6xBa62l5JqvIwgiRjZJfPMrePpOJFTVHb8L2lPgcjjwegRqinyM7FrEsC4hTOzvze0Q8l1vkiQQdCtIomRXogSxEIT/m/FhVMR14GuWZW0VBMEPbBEEYbFlWXvPet0qy7Ku/hDOV+AjxuxxXZk9rit/XHaYze8yCys6vXZVfM2cM1XxSbeTPrCG+Kb5hKfelRdxC064BXfvcUj+UhJbF54TiGeObqF1/k8QvSGsXJrEpvkoxV0ITrqdirt+RXTVM7S+8gv7OD1HoUXqia2eg6WrVH36EaIr/k58w1wkfylGNokcKKXsxgdJbHmFyPLHaXj8SxTP/I9OC6/ocBGecQ++IZfSvvSvRJY+SmLrq4Sm3IVnwJT3VPa+EL3LfBxvTf6vt5yfj1z9AaIr/072xE7kompKb/oenl5jznmdFqmn9dVfodYfwDv4Ioou/Ryiw42pZWl/60+kdi/F3WsMJVd/DdHlI31wHa2v/ReC7KDk6q8R3zSf1J5lSN4w4Ys+iRwsI/L2f5M9sQvJX0JqzzLMdJzwjHs6nTe5400wNHwjruoQZXu+Q5RtOtnaneRO7iZ88acQFaddDV//Is6aQbi6Xnj+XBJh16lYh0fohZFFgVvHdOWhV/eg6iYvb627oNVcgQL/xyjsCwr80wiA2yExsDKAiAkmZEoNXtp0Mu8MEhx7E5Htb9C25gWCl9+Hq+tQnF0GE9/wEr5hlxOccAuNh9aT2PYawfE34x9xJdGVT6K21nYKwizToOmF72LEWxCcHlK7FhNZ8ii+4VdQdvMPiG9acME9QfnNPyBzYifNL34PLXuK9MG1eAdfjH/Y5Tir+tP6ys9pftHWlQlNuxtROZN4lXuOovreP5LZMp+WNS+w/befIzBiJsUTbkHzhnHI4HOJKLKE3yHQmjJIZHVb6EwxyOV0/A4nfoeMgEm5303/cj/xjEZde4pTkQxxVSedM4glNXJqHJ9boVvQTVnQiWWJVBW5CHkdRDM6Dllmy4k2atvTlHodNMSyZDSD1qRGOpdDy4EqGggWaHqOlAqGZaKqYFgGsbSGqpnIAqim/UOQzllE1Y5qeDZJYuM8YpvnY5kGgXE3Epxwa94qLv99GDqxdS/Yo2GBUiru+BnO6gEAHev/rxEkmbJbHsLdfThGKkLLgp+RO7kb75BL0FpPEHnrT532BNlj22h7/deI7gCWoZE+sIbqzzyK6Dgz0mekoqT2r8I/7HJEp4fk7iXosSZKL/40YBFb8yxyUQ2e/lMASGxagKWmKZt0O71K3YS9TuqjGVTTokvAybR+5ZT4nNRFMkiSQJHHgSgKuBXJ1iYQLBRFzuvFiIUY/N+OfzoQtyyrAWjo+M8JQRD2AdXA2QtugQLvyviexciigPouEUxgzCwSmxcQXT2Hshu/i6O0O55+kzuq4rMITriZ1J6lxDfOIzzjHvwjryK64gnU5mOdbKbShzciODzUfO4JLNMgfXAt8fUv2cJglf0ouuxzGPFmsrU7cfcchRKuwlkzkNja50huW4QgO3HWDERrOUHjE/9h+4b3nUBg9CxcXYfQ+sovaX7xe/iGX0F4+j2dFhmlpAvltzxE5ugWIssfp3XhL1DWv0hw4m14+k38wAH54eYkNSHX/6r42tnkGg8TW/Os3XHgCRK+5DP4h888p23cskyS2xYRWf4YgihTcs39eWVyre0kLfN/itZaS3DSbHsm27KIrPg78fUv4qjsQ+l1D5A5sgl3rzGEJt1OZNljxLe8gqf3OIou/yKC4kb2hWme+xDpwxs7BeKWoZHY+iqubsNxlHQlc3QLasNBii7/Iggi0VXP2NXwYVcAkNj+Rr4a/m4Za+M9EiIC9izkQ7MGE0mrqLptIafpttVcIfgu8H+dwr6gwIeBKAr4XTJkLUo8Lo6KKWIZDd2yKPGKNKdMCJRQNvpKmje9hmfcLSihCkKT76Dp2f8kuX0RgTHX4eoxivjGefhHXo1v2OVE1zxLYvMrFF/xxfy5tPZT6O2nKJ75H3Yg13yM+OYFJLYsJLlzMaHJs3F1HXLePUFi62sY6RjuHiOxtBxti35H9uReii79LI7SblR87L+IrniCxJaFZI5to+Sqr+Cs6pc/tyU78Iy7hUGjLuHY23OIb32NxM638I+8mpqJ1+MOllIRcJHKGhh6jpxuJyEEAXoUe4hnVfa3JPA5FRIZnaZ4CguRYq9MJKOyrTZCa1JFxMLtlKky3bS6MnhcDnqWuCj2OMGEk21pSgMO2pI5VNVgXzxBJKOjiBbRRI5UDtKAZIIE+EXwOewWdN2A9pROJK2ja/YMuI5t09acAzOXIr5lIYlN8+3uwgFT8/P0Z6O21tL22q9RGw/hHTidoss+h+j02r7iK560xVQr+lB63X8iB8vsqvn8n9rjbFd/DUvNIAfLz9kTuPuMo/xOW8gte2wrzS9+j1z9Adzdh+fPndj2Ohg6/pHX2F16a59HKeuBu/dY0vtXobWeoOSa+xFECSOTIL55AZ6+E+nWpzf9ywMUBVxM61NKPKMR8rnoWeono5nIIrgVmZZ4Dq9DQu+wKRNFAYcEVseceKEa/u/HhyrWJghCd2AEsOE8T08QBGGHIAiLBEEY9C7H+LQgCJsFQdjc0nKu4nWBjy6juoW5eXSXd1VSF51e/GOuI3N4A7mGQwAEJ92OpWaJb3gZpaga74CptohbKopv2OUIijPvJ3oaS8shOtwIkoyoOPENmkHlPb+n+Movo8ebaHzyq2RrdxOceEaQyz/iSiru+i9Se5YR3zgX/+hZVNz9W6RQhW3B9fZfsHQVR1lPKu/+DYEx15Pc/ib1j32BzNEt53wWd89RVH7id5Rccz+WodO64KfUP/p5kjvf+sAK6/8KQbhlWWSOb6fphQdt27aTuwlOuZPqT//VFtM7KwjX2k/R9OwDtC/+M87qgbZC6sBpeRXzhr9/GSMVoeyWHxCaPBszFbMF3Na/iLvPeMpu+j5yoJRs7RmNAf+oa5F9RWQOb0QJVyF3zOWJLj/WWZ7uqX2rMJLtBMZch2VZRNfMQfKX4htyMdkTO8jV7SEw/mZExYmpZe1qeJfBuLsN+6fu06UDy3n+MxOYPa5r3j1AEjrPiRco8FHhn90XFPYE/97IooDHIeNySLTEsyQyGl6HhFMW8SnQLeSgYvItCIJIokPM1dV1CK5uQ4mtfwlTzRKadBtmJk5i22tIniC+QTNI7Vnaye3k9Poguv0IgoCjvCclV32Fqnv/aLeYL32U2Lrn8fSblH/P2XsC7+CLKLvlB7Zv+O4lND75FdTmY4iKk6JLPmP7kGs5Gp++n8jSv2GqZ9ZtE9DlIoqv+CJVn/wz7j7jiW94mX2/v5dNz/2ebbsOEs3lSGrgUMApg65ZNCZyNMRV4ikVpwSnElkON6WIpHO8tKWOzcdbiaTt/YRuGKRzOeqjKZKqSVXQgSxaqIbV0TatcqwlRUMsSyyjEs2opHK297gpkNc5MbDnu9uTkDPsLjDdgmQOolmIGRBR7SBcT7QRWfEEdX++h9iqp3FWD6Dy7t9Seu03zgnCLUMjttYWttVjTR3Cdl+37eKijTQ+8y3im+bhG3EVFXf8HClQSnzzAhqf+SYAZbf+EN+gGRfcEwiihKu6vx0Au/32OfUz+wJTy5HY9hruXmNQimts+7tIPaGJt4NlEl09B6W4K54OJff4pnlYaprg5NnkdIvjkSzxjE7OtBjds5ixPYvpUuKlMujE51IQBQFJEpBFAVG0wy/DMBFFEUkszIX/u/KhBeKCIPiAucCXLcuKn/X0VqCbZVnDgN8D8y90HMuyHrEsa7RlWaNLS0s/rMsr8H+EG0bWoEjv/mMUGD0L0eUnuvppAHtWfOBUElsXYqQiBCfehqVrxDe+jOT2236ie1eix1vzx5D8xRipSKeAVxBEfEMuofpTf8E3fCaJLbaKZ+bYNtvX0TSQAyWIbj/unqPtdutgOZV3/hL/qGtJbFlIw5NfRW05gSA7CF90L+V3/BxBdtL84vdoeeUXGMnOrfeCIOIdOI2qe/9oZ1llhbZFv6Puz58gsvIp9Fjzh3h3/2cwc2kS216n4bEv0vz8d1CbjxKa9nGqP/c4oYm3ndtypmtE1z5H/WNfRG0+RtEVX7Jn7wIlGNkkra/8nLbXf4Ojsi+Vn/gd7h4jyZ7YSf3fPk/u5C6kYDmiw03ba78GwNN3Inq0If/9OCp6Y+o5sid358+px5s7zYpZlkV848soJV075vu2otYfIDjxFhBloiufQvKX4u+ohie3vY6RihCacud73o93ay0TBRjWJdSp/fyZT47nq5f169SWXqDAR4EPY19Q2BP8e2MBLfEcSAKDq0NM7VvK5N5FVAQ8+J0ysbSK5Q1RMvpKEruXYpz2FZ98F2Y6SmLbqzirB9jWZhtexlQzBMZebythb3k1fx6pwzrz7DVXKe5C2c3fp/S6BzDTMRqf+jqRZY9hqNlz9gRqwyGMdCzvG25mkzQ8+VXim1/Bskzc3YdTde8f8Q29jPimeXaS/h02Z+nT5yyqpvSa+6n65J/x9J9CYvsb7Pvjp9j81++RPLQFTTWxLDsgbk9m8cgCmqFzsCFONqvREM+wqy5GJKXRnsyR1TQwDTI5yKhgmnCyJcGWY1G21EZZvPcUczYcY8uRNg40xDEsE1OwVeodHcfO5jhHcDQDaCo4pDPflYG9vmZP7qbllV9w6r/vIb7hZdzdh1Px8d9QdtP3cJT3Oud7ztbtoeGJ/yC66ik8vcdTde+f8PafbCfm9yyj4fH70NrrbF/xyz6HpedoevYBIkv+iuBw46wZQHzd88D73BN0fM+S70zyO7VnKWY6Zv99mAaxNc/aFmV9x5Pau8JWSp88266Gp2MkNr+Cp/8UHKXd0QyToEsirevUtqbZcLydRFanzOegIujB45RAFKgp8uBQ7KKEIAhI0vuzUyvw0eVDUU0XBEHBXmyfsSzr5bOff+cCbFnW64Ig/EkQhBLLslrPfm2BAgMqA+yoi13wedHpITDuRqIrniBbtxdXzUBCk2aT3reK2LoXKLrkM3gHTiOx9TUCY64nMOY6EltfI75pHkUXfwoAZ3kvsExyjYdx1Qw46/heii/7HN6B02hb9DuaX/guvmGXE55xL2r7KUwtS9G0u2l4/EsY8WaKZv4HRZd8GlePEbS9/hsa/v5lwlM/hn/0tbhqBlD1id8TW/8CsfUvkjmyidDk2eeItAmihHfgNDwDppI9sYPElleIr3uB+LoXcHUfjnfwRXj6jO80y/S/iWUaZGt3kdqzjPSB1VhaDqWsp93WN3Aagnx+n/PMkU20L/kreqQeT7/JhC/5NLLPDpAzx7bR9vpvMNJRQlM/RmCcbR8WXfU0sbXPIXpCBCfNJjTpdgBO/m42RiaOUlxNtnYHmSOb8PQZjxyqRJSdWIaev1a18TDeQTPy15E9utkWZbvyK/Y5Vj+NFCjDN+QSMkc2ojYcsFvbZQVTzRBb/xKu7iPeUym9W5GHk5H0eZ8TBc6rgF6YBS/wUaSwLyjwz6IZJnXtadrTKqZuEvQppLMGiiKjKCJgocgChiXS9eLbadv6Bom1cwhdfT+emgG4eo4ivn4u/uEzCU2+g8anv05iy8K8hkxi66sExt2A6HAjecNI3jC5hgPnXIcgCHj6TcTVfZjd6rzxZdKHN1By5ZdBlDG1LBU3fY+m577ToQ/zZTy9x1D5id/Ttui3RJY8QubIJopn3occKKX4ii/iHTSdtjf+QPNLP8DdZzzhiz55ToVYKa6h5KovE5p6F4ltr5Hc8SapwxuQAmUEB19El9EzMIuqSaoGLhkEUaRL2EttJElWNe1ZcgtECxAsXA67nV0zTE7FMkiyRE6T2d+UAMvCsCy8qo7b4SCRVjEQMU2DpKrjkCB9HsMWSbDHsSxAizUT27eC5K4l6O11CA4P/hFX4R91DUq48rzfsZ5sJ7riCVK7lyIFSim98bt4eo8DbDHV9rf+TPrAapw1Aym5+uv5VvTmBT/HTLbi6T+Zkmu/iSAI73tPALYyOpKc1wmwTIP4hpdxVPbB2WUIqd1L0SP1lF7/AJgmsTVzUMp64ulni/zFN8zF0lVCk2YDUOpVEASBnGqSMkyCHgdHWpKUBdx0LfZSHfZ2+tynfcILVfACH4ZqugD8DdhnWdZ/XeA1FUCTZVmWIAhjsSvxBWPcAp3YciLCHY+uR30fqmP+kVcT3zyf6KqnKL/txyhF1fiGXEJi+yICY68nOOl2UntXEFt/JjBP7niD4IRbkDxBnF0GAQLZ49vOCcRP46oZSOXdvyW2+hnim+aTObaV4pn/QfnND4EkE5p+D9GVT9D4xH0UX/kVPL3G4LznD7S98Qciy/5G+tA6imf+B0pRNaHJd+AdOJ32tx+xRdq2LyI8/RO4e4/r9EMsCALu7sNxdx+OHmsmuXMxyd1LaHv1V7TLTtw9R+HuMw53z9H/sP3ZP4qp5cjV7iR9eAPpQ+sxU9EO269p+IZehqOq3wUXFbXlOJFlj5M9tgW5qJqym3+QF7Izc2kiyx+zvV+Laqi44Ts4K/uQ2PkWsVXPYCTb8A6+hPAln0Jy2otZdM2zuLoMxkzFcJT1xFHWs8PaZCxyoBStvQ5nzUAAsqf2YamZvMCaZVnE1r2IFCjFO3AqmcMbURsOUXTFfSBKRFc+hRyuzNumxTcvwMzE31c1PKcb59U5kES4bUxXbhhZA8Aflx0uqKEX+MhS2BcU+DDIaga6aVHmd9IQyVDmddEuaCiSRJciD43xLOmcjmwJOIIldJl0PSeWP49v3M3Ipd0JTbmLxr9/mfjGeYSm3Im752jiG1/GP/IqguNvpvHpr9tz5GNvQBAEnF2Hkj2+A8syz6vVIjq9FF/xJTz9p9C26Hc0Pv0NAmOvp/S6bwMQuuhe2l79FS1zf4B/1LWEp99N6Y0PktzxBpGlf6P+b18gfNEn8Q29FFeXwVTd83vimxYQW/sc9Y9+lsCoawlOuAXR5et0XtlfTHjqxwhNvJ30wbUkd71N+9rnaV/7HK6yHrj7TMDfeyxF3XpyqCWBqurEshop1dYkCblkJMFENU3iOTCzBiLQrGTJ5ESiaR23BFkdDF3FsExSqo7HKSNIIpYAgmig6HZL+mksy6K1qZaW45to3bcun8RwVg8kOPM+PP2nIjpc5/1uTS1LYtMCYhtewjI0AuNvIjjhtvzr0wfW0vbWnzCzyXxiPnN0M62v/opc3V7kYBmld/wcV4eA2wfZE1iWRebYVpxV/fNFg/T+1ejRBkpnPACmTmzNHBzlvXD3mUBy+yL0aCNlN30PQRDRE20ktr6Kd9B0XCVdKPGKjOpZhIlIazxLzoK0ZpLKmaja+e1mCwF4gdN8GBXxScBdwC5BELZ3PPYA0BXAsqz/Bm4CPicIgo7dzXKbdTodVKBAB+uPtuWFqy7kJ34a0eEiOP4WIkseIXt8O+4eIwhOup3knmVEV8+hpMN/+nRgHhh/M6k9y4lvXkB46sfsYLx6gG0jNnn2hc+jOAnPuAdP34m0vv4bmp//Dr7hMwlP/wTBsdfh7j6U1ld+QfML380ropbe8B1Se5batiWPfZHQpNsJjL0Bpaiaspu/n7fyann5hzirBxKaeheurkPOObccLCM05Q6Ck28nV7eX9P5VpA+uI31wLQCO8l62n3XNIBxVfZF8xR/qj7uZS5FrOETu1D5yJ3eRrdsHhoaguHD3HI2n/2TcvcZ0UoA9Gy3aSGzNs6R2L0V0egjPuBf/qKvz3QDpQxtof+tPGKkIgbE3EJx8B4LsoHnuQ3bLnijjqOiDs6ovktOLpWukD64lW7sTb/8ptCz8OUUXfQpPnwn2fV3yV5SSrpiZRL7VMHNwHUgy7h4jAMjV7iJ3yhbSQZSIruoIvAdfRGrvCrSW4/aYgCTbYiwb5+HuPa6TuM6FGN+zGK9T5pkNtZ3vpQlVIbub4XSyySGLhXb0Ah9VCvuCAv80iiRiWhYZ1cDtlPC6HeQMC4ckUhlwUeR14lUkFAEcDglz2i2cXPcKkVVPU3rDd3BX9MbTbzLxTfPxj7ya4JQ7OwXmrm5DiW+0Z45FxWnbk+5bQe7UPlw1F5QyslvM7/nDO6rjG/MCbBV3/Yro8sdJbHmF7IkdlFz9NfzDZ+LqbnfMtb/xO9L7VlB0+RdQwlUEx9+Ed9AMoiufJL5xHsmdbxEYfxP+EVefE8QKsoJ34DS8A6ehJ9pI719N+sAaImueJbJmDnWeIP7uw/B2G4RQ3h9nSTcESaY9redV5k+HhQrQltLQDBEZu6XcBFIq5DQdr0sgq+moHW/IqKCbBmrbSXL1B8id3E32xE6MpJ07c1T0JjT1Y3gGTD2vANtpLEMjuXMxsbXPYSTbcfedQHj6J1DCVUDHTPnbfyF9cC2O8l4U3/owjrIetC3+b1K73sbSskj+EvyjrsFVPeAf2hPo0Ua0luN58VbLMomtex6luCvuPuNJbn8DPdZE2aWfxdJzxNY+h7N6IP6eo9GB+LrnsUyD4KTZCIBpWQypKabC72Tj8XZSqkFta5xit0h18F+ji7HAvy4fhmr6anhXfS0sy/oD8Id/9lwFPtqEPQ4EQUCwLERRwDStdw3G/cNnEt80j+jKJ3F1H44cKMU/4koSWxbaQd2k2zoC82cpufI/OjzHFxIYcz2S249nwBQib/8FteU4jtLu73ptzur+dnV81dN2dfzoFopn3pefe+qkiHrlf+AbfDGu7iOILP5voiufJLV3OUWXfR5Xl8F4eo3B3X0EyV2Lia15lqZn/xNn16EEJ96Kq+vQc4JpQRBxdRmMq8tgwpd8BrXpKJmjm8ke305i2+skNi8AQPSEcJR2QymuQQ5VIgVKkX1FiO4AotOLoLjygmmWaWDpKmYuhZlNYiTbMRKt9gLVfgqt9QR6tPH0FaCUdsM/4krcPUbi6jrkgq3np9HaTxFb/xKpPUsRRCmfDJE6BFL0WDPtSx4hc2g9Skk3Sq9/AGdVP7R4K5G3/kjmyCaU0u6U3fggeryZ1oW/wjtwOqLTg7vP+LyyumUaxDfNo+ym71F02RfIHN5A9vgOQjPuxVnZ11bE378Kd49RiPlq+hlv8PS+lfnAGyxiq5+x288G2NYk8Q1zsXJpQlPvetfPe5pT0QzfmjmAF7fUdersUCSB8T2LOyWbCirpBT6qFPYFBf5ZLMtCsCwEAZJZHZcs4pJFaorc9Cnz0ZTI0bvMy5HmJK1JFTOroyoeQmNvoH3V0+RO7cdZ3Z/QlDtJH1xLbN3zFF3yGTsw37zADswn3m6rq+94g8DoWbh7jUGQnaT2LH/XQBzsEbniK76Ip9+kjur4/QTGXEdw8h0UXfpZ3L3G2KNqT36F4MTbCI6/mfLbf0xy+xtElj9B/d++QHDCLQTH3YjsL6bkqq8QGD2LyMq/E13+BPGN8wmMvR7/8Jnn6KyAXSUPjJlFYMwsjFSUzLEtZI9tI1m7k9jelfaLJBlHSTeU4i7I4SrkYDmSvxjJG0J3+VAdbnJZByU+EVU3Sac1FC1DLJOkNR1DzLSRaW9CjTRgtNeSaa7Ni5uJnqC9L+k+AnfP0ciBknOusdP3qaskd71NbP2LGPEWnNUDKZn1zfx9tkyDxNbXiK56CsvQ7Sr42BuwBIH4loUkty9CkGxnFclfTOvCX+EbetkH3hMApPYuB8iLrqUPrO3wBr8fS1fzgber52jiG17CSLZTNuublAUkWk+dIrHjTfzDLs8nHEQsNh1vZ2BVkKBLImtYVAW9VBUFaErk6O5+9/1SgX9vPpQZ8QIF/lm2nIjw/YV7MMyO0NuyUGQRrSOY8bkkElmj03sEWSE0+Q7aXv8N6QNr8PafTHDCLSR3Lia68knKbvjOOwLz6wlOuo30gdXEN80jPPVjeAdMJbL0UZI7F+dnx98NUXESvuhe3H0n0raoozo+7HLCM+6h6JLP4O49jrZFv6Xx6W/gH30toSl3UXr9A3bV9+2/0DTnW3gGTCM8/W47aTB8Jt5BF5Hcvoj4hrk0P/dtHJV9CIy5Hk/fieeojIMdlDsreuOs6A0donRq0xFyDQdRm4+itZ4guXsZlnr+OeX3RJJRQlU4ynvjG3IpjoreOKv6ndMqdz4syyJ3aj+JzfNJH1wHooR/xJUExt2E3JGJNjVb3T6+YS4IEJp2N/5R1yDIDhI73iSy7DEsXcNR0YfwxZ9C8hcjB8vsVvUtCwlOvBVBPjNb7+4xkuyxrZhaDtlfjH/ElfhHXJl/PntsK0ayHe/gi+z/fuK0N/inQRCIrnoapawHngFTSGx9zW4/u/kHZ9rPtizEO2j6eyZqTnOy3b7vN42qoTVhb1hK/E5uHFmTD7gdHX/XBZX0AgUKFDg/mm6iGiaSABVBF1nNJJ7VOBXNEPI6qAq7qW3L4FYchNyQVnUUQaTLlBuIbX2VyMq/22NrxTX4hl5KYtsi/KNnnROYO7sMJr7+JXzDrkB0evD0nUBq30rCF92LqJy/rfqduHuMoOrePxBZ+rd8dbx45n22K8q9f6R98V+IrX6G9MF1FM+8z05o9x5nq7CvfobU7iWEZ9yLu894HOU9Kb/5B2Tr9hJbPYfo8seJr3sB34iZ+EdejewvQcKuXr8TyRvCN/hifIMvRrIscvFmcqf2ozYdQW05TrZuD8beFVyoz/DIe3xGyVeEp7QG3/Ar7E68qn7I4ar31YFnpGMkt79BfOtCzFQUR1U/ii//Iq4eI/Pvz57YSfuSR9BajuPqPoLwJZ/BUVyD2naS9jf/SO7kbkRfEUWXfR5PrzEIovQP7wksyyS1622cXYciB8o6RNnmdHiDTya+cR5Gsp2Sa7+BmU0QW/8S7l5jcNQMojVh0rjiaQRRpnLqbWiASwBREGmO5wh7sgyu9CM5NIo8Tkp9DpI5nZxmYFoWDllCKhiFFziLQiBe4F+C9Ufb8kE32L6UI2uCbDpuq4yfHYSfxjtoBvENL9tKm33GI3mCBMffRHTlk2Tr9nQE5m/lA3NP/yl2YD56FpIniKfPBFK7lxKa+rF3bbF+J66aAVTe/bszs+NHNlF02Rfw9BlH1T1/JLL8cRKbF5A+tJ7iSz+Hp884XN2HEV//ErENc8kcWo9/zHUEx92A6PQSGHMd/hFXktz1NvFN82h95ecd/tWX4xt6GXLgwkrBgqzgrO6Ps7p//jHLsjCzCYx4q13pzsQxcyksPZcXKxFECUF2IDo8iG6/LVTjL0byFX1gH3Mzlya1byXJ7YtQm47Yn2ncDQRGzULqsA+zTIPUnmVEVz2NkWjF028yzi6DyNXuIrriCXJNx1DrduPsMpjiK75EfNM8sse24uqY6QqMv5mWeT8iOPFWMA1MUyO58y2SO9/C39FWeD4S2xchekJ4eo+1LcpWPY3kK8I//AoS29/omPv6PpaaJbbmWZxdh+LqMRKA2Jo5dvvZ5Dve970Y37OY2x9Zh2ZYKJLAs5+e0KnifVolff3RtsKMeIECBQq8C6Ig2LZehoUo2JZPIuCQJTKqjlOG1mQWVdeho5PO4/FSPX02ta/9ieyxrbh7jiI4aTapPcuJrnyS0mu/YY+tvSMwb5rzLZLbXiMw9gZ8wy4ntXc56X2r8A299P1dp9NL8cz77NnxN35P0zPfxD/yKkJTP0bptfeT7j+J9rf+TOOTXyUwehbBybMpnfVNMsMuJ/L2I7TM+xHOrkMIT/8Ezsq+uGoG4rrth+TqDxDfMLcjef0ynr4T8A2fiavb0Auu05IgIAfLkYPl+Sox2BVpPd5i7wnSMcxsEkvNYBkalmUiCqKd/FdciC4foieI7CtCCZTaj3GuavqFOJ2UT+5YRGrfKjA0XD1GEhh3Y6euP631JJEVT5DpEKDzj56FHm8hvXcFKSziG+Yiyg6KZ/4HufoDaI2HEfqMB/7xPUH22Db0WBOhqR8D7NlwrbXWDrxzaWLrX8TdawyuLoNpX/oolpohNO3jACSbjpLYu4LySTdTXFaKiElWtVB1k6BbIuhx4PUqhPxOMqqBKIiUB11ohoko2K3+HodcmA8v0IlCIF7gX4LxPYtRZDHfzqtIwvsSbRNEidC0j9Hy8g9J7lqMf/hM/KOvJbH1VSLLHqPizl8SHHcT0VVPka3bQ2jybNIH1hBb/yJFF30S/8grSR9YTWrvcvzDLn//F2waOKsH4I42ojUfo+Xlh/H0nUj4ks9QfPkXbMX1N/5A80vftx+/+JOEptyJb+hlRFb8nfi6522RmHE34h95FaLiwj/iSnzDryBzZDOJra8SW/McsTXP4eo2DO+g6bZq+vuoTAuCgOQOILkDUN7z/X+mD4Bl6GRP7LA3LAfX2qrpJd0ouuzzeAfNyKu7W5ZJ5uB6oqufRmutxVHZx/YFdXhoe/MPSMFyEltfBVEmMO4mQtM+hiCI+EdeQ+vCX+AbdgWSN4Szsg9yuIrMsW04KnoRW/MceqKFkiu/guMCn1GLNJA5vInghFsQJIXMkU32bPhln8cydLv9rOsQXD1HEV31FGYmTnj6JxAEAa3tJMmdi/GPvOpd590EoH+Fn0ha5brh1cRzel6oTTUs5m6tOyfYLqikFyhQoMC7o8giggDlPhdZ3cSlSLTEM+xtiAMCTdE0siQyoVcJkWSW5mQWRZJoiKbIjbyM+tVziax4AlePEXZldMx1xNc9T27M9QQnzya1dwXRlU9Reu39uLqPINZRFXd2GYxS3JX4loV4h1zygYImZ1U/QtM+TnzdCyS2vkr60HqKLv0snr4TcXYdale4N80jtX814Ys/iafvRCrv+T3J7YuIrp5D45NfxdN3IsHJd+Ao7Yazqh+l1z+AFm0kue11kjsXkz6wxhYaHTQD74CpKCXdOl2jeoFrE2QHSlE1SlH1Ba//3YLt9xOEa5F6UvtWktqzvEM13Y1v6KX4R16dVyeH09oxz9mja4qT0NSP4ew6lMjSv+LuNZb4hpew1AzOrsMoueZryL4iHJV9/+k9AUBiy0JEbwhPv4lYhk509TMopd3x9J9sd+Tl0oSmfRw91mQLsg2+KN8RF1n+OKLLT9WUGwm5FOI5k+4lCpYg4HcqeBSRXmV+ehT7SWU1ZFG0Ld06kkS6WZDAKHAuhUC8wL8Eo7qFefZT43l5ax0WcOPIGp5ad/xdbcxO4+49Dmf1QGKr59hzxA43wcl32qIo+1fjHz2LxLbXiCx9jIq7fol30EW2tdnoWTi7DEEp60Fi0wJ8Qy89J8tsWSbJnW+TO7kLweG2g+J+k4gsfRTL0HCUdsNMx3B1H0Fq99tkHt1GaMpd+EdeRdUnfk980zxia58nc3QLgXE3Ehh7A6XX3k9u7PVEV/zdXpg3zstXxUWnB0/vsXh6j0WLNpLatYTUnqW0vf4b2kQZV7dheHqPxdVz1LsGiP8TGJkE2ePbyBzZRObIJsxsEsHpxTtwOr4hl+Co6p/fEFimQfqA3f6ntRxHLqqhZNa38PSbBEDrwl+gtdaiNhzE3WcCjrLuWFqO02OljtJuOKsHkNjyCr5hlyOHKpB9RcjhSnKn9oFlIspO0ofWI4fK8/Pf7ySxeQGIEr4RV2JZJpGVTyKHKvANvZTommcx0zHC0z+BkWgjsWkBngHTcFb2ASCy4u8IipPghFvf9Z58ZmpPvnXlGdX9b8/b1en5Qt67QIECBT44giCgyBKKLBHoeKyuLU1NyI0gQFMsQ9CrIFoCzfEMWAKVISeHmuIEvB5Kpn2Mxld+QWrPcnyDLyI47kZbvXzZ3yi//Sf4x8wivu4FcmNmEZp6F41PfpX4pnmEJt+Bf/S1tL/5B3K1u3B1G3rOtaktx0luX4SpZlDC1fhGXInk9uf3BZ5+E0mLEmYuTcvLP8TdZzxFF3+a4iu+hHfwJbS/9Uda5/8EV7ehhC/6FP6RV+MddBHxjfOId4x2efpNIjjhZhzlvVBCFYRn3NPRVr+O5O4lxNe/RHzdC3ZLde+xuHuNwVndv5Mt6gfl/Va8T2OZBmrDITJHN5M+vAGt+RgAzi6DCY67AU//KZ0sV23tmBdJ7VkGgoh/1DV5J5v2ZY9hpKLEVj2FFCjF3X8KosuH5O0Y6brAnkAp6YqRbEdw+RBzaRLbXsc7+OLzuuForSfJHN1si6xJit0VF6mn9Mbvosea7cB7yMU4SrvT8srPEQQp75aSOb6d7PFtFM24h6JwGAETp2gRcsv0Lg8yrlcRvcoDdC2ydXC8Lvt7ME2LjKajGSZOWSpUwwucQyEQL/Avw9mVwt+8ffB9vU8QBMIzPkHj0/fbaqiTZ+MbcjGJLa8QXfEEnj7jCU25k7ZFvyO9fxWhKbNJ7VtOdNUzlFz1ZQJjb6Dt1V+RObwJT59xnY4dX/8S0ZVPIvmKsbQsmCaSy4+RjlJ06eeRAyW0tJ/CWdWPwNjraXvrT0SWPEJy12KKLv0swQm34B00nciyx4mtmUNyxxuEptyJd/DFlN/6sD0LtuZZoiueILbuBfzDr7BnwYJlKKGKDtX02agNB22F1EPraV/8ZwDkYDnOLkPs1vTKviglXf6pRfidWJaJHm1CbTyUV0hVm44CFqLLj7v3WDx9J+HuMbLTfJaZS5PcvYTE5lfQow3IRTUUX/VV21tclMie3E1k2eOoDQcQHB7KbvsR7m7DyJ7cTWrvCrLHt+HuaA0PTryN1N7ltL/9F/RIPYgS9Y/fB2oaJAXZV0Rq30pydbspv+3Hna7fSMdI7lyMd+B0ZH8xyT3L0JqPUXLN1zFSURKb5ncE3n1pfe2/sCyT8DS7VS1bt4fMofUEp9yJ5A2d9/44JIFp/cq4dFDnZMgNI2t4btNJDNNCEoW8XVmBAgUKFPjnKA44OBmVyGkGVWEXRR4niZzBiC4hdp6KcbAxQTSl4ZLB038Kjo3ziK58Ck+/SYhOD6HJd9D+1p/s3/dxN5Hc8SaRpX+jfPZP8fSdSHzjPPwjrrRVzFc9TWz9i+cE4nqilcan7wdAdPrsZP+oa2wF8XfsC7T2U7i6DcfMxomunsOpv36W0MRbCYy9nsq7f2tXwVc9Q8Pj9+EbeinBybMJTbkD/+hriG+aT2LLQtIHVuPqPoLA6Fm4eo5EkB151XQjFbFdVA6sJb75FeIbX0ZQXDirB+CsGYizqh+O8l4fqs2pmU2iNh8lV3+QXN0esnV7sXIpEESc1f0JX/RJPP0mdRqnsyyLXN0ee4zv0AYEWenQjrkR2V+CkYnTvuSvJLa8ClgEJ99BcNyN5BoOvvueINaM6PLS9OKD6C0nAPL7tOSuxVTe/dtztF1iG+YiyE78I6/CVLPE1syxOxt7jaX1HYF3rv4A6X0rCU64FdlfYifylz2GFCijYvQ1xDULtwkhjwPVhKBHxjAEElmToy1JqsNunLIE2L7tTknENE2EgilEgfNQCMQL/Msyc3Alqw61dnpMALoWeTjR3lmMzFk9oGMhfRnf8CuQfUWEZ9xL8wvfJb7lFQJjriO+ZSGR5U9Q/an/JjDSXuwCY2bhHTDVXnTXPY+799hOGcvUvpU4uw6l/LYfAWBmkiR3LMJR3hs5UIKpZuwfe8tACVdSfstDpA+soX3JozQ9802kYAWh0/NgI68iuuxx2hb9jvjGeQQn34Gn30TKb32YXMMh4htfJr5pPvFN83H3GoN/+BW2oIko4azqZ7e9zbgHvf0UmePbyJ7YQebIRlK737YvVpRRiqpQimqQQxVIgVJ79tsdQHB67Jmp04G6aWDpOcxc2lZNT0Ux4i3osSa0SD1a20ksNWPfc9mJo6ovwUm34+4xAkdlXwRR6nT/1aajJHa8SWrPUiw1g6OqHyXT78bTZzyCKJFrOER09dNkj25B8hUTmv4J9EgDVjYFgFJUjRwsR481A/biLfnCuLoOIX14o63gLsl4+03GO/giXF2GIMgKsXUvEF35JEY61mnDEd80H0tXCY67EUtXia58Ckd5LzwDptL26n9hWRbhaR8n13iY1O6l9qYgWI5lWUSWPobkKyIw+roL/m2qhsXivU2sOtTCM5+0Z9bWH20jkdHygoOGaXGgMVFoQy9QoECBD4HygJsx3UUiaZVxziKyhkk0qSEIFsdaUyRzOgGXSCRlYgoiZRfdQ92cB2xRr/E34Rt2OYktC4ksf5yqe/94JjA/uI7QtI+TPrSe6JpnKb7s8wTGXGfrlzQczKttA2SObMJSM1Td+2eUki6YagZBlMnV7z9nXyAIEBx3ky0Mu+RRoqueIr5pHo7yXvjH3kDVpx8huuZZkttet8fjRlxFYNyNhKd+jODYG0hsX0Riy0KaX/q+bbE5bCa+wRcheUNI3nBeiMzMpcme2EH2xHaytbuJrZ7DaWE2yVfU4aRShRwss7Vg3EFElw/B4bYT6YIIlomla1ha1t4TZOIYyTb0WAt6pB6tvQ4j3pK/D3JRNd5+k3B1H46r+3B7FO4dmNkkqb0rSGxfhNZyHNHlJzDhFgKjrkbyhjFzKaJrniW+cR6WlsXTbxJCh8r76Tb68+0J/MMuR2s/Ra7+AHp7HUpZT1ulvs8EZH8xRiZO3e/vJH1gTadAXI81kdq7DP+IK5E8QaId9mkls76FWr+f9P5VBCfejuQrpnX+T5G8YQLjbgQgtWcZWvNRyq/5OopbwYGAgIVlWgTcCmGvm4ZYjpKAHYC3JnNUBtxYlpX/J0mSrd9jmojiB9PhKfDRphCIF/iXZfa4rtS2pZi//RQtSRUsCwSB5kT2vK8PTb+b9OGNxFY9bauW9hiBu9cYYmufxzf4YsIXfZLm575tB+ATbiG5azGRZY9RfuvDBMffTPubfyB7dAvuXqPzx7S0HFJJ+EzLtaGiRxvxDJgKgJGKYqppxI4gUBAEXF2H4Ok3gVzdXrSW47S99muydbspmnEv5Xf8jMzhDcRWPU3rgp+ilHQjOOEWPP0nUzrrm+jxZhLbFpHcudgWMPEV4R04Hc+AqTjKe9ntesU1KMU1BEZdg2VZ6NEG1IZDqM3H0NpqUVtrSR/ZBIb2wW64ICL5i1HClfgGX4xS2h1HRW97Q3EeBXc93kJ6/6p8tRlJwdt/Mv6RV+c9t3On9hFb9wKZI5sQXX5C0+/GP/JqEERia58jc2wL7t5jkLxh249UENHjLUTXPIvWfgq1bg+iy0dw4u34R151ToVa9ITy39NpjEycxNZX8fSfjFLShdiGuRjxZopn3ofacIjU3uUEJtyCFCil9dVfInqCBCfcAkB6/yrUhgMUz7zvHA/X86FqJnO31vHy1rrzahos2t3A7HFdz/POAgUKFCjwQQl6HDhkAcO02FEX5bUdDbRnVIJOhQEVQY61pklocWQVxC5D8fYeQ3z9C4SHXorpCdoJ+pe+T2LLq2f0ZJY/RtW9f8Y/YiaJbYsIjLoG/4griW+YS2zNs5Td9L38+U+vNWKHEKnocKMn2t51XyAHyii6/AtYCHawfGIHuXp7nSm+5NP4R15FfN3zxDcvILH9dfzDr7QFXcffTGDMdaQPrCGx7XWiyx8juvLvuHuOwjtwOu5eYxEdrrzau6fvBADMXIpcwyHUpqNorcfR2upIH1yLmYl/4PstOr3I4UqcNQNxlHTDUd4LR0Xv81baLV0jc2yrrR1zaD0YGo7yXhRd8SW8A6chKi6MTILo6mdIbH4FM5fC3XcCocl3ohTXEF3zLOkjm867J4hvfRVLzZLaY7vCuHuNITDuRpw1gzoVT0SXH0GSO+0JAGJrnwdBIDD2RoxkhPiGubj7TsBZ3Z/Gp75uJ9/H3UB630py9fspuuI+RKfHdnpZ9RSuyj5UjpiGUxYIe2UUWSTgkvG7FCTBIqnrJLIaTkXCY4m2/Z4gYHVUwU///0JreoGzKQTiBf4lmLOhlkW7G5g5uDIfuGw5EeGJdcdRdRNZFOhd7mdvQ4LMBQQvlHAV/pFXkdiyEP+oa3CU9SA84x7qH/si0VVPUXzFl3D3GU9s3Qt4B19McOJtRJY+SubIJnxDLia2/kWiq5/B1XNU/sdSDpaiR06dOYlloTYdIXzJpwHQ209hZhI4B07PvyS5azEYOmU3PgiWQcvLPya14y0yB9bh6T8Zre0kvmGX25ZdWxbSuvAXyCufxD/mOjthMO3jhCbPJnN4E8ndbxPfvID4xpeRg+W4+4y3FT1rBiLItu+6Eq5CCVd1Vki1TMx0HCMVwcwkOlTT1bxqOqLYWTXdE0Lyhs4bcL/zmFrzMTJH7HkwtcEeHXBU9iF8yWfwDpyO5Pbb8+GH1hPfOI9c3R5El5/glDsJjLoWweEmc2gd0ZVPo8WbcfcYSXTNc4Sn3oWl5TBzKdrf/COZo5sR3QHb3qxjdv58nG5ZP63ODh3e32qW4MTb7Jmztc/b96zbUBqfuh/JGyY47ibS+1eTq9tL0eVfRHR6sXSVyIq/o5T1wDv44gveh3ciCHaXxmlv8LOX2JmDKwH7b7mglF6gQIEC749EVuN4awqHLNCjxI9DtoObtlQOw7CwLFiyt5nSgBPdgtZElnHdQxxojOOWRUzdRDMgPO0eUo99geS6Z/Bc/HlcPUfh6jGK6Nrn8A6+iPBFn7I75zbPJzhpNsndy4gs/RtlN3+fwNjria58Mu9JDnZQDfbaL3UknN/vvkDyBKi85/dkj2wmsuwxWl/5Oen9q3H1HosebyUw7ib09jp7zd+6EO/AGXbX3sDpeAdOR2s9SXLXYlJ7l5M5vBFBduDqMRJ3r7G4e45E9tte3qLTi7v7cNzdh3e6p6aaxUi1Y3aopptqFsvQ8kUOQZIROlTTJU8QyRu+4Np7GiMdI3NsK5nDG8kc3YylZhDdAfzDLsM75FLbahVbPDW65RWSOxdjaVncfcYTnHgbzore6LEm2hb9ntTuJTjKe3baEwiKg+TOxSS2LQI1bYvZTboNR9n5BdmMRCuWriIHy/KPaZEGkruX4B9+BXKghLZFv8PSNcLT7ia1Zxlqw0GKr/oKCAKR5U/gKO+Fb4i9B0humo8Wb2Xw7d/ikkFVpLIaQ2sCeJxOLOH/sXeWUXJVWRt+brl7taUl0nF3D3ElCSQBAkEHBhhcBrfBB4dhcNcISYh7Qtzd3dNeVV3udb8ft1MkJEAY+TKQetZiLbpyu2vf25Wcvc/Z+31FirP0VPqTtM0yEqrJVR3GHxXbBUFAVmNZJghCphDPcAaZQjzDeefbNUd5tEbk6mQr+lUdC1l90JUucJIpkRLv2U/CT8XcZTTB7YvwLPqYrCueQ2kvwNjmYvzrp2FsPRhrrxsp+eQvVC/5HPugu/BvmoV70Sfk/emfWLpeKXmS712JvkZUTJ3XGO/q70iG/ci10my4TGOUdnYjAUIH1qKw5p7WAhXasxJL96vT3tkKcxbqWo2Ie0oIbJ6NTGMgmEwgKFQ4L3+GeNk+fGsm41nwAdVLv8LQvA+GlgPQNeyCrmEXkmEfob2rCe9diX/TLPzrpyIoVNJseH4zaRbsJzvUgiCraV+z/Mu/l1Q8Qrzi0I/zYMe2p3fUVbn1sfS4Vjp1tuYB0imAd/VE/Jtnk/SWIzc5sfb+M4aW/RGUGsL71+Jd8S2x8gOSeNvge1Bm1cW34ltOvH8TqUSUVNAjnZxfdB3GNhefJvRyNiJHt6DOqZ+ejZe8v2egb3IRKmcRrjn/RIxHsPb6E8HtP0in3YPvBZkMzw+fosyqk7ao8a2fStJbjn3082e03v8cfRpn0zTPjKxm51suE7i4RS6uYCy9qbThiIcxH68mlkihUsj45qZOmWI8Q4YMGX6GVCrFtuNeZAJUB6UcoHGumZQoWZlplXLiyRRWg4qy6gjRRAqjVkkCgURSJMuk42A0QAxQOAqwthlE5YbZ5LYYgspZhK33TZR8dgfVS7/CPvAOaYN+5Xj0TXtj7nIF1Ys/I3xwA8a2Q/Gtn5r2JBcEAVVNQR45siXd+fVb8gKlKQtl68GED24gFY8QPryJ0N6VqItakPCUkIqGyB7zMsEdiySx1m3zURc2x9hqMLoGnSTRtouuI3p8B6E9KwntW0N432oAFLZ8NAXNUOc3Rp3bQPL6PmUtk6k0yFR5ULNm/1bS2jGle4me2EXk2HbilYeln62zoG/UHV3DrmiKWkqn0qkkof1rCWyeTfjAepDJ0TfpganDCFTO2iS8Fbjm/pPA1gUgCBjbD0ff+CL866dy4qNbQZAR3L1cOgFv0BlL16tQZdX5xRgjR7YC0rjiSaqXS97fps5XECs/KDmitBuGXG+levHnqHIboG/aC++KsST9lTiG3o8gk5Pwu6hePZFaLbszqH9fajs0VHqjhJIiymSSIocBm16LTIgRjIuIgNOoQgQC0QRyAZRyGXK5DLksU4RnODuZQjzDeWf29tIzvr6qYyGd6tpRKWTEEymUChk9GziZsrkkfZ1aISP6k3ZgudaIudtVeBZ8QHj/WnT1O2LpeqUk8DH/A7LHvISp/aX4Vn+HodUgrL1vonLSM/g3zpAW3TWTqV76JbrijghyBdr6HfGuGk943yoMLfqjctZGbnJy4sM/o7Tlo7QXYGo7PP3+sfIDCHIF6lqS/7WYShKrOkrWRdehtNUifHQr3hVjiR7dBnIlnoUfYh90NzkNuhDau1qaCauZC1PlNsDQrDe6Rt0xtuyPsWV/UrEIkWPbiBzaROTYNrwrxvLjLJgdpaMQpS0PhTkbudGJ3GCVEgS1HkGp/nFRFlOkEjHEaEjyHA95SfqrSHgr0jPiCU8piNLzlZuy0NZrh6aoJdrabX70B0/ECO1ZSWDHIsIH1km2bgXNsPa8QWqTE0VCe5bjXT1RUk+35GAffA/6pr3Ss+PJaJCEtwyZWp8+Of+1XXiAhLeCWOm+tB8onOL93f1qYuUHCGyZi7HtUOQGG54ln6HKbYi+WS+8y09fcJMBD95VE9AWd0Rb1PIcP7lQ16HnmRk7SKakRTiZEpmzo+y0YvvUDaV4IsXqg65MIZ4hQ4YMP0MqhVRo65REBIFIPAmAXCagVcoIx5OIwHWdarNgdznxhEjnenY2HHFhNagRBFAqBEhI/y4bu47Bu30x1Ys+xnn5MygdP27QG1oNrNmgv43qxZ9hH3Q3gS1zcC/8iLw/vS11zi34ID22Jllp1Se0Z0V6pOlfyQvi7hNkjXwCmcaAd80kAptmIsZjyA1WwgfWYe9/G5buV+Nb+z2BbQuomvYSMq2pRqytJ+qC5mgKW2DtewvxqiOED24kenQrwV1LCWyZA4CgVKO0F9bMiOegMDmRG+zIdNKMuEypqRFbFQARMZn4cUY85CMZcJPwVZCoLiPuPk686uiP2jFKNeq8Rui7X4OmThtUOfUQBKlrIV51hOCOxQR3LCIZcEtdaF2uwNBqEAqjnVjVUapmvkFw52JAwNCyP+ZOl0vz9dEgcksuyQPrEKNBtMUdsHQbgyq73jl9dkJ7liM3OlDWFOyx8gOEdi7B1GkUcr2VqqkvIdMasXS9Eu/KcSSDHpwjHifpq8S3ZhK6Rt3RFDQDoHrpV5BKMOLWhyiy6YjFUpi0KmKJFDaDCqdRjU4pJyiTo1VKdnuuQAy9WkQuQDiRwKFXI9R0cSrkmUI8w5lkCvEM552firKdbOdtW2Tlm5s6ndbSm2PSMGdHGQNr1KrfX3rwjJ9nbDWIwKZZeH74GG2dNsg0Biw9rsM95x8Edy7G3PkKgjt+wD3/fXKufV1qU1v+LfomF2HpeR2Vk57Fv1maE1PlSF6Vga0LMLTojyBX4Bh8N9HSvaTCfrR12xLctRSFNa+mxaoCVVYdxJTUAh7as0ISS7HVQhRTkEqhMNhRdxyFb933hPeu4sShTaiy6iI3ORGjQeR6K5qC5kTLD+Ce/z7uBR+iKWyBrkFntMXt0dWT/gNJpTxato9Y2QHilYck/+vt0gzVv4RcidKSg9JRiL5RD1TZdVHlNkif7kvvGSS4ezmhfasI71+LGAtLwibthmNo3g+lo4Bk2F+j/DqDpL8Spb0A+5B70TfpKamnH9+Bd+UEIoc2SDPg3cZgajfsrDZkP0dg23xASLfkxyoP13h/X4zCnE35Nw9JC263q6he/i2poBfbyKdIeCvwrpmIrnGP9ILrWfqF1KrW60/n/P4yYEepj2g8xclhCZEzi+2fbih1qmv/uR+ZIUOGDBc8CoWM2g49hyqDKOQCDXJ+7Pay6FTokyIyARRyGdd0/vGE1KRVsqPEz/7KACa1glAkThwQtCYs3a7CvfAjZEfXY6rfnlSNl7hn/vtkX/0y5g4j8K4aLxXmvf8sbdDXjLn510/Ds/hTNHVaI8jk6Jv0xLPwI2IVB1Fl1f238wJtnTYkveWkEjEiB9ZJfufHd6AwZSEo1SjMWSSQnFL8m6SNeoU5G13DrmiLO6Cu1Vg6fe84QiryXceIle2vmRE/SuTodpI7FgNnH+v7NeQGGwpbLQzNeqN01kGdWx+ls3Z6Y18UU8TK9hPet4bQ3pXEXcdAkKGt2xZD31vRFncAmYzIoU245vyDyMENknp568GYOoxEYXKQDHnxLP0K/8YZ6QLc3PWqdGv7uZDwuySr2A4j0psCnh8+QaY1SeNou5YQPb4D24A7SATc+NZPldrn8xpS8f3zIAjpHCBaupfg9gXkdh1FTO+kQY6JrSeqUQhyEESOuYMEY0mMGhV2nYpcg45EKiW1nyOdhIeTwr/4xDNcSGQK8QznnZMz4T+dEYczLc0eHtz4NN/mMl+E2dvLTjsZF+QKSZjtu6fwrZ8qqaW26Jv2ENUVd8Da+yaqpv6dwObZ2PreTMknt+P54VPsQ+5DU9RC8iRv2gu5xoCx1UA8P3xKrPxAelf2VBVVpS0fMREDpJbtwPaF6Vlr/4bp6Jv3OxkZ6lqNiBzdSvT4TnKvfV2aSVryJZGD6+EEaIpaYWw9BFVOfexD7iFeeYTgrqWE9q6UbMvmv4fSXoimdks0Bc1R5zdGW9TytFNcURRJRYMk/VWSaEzYRyoaOm1GXJDJERRKBJUWucaITG9BbrBJc+I/8VJPxcKEj2whemy7JDJzYrfk4601Sa1ojbpLFi+CTBJnm/kGod3LEBMx1IXNsfW/FW3NxkF4/1p8ayYTPbETmc6Mpce1Ugv6OZyAn4qYiBPYPAdNnTanqJ1/InnIdx1NcMcioid2Yht4FwlfJf4N0zG0GoA6tz4Vk55BkMlPX3C3LcDUYQRKWy0ANAoZkbOIr6U/Y4BKKaNprum0TSQBzii2z7ahlCFDhgwZfp4iu55csxYZUmF+EkEQUCnOfrJYy6rj7r71OVQZYva2E6w84OaIK0RcBHPrIQQ2z6Zk3ofoardCptZj7Xk9rllvEty2EFPnywjs+AH3vHfJvf4faOu1p3rFWHSNe2DteQOVU16QOqxaD0bftDfVS77Av2EG9kF3pd//P5EXZF/1ErHSvXjXTiZ6bDsygw1jy4GocusjU+tQ2gsI71tNcPeKtG2ZTK1HU9QSdWFzNAVNUTqKpMK8+Y/PRkzESQRcJANuKSeIBBHjYUkpvUZYDLkCmVKDTKNHpjVJOYHBLjmunIIopki4ThA5voPI0W1EjmwmFfJKNmYFTbG1uRhdw67I9RaSQQ++9VMJbJlDwlMqnY53G5NWL49Xl+Ge/740O56IoWvQGXOXK875BPxUAptngyhiaNkfgPD+NUSObMXa9xYQpHE0VU4x+uZ9qZjwJDKlBmvP6wkf3CAp5/e4FoXJiSimcC/4ALneQpth13NRfQfti6zsLPFy1BMgkUiSSkEsCTuOVVPboSfHokGjlOM0aAjFEkSTKbQKOXIEBEHq5siQ4WxkCvEM/xNc1bHwN6tLbzjiYc6OMuLJFDIBTtVw09Zti7ZYaivXN+2FwmjH1u8vlH15H9XLvsba52Y0tVunPUZNHWra1VsOwNr7Jko/vwfv8m+w9b0FQ4v+VK8Yi3fNZJzDHjgjDlX2j6IhMo0BmcZI6Wd3obQXoC5oirFlf8RkIi2EYu1xLRWTnyMZrEZbty0qZ1FaXCR8cBORI5sR1Ia0X6ilxzVYelxLwn2c8IF1hA9tIrBlHv4N0wGpbVyVUw+VozZKRwEKq2RTkl6Mz5FULELcdZxEdSlx1zHilUekHXXXMalFXZChyq6HqeNItPXaoc5rhCCTE3efwLtqAsGdi0m4TyCotOib9ZY2FLLqkIqFCWyahW/9NBKeEmQGB6ZOl2HqOAK5xvibfucnCez4gWTQg7291P4X3r+WyOFNWPv8+ccFN68h+ua9qfj2kXRXRGj/GsL712LpeUPaH9Q9/4Oa1rnR6Z//S0U4AAI8eXFTPKFYTVOfRLf6Du7p2+CMYvunG0oZMmTIkOGXUSl+m81TOJZk7SEPiZSIP5pCJf8xL0jJFVj73Ez5hCcpWTUFc6fL0DfrLXmJL/4Mbf1O2PreQuXkZ/Gtn4q1z82UfHIbnkWf4Bj2IOqCZlQv+xpdo+7ItUb0zfoQ2DYfc/erURhsZ8b+L+YFYjyCqcOlJIIeIkc2AzK8K74FBNSFzTA07Y2uQRcMLfqTioYkobSDG4gc2Uxo70oABKVG6mbLqovSUYjCVgulJQeF0YHSknPOz1NMJUkGPMTK9hF3nyBedZRYxSFi5Qck/3BArreird0aTd22aOu2Ra41kYpFJGvVHT8QPrRRGlnLb4K561XoG3YDuYLo8R24574jqasjoK3XvqYF/ZdnwH+OVCyMf9MstMUdUFrzEBNxPIs+kXSCWg3Cs/gzkgEPzksfI7R7GdGjW7H1vw2ZSot7wfsobLUwtb8UQNKTKdlD9pB7yMty0CTPSqxmVCIUVeMJxSn3BklUJTFrVbj9Mex6NWZdjYCuTBpTUypkGauyDL9KphD/HeD1etm2bRuNGzfGbs+0tZ7k1Nnbs/1TZ+19k7SQLv4U59AHUOfWx9B6EP6NMzE074ut/1+kk/CFH2EffDfBXUtxz32H3Bv+gaHVQOm6Fv1RZdXB2GoQvnVTiHe9EqU9/6zxxCoPEz64AYXJiaJ5X+RaE4bmfUlFApIlR1Er5AYrYiJGKhJApjOTDHrwb5qFsfVgxGQCpS0XQ4+rCR/eRHDbfAKbZiI32NE16IS2XgeMbS7G1GEEYjJOtHQ/sZJdkk1JxUHC+9akZ7oBSRVdZ0GuNUo+4go1nFRFTyWlOGJhaUY86D2jnV1usKHMqoOuQWfUeY1Q12qETGOQ2tDKD+JdOZ7QvlWSdRmgLmiGueModI26ISg1xKuO4J7/Pv5tCyEertkgyAGFgljFQfwbZ2LpMlqyYPOUECvbR6K6HBERdU59yUP9LOImYjKBb/V3qLLroandmlQ8imfhhyjthRhbD8G94H1SYT/2y58luHUB0RO7sA++B0GhxD3/A5T2QkztpAI+uG2hJOA25N7fdCoviuAJxbDqVKe1ng1qlnvBFNxLliyhdu3aFBUVne9QMmS44Ni4cSMqlYpmzZqd71D+ZwjGEsRTItlGDcXZRg5VBrAZFETjCXxRJIXxk8JsTXqhMDmw9/8LJZ/fQ/XSL7APuEPawF/xLfpG3TF3vhzv8m+INO+Lre8tlH5+N9XLvsbe/y+YOowgsGUuvjWTsPX581njSUVDhPasIOGrQOUoQFOrMTK9GV1xx1/NCwKbZ6fzAkEmR53XgPCBdbhmv4Vr7jtoardEV78T2rrt0TfqBkhe2ZHjO4mV7JE0UrYvTM90SwjIdGbkOlONbozmFB9xETF50kc8SCrsIxny/iSnUKN0FqJv3B11bkPUtRqjsNVCEASSwWrC+9cR2r+ayMGNiIkocoMdU/tLMDTri8KejxgLE9g2H9+mWSQqD0sK7TozMpUWMRUndGBNuhBPRYNES/YSrzpKKhJAbrRJp+w/8Ss/iX/TLFJhH+ZOlwHgWzuZRHUpWVc8R7zqSLorTmHNo2LSs5IGT6uBeFeMI+EpJevyZxEUSlLRIJ4ln6Gr1ZA2PQdhNajwRqLsKY9L9rQqOQVaJSqFSIk3ikapwGlSk6ppSz/ZXaCsUfm/kNi7dy9ut5tOnTqd71B+V2QK8d8BZWVldO/enYcffpgXX3zxfIfzP8GGIx5KqsMoZEJ659GoVkh+4zUorbnpua9Iy4FoCptj6XEtoT0rcc19h5yrX0kvtIZmvbH3u5WKiU/jXTMJS/drCO1ejnvee2SP+TumDiMkH8/l3+Ac/tAZ8QR3LaVq2iv8dAbL88MnqAubI8iUeFeMQ2nLQ6Y1IVNpUecUEz6yBYUlF2vPGwBwL/gABBnOYQ9Ki/j+NYT2rCCwdQH+jTMRFGrUBU3RFLZAnd9UKswVKgBS8SgJTwkJTykJXwVJv4tkqFqyL4uFSYa8klUJJ1vTVZJHqCVHsiox2KRNBHMOCns+co0BqLEtqzpGcPdyIke3EjmylVSoGqmlrjHW3jeha9gVhclJMhIguHMxga0LiJXuAUBT2AJLj2sRxSRiVPL/THgrqJr+CtHabaia9hIJb/kZz9TQahD2Abef8Xpg2wIS1aU4Rz6BIAh4V39HwltO9ugXiJbuI7B5DsZ2w2sUUT9DXdAMfbM+VC/5gqSvAsdVf0eQK0hGAniWfI66VmP0TXv9ps+fIIA/HGf1QVf6NRlScX4hsGzZMnr27Mnrr7/Ovffee77DyZDhguOee+5h8+bNeL3eC16NOZUSKfNFCMcSaBQCZb4IuSY1DXNNFNr1rDlURTAaJ4m0QV/6yW14fvgE5/CHyMqrS7zDUCrXTMPQrA+2frdQ8vFtuOe9i+OSRwnuXCK1q9/4To096gwMLSRbLn3T3vg3zcLU/hIUJufpMUWDlH5xjyR6+hOkTeQ2iPEw3lUTUFpyfjEvcM1/D1VOfax9biZWupfQ7uXSuNrBdwDS42rqgmZoilpiqFnPRFEk6a8k7i4h4S0n6auS7MvCflLRgDSyFoqnO96k03k1Cmsu8ryGyPRWFCYHClMWClstFOas9OhaMuQlWrKbwNZ5RI5sIVZ+AJA28PXN+6Jv1A11QVMAose2410zieCuxZBMILfkYBtwBwpLNiQTp+UEsXrtCR1Yh3f5t6dtAgB4l39LznVvpG3afnzWIXxrJqGp00ZyqKkuw7tqAroGXdAUNqfsq7+m7VCrf/iUVNiH/fJnSHhK8a7+Dl3jHmjrtAZI68nkXvEUwbhAMBxn+wkvRrUSfzSBQi5HFFM0zrPSIFskEIvTKNuEWqkgmZK0Cy5Uhg0bRiKRYP/+/ec7lN8VmUL8d0DDhg2pV68eEyZMyBTicJodlEIu44oOBYxsk8/8HWVniLel577mv0fu9f9ArjFg63MTVdNfJbB5do2Ax1Jc894l78Z30TXqjnflOPQNu2HteQOu2W9JAmAtB2BqNxzvqvFE21+Sti05iXf1RFTZdcm6/BlkGgMJXyWx0n1EjmwhfHA9SX8VyOQIKjUaewGW7lcDkPCUoq3bBpAKabnBTvWKb/FvnoW+UQ90DbtgaNqLVDxK9OhWwoc2Ejm8meoln0tvLFOgchahzKqLyiGpoyodhWjqtj1jruvXEFNJkn4XieoyQruXnd6GVrOrLjfY0NZuJZ0s1GmDmEoQ2rtaUh8NeQkfXA/JBAp7Aaqc+sj1FrTFHVHXakQqHk3HpDBnIaYSCCoN+qa9kRttqPMapWe03Qs/IrB5NpbuV59my5aKRfCu+BZ1XiO09ToQdx3Du2aipCJbqzGln9+N3OTE0v1qXLP/QSoRxT7gDuJVR/Ct+x59876nKaKmwn5sl//ljLn4X0JAOhF/f+nBtG+4DGlm/EIRYvv6668BuPzyy89zJBkyXJiMGDGCZcuWMX/+fPr373++wzmv7C7zsavUj0wAs1ZJm0IToXgCs07FD7vLcerVVPmlTWilJQdTp8vwLv+GcMsBCLVboek8BvnOFbjmvkPudW9i6X41nkUfE96/GvuA2ykf+wje5d9i6X61tEE/95/kXPMalm5XEty1mOplX+MYcvqGZHD3chKeUpwjn0Rbrx2paIh45WEix7YTObgR39pJIKYkizG5CnWtJugbdwfOzAvEaBjX3HeIle5B17Ablp7XY+n1J+JVRwkfXE/k8ObTx9WMDlTZ9VA5fxxX09XvhExr+k2bNqIokooESHjLiZXuI7h9EfHKw8QqDpKoLpMukitQ5zXC3P1qtHXbIdOZiRzaQPjwJsIH1hHctYykvxKUGhSmLORGO/rGF2FsNfCsOQFyJZpaTaDTZagLm6PKqoNMayRWuo+yrx8gsHU+lq5Xnhand/V3pMK+mg1/UdLTkcmw9rkZ/4bpxMr24Rj6APGKQwS2zpP0YLLqUDH+cQSFCmvvmwCIVRzCv2E65lYDqNOwMTq1EoVcYHeJD7kM6meZaFVoRKtWolTIMGhURGNxcsxaZDKBaDyJTq1AEKRnJ7uAqvLt27ezZ88eHnjgzPHNDL9MphD/nTB69Gief/559u/fT3HxuatI/hE5zV88maKWRfKa/nzV4TOulSk1p8x9TcPccQS6xheh2bYQz5Iv0BZ3wjbwTsq/eYjqpV9i63MzkcOSsmfWlS8S2L6Q6h8+RVevA6aOI/FvmYtn0cdkj3n5tAUtFapGU6dtumhUWnJQWnLQN+6OKIrEyvYR2rOC4K5l+NdOxr9hOrp67dHU60C0ZA+epV8hKJTEq46hqd2aWOkePIs+wrPoI1TZ9dDW64C2blusvW+S7LaC1URP7CRasldSKz2wluC2+affu8YgtaZrDAhqHYJChSBTSMe5p7amR4OkQt4z29CUGpTOIvRNe6POrS+1odX4jyY8Jfi3LcC/ZiKpaBBEEUGpwdh6iCRyZ8rCs+ADlNY8SRgmFj7NEzzuPoFMbUCm0mDpPuaM35umoCmBzbNJhX2nFeK+NRNJBtw4hj8CiLjmvoNMocba+0a8qyYQdx0la9RTRI5sJbR7mTS/Z82l7OsHJXGetEDbPgKbZmFse/Fps3w/x6096mLUKjlRHWbc2qOc7DgTAZkAXYvPPhv+R2X27Nl07NiRWrVqne9QMmS4IBkzZgz33nsvs2bNuuAL8TJfFJtBiVohp9IfxW5Q4auK0zDHSCKRZM62UhRyIAlJwNxxJMHti/DMfw/NDf9EptZh63crlZOfw7fue0wdRhDctQT3gg/Iu+k9DC3641s3BV2j7lj7/JmqaS/j3zAdU/tLMLUdhm/t9xjbDj1N4TsV8gKgrd0KQZAh1xiQFzSTNoK7jCYZrCa0bzWh3csI7VtFaO8K/Buno2/cA02dNoR2L0vnBcmgB3VOcbozTqY1oa3XHm3ddhha9MfccSRiMk6s7ADRkt3SuFr5AclS9NRTZbkSud6KXGtEpqlpTZefbE1P/WhfFg3VtKZXI8ajpzxpAYU1B1V2PQytBqLOa4Qqpz4ypbqmg28t1ZOeIRUJ1AjVCWjrtUPf8zpU+U2pXvz5OeUEKkcLSQD2FNR5DZGptOnnepKEtxz/+qnSZnxOMcGdS4gc3IC1958RUwmql32Ftl57NMUdKfv8LhSWHMzdriK4fRGRI1uw9b8NhcGGKKbwzHsXudZAg8E3IFfI8UdjlPgUiAkRuUJAowqj1cgZ1aYQrUpBIiUSSyiQyWTSiGRNSigV4BdOEQ4wc+ZMAK655przHMnvj0wh/juhb9++PP/880ybNo377rvvfIdzXjmbHdTqgy6i8bMLbOnqd0Rb3AHvim/QN+6GwpSFbcDtlH5yO+757+Ec8Xi65UzXqDvW3jfhmvUmgU2zsA+4g5LP7pSuu/RRLD2ulWzQdvyAoVnv9HsozNkkXMfO+v6CIKDObYA6twGWi64nVrKH4K4l6fYyFCqSQS8KkxNrzz+hypJmbuPuE4T2riK8fw3eVePxrhyLoNajKWiKOr8p6lqNMHcZjUylAaRWsbi7hER1KUlfJYmAm1SomlQkkF4Yf1RNl0mFuVIjtabnNpAUUo0OFJYclNY8ZAYr4b2riRzfQTLsJ3J0G9GV4yQrFH+l9HPUekztLkHboDOuWW+hb9oLdU4x1SvHoSvuAHIFsZI9iLEIqLSIiRiCQoVv7WTUBU1RmLJ+VGw9hbjrOAgy5Ke0oMWry/CumSRtpOQ3xr9pFtFj27EPuotk0IN39QTp/Ws1puST21A6a2PuOBL/xpnSHPjF9yPXmhBTSdzz3kGut6Q7E34JQYB+TXNoW2RlwxEPkzceT1uWCUiCQhdSEb59+3aOHTvGVVdddb5DyZDhgsXpdNKgQQOmT5/Om2++eb7DOa8UWLVsP+FDJE4tqxaVQo5OLccbjDFlcwkalZxkSkSQprhw6Iyk+t/K8QlP4V0zEUvXK9HV74SuQRe8K8aia9AF+8C7KP3iHjwLP8LW9xbCB9fjmv0WOde+LimqL/sKbf1OmLtcQWD7orQN2snuKoU5C4C469hZFcDlegvGVgMxthpIMuiRLEF3LaN66Zew9EuUjiKiJbtRGGzYBtyO0pIjOZgc3FBjHbqG4PaFgIAquy7qgmaoazVG17ArxnbDpXnlRJy4R8oJEt5yaVwt6CEZ9iFGwyTDVZD8iWq6Qo1Ma0Jpq4VMb0FhsKMwZ6Ow5iA3ZxM5uFFyPdGYSEWCeFeMJXpsO9HSvVLRL5OjK+4ozeKvGIu525j/SE6QDHgkj3FT1mmvuxd9DIKA5aLrSIa8uBd+iCq3PoY2g6n87ikQZNj634ZvxbckPCVkXfEcYiyCZ9HHqGs1xtBqIACBrfOJnNhF1uB78Cb12AUwa9QkUyKV/gitCi00yzMjiqBVylDIZSjkoJLLiNWIu/5WccE/EjNmzMBsNtOkSZPzHcrvjkwh/juhe/fu2O12vvvuuwu+EP9X7KBsfW+l5JO/4J7/Ps4RT6C05GDuNobqxZ8S2r0cy0XXEdq/Vlpor3sLza5lVC/5HG3ddli6Xkn10i8J7l5+ig3ap2iLO6TnqDVFLfGumkAy4EFu+Pl4BEGyKlHXaoS1901Ejm2XdsT3rCTiOkrZse2oa7dCX78j2rrtMXcahbnTKJJhH5HDm4kc2ULk2HbC+9fW/EAZSlstlM7aUlu6rRZKay6awhaSFZn83P6Ki2KKVNhP0l9FwltOcPcywoc3ESvdK+2Ip5LS22nNaAuboSm6jFj5AZTO2hhbDUSQK1HXakx4/1rpVCARR13QDFJJIoc2Etq3Cm29dihMWUSObpUUzwfemX4mp8ciEtq3Or0DfvI19/z3JGu6XjeQ8FbgWfwZmqKW6Jr0rJkBM2Lt82fciz4mGazGOfJJkgE31Uu/RFOnLfomPQGkwrxsP46hD5ybb7lI2he8bZGVJy9uypNTt5NMichlAk9e3PSCKcIBxo0bB8DVV//6JkaGDBn+e1x22WU8//zzbNq0idatW5/vcM4bxVkGrHoVyZSIwyC1OueatPhDceLJFI1zTOwo8RJLiPRp5KBhnpl1WSbG7VhA9aoJ6Bv3QGmrhbXfrYQ//guuOW+TfeULmDtdjnflWHSNumEbcAeVk57Bt3I8tv63UfLJbbhm/4Ps0c9JY2yz3kiPsQGoC1sAAqF9q3/Vikuut2JqOxRT26EkfBUEdy0ntGcZ0SNbiCK1TGvrdUDXoBO6hl3QN+qGmEoSLdlL5PAmIke3Etg8G//6qQDIdGZUzjooa8bVFJYctHXaIDfYEVTac25PT8UiJAMuEr5K6b1WjCNybLsk6nZSBE6Qo86tj6nTKBKeMtR5DTC2GYIgVxI5svU/khMAhPavAaQOg/RrB9adZjtWOe0VUpEA9iueI7h1PpEjW7ENuINkwI1v3RQMLQegrd2KyqkvkYqHsQ28E0GQkQxWS3oy+U3RNOtDSoAST4x6Djl6hRyPXOCYJ4x/VzmXtslDLpenY5DJJAG3C5ny8nKWL1/ODTfccNqzyXBuZArx3wlyuZxRo0bxwQcfUFpaSm5u7vkO6bxyNjsoWY1w20nUClnaX1xhzsLctabw3rsSfcOumNoPJ7R7Ge4F75NX9C72QXdRMf5xvMu+wj7wTko+vR3XzDfIGv2cJIwy7100Bc2wD7id0i/upfqHT9MeovomPfGuHEdg6zzMXa44p3sQZPK0B7it31+IHttBaO9KQvvW4KpZdFTZ9dDUaY2mqBXa4g7oG/cAkFrTS/cQK91PrOJAWsDlp2Jxkm2KQVp8FSoEmRyoaU1PxkmdopD6U2EU5IqaBbxtjZpqFEPLgWjyJR9337opxCsPS61tgKaoOeGDG4h7Sggf3iTZnbiOkvBWEKs4hDq/KaKYIrBtAUpbPqEayzFdg85oCn9sQ4uV7CFeeRhb/9vSr4V2LU23m8kNNirGPwGAfdBd+FZNIF5xEOeIx4mW7JE8wTtfjiq7HhUTngRBwD7gdgRBIOGrpHrZV2jqtEFX8yx/DaVCxonqMBuOeGhbZMUTipESRUQgJYrM3l5KwxzjBVGMi6LI+PHjqVevHk2bNj3f4WTIcEFz9dVX8/zzzzNu3LgLuhAXBCFdgJ8klkwRS4hc2roW07eW0L2Bg/pOA3qNktWH3Bx0B2k16m6WvLgB97x3yLrieRQGG9ZeN+Ke8w8Cm+dg7nI5oX2rcM/5J7k3vou+WR+8q79DW78j1t434Z7zNv6NMzG2uZjAtvnpMTa5wYrCYENT1JLA1gWYO19xzpviClMW5o4jMHccQcJbTmjvKkJ7V+FbMxHf6gnIdGbJLqx2azRFzbF0uwq4SmpNLz9ItHQvsfIDxCsPE9g6DzEeOf1ZKdRSXqDWI6jUCDIlNUPNiClpnU9FJTeVn34vgoDcYEdbt23N9QkMrQajrWkj962bQtx17D+eE4iiSGDzbGljIatGVT0Wxj3vXckFpcOlhPasJLRrCeZuYxCUGjw/fIqmdmtpo/6Le5Ab7Fh73Sg9z93LMHcbg8pRiEEOhxd9TCoWwT7gDgRBQC6AQi4jngJfNInNoCbbpCHfosOoUlLmjRBPpcgySt7hp8Z5UildEIQLRkRx7NixgDRCm+G3kynEf0eMGDGCDz74gFmzZnHjjTee73D+p/hgyYHTivAOta1YdCrm7fxRjdvUfjjBnYvxLPgAbVFLZBoD9sF3U/r5PdJJ+fCHMLQejH/9NHQNOmPrezOumW/g3zAN++B7Kf3iHlxz/4nz0sck3/E1kyS1zdqtUNrz0dRpi2/DNIzthqfbxc8VQSZHUyTNRVn73kKsbD++9VMJ719DrPwgvtUTQaZAnVOMulZjVHkNUeUUo6nbPu1TKSZixNOK6VUkA25pzisSQIxFpNb0VBIQEeRKZFojCqUGmVqPTGdCrrOgMDqQm5woLDn41k5CrrdhajeMhLeCwNZ5xCsOpAtxda3GhPauREzEERRK1PlN8S7/FqU1D01+UxS2Wpi7XEFo32pkGgMqZxFxTynBHYtR5dYnGfahyq6H0nG6/ZV3zURkan36BPvUdjNj24vxb5wpzXYNuINksBrvqgnom/VGnd+E0k9uR+kowtLlSknN9fAmbP3+IgnBnBRxSaWw9b/tnBbJHJMadzDGuLVHmbzxON/c1Ck9GnFSp2D5virWHXbzzU2d/vDFeGlpKfv37+eBBx64YJKMDBn+V2nUqBHFxcXMmzePl1566XyH8z9DKpVi4oajLN1dhUmroEs9B0UOHdF4inHrjqFTysgxajkcM+PsfT0Vc94ltH0hjlZ9UbToJ81nL/4Mbd22OIbcS+mX9+Ge/x72/rcRObKFqhmvk3PdG4T2rKR68edo67RJj7G55r+L85JHEQQBY7thVE56huDOJRia9/nN96EwZ0tz6O0vIRH04F31HaGdiwnuXUlw52IA5OZsNDU5gTq3Afrm/TC1HQqcVE13kfCWkfBVkgy4SAW9JCN+xGiIVDwKqR9b0wWZFsGoQqbSIdMYkOstkpuK0YncnE1g8yzJkuyUnCDhOgo1hfh/KyeIHN5MrPwAtppCGaB6yeckfVU4xrxMKhLANfefqLLrYeowgorxj4NMjn3QXXiXf03CfZysy5+tGUt7F2VWnR9tzg5uILhzMTk9RpNXUIg/JmJQy8mxaOlYx4qAnHJfGHcwSqNcEyqFnDJvBLlcwB9J0DjHdJoo28n4ztZe/0dl/vz5GAwG+vT57Z/xDJlC/HdFt27dUCqVTJkyJVOIn8K3a46eVnADbD5WTaHtdF9oQSbHPvBOyr66H8/iz7EPvAOVszbmrqPxLvuaYMOuWHveQOTQJqpmvkHu9f9A16AL1Uu/Rlu7NdYe1+L54RPp1LvrVYT2rcY1+y3y/vRPZGo95i6jKf/mAXzrp2Dp8q/vDEqenB5SQQ9Zo54itHcViepSlPYCIsd24Ns4A9Z9f/KmQCZHaauFvslFKMw5yE1OVM7aIELcdRR1fpPTRFFO5eRiIaaSUtEe8JDwV0pz4SGfJO6G1Oom11tJ1MyGA6hyihGUWsKHN6Ir7giAwppHKhrC2vvHz2e0ZDdiPEYqGkJhtJN/59c/6wUaPbGL8L7VmLuNQabWScXz3HdIRYPYB91N3H2C6sWfoa3bDl3jHpR9cTdyox1rn5txz32HZNhH1mV/Ixn0SDNghS0wtB4EQGj3csL712Lp+SeUlpxz+l0U2nSU+6KIQDyRYvVBF7f3Kuabmzrx5oK9LN9Xddqf/dEL8e+++w6QNCsyZMhw/unduzcffvhhplPuFPZXBtl81IteI2fdIQ+CINCilhm1QYZdr8asVlBok+GLxsnpMATvtsW4F32Mpm5b5HpruiOuatZbZI9+DnPXK/Eu+xpdcUfsg++hYvzjVC/+HPuguyj99A6qZrxKzphXsHQfQ/XizwntWoK+SU+09dqjyq5H9Ypv0TfujlBjNfqvECvdR6LqCM5LHyV8YB0JXxXqvAZEjm4ncmRLujAHQCZHYc5B17AzCksecpMTdU59EBoQry5Dcy45gSgixsLSTLnfReToFmIVh1D9P+cEopiieukXyE1ODM2kQi98ZIvUidB2GOpajaic+DRiPIL94vvxr5tC9MQu7BffT6K6FP+6qRhaD0ZbpzWV01+pyRGeRpArSMXClM5+F7U9n06X3IDeoEchyEikRIqzDTiMWqr8EaxaFY1yjQxrnU91MI43EseiUNZsxovIThFmO/VE/EIgHA4zd+5cBg8enGlL/xfJFOK/I3Q6HUOHDuX777/H7/djNBrPd0j/E8zefqZPZzwpUtdpYH9l8LTX1bn1MbUbLllZNe6BpqgF5o6jCO9bjXveu+Td+A72IfdR/u1DeBZ+iG3gHUQ/vYPKaa+Qc93rhA+ux7PwQzT5TXEMvpeybx7EveADHEPuQ5PfGG2DzvhWT8TQrC8Kk+OMuE4SqzhIYPsiEp4SxGQSmUor7TxbctAUtSJWuhdVXkNJYVWQE9g0E2PrwVh73oCYjBPatwb/xhkgUyDGI8QqDlK95IuzvpdMa0KmNUkd6WG/NNslU0jz0YJAKhpEjIb4aVu7qlYTSUk1lZQsRmQyBOWPJ/2CTI6x1QCCOxZLyu2HN6MpbE5w9zLinhKSvkrEaAgREYXBTuToNmmu/meSEVFM4V74IXKDDVP7SwAI7lhEaO9KLBddj9KaR+mX9yGotNgH3Y1n4UckPGVkX/kC4QPrCO1ehqXHtWlbEgDH4LulGbCwD/eCD1DlFGNqP/xnfy8yAU42VvSo72D1QVf6qcjlP9qTtS2yck/fBqw77D5NNPCPzvjx47FarZmd7wwZ/kcYPXo0H374IV9//XXGOqgGuSAQiac45gkiyAQOVQXYUeajf+McbuhSxNQtpQTCcQQEbEYt9kF31oiyfoDzkodRmLOkFvW5/8S/cSbmTpcRPrBO8hP/0zsY21+Cf90UtHXbYht4B1VTXqR6+ddYul9DeO9q3PPeQ53fFIXJiaXnDVSMfxzfuimYO/+83WMqFiawdR7Rkj2kIgEEuRK5zozc5ETpKCReceiMnEBXvxOmdsMRRZHIkS341k5CUKgRE3GiJbvxrZl85sgZksiqXGsCmUzKCeIRSahNbUCQy0lGQ1JOkIz/9DuRaQznnBNEjmxFU7et1JbuOk7CW04q7ENMxJCpdIjJOIaW/X+2CAcIbltIrGw/9iH3SQrykQCumW+isNXCctG1+DfOIHxwPda+tyDGwlQv/wZd4x7o6rWn5NM7UVhzsfb8E8E9KwjtlFrXTzqlVC/9koS3nAY3vUJxng2HUUeOSU25N0yzQiu+UAKLTkmuSYdCBlqFnPJEFFcwRoUvSot8Cwr5jwJtF1I7+kkmTZpEMpnkiivObSQzw5lkCvHfGaNHj2by5MlMnDiRG2644XyH8z/BoGa5LNtXddprggC3XFSPng2zGL/uKAergvgjkmK4ufsY6TR7zj/IveGfyFQa7EPuo/Tzu3HNeRvniCcwd7oM76rxaOu2wz7kPirGP0H1ok+wD7mX0s/uonLay+Re8yrmzlfgXTkWbZ226JtchLXXjZQevA33wg/IuvSxs8YbPrCeiknPIMgVkoeoQknCW0by4HrEeARrv7+QioVR15JawGVqHXK9lbjrOApzNoJcSSrsQ5VVB0u3Mcg0BjyLPycVDWJsPYh4dQXhvSuJluxOt53LlFpiFQcRZApUhc1JeisQEzFUOcXI9RZkGqO0w22woTA6UJizEVNJXDNfJ1a6D3WtRkRL9qDObUDCV0Vo/xoEQUbCV0GiuozwgXXShsCJXdJNyhTIjXbkWmmzKFSyV/Lv7DgKa8/rz/pcAptmESvdh/3i+5GptMQ9pbjnv486vymmDpfiWfgR8crDOEc9ReTYNoLb5mPqfDlyk5OKSc+grtVEspjbMCMt0qIwZwPgWfgRqYgf+xXP1szJn51TphtY+pPP1Ki2+aedeP8rooG/Z0pLS1m1ahU333xzZuc7Q4b/EXr27InT6WT8+PGZQryGOg4dbYosHKgM0CDHiFmnRKdU4DCqUSnkNMwxUuIJkThRDakUansBlq5XUb30S0J7VqJr2AVDywGE9q2SWs9rt8Zx8f2UfnYXVTNfI2vkU0SObME1601yb3hbsjdbPQlNUSvsF99P6ed3UTXjNbJHP4+2ditJjX3lOHSNuqO0ntm1ICZilH39APHKwyjM2ch0ZsREjFjpXpJBD8rsumgKmv98TiAIJDwlKG35p+cE8YjU1l5dSnDXMqLHdyLT6JHrzAgqLbHSfQgKldQW7q0gFY+idBSh0Zl/HFnTW5Eb7CjMWQgyOa7Z/zgzJwh4iBzeLBXKvgqSIS+RI1tIRQJEj+9I3+fJU3RBoSJRXUZoz3L8m2aSe/0/0rnCqSRDXjyLP0NdqzH6pj2lDrk5/yQZdJNz9Ssk3CV4fvgEbb326Jv2kmbBjQ5s/f6Ca967JAMucq5+BTEWxj33HVQ5xemW9Njxnfg3zMDefgiaWk1Zta8SmVxBlklFgd1AXiCGJxzHolUQSyRYc9xLiTdCbbuWXLOGcCyJ3fCvdzj8URg/fjyCIDBq1KjzHcrvlkwh/jvjkksuQS6XZwrxU7iqYyFrD7mYsrkk/drN3eumC6M95X4ip1ibyZQa7IPuonzsI1Qv/QJb31tQOQqxXnQdnkUfE9gyF3PXKwkf3oxr7j/Ju+FtTJ1G4ls9EU1RS+yD76Fy0jO4F32Mre8tRI5I16ly66O05mHuehXVSz4nuGtpWlztVHzrp6IwZ5N73RvIalTXQWppSoWqSUXD+KqOINfV7BKLSVI1Fh8nibtPSCfdNd8vJqLSzFVWXaIndqO01UJd0JREdSm6+p1R5zXENUfyS7X2knaHw3tXYe42BqU194x5JjGVJBlwITdl4V74oeQ3HvISPbYD9/wPQEzWPEw5SnsBuvqdUGbVRuUoQmEvQGFynlbwiqkkVTNfx7d+KpbuV58hXBOvLsOz5As0tVujb9ITMRGnappkBeMYej+hfavxb5yBsd1wlLZ8Sj+/G1VeQ8ydLq85/RZwDL1fal1f8jnaeu0x1KjXhvatIbjjB8xdr0RVI/Tyr9Asz3zGa2cTDfyjclIt/corrzzPkWTIkOEkJ5Pg9957j+PHj5Ofn3++QzrvyGQyru5UG1GEkuowOpWCLsUOBEHgYFWARDJFo1wjO0t9HIwHsGjl6LqPILxnBe5576IubIZca8I+6G5KP7mdqhmvkXP1K9j63oJr9lv4N0zDOexBSr+4h6oZr+K89HGiJ3ZTNf1Vcm/4B7Z+t+Ka+QbeleOwdBuDte/NhD/ejGv2W2Rf+ULa4uwk4UObiFcexjHswTNyhlQ8Stx1jMDWeb89J1BpUVpyiBzaiNKcjTqv4b+eE4giqUgAha0Wnh8+IRWPkAxWEys/hGfxF4ixH7sP5QYbqpz6qLJq1zi6FEh2aD9ph48c2075tw8T2r0MY+vBZ/we3fPfJxUNYRtwO4Igw795DqE9ksuN0pZP6Rf3INeasA26G/fcd0h4K8i+6iXC+9cQ2rUUS/drUOU2SLeuO4bcL7Wkx6NUzX4LhdmJ86Lr8UVS+ACLLoEvLMMbirP1WDUmnZoqf5Q9pX6K7Aa8oRhrfGGK7AZEwB+JY9FduMV4MBhk9uzZDB8+HLVa/evfkOGsZArx3xlKpZKBAweydOlS4vE4SqXyfIf0P8Gbo1vToY6d2dtLGdQsl6s6FgKkPZ9/iqawOcY2F+PfMB1dw65oCpphbDeM8IF1eBZ9hKawOY5hD1D62Z1UTXuZrCueI3p0O645b5N73ZuYOozAt3YymvwmOIY+QOnnd1E55e/kXP2KpOBZo7KurtXkjBZ1MRFFbrCeVoSDlFDJ9VaQK9OtaQDJsB9BEJCdsmOcigRQ2mqlv06GvGjrFBOvOkbCV4mp40iS3goSnhJpd10UMbToh3vBhxx7a7S0MOY0wLdhOmI8ihgLkYoESIa8pIIekiHvT1rahJoWuQJ0jbqidBShchahtBek4/wlBJkcpbWW1Or2k1a5kyfvAPaBdyIIAu4fPiFWtg/npY8hJuK4Zr2FKrcBlm5XUT5WEsJxDnsI76rxREt24xj2IHK9jYqv7kdQarAPvEuatQ/7cc/9p+Qp/gttgb+GTABPKPYvf/8fgR9++AG1Ws1FF110vkPJkCHDKQwbNoz33nuP+fPnZzboa9Ao5VzfpQ6VgShmrQK9WokoinjDcY55Qoiijva1bXQrdhCMxlm+rxLl5fex+Z27cc97D+fwhyQP75Ot5yvGYul+NeFDG6le9jWawubY+t2Ga9Yb+NdOxjH8Icq+uo+qaS+TPfp5Ike24F0xDnWtJmjrtMbW5yZcs/+Bb+0UzB1HnBbrSXVypb3gjPuQKdUoLDn/8ZwAQN+8L56FH6VzAn2zPvg3z0aMR0hFQ1JeEPKRDHlJBt2SlekpCBoDMpUafeNuKB1FKJ1FqJy1kevO3LQ+GyfvV0ycubYGdvwgjZt1vwaVszax8gO4F3yApnZrjB1G4Jr2ConqmtG0vSulay+6DrnWSMWE91EXNsfUaRSBTTPTretKh/R+vqVfEnefoHD0c0TkWhSASgEapQyDVkk0lsSsV9Ikx0BFMIYvGKOWTYsrEEMQZBRnG0mlUsST4hlxX0isXLmSZDLJxRdffL5D+V2TKcR/hwwaNIiZM2cyb948hgwZcr7D+Z/hqo6F6QIcYMMRD+PXHf1xxvcn9maWi64nfHADrplvkHvD28jUOkkd/bM7qJr+KjlXv4x9wB1UTX8F78pxOIY/ROnnd1M55UWyr/o70ZLduOa8Tc41r2G/+H4qJz6Ne9572AffLbWxfX43VdNfkXbATzkdVuU2wL9xBsmw/6ztWHKNQZodT0jzWaHdy1CYcyQBthpO+l+H9q3Gs/gzEt4K5FoTscpDxE7sJqDWkfBWkAi4cc95m4SvglQkiKnLFRia9cE1520CG6YCQs2clxqlrQCF0Y48pxi5wUZo7ypS8TDZVzyHwpR1zvYrP0f40AaUztpnCNZ4V4wjenwn9ovvR2HOIrDjB+n0u/0laGq3puyr+xFkcpzDH8az+DNi5QdwjniCuOuYNI/foj/6xj1wL/pYsjEb+UTay909/71TxFn+tU0rmQCqC2QG/Ofw+/3MmDGD4cOHX3AzcBky/K/Tp08f5HI548aNyxTip6BUyMiz/HgKu+V4NZsPu3EFo4SjCa7qWMSW4z62HKtmT3kA0VJ0mnirvlE39A27Em7eF9+qCWhrt8I+8A5ipXul8bTr/4G+eV+8q8ajrtUY24A7cM14jerFn2Prdxuxsv1UTX+F3OveRN+8H+ED66le+gWagqao8xqm41LV/H/40Iazdm39lpxATMSJHNlCaO8qosd3oTBnQSqZzgmSoWqSATfViz8nfGgDcp2FrCv/jn/9FCLHdxLasUgSgRVkyNRaVFl1UOU2qLFksxPYuRgxGjqjo+9fIXxoo3T/ufVPez3uPpE+yDB1GkUy7Kfi+xeQa004hv6VwPpp0sl4z+sRFCrcCz9EU7cthjYXU/71AwgKFY6L7yfuOibZmNVti7GNVCxGjm3Hu34ajraDMddvRSIBShlolJBr1mHVqbiogZMkApFEitpWLVh1lPljGLUKWhcaiCaSgED+Bd6a/u233wIwePCZ3QwZzh3Zr1+S4X+Nk6II33///XmO5H+bSRuPkzjl8LWOXcepJYQ0G34PCW8Fnh8+AUBhcmAfdBexsn1UL/0KfZOLMLQciG/NRGKVh6V/3CuP4FnwAY5hDyFT6aic/DzqvEaYu1xJcPsC/BtnoLTVwjbgdqLHd5whomZo3heSCfzrp/1s7JZuY6he9hXl458g4a1A36Qn0ZI9xN0nANAWdyB8dDuehR9h63srqryGRI5sQZPfFF3DrqTCfmLlB4ge20YqFsHW52aUjiKihzcj1xjQ1mmDTGOg1u1fUPTAVArvmUDuta+RNeop7IPuQumsjSq7LjKVDqU1798uwiPHthMr2YOhRb/TXg8fWI935Tj0zfpgaNqLaNl+3HPeRl3QDEuP63DNfou46xiOYQ8QObqVwOY5mDqOQpVdj6oZr6F0FEptfwfW4183BUPrIWm11uCupYR2LcXcZXRanOVs2HSnF+gdaltpmW+mf5NsXri0Off3b3hBWJP9EhMnTkQUxYxPaIYM/4MolUpGjhzJggULCAaDv/4NFygHKgLotAqa17Jg0KhQKuTYDUpsOgXJpEgwAeZOl6HKrS+1OgfcANj63oLCmkvVjNcQU0kcwx8iGfDgmvm6dNKaVYeqGa+irtUYY5sh+NZ9T3j/GqmjK5mg8vvnERNRbIPuQm6wUzn171LXWQ1KSw7qgmb410+XLMXOwrnkBJFjO0jFozVWXnXJvvIFkqFq1PlN0zlBrGQPvjUTERMRdA27YWw3HN+q8egadIZUAoW9gKIHp1H0wBQK7hpL9ugXcA57AFvvmyRHFkchglL9bxfhYiqJb80kFLZa6dl3gFQsQuWUFxFkChzD/gpA1bSXSQZcOC99lFjFITyLP0PXoAv6Zn2pnPIicr0Vx8X3U12jI+MYch8yjVEab1PpcAy+B0EQSEVDVM18A4Ulh7qDbkCrllPg0FCYZeTyjrX5+6gWPDeyBVd1qcvg5jk0zTNTnG2id+NsxnQo5JKWeTTNM9Mg20jjXCNm7YVbiKdSKSZMmEC3bt0ybg3/JplC/HeIw+Fg3759/O1vfzvfofxP89Nzu7pOA2rl6R95TX5TTB1HENgyl9D+NQDoGnTB0HowvrWTCR9Yj7XPn1Fm1cU14zXJB7PbVQR3/EBoz3IclzxKwldJ1bSXMXW5HG1xBzwLPyJ8eDOGpr3SPye4c0n6PVXO2ugadsO3djLx6rKzxq6p2w7bgNsxthuGrc+fUZgcJHyVJPxViGIKbe1WCKkEyVA1nkUfYel6JbqmvYi7T2DqcCnWXn9C37QXcp0FW9+bUWbVIRn2ET2xG4C4pwQQkCk1abuNk6SiIfwbpmPu/J9RwRSTCdwLPkRusKfntkHa9a6a/grKrNrY+v+FRMBN5aRnkenMOIc/jH/t5Jp2s2uRaYy45r4jqdx3HU3V1L8jJuM4L3mYVCRI1aw3UDprY+31JwAS/irc895FldvgZ1vSW+abeeHS5nx0XXtUCsmARCkXKM428uTQpnx4bTuu6ljI7b2KL+giHCRBFrVazfDhP684nyFDhvPHu+++y/bt2zOzmr+Aw6BizUEPc3eUEU0kseuUGDRK1EoFwUQSUZTGqBxD7kNMxHDNegtRFJGptDiGPUgyWI1r1puocupj7X0j4QPrCGyYjvOSR0EUqfz+eczdrkad3xTX7LdIxcI4hj1ArPwgrplvIlPrcF7yCMlgNZVTX0JMJtKxWbqNIRlw4Vs14ayxn0tOINcaKP3sDlKRAJYe16KwZGNo3heZ1pjOCYxthyEmExhaDSZWtk8q4I9sIV51FEGmSOdN/+28ILBlLvHKw1i6XpWemRdFUdp8rzyCY+hfUZiy8Cz6mMjhTdj734ZMa6Jq6ksobfnYBt2Ja/rLJIPVOC95lPD+tZIgbOfL0dZrJ4m7Vh3FMeReaeQPcC/8kKSvksJL7kWUa8kxqlErlRRaNVzSuoC6WWacRqmDwqpX0yTXhAzYUeKj3B9Jq6LrVArUigtbsHT27NmEQqHM5vx/gEwh/juluLg4LcoSiUSIxWIcPHiQcDh8niP736HpT8S1ejbM4vrOtc+4ztLtapRZdXDN/gfJoAcAa68bUTprUzXzdVJhP85LHkEURaqmvIix/SVo63fCs+gTxEQU+4DbiBzeRPWij3Fc/FeU9gIqp7xIrOootj5/Ti/K0ZOK4oC1900gk+Oa+TpiKnlGTIIgSAV7vfbpOSp9o25oi1qmFy1d4x7oGl9EzpiX0Ra1RGFySotyKokoppDrrYiJWFoJ3dDkIlKxMCVf3EP0+C7EZJzysY9QPu5RIse2p9+7etnXmDqN+rc8T0/Fu2Is8YqD2PrdgqzG6iQZ9lEx8WmQyXFe+hiIUDnpGVLRIFkjnyB6fCfVy75C36Qn+ia9qZz8nLTrPewhqhd/RrRkN/aBd6Gw5lE1/RXEeATnsIeQKdWIYgrXzDcRk3EcF99/VpV0hVzgyaFN06MMl7XNp1+TbARg3NqjjPl4NRuOeP4j9/97J5VKMXfuXIYPH45Go/n1b8iQIcP/O3a7ncaNG6NQKDI5wc+QTELTPCNtC20YNQpEQaBJtpF6TgONso04jWoMKgG9swBLz+uJHNpAYNNMANQ5xVh73UB4/1r866ZgbHMxusYXUb3saxLVZTiGPkC84jDuee/iGP4wMp2ZyknPSO4mPW8gtGc51cu+Rp1bH/vAO4ke3Yp7/nvpgldT2Bx9015410wkWrLnjNjPJSew9fsL1t43omvYJf263GAj4a9M5wRyo52k34Uqpxh94x5UTnyaVCxM+OBGdE17kfBXUfr53f/VvCDuPiG1jBe1QneKOJ13+bfpWW9t3bb4NkzHv2E6xnbD0TXoQuWkZwBwjnwC7/JviRzZin3AHSAIuOe9i6aoBZZuYwjuXEJgi9Q9p63bFoDQnpUEty3A1GkUspwmBBPgCSeRI1Jo11Fgk1r7UymR/RV+Vh90s/FYNSe8EVQKGcfcYfzRxJk3c4Eya9YsICPe+p8gU4j/jgkEAsyePZtXXnmFUaNGMXjwYG666SZ27dr16998AeAJxdK7u9KuppePlx864zpBocRx8V8RY2GqZr0p7YAr1TiHP4yYjFM17SUUJieOoX8lVn4Qz7x3sQ++F6U9n6qpf6+x1xqBf+NMAtvmkzXqKWQKFRXfPUUy5MN56aPIDXYqJj2bbiNTmBzY+v9FKjgXf/6b7isVjxLYtpBoyW5JsOWUFjFBEBBkcgRBhr5RN2lurKYQNXe+HIXRjn3AHdgH3UX+7V+Qe/1b2HrfRNX010jFwsTKD5DwlqGr117aIBBFYpWHqfj+ebxrJqVb9c6V0IF1eFdNQN+8L7oGXdLxV056joSvAueIx1CYnFRNe4lY+UEcwx6UFNNnvIYqryG2frdSOeV5UtEAWSMeJ3xgHf6NMzG1vxR94+5UL/ua6LHt2PrflhZi8a+bQuTIZqy9/3yaeM2pJJIikzYeZ8MRD2M+Xs3YtUdZtLuCeFIkJUI8kWL1Qddvutc/KnPmzAHIiLRlyPA/TiYn+GU0Kjl2g4ZaNi0GjRIBiKZSGLRKrmhfQG27Fp1SASIY21yMpm5bPD98SrzyCADGtsPQNuiMZ8nnRE/sxj7wTpT2AqqmvYzCno+l53WEdi8jsGUOWSOfJBULUznpWQytBmJoOQDfqgn4N8/B0Kw3ps6XE9gy97QTcFvfW35sXQ9W/6Z7i5bswbPkC0A4XQ9FkCHU/Kdv1A1D017pPzJ3vpycq19Grrdg7XMT2sLm5N/2+TnlBa557+Je8CGRI1vPOD3/JVKxsNR6LldiH3x3WnMksHUe3pVj0Tfvi6njSEn/ZuFHaIs7YulxDZVTXiTuKcV56SNEjmyRCvS2w9DWa0fl9y8g05pxDHuIhKcE15y3UddqgqX71dKb+qtwz30bdU4xzq5XIYiACGoZOI1qjldHWLW/kvWH3ZyoDlHqjWDRKanyRwhFE8hlUoy/4Tb/0IiiyIwZM2jSpAk2m+18h/O7J1OI/04JBAI899xzjBs3DrlczkMPPcTu3btp2rQpzz777PkO73+CTnXtqJUy5AKolDKW76sikTr7v6QqZxHWXn8icnBDenZbac+Xdq5P7JJmkuq1T7elB7fOwznySRBkVEx8GmOHkdICvfBjomX7cI56ilQkQMV3T4FMTtZlfwOQ5rt8kj+1oWmv9DyZf/Occ7qnZMgr+Z3PeoNY2YH0zwJI+iqRG08XFJMbHSR9lYA0k5WKR1DnFKOy56fFXVTZ9VBYc4i7TxAt2UOs/CAn3r+R8rGPEveU4Jr5JvHKI1Qv/oySD28mcoov6C8RKz9I1bSXUWXXxdbvVimGZIKqqX8nemIXjov/irpWE1xz3iZ8YB22freitBdQMekZ5HoLzksfwzXvXWIle3AMuR8xGU+3p1t6Xk9o/xp8q7/D0HIAhmZ9AIiW7cez5Eu09Tud1gZ/Nr5bf4xnpu8glkiREiElishlAnJBEvm5kMXZTmXs2LEAmRa0DBn+h8nkBL9O60IrNp2KYDRJhyIrz07fztUfrWH+jnLsehVN80wkUyniorSp7Rh8D4JKR+W0l0nFo9Jrg+6WNo+nvih1Yo14DMQUlZOfw9BqMPpmvfEu/4Z41VGcwx4iVnGIqqkvYe39Z7R12+Ge9y6hvSuxdL8afdNeVC/7Kr3+yzQGnJc+SirkpWLys6Rq1NR/jepl31D21f341n2PIJP/prxAUGoQEzHU2fUQ5IpzzgukDYe5lI97FPfcf55TnGIqSdW0l6WW8WEPoDA5AQjtXYVrzj/R1G6NfcAdRI/vkHKHnGLsF/8V97x3iRzZgn3gnYipFO7576Gp0wbLRdfVbFp4cF76KIJCReX3LyIo1TiGPYggVyCkkrhnv46YjFM86kEcJjVatQyjToZcoaTMH0OtkBNNimw77mXe9nICkTgpEfRqFVaDmupQHKdJhVGd0bcGWLt2LUePHuXyy/91J5oMP/IfKcQFQRgoCMIeQRD2C4Lw8Fn+XC0IwviaP18jCELt/8T7XqjE43FuuukmTpw4wRNPPMGjjz5K165dAWl+3G63k0ye2e58odG2yMqTFzeljkOPUiZwxB067c+N6tNblg2th6At7ohn8WdEy/YDoG/cA2PbofjXTyW4cwnmLlega9BFUip3Hcc54jESvgqqpryAfeDdqPMaUjX9VcRYGOeljxF3H6fyu78hN9jIuvwZUhE/5eMfIxmoaYHvczOaum1xz3uX4O7lv3pPwZ1LSLiP4xz1FFlXPEvCU0LCW4GYjBPcvSwtVHYSXf2OBLYtACC0ezmaopaAVNBLLewicU8JCU8ZSmsuxtaDyb/tc2rd+gk5V7+M0ppH7vVvUuvmD8n78wfI1Hp8ayb9apxx9wnKv3sSmVqPc8QT0ix6KknVjNekorv/X9A17Ipn0ccEty3A3PVKdA06UzHhSRBFsi57Gv/GGZIX6EXXS16gk59DYbDhGP4wieoyqma8jiq7Hra+twDSDFvVtJeQ6y3YB931q+re8aTIluNeUuKPyujPDG/GfaeIs2044uGdH/ZfsG3qsViM77//nt69e2d2vjP8JjJ5wf8fmZzg3DjmDlIViFLlD/Pqgj0s2lOFUiGw4kAVgWiCcCKFTq3ArJFyA7neimPIvcSrjhBd/ilQUyxf8iipSIDKaS+hMGXhGP4Q8aqjuGa+jq3/7ajzm1I1600EtRbbgNuJHNqAe/672Ic9iCq3PpXTXiZyeAv2QXejrdce99x3CGxfCEgt8I6L/0qsZC+VU15MK6X/HGIihnfVeHQNu1Fw57doizv8v+QFBXeNpeDusRhaDyGwZS6JmsL+Z+MUU7hmv5XedNfWaQNIgq2V015ClVtfEmOrPEzFxGeQm7LIGvUUvlXjCW5fhLnbGFQ5xVROeRGlrRaOYQ/iWfQR0aPbsA+6C1VOMa7Z/yDuPo5j6ANp21jvqgn4D22l5cg7adOiMXaDiiKbhvpOI0U2Lc1rWejZ0EmFP4pOrcCkVZBMiYRjSQpsWtoWWulQx0Y9pxGZTCCVSpFKpX5TF8AfjXHjxgFw9dVXn+dI/hj824W4IAhy4B1gENAEuFIQhCY/uexGwCOKYjHwBvDSv/u+FzInTpzg0KFDfPXVVxQXFxOPx9m7dy9ffvklkydPZsyYMcjlF7aQBEj2ZU9O3cb+yiD+6OlJiEEtp3Gu6bTXBEHAPvhu5HoLVVNfIhWVCndrrz+hzm+Ca84/iFcexj7kPpTO2lROewmZWo9j8L1Ej+/APe8dHCMfR2nJoWLSM9KCPfRBoqV7qZj0DEp7PlmX/Y2k30XZ2EdI+F01tlyPoK7ViKppLxPctfQX7+nkbrWgUCGTK7D1u5XyCU9Q8slt6Bv3QGkvoHrZN4T2r5Xus0U/UpEAJz68Gd+GaVgvuh6A6LEdlH52J2Vf3EPVtJexD7g9/bPTiCLIfvwnQqYzIyaiZ173E+JVxygf+wiIItlXPIvC5JBOwqe/Smj3Mqy9/oSh1SCql36Jf/1UjG2HYmg7jPIJT5IMuska+SSRo1vxrZognXa3HkzFpGdIxSM4Rz6JIFNQ+f3z0rOr2QUXRRHXvHdIVJfjGPpX5FrTL8Z4mno+0LXYwTc3dTpNnO1k2/pr8/ZcsDPjs2bNIhgMZk7DM/wmMnnB/y+ZnODc2HLci1IusGy/i01HqvEEo1R4wyRTSURRRp/GOUSTIvGkZLciAJbitlg7jqBi7Ux0x1aTb5KTXVAX24A7iB7dhmfxZ2jrtMHa+ybC+1bjXfGtNHJlzqJy0rNoajXB3G0Mwe2L8C7/BufIp1Da8qmc/JzUGTb8YTRFLXDNfJPANqkY1zXsgm3gHUQObqDy++d/VkkdAJkcQaUFMYVMrUOQyf/f8gJBoZI2vAUZgvLn9UPEVBLXrLfSBbWxtWR3FT6wjorvn0flKCLrsqdJVJdRMf4JZBoD2Vc8S2DbAnxrJmJoNQh9875UfPc3ZAoVWaP+RnD7wrSLiqFpL3yniLtqa7cCJLcWz4qxGJtchK55P0q9cfQaBTa9iib5Fq7uUpsRbfNpXWhFRCQUjWM3qKntMNC+jo26TgMymZDe1BdFMS3YdqEW4qIoMmHCBFq2bEm9evXOdzh/CIR/98MkCEJn4G+iKA6o+foRAFEUXzzlmrk116wSBEEBlAFO8VfevF27duL69ev/rfj+qLRo0YJbbrkFtVqNQqFgz549eL1ehgwZkvEWr+GdH/bzytwzRU8ALmmVxzWdazPyvZVn/Fnk+A7Kv30EXYMuOIY/hCAIJAMeSr+4B2Rycq97AzERp+yr+0AmJ+ea1wjuWEz14k8xtrkYY8eRlH/zEGI8QvaVLxCvPEzVjNdRFzQla+RTxMr3UzHxaeQ6M1lXPIfSkkMqGqJi4tNEj+/E1u9WjG3O/jtMxSOUfnI7YjJB9piXUFpy/pOP7GcREzEqJj1L5OhWcq+VTqLPRvTEbiomPQMyGdlXPI/KWUQqHqVq2suE96/B2utPGNtfSvXSL2vaygdi6Xk9Fd89Raz8QHqurmrqS2jrtsVxySNUTXmR8MENZI18Ek3dNlR+/wLh/WvJuvyZ9ILr3zwH99x/Yul+DeYuP6/qKhPgyg6FVPqjLNxVjihKYwtnsyd754f9vDZvDykR5ALc178ht/cq/o89098Do0ePZvz48Xg8HiwWy/kO53eDIAgbRFFsd77jOF/8t/KCTE7w82Rygl9n6d4KluytZNneSgRBYE95AACLVsH027tR4NAzY/MJvlpzmO3HvaREEbNKRigS49i3j+ArO0L+dW8is+YhAu4FH+DfMB374HvRN+uNe967BDbPxjbgDjS1W1H29V8RZEqyx7yEf90U/BumYe58Bca2Qykf+wgJXwVZo56SOr4mPUfkyGZs/f6SXv+lde0d1PlNcI58AvnPWIZVrxiLd/k3WPv8GVO7/z9ni+DOJVRNfxVD60HY+9921mtS8QhV018lvG815m5jsHSVxL2Cu5dTNf1VVFm1ybr8WZL+SsrHPY6gUJF91d+JHNqIe9676Bp1x9rvL1TUPK+cq/5OwldJ5eTn0TbohPOSR4gc2kTFxKdPz9mC1ZR+fheCUkP+dW9i0OtQKmQ0yDHSJM9CKJagUz079ZxGirMMbD/hY1+FD5NWSc8GWRg0yrPfT+pHT1yZ7MKb7l27di0dO3bkueee47HHHjvf4fxu+KWc4D/xKaoFHDvl6+M1r531GlEUE4AXyAxg/ht88skn7N+/n7Vr17J//37y8/O57777MgvuKXSqa0fxM5/wHJOGtkVWXri0+Rk2Z5r8plh6XEtoz3L8G2cAIDdYcY54jFSousa30oxz5JOkwn4qJz6NoeUATO0vxb9xBsEt88ge/TyCXEH5uMdRZtXFPuReosd2UDHhSVRZdci+4jlSkQDlXz9ArPwAMrWOrMufRlvcHvf893Av+PCsauoypQbniMcQE1HKvv4rkaPb/sNP7UwSvkrKxz4qWYgMuP1ni/DgziWUj3sUmVpPzphXUDmLSIb9VEx4gvD+tVKC0f5SPIs+/rEI73UDlZOeIVa6D+ewh0CQUTX9FdS5DbAPe1CygjvZylavHdXLviG8bzXW3jemi/Bo2X7cCz5AU7s1ps6X/eK9KOQyjGoFC3dXkBRBFOD6zrXPak/Wqa4dlUJ2wc6Mx+Nxpk2bRr9+/TJFeIbfSiYv+H8mkxP8Oh3r2unVMItsk5bqcBytQkYDpw6rTsm+KqkoH9wijzb5VnLMWuQyGYF4iuZFdmoNfwAEGWVTXiRZc0Jt7XWjdJo9922iJ3Zj63crmjrSqFnCdYzsyyQnkIoJT2LqfBmGFv3xrhqPf/Nsske/gMKUTcV3fyN6bAdZo55EW9wR9/z3qF72NaIoYmw1EMfQvxIt3UPZV39Ni73+FHPny2s0aj6S3FySv9zO/u8iiim8qyZQNf1V1PlNsPa68azXJfxVlI99hPC+NVj73pIuwn0bplM19SXUuQ3IHv08CU8J5WOl7rbsK18gcngz7nnvoi3ugK3/bVROepq45wTOEY9L3XXTXkGVW4xjyP0kXCeonPoSSmcR9hq/cDGVpGr6K6QiAWpd8jBKrQ5BLiACx6rDHHcHCITjlPkiVIdjVAWilFWHSKWoee3nn59MJkv/dyFyUjMmo5b+n+M/cSI+ChgoiuJNNV9fA3QURfGOU67ZXnPN8ZqvD9RcU3WWn3czcDNAYWFh2yNHjvxb8f3RicViKJXKX52HvVDZcMTD+0sO8MPuitOE2gxqOV/8qSNti6zcM24TUzaXnPZ9Yo34SvjgBrKv/Dua/MbAyR3gV9A374t90N1EDm6gYtIzaIpa4hz5OO557xHctgBLzxvQFXekfKxke5Y9+jniruNUTX8VpaOQ7MueJhnxUzHhKVIRP45hD6Ar7oiYSuJZ9An+DdNQFzaXZp0MZ87mxqqOUjn5ORKeUoxtBmPuNuZX27F/K2IygX/zbKqXfgWI2Afdjb5Rt7NcF8ez+HP866eizm8qqcTrzMQ9JVRMfIaEtwzHkPvQNeiCa84/CW5fgLHtUMzdrqZy4tNES3bjGPYgcoOViglPorDkkn3V3/FvmI53+TeYOo7C2vN6Ajt+wDXjNQwt+mMbeKe06x32S50KqSS517+FXGc+I75TkQsgAqdq9sllAhNu6XzWYnzDEQ+rD7roVNd+wXmJr1mzhk6dOvHmm29y9913n+9wfldkTsT/c3lBJif4bWRygl+n0h/hjXl7mL+jAq1GjgC8cVlLWhRaUchlPD5lK2sPuqgOxVErZDTM0rN8nxv/gXWUfPc0+mZ904rfybCfsi/vIxULkXvt68g0RsrHPkLcfZzs0S8gJuPSumbLJ+uK56j+4ROC2xdi7nolhtZDqJjwZNrvWteom7RGbpuPvklPSetEoSJybDuV37+AmEzgGHwPuoZdzrgnMZnAs+gj/BtnonTWljYFCpr9x59dtGw/ngUfED2xC12j7tgH3522JT2VyJGtVE5/GTEWwTH0AXT1a/KbHz7Fv34q2uKOOIY9QPT4Liq/fx653kL26OcJH9wgFeF122G/+K9UTX2ByNHtOC95GKW9kLJvHkSm1pFz9SvSxshXfyUVC5N73esoTFkAeJZ+iW/VBGpfeh+Nu1/MoaoQiSQ4jHJyTGq0ShVVoSid6zhpXmimvlPPD3sqselUBKJJ2hRa6VTs+I8/uz8CLVq0IBKJsHfv3vMdyu+K//aJ+Amg4JSv82teO+s1NS1oZuCs3kCiKH4oimI7URTbOZ3O/0B4f1zmzp3L119/fcaCO3fuXMaPH8/OnTtxu3+b3dQfjbZFVj66th1tCi2nvR6IJtNzv+sOn/mMBEGGY8h9KExZVE19MS2upm9yEeYuVxLctgDf6u/Q1muHfdBdRA5vwjXzTWz9b0PXqDvViz8jfGgD2Ve+iCCTUf7tIyjM2WSNfIKE54TUsiYI5FzzKkp7PpWTnsO7agIIAra+N2Mfci+xkr2UfnYnoX2rz4hP5Sgk9/q3MLa9GP+m2Zx4/0Y8iz8n4av4t59ZKhrCv3EmJR/fimfBB6hzG5B7/T/OWoTHq45R9vUD6Vnv7NHPIdeZCR/cICUnYR/ZVzyHtm47KiY/S3D7Asxdr8LUZTQVEx4nWroHx7AHURjtVHz3N+RGJ9lXPEto11K8y79B36w3louuI3JsO67Zb6EubI6t/19O2/VOBlw4L3nkV4twgKR4ehEOkEyJP2tV1rbImp4Zv9D48ssvARg6dOh5jiTD75D/WF6QyQnOnUxOcG44jRoeHNCQm3rUpmMdKzf3qIPdpMEVlE66K31RSr1RArEE7mCc5XvdRERQ1m0vrf/bFxDYPBsAudZI1qinIJmg4runQUyRNepvyPVWaaZZa5KEW11HqZz4FNZeN6Jv3hfvirH4104me/TzkkbM9Ffwr5+KbeCdWLpfQ3DnYsq+fYSErwpNQTNyr3sTpb0WlVNewDX3n6Rip/vDC3IFtn5/wTniCVKRIOXfPkz5+CcIH9yAKKbOeAa/BVEUiRzfQeX3L1D2xT3E3SewD7kXx7AHzyjCxUQcz5LPKR/3GDK1gZxrXkNXv6PUHTfx6XSu4Lz0UUJ7VlIx8WkUlhyyx7xMaM8KqQiv1x7H0Aeomv4ykSPbsA+5F1VOA8onPAGCjKzLn0Gm1lM5+XkS/iqyRjyWLsJDe1fhWzUBS8v+NO0xlGA0hQDYdDLEVAoROUatgmRSZMWBKmZtKaMiGCPXrEWllJNj0qDXZPQUzsa+ffvYtm0bgwYNOt+h/KH4T5yIK4C9QB+khXUdcJUoijtOueZ2oLkoircKgjAaGCGK4q/q3mfmwX6ZWCzG1q1badeuHeFwmPfff59vv/2WrKwsGjduzO7du8nKyuLTTz8936GedzYc8XD5B6tInlKFnZz7/WT5IdzB2Fm/L1ZxiLKv/4oqqy7ZV76AIFciiiJVM14ltHMJ9ovvx9C0F941k6le/CmGFv2x9vsLrumvENq7Emvvm9DW70T5uMdIhX04L30MmVpHxcRnIBnHcYkk1Oaa/TahXUvQFnfAPvhe5FojsaqjVE1/lXjFQXSNL8LW5ybk+jMLwljlEbwrxhLauxJEEXVhc3QNOqMtaoXCnn9OJyOJgJvo0a2E9q8lvH8NYjyKKrc+5q5Xoa3b7oyfISbj+NZMpnrlOGQqLfaBd6Br0AUxlcS7chzeFeNQOotwjngcQSaTPNQrj2Drfxvaum2pmPAU8epSnJc8jExjpOK7p5DrrWRf+SLRY9uomv4a2nrtcF76GInqUsq+fgCZzkLO1a8g1xoB8Cz5HN/qidgG3IGx1cDf8nE4DZVCxtg/nzkjfiEjiiJZWVkUFBSwcePG8x3O747Mifh/Jy/I5AS/TCYnOHfiyRRV/ijlvjAalQKHXkUsIZJr0XD3+I0cqgjiDUUp88WgxlkjkYKEmKJi4jOED28m+8rn0eQ3BaQT4PIJT6LOb0L2ZU9LbdnfPAiCQPYYybKr8vsXUGXVwXn50/iWf4N/40z0zfth63szrllvEtqzAkOrQdj63kL4wFqqZr6BoFDhuPh+tHXaICbjVC/9Ct/a75Gbs7DXrKc/JRWP4N84E9+670kFq5EbnegadUVbtx3qvEbIVD8vrHYSMREjWrqX8KGNhHYvJ+EpQabWY2w7FFOHS88q2ho5vgv33LeJVx2VcqE+f0am0hIt2UPl1JdIBt3Y+t6KoWV/vMu/xbtyHOrCFjgveQTf2sn4Vn+HrmE3bAPvoGrqS0QOb8Y+6E609TpQ9u3DJANusq98AVV2XaqmvUJo9zIcwx5C37g7IHUKln11P0p7Pi1ueo2ibDN7y/1EYiBXgEwEk06JRilDkMlAhDyrjkta5tAox0xVIIpOo6B5LQtaVcaq7Kc89dRTPPPMM2zatIlWrVqd73B+V/xSTvBvF+I1bzAYeBOQA5+Kovi8IAjPAOtFUZwmCIIG+ApoDbiB0aIoHvy1n5tZdM+NYDDIY489hsvl4pprruGiiy4ikUig1+upXbs2ixcvpnbt2uc7zPPOhiMeJm88znfrj5FMiSgVkkjX/B1lvL/05z+OwV3LqJr2EoaWA7ANuEM6jU3EKf/uSaLHd5F12d/Q1m6FZ+lX+FaNx9h2KJaef8I141VCe1Zg6XEt+ma9peLTfQLHkHtQ5TWictIzxF3Hsfa+EUOboQQ2zsDzw6eS/daQe9EWtURMxvGu+g7v6gkIchWWrqMxthmKoDhTSCThLSewbQHBXctIuI8DIKj1qBxFKCzZyHUWSV1VEBDjUVJhHwlfJfGqoyQD0kGUTGdG16AzhmZ9UeU1PLMAF0XC+9fiWfwpCfcJdA27Yut3K3K9lYS3nKoZrxM9vgN9s97Y+t1GrHw/lVP+jpiI4Rz+EPKak+9UNEjWyCcQRZHKSc8iN9rJHv080ZI90uxYQVOyRv2NVCRA2dcPICZj5Fz9alqc7uSIgKHlQOwD7+C3UuzUU9dpwGlUM6JNfqYI/wmLFy+mV69evPTSSzz44IPnO5zfHRd6IQ7/nbwgkxOcG5mc4NxIpkQCkTiBaBwQsOiU6NVK5m4/wQdLDuGLxPGFE8TicXxRkZPnymIkQMmX95KKhsm97o20H/bJ8Sld4x44hv6VeNVRyr99BEGtI+eqF4lVHKZyygso7YVkXfY0gU0z8a4ch7Zee+xD/4pv1Xf41kxEXdAM5/CHJQ2aqS8SrzqKqcMILN2vQVAoiRzfgWv22yTcx9HW74S15w0obT+VYJBOp0N7VxLcuZjw4U2QTIAgQ2nLR2HLQ2G0I1MbpCo1lSQVDZL0u4h7Soi7jkNKul5T2Ax9k57oGvU4axGf8FVRvexLgtsXITc6sQ24DV299oipJL41k6he/g1ygx3n8IdQ2guomvk64X2razYh/ox7/gcEty/E0HIAlotuoHLK80SPbsc+6C609TtSPvZREp5Ssi5/GnV+UzyLPsa/fiqWnn/C3HGE9LsM+yn76r6a38mbOLKysBvUeIIxUmKKaFxEqZDTttBEdShBJJEk26xFJoo0yDFzeYdCatv1yGSZkY6fo0mTJgSDQTLjQb+d/3oh/t8is+ieG59//jnLly/nqaeeoqDgx27A8vJyHnjgAZ5++mnq1KlzHiP83+Jsc7+D3lzKrjL/z37PyZkja5+bMbUbBiAVid8+TMJbTvbo51Hl1Mfzwyf4103B2G44lp434J71JsGdizF1GIGp02VUfv880WPbJQuPtsNwzXqD8L7V6Br3wD7gDuKeEqqmv0LCfQJjmyFYelyHTK0j7jqOe+FHRA5tQG50Yu58GYbmfREUqrPGG/eUEj22jWjZfuJVR0l4y0mFfIiJGhsUmQK51ojc6EBpz0eVXQ91fhNU2fUQZGe2ZYmiSOTQRrwrxxE9sQuFLR9b75vQ1muHKKYIbJ6DZ/FnANj6/QV905741k6hesnnKKy5ZI14nIS3gsqpf0em0pJ12d9IVJdTOU3yJM2+4jmiJbuorBFwybr8GcRUkvJvHiLhqyD7yhdR50iK5dHSvZR/+zCqnOIaUbyzq5v+EplT8F/m1ltv5YMPPqC0tJScnP8fZf4/EplC/L9DJic4NzI5wW/jZKecvKYIS6VEJm86xvbjXgxKgcX7KtldFkIU4aSEarzqGKVf3Y/CkkPOmJfTBap3zUSqF3+Ose1QrH1uJlZ+gPJxjyHXGsm+8sX0ybjc5CT7imcIH1iPe/770kn5yCeIHN2Ge87byDRGHMMfRJVdD8+ijwlsnoPSUYh90N2o8xoiJuL41n2Pd/V3iPEohuZ9MXUahdKad9Z7TMXCRI/tIFqym1jFIRKeUpJBN6lIEEk5RUBQ65DrrSgtOSizaqPOa4SmoBmyn1FrT/ir8K2ZjH/zbEDE1O4SzF2uQKbSEqs6imv2W8RK9qRPuZP+KiqnvEjCU4q1143omvbCNe3vRI5sxdz1KgytB0virWX7pcOIuu2oGP84saqjZI18Em2d1nhXf0f1ki8wth2Gtc+fpYORZIKK754icmwHOVe+gDa/CVoFNM4zE4on8IXiaBQyVEoFTWuZcBg0BMIJgokEcqBDXTs5Jg1dip3pz0CG09m1axdNmjThgQce4OWXXz7f4fzuyBTif3Cuu+46OnbsyG23SfYR8Xicd955h7fffpvrr7+eJ5544jxH+L/NhiMervhgJYlfGKMSxdSPtlkjn0RbT/r7lPC7KPvmQcRYmOyr/o7SXoBn4Yf4N0yXTsZ734RnwYcENs2UTon73op7wfsEty9C16g7toF3Etg4g+plX6OwZOMY+gBKRyHVS77Ev2E6cqMda++b0DXsiiAIhA9vpnrZV8RK9iDTWTC2GoihRX8U5qxzuldRTIEonrXYPhupaIjgrqX4N84gXnkYudGBucsVGJr3Q5AriJUfxD3/PaIndqEpapkWl6ma9SaRgxvQNeiCbdDdBLfOxbP4c5TOIrJGPkX44Hrc895FlVOfrMueInJkK1XTX5GK68ufBUGgYvwTRMv3kzXqb2mF9ISvkrKv7geZgtxrX0eut5zTfVzSKo/VB12U+aSNiAvVjuxcSCaTZGdnU79+fVatWnW+w/ldkinE/ztkcoJzI5MTnJ1USiSaSEmtyWcZ2QrHkhxyBTnmCrL+sBsEgWQyhVEtZ+y6Y7iCcTilGA8fWEfFpGfRFnfAeemjCIIMURTTG/Lmrldi6TaGaMkeysc/UVOMv0DCX0XFxGeQKTVkXf40SV8lldNeRqbS4rzkEQSFWipYveWYu12FudNlRA5twjXnbZIBN4bWg7F0vxq51kgy6MG7agL+zXMgmUBbvyPG1oPRFLU8p3VeFEUQU5IX+DmMsYmiSPTETgKbZhPcvRzElKTl0vUqFOYsUrEI3tUT8K2ZjEylxdbvFrSNehDcPBvPD58gU+ul2XKdmcrJz5LwVWIfeCeagmaUT3iKhLcc5/CHUddqRPn4x4m7jpM14nG0ddvi3zwb99x30DW5CMfF96ef90nbOPvgezA074sMMKmhUz0nHWpbUcrlhBMJkilQyOXUduhoX9vOthNekqkUFq0abyRG13oOFPILUw3913j88cd5/vnnM23p/yKZQvwPzsyZM3nooYd46KGHWLp0KUuWLKFNmzbce++9dOzY8XyHd974NcXrb9ccZfb2UiLxJOsOe9KvG9Vy/NEzrcNSsTDl3z5M3FNCzpiXUGXVBaTT5/Qs2FV/R2HJTbdOGVr0x9r/Nnyrv8O7/Bs0Ra1wDH+IwNa5VC/+AqWjEOelj5IMeqia/hrJgAtzlyswd76cWNkBXPPeIV5xCHVBM6w9b5B2wkWRyNGt+NdNIXxA+vuhLmiKrmFXdMUdUJiz/63nlowEiBzaSGjvKsL71yImoiidtTG1G46+aU8EuZKErwrvim8JbJ2PTGvE2utP6Jv1IbR3Je657yDGI1h7/Qldk1645/6T0O5lUlE++B58q7/Dt/o7NHXb4hz+CKG9K3DNegt1XiOyLvsbyORUTnpaUkod/nBaITYVDVH2zYMkvOXkXP0KKmfts8YvIO3vAxRnGfhT1zpc1bGQDUc8jPl4NfFEKj2WkDkRP5OFCxfSt29fXn/9de69997zHc7vkkwh/t8hkxOcG5mc4ExCsQSTNh6n0helrtPAkBa5KE8pulKpFN+sOcpxT5jDriAGpZwWhRaOukJ0rW/nscnbKa3ZyIUf1xnf+ql4Fn6Eqf2lWHtLNl6imMI1+x9pBxVzx5FES/dSMf4JBJWW7Cuek9TUv3uKVCwiiY0arJITiq8KW58/o2vSE/e8dwntWoI6rxH2wfcgN9ioXvYV/o0zkan1mLuMxth6EIJCRTLgwbdhGoEtc0mFfciNDiknqN8Jda3GCPJ/feZZTCWJle0ntG+1NC9eXYqg0mJo3hdju+EoLTmIqSTB7QupXvY1yYAbfdNeWHvfhJiI45rzNpFDG9DUaYtjyD1Ejkriq4JKg/OSRxHkCiomPQOJOM6RT6Cw5FIx/gkS3jKclz6Gtm5bAtsX4Zr5Rlo75uT9+NZ+j+eHT7B2HIW95/XEkVSojWro3TiXBtkm1EoZXYsdhGJJ7HoVhXZpvt0fibPtuJdESqQ4y0CeRfsvP6M/Og0aNCCVSrF///7zHcrvkkwhfgEwf/58Vq1aRSgU4tZbb0Umk7FmzRoqKiqorKykRYsWjBgx4nyH+f/GyaIrlkihOkvR9e2aozz6/dk9uG06Je7Q2X0kE/4qyr68H4Cca15Nz4bFKo9QPlbayc6+6kUU5my8y77Gu2o8uobdcFx8P8FdS3HNeRuFJYesUU+SqC6javqriMk49gG3o63bDveCDwju+AGlowjbgDtQ5zUgsGUu1cu/JRWqRluvPebOl6OuJdmpJbzlBLYvIrRrKXGXZNurMGen28yV9gIUlhzkBjuCUn3ajreYjJMMekh4K4i7jhOrOESspm0NMSXNizfsiqFp7/S8eMJbgW/tZPxb5gIixtZDMHe9EjEWwb3wA8J7V6HKKcY+5D5IJqic9hIJTymWHtdgaH0x7llvENq7EkPLgdj6/wX/xhl4Fn4k2b+NeBxBJqdi8nNEDm3CfvF9GJr2qok1QcWkZ4gc3kzWqKfOKlDzU8526n0h25GdK9u3b2fevHlcffXVZGWdW6dFhtPJFOL/HTI5wbmTyQlOZ9NRDwt3lVNg03HMHWZk23zqOH4UHPOF47y/eD/5Nh1HXUEOVgVplGMi26ymUY6BcauPUO6Lsf24m3gSFAqBWEwkkBTxLPgA/8YZWPvegqmt5DIhuXq8Smj3Mqx9/oyp3XBiFQcpH/8kIJJ12dPIdWYqJj5NvOootr63oGvcA9eM1wgfXC9Zgw28g/CBdbjnv08qHsXS9UpMHS4l7jqOZ9EnRI5sRm50YOo4EkOLfsiUGmkufN8qaS780EZIJhCUGtS1GqPKrY/KWRuFNQ+FyYlMa0QQftyMEMUUqUiApL+KuKeUeOVhYqX7iJ7YRSoarJkXb4G+aU90DbsiU2kRk3GCO5fiXT2BhPsEqtyGWHvfJOUum+fgWfI5iCksF12Pvnk/qhd/SmDTLNR5jXBc8jDR47twzXoTmd5C1qinpBxgwpMkwz6yRj6BprCFpAkz4zU0hc2la2pG8kK7l1M59SUsTbpQd9QjJAVprN2kk6NXq+jfLJdaFi1H3SEuaVNAOJbEolNSZNefcs8iokhmNvxX+Oc//0n9+vUZMGDA+Q7ld0mmEL/AGDt2LIsWLcJqtZKdnU1ZWRkbNmzgxhtvZMyYMec7vP8X3vlhP6/N20NKPHtBds0na1i270e7WqFme1upkNEq38zaU07If0qs4hBl3zyEwuQge8zLyGvmp2LlBykf9xiCUkP2VS+itOSkd2vVhS3IuvRRYhWHqJzyIqSSUhu6s4iqaa8QPbETfZOeWPvdSvT4Dtzz3iPpr0LfrC+Wi65FptLi3zAd37oppMI+1HmNMLQZgq5BF2RKNQBx13HChzYSObaNWMkekoGf2NTIFNICJgiQjCMmTleKF1RaVDn10RQ0RVO7Neq8hggyudSKdmwb/k2zCe1ZAYKAoVkfzF2uQK634ls/Fe/K8SCmMHe9EmPbYfg3zqB66VfIdSYcQx9AbnRQ+f3zxKuOYu15A4Z2w/Au/QrfmonoGnTBMfSvknDb989LRfigOzG06A9IC6Vr1psEty/ENvAujC37/+LvXl6znmZOvf81RFHMeBD/m2QK8f8OmZzgXyOTE8CuUi/TtpSSZVThCsS5smMhtU45AU0mU3y1+ggl1WFSIrQsMNOqwIrTqKbCG+bNBfs4WBUkmUph06uo8EXwh6OU+xLEUsn06JrzkkfSXVxiMkHVtJclB5WaYjzuPkH5+CdIRfw4L3kEdZ5kXRY+sE7qoOt7M/4N06X10+jAMeReFLZaeOa/T2jvyrQ+i6ZuWyJHtuBd/i3REzuRaU0YWvbH0HJgWtQ0FQ0RObyZyNEtRI7vJF55RGpDTyMgqDQ1Ld4pxHj0jD9X2vNR12qMpqglmjpt0q4lCb+LwLb5BDbNIhlwo8yqg6XrlWjrdyZ6YheehR8SK9uPpqgVtoF3IEaDVE1/jbhLEp4zdxuDd+VYfKsnoq7VBOelj5LwVVAx8WkQRbIu+xvq3AbpIlyd34SsUX9Lz+JHjm2nYvwTWAsb0vaWl7GbjQQiCZQKGbkWLYhg0CiQyeXUcehokmvGpldRL8uAWpGxJ/utZPKCf49MIX4B8fXXX7No0SKGDBlCs2bNKCgoQKfTMWXKFL788ksmT558vkP8f+HX2pB/eiJ+a4+6GLVKrDoVi/dUMG9n+S/+/PCRLVR89xTqnAZkXfFM2kszXYwrVGSPfh6lPV9SUp31FkpbHlmjnkoXnPGKw5i7jMbU+TJ8aybhXTEWud6CbcDtaApb4F05Dt+6qQhyOaZ2l2DqcCnI5AS2zse/cToJTymCWo+uQRf0jbqhKWx+mnhbMlhN3H2cRHU5yaCHVDSAGI8BIoJcKQmz6CwoTE4UtloozFnp3XFRTBErP0hozwpCu5aS8JYjU+sxtByAse1Q5AYbwR0/UL38G5K+SrTFHbH2+TMk4rjm/IPoiV1oG3TGPuAOoid2STYsgoBj6AOoC5rhmv0WoV1LJauWfrciJmJUTn6OyJGtpxXhAJ7Fn+FbMwlztzFYul75i78XAVDIBYpsOuo6DdxyUb1MIf4rhEIh5HI5LpcLo9GI0Wg83yH97skU4v8dMjnBbyeTE0gkUyIr9ldx2BWkaZ75rOtCMBrnYGUQtVJOsVNPIJZkd6mPQE0LcySZ5EhlkLpOA65ghHWH3fj8ESoikmVYxbjHiZYfIPvyp9EUtgBOL8YtPa/H3HGUNCP+3d+Iu45h6387huZ9qF7+Db5VE1DlFOMY/jCpkJeqGa+S8JRKejM9riV6bAfuhR+S8JSgKWqJpce1qHIbED2+A9+6KYT3rwUxhTq/KfrG3dEWd0JhcqTvT0zEiLtPkPCUkgi4SIV8pGKh9Iy4TKlBpjUhN9pRWHJQ2vJPU0lPBqsJH1hLcNcyIke2gJhCU7s1pnbD0dRtS7zqCNXLvyG8dxVygx1rrz+hrd9ZGs1bPQG5zox90N2S/dj0V4kc2SK50fS9lfDB9VTNeBW5zvJ/7d13eFzV1ejh3z5t+mhGXbYsdxtXDMaNZnpvphlIgBB6IL3XjyT3y01uvkIKgdBCCWAIJEAAhxo62GBjsI0LtnEv6mU09Zyz7x9HkmVbwsJtJGu/zyOwNKOZdaZo9j577bUovejn3thp8UvUzf192yT8P9As78RJtnoNWx/+IWY4zpFf/QODKsuYMawEx7HxmQaulORsh4Kgj6GlYWqa09QnswyIBpkyNE5xZPft2/o7NS7Yt9REvJ9YtWoVN9xwAz/4wQ844YQT0DRvUrVixQq++93vcvHFF/PFL34xz1EeOD3dI376+IqOPcSX3PkOOadn74nW5W9S+9RvCAybTMn5P+6o3p2tWcu2OT8BJGUX/xKrbBipdR9S8w+vD3nJrB9hlQ2j/oU7aF3yEr5B4yk+69s4rY3UPXcrudp1BEcdSfzEa5GOTePrD5Bc/gbCChI5/ExvIhyKkV6/mNbFL5H85F1kNoUwffgqx+GvHIdVMQqrdAhaMNajs5huJkmudj2ZrV4aWnr9R7itjV4q2uBDCY07vu0sv6B16b9pnvcEduMWrPIRxNr2rje/8zea5j2BZvmJn3gdwdFH0fjafbQs+GfH4EIYFjV//0+yW1YQm/klotMuwM20UvP4z8lsXuEVWxl/QkdcTe8+TuNr9xE+7AwKT75xh2PRBLi7eapUdfTP1tTUxHe/+10++OADjjzySFzXZfz48Zx55plUVlbmO7w+S03E9w81Jvh81Jhgz0kp+fvCjby3tp6c44IUhPw6H21oxNR1Epksja1ZhJBUN9tEfNDQ2Mzmh36A3VJD2SW/wlcx0rstx6b22f8huex1ojNmEzvmi8hsipon/y/ptR8QnTKL2HFfIrX6Peqe/V+klBSddjOB4VNpfP1+WhY8gx4ppvDEa/EPP4LEork0vf0obqoZ/9DDKZh2Ib6qCTgtdbQueZnWj1/t2KpmFlfhr5qINWA0vrIRGPGKHu0Xl66D3biVbPWnZDYvJ7N+MdltawCJUVBGcOxxhCeciBkfQGbzCprnPUFy5TsIy+91iZkyi8zm5dS/eAd2/UZvz/iJ15Hd+gm1z/4PMpOk8OQbCE042Wtx9tr9WBUjKb3gp16m3YJ/0vDSn/EPnkTJ+T/pOCGQq9/E1oe/j9AMxl/3W4YNHkJFQYjSWIDBhSGa01k+2ZbA79MwNZ0ZI4pYuaWF4pBFyG8QDRgcO6oMv6lWxbujxgX7npqI9xO1tbVMnjy5o8ff+++/zwsvvMBHH33EwIED+fWvf41pfv5WT/3FdQ+8v9uV8J21LPoX9c//0dsHfs53O6qU5uo3sW3OTzr6ZfurJpCr20D1E7/Abqqh8OTrCR96Gq1LXqH+xdsRmk785BsIjj6KlveepOntOSAE0akXEJ06C7txC01vP0pyxdug6QQPOYrIoafiGzQeHJv0ug9JrVlAet1H5OrWd8QnrCBGQSl6KI7mC3akpks7h8ymcJJN2C013qS7jR4uxFc1gcDQwwkMOwI9WECubiOJj14gsfgl3FQzVvlICo6cjX/4FFLL36Dh1ftxWmq89PoTrsZuqaPu2f8hV7veazNy3FVkt62m5slf4WZaKT7r2wRHHbl9ZaB+I8Vnf5fQ6KO2P7YLn6X+xdu9nqxnfXuXCrBjKyIs29LCZ/0FE8B3TlXV0bvzta99jWw2y89//nM+/vhj1q9fz9KlS7Esi29961sUFhbmO8Q+SU3E9w81Jvh81Jhgz7mu5KaH3seW3r8bExmiIQvHkbSms6yqTRI0BIm0Tcr2fidoQq6plk8e+L7XSeXSX3UUFZWu41X3/vB5wpNOo/DkGwFoePkuWhY+g3/IYRSf811vgv70/yO7eUXHdrVc7QbqX7iNXM1a/IMPJX7C1RixCloWPuttV0s2YpYMITLpdIJjZ6L7w+RqN5BcPZ/02kVkNn3spZ0DCA0jWoIeKUILRL3aMZoOroubS3t7xBN12M01Xt9xAN3EN2A0/sGHEhwxFbN0GDKXIbniTVoWzSW7eQXCFyJy+FlEp5yLm07Q+Op9Xip9rJzCk2/EVzmu7aTCPzGLqig+93sYBeXU/ev3JJe97u2JP+MbCMOi6c2HaXr7EQIjp1Nyzvc6Mv1kczWbHvo+Mpeh8rLfEB84iJHlEQ4bWESr7VBREKA06uPjzc3omsDQBEOKgqyrb6UgaFAU9FMYtpg+rJiwX73uu6PGBfuemoj3I9dccw3JZJJly5Yxbtw4xo0bx4QJE5g5cyaRSIRUKkUgoCpDduXcP77JhxubPvfvNc//Ow3/vpfQuOO9D5K2CaPdXEP1Yz8j17iF4jO/RWjMsTjpBLVP/5b0pwsIjT+RwpNvxGmtp+6Z/yGzeTmBEVMpPPkrIB0a/v0XkiveRAvFKJh+EeFDT8NJ1NGy4J8klryCzLSiR0sIjT6awKjp+AYcgtB0nHSC7LbV5GrWYjdsxm6uxWltxM20Ip2c175MN9B8QbRAFCNS7KWhFVdhlQ1Dj3gF6Oz6jSRXzSO54i2yWz4BoREcOZ3I5LOwKseTXv0ejW8+RK56DVbZcOInXINVMcpLqZ/3BHooRtHpX8c/9HCvKNsr92BEiymZ9WOs0qFka9ZR/fgtuOmEVxm1rUUZQMuHz1P/rz8QGDHNa+fSxRn89rXx3f0F+9WsCVw2repzP6/9wfe//32OOOIILrroIgAymQzLli3jtttuw7ZtbrvtNoLBYJ6j7HvURHz/UGOCz0+NCXqmoTXLwnUNDC4KMqLMS8P9xpyFZG2XmpY0G+qTxEMWqZwkYmmsr09iuy6tOQgKcCTkgAITqqu3sO3h7yNdl/JL/y9msdfLXUpJ4+v30/zu4wRGTqf4rO+gWX7v8+7F29FDhZSc+32s8hE0vf0oTe88ih6IUnjyjfhHTPVWwt98GDedIDjmWGJHXYJRUEZi6atei9HqNaCbBIZNJjj6KALDJqMHokjXIVe7nmz1p522q9XjplqQuQzSdRCahjB8aIEIeiiOUVCKWTgQs3QYVslghG7iZlpJffpBW0eVechcGqOwkshhZxCecBJuOkHTu4+R+OhFhG4SnX6htzK+aRn1z/8Ru3ErkcnnEJt5JU5TNTVP/V9ydRuJHXs50WkXQvvJio9eIDThZIpOu7nTeKqW6kd+iJNqpuoL/4m/fARBE6YOKcLRdGJBk9ElERK5HG+tqsVvGpSFfUwbUUzY0vlwQyMDCwNMH1LE0JIwumpT1i01Ltj31ES8H8lkMixevJiWlhYqKyvZtm0bNTU1rFq1ijlz5jBhwgTOPPPMjjdYf/Xr55bxr6VbOW1cOT84w6tA/lmV1Hen6Z3HaHz9AULjT6Do9K93fHg4qRZq/v5LMhs/JnbsFUSnXwTSpentOTS9NQezqJLic76LWTyYlvefovGNh0DTiB11KZHJZ5PdupqG1x8gs/4jr2/45LOITDodYfpIrnyH5MevkVq7CFwb4QvhrxzrVUctH4FZPBg9XNiz1PR0glzdBrLb1pDZvJz0+iU4LTUAWOUjCI05ltDY49H8YVqXv07z/H+Qq1mLESun4OgvEBo7k9SaBTS89Gfsxq2Exp9E/MRroK2NS+qTdwmMmErRmd9C94dJffoBNU/+XzTTR+lFt2CVDe+IJfHRC9TN/QP+oYdTev5PEEb3Z66Dlk4yu2uruXYa8G21It6tl19+mauuuorrr7+eK664gkGDBnVcdtJJJ3Hrrbcyfvz4PEbYN6mJ+P6hxgSfnxoT7F5Da5Yb/vo+25ozGJrgZ2eN5djRpby9qpanF21m8eYGSsJ+gqbOog0NBCydzQ0pdA2khFZ719vM1W1g6yM/RCCouPRX6EXb/7Y2L/gnDS/fhVU23GvXFS4ks2UlNU/+GidRR+zYK4lOPY/stjXUzf09ueo1BIZPIX7ideiBCE3zHqdlwT+RuSzBUTOITJmFNWA0ueo1tC55heSKN9uKtQqssmH4KsfhGzAKs2QoZnzAZ36mtpOO7aWm16wlu2UlmY0fk9my0uuoEoh6tWnGH49v4FhyNZ/S/N6TtH78GiCITDqVghmXIKVL46t/ofXjVzHiFRSd9jV8g8bTuvhF6l/6M8L0U3zWdwgMPQw3naDmqd+QXvsBBUdeQsHRX+gYu9gttVQ/4rV5HXfVr5AlIykOWQQCJiVhHxnbwdJ1KguD+A2NTE5SFvUjBAwqCjKiNEx9IkNpxM/wkjCG0bOe6f2VGhfse2oi3k8tWbKEOXPmUF1dTUVFBUcddRTFxcWccsopVFdXd+wX629+/dwy7nh9Tcf3Nxw7bIfJ+P97fjmN3bQv+yyNb8+h6Y2/Ehw7k+Izv9UxGZd2ltq5vyP58WuExh5H4WlfRTN9pNYuou6Z/8ZJtxA75nKiU87Dbq6h4aU/k1r9HkbhQOIzv0Rg5HQyG5fS9M7fSH+6AHST0CFHE5pwEv6qCchsitSnH5D+dCHpjUux6zd1xCRMH3qkBD1YsD01HYF0bWQ2iZNsxmmpw023dPyOForhHzgW/5BJBIYfgREtJVe7gcSSl0gsftlLgyuqIjrtAkJjZ5Kr20DDq38h/elCr6LrKTcSGHwoqTULqJv7O5xkM/GZVxKZch5AR8sys7iK0gt/hhHd3iKrZdFc6p+/rdMkfHvxuZ7QNYF0JS7eHvKuWtcpO/rggw/429/+Rm1tLQMHDmTmzJnous5FF13E1q1b8x1en6Qm4vuHGhPsHTUm6NqrK6r5+T+XMjAWZFtzislVcX594aGA19bs5Y+3Muf9jTSnsoR9BuMGRnlx6VZa0g4hn8a25hyda41bArIScrUb2DrnhyCh7JL/05GmDpBcNY/ap3+L5gtRcv5P8FWMxEknqJ/7e5Ir38Y/eKLXOzxUSMuCp2l86xGkkyM6+RyiMy4G16H5/adILHwWN9OKVTac8MRTCI45Fs0fIrvlE6+LyroPyW75BGlvT03Xw4Xo4TiaP+IVmtV0cB3cXMZLTW+tx2mp215BXTfwlY3AN3gigaGH4xs4xktNX/k2iQ9fILPpY4TpIzzxFKJTz0fzR2h+/0ma330c6doUTL2A6IyLkbk09c/fRnLl2/iqJlJ81rcxIkXk6jdR/cQvsRu3UnTqTYQnntzxONnN1Wx75Me4yUbGX/WfDDzkUITUCPlMGjNZ4n6LipgPITSifoOiiI+WtE15NICpw6GVcZozNn5DZ9yAKAFLV5PwHlDjgn1LTcT7oTfffJNLL72Ua665hi996UsMHjy447IzzjiDX/ziFxxxRP8cJx7323+zti7Z8f2QoiCvfvf4ju+/MecDnly0eY9uu+ndv9H42v0ERs2g5OzvdZx5llLS/M5jNL7xoHcWfNaPMArKcJJN1D3/R6//9oDRXkXR4iqSq9+j4ZV7sOs3YlWMJnb0ZfiHHk6ubgMtC5+ldem/kdkkeriQ4KgZBIZPxTdoPJrpw0m1kK1eg123gVzDFpyWWpxkEzKbwm3bJyZ0A80KeBVSw4VeanpRJVbp0LbUdEl262pSa94nueItcjVrQWgERkwlctgZ+Icchl23kca355Bc9jqaL0jBkZcQmXwWbi5Dwyv30Lr4RW8v2NnfxiobjpvLUP/Cn2hd8jKBEdMoPuvbaL7t6U3trd4Cw6d46eifcxJu6oKfnzOehmSWeNCiIZlV/cK7kUwmWbBgAdlsloEDB1JdXU1tbS0rV67k4YcfZtq0aZx11lmce+65+Q61T1IT8f1DjQn2nBoTdG9NTQvXPbgAAWRslyunD+HqY4cB3mf3XW+s4c1PaklkbIYUBvGZGrUtGT6tT6AhSKRy1KdyuC5YGmRtaG8OmqvbwLY5P0baOUov/kVHATfwqn9XP/F/cFobKDrlK4QnnoyUksRHL9Dwyt0AXrvPSafhJBpofP0BWpe8gvAFiU45j+gR54DQaF36b1oWzSVX/SloBv4hhxIcOYPAsMMxoqXbU9Nr1mLXb8ZurvY6qaQTyFwa6boIoSFMH5o/jB6KoUdLMeMDMEsGYxUPRhgmTqKB1NqFpD6ZR2rN+0g7i1E4kMihpxKaeApCN0gsep6meX/DbW0kMGoG8eO+jBErJ7nsdepfvhM309qx8CA0neQn86h95r8RukHJrB/hH7R9pdVr9fYTZCbJwIt/Qcmw0ZQWBBkUD7GtOUlNIgsIfIZgeEkEKQXxsEE86KOswM9JY8oYUhwmYzuYmqZ6he+GGhfsP2oi3g/97Gc/Y8iQIXz5y18GIJvNsmLFCu666y6qq6u58847iUajeY4yP3a3It45PT3i02nNOEh2vxe5XfP7T9Pw8p1etc9ZP9phsplcNd/70BGCojO/SXDENKSUtH78Kg0v34WbSRKdOouCGbMRhkli8Us0vT0Hp7kGq2w4kSnnETrkaKTrklo1j9blb5Bes9A7262b+CpG4RswGqtsOGbxIIxYRUfLj+5IJ4fdVE2ufiPZbWu8NLRNy3DTCUDgG3gIwUOOJjTmWLRgbHurlE/mIUyfV6Bl2gVe2vrSf9Pw73txU81Ep55P7OjLEIZFrn4TNU/+X3I16yg46hIKjrq0U6s0SeMbf6X5nUcJjj6K4rO/01GB/vOwdMEj181QE+8emDVrFpFIhGXLljFmzBjGjRvHxIkTOeGEE/D5fGrf6F5SE/H9Q40J9pwaE3y2t1fV8uxHmxlZGuHyGVXoupfR1pzKccvTS6gqCrG1KUUq63DxEZV8sKGJopBFXSJDXSJDQyrHhxvqyeQcdEAiQTok0pCo3sLaR3+Cm2qm9Pyf4h88seN+nWQTtU//hvS6jzr6iGumn1zjVur/9QfS6z7EN3AMhSffiFU2jGz1pzS+8VdSq+Z5nVQmnUbk8LPQoyVeavrHr5Fc8RZ2k1d41ohV4Kscg1U+EqtkMEZ8IHo43vH52xUpJW6yiVzDJnI168huXUVm0/KOQrDeAsCRBMfMxDfwENxkIy0fzKXlg2dxk034qiYSO+Zy/JVjyDVspv7FP5P+dAFWxUhvsaFkiNcR5o0HaZ73BFb5CErO+xFGwfbsuOy21Wx77D9AupTP/iVlg4dTWhCgqjCEpmmkclk21mco8Ftsa0kRsQSHDy6krtXmhEPKGFwcxtAFkwfH1Qp4D6lxwf6jJuL90C233MLHH3/MH/7wB5YuXcqmTZtYsmQJQgi++c1vUlZWlu8Q86qrPeIAl98zjzc+qd3l+lOGxPlwYxM52+3RhDyx+GXq5v4Oq2wYpRf+B3po++Qw17CF2qd+TXbbaiKTzyZ+3FUIw8JJNtHw73toXfIKeqSY2MwrCY2dCa5DYskrNM//B3b9RrRgjPCEEwmNPxGruAo3lyGzfjHpdR+S3vgx2erV26udApovhBYsQPOFvAmu8PZ/yWwaJ9WMm2yi82kGo7DS22teNYHAkMPQQzGc1kZaP36NxEcvkKtdh+aPEDnsDCJHnIMeLCCzaTkNr9xNZvNyrIpRFJ16E1bZcO8kw5KXqX/xDoRhUXzmtwgM3/63SLoO9c/fRuKjFwhPPIXCU2/apTp6T+kCvnWK2g++OytXruSss85i5cqVAMybN4/nn3+eJUuWMHLkSP7zP/8zzxH2fWoivn+oMcGeU2OCPeM4Lr99YTlralpZW9vK4KIQh1bGOHVcGesakjQksggB1c1plmxsYsU2r2K3XxcUFfiprk+yuSVNfU0tWx79GbnGzRSf+W1CY47puA/pOh19xM2iKorP+S5W6dC2z89XaPj3PbjpBOGJpxA75gvooTjZbatpevdxkiveAiAwbDLhiScTGDYFdINc7XrSaxeR3rCYzKbluMnG7QelG+jBOFogjGb4vNR06XakprvJRqSd7bi65o9gDRiFf9AE/EMmYZUNAylJr11EYvFLJFe+A65NYNgRRKdfiH/QeK9w2zuP0bzgaYRuEjvmi0QOPwuh6eQat1L7z9+S3byC8KGnUXjSdTtkwKXWLqLmH/+J5gtTPvuXRIsr8VkQD1kMjIcYN7CA+kSWBevrCftMbMdFkw5DSgr4tC7J8aOKOWFsBamcw4zhRWoi3gNqXLB/qYl4P5TNZvn617/OggULmDZtGgBjxozhzDPP3CElTdlRdwXbjhlZzDdOGsXfF27k0ffWY7td/PJOkqvfo/apX6MFY5Re+B9Yxdsrd0s7S8Orf/FaeRRXUXTmt/CVexPI9MalNLx8F9mtqzBLhxI75osEhk8FJOlPP6Dlg+dIrX4PpItZMsRLQRt+BFb5CISmI50cufpN5Go3YDdtbUtNb8bNJpF2DpAIrXNqehyjoByzcABm8WA0XxApJXbjFlJrFpD65F3S6xeDdLHKRxKedBqhsTPRTD/ZmrU0vvkQqZXvoIVixI+9gtCEkxBCw2ltpP6FP3n7wQaNp/is72BEizseAzeT9IqzfLqAghmzKTjmi3v8gSkAn6nxs7PGqZT03ViyZAlf+9rX+PnPf87RRx/d8ZgvXryY73znO9x8882cffbZeY6yb1MT8f1DjQn2nBoT7LnGZJaH3/2UT7a1MmNEMZsb05wxoYJR5REvoyuZ46MNDcxdsok121pJORJwuWxqFc8u3sLWphR1LRkaG1vY0l689biriE49f4fPvNSnH1D37P+01Y35ItEpszo6oTS9+RAtHzyH0E0iR5xLdMp56IEIdnM1LR/MpXXxSzitDQhfiOCIqQRGTMM/ZBK6P4yUEidRT652XVsnlRqvk0o64VVNl8721HRfCD0YQ4+WYMYrvKKv0RKEELjZFOn1H5FaNZ/kJ+/iJpvQ/BFC444nctgZmEWVuNmU11Zt3hO46QSh8ScQm3klRrhwx7R7oVF06k1eph107LFPLH6Zun/9HrOwktKLfk5RUTEa4GgQC5iE/RZHDC2kKh5kS2OSj7e0UBLxtS2SCMKWhqnrzBhRwpShhTSlc9S2ZCgvCDC8JKQm5d1Q44L9S03E+6nNmzdTV1dHPB4nEolQUFCQ75D6hIfnrefeN9ewqqa142edW2AtWNfA9x//cIfLu5PZspLqJ36BtHOUnPt9AkMP3+HyjoJmrY1Ep19IwYzZaKYPKV2Sy16n8Y2HsBu3YJYOo2DaBQRHH4XQDZzWBlqXv0ly+RtkNi0H6XpV0weOwaoYhVU2DLNoEEZB6W7TvKV0cVrqydVvJFezlsyWT8hs+hin2auabhRWEhx9FKExx2KVDEZKSWbTMprn/53UJ+8irCDRqbOITjkPzQogpSS57DXqX7oTN5vcYUDRzm7aRvXjvyBXt4HCU75CZNJpPX5+unLMyGJOH1/Bz55ajO2CocGj1x+pJuPduOeee3jvvfc49dRTGT9+PJWVlQQCAX7/+9+zZMkS7rzzznyH2Kepifj+ocYEeyeZTPLpp58CUFlZqcYEn8PWphR/e38jjivxmxqXTq2iIOit4rakc6zc1sJLS7ewpTGDrrnEAj5OHlvOX+etYVV1KxvrU0gJup3l06f/l8TyN7wssFNu3OEzeoe6MRWjKTr9qx1F3nL1m2h8468kl7+BsAJEJp1OZPI5GNFipOuQXruI1mWvk1o1v60Aq1c13RpwCL7y4ZjFgzEKB6L7w7s9XjeTJNew2dtbvm01mc3LyW5dBa6DsAIEhh1BaMwxBIZN8faPJ5to+eA5Whb8EzfVjH/YZOLHXumtnoOXav/8baTXfuAVajvjGxgFpRhAYQik7bLmxQepf+dv+AcfSsl5PyQYCOM3IedAOKARtAyClsXRo0rw6zqjKiIMLwmTztm8saqWdNZBF4KqohBHDI4T9Bs8uXATiXSOSNDkwsMHEQ99vtoz/YkaF+w/aiLeT919991ce+21zJ07l9NOOw0ppTob+Dk8PG89c5ds4fTxFTv0oV6wroFL73qXbE+WxfGqflY//gtyteuJH/clIlNm7fA8OOkEDS/fTeuSlzBiFRSedH1H+rZ0bFo/fpWmdx/Hrt+IHi4ifOiphCee1FFt3Ek2eSlo6z8is3EZuboNbE81F2ihAvRAW2q6YQDCS03PeanpTqIB3O2p7Hq4CN/AMfirxuMfchhm4UDA+2BuXf4GiUVzyW5dheYPE5l8NpHJ56AHvL6rubqN1L/0Z9JrP2gbRHwNq2TH1Zb0uo+oeerX4DoUn/fDHfqH94Rgx/36pi6Yc90M7nhtNS9+vK3j5yePLeOuK9RcqDv33Xcfb775JsXFxcTjcWpra3n99df51a9+xYknnpjv8Po0NRHfP9SYYO84jkMkEmHmzJnMnTtXjQk+p5qWNHWtWSqi/o5JuJSSldtasAyNrU1pnvxgPZrQGVkWJh7ysXprEx9sbKCuJYXjgCEgnUrz4dyHaHjnUXyVYyk594fo4e0njb2T2a9T/9KfcTOtRKecR8GRl3TUe8nWrKXpncdILn8TgOCoGYQnnY5/8ESE0JCu47UhXbuIzMalZLZ8gsymOm5fWAH0UAzNH0GYXsVxKV1kLoubSeC0NiIz2xcahOHDKh+Or3Ic/sGH4q8chzDMtpPyy0l8+C9al70OTo7A8CkUzJiNb+Ah3rHYOZrf+wdNbz8KmkZ85pcIH3b6DnvUw24rG578L5o/eY/ooacRO/kGNN2gKKChC4kjJbqhE/NZDCsNM3FQnKjf4JCKAhwJrusS8RvUJbI0pXMMLwkzrCTMRxsb+MubaymN+qlpyXDDzOFMHBTbb6+Pg4EaF+wfaiLeT9XU1FBaWsoXv/hFHnzwwXyHc9C47d+r+O8XVuB+jreOm01R9+z/klz5NsFDjqHotK/uUMQNvH1R9S/egV2/kcCwI4gd/+WOdHYpXVKr36dl4TOkP/0AAF/VeEKjjyYwchpGpHiH+8rVrCNXv6ktNb0OJ9WMzCaRtg1I0A00048WiKCHCr0z0/EBWCWDd9jP7maSpD5dSHLFW6RWzUfaGcziKiKHn0Vo3Alolh8AJ9VM09uP0rLwGYTp9/aDHXbGDqvgUkqa5/+dxtfuxywcSMn5P+mY5O/OmPIIn9a1krNdTEPjSzOG8M6aOsqifq6fOZzJg+Oc+8c3+XBjU8fvHFpZwFM3H93zJ6kfePrpp3n99dcxDIMbb7wRKSXz58+nvr6e6upqDj/8cM4666x8h9nnqYn4/qHGBHvvwgsv5IknnqCxsVGtiO+ldM4hlbXZ1JgmaOnYtuSjTQ2MKPVOTH+4oapybSEAAEojSURBVJ7qljQrt7TQlMxi4+A6Go2traxvkrQse526536H5g9RfO738VeO2+H2nWQTDa/+hdbFL6GHC4kdczmh8Sd0fK7mGreSWPgsicUv4aZb0CPFhA45hsCoI/ENGLW9hap0sRu2kKvbsD01PdnUkZqOdEFoCMPyqqYHC9Cjxd6WteIqzMKBnW5Lkt22muTKd0gufx27YQvCChAaexyRyWd3GrNIkivfpvHV+7AbtxAcdSTxE6/FiJbscIzZmrXUPfkrso3bGHDStcSOOJOcI4j4IB7xkbVd0jlvMj5uQJSLJg8i6DM5pDxKUcRHdUsaU9MojfrRNYHrSoQAIQSL1jfy8Ly1RAImrRmbK2YMYewA9ZrfmRoX7H9qIt6PHX/88cyfP5+mpiYMw8h3OH3agnUNvLumjnjQ4pZ/LiVnuxi6QACOK3F281aSUtI87wkaX38AI15ByTnf70jb6riOk6P5/X/S9PYcZC5NaNwJFBx1CWasvOM6ucattC79N60fv4ZdvxEAs2QI/sGH4qsci69iJHqkZI9WOpx0oq1C6rKOFXZcGy0QJTj6KMLjT8QaMLrjtt1MkuYFT9M8/x/IbIrwhJOIHXv5DpN58Cbqdc/dSmrVfIKjj6Lo9K/vciKiO+3V0IGOx7+rfeA77+/vvJ1Agdtuu42HH36Yb37zm8ydO5c333yTKVOm8J3vfIdJkyblO7yDipqI7x9qTLD3nnjiCS688ELuuusurrnmmnyH02elcw5LNjdhO5Kc7VDfmqMxlSMeNIgGLHQhkK7LqpoETckMK7c0U9+apSmVobYlS1MGckC2+lNq/vEr7KZtxI69gui083epaJ7ZtJz6l+8iu2UFZnEVBUd/geCoGds7j9hZkp+8S+vSf5P69IOOz2x/1UT8VeOxBhzS0YLs85KOTa5uA5nNK8hsXEp63Yc4iXoQGv6qCYTGHU9w9FEdq/VSStJrP6Dxjb+S3bISs7iK+AnX7LItr2O/+Et/RveFGHLh9xHl4wnqkJMQDemYQtKUdQlbJhLB5KoC/s/5k9CEIBLYfiyNySzJrEMsaBK0to9zE+kcj87fQH0yS3HEx+wpg3a4XFHjggNFTcT7sTvuuIMbb7yRp556inPOOSff4fRZC9Y18IW7vXR0Q9dwXRfHBU3AtccMIxIweebDzSzb2rLb20qvX0ztP3+Lk2omfuyVRKacu8sHr5Nsoumdx2j54DlwHUJjZxKdOgurdPvEXUpJrnY9qdXvkfp0IdnNyzsqnWq+EEZRJUasHCNc1FY1Pdi2F00gnVxbanoLTqIep3kbubpNOIm6tlsXmKVDCQyZRGD4Efgqx+2wuu0kGmhe+AyJhc/gZloJjJhG7NjLO/aydZZau8grQJNqJn7cl4lMPrvHJwliQZN7rpzSMeFesK6BS+98h5wjMbtoV9bddgIFzjzzTG644YaOgivpdJrf/e533HnnnXzzm9/k5ptvznOEBw81Ed8/1Jhg72UyGeLxOMcccwzPP/98vsPps+oSGVZXJ4iFLFZVt1DfmqMkZLGqJsGhg2JMrIyRSGV5dvEWNtS3sHxLC3WtGdKZLA2tICQkAR9AppUNc/9AcsWb+AcfStEZ39yhsCm0rTCveIvGN/6KXb8Rs7iK6NQLCI09doc95m6mldTq90l9uoD02g+3f6ZrOkasArNwAHqkpC01PYxm+kDoIB2kncVNt+IkG7Gba7xV9IZNHV1YtGAMf9UEAsMmExg+BT24fXVZug7JT96led4TZLesRI+UEDv6UkLjT9ylE4qTaqb++dtIrngL/+BJVJz9bUpLimjOeF1pXAcsEywDbAcKAgbBgMXQohB/uOwIDH37eKmuJc28T+uwJRQGLaYMLcRnbL+/rO3SnM4S81sYRvct2/orNS44MNREvB9rbm6moKCASy+9lIcffjjf4fRZndPR26eQ7e8cQxP84tzx/OzpJdi7WxZv4ySbqPvXH0h98i6+QeMpOuMbO6x6t7Nb6mie/3cSHz6PzKXxDRpPZNLp3tlwY8eiI9LOka1eQ3brJ2Rr1pGr34jTVI2dqNuhndkOhIYeiqFHSjqqpltlw/FVjETbqaCLlJLMhiW0fPgvksvfAtchMGq6tx+sfNeWYW42ReNr99Gy8FmMwkpKzvkuVtnwHj0+7c6bNIBbLzms4/sf/WMxD89b3/H9ZdOq+NWsCZ/rNvurO+64g6VLl/LTn/6U0tLt/Vo/+eQTfvnLX/Lb3/5WtTDaR9REfP9QY4J947LLLuORRx6hoaGBWCyW73D6pGTWZsmmJiRQ35qhriVL1nFJZmxGlEUYP6AAhGR9TQvPL93Gwg0NRC2d2qYk25ptLAOaMmABfh+0pCU1Hz5Pwyt3ITSD+EnXERp3wi4nraXr0LrsdZrffdxrJRqKEZ54KuGJJ+8yhpBS4jRXk9m8kmzNp+Rq12M3bsVprsHNdF9sVlhBjEgxRqwMs2gQZulQfBWjMOIDdonHbqmjdcnLtCz6F05zNUasnOi0CwiPP6nLFfjkqnnUP38bTrKZ2DFfpGTa+ZRFLRzp0pr16u6kcuDTIRYy8evQagsGxAMcP7KIy2YM36Hg2oJ19Sxc10A0YJDMOpw3aSDxkO9zPpv9lxoXHBhqIt7PnXrqqbzxxhvU19fj9/vzHU6f1L4inrNddF3DcdyOVHQNOGpkcZf9xz+LlJLWxS9S//LdIB1iR3+ByBHndtlH20knSHz4PIkPnsNu2ua1KBk1g9Doo/EPPvQzU86klMhsCjebQjo5AIRmICy/t0ouuj9LLKVLdutqkivfonXZGzhN2xBWkPCEE4kcfla3e7xTq9+n7oU/4TTXEJl8NrGZV3pn3j+n9rZx7aveP/7HYh7qNBE/ZWwZd6qCbD2ybds2fvrTn5LJZLjyyis56qijcF2XlpYWjjjiCFasWEEgEMh3mAcFNRHfP9SYYN94+umnOffcc7n99tu54YYb8h1On9WasUlmbfymzpKNTcxfW8+I0hClUT8lIR+ulHy4oYGV25pZtqWZnG2zrTlN2K+xpT5DNgclUQ1XarQkbRozkG7YTPWzt5Le9DH+oZMpPOXGLk/SSylJf7qQloXPkFqzAKSLr3IswUOOIThyxi4r6rv8vpPDzSTb2pe5Xvsyw2or6vrZKexOayPJVfNILn+T9LoPvfuumkj08LMIjJy2wxjGAGzAaW2g/uW7SC57HbNkCMVnfotI2TDiEY0RxVGka/NJbZJUxiXrQGlYozASIOQzOKIqRlVJGFM3OHRQnMp4oGNV/MMNjXywoQFLF+RsyXmHDaAgqCbiPaXGBQeGmoj3c+3V0x9//HEuuOCCfIfTZ7XvEZ8+rIgXl27lzjfWIOX2/tU/fWoJzuep4NbGbq6h/sXbSa2aj1kyhMKTrsdf1fUqr5Qu6XUf0br0FZIr30FmUwjT17YXbAK+gWMwS4ft0aQXvLPtufpNZDYtJ7NhMem1i3BaG7z9YIMPbdsPdiSa2fUJHbtpGw2v3ENy5dsYhZUUnf7VXQrQfB7t/cEfumY6kwfHWbCugdl3vtOReWAZGo9cO121KeuhZDLJbbfdxpw5c6isrGTIkCEsXbqUqVOn8qtf/Srf4R001ER8/1Bjgn2jvXr61KlTefXVV/MdTp+XtV02NSTZ1JhEExrRgMHIsgj1iSzvrq7lo42NbGlI4EpByKdh6QbLNjWh6TaaZpLO2bg4NLbaaAJyGYcN7zxL9RsPgusSnX4h0annd/u5bjfXeHVjlr5Krs47Ud1eN8Y/aDzWgNEY4cI9Pj4n2URmy8qOPeLZLasAiRErJzhmJuHxJ3R7Uj7k2NQteo7qNx5C2hkKZswmPv1Con4TvyEYWRaiqiTKys2NrKtL4koIWTpjKuMMjAcZVxFhQGGIbU0ZNCEojfqIBkzGVEQRQtCUzLFgXR1Z26Us6mfCwAIMY9fFDKV7alyw/6mJeD/X1NRELBbj/PPP54knnsh3OH1e5/3imvDS0i+bVsXD89bz4ycXsydvKSklqZXvUP/KXTjNNQRHHUls5pWfWVVc2jnS6xaRXP0+6bUfYDds9i4QGkZ8AGZRJUZBGXq4CD0YRbOCYJhe+y/Hxs2lcVMtOK312E012A2byNVuQNoZALRgAf6qiQSGH+HtBwtEu43FTSdoevdxmt9/CqFpFEy/mOjU8/eoOMyY8giWofHRRi/tTxfwrVNGc9PxXvr7j9vS07u6TOlaV22K5s6dSzab5ZBDDmHQoEEEgz0rnqfsnpqI7x9qTLDvXHHFFTz44INs3bpVpZ7upSWbmmhOedlmQUtndHmElrRNxGfw3vo6Vm5ppjmVoyaRJWyZJDNpmloz5ByHDQ1pDM3FtW2SNrgSXCCXhk21NdS9cjfJFW+hR0uIHXsFobEzPzOLLVe7geTqeaTWLCSzaRm0ZcFpwRhm8SDMWAV6tG2PuC/URfsyb4+401KL3bCVXN16rzgbgKbjqxiNf+hhBEdMxSwd1m29FyklqVXzaX79PjK1GwgMOYzik64nXFJJ0AexkEVlLMjhVYXops4LH21kQ2MapMRvGpw6voKqoiDTh5UwZkABK7Y2k7Vd/KZOYzLHxEEFHXvBkxmbjO0Q8ZtoAjRN7QXvCTUuOHA+a0ygygf2AwUFBZx++uk8+eSTZLNZLMva/S8p3Xp3jXf21dsvLvn3iuqOAmETBxbs0EKrp4QQBEcfiX/Y4TTP/wfN8/9O8pN3CY0/gYIZszHjFbv+jmESGD6FwPApgFdALbN5Gdmtq8nWrsOu30x67Qdee5LPounefrD4AMKHnurtER8wCqOwcrdF1dxMK80L/knL/H/gZpKExh1H7NgrdmlR8nnMHFXCyePKO7YCmIbG9GFFHZeff3glTyzc2OVlSteEECSTyR0+VE8//fSOf/fmE7KKoux7l112GQ8++CCPPvooX/va1/IdTp+WSOeI+L2TzrWJFM9+mCCZcwhZOhMrI9S3+MnkXIaVhBDSpSFh4Lguy7ckyWSzlBWH2VCXQEivQFnaBqFBJFqCdt4PSa//iIZX7qHumf+med4TFBx5CcHRR3Y5ITeLB1FQPIiCaRci7RyZravIbl1JtvpT7LqNJFfNx0027vaYtEAUI1aGf/ChmCVD8VWMwKoY1W02XLv2lPnGtx4mu3kFZuFAys//CYWjpyFdgeOCqYHmuiTSOWqSWYYXRzAMHb9l4tcEPlOjNe1Q3Zxh7pKt+EydeNDi09pWWrMOBQEDq1PBtqDPIGCpVfDPS40Lege1It5PPPDAA1x55ZU88sgjXHLJJfkOp0/rvF9cCLDd7ZeNrYjw8ZbdV07fHae10auavmguuA7BQ44mesS5+AaM/ly34+0PT+KkWtr6iHtnx9F0NCvgVU0NRD7zDHtX7OYaWhY+Q8sHc5HZJIHhU4gdc/ku7dj2RHnUx7s/OmmHrQArtrbsUA2982UqLb17zzzzDI899hjBYJCCggKGDBnCaaedxtChQwF49tlnOe644wiFQnmO9OCiVsT3DzUm2Hds26akpITx48fzxhtv5DucPm1LY4rVNQkADE1jxdZmKmJ+1lQnmTIkRs6VLNlUj2XorKlJkcnmaE1naU7l2NbSStTyUdfaiiYEhiGobXGRDjTlINt2H1K6JJe9QeNbD2PXb8IorCQ6dRahscd97q1o0sm19RFv9TqtdPQRN9F8YfRgdJdisLu/TZvkirdonv93sttWe1XTj5xNeMJJBE0D3QXdgIq4SdBnksxKhhSFGFwcpjhssb4+waINjfgtk6q4H1fqDC8L09ia5bjRpRw1soTmdA7HkfgMjcUbG2lIZhlfGaMyrlZtPw81LjjwVGq6QmtrK2VlZUyYMIF33nkn3+H0ee0TwReWbt1lBVwTXnrZvmAn6ml570laFv0LmU1iVYwkPPFUQmOOQfMd2D+S0rFJrf2AxIfPk1o1H4DgqCOJTr+wy6rp3fEZGpnOZy92MnVInMduOLLje9UffM8NGjSI22+/naamJmzb5uOPPyaZTHLppZcybNgwHnjgAb73ve/lO8yDjpqI7x9qTLBv3Xzzzdx2220sWbKEceP2vJaH4vUVB2hMZZn70RYsQ6O+NcOMYSWkbYcNtQlW1bXi0wSWDu+vqyHrQE1zkoBh4Ncg5Tpk0y6JNBRGoaYJam2vsrrES1m3XYfk8jdpmvcEueo1aP4wofEnEp54cpftQ/enMNDQsJnE4pdo/ehFnNYGzMKBxKeeT2j8CZi6V/VcCghoEA1rDC6OkrZdkBrDyoKURUIUhgyClkFNawbpQDRgsLYuScYBx3W5fuYIBsS2FwxbsLaOhesaCfkMmjNZLp0yhILg598G11+pccGBpybiCgDXX389d955J+vWraOqSk1k9oWdJ4n7i5tJkljyMolFc8nVrgfdJDh8CoFRMwgMOwI9ENkv9yvtHOkNi0mufIfkyrdxk01owQLCE04ictgZGAV7v7dwYMzP5sa0t+dbEzx2/Y69wS+/Z94OFekPrSzgqZuP3uv7PdjNmzePW265hblz5wKQzWZZvXo1r732Gn//+9+59957GThwYI97uis9pybi+4caE+xby5YtY+zYsXz729/mv/7rv/IdzkFj9bZmtjSlCFoGrhSApCWdZV19Ele6bGlMsuDTBlJpG4RkVHmIEeVRFq6uw9Al6+tS2Lgk09CcAp8G6BAwYFPKu4+OdqILnyX5ybvg2pglQwgecjTBEdMwS4bsl7/tUkrs+k1kVs0ntfJNkptXgtAIDJtM5LAziA6bjE/XCPlA00A3DEK6IBQwyTmC4ohB2GdSEg5Qm8wSC/iYMjSGRGN0eZhsziXqN9jWnKYhZTOyLMywkh3HN6+tqGZNbYJEymZbIs3JY8o5ckQxpq72hu+OGhfkh5qIKwC89dZbHH300fz617/m+9//fr7DOSgsWNfAHa+t5tOaBGtqW/fZSnh3pJRkt6yk9eNXSa54yyuiIjSs8hH4B43HN3AMVvkI9EjxHv0hddIJsttWk928gvSGJWQ2LkXmMgjTR2D4VEJjZxIYNhmh7/uzz5qA/3PerqvdO5/sMHXBnOtmqJT03UgkEpxzzjmMGjWK73//+x1pZwD/8R//gRCCW265JX8BHsTURHz/UGOCfW/s2LG0tLSwfv16NfjeB9a0jQVCps7o8jCbmtI4tsOa2iSOdElnXDTN5f1P60nbWWqas2jA+Iowta056lozOHaGxqSD7W7PsEs7IB2oSW1PV2/nJJtoXfY6yWWvewXaAD1ciH/wod6YoGIUVvHgPSqeKh2bXP1GsltWkt64jMz6j7CbtgEQKB9O+JBjiI89jkykmABg6RC0IBb2kZVQHvJRHvfTkLRJZGwK/AYjSqOMHxRlW3MWCUT9JhGfwcB4kIoCP35TZ1NDCp+pMagwhK7t+LqsS2R4dP56NjelGV0WYWhxkPGVMWJBVf9od9S4ID9UsTYFgBkzZlBWVsZjjz2mJuL7QOfq6Zahcd0xw7j7zU+x9+NsXAiBb8BofANGEz/xWrJbPiG1+n3S6z6k+f2nYf7fvev5QpiFAzCiZejhQrRgFM0KeBNoIZCOjeyomt6A3VyD3bB5e3VUwCyqIjzhZPzDJuOvmrjHLdG6M6I0zKrqRMf3roSlm3ctdHfZtCpeXVHNCx97H/6uK3l3TZ2aiO9GOBzm8ccf5ze/+Q0///nPGT16NMcddxwzZsxg2bJlTJw4Md8hKoqSZ7Nnz+aWW25h/vz5TJs2Ld/h9GmprMOamlYKw1bb/u8so8uiVLekyEkoDvmQQDZn05jIMm9tPUGfQWHAYFtrjqHxIBFLJy3DFGWzNLZmvO4smiCZTJPIQqEBW+0d71cPFhCdfDbRyWdjJ+q9McHaD0it/YDWpf/2riQ0jHgFRkE5RrQYLehVTdfMtqVrKb0xQSaJk2zyqqY3biXXsAVc7w41fxjfoPFEp55PYPgUBpeXommQs6E+4/ULNx3wGZDOZJCajt/wk85mSeXA1DUaUlniER9NrRnqW7Osq03iSsnw4hDThhbhNzXW1LZiGRotaZvaRJqy6I59rIvCPi48YhCLNjQSC5rYjsQy1Gp4T6hxQe+jJuL9iKZpXHTRRfzxj39k3bp1DB48ON8h9Wmdq6fnbJdIwOQX547nztdXs64uicQ7mz2sOMSqmtZ9fv9CaB2Tco75AtLOeqvZ29Z4VdMbtpCtXoPz6QJkNtX1jegmeiiGES3BP+QwzOJBWCVDsSpG7bd093aNyewu++m7O4Vx/czhvP5JjaqU/jkVFhbyox/9iNdff52FCxfy1a9+Fdu2Oeyww/j2t7+d7/AURcmzK664gltuuYU5c+aoifheEsL7ytoujisxNYHtuiSzLo4jcWVbAVUJR44sRQiN5lSGloxNQ8KbuGqWQRzYnAPp2rgIdCkxfSYmOVq6+SjX8PaQG+FCIoeeQuTQU7w08qZtZLd8QrZmLXbdBuymbSS3rcZNNXtF2nY9CrRABD1ciFE4kMCIaZglg/GVj8AoHOiNO2ibPOhgGQLblUQMME0wBdg2WH4IWj4aMw5hDEKGSyIrKAr7aElm0EM+sjmbSNsK+ab6JB9taqQ47CeZdQj7DG9s5XQ9KiiN+jm8KkZz2qYo7CNoqelMT3UeFzzzzDP8+Mc/pr6+Xo0L8kSlpvczb7/9NkcddRT/8R//odJP9lLn6ummofGzs8bxi2eWksm5HZNwQxNIuv8wOVCkk8PNppFODqRE6IbXP9Tw5TUd0dAFriOReCnnj3SRct5eGC8etGhIZlWl9D306quvcvzxx/Pggw8ye/ZsTFMVt9lfVGr6/qHGBPvHpEmT2LJlC5s3b0bXVRuovVHTkmZNTSthn8GI0jArtnldVNI5F78pKA37ERpkci6fVjfx6rJqNjclCZo6wyuiJDM2hQGTdTUJkjlJxk5Tn0jTmLIxcUnloLbJS0+38U5eh4SXDt6QAQdvQr5z+vrOpHSR2TTSziBdFyE0hGkhTD9C2/4aCLbdltZ2fybg1yHig4DfW4WWjnd7iSwUhwwMy6I+kcRFJ2gJSiMBDF0nkbUZWRrBb+ocOaKEmpYMizY0MiAeIJG2Of+wSoJ+k4bWDD5Tx9AEg4tC+E29U9yS9XVJtrWkiQcthpWEd0ldV3pu/PjxjBs3jt/85jcMHDhQjQv2E5WarnSYMWMGo0aN4g9/+EPHfhBlz0weHOeha6Z3tNFqXyGXeB9aR40opqowyCPz1+c7VIRuogd63x9Y6UounVbFgFigywn2zun/D10zXU3C99Dtt98OwMyZM9WHraIoHb785S/z9a9/nccff5zZs2fnO5w+rSTipyTi9dp2XYntuIR8JroQaJog4jdJZG18ho7fZzKg0M/w8gjvrqmjOWkzKB4g60LAb5CTNlubciTTDlnHJRb2UVlk4tMSpNvKqGey4DMh60BZENChoQV8QILtWWYmXvX19tw8ITSELwg+r/WXjjeJZ6frW3jjmVzb/0MamAaELG+xoTAcojWXJZvO4ffrBHwWRSEfqXSOpONS4DcJ+3XioSAt6QwBXSfgM7AdSUlBgBOCZsdj4bMMMrbDgHiA4pAPre0x66wlY7OlKUU8ZFHXmiEWNDseb+Xzefvtt1m6dClnnXUWQ4YMyXc4/dZebaoQQvxWCLFcCPGREOIfQohYN9dbK4RYLIRYJIRQp7PzSAjBNddcQ319Pa+++mq+w+nzJg+Oc9PxI5g8OM70YUVYhoYuwDI1vnHSKM4/vLLjZ/2dLuDYkcWUR33omkAXYBoa4wYUdPs7O6f/v7um7gBGfPDIZrP84x//4JRTTmHQoEH5Dkc5iKlxQd9z1VVXoWkajz32WL5DOahommBQPEgiY5N1XCrjAYI+najfIJOzkS40Z1021qeIB02GFAXQNY2wqTG2Ms7YihBhv0llcYThxREClo6pGwSCfmIBH5oO/hAEAhp+C0qiUBrWCFje6rUFFAJRoCoKlQVQpu068A/gTbzbibYvvwmhIMT8UBaCgWFvQq5bYJgGls/PgMIgEZ/JgMIQU4YUMSBsoklJeSzEhMo4Q0sLGBANUlUcYlhxlOKoj8FFYTRNMKQwwMnjB3DCuAomD4kTsnQKgyZSwoaGFPWtGXbO2hVt//XS/AVqaLXnHn74YQBuuummPEfSv+3tiviLwA+llLYQ4jfAD4HuqoAdL6Ws7eYy5QC64oor+N73vsejjz7K8ccfn+9wDho7r5C3r9w+dM10/vza6o5iY/3V5MFxpg0r4usnjQLoSDf/xTNLu13xbj+5ofaG752nnnqKXC7HJZdcku9QlIOfGhf0MZFIhDPOOIOnn36aVCpFIBDY/S8pPVIS9RMPWQghOlKoA5ZBxnYZGA8wdVCcpVuaKIv4KCkI4tc1mtM2BUETQwjGlGfZ1pIjkbEZEveRy2XR9SDbGlsp0CEpNbI5gWVqhIMm6YyGZaSwsrS1BIXCAJRELOrTOYoKJalayOB9BYGqGLSmvVX1tA2G5qW6IyDjQDID4QCYOkRN0BzQhUbMrxH06ZRE/IT8BlJCwG+h2y5oEPIblIUDVBX7CVne1jIHSTxsYWkahq5R3ZKmOOzHZ+hUxAIkszZbGtNE/AZNaZuQzyRgbU9ND/sMKuMBtjWnKYv6iIf2bRHZ/sJ1Xf72t79x+OGHq5PzebZXK+JSyheklO31G98FKvc+JGV/KysrY/r06fztb3/DcZzd/4LSY51XyDv/LJXrP49z2Kd3eZb6/XUN/PcLK/jC3e8CcNPxI2hIZj9zxbv95Ma3Thmt0tL3wqOPPgrAxRdfnOdIlIOdGhf0TbNnz8a2bf7+97/nO5SDjqFru+xjDvgMklmHoN9k/MAI0aBGYyJDLGwxtiLK8OIwQ0siTBhUwuCSIMNLggwtDtGUdqlJZMg6IAM64YCfwoIA8YBBRSxKQVAgJRjCy0ILmhAJawwtDVMc9hGyTEpCUBSASj8UGBD0WxQWGFgWhAIQsCCbhYwN2Yw3Abc0aEmA4YPCqEXUZzCyNERQ1wlZOj5DoymZIetIEAZlcT9RnyBoaRiGQUnEx6DCMOVhH9JxqWlOk8g4ZG1JdUsa23HJ2G3jJAGulNBFDSshBJWFQSYPKVT7w/fCG2+8QXV1tTo53wvsy3r/XwbmdnOZBF4QQiwQQly3D+9T2UOzZ8+mvr6e5557Lt+h9Aunj6/IdwgHTCLjIASURnbs6elKOibcTyzcyG3/XkU8aHWk7ne34t3VyQ2l5+rr63nqqac488wzCYVC+Q5H6V/UuKCPuPDCCxFCcN999+U7lH6hIuonHjQJ+TSyrmBgPEplYRApJcURH0ITFEV8jB4YYfKQYsZWxmhOSXDB0jXKoj6iRpCh8QAjiwMUFUTQNYEjJbEgZCVkJPh9gICmlIsmNHyGzpDyAMVhgWVCUZGJhkNx0EdJgZ8BRX4q4gahMPgsMEwvRb0gpFNRqDG4MEJR2EdxNEQ0ECQUNKgsKiCgW/j9JrGARdDScHI2dYkcjaks2ZzNmroksZBOc9YlHDCJBExqmtNoSFozNqurW1lT00pjMkssYJKxXQpDPvymaku2P9x7770AaiLeC+w2NV0I8RJQ3sVFP5ZSPtV2nR/jbUl5qJubOVpKuUkIUQq8KIRYLqV8vZv7uw64DqCqqqoHh6Dsifb09D/96U+cffbZ+Q7noLZgXQMNySznTRrAU4s279KiqzhiUZfIdnXyt89yJVS3dF23VdcEjy/YiO146eg/O2ucqoa+H917773Yts2NN96Y71CUg8SBHBeoMcGB4ff7ueKKK7j//vtZu3atKt60n0gpaU7b5GwXQ9MoKQiRyLikczZB0wApyLpe5XMdgS40CgImqYxNOGAwuDjE6upWDF1QWKATC/qwBJS7UJdMk3VcpIAiH6Qy3qKyhksqnaPAD1nHxDIMAhbEgi6FYT/VLSl8lsbggE59wqVRQMBy8UuXiAHxsB/L0jF0k6BlkHElpQUWsZCPjGPTms6SzOWIBwNEfTrr61qpTaaxLD+GobOxIc2QogASQVHIR21Tlm2tSUxhIF2XISURhAY+06A5ZTO81E9RWKWc7y+pVIoHHniAk08+WaWl9wK7nYhLKU/6rMuFEF8CzgJOlN30QpNSbmr7f7UQ4h/AVKDLibiU8k7gTvBalewuPmXPFBYWMmvWLB577DESiQThcDjfIR2UOlf91oRAiF2zrRoS3mYuXcCosgjLtrbkJ9gDZNKgGO+va+hYHW9IZrnp+BH5Duug9dhjj1FSUsJpp52W71CUg8SBHBeoMcGBc+2113L//fczZ84cfvCDH+Q7nINSKuuQzNhYhkYi4+AzNcoiftbVJwn4DDShoQkI+0wak1n8hsDUdUxdo7zAT9DSaUnZBA2d4ngAy9Ap8Bu8vXob9a02AdNC0zL4/JBzoSUJ0RAURXxUFAWpb8kgNJ1Y2KSmMYWdc4j5dYpCFsLUCfpdAgkb22fQlMrhWi6DioKYloVPl2i6D1PXKAlbjK6IUNeSxQk7CF0nkcqwdGsziayDphlE/dCcshlaFGBwUYSsLSmN+EikcwwpilAZC1CbzBL0GTQms7iStkKuKt18f2ovyvjlL385z5EosPdV008DvgecI6VMdnOdkBAi0v5v4BRgyd7cr7JvtLcpmTNnTp4jOXh1rvrtutKbjO90HVd6OZqO5KCZhH/W52imrTDbZ6WjK/vG5s2bee+997joootUf2DlgFDjgr7ryCOPpKKioqOmhLLvuVIihDfhNA2NkGUQDZiMrogyuChMLGiSzDrUNqeobk5T25rF0gUVRUHKC/wUhkwOGxJnxIA4ZQVByqMBTF0jZPopKwgSNKE8EqAgIJBAOAhBS8eWNq3JHEHLIh4wwIXR5WEKYxZ+w48tBQ2tDqFAgNEDoliGTlHEz6CiMBlHIB0XVzORUlIRsRhZHqEo4mNYWYRIyO8VYQtZhAydgpCPoGUAgoGFQY4ZXUpB0MTSBT5T9yb2mqAumaPAb1IatiiL+on4DQYVBtE0gZSSjO3guurc27726KOPYlkW5513Xr5DUdj7PeJ/BCJ4aWWLhBB3AAghBggh2jcflwFvCiE+BOYDz0op/7WX96vsA2effTaBQEDtCduPdm5p9otzx3Py2DJ04b35LEND7+Zd2JtanglgYMxPT+uiyLYV/lPGlnHDscN2uGz2lKpdCrAtWNfAbf9exYJ1Dfs++H7szjvvBFC9gZUDSY0L+ighBBdffDGLFi1ixYoV+Q7noBSwDAxNI51zCVsGQcvAMDSEhJa0Tdjy2pvphkY0oCMEpGwX6YIQOrYDxdEgJTE/miaJBkxsNCzTxdQkrbbOwHiIkliIQYUCvw903ULXdQxD4EhIZV2CpknAsvBrFn6fRiRoEgsYDCz0Uxr1E4/4qCqJoKORSKexgbCpURz1Ma6ygBFlBRQGfdiuS2nUR3HIIudKpC7Y1pRi6eZmGhJZSoIWH21swna9iXpBwODQyhiHDipkdFmEMRUFGIZOUdhHeUEAv6njuJLV1QmWb2nmk+oWco6b76ftoLFt2zbmzp3Leeedh9+v+q/3BnvVvkxK2WU+qZRyM3BG27/XAIfuzf0o+4dpmlxxxRX8+c9/5uOPP2bs2LH5Dumgs3NLM4BfPLMUiddj9Jazx7F0cxMPzVu/y+86vehEsAQ2NaY/3+9IOHRQjJuOH0FVUYi5S7Zw+vgKLpvm7fNs3w/eOX2/qxZmyp6RUnL77bczYsQIjjnmmHyHo/QTalzQt91www387ne/49Zbb+X222/PdzgHHV0TFEd8SCkRQpC1HUAQDpjkHBe/ZSClJJBzsV0HNyfx6TqJdI6A36Aw7KcxlWFAQQC/LtAMHSldIn6TmGYQC+hURAM0r65jYJFOdXMKTYCp64Qsg3QOhHQpjAUxNUnQ8tGcziCFpKzApCho4TiCUeVRaltzRII+AoYBukYilaUsEiAnJdUtKUrDAepaMsRDFqYh0NCQriRo6hxSFiQU8BEwNFJSknUkftc7Zk3TKAxb3T5GiYxNS9omFrJoSmZJpG3ioe6vr/Rc+8n5r3zlK3mORGmnyhH2c+0FnFR6+v7Tuep351R1KSUNySzjBhQclC04NE10nHy4bFoVD149rWMS3lnnx6SrFmbKnmlvT3L11Vcj1J47RVF64JBDDmHKlCk8/vjjuK5aidxf2v8m65qGlOC4Ek3QkXUWMnXiPovyqLdCHQoYpLIulqlRHPZRGvFTFg+RyzloUlAWCjEwFsAyTFbXteI4NsI0iYT9DC2OUuAz0XQNgcAwdHBdpNApi1sMK41w1MgyjhlZzojSGEPLIowdUEhZQZCAqVMQtBhZFmbUwDgDi0PEAz5ytmTZ1iZSOZd1Da2kMy5FUR+V8RCxoMW4gYVoQtCQzTE4FkC6EtPQ8Jv6ble4TV2AgHTOwQWM3pQe2Mc9+uijDBo0iGOPPTbfoSht1ES8n5s4cSLDhg3jkUceoZuaOso+1DlV3TQ04kHLWyGX3ofwwHiAU8aWcezI4nyHulc0AdccPZR319TtNt1858dE7RnfNx5++GHA65CgKIrSU7Nnz6a2tpaXX34536Ec9HRNELB0TF0QMI2OSbnP0rFMHV3TyDousYBF2KcTD5iMLo8RCweImAZBn05xQYB41E9AExi6hl8TlEXDVEQtBheHKQibxCMmYZ9JNCAoLwhSGPBRFPYzpChCeSxIadSbdJsGSOmtxgc0SVVxiPJ4kNKwnyOqChlWEqUg6KMlZdOazhHxG1REApQWBCiJ+Jk8pJCJVXHCfpMpVYUcUhrFMHQOKYtgaBo1LRk2NiRJ55xuH5OgZTC0OETA1BlcGCTiNw/gM3LwWr58OUuXLmX27Nnq5Hwvslep6UrfJ4Tgqquu4qc//SkvvPACp556ar5DOqjtnKreeTUYYFODV6Cl/U+kAM6dNIDLZwzhF/9cyocbm/ZZLAJ2aaXW49/tovp7Z8OKQ9z3ztoepZvv/JiotPS9l0wmueuuu5g5cyYDBgzIdziKovQh7e1N//d//5eTTz453+Ec9HRNoGteMc32lHXbdmjJ2hiaRiKTozjkoyTqR0qBaehIKQn6TSoKwjiuQ3MqS0tKw2hJEQn5kSJLwGdSHAmSymYZWRYnm8tR15qhIGjSmnaJBHSErhMwvQ90n6GTcSRBnwkuBPw+CgydjO3Sks7RnLEptXRsF0xDoNuC5nSW8kiAgQUBWjI2AjhxTBlpW9KcyhH0GbSkcmRdSSrrEPYZZGyHRMbGb3ZfQDQWtIgFVTr6vvRf//VfgKqW3tuoFXGF6667DoCHHuqu3auyL3VOVW9fDe58bjLnePupwFtZ3tSY4qaHFnRMavfVecyeTsJjwV3PRh86cHs6vWVonDdpx8nesJLw50o37/yYKHvviSeewHXdjve2oihKT5WUlDBr1izmzp1LS8vB0cmjrxBCYGgCxwXpgs/QMTRBOufguKBr0JLKsbammc1NrfgMjZDPojwWoiQSYmhJlGjQx5Airwr70KIwYwfESTs5Qn6DioIARWE/FYUBSqIhcrZDyG+gaRqr6xKs2dbCqm0tZB1JVTxAxG9guy7DSkJURP0kc173lwK/RWk0iKVphH06tpTomsBv6qRsie26JFI5GhIZLF0jaOkYuiCZtcnaLj5DTT8OJNd1+dvf/saMGTMYM2ZMvsNROlHvBIXS0lJOPvlk5syZQ2NjY77D6VfaV4MnVhbs8HNdbK+a/t7aBrY2Z1i2tYWs7VIW9XHepAGMKAkdkBi9YjI7WrSxqaOtiO24PPPRFsA7cXDDscO4fuZwlW6eR7fffjuhUEi1J1EUZY9cfvnlAPz5z3/OcyT9j6YJgn6DgKWRzjkEDI1owPJWkCWsr0+QcwR1LVlWbWumIZVBANGgxSEVcaYOK2HsoDhDi0LEQhbRoMWgWJic64LQSKQdisMBKmJBCoIWhiZwAdt20XWBjobjSgI+kxGlUUaWRgj7TTY2pWhqzbC1KY2uQ9DU8PtMQj6LutYsrWmbgGXgui7JjENx1EfYb1IZD2IaOhUFAWJBi4pYQKWbH2Bz5syhubm5432t9B5qIq4AcNNNN5HL5VT/0DyYPDjOz84eh6V7PcYtXfDL8ybwrVNGUxLx7XL9rc0Znly0mfEDC/b5G3hIUXCX1e1ktuvCKqLtZIEmBE7bpFwAkYDZcYKhc4sy5cBYv34977zzDldddRXBYDDf4SiK0gedc845DBgwgAceeCDfofRLQggKAj7iIYtYyI/P1NE1gSslGqDrgkQqB0Li0zQaUlksQ8OVErdt5TwSDuDgEjA0bClpanUwTR2fqdGUTLO1KUlrOkdTa476ljS249KadnFwiYdMQpZONGBQEQ14q/RZl+KwH01IXCnRhSDs00EIrx2rqZHKOaSzDiHLpCDgI2jpiLbsOVPXiAWtth7jyoF0//33Y5om11xzTb5DUXai3g0K4PUUtyyLOXPmcP311+c7nH5n8uA4j1w3Y5d90i2pHHe8vqbL33ly0eZ9Hsdp48q59+21PbrudccMIxIwOwrO5Wx3h9XvyYPjagKeB+0D58suuyzPkSiK0le19xS/9dZb+eSTTxg5cmS+Q+p3NE2w82Y0y9AoCFg0JnP4LEHUb+FKcKUkkc5hOzaJjLdnvLxA4NM1ElmH+tY0pq7RmMxgaBohS8fUJa4UmKbALwxMQ6MyJsg5knjYTyzo8+rBCEks4KMlbVPTksIVUO43MU0dJy2pTaQpClmUFwRwJMQCJk3JLCnbxtQ1rwq6kjcNDQ288MILXHzxxZimykTobdSKuAKApmnMnj2bV199lVWrVuU7nH6pq33SPzhjDDccO4zyqI/BhftvdXNIUZBfzZpAS8bbv7U7pRGL1bWtHZPu0WURThhTpla/88x1Xe6++24GDBjAjBkz8h2Ooih92JVXXgnAbbfdludIlHaaplEeCzCiPMrEgYVE/BaOKwkaOq6EjCMBl4AhaEzlSGVs/IZB1GdREDLw6xoDYgGCPgPHccg5Lo6UBCzdSxsvDFFVEqYk4scyNHK2S9p2cYRECDBNnYjPxJVgaBqlkQBCE2xrSrG5KUUqm6O2JUNda45EKkfUZ6gK3Xl21113AfCFL3whz5EoXVETcaXDt771LQD+8pe/5DkSpbMfnDGGd390Eq9973h+NWsCQ4q6n5AHPqMK6c4MzUst95sa/33xJC6bVtXjAm7VLVle/HgbF93+Nj/6x2I+3NjEix9vY8VWVdgnn9544w3WrVvH17/+9XyHoihKHzdp0iSmT5/Ogw8+iON0325KObA0TcMyNCzLoDDiVVP3WQaWoWM7kHMgFrIoDPqIhXzouiDgM9F1wfDSMBG/iaZp2FJgGt6qetCnUxz2YeoaQavT5FlA1GcStAxCfoPCkI+CkIVlapi6144slckhhaCxNUtDMkt9KktB0CTnSFJd1JhRDqz77ruPqqoqzj777HyHonRBTcSVDpMmTWLEiBHcf//9uO7uV0WVA2fBugZu+/cqRpdH+O+LJ+E3u37rji4L7/D9sSOL+cK0qh2S2wpDFr+aNYFHrz9ylz3cFxxe2bFXXRfs0EbN0HY9q73zq2Tuki17dHzKvnHPPfcAcOmll+Y5EkVRDgazZ8+mvr6eZ599Nt+hKJ24rsR2XKSUaJrA1AQ52yvsFgsYuBIClk7A0nGkiyYEpZEA4YCPkM/bUjasNMKgwjClUT8FQR+WaeAzdbROn/WWoXf0Oi0J+Qn6DPymTkHAR0HQwmcKIn4Tn+ntXxeAdF1yjosLajU8zz766COWLVvGJZdcop6LXkpNxJUdXHfddWzatIm5c+fmOxSlzYJ1DXzh7nf57xdW8IW73wXgSzOGdNnGbMnm5o5/C6CyLZ3d1EXH6vddVxzBZdOqukyFb9+r/p1TR/PL8ybgM73K5z5T4xfnjufksWU73N/Of0BOH1+xLw5Z2QOpVIoHH3yQU045hUGDBuU7HEVRDgJXXnklpmly++235zsUpY3rSmxXIqUka3uTcSnBdl3CfpOIzyLgMwn7LVwpCPlMCsMWftPwKqMLsHQN2/UKruldnGRvp2uCkM8g5DMoCFlE/AYh0/te1wQDYiGCPgOkV4itOBygKh5EAqVhn6qOnme/+93vALjhhhvyHInSHVWsTdnB1Vdfzfe+9z0effRRzjzzzHyHowDvrqnboSf3Ews38th7G7pMI3elxNAEsq2n5+MLNmI7LoauMXvqIC44vLJj4r1gXcMuxeFgxyJro8sjO1znsmlVPDxvPY++t57SqJ8bZg5nxdYW5i7ZwunjK7hsWtWBeEiULjz22GMAfPnLX85zJIqiHCzi8TizZs3iscceo6WlhUgkku+QFAApEZqXvWY7EsdxsXSv3ZnjuOiGN7kWQuDTvM4mrpT4TK/KuWVo6JqGENtXrbO2g5SgC+EVitO230b7Yqrf3HHaYBkag4t2zMQL+Q0K1csk72zb5vHHH+foo49m6NCh+Q5H6YaaiCs7KCws5Mwzz+TBBx/kf//3fykqUv2f8236sKKOoimm4dVQdeWu03AN70PxZ2eNoyGZZXNjikfmr8eV4DguA2OBHSbhX7j7XbK2N0m/cHLlDpP0dl1VPr9sWtUOE+72CbqSX7/73e8oKCjg3HPPzXcoiqIcRK655hoee+wxbr31Vn7605/mO5x+r31S7EoQSETbxNln6qRzOSxDR9MhlbUJWDqmriGEwHbcHdKTO6egZ3IOGdvFdh2ytle8LWDqXmq60ifdf//9NDc3q5PzvZxKTVd2ceONNwLwyCOP5DkSBdilJ/f5h1di6Du+dQVw1MhiHrpmOpdNq+Km40fscD0XrxVau86r7Fnb5ZF56/nC3e+yYF3DATwyZV9ZtWoVH3zwAZdffjl+vz/f4SiKchA5+eSTqays5OGHH853KAreCrVl6pi6wDS8/uKaJpCuJGgZBPwGpq7hM7yJdPvkWxOQSmepb0mSSGWxO3VIcdqy6RynrSaMEKRzqlZQX/bwww9jWRZf/OIX8x2K8hnURFzZxamnnkpBQQF33XUXsouVV+XA67yfe/LgOBdOrtzhcl0TfOOkUbusXjuO90EqJdzx+hoenrce2L7K3n4+XLI97f22f69SE/I+pn3/5lVXXZXnSBRFORhddtllLF++nLfffjvfoSht2ifYQngTcp9lYBqCnO0iEZjGjkP8nCNJOy5ZV5BzJc3pXEdhXkvXcFxvr7kmBI70CrQ6nYrCKX3HmjVreOWVV1Tv8D5ATcSVXRiGwVe+8pWOaotK73PB4ZX4TW8irQm45uihu0zC311Th7PTZ2d7VfP2VfbLplVhtRVya99T3l4UTk3G+wbHcXjggQeYMmUKhx9+eL7DURTlIPSNb3wDgHvvvTe/gSjd0nUNQ9cwNOF9ru9UhM1LZReYukAIbYfJtaFrhHwGRWGLSMDAMrzrpbIO6ZxDKueoyXgfct999wHb37dK76Um4kqXLrnkEgD+9Kc/5TkSpSuTB8f52Vnj0DWBlHDfO2tZsK6ho83ZgnUNTB9WxE4nxHeoaj55cJz/nDWBR66bwbdOGc1FRwzCdrYXhXt3Td0BPiplT/zrX/+itraWiy++ON+hKIpykKqoqGDGjBk88sgjJJPJfIejdMF1JTlXousatkvbCrdXyE1K6fUe1zVsR+JKF7+poWnbBwmaJtB1Db9p4DcNHNc70W/oGq7rZdYpvZ/ruvzlL39h8ODBTJ48Od/hKLuhJuJKlyZOnMj06dO5//77sW073+EoXWhIZnGlRAKZnMsdr63epc3Zo9cfycljyzi0soBfzZrQZVG19rT38w+v9CqpCjANjenDVKG+vuCOO+7ANE2uueaafIeiKMpB7KabbiKZTDJnzpx8h6J0Q+BNwDM5h5zt4DheMTfHkWgCwgGT8gI/xWEfIb/1mbdlGhqulNiui67tWNxN6b1eeOEFNm7cyE033ZTvUJQeUBNxpVs33ngjiUSCv/zlL/kORemCt+LtfTBK4JXl1WRyO65oTx4c564rjuCpm4/ebWXznYvC7ZzqrvQ+69at45lnnmH27NnEYrF8h6MoykHswgsvJB6P8/vf/z7fofRbXs/wrpemNU2gARnbwTI1XHf73u72YulehfUdV8K7o2vCK/5mGvhNVT29r/jd736HrutcffXV+Q5F6QE1EVe69cUvfhGfz8dDDz2U71D6rc6p5jubPDjORUcM2l5wra13+J6uaHfXV1zpvR588EEAvvKVr+Q5EkVRDnY+n48rr7ySDz/8kJUrV+Y7nH7Hcdy2nuES1+16Mm4aOkHLwDJ0dN2bPG+fjPd8Rdt1JcmsTSpr47ru5/pdJX8aGhr417/+xfnnn09hYWG+w1F6QE3ElW5pmsbs2bN57bXXWLVqVb7D6Xfae31/VvG08w+vxGd66eSWofGLc8fv0Yp2T+5L6V1c1+Xuu+9m4MCBzJgxI9/hKIrSD3zpS18CVP2YfHClt+otNIH7WaviwptIawJMU0fXNfS2VqY9LbiWtV2vgromyKiq6X3GXXfdBcDll1+e50iUnlITceUzffOb3wTgz3/+c54j6X869/rurnjazunk7T3EJw+Of+Zq+p7cl9K7vPzyy6xbt46vf/3r+Q5FUZR+4tBDD2X69Oncd999ZLPZfIfTr7RPsL1Jdvcr1Kah4zN1TGN7Onl70TbXlR1tTT+LEO1p8KDWwvsGKSV33nknQ4YM4ayzzsp3OEoPqYm48pkmTZrEpEmT+NOf/tTRb1I5MNp7fe8u1bxzj/F2n3eFu6f3pfQev//979E0jSuvvDLfoSiK0o9ce+21NDU18de//jXfofQruq6hawJDF3tcOK09xXx3K9yWoWFoGgjwm4ZKTe8DXn75ZVavXs1VV12lnq8+RE3Eld268cYbSSaTaq/4AbY3xdM+7wq3KtTWt2zevJlnnnmGCy64gNLS0nyHoyhKP3LZZZcRDodVplweaJrY40mWlD3fLy6EwG95+8137keu9E633XYbADfccEOeI1E+DzURV3brqquuwu/3c//99+c7lH6nq9XuntiTFe49vS/lwHvggQcAuPnmm/MciaIo/Y3f7+fqq69m/vz5rF27Nt/hKD0ghEBvW0lv3y+uHDyampp48sknueiii9TJ+T5GvRuV3TJNk1mzZvHyyy+rD90+Qq1wH7yklNx9992UlpZyzDHH5DscRVH6oUsvvRSAP/7xj3mOROkpIfZ8NV3p3e655x4AZs+enedIlM9LTcSVHvnWt74FbH+zK72fWuE+OL3++uusXr2ar371q2pQpShKXkybNo3DDjuM++67T9WPUZQ8u+eee6ioqGDWrFn5DkX5nNREXOmRyZMnM27cOH7/+9+rD11FyaNbb70VgKuvvjq/gSiK0q9dd9111NXVqfoxipJHr732Gh9//DHXXHMNmqamdX2NesaUHhFCcP3119Pc3MwTTzyR73AUpV/aunUrTz75JGeffTYVFRX5DkdRlH7sS1/6EoZhqKJtipJH7dtDVJG2vklNxJUeu/baa/H7/dxxxx35DkVR+qV7770XgG9+85t5jkRRlP7O7/dzww038NZbb7FmzZp8h6Mo/U5DQwOPP/44559/PgMGDMh3OMoeUBNxpcf8fj9f+MIXeOWVV1i+fHm+w1GUfkVKyR/+8AeGDh3Kcccdl+9wFEVRuO666wD47W9/m+dIFKX/aW9Z9pWvfCXPkSh7Sk3Elc/la1/7GgC33357niNRlP7l2WefZevWrVx//fWqSJuiKL3ChAkTmDZtGn/9619JJpP5DkdR+g3XdbnrrrsYNGgQJ554Yr7DUfbQXk3EhRC3CCE2CSEWtX2d0c31ThNCrBBCrBJC/GBv7lPJr4kTJzJjxgzuu+8+0ul0vsNRlH6j/eTXV7/61TxHoijdU+OC/uerX/0qiUSCRx55JN+hKEq/8eKLL7J+/fqOBTKlb9oXK+L/K6Wc1Pb13M4XCiF04DbgdGAscKkQYuw+uF8lT6699lqam5vVh66iHCAbN27kueee49JLLyUYDOY7HEXZHTUu6EcuuugiotEof/rTn/IdiqL0G3/84x/RdZ0rr7wy36Eoe+FApKZPBVZJKddIKbPAHODcA3C/yn5y6aWXEo1GufXWW5FS5jscRTno/c///A8AN998c54jUZR9Qo0LDiKWZXH11VezcOFC3n777XyHoygHvTVr1vDMM89w4YUXUlJSku9wlL2wLybiNwshPhJC3CuEiHdx+UBgQ6fvN7b9TOmj/H4/P/7xjxFCsGrVqnyHoygHvUWLFnHBBRdw5JFH5jsURekJNS7oZ77zne8wbtw4Xn755XyHoigHveeff54xY8bwk5/8JN+hKHtJ7G5FUwjxElDexUU/Bt4FagEJ/BKokFJ+eaffvxA4TUp5Tdv3lwPTpJRdLu0IIa4Drmv7djSwosdH89mK22Lty/r6MfT1+EEdQ2+hjqF3UMfQtcFSyoN2meJAjgvUmOAzqWPoHdQx5F9fjx/UMfQWB3RMYOzuN6WUJ/XkHoQQdwHPdHHRJmBQp+8r237W3f3dCdzZk/v8PIQQ70spj9jXt3sg9fVj6OvxgzqG3kIdQ++gjqF/OpDjAjUm6J46ht5BHUP+9fX4QR1Db3Ggj2Fvq6ZXdPp2FrCki6u9B4wUQgwVQljAJcDTe3O/iqIoiqL0PmpcoCiKoig9s9sV8d34f0KISXgpaGuB6wGEEAOAu6WUZ0gpbSHEzcDzgA7cK6Vcupf3qyiKoihK76PGBYqiKIrSA3s1EZdSXt7NzzcDZ3T6/jlglxYmB9g+T23Lg75+DH09flDH0FuoY+gd1DEoO+hD44KD4XlXx9A7qGPIv74eP6hj6C0O6DHstliboiiKoiiKoiiKoij7zoHoI64oiqIoiqIoiqIoSpuDZiIuhLhICLFUCOEKIY7Y6bIfCiFWCSFWCCFO7eb3hwoh5rVd79G2AjJ50xbDoravtUKIRd1cb60QYnHb9d4/wGF+JiHELUKITZ2O44xurnda23OzSgjxgwMd52cRQvxWCLG8rSfuP4QQsW6u1+ueh909rkIIX9vrbFXba39IHsLslhBikBDi30KIj9ve21/v4jrHCSGaOr3GfpaPWD/L7l4bwvP7tufhIyHE4fmIsztCiNGdHt9FQohmIcQ3drpOr3sehNfDuloIsaTTzwqFEC8KIT5p+39XPa4RQlzZdp1PhBBXHriolX1JjQt6z+dROzUuyB81Jugd1JggP3rtmEBKeVB8AWPweoy+ChzR6edjgQ8BHzAUWA3oXfz+Y8Albf++A7gx38fUKbb/Bn7WzWVrgeJ8x9hNbLcA39nNdfS252QYYLU9V2PzHXun+E4BjLZ//wb4TV94HnryuAJfAe5o+/clwKP5jnun+CqAw9v+HQFWdnEMxwHP5DvW3RzHZ7428PbNzgUEMB2Yl++Yd/O62orXE7NXPw/AscDhwJJOP/t/wA/a/v2Drt7PQCGwpu3/8bZ/x/N9POprj14DalzQy77UuCBvMasxQS/5UmOCvMXaK8cEB82KuJRymZRyRRcXnQvMkVJmpJSfAquAqZ2vIIQQwAnA420/uh84bz+G22NtsV0MPJLvWPaTqcAqKeUaKWUWmIP3nPUKUsoXpJR227fv4vW77Qt68riei/daB++1f2Lb661XkFJukVIubPt3C7AMGJjfqPaLc4EHpOddICZ2bAHVm5wIrJZSrst3ILsjpXwdqN/px51f8939nT8VeFFKWS+lbABeBE7bX3Eq+48aF/RZalyw76kxQd+hxgT7QW8dExw0E/HPMBDY0On7jez6xi0CGjv9Ye3qOvlyDLBNSvlJN5dL4AUhxAIhxHUHMK6eurkttebeblI+evL89BZfxjtL2ZXe9jz05HHtuE7ba78J773Q67SlyB0GzOvi4hlCiA+FEHOFEOMObGQ9srvXRl96D1xC94P/3v48AJRJKbe0/XsrUNbFdfrS86HsGTUuyC81Ljjw1Jig91Bjgt4j72OCve0jfkAJIV4Cyru46MdSyqcOdDx7q4fHcymffdb7aCnlJiFEKfCiEGJ521mfA+KzjgG4Hfgl3h+dX+Kl0n35QMXWUz15HoQQPwZs4KFubiavz8PBTAgRBp4AviGlbN7p4oV4KVEJ4e01fBIYeYBD3J2D4rUhvP2x5wA/7OLivvA87EBKKYUQqm1IH6fGBV1S44K9pMYFvZcaE/QOakywb/SpibiU8qQ9+LVNwKBO31e2/ayzOrzUD6PtLGBX19nndnc8QggDOB+Y/Bm3sant/9VCiH/gpR8dsDd0T58TIcRdwDNdXNST52e/6sHz8CXgLOBE2bZhpIvbyOvz0IWePK7t19nY9lorwHsv9BpCCBPvA/chKeXfd76884ewlPI5IcSfhBDFUsraAxnnZ+nBayPv74EeOh1YKKXctvMFfeF5aLNNCFEhpdzSlupX3cV1NuHtb2tXibfHWOmF1Ligy9tQ44K9dBCOC9SYoJdQY4JeJe9jgv6Qmv40cInwqkEOxTsjM7/zFdr+iP4buLDtR1cCveFM+knAcinlxq4uFEKEhBCR9n/jFRBZ0tV182GnPS2z6Dq294CRwqtOa+GluTx9IOLrCSHEacD3gHOklMlurtMbn4eePK5P473WwXvtv9LdgCIf2vam3QMsk1L+TzfXKW/fwyaEmIr3N63XDBx6+Np4GrhCeKYDTZ1SpXqTblfhevvz0Enn13x3f+efB04RQsTb0mZPafuZcvBQ44I8UeOCvFFjgl5AjQl6nfyPCWQvqGS3L77w/qBvBDLANuD5Tpf9GK9a5Arg9E4/fw4Y0PbvYXgfxKuAvwG+XnBM9wE37PSzAcBznWL+sO1rKV7KVN6fi06xPggsBj7Ce7FX7HwMbd+fgVf9cnUvPIZVeHtDFrV9tVcU7fXPQ1ePK/ALvMEDgL/ttb6q7bU/LN8x7xT/0Xjpix91evzPAG5of18AN7c95h/iFc05Mt9x73QMXb42djoGAdzW9jwtplN1597yBYTwPkQLOv2sVz8PeAOELUCu7bPharz9ji8DnwAvAYVt1z0CuLvT73657X2xCrgq38eivvb4NaDGBb3k86hTrGpckL+41Zgg/8egxgT5i7lXjglE2x0oiqIoiqIoiqIoinIA9IfUdEVRFEVRFEVRFEXpNdREXFEURVEURVEURVEOIDURVxRFURRFURRFUZQDSE3EFUVRFEVRFEVRFOUAUhNxRVEURVEURVEURTmA1ERcURRFURRFURRFUQ4gNRFXFEVRFEVRFEVRlANITcQVRVEURVEURVEU5QD6/z+dl3rrybMqAAAAAElFTkSuQmCC", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plot_contour(logprob, orbits=samples, weights=weights)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "interpreter": { + "hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1" + }, + "kernelspec": { + "display_name": "imcmc_blackjax", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/test_sampling.py b/tests/test_sampling.py index 66edaec90..f87d26b37 100644 --- a/tests/test_sampling.py +++ b/tests/test_sampling.py @@ -24,6 +24,17 @@ def one_step(state, rng_key): return states +def orbit_samples(orbits, weights, rng_key): + def sample_orbit(orbit, weights, rng_key): + sample = jax.random.choice(rng_key, orbit, p=weights) + return sample + + keys = jax.random.split(rng_key, orbits.shape[0]) + samples = jax.vmap(sample_orbit)(orbits, weights, keys) + + return samples + + regresion_test_cases = [ { "algorithm": blackjax.hmc, @@ -115,6 +126,17 @@ def test_linear_regression(self, case, is_mass_matrix_diagonal): "num_sampling_steps": 6000, "burnin": 5_000, }, + { + "algorithm": blackjax.orbital_hmc, + "initial_position": jnp.array(100.0), + "parameters": { + "step_size": 0.1, + "inverse_mass_matrix": jnp.array([0.1]), + "period": 100, + }, + "num_sampling_steps": 20_000, + "burnin": 15_000, + }, { "algorithm": blackjax.rmh, "initial_position": 1.0, @@ -155,7 +177,13 @@ def test_univariate_normal( functools.partial(inference_loop, kernel, num_sampling_steps) )(self.key, initial_state) - samples = states.position[burnin:] + if algorithm == blackjax.orbital_hmc: + _, orbit_key = jax.random.split(self.key) + samples = orbit_samples( + states.positions[burnin:], states.weights[burnin:], orbit_key + ) + else: + samples = states.position[burnin:] np.testing.assert_allclose(np.mean(samples), 1.0, rtol=1e-1) np.testing.assert_allclose(np.var(samples), 4.0, rtol=1e-1) @@ -206,7 +234,7 @@ def logprob_fn(x): def mcse_test(self, samples, true_param, p_val=0.01): posterior_mean = jnp.mean(samples, axis=[0, 1]) - ess = diagnostics.effective_sample_size(samples, chain_axis=1, sample_axis=0) + ess = diagnostics.effective_sample_size(samples, chain_axis=0, sample_axis=1) posterior_sd = jnp.std(samples, axis=0, ddof=1) avg_monte_carlo_standard_error = jnp.mean(posterior_sd, axis=0) / jnp.sqrt(ess) scaled_error = ( @@ -233,7 +261,7 @@ def test_mcse(self, algorithm, parameters): ) states = inference_loop_multiple_chains(multi_chain_sample_key, initial_states) - posterior_samples = states.position[-1000:] + posterior_samples = states.position[:, -1000:] posterior_delta = posterior_samples - true_loc posterior_variance = posterior_delta**2.0 posterior_correlation = jnp.prod(posterior_delta, axis=-1, keepdims=True) / (