From e38a9329283d774d869b453a5a5e364402204ca7 Mon Sep 17 00:00:00 2001 From: andyElking Date: Sun, 18 Aug 2024 15:44:28 +0100 Subject: [PATCH] Added Langevin docs, a Langevin example and backwards in time test --- diffrax/_solver/align.py | 16 +--- diffrax/_solver/langevin_srk.py | 10 ++- diffrax/_solver/quicsort.py | 17 +--- diffrax/_solver/should.py | 16 +--- docs/api/solvers/sde_solvers.md | 37 ++++++++ examples/langevin_example.ipynb | 151 ++++++++++++++++++++++++++++++++ mkdocs.yml | 1 + test/test_langevin.py | 37 ++++++++ 8 files changed, 242 insertions(+), 43 deletions(-) create mode 100644 examples/langevin_example.ipynb diff --git a/diffrax/_solver/align.py b/diffrax/_solver/align.py index fae453ea..7ce805e2 100644 --- a/diffrax/_solver/align.py +++ b/diffrax/_solver/align.py @@ -19,10 +19,7 @@ ) -# UBU evaluates at l = (3 -sqrt(3))/6, at r = (3 + sqrt(3))/6 and at 1, -# so we need 3 versions of each coefficient - - +# For an explanation of the coefficients, see langevin_srk.py class _ALIGNCoeffs(AbstractCoeffs): beta: PyTree[ArrayLike] a1: PyTree[ArrayLike] @@ -46,15 +43,8 @@ def __init__(self, beta, a1, b1, aa, chh): class ALIGN(AbstractLangevinSRK[_ALIGNCoeffs, _ErrorEstimate]): r"""The Adaptive Langevin via Interpolated Gradients and Noise method - designed by James Foster. Only works for Underdamped Langevin Diffusion - of the form - - $$d x_t = v_t dt$$ - - $$d v_t = - gamma v_t dt - u ∇f(x_t) dt + (2gammau)^(1/2) dW_t$$ - - where $v$ is the velocity, $f$ is the potential, $gamma$ is the friction, and - $W$ is a Brownian motion. + designed by James Foster. + Accepts only terms given by [`diffrax.make_langevin_term`][]. """ interpolation_cls = LocalLinearInterpolation diff --git a/diffrax/_solver/langevin_srk.py b/diffrax/_solver/langevin_srk.py index 5884c917..59a444f3 100644 --- a/diffrax/_solver/langevin_srk.py +++ b/diffrax/_solver/langevin_srk.py @@ -198,7 +198,7 @@ def init( args: PyTree, ) -> SolverState: """Precompute _SolverState which carries the Taylor coefficients and the - SRK coefficients (which can be computed from h and the Taylor coeffs). + SRK coefficients (which can be computed from h and the Taylor coefficients). Some solvers of this type are FSAL, so _SolverState also carries the previous evaluation of grad_f. """ @@ -259,16 +259,18 @@ def step( gamma, u, f = get_args_from_terms(terms) h = drift.contr(t0, t1) - h_state = st.h + h_prev = st.h tay: PyTree[_Coeffs] = st.taylor_coeffs coeffs: _Coeffs = st.coeffs # If h changed recompute coefficients - cond = jnp.isclose(h_state, h, rtol=1e-10, atol=1e-12) + # Even when using constant step sizes, h can fluctuate by small amounts, + # so we use `jnp.isclose` for comparison + cond = jnp.isclose(h_prev, h, rtol=1e-10, atol=1e-12) coeffs = lax.cond( cond, lambda x: x, - lambda _: self._recompute_coeffs(h, gamma, tay, h_state), + lambda _: self._recompute_coeffs(h, gamma, tay, h_prev), coeffs, ) diff --git a/diffrax/_solver/quicsort.py b/diffrax/_solver/quicsort.py index 6b128116..58749592 100644 --- a/diffrax/_solver/quicsort.py +++ b/diffrax/_solver/quicsort.py @@ -22,10 +22,9 @@ ) +# For an explanation of the coefficients, see langevin_srk.py # UBU evaluates at l = (3 -sqrt(3))/6, at r = (3 + sqrt(3))/6 and at 1, # so we need 3 versions of each coefficient - - class _QUICSORTCoeffs(AbstractCoeffs): beta_lr1: PyTree[ArrayLike] # (gamma, 3, *taylor) a_lr1: PyTree[ArrayLike] # (gamma, 3, *taylor) @@ -66,14 +65,7 @@ class QUICSORT(AbstractLangevinSRK[_QUICSORTCoeffs, None]): } ``` - Works for underdamped Langevin SDEs of the form - - $$d x_t = v_t dt$$ - - $$d v_t = - gamma v_t dt - u ∇f(x_t) dt + (2gammau)^(1/2) dW_t$$ - - where $v$ is the velocity, $f$ is the potential, $gamma$ is the friction, and - $W$ is a Brownian motion. + Accepts only terms given by [`diffrax.make_langevin_term`][]. """ interpolation_cls = LocalLinearInterpolation @@ -246,9 +238,8 @@ def _one(coeff): ).ω v_out = (v_out_tilde**ω - st.rho**ω * (hh**ω - 6 * kk**ω)).ω - f_fsal = ( - st.prev_f - ) # this method is not FSAL, but this is for compatibility with the base class + # this method is not FSAL, but for compatibility with the base class we set + f_fsal = st.prev_f # TODO: compute error estimate return x_out, v_out, f_fsal, None diff --git a/diffrax/_solver/should.py b/diffrax/_solver/should.py index 813a31b5..27138e01 100644 --- a/diffrax/_solver/should.py +++ b/diffrax/_solver/should.py @@ -19,10 +19,7 @@ ) -# UBU evaluates at l = (3 -sqrt(3))/6, at r = (3 + sqrt(3))/6 and at 1, -# so we need 3 versions of each coefficient - - +# For an explanation of the coefficients, see langevin_srk.py class _ShOULDCoeffs(AbstractCoeffs): beta_half: PyTree[ArrayLike] a_half: PyTree[ArrayLike] @@ -63,15 +60,8 @@ def __init__(self, beta_half, a_half, b_half, beta1, a1, b1, aa, chh, ckk): class ShOULD(AbstractLangevinSRK[_ShOULDCoeffs, None]): r"""The Shifted-ODE Runge-Kutta Three method - designed by James Foster. Only works for Underdamped Langevin Diffusion - of the form - - $$d x_t = v_t dt$$ - - $$d v_t = - gamma v_t dt - u ∇f(x_t) dt + (2gammau)^(1/2) dW_t$$ - - where $v$ is the velocity, $f$ is the potential, $gamma$ is the friction, and - $W$ is a Brownian motion. + designed by James Foster. + Accepts only terms given by [`diffrax.make_langevin_term`][]. """ interpolation_cls = LocalLinearInterpolation diff --git a/docs/api/solvers/sde_solvers.md b/docs/api/solvers/sde_solvers.md index ecc9fbf9..fc91e643 100644 --- a/docs/api/solvers/sde_solvers.md +++ b/docs/api/solvers/sde_solvers.md @@ -113,3 +113,40 @@ These are reversible in the same way as when applied to ODEs. [See here.](./ode_ selection: members: - __init__ + + +--- + +### Underdamped Langevin solvers + +These solvers are specifically designed for the Underdamped Langevin diffusion (ULD), +which takes the form + +$d \mathbf{x}_t = \mathbf{v}_t dt$ + +$d \mathbf{v}_t = - \gamma \mathbf{v}_t dt - u +\nabla f( \mathbf{x}_t ) dt + \sqrt{2 \gamma u} d W_t.$ + +where $\mathbf{x}_t, \mathbf{v}_t \in \mathbb{R}^d$ represent the position +and velocity, $W$ is a Brownian motion in $\mathbb{R}^d$, +$f: \mathbb{R}^d \rightarrow \mathbb{R}$ is a potential function, and +$\gamma , u \in \mathbb{R}^{d \times d}$ are diagonal matrices governing +the friction and the dampening of the system. + +They are more precise for this diffusion than the general-purpose solvers above, but +cannot be used for any other SDEs. They only accept terms generated by the +[`diffrax.make_langevin_term`][] function. They all have the same `__init__` signature. +For an example of their usage, see the [Langevin example](../../examples/langevin_example.ipynb). + +::: diffrax.ALIGN + selection: + members: + - __init__ + +::: diffrax.ShOULD + selection: + members: false + +::: diffrax.QUICSORT + selection: + members: false \ No newline at end of file diff --git a/examples/langevin_example.ipynb b/examples/langevin_example.ipynb new file mode 100644 index 00000000..818f6de0 --- /dev/null +++ b/examples/langevin_example.ipynb @@ -0,0 +1,151 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "d0b76d478a4179b0", + "metadata": {}, + "source": [ + "# Underdamped Langevin Diffusion simulation\n", + "\n", + "The Underdamped Langevin diffusion (ULD) is an SDE of the form:\n", + "\n", + "\\begin{align*}\n", + " d \\mathbf{x}_t &= \\mathbf{v}_t \\, dt \\\\\n", + " d \\mathbf{v}_t &= - \\gamma \\, \\mathbf{v}_t \\, dt - u \\,\n", + " \\nabla \\! f( \\mathbf{x}_t ) \\, dt + \\sqrt{2 \\gamma u} \\, d W_t,\n", + "\\end{align*}\n", + "\n", + "where $\\mathbf{x}_t, \\mathbf{v}_t \\in \\mathbb{R}^d$ represent the position\n", + "and velocity, $W$ is a Brownian motion in $\\mathbb{R}^d$,\n", + "$f: \\mathbb{R}^d \\rightarrow \\mathbb{R}$ is a potential function, and\n", + "$\\gamma , u \\in \\mathbb{R}^{d \\times d}$ are diagonal matrices governing\n", + "the friction and the dampening of the system.\n", + "\n", + "## ULD for Monte Carlo and Bayesian inference\n", + "\n", + "ULD is commonly used in Monte Carlo applications since it allows us to sample from its stationary distribution $p = \\frac{\\exp(-f)}{C}$ even when its normalising constant $C = \\int p(x) dx$ is unknown. This is because only knowledge of $\\nabla f$ is required, which doesn't depend on $C$. For an example of such an application see section 5.2 of the paper on [Single-seed generation of Brownian paths](https://arxiv.org/abs/2405.06464).\n", + "\n", + "## ULD solvers in Diffrax\n", + "\n", + "In addition to generic SDE solvers (which can solve any SDE including ULD), Diffrax has some solvers designed specifically for ULD. These are `diffrax.ALIGN` which has a 2nd order of strong convergence, and `diffrax.QUICSORT` and `diffrax.ShOULD` which are 3rd order solvers. Note that unlike ODE solvers which can have orders of 5 or even higher, very few types of SDEs permit solvers with a strong order greater than $\\frac{1}{2}$.\n", + "\n", + "These Langevin-specific solvers only accept terms retuned by `diffrax.make_langevin_term`.\n", + "\n", + "## A 2D harmonic oscillator\n", + "\n", + "In this example we will simulate a simple harmonic oscillator in 2 dimensions. This system is given by the potential $f(x) = x^2$." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "9deba250066ddc39", + "metadata": { + "ExecuteTime": { + "end_time": "2024-08-18T14:34:51.363600Z", + "start_time": "2024-08-18T14:34:48.606225Z" + } + }, + "outputs": [], + "source": [ + "from warnings import simplefilter\n", + "\n", + "\n", + "simplefilter(action=\"ignore\", category=FutureWarning)\n", + "import diffrax\n", + "import jax.numpy as jnp\n", + "import jax.random as jr\n", + "import matplotlib.pyplot as plt\n", + "\n", + "\n", + "t0, t1 = 0.0, 20.0\n", + "dt0 = 0.05\n", + "saveat = diffrax.SaveAt(steps=True)\n", + "\n", + "# Parameters\n", + "gamma = jnp.array([2, 0.5], dtype=jnp.float32)\n", + "u = jnp.array([0.5, 2], dtype=jnp.float32)\n", + "x0 = jnp.zeros((2,), dtype=jnp.float32)\n", + "v0 = jnp.zeros((2,), dtype=jnp.float32)\n", + "y0 = (x0, v0)\n", + "\n", + "# Brownian motion\n", + "bm = diffrax.VirtualBrownianTree(\n", + " t0, t1, tol=0.01, shape=(2,), key=jr.key(0), levy_area=diffrax.SpaceTimeTimeLevyArea\n", + ")\n", + "\n", + "# Use the make_langevin_term function to create the terms\n", + "terms = diffrax.make_langevin_term(gamma, u, lambda x: 2 * x, bm, x0)\n", + "\n", + "solver = diffrax.QUICSORT(0.1)\n", + "sol = diffrax.diffeqsolve(\n", + " terms, solver, t0, t1, dt0=dt0, y0=y0, args=None, saveat=saveat\n", + ")\n", + "xs, vs = sol.ys" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "62da2ddbaaf98f47", + "metadata": { + "ExecuteTime": { + "end_time": "2024-08-18T14:34:52.378938Z", + "start_time": "2024-08-18T14:34:52.259089Z" + } + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAA1kAAANBCAYAAAAShHTFAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/TGe4hAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzddXhUV/rA8e+dibt7QgLBNbgUimtpabcu1Lv1dtluf9uV+rbdylapC3VvqeEOxQnBgyUhbsR97PfHmUmgJMRm5s4k5/M8PPeSzNz75kZm3nve8x7FZDKZkCRJkiRJkiRJkqxCo3YAkiRJkiRJkiRJXYlMsiRJkiRJkiRJkqxIJlmSJEmSJEmSJElWJJMsSZIkSZIkSZIkK5JJliRJkiRJkiRJkhXJJEuSJEmSJEmSJMmKZJIlSZIkSZIkSZJkRTLJkiRJkiRJkiRJsiIXtQNwdEajkdzcXHx9fVEURe1wJEmSJEmSJElSiclkorKykqioKDSalserZJLVitzcXGJjY9UOQ5IkSZIkSZIkB5GVlUVMTEyLn5dJVit8fX0BcSH9/PxUiUGn07Fq1SpmzpyJq6urKjF0dfIa25a8vrYlr6/tyWtsW/L62pa8vrYlr6/tOdI1rqioIDY2tjFHaIlMslphKRH08/NTNcny8vLCz89P9R+srkpeY9uS19e25PW1PXmNbUteX9uS19e25PW1PUe8xq1NI5KNLyRJkiRJkiRJkqxIJlmSJEmSJEmSJElWJJMsSZIkSZIkSZIkK5JzsiRJkiRJkiRJAkSLcr1ej8FgUDuURjqdDhcXF+rq6mwel1arxcXFpdNLN8kkS5IkSZIkSZIkGhoayMvLo6amRu1QzmIymYiIiCArK8su69Z6eXkRGRmJm5tbh48hkyxJkiRJkiRJ6uaMRiPp6elotVqioqJwc3OzS0LTFkajkaqqKnx8fM67AHBnmUwmGhoaKCoqIj09nd69e3f4fDLJkiRJkiRJkqRurqGhAaPRSGxsLF5eXmqHcxaj0UhDQwMeHh42TbIAPD09cXV15dSpU43n7AjZ+EKSJEmSJEmSJACbJzHOwBrXQF5FSZIkSZIkSZIkK5JJliRJkiRJkiRJkhXJJEuSJEmSJEmSJMmKZJIlSZLkiOoq4NBSyEkGo1HtaCRJkiTJKeXl5XHttdfSp08fNBoNDz74oF3OK7sLSpIkORKTCTa/CL+/DvXl4mN+MXDVJxA9Qt3YJEmSJMnJ1NfXExoayr/+9S9efvllu51XjmRJkiQ5kpQvYN3TIsEKiAM3H6jIhu9uhYZqtaOTJEmSugmTyURNg16VfyaTqc1xFhUVERERwTPPPNP4sa1bt+Lm5sbatWuJj4/n1VdfZeHChfj7+9viUjVLjmRJkiQ5ipoSWP1vsT/pYZj8CNRXwFvjoTQdVj8G815UN0ZJkiSpW6jVGRjw6EpVzn34yVl4ubUtTQkNDeXDDz9kwYIFzJw5k759+3LDDTdw7733Mm3aNBtH2jI5kiVJkuQo1j0FNachtB9c+DBoNOAZAJe8IT6/6z3I3qNqiJIkSZLkaObOncvtt9/Oddddx5133om3tzfPPvusqjHJkSxJkiRHUFMCyZ+I/Xkvgda16XO9psLQa2Dfl7D+abjhR3VilCRJkroNT1cth5+cpdq52+vFF19k0KBBfPvtt+zZswd3d3cbRNZ2MsmSJElyBEd+BqMewgdB/AXnfn7y3+HAt3ByHZzaCj3G2z9GSZIkqdtQFKXNJXuO4OTJk+Tm5mI0GsnIyGDw4MGqxiPLBSVJkhzBge/EdtCfmv98YDwk3SD21/3HLiFJkiRJkjNoaGjg+uuv56qrruKpp57itttuo7CwUNWYZJIlSZKktoo8yNgi9ltKsgAmPQQaVzi1BbJ32yc2SZIkSXJw//znPykvL+e1117j//7v/+jTpw+33HJL4+dTUlJISUmhqqqKoqIiUlJSOHz4sE1jkkmWJEmS2g4vBUwQMxoCe7T8OP8YGHyF2N/2hj0ikyRJkiSHtmHDBl555RU+/fRT/Pz80Gg0fPrpp2zevJm33noLgKSkJJKSktizZw9ffPEFSUlJzJ0716ZxOU+hpSRJUld1bIXYDrqs9ceOuwf2fQGHf4LSU+dPyiRJkiSpi5s8eTI6ne6sj8XHx1NeXt74//asu2UtciRLkiRJTSYT5OwV+z0mtP74iEHQcwqYjLDjHdvGJkmSJElSh8gkS5IkSU0laVBfDlp3COvftueMvUtsD3wDRoPtYpMkSZIkqUNkkiVJkqSmXPMoVsTgs9fGOp9eU8EjAKqL4NTvNgtNkiRJkqSOkUmWJEmSmixJVvTwtj9H6wr9LxL7h5ZaPSRJkiRJkjpHJlmSJElqykkW26ik9j1vwKVie+QXWTIoSZIkSQ5GJlmSJElqMRogb5/Yb2+S1fNCc8lgIUrWNquHJkmSJElSx8kkS5IkSS3Fx0FXDa7eENKnfc/VukI/UTKoHF1mg+AkSZIkSeoomWRJkiSpJddcKhg5FDTa9j+/zywANCfXWDEoSZIkSZI6SyZZkiRJaik4JLaRQzr2/J6TQeOCUpKGV32B1cKSJEmSJKlzZJIlSZKkluLjYtveUkELDz+IHQtAeMV+KwUlSZIkSV3HDz/8wIwZMwgNDcXPz49x48axcuVKm5/XqZKsTZs2MX/+fKKiolAUhaVLl5738Rs2bEBRlHP+5efn2ydgSZKk8yk+JrYdTbIAek8HIEwmWZIkSZJ0jk2bNjFjxgyWLVvGnj17mDJlCvPnz2fv3r02Pa9TJVnV1dUMHTqUxYsXt+t5R48eJS8vr/FfWFiYjSKUJElqI10dlJ0S+51JshJniENUHgF9nRUCkyRJkiTnUVRUREREBM8880zjx7Zu3Yqbmxtr167llVde4eGHH2bUqFH07t2bZ555ht69e/PLL7/YNC4Xmx7dyubMmcOcOXPa/bywsDACAgKsH5AkSVJHlaSByQju/uDTiRs/4QMx+UbiUpmHPmsn9JlmvRglSZKk7stkAl2NOud29QJFadNDQ0ND+fDDD1mwYAEzZ86kb9++3HDDDdx7771Mm3bua6LRaKSyspKgoCBrR30Wp0qyOmrYsGHU19czaNAgHn/8cSZMmNDiY+vr66mvr2/8f0VFBQA6nQ6dTmfzWJtjOa9a5+8O5DW2LXl9z6UUHMYFMAYnYtDrO3es6NG4pP6EMXsXuoRJ1glQOov8GbYteX1tS15f2+oq11en02EymTAajRiNRmioRvNcjCqxGP+eDW7ejf83mUyNW6PReM7jZ8+ezW233cZ1113HiBEj8Pb25j//+U+zj33hhReoqqri8ssvb/bzIBIxk8mETqdDqz27+29bv8+KyRK1k1EUhR9//JEFCxa0+JijR4+yYcMGRo4cSX19Pe+//z6ffvopO3bsYPjw4c0+5/HHH+eJJ5445+NffPEFXl5e1gpfkqRurk/+Uvrn/UBm0AXs7XFHp47Vq2A5g3K/JNd/BLt6PmClCCVJkqTuxMXFhYiICGJjY3FzcwNdDQGL+6sSS9k9R8RoVjvU1tYyfvx4cnJyWL9+PQMHDjznMd9++y0PPvggn3/+OZMnT27xWA0NDWRlZZGfn4/+DzdCa2pquPbaaykvL8fPz6/FY3TpJKs5F154IXFxcXz66afNfr65kazY2FiKi4vPeyFtSafTsXr1ambMmIGrq6sqMXR18hrblry+59Iu/TOaQ99jmPIoxvH3d+pYhrRNeHx5GUafSAwPHLBShNKZuvrPcFpRNSU1DQyL8cdFa//p2l39+qpNXl/b6irXt66ujqysLOLj4/Hw8HCockGTyURlZSW+vr4oLZQRHjx4kDFjxqDT6fj++++ZP3/+WZ//6quvuO222/j666+ZN2/eeU9fV1dHRkYGsbGx4lqcoaKigpCQkFaTrG5RLnim0aNHs2XLlhY/7+7ujru7+zkfd3V1Vf0XxxFi6OrkNbYteX3PUHICAG14P7SdvSYxwzGhoKnKQ1N3GnwjrBCg1Jyu+DN8vKCSS97aRp3OiL+nK1P7hTGtfxhhvh4EeLnSO8wHRVEorqrHx90FD9cOLJzdRl3x+joSeX1ty9mvr8FgQFEUNBoNGo35ZovWV92gzCxlfZb4/qihoYGFCxdy1VVX0bdvX+644w4OHDjQ2Ozuyy+/5NZbb+Wrr746J/lqjkajQVGUZr+nbf0ed7skKyUlhcjISLXDkCSpOzMaO79G1pncfKj0iMavLhtykqHf3M4fU+oW6nQG7vtyL3U6I1qNQnmtjh/35vDj3pzGx/SP9CPEx40tJ4qJ8vfkjWuTSIoLVDFqSZKks/3zn/+kvLyc1157DR8fH5YtW8Ytt9zCr7/+yhdffMGNN97Iq6++ypgxYxqXcvL09MTf399mMTlVklVVVcWJEyca/5+enk5KSgpBQUHExcXxyCOPkJOTwyeffALAK6+8QkJCAgMHDqSuro7333+fdevWsWrVKrW+BEmSJKjMFSUYGhcIjLfKIUu9EkSSlSuTLKntnv7tMKn5lYT4uPHrfRPJLKlh9eF8tqWdpqbeQE5ZLUfyKhofn1NWy5XvbON/Vw5j/tAoFSOXJEkSNmzYwCuvvML69esby/c+/fRThg4dyltvvcXXX3+NXq/nnnvu4Z577ml83o033siSJUtsFpdTJVm7d+9mypQpjf9ftGgR0HSR8vLyyMzMbPx8Q0MDf/3rX8nJycHLy4shQ4awZs2as44hSZJkd5ZRrMAE0FqntKTMqyc9SjaLkSxJaoOvdmby2XbxmvniFUOJ8Pcgwt+D0QlNbY3Lahr4cW8O1fV6pvYL5431x1l2IJ9//HCAUfFBRPh7tHR4SZIku5g8efI5Hf/i4+MpLy8H4K677lIjLOdKsiZPnsz5+nT8MRt9+OGHefjhh20clSRJUjuVpIltUE+rHbLMO0Hs5CaLycptXF9E6p62nizm3z8dBOCvM/owuW/za7UFeLlx84SExv+/fs1wcsq2si+rjMd+Psg7N4y0S7ySJEnOxv4thCRJkrq70nSxtWKSVe4Rh0nrBrWlTceXpGasTy3k5o92oTOYmDs4gnunJrb5uVqNwnOXDcZFo7DyUAFrjxTYMFJJkiTnJZMsSZIkeyuxfpJl0rhgCjOvCSJLBqVmnCyq4qFv93HbJ7up1xuZ1i+M/105rMV2yC3pH+nHrReI0a3X1h4/b4WJJElSd+VU5YKSJEldQmOSlXD+x7WTKSoJ8vZC7l4YfLlVjy05puKqejYfL6KsRkdOaS2p+ZX4uLswKiGIi4dGEerrTp3OwGtrj/PupjT0RpEQXTY8mv/+aQiuHVwT6/ZJPfl4Wwb7ssv5/cRpLugdYs0vS5IkyenJJEuSJMmeTCabzMkCMEUNhz0fypEsO6is0/HBlnQyiqvRGU3MHBDORUOi0GrsMxfudFU9H287xfub06hpMJzz+RWH8nl1zTGuGRPHzym55JXXATC1Xxj3TU3sdAv2EB93rh4Vx5KtGSxef0ImWZIkSX8gkyxJkiR7qswHfS0oGvCPteqhTZFJYidvHxgNoLHdorHd2ebjRfzfd/vJNScuAL/tz+ONdSd49eokBkT52ezcdToD/156kKUpOegMYlSqX4QvvcJ8CPVxp1+EL6U1On7Zl8vhvAre2SgS+ih/Dx67eCCzBlpvoerbJ/Xks+2n2JZ2ml0ZJYyKD2r9SZIkOTxZAmydayCTLEmSJHuyNKXwjwUXN+seOzgR3HygoQqKjkL4AOsev5sxGE38uj+X9OJqqut0VBQobPv5MF/tygYgLsiL68bEUVmn55NtGRwvrOLKd7bx5nXDmdQn1CYxvbTqKN/uEecfGuPPny/sxZxBEefMq7p9YgIf/Z7Br/tzuWRYNNeOicPD1bpJd3SAJ1eMjOHLnVm8sPIoX98xtt3zuyRJchyurmJJkZqaGjw9PVWORl01NTVA0zXpCJlkSZIk2ZONSgUBMXIVOQxObRGt3GWS1WHV9Xoe+Gova44UnvFRLaSJBOfGcT34vzn98HITL6O3TUzgzs/2sD2thFuW7OLTW8cwrlewVWPadvI0728RSfria4czb0hki4910Wq4fVJPbp9kg5+zM9w3tTff78lhZ3oJm48X2yy5lCTJ9rRaLQEBARQWir97Xl5eDnPjxGg00tDQQF1dHRqN7fr2mUwmampqKCwsJCAgAK224zenZJIlSZJkTzZqetEoOkkkWTl7IOl625yji6uq13PNu9s5kFOOu4uGS5OicdUq7DqSgX9QMA9M78P4XmfPQQrwcuPjW0bzwJcprDiUz12f7+GneybQI9jbKjHV6w387bt9mExwzejY8yZY9hQV4Ml1Y+P46PcMXlp1lIm9QxzmTZkkSe0XESFKii2JlqMwmUzU1tbi6elpl78xAQEBjdeio2SSJUmSZE+2HMkCiBoutrL5RYcYjSYWfZ3CgZxygrzdeP/GkQyPC0Sn07FMSWPu3FEtlo+4u2h55ephXPXONvZll/PnT/fwy30XdLiD35m+2Z1Ndmkt4X7u/GueY41Q3j05kS92ZLIvu5xdGaWMTpBzsyTJWSmKQmRkJGFhYeh0OrXDaaTT6di0aROTJk3qVAlfW7i6unZqBMtCJlmSJEn2ZEmyAm01kmVOsgoOgb4eXNxtc54u6o31J1h1uAA3raYxwWoPD1ct7y0cyaxXNpGaX8ln209x84TOfa/r9QbeXH8CEAmNt7tjvXSH+rpz2fBovtyZxYdb0mWSJUldgFartUqiYS1arRa9Xo+Hh4fNkyxrkYsRS5Ik2YvJZJOFiM8S0AO8gsGog/yDtjlHF1XbYODdTSIJfvrSQe1OsCzC/Dx4aFZfAF5efYzTVfWdiuubXVnkldcR4efBVaOs25HSWiyJ5KrD+WSV1KgcjSRJkvpkkiVJkmQvtaVQXy72A+Ntcw5FaSoZzJUlg+2x6nA+VfV6YoM8uXx4TKeOdfWoOAZE+lFRp+eJXw53uB1wnc7A4vUnAbhnSi+rdwi0lj7hvkzsHYLRBB9vzejQMWTXaEmSuhKZZEmSJNmLZRTLNxLcvGx3nmg5L6sjvjO3Rr80KQZNJxcV1moUnlowEK1G4ed9ubxjHiE7U2WdrtXk6+tdWeRX1BHp78GVDjqKZXHT+HgAlqbkoDcY2/Sc4qp6nlueyshn1vHYHi2P/HiIvPJaG0YpSZJkH45V2C1JktSV2brphUVj84s9tj1PF5JfXsfvJ4oB+NPwaKscc0SPIB6bP4BHfzrEf1eksuJgPgkh3igKHM6tIDW/kvG9gll87XACvc9dM61OZ+DNDea5WFMScXdxzFEsi0l9QgnydqO4qoEtJ4qZ3DfsvI9PKxLrihVXNZg/ovBdcg7Hi6r58a7xnU50JUmS1CRHsiRJkuzFshCxrZpeWFhGsoqPQX2lbc/VRfywNxujCUbFB1qt7TrADWN7cPOEeEwmSMkq48e9OfyQnENqvvi+bD15mksW/87qwwUYjE2jWhV1Ov767T4KKuqJ8vfgypGdK1+0B1ethovMreV/Ssk972Pzymu54YOdFFc1kBjmw1vXDuPO/ga83bXsyyrjh7059ghZkiTJZuRIliRJkr00jmTZOMnyCQO/GKjIhtwUSJho2/M5udoGAx9uyQDgipHWLclTFIXH5g9k4bh4DuSUk1dWiwmI9PcgOsCTRd/sI7Okhts/2U1MoCc3jO2BVqPw4ZZ0csvr0GoU/jlvgMOPYllcMiyaT7adYuWhfGoa9I2LNZ/JaDRxz+fJ5JTVkhDizVd3jMXfXUNDuol7Jvfk+ZXH+e+KVGYNDMfXwzm6iEmSJP2RTLIkSZLsxV5JFohFiSuyRfMLmWSd1+c7TlFcVU9MoCcLhlmnVPCPEkK8SQg5d4Tsp3sm8Pamk3y9K4vs0lqeXZ7a+Lm4IC9euXpYh7scqmF4XABxQV5kltSw6lABC5LOvZ7fJWeTnFmGt5uWT24ZTYiPe+N6PDeO7cG3e3JJL67mkR8O8Po1SXJxY0mSnJIsF5QkSbIXW7dvP1NUktjmH7D9uZxYbYOBtzeK5PfeKYm4udj3ZTHQ241H5vRn+yPTeP5PQ0iKC2BYbADPXjaYlQ9OcqoEC8TI3WXmOW0fb8s45/PlNTr+a04kH5jem9igsxvAuLloeOHyIbhoFH7dn8f7m9NtHrMkSZItyCRLkiTJHuorobpQ7Nt6ThZA2ECxLThs+3M5sdfWHW8cxbqsk23bO8PDVcuVo2L58e4JLL1nAteMjsPTzTlKBP/oujE9cNNq2JtZxp5TpY0fN5lMPPHLIU5Xi3lYLS3SPDI+iH9fNACA51akknlarrslSZLzkUmWJEmSPZRmiK1nEHgG2P584eJNKsXHwKCz/fmc0J5TpbyzUaxB9a95A+w+itVVhfq6c8mwKAA+2NLUuv7bPdn8sDcHjQLPXDoYV23L13vhuB5ckBiCwWjiq12ZNo9ZkiTJ2uQriiRJkj3Yq327hX8suPmCUQenT9jnnE6krKaBh77dh9EElyVFM3tQhNohdSm3ThSjVCsO5rP2SAGrDxfw6E8HAVg0ow+jE4LO+3xFUbhuTBwgkjNdG9fdkiRJchQyyZIkSbKHxvlYdigVBFAUCOsv9gsO2eecTqKqXs+NH+0ivbiaSH8PHps/UO2Qupx+EX7MGRSB0QS3fryb2z/ZTZ3OyIV9Qrl7cmKbjjGtfzjB3m4UVdazPrXQxhFLkiRZl0yyJEmS7MHeI1nQVDJYKOdlWZhMon34vqwyArxc+fiW0fh7yTbhtvDq1UncND6+8f+3XpDAewtHtnmRYTcXDZePEPPkvtqVZYsQJUmSbEa2cJckSbIHS5Jlj6YXFmGWJOuI/c7p4H5KyWXjsSLcXTR8csto+oT7qh1Sl+XmouHxiwcye1AErlqFET3OXyLYnCtHxfLOpjQ2HSuisk4n182SJMlpyJEsSZIke7A0vrDnSJYlyZLlggCU1+p4+jcxqnf/tN4MiQlQN6BuYmzP4A4lWAC9Qn3oGeKN3mji9xPFVo5MkiTJdmSSJUmSZGv6eijPFvtqJFllp0QL+W5u8foTFFc10CvUm9sn2vH7IHXKhX1DAdhwtEjlSCRJktpOJlmSJEm2VnoKMIGbD3iH2O+83sHgEy72i47a77wOqF5v4JvdYl7PP+b2l+3anciUvmEArD9aiMlkUjkaSZKktpGvMpIkSbbW2PQiQXT9s6cw2fwCYO2RQspqdET4eTDZ/KZdcg6jE4LwdNVSUFHPkTw5IitJknOQSZYkSZKtqdH0wiK0r9h285Gsb82jWJcNj0bbxu52kmPwcNUyvlcwABuOyVbukiQ5B5lkSVJXoKuDbYth7ZNQLSeHO5xSyxpZKswDkkkWBRV1bDwm5vNYWoJLzmVyPzH6uPaITLIkSXIOMsmSJGeXuQPeHAMr/wGbX4JXh8HW18FoUDsyyeLMckF7CzEnWcXdN8n6bX8eRhOM6BFIz1AftcOROmB6f5FkJWeWUlhZp3I0kiRJrZNJliQ5s9oy+GahaA/uGwkRg6GhElb9C5ZcBOU5akcoAZSoOZLVT2zLsqCh2v7ndwCbj4tRrJkDwlWOROqoSH9Phsb4YzLJ0SxJkpyDTLIkyZmtfQKq8iE4Ee7dBXdsgvmvii52mVvhu5tBduNSl0EvWqiDOkmWdzB4BQMmKD5u//OrrEFvZEd6CQAX9LZjZ0fJ6mYOjABg1aF8lSOxjYo6HblltWqHIUmSlcgkS5KcVeYO2P2h2L/oFXD3BY0GRtwEd2wEVy/I2gEHv1czSqkiG4x60LqDb5Q6MVhGs4qPqXN+Fe3NLKWmwUCwtxv9I/zUDkfqhFkDxUjk7ydOU1mnUzka68kvr+OfPx5g9H/WMP65dcx5dTNL98oqBElydjLJkiRnpG+AXx4Q+0nXQ8LEsz8fkggX/EXsr34MdPLuqGoaOwvGiyRYDSF9xLYoVZ3zq2jLCdEIZkJiCBrZVdCp9Qr1oWeINw0GI+u7yMLEO9NLmPfaZj7fkUmdzoiiwJG8Ch78OoVf9+eqHZ4kSZ0gkyxJckZbX4WiI+AVAjOeav4x4+4FvxgxkrLjbfvGJzVRs+mFhWUkqxt2GNx8XCRZslTQ+SmKwtzBkQB8vydb5Wg6b/mBPK59bzunqxvoH+nHV3eMZe+/Z3D92DgA/vrNPvacKlU5SkmSOkomWZLkbEpPwcYXxP7s58ArqPnHuXnB1H+K/a2vQ32VfeKTzqZm0wuLUPNIVjcrFyyv0bE/uwyAiTLJ6hKuGCla8G86XkR2aY3K0XTc+qOF3P/VXvRGE/MGR/LDXeMZ2zOYAC83nrh4ENP6hVGvN3LDBztYe6RA7XAlSeoAmWRJkrPZ/BIY6iFhEgy+/PyPHXylWAC35jTset8+8Ulnc4gkyzySdfqkKDXtJtYdLcBogt5hPkT6e6odjmQFPYK9Gd8rGJMJvt3tfKNZBqOJ9zen8edP96AzmLhoSCSvXZOEp5u28TFajcJr1yQxsXcINQ0Gbv9kN59uy1AvaEmSOkQmWZLkTMqyIOULsT/ln6C0MsdE6wIXPiz2t77WbVt4q8qyEHGgiuWCvpHg7g8mA5zuPh0Glx0QXejmmEvMpK7hqlGxAHy7O4uSaue5aWAymbjjk908/dsRGvRGZg0M5+WrhqFtZq6gt7sLH940iqtGxmI0wb9/OsTTvx7GJLvFSpLTkEmWJDmT318Bo06MYsWNbdtzBl8JAXFiNOvocpuGJ/2ByXTGSJaKSZaiQPgAsV9wSL047KiqXs/GY6I5wtzBESpHI1nTrIERBHi5kltex7hn1/Kf3w5Tr3f8xdeXH8xnbWoh7i4anr1sMG9fPwJXbctvw1y1Gp7702D+NkssKP7+lnQ+2JJur3AlSeokp0qyNm3axPz584mKikJRFJYuXdrqczZs2MDw4cNxd3cnMTGRJUuW2DxOSbKJ2jJI/lTsT3q47c/TusDgK8T+oR+tHpZ0HpX5oK8FRSsSXTWFDxTbgoPqxmEn61MLadAb6RniTd9wX7XDkazIw1XLewtHMijaj3q9kfc2p3Pl29s4ddpxR+p1BiPPrxDdPe+8sBfXjI5Daa0SAdHs454piTw+X9wkeW55KilZZbYMVZIkK3GqJKu6upqhQ4eyePHiNj0+PT2defPmMWXKFFJSUnjwwQe57bbbWLlypY0jlSQbSP1NzMUK7QfxF7TvuQMWiO2JNbIBhj1ZOgsGxILWVd1YwgeJbTcZyVp+MA+AOYMj2vRmVnIuo+KD+OXeC3hv4UgCvFzZl13O9P9t5F9LD1Beq84aWodzK3hhZSq3f7Kbfy09QH55HQBGo4nX150g43QNIT5u3D6p/fMzbxwfz7zBkeiNJu75PJnCyjprhy9JkpW5qB1Ae8yZM4c5c+a0+fFvv/02CQkJvPTSSwD079+fLVu28PLLLzNr1ixbhSlJtmFZVHjQ5a3PxfqjiMGi8UJJGhxb0XrDDMk6GtfIUrFU0MKSZOV3/ZGsOp2B9amiVHDOIDkfq6tSFIUZA8L59b4LeOSHA2w+Xsxn2zM5lFvBZ7eOwdvdPm9xsktrWPJ7Bh/+no7xjClTPyTnMCExhPzyOg7klAPwwPQ++HQgLkVRePZPgzmcV0F6cTW3LNnF13eMs9vXKElS+3Xp385t27Yxffr0sz42a9YsHnzwwRafU19fT319feP/KyoqANDpdOh06twds5xXrfN3Bw5/jauLcUnbgALo+s2HDsSp6XcJ2q0vYzz4A4Z+l1g/xvNw+OtrI5rik2gBQ0A8Rht+7W26vkGJuAJU5aMrywPvrtvSfOPRImp1BiL9PegT6mmVn7vu+jNsL525vuE+rny4cDjb0k5z/1f72ZtZxi1LdvL4Rf1JDPMB4HR1A7lltfQN98XNpXNFPJV1en7al8uujFIO5VZyqqSplfy0fqGMSQhixaECkjPLWH1YtF/3dtPywLRErhoe2eGfIU8tvHdDEle+u4ODORX89ZsUXr96aJueK39+bUteX9tzpGvc1hgUk5O2qlEUhR9//JEFCxa0+Jg+ffpw880388gjjzR+bNmyZcybN4+amho8Pc9t6fv444/zxBNPnPPxL774Ai8vL6vELkntFV+8jqFZSyjzjGdjvyc7dAz/mgwmH30UveLGsqHvYFK0rT9J6pSR6W8QXbaTg9HXcDKs7aPwtjLt0EP4NBTye+LfKfYdoHY4NvPVSQ3bCjVMDDdyeU+j2uFIdnSqEhYf1lJvFKP9AW4m9Cao0on/h3iYuKSHkSFB7X/rU2+A1TkaNuUr1Buaqgk0mOjhCzOijQwMFMc1mSC1XOF0HRhNMCTIRIC7Fb5AIKMSXj2oxYjCn/sZGBDolG/jJMlp1dTUcO2111JeXo6fn1+Lj+vSI1kd8cgjj7Bo0aLG/1dUVBAbG8vMmTPPeyFtSafTsXr1ambMmIGrq8rzOrooR7/G2k/eBMB3/E3MHTu3YwcxGTG9+DwuDVXMGZUIYf2tGOH5Ofr1tRXtB6JUud+4OfTt28HvWxu09fpqa76Bo78yNsEH42jbxaMmo9HE0y9sBBq4efZIJiZaZ8Suu/4M24s1r++k3AreWH+SdUeLKGtoSoa83LQU1xn44KiWxy/qx3Vj2t6MZs+pUhZ9e4Bc8zyrXqHeLBgaycAoP4bF+uPrcW7M8zr1VZxfhf9RPvj9FMsLfbjvyvG4u57/ppn8+bUteX1tz5GusaXKrTVdOsmKiIigoODsldILCgrw8/NrdhQLwN3dHXf3c283ubq6qv5NdYQYujqHvMZFxyBrOygatEOvQtuZ+CKHwqnfcS06BNFDrBdjGznk9bUVkwlKMwBwCesDdvi6W72+kYPh6K9oi4507ufIgaVklVFU1YCPuwsTeofh6mLdEdtu9TOsAmtc32E9gnn/pmAKK+vIK6vDRasQG+SFVlF4cdVRPvo9g6eXHWVAdCCjE4JaPV5qfgW3f7qXyno90QGe/PuiAcwaGK5qQ5W/zOzHL/vzySyp5ZOd2dw9ObFNz5M/v7Ylr6/tOcI1buv5naq7YHuNGzeOtWvXnvWx1atXM27cOJUikqQOSP5YbHvPAr+ozh0r0ly/n7evc8eRWldTAvVisjuB8aqG0sjSxj3/gLpx2NAa8xyYC/uE4m7lBEtyLmG+HgyNDWBglD9+Hq54u7vw6EUDmD80Cr3RxN2f7yGvvPa8xyioqOOmD3dRWa9ndHwQqxdNYvYg9TtW+ri78H+z+wHw3qY0quv1qsZjDccKKtl6opjkzFKMRlkCKTk/p0qyqqqqSElJISUlBRAt2lNSUsjMzAREqd/ChQsbH3/nnXeSlpbGww8/TGpqKm+++SbffPMNf/nLX9QIX5LaT18PKV+I/RE3df54Msmyn1LzoqG+UeDa/Mi53YWZ52EVHwOj4y/e2hFrjogka/qAMJUjkRyRoij890+D6RfhS3FVA3d+lnzehYyf+OUQ+RV19A7z4b2FI/Fyc5wCoEuGRZEQ4k1pjY5Ptp1SO5wO0xuMPPbTQWa+vIlr39/BZW9u5a2NJ9UOS5I6zamSrN27d5OUlERSUhIAixYtIikpiUcffRSAvLy8xoQLICEhgd9++43Vq1czdOhQXnrpJd5//33Zvl1yHqm/Qm2JeKOeOL31x7fGkmTl7wejbAhgU6fNbxKC2r8mjs0ExoOLB+jrGksZu5KskhpS8yvRahSm9JVJltQ8LzcX3r1hJP6eruzLKuPuz5LJPF1DZZ3urDW2Nh0rYtmBfLQahdeuScLfy7HKwFy0Gu6dIsoE39vsnKNZdToDNy/ZxcfmJDE+WDQYe39zGjUNzvf1SNKZHOeWTBtMnjyZ8zVDXLJkSbPP2bt3rw2jkiQb2rNEbIffAFor/LoG9wYXT2iogpKTENK788eUmnf6uNiGtG2uhF1otBDSRyTZhUcguJfaEVnVWvMo1sgegQR4uakcjeTI4oK9eP2aJG5esou1qYWsTS0EwEWj8Pc5/ViQFM1jP4uFuxeO60H/SHUaX7XmkmFRvL7uOBmna/j30oO8dOVQ1UsZ20pvMHL/l3vZfLwYLzct/7tyGDMGhDP1pQ2cOl3D17uyuHmCA6wxKEkd5FQjWZLUrZw+CembAAWSrrfOMbUuEGFelFaWDNrW6RNiG+xASRY0lQwWHVE3DhtYc0S8UZ4xIFzlSCRnMKlPKEvvnsCkPqGNH9MbTTz92xFmvryJ9OJqwnzd+cuMPipGeX4uWg3P/WkIWo3CD3tz+Gy785QNPv3bEVYdLsDNRcOHN41i9qAItBqFOyaJ0f/3NqWhM8iKC8l5ySRLkhxV8idimzgdAtrearhVkcPENi/FeseUzlVsSbIcbLQwTEyWpzBV3TisrKJOx/a00wBM6y+TLKltBsf488kto9n5j2kcfGJWY/ldSXUDCSHefHH7GPyaac/uSMb2DObv5iYYT/56mD2nSlWOqHWHcsv5eFsGAK9dPYyxPYMbP/en4TGE+LiTW17HsgN5KkUoSZ0nkyxJckT6Bkj5XOyPuNG6x5bNL2zPaBTlmOB4I1mh5vXRCrvWSNbGo0XojSYSw3xICPFWO5yOMejF6PWxVaI7pWQ3YX4e+Li78NeZfXji4oHcND6epfdMIDHMV+3Q2uS2iQnMHRyBziC6JhZV1qsd0nk9tzwVkwnmD41i9qDIsz7n4apl4bgeACzZmqFCdJJkHU41J0uSuo1jy6G6CHzCoc9s6x77zCTLZAInqd93KpW5oKsBjQsE9lA7mrNZRrJOHxdv6q0x188BNHYVdIZRLIMOastA0cCpLZC+GcqzIWe3+L236DEBxt8n/gbI31O7UBSFG8fHqx1GuymKwvOXD+VofiUni6r586e7+eTWMfi4O97v9+bjRWw+XoyrVuFvM/s2+5irR8fy+rrj7M0sY392GUNiAuwbpCRZgRzJkiRHZGl4Mew60Fq5VCW0H2jdoK4cypynft+pWOZjBcZb//vXWf5x4OoNhgYoSVM7GqvQGYysT7XMx3LwroK5KfDKYHgxEV7oCd8shF3vNd1Y8QxqKjE99Tt8eTVs/K+qIUvOwcfdhXduGIGvhwvJmWUs/GAHlXW61p9oZx9vFa87143pQZy5m+Afhfl6MG9w5FmPlyRnI5MsSXI0pRlwcr3YH77wvA/tEBe3puYHuSnWP74ExebOgo42HwtAo4FQ893jLtL8YldGCRV1eoK93RgWG6h2OC3L3AEfz4fKM+aZBPWCsXfDRS/Dwp/hoWNw3274y2EYc6d4zOb/Qal8oym1LjHMl89vG4OfOdH6548H1Q7pLPV6A1tPFgNw+YiY8z7WMqL4y75cCivrbB2aJFmdTLIkydEkfwqYoOcUCLJR+1o5L8u2LGtkOVL79jOFda15WWsOi1Gsqf3C0GoctKzOaIQfboP6ClEG+PdM+Fch3J8Ms5+FkbdAzwubRj79o2H2c5AwCQz1sOYxdeOXnMaQmACW3DIaRYGf9+WyM91x5vftziilpsFAiI87A1ppi58UF8jwuAAaDEaW/J5hnwAlyYpkkiVJjsRosF3DizPJJMu2LGtkOVrTC4tQ87ysIufvMGgymVh9JB+A6Y7cuj1nD5RlgpsPXPsNePiDi/v5n6MoMOsZQIFDP4pjSFIbDI8L5OpRoivtYz8fwmBseY1Re9pwVNwQubBPKJo23BC580Kxlt+n2085ZOmjJJ2PTLIkyZHkpYhSInd/6DvPdudpbONubn4hWZcjlwtCU/JnGXFzYscKqsgqqcXNRcPE3iFqh9Oyw0vFts9scPdp+/MiBsPgK8T+3s+sHpbUdT00sw9+Hi4cyavg8V+PYMmzSqsb+HzHKf67IpV/Lz3I4vUn2HPKPqNdG4+Jxi6T+4a28khhev9wEsN8qKzT88WOTFuGJklW53htZySpO0vfLLbxE8TcKVsJHwiKFmqKoSJXlCZJ1qGvFyMW4LgjWcHi7jAlaU7fYfI38zo6k3qH4OXmoC9pJhMc/lnsD7ik/c8fdg0c+EaMZs3+r23/NkhdRrCPO09fOpgHvtrLV7uy2e2nYUVFChuPF1OnO3eR37evH8HsQRE2iye3rJZjBVVoFNp8Q0RjXpz44e/2886mNK4eFYe/l4M1E5KkFsiRLElyJBlbxDZ+om3P4+rRNC9HlgxaV0kaYAJ3P/Bx0E53gfGifXhDFVQVqh1Nh5lMpsbFSucNiWzl0SrKTYbyTNHVsfeM9j8/4UKxnENtKZxcZ/34pC7r4qFRvHLVMLQahRMVGlYeLqROZ2RApB83jY/n3imJXJAoEp5//HjApg0m1pqXWUiKCyTAq+03Ci5NiqZ3mA8l1Q28vOaYrcKTJKuTSZYkOQqDDjK3if34C2x/PjkvyzYs7duDeznuCJGLO/ibO3uVOG/J4LGCKk4UVuGm1TDNkdfHOrRUbPvMBFfP9j9fo4VBfxL7B76xWlhS93DJsGi+uX00F8cZeGR2H76+Yyy/3X8Bj188kIdm9eXDm0bRP9KPkuoGHvgyhfIa28x9WpqSC8Ccdo6WuWo1PH7xQEDMzUrNr7B6bJJkCzLJkiRHkbdPjCx4BED4INufTyZZtuHo87EsGudlnVA3jk5oLBXsE4qfh4OWEJlMcPgnsT9gQcePM/hysU1dBvWVnQ5L6l6GxPgzLdrELRPiGdMzGOWMG0BuLhpevXoY7i4atqWdZu5rmzmQXW7V82eV1LDnVCmKAvOHRrX7+RMSQ5gzKAKD0cSLK49aNTZJshWZZEmSo0jfJLbxF4i1jGytMclKsf25upPG9u0OnmQFmedlOWnzC5PJxG/7xZ3xeUNsN4+k0/JSxKLfLp4dKxW0iBouvmf6Wkj9zWrhSRJAn3BfvvnzOOKCvMgpq+XOz/ZQ22Cw2vF/SskBYHyvYML9PDp0jIdm9UVRYM2RQo7kydEsyfHJJEuSHIW95mNZhA8CFNHNsLLAPufsDhrbt/dSN47WNDa/cM4k61BuBSeLqnFzcfBSQcsoVp+Z4Obd8eMoCgy5Uuwf+LbzcUnSHwyNDeDX+y8gOsCTnLJa3tpgnVFuk8nUWCp4ybCON1nqFerD3MFi7uXi9c47Ai91HzLJkiRHYDJB9m6x32Ocfc7p7tM02pK/3z7n7A4a52Q5y0hWmrpxdNAPyeLO+IwB4Y5dKmiZj9WRroJ/ZGnlfnI9VBV1/niS9Ad+Hq78+yLRFOntjWlkFFd3+pjb00rE3EkXTae7F94zWZQ5/3Ygj7Siqk7HJkm2JJMsSXIEJWlQXw5adwgbYL/zypJB66opgZrTYt9pRrLSwHhuO2dHpjcY+XmfSLL+NNyBlx/IPwCl6eDiAb1ndf54wb1E2aDJAId+6PzxJKkZswZGMLF3CA0GI29t6NxIt8lk4sVVYg7VlSNjOn1DZECUH9P6hWEywdsbnXMUXuo+ZJIlSY7AkuSEDwStHe/Kn7kosdR5lvlNftGdKw2zh4A40LiIOT6VuWpH0y6bjxdTXNVAiI8bE3u3bVFTVaStF9teU9u3APH5NJYMfmed40nSHyiKwv3TxEj8T/tyKK/teLfB9UcL2XOqFA9XDfdPtc7o/t1TxGjWD8k55JTVWuWYkmQLMsmSJEeQmyK2UcPse17ZYdC6nGU+FohkPqCH2Hey5hc/7BWjWPOHRuGqdeCXsdy9YhszynrH7HeR2ObsgYYa6x1Xks4wskcgfcN9qdMZ+SE5u0PHqG0w8NzyVABuHB9PWAcbXvzRiB6BjOsZjN5o4l05miU5MAd+dZKkbsQykmUZWbKXiMFiW5YpSt2kznGW+VgWTtj8QmcwsiFVLKB8cQdaQduVJcmKSrLeMf1jwCdClAzKMl/JRhRF4fqxcQB8tv0UJpOpXc83mUz888cDHCuoItjbjTsnWffG071TxWjWV7uyOF1Vb9VjS5K1yCRLktRmMjWNJNl7JMszAAITxH5usn3P3RVZ1shy9PbtFk7Yxj35VCmV9XqCvN0YGhOgdjgtqymB0gyxb83fa0WBmJFi39IsR5JsYEFSNN5uWk4WVfPZjszzPtZoNLH6cAFL9+aw6VgRD327nx/25qBR4PVrkwj0drNqbON7BTM42p96vZGvd2dZ9diSZC0uagcgSd1eaTrUlYPWDUL72//8PcaLGNI3QeJ0+5+/K2kcyUpUN462Cna+JGvDMdFVb1LvEDQapZVHq8hy4yQwATwDrXvsmJGQ+ivkyCRLsh1fD1fumZrI8yuO8thPB6mp16PVKGg1Cv6erozvFUKEvwc5ZbU8/N0+fj9x+pxj/GNuf8b3CrF6bIqicOP4eB76dh+fb8/kjok9cXHk0mGpW5JJliSpzTIfK3wguFj3bl+bJFwIKZ9D2gb7n7srMRqbkhVnS7KcqFxww1GRZE3uG6ZyJK1oLBUcZv1jR8uRLMk+7rqwF+lF1Xy7J5tnzfOrLBQFovw9G5tPeLpqGRTtR35FHSPiArlyVKxNEiyLi4ZE8p/fDpNTVsva1EJmDXTgRcmlbkkmWZKkNrXmY1n0vNAcx35R4uQVpE4czq48Cwz1YkQyIE7taNrGUi5YmgFGA2i0qobTmoKKOo7kVaAoMLG37d68WYUt5mNZRCWBooGKHKjIA79I659DkhAjRs9cNhhXFw1pRVWE+npgMpnILq0lJausMcEaFR/Ic38aQq9QK3XRbAMPVy1Xj47jrQ0nWfJ7hkyyJIcjkyxJUluOeS5U9HB1zu8bAaH9oCgVMjZbZ9HU7shSKhjU0+GTlUb+MSIpNDSIJDEwXu2IzmujeRRrSLQ/wT7uKkfTCsvNE1skWe4+Yj29goOiZNBvvvXPIUlmrloNz1w6+JyP55TVkl5UzcAoP6vPuWqr68f24L1NaWxLO83ezFKS4qxcmitJnSALWCVJTUZD0x1vSwmQGhLMo1lpG9WLwdk523wsEMmgpfGJE8zL2pYm5nxM6uPAa2OBGBEuMzcKsCyTYG3RI8Q2e5dtji9JrYgO8OSC3iGqJViWGC5NEguSv7HuhGpxSFJzZJIlSWoqSoWGKnDzgdC+6sVhKRmU87I6zhmTLGiK1wmSrN2nxDIDI+MdvKS14JDYBsaDh79tzmFJsuQad1I3d/eURDQKrE0t5GBOudrhSFIjmWRJkposE9ejktQtMYu/ABStaIBQekq9OJyZs7VvtwjuKbYO3vyisKKOrJJaFAWS4gLUDuf8iswNAkL72e4cEYPENv+gWAZCkrqphBBv5pvXzHtlzXGVo5GkJjLJkiQ1WVowx6hYKgjibnvsaLF/Yo26sTgrZ+ssaOEka2XtOVUKQN9wX/w8XFWOphVFR8XWlqPTYQNE84uaYqgqsN15JMkJ3De1NxoF1hwpYHdGidrhSBIgkyxJUlf2HrFVcz6WReI0sT2xVt04nJGuVjSOAAh2tpEs52jjvtucZI2Md4KJ7fYYyXL1bEro8w/a7jyS5AQSw3y4alQsAM8tT8UkR3clByCTLElSS30lFB4W+2qPZAEkzhDb9I2gb1A3FmdTkgaYwCPA+VrgN7ZxPwUGnbqxnIclyRrRwxmSLDuMZAGEm0sGCw7Y9jyS5AQemNYHD1cNu0+VNi5a7gwKK+tYvP4ED3+3j4KKOrXDkaxIJlmSpJbcvYAJ/GJEG3W1RQwB71DRiCNru9rROJcz52MpirqxtJdvJLh6gcngsPPxahsMHDJPaB/Zw8GT2JoSqC4U+yF9bHuuM+dlSVI3F+HvwbWjewDwza4slaNpm9WHC5jw3DpeWHmUb3Znc9mbW0krqlI7LMlKZJIlSWo5tU1sY0epG4eFRgOJ08W+nJfVPqfNSZazzccC8X0PEG9MKHPMJGt/dhl6o4kwX3diAj3VDuf8io+JrX8suPva9lyNI1kyyZIkgMtHxACw9kgh5TWOOzIPUK838PjPh9AZTAyN8Sc+2IucslqueHsb2aU1aocnWYFMsiRJLenmNakSJqkbx5ksSdbRFerG4WxOp4mtZX6TswmIE1vL2k4OJjmzDIDhcYEojj5SaJmPZetRLGhKsoqPg06WGUnSgCg/+kX40mAw8tuBPLXDOa/Pt2eSU1ZLhJ8HX/95HN/dNZ7+kX6crm7gz5/uobbBoHaIUifJJEuS1NBQDVk7xb5lIWBHkDgdNK5QfBQKU9WOxnmUpottUE914+goB0+y9maK+VjDewSoG0hbNM7HsmHTCwu/KPAMFKWeRfL3VZKAxsWJf9ybrXIkLauu17N4vVhb8f5pvfFw1RLi4877N44k2NuNQ7kVPPHLIZWjlDpLJlmSpIbMbWDUiZIiR3pj7hkAvaaK/cNL1YzEuZSYk6zABHXj6CgHTrJMJhN7s8oASIqTTS/OoihNo1n5svmFJAFcMiwaRYFdGaVklThm2d13e7I5Xd1Aj2AvrhgZ0/jx6ABPXr8mCYDvk7Mpq5FNqJyZTLIkSQ1pllLBCx2vUcLABWJ7+CdVw3AaDdVQlS/2g2SSZW05ZbUUVdbjolEYHO2vdjits+dIFkDkULHN22ef80mSg4vw92BMgmiQs/qw460hZzKZ+Gy7mP96y4QEXLVnvxUfnxhC/0g/dAYTyw/mqxGiZCUyyZIkNVjmY/V0oFJBi75zRclg4WEoOqZ2NI6vNENsPQJE6ZYzcuAkyzIfa0CUHx6uWnWDaU1DDVSYS5Ts1QQlcpjY5qXY53yS5ARmDBAde1cddrwkZUd6CccLq/B01XLp8OhmH3PJsCgAfkrJsWdokpXJJEuS7K36NOTtF/uO1PTCwjMAek0R+ymfqRqKU7CUCjrrKBY0dResyne4BgqW+VhJsQHqBtIWjQm3v/3WS4saJrb5B8Ggt885JcnBzRwQDoiSwdJqxyq5s4xiLUiKxs/DtdnHzB8qkqwd6SXklzvW32Sp7WSSJUn2dugHwCTWpXKE9bGaM/xGsd3+VtMaUFLzSp18PhaIhMDVW+yXO9Zk8b3mkSynmI9VYu4yGdTLfmXAQb3AzQf0tU3t4yWpm4sN8qJfhC8Go4l1qYVqh9Mor7yWlYfE6Nr1Y+NafFx0gCej44MwmeCXfbn2Ck+yMqdLshYvXkx8fDweHh6MGTOGnTt3tvjYJUuWoCjKWf88PDzsGK0kNSPlC7Edeo26cZxPv3mQOAMMDfDrX8BkUjsix9UVRrIU5YySQcdZK6teb+BwbgUg2rc7vJKTYmvPZjYajbhhA7JkUJLOYBnNcqSSwXc2pqEzmBiTEMTAqPPPMb1oaCTgWPFL7eNUSdbXX3/NokWLeOyxx0hOTmbo0KHMmjWLwsKW71L4+fmRl5fX+O/UKcd5AyF1Q0VHITcZNC4w+Aq1o2mZosC8F8HFEzI2w9FlakfkuLrCSBY45LysQ7kVNBiMBHu7ERvk4IsQwxkjWXbuGGopGcxNse95JcmBWeZlbT5ejM5gVDkaKKqs58ud4u/rfVN7t/r4af1FkrjnlOOVPEpt41RJ1v/+9z9uv/12br75ZgYMGMDbb7+Nl5cXH374YYvPURSFiIiIxn/h4eF2jFiS/mDfl2KbOAN8QtWNpTWB8TDmDrG/6wNVQ3FoJU6+RpaFAyZZyafM87HiAhx/EWJQL8lqbH4hOwxKksXAKD8CvVypaTCwP7tc7XD4YEs69Xojw2IDmJAY3OrjowM86Rfhi9EEG48V2SFCydpc1A6grRoaGtizZw+PPPJI48c0Gg3Tp09n27ZtLT6vqqqKHj16YDQaGT58OM888wwDBw5s8fH19fXU19c3/r+iQpSq6HQ6dDqdFb6S9rOcV63zdwd2ucZGPS77vkIB9IOuwOQM38+hN+D6+6twci26wuMi8eqALvszbNDhUpaJAuh8Y8GJ/0Zo/KLRAsbSDAwO8n1KPlUCwJBoP9V/dtpyjV1OnxS/3/497Pv7HTYIV8CUvx99fR1oHLwLYzO67N8IB9Fdr++o+EBWHS5k6/FChkT52Ow8rV3fvPI6lmwVN+TunBSPXt+2JjWT+4SQml/JmsP5zBsUZp1gnZQj/Qy3NQanSbKKi4sxGAznjESFh4eTmtr8Svd9+/blww8/ZMiQIZSXl/Piiy8yfvx4Dh06RExMTLPPefbZZ3niiSfO+fiqVavw8vLq/BfSCatXr1b1/N2BLa9xRHkyYyrzqHfxZdVJE8Z05yjBG+s7mPDKA2R89xiHo6/q1LG62s+wV30BM0wGDIoryzbvAUXd4oDOXN/I0mJGA2UZ+9m8zDF+Nrcd0wIK9blHWbas+b/z9tbSNdYYG5hfIdotr95zkob9drzzbDIyT+OOi66GzT9+SKVn822hnUFX+xvhaLrb9fWtUQAtv+06Rly17f+GtHR9PzmuoU6noZevibqTu1mW1rbjeVQCuLD2cB6//JaN1gkG9G3NEX6Ga2ratsi10yRZHTFu3DjGjRvX+P/x48fTv39/3nnnHZ566qlmn/PII4+waNGixv9XVFQQGxvLzJkz8fPzs3nMzdHpdKxevZoZM2bg6tp8u0+pc+xxjbVffwqAy8iFzJ52iU3OYQvKUeC7hSRWbSd+5jvg0v7mMV31Z1hJWw+HQRPck7nzLlItDmtcXyU3Ej56g0Clkrlz51o5wvYrqKijdNsmNArcdtkMfNzVfblq9RoXpcI+MLn7Mv3iq+y+yLimeBhk72BSH39Mg9X//rVXV/0b4Si66/XtmV/J94u3kVnryoxZU85Z+Ndaznd992aWsWfbThQFXrp+HAOj2v5e0mA08XHaBkprdIQPHMvoeDstDeGAHOln2FLl1hqnSbJCQkLQarUUFJy9endBQQEREW1rg+3q6kpSUhInTpxo8THu7u64u7s3+1y1v6mOEENXZ7NrXJYFJ9cCoB15C1pn+j72nwd+MSgV2bge+RGGL+zwobrcz3ClGLlQAuMd4uvq1PUN6QWAUlWAKwZwVbcT68G80wD0Cfcl0Mdxml60eI0rxFw2Jagnrm5udo4KiE6C7B24FB4E1+vsf34r6XJ/IxxMd7u+A6MDCfBypaxGR2phjc27lDZ3fT/cKv42XD48hmE9Wp+LddbxgMl9w/hxbw6bjpcwobfsK+AIP8NtPb/TNL5wc3NjxIgRrF27tvFjRqORtWvXnjVadT4Gg4EDBw4QGRlpqzAlqXnJn4DJCPETISRR7WjaR+sCY+8U+1vfAKP6XZochmVNqYBYdeOwBgdbK6txEWJnaN0OZ6+RpQZL8wvZYVCSGmk0CmMSxOjP9rTTdj9/Vb2e9UdFB+ybJsR36BhT+4m5WI603pfUNk6TZAEsWrSI9957j48//pgjR45w1113UV1dzc033wzAwoULz2qM8eSTT7Jq1SrS0tJITk7m+uuv59SpU9x2221qfQlSd2TQw15RKsjIm9WNpaOG3wjuflB8FE6oXw/tMCzJiH/zczydiqJAYA+x7wBrZTUtQhygahxtdlqFNbLOZGnjnr9f3giRpDOM7SlGj7aesH+StfZIAfV6Iwkh3gyI7NiUk0l9QtFqFI4XVpF5um1zgSTH4DTlggBXXXUVRUVFPProo+Tn5zNs2DBWrFjR2AwjMzMTjaYpbywtLeX2228nPz+fwMBARowYwdatWxkwYIBaX4LUHR1fCZV54BUM/dSbt9MpHn4w4kbY+jpsWwx9ZqkdkWOwJFl+XSDJAtHGvfCw6m3cdQYj+3PKABjuLElWqcqLUgf3FuvaNVTB6RMQ2kedOCT7y02Bo8uh5rQYVR95K7jbrpOes7kgMQSAnRkl1OkMeLjar/vmL/vyALhoSGSHl6Hw93RlVHwg29NKWJdawE0TnHxNxm7EqZIsgHvvvZd777232c9t2LDhrP+//PLLvPzyy3aISpLOY/dHYjvsOnA5d76f0xh1u0iyMjZDTYkoL+vuKrrQSBY4zFpZR/MrqdMZ8fNwoWeIk7xZLDWP/nVwmYNO07pAxGDI3gl5KTLJ6g7qKmD1o7BnCWBq+vi2xTD7ORh0mVqROZTEMB8i/T3IK69jZ3oJk/rYZ43Kijodm8zrW100JKpTx5raL4ztaSWsTS2USZYTcapyQUlyOmWZcGKN2B9xk6qhdFpgDwgfLOaWHV+ldjTqMxqhXDS+kEmWdVnmYw2LC0SjcYKexUYDlGeJfbWSLGgqGZSLEjs/kwmydkFtWfOfN+jh6+thz0eASVRJXPAXCEyAqgL47mbY8F9xnG5OURQm9hajWZvsuKjvqkMFNBiM9A7zoW+Eb6eONbWfqNjakVZCVX3b1tiS1CeTLEmypQPfASbR8CJYpQnx1mQpEzy6XN04HEF1IRh1Ym0s3y7STMdBkqxky3ys2ABV42izihww6kHjqu7Pgmx+0TUUHIKP5sAH0+HVIbDpBag6o+mByQQr/wHpG0Wzmht/gas/h+mPwz07Yfz94nEbnoGf7gF9gypfhiOxjF5tPl5st3P+uFdUOlwyrHOjWAC9Qr3pGepNg8HIL/tyO308yT5kkiVJtnRyndj2v1jdOKyl7xyxPblOvnBbRrF8I0WpVlfgIElWU2fBAFXjaDNLqWBALGjsN9/jHGeOZBkN6sUhdVxZFrw/AzK3if/XlcO6p+GlvvDhbFj+d3j3Qtj5jvj8Ze9AwqSm57u4wcyn4KKXQdFCyufw+Z+g4KD9vxYHMqFXCIoCRwsqyS+vs/n58spr2XpSNNq4ZFjnFwdXFIVrR4u/z59tP4VJjlA6BZlkSZKtNNRA1g6x32uKurFYS9Rw8A6D+grI3Kp2NOqylId1lVJBgABzd8GqfNDZ/o1Ic0qqG8gwd9BKinWS9u2lGWKrZqkgQGg/MbLRUAnFx9SNReqYbW+Arhoih8ID++Gy98TfXZNRJF473hJJtIsnzHke+s9v/jgjb4FrvwY3H0jfhOv7kxl//Fkxj6sbCvR2Y0hMAACbjtu+ZPDnlFxMJhgdH0RskJdVjvmn4TG4uWg4lFvB/uxyqxxTsq0OJVnV1dX8+9//Zvz48SQmJtKzZ8+z/kmSBJzaCoYG8I+FYCdbG6slGg30mSn2j61UNxa1NXYW7PxdSofhGSjelIFqa2WlZIlRrF6h3vh7OcmiqZaW95YkVS0aLUQPF/vZu9SNRWq/6mLY87HYn/6EmAc75Eq4Yz3cnwIXvwFj7oRpj8JfDsGYP5//eL1nwK2rYeClmDSuhFYdQfvTn7vtKOdkc8ngmsMFNj/Xj3tFpcOlw633+hDo7cZFg0U58mfb1V9mQ2pdh2pcbrvtNjZu3MgNN9xAZGTH21JKUpeWtl5se04WaxB1FT2nwN7PIHO72pGoq6KLNb0A8XPa2Mb9lCoLZyefKgOcaBFiOKOzoMpJFkDMSNEBNHs3DF+odjRSe+x4G/S1EJUkXjfOFJTQseUBwgfAFUswnNqJsmQu2hOrYf0zMO3fVgnZmcwcGM6ra4+z6XgRtQ0GPN1sU9p7orCK1PxK3LQa5g6y7hzN68bG8cPeHJam5HD3lEQSQrytenzJujqUZC1fvpzffvuNCRMmWDseSeo6LPOxukqpoEXMKLHN3w+6WnD1VDcetXTFckFQfa2sHeliHsOIHk6UZFlGstQuFwSIHim22bvVjUNqn7oK2Pmu2L9gkdVvzJmiktgXdysjTr0NW16GpOvUWzhbJQMi/YgJ9CS7tJaNx4qYPSjCJuexdDAc0zPI6qPxw+MCubBPKBuPFfH0r4f54KZRVj2+ZF0dKhcMDAwkKEiukSNJLaosEG9UUSBhssrBWFlAnJiXZdRD3n61o1FPeRdbI8tCxeYX1fV69po7C07oFWL383eYZU6W2uWCIEayQPz9qa9UNxap7fZ8JJpchPSx2aL12UHjMfacBiYDbHzeJuewi4ZqyElud9mjoijMGigSq5WH8m0RGdA058vSNt6aFEXh3xcNwEWjsDa1kPVHC1t/kqSaDiVZTz31FI8++ig1NTXWjkeSuoZTW8Q2YhB4B6sbi7UpStNoVnee99Hlkyz71/zvzChBbzQRE+hJXLB1JovbnK5WrEsEjjGS5Rsh5oFigty9akcjtYWuTiwgDDDhQTH31UaMF/6f2Nn/NRQ5YXOUynx4byq8NwVeHQpbXhFrhrWRZfRq7ZECdAaj1cOr1xnYniZG42216HFimA83T4gH4KlfDtOgt/7XIVlHh36TX3rpJVauXEl4eDiDBw9m+PDhZ/2TpG7vlLnzXo8L1I3DVix3y7trkqWrg2pzhyq/LpZk+ceKrQqNL7aZWx6P7+VENyYsI37ufqJxiCPo7r+fzmbfFyJR94uBwVfY9FSmqOHQZ47oVrjuKZuey+oqckUb+6JU8f/yLFjzGHx2qWga0gbD4wIJ8XGjok7fmAxZ057MMup0RsJ83ekb3rkFiM/nvmm9CfFxI624mo+3ZtjsPFLndGhO1oIFC6wchiR1MY1J1jh147CVxpGsbjrvo9K8GKSLB3h1sdJpy8icZR0wO/r9hHijNCHRSUsFHaXBTcwoOPRj9/39dCb1lbDhObE//j6xzpWtTfs3HF8JR36G9E1nr7PlyNY8AaXp4nftum9F86UVj4iv4a0JcMkboqPieWg1CjMGhPPlzixWHspnYm/rjjZtPiESt4m9Q23aFM7Pw5WHZ/fj4e/28+ra41ySFEWYr4fNzid1TIeSrMcee8zacUhS11FTYp6PBcSNVzcWW4lKAkUDFdni7qJf51e0dyqV5np+30jHeWNtLZYkqzJXlOHYaaHl0uoGDueJNXzGOdNIliN1FrQ4s/mFydT1fka7kk0viFGsoF4w8mb7nDN8IIy4GXZ/IBY3/vMmx19QvSwLDn4n9q9YAqF9xb/Y0fDNQrEu3OeXw6S/wZR/nvdnfubACL7cmcWqQwU8efEgNBrr/X5sOS5uFE3qY/sbRZcPj+Hz7afYl13Okt8zeHh2P5ufU2qfThX+7tmzh88++4zPPvuMvXtl7bckAU2tzYN7g49tarJV5+4DYQPFfncsSarME1tf67bndQjeYaBxFeVEVbabHP5H29JOYzJBn3Af57oje/qE2HakvbatRA4R38PqQtW6REptcPokbHtT7M9+Flzc7Xfuqf8CjwAoPCRaxzu67W+JZksJk5rWggMI6y+SxDF3iv9vegHWPS1uLrRgfK9gfNxdKKysJyW7zGohljdAakEVioLVR8iao9Eo3DGpFyDW5TIYW/6aJXV0KMkqLCxk6tSpjBo1ivvvv5/777+fESNGMG3aNIqKbL+StiQ5tExLqWAXHcWyiB0ttpbSyO6kcSTLNi2AVaXRNI1M2nFe1vpU0SXrgkQnuzFRbG4eENJX3TjO5OoJEYPFfne8CeIsNr0ARh0kToc+s+x7bq8gmP642F/7JBQese/526O2FPYsEfsTHjj3866eMOe/MOtZ8f/NL8LeT1s8nLuLlin9wgDrdhk8Wi5GxAZF+RPkbYeyT2Ba/zD8PV3JK69j68m2zUuT7KdDSdZ9991HZWUlhw4doqSkhJKSEg4ePEhFRQX333+/tWOUJOdyqpskWfHmdfIyflc3DjV05ZEssHvzC6PRxAbz2jJTzW9+nEbxcbEN6aNuHH8UI9fLcmilp2D/N2J/yj/ViWHETZA4Awz18MPtYNCpE0drkj8FXbWonug1reXHjbu76Vqu+lfTzbBmzBoYLh52qADTeUa92iO1TCRZ9igVtPBw1TJ/qHgd+n6P/ZsVSefXoSRrxYoVvPnmm/Tv37/xYwMGDGDx4sUsX77casFJktOpr4K8fWK/qydZls6JBQfFncbupCuPZAH4R4utnZKsQ7kVFFXW4+2mZVSCg3Toa4uGajEvESCkt7qx/JGlOU2OTLIc0tbXxHpVPaecXf5mT4oimkV4BkH+Adj1vjpxnI/RALveE/tj72x9fuEFiyBymFhzbPnDLT5sct8w3Fw0pBdXc7ywqvNhGk1NSZYdSgXP9KfhYh7tikP5VNY5aKLcTXUoyTIajbi6nruKtaurK0aj7NcvdWPZu0TduH9s03pDXZVvuJh3hglObVM7Gvs6s/FFV9TYYdA+SZZlQc0JiSG4u2jtck6rsMzH8gp2vC6TlpGsvH2gr1c3FulslQVidAZg4iJ1Y/GNgOnmZmYbnhONmxzJsRViXqFnYNva22td4OLXQdHC4Z8gb3+zD/Nxd+ECcxfTFQc7XzJ4JL+Sar2Ct5uW4T3se6NoWGwAvUK9qdMZWXYgz67nls6vQ0nW1KlTeeCBB8jNzW38WE5ODn/5y1+YNu08Q7mS1B65e2HvZ6BvUDuStss0JxtxXbR1+x9ZSgZPdbOSwcZywS46kuVnHsmqsE8b93Xm+ViyVNCKAhNE8mdoEKMUkuPY/qYo0YsZBfET1Y4Gkm6A8EFQVwYbnlU7mrPteEdshy8Uc6/aInII9J0j9lN/a/FhsweKv9/WmJe12dxVcGzPIFy1tltMujmKovCnEeLG2Pd77L/0htSyDv0kvPHGG1RUVBAfH0+vXr3o1asXCQkJVFRU8Prrr1s7Rqk7Ks+BJfPhp3vg/alQcFjtiNqmu8zHsrCUDGZsVjcOe+vyI1mWOVlZNj9VcVU9+8wdvib3dbYky9L0wsFKBUGUVUXLeVkOp7YMdn0g9i9Y5Bjt9TVamPWM2N+zRJTBOoLjqyF9o1guZNRt7Xtuv4vE9jxJ1rT+YWgUUa6cVVLTiUBhoznJmpiozvITlyXFoFFgZ0YJp047yPdP6liSFRsbS3JyMr/99hsPPvggDz74IMuWLSM5OZmYmBhrxyh1NyYT/PZXaKgU/88/AB9fJF6cHJm+oamTV3dJsiwjWfkHRA28SnQGIweyy6mwRz16fSU0mGv4fcNtfz41NM7Jsv1d0WUH8jCZYEiMPxH+TtS6Hc5IshxwJAsgapjYypEsx7HrPfHaFjYA+sxWO5omCZPEzRVDQ1NFhprqK+HXv4j9MXe1v/y+zyxRMlhwoGktuz8I9nFnVLwo8111uKDDoZ6uqic5swxQbzQ+wt+jcRH375PlaJaj6PCYpqIozJgxg/vuu4/77ruP6dOnWzMuqTs7/BMcWy7WeVn4s5j3U3Mati1WO7Lzy90L+jpRouOob7qszS8KgnqKNZUs64PZUUZxNf9eepBR/1nD/De2MP7ZdTy3PJXaBoPtTmoZxXLzBXdf251HTZY5WbUlNr+r/VOKKDu/eKgTLmhdbJ6TFeyAI1kgSsBAvNGU1FdXLtZ7ArjgL2K5BEehKNDzQrGftkHVUDCZYOU/xUh6QBxM7UD3Ra+gppudR5e1+LBZVigZXJdaiNEEMd4mIlW8UXR5Y8lgttU6Jkqd0+Ylvl977TXuuOMOPDw8eO211877WNnGXeowk0msHQJwwYPij/60R+GbG0Qd+5g/g7f92qO2i2V9rLhxjlECYi89JkBJGmRssdtaL4WVdTy3LJWlKTlY1l90d9FQVa/n7Y0nKaqs56Urh9rm5F19PhaAhz+4+0F9hRjNCrXNTYOskhr2nCpFUWC+syVZRiOctszJctAkK8KcZBWmgkEvmgJI6tnyirhhGNwbBl6mdjTn6jlFzINWO8na9AIkfyz2578Kbt4dO07fuaKUPfU3GHtXsw+ZOTCcJ389zO6MEk5X1RPs0/4FodccEaNggwLVTWxmDojA201LTlkte7PKGB7nRJ1au6g2/8V9+eWXue666/Dw8ODll19u8XGKosgkS+q4zO2iJbiLJ4y7R3ys/3zRkjUvBba8DLP+o2aELbOsF9VdSgUt4i8QCz/aqfnFmsMF/OXrFCrr9YAoz7h5Qjzjegaz/GA+93+1l++Ts5k3JIKp/WxQztfV27db+EVDUYVoUW6jJOvnfWIUa1zPYML9nKxUsDxLjFxr3SCgh9rRNC8gHtx8RHnr6eMQ1r/Vp0g2Up4tbhQCzHjSMRPehElim38AqovVuaG5/1tYb36Nn/0c9Jra8WP1mQUrHxHvK/T14HJuAhUT6MWgaD8O5lSw5kgBV41qX1linc7ApmNiPtagQHW7a3u6aZk+IJyfUnL5dV+eTLIcQJvHqtPT0wkODm7cb+lfWlqazYKVugHLOh2DLxctW0GMCk35h9jf+5ljtiPW1zclGZYXqu6ih3leVm6KqKO3oeKqev767T4q6/UMjfHnp3sm8OFNo5jYOxQXrYb5Q6O4dUICAI/8cMA2c7S6etMLCxu3cTeZTCzdK+YOXDLMyUaxoKmzYFAvx3zDDKIcLXyg2M8/qG4s3d2qf4ukvMeEps53jsYnrKnENH2j/c9v0MO6p8T+hAdbHH1qs6Ce4n2EUQeFLTfPmjXAUjLY/nlZv58oplZnIMLPnZgODrhZ07zB4nVp2YE8jEZZMqi2DhUEP/nkk9TUnNuJpba2lieffLLTQUnWtS+rjGeXHeGZZUca7xw7pKpCMR8Lzu0klDgdfKNEi9mjDrjgddZO0NWAd5hYlb47CYgVd/JNBsjaYdNTPfPbEcprdQyI9OP7u8YzNDbgnMc8NKsv8cFeFFTU8+GWdOsH0V1Gsmzc/GLjsSKOF1bh7aZl9iAnTFgbm14kqhtHa+S8LPXt+xoO/SAaMcz6j2OXk/ecLLYn19v/3Id+gLJTYl7zhf/X+eMpiqiCAbFeXAtmDRJ/y7ccL6bKXCHRVh/9niGOMTDcIb6tF/YNxdfdhfyKOvZklqodTrfXoSTriSeeoKrq3BWya2pqeOKJJzodlGQ9uWW1LPxwJ+9sSuPdTWnc/+VeNh4rUjus5u37Stxxih7Z1BXLQqOFoVeZH/el3UNr1cl1YttzsmNNZraXeEsr9y02O8WqQ/n8sDcHRYFnLhuMSwtrkXi4anloVl8APticTnmNlUezGudkOWFi0B42Hsl6Z6Ooerh6dBz+nucubu/wHL2zoIVlXpYcyVJHaYbolgsw+e8QlaRqOK2ylOcdWwlGGzYQ+iOjUUwHADGC5eZlneNGmufm5qa0+JDeYT4khHjTYDCy3rxmX1ukZJWx5UQxLhqFm8c7Rsmwu4uWGQNFmfyvjnxTvZvo0LtBk8mE0kzKvm/fPoKCHGzV+27MYDTx4NcplNfq6Bvuy6Q+oQD8/fv99ml13V6Hl4rt0Kub//zQa8X2+Gox6uVI0sx3/TpTP+7MLCWDGbaZl7XhaCH3frEXgBvHxTOsmRGsM80dFEm/CF8q6/W8v8XKJczdZSTLz5xkVVg/ydqfXca2tNO4aBRuuSDB6se3i9PmzoKOnmSFDxbbAplkqWL530XL9rhxMPGvakfTuoRJ4BEA1YVN6z7aQ9o6UdLn5gujbrfecS03bM8zkqUoCnPMo1nvb0lvc2e+N9eLvwGXDIsmOqCNCyXbgaWJ0E/7cqnT2TFRls7RriQrMDCQoKAgFEWhT58+BAUFNf7z9/dnxowZXHnllbaKVWqn9zensTO9BG83Le/cMIK3rx9OfLAXeeV1PPPbEbXDO1tZJuTsARTof3HzjwntI0a5TAY4+L1dwzuvmpKmu2SWUovuJmGi2ObsgVrrlSg06I28se44d3yyhwaDkTmDIvjXvNYn72s0Cg9OF29+39+cTlrRuSPvHSZHsjrFaDTx3PJUQLRtd6Q3J+3iyAsRnyl8AKBAVQFUOWgVQ1d1cr15ORIXmP+aqMhwdFrXpoV8LTc+7eHgD2I79GrwDLDecS0jWQWHwNDyzeWbJyTg5aZlX1ZZm9q5rz5cwKrDBSgK3DW5p7WitYpJvUOJ9PegrEbXqdb0Uue1K8l65ZVX+N///ofJZOKJJ57g5Zdfbvz39ttvs2XLFhYvdvC1jLqJ01X1vLFO3GV5dP4A4kO88XJz4fnLxR+cr3dnkZpfoWaIZzv8s9j2GH/+BV4HmBOw9M22j6mt0jYAJrG4pF8Xf+PdkoA4CO0vEuDja6xySJ3ByLXvbefFVcdoMBiZNziS165JarFM8I9mDQxnbM8ganUG7v9qLw16K3R+MpnOGMnqogsRW5yZZFlxzZV3N6ex9eRpPF213DfNwROUltSWiaQFHHeNLAs3bwjuJfbPczdfsjKjQaz1BGKOsY06dNrEwAVie/hn+5QM6hsg9VfzuS+17rEDE8DdHwz1UNjyzeVQX3duM4+qP7/yKHrD2a8XJpOJDUcLeXb5ERavP8G9XyQDcO3oOBLDHGu9RK1G4cqRsQB8tTNL5Wi6t3a1RLrxxhsBSEhIYPz48bi6OmEdfTfx2trjVNbrGRjlxxUjYhs/PjohiLmDI1h2IJ/nVxzlw5tGqRjlGSwNLwYsOP/j4szt0TO3iRpuR5j/1Dgfa4q6cait72woOiLu3A65otOHe3vDSXafKsXXw4WnFwzi4qFRzZYpt0RRFF6+ahhzXt3MwZwKnl+Ryr8uGtC5oOrKQV8r9n26ermgueOfvk6M1noHd/qQB3PKeXHlUQAemz+AhBAHaMfVEZZSQZ8I8PBTN5a2iBouYs7ZA72nqx1N93BsJRQeEqV31mjiYE8JF4q18qoLxWutZc6traRtEH9bfcIhbqx1j60oEDlErJeVt0/st+D2ST35dPsp0oqqeeznQzy9YBA6g4n1Rwv56Pd0tqeVnPX46f3DeOJix2x0deWoWF5bd5xtaadJL6523r+1Tq7N71ArKppGPZKSkqitraWioqLZf5K60oqq+HxHJgD/mNsfjebsN6YPzeyLVqOwLrWQHWmn1QjxbOXZkL0TUSo4//yPjRwq1tCqLWkq11GTydS0cGN3nY9l0cfclvj4mvOWZbTF8YJKXjePxD51ySAuGRbdrgTLItLfkxfMo7fvb0ln/dFOzuWzjGJ5+FtvYrajcnEXb3pArAllBa+sOY7eaGLWwHCuGhXb+hMclbOUClrEjBTbnN3qxtGdWErah10HXk42V93FDfqZX4vt0WjKUpY44BLblFRaSgbzUs77MF8PV565dDCKAp/vyGTBm1sZ8fRq/vzpHranleDmouGy4dFM6xfGFSNieP2a4W2urLC36ABPJpvn4b+3WS6tpJY2/3QEBgZSWCjeoAQEBBAYGHjOP8vHJXU9v+IoeqOJKX1DmZB47mKCPUN9Gt/gvLvJAX75jvwitnFjWy+3c3FresOQacdJuS05fUK8AdW6db9FiP8oZqRovVtf3ukJ088uT6XBYGRqv7BOr6E0Y0A4N44TnZ8e+mYfhRV1HT9Yd5mPZeFnbuNe0fk27kfzK1lzRMxheHh2vw4lzQ7DskaWoze9sIg2/83M3m3V0k+pBbrapqVGBl2mbiwdNfwGsT3wvVXn2Z5DX99UKthaJUtHWTo65u5t9aFzBkfynwWiWcy+rDIq6/SE+bpzx6SerPvrhfzvymF8cNMoXrhiKJ5ujj3H7s4LRZnwVzszOVZg2zUspea1uVxw3bp1jZ0D169XYf0EqU12ZZSw4lA+GgUemdtyg4Cbx8fzxY5MNhwroqiynlDfc1dCt5tDS8W2rX9g48aJof/M7TDyFltF1TaWtUTixnb9kY3WaLTQexbs+wKOrYCeF3boMKXVDWwyLzPwz3n9rfJm/JG5/dmRXkJqfiVP/3aE167pYBvl7tJZ0MI/BnKTrdL84u2NJwGYMyiCXqE+nT6eqpylfbtFxCBxI6i2BErTxSKtku0cXwW6avCPg+gRakfTMbFjxBprBQch5UsYd7dtzpP6qygV9I2yfqmgRfRwsc0/IOZ/ubid9+HXjokjyNuV7NJaRicEMTDKH63G+W4KjekZzKyB4aw8VMAzy46w5ObRaofU7bQ5ybrwwgub3Zcch9Fo4j/mroFXjoylT3jLkzF7h/syNDaAfVll/JSSw20TVXrRrciFrO1iv7VSQYse48T21DbbxNQecj7W2XrPEElW+qYOH2LV4Xz0RhMDIv2s9mbcw1XLi1cMZf4bW/h5Xy63XpDQ7ELGrepuI1mNzS86Vy6YU1bbuBD63ZMdfPHetmgcyXKSckEXd4gYLOZkZe+RSZatHfpRbAcucOyFh89HUWDUrfDrX2DX+zDmTtvMgd6zRGyH32C77ouBCeAZKEbkCg42JV3n4ZQLpDfj73P6sy61kA1Hi/jzp7v590UDiAlsuiG89UQxS1NyOJxXQaCXG69cNYxgHxVvuncxHfqNWbFiBVu2NC06unjxYoYNG8a1115LaalcYVotb6w/QUpWGZ6uWhbNaP0O6+UjxBuo7/Zkt3ldCKuzlArGjgH/6LY9J2Y0KFooz7TZQqltYtCJETWAXjLJAprWyyo4JDqwdcCv+0UiM2+IdV/kBkX7c2mS+Bn7z7IjHfuZ744jWQDlnSsX/HJHJgajifG9ghkU7W+FwFRk0EGJuczaWZIsaCoZlPOybEtXJ5pegPOWCloMvlKsW1VyUqxjZW2nT5pvyCmQdL31j2+hKE0jijl7bHceB5QQ4s0/5vZHq1FYeaiASc+v55p3t/PYTwe5+aOdXPv+Dr7Znc3BnAo2Hy/m6ne3U9CZknrpLB1Ksv72t781Nrg4cOAAixYtYu7cuaSnp7No0SKrBii1ze8ninl5jShhefKSgYT5ebT6nPlDInHTakjNr+RATrmtQ2xeW7sKnsndp6lDkJqjWVk7oKEKPIMgYqh6cTgS33DzXXKTuD7tdLqqnq0nRTOWi6ycZIFo+uLuomFnegmbjxe3/wDdbSTLMierEzczdAYjX+8WI2HXj+1hjajUVXYKjDrRgMeyYLMziDljXpZkO9m7QFcjOk9GDlM7ms5x92mam7X1DesfP/kTsU2cLpYBsaXGJCvZtudxQDdPSOC3+y9gfK9gjCbYlnaaj7edYv3RIrQahWvHxPG/K4cS6e/B8cIqblmyC4NRzt20hg4lWenp6QwYIFohf//998yfP59nnnmGxYsXs3z5cqsGKLWupkHPQ9/uw2SCK0fGcMXItnXtCvByY85gcUf++RVH7T+aVVsm2sNC20sFLeLMJYNqNr+wJIh9ZjlGK3lHYWmz34HmF6sOF2Awmhgc7U+PYOu3nI0K8OSa0eLF/JNtp9p/gG43kmX+W9KJxhdrDhc0zvucMcD51xZTLO3bQxKd6/fe8iYzf7+YlyLZhqW6IWGi85YKnmnMnaBoIG095B+03nH1DZDyudgfcZP1jtuSbjqSZdEvwo8vbh/L5oen8Nj8Adw7JZH7piay4oGJPHPpYC4bHsM3fx6Hn4cLh3Ir+Cml882OpA4mWW5ubtTU1ACwZs0aZs6cCUBQUJBs4a6CtzemkVdeR0ygJ09cPKhdz/3rjL64uWjYcqKYVYcLbBRhCzI2g8koJo8HtLOdc5zK87KMhqYka6CTl4RYm2XOXGb7vzebj4uGFzNt+Gb8BnOnwXWpBWSX1rTvyY1JVjcZybKUC1bmgUHf7qebTCaWbM0A4KqRsbg6aLvj9lBOO1lnQYugnmLpAUMDFB5WO5quK92cZMVPVDcOawns0VRpss2Ko1lHl0F1kRjx6zPLesdtSZR5HlbxMdFoo5uKDfLi5gkJPDSrL3+d2ZfeZ8zdjw3y4s7JoiPhS6uOUa+3w0LUXVyHXvEuuOACFi1axFNPPcXOnTuZN28eAMeOHSMmxonKJ7qAnLJa3jF37frH3P7tbikaF+zF7RPFKudP/3aYOp0df6ks60v1nNz+51qSrKIjYqFUe8vcBlUF4k1LR+Lvyizfm5xk0cq4jYxGE9vMpYLjEzu/8G1LeoX6MCFRlE18YV5Prk1MJqhynpGskuoGHv/5EJe/tZVL3/ydjeaOje3iHQoaV3EzxFIq2Q7PrzzKjvQSXDQKV4924nWxztA4khXsRPOxwLwoa9vWC5I6qKFGlAuCGMnqKsbfJ7YHvhXNqqzB0vAi6XrQulrnmOfjEyq6PWKC3BTbn89J3Tw+gXA/d3LKavl8ezteH6VmdSjJeuONN3BxceG7777jrbfeIjpa1O0vX76c2bNnWzVA6fz+uzyVer2R0fFBzBnUsTd+d09OJMLPg6ySWl5fd9zKEZ5HZ5Isn9CmNzkdmPvTaZbuUf3mt9oOttsJ6ikWsTXq2lWacbSgktIaHV5uWobEBNguPuAG89ygr3ZlUVRZ37Yn1ZaKUQBoWqTXgT307T6WbM1g96lS9maWceOHO3nsp4PoDMa2H0SjAT/zOmXtnJf1/uY03togbgA9c9ngszpaObXTTtZZ8EyWOULyTaZtZG0Xf/f8YkRHu64ierhoamTUw453On+8knRRfojSNOfLHixdBXO737ystvJ00/LANDFK/8b6E1TW6VSOyLl1KMmKi4vj119/Zd++fdx6662NH3/55Zd57bXXrBZccxYvXkx8fDweHh6MGTOGnTt3nvfx3377Lf369cPDw4PBgwezbNkym8ZnT3tOlfDzvlwUBR6dP6DD6wl5u7vw+MUDAXhnYxpH8+2waF1ZlljIV9FC/AUdO0ZjK3c7z8sy6OHwz2J/4KX2PbczUJQOlXNaGl6Mig+yeVnZ9P7h9Az1pqS6gVs/3kVNQxtK4SwjOV7BoiW2A9t8vIh1qYW4aBReuHxI42LMH287xa0f76aqvvmvN6ukhp9Scnh97XFOFlWJDzaWDLb9Dvbi9Sd42rycxN9m9eXKNs4TdQZOWy4IEDVMbOVIlm2kd7H5WGeyjGbt/gjqO/kewdLwotdUCIzv3LHao5vPy2qrK0fG0DNEvD6+tzld7XCcWoffyRgMBr7//nuefvppnn76aX788UcMBtuWmn399dcsWrSIxx57jOTkZIYOHcqsWbMoLCxs9vFbt27lmmuu4dZbb2Xv3r0sWLCABQsWcPCgFSdvqsRoNPHEL6Ku/soRsZ1uizx7UAQzB4SjN5r419IDtm+Ckb5RbKNHiJK7jojr+NyfTjm+CqoLxZvtDi642+X1MDe/aMf3xlIqOK6X7UoFLVy0Gj64cRSBXq7szy7n6ne3szO9lbJTJ+ksqDcYefpXkeDcMK4HV4yM5YlLBvH+wpF4umrZdKyIGf/byOL1JxqTLZPJxHub0pjy4gYe+CqFl1Yf49r3tlNYUdf09baxTOj7Pdm8sPIoAA9O783d5hr/rsBNX4lSa16mJNgJ1/uyjGQVHJLNL2zh1O9i21XmY52p9yxRPVJfDsmfdvw4Bh3s/Uzs26PhxZm6cYfB9nDRanhoVl9AVCQUV7Wx2kM6R4eSrBMnTtC/f38WLlzIDz/8wA8//MD111/PwIEDOXnypLVjbPS///2P22+/nZtvvpkBAwbw9ttv4+XlxYcfftjs41999VVmz57N3/72N/r3789TTz3F8OHDeeMNG7QitbM31p9gf3Y5Pu4u/HWWde6oPnHJQNxdNOzKKOX3E6etcswWdaZU0MKSZOXuFbXw9pL8sdgOu9Y+teTOyPK9ydopmoS0wmA0sSPdPB/LDkkWiPVDPrhpFN5uWvZnl3PlO9t4/OdDLbeudZLOgj/uzeFoQSUBXq48MK2ppG36gHC+umMs4X7u5JXX8cLKoyxY/DtbTxRz40e7+M+yI+iNJobE+BMd4ElBRT13frYHfTuSLKPRxOINYs7SPVN68eD0Ph0eYXdEPnXma+AfB25OWP4Y1BPczc0vio6oHU3Xom9oKsOMG6tqKDah0cC4e8T+7690fDTr6HJxk9I7DPrOsVp4bRI5VHRKrMiBivbPMe1O5gyKYEiMPzUNBt5Yd0LtcJxWh5Ks+++/n169epGVlUVycjLJyclkZmaSkJDA/fffb+0YAWhoaGDPnj1Mnz698WMajYbp06ezbVvzd8u3bdt21uMBZs2a1eLjnUFlnY53Np7kf6vFmlh/n9OPMN/W18Rqi0h/T64dI9pbv7r2mO1Gs0wmyLDc8etgqSCIMgOfCFEnnrvXKqG1qjxHjGQBDL/JPud0RuEDwd0PGioh/0CrDz+cW0FlnR5fDxcGRtlvsdrhcYGse2hyY1v3JVszuPeL5OYbwDSOZDlukqUzGHnNPK/yrgt7EeB19nzBobEBbPzbFF66YigRfh6cKKzi2vd3sOlYEW4uGp5eMIif7pnAZ7eNwc/DheTMMrYVmksj29DGfeOxItKKqvF1d+GuyU440tMKnzrzz0CIk35titK0xqCcl2VdBQfAUA+egea1ArugYdeKuWZVBbD5fx07hr0bXpzJ3QdC+4t9OS/rvBRF4f9m9wPg8x2nyDxtxxvZXYhLR560ceNGtm/fTlBQUOPHgoODee6555gwYYLVgjtTcXExBoOB8PCzJ5yHh4eTmpra7HPy8/ObfXx+fn6L56mvr6e+vmlo1NKSXqfTodOpMwFQp9PRYID7Xv+KyJKdfKQXzUXundyTq0ZEWTWuW8fH8fmOTDGadbyQMQlBrT+pvUozcK3MxaRxRR8xDDoRvzZmNJrUnzFkbMUYPbrDx7Fcw9aupWbPx2hNRoxx4zH49+hU7F2dNmY0mpNrMKRvQRcgRltbur670sXCwMNjAzAa9G0Z/LKaIE8tT87vx5j4AP72/QGWH8ynqHI7b1+XhL9n05sATXkuWsDgFY7Rwb7vluv63e4sskpqCfFx45qR0c1eby1w8ZBwRsf7c+fnezmUW8mEXsE8Oq8fPUO90ev1xPi78e95/fjb9wf5KR0mAsbyXAytfN3vbRaVDFeMiMZdY1Ltb6Yt6HQ6fOtFkmUISnS4n4G20kQMQZuxGUNOMsYh16odTqO2/g12VJpTO9ECxqgRGPTtX+7A1qxzfTUo05/E5dsbMG17A/2Qa9rX4KPsFC4n16EAuiHXqvL6qY0chqbwEIbMXRh7zbTacZ3957c5o3v4M6FXML+fPM2LK1N56YrBqsbjSNe4rTF0KMlyd3ensvLcoeKqqirc3Jy709qzzz7LE088cc7HV61ahZeXeuUhfobTvFLxf3i4NFCkjYDwwSTWHWPZsmNWP9eYYA2bCzT854ed3Nm/HZ3I2ij29GaGAyWe8WxZvaFTx+pZ6cNgoCj5V3aUd75scvXq1S1/0mRi2uGP8AH2KkPI7kJNVGyhd00gA4CCXUvZVSwaH7R0fZcd1wAaPGsLVGtOowB/7qvw/lENu0+VMe/ldVyfaCDOR3x+dNo+IoGDp4rJcMDvvcEEr6xOBRQmhtSyfs3KVp9zaywUhUK4ZwGpuwo483aV1gRB7lpO1PqBO9QVnmT1eb7u3GrYetIFBRMxtSdZtsx2peNqGWMuFzyYV++QPwNtEV1qYiRQnrqZzSbH+xrO+zfYgQ3P+JlY4Fi1L0cd+Gej09fXZGKc7yDCKg9S8MW97Em4p81P7Zf7HX0xUeg7iG3bDgP2X6+tR4kbw4DTB1axrTbJ6sd31p/flozzgt9x4Zf9ufQji2hvtSNyjGtsWSu4NR1Ksi666CLuuOMOPvjgA0aPFqMHO3bs4M477+Tiiy/uyCFbFRISglarpaDg7AVzCwoKiIhovnwnIiKiXY8HeOSRR1i0aFHj/ysqKoiNjWXmzJn4+fl14ivoOJ1Ox+rVqynrcwURxz7nde8P0V+9wWZlSwNP1zD9lS2klmsYNv5CogI8rXp87a8rIRMChsxh7tS5nTqWkhMBS74gXHeKuXNmi3rrDrBc4xkzZuDq2kIJQ94+XFMKMbl4MuTKRxji5gB/bRyYkhUEn3xLpC6DGdOns3rNmhav7yuvbAFquGLqKCb1DrF/sGeYnV/JrZ8mU1BRzyuHXLlyRDTXjIolvOAVKIeBY6cxoG/nfm6tTafT8fLXayipVwj0cuXJhdPwcG3fmnnNKQvO5J3fRKMHT305c2fPAk3zx33kx0NADrMGRnDDZUM7fW5Ho9PpMB36KwADJ1/KgB6dKHVWU1EvePdNAnX5nfqbaW1t+hvswFzeehyAxMlX06vXNHWDaYZVr29BD3h/MtFlOwkf+SKE9W/9OQYdLm/8DYCgGX9hbn+V/obmx8IHHxHakGXVn39n//k9nyOm/fx2MJ8dtRG8f8XwNj2ntsFAYVU93m5aQnys043Xka6xpcqtNR1Ksl577TVuuukmxo8fj4uLOIRer+fiiy/m1Vdf7cghW+Xm5saIESNYu3YtCxYsAMBoNLJ27VruvffeZp8zbtw41q5dy4MPPtj4sdWrVzNu3LgWz+Pu7o67+7k/EK6urqp/U4Mv/S8s2Y9ScADXX++FG5bapE1sYoQ/43sFs/Xkab5PyWfRDCu3Ks4Sc+K0CRPRdvaaxg4HF0+U2lJcyzMgtG+nDnfe7/Ox3wBQes/A1TugU+fpFuJGg9YdpboI18pTQPPXt7xGR7q53nt4j2DVf88GxQax7P6JPPHLYX7el8uXu7L5clc2+/2y8QNcAmLAAV9EtxeKvwWXj4jB18s68zSvGRPPm+uPozdocMGAa0NZszd3iqvq+Xm/KKW7fVJP1b+HNtFQhWuDWNDZJXKwQ/4MtEl4X9C6oeiqca3KhSDHWs/JEV5r262mBErSAHCJG+3QPxtWub4xSTDgEpTDP+G65QW4qg3dBk+sFHO5vEJwGXAxuKh0jaIGi/cM9RW4VmRafb07p/z5bcXfZvdj5eECNh4vZk9WBWN7ttycymQy8fKa47y5/gR6owlFgb/O6MM9UxKt1gTJEa5xW8/frhTeaDTy3//+l3nz5pGTk8OCBQv49ttv+e677zh69Cg//vgj/v62m7S+aNEi3nvvPT7++GOOHDnCXXfdRXV1NTfffDMACxcu5JFHHml8/AMPPMCKFSt46aWXSE1N5fHHH2f37t0tJmUOz8UDrvgIXDxFdz5LG1QbuNrcCODb3Vktd1vriIo884uRAnFjOn88rWvTAoO2XJTYZILDS8X+wAW2O09X4uIOsWKkW5OxpcWH7csuAyA+2ItAb8coNw72cee1a5L44rYxzB0cgYIRr3oxb8wRG18UVtZzqFS8gF01ynprUnm6aVk0awBFBABQkJ3W7OM+355Jg97I0NgAhscFWu38jkQpFqXZJu9Q8FZ3tLVTtK4QYr4ZVSg7DFqFZd2loF7gZYN5zI5o8iOAAkd+bttalZauvEnXgYuKf+e1rk2t3C1LyUjnFR/izdWjxevKoq9TeHvjSb7dncVXOzPJKmkqm6uo0/F/3+/ntbXH0RtNuLtoMJngxVXHuOuzZHLKatX6ElTTriTrP//5D//4xz/w8fEhOjqaZcuWsXTpUubPn09iou27LV111VW8+OKLPProowwbNoyUlBRWrFjR2NwiMzOTvLymtpzjx4/niy++4N1332Xo0KF89913LF26lEGDBtk8VpsJ6Q1T/iH2V/0LKgvO//gOmjUwnEAvV/LK69hwtPl1yDok0/zHOGJwx9fH+qNYc7KWacMkK/+ASA5dPMR6IVLbJIh1xJSMTS0+JCWrDBCd7xzN+MQQ3rxuBP83MQQXxYjRpHCwwvEWIv5xby5GFJJi/UkM87Xqsa8eFUulWxgAH6/4ndw/vFAeyC7no61iwcpbJsR3qZbtZykSa3+ZQjo3Wu4QLOVdhYfUjaOryN4ttjEj1Y3DnsL6i26DAN8shLKs5h9nNMLW1+G4eR7N8BvtE9/59JoitifWqRuHE7l/Wm/C/dzJLa/jueWp/O27/fz9hwNMfH49s17exMIPdzL+2XV8szsbjQLPXjaY1Kdm859LB+GiUVhxKJ8pL27gP78dprS6+6zR164k65NPPuHNN99k5cqVLF26lF9++YXPP/8co9H6zRFacu+993Lq1Cnq6+vZsWMHY8Y0jYZs2LCBJUuWnPX4K664gqNHj1JfX8/BgweZO9ex5lJ0yNi7xXoPdWWw9C6xuJ+VubtouXxEDACfbT9lvQNn7RTbuJZLNtvNcqyT68QfdFvY/7XYJk4XbWCltjGvg6ZkbAZT898bS5I1zAGTLIs7hol5iafxY9F3h6jX27H9YSvq9Qa+2Cne4FwxItrqx9doFKLiREvq6uIspr20kcve/J3L39rKrUt2ccU7Wymr0TEg0o+5gx17oebOUMzrSplC2zD/xNGFDxBbOZJlHVnbxTZmlLpx2NvcFyB8MFQXwZfXQN0f5qlUFcLnl4sbwphg5C0Q7ACLkyea58ylb5SLcrdRmK8HaxZdyLOXDebCPqFc2CeU0fFBaBQ4WlDJpmNFVNXr6R3mw4c3jeKa0XEoisJ1Y3rw490TGNsziAa9kfc2pzPphfWsOtRyl++upF1zsjIzM89KUqZPn46iKOTm5hITE2P14KQWaF3gksXw/gw4uRZ+uhcWvCUWC7Si68b04L3N6Ww4VkTm6Rrigq3QXTF7l9jGdrzd+jl6XihGxSpzIWOz+L811ZScsbbHDdY9dlcXlQTufih1ZQTUZpzzaZPJxD4nSLI01WLE+LQSxLGCKl5Zc7xxDRG1fbkjk9zyOvxcTVxkoyTHJ6QHnISkgBo+Pm0gObPsrM9P7hvK69ck4ap1jCYKtqCYR7I6O+/TIYSZk6wC+3d363IMesgyv65Z8+ahM3Dzhmu+hPeminXCvr4OrvtOlIqfWAM/3ikSMBdPmPOcY4xiAUQMBa8QqCmG7J2dW6+zG/H1cOWa0XGN60oCFFXWcyCnjMKKeqIDPZnQKwSN5uxqhsEx/nx5+1g2HiviueWppOZXcu+Xe/ns1jGMtsUyQQ6kXa+Ier0eD4+zJ1S7uro6RM/6bidiMFz5CSha2P8VbF9s9VPEh3gzqU8oJpNYjK7TdHWQt1/sW7OswsUdBl4q9vd/Y73jWmx/CxqqxDXvI0sF20XrAvETAQipPPcNXXpxNaerG3DTaugfqU73zjYxL0QcFNkDgHc2niQ1v23dhWyppkHPG+tPADArxoinW+c7CjbLTyRvl/RU+PqOsbxzwwjevG44z1w6mNevSeL9hSPx9ehak73/SCkWDe67xEiWJck6fVzeye+sggOgqwZ3/6br2p0ExMJ134KbL6Rvgveni+Tqsz+JBCtsANyxAUbcZJNGXR2i0ZxRMrhW3VicXKivO1P7hXP16Dgm9g49J8GyUBSFyX3D+O3+icwcEE6D3shtH+/inY0nKa6qb/Y5XUG7kiyTycRNN93EZZdd1vivrq6OO++886yPSXbSZ6YYrgdY/wyUZVr9FAvHijeVX+/Ook7XyRKp/P1g1IF3KAT0sEJ0Zxhyldge/gkarLgyeV057HhH7E/6m+O8SDgTc8lgaOW58z+2pZ0GICkuwCotx22mUpQ2hEX2YM6gCIwmeGlVx9aoK6luoLLOOjem3t6YRnFVA7GBnowLs2KDmj/yE2WISkUuY3oGM2tgBHMHR3LtmDjmD43CpQuPYAFQV45SkQOAKdQxRjA7xT8G3P3AqBeJltRxp0S3XOLGWL2axGlEDYOrPxcjVvn7Yd+X4uOjboPb10GYA/7OJE4X25MyybInrUbhtWuSGNEjkIo6Pc8uT2Xcs2u554tkDmSXqx2e1bXrL8KNN95IWFgY/v7+jf+uv/56oqKizvqYZEcjb4EeE0BXA8v+JrrgWdGUfmFEB3hSVqPj1/15rT/hfCylgjGjrJ+sxI6FgDhoqISjVlwI8vfXoL4cQvtBv/nWO253Yi7fDK46Boaz75pvOymSrHG9Wm4J6xDMI1n4RvLXmX3RKLD6cEHjfLLW6AxGPt6awdxXNzP8qdWMeHoN/12RSkUnkq3kzFIWm0exHprRG5vmOX5RYluZa8OTODBzqWCta6D1GvaoSVHOaH4h52V1SqYlyRqrbhxq63khPJACs/8L/efD1V/AvJfA1brrbFpNT/NIVt4+MSVAshsPVy2f3zaG//5pMENjA9AZTPy2P4/L3vqdr3Zaf7BATe2ak/XRRx/ZKg6poxQFLnoZ3poAx1bA8VVWLWnTahSuGxvH8yuO8un2U43NMDrEnGQVBQwhyGhC28KwcodoNDDkatj0PGx6EQYsEKVqnVGRC9vMZZhT/91971J2VkgfTJ5BaGtLMOUfhHjRrMZkMrE9Tby4jTvPuhsOwTyShW8EiWE+XJoUw/fJ2Ty/IpXPbxtz3o56R/MrueeLZE4UVjV+rEFv5K0NJ/no93RmDYzg3imJ9A5ve1fAqno9D36VgsFo4pJhUcwdHMGyFpp7WYWvea5XRa64kdPdRnQLRalrpUc0XaZBfdgAsexFwUEYfLna0TgnkwkyzU0vutt8rOb4RsDYO8U/R+cbDsG9xUhu1k7oO1vtiLoVD1ctV42K46pRcRzKLee1tcdZeaiAv/9wgNyyWhbN7AJzX2nnSJbkoEL7wti7xP7qx8Bo3c5nV46MxU2rYV9WGfvNaxq1V73eQOUJ0b79/s1arnl3u/XrcMfdDZ5BUHQEdn/Q+eOt+w/oa8WLZ795nT9ed6UomMzrkig5uxo/fKKwiuKqetxdNAyLC1ApuDZqTLJEsvHg9N64aTVsPXmaL3e2nN0UV9Vzy5JdnCisIsjbjScuHsjuf03n/YUj6RPuQ53OyE8pucx7fQvvb07D1MaR6NfWHiezpIboAE+eWmCHJSksSZa+DmpLbX8+R1MgSl0rPLpQg6cI889N/gF143BmJWlQXQhaN4garnY0UntZRh8z27DOl2QzA6P8efv6Efx1Rh8AXlt3gg+3pHMwp/ysm5POSCZZXcXEReARIBKMlC+seugQH3fmDhYLsH66rf0NMDKKq7njjZ/xrS/AYFLYb+zJzowSLn59CxnF1dYL1DMQpv1b7K//D1Sf7vixyrMh5XOxP+Op7nfn3spM0aK1sZKzu/FjlvlYI+MDcXdx4PlYcNZIFkBskBcPzxZ32p769TDpzfwcN+iN3PXZHnLKakkI8WbNogu5cXw8IT7uTB8QzsoHJ/HzvROY3DeUBr2Rp387wrubml/s90wni6r4cItYl+rpSwfhZ4+GE64eohsXgHluUrdiXmy2zLunyoFYUcRQsbU0I5La76R5naWo4eJ3RHIuPcaLrWU0UlKNoijcN603i8yJ1pO/Huai17cw8+WN/JTivK85MsnqKjwDRWMGEE0wrNwx6oZx8QD8vC+XsprzH9toNHGisJLyGh0/78vlote34FYoXsir/BL59v6Z9Az1Jre8jrs/T+58Q40zDb9RdAGsK4dtb3T8OId/AkxiFCu2m619YgOmaNFN8qwkyzIfy9FLBQ16cbcamkZ0gFsmJDCuZzC1OgN3fLKbkj8ssPj+ljR2ZZTi6+7CewtHEuTtdtbnFUVhSEwAH900qrEd/EurjnEkr+WuhUajiSd/OYzeaGJqvzCm9A2z0hfZBpZ5WRWdnJvpbPQNjaM9pV4JKgdjReEDQdGIn23LTQSpfQ7/JLb9usD6m92RZSQrJxl0ted/rGQX901N5Kbx8QD4erhgNMFfvk7ht872BFCJTLK6ktG3izeBlblw6EerHnp4XAADIv2o1xv5dnd2i49bcTCP2a9uYvr/NjH0yVXc/+Vequr1zAwU6wz59xzJgCg/vrx9LMHebhzOq+CJX87tOtdhGi1c+Hexv+t9kWx1xKGlYmtpDS91iikqCRMKSnkWVOZTrzew+XgxAOMTQ1SOrhXVRWIhZUUL3k2xajQK/7tqKBF+HhwvrOLGD3dSXisaWeSV1/LGOtGU4vGLB5IY1vIC1oqicOeFPZneP5wGg5Fbl+zib9/u4+7P97Bg8e88s+wIWSU11DToefDrFDYeK8JVq/Dvi+zcLroxyXLeu4odUnAQDA2YPAOpcbNjUmtrbl5iTgrI0ayOqCqCU7+L/QGXqBuL1DGBCeATLroe5ySrHY2EeD18/OKBHHt6DvsencmVI2MwmuChb/c5Zat3mWR1JS7uomUqiHWzrNhpUFEUFo4Tbdc/23EKo/HcY3+8NYM7P0vmWEEVbuZWZ4oC909N5E/R5u49EUMACPfz4NWrk1AU+HJnFmuPFFgtVvrOFd0A6ytg94ftf355tligEAX6X2y9uLozd9+m+SzZu9hyvJiqej3hfu4MiwlQNbRWWToL+oSLJP4Mkf6efHbbaIK83TiQU87817ew4mAe//f9AWoaDIzsEchlw6NbPYWiKDz3p8GE+rqTW17Ht3uyWXYgn5SsMt7dlMbE59cz4NGV/LwvFxeNwguXDyUhxNsWX23LGpOsbtZh0FwqaIoc3vXKhiPF32Py96kbhzNK/UXcfIlKgsB4taOROkJRmhqWyHlZDsXNRYNGo/DsZUMYGuNPrc7Au5sz1A6r3WSS1dWMuBlcPERbUivXGV88LApfDxdOna5h84nisz63I+00T/0qOnDdPCGeXf+azoHHZ5Ly6EwWzeyLxjK52vKiDlzQO4Q7Joo5Dv9aepCqer11AtVoYMKDYn/bm6Bv590PSwlI3NjGRVilziv1ThQ7WTtZflCUJ80ZFNni4oUO4w/zsf4oMcyXz24dQ0ygJ5klNdz5WTKbjhWhUeCJSwaet/PgmUJ83FnxwERevmoof53Rh3/N68+LVwxl/Bnt7QO9XPnwplEsSGo9cbO67trG3XyH2xQ1TN04bCFisNjKkaz2s1Q7yFEs52ZJsjK2qBuH1CytRmnsNPjFzizKnWztdJlkdTXewU0L83ZmTlIzvNxcuGJELACvrDmGzmAE4EB2OXd9nozeaOLioVE8etEA/D1d8fVwxd/TVaxBUWEuMbS8qJs9OL0PcUFe5JXX8cKKVOsFO/hy8IkQ8w3SNrTvuQd/EFtZKmhVJeYky5i1k9WHxcjl7EHNJy4OpXGNrJZjHRDlx2/3TeSSYVFEB3gye2AEH9w0ioFR7VtTKdjHnUuTYrhvWm9um9iTy0fE8MXtY0l9ajb7Hp3Jjn9MZ1Kf0M58NR3n201HsnItSVYX7B4XYRnJkklWu+SmQMZmsT9ggZqRSJ1lWZQ4fTNUFaobi9SsSb1DGNkjkHq9kdXZzpW2OFe0UtuMuwdQIPVXyD9o1UPfOjEBX3cX9maW8dzyVL7dncVV726jpLqBITH+/PdPQ869c59nLkUJTDhnIU9PNy3PXiYSr892ZJJWZKV2nVrXpjuMljuObVFwCHJ2g8ZFvnhaWal3LwBMuXuprq0lxMeNUfFBKkfVBn9o394Sfy9XXr06id//PpW3bxhh1aYUHq5a/L1ccXNR8U92dywXrKtoXIjYFJmkcjA2EGnuMFia0fH5q91NXQV8e5MoFew/H4K6UDOU7igkEaJHgMkAB79XOxqpGYqiNHYdzKtVmp2u4qhkktUVhfZtGoXZ+F+rHjo6wJPnLxd3Pz/Yks7fvttPTYOBCYnBfH7bGDzdmmnFbblLekap4JkmJIYwvX8YBqOJl9cct16wliTr6G9t7rao2fux2Ok3TyxWKFlNlXsEJo8AtIZ6+iuZzBwYYd0FqW2lDSNZ3YKfuUSxOyVZeSmACfxjwacLNb2w8AoCP/NcSSvfkHNKRcfEEihHV4gyUcvi2xZGI/x8L5Smg38cXPy6erFK1jPkarHd95W6cUgtGp8Ywle3jeLeAQbHn2JwBplkdVUXPgwocORnq794zhkcye0Txd27HsFePDi9Nx/eNArfltbrsdT7RzSfZAH81Vxz+8u+XA7nttzCul3ixopmBXXlkL6p1YdrDXVoDnwj/jPyFuvEIDVRNFSFDgNguOY4N4ztoW48bdXKnKxuwzI/sb4C6ivVjcVeLB3HortgqaBFZDcvGTSZRFL17hRYPAqW3gVfXgXvTYH/9Yf3poq/ASYTrPyHmLOrcYXLPxBLp0jOb9CfRPVKXkrjyLXkeEb0CHS63kMyyeqqwvrbbDQL4B9z+7Pzn9PY8NBkHpze5/yLyTaOZA1t8SH9I/2YP1SUI72x3kqjWRqtKOcAONx6S/uY0m0oDVUQ1AviJ1knBuksG6vjAZgXmE3/SD91g2mrNpYLdnnuvuBu/p51l7WyzJ0FiR6hbhy2ZLn51R2bX1Sfhs8uE0lVbrJ4ox03HiKHiTmIilZ8/IMZsOQi2PGWeN6lb0PsaFVDl6zIOxgSZ4j9dU+JEUtJsgKZZHVlNhzNUhSFMF+P1junNVRDsTlpOs9IFsA9U8ScnVWHCiiqtNJ6CJZ5VYeWigYcLTHqSSxcJvZH3iw6FEpWlV8DX+eLRGWYYsWyUFuT5YJNuttaWZaRrK7Y9MKiu45klWfDR7Ph5DrQuomOtItS4Zbl8OeN8NcjcN8eMZe4LBNObQEUmP2caKwkdS2THhIjlEd+gQ3Pqh2N1EXId5JdmY1Hs9qk4BBgEmV7rcxx6hfhx7DYAPRGE98nt7zgcbvEXyA6GjZUnbfbonLwO3zqCzB5BcOIm6xzbuksv2RqSDH2woiCW2Wmc3RyMuigxrxcQXcfyYIz2rh3g5GsygJzV1QFumL7dgvLza+i1PYvd+GsDHr4ZAEUHxNzDf+8GWY8AT5/6NwZlAC3rISxd4vk6v5kGHuXKiFLNhYzEua/KvY3PQ/7v1U3HqlLkElWV3fmaFbuXvuf39JZsJVRLItrRosW8V/vysJkjcWUFQUu/LvY3/Fu02iWvkG80AIY9Gi3vASAcew9oixKsqqdGSUcLNVQo/FGHyS6BJG9S92g2qLKvEi2xhU8naAToq35dqORLHPrdkL7de2/Cf4xYm6RUQ+Fh9WOxj5OroXTx8XXfctKCOvX8mN9w2H2syK5Cuppvxgl+0u6DiY8IPZ/ugeyd1vnuBW5cORX0NVZ53iS05BJVlcX1h8GXyH2f10ERoN9z29JslroLPhHFw2JwttNS3pxNdvTzlPe1x795kH4YGiohOd7wtPh8HQoPJ8AK/4BS+ailKZT7+KLccSt1jmn1MhoNPHflccAuGpkNG7xY8QnsnaqGFUbndn0QpaQdq827o3zsbpwqSCIG1HdbV5W8idiO/QaCIhVNxbJsUx7DPrOBUM9fHUt1JZ1/Fg1JfD19fDyQPj6Ovj8cqi30jI1klOQ7xq6g5lPiQnrucmw6wP7nrsNTS/O5O3uwsXDxBu5H6xVMqgoMPNJcPEATKA3302qr4DtiyFrByZXL/bF3Ahu3tY5p9Too60Z7M+uwE1j4r4pvSBmlPiEM4xkyflYZ5NJVtdkWSS+O8zLqiqEYyvEftIN6sYiOR6NFi57D4J7i0qGTS907DgmE/x0r5jjZTKKaoiMzaLRijUSrdMnxXptn10OS+9umj/aVRn0KFk7xLV0IjLJ6g58I2D6Y2J/7RNQeMQ+5zXoms7VxnJBgEuTxLotyw/mU6ez0shbr6nwSDY8dAIe2Af/lwHXfC1We0+6Af2dO8gLlN2irO1YQSX/XZEKwCU9jIT4uEOM+TrnJDeVbDoq2b79bN0lyTKZzmjf3oU7C1pYboJ1h5GsfV+K0sjokRA+QO1oJEfk7iPm4AHseEckNO2V/IlYo1PjCresEmWpHv6QtQN+uL1zVUUlabBkHhz6EU6shpTPxZIDP94FtaUdP64jS9+AyyfzmHTsybPXrnNwMsnqLkbcAvETRQOIL64SrWttrSgVDA3g7g+B8W1+2sgegUQHeFJVr2fNkQLrxaN1FRObA+NFLX7f2XD993DJG01rAElW9Y8fDtCgN3Jh7xAmhJv/MIb0ET8T+loocPAFUC0jWT4yyQK6T5JVkgZ1ZaLrXNhAtaOxPctNsIKD9i8ptyejoamaY7gcxZLOo/d00dbdqIPVj7bvuVVFsOIRsT/tUYgbAzEj4LrvQOsOR5fBD3eIrsenT7avZXz1afj4YvHaFNoPLlkM/8/efYe5UZ1vA35GZXvvu/a6916wjanGFUyHQOghhQRiSAjkl4QUShohyZcCJJRQEzC9g7ExYGzce7fX9nq9vfeqOt8fRyNpdyWttCtpVJ77unzNrHY0czyrMu+857xnxjfF7w6sBp46Dyjb7lt7w8GhtwEAzQljEE6TZTHIihYaDXD9f0WA0VIKfPSjwB/TPgnxdJ/eFBqNhCttXQbf3xcFA+wjVFFNO3aXNkOnkfD7q6Y4XgIajfjCAUK/yyAzWb2lDBPLrobIrkSnZLHyZgC6GHXbEgxZ4wF9AmDqEhX3ItWxD8X3X3wGMP16tVtDoW757wFIwPGPgVofisLsfQkwdYoM8cK7HY8Xzgeu+rdYP/w28Na3gCfmAH/MB/44HHh8NnDqc/f7lWVx7dZaLubzvO1DYPYtwDXPAt/9XBRmaasA/neNmHYgUpi6RbdLABXpC1VujG8YZEWThAzg+v+J9aJPPc8b5Q/28VjedxVUXD1bXMx9VVSPpk6jP1tFQaKU4V88KQd5KXG9f6l0GQz5IEsZk8VMJwCRAdbZ/paRXMa9Koq6CgJiHIoyF1iovycHS5aBLY+L9fl3ADEJ6raHQl/OJGDKFWJ9yz+9e47FBOx6Qayfvap/waTp3wBuehOYcxtQMFtktsw9ojBX02ng1euArU+43vfe/4qATxsDXP9y72lxCueJqQgKF4gA75OfhlW3Oo9OrAWMHZBTR6A5cZzarfEJg6xokz8DyJ0GyBbH4N9AsWeyfA+yxucmY9qwFJitsv8KYFDQmC1WvLtXZCG/MXd4/w2U4hehXmGw3dZdlZksQZIcAWckdxmMpqIXiuFniaW/ylaHmtItInjWxgLz7lC7NRQuzr1XLA+95VV2SDqxBmivAhKzgalXud5owgrgiieA738F/LIK+NF+4J69IvCSrcBnvwa2/bv3cxqLgbW26WgW/8ZRrMZZbJLYr0YPnFwHHH3fu/9jqLN1FbROvSasugoCDLKi06TLxNKWfg0Iq8VRvn2QE3neNH8kAODVHWX+mTOLgubrkw1o6DAgIzEGiybm9N9A6S7YXAJ0NgS3cb5gJqs/pctgpAZZFpPjsytaMlmAU9XPCAyyTD3AJ/eL9Vk39Z90mMidYXOAMYvEjemtT3rcVJIt0GyzbTP3dkAXO/D+tTox6XXmWODyx4GLfi0eX/eAyLx2NorPpHfvEN15R1/QuwtiX9kTgfPvE+uf/nxoJehDQXczcPIzAIB16rUqN8Z3DLKi0WRbkFX8JWDsDMwxGk6IlLU+URQ6GIQrZhUgKVaHkoZObCsOQqEO8gtZlvHClhIAwBUzCxCjc/ExE5/ueF2EaulZswHotnWpZSbLISXCJySuOya678SminEP0ULJZNUdBQzt6rbF3774rSjElJgDLP612q2hcHPeT8Ry73893hScXPUmNNX7gJgk4KxBzLkpScAFPwXmf1/8vP43wF/GAn8cJrLrcanAVU8PPGfjefcBmeNECfovHvG9HaHk2EeigFrOVDHva5hhkBWNcqcBaSPFhYSnQZZDUbVPLAtmif7+g5AUq8NVs8UF3as7ImgQZ4T76kQ9vj7ZgBitBt85d7T7DZUPzFAdaK8UvdDGiqCQBKUSZ1uEjslSugoWzIquCaiT84DUQgCy4/M73MmyGN+y/V/i5yufBBKz1G0ThZ/RF4rxU+ZuUdLdBeno+xhf96n44cp/Db5isSSJ8vErHrUNtZDFxMiQRKYrddjA+9DHAZfbxpDtfgEo3Ta4toSCg2+K5fRvqNuOQYqibxCykyRg8uVi/cS6wBxDyU4UzB7SbpQug+uO1KCuvWeoraIAM1us+OMnYm60b50zEiMyPQwuVzJZjSeD0LJBcK4sGGb9wAPK3l0wQjNZ9vFYUdRVUGEflxUBxS9kGfjwHjG+BRBdrCasULdNFJ4kyZHN2vkM0NPW+/d1x6D9+McAAMvCe9yPxfKWRgss/CFw59fA/50GfnwQuL/It/2OOk9UHgREV8NwnD+rrQo4s1msTwu/roIAg6zoNeYisVRewP5mz2QNLciaUpCCOSPSYLbKeGs3C2CEsm6jBT958wBO1nUgLUGPuy8a7/kJSpDVEKpBFsdjuRTpc2UpY5KiqeiFIpLGZRWtAfb9D5C0wCV/tpXjJhqkSZeJLng9rcD/rgY66sTjbdXA6zdDMnWiPmkKrIt+5d/jJmYC6SN7VxL01opHgfTRouT7B3eHX7XBw+8CkIHCs8U5CEMMsqLViAXiy6elFGgp9+++zUag5pBYH2KQBQC3nC3eXKt3lMFiDbMPiShR1dKN657Zio8OVEGnkfC7K6chNUHv+UmZtlKsIRtkcY4sl5QgKxJLuHfUA/UiE4sR56jbFjXYq37uCL8LMmdmI/DZb8T6uT8GFvyA2WgaGo0WuOopIC4NqNwN/Gs+sPoG4Im5QFMx5JRh2D3qh4BGp3ZLHeJSgG+8IKoNHv84cMNDAkGWgYNviPUZ16nbliFgkBWtYpPFRHmAKG3rT/XHRB/iuFQxOd4QrZyej7QEPSpburHpRL0fGkj+tPtME654cjMOV7YhIzEGr3xvAS6fWTDwE7Nsma7OutDsymDPZDHI6iVZCbJqAItZ3bb425mvxTJ3mriDHG3yZ4lJibsaRaGIcLX7eaCpWJTRVrp5EQ1V4Xzge1+IG4TdzcCJT0WBr+HzYL7hTRj1KWq3sL9hc4B53xPre19Wty2+KN0q5lrVxQFTrla7NYPGICuajTpPLP3dZdB5PJYf7h7G6bX4xhwx19J/t50Z8v7If7acasDNz+1AQ4cRk/NT8MGqc3H2GC8vTmOTHRfsDacC18jB6uAcWS4l5YgsuGwRAXIkKdkklqMvULcdatHFiAtJIHBdyQOtah/w+cNi/aJfibv5RP6SNQ64axvw7U+BpY8AN6wGvrtelE4PVXNuE8uiTx3dHEPdVtvE4bNuCusbXgyyopkSZPk7k2Ufj+W/MQ23nD0SkgRsKKpHUU2ElRcOU1uLG/Cdl3bBYLbioonZeOeuhSjM8FDowpUsW5fBUCx+wTFZrmm0kTshcbQHWQAwMkA334KhtQJ47SZROXf8csfFJZE/6WKAkecA590LTLo09Lui5k4Bhp0FWM3AgdfUbs3A6ouAE2sBSMDZq9RuzZAwyIpmI84GJA3QdNq/5Zir/FNZ0NmorERcMk1kFJ7ddNpv+6XBOV7Thh/8dw8MZisWT8rB07fORULMIPqi24tfhGAZd47Jci8lAoOs1krRxUzSiAuoaDXqXLEs3RI+47KsVuDr/wc8OR9orxKfK9c+N+jpQ4gizpxbxXLv/8T7JZRt/7dYTrrUcSM2TDHIimZxqUDedLGu3MEdKlO3mMwT8Ht1rh9cICYG/WB/Japbu/26b/JebVsPvv3iLrQbzFgwOgNP3TIHsbpBXsyEcoVBZrLci8QKg8p4rILZ4rMxWg2bK8ZBdNaH5vvSlR1PiQmHbeNjcNOb0f03JOpr2rVATLLoNXL4bbVb456hAzhka9+CO9Vtix8wyIp245aKZdEa/+yv5rBISSdmO+bT8ZOZhWk4e0wGzFYZb+zyc0VE8orRbMUPX92L6tYejM1OxLO3njX4AAsI3QqDxi5RqhdgJssV+1xZETStgnKjadT56rZDbbpYR5VBJfAMZR31wFePifWlD4vxMRkeJkEnikaxycD5tiIwX/xW3BAPRUffB4wdomiaMqQljDHIinYTLxXLU18AZsPQ9+c8HisA/ZSvmCku7rYVN/p93zSwP645hj2lzUiO0+H5b80buEz7QJRMVtPp0KpU12HrKqhPAGI5cL6fVFGIBq0REmTJMsdjOQvUeN1A2PB7wNAqquWe86PQHx9DpJazfyhukLWWA5v/rnZrXNv7P7GcfUtEvJcZZEW7gtlAUh5gbPfPXUtlPFaAJvJcMCYDALCvvAU9JktAjkGu7TrThJe2ngEA/P36WRiVlTj0naYMA3TxgNUEtJYNfX/+4jweKwI+6P0utVAs/T3HnlqaS8SFh0YvxqpGu5G2cVlnQnxcVt0xYO9/xfrFf+IYLCJP9PHAkofE+sbHgE/uF3PKhYqGk0D5djEuduZNarfGLxhkRTuNBph4sVg/7ocug5X+L3rhbExWIrKTY2E0W7G/vCUgxyDXnvqqGADwzbMKsXTKIGafd0WjcXTtaQqhgiYcj+VZmi3Iao2QIEvJYg2fB8T44eZBuBt+FqCNERndUHpf9vXF7wDZCky+PLqLlRB5a8b1wKIHAEjArueANT9Vu0UO+2xZrHHLHMWVwlzYBFlNTU24+eabkZKSgrS0NHz3u99FR0eHx+csWrQIkiT1+nfnneE/kM7vlC6DRZ8O7a6lod1RJS5AQZYkSVgwWmSzdpxuCsgxqL+imnZ8ebwOkgTcuWisf3euTFjdVOLf/Q6FkslK8lMwGWlSR4hlRy1g6lG3Lf5QYsvis6ugoI8XJZ+B0B2XVb4LKPpE3PVe/Bu1W0MUHiQJWPQL4JuviJ/3viyGi6jNYgL228rLK5UQI0DYBFk333wzjhw5gvXr1+Pjjz/Gpk2b8P3vf3/A591xxx2orq62//vzn/8chNaGmdEXAPpEUfq2ev/g91N9AIAMpAwXE5YGiDLZ7fbTHJcVLM9sElmsS6blYbQ/ugk6UzJZjcX+3e9Q2LsLRsbdNL9LyBDdPAGgrVLdtgwVx2O5Zp+sPgTHZcky8MUjYn3WTaE9ESxRKJp8GTD/B2L9wx8BPW3qtufkZ2Jy+8RsYMLF6rbFj8IiyDp27BjWrl2L5557DgsWLMB5552HJ554Aq+//jqqqjyXEE5ISEBeXp79X0oKB7H3o48Dxi0W60PpMli2TSyHnzX0Nnlwtm1c1t6yZhjMHJcVaFUt3fhwv3ifKWX0/cqeyQqhbkmcI8szSYqcLoP1ReLLXRcX8M+usBLK82UVfykybNoY4MJfqN0aovC09CEgfZSoErvrOXXbohS8mHkDoB1iQa0QMojZQ4Nv27ZtSEtLw1lnOb4Aly5dCo1Ggx07duDqq692+9xXX30Vr7zyCvLy8nD55ZfjN7/5DRISEtxubzAYYDA4quy1tYno3mQywWQy+eF/4zvluIE8vjTuYuiOfQT5+Ccwn/+zQe1DW7IZGgCWwoWwBrCtI9JikZkYg8ZOI/aeacRZI9OHvM9gnONw9Z9NxTBbZZw9Oh1T8hIHdY48nV8pdSR0AOTGUzCHyPnXtlVBA8CckA05RNrkiRqvX23KcGgaTsDceAZyYeifI3c0pzZAC8BauAAWWQO4OYdR9xmRNxs6jR5SWyVM9afExVgAeX1+ZSt0nz8MCYBl7ndgTcxz+zcjh6h7/QZZWJ5fKQbSeT+F7qO7Ie/8D8zz7lQnwOmohe7kZ5AAmKbfFBafwd62ISyCrJqaGuTk9O5+ptPpkJGRgZqaGrfPu+mmmzBy5EgUFBTg4MGD+PnPf46ioiK8++67bp/z6KOP4pFHHun3+GeffeYxOAuG9evXB2zfejNwMTTQ1B3BV++9jK7YbJ+eL8lmrCzdBg2AjaVmtNf5ad4tN4bFatDYqcHr67ejrsB/d1kDeY5DmSwDX9dIaDVJGJciY3yKDJ0G6DIDr+7RApAwK7YBa9YM7e/q6vzGGxuwHIDcfAaffvIRZEn9CmGLa4qRDGD7kVI0lgX2texPwXz9zmyVMQrAqT0bUFQ19Bsdapl3+m0UADjek4OTXry+o+kz4rz4UcjsPImjHz2JM1mLg3LMgc5vQfNOzKs5CLMmDuu7psE4xM+kaBNNr181hNv51VjjsUyXgrj2Kux//XeoSg9+ddWRDRswS7agOWEMNu08CcDzvJmhcI67urq82k7VIOsXv/gFHnvsMY/bHDt2bND7dx6zNX36dOTn52PJkiUoLi7G2LGuuz098MADuO++++w/t7W1obCwEMuXL1etq6HJZML69euxbNky6PUBvMvQ+ipQthWLhxtgnbfSp6dKlXuh22+AHJ+O86+5QwxGDqDypBIcXH8SPYn5WLly1pD3F7RzHKKe3FCMd86IMVGfVwIFqXG468Ix2H66CUZrDSblJuG+mxZCGmQ5c4/nV7ZCPv4ANBYDLjl3BpA2cqj/naGRZeiO3AUAWLD0CiBzvLrt8YIar1/NliLgqw2YkBOHsSt9+7wIGbIVur/9CAAwYcX3MH7YXLebRuNnhCb9FPDlbzFdW4wpK/8a0GN5dX5lGbrn/wIAkM69B0svuCGgbYok0fj6DaZwPr+alCLg6z9jrmknZq38bdCPr31DFOFImXcDVp7n/rsklM6x0sttIKoGWffffz9uv/12j9uMGTMGeXl5qKur6/W42WxGU1MT8vK8HzOxYMECAMCpU6fcBlmxsbGIjY3t97her1f9jxrwNky6FCjbCu2JT6E9Z5Vvz63cDgCQRpwDfUz/8+dvZ43OBHASByvb/HpOQuHvHGxv76nAP78UAdbiSTk4WNGKqtYe/ObDo/ZtVi0ej5iYmCEfy+35zRgN1B+HvrUUyB435OMMSXcLYOwEAOgzRgFh9HoI6us3XQTDmrZKaMLoHPVSfQDoaQFikqErPAvQDvyVGFWfEdOvBb78LTSlW6AxNAe0oJHC4/k9swWoPQTo4qFd+ENoo+Xv4EdR9fpVQVie3wV3AFv/AU3lbmiqdgMjFwbv2MYu4IwoPKSdfKlX7+lQOMfeHl/VICs7OxvZ2QN3S1u4cCFaWlqwZ88ezJ0r7jR++eWXsFqt9sDJG/v37wcA5OezYphLky8DPvsVcGazKKetVH3zRulWsQzSXCXTh6VCIwHVrT2oae1BXmpcUI4baXpMFvz+ExFMrbpoLP5vxSR0Gy34z9en8cnBaozPTcIVMwuwfGqAC0BkjAHqj9uKXywJ7LEG0lohlgmZQIy6XYRDWiQUvlCqCo4616sAK+qkjwIK5ohJ5o99CMz7nrrt2fGUWM78pqhwSURDl5QjCk7s/S+w+W/AyLeCd+zTXwHmHjEtSO7U4B03SMKiuuDkyZNx8cUX44477sDOnTuxZcsW3H333bjhhhtQUFAAAKisrMSkSZOwc+dOAEBxcTF+97vfYc+ePThz5gw+/PBD3HbbbbjgggswY8YMNf87oSt9FDB2CQAZ2P2C98+zmB2VBYMUZCXG6jAhNxkAsL+8OSjHjEQfHahCS5cJw9Licd8yUQY5PkaLHy0Zj3U/uQBP3jQn8AEWEFoVBpUgK3W4uu0IdalKkFUJWK3qtmWwOD/WwKbaCksdeV/VZqC5FDj+iVhfcJe6bSGKNOfeK4Z5nPwMqDkUvOOe+FQsJ14sqtZGmLAIsgBRJXDSpElYsmQJVq5cifPOOw/PPvus/fcmkwlFRUX2wWgxMTH4/PPPsXz5ckyaNAn3338/rr32Wnz00Udq/RfCg3Knct//AFO3d8859iHQ0yru/OcFL4CdPUIMtt9X3hK0Y0aa/20vBQDcfPYIaDUqfsApWdOQCLJsmRkliCDXkvMBSQtYTUCH+wJEIctidmTgR52vbltC2dSrxPLMZqClTL12bPknIFuBMRcBOZPUawdRJMocC0y5Sqxv/ntwjmkxA0VrxfrES4JzzCALm/4RGRkZWL16tdvfjxo1CrLTXB6FhYXYuHFjMJoWWSasEGnb1jLg8DvA7Fs8by/LwNYnxPq8O4La5WZ2YRpe21mGfWUtQTtmJNlf3oKDFa2I0WrwzbNUDigybeOwGk6o2w6AmSxvaXVAyjDxWdFSDqQUqN0i39QeAoztQFwakDtN7daErrQRItNXsgnY9m/gkj8Fvw2tFaIrEwBc8NPgH58oGpz3E+DIu8CR94CLfiUCr0BSJiBOyAJGnhfYY6kkbDJZFCQaLTDvO2L9yz+IIgCelG4V/fV1ccD8OwLePGezR6QBAA5VtMJsCdPuSipavUNksS6dkY/MpMAXK/Eox9YXu6kEMLSr2xYGWd5TzlE4jssq3yWWw+cBGn4VenTuvWK592Wgqyn4x//6/4mM6ajzgVGReTFGpLr8GcC4ZSJjvPXxwB9v78tiOetGQDf0wlqhiN8s1N/874sxMu1VwNoHPG+rpJVn3QQkZgW+bU7GZichXq9Ft8mCM43ezVlAQrfRgjWHRBevG+aFQLe4pGwgKQ+ADNQeHXDzgLIHWSFwXkJdOBe/qHAKssizsYtFV3BTF7Dz2YG396eWcmDv/8T6ol8E99hE0eZ82xRG+1cDbdWBO05rpchkAcCcbwXuOCpjkEX9xSQCVz0tBkEeWA289W2gcm//7U5/BZxaL8ZlLLw76M3UaCRMzBPFL47XeDdnAQnrjtSgw2BGYUY85o0KkSpdedPFsjaIg25dYZDlPeUctYRzkHWWuu0IB5IEnHevWN/xjH2Kg6DY/DdmsYiCZeQ5QOHZgMUIbP9X4I6z/1WRMRt5LpAV+nNRDhaDLHJtxALgQttdwyPvAv9ZDGz8i6OKmNUCrPuVWJ/3vcD33XVjcr4Iso5VM8jyxTt7RSBxzezh0KhZ8MJZnm1cTM1h9dpgMYsMLsDugt4I10xWRz3QXCLWPUxATE4mXykq0HY3AfteCc4xW8qcslgD9KogIv9Qslm7XghM92CrxTHGMoKzWACDLPJk0c+BH2wCpl4DQAY2/B7431VA8ZfAxz8Bag+LQeMqduGYnJ8CADherfI4njBS3dqNzacaAADXzgmhQELJZAWzfGxf7dXi7po2BkgceA6/qBeumazK3WKZNRGIT1O1KWFDqwPOuUesb30SsJgCf8yvbVms0ReIucyIKPDGLxfFgEydwM7/+H//xRvEjbm4VGDKFf7ffwhhkEWe5c8ErnsRuPJfgDYWKNkI/O9qx4DFJb9RdVLISXm2IKuGQZa33ttXCVkG5o1Kx4jMEJpsN9cWZNUdFXe61KBkZFKGsRiCN1KdMllO1V1DntJVsJDjsXwy62Zx86G1DDj8bmCP1dMmxoUAjl4VRBR4kiQqDQLAjqf93z1470tiOeMGQB/v332HGF5FkHdm3wL8cJv4ktXFAQVzgJvfccyrpRJlTFZlSzdau4JwZzXMybKMd/aIroIhlcUCRJdTXbwYXK/WfFmsLOgb5TwZO4CeFlWb4pNyMWk9i174SB8PLLhTrG97IrCB9fGPAYsByJoQtEnuichmylWO7sG7nvfffjvqgCLbBMRzI7urIMAgi3yRORa46t/Ar2uB728Axi9Vu0VIjddjWJq4E8LiFwM7UNGK4vpOxOo0WDkjX+3m9KbRArlTxLpaXQY5EbFvYhLEHCdA+HQZlGWg+oBY53gs3531HXEzpOaQmKA4UA69LZbTviHurBNR8Gh1wAX/J9Y3/11klv1hxzOA1QwMOwvIneqffYYwBlkU9pTiF+wyODAli7Viah5S4vQqt8YFZVJY1YIsZrJ8Fm7FL1rKAEMboNGLMVnkm4QMMa8NAGx/KjDH6KgX1WsBYPo3AnMMIvJsxg0ik9zdBGzzQ6XBxmLH/Fvn/mjo+wsDDLIo7Cnjslhh0DOj2YqPDorKedfODdEgomCWWFa5mDIgGJRsDIMs74Vb8YvaI2KZPTFiJ8AMuAV3iWXRGnHh5G9H3wdkC1AwW7XKtURRT6sDLrJVkd725NAy17IMrPk/URp+7BJgcmQXvFAwyKKwN4ll3L2y5VQDWrpMyE6OxXnjgjtxtNeUMTIVe9QpftF4SiwzxgT/2OEqNcwyWbW2KQKUrCn5LnsCMG4ZABnY+Gf/7luWgX22su3TmMUiUtXkK4AR54hxty9fMfiM1tEPgOIvRAG1lX+Jmi7ADLIo7E0rSAUAHKtuh9FsVbk1oWvt4RoAwMVT86ANlbmx+sqeDOgTAWM70HAiuMc2G4GWUrEewZMj+l24dRe0B1mRPx4goC76pVgefB2o2u+33UplW8WYOV08MPNGv+2XiAZBowFueUd0HZQtwLpfAp8/7FvRG0M7sNY2z915P4mq7DSDLAp7IzMTkBKng9FixYlajstyxWyx4rOjtiBrWp7KrfFAqwOGzRHrSpntYGkuEXNkxSQDSbnBPXY4C9fuggyyhmbYHGD69WJ93S8BU7dfdqvZ+bRYmXkDkJjpl30S0RDEJABXPw0sfUT8vPnvwPrfeP/8jY8B7VWiWuF59waihSGLQRaFPUmSMGN4GgDgYEWruo0JUTvPNKG5y4T0BD0WjFZvXjOvKBXfKnYH97gNJ8Uyc2zUdGXwi3DKZBk7HWOIlMmvafCW/EZ0/yndAvx9mug62NU06N0l9lRDOrFW/HD2D/3USCIaMkkSAdJl/xA/b30C2PPywM87s9nRxXDlXyN+Xqy+GGRRRJgxXHQZPFjRom5DQtQ6W1fBZVNyodOG+NvePi4ryEFWoy3IYldB3yiZrM56v2UzAqbuOABZTKiblKN2a8Jf2gjgGy+IZVcDsOEPIth64xZg65NAZ6P3+2qvwYLT/4AEGRi/XIz7IqLQcta3HcUwPrlPBFumHtfbdtQBb39H9BCZeRMwflnw2hkidGo3gMgfHEEWM1l9dRnN+PhgNYAQ7yqoGH6WWNYdFX25Y5ODc9wGW9GLTAZZPolPB2JTRFn05lIgZ5LaLXKP47H8b/JlwISLRUXAzf8Aag8Bxz4S/zY+Boy5ECjZBEhasW3udCAxCxi3FIgTlWFRsRu6d+9AsqEacsowSJf4uZgGEfnPBf8nxkwfegv47NfApr+IYlHD5op59HKnAvVFIsDqqBVjrS/9q9qtVgWDLIoI023dBYtq29FjsiBOr1W3QSHk5a2laOw0YkRGAs4fn612cwaWnAekjgBay4DKPcCYRcE5rlJZMGtccI4XKSRJ9LWvOSjGtYVFkMXKgn6l1Yn5rKZdK8ZSlm4VkwkrAZdi738d63FpYvu2KuDEWkiQ0RWTBf2tH0KfMTro/wUi8pIkAVc/A4y+EPjqUaCtEqjaJ/7tek7cdDN1A1aTmKz++peBmES1W60KBlkUEQpS45CVFIOGDiOOVrdhzoh0tZsUEtp6THh6oxiDcu/S8dCHeldBReF8EWSVbgtikKWMyWKQ5bOM0SLIajqtdks8sxe9YJAVEJIk3ruF84FzfgQcfU/c0R59objgKvpUBFW1h8VrZffz9qdaZ9yAjdbzsDRtpIr/ASLyikYLzLkVmPFNoKEIaCoBDr8NHPtY9GoAgDEXiYIZyWHQgyZAGGRRRJAkCdOHpWJDUT0OlrcwyLJ5btNptHabMC4nCVfOGqZ2c7w36jzxgT2UyQ990dUEdNnGjzDI8l26LfPQVKJuOzyRZUcmK49BVsBpNCJT5Uy5YWK1iO6FFbvFmL6RC2HJngbjmjXBbiURDYUuRhQRypsOTLkC6GkTY7FkC5A1IeqLSDHIoogxszANG4rqsau0Gbefy+4m5U1deGaTyCzcv2xC6M6N5cqo88WyYqfodhDoikRKV8GUYVHbrWFIlO5dzSEcZLVVAj2tgEYnvvxJPRqtCMCcgzCTSb32EJF/xKU4xloSqwtS5DhnbBYAYFtxI6xWHybKi1B/+OQYDGYrFo7JDI+CF84yxwLJ+YDFGJz5spQgi1mswckYI5ahnMmqsWWxsiYAulh120JERBGPQRZFjFmFaUiI0aKp04hjNW1qN0dVm082YO2RGmg1Eh6+YiqkcEvZS5LoMggAJV8H/ngNLN8+JEp3wZYy0RUsFLHoBRERBRGDLIoYMTqNfaLdLacaVG6NekwWKx7+SAzwv/XskZiYF6QS6P6mdBkMxrgsFr0YmpQCQBsjihu0VqjdGtfsRS9Yvp2IiAKPQRZFlHPHiS6Dm0/5MAlmhHl56xmcqutAZmIMfrIsjMeeKJmsil2AsSuwx+IcWUOj0QJKVbhQHZfFTBYREQURgyyKKMo8UDtLGmEwh2i3pQDacqoB//hcZGV+dvFEpMbrVW7REGSMEeOyrCYx/0agWC2O0uOcI2vwMkK4wqCp2zHujpUFiYgoCBhkUUSZkJuErKRY9Jis2FvaonZzguqZjcW45fkd6DCYMX9UBq6bW6h2k4ZGmXMHAMp3BO44reWAxQBoY0U5aRocexn3EJwrq/44IFuBhEwgKVft1hARURRgkEURRZIknDcuE0B0jcsqb+rCY2uPQ5aBG+ePwH+/Ox+acCrZ7s5wW5AVyAqD9q6CY0W3NxqcUC7jrlQWzJ0a9fO2EBFRcDDIoojjGJcVPUHWK9tLYZWBc8dl4tFrpiNOHyHBQuECsSzfISaTDQR70Yuxgdl/tAjlCYkr94hl/ixVm0FERNGDQRZFHCXIOljRgtbuyJ/gsstoxms7ywAA3z4nwiZhzp8huvF1NQauG5pSvp1FL4ZGKX/feAqwWtVtS18Vu8Vy+Dx120FERFGDQRZFnIK0eIzJToRVBrafjvwqg+/vq0JbjxkjMhJw0aQctZvjX7pYoGCWWC/fGZhjNHKOLL9IGynKuJt7xDi3UGHoAOps5dsZZBERUZAwyKKIdJ4tmxUN47Le2yfmJbpt4UhoI2EcVl+BLn7RWCyWzGQNjVYHZNi6XDacULctzqr3i6IXKcOAlHy1W0NERFGCQRZFpGgZl9VlNGNfWQsAYPmUPHUbEyj2cVkByGQZO4G2SrHOMVlDl22bly2UgiylaMrws9RtBxERRRUGWRSRzh6TCa1Gwun6Tpxp6FS7OQGz60wzzFYZw9PjMSIzQe3mBIZSYbDuKNDT5t99K3MnJWQCCRn+3Xc0yrIFWfVF6rbDmTIeaxiDLCIiCh4GWRSRUuP19i6D7+6tULk1gbPVlqk7Z2ymyi0JoORcMd4HMlC527/7VoIBJTigocmaKJZKMRG1yTKLXhARkSoYZFHEunbucADAO3srYbUGqPy3yrYWi8Ie54zNUrklARaoLoN1R8UyZ4p/9xutlOIhDSGSyWqrBDpqAEkL5M9UuzVERBRFGGRRxFo+JRfJsTpUtnRj55kmtZvjd61dJhyuagUALIzkTBbgVPzCz0FWrS3IymWQ5RdKkNXVCHSGQGXPsu1imTcdiInQ7rRERBSSGGRRxIrTa3HpDFFN7J09kddlcHtJI2QZGJudiNyUOLWbE1hKkFWxy79zMNUdE0tmsvwjJhFILRTroVD8QgmyRixUtx1ERBR1GGRRRPuGrcvgRwer0NhhULk1/vXJwWoAwPnjs1VuSRDkTAX0iYChDag/7p999rQBrWW2/U/2zz7JMb4tFIKsciXIOlvddhARUdRhkEURbe7IdMwYnooekxUvbz2jdnP8pqnTiLWHawAA184ZrnJrgkCrA4bNEesVfuoyqARryQVAfLp/9kmhE2T1tAK1tkmIGWQREVGQMciiiCZJEu68UMx/9PK2UnQazCq3yD/e2VMBo8WK6cNSMX14qtrNCQ6l+EXpNv/sz170glksv8q2VRj0V8ZxsCp2iUmI00cByRE6hxwREYWssAmy/vCHP+Ccc85BQkIC0tLSvHqOLMt48MEHkZ+fj/j4eCxduhQnT4ZIaWEKmhVT8zA6KxGt3Sa8vqtc7eYMmSzLeG2n6OZ24/wRKrcmiMYsEstT6wGrZej7s4/HYpDlV8r4NuX8qqVsh1hyPBYREakgbIIso9GI6667DnfddZfXz/nzn/+Mxx9/HE8//TR27NiBxMRErFixAj09PQFsKYUarUbCd84dBQB4b194F8Awmq341fuHcbqhE4kxWlwxq0DtJgXPiLOBuFRRua5i19D3x/LtgaFkstoqRZc9tZTZMp5KBpSIiCiIwibIeuSRR/CTn/wE06dP92p7WZbxj3/8A7/+9a9x5ZVXYsaMGfjvf/+LqqoqvP/++4FtLIWcldPzoZGAw5VtKGvsUrs5PpFlGf/vsyIsfPQLzPvD51i9owySBDywcjKSYnVqNy94tHpg/HKxXvTp0PfH8u2BEZ8mxrkBQJ1KXQZNPY5JiJnJIiIiFUTsFVpJSQlqamqwdOlS+2OpqalYsGABtm3bhhtuuMHl8wwGAwwGRxW6trY2AIDJZILJZApso91QjqvW8SNBSqwG80elY3tJMz45WInvnTeq1+9D9RxbrTIe/OgY3tjtyMAlxmrxt+tmYPHE7JBrrzv+Or/S2GXQHXoLctEamBf9evA76qyHvqsBMiSY08YAYXIe3Qm11682exI07VUw1xyGnD8n6MeXijdCZ+6GnJzvt79vqJ3jSMPzG1g8v4HF8xt4oXSOvW1DxAZZNTWi8lpubm6vx3Nzc+2/c+XRRx/FI4880u/xzz77DAkJ6k5muX79elWPH+6GQwKgxRtbilDQdtTlNqF2jt8/o8GGag0kyLh2tBVjkmVkxprRU7wLa4rVbp3vhnp+dWYLLoEWmoYT2Pjei+iMzR34SS5ktR/FuQA6Y3PwxfqvhtSmUBIqr9+p7TEYB6B011ocrs4K+vGnVbyCsQBKYybgwKd+yHo6CZVzHKl4fgOL5zeweH4DLxTOcVeXdz2iVA2yfvGLX+Cxxx7zuM2xY8cwadKkILUIeOCBB3DffffZf25ra0NhYSGWL1+OlJSUoLXDmclkwvr167Fs2TLo9XpV2hAJ5rb14O2/bMKZDgmzz12M/FTHBL6heI4/OVSDDdsOAgD+cu10XBnG46/8en7bVwNnvsZFw02wzls5qF1odlUAp4CEUWdh5crB7SOUhNrrV9rfDHyyFqMTezBChfOre0rcKBu+6NsYNsk/xw+1cxxpeH4Di+c3sHh+Ay+UzrHSy20gqgZZ999/P26//XaP24wZM2ZQ+87LEyV7a2trkZ+fb3+8trYWs2bNcvu82NhYxMbG9ntcr9er/kcNhTaEs+GZeswdmY49pc346mQjbls4qt82oXKOd59pwi/fF3P83HnhWHxj3kiVW+Qffjm/YxcDZ76GtmwrtOesGtw+GsRYIU3uVGhC4O/tL6Hy+kW+GDurqT8e/PPbdBpoKgY0OujGLwH8fPyQOccRiuc3sHh+A4vnN/BC4Rx7e3xVg6zs7GxkZ2cHZN+jR49GXl4evvjiC3tQ1dbWhh07dvhUoZAiy9LJuSLIKqp3GWSFgk8OVuMnb+6H0WzFueMy8dPlE9RuUmgZfYFYnvlalHLXaH3fB8u3B5ZSYbCzDuhsBBIzg3fsk5+LZeHZQJw6vQ+IiIjCprpgWVkZ9u/fj7KyMlgsFuzfvx/79+9HR0eHfZtJkybhvffeAyAmob333nvx+9//Hh9++CEOHTqE2267DQUFBbjqqqtU+l+Q2hZNFEH91uIG9Jj8MNfSENW29aC9xwRZlnG6vgP3vLYPq1bvhdFsxdLJOXjutnnQacPmbRoc+bOAmGRRHrzmkO/Pl2VHkJU71a9NI5vYJCDNNodbfZDnyypaI5bjl3rejoiIKIDCpvDFgw8+iJdfftn+8+zZswEAGzZswKJFiwAARUVFaG11zMvys5/9DJ2dnfj+97+PlpYWnHfeeVi7di3i4uJA0WlSXjLyUuJQ09aDHSVNuHBCYDKpA6lr68FvPz6Kjw9WAwDSE/Ro7hLVajQScMf5Y/B/KyYywHJFqwNGngOcXCeyWQWzfHt+Sxlg7AC0MUDG4LojkxdypohzXXsUGHVecI7ZWgGc/kqsT7kyOMckIiJyIWyu4F566SXIstzvnxJgAWI+IecxXpIk4be//S1qamrQ09ODzz//HBMmsOtVNJMkyR5YfVVUp0obmjuNuPSJzfYACwCau0zQayWcNy4LH959Hh5YOZkBlidKl8GSTb4/V8liZU0Qc29RYChZwtrDwTvmgdcAyMDIcxlAExGRqsImk0XkLxdNysYbu8vxVVE9Hro8+Md/Y3c56tsNKMyIx1M3z0V+ahwqmrsxITcZ8TGDGF8UjUafL5al2wCLWWS3vFVnK9/P8ViBZQ+yjgTneLIM7F8t1mfdHJxjEhERucFb5RR1zh2XBZ1GQklDJ07Xdwz8BD+yWGW8sr0UAHDPReMxbVgqMpNiMbMwjQGWL3KnA3FpgLEdqN7v23PtRS+m+LtV5CxXVBhE3VFRoCTQyraJyoIxSewqSEREqmOQRVEnOU6PhWNFtbNPD7ufmDoQviqqQ0VzN1Lj9bh8ZvjOe6U6jcYxzqdko2/PrT4glgyyAitjDKCLA0xdQPOZwB/v4JtiOeUqUXiDiIhIRQyyKCpdNkPMnfbRgaqgHve/20QW65vzCpm5Gir7uKyvvX9OVxPQUCTWC+f7v03koNUB2baJ5AM9LstiBo59JNanXxvYYxEREXmBQRZFpRVT86DTSDhe045TdcHpMljd2o1NJ+sBADcvGBGUY0a0UbZxWWXbAbPBu+eUbRPL7ElAQkZg2kUOedPEsibAQVbpZqCrAYjPAEZdENhjEREReYFBFkWltIQYnDc+C4CY/DcYPthfBVkG5o/KwMjMxKAcM6LlTAYSsgBzN1C5x7vnKEHWiLMD1y5yyLUFWYEufnFEzI+IKVf4VgSFiIgoQBhkUdS6dLqty+DBKsiyHNBjybKMd/dWAACunjMsoMeKGpLkqDLobZfBUiXIWhiYNlFv9iBrEJNGe8tiBo5+KNanXh244xAREfmAQRZFreVT8xCn1+BUXQf2lLUE9FhHqtpworYDMToNVtqCO/IDpcugN/NlGbsclQgZZAWHUsa9pQzoafW87WCVbgG6m4CETGBkkCY9JiIiGgCDLIpaqfF6XDlTZJVe2VEe0GO9t68SALBsci5S4zkBrt+MvlAsK3YCpm7P21buBqxmILkASOOYuKBIyABSC8W6UtXR3059Lpbjl7OrIBERhQwGWRTVbl04EgDw2dFatBkDcwxZlrHWVir+ilks2+5XmWOB5HzAYgTKd3rettRpPJYkBb5tJAybK5YVuwOz/+IvxXLsksDsn4iIaBAYZFFUmzYsFbNHpMFkkbG1NjAX3seq21HZ0o04vQYXjM8OyDGiliQ5lXIfoMugkvFQxnFRcChBlrfFSXzRXmMrDy8BYy/y//6JiIgGiUEWRb3bzxkFAPiiSoOqlgG6nA3C+qO1AIDzxmVzbqxAUMZlnfFQ/KKzEajYJdbHLw98m8hh+FliGYggS8liFcwCErP8v38iIqJBYpBFUe/yGQWYOyINRquEhz8+5vdKg+uPia6Cy6fk+nW/ZKNkpir3AAY3c56d+hyALKrdpQ4PWtMIQP5MQNIC7dVAm58n/z71hViyqyAREYUYBlkU9TQaCb+7cgq0kowNRQ34zJZ58oeqlm4crmyDJAGLJ+f4bb/kJH2UKGRhNYuJiV05+ZlYMosVfDGJQM4Use7PcVlWiyOTNY5BFhERhRYGWUQAxuck4aJ8kcF6/IuTfstmfXRA3LmfOyIdWUmxftknuTBKGZe1sf/vLGbHeKwJK4LXJnIYHoBxWdX7Ren2mGRg+Dz/7ZeIiMgPGGQR2SwusCIhRosjVW3YeKJ+yPvrMVnw3OYSAMB1Z7GLWkCNWSSWxz8G+gbIZ74GelqAuDRg2FlBbhgBcJx3fwZZp2xZrDEXAlpOi0BERKGFQRaRTaIeuMEWDP17Q/GQ9/fWngrUtxtQkBqHq2czyAqoSSuBmCSg6TRQts3xuNUKfP6wWJ92LedRUouSaarYDZh6/LPPYmU81mL/7I+IiMiPGGQROfnOuSMRo9Vg55kmvLl78BMUmyxWPLNRBGrfv2AMYnR8qwVUTCIw7Rqxvu8Vx+MHXhPdymJTgEW/UKVpBCB7IpCUB5i7gfIdQ99fT6tjXjSOxyIiohDEKz8iJ7kpcfjhRWMBAL967xC2n24c1H5e3noGFc3dyEqKwQ3zR/izieTOrFvE8sh7gKEdqDsGfP6QeOyC/wOSWHhENZLk6NJ5esPQ91eyCZAtQMZYUfiEiIgoxDDIIurjR4vH49IZ+TBZZNz5yh6UN3X59Py6th784/OTAID/WzERcXrOjRUUhfOBzPGAqQt4fjnw/Aqgsx7ImQosuFPt1pEyWXCxH4IspXQ7s1hERBSiGGQR9aHRSPh/183EzOGpaOky4a5X96DHZPH6+X9ccwwdBjNmFabhurmFAWwp9SJJwNKHAW0MUHcUMLQCIxYCt38M6GLUbh0pmazqA0BX0+D3Y7UCJ9aK9XFLh9wsIiKiQGCQReRCnF6Lf98yF+kJehyubMMPX92Lpk7jgM8rru/A+/urIEnAb6+cCo1GCkJryW7yZcD9RcBlfweWPAjc+j6QkKF2qwgAkvNs82XJwOmvBr+fip1iYuPYFEfgRkREFGIYZBG5MSwtHk/cOAd6rYQvj9dhxT824ZOD1R7n0Hrua1GyfcmkXMwYnhakllIvCRnAWd8Bzr8f0Mep3RpyNsbWZVCZHHowjrwvlhMvAXSce46IiEITgywiD84bn4X3fnguxmYnor7dgFWr9+KGZ7ejocPQb9uGDgPe3VsBQFQUJKI+plwhlkfeG1yXQasVOPahbV9X+q9dREREfsYgi2gA04al4pMfnY8fLxmPOL0GO0qacNvzO9Habeq13UtbzsBgtmLm8FTMG5WuUmuJQljhAiB3OmDu6V1q3xOr1bFeuRtoqwRikoGxLHpBREShi0EWkRfi9Fr8ZNkEfPKj85GVFIuj1W244snNePjDI9hxuhFbTjXgKdu8WD+4cCwkiWOxiPqRJGD+HWJ99/OAdYCCMl1NwFPnAI/PAQ68Dnx4j3h84sXsCkpERCGNQRaRD8ZmJ+F/352PtAQ9Shu78NLWM/jms9vxrRd2wmKVcc2cYbhkWp7azSQKXdOvA+JSgeYzotugO7IMfPRjoP4Y0FQMvPcDoP44kJwPXPjzoDWXiIhoMBhkEflocn4Kvrx/ER6/cTauP2s4YrQamK0yZg5PxR+vns4sFpEnMQnA/B+I9Y9/AjQWu95u/6ti/JVGB8y8CYAEFMwG7vgSyBoftOYSERENhk7tBhCFo4zEGFwxswBXzCzA/csnYmNRPVZMzePEw0TeuPBnwJmvgbJtwBu3At9ZC8SlOH7fUQes/aVYv+hXwPn3ASv+AMSlARreGyQiotDHbyuiIcpNicP18wqRmqBXuylE4UGrB657CUjMAeqOAKu/CRi7HL///GExmXT+TODcH4vHEjIYYBERUdjgNxYREQVfch5w85tiUuGyrcALK4A9LwFf/z/RVRAALv0boGF2mIiIwg+7CxIRkToKZgM3vw28cg1Qc1AUulDMvhUYfpZ6bSMiIhoCBllERKSeEQuAe/aIebOOfwIk5QCjzgPm3aF2y4iIiAaNQRYREakrOQ+44KfiHxERUQTgmCwiIiIiIiI/YpBFRERERETkRwyyiIiIiIiI/IhBFhERERERkR8xyCIiIiIiIvIjBllERERERER+FDZB1h/+8Aecc845SEhIQFpamlfPuf322yFJUq9/F198cWAbSkREREREUS1s5skyGo247rrrsHDhQjz//PNeP+/iiy/Giy++aP85NjY2EM0jIiIiIiICEEZB1iOPPAIAeOmll3x6XmxsLPLy8gLQIiIiIiIiov7CJsgarK+++go5OTlIT0/H4sWL8fvf/x6ZmZlutzcYDDAYDPaf29raAAAmkwkmkyng7XVFOa5ax48GPMeBxfMbWDy/gcdzHFg8v4HF8xtYPL+BF0rn2Ns2SLIsywFui1+99NJLuPfee9HS0jLgtq+//joSEhIwevRoFBcX45e//CWSkpKwbds2aLVal895+OGH7VkzZ6tXr0ZCQsJQm09ERERERGGqq6sLN910E1pbW5GSkuJ2O1WDrF/84hd47LHHPG5z7NgxTJo0yf6zL0FWX6dPn8bYsWPx+eefY8mSJS63cZXJKiwsRENDg8cTGUgmkwnr16/HsmXLoNfrVWlDpOM5Diye38Di+Q08nuPA4vkNLJ7fwOL5DbxQOsdtbW3IysoaMMhStbvg/fffj9tvv93jNmPGjPHb8caMGYOsrCycOnXKbZAVGxvrsjiGXq9X/Y8aCm2IdDzHgcXzG1g8v4HHcxxYPL+BxfMbWDy/gRcK59jb46saZGVnZyM7Oztox6uoqEBjYyPy8/O9fo6S6FPGZqnBZDKhq6sLbW1tqr+wIhXPcWDx/AYWz2/g8RwHFs9vYPH8BhbPb+CF0jlWYoKBOgOGTeGLsrIyNDU1oaysDBaLBfv37wcAjBs3DklJSQCASZMm4dFHH8XVV1+Njo4OPPLII7j22muRl5eH4uJi/OxnP8O4ceOwYsUKr4/b3t4OACgsLPT7/4mIiIiIiMJPe3s7UlNT3f4+bIKsBx98EC+//LL959mzZwMANmzYgEWLFgEAioqK0NraCgDQarU4ePAgXn75ZbS0tKCgoADLly/H7373O5/myiooKEB5eTmSk5MhSZL//kM+UMaFlZeXqzYuLNLxHAcWz29g8fwGHs9xYPH8BhbPb2Dx/AZeKJ1jWZbR3t6OgoICj9uFXXXBaNTW1obU1NQBB9jR4PEcBxbPb2Dx/AYez3Fg8fwGFs9vYPH8Bl44nmON2g0gIiIiIiKKJAyyiIiIiIiI/IhBVhiIjY3FQw895NNYMvINz3Fg8fwGFs9v4PEcBxbPb2Dx/AYWz2/gheM55pgsIiIiIiIiP2Imi4iIiIiIyI8YZBEREREREfkRgywiIiIiIiI/YpBFRERERETkRwyyQsS//vUvjBo1CnFxcViwYAF27tzpcfu33noLkyZNQlxcHKZPn441a9YEqaXh59FHH8W8efOQnJyMnJwcXHXVVSgqKvL4nJdeegmSJPX6FxcXF6QWh5eHH36437maNGmSx+fw9eu9UaNG9Tu/kiRh1apVLrfna3dgmzZtwuWXX46CggJIkoT333+/1+9lWcaDDz6I/Px8xMfHY+nSpTh58uSA+/X1czxSeTq/JpMJP//5zzF9+nQkJiaioKAAt912G6qqqjzuczCfM5FqoNfv7bff3u9cXXzxxQPul69fh4HOsavPZEmS8Je//MXtPvkaFry5Juvp6cGqVauQmZmJpKQkXHvttaitrfW438F+bgcSg6wQ8MYbb+C+++7DQw89hL1792LmzJlYsWIF6urqXG6/detW3Hjjjfjud7+Lffv24aqrrsJVV12Fw4cPB7nl4WHjxo1YtWoVtm/fjvXr18NkMmH58uXo7Oz0+LyUlBRUV1fb/5WWlgapxeFn6tSpvc7V5s2b3W7L169vdu3a1evcrl+/HgBw3XXXuX0OX7uedXZ2YubMmfjXv/7l8vd//vOf8fjjj+Ppp5/Gjh07kJiYiBUrVqCnp8ftPn39HI9kns5vV1cX9u7di9/85jfYu3cv3n33XRQVFeGKK64YcL++fM5EsoFevwBw8cUX9zpXr732msd98vXb20Dn2PncVldX44UXXoAkSbj22ms97pevYe+uyX7yk5/go48+wltvvYWNGzeiqqoK11xzjcf9DuZzO+BkUt38+fPlVatW2X+2WCxyQUGB/Oijj7rc/vrrr5cvvfTSXo8tWLBA/sEPfhDQdkaKuro6GYC8ceNGt9u8+OKLcmpqavAaFcYeeugheebMmV5vz9fv0Pz4xz+Wx44dK1utVpe/52vXNwDk9957z/6z1WqV8/Ly5L/85S/2x1paWuTY2Fj5tddec7sfXz/Ho0Xf8+vKzp07ZQByaWmp2218/ZyJFq7O77e+9S35yiuv9Gk/fP26581r+Morr5QXL17scRu+hl3re03W0tIi6/V6+a233rJvc+zYMRmAvG3bNpf7GOzndqAxk6Uyo9GIPXv2YOnSpfbHNBoNli5dim3btrl8zrZt23ptDwArVqxwuz311traCgDIyMjwuF1HRwdGjhyJwsJCXHnllThy5EgwmheWTp48iYKCAowZMwY333wzysrK3G7L1+/gGY1GvPLKK/jOd74DSZLcbsfX7uCVlJSgpqam12s0NTUVCxYscPsaHcznODm0trZCkiSkpaV53M6Xz5lo99VXXyEnJwcTJ07EXXfdhcbGRrfb8vU7NLW1tfjkk0/w3e9+d8Bt+Rrur+812Z49e2AymXq9HidNmoQRI0a4fT0O5nM7GBhkqayhoQEWiwW5ubm9Hs/NzUVNTY3L59TU1Pi0PTlYrVbce++9OPfcczFt2jS3202cOBEvvPACPvjgA7zyyiuwWq0455xzUFFREcTWhocFCxbgpZdewtq1a/HUU0+hpKQE559/Ptrb211uz9fv4L3//vtoaWnB7bff7nYbvnaHRnkd+vIaHcznOAk9PT34+c9/jhtvvBEpKSlut/P1cyaaXXzxxfjvf/+LL774Ao899hg2btyISy65BBaLxeX2fP0Ozcsvv4zk5OQBu7PxNdyfq2uympoaxMTE9LvpMtB1sbKNt88JBp1qRyZSwapVq3D48OEB+0EvXLgQCxcutP98zjnnYPLkyXjmmWfwu9/9LtDNDCuXXHKJfX3GjBlYsGABRo4ciTfffNOrO3vkveeffx6XXHIJCgoK3G7D1y6FC5PJhOuvvx6yLOOpp57yuC0/Z7x3ww032NenT5+OGTNmYOzYsfjqq6+wZMkSFVsWmV544QXcfPPNAxYY4mu4P2+vycIVM1kqy8rKglar7Vc1pba2Fnl5eS6fk5eX59P2JNx99934+OOPsWHDBgwfPtyn5+r1esyePRunTp0KUOsiR1paGiZMmOD2XPH1OzilpaX4/PPP8b3vfc+n5/G16xvldejLa3Qwn+PRTgmwSktLsX79eo9ZLFcG+pwhhzFjxiArK8vtueLrd/C+/vprFBUV+fy5DPA17O6aLC8vD0ajES0tLb22H+i6WNnG2+cEA4MslcXExGDu3Ln44osv7I9ZrVZ88cUXve5GO1u4cGGv7QFg/fr1brePdrIs4+6778Z7772HL7/8EqNHj/Z5HxaLBYcOHUJ+fn4AWhhZOjo6UFxc7PZc8fU7OC+++CJycnJw6aWX+vQ8vnZ9M3r0aOTl5fV6jba1tWHHjh1uX6OD+RyPZkqAdfLkSXz++efIzMz0eR8Dfc6QQ0VFBRobG92eK75+B+/555/H3LlzMXPmTJ+fG62v4YGuyebOnQu9Xt/r9VhUVISysjK3r8fBfG4HhWolN8ju9ddfl2NjY+WXXnpJPnr0qPz9739fTktLk2tqamRZluVbb71V/sUvfmHffsuWLbJOp5P/+te/yseOHZMfeughWa/Xy4cOHVLrvxDS7rrrLjk1NVX+6quv5Orqavu/rq4u+zZ9z/Ejjzwir1u3Ti4uLpb37Nkj33DDDXJcXJx85MgRNf4LIe3++++Xv/rqK7mkpETesmWLvHTpUjkrK0uuq6uTZZmvX3+wWCzyiBEj5J///Of9fsfXru/a29vlffv2yfv27ZMByH/729/kffv22avb/elPf5LT0tLkDz74QD548KB85ZVXyqNHj5a7u7vt+1i8eLH8xBNP2H8e6HM8mng6v0ajUb7iiivk4cOHy/v37+/1mWwwGOz76Ht+B/qciSaezm97e7v805/+VN62bZtcUlIif/755/KcOXPk8ePHyz09PfZ98PXr2UCfEbIsy62trXJCQoL81FNPudwHX8OueXNNduedd8ojRoyQv/zyS3n37t3ywoUL5YULF/baz8SJE+V3333X/rM3n9vBxiArRDzxxBPyiBEj5JiYGHn+/Pny9u3b7b+78MIL5W9961u9tn/zzTflCRMmyDExMfLUqVPlTz75JMgtDh8AXP578cUX7dv0Pcf33nuv/e+Rm5srr1y5Ut67d2/wGx8GvvnNb8r5+flyTEyMPGzYMPmb3/ymfOrUKfvv+fodunXr1skA5KKion6/42vXdxs2bHD5maCcR6vVKv/mN7+Rc3Nz5djYWHnJkiX9zv3IkSPlhx56qNdjnj7Ho4mn81tSUuL2M3nDhg32ffQ9vwN9zkQTT+e3q6tLXr58uZydnS3r9Xp55MiR8h133NEvWOLr17OBPiNkWZafeeYZOT4+Xm5paXG5D76GXfPmmqy7u1v+4Q9/KKenp8sJCQny1VdfLVdXV/fbj/NzvPncDjZJlmU5MDkyIiIiIiKi6MMxWURERERERH7EIIuIiIiIiMiPGGQRERERERH5EYMsIiIiIiIiP2KQRURERERE5EcMsoiIiIiIiPyIQRYREREREZEfMcgiIiICcPvtt+Oqq65SuxlERBQBdGo3gIiIKNAkSfL4+4ceegj//Oc/IctykFpERESRjEEWERFFvOrqavv6G2+8gQcffBBFRUX2x5KSkpCUlKRG04iIKAKxuyAREUW8vLw8+7/U1FRIktTrsaSkpH7dBRctWoR77rkH9957L9LT05Gbm4v//Oc/6OzsxLe//W0kJydj3Lhx+PTTT3sd6/Dhw7jkkkuQlJSE3Nxc3HrrrWhoaAjy/5iIiNTEIIuIiMiNl19+GVlZWdi5cyfuuece3HXXXbjuuutwzjnnYO/evVi+fDluvfVWdHV1AQBaWlqwePFizJ49G7t378batWtRW1uL66+/XuX/CRERBRODLCIiIjdmzpyJX//61xg/fjweeOABxMXFISsrC3fccQfGjx+PBx98EI2NjTh48CAA4Mknn8Ts2bPxxz/+EZMmTcLs2bPxwgsvYMOGDThx4oTK/xsiIgoWjskiIiJyY8aMGfZ1rVaLzMxMTJ8+3f5Ybm4uAKCurg4AcODAAWzYsMHl+K7i4mJMmDAhwC0mIqJQwCCLiIjIDb1e3+tnSZJ6PaZULbRarQCAjo4OXH755Xjsscf67Ss/Pz+ALSUiolDCIIuIiMhP5syZg3feeQejRo2CTsevWCKiaMUxWURERH6yatUqNDU14cYbb8SuXbtQXFyMdevW4dvf/jYsFovazSMioiBhkEVEROQnBQUF2LJlCywWC5YvX47p06fj3nvvRVpaGjQafuUSEUULSeb09kRERERERH7D22pERERERER+xCCLiIiIiIjIjxhkERERERER+RGDLCIiIiIiIj9ikEVERERERORHDLKIiIiIiIj8iEEWERERERGRHzHIIiIiIiIi8iMGWURERERERH7EIIuIiIiIiMiPGGQRERERERH5EYMsIiIiIiIiP2KQRURERERE5EcMsoiIiIiIiPyIQRYREREREZEfMcgiIiIiIiLyIwZZREREREREfsQgi4iIiIiIyI8YZBEREREREfkRgywiIiIiIiI/YpBFRERERETkRwyyiIiIiIiI/IhBFhERERERkR8xyCIiIiIiIvIjBllERERERER+xCCLiIiIiIjIjxhkERERERER+RGDLCIiIiIiIj/Sqd2AUGe1WlFVVYXk5GRIkqR2c4iIiIiISCWyLKO9vR0FBQXQaNznqxhkDaCqqgqFhYVqN4OIiIiIiEJEeXk5hg8f7vb3DLIGkJycDECcyJSUFFXaYDKZ8Nlnn2H58uXQ6/WqtCHS8RwHFs9vYPH8Bh7PcWDx/AYWz29g8fwGXiid47a2NhQWFtpjBHcYZA1A6SKYkpKiapCVkJCAlJQU1V9YkYrnOLB4fgOL5zfweI4Di+c3sHh+A4vnN/BC8RwPNIyIhS+IiIiIiIj8iEEWERERERGRHzHIIiIiIiIi8iOOySIiIiIiIgCiRLnZbIbFYlG7KXYmkwk6nQ49PT0Bb5dWq4VOpxvy1E0MsoiIiIiICEajEdXV1ejq6lK7Kb3Isoy8vDyUl5cHZd7ahIQE5OfnIyYmZtD7YJBFRERERBTlrFYrSkpKoNVqUVBQgJiYmKAENN6wWq3o6OhAUlKSxwmAh0qWZRiNRtTX16OkpATjx48f9PEYZBERERERRTmj0Qir1YrCwkIkJCSo3ZxerFYrjEYj4uLiAhpkAUB8fDz0ej1KS0vtxxwMFr4gIiIiIiIACHgQEw78cQ54FomIiIiIiPyIQRYREREREZEfMcgiIiIiIiLyIwZZRETulGwCGovVbgURERENUnV1NW666SZMmDABGo0G9957b1COyyCLiMiV5lLg5cuBN25RuyVEREQ0SAaDAdnZ2fj1r3+NmTNnBu24LOFORORKa4VYNpWo2w4iIiIVyLKMbpNFlWPH67Vez9H17LPP4uGHH0ZFRUWvqoBXXnklMjMz8cILL+Cf//wnAOCFF14ISHtdYZBFROSKoU0szd2AqRvQx6vbHiIioiDqNlkw5cF1qhz76G9XICHGuzDluuuuwz333IMNGzZgyZIlAICmpiasXbsWa9asCWQzPWJ3QSIiVwztjvXuZrHsqAe++C3w8X2AVZ27e0REROSQnp6OSy65BKtXr7Y/9vbbbyMrKwsXXXSRau1iJouIyJWeVsd6VxNQsRt47weAqUs8NutmYPhcddpGREQUYPF6LY7+doVqx/bFzTffjDvuuAP//ve/ERsbi1dffRU33HCDqhMrM8giInJF6S4IAN1NwIHXHAEWALRXAWCQRUREkUmSJK+77Knt8ssvhyzL+OSTTzBv3jx8/fXX+Pvf/65qm8LjzBERBVvf7oIdtbYfJAAy0F6jRquIiIioj7i4OFxzzTV49dVXcerUKUycOBFz5sxRtU0MsoiIXOlxymR1NYnxWACQNx2oOcggi4iIKITcfPPNuOyyy3DkyBHcckvv6Vf2798PAOjo6EB9fT3279+PmJgYTJkyJWDtYZBFRORKr0xWE9BZJ9bzZ4ggq4NBFhERUahYvHgxMjIyUFRUhJtuuqnX72bPnm1f37NnD1avXo2RI0fizJkzAWsPgywiIlecx2S1lAHmHrGeN0MsmckiIiIKGRqNBlVVVS5/J8tykFvDEu5ERK45Z7Lqi8RSnwhkjBHr7bX9n0NEREQEBllERK45j8mqPy6WSdlAcp5Yb68OfpuIiIgoLDDIIiJyxeA0T5YyGXFSLpBkC7K6GgCLKfjtIiIiopDHIIuIyBXn7oKKxGwgIRPQ2IazdtQFt01EREQUFhhkERH1Jcu9uwsqknIAjUZktAAWvyAiIiKXGGQREfVl6gZkS//HE3PEUhmXxTLuRERE5AKDLCKivgwusliAKHwBOMZlsfgFERERucAgi4ioL2U8VlwqoIt3PN43k8Uy7kREROQCgywior6U8VixqUB8uuPxpL5BFjNZRERE1B+DLCKivpTugrHJQEKG4/FEW3dB+5gsZrKIiIhC2bvvvotly5YhOzsbKSkpWLhwIdatWxfw4zLIIiLqSwmy4lJcZ7I4JouIiCgsbNq0CcuWLcOaNWuwZ88eXHTRRbj88suxb9++gB5XF9C9ExGFox6nTJbeNiZLFw/EJIl1jskiIiIKCc8++ywefvhhVFRUQKNx5I+uvPJKZGZm4oUXXui1/R//+Ed88MEH+OijjzB79uyAtYtBFhFRX0rhi9gUINYWWCXlAJJkW7fNk9XVAFgtgEYb/DYSEREFkiwDpi51jq1PcHznDuC6667DPffcgw0bNmDJkiUAgKamJqxduxZr1qzpt73VakV7ezsyMjL6/c6fGGQREfXlPCZL6S6odBUEHOO0ZCvQ3QwkZgW3fURERIFm6gL+WKDOsX9ZBcQkerVpeno6LrnkEqxevdoeZL399tvIysrCRRdd1G/7v/71r+jo6MD111/v1yb3FdFjsp566inMmDEDKSkp9oFun376qdrNIqJQZy/hngIkZIp1JXsFAFq9I/jqbAhu24iIiKiXm2++Ge+88w4MBgMA4NVXX8UNN9zQq/sgAKxevRqPPPII3nzzTeTk5Ljald9EdCZr+PDh+NOf/oTx48dDlmW8/PLLuPLKK7Fv3z5MnTpV7eYRUajqaRXL2BRgyhVAyUZg3vd6b5OYLbJYnfUAJgW9iURERAGlTxAZJbWO7YPLL78csizjk08+wbx58/D111/j73//e69tXn/9dXzve9/DW2+9haVLl/qztS5FdJB1+eWX9/r5D3/4A5566ils376dQRYRuec8JittBHDzW/23ScgCcEKMyyIiIoo0kuR1lz21xcXF4ZprrsGrr76KU6dOYeLEiZgzZ47996+99hq+853v4PXXX8ell14alDZFdJDlzGKx4K233kJnZycWLlyodnOIKJQ5l3B3RxmHxe6CREREqrv55ptx2WWX4ciRI7jlllvsj69evRrf+ta38M9//hMLFixATU0NACA+Ph6pqakBa0/EB1mHDh3CwoUL0dPTg6SkJLz33nuYMmWK2+0NBoO9PycAtLWJiy2TyQSTyRTw9rqiHFet40cDnuPACrfzq+1uhQaAWRsP2U2bNfEZ0AKwtNXAqvL/K9zObzjiOQ4snt/A4vkNrEg5vyaTCbIsw2q1wmq1qt2cXmRZti/dtW3RokXIyMhAUVERbrjhBvt2zz77LMxmM1atWoVVq1bZt7/tttvw4osvutyX1WqFLMswmUzQantXEPb27yzJSqsjlNFoRFlZGVpbW/H222/jueeew8aNG90GWg8//DAeeeSRfo+vXr0aCQm+9Q8lovC0+NgvkNxThc3jHkBj8mSX20ysfheTat5HSdZiHCy8PbgNJCIi8jOdToe8vDwUFhYiJiZG7eaoymg0ory8HDU1NTCbzb1+19XVhZtuugmtra1ISXHf4yXig6y+li5dirFjx+KZZ55x+XtXmazCwkI0NDR4PJGBZDKZsH79eixbtgx6vV6VNkQ6nuPACrfzq/vnNEgdNTB95wsgf6bLbTS7noP2s1/AOulyWK51fScsWMLt/IYjnuPA4vkNLJ7fwIqU89vT04Py8nKMGjUKcXFxajenF1mW0d7ejuTkZEhezp81FD09PThz5gwKCwv7nYu2tjZkZWUNGGRFfHfBvqxWa68gqq/Y2FjExsb2e1yv16v+xgmFNkQ6nuPACovzK8v2MVn6pAzAXXtTREl3TXcTNCHyfwqL8xvmeI4Di+c3sHh+Ayvcz6/FYoEkSdBoNP1Kn6tN6fqntC/QNBoNJEly+Tf19m8c0UHWAw88gEsuuQQjRoxAe3s7Vq9eja+++grr1q1Tu2lEFKo6G2wz3EtAyjD329kLX9QHpVlEREQUPiI6yKqrq8Ntt92G6upqpKamYsaMGVi3bh2WLVumdtOIKFQ1l4hlyjBA1z+rbZeYLZasLkhERER9RHSQ9fzzz6vdBCIKN022ICtjtOftEmyZrO4mwGIGtBH9cUpEREQ+CK0Ol0REalMyWemjPG+XkAHANvi2uymQLSIiIgqaKKuJ55I/zgGDLCIiZ95msjRaICFTrHNcFhERhTmloENXV5fKLVGfcg6GUsiE/VuIiJzZM1kDBFmAKH7R1eB5XFZLObD9KWDB9wfOjhEREalEq9UiLS0NdXV1AICEhISglEv3htVqhdFoRE9PT0CrC8qyjK6uLtTV1SEtLa3fRMS+YJBFROTM20wWIIpf1B/3nMna9R9g+79E5mv57/zTRiIiogDIy8sDAHugFSpkWUZ3dzfi4+ODEvilpaXZz8VgMcgiIlIYOoBO2xeLN5kspbtgV6P7bZpLB96GiIgoBEiShPz8fOTk5MBkMqndHDuTyYRNmzbhggsuCPhcZHq9fkgZLAWDLCIiRfMZsYxPB+LTBt7eXsbdQyarrVIse1qH0jIiIqKg0Wq1fgk0/EWr1cJsNiMuLi5sJnxm4QsiIoUv47EApwmJPYzJamWQRUREFG0YZBERKXwZjwU4BVluMlkWM9BRI9Z7WobUNCIiIgofDLKIiBS+ZrKUCYm73MyT1V4NyFaxrmSyNv0V+Ow3g28jERERhTyOySIiUrSUiaW3pdYHKnyhjMcCgJ42wGwAvvw9ABk4+y4gpWCwLSUiIqIQxkwWEZFCGVuVlOPd9gMFWa0VjnVDm23/tlnk26oH1UQKMfVFwNPnA0Vr1W4JERGFEAZZRESKblu3v/gM77ZXgqzuJsBq7f9750yWbHVkygDRlZDC39EPgJqDwGe/AmRZ7dYQEVGIYJBFRKRQxlYleBtk2baTra4LW7RW9v656bRjnUFWZOhuFsvGU0DpVnXbQkREIYNBFhERAJiNgLFDrHsbZGn1QGyqWHdV/KLNU5BV43sbKfQoQRYA7H1ZvXYQEVFIYZBFRAQ4ugpKGkfg5A0lIHM1LotBVuRzDrKOftD7ZyIiiloMsoiIAEeQFJ8OaHz4aPRU/ELpLqiLE0t2F4w8zkGVuQc4sU69thARUchgkEVEBDiNx8r07XnuMllmA9BZJ9azJ4qlMtkxwExWpFCCrLQRYulcUZKIiKIWgywiIsD3yoIKd5mstiqx1MUBGWPEuqHV8XtmsiJDd4tYZk8Sy8561ZpCREShg0EWERHge2VBhdsgy9ZVMKUAiEvr/7zuJpHtovAly45MVtYEseyoVa89REQUMhhkEREBTmOyfA2ylO6CfaoL1h0Ty4wxQFyK6+eyy2B4M3YCVpNYV7qEdtSp1x4iIgoZDLKIiABHRsJfmayqfWJZMAeIc1OtkEFWeFNeM9oYIH2UWGeQRUREYJBFRCQMtbtgd59Mlj3Imt0/yEouEEuOywpvSpAVnw4k5Yl1BllERAQGWUREgj8LXxg7gfrjYn3YnP5jsnKniiUzWeGtp0Us49OBpGyxbmgFTD2qNYmIiEIDgywiIsARJPmju2D1AUC2ioxVcl7vTJY2FsgcK9aZyQorenM7pNNfAVaLeMA5kxWXJroNAo7S/UREFLV0ajeAiCgkDHqeLKW7YAtwfA1Qvh2ITRaPFcwWS+cgKyETSM4X68xkhZVZZS9Ad2gPkDcDuOwfjiArLg2QJCApF2gtF10GlXmziIgoKjHIIiICBt9dMC4NgARABt75LmDqAjS2j9ZhAwVZzGSFk0SjbQ6smoPAK1cDZ31X/Byfbtsg2xZksYw7EVG0Y3dBIiKrxTGprK/dBbU6ID5NrJu6bPszi6XLTFaG6EIIOObSorCgs3Q5fuhpBUq3inUlyErKFUsWvyAiinoMsoiIulsAyGJduWD2hXMXw5RhthVJlG8H+meyMseJ9eYzgNno+/FIFXolyEobKZZVe8XSHmTliCWDLCKiqMfugkRESlfB2FRAq/f9+QmZQOMpsX7F40DtESA2xZEV08WJoggWo3gspQCISQaM7UDTaSBnkn/+HxQ4shV6S7dYL1wAtJSKvyfgyGTagyx2FyQiinbMZBER2YteDCKLBTgyWXFpwOgLgXN/DJz1bcfvJcmRzUrIFD9nTxQ/K6XeKbQZOiAp2c7C+b1/17e7IKsLEhFFPQZZRERK+XVfi14oErPEctJl7jNhsSliqQRk2bbsVX3R4I5JwWVoBQDIujggf2bv3/XLZDHIIiKKduwuSETRzdgpxkYBvhe9UCy4U0xAu+gX7rdRLsTtQRYzWWGlRwRZiE0Bsib0/p29uiC7CxIRkcAgi4iiV0s58K8FgKlT/OzrHFmK3KnAtf/xvM287wH6BGDsYvEzM1lhRVKCrLhUETAn5ztK8PcrfFEf9PYREVFoYXdBIopeZdscAZZG5wiAAmHWTcDtHzuyZUomq/EkYDEH7rjkH7YgS1a6fSp/P6D/mCxTJ2DoCGLjiIgo1DDIIqLopVQEnH0L8MsqYOYNwTt2aqHIbFmMju6KFLoM7WKpFDBRMpGQRFVKAIhNAvSJYp1zoBERRbWIDrIeffRRzJs3D8nJycjJycFVV12FoiJ2zSEim4aTYpk1AdDFBvfYGo1jbA/HZYU8yaB0F+yTyYpLFX9LRf4MsSzdErzGERFRyInoIGvjxo1YtWoVtm/fjvXr18NkMmH58uXo7OxUu2lEFAqUTFbmeHWObx+XxSAr5Nm7C9qyVnm2YCqloPd2Y5eI5akvgtQwIiIKRRFd+GLt2rW9fn7ppZeQk5ODPXv24IILLlCpVUQUEmQZaCwW65nj1GmDctymEnWOH21Kvga2/QtY+WcgbYRvz3UufAEAw+YClz8O5E7rvd3YxcCG3wMlm8RYO21Ef80SEZEbUfXp39oqviQzMtyXaTYYDDAYDPaf29raAAAmkwkmkymwDXRDOa5ax48GPMeBFZLnt60aelMnZEkLc/IwQIW2aWJSoAVg7W6BZQjHD8nzG4K02/4NzYlPYRk+D9aFP/LpuVJ3CwDAok+CVTnPM24SS+fznj0Vuvh0SN3NMJftgDy8z8TF5BJfw4HF8xtYPL+BF0rn2Ns2SLIsywFuS0iwWq244oor0NLSgs2bN7vd7uGHH8YjjzzS7/HVq1cjISEhkE0koiDKaj+Kc0/9CR2xufhiyl9UacPwpi2YW/oM6pKnYtu4n6vShmhy4fEHkdZ9BqeyL8aR4Tf59Nx5p/+JgtY9OFB4O85kea5CeVbJkxjWshPH865CUf41Q2kyERGFmK6uLtx0001obW1FSkqK2+2iJpO1atUqHD582GOABQAPPPAA7rvvPvvPbW1tKCwsxPLlyz2eyEAymUxYv349li1bBr1er0obIh3PcWCF4vnV7KkFTgEJw6dj5cqVqrRBOqEBSp9BVnLskNoQiuc3FOmKfgIAGJObhJE+nm/N/54BWoFJsxZgygzPz5X2NwGf7MQEbQXGqvTaCjd8DQcWz29g8fwGXiidY6WX20CiIsi6++678fHHH2PTpk0YPny4x21jY2MRG9u/ypher1f9jxoKbYh0PMeBFVLnt+UMAECTPREatdqUKOZX0hg7/NKGkDq/ocbUDXQ1AgA0XQ0+n2/ZIL5UtYkZ0A303NHniePUHVXvtRWm+BoOLJ7fwOL5DbxQOMfeHj+iqwvKsoy7774b7733Hr788kuMHj1a7SYRUaiwVxYcq14bYpPFUpmDiQKnrcqx3tng+/OVv1GsFz0akvPE0tQFGFnNlogoGkV0JmvVqlVYvXo1PvjgAyQnJ6OmpgYAkJqaivj4eJVbR0SqarTNkaVWZUHAEWT1eNf1gIagtdyx3lHn+/Nt82TJSnVBT2KSAF08YO4GOuuBmETfj0dERGEtojNZTz31FFpbW7Fo0SLk5+fb/73xxhtqNy26FH8JHF+jditCV2Mx8NlvgJYytVsSPZpLxT8AyFJpjizAkRUxdQJWi3rtiAatlY71rgbfzrcsO0q4e5PJkiQgMVusd9R7fxwiIooYEZ3JipLCiaGtpxVY/U1xQXPfUUc3GnLY+gSw50Vg6+PAT44CqcPUblFkk2Vgzf8BsgUYdX7/yWSDSclkAaI7Wnyaak2JeK0VjnXZCnQ1AUnZ3j3X2AFJtop1bzJZgNh3axnQOYisGRERhb2IzmRRCCjdCliM4oK2ar/arQlN9UWO9f9eAZh61GtLNDj2EXByHaDRA5f+Td226GIBra3QDsdlBVZbRe+fO33IMNmyWBZJB+jivHuOksny5ThERBQxGGSR/5V8DfxjBnDyc6Bkk+Px6gPqtSmkOWVcG08BFbvUa0o02PG0WJ77IyB7grptAVj8Ilha+wZZPmSYbEGWSZsgugJ6g90FiYiiGoMs8r+DbwAtpcDnDwOnNzoeZ5DlWndz759NXeq0I1p01IrluKXqtkPBIMulrcUNeOKLk7Ba/dTtWxmTpbH1kvcl+LEFWWatDxPSJ+WIJbsLEhFFJQZZ5H9NJWJZewioO+J4nEGWa11NYhlrG+vBICuwlAIG3o6tCTR7kMUKg84e/OAI/t/6E9h8ahDl1vuSZaDNFmTlTBHLQXQXNPkSZLG7YPQo2w78ZwlQsUftlhBRCGGQRf7XVNz757SRYtlWAXQ2Br89oUyWHZkspeCFkUFWwDhXiQuZIMtWrY5Blp0sy6hoFu+DgxUtQ99hTwtg7BDrBbPE0qfuguJvM6ggi90FI9+ht4DK3cDht9VuCRGFEAZZ5F/GTqC9uvdjEy8BMmwTvtYwm9WLsROwmsR6ii3IYiYrcMw9ohALEEJBFrsL9tXWbUaPSVTzO1zph+BT6SoYnwGkjRDrg+guyEwWuaTcuBnM/GtEFLEYZJF/KV0F49OBwgVifcIKIH+mWK8+AFit6rQt2FrKAUOH5226bV0FtTFAQqZYN3UHtl3RTLkYkjRiwthQwCCrn5o2R4XNQ5WtQ9+hUvQidTiQOIixUh1iInufgiyOyYoeymTiDKiJyAmDLPIvpatgxhjghtXAtz4Gxi52BFmb/gr8PhvY94p6bQyGtirg8dnAq9/wvJ3SVTA+A4ixXcAxyAqc7haxjEv1vkpcoMUp3QUZZCmcg6zKlm40dxqHtkNlou/U4U7Bj5cXxGe2AFufBAC0xRd6f0wlmOtuBiwm759H4Ue5ecMgi4icMMgi/2o6LZYZY4HELGD0+eLnYXPE0tgBWM3AiXXqtC9Yao+IboDOc2C5ohS9iE8H9EqQ1RnYtkWzUBuPBTCT5UJNa+8bDYerhpjNUoru5E71baxUySbgtRsBiwHWCStRkrXE+2PGpwOSVqx3+qF4B4UuBllE5AKDLPKvRqdMlrNR5wMX/wmYcYP4ub0muO0KtrYqsTS0i2IL7iiZrIQMpyCLmayACeUgq4eFLxQ1rYZePw+5y2DVXrEsmO00VqrO/XtTloFt/wL+exVgaAVGLITlqmdEN1NvaTTiRpNyLIpcStGarkbAalG3LaSqgxWteG9fxcAbUlRgkEX+pYzJyhzb+3FJAs6+CzjrO+LnjggPspTiH1YTYDa4387eXTAd0MeLdRa+CJyQDLJYXbAvpbtgcpyY0+pwZSu6jGbInm5YuGPsBOqPi/WCOY7ughaj63Nu7ATe+R6w7peAbBE3hm59z/H+9EWij10TKTwpnyuy1dE7gaLSfW8dwk/eOIDi+gHGY1NUYJBF/tXkJpOlSM4Ty/YazxmecOdcYdFTNzCl8EV8miOTFWUl3P++/gT+9tkA3Sr9padFLEMqyGJ3wb5qbUHWookiSFl7uAZTHlyHBz844ulprtUcEhe/SXlASr4IlmJs59xVRv2Du0Upbo0OuOTPwNVPDy7AAhyZLJZxj1wWs2N6AIABdRSzykBFi+iJUus0rpSiF4Ms8h/n8u0DBVkWoyOLE4nanIMsDxkKpRBDlBa+aOky4p9fnMTjX55CXTC+lOyZrLTAH8tbDLL6qWkVr4Wlk3OQEKOF1XY/5rOjg8iAV+0Ty4LZjsdybRMSb/5H722tVuDkZ2L9htXAgh8MrUAKKwxGvr6f7/xbR61OM2CxfVi195hVbg2FAgZZ5D/O5dsTMlxvo4sVvwcie1xWe5Vj3dPFs8vCF9GTyapodgSUp+qC0L0ipLsLMshSKHeBx+ck490fnoNnb51re9yAli4fKw26CrKW/x6ABBxYDRRvcDzeeEpkJXTxwLilQ/gf2HCurMjX02e8IIucRK02p48mBlkEMMgif2otF8u0kZ63S84Xy76TFkeSNm+7CzoXvlDGZEVPJqui2RFQngpGH3ZmskKewWxBo61ke15qHCblpWD51DwMTxfvj6IaH89Tpa3ohVLhFAAK5wPzvy/W1/zU0XVZqUKYNx3QaAf7X3DwpZIhhad+QRb/1tGqzeTIerf3cNoGYpBF/qRcJManed7OeVxWJDIbgC6nu5neBFlRWvjCOZN1sjYYQVaLWIZUJotBlrO6NlEoJkanQXqC3v74xFxxnopqfThPnQ1A40mxnj+r9+8W/1qMu2o85bhBVL1fLAv6bDtY7C4Y+fp2F+zg3zpaMZNFfTHIIv9RLhJjkjxvp2SyIrXCYN/g0avCF+mAPlGsR1GQVdnC7oL27oLGdpZ/hqOyYF5KHCSn8VAT80SQddzbTFZHvSjBDgDZk4Ck7N6/j0sRGSsAqNgllkomS5k8fajYXTDyMZNFNm1OyStmsghgkEX+pFRYGijISsoVy0jNZPULsjwVvlAyWdHaXdApyApqd8FQCrKSHetGlv1Vil7kpcT1elwJsk54G2S9eRtQe0iUUf/GC663GT5PLCt2i6IX9iBrlq/Ndo3dBSMfgyyyaTM6dxdkJosYZJE/GWwXiLFeZrIiNsiq6v2zu0yWLPfpLhh9JdwrnYKs+nYDWrsDfPcvFIMsXSygsXWLY5dBe9GL3FTXQVZRbfvA82VZLUD5drF+67tA7lTX29mDrF1Ac4m4IaKNBbInDrr9vShBVleDCOIo8iiTiCs9ERhkRS3nTFYbM1kEBlnkT95msiJ9TFZbn4Ie7i6cDe2A1Xa3q1fhC++DLKVcbLhSCl9obDcAA95lMBSDLEniuCwnjkxWbK/Hx2QlQaeR0N5jRlWr63L/RrMtkOluFnNjAaKroDvDzxLL6gNA+U6xnjcN0OrdP8cXSpBlNTvGA1JkUT5TMseKJbOWUYuZrCHqqAc+uheo2q92S/yGQRb5j3KBOGAmK0KDrNojwJZ/ijviztxdOCtZLF2cbYJUWybLagIsA98F23yyAdMeWoc3d5UPodHqaesxoc32RTSrMA0AcKougEGGLIdmkAWI8UEAgywAZU0i8B6W1nsC4BidBmOzxWdLUU3/Lrgbiuow9aG1eHVHqaOMdlya54ApfTSQkCXm7dv2L/GYv8ZjAYAuxlHJkgURIpM9yBonlp31jmqVFFV6Z7IYZPns8DvAnheBrU+o3RK/YZBF/mPPZCV73k4JsjpqIuvLaO0DwPoHgV3Pi58TbZXF3AZZTkUvAEd3QcCrcVk7zzSh22TBRwerBtw2FCldBdMT9JgxPA1AgDNZpi5H5jDUgix7JsvD+L0oUWwbmzcmu//Nmgkeil9sLKqHySJj04l6R3XPxCzPB5MkR5fB2kOApAGmfWPwjXeFxS8im/KeVYIsczdg7FSvPaSa3tUF2V3QZ12NYhlBWX8GWeQ/yhdLTKLn7ZTCFxajI5sTCWqPiKVsqxCnjOsYKJMVb5u4WRsjLvIAr7oMdhlEwHCkqm3gMSohSAmyhqXHY1yOuKAOaJDV3SKWGt3Ar9FgUyoM9kR3kGWyWO2ZrLE5/YOs0Vni7+Y8lg8WE1B3HFX14gu6sqXbkclKGCDIAhxdBgExSfGocwfXeHdYxj2yKZmslHzHjTL+raNOh8EMo5XdBYdEuWERQT06GGSR/3hb+EIX6wgsImVC4q6m3nNjAUDWBLF0l51QLgSVTJYk+VTGvcskgrmmTqO97HU4UcZjDU9LwOR8kaHYX94Ca6DGmTl3FXQqDR4SOCYLAFDe1AWTRUa8Xov8PtUFASAnWYzTqm8Xc2nhwx8Bjw4H/r0A3656BABQ1dLjfSYLAKZcJbLOC+8Gzv6hP/4bvSlt6GzwvB2FJ+VzJTbFKWvJv3W0aegw9PqZmaxBUN5LEfQ9yCCL/MeozJM1QHdBIPIqDNYXiaVz0Y+cyWLp7gPj8LtimTfN8ZgPZdyVTBYAHK4MvwyIMkfWsPR4zBiehqRYHZq7TDhSFaD/S6iOxwIYZNmcrhfZ8NFZidBo+gfC2UqQ1WEAWsqAvS8DZnGDYbL5GABx08HYZuual5A58EGzxgE/PQGs+ENggm+l2zDHZEUm++dKGpAyTKw3nFStOaSOOtuNH2UC9R6TFSYLK4r6pIeZLCL3vM1kAaJrBQA0nQ5ce4KpwRZkFS4ArnoaWPEokGsLnlx9YDSdBk6sFevzvud4XAmyvCjj3mV0TFx7pKrVw5ahyR5kpcVDr9Vg4VhxQbzpZIDGrjDICnnKeCxXXQUBpyCr3QCc2SIezBLdctOkTiRABFxdLbabN7YsUm1bD+783x5sK250feBAZjbZXTCyOX+ujDhbrJ/5Wr32kCrq28WALKVLM8Augz5jJovIA29LuAOOL6NTXwSuPcFUf0IssycCs24EFv7Q84XzzucAyMC4pUDWeMfjSp9+b7oLOgVZYZnJchqTBQAXjBcXxJtORGOQxeqCgCOTNSbL9Zi57CQRZNW1GyCf2SwenHgxTHrxXsuXRBBlbFUyWeI19caucqw9UoOXtvap/BkM7C4Y2Zw/V0ZfINZLvo6sok40oHpbd8G8lFgkxGgBAG2Bnvcx0hicgqwIef8wyCL/8SWTNeFisSzZCJjCbzxRP0omSxmHBbgPsoxdwL7/ifUFd/b+nVLG3ZvugkbHXbJwzGQ1dIg7f0p24oIJYjzDntJmdBgCcAcwpIMsVhcEvM9kGc1WWJUga9T5aI8RxXSGSSKQsSqV/GwBzoHyFgAI/GTXrrC7YOSSZcd7Ni5F9GTQ6IG2isjppUFeUcaJZifHIjlOB4CZLJ8p39GyxatroHDAIIv8w2IWpWsB78Zk5U4T/ddNXYBysRTOnDNZCuXC2dzde96rxpPiizk+Axi7pPd+BpnJqm7tQWOfgbehrrlLBFkZCTEAgJGZiRiRkQCzVXbfrWsoQjrIUjJZDLIA95msOL0WyXE65KMR2pYzgKQFChegQSsC9HxJTIugsU2PUG9NgizL2G8Lstq6VbjoYXfByGXscEx6HZcqbpIVzhc/l2xSr10UdPYgKykWloSHlgAA98xJREFUKXFiXBaLX/jIubpuhPTqYJBF/mFymhfEm/LYkgSMXybWT64LTJuCxdgFtJaJ9SynIMu526TzB0abraJiWiGg6fMWtBe+8C3IAhC4ghEB0GOy2NufkRRjf/x8W5fBnSUBCLKUkvkhGWRxTFZTpxHNXeKiZEy2+8+Q7ORYLNCIIhfInwnEpaDcKsbzTU0QgXSsQQRZd75bip0lTWjsFAF9mxoXPewuGLmUGzfaGDGpPACMOl8sOS4rqhTbujqPyIi3Z7I4IbEPnLPCgGP4CSCCL3N43URWMMgi/1C6Cmp0okS7N5QugyfWhm//27IdwIlPxXpCJpDoVM1M5/TF63zx3G6bPDi5oP/+fKkuaOsuONw2pqm8eeDALFQoWSydRkJyrM7++MhMkclTuhL6VVulWLo672pjkIXTtizWsLR4JMTo3G6XneQUZI06DwBQbBCB87TkDkiwItHSAgCoNCbitx8ftT9XlTESSndBU5fjc5LUc/hd4JVrgU4/3MhR7rzHpjiKp9jHZW0K3++1ADGarXhnTwVqWiNgiIATq1XGSdscjxNyk5HMTJbvnLPCQO85s/45A3jhYnXaNUQMssg/nIteeFupa/QF4g5gSxnQfCZgTQuYit3AC8uBt78jfnYej6VwdfGsZLKUCovOfJkny5YJKkgTQZYq400GqcmWWUhPjIHk9HpJjRdfTgH5v7RW2A4y3P/7HioGWTjd4Cjf7kl2cixmaYoBAI2Zc7GvrBnHukSQNVLbhBR0QQfxZd2ElF4Z3naDOXDzsLkTkwjobDdPOgNU1IW8t+Np4NTnQLEfii656oI8/CwAkvhb8+/dy6eHq3H/Wwfw6KfH1G6KX5U1daHbZIVOkjHSKZPFMVk+6Okzrlz5Lmw8JXqhVO0DrOF3PhlkkX/Yi154MR5LEZMIZIwV6+EySLi9RlSOAsQXtbO8Gf23d3Xx7E0ma4AS7larjG7bZMQFqSJbFo5BljIeS5EaL35u6QpgJotBVkgqbRRBlpLNdCc7SY/RkrhR8d1PO3D1v7eiytZdMNVUi0xJBFXtcjyM0Pd6riwDHcYgf1FLEpCkTFLLi27VddkyWH0v6gbDVZCli3XMz8ZiJ72U2G6klDaGT68LbxyvEZ/beQmATquxZ7Laekz48EAVqloio4hDQPX0Ge6gfBe219oekIGupqA2yR8YZJF/2Cci9qKyoLP0UWLZfEZ0kdv279DOar1xK/DyZUDxl0DpVvHYOfcAFz8GXPjz/tv7nMlyMyarcg+w9UnAKgKrHrPF3hMlX8lkdYVhkJXYN8gKUCbLbHRMfJ1a6N99+4NykRbFQVZZk7gQGSjIGqNvRpxkghE6HOwQ560S4qJW216FPK04h01IxqKJ2f2er2qXQV50q0+5UPNHkKUUM+k76XWSqHaJjlqQQ22bGFfTEGZFmgZSZAuy8hPEl3KKLZP1+s5y/Oi1fXjkoyOqtS3kNZwCqg+4z2Q5v4e6wm9ca8QHWZs2bcLll1+OgoICSJKE999/X+0mRSZfyrc7cw6yDrwOrHsA+OK3/myZ/zSVABU7xfr+14CKXWJ95k3A2Xf2Ho+lcFU1rt0WZCW7CLKUoiHOY7JOrgdeuAT47FciuEPvohd5KSKT1RJGQVazmyArLSFAQVZ7FQAZ0MY6ChGEEudg3Gr1vG2EKmtUBo577i44UhaZ4DPWXFihwfnjs/DsDy+DDAmSxYA5CeLC1xCTge+cOxqAeF1l2l5rqlQYTGQmKyRYrUBPi1j3R5DVZuuVkDqs9+NJDKpdqWsTY7Hq2w2QI2i8WlGt+H4vsAVZSnfBGtv/91h19N4888hiBl66FHh+BdDcZw5De5DleA9JYVg8KOKDrM7OTsycORP/+te/1G5KZLOPyfKisqAz5yCr5qBjXU3rHwSePt9RjU5x7CPH+uG3RbYpLg3InuR+X8rFs3OlHOWLOcWLwhdlO4DXbwIstjt/tnPTZRBBVkKMNnCByRCZLe6DhSZbQJie2Ls7l3Mmy69fwvbxWMO8HzMYTPZutnLv10qEsFplPPTBYby7t8LtNqVNIns7UCYr31wOADgti/fPWSMzMKUwG5ItezBHXwoAiE3Jxvnjs/C7K6fiiRtnIzXB0YUn6NhdMDT0tDgG1/slyHJTTIdl+12qbRdBh8FsDcxciCo5bguiCmwfXSnxvb/XKlu6YfLwfRi16o4AHTVimpvKPb1/5zKTFX6fnxEfZF1yySX4/e9/j6uvvlrtpkQ258IXvnAOsupsg2GVbl1qkGVg90si4DvRp7T8sQ+dtrN9YI5Y2L8MuzPbxbNs62/c09XhuJPqKpNlnyfLVhL/wGrA4jQ+yRagdZnEF5QIsmzjmEIoyHp1RymmPbwO20+7ruDV1CmCxv5jssSXk8niGHPmF60hPB4LEFUoNbaKehHYZfBwVSte3laKP68tcvn71m6TPRNbmOE5yMrsEUFUsSzeP7NHpIlf2P62Z8eLIGzYsEJIkoRbF47C+eOz7XPXqNNdkEFWSHC+ceaPOensXb/dBFnMZPVS0+roJtjQYRSZjDAcZ+Osx2TBGVsWPr9PJkthscocl+VK2XbHes3h3r+zB1mO60Fmsih6DabwBdAnyLKVWm6vsY89CrquBsBgu8OpjLkCxEV6xS4AEjDhEsfjIxd63F1lt/iwXb+vGD0mC+566mMAgKxPcD1fU99MVqOoombPltkC0E57JkuHNCX746FYxNZTDXhxS0nQumhsLW5Ej8nqNshq7hQXun27CybEaKHXikyTX7s/tooL75AcjwWI7FoEF79Q/pZNnUaXr8FyWxYrKykGSbHuy7cDQFKH6FZy2ioubGcWpolf2IKshAbxZa1L7j0eS7m7rMrcNRyTFRqcL+j92V2wb5DFv3c/JosVjZ2OIKulthx4+lzgb1OA018BLeXAO3f0v7kZ4k7WdsAqA+kJeqTYEljJsfp+20VasQ+/cL7Gqu0zbs1Fd8FwnGvQ87dZFDIYDDAYHB8EbW3ibpfJZILJpE6mQDmuWsf3hqa7DVoAFl0CrL60Mylf1P9yvqsoW2BqrXYMHg4C5dxa6ors9cjksm0w2x7XHPkAWgDW4fNhnftd6GxzY5mHzYfs4f9b2q7BMADlNbW49fkd0DaWAzFAT1wOdOb+F3uSJhY6AFZjJywmE3SNxZAAWAoXQlt/HNa2SlhMJrR3i9dovF6DRL0ISlq73b9Gf/b2AVS09GBybiLmjkz3+fz4qsOWLWjsMPR67yjLxg7RbSQlTtuvzSlxejR2GtHY3o3sRP98RGmay8TrMynft9dnEOlikiF1N8Pc1ezxNeVKqH9GNNv+3kaLFW1dPf3mwTpdJ97/henxA/4f9M3ixsNpOR9jshKRoBP/b01yPrQAABHEWeLSe/2tk2PEb5s7ewZ1noZyjqX4DPG+7qiFJUT/RmoLxmtYaq+zX/RYu1uH/LfQtVVCAmBKyAWc9iXFZ4q/d3tNyPy91f6MqG7tsRdrykQrxn16I9AhqgrLb90OxCRBai2HtfkMLKMXq9LGwThaJbKj47MTIUndMJlMiO8fY+F0fTsWjk4LbuNCmSxDV7YN9s77tuJpMiRIkGHtaRPXQO219m3kjlpAGxrfc962gUFWH48++igeeeSRfo9/9tlnSEjw3I0l0NavX6/q8T2ZVnEIYwEUV9Ti2Jo1Pj13hS4VcebedxW3rH0brQmj/dhC7xzb/BFm29alhhP4/MM3YNQlY1bpGowEcMJcgBPH2nBhXCE0sGDDvirIB9z/f822akrjpErsPdOAyzTiTmqtMR4HXJynvNajWACgpa4KWz9+D5fZyr3vbUrAPAAd1aewYc0aHGySAGhh6GzHzs1fAdCh02jBhx+vga5PftoqA1UtWgAS/rduO2qH+57NeuWkBu0m4AeTrdB4MaSpvEYc7+ipM1izxlGeX3kNl9aK3588vB9rKvb1eq7OKn63bsNmnE71T+ZtQfF+5AE4WNqEMh9fn8GyyAikAtj59ReoTxncHfBQ/YzYWiterwDw3iefIb3PfOXrK8XvNd3NWOPh76O19OAyW+GYYjkfk6R2+/a5rXFYYPuCBoDdp5tR0+TYV3O9BoAGew4eRU7z4Kt9DeYcZ7UX41wAnbVn8GWIvv5CRSBfw4WNmzHHtt7ZWDWkv4XWYsBltq7fn20/DLO22P677LYzOAdAR81pbAixv7danxGl7YByyflj3btI7jiNbn0GDLoUpHWfsXfl7Ko7gy9C7Jx5sqlCfHZJtizp+vXrUdkJKP/XEYkyyjolbNx9BOkNh1RrZ6hJMNRhmYvqmz36NMSbmlFbdgo7P/kEl7ZV2QOV+tJjwJiLQ+J7rqvLu8wkg6w+HnjgAdx33332n9va2lBYWIjly5cjJSVFlTaZTCasX78ey5Ytg17v4hZJCNB+vA6oB8ZOnonR56707bn1Tzqq9tmcN2Ms5AnBm+FbOcfT8uOBMsfjyyYmQ564EtrXXgSagHFnXYSxM68ALrkUkCRcInnucfu/1ibg9Ju4UHsQb0i/w27rRACAnDkGK1f2P09SSSJw+h9IT4rFivkTgQOAHJeKWStuBp75F5LldqxcuRKmA9VA0SEU5Gbimsvn4ld71kOWgXMWLUFWUu8r2MZOI6zbvwIAtMflYOXKuT6dG6PZih9vE3OCTV94AUYOMGYGAJ4u2Qa0tyM+LRsrV87t9xr+w+GNAAxYsehcTC3o/b56qWIHastbMXnmHCyf4p9spu7ZR0X7z7sE08Zc5Jd9+pu24d9AeRnmz5wMebJv76FQ/4yo+LoEOH0SADDn7PMxOb93t+It7x8ByiqxcNo4rFwyzv2Oqg8AB4EWKRVtSMKlZ0/GyvlKF9CVMLd/F1LNAcDcgzmTLgec3p9HPjuBrbVnkFc4GitXeihW48aQznH9WODUn5Akdbt831NwXsOanWX2z/cknWVof4vGU8BBQI5JwvLLr+39u7pRQPGfkawJnb+32p8R64/WAYf3AwBm2iYT11/2Z+iGL4D82rWA1Qyp6TQS0RUy58wbe9ccB8rLMH38KEA+jWXLlqHNYMVfD21EUqwOt1wwFn/8tAja1FysXDl7wP1FC+ngG8BRQNboIVkdWaHY7DFA1R7kpidg5dILoNvvGAaRkyg+z0Phe07p5TYQBll9xMbGIjY2tt/jer1e9T9qKLTBLVuhBm18KrS+tjFjdL8gS9dVB6jwf9W2nLE1IA4w90BXuROYdqV9LJQuvdDWLu/adiRuDu4x3o2/xL2Is3ACszWnAACtumyMdvX/ixMXn5KpG/pWMcBfyhgLfbq4kJQMbdDLRhgs4m59UqwecbExSI7Voa3HjE4TkN9nvy09jgG3+8paodHqcKquA8PS4wcc/wIAbU7dZ5u6LBiXO/D/vctWtKK129zrNavX66HT6dBsGz+Wm5bQ7zWdnijefx1Gq/9e77YqYLqMUaq8rrxiG6OnM3cOuo3dZsAE2CvphYouk6Oylqu/a0WL6E44OjvZ89+8RYzHMqaNxbzcdFw2c1jv7TMKxT8X0vz0uhrU53CqKNIh9bRAL8mALmaAJ0SvgH7PGRw9JiRD29CO0yWyzVLKsP77SRVjtKSuRug1ALSh835U6zqiwTYuU4IV4yXb5/GwWUDmCOCHO8Tf5rFRkIwd0MMC6OMcT5Zl4PgnQO5Ucb0QQlpsU0JkpcQBreL85iXo8fzt85CREIMm23ddeXNP6F6/qaFSXPNJ45cBRY7MpSatEKjaA42xA5qe3kVRNLaJxEPhWtjb40d84YuOjg7s378f+/fvBwCUlJRg//79KCsr8/xE8s1gS7gDjuIXAJBgm8NIpQqDUpOty8fkK8RSGZipDHDuW6rXhfKmLhyvEXc5Og1mfGQ9BxvPehIAoIW42KzXuJhTCwBilOqC3YDSlsyxQFyKo3Jjew26jY4S7gDsFQZbu/sXv2hodzzWbjDjT58ew4p/bMKyv23EDjeFKZw5Fwqotc37MZBOW3leZdJhZ+0GM0y2IDE9of/Fpt8nJO5pdYz56zufTSgZYuELgwVY+eRWLP37RjSG2GSfznNTuaqCqQwKH6h8OxrFTYqc0dPw1p3nIDOp/w0xd1StLhifDkjivRqOE2pGDOfCF+YewDyE94mnqTgSMhx/7zAcrB8IynfHGF0jEiQDTJIeSLcFTBqNmA5FqbDa9z1SsQt442bg/buC12AvKd9xfSvlXjQxBzML0zAqU1wTlTV1RdTcYENWukUsp3+j9+NKBWBDh6N8u9Z2blnCPfTs3r0bs2fPxuzZIk173333Yfbs2XjwwQdVblmEUaoL+lrCHegdZI21deWyjUUKKtkKNNnGD82+RSyrD4i+4sodUFdfqM67kGVc/8w2XPHkFrR2m9BpFBeXHbnzgfHL7dvVym6KT9hLuHc5KgtmjBXL5DyxbKtyVBe0ZaKUubJcVeSr7+gdGP3na5ENqG7twY3/2Y4tpzxfBDgHO94HWRZbe/oHWcpExAkxWsTptf1+7/cgSynfHp8xuJsAwTLEIOtIs4TaNgPq2w1uS6WrxXluqr6vUaPZiupWkW0dMVBXVKUEt1K9zQeO6oIqBFkajaOMOyvOqae7T7nwniGUcVfmyHL1naDROiY9dzHuJBrV2sYnL84QN/bKNYWA1qknhSQ5TXXQ5ztJmUOz4USgm+mzxg5bkJXoOrMxLC0eGgnoNllQ3x5aN79U01opbphJGmDsEsfNdQBIsd0INbQ73jvZYpiFZOyE1hpe5zDig6xFixZBluV+/1566SW1mxZZjLZ5nWKHEGRp9MDIc8W6CpmseFMTJIvB0Q59IiBbHHM5xCSJjJIHDR1GVLf2wGi2oratxz7hYmKsDlj8G/t2FWZ3QZathLux0xHwZSpBlm1erfZqxzxZtiDFU2Di6oO9MCMeiyZmwyoDnxyq9vh/cr7zX+fFl4TF6pjjqtNogcHcuxy/cufPVRYLcPxf/FbC3Xki4lA2xCBrb4OjIskbu8uxt6zZw9bB5fwaaumTbV17pAZWGUiK1SE7eYDMlMk22Fh5n/ggxTZ3jXNWLajcXUBSfy1lwJ6X/D+VR985mYZSxt1TJgtwmpC4z9332qPA1ieBL37Xe56gCKfcoJufKL7bT8ou5ixULrb7vkcabd+FXY2O6U1ChFKWvu90JIoYnQYFaeLzSplwPeqd+Vos82cB8WlAmlMX71TnIMt2QypjLKAV3w0x5vCa4iTigywKElv5TcT4OE8WABTMBgrmAHO/BaSNEI+1eb7wD4SkHltglzFa3GHLHCN+LrF9IAyQxQKAsqZO+3pbt8nebS4pVgfkz8Dh6T/Ha+aLcMA6yvUO4tPFeDDn4E7JZCnHb69GV59MlqfARAmy7PMJAXjgksm4erb4MCuq8fyh5Xzn35tMlpK9U/RtkxJkZSZ5DrL8lsnqtH1QJ+X5Z3+BEmsL4AcxSWpbtwlHW0SQNW+UCOD/8flJvzVtqNqdupw6vx6aOo145ENR6e+7542GJA1QulK5wBpERlLVTBYAJClBFjNZA/r058BHPxaBlj/1y2QFMsiyFe3pm8n639XAZ78Cvv4r8MatQJR0IauzZbImQMxZeMg0rH/3OSX71zcwbXJUqLWf9xAgy7Kju6CbIAtwdINWukW/tKUE/9teGvgGhirlmmr0+WLpPH+l0l3Q1OnIFifn2W9SxZr8MIl4EDHIIv+wT0Y8iEyWPh74/gbg0v/XK1sTbIkGW5CVaatupgQ3ZzaJpdI2D5wnHGzvMdu7zSXG2ubomXEHHjDfgeZuq8vnQx8PLLhTrMu2u7hKsGfvLliNrn5jsmxBlodM1vIpuVg2JRc3zCvEJdPyMClPXNSfqGn32Ffc+c6/N0GWEgAqmvt0GRwok6X8X/wXZNnuiiZmed5ObUqWdBCZrPXH6mCRJYzPScRPl4uuFRUhdNe0d3dBx+vh958cRWOnERNzk7HqIg9VBRVDymR5Nybrb58VYcXfN7ns6joknKDWexW7xfLEWv/ut6tPdtcwhCBL6dKe4iZD7urvbewEOmzfM9oYEXDXHxc/Wy3A3v8Cm/4akYFXje27I8dwBgBw1DIM7YY+WWXlM7rvmCxlfDIQUkGW8/jivmOynI3IcIzLauww4OGPjuI37x9GZUtoZeWCQpaBko1iffQFYukcZKU4ZTiV4Dopx/7aiDUzyKJoZBzCmCxnKbZAprsJMHk3/sdfkuxBli24UoKtmsO2tg2cyXIOstp6HGOylCp+SmDRN/Do5fz7xPghQCzjbV0Lk50yWbb9KkGWkv1xdQFZbyuCUJAWh//cdhb+dO0MSJKEMdmJ0GsltBvMHj/snS+QlbuRnnT0+eJs7uzdJuX/7u7On/8zWba7ogluio2ECqW74CDGiaw5LF67l07Pd8rYBLdbXGuXCd9+cafLO7S9Cl/YMlmyLOPTQ6Ldv7tqGmL6TvDmij3I8n3OwpR48R5sN5hhtbq+iDWYLfjP1yUoqm3HjpIml9sMmru79NRbR60j21eyqX/3sIo9wNvfHVxvByWTpdwwC0Z3QecgS7nho40FRp4j1ku+BppKgOeXAx/eA3z5O6Cq99yB4a7HZEFrtwk6mBHXKgKmE3Jh/67s9i61Tu8RqwVoPuP4WcluhICmDsf44viY/uOLFcPTxU2hyuZuVDQ7Xs8bjkfhDZfmM0BruShyUni2eMy5u2BilhiyATjGpSc5ZbIYZFGwdfa9GxRsFrOo1AQ4LhQHKy5NdJcDHHf8giS7/ahYyZshlkqwZZvc1Lvugs5Bltn+t0ns062vucvkPnsUlwosesDWlmmOx5VMVrtzJstW+CJeBCyu7r4rX2TZSXG9HtdrNRibLYJiT10GfS180ff12DegbByge4Xfx2TZyr7av8BD1SDHZFmtMvaUtQAAFk/MRrIy9ijI3eJe3VmKDUX1ePqr4n6/65XJsr2eatsM6DZZoJGAWU5dWT1SLrgHE2TZMlmyDHQYXX9m7iltto8n9Gb8oU/cjdGhXqTaw44fzD3Amc1iLJXR9tm65n7g8NvA5r/7tmNTjyNIV6raDTbIMhscf0d3mSx7kOXUXbDLKas+ytZV6szXwAd3A5W7HdvVHRtcu0LUNlsV2wn6ekgWI7oQh0o5Ew39gizlRoRT1dvWCsDi9B0SQkHWQN9limG2MVlVLd2oaonyIEsZjzXsLEfPJ6WLYEyyKBqjfBfaqskiKdf+/R3DIIuC6YXNJZj28Dp8eVzFCkbOY0iGWr1Nkhx3GYM5LqulDCk9FZAlLTBuqXhM6S6o8Kq7oGNMVmOHwd6VQAmy0m0fxkaz1X4x59L8O4BvvAhc/rjjMSXIa3ORyfKiu6CrogIT88SH2XEPQZZzdqzTaOmXqeproCBrny0gGOWmXHfUdhccZJBV0tiJToMFeo2M8TmJ9kyW0WxFj6fXmB/Jsoy3dosCI9Wt3TCaHd1hzRar/aYA4LgRUNIg3iuFGQmus1h7XgbeuKV3Rts4+O6CcXqt/Tjuugx+fdLRTcnvlcBc3aX3kdUqR3wZ6F5BFgB89Sjw/yYC/7lIZLGULM+xDwGrm27XrihZLEnruKgbbHVBJTul0Tt6GvSljMly/nsrwUNCpiPIOvkZULpZ3NmfaJuEV+lCGAGO17ThR6vF3+zm0eI9X6kfCRkaNHT0uSmY4CLb29Tnpk0IdRe0jy8eIMhSCl9UtnT36jWypbghaJ/RIUPJThXMcjyWbZscXil6oXwXKpMUp+Q7ugtyTBYF09biBsgysOuMipXElFR+Yg6g837eGrdUGJelObkOACAXzhdznACO7oIKd3csnThnsmpaHReHibZgKDFGC71WDO5v9pSpkSRg2jW9J150ri5ocN1dsG9gYjRb7cfxFGR5ymT17XY2UDar09j7S6PZaa6s9h4T9pSK1+qFE1yX4XYuUOCuW5dP7N0FIzPIOlQh7sYPTwR0Wg2SYnRQ6ke0B6nL4M6SJnvQZJXR625t3zYoGcozthsSyjwy/Xz9V+DYR0D5DsdjQ+guCDiPy3J9Xr4+6bi483+QpWQ2BhdkybKMG/6zHUv+trFfxc5IItUeEivD54tl5R6Ryag/Drx+k2PD9urer42BKJUF49NFRTNg8JksJWBLyADcFWtRMlnO32PKZ1Fitij4pE9w9AKZerXjBl8EBVn3vXEA7QYzFozOwDdHis+2+ngxzriuvc93iasbEc5FLwDHlBwhoGmAyoKKYbbugtWtvbsL9pis2O7FXJURxd6zxOn7OGs8cMNrwHUviZ9jnao4504HcqaIzwEA4+r9PE4zwBhkhbnKFvEh1S/tHkzNYt4lv83Ebu8WF7zugpISZI1f4XgwIUN03VOkeM5kdRjMve7MVdmCrDi9BjqteKtJkmSfOLjZxUS9HiXlirudVhMKekQa3dFd0BZk9QnclPKyOo1k38bZJG+CrD6B24BBVr9MluP5W4qbYLHKGJOViBFuMllKwCjLfgoSwqa74OAKXxy0BVmFiSIg1Wgk+xjAYHUZfGN3ea+fS5t6j0101tItusoqQdnoLBdBltXiuGOtjPcEnLoL+p7JAhzjslydl8YOAw5XOu6S1ve9AByqIVYX3F/egp0lTThd34kyp7GfkcaeyTrnbsckpMqdbqULuTLtx9H3vd+xc2A0hEqeAJwCtgz32yjZstZKRyEL5+6CuhigcIFj+7Pvcvw/IyTIkmUZp+rE+/exa2dA1yDm7+tIGQ9AjFHqRfmMdi58oZRvzxTPCc3ugp5vLucmx0KrkWCyyNhf3gIAiLdNvxJ1XQa7nLK5ziatBHImi3Xn+xbn3yduZIxZBAAwaIc47j/IGGSFucpm8WWrFDdQRZMtyEr3U5CldL8YyqBkX/S0QbLNPm51DrIkqXc2a4BMVt8LnxrbBKvKBa8ifbDd4XQxwOQrAABXGj4A4Fxd0DYmq88+lbvxWUmx0Gj633GdaKswWFzf0auLl7O+F6QDFb/oV/jCqbvgJlt3rAsnug94YnVa+xfQkLsMyrJTd8FQL3zhdOHnQ5ewQ5UtAIARSY7nKBmbYGSyzBYr1tjmWstPFeP+eo1NtGWNkm3vA9GN0eo5yOqoBay2titz8AGOTNYguyV7qjC4uc+k3IHrLtjgWzc3m08OOjIi/bpZRQitxeDoTjRioeguff79wA82iTEcgCgAtPwPYv3Q28C7PwD2/m/gnTsHRsrNM39kstxRvi/M3Y5jO2eyAGDMhWJZuAAYNtcRZLWU9X7dhymD2QqjRbzWs5Jj7WPNNHlTAQBn+t4sUD6jnefJUroLjjpPLEOou6AyEbG76UgUOq0GeSnis/FwpXjNXTZD3LRVtReSGtwFWc5qDjnWp1wpluf9BOZL/4GNEx8JXNsCgEFWGGvvMdm7cqk6k7g9kzXGP/sb6hegr0o2QbKa0BGb67hbplDGZWn0A3Y3c54jCwCqbVnGxD5BVpo3FQbdWbgKALDcuhnZaOlXwr21u3dBDU/jsQCgIDUOyXE6mK0yTjd0uNxGCXQKbBfQNQNksrr6FBVQMnay7AiyFk103VVQYS9+0T3Ei0ljp7jIAcKnuyDk3tkbD8wWqz37omSyADiKX/hrXJsH7T1m9JjEhdSyKWIcSrlTkNVuC9LzUuOcusoaccYWZI1yFWQ5dwnqlcka/JgsAB4rLyqVDs8eIy6c/V74Qnn9yRag27cLK6tVtgeygCND7Q8vbinBlU9uRqOaN+psUnrKIUEWWfukHGDWjcCSB0U39Mv/IboOLXkQGL9M3JToagAOvg58ct/AExc7B0ZD/Y5x7nroji7WMS6r1Zbp7exzgTn/B8DiXwPXPCt+Tsx0BGANJwbXthCifP5oJCBRY7IHTEkjpgPoPYYZgOP/bupyBJlKd0Gl3HdXQ9ArD7vj7ZgswFH8wmzrAn/eePF5UNUaZWXcvQmyLviZWF79jCiEAQBaPeRZt6A7NsR7pPQRlCBr1KhR+O1vf4uysrJgHC5qVLU4PmgaXHxBWq0y9pU1w2Tx/a6pT5rOiKW/ugsGO8iydT9ojR/Z/3dKhcHkPEDj+e2ilG9XAh9lDpDEGNeZLI9jstwZfhbk4fMRAzNu0a3vV7XQYpV7ZZIGCrIkSbJPlFjlpoy7kokYnyuCgIG6C3bY5snKTRHHVP6fVV2iolycXoMFoz3cAYYfi18o3U50cUMvyhJo+ngxKB/wustgcX0nuk0WJMZokeMUdwQzk6W83uL1WoyxBUxljf27C6bE65Fqq4LZ1Gm0v1/GuAqy2ioc68ocfBazo8rYoMdkifdL3yqclS3d+OyoCLLuvFC85xs6DP4ZE6jQxYjqqYDPXQb3lbfYux8Djjvo/vD85hIcqGjFTn+XrB+E1G7bNULe9P6/zJsO3LVZBF66WHEBNudbACTxuuhs6P8cZ70yWbas8WALXyhBsqdMFuDUZdD2eu6byYpJAC74P0f3R8Cpy2DR4NoWQpzf+1LDSUC2AnFpGDZsFADRrbjXeywmSZS3B2wZX6fy7cPmOt737aGRzXJVXVA68BrwyU/7Bf3KuCzFvFHitdPSZVK/QnQw2YMsDzc9Fz0A3F8EzLwhOG0KoKAEWffeey/effddjBkzBsuWLcPrr78Og0H9u2bhrrLFcSHT0GHsd0Hw8EdHcPW/t+KVQM8s3uzn7oLBDrJsX7QmrYu748oXXpqLAKwPZRzKlPyUXo8rExEr7OXWfR2TZWOeLyYr/p52DRI6xd82Tq9FrK1yWpPTfpXgOzvJfZ/xgealUr4ox+eIvtADdRdUvjAK08UXonJBe6xFZDEWjslEnN79nCKAI8hydfPAJ51O47HcDVAPFZLkGJA/wAWjxSrjzd3leHqjuDM8pSAFzr1BPY098jflGElxOvs4u1IX3QVT4nT2v+vR6jYYLVbEaDX2ylu9tDoFWcodbbPTTYBBBlk5ySIb2zfz/79tpbDKwLnjMnHOWPHlb7LILqt1DomS2XAu6+2Fjw/2vqj0V9aprcdkH4gfrCIpniQYbEFI36JDrkxaCVzxuCNgGWjKD/vFnT8zWT4GWc5jstzJFpOJR8K4LOU7JSVO7yhLnzMF+Wnx0GslGM3W3j0jJKl3t9rWchFAa2PEuVQq7IZI8Qul8IVzd0Htht8Bu/7jmHDXpiDNMYVKcqwOBWnx9h4H7m5wRhyLGehuEeueMlkajWNsfpgLWpC1f/9+7Ny5E5MnT8Y999yD/Px83H333di7d28wmhCRKp0yWRZr7wuCw5Wt9klBD1UGMFgx9Tj6SIdrJss2+NmscXGxN+lSYPFvgBV/GHA3ShepqQV9g6w+3QUTh5DJAtAx5lJss0xBomRA0sc/FB9ccIxtOVHbgdYuEx764DDe3Se+jNxlsgDniYz7X2T1mCz2sVoTvMxkKd0FlTt3StB3rEV83AzUVdD5WEer+t9prm7txktbSryrsKbiRMRtPSb87O0DWH/UhwvqtBFi2eI56/9VUR1+9vZBvGf7+84Y1vs1l2zPZAU+yOqwXZwnx+kwIkO8BsubuuzdVp3vZitZXGXw94jMBGhdjBV02V1QKd8OadBVTJUxY9VOWaEekwWv7xLn+1sLRyFGp7G30+/dsJXiOT5MT2E0W/HRAfEZO832d24Y5A2avo5XOzKmwSqS0tZjwqs7Sl0GinGmFrHixXQZdsm2wLV9gPeZUnY9KReItX3HDLbwhTdjsgAg1TbJat/ugp6K8ERSJsv2nZIarwfqlSBrMnRajf0m3Jl+XQZtAWhXg6OrYPoo0W3MPo1JaGSymjqMSEIXJp94Gmg8BclqhqRkqUu39dp2WJrjxpByY2mYU2n3qNDTAvu8o5662kaQoI7JmjNnDh5//HFUVVXhoYcewnPPPYd58+Zh1qxZeOGFFyJ+/g9/61uZR7kgsFplPPjBYfvY+X4VfPyppRSALPrH++tCNuhBlrjQMGld3B3X6oELftp7Tgc3lPM/Lrf3hMx9g6x0pUjFYMZkAegyy7jfdCfa5ARIlbuBnc8AENkMQAQmr+0qw8vbSnG6XnyBjcl231VO6VrmKpOlXHhpJMc+BhqTpXQXVGa5b+sxo6XLhNO267lFHopeKKYPE68BpXKeQpZl3PXKXjz80VF8fMCLC1Vv7hwHyH82ncabuyvwt/U+jK1QMqYtnrPPStGIERkJuHF+Ib61sHemNcU+Jivw2QklA5Icq7P/zTsMZvtNBGX8U0qco7ugY640N6/LNheZLOfy7YPMSua6GFf4xbE6tHSZMDw9Hksmiwt25aaE34OsQUxP8eXxWjR0GJGdHItrZovMiL8yWceqHUFGsDJZf1lbhF+9dxjXPrUVFc29Cx/EmWzd8HwJsrzNDiq/T8pRL5PlzU0fJZMVARMSO26w6JwyWaKCnNJNvbRf8QunubKUIijK2GilmEgIVBiUZRmNnUZcrt2G/L1/g3bTY4g1O72eyvoEWU7dBZV1xyTFXo4x2/RX4J07Bh5/GKqUbHJcGqDVedw0UgQ1yDKZTHjzzTdxxRVX4P7778dZZ52F5557Dtdeey1++ctf4uabbw5mc8Je3xSzckHw+bFa7LVdxAABHljpfKfJX92xlHELwc5kaeMG2NAzpWvb2D5jTJLcjMkabFekLoMZVcjCM5rrxQMnPwPg6KZ4tLoVu8+Ii4ArZhbg3zfPweUzC9zuz1N3QWXgcnKcHvm2L4Tath6PY1WU7oLOd+4+PVIDqyxhVGYCRrq7sHYyY3gaAJGRdT7W9tNN9ixItTev675jIIKkw2DGf7eJQKnvhaRH6bZgqdlzkKXMwbZiai4evWaGPUOjCGomy6BksvSI02vtVbSUQe2O15Cju6BycT86y023v17dBW2ZLKV8e8zgugoCjkyW8xx2Sqb/wgnZ9qyaEmT1m8dnqAYRZL2+S2RBvjF3OPJs7ffXmKxgB1ldRrM9+3qmsQvXPb0NdU4BryPI8qGrUJJt24G6C9ozWTmOMVmGNrcXrLIs46mvivH6zrL+N4C9zmQ5BVnORXg8fR4p3e5bK3yqMhqK2np1FzwqHsyZAgD27wGlAI6d81xZSuViZWx0CGWyuowWGMxWZMN2ndJ8BvFKJhYAKnYDZsf7dJhTd0Gl62CBPcjy4rvMagU2/n/23jvOjepcH39m1Otqe6/uvdtgU4zppoZAeoHkJr8UIDeQ3IR803NJuekkpCekEEhCAgkhxmCKMS5g3HtZ2+vtfbWrXuf3xzlnZjQa1ZW22Ho+n/1IK42k0WjmnPO+z/M+7/8Bh/8G9B0d9/5PCtIxvbjAMCFB1r59+2IkggsWLMCRI0ewfft23HPPPfjSl76El156Cc8888xE7M4FAyXFzBb5T+0lC5Rb6cK6x+lHJJcF3HKwQTBXUkFg0pissFpNVpqIRAVRFtesYI1y6i4IMrgDwDHdQvJA9wFAEEQm62j3mNjw9551Tdi4qBo6TeJL3Z4kyBpl9TQmLSpsBnAcqVUZSiJXYkGW3aQVGZWn9pKF1RWz0mOUZpRbYNJp4AlGYlwPf7ntjHjfmY7cktU2TfCg/pfd7eLxdPnD6Uux0mSy+mhCpdKunhiQarImgsmiNVn0PGd1WczGXS4XVPZqW1hbBFXEyAWVTFb21ykLAHvH/OLC+RgNNObLZL6JarfGjQwXid1OH7adIomCd6ysF13Mkl1/meC4rD/eRATkzx3qgTsQRl2xCS1lFvSM+vGrbVKz2ZzLBbv2Al20JMEjlwvK5LUJTGbODXrwnc0n8PmnD+Mjf9wT24MwYyarQ0z4RHgDnjmaxF2S1WVGQ1JiYZqCjT9lupAkgaZMVlNpArkgu0aGzkj27Wx9odZHa5LAEh3FGjIucWNdMMiDrLAP6Dko/iuvPWUJyJpM5ILeQSBCxyN5Emo6oRBk5QerVq3C6dOn8fOf/xxdXV343ve+h7lz58Zs09zcjHe9a/o7iUwkmAyQDVYDrgCG3AGxud0nrpoBLc8hHBVyn5FlyLXpBSAFWYGxrPrJZIxkcsE0MewJIioQMq/cahAdBgHAqjC+kOSCWdZk0SCm19BMCoL9TmCkTWSyOkd8GPGGYNDyWFCTYBErg1STlVguaDfqoNPwqKAZfjkToIQnKLkqsgX3YWozfuXs9IIsrYYX60+YZPB4zxi2nhwQt0mLCVTrLp9nRKMCfrv9XMxjaRc2p8lk9dHjnyjImkgmyxWQarIAImEEpBpFyfhCJ+4XQJJANy9WYVjDgVj3PZHJkskFs0QFdbwMhqOinJGxOXLDGonJypdcML1G6/851IOoAKxuLkFzmQWl1MBm3IYwIImhk70Ty2Q9uZsstN+zpgFfumW++NioLwQE3dBF6XWSCyZr4CTw2+uAP9xCEnbMEdBSAeiMgI4mwxKYzPTJDH5eOt6Pb2+WyffSZrJojaW7D2P95JrujVjx6b8dSszE662Sy2iqRKO7n0jIUtWjTRJYoqmBp/tnLhWPWSNVfMTJBWtXkNuuvfFyQeZIx8b1SQRro1Cho3OhZwDm4EDsRu07xbtmvVZ0IWRMFpMNphVksbo+oBBkTSNMSJB19uxZbN68GXfddRd0Op3qNhaLBY899thE7M4FgVAkij4aOC2tdwAgE++/DnQjHBWwuK4Ic6vsorwkb+41eWGy2GJHyL4wORMwJkvN+CJNsEVPiVkPrYYX65wANSaLGV9kl41mr7NZLaL0Aj0H4DDrxV5WALCkzgG9NvUlnozJipF7AKgqovKGJFI9D63Jshi0+Nl7VuByyl6ZNQJWN6Vf7Lqo1gFACrL+QRla9p3SY7JYDcTEBVlDniB6Rv3gOMmRMe3rz9FEbp3nk0qF2LVfVZSAyRKb7k5cTZaVBlmsdm83bbLpktVlrGouhobncMeyWvzgHUvUTS+ULE9AIRccB5Nl0GpENqh31I9+lx8DrgA4DphTJdVSVkyRmqxzNMt/SQtZlJRRFzOXP5ye8UsStA15xP5mAOAK5Dcgb+13YX+7E1qew50r6rB+djlmV1rhCUZI8EVrpgSdRdYzLg1YqZGOMtB48UukoXXQTaRbAMBrpYJ7xpgksANX9iIT60MjYSn4ScVkmUsALTlfX311C3lfQUqGqYLj0ldz7Pgx8Mo3gDceTb7dJIHNHxVwkgdsUlKF1WO2DXli5Zis6XT/cSmJy+SCLKhlTOIkgilXSnjyO3IQ4PC2kSc11G1QYX6xuK4IHCcx+ExCmNb8IA+s5AHXdEDfUaDn0KQpSyYTExJkXXXVVRgais88OJ1OtLTkqIHtRYbeUT8EgSw451SRQXvAFcA/9pEL8e3LiUyB0dEJB/TxIBoBeg+R++lY7qYLrUGcmCZEMuhnNVnjD7LKaKaZZfUBlSBLxhxlY/bCmvuWmPVAzTLyYPcBALGSpxVpBjRJa7L84Zhtqu3xNS1KMKbNYtCgodSMP35oNR7/0ErctzCS0rpdjsV1zPzCCUEQ8NJxsoi6YUEV3d80glQ2qE9gTZZb1h+NNdpN23zGUQ+AI6xNggy7IAji8a+0JWKyJs7CXXIXJOfITYtIIPH66QF0OX3iOWQz6rB2RhkOf/U6/OCdS6FNJGFVZmnj5ILj63dWJZpf+HCcuus1l1lgltVO5q0myy5jstIoXmeMJaslsxt10NLAdHickkF5PRaQfybrGD3WyxocqLAZwXEcPnI5mf8f23EOUea4aKvKrL6XsV5y44szrwCnX5D+Z0GWpVzqd5hCusmOL+vjJjbY9juljVI5pHGcKBm0jZA6mmEaZPWO+uENhvHb7efi6zbFIMuJpKDjPgZPJ99uksDGnzLOSR5gATGI6YOG5+APRfGJP+/D3/bQwMFWSV1WBRIka/SS4QVbnKfqiTYBYJJdBy/9dsVeyrzNuJrctu+KUeP87L3L8dpnrsKMcpJ8Y+uz3tE0SjpigqxpxGQFvcDvbgQeu1EKmlMxwBcQJiTIamtrQyQSP6EEAgF0dU2+S8x0BKOXa4qMYtZ1d9swjnaPQafhxHqsukzdazJB23YysRkdQN3q3L73RNZlBVifrOyNL8Qgy0YyWHZZ7YlVEWSx56IC4AmqL7S8wTDufWIfnj8cn/Ee9pCJq9iil1wPu/cDiJU8rWjILMhSW2SJTBat8al2xFtgx+07DTLY9+Y4DmuaS1CTocqLBVlHu8dwss+FtiEv9BpeNPFIywJ/EuSCHlmQKVn0pnn9aQ0S25GgLmvUF0KA2uoz+ZsS9iS/aa7BmCob/b0bSy24pKUEggD8fU+njA0lz5sVRjBxYM5hrG6GBVnB8ddkAbK6rNGAqlQQyKO7oKUC4HhAiEgsaxKw64wFhjzPiZKj8ZpfHKESXsY85vtcYY6IFTKJ661La6DX8OgbC2Csn0gJhUz748jdBVnS6rX/i92mczfdVtY+IoVTHTu+S6hSZMwfJokoxqIYitJzSKNBVnOQBEI+PRmXe0f9+Me+LnzjuWP40UuKICmd+U8QgL7D5P7wucTbTSIYk+6IUKkm+61AEsQz6Ln3/JFefOHpw/CH6HxYt0p6k+JmYt8OyOzdhybdFIQF4TZINWXWAA30Z19HZM1+JzAoWfGbZRJ6gNR+atIt6ZiuQVbvISAwShjlM1vJYwUmKzd49tln8eyzzwIAXnjhBfH/Z599Fs888wy+8Y1voKmpKZ+7cMGCZcZri03igoCxVZfOKCMLcMgLKzNwOEsXh58itwtuB7T6pJtmjIkKsgRBZnyRfa0HW4ylw2QZtDz0NIuvVgcFAJuP9OK5Qz34vxfie6UwuWCJRQdULyUP9hyMMb8AgOWNOWCyFHJBqc+QD95gGEe7Y3+faFQQA8eUi+kUaCq1oMxqQCAcxX1PkCDykhmlYuCSUi4oCJPSJ8stCzKz6oMi1mW1qT7NakUcZl1CZnBCmSxFTRYAvHMV6Q/0+JvnMUAX13aTulQ8DmwBways42qyxhlkiQ6DPrEP2zxFkFWRr5osjZYEWkBakkFmNS93j8xVXdbrp8m1cfVcsj/5D7LIuFVmkeYKg1YjyqeDI/R4WLMMssJ+Ml8EXEAHDarm3UpuRSZLHmQlZ7KYXLCu2CRKTDtHvLJ6rDSlzyzIEghTo7ESVr13zI8z/eTcjpOLMfOLZPPfaIf0/EjbpAcdahCbEUfoMZMHuQB++p7l+MLGubAZtAhHBbE1RUyQxaSCgCTPjIYmppQgCViQZRXc8U8WNQB1VPZ4fmf88xQanhOTPiklg9MoyOod9ePHL50mgSMzngGk1hyFICs3uP3223H77beD4zh88IMfFP+//fbb8a53vQtbtmzB97///XzuwgULafI1iQt7husXSNkiVliZcyYrHACOkQAaC+9Ez6gvoYTMHQirLuCTgtVl5TvICvlIVhlAaFw1WXQBQX+L2Jqs2IUwx3Ey9zf148L6W50b9MSZF7DBvdisJzVZMvOLVU0lcJh1uGxmmZjxTgXGMLgDYYQjsUYjcmc4QKrJ6hn14xvPHcNNj2zHi0elgnNvSGLmlAxepuB5Dp+5bjYA4DRdjFwzr0JclI36gsnllkEPWXgBEysX9MuCLPH6yyDISuEwyK79qgSmF4B0/rkD4aR2+7nAmKImCwBuXFgNm1GLAVcAwXAUVXajGHCmfkPKLMiDLEGQWbiPUy4ocxg8ruIsCEhsi8sfFpnJnCHNhsT+UES81uW/NavLGg+T1Tfmx9HuMXAcRPORfJuksKClVDFfsXGKyQUzZrL0Zon1dPcB7W+QMb24CWi6jDzOZHcyJiVduWCpRY860czFl76zIANrLkxhdJDv1zvmFyWIcb8lSzL6nInft/ewdD/sS90nbBLA5g9LkCoK5McfpOn8R6+Ygdm0HpKN8zFBVomspERvloxvJtn8gv1mpohKkGWrBBrWkvuKfllKSOYXGTBZrh4gMjHNw7PBr7adxQ9fOoU/7jwPdO+L32AS+lZOFvIaZEWjUUSjUTQ0NKC/v1/8PxqNIhAI4OTJk7j55pvzuQsXLLplckHGZAFEAn7tfGkgE5msXNdknd5CKGBbDfw1a3DTI9tx0yOvxwVT0aiAmx55HRu+txW+BNI4VUwUk0WzYQLHI8KrS6/SwWASJkst2EhlTCC3LWd1IwwSk6UnDGLlAvJEzwGUWg3Y+fkN+O3dK9PedznDoMxmS85w5DvUyJisl44TB7jnDkmLRSYV5DnAqBv/8PKOlfVY3uAQ/796XqUYZIUigmhnr4qhVnJrKh73wjwTiO6KBm12118Kh8G+seTOgoB0/gmC5P6XLyhrsgDAqNPgwWtno7nMggevnY0XH7gi/Xo8xj4yx9JoGIgEc85knRv04MwAuc4WKJgsu1EnBjOt/SqLqPGAFf8nMFxgYL+zUceLbDMgjTFKY4ZMsPUkuXYX1zlEA4JAOIpgOH9uriwRVWqNTf6w65lj7oCZMllArGSw7XVyv+kyKWEhbidLtjB79QRyQba/JVaD2GQ7lslKMxu/8h74Z9wo/msvI0F236gfHbQWK+63TGf+6z0S+/8UlAwyJYQxSGuoFEwWAzMIEq+1qkWSeUSJom5fdBicXPOLYU8AHKLQRzzxT9qqgcZLyf3zKYKsdHtlxbBXwpToFZYIrXRc7RjxxjJZDAUmK7c4d+4cysounsh1ItAravVNKLHowUy6ltU7xB4vAGLkStmYLCTEmZfJ7fzb0OEMYNgTxJAniKf2xLre9LsCOD/kxZAnKPajSQsTFmTRAMZgG1czZSaJYgGvPHBRk82Jcq4EDB9jsgDSkFcOkcliTFUZYXuYvMys18KgTd9gQqfhYaGW88ogmWUii8yMySLnVsewT5RIvnZqQCzalUwvtOBy0Jya5zn87+2LYNFrsG4mkQqadBpRbpnUxp1leqsW5a5RdhpQkwv2ufwIRdJcwKZgsiT79sRJAaNOI7ow5puhYK50ymTC3eua8epn1uO+q2fFMLsp4aEZ6uIm6bGgJycW7oB0Dr/VNoKoQORg8kQVw6wKkl0/1afeRylryM0vkqB3VFIryK+l0hzUZL16ggSyV80pj2Eg83musJqsUos6k6X1UnfBTHpkMTD2y9UHnGNB1uVSwoJBjckaVQ+y2DhbZtGjvpicc50jMiYr3eJ9vQXnrv4lvhi6By9zlwCzrwdA1AAdwz7xs2IYZ6OD3CYzvmCmUwwJ5MVqCEWi2HykZ9zmKckg7H8cf498Ci1cN/Q+mjhRMFkMM8Ugi15rWgPQsh4AB9Qr6r3ZcZ9k84shTxA2eMEhdl0l8FrCctauJFb8Y52AM7EbYJVKg/Q4hPxSWwtmtjKFHQbFRvQjg1KvMzkuoiBrfHqeJHjkkUfw0Y9+FEajEY888kjSbe+///587cYFix6Z65SG51BiMWDQHcD1C2KzgKwfgzsQxpg/HJMRHReY1KV8Tky9yWM72nD32ibROUzeaPBY9yhWpFknFNMrK59g76/PwDJYBZJckCwYUjJZzGFQtqg53jOGU30u3LK4RtKmAziiqHuKcRcEqBMTpGaPWaDIpIMnGIkLshhrxs6bSrsRHBcr/x/1hXCgYwQrGksk+/Zx1mPJMb/Gjp0PXQ0TZUI4jkORWYcBVwBObzCxDE0MshbnbF/SgVwuWGrRQ6/lEQxH0TvqR31JGgFCKibLlVouCBD2cdAdJGxk+s75GYN9X7sxR7+5V5b11hqJ5DPozomFOxBb3wQAd69tUk0IzK60YtfZIUnClCuwgCCFXDCRLFSqycpugRwMR7G9lRzjq+ZUQMNzsOhJ42+XPxwn58sVlGMkA2vObmQL8UzlgoC0eB86DfQcIPebLot3/1MzvvAOkkWsLvY4ixbdVj3qS8g51zHsBcwZygVBTHoej1yLN21vw2MVNQBOxsybUYEkjESJd1pMFh3fSlqA4bOSc1sa+M+hHvz3Xw/gHSvr8H93Lkn7dZkgevCvmMF143p+DzReGiAk+G1nKpksALjjV+QaqZwfu7Hc/GISMeQOws6p1LpbK4mDpcEKVC8hcrn2XdQ5Nh4saZI04GVsq9ZEkobntk3ZuqxQJCr6A5SMHiMPOhqJvJGx9xeRu2Degqwf/vCHeO973wuj0Ygf/vCHCbfjOK4QZGUB1siQub1dO78Sr5zow21La2O2Yw3whj1BdI34chdkMf23tTKm3qvL6cOWY324kdo4t8saDR7tnuJM1jigtHBPVpMlf14uz3vwbwdxrGcM/lBEdI8DgKNdscdtWC4XBIAiOngnyZalgt2kQ/eoPy7IGlLUmuk0PMqthjhDgK0nB0iQJUrl0mfS0oHyvC2mQdZoMvMLOZM1gfDI2Dye51DrMOHcoAddTl96QRZjskY7ic03H3sse0fjXdrUYDfqMOgOpsVOBMIRnB/yQstzaCgxJ7ZXV0AQhLg+WeOG2EuljDRmDftJr6wcMVlymaXNoBVNOpSYVZmYyQpHoogIQkaMsfSh6ckFlc6CDExul61ccH/7CNyBMEoteiyi/XpsRp0YZOULIpOlCOKKzToAgtjIVchGLsgW70f/CQhRIjVlckBLuSRBlRtfmIqlIN7VE9PrMRIVxARTqcWAOjmTVZIhkwVpzC626Kl9fbxPxZA7oBJkOdXf0D8qMd3zbiH9sjKQC56k53TbYG4MsX617Qx6Rv343A1zRVlw1NUHDYDFmnPg2DybSC5Ir7Vzgx6EI1Ey/piK1S3yGQsyyUHWsCeIFsjaSoRIYlSwVkFM2TSuJUHW+Z3A4neovg/7zZP2zWRBVlGd1OB6ijJZnSM+UdlS4z1O9HK1y0kiw9VN2D1D0eTu5AQib3LBc+fOobS0VLyf6O/s2bP52oULFv5QRLSvrraTDNu37liENx66WrU5abWsL0zO4KaZKWulqCVmEq4ndkuMipzJmpJBFu2RJYwjyIpGBTELxWRHMe6CKqyOaHxBg5poVBB1zL95nUyWxVSi1zrgFq1tfcGI2EBUlAuyDNk4mCw1Zk0QBDHIkptoyJmA62j936u0xsOjsG/PFxwmsj8J5YKCAPTRmoUJDrLclM1jQUdNJg0nASJj4nXEQUtFd9+fJpMlOQymXji/45dv4LofbsOG77+G9/7mzfT2E6SOJ0wnVFsmksBEiEYAH7V7tpRJtXRBj8zCfXxBls2oE8/P96xpSLjfs+nC73RfLJMlCALe9rOd2PC917JrCJyhXFA5picyvnhydzu++uzRlEYnbJxZUu8AT3Xm7FzNl1zQF4yIrqPKmqxisx52eKAXaNCYQFKWFGzxzuyymy+XnpPVZX38X5144yxdnHNcQvOLEW9QDIKKzTrU05qsjhEvIkzOmqpHVsz7keNaYibMtlIyCSiYSfbeiea/PtJzC/Y6qVdiBnJBxjQMjNOhEiDrkW8/f4KoWB7bLcqleZqIvYQ/TjbUGiWDEgVqioww6zUIRQScH04R+IlB1uTJBX3BCHyhCOwcXd8U1UFgY5X8/GUOg2wuUkFxOkwWY62K6qTkwRRlsuRrvvkCrYuuWSa1mzGXSL3qLgJcPN/0AgKbfM16jbhYB5CwBkZyGstiQaAGQZAxWRXi4vHqeWSikxs1yAfMk72u9OtSJpzJUh/808GINyhmblgwwoIWs14jLmTkEI0v6KJmwB0Qi86ZPGl1cwnKrHpEogJO9JL9ZBlRvayOSmI+OrK28VWzcXcFwgjS30u+KKgukuRaD1D3vyNdY+hy+sQJdrz27Sn3lwagCW3cneeJFFSjl2rWJghuRY1SbabmF7xGmkhVAmexEXEqJkvslZV84dzl9OFgh1P8f1/7SNr1m+z85TjAnEGj6YTwDgOsxsFUQpgsIKdyQQC4fkEVah0mfOiy5oTbzK4kny0/rwHCPh/uGkWX04fzQ1kwAazmKEXheq+iETED+93PDXpEM6FQJIqvPnsUv9/Zhr3tI0nfl52HzMwByCwgzwaMddNreLGfGkOxWY9K2qw2qLFk9/vaZQoOex2wVqaOkdVl7ejl8e+D3fGvU/wWLIAtNuug1fCiA5w3GMGR1jYAwGDUmvbujYh1tKy2NT7Iillkp5r/GGtVNksyiMlALthFDTdy0QeuZ9QPFte/cXYY//P3g0A4AE3ACQAoBk2uWisS1sZyHCdKBpVJjThMASZLdMrk6ZhkcojnUow7JjPtSMIyFlO57MiFEmTJSh1mcZSBq1oE1Cwn97NhqqcxJiTIevvb347vfOc7cY//3//9H+66666J2IULCt1UKlhVZEzLXIBJt7y5chnzjZAsOwBYK0Rt+bqZRCs96A6Ii6/zsqxGMBIV3bxSYsKDrPQnTCVYNrDYrIOOsnksiFL2yGIQmSPq3tc5Er9Yaym3YkENOQ7M/EI+WYu/PVsohLxZOy6x/ZUHWcN0oWHRa2DSSwtolllvKjVjbpUda2eQSe/X286K5iY5k6UmgIO+f0KJBZMKVswDNPndFyWkujRyzFjRfFsmC3K2MFSYX4QjUVGaWqmyUJMjlbkKw64zZLEyhzI3KV0bZZDXn6klEzIGy06biklPKTmTxeSCOXCK/P47lmD7565KGqg6zHqRmZbXirBaKSBL19aiegAckYIlKd7vSVCTNa/KjvoSE9yBMJ4/Quq6TvS4RInxyd7kRh1svJbXMtqM6QXk2WJIVo+lnLNKLHo0cCRp59NlWasx+wYSWN3+c+D+/ST4YKBJqBC0GIMl1pI/QUNitohmSTODViMazRhCZCw+PBy/fPrzm+fx/t++mbztBtRZ6Bj5ZyoLdyYVc9RLBjGeAWk+SwF2DrgDYXiD41sXsGuAXf67z42oN9pOwVDOLFeYXySCGGRNnrug2FrBQH8zYxEEJgOWf08WAHsHE/42rLZ6JJn0nf3eRfWyBNzUkgv6QxF4g2FZ4klAHcecYpuAmVcDV34euOGbk7WLk4IJCbK2bduGjRs3xj1+4403Ytu2bROxCxcUEmU4E4GxCu5cBVlMKmh0AFqDGPTNqbKJTTzPDXggCIJ4wbHiTmV9UUKk0qRni7OvAX/7AHGhAqSBbxzGF4Ou2LolAFhYa8fls8pw99om1dfYFc1iO1UWay1lFiyuI8dh73mSnVZO1gBIwTYb2BM40qVCkSLoA2QLDYW8h2Uc19Kg+hPrZwIgcqXHtrcBAG5fVpPVfqQLqVdWgolpkuqxAJm7IF24xjlnpQPGTirML5gjnl7Do0xFciSHWt2fGliQtWFeRXqujTKw91ayE1lDXo8FSMmPoDtnFu4M6SSoGJslr8vqHZMWw52Z9D9jMFil+h95ryMFemXJNDl4nsM7VxKJ8F/eIgutAx0Se6UWZP3k5dN42892wOUPiWNNrQqTla+arEQ9sgByLc+mGe8xU112H2CwAtd9A1j6HtLWQg6asBgQigBwomwRgCQXPPoM8IMFwDfrgJ+sQKCXyA7lDH59sRkcoqjnyPy33xkb7AfCRDb3+ulBvHYqNsgYUdTRyoN7xnjHyAVTJRnli26TQ5IXJjDLid3PaExNLZu/skWXk1yXTF474g1CUJPCpgqyKlXML9TAgqxJdBccYkGWnl7/RgeEWdchzOshNF4mbWi0S/ubQM7J2E1fKJK4zQ0zybHXUGUGR6SxSRodTyT8oQiu+t5W3PKT7eL4U4YxGLkQBHCEXeY1wFUPAc1XTPLeTiwmJMhyu93Q6+Mbo+p0OoyNTW7X7umIHpm1bzoQmaxM+lQlg8z0IhIVxKCvxmFCcxmZeM4OujHiDYmT9nW0QXLadVmihW2OmaxdPwWO/Yv8AaTXFwDBmL1cUGl6AZDM558+vAafvGqm6muUNVBs4SNf97WUW3FJCxmgd7QOQhCEuMlahGN8xbBqckGxr41iMX/nijo8+p7l+J/rSbPYdTNLsaTeQfrsRKK4fFZZnMtlrsEcyZwyJqt9yCvVyEySsyAgt3An190s2eIh7TYKKkxWMBzFl/5FtP1vX1GbkjliC+dkjcAFQRBrVC5tKZXJMNNbeLHvmpN6LECSADEHMZHJyq1cMF0wG/fT8iBrdJxMFiAF//JaDc8g0LkXAGEsmZRLrc72zhX14Dlg97lhnB1wY3+7U3xOLcj6y1sd2N/uxI7WIZlcUKpts+c5yErUIwsgY9lsnoxbLmNt3PPjRvVSAMDpKHnvWCaLBlk9B4jVdtAFDLXCcfa5uP1tLrOgBkOwcAEEBQ1eHYhVP+xoHRSPX4+isWwyJosl0obk9VFs/guMAVEVib1cPgZIbFYadVm9o/4YVfmAO0UT3BRg5xNTXUSiArzDKlLYBKYXDM20X1tHqmtqCrgLisysjh47YxGiqz6K/yz+FYT6NbEbMzYrgWTQatBCpyFjeUJlhrw3W1EtsOKD5P//fAaI5LcPYjo4M+BGz6gfZwY82EXnkxVFZK3n1lfEJz4uIkxIkLVo0SL89a9/jXv8L3/5C+bPn6/yigKSQXQWTJPJYsYLnnHKAkSIphcVGHQHEIoI4Dmg0mZAC6X8zw14RKlgdZERyxpIpu1od5pBU77kgmzfmTwkF0yWokdWOlA2I2ZywQ1zpIloRrkFKxqLodfy6HcFcGbAHd8ji2GcDoNFpnhpmdgnRrEwMuo0uGlxtRjocByHe2kwqdNw+NqtC3LSIyv5/sbWZL16sh9XfPdVfPv5E2SDwVPktnxuXvdDDXJ3QQBoLLVAy5MMek+yXihyqDBZv379LFr73Siz6vG5G1J/r6oiqVg/ETqGfehy+qDTcFjZVCzKMBPWuinAZFE5cxZkckGW/RVrsuR9siausfRs0WFQXS6YtpmJEpU0yJIzWX/7APCbDUDHbgy4A4gKgJbnVBnLqiIj1tOx4i9vdeCArKbuRO9YXDDPruWj3aNiCwA1uSCrJ8w1hhIkbACSMJnDkaBhxJAlk5UMNUvxymVP4lOhTwJQBlmyoK5xHXDFZwEAVicxa5Ans+7bMAufXU6O61mhBsf7vKIhEQD855DE3iivc3Y9sfeTB85L6x0AFEYmbP4ToiTwU0IMsupjv4creVsAADHW8QDQPza+uizG5raUW2CmEmmfapCVnMliCZ5U8ubJNL74w842XPF/r2LveRL0lLCaLPZ7cSpLajEAVg+yOI4Tg++E5hfMDIgxlld/hdzvPwrsfSzTr5FztKtI4S8rp422dVkY2VxAyG91OsWXvvQl3HHHHThz5gw2bNgAAHj55Zfx5JNP4qmnnpqIXbigkMh1KhHMdLHnzZXxhYzJYgN2ld0IrYZHC2Wyzgx60EIvvIYSM5bRiWR/hxOeQDhhrZIIMciimbxcudEwiQGbjOQW7ln2HGUZ57IM+svYEsgFr19YhWqHEVaDTgxiVjYWY+eZIexoHYrvkcUwTodBNXfBRM1D1XDNvAp8/bYFqCs2iYF2PsHkgkzW9hfqaCk2bg7QRbHJkfd9UcIlq1MCiO19Y6kZZwY8aO13oyZRXy852MRMmSxBEPC77WSS/n83zRPPjWRgUrdkheQ7z5DrYWm9A2a9Vpzs0w+yGJOVK7kgzU6LQRYNqAKTw2TNqSJB1rEeKXDpky1KlQvWtMGYrF7KZI31AOd3AAB69v4bX3eTxWql3ZiQsXzvmga8cqIff9p1Hj662Oc5Yl7RO+YXlQ7MCQ0grRYEATBo+ZjkCZN75k0uKLL98eetXQ8YObIo79XmIcgC8LqvEU60AVDI5qsWAryWMFrv+CPQfwzY9l2UukiyRi5vbCg1o6HODRwD2jX1CAcFHOsZw/KGYgTDUWw5Jg+yYs8LZXKMzd16DY951faYbQAQCbjGAEQCJNHI5kOAmBspmSzRJVG9sbIcXQqWbbwOg3IjlWKzHt6gD6FRMr+OCWapl1QKJktpBpUQbGzwj5LeS3msuY1EBXz3hZNY3VyMDXMr8cddbWgf9qLzLfKdHDz9bsnmmZLkTBZAgu9+VyAxk6VsgG0uIQmBF74AHP83sPojGXyr3EPpCGnU8VhgcgIAelCBponfpSmDCWGybrnlFvzzn/9Ea2srPvGJT+DBBx9EZ2cnXnrpJdx+++0TsQsXFFiWrCZduSDNLuWOyZKCLCaLYAvHlnKyKCJMFrnwmkotmFlhRX2JCcFwFK+fTiMDJbr9CeqZvGwgCFJBLguyRAv37OWCYtPQFEYEckjOb+Q36RiWAtL/vX0RPn+jxFQwQ5EdrYMx/VZikEe5oLImSw0cx+EDlzZhw9yJyVoxC/dRbwjuQBivniS/q1g8PAmsBwO7zuQ29qLsLN3GtozJGusGwgGcGfBgyBOEUcfjpkXp1buxz2wb8iS0Gt/dRibvS6ksVZQL+tKTCyoDynGDZadFuaCMycqRhXsmmF9th4bnMOAKoI8mU8ZtfAGQxT1A6irCAeD0i+JTbftfwfNHyIL9Mnrtq2HD3AosrLWLAVRTqVlMcJyQSQblC7fDNAlRW2yKYZvzXZM14vZCg4iqXJAbaYOBC8ErGNDHJf6+44G8FjhGNu9oAD65G/jYDnLOVZLfpSTYAzs8Yi2xiAESfHntMwAAhzqcGPEE8diOczHOjN0KJkuUedMkxrxqO0w6DVY1F4sKiEFl3zO2cFeqObxDpLcXZBb0omNlaiZLyb6O12FQbqTCmLoIrXl+IzpP2jCFq5zaHKQKUzHAOlExhidPePVEP37x2hn8z98PY9QXwpkBos5hbop2UGMveRCsRBruj0mZrGhEOgfkbQPqVpHboTMpv0cytPa78cddbaJDcjZop+sXvZaEFE2lFpRHiWrofDQ/1/R0wYRZuN90003YsWMHPB4PBgcH8corr+DKK6+cqI+/oJCoSWUi5J7JkuSCbMBmQRaryTo36BH7JTSUmsFxHK6dRwbZLcf6Un+Gzkj6agC5kwwGXCQzCEiTUQ6aEfeNpWepLYckFwwhEhXEiUpuq8zAgqxdZ4fEIuUSsyJ7xxoUZslkqU1wrLg3bqExBeCQBQMvH+8T7e9HPEESTAfp5KefuAU5QBgnpVwQkJtfuNE35k/NgFjKaDBBstZ7aDC0tN4hTmSpUGk3wGbUIioAZ+niQBAE/HzrGWyniQ4WJMyksrhM5YI5r8lSGl/EWLjn1vgiHZj0Gsyiv90RulBnzaABoM/lT78thRz2WrJgiobJwv3UC+JTS3AaVl0U//j4pfjWHYmNWziOw6evkdoTLKl3iMzbqQRBFkOtgk21pcsiZINIGJ9p/SCe038BpcpxCyDsEYDTQi08kdw2MQdID0LmegqoGECVziAGBQBhCOhYOpdrjw8KB4ghhq6KBA9/2HUea771Mr5FZcrLGhwAgB7Z9U1c18jc66AmB2VWA9546Gr8/p7VCfueJXQYZGO8tRLQ0sQeC7ZSNLgGpKCISfsyCbIEQcCzB7vFUgB5TXZtsUlM/nE0yHozOo8YH7D9TQKWePSHosn7z/EaKdjIs/nFoU4nAFIS8K8D8SyhRUgjyEqTyQIS2Lj7RyG2tWC1egBQSuu9xzqlBFSG6B/z412/egNf/tdR/PnN7EyzAEkueN9VM7GysRgfXNsER4CssU4Hs3QMvUAwoX2y9u7di8cffxyPP/449u/fP5EffUHgtzva8OirrWK2YyowWV2KIKu+xAwtz8EXiogNalngdc18Ihd45URfelmTVDa2mUJuK6smF8wSTD6UUZBFa6DCUQHnhzwIRQRoeE7V2ndRbRFsRi1c/rBYVBrPZI23JiteDz/syVwGOVFwyPpkPXdIyt46fSEI4QAg0El6AlkPgDh3hSLk3JbXKTHzi33nR3Djj1/Hxh+/nthJCiAOKGJdVpvIOK1qSn/C4jhOVlNEzvPDXaP4zuYT+OI/SS2QWHdHz6eUro0KsJqsnMkF45gsykT6nVLbiAkOnBfV0jYK1LRHzmQJQqwRRjrwhyK48ZHtOBaliZHOt4CzrwIAopwWZi6A9fZerGgsSWlusmFuBZZQ44RVTSWiDb/c/GLEE/9bKpM5eWWyxrpQHe7EPL4DNRoV5qGfBCinonXw5OHjzw97YwIrTyCc3ICGSjkX8G2xBkOCIAZZZc3EUOfcoAfBcBQzK6z40LpmfPsO8viAOyAG3yzI1fJcjAtnEW35UULl2KO+UGzAnqguWSkVBGRyQRUm6+g/gcc2is8xuSAz3MgkyNpzfgT3P7kfd/5iF5zeIPrG/AhHBWh5DhU2o5j80/rIXNshlKO/5moSuFYkryO1GbSi8VPK83CCzC8Y8wuQmlglTBF6ncmDHyWY9Hu0k8gb1TahwfewWnKLsXV6W6yBhLlECjaH4/ctFUKRKD75xD6xpvz3O9tSNjJPhPPDJNhc01KKv398Ld69ugEmDzlPT/iKEc4mEXWBYEKCrP7+fmzYsAGrVq3C/fffj/vvvx8rVqzA1VdfjYEBlX4KOcajjz6KpqYmGI1GrFmzBrt37877Z+YDv3r9HL77AhnkrQZtTCPiZGAW7rlzF4xnsmodJDjQaXg0lJBFkNMbQnWREevnlAMAVjeVoMikw4g3JFqSJwUbQHIlCZAPyEE3kQqO0/hCEISsmCyTTgMtXUCxLGuNg9S1KaHhOVw7n2QB2eI3zl2QFUAHRrMKSqWaLGkBMpTEEWyywWqSAuEottJgHiCZ1TG3TF6ag55KmUBeVG+RNWSeQWVcJ/tcGPYEMeoLia0PEkLmMPgWDbJWZhBkARBZGGaLzBIC3aN+CIIgBllMEqrm2pgMIpOVMwt3RU0Ws3CXZ6wnOHBmi9EjXWMIRKSGvSz5oNZ+IRleOzWA4z1j2OmhC+NdjxKWzlaDnvJ1AIDL9KfSei+O4/DL96/EN25bgHeuqheZrERyQYZETJa8v1PfmD/rRVcMZHbelREVa2/KZJ0U6qESD44b7Nph10JUIGxJIghUMjifO49yeYLJ1UPc/jgNZs1bCr2GB88Bn79xLrZ8+gp8+Zb5mFVhhV7DQxAkhQMLcost8T3CAMIes1h6JJ2GxCzIYok1AGA9mtSML/b8Dji/A/zp5wFIcsGl9WR+zaQmi5UADLgC+Nq/j4nnfrXDCA3PieOH0U+u136hGH03/gb41IGUYzHPc6LsOGWSZwLMLwRBwGGZzLRjmHzXhbVSaYE+zIKsJEyWtYqocoSIupw/HMCS4AEYEFRnspSmF3IwNmuoNel3UcMTb7bjrbYR2AxaWA1anB3wYHtr5sczFImimwbujaV0bBYEaFzkPG0XysZd9zedMSFB1n333QeXy4WjR49ieHgYw8PDOHLkCMbGxnD//fenfoNx4K9//SseeOABfOUrX8G+ffuwZMkSXH/99ejv70/94imGty2twV0r6nDTomr87+0L03ZwYxbunmz7ZAkCqRtgkDFZ4iArY9VYXRYAfPnm+WKQp9Xw2DCXsFkvn0hDMpjrgVTZINHVI7Nwzy7IcgXCYvDKmlWmA47jxMDmGM2Q1zkSLx4/c90cmHSSlKZYaXxgsErZtDQcppQosehh0PKIRAWxbkisyZqCckGLXgpSQxEBq5qKxePjGnOSjXjdpDUiNuk00MhYiBnlVigv1yF3EAc7nNjwva3YfERafEaiAh56+jCOeB0AAHfvGXQM+8BzwHIqR0oXsxRMFmMng+EoxnxhscaPmZswJitpY0wZWMCRc3dBJZPFEjscD2gm9nxcVOcAQJgsJ10DWQ1azKUBTaYOgy/Q3/potIk8wLLQs69Hq4kwIYsjx9N+v6oiI95/aRN0Gh7zqsgCsLXfHcekyCG3bwfimaxNh3uw5psv47fbE0uc0kVUNh6VBDrjN+gn3/W0kD2TNeoN4XsvnMRBmcsiw6snyLlzw0KpJiiZqqPPQlpTLNScR1OZLDCg9VgoaUFJkQ1/+f8uwbP3XoaPXTlDnIt5nhObhDNZv7IeSwme50Q2K7ZXloPcsl6RkRAxQFBlsmhNFkseysESFJ4BRASghyZamLQxEyZLbjP/zP4u/GFXGwApaCdzhQBLiHzmgFCEKoeJSPzSgJqiQhVMeqjWjytH6BsLiCyPHA9cO1u0XNcG6bFOZnzB8xKbpSYZfPOXuOvYvfio5jlxPI6BaHqR2yCL1cZ/4qqZuGslOZd+v7Mt4/fpGvEhEhVg0PJin1S4+8GF/YiAR69QmjHbfyFhQoKszZs342c/+xnmzZOKIOfPn49HH30Uzz//fF4/+wc/+AE+8pGP4J577sH8+fPxi1/8AmazGb/73e/y+rn5wOdvmIPv3rUEj753OW5fln4/EfN4Ldz/+j7gB/NJRiUSEhkhv6FUXLyxDCoAzKUT/eWzymImNgBYSKU33c40Lrpcd3ZXBlljXeNmsvppttJm1IrHOV2w3jSMyVKrx2KocZhw7wap55Zq4MPcm5TfMw3oNDzWUPODbacGEI1KPbmmolyQ4zhUU/b0+gWVeOye1SimAYKL9d6bYFkZALgC6pbmJr0m7vcddAfw4rFenB304LlDUi3F/vYRPLm7Hc+cow6U3YTVmFdtz7j2SekwOCyjCk73u8R+OezYyQ1F0oFbdBfMQTArCBLbrKzJYue0zoK4aDXPmFtlg5bnMOQJos1FPruqyCguLDNxGAxFonjpOEkwbYquwZH69wIL3w4sex9w+YM4wJE5stl3GEi3p5oMdcUm2AxaBCNRkb1kbCVbGAKxjYiB+MbVbB//czjzhI0S/mGplsXqVQRZkTAwTAr3iVww89921BvC+3/3Jn76aiu+uSk2OJUbLV09r1KsQ0qWcDwUIQzRTK4LOkG2HZUKopwEYcsbisX5TA6WcGTBt+QsmPgaEeuy5OYXSibrqbuBHy6QTFKKZEyW3gIY6PbKJBtLXLj7MRokSRydhsOCGjJPD7oDaTOWLOhgjNN/qFS7liYIiy162OGBViDjh1PjSNk0XQ52HqZksphlfRpuitmC1WO1lFlgoHWwGp7DpS1leORdy/C/t8wGH1ZYuCdCMvMLel6t4k9OGJMlCAL2tZP3XdNSgg9e2gSAtEJJV8XA0C4z7eICY2S9+Pr3AADDfClC0GbM9l9ImBAL92g0Cp0ufoDR6XSIqjXayxGCwSD27t2Lhx56SHyM53lcc8012LVrl+prAoEAAgFpoGPNkkOhEEKh/PQQSQX2udl+voEnA6gnEM7qPbRnt4ILuhHuPgyhuAk6CBA4Hnv7BYSjAiptBlRateJ7f/CSetiNGty+tBrhcOxkZtGRSXTUG0i5L7yxGBoAEVc/ojk49vxYH+T5tMjAaWgEcv6FeDIxZnp8uqgWucJmyPi1LHvMsq81Rcnf44OX1ONVWs9WYtLEbasxl4IHEB7rhZDF8bpsRgm2nRrA1pP9uHVxpVg3Z9Vx4z73x3sOq+HH71iM9mEfblxQCZ4XUGTSoXvUj7FREpQLOjPCE3zNjtJFkkUf//vcvKgKf9vThWKzDq0DHvSPetFLJYP9Y35x+yNdTgDACYEuonoOAiAsVqLjl+j4NpWQQLRtyAO3L4BBlzTZHaWf4zDpIEQjCEUjsOpZU8zU12c0KqCL9uAyaXPw2/qc0EXJeBHSFwGhEDjeQCYpms0XdKYJ/001IFKz470uHB6mQZbdgEo7WRgf7XLiiTfasGFueUqTmB1nhkT2LwA9/mj/KB6+bYH4/Bu+OtwrcDCFnAiNdAK2zJt6z6224a22ERzqGMbMMpPIPiypK8Ke804AQKVVF/N7GbXkWveFIvD6AzhGa1GOdo/C5fXj888cxeGuUdyxrBbvWV0Xz6QngXeoEyzdwY+cjT1PRjugi4YR4XToRTFqQ5mdR9GogHt+/xYOdZL9be13x7z+zbPDcAfCKLXoMa/CDIteA28wAqfHjxq7+nfY2W/EasECB+dBqOeIWKOl6TkMHkCkZFbS+aiKZvO7RjwIhULiNecw6RJ+N5bk6B7xIBRyAAB4vY3Mf94RRIMBaM+8Ai7kBYZOAwDClqqYcV5rqwIXGEV4pAOCo4U8KAjQeofAARDcfegLkfO3vtgEh5HMhqGIgEGXN63fdIAmFT+0rhHPHSIJIgCotusRCoVQZOBRwTkBAE7BgmKbDZFIGJE0KxVsdJ9G3P6k5wFvrYQGQNTZiUiexoODHSQIWVpfhEq7AbvODmN2hRVaLopr5pYBHgHYQrYN8aakcxxvryW/5UhH3LmjGesGD2A+fx7DrvjvzbsHyHc1FMV9V87RDC2A6ODpjI7D+SEvhj1B6DQcZpebYdDyKDJpMeoLo2fEI67T0sHZAZKori82IXzsOWiP/1t8zmWsBrzAyZ5R3DC/POZ1oUgUY75QTJuEVMjHOiJbpLsPExJkbdiwAZ/61Kfw5JNPoqaGaIe7urrw6U9/GldffXXePndwcBCRSASVlbGuNpWVlThx4oTqa771rW/ha1/7WtzjL774Iszmic+My7Fly5asXjfkBwAtXL4gNm3alNFr+WgItwRJRnTf9i3w6suwHkBAY8OTL+8BoEG1zhfHSFYC2LX1aNz7nRriAGjQ3jOYcl/mdo9gDoDzx/fhsDuz/VbDws49mCH7//zeLWgBEAWPLVu3AxyX8THePUC+jyboyvjYBtw8AF6UZoV6T2HTppNJX/O+apLIf2FzPAO8ciyMWgDH3noN59oyZ58ELwBo8eaZQfztuZcAaGHWCHjpxc0Zv1ciZHsOJwIHYDOVuUd85HgeOngAawB4glG8nOFvMl4cHSHnQ8TviTsf5gL40iLgqXM8WsHjjQPHcN4NADzaeofF7becJd/jcJQslGqEPlRpxlDjPYtNm5IXOCuPryAAJo0GvgiHPz6zGYd7yHsDwJbdRwHwMEAaFzo9AKBF34g75fn8SjeH1gENdLyAwZN7sakt1dFJDou/F9cACPFGbHrxZQBAkbcN62XbeEMCXprg3xQAiqLkuJ1wUonq6ACGzvcD0OCFY/144Vg/1lVG8Y6W5EnDp+hva9MJcIU47D7ZiU2bJFevk70a9KMY1RjGzs1PwWmZkfjNEsDsJ5+xaedhGHsO4uhp8n9ZZBg6noOeB/ZufwVyT42IAPDQIAoOj/9zM071awBwCEUE/O/jL+I/Z8ni90cvt+LpN07jwcWKVbMgYHHnH+HXOXCq6raYp1paD4GZOI+2HcQ22e9X4j6JywE4+WII4OEJRzMaIwZ8wL52LTScgIhA2Ma/P7sJZrqy+Wcb+e4zzH5s3vw8ECbf6+XXtqMtQdeObUc1uC7aiLWaYziy5c9oL70CALDhxCuwAdjTK6A3yTnoGSSf+cbBk6gdO45/nST/e4Z6sGmTOvOi9ZBtHv73EYyePYhKEzCzrwcLAHS3HsUJz59wLXPXpNh+uA2jZ6T9uDSgQwWAQ9s3o+M4CX50YQ820sTFWNdpdFJi2CG48fKLm2HWauANc3h600uoTmN5c4Iez8Hzp3BblYAfDWoggMNQx2ls2nQKp0c5lHMk4B0QHNBHfBnNi75R8v479xwA35nYHK1mpA+rAIycP4rteRoPXj1O9oUb6UCFAAAaVHKj4vex+ntwNYCQxoxNmyV3ULXzd07PEOYC6Dh5EAd9sfu7vvs0igCUcWMIjXTEHa85PW9hLoDzAy4cUjxn93XjKgDh3uN4PoPj8BZdt9SZo3iZzu/aKLk2Xnh1G05lIOx5jZ4TkbE+tO5+EXJ7k6EoeaPth1oxKxBbZ/qPczy293K4f2EEzRkKiXK9jsgGXm96jo4TEmT99Kc/xa233oqmpibU15PsbEdHBxYuXIjHH398InYhbTz00EN44IEHxP/HxsZQX1+P6667DnZ79r2UxoNQKIQtW7bg2muvVWUEU2HYE8TX929FKMrh+htujKkVSYmxHoAk0rF8XhMpsD0JGEob4DVWARjAxjXzsHFtY1pvV3J2GL87tQcakxUbN65Lui2/uwPY8iyaKqyo37gx/X1OAM0zzwADgKCzgAt50KQj0iTO5MC1112X1THu2HYOaD2NBc212Lgxsd2yGjaNHsCpUVIvoNfy+PidV8OoS0+7rgZ+81Zg71tY0FiJeeszP16CIOCxc9vQOxaAp3gWgHOodFiwceNlWe8Tw3jP4XTwgusgTo32oba6EhgALI4KbMzBeZMJood6gBOHUVNRgo0bV6luc+aVVuzoO4vi6ga0tTsBuOGDDhs3Xg8A+NNvdgNw4n1XLkLPnlpUR7rw/F12WBdcl/Bzkx3fP3S9if0do6iduxwmfzcwQORDAVMpgBHUVxRj48bVAIj07buHXodf0ODGG69LWPd5vMeFTbvfACDgyzcvwLtWjb+JLNe5GzgOaO1V0u821Aqc/LK4jbmodMJ/UwAoaxvGm7/dg5BAjseqBTNx6YwS/PnMHnEbwVqGjRtXJnyPaFTAw9/bBiCAT2yYg++8cAoDQa14nMORKB5482X0aEtQzQ1j3aImCHMz/67+/V147emj8BpLsXHjKvytfy8wNISr1yzGxyuJMQNrKSDHn+h50mFsQVSQWkG83G8CEEStw4gupx+DIa14rooYaoXuFyQwnvm2h6QaFADDv/w5QNvDOaIj2HjjjcQAoKgO3BE3cBrE5toDeEJcRmPEnvMjwIG3UOMwIxiOos8VwOzl60Szkh//eAcAD9571VJsXFSFX53fhYFuFxYvX4UrZ5fHvV84EsXn97yCY0Ij1uIYFlfyWHjdRsDVC93+HgjgsPz2e5PW4Ay/2Y6Xu0/AUFwFx5x6HNi1FzwHfOaOtZhfrb6GuNQbxPt/twcn+9z47VkLNt+/DrZjg0D3X1Fbakb1wirgWOxr1t307hgJmebfzwOHjmBJSwUWrZNdP4fpsdcF0ekm5++1K+di42VN+EnrDrQOeDBv2RqsnVGa8nj/sm0XMOrCVWtXYv3sctga2/HsoR586q5lKLXocbLXhfYTbwAA+gUH5jVVY+PGxSnfl2Fb4AgODXejYcYcbLyyJeF2XGc50PYoSrT+vIwHgiDgG4dfAxDEO6+9FItq7bj+5CDWzigRW3PwOx8h41X5TGzcuDHpGMzv6QF6n0FDuRW1iv3Vnvy0eL9J6MCNN34mZtzlN78G9AINc5aiTjmvh7zAiS9CH/Fg4/pLpGbFKfDmv48B6MSGxU3YeAORv/76/BsY6h7D/KWrcNWc+GsjEZ574gDQ048rl8/H7H4D0AsIegu4oAdFc68A3gDcGlvceu/nj+5CFC6M2FrwyY3JnSfFrzsB64h0wVRuqTAhQVZ9fT327duHl156SWSQ5s2bh2uuuSavn1tWVgaNRoO+vliThb6+PlRVqcswDAYDDIZ4FkCn0036j5rtPhRZpNK7kMDBmMl7BJ3iXW3ACejosbFVYv9Z8tyq5tK096vYSuRLLn849WtspMaI942Az8Wx9xEZGVe5AOjcDb57H/m/ca24L5ke4wFaqFztMGf82zjM0nm2srEYNnP67oSqsBHGVuMfgibL43Xl7Ar8dU8Hnj1ECorLbIacnvf5vI5K6bkVDpAsLqe3TPg166cJfpsx8fessBNp6og3LJ4/Ln8YEfAwaHmcpPVTtyytRbV3LXD4KRSPHgN0N6X8fLXjW19iwf6OUfS7Q3D6JPnuKfo55Taj+JpyO5ncg+EowuBh1qlPEb94/RxCEQHXzq/E+y5tStuEJykCTgAAZymTvoPZEbMJZ7BNyji8blYl7tswA4+8QuqHakssWNNSjg+ta8aYP4S/7+1E31gg6b7tax9BvysAq0GLD6xtxg9eOg1PMIJ+Txh1xWb0e0gBea9QCqAVWncvkMV3XVxPFlonelzQaLQYpY6BZTYTljQkXkivm1mO/R2j+Me+WLaFmTHcvbYZD286Dm8wgijHw6CVJYTobwcAumNPA+s/J/3vk2pEOb8Tujd+DLzyv8BN3xedUIWiOqADcIczGyOcPnLBldsM0Gt59LkC6HAGsKJZh1FvSJSzrZ9XBZ1OB4uB9mKKIOYzekf9+M7mE1ha74AvFEWrnizwNf1HyVja9SbZ/6pF0NmTL0DrSkgA2z7swzeoMuH9lzQmPfYVRTo88ZFLcNMj29E75see9jFcamuGFUCo6zAMTUQiiLk3E4dDowM6W3lsfaKDJDo07l5p/A9KC0HOM4gOqnBaXF8MnU6HCrsRrQMejPgiaR1zVtNZVUTG1g9dPgMfulxiW8uLzGjmyNzRDwfqijObF9mc6A6pl5iIKCHtDzhXD3QaDTGXyCGG3AHxvF9YXwyTXouNS2R18CE/sPsXZB8u+UTMvqqev1bC5catZcLBGNfj2cJ5BAUeVnl9NzXn0lhK4+d1XRFgrwPGOqEbOw8UJe9FxrC/g5wXK5uktRtrRO8NCWn/Zv5QBMd6iFywudwG/jRJznA3/whwNMJimQe88TrahrwQOE1Mj0fmvvnmuZGMx/Spsh5PBxPWJ4vjSIbqvvvuw3333Zf3AAsA9Ho9VqxYgZdffll8LBqN4uWXX8all16a98+fKjBoeZG9ytjGXW577hkSnQXd2hKMeEPQa3ksqElR9CmDssA6KVhWJle9MJjLUrUiszbvlqzfMpseWQxyC37WcHhcsNDJfxwNGpndPivmL82gaHmyweoagj5mZpIbea/LH8KPXjqFMwPulNuygnprEktzpkHvHfOLDZ8B4vLVM+qHyx+GlueI7XvNcvIkTQhkg2ratLxn1C+6CwLq7QDMeo1okJCoIbHTG8RLxwgD++lrZucmwAKAQSonYU5pADFzqVwE8FqgYgGw7lO5+aws8MkrW7CkJAqdhsPKxmJoNTy+fMt8fGI9WWT2UFv8RHjhKFl8rp9TDotBK9r6M/Mg1hh6zEAXSlkW9c+kNuKuQBidI74YC/FkYEwGmyNWNclYEp7DHctrRYlh3Lkhd4A9+GSMaYfRr3Dyff0H5PbMq6JTnqGUKCGc1JghXTAjhjKrAc1l5HiywKqD1guWWfWiax27LpXGF795/Sye2d+FrzxLJO7hCmLjjt7DQDQKtG0n/zddnnKf2PV2ss+F1n43Si16PHDdnJSvK7UaRHZt7/kRPNFdjpCggcHXC5x+ib75UuCD/wbe+ad4AxgbvW7kxhey34UL+xAIkIUtM72oovt6No2xTRAE0ZgjUVsPh0mLWzU7AQA7owvE908XkrtgivWBtZI4jUZDebFxP0Obt9c6TOqGVgf+DHj6SYCz6M7Ub8hMvHwKEy93LAEwn2+LN79gr0nEUpXSILcvvjxDDe5AGCd7SZC1vFG6xtmxT7dHIgB874WTWD32In5s+jVW1FuAkTZpnxrWoKbEBoteI/YDZfCHImKZxIleV4xrpRyj3hB6UrU6meLIG5P1yCOPpL1tPm3cH3jgAXzwgx/EypUrsXr1avzoRz+Cx+PBPffck7fPnGrgOA5mvQYufzhzG3d5gOMdEpu8doaIiHZJXVFMdiIVmNmDLxRBKBKFTqUvlAjmMJazIItmVKtkQRanAWYllmGlQq/YIyvzYMQuc2RLR6qREizIcvcn3y4JrltQhTuW1+Jpms0umYI9shKB9WgJ+alWOkf9lN7z6zdxuGsUBzqc+P09q5Nuy5IHliRBFnNrZItrhn5XQLQubim3kOuqlgZZXdkHWVVikOXDsDsY97y88JjjOBSZ9Bh0B+D0hsQm43L8+1APgpEo5lbZML8mhxLqs6+R20aZPJXXAB97HYiGJ9yOXwme53DP7CiuuvYaFFmk48KOrzcYgSsQjrmuGQRBEK3bmePq7EobTvS6cLLXjQ1zK9HlJOdtwFwDjEGy6s4QOg2P2VVWHOkaw9HuUakXWgpjg+WNxdBreASp9fudK+pwpGsMvlAEa5pLUGo1wGHWY9gTxIg3GJtYkid2Rs6RBsv1q4FwAOYIWdCNmJtR7D1HJE4AsUQPk/HTUtEEDc8hEiV9m+oN6Y07jAkusxnQXErs1s/RIKtL7OEo/VbsunQHYpONr5yIHTNLGxcCIwbCGjnPy4Ks1NLpallgUWLR45fvXyEuYFNhRWMx/rqnA/vOj0Cn5bBaaMJS7gxwnn5+eZJgTWxILDmVKhNu5dwoHEUl4li5uqkET+/rwvbWwZSB4JgvLDZaT9TWw9C7HzP4HngFAzZF1uDqDIMs5rib0sJdoyOBlquHJCOYs26OwBJqM1RktQCAN35Obtfel964lChhrLCgn8+dx7AniPoS2dyVzF0QABrXAudeA7Z9F1h4R0qnw/3tI4gK5LqQX8OZBFmt/W48d6gbv9l+DtsNT6FOGATaXpQCfOqmyHEcZlbacLDDidP9brGliNLS/Y2zw7hpsZRc8wTC+PnWM/jdjnMIRaLY8bkNqMgikT0VkLcg64c//GFa23Ecl9cg653vfCcGBgbw5S9/Gb29vVi6dCk2b94cZ4ZxocOi19IgK0MmSz5Ie4eACMk4nPOTwWdZQ4ILPwFsMmtrlz+cvAeTaOE+RDKj48mYR6NSxkvOZDVdRgbALN1q+rNoRMzA+mTZDFosUrECzhgik5V9g28Nz+H7dy3BysYS/HFXG26WDXxTHcwiORqgWdksGhELgoCtJwfw6KutCISj+Oz1c3CYOq3tbUvdFDs9Jouc80pWecAVEBeIs+lkhKrFJGPr7iWLJ7aQygDMUrptyAuPCpOtdMRzmHUkyPKpW/n+Yy9Z/N+5Yvx1WCJCfqCdOr62rI99juMmPcBi4DjEZbbNei2KTDqM+kLoHfWrBlmn+txoG/JCr+Wxfg5ZEM6psgEHpWCbtbUQ7LUkyBqHPfWC6iIc6RrDvvYR+ELkN3cksRAHAKNOg+WNDrxxlmTOF9U6sLKpGK+fHhQXQMVmHQmylF2DlWzCwb+QIIsuIgOCDq7SxSTIYhg+S4JnALyjAZW2KLpH/egZ9aO+NL1KeNbjqdxqQHMZC7LI9c+YQbldvZX2jPTKko1nB9w4O+iBludwxexyvHKiH+vn1wLd84CeA0DrS9TRjwMaUytgSix6XL+gEiPeEL5/15LYBXMKMGbhYKcTAoC3MAdL+TPSBhXz1F8IpGSyAKAMo3DI6sIum1VGP28UY/6Q6rnLMEhZLJtBm7h2+MCfAQDPR1fBAxOqiuKTNMnAJGtj/jTmY3sNDbK6gZplGX1OKpyh7Q9mlKvMIZGwZJm+4G3pvaG8HY18LeOmQZajAXC2o5nrxavDw0C9g/SPs1TIgqwETNba+8j1NnIOePGLwK0/Sborr54g64N1M2MTu/Y0g6xdZ4bwnt+8AUEAeERRzdH9O/oMuTXYYwLCWRVWHOxw4lSfCxsXkXO0W8FO7To7GBNkPbzpOJ54U6oLPd3vLgRZSpw7N/4mhrnCvffei3vvvXeyd2NSYWYNiTPtlaVksmgWsidCBuqaDDNVWg0Pi14DTzCCMV8ovSArEiT9rIzjyJr7RgBq147yeWThKkTTlgpGowLahjxoLrOI8qhoVEC/K3u5IOuOftXcCmiTMXrpIgdyQYAkPt6zpgHvWdMw/n2aQLDsbDSQPZP16Kut+N6LkgvSB363W7wv7wWXCOz6ShZkJeobM+AOiDIO1ugWejM5X/uPEjYrqyCLnJut/S7V55XXoINNtipywTMDbhzocELDc7htafq9+lKi403Calirkmfrpyiqi4wY9YXQM+qXAmQZXqRSwctmlonnBvuNWRDPesnoSuqBTmTNZAHAwlo7/roH2HqSLKi0PAdbknOSYe2MMrxxdhg6DYeZFVY8fPsibG8dxDtXEcMqYvPtiW9w7KHzRMUCcq6efpEsJmmQ1S84SHa7Q/YaISrJi4rqUePoI0FWOj0UKUS5oM2AZrogPjfggSAI6kwWDZDdsnmQsVhrWkrw2w+uxJgvTBb7RxaRIGs7TRhXLUrMJsjAcRx++f7EBijJ0FJmgcOsE+WYb/Fz8BFQ1ziNXuq3pAY2Nrj7gB8uBFZ9WPpdKMo5J2bL2Oe6YjNayiw4O+jBrjNDuH5B4pYBQ5Q1TCQVRMgPHHkaAPD3yJUAMl8fsCAvJZMFkO/btTeWucsRmOSUSXpj4HcCoJJWc5oKFBYgRUNkLWOwAdGIxGRVL4HL7YYtPIxtr7+KDUXd4P5wC1C7EvCmYLL0FuD2nwGPbQT2/RFY8zGgcoHqpoIg4OUTRKK4YW4s0ZDusT/Q4YQgkHP1/tVWaF6hibtT1GHR0RiTEJ9F2cBDnaN45OXTWNNcIjJZjDnf0ToUo2w62Rs7V7H6remICavJAkjfqpMnT8b1Tiog/2CTizfjIGsw9j7VEHeFyUDtyKBfCoMt3bosvVlaKI9XMsi+h9EB6IxAw1oy8M27Na2XP7azDRu+/xqe2E2yK9GogI4RL8JRARxHCq8zxZWzy/H0J9bi4bctzPi1qrBQeWXQBYSmt445G4hyqBDVfmfBZP3rAJmw37GyLs6BTY0FUiIduaDdpI1pDMsw4ArgRC9r7i1LKDDmdeB43GvSAQuymNRHSQgrF03smnaqTLZ7z5MJf1VTcVbnfEKco1LBlvUT3mw4F2CSwd4E9QMnKVt1aYu0KFtS7wBApDejvpAYFNjKm8gGrl7S/D0LrG4mn3OaZuSLLfq0aueunlcBDc9hdXMJ9FoeDaVmvGdNg1jTy86NuCCLja/zbwV4HXEPHD4rZur7UAxdOXWLM9hJbZEcRbWooomq7tHMg6xyqx71xWbwHLlOB1wBicmSBVlmlZqsl4+TIGvD3Eoil6VsiigrZ4zi6o+kvV/Zguc5LJepQ/ZEZQmHstmAJkmgbC6TGhSPdgA7fxrHZJVzo1hQE5sEYGzW66eTKyBY3UzCvkZnXgECoxjSlOON6DxoeS6jHkhA+mwK2ZgmecaRjEgEUS6oFmR5aY2UsSj57yGH3gxo6XnoGQB+dgnw66ukfbdWQdO0FgBwWd+fMfafr5AkROdu0fgiaYDfuBaYTR0/T8a3d2E4O+jB+SEvdBpO/N0Z0pULshqpGxdV4fYW2ZhCpb8ojnWaZkmnV0704wdbTuGhpw+jh17j6+eUg+eIxHf517fgV9sIa8uuXTYH9xaCrOTwer348Ic/DLPZjAULFqC9nSxU77vvPnz729+eiF246GFhTFamcsEYJmtYrPfpCJILx2HOXMbDDB/SkgTIJYPjAZPQMbbn/U8D9+8XHflS4Vg3YRiOdo8hGhVw449fx5Xf3QqAmEMkrS1LAI4jE6otiUQjIxiLSLYTGDebNR3BmmnyoeyYrEA4Isr1/vua2fjrRy/BHctrccsSkiFOp54xHbkgx3GqhiJdIz5xcp8rZ83EhVN28rEyqwFaWduG+uLY46LcF3ZNxy2kAXQOk2PLTAZyhrNbya1SKjhNIDcXUQNrQCw3nyizGtBApWT720dwlDJaVbX1JFCBECv9ygCzK62okAXBxWmO0wtqirD5U5fjp+9ervo8e5844ws23jgagPo15P7ZrQiNkqRFn+CAecFGYhxx9ZeBmqXSay3lgM6EGkfyY6gGufGFXsuL0ryzgx6JyZKd71bFPDjmD+GtNrJovmaeoq5HLiu/9F5g+QfS3q/xYIXMjEBjK8eZKJVRlaewueZ5Ur94938AcIB3EIEe4vsu6Mn1WsaNiswCw2XUdGn76eRzxiCt7StLxGSdJkzGMftlEMCj0m7MrF0MZMYX6RhjqdWg5QD+UAQddJybUaGSqGNrkXRZLAZWl9W9n9Qj9h4CzlBTNlsVzNd/GRFOg2s1e1E0sCf+9alYVFZb3vpSwk1eoQmFS1pK4+aodIMsJmuuLjKpS5pl7RsAxCUrzw560EqTP3OqbPjM9XNQatHDFQjjJ6+00lYM5DNW0IRD/5i6McZ0wIQEWQ899BAOHjyIrVu3wmiU6ONrrrkGf/3rXydiFy56WPRaaBCBz58hwyGXGwTdpBAYQJufDD7jY7ImMcjSGpL2OlGCubINuAIY9ATEzDQgOTVNOjguJ3VZ0xWs5kQfpYs0XWb1AOcGPQhHBdgMWlQXGVFqNeAH71iKT15F3JuSBVnBcBRvnB1CL50MrMbkGU45e8RMU14/PYBQRECxWYc6WR0JimjGNssaHZ7nYuSstQ4TLHqppiKOyUoiF2SStpj9Gy/8o2ThAQAtV+bufScQVdSWX1nQzcAkOHbFebG8wQEA+MPONgx5grAZtFjWWCItILMMrDkuNlNdnME4PavSltCJkD0e54DGGBNzmRQon90K/zBZAA+hGPbiMuDu5wgjJA8YaBKhJkWgqoQgCFJNFg0opbosKciSn6sWBZO1p20Y4aiAplIzGksVC+ralUTpsOZjwLXfSGufcgE5k/Xu1Q14PUr7L9alIUE0FZM6YweRemv7jwAAImXkeJfDiTLFb3vJjFJoeA5tQ160DXqQCEmZLEEATpMGse2lxIExU2dBIFaylsypk2zMxsXcBlnnh7yICqT2rFztu+YiyGLopU3MbNVA+RyElktsqaCRHT+DPTVrNutactuxW2yNoASTxm6YG28UkimTVeMwqieBFEFWrcOEFY3FmFlhFaXpr50i65OqIiM+sX4mdnx+AziOKEGOdo9CEACjjse8anWjjOmECQmy/vnPf+KnP/0pLrvsshjJwoIFC3DmzJkkrywgVzDrNXhW/0Vc99odpHAzXajZo2qN6PKTCzLdDKkckoNQOjbuuQqy6PewlCXfLgGY1faAKyBmVcqsevz1o5fgkXfltuh2XGDf7yJksmwGLbQ8BxNHs14ZygWZDnx2lS1mnGJS20T1jKPeEN71q11416/ewPEekoRIJhcEJIdBAFhIWyCw+r7FdY5YaVcOFhNKxzO2KOW4+AW4IxFbAckWO5Ni/pQYbCXSGFt1VjVnUwFVRZItvxoYa29XuMwx46BXae3UFbPLCSteRE1FOt4Anvs0MHg64326PMsgKxmKRbmg0viCSqgspVKQdW4bQsOkCMujL489p2OCLPJd2aJcWRSfCJ5gBP4QqbNl19NMKu9669yw6KoYa3wRey0f6CDsodzKWoRGS6zSb/xOzvswJcPyRgeumF2O96xpwKUtpfhu+J34iuGzwKr/Sv9N6PHVcCRQ8TqI7LCKd8KgMK2wG3Wiu+3jb5wHAJzoHYsrLWA1WcogDQDQd4QkgXRmuKovARA75qQLpnIJR4XU7WZEJit7gxg1MDv7lgqrusSWrUUSGVEkAtu+52D8c1RRY7z2C+hBGYYEG/rWPCR7bRoGY44GoGwOcYA++2rc091On8jaqgVZbGxKVcbBkiDpMlk8z+HvH7sUWz59BZZRiTS7Ntk5YtRpUENNUhijWuOQjFMYszUdMSEjx8DAACoq4n9Uj8eTux4rBSRFidaPBfx5OLxtcb0ZkkIluBEsFfAGyeTmMGXPZGUkFxxv0OB3ktsU9qaJwCaYAVdALMKsLjJhTUuppOGfCriImSyO4+Aw62AGDbLSlAuGI1FEooLo8qY0uGABkz8URZjaWzO4/CG881e7sK/dCYteg+oiI+ZW2bCMMhSJIGePlEzokjrFOZqD2oNqWV1KiUUvLkodJl2cpKeILqSH1eSClMmqzyWT5SQLOzgak283hcEWA4mZLLJwscUxWbGLp6vY4ocFWS9/HdjzO+DNX2S8T/Lee6l6ZKULSS4oOzcEQRqfzWXE6c1gB/xOWDtJrV3ArJBly13yRCaLHMN0maxBmpQw6TTiNXoJrXnbdIRk2G1GbYxjnmh8QeWChzqdAICldPE3FWDQavDHD63GN9+2CDPKLfDAhD+NLYM/mnq59tSeDvxhZ1ucecywhbDxFfyYyquAD11GDDX+8lYHvrnpOG740ev40j9j+y4NJmOyTm0mt81X4qZlzbhhQRXuWdeUcn+VMOmkPn0p1wdyueDJzcBwbszWpHqsBEk6sW9VpkwW3V41yKKSUGMRPmH/KTYEvo/2ymul59MJsgCJzTodLxn82dZWhKMCLmkpiWdtkR6T5Q9FYgMklvirXiJtpAiyADI3cxwX1/KjWuY+yVjo11vJWEIs5sm5VpALpsDKlSvxn//8R/yfBVa/+c1vLqqmwJOJEl6WHUyXFYpGZUWeDvHhsJks5HkuftGQDqSarDSYLMbMjJfJYrbehvSsgZUYkskFWQPiilwW/ucKYpCVfa+s6QyHWQ8T6AIwjWbEvaN+LP36FnzqL/txspfqxCuVQZaU+VWaXzx/uBcnel0os+rxj0+sxa6Hrsbm/74iqRUyEMtkLVDY9y+uc8RuzOSCficQTCznSQZ5VrlYFmSpLZiq7czEIXaxGwhHRKamrjiHTNYotZxzTC83SzlS1WQxabTyvJhbbYNRR6ZhjpOagYuBNXNEzaL3XYXNKNb2laSwb08XDrUAPOgWW3vAUkYYINq0Vx8g80fIomgFYa2UEl4OGmTRmqxhTwj+UOraYclZUAogmfSNMVy1ij5vostuIAxBEHCwwwkAWKK85qYIym0G2AxaRAUiYwOA5w5145UT8YlSfyiCh54+jK88exQd2tiERa+RmI6Uwan6OVfOKseMcgvcgTB+te0sAOCfB2JZiqTugqdeJLezr0N9iRm/eP8KrGjMkOkBWRuyaySl+QULTCIB4Ml3Ao8sA558N+BsT/66FGCNiFVNLwCZXDDD78e2D6gEulbJ1dFgdWAUVvSiGCihZjHpBlkzryG3irqsbqcPf32LjLOfunq26ktZkOUOhOOSiQxsfDPpNGR7FmQxAzFeJ9UQq2B+tTLIkuYlFmTtbyfmSnXFJpHd7hvzI5pBk/KphLwGWUeOED3wt771LXzhC1/Axz/+cYRCIfz4xz/Gddddh8ceewwPP/xwPnehAIpijVf6R9l1PBH8TrH5MMqkCzNgJIFPkUkHPsPCVkBaaKRXk5WgiV+mYItTfeYF+95gWJy0g5EoTlMr7CnZt+EilgsCxGHQxLGarNRywd1tw3AHwnjuUA/ePEvOMaUFt17Di8YRSgnNGdqT5+bFNZhblX5tXlkSJmtxvYLJMhYBerpPWdboVMnO1VKLXlyYqrVQqCshC9POEW/M4z1OPwSBTLAJi9+zAVsUORJPzlMdbDEw6gvFnSP+UASBMBk/lHJBnYbH4loHALLQF4NvFlgzsF45GeJd1Hp9TXMOmp0jgfEFG2u0Jkmie8VngFnXobXkSjwavhUj5ati34jjpN5GdG6xG7Uw8GQh1e1MLRmU98hisBq0oiQJiK8dtMpqsjqGfRjxhqDX8JhbnV3yLd/gOA4t1Digtd+Nf+7vwr1P7MdH/7g3lk0ECTrDdCH6RJuUBAkIOnRxZBFfjFHCPCrA85zIZjE0KCTBrE9WnGlPOAB0UaMGtsgfB9g1krKcQGsAVn2EJGcqFgAQgJObgKf/v3F9ftsQWSu0lCWYP7zjZLJE0LUTr4sJ2ErkdY+N6+hr0wzoWN2eu1dKLAP4xWtnEIoQFuvSGer7rexhGoOt3wb+8l70DhF5bbXDSMgSFmQ1riX9ue74FXFvToAFNdLcZtTxMU26WZDFXHBrHSaUWQ3gOCIfVVNWTAfkNchavHgx1qxZg2PHjmHHjh0Ih8NYvHgxXnzxRVRUVGDXrl1YsWJFPnehAAo7lwWTxbYz2AG7rBu3jlyk2ZheADK5oC+M//fMYbzjF7sQSpA5yVlNlhhkZW7rzTJ4DEep0yCjsqcULmK5IIBYuWAaTFb7kMQMuWgx/OzK2ECc47i4gnkGVijeVJoZs8MWKjxHJBNssqkuMqLCpjJJieYX2UkGGUsAECar3Gqk+xF/DbPs/4g3BLfs+7J6rLpiU25l3mKQNX2ZLJtBK5qJ3P3YW7j1p9vF2g62YOE4qPaq2kBd7W5fKqtHU0onvWkmxhT44NomnPjGDbhidnlWr1dCNL6QL3jY2Cyvd61dDrz3Kfyu/mF8N/wulNtVro9bfwLc+Tug5SoAVO5Lh9R0JINyZ0E55IYfSiZLfh0foFLBeTV2GLQJmutOAcyh49HX/n0UDz1NjBLCUQHbW2MTaSzoBIA/nZau6yHYcNxFrncdIpJ0XoE7ltVhZWOxeMyUSVA2D5bbFGPGyHnCuOqtSVmMdCEFWWkkYW/6HvDfh4FP7AQ+toO467bvBNrfyPrzmeHP0hM/AJ57ID4oFYOsLGuyGJouI7e2qpi2FazuccgTBJa+l7Bcczam9xkGm5SQk5lSvEETiB9al7jPmo72MAVUWMSdPwVOPIfwue0AqLRXEKQgy15D3DcX3pF09+qKTeIYWF0UO480K4LaumIzdBpenCunq/lFXoOs1157DQsWLMCDDz6ItWvXIhgM4nvf+x6OHTuGxx9/HIsWLcrnxxcgQxEnkxmlO2HLXXTM0sTl0rIgKzsJCpMLDnkCeHJ3O3a3DYv1MHEw50guGKTvb8icyRpSOGkxO3fVxfBkw0JrOi7SIKvYrIeZGV/oLBAEIaH0AQDah2PZmjKrQVVCJ2XAY2VMbYPk9U2Jsp4JwCQ3ZVYDNDwnGlEsVtZjMYh1WVkyWTLte6lFj6vnVaCl3IJbl8QbTdiMOvHalrNZHcN5cBYEACeVC+ZggTZZ4DhOZLN2nxvGoc5R3PmLXTjQ4RRrS6x6rSrz/5HLW7Dp/svxwbVN0oMtVxHb8Kv+H/k/XfWByn4ZdbkLINgCcNQXQoTJd8R6rPgMudL9LwaOBmDh22NMJYr15D270mGymBGD4r0vk9Wi1SZisoIRHGh3AgCWJrrmpgjuvWoWWsot6HcF4AtFxJql107GjvHyIMstmNApkOMwLNhxpM+PEYHOfQnaApj0Gvz942vxt4+REo4xX1h0+AuGo+LCO47JGibyQpS05KTHHTPGSqtXlhxVC4El7yL3WQPpLDDqC6EMo6g++itgz2+lhsEMWbsLKrZfeQ+5LYuV75XGMFmXAp85CSy6M/3PsVHpoauHBEIj5zFGA0d5DZQaVOuywkFx/WTofpO+j5GsI5lM2FaNdMDzHOZRyaDSGEUZZLFrl5kK9U9T84u8BlmXX345fve736Gnpwc/+clP0NbWhvXr12P27Nn4zne+g97e3tRvUkBOYBWyCLLkk6dsgHDyDgCS1XOmYEzWka4xsHl60J2ACs45k5V5kMXs2xlYdr/AZE09lFj1MFEmK6Qx4pofvIbbHt2RsMaD1TkwzKlSPz/MeqmWgyEaFURpiXKCSIVlDcWYUW7B25aR4IlJnuLqsRjGaeMeU5Nl1mNhbRFeeXA9blykPjmyQKpzWFrsdubDWVAQZEzW9DW+AKQFjN2oxfxqO4Y9QTzwtwMik6WUCjJoeFIQHsMOarTA9Q9Li0bvsKrMa6LBgm9BkDEN3sRBVn+yIEsFxXSzdOSCiZisJfUOMZiqdcSeq+w6jkQF7G4bErefymgoNWPT/Zfj4+tn4IYFVfjuncRkYNvpgRib835X7DzVGiVjxrBgw/EeF3oE8vtwKcYQFuQEI1FR5sr6B1r0mhiJF/kA6hDN6ofGCZHJSqecQIl1/w2AI0YcfccyfrkgCBj1hbCAb5MeVAal47VwBwjjNv9twAeeBd4Wa2rD2OKs5XFMdTTWA+x6FPjxYlwdIPb6LMGd8KVqQZYswVM6vA8ANVJi55GlnEg30wQzv1Ba/NcVm2L6OTJGtdLGaoSnp/nFhBhfWCwW3HPPPXjttddw6tQp3HXXXXj00UfR0NCAW2+9dSJ24aJHbJCVrlxQZnsuG1AGQYows5ULskGcTZJAbBYuBrmqMWL65CzkgokCwCnJZFlpkDXaSYxLLjKUWvSiXLDPx+PMgAdHu8fw+51tqtszJuuO5WRBsnaGusU/kxnJ5XO9Y34EwlFoeS5OlpQKRSYdXn5wPR7aSFzW3rW6HkvqHarMEgDATt3msnQYLLMaYNFrwHPp9a+po4vTGCYrHz2yvMNAiI5NzFFvmuLDlzXjqjnl+MtHL8Xv7iY1SG2DHlFal41JkCgxigSAkDf5thMAnYYX5T6iZFBNLkjB+mmpyVLVUGwgQUOPM3XWOhFLptPw+NiVLVhQY8e6mbELYeYuCJAkH5AksTGFYNRp8Lkb5uIX71+BGxZWwajj0TcW269ROYeeFsiYNgQ7Rn0hdAvkXEoVZFn0WrC1Lgukj/WQWpz5NfZ4NlbOZOUA6fZrUkXpDMlhT8XGPBV8oQjCUQELOJlTYaIgK1MLd3mQVVRPGNyWKwFrrPN2SaJedOnCRucQVzdpAQFgTpQEwqLxzmgX0HsEcMcmY1UDXNl6sc5zFFqESU87dlzSZLEY3rW6HmtnlOLdq2Pl4VoNL9YBamW9HStl5hfTERPX/IFi5syZ+MIXvoAvfvGLsNlsMa6DBaRAOAB07SX3o1Hg6D9jmwWrgfbEMkWlIsi0pScxckFpsuoXiLwiW7mgTcV5TR5wxYC56vid6QUNoQQZ0HExWeqD3ZRkssrnEuty3wjQfzT19hcYSi1amDjyew0GJJnUo6+0xv2O/pDklvf/Ns7DSw9cif/vCvWFAnMYlPduYfVY9SVmaDXjG0pvW1qLf31yXWKWSG5XnAU0PIdffWAlHn3PclWzCyVEJmtEhcnKqbMgZbGslUkLpqcDrppbgcfuWY35NXaU2wzgOSAqSOdJIiYrKfQWkvUGxs/m5wis6bcYZMnt2xVgC+V054pi+lXT6ZXF5oxyFROWezfMwn/uvzwuEcjznMhmAaSWMqFV9xSFUacRreq3nZIWyQP0eLDv86phA9oNs/FMhNT+dFP5YCrJMc9zcW1WjtKAVOkOByBvQZZan760UDqT3GbSqoaCna+LNOelB+VBViRMmqcD45MLJjH5YZLcROuOlJAzWcNtAIAqjqz5rEYt6Uv4wwXAL9YB35sJ7PuT+FLVAFemfDIIASzk2kiijgXrTMqeJuZW2fHERy7Bqqb4IJUpQqqKjGJrEcZkFeSCaWDbtm24++67UVVVhc9+9rO44447sGPHjonchWkLPhqE9gezgV9vIBrhI/8Anvog8Oc7SfDRvR848ESspOTI08C3aoFjz8IUkQVZ6U7WHlmQZZEGiJ4IGWiz6ZEFAEUqlHVCJssgG9SDCeq2GI49C3yzFtj/5/jngozJyqImi05e8n5CPJegX8hkQ2uQCmrPZJ7Jm+4oM0iB+IAsyHIFwvjJK7ENXTtHfBAEIoEpsegxs8KaMFiS+utITNa5oexML7LCOOWCAOmblEgeqIRakMVqsnIqF7wATC/UoOE5lNDalbPUEtqeDZPFcVLGPEvzi1xDbEjsYXJBxmTFLjqjUSFhE+aE702H1HRqsiRL8czGYXmj8NuW1k7LXp1XUiOT109LCg/WS+h9lzTi/Zc04j233YQ/LPoDtkWJvLCbyQVdXWSd0LE7YVKSycpGqcMfM3uSu8OJGKJywdIZ4/xWBKq92DKBlfZkc2UfZC3kZUHWmCzI8jsB0DVWurbqDCYFk5UALAmWKsjqGPbitkd34PnDCqaNMUuubmCEMHIV3AjMeg1pdN5/FOJ3AICzW6XdUg2yYteLK/mTqHGYYk0vcgQWZMmVIWKj94LxhTq6u7vxzW9+E7Nnz8b69evR2tqKRx55BN3d3fj1r3+NSy65JN+7cEEgyuulTNH5ncC5reR+9z5g67eAx24C/vlxoG279KJz24CwHzi7FYawLEBJd7JmTUIVTFZnkLjXFGfZeyUjJktnlDK5fvVGitKO7abdzrfGP8eCrAyML473jOHMgFs0vpBbujLDgimJGRvI7ZlXJnc/JgFlBolp6vOR4Y1l0f9zqCemhqF9mCx+G0otKRdaVhV3QdFZMMN6rKwgygWzD7IyAeuD1ekk7JUvGBGv0ZzKBVmQNY1NLxKBydjOUpv/VL3TEoLJjLI0v8g1xCArBZPlCoTFnF9cHU8COPSSXFBIUYPGrsVMj6u8PvP2ZZll4acKmEEOC+ABicmqcZjwjdsX4ubFNTFSSonJ6gSO/xv47bXEPQ8A3vot8K97gSg5NnYZkyUIAo71UCZL0W4C4aDU5y5HTJZDPL+yZLKY8YM785r/UW8INnhRD9lr5cYXLOAwOkjdZCbQWwAN/T2S1J+WyBw8k10DT+3pwMEOJ/70xvnYJ1iQ1XtYXPdUck7pOmFMHAP7/SBdpzH2+YpxZzV/AtV2A9BGCZIcyryXN5LAVW4AxVrl9E3ThsR5DbJuvPFGNDY24ic/+Qne9ra34fjx49i+fTvuueceWCzTi6KfCojW04C0fRfJQjFs+z+pruHMy9LjzKrV0w99pkFW517gBJVyNl1OBgWdGSidiQGaUEh34lRCbVJMyGQBUgPhQAomi2Xl1OpWMrRw7x/z420/24G7frFLzKDMlUklKqdijywGFmSd35lYPpkKo13Ai18k9rzTCCU6MjF7BQMG3OT+hrkVMGh59LsCODMgMbrM9KIxDWZGbGIqkwueo86CmZpeZAXGZAVd8ZNkHiD1yiLnD6v9KLHos77uVcGcBS8wJguQeqGd6R+HXBCYgkyWQs4lr92VgdXzGHV82hbpzMLdF4qklIt5aD8yufwvHch7AE3ItZsHMMluz6hPbH8ySOfQCllgJe8h1s8z44tuoO118uDRpwkT9fz/APv/BHQRYwMxyPKF0OX0YdQXgpbnMEvR3gLOdmLfrjNLDNI4wYL47JksWuOUBZM15g9jHqeY81wyiXa29u0AYaXZ69KQC4YigthWRA37aSPtDkU/Q5FZGmkTHyqHE0UGmkhk80cJZR5l6yXVRtA0sHTbiQzzcs1h2LZ9jVjla03EITRHuHFhFZ7/1OX47PVzxceq7IWarITQ6XT4+9//js7OTnznO9/BnDlz8vlxFzyEBhpkndoMDJ4i99nAxtGJ5uxr0gt8TnLr7oc2JAUoQqqMaDQC/OcBAAKw5N1A3QrA5AA++SbwoRfFyS9b4wujjhdtaBmSB1k0uFHrlC4HCyiUvYQiYcLoAQnlghFFN/Etx/vgD0Ux7AliTxtpBDq3SmpYOSXrsRjKZpPi10gA2P846f6eqQnG3t8DO39C/qYRHFoaZMEgOmLVOUxY2UQyZDvPSNIHMchKQ+6n1ierTZQLTsBCTW8h2VNAYn/yCMZkOb0huPwhHKI9hRbXFRV6ZKUJxiKwur+s5IIAYKaypCwbEucaDiWTxYrnFUwWW6hlEpTreCk4TSYZjEQFsUF8pkEWM+F49+rpy56W2wwwaHlEBeLEKAiCqhFIhWye8plYrU4XYTkAMi/+7YNAlI5rdG3A5IJj/rDYsmRWpS0+WM6xfTsgBfFyJmvYE8S1P3gN39p0PPUbWBmTlZ1ccCFzFmTrDjUmK9N6LIayWeS2KnH7IpNeAxNtu5DI/CIaFXCQBlndTn9smxIVIwotF0WdgQZjLMiqXEBuXT1AhF2r9HePCbLIuPO8fxG2RpbAhCCw66fkuSseBIpz5wrLccTiXa+VQhNm5T7kCcbI9acL8hpkPfvss7jtttug0UzdRn/TCQJjstjCpHQWsf9suhx4F61D6jkgTcbs1t0PbVAKULigm5hoJMKJ58j7GIqAa78uPe5oACyl4uRZnKXxBcdxcZLBQXcAgiCgtd8d35jYSAe7VHJBxlaNdYuyB/K4rB5NJcg62j2KJd94GZvapcthyzFpgA7S/ZEHWeVT0VmQgeMkNmvTZ4DH3w4c+mtm78Gy0/1pTGpTCAaBnNc+wSD2XiuzGUTXwB2yBp4dw+lbklv1sUFWJCqgfWgCmSwAqFlGbk9uzvtHWQ1a8frucvpwqJNMzDl1YvOPAgP0/LqAgywGNZl0WhDbWEwVJksWZHmHJfMSZjhAkU2QBYA4lyF5Q2JvUFpsWVQaPCfD7+9ZjQeunY2v3bowo9dNJXAcJ45b7cNejPpC4jwlt7SXn4NRaxUEcOCiIaBzj/RmfYel+zQxK2eyWD2WuulFbu3bAZUgHmTcPt3vxi+3ncXucymuAyYX9A0TOWO68A6j+ejPcJeGJqpnXk1u5WZD4w2y7vw98F+vSAFOAqSqyzo35MGYX5qLYq4VawWA+IC3Tuskd1iQVTaLlGIIUbHWt8ic2MK91WPE5zSfRbiO9FFD6Uxg7f1Jv0cuYDfqxMSIsuXKdMCEuwsWMA5YK2MHs4Y1ZDF993PAnBsJgyFEpbosJhd094MPKCRGySZslp2ac2OcvSggDX7ZGl8AUlZ3NpUfjHhDeOFoL675wWv4pjJblSmTFQ0D7n7pcRZ88TpAG7/Pu84MIRCOYks3h84RH9yBMHa2xpuDzKywSo43U5nJAoAl7wQZaOlge357sq3jwQLagRO53Kv8g/7WXhhEaWCZ1YC1M8ik+MbZYZG1PD+cPpNlljUxBYADHSMIRqLQa3lSBDwRWHQXuT38twnpmSTWZQ37JCarNkeNW9t2AD9ZSSQtWiNQOX0XvIlQrjBkSNWjJiFEueDUcBesp1LS5w72oPvIVvJg2ew444tsgyzWYqDb6cOYP6QqE/LR65DjAIM2s2XMoroi3H/1rJhs+XREPc3wdwz7RBbLbtTGNJ+Wtxkptprg11FWNBoCOJXvTxOzcitvVo+1QFmPBeTcWRCQAgyXPywyNOeHpNqzrzx7NE55EgNTMZnrgczYrDd+jhVnH8U8niYN5tH2Qn6ntLbI1r6dwVJKlEEpIK/LUgNjsRhiJIManeq6rVZD14BM4WQqluqpqGTQrnCVBCB+5xFY8V8bFkD7/qeAjd8D3v9MRv2xxgM2RxeCrALyj4a10n3GbDE0X0lumWSQXUwhT+Ku5RRfffYo7n5sNykKZrVPxvgFlT8UEWUajiyNLwApq7u8oVhsQPfsQZIxYnaxIth+pKpFkfeRkddlpTC9YMYWUYHDo1vP4rWTAwhGorAoZCjlNoOYUZnSNVkA0HwF8FAn8I4/kP97Dmb2enYOeAfH36NsIkHPAR8MCEXIRFxmNWBRbRFsBi1GfSEc6x5Dl9Mn9shqLEnNRFlZTVYgDEEQ8J3nTwIAbl1SM3EGKPNuIQHJ4CnCNOcZzODiYKcTrf3kGlpcn6Mg682fA55+kg1971OS7fAFBCWTdaEYX9y8uAarmorhCoSxdcuz5MGGS+K2Gy+T1e304c6f78SV3301rj6HJTsseu20dAfMBRiT1THiTdgzzGHSifNrqdUAr14WCJfNiV9D0MSsxGSFxWtfruQQMUgdW3MYZBWZdKLy0EnPoTbZ4vp4zxj+vrdD7aUEHCeVUWQSZHlIYnZ7ZAGem/FVYMHbSK0ZINm4s2swm5qsDMAaEg8l6NF5QBFkyZvGA4iRDEZB5q5KauMurqOMDslwiNbGMpOo/rGAaLoRcpO1ohM2vHtNA6mRX/2RCVUfMHOptkKQVUDe0XipdL9+TexzLevJ7dmtpAZHHpREyMU6ItBAQzZh+0MR/H5nG7aeHMB/DvVIjXsN8YMqq8fS8JzYlDIbsKzujHIrSqkGfzu1o+0ZUwwYmTJZQGxdVgr79mHZQPbMgW58fwtZQL97dYPIuBl1PMx6rbjwzKnDWr5gsEoSs/4TySWiSsiP9cDJ3O5XPkGZLJ8gLTbKrHpoNTzWtJCJ8av/Por3/+ZNBMNRzK2ypfVbymuyXjrej91twzBoeTx43ew8fIkEMNoJuwwAh57K+8cx9u+X284iKgDVRcbcNeBmTPPVXyEJgQsQZXFM1oVhfKHX8vj5+1ag1mHC7ADpxRdVLtYhBVmZfm/GDO88M4RTfW74Q1GxfQADk+1mWo91IYGZX3QMe0VnQeX1yfOceB6WWvTw6WRBVuUCYON3gZUfApa9nzwmMlnMwj0o1sbFyaojIaDzLXK/ZmmuvhY0PCcGeawmiTFZTLIol/OrwpZFkEXXPa9Gl6Gr4VYSrDHpIUtSi8YXWcoF00SJWJeWPMhiSd+E5hcAesxkjioTlEFWkRRkUYfBuVV26DQcesf8Yk1zyEVqLk32ctFld6LRTOue24YLQVYB+UbzFURHW9QgFVEyNNAAbOg04BlATC8EivMCoZFDrvjaFAB4Yne7xGKoBVk+ctGTbFP2GcSbF9egucyCa+ZXitk3pjHuGw0gKpcDpFuTlYjJYkFjAmdBxmRpOQGRqCDa4t60uFpsmFdKe958/baF+NLN88UanymPonoiC4iGMquvkh/r6SQZpOeAF7Igi55fd69thk7DYe/5EZwd9KDWYcLv7l4FPg0mivXJ8gTDeORlkr398GXNqC6a4GB78TvJ7dGn8/5Rd66oR6lFj2CYMNdyW91xw0PNEizluXvPKYYLlckCSAD56/cswGKOyMWe7InvlZO1XJBKsQ93SUnCmBoREPdBIPN6rAsJTLbZMeITe2QpzzlAMr8os+rh08sYmMoFQPVi4OYfSmsJRU3WuUEvguEoeE6ScYrofIskMM1lQGViI4dsoDS/YDKxd6wk8rYDHc7kFv9WRXCUDmiCzgOjdM7aFE3gGaM1QUzWsCcU95w/FMFxKuHcSPsetiuDDxYcAjhtJL9NSUQlyHLEBlkWgxZrmkkA+epJMkbzfhJ4V1Xlrh9WpmBMVnuBySog73A0AB/eAnzw2Xg3H3OJ1FOKuQ/KIGiNGOZJcNDdI/XbkVOwe8+PwDVGL8YkTJYjS9MLhnevbsCrn1mP5jJLXO1CMBLFsDyDI1q4ZxJkSd+vd5BKIxMwWUMeMkG9vTmKL980F1+9ZT4eu2cVljUUY3UzDbIo27awtggfvqx56vbIUoLjgGrSjDIjyeB0ZbIom+mjQZZey4uM62WzyvDqZ9bjXavqsaqpGH/68Oq066nYYm7UJ9UovO+S3LkqpQ3G+rh6JDlwnmDSa3DPuibx/5yaXjAJ6oUcZCnGNVu27oJTjMlimC+chZ4LY0Aowpe3e8XMN0PWckGVa1IZZDEmy6S7eJksqWZSYrLUgqyZ5WTeay6zwCeXC8rrIJlzKZML0t+stZ8kXCvtRtLIVg7Wh7FlPcDndikpN7/wBsPop3LIjYuqodNwGHQHYxqlx4HVJGXCZFHFi1eQB1myYM07LNW716auqxoPSsUgK1590jbkQSgioMikwyUt5PfsiAuyaEBkdOA8T2R9RWGa2FJjspyS/HL9HDImbz3ZD0TCMNL2Pw31k+fG2VyQCxYwoahZCpQ0xz8u1yKrBFmcsQh6G7koe3skx5w2xeTYP0AXQKpBFjO9yF2vHKWsBlB09zaky2TJBl2amekY9uJb/6JOSgmYLObgU2kS8P5LGnD3umZcNYcM0rcurcGcShvumKZNKwEAVYvJbdZB1jRispjxBZULlln0MYxrXbEZ3377Yjz1sbVoKU+/MbWF1mR1jvgQiQow6nhUKzO7EwG9RZKqjCapS8gR3n9Jk1ibuLTekZs3Dfml88syTRjhLFAkq4cBxiEXFJmsqWHhLqJ9FwDgnHkRIlHgb3tiz8fx1mSpvReDl9VkGS7eIIvJ94Y8QXEOVwuyvnLLAjzxX2uwflaZIsiSOdyZYtsEMJk8q2utVUtGnXmV3M64ajxfQxVSL7agyGI5zDpU2I2YRyWDyrqkGChlfumABlluGKVrVXyfHuDIP0jZReWipBbsuUAyJmuEPlZuM8SwmTFgNa7FTeiOOgAAliBd18UEWcz4Qrp2ry/tRzlG8ObZYXhGB8THZzVOXpDFjC+GPEHI2txNCxSCrAsNLDOsEmTBWARbCQnCnEPS4MP6/Syhiyivi07mKkEWk9Yx95tcQG1iiLEkZXLBdJsRA6Il6blBDywg7xXWJQiyaE2WVWUtUF1kwgufvgJ3r1MJaqcLMmWyBCH2WPcfB17+OrDpf8R+GlMWCrlgmcq5lQ2YXJApVJpKLZNXcM8KjiegX1aRWYefv28FPnfDXLFGa9xg7QF4naq5zoUCeT0MMA4miwXVgbHMLKnzAXlrDDqe2GcSmfrT+zpjXN/GsgyySi36uD6KiZgss/7ilQsWmXRiMLSvnczZFSrjXZFZh7Uzy8DzHDx6yvCYS2PqdmBykFvKjhcplCq1yrpV3wjQTRoXoyUfQRZjskJiPVYjrcthyR5lkNU76scnn9iHna2DMuOLfqQN0ZlWxmSxY+TqAQ4+Se4vfU9mXyYLlKjY2DOM+qREN6vLG3AFiGkZw6zrSWufS+9FR9gBADAF+sn8HaJJdZNDJhfsJJNb/3HUPXUjnjY9jEgkhK0HiIrFKViwoC6/EslksBl1Yv+8gWnWk7gQZF1oYIOLmsTLWISqasLI+Ef7xYmLZYreu6YBc6tsMAk0WFEJsiTtd+6y+OpMlixgStv4QiEXdHYg1HcSZhpkOcPxgWEgHBG7qttyR85NLVQvJbd9R0lj5lQIukkrAAZPP/D694HdvwRe+mo+9jB3CErugoD6uZUNlLUfE9KAOBFUJB75xBWzy/Hx9TNyF1TK67EucGc4lkAy6zXxcqt0YSyC2IphMtksZzvww4XAU3eT/2mdyow5i1Bs1qFvLIDXT0uZ72yZLJ7n4modE9VkXczGF4DEZg26g9DwnJgoTQSXqQ6R674FvP03sdcekwuKTJYiyGJMlncY+O31wK83kDmibA5QlHuVh1wuyNYnTZTNSBRkPbO/C/851IOP/mkvuiM0eePOoiZLTS54bhvQtRfgtVIrjTyCMVlqzYjlJRsOs040o+iUm19Yy0lrn8V34XyIrOP0gZHYoNNgB+y1ADjSlNozCJzaDE6Iol7oxk38m/jnjkMAADdvn/T6RzbnDvin15xRCLIuNDAtcgImq6KC0MhFggubDpMiTsZktZRZcN+GWbByJMBxgwysbYMe7DpD6pqYNlotY5Yt5EzWnEoyIKgyWcnkguGg1LUeIIHBz9fiiq13oZpalw4E4id7JhXU8hxMF+p8XdJC6tHCPmCoNfX27DjzWknbzRZ5u34KHHs2L7uZE4QUckFrbhhXpSypsSx1b628gTFZEyAXzAvEeqwLVyrIwM6/rFksAOA1MqZhEuuyNj8EuLqBk8+TrDc1AdA5anHbUrLQfmqvZDiUbZAFADUOksQz6viY92LwBFiQdfEyWYDkMAgA/331LMxIQwIdXfURqVk9A5ML+p2AIMRJW8U6uSP/ADrekPpjzb4+211PClEu6AmJdTiMyWKB5JGuUYQiUjKwhyZm3YEwHt5GrxNX+jVZAlVveGCUgszyeeSWJYZmXksCmDyjSNanTAmneF0RKXydrF+aGjp9JgQEep0M0uS7wU7GFa1BSsyPtkvtfwB8XPcceDrehA3F4/5O4wUzvygwWQVMLliQReVyMMjkOMYicHRhU8y58I+9nQiEI+imFq2NpRbcuLAKdo6cxf88ThbbH/7DW3j3r9/A+SEPBlzkuYocNuOVB1lXzSX7H1uTRb5DwDOCUW8CuZqcxWLmH4Ex6CJeLODbAAA9vvgoivWhKLHoL9ykOs9LfTM8A8m3BWTuknag+XJyf+N3pe7uW7+d+33MFWiA6AJZfOSKyTLpNJB7nTRPJpM1gXLBvOAicBZkYGNb1s6CDJNtftH6EnDiOXI/7Ce/IXNas1XjzhWktuOlY31iA9nxBFl3r23GpS2l+OClTQDiF5veIEmoXcw1WQAwu5IEVaubSvCJq2Zm/0YsiI8EgZAPFn3seCfKBVkd1tL3AXf8Grjyc9l/ZhI4LHImi8oFKWvXXGqB3ahFIBzF3vMSsytPzO4domsATz9pZ5MKghAjFxSTIpXzgQ+9AFz7deDSe4EbvjXer5YWxGbQvnCciyJjsth11UCPy//+5xg2H4ll7gRBgCsQRr9Ag6R+Wl8tl2mz+aRzL9D+BrnPazGPa8PtOvI/b82vZX06YOYXA77ptVArBFkXGlhWgqFc1sfHWCRmrIo4D/acH8GO1kFEBcCi16DMqgePKMwgQdeO9gA8gTDOUEvz031uGZOVO7lgS7kFOg2HedV2seGhGpPlc43gC88cVn8TVo/FaeKa5M3kSMDZ7ok/3cUas3G6JU55pNvQGZBkmUY7cOtPgE8dIs0HmRY9EwnGRINm3py0H1yugiyO48S6LEDKqk0KRLngdA+yLgYmiwZZ4zUKYuYXk/Wbb/lK7P89B6lygAOsFZhfbYdFr0EgHEXbkAfRqJB1TRYA3LCwCk9+9BLMpsqGsQTGFxc7k/Xhy1rwjdsX4lcfWDE+x1u9lSgXAMA3Ao7jYs7ZOoeJSM3bXicPrPowsPgdpBdjHiAZX4QkuSBVD/A8h3Uzydjxod+/hWcPEtkqS8w2lJgxCDrfRcPpsb/hADiBnFOcwRrb1qPhEmDdp4DrH1Y3HMsDWK1dMBJFIBwbJIo1WfQYvXtNA2xGLc4MePCxx/eKjpAAadodFYBu0CCp7wi5lQdZczeS25e/RtQulgpg1UcAADdyxNymunryjb8KcsECpgYYk8VQpgiyqPa6VEMGru++QGSFjayQnzXuBXBkMBpjy9vl9CXsLD8eVNiM2PLpK/HEf60Re3H0jsmZLDLRWuHDjtYB9f4YjMnSmaUaJFowXsaRoKHTw4kF0wzMIjWXRh5TEixTmU6QxeSCBhuRExRTq3J5oJasR8lkgmb6fVoSmOfK+AIAzLKs+aTWZCl6m0w7XAT27QwSkzXOYKBxLbl95X/Tu4ZziWgU6D9G7jP5MGtCa60ANDrwPIc5NEF2rMcFdzAM5oExngCTBWjx7oKFZsQAMah4/yWNYg1T1uC4eBt3GftaW2wiNUmBMZKoZWZKeQIzvuhy+tBNZYANJdKY+/XbFuLSllJ4gxE8+LcDGHQHxDXDyqZihKGFS8PqstKQDMrWPXpTfgLHTGDRa0UmUZlgULbRuWpOBbZ/boPowNfllNZO7LUdoMn37v3kVh5krfwQUa2wY9ByJXD5A1ItPACddfITYpe0lODxD63Eh+ZEUm88hVAIsi40WJRBlqxhsbFIXGxbBQ84RMWmdoyKZY17g4IGHa5oTHFp54hXDLJyWZMFEGag2KIXbbF7Rn1iMHWUtrnSclEEfW517bEYZJmAW34EfGwHsPS9MZu4BROOdsfWdcnlghc0MmKy6DZyqan8PaLhWHnmVAJ1x+Jp5r8yh+cpK/w16TSozKFcNmMwJss7JEpcphUuopqs6xdUYd3MUrx3zTh7ql3xP0BxEzDWCbzwhZzsW9oIuiQjHGZd3bGb3DIZMoC51Fr7eM+YKOs2aHkYx9HLirncJa7JuriDrJxCrPuj5hcmMt49YHoO5ifvAPb8ljzffCWp58kjWADR5fRBEIiFt7y+ttxmwOP/tQb1JSaEIgKOdY9hkPYKW9VExv4xga5p0pnzaIDhE/SwmiahNYcCPM/BZlSvy1LKBdl9Zk7CzDJO9I6hjwaevRqaHGHtWFhADZB5fdWHpf+bryTJkw1flO3Q5F9npVYD1jSXoGiaLdUKQdaFBiWTZauWMhLGInGhzAlR3LFAWkSzLAirx/FyZgActhyTskBHusYQpunJXMmwlKi0kwHOH4qKE+vv3+pHWCCnqg1eHOx0xr+QyQV1JsLAVC2UFqMUXsGIw12xA24+LOmnJDIKsqjcwGiPfVxnliQlE51NTxdUGvLu9Utw99omrGjMXcEuc3FqLDVPnn07QBZDLACeIIfBnOIiqsmqcZjw5/+6BNfMr0y9cTIYrMDtPwfAAfsfz6z/z3jBrnWNASildT9de8mtzAac9S860TM2rnosOVIxWZPteHZBQeyV5QRAmKx3a17G/cITwLnXgEN/Jc/noS+WEsUKZu7quZVxY66G58RGy7vPDUMQAJ2Gw+I6MjY6o7SOLFV/TUBMVrnl9u2TDBbkjvpi1TfsWlCyl8UyR8a950dww49ex71PEOZqUE/lfixZomydsebjgNZEyi1a1pPHVn4YouEV67VZQMYoBFkXGpQ1WUaHtJgxFpEgREMCpC9uqBE7i4uNWekCO6ghQdfOM4PiWx2iwU2JRQ+9Nj+njlGnEQOenlE/3IEwnj3UIzodWjmfuB8xYMyKvOGwIzbIcsOIYwoma7jAZMVDLheUg+Mye5+JRjgoZiSvWDIHX711AbTZ2margGXNJ1UqyDCdJYMXUZCVUzSuJU6hADB4euI+ly66YXJIdt2sblPGZM2vJuPF8R7XuOqx5BBd1nwhRGU9uKSarMnPsF8wUMgFl+A0vq79PXnMJgXT+eiLpYQyyLpmXoXqds1lZN3C1imVdqNoaz8SoYxUkrnKH4ogGI6KCh6vMIWCrARMlhhkKfaTsX8j3hCOUYVSFzU1GzXVxb65MsiyVQJ3/wf4wD+luUWjBT5zmiR35t823q9z0aIQZF1oMFgJ48BgKgbq15AMRSWVelBZQDHvxWP3rMIn1s/AzYvpZEknz6iOTJis4ztAiiiB3EsFlWCSwd5RP072jiEQjsLLkYWtHV4c6lQZNOVMFoOSyYJRpM8ZGJNVWgiyJLAFlMEe/9xUDrJYDyGOj5c65gCMyZpU0wuG6ewweBHJBXMOFmQxC+2JALvWjUVAkWKxZpeCrDlVZLzoHfOLttu5CrKiAuAOShl9T8H4IvcQmSwyjl7pfQE6LoLjjiuB+/cRd731X5BqdPO5K3qNaN9vM2ixskm9EW5zORmLD9I1QXURsV8vMukwBlYCoc5kBcNRXP3913DTI69DoMk571RisoxSgkEOpzfW+IKBBaZObxBDVDrJ4DbHroVUm8DXrQCar4h9zFpODK80U+OYTEcUgqwLEXLJoMkB3PYo8NlWyWlQlrFaXOfA/9wwV9LNUyaLM8Y3ImbIpemFGliQ1T3qw+k+MviFdCRjZeO8ONI1iogsqwlAqk2RBZiCokmiRzDGGmoAGBKNLy7wQSQbJkspF8z0fSYazEXK6CC29TkGK+xf3Tz5PUOmrcOgIBSYrPFgUoIsJ7k1OgC7IsiSMRxWg1a0k37zHCmkHe+C1ajTiKoJefsOLzUwshSYrNxBrMlyAgBaDGQe0M25niQvr38YWJ8fy3Y1sKDhijnlCZUzrJUGWw+wcoO6YhNcAl0LJJirekZ96HL6cLrfDaeTBJZTUS445peSC8FwVEwwOEyxiWE5k8VqzRm05mKpDQSgHmQVkBdc0EHWww8/jLVr18JsNsPhcEz27kwc5OYXbMFpll1gisE0BjSjo7dIF6GW58SsEpBb+3Y1MDnWqV4XWvvJ/giUVSnVBuAJRnB2wB37IhUmy81ZMSZI/3tUmKzhQk1WPOR9shK9j9q5M9lgTJYpP0HQg9fOwfbPXYUNc8dZX5MLTNeGxAEXEKFZVnOBycoYk8lkmRxJmSwAmEclg2+eJQmPXCxY1eqyRLlgoSYrd1AwWRUc+d1ntsyYlN2poAFTIqkgIDFZDCxBW19sxhiSB1mDskBkcIgkBbyCUWx6PNlQY7LYNcBx8Q3O5UzWoILJsht10tgBSGvAAvKOCzrICgaDuOuuu/Dxj398sndlYqFkspRItuCmC2yTVXpdY6kZdbLO8rlsRKyGRbRw9XDXKE7TIEtrIgv+OQ5SuBknGRSDLGk/nd4QugVpIecRjHD5w2LRNCDVZBXkgjIEEtRkZfo+Ew3WqNWsLi0ZL3iei7kOJhUsyBppm9TdyBiMxdJbAf0UOZbTCWKQdW7iPpMlVIxFhH3kZYGTvFYHwFyZZBDIQX8wxNZlMRQs3PMARU0W3P3kVmmmNUH48s3z8Nnr5+DWJYl7NFXbjTDIWC5VJiuBXFAuqRtxOgGQROzMism3cAdk571fHmSR9YrdqIvt5QUpUTziDcYxWTajNrbHV4HJmjBc0EHW1772NXz605/GokWLJntXJhbM/EJrIn2OlFAOpnLQIMtgKRJ7u8wot6LGITFC+a7JWlhLBoBjPWM41Uf2x2gjWbZrozvwkv4z8LW9FfsieZ8silFfCF2yIEugphh9Y2RwDYQjcFHZyYXPZDnIbVpyQVkNRtz7TIEgKxIGBk7F9+rKM5M1pVBKs8tDZyZ3PzIFq8eiPewKyBBsoTR8duJ61YnjgYOoIuQybFtVzKaXzoj9XeW229mCzUNyJstTML7IPeQKl2gU8LAga3KY+xWNJfjkVTOTNlnmeS7GiKi6iKxT6kvMcCG5u6CcyRp2kgSdR5g6QZZdTC5ISWFljyw5RLmgJ4RBWgaxcRG5Ppc3FscyWYUga8JwQQdZFy1Y5ikRJZxMLkizPpzBJg42MyqsYg8GIP9yweZSC6wGLfyhKHpoF3ernbATM9x7MZPvxprzv4p9kYpccMQbjAmy7DYysLDO8Czbo+W58TcLnerIufGFMye7lRW2fBl4dBVw+sXYx1lNlik/TNaUQjFdbPudEoM3HVCoxxofHA3E2CXkkZiGfEOsyaLXPqsH1JnjFmuXtJRi839fjs/fOBcfvLQR71ilKLjPAkq5YDgSJY5wIE1bC8gR5HJB3wjphwhMGpOVLpplRkRVRRKTlapPlpzJ6u4jyZ+I1jxlEq5sTSJnssQgS4UhjjW+IGubT18zGwe/fB1uXFhVCLImCYURSoFAIIBAQLr4xsbIgjMUCiEUCiV6WV7BPjfdz+dNpdAAEIxFCKu8htfboAEQ8Q4jqnie942R53QWrJtRgn3tTqxudMT0lyoxa/J+LOZV2/BWG2EmSiw66Myxg8IYZ43ZBz7gIvutMYjfacjlR7dAMquC1oQyuwmtQ350j3gQCtnxynHSa6a5zIxwmEwok/Ub5x0aM3QAhMAYwsEAWaglgNY/Bg5AWGuGoDw/dOTciXpHEMngWGV6DieDpvcweACRjj2INm+Q9s09SM4BY1HceX3BgdNBa68FN9aFcN8JhCqXApj65y831gstgKi5NKPzZyogl+dw9uChtdeBG21HeOA0BGP+Ewoa7zC53vQ2REMhaGzV4AEItipx3JRjRqkJM9Y2iP+ne7wSHV8brbsa8QQQCoXgki06dbww5c/5qYJU5y+ns0ILQPCNIOzsIvOFuRThKIDo1D3GDSVS0reMrk2qbXqRyYr6R8Wx5niPC398ox33b5iBfpesPjvoBrSA3mzP+nzK9fhg0TPDl6D4nkNukky2G7Vxn2PVE8aPsLyE6bUbeJh1QDgcBmevFxf8Ia0FmIbXzdQYg5HRPky7IOvzn/88vvOd7yTd5vjx45g7d25W7/+tb30LX/va1+Ief/HFF2E2T24NwZYtW9LartQ1jMsA9AZN2L1pU9zzLf09WASg5+wx7FU8v6LtJOoAHDvTgRnlp/CtVcDYqTfRP8ABINKMY3t3of/o+L5LKlgCPBjRWswHcaKtGwtkzzs9fmyS7fvijhNoBnD6fDdO0sd39HLwUyYrAB3CriEAPF7bfQDarv345SENAA7zTWPisU33GE838NEgbgHAQcCL//4HwtrExb03uAZhALBt9wG4Dg/FPNc00IklAHrPn8JbKudWKuTi+K7va0MRgI5jb+KgW9qHJe2H0ATgVMcATmWxb9MNa6NFKEcXDr36DDpKCUM01c/f2b27MA9Ax5AXB6bpbzTZx/jSqA0VAA5t/Sc6SodSbj9erGk/jSoAh063o314E+b2+zEHwGDQgJ15+A2Vx3ekn8wF+46cwKaxY3AGAEALHgJeemEzJrMv+HREovPX5uvCBgChsX7seeXfWAvAFTXh1Sl+nbr6ydqEg4C921/FAR4IRAAXNb5wDXZjK/0Of27lsXuAh7u/HX0+DmyNYQYJuIIRIWZdkQ1yNT6cGibfq6N3UNynXd3kMc/IQNx+RgWAgwYCbSDMQcDOrS+BqS31oTHcSLd9cdtuhDWHc7Kfk4HJHoMBwOv1prXdtAuyHnzwQdx9991Jt2lpaUn6fDI89NBDeOCBB8T/x8bGUF9fj+uuuw52u4p8agIQCoWwZcsWXHvttdDp0igkFm5E+NxylFXMx0YVqp876AS6nkBNiRWVGzeSB53tQNADjcsGjADzlq7G3KUbxdeUt43g8VZSB3XXzdfDlGctfOhgD7b+nQwCq+fWY27pLKBbet6uB67YKO2f5t/PA4PArHmLMWMtebxt61m80Ebc1wwl9VjW0II929tQUtuMhiU16Nj1BnQaDl9499Ww6bnMjvE0hHD0XnBhP667Yo1knKAC7SHC5F5+zU1xbmLcES/Q+UdUOUzYKDv+qZDxOZwE2lZiI9xQpEGt/Bz4+9+AIWD24kswc2X6+zZdwW96Gdh/DEvqrZi77tppcf7yW3YAPUDdnKWo2TC9fqNcnsPjAf/8K8C+o1hSb8ei9fk/hpo//BQYAxatuhwL524EdyIK/ONZlMy/Ahuvzd3nJzq+p15uxeu9Z1Fe24CNG+fj7IAH2LcDFqMON910fc4+/0JHyvPXOwSceAj6iAdrZlcAZwBr9YyMxvnJQFW7E0+e2Y0KuxG33Hyl+Pi/jvYCEcDIh8Xv8NQf9gIDQzCV1cEw6geGiFrGwpEgq76+Eauz/L65Hh/K20bw65NvgTdYsHHjZQCAky+1AufPYv7MRmzcOC/uNV8/9CpGqKSw1GrAzTetj3k+Yj8DQMB1V7193Ps3GZgqYzAgqdxSYdoFWeXl5Sgvz5+e32AwwGCIN3bQ6XST/qNmtA9zrk38nIVI6PjAKHidDgh6gd/fQGpx7KSoWWt2ALLPmlNdBJ2GQ4XNCLslvzVZALC0QZLBzK6yQ1Mcy0xqI77YYxEhg6TGaIWGPu4KRHBMaMKzM76KW6/ZgJqzJLM14A7hb/tIxHbjwmpUOiwi9TsVfue8wVgEuP3QhT0xv20MQn4gQvTcOmtJ/HbiuTNGzp0MMe7jKwiiwQXv6o7dhwCRtGqspeI5cEGD9r3TOM+Jx3TKn7+0vkdjLZ+2v9GkH2NqeqIZbZuYY0hrNLUWOh4svB0oex2a8jnQaHP/+crjW2wh8/HrrUO49DuvYcNcMv9b9Nqpfa5PUSQ8f4uqiCGNdwiajl0AAN5WldU4P5FY1VyGT6yfgcV1RTHfa35zA9AKcIEx8XFmdtHlDIjBCABYKJNVUlo67nMqV+NDCa19dwXC6HWFsO30AEZ8zKjLoPoZJRa9+L3KrCrbXPdVAEyTNH0x6WMw3Yd0MO2CrEzQ3t6O4eFhtLe3IxKJ4MCBAwCAmTNnwmqdGg4ykwKl8cXRpyUnoWHqVqaw7y61GvDMJ9aJvRvyjZYyCyx6DTzBCDHgmHEzcOtPcKqtE7MPfQe6qC/2BUHmLigZXzhpoXRX/a1A9QxUDvYAADpHvDh7ijQvfvfqxIzOBQdjEeDuS25+Ibe71efAwr3vKDB4Cph9c/r7mQwhLxCmWvqx7tjn8mzhPuUgOgy2Tu5+ZAIvlbcV3AWzx0T3ymLzBJs3OA6oXjwxnw3J+KJjmIz5/9jXBQAwG6b7UnEKomwO0L4TOPc6+X+Km14AxGHwf26ILw+5YlEL0Aroo35EQkFodHoxyOoc8cJPzVNmV1phGSFzSnnJ1BmXpD5ZYXz9uWPYcqwPWqr9S9QagZhfkLVNaQ6cPQsYPy5od8Evf/nLWLZsGb7yla/A7XZj2bJlWLZsGfbs2TPZuza5UFq4v/Xb+G1UnOUW1hahoXRi6tJ4nsNDG+fhzhV1uKSllFgHL/8AwqUke6+PKIIsFQt3p5cMqMzalPXQONg5ClcgjFKLHmuaL5IFOZBegMQaEett5Jin+x6CALz4JWD/n2Mf//uHgafuBvqPZbXLcfDKalACY7H2vKK74EVg4Q4ApTPJ7dAE2nmPF4Uga/xgFuqu3on5vGQtHSYAyobGkSg51wv27XlA+RxyK9q3VyXedopjzTypL9S+U+cRiQoYptbmPWN+jND1waqmErEmq8QxdeYOFkgFI1HsaSNzW5ie+w6zegAlf7zMmt9WOwWkhws6yPr9738PQRDi/tavXz/Zuza5YBlJ/yjQvR/o3he/jVoj2gnG+y5pxPfuWgKdRjpNdWYS/BkFZZCl3owYkOxOKxVNlK+cUx7X0O+CRjpBlrigSlB/KH8P+cK+/xiw8xFg02eAKHE2QjQqsixcrprmKu3K5WyW2CfrIgmcHY0yO+8JWnCPF4Uga/zQUxVGML3C63EhHADCdGxlybkJhjLIYjAX7Ntzj3IFIzRJPbJyAb1ejwBPlC2vHzmDYU8QNEaBIJA/jgNWNhWLckHeOPnrHgaLXiOaVsiljYC6hTsAFMv6Z5VaCkHWVMAFHWQVkABsoRwJAm/8nNyvWxW7jX5qyikNJjIIJg6y4uWCRXTgUfb3umrO1JdC5BRpMVlJemTJ30OIENtbBpZVD3mBwdPkvndQtP7l3H1Z7rQCXoWb2lgnuQ3KZIQXi1xQqxcNTLiJko6NF6KksxBkZQ2WSAp58v9Z4ljBJR4T8gzW+wgg0i4GS4HJyj0Yk8UwDeSCSUHP2QOnz2NQ1heLocSsxw0LqlGmp0GMPrHr7kSD47iEskC1ZsQAUCzr8VWQC04NFIKsixF6K8DRCerEf8jtlZ+LXfhMASZLDXoz2S+T4I99gi04VJksMtjotTxK6SDEc8AVsy6yhqjpBFn9J8gtkyQpoTMBvC7+fViTWQDoPURux7pkz+eocaqSyRqln8FYLF47ZRMEeQGTDLJayqmMkF8KzC+WQDgf0NMxLhoGwsH8fharxzLa1eXDE4DGUgt+/K6leOpjl2LtDKm5fIHJygMuICYLALQWBwAg5B3FucH4pESpVQ+TXgM7T6+jKRRkAYipgZ9fbRdrshJJAeXBV1khyJoSKARZFyM4TpIMBt0AOKB+NVC3WtpmigZZRgvJTJnhRzgckZ5gTBZdgIz5Q2JNVrFFGnhYXdaKxmKR4bpokE6Q1UYLnhvXqj/Pcerv45YFUT0Hye1ol/rz44EvgVxQrMcqwUXVOKeC2PjypzZP8o6kAfYbcRrAMDn1PRcEdLKFYL7ZrEmux2K4bWktVjWVYEGNxKYVarLyAFtV7LU5zZksDT1v7fBi7/mRuOdLLQaiG2TJnym27rGbpETC5bPK8Mi7l+GLN81DU5l6MFgsq8kqyAWnBgpB1sUKub6+fA6ZROtWkv91FoCfmhMYC7I0nACvT7bAUNRkPbOvC+GogJkVVlTZJblJbTGRE66/2KSCgCw4cqo/H40C53eS+02Xp/E+CZgsFmTJ6qW4nDFZCeSC3ovM9IJh+d0Ax4NvfRFF3rbJ3pvkkNdjTRIrckFAqyeMLZC/uqzRTmJYw5Iuk1SPpcTCWikAsBgKTFbOwXFiawho9NN/PGVBFufBvnaVIMuqp6ZZtFhrCjNZsytt2LioGv91eeI+sDFBVoHJmhIozHQXK+SZSVaPVb+G3DKWawrCYJKkYH6PzFlOdBc0QRAEPP7GeQDA+9Y0gJMxGw9cOxufWD8D96xrmojdnVpIxWQNHCdsg84M1CzL7H2UckFBiJUL5orJYgt1O22SrJQLXmwytLKZwELSWHJ2778meWdSoGB6kTswNiuUpyDrrd8AR58BXn2Y/D/JTBbDzAor9FqybCkwWXkCq8uyVk5/VQA1cLLBhyNdZL6qdUh122VWAxBkyVouptxgKkAeZM2pSs2yFcfIBQtM1lRAIci6WCEPpFiQ1bgOWPffwLVfn4w9SgucRguvQAYPMcgKB0l9AgDoTNh9bhin+90w6TS4Y0VdzOvnVdvxPzfMvTj1/KmCrLYd5LZ+DaBJIqVMJRf0jwLO9pggK3dMFmWsqhaRW8aWDZ4it7bq3HzOdMLln4EADjWje6d2z6xCkJU7sLqsYJ7kguw8YuPqFEm86TQ85lSSxWYhyMoTWF3WNJcKAhCNL2zwIhQhbNXSeof4dJlVL2tbYp1yQSWTC/IcSTCkQsH4YuqhEGRdrJDLP+ppLRbPA9d+DVh056TsUrrwcyTICnjp4CjP5uos+MOuNgDA7ctqJqx58rRAyiCLSoOaLkv+PvIWAAzKIKr3UKy9ursvu15O/jGJpQKkhboYZHWR9z3zKvm/+YrMP2O6o2IuhFoi9eX6j07yziTBxdYsOp9gLqr5YrKGFG6VU4TJAoB1M4n5RXPZRWRwM5GYeS0JTmZdP9l7Mn6IckHpOlnW4BDvl8qZrCkmFQQkJqup1AKjLnVSocZhglmvQa3DdHEmkqcgCr/CxQq2UDbYSZf3aQQ/ZwKEMQR9bmDoDFnAAwCnwckBP54/QuzEP3Bp0+Tt5FSE2IRaJcgSBOA8ZbJSBVmqcsFBclu5COg7DPQcInUdFFwkCF0kwwVhNAr89loSWN23j0g/2EK9ejG5DbrJ53TuJv/PuCqzz7hQYK8BugBuohrUZoMCk5U7MLlgPmqyBAFQtgSYIjVZAJF837G8FrPSyOwXkAUq5gKfa5uyddkZwSgxWQyzKm2w6DXwBCOxckHD1DufWI+42ZXpGXJYDVo8/6nL0wrICpgYFJisixVs0qxdMe2K0FmDwYh7EPjNNcBjN5IndGb8+JXTEATgxoVVmFc9OX1dpixYcMRsmeUYPksWwRoDULM8vfdhQVY0KtVkzbya3Ha8Cbh6Yl5mCKt8bjL0HAAGTpD37jlA950GWfZawEztnF/7NpE1lcwAipsy+4wLBIKtityZyk2JC0FW7qDPY68sV4/UgJhhisgFAdKKY3alLabWtoAc40IIsABxrrLJmKwyqx6L6sjjMyusUm/IKchk3bKkBlfPrcCHL29O+zWNpRbRRbmAycf0Wl0XkDu0XElYrCXvmuw9yRhBGmRpnW0xlt5hjRGbDveC44D/vmb2JO3dFAZzigq6gEhsB3nREbByAXEvSwYWoDOmyu+UajcW3UVuz20jza7BAcVkgjCGkljHq+HMy7L9o2Ya8oX60veQ+/sfJ7czNmT2/hcSrCTIKjBZFwlYgX4+mKwh2nPNXgto6WJtCjFZBRSQNqgdvYOXkgblNgN++b6V2PLpK9BcZgH2/p48wXoOTiE0lVnw27tXYVVTQWI9XVEIsi5WtKwHPt8+LYOskIYGWa72mMc1PiJZu35+VVpOPBcdjA4ANPvrU9jZsiCreknq92FywtMvAgG3xGIZikiQVtIC0RLXWgkUEfMRQ6ZBVusr0v3eQ6T+JEybUJtLgcsfjF2wX8RBVoHJusjAsu75YLJYY+uKecDcm6X7BRQw3UDlgqVaMm9wHFBi1qPIrMOsShtw5hXg5CbSEuHKz0/mnhZwgaIQZF3MmKZyi7CGZHFN7s6Yxzm6sC8EWAmg0UqyH6+iqS8LsmqWpn6fulVEmhfyAsf/LTkLWsvJOTX7Bmlbew0JtAAYwjTIOvMq8LO1QOfexJ/hH5PqrACg97C0zxo9WWSaHMD6h8hjvDZ1LdmFjAKTdXFhIpiskhnALT8GPrr14r62Cpi+oHJBB5ULllr00GrosjcaBTZ/gdxf/VGpP1gBBeQQhSCrgGmHsJZkcc1eGmTZ6wC9DTvsNwGQikULUIGJyg7kTX0FQap5SofJ4jhgybvJ/YNPSs6CFmr5O1vmSlVUKwZZolzwpa8C/UeBw08l/oxz24gEke3vwEnJEt5cKiUIVtwDXHovsPG7YtbyYsT0YLIK7oI5g1iTlYcgi5lelM4gZgDJeuYVUMBUBrVwt1Dji5jeUYMnSW9InRm48n8mY+8KuAhQCLIKmHaI0iyuzUcX3Y1rgc+cxC/s9wMoBFlJwVgEeZA12kHkg7wWqJif3vssfge5PbcN6D5A7luoEUXDWkBP2UR7rdhvxRhyAn1HpIAuWUDA6rEWvp3ssxAB2rbHfgeAsHPXPwys/FB6+32hgjFZAReRcE41KOvpChgf8tmMWM5kFVDAdAZN6JjCY+AQRblNFmR17SO3NcukeuUCCsgxCkFWAdMOUbrA0EdpMaulDNBbMOYn5guFICsJ2AJXZhgiSgUr5gHaNLvEFzcCjZcBEIC9fyCPseaVWj0w+zpyv2xWjFyQP/Bn6T2SSdv6aL+nxrVAFbVrP7uVfocCExIHgw1hnpoUsJYGUwnKeroCxgd9GnLB0S7SqD0TRKPAyDlyvyR9R7MCCpiSoGMNjyjes9CKj18pSxx0y4KsAgrIEwpBVgHTDzqF1SodSEd9xDGvyFwIshLCrCIXzMT0Qg5mmhKgMkAmFwSAjd8Dbv4RsOz9YvBlDg6APyKTCCos3mPA6rxs1VJPrPZd9DsUFulq8Osc5A47rmM9wJYvAyPnJ22fRLDzTWOYklbJ0w66FHLBM68CP1xApLmZwNVNgmFeCzgax7WLBRQw6dDoRJbq4euqsbbUA2z9Nmk/0lUIsgrIPwpBVgHTDwbFIo3K1MQgq8BkJYYYZKkwWdVLM3uv+bcBWpP0P5MLss9ZeQ9hxmi9kDXQB87vBPS06aOrl8jI1MCaG1srJCYrGiamFyvuzmw/LxL4dFTywhjC3b8CdvwYePOXk7dTDOz3lNfTFZA9WKAaTOAuuP9PAASg863M3rfvGLktbiJS3AIKmO5g/RQ9A8D2HwBbvwW88P+IdB0AalP0hSyggHGgEGQVMO3AKzPh5jIIglCQC6YDsSZLFmQxaV7Voszey2gH5t0s/W+tUN/O0QCBujwJpbOAOx8jj4f9pMeWEiEf6eUFkMCtcR3J3Bc3Ax/eQtoPFBCHOCZr4AS5ZRb7kwm2T8ygo4DxIRmTFQ4Ap14k9zP97Vu3kNvGddnvWwEFTCVYysmtZwAYaSP39z9O+jiaisU+jgUUkA8UUlUFTDtojNbYByxlcAfCiEQJK1IIspJAaXwRcEmufeVzMn+/Je+SXAItCYIsgw3h/3oN217ejCve9mHo9HrSs8vvBFx98UXHbGGo0RN3KGMR8OAJIhMtZNcTQgqyKJM1eJo+4ZyM3YnFWDe5tddM7n5cKEhWk3Vum5SkyCTIEgTS+w6IdQgtoIDpDKaw8AzK6oCpgqJmWYFZLyCvKDBZBUw78AZFkGUuE6WCei0Po04zCXs1TaC0cB88RW4tFdk5LLVcRVzIdGZi+ZwIRXVwG2ukCc1WTW7V6rLYwtBSIW1vLCoEWCngF+WCPUAkLBkY+JyTtk8iRCarenL340KB6C6oIhc8/m/pftCdfi+twdMk06/RA81XjnsXCyhgSoAxWd7B+PmmpiAVLCC/KKxaCph20Jlimw3/46QPc5sL9VhpQekuOECDrGxYLADgNcCHXiCLPXlNVirYqkiPEjWHQVa/k8n7FRDLZDnPkxo2YIowWXRxYy8EWTlBIiYrGgFObop9zDMA6JOYWLRtB069ICU0GteR/lgFFHAhgAVZznZieAGQREIkWKjHKiDvKARZBUw7yIOskKDBZ587j1++n0jVCkFWCijlgoMnyW3ZOLrdW8sBlGf2mmRMFnMWtGT4nhc5YpgsJhUEpgiTReWCtoJcMCdI1CfrlW+QoMpUTJwc3b3k/+IkQdYLX5DMbwBg1nW5398CCpgssGRd72FyqzMDN/2AuNXOvHby9quAiwIFuWAB0w56s128PwwbogKHEz1jAApBVkowd0H/KJGUjZfJyhY20jtLncmicsFERhoFqCKGyfr/27v36Kjqe+/jn0kymSTkShJykYT7xWrgAK0Y2j4iIkI9Aq31gi4qirRVtGLtc+Csp4qcrlVL5fSsVl3UtoL2sd54vC211aICtYqogFUUIyCCXMIlkAuEhCHze/7YmVsykxuzZzKT92utrNmzZ+89P775sTPf+f32d9cEJFlNteGrOEaL75osRrIiwjeSFTBd8JMXpH/+j7X8nRX+WJ84bBWTORHi+iyPx38O8OJ6LCQS75d13kJAWSXSv82RZv7OuqcjYCOSLMQdV4Z/JOuYsRKuz6qtC71JsjqRliupdVrQqeORGcnqiQ6vyWK6YE80OfNkklKs0Y3tL/tfaDntvxFwrHinCzKSFRltqwueOi699BNrufI2qeL7/kI0J49I/+8m6bdj2idU9fulM6ekJKc0dZl02X0dX1sJxBtvkuWdPs11oYgikizEnbR+/pGsGmMlXNurGcnqkuQUKT3XWm44KB1rLY4Q9ZGs1lLeIUeymC7YE56kVBnv9Jev3g1+MZZTBpsb/NXuGMmKDO9tLM40WddhvfOANTpdeK6VLEmt03hljWR9sd5KyD56Kvg43hHP/kOkby2SKm+NRuuB6Gn7d4RzEKKIJAtxJyMzx7d8TFbC9eVRa9oMSVYXeK/L2ve+ZFqk1Kzof7vnG8nqYLpguJLwCMtTcU3oFwKLX0R76qB3FMuVLbmyOt4WXeMdyZKsioDvrrSWp/zcX4XT++Gy+l/+Ea9PXgj+/R/daT3mj7CztUDstJ0Rwb36EEUkWYg7Llea3MYq017TOl2w9RZZyibJ6py3jPvejdZjwYjo3yvE+4fuRHX7D/1MF+wxM/zS4FL8Ga0x9I5kNdVJv/s36ZW7otcoX9ELvkGOGGe6fNN+N/zaSqLOmSCNvty/jfdLir0Bo5rHdkmHtvmfe2/hUDDc1uYCMZOWKyUF1HjjPIQoIslC3HEkJemUXJKkYyb4m3FGsrrAO5K1pzXJivZUQUnKbC180XLaup4kENUFey7FJZ1/pbXcr1DKLbeWvSNZBz60Rj4Cr9myG+XbI8/h8I9mfbHeerzgR8FflmQGXJMV6NMX/cve6YKMZCFRJSX5v2ySSLIQVSRZiEunHGmSpNOu/kHrSbK6wJtk1e+zHkvGRr8NKS5/OwKLX3g81k0jJaoL9tTX51sfwIdf6r/+zjuS5Y2t934x0UD5dnt4KwyeaJ1yW9imeE3bkeDsc6zHT573jx57pwtGu/ANEE2BX9iRZCGKSLIQl5od6ZKk8rLg+7+QZHVBRkBi6sqWxs6JTTsyQxS/OHVcMh5r2ZuEoXuKvibdVSXNelBKa71+0TuSdbL1/mhnTklnmqPTHsq32yPwuixJyhsS/LztNY0X/FBKSZdqdkr7PrDKv3u/aClgJAsJLPALB67JQhSRZCEu7cgYqwaTrnO/flHQepKsLghMsi5Y4B/tiDbv+waOqngrC6bnScn8LnssLVtKSm4t2S9/jL03oZakpnr73t/j8S/7yreTZEWUt8KgZH0h0fb/cduR4NJx0nmzreWtf5ZqdlnL6f2DzwlAomEkCzFCkoW4VPHDVdp2/WaNr6hQfj//DQVJsrrAO0LkzJAujGHJZt8oS2CSRWXBiAo3XVCyb8rgztelX5VLH62xnnunC2YzXTCiAkey+oe4t1V6nuQI+BOfP0waN9da/vhZ6eCH1jKjWEh03pGs9DzJmRbbtqBPIclCXBqQk67KkdY1BiW5/pMmSVYXDL9UGnCedOl/xbaCn6v1fmfNASMqviSLohcR4RvJqrUeg0aybEqyPn/Nui/Wjr9bz32FL0iyIio1MMka2v71pGT/Bf8p6dY1cYMmWQmZ+6T0+r3WaxS9QKLz/p3julBEGUkW4l5JTrpvmSSrC7JLpFvfsaYKxlJaa5IV+GH/hDfJonx7RLQdyToZOJJVa897em9wfaJaajnjnwLKB5zIcgZMF8wPMZIl+acM5g+zqqw5HNL4H1jrGmus0taBZd+BROQt+pJbFtt2oM9J6XwToHcrzbFGslKTk5Tm5HuDuOGbLhgwklW7x3rk4uTIiMVI1vEvrceGaunEIauQSVIKo5OR1tlIluSPeWAS9o2bpdq91rUp4+fyfw2J79yZVp8f/e+xbgn6mIT9RPrll19q/vz5GjJkiNLT0zVs2DAtXbpUp0+fjnXTEGEludZIVna6U45o31QXPecKMZK15x3r8ZyvR789icibyPquybI5yfK0+BPlhkP+8vyZxdZICiIn6JqsIaG38Y1kBdxs2JUp/ftvpIv+NwkW+obUDOmi/7AqrwJRlLAjWZ999pk8Ho8efvhhDR8+XNu2bdOCBQt08uRJrVixItbNQwSVtI5k5aQnbHdOTN4EwHtNVlO9VP2RtTxoUmzalGgCKzgaY3+SVX/AusG0JDXX+SvYUb498gKrC4Ybyfr6TdZtEcZeF502AQB8EvZT6fTp0zV9+nTf86FDh6qqqkorV64kyUowFwzpr9wMpy4aSUW6uOK7Jqs1yfrqPWtqWd5gKeecmDUroQROF2yqlTxn/K/ZkWR5pwp6HdhqPVI2OfK8I1np/a2qaaGUXyhdvyZ6bQIA+CRskhVKXV2d+vfv+H4gzc3Nam7236Szvt76AOh2u+V2u21tXzje943V+/d2BRkpenfxZCUnOXocI2Jsr1DxdaT0U4ok01SrM263kna/pWRJnrJKtfB76Jaw/TclU05JcjfKfXyfAsvCtDQelyfCcXYc3Rn0R8Wzf4uSJLVkFkf8vaKtt50jkpLTrP8veUMS4v9Lb4tvoiG+9iK+9utNMe5qGxzGGGNzW3qFnTt3asKECVqxYoUWLAhfVe3ee+/VsmXL2q1/4oknlJGREWIPAD2Re3KXLvp8mRqd+Vp7/v/oW5//Qvknd2hr+c3am/+/Yt28xGA8mvnhjXLIaNOQOzRx9299L+3Lu1CbB0f2PmnnHlijkYde8j0/40hVijmtT0qv0c4iqthFUlnNWxq/94/anX+xPiq/MdbNAYA+o7GxUdddd53q6uqUnZ0ddru4S7KWLFmi5cuXd7jN9u3bNXr0aN/z/fv366KLLtLkyZP1pz/9qcN9Q41klZWV6ejRox0G0k5ut1tr167VpZdeKqeTEuV2IMb2Chnfmh1y/r5SxpWlM3d8opQVw+TwuOW+9X0pL8yF/Aipo/6b8t/D5GiqU8uUe5T85n/51nuGTVXLtU9FtB3Jz9+spE9faLf+zKzfy5z//Yi+V7T1unPEmWY5drwqM+jbUkbHMzTiQa+Lb4IhvvYivvbrTTGur69XQUFBp0lW3E0XvOuuuzRv3rwOtxk61H8R8IEDB3TxxRdr0qRJ+sMf/tDp8V0ul1wuV7v1Tqcz5r/U3tCGREeM7RUU3375kiRH8wk5D38sedxSZrGchSOs+/mg20L237QcqalOybVftq5wSDJKaq5XUqT7ureyYE65VLfXtzolr0xKkP9XveYc4XRKY+I7cQ2l18Q3QRFfexFf+/WGGHf1/eMuySosLFRhYdfut7J//35dfPHFmjBhglavXq0kSggDvYe3uqCMdHi7tVg4kgQr0tJyJe31V/rLGSjVfWVP4QvvjYjLL5Q+9idZFL4AAPQ1CZt17N+/X5MnT1Z5eblWrFihI0eOqLq6WtXV1bFuGgBJcqZJyanW8pHPrMes0ti1J1Flt1Zq9Fb685b7jnSSdeq4/6bH5RPbtIHfKwCgb4m7kayuWrt2rXbu3KmdO3dq4MCBQa/F2WVoQOJKy5FOHpGOVFnPuTlq5I2YKn3+N8ndaD3vP1TavSHySZa3fHu/Qqn/MP/6tFzJmR7Z9wIAoJdL2JGsefPmyRgT8gdAL+FqvWDUl2QxrSziRn0n+Hl+awJ05pR0prn99j114rD1mF0anCwzigUA6IMSNskCEAe8NyQ+2foBnZGsyMsulUrH+58HVm703gg6ErzHSsuRMov860mcAQB9EEkWgNjxFb9oxQdye4wOGM3KLPKPIEZyyqD3eixXtpSeJyW3VmllJAsA0AeRZAGIHVeb+0swkmWPUQE3As7o709uI5lkNQeMZDkcUlbraBZJFgCgDyLJAhA77UaySLJsMeBc6bzvScOmSHmDA5Ks2si9h3e6oDdxzmz9XTI6CQDogxK2uiCAOBCYZGXkSyntbwSOCHA4pKtW+5/bOpLVmmSNn2sV1xg+NXLvAQBAnCDJAhA7gdMFGfGIHjuSrMDCF5I0/gfWDwAAfRDTBQHETuBIFlMFo8cb900PS6umS8f3nP0xm9tMFwQAoA8jyQIQO2mBI1kkWVHjTbKObJf2bpS2PHb2x/SOiqWRZAEAQJIFIHaCpgtShS5qklODn3+x/uyP2bbwBQAAfRhJFoDYYbpgbBRXWI/ZA63HA1ulU8fP7pjNba7JAgCgDyPJAhA7aRS+iInzr5Ru/Jt0+2apYJRkPNLut87umG0LXwAA0IeRZAGIHUayYiMpWRo0SXKmSUMnW+vOZsqgp0U63WAtM10QAACSLAAxRAn32PMlWet6fgzvVEGJwhcAAIj7ZAGIpfQ8afC3relqmUWxbk3fNPhbkiNZOvaFVLdfyjmn+8fwThVMdnFDaQAARJIFIJYcDumGl/zLiL60bCm3XDq+W6rd07Mki6IXAAAEYboggNhyOEiwYs17PVxDdc/29xW9YKogAAASSRYAwDtV88Shnu3fzD2yAAAIRJIFAH3d2SZZjGQBABCEJAsA+rqs1iSroadJVp31yDVZAABIIskCAGS2XpN1oofXZDW3JllMFwQAQBJJFgDgrEeyqC4IAEAgkiwA6OvOeiSLwhcAAAQiyQKAvs5b+KKxRjpzuvv7M5IFAEAQkiwA6Osy8qWk1nvTnzzS/f19hS8YyQIAQCLJAgAkJUn9BljLPZkyyHRBAACCkGQBAM6u+AX3yQIAIAhJFgDg7IpfMJIFAEAQkiwAgJTZOl3w89ek+4dL767s+r4UvgAAIAhJFgBAymodyfr8Vav4xbbnurZfU53kPmktk2QBACCJJAsAIPnLuHsd3921/bb8X+uxcLSUnhfZNgEAEKdIsgAA/pEsr5NHpOaGjvdpOSNtethavvAWyeGwp20AAMQZkiwAgL/whSSpNVk6/mXH+3z2slS317rP1phr7GoZAABxhyQLACANOFfKHyGNmCaVjrPWdZZkfbDKevz6TZIz3dbmAQAQT0iyAABSaoZ0+wfSdc9I/YdY6451cF3WmWZp77vWcsXV9rcPAIA4ktBJ1syZM1VeXq60tDSVlJRo7ty5OnDgQKybBQC9l8Mh5bUmWR0Vvzj4L6ml2ZoqWDAiOm0DACBOJHSSdfHFF+uZZ55RVVWVnn32We3atUvf//73Y90sAOjdujKS5R3FKruQghcAALSREusG2OnOO+/0LQ8aNEhLlizR7Nmz5Xa75XQ6Y9gyAOjFujKS5U2yyi+0vz0AAMSZhB7JCnTs2DH95S9/0aRJk0iwAKAj3pGs2q+kFre1fKTKXwjDGOkrkiwAAMJJ6JEsSVq8eLEefPBBNTY26sILL9TLL7/c4fbNzc1qbm72Pa+vr5ckud1uud1uW9sajvd9Y/X+fQExthfxtVfE45uWr5SUNDnONMlds1vKKFDKHyZLqZk685OPpeNfyNlYI5Ps0pmCr0l94PdKH7YX8bUX8bUX8bVfb4pxV9vgMMYYm9sSUUuWLNHy5cs73Gb79u0aPXq0JOno0aM6duyY9uzZo2XLliknJ0cvv/yyHGGuIbj33nu1bNmyduufeOIJZWRknP0/AADiwJTtS5TVdEDvDPsPNTuzdfFnP5ck/f1r/63CE59q3N5HdLTfKL098v/EuKUAAERPY2OjrrvuOtXV1Sk7OzvsdnGXZB05ckQ1NTUdbjN06FClpqa2W79v3z6VlZXpnXfeUWVlZch9Q41klZWV6ejRox0G0k5ut1tr167VpZdeylRHmxBjexFfe9kR3+Snr1PSzr+rZfr9MrnlSnnKutnwmTlr5PjsJSVv/bNaKn8iz5R7IvJ+vR192F7E117E117E1369Kcb19fUqKCjoNMmKu+mChYWFKiws7NG+Ho9HkoKSqLZcLpdcLle79U6nM+a/1N7QhkRHjO1FfO0V0fi2XpeV3LBfSk3zrU6p/VI6WmW9VjpWyX3s90kfthfxtRfxtRfxtV9viHFX3z/ukqyu2rRpk95//31961vfUl5ennbt2qW7775bw4YNCzuKBQBolVtmPdbtk1L7+dfX7JAOb7eWB3wt+u0CACAOJGx1wYyMDD333HO65JJLNGrUKM2fP19jxozRhg0bQo5UAQAC5Ay0Huu+kuoDbuL+xXqpuV5Kckr5w2PSNAAAeruEHcmqqKjQm2++GetmAEB8yim3Huv2Sen9/euPfm49FoyQUtpf+woAABI4yQIAnAXvSFbDQSk9r/3rTBUEACCshJ0uCAA4C/0KpWSXZDz+a7ACDTg3+m0CACBOkGQBANpLSpJyzrGWTYv1mFPmf73ovOi3CQCAOEGSBQAILTCpSkqRyib6nzNdEACAsEiyAAChBSZZmcVS4ShrOTUz+DUAABCEJAsAEFpuQCKVVey/Dqt4jDWdEAAAhER1QQBAaN4Kg5KUXSKNnCFd9ktp6OSYNQkAgHhAkgUACC1wSmBWiZScIlUujF17AACIE8z3AACEFjiSlVUcu3YAABBnSLIAAKFln+NfziqNXTsAAIgzJFkAgNCcaVJmkbXMSBYAAF3GNVkAgPC+eYe0a51UfmGsWwIAQNwgyQIAhFe5kGIXAAB0E9MFAQAAACCCSLIAAAAAIIJIsgAAAAAggkiyAAAAACCCSLIAAAAAIIJIsgAAAAAggkiyAAAAACCCSLIAAAAAIIJIsgAAAAAggkiyAAAAACCCSLIAAAAAIIJIsgAAAAAggkiyAAAAACCCUmLdgN7OGCNJqq+vj1kb3G63GhsbVV9fL6fTGbN2JDJibC/iay/iaz9ibC/iay/iay/ia7/eFGNvTuDNEcIhyepEQ0ODJKmsrCzGLQEAAADQGzQ0NCgnJyfs6w7TWRrWx3k8Hh04cEBZWVlyOBwxaUN9fb3Kysr01VdfKTs7OyZtSHTE2F7E117E137E2F7E117E117E1369KcbGGDU0NKi0tFRJSeGvvGIkqxNJSUkaOHBgrJshScrOzo55x0p0xNhexNdexNd+xNhexNdexNdexNd+vSXGHY1geVH4AgAAAAAiiCQLAAAAACKIJCsOuFwuLV26VC6XK9ZNSVjE2F7E117E137E2F7E117E117E137xGGMKXwAAAABABDGSBQAAAAARRJIFAAAAABFEkgUAAAAAEUSSBQAAAAARRJLVSzz00EMaPHiw0tLSNHHiRL333nsdbr9mzRqNHj1aaWlpqqio0F//+tcotTT+3HffffrGN76hrKwsDRgwQLNnz1ZVVVWH+zz66KNyOBxBP2lpaVFqcXy5995728Vq9OjRHe5D/+26wYMHt4uvw+HQwoULQ25P3+3cP/7xD11xxRUqLS2Vw+HQCy+8EPS6MUb33HOPSkpKlJ6erqlTp2rHjh2dHre75/FE1VF83W63Fi9erIqKCvXr10+lpaX6wQ9+oAMHDnR4zJ6cZxJVZ/133rx57WI1ffr0To9L//XrLMahzskOh0P3339/2GPShy1d+UzW1NSkhQsXKj8/X5mZmbryyit16NChDo/b0/O2nUiyeoGnn35aP/3pT7V06VJt2bJFY8eO1WWXXabDhw+H3P6dd97RnDlzNH/+fG3dulWzZ8/W7NmztW3btii3PD5s2LBBCxcu1Lvvvqu1a9fK7XZr2rRpOnnyZIf7ZWdn6+DBg76fPXv2RKnF8ee8884LitU///nPsNvSf7vn/fffD4rt2rVrJUlXXXVV2H3oux07efKkxo4dq4ceeijk67/+9a/1u9/9Tr///e+1adMm9evXT5dddpmamprCHrO75/FE1lF8GxsbtWXLFt19993asmWLnnvuOVVVVWnmzJmdHrc755lE1ln/laTp06cHxerJJ5/s8Jj032CdxTgwtgcPHtSqVavkcDh05ZVXdnhc+nDXPpPdeeedeumll7RmzRpt2LBBBw4c0Pe+970Oj9uT87btDGLuggsuMAsXLvQ9b2lpMaWlpea+++4Luf3VV19tLr/88qB1EydOND/60Y9sbWeiOHz4sJFkNmzYEHab1atXm5ycnOg1Ko4tXbrUjB07tsvb03/Pzh133GGGDRtmPB5PyNfpu90jyTz//PO+5x6PxxQXF5v777/ft662tta4XC7z5JNPhj1Od8/jfUXb+Iby3nvvGUlmz549Ybfp7nmmrwgV3xtuuMHMmjWrW8eh/4bXlT48a9YsM2XKlA63oQ+H1vYzWW1trXE6nWbNmjW+bbZv324kmY0bN4Y8Rk/P23ZjJCvGTp8+rc2bN2vq1Km+dUlJSZo6dao2btwYcp+NGzcGbS9Jl112WdjtEayurk6S1L9//w63O3HihAYNGqSysjLNmjVLn3zySTSaF5d27Nih0tJSDR06VNdff7327t0bdlv6b8+dPn1ajz/+uG666SY5HI6w29F3e2737t2qrq4O6qM5OTmaOHFi2D7ak/M4/Orq6uRwOJSbm9vhdt05z/R169ev14ABAzRq1CjdcsstqqmpCbst/ffsHDp0SK+88ormz5/f6bb04fbafibbvHmz3G53UH8cPXq0ysvLw/bHnpy3o4EkK8aOHj2qlpYWFRUVBa0vKipSdXV1yH2qq6u7tT38PB6PFi1apG9+85s6//zzw243atQorVq1Si+++KIef/xxeTweTZo0Sfv27Ytia+PDxIkT9eijj+rVV1/VypUrtXv3bn37299WQ0NDyO3pvz33wgsvqLa2VvPmzQu7DX337Hj7YXf6aE/O47A0NTVp8eLFmjNnjrKzs8Nu193zTF82ffp0/fnPf9Ybb7yh5cuXa8OGDZoxY4ZaWlpCbk//PTuPPfaYsrKyOp3ORh9uL9RnsurqaqWmprb70qWzz8Xebbq6TzSkxOydgRhYuHChtm3b1uk86MrKSlVWVvqeT5o0Seeee64efvhh/eIXv7C7mXFlxowZvuUxY8Zo4sSJGjRokJ555pkufbOHrnvkkUc0Y8YMlZaWht2Gvot44Xa7dfXVV8sYo5UrV3a4LeeZrrv22mt9yxUVFRozZoyGDRum9evX65JLLolhyxLTqlWrdP3113daYIg+3F5XP5PFK0ayYqygoEDJycntqqYcOnRIxcXFIfcpLi7u1vaw3HbbbXr55Ze1bt06DRw4sFv7Op1OjRs3Tjt37rSpdYkjNzdXI0eODBsr+m/P7NmzR6+//rpuvvnmbu1H3+0ebz/sTh/tyXm8r/MmWHv27NHatWs7HMUKpbPzDPyGDh2qgoKCsLGi//bcW2+9paqqqm6flyX6cLjPZMXFxTp9+rRqa2uDtu/sc7F3m67uEw0kWTGWmpqqCRMm6I033vCt83g8euONN4K+jQ5UWVkZtL0krV27Nuz2fZ0xRrfddpuef/55vfnmmxoyZEi3j9HS0qKPP/5YJSUlNrQwsZw4cUK7du0KGyv6b8+sXr1aAwYM0OWXX96t/ei73TNkyBAVFxcH9dH6+npt2rQpbB/tyXm8L/MmWDt27NDrr7+u/Pz8bh+js/MM/Pbt26eampqwsaL/9twjjzyiCRMmaOzYsd3et6/24c4+k02YMEFOpzOoP1ZVVWnv3r1h+2NPzttREbOSG/B56qmnjMvlMo8++qj59NNPzQ9/+EOTm5trqqurjTHGzJ071yxZssS3/dtvv21SUlLMihUrzPbt283SpUuN0+k0H3/8caz+Cb3aLbfcYnJycsz69evNwYMHfT+NjY2+bdrGeNmyZea1114zu3btMps3bzbXXnutSUtLM5988kks/gm92l133WXWr19vdu/ebd5++20zdepUU1BQYA4fPmyMof9GQktLiykvLzeLFy9u9xp9t/saGhrM1q1bzdatW40k85vf/MZs3brVV93uV7/6lcnNzTUvvvii+eijj8ysWbPMkCFDzKlTp3zHmDJlinnggQd8zzs7j/clHcX39OnTZubMmWbgwIHmww8/DDonNzc3+47RNr6dnWf6ko7i29DQYH72s5+ZjRs3mt27d5vXX3/djB8/3owYMcI0NTX5jkH/7Vhn5whjjKmrqzMZGRlm5cqVIY9BHw6tK5/JfvzjH5vy8nLz5ptvmg8++MBUVlaaysrKoOOMGjXKPPfcc77nXTlvRxtJVi/xwAMPmPLycpOammouuOAC8+677/peu+iii8wNN9wQtP0zzzxjRo4caVJTU815551nXnnllSi3OH5ICvmzevVq3zZtY7xo0SLf76OoqMh85zvfMVu2bIl+4+PANddcY0pKSkxqaqo555xzzDXXXGN27tzpe53+e/Zee+01I8lUVVW1e42+233r1q0LeU7wxtHj8Zi7777bFBUVGZfLZS655JJ2sR80aJBZunRp0LqOzuN9SUfx3b17d9hz8rp163zHaBvfzs4zfUlH8W1sbDTTpk0zhYWFxul0mkGDBpkFCxa0S5bovx3r7BxhjDEPP/ywSU9PN7W1tSGPQR8OrSufyU6dOmVuvfVWk5eXZzIyMsx3v/tdc/DgwXbHCdynK+ftaHMYY4w9Y2QAAAAA0PdwTRYAAAAARBBJFgAAAABEEEkWAAAAAEQQSRYAAAAARBBJFgAAAABEEEkWAAAAAEQQSRYAAAAARBBJFgAAkubNm6fZs2fHuhkAgASQEusGAABgN4fD0eHrS5cu1W9/+1sZY6LUIgBAIiPJAgAkvIMHD/qWn376ad1zzz2qqqryrcvMzFRmZmYsmgYASEBMFwQAJLzi4mLfT05OjhwOR9C6zMzMdtMFJ0+erNtvv12LFi1SXl6eioqK9Mc//lEnT57UjTfeqKysLA0fPlx/+9vfgt5r27ZtmjFjhjIzM1VUVKS5c+fq6NGjUf4XAwBiiSQLAIAwHnvsMRUUFOi9997T7bffrltuuUVXXXWVJk2apC1btmjatGmaO3euGhsbJUm1tbWaMmWKxo0bpw8++ECvvvqqDh06pKuvvjrG/xIAQDSRZAEAEMbYsWP185//XCNGjNB//ud/Ki0tTQUFBVqwYIFGjBihe+65RzU1Nfroo48kSQ8++KDGjRunX/7ylxo9erTGjRunVatWad26dfr8889j/K8BAEQL12QBABDGmDFjfMvJycnKz89XRUWFb11RUZEk6fDhw5Kkf/3rX1q3bl3I67t27dqlkSNH2txiAEBvQJIFAEAYTqcz6LnD4Qha561a6PF4JEknTpzQFVdcoeXLl7c7VklJiY0tBQD0JiRZAABEyPjx4/Xss89q8ODBSknhTywA9FVckwUAQIQsXLhQx44d05w5c/T+++9r165deu2113TjjTeqpaUl1s0DAEQJSRYAABFSWlqqt99+Wy0tLZo2bZoqKiq0aNEi5ebmKimJP7kA0Fc4DLe3BwAAAICI4Ws1AAAAAIggkiwAAAAAiCCSLAAAAACIIJIsAAAAAIggkiwAAAAAiCCSLAAAAACIIJIsAAAAAIggkiwAAAAAiCCSLAAAAACIIJIsAAAAAIggkiwAAAAAiCCSLAAAAACIoP8PCne0acbnROsAAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Plot the trajectory against time and velocity against time in a separate plot\n", + "fig, axs = plt.subplots(2, 1, figsize=(10, 10))\n", + "axs[0].plot(sol.ts, xs[:, 0], label=\"x1\")\n", + "axs[0].plot(sol.ts, xs[:, 1], label=\"x2\")\n", + "axs[0].set_xlabel(\"Time\")\n", + "axs[0].set_ylabel(\"Position\")\n", + "axs[0].legend()\n", + "axs[0].grid()\n", + "\n", + "axs[1].plot(sol.ts, vs[:, 0], label=\"v1\")\n", + "axs[1].plot(sol.ts, vs[:, 1], label=\"v2\")\n", + "axs[1].set_xlabel(\"Time\")\n", + "axs[1].set_ylabel(\"Velocity\")\n", + "axs[1].legend()\n", + "axs[1].grid()\n", + "\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/mkdocs.yml b/mkdocs.yml index 6b4c2547..71205be8 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -112,6 +112,7 @@ nav: - Kalman filter: 'examples/kalman_filter.ipynb' - Second-order sensitivities: 'examples/hessian.ipynb' - Nonlinear heat PDE: 'examples/nonlinear_heat_pde.ipynb' + - Langevin diffusion: 'examples/langevin_example.ipynb' - Basic API: - 'api/diffeqsolve.md' - Solvers: diff --git a/test/test_langevin.py b/test/test_langevin.py index 8625d7e6..4ec872be 100644 --- a/test/test_langevin.py +++ b/test/test_langevin.py @@ -9,6 +9,7 @@ from .helpers import ( get_bqp, get_harmonic_oscillator, + path_l2_dist, SDE, simple_batch_sde_solve, simple_sde_order, @@ -194,3 +195,39 @@ def get_dt_and_controller(level): assert ( -0.2 < order - theoretical_order < 0.25 ), f"order={order}, theoretical_order={theoretical_order}" + + +@pytest.mark.parametrize("solver_cls", _only_langevin_solvers_cls()) +def test_reverse_solve(solver_cls): + t0, t1 = 0.7, -1.2 + dt0 = -0.01 + saveat = SaveAt(ts=jnp.linspace(t0, t1, 20, endpoint=True)) + + gamma = jnp.array([2, 0.5], dtype=jnp.float64) + u = jnp.array([0.5, 2], dtype=jnp.float64) + x0 = jnp.zeros((2,), dtype=jnp.float64) + v0 = jnp.zeros((2,), dtype=jnp.float64) + y0 = (x0, v0) + + bm = diffrax.VirtualBrownianTree( + t1, + t0, + tol=0.005, + shape=(2,), + key=jr.key(0), + levy_area=diffrax.SpaceTimeTimeLevyArea, + ) + terms = diffrax.make_langevin_term(gamma, u, lambda x: 2 * x, bm, x0) + + solver = solver_cls(0.01) + sol = diffeqsolve(terms, solver, t0, t1, dt0=dt0, y0=y0, args=None, saveat=saveat) + + ref_solver = diffrax.Heun() + ref_sol = diffeqsolve( + terms, ref_solver, t0, t1, dt0=dt0, y0=y0, args=None, saveat=saveat + ) + + # print(jtu.tree_map(lambda x: x.shape, sol.ys)) + # print(jtu.tree_map(lambda x: x.shape, ref_sol.ys)) + error = path_l2_dist(sol.ys, ref_sol.ys) + assert error < 0.1