diff --git a/diffrax/__init__.py b/diffrax/__init__.py index 7d1547c6..e2f094be 100644 --- a/diffrax/__init__.py +++ b/diffrax/__init__.py @@ -14,9 +14,12 @@ VirtualBrownianTree as VirtualBrownianTree, ) from ._custom_types import ( - AbstractBrownianReturn as AbstractBrownianReturn, + AbstractBrownianIncrement as AbstractBrownianIncrement, + AbstractSpaceTimeLevyArea as AbstractSpaceTimeLevyArea, + AbstractSpaceTimeTimeLevyArea as AbstractSpaceTimeTimeLevyArea, BrownianIncrement as BrownianIncrement, SpaceTimeLevyArea as SpaceTimeLevyArea, + SpaceTimeTimeLevyArea as SpaceTimeTimeLevyArea, ) from ._event import ( AbstractDiscreteTerminatingEvent as AbstractDiscreteTerminatingEvent, @@ -69,6 +72,7 @@ AbstractRungeKutta as AbstractRungeKutta, AbstractSDIRK as AbstractSDIRK, AbstractSolver as AbstractSolver, + AbstractSRK as AbstractSRK, AbstractStratonovichSolver as AbstractStratonovichSolver, AbstractWrappedSolver as AbstractWrappedSolver, Bosh3 as Bosh3, @@ -78,6 +82,7 @@ Dopri8 as Dopri8, Euler as Euler, EulerHeun as EulerHeun, + GeneralShARK as GeneralShARK, HalfSolver as HalfSolver, Heun as Heun, ImplicitEuler as ImplicitEuler, @@ -93,8 +98,14 @@ MultiButcherTableau as MultiButcherTableau, Ralston as Ralston, ReversibleHeun as ReversibleHeun, + SEA as SEA, SemiImplicitEuler as SemiImplicitEuler, + ShARK as ShARK, Sil3 as Sil3, + SlowRK as SlowRK, + SPaRK as SPaRK, + SRA1 as SRA1, + StochasticButcherTableau as StochasticButcherTableau, StratonovichMilstein as StratonovichMilstein, Tsit5 as Tsit5, ) diff --git a/diffrax/_autocitation.py b/diffrax/_autocitation.py index 368ab508..484d1ebb 100644 --- a/diffrax/_autocitation.py +++ b/diffrax/_autocitation.py @@ -16,14 +16,23 @@ from ._saveat import SubSaveAt from ._solver import ( AbstractImplicitSolver, + AbstractItoSolver, + AbstractSRK, + AbstractStratonovichSolver, Dopri5, Dopri8, + GeneralShARK, Kvaerno3, Kvaerno4, Kvaerno5, LeapfrogMidpoint, ReversibleHeun, + SEA, SemiImplicitEuler, + ShARK, + SlowRK, + SPaRK, + SRA1, Tsit5, ) from ._step_size_controller import PIDController @@ -374,7 +383,15 @@ def _backsolve_rms_norm(adjoint): @citation_rules.append def _explicit_solver(solver, terms=None): - if not isinstance(solver, AbstractImplicitSolver) and not is_sde(terms): + if not isinstance( + solver, + ( + AbstractImplicitSolver, + AbstractSRK, + AbstractItoSolver, + AbstractStratonovichSolver, + ), + ) and not is_sde(terms): return r""" % You are using an explicit solver, and may wish to cite the standard textbook: @book{hairer2008solving-i, @@ -467,6 +484,12 @@ def _solvers(solver, saveat=None): Kvaerno5, ReversibleHeun, LeapfrogMidpoint, + ShARK, + SRA1, + SlowRK, + GeneralShARK, + SPaRK, + SEA, ): return ( r""" diff --git a/diffrax/_brownian/base.py b/diffrax/_brownian/base.py index 53b1ddfc..1642d315 100644 --- a/diffrax/_brownian/base.py +++ b/diffrax/_brownian/base.py @@ -4,17 +4,22 @@ from equinox.internal import AbstractVar from jaxtyping import Array, PyTree -from .._custom_types import AbstractBrownianReturn, RealScalarLike +from .._custom_types import ( + AbstractBrownianIncrement, + BrownianIncrement, + RealScalarLike, + SpaceTimeLevyArea, +) from .._path import AbstractPath -_Control = TypeVar("_Control", bound=Union[PyTree[Array], AbstractBrownianReturn]) +_Control = TypeVar("_Control", bound=Union[PyTree[Array], AbstractBrownianIncrement]) class AbstractBrownianPath(AbstractPath[_Control]): """Abstract base class for all Brownian paths.""" - levy_area: AbstractVar[type[AbstractBrownianReturn]] + levy_area: AbstractVar[type[Union[BrownianIncrement, SpaceTimeLevyArea]]] @abc.abstractmethod def evaluate( diff --git a/diffrax/_brownian/path.py b/diffrax/_brownian/path.py index 3b359d79..66075069 100644 --- a/diffrax/_brownian/path.py +++ b/diffrax/_brownian/path.py @@ -11,6 +11,7 @@ from jaxtyping import Array, PRNGKeyArray, PyTree from .._custom_types import ( + AbstractBrownianIncrement, BrownianIncrement, levy_tree_transpose, RealScalarLike, @@ -87,7 +88,7 @@ def evaluate( t1: Optional[RealScalarLike] = None, left: bool = True, use_levy: bool = False, - ) -> Union[PyTree[Array], BrownianIncrement, SpaceTimeLevyArea]: + ) -> Union[PyTree[Array], AbstractBrownianIncrement]: del left if t1 is None: dtype = jnp.result_type(t0) @@ -129,13 +130,13 @@ def _evaluate_leaf( ): w_std = jnp.sqrt(t1 - t0).astype(shape.dtype) w = jr.normal(key, shape.shape, shape.dtype) * w_std - dt = t1 - t0 + dt = jnp.asarray(t1 - t0, dtype=shape.dtype) if levy_area is SpaceTimeLevyArea: key, key_hh = jr.split(key, 2) hh_std = w_std / math.sqrt(12) hh = jr.normal(key_hh, shape.shape, shape.dtype) * hh_std - levy_val = SpaceTimeLevyArea(dt=dt, W=w, H=hh, K=None) + levy_val = SpaceTimeLevyArea(dt=dt, W=w, H=hh) elif levy_area is BrownianIncrement: levy_val = BrownianIncrement(dt=dt, W=w) else: diff --git a/diffrax/_brownian/tree.py b/diffrax/_brownian/tree.py index 9699f608..88c10524 100644 --- a/diffrax/_brownian/tree.py +++ b/diffrax/_brownian/tree.py @@ -13,7 +13,7 @@ from jaxtyping import Array, Float, PRNGKeyArray, PyTree from .._custom_types import ( - AbstractBrownianReturn, + AbstractBrownianIncrement, BoolScalarLike, BrownianIncrement, IntScalarLike, @@ -59,7 +59,7 @@ Float[Array, " *shape"], Float[Array, " *shape"], Float[Array, " *shape"] ] _Spline: TypeAlias = Literal["sqrt", "quad", "zero"] -_BrownianReturn = TypeVar("_BrownianReturn", bound=AbstractBrownianReturn) +_BrownianReturn = TypeVar("_BrownianReturn", bound=AbstractBrownianIncrement) class _State(eqx.Module): @@ -71,7 +71,7 @@ class _State(eqx.Module): bkk_s_u_su: Optional[FloatTriple] # \bar{K}_s, _u, _{s,u} -def _levy_diff(_, x0: tuple, x1: tuple) -> Union[BrownianIncrement, SpaceTimeLevyArea]: +def _levy_diff(_, x0: tuple, x1: tuple) -> AbstractBrownianIncrement: r"""Computes $(W_{s,u}, H_{s,u})$ from $(W_s, \bar{H}_{s,u})$ and $(W_u, \bar{H}_u)$, where $\bar{H}_u = u * H_u$. @@ -105,18 +105,18 @@ def _levy_diff(_, x0: tuple, x1: tuple) -> Union[BrownianIncrement, SpaceTimeLev u_bb_s = dt1 * w0 - dt0 * w1 bhh_su = bhh1 - bhh0 - 0.5 * u_bb_s # bhh_su = H_{s,u} * (u-s) hh_su = inverse_su * bhh_su - return SpaceTimeLevyArea(dt=su, W=w_su, H=hh_su, K=None) + return SpaceTimeLevyArea(dt=su, W=w_su, H=hh_su) else: assert False -def _make_levy_val(_, x: tuple) -> Union[BrownianIncrement, SpaceTimeLevyArea]: +def _make_levy_val(_, x: tuple) -> AbstractBrownianIncrement: if len(x) == 2: dt, w = x return BrownianIncrement(dt=dt, W=w) elif len(x) == 4: dt, w, hh, bhh = x - return SpaceTimeLevyArea(dt=dt, W=w, H=hh, K=None) + return SpaceTimeLevyArea(dt=dt, W=w, H=hh) else: assert False @@ -243,7 +243,7 @@ def evaluate( t1: Optional[RealScalarLike] = None, left: bool = True, use_levy: bool = False, - ) -> Union[PyTree[Array], BrownianIncrement, SpaceTimeLevyArea]: + ) -> Union[PyTree[Array], AbstractBrownianIncrement]: t0 = eqxi.nondifferentiable(t0, name="t0") # map the interval [self.t0, self.t1] onto [0,1] t0 = linear_rescale(self.t0, t0, self.t1) @@ -294,11 +294,14 @@ def _evaluate_leaf( bhh = (bhh_0, bhh_1, bhh_1) bkk = None - else: + elif self.levy_area is BrownianIncrement: state_key, init_key_w = jr.split(key, 2) bhh = None bkk = None + else: + assert False + w_0 = jnp.zeros(shape, dtype) w_1 = jr.normal(init_key_w, shape, dtype) w = (w_0, w_1, w_1) @@ -334,11 +337,13 @@ def _body_fun(_state: _State): _w = _split_interval(_cond, _w_stu, _w_inc) _bkk = None - if self.levy_area is not BrownianIncrement: + if self.levy_area is SpaceTimeLevyArea: assert _bhh_stu is not None and _bhh_st_tu is not None _bhh = _split_interval(_cond, _bhh_stu, _bhh_st_tu) - else: + elif self.levy_area is BrownianIncrement: _bhh = None + else: + assert False return _State( level=_level, @@ -359,23 +364,7 @@ def _body_fun(_state: _State): ru = jax.nn.relu(su - sr) w_s, w_u, w_su = final_state.w_s_u_su - - if self.levy_area is BrownianIncrement: - w_mean = w_s + sr / su * w_su - if self._spline == "sqrt": - z = jr.normal(final_state.key, shape, dtype) - bb = jnp.sqrt(sr * ru / su) * z - elif self._spline == "quad": - z = jr.normal(final_state.key, shape, dtype) - bb = (sr * ru / su) * z - elif self._spline == "zero": - bb = jnp.zeros(shape, dtype) - else: - assert False - w_r = w_mean + bb - return r, w_r - - elif self.levy_area is SpaceTimeLevyArea: + if self.levy_area is SpaceTimeLevyArea: # This is based on Theorem 6.1.4 of Foster's thesis (see above). assert final_state.bhh_s_u_su is not None @@ -414,6 +403,21 @@ def _body_fun(_state: _State): inverse_r = 1 / jnp.where(jnp.abs(r) < jnp.finfo(r).eps, jnp.inf, r) hh_r = inverse_r * bhh_r + elif self.levy_area is BrownianIncrement: + w_mean = w_s + sr / su * w_su + if self._spline == "sqrt": + z = jr.normal(final_state.key, shape, dtype) + bb = jnp.sqrt(sr * ru / su) * z + elif self._spline == "quad": + z = jr.normal(final_state.key, shape, dtype) + bb = (sr * ru / su) * z + elif self._spline == "zero": + bb = jnp.zeros(shape, dtype) + else: + assert False + w_r = w_mean + bb + return r, w_r + else: assert False @@ -499,7 +503,7 @@ def _brownian_arch( bkk_stu = None bkk_st_tu = None - else: + elif self.levy_area is BrownianIncrement: assert _state.bhh_s_u_su is None assert _state.bkk_s_u_su is None mean = 0.5 * w_su @@ -510,4 +514,8 @@ def _brownian_arch( w_t = w_s + w_st w_stu = (w_s, w_t, w_u) bhh_stu, bhh_st_tu, bkk_stu, bkk_st_tu = None, None, None, None + + else: + assert False + return t, w_stu, w_st_tu, keys, bhh_stu, bhh_st_tu, bkk_stu, bkk_st_tu diff --git a/diffrax/_custom_types.py b/diffrax/_custom_types.py index 0d45fe60..69853e98 100644 --- a/diffrax/_custom_types.py +++ b/diffrax/_custom_types.py @@ -1,5 +1,5 @@ import typing -from typing import Any, Optional, TYPE_CHECKING, Union +from typing import Any, TYPE_CHECKING, Union import equinox as eqx import equinox.internal as eqxi @@ -53,40 +53,54 @@ sentinel: Any = eqxi.doc_repr(object(), "sentinel") -class AbstractBrownianReturn(eqx.Module): - dt: eqx.AbstractVar[PyTree] - W: eqx.AbstractVar[PyTree] +class AbstractBrownianIncrement(eqx.Module): + dt: eqx.AbstractVar[PyTree[FloatScalarLike]] + W: eqx.AbstractVar[PyTree[Array]] -class BrownianIncrement(AbstractBrownianReturn): - dt: PyTree - W: PyTree +class AbstractSpaceTimeLevyArea(AbstractBrownianIncrement): + H: eqx.AbstractVar[PyTree[Array]] -class SpaceTimeLevyArea(AbstractBrownianReturn): - dt: PyTree - W: PyTree - H: Optional[PyTree] - K: Optional[PyTree] +class AbstractSpaceTimeTimeLevyArea(AbstractSpaceTimeLevyArea): + K: eqx.AbstractVar[PyTree[Array]] + + +class BrownianIncrement(AbstractBrownianIncrement): + dt: PyTree[FloatScalarLike] + W: PyTree[Array] + + +class SpaceTimeLevyArea(AbstractSpaceTimeLevyArea): + dt: PyTree[FloatScalarLike] + W: PyTree[Array] + H: PyTree[Array] + + +class SpaceTimeTimeLevyArea(AbstractSpaceTimeTimeLevyArea): + dt: PyTree[FloatScalarLike] + W: PyTree[Array] + H: PyTree[Array] + K: PyTree[Array] def levy_tree_transpose( - tree_shape, tree: PyTree[AbstractBrownianReturn] -) -> AbstractBrownianReturn: - """Helper that takes a PyTree of AbstractBrownianReturn and transposes - into an AbstractBrownianReturn of PyTrees. + tree_shape, tree: PyTree[AbstractBrownianIncrement] +) -> AbstractBrownianIncrement: + """Helper that takes a `PyTree `of `AbstractBrownianIncrement`s and transposes + into an `AbstractBrownianIncrement` of `PyTree`s. **Arguments:** - `tree_shape`: Corresponds to `outer_treedef` in `jax.tree_transpose`. - - `tree`: the PyTree of AbstractBrownianReturn to transpose. + - `tree`: the `PyTree` of `AbstractBrownianIncrement`s to transpose. **Returns:** - An `AbstractBrownianReturn` of PyTrees. + An `AbstractBrownianIncrement` of `PyTree`s. """ inner_tree = jtu.tree_leaves( - tree, is_leaf=lambda x: isinstance(x, AbstractBrownianReturn) + tree, is_leaf=lambda x: isinstance(x, AbstractBrownianIncrement) )[0] inner_tree_shape = jtu.tree_structure(inner_tree) return jtu.tree_transpose( diff --git a/diffrax/_integrate.py b/diffrax/_integrate.py index 34e8c030..5a18ff2a 100644 --- a/diffrax/_integrate.py +++ b/diffrax/_integrate.py @@ -95,17 +95,22 @@ def _term_compatible( args: PyTree[Any], terms: PyTree[AbstractTerm], term_structure: PyTree, + contr_kwargs: PyTree[dict], ) -> bool: error_msg = "term_structure" - def _check(term_cls, term, yi): + def _check(term_cls, term, term_contr_kwargs, yi): if get_origin_no_specials(term_cls, error_msg) is MultiTerm: if isinstance(term, MultiTerm): [_tmp] = get_args(term_cls) assert get_origin(_tmp) in (tuple, Tuple), "Malformed term_structure" assert len(term.terms) == len(get_args(_tmp)) - for term, arg in zip(term.terms, get_args(_tmp)): - if not _term_compatible(yi, args, term, arg): + assert type(term_contr_kwargs) is tuple + assert len(term.terms) == len(term_contr_kwargs) + for term, arg, term_contr_kwarg in zip( + term.terms, get_args(_tmp), term_contr_kwargs + ): + if not _term_compatible(yi, args, term, arg, term_contr_kwarg): raise ValueError else: raise ValueError @@ -137,7 +142,9 @@ def _check(term_cls, term, yi): ) if not vf_type_compatible: raise ValueError - control_type = jax.eval_shape(term.contr, 0.0, 0.0) + + contr = ft.partial(term.contr, **term_contr_kwargs) + control_type = jax.eval_shape(contr, 0.0, 0.0) control_type_compatible = eqx.filter_eval_shape( better_isinstance, control_type, control_type_expected ) @@ -148,7 +155,7 @@ def _check(term_cls, term, yi): # If we've got to this point then the term is compatible try: - jtu.tree_map(_check, term_structure, terms, y) + jtu.tree_map(_check, term_structure, terms, contr_kwargs, y) except ValueError: # ValueError may also arise from mismatched tree structures return False @@ -758,7 +765,9 @@ def _promote(yi): # Backward compatibility if isinstance( solver, (EulerHeun, ItoMilstein, StratonovichMilstein) - ) and _term_compatible(y0, args, terms, (ODETerm, AbstractTerm)): + ) and _term_compatible( + y0, args, terms, (ODETerm, AbstractTerm), solver.term_compatible_contr_kwargs + ): warnings.warn( "Passing `terms=(ODETerm(...), SomeOtherTerm(...))` to " f"{solver.__class__.__name__} is deprecated in favour of " @@ -770,7 +779,9 @@ def _promote(yi): terms = MultiTerm(*terms) # Error checking - if not _term_compatible(y0, args, terms, solver.term_structure): + if not _term_compatible( + y0, args, terms, solver.term_structure, solver.term_compatible_contr_kwargs + ): raise ValueError( "`terms` must be a PyTree of `AbstractTerms` (such as `ODETerm`), with " f"structure {solver.term_structure}" diff --git a/diffrax/_solver/__init__.py b/diffrax/_solver/__init__.py index e97c3336..da6fe6c9 100644 --- a/diffrax/_solver/__init__.py +++ b/diffrax/_solver/__init__.py @@ -38,6 +38,16 @@ CalculateJacobian as CalculateJacobian, MultiButcherTableau as MultiButcherTableau, ) +from .sea import SEA as SEA from .semi_implicit_euler import SemiImplicitEuler as SemiImplicitEuler +from .shark import ShARK as ShARK +from .shark_general import GeneralShARK as GeneralShARK from .sil3 import Sil3 as Sil3 +from .slowrk import SlowRK as SlowRK +from .spark import SPaRK as SPaRK +from .sra1 import SRA1 as SRA1 +from .srk import ( + AbstractSRK as AbstractSRK, + StochasticButcherTableau as StochasticButcherTableau, +) from .tsit5 import Tsit5 as Tsit5 diff --git a/diffrax/_solver/base.py b/diffrax/_solver/base.py index d318f2a8..42f19e4c 100644 --- a/diffrax/_solver/base.py +++ b/diffrax/_solver/base.py @@ -1,6 +1,16 @@ import abc from collections.abc import Callable -from typing import Generic, Optional, Type, TYPE_CHECKING, TypeVar +from typing import ( + Any, + ClassVar, + Generic, + get_args, + get_origin, + Optional, + Type, + TYPE_CHECKING, + TypeVar, +) import equinox as eqx import jax.lax as lax @@ -20,7 +30,7 @@ from .._heuristics import is_sde from .._local_interpolation import AbstractLocalInterpolation from .._solution import RESULTS, update_result -from .._term import AbstractTerm +from .._term import AbstractTerm, MultiTerm _SolverState = TypeVar("_SolverState") @@ -49,6 +59,18 @@ def __instancecheck__(cls, obj): _set_metaclass = dict(metaclass=_MetaAbstractSolver) +def _term_compatible_contr_kwargs(term_structure): + origin = get_origin(term_structure) + if origin is MultiTerm: + [terms] = get_args(term_structure) + return tuple(_term_compatible_contr_kwargs(term) for term in get_args(terms)) + if origin is not None: + term_structure = origin + if isinstance(term_structure, type) and issubclass(term_structure, AbstractTerm): + return {} + return jtu.tree_map(_term_compatible_contr_kwargs, term_structure) + + class AbstractSolver(eqx.Module, Generic[_SolverState], **_set_metaclass): """Abstract base class for all differential equation solvers. @@ -61,6 +83,11 @@ class AbstractSolver(eqx.Module, Generic[_SolverState], **_set_metaclass): # How to interpolate the solution in between steps. interpolation_cls: AbstractClassVar[Callable[..., AbstractLocalInterpolation]] + # Any keyword arguments needed in `_term_compatible` in `_integrate.py`. + term_compatible_contr_kwargs: ClassVar[PyTree[dict[str, Any]]] = property( + lambda self: _term_compatible_contr_kwargs(self.term_structure) + ) + def order(self, terms: PyTree[AbstractTerm]) -> Optional[int]: """Order of the solver for solving ODEs.""" return None @@ -250,6 +277,10 @@ def term_structure(self): def interpolation_cls(self): # pyright: ignore return self.solver.interpolation_cls + @property + def term_compatible_contr_kwargs(self): + return self.solver.term_compatible_contr_kwargs + def order(self, terms: PyTree[AbstractTerm]) -> Optional[int]: return self.solver.order(terms) diff --git a/diffrax/_solver/sea.py b/diffrax/_solver/sea.py new file mode 100644 index 00000000..b4f985e7 --- /dev/null +++ b/diffrax/_solver/sea.py @@ -0,0 +1,63 @@ +from typing import ClassVar + +import numpy as np + +from .base import AbstractStratonovichSolver +from .srk import AbstractSRK, AdditiveCoeffs, StochasticButcherTableau + + +_coeffs_w = AdditiveCoeffs( + a=np.array([0.5]), + b_sol=np.array(1.0), +) + +_coeffs_hh = AdditiveCoeffs( + a=np.array([1.0]), + b_sol=np.array(0.0), +) + +_tab = StochasticButcherTableau( + a=[], + b_sol=np.array([1.0]), + b_error=None, + c=np.array([]), + coeffs_w=_coeffs_w, + coeffs_hh=_coeffs_hh, + coeffs_kk=None, + ignore_stage_f=None, + ignore_stage_g=None, +) + + +class SEA(AbstractSRK, AbstractStratonovichSolver): + r"""Shifted Euler method for SDEs with additive noise. + + Makes one evaluation of the drift and diffusion per step and has a strong order 1. + Compared to [`diffrax.Euler`][], it has a better constant factor in the global + error, and an improved local error of $O(h^2)$ instead of $O(h^{1.5})$. + + This solver is useful for solving additive-noise SDEs with as few drift and + diffusion evaluations per step as possible. + + ??? cite "Reference" + + This solver is based on equation (5.8) in + + ```bibtex + @article{foster2023high, + title={High order splitting methods for SDEs satisfying a commutativity + condition}, + author={James Foster and Goncalo dos Reis and Calum Strange}, + year={2023}, + journal={arXiv:2210.17543}, + } + ``` + """ + + tableau: ClassVar[StochasticButcherTableau] = _tab + + def order(self, terms): + return 1 + + def strong_order(self, terms): + return 1 diff --git a/diffrax/_solver/shark.py b/diffrax/_solver/shark.py new file mode 100644 index 00000000..59d5d9da --- /dev/null +++ b/diffrax/_solver/shark.py @@ -0,0 +1,63 @@ +from typing import ClassVar + +import numpy as np + +from .base import AbstractStratonovichSolver +from .srk import AbstractSRK, AdditiveCoeffs, StochasticButcherTableau + + +_coeffs_w = AdditiveCoeffs( + a=np.array([0.0, 5 / 6]), + b_sol=np.array(1.0), +) + +_coeffs_hh = AdditiveCoeffs( + a=np.array([1.0, 1.0]), + b_sol=np.array(0.0), +) + +_tab = StochasticButcherTableau( + a=[np.array([5 / 6])], + b_sol=np.array([0.4, 0.6]), + b_error=np.array([-0.6, 0.6]), + c=np.array([5 / 6]), + coeffs_w=_coeffs_w, + coeffs_hh=_coeffs_hh, + coeffs_kk=None, + ignore_stage_f=None, + ignore_stage_g=None, +) + + +class ShARK(AbstractSRK, AbstractStratonovichSolver): + r"""Shifted Additive-noise Runge-Kutta method for additive SDEs. + + Makes two evaluations of the drift and diffusion per step and has a strong order + 1.5. + + This is the recommended choice for SDEs with additive noise. + + See also [`diffrax.SRA1`][], which is very similar. + + ??? cite "Reference" + + This solver is based on equation (6.1) in + + ```bibtex + @article{foster2023high, + title={High order splitting methods for SDEs satisfying a commutativity + condition}, + author={James Foster and Goncalo dos Reis and Calum Strange}, + year={2023}, + journal={arXiv:2210.17543}, + } + ``` + """ + + tableau: ClassVar[StochasticButcherTableau] = _tab + + def order(self, terms): + return 2 + + def strong_order(self, terms): + return 1.5 diff --git a/diffrax/_solver/shark_general.py b/diffrax/_solver/shark_general.py new file mode 100644 index 00000000..74aaea40 --- /dev/null +++ b/diffrax/_solver/shark_general.py @@ -0,0 +1,76 @@ +from typing import ClassVar + +import numpy as np + +from .base import AbstractStratonovichSolver +from .srk import AbstractSRK, GeneralCoeffs, StochasticButcherTableau + + +_coeffs_w = GeneralCoeffs( + a=(np.array([0.0]), np.array([0.0, 5 / 6])), + b_sol=np.array([0.0, 0.4, 0.6]), + b_error=None, +) + +_coeffs_hh = GeneralCoeffs( + a=(np.array([1.0]), np.array([1.0, 0.0])), + b_sol=np.array([0.0, 1.2, -1.2]), + b_error=None, +) + +_tab = StochasticButcherTableau( + a=[np.array([0.0]), np.array([0.0, 5 / 6])], + b_sol=np.array([0.0, 0.4, 0.6]), + b_error=None, + c=np.array([0.0, 5 / 6]), + coeffs_w=_coeffs_w, + coeffs_hh=_coeffs_hh, + coeffs_kk=None, + ignore_stage_f=np.array([True, False, False]), + ignore_stage_g=None, +) + + +class GeneralShARK(AbstractSRK, AbstractStratonovichSolver): + r"""ShARK method for Stratonovich SDEs. + + As compared to [`diffrax.ShARK`][] this can handle any SDE, not only those with + additive noise. + + Makes two evaluations of the drift and three evaluations of the diffusion per step. + Has the following orders of convergence: + + - 1.5 for SDEs with additive noise (but [`diffrax.ShARK`][] is recommended instead) + - 1.0 for Stratonovich SDEs with commutative noise + ([`diffrax.SlowRK`][] is recommended instead) + - 0.5 for Stratonovich SDEs with general noise. + + For general Stratonovich SDEs this is equally precise as three steps of + [`diffrax.Heun`][] or a single step of [`diffrax.SPaRK`][], while requiring + one fewer evaluation of the drift, so this is the recommended choice for general + SDEs with an expensive drift vector field. If embedded error estimation is needed + (e.g. for adaptive time-stepping) then [`diffrax.SPaRK`][] is recommended instead. + + ??? cite "Reference" + + This solver is based on equation (6.1) from + + ```bibtex + @misc{foster2023high, + title={High order splitting methods for SDEs satisfying + a commutativity condition}, + author={James Foster and Goncalo dos Reis and Calum Strange}, + year={2023}, + eprint={2210.17543}, + archivePrefix={arXiv}, + primaryClass={math.NA} + ``` + """ + + tableau: ClassVar[StochasticButcherTableau] = _tab + + def order(self, terms): + return 2 + + def strong_order(self, terms): + return 1.5 diff --git a/diffrax/_solver/slowrk.py b/diffrax/_solver/slowrk.py new file mode 100644 index 00000000..86eeb6ae --- /dev/null +++ b/diffrax/_solver/slowrk.py @@ -0,0 +1,88 @@ +from typing import ClassVar + +import numpy as np + +from .base import AbstractStratonovichSolver +from .srk import AbstractSRK, GeneralCoeffs, StochasticButcherTableau + + +_coeffs_w = GeneralCoeffs( + a=( + np.array([0.0]), + np.array([0.0, 0.5]), + np.array([0.0, 0.0, 0.5]), + np.array([0.0, 0.0, 0.0, 1.0]), + np.array([0.0, 0.0, 0.0, 0.75, 0.0]), + np.array([0.0, 0.0, 0.5, 0.0, 0.0, 0.0]), + ), + b_sol=np.array([0.0, 1 / 6, 1 / 3, 1 / 3, 1 / 6, 0.0, 0.0]), + b_error=None, +) + +_coeffs_hh = GeneralCoeffs( + a=( + np.array([0.0]), + np.array([0.0, 0.0]), + np.array([0.0, 0.0, 0.0]), + np.array([0.0, 0.0, 0.0, 0.0]), + np.array([0.0, 0.0, 0.0, 1.5, 0.0]), + np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), + ), + b_sol=np.array([0.0, 0.0, 0.0, 2.0, 0.0, 0.0, -2.0]), + b_error=None, +) + +_tab = StochasticButcherTableau( + a=[ + np.array([0.5]), + np.array([0.5, 0.0]), + np.array([0.5, 0.0, 0.0]), + np.array([0.5, 0.0, 0.0, 0.0]), + np.array([0.75, 0.0, 0.0, 0.0, 0.0]), + np.array([1.0, 0.0, 0.0, 0.0, 0.0, 0.0]), + ], + b_sol=np.array([1 / 3, 0.0, 0.0, 0.0, 0.0, 2 / 3, 0.0]), + b_error=None, + c=np.array([0.5, 0.5, 0.5, 0.5, 0.75, 1.0]), + coeffs_w=_coeffs_w, + coeffs_hh=_coeffs_hh, + coeffs_kk=None, + ignore_stage_f=np.array([False, True, True, True, True, False, True]), + ignore_stage_g=np.array([True, False, False, False, False, True, False]), +) + + +class SlowRK(AbstractSRK, AbstractStratonovichSolver): + r"""SLOW-RK method for commutative-noise Stratonovich SDEs. + + Makes two evaluations of the drift and five evaluations of the diffusion per step. + Applied to SDEs with commutative noise, it converges strongly with order 1.5. + Can be used for SDEs with non-commutative noise, but then it only converges + strongly with order 0.5. + + This solver is an excellent choice for Stratonovich SDEs with commutative noise. + For non-commutative Stratonovich SDEs, consider using [`diffrax.GeneralShARK`][] + or [`diffrax.SPaRK`][] instead. + + ??? cite "Reference" + + This solver is based on equation (6.2) from + + ```bibtex + @article{foster2023high, + title={High order splitting methods for SDEs satisfying a commutativity + condition}, + author={James Foster and Goncalo dos Reis and Calum Strange}, + year={2023}, + journal={arXiv:2210.17543}, + } + ``` + """ + + tableau: ClassVar[StochasticButcherTableau] = _tab + + def order(self, terms): + return 2 + + def strong_order(self, terms): + return 0.5 diff --git a/diffrax/_solver/spark.py b/diffrax/_solver/spark.py new file mode 100644 index 00000000..bae71803 --- /dev/null +++ b/diffrax/_solver/spark.py @@ -0,0 +1,75 @@ +from typing import ClassVar + +import numpy as np + +from .base import AbstractStratonovichSolver +from .srk import AbstractSRK, GeneralCoeffs, StochasticButcherTableau + + +_x1 = (3 - np.sqrt(3)) / 6 +_x2 = np.sqrt(3) / 3 + +_coeffs_w = GeneralCoeffs( + a=(np.array([0.5]), np.array([0.0, 1.0])), + b_sol=np.array([_x1, _x2, _x1]), + b_error=np.array([_x1 - 0.5, _x2, _x1 - 0.5]), +) + +_coeffs_hh = GeneralCoeffs( + a=(np.array([np.sqrt(3.0)]), np.array([0.0, 0.0])), + b_sol=np.array([1.0, 0.0, -1.0]), + b_error=np.array([1.0, 0.0, -1.0]), +) + +_tab = StochasticButcherTableau( + a=[np.array([0.5]), np.array([0.0, 1.0])], + b_sol=np.array([_x1, _x2, _x1]), + b_error=np.array([_x1 - 0.5, _x2, _x1 - 0.5]), + c=np.array([0.5, 1.0]), + coeffs_w=_coeffs_w, + coeffs_hh=_coeffs_hh, + coeffs_kk=None, + ignore_stage_f=None, + ignore_stage_g=None, +) + + +class SPaRK(AbstractSRK, AbstractStratonovichSolver): + r"""The Splitting Path Runge-Kutta method. + + It uses three evaluations of the drift and diffusion per step, and has the following + strong orders of convergence: + + - 1.5 for SDEs with additive noise (but [`diffrax.ShARK`][] is recommended instead) + - 1.0 for Stratonovich SDEs with commutative noise + ([`diffrax.SlowRK`][] is recommended instead) + - 0.5 for Stratonovich SDEs with general noise. + + For general Stratonovich SDEs this is equally precise as three steps of + [`diffrax.Heun`][] or a single step of [`diffrax.GeneralShARK`][]. Unlike those, + this method has an embedded error estimate, so it is the recommended choice for + adaptive time-stepping. Otherwise, [`diffrax.GeneralShARK`][] is more efficient. + + ??? cite "Reference" + + This solver is based on Definition 1.6 from + + ```bibtex + @misc{foster2023convergence, + title={On the convergence of adaptive approximations + for stochastic differential equations}, + author={James Foster}, + year={2023}, + archivePrefix={arXiv}, + primaryClass={math.NA} + } + ``` + """ + + tableau: ClassVar[StochasticButcherTableau] = _tab + + def order(self, terms): + return 2 + + def strong_order(self, terms): + return 1.5 diff --git a/diffrax/_solver/sra1.py b/diffrax/_solver/sra1.py new file mode 100644 index 00000000..fcc16c0e --- /dev/null +++ b/diffrax/_solver/sra1.py @@ -0,0 +1,62 @@ +from typing import ClassVar + +import numpy as np + +from .base import AbstractStratonovichSolver +from .srk import AbstractSRK, AdditiveCoeffs, StochasticButcherTableau + + +_coeffs_w = AdditiveCoeffs( + a=np.array([0.0, 3 / 4]), + b_sol=np.array(1.0), +) + +_coeffs_hh = AdditiveCoeffs( + a=np.array([0.0, 1.5]), + b_sol=np.array(0.0), +) + +_tab = StochasticButcherTableau( + a=[np.array([3 / 4])], + b_sol=np.array([1 / 3, 2 / 3]), + b_error=np.array([-2 / 3, 2 / 3]), + c=np.array([3 / 4]), + coeffs_w=_coeffs_w, + coeffs_hh=_coeffs_hh, + coeffs_kk=None, + ignore_stage_f=None, + ignore_stage_g=None, +) + + +class SRA1(AbstractSRK, AbstractStratonovichSolver): + r"""The SRA1 method for additive-noise SDEs. + + Makes two evaluations of the drift and diffusion per step and has a strong order + 1.5. + + See also [`diffrax.ShARK`][], which is very similar. + + ??? cite "Reference" + + ```bibtex + @article{rossler2010runge + author = {Andreas R\"{o}\ss{}ler}, + title = {Runge–Kutta Methods for the Strong Approximation of Solutions of + Stochastic Differential Equations}, + journal = {SIAM Journal on Numerical Analysis}, + volume = {48}, + number = {3}, + pages = {922--952}, + year = {2010}, + doi = {10.1137/09076636X}, + ``` + """ + + tableau: ClassVar[StochasticButcherTableau] = _tab + + def order(self, terms): + return 2 + + def strong_order(self, terms): + return 1.5 diff --git a/diffrax/_solver/srk.py b/diffrax/_solver/srk.py new file mode 100644 index 00000000..76237401 --- /dev/null +++ b/diffrax/_solver/srk.py @@ -0,0 +1,643 @@ +import abc +from dataclasses import dataclass +from typing import Any, Generic, Optional, TYPE_CHECKING, TypeVar, Union +from typing_extensions import TypeAlias + +import equinox as eqx +import equinox.internal as eqxi +import jax +import jax.lax as lax +import jax.numpy as jnp +import jax.tree_util as jtu +import numpy as np +from equinox.internal import ω +from jaxtyping import Array, Float, PyTree + +from .._custom_types import ( + AbstractBrownianIncrement, + AbstractSpaceTimeLevyArea, + AbstractSpaceTimeTimeLevyArea, + BoolScalarLike, + DenseInfo, + FloatScalarLike, + IntScalarLike, + RealScalarLike, + VF, + Y, +) +from .._local_interpolation import LocalLinearInterpolation +from .._solution import RESULTS +from .._term import AbstractTerm, MultiTerm, ODETerm +from .base import AbstractSolver + + +if TYPE_CHECKING: + from typing import ClassVar as AbstractClassVar +else: + from equinox import AbstractClassVar + +_ErrorEstimate: TypeAlias = Optional[Y] +_SolverState: TypeAlias = None +_CarryType: TypeAlias = tuple[PyTree[Array], PyTree[Array], PyTree[Array]] + + +class AbstractStochasticCoeffs(eqx.Module): + a: eqx.AbstractVar[Union[Float[np.ndarray, " s"], tuple[np.ndarray, ...]]] + b_sol: eqx.AbstractVar[Union[Float[np.ndarray, " s"], FloatScalarLike]] + b_error: eqx.AbstractVar[Optional[Float[np.ndarray, " s"]]] + + @abc.abstractmethod + def check(self) -> int: + ... + + +class AdditiveCoeffs(AbstractStochasticCoeffs): + """Coefficients for either the Brownian increment or its Lévy areas in an + SRK for solving additive noise SDEs. + """ + + # Assuming SDE has additive noise, then we only need a 1-dimensional array + # of length s for the coefficients in front of the Brownian increment + # and/or Lévy areas (where s is the number of stages of the solver). + # This is the equivalent of the matrix a for the Brownian motion and + # its Lévy areas. + a: Float[np.ndarray, " s"] + b_sol: FloatScalarLike + + # Explicitly declare to keep pyright happy. + def __init__(self, a: Float[np.ndarray, " s"], b_sol: FloatScalarLike): + self.a = a + self.b_sol = b_sol + + @property + def b_error(self): + return None + + def check(self): + assert self.a.ndim == 1 + return self.a.shape[0] + + +class GeneralCoeffs(AbstractStochasticCoeffs): + """General coefficients for either the Brownian increment or its Lévy areas in an + SRK for solving SDEs with any type of noise (i.e. non-additive). + """ + + a: tuple[np.ndarray, ...] + b_sol: Float[np.ndarray, " s"] + b_error: Optional[Float[np.ndarray, " s"]] + + def check(self): + assert self.b_sol.ndim == 1 + assert all((i + 1,) == a_i.shape for i, a_i in enumerate(self.a)) + assert self.b_sol.shape[0] == len(self.a) + 1 + if self.b_error is not None: + assert self.b_error.ndim == 1 + assert self.b_error.shape == self.b_sol.shape + return self.b_sol.shape[0] + + +_Coeffs = TypeVar("_Coeffs", bound=AbstractStochasticCoeffs) + + +@dataclass(frozen=True) +class StochasticButcherTableau(Generic[_Coeffs]): + """A Butcher Tableau for Stochastic Runge-Kutta methods.""" + + # Coefficinets for the drift + a: list[np.ndarray] + b_sol: np.ndarray + b_error: Optional[np.ndarray] + c: np.ndarray + + # Coefficients for the Brownian increment + coeffs_w: _Coeffs + coeffs_hh: Optional[_Coeffs] + coeffs_kk: Optional[_Coeffs] + + # For some stages we may not need to evaluate the vector field for both + # the drift and the diffusion. This avoids unnecessary computations. + ignore_stage_f: Optional[np.ndarray] + ignore_stage_g: Optional[np.ndarray] + + def is_additive_noise(self): + return isinstance(self.coeffs_w, AdditiveCoeffs) + + def __post_init__(self): + assert self.c.ndim == 1 + for a_i in self.a: + assert a_i.ndim == 1 + assert self.b_sol.ndim == 1 + assert (self.b_error is None) or self.b_error.ndim == 1 + assert self.c.shape[0] == len(self.a) + assert all(i + 1 == a_i.shape[0] for i, a_i in enumerate(self.a)) + num_stages = len(self.b_sol) + assert (self.b_error is None) or self.b_error.shape[0] == num_stages + assert self.c.shape[0] + 1 == num_stages + assert np.allclose(sum(self.b_sol), 1.0) + + assert self.coeffs_w.check() == num_stages + if self.coeffs_hh is not None: + assert type(self.coeffs_hh) is type(self.coeffs_w) + assert self.coeffs_hh.check() == num_stages + if self.coeffs_kk is not None: + assert self.coeffs_hh is not None, ( + "If space-time-time Levy area (K) is used," + " space-time Levy area (H) must also be used." + ) + assert type(self.coeffs_kk) is type(self.coeffs_w) + assert self.coeffs_kk.check() == num_stages + + if self.b_error is not None and (not self.is_additive_noise()): + assert self.coeffs_w.b_error is not None + assert (self.coeffs_hh is None) or (self.coeffs_hh.b_error is not None) + assert (self.coeffs_kk is None) or (self.coeffs_kk.b_error is not None) + + if self.ignore_stage_f is not None: + assert len(self.ignore_stage_f) == len(self.b_sol) + if self.ignore_stage_g is not None: + assert len(self.ignore_stage_g) == len(self.b_sol) + if self.ignore_stage_f is not None and self.ignore_stage_g is not None: + # Check that no stages are ignored for both the drift and diffusion + assert not np.any(self.ignore_stage_f & self.ignore_stage_g) + + +StochasticButcherTableau.__init__.__doc__ = """The coefficients of a +[`diffrax.AbstractSRK`][] method. + +See the documentation for [`diffrax.AbstractSRK`][] for additional details on the +mathematical meaning of each of these arguments. + +**Arguments:** + +Let `s` denote the number of stages of the solver. + +- `a`: The lower triangle (without the diagonal) of the Butcher tableau for the drift + term. Should be a tuple of NumPy arrays, corresponding to the rows of this lower + triangle. The first array should be of shape `(1,)`. Each subsequent array should + be of shape `(2,)`, `(3,)` etc. The final array should have shape `(s - 1,)`. +- `b_sol`: The linear combination of drift stages to take to produce the output at each + step. Should be a NumPy array of shape `(s,)`. +- `b_error`: The linear combination of stages to take to produce the error estimate at + each step. Should be a NumPy array of shape `(s,)`. Note that this is *not* + differenced against `b_sol` prior to evaluation. (i.e. `b_error` gives the linear + combination for producing the error estimate directly, not for producing some + alternate solution that is compared against the main solution). +- `c`: The time increments used in the Butcher tableau. + Should be a NumPy array of shape `(s-1,)`, as the first stage has time increment 0. +- `cfs_bm`: An instance of a subclass of `AbstractStochasticTableau` representing + the coefficients for the Brownian increment and possibly its Levy areas. +- `ignore_stage_f`: Optional. A NumPy array of length `s` of booleans. If `True` at + stage `j`, the vector field of the drift term will not be evaluated at stage `j`. +- `ignore_stage_g`: Optional. A NumPy array of length `s` of booleans. If `True` at + stage `j`, the diffusion vector field will not be evaluated at stage `j`. +""" + + +class AbstractSRK(AbstractSolver[_SolverState]): + r"""A general Stochastic Runge-Kutta method. + + This accepts `terms` of the form + `MultiTerm(ODETerm(drift), ControlTerm(diffusion, brownian_motion))`. + Depending on the solver, the Brownian motion might need to generate + different types of Levy areas, specified by the `minimal_levy_area` attribute. + + For example, the [`diffrax.ShARK`][] solver requires space-time Levy area, so + it will have `minimal_levy_area = AbstractSpaceTimeLevyArea` and the Brownian + motion must be initialised with `levy_area=SpaceTimeLevyArea`. + + Given the Stratonovich SDE + $dy(t) = f(t, y(t)) dt + g(t, y(t)) \circ dw(t)$ + + We construct the SRK with $s$ stages as follows: + + $y_{n+1} = y_n + h \Big(\sum_{j=1}^s b_j f_j \Big) + + W_n \Big(\sum_{j=1}^s b^W_j g_j \Big) + + H_n \Big(\sum_{j=1}^s b^H_j g_j \Big)$ + + $f_j = f(t_0 + c_j h , z_j)$ + + $g_j = g(t_0 + c_j h , z_j)$ + + $z_j = y_n + h \Big(\sum_{i=1}^{j-1} a_{j,i} f_i \Big) + + W_n \Big(\sum_{i=1}^{j-1} a^W_{j,i} g_i \Big) + + H_n \Big(\sum_{i=1}^{j-1} a^H_{j,i} g_i \Big)$ + + where $W_n = W_{t_n, t_{n+1}}$ is the increment of the Brownian motion and + $H_n = H_{t_n, t_{n+1}}$ is its corresponding space-time Lévy Area, defined + as $H_{s,t} = \frac{1}{t-s} \int_s^t (W_{s,r} - \frac{r-s}{t-s} W_{s,t}) \, dr$. + A similar term can also be added for the space-time-time Lévy area, K, + defined as $K_{s,t} = \frac{1}{(t-s)^2} \int_s^t (W_{s,r} - \frac{r-s}{t-s} + W_{s,t}) \left( \frac{t+s}{2} - r \right) \, dr$. + + In the special case, when the SDE has additive noise, i.e. when g is + independent of y (but can still depend on t), then the SDE can be written as + $dy(t) = f(t, y(t)) dt + g(t) \, dw(t)$, and we can simplify the above to + + $y_{n+1} = y_n + h \Big(\sum_{j=1}^s b_j k_j \Big) + g(t_n) \, (b^W + \, W_n + b^H \, H_n)$ + + $f_j = f(t_n + c_j h , z_j)$ + + $z_j = y_n + h \Big(\sum_{i=1}^{j-1} a_{j,i} f_i \Big) + g(t_n) + \, (a^W_j W_n + a^H_j H_n)$ + + When g depends on t, we need to add a correction term to $y_{n+1}$ of + the form $(g(t_{n+1}) - g(t_n)) \, (\frac{1}{2} W_n - H_n)$. + + The coefficients are provided in the [`diffrax.StochasticButcherTableau`][]. + In particular the coefficients $b^W$, and $a^W$ are provided in `tableau.cfs_bm`, + as well as $b^H$, $a^H$, $b^K$, and $a^K$ if needed. + """ + + interpolation_cls = LocalLinearInterpolation + term_compatible_contr_kwargs = (dict(), dict(use_levy=True)) + tableau: AbstractClassVar[StochasticButcherTableau] + + # Indicates the type of Levy area used by the solver. + # The BM must generate at least this type of Levy area, but can generate + # more. E.g. if the solver uses space-time Levy area, then the BM generates + # space-time-time Levy area as well that is fine. The other way around would + # not work. This is mostly an easily readable indicator so that methods know + # what kind of BM to use. + @property + def minimal_levy_area(self) -> type[AbstractBrownianIncrement]: + if self.tableau.coeffs_kk is not None: + return AbstractSpaceTimeTimeLevyArea + elif self.tableau.coeffs_hh is not None: + return AbstractSpaceTimeLevyArea + else: + return AbstractBrownianIncrement + + @property + def term_structure(self): + return MultiTerm[tuple[ODETerm, AbstractTerm[Any, self.minimal_levy_area]]] + + def init( + self, + terms: MultiTerm[tuple[ODETerm, AbstractTerm[Any, AbstractBrownianIncrement]]], + t0: RealScalarLike, + t1: RealScalarLike, + y0: Y, + args: PyTree, + ) -> _SolverState: + # Check that the diffusion has the correct Levy area + _, diffusion = terms.terms + + if self.tableau.is_additive_noise(): + # check that the vector field of the diffusion term does not depend on y + ones_like_y0 = jtu.tree_map(jnp.ones_like, y0) + _, y_sigma = eqx.filter_jvp( + lambda y: diffusion.vf(t0, y, args), (y0,), (ones_like_y0,) + ) + # check if the PyTree is just made of Nones (inside other containers) + if len(jtu.tree_leaves(y_sigma)) > 0: + raise ValueError( + "Vector field of the diffusion term should be constant, " + "independent of y." + ) + + return None + + def _embed_a_lower(self, _a, dtype): + num_stages = len(self.tableau.b_sol) + tab_a = np.zeros((num_stages, num_stages)) + for i, a_i in enumerate(_a): + tab_a[i + 1, : i + 1] = a_i + return jnp.asarray(tab_a, dtype=dtype) + + def step( + self, + terms: MultiTerm[tuple[ODETerm, AbstractTerm[Any, AbstractBrownianIncrement]]], + t0: RealScalarLike, + t1: RealScalarLike, + y0: Y, + args: PyTree, + solver_state: _SolverState, + made_jump: BoolScalarLike, + ) -> tuple[Y, _ErrorEstimate, DenseInfo, _SolverState, RESULTS]: + del solver_state, made_jump + + dtype = jnp.result_type(*jtu.tree_leaves(y0)) + drift, diffusion = terms.terms + if self.tableau.ignore_stage_f is None: + ignore_stage_f = None + else: + ignore_stage_f = jnp.array(self.tableau.ignore_stage_f) + if self.tableau.ignore_stage_g is None: + ignore_stage_g = None + else: + ignore_stage_g = jnp.array(self.tableau.ignore_stage_g) + + # time increment + h = t1 - t0 + + # First the drift related stuff + a = self._embed_a_lower(self.tableau.a, dtype) + c = jnp.asarray(np.insert(self.tableau.c, 0, 0.0), dtype=dtype) + b_sol = jnp.asarray(self.tableau.b_sol, dtype=dtype) + + def make_zeros(): + def make_zeros_aux(leaf): + return jnp.zeros((len(b_sol),) + leaf.shape, dtype=leaf.dtype) + + return jtu.tree_map(make_zeros_aux, y0) + + # h_kfs is a PyTree of the same shape as y0, except that the arrays inside + # have an additional batch dimension of size len(b_sol) (i.e. num stages) + # This will be one of the entries of the carry of lax.scan. In each stage + # one of the zeros will get replaced by the value of + # h_kf_j := h * f(t0 + c_j * h, z_j) where z_j is the jth stage of the SRK. + # The name h_kf_j is because it refers to the values of f (as opposed to g) + # at stage j, which has already been multiplied by the time increment h. + h_kfs = make_zeros() + + # Now the diffusion related stuff + # Brownian increment (and space-time Lévy area) + bm_inc = diffusion.contr(t0, t1, use_levy=True) + assert isinstance(bm_inc, self.minimal_levy_area) + w = bm_inc.W + + # b looks similar regardless of whether we have additive noise or not + b_w = jnp.asarray(self.tableau.coeffs_w.b_sol, dtype=dtype) + b_levy_list = [] + + levy_areas = [] + if self.tableau.coeffs_hh is not None: # space-time Levy area + assert isinstance(bm_inc, AbstractSpaceTimeLevyArea) + levy_areas.append(bm_inc.H) + b_levy_list.append(jnp.asarray(self.tableau.coeffs_hh.b_sol, dtype=dtype)) + + if self.tableau.coeffs_kk is not None: # space-time-time Levy area + assert isinstance(bm_inc, AbstractSpaceTimeTimeLevyArea) + levy_areas.append(bm_inc.K) + b_levy_list.append( + jnp.asarray(self.tableau.coeffs_kk.b_sol, dtype=dtype) + ) + + def add_levy_to_w(_cw, *_c_levy): + def aux_add_levy(w_leaf, *levy_leaves): + return _cw * w_leaf + sum( + _c * _leaf for _c, _leaf in zip(_c_levy, levy_leaves) + ) + + return aux_add_levy + + a_levy = [] # if noise is additive this is [cH, cK] (if those entries exist) + # otherwise this is [aH, aK] (if those entries exist) + + levylist_kgs = [] # will contain levy * g(t0 + c_j * h, z_j) for each stage j + # where levy is either H or K (if those entries exist) + # this is similar to h_kfs or w_kgs, but for the Levy area(s) + + if self.tableau.is_additive_noise(): # additive noise + # compute g once since it is constant + + @jax.vmap + def _comp_g(_t): + return diffusion.vf(_t, y0, args) + + g0_g1 = _comp_g(jnp.array([t0, t1], dtype=dtype)) + g0 = jtu.tree_map(lambda g_leaf: g_leaf[0], g0_g1) + # g_delta = 0.5 * g1 - g0 + g_delta = jtu.tree_map(lambda g_leaf: 0.5 * (g_leaf[1] - g_leaf[0]), g0_g1) + w_kgs = diffusion.prod(g0, w) + a_w = jnp.asarray(self.tableau.coeffs_w.a, dtype=dtype) + + if self.tableau.coeffs_hh is not None: # space-time Levy area + assert isinstance(bm_inc, AbstractSpaceTimeLevyArea) + levylist_kgs.append(diffusion.prod(g0, bm_inc.H)) + a_levy.append(jnp.asarray(self.tableau.coeffs_hh.a, dtype=dtype)) + + if self.tableau.coeffs_kk is not None: # space-time-time Levy area + assert isinstance(bm_inc, AbstractSpaceTimeTimeLevyArea) + levylist_kgs.append(diffusion.prod(g0, bm_inc.K)) + a_levy.append(jnp.asarray(self.tableau.coeffs_kk.a, dtype=dtype)) + + carry: _CarryType = (h_kfs, None, None) + + else: # general (non-additive) noise + g_delta = None # so pyright doesn't complain + + # g is not constant, so we need to compute it at each stage + # we will carry the value of W * g(t0 + c_j * h, z_j) + # Since the carry of lax.scan needs to have constant shape, + # we initialise a list of zeros of the same shape as y0, which will get + # filled with the values of W * g(t0 + c_j * h, z_j) at each stage + w_kgs = make_zeros() + a_w = self._embed_a_lower(self.tableau.coeffs_w.a, dtype) + + # do the same for each type of Levy area + if self.tableau.coeffs_hh is not None: # space-time Levy area + levylist_kgs.append(make_zeros()) + a_levy.append(self._embed_a_lower(self.tableau.coeffs_hh.a, dtype)) + if self.tableau.coeffs_kk is not None: # space-time-time Levy area + levylist_kgs.append(make_zeros()) + a_levy.append(self._embed_a_lower(self.tableau.coeffs_kk.a, dtype)) + + carry: _CarryType = (h_kfs, w_kgs, levylist_kgs) + + stage_nums = jnp.arange(len(self.tableau.b_sol)) + + scan_inputs = (stage_nums, a, c, a_w, a_levy) + + def sum_prev_stages(_stage_out_buff, _a_j): + # Unwrap the buffer + _stage_out_view = jtu.tree_map( + lambda _, _buff: _buff[...], y0, _stage_out_buff + ) + # Sum up the previous stages weighted by the coefficients in the tableau + return jtu.tree_map( + lambda lf: jnp.tensordot(_a_j, lf, axes=1), _stage_out_view + ) + + def insert_jth_stage(results, k_j, j): + # Insert the result of the jth stage into the buffer + return jtu.tree_map( + lambda k_j_leaf, res_leaf: res_leaf.at[j].set(k_j_leaf), k_j, results + ) + + def stage( + _carry: _CarryType, + x: tuple[IntScalarLike, Array, Array, Array, list[Array]], + ): + # Represents the jth stage of the SRK. + + j, a_j, c_j, a_w_j, a_levy_list_j = x + # a_levy_list_j = [aH_j, aK_j] (if those entries exist) where + # aH_j is the row in the aH matrix corresponding to stage j + # same for aK_j, but for space-time-time Lévy area K. + _h_kfs, _w_kgs, _levylist_kgs = _carry + + if self.tableau.is_additive_noise(): + # carry = (_h_kfs, None, None) where + # _h_kfs = Array[h_kf_1, h_kf_2, ..., hk_{j-1}, 0, 0, ..., 0] + # h_kf_i = drift.vf_prod(t0 + c_i*h, y_i, args, h) + assert _w_kgs is None and _levylist_kgs is None + assert isinstance(levylist_kgs, list) + _diffusion_result = jtu.tree_map( + add_levy_to_w(a_w_j, *a_levy_list_j), + w_kgs, + *levylist_kgs, + ) + else: + # carry = (_h_kfs, _w_kgs, _levylist_kgs) where + # _h_kfs = Array[h_kf_1, h_kf_2, ..., h_kf_{j-1}, 0, 0, ..., 0] + # _w_kgs = Array[w_kg_1, w_kg_2, ..., w_kg_{j-1}, 0, 0, ..., 0] + # _levylist_kgs = [H_gs, K_gs] (if those entries exist) where + # H_gs = Array[Hg1, Hg2, ..., Hg{j-1}, 0, 0, ..., 0] + # K_gs = Array[Kg1, Kg2, ..., Kg{j-1}, 0, 0, ..., 0] + # h_kf_i = drift.vf_prod(t0 + c_i*h, y_i, args, h) + # w_kg_i = diffusion.vf_prod(t0 + c_i*h, y_i, args, w) + w_kg_sum = sum_prev_stages(_w_kgs, a_w_j) + levy_sum_list = [ + sum_prev_stages(levy_gs, a_levy_j) + for a_levy_j, levy_gs in zip(a_levy_list_j, _levylist_kgs) + ] + + _diffusion_result = jtu.tree_map( + lambda _w_kg, *_levy_g: sum(_levy_g, _w_kg), + w_kg_sum, + *levy_sum_list, + ) + + # compute Σ_{i=1}^{j-1} a_j_i h_kf_i + _drift_result = sum_prev_stages(_h_kfs, a_j) + + # z_j = y_0 + h (Σ_{i=1}^{j-1} a_j_i k_i) + g * (a_w_j * ΔW + cH_j * ΔH) + z_j = (y0**ω + _drift_result**ω + _diffusion_result**ω).ω + + def compute_and_insert_kf_j(_h_kfs_in): + h_kf_j = drift.vf_prod(t0 + c_j * h, z_j, args, h) + return insert_jth_stage(_h_kfs_in, h_kf_j, j) + + if ignore_stage_f is None: + _h_kfs = compute_and_insert_kf_j(_h_kfs) + else: + drift_pred = jnp.logical_not(ignore_stage_f[j]) + _h_kfs = lax.cond( + eqxi.nonbatchable(drift_pred), + compute_and_insert_kf_j, + lambda what: what, + _h_kfs, + ) + + if self.tableau.is_additive_noise(): + return (_h_kfs, None, None), None + + def compute_and_insert_kg_j(_w_kgs_in, _levylist_kgs_in): + _w_kg_j = diffusion.vf_prod(t0 + c_j * h, z_j, args, w) + new_w_kgs = insert_jth_stage(_w_kgs_in, _w_kg_j, j) + + _levylist_kg_j = [ + diffusion.vf_prod(t0 + c_j * h, z_j, args, levy) + for levy in levy_areas + ] + new_levylist_kgs = insert_jth_stage(_levylist_kgs_in, _levylist_kg_j, j) + return new_w_kgs, new_levylist_kgs + + if ignore_stage_g is None: + _w_kgs, _levylist_kgs = compute_and_insert_kg_j(_w_kgs, _levylist_kgs) + else: + diffusion_pred = jnp.logical_not(ignore_stage_g[j]) + _w_kgs, _levylist_kgs = lax.cond( + eqxi.nonbatchable(diffusion_pred), + compute_and_insert_kg_j, + lambda x, y: (x, y), + _w_kgs, + _levylist_kgs, + ) + + return (_h_kfs, _w_kgs, _levylist_kgs), None + + scan_out = eqxi.scan( + stage, + carry, + scan_inputs, + len(b_sol), + buffers=lambda x: x, + kind="checkpointed", + checkpoints="all", + ) + + if self.tableau.is_additive_noise(): + # output of lax.scan is ((num_stages, _h_kfs), None) + (h_kfs, _, _), _ = scan_out + diffusion_result = jtu.tree_map( + add_levy_to_w(b_w, *b_levy_list), + w_kgs, + *levylist_kgs, + ) + + # In the additive noise case (i.e. when g is independent of y), + # we still need a correction term in case the diffusion vector field + # g depends on t. This term is of the form $(g1 - g0) * (0.5*W_n - H_n)$. + if self.tableau.coeffs_hh is not None: # space-time Levy area + assert isinstance(bm_inc, AbstractSpaceTimeLevyArea) + time_var_contr = (bm_inc.W**ω - 2.0 * bm_inc.H**ω).ω + time_var_term = diffusion.prod(g_delta, time_var_contr) + else: + time_var_term = diffusion.prod(g_delta, bm_inc.W) + diffusion_result = (diffusion_result**ω + time_var_term**ω).ω + + else: + # output of lax.scan is ((num_stages, _h_kfs, _w_kgs, _levylist_kgs), None) + (h_kfs, w_kgs, levylist_kgs), _ = scan_out + b_w_kgs = sum_prev_stages(w_kgs, b_w) + b_levylist_kgs = [ + sum_prev_stages(levy_gs, b_levy) + for b_levy, levy_gs in zip(b_levy_list, levylist_kgs) + ] + diffusion_result = jtu.tree_map( + lambda b_w_kg, *b_levy_g: sum(b_levy_g, b_w_kg), + b_w_kgs, + *b_levylist_kgs, + ) + + # compute Σ_{j=1}^s b_j k_j + if self.tableau.b_error is None: + error = None + else: + b_err = jnp.asarray(self.tableau.b_error, dtype=dtype) + drift_error = sum_prev_stages(h_kfs, b_err) + if self.tableau.coeffs_w.b_error is not None: + get_err_w = self.tableau.coeffs_w.b_error + assert get_err_w is not None + bw_err = jnp.asarray(get_err_w, dtype=dtype) + w_err = sum_prev_stages(w_kgs, bw_err) + b_levy_err_list = [] + if self.tableau.coeffs_hh is not None: + get_err_hh = self.tableau.coeffs_hh.b_error + assert get_err_hh is not None + b_levy_err_list.append(jnp.asarray(get_err_hh, dtype=dtype)) + if self.tableau.coeffs_kk is not None: + get_err_kk = self.tableau.coeffs_kk.b_error + assert get_err_kk is not None + b_levy_err_list.append(jnp.asarray(get_err_kk, dtype=dtype)) + levy_err = [ + sum_prev_stages(levy_gs, b_levy_err) + for b_levy_err, levy_gs in zip(b_levy_err_list, levylist_kgs) + ] + diffusion_error = jtu.tree_map( + lambda _w_err, *_levy_err: sum(_levy_err, _w_err), w_err, *levy_err + ) + error = (drift_error**ω + diffusion_error**ω).ω + else: + error = drift_error + + # y1 = y0 + (Σ_{i=1}^{s} b_j * h*k_j) + g * (b_w * ΔW + b_H * ΔH) + + drift_result = sum_prev_stages(h_kfs, b_sol) + + y1 = (y0**ω + drift_result**ω + diffusion_result**ω).ω + dense_info = dict(y0=y0, y1=y1) + return y1, error, dense_info, None, RESULTS.successful + + def func( + self, + terms: MultiTerm[tuple[ODETerm, AbstractTerm[Any, AbstractBrownianIncrement]]], + t0: RealScalarLike, + y0: Y, + args: PyTree, + ) -> VF: + return terms.vf(t0, y0, args) diff --git a/diffrax/_term.py b/diffrax/_term.py index 28b2b213..2f7eca30 100644 --- a/diffrax/_term.py +++ b/diffrax/_term.py @@ -50,7 +50,7 @@ def vf(self, t: RealScalarLike, y: Y, args: Args) -> _VF: pass @abc.abstractmethod - def contr(self, t0: RealScalarLike, t1: RealScalarLike) -> _Control: + def contr(self, t0: RealScalarLike, t1: RealScalarLike, **kwargs) -> _Control: r"""The control. Represents the $\mathrm{d}t$ in an ODE, or the $\mathrm{d}w(t)$ in an SDE, etc. @@ -198,7 +198,7 @@ def _broadcast_and_upcast(oi, yi): return jtu.tree_map(_broadcast_and_upcast, out, y) - def contr(self, t0: RealScalarLike, t1: RealScalarLike) -> RealScalarLike: + def contr(self, t0: RealScalarLike, t1: RealScalarLike, **kwargs) -> RealScalarLike: return t1 - t0 def prod(self, vf: _VF, control: RealScalarLike) -> Y: @@ -267,8 +267,8 @@ class _AbstractControlTerm(AbstractTerm[_VF, _Control]): def vf(self, t: RealScalarLike, y: Y, args: Args) -> VF: return self.vector_field(t, y, args) - def contr(self, t0: RealScalarLike, t1: RealScalarLike) -> _Control: - return self.control.evaluate(t0, t1) # pyright: ignore + def contr(self, t0: RealScalarLike, t1: RealScalarLike, **kwargs) -> _Control: + return self.control.evaluate(t0, t1, **kwargs) # pyright: ignore def to_ode(self) -> ODETerm: r"""If the control is differentiable then $f(t, y(t), args) \mathrm{d}x(t)$ @@ -411,9 +411,9 @@ def vf(self, t: RealScalarLike, y: Y, args: Args) -> tuple[PyTree[ArrayLike], .. return tuple(term.vf(t, y, args) for term in self.terms) def contr( - self, t0: RealScalarLike, t1: RealScalarLike + self, t0: RealScalarLike, t1: RealScalarLike, **kwargs ) -> tuple[PyTree[ArrayLike], ...]: - return tuple(term.contr(t0, t1) for term in self.terms) + return tuple(term.contr(t0, t1, **kwargs) for term in self.terms) def prod( self, vf: tuple[PyTree[ArrayLike], ...], control: tuple[PyTree[ArrayLike], ...] @@ -455,10 +455,10 @@ def vf(self, t: RealScalarLike, y: Y, args: Args) -> _VF: t = t * self.direction return self.term.vf(t, y, args) - def contr(self, t0: RealScalarLike, t1: RealScalarLike) -> _Control: + def contr(self, t0: RealScalarLike, t1: RealScalarLike, **kwargs) -> _Control: _t0 = jnp.where(self.direction == 1, t0, -t1) _t1 = jnp.where(self.direction == 1, t1, -t0) - return (self.direction * self.term.contr(_t0, _t1) ** ω).ω + return (self.direction * self.term.contr(_t0, _t1, **kwargs) ** ω).ω def prod(self, vf: _VF, control: _Control) -> Y: return self.term.prod(vf, control) @@ -558,8 +558,8 @@ def _fn(_control): ) return jtu.tree_transpose(vf_prod_tree, control_tree, jac) - def contr(self, t0: RealScalarLike, t1: RealScalarLike) -> _Control: - return self.term.contr(t0, t1) + def contr(self, t0: RealScalarLike, t1: RealScalarLike, **kwargs) -> _Control: + return self.term.contr(t0, t1, **kwargs) def prod( self, vf: PyTree[ArrayLike], control: _Control diff --git a/docs/api/solvers/abstract_solvers.md b/docs/api/solvers/abstract_solvers.md index 2942c989..1dbbb54b 100644 --- a/docs/api/solvers/abstract_solvers.md +++ b/docs/api/solvers/abstract_solvers.md @@ -84,3 +84,16 @@ In addition [`diffrax.AbstractSolver`][] has several subclasses that you can use ::: diffrax.CalculateJacobian selection: members: false + +--- + +### Abstract Stochastic Runge--Kutta (SRK) solvers + +::: diffrax.AbstractSRK + selection: + members: false + +::: diffrax.StochasticButcherTableau + selection: + members: + - __init__ \ No newline at end of file diff --git a/docs/api/solvers/sde_solvers.md b/docs/api/solvers/sde_solvers.md index 7c823352..0e1bd86f 100644 --- a/docs/api/solvers/sde_solvers.md +++ b/docs/api/solvers/sde_solvers.md @@ -20,7 +20,9 @@ See also [How to choose a solver](../../usage/how-to-choose-a-solver.md#stochast --- -### Explicit Runge--Kutta (ERK) methods +## Explicit Runge--Kutta (ERK) methods + +These solvers can be used to solve SDEs just as well as they can be used to solve ODEs. ::: diffrax.Euler selection: @@ -44,32 +46,60 @@ See also [How to choose a solver](../../usage/how-to-choose-a-solver.md#stochast --- -### Reversible methods +## SDE-only solvers -These are reversible in the same way as when applied to ODEs. [See here.](./ode_solvers.md#reversible-methods) +!!! info "Term structure" -::: diffrax.ReversibleHeun + These solvers are SDE-specific. For these, `terms` must specifically be of the form `MultiTerm(ODETerm(...), SomeOtherTerm(...))` (Typically `SomeOTherTerm` will be a `ControlTerm` or `WeaklyDiagonalControlTerm`) representing the drift and diffusion specifically. + + +::: diffrax.EulerHeun selection: members: false ---- +::: diffrax.ItoMilstein + selection: + members: false -### SDE-only solvers +::: diffrax.StratonovichMilstein + selection: + members: false -!!! info "Term structure" +### Stochastic Runge--Kutta (SRK) - These solvers are SDE-specific. For these, `terms` must specifically be of the form `MultiTerm(ODETerm(...), SomeOtherTerm(...))` (Typically `SomeOTherTerm` will be a `ControlTerm` or `WeaklyDiagonalControlTerm`) representing the drift and diffusion specifically. +These are a particularly important class of SDE-only solvers. +::: diffrax.SEA + selection: + members: false -::: diffrax.EulerHeun +::: diffrax.SRA1 selection: members: false -::: diffrax.ItoMilstein +::: diffrax.ShARK selection: members: false -::: diffrax.StratonovichMilstein +::: diffrax.GeneralShARK + selection: + members: false + +::: diffrax.SlowRK + selection: + members: false + +::: diffrax.SPaRK + selection: + members: false + +--- + +### Reversible methods + +These are reversible in the same way as when applied to ODEs. [See here.](./ode_solvers.md#reversible-methods) + +::: diffrax.ReversibleHeun selection: members: false diff --git a/docs/devdocs/srk_example.ipynb b/docs/devdocs/srk_example.ipynb new file mode 100644 index 00000000..2e99e5a1 --- /dev/null +++ b/docs/devdocs/srk_example.ipynb @@ -0,0 +1,613 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "bcbdcebc7f4e019f", + "metadata": { + "collapsed": false + }, + "source": [ + "# Stochastic Runge-Kutta (SRK) demonstration\n", + "The `AbstractSRK` class takes a `StochasticButcherTableau` and implements the corresponding SRK method.\n", + "Depending on the tableau, the resulting method can either be used for general SDEs, or just for ones with additive noise.\n", + "The additive-noise-only methods are somewhat faster, but will fail if the noise is not additive.\n", + "Nevertheless, even in the additive noise case, the diffusion vector field can depend on time (just not on the state $y$). Then the SDE has the form:\n", + "$$\n", + "\\mathrm{d}y = f(y, t) \\mathrm{d}t + g(t) \\mathrm{d}W_t.\n", + "$$\n", + "To account for time-dependent noise, the SRK adds a term to the output of each step, which allows it to still maintain its usual strong order of convergence.\n", + "\n", + "The SRK is capable of utilising various types of time Levy area, depending on the tableau provided. It can use:\n", + "- just the Brownian motion $W$, withouth any Levy area\n", + "- $W$ and the space-time Levy area $H$\n", + "- $W$, $H$ and the space-time-time Levy area $K$.\n", + "For more information see the documentation of the `StochasticButcherTableau` class.\n", + "\n", + "First we will demonstrate an additive-noise-only SRK method, the ShARK method, on an SDE with additive, time-dependent noise.\n", + "For more additive-noise SRKs see the langevin.ipynb notebook.\n", + "\n", + "We will compare various additive-noise-only SRK methods as well as some general SRK methods proposed by Foster." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "initial_id", + "metadata": { + "collapsed": true, + "ExecuteTime": { + "end_time": "2024-04-02T15:04:56.415203Z", + "start_time": "2024-04-02T15:04:56.369949Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "env: JAX_PLATFORM_NAME=cuda\n" + ] + } + ], + "source": [ + "%env JAX_PLATFORM_NAME=cuda\n", + "\n", + "from test.helpers import (\n", + " get_mlp_sde,\n", + " get_time_sde,\n", + " simple_sde_order,\n", + ")\n", + "\n", + "import diffrax\n", + "import jax.numpy as jnp\n", + "import jax.random as jr\n", + "import matplotlib.pyplot as plt\n", + "from diffrax import (\n", + " diffeqsolve,\n", + " GeneralShARK,\n", + " ShARK,\n", + " SlowRK,\n", + " SpaceTimeLevyArea,\n", + " SPaRK,\n", + " SRA1,\n", + ")\n", + "from jax import config\n", + "\n", + "\n", + "config.update(\"jax_enable_x64\", True)\n", + "jnp.set_printoptions(precision=4, suppress=True)\n", + "\n", + "\n", + "# Plotting\n", + "def draw_order(results):\n", + " steps, errs, order = results\n", + " plt.plot(steps, errs)\n", + " plt.yscale(\"log\")\n", + " plt.xscale(\"log\")\n", + " pretty_steps = [int(step) for step in steps]\n", + " plt.xticks(ticks=pretty_steps, labels=pretty_steps)\n", + " plt.ylabel(\"RMS error\")\n", + " plt.xlabel(\"average number of steps\")\n", + " plt.show()\n", + " print(f\"Order of convergence: {order:.4f}\")\n", + "\n", + "\n", + "def plot_sol_general(sol):\n", + " plt.plot(sol.ts, sol.ys)\n", + " plt.show()\n", + "\n", + "\n", + "def draw_order_multiple(results_list, names_list, title=None):\n", + " plt.figure(dpi=200)\n", + " if title is not None:\n", + " plt.title(title)\n", + "\n", + " orders = \"Orders of convergence:\\n\"\n", + " for results, name in zip(results_list, names_list):\n", + " steps, errs, order = results\n", + " plt.plot(steps, errs, label=name)\n", + " orders += f\"{name}: {order:.4f}\\n\"\n", + " plt.yscale(\"log\")\n", + " plt.xscale(\"log\")\n", + " # pretty_hs = [\"{0:0.4f}\".format(h) for h in _hs]\n", + " # plt.xticks(ticks=1 / _hs, labels=pretty_hs)\n", + " plt.ylabel(\"RMS error\")\n", + " plt.xlabel(\"average number of steps\")\n", + " plt.legend()\n", + "\n", + " # Write the orders in the corner of the plot\n", + " plt.text(\n", + " 0.05,\n", + " 0.05,\n", + " orders,\n", + " transform=plt.gca().transAxes,\n", + " verticalalignment=\"bottom\",\n", + " fontsize=10,\n", + " )\n", + " plt.show()\n", + "\n", + "\n", + "dtype = jnp.float64\n", + "key = jr.PRNGKey(2)\n", + "sde_key = jr.PRNGKey(11)\n", + "num_samples = 100\n", + "keys = jr.split(jr.PRNGKey(5678), num=num_samples)\n", + "\n", + "t0, t1 = 0.0, 16.0\n", + "t_short = 4.0\n", + "t_long = 32.0\n", + "save_at_solver_steps = diffrax.SaveAt(steps=True)\n", + "\n", + "\n", + "def constant_step_strong_order(keys, sde, solver, levels, bm_tol=None):\n", + " def _step_ts(level):\n", + " return jnp.linspace(sde.t0, sde.t1, 2**level + 1, endpoint=True)\n", + "\n", + " def get_controller(level):\n", + " return None, diffrax.StepTo(ts=_step_ts(level))\n", + "\n", + " _saveat = diffrax.SaveAt(ts=_step_ts(levels[0]))\n", + " if bm_tol is None:\n", + " bm_tol = (sde.t1 - sde.t0) * (2 ** -(levels[1] + 8))\n", + " return simple_sde_order(\n", + " keys, sde, solver, solver, levels, get_controller, _saveat, bm_tol\n", + " )\n", + "\n", + "\n", + "def pid_strong_order(keys, sde, solver, levels, bm_tol=2**-18):\n", + " save_ts_pid = jnp.linspace(sde.t0, sde.t1, 65, endpoint=True)\n", + "\n", + " def get_pid(level):\n", + " return diffrax.PIDController(\n", + " pcoeff=0.1,\n", + " icoeff=0.3,\n", + " rtol=2 ** -(level - 1),\n", + " atol=2 ** -(level + 3),\n", + " step_ts=save_ts_pid,\n", + " dtmin=2**-14,\n", + " )\n", + "\n", + " saveat_pid = diffrax.SaveAt(ts=save_ts_pid)\n", + " return simple_sde_order(\n", + " keys, sde, solver, solver, levels, get_pid, saveat_pid, bm_tol\n", + " )\n", + "\n", + "\n", + "time_sde = get_time_sde(t0, t1, dtype=dtype, noise_dim=7, key=sde_key)\n", + "terms_time_sde = time_sde.get_terms(\n", + " time_sde.get_bm(key, levy_area=SpaceTimeLevyArea, tol=2**-10)\n", + ")\n", + "time_sde_short = get_time_sde(t0, t_short, dtype=dtype, noise_dim=7, key=sde_key)\n", + "\n", + "mlp_sde = get_mlp_sde(t0, t1, dtype=dtype, key=sde_key, noise_dim=3)\n", + "terms_mlp_sde = mlp_sde.get_terms(\n", + " mlp_sde.get_bm(key, levy_area=SpaceTimeLevyArea, tol=2**-10)\n", + ")\n", + "mlp_sde_short = get_mlp_sde(t0, t_short, dtype=dtype, key=sde_key, noise_dim=3)\n", + "\n", + "commutative_sde = get_mlp_sde(t0, t1, dtype=dtype, key=sde_key, noise_dim=1)\n", + "terms_commutative_sde = commutative_sde.get_terms(\n", + " commutative_sde.get_bm(key, levy_area=SpaceTimeLevyArea, tol=2**-10)\n", + ")\n", + "commutative_sde_short = get_mlp_sde(t0, t_short, dtype=dtype, key=sde_key, noise_dim=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "3b0a11dc7bb9f9bc", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-04-02T15:04:57.426134Z", + "start_time": "2024-04-02T15:04:56.415910Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "
", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiIAAAGdCAYAAAAvwBgXAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAAA9hAAAPYQGoP6dpAACouElEQVR4nOzdd3hT1RsH8G9Gm+699y5ll7333ihDBfcWVMSJCuLEgVt/7oEKKiBLRZFN2aOUTQd07z3SlSb398fpvclt0kWbJqXv53n6cHdOoTRvznnPeyQcx3EghBBCCDEBqakbQAghhJCuiwIRQgghhJgMBSKEEEIIMRkKRAghhBBiMhSIEEIIIcRkKBAhhBBCiMlQIEIIIYQQk6FAhBBCCCEmIzd1A5qi0WiQlZUFe3t7SCQSUzeHEEIIIS3AcRzKy8vh4+MDqbTpPg+zDkSysrLg7+9v6mYQQggh5Aakp6fDz8+vyWvMOhCxt7cHwL4RBwcHE7eGEEIIIS1RVlYGf39/4X28KWYdiPDDMQ4ODhSIEEIIIZ1MS9IqKFmVEEIIISZDgQghhBBCTIYCEUIIIYSYDAUihBBCCDEZCkQIIYQQYjIUiBBCCCHEZCgQIYQQQojJUCBCCCGEEJOhQIQQQgghJkOBCCGEEEJM5oYDkUOHDmHmzJnw8fGBRCLBtm3bhHMqlQrPP/88evXqBVtbW/j4+OCuu+5CVlZWe7SZEEIIITeJGw5ElEol+vTpg88//1zvXGVlJWJjY7Fy5UrExsZiy5YtiI+Px6xZs9rUWEIIIYTcXCQcx3FtfohEgq1bt2LOnDmNXnPq1CkMGjQIqampCAgIaNFzy8rK4OjoiNLSUlr0jhBCCGknmSVV2HY2E+EedhjXzQNyWftmarTm/bvDVt8tLS2FRCKBk5NTo9fU1NSgpqZG2C8rK+uAlhFCCCFdx28n0/DClgvC/hPjwrB8UqTJ2tMhyarV1dV4/vnncfvttzcZGa1ZswaOjo7Cl7+/f0c0jxBCCOky1v6XINqvUqnRDoMjN8zogYhKpcKCBQvAcRy++OKLJq9dsWIFSktLha/09HRjN48QQgjpMmrrNCioYCMPYyPd8efSEXhpendIJBKTtcmoQzN8EJKamop9+/Y1O06kUCigUCiM2SRCCCGky7peUAEAsLeS4/t7Bpo0AOEZLRDhg5DExETs378frq6uxnopQgghhLRAfE45ACDS094sghCgDYFIRUUFkpKShP3k5GTExcXBxcUF3t7emDdvHmJjY/HXX39BrVYjJycHAODi4gJLS8u2t5wQQgghzTqfUYIdcVlQ1qrx68k0AECkl72JW6V1w4HI6dOnMXbsWGF/+fLlAIC7774bq1evxo4dOwAAffv2Fd23f/9+jBkz5kZflhBCCCEtVFlbh7n/Owq1RpyM2u1mCETGjBnTZJatKTNwCSGEEAIcSijQC0IAIMrbfGpz0VozhBBCyE3qUlapsD0kxAUSCTC7rw/6BzqbsFViHVbQjBBCCCEd5+i1Any6j+Vyvj67B+4cGoSqWjWsLWUmbpkY9YgQQgghNxllTR3u+OaEsN/D1xEAzC4IAahHhBBCCLmp/HMhG1fqp+nyorzMJyekIQpECCGEkJvEmdQiPLo+Vu+4OfaE8GhohhBCCLlJnEsv1Tv21IQIE7Sk5ahHhBBCCLlJlFWrRPvblgxHHz9HE7WmZahHhBBCCLlJpBZWiva7eZlPKffGUCBCCCGE3ARUag3OpBYL+2/O7QkrC/PNDeHR0AwhhBByE/j7fDbSiirhamuJfU+PgaONhamb1CLUI0IIIYTcBDafyQAA3Dk0sNMEIQD1iBBCCCGdWnZpFdYfT8PhpAIAwIze3iZuUetQIEIIIYR0YvO+OIbMkioAgFQCBLramrhFrUNDM4QQQkgnVVatEoIQAPBxsoaFrHO9tXeu1hJCCCEEBRU1AICUAqXoeKCrjSma0yYUiBBCCCGdyK8n0zDgjT349WQaipS1onM9fcy7eJkhFIgQQgghZoTjOJxNK0Ztncbg+RVbLgh/Hk5kCap+ztb4+s7+eHJCeIe1s71QIEIIIYSYidIqFRZ+dRxz/3cUb/x9udnrvz2cDACIDnDGpB5esLHsfHNQKBAhhBBCzMCV7DL0efU/nEwpAgD8dCwVVbVqvetsDKyk69KJ6oY0RIEIIYQQYgae2XRO79gPR5NF+0eSClBpIDhxtrU0WruMjQIRQgghxMTq1BpcyirTO/7XuWwAgLKmDgDw3q544Vyou7ZeiAsFIoQQQgi5UQm5FcL2zidGYv8zYwAAl7PLsOFEGnqu3oWfj6fiUlYpAGByD0/09XcW7vFysOrQ9rYnCkQIIYQQE0uurwfSP9AZ3X0c4O9sLZx7cesFcBywcttFqNQcpBLg49ui4emgEK4ZFubW4W1uLxSIEEIIISaWW1YNQNuzIW+iOmqQmy2sLGS4bWAA7BRyLB7C/uysOm/LCSGEkJsEH4h46gyxvD67B1Zuv6R3rb8zq54a4GqDC6sndUwDjYh6RAghhBATE3pEHLXDLXcODUIffye9a32ctMM2EokEEonE6O0zJgpECCGEEBPLMdAjAgARHnZ619oaqCPSmVEgQgghhJhYXhlbxK5hIHLn0EAAgIVM2+uh4TquXR2BckQIIYQQE+I4rtEekd5+TvjtoSGwkEnx07EU7LyQjbvqg5ObBQUihBBCiAmV19QJ1VJ1p+TyhoS4AgCi/Z3w2qyecOzE5dwNoaEZQgghxITy6ntD7K3kTS5aJ5VKbrogBKBAhBBCCDGpnFKWH9KZq6O2xQ0HIocOHcLMmTPh4+MDiUSCbdu2ic5zHIdVq1bB29sb1tbWmDBhAhITE9vaXkIIIeSmYqiGSFdyw4GIUqlEnz598Pnnnxs8/+677+KTTz7Bl19+iRMnTsDW1haTJ09GdXX1DTeWEEIIudk0lqjaVdxwsurUqVMxdepUg+c4jsNHH32El19+GbNnzwYA/PTTT/D09MS2bdtw22233ejLEkIIITeVPAPFzNrDlsQtyFZm47E+j5l10TOj5IgkJycjJycHEyZMEI45Ojpi8ODBOHbsWKP31dTUoKysTPRFCCGEdDYZxZX4NuY6aus0zV5rjB4RjuPwytFX8OW5L3GpUL9MvDkxSiCSk5MDAPD09BQd9/T0FM4ZsmbNGjg6Ogpf/v7+xmgeIYQQYjQcx2HEO/vxxt9X8Pvp9Gavz2mkmFlbVNVVCdu5ytx2e64xmNWsmRUrVqC0tFT4Sk9v/h+QEEIIMSfnM0qF7djU4mavzzNCj0hpjbYNJTUl7fZcYzBKIOLl5QUAyM0VR2G5ubnCOUMUCgUcHBxEX4QQQkhncux6obBdWqVq8tqaOjXyylmPiI9jOwYitdpAJKdSPBJRWlOKSZsnYfmB5e32em1hlEAkODgYXl5e2Lt3r3CsrKwMJ06cwNChQ43xkoQQQohZOJVcJGynFCibvDa5QAm1hoO9lRzu9u2XrKrbC3KpQJwjsjt1N7KV2diduhsl1SUwtRsORCoqKhAXF4e4uDgALEE1Li4OaWlpkEgkWLZsGd544w3s2LEDFy5cwF133QUfHx/MmTOnnZpOCCGEmJeqWjWOXtP2iKQWVaJapW70+sTcCgBAuIddu85s0R2aicmMQXxRvLB/pfCKsH0m7ww4zrSr6N1wIHL69GlER0cjOjoaALB8+XJER0dj1apVAIDnnnsOjz/+OB566CEMHDgQFRUV+Pfff2Fl1TXnSRNCCLn5HU4qQJVKDT9na7jZWUKt4XApq/EZoIl5fCBi367t0A1EAOBc/jlhO75YG5Qs278MWxK3tOtrt9YNByJjxowBx3F6Xz/++CMAQCKR4LXXXkNOTg6qq6uxZ88eREREtFe7CSGEELOTVB9YDAxyQW8/JwDA+YySJq4vBwCEe9q1azsaBiLXSq4J24VVhaJz+9L3tetrtxatvksIIYS0k5xSNm3W29EKHg4K7Luah7SiykavT+CHZjzbt0ekqJrlqbhYuaCouggbrm7AgsgFCHUKRWE1C0ReHPwiFDIFRviOaNfXbi2zmr5LCCGEGFtFTR3+OJPR7IyWltLNsbiaw3o4vJ2s4WJjCQAoqTT8Onll1UIPSrhH+/aIFFQVAACmBU8Tjr13+j1cK7km1BiZFToLt4TfAg8bj3Z97daiHhFCCCFdyqrtF7ElNhNTenjhyzv7t+lZf53Pwqrtl3DX0EB4OVjhRP2MGR9HKxRW1AIAiitrDd778V62EKy/izW822nqbq4yFyuPrMSxbFbFvI97Hwz2HozH9z2OI5lHcCTzCADASmYFG7lNu7xmW1EgQgghpEvZEpsJAPj3UuOVvluiTq3B0g1nAQAf7RGvLu/laAW1hvWUFDfSIxJf33vy2Jiwdpkx8+nZT/H1+a9Fx9ys3dDPs5/eta7Wrmaz/gwNzRBCCCE34HK24dkwE6I80c3LAc62/NCM4R6RrBI2RBLp1fb8kP1p+/WCEABwt3GHVCLFxMCJouOu1q5tfs32QoEIIYSQm1pygRLX8iva9ZnrjqbgrZ1X9I672lri27sHQCaVwNnGAgBQrNQPROrUGuTWV1T1dbJuU1uKqovw7KFnDZ5zt3YHAKwethpSifYt38fWp02v2Z4oECGEEGIyhxML8Pn+JKMV1TqXXoKxaw9g7udHUK1So6ZOXFxMpW5+ddyGTqcU4ZUdl3D8OssHkTYywuFUn6xaVl2H+V8exR9nMgCwwCQ2rQRqDQcLmQTudm2rqPpP8j+oUdfAz85PdHxq0FTYWLA8EAdLB7w0+CXhnK+db5tesz1RjgghhBCTWfzdCQBAqLsdpvRsfC0yQ0qrVDh2rQCjIzxgbSkzeM2GE2kAWDCQVVIFC5n483deeU2reiSS8iow78tjomP/W9QfZ9OL8dXB63htdk/huJO1BeRSCeo0HE6lFONUSjE2n8kQrUXj5WgFaWORjAEaToOsiix42HjAUsYCndjcWADAvIh5UKqU+CPxD2yYvkEv2Ojh1kPY9rEznx4RCkQIIYSYhG7p8zOpRVDW1GHHuSw8NTEC4R52WHcsBbaWctw9LMjg/R/tScAPR1IQ5mGH3U+NMph8mVVf1wMAskuroZCLA5HCitYFIv/bn6R3rJuXPcZHeWDhAH+EuGun4cplUjwwMgRfHtQWE9MNQgAgyqvli7uW15bj/l3340oRGxJaHLUYzw18DunlbKX6EMcQjA0Yi8ejHzf4dxHhpC0qamth2+LXNTYKRAghhJhEZok2SPgmJlnYPpiQL7qum5c9BofoJ1fy67Qk5VWgUFkLNwNDHDml1cJ2VkkVbBXit71CZS0qauqw53IuJnb31Duvi+M4UdusLKQYGe6OABcbSKUSURDCe35KJI5fL0RceonBZ3bzblkgcizrGDYlbBKCEAD45covWBS1SAhEAhwCAKDR2TAWMgssilqEkzknMcZ/TItetyNQIEIIIcQkmqo4quuv89kGAxFlbZ2wnVlc1YJApBqO1uK3vaKKWkz56BAyiquwYmo3PDw6VO8ZHMfhnX/jUVatQqGyFnKpBPufGQN/l+brcEgkEvx0/yAs/z0Oe67kAQDuHxGMQwn5SMyrwOgI92afwXEcHtr9kMFzV4uuokLFArKW5H28MOiFZq/paJSsSgghpM2KlLV44Y/zOJtW3OJ70lsYiKQVVeLD3Qk4eq1AOKZSa5BWqL0/S6d3haesqUN5jTZYWX8iFWlF4ute//syMorZsdOphtsen1uOLw9eE/JN+vo7tSgI4TlYWeCOwQHC/q39/LD50WHYvmQ4+gc6N3t/Wa14mvDiqMXCdkxmDAAWhFjJO+eishSIEEIIabNV2y/it1PpmPu/oy2+J7WwZYHIwYR8fLw3EXd8cwL7r+bhUlYpXvvzMgp1psVmGghEdIMTKwsp8sprsOl0uuga3fLrskaGNFIKxO1cPCSwRe3WNSrcHbcPCsDL06PQ3ccBjtYW6OPvJLpmX9o+LNm7BNkV2cIxpUqJ1LJU0XX+9v6YGjwVALArZRcAoLd771a3yVzQ0AwhhJA2O1lf2rw1Wjo0o+veH09BKgE0DWb78r0aui5kshVoBwQ6I8DFBlvOZgo9JGMi3XEgXpyLUmSg3gcAvRokc6JbP/VVLpNizS29mrzmqQNPQcNp8GTlk9g4cyMKqgowa9sslNeWC9coZAqM8huFbCULVpQqJQBWyr2zoh4RQgghbVZW3foF5Pihmfn9/TA6wh39A50xV+dN3sPecH2NhkEIACTmlesdO1efINrbzwnRAU7C8VB3W0zVmSp8/4hgAEChssbg6/EL0wHAE+PCDH8zbVRQVQANx2qaXCm6goKqAhxIPyAKQsb4j8HR24/Cz94Pnjaeovsb1hDpTKhHhBBCSJtVq7SFwf67lINJPRqvCcJxHD7akyisVPvomFDRjJOtZ9laMLcN9Mcn+/SnyxoSn6NfOZVfgK5vgBN8dBaVYzVLvJGYW4HBIa4IcLHBd4eTRUM9uvgekUdGh2LZhAiD17TV1aKrov1Xj72Knq49RccinCOE2iENV8x1VDgapV0dgXpECCGEtEnDqqgP/Xym0Wu3ns1A8IqdwsqzYR52CHQV17R4Y05PDAh0xn0jgvH6HPGbsW2DwmXTe3sDAAoqalBYoe3R+DbmOq7mlEMqAUaEuSHMQxvouNkr4GhtgZdndMfE7p5wEdaEUWHCBweFxej47+1afY/IvP6+LSo+xvdstIZuzwcAHEg/gO3XtouO3dvjXmGbAhFCCCGkXn65/pBGZf3U2q1nM3D39yeFhd8+3C1epfbbu9i6LLoWDwnE5keHwcnGEosHB2DzI0OFc7pByxPjw/H5Hf3g5cB6O747nIxr+RWoqVPj4/rVcKMDnOFiaymUWwcAVZ04UODXhAHYMMy0T2KE/Zyyaihr1ZBLJXoBkyHxRfEY8dsI/HDxh2av1cVPwR3jPwaTgyYDgFAfxEZugz/n/Ak7S20w5WUr7nFyUji16vXMCQUihBBC2sRQ0mlCfbGxp34/h4MJ+fhoTyLSiyqFa9fc0gvX35qGILem39wlEgkGBLkI+6Eedrilny98naxx91A2e8XDgeWS/O/ANYx//yCOXisUklK/uWuAcO/kHiyvomGlVrlMipemRQn76voklGJlLXZeyAEABLnZ6pWHN+TTs5+ivLYcH5z5oNlrdVXUsr8vewt7jPIbJTq3aeYmBDmK29xw9VwHy5ZXaDU3lCNCCCGkTd7dFa937HxGCXr4aN8cr+VXYOS7+4X92wcF6N3TlG/uGoCfj6di5fQoeDiI62V42FsBKBX27/3hFABg4QB/YdgFAD69vR8KKmrgY6Ck+93DgvCmzmq6ZdUqzP/qmJCoOjjYRe8eQ3RXuG0NfmjGztIOwQ7BwnFLqaXBQmUWUgvRvkxqeK2dzoB6RAghhLRKWmElnt54DhczS5FdWmVw6u6q7Zcw67Mjwn5MorYYmaEKqM2Z2N0TP903SC8IAbQ9Ig311ZkpAwCWcqnBIIQ/d4vOjJ3tZzNFs2WGh7m1qJ32lvbC9vHs4y26B9BOw7WzsBP1frjbuDcaZLhYtSw4MncUiBBCCGmxi5mluOfHk/gjNgP3/HAS59K1PRE/3TdIlM9xJbvM0CPw0cK+7domT3vDFUV7+bYugfODhX3R24/ds3L7JQCAj6MVnpkUgclNzALSpea0C/k9+N+D+Df53xbdx+eI2Fnawd7SXpiee2v4rY3e88yAZwAAA70Gtug1zBUNzRBCCGmRrJIqzP78iJBDUVBRi0d+YTNkRkW4Y1SEOziOw9MTI7Dnap5Qx0PXyhndMSK8Zb0LLeXtpA1EAl1t0MvXEZZyKbq3cEE5Xf0CnHE+QxtcvTG3J8Z182ziDrHSmlLR/t/Jf2NK8JRm7xOGZixYQura0WtxvfQ65oTNafSemaEz4WXrhSCHoBa3zxxRIEIIIaRFrmSXCUHIk+PDhSm4AIQ6HRKJBI+PD8fScWEY8MYeFCpr4WhtgdIqVvCsr3/7TzOd1ssbZ9OK4WqrwJMTwluUVNqYST088ePRFACAjaUMI8KaX5ROFx+IDPUeimPZx3Aq51SL7hN6ROoDkb4efdHXo2+z93X23hCAAhFCCCEtxM94mdrTC5N7eIkCES9H8fCIRCLBlseGIb+8BmlFlVi+8RysLWTo7t3+gYidQo41t7TPWitDQ1zx2uwe8LC3wphId1jKWx7UFFYV4kLBBQDA7d1ux7HsY1CqlFBpVHrJpQ3xs2Z0p+h2FRSIEEIIaRF+kboAFxv4OouTPhcN1l8ILtDVFoGuthgQ5AI3OwVsLGWwtjTv2R0SiQR3DQ26oXvXX1kvbPvaaxNfK2or4GzV+Cq7HMchtzIXwM2TgNoaFIgQQghpEb5HJMDVBo7W2k/4t/bzg3sj68LwRkW0boijs/n7+t/YmLBR2A9zCoO13BpVdVXNBiIFVQUoqi6CVCJFqFNoRzTXrFAgQgghpEX46azBDYqQRXnbG7q8y4jNjcULMS8I+7/N+A1SiRT2FvaoqqtCuUp/QT5d/DozQQ5BsJYbnl58M6NAhBBCSLMqa+uEHpFITxZ4/HjvQOy7moc7h+oPy3Qllwsvi/bDncIBsHyPvKo8If+jMXwgEukSaZwGmjkKRAghhDQrsb5ku5udAq71BcnGRHpgTKRHU7d1CSdzTor2+RVy+cTT5npErhSxiq7dXboboXXmjwIRQgghzYrPZW+mkV5db1ZHU2JzY7E/XVu6Xndoxd6C9RwZ6hHhOA67UnbhatFV7E7dDQDo5trNyK01T0arrKpWq7Fy5UoEBwfD2toaoaGheP311/WWiyaEEGL+EnJYIBLh2bXzQRqKy48T7S/rt0zY5ntE+Bohuv66/heePfQsvrv4nXAsyiVK77quwGg9Iu+88w6++OILrFu3Dj169MDp06dx7733wtHREU888YSxXpYQQogRCD0iFIgIquqq8OGZDwGwIGLF4BXo695XOM8XJ+OrpvLWnlqLdZfX6T3PUdH+NVY6A6MFIkePHsXs2bMxffp0AEBQUBB+/fVXnDx5spk7CSGEmJuE+kAkwosCEd7Pl38Wtu/ucTeiPaJF5x0sWYn5slrtmjulNaX4+Qq7L9gxGMmlyR3QUvNmtKGZYcOGYe/evUhISAAAnDt3DocPH8bUqVMbvaempgZlZWWiL0IIIaZVUlmL3LIaAEC4B+WI8LYlbRO2I5wj9M47WTkB0JZ9r1RV4oMzH0DDaQAAP0z+AR+OYT0q7416z7iNNWNG6xF54YUXUFZWhm7dukEmk0GtVuPNN9/EokWLGr1nzZo1ePXVV43VJEIIITcgoX7GjK+TNeytmi5V3lXUqGuQVZEFgOWFhDuH613jpHACAJTUlAAAvjz3JbYkbgEA3N/zfrhau2JC4AScv+s8JBJJh7TbHBmtR2Tjxo1Yv349NmzYgNjYWKxbtw5r167FunX642K8FStWoLS0VPhKT083VvMIIYS0kHbGDA3LAICG0+Cp/U9BzalhI7fBfT3vM3gdn/NRUl0CADiefVw4t7j7YmG7KwchgBF7RJ599lm88MILuO222wAAvXr1QmpqKtasWYO7777b4D0KhQIKRdNlggkhhHQsfsZMuCcNywDA8azjiMmMAQC427g3Gkg07BEprikGALw76l24WbsZvZ2dhdF6RCorKyGVih8vk8mg0WiM9ZKEEEKMgGbMiCUUJwjbSpWy0eucFWx9mbTyNKSXpSNHmQMAGOw92LgN7GSM1iMyc+ZMvPnmmwgICECPHj1w9uxZfPDBB7jvPsNdWIQQQswPx3HaGTMUiKCgqgBfnv9S2G+qNpbudNyl+5YCYEXO+ACFMEYLRD799FOsXLkSjz32GPLy8uDj44OHH34Yq1atMtZLEkIIaUf/XsyBtaUMJZUqWMqkCKMZM/jq3FeiXpBXhr7S6LW6gcj10usAAH8H/y6fE9KQ0QIRe3t7fPTRR/joo4+M9RKEEEKM5EB8Hh755Yyw38vPEVYWMhO2yDzsS9sHAPhiwhcY5DVIWFfGELlUjmcGPIO1p9cKxwLtu/YCgYYYLUeEEEJI57Xvap5of2CQi4laYjpZFVk4nHkYtepaAKwOSF4V+3vp5darySCEd1f3u0RDMd523sZpbCdGgQghhBA9FTV1ov3FQwJM1BLT4DgO9++6H4/ueRSvHXsNAJBezkpKOFg6tLgcu0QiwbL+y4T9MKewdm9rZ0er7xJCCNFzLV+bB2FjKYOfs40JW9PxUspSkFGRAQA4lXMKgDYQCXRo3fDKLeG3IMI5AvvT92Ny0OT2behNgAIRQgghIsXKWlzIKBH2h4W6mq4xJnI697SwnVuZC5VahdSyVACAv71/q5/X060nerr1bLf23UxoaIYQQojIocR8aOpnpU7q7om3b+1t2gaZAJ+UCgBqTo3UstQb7hEhTaMeEUIIISInkosAAA+MCMbLM7qbuDUdr6S6BMezWDl2ewt7lKvKcb30ept6REjjqEeEEEKIyKn6QGRgcNebKQMAe9L2oI6rQzeXbhgbMBYA8OyhZ4XhGuoRaV/UI0IIIUSg1nC4XsASVXv7tWxmyM3mYMZBAMCkwElC8TENx5YncVY4I9Qp1GRtuxlRIEIIIQQcx6GmToOyahXUGg4SCeBu1/UWIdVwGsTmxgIAhngPQV6ltp7KOP9xWBq9FLYWtqZq3k2JAhFCCOmi/r2Yg3f+vYoPFvTB9rgs/HYqDWvn9wEAuNkpIJd1vdH7rYlbUVZbBhu5DaJcoxDqFIqBXgMR7RGNx6MfN3XzbkoUiBBCSBdUU6cWSrh/sDsBMYkFAIDVOy4BADzsb+7ekPTydFSqKhHpEik6vilhEwBgUdQiyKVyyKVyfD/5e1M0scvoeuEuIYQQJOVVCNvnM0qF7YIKVs7c08Gqw9vUUVRqFWZtnYV5f87D9qTt4DgOpTWl4DhOWJxuRugME7ey66AeEUII6YIK6wMOACitUumdv5l7RE7lnEIdx0rYv3zkZaw+uhp1XB2W91+OqroqyCVymqLbgahHhBBCuqDiytomz8/s49NBLel4O67vEO3zQckHZz4AAPjZ+8FCatHh7eqqKBAhhJAuiO8R8Xex1js3sbsnhoe5dXSTOkRFbQV2pexq8ppgx+AOag0BKBAhhJAuqUjJApFBQfrryIy4SYMQAEgqSUKdpg4eNh7YPHMzVg9drXcNBSIdi3JECCGkC8kvr8HqPy/h7/PZAFiPSHdvB1zJKcO0nt4YHemO2X1v3mGZayXXAABhTmGIdImEs5Wz3jXett4d3awujQIRQgjpQjacSBOCEABwtbXE+gcGo7iyFiHudiZsWcdIKkkCAKE6qoeNB3bO3QkruRXGbRoHAAh3DjdZ+7oiCkQIIaQLiU0rFu17O1rD2dYSzraWJmpRx8oozwAABDkECcf8HdgMmV+n/4rUslT09+xviqZ1WRSIEEJIF6HRcIhLLxEd6xvgZJK2mEpuZS4AwMvWS+9cT7ee6OnWs6Ob1OVRsiohhHQRyYVKvZohbl1sPRk+EPG08TRxSwiPekQIIaSLiE1lwzJ9/J3QL8AJYyI9TNyijlWjrkFRdREAwz0ixDQoECGEkC7ibP2wzOBgF7w4Lcq0jTEBfiVdK5kVHCwdTNwawqOhGUII6SLOppUAAPp1sbwQXlpZGgDA284bEonExK0hPApECCGkC1DW1CE+pwwAEB2gXzujK7hceBkA0M2lm4lbQnTR0AwhhNzkTiYX4c2/L0PDAc42Fjf1yrpN4QORHq49TNwSoosCEUIIuQnVqTWY9+Ux1NZpcCWnDBzHjvs6668t0xVwHIdz+ecAAN1du5u4NUQXBSKEEHKTUak1uJBZqlczBAAcrbvmqrLp5enIr8qHhdQCvdx6mbo5RAcFIoQQcpN59c9L+OV4msFzVbXqDm6NeTiTewYA0MutF6zkXXNoylxRsiohhHQyHMfh25jr2Hg63eD5hkHInUMChW1vx645NHM69zQAUPl2M0Q9IoQQ0smcyyjFG39fAQCMiXCHh4MV4nPK8fHeBIzrJq4YumxCOJZNiMCYSHd8fyQZK6Z1rRkjp3JOYdWRVcioYGvMDPAcYOIWkYaM2iOSmZmJxYsXw9XVFdbW1ujVqxdOnz5tzJckhJCb3q8ntD0eg97ai7yyany0JwE7L+TgmU3nhHM+jlZYOJAt6DY+yhPrHxgCP2ebDm+vKd236z4hCAmwD8AALwpEzI3RekSKi4sxfPhwjB07Fv/88w/c3d2RmJgIZ+euOX+dEELaQ1x6CTaeEQ/JDHprr951fs7W2LN8NKwsZB3VNLM3L2IeLGVdY5XhzsRogcg777wDf39//PDDD8Kx4OBgY70cIYR0CR/uTgDHAbdE+2JkhBue+v2cwetGhrt1+SCEX1eGF+EcYaKWkKYYbWhmx44dGDBgAObPnw8PDw9ER0fjm2++MdbLEULITWvHuSz0f303Xt52AadS2Jvr3cOCMKevL7p52Ru8p4+fUwe20DxllGeI9kOdQk3UEtIUo/WIXL9+HV988QWWL1+OF198EadOncITTzwBS0tL3H333QbvqampQU1NjbBfVlZmrOYRQkin8dXBayhU1opmwwS52UIikWDdfYOQUVyFfgFO+OlYKl7ZcQkAMDDYxVTNNRv5VfnC9oO9HqQVd82U0QIRjUaDAQMG4K233gIAREdH4+LFi/jyyy8bDUTWrFmDV1991VhNIoSQTqmkUiXad7G1FAqTeTpYCSXbFw70h1QCRHjaI9TdrsPbaW4KKgsAAGP9x+KJfk+YuDWkMUYbmvH29kb37uIyulFRUUhLM1xkBwBWrFiB0tJS4Ss93fAceUII6Sr2XslFZkmV6Ji/i+GZL1YWMtw5NAiDQ1w7omlmr6CaBSLu1u4mbglpitF6RIYPH474+HjRsYSEBAQGBjZyB6BQKKBQKIzVJEII6VTi0kvw4E+s5IGHvQJ55WzoengoBRotkV/JhmbcbNxM3BLSFKMFIk899RSGDRuGt956CwsWLMDJkyfx9ddf4+uvvzbWSxJCOrGskiocv16I2X19IZNKTN0cs/DCH+eh4YBgN1t8ens0YhILcCmrFI+PCzd10zqFvMo8AICbNQUi5sxogcjAgQOxdetWrFixAq+99hqCg4Px0UcfYdGiRcZ6SUJIJ/bc5vM4nFSA5RvP4fM7+mF6b29TN8mkCitqcDWnHADwx6PD4GJriZ6+jiZuVedRp6lDXH4cAJq2a+6MWuJ9xowZmDFjhjFfghByE0gpUOJwUoGwv/TXWER5j0ZIF024rKytQ/839gAAQtxt4WJLRbha62TOSZTXlsPB0gE9XXuaujmkCbToHSHEqDQaDmXVqkbP/3AkGWPWHhAd4zhg4+kMwzeYkYTccmQ1SCRtq7JqFbqv2iXsT+zu2cTVpDHfXvgWADA1eCpk0q5d2M3cUSBCCDGa+JxyDHt7H3qv/g9v/HVZ77xGw+HVP7XHHx8Xhnfn9QYAXMgsAcdxoutr6zRY/O0JPPHrWb1zHS2vrBqTPjyEYW/va9fnHr9WKNp/ZlJkuz6/q0gpTQEAzA2fa9qGkGZRIEIIMZr/LuUgp6yabV/O1TtfqKwVti3lUszr7yfUvziSVIg+r/6H9KJK4ZqLWaU4nFSAHeeysOdKnpFb37QLmaXCdkllbRNXtk5ceomw7WAlh4WMfk1X1Fbgvl334efLP7f4nvJall/jpHAyUqtIe6GfcEKI0WSVVgvbGcWVqKlTi87n1gcpzjYWOLtyIgJdbRHqbiucL6uuw84L2cL+pSxtteUjOjklpsAHWACQWljZxJWtoxvgfHx7dLs9tzPbmrQVp3JO4d1T77aoJ6xWXYtqNfv3sbc0XAKfmA+jJqsSQrq2nFJt/oSGA9IKKxHuyd4YipS1+PFoCgDA19katgr268jJxhJeDlbCG32dhr3xcByHf3SCkoIK7XIQppBSoBS204oq0cffqdXPqKlT4/nN51GlUiM+pxzzB/gjKa8CAJsp0z+QVisHgBq19t86sSSxyVkwZbVlQn6IBBLYWXTNhOfOhAIRQojRZOv0iADA/etOo6+/E+b288XO89nYfIYlpLrbiQsZfnxbX9z2zXFwHJBfX8TrbHoJjurkT/DHTSWxPmAAgNRCZRNXNu6/S7nYFpcl7L+3S1sEMsyja7+BppWlwcXKBXaWdkJhMgB45+Q7+G7yd43e9+juR3G+4DwAwM7CDlIJdfybO/oXIoQYRUZxpVAH49Z+fgBYz8GOc1m494dT2HRGOyvG2UY8PXVwiCtemcGWiMgrZ8FMQv2zeKYMRDQaDmfTSoR93byO1ohJzDd43N1eIawl0xUlFSdh+tbpeHTPowCArAptsHYy5yRKa0obu1UIQgAaluksKBAhhLS7qlo15n95DABgp5Bj5YyoRq/1sFfgjsEB+sfrF3LLK6sBx3HYeJqtPTUijFXJNGUgcr2gAqVV2inJp1OLodG0bhYPx3E4mGA4EAnrovVTeJsTNwMA4vLjoNKocLX4quj8hYILBu8rqBLnDVEg0jlQIEIIaXdXcsqQXVoNmVSCbUuGw8nGcEEue4UcJ14cjwFB+kvWe9iz4Zrc8mrsj89DbH0PBJ83UV5Th6patd59HSG5gCWnRnk7wEImQUmlClmlrasnciW7HLllhoOprj4sk1mRKWxvS9qGHGUOXKxcMM5/HADgYsFFg/clFCWI9ikQ6RwoECGEtLvkfJYzMSjIRXhT3fjwUAwNccUjo0PBLyUzpacXJBLD68p4ObIekZzSahxK0H7SjfJ2gI0lK1DV2jf/9sInyvo4WsGzvuemYT5Mcw4ksOnH47t54I05PdEvwEk419UDEb4GCAC8duw1AKwwWS/3XgBY/oghiSWJon0KRDoHSlYlhLS7lPrkzSA37VTcQcEu+PWhIQCAyT08cSa1GIuHNL4at5eDFWRSCVRqDleytdN2x0d5IMDFBldzypFWWCnUHelIBfXDQq52lvCpsUZGcVWrK6weiGfDMmMi3bF4SCCm9PTCkLf2ok7Dobdf111T5uXDLyOlLEXveKRzJGwsbAAA6eXpBu9NKBb3iHRz6dbu7SPtjwIRQki7S66f2hrsZmPwfHSAM6IDmp6aKpdJ4eVghcySKpxILgIAfLm4HyxkUgS6skCEn62i0XCQduCKvXwhNjc7BWrrNABa1iOi1nB4vb7C7OkU9j0NDnEVnnXoubFIyC1v9u/mZlWnqcP2a9sNngt3DodMwnrC0soN94g0zB25JfyW9m0gMQoKRAgh7U7oEXG1bebKpvk6WyNTp6eB72Hhn5tSWImfj6Vg9Z+XsXCgP96a26tNr9cS1Sq1UP/EzU4BPkU1u6QKeWXVUFjI9Ga8/HQsBXllNRgS4ircy/NxshZt6+7f7DiOQ15lHjxsPCCRSJCjzBGd97b1hquVK9xt3NHNpRuq61iwV1RdBKVKCVsL7c9XaU0pkkuTAQBbZm2BjYUNvGy9Ou6bITeMAhFCSLviOA4p9cmcIe5tC0SCXG1wsr43JNLTHhEebMyfD0guZpZiS2wG1BoOG06k4bnJkY0mxraXLbHaREpnWwtY1+ernEopxqj39sPTwQoHnhkDiUSCPZdz8cm+RJzPYNNN/7mYLXqWRMJmFXVVx7KO4eE9D2NBxAKsHLpSNOTyzIBncHePu0XX21nawd7CHuWqcuRV5iHYMVg4F1/EarD42fkh3Dm8Y74B0i4oWZUQ0m4uZZVie1wWKmrqIJUA/i6Gh2Zaam60n7D94vQoYfgl0osFJKdTi1FWXSdc09gslPZ0JrVY2A73sEeUtwMA4HJ2GapVGqQWViKlsBIqtQav/nVJCEIA4Fq+uPCZidftM7k9aXsAABsTNiJHmSMEIqP8RukFITw3m/rp25Xiqc85law3xc/eT+8eYt66bihOCGlXHMfh7u9PoqCC5U/4OFlDIW/b8utDQlzwxPhwKORSjI5wF45HehqeDZFXXi0EKcaSWcJ6e2b09kZPX0dUq9SQSlgJe97YtQfgYmuJIqX+Yngjw90Qk2jadXLMhUKmraibUJwgDM342Po0eo+HtQeSS5ORVyVe9DBXyRZV9LTxNEJLiTFRjwghpF3kltUIQQiAdpn5IZFIsHxiBJaMDRMdt1XI4eeszaXwrc+ryDNyj4hawyG+vsLrw6NCAQBWFjKhV0SXoSAEAIaHueGxMeze6b28jdTSzkG3Qup/Kf/hXP45AIC7jXtjtwjnGvaI5FbWByK2FIh0NtQjQghpF9fzK0T7UV76b87tKcDFBhnFLJF1YJAzMuOqkGekaquHEvLx7q6ryCyuQnGlCk42Fojw0k4bntPXV7QysK5J3T3x8W3ReP+/eBxKzMf8/n6wt7JAbz8nDA11NUp7O4uSmhJhW3e2jJu1W6P3uFuzQCSvknpEbhbUI0IIaRfXdFajtVPIMSfa16iv9+K0KFjKpJgb7SsUFePXpWlvd31/Ehczy1Bcycq6z+vnJxp2mh3d+FBCoKsNrC1leHlGd/z31Gi42ilgKZdiSk+vLr2eDIBG14xpKhDxsPEAoO0BAdi0X74MPM2U6XwoECGEtMnvp9Kw8KtjOFNfF+P+EcE498qkNieqNqenryOOvDAOb9/aC771wzTpRcavtPrUhAg8MzlSdMzD3goz+/hAIgFsLbUBio2lrMmibV2dbo+IrqYCEV87FuDuTt2NjfEbAQBHMo8gR5kDZ4UzBnoNbPd2EuOioRlCiEEcx0Gt4SCXNf155fk/xEWkwjzsIOug4mLu9evR8NVVGw4PGcJxHDafyUCgqy0GBeuvcdPw2pr6gmW8J8aHGSxL//78Pnh5ehRsFXL8eiINM/v4wMNe0aGF1jobPhC5JfwWhDuF451T7wAAXK0aH7LSnRXz+vHX8UfiH7hcyIrEjfAdAWt516nDcrOgHhFCzAjHcfj7fDayTbSGiq77fjyFUe/uh7KmrtFraur0F50zRcl1/jVTiyqFSqcAUKysxezPj+Dz/UnCscNJBXh283ks+OoYyqpVes/ivf3PVfR7fTd+OZ4qHNu+ZHija+NYyqXwdLCCnUKOB0eFwMvRioKQJqjUKpTXssTfJ6KfwILIBQh0CESwY3CLekR4fBACAL72xh0OJMZBgQghZuS7w8lYsiEWT/4WZ9J2lFWrsD8+H1ml1TiXUdLodXzhMl1tLWJ2IzwdFLC1lEGt4ZBWpM1V+TrmOs6ll+C9XazY1YWMUjyw7rRw/s9zWXrPSilQIvTFnfjy4DUUV6rwxt9XAADBbrbo4+9k3G+kC8lWZoMDByuZFVysXGAps8SmmZvwx8w/IJM2Pu2bX2/GkIZBCukcKBAhxIys/Y+9YZ5MLsL5jBLRp/uOkllShRPXi4T9apV+rwcA1NZp8H59e3kTu3vC1da4lU0NkUgkCKnvFUnK0wYiaUXaQCnohb8x87PDoqGWhPqpuLoe+eUM1Br9SmNBrsbNeelqMioyALDgge9lspZbw0LWfALv5+M/N3icApHOiXJECDETGg2HapX2TXLWZ0fQzcse/zw5stHhgPaWWVKFcWsPiN6sCypqsfNCNhysLDAiXNtl/vPxVPx3OReWcikeHBmMef39EezW8b0hvFB3W1zILMX1ApYnwnEczqWXNHlPcqG4R0et4ZCQqx+c9PJ1xOtzerZbW1tLw2nw5P4noeE0+GTsJ032GHQWmRWsVP6NDKeM8huFp/s/jffPvC8cs5BaINQptN3aRzoOBSKEmIlcA1NPr+aUI6u0WijYZWyHEvL1kjNPXC/CH7Hs02vymmmQSCSoU2uE3IkXpnTDfSOC9Z7V0fg8kWv1PSKxacVCnRFdXy7uBwcrC9zx7QkcSsjH5jMZmNefJUCmFCpFFVKXjA3FgEAXDAlxFdaUMYW8yjwcSD8AADhfcB7RHtEma0t7yapgw2I32ouhO0Tz363/oU5TBxerppOPiXmioRlCzITu1NN7hgXB34UFH5cyDddaMIZTKUV6x/7VWaitrIolrn53OBnJBUrIpBJMM5PqoKEe9YFI/cwZQ2XU3ewUGBPpISyaBwBv/s2SHUsqa3Emha0jI5UAD48OwePjwjG2m4dJgxAAosXg7vrnLnx74VuoNYaHzDqLgir278PXBWktb1vtz523nTf8HfzbpV2k41GPCCFmIr0+n2F4mCtWz+qB8uo6pBdl4GJWGSb1MH6RppLKWvx7ka318cS4MMTnlmPXpVwoa7VveAXKGjjaWOD49UIAwMKB/vBytDJ621pC6BHJrwDHccJicy9Pj0JqYSW8HK3wwMhgKOQyeDta4ZZoX2w5m4niShWCXvgbFjIJVGrWHbJ4SCBWTI0y2ffSUEZ5hmj/49iP4WfnhynBU0zUorYrrmZBn7PC+YbuH+E7Ag/2ehBRrubz70RuDPWIEGImskpYj4ifE+ty7h/IfkEfSeqYBdIOJuSjslaNCE87PDUxArf001/FtLB+LZmUQu3Cb+YisD6ZtLy6Dst+jxPyQ/oFOuP1OT2xZGyYUA1VIpHgg4V9MVIn54UPQgBgRFjj00dNQbdHhHc066gJWtJ+hEDE6sYCEYlEgif6PYGJgRPbs1nEBCgQIcRMFNYvkuZmz2adjI5ka2qcTStust5Fe+EXcxsQ5AKJRIIJUZ6iKqEAUFhRA5Vag+T6cu6mTE5tyMpC29btcVkoVNZCLpWgu4EF6XjPT+kGeys5BgW7YP0DgxHiZgt3e4XJ14DRcOI8Hb5H5On+T+OLCV8AAI5lHwPH6c/u6SyKqtkwIOV1EApECDETBRVswTYXW1Yt1NfJGh72Cmg4VtvCmGrrNEJORUR9roVMKsH6B4eI26isxbOb2AqpVhZSeNqbx7AMb8XUbqJ9NzuFKEBpqKevIy6snoyNDw/F8DA3/P3ESBx8dgzsrUy3BoxSpcSMrTPwwH8PCMf4HhF/e3/08+gHC6kFcpQ5SCtPM1Uz26y4pm09IuTmQYEIIWaCXzZetw4Hv16LMddQ0Wg4TP34EC7UJ8WGedgL5/r6O2HH0uGY2J2taHompQg76ouAPTwq1Owqhz48OhQnXxoPOwVLf5vUo3UrsVpbymBjadrUuT2pe5Beno4T2SdQUcsSb9MrWCDiZ+8HGwsb9HHvAwA4nXO60ecYQ426Bvftug9rTqxp03Nq1bVQqlhwTYEI6bBA5O2334ZEIsGyZcs66iUJ6VT4QMRFJxDx4xdzK9avYNpeUgqVuJbP3hRGhrthQJD4jaG3nxMG16/Jsi0uCxoOGBPpjqcmRhitTW3hYW+FfU+PxlMTIrBkbJipm9NqsXmxwnZ6eTrKasuEVWr5dVaCHdl06ZzKnA5t24X8CziVcwobrm7AhE0T9IaQWiqvMg8AIJfKYW9h38zV5GbXIaH/qVOn8NVXX6F3794d8XKEdBrl1So8+VscpvfyFnJEdAMRf2fWI5JhxEDkUlYZAKCPvxN+vn+wwWsGB2tzJhRyKVbO6G609rQHDwcrPDkh3NTNaLX4onhsSdwi7KeVp+Fs3lkAbJqrrQXLyXG1Zv8eRVX6062NiZ9yCwC5lbnIqsgSLULXEhW1FZi6ZSoAINI5ssOK9RHzZfQekYqKCixatAjffPMNnJ2pC44QXetPpGHf1Tw8vemcdmjGTndoxvjL22+pL1bWw6fxpM7uPg7wcbSCpUyKb+8eYJKF7W52Ococ3Pb3baJj6eXp2JWyCwCwKGqRcJxP8OQTPjtKbmWuaD+pJKmRKxun2+MzwHNAm9tEOj+jByJLlizB9OnTMWHChGavrampQVlZmeiLEHNQUlmLO787gT/OZDR/cStU6dToUGs4SCUNh2bqc0TaqUekoqYOR68VCGupJBcosT8+HwBwa7/GK1zKpBJsXzoC+54ZjZHh7u3SFiK2K2UX6jTilY4TixOFXgg+LwQAXK1Yj0hhdaHBZ53IPiHqvWgv/JAK73z+eZRUlyBXmYtPYj/Bb1d/g0rT9Awv3eDptm63NXEl6SqMOjTz22+/ITY2FqdOnWrR9WvWrMGrr75qzCYRckO+ibmOmMQCxCQW4Nb+reuKBoD88hrIpBJRkAFoZ8rwIjzthVoXgO7QTBU0Gq7NyaFL1sfiYEI+1tzSC7cPCkBKIcsN6eZlj/6BTU+jdLdXtOm1SdN0l7OfGzYXW5O2Ir4oHvlVLFB0s9bWNmmqR+RI5hE8sucR+Nj6YNe8Xe3aRj4Q8bPzQ0ZFBr658A2+ufCN6JrC6kIs6buk0Wfw9UNmhsxs9bAOuTkZrUckPT0dTz75JNavXw8rq5ZN8VuxYgVKS0uFr/R0/SI+hJhCadWN1/GoqKnD6Pf2Y/onMaJVXdMKK7H+hHj6ZQ8fR9G+t5MVpBI2vbZh0HIjDiawN7XfTrH/Wxn11Vz52Tmk48VkxOCN428gJjMGAPDFhC+wNHopAOBa6TVU1bFhOXdrbU9UUzki+9P3AwCylFmoUbf9ZwYArpdcx6GMQ8JQzOPRj4vao+ts7tkmn0X1Q0hDRusROXPmDPLy8tCvXz/hmFqtxqFDh/DZZ5+hpqYGMpl4fr9CoYBCQZ+6iPmx1qlFUVlb16opnp/tS0JlrRqVtWpkFFci0JUlHG6LyxSumdzDExczy3DH4ADRvRYyKbwdrZFZUoWrOeXwcLjxuh26RdH8nPjZOOxNju95IR0rozwDj+19THQsyiUKLlYuQq8DT3eRNz4IKFeVo6K2AnaW2pwdmUT7s3oh/wIGeLUtD0OtUePeXfeKel/6efZDX4++2J26W+96fkiIL7bWMBmVfw5N2yU8o/WIjB8/HhcuXEBcXJzwNWDAACxatAhxcXF6QQgh5ky3/HdWif4quY1JyqvAlwevCfuJuRXCdkmlNjD4cGFfHHlhnFDWXdeoCPams/lMBlRqDWrrbmzKZEKOzvL29e8NafWl2vlpwqRjXSm6Itr3tPGEq7UrJBIJZoXOavQ+O0s7oUfhz+t/is5lVmgD3JSylDa3sbC6UBSEOFg6wMvWCyGOIQavL6guQFZFFub/OR+Ldi7SW5yPz2uhHhHCM1ogYm9vj549e4q+bG1t4erqip49exrrZQm5YceuFWL6JzFYseWC3jl+RgsAPPxz00WkOI7DS1svYOmGWJxIFicTJuZpA5EiJes2f2laVJM9LDP7sPVczqYXY8YnhzHxw4NQqVsfjGSWaGfeFFXUguM4nE1n4/VRTZRBJ8aTWpYq2tddwO3+XvdjuO9wAMCUIP3F7fhekbdOvIWYDDaso9aoEV8cL1zTcJZLS3EcJyTO8nkhEkgwKXASVg1dxdrqIl5s7uOxHwMASmtK8enZTxFfHI8LBRcwbcs0/HmNBUvp5ek4knkEgHZ4iRBafZeQemv/i8elrDJcyirD81Mi4WSjTSwtrtQGItfylbiaU4bzGaUYHuaG7w8nI7WwEi9M7YYwDzsUVNQKuR9/nc8Wvcb+q3l4YGQwHvrptDBbRXe6riH8sInuFN70okqEtHIKbU6ptienuLIWSXkVyC2rgYVMgugAp1Y9i7SPhovZdXfV1mexlFniywlfIr08XZglo0s3yFh/dT1G+o3Ejms7kKPUFjnT3W6NB/97ELF5sVg5ZCUcFCxI7enWE++PeV+4ZpT/KChkCtSoa/Dz1J/R2703JJCAA4e/rv8lXJelzMKLh1/EjJAZePvk2wBYQbZBXoNuqG3k5tOhgciBAwc68uUIabGyahXi6ldrBYC+r+3Gu/N6Y8EAfwDiQAQApnwUo/cML0cFnp4YiXVHU/TOLRjghz9iM3EypQg7L2QLQQgAuNo1nRfl6WAFiQTQXd+sSFmLkFbOos3WCUSu5pRj4oeHAABDQlybXI+FGE9yabJov7uLfqE4f3t/g/cu67cMq4+tBsBmynx9/mtcyGe9ed623shWZiNX2foekbzKPJzIOQEA2JK4BdNCpgGAXnKqhdQC/976LxKKE9DXoy8ANrOHn+XTUFp5Go5mshWD145eCyu5ea1TREyH1pohBMCFjFLRjBYAeG7zeWFF2oJyFoi8P7+P3r28hJwKPLr+DD7br1/kaUCgCwLrZ6Zsj8sSnXNrpkfEUi6Fq604WMkrb/lsiLzyaqzYcl4oXNbQCw0WiiMdIy4vTqiaytMdmmnO3PC5+Hnqz8L+p2c/xYGMA8I5oPUl4AuqCjB+03hhv6SmBNlK1qvnYeOhd72btRuG+QwT9scHjNe7hncg/QDquDpIIBFK1BMCUCBCCAAIy9o3dC69BJW1dcgpY70J47p5IMzD8JDIyZQiHL9uuNKln4s1fOsTQvddFReF8mjBCrZqjTgnJLes5Qmzb/19Bb+eTEdZdZ3B890pP6RDnc45jUV/L8Kd/9wpHJsSNAW3ht9q8M2+MVKJVOiJ0CWTyDDOfxwAoKCydUXNPo/7XLSfUpaCHy7+AEA8bNSYFwe/KGx72XrB105bJG/t6bUAWJKqhdR0qxsT80OBCCHQBiLTe3vjziGBGNeNvSFcy6/ArM9Ycp29Qg5nW0sEu9mK7g1xs4W8QaExS5lUVAAswMVGqJLKu2NwAFbN6N6iQmGz+4qrnuaWtaxHpKCiBtt0emAWDQ7AzidGgm/u3UMDaa2PDvbEvidwvuC8sD87dDbeG/0eVg9bfUPPWxy1WDQDxVJmCU8btupwuaocKnXLa+DwNUC8bb1Fx+0s7DA9ZHqz90skEtwafisA4LVhr+GPWX/gjm53iK5pTbBFugYKRAgBkFIfiAwJccXrc3piVDirYvnVoetIqp/pUl7DehRCdAKRx8aEYsODQ3D/CG1Xc3SAE/5+YgQG6EzF9Xa0Fk2RDfOww5tzeuK+ES3rol41oztuG6jNFUgr0vbgFCtrsey3sziSJP70u/tyLga8sUfYf25KJF6f3RPdfRxwfc107H9mDF6c3vKhANI+ylXaadRyiRyvD3+9Tc97ftDzOLDggLBfVVcFB4WDUE+kuKa4Rc+5UngF10rZVPMN0zeIzvnb+8NS1vQQIu+FQS/gzzl/YqjPUNha2KKHWw/ReXcbWiKAiFEgQgiA5PpS58H1xcYMzUh5ZHQoAGBsN+0nukk9vODlaCWcA4DZfXwQ7mkPjU52qUwqQZS3drnzoSGureqJkEolWHNLL7w7j61gffRaIdQaDhtPpyP69d3YFpeFRd+eEN3zyvaLwvZP9w3CY2PCRCXig91sReXkSftSa9T4JPYT7E3bKxyrVInXDKrj6tqlR0oikaC3O/vZmBo8FVKJFE4KJwDAr1d/FaqzGrIlcQt+uvQTlu5l1Vw9bDxE5eQBwELW8qEUK7kVghyDhH3dHBL++YTookCEdHl1ag3S60udB7mx4RPd4Rd7hRx/PDoMy+qXlR8c7II5fX3Q198J3bxYcOFsa4npvb1hbyXH1F6sW3t+f9aD0cePlW3XXSxueJj4F31LSCQSzI32hb2VHCWVKqz9Lx7PbT5v8NqaOjWy6mfJDA52EYqikY4TkxmDby58g2X7l2HNiTVQaVS44+87mr/xBn027jMs778cq4awOh985dJvL3yLZfuXGbyntKYUrxx9Be+dfg95VSx36fmBzwMA1k1Zp72QM3R3y7hZu2HHnB3C/rTgaTf+MHJTojoipMvLKqmGSs3BUi6FjyMbPvF10g6juNhZiiqeSiQSfHRbtN5zPrktGhzHQS5j8f34KA/88ehQhHmwYMVCJsWOpcNxLr0Ek3t43lBbLWRSPD4uDG/tvIovDlxr9LqEHDac5GRjgd8eGnJDr0Xahi/cBQAbrm5AP89+wtCHjdwGVXVVeG7gc+32es5Wzri3573Cvm7eyNGsoyiuLtYrq36tRP9naFLQJACsjDsvwCFA77rWCHYMxlsj3oJMIsNAr4Fteha5+VCPCOny+GGZIFcbYehCdwjDuoU1NmRSiRCEACxg6R/oAkdrbbd2bz8n3Dk0qE3d8fcMazyvpFrFymlfL2CBSISnPSWjmsjFgouN7r898m2cWHQCi6IWGe31G5ZQ/+r8V3rX8IvY8d4b9Z5o//Pxn2O032gs77+8ze2ZGTpTqElCiC4KREiXl5zP3rSDXMWzYRbWFzN7dnJkh7epKZZyKd65tZfBc/n19UXy6mfVeDtS0ShTaZgkygci3rbeGBswFtZya6MGiX3cxTVv/kj4Q6/AWcMekSnB4lLyo/xG4bPxn1GCKTEqCkRIl5dSv/Bbw2m5r8/piT3LRwlTec3JwoEBSHl7OpLXiD9h8gXY+Lonnm1YrZe0TVltGQAg0pkFsnwgMsZ/TIe8/qywWfCw9oCfnR8inSNRra7GyiMrRdfwK+UCaNH0XEKMgQIR0ulcz69AXisKejWHryES1CAQsZRLEeZh3kMbEokEfy4dAXsrlu71+K9n8d3hZPx2kq11Q4GIaWg4DSpqWU8bP321Ws1+ZvnZLMbmYOmALbO3YPOszcIU4ZM5J0Uzd0prSgEAc8PmCkmuhHQ0CkRIp5KQW47JHx3CiHf3i9aGuVEcxwl1QhoOzXQWvfwcceCZMZBIgCqVGq//dRnKWpYr4kWBiEmU15aDq59q0sNVXEfDUeHYYe1wVDjC1sIWUa5R8Lb1hppT40zuGeE8P3w0OWgybCxsGnsMIUZFgQjpVHZeyIZKzaG2ToNfjqc2f0MzTqcWI7OkClYWUvTw7bylzl3tFHhUp5YJT7eIGuk4/LCMtdxab12VjuoRaYifBfPY3seE1XFLakpYm6xM0yZCAApESCdzKEG7sue+q3nQaNpQ4ABATCIbI5/SwwsOVp17/QsfJ3HQMbG7J3r7ddynb6LFByL2lvZ6q+eaLBDx0E7HXRGzAhzHoaS6xKRtIgSgQIR0IrV1GlzMKhP2i5S1SC2qbOKO5uWUsoqToQYqqXY2Pk7aYZjlEyPw1eL+Zp3fcjMoqi7Sq5YKANdLrgNgeRoeNh6wkbNhD4VMgW4uplnteIDnANH+iZwTqNWwVaWdFc6GbiGkQ1AgQjqNhNxy1NZpYG8lR19/JwDA+YySNj0zp36aq9dNMM1VdxXf+QP8RLVQSPsrri7G6N9HY9FOcS0QjuPw4mG2Cq2thS2kEik+G/8ZJgZOxCdjP4GrtaspmosQpxCsHroadhYs6P7p0k8AAAupBazlNIRHTIcqqxKzVVqlgoOVXPhUn5jHpqZ293ZAlLcD4tJLcDA+X29l2tbge0RuhkCkm5c9Rke4w8NeAW9HemMxtkMZhwCwomBqjRoyKSt8l6PMEa4Z4s2q2g70GmgWFUVvjbgVjgpHPHXgKcRkxgBghc+o54yYEvWIELO072ou+r72Hz7akygcK61ky5m72llidl8fAMCOc1nC8RuRXb8ey81Q+Esuk2LdfYPw3vw+zV9M2uxK0RVhu0LFZl5pOA0OZhwEAES5RGFp9FKTtK0p4c7hon0vWy8TtYQQhgIRYpae2XQeHAd8vDdRKFteXl0HAHCwskB0gDO8HKxQp+GQUl+ivbWUNXXCM72oB4G0QmZFJtZfWS/sl9Ww3KVfLv+CN0+8CQDo5Wa4+q2p+dn5QSFTCPsUiBBTo0CEmJ1qlRpFylph/2ImK7rEFx7ji3fxyZlZJY0vcd4UvvqonUIOOwWNUpKWiy+KF+3zs2TeO61dq6WXu3kGIjKpTJQw62VDgQgxLQpEiNlpGFgUVNRie1wmtpzNBABhmq2vM5uJkHmjgUj9sMzNkB9COhZfkVTYr2X7cqk2oO3t1rtD29QaEwImCNvUI0JMjQIRYnaySsTl24uUtXhm0zlhv2GPyI0GIjdTfghpu9KaUmxO2Cz0bjQmsyITq46Ky6Hz9/C9C9ZyawQ5Bhmlne1hbvhcRDpHQi6Ro69HX1M3h3Rx1B9NzE5WqTiwKKyoqc/qZ8XL7Ot7RPzqC3jxJdpbS5gxQ2XQCYDnY57HkcwjOJt3Fm+OeLPR65bsWaJ3jM8R4XtKfp/xO6QS8/2c56hwxMaZG1GpqoSdZeevoUM6N/P9n0K6pJQCJV7/87Lo2Pu7E1BbpxH2HaxZIDI0lNVjiEkswA9HknEoIV90XUOlVSoh8RXQ5ojQ0MzNQ6lS4tYdt2LlkZVQa9TN36DjSOYRABDKnxui0qhwrfSa3vGy2jLUqmtRrmJTzF2sXFr12qYglUgpCCFmgXpEiFl5dH0symvYTBZvRyth+ESXQs7i5zAPezjbWKC4UoVX64MXBys5dj01Sq+OhrKmDgPf3AMPewUOPz8OAOWI3IwuF15GQnECEooTMMR7SIuXtudLnQPi5M2CqgLkKnMhl8oR6RKJK4VXRPcpZArUqGuQq8xFUXURAEAukcPBsvOuW0RIR6MeEWJWrmRrx+f5Ho+GZDoVQ7+/R1wkqqy6Do/+EmvwubV1GmQUV6GylgU6nSlHJKk4CTEZMeC4tq2tc7PTze+4Xnq9xfetOblG2HazcQPAej9u3XErbvv7Nsz7cx6yKrJwNu+s6D4fO1bPJqE4AXmVeQAAF2sqEEZIa1AgQsyGSq0dVvF2tMIzkyL1rpnWywtDQrQBSl9/J9hYykTXxKWX6N1XozNkw/eE8H96mnmOSFJxEubumIvH9j6GA+kHTN0cs1ZeWy5sF1QVNHmtSq3Cy4dfxkP/PSQKMGrVbOp4fmW+0MsBABcKLiC9PF30jHnh8wAA8cXxeP/0+wCgt8gdIaRpNDRDzMapFPZLXyoBDj47FpZyKRLfnIrlG8+htEqFFVO7Icpb3OUtkUjg5WCF6wVNFzUrqKgRtv934BoeHROKwvpaJeZeDj2+WFuz4s0Tb2KY7zBRQSqipRuI8D0UjdmYsBHbr23XO3616Cp2Xt8Jbztv0fH08nQhuLmz+50Y6DkQI/xG4P0z70OpUiI2j/XEUSBCSOtQjwgxmTU7ryDohb8x+K09uJpThrd2svH3gUEusKzPA7GQSfHp7dH46b5BekEI79b+fs2+VmGFtkDa5jMZGP8+K8NtKZfC2caird+KUel+ss+tzMV7p95r4uquq1JViXdPvSvsN9Yjkl6Wjhp1DY5lHWv0Wc/HPI9cZa7oWExGDPam7QUARHtEY2zAWFhILeCkcBJdZyE1758nQswNBSLEJFRqDb46xMbwc8tqMOWjGFzMZOP7r8/p2apnLRkbhvg3pmDljO4AWI+KWiPOpShU1hi6FeO7eZj9eH7DT/a/x/8OlfrG19e5WX174VvRfn5lvt415/LPYdrWaRjwywBhTRjecwOfE+3nVrJAxFJqCQBCjwcAuFu7C9sNA5FBXoNa33hCujAKRIhJPPV7nMHjoyPcEeFp3+rnKeQy3D00EBIJoOEgKhHPcRy+PmQ4cXHNLeZZhluXoTfUlLKUVj+nRl2DjPKMdmiR+VBr1Hj75Nv49eqvSChOEJ0rqi4S8j14v1/9XbTvauWKV4a+gt3zduOObneIzl0ouAAAmB4yHRKIg1U3azdhWzcQmRY8DZODJt/w90NIV2TUQGTNmjUYOHAg7O3t4eHhgTlz5iA+Pr75G8lNjeM47L2iP34/NtIdb99644GBXCaFsw379Jpfru0BOZlcBJVa3EMyLNQV/1vUD07115urlNIU/JPyDwBgad+l8LXzBQB8eOZDZFZktupZT+5/ElO3TNWbgtqZncw5ifVX1uOtE2/pJZJy4PD4vsfx5vE3Uaepw8b4jfjz+p/CeS9bL2ycuRHzIubBy9YLMqkM8yPmC+d3pewCAMwMnSlamwUQByKOCkdh+7G+j5l9Dxsh5saogcjBgwexZMkSHD9+HLt374ZKpcKkSZOgVN7Yaqnk5lCorEVVfWGxjQ8PxZEXxuHEi+Px/T0D25w46ltfbTWjuBIAUFunwZIN2hkRHy3si3dv7Y0NDw7BtF7eBp9hTjYlbBK2R/iOwGDvwQCAmMwYvHH8jRY/h+M4oWDXH4l/tG8jTeh07mlh29B03aNZR/Fb/G/YmbwTrx9/XXQu0jkSHjYeomOrhq4S1QAJdw7HQK+BGOYzTDi2YtAKWMm1M62q67S1bvhAkRDSckYNRP7991/cc8896NGjB/r06YMff/wRaWlpOHPmjDFflpi5E9fZ7BgvBysMCnaBr5M1PB2s2uWTZIALWwgvragSRcpa9Fq9S5gxs3RsGOZE+2LBwI6f1XCl8AoW71yMk9knW3XfsWyWUDkhYAK6u3ZHX/e+wrmEooRG7tKXpcwStq3l5j1LqKFLhZcw/8/5OJVzSu/cubxzBu7Q98PFH/SOTQ2eavDacOdwYTvUMRQAcF+v+zDAcwAe6v0Q7ogSD+EU1xQL27qL3hFCWqZDc0RKS9k6DC4uhssf19TUoKysTPRFbi7VKjWWbGBJf7p1Q9qLf30gkl5Uie8PJ4vqh/i7mO4N+KvzX+Fc/jnc/9/9enkLDWVXZKNWXYtKVSUSixMBAC8NeQkSiQSzQmdh7ei1AICC6oIWJ63qBi35Vfo5J+bgSuEVzN42G8N/HY6jWUeF4/f8cw+uFl3FE/ue0LunYa7M5KDJOHr7UTzc+2HR8aSSJNH+kr5LMC14msF2fDruU2Fbw7GfHwdLB/ww5Qc8Hv243vX9PfsD0E9aJYS0TIeF7xqNBsuWLcPw4cPRs6fhWRFr1qzBq6++2lFNIkakUmtQVqWCq5223gXHcVj07Qlhf3yUh6Fb24TvETmZUiyq0trNyx4z+/i0++u1lG59i39T/sWs0FkGrzuffx6Ldi7C9JDpWBy1GABLqORzEmRSGSYFThJKi+dU5sDB0gGVqkq9uhe6dIOPrIqsRq8zpbv/vRtVdWwhwod3P4zzd50HAFSr2dBHhUq8uOGhjEPCzJaerj1hZ2mHNSPXwEJqgQd7P4huLt1QVVeFFw+/KNzz2rDX4GnrKRpqacje0h7hzuFILE7ElOApzbZ7Sd8lcLVyxZSg5q8lhOjrsEBkyZIluHjxIg4fPtzoNStWrMDy5cuF/bKyMvj7U3GgzujZTeew82IONj08FH38nfDryTR8fzgZiTor5T45IaLdX5cPRHSDkOemROKxMWHt/lqtoTsFd3fK7kYDkV8u/wIA+Pv635BL2H/PMGdx2yUSCXzsfJBcmoyYjBh8c+EbFFcX45Nxn2CU3yiDzy2pKRG2zTEQ0XAaIQjhZVRkwN5CO4PKRm4jbCtVSizZq10Fd8P0DaKhPYVMgQmBE1BRW4GBXgNhKbXEK0NfaTJY0/X9pO9xqfBSkwELz97SHg/2frBFzyWE6OuQoZmlS5fir7/+wv79++Hn13jxKYVCAQcHB9EX6Zy2xWXVJ4rGguM4rNhyQRSEbF8yXEgsbU98IKJrTl/TJhCqNWrRDJe8KvGMoRPZJ/D2ybdRVVcFB4X2Z56v+hnhrB+w9fPoB4CtkVJQVQA1p8a6S+sabYNuIJJflW92dUhSy1L1jiUUJYjaXVlXiUoVS0LekrhFOB7pHNlofpGdpR2+n/w9vpz4ZYuDEABwsnLCcN/hNAOGkA5g1B4RjuPw+OOPY+vWrThw4ACCg4ON+XLETFSrtMuvZxRX4VRKsej89N7e6OPvZJTX9nYSrxvj62QNLxOvJXOp8BJUGu0bf3p5OuZsmwMXaxe8PORlPPDfAwCAbi7dDJYlvzX8Vr1jj0c/jq1JW4UcBkA/D0JXaU2psK3hNMipzDGbUuTp5en48xqbVutq5Yoebj1wKOMQrhZfhbuNu+ja7y9+j0MZh3CliE1B7unaE68Pf13vmYSQzsOogciSJUuwYcMGbN++Hfb29sjJyQEAODo6wtq6c2Xuk5bLLBF3sb/z71XRvr+zfq9Fe7GQiTv5jrwwzmiv1VL/pfwHgPVixObFory2HOW15bhWeg2zt80Wrosvike2Mlt0rwQShDqF6j3T1doV3Vy64XLhZeFYUXURSqpL4GTlpHd9cbU4GMyuyDaLQITjOCz4c4GQ/zExcCKCHINwKOMQ4vLi0MtNXFfmq/NfCdvett74YcoPoqm0hJDOx6hDM1988QVKS0sxZswYeHt7C1+///578zeTTutIkniNjzOp4jdBY89emRvtC4kEeHdeb6O+TkvxhbYmBE5o8rpfrvyCq0XioK1hIS1dQ72H6h0zVEsDEPeIAOLpvKakVClFSajdXbsL3/Px7ON47tBzBu/r5dYLn43/jIIQQm4CRg1EOI4z+HXPPfcY82WJCZVWqvD+f2yqqJ+z4YBjRJibwePtZc0tvXDwmbFYMMD0n/gBoKyWJc42HGZozHCf4Xhv9HsIdw7H26PebvS6SUGThO0erj0AwGC1VZVGhctFrOckxDEEgPkkrOouTCeBBEN9hopyYpQqVvxQJpEJx8b4jcGG6RsM5s4QQjofqr5D2tWmM+korVIh1N0WP98/GMPe3ic6P7+/HwJdbY3aBisLGQJcjTf801qltaw3QrdiZ4B9ACYGTsR3F78DAAQ6BMJCaoEQxxC8PORlOFs5NzsdNMolCi8MegHWcmucyT2DS4WXhOmsAHA27yx2peyCs8IZdZo6WEotMcR7CK6XXjfLQGTr7K3wsvUCwHI/LhZeFM4N8xmGmMwYAICfffOrLRNCOg8KREi7ulw/bXZutC98nKzx0cK+WFa/wN26+wZhpJF7QzpKUXURKmorEOAQ0Oy1ZTXs78TR0hGrhq7CgfQDWD10NfKr8oVA5J1R7wi9Gi0lkUiwKGoRAG0PR3aFNsfkrn/uEl1/T897EOwYDFyFXi6KLo7j8MGZD6CQKbA0emmr2tRaBdUsEOnn0U+UC/O/Cf/DswefxYkcVncmzClMCEQalmUnhHRuFIiQdnWtfopumIcdAGBCd0/hnK+TFaTSm2M65JQ/pqCqrgp75u2Bp61nk9fyQzMOlg6YHzFfWFjN3cYdP075EddKrqG7S/c2tYfvSdiYsBEbEzYK03t1zQ6dLfRAnMw5iatFVw3moFwtuoofL/0IAKhV1+Kp/k8ZbRprYVUhAPEicgDgbOWM0f6jhUDEQeGA14a9hoMZB7EwcqFR2kIIMY0OLfFObm5qDYf4XFZBNNSdBSJ2Cjmen9IN9wwLEo51dnWaOqH41rn8ptc6UalVwrW6NUJ4/T37Y0Hkgja/0fvYiqvGxubFivYHew1GgEOAaFhj/p/zUVRdpPesw5naooM/XPpBL4G2PfGBUcNABGDDVbyh3kMxN3wuPhr7EWwszGfYjRDSdtQjQtpFSoESD/50GtUqDWwtxTkaj47Rn37amem+edeoa5q8VnctFDsL4wVifTz6NHn+vdHvAWDDGnPC5mBb0jYAwPWS63DxchG2Vx5difP550X3GgpW2ktTgcgAzwHo494HPVx7oIdb64atCCGdB/WIkHbx7OZzQuXUF6dHQSGXNXNH56WbYKm7bYjuQm0yqfH+TmwtbPHMgGcaPe9s5SxsvzbsNXjasOGkjIoMACwImb19tl4QAgCP7HkEs7bNQlIxK5h2NPMo8ivbZ+G8pgIRGwsb/DLtF6wYvKJdXosQYp6oR4S0WVm1Sqieaq+Q47aBzSdwdma6wYfuLJWGNJxGeKNXyBSNXtde7u5xN2QSGd459Y5wrL9nf7w54k3RdRKJBCP9RmJzwmZklLP26ZZMNyS5NBkxmTGoUFXg4T0Pw0pmhdlhs6HhNHis72NwtHSEhcyi1W3mc0RcrV1bfS8h5OZAPSKkzQ4lsE/HCrkUR1aMg+wmSUhtjG4gYqgkO49/kweA7XO2G7VNvDui7hC91mDvwfC1019rx8+O5YrwxdYuFFwQnR/mMwzhzuGiY3mVeYjLiwPAVsT9Pf53bErYhLEbx+KNE2+0uq0cxwkF2Az1iBBCugYKRExErVHrld3urH49mQYAeGhUCBysWv+p2FyoNWo8svsRPLHvCag16kav0w0+cpWN94gkFicCYNVCDQUDxiCVSBHiGAIPazbFdaz/WIPXBTuydZ+SSpJQp6kT1m6ZEjQFrlauWDFohd5Mnl+u/NLoejZbEreI1tNpibdPvi3k2FAgQkjXRYGIiXwW9xlG/z4aRzKPmLopbXY5i01Pndar5aubmqO08jQcyTqC/en78df1vxq/rixN2M6pzGn0uvwq1lPkbdvxfy+/z/wdv07/tdES8VEuUQBYbsiVwiuoqquCrYUt3hn1Dg4sPIAgxyDc1/M+uFq5oqdrT+E+fkVgQ/am7m1x+w6kH8CGqxuEfRcrlxbfSwi5uVAgYgJ1mjp8e+FbcOCwZO8SPLLnEXxw+gPUaepM3bRWU6k1KK5kn4Q9TbzKbVvpVhs9mnW00et0Z8IUVhWiTlOH3am7sSlhk+g6fraJq1XH5z+4Wbuhp1vPRs972XrBSeGEOq4OmxM3A2DVTKUS7a+EEKcQ7F+wH88Per5Fr/nsoWdxMP2gwXOVqkrEZMQIPUhfnPtCOPfi4Bchl1K6GiFdFf3vN4GLBdrS1WpOjSOZR3Ak8wiG+Q7DEO8hJmxZ6xUrawEAUgngZN15h2UA8TotyaXJBq/hOE50Ts2pcSjjEJYfWA4AGOI9RFjVlk/EdLE2v0/7EokE/T37Y2/aXiFRtbe7/iKBEolEKJbWEj9f+Rmj/Ebp1UV57tBzOJhxEOHO4RjnP05YNXj7nO3C+jeEkK6JekSMJC4vrtGVUOOL4lt13JwVVLBAxMVW0emrpuoGIillKdBwGr1rrhRdQYWqApZSS+HYk/ufFLZ1a26YskekJcYFjBO2beQ2uCPqDoPXedl64c0Rb2Kk70iD598f/b6wfSL7BKZumYpKVaVwLEeZg4MZrKcksTgRX53/Sjjnb2ceCxMSQkyHAhEjyK/Mx73/3ovZ22YLVSq3Jm7FtC3TcKngEhJLWBJjP49+ooXQEooTTNLetiioqE82tLNs5krzpzs0U1VXhRylfv4HnzsyLmAcJgZO1DtfVKUNRAqr63tEzDT/YULABFjL2QrJ9/a8t8mE0Vmhs/DZ+M/w8diP8dUEbSDx89SfMSloEvbO1+aHZFZkiqqxfhz7caPPvZEpv4SQmwsFIkZwpegK6jiW77H8wHIkFSdh1dFVSC9Px9akrUK39PzI+Tiw8ADeHfUuAAi/vCtVlXj+0PONjrebi/SiSjyziZU4d7Mzfp0MY2u4EJyhHq1LBZcAAKP8RuGdke/gkT6PiM7zi7j9dvU3nM07C8B8AxEbCxusn7Yej/V5DPf0uKfZ66USKcYFjMMArwHo694Xs0Jnoa9HXwCsYuvcsLnCtXweDcdxOJB+wODzGpvRQwjpWihHpB29ffJtpJeno697X+FYVV0VdlzbIexfKbqCCwUXIIEEAzwHwEJqgQGeAwCwHpG4vDjE5cVhZ/JO7EzeiXN3nRMlEJoLjuNw53cnkFfOekQ87G+eQMTb1hvZymxcL7mOM7ln4GPng/kR88FxnNBrFekSCQuZBUb7jcaX574UnvHTpZ/Q3bU73jyhLSLWmhyLjhbuHK5XL6Q5ljJL/DztZ73jq4auQlx+HJJLk5FSmgKA9QpVqCoggQRDfYbiaNZRdHftjhcHv4ggh6B2+A4IIZ2d+b3DdVLVddVYf2U9DmUcwi9XfhGd25um7bbmS2hHe0QLb1DuNu7Cp+Y7/7lTWK0VAHYm7zR2029ISmElUgq1eQCdfequSq0SypYP8xkGAHjv9Hv49sK3eO3Ya+A4DhkVGahQVUAulSPYgdXhiHSJxECvgcJzUspScP+u+0XP9rETL0h3s5JL5bi92+0AgOSyZHAch5PZJwGwv4PpIdNhZ2GHJ6OfRB/3PnBUOJqyuYQQM9GlA5FTOafw0uGXUF5b3uZnfXb2M2GbT1IMcwoDwOpTNNTHXbxImW63tm7S5Fsn3kJFbUWb2xdfFI+52+dic8LmNj8LALJLq4Tt2wcFYFw3j1a35+HdD2Py5smi75enUqtQXVfd5na2VG5lLjhwUMgUGOKjP3Opsq4S/yT/A4D92/G5DRZSC3w/+Xu8NeIt4VqlSim61xx7tIyFXwU4V5mLny//jOdj2NTfIIcgzAqdhaO3H8Uw32GmbCIhxMx0nd+QBty36z7suLajyWQ6XceyjuHuf+7G+ivrAbC1RHZc24Fz+eew7vI6/ef3vA8WUsPJeA27w5f1XwZ7S3sA4iXcy2vL8U/KPy1qX1N2pexCUkkSXj32Knqt64XrJYZn9LRUfv2QzJAQF6y5pVeLZ8xoOA0W/rUQ8/6ch6NZR5GlzEJsrnjJerVGjYV/L8SsbbPaJUhsCd1hGX6oTFd6ebrw735r+K1656cETcHLg1/WO744anE7t9S8ediwgDSvMg/vnX5POO5nz0rKN5zWSwghXTYQqVXXCtstna2y+uhqxObF4u2TbyNHmYNfLv+Clw6/hMU7Db/Z9PPsh4WRC4X9J6K1K7EaWtacLwPOz9bge1RSS1Nb1L6m6A73AMDs7bNxJvfMDT8vr4wFIq0tYpZRniEk6/IarmAblx+HxOJEZCuz8UfCHzfcxtbgAxEvWy+4WbvpVST969pfKKougp+dH6YET9G730JmgQWRC0SB5yi/UXiq/1PGbbiZ8bRlq/ryM4Z4XWV4ihDSel02WVX3zbCspqyJKxkNpxFKdgNATGYMvr7wdZP3eFh74LG+j6FWXYtIl0gsiFyA4b7DUVBVYLCIk4+tj2jaY3/P/kgqSdKbzdHQgfQDsJBaYLjv8Eav0a1vwdtwZQP6e/Zv8tmNyS1jwyYtSVItrSnFz5d/hlwqNzgM0zAQ2Ze2T9g+mnUU9/S854ba2BrZFdoeEQD4fPznyKrIwrun3sWFggvConC93Xs32sslkUjgaeMprLg7O3Q2LGWdf1pzazgrnGEhtdBbd4YfsiGEkIa6bI/IufxzwnZyWTKq6qqauJq9kev+cv3uwncorSkVXaP7pu6kcIKFzAL2lvZYOXQlFkQuAMAWQBvlN8rga4Q6hQrbcokcA7zYEIGheha8rYlb8fi+x7F031K99ujiq3zeEn4LrGSsFyMuLw4cxzV6T1O0s2Ws8MGZD3Dfrvsa/Tt89uCz+Or8V/g87nNsS9omHO/hynqFdAORK4VXsDVpq7DP11wxNmFoxo4FIh42Hujr0VdIIuaHy5pbN8bdxl3Y5tdz6UokEokwPKPLnGcOEUJMiwIRsN6O5qqaNlxllf9k38+jn+g5PFsL21a3Sbeugr2lvVAqvLEekeyKbKw5uQYAW7+Gr3HRUI4yR3gjnRkyE4dvPwyFTIG8qjxcK7nW6nYCQEIuy91IrN2BHy7+gFM5p4QZErpq1DU4ln1M7/g7I9/Bnd3vBKANRL45/w0W/LVAlBdSUFVg9FWKNZxGKLvfMNBoWAOkuUBkavBUAIC9hT187TtmxV1z42fnp3fMFAv/EUI6hy4ZiHAch3N5LBCxkdsAgLAMemPX3/b3bQbP3dn9TmHGi25RqF5uvVrdLt1FyjxtPeFry97I8qvyDfZ2fH3ha1EvhG5wxfs87nNM3KytAOpq7QqFTCEkZB7Jav3qv2XVKsTnlkMiL8W/md8Lx5fuW6o3wydPmad3//yI+ZgWMk2o5MkPeX1y9hPhmnDncCFn5tE9j4pyegC2Au4rR18R6lW0RUxGDOKL4yGTyEQ1YAA2zVqXnaVdk8+6LfI2vDH8DXw87uMuNVtGV6BDoLD93MDn8M7Id4TcEUIIaahL/qbMVmYjryoPcokcs0JnARAv7d6Q7tBIwyJMPVx7YOXQldg6ayvG+o/F4qjF8LTxxNMDnm51uyQSCTbP3IxBXoPwzIBn4GTlJCSsNlwNtrquWpiKy+eGpJaLk1p3pewSFdsCAE8b9obADyPp5qS01MWMUnAc4OZ1BRzEQzv8zBJeTiX7uwt0CMTXE7/GrNBZQgInnzeQVZGl1+shl8iFmUWXCi9h3SXxrKQn9j2BLYlbsGz/sla3vyG+t2hGyAwEOQaJzs0KnYUZITMAAFYyKwz2HtzksyQSCWaHzRbVFulqrOTaBObbu92OaSHTTNgaQoi565KBSFxeHACgm0s34Y0ntzK30etP5miHHD4Y84HonJetFyykFghzDoNEIsHzg57Hnvl7bnhMPNIlEt9N/k54wxvpxxYa+z3+dxzKOASADSXM+3MeADYEMDGA9XiUVJcIz6mqq8IzB58RPfvnqT/DxoL1APHTKXXXV2mp6wWsTobCjg3r6L7p8mXNAdaTtPLISgAsABrqMxRvjnhTmKbsY+cDhUyBGnUN9qfvF71GiFOIEDQBrLdkY/xGYf9a6TXRn21xpZD1hvXx6KN3TiaVYc3INThxxwnsmb+nyfVYCKM7/Vku7bL58ISQFuqSgUh1XTXcrNwR4dRTeLNrLBApri7G+6fZ6qJ3d79br/6Hsesi8Cuensk9gyV7l+BUzinsT9+P1DLW+3F71O1CgmRxDetVqFHXYMSvI0TP+Xnqz8K6IIB2OmWWsvWBSFpRJQANlNIkAMDTA57GbZFs6Co2LxYqNUvqjS+OF3JpDNUDkUllCHZkFUpfOfoKAJZbMyFgAp4f+DwinCNE179+/HU8tucxXMi/0Oo2N6ZSVYnzBazabXfX7o1eZ2NhQ5VAW2iM/xi8Pvx1/DGrY6ZeE0I6ty4ZiBTk9EHy2eUoTp8oBCJZFVn4PO5zUZ5FrjIXo34fheKaYthb2OPenvcC0M6GmBw02eht1Q0eADZtOLGYzSSZGzYXj0c/DmeFMwAgoSgB66+sx6TNk1Cr0eZUDPIapPcc3Zol/6X816o2pRQoIbUsQB1XBWu5Nbo5d8OKwSvgYuUi9MSo1Cqcyjkl3DM9ZLrBZzUMNr6Z+A0+HPshnK2ccUv4LRjhKw6oYjJj8MB/D4iObU3ciozyjFZ9D7z/Uv+DUqVEgH1Al5zlYgwSiQRzwubo/dsSQoghXTIQYUW4JMgtUwtJdAVVBfjy3JdYvHMxKlVsDRXdCqchTiFwtXYFAHw2/jMs7bsUq4euNnpbLaQWWDlkpbB/JueMMBWXHyZwsnICANRxdXj75NuimiH397wf749+X++5rlauwvY3F75pcXs4jkN8bjmkCpaEGuIYAplUBqlEKgQN+9L34cPYD3E65zQA9gn5jm53GHxew4Jfugm7cqkc/xv/PzzW9zFhyjHAyq3rWnV0lSiZmOM4rLu0Dg/vfhhjN47FjK0zsDVxKwzhh2XGBYzrssmlhBBiSl1yANfb0RoAcCK5CN8ecICTwgklNSXC+b+T/0alqlJUN+Su7ncJ2x42Hni4z8Md1t4FkQvgbOWM5QeW43zBeSHBlJ9a2tgy8zvm7BCGPhriOFaW/N+Uf0Xfe3Ou5VcgtbASVm4sGNKdIbG8/3L2utd24OfL2tVZH+z1oLA2S0MNcy4aDnVJJBI82udRLIhYgK/Pf41oz2g8e/BZveeU1pSiqo710JwvOI+1p9cK5wqqCrDq6Co4WDpgfOB40X38OkC63wchhJCO0yU/Ano7aj9df30oBQsCnhedf+3Ya1h7eq2wBk2YUxgmBk6EKU0MnCi0gR924QMQfgqyrjF+YxpdZv2nYykIe2knBjmxIms5yhy9EvCN+e8yy6XxcmXX676Gq7WrXi+RtdwaUa5ND3nwwzYNh2F0uVq7YsXgFZgc2PhwGF8LprHCbssOLNM7ll6eDgAIsA9oso2EEEKMo0sGIh4O4rLkVWWhsFCFNXr9aL/RZrFYV8OaFs5WLDdEIpFgsNdgOCmcsDhqMb6d9C0+Hf9po21etf0SNBzw2nbtlOUHdj1g8FpdKQVKvPsve7O3sGYBSYiTuFS9hcwC7456V9iPcolqtCS60J4hq/DS4Jfw9si3m22DRCLB49GPQwIJvp/8PSYETBDOnc49jVp1bZMVZnUrydZp6pBZzpJpAxwoECGEEFMweiDy+eefIygoCFZWVhg8eDBOntSvvtnRFHKZaD8pTwku5wEorz9h8Ho+N8TUGg4f6A7JfDnxS/x76794ftDzTda6qK3TVn8tr9Sug9JUQTferyfrAxdJHQpqUwBoy7TrGuM/RtjWnYLbGBsLG9zW7bYWz0p5sNeDOLnoJAZ6DcS7o98VCsl9HPsxZm6dKVRq9bTxxKfjPhVmHgHamUUAUFJTgjquDhJI4G7tDkIIIR3PqIHI77//juXLl+OVV15BbGws+vTpg8mTJyMvT7/aZkebEKVdDyM+txzllRJoagwvzFWjrumoZjWpYb6H7romcqm82bLy285m4v51p3SOiP/596buRXVdNZJLk/HntT+h1qhF5zNLWBXX6f01qOPq4KRwEmbf6LKWWwvbCnnzi+K1lkQiEYpmWUgtRHVMspRZQg7NWP+xGOM/Bh+N/QgyCQs++VL9SpVSKBLnqHCETCoOTgkhhHQMowYiH3zwAR588EHce++96N69O7788kvY2Njg+++/b/5mI/vfov7Y/dQoWMgkSC2sRJ2GddlX57IqkH0tn4aNlAUrQ32Gmqyduvzt/fHcwOcw1HsoXhz8YqNJqobklFZj2e9xiEkUr3QrlWjfgJcdWIbJf0zGrG2z8OLhF7ErZZfeMyQWhaiyYmvH9HTr2ejwD98LcXu321vcxhsV7REt+rtIKmH1TfihK0uZJSJdIgEAeZUsCH5q/1N46fBLANgChYQQQkzDaIFIbW0tzpw5gwkTtGP4UqkUEyZMwLFj+ougAUBNTQ3KyspEX8ZiKZci3NMeM3uLe0FURSNRfvU1xJxzR+7Vh9Fb8orB4QdTubP7nfh60tetfoPffCbd4PGB8tdF+7pTf49nHxe28yrzkKraA5vAL3Gq8B8A4qm2DX049kP8c8s/TRYJay/2lvbYdesu9HbvDQBCnRXdoR6+nPy5/HOoqqsSLcTXmoCOEEJI+zJaIFJQUAC1Wg1PT3GOgKenJ3JyDC9rv2bNGjg6Ogpf/v7+xmqeINzTXrRvIZMCnCV8nawBtS1OxtugtErVyN2dw4nrhfgmJtnguT3npbCUWhk8dyb3jLD9zMFnUO24CVILbYXUphb2U8gUQhn5jmAltxLyUfj1b3R7Ovj1Tr658A0GrR8kupd6RAghxHTMatbMihUrUFpaKnylpxv+FN+ePHVm0AwMcsYv9w/Gk+PDsfOJkQh1t4VKzeFIUkETTzBv8TnlWPj1cb1g6p5hQYjydgAAWNV1E47rLteeVp6G0ppScBwnWkMGAJ7s92ST021NoWHCqW5OzRi/MY3exw/hEEII6XhGC0Tc3Nwgk8mQmytewyU3NxdeXoYXhFMoFHBwcBB9GRurssq42iowOMQVT02MgKONBXr7OQEAHlsfi+1xmUZvS3vbeDodkz86JDr29Z39MbOPD56dHInVM9mwiaZoMnrb3YInwzbgr7l/4cCCA0IS6pQ/puCJveLF83ztfPFArwfMrhKpbvKutdxaVGK8YUG1Hyb/IGyb2/dBCCFdidF+A1taWqJ///7Yu3evcEyj0WDv3r0YOtQ8kj8BwMNe2yMS5mEnOqe7/+RvcaIaFO3tTGoxVu+41K7DQM9tPq93bFIPL3x6ezRsFXIEubFZNtn5zjhyahDe+DMNEshRWWWNSGdWhKxCVYEDmeK1aJqrC2IqwQ7aHpBebr2aXPl1gNcAvDz4ZTgrnDEtmJapJ4QQUzHqR8Hly5fjm2++wbp163DlyhU8+uijUCqVuPfee435sq3iodMjMjxMXG7cx0mcO1FWVdfur6/RcNhzORe3f30cPx5NwXObzzV/UwsYCpqeGC9eOdjDXgErC/GPwC/HUzH6vf1ITNWvq1FbPBg28BWtfWNOdHtADCXSLuu3DIB2fZuF3Rbi4MKDGOA1QO9aQgghHcOoa80sXLgQ+fn5WLVqFXJyctC3b1/8+++/egmspuRgJce8/n4or1ZhYJA4V2BaL2/EJBZgSywblsktr4ajTfv2Bny8NxEf700U9nddysXFzFL09G3bkvMNg6aPFvbFjN7eomMSiQSBLraIz9UmoL7652UAwJXrfrCrLzZbmXo/JPIK1JX1xXAfPwzy7tumthmLr722pkmoU6je+Xt63IMRviMQ7qwNyMyhYi4hhHRlRl/0bunSpVi6dKmxX+aGSSQSrJ3fx+A5hVyGDxb0xaXMMsTnliO3rBoRDWbZtAXHcaIghPdtzHV8dFu0gTta7nK2eOrz7L4+Bt90wzztRIGI0DaVK6pzpwPgEOHYH1fqn9fNq/2+//YmlUjx/MDncbHwIqYGTdU7L5PKhHoihBBCzEOXXH23tTwcFPWBSPtWWM0pqzZ4fNelXNTUqfVK0bdUWbUKt3+jrQEyOsK90U/+fk7WBo8DrKYKAIT2tsWXi/th69lM3DHYvFepXdx9sambQAghpBVoukALeNizXJFLWY0vptacfy9mY9S7+zHt4xjM/vwICitqkJhbIZwPdLXB/mfGwMpCiiqVGjmlhoOUlvj9pHja89d39W/0WnedZN07BgfA0doCcqk4aInydkCgqy2WTYiAnYJiV0IIIe2H3lVaoLefI/6IzcDPx1Jxaz8/g/kbypo6SCUSWFsa7sV4dH0sdPNH3/z7Crr7sOnJU3t64YvFLFjwdLBCamElckqrEeja9NoxjbleoA1wXp/do8meldsGBeC/y7kY180Dj4wOxSszu6O0SoXkfCUWfs16VXq1MV+FEEIIaQz1iLTA4iGBmBDliToNh7u+P4nCCvEQTUVNHca/fxAzPo1BnVpj8BkNJ7FsOZuJN/5mK97y9UoAbV2TxoZtWiK/vBYA8ObcnrhzaFCT19op5Nj48FA8MpoldyrkMnjYW6Gbt7aGS1sTZwkhhJDGUCDSAjKpBO/N6w1fJ2sUKWux6UyG6PzuyznIKavGtXwl3t+dYPAZLraWBo872VjgzqHavAuv+kAktw2BSEF9oORmd+Mr3zpaW+CHewbiu7sHNNp2QgghpK0oEGkhZ1tLLBjA1r55+5+rOJNaDICtSPvS1ovCdV8cuIY/zmQgo7hSOPZtzHUUKVkvxf5nxsBCps3B2L5kuCjvwsuxvkek9MYTY9sjEAGAsd08MD7KfKZaE0IIuflQINIKM/po63D8EZsBjuPwbcx1VNaqRdc9vekcHlh3GgBwOqVIGIIBgCBXG0zpyZ4zJMRFLw/Ez5nNYkkpVIqON1bVleM4fLQnAR/sTgDHceA4DvnlLBBxb2MgQgghhBgbJau2Qqi7Hb66sz8e/vkMNpxIw4YTacK5r+/sj5+OpeJw/QJ5V3PKkZBbjtf/uixc89CoEEgkEny0sC8eHR2KYDf9ZFR+ITq+bkduWTUeWHcaZdUq/HjvIL17rhco8dEeVotkYJAz+vg7oaaO5am42dOQCiGEEPNGPSKtNDTU1eDxYWFuonVrAODhn8/gXAab8rvuvkF4cRpbv0UmlaC7j4PBGTZ8wbDs0moUKWvx/ZFkXMgsRWphJdbsZD0rdWoNDibko6xahf1X84R7t53NwoH4fACAr5M1bCwpziSEEGLe6J2qlRysDJd4t1PIkd9gNk1yARte8XG0wqhwN0O36bG3skCwmy2SC5Q4m1aMy1naCqknU4qg0XD4+XgqXv3zMkZHuEOjM2STlFeOmjo2TDQ32lfv2YQQQoi5oUCkjRYPCcCM3j4AgPtHBCMmsQDRAU44m1YiXPP0pMhWrWkyJMQVyQVK3L/uNHRri5VUqvD4b2dxKZP1shxMyIelTNuplZRXgVo1C0x6+9GUW0IIIeaPhmZuwIT6mST3DAvCG3N6YUgIG64ZE+mBf54ciV8fHAJfndLpoR52rXr+2Ejtyrea+g6P6AAnAMDf57ORUqidkVOr1sDdXgGZVAJlrVrILQlwtWn190UIIYR0NApEbsB783rj/fl98PyUbnrnorwdYGUhw6BgF+FYeCsDkYndPbFiqvbZ7vYK/PHIsEavHxzsgmENclf8nSkQIYQQYv4oELkBzraWuLW/X6Pl3AHg4dEhmN3XB5seGQrbVq7PIpFI8PBo7TL2ag0HqVSCOwYHGLy+u48D3psnXkG4ta9JCCGEmAIFIkbSzcsBH98WjYFBLs1f3IiHRoUAAF6f3RMA8NqsHqLziwYHYGYfHyweEggvRytcfHUyxka648Vp+j01hBBCiDmij81m7JlJkZjX308Y2pHXJ6beL/sbCtTh2TlfiZJg7RRy/HDvIJO0lRBCCLkRFIiYMUu5FBGe9qJjk6UnsdJiPdspWwU4+pmgZYQQQkj7oKGZTmaCNFa7s+tF4NMBwJW/gO1LgHeCgCMfm6xthBBCSGtRINLJDHIu1+5c3g4UJgK/LwLO/gJUFQP73gBKMxp/ACGk7U5+A/wyD6gua/5aQkiTKBDpZAK4rKYvUNcCB97umMYQYm7UKqA8x/ivs/MZIGk3cOQj47+WMRReA67tM3UrCAFAgUjnsmc1JBW54mOjntNuj1nB/rywCaip6Lh2EWIu/ngAeL8bkHvJeK+hVmm3s88BHAdkxgKqauO9Znv7aQ7w81wg4T9Tt4QQCkQ6jbhfgcMfavddQoA5XwIjlgH2PoCDHzB8GTteV81yRqjbmHQlGg1weRsADji73nivc+wz7XZBInDuN+CbscDulcZ7zfZUVwOU1q8cfvIr07aFEFAg0jloNOJfcvf8DTxxFuh7O2BpCzx6BHgkBrCwAkYsByBhv5D/e5ndq7MwHtFRVwvUKk3dCtJekg9qt+WKxq9rC1U1sGe1dr8kFdj2CNs++bVxXrO9FSVrt/OumK4dhNSjQKQzyLsMKPPZ9uOxQNAI8XkbF/YFAP3uBOZ9z7Zj1wGvObPekaZo1F0vWKkqAT7oBrwTDFzbb+rWkLbiOGDH49r98uz2fX5lEZCfAOxY2vR1Ffnt+7rGUJio3S7LBLLOmq4thIACkc4h9if2Z/gkwDW06WsBIGqmeD9uPbBhIVBmINH18g7g7QBgwwLx2PfNLjsOqCwE1DXApS2mbg1pq5JUoDRdZz+98WubU5wKpBzW7mfGAu+FAp8PZPlXvPDJ+vemn7jx1+0oBQni/a/HdK78FnLToUDE3NVWagORIY+17B6ZBfDUZfGxhH+BbY/qX7v/TaC2Akj8Dzj3a9va2pkUXtNuZ9Inwk7v8nbxfknqjT2nrhb4fgrw43Qg+ZD22ZxGfN1tvwIjntLue7JlGHD6+xt7XQDIuwqkHb/x+1uqoL5HxENnyYg3PYF/V7DeUUI6GAUi5i75IFBXBTj6AyFjWn6foy8w+3Og563ArE/ZsesHxG/AHAcUp2j3dzwOrHYEdq9i52orWWLbzaYkDfh7uXY/7zL7Xknnsn8N+3n9bjL7mQWA4U+yP0szAGVB657HccAnfYHy+p7DdTMBZaE2IJnzBfB0ArC6FOg2DfAfzHKyprwDDK0fsrm2F0jc3frvRV0H/G8w8P1k49YBSjuh/cAx+jlg4mvac8f/p/1eCelAFIiYqxNfAV+N0v7SiJgM6Kwr0yLRi1m+SL+7gKCR7FjSXu15ZT6bYdPQkY+BhF3AF8OAzwezT4k3E/5Ni8epgZzzpmkLuTHVZcDB+no56fW9CNYuwPhXAEs7ABwbTlntCGx9tPk3d40G+Gk2y5nQdeILIOcC2w4aAdh7as9JpcCEV4AhjwDdZwEOvux4XCtn7HAc8Em0dr8gkR37bRHwbijw9zNAdWnrntmYXS9qt90igN4LxeePfc6CL15FHnBpGwuUCDESCkTM1T/PsRoFfJdzwwTV1gqbwP5M/A/Ivcx+sfHj6PY+wLAnxNdveQgoTmZfBfFte21TqyxidSWuH2SB1aWt2nPefdifmWdM0zbSOhzHvt721z8nlbGv2gY1dM5tAP43DCjP1b+Hl3tRPOuGd+g9QKMCrJxYr2RjLG2BBT+z7aS9LLBpqaLr2um0AAuGci8CV/8CKguAU9+wuh+tDQY0GpaUDbC/s92rgMzTbN+jB+AeCdh7sVl4gx6ub/tu4Oc52uT1H6cDm+5u25ATIc2gQMQcGeqB8B/StmeGjWd/Ju0GvhjKPv1d3saOOQUAk15nXc7T1rJjNTqfwL4cwT4pdRZ1teLE3E33sN6dn2YB+Ve1x1cVAd1ns22qMmne1HXA5vuAtRFASoz2eMQUIPpOtj32JfbnyGfYn4MeBvrfy7ZrSoH3IxoPEPJ0cqo8ewIvpLHaPDyvXs33SHr3AeRWQE2ZuOeBl3wIOPgeUFMuPt4weXTv62z9KF1ZsSygOrOu5UNOxz4D3gkEEvew1+bXofLuAzx2lAVtAPuQM16nPEDOeWDjXUB+vLZtDXNwCGlHtPquOSpOFu8PXQo4eLftmZ492Sc6fmZB1lnttL1e87TXRUxm5asb2vUi4BEFHPkE0NQBizYBFtZta5OxbF8CXNzM3jwiphr+pCuRsl/E3ecAe19jgUhFHmDn0eHNJS1weRtw8Q+2vU5nVtgtXwNya6DvHdpgfcwKoP/dLMAGWI8D/zOQfZb1jp1dDyz8BbBzZ8f54RcAuG0DYOUIzPoE+OUWdix0XPNtlMkBqQWAajakM+oZwNaNnasqBn69A6gtZ225+09tYJPfoMexIkc77DRtLRtCPfiOdnpyViwwswWLW/K1h/5+ig3P8gYbSFpX2AM9btHOILuyg33xLG2afz1CbpBRekRSUlJw//33Izg4GNbW1ggNDcUrr7yC2tqbLNfAWHQ/8d37LzD5zbY/UyJhwcO0tWxsmOceBQx6ULvvFABMeRvw6M6utXLSnvt5LnB9P2tfwq62t8kYyrKBCxvZLIfsc9pf6Lo8ugO3fse2XUMB13B2PRV3ujEaNVts8djnxlt0MfWI/rH+97KAQW4JBA5jORsACwj4IARgSaa8b8axN/T048Chd1nbY3/STsud/TngHMi2Q8cBI58GhizRH7pszDCdOiO6wc2531gQArD/PwffZb0MGxYCe15hxwMNDL9699XP4zjzI/uT49hrnPyG9WAoC4DzG1mCuW6hvpI09u8CAN1mAH1uM9z2Of8DxhjoyQGMVyCOEBipR+Tq1avQaDT46quvEBYWhosXL+LBBx+EUqnE2rVrjfGSN4fMWNbzkHaM7Y99GQgc2n7P94hiXwp7YGv9mPCAe/WvG/Io+wKAfnez2gjrZoivOfUtG9ZobQKtsf0wtfFzA+5jbypuYeLjTgGsyFNJmuH7iL7qUiDmfUDhALgEs1wK3oXNrGep2wxgeP0beMYZ4OqfrLfiRt7UMk6xP8MmAnaeQFWR9me0OY6+wIyPgL+WiY+XZ7MARLcQmm7Ph0QCjG+Q2Nyckc+w3I6cC+wrdCw73nBo48Bb4n2pBTDzI0BVyZLUeZ49WG+E3EqcWF6SzvLI4ndqj/GvkXeF9WwaMmZF4/9nLayB6EX6bQO0uSaEGIFRApEpU6ZgypQpwn5ISAji4+PxxRdfUCDSmMTdrKgYX69ApmBvnMbQYy5w/AtWb6T/PU1fK7cEgkcCd+1gORa8lBj2KbWtSbTtqVapnY7sHKTdnviadlqnIfynZwpEWqZWCXw9Fiiqnwpu7yM+LyQ5JwDDHmdvfN/Wv8Gf+BoIGAw41q+N1JICfRo1q7EBADM+EPd2tFT4RP1jyTHAlT+1+wFDAQcf/etaQyZnU+ZzLrChkZAxgGuYttDZyGeAGAO/A+f/CLiFs+3RL7DemrEvaodEFm1i05Wz41iw8lHPxttw+AMWqDV07z+AVxP3Aezf5a7tgKqKtfvbCUB1CRu2JMRIOixHpLS0FC4uLk1eU1NTg5oabd2KsjIzXrSN41rWG6CqYuO7V/9mBclCxrA3Sd178+NZIh4fhHj2Au77F1DYGaPl7BPpwwbyJpoSMBQIHM6GamRy9ukr/UTrAhGOY28qsiZ+7Mqy2C+9c7+y7nBH36afWV3KZsW4BAM5FwFwgK0H8EQcy2WpLgNsXZt+BgUirZMcow1CAG3dDd3gD2B5ESWp9dNp66mU2sTg2J+AcS8Do55t+vXKstjMFamFdopsazn6ARNfZzlShddYvY/qEu35EU/Vr9PUDnot0K5H89dTwNR32f9tW3egxxxxIOISAjywV7tEAwCMeQEYuVzccxQ8in0l7QF+ubX5NvDDN26RLBAPG88+eLSEbr2i+/4F/jeE5a0QYiQdEogkJSXh008/bbY3ZM2aNXj11Vc7okksez7zDPtF0NwbVUNnfwF2PsuS2jRq9qnOJVh8DcexdSnO/qI9xncN+w1kxZAmrGa/HLYvZZn2AcPYpxGpXDvebS7klsC99d3ARz9lgUhGK6e87n2N3fvgXu20WV3Z54FvxwPq+lyi/Hj2RtZrAfslDABBw8X3/H6nfjKqRxQL9GQWLfu35XMCiq637vvpqq7V16JxCWU9H5yG/SyPeIr9GxYmAYfWAnmX2HBjeRNvYvveYHlKUTMav4avkurop53pcSOG6+R5/DhDnIs16lk2Bbc9OPqyIZADa9jvGL7OiWdPbQVWgAX29+7Uv18iaXz4KnQ8m3qbd8nw+W4z2AyZmvoPcaOfAyKnGL62Rd9L/ZRlPuC3aeTDZFk2K4g2+JHmPzwQ0kCr3u1eeOEFSCSSJr+uXr0quiczMxNTpkzB/Pnz8eCDDzbyZGbFihUoLS0VvtLT27BeRHN2LAW+m8DmyDelVgmkHtPOq1er2KwMVSWbb7/+VuCzgeLENI2GTYHTDUJ0ZZxiU+uOfQ4ceAfIOMmO3/ote8M3tyCkocBh7M+k3S1PTNSoWZexRsV6f678pS3SVJ4D/LcS+GqkNggBWGJs0XWWcPrjNPaVUp+0qKpmOQqGZsT0XtC674d/c8i91Lr6D11Rcox2ldmxL7LAeeYnwLD6oS/v3kDPW7S5TUc+AnatED9j+DLx/gEDCcW6iusDET5gbA8TVrMCaM5BwMqC9gtCeGNeAJwCAXBsFWyABd8SCXD7b+wD0IQb+NAlkQCzPwVklqyXkjfve2DcSpaEPWUNO+bgy6Y3t4XCTtsLxZeGbyj3MltA8ugnwL7X2/Z6pEtqVY/I008/jXvuuafJa0JCQoTtrKwsjB07FsOGDcPXXze/RLZCoYBC0QHZ2WXZ2uqHKTHsjbH7bPap29pZfO2me4HEXax7NWwCW4+hIY2KFTHy6sX2970GHP5Qe96zFzDxVSDmAyBVZzEtPluev6azfJLw6Qf4DWIBVMIuYOD9zd+TFafdLkwCfl/EphRaObGgTHctj3Eva7P8G/pxGluB+L+XxYl6vMGPsIqyreEaxqaAqpQsgPTp27r7u4qkvdrprBFTWa6RVKbtrdLl258lNGefY/vOQcCkN9hQWa95LMjkz+VeYKvW8lNpdVUWafM4nNoxEPEbADx1iU3jbumQRWtF3wnsf0P7s80HBZFT2deN8u3P/g9YObKf16pi8UKXfe4AJDKWi9Mew7tu4azI2j/PArf/rl9KYK9OQJV7se2vR7qcVgUi7u7ucHc38MvCgMzMTIwdOxb9+/fHDz/8AKk5fcrXXUETYPUJLv7B3sCmf8iqDxYkAPH/siAEYBnqhsgUbAXXPa8AAx9g//F1gxCFA0s0c/Bm47Rn1gF/GpgK2G16+3xvHUEiYW/WGSdb3iNiqIQ6v5gfT24FzPuBreNxcWvj3c+/3qZfBCpqJguOmku+NUQqA3yigbSjLNC875/WP+Nmp6oGdr2k3Z++tulhEp9+4v1bvgH8B2n379jE8kx2PMFmLOWc01b/5ZXnAJ/211ZK9eyBdmXs2hhDHgXO/sRyj1zDxN9/WznVD5kYytGSSoG+t7ffa7lFsnWqss+xDxAPNij+lxmr3VY4tN/rki7DKDkimZmZGDNmDAIDA7F27Vrk5+cL57y8vIzxkq0TPJJNS60u1VYXBdhQSkUeK4PelF4L2Kd6mSXrnuZnkxz+ALBxE1/7eKz4k16/u9gnMN2VcBWOhqfRmjO+u/b8RrYuzsRXxfVIdO18Vtud79kLGPqY+Pv36g08sIcl9lo7sWNBw1kg4tmTzSiQK1iJ7u8maIOQxsbYb8TUt9m0ybSjLJmxJbM5uoqKPGBtuHb/ueTGcwV4un9/MoX+m7C9J/vy7sMCkWwDgci1feJy7V69b6z9pqKwA+7cBlzcAvS7s235LaYUMAQ4+RXbbrgUQnkOoNSZUZN6hPWSZp9nOTlUf8T8qevYz6YJSzEYJRDZvXs3kpKSkJSUBD8/P9E5js+1MCWfaGBW/SJTl3ewlViV9cGSoSCk1wL2qYbTsKS2iEnacxwHRE5jwwQx74vve+qSfnezRMKqQLpFsJk0fW5jnyLszSBAaw3H+n/XsvoekZ3PGA5EMk5rgxAAGPQA+/7P/apd6dO3P/uFpftLq/89bChg+DLttEanACByOhD/N9v3H9x+3493HyBkLBsyuLSl+ZkcXYGqmq1zottz1WtB80EIIB7uUDexgrN3b1YFN9tAj1nDMuft3SPSEVxDgdGd/Gep4dCbbtIqP7wmtWBD1AArQwCwXuDWDpMS49No2Irulrbsd+y/L7CZVW0ZLmwjo4yX3HPPPeA4zuCX2ek+C3g2SZz4FTAMmP0/wMKGFU+asBq4fxfwwG5xEAKwwGLeD/rj17M/175ZG+I3gK3c6R7Z9vLtpmBoGqWhyqSx68T7nvV5NO5R2mOGeh88ewBPxAK95+sf5+n+m7WHnvX5D/veYN/Lia+BH6ZrFwdsSvY5IPVo+7bH2DQaoKai8fOnv2e5OEIP1Aj2f6E98bOn+Dc03tWd2oCz/z1sQTkr6vY3CVs39vuMp/vBgv93C2/wexFgw5y1SkoANze/LwI+7MESz/lh7qOfmbRJZpS4YWIe3bXb3WezCoMvpANPxzefRGphBSw9BfgOYAHJC2k3/ycBtwjW5Q6d7ryfbxHXkdCoWa8PwBYku/U7wLc+d8A9Untda7rcdZOJ/Qe2ttVNi5rJgk+A1U7451mWXMzXZGhMeS7w/RTgh2lA+sn2bZMx/b4YeC9MvBpxZixLfgTYmia8qJnAvX+3LqF6XP1sEX4hRUP4f/viZDZriXf0E/bnwAfZuirdZ+nfSzpO9GLglm/Zdtx67SrIKfXJ90HD2XRugPVwAmwK8Vs+wCd9qDKrudBoWO99VTGrlq2uZUtc3L7BpM2iRe94ulMD+cTRpgpvNSRXAPfvZj0k5lb23BhsXVlRNIUDm3nw43SWfLjzWZacC7AVTSsL2TUjlov/PrtNB87+zEp2G5p10ZjIqWw6qFdv/RlObWXtzP4Nvx6j7WYG2No1I57Sn4Fw/QBw/SDLbVFVsmM/TgeePNf2Cp3GlndF2+Ow6R429Li7vpy5vTfwyBE2K4NndwNDhyOeZgupuYQ0fo2Ni3ZI7I8HgYcOsCFQPqAb9njj95KOxZeNL0ljgbr/YO30+eBRbNiurppNudetwlySxgKWpmrFkI7BFx/UNfwJ8f91E6BAhNd7IXDud6DPQm1GemuZ08ygjuChM7xy+2/A5wNZjk3yIdazEPMBO+fTVz+os/dibzqt5RLM3uh1F+NrT1492bTUCxu1x0rSWHGqia+xNVWcg1m+yq93sCm/utS1rMjauJfZOiN1taw2jLlJ3C3e54MQgK3BkrhLXNZ76JLWv4ZU2rKk31u+Af43mCUnH/2YJa1yalbn40bKuRPjsHIA3LsB+Ve1X5AAsz7Vli4ADOcQZZ6hQMQcGCra2GNux7ejgS72ztkEey/gsaNNr0lCGuceoR3eWjeTVUjlP3H7RLfvazkHaWfXGEO/O7XbQx5jf55Zx2YIHVgDbH0I+GGKOAiRSFn9BoBN//79Tlbk7u0AVqzN3PD5PP0aKeiXcVq7DPy87/UrB7cnO3dWYwQA4jZo8w68enWN3sXOZNpawNKebUtkrAij7v8XgCVBLj0tPnb9gLYoJDGN7HPsdzPAFnecsJpVB1fYm7RZAAUipD01lhcT1cnG94NHAfftYgl6k95kgU9tObDtkcbviZgiHoKoLWc5GHVV2nwHc8LXaAmfxIZhADbzYWD9zKfT32mvdeyAXgk+2bHoOvBn/YcBv3bOASJtFzwSeDEDeDgGePgQK05niHMQW6qClxUrrjdCOt5OndlbfoPYcLOZ1K+ioRnSfoYuYTMcOA2r0SJTsIz7zvipNmAI+wJY8vKRj7XnJr3JkvOkclbQ7ewvwLT3WNXQg2+zxfYAceJurbL9y4jfqOxz2l4Hzx6st+ORw2xYSTdhlOceYfw22brpHwsbb/zXJTfGu5kEc5kF8OR59rtg57NAwj9seMavf+P3KAtZrp2xFvvs6oqStdt9FpquHQZQjwhpX5a2rKvP0Y91uXfGIKShgQ+I9/vdyYabvHqx5Nnb1rPkVI9uwLKLhhMs+XwZcxBXnyEfMUU75OLVi812aJiU+sC+jktki9T5dBYwlHpEOjtHX5Zvx9cBKkxq/NryXODTaGB9Iz0spG04TptQP3lN0wnkJkCBCCHNcQoQJ+M19cbs4K2dxqgr5n0g/RSr2spPaQZYT4luUuiNUhayAmDKguav5euCGOqW9YmGMCXbLcLwKsnGMu1dYO7XwMv5bPl5Y60BQzoWH4hknDJ8nuOA/W+yXtS0YzTVt6U0GrYMSWOLEQrXqVkl69oK1ovb8IOVGaChGUJaYvbnwHeTWFXY5jg2mHXlP4QtBX9lO5DwH1AQz2aK9JrPaq9kn2N1aG50thYAbL5HW6l29AtsyCX3EhsHtrASX8v/4nIzMORi5w48d529OSjsWzeFva0c/cyuy5i0A7f6mkFZsaxQXbdp4vP/PC8ufJhzgeWikKZd3soWbAXYrMXGKqMmH2KVrAGWx2eGs/goECGkJbz7AM8ktizPQ7ei7u2/s+UD0o8DaSdYEAIAWx5k63KkH2f78TuBwQ+3vD1px9kU3CGPsh4aPggBWJ4KrzgZmPkJC0YubWPl1EvrK8W6hsOglpRwJ6SlfPuzn7XCRDZ7hg9EsuJYefG0Y+Lrs89RINISuZe129uXskrUhnprS9K021PfM367bgANzRDSUlYOLVu4TLf6qFdP7fTljAZVV3Urtv7zHEvq46uaNiX9JPD9ZCBmLfBeKLCjiaJf539nxaXqaoBNdwNX/mTH7TxZUTpCjE0mZz1zACtyCLDlE74eLQ5CnIPYn3xwTppWkaPdriwALmw2fF1ZfRGz/veYZW8IQD0ihLQ/hT0w/X02huvoxxJAFQ6s5DXAgoCaCjajwDUUyL3Ijp/8mtUreTaRBRtyBZB3ldUr6TmPBS4VuawirS6+21UiY0MyOQ0WkEs/AbzhIT7mO6Ddv21CGsWvEZV7kQ37nfpW/5pbvgG+m8jWbCrLrp9xJ2ProSjr84aaWs23OJUF3rkX2bpIgx8yzvdianEbgL+Ws9IAAAvgilNYb9PA+/WvL8tkfxpaH8xMUCBCiDHoJoTJ5EDgcDaFEWDjtAMfYNObLW2BXS9q63aoa4Bf5un3nuxZbeA1HmSfiq78yQqq3fE7ED4RiP+HJcH+YeCXEq+91+khpCnukexntKqYJUvzCdO88MmAd19WMbmyEPigGzDgPiBiKqvyC7AifD59xfelHmXFKC1sWQ4X30tweQcQOeXmqsyr0QAXNrHEU119bmeFFq8fMFwmgO8RsTffxVUpECGkI0TN1AYiwaPEa9HM+IAtBri9vox6wyDEkHk/aFcLzrvCPmV61le25ZPW5FZsto6NCwAJqxoss2Rl+Pvf2y7fFiEtYmHNZpMVJrJeP+hUWe02g5WJl1sCgx5kyygAbPVn3Rkh/2/v7qOiqtc9gH+Hd0QcBeVl0hFUFAVFDOEqltf0iOXV0AwtNM27utcOHgWNtNMl1zmZqB3LJA9kt7yeU/aeFqYpEpKmSDKisVJ8CdEkRU0FMZAzs+8fv4aZgeF9hj0M389as5i9Zw/z/EQ3j7+352aJaSJypUgUmoS5HVt/73X5w18t3xa55L8l5tTUN3iK6BW9eUH8p6SmUhQmnPiS2EFVPyfMhutfMREh6ghhT4ghkjvlgHpsw9fD5wIj5og5H9W3DOd9hon1/8abowGmhQKNa/4YG/of5ut7qKNaGz1R+/mGiETk2LvieMQckSh09zHsN/TAcrG7btFn4vjCQcP769dJOfEBGiQhz+SI4csP5gB56cCYxeL724OCbebPew0Qu1cf3iRKM5zbL5K2f84A/lxm2L9FPzxmgzhZlagjODgA0zcBT37Y+IQxRydgwL8bjv/8iygM2DfScO6/csUmY+Z2IiWyZfpJ2/q5DeoowNPXdNNDZ3dR28j477zeDaNEpOy4SDSMPb5N9CwOniLmQGnviTkj9uBf90SCZcxVCSw5bihGCIghr5tGO6iuUYm5aN39xBCWjWKPCJEtmZIqxshHPQW4dBPnJr4kxsJDYhuOkRN1FmFzgP2rDMf9xzV+bezfgey/AH2GiiHHTxaYDll+/7+iQvPAh8TyYGd38e8DEIlN+FxRfPL4e6JXpLPv8Hw6E/jtV/G8VyDw+FbRE6JfrqvfNK4k1/z7bfy+wUSEyJb0UAELdpme69kPWGamBgxRZ+LpJ4YlLx4Wx70b2cdG/9rs98Tz326J1TPXzwC3LomVaKd3i9ceWA4EmEloQmcCX78AXDstlgzb8LBEk3RaIPuvwLGt4njMYiDmlYbX9Q6C2BG5kQrHg6dYK0KL4NAMERF1jBnpIhl57J2W91K49xRzpQCRVNy8IHoHHF3MD+EAoqdAX5ahqRo3cis/Dez/S+Pb2hfvBr7bCNTcFsd9hpi/zr2XKLyp5zcC8B4knvdUGya22yj2iBARUcfoFQAs3NP69/VQAVd/EKto9Jvy+YY2vUGXsq8Yzrn9c5tCbbXaauDtCYBzN+A/s8S8sOa8GyMmp9+9LlYO6UmSmONiXPUbMCQX5kQ+A4TMBL57HQieJopGOjiI72XjQ1NMRIiIyLZ5+oqv+140nBs0sen36Est3L7c8s+p+AX4Z6z45Q1JLJP3C23Ze0tyDTvH/lIoJs42R79C7tw3hnNZq8RQjCqsYaHAphIRQOyWPHm16TkbT0IAJiJERGTrutdb8fFQCjBuWdPv0ScieZuB89+Igo5PftKwCKSxU1+KeSV6u58TO7q2xJm9hufnsptPRCSj+RySVny9UiSGYgBD/aipG0SVXdfugEeflsXSyTARISIi21Z/6emI2c0Pfehr1wDAtVPikfd34IEmEpjLGtPjf1W3PMbLx0w/DxClHBwcxaqe+ozrSun+JYoA7ko0vaZXIDBqgelOzXaIk1WJiMi2GW9P7h1kWuG6MQMfAtzrVZLO3wJoaxt/z+UC02OpkVUo9d04L6oG690sBS58B2wYAmSMA6orGr7n7D7D86proghg2XHTax58TuwvZOeYiBARkW0LiBb7jvQKBGa+1bJ5D47ODYdVKn9pWBRSr7Ya+PW8eP7I38TXW6UNrzu3H8hJFZuMAaLqbVq9YZhbpcDXK4B7d8SqHXO1oo5sbjx21Sgg9DGx+2wXYP+pFhERdW5uSuDpr1r/vt6DG5678ZPYBK2+62fELqTuvURJht3PieGT326JJcQ6HQAJeO8xcX3uWiDuH4bikt19gYf+B/jyT6KHo+qa4Xsfe1dUDi7eDczPFNfqJ7bWN3SaYQ+VLoI9IkREZJ8UCuDJj0VC4vv7viK//iQq1R5OA6puiHM3Lxh2Je0z1HRi6K1SUdU2LRxIrTck9PFT4mt3PyDpx993RPY0vD5mMeATAkAS81NuXgAyE4HVPmJeiJsSiP19q/qh08X3mLXV4n8Mto49IkREZL8Gx4jHwQ1iL5JzWcCBNeK1il/E5NU3wgzX64tI9goQvRq/lohH/cKTxsYlGeZy+IUCF4+I55NXA5lLgXKjnZHPZxue940UvS+D/iBW9XRRTESIiMj+6XdnNd6b43JBwwmqxonIz9+LOSXH6w2VhM4SwzCXC4AhDwMuHobXpm4Qwzdj/yR6ZAZOADRmKuc+8jeRICkUXToJAZiIEBFRVxA0WUwCLTNaonspDyj4P9Pr9ImI1wDx9eAG8VXhAPzxqFimO2AC0MMf8Aps+Dm+IcByo71IhsWKiadFnxnORT0rdkIlAJwjQkREXYGDI7Bwb8PzxbtNj/U9J6GPmZ6P/G+gz2Bg5JMiCWkphQKY9S7w7BGj78UkxJjVE5GamhqMHDkSCoUChYWF1v44IiIi85xcgAW7gSGPALEZQMADhtf6/ZtY0dLt971H+gwBYlINrwc/0r7P9h0mhmOe/ATwHti+72VnrD408/zzz0OlUuHEiRPNX0xERGRNAdHiAQAj4gDNP0TyETyt4W6tY/4IhMwQS20DH2z/Z7MnxCyrJiJ79uzBvn378Nlnn2HPnjZUXCQiIrIWB0cg4ummr+nh37qhGGo1qyUiV69exTPPPIOdO3eiW7du1voYIiIi6sSskohIkoQFCxZg0aJFiIiIwIULF1r0vpqaGtTU1NQdV1SY2Z+fiIiI7EarJquuXLkSCoWiycfp06eRlpaGyspKvPDCC60KJjU1FUqlsu7Rr1+/Vr2fiIiIOheFJLW0vCBw7do13Lhxo8lrBgwYgLi4OGRmZkJhVJhIq9XC0dER8fHx2LbNzOYuMN8j0q9fP9y+fRs9evRoaZhEREQko4qKCiiVyhb9/m5VItJSFy9eNBlWKSsrQ0xMDD799FNERUWhb98WlHBG6xpCREREtqE1v7+tMkdErVabHHfv3h0AMHDgwBYnIURERGT/uLMqERERyaZDas0EBATACiNARERE1MmxR4SIiIhkw0SEiIiIZMNEhIiIiGTDRISIiIhkw0SEiIiIZNMhq2baSr/ShjVniIiIOg/97+2WrJi16USksrISAFhzhoiIqBOqrKyEUqls8hqrbPFuKTqdDmVlZfD09DSpW9Ne+ho2ly5dstut4+29jWxf52fvbbT39gH230Z7bx9gvTZKkoTKykqoVCo4ODQ9C8Sme0QcHBysuiV8jx497PYvl569t5Ht6/zsvY323j7A/tto7+0DrNPG5npC9DhZlYiIiGTDRISIiIhk0yUTEVdXV6xatQqurq5yh2I19t5Gtq/zs/c22nv7APtvo723D7CNNtr0ZFUiIiKyb12yR4SIiIhsAxMRIiIikg0TESIiIpINExEiIiKSTZdMRDZv3oyAgAC4ubkhKioK+fn5codkEampqRg9ejQ8PT3h4+OD2NhYFBcXyx2W1axduxYKhQKJiYlyh2JRly9fxty5c+Ht7Q13d3cMHz4cx44dkzssi9BqtUhJSUFgYCDc3d0xcOBAvPzyyy2qR2Grvv32W0ybNg0qlQoKhQI7d+40eV2SJLz00kvw9/eHu7s7Jk2ahLNnz8oTbBs01b7a2lqsWLECw4cPh4eHB1QqFZ566imUlZXJF3AbNPczNLZo0SIoFAps3Lixw+Jrr5a079SpU5g+fTqUSiU8PDwwevRoXLx4sUPi63KJyEcffYRly5Zh1apV0Gg0CAsLQ0xMDMrLy+UOrd1yc3ORkJCAvLw8ZGVloba2FpMnT0ZVVZXcoVnc999/j7feegsjRoyQOxSLunnzJqKjo+Hs7Iw9e/bgxx9/xIYNG9CrVy+5Q7OIdevWIT09HW+++SZOnTqFdevWYf369UhLS5M7tDarqqpCWFgYNm/ebPb19evXY9OmTcjIyMDRo0fh4eGBmJgYVFdXd3CkbdNU++7evQuNRoOUlBRoNBp8/vnnKC4uxvTp02WItO2a+xnq7dixA3l5eVCpVB0UmWU0177z589j3LhxCA4OxoEDB3Dy5EmkpKTAzc2tYwKUupjIyEgpISGh7lir1UoqlUpKTU2VMSrrKC8vlwBIubm5codiUZWVlVJQUJCUlZUljR8/Xlq6dKncIVnMihUrpHHjxskdhtVMnTpVWrhwocm5mTNnSvHx8TJFZFkApB07dtQd63Q6yc/PT3r11Vfrzt26dUtydXWVPvjgAxkibJ/67TMnPz9fAiCVlpZ2TFAW1lgbf/75Z+m+++6TioqKpP79+0uvv/56h8dmCebaN3v2bGnu3LnyBCRJUpfqEbl37x4KCgowadKkunMODg6YNGkSjhw5ImNk1nH79m0AgJeXl8yRWFZCQgKmTp1q8nO0F19++SUiIiLw+OOPw8fHB+Hh4Xj77bflDstixo4di+zsbJw5cwYAcOLECRw6dAgPP/ywzJFZR0lJCa5cuWLyd1WpVCIqKsou7zmAuO8oFAr07NlT7lAsRqfTYd68eUhOTkZISIjc4ViUTqfDV199hcGDByMmJgY+Pj6IiopqcnjK0rpUInL9+nVotVr4+vqanPf19cWVK1dkiso6dDodEhMTER0djdDQULnDsZgPP/wQGo0GqampcodiFT/99BPS09MRFBSEvXv34tlnn8WSJUuwbds2uUOziJUrV2LOnDkIDg6Gs7MzwsPDkZiYiPj4eLlDswr9faUr3HMAoLq6GitWrMATTzxhV0Xi1q1bBycnJyxZskTuUCyuvLwcd+7cwdq1azFlyhTs27cPM2bMwMyZM5Gbm9shMdh09V1qu4SEBBQVFeHQoUNyh2Ixly5dwtKlS5GVldVxY5cdTKfTISIiAmvWrAEAhIeHo6ioCBkZGZg/f77M0bXfxx9/jPfffx/bt29HSEgICgsLkZiYCJVKZRft68pqa2sRFxcHSZKQnp4udzgWU1BQgDfeeAMajQYKhULucCxOp9MBAB599FEkJSUBAEaOHInDhw8jIyMD48ePt3oMXapHpHfv3nB0dMTVq1dNzl+9ehV+fn4yRWV5ixcvxq5du5CTk4O+ffvKHY7FFBQUoLy8HKNGjYKTkxOcnJyQm5uLTZs2wcnJCVqtVu4Q283f3x/Dhg0zOTd06NAOm71ubcnJyXW9IsOHD8e8efOQlJRktz1c+vuKvd9z9ElIaWkpsrKy7Ko35ODBgygvL4dara6775SWlmL58uUICAiQO7x26927N5ycnGS973SpRMTFxQX3338/srOz687pdDpkZ2djzJgxMkZmGZIkYfHixdixYwe++eYbBAYGyh2SRU2cOBE//PADCgsL6x4RERGIj49HYWEhHB0d5Q6x3aKjoxssuT5z5gz69+8vU0SWdffuXTg4mN52HB0d6/5XZm8CAwPh5+dncs+pqKjA0aNH7eKeAxiSkLNnz2L//v3w9vaWOySLmjdvHk6ePGly31GpVEhOTsbevXvlDq/dXFxcMHr0aFnvO11uaGbZsmWYP38+IiIiEBkZiY0bN6KqqgpPP/203KG1W0JCArZv344vvvgCnp6edWPQSqUS7u7uMkfXfp6eng3mu3h4eMDb29tu5sEkJSVh7NixWLNmDeLi4pCfn48tW7Zgy5YtcodmEdOmTcMrr7wCtVqNkJAQHD9+HK+99hoWLlwod2htdufOHZw7d67uuKSkBIWFhfDy8oJarUZiYiJWr16NoKAgBAYGIiUlBSqVCrGxsfIF3QpNtc/f3x+zZs2CRqPBrl27oNVq6+47Xl5ecHFxkSvsVmnuZ1g/uXJ2doafnx+GDBnS0aG2SXPtS05OxuzZs/Hggw9iwoQJ+Prrr5GZmYkDBw50TICyrdeRUVpamqRWqyUXFxcpMjJSysvLkzskiwBg9rF161a5Q7Mae1u+K0mSlJmZKYWGhkqurq5ScHCwtGXLFrlDspiKigpp6dKlklqtltzc3KQBAwZIL774olRTUyN3aG2Wk5Nj9t/d/PnzJUkSS3hTUlIkX19fydXVVZo4caJUXFwsb9Ct0FT7SkpKGr3v5OTkyB16izX3M6yvsy3fbUn73nnnHWnQoEGSm5ubFBYWJu3cubPD4lNIUife0pCIiIg6tS41R4SIiIhsCxMRIiIikg0TESIiIpINExEiIiKSDRMRIiIikg0TESIiIpINExEiIiKSDRMRIiIikg0TESIiIpINExEiIiKSDRMRIiIikg0TESIiIpLN/wONbwIR2BkYxwAAAABJRU5ErkJggg==" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# A plot of the solution of the SDE used to compare the methods\n", + "sol_general = diffeqsolve(\n", + " terms_mlp_sde,\n", + " GeneralShARK(),\n", + " t0,\n", + " t1,\n", + " dt0=0.02,\n", + " y0=mlp_sde.y0,\n", + " args=mlp_sde.args,\n", + " saveat=diffrax.SaveAt(steps=True),\n", + ")\n", + "plot_sol_general(sol_general)" + ] + }, + { + "cell_type": "markdown", + "id": "3114fb2bcb2ab174", + "metadata": { + "collapsed": false + }, + "source": [ + "## ShARK\n", + "`ShARK` is an SRK method for additive-noise SDEs. It uses two vector-field evaluations per step and has strong order 1.5, but applied to a Langevin SDE it has order 2.\n", + " While it has the same order as `SRA1`, it has a better proportionality constant.\n", + "\n", + "Based on equation (6.1) in\n", + " Foster, J., dos Reis, G., & Strange, C. (2023).\n", + " High order splitting methods for SDEs satisfying a commutativity condition.\n", + " arXiv [Math.NA] http://arxiv.org/abs/2210.17543\n", + " \n", + "\n", + "## General ShARK\n", + "`GeneralShARK` is a generalisation of the ShARK method which now works for any SDE, not only those with additive noise. It uses three evaluations of the vector field per step and has the following strong orders of convergence:\n", + "- 2 for the Langevin SDEs\n", + "- 1.5 for additive noise SDEs\n", + "- 1 for commutative noise SDEs\n", + "- 0.5 for general SDEs.\n", + "\n", + "\n", + "## SRA1\n", + "Another method for additive-noise SDEs.\n", + "`SRA1` normally has strong order 1.5, but when applied to a Langevin SDE it has order 2. It natively supports adaptive-stepping via an embedded method for error estimation. Uses two evaluations of the vector-field per step.\n", + "\n", + "Based on the SRA1 method from\n", + " A. Rößler, Runge–Kutta methods for the strong approximation of solutions of stochastic differential equations,\n", + " SIAM Journal on Numerical Analysis, 8 (2010), pp. 922–952.\n", + " \n", + "\n", + "## Shifted Additive-noise Euler (SEA)\n", + "This variant of the Euler-Maruyama makes use of the space-time Levy area, which improves its local error to $O(h^2)$ compared to $O(h^{1.5})$ of the standard Euler-Maruyama. Nevertheless, it has a strong order of only 1 for additive-noise SDEs.\n", + "\n", + "\n", + " ## The \"Space-Time Optimal Runge-Kutta\" method\n", + "This is a general Stochastic Runge-Kutta method with 3 evaluations of the vector field per step,\n", + "based on Definition 1.6 from\n", + "Foster, J. (2023).\n", + "On the convergence of adaptive approximations for stochastic differential equations.\n", + "arXiv [Math.NA]. Retrieved from http://arxiv.org/abs/2311.14201\n", + "\n", + "For general SDEs, this has order 0.5.\n", + "When the noise is commutative it has order 1.\n", + "When the noise is additive it has order 1.5.\n", + "For the Langevin SDE it has order 2.\n", + "Requires the space-time Levy area H.\n", + "It also natively supports adaptive time-stepping.\n", + "\n", + "\n", + "## SLOW-RK\n", + "This is a general Stochastic Runge-Kutta method with 7 stages (2 evaluations of the drift vector field and 5 evaluations of the diffusion vector field) per step. Remarkably, it has order 1.5 for commutative noise SDEs and order 0.5 for general SDEs.\n", + "Devised by James Foster." + ] + }, + { + "cell_type": "markdown", + "id": "12f62a28fb25ada", + "metadata": { + "collapsed": false + }, + "source": [ + "# Comparison of the orders of convergence of various SRK methods\n", + "## General SDEs" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "5a8d281b0522bb92", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-04-02T15:06:25.955243Z", + "start_time": "2024-04-02T15:04:57.426878Z" + } + }, + "outputs": [], + "source": [ + "# A comparison of SlowRK, SPaRK and GeneralShARK for general SDEs\n", + "# We compute their orders and plot them on the same graph\n", + "out_SLOWRK_mlp_sde = constant_step_strong_order(\n", + " keys, mlp_sde_short, SlowRK(), levels=(4, 10)\n", + ")\n", + "out_SPaRK_mlp_sde = constant_step_strong_order(\n", + " keys, mlp_sde_short, SPaRK(), levels=(4, 10)\n", + ")\n", + "out_GenShARK_mlp_sde = constant_step_strong_order(\n", + " keys, mlp_sde_short, GeneralShARK(), levels=(4, 10)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "5beb86506adfa933", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-04-02T15:06:26.185839Z", + "start_time": "2024-04-02T15:06:25.956242Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "
", + "image/png": "iVBORw0KGgoAAAANSUhEUgAABIIAAAOOCAYAAAB88LmZAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAAB7CAAAewgFu0HU+AAEAAElEQVR4nOzddVhU6fs/8PfM0B0KBgii2B2YKGAHduzq2q4da6y7qxv60Q111dW1de1YY+0OQOwu7EJBQZDumDm/P/xyfjMyA0MNCO/XdXk5w9znOfeZOVP3PCERBEEAEREREREREREVe9LCToCIiIiIiIiIiHSDhSAiIiIiIiIiohKChSAiIiIiIiIiohKChSAiIiIiIiIiohKChSAiIiIiIiIiohKChSAiIiIiIiIiohKChSAiIiIiIiIiohKChSAiIiIiIiIiohKChSAiIiIiIiIiohKChSAiIiIiIiIiohKChSAiIiIiIiIiohKChSAiIiIiIiIiohKChSAiIiIiIiIiohKChSAiIiIiIiIiohKChSAiIiIiIiIiohKChSAiIiIiIiIiohKChSAiIiIiIiIiohKChSAiIiIiIiIiohKChSAiIiIiIiIiohKChSAioiJs06ZNkEgkkEgkGDp0aGGn81mIiIjAnDlz0KRJE1hbW0Mmk4n34aZNmwo7PSIionyX8T4nkUgKOxUi+gzoFXYCRES68Pr1axw8eBDHjx/Hixcv8P79e6SmpsLe3h7ly5dHmzZt4O3tjcaNGxd2qpQHL1++RKtWrfD27dvCToWIiOizl5KSgn379uHIkSO4desWQkNDERcXByMjI1hbW8PJyQk1a9ZE48aN4eXlBRcXF41teXh44Ny5c2pvMzQ0hKWlJSwsLGBvb4/69eujYcOG8PLyQoUKFXKUc16KYRs3buQPb1QisBBERMXahw8fMHv2bKxZswbp6emZbn/9+jVev36NS5cuYe7cuejQoQP+/PNP1KpVqxCypbwaPXq0WAQyNjZG27ZtUb58echkMgBA9erVCzM9IiKiz8aRI0cwZswYtT+uJCQkICEhAcHBwbh48SLWrl0LAJgwYQL+/vvvHO8rJSUFYWFhCAsLw/Pnz3Hx4kUAgFQqRceOHTFp0iR06NAhbwdERCIWgoio2Hr48CE6duyIoKAg8W96enpo2rQpnJycYGhoiHfv3uHSpUuIjY0FAJw8eRJ+fn7Yvn07evfuXVipUy6EhITgzJkzAD7+snj37l24uroWclZERESfn40bN2LEiBEQBEH8m6urK+rUqQNbW1uxcHPnzh2EhISIMVFRUVq137hxY7i5uYnXFQoFYmJiEB0djQcPHuD169fi348dO4Zjx45h6NChWLZsGczNzbU+jh49eqB8+fJax/MHIyopWAgiomLp4cOHaNmypfiBRF9fH9OnT8e0adNga2urEpuSkoJdu3Zh+vTpCA8PR0pKCvr164ctW7Zg4MCBhZE+5cLt27fFy+7u7iwCERER5cLz588xduxYsQjUrFkzrFixAvXr11cb//TpU+zbtw///POP1vvo3LkzZs+erfH20NBQbN26FcuWLUNwcDCAj/MmPnjwAOfOnYOxsbFW+5k8eTI8PDy0zouopOBk0URU7CQnJ6N///5iEcjExASnT5/Gb7/9lqkIBHzsPTJ48GCVHiQKhQKjR4/G06dPdZo75Z7yr5Bly5YtxEyIiIg+X0uWLEFKSgoAoGbNmvDx8dFYBAKAKlWq4Pvvv8fTp0/xv//9L19yKFOmDL799ls8evQIffv2Ff9+/fp1zuFDlA9YCCKiYue3335DQECAeH3r1q1o3bp1ttuVLVsWp0+fFrscJyQk4Ouvvy6wPCl/paWliZelUr69ERER5capU6fEyxMnToSRkZFW20kkkiwni84NMzMz7Nq1C126dBH/tnv3bvj7++frfohKGn5SJqJiJTExEStWrBCv9+jRA7169dJ6eycnJ8yZM0e87u/vj2vXrqmN9fDwEJdq9fPzA/BxnprffvsNbm5uKFOmDGQyGaysrNRuf+vWLXz99ddwcXGBsbExSpcuDTc3NyxYsACRkZFa5/yp69evY8qUKahXrx5Kly4NAwMDlClTBq1bt8b8+fO1Gr/v7OwsHltgYCAA4MWLF5g1axbq16+P0qVLQyqVol69ernOM0N8fDyWLVuGDh06wMHBQVyJpFatWpgwYQKuXr2qcVs/Pz8xz2HDhol/37x5s8pSuhKJJN9+QTx+/DhGjx6NWrVqwdbWFvr6+rCyskKDBg0wevRoHDp0SO3E5MoEQcCePXvw5ZdfolKlSjAzM4OZmRkqVaqEAQMGYO/evSrzMmii7hyMjIzE/Pnz0bhxY5QqVQrGxsZwcXHBiBEjVAqkn1q8eLHYVk4m5Lx06ZK4nY2NjfgrsjoJCQlYtWoVvL294eTkBBMTE5ibm8PV1RXDhw+Hj49PtvvbtGlTpsdULpfj33//Rffu3cXnk0QiwYEDBzJtHxUVhXnz5qFRo0awtraGmZkZqlatipEjR+L69etiXE6XYo6IiMCiRYvQrl07ODo6wsjICFZWVqhRowbGjx+PGzduZNvG7NmzxX1mDJlIT0/Hli1bxInPDQ0NUbZsWfTo0QNHjhzRKjdl9+7dw/fff48mTZqgTJkyMDAwEO+D/v37459//kFMTIxOjje3Hjx4gG+//Rb169dHqVKlYGhoiHLlysHDwwPz589HREREtm2oO48AYP/+/fD29kaFChVgaGgIOzs7tG/fHtu2bdPqOZkTSUlJOHDgACZNmoSWLVvC3t5efDycnZ3Rs2dP/PPPP0hNTc3X/Sq7du0aRowYkel9SPl+1HRfZaWg3oeCg4Px008/oW7durCysoKpqSmqVauGiRMninPKaCstLQ1bt25Fv3794OLiAnNzc5iamqJixYr48ssvsX///mwfc+X3IOXhR8eOHcOXX34JV1dXmJmZQSKR4K+//sq0/5MnT2LGjBnw9PREuXLlYGRkBGNjYzg4OKBTp07466+/EB8fn6Pjyg/Kk0M7OTnpfP+fkkgk2LJli8rcQL/++mshZkRUDAhERMXIpk2bBADiP39//xy3ERcXJ5iZmYltDB06VG1c69atxRhfX1/hwIEDgrW1tcr+AQiWlpaZtp01a5Ygk8kyxWb8c3BwEC5fvixs3LhR/NuQIUOyzDsyMlLo3bu3xjYz/llZWQl79uzJsi0nJycx/tWrV8KaNWsEIyOjTG3VrVtXy3tVvcOHDwtlypTJNucBAwYICQkJmbb39fXNdltt77/sBAQECI0aNdJqX/3799fYztOnT4X69etn20bDhg2FFy9eZJnTp+fghQsXhPLly2tsUyaTCWvXrlXb1rt378RzUiaTCSEhIVrdL2PHjhXbHzVqlMa43bt3a/VYd+3aVYiOjtbYzqfPibdv3wotW7ZU29b+/ftVtvXx8RHs7e017lsqlQqzZ88WBEFQ+Xt2li9fLlhaWmZ5XBKJRBg+fLiQkpKisZ1ffvlFjP/ll1+E4OBgoXnz5lm2O2zYMEEul2ebY1RUlNC/f39BIpFk+xjY29vr5HhzKi0tTZg4cWKWr50Zr3GbNm3Ksq1Pz6Po6GihW7duWbbbsWNHITExMV+O5cqVKyrvM1n9c3Z2Fm7dupUv+82gUCiE6dOnC1KpVON+y5cvX6Teh/bv35/leWdsbCwcOXJEq+P39fUVKlWqlG2eTZs2FYKDg7NsJyO2devWQnR0tNCzZ0+1bS1ZskTc7s2bN4Ktra1Wj7+tra1w6tSpbI8pJ69Z2TE1NRXbWrlyZZ7by6D8nvXLL7/kePuJEyeqvMZERESojVO+L3x9ffOWNFExxcmiiahY8fX1FS87OjrC3d09x22YmZmhe/fu2L59OwCIPS2ycunSJcyePRtpaWmwtbVFq1atUKpUKYSFhalMYgwAM2fOxO+//y5eNzExgZeXF8qWLYvQ0FD4+PggODgYnTt3xjfffKNVzqGhofDy8sKjR4/Ev9WsWRN169aFmZkZwsLCcP78eURERCA6Ohr9+vXD1q1btZoMe8+ePZgxYwYAoFy5cmjRogUsLS3x7t27PPVc2rVrFwYOHAi5XA4AkMlkaNmyJSpXroz4+HicP38e7969AwDs2LEDr169go+Pj0oX9fLly2P8+PEAgMePH+Ps2bMAgGrVqqFNmzYq+2vatGmuc/Xz80O3bt0QFxcn/q1ChQpwc3ODjY0NEhIS8OTJE9y9exdpaWlITk5W286jR4/QunVrhIeHi3+rXbs26tWrB4lEgtu3b+P+/fsAgJs3b6J58+bw9/dHlSpVss0xICAAP/zwA+Lj42FnZwd3d3fY2tri7du38PHxQVJSEuRyOcaMGYPatWtnuj/Kli0LLy8vnD59Wuxhk935l5aWht27d4vXBw0apDZuyZIlmDZtmvjruoWFBZo1awYHBwfI5XI8ePAAN27cgCAIOHLkCDw8PHDx4kWYmJhkuf+UlBR069YNN2/ehJ6eHpo3b45KlSohJSUFt27dUom9cuUKunbtisTERAAff2Fu3LgxatasidTUVFy7dg3Pnj3D7NmzUapUqSz3q+ybb77B0qVLxeulSpVCs2bNUKZMGSQnJ+P27dsICAiAIAjYsGED3r17h6NHj2Y7fDE+Ph4dO3ZEQEAATExM4O7uDkdHR8TFxcHX1xdhYWEAPq7sU7VqVXz33Xca23r37h28vLzw5MkT8W9WVlZo0aIFypYti7S0NLx58wY3b95EbGysxvO3II83OwqFAr1798ahQ4fEv9nY2MDDwwM2NjYICgqCr68vUlNTER0djaFDhyI6OhqTJ0/Otu309HT07t0bZ8+ehYGBgXgeJScn4/z583jz5g0A4MSJE5g6dSpWrVqVp2MBPvZMy+jpYWdnh5o1a8LBwQGmpqZITEzE8+fPce3aNaSnpyMwMBCtW7fGrVu3ULly5TzvGwCmTZuGJUuWiNfNzMzg6emJMmXK4P379/D19cXbt2/RpUuXIvE+dObMGYwZMwZyuRwVKlRAs2bNYGFhgVevXsHPzw/p6elISkpCv379EBAQgIoVK2psa8+ePRg4cKA4nNjY2BhNmzaFs7MzpFIpnj59isuXLyM9PR1XrlxBs2bNcP36ddjb22eZoyAI+Oqrr3DkyBFIJBI0atQINWrUgCAICAgIUOldmJCQIPa4sra2Rs2aNeHk5AQzMzOkpqbi1atXuHLlCpKTkxEREYHOnTvj3LlzaN68ebb3VX6oVKkS7t27BwD4+++/MWTIkGxfj3Whb9++4tL0giDgwoUL6NatWyFnRfSZKsQiFBFRvlP+ha9Pnz65bmfZsmUqvyip+0VQ+ZctPT09QSKRCHPnzhVSU1NV4pKTk8XL586dU/lFvk+fPkJkZKRKfHR0tPDFF18IAAQDA4Nsf4mVy+WCp6enGOfm5qb21+OkpCRh9uzZ4v5NTU2Fly9fqm1T+ZdYPT09wcDAQFi7dq2gUCg0HltOPH/+XOXXcDc3N+HZs2eZjmvRokUqv1hPnDhRY5s5+dU6J968eSOUKlVKbLtixYrC8ePH1cZGRkYKq1evFqZPn57ptpSUFKFu3bpiO3Z2dsLp06czxZ08eVJlfw0aNMh0TmVQPgcNDQ0FmUwmLFq0SEhLS8t0DLVq1RJjPT091ba3efNmMaZhw4bZ3TXCoUOHVO6XT88PQRCEM2fOiI+hgYGB8Mcff6jt3XX79m2hRo0aYntjx45Vu0/lx1lPT0/8Jf7Vq1eZYjPOz6SkJMHV1VUl12vXrmWK37Vrl2BiYiIYGhpq9ev6P//8I8ZYWFgI69atU/tY+fj4qPTUmj9/vtr2lHsEZeQwZMiQTL96JyQkCF9++aUYa2ZmJsTHx6ttMy0tTWjRooVKr4nly5erzTMlJUU4dOiQ0KNHD50cb07Mnz9f5TH5/vvvM/U2CgkJEdq3b69yfly5ckVte8rnUcZ93alTp0yv9WlpacL06dNVeiGoO9dy6sqVK8LMmTOF+/fva4x5//69MGjQIHHfbdq0yfN+BeHjc1L5vhw4cKAQExOjEhMXFycMHTpU5f4pzPchQ0NDwdTUVNi6dWum15mAgACV823YsGEajz0gIEAwNjYWH8vp06cLUVFRmeJevHih0tOwU6dOattT7hGU8XpUu3Zt4d69e5lild8vAwMDhYkTJwpXr17V2KMvJiZGmDZtmth+lSpVsuz9p81rlrZ+/vlnlfZq1aolbNu2TYiNjc1Tu3ntEZSQkKDSI/CHH35QG6ecO3sEEanHQhARFSsZH8QAiEM8csPHx0flg8SFCxcyxSh/oAEgzJs3L9t2lYd5tGnTRkhPT1cbJ5fLVb7QZPUBfMuWLWJM06ZNsx26oPxlc8yYMWpjlD+AAxC2bduW7bHlxODBg8W2K1eunOVQoMWLF4uxUqlU45eGgioEDRw4UGzXyclJCA0NzVU7GzZsENvR19fPcqjHtWvXVM7lzZs3q4379Bxcs2aNxjbv378vfvmSSCTCu3fvMsXExcUJJiYmYnuPHz/O8pj69+8vxv7444+ZbpfL5SoFmH379mXZXkhIiDh0S19fXwgKCsoUo/w4Z3zpyu6cX7VqlRhvYmIiPH/+XGPsvn37VNrX9KUqNjZWsLKyEgtcmgoOGR4+fCgOr7S1tVVbDFN+bgIQvvzyS43tJSUlCY6OjmLsv//+qzZu3bp1KuddbobLCkLBHK+2YmJiVArH6gqtGZKTk4XGjRuLsZqKnp+eR+7u7pkKqBkUCoVKm3/88UeujyU3OnXqJO774cOHeW6vSZMmKgUOTcUFhUIhdO/evUi8D0kkEo0FeEEQhCNHjoixZmZmGh9LLy8vMW7x4sVZ5hgfH69SnFZ3zn86PLlMmTJCeHh4lu3m1JgxY8T2jx07pjEuPwtBkZGRmT4HZLyGNG7cWBg3bpywYcOGbN8jPpXXQpAgCIKzs7PYxvDhw9XGKOfco0cPYfz48Vr9+/nnn3OVE9HniIUgIio2YmJiVN78ly5dmuu2bt++rdLWoUOHMsUof6ApV66cxg+eGR4+fKjS5oMHD7KMf/LkiUrvIU0fwOvVqyfG3LlzJ9tjS0pKEr/QWVpaqv0SoPwB0M3NLds2cyIqKkrlF+bsigNyuVyoWbOmGP/999+rjSuIQlBwcLBKQSarLyLZUf7yNWnSpGzjlefeadq0qdoY5XOwdu3a2bbp5uaW5TktCIIwYMCALIs7GWJjY8Vf1jUVjQ4cOKDyYVwbv//+u7jNokWLMt3+6Rf4rL4YZVCe22nmzJnZxiv3bND0peqvv/4Sb//mm2+yPzBBEEaPHi1u899//2W6XfnLsYGBQbbzNM2YMUOMnzp1qtqYatWqiTHfffedVnmqUxDHqy3lQp69vX22RYarV6+qPH7qzs1Pz6ObN29m2ebKlSvF2F69euX6WHJj165d4r6XLVuWp7YCAgJUjvvJkydZxr969UqlV2ZhvQ95e3tn2Z5CoVCZg0xdj5w7d+6It9evX19tD8ZP7dy5U9xGXY/UTwtB+TmfTgbl81nT81wQ8rcQJAgfP4Mov/dq+lemTBlh8uTJwtOnT7NtMz8KQco9a3v27Kk2JrucNf1zcnLKVU5EnyOuGkZExYby/C0AYGpqmuu2zMzMVK7HxsZmGd+nTx/o6WU97Zry/EUNGzZEjRo1soyvUqVKtvPahISE4M6dOwCAGjVqoG7dulnGA4CRkRGaNWsGAIiJiclyJSkA+OKLL7JtMycuXbokrixVqlQpeHt7ZxkvlUoxfPhw8bry/VjQzpw5I64A5urqio4dO+aqnbi4OJVVlJSPR5ORI0eKl69fv46EhIQs4/v27Zttm/Xr1xcvZ6zC86mvvvpKvLxjxw6Nbe3btw9JSUkAgEaNGqFq1aqZYo4dOyZeHjBgQLb5AYCXl5d4+cKFC1nGWltbo3379lnGxMXFqcwXpHx8mmgTU9DH1rJlS5QpUybLmOwez9evX+Px48fi9QkTJmiVpzoFfbxZUV5N7ssvv4SxsXGW8W5ubqhdu7Z4PbvXDBcXFzRo0CDLGG2eO7mVmJgIHx8fLF26FD/++CMmT56MCRMmiP927twpxma83ueW8px3bm5u2c4/5uzsjBYtWmQZo4v3oexe3yQSicp+1T1Gyufwl19+qdWKgDk9h/v3759tzKfS0tJw4cIFrFixAj/99BO++eYblcdfeSXUvD7+OVGlShXcunULK1euVHk+fSo0NBRLly5FzZo1MXPmTCgUigLNS/nz2aef+4hIe5wsmoiKDeVlRQFk+8U5K58u12phYZFlfMOGDbNtU3nS6IwPwNlp1qwZLl++rPF25duSkpK0/qL34sUL8XJQUBDq1KmjMVabY8sJ5fvBzc0t2wIaAJUvIrdv34YgCFov650XV65cES8rLw2cU/fu3RMnxTYzM8vy/s5Qr149mJqaIiEhAXK5HHfv3s1yotCsPqhnsLW1FS9rKm62a9cOdnZ2CAsLw8uXL3Hp0iW1+922bZt4WdMk0crn53///Ydz585lm6Py0uVBQUFZxtarVw8ymSzLmHv37olfTCwsLFCtWrVsc2jSpEm2McrHtnbtWmzevDnbbYKDg8XL2R1bfjyeyuevq6srHBwcsm1Tk4I+3qwov2ZoO1luixYtxInXP504/FP59dzJqcjISPz888/YsmWL1l9oP3z4kKd9KhcStDnPM+LOnz+v8XZdvA/lx2OknKevr69Wy80LSsvHZ3cOV6xYETY2Ntm2mSEpKQm//fYbVq9erfXjmtfHP6cMDAwwduxYjB07Fs+ePcO5c+dw+fJlcWGDjB9KgI8Frd9//x1v377V6vUht5SfK9l9NgM+PtZ5ef8mKq5YCCKiYsPCwgJ6enriB5O8rGgVFRWlcj27D3elS5fOtk3llaIqVKigVR7ZxWWsqgUAr169UvnlUFufHuuntDm2nFC+H5ycnLTaxtnZWbycmpqKuLg4rT4A5tX79+/Fyy4uLrluR/mYHR0dtSpiSaVSODo6ij06svsCYGlpmW2b+vr64uWMFXM+paenhy+++ALLli0DAGzfvj3Tl++QkBCxl0ZGvDrK5+euXbuyze9T+XFuKt/3Dg4OWt332RVM4uPjVb6MrF+/Pts2P5XdseXH45lf568ujjcreX3N0NVzJydev36NVq1aiSuSaSuvPSA+fS3SRnbPB128D+XHY6Sc5/Hjx3OQ3Uf5+V4ZFRUFLy+vHPfwKcweMK6urnB1dRV7q8bHx+P06dNYuXIlzpw5I8Zt2bIF3bt3R69evQokD+UfC3JSeCMiVRwaRkTFivKXhOy6mmfl022Vv1Sok91QBUC1l5G2y7BmN7xN+QNRbin/oqeONseWE8r3g7bD9z6N09WHYeX9fDpcMCdyc8yfxmZ3zPnZQ0q5h8/u3bszfanauXOn2Mumffv2sLOzU9tOXs/P/Dg3c/O8y+6x1sXzLj8ez/w6f3VxvFnJ62uGLp872howYIBYBDI3N8eUKVNw4sQJvHz5EvHx8ZDL5RA+zuWpMrQtr8NuSvLzIa95ZvTq1CQn75Xjx48Xi0AGBgYYOXIkDh48iKdPnyIuLg7p6eni4//q1Stxu4IedpUTZmZm6NmzJ06fPo2lS5eq3Pbp9fySkJCg0tMwu+GzRKQZewQRUbHSokULsbv51atXc92O8rbOzs4oX758nnNT/jCdmJio1TbZDW9T/rLTrVs3HDx4MHfJ6ZDy/aDt8L1P4z4dBlhQlPfz6XDBnMjNMX8aq6tjBj7O+VOtWjU8fvwYHz58wMmTJ9G1a1fx9u3bt4uXs5pPx9TUVPzydevWLZV5VnSloJ93wMfeh9bW1jlProDl1/lb2MdrZmYmnke5ec3Q5XNHG5cuXcKlS5cAfDy2K1euZDlnXH4Wvkvy+5Bynvv27UPPnj0LJY+3b9/i33//BfCx5+eJEyfg6empMf5zmAdn0qRJOHr0KE6dOgXg4zC89PR0rYZ+58SNGzdUCnLZzaNIRJqxRxARFSvKH6aCg4Ph7++f4zbi4+NVPshm9QEtJ5S7jWs7HCC7OQns7e3Fy6GhoblLTMdycz8oT/xpYGCgsy92yvev8q+yOaV8zMHBwSrzTmiiUChUHv9SpUrlev+5MXDgQPGy8nxAjx49EudcMTc3R48ePTS2URTOT+X77e3bt1pto/yLszpWVlYwNDQUrxfV515+nb+Ffbx5fc3Q9XMnO2fPnhUvDxkyJNuFA7SZy0ZbyvdFdue5tnFF4XmujaKSp4+Pj/ge0KlTp2w/Y+Tn41+QlBdTSEtLQ0RERL7vY8+ePeJlqVSKli1b5vs+iEoKFoKIqFjp27evyi/VixcvznEb69atU/n1fMyYMfmSm3JvCOVJXLOS1UTRgOpkn3fu3MnTBNm6onw/XLt2Ldvu9gDEX88zttfVUA7lXxvzslpZnTp1xEmN4+LixElss3L37l3x8ZTJZFqtxJOfBg4cKN7Phw4dEn+VVu4N1KtXryyHQyifnxcvXiygTLNWp04dSKUfP+7ExMSorKKlybVr17KNcXNzEy8X1rFlR/n8ffr0qdZf/NUpzONVfs1Qfi3IinJcdiuC6ZryXDXaTIKcmx80NKlXr554Wdtes9k9Hz6X96Gi8HoEFO7jX5CMjIxUrisXj/NDRESEyiTUHTt21GruKCJSj4UgIipWTE1NMW7cOPH6wYMHsX//fq23f/36NX7++WfxeqtWrVS+AOWF8q9+N27cyPYL6fPnz7MtBLm4uKB69eoAPk6i/M8//+Q90QLWvHlz8QNieHg4jh49mmW8QqHAxo0bxevKy/kWtHbt2old2589e4aTJ0/mqh1zc3M0atRIvL5p06Zst1F+LN3c3HI0t1B+qFixojhJdFJSEvbt2wdBEFSWlM9umXXl4WQbNmxAcnJywSSbBQsLC5VCgnIhSxPlHlCaKB/bqlWrtOrlpWtOTk7i6wOAXE3im6Ewj1f5Of/vv/9mex7duHED9+7dE6/nV6/O/JJRmASyH5717t27fB1qpbx60rVr1/D8+fMs49+8eZPlimHA5/M+pHwO79u3T2UydV3KyeOfmJiILVu2FHRK+eLu3bviZXNzc1hZWeVb24IgYMiQISo/0v3444/51j5RScRCEBEVOzNnzlTpav/VV19p9YtaaGgo2rdvL37QMDU1xbp16/Itr+rVq6usvjR58mSNEz8qFApMmjRJqy9b3333nXj5xx9/1Kq3SYbC6B5vZWWF/v37i9e//fbbLOdAWL58uXhMUqkUo0aNKvAcM5QrV04l19GjR+f6y8Po0aPFyytWrFD5ovqpmzdvYs2aNeL1/OqVllPKk0Zv374dly5dEocYlS9fPtuiXO/evVG5cmUAH1caGzdunNYFhPj4+HzrWTB8+HDx8l9//ZXlMKlDhw6pDN3RZPTo0eIXnVu3bmHOnDla5/PhwwetesLlh6lTp4qXFy1alO2Xek0K83gHDBggzm0TEhKS5b5TU1MxceJE8bqnpyeqVq2a630XBOUV3A4dOqQxTi6XY9SoUUhNTc23fdeqVQuNGzcG8PHL9TfffJPlc3LKlClaTVD8ObwPubm5iYWwpKQkDBo0SOv7NjU1NU8r3ylTfvyPHTuW5XNj2rRphVKwmjJlCp48eaJ1/JMnT7B161bxuvIwsbyKj4/HF198ofKj0aBBg9CsWbN82wdRiSQQERVD9+/fFywtLQUAAgBBX19fmDlzpvDhw4dMsSkpKcLmzZsFOzs7MV4qlQrbtm3Lch+tW7cW4319fbXKy9fXV5BIJOJ2/fv3F6KiolRiYmJihAEDBggABAMDAzF2yJAhattMT08XvLy8xDgLCwth9erVQkpKitr4mJgYYdu2bULr1q2FPn36qI1xcnIS23v16pVWx5YTz58/F8zMzMR9NGvWTHjx4oVKjFwuF/766y9BJpOJcRMnTtTY5saNG7O9r3LjzZs3go2Njdh2xYoVhRMnTqiNjYqKEtasWSN8++23mW5LSUkR6tatK7ZTpkwZwcfHJ1Pc6dOnhdKlS4txDRo0EFJTU9XuL6fn4C+//CLG//LLL9nGR0ZGiuegTCYTevToIW4/ffr0bLfPOB7lx7BTp07Cw4cPNcbfvn1bmDFjhmBlZSXcv38/0+25eZwTExOFypUri9u5uLgIN27cyBS3Z88ewdTUVDA0NBRjs/qopJwLAGHw4MHC69ev1cYqFArhwoULwtixYwVjY2MhLi4uU0xOHx9fX18xvnXr1mpj0tLShObNm4txJiYmwooVK9SeUykpKcKhQ4eEHj166OR4c2L+/Pkq+/7xxx8zvcaFhoYKHTt2FGP09PSEK1euZHss2pxHr169EuOdnJzydCyPHj1SeR+YNm2akJiYqBITEhIidO/eXQAgmJqaZvs458SpU6dU7stBgwYJMTExKjFxcXHCiBEjBAAqz4ei/D40ZMgQMX7jxo1qY+7fv6/y3tOkSRON54ggCMKTJ0+E//3vf0LZsmWFw4cPZ7pdm+fgpyIjIwUTExNxu4EDB6r9HPD1119nevyzOve0ec3Slr29vSCTyYRu3boJe/fuFeLj49XGpaWlCbt27RLs7e1VPj+pe30VBNX3rOxe40JCQoSFCxcKjo6OKsfWvHlzITk5OcttleO1/XxGVNJw1TAiKpZq1aqFCxcuoFOnTggODkZaWhp+++03LFiwAM2aNYOTkxMMDAwQEhKCS5cuqSwra2hoiK1bt6Jv3775npeHhwemT5+OhQsXAgB27dqFI0eOwMvLC2XKlMH79+/h4+OD+Ph4WFtbY/LkyZg9e3aWbcpkMuzevRvt2rXD7du3ERsbizFjxmDGjBlo1qwZypcvD5lMhqioKDx58gSPHj0Sl+rt3bt3vh+jNipVqoT169dj4MCBkMvluHz5MqpWrQp3d3dUqlQJ8fHxOH/+vMrkvk2bNsWCBQt0nqujoyN2796NHj16ID4+Hq9evULHjh3h5OQENzc32NjYID4+Hk+fPsWdO3eQlpaG7t27Z2rHwMAAO3fuROvWrREeHo7Q0FB4eXmhbt264rwdd+7cUeleb2dnh507d0JfX19Xh6vC2toanTt3xoEDByCXy3HgwAHxNuXeQllp27YtVq1ahbFjx0Iul+P48eM4ceIEatSogTp16sDCwgKJiYkICQnB3bt3ER4enu/HYWxsjE2bNqFdu3ZISkrCy5cv0bhxYzRp0gQ1atRAamoqrl27hqdPnwL42AttwoQJALJetnro0KF4+fIl5s6dCwDYsmULtm/fjnr16qFatWowMzNDfHw8goODcefOnXxZZjun9PT0sGvXLnh5eeHZs2dITEzE+PHjMWvWLLRo0QJly5ZFeno6Xr9+jZs3byI2NlbjvBuFebzTp0/HhQsXcPjwYQDAvHnzsGrVKnh6esLa2hpBQUHw9fVFSkqKuM3ChQtV5oUpKqpVq4ZBgwaJQ34WLVqEHTt2oHHjxrCzs0NgYCD8/f2RmpoKc3NzLFy4MF97BbZr1w6TJk3CsmXLAABbt27FgQMH4OnpCXt7e4SFhcHX1xexsbGwsbHBN998Iw6ZVh7WpOxzeR+qVasWdu7cif79+yMxMRFXr15F06ZNUalSJTRo0AA2NjZITk5GWFgY7t27p/UE8zlhbW2N6dOn43//+x+Aj70tjx8/jiZNmqB8+fIICQmBn58fEhISoKenh5UrV2LIkCH5nkd25HI5Dh06hEOHDkEmk6FOnTpwcXGBra0t0tPT8e7dO1y/fj3TpNALFixAw4YNs23/2LFj+PDhg3hdoVAgNjYW0dHRePjwodqem19//TWWLFmSo/mHli5dir1792odX7duXXz99ddaxxN9tgq7EkVEVJDev38vjBkzRtDT01P5hUjTv/bt2wv37t3Tqu3c9AjK8P333wtSqVRjHuXKlRMuXbqUo1+tExMTc3SsxsbGwm+//aa2rYLuEZTh8OHDKr8kavr35ZdfCgkJCVm2VVA9gjLcuXNHpUdPVv8GDhyosZ0nT54I9evXz7aNBg0aCM+fP88yp4LuESQIgrB3795MudWuXVurbZX5+PgIrq6uWt1/AISaNWsKb9++zdROXh7nT3tbffpPKpUKs2fPFlJTU8W/WVpaZtvurl27hHLlyml9bG5ubmp/0S6IHkEZIiIihJ49e2qVX/ny5XVyvDmVlpYmTJgwQaWHmbp/lpaWGnuEZCjMHkGCIAgJCQlC+/btszwOBwcH4cKFC7nqdZIdhUIhTJkyRaVnkrr3ocuXLwtr164V/zZ58uQs2y3M9yFtegRluHPnjtCwYUOtz2FnZ2fh9u3bmdrJ7WOTnp4uDB48OMt9WllZCfv379f63FPeNq/GjRsnWFtba33/ZLxu7NmzJ8t2ld+ztP0nk8mErl27CmfOnNE6/5zuQ/lf9+7d83jvEX0e2COIiIo1Ozs7rFq1Ct999x0OHDiAEydO4Pnz5wgLC0NaWhpKly4NBwcHeHl5oXv37uLcCQXt999/R58+fbBy5Ur4+PggJCQEZmZmcHZ2Rq9evTBq1CiUKlUqR2P0jY2NxWPdtm0bfHx88PTpU0REREChUMDS0hIuLi6oW7cu2rRpg44dO8LCwqIAjzJ7Xbt2xfPnz7FhwwYcOXIEDx48wIcPH2BsbIxy5crB09MTgwcPLhK/6tetWxe3b9/GgQMHcODAAVy+fBnv379HQkICLCws4OLiAjc3N3h7e6NDhw4a26lSpQpu3LiBvXv34r///sO1a9cQFhYG4OP52qRJE/Tp0we9e/fW2epoWenatSusrKwQHR0t/i27SaLV8fT0xKNHj3DgwAEcPXoUV65cQWhoKGJjY2FiYgJ7e3tUq1YNzZs3R6dOnVRWN8ovbdu2xePHj7F8+XIcOHAAL1++RFpaGsqXL49WrVph9OjRaNy4scqcHNpMeNqvXz90794d//77L06ePInr168jPDwc8fHxMDU1Rfny5VG9enW4u7ujc+fOqFKlSr4fW3ZsbGywb98+XL9+HTt27ICfnx+Cg4MRFRUFY2NjODg4oF69eujYsSP69OmTZVuFdbx6enr4+++/MWbMGGzYsAFnz55FUFAQ4uLiYGNjgypVqqBz5874+uuvYWtrmy/7LCgmJiY4fvw4duzYgc2bN4u9aEqVKgUXFxf07t0bQ4cOhbW1Nfz8/PJ9/xKJBIsXL0b//v2xevVq+Pn5ie9DFStWRO/evcX78dy5c+J22T0fPpf3obp16+LGjRs4deoUDhw4gIsXL+Ldu3eIjo6GoaEhSpcujapVq6JJkybo0KEDmjVrlq+vxzKZDJs3b0bfvn2xdu1aXL16FVFRUbC2tkaFChXQvXt3DB8+HOXKlUNgYGC+7VdbK1aswNKlS3H58mWcP38eN27cwJMnTxASEoK4uDjo6enBwsICTk5OqFu3Lrp06YLOnTvDwMAg1/s0MDCAhYUFLC0tUaZMGdSvXx8NGzZE27Zt4eDgkI9HR0QAIBGEIrjMBREREVEhOX36NNq3bw/g46Snx48fL+SMiArPwIEDxdUC//33X5UJ9ImI6PPEVcOIiIiIlOzatUu8rKtegkRFUXx8vMpqTXw+EBEVDywEEREREf2fq1evipP4Ah+XLicqqWbOnClO+t2kSROVpc+JiOjzxUIQERERFXtv3rxB3759ceHCBagbFS+Xy7Ft2zZ06NABaWlpAIBu3bqhWrVquk6VqMAtX74cc+fORXBwsNrbw8LCMGrUKPz999/i37777jtdpUdERAWMcwQRERFRsRcYGIiKFSsC+Dgpd8OGDVG2bFnIZDK8f/8ely9fVlm6vmzZsrh58ybKli1bWCkTFZjZs2djzpw5kEgkqFGjBmrWrAlra2skJyfj+fPnuH79OlJTU8X4IUOGYNOmTYWXMBER5SuuGkZEREQlSlhYWJYTQDdq1Ah79+5lEYiKPUEQ8ODBAzx48EDt7Xp6epg8eTIWLFig48yIiKggsUcQERERlQjXrl3D4cOHceXKFQQHB+PDhw+Ijo6GmZkZ7O3t0axZM/Tq1Qve3t6FnSpRgUpISMDRo0dx5swZ3Lt3D2FhYfjw4QOSk5NhY2MDFxcXeHh4YPjw4ahcuXJhp0tERPmMhSAiIiIiIiIiohKCk0UTEREREREREZUQLAQREREREREREZUQLAQREREREREREZUQLAQREREREREREZUQLAQREREREREREZUQLAQREREREREREZUQeoWdAH1+kpOTcf/+fQBA6dKloafH04iIiIiIiIgov6WnpyM8PBwAULt2bRgZGeW5TX6Dpxy7f/8+3NzcCjsNIiIiIiIiohLj2rVraNy4cZ7b4dAwIiIiIiIiIqISgj2CKMdKly4tXr527RrKli1biNkQERERERERFU8hISHiiBzl7+J5wUIQ5ZjynEBly5aFg4NDIWZDREREREREVPzl1/y8LARRtmrWrKlyPS0trZAyISIiIiIiIqK84BxBREREREREREQlBHsEUbYePHigcj04OBiOjo6FlA0RERERERER5RZ7BBERERERERERlRAsBBERERERERERlRAsBBERERERERERlRAsBBERERERERERlRAsBBERERERERERlRAsBBERERERERERlRAsBBERERERERERlRB6hZ0AERERERFRXiUnJyM6OhqJiYmQy+WFnQ4REWQyGQwMDGBhYQEzMzNIpUWjLw4LQURERERE9NkSBAEhISGIiYkp7FSIiFSkp6cjJSUFcXFxkEgkKF++PMzNzQs7LRaCiIiIiIjo8xUREZGpCKSnx685RFT45HI5BEEA8LFo/fbt2yJRDOIrJBERERERfZZSU1MRHh4uXrezs4OVlRVkMlkhZkVE9JEgCEhMTERkZCTi4+PFYlCVKlUKdZhY0RigRkRERERElEPx8fHiZVtbW9ja2rIIRERFhkQigampKRwcHGBmZgbgY3FI+bWrMLAQREREREREn6WEhATxsoWFRSFmQkSkmUQigY2NjXg9Nja2ELNhIYiIiIiIiD5TqampAD5+yTI0NCzkbIiINDMxMYFEIgHw/1+7CgsLQURERERE9FlSKBQAPi7RnPEFi4ioKJJIJOLQVblcXqi5sBBERERERERERFRCsBBERERERERERFRCsBBERERERERERFRCsBBERERERERERFRCsBBERERERERERFRCsBBEREREREREOuPh4QGJRAIPD4/CToWoRGIhiIiIiIiIiHIkISEBq1evRufOnVG+fHkYGRnB0NAQpUuXRuPGjTF8+HCsW7cOQUFBhZ1qvvLz84NEIlH7z8TEBI6OjujatSs2bNiAlJSUbNvL2Da7olh6ejr69+8vxjdt2hTR0dH5c1BU4ugVdgJERERERET0+bh8+TK++OILvHnzJtNtHz58wIcPH3Djxg1s3LgR9vb2CA0NLYQsdS8pKQnBwcEIDg7G0aNHsXjxYhw5cgTOzs55ajctLQ39+/fH/v37AQAtW7bEsWPHYG5ung9ZU0nEQhARERERERFp5enTp+jQoQPi4uIAAN26dUOfPn1QpUoVGBgY4MOHD7h79y5Onz4NX1/fQs62YI0dOxbjxo0Tr4eFhSEgIAALFy5EcHAwHjx4gG7duuH27duQyWS52kdKSgr69OmDI0eOAPg4rO7IkSMwNTXNl2OgkomFICIiIiIiItLKrFmzxCLQxo0bMXTo0Ewx7dq1w/Tp0xEeHo7du3frOEPdsbOzQ61atVT+5uXlhWHDhqFOnToIDAzE/fv3sX//fvTp0yfH7ScnJ6NHjx44efIkgI/368GDB2FsbJwv+VPJxTmCqEQIeBuDGXvv4m10UmGnQkRERET0WZLL5Th69CgAoFGjRmqLQMpKly6N8ePH6yCzosXc3Bw//vijeP3MmTM5biMxMRFdu3YVi0CdOnXCoUOHWASifMFCEJUIf556gt03guG50A9zDj9AeFz2E7cREREREdH/Fx4ejqSkjz+sVq5cucD3d+HCBQwaNAjOzs4wMjKClZUV6tevjx9//BHh4eFqt/nzzz8hkUigr6+P+Pj4TLcnJyfDyMhInHT5zp07atupVq0aJBIJvvjii1zlXrt2bfFyTifMjo+PR+fOnXH27FkAH4ffHThwAEZGRrnKhehTLARRsXc9MBJ+Tz6+UaTKFdh4MRCtF/riz5NPEJOUVsjZERERERF9HgwMDMTLjx49KrD9KBQKTJgwAe7u7ti2bRtev36NlJQUxMTE4M6dO/j111/h6uqK06dPZ9q2devWAD6usnXhwoVMt1+9elVlNS8/P79MMe/fv8eTJ08AINdL3CvfV/r6+lpvFxsbi44dO+LcuXMAgD59+mDv3r0q7RHlFQtBVOz97fM8098SU+VY7vsc7vN9sNLvORJT0wshMyIiIiKiz4eNjQ2cnJwAAHfv3sX8+fOhUCjyfT/ff/89VqxYAQCoWLEiVq9ejWvXrsHX1xdTpkyBvr4+YmJi0LVrV9y9e1dl2wYNGoiraakr8nz6t+xiMgpLOaVcKNN21bCYmBi0b98eFy9eBAB8+eWX2LlzZ44KSUTa4GTRVOz92bcOVvq+wParr5EmF1Rui01Ox4ITT7DhQiAmelXGF26OMNTL3Yz+RERERFT0KBQCohJTCzsNnbI2MYBUKimQtidOnIjp06cD+FiwWb16Nbp164bmzZvDzc0NFStWzFP79+/fx6JFiwAAtWrVwvnz52FlZSXe7uHhgfbt26NLly5ITU3FqFGjcPXqVfF2mUyGli1b4vjx42qLPBk9bby9vXH48GH4+/tDoVBAKpVmirG3t0f16tVzfAxyuRwLFy4Ur2szUXRMTAzatm2LGzduAAAGDx6MjRs3quRFlF9YCKJiz87cCLO71cRI94pYdvYZ9t4MhkK1HoQP8Sn45dADrPV/iW/auqJn/fLQk/FFl4iIiOhzF5WYiobzcj5Z7+fs5o9tYWtmWCBtT5kyBQ8fPsSGDRsAAIGBgVi2bBmWLVsG4GPxxMPDAwMHDkTXrl0hkeSsILVq1Sqxl9H69etVikAZOnbsiOHDh2P9+vW4du0arl+/jsaNG4u3e3h44Pjx47h58ybi4+NhZmYG4ONS7FeuXAEAfPfddzhz5gyioqJw79491KtXT9w+o4DUqlWrHOUeHh6O+/fv4+eff8bt27cBfCwCtWzZMtttlecq+vLLL1kEogLFM4tKDAdrEyzoUxenprRGlzpl1ca8jU7Ct3vvocNf/jh2PwSKTytGREREREQlmFQqxT///INTp06hY8eO0NNT7Vvw/v177Nq1C926dYObmxtevHiRo/YzVtiqWbMmmjRpojHu66+/zrRNBk3zBF27dg1JSUmwtLRE06ZN0bRpUwCqQ8HCwsLEYV3ZzQ80Z84ccdJpiUQCOzs7tGnTBhcvXoSJiQmmTp2KHTt2ZH/QgErB7PLly3j37p1W2xHlBgtBVOJUtjPDigENcGRiS3hWLa025kV4AsZtv4VuKy7A70kYBIEFISIiIiKiDO3atcPx48cRERGBY8eOYc6cOfD29oalpaUYc+PGDbi7uyMkJESrNlNSUvDs2TMAyLIIBAD169cX584JCAhQua1hw4ZiLyDlIk/G5ZYtW0Imk4mFHuWYjGFhQO7nBwKAevXqYdKkSVrP79OyZUtxhbLAwEC0adMGoaGhud4/UVZYCKISq1Z5S2wc5oY9Y5rBzdlGbUzA21gM3Xgd/ddcwbVXkTrOkIiIiIioaLOwsECnTp3w888/49ChQ3j//j02bNgAa2trAEBISAh++uknrdqKiooSL9vZ2WUZq6+vD1tbWwBAZKTq53Q9PT20aNECgPoiT0YBKOP/jHmClGNKly6NmjVrZpnD2LFjcf/+fdy/fx+3b9/G4cOHMWTIEEilUly6dAkeHh4al7n/lFQqxdatW9GjRw8AwNOnT9GuXTtERERotT1RTnCOICrxGjvbYNfopvB/9gELTz5GwNvYTDHXAiPRb81leFQtjentq6JWeUs1LRERERFRUWNtYoCbP7Yt7DR0ytqk8JYaNzQ0xLBhw1CuXDl07NgRALBv3z6sXbs2R3Pe5HRuoU95eHjg5MmT4jxBhoaGuHz5sngb8LHXkZGRkco8QRmFIG3mB7Kzs0OtWrXE6/Xq1UPXrl3h6emJoUOHIjAwECNHjsTBgwe1yllPTw+7du1C9+7dceLECQQEBKB9+/bw8fFR6WlFlFcsBBHh4xtN6yql0cq1FE4EhOLPU0/wIjwhU5zfk3D4PQlHl9plMaVdFVS2MyuEbImIiIhIW1KppMAmTibNOnToAEdHRwQFBSEqKgoREREoXVr9tAwZMnoRAR/nGspKenq62FvGxiZz7/5P5wkyNzdHYmIiLC0tUb9+fQAfi1ZNmzaFn58f/Pz84ODggAcPHgDIfn6grAwZMgSHDx/Gf//9h0OHDsHHxwdeXl5abWtgYIB9+/ahS5cu8PX1xa1bt9CpUyecOnVKHO5GlFccGkakRCKRoFPtsjj5TSv82bcuylsZq407ej8E7Zecw7d77iI4KlHHWRIRERERFX3lypUTL2vTw8fQ0BCurq4AoLIkvDq3b99GWloaAKj0ysnQuHFjmJqaAvg4PCyjp0/G/EAZlOcJ8vf3F+cGzcv8QADw22+/ifuZOXNmjrY1NjbGoUOH0KxZMwAfJ4/29vZGUlJSnnIiysBCEJEaejIp+jR0gM/01vhf95oopeZXJIUA7LkZDM8//TD70AOExSUXQqZEREREREVPYmIiHj58CODjPEIZ8/lkp23bj8P4Hjx4gGvXrmmMW79+faZtlOnp6aF58+YAIPb4ATL39FGeJ8jHxwcAYGtrq7a4lBNVqlRBv379AHwsap0+fTpH25uZmeH48eNo2LAhgI/H0KtXL6SmpuYpLyKAhSCiLBnqyTC4mTP8Z3jgu47VYGmcedb/NLmATZcC0XqBHxaceIyYxLRCyJSIiIiIqGDFx8ejSZMmOHLkiDi5sjoKhQITJ05EXFwcAKBbt25az/kzduxYcS6hUaNGITY28/ydp06dwj///AMAcHNzQ+PGjdW2lVHkuXnzJi5evKjytwxNmjSBoaEhoqKisG3bNgAf5wfK6xxFwMeeQBntzJs3L8fbW1pa4uTJk6hduzYA4MSJE+jfvz/S09PznBuVbCwEEWnBxEAPYz0qwX+GJyZ6VYaJgSxTTFKaHCv9XqDlAh+s8H2OhBS+QBMRERFR8XLt2jV4e3ujQoUKmDBhArZv344LFy7g7t27OHfuHP766y/Uq1cPGzZsAPCxmDF37lyt269duzamTZsGALh79y4aNGiAdevW4caNGzh37hymT5+Orl27Qi6Xw8DAAGvWrNHYlvI8QZ/OD5TByMgITZs2BQDExMQAyNv8QMpq1aqFbt26AfjY4+jChQs5bsPW1hanT59G1apVAQAHDhzA4MGDsyzEEWWHk0UT5YClsT6mta+KIc2dsdL3BbZdeY1UueqLcFxyOhaefIKNF19hvGdlDGhSAYZ6mQtHRERERESfEz09PZQpUwahoaF4+/YtVqxYgRUrVmiMd3V1xc6dO+Hs7Jyj/fzxxx9ISEjAypUr8eLFC4waNSpTjKWlJXbv3o169eppbMfNzQ0mJiZITPw4p+en8wNl8PDwEOcQAvI+P5CyWbNmiauGzZ07FydPnsxxG/b29jh79izc3d3x6tUr7Ny5E8bGxli/fn2+9Fyikoc9gohyoZSZIX72rgHfbz3Qv5EjZNLML8Af4lMx5/BDeP15DruvByFdzqo9EREREX2+jIyM8PbtW1y8eBFz5sxBp06d4OLiAlNTU8hkMlhYWKBatWro378/duzYgYCAAHGOm5yQSqVYsWIF/P39MXDgQFSoUAGGhoawsLBAvXr1MHPmTDx79gzt27fPsh19fX1xwmVAc08f5b/b2NigTp06Oc5Zk8aNG6Ndu3YAPg5pu379eq7aKV++PHx8fODo6AgA2LBhAyZOnJhveVLJIhEypkUn0lJwcLD4AhQUFAQHB4dCzqjwvQiPx5LTT3HkXojGGJdSppjavgo61yoLqZrCERERERHlzLNnz5Ceng49PT1xtSkioqIqN69ZBfH9mz2CiPJBpdJmWD6gAY5OagmvanZqY15+SMCEHbfR9e8L8H0cBtZgiYiIiIiISNdYCCLKRzXLWWLD0MbYO6YZ3CraqI15GBKLYZuuo+/qy7j6MkLHGRIREREREVFJxkIQUQFo5GyDXaOaYstwN9Qub6k25sbrKPRfewWDN1zD/eAYHWdIREREREREJRFXDSMqIBKJBK2qlIa7aymcfBCKP089xfOw+Exx/k/D4f80HJ1qlcG09lVQ2c68ELIlIiIiIiKikoA9gogKmEQiQcdaZXHym1ZY1LcuHKyN1cYdDwhF+yX+mLb7LoIiE3WcJREREREREZUELAQR6YhMKkHvhg7wmeaBud1rorS5YaYYhQD8dysYXov88PPBAITFJhdCpkRERERERFRcsRBEpGMGelIMauYM/2898X2narA01s8UkyYXsOXya7Ra6Is/jj9GdGJqIWRKRERERERExQ0LQUSFxNhAhjGtK+H8d56Y5FUZJgayTDHJaQqsPvcC7vN98ffZZ0hISS+ETImIiIiIiKi44GTRlK2aNWuqXE9LSyukTIonCyN9TG1fFYObO2OV3wtsvfIaqekKlZi4lHQsOv0Umy4FYrxnZQxoUgFG+pkLR0RERERERERZYY8goiKilJkhfupaA37TPfClmyNkUkmmmIiEVPzvyEN4/emHXdffIF2uUNMSERERERERkXrsEUTZevDggcr14OBgODo6FlI2xV85K2P83qsORrWqhCWnn+LQ3XeZYt7FJOO7/+5j9bmXmNquCrrULgupmsIRERERERERkTL2CCIqoiqWMsWyL+vj2CR3tK1upzbm1YcETNx5G13+vgCfx+8hCIKOsyQiIiIiIqLPCQtBREVcjXIWWD+kMf4b2xxNXWzUxjwKicXwTTfQZ/VlXHkZoeMMiYiIiIiI6HPBQhDRZ6KhkzV2ft0U20Y0QV0HS7UxN19H4Yu1VzDon6u4Fxyt2wSJiIiIiIioyGMhiOgzIpFI0NK1FA6Mb4E1gxqiir2Z2rjzzz6g2/KLGLP1Jp69j9NxlkRERERERFRUsRBE9BmSSCToULMMjk9uhSX968LRxlht3IkHoWj/lz+m7r6DoMhEHWdJRERERERERQ0LQUSfMZlUgp71HXB2qgfm9agFO3PDTDGCAOy79RZei/zw04EAhMUmF0KmREREREREVBSwEERUDBjoSfFVUyec+9YTMztXg5WJfqaYNLmArVdeo9VCX/x+/BGiElILIVMiIiIiIiIqTCwEERUjxgYyjGpVCf4zPDG5jStMDWSZYpLTFFhz7iVaLfDFsrPPEJ+SXgiZEhERERERUWFgIYioGLIw0seUdlXgP8MTX7tXhIFe5qd6XEo6Fp9+ilYLfLH+/Eskp8kLIVMiIiIiIiLSJRaCiIoxWzNDzOpSA+e+9cCAJhUgk0oyxUQmpGLe0Ufw/NMPO6+9QZpcUQiZEhEREdHnJCEhAatXr0bnzp1Rvnx5GBkZwdDQEKVLl0bjxo0xfPhwrFu3DkFBQZm2HTp0KCQSSaZ/UqkUVlZWqFu3LsaPH487d+4USO6BgYFq9y+RSGBkZIRy5cqhffv2WLp0KWJjY7Ntz9nZGRKJBM7OztnGTp06VdyXq6ur2vuHqKBJBEEQCjsJ+rwEBwfD0dERABAUFAQHB4dCzoi0FfghAX+deYqDd99B0zPf2dYEU9pVgXedcpCqKRwRERERFRXPnj1Deno69PT04OrqWtjplBiXL1/GF198gTdv3mQba29vj9DQUJW/DR06FJs3b852W6lUiu+//x6//vprrnNVJzAwEBUrVtQq1tHREQcOHECDBg00xjg7O+P169dwcnJCYGCg2hhBEDBp0iQsX74cAFCtWjWcPXsW5cqVy3H+9PnKzWtWQXz/1stzC0T02XAuZYq/vqiPMR6VsOjUU5x++D5TTGBEIib/ewer/F5gevuqaFPdDhIJC0JEREREBDx9+hQdOnRAXFwcAKBbt27o06cPqlSpAgMDA3z48AF3797F6dOn4evrm217J0+eFIshCoUC79+/x9GjR7FixQqkp6fjt99+Q/ny5TFu3LgCOZ7u3btj3rx54vWoqCg8fvwYS5YswaNHjxAUFIQuXbrgyZMnsLCwyNU+BEHAmDFjsHbtWgBAzZo1cfbsWdjb2+fLMRDlFAtBRCVQtTIWWDe4EW69icKfJ5/g0ouITDGPQ+MwcssN1K9ghW87VEXzSqUKIVMiIiIiKkpmzZolFoE2btyIoUOHZopp164dpk+fjvDwcOzevTvL9qpUqZJpSFW7du3Qpk0bdOvWDQAwe/ZsjB49GjJZ5oVQ8srKygq1atVS+Zu7uzuGDh2KVq1a4cqVKwgNDcXatWsxffr0HLevUCgwcuRIbNy4EQBQt25dnDlzBqVK8bM1FR7OEURUgjWoYI0dXzfF9pFNUNfRSm3M7TfRGLDuKr5afxV3gqJ1mh8RERERFR1yuRxHjx4FADRq1EhtEUhZ6dKlMX78+Fzty9vbG+7u7gCA8PBw3Lp1K1ft5Ja+vr5KT6EzZ87kuA25XI4hQ4aIRaCGDRvC19eXRSAqdCwEERFaVC6FA+OaY+2ghqhqb6425sLzD+ix4iJGbbmBJ6FxOs6QiIiIiApbeHg4kpKSAACVK1cu8P25ubmJl1+/fi1efvnyJRYtWgRvb284OzvD2NgYxsbGcHJyQv/+/XHixIl82X/t2rXFyzmd1Dk9PR0DBw7Etm3bAABNmzbF2bNnYW1tnS+5EeUFh4YREQBAIpGgfc0yaFPdHofvvsPi00/xJjIxU9yph+9x+tF79KhXHt+0dYWTrWkhZEtEREREumZgYCBefvToUYHvT19fX7wsl8sBAK9evUKlSpXUxr958wZv3rzB7t278dVXX2Hjxo3Q08v9V17l41XOJTtpaWn44osvsG/fPgBAy5YtcezYMZibq//BlUjX2COIiFTIpBL0qF8eZ6e1xq89a6GMhVGmGEEA9t9+izaLzmHW/vsIjUkuhEyJiIiISJdsbGzg5OQEALh79y7mz58PhUJRYPu7f/++eDljQmm5XA4DAwN4e3tj2bJlOHPmDG7duoUzZ85g5cqVqFmzJgBg27ZtmDt3bp72r1zs0mZpeABITU1Fnz59xCKQp6cnTpw4wSIQFSnsEUREaunLpBjYxAm9Gzhg25XXWOH7HFGJaSox6QoB26++wd6bwRjS3BljWleCjamBhhaJiIiICoFCASRFFnYWumVsA0gL5jf/iRMnipMmf//991i9ejW6deuG5s2bw83NTetl2bNz9+5dcYiXiYkJGjduDAAoW7YsAgMDUbZs2UzbtGnTBmPGjMHw4cOxadMmLFq0CFOnToWlpWWucvj999/Fy3369Mk2Pi0tDT179sSxY8cAfJz0+uDBgzA2Ns7V/okKCgtBRJQlI30ZRrq7oH9jR2y4EIh1518iPiVdJSYlXYG1/i+x4+objHSviBEtK8LcSPvus0REREQFJikSWKh+KFGx9e0LwLRgJiSeMmUKHj58iA0bNgAAAgMDsWzZMixbtgwAYG9vDw8PDwwcOBBdu3aFRCLRum1BEPD+/XscOXIEP/zwgzgcbNKkSTAy+thL3dTUFKammqcmkEgkWLRoEbZu3YqEhAScOXMGvXv31jqH6OhoPHr0CL///jsOHz4MAGjWrBn69++f7bbv3r3Du3fvAACtW7fGoUOHxLyJihIODSMirZgb6WNyW1ecn+GJ0a1cYKiX+eUjPiUdf515hlYLfLHO/yWS0+SFkCkRERERFRSpVIp//vkHp06dQseOHTPNwfP+/Xvs2rUL3bp1g5ubG168eJFlexUrVoREIoFEIoFUKkXZsmXx9ddf48OHDwCALl264H//+5/G7dPS0hAcHIxHjx4hICAAAQEBePfuHWxtbQF87FmUlc2bN4v7l0gksLa2RvPmzXH48GHo6+tj6NChOHHihFZzBCkXve7fv4+nT59muw1RYWAhiIhyxNrUAD90rg7/GZ74qmkF6Ekz/8oTlZiGX489QuuFvth+9TXS5AU3dpyIiIiIdK9du3Y4fvw4IiIicOzYMcyZMwfe3t4qw7Bu3LgBd3d3hISE5KhtAwMDtGjRAps3bxYLMsrS0tKwYsUKNG3aFGZmZnB0dESNGjVQu3Zt8V9YWBgAiAWl3HB1dcWUKVNgYWGhVXyFChXw7bffAgAiIyPRrl07PH78ONf7JyooHBpGRLlib2GEeT1qY5R7Jfx15in233kLQVCNeR+bgln7A7Dm3EtMbVcF3nXLQaamcEREREREnycLCwt06tQJnTp1AgCkpKRgx44dmDZtGqKiohASEoKffvoJ69evV7v9yZMnxYmgpVIpzMzMUKZMGZUVu5RFRkaiffv2uHnzplb5ZSx3r0n37t0xb948AIBCocC7d+9w4sQJrFmzBg8fPoSHhwcuX76MqlWrarW/BQsWICkpCcuXL0dYWBjatm0Lf39/uLi4aLU9kS6wEEREeVLB1gSL+9fD6NaVsPj0E5x88D5TzJvIRHyz6w5W+b3AtPZV0K6GfY7GixMRERHlmrHNxzlzShJjm0LbtaGhIYYNG4Zy5cqhY8eOAIB9+/Zh7dq1kKqZwLpKlSpar8gFAJMnTxaLQD169MDw4cNRp04d2NnZwcjISPyMWaFCBQQFBUH49JfKT1hZWaFWrVri9Tp16qBjx47w9vZGx44dERUVhQEDBuDatWuQyWRa5bhs2TIkJiZiw4YNePv2Ldq0aQN/f384OjpqfZxEBYmFICLKF1XLmGPNoEa4ExSNRaee4PyzzN1wn7yPw6itN1HX0QozOlRFi8oFM4khERERkUgqLbCJk0mzDh06wNHREUFBQYiKikJERARKly6dpzZjY2Oxa9cuAMDAgQOxbds2jbFRUVF52lebNm0wefJkLFq0CLdu3cKmTZswYsQIrbaVSCRYt24dkpOTsWPHDgQGBorFoDJlyuQpL6L8wDmCiChf1XO0wtYRTbDj6yaoX8FKbczdoGgMXH8VA9Zdwa03eXuTJiIiIqKiKWPIF4B86Q3+7NkzpKWlAUCWq3g9fvwY8fHxed7fzJkzxfmB5syZg9TUVK23lUql2Lx5M3r16gXgY+5t27ZFREREnvMiyisWgoioQDSvVAr7xjbHP0MaoVoZc7Uxl15EoNfKSxi5+QYeh8bqOEMiIiIiKiiJiYl4+PAhgI/zCGWs4pUX6enp4uWEhASNcatXr87zvgDAxsYG48ePBwAEBQVh8+bNOdpeT08PO3fuFOdPevDgAdq3b4+YmJh8yY8ot1gIIqICI5FI0Ka6PY5NcsfSL+rB2dZEbdyZR+/Rael5TP73NgI/aH5TJyIiIqLCEx8fjyZNmuDIkSNQKDSvCqtQKDBx4kTExcUBALp165YvPYIqV64strN582a18/8cPnwYy5cvz/O+MkyZMgUmJh8/w/7xxx+Qy+U52t7AwAD79u2Dl5cXAODWrVvo2LFjvvRYIsotFoKIqMBJpRJ0r1cep6e2xu+9aqOMhVGmGEEADt55hzaLz+GHffcREpP1Cg9EREREpHvXrl2Dt7c3KlSogAkTJmD79u24cOEC7t69i3PnzuGvv/5CvXr1sGHDBgCApaUl5s6dmy/7trW1RefOnQEAJ06cQPv27bFv3z7cvHkTx48fx8iRI9GzZ0+4uLjkeT6iDKVLl8bXX38NAHj58iV27NiR4zaMjIxw6NAhtGjRAgBw5coVdO3aNdsVzYgKCieLJiKd0ZdJ8aVbBfSsXx7brrzGSr8XiExQHWstVwjYee0N/rsVjMFNnTDWoxJszQwLKWMiIiIiyqCnp4cyZcogNDQUb9++xYoVK7BixQqN8a6urti5c2eOVgXLzqpVq9CyZUu8efMGZ86cwZkzZ1Rur1ChAg4cOCAWjPLD9OnTsWrVKqSmpuL333/HwIED1a6AlhVTU1McO3YMbdq0wY0bN3Du3Dn07NkThw4dgoGBQb7lSqQN9ggiIp0z0pdhpLsL/Gd4Ymq7KjA3zFyTTk1XYP2FV2i1wBeLTz9FbHJaIWRKRERERBmMjIzw9u1bXLx4EXPmzEGnTp3g4uICU1NTyGQyWFhYoFq1aujfvz927NiBgIAANGzYMF9zcHR0xK1bt/Dtt9+iSpUqMDQ0hKWlJerWrYtffvkFd+7cQY0aNfJ1nw4ODhgyZAgA4NGjR/jvv/9y1Y6FhQVOnjyJOnXqAABOnjyJ/v37q8x9RKQLEkHdwEqiLAQHB8PR0RHAx0nTHBwcCjkj+txFJaRitf8LbL4UiOQ09ePNrUz0MbZ1JQxu5gxjA5mOMyQiIqKi6NmzZ0hPT4eenh5cXV0LOx0ioizl5jWrIL5/s0cQERU6a1MD/NCpOs5964lBTZ2gJ808mWB0Yhp+P/4YrRf6YuuV10hN1zxBIREREREREanHQhARFRn2FkaY26MWfKZ5oFeD8lC3uERYXAp+OhCANov9sO9WMOQKdmokIiIiIiLSFgtBRFTkVLA1weJ+9XDym1boWLOM2pigyCRM3X0XHf/yx4mAULXLhxIREREREZEqFoKIqMiqYm+O1YMa4uD4FnB3LaU25llYPMZsu4keKy7i/LNwFoSIiIiIiIiywEIQERV5dR2tsHVEE+z8uikaOlmrjbkbHINB/1zDl+uu4ObrKB1nSERERERE9HlgIYiIPhvNKtli75hm2DC0EaqXtVAbc+VlJHqvuoQRm67jUUisjjMkIiIiIiIq2lgIIqLPikQigVc1exyd2BJ/f1kfFUuZqo07+zgMnZaex6Sdt/HqQ4KOsyQiIiIiIiqaWAgios+SVCqBd91yOD2lFeb3ro1ylkZq4w7dfYe2i8/hh3338C46ScdZEhERERERFS0sBBHRZ01PJkX/xhXgM90DP3etAVtTg0wxcoWAndeC4PGnH+YeeYiI+JRCyJSIiIiIiKjwsRBERMWCkb4Mw1tWhP8MT0xvXwXmRnqZYlLTFfjnwiu0WuCLxaeeIDY5rRAyJSIiIiIiKjwsBBFRsWJqqIcJXq44P8MTYz0qwUg/88tcQqocy3yew32+L1afe4GkVHkhZEpERERERKR7LAQRUbFkZWKA7zpWg/+3nhjSzAn6MkmmmJikNPxx/DFaLfTF1suBSE1XFEKmREREREREusNCEBEVa3YWRpjTvRZ8pnmgT0MHSDPXgxAel4KfDj6A1yI//HczGHKFoPtEiYiIiIiIdICFICIqERxtTPBn37o4NaUVOtcuozYmOCoJ0/bcRYe//HEiIASCwIIQEREREREVLywEEVGJUtnOHCsHNsThCS3RukpptTHPw+IxZtstdFt+Ef5Pw1kQIiIiIiKiYoOFICIqkWo7WGLzcDfsGtUUjZys1cbcfxuDwRuu4Yu1V3AjMFLHGRIREREREeU/FoKIqERr4mKLPWOaYeOwxqhR1kJtzNVXkeiz+jKGb7qOB+9idJwhERERERFR/tEr7ASIiAqbRCKBZ1U7tHYtjeMBoVh0+glehidkivN5HAafx2HoWqcsprarApfSZoWQLRERERERUe6xRxAR0f+RSiXoUqcsTn3TCgv61EF5K2O1cUfuhaDdEn98t/ce3kYn6ThLIiIiIqKPnJ2dIZFIMHTo0ALbx9ChQyGRSODs7Fxg+yDdYiGIiOgTejIp+jVyhM/01pjtXQOlzAwyxcgVAnbdCILnQj/MOfwA4XEphZApERERUeFJT0/Hf//9h1GjRqF27dqws7ODvr4+LC0tUblyZfTs2RMLFy7Eq1evCjvVEk0QBBw6dAhffvklXF1dYWZmBj09PVhZWaFWrVro27cvFi5ciLt37+o8N09PT0gkEkgkErRv317r7Tw8PMTtlP/JZDLY2NigYcOGmDx5Mh48eJBtW7Nnzxa39/PzyzL2woULsLCwgEQigZ6eHrZt26Z1zkUJh4YREWlgqCfD0BYV0beRIzZdCsSacy8Qm5yuEpMqV2DjxUD8ey0IQ5o7Y3QrF1ibZi4cERERERUnhw4dwrRp0/D8+fNMt8XGxiI2NhYvXrzAgQMHMGPGDHTp0gV//PEHatWqVQjZllzv379Hnz59cOHChUy3xcTEICYmBg8ePMDevXsxY8YMPHr0CNWqVdNJbq9fv8a5c+fE62fPnsW7d+9Qrly5XLepUCgQFRWFqKgo3Lp1CytWrMC8efPw/fff5zlfPz8/dO3aFQkJCdDT08P27dvRr1+/PLdbGFgIIiLKhqmhHsZ7VsZXTZywxv8FNl4MRFKaXCUmKU2O1edeYNuV1xjewhkj3F1gaaxfSBkTERERFZx58+bh559/hiAIAD72zujatSvq1KkDW1tbJCYmIiQkBP7+/jhy5AgCAwNx9OhRODg4YPXq1YWcfcmRmpqKdu3a4f79+wCA+vXrY9iwYahXrx7Mzc0RGxuLR48ewd/fH0ePHkVMjG4XRdm6dSsEQYChoSHkcjnS09Oxbds2zJgxI0ftZBwf8PGYX758iQMHDmD79u2Qy+X44YcfUKlSJfTt2zfXuZ45cwbdunVDUlIS9PX1sWvXLvTs2TPX7RU2FoKIiLRkaaKPGR2rYWgLZ6z0fYHtV18jTS6oxMSnpGOZz3NsuhSIUa1cMLRFRZgZ8qWWiIiIiocNGzbgp59+AgDY29vj33//hYeHh9rYvn374q+//sK///6LmTNn6jBLAoB169aJRZJhw4Zh/fr1kEpVZ4dp1aoVRo8ejZSUFOzcuRNWVlY6y2/r1q0AgK5duyIpKQnHjh3D1q1bc1wI+rSXWYMGDdCnTx80adIEkyZNAgDMmTMn14Wg48ePo1evXkhOToahoSH27t2Lrl275qqtooJzBBER5ZCduRFmd6sJ3+ke6NfIATKpJFNMbHI6/jz1FK0W+GKt/wskpcrVtERERET0+QgKCsL48eMBABYWFrhw4YLGIlAGmUyGgQMH4u7du+jSpYsOsqQMBw8eBADo6elh8eLFmYpAygwNDTF06FCUKVNGJ7lduXIFT58+BQAMHDgQX331FQAgICAAt27dypd9jB8/HhUqVAAAPHjwAKGhoTlu4/Dhw+jRoweSk5NhbGyMgwcPfvZFIICFICKiXHOwNsGCPnVxZmpr9KxfHpLM9SBEJqTit2OP0WqhLzZefIXkNBaEiIiI6PO0ePFiJCcnAwB+/fVXVK5cWettrays4O3trfH20NBQzJo1C40aNYKNjQ0MDQ3h6OiIfv364cyZMxq3CwwMFCf63bRpEwDg9OnT8Pb2RpkyZWBoaIiKFSti7NixCA4O1ipXX19fDBkyBC4uLjAxMYGFhQVq166Nb7/9Fu/evdO4nfKkw8DHOXjmzp2L+vXrw8rKSiVHAEhISMCuXbswcuRI1KtXD5aWltDX10fp0qXRunVr/Pnnn4iPj9cqZ3XevHkDAChVqlS+9vSJjo7Gzz//jJo1a8LU1BRWVlZo1aoVtm/frnUbW7ZsAQBYW1ujS5cu6NGjB8zNzVVuyyupVIqaNWuK14OCgnK0/b59+9C7d2+kpqbCxMQER44cQYcOHfIlt8LG8QpERHlUsZQplvSvh3EelfDXmWc4ej8kU0x4XArmHH6Itf4vMcGrMvo2dISBHmvxRERE9HkQBEEcymNubo5hw4blW9vbt2/H6NGjkZCQoPL34OBg7NmzB3v27MGIESOwevVq6Oll/RX2hx9+wB9//KHyt8DAQKxevRr//fcfzp07h+rVq6vdNjk5GcOGDcO///6b6baAgAAEBARg1apV2LlzZ5ZFLQB49uwZ2rdvj8DAQI0xXbp0UZksOcOHDx/g7+8Pf39/rFy5EseOHcvVBM4GBh8XMHn//j0iIyNhY2OT4zY+9eTJE3Ts2DHTcZ0/fx7nz5/H5cuXsXz58izbSE1Nxa5duwB8HD6YkWevXr2wefNm7Ny5E3/++We2j7U2MtoGAH197efv3LVrF7766iukp6fDzMwMx44dg7u7e57zKSr4LYSIKJ+42ptjxcAGODqpJdpWt1cbExKTjFn7A+C1yA+7bwQhXa7QcZZEREREORcQEICIiAgAgLu7O0xNTfOl3d27d2PQoEFISEiAi4sLFi9ejBMnTuDmzZv477//0LlzZwDAP//8k+3cMevWrcMff/yB1q1bY8eOHbhx4wbOnDmDwYMHAwDCw8MxfPhwtdsKgoA+ffqIRSBvb29s3boVFy9exOXLl7F06VJUqFABCQkJ6NOnD27cuJFlLn369MHbt28xceJEnD59Gjdu3MDOnTtRtWpVMSY9PR21a9fGrFmzsH//fly9ehVXrlzBrl278MUXX0AqleLVq1fi0KScatCggXhsX3/9dZ56FwFAYmIivL29ERERgR9//BF+fn64ceMG1q1bBwcHBwDAihUrcPLkySzbOXLkCCIjIwFAHBKmfDksLAwnTpzIU64ZHj16JF52cnLSapvt27dj4MCBSE9Ph4WFBU6dOlWsikAAewQREeW7muUssX5II9wNisbi009x7ml4ppjgqCTM2HsPq/xe4Ju2ruhap5zauYaIiIgobxSCAtEp0YWdhk5ZGVpBKsnf3/zv3bsnXs4oMOTVhw8fMGrUKAiCgOHDh2PNmjUqvUAaNGiAXr16YdasWfjtt9+wdOlSjB49WqWYouzSpUv4+uuvsWbNGnF4FgC0adMGBgYGWL9+Pa5cuYLbt2+jfv36KtuuX78eR48ehb6+Pg4dOoSOHTuq3N60aVMMGjQI7u7uePDgAb755hu1S7JnCAgIwPHjx9G+fXvxbw0bNlSJ2bhxI1xdXTNt26RJE/Tr1w8jRoxAhw4d8OTJE2zfvh0jRozQuD91xo0bh61bt0KhUGDfvn3w8fGBt7c33N3d0aRJE9SsWRMymUzr9sLDw5GamorLly+rDLlq2LAhPDw8ULt2bSQnJ2PlypVZDqHKGPrl7OyMli1bin/38vJCuXLl8O7dO2zZsiXPc/Hs27dPnIeoTZs2sLa2znabLVu2YPPmzVAoFLC2tsbJkyfRuHHjPOVRFLEQRERUQOo6WmHzcDfcCIzEolNPcfllRKaYVx8SMPnfO1ju8xxT21VBh5plIGVBiIiIKN9Ep0Sj9a7WhZ2GTp3rfw42RnkfBqTsw4cP4uXSpUtrjFMoFHj48KHG26tWrSoO0Vm1ahViYmJQvnx5rFy5UuNQoDlz5mDz5s14+/YttmzZgl9//VVtXNmyZfH333+rFIEyTJ8+HevXrwfwcRiTciFIEATMnz8fADBp0qRMRaAM1tbWWLhwITp37oyLFy/i2bNnags5ADB06FCVIpA6mrbN0LZtW3Tr1g0HDhzAgQMHclwIcnNzw5o1azBu3DikpaUhOjoaW7duFYf4mZqaonnz5ujbty8GDBigVS+vuXPnqhSBMlSuXBk9evTAv//+m2WBLCIiAseOHQMADBgwQOWxkkqlGDBgAP78808cPnwY0dHROZ7bKGP5+P3792PevHkAABMTE43nzKc2btwIADA2NsbZs2czFQyLCw4NIyIqYI2cbbBzVFPs+LoJGjmp/yXiWVg8xm6/ha5/X8CZh+8hCILaOCIiIqLCEBcXJ17OqmAQGxuL2rVra/z39u1bMfbQoUMAPi4fbmhoqLFNPT09NGvWDABw+fJljXF9+vTR2E7VqlVhZmYGAHj58qXKbQ8fPsSLFy/ENrLSqlUr8XJWuQwcODDLdtQJDw/Hs2fPxPmIAgICxKLb3bt3c9weAIwcORL379/HsGHDxMmYMyQkJOD06dMYNWoUXF1dsx2OJZFIMGDAAI23Z/R4ioyMRHR0tNqYnTt3Ii0tDYDqsLAMGX9LTk7Gnj17ssxHOa+Mf4aGhqhevTpmzpyJxMRENGjQAKdOnUKTJk20bgsAkpKScPToUa22+RyxEEREpCPNK5XCnjHNsGlYY9RxsFQb8zAkFiO33ECPlZfg/zScBSEiIiIqEpSLCJ9O6pwbcrkcd+7cAQBxKFdW//bu3QsAWS4Bnt2EyhlDg5SLWgBU5vtp1qxZlnlkFJOyy6VOnTpZ5pLh4sWL6N+/P2xtbWFnZ4cqVaqoFM7WrVsHQLVHVk5VrVoVGzZsQEREBC5duoTFixdj4MCB4rw+ABASEoKuXbtmuUJbqVKlYGtrq/F25cmoP72PM2zevBnAx2F/6ibtrlu3LmrVqgUg76uHGRgYYMSIEWjRooXW2/z222/iuf7TTz9hyZIlecqhqGIhiIhIhyQSCTyq2uHg+BZYN7gRqpUxVxt3NygagzdcQ781l3FFzZAyIiIiIl1SLgCEh2ee/zCDlZUVBEFQ+TdkyJBMcZGRkUhPT89xHomJiRpvMzExyXJbqfTj11+5XK7y97CwsBznkV0u2sxHM3v2bLRs2RK7d+8WJ0/WJCkpKcf5fUpfXx/NmjXDlClTsG3bNgQFBeHs2bPiUC+5XI5x48Zp/CFS2/s3o61PPXr0SCy6qesNlGHQoEEAPhbJXr16lfVBAbh//774z9/fH8uXL0elSpWQmpqK8ePHY+HChdm2kaFp06Y4cuSIeKxTp07F6tWrtd7+c8E5goiICoFEIkG7GvZoU80OxwNCseTMUzwPy7ySw/XAKHyx9gpaVLbF1HZV0VDD0DIiIiJSz8rQCuf6Z16iuzizMrTK9zbr1q0rXr59+3ae21MuFIwcORKTJ0/Wajvl5cDzi3Iuhw8fhrOzs1bb2dnZabwtu0mYz549izlz5gAAXFxcMH36dLRs2RIVKlSAqampOF/Szz//jLlz52qVT254eXnh9OnTqFWrFiIjI/Hs2TPcuXOnQObGUe7hM3XqVEydOjXLeEEQsGXLFvzyyy9ZxmX0IMrg7u6OwYMHo2XLlrh37x5mzpwJDw8PrSd9btWqFQ4cOABvb2+kpKRg3LhxMDExEVefKw5YCCIiKkRSqQRd6pRFx1plcOjuWyw98wyBEZl/Xbr4PAIXn1+CR9XSmNauKmprGFpGREREqqQSab5PnFwS1apVC7a2toiIiMD58+eRmJiYbQ+RrCgPIxIEIdOXeV1S7u1kZWWlk1wyhnxZW1vjypUrGifgzq6nUH4oW7YsunTpIk4i/fz583wvBCkUCmzfvj3H223dujXbQpA65ubm2LJlCxo0aID09HRMmzYN/v7+Wm/frl077NmzB71790ZaWhqGDx8OIyMj9OvXL8e5FEUcGkZEVATIpBL0rO+AM1NbY0HvOihvZaw2zu9JOLyXX8CoLTfwKCRWx1kSERFRSSWRSMThPLGxseJcL7llYGAgDkm6ePFinvPLC+Wih65yefDgAQDA09Mzy1XYlOcvKkjlypUTL6tbdS2vfH19ERQUBACYOHEidu7cmeW/b775BgDw4sWLXD8mdevWFSe3Pn/+fLaTYX/K29sb27dvh0wmg1wux1dffYXDhw/nKpeihj2CKFufLg+YMcs7EeU/PZkU/Ro7okf98th1IwgrfJ4jNDY5U9yph+9x6uF7dKlTFlPauqKynfq5hoiIiIjyy9SpU7FmzRokJyfjhx9+QMeOHVGxYsVct9etWzc8ePAAjx8/xsmTJ9GhQ4d8zFZ7DRo0gIODA4KDg7F27VpMnjwZRkZGBbrPjPmRspp4+/bt27h69Wqu9yEIgtZFHeWCk4uLS673qUnGsDCZTIYff/wxy2F1ANC2bVssX74c6enp2LJlS44mfFY2a9Ys7NixAwqFAvPmzUPHjh1ztH3fvn2RlJSEoUOHIi0tDX379sXhw4fRrl27XOVTVLBHEBFREWSgJ8Wgpk7w+9YDP3etgVJm6pdCPXovBO2X+GPqrjt4HZH3FTyIiIiINKlQoQKWLVsGAIiJiUHLli1x4cKFLLcRBEHjUuKTJ08WV+EaNmyY2EtGk6NHj+LevXs5TzwbUqkUM2fOBPBxafnBgwcjJSVFY3xsbCyWL1+ep326uroCAC5cuIDnz59nuj08PFycNDm3evXqhZUrV2a7ytumTZtw9uxZAB8f4/weFpaQkIB9+/YB+Dh/T3ZFIODjCmWtW7cGAOzevTvLxyMr1apVQ69evQB87O3l6+ub4zYGDx6MVatWAQBSUlLQo0ePHA0zK4rYI4iy9ekLcnBwMBwdHQspG6KSxUhfhuEtK+ILN0dsufwaa869QFSiaq88hQDsu/0WB+++Q9+GDpjgVRkO1rkfs09ERESkyddff423b99izpw5ePfuHdzd3eHl5QVvb2/Url0bNjY2kMvlCA0Nxa1bt7B7927x+4RMJlOZ7Nne3h6bN29Gnz59EBISgkaNGmHo0KHo1KkTHBwckJaWhuDgYFy7dg179+7Fy5cvcfjwYa2XZs+JMWPG4PTp09i/fz/27NmDW7duYfTo0XBzc4OlpSViY2Px+PFj+Pn54dChQzAyMsKECRNyvb/Bgwfj8OHDSEhIQOvWrfH999+jYcOGACAu8R4aGopmzZrh8uXLudpHUFAQxo8fj++++w7e3t5o1aoVqlatCmtrayQnJ+Px48fYs2cPjh07BuDjkLAlS5bk+9Cwffv2IT7+46IovXv31nq73r174+zZs4iOjsahQ4fQt2/fXO1/5syZ2Lt3LwBg3rx58PT0zHEbo0ePRlJSEqZMmYLExER07doVZ86cgZubW65yKmwsBBERfQZMDPQwpnUlDGxSAZsuBmLt+ZeIS1ZdclWuEPDv9SD8dysYXzSugAlelWFvUbDdmomIiKjkmT17NurWrYvp06fj5cuX8PHxgY+Pj8Z4iUSCDh06YOHChSpz0QAfe60cPHgQQ4cORWRkJFavXq1xuW6pVApTU9N8PRblHHft2oXJkydj9erVePHiBWbMmKExXpteLVnp06cPhg0bho0bN+Ldu3eYNGmSyu0ymQxLlixBVFRUrgtBDg4OuHnzJuLj48W5dzSxtLTE33//LfaeyU8Zw8IkEkmO2u/VqxcmTJgAhUKBLVu25LoQVL9+fXTu3BnHjh2Dj48Prly5gqZNm+a4nW+++QaJiYmYNWsW4uLi0LFjR/j4+KBevXq5yqswcWgYEdFnxNxIHxPbuOLCDC9M9KoMU4PMS5OmyQVsvfIarRb4Yu6Rh/gQn7uutERERESa9OzZE0+ePMHu3bsxYsQI1KhRA6VKlYKenh4sLCxQsWJFdOvWDb///jtevHiB48ePa1yNy9vbG69evcKff/4JLy8v2NvbQ19fH8bGxqhYsSK6du2KxYsXIzAwMFe9ObSlr6+PlStX4u7du5g4cSJq164NS0tLyGQyWFpaol69ehgxYgT27t2LR48e5Xl/GzZswNatW+Hu7g5zc3MYGhrCyckJgwYNwqVLlzB58uQ8tX/gwAE8fvwYS5cuRb9+/VCzZk3xeExNTVGhQgV07twZf/31F54/f57noWjqvH37ViwSNmvWLFMhMCv29vbi3EAnTpxAeHh4rvOYNWuWeHnu3Lm5bmfmzJn48ccfAQBRUVFo3759vpwLuiYRBEEo7CTo86I8NCwoKAgODg6FnBFRyRWZkIo1515g8+VAJKcp1MYY68swpLkzRrdygbWpgdoYIiKiz9GzZ8+Qnp4OPT09cc4VIqKiKjevWQXx/Zs9goiIPmM2pgb4oXN1+M/wxLAWzjDQy/yynpQmx+pzL+C+wBeLTz9FTBJX/iMiIiIiKqlYCCIiKgbszI3wi3dNnPvWAwObVIC+LPMkf/Ep6Vh29hnc5/tghe9zJKSkq2mJiIiIiIiKMxaCiIiKkbKWxvi1Z234TPNAv0YOkEkzF4Rik9Ox8OQTuC/wxVr/F0hKlRdCpkREREREVBhYCCIiKoYcbUywoE9dnJnaGj3qlYO6VUAjE1Lx27HHaLXQF5suvkJKOgtCRERERETFHQtBRETFWMVSpvjri/o4+U0rdK5dRm1MeFwKZh9+CI+Ffthx9Q3S5OonnSYiIiIios8fC0FERCVAFXtzrBzYEEcntUTb6vZqY0JikjFz/314LfLDnhtBSGdBiIiIiIio2GEhiIioBKlZzhLrhzTCgfEt0KpKabUxQZFJ+HbvPbRf4o+Dd95CrhB0nCURERERERUUFoKIiEqgeo5W2DLcDXvGNEMzF1u1MS8/JGDyv3fQaak/jt8PgYIFISIiIiKizx4LQUREJVhjZxvsHNUUO0Y2QUMna7UxT9/HY+z2W+j69wWcffQegsCCEBERERHR54qFICIiQvPKpbB3TDNsGtYYdRws1cY8DInFiM030HPlJfg/DWdBiIiIiIjoM8RCEBERAQAkEgk8qtrh4PgWWDuoIaqVMVcbdycoGoM3XEP/NVdw5WWEjrMkIiIiIqK8YCGIiIhUSCQStK9ZBscmuWP5gPqoVNpUbdy1wEh8sfYKBq6/gpuvo3ScJRERERHR56Wo9KhnIYiIiNSSSiXoWqccTk1pjSX968LJ1kRt3MXnEei96hKGbbyG+8ExOs6SiIhKMplMBgCQy+VQKBSFnA0RkWZyuRxyuRzA/3/tKiwsBBERUZZkUgl61nfAmamtMb93bZS3MlYb5/skHN7LL2DUlht4HBqr4yyJiKgkMjIyAvDxV/b4+PhCzoaISLPo6GjxsomJ+h9YdYWFICIi0oq+TIr+jSvAd7oH5vaoBXsLQ7Vxpx6+R6el5zFhxy08D+OHciIiKjgWFhbi5dDQUMTGxrJnEBEVGYIgIDk5GWFhYQgLCxP/bm2tfrVeXdEr1L0TEdFnx0BPikFNndC3oQO2X32DVX7P8SE+VSVGEIAj90Jw7H4IetQvj8ltXOFkq36uISIiotwyNTWFsbExkpKSIJfL8fbtW0gkkkIfdkFEBHwcDvbpvECWlpYwNFT/g6quSISiMlsRfTaCg4Ph6OgIAAgKCoKDg0MhZ0REhSkxNR1bLr/G6nMvEJ2YpjZGJpWgb0MHTGzjqnFoGRERUW4oFAq8efMGSUlJhZ0KEVGWSpcuDVtbW0gkEq23KYjv3ywEUY6xEERE6sQlp2HjxUCsO/8SccnpamMMZFJ84eaI8Z6VYW9hpOMMiYiouBIEAQkJCYiLixN7BxERFTapVAoDAwOYmprCzMwMBgYGOW6DhSAqElgIIqKsxCSmYd35l9h48RUSUtV/EDfUk+Krpk4Y61EJpcwKt2ssEREREVFRVRDfvzlZNBER5StLE31M71AV57/zwuhWLjDSz/xWk5KuwD8XXsF9vi/mn3iMqIRUNS0REREREVF+YyGIiIgKhI2pAX7oXB3+MzwxtLkzDGSZ33KS0uRY5fcC7gt8sfj0U8Qmq59jiIiIiIiI8gcLQUREVKDszI0wu1tNnJvhgYFNKkBPmnlyvPiUdCw7+wzu832xwvc5ElLUzzFERERERER5w0IQERHpRFlLY/zaszZ8p3ugb0MHyNQUhGKS0rDw5BO4L/DFOv+XSNIwxxAREREREeUOC0FERKRTjjYmWNi3Lk5PaYXu9cpB3eqZkQmp+PXYI7Ra6ItNF18hJZ0FISIiIiKi/MBCEBERFQqX0mZY+kV9nPymFTrXLqM2JjwuBbMPP4TnQj/suPoGaXKFjrMkIiIiIipeWAgiIqJCVcXeHCsHNsTRSS3Rtrqd2ph3McmYuf8+vBb5Yc+NIKSzIERERERElCssBBERUZFQs5wl1g9pjAPjW6BVldJqY4Iik/Dt3ntov8QfB++8hUIh6DhLIiIiIqLPGwtBRERUpNRztMKW4W7YM6YZmrrYqI15+SEBk/+9g45L/XEiIASCwIIQEREREZE2WAgiIqIiqbGzDf4d1Qw7RjZBQydrtTFP38djzLZb6Pr3BZx99J4FISIiIiKibLAQRERERVrzyqWwd0wzbBzWGLXLW6qNefAuFiM230DPlZdw/lk4C0JERERERBqwEEREREWeRCKBZ1U7HJrQAmsHNUS1MuZq4+4ERWPQP9fQf80VXH0ZoeMsiYiIiIiKPhaCiIjosyGRSNC+Zhkcm+SO5QPqo1JpU7Vx1wIj0X/tFXy1/ipuvYnScZZEREREREWXXmEnQERElFNSqQRd65RDp1plcfDOWyw9+wyvIxIzxV14/gEXnn+AZ9XSmNquKmo7qB9aRkRERERUUrBHEBERfbZkUgl6NXDAmamtMb93bZS3MlYb5/skHN7LL2D01ht4HBqr4yyJiIiIiIoOFoKoREhMjkbIh0eFnQYRFRB9mRT9G1eAz/TWmNu9JuwtDNXGnXzwHp2WnseEHbfwPCxex1kSERERERU+FoKoRPjX5zt0PtIX8/7rhfeRzws7HSIqIIZ6Mgxq5oxz33rip641UMrMIFOMIABH7oWg/ZJzmLr7Dl5HJBRCpkREREREhUMicI1dyqHg4GA4OjoCAIKCguDg4FDIGWUtIfEDOu7yQLRUAgAwEAT0s6yBEZ4LUMrKuXCTI6IClZiajs2XXmON/wtEJ6apjdGTStC3kQMmeLlqHFpGRERERFQYCuL7N3sEUbG33ec7sQgEAKkSCbbFPkKnA12x6MAARMYGF2J2RFSQTAz0MNajEs7P8MSUtlVgbph5jYR0hYCd14LgudAPPx8MwPvY5ELIlIiIiIhIN1gIomIvNSkSBorMHd+SJRJsirmPjv91xNLDgxEdH1II2RGRLpgb6WNyW1dc+M4LEzwrw8RAlikmVa7Alsuv0WqBL+YdeYgP8SmFkCkRERERUcHi0DDKsc9taBgAhL4+j/Xnf8Z/6eFIl0jUxpgqBHxV2g2DvRbAwqSUjjMkIl2KiE/BGv+X2HI5EMlpCrUxJgYyDGnujNGtXGBlknmuISIiIiKiglYQ379ZCKIc+xwLQRnevTyLtRfn4KA8UmNByFwhYLB9c3zl+QfMjG10nCER6VJYXDJW+r7AjqtvkCpXXxAyN9TD8JYVMcK9IiyM9HWcIRERERGVZCwEUZHwOReCMgQ9O4E1l+fhsCIaCg0FIUsFMLRsKwzw+B0mRhY6zpCIdOlddBKW+z7H7utBSFczlBQALI31MaqVC4Y2d4apmrmGiIiIiIjyGwtBVCQUh0JQhsAnh7D6yu84JsRB0FAQslYAwx3aoH/reTA2MNNxhkSkS0GRiVh29hn+uxUMDfUg2JoaYEzrShjUzAlG+pnnGiIiIiIiyi8sBFGRUJwKQRlePNyLVdcW4qQkUWOMrQIYWaEj+rb6Hwz1ucQ0UXH2MjweS88+w6G776DpXdLO3BDjPSvjCzdHGOqxIERERERE+Y+FICoSimMhKMPT+zuw6sYSnJFqXj7aTiHBSOcu6N3yFxjoG+kwOyLStafv47Dk9FMcDwjVGFPO0ggT27iiT0MH6Mu4GCcRERER5R8WgqhIKM6FIACAIODR3c1Yeftv+ElTNYaVUUgwqlJP9GgxC/oyrihEVJwFvI3BX2ee4syjMI0xFWxMMKmNK3rUKwc9FoSIiIiIKB+wEERFQrEvBGUQBNy/tQ4r7q7CRVm6xrDyghSjK/eDd7MZ0JNxRSGi4uxOUDQWnXqC888+aIxxKW2KyW1c4V2nHKRS9XOPERERERFpg4UgKhJKTCEogyDgzvUVWBGwHldkco1hFQQZxlQdgM5uUyGTcUUhouLsemAk/jz5BFdfRWqMqWpvjintXNGhZhlINExGT0RERESUFRaCqEgocYWgDAoFrl/7CysebMZNPYXGsIqCHsbVGIr2jSdCKuHwEKLiShAEXHoRgUWnnuDWm2iNcTXLWWBa+yrwrGrHghARERER5QgLQVQklNhC0P8R5Om4emURlj/ejrt6mp8+laGPcbVGok2DMSwIERVjgiDA72k4Fp96ivtvYzTG1XO0wrT2VdCycikWhIiIiIhIKywEUZFQ0gtBGYT0NFy69AeWP9uNgCxGglWDIcbVGQOPeiP45Y+oGBMEAacevseS00/xODROY5xbRRtMa1cFTVxsdZgdEREREX2OWAiiIoGFIFVCeir8z8/Fipf78UhPc6GnpsQI4+tNRMvag1gQIirGFAoBR++H4K8zT/EiPEFjXMvKpTC1fRU0qGCtw+yIiIiI6HPCQhAVCSwEqSekJcPn3GyseH0Ez7IoCNWRmGB8w2/QrMYXLAgRFWNyhYCDd95i6dlneB2RqDHOq5odprarglrlLXWYHRERERF9DlgIoiKBhaCsKVITccrvJ6wKOoGXeprnBmogNcOExt+icbVeOsyOiHQtTa7AvlvBWHb2Od5GJ2mM61DTHlPaVUG1MhY6zI6IiIiIijIWgqhIYCFIO/LkOJzwm4VVb8/idRYFoSYyC4xv8gPqu3bVYXZEpGsp6XLsvh6E5b7P8T42RW2MRAJ0rVMO37R1RaXSZjrOkIiIiIiKGhaCqEhgIShn0pOicdT3B6wO8UdwFgWh5nrWGN9sFuq4dNBhdkSka8lpcmy78hqrz73Ah/hUtTFSCdCjfnl806YKKtia6DhDIiIiIioqWAiiIoGFoNxJS4zEobMzsCbsMkKyKAi10rfFuOY/o6azlw6zIyJdS0xNx+ZLr7HG/wWiE9PUxujLJBja3BkTvFxhaayv4wyJiIiIqLCxEERFAgtBeZMWH4b9PjOw5sN1hMk0F4S8DOwwruVsVHV012F2RKRrcclp2HAhEOvPv0RcSrraGGsTfUxpVwUD3CpAL4vXDSIiIiIqXlgIoiKBhaD8kRL7DnvPzsD6qNv4kMUXu3aGZTHOfS4ql2+iw+yISNeiE1Ox7vxLbLwYiMRUudqYynZmmNW5Ojyqluaqg0REREQlAAtBVCSwEJS/kmOCsevsdGyIvo9IDQUhiSCgo7Ejxrb+DRXL1NdxhkSkSxHxKVjl9wKbLwciTa7+LdrdtRR+7FIDVcuY6zg7IiIiItIlFoKoSGAhqGAkRr3CzrPTsTH2MWI0FISkgoCups4Y3fo3VLCro+MMiUiXXkck4I/jj3E8IFTt7VIJ8IVbBUxtVwWlzAx1nB0RERER6QILQVQksBBUsOI/PMUOnxnYFP8McRoKQjJBQHezyhjl8TvKl6qu4wyJSJeuvozAvKOPcP9tjNrbzQz1MN6zMoa1cIaRvkzH2RERERFRQWIhiIoEFoJ0IzbsAbb5fIetSa8QL1VfENITBPS0qIpRHvNRxqayjjMkIl1RKATsv/0WC04+xvvYFLUxDtbG+L5TNXSpXZbzBxEREREVEywEUZHAQpBuxYTexWbf77Et+Q2SNBSE9AUBfSxrYqTnfNhZOes2QSLSmcTUdKz1f4k1514iKU39hNKNnKzxY9caqOdopdvkiIiIiCjfsRBERQILQYUj8u0NbDr3A3amvEOyhoKQoSCgr3UdjPBcgFIWfFyIiqvQmGQsPPkE/90K1hjTo145zOhYDeWsjHWYGRERERHlJxaCqEhgIahwfQi6jA3nZmF32nukaCgIGQkCvrRtiGGeC2BtZq/jDIlIV+4Hx2DukYe4Fhip9nYjfSlGubtgdOtKMDXU03F2RERERJRXLARRkcBCUNEQFuiP9ed/xl75B6RpmA/ERBAwsHQTDPH4A5ampXWcIRHpgiAIOPkgFL8de4w3kYlqY+zMDTG9Q1X0buAAmZTzBxERERF9LlgIoiKBhaCiJeTFaay7+D/sV0QhXUNByEwABtk3x1cev8PC2EbHGRKRLqSky7Hl0mss83mGuOR0tTE1ylrgp6410KySrY6zIyIiIqLcYCGIigQWgoqmt0+PY+3luTgoxEKuoSBkrgCGlmuNgR6/wdTQQscZEpEuRMSnYOnZZ9h+9Q3kCvVv8e1r2OOHztVRsZSpjrMjIiIiopxgIYiKBBaCirY3jw5gzdU/cATxUGgoCFkpgKEObfBlq3kwMTTTcYZEpAvP3sfh12OP4PckXO3t+jIJBjdzxiQvV1ia6Os4OyIiIiLSBgtBVCSwEPR5ePlgN1Zf+xMnJIkQNBSEbBTAcKdO6NfyFxgbsGcAUXF07mk4fj36EE/fx6u93cpEH9+0ccXApk7Ql6mfgJ6IiIiICgcLQVQksBD0GREEPL+/HStvLsFpaarGsFKCBCOdvdGn5U8w1DPSYYJEpAvpcgX+vR6EJaefIiJB/WuBS2lTzOpcHV7V7CDRUDwmIiIiIt1iIYiKBBaCPkOCgCd3NmHl7b/hI0vTGGYvSDGqUi/0bPYD9PUMdJggEelCbHIaVvg+x8YLgUiVK9TGtKxcCrO6VEf1spxHjIiIiKiwsRBERQILQZ8xQcCDm2ux8u4q+OvJNYaVFaQY7doP3Zp+C30ZC0JExU1QZCL+OP4YR++HqL1dKgH6N3bE1HZVUdrcUMfZEREREVEGFoKoSGAhqBhQKHD3+nKsDPgHl/TU9woAAAdBhjHVBqKL2xToSfV0mCAR6cL1wEjMPfIQ94Jj1N5uaiDDOM/KGNGyIoz0ZTrOjoiIiIhYCKIigYWgYkQhx60rS7Dy0RZc1dP8UuAEPYypPhSdGk2ATMovg0TFiUIh4ODdt1hw4glCYpLVxpS3MsZ3narBu05Zzh9EREREpEMsBFGRwEJQMSRPx/XLC7H8yQ7cyqLjjwsMMLbWSLRvMBpSCVcXIipOklLlWHf+JVb5vUBSmvqho/UrWOGnrjXQoIK1jrMjIiIiKplYCKIigYWg4ktIT8Xli79jxfM9uKen+Vd/VxhifN0x8Ko7gr0DiIqZ97HJ+PPkE+y9FQxNnxC61S2H7zpVQ3krY90mR0RERFTCsBBERQILQcWfkJaMC+fnYcXLA3igr7nQU11ijPH1J6BVrUEsCBEVMwFvYzD3yENcfRWp9nZDPSlGulfEWI/KMDPkHGJEREREBYGFICoSWAgqOYTUJPj5z8aKwCN4oq95KFgtqQnGN5yCFtX7syBEVIwIgoBTD9/j92OPEBiRqDamlJkhprevgr6NHCGT8vlPRERElJ9YCKIigYWgkkeREo+zfj9hZfApPNfTXBCqJzXHeLfpaFKlJwtCRMVIaroCWy4HYtnZZ4hNTlcbU72sBX7qUh3NK5fScXZERERExRcLQVQksBBUcimSY3HS9wesfOeLQD3Nq4c1lFliQtPv0ahyVx1mR0QFLTIhFUvPPMW2q28gV6j/+NC2uh1mdq4Ol9JmOs6OiIiIqPhhIYiKBBaCSJ4YhWO+P2B1qD/eZFEQaqJnjQnNfkQ9l/Y6zI6ICtrzsHj8duwRfB6Hqb1dTyrBoGZOmNzGFVYmBjrOjoiIiKj4YCGIigQWgihDekI4Dp/9HmvCL+NtFgWhFvqlML75T6jt7KXD7IiooJ1/Fo5fjz7C49A4tbdbGutjchtXfNXUCQZZDCslIiIiIvVYCKIigYUg+lRa/HscOPst1n64gdAsCkIeBnYY13IOqju21GF2RFSQ5AoBu64HYfHpJ/gQn6o2pmIpU8zsXB1tq9tx/jAiIiKiHGAhiIoEFoJIk9SYt/jv7LdYH3UHYVkUhNoYlsW4VvNQpZybDrMjooIUl5yGlX4v8M+FV0hNV6iNaV7JFrO6VEfNcpY6zo6IiIjo88RCEBUJLARRdpKjX2PvmW+xPjYAETLNBaEOxg4Y2+pXVCrTQIfZEVFBCopMxPwTj3HkXoja2yUSoF9DR0xrXwV2FkY6zo6IiIjo88JCEBUJLASRthIjXmC3z7fYEPcYURoKQhJBQGfTihjT+lc429XRcYZEVFBuvo7C3CMPcScoWu3tJgYyjPOohJHuLjDS11wwJiIiIirJWAiiIoGFIMqpxPDH2OEzAxsTXiBWpn7CWKkgwNu8MkZ7/AFH22o6zpCICoJCIeDwvXeYf/wx3sUkq40pZ2mE7zpVQ7e65Th/EBEREdEnWAiiIoGFIMqt+PcB2Ob7HbYkBSJOqr4gpCcI6G5RFaM856OcdWUdZ0hEBSE5TY71519ipd8LJKbK1cbUc7TCT12ro6GTjY6zIyIiIiq6WAiiIoGFIMqrmJDb2Or7PbalBCMhi4JQb6uaGOk5H2UsnXWbIBEViLDYZCw69RS7bwZB06ePrnXK4ruO1eBoY6Lb5IiIiIiKIBaCqFDUrFlT5XpaWhqePXsGgIUgypvo4GvYdG4mdqSGIElDQchAENDXph5GeM5HafPyOs6QiArCg3cxmHfkES6/jFB7u4GeFCNaVsQ4j0owN9LXcXZERERERQcLQVQoWAiighbx+gI2+v+Ef9PDkKKhIGQoCOhfqiGGe8yHrVkZHWdIRPlNEASceRSG3449wqsPCWpjSv0/9u47KqqrawP4c2eGXgVEFLCLvfcC9t57xZaoMfYSE2NibIkxxpZYorH33nsXC7H3igoINhSUXqbc7w9e+NQ7g6jDFHh+a7mW8dyZ2RjRuc/sfY69JUY3Lo6uVb0hl3H/ICIiIsp5GASRSeBoGGWVV8EnsOz0L9iijkKKjps+GxHo7l4dfev+hlx27gaukIj0LUWlwdr/QjHvWBCiE5Varynh4YAJLUvCt1huA1dHREREZFwMgsgkMAiirPYi6BCWnpuKbeJbqHScImQrAr08aqN33elwssll4AqJSN/eJqRg7tEgrP0vFCqN9rcmDUq448cWJVHU3d7A1REREREZB4MgMgkMgshQnt3fgyWBv2EXYnUGQg4i4J+vHnr5TYODtZOBKyQifXv0Kg7T99/F0bsRWtflMgG9qufHyEY+yGVnaeDqiIiIiAyLQRCZBAZBZGhhd7Zj8fnfsUdIgEZHIOQoAn29GqGH7xTYWTkYuEIi0rezD19j6t47uPciVuu6o7UCwxsWQ++aBWGp0L63GBEREZG5YxBEJoFBEBmFKCLk1kb8c3EW9suSIOoIhHKJAvrlb46udSbC1tLOwEUSkT6pNSK2Xg7DzEMP8DouWes1BV1t8UPzkmhaOg8EHX8vEBEREZkrBkFkEhgEkVGJIh5fX4NFV+bioFz7xrIA4CIK+LpQG3SuNQHWFjYGLJCI9C0uWYVFJx/i39PBSFFptF5To7ALfmpZCmU8OSJKRERE2QeDIDIJDILIJIgiHlxZhkXXFuCoQqXzstyiDF8X6YBONX+ApcLKgAUSkb6Fv0nAHwfvY/f1Z1rXBQHoWMkL3zUtjjyO1gaujoiIiEj/GASRSWAQRCZFo8Hdy/9g4Y3FOKnQ3ikAAHlEOQb6dEX76mNhIbcwYIFEpG9XnrzB1L13cPXJW63rtpZyfFO3CAb4FoaNpdywxRERERHpEYMgMgkMgsgkadS4df5vLLi9HGcsdP+15gkFBpXohVZVh8NCxkCIyFyJoog9N55jxoF7ePo2Ues1eZ2sMa5ZcbQt7wmZjPsHERERkflhEEQmgUEQmTS1Ctf+m40Fd9fgvwxyHm9YYHCpvmhReQjkMnYMEJmrJKUay84EY+GJh4hPUWu9pryXE35qVQpVC7oYuDoiIiKiL8MgiEwCgyAyC2olLp2dgQUPNuKShe5OgIKwxOCyX6NphYEMhIjM2KvYZMw+ch+bLoZBo+OdTcuyefFD8xLwdrE1bHFEREREn4lBEJkEBkFkTkRlEi6cnY75D7fhWgaBUFHBCoPLD0ajcv0gE2QGrJCI9Onu8xhM23cHZx9Gal23lMvQr05BDKlfFI7WHA8lIiIi08YgiEwCgyAyR2JKIs6dnooFwbtw00J30FNcsMG3lYahfuleEATuKUJkjkRRxPF7Efh1/108fhWv9RpXO0uMauyDblW9oZAz/CUiIiLTxCCITAKDIDJnYnI8TgdMwvzQfbhroXsUrJTMDkMqj4JvyS4MhIjMlFKtwbr/QjH3WBDeJii1XuOTxx4TWpZCXZ/cBq6OiIiI6OMYBJFJYBBE2YGYFIvjJ3/CgvAjCMogECond8CQKmNRs3h7BkJEZio6QYm/jgdh1bkQqHRsIFSveG5MaFESxfI4GLg6IiIiIt0YBJFJYBBE2Ykm8S2OnJiAhc9O4HEGgVAlhROGVPsB1Yq1MmB1RKRPwa/j8dv+uzhy56XWdblMQI9q+TGqsQ9c7CwNXB0RERGRFIMgMgkMgig7Use/xsHj4/HPy7MIySAQqmqRC0Nr/IRKhZsYsDoi0qdzj15j2t67uPM8Ruu6g7UCwxoURZ9aBWGl4GmCREREZDwMgsgkMAii7EwVF4F9x77HP6/PIzyDG8CaFm4YUnsiyheob8DqiEhf1BoR266EY+ah+3gVm6z1mvwuthjfvASalfHgaCgREREZBYMgMgkMgignUMY8x55j32Fx1BU8yyAQ8rXKgyG1J6O0d20DVkdE+hKfrMLiU4+wOOAxklUarddUK+SCn1uWQlkvJwNXR0RERDkdgyAyCQyCKCdRRodhx7HvsPjNDURkEAg1tPHEiHozUci9rAGrIyJ9efY2EX8cvIed157pvKZDJU+Ma1oCHk7WBqyMiIiIcjIGQWQSGARRTpQcFYytx77D0pg7eK0jEJKLIjrmKofBDWfBzT6vgSskIn24FvYWU/feweXQN1rXbSzkGOhXGIPqFoatpcLA1REREVFOwyCITAKDIMrJkl4/wKbj47A87gGi5NoDIRsR6OtZH33rToetpZ2BKySiLyWKIvbdfI7fD9xD+JtErdfkcbTCd01LoENFT8hk3D+IiIiIsgaDIDIJDIKIgISIO9h4/HssT3iEaB2BkKsow7c+3dG+xhhYyCwMXCERfakkpRorzoZgwYmHiEtWab2mrKcTfmpZEtULuxq4OiIiIsoJGASRSWAQRPT/Yp5exvLjY7FWFYFkmUzrNQVhiZGVRqJBmV48eYjIDL2OS8bsIw+w8cITaHS8a2pW2gPjW5RAAVd2ARIREZH+MAgik8AgiEjqxf29WHB2MnbJEiHqCHsqKJwwuvYkVCzYyMDVEZE+3HsRg1/33cXpoNda1y3lMvStXRBD6heFkw27AImIiOjLMQgik8AgiEgHUcSDS4sx9/pCnLbQ/VdrAxsvjKj/Bwrn5gljROZGFEWcvP8K0/bdwaNX8VqvcbGzxKhGxdC9Wn4o5No7BYmIiIgyg0EQmQQGQUQfoUrBhYBpmP14K25b6D5hrINLWXzbYDZPGCMyQ0q1BhsuPMGcIw/wJkGp9Zqi7vaY0LIk6hd3N3B1RERElF0wCCKTwCCIKHM0iW9x+Oh3mBtxFk91HDlvIwJ9PBugb91fYWdpb+AKiehLRScoMf9EEFaeC4FSrf0tlZ9PbkxoURLFPRwMXB0RERGZOwZBZBIYBBF9GuXbUGw+PBL/xN3HWx0njLmIMgwu3g0dq4/lCWNEZijkdTx+P3APB2+/0LouE4Du1fJjVGMfuNlbGbg6IiIiMlcMgsgkMAgi+jyxTy9jxfGxWKOKQJKOE8YKCJYYUXEEGpXx5wljRGbov8eRmLbvDm49jdG67mClwJAGRdG3VkFY6xgdJSIiIkrDIIhMAoMgoi/z8v5eLDw7CTtlSdDoCHvKKZwwutYkVC7EE8aIzI1GI2L71aeYeegeXsYka73G28UGPzQriRZlPRj6EhERkU4MgsgkMAgi0gNRRNClfzD3+kIEZDAJVt/GCyPrzURh9zKGq42I9CIhRYV/Tj3GkoBHSFJqtF5TpUAu/NyqFMp7Oxu2OCIiIjILDILIJDAIItIjVQouBkzF7EfbcMtS+5iITBTRwaUcvq0/C7kdeMIYkbl5Hp2ImQfvY/vVpzqvaV/RE+OaFUdeJxsDVkZERESmjkEQmQQGQUT6Jya+xaGj3+Gvl2cQZqHQeo2NCPT2bIB+PGGMyCzdCH+LqXvv4GLIG63r1hYyDPQtjEF1i8DOSvvfA0RERJSzMAgik8AgiCjrKN+EYPORkVgc9wBvMjhhbFDxbuhcbSws5DxhjMiciKKIg7de4LcDdxEWlaj1GncHK4xtWhydKnlBJuP+QURERDkZgyAyCQyCiLJe3NNLWH5sLNaoX+k8YSz//04Ya8wTxojMTrJKjZVnQzD/+EPEJqu0XlM6nyN+alkKNYu4Grg6IiIiMhUMgsgkMAgiMpyX93Zj0dkp2CHP+ISxUbUmoQpPGCMyO5FxyZhz9AHWn38CjY53ZE1K5cGPLUqioJudYYsjIiIio2MQRCaBQRCRgWk0eHRpMeZeX4iTlrovq2fjhZH1Z6JIbp4wRmRuHryMxa/77uLUg1da1y3kAnrXLIjhDYrByZYjoURERDkFgyAyCQyCiIxElYxLp6Zi9uPtuJnBCWPtc5XFtw1mwd0hn4ELJKIvdfJ+BH7ddxdBEXFa13PZWmBkIx/0qJ4fFnLtY6NERESUfTAIIpPAIIjIuMSENzhy7DvMe3kWT3ScMGYtAv6eDdDfbxrsrRwMXCERfQmVWoMNF8Mw58gDRMWnaL2mSG47TGhZEvWLu3OPMCIiomyMQRCZBAZBRKZB+SYYWw6PwuL4B4jSccJYLlGGQcW7o0u1MTxhjMjMRCcqsfDEQ6w4G4IUtUbrNXWKuuGnViVRwsPRwNURERGRITAIIpPAIIjItMSFX8CK4+OwRv0KiTpOGPMWLDG84nA0LdOb3QNEZiY0Mh4zDt7D/psvtK7LBKBr1fwY3dgHuR2sDFwdERERZSUGQWQSGAQRmaZX93Zh4dkp2C5P1nnCWNn/nTBWlSeMEZmdC8FRmLbvDm6ER2tdt7dS4Nv6RdC/diFYW2jvEiQiIiLzwiCITAKDICITptHg8aV/MPf6IpzI4ISxujZeGFn/DxTNXdZwtRHRF9NoROy89hR/HLyPFzFJWq/xdLbBD81LoFW5vOwAJCIiMnMMgsgkMAgiMgOqZFw+NRWzH2/DDUvtG0rLRBHtXMrh2/p/Ig9PGCMyKwkpKvwbEIx/Tj1ColKt9ZpK+Z3xc6tSqJg/l4GrIyIiIn1hEEQmgUEQkfkQE6Jw9Oh3mBdxDqEfOWGsn99UOFhxw1kic/IiOgkzD93HtivhOq9pWyEfxjUrAU9nGwNWRkRERPrAIIhMAoMgIvOjjArGtiOjsCiDE8acRRm+4QljRGbpZng0pu67gwvBUVrXrRQyDPAtjMH1isDOSnsoTERERKaHQRCZBAZBROYrPvwCVh77Dqs0r3WeMOYlWGIETxgjMjuiKOLQ7ZeYfuAuQiMTtF6T28EKY5v4oHNlb8hk/P4mIiIydQyCyCQwCCIyf6/u7sCic9OwXZ4MtY6wp4zCCaN5whiR2UlWqbH6XCj+Oh6E2CSV1muqFXTBjE7lUMjNzsDVERER0adgEEQmgUEQUTbxvxPG5l1fhOMZnDDma+OJUfVmopg7TxgjMidR8SmYe/QB1p1/ArVG+nbP2kKGsU2Ko1/tQpCzO4iIiMgkMQgik8AgiCibUSbhasBUzHq8HdczOGGsjUs5DKn/Jzx4whiRWQl6GYvf9t/FifuvtK5Xyu+MPzqVR1F3ewNXRkRERB/DIIhMAoMgouxJTIjC8aPfYW7EOYToOGHMSgT8vRqgvy9PGCMyN6cevMJPO28iLCpRsmapkGFko2IY6FsYCrn2/cOIiIjI8LLi/pv/0hMREQBAsHVBwzbLsL3DPvxsVRguarXkmmQBWPr0OFps8MXawOlIUacYoVIi+hx1fXLj4Ag/9K1VULKWotLgj4P30WHROdx/EWv44oiIiMhgGAQREdF7LFwKo0u3XdjfeDkGIxdsNBrJNW8FDWY8WI8262rhwM1V0IjSa4jI9NhZKTCpTWlsHlQTBV1tJes3wqPR6u/T+OtYEJRqfl8TERFlRwyCiIhIKzvvGvi2TwD2V5+GLkpLyLVMEj8VkzHuyp/ovt4P54OPGKFKIvoc1Qq54MAIPwzwLYQPDw5UqkXMPvIAbeafxa2n0cYpkIiIiLIMgyAiIsqQW6n2+Ln/RewoORgNk7VvK3dHFY2vA0Zj8OZmePDqpoErJKLPYWMpx4SWpbBtcC0UyS09Rv7u8xi0W3AWsw7fR7JKOipKRERE5olBEBERfZxMhkLVh2Buv0tY49kGFVJUWi87k/gUnfZ1x0+7u+NF7FMDF0lEn6NS/lzYN9wXg+sVkRwjr9KI+Pv4Q7T++wyuh701ToFERESkVwyCiIgo8yysUaHRr1jd4wzmOldFQaU0EBIFAbve3EKrbc0w5+gIxCRztITI1FlbyPF9sxLY8W0tFM/jIFl/8DIO7ReexfQDd5GkZHcQERGROWMQREREn0ywc0XDtsuxo/0+/GxVCK5axkaSBWD50+NosdEPq3nCGJFZKOfljD3D6mB4w2JQfNAdpBGBxaceo8Vfp3E5NMpIFRIREdGXYhBERESfTeFaGF267cb+xsvxLZy1njAWDQ1mPliPNutqYt/NlTxhjMjEWSpkGN3YB7uG1kbpfI6S9cev4tHpn0BM2XMHiSnsDiIiIjI3DIKIiOiL2eavgcG9U08Y66q0gELrCWMp+OHKLHRb74f/eMIYkckrnc8JO4fUxtgmPrCQv98dJIrA8rPBaDYvAP89jjRShURERPQ5GAQREZF+CALcSrXHT/0vYkfJQWis44Sxu6poDAgYjW82N8P9iBsGLpKIPoWFXIahDYph33BflPdykqyHRiag25L/MHHXLcQna99EnoiIiEwLgyAiItIvmRwFqw/D7H4XscazFSrpuDk8m/gUnff3wITdPfA8NtzARRLRp/DJ44Btg2thfPMSsFRI3z6uDgxFkzkBOBP02gjVERER0adgEERERFnDwgYVGk3Hyp5nMM+pKgrpOGFs95ubaLWtOWYf4QljRKZMIZdhUN0iODDCF5UL5JKsP32biF7LzmP89huISVIaoUIiIiLKDAZBRESUpQQ7VzRotxzb2+/FRMuCcNNywliKAKx4dhzNN/phFU8YIzJpRXLbY/Ogmvi5VSlYW0jfSm64EIamcwJw4n6EEaojIiKij2EQREREBqFwLYLO3fdgX5PlGCI6wVbLCWMx0ODPB+vRel1N7OUJY0QmSy4T8FWdQjg00g/VC7lI1p9HJ6HfiosYs/k6ohPYHURERGRKGAQREZFB2XrXwDd9TmN/tanopuOEsWdiCsb/74SxwODDRqiSiDKjgKsdNgyogaltS8PWUi5Z33YlHI3nnMKROy+NUB0RERFpwyCIiIgMTxDgWroDJvS/iJ0lBqFxsvbOn7uqaAwMGINBm5vhXsR1AxdJRJkhkwnwr1kQh0b6oU5RN8l6RGwyBqy+hBEbryIqnmOfRERExsYgiIiIjEcmR4EawzC73yWszddS5wlj5xKfosv+nvhxdw88i+EJY0SmyNvFFmu+qobfO5SFg5VCsr7r2jM0mXMK+28+N0J1RERElIZBEBERGZ+FDco3/h0re5zG305VUDhFuqeIKAjY8+YmWm9vjllHhiOaJ4wRmRxBENCtWn4cGuWHesVzS9Zfx6Xg23VX8O26y3gVm2yEComIiIhBEBERmQzB3g312q3AtvZ78YtlQeRWSTuEUgRg5bMTaLHRDysDpyNZzZtJIlOTz9kGK/pWxZ+dy8PRWtodtP/mCzSZcwq7rj2FqGWfMCIiIso6DIKIiMjkKNyKolP3PdjbeBmGic6w03HC2KwH69F6XS3s4QljRCZHEAR0quyFo6PronGpPJL1NwlKjNh4DQNWX0ZETJIRKiQiIsqZGAQREZHJss1fCwP7BGB/tanokaLQesLYczEFP16Zha7r/XCOJ4wRmRx3R2ss8a+Med0qIJethWT96N2XaDT7FLZcCmN3EBERkQEwCCIiItMmCHAp3QHjv7qEXcUHoqmOE8buqaIxKGAMBm5phrsvecIYkSkRBAFtK3jiyOi6aFk2r2Q9JkmF77beQL+VF/HsbaIRKiQiIso5GAQREZF5kMmRv+Zw/Nn3Itbna4HKydINpQEgMOEpuhzshfG7u/OEMSIT42ZvhQU9K2FRz0pws7eUrJ+8/wpN5gRgw4Un7A4iIiLKIgyCiIjIvFjaomzjGVjR4zTmO1VGES0njAHA3je30Gp7c/x5dARPGCMyMc3L5sXhUXXRrkI+yVpcsgrjt99Er2XnERaVYITqiIiIsjcGQUREZJYE+9yo224ltrbbjcmWBeCu5YQxpQCsenoczTf6YUXgbzxhjMiEuNhZYm63ivi3dxW4O1hJ1s8+jETTuQFYHRgCjYbdQURERPrCIIiIiMyaIrcPOnTfi72NlmG4xknrCWOx0GD2gw1ova42dt9YCbVGbYRKiUibxqXy4MiouuhU2UuylpCixsRdt9H93/8Q8jreCNURERFlPwYJggoXLozChQtj/vz5hng5IiLKgWwK1MKAvqexv+pk9EyR6zhhLBkTrs5C1w11cTb4EPcgITIRTrYW+LNzeazsVxV5nawl6+eDo9BsXgCWnn4MNbuDiIiIvohBgqDw8HCEhoaiQoUKhng5IiLKqQQBLmU64Yf+l7DbZwCaJWnv/LmvisY3AWMxYEtz3IngCWNEpqJecXccHuWH7tXyS9aSlBpM23cXnf85h4cRcUaojoiIKHswSBDk4eEBALCxsTHEyxERUU4nV8C71gjM7HcRG/K2QFUdJ4ydT3yKrgd64YfdPfA0JszARRKRNg7WFpjeoSzWfV0dXrmk7x2vPHmLFn+dxj+nHkGllo6CEhERUcYMEgRVr14dAHD79m1DvBwREVEqSzuUaTIDy7oHYIFjZRTVccLYvjc30Xp7C8w8OgJvk94atkYi0qp2UTccGumHPjULSNZSVBr8fuAeOi46h/svYo1QHRERkfkySBA0ePBgiKKIOXPmQKnU/iaciIgoqwgO7vBrvxJb2+/GFAvdJ4ytfnocLTbVxfLA6UhSJRmhUiJ6l52VApPblsGmgTVQwNVWsn49PBqt/j6Nv48FQcnuICIiokwxSBDUoEEDjB8/HtevX0erVq0QFsb2eyIiMjy5mw/a99iLvQ3/xQiNI+x1nDA258F6tFpfGztvrOAJY0QmoHphVxwc4Yev6hSCILy/plSLmHXkAdrOP4vbz6KNUyAREZEZEUQDHJkyZcoUAMC2bdtw8+ZNyOVy1K5dG+XKlUOuXLkgl8szfPzEiROzukT6BOHh4fD29gYAhIWFwctLetwrEZHJE0W8ubUVS/77FRstVFB9eHf5P8UUjhhdcyJqF2oCQcc1RGQ4l0PfYNzW63j0SnqcvEIm4Nt6RTC0QTFYKgzyeScREVGWyor7b4MEQTKZ7L03z6IoftKbabWan8aaEgZBRJStqFUIO/83/r75Lw5Y6/5gorqNJ0bV/R2l81QwXG1EpFWSUo25R4OwJOARtJ0mXzyPA2Z2LodyXs4Gr42IiEifzDoI+hIaLa37ZDwMgogoW0qJx+2TUzA7ZBcuWFnovKx5rjIYXncGvJykx1sTkWFdD3uL77Zex4OX0uPkZQIw0K8IRjYqBmuLjLvPiYiITJXZBkGUvTAIIqLsTIx9iTNHv8Ps1xfw0FJ7IKQQgW5e9TGozhQ4WzsbtkAiek+ySo0Fxx9i4clHUGlpDyqS2w5/dCqPygVyGaE6IiKiL5MV998cniYiInqH4JAHvu1XY2u7XZiq8EYeLSeMqQRg7dMTaLGpLpYG/sYTxoiMyEohx+gmxbFraG2UyusoWX/0Kh6d/jmHaXvvIDGF2w0QERExCCIiItJCnrs42vXcj70N/8VItQMctBxNHQsN5j3YgFbra2MHTxgjMqrS+Zywa2htjGnsAwv5+3tRiiKw9Ewwms8LwPnHkUaqkIiIyDQwCCIiIsqAdcE6+KrfWeyvMhH+yXJYaJmofimmYOLV2ei0wQ8Bjw+CU9dExmEhl2FYw2LYO8wX5bycJOshkQnouuQ//LLrFuKTpd1+REREOYHBg6CoqCjMmjULzZs3h7e3N+zs7GBnZwdvb280b94cs2bNQlRUlKHLIiIi0k0Q4FyuK8Z9dQm7fb5CiyTtnT8PVTEYcvo7fL21OW6/vGbYGokoXXEPB2wfXAvfNyuh9Rj5VYGhaDo3AGcfvjZCdURERMZl0M2iFy9ejLFjxyIhIQEAJJ+Yph0pb2tri1mzZmHgwIGGKo0+ATeLJqIcLzkOt09OxpyQPThvnfEJY8Pq/QFvR28DFkdE73oYEYdxW6/jypO3Wte7V8uPH1uUgEMG38tERETGYtanhv3++++YMGFCevjj5OSEihUrwsPDAwDw4sULXL16FdHR0amFCQKmT5+OcePGGaI8+gQMgoiIUokxL3D26DjMibyABx85YWxgncnIZc1Ti4iMQa0RseJsMP48fB9JSul+X3mdrDG9Q1nUK+5uhOqIiIh0M9sg6NatW6hYsSLUajXy5s2LmTNnonPnzrCweP9Ns0qlwpYtW/Ddd9/h2bNnUCgUuHr1KkqXLp3VJdInYBBERPQ+9at72Ht4FOYnheCFQqH1GnvI8JVPV/SsOgo2ChsDV0hEABDyOh7jtt3AhWDt2xB0quyFn1uWgpMtu4OIiMg0mO3x8fPnz4darUbu3LkRGBiIHj16SEIgAFAoFOjevTsCAwPh7u4OtVqN+fPnG6JEIiKizybPXQJtex7AngaLMVrHCWNx6SeM1eEJY0RGUtDNDhsH1MCUtqVhaymXrG+9HI7Gc07h6J2XRqiOiIjIMAwSBB0/fhyCIGD8+PHInz//R6/39vbG999/D1EUcezYMQNUSERE9OWsC/mhX7+zOFDlZ/RJlmk9YSyCJ4wRGZVMJqB3zYI4NNIPtYu6StYjYpPx9epLGLHxKt7EpxihQiIioqxlkCDo6dOnAIBatWpl+jG1a9cGADx79ixLaiIiIsoSggCnct0w9qtL2FOsP1omaT+iOu2Esa+2Nsetl1cNXCQRebvYYu1X1fFb+7Kwt5KOdO669gyN55zCgZvPjVAdERFR1jFIECSXp7beqlTa3wxro1antszLZAY/4Z6IiOjLyS3gWXs0fu97AZs8mqFGkvbOgosJT9H9YG+M3d0dYTFPDFwkUc4mCAJ6VM+Pw6P8UNcnt2T9dVwKBq+7giHrruB1XLIRKiQiItI/g6QsaeNgnzLmlXZtZkbJiIiITJaVA0o1nYl/e5zCYocKKJ6iPRA69OYW2mxvielHhyMqSftGtkSUNfI522Blv6qY2akcHK2l3UH7bj5H49mnsOvaU45zEhGR2TNIENS4cWOIoog///wTN2/e/Oj1t27dwsyZMyEIApo0aWKAComIiLKYgwdqdViDzW134je5Fzy0dMmqBGD90xNosak+/g38DYmqRCMUSpQzCYKAzlW8cWR0XTQqKT1G/k2CEiM2XsPANZcREZNkhAqJiIj0wyBB0MiRI2FlZYW4uDjUqVMHf/75JyIjIyXXRUZG4s8//4Svry9iY2NhZWWFkSNHGqJEIiIig5C5l0TrXgewt/5ijFHbaz1hLB4a/PVgA1pv8MXxoN1GqJIo58rjaI1/e1fBvG4V4KzlGPkjd16i8ZwAbLsczu4gIiIyS4JooH/BVq9ejX79+v3/CwsCChUqBHd3dwiCgJcvXyI4OBiiKEIURQiCgJUrV8Lf398Q5dEnCA8Ph7e3NwAgLCwMXl5eRq6IiMhMiSKib27E0v+mY52lBkpB0HpZQ6cSGN/4b+Sx8zBwgUQ526vYZEzcdQsHbr3Qul6/eG781qEs8jrZGLgyIiLKKbLi/ttgQRAA7Nu3D4MGDXrvJDDhf2963y0jX758WLJkCVq0aGGo0ugTMAgiItIztRLP/vsL828uxV5rOUQtgZAdBAwvMwBdK34LuUxuhCKJcq79N5/j5523EKnlOHkHKwUmtCyJrlW909/XEhER6YvZB0FA6slhO3bswNGjR3Hr1i1ERaVuiOni4oIyZcqgUaNGaNeuHSwspK24ZBoYBBERZZGkGNw7OQV/hu7BeWtLrZeUs3LDxIbzUTx3aQMXR5SzRcWnYNLu29h9/ZnW9TpF3TC9Q1l4u9gauDIiIsrOzDYIevIk9Thce3t7uLi4ZPXLURZjEERElLXEt2HYu3cgZqaE4o1c2v0jF4He+ZtisN9U2Cg4kkJkSIdvv8CEnbfwKlZ6nLydpRw/NC+BntULQCZjdxAREX25rLj/Nshm0QULFkShQoWwceNGQ7wcERGRWROcvdG61wHsrvIL2iVJN5NWC8CKsENov8EPZ4MPGaFCopyrSWkPHB1VFx0rSd+Ix6eo8fOu2+ix9D+ERsYboToiIqKPM0gQZGOT+mll1apVDfFyRERE2YJzua6Y2vc/LHeugYJKpWT9qSYJ3wSMxbjd3fE64bURKiTKmZxsLTCrS3ms6FcVeZ2sJev/PY5C07kBWHYmGGoNTxYjIiLTYpAgyNPTEwCgVqsN8XJERETZh5UDqrb9F1ubrMQ3SisotEx0H3hzC223NMT2a4t5nDWRAdUv7o5Do/zQvZq3ZC1JqcHUvXfQZXEgHr2KM0J1RERE2hkkCGrSpAkA4MyZM4Z4OSIiomzHKn8NDOn3H7YV7IZKydLuoBho8Mv1+ei3uTEeRz0wQoVEOZOjtQWmdyiHtV9Vh6ezdM+uy6Fv0GLeaSw+9QgqtXTUk4iIyNAMsll0UFAQKlasCHt7e1y+fDm9Q4jMEzeLJiIyLs2bUGzfNwCzU54iVi79TMdCBL4u1Bpf15kES7n208eISP/iklX44+A9rA4M1bpe3ssJMzuXh08eBwNXRkRE5spsN4suVqwY1q9fj4SEBNSoUQPr169HSkqKIV6aiIgo25HlKoBOPQ9hd+Uf0TxZ2mGgFIBFIXvQcb0vLoaeMEKFRDmTvZUCU9qWwcaBNVDAVXqM/PXwaLT66wzmHw+Ckt1BRERkJAbpCGrQoAEAIDQ0FMHBwRAEAZaWlihWrBhy5coFuZajcdMLFAQcO3Ysq0ukT8COICIiE5IUgzMHR2Da6//w1EKh9ZL2rhUxptFfcLJ2NmxtRDlYQooKfx56gBXngqHt3XbpfI6Y2ak8SuVzNHxxRERkNrLi/tsgQZBMJoMgCACQ6U0sBUGAKIoQBIGbTJsYBkFERKYnIfQc/jkyDKsVyVD/79/cd7lAju8qjUTLMn3S/00moqx3OTQK3229gcevpMfJK2QCvq1fFEPrF4WlwiCN+kREZGbMNgiqV6/eF73pPHGCbe2mhEEQEZGJUqtwL2AqJj/chFuWFlovqWXjiZ+aLIS3c2EDF0eUcyUp1Zhz9AH+DXgMbafJl/BwwMxO5VHWy8nwxRERkUkz2yCIshcGQUREpk39Jhgb9w7AX6rnSJBJuwysROCbIh3Rp/YEWMi0B0ZEpH/Xwt5i3NbrePBSepy8XCZgkF9hDG9YDNYWurdNICKinMVsN4t+8uQJnjx5gqioKEO8HBERUY4mz1UIPXsdwa6K36OBls2kkwVg3uNt6LreD9fDzxqhQqKcqYK3M/YMq4NhDYpCLnu/W16tEbHw5CO0+vsMrjx5Y6QKiYgoJzBIEFSwYEEUKlQIGzduNMTLERERkSDAo0JvzOt9DnMdKsBdpZJcEqSOg//RQfh1/1eIS441QpFEOY+VQo4xTYpj15DaKJlXulH0w4g4dFp0Dr/uu4MkJffJJCIi/TNIEGRjYwMAqFq1qiFejoiIiNJYO6FhhzXY1XAJeigtIXwwES4KAja+uoC2G/1w9PaGTB/qQERfpoynE3YNqY3RjX1gIX+/O0gjAv+eDkbzeadxIZgd9UREpF8GCYI8PT0BgKd/ERERGYl9QV+M7/cf1nq3hU+KtDsoAiqMuvQbhm9rjRcxYUaokCjnsVTIMLxhMewZVgdlPaUbRQe/jkfXJYGYtPs2ErR83xIREX0OgwRBTZo0AQCcOXPGEC9HRERE2sgtUK7hr9jYYS9GCW6w1kj3DzoZH4q221tg7bnfoNbwAxwiQyjh4Ygd39bCuGbFYSl//+25KAIrz4Wg6dwAnHv42kgVEhFRdmKQU8OCgoJQsWJF2Nvb4/Lly+kdQmSeeGoYEVE2IIoIv7oC0y7PwllL7Z8LlVY44pcGc1EyL0e7iQzlYUQsvtt6A1efvNW63qN6foxvXgIO1jzxj4goJzDbU8OKFSuG9evXIyEhATVq1MD69euRkpJiiJcmIiIibQQBXpX6Y1HPM5hhXwYuWsa3b6ti0P1QP/x58BskpMQboUiinKeouwO2flMLP7UsCSuF9K36+vNP0HROAE49eGWE6oiIKDswSEdQgwYNAAChoaEIDg6GIAiwtLREsWLFkCtXLsjlct0FCgKOHTuW1SXSJ2BHEBFR9hP9+DjmHB+LbRZKrev5YIEJNSbAr3hHA1dGlHMFv47H91tv4EKI9g2ju1TxwoSWpeBkw+4gIqLsKivuvw0SBMlkMghC6mkImX05QRAgiiIEQeAm0yaGQRARUTalVuLy8Z8xOXQXgi0UWi9pal8YPzRdDDd7DwMXR5QzaTQi1vwXihkH7yEhRfqeOI+jFX5rXxYNS+YxQnVERJTVzDYIqlevXnoQ9DlOnDihx2roSzEIIiLK3lJeB2HZ/gH4V/MaSi3/fjuIwKiSfdGx2ijIBINMmRPleGFRCfh+2w2cexSpdb19RU/80roUnG0tDVwZERFlJbMNgih7YRBERJQDiCKCLy/B1KvzcNFS+wh3RUUu/NLobxTJU97AxRHlTKIoYsOFMPy2/y7ikqXHybvZW2Fau9JoViavEaojIqKsYLabRRMREZGZEQQUqjIIy3qexRS7UnDSMqZ9VfUGnQ70xN+HhyFZlWSEIolyFkEQ0KN6fhwa5Qc/n9yS9ddxyfhm7RUMWX8FkXHJRqiQiIjMAYMgIiIi0kmwzYX2nTZhl988tFRJ9w1SCQKWPD+Jjutq43zQHiNUSJTzeDrbYFW/qvijUzk4WEu/L/fdeI7GcwKw5/qzTO/PSUREOYfRgqDw8HBcunQJAQEBSExMNFYZRERElAmuRRvj974XsNijCbxU0u6gUKTg63M/YsKOjngTz2OtibKaIAjoUsUbR0bVRcMS7pL1qPgUDNtwFd+svYyIWHbsERHR/zNoEBQbG4uff/4Z3t7eKFCgAKpXr4769esjODj4ves2btyILl26YMCAAYYsj4iIiDIit0CtprOwvfV2fAVnKLR0GuyOeYA2Wxpi98W/2IlAZAAeTtZY2qcK5natAGdb6THyh26/ROPZAdh+JZzfk0REBMCAm0UHBQWhRYsWePz48Xv/CAmCgJs3b6JUqVLpvxYSEoKiRYtCFEWcOnUKderUMUSJlEncLJqIiCCKeHBxISbfWIgbFto/V6pu4YqfmyxEAbdSWteJSL8iYpMwcedtHLz9Qut6gxLu+K19WXg4WRu4MiIi+lxmu1l0UlISWrZsiUePHsHW1hbjxo3D3r17dV5fsGBB1K9fHwCwe/duQ5RIREREn0IQ4FNtCNb0CMAEWx/YazSSS84rI9FhbxcsOTYGSrXSCEUS5SzuDtZY1KsSFvSoBFc76THyx+9FoPHsU9h08Qm7g4iIcjCDBEGLFi3Cw4cPYWdnh9OnT+P3339HixYtMnxM8+bNIYoiAgMDDVEiERERfQaZrSu6dd6GXXVmobFKesx8iiDg7/DD6LKuJq49PmyEColyFkEQ0LJcXhwe5YfW5fNJ1mOTVfh+2030Xn4B4W8SjFAhEREZm0GCoO3bt0MQBIwYMQIVKlTI1GPKly8PIHWkjIiIiEybe7FmmN3nAv7O0wAeWjaTfigmw//0GEzd2Q0xSW+MUCFRzuJqb4W/u1fEYv/KcLO3kqyfDnqNpnMCsOa/UGg07A4iIspJDBIE3b17FwDQpEmTTD/G1dUVAPD27dusKImIiIj0TWGJes3mYVerzegFJ8i0jJ5sjr6Nthvr4tDVxRxNITKApqU9cHS0HzpU8pSsxaeo8fPOW+i59DyeRLI7iIgopzBIEBQXFwcAsLe3z/RjkpOTAQAWFtLTD4iIiMh02eYpg+97n8b6EgNQUindO+i1IGLsjfkYurERnkWx85coqznbWmJ2lwpY3rcKPBylG0UHPo5E07kBWHE2mN1BREQ5gEGCoLTunpCQkEw/5vbt2wAADw+PrCiJiIiIspIgoHSNEVjf7QTG2hSBjZbNpANSItBud3usOvEDVBqVEYokylkalMiDw6P90LWKt2QtUanG5D130GVxIB6/ijNCdUREZCgGCYIqVaoEAAgICMj0Y1avXg1BEFCzZs2sKouIiIiymMLeHX267MSOmtPhq5K+7UgUBPz5ZB96rK2F26EnDV8gUQ7jaG2BGZ3KYc1X1eDpbCNZvxT6Bs3nncaSgEdQszuIiChbMkgQ1KlTJ4iiiCVLluDJkycfvX7u3LnpoVH37t2zujwiIiLKYp4l2mBB7wuYmdsPbmrpZtJ3xUT0ODEUM3b3QkJyrBEqJMpZfIvlxqFRfuhVI79kLVmlwW/776HDonMIesnvRyKi7MYgQZC/vz/KlSuHpKQk1KtXDwcOHHhvg0hBECCKIi5evIiePXtizJgxEAQBvr6+aN68uSFKJCIioiwmWFihWYsF2NViAzrDQbKuEQSsfXMdbTfUwcnrK4xQIVHOYm+lwLR2ZbF+QHXkd7GVrF8Pe4uWf53BghMPoVRLxzuJiMg8CaKBjux48uQJ6tSpg/DwcAiCAFtbWyQkpJ5O4ObmhtjY2PQNokVRRJEiRXD27Fm4u7sbojz6BOHh4fD2Tp0tDwsLg5eXl5ErIiIisyOKuHpuFibfW4FHCu2fSzW2yosfmi2Gu3MhAxdHlPMkpKgw89B9rDwXAm13B2U8HTGzU3mUzOto+OKIiHKwrLj/NkhHEADkz58f165dQ/fu3SGTyRAfHw9RFCGKIl69eoWkpKT0LqEuXbrgwoULDIGIiIiyK0FAxdpjsaXrcQyzLghLLXuRHEl+jrY7WmPTqYnQiOxGIMpKtpYK/NK6NLYMqonCbnaS9VtPY9D67zOYc+QBUlT8fiQiMmcG6wh6V2hoKPbt24dLly4hIiICarUarq6uqFixIlq3bg0fHx9Dl0SfgB1BRESkb6F3tmNq4GScV2i/wSwv2OKXBnNQzKuWgSsjynmSlGrMOfIA/55+DG37RZfwcMDMTuVR1svJ8MUREeUwWXH/bZQgiMwbgyAiIsoKojIJew6PwMyXZ/BWLm1aVogi+rpVwaAmC2BtKe1YICL9uvrkDcZtvYGgCOlx8nKZgEF+hTGiUTFYKeRGqI6IKGcw69EwIiIioowIFtZo03IxdjdbgzaiNOhRCQKWRl5Gh/W1EHhrvREqJMpZKubPhb3D62BI/SKQy4T31tQaEQtPPkLLv87g6pM3RqqQiIg+B4MgIiIiMim58lXCr30CsbRwD+TXshdJmKDBwMvTMX5zc0TFhBuhQqKcw0ohx3dNS2DXkNoo4SE97e9hRBw6LjqH3/bfRZJSbYQKiYjoUzEIIiIiItMjCKjuOx7bOx/BQCtvKLRMsu9NDEebbc2x48w0cNKdKGuV8XTC7qF1MLJRMSg+6A7SiMCSgMdoMe80LoVEGalCIiLKLAZBREREZLKsHPNhWLf92FLlZ1RUCZL1aBkw8dEmfLW2FoKfXjBChUQ5h6VChpGNfLBnWB2U8ZQeI//4dTw6Lw7E1L13kJjC7iAiIlPFIIiIiIhMXtEyXbGy93n87FINDhrpuNhFTRw6HumPRfu+Rooy0QgVEuUcJfM6Yue3tfFd0+Kw/GBjd1EElp0JRvN5AbgQzO4gIiJTxCCIiIiIzILMwgZdWi/DrkbL0VS0lawrBQELX59Hp3U1cfnuFiNUSJRzKOQyDKlfFPuG10EFb2fJekhkArouCcSk3beRkKIyfIFERKQTgyAiIiIyK7m9q+PPPv9hQcHOyKeWdgcFC2r0vTAFkza3QnTccyNUSJRzFMvjgG2Da2F88xKwVEi7g1aeC0Gzuafx3+NII1VIREQfYhBk5p4+fYq5c+eiSZMmyJ8/PywtLeHh4YGOHTvi/Pnzxi6PiIgoawgC/OpOxI6OB9HH0hMyLZtFb0sMRdstTXAg8A9uJk2UheQyAYPqFsH+4b6omN9Zsv4kKgHdlvyHX3bdQnwyu4OIiIxNEPnOyKz98MMPmDFjBooUKYJ69eohd+7cCAoKws6dOyGKItavX4+uXbvq9TXDw8Ph7e0NAAgLC4OXl5den5+IiOhT3b2xFpMu/oE7Cu1va2rLHPFT4/nw8qho4MqIcha1RsSKs8GYeeg+klXSjj1vFxvM6FgOtYq4GaE6IiLzkxX33wyCzNz27dvh6uqKunXrvvfrp0+fRsOGDWFvb4/nz5/DyspKb6/JIIiIiEyROjkeGw5+i7+iLiFRJm16ttaI+NbDF70az4WFQn//LhKR1ONXcRi39QYuhb7Ruu5fowB+aF4CdlYKA1dGRGResm0QdP36dTx8+BCCIKBw4cKoUKGCsUvKFpo2bYrDhw/j4sWLqFKlit6el0EQERGZshdPzuLX4yNxUkjSul5cVOCXmpNQtnhbA1dGlLOoNSJWngvBzEP3kKSUdgd55UrtDqpdlN1BRES6ZMX9d5bsEfTgwQM8ePAAERERGV53/PhxlChRApUqVUKXLl3QuXNnVK5cGYULF8auXbuyorT3REREYO/evZg4cSKaN28ONzc3CIIAQRDQt2/fT3qu0NBQjBkzBiVKlICdnR1cXFxQtWpVzJw5EwkJCVnzBXyEhYUFAECh4CctRESUc3jkr42//P/DnPzt4K5lM+n7ggo9Aydg+tZ2iIvP+L0KEX0+uUzAV3UK4eAIP1Qr6CJZD3+TiJ5Lz+PHHTcRm6Q0QoVERDmT3juCbty4gQoVKkAQBKxYsQK9e/fWet2hQ4fQunVrqNVqrRs4ymQyrF69Gj169NBnee8RBEHnWp8+fbBy5cpMPc+ePXvQq1cvxMTEaF338fHBvn37ULRo0c8p87M8efIEPj4+cHFxQVhYGORyud6emx1BRERkLmLfBOOvAwOxKeU5RC3/7rtrgB9L9kfDGqOMUB1RzqHRiFgdGIIZB+8jUamWrHs62+D3jmXhWyy3EaojIjJdZtERdPjwYQCAk5MTunfvrvWahIQE9O/fHyqVCqIowsXFBf7+/vj+++/RsGFDAIBGo8HQoUMRFRWl7xK1yp8/P5o0afLJj7t69Sq6du2KmJgY2Nvb49dff8W5c+dw7NgxDBgwAEBqh1TLli0RGxur77K1UiqV8Pf3R3JyMmbMmKHXEIiIiMicOOQqhAk9jmBNhTEoJr33RIQMGHl/OUas9cWLl7cMXyBRDiGTCehbuxAOjvRF9ULS7qCnbxPhv+wCxm+/gRh2BxERZSm9zwxduHABgiCgZcuW6aNJH1q/fj2eP38OQRBQunRpHD58GB4eHunrK1euRP/+/REdHY1169Zh2LBh+i4TADBx4kRUrVoVVatWRZ48eRASEoJChQp90nOMGDECiYmJUCgUOHz4MGrWrJm+1qBBAxQrVgzjxo3DgwcPMGvWLEyaNEnyHGPGjEFycvInvWaxYsW0rmk0GvTt2xcBAQEYMGAA/P39P+nrISIiyo7KV+iHTSU7Y/WBb7DozTUky97vDjqufovz+7thWL766NZwNuQK7e9hiOjLFHC1w4YBNbD2fCh+P3APCSnvJ7QbLoTh5P1X+L1jOdT1YXcQEVFW0PtoWNmyZXHnzh38888/6R0xH2ratCmOHDkCQRBw7Ngx1KtXT3JNy5YtceDAAbRs2RJ79uzRZ4k6vRsEZWY07MKFC6hevToAYNCgQfjnn38k12g0GpQpUwZ3796Fs7MzIiIiJAGZvb094uPjM13niRMntP6eaTQa9O/fH6tWrUKvXr2watUqyLScmvKlOBpGRETmLCz4JKaeGotAQfuHMGVFC/xSexqKF2th4MqIcpYnkQn4ftsNBD6O1LretYo3JrQqCUdrBrNElHOZxWhYeHg4AKBkyZJa1zUaDc6dOwdBEODl5aU10ACALl26AABu3TLdNu2dO3em/7xfv35ar5HJZOn7JL19+xYnTpyQXBMXFwdRFDP9Q1cI1K9fP6xatQrdu3fHypUrsyQEIiIiMnfehephsf95TPdqCRctm0nfFJToenYcZm/riMRE7UdfE9GXy+9qi3VfV8e0dmVgZyndymDTpTA0mR2AE/e4qTsRkT7pPSmIi4sDADg6Ompdv337dnr3S926dXU+T4kSJQAAkZHaPyEwBWfOnAEA2NnZoXLlyjqve/frPHv2rN7rSAuBVq9eja5du2LNmjXcF4iIiCgDglyOVg1/x662u9BeIR0/UQsCVsQ9QPsNfjh7cb4RKiTKGWQyAb1qFMDBkX6oXdRVsv4iJgn9Vl7E2C3XEZ3AvYOIiPRB70GQtbU1AOjcGPn8+fPpP88oPEl7nqSkJD1Wp193794FABQtWjTDI9rTQq13H6MvaeNgq1evRufOnbF27VqGQERERJnk7FoUU3oex/Kyw1BQy2bST+XAN3cWY9y6enj9+p7hCyTKIbxdbLH2q+r4rX1Zrd1BWy+Ho8ncUzh296URqiMiyl70vll03rx58ejRI1y7dg21a9eWrJ8+fTr95zVq1ND5PG/epLZi29vb67tEvUhKSsLr168B4KMzerly5YKdnR3i4+MRFham1zqmTJmCVatWwd7eHj4+Ppg2bZrkmnbt2qFChQqZfs608T5dnj9//qllEhERmbSqlQZiW6muWLp/IJbG3Ibyg6PmD6gicWZPJ4zxaoL2Df6ATK73t1BEOZ4gCOhRPT/8fNwwfvtNnA56/d76y5hkfLXqEjpU9MTE1qXgbGtppEqJiMyb3t/FVKlSBQ8fPsSKFSswZMiQ99bi4+PTN352cHBAlSpVdD7P/fv3AXw8ZDGWdzueMhNWpQVBaaNz+hISEgIgdSTv119/1XpNwYIFPykIStuIioiIKCextHbCtx02odnjo5gS8D0uCynvrcfKBEx6dgS711THL34zULhwIyNVSpS9eeWyxer+1bDpYhim7buLuGTVe+vbrz7F6Yev8Vv7smhcKo+RqiQiMl96Hw3r3r07AODq1asYMGAAYmJiAKRulNy3b1+8ffsWgiCgU6dOGY4wBQQEAABKly6t7xL14t2RNUvLj38aYWVlBQBITEzUax0rV6786ObSffv21etrEhERZWeFCzfCcv/zmJyvCRw10s2krwgp6BgwEgu2d0Vy0lvDF0iUAwiCgG7V8uPwKD+tx8i/ik3GgNWXMHLjVbyJT9HyDEREpIveg6DWrVujdu3aEEURy5cvR+7cueHp6Qk3Nzds374dAGBhYYHvv/9e53MkJCRgz549EARB63iZKUjbwwgAUlI+/o9PcnLqEbU2NjZZVpO+hIWFZfjjwoULxi6RiIgoS8nkCnRoPAu7Wm9DC4V0A1uVIOCf2DvotN4XFy8vMUKFRDlDPmcbrOxXFX90KgcHa+kww85rz9B4TgAO3X5hhOqIiMxTlpwvvm3bNpQpUwaiKEKpVOL58+fQaDQQRREymQwLFy5EsWLFdD5+1apV6SNUTZs2zYoSv5iDg0P6zzMz7pV2Upqp7nn0Li8vrwx/5M2b19glEhERGYSbWwnM6HkS/5QaBE8tm0mHyIH+t/7Gz+sa4G3kQ8MXSJQDCIKALlW8cXiUH+oXl3YHvY5LxqA1lzF8w1VEsTuIiOijsiQIcnd3x+XLl/H333+jcePGKF68OMqWLYs+ffogMDAQ/fv3z/DxR44cQeXKldGmTZsMAyNjsra2hqtr6ieEH9tc+c2bN+lBEPffISIiMj+1qw7Fjm6n0N++OOSiKFnfqXqFtrvaYs+x7yGqtSRGRPTF8jrZYHnfqvizc3mt3UG7rz9DkzmncOAmDzYhIspIlh15YWFhgSFDhkg2jM6MtBEyU1eqVCmcPn0aDx8+hEql0nmE/L17/3/cbMmSJQ1VHhEREemRja0LRnXcihZB+zH57E+4KSjfW4+Sy/Bj+H7sWXMMP9f9E96F6hmlTqLsTBAEdKrshTpF3TBhx00cuxfx3vrruBQMXncFLcvlxZQ2peFqb2WkSomITFeWdATlFHXq1AGQOvZ1+fJlndedOnUq/eemuucRERERZU7xYi2wptd5jPeoDzuNtDsoUEhG+5NDsXRnTyiT9XtaKBGl8nCyxtI+VTCna3k42VhI1vfdeI7GcwKw7wa7g4iIPsQg6Au0a9cu/ecrVqzQeo1Go8Hq1asBAM7Ozqhfv74hSiMiIqIsJFdYoEfTv7CzxUY0lDtL1pNlAuZF30DXdTVx/Zr29whE9GUEQUD7il44MsoPjUpKj5GPik/BkPVX8O26y3gdl2yEComITBODoC9QrVo1+Pr6AgCWLVuGwMBAyTWzZs3C3bt3AQAjRoyAhYX0EwsiIiIyTx55ymBur9OYV6I/3KUnzSNIDvhfm4Vp6xsj9m2IwesjygncHa3xb+/KmNetApxtpe+19998gcazT2HP9WcQtezxRUSU0wiinv82nDJlij6fDgAwceJEvT8nAJw5cwYPH/7/CR+vX7/Gd999ByB1hOvrr79+7/q+fftKnuPq1auoXbs2EhMTYW9vjx9//BH169dHYmIiNm7ciCVLUo+U9fHxwaVLl947bcxchYeHp296HRYWBi8vLyNXREREZHzx8a/w94EBWB/3EKIgSNZzqzUYX6g9GtWdAkHGz+KIskJEbBJ+2nELh++81LretHQeTG1XBu4O1gaujIjo82TF/bfegyCZTAZBy5ufL6HOotM3+vbti1WrVmX6el2/VXv27EGvXr0QExOjdd3Hxwf79u1D0aJFP6tOU8MgiIiISLdb93dhUuAk3BdUWtfridb4scFc5M3PfQOJsoIoithz4zl+2XULbxKUknVnWwtMblMabcrn0/t9CxGRvmXF/XeWfRwliqJefpiD1q1b48aNGxg1ahR8fHxga2sLZ2dnVKlSBTNmzMDVq1ezTQhEREREGStTvC029jqPMe51YKNlM+mTQhLaHhuINbt6Q50cb4QKibI3QRDQpnw+HB5VF83LeEjW3yYoMWLjNQxccxkRMUlGqJCIyLiyrCPI2toabdu2Re/evb/4yPQCBQroqTrSB3YEERERZc7TF9cw7cgQnNFo7xoupRYwqer3KFm2p4ErI8o59t14jp933UJUfIpkzcnGApPalEK7Cp7sDiIik2QWo2GNGzfGiRMnoNFo0v8yrVy5Mvz9/dGtWzfkzp1bny9HRsAgiIiIKPNEUcShwD8w/f5aRGnpxZaJInpZeWJIi6WwdfI2fIFEOUBkXDIm7r6t8zj5RiXd8Wv7ssjjyL2DiMi0mEUQBADPnj3DunXrsHbtWty8eTP1hQQBCoUCTZs2Ra9evdC2bVtYWVnp+6XJABgEERERfbrouOeYs38AtiWGal3Pq9bgpyJd4Oc3EWBnAlGWOHAztTvodZy0O8jRWoGJrUujYyV2BxGR6TCbIOhdN27cwKpVq7Bhwwa8ePEi9UUFAY6OjujcuTN69eoFPz+/rCyB9IxBEBER0ee7fHcLppz/FY8F7YdhNBFt8UOjv5Dbq7qBKyPKGaLiU/DL7tvYc/2Z1vX6xXNjeody8HBidxARGZ9ZBkFpNBoNjh49itWrV2Pnzp1ISEhIT9rz588Pf39/9OrVCz4+PoYoh74AgyAiIqIvk6JMxPLDw7HkVSCUWjoPHDQajHSrjk7NFkBmYWOEComyv4O3XuCnnbfwOi5ZsuZgrcDPrUqhc2UvdgcRkVGZdRD0rvj4eGzfvh2rVq3CyZMn39tPqFatWjh9+rShS6JPwCCIiIhIP0KeXsSU48NxUROndb2CSsAv1SegaJmuBq6MKGd4E5+CyXtuY+c17d1Bfj658XuHssjnzECWiIwj2wRB73r27BmWL1+O3377DUlJSbC2tkZCQoIxS6KPYBBERESkP6IoYteZafjz4SZEy6SdBwpRRD/r/BjYchmsHfIaoUKi7O/w7ReYsPMWXsVKu4PsrRT4qWVJdK3qze4gIjK4rLj/1nJ2heEEBgZi2rRpmDt3LpKTpX/pEhEREWV3giCgne/P2N3hAFpbS9/cqQQB/yaHoePmRvjv9G+AcT/DI8qWmpT2wJFRfuhQyVOyFpeswg/bb6L38gt4+jbRCNUREemXwYOgR48eYfLkyShWrBjq1KmDxYsXIyoqClZWVujSpQu2bt1q6JKIiIiIjM7FyRu/dT2AJZV/gLcofYv2RCHDgMcbMGF1Tbx5dsUIFRJlb862lpjdpQKW9amCPI7S041PB71G0zkBWH/+CYw8VEFE9EUMMhr25s0bbNy4EWvWrMH58+cBpLZBC4IAX19f+Pv7o3PnznB0dMzqUkgPOBpGRESUtZJS4rHk8BCseH0JKi2jKM5qDcbmqYM2TeZBsODJRkT6Fp2gxNR9d7D1crjW9TpF3TC9Q1l4u9gauDIiymnMao8gpVKJPXv2YM2aNThw4ACUSmV6cu7j4wN/f3/4+/sjf/78WfHylIUYBBERERlGUNhZTD4xGtdF7fsnVlfJ8HPNX1CgVAcDV0aUM5y4F4Hx22/iRUySZM3OUo4fWpREz2r5IdOyvxcRkT6YRRB05swZrF27Flu2bMHbt2/Twx9XV1d069YNvXv3RtWqVfX5kmRgDIKIiIgMRyNqsCXgF8x9vANxWm42LTUiBtoWQv+WS2Fhn8cIFRJlb9GJSvy67w42X9LeHVSzsCv+6FSO3UFElCXMIgiSyWQQBAGiKMLKygpt2rSBv78/mjVrBoVCoc+XIiNhEERERGR4EW+D8fvBQTiS/FzrehGVBr+U6IeKtcYAPNmISO9O3k/tDnoeLe0OsrWU44fmJdCregF2BxGRXplVEGRtbY2mTZvC2dn5i55PEAQsW7ZMP8XRZylduvR7/61UKhEUFASAQRAREZGhnbq+AtOuzsULQaN1vRMcMKrpIjh6lDdwZUTZX0ySEtP338WGC2Fa16sXcsEfncqhgKudgSsjouzKrIIgfVKr1Xp9Pvo0DIKIiIhMS0JyLOYfGox1Udeg0fK+y1Wtxg8e9dC08RwIFtLTj4joywQ8eIXx229qPU7exkKO75sVR++aBdkdRERfzGyCIH3TaLR/4kXGwdEwIiIi03A79CQmnxqHu6L0ZhQAfFUyTKg9DZ4lWhu4MqLsLzZJiekH7mH9+Sda16sVTO0OKujG7iAi+nxZcf+t99RGo9Ho/QcRERERSZUuUA/re53Dd/lbwkbLZ3unFRq0D/wBqza3gyouwggVEmVfDtYW+K19Waz9qjo8nW0k6xdCotBsXgCWnQmGWpMlBzUTEX0W/bfvEBEREZHBKGQK9K7/O3a22Y66lu6S9USZDH8mPkKPjfVx+7+5gH6bwYlyvDrF3HBolB/8axSQrCUpNZi69w66Lg7E41dxRqiOiEiKQRARERFRNpDPxQd/dzuKWeWGwk2U7kty10KGHveWYsZqX8S9uGmEComyL3srBaa2K4P1A6rD20XaHXQp9A2azzuNpacfszuIiIzO5IOgy5cvG7sEIiIiIrMgCAKaVByEXV1PoYtTKcm6RhCwFtFos68r9u8ZADE53ghVEmVftYq44eAIP/SpKe0OSlZpMG3fXXT+5xwesTuIiIzIZIOgc+fOoXnz5qhevbqxSyEiIiIyK442ufBzu01Y4zsLRQXpqWGvFHJ8H/Uf+q+pjoeXl3JcjEiP7KwUmNy2DDYOrIH8LraS9StP3qL5vNNYfOoRu4OIyChMLgg6duwY6tevD19fXxw+fBh6PtSMiIiIKMeoULgJNvcMxHCvJrDU8p7qkoWAzjfn4s81dRH/8pYRKiTKvmoUdsXBkb7oW6ugZC1FpcH0A/fQcdE5PIyINXxxRJSj6f34+DSiKGLHjh04evQowsLCYGFhgYIFC6JTp06oVauW5PqTJ0/ixx9/xPnz59MfDwBNmjTBwYMHs6JE+kw8Pp6IiMj8PHl9F9OPDMWZFO2nh+VWqTHWvTaaN5kLwYrHXRPp04XgKIzbeh0hkQmSNUuFDKMa+WCAbyEo5Cb3OT0RGVlW3H9nSRAUGhqKtm3b4uZN7RsRdu7cGevWrYNcLkdkZCS+/vpr7N69G0BqACQIAtq0aYMJEyagSpUq+i6PvhCDICIiIvMkiiKO31qNPy7PwTNBrfWaqkrgx4ojUbRSf0CQbjpNRJ8nMUWNPw/fx/KzwVqnMct7OWFm5/LwyeNg+OKIyGSZRRCUkpKCypUr4/bt27pfVBAwZswYDBs2DHXr1kVoaChEUYRcLkeXLl3w448/onTp0vosi/SIQRAREZF5S1QmYOnxsVjxPABKLWGPQhTRU+aCwU0Xwi5PGSNUSJR9XQqJwndbbyD4tXSzdku5DCMaFcMgv8LsDiIiAFlz/633v13WrVuH27dvQxAEFCxYEEuXLsX58+dx9epVrF+/HhUrVoQoili0aBF69OiBkJAQiKKIjh074s6dO1i3bh1DICIiIqIsZGNhi2FNF2JHy02obekuWVcJAlaJb9Bmbxcc2DuQp4sR6VGVgi44MMIXA3wLSZruUtQazDx0H+0XnsO9FzHGKZCIsj29dwS1bt0a+/btg7e3N27fvg17e/v31jUaDfz8/HDu3DkAgFwux7Jly9C7d299lkFZiB1BRERE2Ycoijh+czVmXJmD5zrGxaopgfGVRqFoxX4cFyPSo8uhb/Dd1ut4/EoatlrIBQxvUAzf1CsCC3YHEeVYZtERdP36dQiCgO+++04SAgGATCbDlClTAKSOiPn7+zMEIiIiIjISQRDQsFwf7Op5DgPz+MJCy2eEFyyAzjdmY9aaejxdjEiPKhfIhf3DfTGobmHIPshYlWoRs448QLsFZ3H3ObuDiEh/9B4ERUZGAgDKlNE9T16uXLn0n3fq1EnfJRARERHRJ7KxsMWwZhmPi60UozguRqRn1hZyjG9eEtsG10JRd+kH6befxaD132cw9+gDpKg0RqiQiLIbvQdBiYmJAAB3d+kbiDRubm7pP+dYEREREZHpKJC7NBZ1O4q5FcciryiXrEco5BgXGYiv19TAoyvLofX4IyL6ZBXz58LeYXUwuF4RSXeQSiNi7tEgtF1wFrefRRunQCLKNow+bKpQKIxdAhERERG9491xsQF56ugcF+uUNi4WwXExIn2wtpDj+2YlsOPb2vDJI+0Ouvs8Bm3nn8XsI+wOIqLPZ/QgiIiIiIhMk42FLYY3W/S/cbHckvX0cbE9HBcj0qfy3s7YM6wOhtYvCvkH7UEqjYi/jgWhzfwzuPWU3UFE9On0fmqYTCaDIAgYPHhwhuNhkyZNytR1ADBx4kR9lkhfiKeGERER5Typp4utxIwr8zI8XezHSqNRpGJfni5GpCc3w6Mxdst13H8ZK1mTywR8W68IhjYoCiuFdJSTiMxfVtx/Z1kQpE9qtfY3G2QcDIKIiIhyrkRlAv49NgYrX5yGUst7PoUoopfMFd80WwA7d92HhxBR5qWoNJh/PAgLTj6CWiO9fSuexwEzO5dDOS9nwxdHRFnKLI6PB1I/MdLXDyIiIiIyHf8/Lrbxo+NiB/cOgpiSYIQqibIXS4UMo5sUx64htVHCw0Gyfv9lLNovPIc/Dt5DsoofohNRxvTeEXTq1Cl9Ph0AoG7dunp/Tvp87AgiIiIiIHPjYtWVwPjKY1LHxYjoi6WoNFhw4iEWnHgIlZbuoGLu9pjZuTwqeDsbvjgi0juzGA2j7I9BEBEREb0rdVxsNFa+OKNzXMxf5opBHBcj0pvbz6Lx3ZYbuPM8RrImE4ABfoUxqpEPrC24dxCROTOb0TAiIiIiyjlSx8X+wfaWG1BLx7jYivRxsW84LkakB6XzOWHX0NoY3dgHFvL3A1iNCCw+9Rgt/zqNK0/eGKlCIjJVDIKIiIiISC8K5i6Lf7odw5wKo5BXlHYhRCjk+C7yLAasro7H11YZoUKi7MVCLsPwhsWwe2gdlPF0lKw/ehWPTovO4bf9d5Gk5N5BRJSKQRARERER6Y0gCGhUvj929jiLAXlqQaFlF4LzFkDHazMxe009xEfcNkKVRNlLybyO2PFtbYxtor07aEnAY7SYdxqXQ6OMVCERmRIGQURERESkd7aWdhjebDF2tNyAWpZuknWVIGCFJhJt9nTGwX0cFyP6UhZyGYY2KIa9w3xR1tNJsv74dTw6/ROIaXvvIDGF3UFEORk3i6aPKl269Hv/rVQqERQUBICbRRMREdHHiaKIYzdWYMbVv/Aig9PFfqwyFoUr9DFwdUTZj0qtwZLTjzH3SBBS1BrJeiE3O/zRqRyqFnQxQnVE9Cm4WTQRERERmZ20cbFdmRwXS3h11whVEmUfCrkM39Yrin3D66C8lmPkg1/Ho8viQEzecxsJKSrDF0hERsWOIPpkPD6eiIiIvkRIxA1MPzYC51Jea113V6nxnUddNG08C4KlrYGrI8peVGoNlp4JxuwjD5CiknYHFXC1xR8dy6F6YVcjVEdEH8OOICIiIiIyewXdy+Gfbscxu/xIeOg6Xez1mf+dLrbaCBUSZR8KuQzf1C2C/cProIKW7qDQyAR0XfIfJu1mdxBRTsEgiIiIiIgMThAENK7wFXb1OIuv3TMaF/sDs9fU57gY0Rcq6u6AbYNr4ccWJWClkN4GrjwXgmZzTyPwUaQRqiMiQ2IQRERERERGY2tphxHNF2N78/WoaaHrdLHXaL27Iw7u+xZiSqIRqiTKHuQyAQP9imD/CF9ULpBLsv4kKgHd//0PP++8hfhkdgcRZVcMgoiIiIjI6ArlKYfF3T82LnYaA1ZXw+Pra4xQIVH2USS3PTYPqomfWpbU2h205r9QNJ0bgHMPte/jRUTmjUEQEREREZmE9HGx7mfwtXtN3eNiV2dwXIzoC8llAr72LYyDI/1QtaC0Oyj8TSJ6LD2PCTtuIo7dQUTZCoMgIiIiIjIptlb2GNF8yf/GxaQnGb07LnZo/xCOixF9gUJudtg0sCZ+aV0K1hbS28N155+g6ZwAnAlidxBRdsEgiIiIiIhMUuq42AnMLj8CeUTp29YIhRxjXwVgIMfFiL6ITCagX+1CODjCD9UKuUjWn75NRK9l5zF++03EJimNUCER6RODICIiIiIyWanjYl9jd/ez+ErHuNh//xsXm7OmARJe3TNClUTZQ0E3O2wcUAOT25SGjYV0r64NF1K7gwIevDJCdUSkLwyCiIiIiMjk2VrZY+RHxsWWa16hze4OHBcj+gIymYA+tQri0Eg/1Cgs7Q56Fp2E3ssv4PutNxDD7iAis8QgiIiIiIjMRtq42Kxyw7WOi718b1xsrREqJMoe8rvaYv3XNTC1bWnYWkq7gzZdCkPTOQE4cT/CCNUR0ZdgEEREREREZkUQBDSpOCB1XCx3jQzGxX7HnLUcFyP6XDKZAP+aqd1BtYpIO/GeRyeh34qL+G7LdUQnsjuIyFwwCCIiIiIis2RrZY+RLf7F9ubrUEPXuJia42JEX8rbxRbrvq6OX9uXgZ2W7qAtl8PRZM4pHL/30gjVEdGnYhBERERERGatUJ7yWJKpcbHqeHx9nREqJDJ/giCgZ/UCODTKD3WKuknWX8Yko//KSxiz+TqiE9gdRGTKGAQRERERkdnL3LiYiI5Xp2Pu2gZIeHXfCFUSmT+vXLZY81U1TO9QFvZWCsn6tivhaDznFI7eYXcQkaliEERERERE2UbauNi25mt1jostU79Cm93tcfjAUIjKJCNUSWTeBEFA92r5cWiUH3yLSbuDImKT8fXqSxi16RreJqQYoUIiygiDICIiIiLKdgrnqYAl3U/gz3JDdY6LjYk4hUGrqiH4xnojVEhk/jydbbC6fzX80bEcHLR0B+24+hSNZgfg0O0XRqiOiHRhEERERERE2ZIgCGhacRB2dz+L/rmrax0XC7QQ0eHKb6njYq8fGKFKIvMmCAK6VPXG4dF+qFc8t2T9dVwyBq25jOEbriIqnt1BRKaAQRARERERZWu2VvYY1WLp/8bFXCTraeNibXe147gY0WfK62SDFX2rYmancnCwlnYH7b7+DE3mnMLBW8+NUB0RvYtBEBERERHlCKnjYifxZ7mhcNcyLvaC42JEX0QQBHSu4o0jo+qiQQl3yfrruBR8s/YKhq6/gsi4ZCNUSEQAgyAiIiIiykHSxsX2ZGpcrCESXgcZoUoi8+bhZI1lfapgVufycNTSHbT3xnM0mROA/TfZHURkDAyCiIiIiCjHSR8Xa7Ya1XWOi0Wg7a62OHJgGMfFiD6RIAjoWNkLR0bXRaOS0u6gyPgUfLvuCoasu4LX7A4iMigGQURERESUYxX2qIR/u5/EzLJDdI6LjY44yXExos+Ux9Ea//augrldK8DJxkKyvu9manfQnuvPIGrp0CMi/WMQREREREQ5miAIaFbpG+zpfhb93KplOC42j+NiRJ9MEAS0q+iJI6P90KRUHsl6VHwKhm24isFrr+BVLLuDiLIagyAiIiIiIqSOi41uuSx1XEyhfVxsKcfFiD6bu4M1FvtXxl/dKyKXrbQ76ODtF2g85xR2XXvK7iCiLMQgiIiIiIjoHYU9KuHfHh8fF/tmdTWE3NxghAqJzJcgCGhTPh8Oj6qLZqU9JOtvE5QYsfEaBq25jIhYhq1EWYFBEBERERHRBzIzLnZOIaL95V85Lkb0GXI7WGFRr0qY36MiXOwsJeuH77xE49kB2HmV3UFE+sYgiIiIiIhIh/fHxXJJ1t8dFzt6cDhEJfc3IcosQRDQqlw+HB7lh5Zl80rWoxOVGLnpGoZtuIroBKURKiTKnhgEERERERF9ROq42CnMLPutznGxUS9P4JvVVRFya6MRKiQyX272VljQsxIW9KgEVy3dQXtvPEezeQE49/C1Eaojyn4YBBERERERZULquNhg7O5+Bv3cquoeF7s0DX+ta4SEyIdGqJLIfLUslxeHR/mhdfl8krXn0UnosfQ8ft13B8kqtRGqI8o+BJEDl/QRpUuXfu+/lUolgoJS5+DDwsLg5eVljLKIiIiIjOrRiyuYfmwkzqveaF33UKnxvWcjNGw4E4KFlYGrIzJv+248x487biI6UToSVsLDAfO6VURxDwcjVEZkWOHh4fD29gagv/tvdgQREREREX2GImnjYmUyHhcbvLoqQm5tMkKFROarZbm8ODTSD3WKuknW7r2IRev5Z7DsTDA0GvY1EH0qdgTRJ8uKRJKIiIjInMUnx2Lx0eFY8+oiVIIgWbcQRfS18MDXzRbD1rWIESokMk8ajYgV50Iw4+A9pKg0kvU6Rd3wZ+fy8HCyNkJ1RFmPHUFERERERCbIzsoBo1uuwNZmq1FNy+liSkHAv6qXaLezNY4eHMHTxYgySSYT8FWdQtg9tDZKaBkFO/PwNZrODcD+m8+NUB2ReWIQRERERESkJ0U8KmFpj1OYWWYw3EVpZ9BzhRyjXh7H4NVVEcpxMaJMK+HhiJ1DamOAbyHJWnSiEt+uu4KxW64jNonHzBN9DIMgIiIiIiI9EgQBzSp/m3q6mGsVraeLnVWIaH9pKv5a1xgJkY+MUCWR+bG2kGNCy1JY93V1eDhKR8G2Xg5Hi79O41JIlBGqIzIfDIKIiIiIiLKAnZUjRrdKGxdzlqynjou9QLudrXHs4EiOixFlUu2ibjg40hcty+WVrIVFJaLL4kDMOnwfSrV0TyEiYhBERERERJSlUsfFAjIcFxv58hgGr66G0NubjVAhkflxtrXE/O4VMadreThYKd5b04jA38cfouOic3j8Ks5IFRKZLgZBRERERERZ7N1xsb6ulXWMi2nQ/uIU/LWuMRI5Lkb0UYIgoH1FL+wf4YtqBV0k6zfCo9HyrzNYdz4UPCyb6P8xCCIiIiIiMhA7K0eMabUSW5uuynBcrO3O1jh2aBTHxYgywdvFFhsG1sC4ZsWhkL3fdZeoVGPCjlv4etUlvI7j9xMRwCCIiIiIiMjgiuStjKU9AvBH6W+QW9e42IujHBcjyiS5TMC39Ypix7e1UTi3nWT92L0INJsbgGN3XxqhOiLTwiCIiIiIiMgIBEFA8ypDsCdT42JNkBj12AhVEpmXsl5O2DfMF/41CkjWXsel4KtVlzBhx00kpKiMUB2RaWAQRERERERkRGnjYluarERVneNiz9FuRyuOixFlgo2lHFPblcGKvlXhZm8pWV93/gla/XUGN8LfGr44IhPAIIiIiIiIyAQUzVcFyzIYF3v23rjYFiNUSGRe6pdwx8GRfmhUMo9k7fHreHRYeA7zjwdBreFG0pSzMAgiIiIiIjIR746L9XGplMG42GT8vZ7jYkQf42ZvhX97V8b0DmVhYyF/b02lEfHn4QfoujgQYVEJRqqQyPAYBBERERERmRg7K0eMbb0qw3GxJcq0cbHREFUphi+SyEwIgoDu1fJj/whflPdykqxfCn2D5vNOY+vlcB4zTzkCgyAiIiIiIhOVNi42o9TADMbFjuDbVVXx5M5WI1RIZD4Kudlh6+BaGN6gKD44ZR5xySqM3XIdQ9ZfwZt4BquUvTEIIiIiIiIyYYIgoEXVYdjd7TT6uFSCXEvHwhmFBu0uTOK4GNFHWMhlGN2kOLZ8UxPeLjaS9f03X6DZvACcDnplhOqIDINBEBERERGRGbC3dvrfuNgKVJFLx1veHRc7fngMx8WIMlC5gAsOjPBD58pekrWXMcnwX3YBU/bcQZJSbYTqiLIWgyAiIiIiIjNSLF9VLO95OsNxsRHPD3NcjOgj7K0UmNm5PBb1rARnWwvJ+vKzwWg7/yzuPo8xQnVEWYdBEBERERGRmXl3XKy3S8UMx8Xmr2+KxKhgI1RJZB6al82LQyP94FvMTbJ2/2Us2s4/i6WnH0PDY+Ypm2AQRERERERkpuytnfBd69UZjostVj5Dux0tOS5GlIE8jtZY1a8afmldCpaK92+TU9QaTNt3F72Wncfz6EQjVUikPwyCiIiIiIjMXNq42O+lBsAtg3GxIauq4smdbUaokMj0yWQC+tUuhL3D6qBkXkfJ+rlHkWg6JwB7bzwzQnVE+sMgiIiIiIgoGxAEAS2rDseeDMbFTis0aHfhF8xf34zjYkQ6+ORxwM4htTDIrzCED3LVmCQVhq6/itGbriEmSWmcAom+EIMgIiIiIqJsJHPjYk//Ny42luNiRFpYKeQY36Ik1n1dHXmdrCXr268+RfO5p3EhOMoI1RF9GQZBRERERETZUObGxQ5hyKpqCLu7wwgVEpm+WkXccHCEH1qXzydZe/o2EV2XBOKPg/eQotIYoTqiz8MgiIiIiIgom8rcuJga7f77OXVc7E2I4YskMnFOthb4u3tFzOtWAQ5WivfWRBFYePIROiw6i4cRcUaqkOjTMAgiIiIiIsrm0sfFGi9HZS3jYimy1HGx9ttb4MSR7zguRqRF2wqeODDSF9UKuUjWbj2NQau/T2NNYAhELYErkSlhEERERERElEMU86yGFRmMiz1VyDH82cH/jYvtNHyBRCbOK5ctNgyogR+al4CF/P3voSSlBj/vuo1+Ky8iIjbJSBUSfRyDICIiIiKiHOTdcTF/lwoZjIv9hAXrmyHpbagRqiQyXXKZgG/qFsGOb2ujqLu9ZP3k/VdoNvc0Dt9+YYTqiD6OQRARERERUQ5kb+2Eca3XYHPjZagsd5Ssp8gE/KN8inbbmv9vXIxHZRO9q4ynE/YOq4M+NQtI1qLiUzBwzWWM334D8ckqI1RHpBuDICIiIiKiHMzHszpW9DyD6SW/znBcbOiqqhwXI/qAtYUck9uWwYp+VZHbwUqyvuFCGFr+dRpXn7wxQnVE2jEIIiIiIiLK4QRBQKtqI7CnawD8c2kfFwtIGxfb0JzjYkQfqF/cHQdH+KJJqTyStZDIBHT6JxDzjgZBpeYx82R8DIKIiIiIiAgAYG/jjHFt1mBz46WopGtcLCUc7bY1x8kj4wA1R16I0rjaW2Gxf2XM6FgWtpby99bUGhFzjj5A58WBCI2MN1KFRKkYBBERERER0Xt8PGtgZc8zmF7yK53jYsOeHcCQlVUQdm+XESokMk2CIKBr1fzYP9wXFbydJetXn7xFi3mnsfliGI+ZJ6NhEERERERERBKp42IjsbvrKfTKaFwscAIWbGiOhNdBRqiSyDQVdLPD1m9qYmSjYpDL3g9T41PUGLftBgavvYKo+BQjVUg5GYMgIiIiIiLSycEmF77PxLhYk93t8Nf6Jnj1+LgRqiQyPQq5DCMb+WDLNzVRwNVWsn7w9gs0mxuAUw9eGaE6yskEkf1o9BGlS5d+77+VSiWCglI/8QkLC4OXl5cxyiIiIiIiAxNFEXsvzsOsO8sRKWi/jVCIIlrADr3LDUTxCv0AGT97JopPVmHq3jvYeDFM63rfWgXxQ/MSsLaQa12nnCs8PBze3t4A9Hf/zb+ViYiIiIgoUwRBQOtqI7Gn6yn0ylVe67iYShCwW0hAp5tzMWB5BZw58RPEZG6OSzmbnZUCv3csh8X+lZHL1kKyvvJcCFr/fQa3n0UboTrKadgRRJ8sKxJJIiIiIjI/958G4q+ACQhIyXi0pYhSjd7uNdDSbzKsnL0NVB2RaYqIScJ3W29oHQmzkAsY26Q4vvYtLNlbiHKmrLj/ZhBEn4xBEBERERG96/HLa1hzdir2xNxHsqD75tVFrUY3u8LoVusn5PKuYcAKiUyLKIpY818oft13F8kqjWS9RmEXzOpSAZ7ONkaojkwJgyAyCQyCiIiIiEibqLiX2HRmMja+OIMoHXsIAYCVRoM2Mmf4VxqKQmW6ARmER0TZ2cOIWIzYeA23n8VI1hysFZjWrgzaVvA0QmVkKhgEkUlgEEREREREGUlWJWHvhdlYHbQNj5Hx8dh11Qr08emCKtVHQ7CwMlCFRKYjRaXBnKMP8M+pR9B2d962Qj5MaVsGTjbSvYUo+2MQRCaBQRARERERZYZG1ODsnY1YdW0RzqveZnhtSaUGvfPVRVO/SbCwdzdMgUQm5PzjSIzefB1P3yZK1vI5WWNWlwqoWcTVCJWRMTEIIpPAIIiIiIiIPtX98ECs/u837I8LhiqDUTB3lRo9HUuik98vcMxTzoAVEhlfTJISv+y6jR1Xn0rWBAEY6FcYoxv7wErBY+ZzCgZBZBIYBBERERHR54qIfoINZyZh86uLiMlgayBbjQYdLNzRs+poePm04j5ClKPsvv4MP+24iZgklWStVF5HzOtWAcXyOBihMjI0BkFkEhgEEREREdGXSkiOw67/ZmBN8F6ECdKb3TQyUURD0Rp9SvqjfNUhgFxhwCqJjOfZ20SM2XwdgY8jJWtWChnGNy+BPrUKQmBImq0xCCKTwCCIiIiIiPRFrVHj5I3lWH1zOa5o4jK8trwK6OPdBA18f4bcxtkwBRIZkUYjYumZx5h56D6Uaumtu59PbvzZqRzcHa2NUB0ZAoMgMgkMgoiIiIgoK9wMPopVF2biSOJTaDLocvBUaeDvUh7t/abA1rWoASskMo47z2IwctNVPHgpDUtz2VpgeodyaFbGwwiVUVZjEEQmgUEQEREREWWlp5H3se7MFGx/cwPxGUy9OKg16Gztie41foBH4QaGK5DICJKUasw4eA8rzoZoXe9SxQsTW5eGvRXHJ7MTBkFkEhgEEREREZEhxCa+wfZz07A27CheCBqd1ylEEc1gjz7lBqBEhX6ATGbAKokMK+DBK4zdch0RscmStfwutpjTtQIqF8hlhMooKzAIIpPAIIiIiIiIDEmpTsGRy4uw6t463BETM7y2mkqGPoXboE7tHyCztDNQhUSG9SY+BeO338TB2y8kazIBGNqgGIY1KAoLOUNRc8cgiEwCgyAiIiIiMgZRFHE5aA9WXZ6LU8kREDPYR6iQSgN/9+po7TcF1k58v0rZjyiK2Ho5HJN230Z8ilqyXsHbGXO6VkAhNwai5oxBEJkEBkFEREREZGwhL69j7dkp2BVzH0kZBEK51Bp0tSuEbrUnwtWrmgErJDKMJ5EJGLX5Gi6HvpGs2VrKMbFVKXSt6s1j5s0UgyAyCQyCiIiIiMhUvI1/ic2np2D989OIlOm+tbHUiGgtd4Z/xaEoUqYrwJtiykZUag0WnXyEuceCoNZIvw8al8qD3zuUhau9lRGqoy/BIIhMAoMgIiIiIjI1Kapk7D8/C6sebsNDpGR4bR21An18uqB69dEQLHhjTNnHtbC3GLXpGoJfx0vW3OytMLNzOdQv7m6EyuhzMQgik8AgiIiIiIhMlSiKCLyzAauuLcI51dsMry2uEtE7nx+a+06ChT1vjil7SEhRYdq+u1h//onW9d41C2B885KwsZQbuDL6HAyCyCQwCCIiIiIicxAUHojVgb9hX3wwlBmMguVWa9DDsQQ6+06CU56yBqyQKOscvfMS32+7gch4aYdckdx2mNetIsp4OhmhMvoUDILIJDAIIiIiIiJz8jr6CTacnoTNry7ibQanadtoNGhn4Q7/qmPgXbyV4QokyiKvYpPx/bYbOH4vQrKmkAkY3cQHg/yKQC7jnlmmikEQmQQGQURERERkjhJT4rAncAZWB+9FqKDSeZ0gimggWqNPqd6oUOVbCHKFAask0i9RFLHu/BNM23cHSUqNZL1aQRfM6lIe3i62RqiOPoZBEJkEBkFEREREZM40oganri3H6lvLcEkTl+G1ZVVA7/xN0ajOT1DYOBumQKIs8OhVHEZuvIabT6Mlaw5WCkxpVxrtKnjymHkTwyCITAKDICIiIiLKLm4HH8PqCzNxKDEc6gxugPOpNOjpUh4d/CbD3rWYASsk0p8UlQZ/HQvCwpMPoeWUebQqlxe/tisLJ1sLwxdHWjEIIpPAIIiIiIiIspsXUQ+w7vRkbI26jrgM9kux12jQ0coLPWv+gLyF6huwQiL9uRgShVGbriH8TaJkLa+TNWZ1Lo9aRd2MUBl9iEEQmQQGQURERESUXcUlvsGOc79ibdgRPBOk+6mkkYsimgj26FN2AEpX6AfIMtiFmsgExSQpMWn3bWy/8lTr+gDfQhjbtDisFDxm3pgYBJFJYBBERERERNmdSq3E0csLsfreOtwUpV0T76qskqFP4TaoW3s8ZJbccJfMy94bzzBhxy1EJyolayU8HDCvW0UU93AwQmUEMAgiE8EgiIiIiIhyClEUcS1oD1ZfnotjyREQM9hHqIBKA3/3GmjjNxk2TnyPTObjeXQixm65jrMPIyVrlgoZvm9WAv1qFYSMx8wbHIMgMgkMgoiIiIgoJwp7eQNrzk7Bzph7SMwgEHJSa9DFrjB61P4Zbl7VDFgh0efTaEQsPxuMPw7eR4paOhZZp6gb/uxcHh5O1kaoLudiEEQmgUEQEREREeVk0fER2HJ6MjY8P40Ime7bKQtRREtZLvhXGgKf0l0BHstNZuDeixiM3HgN917EStacbCwwvUNZtCib1wiV5UwMgsgkMAgiIiIiIgKUqhQcPD8Lqx5uxX2kZHhtLbUCfXy6omaN0RAUlgaqkOjzJCnV+PPQfSw9E6x1vVNlL/zSuhQcrHnMfFZjEEQmgUEQEREREdH/E0UR5+9sxOpri3Ba9SbDa4uqRPTOVxct/SbB0i63gSok+jxnH77GmM3X8SImSbLm7WKDOV0qoEpBFyNUlnNkxf03zzgkMjN9+/aFIAgoWLCgsUsxK5GRkRg7dixKliwJGxsbCIIAQRAwd+5cY5dGREREZk4QBNQo3R0LewZgZ8Ml6GhbEJY6Pm9/qBAwMSIATTbVw+KtnfD25U0DV0uUebWLuuHgSF+01DIKFhaViC6LAzHr8H0otewpRKaLQRCRFtHR0ViwYAFatGiBggULwtbWFk5OTvDx8UHPnj2xadMmqNVqY5dJmRQdHY2aNWti1qxZuHfvHpKSpJ9oEBEREelDEa+amNR5Dw6334/BrtWQS8f9caRchvnx99F4fzdMXVcfIQ/2GbZQokxytrXE/B4VMbtLedhbKd5b04jA38cfouOic3j8Ks5IFdKn4mgYfbLsPhr277//Yvz48YiMlB6d+K5SpUph8eLFqFOnjoEqS9W3b1+sWrUKBQoUQEhIiEFf21z99ttvmDBhAgBg3LhxaN26NZydnQEAefPmhaurqxGrIyIiouwsKSUeewN/x+rgPQgWdH+QKIgi6sIGvUv6o0qVbyHIFTqvJTKWsKgEjN58DRdDpCOQNhZy/NSqJHpUyw+BG6PrDfcIIpOQnYOgsWPHYtasWQAAhUKBbt26oU2bNihQoABSUlJw//59rF+/HsePHwcAWFlZYe3atejUqZPBamQQ9OkaNGiAEydOoEqVKrh48aKxyyEiIqIcSCNqcObacqy+tQznNRl3TpRSAX3yN0Vj359hYe1koAqJMketEfHPqUeYc+QBVBppnNCwhDtmdCoHN3srI1SX/XCPIKIstHDhwvQQyMvLCxcvXsSaNWvQuXNnVKtWDXXq1MFXX32FY8eOYd26dbC0tERycjJ69eqFa9euGbd4ytDTp08BAD4+PkauhIiIiHIqmSCDX8WvsdQ/EJvrzkVra08odHwmf0cBfP/sEFqsr4WVO3siNuqRgasl0k0uEzCkflHs+LY2Cue2k6wfuxeBZnMDcOzuSyNUR5nBIIgIQGhoKMaMGQMAsLOzw7Fjx1ChQgWd1/fo0QPLly8HACQnJ8Pf3x9srjNdycnJAAALCx5vSURERMZXsmBD/Nb1IA623oH+zuXgoKWrAgBeyGWYFX0DjXa1wYwNTfE0+ISBKyXSrayXE/YN84V/jQKStddxKfhq1SVM2HETCSkqI1RHGWEQRARg7ty56RsIT5w4MVOdIz179kSzZs0AALdu3cLevXsl19SrVw+CIKBevXoAgKCgIAwdOhTFihWDra0tBEGQjHfdvXsXffv2hbe3N6ytreHt7Y0ePXp88kjTixcvMGHCBFSpUgUuLi6wsrKCt7c3unTpgqNHj+p8XEhISPqJWitXrgQAbN++HS1atEC+fPmgUCjSv540Dx48wLBhw1CmTBk4ODjA0tIS+fLlQ4UKFdC/f39s2rQpPYz5XHv27EGnTp3g5eUFKysruLq6ombNmvj9998RFydtrz558mT61xEaGgoAWLVqVfqvvfv/5VNpNBps2LABHTt2RP78+WFjYwMbG5v0zcS3bt0KpVKp9bEpKSlYuHAh6tevj9y5c8PS0hIeHh5o0aIF1q5dC41G94kLH54Y9/btW0ycOBGlS5eGnZ0dnJ2d4efnh3Xr1ml9/JQpU9K/9qCgoI9+nU2bNoUgCMibN6/OzdF37tyJzp07I3/+/LC2toazszOqVKmCyZMn480b3cfnfvi1PH/+HN9//z1Kly4NBwcHCIKAkydPvveYJ0+eYPDgwShUqBCsra2RL18+tGvXDidOpL4pnjRpUvrXl5Ho6GhMnz4dtWvXTv9/kDdvXrRu3Rpbt27NMNRNe/5JkyYBAC5evIju3bun/7n09PSEv78/7t69m2ENaW7duoVhw4ahbNmyyJUrFywsLODh4YFGjRrhjz/+wPPnz3U+9nO/x4mIyHTkcS2GUW3X4Wi3M/jBsym8NNpvzxJkMqxNeYYWp4Zh7KoauHF1GcAPIckE2FjKMbVdGSzvWwVu9paS9XXnn6DVX2dwI/yt4Ysj3USiTxQWFiYCEAGIYWFhxi7ni2k0GtHFxUUEINrY2Ihv377N9GMPHjyY/nvRvn17yXrdunVFAGLdunXFnTt3inZ2dunXp/0IDg5Ov37Tpk2ilZWV5BoAokKhEJcuXSr26dNHBCAWKFBAZ11r167V+lrv/vjqq69EpVIpeWxwcHD6NcuXLxf9/f0lj61bt2769Zs3bxYtLS0zfC0A4s2bNzP9+/quxMREsX379hk+d758+cSrV6++97gTJ058tKZ3v47MCg4OFitUqPDR5z5x4oTWx5YoUSLDx9WpU0eMjIzU+trv/r+/d++eWLBgQZ3PM2TIEMnjg4KC0tcnTZqU4df54sULUS6XiwDEkSNHStajoqLEBg0aZPi1uLu7i4GBgR/9WgIDA0U3N7cMfw+PHTsm2tvba30dQRDEX3/9Vfzll1/Sf02Xo0ePiq6urhnW3aJFCzE2Nlbr49Ou+eWXX8QFCxaICoVC63PY2tqKp06d0lmHSqUSR40aJQqCkGEtffr00fr4L/keF0VRXLFixXtfCxERmQaVSikeuTBP7LWyqlhmZZkMf/gvKycePfGzqEqON3bZRKIoiuKr2CTxq5UXxALf75X8KDJ+n/j3sQeiSq0xdplmJyvuv7kVPeV4t2/fRlRUFADA19cXTk5OmX5so0aNYGNjg8TERJw5c0bndU+ePEGvXr1ga2uLn3/+Gb6+vpDL5bh48SLs7e0BpHYW9OzZEyqVClZWVhg1ahRatGgBKysrnD9/Hr/99hsGDx6MUqVKZVjT5s2b00fVChcujKFDh6JUqVLInTs3QkJCsGzZMuzfvx/Lli2Do6MjZs+erfO55s6dixs3bsDX1xeDBw+Gj48P3r59m97F9PLlS/Tr1w8pKSlwd3fH0KFDUaNGDbi5uSExMREPHz7EqVOnsHPnzkz/nn6oT58+2LFjBwCgfPnyGDNmDEqWLImoqChs3LgRK1euxLNnz9CwYUPcuHEDnp6eAICqVavi5s2bAFI7W549e4a2bdti2rRp6c9tZyedac7Iy5cvUbt2bTx79gxA6ibUffr0QYkSJSAIAoKDg3H8+HFs2bJF8ti4uDg0bNgQjx8/BgC0a9cO/fv3R758+RAcHIz58+fj1KlTOHPmDFq3bo2AgADI5XKtdSQkJKB169aIjIzETz/9hEaNGsHe3h5Xr17F5MmTER4ejgULFqB169Zo2rRp+uOKFi2K6tWr4/z581i/fj1++eUXnV/rpk2b0ruAevbs+d5acnIyGjVqhCtXrkAul6NHjx5o0aIFChUqBKVSiYCAAMyePRsRERFo0aIFrl69igIFCmh9nbi4OHTs2BFJSUmYMGECGjduDFtbW9y8eRN58+YFADx+/Bht2rRBfHw8FAoFBg8ejHbt2sHR0RG3bt3CzJkzMWHCBFSvXl3n1wMAZ8+eRfPmzaFUKpEnTx4MGzYM5cuXR758+fDs2TNs2rQJa9euxf79+9GnTx9s27ZN53MdOnQIFy5cQNmyZTFixAiULVsWiYmJ2LFjB+bNm4eEhAT4+/sjKCgIlpbST8cGDhyYPl6aN29eDB06FLVq1YKTkxNevXqFCxcuYOvWrVpfW5/f40REZFrkcgUaVR2ORlWH4/qDPVh9aQ6OpkRAo6Xb9apcg6uhO+D9eDt6uVdHO7/JsHXMPoe4kPlxs7fCv72rYMOFMEzdeweJyv/vKFdpRPx5+AFO3n+FOV0rwNvF1oiVEjuC6JNlt46gtWvXpn89P/zwwyc/vkaNGumPf/r06XtraR1B+F/XSmhoqM7nqVKlighAtLCw0NpJEB4eLnp5eaU/n7aOoFevXolOTk4iALF///46uwF+/PFHEYAok8nEe/fuvbf2bkcQALF3796iRqM9uV+2bFmmOn4SEhLEhIQEneu67N27N/35GzZsKCYnJ0uuWbJkSfo1Xbp00fo8BQoUyLC7IrPe7UyaMWOGzutiY2PFqKio935t7Nix6Y/96aefJI/RaDRiz549069ZuHCh5Jq0LhoAopOTk3jr1i3JNUFBQaK1tbUIQGzTpo1k/a+//kp/josXL+r8GqpXry4CEH18fCRraX9+nJ2dxUuXLml9fEhIiJg3b14RgNijR48MvxZ7e3vx2rVrOmtp165d+rU7duyQrMfHx4vVqlV778/th1JSUtI7qJo1aybGx2v/9PTdP0+HDx+WrL/7Gi1atND6Z3LatGnp12zfvl2yvmvXrvT1mjVrhuat+gAAkF9JREFUim/evNH5tT958uS9/9bH97gosiOIiMichL24Lv6+raNYbXnpDDuEai0rJc7Z2Ep8GX7B2CUTiY8iYsU2f5/W2h1UeuJBcculMJ33GPS+rLj/5h5BlOO9fv06/eceHh6f/Pg8efKk/zwyMlLndb///jvy58+vde3ixYu4dOkSAGDQoEHw8/OTXOPp6Zl+qpkuixYtQnR0NDw9PbFw4UIoFNqb/iZPngxPT09oNBqsXr1a5/M5Oztj/vz5OvdcefHiBQAgV65cKFOmjM7nSdtD51MtWLAAQOomzytWrNDaWTFgwAA0atQIQOpeRhntqfIl7t+/n97Z1K5dO4wbN07ntfb29siVK1f6fycnJ2Pp0qUAgNKlS6fvL/MuQRCwcOFCuLq6AgDmz5+fYT1Tp05F6dKlJb9etGhRtGvXDgC0dql17do1vdNI115Cjx49wvnz5wFIu4Hi4uLS/79MnToVlStX1vocBQoUwM8//wwA2LJlC+Lj43V+LePGjUP58uW1rj179gx79uwBAHTq1Cn9a3uXra0tlixZovP5AWDjxo0ICQmBtbU1Vq9eDVtb7Z9CDRgwANWqVQOA9D2ytLG2ttb5Z3L48OHpv3769GnJ+u+//55e99atW+Hs7KzzddKOCk2j7+9xIiIyfV55yuH7DltxpMtxjM7jhzw6thOMkcuwLCkETY/0w4Q1vrh/exO4jxAZS+Hc9tg6uBaGNygK2Qe3EnHJKozdch1D1l/Bm/gU4xSYwzEIoo8qXbr0ez8aNGhg7JL0KjY2Nv3naWNan+Ldx8TExGi9xtLSEp07d9b5HO9u7NqvXz+d17Vv3z7Dm8bdu3cDAFq1agUrKyud1ykUCtSsWRMAEBgYqPO61q1bw8HBQed62ujOmzdvsGvXLp3XfQ6VSoVTp04BAJo0aSK5IX7XgAED0h/z4QbD+rJv3770TYRHjRr1SY+9fPky3r59CyB1k2RdI1+Ojo7o0qULAODOnTs6Qy1BENCjRw+dr5cWzkRFRaW/bhp3d3c0btwYQOr4l7bNqdevX5/+8w9f59SpU4iOjgaQGsxkJC3QVCqVuHz5ss7rPgyb3nXixIn0ETV/f3+d15UvX15nmAT8//dG3bp1kTt37kzVndH3RuPGjeHu7q51zcHBAcWKFQOA9FHANJGRkfjvv/8ApIZy+fLly7CWD+nre7xv374QRRGiKGoNJomIyPQ42rmjX7MFOOB/Cb8X7YGSovTDCABQCQJ2a96i06Vp+HplZQScnQ6NijfbZHgWchlGNymOLd/UhLeL9EPh/TdfoNm8AJwOemWE6nI2BkGU470bdGg7fepj3n2Mo6Oj1muKFSsGa2trnc+RtpeNpaVlhjezFhYWqFixotY1tVqNa9euAQAWL1783ulY2n6k7T+S1tWjTbly5XSuAUCbNm3Sg6n27dujQYMGmDNnDi5fvqzzlKnMevz4MRISEgDgo3u/vLt+69atL3pdXa5evQog9f9BjRo1Pumx79akj6/Fzc0tvXNIGxcXl/Sfvxt0pkkLXp4/f47jx49L1tOCoOrVq6No0aLvraV1rgGpQWBGf8be7RLT9efM3t4ehQsX1vm1vPt7oKv7KE2VKlV0rqXVfejQoY9+b/z5558Z1gwAJUqUyLCWtP8HH/7+X7t2LT1Q9PX1zfA5PpQV3+NERGR+LBRWaFl7PDb1uYTllX9EPXkundeelykx5OF6tF9dCdsODEFyPG+4yfAqF3DBgRF+6FxZuofVy5hk+C+7gCl77iBJ+WX3D5R5DILoo27fvv3eD203jubMzc0t/eefc8P08uXL9J/rujl/d0xIm7TNql1cXHR2i6R5dxTtw+dQqVQZPlabtLBFm4/V7erqit27d8PT0xOiKOLEiRMYPXp0+nHWHTp0wN69ez+5JuD/f08A6Oy8SPPuSN+7j9OntBFCFxcXreNAGdH316JrrCmNTPb/f7VrC+TatWuX/hwfjodduXIF9+7dA6C9UyciIiLD19ZF15+zjDrcALx3BP3HOnkyWv+cuhMTE3WuZfb/wYe//++OoqZ11GVWVnyPExGR+RIEAVXLdMffvQKwu+G/6GxbEFY6RsEeywVMighAk031sGhbJ0S9zJoPzoh0sbdSYGbn8ljYsxKcbS0k68vPBqPt/LO4+1z7hAXpF08Noxzv3a6XtK6PzFKr1bhx4waA1JtQXWMeHwt30ujaiyeztaT5+uuvMWLEiEw9LqNQIzN1+/r64uHDh9i2bRv279+PgIAAhIeHIyYmBjt27MCOHTvQtGlTbN++/aM3z7p8ye+LqTGFr8Xe3h5t27bFhg0bsH37dixatCi9Yy2tG0gul6Nr166Sx7775+zKlSuwsJD+Q66Nl5f2U0wy+73xpdLqbt68Of744w+DvKa+ZcX3OBERZQ+FvGpgYuc9GBYThk0Bk7Dh1QVEafnIP0ouw8K4+1h6oBtaW7ijd9XRKOzTyuD1Us7VomxeVMqfC99tvY7TQa/fW7v/MhZt55/FuGbF0b92Icg+3FyI9IZBEOV4ZcqUgYuLC6KiohAQEIDo6OhMHyF/9OjR9E/bP3XM411pnTeRkZFQq9UZ3hy/24H0rnfHgURRzHDzZn2ztrZGz5490ztIgoODsW/fPvz999948OABDh06hAkTJmDOnDmZfs53vx5dX3Oadzu53n2cPqV1jkVFRSElJeWTbq4//Fp8fHx0XmuIrwVI7fbZsGEDYmJisHfvXnTq1AkajQYbN24EoHsPnHe73nLnzq0z4NGXd7vSXr16BU9PT53Xvnqlu93d1dUVz549Q0pKikG/Nz70bgfip25sbszvcSIiMg+5HL3xTatl6JcSj/2Bv2P14z14KJN2B6cIArapXmFb4Hj4nZuM3iX9Ua3KEAgG+oCGcjYPJ2us6lcNK8+F4PeD95Ci+v89K1PUGkzbdxfH70VgVpfyyOv06QfO0MdxNIxyPEEQ0Lt3bwCpoyD//vtvph/7999/p/+8b9++n11D2bJlAQApKSm4fv26zutUKlX6HiEfsrS0TD9F6uzZs59diz4UKlQIQ4cOxcWLF9ODgs2bN3/ScxQuXDi9gyjtBCtdLly4kP7zrLo5rlSpEoDUjY8z2kRYm3drMoWvBQCaNm2aHkqkdQGdOnUKT58+BaB7A+d396gyxJ+zd09Gy2jDaeD9/Ys+lFb3pUuXkJJivA0zK1asmN4VFhAQ8EmPNaXvcSIiMm1WlnZoX3cqtve9in/KjURNme4DUQKEJHx97190WVURe46MhTIx2oCVUk4lkwnoX6cQ9gytgxIe0sNpzj2KRNM5Adh745kRqsv+GAQRARgxYkT6CTyTJ0/Gw4cPP/qYjRs3Yt++fQBSb9hbtfr8ttq0488BYNWqVTqv27Fjx3t7pnyoTZs2AIB79+7h0KFDn12Pvjg6OqJq1aoA3t8bJTMUCgXq1q0LADhy5AjCw8N1Xpt2NLtCoUC9evU+r9iPaNmyZfoN/Ny5cz/psZUrV07fC2fVqlVaT+oCUjcWTgvMSpUq9cl7yHwKhUKRfkLZ/v378fbt2/RAyNbWVusx7UDqn9W0gO6vv/5K3/g4q9SrVy99v501a9bovO769esZhqhp3xvR0dFYsWKFfov8BC4uLqhVqxaA1HD02bNPe3Njat/jRERk2gRBQO2KX2GJfyC21p2HtlaeUOj4t/ueXMSPzw6h2YZaWLazF6KjHhm4WsqJins4YNfQ2hjoVxgf7qAQk6TC0PVXMXrTNcQkKY1TYDbFIIgIQMGCBTFz5kwAqaeANWzYMMObys2bN6NPnz4AUj+lX7NmzRft/VKtWrX0jpNFixbhzJkzkmueP3+OsWPHZvg8I0aMSD/Ovl+/frh9+3aG1+/bty99j6PPcejQoQzHW6Kjo9M7XAoVKvTJzz9kyBAAqZ1SX331f+3dd1gUV9sG8HuXLgqICjYELNi7osYCNiyAXbEDdqPGaDTWWGKMMRqjscRYwIrRWMHesGBHUUHFCgpWrChF2nx/8O28C1tY6qJ7/65rr4ycM2eeGZbgPp7znKFITlb8BeDj44OjR48CAHr06JFvyRMHBwd0794dALB3717x/aJMXFxchoSdkZERhg0bBiB9F6x58+YpnCMIAsaOHSsmzMaOHZuX4Sslm/Xz+fNn+Pn5YdeuXQCArl27iu+jzCwsLMTYzp8/jwkTJqhMbAHpS+FkibqcKF++PFxdXQEAO3fuxN69exX6JCQkYMSIEWrH8fT0hI2NDQBg0qRJWc7GCQoKwunTp3MWdBamTJkCIL2Ic+/evfHhg+p/ec2cAM2rn/ENGzaIu4tx+3giIt1Q1a4Nful7GEfd92C4RR2YpSlPCL3Sk2Lphxtov68LFmzrgKjIUwUbKOkcI309TO9cHVuHNkEZc8WdlneHPEWnpWdxOSJ/NoXRRUwEEf2/cePGicVXnzx5gkaNGmHw4MHYuXMnrly5gvPnz8PHxwft2rWDh4cHkpKSYGRkhK1bt6JevXq5vv6qVaugr6+P5ORktG/fHtOnT0dQUBCuXLmCFStWoGHDhnj+/Lna7eWtra2xceNGSCQSPH/+HI0aNcLo0aPh7++Pa9eu4dKlS9i1axemTJmCSpUqwc3NDU+ePMlxzNu2bYOtrS1cXV2xbNkynDhxAiEhIThz5gxWrVqFZs2aiUuNRo0ale3xXV1d0bt3bwDA0aNH0bRpU2zduhVXr17F8ePHMWzYMDHBYmlpiSVLluT4XjSxatUqsSD4jz/+iLZt22Lz5s24cuUKgoODsXPnTowZMwYVKlRQSCTOmjVL3CZ9zpw56NWrFw4cOIBr165h165daNOmDTZt2gQAaNasWZaJjbzwzTffiAm6GTNmiMkrVcvCZH7++Wdxm/tly5ahQYMGWLlyJc6dO4fr168jMDAQK1asQLdu3VChQgWsXr06V3EuWbJEnIXUu3dvfPfddwgMDMTVq1exceNGNGrUCJcvXxZnnyljZGSEHTt2wMjICJ8+fUKbNm0wcOBA7Ny5E1evXsWVK1fg7++P2bNno06dOmjZsiVCQ0NzFbcq7u7uGDp0KID0ZFqNGjWwYMECnDlzBtevX8fx48fx22+/oX79+pg5c2aGcwv6Z5yIiL4+pUpUwXddt+JYv3OYUa4DKqQp/0iYIJXCL+kZ3E6NxYSNzXA9ZD2QzzOBSbd9U7kkDo9vBbc6iv+w+/R9AjzWXMDvmWoKUQ4JRNkUFRUlABAACFFRUdoOJ8/9/fffgqWlpXiPql7Vq1cXzpw5o3YsJycnAYDg5OSk0bX9/PwEQ0NDpdfT19cX1qxZI3h6egoABFtbW5Xj+Pv7a3QPUqlUOHnyZIZzIyIixHZfX1+18cpiyeo1atQoITU1VaNnkFlCQoLQvXt3teOXLVtWCAkJUTmGra2tAEDw9PTMUQzyHj58KNSqVSvLew4MDFQ4NyIiQqhWrZra85o3by68efNG6bU1+d4LgiD4+vqK40VERKjtO2PGjAzXL1mypJCcnJzlc4iNjRV69Oih0fe/devWOb4XmaNHjwqmpqYqrzF79mzhp59+EgAIxsbGKse5cOGCYGNjo1HcGzduVDhf/nrqZPWzn5KSIowdO1aQSCRqY1D1ns3Nz7ggZHyPZHUvRET0dUtJSRZOXFoqDN7QWKi1oZbaV//1dYUjgT8JyZ/jtB02fcXS0tKEPdeihVqzDgu2U/YrvFz/OiPcf/lR22EWmPz4/M0ZQUSZjBo1Cg8fPsTy5cvRsWNH2NjYwNjYGEWLFkWlSpXQt29fbNu2DaGhobnaKUyZfv36ISQkBIMGDULZsmVhaGiIcuXKoU+fPggKCsLw4cM1Gsfd3R0RERFYvHgx2rRpA2traxgYGMDExAT29vZwc3PDkiVLEBkZidatW+c43j///BNbtmzBkCFD0KhRI5QrVw6GhoYwMTGBg4MDPD09cfbsWfz9999inZfsMjY2xu7du+Hv748ePXqIz6V48eJo0qQJFixYgLt37+bJrCxNVKxYEdevX8eGDRvg6uqKMmXKiM/WwcEBgwcPxr59+5S+N+zs7HDjxg2sWLECTk5OKFGiBAwMDGBtbY2OHTti8+bNOHPmTL7uFpZZ5tk/ffr0gb5+1htKFitWDLt27cLZs2cxbNgwVK1aFcWKFYO+vj4sLS3RuHFjjBkzBgcPHsSxY8dyHWf79u0RFhaGkSNHwtbWFoaGhrC2toarqysOHz6MOXPmIDY2FgDU7vrXtGlT3L9/H6tXr4arq6v4fjI2NoaNjQ1cXFwwf/58hIeHi0Xk84Oenh6WL1+O4OBgjBgxAg4ODjA1NYWBgQFKly4NFxcXLFmyBIsXL1Z6fkH9jBMR0ddPT08fbRzHY6PnZWxr9is6GZSCnoqZPzf1UvHD4z1w2+KILfuHIu6D6hqORDklkUjQrX45HPq+JRztFf9eHPY0Fm7Lz2Lzhch8r1f5tZIIfHKUTdHR0WKtjaioqHzfPpqISBPt2rXDiRMn0KJFC5w9e1bb4RAREX2xnr28Cb9zP2Pnh3DESVXXwSyWloZeppXQv/lMlC7nWIARkq5ITROw5swjLDl2F8mpiqkL56ql8HuvOrAqplhb6GuRH5+/OSOIiIi+eM+ePRMLQDdt2lTL0RAREX3ZylrXwaQeO3G8z0lMsm6FMipKsnyUSuGbEIFOx4ZgyuaWuH1rO+sIUZ7Sk0ow2rkS9nzbHJVKmSq0n7obg45Lz+LorRdaiO7LxUQQEREVeg8ePFDZlpCQAC8vL3FXufxc0kVERKRLippawbPjShwcdBWLKvdHLcFQab8UiQQH097DI/gXDNnQEKfOLUBaSlIBR0tfs1rlzLF/XEt4NrNVaHsbl4QRm69i2u6biPucooXovjxcGkbZxqVhRFTQnJ2dERcXhz59+qBhw4awtLTEx48fERwcjFWrVomJoqFDh+Zqu3oiIiJSTRAEhNzejo0hqxCY8haCRPWyMbtUYFDZVnBvOQcmpqUKMEr62gXefYXJ/93E60+fFdrsShTB0r71Uc/GouADyyf58fmbiSDKNiaCiKigOTs74/Tp02r7dO/eHVu3boWJiUkBRUVERKS7HkdfxJYL87HvUwQS1NQRskhNg4d5dfRtMQclrWsVYIT0NXvz6TOm7g7FsdsvFdr0pBKMb1sF3zpXgr7el78IiokgKhSYCCKignbt2jXs2bMHJ0+eRHR0NGJiYiAIAqysrNC0aVN4enqic+fO2g6TiIhI53z4EI0dZ2fBL+YyXqtJCBkIAtwMrDDY8QdUruJagBHS10oQBOwIjsLcgNuIT0pVaG9QwQJ/etSDbQnF2kJfEiaCqFBgIoiIiIiIiOQlJcXj0IUF2PgoAPelih/K5TUXjDG4xmA0a/gtJHp6BRQhfa0iX8fh++3XcT3qvUKbqaEeZnepid4Ny0OiZiljYcZEEBUKTAQREREREZEyQloaLt7wxcbQdTgnfFLbt0qqBINtXNC55U8wNDYvoAjpa5SSmoblJx9gReADpKYppjg61iyNBT1qo7ip8oLnhRkTQVQoMBFERERERERZeRAZiM0Xf0NA4lMkq5mNUTJVQH/Leujdcg4sSlQuwAjpa3P18TtM3HEdj9/EK7RZFTPC4t510crhyypezkQQFQpMBBERERERkaZev32A7WdnY/vbG3inpo6QSZqALiblMajpVNjaORdcgPRV+fQ5BfMCbmN7cJTSdq9v7DC1UzUYG3wZyxKZCKJCgYkgIiIiIiLKrsTED/APmofNUUcRKVX9MVQiCHCWFoNn7WFoUNcbEumXv/MTFbzDYS8wbfdNvItPVmirYlUUS/vWQ82yhX9JYn58/uZPFFEecnZ2hkQigbOzs7ZDISIiIiIqVIyNzdGn3WLs8wzBiurD0BgmSvsJEgkChU/wurkUwzc54k1MeAFHSl+DjrVK48j3rZQuBbv/6hO6rTyHI7deaCEy7WMiiEhOXFwcVq9ejc6dO6NcuXIwNjaGkZERSpUqhcaNG2PIkCFYu3YtoqKUTzP8kp06dQoSiUTpq0iRIrCxsYGbmxt8fHzw+fPnLMeTnZtVUiwlJQUeHh5i/6ZNm+L9+/d5c1MqnD9/HgMHDoStrS2MjY1RunRpdOjQAdu2bcuT8b28vFQ+y8yvyMjIbI8/ZcqUDGOcOnUq22PEx8ejYsWK4hh2dnYanbd//3706tUL5cuXh5GREUqWLImmTZti8eLFiIuLy3YcREREpHukUj04OY6Hj+dlbG+2AK4GpaCvYqHKJcln9A/ohXv3DxRwlPQ1sDIzxkbvxpjbpSaM9DOmP8xNDNHItriWItMufW0HQFRYXLhwAX379sWTJ08U2l6/fo3Xr18jODgYvr6+sLa2xosXupM9TkhIQHR0NKKjo3HgwAEsWbIE+/fv1zh5oEpycjI8PDywZ88eAECLFi1w8OBBFCtWLA+iVm7OnDmYN28e0tLSxK+9fPkSR48exdGjR7F161bs3LkTxsbG+RZDbly/fh1LlizJ9TizZs1CRESExv0/fvyIAQMGICAgIMPX37x5gzdv3uDSpUv4559/4O/vj+rVq+c6PiIiItINNRzc8JuDG75/GQq/oLnYGRuOj5nqCD3Tk2BQ0BT8/voOnJpN0lKk9KWSSCTw/MYO31QqgfH/Xsft57EAgEW96qBEUSMtR6cdTAQRAbh37x46dOiAjx8/AgC6dOmCXr16wcHBAYaGhnj9+jVu3LiBY8eOITAwUMvR5r/Ro0fj22+/Ff/86tUrhIWFYdGiRYiOjsatW7fQpUsXhISEQE8vZ0XWPn/+jF69emH//v0A0pfV7d+/H6ampnlyD8r8888/mDt3LgCgUqVKmD59OmrXro1nz55h2bJlCAwMxIEDBzBkyBD4+fnl+nply5bFkSNH1PYpV66cxuOlpaVhxIgRSElJgZWVFV69epWjuEJCQrB06VIYGxvDwMBAfN+rIggC+vTpg8OHDwMAGjZsiAkTJqBatWr4+PEjDhw4gOXLl+PBgwfo1KkTgoODUbJkyRzFRkRERLqptHVtTOy5EyPjXmHv2bnweX4ar+QSQvFSCcbd3YAf3tzB4M5rWTeIsq2KdTHsHdMcS47dQ0JSClpXs9J2SFrDRBARgBkzZogfhn19feHl5aXQp3379pg0aRJiYmKwY8eOAo6wYFlZWaFWrVoZvtamTRt4e3ujTp06iIyMRGhoKPbs2YNevXple/zExER069ZNTJK0b98e+/btg4mJ8nXieeHt27eYMmUKAKBChQq4ePFihmSFm5sbunfvjoCAAGzbtg0jRozIda0nAwMDheeYG3/99ReuXLmCatWqoXv37liwYEG2x0hNTcXw4cORmpqK2bNnY/369Vkmgnbt2iUmgdq3b4/9+/fD0NBQbHd2dkaHDh3QsWNHPH78GHPmzMGKFSuyHRsRERGRqakVBnRciQ5v7uH7gH64IUkS2wSJBIvfXEbEvy6Y0XMvDIyKajFS+hIZ6ksxtVM16PqeWUyjks5LTU3FgQPpa44bNWqkNAkkr1SpUhgzZkwBRFb4FCtWDDNnzhT/fPz48WyPER8fDzc3NzEJ1KlTJ/j7++drEggA1q1bhw8fPgAAFi5cqDBjRU9PD6tWrRJnOC1atChf48muJ0+e4KeffgIArF69OkMiJjuWLVuGq1evomrVqmJiLCsbNmwQj1euXKn02u3atUPfvn0BAGvWrMHbt29zFB8RERERAJQs4YD1/QLhaqBY6HdX8kuM2OaE928faiEy+hpIJJKsO33FmAginRcTE4OEhAQAQOXKlQvkmkFBQRg0aBDs7OxgbGwMCwsL1K9fHzNnzkRMTIzScxYvXgyJRAIDAwN8+vRJoT0xMRHGxsZi8d/r168rHadatWqQSCTih/bsql27tnic3aLZnz59QufOnXHixAkA6Uvw9u7dWyD1ePbu3QsAMDMzQ48ePZT2KV++PNq1awcAOHHiRJYzZQrSmDFj8OnTJ3h6esLJySlHYzx+/BizZs0CkL1kUnBwMID0n48qVaqo7NexY0cA6bWf/P39cxQjERERkYyRkRkW9D2O7ywbKrQFS5LQf283PHp4TAuREX3ZmAginSf/YfjOnTv5eq20tDSMHTsWLVu2xJYtW/D48WN8/vwZHz58wPXr1zF//nxUqVIFx44p/kKTffhPSUlBUFCQQvulS5cy7OalbCeply9f4u7duwCQ42VP8s/LwMBA4/NiY2PRsWNHnD59GgDQq1cv7Ny5M8tkhPxuZlnN1lIlKSkJly9fBgA0a9ZM7TVlz/nz589iAkTbduzYgf3798PS0hKLFy/O8Tjffvst4uLiMGjQoGx9/9+8eQMAsLa2VttPvv3MmTM5ipGIiIhInkQqxXD3DVhSZQCM0zIu54nSAwae+R7nLy/XUnREXyYmgkjnWVpawtbWFgBw48YNLFy4MMOOUnlp6tSpWLlyJQDA3t4eq1evxuXLlxEYGIgJEybAwMAAHz58gJubG27cuJHh3AYNGoi7aSlL8mT+WlZ9cjqrRD5ZpumuYR8+fICLiwvOnTsHAOjXrx+2bduWrURSbty7dw+pqakA0mdEqSPfntvE4Js3b+Dk5IQSJUrAyMgIZcqUQYcOHbBixQrEx8drNMb79+8xfvx4AMqXtGnq33//xcGDB1G8eHH88ccf2Tq3aNH09feypXWqyLffvn07+0ESERERqdD+m6nY+M18WKVmTAZ9lErx7e1/sO3Qt4CO130h0hQTQUQAxo0bJx5PnToVlSpVwvjx47F9+/ZsbbGtTmhoqPgBvFatWrh27RpGjhyJxo0bw9nZGUuWLIG/vz+kUimSkpIwYsSIDOfr6emhRYsWAJQneWQzbdzd3QGkz8jInNCS9bG2ts7RFt+pqakZaudoUij6w4cPaNeuHS5dugQAGDx4MLZs2QJ9/YKrVR8dHS0ely9fXm1fGxsb8Ti7S98y+/TpE86cOYO3b98iKSkJL168wNGjRzFu3Dg4ODjg/PnzWY7x448/4sWLF2jevDmGDh2aozjevXuH77//HgDw22+/oVQpxbX26sjeK3fu3FG5dBHIOAvoyZMn2Q+UiIiISI0aVbtim9t21EzL+PfIVIkEv746i1+2d0JKkmb/2Eaky5gIIgIwYcIEDBkyRPxzZGQk/vrrL/Tt2xcVK1ZE6dKl0bdvXwQEBOS4wvzff/8tJmbWrVsHCwsLhT4dO3YU47h8+TKuXLmSoV22nOfq1asZ6gR9/vwZFy9eBABMmTIFJiYmePfuHW7evJnhfFkCqVWrVtmKPSYmBidPnoSTkxNCQkIApCeBZIkpda5fvy4userXrx98fX0hLeDtPuVr/chmt6giv329slpMmpBIJGjatCnmz5+PQ4cO4dq1azh//jz++ecfODo6AgCePn0KFxcX8Xkqc/bsWaxbtw76+vpYvXp1jovaTZ48GS9fvkSzZs0wfPjwbJ/fpUsXAOmJQPli4fLu378PX19f8c+Fqb4SERERfT2srGrCt28gXPQsFdq2f36K0X5OiH3/WAuREX05mAgiAiCVSrF+/XocPXoUHTt2VJit8vLlS2zfvh1dunSBo6MjHj7M/g4Fsh22atasiSZNmqjsJ/9BPfOuXKrqBF2+fBkJCQkwNzdH06ZN0bRpUwAZZw69evVKXOqUVX2YuXPninV5JBIJrKys0LZtW5w7dw5FihTBxIkT4efnl/VNI2NF/gsXLuDZs2canSfj7OwMQRAgCEKG3auyIzExUTzOqiaRkZGReCwrIp5df/75Jy5cuIDp06ejY8eOqF+/Ppo1a4YRI0bg4sWLmD59OgAgLi4Ow4YNU5pclM0KEwQBEyZMyPE29GfOnIGPj0+ukkmjR49GuXLlAKTvCDZo0CDcvHkTSUlJePPmDTZv3oxWrVohLi5OXO6X02dHRERElBUTEwss6n8So8zrKLRdlCRiwG43PHnMeoVEqjARRCSnffv2OHToEN68eYODBw9i7ty5cHd3h7m5udgnODgYLVu2xPPnzzUe9/Pnz7h//z4AqE0CAUD9+vXFD9NhYWEZ2ho2bCjOaJFP8siOW7RoAT09PTHRI99HtiwMyHl9IACoV68evvvuO43r+7Ro0ULcoSwyMhJt27bFixcvcnz9nJDflSwpKUltX/mC2znd0l7ZbC8ZiUSC+fPno23btgAgzhbK7Ndff0V4eDgqVKiA2bNn5yiOz58/i8mk8ePHo04dxb8sacLc3Bz79u2DlZUVAGDLli2oW7cujIyMULJkSQwePBgvXrzAr7/+Kt67rJ4VERERUX6QSvUwpttWLLTvDcNM/6gWqQf0OzkaV66u0VJ0RIUbE0FESpiZmaFTp06YNWsW/P398fLlS/j4+KB48eIAgOfPn+Onn37SeLx3796Jx7IP06oYGBigRIkSAIC3b99maNPX10fz5s0BKE/yyBJAsv/K1wmS9SlVqhRq1qypNobRo0cjNDQUoaGhCAkJQUBAADw9PSGVSnH+/Hk4OzurrRUjTyqVYvPmzejWrRuA9MLN7du3F3eiKgjySYmslnvFxcWJx1ktI8uNkSNHisfySToACA8Px4IFCwAAy5cvz7BcLTvmz5+Pu3fvwsbGBnPnzs15sEhPQl6/fh1jx45V2D2scePG2L9/P6ZOnSouCZP9rBARERHlp86tZsHXcTZKZNpRLFYqxYjQv7Dr6PfaCYyoECu4aq1EXzAjIyN4e3ujbNmy6NixIwBg9+7dWLNmTbbr3eS0zouMs7Mzjhw5ItYJMjIywoULF8Q2IH3WkbGxsVgnqF69emKyQZP6QFZWVhmWItWrVw9ubm5o3bo1vLy8EBkZiWHDhmHfvn0axayvr4/t27eja9euOHz4MMLCwuDi4oKTJ09mmG2VX+QLRMsXjlZGvkC0fOHovFajRg3x+OnTpxna/vzzTyQlJaFixYqIj4/Hv//+q3C+/GyxkydPirOs3N3dxcTRwoULAQDt2rVDQECA0jhkia+4uDjxOlZWVmjTpo1C3zJlymD58uVYvnw5Xrx4gdjYWFhbW4vfw+joaHEZXlbJRiIiIqK8UqdGb2yzrIJxh7xwV5oqfj1FIsGc5yfwcIcbfui+C3oGRmpGIdIdTAQRZUOHDh1gY2ODqKgovHv3Dm/evNFoByb52REvX75U2zclJUWcLWNpqVgEL3OdoGLFiiE+Ph7m5uaoX78+gPTEVdOmTXHq1CmcOnUK5cuXx61btwBkXR9IHU9PTwQEBGDXrl3w9/fHyZMnlSYMlDE0NMTu3bvh6uqKwMBAXLt2DZ06dcLRo0fzdeYNADg4OEBPTw+pqakIDw9X21e+PSc7q2lKXUJQtjzt0aNH6NevX5ZjzZs3TzyOiIgQE0GyZXC+vr4ZCjkr8/r1a/FaTk5OWX5fS5cujdKlS2f42tWrV8VjWVFsIiIiooJQpnQ9bOpzHFN3d0Ng2ocMbZsTHiPSrxV+774bRc3KaSlCosKDS8OIsqls2bLisaaze4yMjFClShUAELdRVyUkJATJyckAoLRAcOPGjcUP+qdOnRJn+sjqA8nI1wk6c+aMWJA4N/WBgPTaNbLryIoea8rExAT+/v5o1qwZgPTi0e7u7vleWNjQ0FBMTFy4cEFtnSDZ8zQyMkKjRo3yLabbt2+Lx/LvqS/Zf//9Jx57eHhoMRIiIiLSRUVMS2Jp/1MYUqyaQttZxGPQzk54GqVYm5FI1zARRJQN8fHx4gd4MzMzsZaPJtq1awcAuHXrFi5fvqyy37p16xTOkaevr49vvvkGAMQZP4DiTB/5OkEnT54EAJQoUSLHu0/JODg4oE+fPgDSk1rHjh3L1vlFixbFoUOH0LBhQwDp99CjR48sizjnlqxGUWxsLHbv3q20T3R0tLhTW9u2bfO14PE///wjHmdOzm3YsEHcKU3VS76AdGBgoPh1Ozs78etZjSEIAmxtbQEAtra24tfk609p6vbt29i+fTuA9Petg4NDtscgIiIiyi2pnj4m9PgP8yp0gX6mItIP9AT0PzYcITc2aik6osKBiSDSeZ8+fUKTJk2wf/9+sbCyMmlpaRg3bpxYDLdLly7ZqvczevRosZ7QiBEjEBsbq9Dn6NGjWL9+PYD0pTWNGzdWOpYsyXP16lWcO3cuw9dkmjRpAiMjI7x79w5btmwBkF4fKLc1ioD0mUCycX755Zdsn29ubo4jR46gdu3aAIDDhw/Dw8MDKSkpCn1PnTolbmPv5eWV45iHDRsm1rKZOnWqQrHq1NRUfPvtt0hNTV9XPnnyZKXjzJkzR4xH2Xb2Fy9eVLujnCAImDlzpphwqlu3rlgAvDDLXMdIXlRUFLp27YqUlBQYGRlh+fLlBRgZERERkaJuredjfcNpKJ6piPRbPSmGhixCwIkpWoqMSPuYCCICcPnyZbi7u6NChQoYO3Ystm7diqCgINy4cQOnT5/G0qVLUa9ePfj4+ABIT2TI12XRRO3atfHDDz8AAG7cuIEGDRpg7dq1CA4OxunTpzFp0iS4ubkhNTUVhoaGGWaMZCZfJyhzfSAZY2NjNG3aFADw4UP6Ounc1AeSV6tWLXTp0gVA+oyjoKCgbI9RokQJHDt2DFWrVgUA7N27F4MHD1abjMsNS0tLsXjy48eP0aRJE/j6+iI4OBj+/v5o3769WFC5X79+OX5Whw8fhr29Pbp164aVK1ciMDAQISEhuHjxItasWYNmzZph/vz5AIAiRYpg7dq1eZKcy2+jRo1C06ZNsWjRIvGejh49iilTpqBWrVp48OABpFIp1qxZg2rVFKdjExERERW0BrUHYKvLelRKy/ixN1kiwfTog1i2swfSUpO1FB2R9rBYNOk8fX19lC5dGi9evMDTp0+xcuVKrFy5UmX/KlWqYNu2bRmW4Gjqt99+Q1xcHFatWoWHDx9ixIgRCn3Mzc2xY8cO1KtXT+U4jo6OKFKkCOLj4wEo1geScXZ2zrA1eW7rA8mbMWOGuGvYvHnzcOTIkWyPYW1tjRMnTqBly5aIiIjAtm3bYGJignXr1uVLcmTkyJF49uwZ5s2bh4cPH2LIkCEKfTp37iwm/HLq8+fP2Ldvn9pd1SpUqAA/Pz+Vs74KG0EQcOnSJZU1riwtLbFq1SrWBiIiIqJCxaZcE2zpfRSTd3dDkPApQ9u6uPuI2OqEX3vsQZGi1lqKkKjgcUYQ6TxjY2M8ffoU586dw9y5c9GpUydUrFgRpqam0NPTg5mZGapVqwYPDw/4+fkhLCxMrG+TXVKpFCtXrsSZM2cwYMAAVKhQAUZGRjAzM0O9evUwffp03L9/Hy4uLmrHMTAwEAsuA6pn+sh/3dLSEnXq1MlR3Mo0btwY7du3B5C+pO3KlSs5GqdcuXI4efKkuFW7j48Pxo0bl2dxZjZ37lwEBQWhf//+sLGxgaGhIaysrNC+fXv4+fnhwIEDMDY2zvH43t7eWLVqFQYNGoS6deuiTJkyMDQ0RJEiRVChQgV069YN69evx927d7+IJWEy06ZNw8SJE9G4cWOULl0aBgYGKFWqFJo2bYoFCxbg7t27TAIRERFRoVS0qDWW9z+FgaaVFNpOCB/hucMFL54FayEyIu2QCEKmClpEWYiOjhY/tEdFRaF8+fJajoiIiIiIiChrO45PwoLow0jJNPu8ZKqAvxxnoHatflqKjEi5/Pj8zRlBREREREREpBP6tFuM1fUmwixTEenXehJ4X5mPw6dmaSkyooLDRBARERERERHpjCb1hmBr29WwS8s4K+izVILJj/dg1e4+EP5/J1mirxETQURERERERKRT7Cq0wJaeB9AERRTa/v54Bz/6OSEx/o0WIiPKf0wEERERERERkc4xN7PB3/1PoY+JrULb4bQP8N7eFjEvb2ohMqL8xUQQERERERER6SQDAxPM7B2AqaWdIc20j1KYNBX9DvTHnTu7tRQdUf5gIoiIiIiIiIh0lkQiwYAOy7Gy1hgUzVRE+qWeBJ4XZ+HE2V+0FB1R3mMiiIiIiIiIiHRei0ajscV5OcqnZfx6glSC7x9tx7q9AyGkpSk/megLwkQQEREREREREYBK9q3h190fDWGs0Lbsww3M9GuDpMQPWoiMKO8wEURERERERET0/4pb2GNtv9PoZlRGoc0/9Q2GbXPGm5g7WoiMKG/oazsAKvxq1qyZ4c/JyclaioSIiIiIiCj/GRgWwc99DqPS4dFY8uocBIlEbAuRpmBAQG8sb/Ebqji4aTFKopzhjCAiOXFxcVi9ejU6d+6McuXKwdjYGEZGRihVqhQaN26MIUOGYO3atYiKilJ6vpeXFyQSicJLKpXCwsICdevWxZgxY3D9+vV8iT8yMlLp9SUSCYyNjVG2bFm4uLhg2bJliI2NzXI8Ozs7SCQS2NnZZdl34sSJ4rWqVKmi8hnllbCwMIwcORKVKlWCiYkJSpUqhZYtW2L16tVISUnJ9fhz5sxR+Swzv06dOqVynNu3b2PRokVwc3ODnZ0djI2NUaRIEdjb26Nv3744ePBglrFoGof8S50LFy5gyJAhqFq1KooWLQojIyOUKVMGHTp0wNq1a5GUlJTdx0VERET01ZFIpfDq/A/+qj4MRTIVkX6qJ8HAc1Nw5sIiLUVHlHMSQci0Rx5RFqKjo2FjYwMAiIqKQvny5bUcUd64cOEC+vbtiydPnmTZ19raGi9evFD4upeXFzZu3Jjl+VKpFFOnTsX8+fNzFKsqkZGRsLe316ivjY0N9u7diwYNGqjsY2dnh8ePH8PW1haRkZFK+wiCgO+++w4rVqwAAFSrVg0nTpxA2bJlsx2/ptauXYuxY8eqTFg4OjriwIEDKFmyZI6vMWfOHMydO1ejvoGBgXB2dlb4uqenJzZt2pTl+R06dMC///4LCwsLpe1ZJXYyc3BwwN27dxW+LggCxo8fj+XLl6s9v2bNmjh48CAqVKiQresSERERfa3uPjiCcWcn4XmmqRRSQcAPJRwxyHUdJFLOs6C8lx+fv7k0jAjAvXv30KFDB3z8+BEA0KVLF/Tq1QsODg4wNDTE69evcePGDRw7dgyBgYEajXnkyBExGZKWloaXL1/iwIEDWLlyJVJSUvDrr7+iXLly+Pbbb/Plnrp27YpffvnfNpfv3r1DeHg4/vzzT9y5cwdRUVFwdXXF3bt3YWZmlqNrCIKAUaNGYc2aNQDSEwgnTpyAtbV1ntyDMgcPHsSoUaOQlpYGa2trzJgxA02aNMHbt2+xdu1a7N69G5cvX0b37t1x6tQp6Onp5fqaoaGhattVJd+ePn0KALC0tESvXr3g7OwMOzs76OvrIyQkBEuWLMHdu3dx5MgRuLu74/Tp05Aq+QtEVtcHgI0bN2Lx4sUA0hNQyvz2229iEqhYsWKYOHEimjdvjqJFi+Lu3bv4448/EBYWhlu3bsHV1RUhISHQ1+evCSIiIqKqlTvAr3hFfB/QFzck//vHyDSJBIveXsGjf10wo+deGBgV1WKURBoSiLIpKipKACAAEKKiorQdTp7o1auXeE++vr5q+7569UpYsWKF0jZPT09xnIiICKV9/P39xT6lSpUSUlJSchn9/0RERIhje3p6Ku2TlJQkNG3aVOy3aNEilePZ2toKAARbW1uFttTUVMHb21scp27dukJMTEwe3YlySUlJQsWKFQUAgpmZmfDgwQOFPt9++63G30t1Zs+eLY6TU15eXsI///wjJCYmKm2Pi4sTWrRoIV5n48aNOb6Wo6OjAECQSCTC48ePFdqTkpIECwsLAYBgaGgohISEKPRJTk4WmjRpIsbz33//5TgeIiIioq9RYmKsMGVrG6HWhloKL2+fBsK71/e1HSJ9ZfLj8zfnrpHOS01NxYEDBwAAjRo1gpeXl9r+pUqVwpgxY3J8PXd3d7Rs2RIAEBMTg2vXruV4rJwwMDDIMFPo+PHj2R4jNTUVnp6e8PX1BQA0bNgQgYGBuVqKpYk9e/bg0aNHAIBp06ahUqVKCn0WLVqE4sWLi8fa5OvrixEjRsDIyEhpe5EiRfD333+Lf965c2eOrnP37l1cvnwZAODs7Kx0SdedO3fw/v17AICbmxvq1aun0EdfXx/Tp08X/3zhwoUcxUNERET0tTIyKoYFfY9hXIlGCm1XpEkYsK87Hj08qoXIiDTHRBDpvJiYGCQkJAAAKleuXCDXdHR0FI8fP34sHj969Ah//PEH3N3dYWdnBxMTE5iYmMDW1hYeHh44fPhwnly/du3a4nF2izqnpKRgwIAB2LJlCwCgadOmOHHihJh8yU979+4Vj1Ul7IoUKYI+ffoASC/UfO/evXyPKzdq1aolJtAePnyYozHk6xCpWhYmX0+pYsWKKseST66xaDQRERGRIolUihFuvlhSZRCMMxWRfqIHDDwzAecv/aWl6IiyxkQQ6TxDQ0Px+M6dOwVyTQMDA/E4NTUVABAREYFKlSph0qRJ2L9/Px4/fozExEQkJibiyZMn2LFjBzp16oRBgwblelcs+XuWjyUrycnJ8PDwwPbt2wEALVq0wNGjR2Fubq72PPndzJQVVdZUUFAQAKBq1aooXbq0yn5OTk7i8blz53J8vYIiS7jkpJ6RIAhiUs7U1BQ9e/ZU2q9KlSpi0WnZrCpl5JNRVatWzXY8RERERLqi/Tc/YsM3v8IqUzLoo1SKb++swb8HRwPcm4kKISaCSOdZWlrC1tYWAHDjxg0sXLgQaWlp+XpN+eK/soLSqampMDQ0hLu7O/766y8cP34c165dw/Hjx7Fq1SrUrFkTALBlyxbMmzcvV9eXT3hpsjU8kJ6s6NWrF3bv3g0AaN26NQ4fPoxixYrlKhZNffr0SZy9VK1aNbV95dvzIrnn4uICKysrGBoawsrKCs7Ozvjtt9/w7t27XI8dEhKC2NhYAED16tWzff6pU6fEne569OiBokWVFyg0NzdHv379AAD79+/HzZs3FfqkpKRgwYIFCv2JiIiISLmaVbtgm9sO1EjLuMFGqkSC+TFBmP9vR6QkxWspOiLlmAgiAjBu3DjxeOrUqahUqRLGjx+P7du3IyIiIk+vdePGDXGJV5EiRdC4cWMAQJkyZRAZGQl/f3+MGzcObdu2Rf369dG2bVuMHj0aoaGh4nKoP/74Ax8+fMhxDLIP+wDQq1evLPsnJyeje/fu8Pf3BwC0b98eBw4cgKmpaY5jyK7o6GjxOKstE2XbKwLZX/qmzLFjxxATE4Pk5GTExMTg9OnTmDZtGipWrIh9+/blauxff/1VPJYtacsO+WVhgwcPVtt3yZIlaNCgAZKSktCyZUv8/PPPOH78OC5evIiNGzeiUaNGuHjxIooUKYJNmzahRIkS2Y6HiIiISNdYlaqBDX0D0V7PUqHt36RnGOPnhNj3j5WcSaQdTAQRAZgwYQKGDBki/jkyMhJ//fUX+vbti4oVK6J06dLo27cvAgICIORgeqcgCHjx4gXWrVuHdu3aicvBvvvuOxgbGwNIX9ZTpkwZlWNIJBL88ccf0NPTQ1xcXLaLPL9//x4XLlxAly5dEBAQAABo1qwZPDw8sjz32bNnOHjwIID0ZVf+/v4wMTHJ1vVz6+PHj+KxqlkvMvIJqk+fPuX4mrVr18ZPP/2EgIAAXL16VUyYuLi4AEh/pj179sShQ4dyNP6uXbvEAtENGzZEjx49snV+fHw8du3aBSA9OdamTRu1/a2trXH27FksXboUJiYmmD17Ntq3b49mzZrBy8sLN2/exLBhw3D16lV06dIlR/dEREREpItMTCywuP9JjLSoo9B2XpKIgbvd8CTyVMEHRqQEE0FEAKRSKdavX4+jR4+iY8eO0NfPOLXz5cuX2L59O7p06QJHR0eNivra29uLdXGkUinKlCmD4cOH4/Xr1wAAV1dX/PzzzyrPT05ORnR0NO7cuYOwsDCEhYXh2bNn4iyNGzduqL3+xo0bxetLJBIUL14c33zzDQICAmBgYAAvLy8cPnxYoxpBstoyQPqytuwWYLazs4MgCBAEAadOncrWuTKJiYnisXyNI2Xkd+mSFQLPru+//x43b97Ezz//DDc3NzRo0ABNmjTB4MGDceTIEaxevRpA+pK+YcOGZYhPE3fu3IG3tzcAwMTEBJs3b87wnDWxd+9eMUE2cOBASKVZ/y/95MmT2LJlC16+fKnQJggC9u3bhw0bNrBQNBEREVE2SaV6GNt1K36r2BuGmf7xOEIP6B84Bleu/qOl6Ij+h4kgIjnt27fHoUOH8ObNGxw8eBBz586Fu7t7hmLIwcHBaNmyJZ4/f57t8Q0NDdG8eXNs3LhRTMjIS05OxsqVK9G0aVMULVoUNjY2qFGjBmrXri2+Xr16BQBiQiknqlSpggkTJsDMzEyj/hUqVMDkyZMBAG/fvkX79u0RHh6e4+vnhGzmFJD1blafP38Wj3M6c8nCwkJt+8iRIzF06FAA6TOmZDNzNPHs2TN07twZHz9+hEQigY+PT47qA2myW5i8ZcuWoUuXLggODkarVq1w7NgxfPjwAZ8/f8bt27cxadIkvH37FgsXLkSbNm1yNZuKiIiISFe5tpwFH8c5KJGpiPQHqRQjQpdj95HxWoqMKB0TQURKmJmZoVOnTpg1axb8/f3x8uVL+Pj4iFukP3/+HD/99JPaMY4cOYLQ0FCEhobi1q1bePz4MT5+/IigoCAMHjxYYfbH27dv0axZM4wdOxaXLl3KMtmR1UyXrl27ite/ceMGDh06hPHjx8PY2Bi3b9+Gs7Mz7t69q8HTSPf7779j7NixAIBXr16hXbt2anefymvyRamzSlDExcWJx1ktI8uNkSNHisenT5/W6Jy3b9/CxcUFkZGRAIDly5ejb9++2b728+fPxeWBjRs3zrKA9s2bNzFx4kQIgoB27drh5MmTaNeuHczMzGBoaIjq1atj0aJFWLNmDYD03dZmz56d7biIiIiICKhboxe2ddqKqmkZd4VNkUgw+8VJLN7hhtTk7M0oJ8orTAQRacDIyAje3t7Ytm2b+LXdu3er3V3MwcEBtWrVQq1atVCjRg1UqFBB7ZKm8ePH4+rVqwCAbt26wd/fH5GRkYiPj0daWpq4tEpWCDmrWkUWFhbi9evUqYOOHTti6dKl2L9/P/T19fHu3Tv0799frFekib/++kuspfT06VO0bds2T4oxa6JcuXLisXzhaGXkY5IvHJ3XatSoIR4/ffo0y/4fP35Ex44dcevWLQDAvHnzMGbMmBxde+vWreL3Lqsi0QDg6+srvl/nzp2rcqv6IUOGoEqVKgCADRs25KgmFhEREREBZUrXxaY+x+GsZ67QtjHhMcb7OSEuNuu/QxLlNSaCiLKhQ4cOYmLh3bt3ePPmTZ6MGxsbi+3btwMABgwYgD179sDd3R22trYwMTHJMHsot1uWt23bFuPHp09HvXbtGjZs2KDxuRKJBGvXrkX//v0BpBfVbtu2LV68eJGrmDRRrFgx8dlntSxNvj0nS640lZ2aPgkJCXB3d8eVK1cAAJMnT8bMmTNzfO3NmzcDSF9uqMk273fu3BGPGzRooLavrP3t27fiUkQiIiIiyr4ipiWxtG8gvIsp/p30NOIxcGcnPI06r4XISJcxEUSUTWXLlhWPs1vcV5X79+8jOTkZANTu4hUeHp4ndVumT58u1geaO3dutgoDS6VSbNy4Udzh6v79+2jXrl2eJcXUadGiBQDg7t27apNP8su0mjdvnm/x3L59WzyWf19klpycjJ49e4pxjRo1Cr///nuOr3v9+nXcvHkTANC5c2eNtnmXL4CekpKitq/svZj5PCIiIiLKPj19A0zssQPzbLtCP9Ns6wd6AvofG47r1zdoJzjSSUwEEWVDfHy8+OHfzMxMow/gmpD/YC5f3yYz2U5VuWVpaSkuSYqKisLGjRuzdb6+vj62bduGTp06AQBu3boFFxcXfPjwIU/iU6Vbt27isaqZTPHx8dixYweA9KVbDg4O+RbPP//8b9cHJycnpX1SU1PRv39/cYv5QYMGYdWqVbm6bnaLRAPpu9jJnD17VmW/5ORkXLhwAQBgbm4OS0vLHEZJRERERPK6Of+CdQ2nwyJTEem3elIMub4YAccnayky0jVMBJHO+/TpE5o0aYL9+/errfmTlpaGcePGidt1d+nSJc9mBFWuXFkca+PGjUrrsgQEBGDFihV5cj0AmDBhAooUKQIA+O2337JVKwhIX5K0e/dutGnTBkD6MrOOHTsqnbEUGRkpbmPv7Oyc45i7d++OihUrAgAWLFiAhw8fKvSZPHmyuHxOttNZZhs2bBDjmTNnjkJ7aGgoHjx4oDaWNWvWYN26dQCA0qVLo3v37gp9BEHA8OHDsXPnTgBAz5494evrm6v3TWpqKvz8/AAAJUqUgKurq0bnubu7i8dTp05FbGys0n6zZ88Wd8Tr3Llznr3HiYiIiAhoWLs//Dr4olJaxo/iyRIJpj89jGX/dUdaarKKs4nyBuf8EwG4fPky3N3dUa5cOXTr1g3NmjWDra0tihUrhvfv3yMkJAQ+Pj4IDQ0FkD5TYt68eXl2/RIlSqBz5844cOAADh8+DBcXF4wePRq2trZ49eoVdu3ahQ0bNqBixYp4//49YmJicn3NUqVKYfjw4Vi2bBkePXoEPz8/DBo0KFtjGBsbw9/fHx06dMC5c+dw8eJFuLm54dChQznetl0dAwMDLF++HO7u7oiNjUXz5s0xc+ZMODo64t27d1i7dq24jXuLFi2yfT8yV69exbBhw9C6dWt06tQJtWvXRokSJZCSkoLw8HBs3boVR48eBQDo6elhzZo1MDU1VRhn0qRJ8PX1BQDUqlUL06dPz1CrR5latWqpbT9y5AhevnwJAOjXrx8MDAw0uicXFxe0adMGJ0+exM2bN1GvXj2MHz8ejo6OMDY2xoMHD+Dj44PDhw8DAExNTblrGBEREVE+sCnbGJt7HcHkPT1wTviYoW1d/ANEbGmFX3vuRZGi1lqKkL56AlE2RUVFCQAEAEJUVJS2w8m1hIQEoXTp0uI9ZfWqUqWKEBwcrHQsT09PsV9ERES24njy5IlQoUIFldetUKGCcOvWLcHW1lYAIHh6eiqMERERIfZX1p5ZVFSUYGhoKAAQqlevLqSmpmZol13L1tZW7TgfPnwQGjVqJF67Q4cOwufPn5XG5eTkpMHTUG/NmjVi3Mpejo6OQkxMjMrzfX19xb6zZ89W267uVaJECWHv3r0qryN7ftl5ZcXDw0Pse/nyZY2el8zbt2+F1q1bZxlDqVKlhGPHjmVrbCIiIiLKnuTkz8Jv/3UTam2opfDqvb628PzpFW2HSIVAfnz+5tIw0nnGxsZ4+vQpzp07h7lz56JTp06oWLEiTE1NoaenBzMzM1SrVg0eHh7w8/NDWFgYGjZsmOdx2NjY4Nq1a5g8eTIcHBxgZGQEc3Nz1K1bF7Nnz8b169czbFeeF8qXLy/WmLlz5444mya7zMzMcOTIEdSpUwdA+qwVDw+PLIsS59Tw4cNx9epVDB8+HBUrVoSxsTFKlCiBFi1a4O+//8a5c+dQsmTJHI/fuXNnrF+/HsOGDUPDhg1Rvnx5mJiYwNjYGGXLlkWnTp3EmVRdu3bNwztTLzY2Fv7+/gDSd0Nr3Lhxts4vXrw4Tpw4gb1796JPnz6wt7eHiYkJDAwMUKpUKTg7O2PhwoUIDw9Hu3bt8uMWiIiIiOj/6esbYkqvPfipfEeFItJ39AT0O+yJsNBtWoqOvmYSQVBSjIRIjejoaHEb76ioKJQvX17LEREREREREX25Ll33xYSQP/BRmrE+o1GagF/suqFj61+0FBlpW358/uaMICIiIiIiIiItalLPG37t/oFtWsZE0GepBJOf7MPfu3tDyObmLkSqMBFEREREREREpGV2Ns2xtedBNJEUUWhb9TEcU/yckBj3WguR0deGiSAiIiIiIiKiQsDcrDz+7n8avU1sFdoOpX3AkO3tEPPihhYio68JE0FEREREREREhYSBvjF+6h2AqWVaQ5qppG+oXir6HRyAO7d3aik6+howEURERERERERUiEgkEgxw+Qsra49F0bSMyaCXehJ4XpqDE2dZQJpyhokgIiIiIiIiokKoRcNR2NJ6BcqlZfx6glSCCQ//xfq9A1lEmrKNiSAiIiIiIiKiQqqSnTO2dQ9AAxhn+LogkWDphxuYua0NkhLeayc4+iIxEURERERERERUiBW3sMPafqfRzaicQpt/6lsM+7c13sbc1kJk9CViIoiIiIiIiIiokDM0LIKfPQ5holVzSDIVkQ6RpqB/QB/cvxegpejoS8JEEBEREREREdEXQCKRwLvTaiyrPhwmmYpIP9WTYNC5qThz7nctRUdfCiaCiKjQsrOzg0QigZeXV75dw8vLCxKJBHZ2dvl2DSIiIiKivNS6yXhsbrUEpTMVkY6TSjHu/iZs8veGkJam/GTSeUwEESmRkpKCXbt2YcSIEahduzasrKxgYGAAc3NzVK5cGd27d8eiRYsQERGh7VB1niAI8Pf3R79+/VClShUULVoU+vr6sLCwQK1atdC7d28sWrQIN27cKPDYWrduDYlEAolEAhcXF43Pc3Z2Fs+Tf+np6cHS0hINGzbE+PHjcevWrSzHmjNnjnj+qVOn1PYNCgqCmZkZJBIJ9PX1sWXLFo1jzon4+Hj8/vvvaNy4MSwtLWFqaopq1arhhx9+wOPHj3M9fmRkpNLnqOylLtn44cMHbN26Fd7e3qhbty7Mzc1hYGCAUqVKoXXr1vjjjz/w/v17jWJKS0vD9u3b0a1bN9jY2MDY2BhFihSBvb09PDw8cOjQIY3v7/z58xg4cCBsbW1hbGyM0qVLo0OHDti2bZvGYxAREdGXq2olF2zruht1BMMMX0+TSLDoXTDmbmuH5MSPWoqOCjWBKJuioqIEAAIAISoqStvh5Ll9+/YJlStXFu8xq5erq6sQGhqq7bC/Sra2tgIAwdPTU2n7ixcvhBYtWmj8vbpz547CGJ6engIAwdbWNk9jj4yMFCQSiXhtqVQqPH36VKNznZycNLofPT09YcGCBWrHmj17ttg/MDBQZb/AwEDB1NRUACDo6+sL27dvz87tZtv9+/eFKlWqqLw3MzMzISAgIFfXiIiI0Pi9oeo9dvDgQcHIyCjL80uXLi2cPHlSbTxv374VWrZsmeVYPXv2FBITE9WONXv2bEEqlar9/1JCQkJOHx0RERF9QRITY4Uf/doItTbUUnh5+9QX3r+5r+0QKRfy4/O3fk6SR0Rfq19++QWzZs2C8P/F15ydneHm5oY6deqgRIkSiI+Px/Pnz3HmzBns378fkZGROHDgAMqXL4/Vq1drOXrdkpSUhPbt2yM0NBQAUL9+fXh7e6NevXooVqwYYmNjcefOHZw5cwYHDhzAhw8fCjS+zZs3QxAEGBkZITU1FSkpKdiyZQt+/PHHbI0juz8g/Z4fPXqEvXv3YuvWrUhNTcW0adNQqVIl9O7dO8exHj9+HF26dEFCQgIMDAywfft2dO/ePcfjZeXjx49wdXXF/fv3AQDDhw9H3759YWJigsDAQCxYsACxsbHw8PDAuXPnUK9evVxf85dffkHXrl1VthcvXlzp19+8eYPPnz9DKpWiffv26NixI+rWrQsLCwtER0dj69at2L59O168eAE3Nze18fbt2xdnz54FANjb22Py5MmoXbs2kpOTcfXqVSxcuBCvX7/Grl27ULJkSZX/T/nnn38wd+5cAEClSpUwffp01K5dG8+ePcOyZcsQGBiIAwcOYMiQIfDz88vGUyIiIqIvkZFRMfzmcQyVDg7D8jdXMrRdkSaj/95uWOH0B+wrddBShFTo5Ek6iXTK1zojaP369eJ9WVtbq509IQiCkJKSImzZskWoUKGCMHLkyIIJUseomxG0YsUK8fvl7e0tpKamqhwnMTFR8PX1FZ4/f67Qll8zghwcHMTZHZ07dxYACLVq1dLoXPkZQar89ddfYp+aNWuq7JfVjKCDBw8KxsbGAgDByMgo17NwNPHTTz+JMf3+++8K7efOnRP09fUFAIKTk1OOryM/I8jX1zdHY/z777/CyJEjhcePH6vsI/+9aN26tdI+V65cEftUrFhRiI2NVejz+PFjwcLCQpxB9vLlS4U+b968EczNzQUAQoUKFYSYmJgM7SkpKYK7u7tGs8CIiIjo63Pk/EKhkU9NhZlBzXxqCOcvLtV2eJQD+fH5mzWCiABERUVhzJgxAAAzMzMEBQXB2dlZ7Tl6enoYMGAAbty4AVdX1wKIkuTt27cPAKCvr48lS5ZAKlX9vzMjIyN4eXmhdOnSBRLbxYsXce/ePQDAgAEDMHDgQABAWFgYrl27lifXGDNmDCpUqAAAuHXrFl68eJHtMQICAtCtWzckJibCxMQE+/btg5ubW57Ep0pycjL++usvAED16tXxww8/KPT55ptvMHToUADA6dOnceXKFYU+BcXDwwOrV68Wn7Uy48aNQ6NGjQCkx/v69WuFPufPnxePv//+exQrVkyhT4UKFeDt7Q0gvZbQpUuXFPqsW7dOnN22cOFClCxZMkO7np4eVq1aBT09PQDAokWLsrpFIiIi+oq4NPsRG5ovgFWmOtEfpVKMvrMW2w+OBDJtPU+6h4kgIgBLlixBYmIiAGD+/PmoXLmyxudaWFjA3d1dbZ8XL15gxowZaNSoESwtLWFkZAQbGxv06dMHx48fV3mefLHbDRs2AACOHTsGd3d3lC5dGkZGRrC3t8fo0aMRHR2tUbyBgYHw9PRExYoVUaRIEZiZmaF27dqYPHkynj17pvI8+aLDQHoB3Xnz5qF+/fqwsLDIECMAxMXFYfv27Rg2bBjq1auXocCuk5MTFi9ejE+fPmkUszJPnjwBAJQsWRIWFhY5Hiez9+/fY9asWahZsyZMTU1hYWGBVq1aYevWrRqPsWnTJgDpy41cXV3RrVs38YO/rC23pFIpatasKf45KioqW+fv3r0bPXv2RFJSEooUKYL9+/ejQ4f8ny4cGBgoJjI8PT1VJvDkizfv2bMn3+PKLVniOC0tTWkR+aSkJPG4YsWKKsepVKmS0nNk9u7dCyA9Yd2jRw+lY5QvXx7t2rUDAJw4cQIfP7JIJBERkS6p6eAOP7ftqCEYZPh6qkSCX2LO49d/OyIlKV5L0VFhwEQQ6TxBELB582YAQLFixcR/kc8rW7duReXKlfHrr7/i6tWrePfuHZKSkhAdHY3//vsP7du3x7Bhw5CSkpLlWNOmTYOLiwv279+Ply9fIikpCZGRkVi9ejUaNGiAO3fuqDw3MTER/fr1Q5s2bbBp0yZEREQgISEBHz9+RFhYGBYvXgwHBwcEBARkGcf9+/dRr149zJo1C9evX1daf8fV1RV9+/bF+vXrcePGDcTGxiIlJQWvX7/GmTNnMHnyZNSpUwfh4eFZXk8ZQ8P03RFevnyJt2/f5miMzO7evYv69etj3rx5uH37NuLj4/HhwwecPXsWAwcOxNixY7McIykpCdu3bwcA9O7dG4aGhjAxMRE/tG/btk2j77UmZM8AAAwMDNT0zGj79u3w8PBAcnIyihYtisOHD6NNmzZZnie/m1lkZGROQkZQUJB47OTkpLJfo0aNUKRIEQDAuXPncnStgvT582fxWDYbR17VqlXF40ePHqkc5+HDh0rPAdLfW5cvXwYANGvWLMP3PzPZs/38+TOCg4OziJ6IiIi+NtalamBD30C01y+h0LYt6RnG+LVC7LvIgg+MCgUmgkjnhYWF4c2bNwCAli1bwtTUNM/G3rFjBwYNGoS4uDhUrFgRS5YsweHDh3H16lXs2rULnTt3BgCsX78+yyLCa9euxW+//QYnJyf4+fkhODgYx48fx+DBgwEAMTExGDJkiNJzBUFAr1698O+//wIA3N3dsXnzZpw7dw4XLlzAsmXLUKFCBcTFxaFXr15ZfnDs1asXnj59inHjxuHYsWMIDg7Gtm3bMnxwTUlJQe3atTFjxgzs2bMHly5dwsWLF7F9+3b07dsXUqkUERER4tKk7GrQoIF4b8OHD8/V7CIgfStzd3d3vHnzBjNnzsSpU6cQHByMtWvXonz58gCAlStX4siRI2rH2b9/v5iYki0Jkz9+9eoVDh8+nKtYZeQTf7a2thqds3XrVgwYMAApKSkwMzPD0aNH0bJlyzyJRxO3b98Wj6tVq6ayn76+vjgzT12CU1PLly9H5cqVYWxsDHNzc9SsWROjRo3Ks6V6p0+fBpCekFM2o7BDhw6wt7cHACxbtgxxcXEKfaKjo8VZdS1atECtWrUytN+7dw+pqakA1D+7zO158fyIiIjoy2NibI7F/U5gRPG6Cm3nJZ8xcI87oiJOFXxgpH15UmmIdMrXVix6y5Yt4v3MnDkzz8aNiYkRi7oOGTJESE5OVtpv+vTpYnHY8PDwDG2Zt78ePny4kJaWpjDGsGHDxD7Xrl1TaF+zZo0AQDAwMBAOHTqkNI63b98KNWvWFAAIzZs3V2iXLzoslUqFI0eOqL3/e/fuqW0/duyYuP31unXrlPZRVyz60qVLGbbPtrCwEAYNGiSsWbNGuHHjhpCSkqL2+jKyYtEABHNzcyEsLEyhz/3798WCyl26dFE7XteuXQUAgp2dXYbvVWpqqlC2bFkBgNC7d2+1Y2hSLHrXrl1in7Zt26rsJ/998/b2Fp9Z8eLFhcuXL6uNQ11cERER2TpXpkmTJgIAwdTUNMu+rq6u4vWy2k5dGU23jx85cmSOxpfZv3+/OJabm5vKfhcuXBBKliwpABAqVaokrF69WggKChICAwOFxYsXC1ZWVmIxaWU/P4cOHRKvs2jRIrUxyRennjp1ao7vjYiIiL4O+8/MExr4KhaRbu5TQ7h85W9th0dqsFg0UT6QL+xaqlQplf3S0tIQFham8pWcnJyh/99//40PHz6gXLlyWLVqFfT19ZWOO3fuXJQrVw5paWlq68eUKVMGy5cvF2v0yJs0aZJ4LNueWkYQBCxcuBAA8N1336Fjx45Kxy9evLhYWPbcuXPi1t7KeHl5wcXFRWU7AFSpUkVte7t27dClSxcA/6t7kh2Ojo74559/xCVR79+/x+bNmzFixAjUrVsX5ubmcHFxwdq1a5XOvlBm3rx5GeruyFSuXBndunUDkHFpU2Zv3rzBwYMHAQD9+/fP8L2SSqXo378/gPQize/fv9coJnlJSUkIDw/HggULMGjQIABAkSJFMH/+fI3O9/X1RVpaGkxMTHDixAk0btw42zHklqxeTdGiRbPsKz87L6czviwsLODt7Y2NGzfi/PnzuHbtGg4cOIDx48eLMfzzzz8qZ9Nl5e3bt2KheT09Pfz8888q+zZt2hQhISGYNGkSnjx5glGjRqFFixZo3bo1Jk2ahPj4eMybNw9XrlxR+vMjX+snq+eXF8+OiIiIvh6uLWdifZM5sEzLWCj6g1SKEWErsOfIdywirUOYCCKdJ//hSt2ysNjYWNSuXVvl6+nTpxn6+/v7AwDc3NxgZGSkclx9fX00a9YMAHDhwgWV/Xr16qVynKpVq4ofDDPXH7l9+7ZYd6RXr14qxweAVq1aicfqYhkwYIDacZSJiYnB/fv3MyTPZIm3GzduZHs8ABg2bBhCQ0Ph7e2tsAtTXFwcjh07hhEjRqBKlSpZLseSSCRiokaZhg0bAkj/4K8qibNt2zYxISi/LExG9rXExET8999/auORj0v2MjIyQvXq1TF9+nTEx8ejQYMGOHr0KJo0aaLxWACQkJCAAwcOaHSOvFOnTkEQBAiCADs7u2yfD0BcBqiuvo2M/Ps9ISEh29cqW7Ysnj59Ch8fHwwePBjNmjVD/fr10blzZyxduhTXrl0TdwPz8/MTf2Y1lZqaigEDBuDx48cAgJkzZ6J+/foq+wuCgH///Rc7duxQSBwD6QmbrVu3qkyMyi+hzOr55fbZERER0denXvVe2NZ5KxzSMtYzTJFIMOtFIP7Y4YbU5OyXbKAvDxNBpPPkEwiazhzJSmpqKq5fvw4gfbaB/Id5Za+dO3cCgNotwLOqCVK8eHEAUNghSL7eT7NmzdTGIT/LQF0sderUURuLzLlz5+Dh4YESJUrAysoKDg4OGZJna9euBQCl221rqmrVqvDx8cGbN29w/vx5LFmyBAMGDBDr+gDA8+fP4ebmpnaHtpIlS6JECcViejKWlpbisapdmDZu3AggvX5R9erVFdrr1q0r1n3J7e5hhoaGGDp0KJo3b67xOb/++qv4fv/pp5/w559/5iqGnDA2NgagfEeszOQLMJuYmGT7WoaGhmLBaWWqVKmCLVu2iH9evnx5tsb/9ttvxQSjm5sbfvrpJ5V909LS4OHhgcmTJ+PJkycYOnQorl27hoSEBHz69AlBQUHo0qULwsPDMXToUHz//fcKY8ieHZD188vtsyMiIqKvU1nrutjkcRzOeuYKbRsSn2C8XyvEfdBsN2L6cjERRDpP/sN/TEyMyn4WFhbibAjZy9PTU2nft2/f5mhnqPh41ds4qvtAC0DchltWTFbm1atX2Y4jq1hkSSd15syZgxYtWmDHjh1Z7uqVFzMWDAwM0KxZM0yYMAFbtmxBVFQUTpw4IS71Sk1NxbfffgtBxZRXTZ+vbKzM7ty5IybdlM0GkpEt6Tp37pzSbcYzCw0NFV9nzpzBihUrUKlSJSQlJWHMmDHicj5NNG3aFPv37xfvdeLEiVi9erXG5+cFWSJKk+VK8olZTZaS5UTLli1Ro0YNAOnL/tLS0jQ6b9q0aVizZo04xo4dO5TuFibz999/i7PA5syZg3Xr1qF+/fowNjaGqakpmjdvjn379onvj2XLlins4CeftM7q+RXEsyMiIqIvk2mRklja7xS8zWootJ1GAgbt6oRnT85rITIqKEwEkc6rW/d/VfRDQkLyZEz5RIFs+ZImr6NHj+bJ9VXFEhAQoHEs3377rcox1X3gBYATJ05g7ty5AICKFSti1apVuHnzJt6/f4/k5GQxkaZuBkVeaNOmDY4dOybO5rl//744Uyuvyc/wmThxospZV1OmTAGQvkxIk1lBtWrVEl8tW7bEmDFjEBISIs7Kmj59Oq5cuaJxnK1atcLevXvFpUPffvttrmcnZYdsplZcXFyWdZKioqIApNfuUre8MrdkiaDExERxB0F1Fi5ciN9++w1A+uyv/fv3ZznrZt26dQDSkzlTp05V2e/XX38Vj318fDK0yc9yi45W/y91smcHADY2Nmr7EhERke7R09PHxO7b8bNtN+hn+ofS+3pAv+PDcf26j4qz6UunvHotkQ6pVasWSpQogTdv3uDs2bOIj4/PcnZIVuSXEQmCoLANdEGSn/FkYWFRILHIlnwVL14cFy9eVFmEO6uZQnmhTJkycHV1xebNmwEADx48UFvHJSfS0tKwdevWbJ+3efNmzJ49O9vnFStWDJs2bUKDBg2QkpKCH374AWfOnNH4/Pbt2+O///5Dz549kZycjCFDhsDY2Bh9+vTJdizZVaNGDezatQsAEB4ejqZNmyrtl5KSIta2UrbMLi8pK8CuyqpVq8RETvXq1XHkyBGYmZlleZ5sC/caNWqoTWqVL18e1tbWePnyJcLDwzO0OTg4QE9PD6mpqQptmcm35/fzIyIioi9Xd+d5sAmrjglXfsV76f/+TvRWT4oh15dgbsxtuLdfrMUIKT9wRhDpPIlEIi7liY2NFeu85IahoaG4JOncuXO5Hi835JMeBRXLrVu3AACtW7dWuxObfP2i/FS2bFnxODsf+jUVGBgozsAYN24ctm3bpvYlq//y8OHDHH9P6tatKxa3Pnv2bJbFsDNzd3fH1q1bxcTCwIEDFZYi5YcWLVqIx6dPn1bZLzg4WFzelJ06SDlx+/ZtAOkFltXVidq8eTPGjh0LIH2m2/Hjx1GyZEmNriHbNVCTJaOyQtKZdxo0NDSEo6MjgPRi7urqBMmerZGRERo1aqRRjERERKSbGtXqD78OG1AxLWN6IFkiwfRnR/DXf92QlpJ1fUf6cjARRIT0pTyyQqzTpk3TqHZLVmRbo4eHh+PIkSO5Hi+nGjRoIC4pWbNmTYadh/KL7MOuuuLbISEhuHTpUo6voarWjzLyCaeKFSvm+JqqyJZW6enpYebMmejbt6/a14wZM8QP+blZljVjxgyxdtEvv/yS7fN79+4NHx8fSCQSJCcno3fv3jh27FiO49GEs7MzzM3TixNu3LhR5fdxw4YN4nH37t3zLZ5z586JicsWLVpkqAUlb/fu3fD29oYgCChfvjxOnDiRIcGYFXt7ewBAWFiY2iVxYWFh4kw52TnyunXrBiA9ab17926lY0RHR4uF0du2bauwox4RERFRZjZlG2FL76NoLlX8e8Pa+If4YWsrxH9SvZkMfVmYCCICUKFCBfz1118AgA8fPqBFixYICgpSe44gCGo/0I0fP14s0urt7S1+2FTlwIEDuHnzZvYC14BUKsX06dMBpG8tP3jw4Aw7CmUWGxuLFStW5OqaVapUAZBefPfBgwcK7TExMWJR3Jzq0aMHVq1aleVObxs2bMCJEycApH+f83pZWFxcnPiBvGXLlrCyssrynJIlS8LJyQkAsGPHDrXfD3WqVauGHj16AEhPaAQGBmZ7jMGDB+Pvv/8GkL7TVLdu3VQuM3N2dhZrHUVGRuYoZkNDQ3z33XcA0pdLLV6sONX4woULWL9+PQDAyckJjRs3VjqWLBZVW9nv3btXbcLwwYMH4qwqACrrYh09ehT9+vVDamoqrKyscPz4cZXXVMXd3R1A+jOeOHGi0rgSExPFZwOk70SW2bBhw8RE2tSpUxVqGsmKostqg02ePDlbcRIREZHuKlbUGiv6ncKAopUV2o4jDl47XPDyqea1KanwYo0gov83fPhwPH36FHPnzsWzZ8/QsmVLtGnTBu7u7qhduzYsLS2RmpqKFy9e4Nq1a9ixY4eY3NHT04OhoWGG8aytrbFx40b06tULz58/R6NGjeDl5YVOnTqhfPnySE5ORnR0NC5fvoydO3fi0aNHCAgI0Hhr9uwYNWoUjh07hj179uC///7DtWvXMHLkSDg6OsLc3ByxsbEIDw/HqVOn4O/vD2NjY3EJTE4MHjwYAQEBiIuLg5OTE6ZOnYqGDRsCgLjF+4sXL9CsWTNcuHAhR9eIiorCmDFjMGXKFLi7u6NVq1aoWrUqihcvjsTERISHh+O///7DwYMHAaQnDf788888Xxq2e/ducQennj17anxez549ceLECbx//x7+/v7o3bt3jq4/ffp07Ny5E0D6rKDWrVtne4yRI0ciISEBEyZMQHx8PNzc3HD8+HFxGVJemzx5MrZv34579+7hxx9/xIMHD9C3b1+YmJggMDAQv/76K1JSUmBiYoKlS5fm+Drdu3dH5cqV0aNHDzg6OqJ8+fIwMjLC8+fPceTIEaxfv1783vXp00dMqsm7ePEiunfvjqSkJBgYGODPP/9EcnIywsLCVF63fPnysLCwyPC1iRMnYv369Xj16hV8fX1x//59jBo1CtWqVUNqaipCQkLw119/icvUqlevDi8vL4WxLS0tsXDhQowaNQqPHz9GkyZNMGPGDNSuXRvPnj3D0qVLxYRgv3794OzsnLOHR0RERDpJX98QU3vuQcUTP+LXqINIlfu78x09Af2OeGF54xmoWbu/mlGo0BOIsikqKkoAIAAQoqKitB1Ontu9e7dQsWJF8R7VvSQSidCxY0chNDRU5Xj+/v6CpaVllmNJpVLh5MmTGc6NiIgQ2319fdXGbWtrKwAQPD09lbYnJSUJo0ePFiQSSZax2NvbK5w/e/ZssV0T3t7eKsfX09MTli5dmuWY6u6pa9euGn2PAAjm5ubCpk2blF7D09NTACDY2tqqvR9fX19xvIiICPHr7dq1E98LT58+1ejZCIIgvHjxQpBKpQIAwc3NLUObk5NTtp51586dxf4XLlzI0Cb/jAMDA9WOM3/+fLFv8eLFhZCQEJVxyT+DnLh//75QpUoVld8zMzMzISAgQO0Ysr6qvneavj9Gjx4tJCYmKh1D/vlp+lL1sxoSEiLY29tneX69evWEyMhItfc+a9YstT/LnTt3FhISEtSOQURERKTOheu+QjOfmkKtDbUyvBr61BQOn5yh7fB0Rn58/ubSMKJMunfvjrt372LHjh0YOnQoatSogZIlS0JfXx9mZmawt7dHly5dsGDBAjx8+BCHDh1SuxOXu7s7IiIisHjxYrRp0wbW1tYwMDCAiYkJ7O3t4ebmhiVLliAyMjJHszk0ZWBggFWrVuHGjRsYN24cateuDXNzc+jp6cHc3Bz16tXD0KFDsXPnTnGHo9zw8fHB5s2b0bJlSxQrVgxGRkawtbXFoEGDcP78eYwfPz5X4+/duxfh4eFYtmwZ+vTpg5o1a4r3Y2pqigoVKqBz585YunQpHjx4kOulaMo8ffoUJ0+eBAA0a9YsWzVjrK2txSLIhw8fRkxMTI7jmDFjhng8b968HI8zffp0zJw5EwDw7t07uLi45Ml7QZnKlSsjJCQECxcuRKNGjWBhYYEiRYqgatWqmDBhAm7evKl0aVR2+Pv7Y9q0aWjTpg0qVaoEc3Nz6Ovrw9LSEo0aNcKECRMQGhqKVatW5ev29DL16tVDaGgoVq5cCRcXF5QuXRqGhoYwMjKCjY0NunTpgs2bN+Py5cuwtbVVO9bcuXMRFBSE/v37w8bGBoaGhrCyskL79u3h5+eHAwcOiHXPiIiIiHKiaV0vbG23BrZpGWfUf5ZKMOnJPqze1RtCatYbYVDhIxGEbFRcJUJ6IVIbGxsA6ctzZIWIiYiIiIiI6OvyIfYpftjbA5eEeIW2ThJz/NxzL4xNNdtJlbIvPz5/c0YQERERERERESllblYOf/c/jV5FFGcsHxI+YOj2tnj94nrBB0Y5xkQQEREREREREalkoG+MWb0CMKVMG0gzLSq6qZeGfgcHIPz2Ti1FR9nFRBARERERERERqSWRSDDQZRlW1B4L07SMyaAXelIMvjQbJ8/8rKXoKDuYCCIiIiIiIiIijbRsOApb2qxAuUxFpBOkUnz/aAfW7+kPITVVS9GRJpgIIiIiIiIiIiKNVbZ1hl93fzRAxl1KBYkES2NDMdOvNZIS3mkpOsoKE0FERERERERElC2WFnZY2+80uhiVU2jzT3uH4f+2wdtXt7QQGWWFiSAiIiIiIiIiyjZDwyL4xeMQJlg1hyRTEelr0hT03++BB3f9tRQdqcJEEBERERERERHliEQiwZBOq7G0xnCYZCoi/VRPgoHnp+HsuYVaio6UYSKIiIiIiIiIiHKljeN4bGq1BKXTMn49TirF2PubsdnfC0JamvKTqUAxEUREREREREREuVatkgu2dd2NOoJhhq+nSST4/d1V/LytHZI/f9RSdCTDRBARERERERER5YmSllXg0/80OhlaK7TtTInBKD8nfHhzTwuRkQwTQURERERERESUZ4wMi2Khx1GMLemo0HZZmowB+3og4sFhLURGABNBRERERERERJTHJFIpRrqux2KHwTDOVET6sZ4EA87+gAsXl2onOB2nr+0AqPCrWbNmhj8nJydrKRIiIiIiIiL6knRoNhnlS1bHd0HT8EpuKspHqRSjw9dh+ps76NN5NSCRaC9IHcMZQURERERERESUb2pWcYOf+w5UFwwyfD1VIsG81+ex4N8OSEmK01J0ukciCIKQdTei/4mOjoaNjQ0AICoqCuXLl9dyRERERERERFTYxSe+x8xd3XAs5Y1CW/M0IyzqvhPFLOwKPrBCLD8+f3NGEBERERERERHluyLGFljc7wSGF6+n0HZO+hkDd7sjKuJkwQemY5gIIiIiIiIiIqICIZXq4bsum7GgkgcMMy1QeqQH9A8ch+Dgv7UUnW5gIoiIiIiIiIiICpRbi5lY33QuLDPtKPZeT4rhYSux5/A4gJVs8gUTQURERERERERU4OpV64ltnf1QRci4oXmKRIJZL0/hj+2uSE1K0FJ0Xy8mgoiIiIiIiIhIK8pa18HmPsfhpGeh0LbhcxS+93NC3Ifogg/sK8ZEEBERERERERFpjWmREljWLxDeZjUV2k5JEjB4Vyc8exKkhci+TkwEEREREREREZFW6enpY2L3f/GzXTfoZ6oNdE8P6Hd8JK6HrNdSdF8XJoKIiIiIiIiIqFDo7jQPaxpNh0WmItJv9aQYeuNP7D/2g5Yi+3owEUREREREREREhUbjWv3h12EDKqZlTFkkSSSY9uwo/trRFWkpSVqK7svHRBARERERERERFSo2ZRthS+9jaC4tptC2NuERJm1thfiPz7UQ2ZePiSAiIiIiIiIiKnSKFbXCin6n0L9oFYW2Y4iD138d8DL6shYi+7IxEUREREREREREhZK+viGm9dyNmeU7Qy9TEek7egL6HfXGrZtbtBTdl4mJICIiIiIiIiIq1DzaLsTf9SehWKYi0jF6UnhdXYCjJ6drKbIvDxNBRERERERERFToNavrha3t1sI2UxHpRKkUP0QF4J9dvSGkpmgpui8HE0FERERERERE9EWwt2mGrT0PwlFiqtC24lM4pm51wue4GC1E9uVgIoiIiIiIiIiIvhjmZuWwuv8p9Cxip9B2UIjFkO3t8Pr59QKP60vBRBARERERERERfVEM9I0xu5c/fizTFtJMRaRv6qWh36EBuHv7Py1FV7gxEUREREREREREXxyJRIJBLkuxvM44mGYqIv1CT4pBl+Yg8PRcLUVXeDERRERERERERERfrFYNRmJLm5UolybJ8PUEqRTjI/6Dz55+EFJTtRRd4cNEEBERERERERF90SrbOsGvuz8awCTD1wWJBH/GhuEnv9ZIin+rpegKFyaCiIiIiIiIiOiLZ2lhh7X9TqGLcXmFtn1p7zD83zZ49+qWFiIrXJgIIiIiIiIiIqKvgqFhEfzS5yAmWLeAJFMR6Wt6qei33wMPw/21FF3hwEQQEREREREREX01JBIJhnT8G3/WGAGTTEWkn+pJMPDCNAQF/aal6LSPiSAiIiIiIiIi+uq0dfwOm1r9Ceu0jF//JJVizIMtOHJ6tnYC0zImgoiIiIiIiIjoq1StUnv823UPasMow9dLCxI0quOlnaC0jIkgIiIiIiIiIvpqlbSsDJ++gehkaA0AKJImYLnzUpQobq/lyLRDX9sBEBERERERERHlJ2OjYljocRQVD41ADesGcLBvq+2QtIaJICIiIiIiIiL66kmkUoxyXaftMLSOS8OIiIiIiIiIiHQEE0FERERERERERDqCiSAiIiIiIiIiIh3BRBARERERERERkY5gIoiIiIiIiIiISEcwEUREREREREREpCOYCCIiIiIiIiIi0hFMBBERERERERER6QgmgoiIiIiIiIiIdAQTQUREREREREREOoKJICIiIiIiIiIiHcFEEBERERERERGRjmAiiIiIiIiIiIhIRzARRERERERERESkI5gIIiIiIiIiIiLSEUwEERERERERERHpCCaCiIiIiIiIiIh0BBNBREREREREREQ6gokgIiIiIiIiIiIdwUQQEREREREREZGOYCKIiIiIiIiIiEhHMBFERERERERERKQjmAgiIiIiIiIiItIRTAQREREREREREekIJoKIiIiIiIiIiHQEE0FERERERERERDqCiSAiIiIiIiIiIh3BRBARERERERERkY7Q13YA9OVJSUkRj58/f67FSIiIiIiIiIi+XvKfueU/i+cGE0GUbTExMeKxo6OjFiMhIiIiIiIi0g0xMTGws7PL9ThcGkZEREREREREpCMkgiAI2g6CviyJiYkIDQ0FAJQqVQr6+pxYVhDatGkDADh58qSWI/k66fLz/dLvvTDHX1hi00YcBXHN/LrG8+fPxRmvly9fRpkyZfJ0fNJtheX/C18jXX62X/q9F/b4C0N8X+vv8vy8ztfy+zwlJUVclVO7dm0YGxvnekx+gqdsMzY2RuPGjbUdhs4xMDAAAJQvX17LkXyddPn5fun3XpjjLyyxaSOOgrhmQVyjTJkyWv/+0delsPx/4Wuky8/2S7/3wh5/YYjva/1dXlDX+dJ/n+fFcjB5XBpGRERERERERKQjmAgiIiIiIiIiItIRTAQREREREREREekIFosmIiIikhMdHQ0bGxsAQFRU1BddU4CIiEhX8fe5apwRRERERERERESkI5gIIiIiIiIiIiLSEUwEERERERERERHpCNYIIiIiIiIiIiLSEZwRRERERERERESkI5gIIiIiIiIiIiLSEUwEERERERERERHpCCaCiIiIiIiIiIh0BBNBREREREREREQ6gokgIiIiIiIiIiIdwUQQERERUR55+vQpli5dChcXF1SoUAGGhoYoXbo0evbsiUuXLmk7PCIiIspCYmIiJk6ciFatWqFs2bIwNjZG6dKl0bx5c/j6+iI5OVnbIeaaRBAEQdtBEBEREX0Npk6dioULF6JSpUpwdnZGqVKlcP/+fezduxeCIMDPzw8eHh7aDpOIiIhUeP36NWxsbODo6AgHBweUKlUK7969w6FDh/D48WO4uLjg0KFDkEq/3Hk1TAQRERER5ZHdu3ejRIkScHJyyvD1s2fPom3btihatCieP38OIyMjLUVIRERE6qSlpSElJQWGhoYZvp6SkoL27dvj1KlT2L9/P1xdXbUUYe59uSksIiIiokKmR48eCkkgAGjZsiVat26Nd+/eITQ0VAuRERERkSakUqlCEggA9PX10b17dwDAgwcPCjqsPMVEEBEREX0VXr16hf3792PWrFno1KkTSpYsCYlEAolEAi8vr2yN9fjxY/zwww+oVq0aTE1NYWlpicaNG2PRokWIj4/PUXwGBgYA0v8iSURERIoK8+/ytLQ0HD58GABQq1atbJ9fmHBpGBEREX0VJBKJyjZPT09s2LBBo3ECAgIwcOBAxMbGKm13cHDAgQMHULlyZY1je/LkCRwcHGBpaYmoqCjo6elpfC4REZGuKEy/y5OSkvDrr79CEAS8efMGJ06cQHh4OLy9veHj46NRHIUV/0mKiIiIvjoVKlRAtWrVcPTo0WydFxISAg8PDyQkJKBo0aKYNm0aWrdujYSEBPz7779Yu3Yt7t27B1dXVwQHB6NYsWJZjpmcnIxBgwbh8+fPWLhwIZNAREREGtD27/KkpCTMnTtX/LNEIsGkSZOwYMGCXN1XYcBEEBEREX0VZs2ahcaNG6Nx48awtrZGZGQk7O3tszXG+PHjkZCQAH19fRw9ehTNmjUT29q0aYMqVargxx9/xL179/DHH39gzpw5asdLS0uDl5cXzpw5g+HDh2PQoEE5uTUiIiKdUJh+lxctWhSCICAtLQ3Pnj1DQEAApk+fjgsXLuDgwYMwMzPLza1qFZeGERER0VdJ/i+Pmkwnv3z5Mpo0aQIAGDlyJFavXq3QJy0tDbVq1cKdO3dgYWGBV69eibV/lPUdMmQINm7ciIEDB2Ljxo1f9FazREREBU3bv8sz+++//9CnTx/8+OOPWLhwYfZuphDh30aIiIiIAOzdu1c89vb2VtpHKpVi8ODBAID3798jMDBQab+0tDR4e3tj48aN6NevHzZs2MAkEBERUT7Ly9/lyri4uAAATp06leMYCwP+jYSIiIgIQFBQEADA1NQUDRs2VNlPfnv4c+fOKbTLkkCbNm2Ch4cHNm/ezLpAREREBSCvfper8uzZMwDQeAZRYcVEEBERERGAO3fuAAAqV66sdov3atWqKZwjI1sOtmnTJvTu3RtbtmxhEoiIiKiA5MXv8tu3byvdXj4+Ph4TJ04EAHTu3DkvwtUaFosmIiIinZeYmIjXr18DAMqXL6+2b/HixWFqaoq4uDhERUVlaPv555+xceNGFC1aFA4ODvjll18Uzu/WrRvq1auXZ7ETERFR3v0u37FjB5YsWYIWLVrAzs4OZmZmePr0KQ4dOoQ3b96gZcuWmDBhQr7dR0FgIoiIiIh03sePH8XjokWLZtlf9pfHT58+Zfh6ZGQkAODTp0+YP3++0nPt7OyYCCIiIspjefW73M3NDc+ePcP58+dx4cIFfPr0Cebm5qhTpw769u2LIUOGqJ1t9CX4sqMnIiIiygOJiYnisaGhYZb9jYyMAAAJCQkZvr5hw4YsdzQhIiKivJdXv8sbNWqERo0a5W1whQxrBBEREZHOMzY2Fo+TkpKy7P/582cAgImJSb7FRERERJrj73LNMRFEREREOq9YsWLiceYp4srExcUB0GzqOREREeU//i7XHBNBREREpPOMjY1RokQJAEB0dLTavu/evRP/8mhjY5PvsREREVHW+Ltcc0wEEREREQGoUaMGAODBgwdISUlR2S88PFw8rl69er7HRURERJrh73LNMBFEREREBKBFixYA0qeKX716VWW/06dPi8fNmzfP97iIiIhIM/xdrhkmgoiIiIgAdOvWTTz29fVV2ictLQ2bNm0CAFhYWKB169YFERoRERFpgL/LNcNEEBEREREAR0dHtGzZEgCwfv16XLhwQaHPH3/8gTt37gAAxo8fDwMDgwKNkYiIiFTj73LNSARBELQdBBEREVFuBQUF4cGDB+KfX79+jcmTJwNIn/Y9bNiwDP29vLwUxggJCUHz5s2RkJCAokWLYvr06WjdujUSEhLw77//Ys2aNQAABwcHBAcHZ9ihhIiIiHKHv8sLBhNBRERE9FXw8vLCxo0bNe6v6q9AAQEBGDhwIGJjY5W2Ozg44MCBA6hcuXKO4iQiIiLl+Lu8YHBpGBEREZEcd3d33Lx5ExMmTICDgwOKFCkCCwsLNGrUCAsXLkRISIjO/sWRiIjoS8Df5epxRhARERERERERkY7gjCAiIiIiIiIiIh3BRBARERERERERkY5gIoiIiIiIiIiISEcwEUREREREREREpCOYCCIiIiIiIiIi0hFMBBERERERERER6QgmgoiIiIiIiIiIdAQTQUREREREREREOoKJICIiIiIiIiIiHcFEEBERERERERGRjmAiiIiIiIiIiIhIRzARRERERERERESkI5gIIiIiIiIiIiLSEUwEERERERERERHpCCaCiIiIiIiIiIh0BBNBREREREREREQ6gokgIiIiIiIiIiIdwUQQEREREank7OwMiUQCZ2dnbYdSqB04cAAdOnRAyZIloaenB4lEAgsLC22HRUREpEBf2wEQEREREX3JVq1ahTFjxmg7DCIiIo1wRhARERERUQ7Fx8dj+vTpAIBq1aph586dCAkJQWhoKC5cuFAgMXh5eUEikcDOzq5ArkdERF82zggiIiIiIsqh4OBgfPjwAQCwePFiuLq6ajkiIiIi9TgjiIiIiIgoh54+fSoeOzg4aDESIiIizTARRERERESUQ58/fxaPDQwMtBgJERGRZpgIIiIiyqWwsDD88ssv6NChA8qXLw8jIyMULVoUVapUgaenJy5evKj0vPj4eBQrVgwSiQQDBgzI8joXLlyARCKBRCLBqlWrlPZ58eIFZsyYgUaNGsHS0hJGRkawsbFBnz59cPz4cZVjR0ZGimNv2LABALB792507twZZcuWhb6+vsKuURcvXsTMmTPh7OyM0qVLw9DQEGZmZqhRowZGjx6N27dvZ3lPAPDkyROMHj0a9vb2MDY2RtmyZdGtWzcEBgYCAObMmSPGps6HDx+wYMECNG/eHKVKlYKhoSHKlCkDd3d37Ny5E4IgaBSPMsqez7Fjx+Du7o7SpUvDyMgI9vb2GD16NKKjo1WOo2ktlw0bNojXi4yMVGi3s7ODRCKBl5cXAODatWsYMGAAbGxsYGJigsqVK2PixIl4/fp1hvPOnz+P3r17o0KFCjA2NkalSpUwZcoUfPz4UeNncffuXYwYMUL8fpUpUwZ9+vRR+T7PrCDfo5qKiYnBzJkzUb9+fVhYWMDY2Bh2dnYYNGgQgoKClJ4j203N29tb/Jq9vb0Yo0QiwalTp7IVR2JiIv766y84OzujVKlSMDAwgKWlJapWrYpOnTphyZIlGd4Psp+NjRs3AgAeP36c4frqfm4SExOxYsUKtG3bVvz5tbKyQrt27bB+/XqkpKSojDPz++/KlSvo168fbGxsYGxsDBsbG3h7eyM8PDxP75eIiPKIQERERDkWGBgoAMjyNXXqVKXnDxw4UAAgmJqaCp8+fVJ7rTFjxggABH19fSEmJkahfcuWLYKpqanaOIYOHSokJycrnBsRESH28fHxEQYNGqRwrpOTk9jf19c3y3vW09MTVq5cqfaeTpw4IRQtWlTp+RKJRJg/f74we/Zs8WuqHD9+XChRooTaeDp37ix8/PhRbTyqyD8fX19fYerUqSqvU6pUKeH27dtKx/H09BQACLa2tmqvJ/98IyIiFNptbW0FAIKnp6ewadMmwdDQUGksDg4OwvPnzwVBEIRFixYJEolEab8GDRqofDZOTk7i9//gwYMq32NSqVT4888/1d5XQb5HNXXkyBHBzMxMbUxjxowRUlNTlT4Xda/AwECN43j27JlQo0aNLMf84YcfxHPkfzbUvTK7fv26+B5S9WrcuLHw4sULpbHKv//Wr18v6OvrKx3DyMhI2LFjR57dLxER5Q0mgoiIiHLh2LFjgqmpqdCnTx9h9erVwqlTp4Rr164Jhw8fFv74448MH7Z8fHwUzj906JDYvnXrVpXXSU5OFqysrAQAgqurq0L79u3bxQ/5FStWFJYsWSIcPnxYuHr1qrBr1y6hc+fO4nUmTJigcL78h+w6deoIAISWLVsKfn5+QnBwsHD8+HFh3bp1Yv+1a9cKxYsXF7y8vAQfHx/h7NmzwrVr14T9+/cLP//8s1CyZEkxmXPixAml9/Tw4UMxKaCvry+MGzdOOHHihHDlyhXB19dX/JDYpEkTtYmgoKAgwcDAQAAgWFtbC7/88osQEBAgXL16VQgICBCTbQCEHj16qHzG6sg/n2+++UZMOsg/n8GDB4t9mjZtqnScvE4E1atXTzA0NBRq1Kgh+Pj4CFeuXBFOnjyZ4Z4HDBgg7Nq1S4xr69atQnBwsHD48OEM74spU6YojUWW8KhSpYpgYWEhmJubC7/++qtw/vx54fz588L8+fMzJFL27NmjdJyCfo9qIiQkREyiGRgYCBMmTBACAwOFy5cvC//8849gb28vXvPHH3/McO6jR4+E0NBQ4ZdffhH7HDlyRAgNDRVfWSV35fXs2VMcZ+DAgcLu3buFixcvCleuXBH8/f2FWbNmCXXr1s2QGHn58qUQGhoqdO3aVQAglC1bNsP1ZS959+/fF8zNzQUAgpmZmTBt2jRhz549QnBwsHDkyBFhzJgxYmKnSZMmQlJSkkKssvdf3bp1BQMDA6Fs2bLC8uXLhUuXLgmnT58WpkyZIhgZGYnP9cqVK3lyv0RElDeYCCIiIsqFmJgY4d27dyrbP3/+LLRv31788J+SkpKhPasEj4x8wsjPz08hBtkHuyFDhiidTSEIgjB9+nRx5kZ4eHiGNvkP2QCEwYMHC2lpaSrjiY6OFuLi4lS2v3//Xvyw3qJFC6V9unXrpjZ5EBcXJzg6Oqqd2ZCUlCTY2dkJAISOHTuqjGnNmjXiGEePHlUZtyqZn8/w4cOVPp9hw4aJfa5du6bQnteJIFliStl99+rVSwDSZ2ZZWloKPXv2VHj/paSkCE2bNhUACCVKlFD63pGf+WJubq50tlNYWJiYDCpXrpxC8kAb71FNNG7cWHxGR44cUWh/+/atmJCUSqVCWFiYQp+svleaSEhIEJOZWSU+3rx5o/A1Td9XgiCIicz69esrnVkoCOn/v5FKpQIAYc2aNQrt8u8/W1tbcdaZvJMnT4oJpcaNG2doy+39EhFR7jARRERElM+uX78ufmgKDg5WaB83bpz4L+evX79WOoZshkfRokUVPvT//PPP4gfwxMRElXEkJycL5cqVEwAI06dPz9Am/yHbwsJCiI2NzcGdZrR3715xzMz39fTpU0FPT08AIPTq1UvlGPLPTlkiaNOmTQIAwdjYWHj16pXaeGRJpf79+2f7XuSfT5kyZVQ+5/DwcLHfsmXLFNrzOhEkkUhULkM7efKkOEaRIkVUfqD28fER+924cUOhXT4RtHjxYpUxL1y4UOz333//ZWgrjO/RS5cuieONGjVKZb+goCCx37fffqvQnheJoKdPn4pj7Nu3L9vna/q+OnPmjHidmzdvqu3bp08fMdGYmXwiaOfOnSrHGD16tNhPflZQbu+XiIhyh8WiiYiI8tDnz5/x5MkT3L59G2FhYQgLC8tQpPjGjRsK58gKRScnJ2PHjh0K7QkJCdi7dy8AoFu3bihSpEiGdn9/fwCAm5sbjIyMVMamr6+PZs2aAUgvPK2Ku7s7ihUrprJdmbi4OERGRuLWrVvifcvvoJT5vgMDA5GamgoAGDRokMpx69ati7p166psl927k5MTSpUqpTbGVq1aAVB/75ro1auXyudctWpVFC1aFADw6NGjXF1HE3Xq1EH16tWVtsk/t/bt28PS0jLLfupilkgk8PT0VNnu7e0tFibOXPS5MLxHM5OPcejQoSr7NW/eXHzG6opZ50aJEiVgaGgIANi8ebPaQs25Ifs+VK1aFbVr11bbV/bzcuXKFZXxFC9eHF27dlU5xpAhQ8Rj+WdXUPdLRETKMRFERESUS3FxcViwYAHq1q0LU1NT2NraombNmqhduzZq166N+vXri30z7+IEAE2aNEGlSpUAAFu3blVo9/f3x6dPnwBAYXex1NRUXL9+HQDwzz//KN0xSP61c+dOAOk7N6lSp04dje779evXmD59OqpWrYpixYrB3t4etWrVEu/b1dVV5X2HhYWJxw0bNlR7nUaNGqlsCw4OBgAcOXIky3tfvHgxAPX3rolq1aqpbS9evDgAZGsnrpxycHBQ2WZhYZHtfupitre3R8mSJVW2lypVStwNLTQ0VPy6Nt+j6sjeg4aGhqhXr57avk2aNAEA3L9/H0lJSbm+dmZGRkbw8PAAAOzcuROVK1fGjz/+iIMHD+L9+/d5dh3Zz8vdu3ez/D6MHTsWQHqC+u3bt0rHq1+/PvT19VVer169emLCR/49UVD3S0REyjERRERElAuRkZGoXbs2pk+fjps3b4qzXFRJSEhQ+nVZguf8+fMK2yXLkkOyrZ3lvX37Nkf/mh4fH6+yTZbIUOfq1auoVq0aFixYgHv37mW5NXvm+3737p14nNVMHnXtr169yjLWrGLJrswzsjKTStP/epXVeyEvqItFFkd2+qmL2crKKst4rK2tASBD4kBb79GsyGK0tLRUm8wAgNKlSwMABEHI8N7NSytWrIC7uzuA9G3gFy1aBFdXV5QoUQKNGzfGokWL8OHDh1xdIyc/L4Dq70VW7wl9fX1xJlrmZFJB3C8RESmn/rceERERqTVo0CBERERAIpHA29sbffv2RfXq1VGqVCkYGhpCIpEgLS0Nenp6AKAyYTJgwAD8/PPPEAQB27Ztw7Rp0wCkf3g6cuQIAMDDw0PhA6v8B/dhw4Zh/PjxGsUt+1d6ZWSxqpKUlIQ+ffrgzZs3MDAwwLhx49C1a1c4ODigePHi4tKfR48eiTOdskoU5ZTs/jt16oTff/89X65B6WTLvrJLG+/R7MjpfeU1MzMz+Pv74/Lly9ixYwdOnTqF69evIzU1FcHBwQgODsbixYuxd+9ecflcdsm+F3Xr1sWWLVs0Pq9cuXJKv56bZ1cQ90tERMoxEURERJRD4eHhCAoKAgBMnz4dv/zyi9J+qpZVyHNwcECjRo0QHBwMPz8/MRG0c+dOcSlK5mVhADLUfREEAbVq1cr2fWTXyZMnxVoyq1atwrBhw5T2U3ff8jM6YmJiVH7QlLWrUqJECTx79gxJSUkFcu+5JZt9k5aWprZfXFxcQYSTLS9fvtS4j/z7UhvvUU3I4nrz5g1SUlLUzgqSLVOTSCR5MhtJHUdHRzg6OgJIX6p36tQpbNiwAbt378arV6/Qs2dPPHz4ECYmJtkeu0SJEgCAT58+5cn3Iav3REpKSoaZV8rk5/0SEZFyXBpGRESUQ7du3RKPZfUulJHV5ciKLNETFhaGmzdvAvjfsrBKlSqJdUrkGRoaombNmgCAc+fOaRZ4LuXFfctiBtKXmamjbhxZ/aXg4OB8qd2S12QFjrOqg3Lv3r0CiCZ7IiIi8ObNG5XtMTEx4rJG+SSDNt6jmpDFmJSUJNYwUuXy5csAgCpVqqidqZTXihUrBnd3d+zatQvfffcdAOD58+diAlpG05k5sp+XR48e5bpWFgBcv35d7bK/GzduiD+XmiSeNL1fIiLKHSaCiIiIckj+A5C6GRyrV6/WaLy+ffuKS162bt2K6OhonD17FoDy2UAyXbp0AZA+Q0m2jCw/aXLfaWlpWLt2rcoxnJ2dxdkxmzdvVtnvxo0bSndak5Hd+4cPH+Dr66s27sLA3t4eQPrMh7t37yrtk5SUhF27dhVkWBoRBAGbNm1S2b5hwwZxCWDmWlYF/R7VhHyMPj4+KvtduHABt2/fVjinoLVt21Y8zlx83djYGED6roXqyL4PgiBg2bJluY7p7du3CAgIUNku/1yz++zU3S8REeUOE0FEREQ5VKVKFfF4w4YNSvv8/fff2Ldvn0bjlS5dGm3atAEAbNu2DX5+fuIHa3WJoPHjx4tblnt7e2eYsaPMgQMHxBlHOaHJfU+bNg3Xrl1TOUb58uXFXcV27tyJvXv3KvRJSEjAiBEj1Mbi6ekJGxsbAMCkSZNw5swZtf2DgoJw+vRptX3yk5OTk3j8xx9/KO0zceJEPH36tKBCypZ58+YpTWDduXMH8+fPBwCUKVNGYUvxgn6PasLR0VHckW7t2rU4ceKEQp8PHz5g5MiRANKX9Y0ePTpfYnn06FGW78ujR4+Kx7KEokyZMmUApBeDVrfzm4uLi7gMa9GiRdixY4faa4aGhqpN9ADp71dlS8ROnz6NNWvWAEjfGbBx48ZiW27vl4iIcoc1goiIiHKofv36qFWrFsLCwvDPP//g3bt3GDRoEMqUKYPo6Ghs2bIFO3fuRPPmzTVeEjNgwAAcO3YMUVFRWLBgAYD07dPVbf9tbW2NjRs3olevXnj+/DkaNWoELy8vdOrUCeXLl0dycjKio6Nx+fJl7Ny5E48ePUJAQECOt+Du0KEDrKys8OrVK8ycORORkZHo3r07SpYsiQcPHogfqrO67yVLluDEiROIj49H7969MXr0aHTv3h1mZmYICwvD77//jtu3b6Nx48a4cuWK0jGMjIywY8cOODs749OnT2jTpg369u2Lbt26wd7eHmlpaXj+/DmuXr2KPXv2IDQ0FMuXL8+QkClI9evXR7NmzXDhwgWsXbsWSUlJ8PT0hLm5Oe7fv481a9bg5MmT+Oabb3D+/HmtxKhK5cqVERMTg6ZNm2LKlClwdnYGAJw6dQq//fabuMPT8uXLFZZPFfR7VFNr165FkyZNkJSUhM6dO2PcuHFwd3eHqakpQkJC8Ntvv4n1sCZNmpRv9Y2ePHmC1q1bo0aNGujevTsaNWok1s2KiorC9u3bxaRNvXr1FJaJfvPNNwDSZ+KNGjUK48aNQ8mSJcX2ypUri8d+fn5wdHTE27dv4eHhgS1btsDDwwNVqlSBnp4eXr16hZCQEAQEBODixYv44YcfxN29Mqtbty5u376Nhg0bYtq0aXB0dMTnz59x8OBB/Pnnn2LtpZUrV+bp/RIRUS4JRERElGMhISFC8eLFBQBKX7Vr1xaePXsm/nn27Nlqx4uNjRVMTEwyjPHnn39qFIu/v79gaWmpMhbZSyqVCidPnsxwbkREhNju6+ub5bUOHz4sGBsbq7yGs7OzEBYWluWYR48eFUxNTVWOM3v2bOGnn34SAAjGxsYq47lw4YJgY2OT5b0DEDZu3KjR88zp87G1tRUACJ6enkrb79y5I1hZWamMb9KkSYKvr6/454iIiGxfQ0aT911W9+bk5CQAEJycnIT9+/cLRYoUUfm+Wrx4sdp4CvI9qqkjR44IZmZmauMZM2aMkJqaqvT8rL5XmggMDNTovVutWjXh0aNHCuenpqYKTZs2VXleZnfv3hVq1aql0TXnzp2rcL78+2/t2rWCvr6+0nMNDQ2Fbdu25fn9EhFR7nBpGBERUS7Uq1cP169fx6hRo2BrawsDAwNYWlrC0dERixcvxuXLl8VlG5qQFUuV0dPTQ9++fTU6193dHREREVi8eDHatGkDa2trGBgYwMTEBPb29nBzc8OSJUsQGRmJ1q1bZ/te5XXo0AHBwcEYOHAgypYtCwMDA5QqVQpOTk5Ys2YNTpw4AVNT0yzHad++PcLCwjBy5EjY2trC0NAQ1tbWcHV1xeHDhzFnzhzExsYCAMzNzVWO07RpU9y/fx+rV6+Gq6srypYtC0NDQxgbG8PGxgYuLi6YP38+wsPDMXjw4Fzde25Vq1YN165dw+jRo8V7LlWqFDp27IgDBw5g0aJFWo1PHVdXVwQHB8Pb21uM3crKCj179kRQUBB++OEHtecX5HtUUy4uLnjw4AGmT5+OevXqwczMDEZGRqhQoQIGDBiAs2fPYsWKFWJNq/zQsmVLnDp1CtOmTUPr1q1RuXJlFCtWDAYGBrC2toaLiwtWr16N69evK10mJZVKcfToUcycORN169ZF0aJF1RaQdnBwwPXr1+Hn54eePXuiQoUKMDExgaGhIcqUKQNnZ2fMnDkTV69exaxZs9TGPmzYMJw9exZ9+vQRf+7KlSuHwYMHIyQkROn/v3J7v0RElDsSQfj/4gNEREREhVC7du1w4sQJtGjRQiyeTUTaY2dnh8ePH8PT01NlnTAiIiq8OCOIiIiICq1nz56JBaCbNm2q5WiIiIiIvnxMBBEREZHWPHjwQGVbQkICvLy8kJycDABaX9JFRERE9DXgrmFERESkNcOGDUNcXBz69OmDhg0bwtLSEh8/fkRwcDBWrVolJoqGDh2K2rVrazlaIiIioi8fE0FERESkVcHBwQgODlbZ3r17dyxfvrwAIyIiIiL6ejERRERERFqzZMkS7NmzBydPnkR0dDRiYmIgCAKsrKzQtGlTeHp6onPnztoOk4iIiOirwV3DiIiIiIiIiIh0BItFExERERERERHpCCaCiIiIiIiIiIh0BBNBREREREREREQ6gokgIiIiIiIiIiIdwUQQEREREREREZGOYCKIiIiIiIiIiEhHMBFERERERERERKQjmAgiIiIiIiIiItIRTAQREREREREREekIJoKIiIiIiIiIiHQEE0FERERERERERDqCiSAiIiIiIiIiIh3BRBARERERERERkY5gIoiIiIiIiIiISEcwEUREREREREREpCOYCCIiIiIiIiIi0hFMBBERERERERER6QgmgoiIiIiIiIiIdAQTQUREREREREREOuL/AO1SQogrFlycAAAAAElFTkSuQmCC" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "draw_order_multiple(\n", + " [out_SLOWRK_mlp_sde, out_SPaRK_mlp_sde, out_GenShARK_mlp_sde],\n", + " [\"SlowRK\", \"SPaRK\", \"GeneralShARK\"],\n", + " title=\"Order of convergence on a general SDE\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "a260a1022d30c8", + "metadata": { + "collapsed": false + }, + "source": [ + "## Commutative noise SDEs" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "4c72a44488366e9d", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-04-02T15:06:28.473708Z", + "start_time": "2024-04-02T15:06:26.187052Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "
", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiIAAAGeCAYAAACpVGq5AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAAA9hAAAPYQGoP6dpAAChDklEQVR4nOzddXhU19bA4d9M3D0hIQnu7u6UUiilRo0K1Fvqt+1Xt9te6re3rrTU3WgLLYWipbi7BUJC3D2Zme+PPZqZhMhMMknW+zx55tics4NkVvZee22NwWAwIIQQQgjRDLTN3QAhhBBCtF0SiAghhBCi2UggIoQQQohmI4GIEEIIIZqNBCJCCCGEaDYSiAghhBCi2UggIoQQQohmI4GIEEIIIZqNBCJCCCGEaDaezd2A2uj1elJTUwkKCkKj0TR3c4QQQghRBwaDgcLCQuLi4tBqa+/zcFkgotPpeOKJJ/j0009JS0sjLi6OefPm8cgjj9Q5qEhNTSUhIcFVTRRCCCGECyUnJxMfH1/rNS4LRJ577jneeustFi9eTJ8+fdiyZQvz588nJCSEO+64o073CAoKAtQ3Ehwc7KqmCiGEEMKJCgoKSEhIMH+O18Zlgcjff//N7NmzmTlzJgAdO3bkiy++YNOmTXW+h6nnJDg4WAIRIYQQooWpywiIy5JVR48ezYoVKzh06BAAO3fuZN26dZxzzjk1vqe8vJyCggKbLyGEEEK0Xi7rEXnggQcoKCigZ8+eeHh4oNPpeOaZZ5g7d26N71m4cCFPPvmkq5okhBBCCDfjsh6Rr7/+ms8++4zPP/+cbdu2sXjxYl588UUWL15c43sefPBB8vPzzV/Jycmuap4QQggh3IDGYDAYXHHjhIQEHnjgARYsWGA+9vTTT/Ppp59y4MCBOt2joKCAkJAQ8vPzJUdECCGEaCHq8/ntsh6RkpISu7nDHh4e6PV6Vz1SCCGEEC2My3JEZs2axTPPPENiYiJ9+vRh+/btvPzyy1x77bWueqQQQgghWhiXDc0UFhby6KOP8sMPP5CRkUFcXByXX345jz32GN7e3nW6hwzNCCGEEC1PfT6/XRaIOIMEIkIIIUTL4xY5IkIIIYQQZyKBiBBCCCGajQQiQgghhGg2EogIIYQQotlIICKEEKLVSs0r5Y2/jpBXUtHcTRE1cFkdESGEEKK5Xb94C/tOF3A0o4iXLx3Y3M0RDkiPiBBCiFZr32m1ivsvu083c0tETSQQEUII0Srp9JYyWT6e8nHnruRvRgghRKt0PKvIvK3TG3Dj+p1tmgQiQgghWqUdyfnm7ZIKHZlF5c3YGlETCUSEEEK0OhmFZdz7zU6bY0lZJc3UGlEbCUSEEEK0OuuPZJm3vY35IdZDNcJ9SCAihBCi1TmVUwrA4MRQrhieCMCxzOLmbJKogQQiQgghWqQ9Kfk8+P0uMgvtcz+Sc9UwzMQe0fRsFwTA1hO5Tdo+UTdS0EwIIYTLVOr0vLbiMEezinns3N7EBPs65b5llTrOfW0dAD6eHjxxXh+b86dyVY9IfJgfwzqGA7A9OY+CskqCfb2c0gbhHNIjIoQQwiWyi8p5468jvLryCL/uOs1Xm5Oddu9/jmWbt/86mIHBYOCrzSfZeiIHsA5E/EkI9yc2xBed3sDhdMkTcTfSIyKEEMLpft11mgWfb7M5djijcUGAXm/gWFYRXaIC2ZGcZz5+IruEP/al83/f7QZgSs9oUvNUIJIQ7gdATLAvp/PLyJIpvG5HAhEhhBBOVz0IATjSyEDkxT8O8uaqo/zf9J42gQjAJxtOmLdXHMgAwMtDQ3SQGgqKCvIBcJhPIpqXBCJCCCGc6i9jIFDdscwiTueXEhPki1arqfd931x1FIDnlh0g1N82z2Od1XRdk7hQPzyMz5FAxH1JjogQQgin+nqL41yQ8io9oxau5Jnf9jf6GXkllXh7ahnVOcLm+KwBcebt+DA/83ZUoApE/tiXLqXe3YwEIkIIIZzqYFqhzb5WA12jA837H29Icspz+sYF2+R8TOsdw9Re0VbnQ8zbph6R/acL2HZSpvG6EwlEhBBCOE1ZpY6kbNvCYSF+XszoF2ve7xAR0KB7+3rZfmQNTAjjmtEdAQj19+L5i/szqaclELlwcLzVtaHm7X2pBQ16vnANyRERQgjhNEnZxeirjXyE+Xtzw7hOnM4r5Zutpygsq2zQvQN9vCirtPSA9GgXyIWD4+kcFcCwjuF4eahA5YsbRlJcXkUPYyEzgL7tQ7hocDzfbTvF4YwiSit0+Hl7NKgdwrmkR0QIIYTTnM4vA6BDhL/5mI+XB0G+XtwxpRug8jsawt8qcPDUapjWux1eHlpGd4k0ByEAo7pEMLV3jN37e8cFA/DxhhNMe2U12TKV1y1IICKEELU4klHI2sOZzd2MFiPdGIh0jrQMv7QPVVNowwK8AZW0Wlapq/e9K6r0APx7dh+O/GeG+X51ZZ28mpxTyvvrjte7DcL5JBARQohqTmQXs/jvJMoqdVzx3kau+mAT39QwE0TYMvWItAuxfOgP6aBKrAd4e+BpnE7bkF6R4ooqAMZ0jWxQ23paDdWAfVKtaB6SIyKEENVMeGEVAGsPZ5JhrDtx37e7SM0ro2/7YN746wgA14/rbJOEKSC9wBiIBPsyf0xHtp7I5apRHQDQaDSE+nuRVVRBXmkFTy7Zy+n8Mj65bjhp+WV8vOEEt03uarMezdHMIm7+ZCuFZVUUlqlAJMCnYR9dHSICuHZMJxatVz0hKw9k8PWWZOYMiUejqX9dE+EcEogIIYRRaYWOWz/bat7/c79tYa7//nnIZv8/v+13i0Dkz33pfPLPCV64uD/RTlpUriFeXXGYL43rycSG+HLn1G5214T4qUAkJbeUpXvSAHjpj0OsPJDByZwSjmUV8dn1I83XrzucZVca3r8RSaaPntuLi4fEM+PVtQDc/+0uvD20nD+ofYPvKRpHhmaEEAIwGAxMfXk1fx2sez7IqdxSc66DTm8w9wY0tes/3sLqQ5k8/WvjC4U1xsvLLYGaKTG0OlM9jz0plim0H/2dxMmcEgDWH8m2ub6ovMruHv7eDf8dWqPR0LNdEN6elo+/JTtTG3w/0XgSiAghBFBQVkWKcaG06jpF2ta98PHUEmgcHkg2foDe8/UORvxnBVuSclzbUKOSiipKKmw/pI9lNd/KspU6vc1+j2r5GCaxxtyRXafyaryXdeVT03CMNY8GlIe3ptVqiLRKdD2SKSvyNicJRIQQAsyrtToyODHMZj8y0MccnFzx/kaOZxXz0w71W7VpPRRXKq3QMeGFVZz9yhqbyqLllfpa3uVa1mu4TO0VYzOd1pop/2Pnqfwa7/XFpmTz1NpiY4/I1F4x9I4N5pKh8TW+rz6sa4icyC5BV734iWgyEogIIQRwOt82EOkdaxlasC5PDhAZ6M3UXqpORWZhuc3aKtV7Blxh3+l8MgvLSc4pZejTfzrl2dlF5Y36ME6zGpZ65bKBNV4XG6ICEVMANbN/LE/M6m1zzUM/7Gbaf9dQWFZpHpoZ3imM3+4cx/MXD2hwG61V71XZcDS7hiuFq0kgIoQQQGqebX7Hb3eOY/nd4/nj7vF2v4UH+Xpx++Su5rLhOUUV5nNVOtf/Zr3vtONpp/ml9Z8Sm1FYxnUfbWbI03/ywHe76v3+Q+mFLFy6n6PGhNKhHcLMw1aOtAuxTaaNDvLhkmEJjK02JTe7uIJBTy1n6wm1Lkygj+1qu42lrTZL5sEf6v+9C+eQQEQIIcBhfki3mCC6xwQREejD3ifPNh+vqNKj1Wq4aLCaaZGcW2I+1xQ9IgdOO14rJa+0Er2DXo2KKr1dj4/JO6uPseKAmh30zdZTDpNDa3PRW3/zzupj3Pet+iBPtKqo6sjQDrbDXH3jQvD39uTT60fw30ttezuq9AZzEmuAj3PLsVefrpucU2oumCaalgQiQggBHDpDcasAH08+mj+MhHA/bpnUBYAQf5XweNQq2TG3pMLh+50pvcBxaXKDAYqqJbCm5JUyauEKRi1cyd7UfJtAxWAwsMw4hdZk8/H6JdtWTyY9p2/t05kjAn345Lrh5v3x3aPM2+f2j6NHjOMk1yBf51absE5hMQ3T5BS7/u9O2JNARAjRZry24jBXfbCR8ir78uL7rXoZ3pw72OH7J/aIZu39k5nUQ63wGuqnhgusA4OaggRnqr5o3N1Tu5unoxZYDc9U6fSMe24l2cYP2JmvruOCN9ebc0FySyrNPUGDE0OB+g3vVP9z9PbQMr77maueju4SyVUjO3Df2T3M03kBvDy0fHfraLY8MpXbJnW1eY+zh2ZirSq/hhtn0Nz+xTabfweiaUggIoRoE55deoCXlh9i7eEsVlQrVJZfUkmqsTT5riem1blIWai//YdjUXlVvYc36qt6L0RMsA/BvqotBaXqXFp+GQs+32a3Eu7OU/n8vlf1gpjqnoQHeBMZqAKC4oq6t/1kdonNfoVOj4/nmYdQPLQa/n1+XxZUCzYAAn08iQz0sesBcfbQzBPn9WFEp3DevnKI+XvfnJTL3Pc3OvU54swkEBFCtHoFZZW8vdoyrbb67JD0QvWBHObvZf5Ar4tQP8eLrqXlu7awWWG5ba9FZKAPwcYP7hmvrqWsUsftX2zj973pDt+/yTj8YprpEhPsay6bbpouq9cbeHPVEfO11X2x6STXLNpkc+ymCZ0b+B3ZC6r291BbAmxDtA/146ubRjG9bzsiAy1/jzI80/SkxLsQotUzFR0zKag2tJFtnPVS39VcQ6r1iPh7e1BSoSMtv8xuyq8zFVXrEfH18qDcKtHy/m93sTkpt8b3m4ZqMszrwviYexyKy9Vwy/fbU3h+2UEAkp6daX7v15uTub/a7JoZ/dpxVu8YpvVu19BvyU6wn+3HU01BnzOYekRE85AeESFEq3cq13bGiPV025ziCtYeVmXdI+oZiARV+y29l7H2yJUfbCSjsIzrF29m8d9JDWhxzQwGg3loxt/bAy8PDQMSQsi0Kmz28xlKlpuKhW07kQcYe0SMZdNN1Vr3pFgKjpl6CY5mFtkFIQB94kK4YFB8gxejc8S6R8RDq3F6sqo161V5/bycOwQkzkx6RIQQrdarKw6z8kCGObnU5KXlh/hpZypHM4uwqiZOmH/9AhFttaJY3WOCzHUvpry0msKyKv7cn8E1ozs2qP2OlFXqqTIOLa29fxJenlqCfL1qnXp61cgOfPLPCfN+dlEFOcUVfLXFtECdJXGzuEL1iBzPKjYf23kqj0k9oknJdTwFuHO1EvjOEGwVeIT6edn9WTvTvDEd+XjDCVLySm2SZ0XTkB4RIUSrVFap4+Xlh9iRnMf7a48B4OVh+TA7kmEbhABEBNa/+9/6nrdP7mr+zd3RGinOsPqQ6r3RalSSqSmnpaZpr89d1I8bx6vcDVOeRXZxubk+B8ClwxLMQzO7TuUx/ZU15ucA7EzOo6JKzyM/7nH4jPiw2muHNIR1j4ijpGBn8vH04P1rhgLYrd8jXE8CESFEq7TZavG5QmMCZtdoxx/WJsF+9f/A87Xqyo8L9eOlOc4pQe7Ikp2p3PzpVkAFFdZFud66cjBXj+rA5J62vT/ju0eREO7Ppoen8Pvd4wE11GKqd9KzXRDtQizJqntSCjhQrabKzuQ8Pt94why8nNU7xuZ89WqpzmAdfFivC+MqpiDNVQGkqJkEIkKIVmmXg0XV/nVWd4Z0CGPe6I6EOAo6GlCd3VTm3SQ62P5DuaEVO3/fm8bCpfvNRche+uOg+Vz1oKlzVCBPze5r1x5Tkmd0kC8xQT5oNKA3wLFMNfRi+nPwd/Bh/8YVqp7K+iPZ/Lr7tPl4XIgvswbEmffrm1tTF9b3dPUsJLAEIuVVeqqaoDqusJAcESFEq3Q43b5S6pAOYXx3y2gA7pjSjcH/Xg6oD+HEcH/mjuhQ7+c8f3F/nv51P/ONeSAxwfY5BsXlVXh71v/D+qZPVO/HwPhQzukXS5ZVkm37UD+H74m2ynHw9tTi62X5fdPTQ0v36CAOphfymTFnxBSIWE+P1Whg+6NnEeTrxeDEULadzLOZhWMAyistxcxckb9h3dtj/X27inWibXG5jhB/+T29qciftBCiVTqcUWSz3zkywGZ6bphV1/8rlw5k2V3jz7hOiiOxIX68ccVghnYMB9RU0GrLmNSrSJiJda5CSl4pOr3B5lhcDYGI9fEwfy+7NVXGGBeXO2ZMRg0294hYPoh7xAQR6u+Nh1bD8xf3t3tGfmklEU0w5fWes7oD8NCMni5/lrenFm9j3ffqZfKFa0mPiBCi1TEYDOahB5PHz+tjs6/RaHhz7mCSc0rsch4aw8tDS0SAt81v8abaHPVxwqpqaUFpJbklFeYqqSF+XuYP6ep6WE1F7RBhP5vloiHtWbT+uHnf1CMSH2YJYKyHd7pGBxHq70VeiaX2yiVDE+gWE8ip3BIuH55Yv2+sHhZM6sr0vu3oGuW6mizWAnw8qCjRm4u6iaYhPSJCiFanoKyK0krbD//qy8wDzOgXy00Tutj1GjRWdJBtnkhDSr4fserRSc4tJctY+yM8wJudj08jIdxx74310IyjvI8+cSF8eeNIu+MJ4f48eE5PooJ87IIL64Jfr18xiDFdI4kO8uWT60bUuRx+Q3hoNXSPCXLp1F1rph6zzELXrxckLCQQEUK0OpnGku0+nlpm9o/lhYv7m1dYbQrR1fJEGvIb9scbkszbS/ec5qBxJsuZEkM1Gg0J4ap3Y/bAOIfXjOwcYd7OtiqEdtOELmx+eCoDqiW8Whf5ql6TpTXpYux5sQ4CP9t4gsve3VCvxQBF/bg0EElJSeHKK68kIiICPz8/+vXrx5YtW1z5SCGEIMO4Am5CuD9vXDGYOUMTmvT53h62P1q/3HyyXr0iZZU6m+TQsko9d365A6hbOfKvbxrFu1cN4fyB7Wu8xtQJVH2WjSPWs36cWT3V3ZjK8h/OsCQ6P/zDHv45lsN7a441V7NaPZf9i8rNzWXMmDFMmjSJpUuXEhUVxeHDhwkLC3PVI4UQAoAMY9d6dDNVyayqtqjeb7vT8PX04OVLB9bp/aahAV8vLXNHdOCDdfY5HbWJDfGzqZbqyMp/TWTd4UwuHXbmHI/yqvrnuLRE3YyByME0+xlXppWKhfO5LBB57rnnSEhI4MMPPzQf69Spk6seJ4QQZqYPjeYKRBytV/L99pQ6ByKmQCoqyMc8XGAyo79zcjI6RQbQqY6l2RMjAkjKLjnzhS2cqXdoZ3I+pRU6m6nPFVJbxGVcNjTz888/M3ToUObMmUN0dDSDBg3ivffeq/U95eXlFBQU2HwJIUR9peSpNVHanaFXwFX+Na070UE+NgFJffJhTTkuUYE+5nwPgM0PT+W8AY7zPlzpmfP7MrFHFJ9fP6LJn92UOkUG0C7Ylwqdnm0nc20SnhtalE6cmcsCkWPHjvHWW2/RrVs3fv/9d2655RbuuOMOFi9eXON7Fi5cSEhIiPkrIaFpx3WFEK2DqWu9e0zTTPusrnNUIBsfmsKzF/UzH6u+Um9tMq16RIZ3CmdSjyhuGt+52RZkSwj356P5wxntYOZRa6LRaOgVq6Y/n8otIddqynJDZj6JunHZ0Ixer2fo0KH85z//AWDQoEHs2bOHt99+m2uuucbhex588EHuuece835BQYEEI0KIejEYDBxKNwUita8t40oajYZEqym21ou41WbNoUwe/WkvoAIRH08PPpw/3CVtFPZMKzDnllSSV2KpBWNKgBbO57IekdjYWHr37m1zrFevXpw8ebLG9/j4+BAcHGzzJYQQ9ZFeUE5uSSVajWUWRHPp1z7EXDq9Sl9z135yTgn3fLWDA2kF3PP1DvPxmlbUFa4TagxEnl16gKd/2W8+nl4oyaqu4rJAZMyYMRw8eNDm2KFDh+jQof5rOQghRF3tSFbTXrvHBNmsjNscPD20LL1zHIBNZVJQPTcmr688wvfbU5j+ylqbiqwXDI5vmoYKM+vS/xuOZZu380oqKatsG7OHmprLApG7776bf/75h//85z8cOXKEzz//nHfffZcFCxa46pFCCMH25DwABiW6R6kA03L25VV68wfZygPpDH36T77fdgqA7GL7Rd1+v2u8zUJ0ommE1lIwTiquuobLApFhw4bxww8/8MUXX9C3b1/+/e9/88orrzB37lxXPVIIIdh+Mg+AQYmhzdoOk0AfTzyNVV1P55cx/vm/uPajLWQXV/D+WlUfxLp3BCA2xNdmzRjRdKx7RC4fnsD1YzvRLliV7M9owPDMpuM5/HUgw2nta41cGm6fe+65nHvuua58hBBCmFXp9Ow+lQ/AoDpUDG0KGo2GuFA/TuaUsPJABidzLPU4gv3Uj+DqPSJ3TunWpG0UFp5ay+/nD83oRZCvFzuS80grKCO9WsJqpU7Pr7tOM6lHNCH+Xqw/koW3p5ZhxpWYy6t0XPLOBgB2PHaWOf9E2JJ+PyFEq3Eks4jSSh1BPp52hcCaU8fIAE7mlLDtRK7N8dIKNVSTYwxEPr9+BJ2iAs5YFVW4jvUqxKaZTrGhfnAi12YNmq+3JPP5xpPsSM7jrN4x/OeCfsx9fyMAR545B08PLYfTLdcXlFZJIFIDWfROCNFqZBsTPeNC/Zpsxda66BihpvFuNQYi3p7qR29BWRVllTpzL0m7EF8JQppZ3/Yh/O+ygfy4YIz52MjOqodj1UE1xLL1RC73f7uLHcZ8pOX70tl/2lKA09TDtS/VcqykUuqQ1EQCESFEq2EqOhXg07yzZarrEKFKqacZS893NpZWLyit5JlfLVNEI+qwoJ1wvdkD29ssBmhacXjbyTx2JOex8Xi23XuuXrTJvG2qOWKqZwNQUiEzbmoiQzNCiFajqEwFIoF1LB7WVKrXA+kYEcCBtELySyvZZ/WbdLCv/Eh2R3GhfgT7elJQVsX5b6xnep92tV7/+l+HSQz3JzW/1HysTAKRGsm/eiFEq2HqEalPOfWmYCobbtLR2CNSpTeYh2u+v3U0mvosSCOaVLeYIPPf1bK9abVe+/vedLtj0iNSMxmaEUK0Gu46NBMR6ENMsGXYxXohOxN3Sq4V9h6f1bvGc/ed3YNu0YFM6B5V4zUlUgytRhKICCFaDVMgEujjXkMzgE3OQd+4EEL8bNtYfV+4l/7xocwb3dHhuQWTurL8ngkM6VBzEb3SCkuyql5v4FhmET9sP2VXQ6Ytcq/+SyGEaARLjoj7/Wgb2TnC3GXfr30I14zqwJrDWQT6eHLFiMRmbp2oiy5RAbWeH9stkpeXHwLUTKmkbEvNGNPQzIaj2Vz70WZKjT0kfl4eTO8b66IWtwzu979VCCEaqNjcI+JeQzMAc0d0IC2/jNFdI9FqNdwzrQf3TOvR3M0S9TCjX6x5ZWST7jGWIbWB8aF0jgogq7Cc2QPb878Vh83nTIHIh+uPm4MQgDWHs+oUiKTklVJQWkmv2Na3GKwEIkKIFu9oZhFPLdnHrlN5gHsOzXh7anlwRq/mboZohIhAH96/eig7kvOY1DOat1Yd4YFzLH+nWq2GH24ZQ4VObzfF11S8LjrYdor2UasiaTVJLyhj+itrKK/Us/r+ia2u1owEIkKIFu/2z7fbTIN1x6EZ0TpM7R3D1N4xALx/zTC78yHGtWoGV1t00dQjUlqhByA8wJuc4gqbaq2OVOr0XPjm3xQahx03Hc9h9sD2jfsm3IwkqwohWjzrIAQgSAIR0cziQv1YcttYLhysgoZSY2XVovJKAK4d0xFQVVjLaplRcyK7hJQ8Sz0S06KOrYkEIkKIVic2xLe5myAE/eJD6G3M6TD1iJh6NhLC/fH1Uh/Bp3JL+HLTSU5bFUAzScoqttk3DT+2JvJrgxCiRdPp7ac/xga3rjF00XKZFrozLWxoLrrn60lciB/Hsoq5/9tdbDuZR7/2ISy5fSwAf+5LZ1dKPj9sPwVAt+hADmcUcTCtEL3e4FZrKTWWBCJCiBYtr6TC7liwn/xoE+4hKkglp2YVVZCcU0KycYHDQB8vYkN9OZZVzDbjcMvulHwAKqr0XP/xFpv7TO4ZzYmcEoordCTnlpjXL2oNZGhGCNGimbq6rUmpdOEuIgNVj8j+0wVMfXk1uSUqRyTQR/WIOLJkZ6rdsXP7x5mn7q7Yn+Gi1jYPCUSEEC1a9UBkcs/oZmqJEPZMPSIA5VV683aQryexofaByNdbknn61302xw4/cw794kOYMyQegG+3nnJRa5uH9F8KIVq0wjL1G2bX6EDumtqNQYk1l9kWoqmF+3uj0UD1Su4qR8Q+qfr+b3fZ7PdrH4KXh+ozGNctElB1c1pTnogEIkKIFq3AGIiE+Hlxbv+4Zm6NELY8PbREBHiTVWTJZRqcGEqwr5fDHhGTjhH+vDBngM1iiHGhfnhoNZRX6ckoLKddK5kdJkMzQogWraDMMgtBCHcUE2wJGG6d2IXvbx2DVqupdZq53gDDOoYTHuBtPubloaW9MXh57Kc9DhO1WyIJRIQQLVqhORBxv7LuQgDm4AFsa9x0igygX/sQRneJsMklAcgoLHN4rw4R/gD8sS+dqz7Y5ILWNj0JRIQQLVpBqRqaCZYeEeGm4sP8zduJVtNuvTy0LLl9LJ/fMJJJPaJs3nPrxK4O7zW0Q7h5e3dKPgs+38bfR7Kc3OKmJf9zhRAtmvSICHfn7Wn5nX9QYqjDax6e0ZvsogrGdI2kS3QgozpHOLxuRr92/PfPQ+b9X3ed5tddp0l6dqZT29yUJBARQrRopmRVKWIm3FXvuGDzdnANAXOIvxcfzLNfRK+6bjFBfDR/GA99v5vUfMvwTUlFFf7eLfP/QMtstRBCGOUaS2eH+3uf4Uohmses/rGUlFcxtKNzppZP7BHNxUMTeHXFYfOxA2mFdiv+thSSIyKEaNFyjDMHwgIkEBHuSaPRcNnwRLpGBzntnhO62+aUHMkoctq9m5oEIkKIFs3cIyKBiGhDhnQI45GZvcz7mYXlzdiaxpFARAjRomVLICLaqOvHdebWiV0AyChwPN23JZBARAjRYlXq9OZZM5IjItqiaGP9kcUbTpiXO2hpJBARQrRYucb8EK0Ggv1k+q5oe6KtqrY+v+ygw2t+2pHCLZ9urbFIWnOTQEQI0WLllVjWmfFoJQuACVEf1kOSf+xLszt/NLOIO7/cwdI9aSxal9SELas7CUSEEC1WobmGiPSGiLbJukaJTm/AUG2Z38Pphebtt1cf5dutp5qsbXUlgYgQosUyLXgX6CMlkUTbFOzrxa4npgGQVVRBQWmVzXlTr6HJvd/sZF9qQZO1ry4kEBFCtFhFsvKuEAT7epmD8exi22m8eaX2Caxfb0luknbVlQQiQogWq9DcIyJDM6JtCwtQ/wdMCdwmpv0rRiQyo187AA6kSY+IEEI4RVG5+m1PekREWxceoKbxZhdZApHHftrDO6uPARAb7Mt1YzsB8M+xHLeaQSOBiBCixSqUoRkhAAj3t+0RqdTp+XjDCfP5UH8vEsMDzPvDn1nBFe/94xal4SUQEUK0WIWSrCoEYNUjYqw0nJZv2+MR6u9NZKBt0b+/j2bz5JK9TdPAWkggIoRosSw9IpIjItq2cGOOSI5xaOa0XSDihUajYe6IRJvjOcW2OSXNQQIRIUSLZaojEihDM6KNiwxUPSKZRWrWzOn8Upvz3WPUyr/PXNCPmyd0MR+PsarM2lwkEBFCtEiFZZVsTsoBIEIWvBNtXEK4PwDJOSUApOTZBiLWAceVIy29IpU6fRO0rnYSiAghWqT1R7LJLakkJtiHST2im7s5QjSrRGMgcjJHBSCn8yxDM32sqq8CxIf588LF/QEZmhFCiAYzdT0P7RCOn7dHM7dGiOZl6hHJKipn+8lc8/+PKT2j+eS6EXbXm4ZqciUQEUIIWwfTCpn12jpW7E+v9TrTrAB3GOMWormFWK239OSSfaQYe0SuHNnBZmE8E9OxnBIJRIQQwsZDP+xmd0o+1y3eUut1aQXqB227EJ+maJYQbu/OKd0AyCwsN/eIxIX6Obw2ItAbjQbKKvVkFpY7vKapSCAihHAreTX8hvb+2mOc/8Z68w/Y09IjIoSNucYk1JS8UvNid7Ghjv9/+Ht70jUqEIAdyXlN0r6aSCAihHAr1l3M1kuaP/3rfnYk5zFq4Uq2JOWYh2baSSAiBABRgT74eFo+1vu2Dya4lho7gxPDANh2MtflbauNBCJCCLcSYFUlNcPYZVxRZTvF8Ikle0k1Tk+MNybpCdHWaTQa2odZhmKuGdWx1usHdwjF10tLWaXOxS2rnVQBEkK4laLyKvN2ck4JMcG+nMotsblmT4paPVSjkR4RIazFh/lzLLMYgPY15IeYzB7YngsHx+Pl0bx9Ek329GeffRaNRsNdd93VVI8UQrRAprFtsPSI7DvteNlygwE8tJomaZcQLUGCVY9IdHDtidy+Xh7NHoRAEwUimzdv5p133qF///5N8TghRAuWa5Wsmm6cGfPmX0ebqzlCtCimUu8AUUEto7fQ5YFIUVERc+fO5b333iMsLMzVjxNCtGB6vYH8UkuPSHpBOZU6fY09Ii9fMqCpmiZEi+DjZflYD24hazC5PBBZsGABM2fOZOrUqWe8try8nIKCApsvIUTbUVBWidVEGd5efZQHvtsNgJeHhh8XjCE6SP3Gd9/ZPbhwcHxzNFMItzUk0fILv0bTMoYtXRouffnll2zbto3NmzfX6fqFCxfy5JNPurJJQgg3lmuVH2Ly3bZTAEQH+TIwIZRV901kw9FsxnSNbOrmCeH2RnSO4M25g+kcFdDcTakzl/WIJCcnc+edd/LZZ5/h61u3caoHH3yQ/Px881dycrKrmieEcBMH0woZ8+xKvtx00lzMLMDbg5Gdw22uK6lQs2n8vT2Z0isGXy9ZX0YIR2b0i6Vnu+AzX+gmXBaIbN26lYyMDAYPHoynpyeenp6sXr2aV199FU9PT3Q6+3nLPj4+BAcH23wJIVq3p3/dR0peKQ98v9s8Y6ZTVABf3jiKAKvF7Bz1lgghWj6XBSJTpkxh9+7d7Nixw/w1dOhQ5s6dy44dO/DwkN9mhBBQbFU35K3VanZMqJ9akOvD+cPN564Z1aFpGyaEaBIuyxEJCgqib9++NscCAgKIiIiwOy6EaLush1g2Hc8BINRflaUe3imcI8+cw+pDmYzqEtEs7RNCuFbzVzIRQri9/NJKFq07TmGZ84dHSh2Ulw7ztyxb7umhZUqvGPy9W8ZURCFE/TTp/+xVq1Y15eOEEE7y+E97+HFHKn8dzOCT60Y47b6peaVsP5lnd9zUIyKEaP2kR0QIcUY/7kgFYO3hLKfed/HfSQ6Px4bUvkaGEHW243N4KhKOrmzulrgl6xWum4sEIkKIM4oIsAyVZBjLrjtDUrZanGt6n3ZEBlqeER8mgYhwkh9vAX0lfHdDc7fE7aQWpXLl0is5kHOgWdshgYgQolZF5VU2678cNa7s6QwpeaUAXDwknvAACUSEk+mscpoqipqvHW7IYDDw4NoH2ZW5i+c3P9+sPSMSiAghapRRUEbfx39Hb/Uz6niWEwORXBWItA/zs1kFNO4My5cLUScp2yzb+iqoqqj52lYiqzSL8348j8fWP1bjNQaDgfd2v8e2jG14a715avRTzVoOXgIRIUSNft6ZanfseJZzfrMsLq8yFylrH+bHvWf3wNtTy4D4EKmaKhqvsgwWTbPs66sgc3/ztaeJfLz3Y47nH+eHIz/w3Kbn7Ho6juYd5aqlV/Ha9tcAGNZuGPFBzbtmkwQiQggzg8HAxxuS+PuISkqt0tt3157ILnHKs7KKygHw8/Ig2NeLST2i2fzwVL66aZRT7i/auIO/2R87vbPp29GEkguT+XDvh+b9T/d/Sv+P+/PkhifRG/SUVpVy3e/XsTPT8ucwu+vs5miqDZmYL4Qw23oil8d+2gvA8YUzSDXmcFjLNAYQjWUq5x5mNVU3xE+m7QonyTxo3NDAqAWw4XU4tRkGX92szXKVcl051/5+LQBhPmFE+EVwJO8IAN8e+haANafWkF2WDcADwx9gcsJk2gW0a54GW5EeESGEmXX+x+GMIj7ecMK8/8A5PQHILHRSIFKqApEQq+JlQtTL3h/g/amQvs/+XKZxJshZT0GXSWp7/5JWmyey9PhS0orTAHh18qu8NfUtm/PfHvqWjJIMPDQevDLpFeb2mktsYGyz5oaYSCAihDBLy7dMzT37lTXm7U+vG8HMfrGACkSckWFvWmk3VHpBWr/SPMg+Cs6cmZGfAt/MU70cqxaqY+n7oFB9GJN1SL1G9YROE8EvHEpzIX2389rgRlacWAHAgoELGBg9kHYB7fjknE+4ru915mvGtR/HkvOXMCVxSnM10yEZmhFCUKXTozdAklX+h+kzY+GF/RjTNYLyKj0A5VV6CsqqGj2Mkm/sEZEqqq3cyY3w6UVQUQiXfAK9z2vc/fR60Grh0DLLscN/wMl/4MNzwKCHHjMh67A6F9UdPDwhshskb4S8k9B+SOPa4EYMBgMvb32ZVadWATApYZL53MDogQyMHsjY9mMxYGBozFC36AGpTgIRIdo4g8HArNfXU1xeZRcU+Ht7cNmwBDQaDb5eHgT5elJYVkVmYXmjAxFTjogEIq1U5kH4+Q5I/sdy7PAfjQtE9nwHP98Jcz6yrZRaVQZ/PqGCEICDv6pXTz8ISVTbIQnGQCS54c93QxvTNvLR3o8A6B/Vn+5h3e2uGdpuaBO3qn5kaEaINi6toIz9pws4mVPCrlP5NudKKnQ2v0ElhPkDsHT36UY/1xSIhPhJjkir9MZw2yAEbOt6NMS316qelc8ugmOrbc+d3GB/fWQ31XsCEGoMSNa+pHpVWriyqjKe/udpbvjDUjH2+fHPu2WPx5lIICJEG3fcQaXUeaM7AnDHlG42x68d2wmAH7anoHcwtbc+8kqNOSLSI9L6bH7f8fGMvfDB2XBwmePz9VFRCP4R0K6f5Vh4Z5WcahLb37IdmqBey/LgWMted2Zr+lau/+N6vjr4lfnYlzO/pH1g+2ZsVcNJICJEG3ck07ZAWUSAN4+d25uvbhzJrRO72JwblBgKwLGsYq75cFOjnpuco/JR2gX7Nuo+wg3t/dH+mF+4ek3+B/56pmH31VT7yOo8CbpPt+yf9xoMnAuexn9Tw2+0nOsy2bKduh10VQ1rQzNbcnQJ85bNs6kFsmDgAvpE9mnGVjWO5IgI0QYdzyrmlk+3ct7AOD7756TNuY6RAWi1GkZ0jrB7X3ur0utrD2eh1xvQauvfFazXG9h/uhCA3nHB9X6/cGO6KkjZanvMJxgGXqFqeYBlZkt9+YVBSbZlv9M46HMB+IbCgMsgIFIdv2aJWlsmdoDl2rCOMPkRWPm0+tr5FVz9E4S0nF6EksoSXtryknn/pv43sWDgghY5HGNNAhEh2qCvNidzIK2QA8sO2p0blBBa4/uql17PL60kLKD+OR7JuSUUlVfh7amlc2RAvd8v3Fj2YaisVn03MBrG3gP5ybDvJzDoGnZv70BLIOIfCb1ng28IjL7N9rqE4Y7fH9XTtp1/vwrnPNewtjiZwWDg79S/Kasq48ejP3LHoDtIK04jNiCW0qpS+kb25ZmNz5Bdlk37wPYsuWAJXtrWMawpgYgQbdCGY9k2++O6RbL2sCrrPigxrM73ySwqb1Agcsq42F2HcH88PWSEuFXJOaZeo3pZ1nbxDoSACJjxkgpESnLUyrge9fwgrTDmM132OfScWf+2JYyw3T+0DPpfCu0H1/9eTlJYUYiHxoO3d75tU559VfIqh9d7aDz495h/t5ogBCRHRIg2Z/2RLHYm59kc6x4TxFtzB3P75K6c07f2ks/3TrNMD1x9MNPhNUlZxXyw7jhllY5/8801FjNrSBAj3FzOcfUa3ctyzPRB7x8OGg/AAMVZkHHAMpNGV6kKklUvelaWDz8tgE8vhhIVLNOuPw0SGA1X/wwxfdV+bhK8NwmOrGjY/erhaN5RCisKbY7lleVx3o/nMfqL0TZBSG3uHHwnw9oNc0UTm430iAjRxny/LQWAi4fEk5RVzJYTuYztFsmkHtGcY6yeWpvbJnfjh+0pHM0s5pnf9jN3ZCL+3rY/Ss7672oqdQb0egM3jO9sdw/TqrtSVbUVylbrmxDeCbqfAylbYPz96pjWQ+VxFKVD4Wn47GI11HLZF5C+RyWxznwZhlmqgbLne9j+qe0zfBuRV9R5Aty0Bv4Tp+qPAKz/H3R1XbXRv1P+5uY/bybMN4zZXWfTLbQbs7rM4qejP5FVmmVz7X1D72NE7Aj2Zauy9VM6TOHPE39yJO8I7QPbc0XPK1zWzuYigYgQbczWEzkAzOwfy5gukRzOKKR3bP1+sJdVWuownMotRac3EBXkQ2SgDzq9gUqd+q12T2q+w/fnFRt7RFrCOjMGg1pCvr7DCG3RTwssQUN4Z5j8KFSVg5fVzKjAGBWIpG6z5Ht8ebnl/K/32AYixbYf1AB4BzWunVoP1Y4841pKja1vcgaL9i7CgIGcshw+3PMhGjSMihtlnn7bP6o/qUWpXN/veub2mgtAj/Ae5vdf2O1Cl7avucnQjBBtSH5JpbmM++DEMLw9tfSJC6l31v3js3qbt1cfzOSc/63lvNfWAbDzVJ75XEwNU3PNPSIBbvrhXlkKGftVEPLJ+fBSD1Wq3F2lboc/HoHyojNf60rWPRcdRoNGYxuEgOopATi2qub76CotQzRlebbngttbipQ1hqePZbuiEA793vh7OpBWnMam07ZT3Q0YmPT1JJILkwnyCuK9s97jr0v+MgchbY0EIkK0UqdyS/h+2ynzAnV7U/N5+lfV3Rvm79WoEu3T+rRjaq8YABZvSAIgNb+M9IIy7v3aUt+gsKzS7r3PLTvAovXHje1wwx4RgwG+vQ7eHAl//Ud9YJZkw6Jp8NWVsOxBeLYDPBkOvz/c3K1V3p0If7/W8PoczlCSY7sfbj8kp44ba9McXVXzvf4dCd9drxbLK81VxyY/Ald8rb6cQVttQOCn2xxf1wgnC07y7aFvMWBgYNRAvj/ve6YmTrW55paBt+Dv5e/0Z7ckMjQjRCt1ydsbSM0vo7CsirkjEpn56jrzuYhAn1reWTfxYaqmiGkGDMBHfydxLMtSqTW/tJKSiirKKvWE+nlxMqeEt1YdNZ8Pc7eqqllHVN5CrjHhcs3ztuf3L7Hd37oYpj2tfvPXVar6GKYKnq529C9VuKvDKMux6mXPm5IpSRXg2j9qvs4UoJgSNwddCSMXwFujbK/b8636CjQmT/tHQveznddeje1UdIozoKygcfknRgaDgeP5x5mzZA4VejUMOSRmCN3CunHXkLvQGXQMiRlC19CujI4b3ejntXQSiAjRSqXmq0S8V/48xOM/77U5F+GE2So929mP01sHGQDHMouZt2gze1LzOat3DD/tSLU571Y9Ino9vF7PVVkrCtXMi/BO8MPN6oNz3m/QcYxLmmhWmKZWtDXoVCVRk6rSmt9TE4NB9foc+AX6XFj/th/4VSV7djau+tphLCSOqPn6uIG2+9F9IKY33LYFDv4Gyx+zPV9kLH7mF1q/dp2J1ioQ8Q1VQ0CfXwLnPG9bGr4eiiuL+dfqf3Ew5yDlunJzEALQJVT1BHUI7sCrk19tRMNbHxmaEaIVyi4qN2+b8jGsRTqhR+SiIfHEhTjOAelkLFJ2IK2QTUk5lFTo7IKQjhH+jOka2eh2OE1BSs3nuk2z3Q+Ks0whfXWg6onY863aX/eyS5pn48TflqJgOz6zHC8vdHx9bbYsUnkwm9+H766DqoozvsXGl1eoVW1XP6v248+w0mu7fjDuXst+nwvUa2Q3GHU79D5fVWL1qdYz4Vf3+jZ1MuNFNTwz4f8gJF4dO7lB9YjVQVZpFh/s/oCiiiKSC5P5aM9HnPvDuaxPWU9WaZbdVN2uoV2d2/5WRHpEhGiFDmfUnrQYEdj4nggvDy29YoPNPS/W4sP8OJ5lv5ieSftQP365YxwBPm70IyjnaM3nOo1XS9ibhLRXH6Bpu9T+wd8s56rKcbmT/zg+XpwFep3tb/tV5SrIOLgULv3UtmfBYICNb1v2C09D0hroapvHUCN99ToxGhh+g8NLbUx+RCWxBsZAsNWUca0WLlmstte+DCuetJzzDcWpEobBAyfBOwBSd6jpw6Bm9NSgXFfOwo0L2Ze9j/05qljbK9tewc/Tj9JqvVEXdbuIy3tezsPrHqZKXyWBSC3c6KeAEMJZMgpr/zCs1DlnGXTryTYDE0LZYSyUFh9We/Lduf1jCXSnIAQs9S9M7j+upph6+qk1Sza/r4ZhQP0mPWqB+tBa/z/b/AhdPXsUGsI07dSOQSV3mtZc0evho5lwarPaf64DxA+Heb+CpzcUZ0LWIUCjejJObYYix0XqHKr+Z+YdYOldqI1GA+Pvq/0a/2prHTm7RwRUe8Eyk6cWuWW5XPzzxWSUZtidsw5Cov2j+e3C3/DxUL2OX89SybXa6gv2CTP5kxGiFcovtR+Oee3yQVw3Vv3AndkvzinPObe/uk9iuL85eRWgX/sQ7pzSjWcu6Mv903vYvS8qqPFDQ06XccB23z9c9Qx0HAM+gXDnThhgrHdhSpo0Dc+c3mF5X5H9B5XTVZ+hMu0ZS49BgXEILOsIvDvBEoSYnNoE2z9W26bAKri9JYAoM9Z+0etVr8SJv+2fn7IV3hgBb1Rb02XiAw35bhyrHogExjjv3tWNvMWyHWy/CF5ZVRm3/nmrOQgJ8QmxOX9+1/PZcdUOVl2yiu9mfWcOQkAFIBKE1M7NfiURQjhDgYNAJDzAm0dm9mLBpK6EO6m0+nkD4vDx1DIoMYyP/k4yHw/y9eSKEaoU/Ndbku3e12yBSPZRtfBa54n256yHO7qf4/j9578FI26ylAgPiFKv1t35+clqBVoPF/54LbFdK4jQBBU4leXBO+NU0PTlFZBlv6ghoIYiwBKIhHVUi8eBJRDZ9aVlaOTxPNWLUVECL/e0XGMy/CYYOh8iu+M01oGIb4h9PRJnCusIN6+Ht8fYL9gHvLb9NfZkq6Gba/tey12D72J96nqi/aMJ8wkj0i8SjUZDhJ/9itXizCQQEaIVMvWIhPh5mbfDA7zRaDROC0IAtFqNuSx85yjLKrpBvpYfLTP7xbJifzqxIX7mYCXKCcmy9XZ4uSUR8ZYNaqaGSVmBJUdgzkfQ9SzH99BoIG6QZT8w2v4afZUKRurQ3d9gpdV6RPzCbYODt8Zapsc6Yuo12f6JenUUiCStt1yfdQiieqien+pBCKg/E+u1ZZzBOhAJrH39I6cwDmeVledzz5+3AhATEMPW9K0cz1dDb/+b9D8mJ04GYGz7sa5vUxshgYgQrVC+cabM6C4RLN2jpj86MwBxpHuMZTpvkK+lPkiAjyfvXDWU41nFlkCkqXtEKkpsZ0Ok77ENRPJOAgb14WeaxVEXAdUDEY26z9EVEH69GtYITaxb3kRd6XWq0BdAl8nqe0scadtLUj0I8fS1rKsCaoZQ3kk4vkbtR/UAvbEXLfc4/PEo7LCqknpkhar/8f2Njtvk7CAEIKgJgg9rxqGttb4+rE1Za3c6xj+GiQkTm7ZNbYQMXAnRCpl6QYZ1DCc+zI/EcH+n1A6pTdfoQPO2p9a+ZHywVS+JMwqq1Uv2Ydv98gLbfVMPQVA9c2eqJ1Cahnx+/ResfAY+PAf+N7B+9zyT9D2Asfz55V/Bdb+rdXBqqmQ65i41VOMTbMlpKUiFPKshs2HXWXpEDv4Gf1erc3HwNzW7Jt/4nh4zbM+7ovfHurBYTUNMzuTlC55+/BFgSbS2XuV2XPw4yfVwEekREaIVKjCWVo8I9ObPeyZgMICnh2t/iAb6eDKjXzuOZhTTM9a+2FlEoA+XD09A6+ThoVrlnlAzO6rnVGxepKa6pu5Q504Z1wKp72/h1dc86TULjv2ltk1VWfX2+ToNdmw1fHye2vYOUjNfTOYshrUvQUURHPnTcnzUAjWEdO9hNWy0sL0KxEwzXhJGqtkjjqbHXvSBqi2StFZ9mYQkQM9zVRE0sAQxzjbsBtj8Hkx80DX3r2ZtcDjLAtXU58XTFzM4ZjBv73yb5SeWc3P/m5ukDW2RBCJCtCIGg4Fnlx7g76PqgzfYzwtfL48zvMt53pw7BIPBUOMiegsvbFjFygYxGOB/xud1qbbEe8Ze9VWddU2LuvLytyQ4DrtOJY1+M8++LfVcWBCAXd+oYmHnPKdqg/xoNbvDNEXXJLa/qsGx4U1LIKLRWnItTMme4Z0h5xgsucP2PtbBhKcf3HdEzRZK3wPr/mv7rLJ8S6KuK53zHPQ+DxJqqdTqRK+H+AEVXBkxmMExgwG4OSOVmw/uhAlVTdKGtkj6mYRoRX7emco7a46Z9109HONIfVfydZlTWyzbR1eoV9Nsl5p4B9Z+3pFuxsRW0we5aVE3a5UNLL3+/fWqR2DPd+pYcZblfE15J9ZDNP6RtsXNQJVxt2YKRPzDLce6TFJBCMDkR+2rnPabo6bqdhoPFy+q2/fTEFoP9QxP1w/lZZdmsw9VA+a6Sqvnrf+fqrey/hWXt6GtkkBEiFaivErHv3/Zb3PMOoG0zTENkVg760mV4Dn6Dug4zv58+3quNQMw82UYfqMqEgYQ1sH+mupL2ddF4WnLdv4p1QuhsypUN6KGoYL2gy3bEQ6ColELbGtlmIKv2EGW2UIDr7Cc13rYXn/Bu9B1ihrGumYJ9L2obt+PG9uesZ2JX08070ee2KACQWumBGHhdDI0I0QrkZRVQlaRbUXVphyWcTsp22z3NR7QYYylfLmuUi03DzDpETXzo0cN9UNqExAJM16w7PuGqCRW0/L1oD7EguuZCGvd/swDkG+1Fs7VP0PnCY7fZz2lOKqn/Xn/cLh7LzwZqvZNNVC0Wpj7jUpkDalW1Mu6l6vvRQ0bZnJDKUUpRPtF88aON8zHzioph7x0VVU3ymo2UEMWFBR1IoGIEK1EeoFlemb7UD/uPsuJxaVaotTttvvDbwAvS/VXPLxg4kNqRsbo251bMCu0g20gUt8eEYMBVj9n2d/1lWWKbEzfmoMQk9lvwM4vYdLDjs9rNGoGTdou6D3b9nj1IARsh5ZcWaitCa1LWcctf95ic2xW51ncdHwXkA7pe21rmVSvZiucpnX8ixJCmAOR8d2j+Pja4We4upUrL7QsH28y9Qn76yb+n2ueHxRrW/Z95xcQOxC8a1+Dx6wsz7KgnsZDrbT75xNq30EJcjuDrlRftbnmZ/Vh22HMme+XMELVF/Fww9L8DfTSlpds9i/qdhFPjH4CchYAa2Hvj7Y5Obk1re8jGksCESFasLJKHd9tO8W5/ePMC93FuOM6Lk0t76R69fKH7tPVcIx1b4irGaqtSrvtY/DwhpkvOb6+OlOND/8IVQdk+aOWc+36OaWJ+IVBxzpWBz37GbVq76CrnPPsZlRSWcKW9C2kFFmGumZ0mqGCELBUcT26wpLkDKoIXFWF7ZRp4RQSiAjRgj320x6+3nKKP/amkxiuftuOCXbhmhwthSkQiewGcz5s+ucHOZgGvPn9ugcipsJhoYm2gcfFi6DXeY1vX30FRKqptC1cfnk+l/96OcmFlmJuN/a/kRv7W1WMrbGWjEH9vThKAK7N7m9VD9ew6+vd3rZCZs0I0YJ9veUUAKsPZXIiR9WyiAmRQMTcjR7qYAZLU5j0kJqVYz0Lp67DGlUVcNxYPCwkQfVa9L9MTaPte5HKbWkjNqdtZu5vc9mWvu3MF59BSWUJ1yy9xhyEhPiEsHj6Ym4fdLvNark2yb4T/g+mPAYR3dS+aZFAk8oy2PwBFGWq/e2fWaZag5rp9N11qtKu6RphR3pEhGgl1hxSP+iGJIad4co2wFQSPKxj8zw/qB3M+8W2Eqp1nY7qdFUqpyRukBqG2fi2Oh6aqAKPC99xeZPdTbmunGt/vxaAZzc9y9ezvq7zeyt1legMOnw9LUH5N4e+4Wj+UQCeHvM0Z3c82+a8mafVEN74+9Sff/JmtUxA+l41dVmvh7+eVr1cZfmqwuyMF+EntVge3c9R+UAnN1ruVZYHgU1QBK4Fkh4RIVqonOIKu2ORgd70bNeGa4eYmIqZNaQuiDMljlTVTQEM+pqv2/4xvD8Fvr7aEoSAKhzWRu3K3GXe3p+zn5yyHHLKcigxVbGtQVpxGrN/ms3oL0bz/eHvAdUb8vVBFchM7zid87qc5zgIATUjqdN4GHuPpfepq7Ey7y5jMHTwV1VO37QS8dGV8JpV/ZZ81VNJ8j+WY9XXNxJm0iMiRAv1844Uu2Mz+8WidbDgXJtSXqR+cwWIH1b7ta7m6QN374OXe6oZGHqdKhCWugP+fBzG/Ut96B1erq43rd0CMP05iBvYHK1uVhklGezJ2sObO960OX7Fr1eQUpRCn4g+LDp7EQUVBbQLaMfRvKMs3LiQgooC+kb2ZWv6VvPwy+N/P06lrpKkgiROFp4k3DecR0Y+Unv1X08fVajNWt+LYOn9kL4bXhuqiuLV5tv5qjfOesXj8sIaL2/rJBARogU6llnEE0v2ATCuWyTn9o9l9aFM7jmrRzO3zA2cWK9mrYQmOq6J0dQCIgGNalNBCnx0LuQZc1j0OhWI6KvNsonsDiPb1iJrFboKdmTs4LaVt1FqVTzMU+NJlaHKPMtlb/ZeRnyu1p65e8jdbE3fysY0NQSyP0dVFg7yCiLYJ5iUohSe3vi0+V7PjH2GEJ8GLNDnH65K52cfUUM0OUdrvz59j3GVZCsSiNRIAhEh3JhOb+CfY9n0iw8h2Fd1E5dV6jj7lTXmawYnhnHpsEQuHZbYXM10L0eMUy6rL3TXXDy8IDRBzeTZ+4MlCAG1Ui6oaqYA0b1VMbNz/2t/n1akQlfBWzvfwkvrxfy+89mQuoF/rfoXVQbLwnLRftHc0P8GOoV04oY/biDEJ4S88jyb+/x3q+XP6bwu51GlryLMN4wrel6BRqPhhj9uMAcwI9qNYExcHWqm1GTqk/DVXLVdfZhtxouw7ye1EODe7x2/3xSI6HWqx64oQy1DMOWxJllLx51JICKEG/t4QxJPLtnHkA5hfHfLaAA+/ecElTq1DoanVsP8MR2bsYVuyFT7oaubBCIAMf1UIHLoD9vjVRVq5kX6brV/0fsQ06fp29eEMkoyuOyXy8gsVcnVb+18y+Z8qE8oi89ZTIegDngYF+xbf/l6/Dz9+L81/8cfJ9Sf4S0DbuG9Xe9RZajiur7XcdeQu+yeteyiZfxz+h9+PvIztwy8pXELMvY6F8bcqRbBq274Deprz/eWQMTTz7YsvCkQ2f0t/GA1XTi0A4yw2q9J6nYoybYsUdCKSLKqaFMWblzI/avvp0rfMpb0/mKTqoex9YSlXPgP2y25IZsfnkqovxRYMss7qbrPNcZVW92FqRbIiXXq1TStuCRLzb4wcVR/xA3szdrLtG+n8c2hbyiuLAYgKT+Js789m0V76rf67pcHvjQHIdYi/SLpGd6Tt6e+TeeQzuYgBCDIOwhPrScPjXiIe4feyz9X/MOtA2/lj4v/YMn5SxwGISYjY0fyn3H/ISEooV7tdMhREBBoVXek57lqyrV3oFo2wFqZMVn1cLVg9PjqMz93/xJ4dyJ8ejFkHalXk1sCl/aILFy4kO+//54DBw7g5+fH6NGjee655+jRQ8axRdNLK07j8wOfA3BJj0sY2m5oM7fozHR62xVAi8qr2Hda/UDb9PAUwgIkCLGRc1y9RnRVi8+5C+sVcQFi+6shmuIsNcUXAI2qduqGXtn2CqeLT/PUhqd4asNTPDHqCZ7Y8ASghke0aPk79W9md53NOZ3OQatx/DuuwWDgt+O/AfDcuOfYnrGdtJI07ht6H4nBZx5ajPCL4Jo+15j3o/yjiKIJp8R2Gq9Wbj65Qa24vPYlOO91y3lPb7jhL6gsUSXx11i91zRrpnqFX9MMr5qU5sI384w7BjVDJ7JrI78R9+LSHpHVq1ezYMEC/vnnH5YvX05lZSXTpk2juLjYlY8VwiHrokh/p/7djC2xVaXTs+CzbUx+aRU7k/NszlUPRHafysdggLgQX6KDpHCZHdNCc7XV7GgOCcMBq2GBdv3Vq0FnWVPm9q3NvqqtwWCwmR5rMBh4bftr/HP6H5vrTEGIyUtbX2LD6Q08sPYBBnw8gM/3f47BYODHIz/y7w3/JrMkk9NFp7n+j+tJKUrBS+vFxISJPDzyYV6b/FqdghC3Me3fcP2f0P8SWLAREqrNzAqMgrAO0HE8DLjCctw0NGP6NzrlcfValKbOlReq+iTVpe0B6x7cuvSgtDAu7RFZtmyZzf5HH31EdHQ0W7duZfx4N+o2FW3C1vStDreb24G0Qn7dfRqAR37cg5+XB1eMSOT8Qe3RGWwDkb2pqm5Bv3g3+m3fnZhWufUNbc5W2PMLg3Z9Ic2YCxLZ3bKYnUlYpyZtUrmunBMFJ+geplZprtRXcvvK2/k75W8SgxPpHd6bbRnbSC9JB2BA1AB2Zu60u0+wdzAFFbY1MhZuWoiPhw/PbX6O0qpSDuQewFPjybYM9cvAsHbD8Peq4wKALZVWCxe8pYblfn/Q8m/TFIiEdVTJrcWZai2iPx5VlVxNCzHu+R6OrYJti9V+YDsVtCSttUwDbyWaNFk1P1/9EA0Pd/zbSnl5OeXl5eb9ggIpACMaxmAw2CSmVegq+Cv5L/P+toxtpBenExMQ0xzNs5FbYilMtjtF/R/ZlJSjAhGdbSByKlclv3WKDGy6BrYkph/y7jjE0f0cSyDS7SzoNQv2/aj2pz6pPriaSG5ZLlcvvZqkgiTenvo2Y9qPYePpjaxPWQ/AiYITnCiwzO65a/BdXNv3WjQaDQdzDqLVaGkf2B4/Tz/SitNYm7KWWV1m8fKWl/ny4JeAba+JdXGyK3tdyRW9rHoKWjvT2jUFp9VQnKnYmX+4GkIszoTfH1LHVv1HBSJl+aoWibXBV8HGd9S5k/9Ax0bMAHIzTfYvX6/Xc9dddzFmzBj69u3r8JqFCxcSEhJi/kpIcEJykWhTCisKmb9sPuO/Gs/+bFVTIK04jfFfjTcnyGmMXeT3rr6XP5L+4M8TfzZLW0/llpCUVUxuSWWN11RZDc0YDAZS8lQg0j6sCVeSbUlK89SrX2hztsKx0bdB79kw+w3wDoA5H8FjufBEPoy9q8masT5lPVO/mUpSQRIAN/95M0UVReb/L0FeQUxOmMzVva8m2i+a87uebw5CAHqE96BbWDf8vfzRaDTEBsZySY9L8PP04+GRD7Plyi30DO/p8NmPjXqM/xv+f85JHG0pQo3DTqnb4ZX+lunbfmEQ7mABvfIi+Ok222PdzoaRt0Jv43IBa15wXXubQZP1iCxYsIA9e/awbt26Gq958MEHueeee8z7BQUFEoyIevnm0DdsSVfJX5f8cgmTEiYR5htmzvS/qvdVhPuG879t/2NH5g52rN4BwPKLl9MuoKZVN52voKySWa+to6xSz9WjHS/Mdv3iLWQUWnoIv9iUTLJxYbv2oS0zP+RAzgE0aOgR7qKEdXfuEfENgUs+tuxrNC7LCUkuTGbT6U3M7DzTXMpcp9dRqa/kgbUPUKG3XR7g9pW3k12WDajVaOf1nQfAvUPvrfeUVx8PH96f9j4/HfmJ1OJUpnecTnZpNvFB8a77e3dnIfHqVVeuvkz8wh2v5LvQQRG+ucbS8qNug+2fQvLGVjU80ySByG233cYvv/zCmjVriI+Pr/E6Hx8ffHzadmEX0TjLjtvmJVkPx9w5+E6u73c9Or2O/22zrQWw8uTKJu0u/npzsrkn5J3Vxxxe8+f+dJv9h37Ybd6OC21ZPSJ6g57CikKu+u0qynRlrJyzkgi/iBpnV9RbcRZseB1SjQnJ7pYj0oRKKku49vdrSStOY2XySl6d9Cqrkldx/5r7zQFIlF8Uyy5axkPrHuL3pN/NwTvAyLiR5u2G1t0I8Qnh6j5XN+r7aDUCosHDG3TV1obyj1BDMzUJbKeGY7qdbTkW2R08fdWsnMPLofvZzZ7g7AwuDUQMBgO33347P/zwA6tWraJTp6ZNxhJtR3JBMvN+n0dGSQYAr0x6hS8OfMHG05bVL0e0U2WhPbQe+Hn62ZSR/uLAF1za41Kb2gWutHRPWqPenxDmnol+KUUpFFUU2fzmeyDnAPOXzaeossh8bPI3kwn3DeezGZ8RH1TzLyd19vMdaiEyE3fsEanmaN5R/kj6g20Z2wjzCeOeoffQLqAd5bpy22XpUavJemg9KKsqqzHJ02AwUFBRwG/HfyOtWP37WnNqDQM/GWh37fRO0/H28ObFCS/SJ6IP3x3+jmj/aC7veXmNwyqigbRaVVvEVBY+KA7OeU6tzuuoR8Ta5Eeq3csDonupYZ4vLoWh18G5L7um3U3IpYHIggUL+Pzzz/npp58ICgoiLU395wgJCcHPr2X9RifcV05ZDhctucgmsJiSOIUpiVOY+s1Uc9Z/74je5vOvT36dpUlLubHfjcz5ZQ5JBUksS1rGzM4zXd7e/NJKtp/MtTt+w7hO+Hp58NHfSRSW1Vxwbf6YjgT4uF9R5LyyPOYsmUN5VTm/XPALH+79kPUp6zlZeNLh9TllOfx09CcWDFzQ+IdbByEAARGNv6eTFVYU8uaON0kMTuSyHpdx65+3klqcaj6/NGkp/SP7sytLJXZOjFfTWw0GA5f/ejlFlUWU68q5steVDIoexKTESXhpvczv/+LAFyzctPCM7Tin0zncMuAW8/78vvOZ33d+Le8QjRbZ3RKIjL3bkusR2UMtReAbDMmb1FpEJuE1/OKeOEoFIgBbPlCVhEfdpiq7tlAag6Ha/EBn3ryGLqMPP/yQefPmnfH9BQUFhISEkJ+fT3BwsJNbJ1qqSn0l962+j6zSLB4a8RDH84/zwNoHzOf/NeRf5jHuP0/8yd2r7mZW51n8Z9x/HN7vvV3v8er2Vwn0CmTJBUuI9It0afv3pRYw49W1RAZ607d9CKsOqiTaa0Z14MnZfdl/uoA5b2+gqNxxMPLrHWPpE+c+03e3Z2ynS2gXnt/0PD8d/QmAwdGDzVM1TWZ2nkl5VTkJQQl8efBLSqtKCfMJ45tZ3zR+9tIT1f487typpke6CYPBwK0rbmVdisqR6xnekwM5B874Pq1GS7hvOFmlWXbnpneczlkdziKzNJMrel7BeT+eZ05A7RLShbuG3MVHez8iwCuAjJIMxseP55YBt+Cpdb8gttX7/WE1dAgwfxl0GGV/TdI6+Hg2BMaoNYfOec5xj0nWEXhzJOirJbk/nudWwzT1+fx2aSDSWBKIiOoKKgoY84Vl2lq0XzR69GSVZjGvzzxuGXALfp5+NkHw3qy9dAzpSIBXgMN7phSlMP276QAkBCXw6wW/1mts3GAwYMBQ53yHdYezuPKDjXSPCeSX28fx4Pe7WbIzlR8XjKF3XLD5nltO5PL7njRmD2zP8v3pvLriMAAH/j0dXy/3SFL74fAPPPb3Y7Ve89+J/yXaP5r+Uf3Nx0oqS7j0l0tJKkgyrxNyuug0b+58kx5hPbiy95Xmayt1laxPXc+I2BH4efqRUpTCmlNr8PP0Y0riFIL0eni2WkGsR7PBw30+cFecXMFdf91ld3xYu2EsOnuR+c9xSMwQZneZzdLjS9lweoP5Ok+NJw+NfIj/bvkvhZW1r+I6v+98bh90u01viWhma1+CFU+p7dr+beaeULkjPmeYnn96Fyx/VNUZMVmwCaLcJxm4Pp/f7vM/VYgaVOoqOV18Gg0aXtzyos25jNIM8/a49uMcjp/3iax9EbH2gZYs9eTCZN7b/R439q/DIlSoD9Q5S+aYhx8SghL45JxPiPCreWggu1hlzocHeOPtqeXFOf15anYfm+EWjUbDsI7hDOuoau50iwnk10Nr6Bwe7TZByIbUDXZBSKBXICE+IeYVT58a/RRTO9ivz+Hv5c/NA27mgbUP8MeJP7ht0G08uv5R83LufSL70Cu8F4v2LDIvitYvsh+BXoE2H9DfHPqGj0c8gd2fiBsFIQaDwbxK7Jzuc1iXso6iiiJiAmKY12ceABd0u4Dx8eMJ9w1Ho9FwQbcL+ObQNzy1QX14vTr5VcbFj2NC/AR+PvqzXbK1SZhPGPcMucfhOdGMhsyHIyuh7wW1/9sMczyDzk5sfxh6rW0gcmqLWwUi9eE+/1uFcOBw7mGuWXYNhRW2vwWOiB3BM2Oe4f4197MtYxuJQYmNWjvm1UmvcsdfdwDw2vbXCPMNY073OWd83/ITy21yIJILk/n+8Pf0DO/JipMrmNtrLt3Cutm8xzQUExGgEhI1Go3DnA+dXkdKUQrZZdks3ruYjMAVZFTAg2s38MzYZ5w346QB1p5ay60rbgUgwCsAvUHPkJghvDnlTTQaDWnFaWxL38a0jtNqvMf4+PH4e/qTXJjMoE8G2Zx7ccuLpBal2gxJ7M7aXf0W7Mrcxb+3vcJjmIoiadRMAzdRWlXK9O+mk1OWg5fWi38N/RcPjXjI4fBI9eB1Tvc5DG83nJTCFEa3VysvR/tHc32/6ynXlfP2zrft7nH3kLtd842IxvEPh/m/nvm6+uh1Hoy/z1JTJM9xLlZLIEMzwm29sPkFPt73sd3xcN9wllywhGDvYPQGPb8n/U6P8B50DuncqOclFybzzD/PsD5VVZe8steVXNbzMrJLsxkQNQC9QU96iarGuvLkStoHtufyXy+v9Z7hvuG8OOFFyqrKeG31VjbtSQSDD1DFVSO78O/z7Yv75Zfn88GeD/hwz4c13veOQXcwOm40Xh5edAvtxs7MnWxJ38KE+Al2gY+zfXfoO5uqmX9d8hchPiF4ajzrPd3zs/2f8eymZ837nlpPu5WRE4MSCfUNtanO+Z+x/2F/zn4+2fcJAP9Nz2RqSHe45BNVydKzecsAlFaV8u6udzmWd4yVySsB6BbWje/P+94p96/SV/H+7vfx1HpyTe9rqNBXcCj3EAOjBjZuqXvR8piHfTTwSHrN//Z1VU3aUyg5IqLFyy/PZ+yXY837z417jhMFJ0gpSuGxUY/h7eGaVWcr9ZVcsuQSjuTZLrU9rcM0juQd4Xj+cTqGdOR4/nGb8+9Pex9vD2+e2vCU3XurM+i9QFPFuIDHefXC81iZvJLNaZu5steVfLr/U5YlLSO/PN/ufb4evvh5+pFbbj/jxsRT48kbU97gdPFpRsSOcM7UWCtZpVlM+nqSef+j6R8xJGZIo+655OgSHl73MMNjh/Pi+BeZ9t008wyod6a+Y+4N+PPEn2w8vZHLe11O55DOlFWVMfqL0VTqK7mooIgnQgfCVT80qi3O8u6ud3lt+2s2x16f/DoTEiY0U4tEq7X7W/juOrU9+RHVS2JNVwVfXqGSYa/8FjqMbpJmSSAiWrzfjv3G/639P/P+zqt3NtlQRH55PhO/mkiVoeYptNaGtxvOB2d/AKjfVI/nHyenLIdNaZtYtGeR3W/4JhGePYgO1rI/Z7/D8z4ePpRbVWLcfc1udHodL255kc8PfI7e4GClTishPiF8c+43PLHhCTw0Hrw2+TU+2fcJq06t4olRT9AxpKPN9Tq9Dr1Bj5dHzUmO85bNMy8YeNfgu7iu33W1tqGuSipLzEnGj65/lB+P/IiX1ostV26p9e99dfJqblt5G+0rq1gaMgrNnEVOaU9jGAwGzv7ubE4XnzYf++WCX+gQXMfxfyHq4+Q/sMhY9Mw3BB6oNkRzdCV8coFl/8rvoesUlzdLklVFi3Yo95BNEDKvz7wmzYcI8Qnhw+kfctXSq4j0iySvLK/GoGRah2m8MMGy7oOn1tM8NDIidgS3D7qdbenb2Jh8mFf/+RZP3ywiArzJrkglu+og2TmO2zCz80wWjl3Ir8d/5cG1D3JJ90sAVYzt/4b/H/83/P/IKcth2fFlvLvrXW4bdBtj24/lvB/PM/cm5Jfnc/GSi80ro1oXtnpz55s8N+45vjr4FVvTtzKryywW7VnEvux9PDn6Scp15by/+33GxI2hS2gXAr0C6RbWzRyEfDHzC/pGOl4zqiGsk4zvGnwX3lpvLu156Rn/3oe1G4afxoMUL9jl7cEAp7Wo7gwGAzsyd9A7ord5yMQ6CAE1vCSES8QNVhVas49ARQlUldsOzxTY/lvk0wvhxtUQN7BJm1kb6RERbmf+svnmktOLpy9mcMzgZmnHrsxdtA9sz7H8Y5RUljAhYQJ/nfyLYJ9gViWv4lDuIV6Y8ALB3rX/21y2J42bP1Uf4H3igvn1jnF8sPsD3tv9Hv6e/tw+6Hb+s/E/VOgruGPQHUztMJV2Ae3M1TX3Zu+lU3CnWitqmvICUopS2J21m1+P/sqqU6uc9mdhbXz8eN6Y8oZL7t0QD34xlV8q0pnh35G4LlOZmjj1jDOlnMmU53Jel/Pw0HjwwxE1PDSr8yyKK4sZHjucub3mNll7RBtkMMAz7aCqzL6GztqXYcWTttcPuBwusE92diYZmhEtTkllCQUVBTy/+XmWn1gOwKDoQXx49odNVnbdFTIKyxj+zArz/qQeUXw4fzighnEMGPDSepFTloPeoHdaMbXkgmRm/TgLnUHH0JihRPlHUamrJDE4keLKYr46+NUZ76HVaB0O/zRncOjIsk/P4T7dKfN+lF8Uv1zwS42BmzPp9DqHJdRnd5nNIyMfMS84J4TL/W8A5CbZFkw7+Q98f4OaUTNygSo3//drgEYN4fi67nNVhmZEi2IwGLhy6ZUczj1sPnb7oNvrXMvDnW1Nsk0sjQqydJlaT+EM9w136nMTghP44OwPWHNqDdf1u86u1+bavtey9PhShsQMYWD0QL448AUVugrO7ng2Ny6/ke5h3Xlh/AvoDXq+PvQ1K0+uZOPpjczrM8+tghCAkcVFYPV5n1mayarkVczoPMPlz3Y0pTguII5/j/m3zF4RTSsoTgUihcZlA/R6S+4IQEh76HmuMRAxwLMJEDsALnwforo3R4vNJBARzUan1/HHiT/46uBXNkHI7C6zW0UQArD/dIHN/thuUU327CExQ2qc0RIXGGeTaHp5T8s05J/P/9m87aHx4PKel3N5z8vJL88nxMd9SsubhBZmcGdJCf8LDzUf25K+pd6BSFlV2Rl7MPZk7WHJ0SVc1fsqjuQd4faVtwNwVoezOJhzkJOFJ7m+//UShIimFxyrXk05IYWptucDoiG0Wq7S6Z3w6z0w7xfXt68WEog4ydurjxLk68ncEZIZfyam1UUfXv8wvx6zLfIzIX4C9w+/v5la5lw7kvN4daVlKu/cEYnM6h/bjC1qHHcMQtBVQXEG1xv0zJq7jD2lady16i6WHl/K1MSp5qm/NXn6n6fZeHojSQVJeGo9eXvq24yIHWF3XVlVGfesuoe1KWsB+PzA5+ZzQV5BXNbjMuKD4tmesZ1zOp3j3O9RiLoIMv5sKTQGItlHq51vp9aiOXsh/P6g5biuomnaVwsJRJwgo6CMZ5eqBazOGxBHkG/bWeNh/RFV+XJM15pzGx7/+3FSilJ4dtyz/HD4B17d/qrD6+4beh9X97naJe10NoPBgN4AHtqaf/N96HtLt/2f90yga/QZ1o8Q9VecCQY9aDyIiexJmKEbA6IGsDNzJwtWLuDds95Fq9Hi6+FL74jeaDQaNqdtZkvaFt7d/a7N1OoqfRXPb36eAVEDuLDbhZRWlfLilhfpHtad9oHtzUGItZ7hPflsxmfmujZxgXFN9q0LYSMkQb3mJqnX7Gr1jNobh1RH3apKwX96odr3bv6fSxKIOEF2sSWiPJhWyNCOzh3vd1efbTzBwz/swctDw9ZHzyLYKgDT6XW8seMNSqtK+f6wqiZpXQjLpH1gex4d+ShFlUWc1eGsJmt7Y12/eAv7Txew7O7xNt+3icFg4EhGEQC/3D5WghBXKTauNRQQBVoPvPFg0dmLeGDtAyw/sZxrf7/WfGmv8F4MbTfUXI3VmqfGkypDFYdyD3Eo9xDfHPrGfG5f9j6ba0fFjiLUN5ShMUOZ2Xmmy4rrCVEv0T3Va4bx32v1HhFvq0U/u06BYdfD5vehKL1p2lcLCUScINcqENl/uqDVBSIGg4EKnR4fTw+bYx+tTwKgUmcgLb/M5gN5fep63tv9Xq33HRg1kBv738iY9mNqvc4drTigPgB/3J7C1aM62p0vKKuiQqdmnEgQ4kIlxkIs/pb/c94e3jw28jHWp6ynpKrEfHx/zn6b4nEXd7+YPhF9OK/LeXh7ePPpvk95bvNzDh+j1Wi5pvc1XN7zcmIDW+7wmmjFoo1T1nOOwaHfLT0iQXFwzc/21w+9VgUihaftzzUxCUScIKfEEogczSxuxpY4n8Fg4M4vd7B8Xzq/3zWexAg1JTI1v4zDxt/4ATIKyukaHYAGDQ+sfYDfjv9mPhfuG84FXS/ggz0f0CG4A29NeQtfT1+i/JsucdOZyqt05u3D6UV25zMKLUN1QT6ebrNabqtUapyV5BdmczjUN5SnxjzFA2seQGfQ8X/D/89mTZsvZ35pV2tkbq+5HMg5wE9Hf6J3RG8i/SJ5fNTjeGg80Gg0Tp/ZJIRTBUaBRquGKlf+GypVYUMueBsiHaw/ZcopKcmGsnxVlbWZSCDiBDlWPSL5pZXN2JLGsS6MZbJ0Txo/71TZ1xuPZ5sDkfSCMpvrvj76AQs2qAQ+69oTUxOn8sKEF/DUejIxYSKJwYkt+gd6fkklheWWv+OkbPvA8/rFW9h1Sq0VExnUvIuvtXqlxh6RaoEIwNkdz6ZPRB9yy3LpF9WPbqHdOJx3mK6hXR0WPNNoNDw99mmeHvu0q1sthGtc8I6qG5JmNa08oovja/3DLRVZj62C3rObpImONN864q2IdSBS0EIDkSVHlzD8s+GsOLnC5vgH6yyLu+WWVJCWX8bVizbxzZZTVlfp2ZD5M3qD3iYICfIK4t5h95rrZQyMHtiig5Cc4gpGPbuCyS+tNh9Ly7cNyFYdzDAHIQCRgZI/4FI19IiYxAfF0y+qH4C5wqmjWTFCtAq9zweNVQ9scHs1NFOTbtPU6+E/XNqsM5EeESewCUTKWl4gUlxZzEPrHgLgf9v+x5REtSCSTm9g16k883WZheU89cte1hzKtHm/1jeVMr2ql3Fd3+uI8o9iTvc5aNDUuoBaS3MovZCSCp3NscMZRXy4/jjzx3QCMA/JmEQGSo+IS5UYAxH/lhvgCuE0nt6qByTrkNofe7eqplqTnjNV1dUurl8ErzYSiDRSaYWO5fssWcfuOjSTV5ZHpb7SLi+juLKY+9dY6nZklGSg0+vw0HpwOr+USp1pBQADG3O/IKP0NFrfYRj0Phgq1L08Aw8CMDF+MncNuaspvp0mdTq/lCqdwSYp2dqTS/Yxf0wncosrOJBWCECgjyflVTom9Yhuyqa2PWfoERGizYnqYQlEYgfWfm3HseqrmUkg0kjrjmRx2qp7vqC0bkvHO8PRvKP8c/of5nSf43AKod6g51DuITJKMrh39b2UVpUyIX4CL054EV9PX5Lyk5j/+3yySrPM7ymuLObDvR9yfb/rOZltmXGg9T3Fcd2P4A0BnTaq+1dEMDH0XlYXrgego797lf52Br3ewLT/rqGwrIr7zu5R43UGg4Gdxt6jzlEBrPzXRIc5N8KJDAbINPZA+UmPiBAAhFvlhFSvpOqmJEekEXR6Aws+2wbAmK4RQN16RF5efogxz67k76NZZ7y2JgUVBdy4/Eae3fQsL215CYBP9n3CZb9cxppTa6jSV3HHyjuYs2QOC1YsMC8Nv/rUal7f/joA7+1+j6zSLDy1niw6exGPjnwUUMMz/Rb3476Nl+Db/jO8PPT4RK60a4PWO5s1JQ+i8SjBoPOlo9/wBn8/7qqgrJLCMhVcvvD7wRqvKyyvIrtI9Zi0D/UDkCDE1ZI3Qqr6/0fcoOZtixDuwifIsh3YMnpkpUekEf7Ym2auFTGlZwzrj2RTWqmjokqPt2fNMd4bfx1Bpzcwb9FmDj1zDhkFZSTnljCkQ91/q3tpy0tklKhaFp8f+Jzssmx+T/odgAUrFjh8j5fWi0p9JV8c+ILzu55vvv7Dsz9kYPRABkYN5PXtr5Nbrrq7C6oy8QrOJDJYTz77MRi0aHTB4Jlnd++ytAuoqgywO95SGQwGVh7IwNOjbrF6Sm4pb61WBYTCAyRBtUmYZgZ0mQKx/Zu3LUK4i96z1fTdqJ6qpHsLID0iDVRWqePRn/aa968YYekCqy1htUqnR6dXeRemIGbqy6u56K0NbD+ZW+P7rO3N3muuVjq2vRrfMwUV1Z3f9Xzm95nPbxf+xtYrtzIwaiAV+gou+PkCynXl9AzvyYCoAQB4eXjx0sSXGNN+DOG+4fhg7OVBfZ+VeUMpOnYblfkDKTl5LdrS3gAEaOKoKuxNTrF75sc0xIaj2Vy3eAvXLNpU4zUz+rWjU6QKvh7/aa+5kqoEIo2UfRQ+uRCOr6n9OlMp6+heLm+SEC1GZDe4bStc6/gzwR1Jj0gDPbv0AFlF5QDcO607vl4eBPl4UlheRUFpZY2zJd5cZSm7a+o1KTB2/f91IINBifZJd5X6SpYdX0ZicCJ/p/7NmzveBKBjcEfenPImK0+u5LnNz1GuK+f1ya+TVJDEX8l/4aHx4NGRj9rkjzw66lHm/nolZTo1VHN+50tshhCGtRvGxn2hbDx4nNLKKoj6DK+QXYCGyryhGHSBlKVeBsA5vafwwMxEXl1+ikWGE+SWNP/iSc6yPTnPZn9AQig7jceuHtWBUZ0jGNstkms/2szxrGI2JeWYrw33l0CkUZbeD0dXqK8nLFOhyTkO6XtVpr9Go/YBwjo2SzOFcFuRXZu7BfUigUgDffrPCfN2Qrgq8hXs56UCkbKaE1ZfXn7IvF1Rpaes0jIdtKxKb3Ntpb6SO1bewbqUdXb30aDhxv43otFomNJhCpMTJ1Olr8LLw4t+Uf2Y1WWWw+d3D+vOIwMX8dCG2zDofegVNNHmvMFg4MU/LG0k9QoWz/43Kw+d5p39eTbXnjcgjlDfUCICVa5LTg2zSlqiYF/b/xrh/l48dm5vvth0kpsndCHOmAcS5aBgWaj0iDROQarj42+OhCpjYvidOyHXFIh0app2CSFcQgKRBtDpDTZDb9FBvoAKRFLySmtMWC2tVoMC4FSuZWZKcbltALMmeY1dEBLjH8Ntg25jdpfZNj0ZGk3da3YUFAVRfPRfgJZTuZUMtkqsXn8k2+76AXEdOJnpBeSZj31xw0hGdVFDN6ahiJqmt7ZEReW2f1dhAd5cO7YT1461/dCLDfGzf7PBYH9M1J1/hGW7rAB8g0FXaQlCAH6+w7K4V02VI4UQLYLkiDRAck6Jub7GdWM7MbKzSjIN8VNx3UfrjzP9lTU2QQZgHsqxlpRlueazjSc5YSwZrtPrWLRnkflctF80G6/YyJ9z/uT8ruc3akbG0YwiTH/1J7JsS5Q//avtSqPL7hqHr5cH0cG+5mNjukaYgxCAMONQRE4rGprJK7X9XmoabjH1jFgLlaGZxtFZBfKmegg5x2yvOW6pbktoB9e3SQjhMtIj0gCH0lXRqt6xwTx6bm/zcdPqs38dVJVHxz73F5/fMAJvDy2dowLJNvYYxIX4Ul6lJ7u4wm6tkm93HCCy3R4+2/8Zp4vVqojj2o/jvmH34e/l3+i25xZX8PWWZPN+crVgKSVX5Y5cNiyB8we1p2e7YABigi1DED1igm3e0xp7RKxL9Ws0MKGH4wX62of62uxfMSKRc/q2c2nbWj3rZclzkyB+KGxdXPP1tVWOFEK4PQlEGsC06mz3GNvl3YP97IdGrnhPFf+a0iuaS4aomv+RQT7klZZR6L+U5zasBsYCGnyil/LRqbVgXMbFU+vJC+NfYGqHqU5r+85TeTZlytMKLL00ZZU6Co3DQw+e04sQf8v30z7UDz8vD0ordVw9yvY30PAAdV1uSSUZhWVoNRo8tRqOZhYzpEPLrHiZV6ICkbN6x3DrxC4Ok4jB0hsEsOOxs6Q3xBmKMizbuUmg18OWRY6vHXVbkzRJCOE6EojUQ2FZJXd/tYM/96sflL1ibXsGQhwEIgAajyI2lP2PjVvS8Qi4Bj+vUeR4rcLHV3Uv+8YstXuPrjyKOwc+5NQgBOBYpuqBCfTxpKi8inSrqrCmHhsvDw3Bfrb/NIJ8vfj2llH4eGrpGGlbL8T0YZxfWsnUl1ZTWqkzD139cvtY+rZvvuWlG8rU63Vu/9gagxCAHu0sxYNq+vsX9VBeBJVWvYQr/w1Zh8FYkI8Rt8DGt9R2j5kw+ZGmb6MQwqkkEKnGUVnu4spiDuYcZP3eAHMQAjCtj20XfIC3B7Z0aDyL8I5ciYevGmbxT/wQb69wKsp3OXx+Wdq5VOaOAjx4+lg5Vw7S4etV/b4NdzRT9eaM6hLB8n3ppBVYBSLGHJaIAB+HOSh94hwHFCF+Xmg0Kkez+oyhpOxicyCSUVBGkK8XfnZ/Tu7l7yNZHDUGbGcKLkL9vVlz3yR8vbVSSbUxirNUkmpxhv25XV+q1+D2cM6zMOF+OLoSuk8HLwfJwkKIFkUGV628svUVxn01jn3ZKmHzcO5hbvnzFiZ8NYFrll3D2qzPzNcmhvubi1mZ5JZUgkcxUIVP9K8E9XqYwG4L8Q7baHPdjsqXKNWqqYflmZZVD6cnziawfBJg+aBeuue0U7/H48bk1FGdLSXpTVOITcm0kUH1G17w9NDW+IFdZAxMjmQUMerZlSz4fFuD2t2Uftlt+TMf2vHM1W4TI/zNM6dEAxxbDS90gT+fsAzL+EeoJc2tFaQYz4VDv4vBx3ZoVAjRMrXZQKRKp6e8ypIrUa4r54M9H5Bfns+9q+/FYDBw+8rbWZeyjnKd+oA+XLYUjbG8eai//QfvkB45BHV7mqBej+AdsdbmnK60PSUn59scM+i9qMiawh19n+SKnlfw9PhHWXXfJD6/YYR52OfXXc4NREwL9PWOC8bf2DNhOpZVqIZmIgLqv3R9QQ1Tlk3rtPy0IwWdXpVNP5JRaJ4d5I7+PqLqonxwzVACfaTT0OX+fEK9rn8FDi9X2xFd4ZLFahlzk4QRTd0yIUQTaLM/Ze/6agerD2byyx2jKNCfYPnJ5eZzyYXJLEtaRkqR+g0s2i+aCL8I9ufsJ7Dbs+grwkF/h9099xWsA41tDYmrel3Dyt16Dp5qj6EqlE662zjuoRadU0MwWi7qeS7hARcC4OMHo7tE8sLF/Tn3tXX8cyyH5JwS4sP8Gt31bzAYOJ2vxtpjQ3yJD/PjUHoRyTkldIoMMM/gcTQl9UwGJISy/WSe3fFCY7l7D62l7VNfVqW72wX7ctOEzswf4z4FqcoqdSQZVx0emBDavI1pKzytAt+1L6pX02Jdkx6GrlNh748w6tYmb5oQwvXaZCCSVVTOL8aehsfXvMbWgi/trrl/zf0AJAYl8uH0D8kvz+fCn1WwoPXO4ThP8MNhLR1DOjIoWq38uT1jOwAzO8/E39Ofq3pfRaeQTpSm7+XA4SQAugSM4LuLr+V4ThbTXtpGVJAPYQ56V7rHBOGp1VBUXsW45//inrO6c8eUbnX+HnV6A1qN7QqwahhGVW+NCfYlMdyfQ+lFnMxRH7wH01SCZk+rBMy6untqd37ZlcrdZ3Vn1ELLSr2vrjzCrAFxnM4rs3tPWkEZTy7Z51aBSJqxd8jPy0PWjKlOrwOtC/J79A4qEQfGqFcPL+g4Vn0JIVqlNhmIrNiv6hRovTPsgpBLe1zKD4eWUGEowYsgPpz+IdH+0UT7RxNRfAPZAe+Zr33s78cAGBg1kAdHPMihXFV86Z4h9xDtb1l+uVc7y+yaqCAfvDy86B4Vy7r/m4S3p+MkR29PLZ2jAjiUrpJLX15+qM6BiMFg4PJ3/yE1v5Tld08wJ4eahmDCA7zx9fIwl6Y/kFZgfG14IDK+exTju6taG59eN4KHf9zNCWPPwln/rX3xsrWHM9FqNIzpGlnv5zpbap7qMYoL9ZXkU2t7f4Bv5sNF76v8DGc5uhJObbY/HiS1WIRoK9pkjkiwnwbf2G8J6PIyAFqNB73DhhDsGcmCgQvIPXITFTmjyDt2LVF+6sP1VG4JJ5I7U5J8De9M/Ir+UZZlx3dk7uDSXy5Fb9ATHxhvE4QANtNXO0ZYElzjw2pPcuwWU/+AAOBkTgmbknI4lVvK3lTLomGm3/bbGaukJhoDkU//OcmRjCJSjB/CnaICaIyx3SK5flznOl9/7Uebmfv+RpbsrGGNkSaUYg5EZDaGjZ/vBAzw3XXOve+e7yzbGqsfRxF17/0TQrRsbTIQSWcFXqFbzPvFKeez8e+LSdl9L4dPG/DUxVKePht9eSwHjfUk/tyXjsGgYWjUGEZ36M0H0z7guXHPcX7X823uPThmsN3zescF89KcATx/cX8uHZZQ53bGh9l+GJZV6th+MhfDGdYyWbYnzby96mAm7645SnF5lTkHxHTfyT0tAdMrf1oWunNGPQz7qcw1M9UcWXc4q9HPbaxTxsqycY7WkBFKforz7pV1WL1e8A6EW60ZE9XDec8QQri1Njk0MzB6IAOiBrArLYmiE/PRl1u6gf/Ym0aFzrIK7o6TefRsF2ye9jowMRQAX09fZnSewYzOM7im9zVc8PMFAMzpPsfhMy8aEl/vdlb/MOz56DIAnr2wH5cNT7Q599hPeziYVsi/z+/LS1Yr/L7+1xEAlu9LJ91YRbVzlJr22CEigMXXDueaRZvMOTPenlp8PBufB1DsYIG/XrHB7D+thoHC/L0oq9RTarX6cE2LBTYlUw9StxiZGgqo4jAbXodyS88ap3dCSHvn3DvzoNqO6QN+VoXjwuveoyaEaNnaZI/IwOiBfDrjU76Z8ZtNEAKw61S+zf6O5DwAThgTOjuE2w9bdA3rylOjn+K5cc8xMHqg09oZG+J42OaB73dz91c72GlsW0peKR9vOMHG4zlc/NbfVFTp7d6zOSnXnJTa2WroZXy3SMZ1s+RmBPs6JzbtFm3/QT44MZQrRqgA6s4p3WyCEFDr3pgCleZQXqVjrbFXZnALLU3vdIeXwx/Vqpem7XbOvUtzoSxPbYd3gWlPQ7v+MPga25k0QohWrU0GIiY92gXbfdhvSsoBMBcrO2bsCTlpTLzsGOF44bkLul3AjM4znNq+djUEIgA/bE/hnq93AKoXx8RU2fTKkYmO3gZAlyhLkKDRaJjQ3bKgW5Cvc8qUj+gUzmuXD+LDecPMx3rHBfP4rN58f+torhrVkQHxtpVa96YWcM7/1prLqze1N1YeobxKT4C3B72rle9vUw4vh/S9ajtjr/359D3OeU6hsUaOfwR4+0PiCLh5LZz3qnPuL4RoEdp0IALUOEVzkHEIZtPxHH7akWLOHTDNNGkKsQ7yFG6aYOmyNpUhX38k2+66i4c4zkVpH+rHYOP3ZtLVqvfCWQW8NBoNswbEMapLhPlYx4gAfDw9GJwYhodWw1Oz+zp879YTuU5pQ33tSlG9YXNHdnBqWf0W5fRO+OxieGu0GjrJOW5/Td5J5zzLVEU1ILr264QQrVqbD0QCvB1/8FpPub3zyx1U6PRoNbX3UjhbVJAP7141hOcu6sfgxFBevXwQd0/tziTjkvQdIvyp1OnZdNw+EOkRE8RVIzvYHX/h4v5201KtAxF/J68D4+vlwcx+sQxODGVYtXLpAxJC2f7oWdw03jYfILekwqltqKtc44q7Q9vysEzKVsv2+ldg22LL/sUfqtf8U855VnGmeg2Mqv06IUSr1iaTVa2dyi1xeLyHg1oaqgZI08ZupoX1Lh1mGWp5bFYf/jq4iqzCcpbvS6egrMq8mi5AkI8nft4ePDW7D5cMTSCruJz5H6paDYkOhpask2JdkTD6xlz7mUQmYQHeXD+uM++sOWY+5qj4WVPIMwZAYW25kFnWEcu2qfS6fwTcvB68jEF4SRZUlKjhlMaQHhEhBNIjwp1Tu6HRQEK47TCIoyGYdm4ypTMqSCXyFVfo+GmHmkp5pVXvh6ntGo2GfvEhTOwexdWjOjB3RCLtHdTH0FqVX08vaPogINjPNh7+5J8T6PS1T1F2ttdWHDYXYHNU6bbV++Vu+O4GSHeQiHrPfgiOBd9Q8DYG6FmH7K+rL9NKuwHSIyJEW9bmA5FLhyWy98mzue/snjbHHeWOxAa7xwqrAd4e+BlzGP45ppJrByaEmiuiWueRgApInprdl2cu6HfGaqGmEvBNydF04W+3JjdpG6ynPIf5t7EekcI02LIIdn8Nx6tVwfXyt8xg0WggdoDaXv+/xj+3SIZmhBASiADg7+1JhFXg4anVEOzryb/O6m5zXVPmh9RGo9EQE6w+HExDKV2iAlg0bxiLrx3O7IH1r/HwxhWD8fPy4JXLBjqzqXX230sH2JSwX3Uws1naAc4p6NainN5V7YAGZv1PvU77t+2p8feq173fw+rnG/dcU+9LmPusNSSEaHptPkfEJDLQUrcgLMAbjUbD7VO6cdOELnyzNZkV+zPqVRXV1XrHBZtXidVqVO6Hj6dHg0uTz+wfy/S+7WxWyW1KFwxSBd8mdI/iorf+ZuPxHAwGQ7Os9+LZxHlAzaKqHLZ+BJ0nqVdr4Z1hyDzoNwe8q9XNaW+V7/PXMzDuXtA24M+rLB/SjNOAO4yu//uFEK2Gy3/ivvHGG3Ts2BFfX19GjBjBpk2bXP3IBokMtPSImHobQFUanTuiA4vmDaOXG9WWGBAfat5OCPd3SjXU5gpCrJnqd+QUV/DMr/ub5Jn6Js5HcQtbPoSl98Mbw+Dgr7bnEkao1+pBCICvbe0X88yX+krfCxggJFEWuBOijXNpIPLVV19xzz338Pjjj7Nt2zYGDBjA2WefTUZGhisf2yDWeQH+Xu7fUTQwIdS8bV2grKXz8/bA10v9s3x/3fEmSZ4tLLcsQ7/2/kkuf16zMxhg9ze2x2Ksarp0P7v293tY5dA4Sm4FKEhVPS26KsfnC41F+ELdp5dRCNE8XBqIvPzyy9xwww3Mnz+f3r178/bbb+Pv78+iRYtc+dgGsZ454uPl/l3z1iv6BjipCJm7sE6YzSoqd/nzCox5Nj6e2iYtWNdsdn4JKVtsj/WYAVd8A1OfgF7n1f7+G/6ybH96kRrmqe61obDkTtj2keN7FMmMGSGE4rJP3IqKCrZu3crUqVMtD9NqmTp1Khs2bHD4nvLycgoKCmy+mkNLSFa0Dj5qWpOmpQqy+t5yil1f3KygTAUiwS3g790pdn2lXsfdC2PvUcMj/S6G7tNg7N1nzvlo11etCWOSfUTNojnyp9ovK4BKVfWXU1vt3w9QlK5eA2Ma/n0IIVoFlwUiWVlZ6HQ6YmJsf9DExMSQlpbm8D0LFy4kJCTE/JWQ0LTdtndM7qpmy0xrGUuQL5o3lFkD4rh9ctfmbopTvX/NUPN2dlETBCKlavjAWQv+ubWKEjixXm0PuAymPg5374aoev6bP+spy/a6V2D5Y8bekQr47V7LOV8HeVVrXoR1L6vtQClmJkRb51ZjEA8++CD5+fnmr+Tkpq0lcc+0Hux4bJp5wTt3N7lnDK9dPshpC9W5ixGdIzhvQBwAn288SUpeqUuf12p7RPJOws6vVE4IqBoh384HXQX4hEBEIwLYLpNgwOVqe/fXluOb3rH0uIAagslNghN/Q2kebH4fVlpNCZYeESHaPJf9ChgZGYmHhwfp6ek2x9PT02nXznGWvI+PDz4+zbv8t9YNZo4IS0G5TUk5THpxFYeePsdlz8osVDkOEa2ttPvi8yD3OJTlQUwfWDzLci6kvSpQ1hgxDhYt/OMR2/2936uvmoR3rvmcEKJNcFmPiLe3N0OGDGHFihXmY3q9nhUrVjBq1ChXPVa0EtbTqSuqXFvtNdXY49LQGixuSa9TQQioabofzbQ9HxzX+GcMmVdzj0Z0nzO/f+w9UkNECOHaoZl77rmH9957j8WLF7N//35uueUWiouLmT9/visfK1qBiMCm6xlrlYFIxhlqsDgjEPEJtE1atRbZzfFxkxG3qPyUZihYJ4RwLy7Nzrv00kvJzMzkscceIy0tjYEDB7Js2TK7BFYhqmvKYZLUfFWrpFUFItlHaj/vbb+6dIO0HwJHlqvtHjPg4G9qO3EUFGeBX6h6Tf7H9n2hiQghBDRBiffbbruN2267zdWPEa1MRKBtIFJWqcPXy756rE5v4K1VRxjZOYKhHcMb9CxTj0j70FY0DdpRxdM5H0HmITVjpfds5zxn7N1QUQS9ZsGhZZZAJDAa5hsrth76Az6fY/u+kHjnPF8I0eK1gfmKoiWKCLAdmskvreTB73ez/3QB390y2lxH5dutybz4h1o5N+nZmXb3OROd3kBaa+wRKcm2PxYUC30uUAvXaRu/JAAAXr5w9jNqu9BqWr6/VVDYdSrMeBH8wuC769QxqagqhDByq+m7QphU7xFJzSvlh+0pHEgrZO3hLPPxrSdyG/WczMJyqvQGPLQaooNaYY9I/HDLMdOaLs4KQqprP8Sy7WU1BV6rheE3QKfxlmPB9V8hWgjROkkgItxSYLWy9f8cyzFv/7wzxbxdUqEzb5vqgdSHqUZJu2Bft1j0z2mKjcFazxnq1S8MgpyQoFqbkHhVpdU70HGBtIAoGHglDL5aSrsLIcxkaEa4JU212RTvrT1m3v5tdxrbT+YyID6UQ+mF5uOn88oIbmdblMxgMLA3tYAuUYH4edv3BJzON82YaUW9ISU5llkzIQlwzwE1O8XTxQnAGg0s+AcqyxxXVNVo4Pw3XNsGIUSLIz0iwm29feVg83b1NWcOphWydE8ah9KLzMdS8+0rsP66+zTnvraOGz62XeRtc1IOj/20hyMZ6v2xIa0kPyT/FLwxArIOqv2ASAiOtQzLuJp3AARENM2zhBCtgvSICLc1vW8sD8/oxTO/2dfEyCoqZ+epfJtj6cakU2uL/04CYN2RLJvjc962XXgxKqh5K/o6hV4P38yH4gzLsciWsW6SEKLtkh4R4dam922HKXVjeMdwbp3YBYCUvDK2nlB5I0HGxepyS+xzREzLrJxJeGso7559BE5tsj0WHNs8bRFCiDqSHhHh1hLC/Vn5r4kE+noSGejDxxuSAPhi00nzNZN6RPPzzlTySu1X6q1jHEKofytY8C43ybId2gEmPtBsTRFCiLqSQES4vY5WqyGHVFshN8TPi27RgQDkFTvqETHYbFdPgjUJ92/BPSLlhWqmSt4Jtd9jJlz+efO2SQgh6kiGZkSL0ifOdjZG56gAQo3DKidzSnh1xWFOZpeYz1v3iBSUVdV439CWGojs/REWxsPWDy09ImEdmrNFQghRLxKIiBala3QQ143tZN7vEhVImHFYZcOxbF5efogrP9hoPl9SbqkzkldiGbrx8rDtGWmxOSLfXKNef7kbMg+o7VAJRIQQLYcEIqLFuXSYpTz4uG6RhFXrzTiZY+kRybEKPvKMyawVVXoqdbbZI2EtMUekeibukT/Va+eJTd4UIYRoKMkRES1O58gAxnaNxN/bg1n949h3usDhdVU6PdlF5eb9XGNQUlJhGaKJCvLBx1PbMntEco7ZH4vsAdE9m74tQgjRQBKIiBbH00PLp9ePMO93skpmtZZVVIHeqtPA1CNSbCwL7+2pZe39k9AbDHh6tLDOwZIceM1Y8C2sE1QUq4XuznmuedslhBD11MJ++gphL8DHk+9vHW13/HS1SqvmHpFy1SMS4O2Br5cH/t5uHI/r9XByowo0TMqL4CWrQmU9Z8KdO9RXl0lN3UIhhGgUCUREqzAgPpSYYEt1VJ3eQFJ2sc01poJnRcZAxK0DEJNti2HRNPjqKsuxv18DnVXNlJG3qNLqoYlN3z4hhGgkCUREq+Ch1bD6PktvwMvLD3L3VzttrjHNmskrVQFJgI/9InhuZ+Pb6vXoCvVqMMDWj9R2cHt46LRa9VYIIVooCUREq+Hr5UGQj+rl+Gpzsvl4ZKBKRDXliHy79RQAcaFutNDdyY0qwKistl6Ovlrtk8wDUJQGnn5wx3bw9m+yJgohhCu0gL5pIeouLMCbwvIqsopU78fT5/fF21PL/d/u4uedqdw2uStbk3IBmDe6YzO21Er+KVh0NmCA1O3g4QN9LoAOo+wDkaR16jVxJHi2goX6hBBtnvSIiFbFOk/E20PLpcMSiLCamjv3/Y3mHJGOEY5n2zS57KOYa8Bu/Qg2vQNfXKb2dQ56RADiBjZR44QQwrWkR0S0KtHBvubtjpH+eHlo6RMXYj6WWWipKxLo28z//A//CRVF9r0eAGV5kL4X9Fbr51SUqBV2ASK6NUkThRDC1SQQEa1KTJAlEEkIU/kT7UJ82f3ENPo98YfNtUHNGYhk7IfP54BBD92mOb7mrWpTkv8Ta9mO6Oq6tgkhRBOSoRnRqlgPzSSEWxI5g3y9CLUq4+7tocXHsxlnzax7RQUhAIf/qPVShyKlR0QI0TpIICJalRiroZn4MNtZMaF+lkCkWXtDKkpg/8/2x3vNgph+Z35/vzngH+78dgkhRDOQQES0KuO7R6E1Lqzbs12wzbkQq0CkWfNDco5CZQn4hUPCSMvxyO5wyzq4cRUEtqv5/RMfdHkThRCiqUiOiGhVwgO82fDgFHadymdM1wibcyFWq/Q2a49IYZp6DWkPV30P6/4LB36D4Teq43GD4N6DkHsCkjdC9+kqeHl3ojof3rlZmi2EEK4ggYhodWKCfTmrt6/dcZseEZ9mTlQFCIpVpdknP6K+qgvroL5ABSfXrwDfUNBomqypQgjhahKIiDbDNkfEq5YrXWTlM7DnO9W7ARBUy/CLI/FDnd8mIYRoZhKIiDbDukckuKkDEV0VrHne9php1owQQrRhkqwq2gzr6bvW202iMNX+WLv+TdsGIYRwQ9IjItqMYKseEethGpfLS4Y/jDkgoYlw9U9wcBkMvbbp2iCEEG5KAhHRZlgHHyFN1SNSVgBvDFfTdQHCOqpZL6NubZrnCyGEm5OhGdFmhFpN3w1pqh6RY6ssQQhASGLTPFcIIVoICUREm2EdfFgHJS6j18Ffz9ge6zXL9c8VQogWRIZmRJthnaAa3BQFzX5/GDIPqO2EkTBoLvSY7vrnCiFECyKBiGgzrHtEvDwa2Rm4+1vIOQ7j7625wNiOz9Vrjxlw8YfgZV9kTQgh2joJRJxBr4cvrwC/ULjg7eZuTdNJ3wffzgcvP7juT/Bw739Ovl4edIzwJ7u4gq7RgY272XfXqdfono6HW8oKoDxfbV/4ngQhQghRA/f+5Ggp8k7AoaVqe/IjEBLfvO1pKkvvtww95J2AiC7N2546+OPuCej0Bny9PBp+k6oKy/bRv+wDkaR1KjAF8A0Bn0YGPUII0YpJsqozFGVYtk/+03ztcAWDAX69F/7bF/JTLMcLUiFprWU/P7np29YA3p5a/LwbEYRsWQR7v7fsmwIxk+Js+GgmlBl7Q4LbSFAqhBANJIGIMxSlWbZTtzdfO1xhx2ew+T0VaBxfrY7lHIecY7bX5bWMQKRRMvbDL3fDDzdZjp1YD6e2Wvatz4FaYVcIIUSNZGjGGQrTLdvWvSOtwZoXLdv5KWq5+i8vB89qOQ/5p5q2XU2pogT0lZB7wvH59yfDE/lQmgvH/rI9Fz/c9e0TQogWTAIRZ7DuESnJar52OFtpHuQet+wXnIL9P6vtqjLba/d+DxPuB20jhj3ckcEA706Esjzoc0HN1+n1kLwJ9FUQ0Q1uXqfyZsLdP29GCCGakwzNNJZer6pnmpRkN1tTnC7zoO1+/inAYHus57mABrIOQcpWWp2SHMg6CEXpsLGWGVElWZC+V21HdFGzZKJ6uP1MIiGEaG7yU7KxktbYfgAX1yEQKc5WH26Jo2quQeEOMvapV41WLVmfcwy8/G2v6TAaijMheaNKYG0tygrgi8shvFPdrt/xOax4Um0HxriuXUII0cpIj0hjlBfCx7PVdsdx6rUuPSLfXQcfngM/3ab2DQaoKndNG2uiq1IlyGtjmhHS/zLw8FGBSPoe22u6TLF88Bal02rs/BJOrIPtn9ifSxgJcYPVKrrth6hjfz5uOR8U2zRtFEKIVkACkcbY+4Nle8g89VpVChXFNb/HYLAkNO74VL1+fyO80A0KTrukmXaqKuCtUfDuBDW0VF3yJlWePGWb2u84BnrOtL+u3xxV0CuondovTLO/pqUyVPtz6TIZRt4KWi847zW48S/oPNFxzZjAqCZpohBCtAYSiDTUkRXw8+1qOzge+l6keg2g9l6Rv1+zbHsYF17b/bWqwrn1Q9e0tbr0PSqnI223bTKqyQdnwYbX4dQmtR/VCyK721/X3bhuSmvsEamsFkwGtoNpT8NDKRBl9WfhKBnVuuCZEEKIWkkg0lA/LbBsT7hP5Xr4R6j94hpmzhgMsPxRy76uwrb3pLaeFGcy5X5U3wbYv8T++qgeEFpt+fpz/6uCL2idPSLV/w5D4tWMIE8f2+MRXe3fKyvsCiFEnUkg0hC6Kii0GkYx/VYcYAxE8k/ByY0q8LBWXmB/rwyrypwbXlf3drWkdZZt00wPkz+fsN2fv1SVKLcORHrMgKHXWhJtA42BSGvqEbEOROIGwbDrHV9nPTRzwTtw31EITXBt24QQohVxSSCSlJTEddddR6dOnfDz86NLly48/vjjVFS0ki5r67Leo26DDmPUtqlH5OurYNE0+HCGKvVdXqSCEutZJV4B6jWjWiBgyhtxlZxjsPMLy372Ecu2Xg95J9X2jBfhti1qVgzYBiLdptneM8g4NFOYpu6fm6Sqr+7/xT4Yawms83jOfxtuXGX5Hqtr10+9ar2g/6UQENkkTRRCiNbCJdN3Dxw4gF6v55133qFr167s2bOHG264geLiYl588cUz38DdpZqSOMfB2c9YjvtX+xA6+Tc8m6hyQQbOhV7nquPRfaCiEPKKLXkmJvt+UkmrA6+AsA7Ob3vyZtv93CTLdkm2Gi5CA4OvAU9vy7nQRBV0eXipc9ZMPSIlWfDGCOM9jK79AxJHOPM7cL0Dv6gpyQCB0bVf6x8Od2xX05rdeSq2EEK4KZcEItOnT2f69Onm/c6dO3Pw4EHeeuutlh2IVFXAX8/Arq/VfvvBtudNPSLV6SpUIqopGTUgwliB9KTlmsgeqrbI0ZXqa/WzMPxGmPGCc7+HtF3qNWGEqv1hXba8wLioXWC0bRAC6kPWOuiy5h8BWk9VVVRXrdcr74QKRCqK4Yebocc5KshyZzu/VK/eQZYeodqEd3Zte4QQohVrshyR/Px8wsPDa72mvLycgoICmy+3sv0TWP8KFBqHWHrNtj3v5Wv3Focie4BfqO2xSz6GHtWmyG561/mL6JnqgJgSKoszLEmypqGj4Lj63VOrBe8Ax+dKc9Xrjs9Vefgfb7EM37gjvV4FggDzfwMvv+ZtjxBCtHJNEogcOXKE1157jZtuuqnW6xYuXEhISIj5KyHBzZL+so9atsM6QvwQ2/O+IXW7z7h7bIdxpj+n6nFc+gmM+5fttf+81aCm1ijbuGpu/HDwDVXbppV0TfkhwQ1YMTaghiEMU3Bjnfz56iD43wB4KhJWPVv/Z7lSQQpUlqicj+jezd0aIYRo9eoViDzwwANoNJpavw4cOGDznpSUFKZPn86cOXO44YYbar3/gw8+SH5+vvkrOdmFS8tnHVEryTaUX5j9sWHXw+Cr4brlcMkn0O8SGHYDhFmVCU8crXocEkdZjvU+T71qPWDKY2ol12v/UMd2fQWbP1BrnjRWVYVauA5UIBXVU22b1pQx5b6YEjDrY+QtENMP5v1qe3z9K/DrvfbThEGtaLtqIVSWuk/tjRxjsBnWUdaJEUKIJlCvn7T/+te/mDdvXq3XdO5sGS9PTU1l0qRJjB49mnffffeM9/fx8cHHx+eM1zWaXg/vTVLTaa/+SVXItGYwQH4yhCTYJiDmWwVGjj6sfUNU1U0TU4Dxy92wxVg4LMTY29BrFiy9TxVDc1QSvP0QlQBZWQK/3qOGNa7+qe7f41//UdNpZ/5XDZ2A6vEw6NV9A6NVfZDkf1Rxtn4Xq4qqAPFD6/4ck2HXqS9Qgdgfj6p7A2x+r/b3vj4MKorgss/rlpPhSqZZRI7qgwghhHC6egUiUVFRREXVrXx1SkoKkyZNYsiQIXz44YdotW5UsuT4KktNjz3f2QciO79QuQxTHrMdKjENXWg94ayn6v68+OGwZZHaNg17BMeq2Raevo5nW3h4qvoVJ9arfesVfs+kIBVWP6e2B19jSao1VVEN66ieGd1L7e/8HIbOt5xv17/uz3IkYTgMvsoSiJyJKcDb9VXzByKnjcm8kd2atx1CCNFGuCQ6SElJYeLEiSQmJvLiiy+SmZlJWloaaWluUnkzdYdle9vHsPtbOLwc1r6sekN+uUedW/GUpcBY+l44vQPQqPoajoZmatJ5gqobovW0DXrCO9eeGBrT13a/JAeO/uW4NkdZARQaC4qtfdlyfM0L8PH5qv2m4RHTh2zv2WpmCMCXVjNZ/GpPKq4TR4v4nWl2SU0VaZtKZSns/VFtmxYxFEII4VIuGQRfvnw5R44c4ciRI8TH2y4KZnCHAlfj7lFDI68bhyC+u85yLqqnWrjOJGULJI6Eg8Z8kh7n1H1peJPgOLhzB6Cp34JoEdXWMXne+NwJ/wcjblbBkEajhpo+nKF6NKY+YTsUYmr3W1Y9DaYej+A4uGW9Ghox1c3wDXFOboSjoasuU1Qeze8PqVlC38wHg9UKwAd+gc8vhYsX1TwLx5X+fEKt+ePlrxb6E0II4XIu6RGZN28eBoPB4ZfbiOymSpVXt+db233TkEjGfvWaMLxhzwuMrv+qrDX1IKx+TgUla4w1Ro6uhPTdKs/it3vVMVNPhyOxAyzbYR1g7F2W/fr09NQmYThc/iXMX2Y5Fj9MrWB772HodZ5avbe6Q8vg0O/OaUN9mZJ2h17bPIGQEEK0QW6UuNEMLnjb/tjhP233kzeqV9OaMFG9XNsma46WmI8fZtn+y1hgbM939tddWEtycPU8jHir4MpZgQio3qPEkZb9jmNUD05gtHo9/02Y/Sb4BNu+L9+Fs6VqU5ShXrtMbp7nCyFEG9S2AxHfEFVu3Vp5vno1/bZ+dCX88zZkHVL70T2brn2RPdSHoilnY/iNMOtVy/ngeKgsg0NL7d/baRwMvNL++Lxf7X/bt+4h8ahWUbWxNMacmpvW2AdWWg8YNBdu32Z73FTXpClt+dCy7k9gDevKCCGEcDoplKD1cHw8fhjs/kZtL/s/47VeakpvU9Fq4aof1Hb+KVU0zNNbzbZ5dZBaG2bbx6p6aUCUbZ6HT5CaSjziJpWH8e5Edc5RkS7rIaM8F/RGnGkGin+15NitH6mgK6aPw8td4pe7LNsSiAghRJNp2z0iUPNqqXGD7Y+Fdag5cHG1kHjL+i+mKcBVpWrKK8CYOy3XdpmiXrVaiO2vpgFf/TNc+Z39h351ppLsTUnrYT9leNMZao+4Uk1rBgkhhHA6CUSmPK7WeBl1m+1x62XvTcLqOVvGVTx9VA8IqFk9oPI8Zr6kpp2e85z9ezpPgK5Ta77n7DcBDcx+3enNrZNrf4f7j8NM49Tjg785nqbsCtWf4041b4QQopWTn7hxA+Hyz6HzJMsxjVb1lPS/zPZad1pl1aYHQaOGMYZdD/N+OfPS9Y4MmgsPpagKq83B21/11gy8AtCoqrDvjAe97oxvbbTKEsv21Cdc/zwhhBBmEoiYWBcWC4xRwwXnvwX3HICz/g19LlT5Fu7CujBaRBfwCWz8Pd1hyqqXH2DsoUjb1TSr9JZkq1cPHxhzl+ufJ4QQwkySVU1CrFacNa39otWqUuxj7mieNtXGOhCJasKZPE2t8LR9YTdnMwUi/hGOy+0LIYRwGekRMfENsWy7Q8/AmViXf/d2Qm+IO5n9pmW74LTrn2cdiAghhGhSEog4Ur3AljvSamHcvaqCqvXCfK3BwCtUAjHA99fDnu9d+zzTGj01zaASQgjhMhKIWBtxsyroNfmR5m5J3Ux5FB46BVHdm7slzqXRQGRXy/638137vJyj6rW+awgJIYRoNMkRsTb9WTWd19u/uVsigmpZldjZsk2BiItzUYQQQtiRHhFrGo0EIe4iuFogUltNkcZO8TX1iLg6KVYIIYQdCUSEe6oeiFQUQ8o2OPCb7fH8U/B8Z/j13oY9x2CAbOPaNtIjIoQQTU4CEeGeTFOoTYoz4L1J8OXlkLbbcnz9q1CWB5sbWBK+KB0qi1URu7CODW2tEEKIBpJARLin6gvPHVtt2V7zgmXbtNAfqF6T+jLlh4QkWNbyEUII0WQkEBHuyaNaHrX16rj7foKT/0BZgXo1yTlmf5+SHPjhZji+xvZ4eSGc3Cj5IUII0cwkEBHua/YbNZ9L3wPbFkNhquWYqXfD2oqnYOcXsHiW7fGvroRF02Ddf9W+DMsIIUSzkEBEuK9BV8KF79se8wtXr3nJcHi57blCB1VYMw84vvexVerV1IsS2K7BzRRCCNFwEogI99b7POh3CUT1gjkfWarIrn8FjhvzRrpMVq9FGfbvN+jr9pygmDNfI4QQwumkoJlwb54+cJHVjJg939me9w2BxNFwdKWaWaPXqZWTTazrj+iq7HNPTKonxwohhGgS0iMiWpb2Q2332/WHwGi1ved7WBgPWxZZzusrLdslWTXf13QPIYQQTUoCEdGyhHWA89+27McNsvRmVJaor1/utpwvtgo+ioyL2+msghMT6RERQohmIUMzouXpOcOyQvLIWxwnqYIairE+V2SsOVJeaDkW3kUN5UiyqhBCNAsJRETL4xsCN68FDx8IjrUEJdUVnAJ9lWXf1CNSlqdevfzhlr9VIFJT7ogQQgiXkqEZ0TKFdVRBCIBPINy+zf6anOO2+8XGWTVlBcb3BYOXL3h4uayZQgghaieBiGgdIrrAhP+z7O/5Hj453/Ya0/ReU8+Ibw09KUIIIZqMBCKi9Zj4oBpuAfjjUctxT1/1agpENhuLpEk1VSGEaHYyMC5aD41GzX7JPa7yQwDG3g2hHdRaNXu+haiecHqnOjfipmZrqhBCCEV6RETrEppo2fYNgcmPQWQ3y7FVCy1TeqN6Nm3bhBBC2JFARLQu1oFITF/QaiFxFIxcoI4ZdOoLwC+s6dsnhBDChgQionUJ62DZNvWEaD1g+n/AN9RyzsPbkk8ihBCi2UggIlqXsE6W7cgetucCIi3bfmEqp0QIIUSzkkBEtC49zoHYAYAGOo6xPRcQZdn2C2/SZgkhhHBMZs2I1sU7AG74C0pyIDDK9lz1HhEhhBDNTnpEROuj9bAPQqBaj0hokzVHCCFEzSQQEW2HdSDiH9F87RBCCGEmgYhoO6wDkQAHPSZCCCGanAQiou2wzhGRQEQIIdyCBCKi7fCXQEQIIdyNBCKi7bAZmpEcESGEcAcSiIj/b+9uY9oq3zCAX4VCQWRVWHg5bh0wUZQhQwvEsUjMiGiWzbko0zCGLjGZ6TLKDLLF1H1Qx5hvc7jAMMbsg/MlRnDMzA2RoYsylFoncbKphKFkVI3SjmVI2uf/wdD/mJPX0z5wev2S8+GctvS6c+DpzTnP6QkelzciEUZ5OYiIyIeNCAWPy787JDpRXg4iIvLhF5pR8AgJATYcAobdwDxFdhoiIgIbEQo2KfmyExAR0WV4aoaIiIikYSNCRERE0rARISIiImnYiBAREZE0fm9EhoeHsXTpUuh0OjgcDn+/HREREc0hfm9EnnrqKSgKL5UkIiKif/NrI3LkyBEcO3YML774oj/fhoiIiOYov32PyMDAAB5//HE0NjbimmuumdRrhoeHMTw87Ft3uVz+ikdERESzgF+OiAgh8Oijj2LTpk0wm82Tfl1VVRWMRqNvWbhwoT/iERER0SwxpUZk27Zt0Ol04y4//PADampq4Ha7sX379imF2b59OwYHB31LX1/flF5PREREc4tOCCEm++TffvsNf/zxx7jPSUlJQVFREZqamqDT6XzbPR4PQkNDUVxcjAMHDkzq/VwuF4xGIwYHBzFv3rzJxiQiIiKJpvL5PaVGZLLOnTs3Zn5Hf38/CgsL8f777yM3NxcLFiyY1M9hI0JERDT3TOXz2y+TVU0m05j1a6+9FgCwePHiSTchREREpH2z+u67owdrePUMERHR3DH6uT2Zky4BaUSSkpImFeZKbrcbAHj1DBER0RzkdrthNBrHfY5f5oioxev1or+/H9HR0WMmvs6Uy+XCwoUL0dfXp9m5J1qvkfXNfVqvUev1AdqvUev1Af6rUQgBt9sNRVEQEjL+Bbqz+tRMSEiIX+eUzJs3T7O/XKO0XiPrm/u0XqPW6wO0X6PW6wP8U+NER0JG8e67REREJA0bESIiIpImKBsRg8GAHTt2wGAwyI7iN1qvkfXNfVqvUev1AdqvUev1AbOjxlk9WZWIiIi0LSiPiBAREdHswEaEiIiIpGEjQkRERNKwESEiIiJpgrIR2bdvH5KSkhAREYHc3Fx0dHTIjqSKqqoqZGdnIzo6GnFxcVizZg26u7tlx/KbXbt2QafTwWq1yo6iql9//RXr169HbGwsIiMjkZGRga+//lp2LFV4PB7YbDYkJycjMjISixcvxrPPPjutW0DMFp999hlWrVoFRVGg0+nQ2Ng45nEhBJ555hkkJiYiMjISBQUFOHv2rJyw0zBefSMjI6isrERGRgaioqKgKAo2bNiA/v5+eYGnYaJ9eLlNmzZBp9Nhz549Acs3U5Op7/Tp01i9ejWMRiOioqKQnZ2Nc+fOBSRf0DUi7777LrZu3YodO3bAbrcjMzMThYWFcDqdsqPNWFtbGywWC9rb29Hc3IyRkRHcc889GBoakh1NdV999RX279+P2267TXYUVf3555/Iy8tDWFgYjhw5gu+//x4vvfQSrr/+etnRVFFdXY3a2lq89tprOH36NKqrq7F7927U1NTIjjZtQ0NDyMzMxL59+676+O7du7F3717U1dXh5MmTiIqKQmFhIS5duhTgpNMzXn0XL16E3W6HzWaD3W7HBx98gO7ubqxevVpC0umbaB+OamhoQHt7OxRFCVAydUxU308//YTly5cjLS0Nx48fx6lTp2Cz2RARERGYgCLI5OTkCIvF4lv3eDxCURRRVVUlMZV/OJ1OAUC0tbXJjqIqt9stUlNTRXNzs8jPzxdlZWWyI6mmsrJSLF++XHYMv1m5cqXYuHHjmG1r164VxcXFkhKpC4BoaGjwrXu9XpGQkCBeeOEF37a//vpLGAwG8fbbb0tIODNX1nc1HR0dAoDo7e0NTCiV/VeNv/zyi7jhhhtEV1eXWLRokXjllVcCnk0NV6tv3bp1Yv369XICCSGC6ojI33//jc7OThQUFPi2hYSEoKCgAF9++aXEZP4xODgIAIiJiZGcRF0WiwUrV64csx+14tChQzCbzXjooYcQFxeHrKwsvP7667JjqWbZsmVoaWnBmTNnAADffvstTpw4gfvuu09yMv/o6enB+fPnx/yuGo1G5ObmanLMAf4Zd3Q6Ha677jrZUVTj9XpRUlKCiooKpKeny46jKq/Xi48++gg33XQTCgsLERcXh9zc3HFPT6ktqBqR33//HR6PB/Hx8WO2x8fH4/z585JS+YfX64XVakVeXh6WLFkiO45q3nnnHdjtdlRVVcmO4hc///wzamtrkZqaiqNHj+KJJ57Ali1bcODAAdnRVLFt2zY8/PDDSEtLQ1hYGLKysmC1WlFcXCw7ml+MjivBMOYAwKVLl1BZWYlHHnlEUzeJq66uhl6vx5YtW2RHUZ3T6cSFCxewa9cu3HvvvTh27BgeeOABrF27Fm1tbQHJMKvvvkvTZ7FY0NXVhRMnTsiOopq+vj6UlZWhubk5cOcuA8zr9cJsNmPnzp0AgKysLHR1daGurg6lpaWS083ce++9h7feegsHDx5Eeno6HA4HrFYrFEXRRH3BbGRkBEVFRRBCoLa2VnYc1XR2duLVV1+F3W6HTqeTHUd1Xq8XAHD//fejvLwcALB06VJ88cUXqKurQ35+vt8zBNURkfnz5yM0NBQDAwNjtg8MDCAhIUFSKvVt3rwZhw8fRmtrKxYsWCA7jmo6OzvhdDpx++23Q6/XQ6/Xo62tDXv37oVer4fH45EdccYSExNx6623jtl2yy23BGz2ur9VVFT4jopkZGSgpKQE5eXlmj3CNTquaH3MGW1Cent70dzcrKmjIZ9//jmcTidMJpNv3Ont7cWTTz6JpKQk2fFmbP78+dDr9VLHnaBqRMLDw3HHHXegpaXFt83r9aKlpQV33nmnxGTqEEJg8+bNaGhowKeffork5GTZkVS1YsUKfPfdd3A4HL7FbDajuLgYDocDoaGhsiPOWF5e3r8uuT5z5gwWLVokKZG6Ll68iJCQscNOaGio778yrUlOTkZCQsKYMcflcuHkyZOaGHOA/zchZ8+exSeffILY2FjZkVRVUlKCU6dOjRl3FEVBRUUFjh49KjvejIWHhyM7O1vquBN0p2a2bt2K0tJSmM1m5OTkYM+ePRgaGsJjjz0mO9qMWSwWHDx4EB9++CGio6N956CNRiMiIyMlp5u56Ojof813iYqKQmxsrGbmwZSXl2PZsmXYuXMnioqK0NHRgfr6etTX18uOpopVq1bh+eefh8lkQnp6Or755hu8/PLL2Lhxo+xo03bhwgX8+OOPvvWenh44HA7ExMTAZDLBarXiueeeQ2pqKpKTk2Gz2aAoCtasWSMv9BSMV19iYiIefPBB2O12HD58GB6PxzfuxMTEIDw8XFbsKZloH17ZXIWFhSEhIQE333xzoKNOy0T1VVRUYN26dbjrrrtw99134+OPP0ZTUxOOHz8emIDSrteRqKamRphMJhEeHi5ycnJEe3u77EiqAHDV5c0335QdzW+0dvmuEEI0NTWJJUuWCIPBINLS0kR9fb3sSKpxuVyirKxMmEwmERERIVJSUsTTTz8thoeHZUebttbW1qv+3ZWWlgoh/rmE12azifj4eGEwGMSKFStEd3e33NBTMF59PT09/znutLa2yo4+aRPtwyvNtct3J1PfG2+8IW688UYREREhMjMzRWNjY8Dy6YSYw19pSERERHNaUM0RISIiotmFjQgRERFJw0aEiIiIpGEjQkRERNKwESEiIiJp2IgQERGRNGxEiIiISBo2IkRERCQNGxEiIiKSho0IERERScNGhIiIiKRhI0JERETS/A+HQONdx2ahnwAAAABJRU5ErkJggg==" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# A plot of the solution of the commutative-noise SDE used to compare the methods\n", + "# A plot of the solution of the SDE\n", + "# We will use this to compare the methods\n", + "sol_commutative = diffeqsolve(\n", + " terms_commutative_sde,\n", + " GeneralShARK(),\n", + " t0,\n", + " t1,\n", + " dt0=0.02,\n", + " y0=commutative_sde.y0,\n", + " args=commutative_sde.args,\n", + " saveat=diffrax.SaveAt(steps=True),\n", + ")\n", + "plot_sol_general(sol_commutative)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "3f6e04f29792d26a", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-04-02T15:07:43.564266Z", + "start_time": "2024-04-02T15:06:28.474256Z" + } + }, + "outputs": [], + "source": [ + "# A comparison of SlowRK, SPaRK and GeneralShARK for commutative noise SDEs\n", + "# We compute their orders and plot them on the same graph\n", + "out_SLOWRK_commutative_sde = constant_step_strong_order(\n", + " keys, commutative_sde_short, SlowRK(), levels=(4, 10)\n", + ")\n", + "out_SPaRK_commutive_sde = constant_step_strong_order(\n", + " keys, commutative_sde_short, SPaRK(), levels=(4, 10)\n", + ")\n", + "out_GenShARK_commutative_sde = constant_step_strong_order(\n", + " keys, commutative_sde_short, GeneralShARK(), levels=(4, 10)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "9a887a880a90ecd4", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-04-02T15:07:44.059675Z", + "start_time": "2024-04-02T15:07:43.564994Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "
", + "image/png": "iVBORw0KGgoAAAANSUhEUgAABIMAAAOOCAYAAACTMtKnAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAAB7CAAAewgFu0HU+AAEAAElEQVR4nOzdd3gUVdvA4d+mN5IQegk9gIQSQEKH0AMkdEFBmiCogIBiQ0V4UZQXFUV67wjSO9I70kvoLZBAgCSk9zLfH/ky7y7ZbDadmOe+rr2yu3PmzJndyezss+c8R6MoioIQQgghhBBCCCGEKBRM8rsBQgghhBBCCCGEECLvSDBICCGEEEIIIYQQohCRYJAQQgghhBBCCCFEISLBICGEEEIIIYQQQohCRIJBQgghhBBCCCGEEIWIBIOEEEIIIYQQQgghChEJBgkhhBBCCCGEEEIUIhIMEkIIIYQQQgghhChEJBgkhBBCCCGEEEIIUYhIMEgIIYQQQgghhBCiEJFgkBBCCCGEEEIIIUQhIsEgIYQQQgghhBBCiEJEgkFCCCGEEEIIIYQQhYgEg4QQQgghhBBCCCEKEQkGCSGEEEIIIYQQQhQiEgwSQgghhBBCCCGEKEQkGCSEEEIIIYQQQghRiEgwSAghhBBCCCGEEKIQkWCQEEIIIYQQQgghRCEiwSAhhEHLly9Ho9Gg0WgYMmRIfjenQAgODmbKlCk0btyYokWLYmpqqr6Gy5cvz+/mCSGEEMJIlSpVUj/DfX1987s5AvD19VXfk0qVKuV3c4QosMzyuwFCFBaPHj1i27Zt7Nmzh/v37/P8+XPi4+MpVaoU5cqVo127dnh7e9OoUaP8bqrIhgcPHtCqVSuePHmS300RQgghhBD/AkFBQaxatYrDhw9z9epVgoODiYmJwcbGhpIlS1K5cmXq1atH48aNadeuHU5OTunWpdFo0l1ma2uLvb09Dg4OVKhQgYYNG9KwYUM6dOiAvb290e09cuQIbdq0ydQ+anv48KEE+vKABIOEyGVBQUFMnjyZBQsWkJiYmGb5o0ePePToEadOnWLq1Kl06tSJn3/+mdq1a+dDa0V2jRw5Ug0EWVtb0759e8qVK4epqSkAb7zxRn42TwghhMhz2l8MW7duzZEjR/KlHZUqVeLRo0eAfNkUBUNycjIzZsxg8uTJxMbGplkeERFBREQE9+/f58CBAwCYmJiwbt06+vbtm+ntRUVFERUVRUBAALdu3eLvv/8GUoJEb7/9Np988gm1atXK3k6J14YEg4TIRTdu3MDT0xM/Pz/1OTMzM5o0aULFihWxtLTk6dOnnDp1ivDwcAD27dvHkSNHWLNmDb17986vpossCAgIUD+ILS0tuXLlCi4uLvncKiGEEEIIURANGzZMJ8WAiYkJDRo0oHr16hQpUoTIyEj8/f25dOmS+l0iOTlZvZ+RHj16UK5cOfVxYmIiISEhBAcHc+nSJV6+fAmkBImWLFnC6tWrmTZtGuPHjzfYw+hVo0aNMroskKleSCLrJBgkRC65ceMGLVq0ICQkBABzc3MmTJjAp59+SrFixXTKxsXFsX79eiZMmEBgYCBxcXH07duXlStXMmDAgPxovsiCS5cuqfdbtmwpgSAhhBCigJM8Qa+fSpUqoShKfjcj161cuVInEDR06FB++OEHypQpk6ZscnIyp0+fZuPGjaxYscLobYwdOxYPD490l9+4cYP58+ezdOlSoqKiiIuL49NPP+XBgwfMnj3b6O1kpqzIO5JAWohcEBsbS79+/dRAkI2NDfv372fatGlpAkGQ0otk0KBBOj1JkpOTGTlyJHfu3MnTtousS32/Ab0f1EIIIYQQQhhj+vTp6v1hw4axdOnSdK8vTUxMaN68OTNnzuTJkyd06dIlR9pQq1YtZs2axeXLl6lTp476/Jw5c5g3b16ObEPkHwkGCZELpk2bho+Pj/p41apVtG7dOsP1ypQpw/79+ylSpAiQ0iXz/fffz7V2ipyVkJCg3jcxkdOrEEIIIYTIvCdPnnDjxg318WeffWb0utbW1pQtWzZH21OtWjUOHz6Ms7Oz+tzEiRONHo4mXk/ybUWIHBYdHc2cOXPUxz169KBXr15Gr1+xYkWmTJmiPj527Bhnz57VW9bDw0OdWjM1GWNAQADTpk3D3d2d0qVLY2pqiqOjo971L168yPvvv0+VKlWwtramRIkSuLu789///lcdI5wV586dY/z48bi5uVGiRAksLCwoXbo0rVu3Zvr06To9aNKjbyrX+/fv8/XXX1O/fn1KlCiBiYkJbm5uWW5nqsjISGbNmkWnTp0oX748VlZWFC1alNq1azN69Gj++eefdNc9cuSI2s6hQ4eqz69YsUJ9PvU2ZMiQbLcVYM+ePYwcOZLatWtTrFgxzM3NcXR0pEGDBowcOZLt27frTVauTVEU/vrrL9555x2qVq2KnZ0ddnZ2VK1alf79+7Nx40ajumDrOwZfvnzJ9OnTadSoEcWLF8fa2poqVaowbNgwnSDpq3799Ve1rk6dOhn9epw6dUpdz8nJibi4uHTLRkVFMW/ePLy9valYsSI2NjYUKVIEFxcX3nvvPQ4dOpTh9pYvX57mPU1KSuLPP/+ke/fu6v+TRqNh69atadYPCQnh+++/580336Ro0aLY2dlRo0YNhg8fzrlz59Ry2seOMYKDg/nll1/o0KEDzs7OWFlZ4ejoSK1atRg1ahTnz5/PsI7Jkyer25w8eTKQkj9g5cqVajJ0S0tLypQpQ48ePdi5c6dRbdN29epVvvzySxo3bkzp0qWxsLBQX4N+/fqxZMkSwsLC8mR/s+r69et89tln1K9fn+LFi2NpaUnZsmXx8PBg+vTpBAcHZ1iHvuMIYMuWLXh7e1OhQgUsLS0pWbIkHTt2ZPXq1Tk+LCImJoatW7fy8ccf06JFC0qVKqW+H5UqVaJnz54sWbKE+Pj4HN3uq173c1pAQABTpkyhfv36ODk5YWVlRc2aNfnyyy/1flb6+/szceJE6tevT9GiRSlSpAhubm5MmzaNmJgYg9tPb8rs48ePM3DgQKpVq4aNjQ0ODg54eHiwdu1avft1+PBh3nrrLVxcXLC2tqZkyZJ07dqVPXv2ZPgaDBkyRG2D9hCV9KR3LMP/zinaswodPXo0zedjelOEJycnc/z4cSZNmkTHjh2pUKECNjY26nmobdu2/PDDDwQFBaXbPu3XNDV5NEDlypX1tuPV5NaGppb/+OOP1WUjR47M8LVKtXbtWnU9V1dXg2Xz+lynb3/9/f359ttvqVevHo6Ojtja2lKzZk3GjBmj85oa49GjR0yaNIkmTZqo55xSpUrRpEkTvvvuO508m+nJzNTyfn5+TJkyhVatWlGqVCksLS2xsLCgWLFi1KtXj/79+zNv3jyePXuW4XYVRWHLli0MHjyY6tWr4+DggJWVFc7OzvTo0YMVK1ZkeK4y1quz0lasWDFH6s2OYsWKsWTJEvVxaGioznceUQApQogctXz5cgVQb8eOHct0HREREYqdnZ1ax5AhQ/SWa926tVrm8OHDytatW5WiRYvqbB9QHBwc0qz79ddfK6ampmnKpt7Kly+vnD59Wlm2bJn63ODBgw22++XLl0rv3r3TrTP15ujoqPz1118G66pYsaJa/uHDh8qCBQsUKyurNHXVq1fPyFdVvx07diilS5fOsM39+/dXoqKi0qx/+PDhDNc19vXLiI+Pj/Lmm28ata1+/fqlW8+dO3eU+vXrZ1hHw4YNlfv37xts06vH4IkTJ5Ry5cqlW6epqamycOFCvXU9ffpUPSZNTU2VgIAAo16XDz/8UK1/xIgR6ZbbsGGDUe+1l5eXEhoamm49r/5PPHnyRGnRooXeurZs2aKz7qFDh5RSpUqlu20TExNl8uTJiqIoOs9nZPbs2YqDg4PB/dJoNMp7772nxMXFpVvPd999p5b/7rvvFH9/f6VZs2YG6x06dKiSlJSUYRtDQkKUfv36KRqNJsP3oFSpUnmyv5mVkJCgjBkzxuC5M/Uct3z5coN1vXochYaGKt26dTNYr6enpxIdHZ0j+3LmzBmdzxlDt0qVKikXL17Mke1qKwjntH379inFihVLt86KFSsqvr6+6vpLlixRLC0t0y3v6uqqvHjxIt3tP3z4UKfuxMREZdy4cRn+DyYnJyuKoihRUVFK9+7dDZb/7LPPDL4GgwcPVssuW7bMYFlFSXssa9M+p2R0q1ixos668fHxBj9PtG+2trbKqlWrMnxNjbkdPnxYZ/1Xr0e0/fPPP+qyokWLGn2+6dy5s7retGnT0i2XH+e6V/d3y5YtBttgbW2t7Ny506i6v//+e73Xcto3Kysr5aeffjJYz6v/J+lZsGCBYm1tbdT73rx5c4PbvHLliuLm5pZhPTVq1FCuX79u1OthyLlz53TqvXHjRrbrTGXoeDdGnTp11PXr1q2rt8yr18fi9SQJpIXIYYcPH1bvOzs707Jly0zXYWdnR/fu3VmzZg2AUVOwnjp1ismTJ5OQkECxYsVo1aoVxYsX58WLFzqJjSGlW+ePP/6oPraxsaFt27aUKVOGZ8+ecejQIfz9/enSpQvjxo0zqs3Pnj2jbdu23Lx5U33O1dWVevXqYWdnx4sXLzh+/DjBwcGEhobSt29fVq1aZVSC7L/++ovPP/8cgLJly9K8eXMcHBx4+vRptnowrV+/ngEDBpCUlASAqakpLVq0oFq1akRGRnL8+HGePn0KpPyK9/DhQw4dOoSVlZVaR7ly5dQZEm7dusXBgwcBqFmzJu3atdPZXpMmTbLc1iNHjtCtWzciIiLU5ypUqIC7uztOTk5ERUVx+/Ztrly5QkJCgt7pRwFu3rxJ69atCQwMVJ+rU6cObm5uaDQaLl26xLVr1wC4cOECzZo149ixY1SvXj3DNvr4+PDVV18RGRlJyZIladmyJcWKFePJkyccOnSImJgYkpKS+OCDD6hTp06a1yP1V979+/erPW0yOv4SEhLYsGGD+njgwIF6y82cOZNPP/1U/QXd3t6epk2bUr58eZKSkrh+/Trnz59HURR27tyJh4cHJ0+exMbGxuD24+Li6NatGxcuXMDMzIxmzZpRtWpV4uLiuHjxok7ZM2fO4OXlRXR0NJDS86dRo0a4uroSHx/P2bNnuXv3LpMnT6Z48eIGt6tt3Lhx/P777+rj4sWL07RpU0qXLk1sbCyXLl3Cx8cHRVFYunQpT58+ZdeuXRkOZYyMjMTT0xMfHx9sbGxo2bIlzs7OREREcPjwYV68eAHAsmXLqFGjBl988UW6dT19+pS2bdty+/Zt9TlHR0eaN29OmTJlSEhI4PHjx1y4cIHw8PB0j9/c3N+MJCcn07t3b7Zv364+5+TkhIeHB05OTvj5+XH48GHi4+MJDQ1lyJAhhIaGMnbs2AzrTkxMpHfv3hw8eBALCwv1OIqNjeX48eM8fvwYgL179/LJJ5/kSJ6GkJAQIiMjAShZsiSurq6UL18eW1tboqOjuXfvHmfPniUxMRFfX19at27NxYsXqVatWra3DQXjnHb58mUmTpxITEwM5cuXp3nz5hQpUoQ7d+5w/PhxFEXh0aNHdO7cmWvXrrF+/XqGDRsGgIuLC+7u7lhZWXHt2jW1h+/169cZOHAge/fuNep1+uabb/jtt98wMTGhUaNG1KpVi8TERI4fP6722li2bBkuLi588cUX9OnThz179mBmZkbz5s2pVq0a0dHRHD58WO35MGPGDBo0aMDbb79tVBuyw93dnVGjRvHkyRO1p2TZsmXp2bNnmrKv5lRMSkpSe0fY2dnh6upKlSpVsLe3JyEhAX9/f86cOUN4eDhRUVEMHDgQc3Nz+vXrp1OPvb29+jm9cuVK9ZgbNGiQOixfm/bMSsbsX/Xq1blz5w4hISHs3r2bHj16GFwnMDCQ/fv3AymfA+ldB+XXuU7bgQMH+OCDD0hKSqJChQo0bdoUe3t7Hj58yJEjR0hMTCQmJoa+ffvi4+ND5cqV061r9OjROr1I7OzsaNOmDaVLl+bZs2ccPnyYyMhIYmNj+fLLL3n27BkzZ87Mctu3bt2q01tL+3PfzMyMsLAw7ty5g4+PT4a9H48dO4a3t7c6JMrc3JxGjRrh4uKCubk5vr6+nDhxgtjYWG7fvk2zZs04ffo0b7zxRpbbX6VKFTQajXrd8tNPP2UqMXRueuutt9Tzqo+PD6GhoemOQhCvuXwMRAnxr1S1alU1Ct6nT58s1zNr1iydiLq/v3+aMtq/YJqZmSkajUaZOnWqEh8fr1MuNjZWvX/06FGdX+b79OmjvHz5Uqd8aGio8vbbbyuAYmFhke4vfqmSkpKUNm3aqOXc3d31/oocExOjTJ48Wd2+ra2t8uDBA711av8yZWZmplhYWCgLFy5Uf/3Ut2+Zce/ePZ1fxd3d3ZW7d++m2a9ffvlFMTExUcuNGTMm3Toz04sqMx4/fqwUL15crbty5crKnj179JZ9+fKlMn/+fGXChAlplsXFxSn16tVT6ylZsqSyf//+NOX27duns70GDRqkOaZSaR+DlpaWiqmpqfLLL78oCQkJafahdu3aatk2bdrorW/FihVqmYYNG2b00ijbt2/XeV1ePT4URVEOHDigvocWFhbKTz/9pLeX16VLl5RatWqp9X344Yd6t6n9PpuZmSmA0rp16zS/GCvK/47PmJgYxcXFRaetZ8+eTVN+/fr1io2NTZqeBelZsmSJWsbe3l5ZtGiR3vfq0KFDOr+wT58+XW992r/ip7Zh8ODBSnBwsE65qKgo5Z133lHL2tnZKZGRkXrrTEhIUJo3b67zK/Ls2bP1tjMuLk7Zvn270qNHjzzZ38yYPn26znvy5ZdfpvklPiAgQOnYsaPO8XHmzBm99WkfR6mvdefOndOc6xMSEpQJEybo9ALQd6xl1pkzZ5SJEycq165dS7fM8+fPlYEDB6rbbteuXba3qygF65xmbm6uzJkzJ03vtyNHjii2trZq2WnTpil2dnaKvb29snHjxjT1rl+/XqdH2dGjR/VuX7vHg7m5uaLRaJSaNWsqly5d0imXkJCg02OoaNGiypQpUxRAadGiRZrP1ujoaKVv375q+SpVqug9XypKzvYMSqXdS6B169YZ1qkoKe/v0KFDlcOHD6f7fsXGxir//e9/1XOxo6OjEhERkW6dhnr5ZHWd1NcdUHr37p1hfdrXd+m9Fvl5rtPeX0tLS7XX1avHi4+Pj862hw4dmm6d69ev1zl/DhkyRAkLC9MpExYWprz77rs65TZt2qS3PmN6Bmn34hk9erTez31FSemNv2HDBuWLL77QuzwgIEApWbKkWtegQYOUp0+fpin37NkzpWfPnmq5OnXqKImJiem+JsZo1aqVzuvRtWtX5cCBA2musTJLu86s9Azat2+fTh379u1LU0Z6BhUM8s4IkcNSL0gAdbhHVhw6dEjnJHrixIk0ZbQvWgHl+++/z7Be7SEf7dq1S/eDKikpSedLjaGLvJUrV6plmjRpkuEwBu0vnB988IHeMtoXI4CyevXqDPctMwYNGqTWXa1aNYPDgn799Ve1rImJSboBrNwKBg0YMEDnoufZs2dZqmfp0qU6XzIMDfs4e/aszrG8YsUKveVePQYXLFiQbp3Xrl1TA4EajUbvxVRERIRiY2Oj1nfr1i2D+9SvXz+17DfffJNmeVJSkk4QZvPmzQbrCwgIUIdxmZubK35+fmnKaL/PqRd8GR3z8+bNU8vb2Ngo9+7dS7fs5s2bdepP7yIqPDxccXR0VCAlyJVe0CHVjRs31O75xYoV03th/OqQjnfeeSfd+mJiYhRnZ2e17J9//qm33KJFi3SOu6wMnVWU3NlfY4WFhekEj/UFJlLFxsYqjRo1UsumF/h89Thq2bJluhf4ycnJOnVmNIQip2kPacmJoQoF6Zy2ePHidOv8/vvvdcpqNBrl4MGD6ZYfPny4Wja9YPOrQ5pKlCiR7pDZxMREpUaNGjrl33jjjXTPR+Hh4YqTk5Na9p9//tFb7nUJBmXGTz/9pNY/d+7cdMvlRjDo/v376nJLS0uD1xOKoiiNGzc2eHzl57lOUXT3V6PRpBuoVRRF2blzp1rWzs5O7zksKSlJqVy5slrurbfeSjcQmZycrDPMsWrVqnqHIWcUDIqIiFCXOzs7p7s9Y7z33ntqXR9//LHBsomJiUrbtm0z/Fw01unTp3V+lE29FSlSRGnXrp3yxRdfKBs3bjR6WH0q7bqyEgzy9fXVqWPlypVpyrwaDBo1apTRt/SGfIqcJ8EgIXJQWFiYzonv999/z3Jdly5d0qlr+/btacpoX7SWLVs2w18Kbty4oVNnRmOab9++rdOLKL2LPO1fXy5fvpzhvsXExKgXOg4ODno/6LUvRtzd3TOsMzNCQkJ0el5kFCBISkpSXF1d1fJffvml3nK5EQzy9/fX+QJj6KIsI9oXoBld0CiKbi6eJk2a6C2jfQzWqVMnwzrd3d0NHtOKoij9+/dXy+gL8KQKDw/XyQWgL3C0detWdXl6vU1e9eOPP6rr/PLLL2mWv/olfvfu3RnWqZ0XZeLEiRmW1+5pB/o/qn/77Td1+bhx4zLeMUVRRo4cqa6j7xdX7WCQhYVFhheYn3/+uVr+k08+0VumZs2aapn0fnk1Rm7sr7G0g3mlSpXKMPinnUckvWPz1ePowoULBuucO3euWrZXr15Z3pes0P5Vf9asWdmqqyCd0zLKSacdBDDmHHPw4EG1bHo9H18NBs2cOdNgnd9++61O+a1btxosr93T648//tBbpiAGg54/f27U/0duBIMURfeHNkMBxLt376rlrKys9AaO8vNcpyi6++vt7W2wbHJysk4uvqtXr6Yps2fPnkx9rvj7+yvm5ubqOnv37k1TJqNg0JMnT9Tlbm5uhnfYgBcvXqjBmNKlSysxMTEZrnP69GmjXz9j7Ny5U71mNnRzdXVVpk2bpoSEhGRYp/Z6WQkGhYSE6NSh7/tOZnJqvnrLyR9UhWEym5gQOUg79wGAra1tluuys7PTeZzR1I19+vTBzMxwGjDtfEYNGzakVq1aBstXr149wzw3AQEBXL58GYBatWpRr149g+UBrKysaNq0KQBhYWEGZ5gCcjyvwalTp9QZp4oXL463t7fB8iYmJrz33nvqY+3XMbcdOHBAnZnCxcUFT0/PLNUTERGhM+OI9v6kZ/jw4er9c+fOERUVZbD8W2+9lWGd9evXV++/OitLqnfffVe9v3bt2nTr2rx5szozz5tvvkmNGjXSlNm9e7d6v3///hm2D6Bt27bq/RMnThgsW7RoUTp27GiwTEREhE7+IO39S48xZXJ731q0aEHp0qUNlsno/Xz06BG3bt1SH48ePdqoduqT2/triPYsc++88w7W1tYGy7u7u1OnTh31cUbnjCpVqtCgQQODZYz538mq6OhoDh06xO+//84333zD2LFjGT16tHpbt26dWjb1fJ9VBemc1qdPH4PLq1SpovM5n1H52rVrq/cfPnxosKyxdWofZ9bW1nTp0iXH2/A6SE5O5ty5cyxatIjvvvuO8ePH6xyj//nPf9Sy2T1Gs0L7nJ2a71Ef7WVeXl44ODikKZOf57pXZfS5rtFodK779J2btM+fXbp0yfBzpVy5cjrnhaxccxUvXlzN7+jj48PJkyczXQeknK9S8wn16tVLJ2dkeho3bqyeF3LivejatSt3797lyy+/NDhd/PXr15k4cSJVq1bVyaWYG179jvLq9x9RcEgCaSFy0KuJCDO60DQkNbFnKnt7e4PlGzZsmGGd2omkU4MxGWnatCmnT59Od7n2spiYGKO/7N2/f1+97+fnR926ddMta8y+ZYb26+Du7p5hEA2gefPmOusrimL0lN/ZcebMGfW+h4dHluu5evWqmijbzs7O4Oudys3NDVtbW6KiokhKSuLKlSs0a9Ys3fLaX0rSo50gNL0AZ4cOHShZsiQvXrzgwYMHnDp1Su92V69erd5PL3G09vG5adMmjh49mmEbtac1z2iKWzc3N0xNTQ2WuXr1KsnJyUDK/3HNmjUzbEPjxo0zLKO9bwsXLjQqsaS/v796P6N9y4n3U/v4dXFxoXz58hnWmZ7c3l9DtM8Zhv4HtDVv3lxNsPlqMvFX5dT/Tma9fPmSSZMm6STVzYihKbyNUZDOadqBk/Q4Ojqqn/UZTRHu5OSk3jfmPXRwcMjwf6Zo0aLq/erVq2Nubp6jbchviYmJzJo1i5kzZ+r8PxuS3WM0K/r27cvYsWNJSEjg6NGj+Pv7633vtINBxnxu5fW57lU5cW7K6vlzx44dQMbnT30sLCzo0aMHf/75J4mJibRt25Z+/frRp08fWrVqZXSyY+334urVq5n+QSMkJISoqKhs/TgMKcGtH3/8kR9++IELFy5w7Ngxzp49y8WLF7l3755O2ZcvX9KvXz8iIyONCpJnxaufFxl9RwHURNji9SLBICFykL29PWZmZuqvntmZ6SokJETnsfYFnD4lSpTIsE7t2VYqVKhgVDsyKpc62xak/MqoPVOEsV7d11cZs2+Zof06VKxY0ah1KlWqpN6Pj48nIiLCqA+/7Hr+/Ll6v0qVKlmuR3ufnZ2djQpkmZiY4OzsrPbsyOgCW98vnK/S/qKSkJCgt4yZmRlvv/02s2bNAlIunl+9gAwICFB/bUwtr4/28bl+/foM2/eqnDg2tV/78uXLG/XaZ/QFMDIyUudibPHixRnW+aqM9i0n3s+cOn7zYn8Nye45I6/+dzLj0aNHtGrVSp2pzFjZ/QX433ZO0/4xIaPy2mVTrxPycvuvls+J4yg3pc7W+Pfff2dqvfzopVCsWDE6d+7M9u3bSU5OZt26dXz22Wc6ZVJnjNQu/6r8Pte9KifOTbl9/kzPzJkzuXDhAnfv3iU+Pp5Vq1axatUqTExMcHV1pWXLlnTo0IHOnTtjaWmptw7ta4gTJ05kqadPSEhItoNBqVJnFWzUqJH63LNnz9i2bRu//vord+7cUZ8fNWoUnTp1ytTseMbS/tEMMv6OIl5fMkxMiBym/UGX0fAnQ15dV/uDUZ+Mhi2Abm+jjKbMTpXRB9irHwhZkdFFsTH7lhnar4OxH9Cvlsuri03t7bzaLTczsrLPr5bNaJ9zsqeU9i+mGzZsSHOBuW7dOrW3TceOHSlZsqTeerJ7fObEsZmV/7uM3uu8+L/Lifczp47fvNhfQ7J7zsjL/x1j9e/fXw0EFSlShPHjx7N3714ePHhAZGQkSUlJKCm5JXWGaaT+32XVv/mcltPvY35vP79NmTJFDQRpNBr69evHhg0buHnzJmFhYcTHx6vHqHavg/zqgaD9uaXdc1Xfc/369dPbiyu/z3WvyoljKrfPn+kpXbo058+f55tvvqFUqVLq88nJyVy7do25c+fSs2dPypQpw08//aT2NNT2ur0f+pQuXZqRI0dy7do13nnnHfX52NhYFi5cmCvb1B7+ndoGUTBJzyAhcljz5s3VIVD//PNPluvRXrdSpUo5EtnXvvCOjo42ap2Mhrppf2B369aNbdu2Za1xeUj7dTB2KN+r5V4dEphbtLfz6tDBzMjKPr9aNq/2GVJyANWsWZNbt24RFBTEvn378PLyUpdrd7U3lF/H1tZWvZi7ePGiTt6VvJLb/3eQ0gtRe7jI6yKnjt/83l87Ozv1OMrKOSMv/3eMcerUKU6dOgWk7NuZM2cM5pDLyeB3YT2n/ZtlN0CoT1xcHH/88Yf6ePny5QwaNCjd8q9DzhJvb28cHBwICwvj6tWr+Pj4qEMNk5KSdHqnpve5ld/nutyQ3Wuu7Pyf2tvbM3XqVCZPnsz58+c5fvw4J0+e5MSJE2qPo5CQEL766ivOnDnDli1bdAJg2u/Hr7/+yvjx47PcltxmYWHBokWL2Ldvnzoy4fjx47myLe3vKKampjo9lUTBIj2DhMhhbdq0Ue/7+/tz7NixTNcRGRmpE1TRrjM7tIe0GDs0IKOx59q/tjx79ixrDctjWXkdtJMiWlhY5NmXCO3XNzvJPrX32d/f36hfTpOTk3Xe/+LFi2d5+1kxYMAA9b72L6o3b95UcwgUKVKEHj16pFvH63B8ar9uT548MWqdjHJjODo66nRrf13/93Lq+M3v/c3uOSOv/3cycvDgQfX+4MGDM5xM4NGjRzm27cJ8TisotHusGNOrISd6T7zq7NmzarDQ1dXVYCAIcvYYzSpLS0udhN/an1t///03L168AKBatWrp5m3M73Ndbngdzp+mpqY0btyYCRMmsGXLFp4/f87x48fp1q2bWmbbtm1s2rRJZ73X4RoiM2xtbWnRooX6OCAgIFe2s3HjRvV+vXr18iRtgsgdEgwSIoe99dZbOr/i/Prrr5muY9GiRTq/mH7wwQc50jbtXhHaSTwNMZQ8GnQT3V6+fDlbSbPzivbrcPbsWb1dg1+V+it66vp51R1feza37MxiVrduXTXRcUREhJrY1pArV66o76epqalRM8XlpAEDBqiv8/bt29VffrV7BfXq1cvgUC3t4zOrs4lkV926dTExSfm4DQsLS9O9Wp+zZ89mWMbd3V29n1/7lhHt4/fOnTtGJ4DVJz/3V/ucoX0uMES7XEYzheU17TwYxiSIzcqPGukpzOe0gkL7i11wcHCG5Y157TP7mZlbx2huf3Zr9/hZt26dGqTU/tzS/qFDn4Jwbs+M1/H8aWJiQosWLdi6dSsdOnRQn9++fbtOudfhGiKztGc8Sy8XUnb8/fffOqkscnrGX5G3JBgkRA6ztbXlo48+Uh9v27aNLVu2GL3+o0ePmDRpkvq4VatWOhcG2aHdw+j8+fMZfim9d+9ehsGgKlWq8MYbbwApiZWXLFmS/YbmsmbNmqkfkIGBgezatctg+eTkZJYtW6Y+1p7CNbd16NBBTfh59+5d9u3bl6V6ihQpwptvvqk+Xr58eYbraL+X7u7uOZYA0ViVK1dWE0fHxMSwefNmFEXRmW4+oynYtYeWLV26lNjY2NxprAH29vY6F8OGph1OpS/fxKu0923evHmv5UwdFStWVM8PQJYSzKfKz/3V/p//888/MzyOzp8/z9WrV9XHOdW7M6ekBich46GLT58+zdHhv4X5nFZQaOcozGia9tjYWHXWJ0O0v6Aak7g6M8docnKy0blRMtuOzGrdujXOzs5ASi+YY8eOERUVxdatW9Uymfncel3P7Zmhff7cvXu32kMqPU+fPmXPnj16189pGo0Gb29v9bF2gnuATp06qeerU6dOceXKlVxrS07RbqOxk8UYKzg4mOHDh6uPixUrxocffpij2xB5S4JBQuSCiRMn6nS7f/fdd4361erZs2d07NhR7RVka2vLokWLcqxdb7zxhs6sTGPHjk13rH9ycjIff/yxURchX3zxhXr/m2++MepXwlT50e3W0dGRfv36qY8/++wzg/kGZs+ere6TiYkJI0aMyPU2pipbtqxOW0eOHJnmYsVYI0eOVO/PmTNH58vqqy5cuMCCBQvUxznVOy2ztBNyrlmzhlOnTqlDS8qVK5fhRWLv3r2pVq0akNJd+qOPPjL6wjoyMjLHerppT+/622+/GRwes337dp1hPOkZOXKkOj3uxYsXmTJlitHtCQoKMqpHXE745JNP1Pu//PJLlnMY5Of+9u/fX817ERAQYHDb8fHxjBkzRn3cpk0batSokeVt5wbtWbxe/SVcW1JSEiNGjCA+Pj7Htl3Yz2kFgXZviJ07dxqczWnSpElGzfakPf24McNltY/Ro0ePGhyKNmPGDKO/pGe2HZml0Wh0ev6sWbOGrVu3qp8lTZo0UT+T0lNQzu3G6tixI5UrVwZSckGNGzcu3bKKojBmzBg1UFe1alXat2+f6W1GREQYfd7SHjr66mQU5cqVU4N3iqIwaNAgwsPDjao3OTlZZya1zPL19eWbb77J1MzEq1at4vbt2+pjT0/PLG//Vffu3aNt27Y6r9d///vfbE0EIF4DihAiV1y7dk1xcHBQAAVQzM3NlYkTJypBQUFpysbFxSkrVqxQSpYsqZY3MTFRVq9ebXAbrVu3VssfPnzYqHYdPnxY0Wg06nr9+vVTQkJCdMqEhYUp/fv3VwDFwsJCLTt48GC9dSYmJipt27ZVy9nb2yvz589X4uLi9JYPCwtTVq9erbRu3Vrp06eP3jIVK1ZU63v48KFR+5YZ9+7dU+zs7NRtNG3aVLl//75OmaSkJOW3335TTE1N1XJjxoxJt85ly5Zl+FplxePHjxUnJye17sqVKyt79+7VWzYkJERZsGCB8tlnn6VZFhcXp9SrV0+tp3Tp0sqhQ4fSlNu/f79SokQJtVyDBg2U+Ph4vdvL7DH43XffqeW/++67DMu/fPlSPQZNTU2VHj16qOtPmDAhw/VT90f7PezcubNy48aNdMtfunRJ+fzzzxVHR0fl2rVraZZn5X2Ojo5WqlWrpq5XpUoV5fz582nK/fXXX4qtra1iaWmpljX0Ua3dFkAZNGiQ8ujRI71lk5OTlRMnTigffvihYm1trURERKQpk9n35/Dhw2r51q1b6y2TkJCgNGvWTC1nY2OjzJkzR+8xFRcXp2zfvl3p0aNHnuxvZkyfPl1n2998802ac9yzZ88UT09PtYyZmZly5syZDPfFmOPo4cOHavmKFStma19u3ryp8znw6aefKtHR0TplAgIClO7duyuAYmtrm+H7nBn/pnNaZj+rMvq/zuz7bMz/oDZjjrvk5GSlatWqarn27dsrL1++1CkTFRWlTJgwQQF0zleGrhNsbGzUcmfPnjXYzqSkJKVcuXJq+Q4dOihPnjzRKRMbG6t8++23aY5RQ+fMkSNHqmU++ugjg21Ildn32MfHRy1ftGhRneuj2bNnG7XN/DzXZXZ/Bw8erJZftmyZ3jLr16/X2Z/hw4enaWd4eLhOXYCyadMmvfVl9H9y+PBhpUyZMsp3332nXL9+XW8diYmJyp9//qlYWVmpda1ZsyZNuSdPnihlypRRy9SoUUPZt29fuq+Hn5+f8uuvvyrVqlVT/vjjj3TLZeTmzZsKoBQpUkR5//33lePHjysJCQl6y4aHhys//PCDYmZmprazXLly6R4L2q9xRue5GzduKB9//HGa/7FPPvnE4Hra5yZD/5Mif8lsYkLkktq1a3PixAk6d+6Mv78/CQkJTJs2jf/+9780bdqUihUrYmFhQUBAAKdOndL51cvS0pJVq1bx1ltv5Xi7PDw8mDBhAjNmzABg/fr17Ny5k7Zt21K6dGmeP3/OoUOHiIyMpGjRoowdO5bJkycbrNPU1JQNGzbQoUMHLl26RHh4OB988AGff/45TZs2pVy5cpiamhISEsLt27e5efOmmpSyd+/eOb6PxqhatSqLFy9mwIABJCUlcfr0aWrUqEHLli2pWrUqkZGRHD9+XOeXwyZNmvDf//43z9vq7OzMhg0b6NGjB5GRkTx8+BBPT08qVqyIu7s7Tk5OREZGcufOHS5fvkxCQgLdu3dPU4+FhQXr1q2jdevWBAYG8uzZM9q2bUu9evVwc3MDUoYEaP/CWrJkSdatW6d3Cty8ULRoUbp06cLWrVtJSkrS6Wqv3WvIkPbt2zNv3jw+/PBDkpKS2LNnD3v37qVWrVrUrVsXe3t7oqOjCQgI4MqVK9n6JS891tbWLF++nA4dOhATE8ODBw9o1KgRjRs3platWsTHx3P27Fnu3LkDpPRGGz16NGA4x8WQIUN48OABU6dOBWDlypWsWbMGNzc3atasiZ2dHZGRkfj7+3P58uVcSfSaETMzM9avX0/btm25e/cu0dHRjBo1iq+//prmzZtTpkwZEhMTefToERcuXCA8PBwHBwe9deXn/k6YMIETJ06oQ2K+//575s2bR5s2bShatCh+fn4cPnyYuLg4dZ0ZM2bo9LJ4XdSsWZOBAweycuVKIKXH1tq1a2nUqBElS5bE19eXY8eOER8fT5EiRZgxY0aO9qQpzOe0gkCj0fDjjz/St29fAA4cOEDlypVp164dxYsX59mzZxw7dozQ0FDKli2r/j8bYmpqSo8ePdRhvh4eHnh6elKhQgU195OTkxMTJ04EUnrhTp06Ve1VuX//fqpXr06zZs2oWLEiwcHBHDlyhJCQEAAWLlyYYS4eSLnmSO0hNnfuXC5cuECDBg2wsbFRy3z44YdUrVo1My+ZDldXV9zc3Lh8+TIhISEcOnQISEnMrd0rzpCCcG7PjL59+3Ls2DF1qPDixYtZv349bdq0oVSpUrx48YKDBw/q5MscN24cvXr1yvI2U3txTpkyhdKlS+Pm5kbp0qUxMzPj+fPnXLhwQSc3VcuWLfXmvylbtizbtm2jS5cuBAUFcfv2bTp16kS5cuVwd3enRIkSJCQkEBQUhI+PT7YS4+sTERHBokWLWLRoEba2tjRo0IBy5crh6OioXk+cO3dOZ/iyra0ta9euNarXzu+//66TEDoxMZHQ0FCCg4O5dOlSmrxh1tbWTJ8+Xb1GMVZmy/fp0wcPD49MrSOyIL+jUUL82z1//lz54IMPdKL1hm4dO3ZUrl69alTdWekZlOrLL79UTExM0m1H2bJllVOnTmXq1+vo6OhM7au1tbUybdo0vXXlds+gVDt27FBKlSqVYVvfeecdJSoqymBdudUzKNXly5d1fgU3dBswYEC69dy+fVupX79+hnU0aNBAuXfvnsE25XbPIEVRlI0bN6ZpW506dYxaV9uhQ4cUFxcXo14/QHF1dU3zS7SiZO99frWHwqs3ExMTZfLkyUp8fLz6nIODQ4b1rl+/XilbtqzR++bu7q7ExsamqSc3egalCg4OVnr27GlU+8qVK5cn+5tZCQkJyujRo3V6mum7OTg4pPsLear87BmkKCk9Ozp27GhwP8qXL6+cOHEi071PjPVvOKf9G3sGpZoyZYrB17NGjRqKj4+P0XX6+voqpUuXTrc+ffs7ceJEg22wsrJS5s+fryhKxq9tqnfeecdgna++71m5Hvn555/T1Ovt7W3Uutry41yXGz2DUk2dOjVNz1d972l614apMvo/OXPmjNHXooDSp08fJTw83OA2fX19lXbt2hldZ6lSpdLt8WiMFy9eKJ07d9bpoW/MrXHjxsrly5cN1p2Z+lJvdnZ2yvvvv6/cunXLqPa/2jMos7eZM2dm+bUTxpOeQULkspIlSzJv3jy++OILtm7dyt69e7l37x4vXrwgISGBEiVKUL58edq2bUv37t1p1KhRnrTrxx9/pE+fPsydO5dDhw4REBCAnZ0dlSpVolevXowYMYLixYvrjD3OiLW1tbqvq1ev5tChQ9y5c4fg4GCSk5NxcHCgSpUq1KtXj3bt2uHp6Znv01F6eXlx7949li5dys6dO7l+/TpBQUFYW1tTtmxZ2rRpw6BBg16LX/fr1avHpUuX2Lp1K1u3buX06dM8f/6cqKgo7O3tqVKlCu7u7nh7e9OpU6d066levTrnz59n48aNbNq0ibNnz6oJHUuWLEnjxo3p06cPvXv3zrNZ0wzx8vLC0dGR0NBQ9bmMEnDq06ZNG27evMnWrVvZtWsXZ86c4dmzZ4SHh2NjY0OpUqWoWbMmzZo1o3PnzmrPgpzUvn17bt26xezZs9m6dSsPHjwgISGBcuXK0apVK0aOHEmjRo10cqik5o4wpG/fvnTv3p0///yTffv2ce7cOQIDA4mMjMTW1pZy5crxxhtv0LJlS7p06UL16tVzfN8y4uTkxObNmzl37hxr167lyJEj+Pv7ExISgrW1NeXLl8fNzQ1PT0+d6Zn1ya/9NTMz448//uCDDz5g6dKlHDx4ED8/PyIiInBycqJ69ep06dKF999/Xyc3yevIxsaGPXv2sHbtWlasWKH26ixevDhVqlShd+/eDBkyhKJFi3LkyJFcaUNhPacVFJMmTaJDhw788ccfHD9+nBcvXmBvb0+1atV4++23GTZsGHZ2dpw7d86o+ipWrMiVK1eYPXs2f//9N3fu3CEiIsLg9PU//PADnTt3Zvbs2Zw4cYLAwECKFClC+fLl8fT0ZNiwYbi4uGRqv9asWYOXlxfr1q3j8uXLBAUF5fjkAv379+eLL77Qyd+Tlc+tgnBuz4xvvvmGgQMHsnjxYvbt28fDhw8JDQ3F0dGRKlWq0KlTJ4YPH57txMeNGzfmxYsXHDhwgBMnTnDp0iXu379PcHAwSUlJ2NvbU7VqVZo0acK7775r1EQtFStW5MCBA5w+fZq//vqLY8eO4efnR0hICGZmZhQrVgwXFxfefPNNOnbsiIeHh5p8OitKlCjB7t27CQ8P5+jRo5w4cYIrV65w9+5dAgMDiY6OxsrKCkdHR1xcXGjYsCG9e/emadOmWd4mpFzLOzg4YG9vT8WKFWnYsCGNGjWiQ4cOFClSJFt1i9ePRlEKeIp6IYQQ4l9m//79dOzYEUhJAKk9s4oQQgghhBDZJbOJCSGEEK+Z9evXq/fzqregEEIIIYQoPKRnkBBCCPEa+eeff2jZsqU6te7NmzepWbNmPrdKCCGEEEL8m0jPICGEECIPPH78mLfeeosTJ06g73eYpKQkVq9eTadOndRAULdu3SQQJIQQQgghcpz0DBJCCCHygK+vL5UrVwZSkto2bNiQMmXKYGpqyvPnzzl9+rTOtPZlypThwoULlClTJr+aLIQQQggh/qUkGCSEEELkAe1gUEbefPNNNm7cSMWKFXO5VUIIIYQQojCSYJAQQgiRR86ePcuOHTs4c+YM/v7+BAUFERoaip2dHaVKlaJp06b06tULb2/v/G6qEEIIIYT4F5NgkBBCCCGEEEIIIUQhIgmkhRBCCCGEEEIIIQoRCQYJIYQQQgghhBBCFCISDBJCCCGEEEIIIYQoRCQYJIQQQgghhBBCCFGISDBICCGEEEIIIYQQohCRYNC/0PTp09FoNGg0Gs6cOZPfzRFCCCGEEEIIIcRrxCy/GyBylo+PD9999x22trZERUXl2nZiY2O5du0aACVKlMDMTA4lIYQQQgghhBAipyUmJhIYGAhAnTp1sLKyynad8g3+XyQhIYHBgwfj5uaGi4sLq1evzrVtXbt2DXd391yrXwghhBBCCCGEELrOnj1Lo0aNsl2PDBP7F/nhhx+4fv06S5cuxdTUNL+bI4QQQgghhBBCiNdQoe4Z9OLFC86ePcvZs2c5d+4c586dIzg4GIDBgwezfPlyo+t69OgRs2bNYteuXfj5+WFpaUnVqlXp27cvo0aNwsbGJpf2IsXFixf54Ycf+M9//kOtWrVydVuQMjQs1dmzZylTpkyub1MIIYQQQgghhChsAgIC1JE52t/Fs6NQB4NKlSqVI/Xs2LGDd999l/DwcPW56Ohozp8/z/nz51m8eDG7du2iWrVqObK9V8XFxTFo0CDc3Nz4/PPPc2Ubr9LOEVSmTBnKly+fJ9sVQgghhBBCCCEKq5zK1yvDxP5fhQoV6NixY6bXu3TpEv369SM8PBw7Ozt++OEHTp06xcGDB3n//fcBuHPnDl27diUiIiKnmw3ApEmTuHv3LsuWLZPhYUIIIYQQQgghhDCoUPcMmjRpEo0aNaJRo0aUKlUKX19fKleunKk6xo4dS0xMDGZmZvz99980bdpUXda2bVtcXFz4/PPPuXPnDr/88guTJ09OU8enn35KXFxcprbp4uICwOnTp/n555+ZPHkytWvXzlTbhRBCCCGEEEIIUfgU6mDQlClTsrX+2bNnOX78OADDhg3TCQSl+vTTT1m2bBk3b97k999/5+uvv8bc3FynzIIFCzI1DXyfPn1wcXEhMTGRwYMHU7duXb788sts7YsQQgghhBBCCCEKBxkmlg1bt25V7w8dOlRvGRMTEwYNGgRAaGgohw8fTlMmMjISRVGMvnl4eKjr3b17l8uXL2NhYYFGo1FvK1asAKBp06ZoNBqdtgohhBBCCCGEEKLwKtQ9g7LrxIkTANja2tKwYcN0y7Vu3Vq9f/LkySzlJtLH0tKSYcOG6V127Ngx7t69S7du3ShRogSVKlXKkW0KIYQQQgghhBCiYJNgUDbcvHkTgGrVqhnM6F2zZs006+QEa2trFi9erHfZkCFDuHv3Ll999RVNmjTJdN3+/v4GlwcEBGS6TiGEEEIIIYQQQuQ/CQZlUWxsLEFBQQAZTqtetGhRbG1tiYqKws/PLy+al23Ozs753QQhhBBCCCGEEELkAskZlEXa08Tb2dllWN7W1hZIyfMjhBBCCCGEEEIIkV+kZ1AWxcbGqvctLCwyLG9paQlATExMrrVJ2/Lly1m+fHmW18+oB1NAQADu7u5Zrl8IIYQQQgghhBD5Q4JBWWRlZaXej4+Pz7B8XFwckJLnpyDIaOibEEIIIYQQQgghCiYJBmVRkSJF1PvGDP2KiooCjBtSJoQQQgghhMi82NhYQkNDiY6OJikpKb+bI4QQmJqaYmFhgb29PXZ2dpiYvB7ZeiQYlEVWVlYUK1aM4ODgDGfeCgkJUYNBkphZCCGEEEKInKUoCgEBAYSFheV3U4QQQkdiYiJxcXFERESg0WgoV66cTueS/CLBoGyoVasWx48f5969eyQmJqY7vfytW7fU+2+88UZeNU8IIYQQQohCITg4OE0gKL1rcyGEyEtJSUkoigKkBK6fPHnyWgSE5AyZDS1atOD48eNERUVx4cIFGjdurLfc0aNH1fvNmzfPq+YJIYQQQgjxrxcfH09gYKD6uGTJkjg6OmJqapqPrRJCiBSKohAdHc3Lly+JjIxUA0LVq1fP1yFjr8dgtQKqR48e6v1ly5bpLZOcnMzKlSsBcHR0pE2bNnnRNCGEEEIIIQoF7fydxYoVo1ixYhIIEkK8NjQaDba2tpQvX17NIawoilG5h3OTBIOywd3dnZYtWwKwZMkSTp8+nabML7/8ws2bNwEYO3Ys5ubmedpGIYQQQggh/s1Sc3MC2Nvb52NLhBAifRqNBicnJ/VxeHh4PramkA8TO3HiBPfu3VMfBwUFqffv3bvH8uXLdcoPGTIkTR2///47zZs3JyYmho4dOzJx4kTatGlDTEwMf/75JwsXLgSgevXqfPrpp7myH3nB1dVV53FCQkI+tUQIIYQQQoj/iY+PB1K+aFlaWuZza4QQIn02NjZoNBoURVHPXfmlUAeDFi9ezIoVK/QuO3nyJCdPntR5Tl8wqH79+qxfv553332X8PBwJk6cmKZM9erV2bVrV74niBJCCCGEEOLfJjk5GUiZvlmj0eRza4QQIn0ajQZTU1MSExNJSkrK17YU6mBQTvH29ubq1av8/vvv7Nq1C39/fywsLKhWrRpvvfUWo0ePxsbGJr+bmS3Xr1/Xeezv74+zs3M+tUYIIYQQQgghhBBZpVFS5zgTIhO0g0F+fn6UL18+n1skhBBCCCEKo7t375KYmIiZmRkuLi753RwhhDAoK+es3Pj+LQmkhRBCCCGEEEIIIQoRCQYJIYQQQgghhBBCFCISDBJCCCGEEEIIIYQoRCQYJIQQQgghhBAiT3l4eKDRaPDw8MjvpghRKEkwSAghhBBCCCFEpkVFRTF//ny6dOlCuXLlsLKywtLSkhIlStCoUSPee+89Fi1ahJ+fX343NUcdOXIEjUaj92ZjY4OzszNeXl4sXbqUuLi4DOtLXTejwFhiYiL9+vVTyzdp0oTQ0NCc2SlR6MjU8kIIIYQQQgghMuX06dO8/fbbPH78OM2yoKAggoKCOH/+PMuWLaNUqVI8e/YsH1qZ92JiYvD398ff359du3bx66+/snPnTipVqpStehMSEujXrx9btmwBoEWLFuzevZsiRYrkQKtFYSTBICGEEEIIIYQQRrtz5w6dOnUiIiICgG7dutGnTx+qV6+OhYUFQUFBXLlyhf3793P48OF8bm3u+vDDD/noo4/Uxy9evMDHx4cZM2bg7+/P9evX6datG5cuXcLU1DRL24iLi6NPnz7s3LkTSBlit3PnTmxtbXNkH0ThJMEgYRRXV1edxwkJCfnUEiGEEEIIIUR++vrrr9VA0LJlyxgyZEiaMh06dGDChAkEBgayYcOGPG5h3ilZsiS1a9fWea5t27YMHTqUunXr4uvry7Vr19iyZQt9+vTJdP2xsbH06NGDffv2ASmv67Zt27C2ts6R9ovCS3IGiULjTMAZpp+dTlBMUH43RQghhBBCiAIpKSmJXbt2AfDmm2/qDQRpK1GiBKNGjcqDlr1eihQpwjfffKM+PnDgQKbriI6OxsvLSw0Ede7cme3bt0sgSOQICQYJo1y/fl3ndujQofxuUqYoisIfl/5g9c3VdN7UmV/O/8LL2Jf53SwhhBBCCCEKlMDAQGJiYgCoVq1arm/vxIkTDBw4kEqVKmFlZYWjoyP169fnm2++ITAwUO86P//8MxqNBnNzcyIjI9Msj42NxcrKSk3EfPnyZb311KxZE41Gw9tvv52lttepU0e9n9kk2pGRkXTp0oWDBw8CKUPxtm7dipWVVZbaIsSrJBgkCoXjT45zNfAqALFJsSy/vhzPTZ78duE3QmND87dxQgghhBBCFBAWFhbq/Zs3b+badpKTkxk9ejQtW7Zk9erVPHr0iLi4OMLCwrh8+TI//PADLi4u7N+/P826rVu3BlJm3zpx4kSa5f/884/OLF9HjhxJU+b58+fcvn0bIMNZvtKj/VqZm5sbvV54eDienp4cPXoUgD59+rBx40ad+oTILgkGiUJh4dWFaZ6LSYxhic8SPDd78selPwiLC8uHlgkhhBBCCFFwODk5UbFiRQCuXLnC9OnTSU5OzvHtfPnll8yZMweAypUrM3/+fM6ePcvhw4cZP3485ubmhIWF4eXlxZUrV3TWbdCggTrLlr5Az6vPZVQmNbiUWdrBMmNnEwsLC6Njx46cPHkSgHfeeYd169ZlKpgkhDEkgbQoFKa3ms7CqwvZdm8bSUqSzrKohCgWXl3IupvrGOg6kHffeJciFjJFoxBCCCHEv0VyskJIdHx+NyNPFbWxwMREkyt1jxkzhgkTJgApQZv58+fTrVs3mjVrhru7O5UrV85W/deuXeOXX34BoHbt2hw/fhxHR0d1uYeHBx07dqRr167Ex8czYsQI/vnnH3W5qakpLVq0YM+ePXoDPak9bry9vdmxYwfHjh0jOTkZExOTNGVKlSrFG2+8kel9SEpKYsaMGepjY5JHh4WF0b59e86fPw/AoEGDWLZsmU67hMgpEgwShUI5u3JMaTaF4bWHM//qfHY+2EmyovsLRkRCBHMvz2X1jdUMcR1C/zf6Y2su0zUKIYQQQhR0IdHxNPw+8wl8C7IL37SnmJ1lrtQ9fvx4bty4wdKlSwHw9fVl1qxZzJo1C0gJoHh4eDBgwAC8vLzQaDIXlJo3b57a22jx4sU6gaBUnp6evPfeeyxevJizZ89y7tw5GjVqpC738PBgz549XLhwgcjISOzs7ICUadrPnDkDwBdffMGBAwcICQnh6tWruLm5qeunBpFatWqVqbYHBgZy7do1Jk2axKVLl4CUQFCLFi0yXFc7d9E777wjgSCRq+TIEoWKs70zP7T4gW3dt9G1Slc0pP1gCo8PZ9alWXhu8mSpz1KiE6LzoaVCCCGEEEK8nkxMTFiyZAl///03np6emJnp9jF4/vw569evp1u3bri7u3P//v1M1Z8685arqyuNGzdOt9z777+fZp1U6eUNOnv2LDExMTg4ONCkSROaNGkC6A4Le/HihTrEK6N8QVOmTFETUWs0GkqWLEm7du04efIkNjY2fPLJJ6xduzbjnQadoNnp06d5+vSpUesJkRUSDBKFUiWHSvzU8ie2dN+CZyVPvWVC40KZeWEmnTd3ZsX1FcQkxuRxK4UQQgghhHh9dejQgT179hAcHMzu3buZMmUK3t7eODg4qGXOnz9Py5YtCQgIMKrOuLg47t69C2AwEARQv359NZeOj4+PzrKGDRuqvYG0Az2p91u0aIGpqaka7NEukzpEDLKeLwjAzc2Njz/+2Oh8Py1atFBnLvP19aVdu3Y8e/Ysy9sXwhAJBolCrapjVWa0nsGmbpvoULGD3jIvY1/y8/mf6bK5C2turiEuKU5vOSGEEEIIIQoje3t7OnfuzKRJk9i+fTvPnz9n6dKlFC1aFICAgAC+/fZbo+oKCQlR75csWdJgWXNzc4oVKwbAy5cvdZaZmZnRvHlzQH+gJzUIlPo3NW+QdpkSJUrg6upqsA0ffvgh165d49q1a1y6dIkdO3YwePBgTExMOHXqFB4eHgQGBhqsI5WJiQmrVq2iR48eANy5c4cOHToQHBxs1PpCZIbkDBICqF60Or96/Mqtl7eYc3kOR/yOpCkTFBPET2d/YqnPUt6v8z69XHphYSrTOwohhBBCvO6K2lhw4Zv2+d2MPFXUJv+uUy0tLRk6dChly5bF0zOlF/7mzZtZuHBhpnLgZDbX0Ks8PDzYt2+fmjfI0tKS06dPq8sgpfeRlZWVTt6g1GCQMfmCSpYsSe3atdXHbm5ueHl50aZNG4YMGYKvry/Dhw9n27ZtRrXZzMyM9evX0717d/bu3YuPjw8dO3bk0KFDOj2uhMguCQaJwuPiSnh6GZp8BMWr6S1S06kmf7T9g+tB15lzeQ7HnxxPU+ZF9At++OcHlvgsYUTdEfSo2gNzU5nqUQghhBDidWViosm1ZMoifZ06dcLZ2Rk/Pz9CQkIIDg6mRIkSBtdJ7U0EKbmHDElMTFR7zTg5OaVZ/mreoCJFihAdHY2DgwP169cHUgJXTZo04ciRIxw5coTy5ctz/fp1ION8QYYMHjyYHTt2sGnTJrZv386hQ4do27atUetaWFiwefNmunbtyuHDh7l48SKdO3fm77//Voe+CZFdMkxMFA7JyXBiJpxfArMbwtp+8PAYKIre4q7FXZnbfi6ru6ymWdlmess8i3rGf07/B++t3my5u4XE5MTc3AMhhBBCCCEKnLJly6r3jenpY2lpiYuLC4DOdPH6XLp0iYSEBACd3jmpGjVqhK1tyuzAR44cUXv8pOYLSqWdN+jYsWMo//8dITv5ggCmTZumbmfixImZWtfa2prt27fTtGlTICWhtLe3NzExksdU5AwJBgmjuLq66tyMjWq/Nu7shZcPdB+v8Ib5LeHyOkiM17tavRL1WNBhASs8V9C4tP4Edk8inzDp1CS6be3Gjvs7SEpOyo09EEIIIYQQokCJjo7mxo0bQEpeodT8Phlp3z5lSN/169c5e/ZsuuUWL16cZh1tZmZmNGuW8sNuas8fSNvjRztv0KFDhwAoVqyY3gBTZlSvXp2+ffsCKYGt/fv3Z2p9Ozs79uzZQ8OGDYGUfejVqxfx8fq/uwiRGRIMEoXDP/P1P//8Gmz9AH6rDcdmQJT+5GwNSjVgcafFLO20lAYlG+gt4xfhx8QTE+mxrQe7H+yWoJAQQgghhPjXiYyMpHHjxuzcuVNNuKxPcnIyY8aMISIiAoBu3boZnQPoww8/VHMLjRgxgvDw8DRl/v77b5YsWQKAu7s7jRo10ltXaqDnwoULnDx5Uue5VI0bN8bS0pKQkBBWr14NpOQLym7OIkjpEZRaz/fff5/p9R0cHNi3bx916tQBYO/evfTr14/ERBmVILJHgkHCKNevX9e5pUbMC4we86DFeLBKJ+la5HM49D3MrAU7xkHgHb3FGpVuxHLP5SzquIh6JerpLeMb7ssXx7+g9/be7PPdR7KS/oekEEIIIYQQBc3Zs2fx9vamQoUKjB49mjVr1nDixAmuXLnC0aNH+e2333Bzc2Pp0qVASkBj6tSpRtdfp04dPv30UwCuXLlCgwYNWLRoEefPn+fo0aNMmDABLy8vkpKSsLCwYMGCBenWpZ036NV8QamsrKxo0qQJAGFhYUD28gVpq127Nt26dQNSeh6dOHEi03UUK1aM/fv3U6NGDQC2bt3KoEGDDAbjhMiIJJAWhYNDOWg/GVpOgCvr4Mxc3WFjqRJj4cKylJtLR2g6Ciq3Bq1fBTQaDU3KNKFx6cacenqK2Zdm4xPsk6aq+2H3mXB0Ai5FXRhVbxRtK7TNkV8XhBBCCCGEyC9mZmaULl2aZ8+e8eTJE+bMmcOcOXPSLe/i4sK6deuoVKlSprbz008/ERUVxdy5c7l//z4jRoxIU8bBwYENGzbg5uaWbj3u7u7Y2NgQHR0NpM0XlMrDw0PNKQTZzxek7euvv1ZnE5s6dSr79u3LdB2lSpXi4MGDtGzZkocPH7Ju3Tqsra1ZvHixfMcQWSI9g0ThYmkH7u/D6Avw9jqo1DL9snf/hpXdYX4LuLQGEuN0Fms0GpqXa87armuZ3XY2bzi9ob+akLuMOzKOfjv7ccTviJqQTgghhBBCiILGysqKJ0+ecPLkSaZMmULnzp2pUqUKtra2mJqaYm9vT82aNenXrx9r167Fx8dHzXmTGSYmJsyZM4djx44xYMAAKlSogKWlJfb29ri5uTFx4kTu3r1Lx44dDdZjbm6uJmGG9Hv8aD/v5ORE3bp1M93m9DRq1IgOHToAKcPbzp07l6V6ypUrx6FDh3B2dgZg6dKljBkzJsfaKQoXjSLfTEUW+Pv7qychPz8/ypcvn88tyoanl+HMPPDZCIZmBLMtmRJIevM9sC2eZrGiKBzyO8Tcy3O5E6J/mBmAazFXRrmNokW5FhLFF0IIIYTIprt375KYmIiZmZk6C5UQQryusnLOyo3v39IzSIiybtBrAYzzgZafgnVR/eWiXsDhH2CmK2z/GF7c0lms0WhoV6Edf3n/xS+tf6GqQ1W91VwPvs5HBz9i4J6BnH56WnoKCSGEEEIIIYTIUxIMEiKVfRloNwnG34Cuv0KxavrLJcbCxRUwtzGs7g33D4FWQMdEY0LHSh3Z1G0T/231XyrZV9JbzZXAK4zYP4Ihe4dw7lnWuooKIYQQQgghhBCZJcEgIV5lYQONhsGoc/DOeqjcKv2y9w7Aqp4wrxlcXAUJseoiUxNTOlfuzNbuW5nWYhoVilTQW8XFFxd5b997DNs3jIvPL+b03gghhBBCCCGEEDokGCREekxMoIYnDN4BI49Dvf5gYq6/7IsbsH00/FYbjvwEkYHqIlMTU7yrerOtxzamNp9KObtyeqs4++wsg/cOZsTfI7j84nIu7JAQQgghhBBCCCHBICGMU6Yu9JwH432g1Wdg7aS/XFQgHPkxJa/QttHw4qa6yMzEjB7VerCj5w6+a/odZWzL6K3idMBpBu4ZyIcHPsQnKO2U9UIIIYQQQgghRHZIMEiIzChSGtp+A+Ovg9dMKF5df7mkOLi0CuY2SRlGdveAmlfI3MScPtX7sLPnTr5p/A0lbUrqreLEkxO8s+sdxhwcw83gm3rLCCGEEEIIIYQQmSXBICGywsImZYr5j/6B/n9BFY/0y94/BGt6pwSGLiyHhJiUKkwt6FezH7t77eZL9y8pbp12unqAI/5H6LuzL+MOjzM4Zb0QQgghhBBCCGEMCQYJkR0mJlC9IwzaBh+cBLcBYGqhv2zgLdgxNmUI2eFpEPkCAEtTSwa8MYA9vfbw2Zuf4WSlfwjawccH6b29NxOOTuB+6P3c2iMhhBBCCCGEEP9yGkXRmhNbiHS4urrqPE5ISODu3bsA+Pn5Ub58+fxo1usp4jmcWwznl0B0cPrlTC2gTl9o+hGU+t/rG50Qzfrb61nqs5TQuFC9q2rQ0LlyZz6o9wGVHSrn8A4IIYQQQhQcd+/eJTExETMzM1xcXPK7OUIIYVBWzln+/v44OzsDOff9W3oGCZHTipSCtl+n5BXyngXFa+gvlxQPl1enTEu/sjvc3Q/JydiY2zC09lD29t7L2AZjsbewT7OqgsLuh7vpsa0HX5/4msfhj3N5p4QQQgghhBBC/FtIzyCRJbkRmfzXUhS4dxDOzEnJH2RI8erQ5COo9zaYWwMQER/BmptrWHl9JREJEXpXM9WY0q1qN0bUHUH5IvJeCCGEEKLwkJ5BQoiCRHoGCVFYaDTg0h4GboEPT0P9gWBqqb9s0B3YOQ5+rQWHvoeI5xSxKMIH9T5gT+89jKw7Eltz2zSrJSlJbLm3Be8t3kw5PYWAyIDc3SchhBBCCCGEEAWWBIOEyEulakH32SlDyDy+Ahv9M4gR8xKOzUhJNr3lQ3h2DQdLB0bXH83eXnsZXmc41mbWaVZLVBLZeGcjXbZ04fsz3/M86nku75AQQgghhBBCiIJGgkFC5Ae7EuDxZUpQqNsfUOIN/eWSE+DKWpjfAlZ4w+29OFrYM7bBWPb23stQ16FYmVqlWS0xOZH1t9fTZXMXpp+dTlBMUC7vkBBCCCGEEEKIgkKCQULkJ3MraDAIPjoN726Gau3TL/vwGKzrB3MawbnFOJlY8cmbn7Cn9x4G1hqIpZ6hZ/HJ8ay+uZrOmzrz87mfCY4xMLuZEEIIIYQQQohCQYJBQrwONBqo1g7e3QQfnUkJEKWXVyj4Huz6FGbWgoP/oXhCAp83+pzdvXbTv2Z/zE3M06wSmxTLihsr6Ly5MzMvzCQkNiSXd0gIIYQQQgghxOtKgkFCvG5KvpEydGz8dfCYCLYl9JeLCYHjv8BvdWDzSEqGBfBV46/Y3Ws3/Wr0w8zELO0qiTEs9VmK5yZPZl2cRVhcWC7vjBBCCCGEEEKI140Eg4R4XdmVAI8vYJwPdJ8DJV31l0tOgKt/woJWsNyL0v6X+MZ9Irt67qK3S2/MNGmDQtGJ0Sy6tgjPTZ7MvTyX8PjwXN4ZIYQQQgghhBCvCwkGCfG6M7eC+u/Chydh4FZw6Zh+Wd/jsO5tmP0mZW/sYvKbn7G953a6V+2OiSbtv3tkQiTzrszDc5MnC64sIDI+Mvf2QwghhBBCCCHEa0GCQUIUFBoNVG0DA/6CUWeh4VAwSzuTGAAv78PuCfBrLZz/WcL3dT9ie4/teFXx0hsUioiPYPbl2Xhu9mTJtSVEJ0Tn8s4IIYQQQgghhMgvEgwSoiAqUQO8f4PxN6DtN2BXSn+52FA4MRN+q0PF/d/zY9W+bOm2hc6VOqNBk6Z4WFwYv138jc6bO7Pi+gpiEmNydTeEEEIIIUTBFRUVxfz58+nSpQvlypXDysoKS0tLSpQoQaNGjXjvvfdYtGgRfn5+adYdMmQIGo0mzc3ExARHR0fq1avHqFGjuHz5cq603dfXV+/2NRoNVlZWlC1blo4dO/L7778THp5xSoVKlSqh0WioVKlShmU/+eQTdVsuLi56Xx8hcptGURQlvxshCh5/f3+cnZ0B8PPzo3z58vncokIuMQ58NsHpufD8muGyFZtDk4+4W9KFedcWsP/R/nSLFrMqxvA6w+lTvQ9W6fVCEkIIIYTIR3fv3iUxMREzMzNcXFzyuzmFxunTp3n77bd5/PhxhmVLlSrFs2fPdJ4bMmQIK1asyHBdExMTvvzyS3744Ycst1UfX19fKleubFRZZ2dntm7dSoMGDdItU6lSJR49ekTFihXx9fXVW0ZRFD7++GNmz54NQM2aNTl48CBly5bNdPtFwZWVc1ZufP9Om1lWCFHwmFmCW3+o9w48PAZn5sKdvfrLPjoJj07iUrQyvzb5kNudBjL3xnIO+R1KUzQ4Npjp56azzGcZw+sOp7dLbyxMLXJ5Z4QQQgghxOvszp07dOrUiYiICAC6detGnz59qF69OhYWFgQFBXHlyhX279/P4cOHM6xv3759akAkOTmZ58+fs2vXLubMmUNiYiLTpk2jXLlyfPTRR7myP927d+f7779XH4eEhHDr1i1mzpzJzZs38fPzo2vXrty+fRt7e/ssbUNRFD744AMWLlwIgKurKwcPHqRUqXR6+AuRyyQYJMS/iUYDVVqn3ILupgSFLq8DfcO9Qh7Cns+pYeXA7w2HcN1jFvPubeKo/9E0RV/EvGDaP9NYcm0JI+qOoGe1npibmufBDgkhhBBCiNfN119/rQaCli1bxpAhQ9KU6dChAxMmTCAwMJANGzYYrK969epphld16NCBdu3a0a1bNwAmT57MyJEjMTU1zZF90Obo6Ejt2rV1nmvZsiVDhgyhVatWnDlzhmfPnrFw4UImTJiQ6fqTk5MZPnw4y5YtA6BevXocOHCA4sWL50j7hcgKyRkkxL9VcRfwmgmf3IC234Jdaf3lYsPg5O+4Lu/N7JAY1jaaRPNyzfUWfR79nKlnpuK1xYvNdzeTkJyQizsghBBCCCFeN0lJSezatQuAN998U28gSFuJEiUYNWpUlrbl7e1Ny5YtAQgMDOTixYtZqierzM3NdXoMHThwINN1JCUlMXjwYDUQ1LBhQw4fPiyBIJHvJBgkjOLq6qpza9u2bX43SRjLxglaTYBx16DnAihdR385JQl8NlJnw3DmP/Zl1RsjaVK6sd6iT6Oe8t2p7+i+tTvb728nMTkxF3dACCGEEEK8LgIDA4mJSel1Xq1atVzfnru7u3r/0aNH6v0HDx7wyy+/4O3tTaVKlbC2tsba2pqKFSvSr18/9u5NJ2VCJtWp879r58wmek5MTGTAgAGsXr0agCZNmnDw4EGKFi2aI20TIjtkmJgQhYWZBdR7G+r2A98TKUPIbu8B9OSQf3wat8enWVS0Eufr9GBO7CPOB15KU8wvwo+vT3zNoquL+KDeB3hW8sTUJOe77gohhBBCiNeDhcX/8kfevHkz17dnbv6/1ARJSUkAPHz4kKpVq+ot//jxYx4/fsyGDRt49913WbZsGWZmWf/aq72/2m3JSEJCAm+//TabN28GoEWLFuzevZsiRYpkuS1C5CTpGSSMcv36dZ3boUNpkw2LAkKjgcot4Z11MOYCNHofzG30lw3x5c1js1h65TCLizalvlMtvcV8w3358viX9Nrei72+e0lWknNxB4QQQgghRH5xcnKiYsWKAFy5coXp06eTnJx7137Xrv1vptzUJNNJSUlYWFjg7e3NrFmzOHDgABcvXuTAgQPMnTsXV1dXAFavXs3UqVOztX3tgJcx08YDxMfH06dPHzUQ1KZNG/bu3SuBIPFakZ5BQhRmxapC15+hzUS4sBzOLoSIgDTFNHHhNL64HneNKadreDDHMomr4Q/SlHsQ9oDPjn7GAscFjHIbRdsKbTHRSMxZCCGEEPksORliXuZ3K/KWtROY5M512JgxY9REyl9++SXz58+nW7duNGvWDHd3d6OnbM/IlStX1OFeNjY2NGrUCIAyZcrg6+tLmTJl0qzTrl07PvjgA9577z2WL1/OL7/8wieffIKDg0OW2vDjjz+q9/v06ZNh+YSEBHr27Mnu3buBlETY27Ztw9raOkvbFyK3SDBICJGSV6jlJ9B0NNzYCqdnQ8CVNMU0ShLNbh2kKXC8Qj3mFLHmRvTTNOXuhd5j/JHx1HSqyUf1PsLD2QONRpP7+yGEEEIIoU/MS5ihf1jRv9Zn98E2d5IUjx8/nhs3brB06VIAfH19mTVrFrNmzQKgVKlSeHh4MGDAALy8vDJ1HagoCs+fP2fnzp189dVX6tCwjz/+GCsrKwBsbW2xtbVNtw6NRsMvv/zCqlWriIqK4sCBA/Tu3dvoNoSGhnLz5k1+/PFHduzYAUDTpk3p169fhus+ffqUp09Tro9bt27N9u3b1XYL8TqRn+yFEP9jZgF1+8KIozBkN9T0AtJ+eGuAVo+v8Of1M8yKhBqWxfRWd+vlLT4+/DFv73qbY/7HUBQ9+YmEEEIIIUSBYmJiwpIlS/j777/x9PRMk5Pn+fPnrF+/nm7duuHu7s79+/cN1le5cmU0Gg0ajQYTExPKlCnD+++/T1BQEABdu3blP//5T7rrJyQk4O/vz82bN/Hx8cHHx4enT59SrFjKNeqVK2l/5NS2YsUKdfsajYaiRYvSrFkzduzYgbm5OUOGDGHv3r1G5QzSDnxdu3aNO3fuZLiOEPlBgkFCiLQ0GqjUHN5ek5JXyH0EmKf99UUDtAl8zIZbl5gZHEk1M/3joG8E32DUwVG8u/tdTj05JUEhIYQQQoh/gQ4dOrBnzx6Cg4PZvXs3U6ZMwdvbW2dI1vnz52nZsiUBAWlTERhiYWFB8+bNWbFihRqU0ZaQkMCcOXNo0qQJdnZ2ODs7U6tWLerUqaPeXrx4AaAGlbLCxcWF8ePHY29vb1T5ChUq8NlnnwHw8uVLOnTowK1bt7K8fSFyiwwTE0IYVqwqdJnx/3mFVqTkFQp/olPEBGgf/pK24S/529aWeaXK8UCJTVPV1aCrjDwwkvol6zPKbRSNy+iful4IIYQQQhQc9vb2dO7cmc6dOwMQFxfH2rVr+fTTTwkJCSEgIIBvv/2WxYsX611/3759anJoExMT7OzsKF26tM5MXtpevnxJx44duXDhglHti4mJMbi8e/fufP/99wAkJyfz9OlT9u7dy4IFC7hx4wYeHh6cPn2aGjVqGLW9//73v8TExDB79mxevHhB+/btOXbsGFWqVDFqfSHyggSDhBDGsS4KLcZB01FwY1tKXqGnutPNmwCeUVF0eHCHvbY2zCtRikeapDRVXXpxieF/D6dR6UZ8VO8j3iz9Zt7sgxBCCCEKJ2unlBw6hYm1U75t2tLSkqFDh1K2bFk8PT0B2Lx5MwsXLsRET1Lr6tWrGz1TF8DYsWPVQFCPHj147733qFu3LiVLlsTKykodqlWhQgX8/Pwy7JXu6OhI7dq11cd169bF09MTb29vPD09CQkJoX///pw9exZTU1Oj2jhr1iyio6NZunQpT548oV27dhw7dgxnZ2ej91OI3CTBICFE5piaQ50+ULs3PD6TEhS6tQv434esKdA1KppOUQ/ZZWfLfCcn/E3T5h469+wcQ58NpUmZJoxyG4VbSbc82w0hhBBCFCImJrmWTFmkr1OnTjg7O+Pn50dISAjBwcGUKFEiW3WGh4ezfv16AAYMGMDq1avTLRsSEpKtbbVr146xY8fyyy+/cPHiRZYvX86wYcOMWlej0bBo0SJiY2NZu3Ytvr6+akCodOnS2WqXEDlBcgYJIbJGo4GKTVPyCn18ERp/ABZ2OkXMgO6RUWx/7MeUwGDKJqbtJQRwJuAMA/cM5IMDH3At8FoeNF4IIYQQQuSF1OFfQI7MLnv37l0SEhIADM7udevWLSIjI7O9vYkTJ6r5gqZMmUJ8fLzR65qYmLBixQp69eoFpLS9ffv2BAcHZ7tdQmSXBIOEENnnVAU6T4fx16HDVLAvr7PYHOgVGcVOvyd8G/SSUomJeqs5+eQk/Xf3Z/TB0dwIvpEHDRdCCCGEELklOjqaGzdSruns7e3V2b2yI1HrOjIqKirdcvPnz8/2tgCcnJwYNWoUAH5+fqxYsSJT65uZmbFu3To1n9L169fp2LEjYWFhOdI+IbJKgkFCiJxj7QjNP4axV6DPUijXUGexOdA3IpLdfk+ZGPSSEukEhY76H6Xfzn6MPTSW2y9v5367hRBCCCGEUSIjI2ncuDE7d+4kOTk53XLJycmMGTOGiIgIALp165YjPYOqVaum1rNixQq9+YB27NjB7Nmzs72tVOPHj8fGxgaAn376iaQk/b3d02NhYcHmzZtp27YtABcvXsTT0zNHei4JkVUSDBJC5DxTs5ScQsMPwnt/Q63uoPnf6cYCeCcikt3+AXwRHEKxdIaPHfI7RJ8dffjkyCfcC7mXR40XQgghhBCGnD17Fm9vbypUqMDo0aNZs2YNJ06c4MqVKxw9epTffvsNNzc3li5dCoCDgwNTp07NkW0XK1aMLl26ALB37146duzI5s2buXDhAnv27GH48OH07NmTKlWqZDs/UaoSJUrw/vvvA/DgwQPWrl2b6TqsrKzYvn07zZs3B+DMmTN4eXllONOZELlFEkgLIXKPRgMVGqfcQnzhnwVwcRXEp/xCZKUovBseQe+ISDYUsWOJoz0hemZo2P9oPwceHcCzkicfuH1AFQeZllMIIYQQIj+YmZlRunRpnj17xpMnT5gzZw5z5sxJt7yLiwvr1q3L1GxhGZk3bx4tWrTg8ePHHDhwgAMHDugsr1ChAlu3blWDRjlhwoQJzJs3j/j4eH788UcGDBigd2Y0Q2xtbdm9ezft2rXj/PnzHD16lJ49e7J9+3YsLCxyrK1CGEN6Bgkh8kbRSuD5I3xyHTr+AA4V1EXWisLg8Aj2+j1l7MtQHPR0vVVQ2OO7h57bevLV8a94FP4oDxsvhBBCCCEgpYfLkydPOHnyJFOmTKFz585UqVIFW1tbTE1Nsbe3p2bNmvTr14+1a9fi4+NDw4YNM644E5ydnbl48SKfffYZ1atXx9LSEgcHB+rVq8d3333H5cuXqVWrVo5us3z58gwePBiAmzdvsmnTpizVY29vz759+6hbty4A+/bto1+/fjq5kITICxpF3yBLITLg7++Ps7MzkJJIrXz58hmsIcQrkhLh1g44PRf8z+ositRoWONQhBX29kSY6o9Zm2pM8a7qzYi6I3Au4pwXLRZCCCHEa+ju3bskJiZiZmaGi4tLfjdHCCEMyso5Kze+f0vPICFE/jA1A9eeMHw/DNsPtXqoeYXsFIWRoeHs9X/ChyFh2OlJTpikJLH13la6benG5FOTeRr5NI93QAghhBBCCCEKJgkGCSHyn7M79F0BH1+GpqPB0h4A+2SFj0LD2Ov3lPdDw7DRExRKVBLZdHcTXbd05fsz3/Ms6lkeN14IIYQQQgghChYJBgkhXh9FK0KnH2D8dej0Izim5BVySE7m45CUoNB7oWFY6wsKJSey/vZ6umzuwo///EhgdGBet14IIYQQQgghCgQJBgkhXj9W9tD0o5SeQn1XgnMTAIomJzM+JIw9fk8ZHBaOpZ6gUEJyAmtvraXz5s7899x/CYoJyuPGCyGEEEIIIcTrTaaWF0ZxdXXVeZyQkJBPLRGFiokp1OqecvO/AGfmwPWtFEtOYsLLUAaHhbPEwYG/itgRb6LRWTUuKY5VN1bx1+2/eKfmOwypPQQnK6d82hEhhBBCCCGEeH1IzyAhRMFQviH0WQpjr0CzMWDpQImkZL58GcIu/6f0C4/ATM/kiLFJsSy7vgzPTZ78fvF3QmND877tQgghhBBCCPEakanlRZbI1PIi38VFwKU18M88CPEFIMDUlIWO9mwtYkeiRqN3NVtzWwbWGsjAWgOxt7DPwwYLIYQQIjfI1PJCiIJEppYXQojssCwCTT6AMReh32qo0IwySUl8FxzCDv+n9IyIxFRPrDsqIYr5V+bjudGT+VfmExkfmQ+NF0IIIYQQQoj8I8EgIUTBZmIKb3jDe3vg/UNQuw/lk+A/QS/Z7h9At4hITPQEhSISIphzeQ6emz1ZfG0x0QnR+dB4IYQQQgghhMh7EgwSQvx7lGsIfZbAuKvQfCwVzGz5IeglW58E0CUyCo2eoFBYXBi/X/wdz02eLPNZJkEhIYQQQgghxL+eBIOEEP8+DuWhw39g/A3oPIPKds5MDwxmy5MAOkVG6V0lJC6EXy/8SufNnVl5fSWxibF53GghhBBCCCGEyBsSDBJC/HtZ2kHjETDmAry9lqplm/BzYDAb/QNoF6W/B9DL2JfMOD+DLps7s+bmGuKS4vK40UIIIYQQQgiRuyQYJIT49zMxhZpdYeguGHGUGm/04regUNY/CcAjnaBQYEwQP539ia6bOrPh9gYSkhLyuNFCCCGEEEIIkTskGCSEKFzKukGvhTDuGrXcR/NHeCLrnjyjRXSM3uLPYwKZemYqXTd2ZNOdTSQkS1BICCGEEEIIUbBJMEgIUTjZl4X2k+GTG9RuP415CfasevqMpjH6g0IBsUFMPj0Z7w0d2Hp3K4nJiXnbXiGEEEIIIYTIIRIMEkIUbha24P4+jD6PW6+VLLSqyfKnz3GP0Z9A+klcMN+e+pbu69uw485mkpKT8rjBQgghhBBCCJE9EgwSQggAExOo0RmG7KThkP0sKdmWJc+CaRCrPyj0OD6Uiae/o+e6Vuy5uZ5kJTmPGyyEEEIIIYQQWSPBICGEeFWZetBrAe4fXmB51XdZ+DKaurH6ZxV7mBjO52e/p/fqpqw/9zthcWF53FghhBBCCCGEyBwJBgkhRHrsy6Bp/x1NR11jdf3PmRtthmuc/qDQveRovr+xGI8/WzB2R38O+O4nPik+jxsshBBCCCEKm0qVKqHRaBgyZEiubWPIkCFoNBoqVaqUa9sQeUuCQUIIkRELGzTuw2k58gLrWs/iD0pSM05/oCcROPTyGuOPfoLH2mZMOTmJi88vyjAyIYQQQvwrJSYmsmnTJkaMGEGdOnUoWbIk5ubmODg4UK1aNXr27MmMGTN4+PBhfje1UFMUhe3bt/POO+/g4uKCnZ0dZmZmODo6Urt2bd566y1mzJjBlStX8rxtbdq0QaPRoNFo6Nixo9HreXh4qOtp30xNTXFycqJhw4aMHTuW69evZ1jX5MmT1fWPHDlisOyJEyewt7dHo9FgZmbG6tWrjW7z68QsvxsgRF6KS0zC0sw0v5shCioTEzQ1PPGo4UnrgKscOj6VOWHXuGthrrd4RHIcG+9tYeO9LZSzKU3Xat3wquJFZYfKedxwIYQQQoict337dj799FPu3buXZll4eDjh4eHcv3+frVu38vnnn9O1a1d++uknateunQ+tLbyeP39Onz59OHHiRJplYWFhhIWFcf36dTZu3Mjnn3/OzZs3qVmzZp607dGjRxw9elR9fPDgQZ4+fUrZsmWzXGdycjIhISGEhIRw8eJF5syZw/fff8+XX36Z7fYeOXIELy8voqKiMDMzY82aNfTt2zfb9eYHCQaJQuPWs3AGLz3Ld96udKlTJr+bIwo4TZm6tOv7F23Cn3L8+Pds9zvEEUsz4k00ess/iX7GwqsLWXh1IbWL1carqheelTwpZl0sj1suhBBCCJF933//PZMmTUJRFCCll4aXlxd169alWLFiREdHExAQwLFjx9i5cye+vr7s2rWL8uXLM3/+/HxufeERHx9Phw4duHbtGgD169dn6NChuLm5UaRIEcLDw7l58ybHjh1j165dhIXlbf7LVatWoSgKlpaWJCUlkZiYyOrVq/n8888zVU/q/kHKPj948ICtW7eyZs0akpKS+Oqrr6hatSpvvfVWltt64MABunXrRkxMDObm5qxfv56ePXtmub78JsEgUSjEJSYx7s/LPA+P46M1F+ndoDyTu9WiiJX+Hh1CGMvEviytu86ldXw04ZdWsP/SQnYSyXlrq3TX8Qn2wSfYhxnnZtC8XHO8q3jj4eyBlVn66wghhBBCvC6WLl3Kt99+C0CpUqX4888/8fDw0Fv2rbfe4rfffuPPP/9k4sSJedhKAbBo0SI1UDJ06FAWL16MiYlutphWrVoxcuRI4uLiWLduHY6OjnnWvlWrVgHg5eVFTEwMu3fvZtWqVZkOBr3a26xBgwb06dOHxo0b8/HHHwMwZcqULAeD9uzZQ69evYiNjcXS0pKNGzfi5eWVpbpeF5IzSBQKv/x9h1vPItTHmy760/n345zzfZmPrRL/KhY22Df+kN4jLrCs7Rz2UZ6PX4ZSOT4h3VWSlCSO+R/js2Of4bHBg29PfsvZgLOSX0gIIYQQry0/Pz9GjRoFgL29PSdOnEg3EJTK1NSUAQMGcOXKFbp27ZoHrRSptm3bBoCZmRm//vprmkCQNktLS4YMGULp0qXzpG1nzpzhzp07AAwYMIB3330XAB8fHy5evJgj2xg1ahQVKlQA4Pr16zx79izTdezYsYMePXoQGxuLtbU127ZtK/CBIJBgkCgEkpMVnoTGpHnePySGfgtO89+9t4hPlC/fIoeYmED1jpQdvIf339nNtmIe/BkQyLth4TglJaW7WlRCFFvvbWXY38PouLEjMy/M5F5I2vH3QgghhBD56ddffyU2NhaAH374gWrVqhm9rqOjI97e3ukuf/bsGV9//TVvvvkmTk5OWFpa4uzsTN++fTlw4EC66/n6+qrJf5cvXw7A/v378fb2pnTp0lhaWlK5cmU+/PBD/P39jWrr4cOHGTx4MFWqVMHGxgZ7e3vq1KnDZ599xtOnT9NdTzsRMaTk5Jk6dSr169fH0dFRp40AUVFRrF+/nuHDh+Pm5oaDgwPm5uaUKFGC1q1b8/PPPxMZGWlUm/V5/PgxAMWLF8/RHj+hoaFMmjQJV1dXbG1tcXR0pFWrVqxZs8boOlauXAlA0aJF6dq1Kz169KBIkSI6y7LLxMQEV1dX9bGfn1+m1t+8eTO9e/cmPj4eGxsbdu7cSadOnXKkbflNhomJfz0TEw2z36lP2xol+W77dSLjEtVlyQrMPXKfY3cD+a2fG9VKFsnHlop/ndJ10PSaj2vEZFzPLuLT80s4TQw77Ww5ZGNNbDq/zDyPfs5Sn6Us9VlKTaeaeFXxokvlLpSwKZHHOyCEEEII8T+KoqjDeooUKcLQoUNzrO41a9YwcuRIoqKidJ739/fnr7/+4q+//mLYsGHMnz8fMzPDX2O/+uorfvrpJ53nfH19mT9/Pps2beLo0aO88cYbeteNjY1l6NCh/Pnnn2mW+fj44OPjw7x581i3bp3BwBbA3bt36dixI76+vumW6dq1q04C5VRBQUEcO3aMY8eOMXfuXHbv3p2lpM4WFhZAShLply9f4uTklOk6XnX79m08PT3T7Nfx48c5fvw4p0+fZvbs2QbriI+PZ/369UDKUMLUdvbq1YsVK1awbt06fv755wzfa2Ok1g1gbm58mpD169fz7rvvkpiYiJ2dHbt376Zly5bZbs/rQnoGiUJBo9HQu2F59oxtiXultCdAnyfhdJ11gpWnfdUkeELkmCKlod23mI2/Qcu205ieXJQjj5/wfWAwTWJi0Bg45m69vMXP53+m/cb2jNw/kh33dxCdEJ2HjRdCCCGESOHj40NwcDAALVu2xNbWNkfq3bBhAwMHDiQqKooqVarw66+/snfvXi5cuMCmTZvo0qULAEuWLMkwl8yiRYv46aefaN26NWvXruX8+fMcOHCAQYMGARAYGMh7772nd11FUejTp48aCPL29mbVqlWcPHmS06dP8/vvv1OhQgWioqLo06cP58+fN9iWPn368OTJE8aMGcP+/fs5f/4869ato0aNGmqZxMRE6tSpw9dff82WLVv4559/OHPmDOvXr+ftt9/GxMSEhw8fqsOUMqtBgwbqvr3//vvZ6mUEEB0djbe3N8HBwXzzzTccOXKE8+fPs2jRIsqXLw/AnDlz2Ldvn8F6du7cycuXKSk7UoeHad9/8eIFe/fuzVZbU928eVO9X7FiRaPWWbNmDQMGDCAxMRF7e3v+/vvvf1UgCKRnkChknJ1sWDeiCQuO3Wfm/jskJP3vS3hcYjKTtl3n4M0XzOhTl5L2ksxX5DALG2g0DBoOxfbu33Q/M4fuD4/x3NSU3XY27LS15Y6lhd5Vk5VkTj09xamnp7A2s6ZdhXZ4V/GmcZnGmJqY5vGOCCGEEAVLspJMaFxofjcjTzlaOmKiydnf/q9evareTw0yZFdQUBAjRoxAURTee+89FixYoNMbpEGDBvTq1Yuvv/6aadOm8fvvvzNy5EidgIq2U6dO8f7777NgwQJ1qBZAu3btsLCwYPHixZw5c4ZLly5Rv359nXUXL17Mrl27MDc3Z/v27Xh6euosb9KkCQMHDqRly5Zcv36dcePG6Z2uPZWPjw979uyhY8eO6nMNGzbUKbNs2TJcXFzSrNu4cWP69u3LsGHD6NSpE7dv32bNmjUMGzYs3e3p89FHH7Fq1SqSk5PZvHkzhw4dwtvbm5YtW9K4cWNcXV0xNTX+WjIwMJD4+HhOnz6tM/yqYcOGeHh4UKdOHWJjY5k7d67B4VSpw8AqVapEixYt1Ofbtm1L2bJlefr0KStXrsx2bp7NmzereYnatWtH0aJFM1xn5cqVrFixguTkZIoWLcq+ffto1KhRttrxOpJgkCh0TE00fORRjVYuJRi3/jL3XuhGx4/eCaTTb8f4sVcdPGvLFPQiF5iYQA3PlFvAVUqdmcvQaxsZGhbBbXNzdtrZstvOhhfpdIuNSYxh54Od7Hywk+LWxelSuQveVb2pUbSGzkWPEEIIIVKExoXSen3r/G5Gnjra7yhOVtkfEqQtKChIvV+iRPrD15OTk7lx40a6y2vUqKEO15k3bx5hYWGUK1eOuXPnpjssaMqUKaxYsYInT56wcuVKfvjhB73lypQpwx9//KH3mmjChAksXrwYSBnSpB0MUhSF6dOnA/Dxxx+nCQSlKlq0KDNmzKBLly6cPHmSu3fv6g3mAAwZMkQnEKRPeuumat++Pd26dWPr1q1s3bo108Egd3d3FixYwEcffURCQgKhoaGsWrVKHe5na2tLs2bNeOutt+jfv79Rvb2mTp2qEwhKVa1aNXr06MGff/5pMEgWHBzM7t27Aejfv7/Oe2ViYkL//v35+eef2bFjB6GhoZnOdZQ6tfyWLVv4/vvvAbCxsUn3mHnVsmXLALC2tubgwYNpgob/FjJMTBRatcs5sHNMC4Y0q5RmWUh0Ah+svshnf13RyTEkRI4rUxd6zodx16Dlp9QwteXTkFD+9nvKwoDndIuIxCY5/QTnQTFBrLyxkrd2vEWv7b1Ycm0Jz6IyP0uCEEIIIURGIiL+NzuvoaBBeHg4derUSff25MkTtez27duBlKnFLS0t063TzMyMpk2bAnD69Ol0y/Xp0yfdemrUqIGdnR0ADx480Fl248YN7t+/r9ZhSKtWrdT7htoyYMAAg/XoExgYyN27d9X8RD4+Pmrg7cqVK5muD2D48OFcu3aNoUOHqgmaU0VFRbF//35GjBiBi4tLhkOzNBoN/fv3T3d5as+nly9fEhoaqrfMunXrSEhImXFXe4hYqtTnYmNj+euvvwy2R7tdqTdLS0veeOMNJk6cSHR0NA0aNODvv/+mcePGRtcFEBMTw65du4xapyCSYJAwiqurq86tbdu2+d2kHGFlbsrkbq6seM+dkkXSfmj8dcGfzr8f47xMQS9ym30ZaDcJPrkBXX/B1KkqTWPj+CHoJYcfP+GnF0G0iI7B1EB+oXuh9/jt4m903NiRYfuGseXuFiLjszcuXAghhBAilXYg4dVEz1mRlJTE5cuXAdRhXYZuGzduBDA4PXhGSZZThwlpB7YAnfw/TZs2NdiO1IBSRm2pW7euwbakOnnyJP369aNYsWKULFmS6tWr6wTPFi1aBOj2zMqsGjVqsHTpUoKDgzl16hS//vorAwYMUPP8AAQEBODl5WVw5rbixYtTrFixdJdrJ6h+9TVOtWLFCiBlCKC+RN716tWjdu3aQPZnFbOwsGDYsGE0b97c6HWmTZumHuvffvstM2fOzFYbXlcSDBICaF29BPvGtaJz7dJplvm9jKHvgtP8vO82CUkyBb3IZRa20Gg4jD4P7/wJlVpioyh0jYpm3vNADjx+wufBIdSKi0u3CgWFs8/OMunUJDw2ePDZ0c845n+MhOSEPNwRIYQQQvzbaAcBAgMD0y3n6OiIoig6t8GDB6cp9/LlSxITM98LPzo6/ck0bGxsDK5r8v+zuSYlJek8/+LFi0y3I6O2GJOfZvLkybRo0YINGzaoCZXTExMTk+n2vcrc3JymTZsyfvx4Vq9ejZ+fHwcPHlSHfSUlJfHRRx+lO6mOsa9val2vunnzphp409crKNXAgQOBlEDZw4cPDe8UcO3aNfV27NgxZs+eTdWqVYmPj2fUqFHMmDEjwzpSNWnShJ07d6r7+sknnzB//nyj1y8oJGeQMMr169d1Hvv7++Ps7JxPrckdRW0tmDugAZsuPmGyninoZx++x7G7gczs50bVEnYGahIiB5iYQI3OKbeAK3B6LvhspHhyIgPDIxgYHsF9czN22tmyy86WgHTG18clxbHXdy97fffiZOWEZyVPvKt641rMVfILCSGEKDQcLR052i/t9N3/Zo6WjjleZ7169dT7ly5dynZ92sGC4cOHM3bsWKPW054qPKdot2XHjh1UqlTJqPVKliyZ7rKMEjMfPHiQKVOmAFClShUmTJhAixYtqFChAra2tmr+pEmTJjF16lSj2pMVbdu2Zf/+/dSuXZuXL19y9+5dLl++nCu5crR7+nzyySd88sknBssrisLKlSv57rvvDJZL7UmUqmXLlgwaNIgWLVpw9epVJk6ciIeHh9GJoFu1asXWrVvx9vYmLi6Ojz76CBsbG3VWun8DCQYJoUWj0dCnYXkaV3Zi/PrLnH8UorP8qn8YXWcd5+sub/Buk4ryZVrkjTL1oNcCaP8dnF0E55dCbChVExIZGxLGmJAwLlhZstPOlr9tbYg00d/p82XsS9beWsvaW2upZF8JrypeeFX1opxduTzeISGEECJvmWhMcjyZcmFUu3ZtihUrRnBwMMePHyc6OjrDniKGaA8pUhQlzRf6vKTd68nR0TFP2pI6/Kto0aKcOXMm3aTcGfUYygllypSha9euamLpe/fu5XgwKDk5mTVr1mR6vVWrVmUYDNKnSJEirFy5kgYNGpCYmMinn37KsWPHjF6/Q4cO/PXXX/Tu3ZuEhATee+89rKys6Nu3b6bb8jqSYWJC6OHsZMP6kU35rFMNzEx0Az6xCcl8u+06Q5ef40VEbD61UBRK9mVTAkKf3IAuP4NTFSDlRN4oNo4pQS858tifn58H4hEVjZmB/EK+4b7Mvjwbz02eDN4zmL/u/EVYXFge7YgQQgghCiKNRqMO7QkPD1dzv2SVhYWFOjzp5MmT2W5fdmgHPvKqLamjL9q0aWNwdjbtfEa5qWzZsur93PjR+/Dhw/j5+QEwZswY1q1bZ/A2btw4AO7fv5/l96RevXpqwuvjx49nmCD7Vd7e3qxZswZTU1OSkpJ499132bFjR5ba8rqRYJAQ6TA10TCqTTW2fNScqiXSzpZw5HYgnr8dZ991mblJ5DELW3B/H0ZfgLfXQcUW6iJLBTpFx/DHiyAOPX7CxKCX1I1NP78QwMUXF/nP6f/QZkMbxh8ez8HHB0lIkvxCQgghhEjrk08+wcrKCoCvvvrKqHwuhnTr1g2AW7dusW/fvmy3L6saNGigJlNeuHAhsbG5/6Nvar4kQ8m4L126xD///JPlbaSX+0cf7aBTlSpVsrzN9KQOETM1NeWbb77h7bffNnj7+uuv1aFy2Ukk/fXXX6u5jFKnms+Mt956i6VLl6LRaEhISOCtt95i//79WW7P60KCQUJkoE55B3aOacmgphXTLHsZFc/IVRf4YuNVmYJe5D0TE6jZBYbughFHoE5fMPnf6N+iycm8ExHJmoDn7PR7ygchYZRPSv+CICE5gQOPDzDu8Dja/NWGqaencvnF5UxdRAghhBDi361ChQrMmjULgLCwMFq0aMGJEycMrqMoSrrTjI8dO1adnWvo0KFpcpW+ateuXVy9ejXzDc+AiYkJEydOBFKmnR80aBBxBibsCA8PZ/bs2dnapouLCwAnTpzg3r17aZYHBgaqiZSzqlevXsydOzfD2d+WL1/OwYMHgZT3OKeHiEVFRbF582YgJZ+PoVxLqYoXL07r1q0B2LBhg8H3w5CaNWvSq1cvIKXX1+HDhzNdx6BBg5g3bx4AcXFx9OjRI1NDzl5HEgwSwgjWFqb8p3ttlg9tRAk9U9CvP+9Hl9+Pc+GVHENC5Jmy9aH3Ihh7FZqPAysHncUVExMZFRrG7sd+rHr6jL6RMdhr0k8bFxYXxoY7Gxi4ZyBdNndhzuU5PAp/lMs7IYQQQoiC4P3331dzuDx9+pSWLVvSrl07fvvtNw4ePMilS5c4f/48O3fu5D//+Q916tRh27ZtQEqvEO0E0KVKlWLFihVoNBoCAgJ48803+fDDD9m+fTsXL17kn3/+YdOmTXzxxRdUrVoVLy8vHj9+nCv79cEHH9CzZ08A/vrrL1xdXZkxYwZHjx7l8uXLHDt2jIULF9K/f3/Kli3L5MmTs7W91GTEUVFRtG7dmj/++INTp05x6tQpfv75Z+rVq8eNGzdo2rRplrfh5+fHqFGjKF26NP3792f+/PkcPnyYy5cvc+bMGZYvX07Xrl0ZOnQokDI8bObMmTk+TGzz5s1ERkYC0Lt3b6PXSy0bGhrK9u3bs7z91EAfZK13EMDIkSPVaeajo6Px8vLi7NmzWW5TfpME0kJkgkeNkuwb14qvNl9l3/XnOssev4zmrfmnGN2mGmPauWBuKrFWkQ8cykGHKdDqM7iyDk7PgZD/dd/WAG5x8bgFBvJlYCDHbGzYWboyR4khQdHfu80/0p/5V+Yz/8p86havi1dVLzwreVLUKuPpUoUQQgjx7zR58mTq1avHhAkTePDgAYcOHeLQoUPpltdoNHTq1IkZM2bo5KaBlN4r27ZtY8iQIbx8+ZL58+enO5W3iYkJtrZpUzjkBI1Gw/r16xk7dizz58/n/v37fP755+mWN6Z3iyF9+vRh6NChLFu2jKdPn/Lxxx/rLDc1NWXmzJmEhIRw+vTpLG2jfPnyXLhwgcjISDUXT3ocHBz4448/1F40OSl1mJdGo8lU/b169WL06NEkJyezcuVK3nrrrSxtv379+nTp0oXdu3dz6NAhzpw5Q5MmTTJdz7hx44iOjubrr78mIiICT09PDh06hJubW5balZ/k26oQmeRka8H8dxvy3z51sbXQnS4yWYFZh+7RZ94pHgRG5lMLhQAs7VLyCo25AG+vhYrN0xQxB9pFRzPzwXUO+z5kUrw1DWydDVZ7Negq0/6ZRtsNbRlzcAz7fPcRl5S1LrtCCCGEKNh69uzJ7du32bBhA8OGDaNWrVoUL14cMzMz7O3tqVy5Mt26dePHH3/k/v377NmzJ91Zury9vXn48CE///wzbdu2pVSpUpibm2NtbU3lypXx8vLi119/xdfXlzZt2uTaPpmbmzN37lyuXLnCmDFjqFOnDg4ODpiamuLg4ICbmxvDhg1j48aN3Lx5M9vbW7p0KatWraJly5YUKVIES0tLKlasyMCBAzl16hRjx47NVv1bt27l1q1b/P777/Tt2xdXV1d1f2xtbalQoQJdunTht99+4969e9kelqbPkydP1EBh06ZN0wQDDSlVqhTNm6dcx+7du5fAwMAst+Prr79W70+dOjXL9UycOJFvvvkGgJCQEDp27Jgjx0Je0yiSDEJkgb+/P87OKV8a/fz81GRrhc3j4GjGb7isd3iYtbkpX3d9gwGNK8gU9OL18OQinJkL17dAcvo5rvwdyrCrqjs7E4LwjfTLsNoi5kXoUKkDXlW8aFiqISYa+Z1BCCFE3rl79y6JiYmYmZmpOViEEOJ1lZVzVm58/5ZgkMgSCQb9T2JSMvOP3ue3A3dJTE7779S2Zkmm966rN9eQEPkizB/OLoTzy8HAdPKKuQ3XXb3YUbQYe5+d5mXsywyrLmNbhq5VuuJdxZsqjjk/C4UQQgjxKgkGCSEKEgkGiQJNgkFpXfUPZdz6yzwITJupv5itBT/1rkuHWqXyoWVCpCMuEi6vSektFOJroKCGhOqdOV2jNTsi73PY77BRQ8NqFauFVxUvOlfuTHHr4jnWbCGEEEKbBIOEEAWJBINEgSbBIP1i4pOYtvsmq87on3XpHXdnvulaC1tLyd0uXiPJSXB7N5yeC49PGS5bxo1I9/fZb2vFzod7OPfsHAqGP0ZMNaY0LdsUrypetK3QFmsz6xxsvBBCiMJOgkFCiIJEgkGiQJNgkGGHb73gs41XCYpM23uiUjEbfu3nRoMKMhOTeA09uZASFLq+BZSk9MsVKQuNR/Dsjc7senqSnQ92ci/0XobV25jZ0L5ie7yqeOFe2h1TE9MM1xFCCCEMkWCQEKIgkWCQKNAkGJSx4Mg4vtp8jb9vPE+zzNREw+g21RjdtppMQS9eT2H+8M8CuLDCYF4hzG2g/rso7iO5bZLIjvs72P1wN0ExQRluoqR1SbpU6YJXFS9qONXIwcYLIYQoTCQYJIQoSCQYJAo0CQYZR1EUNpz3Y8qOG0THp+1lUc/Zkd/6uVG5uG0+tE4II8RFwKX/zysUqn/4YwoN1OwKTUeRVN6df56dZceDHRx8fJCYxJgMN1O9aHW8qnjRpXIXStlKbi0hhBDGk2CQEKIgkWCQKNAkGJQ5j4KjGL/+Mhcfh6ZZZm1uyrdetXjH3VmmoBevLzWv0Bx4fNpw2bL1oeloqNWd6OQEDj4+yM4HOzkTcIZkJdngqho0NC7TGK8qXrSv2B5bcwmUCiGEMEyCQUKIgkSCQaJAk2BQ5iUmJTP3yH1+P3iXJD1T0LerWZLpfepS3E6moBevOf8LcGYOXN9qOK+QfTlwHwENB4N1UQKjA9n9cDc7H+zk1stbGW7GytSKthXa4lXFi6Zlm2JmIonXhRBCpCXBICFEQSLBIFGgSTAo6y77hTJ+/WUeBumfgn5677q0lynoRUEQ6gdnU/MKhadfztwW6r8LTT4ApyoA3A25y84HO9n1YBfPo9Pm1XqVk5UTXSp3wauqF7WcakkvOiGEECoJBgkhChIJBokCTYJB2RMdn8gPu26y5p/Hepe/416Bb73ewMZCekKIAiAuAi6t/v+8QvqP6RSpeYVGQ4UmoNGQrCRz/tl5djzYwf5H+4lKSBskfVUVhyp4V/WmS+UulLUrm3P7IYQQokCSYJAQoiCRYJAo0CQYlDMO3XrO5xuvEhQZn2ZZ5eK2zOznhpuzY943TIisSE6CWztT8gr5/WO4bNkG0HQU1OoOpuYAxCbGcsTvCDse7ODkk5MkGRqC9v/eLPUm3lW9aV+xPfYW9jmwE0IIIQoaCQYJIQoSCQaJAk2CQTknKDKOLzdd48BN/VPQj2lbjdFtqmEmU9CLgsTvXEpeoRvbM8grVB4aj4QGg8DaUX06OCaYvb572Xl/Jz7BPhluzsLEAg9nD7yqeNGiXAvM/z/AJIQQ4t9PgkFCiILkzp07JCUlSTBIFEwSDMpZiqLw5zk/pu7UPwW92/9PQV9JpqAXBU3oY/hnAVxcaTivkIVdSl6hxh+AU2WdRQ/DHqr5hZ5EPslwk46WjnhW8sSrqhd1i9eV/EJCCPEv9+DBA+Li4tBoNFSvXh0TE/kBTQjxekpKSuLOnTsAWFpaUqVKFaPWk2CQeG1IMCh3+AZFMW79ZS77haZZZmNhyiSvWvRrJFPQiwIoNvz/8wrNgzADeYU0Jv/LK+TcGLSO9WTl/9i77/iqq/uP4697b/YmYc8QsiAJG5QNguy9iUwV25/aqtXaWltrba2rtq6qdQCiMkQ2KKgIylL2CCOEvTfZO/f+/rhCpXeEcZPcm7yfjwePR+Cc7/f7uX3UJPd9zzkfMzvO72Dp4aWsPLqSrMKsUh/bKKQRA6IGMLDxQBqENHDFKxERETdz+vRpMjIyAKhXrx4hIdo2LCLu6dKlS5w/fx6AatWqUbt27Ru6TmGQuA2FQWWnuMTMW6sP8ua3B+22oO/VtBYvjkhSC3rxTCXF/z1X6OQm53PrtbGeK9R0CJiuP0y9oKSAtSfXsvTQUr4/9T3F5uJSH92yRksGNRlEn8g+hPqG3s6rEBERN5Kdnc2JEycAMJlM1K5dm6CgIK0QEhG3YLFYKCgoIDMzk0uXLl3796ioKHx9b+w9ncIgcRsKg8re9uNXeGzuDo5eyrUZqx7kw8sjm3NXvFrQiwc7sckaCu1bAhaz43mhDf57rpCfbYiTUZDByqMrWXpoKTsu7Cj1sV5GL7rW68qgJoPoWr8rPiaf23gRIiJS0SwWC8eOHSMvL+/avxkMBkwmUwVWJSJiVVJSwv/GLqGhodSte+NdcRUGidtQGFQ+cgqK+dvyfczeZH9bzT13NOTpAWpBLx7uyrH/nivkbOuXT5A1ELrjF1At0u6UE5knWHZkGcsOLeN4lrM291bBPsH0iezDwKiBtKrZCqNBnyKLiHgis9nM8ePHrwuERETcUY0aNYiIiLipoz8UBonbUBhUvr7Ze47fzd/FpRzbFvRRP7Wgb6EW9OLp8jNh+8fww7ulnyvUdNBP5wq1tzvFYrGw++Julh5ayoqjK0gvSC/18fWC6lnPF4oaSOPQxqXOFxER92KxWMjJySErK4u8vDxKSpx0sxQRKSdGoxEfHx8CAwMJCgrCx+fmV6UrDBK3oTCo/F3IKuD383exav95mzGT0cAjPWN4sHsTtaAXz1dSDPuXwoa34NQW53Prtf3pXKHBNucKXVVUUsT60+tZemgpa06sodBsG6r+r8SIRAY2GUjfyL5E+EfcwosQEREREXENhUHiNhQGVQyLxcLsTdYW9HlFtp92tW4Yxr/GtKRRhFrQSyVxYhNsfAv2LS3lXKGGP50rNMHuuUJXZRZm8s2xb1h6aClbzpUSNAEmg4lO9ToxKGoQ3Rt0x8/L71ZehYiIiIjILVMYJG5DYVDFOnwhm8c+28lOBy3o/zyoGaPbqgW9VCJXjv7sXKFsx/N8gn92rlAjp7c8k32G5UeWs/TQUg5nHC61hEDvQO5udDeDogbRtnZbnS8kIiIiIuVCYZC4DYVBFa+oxMxb3x7krdX2W9D3blaLF4YnEaEW9FKZ5GdYA6Ef/wMZJxzPMxitW8c6PAwN2jm9pcViYe/lvSw7tIwvjnzB5fzLpZZRK6AWA6IGMChqENHVom/2VYiIiIiI3DCFQeI2FAa5j20/taA/ZrcFvS+vjGxOj/iaFVCZSBkqKba2pN/4Fpza6nxu/fbQ4UGIH+TwXKGris3F/HDmB5YeWsq3x78lvyS/1FLiw+MZGDWQ/o37UyOgxs28ChERERGRUikMErehMMi9WFvQ72X2JvsrJSbc2Yg/9G+Kv4+pnCsTKWMWy3/PFdq/rPRzhe78JbSaAH4hpd46pyiHVcdXsfTQUn488yMWnP+4NBqM3FnnTgZGDaRnw54EeAfc7KsREREREbGhMEjchsIg9/T1Ty3oL9trQV8jkNfGtKR5/bDyL0ykPFw+Yt0+tv3j0s8VajPJeq5QWMMbuvW5nHN8eeRLlh5eyoErB0qd7+/lT8+GPRkUNYg76tyByaggVkRERERujcIgcRsKg9zXhawCfjd/F9/aaUHvZTTwaK8YftlNLeilEstL/++5QpknHc8zmKDZT+cK1W97w7dPvZzK8sPLWX54OefzbP87+1/V/avTv3F/BjUZRFy1OB3sLiIiIiI3RWGQuA2FQe7NYrHw6Y/H+dvyveQX2W6badOoGv8a3ZKGEdrGIpVYSRHsXWzdQnZ6u/O5De6ADg9B/EC4wVU8JeYSNp3dxLLDy/jm2DfkFtue2/W/osOiGRg1kAFRA6gdWPuGniMiIiIiVZvCIKkwCQkJ1/29qKiItLQ0QGGQOzt0IZvH5u5g18kMm7FAHxN/HpzAqDb1tVJBKjeLBY7/AD/8G/YtA2dn/4Q1gjv/D1qNB9/gG35EblEua06sYenhpWw8vZESS4nT+QYMtKvdjoFRA7m70d0E+QTd8LNEREREpGpRGCQVRmGQ5yoqMfPmqjTeWn0QOx3o6ZNQixeGNyc80Kf8ixMpb5cPW7ePbfsYinIcz/MNgdYT4Y5fQliDm3rExbyLrDiygqWHl7L30t5S5/uafOnRoAeDmgyiQ90OeBu9b+p5IiIiIlK5KQwSt6FtYp5n6zFrC/rjl223stQItrag7x6nFvRSReSlw7aPfjpX6JTjeQYTNBvy07lCbW76MYfTD7Ps8DKWHV7GmZwzpc4P9wunb2RfBjUZREJEglbtiYiIiIjCIHEfCoM8U3ZBMX9dupe5W+y3oJ/YoRFP9VMLeqlCbupcoTt/OldowA2fK3SV2WJm27ltLDu8jK+OfkVWUVap10SGRDIwaiBDoofofCERERGRKkxhkLgNhUGebeWeszy1YLfdFvRNagTy+thWJNYLrYDKRCqIxQLHN8LGf8P+5ZR+rtCD0OqemzpX6KqCkgLWnFjDssPLWHdyHcWWYqfzvQxe9I/qz+SEycRUi7np54mIiIiIZ1MYJG5DYZDnO5+Vz5Of72JN6gWbMS+jgcfujuWX3ZpgMmqbilQxlw5Zt49t/6SUc4VCoc0kuOMXEHpr3wOv5F9h5dGVLD28lF0XdpU6v2v9rkxJmEKbWm20hUxERESkilAYJG5DYVDlYLFY+OSHYzz/xT67LejbRVbjn6Nb0iBcLeilCsq7Alt/Olco67TjeQYTJAy1biGrd/PnCl11LPMYyw8vZ+mhpZzMPul0bvMazbk38V56NOiB0WC85WeKiIiIiPtTGCRuQ2FQ5XLoQjaPztnB7lO2LeiDfL14dnACI1rX00oEqZpKimDPItj4JpzZ6Xxuww7WUCiu/02fK3SVxWJh54WdLD60mGWHlpFfku9wbmRIJFMSpzAwaiA+JnUEFBEREamMFAaJ21AYVPkUlZh5/Zs03l5jvwV9v8Ta/H1YEtXUgl6qKosFjm2wniuU+gVOzxWqFmk9V6jlPeAbdMuPvJJ/hdn7ZzNr/ywyCmzD2qtq+NdgfLPxjIodRbDPzZ9jJCIiIiLuS2GQuA2FQZXXlqOXeeyzHZy4nGczVjPYl1dGtaBbbI0KqEzEjVw6BD+8Azs+haJcx/N8Q6HtZGj/Cwitd8uPyy3KZeHBhczcM5PTOY63rAV5BzE6bjTjm46nRoD+OxURERGpDBQGidtQGFS5ZRcU85cle5i31f65JZM7RvL7fvH4easFvVRxuZdh29Vzhc44nmf0goRh1tVC9Vrf8uOKzEWsPLqS6SnTOXDlgMN53kZvBjcZzKSESTQObXzLzxMRERGRiqcwSNyGwqCqYUXKGZ5asJsruUU2Y9E1g3htTEu1oBcBKC6EvYtgw5twtpSuYI06WUOhuH63da7Q+tPrmZ4ynU1nNzmcZ8DAXQ3v4t7Ee2leo/ktPUtEREREKpbCIHEbCoOqjvOZ+fz28118d8C2Bb23ydqC/hdd1YJeBPjpXKH1P50r9CXOzxVq/NO5Qsm3da7Q7gu7mb5nOt8c+waLk+e1rdWWKYlT6FKviw6DFxEREfEgCoPEbSgMqlosFgsf/3CM55fvo6DYtgV9+8hwXh3dQi3oRX7u4kH48R3Y/ikU257BdY1fKLSZAu0fuK1zhY5mHOWjvR+x+OBiisy2q/muig6L5t7Ee+nbuC/eRu9bfp6IiIiIlA+FQeI2FAZVTQfPZ/Ho3B2knMq0GQvy9eK5IQkMa6UW9CLXyb0MW6fDj+9B9lnH84xekDAcOjwIdVvd8uMu5l3k032fMnf/XLKKshzOqxNYh4nNJjI8ZjgB3gpyRURERNyVwiBxGwqDqq7CYjOvfXOAd787ZLcF/YCkOjw/LJGwALWgF7lOcSHsWQAb34Kzu53PbdQZOjwEsX3BaLylx2UXZvP5gc/5eO/HnM8773BeqG8oY+PGktw0mXC/8Ft6loiIiIiUHYVB4jYUBsnmo5d5bO4OTl6x3f5SK8SXf4xqQZcYtbYWsWGxwNG1sPFtOPCl87nhUdbtYy3Ggn+1W3pcYUkhyw8vZ/qe6RzJOOJwnq/Jl6HRQ5mUMIkGwQ1u6VkiIiIi4noKg8RtKAwSgKz8Iv6ydC+fO2hBP6VTJL/rqxb0Ig5dTIMf3oEds5yfK+TlB82GQJvJ0LAD3MJWTLPFzHcnvmNayjR2XNjhcJ7RYKRPoz5MSZxC04imN/0cEREREXEthUHiNhQGyc99ufsMTy3cTbqdFvQxNYN4bWxLEuqqBb2IQ7mXYcs02PQeZJ9zPrd6LLSeBC3GQWDELT1u27ltTE+ZzpqTa5zO61CnA/cm3csdte/QWWAiIiIiFURhkLgNhUHyv85l5vPEvJ2sTbtoM+ZtMvB47zimdolSC3oRZ4oLIGWBtTX9uVLOFTL5QNNB1tVCkV1uabXQwSsHmb5nOl8c/oJiS7HDec0imjElcQp3N7wbk1Er/URERETKk8IgcRsKg8Qes9nCzI1HeeHL/fZb0DcO55+jW1C/mjoXiTh19VyhLdNg3zJw0ioesJ4t1HoStLwHgm7+rK6zOWf5eO/HfH7gc3KLcx3Oqx9Un8kJkxkSPQQ/L7+bfo6IiIiI3DyFQeI2FAaJM2nnsnhkzg72nrFtQR/s68VzQxMY2lIt6EVuSM5F65lCW2fA5UPO5xq9IH6AdbVQ4+433YksoyCDz1I/45N9n3A5/7LDeeF+4STHJzM2fiyhvtoCKiIiIlKWFAaJ21AYJKUpLDbzr59a0Nv7LjOweR2eH5pEaIB3+Rcn4oksFji23hoK7V0CJQXO54c1gtYTodV4CK59U4/KL85nyaElzNgzgxNZJxzO8/fyZ2TsSCY2m0jtwJt7hoiIiIjcGIVB4jYUBsmN+vHwJX7z2U5Opdt2Sqod4sero1vQKbp6BVQm4sFyL8POObDtI7iw3/lcgwni+lm3kUX3hJs486fEXMI3x79hWso09l7a63Cel8GL/lH9mZwwmZhqMTd8fxEREREpncIgcRsKg+RmZOYX8eySPSzYdsru+H2dG/PbPnFqQS9ysywWOPEjbP0I9iyA4nzn80PqQ+sJ1tVCoTf+fdtisbDp7CampUxjw+kNTud2rd+VexPvpXXN1toKKiIiIuICCoPEbSgMkluxfNcZ/rBwNxl5tofhxtUK5l9jWtKsbkgFVCZSCeRdgV3zrNvIzu9xPtdghJje1tVCMb3B5HXDj9l/eT/TUqax8uhKzBbbg+KvalGjBVMSp9CjQQ+Mhps7u0hERERE/kthkLgNhUFyq85m5PPbz+23oPcxGXmiTyz3d47CqBb0IrfGYoFTW62hUMp8KHLcHQyA4DrQaoJ1xVBYwxt+zMmsk8zcO5OFaQvJL3G8IqlxaGOmJExhQNQAfEw+N3x/EREREbFSGCRuQ2GQ3A6z2cKMDUd5ccV+Cu20oL8zKpxXR7ekXph/BVQnUonkZ8Luedazhc7sLGWywXqmUOtJ1jOGTDd2uPvl/MvM3j+b2ftnk1GQ4XBeTf+ajG82nlGxowjyCbqJFyEiIiJStSkMErehMEhc4cC5LB511ILez4u/DU1kSMt6FVCZSCV0erv1bKHd86Aw2/ncwJrWc4VaT4Twxjd0+9yiXBYeXMhHez7iTM4Zh/OCvIMYHTea8U3HUyOgxs28AhEREZEqSWGQuA2FQeIqBcUl/OvrNP7zvf0W9INa1OVvQxLVgl7EVQqyrYdNb51h3U5Wmqju1tVC8QPBq/RtXkXmIlYeXcm0lGmkXUlzOM/b6M3gJoOZnDCZyNDIGy5fREREpKpRGCRuQ2GQuNoPhy/xuIMW9HVC/Xh1VAs6qgW9iGud3W1dLbRrLhTYrtC7TkB1aDkOWk+G6tGl3tpisbDu1Dqm75nO5rObHc4zYKBnw55MSZxC8xrNb/IFiIiIiFR+CoPEbSgMkrKQmV/EnxfvYeF2+y3o7+/cmCfUgl7E9QpzYe8i62qhEz+WPr9RZ2gzGZoOAm+/UqfvurCLGXtm8M2xb7Dg+NeOtrXacm/ivXSu11lt6UVERER+ojBI3IbCIClLS3ee5umFu8nML7YZi68dzGtjWxJfWy3oRcrE+X3W1UI7Z0N+uvO5/tWgxTjrNrKa8aXe+mjGUWbsmcGSQ0soMhc5nBdTLYYpCVPo27gv3kZtERUREZGqTWGQuA2FQVLWzmTk8cS8naw/eMlmzMdk5Mm+cdzbqbFa0IuUlaJ82LfEulro2PrS5ze407paKGEoeDvvBHgh9wKf7vuUualzyS5yfJh1ncA6TGw2keExwwnwDrip8kVEREQqC4VB4jYUBkl5MJstTFt/hJdXptptQd+xSQT/GNWCumpBL1K2LhywtqffORtybQPa6/iFQvMx1tVCtROdTs0uzObzA5/z8d6POZ933uG8UN9QxsWPY1z8OML9wm/lFYiIiIh4LIVB4jYUBkl5Sj2bxSNztrP/bJbNWIifF38blsTgFnUroDKRKqa4APYvs64WOvJ96fPrtbWuFkocDj6BDqcVlhSy/PBypu+ZzpGMIw7n+Zn8GBo9lIkJE2kQ3ODm6xcRERHxQAqDxG0oDJLyVlBcwqtfHeD9tYfttqAf0rIuzw1JJNRf54uIlItLh2DbTNjxKeRccD7XJxiaj7KuFqrb0uE0s8XMmhNrmJYyjZ0XdjqcZzQY6dOoD1MSp9A0oumt1S8iIiLiIRQGidtQGCQVZeOhSzz+2Q5OZ+TbjNUN9eMfo1vQsYla0IuUm+JCOPCl9dDpQ9+Ck25hANRpaV0tlDQSfIPtTrFYLGw/v51pKdP47uR3Tm/XsW5HpiRO4Y7ad6gDmYiIiFRKCoPEbSgMkoqUkVfEM4tTWLzjtM2YwQBTu0TxeO9YfL3Ugl6kXF05Bts/hm0fQ/ZZ53O9AyFphDUYqtva+h+vHWlX0pixZwZfHP6CYotth8GrmkU0Y0riFO5ueDcmo/7bFxERkcpDYZC4DYVB4g6W7DzNH520oH99bCviattfeSAiZaikGNJWWlcLHfwaLLYHwF+nVhK0mQTNR1sPoLbjbM5ZZu6dyecHPievOM/hrRoEN2BywmQGNxmMn5ff7bwKEREREbegMEjchsIgcRen060t6DccstOC3svIk33Ugl6kQmWchO2fWM8XyjzlfK6Xv/Ww6daToEF7u6uFMgoymJs6l0/3fcrl/MsObxXuF849Te9hTNwYQn3tB0wiIiIinkBhkLgNhUHiTq61oF+RSmGJ7QqETtHWFvR1QtWCXqTCmEvg4DfW1UIHVoClxPn8Gk1/Wi00BgJs28nnF+ez5NASZuyZwYmsEw5v4+/lz8jYkUxsNpHagbVv91WIiIiIlDuFQeI2FAaJO9p3JpPH5u5w2IL++WFJDFILepGKl3kGdnwCW2dCxnHnc02+kDDUulqoUUeb1UIl5hK+Pv4103ZPY9/lfQ5v42Xwon9Uf6YkTCG6WrQLXoSIiIhI+VAYJG5DYZC4q/yiEl79KpX31x6xOz6sVT3+MiSBED+1oBepcGYzHF4NW2dA6hdgdnxANAARMdbVQi2SITDiuiGLxcKPZ39kesp0Npze4PQ23ep3Y0riFFrXbK0OZCIiIuL2FAaJ21AYJO5uw8GLPD5vJ2fstKCvF+bPq6NbcGdUhJ0rRaRCZJ+HHZ9at5FdsR/mXmPygaaDrKuFIruA0Xjd8L5L+5ieMp2Vx1ZidnJ4dYsaLbg38V66N+iO0WB0OE9ERESkIikMErehMEg8QUZuEX9anMKSnfZb0D/QNYrf3K0W9CJuxWyGo2utq4X2LQVzkfP54VHQeiK0vAeCal43dDLrJB/t+YhFBxeRX2IbDF/VOLQxUxKmMCBqAD4mHxe8CBERERHXURgkbkNhkHiSxTtO8cdFKWTZaUHftE4Ir49tSWwttaAXcTs5F2HnbGswdOmg87lGL4jrD20mQ1SP61YLXc6/zOz9s5m9fzYZBRkOb1HTvybjm41nVOwognyCXPMaRERERG6TwiBxGwqDxNOcSs/j8c928MNh21bUPl5Gft83nskdI9WCXsQdWSxwbIM1FNq7GEoKnM8Pa/jTaqHxEFLn2j/nFuWyIG0BM/fO5EzOGYeXB3sHMzpuNPc0vYcaATVc9CJEREREbo3CIHEbCoPEE5nNFj5cd4RXVtpvQd8lpjqvjGxB7VC/CqhORG5I7mXYNdcaDF3Y73yuwQSxfa2rhaJ7gtG6JbTIXMSKIyuYvmc6aVfSHF7ubfRmcJPBTE6YTGRopMtegoiIiMjNUBgkbkNhkHiyvaczeXTudg6cy7YZC/X35u/DkhjQvI6dK0XEbVgscGKTNRTasxCK85zPD6kPrSdAq/EQWv+nW1hYd2od01KmseXcFoeXGjDQs2FP7k28l6QaSS58ESIiIiKlUxgkbkNhkHi6/KISXlmZyofr7HctGt66Hs8OVgt6EY+Qlw6751mDoXMpzucajBB9t3W1UExvMHkBsOvCLqanTGfV8VVYcPyrUbva7ZiSMIXO9TqrLb2IiIiUC4VB4jYUBkllsf7gRR7/bCdnM+23oP/XmJa0bxxeAZWJyE2zWODUNtg2A3bPh6Ic5/OD61hXCrWaANUaAXAk4wgf7fmIJYeWUOSkk1lMtRimJEyhb+O+eBsVGouIiEjZURgkDkVGRnLs2DG7Y926dWPNmjUufZ7CIKlM0nMLeXpRCst32R4oazDAL7s14bFesfh4Ge1cLSJuKT8TUuZbVwud2VHKZAM0ucu6WiiuH5i8uZB7gU/2fcJnqZ+RXWS7pfSqOoF1mJQwiWHRwwjwDnDhCxARERGxUhgkDkVGRpKens6jjz5qd2zy5MkufZ7CIKlsLBYLi3ac4plFe8gqsG1Bn1A3hNfGtCRGLehFPM/pHbDtI9g1DwqznM8NrAmt7rF2IwuPIrswm3kH5vHx3o+5kHfB4WWhvqGMix/HuPhxhPtpNaGIiIi4jsIgcSgyMhKAo0ePlsvzFAZJZXXySi6Pf7aTH4/YtqD39TLyVL94JnZQC3oRj1SQDXsWwNaP4JTjA6OvadzNuloofgCFBgPLDy9nWso0jmYedXiJn8mPodFDmZQwifrB+tkoIiIit09hkDikMEjEdUrMFt5fe5hXv0qlqMT2W2SXmOr8Y1QLaoWoBb2IxzqbYl0ttHMuFGQ4nxsQAS2TofVkzBFRrD6xmmkp09h1YZfDS0wGE70je3Nv4r3Eh8e7uHgRERGpShQGlYHz58+zadMmNm3axObNm9m8eTOXLl0CYNKkScyYMeOG73Xs2DHeeOMNli9fzokTJ/D19aVJkyaMHj2ahx56iICAsjtLIDIykoKCAl544QVOnz5NSEgI7dq144477iiT5ykMkqpgz+kMHpu7w24L+rAAawv6/klqQS/i0QpzYe9i69lCJ34ofX6jztBmMpb4gWy7spdpKdP4/uT3Ti/pWLcj9ybeS/va7dWBTERERG6awqAy4OyXspsJg5YuXcr48ePJzMy0Ox4bG8vy5cuJjo6+lTJL5egA6Xbt2jF79myaNGni0ucpDJKqIr+ohJdW7Gf6+qN2x0e0rs+zg5sRrBb0Ip7v/H7raqEdsyA/3flc/2rQYhy0nkSat4kZe2bwxeEvKLbYnjl2VbOIZtybeC+9GvbCZDS5tnYRERGptMri/bda4/xMw4YN6d27901ft337dsaMGUNmZiZBQUE8//zzbNiwgVWrVjF16lQADhw4wIABA8jKKuXgyls0ZcoUVq1axblz58jJyWH79u1MmDCBzZs307NnzzJ7rkhl5+dt4s+DEvj4vvbUCvG1GZ+/7ST9Xl/L5qO2ZwyJiIepGQ99X4DHU2H4B9ZVQI7kXYEf3oa37yBmwcM8H5TAF4PmM6HZBPy9/O1esvfSXp747gkGLRrEZ6mfkV+cX0YvRERERMS5Kr8y6M9//jPt2rWjXbt21KpVi6NHj9K4cWPgxlcGde3albVr1+Ll5cX3339Phw4drht/5ZVXePLJJ68979lnn7W5x+OPP05BQcEN1/3II48QExNT6ryJEyfy8ccf8+qrr/Kb3/zmhu9fGq0MkqooPbeQpxemsHy3bQt6owEe6NqER3vF4OetT/xFKo2Laf9dLZR7yflc31BoPpqMpJHMubKTWftncTnfcVAc7hfOPU3vYUzcGEJ9Q11cuIiIiFQW2iZWDm42DNq0adO1c3l+8Ytf8O6779rMMZvNJCYmsm/fPsLCwjh//jze3tdvKQkKCiInJ+eG61y9ejXdu3cvdd769evp3Lkzw4cPZ/78+Td8/9IoDJKqymKxsHD7KZ5ZvIdsOy3om9QI5OWRLWjTqFoFVCciZaa4APYvt54tdOS70ufXa0t+q2QW+3kzI3U2J7NPOpwa4BXAyNiRTGg2gdqBtV1Xs4iIiFQK2ibmhhYtWnTt6ylTptidYzQamThxIgDp6emsXr3aZk52djYWi+WG/9xIEARQvXp1gJsKmkTEMYPBwPDW9fnykS60jwy3GT90IYeR727g+eV7yS8qqYAKRaRMePlC4nCYtAR+vR06PwaBNR3PP7UFv2W/YcySP7DMO5ZXkh6maXhTu1Nzi3OZuXcm/eb34+l1T3PwysEyehEiIiIiVgqDbtO6desACAwMpE2bNg7ndevW7drX69evL/O6rvrxxx+B/7aeFxHXaBAewOwH7uR3fePxNl1/EL3FAu+vPaKzhEQqq/Ao6PUs/GYvjP4YonsBDhpSFGZh2jaDvkueZO7ps7zfaAQdarWzO7XYUsySQ0sYtmQYD696mG3ntpXZSxAREZGqTWHQbdq3bx8A0dHReHl5OZwXHx9vc42r7N+/n9zcXLv//rvf/Q6A5ORklz5TRMBkNPB/3Zuw7FddaF7f9ryPIxdzGP2fjTy7ZA+5hY47DImIhzJ5Q7PBMH4+PLITuv4Wgus4nG44s5M71/yL97auYG5Qa/rVugOjwf6vYt+d/I5JKyYx4YsJfHv8W8wWc1m9ChEREamCHKcXUqr8/HwuXrwIUOqevWrVqhEYGEhOTg4nTpxwaR1z5szhn//8J127dqVRo0YEBgZy4MABvvjiC4qKinjqqafo2rXrTd3z5EnHZxsAnDlje4CuSFUVVzuYBf/XkffXHuFf3xygsPi/b9osFpix4Sjf7j/PSyOa06FJRAVWKiJlplojuOuP0O33kPaV9dDptK/AXohTlEOz3Yt4GfhV7WZ8VLcxizJTKSgptJm648IOHln9CI1DGzMlYQoDogbgY/Ip+9cjIiIilZrCoNvw83btQUFBpc6/GgZlZ2e7tI4ePXqwb98+tm/fztq1a8nNzaV69er079+fBx98kN69e9/0Pa8eTiUiN8bLZOT/ujfh7mY1+e3nu9h+PP268eOXcxn3/g+Mv7Mhv+/XlCBfffsVqZRMXhDf3/on4xRs/wS2zYRM+x+yNDi7lz+e3cuDPgHMatyK2eZLZBbbrvY9knGEZzY8w1vb32JCswmMjB1JkE/pv3uIiIiI2KN3I7chPz//2tc+PqV/Sufr6wtAXl6eS+vo1q3bdWcSiUjFia4ZzOe/7Mi0dUf4x1epFBRfvyrgkx+Os3r/BV4a0ZzOMdUrqEoRKReh9aD776DrE3BwlXW1UOqXYLE9XD68MJeHU9dzr8HAgjpRfBToy9li2w+Pzued59Wtr/LervcYHTea8c3GU91f30tERETk5igMug1+fn7Xvi4stF3a/b8KCgoA8Pf3L7OaXKW0rWxnzpyhffv25VSNiGcxGQ1M7RpFz6Y1+d38XWw+euW68VPpeYz/8EfGtW/AU/2bEuLnXUGViki5MJogtrf1T+YZ2PGpNRhKP24zNcBiYfzpQ4wBVgSHMq1mbQ6abT9EyirK4sOUD5m5dyaDmwxmcsJkIkMjy/61iIiISKWgMOg2BAcHX/v6RrZ+XW3vfiNbyipaaWcgiUjpomoEMfeBDny08Sgvr0gl739azc/edII1qRd4YXgS3eOctKgWkcojpI51pVDn38Dh1dZQaP9yMF9/yLw3MCgrg4FZGaz192N69Vps8bLY3K7IXMT8tPksSFtAr0a9mJIwhaQaSeX0YkRERMRTqZvYbfDz8yMiwnoYbGkHLl+5cuVaGKTzeESqDqPRwJROjVnxaBfujAq3GT+Tkc/k6Zt5Yt5OMnKLKqBCEakQRiNE94TRM+E3+6DXX6wt6/+HAeial8/0E8f49PRZeubm221ib8HC18e+JvmLZO5deS9rT67FYrENj0RERERAYdBta9asGQAHDx6kuNhx6+j9+/df+7pp06ZlXpeIuJdGEYHMuv9O/jokgQAfk83451tPcve/vuObvecqoDoRqVBBNaHzo/DwVpi0FBJHgJ2OYc0LCnnt3HkWnzzNiKxsvB1kPZvPbubBVQ8yculIlh1eRpFZQbOIiIhcT2HQbercuTNg3QK2detWh/O+++67a1936tSpzOsSEfdjNBqY0CGSlY92pXO07YGv57MKuH/mFh6ds50rOaWfQyYilYzRCI27wshp8Jv90Pt5iIixmda4qJhnL15mxYlT3JueQZDZTvt64MCVAzy19ikGLhjIp/s+JbfItkuZiIiIVE0Kg27T0KFDr309ffp0u3PMZjMzZ84EICwsjB49epRHaSLiphqEB/Dxfe15YXiS3Rbzi3ac5u5/fc+KlLMVUJ2IuIXACOj4MDy8GaZ8Cc3HgMn3uik1S0p47EoGXx0/xWOXr1DDwQrl0zmneXHTi/Se35t/7/g3V/Kv2J0nIiIiVYfCoNvUvn17unTpAsCHH37Ixo0bbea8+uqr7Nu3D4BHHnkEb291DhKp6gwGA+PaN+Srx7rSLbaGzfjF7AJ++clWHp61jUvZBRVQoYi4BYMBGnWE4e/B4/uh70tQ4/rt5sEWC/dmZLHixGn+cuESkYX2t4VlFGTw7s536f15b57/4XmOZBwpj1cgIiIibshgqeKnC65bt46DBw9e+/vFixf57W9/C1i3c91///3XzZ88ebLNPbZv306nTp3Iy8sjKCiIP/zhD/To0YO8vDzmzJnDe++9B0BsbCxbtmy5rguZpzp58uS1g7BPnDih7mMit8FisfD51pM8t2wvWfm2n+xHBPrwlyEJDEiqg8Fg7+hYEalSLBY4uRm2zoCUBVB8fet5M7A6wJ9poSHs8vO1e4urOtbtSHJ8Ml3qd8Fo0GeEIiIi7qgs3n9X+TBo8uTJfPTRRzc839H/XEuXLmX8+PFkZmbaHY+NjWX58uVER0ffUp0VLSEh4bq/FxUVkZaWBigMEnGVc5n5/GHBblbtP293vG9Cbf46NJEawc7f3IlIFZKfAbs+g60fwbnd1w1ZgK1+vkwPDeH7AH+nt2kQ3ICxcWMZGjOUEJ+QMixYREREbpbCoDLgqjAI4NixY7z++ussX76ckydP4uPjQ3R0NKNGjeLhhx8mICDAFSVXCIVBIuXDYrGwaMcpnl2yl4w8260eYQHePDsogSEt62qVkIj8l8UCp7dZQ6Hdn0NRznXDB7y9mREawpdBARQ7+d7h7+XPoKhBjIsfR3Q1z/wAS0REpLJRGCRuQ9vERMrW+ax8/rQohZV77Lea79W0Fn8flkjNEL9yrkxE3F5BljUQ2vYRnN5+3dAZk4nZIcEsCA4kw2Ryeps7at/BuKbj6F6/Oyaj87kiIiJSdhQGidtQGCRS9iwWC8t2neHPS/Zw2U6r+RA/L54ZlMCI1vW0SkhE7Duz07paaNdnUJh17Z/zDAa+DAxgVkgwqb4+Tm9RN7AuY+LHMDx6OGF+YWVcsIiIiPwvhUHiNhQGiZSfS9kFPLNkD8t3nbE73j2uBi8MT6JOqPMzQUSkCivMgT0LYct0OLXl2j9bgG2+vswKCWJVYAAlToJlX5MvA6IGkByfTFx4XDkULSIiIqAwSNyIwiCR8vfl7jP8aXEKF7NtVwkF+3rx9ICmjGnXQKuERMS5U1thw1uwdzFYSq7981mTic9CgpgfHMTlUraQta7ZmuSmydzV8C68jd5lXbGIiEiVpjBI3IbCIJGKcSWnkL8s3cOiHaftjneJqc4Lw5OoX81zD6wXkXKSfhx+eNd6tlBh9rV/LjDAisBAZoUEsdfXeffCmgE1GRM3hpGxIwn3Cy/rikVERKokhUHiNhQGiVSsr/ee4+mFuzmfVWAzFuhj4vf9m3JP+4YYjVolJCKlyM+wniv047uQeeraP1uAnb4+zAoJ5utA513IvI3e9Gvcj+T4ZBKqJzicJyIiIjdPYZC4DYVBIhUvI7eI55btZf62k3bHO0RF8NKI5jSM0CohEbkBJUXWc4U2vAlnd103dMFkZF5wMPOCg7jo5XwLWYsaLUiOT+buRnfjbdIWMhERkdulMEjchsIgEfexOvU8T83fzdnMfJsxf28Tv+sbx8QOkVolJCI3xmKBo2ut5wqlrbxuqAj46qcuZLv8nG8hq+5fndGxoxkVN4rq/tXLsGAREZHKTWGQuA2FQSLuJTO/iL8v38eczSfsjrePDOelkc1pXD2wnCsTEY92IRU2vgU750LJ9dtS9/j4MCskiC+DAilysoXMy+hF70a9SW6aTPPqzXXIvYiIyE1SGCQVJiHh+v3/RUVFpKWlAQqDRNzJ2rQL/H7+bk6l59mM+XkbeaJ3HFM6NcakVUIicjOyz8PmD6x/ci9dN3TJaGR+cBBzQ4I47+Xl9DYJEQkkN02mb2RffEw+ZVmxiIhIpaEwSCqMwiARz5FdUMyLX+7jkx+O2x1v3TCMl0e2ILpmUDlXJiIerygPds6Gjf+GSwevHwK+DfBnVmgw2/z8nN4m3C+cETEjGBM3hlqBtcqwYBEREc+nMEjchraJibi/DYcu8rv5uzhx2XaVkI+Xkd/cHcv9nRvjZTJWQHUi4tHMZjiwwrqF7Nh6m+H9Pt7MDglmeWAABUbH32NMBhM9G/YkuWkyrWu21hYyEREROxQGidtQGCTiGXIKinllZSozNhy1O96ifiivjGpBbK3g8i1MRCqPU9usodCeRWApuW4o3WhkQXAgc0KCOVPKFrL48HjGxY+jf+P++Hk5X1kkIiJSlSgMErehMEjEs2w6cpknP9/J0Uu5NmM+JiO/7hnNL7o1wVurhETkVqUfhx/ehW0zoTDruqESYE2AP7NCgtnk7zzoCfUNvbaFrG5Q3TIsWERExDMoDBK3oTBIxPPkFZbw6lepfLj+CPa+8yfUDeEfo1rQtE5I+RcnIpVHfgZs/Qh+fBcyT9kMp3l7MzskiGVBgeQ52UJmNBjp0aAHyfHJtKvdTlvIRESkylIYJG5DYZCI59p67Aq//Xwnhy/k2Ix5GQ081COah3pE4+OlVUIichtKiqxbxza8AWd32QxnGA0sCgpiTkgQJ729nd4qOiyacfHjGBg1kADvgDIqWERExD0pDBK3oTBIxLPlF5Xwr28O8P73hzHb+SkQXzuYf4xqQWK90PIvTkQqF4sFjq6FDW9B2kqb4RJgnb8fs0KC2RDg7/RWwT7BDIsextj4sTQIblBGBYuIiLgXhUHiNhQGiVQOO06k89t5O0k7n20zZjIa+L9uTfhVz2h8vUwVUJ2IVDoXUq1t6XfOgZICm+Ej3l7MDg5mcXAQuUbH28IMGOhavyvJ8cl0qNtBW8hERKRSUxgkbkNhkEjlUVBcwpurDvLOd4cosbNMKLZWEK+MbEGLBmHlX5yIVE7ZF2Dz+7D5A8i9ZDtsMLA4OJA5IaEc9XYeRkeGRDIufhxDoocQ6B1YVhWLiIhUGIVB4jYUBolUPimnMnhi3k72n82yGTMaYGrXKB7rFYtfKW/MRERuWFEe7JxtXS106aDNsBnY+NMWsrUB/jj7pTXQO5AhTYYwLn4ckaGRZVWxiIhIuVMYJG5DYZBI5VRYbObtNQd569uDFNtZJRRVI5BXRragTaNqFVCdiFRaZrP1PKENb8GxdXanHPfyYk5IEItCQskyOP/1tVPdTiQ3TaZzvc4YDToMX0REPJvCIKkwCQkJ1/29qKiItLQ0QGGQSGW093Qmv/18J3tOZ9qMGQxwX6fGPN47Dn8frRISERc7tQ02vmXtRGYpsRnONRhYFhTIrGrhHCrlW1CD4AaMjRvL0JihhPiElE29IiIiZUxhkFQYhUEiVU9RiZn/fHeIN1YdpLDEbDMeGRHAyyNb0L5xeAVUJyKVXvpx+PE/sPUjKLTdvmoBNvn5MqtaOGv8vLH9LvVf/l7+DG4ymHHx42gS1qTMShYRESkLCoPEbWibmEjVceBcFr+dt5OdJzNsxgwGmNQhkif7xhHg41UB1YlIpZefYQ2EfnwXMk/ZnXLKy8Tc0FDmh4aRaSl2ers76txBcnwy3ep3w2TU6kYREXF/CoPEbSgMEqlaikvMfLDuCP/8+gCFxbafvzcI9+elEc3p2KR6BVQnIlVCSZF169jGN+HMTrtT8gwGvggMYFaNOhyg0Ont6gXVY0zcGIbHDCfUN7QMChYREXENhUHiNhQGiVRNB89n8+TnO9l2PN3u+D13NOSp/k0J8tUqIREpIxYLHF1nPVfowAr7U4Ctfr7MqlGPb71KKHHSh8zP5MeAqAGMix9HXHhcGRUtIiJy6xQGidtQGCRSdZWYLUxff4R/fJVKfpHtKqF6Yf68OCKJLjE1KqA6EalSLhyAH/4NO2ZDSYHdKWdNJj6LqM3nwQFcMdufc1WbWm1Ijk/mroZ34WVUqC0iIu5BYZC4DYVBInLkYg6/+3wXm45etjs+tl0D/jCgKSF+3uVcmYhUOdkXYPMHsPl9yL1kd0qBAVaEhDOrZl32Ftt2Svy5WgG1GBM3hhGxIwj30yH5IiJSsRQGidtQGCQiAGazhZkbj/LSilTyimxbQNcJ9ePvw5PoEVezAqoTkSqnKA92zoGN/4ZLaXanWICdfn7MqhfL15Ysiu20r7/Kx+hDv8b9SG6aTLOIZmVUtIiIiHMKg8RtKAwSkZ87fimX383fxcbD9j+RH9G6Ps8MbEZogFYJiUg5MJsh7SvY8CYcW+dw2nmTiXn1Ypnna+BScbbTW7as0ZLkpsn0atQLb6O+l4mISPlRGCRuQ2GQiPwvs9nCrE3HeeGLfeQU2n7SXjPYl+eHJXF3s1oVUJ2IVFmnt8OGt2DPQnCwCqgIWFmzEbMjarIr/5zT29Xwr8GouFGMih1FdX91UBQRkbKnMEjchsIgEXHk5JVcnlqwm7VpF+2OD2lZl2cHJVAt0KecKxORKi39BPz4Lmz9CAqzHE5LCQ5nVoNmrCg4Q5G5yOE8L6MXfSL7kByfTPMazcuiYhEREUBhkLgRhUEi4ozFYmHu5hM8v3wfWQXFNuPVg3z429BE+ibWqYDqRKRKy8+AbTPhh3ch86TDaZe8fPk8uh2fkcX5gitOb5kYkUhy02T6RPbBx6SgW0REXEthkLgNhUEiciPOZOTx1ILdrEm9YHd8QPM6PDc4gYgg33KuTESqvJIi2LvYeq7QmR0OpxUBq6LaMTs4kG2Zh53eMtwvnJGxIxkdO5pagdoSKyIirqEwSNyGwiARuVEWi4X5207x3NI9ZObbrhIKD/ThL4MTGNi8DgaDoQIqFJEqzWKBY+utodCBFU6n7qsdz+y60XyRmUpBSYHDeV4GL3o26klyfDKtarbS9zYREbktCoOkwiQkJFz396KiItLSrC1bFQaJyI04l5nP0wtT+Gaf/cNZ+yTU4q9DE6kZ7FfOlYmI/OTCAfjh39b29MX5DqelB9dmfmwH5uaf5Eyu8wOn48PjSY5Ppl/jfvh56fubiIjcPIVBUmEUBomIK1gsFpbsPM2fl+whPdf2YNawAG+eHZTAkJZ19Um6iFScnIuw+QPY9D7k2j8MH6DYO4Dvmt3NLO9iNl3a7fSWYb5hDI8Zzti4sdQJ0nlpIiJy4xQGidvQNjERuR0Xsgr406IUVuw5a3e8V9OaPD8siVoh+hRdRCpQUZ51ldDGf8OlNCcTDaTF9WR29TosO7+JvOI8hzONBiN3NbiL5KbJtK3VVsG3iIiUSmGQuA2FQSJyuywWC8t3n+GZxXu4nFNoMx7i58WfBjZjZJv6erMkIhXLbIa0r2DjW3B0rdOpGfXbsiiqNbOv7OZU9imnc6PDoklumsyAxgMI8A5wZcUiIlKJKAwSt6EwSERc5VJ2AX9esodlu87YHe8eV4O/D0uibph/OVcmImLH6e3WlUIpC8BS4nBaSbVGrE3oz6ziC2w8t8npLYN9ghkePZyx8WOpH6zfqURE5HoKg8RtKAwSEVdbkXKWPy5K4WK2bYeeIF8vnh7QlLHtGmiVkIi4h/QT8OO7sPUjKMxyPM8vjMMtRjA7yJ8lx78htzjX4VQDBrrV78a4puPoUKeDvt+JiAigMEjciMIgESkLV3IKeW7ZXhZut7+1onN0dV4YnkSDcG2nEBE3kZ8J22Zag6GME47nGb3JThzG4npxzD69hmOZx5zetnFoY8bFj2Nwk8EEege6uGgREfEkCoPEbSgMEpGytGrfOf6wcDfnMm1XCQX6mPh9/6bc074hRqM+NRcRN1FSBHsXw4Y34cwOp1PNjbuzoVkvZmXsY+0p52cQBXkHMSR6COPix9EopJHr6hUREY+hMEjchsIgESlrGXlF/G3ZXuZtPWl3/M6ocF4e0YKGEVolJCJuxGKBY+thw1tw4Evnc2smcLx1MrMNmSw6tJTsomyn0zvX60xyfDKd6nXCaDC6sGgREXFnCoPEbSgMEpHysib1PE8t2M2ZjHybMX9vE0/2jWNSh0itEhIR93MxzXrY9M7ZUGz7PeyaoFrktp3C0ohazDq0mMMZh53etlFII8bGjWVI9BCCfYJdXLSIiLgbhUHiNhQGiUh5yswv4oUv9jF7k/3zONpFVuPlkS1oXF3naoiIG8q5CJs/gE3vQ+5Fx/O8A7C0SObHmM7MOrWaNSfWYMHxr+r+Xv4MbjKY5PhkosKiXF+3iIi4BYVB4jYUBolIRViXdpHfzd/FqfQ8mzFfLyNP9I7j3s6NMWmVkIi4o6I82DXXulro4gEnEw3QdCCnWo5jblYq89Pmk1mY6fTWd9a5k+T4ZLrW74rJaHJt3SIiUqEUBonbUBgkIhUlu6CYl77cz8c/2O/E06phGK+MbE50TW2dEBE3ZTbDwa+th00fdX6ANPXbkdf+AZb7wqzUOaRdSXM6vV5QPcbGjWVYzDBCfUNdWLSIiFQUhUHiNhQGiUhF23joEr+bv4vjl3Ntxny8jDzWK5apXRrjZdIhqyLixk7vgI1vQcoCsJQ4nhfWCMudD7KlXjNmH1rEt8e/pcTJfD+THwOiBpDcNJnYarGur1tERMqNwiBxGwqDRMQd5BYW88rKVGZsOIq9n2Yt6ofy8sgWxNXWKiERcXMZJ+HHd2HrR1DgZEuYXyi0vZezicOYe3oN8w/M50rBFae3blurLclNk+nRoAdeRi8XFy4iImVNYZC4DYVBIuJONh+9zJOf7+LIxRybMW+TgV/fFcMvuzfBW6uERMTd5WfC9o/hh3cgw/6h+QAYvSFpJAXtH+DLvOPM2jeLfZf3Ob117cDajIkbw4iYEVTzq+biwkVEpKwoDBK3oTBIRNxNXmEJ//w6lQ/XHcFs5ydbQt0QXhnZgmZ1Q8q/OBGRm1VSDHsXWbeQnd7ufG5UDywdHmJnSHVm7Z/N18e+pthS7HC6j9GHfo37kdw0mWYRzVxbt4iIuJzCIKkwCQkJ1/29qKiItDTrAYYKg0TEnWw7foXfztvJoQu2q4S8jAYe6hHNQz2i8fHSKiER8QAWCxzbYA2FUr8EJ63mqdkMOjzE+SbdmXd4MfNS53Ep/5LT27eq2Yrk+GR6NuqJt9HbtbWLiIhLKAySCqMwSEQ8SX5RCa+vSuM/3x2yu0oovnYwr4xsQVJ9ddoREQ9yMc3aln7nbCjOdzwvqBa0f4DCVuP56sIWZu+bza6Lu5zeuqZ/TUbFjWJk7Eiq+1d3ceEiInI7FAaJ29A2MRHxBLtOpvPbebtIPZdlM2YyGvhltyh+3TMGXy9TBVQnInKLci7C5g9h03uQe9HxPO8AaHkPdHiQFHMus/bNYsXRFRSZixxfYvSmT2QfkuOTSaqRVAbFi4jIzVIYJG5DYZCIeIqC4hLe+vYgb685RImdZUIxNYN4ZVQLWjYIK//iRERuR1E+7JprXS10MdXJRAPED4COv+ZijSbMPzCfz1I/43zeeae3T6qexLj4cfSJ7IOPyce1tYuIyA1TGCRuQ2GQiHialFMZ/PbzXew7Y9uy2WiAqV2ieOzuWPy8tUpIRDyM2QwHv4YNb8LRtc7n1m8HHR6mKK4vq06uYfa+2Ww7v83pJRF+EYyMHcnouNHUDKjpwsJFRORGKAwSt6EwSEQ8UWGxmXfWHOKt1WkUldj++IuqEcgrI5vTplF4BVQnIuICp3dYVwrtWQBmxx3FCGsEdz4IrcazL/sEs/bP4ovDX1BoLnR4iZfBi16NenFP03toUaMFBoPB9fWLiIgNhUHiNhQGiYgn2382kyfm7STllO0qIYMB7u3UmCd6x+Hvo1VCIuKhMk7Cj/+BrTOgwPZ73TV+odBmCtzxC674+DM/bT5zU+dyNues09s3DW9KctNk+jXuh6/J17W1i4jIdRQGidtQGCQinq64xMx/vj/M69+kUVhithmPjAjgpRHNuSMqogKqExFxkfxM2P4x/PAOZJxwPM/oDYkjoOPDFNdsypoTa5i1fxabz252evtqvtUYETuCMXFjqB1Y27W1i4gIoDBI3IjCIBGpLNLOZfHE57vYeSLd7vikDo14sm88gb5e5VuYiIgrlRTDvsWw4S047fyMIKK6Q4dfQXRPDqSnMXv/bJYdWkZ+ieN29iaDibsa3sXYuLG0q91OW8hERFxIYZC4DYVBIlKZFJeY+XDdEV79+gCFxbarhBqE+/PS8OZ0jK5eAdWJiLiQxQLHNsDGtyD1S8DJW4EaTaHDQ9B8NBkl+Sw6uIjZ+2dzKvuU00dEhUYxJm4Mg5oMItgn2LX1i4hUQQqDxG0oDBKRyujQhWye/HwXW49dsTuefEdDnuoXT7CfdzlXJiJSBi4ehB/+DTtmQbHjVT8E1YL2U6HtfZT4hfL9ye+ZtX8WP5z5went/b38GRg1kLHxY4mtFuvi4kVEqg6FQeI2FAaJSGVVYrYwY8NRXlm5n/wi21VC9cL8eWF4El1ja1RAdSIiZSDnEmz5EDa9BzkXHM/z8odW91i7kEU04XD6YWbtn8WSQ0vIK85z+ojWNVszNn4svRr2wtukQF1E5GYoDBK3oTBIRCq7oxdzeHL+LjYduWx3fEzbBjw9sCkhWiUkIpVFUT7smmttTX8x1clEA8QPgI6/ggZ3kFWUzZJDS5ibOpcjGUecPiLcL5wRMSMYFTuKOkF1XFu/iEglpTBI3IbCIBGpCsxmCx//cIyXVuwnt7DEZrx2iB8vDE+iR3zNCqhORKSMmM1w8BvY+CYc+d753HptoePDED8Ii9HEprObmLN/DqtPrKbEYvt98yqjwUj3+t0ZEz+GO+vcidFgdPGLEBGpPBQGidtQGCQiVcmJy7n8bv4uNhy6ZHd8eOt6/HlgAqEBWiUkIpXMmZ3WDmR7FoC52PG8sIbW7WMtk8EvlLM5Z5mfNp/PD3zOxbyLTh8RGRLJ6LjRDG4ymFDfUBe/ABERz+exYVBUVBQAv/nNb3j44YfL+nFSDhQGiUhVY7FYmLXpOC98sZ/sAts3RDWCfXl+aCK9E2pXQHUiImUs4xT8+C5snQEFmY7neQdA4ghoOwXqtqbIUsyq46uYu38uW85tcfoIP5Mf/aP6MzZuLE0jmrq2fhERD+axYZCPjw8lJSV89913dO7cuawfJ+VAYZCIVFWn0vP4/fxdrE2z/0n34BZ1eXZwAuGBPuVcmYhIOSjIgm0fww/vQMZx53NrN4e290LSSPAN5uCVg8xJncPSQ0vJLc51emnzGs0ZGzeW3pG98TX5uvAFiIh4Ho8Ngxo2bMipU6fYtGkTbdq0KevHSTlQGCQiVZnFYmHelpP8dflesvJtVwlVD/Lhr0MS6Zekw1FFpJIqKYZ9i61byE5vcz7XJwiSRllXC9VpQU5RDssOLWNO6hwOph90emk132oMixnGqNhR1A/W75siUjV5bBg0atQoFixYwPTp05k4cWJZP07KgcIgERE4m5HPHxbu5tv95+2OD0iqw1+GJFA9SJ9qi0glZbHA8Y3ww9uw/wtwcmg0AHVbW0OhxBFYvAPYem4rc1Pn8s2xbyi2OD6TyICBrvW7MiZuDJ3qddKB0yJSpXhsGPTtt9/Sq1cvWrRowaZNm/D21gGbnk5hkIiIlcViYcG2U/xl6R4y7awSCg/04dnBCQxqXgeDwVABFYqIlJOss7D9Y9g6s/QtZL4h0HyMNRiqlcCF3AvMT5vPvAPzOJ9rP2C/qn5QfcbEjWFo9FDC/MJcV7+IiJvy2DAI4Omnn+aFF17g7rvv5oMPPrj2QsQzKQwSEbne+cx8nl6Uwtd7z9kd792sFn8blkjNYL9yrkxEpJyZS+DQt7BlOhz4Eixm5/Prt7eGQgnDKDZ5s+bEGuakzuHHMz86vczH6EPfxn0ZFz+OxOqJrqtfRMTNeGwY9NxzzwEwf/58du/ejclkolOnTjRv3pxq1aphMpmcXv/MM8+UdYlSioSEhOv+XlRURFpaGqAwSETkKovFwpKdp3l2yR6u5BbZjIf6e/PnQc0Y1qqeVgmJSNWQccq6WmjbTMg85XyuXyi0SLYGQzXiOJxxmHmp81h8cDFZRVlOL02ISGBM3Bj6Ne6Hn5dCdxGpXDw2DDIajdf90muxWG7ql+CSklL2HkuZUxgkInLjLmQV8OclKXyx+6zd8Z7xNXl+WBK1Q/WGRUSqiJJiOPg1bJkGaV8DpbwFadjRGgo1HUwuZr448gVz9s8h9Uqq08tCfEIYFj2M0XGjaRjS0HX1i4hUII8Og26H2VzK0lIpd9omJiJSui92n+FPi1K4lFNoMxbs58WfBjZjVJv6WiUkIlVL+nHrSqFtH0O2/dD8Gv9waJkMbSZjiYhm54WdzEmdw1dHv6LIbLsC8+c61e3E2PixdKnXBZPR+U4EERF35rFhkFQ+CoNERG7M5ZxC/rxkD0t3nrY73jW2Bi8OT6JumH85VyYiUsFKiuDACuvZQoe+pdTVQpFdrKuF4gdxqSiLhQcX8lnqZ5zJOeP0srqBdRkVN4ph0cOI8I9wXf0iIuVEYZC4DYVBIiI3Z+Weszy9MIWL2QU2Y0G+Xvyhf1PGtW+gVUIiUjVdPgLbPoLtn0DOBedzA6pDq3ugzWRKwhqx9tRa5uyfw/rT651e5m30pndkb8bGjaVFjRb6fisiHkNhkLgNhUEiIjcvPbeQ55buZcF2+4eodoqO4MXhzWkQHlDOlYmIuIniQkhdbl0tdOS70udHdYc2UyB+AMdzzvBZ6mcsPLiQzMJMp5fFh8czJm4M/Rv3J8Bb33NFxL0pDBK3oTBIROTWfbv/HE8t2M25TNtVQgE+Jn7fL57xdzTCaNSn1iJShV06BFunw45ZkHvJ+dygWtBqPLSeRF5wTVYcWcGc1DnsvbTX6WXB3sEMiR7C6LjRNA5t7MLiRURcp1KEQZcvX2b69Ol88803pKSkcPnyZQDCw8NJTEykV69eTJkyhfDw8PIsS26SwiARkduTkVfE88v38tmWk3bH72gczssjm9MoIrCcKxMRcTPFBbBvqXW10LF1pUw2QHRP62qh2L7svryPOalzWHFkBYVm28P8f+6OOncwLm4c3Rp0w8vo5br6RURuk8eHQf/5z3944oknyM3NBawt5q8r5qd9uwEBAbz66qs88MAD5VWa3CSFQSIirvHdgQs8NX8XpzPybcb8vU080SeOSR0a4WW6vc6cIiKVwoUDsHUG7PgU8tOdzw2uC60nQOuJpPsGsejgIuamzuVktv0Q/qqaATUZFTuKkbEjqe5f3WWli4jcKo8Og1588UWefvrpawFQaGgorVq1onbt2gCcPXuW7du3k5GRYS3MYOCFF17gySefLI/y5CYpDBIRcZ2s/CJe+HI/s348bnc8oW4Izw9LomWDsPItTETEXRXlwd7F1tVCJ35wPtdghJje0GYK5uierD/zA3NS57D25FosTjqYeRm86NWoF2PixtCmVhsdOC0iFcZjw6CUlBRatWpFSUkJderU4ZVXXmHUqFF4e3tfN6+4uJh58+bx29/+ltOnT+Pl5cX27dtJSEgo6xLlJikMEhFxvfUHL/K7+bs4eSXPZsxggHvuaMhv+8QT6u9t52oRkSrq3F7raqGdc6Agw/nckPrQZhK0msBJQwnzDsxjQdoC0gvSnV4WHRbN2LixDGwykEBvbd8VkfLlsWHQL3/5S9577z1q1KjB5s2badiwodP5J06coF27dly4cIEHHniAd955p6xLlJukMEhEpGzkFBTz0or9zNx4zO549SAf/jigGUNa1tWn1CIiP1eYC3sWWFcLndrifK7BBHH9oM0UCiI78dXxb5iTOoddF3Y5vSzQO5BBUYMYEzeG6GrRLixeRMQxjw2DYmNjOXToEK+++iqPPvroDV3zr3/9i8cff5zo6GgOHDhQtgXKTVMYJCJStjYfvczTC3dz4Fy23fGOTSL469BEmtQIKufKREQ8wNnd1lBo12dQmOV8blhDaG1dLbS38BJzU+fyxeEvyC+xPcvt59rWasvY+LHc1fAuvI1asSkiZcdjw6DAwEDy8/PZuHEj7du3v6FrNm3axJ133klAQADZ2fZ/EZaKozBIRKTsFZWYmbbuCK99k0ZeUYnNuI/JyC+6RfFQj2j8vE0VUKGIiJsryIaU+dYW9ae3O59r9IL4AdBmChl1W7HkyFLmps7lWKb9lZpX1fCvwYjYEYyMGUmtwFouLF5ExMpjw6CQkBBycnJYu3YtHTt2vKFrNm7cSKdOnQgKCiIzM7OMK5SbpTBIRKT8nLySy1+W7uXrvefsjjcMD+C5IQl0j6tZzpWJiHiQ09utq4V2fw5FOc7nVmsMbSZjbpnMD5kHmbt/LmtOrsFsMTu8xGQwcVfDuxgTN4b2tdtrK6+IuExZvP8ulz61V88IWrVq1Q1fc3VuaecLiYiIVHb1qwXw/sS2vD+xLfXC/G3Gj1/OZfL0zTz46VbO2mlRLyIiQN1WMPgNeHw/DPgn1EpyPPfKEfjmzxj/2YyOGz7g9cYjWTl8BVOTphLuF273khJLCV8f+5r7v7qfIYuH8Om+T8kqbYuaiEgFKZeVQY899hivv/46wcHBrFu3jqQkJ994sXYf69SpE9nZ2TzyyCP885//LOsS5SZpZZCISMXILSzm9VVpfLj2CMVm2x/hgT4mftM7jkkdGuFlKpfPfEREPJPFAqe2WlcLpcyHYttOjteJiIE2kylqPoqvz29lbupctp3f5vQSfy9/BkYNZEzcGOLC41xYvIhUJR67TezYsWPEx8dTWFhIUFAQf/rTn5gyZQoRERHXzbt06RLTp0/n+eefJyMjAz8/P/bv36/VQW5IYZCISMVKPZvFHxftZvPRK3bHm9UJ4flhibRqWK2cKxMR8UB56dbDprdOh/N7nc81+UKzIdB2CqnBEXyW+hlLDy8lr5QwqVXNVoyJG8Pdje7Gx+TjutpFpNLz2DAIYObMmUyZMuW/DzYYaNy4MTVr1sRgMHDu3DmOHDmCxWLBYrFgMBiYMWMGEyZMKI/y5CYpDBIRqXhms4XPt53khS/2cSW3yGbcYIDk9g15sk88oQHqdCMiUiqLBU5sgi3TYM9CKClwPr9GPLSZQlbTASw9vZa5qXM5nHHY6SXhfuGMiBnBqNhR1Amq48LiRaSy8ugwCGD58uX84he/4PTp0/8t4KeD1X5eRt26dXnvvffo379/eZUmN0lhkIiI+7iSU8iLX+5n7pYTdserB/nw9ICmDG1ZTweaiojcqNzLsHOOdbXQxQPO53r5QcJwLG0ms9kb5qTO5dvj31Jise0EeZXRYKRb/W6MjRvLnXXvxGjQ1l4Rsc/jwyCA4uJiFi5cyDfffENKSgqXL18GIDw8nMTERHr16sXQoUPx9tYnmO5MYZCIiPvZcvQyTy9MIfWc/QNLO0RF8NehiUTXDCrnykREPJjFAsc2WFcL7VsCJYXO59dKhDaTORfdg/nHv+LzA59zIe+C00sahTRidOxohkQPIdQ31IXFi0hl4LFh0PHjxwEICgoiPNz+6fviWRQGiYi4p6ISM9PXH+FfX6eRV2T7ibS3ycAvujbh4bui8fM2VUCFIiIeLOcS7PgUts6Ay4ecz/UOgMQRFLWewOqSdOakzmXz2c1OL/Ez+dGvcT/Gxo+lWUQz19UtIh7NY8Mgo9GIwWDgzTff5MEHHyzrx0k5UBgkIuLeTqXn8Zcle/hq7zm74w3C/XlucCI94muWc2UiIpWA2QxH11q3kO1bBmbbc9uuU7s5tL2XQw3bMPfIMpYcWkJOUY7TS5pXb86Y+DH0ieyDr8nXhcWLiKfx2DAoMDCQ/Px8fvjhB9q1a1fWj5NyoDBIRMQzfLP3HH9esodT6fa73PRLrM0zg5pRJ9S/nCsTEakkss/D9k+sq4XSjzmf6xMESaPIaTmW5bnHmb1/NgfTDzq9JMw3jGExwxgdO5r6wfqdW6Qq8tgwKDY2lkOHDrF+/XruvPPOsn6clAOFQSIiniO3sJg3vz3I+98fpths+2M/0MfEY3fHMrljJF4mHWAqInJLzGY4vNq6Wmj/F+Dk8GgA6rbG0mYy22pFM/fQYr4+/jXF5mKH0w0Y6FyvM2Pjx9KpbidMRm31FakqyuL9d7n8xte7d28A1q1bVx6PExERkZ8J8PHid33j+eKRLrSPtD27L6ewhL8t38egt9az7fiVCqhQRKQSMBohuieM+QQe2wM9/gihDRzPP70Nw9Jf0+bjsbycbebrbm/xcMuHqRVQy+50CxbWnlrLQ6seYsDCAUxPmc6VfH3PFpFbUy4rg9LS0mjVqhVBQUFs3bqVevXqlfUjpYxpZZCIiGeyWCx8vvUkL3y5n8s5th1xDAYY264hv+sbR1iATwVUKCJSiZhL4OAq62qhAyvAYnY+v357ittM4rvQCOYeXMjGMxudTvcx+tC3cV/GxI0hqXoSBoPBhcWLiLvw2G1iAEuWLGH8+PGEhoby0ksvMXLkSHx89Eump0hISLju70VFRaSlpQEKg0REPNGVnEJeXrmf2ZtO2B2PCPThD/2bMrx1Pb25EBFxhYxTsP1j2DYTMk85n+sXCi2SORLfm88ubGbxwcVkFWU5vaRZRDPGxo2lb+O++HvpHDiRysRjw6C77roLgGPHjnHkyBEMBgM+Pj7ExMRQrVo1TCbH+10NBgOrVq0q6xKlFAqDREQqp63HLvP0whT2n7X/JuOOxuE8PyyR6JrB5VyZiEglVVIMaV9ZVwulfQ2U8nasYUdyW93Dl35ezDk4n/2X9zudHuITwtDooYyOG02jkEauq1tEKozHhkFXW8uDdXn6jTAYDFgsFgwGAyUlpRy+JuVO28RERCqPohIzM9Yf5V/fHCC30PZnrrfJwNQuUfzqrhj8fXRgqYiIy6Qft64U2jYTss85n+sfjqXFOHY16cics+tZeXQlRaW0tO9YtyNj48bStX5XHTgt4sE8Ngzq3r37bS0xX716tQurEVdQGCQiUvmcTs/jL0v3sHKP/Tck9av589yQBO6Kt3+4qYiI3KKSIkj90rpa6NC3pc+P7MLlFqNZaMzns7T5nM457XR6ncA6jIodxfCY4UT4R7ioaBEpLx4bBknlozBIRKTy+nb/OZ5ZvIeTV/LsjvdJqMWfByVQN0xnUoiIuNzlI7DtI9j+CeRccD43oDolLZNZ1yCROafWsP7UeixOtp15Gb3o3ag3Y+PH0rJGS50JJ+IhPDYMOn78OABBQUGEh9u2tBXPozBIRKRyyyss4c1v03h/7WGKSmx/VQjwMfFYr1gmd4rE22SsgApFRCq54kJIXQ5bpsGR70ufH9WdE4lD+Kz4IgsPLyGjIMPp9LhqcYyJH8OAxgMI8A5wUdEiUhY8Ngy6embQm2++yYMPPljWj5NyoDBIRKRqSDuXxdOLUth05LLd8fjawTw/LIk2jaqVc2UiIlXIpUPWLWTbP4U8+9+PrwmqRX7Lcayo0ZC5J74h5VKK8+neQQyJHsLouNFEhUa5sGgRcRWPDYMCAwPJz8/nhx9+oF27dmX9OCkHCoNERKoOi8XCgm2neP6LfVzOKbQ7Z1z7BvyubzxhAT7lXJ2ISBVSXAD7llpXCx1bX8pkA0T3JCWuF3PyT7Li2EoKSgqcXnFH7TsYEz+GHg164GX0cl3dInJbPDYMio2N5dChQ6xfv54777yzrB8n5UBhkIhI1ZOeW8hLK1KZvem43fHwQB/+0L8pI1rX0zkUIiJl7UIqbJ0BO2ZBfrrzucF1SW8xmsVh4cw9vpITWSecTq8ZUJORsSMZGTOSGgE1XFayiNyasnj/XS6b/Hv37g3AunXryuNxIiIiUgbCAnx4YXgS8/+vI/G1g23GL+cU8sS8nYx57wfSzmVVQIUiIlVIjTjo+wI8vh+G/QcaOPnQPes0YeteY9LyZ1mWF8Q78ffRvX43DNgP7s/nnuftHW/T+/PePPHdE2w+uxn1HRKpXMplZVBaWhqtWrUiKCiIrVu3Uq9evbJ+pJQxrQwSEanaikvMzNhwlH99fYCcwhKbcS+jgaldo/j1XTH4+5gqoEIRkSro3F7r2UI750IpB0gTUp9TzUcwL9CHBce+4krBFafTo8OiGRM3hoFRAwnyCXJh0SJSGo/dJgawZMkSxo8fT2hoKC+99BIjR47Ex0fnCngqhUEiIgJwJiOP55bu5cuUs3bH64X589yQBHo2rVXOlYmIVGGFubBnAWyZDqe2OJ9rMFEY24evGrVkTvpudl7Y6XR6gFcAg5oMYkzcGGKqxbiwaBFxxGPDoLvuuguAY8eOceTIEQwGAz4+PsTExFCtWjVMJsefGBoMBlatWlXWJcpNUhgkIiI/t3r/eZ5ZksKJy3l2x3s3q8WfBydQL8y/nCsTEanizuyyrhbaNQ8KS9nCG9aQfYmDmOtdwhcnV5NXbP97+lVtarVhbNxYejbsibfJ24VFi8jPeWwYdLW1PHDDe00NBgMWiwWDwUBJie3yc6lYCoNEROR/5RWW8NbqNN77/jBFJbY/7wN8TDzaK4YpnRrjbSqXYwtFROSqgmxI+dy6WujMDudzjV5kxvVhSZ0Y5l7aytHMY06nV/evzoiYEYyMHUntwNquq1lEAA8Og7p3735bXUVWr17twmrEFRQGiYiIIwfPZ/HHRSn8cPiy3fH42sH8bWgibSPDy7kyEREB4PR2ayi0+3MoynE61VItkh+a9mauIYfVZ9ZjtpgdzjUZTPRo0IMx8WO4o/Yd6iwp4iIeGwZJ5aMwSEREnLFYLCzcfornl+/jUk6h3Tlj2jbg9/3iqRaoMwRFRCpEfibs/gy2zIBzu53PNXpzNr4P82rUY/75TVzKv+R0emRIJGPjxzKoySBCfEJcV7NIFaQwSNyGwiAREbkRGblFvLRyP7M3HcfebxzVArx5qn9TRrWpr0+QRUQqisUCp7ZaVwulzIdSzgoqiohmVWwX5hRfZOtF5wdO+3v5079xf8bGjyU+PN6VVYtUGQqDxG0oDBIRkZux7fgV/rgwhb1nMu2Ot4usxt+GJhFXO7icKxMRkevkpcOuudZg6MI+53NNvhyIu5vPqlVj6fnN5BbnOp3eskZLxsSPoXej3viYtCpU5EZVqjDo5MmTnD17ltzcXNq1a4e/v7qLeBKFQSIicrOKS8x8tPEY//wqlZxC2+YQXkYD93VpzCM9Ywjw8aqACkVE5BqLBU78aA2F9iyEkgKn07NrxLE0qh1z849zKPOo07nhfuEMjxnOqNhR1A2q68KiRSonjw+DsrKyePnll5kxYwanT5++9u+7d++mWbNm1/4+Z84cFixYQGhoKO+//355lSc3QWGQiIjcqrMZ+Ty3bA9f7D5rd7xemD/PDk7g7ma1yrkyERGxK/cy7JwDW6bBpTSnUy1efmyJ68mcIF++vbCDYkuxw7lGg5Gu9bsyNm4sHep2wGhQp0kRezw6DEpLS6N///4cPnz4uvbyBoPBJgw6evQo0dHRWCwWvvvuOzp37lweJcpNUBgkIiK3a3XqeZ5ZnMKJy/bPpri7WS2eHZxAvTCtHhYRcQsWCxxbb10ttG8JlNhvEHDV+drNmN8ggc9zDnE+76LTuQ2CGzAmbgxDo4cS6hvqyqpFPF5ZvP8ul+g1Pz+fAQMGcOjQIQICAnjyySdZtmyZw/mRkZH06NEDgCVLlpRHiSIiIlLOesTV5OvHuvGru6LxNtkeHv313nP0evU73v3uEEUljlsZi4hIOTEYILIzjPwQfrMP7v4rhEc5nF7z7F7+b/M8VqTt558BzbijWlOHc09kneAfW/5Bz3k9+dP6P7Hn4p6yeAUi8pNyCYPeeecdDh48SGBgIGvXruXFF1+kf//+Tq/p168fFouFjRs3lkeJIiIiUgH8vE083juOLx/pSoeoCJvxvKISXvxyPwPfWMfmo5croEIREbErsDp0+jU8vBUmLoZmQ8Fo/7w376Jc7t6zgg+2rWRxfjDJ4S0J8g60O7egpIBFBxcxdvlYkpcns/jgYgpKOa9IRG5euYRBCxYswGAw8Mgjj9CyZcsbuqZFixaAdXuZiIiIVG7RNYOYNfUOXhvTkupBth1mUs9lMerdjTz5+U4u5zjfliAiIuXIaISo7jD6I+tqoZ5/hrBGDqdHndnDU1uXsOrYCf4UEEdsUEOHc3df3M0f1/+RPp/3YUbKDHKLnHcrE5EbVy5h0L591paEvXv3vuFrIiKsnw6mp6eXRUkiIiLiZgwGA0Nb1WPVb7pzzx0NMdjuHOOzLSfp+eoa5m4+jtlcIQ1RRUTEkaCa0OU38OsdMH4BxA8Eg8nu1ICCbEbv+ZrPd69jZmEo/UKb4uVgZdGl/Eu8uvVVes/vzTs73yGjIKMMX4RI1VAuYVB2djYAQUFBN3xNQYF1KaC3t3eZ1CQiIiLuKTTAm+eHJbHg/zqSUDfEZvxKbhG/m7+b0f/ZyP6zmRVQoYiIOGU0QnRPGPspPLYHevwRQhvYnWoAWp3azcs7VvL16Uv8OiCG2n6224YBMgoyeHvH2/SZ34fXtr7GpbxLZfgiRCq3cgmDrq7yOXr06A1fs2eP9cCw2rVrl0VJIiIi4uZaNazG4oc68czAZgT52n5avOXYFQa+sY4XvthHbqHj1sUiIlKBQupAt9/CIzsheR7E9gMHLeSr52Uwdc8qvty3ndeLw7gjKNLuvJyiHD5M+ZC+8/vy4qYXOZtztgxfgEjlVC5hUOvWrQH4/vvvb/iamTNnYjAY6NChQ1mVJSIiIm7Oy2Tk3s6N+eY33RiQVMdmvNhs4T/fH+buf37PV3v0ZkBExG0ZTRDbG5LnwKO7odvvIbiu3alewF0ndvHB7u+ZfTGbHr72Fwjkl+Tz6b5P6begH89ueJYTmSfK8AWIVC7lEgaNHDkSi8XCe++9x/Hjx0ud/9prr10LjsaNG1fW5YmIiIibqx3qx7/vac2MKe1oFBFgM34qPY8HPt7K/R9t5uQVHTAqIuLWQutDj6esodDY2RB9N9YNY7YSsy7zxv5NzD99gX7eNTDamVdsLmZ+2nwGLhrI79f+nkPph8r4BYh4PoPFYinz0xfNZjOtW7dm165dREZG8u9//5u+fftiMpkwGAykpKQQHx/Pli1beO2115gzZw4AXbp0Yc2aNWVdntyCkydP0qCBdd/viRMnqF+/fgVXJCIiVUV+UQlvrz7Iu98dprDEbDPu723i1z1juL9LY7xN5fK5l4iI3K4rx2DbTNj+MWSfczjtmJcXHzZsxlIyKbbY/gy4qlfDXkxtPpVmEc3KolqRclUW77/LJQwCOH78OJ07d+bkyZMYDAYCAgLIzbV+cle9enWysrKuHRptsVho0qQJ69evp2bNmuVRntwkhUEiIlLRDl3I5k+LUthwyP4BorG1gvjb0CTaNw4v58pEROSWlRRB6pewdToc+tbhtNNeJqbXjWGBVyGFFsfnxnWq14lfNP8FrWq2KotqRcqFR4dBAJcvX+ZXv/oVn332GSUlJfYLMhgYNWoU77zzDtWqVSuv0uQmKQwSERF3YLFYWLLzNH9dtpeL2YV254xqU5+n+jclPNCnnKsTEZHbcukQbPw37PgUivPtTrloMjKzViPm+BvJMxc5vFXbWm2Z2nwqHep0wGCwvyVNxF15fBh01bFjx1i+fDlbtmzh/PnzlJSUEBERQatWrRg0aBCxsbHlXZLcJIVBIiLiTjLyivjHylQ++fEY9n6zCQvw5vd94xndtgFGo94EiIh4lOwLsOk965/8dLtT0o1GPq1em0+DA8gy2/9wACCpehJTk6bSrUE3jA66mom4m0oTBonnUxgkIiLuaOeJdJ5etJuUU5l2x9s0qsbzwxKJrx1SzpWJiMhtK8i2nim08d+QYb9zWLbBwNxqEcysVo3L5gKHt4qpFsPUpKn0btQbk9FUVhWLuITCIHEbCoNERMRdlZgtfLzxKP/46gDZBbbnSJiMBu7r3JhHesYQ6OtVARWKiMhtKSmCPQth/etwLsXulDyDgQUhIUyPqME5i+OVQo1CGnFf4n0MbDIQb6N3WVUsclsUBonbUBgkIiLu7lxmPn9dtpdlu87YHa8b6sczgxLok1BL50eIiHgiiwUOrbKGQke+tzulEFgaFMgHNWpzEsdnCtUJrMOUxCkMix6Gn5dfGRUscmsUBonbUBgkIiKe4vsDF/jT4hSOXcq1O94zvibPDk6gQXhAOVcmIiIuc2qbNRTatwTstJwvBlYEBvBBjdocMjjuPhbhF8HkhMmMjhtNgLd+Loh7UBgkbkNhkIiIeJL8ohLeWXOId9YcorDE9k2Cn7eRX/eM4f7OUfh46UBRERGPVUoHMjPwbYA/71WvyT6T47fCob6h3NP0HpLjkwn1DS3DgkVKpzBIKkxCQsJ1fy8qKiItLQ1QGCQiIp7j8IVsnlm8h3UHL9odj6kZxN+GJnJHVEQ5VyYiIi5VSgcyC7De34/3I6qzzdvxhwCB3oGMjRvLhGYTiPDXzwapGAqDpMIoDBIRkcrCYrGwZOdp/rZ8Hxey7HeaGdG6Pn/oH09EkG85VyciIi5VkA3bP4GNbznsQLbFz5f3qoWz0c/xAdJ+Jj9GxI5gcsJkagfWLqtqRexSGCRuQ9vERETE02XmF/HqylRm/nAMe78Nhfp78/t+8Yxp2wCjUQdMi4h4tBvoQLbbx4f3qoWxJsDxAdJeRi+GNBnCfYn30SCkQVlVK3IdhUHiNhQGiYhIZbHrZDpPL0xh96kMu+OtG4bx/LAkmtYJKefKRETE5W6gA9kBb28+CAthZVAgtqfMWRkNRvo17sfUpKk0CWtSdvWKoDBI3IjCIBERqUxKzBY++eEY/1iZSlaBbZcZk9HAlI6RPHp3LEG+XhVQoYiIuNypbbDhDdi72G4HsmNeXnwYFsLSoCCKnSwQ7dWwF1ObT6VZRLMyLFaqMoVB4jYUBomISGV0PjOfvy7fx9Kdp+2O1wn148+DmtEnoTYGg7aOiYhUCpcPWzuQbf/Ebgey014mpoeGsCA4mEIn3/o71evEL5r/glY1W5VhsVIVVdowaOfOnRw8eBCDwUBUVBQtW7as6JKkFAqDRESkMlubdoE/LUrh6KVcu+N3xdfkL4MTaBAeUM6ViYhImSmlA9lFk5GZISHMCQ0mz8kHAm1rtWVq86l0qNNBHxyIS3hMGHTgwAEAwsLCqFmzpsN53377LQ8++OC1rlRXNWrUiH/9618MGTLE1aWJiygMEhGRyi6/qIR3vzvE26sPUVhiu33Az9vIr+6KYWqXKHy8HLclFhERD1NKB7J0o5FPQ4L5NDSYLKPj7/9J1ZOYmjSVbg26YTTo54TcOo8Ig3bt2kXLli0xGAxMnz6diRMn2p23cuVKBg0aRElJCfZKMBqNzJw5k+TkZFeWJy6iMEhERKqKIxdzeGZxCmvTLtodj64ZxF+HJNKhSUQ5VyYiImWqpAj2LPqpA9lum+Fsg4E5IcF8HBrMZZPJ4W1iqsUwNWkqvRv1xmR0PE/EEY8Ig/7xj3/w5JNPEhYWxrlz5/D29raZk5ubS0xMDGfOnAEgPDycAQMGULduXbZs2cKqVasA68qigwcPEh4e7soSxQUUBomISFVisVhYtusMzy3by4WsArtzhreuxx/6N6V6kG85VyciImWqlA5keQYDC4IDmR4awjkvx00GIkMiuTfxXgY2GYi30fZ9sogjZfH+2+Vr1TZt2oTBYGDAgAF2gyCAWbNmcebMGQwGA4mJiaSkpPDRRx/xwgsv8PXXXzNt2jQAMjIy+PTTT11dooiIiMhNMRgMDGpRl1WPd2Nyx0iMdo6AWLDtFD1f/Y5ZPx7HbK7wIxlFRMRVDAaI7gWTlsLU1ZAwDH627cvfYuGezGy+OHGaP1+8RP2iIru3OZp5lGc2PMOABQOYvX82+XYOqxYpLy4Pg/bt2wdA165dHc6ZN2/eta/feOMNateufd345MmT6devHxaLha+++srVJYqIiIjckhA/b54dnMDihzrTvH6ozXhGXhF/WLibEe9uYM/pjAqoUEREylS91jBqBvxqK7S7H7z8rg35ACOzclh68gwvnL9Ik8JCu7c4k3OGv//4d/ot6MeMlBnkFtlvViBSllweBp08eRKApk2b2h03m81s2LABg8FA/fr16d69u915o0ePBiAlJcXVJYqIiIjclqT6oSx8sBPPDUkg2Nd2S8D24+kMenMdf122l+yC4gqoUEREylR4FAx4FR5Nga5Pgl/YtSEvYGBOLgtOneVf5y7QtMB+KHQx7yKvbn2V3vN7887Od8go0IcIUn5cHgZlZ2cDEBISYnd8z5495OTkANCtWzeH94mPjwfg0qVLLq5QRERE5PaZjAYmdohk1RPdGNKyrs242QIfrjtCr1e/48vdZ+w2zBAREQ8XVAPuehoe2wN9X4LQBteGjECv3Dzmnj7L22fP0yrf/rawjIIM3t7xNn3m9+G1ra9xKU/vgaXsuTwM8vOzLpPLysqyO/7jjz9e+7pNmzal3iffwX8wIiIiIu6gZrAfr49txSf33UHj6oE242cz8/m/T7cxZcZmjl/SVgARkUrJNwju/CX8ejsM/wBqJV0bMgBd8vL56Mx5pp05R4e8PLu3yCnK4cOUD+k7vy8vbXqJszlny6l4qYpcHgbVqVMHgB07dtgdX7t27bWv77zzTof3uXLlCgBBQUGuK05ERESkjHSOqc6Xj3ThsV6x+HjZ/oq1JvUCd//rO976No2C4pIKqFBERMqcyRuaj4JfroXx86Hxf8/SNQDt8gt47+wFZp06S/cc+x8Q5Jfk88m+T+i3oB/PbniWE5knyql4qUpcHga1bdsWi8XC9OnTbcZycnJYunQpAMHBwbRt29bhfVJTUwHUslxEREQ8hp+3iUd6xfDVo13pElPdZryg2Mw/vjpAv9fXsuHQxQqoUEREykUpHciSCgt58/xFPj95hn7ZORjsbCUuNhczP20+AxcN5Pdrf8+h9EPl+QqkknN5GDRu3DgAtm/fztSpU8nMzAQgPT2dyZMnk56ejsFgYOTIkZhMJof3+f777wFISEhwdYkiIiIiZSqyeiAz723PW8mtqBnsazN++EIOye//yGNzd3Ahq6ACKhQRkXLjpANZXFERL1+4xJKTZxialY2XnVDIbDGz/PByhi4eymOrH2Pvpb3lWLxUVgZLGZxm2KVLF9avX4/BYMDLy4vq1atz7tw5LBYLFosFHx8fdu/eTUxMjN3rc3NzqV27Njk5Obz++us8/PDDri5RbtPJkydp0MB6ONqJEye0gktERMSBrPwiXv3qADM3HsVs57euED8vnuwbT3L7hhiNhvIvUEREylfORdj0nvVP3pXrhk57mZgeGsKCoCAKnfxM6FyvMw80f4BWNVuVdbXiBsri/bfLVwYBzJ8/n8TERCwWC0VFRZw5cwaz2YzFYsFoNPL22287DIIAPvroo2tdyfr06VMWJYqIiIiUi2A/b54dnMCShzvTon6ozXhmfjF/XJTCsHc2kHJKbYVFRCq9wOrQ4w92O5DVLS7h6UtXWHHyFJPTM/E3m+3eYt2pdUz8ciJTVkxhw+kN6lgpN61MVgYBFBUV8d5777FkyRKOHz+Oj48PrVu35sEHH6Rdu3ZOrx0+fPi1tGvhwoVlUZ7cJq0MEhERuXklZguzNh3n5RX7ycovthk3GmByx8b8pncsQb5eFVChiIiUu5Ii2LMI1r8O53ZfN5RuNPJpSDCfhgSTZXK8liOpehJTk6bSvUF3DAatMq1syuL9d5mFQVK5KQwSERG5deez8vn78n0s2nHa7nitEF/+PCiBfom19Uu9iEhVYbHAoW+todCR764byjYYmBMSzMehwVx2cvZuTLUYpiZNpXej3piMjueJZ1EYJG5DYZCIiMjtW3/wIn9alMLhizl2x7vF1uC5IQk0iggs58pERKRCnd4O69+AvYvA8t+tYnkGA/ODg5geGsx5L8crSCNDIrk38V4GNhmIt9G7HAqWsqQwSNyGwiARERHXKCgu4b3vDvPm6oMUFtueDeHrZeThHtE80C0KXy99yisiUqVcPgwb/w3bP4Hi/Gv/XAgsCQ7kw9AQTno7DnvqBNbh3sR7GRYzDF+TbXdL8QwKg8RtKAwSERFxrWOXcnhm8R6+O3DB7nhUjUD+NiSRjtHVy7kyERGpcA46kBUDKwIDeD8slMM+jkOh6v7VmdRsEqPjRhPgHVAOBYsreUQY9Nxzz7nydgA888wzLr+n3B6FQSIiIq5nsVj4YvdZnlu2h3OZBXbnDG1Zl6cHNKNGsD7hFRGpcgpzrKuENrwFGcev/bMZ+DbAn/fCQtnn6+Pw8lDfUO5peg/J8cmE+tp2uBT35BFhkNFodPlBhyUlJS69n9w+hUEiIiJlJyu/iH9+fYCPNhzFbOc3tWA/L57sG09y+4aYjDpgWkSkyikptp4ntO616zqQWYB1/n68HxbCdj8/h5cHegcyNm4sE5pNIMI/oszLldvjUWGQq25rMBgUBrkhhUEiIiJlL+VUBk8vSmHniXS74y3qh/L8sCQS6+nTXRGRKslBBzILsMXPl/fCQvjB39/h5X4mP0bGjmRSwiRqB9Yuh4LlVnhUGOTn58eQIUOYOHEiTZs2va17NmrUyEXViasoDBIRESkfJWYLszcd5+UV+8nML7YZNxpgYodIHu8dS7CfOsaIiFRZDjqQ7fbx4b2wENYEOj4ryMvoxZAmQ7gv8T4ahDQoh2LlZnhEGHT33XezevVqzGbzte1ibdq0YcKECYwdO5YaNWq48nFSQRQGiYiIlK8LWQX8/Yt9LNx+yu54zWBfnhnUjAFJdVy+ZV9ERDzI5SOw8S2bDmSp3t58EBbCysAALA5+ThgNRvo37s/9SffTJKxJeVUspfCIMAjg9OnTfPrpp3zyySfs3m3dv2gwGPDy8qJPnz6MHz+eIUOG4Ourgw89lcIgERGRirHh4EX+uDiFwxdy7I53ianOX4ckElk9sJwrExERt+KgA9lRLy8+DAthWVAgxQ5CIQMGejbsydTmU2kW0ay8KhYHPCYM+rldu3bx0UcfMXv2bM6ePWt9qMFASEgIo0aNYvz48XTt2rUsS5AyoDBIRESk4hQUl/D+94d589uDFBSbbcZ9vIw81D2aX3aPwtfLVAEVioiI23DQgey0l4lpoSEsDAqi0Ekzgs71OvNA8wdoVbNVeVQrdnhkGHSV2Wzmm2++YebMmSxatIjc3NxrS5gbNmzIhAkTGD9+PLGxseVRjtwmhUEiIiIV7/ilXJ5ZksKa1At2x6OqB/LXoYl0iq5ezpWJiIjbcdCB7ILJyMyQEOaGBJFnNDq8vG2ttjzQ/AHurHOntiOXM48Og34uJyeHBQsW8NFHH7FmzZrrzhfq2LEja9euLe+S5CYpDBIREXEPFouFFSln+cvSvZzNzLc7Z3CLuvxxYFNqBjtuMywiIlWEgw5k6UYjn4QEMyskmCyT41AoqXoSU5Om0r1Bd4VC5aTShEE/d/r0aaZNm8bf//538vPz8fPzIzc3tyJLkhugMEhERMS9ZBcU86+vDzBjw1FKzLa/3gX7efHbPnHcc0cjTE62A4iISBVipwNZtsHAnJBgPg4N5rLJ8VbjmGoxTE2aSu9GvTEZtSW5LFW6MGjjxo18/PHHfPbZZ1y5cgWLxaIwyEMoDBIREXFPe05n8PTCFHacSLc73rx+KM8PTSKpfmj5FiYiIu7r8hHY+G/Y/vG1DmR5BgPzg4OYHhrMeS8vh5dGhkRyb+K9DGwyEG+jd3lVXKVUijDo0KFDfPLJJ3zyySccPnwY4FoINHjwYCZOnEj//v3LsyS5BQqDRERE3JfZbGH25uO89OV+MvOLbcaNBnioRzS/7hmDt5OtACIiUsXY6UBWCCwJDuTD0BBOejsOe+oG1mVK4hSGxQzD16TO4a7ksWHQlStXmDNnDh9//DE//vgjYA2ADAYDXbp0YcKECYwaNYqQkJCyLkVcRGGQiIiI+7uYXcDfv9jHgm2n7I63aBDGa2Na0lht6EVE5OfsdCArBr4MDOCDsFAO+zgOhar7V2dSs0mMjhtNgHdAORVcuXlUGFRUVMTSpUv5+OOP+fLLLykqKuLqo2JjY5kwYQITJkygYcOGZfF4KWMKg0RERDzHxkOX+OOi3Ry6kGMzFuBj4s+DmjG6bQMdBCoiIte72oFs/Wtw1tqBzAx8G+DPe2Gh7PP1cXhpqG8o45uOJ7lpMiE+WvhxOzwiDFq3bh2ffPIJ8+bNIz09/VoAFBERwdixY5k4cSLt2rVz5SOlAigMEhER8SyFxWb+890hXl+VRrGdA6b7JNTixeHNqRbo+Bd7ERGpoiwWOLza2pb+pw5kFmCdvx/vh4Ww3c9xt8pA70DGxo1lQrMJRPhHlE+9lYxHhEFGoxGDwYDFYsHX15fBgwczYcIE+vbti5eTQ6fEsygMEhER8Uy7T2bwyNztHLazSqhWiC//GNWCLjE1KqAyERHxCP/TgcwCbPHz5b2wEH7w93d4mZ/Jj5GxI5mUMInagbXLrdzKwKPCID8/P/r06UNYWNht3c9gMPDhhx+6pjhxGYVBIiIiniu3sJjnl+/j0x+P2x2/r3NjftsnDj9vtQoWEREHrnUg+wSK8wDY7ePDe2EhrAl0fFaQl9GLIU2GcF/SfTQIblBe1Xo0jwqDXKmkpMSl95PbpzBIRETE83299xy/m7+LyzmFNmPxtYN5Y1wrYmsFV0BlIiLiMXIuwqb3YdN/rnUgS/X25oOwEFYGBmBxkA+YDCb6Ne7H/Un30ySsSXlW7HE8JgxyNbPZ7PJ7yu1RGCQiIlI5nM/K54l5u/j+wAWbMR8vI0/1i2dyx0gdLi0iIs7Z6UB21MuLD8NCWBYUSLGDnyMGDPRs2JOpzafSLKJZeVbsMTwiDJKqQWGQiIhI5WE2W/ho41Fe+HI/hcW2H8J1i63BK6OaUzPY8QGhIiIigN0OZKe9TEwLDWFhUBCFRscfLnSu15kHmj9Aq5qtyqdWD6EwSNyGwiAREZHKJ/VsFo/M2c7+s1k2Y+GBPrw0ojl3N6tVAZWJiIjHudqBbP3rcHgNABdMRmaGhDA3JIg8J7uK2tVux9SkqdxZ506tTEVhkLgRhUEiIiKVU35RCa+sTOXDdUfsjiff0ZA/DmhKgI+6xIqIyA06vQM2vAF7FoLFTLrRyCchwcwKCSbL5DgUSqqexNSkqXRv0L1Kh0JVMgzaunUrbdq0qegy5H8oDBIREanc1qZd4PHPdnI+q8BmLKp6IK+PbUVS/dAKqExERDzW/3QgyzIYmBsSzMehwVw2Oe5gGVMthgeSHuDuRndjMla9TpdVKgzasGEDf/3rX/n6668pLi6u6HLkfygMEhERqfyu5BTy+wW7WLnnnM2Yl9HAb3rH8ouuTTA5Of9BRETExv90IMszGJgfHMT00GDOezleeRoZEsm9ifcysMlAvI3e5VhwxaoSYdCqVav429/+xvfff3/t39Ra3v0oDBIREakaLBYLn205wbNL9pJXZPs72R2Nw/nnmJbUC/OvgOpERMSjFebA9k9h45uQfpxCYHFwIB+GhnLK23EoVDewLlMSpzAsZhi+Jt/yq7eCeFQYZLFYWLhwId988w0nTpzA29ubyMhIRo4cSceOHW3mr1mzhj/84Q/8+OOP164H6N27NytWrCiLEuU2AcZWQgAAtYBJREFUKAwSERGpWo5czOHROdvZeTLDZizYz4vnhyUxuEXdCqhMREQ83v90ICsGvgwM4IOwUA77OF4BVN2/OpOaTWJ03GgCvAPKq9py5zFh0LFjxxgyZAi7d++2Oz5q1Cg+/fRTTCYTly5d4v7772fJkiWANQQyGAwMHjyYp59+mrZt27q6PHEBhUEiIiJVT1GJmTdWpfHv1Qcx2/kNcnirevxlSALBflVn6b6IiLjQ/3QgMwOrAvx5PyyUfb4+Di8L9Q1lfNPxJDdNJsQnpPzqLSceEQYVFhbSpk0b9uzZ4/ihBgOPP/44v/rVr+jWrRvHjh3DYrFgMpkYPXo0f/jDH0hISHBlWVXCwoULefvtt9m2bRs5OTnUqVOHO++8k5dffvna/3FcRWGQiIhI1bX56GUenbODU+l5NmP1q/nz2piWtI0Mr4DKRESk0vhZBzKLxcw6fz/eCwtlh5/jbWGB3oGMjRvLhGYTiPCPKL9ay5hHhEHTp0/nvvvuw2Aw0KhRI/74xz+SlJSEj48P+/bt45VXXmH79u0EBgbSsmVL1q9fD8CIESP4+9//TkxMjCvLqRIsFgu//OUvee+992jSpAl9+vQhODiY06dP89133/Hpp5/SuXNnlz5TYZCIiEjVlplfxDOLUli047TNmNEAD/eI5lc9Y/B20jJYRESkVD/rQGYpzmOLny/vhYXwg7/js+r8TH6MjB3JpIRJ1A6sXY7Flg2PCIMGDRrE8uXLadCgAXv27CEoKOi6cbPZTNeuXdmwYQMAJpOJDz/8kIkTJ7qyjCrl9ddf59FHH+XBBx/kjTfewPQ/LfmKi4vxcnIi+61QGCQiIiIAi3ec4o+LUsjKt+3+2qJBGK+PaUlk9cAKqExERCqVax3I3oO8y+zy9eH90BDWBDo+K8jL6MWQJkO4L+k+GgS7drdMefKIMKhhw4acOnWK119/nYcfftjunG+//ZZevXphMBiYNGkS06ZNc2UJVUpeXh716tWjWrVqpKamujz0cURhkIiIiFx18kouv/lsJ5uOXLYZC/Ax8eygBEa1rY/BoBb0IiJym/6nA1mqtzcfhIWwMjAAi4OfMyaDiX6N+3F/0v00CWtSzgXfvrJ4/+3ydbuXLl0CIDEx0eGc5s2bX/t65MiRri7hhp0/f55ly5bxzDPP0K9fP6pXr47BYMBgMDB58uSbutexY8d4/PHHiY+PJzAwkPDwcNq1a8crr7xCbm5u2bwA4KuvvuLKlSsMHTqUkpISFixYwIsvvsi7777LwYMHy+y5IiIiIlfVrxbA7Kl38ts+cXgZr/9FPLewhCfn7+L/PtnGlZzCCqpQREQqDZ9AuOMB+NV2GPEhcRHxvHLhEotPnWFIVjZedta7lFhKWHZ4GcMWD+Ox1Y+x99LeCijcvbh8GUleXh4Gg4GaNWs6nFO9evVrX1fkipJatWq55D5Lly5l/PjxZGZmXvu33NxctmzZwpYtW/jggw9Yvnw50dHRLnnez23duhWwbrdr3rw5Bw4cuDZmNBp57LHH+Mc//uHy54qIiIj8nMlo4KEe0XSJqc6jc3Zw+GLOdeMr9pxl+4krvDqqJZ1jqju4i4iIyA0yeUHSSEgcAYdX03j96/zt8BoeTM9gWmgIC4OCKPyfDygsWPjm+Dd8c/wbOtfrzAPNH6BVzVYV9AIqVoWf6Fde25pK07BhQ3r37n3T123fvp0xY8aQmZlJUFAQzz//PBs2bGDVqlVMnToVgAMHDjBgwACysrJcXTbnz58H4J///CehoaFs2rSJrKwsvv/+e2JjY3n11Vd55513XP5cEREREXua1w9j2a87M659Q5uxc5kFjP/wR55fvpeC4pIKqE5ERCodgwGa3AUTF8MD31E3fih/vJzBipOnmJSRib/ZbPeydafWMfHLiWw/v72cC3YP7pHEVJBnnnmGdu3a0a5dO2rVqsXRo0dp3LjxTd3jkUceIS8vDy8vL7766is6dOhwbeyuu+4iJiaGJ598kgMHDvDqq6/y7LPP2tzj8ccfp6Cg4KaeebXrmvmn/2P7+PiwaNEi6tatC0CXLl2YN28eLVq04NVXX+X//u//bup1iYiIiNyqAB8vXhieRI+4Gvxu/i6u5BZdN/7+2iOsO3iJ18e2JLZWcAVVKSIilU7dljByGvR8hhob/80T2z7m/vRMPgkJZlZIMFn/0+GyaQm0NHtXTK0VrMzCoLffftvpVrGbmffMM8+4qqzr/OUvf7mt6zdt2sTatWsBuO+++64Lgq56/PHHmT59Ovv27eP111/n6aefxtv7+v+z/ec//yEnJ8fmWkdGjhx5LQwKDQ0FoG3btteCoKsSExOJiori4MGDpKenExYWdjMvT0REROS29E6oTcsGYTzx+S6+P3DhurF9ZzIZ9OY6nuoXz6SOkTpcWkREXKdaJPR/Bbr9nrBN7/HwpveYdOIUc0OCmRkazJWfOnBPzczFUC2yQkutKGUWBpW2NenqD/wb2cJUVmHQ7Vq0aNG1r6dMmWJ3jtFoZOLEiTz11FOkp6ezevVqm+1o2dnZt1xDXFwcgMOg5+q/5+XlKQwSERGRclczxI8Zk9vx0cajvPDlfgqL/7tcv6DYzLNL97LmwAVeHtmcmsF+FVipiIhUOoER0OMp6PRrgrd/yv0b3+SeEyeYHxzEdwF+9EycBL5BFV1lhSiTM4MsFovL/rizdevWARAYGEibNm0czuvWrdu1r9evX+/SGnr06AHAvn37bMaKioo4ePAggYGB1KhRw6XPFREREblRRqOBKZ0as+ThTsTXtt0Wtib1Av1eW8uqfecqoDoREan0ftaBzH/4B4wPaMz7FzIx3ll1j1Nx+cqg1atXu/qWbutqABMdHe30IOz4+Hiba1ylSZMm9O7dm6+++ooPPviA+++//9rYiy++SHp6OuPHj7/pg7pPnjzpdPzMmTO3VK+IiIhUXfG1Q1j0UCdeXpHKtPVHrhu7lFPIfR9t4Z47GvLHAc3w9zFVUJUiIlJp/bwD2cU0CKq6iyZcHgb9fBVMZZafn8/FixcBqF+/vtO51apVIzAwkJycHE6cOOHyWt5++206duzI1KlTWbRoEfHx8Wzfvp1vv/2WRo0a8corr9z0PRs0aODyOkVERET8vE08M6gZ3eNq8MS8nZzPur6Jxqc/Hmfj4Uu8MbYVifVCK6hKERGp1AwGqBFb0VVUqApvLe+pft4mPiio9D2GgYGBwO2dD+RIkyZN2LJlC5MnT2br1q288cYbpKWl8dBDD7Fp0yZq167t8meKiIiI3I6usTVY8WhXejerZTN2+EIOw95ez7vfHaLE7N7HBoiIiHiiKt1a/nbk5+df+9rHx6fU+b6+voD1IOey0KBBA6ZPn+6y+5W2gunMmTO0b9/eZc8TERGRqic80If/TGjD3M0n+MvSveQVlVwbKyqx8OKX+1mTep5/jm5J3TD/CqxURESkclEYdIv8/P7b7aKwsLDU+QUF1iXQ/v6e8YtMaVvfRERERFzBYDAwtn1D2jcO59G5O9h1MuO68R8OX6bva9/z/LAkBrWoW0FVioiIVC7aJnaLgoP/2wnjRrZ+5eTkADe2pUxERESkqomqEcT8/+vIwz2iMRiuH8vML+ZXs7fzm892kJVfVDEFioiIVCIKg26Rn58fERERQOmdt65cuXItDNLBzCIiIiL2eZuMPNEnjrkPdKCenW1hC7adov8ba9l67HIFVCciIlJ5KAy6Dc2aNQPg4MGDFBcXO5y3f//+a183bdq0zOsSERER8WTtG4fzxSNdGNLSdlvYict5jHp3I//8+gDFJeYKqE5ERMTzKQy6DZ07dwasW8C2bt3qcN5333137etOnTqVeV0iIiIini7U35vXx7bi9bEtCf5/9u47KqqrawP4M0PviIKKIIqKBbE3VMSCWLEFexR7iRprYiyxJDGJscQeY++9BXtFERsWVFBRVFCwIiooIPV+f/DNfRmnMNRBfH5rzcrAOffcfWcG4t2cs4+BfJnLdAFYcjoM3isv4UlMvJYiJCIi+nIxGZQLXbp0EZ+r2skrPT0dmzZtAgBYWlqiRYsWBREaERERUZHQuVYZHBnrhvrliim03Yx8j/aLz2PXtUgIAregJyIi0hSTQbnQoEEDuLm5AQDWrl2LS5cuKfRZsGAB7t27BwAYO3Ys9PT0CjRGIiIioi+dvZUxdgxzxQ9tKkNXKl9dOj45DT/uuY3vtt7A+4Ssd3glIiKir3xr+YCAADx8+FD8+s2bN+Lzhw8fYsOGDXL9BwwYoDDG4sWL0aRJEyQmJsLT0xNTp05FixYtkJiYiB07dmDVqlUAACcnJ0ycODFfrqMgODs7y32dksKdPIiIiKjg6EglGNWiIppWLIFxO28i/I388rCjIS8R9PQ9FvaoicYVS2gpSiIioi+DRPiK59QOGDAAGzdu1Li/qpfq4MGD+PbbbxEXF6e03cnJCYcPH0bFihVzFGdhoCwZFBYWBgCIjIyEnZ2dNsIiIiKir1B8Uip+O3wX2wMjlbYPa+aIiZ5OMNDVKeDIiIiI8l5UVJS4M3le3X9/1TOD8oqXlxdu376NxYsX4/Dhw4iKioK+vj4qVqyI7t27Y/To0TA2NtZ2mLly584dua8zfxiJiIiICpKJgS7+6FYD7k42mLLvNt4lyM9YXuX/GAFhb7C4Vy1UKmmmpSiJiIgKr696ZhDlXH5kJomIiIiy61XcJ0zafQvnw94otBnoSjGtQ1X0a+QAiUSi5GgiIqLCLz/uv1lAmoiIiIi+WCXNDbFxYAP83LEa9HXk/2mblJqOGf/dwaANVxH9IUlLERIRERU+TAYRERER0RdNKpVgcNPy+G90E1RWsizM73402i7yx+l7r7QQHRERUeHDZBARERERFQlVS5vjv9FNMLBJOYW2mPhkDN54DdMPBCMxOa3ggyMiIipEmAwiIiIioiLDUE8HM72csXFQA1ibGSi0b7n8FB2XnkfIs1gtREdERFQ4MBlEREREREWOu5M1jo11Q+tqJRXaHkXHo+uKC1h57hHS07mXChERfX2YDCIiIiKiIqm4qQFW9auL37u6wEhPR64tJU3An0dD0XfNFTx/n6ilCImIiLSDySAiIiIiKrIkEgn6NCyLw983hUsZC4X2S49j0HaRPw7ffqGF6IiIiLRDV9sB0JfB2dlZ7uuUlBQtRUJERESUfY7Wptg7sjEWnXqAf849gpBpdVjcp1SM2nYDZ0LtMKtTNZgZ6mkvUCIiogLAmUFERERE9FXQ15Xix7ZVsH1oI9haGCq0770RhfZLzuP6k3daiI6IiKjgSARBYNU8yraoqCjY29sDACIjI2FnZ6fliIiIiIg0F5uYgp8PhMD31nOFNh2pBKNbVMSYlhWhq8O/nRIRkXblx/03/+9GRERERF8dCyM9LOldG4t61oKZgXzlhLR0AYtPh6HHv5fwNCZBSxESERHlHyaDiIiIiOir1aV2GRwZ64Z6DsUU2m48fY92i/2x+1okOJmeiIiKEiaDiIiIiOirZm9ljB3DGmFiayfoSCVybfHJafhhz22M3haE9wnJWoqQiIgobzEZRERERERfPV0dKca0qoS9IxujXHFjhfbDwS/QdtF5XHz0RgvRERER5S0mg4iIiIiI/l8te0sc/t4NPevZK7S9jPuEvmuu4I8j95CUmqaF6IiIiPIGk0FERERERJmYGOhirncNrPy2DiyN9eTaBAH41/8xui6/iIevP2gpQiIiotxhMoiIiIiISIm21Uvj+LhmaFqxhELb3Rdx6LAkAJsvRbC4NBERfXGYDCIiIiIiUqGkuSE2DWqA6R2qQl9H/p/OSanp+Pm/Oxi88RqiPyRpKUIiIqLsYzKIiIiIiEgNqVSCIW6OODCqCZxKmiq0nwl9jXaL/eEX+loL0REREWUfk0FERERERBqoZmsO39FNMaBxOYW2Nx+TMXDDVcz4LwSfUlhcmoiICjddbQdAXwZnZ2e5r1NSUrQUCREREZH2GOrpYFYnZzSvbI1Ju2/jzUf55WGbLj3BxUcxWNyrFpxtLbQUJRERkXqcGURERERElE3NK9vg+Dg3eFS1UWh7+Pojuiy/gFX+j5CezuLSRERU+EgEbn9AORAVFQV7e3sAQGRkJOzs7LQcEREREVHBEwQB2wKf4tdDd/EpJV2hvXGF4ljQoyZKWxhpIToiIioK8uP+mzODiIiIiIhySCKRoG9DBxwa44bqZcwV2i8+ikHbRedxJPiFFqIjIiJSjskgIiIiIqJcqmhjin0jm2Bk8wqQSOTbYhNT8N3WG5i0+xY+JqVqJ0AiIqJMmAwiIiIiIsoD+rpSTG5bBduGNIKthaFC+57rUWi/+DxuPH2nheiIiIj+h8kgIiIiIqI85FqhOI6ObYaONUortD19m4DuKy9h8akwpKYp1hgiIiIqCEwGERERERHlMQtjPSztXRsLe9SEqYGuXFtauoC/Tz1Az1WX8TQmQUsREhHR14zJICIiIiKifCCRSNCtjh2OjnVDXYdiCu3Xn7xD+yXnsfd6FLjBLxERFSQmg4iIiIiI8pG9lTF2DmuECa2doCOVry79MSkVE3ffwujtQYhNSNFShERE9LVhMoiIiIiIKJ/p6kjxfatK2DPCFQ7FjRXaD99+gbaL/XHx0RstREdERF8bJoOIiIiIiApI7bLFcPh7N/SoZ6fQ9iL2E/quuYI/jt5DciqLSxMRUf5hMoiIiIiIqACZGujiL++a+KdvHVgY6cm1CQLw77nH6PbPBTx8/VFLERIRUVHHZBARERERkRa0cymN4+OaoUnF4gptIc/i0HHpeWy+/ITFpYmIKM/pZt2FCHB2dpb7OiWFBQ6JiIiIcquUhSE2D2qItQHhmHf8PpLT/rc87FNKOn4+EIKzoa8x17sGSpgaaDFSIiIqSjgziIiIiIhIi6RSCYY2c8SBUU1QycZUof106Gu0XeQPv/uvtRAdEREVRRKB804pB6KiomBvbw8AiIyMhJ2dYhFEIiIiIsqeTylp+OPIPWy89ERpu4+rA6a0rwpDPZ0CjoyIiLQlP+6/OTOIiIiIiKiQMNTTwezO1bF+QH2UMNVXaN946Qm8lgbgzvNYLURHRERFBZNBRERERESFTIsqNjg2rhlaVbFRaAt7/RFdl1/Eav/HSE/nJH8iIso+JoOIiIiIiAqhEqYGWONTD792qQ5DPfl/tienpWPOkXvot+4KXsZ+0lKERET0pWIyiIiIiIiokJJIJOjXyAGHxrjB2dZcof3Cwxi0XeyPo8EvtBAdERF9qZgMIiIiIiIq5CramGL/d00w3N0REol82/uEFIzcegM/7rmF+KRU7QRIRERfFCaDiIiIiIi+APq6UkxpVxVbhzREaQtDhfZd16LQfsl5BD19p4XoiIjoS8JkEBERERHRF6RxhRI4NrYZOtQordD2JCYB3isvYcnpMKSmpWshOiIi+hIwGURERERE9IWxMNbDst61saB7TZjo68i1paULWHjyAXquuozItwlaipCIiAozJoOIiIiIiL5AEokE39S1w9GxzVCnrKVC+/Un79Bu8XnsuxEFQeAW9ERE9D9MBhERERERfcHKFjfGruGuGO/hBB2pfHXpj0mpmLDrFsZsD0JsQoqWIiQiosKGySAiIiIioi+cro4UYz0qYfcIV5S1MlZoP3T7Bdot9selRzFaiI6IiAobJoOIiIiIiIqIOmWL4chYN3jXtVNoex77CX3WXMbcY6FITmVxaSKirxmTQURERERERYipgS7md6+JFX3rwMJIT65NEIB/zj5Ct38u4FH0Ry1FSERE2sZkEBERERFREdTepTSOjXND4wrFFdpCnsWhw5Lz2HL5CYtLExF9hXS1HQB9GZydneW+TklhAUIiIiKiwq60hRG2DG6INQGPMe/4faSk/S/x8yklHdMPhODs/deY+00NFDc10GKkRERUkDgziIiIiIioCJNKJRjWrAL2f9cEFW1MFdpP3XuNNovO4+z911qIjoiItEEicF4o5UBUVBTs7e0BAJGRkbCzUyxSSERERESFS2JyGv44eg+bLj1R2j6gcTn81K4KDPV0CjgyIiJSJT/uvzkziIiIiIjoK2Gkr4NfOlfHugH1UMJUX6F9w8UIdFoWgLvP47QQHRERFRQmg4iIiIiIvjItq5TE0bHN0LKKjULbg1cf0WX5Baw5/xjp6VxEQERUFDEZRERERET0FbI2M8Ban3r4tUt1GOjK3xYkp6Xjt8P30H9dIF7FfdJShERElF+YDCIiIiIi+kpJJBL0a+SAw983RbXS5grtAQ/foM0ifxwLeaGF6IiIKL8wGURERERE9JWraGOG/aMaY3gzR0gk8m3vE1IwYssNTN5zG/FJqdoJkIiI8hSTQUREREREBANdHUxpXxVbBzdEKXNDhfad1yLRYcl53Ix8X/DBERFRnmIyiIiIiIiIRI0rlsCxcW7o4FJaoS0iJgHf/HMRS0+HIY3FpYmIvlhMBhERERERkRxLY30s61Mb87vXhIm+jlxbWrqABScfoOe/lxD5NkFLERIRUW4wGURERERERAokEgm869rhyFg31C5rqdB+7ck7tF98HgeCnhV8cERElCtMBhERERERkUoOxU2we7grxnlUgvSz4tIfklIxbudNfL89CLGJKdoJkIiIso3JICIiIiIiUktXR4pxHk7YPaIx7K2MFNp9bz1H+8XncT4sWgvRERFRdjEZREREREREGqnrUAxHvnfDN3XsFNqevU9Ev7WBGLnlOp69T9RCdEREpCkmg4iIiIiISGNmhnpY0KMmlvepA3NDXYX2oyEv4bHgHJb7PURSapoWIiQioqwwGURERERERNnWoUZpHBvXDI0crRTaElPSMO/4fbRddB7nHnDpGBFRYcNkEBERERER5YitpRG2DWmEP7u5oJixnkJ7+Jt4+KwLxIjNXDpGRFSYMBlEREREREQ5JpVK0KtBWfhNao6+DctCIlHsc+zOS7RacJZLx4iICgkmg4iIiIiIKNcsjfUxp6sL/hvVBLXsLRXaP6Wki0vHzt5/XfABEhGRiMkgIiIiIiLKMzXsLLFvZGPM/cYFVib6Cu3hb+IxYP1VDN98DVHvErQQIRERKZb/J1LC2dlZ7uuUlBQtRUJEREREhZ1UKkHP+mXRxrkUFpx4gC1XnkAQ5Pscv/MK5x5EY1TzihjazBGGejraCZaI6CvEmUFERERERJQvLI318WuX6jg4uilql7VUaP+Uko4FJx+g7SJ/+HHpGBFRgZEIwuc5eqKsRUVFwd7eHgAQGRkJOzs7LUdERERERIVZerqAPTei8OfRULyNT1bap3W1kpjRsRrsrYwLODoiosIrP+6/OTOIiIiIiIjynVQqQY969vCb2Bz9XR0gVbLr2Mm7r+Cx8ByWnA7DpxTuOkZElF+YDCIiIiIiogJjYayHXzpXh+/opqijZOlYUmo6Fp58gDaL/HEm9FXBB0hE9BVgMoiIiIiIiApc9TIW2DOiMeZ510BxJbuOPYlJwKAN1zBk4zVEvuWuY0REeYnJICIiIiIi0gqpVILu9exxZlJz+KhYOnbqXsbSsUWnHnDpGBFRHmEyiIiIiIiItMrCSA+zO1fHwTFNUc+hmEJ7Umo6Fp0Kg+ff/jh9j0vHiIhyi8kgIiIiIiIqFJxtLbB7hCsWdK+JEqaKS8eevk3A4I3XMHjDVTyN4dIxIqKcYjKIiIiIiIgKDYlEgm/q2uH0xOYY0Lic0qVjp0Nfw+Pvc/j7JJeOERHlBJNBRERERERU6FgY6WFWJ2ccGuOG+uUUl44lp6Zj8ekwtP77HE7d5dIxIqLsYDKIiIiIiIgKrWq25tg13BULe9RECVMDhfbIt4kYsukaBm24iicx8VqIkIjoy8NkEBERERERFWoSiQTd6tjhzCR3DGxSDjpK1o6dCX2N1n/7Y+GJ+0hM5tIxIiJ1mAwiIiIiIqIvgrmhHmZ6OePQmKZoUM5KoT05NR1LzjxE67/P4cSdlxAEQQtREhEVfkwGERERERHRF6VqaXPsHN4Ii3rWgrWZ4tKxqHeJGLb5OgZuuIqIN1w6RkT0OSaDiIiIiIjoiyORSNCldhmcmeiOwU3LK106dvZ+NDz/9scCLh0jIpLDZBAREREREX2xzAz18HPHajjyvRsalFeydCwtHUvPPITHwnM4zqVjREQAmAwiIiIiIqIioHIpM+wc1giLe9WCjZKlY8/eJ2L45usYsP4qwrl0jIi+ckwGERERERFRkSCRSNC5VhmcnuiOISqWjp17EI02f/tj3vFQJCSnaiFKIiLtYzKIiIiIiIiKFDNDPUzvWA1Hx7qhkaPypWPL/R6h9UJ/HAt5waVjRPTVYTKIiIiIiIiKJKeSZtg+tBGW9K6tcunYiC030H9dIB5Hf9RChERE2sFkENEXaMCAAZBIJChXrpy2Q/mixMTEYNKkSahatSqMjIwgkUggkUiwaNEibYdGRERE+UQikaBTTVucmdQcw5o5QlfJ0rHzYW/QZpE//jrGpWNE9HXQ1XYARIVVbGwstmzZgsOHD+Pu3bt4/fo19PT0ULJkSdSvXx+dOnWCt7c3dHR0tB0qaSA2Nhaurq4ICwvTdihERESkBaYGupjaviq617XDjP/u4NLjGLn2lDQBK84+woGgZ/i5YzW0rV4KEoli4oiIqChgMohIidWrV2PKlCmIiZH/R0JiYiLi4uIQFhaGbdu2oVq1avj333/RtGlTLUVKmlq+fLmYCPrxxx/h5eUFS0tLAEDp0qW1GBkREREVpEolzbBtaEMcuv0Cvx2+i1dxSXLtz2M/YeTWG3CrVAKzOjmjgrWpliIlIso/TAYRfWbSpElYsGABAEBXVxe9evVCp06d4ODggOTkZNy/fx/btm3DmTNncPfuXXh4eGDLli3w9vbWcuSkzqlTpwAA9erVw9y5c7UcDREREWmTRCKBV01btKhig6Wnw7A2IByp6fJFpM+HvUHbRf4Y3NQRY1pWhIkBb52IqOhgzSCiTFasWCEmguzs7HD16lVs3rwZ3bt3R4MGDdC0aVMMHjwYp0+fxtatW6Gvr4+kpCR8++23uHnzpnaDJ7WePXsGAHByctJyJERERFRYmBroYkr7qjg2zg1NKhZXaE9JE7Dy3CN4LDyHw7e56xgRFR1Mb5NGnJ2d5b5OSUnRUiT558mTJ5g4cSIAwMTEBKdPn1abOOjTpw8EQcC3336LpKQk9OvXD7dv3+ba8kIqKSljCrienp6WIyEiIqLCpqKNGbYMbogjwS/x66G7eBn3Sa79RewnjNp2A00qFsfsTtVR0YZLx4joy8aZQUT/b9GiRfj0KeN//DNmzNBoBknfvn3Rtm1bAEBISAgOHTqk0Kd58+aQSCRo3rw5ACAsLAyjR49GpUqVYGxsDIlEgoiICLlj7t27hwEDBsDe3h6Ghoawt7dHnz59cPXq1Wxd08uXLzFt2jTUq1cPVlZWMDAwgL29PXr06CEum1ImIiJC3Glrw4YNAIB9+/ahffv2sLW1ha6urng9Mg8ePMCYMWNQvXp1mJmZQV9fH7a2tqhVqxYGDRqEnTt3igmZnDp48CC8vb1hZ2cHAwMDFC9eHK6urvjzzz/x8aPidrBnz54Vr+PJkycAgI0bN4rfy/y+ZFd6ejq2b9+Ob775BmXLloWRkRGMjIzg5OSEvn37Ys+ePSqTpsnJyVixYgVatGgBa2tr6Ovro1SpUmjfvj22bNmC9PR0lef9fCe59+/fY8aMGXB2doaJiQksLS3RrFkzbN26Venxv/zyi3jtmhTTbtOmDSQSCUqXLo20tDSlfQ4cOIDu3bujbNmyMDQ0hKWlJerVq4fZs2fj3bt3Gl/LixcvMHnyZDg7O8PMzAwSiQRnz56VO+bp06cYOXIkypcvD0NDQ9ja2qJLly7w8/MDAMyaNUu8PnViY2Pxxx9/oEmTJuJ7ULp0aXh5eWHPnj1q//IrG3/WrFkAgKtXr6J3797i57JMmTLo168f7t27pzYGmZCQEIwZMwYuLi4oVqwY9PT0UKpUKXh4eOCvv/7CixcvVB6b059xIiJSJJFI0KFGaZye6I4R7hWgp6P4/5ILD2PQbrE//jh6D/FJ3HWMiL5gAlEOREZGCgAEAEJkZKS2w8m19PR0wcrKSgAgGBkZCe/fv9f42GPHjomvRdeuXRXa3d3dBQCCu7u7cODAAcHExETsL3uEh4eL/Xfu3CkYGBgo9AEg6OrqCmvWrBF8fHwEAIKDg4PKuLZs2aL0XJkfgwcPFlJSUhSODQ8PF/usW7dO6Nevn8Kx7u7uYv9du3YJ+vr6as8FQAgODtb4dc0sMTFR6Nq1q9qxbW1thaCgILnj/Pz8sowp83VoKjw8XKhVq1aWY/v5+Sk9tkqVKmqPa9q0qRATE6P03Jnf+9DQUKFcuXIqxxk1apTC8WFhYWL7rFmz1F7ny5cvBR0dHQGAMG7cOIX2t2/fCi1btlR7LTY2NsKlS5eyvJZLly4JJUqUUPsanj59WjA1NVV6HolEIsyZM0eYOXOm+D1VTp06JRQvXlxt3O3btxc+fPig9HhZn5kzZwrLly8XdHV1lY5hbGwsnDt3TmUcqampwvjx4wWJRKI2Fh8fH6XH5+ZnXBAEYf369XLXQkRE8sJefRD6rr4sOEw+pPTRcM4p4eCtZ0J6erq2QyWiIi4/7r+5TIwIwJ07d/D27VsAgJubGywsLDQ+1sPDA0ZGRkhMTERAQIDKfk+fPsW3334LY2Nj/Pzzz3Bzc4OOjg6uXr0KU9OMqcZXr15F3759kZqaCgMDA4wfPx7t27eHgYEBrly5gt9//x0jR45EtWrV1Ma0a9cu9OvXD4IgwNHREaNHj0a1atVgbW2NiIgIrF27FkeOHMHatWthbm6OhQsXqhxr0aJFuH37Ntzc3DBy5Eg4OTnh/fv34mymV69eYeDAgUhOToaNjQ1Gjx6NRo0aoUSJEkhMTMTDhw9x7tw5HDhwQOPX9HM+Pj7Yv38/AKBmzZqYOHEiqlatirdv32LHjh3YsGEDnj9/jlatWuH27dsoU6YMAKB+/foIDg4GkDHD5fnz5+jcuTN+++03cWwTE5NsxfLq1Ss0adIEz58/BwC0bNkSPj4+qFKlCiQSCcLDw3HmzBns3r1b4diPHz+iVatWePz4MQCgS5cuGDRoEGxtbREeHo5ly5bh3LlzCAgIgJeXF/z9/aGjo6M0joSEBHh5eSEmJgbTp0+Hh4cHTE1NERQUhNmzZyMqKgrLly+Hl5cX2rRpIx5XsWJFNGzYEFeuXMG2bdswc+ZMlde6c+dOcTZQ37595dqSkpLg4eGBGzduQEdHB3369EH79u1Rvnx5pKSkwN/fHwsXLsTr16/Rvn17BAUFwcHBQel5Pn78iG+++QafPn3CtGnT0Lp1axgbGyM4OFjc6e3x48fo1KkT4uPjoauri5EjR6JLly4wNzdHSEgI5s2bh2nTpqFhw4YqrwcALly4gHbt2iElJQUlS5bEmDFjULNmTdja2uL58+fYuXMntmzZgiNHjsDHxwd79+5VOdbx48cRGBgIFxcXjB07Fi4uLkhMTMT+/fuxePFiJCQkoF+/fggLC4O+vr7C8cOGDcO6desAZOxoN3r0aDRu3BgWFhaIjo5GYGAg9uzZo/TcefkzTkREylW0McXmwQ1wNCRj6diLWPmlYy/jPmH0tiBsq/AUv3R2RkUbMy1FSkSUA3mSUqKvTlGbGbRlyxbxen766adsH9+oUSPx+GfPnsm1yWYG4f9nrzx58kTlOPXq1RMACHp6ekpnFERFRQl2dnbieMpmBkVHRwsWFhYCAGHQoEEqZwVMnTpVACBIpVIhNDRUri3zzCAAQv/+/VX+1Wvt2rUazfxJSEgQEhISVLarcujQIXH8Vq1aCUlJSQp9Vq1aJfbp0aOH0nEcHBzUzrLQVOYZSnPnzlXZ78OHD8Lbt2/lvjdp0iTx2OnTpysck56eLvTt21fss2LFCoU+stk0AAQLCwshJCREoU9YWJhgaGgoABA6deqk0L5kyRJxjKtXr6q8hoYNGwoABCcnJ4U22efH0tJSuHbtmtLjIyIihNKlSwsAhD59+qi9FlNTU+HmzZsqY+nSpYvYd//+/Qrt8fHxQoMGDeQ+t59LTk4WZ1K1bdtWiI+PV3quzJ+nEydOKLRnPkf79u2VfiZ/++03sc++ffsU2v/77z+x3dXVVXj37p3Ka3/69Knc13nxMy4InBlERJQd8Ukpwp9H7wkVpx5WOkuowpTDwpzDd4UPn5T/TiYiyo38uP9mMohypKglgxYtWiRez6JFi7J9fOfOncXjb9++LdeWORm0adMmlWMEBgaK/UaPHq2y386dO9Umg3755RcBgFCmTBnh06dPKsdJSUkRypQpIwAQpk6dKteWORlkaWkpxMXFqRxnzpw5AgChWLFiKvvkRrt27cQE2ec3xZl5eHgIQMZSuufPnyu050UyKDQ0VFzS06VLl2wd++nTJ8HS0lIAIDg7OwupqalK+8XGxopLmKpVq6bQnjmBsmTJEpXn69WrlwBAsLKyUmh79eqV2uVfgiAIDx8+FM8ze/ZsubYPHz6IyYilS5equ2xhxYoV4vv38eNHldfyyy+/qBzj2bNnYrze3t4q+928eVNtMmjTpk0CAMHQ0FB4/fq12rhliSVlSSzZ+IaGhsKrV6+UHh8XFycunRw/frxCu6urq7iU7PMEclby4mdcEJgMIiLKiYevPwjfrlG9dKzBnJPCfze5dIyI8lZ+3H+zgDQRgA8fPojPZUu2siPzMXFxcUr76Ovro3v37irHyFzsdeDAgSr7de3aFZaWlirbfX19AQAdO3aEgYGByn66urpwdXUFAFy6dEllPy8vL5iZqZ72LFvG8+7dO/z3338q++VEamoqzp07BwDw9PSEvb29yr5Dhw4Vj/m86HBeOXz4sFhYePz48dk69vr163j//j2AjMLJqpZ/mZubo0ePHgCAu3fvqiweLJFI0KdPH5Xnq1u3LgDg7du34nllbGxs0Lp1awAZS8GUFazetm2b+Pzz85w7dw6xsbEAAG9vb5UxAECzZs0AZOxAeP36dZX9Pl+Glpmfn5+4XK1fv34q+9WsWRM1a9ZU2S772XB3d4e1tbVGcav72WjdujVsbGyUtpmZmaFSpUoAIC4LlImJicHly5cBAD179oStra3aWD6XVz/jAwYMgJDxRyGxGDYREalXwdoUmwY1wMpv68DWwlCh/VVcEr7fHoQ+q6/gwasPSkYgIiocmAwiAuSSHcp2pcpK5mPMzc2V9qlUqRIMDRX/0SAjq22jr6+v9oZWT08PtWvXVtqWlpaGmzdvAgD+/fdfuV2zlD1k9Uhevnyp8nw1atRQ2QYAnTp1EpNTXbt2RcuWLfH333/j+vXrKnef0tTjx4+RkJAAAFnWgsncHhISkqvzqhIUFAQg4z1o1KhRto7NHFNeXEuJEiVQvHhxlWNYWVmJzzMnO2VkyZcXL17gzJkzCu2yZFDDhg1RsWJFubZr166Jz0uXLq32M1a9enWxr6rPmampKRwdHVVeS+bXQJbkUqVevXoq22RxHz9+PMufjfnz56uNGQCqVKmiNhbZe/D563/z5k0xqejm5qZ2jM/lx884ERFlj0QiQdvqpXFqojtGtagAfR3FW6pLj2PQfvF5zDl8Fx+56xgRFUJMBhEh48ZaJic3Ta9evRKfq7pBL1asmNoxZAWsraysVM4akSlZsqTKMVJTs/8PDlnCRZms4i5evDh8fX1RpkwZCIIAPz8/TJgwQdzqulu3bjh06FC2YwL+95oAUDkDQ6ZUqVJKj8tLb968AZDxHikrCKxOXl+LsbGx2jGk0v/9eleWlOvSpYs4xufb0N+4cQOhoaEAlM/Yef36tdpzq6Lqc6ZuphsAue3ps5rRo649J3EnJiaqbNP0Pfj89Zd9joD/zazTVH78jBMRUc4Y6+vihzZVcGycG5o5Kf7/JzVdwOrz4Wg5/yz+u/lM/EMAEVFhwN3EiCA/+0U2+0NTaWlpuH37NoCMG1FVSz6ySvDISCSSbJ3/81hkhgwZgrFjx2p0nLrEhiZxu7m54eHDh9i7dy+OHDkCf39/REVFIS4uDvv378f+/fvRpk0b7Nu3L8sbaFVy87oUNoXhWkxNTdG5c2ds374d+/btwz///CPOXJPNCtLR0UHPnj0Vjs38Obtx4wb09PQ0OqednZ3S72v6s5FbsrjbtWuHv/76q0DOmdfy42eciIhyx9HaFBsH1sfxO6/w66G7ePZe/g8Jrz8kYeyOm9h25Sl+6VwdlUtx1zEi0j4mg4gAVK9eHVZWVnj79i38/f0RGxur8fbyp06dEv/qnt0lH5nJZuDExMQgLS1N7Q1y5plImWVeGiQIgtwSnfxmaGiIvn37ijNJwsPDcfjwYSxduhQPHjzA8ePHMW3aNPz9998aj5n5elRds0zmGV2Zj8tLshlkb9++RXJycrZusD+/FicnJ5V9C+JagIxZP9u3b0dcXBwOHToEb29vpKenY8eOHQBU18TJPPvN2tpaZZInr2SenRYdHY0yZcqo7BsdHa2yrXjx4nj+/DmSk5ML9Gfjc5lnIqqqCaWKNn/GiYhItYylY6Xg7mSN5X4Pscr/MZLT5GvyXQl/i/ZLzmNA43IY51EJZoaa/TGFiCg/cJkYETL+B96/f38AGctCVq9erfGxS5cuFZ8PGDAgxzG4uLgAAJKTk3Hr1i2V/VJTU8WaIZ/T19eHs7MzAODChQs5jiUvlC9fHqNHj8bVq1fFZMGuXbuyNYajo6M4k+jKlStq+wYGBorP8+sGuU6dOgAyiiGrKyysTOaYCsO1AECbNm3ExIRsNtC5c+fw7NkzAKqLOmeuWVUQnzPZZxqA2iLUgHw9o8/J4r527RqSk5PzJrgcqF27tjg7zN/fP1vHFqafcSIiUmSkr4NJbSrj+PhmaF5ZcelYWrqAtQHhaLngHA4EcekYEWkPk0FE/2/s2LHizjyzZ8/Gw4cPszxmx44dOHz4MICMm/aOHTvm+PweHh7i840bN6rst3//frkaKp/r1KkTACA0NBTHjx/PcTx5xdzcHPXr1wcgXytFE7q6unB3dwcAnDx5ElFRUSr7rlmzRjymefPmOQs2Cx06dBBv4hctWpStY+vWrSvWxtm4caPSHbyAjGLDsqRZtWrVsl1TJjt0dXXFncuOHDmC9+/fi0khY2NjdOnSRelxHh4eYpJuyZIl+f4P2ebNm4v1dzZv3qyy361bt9QmUmU/G7GxsVi/fn3eBpkNVlZWaNy4MYCMBOnz58+zdXxh+xknIiJF5UuYYP2A+ljVry7KWBoptEd/SMK4nTfRc9VlhL5UvhMtEVF+YjKI6P+VK1cO8+bNA5CxO1irVq3U3lju2rULPj4+ADL+Wr958+Zc1YJp0KCBOPPkn3/+QUBAgEKfFy9eYNKkSWrHGTt2rLjV/cCBA3Hnzh21/Q8fPizWPMqJ48ePq13qEhsbK850KV++fLbHHzVqFICMGVODBw9GSkqKQp9169bhxIkTAIBu3brlWwLFyckJXbt2BQAcOHBA/LwoEx8fL5e0MzAwwJAhQwBk7I7166+/KhwjCAJGjx4tJs1Gjx6dl+ErJZv9k5SUhG3btmHv3r0AgM6dO4ufo89ZWlqKsV28eBHjx49XmdwCMpbFyZJ1OWFnZ4cOHToAAPbs2YMDBw4o9ElMTMSwYcPUjuPj4wN7e3sAwKRJk7KclRMQEIBz587lLOgsTJ48GUBGYefu3bsjNjZWZd/Pk6B59TO+YcMGcdcxbi1PRJT3JBIJPJ1L4dQEd3zfsqLSXccCw9+iw5IA/HLwLuI+Kf4bh4govzAZRJTJmDFjxIKsT58+Rb169dC/f3/s2bMHV69excWLF7Fu3Tp4eHigZ8+eSE5OhoGBAbZu3YpatWrl+vwrVqyArq4uUlJS0Lp1a0ydOhUBAQG4evUqli1bhrp16+LFixdqt54vWbIkNm7cCIlEghcvXqBevXoYOXIkfH19cePGDVy5cgV79+7F5MmTUaFCBXTs2BFPnz7Ncczbt2+Hg4MDOnTogMWLF+P06dMICgqCv78/VqxYAVdXV3HZ0YgRI7I9focOHdC9e3cAwIkTJ9CoUSNs3boV169fx6lTpzBkyBAxyWJlZYWFCxfm+Fo0sWLFCrFI+I8//ohWrVph8+bNuHr1Kq5du4Y9e/Zg1KhRKFu2rEIyccaMGeIW6rNmzYK3tzcOHz6MGzduYO/evWjZsiU2bdoEAHB1dc0yuZEXGjduLCbppk2bJiawVC0Rk/nll1/QsGFDAMDixYtRp04dLF++HBcuXMDNmzfh5+eHZcuWoUuXLihbtixWrlyZqzgXLlwozkbq3r07vv/+e/j5+eH69evYuHEj6tWrh8DAQHEWmjIGBgbYtWsXDAwM8PHjR7Rs2RLffvst9uzZg+vXr+Pq1avw9fXFzJkzUaNGDbi5uSE4ODhXcavi5eWFwYMHA8hIqFWrVg1//PEH/P39cfPmTZw6dQp//vknateujenTp8sdW9A/40RElDtG+jqY4FkZJ8Y3QwsVS8fWXQhHy/nnsD8oikvHiKhgCEQ5EBkZKQAQAAiRkZHaDifP/fPPP4KVlZV4jaoeVatWFfz9/dWO5e7uLgAQ3N3dNTr3tm3bBH19faXn09XVFVatWiX4+PgIAAQHBweV4/j6+mp0DVKpVDhz5ozcseHh4WL7+vXr1cYriyWrx4gRI4S0tDSNXoPPJSYmCl27dlU7vq2trRAUFKRyDAcHBwGA4OPjk6MYMnv06JFQvXr1LK/Zz89P4djw8HChSpUqao9r0qSJEBMTo/Tcmrz3giAI69evF8cLDw9X23fatGly5y9RooSQkpKS5esQFxcndOvWTaP3v0WLFjm+FpkTJ04IJiYmKs8xc+ZM4eeffxYACIaGhirHuXTpkmBvb69R3Bs3blQ4PvP51MnqZz81NVUYPXq0IJFI1Mag6jObm59xQZD/jGR1LURElDfS09OFE3deCk3+PC04TD6k9NH9n4vC3eex2g6ViAqR/Lj/5swgIiVGjBiBR48eYenSpWjbti3s7e1haGgIU1NTVKhQAb169cL27dsRHBycqx3ElOnduzeCgoLQr18/2NraQl9fH2XKlEGPHj0QEBCAoUOHajSOl5cXwsPDMX/+fLRs2RIlS5aEnp4ejIyMUL58eXTs2BELFy5EREQEWrRokeN4//77b2zZsgWDBg1CvXr1UKZMGejr68PIyAhOTk7w8fHB+fPn8c8//4h1X7LL0NAQ+/btg6+vL7p16ya+LsWKFUPDhg3xxx9/4P79+3kyO0sTjo6OuHnzJjZs2IAOHTqgdOnS4mvr5OSE/v3747///lP62ShXrhxu3bqFZcuWwd3dHcWLF4eenh5KliyJtm3bYvPmzfD398/XXcQ+9/ksoB49ekBXN+vNJs3MzLB3716cP38eQ4YMQeXKlWFmZgZdXV1YWVmhfv36GDVqFI4cOYKTJ0/mOs7WrVsjJCQEw4cPh4ODA/T19VGyZEl06NABx44dw6xZsxAXl1F3Qd1ugI0aNUJYWBhWrlyJDh06iJ8nQ0ND2Nvbw9PTE3PmzEFoaKhYWD4/6OjoYOnSpbh27RqGDRsGJycnmJiYQE9PD6VKlYKnpycWLlyI+fPnKz2+oH7GiYgo70gkErSuVjJj6VirStDXVbJ0LOItOi4NwCzfO1w6RkT5RiIInIdI2RcVFSXW3oiMjMz3raWJiDTh4eGB06dPo2nTpjh//ry2wyEiIlLrSUw8Zh+8izOhr5W2lzA1wJR2VdCtTplc1aYkoi9bftx/c2YQEREVCc+fPxeLQjdq1EjL0RAREWXNobgJ1g2ojzX968HeSnHXsTcfkzBx9y10X3kJd59z1zEiyjtMBhER0Rfh4cOHKtsSExMxYMAAcbe5/FzeRURElNc8qpXEyfHuGOehfOnYtSfv0HHpeczyvYPYRC4dI6Lcy7ooBBERUSEwZMgQxMfHo0ePHqhbty6srKzw4cMHXLt2DStWrBCTRYMHD4aLi4uWoyUiIsoeQz0djPNwQrfadvjl0B2cuie/dCxdADZcjMCh28/xU7uq6Fa7DKRSLh0jopxhMoiIiL4Y165dw7Vr11S2d+3aFUuXLi3AiIiIiPJW2eLGWONTH6fvvcLsg3fx9G2CXPubj8mYtPsWtgc+xS+dneFsq3rTBCIiVVhAmnKEBaSJqKDduHED+/fvx5kzZxAVFYXo6GgIggAbGxs0atQIPj4+aN++vbbDJCIiyjOfUtLw77nHWHH2IZJS0xXapRKgXyMHTPCsDAsjPS1ESEQFIT/uv5kMohxhMoiIiIiIqGBEvk3A7IN3cereK6XtxU30MbldFXjXsePSMaIiiLuJERERERERfWXsrYyxxqce1g+oD4fixgrtMfHJ+HHPbXivvIiQZ7FaiJCIvjRMBhEREREREX0BWlSxwfFxzTCxtRMMlOw6duPpe3RaFoCfD4QgNoG7jhGRakwGERERERERfSEM9XQwplUlnJrgDs9qJRXa0wVg8+UnaLHgLHZdjUR6OquCEJEiJoOIiIiIiIi+MPZWxljVvx7WD6yPckqWjr2NT8aPe2+j2z9cOkZEipgMIspjzZs3h0QiQfPmzbUdChEREREVcS0q2+DYuGaY5OkEQz3F27ubke/htSwA0w8E431CshYiJKLCiMkgos/Ex8dj5cqVaN++PcqUKQNDQ0MYGBjA2toa9evXx6BBg7B69WpERkZqO9Q8d/bsWUgkEqUPY2Nj2Nvbo2PHjli3bh2SkpKyHE92bFaJsdTUVPTs2VPs36hRI7x//z5vLuozDx8+xPbt2zF+/Hg0adIExsbG4nk3bNiQZ+eRJQU1eaiSnp4Of39/TJ06Fc2bN0epUqWgr68Pc3NzVK9eHd999x1u376drbi2b98OT09PlCpVCoaGhnBwcMC3336LS5cuZXns4cOHMWvWLHTo0AFVq1ZFiRIloKenh2LFiqFu3bqYOHEi7t+/n614iIiIKPcM9XQwumXG0rE2zopLxwQB2HL5KVrMP4sdgU+5dIyIuLU85UxR3Vr+0qVL6NWrF54+fZpl35IlS+Lly5cK32/evDnOnTsHd3d3nD17Nh+izD9nz55FixYtNOrr7OyMQ4cOoVy5cir7yBId6l6LlJQU9OzZE/v37wcANG3aFEeOHIGZmVm2YtfEuXPn1Cam1q9fjwEDBuTJuWSfA02o+jVctmzZLJOOUqkUkyZNwp9//qk2sZSYmAhvb28cOXJE5TgzZszAzJkzlbanpqZCT09PbSwAoKenh19++QU//fRTln2JiIgof5x7EI1ZvncQ/iZeaXtNe0v82tkZNewsCzYwIsqR/Lj/1s31CERFxIMHD9CmTRt8+PABANCpUyd4e3vDyckJ+vr6ePPmDW7duoWTJ0/Cz89Py9Hmv5EjR+K7774Tv379+jVCQkIwb948REVF4c6dO+jUqROCgoKgo6OTo3MkJSXB29sbhw4dApCRQDl06BBMTEzy5Bo+lznpIpVKUbVqVZiYmCAwMDBfzgcA9erVw/r163N07PPnzwEAFStWxDfffIMmTZrA1tYWiYmJ8PPzw99//413797hr7/+go6ODn7//XeVYw0aNEhMBLVo0QJjx46Fra0tgoOD8fvvv+PRo0eYNWsWSpcujWHDhikdw8LCAs2bN0fDhg3h6OiI0qVLw9jYGM+fP8fZs2exbt06xMbGYsqUKbC0tMSIESNydN1ERESUO+5O1jg2zg1rzodj6ZkwfEpJl2u/FfkenZdfQO8GZfGDZ2UUM9HXUqREpDUCUQ5ERkYKAAQAQmRkpLbDyRPe3t7iNa1fv15t39evXwvLli1T2ubu7i4AENzd3fM+yHzm5+cnvgYzZ85U2icuLk4oV66c2G/37t0qx5P1UfZaJCYmCm3atBH7tG7dWkhISMijK1HuwYMHwrx584SzZ88KHz58EARBENavX6/x+54defE5cHV1FY4dOyakp6crbX/48KFgbW0tABB0dXWFR48eKe13+vRp8Rq9vLyE1NRUufbo6GihbNmyAgDB0tJSePv2rdJxPj/uc48fPxaKFSsmABCsra2z7E9ERET5L+pdgjBi8zXBYfIhpY9as48L2648EdLSlP97g4i0Lz/uv1kziAhAWloaDh8+DCBjJkdWS4Wsra0xatSoAois8DEzM8P06dPFr0+dOpXtMRISEtCxY0ccP34cANCuXTv4+vrCyMgoz+JUplKlSpg0aRLc3d1hamqar+fKCxcvXkSbNm1ULv+qUKECZsyYASBjGdeBAweU9ps/fz4AQFdXFytWrFCYyVWiRAnMnTsXAPD+/XusWbNG6ThZzQArX748evToAQCIjo5GaGio2v5ERESU/8pYGuGfb+ti06AGcCyhOPv6XUIKpuwLRtcVF3Ar8n3BB0hEWsFkEBEyblwTExMBZCzJKQgBAQHo168fypUrB0NDQ1haWqJ27dqYPn06oqOjlR4zf/58SCQS6Onp4ePHjwrtnz59gqGhoViY+ObNm0rHqVKlCiQSCXr16pWj2F1cXMTn2S2k/fHjR7Rv3x6nT58GkLEc78CBAzA0NMxRLF+7zDWeHj16pND+4cMH8bX28PBQub64W7duMDc3BwCxflNOZK719OnTpxyPQ0RERHmrmZM1jo5zw49tK8NIT/EPPLeiYtFlxQVM2ReMd/HcdYyoqGMyiAiAvv7/1knfu3cvX8+Vnp6O0aNHw83NDVu2bMGTJ0+QlJSE2NhY3Lx5E3PmzEGlSpVw8uRJhWPd3d0BZMwCCQgIUGi/cuWK3C5fyoo2v3r1StzxKatdvlTJ/HppUlRYJi4uDm3bthULK3t7e2PPnj1y4ymTeZezvCrwXFRkfr+Vzdy5evUqkpMz/kEn+/woo6+vj0aNGonHpKSkZDuWxMRE/PfffwAyajI5OTllewwiIiLKPwa6OviueUWcnuiO9i6lFNoFAdge+BQtFpzF1itPkMZdx4iKLCaDiABYWVnBwcEBAHDr1i3MnTsX6enpWRyVMz/99BOWL18OIGNZzcqVKxEYGAg/Pz+MHz8eenp6iI2NRceOHXHr1i25Y+vUqSPOvFCW6Pn8e1n1UZccUCdzwkzdbmKZxcbGwtPTExcuXAAA9O7dG9u3b89WMulLFBoaioYNG8LS0hKGhoaws7ND586dsWnTphwlXD6XeceyqlWrKrTfvXtXfF6lShW1Y8naU1NTERYWptH5U1JS8PTpU+zYsQONGzcWjxs0aFC+7AhHREREuWdraYQVfeti8+AGcLRWXDr2PiEF0/aHoOuKC7jJpWNERRKTQUT/b8yYMeLzn376CRUqVMDYsWOxc+dOhIeH58k5goODsWDBAgBA9erVcePGDQwfPhz169dH8+bNsXDhQvj6+kIqlSI5OVlhVycdHR00bdoUgPJEjywx4OXlBQDw9/dXSGrJ+pQsWVJp8iAraWlpmDdvnvi1t7d3lsfExsbCw8MDV65cAQD0798fW7Zsga5u0d/Q8NWrVwgMDERsbCySkpLw7Nkz+Pr6wsfHB7Vq1crVTLSEhAQsWrQIAGBgYIDOnTsr9ImKihKfZ7UFpWy7SkD98r+IiAhxppa+vj4cHBzQu3dvcVlimzZtxM85ERERFV5ulaxxbGwz/NSuCoz1FWcY346KRdcVF/DT3tt4y6VjREUKk0GkEWdnZ7lHy5YttR1Snhs/fjwGDRokfh0REYElS5agV69ecHR0RKlSpdCrVy8cPHhQbovy7Pjnn3/E5MyaNWtgaWmp0Kdt27ZiHIGBgbh69apcu2xp1/Xr1+XqBiUlJeHy5csAgMmTJ8PIyAjv3r3D7du35Y6XJZGaNWuWrdijo6Nx5swZuLu7IygoCEBGIkiWnFLn5s2buHbtGoCMGUHr16+HVFq0f/1IpVK0atUKCxYswKlTpxAUFAR/f38sWrRITMLdvXsXLVq0wNOnT3N0jsmTJ4vHjho1Cra2tgp9Pnz4ID7Pqmi2icn//jKorCZVVkqUKIGdO3fi8OHDYv0hIiIiKtz0daUY4V4Bpye6o0ON0grtggDsuBqJFvPPYstlLh0jKiqK9t0YUTZIpVKsXbsWJ06cQNu2bRVmrbx69Qo7d+5Ep06d0KBBA6XFerMi23nL2dkZDRs2VNlv6NChCsfIqKobFBgYiMTERFhYWKBRo0Zi/ZfMM4hev34tzkTJql7Q7NmzxdkfEokENjY2aNWqFS5cuABjY2NMmDAB27Zty/qiAbndsC5duoTnz59rdJxM8+bNIQgCBEHAhg0bsnWstuzbtw+nTp3ChAkT0KpVK9SqVQtubm4YO3Ysbt26BR8fHwAZn6tx48Zle/ytW7di2bJlADKWh/32229K+2Uu4pxVbSYDAwPxuaygujJlypRBcHAwgoODERQUhEOHDmH06NGIj4/HiBEj8Ndff2XnUoiIiKgQKG1hhOV96mDrkIaooGTpWGxiCqYfCEHn5QEIevpOCxESUV5iMog0cufOHbnHmTNntB1SvmndujWOHj2KmJgYHDlyBLNnz4aXlxcsLCzEPteuXYObmxtevHih8bhJSUliPRV1iSAAqF27tlhLJyQkRK6tbt264gyPzIke2fOmTZtCR0dHTPZk7pO5vkxO6wUBQK1atfD9999rXO+nadOm4s5lERERaNWqFV6+fJnj838JlM36ktHT08OaNWtQuXJlABm7dz179kzjsc+ePYvBgwcDyKh3tXfvXhgZGSntm3mXNlkhaVUyF6NWNZ4s/urVq6N69eqoVasWOnTogKVLl+Ly5cuQSCSYOnWq3Cw7IiIi+nI0qVgCR8c2wxQVS8dCnsWh64qLmLznNmI+JikZgYi+BEwGEalgbm6Odu3aYcaMGfD19cWrV6+wbt06FCtWDADw4sUL/PzzzxqP9+7d//6CYmNjo7avnp4eihcvDgB4+/atXJuuri6aNGkCQHmiR5YEkv03c90gWR9ra2s4OzurjWHkyJFysz8OHjwIHx8fSKVSXLx4Ec2bN0d0dLTaMWSkUik2b96MLl26AAAePHiA1q1bIyYmRqPjiyJdXV0xoQPIJ+rUuXbtGjp16oSkpCSYmpriyJEjams/ZS7inNXSr/j4ePF5VkvKlKlRo4Y4Q2n9+vU4ceJEtscgIiIi7dPXlWL4/y8d66hk6RgA7LyWsXRs86UILh0j+gIxGUSkIQMDAwwcOBDbt28Xv7dv374c7TqWedlUTnxeNyglJQWXLl2Sa2vYsCEMDQ3l6gbJEg6a1AuysbGRm/3RsWNHbNiwAevWrQOQMcNnyJAhGsesq6uLnTt3om3btgAyZjx5enoiNjZW4zGKmmrVqonPNZkZdOfOHbRt2xYfPnyAgYEBDhw4kOUss8xFozMXk1Ymc9HozMWksyNzEes9e/bkaAwiIiIqHEpbGGFZnzrYNqQhKtoo/qEo7lMqfv7vDjovD8D1J1w6RvQlYTKIKJvatGkj3ii/e/dO49ktshlFQEadGHVSU1PFca2srBTaP68bFBgYiISEBFhYWKB27doAMpJXmesGvXnzBnfu3AGQdb0gdXx8fPDNN98AAHx9fbO1ZFBfXx/79u1DixYtAAA3btxAu3btclSsuCjITlLw0aNH4mwqWWKtVatWWR6XOeEUGhqqtq+sXVdXF5UqVdI4tsysra3F50+ePMnRGERERFS4NK5YAkfHumFa+6owUbF07Jt/LuKH3bfwhkvHiL4ITAYR5UDmXZs0vaE3MDAQb7BlW6yrEhQUhJSUFAAZW9B/rn79+uLOT2fPnhVn/MjqBclkrhvk7+8v7oKWm3pBAPD777+L55k6dWq2jjUyMoKvry9cXV0BZBSU9vLyUluwuKi6e/eu+FzZTmAyUVFR8PDwwIsXLyCVSrFx40al28grU79+fbFwtLqlaMnJyeJudPXr19e4HtTnMs9wyslSMyIiIiqc9HSkGNrMEacnNkenmsr/3bL7ehRazj+LTVw6RlToMRlElE0JCQniTby5ublY20cTHh4eADKW+wQGBqrst2bNGoVjMtPV1UXjxo0BZCR6ZLWDPp/xk7lukGwGT/HixZUmmLLDyckJPXr0AJCR2Dp58mS2jjc1NcXRo0dRt25dABnX0K1btywLHBclqamp4pI7QPXSvdevX8PDwwMREREAgJUrV6JPnz4an8fMzEycQXTq1CmVS8X27duHuLg4AEDXrl01Hv9zu3fvFp+7uLjkeBwiIiIqnEpZGGJJ79rYPrQRKqlYOjbjvzvwWhqA60/eKhmBiAoDJoOIkFFYt2HDhjh06JDaGkDp6ekYM2YMPnz4AADo1KlTtpb6jBw5ElJpxo/dsGHDxJvvzE6cOIG1a9cCABo0aID69esrHStz3aALFy7IfU+mYcOGMDAwwLt377BlyxYAGUmH3NYsAjJmBMnGUbWtuToWFhY4fvy4mDA4duwYevbsidTUVIW+Z8+eFbe4HzBgQK7izgsDBgwQ48lcxFvGz88P79+/V3l8SkoKhgwZgnv37gEAvLy8lNboef/+Pdq0aYP79+8DAP7++28MHTo02/FOmjQJQEYCatSoUUhLS5Nrf/PmDSZPngwgYxc0ZbWgDhw4kOXuef7+/vjll18AZCQse/fune1YiYiI6MvgWqE4jox1w/QOypeO3X0Rh2/+uYRJXDpGVCjpajsAosIiMDAQXl5eKFOmDLp06QJXV1c4ODjAzMwM79+/R1BQENatW4fg4GAAGcmMX3/9NVvncHFxwcSJEzFv3jzcunULderUweTJk1G7dm3Ex8fj4MGDWLJkCdLS0qCvr49///1X5ViZ6walpqbK1QuSMTQ0RKNGjXDu3DmxUHNu6gVlVr16dXTq1An//fcf/P39ERAQgKZNm2ZrjOLFi+PkyZNwd3fH/fv3ceDAAfTv3x9btmwRk2Z5bc+ePXI1igICApQ+B4BSpUqJBa+zY+PGjejUqRM6deqE5s2bo3LlyjA3N8fHjx9x/fp1rFq1SpxdZmNjg8WLFyuMkZSUhA4dOuDmzZsAgL59+8LDwwMhISEqz2tiYoLy5csrfL9ly5bo1asXduzYAV9fX7Ru3Rrjxo2Dra0tgoODMWfOHDx9+hQAMHfuXLn6VjIHDhxAz5490aFDB7Rq1QrOzs6wtLREUlISHj16hIMHD2LXrl1iMnXGjBmoXLlytl87IiIi+nLo6UgxxM0RXjVt8fuRe/jv5nOFPnuuR+H4nZeY2NoJ3zZygK4O5yMQFQoCUQ5ERkYKAAQAQmRkpLbDybXExEShVKlS4jVl9ahUqZJw7do1pWO5u7sLAAR3d3el7WlpacJ3332ndnwLCwvh+PHjamNOTk4WjI2NxWM6dOigtN/MmTPlxr5586bKMf38/MR+M2fOVHt+QRCEwMBAsb+np6dCu6xN1WshExUVJZQvX17sP2jQICE9PV1pXD4+PlnGpY6Dg4PG77OquH18fMQ+fn5+atvVPVxcXIQ7d+4oPUd4eLjGcWryOickJAjt27dXeaxUKlX7nmt6TUZGRsKCBQvUvANERERUVF169EZovfCs4DD5kNJH20X+wtXwGG2HSfTFyY/7b6ZliZAxg+bZs2e4cOECZs+ejXbt2sHR0REmJibQ0dGBubk5qlSpgp49e2Lbtm0ICQkR691kl1QqxfLly+Hv74++ffuibNmyMDAwgLm5OWrVqoWpU6ciLCwMnp6easfR09MTizADqmf8ZP6+lZUVatSokaO4lalfvz5at24NIGN529WrV3M0TpkyZXDmzBlxqdS6deswZsyYPIuzoE2ePBl///03evTogerVq6NkyZLQ09ODqakpKlSogJ49e2L37t0ICgqS2+0rPxkZGeHw4cPYunUrWrduDRsbG+jr68Pe3h59+vRBQEAAZs2apfL4v/76C+vWrcOAAQNQr1492Nvbw8DAAEZGRihTpgw8PT3x559/4tGjR5gwYUKBXBMREREVLo0ci+Pw9xlLx0wNFBeh3HsRB++VlzBh101Ef+DSMSJtkgiCwDLvlG1RUVHijXtkZCTs7Oy0HBERERERERUWr+M+4Y+jodgf9Expu5mBLiZ4OqEfl44RZSk/7r/5U0dERERERER5ysbcEH/3rIWdwxqhSikzhfYPSamYffAuOi4NQGA4dx0jKmhMBhEREREREVG+aOhYHIfGNMWMjtVgpmTpWOjLD+jx7yVM2HkTrz980kKERF8nJoOIiIiIiIgo3+jqSDGoaXmcnuSObrXLKO2zL+gZWs0/h7UB4UhNSy/gCIm+PkwGERERERERUb6zMTPEwp61sGu4q8qlY78eylg6duVxjBYiJPp6MBlEREREREREBaZBeSscGtMUM71ULx3rueoyxu0Iwus4Lh0jyg9MBhEREREREVGB0tWRYmCT8jgzqTm+qaN8Z6QDN5+j5YJzWBcQjrR0boJNlJeYDCIiIiIiIiKtsDYzwIIeNbFnhCuqljZXaP+YlIpfDt3FN/9cxINXH7QQIVHRxGQQERERERERaVW9clY4OLoJZndyhpmh4tKxm5Hv0WHJeSw5HYbkVBaYJsotJoOIiIiIiIhI63R1pPBpXA5nJjaHd13FpWMpaQIWnnyATssCcCvyfcEHSFSEMBlEREREREREhYa1mQHmd6+JncMaoXwJE4X20Jcf0HXFBfxx5B4Sk9O0ECHRl4/JIKLPxMfHY+XKlWjfvj3KlCkDQ0NDGBgYwNraGvXr18egQYOwevVqREZGKj1+wIABkEgkCg+pVApLS0vUrFkTo0aNws2bN/Ml/oiICKXnl0gkMDQ0hK2tLTw9PbF48WLExcVlOV65cuUgkUhQrly5LPtOmDBBPFelSpVUvka5FRkZib179+Knn35Cy5YtYWFhIZ531qxZeX6+1NRUrFy5Em5ubrC2toaRkREqVKiA4cOH486dOxqP8+bNG8yYMQM1atSAubk5zM3NUaNGDcyYMQMxMZpvnxoSEoLhw4ejQoUKMDIygrW1Ndzc3LBy5Uqkpqbm5BKRkJAAR0dH8XXU5P0mIiIiyk8NHYvj6Fg3jHCvAB2pRK4tXQD+9X+Mdov9cZnb0BNlm0QQBJZlp2yLioqCvb09gIwbczs75TsAfGkuXbqEXr164enTp1n2LVmyJF6+fKnw/QEDBmDjxo1ZHi+VSvHTTz9hzpw5OYpVlYiICJQvX16jvvb29jhw4ADq1Kmjsk+5cuXw5MkTODg4ICIiQmkfQRDw/fffY9myZQCAKlWq4PTp07C1tc12/Fl58uSJ2kTFzJkz8zQh9ObNG7Rv3x5Xr15V2m5gYIBly5ZhyJAhase5cuUKunTpovQzAwClS5fGgQMH0KBBA7XjrF69GqNHj0ZycrLS9gYNGuDw4cMoUaKE2nE+N2nSJCxYsED8Wt37TURERFTQgqNi8ePe27j3QvkfM/s2LIuf2lWBmaFeAUdGlP/y4/6bM4OI/t+DBw/Qpk0bMRHUqVMnbNq0CZcvX8aNGzdw4sQJzJs3D56entDT0+x/MsePH0dwcDCCg4Nx69YtnDhxAmPHjoWuri7S09Px+++/Y8WKFfl2TZ07dxbPHxwcDH9/f6xatQpVq1YFkPGLpEOHDhrNEFJFEASMGDFCTAQ5Ozvj7Nmz+ZIIkp1PRiKRoGLFimjWrFm+nCstLQ1du3YVE0HdunXD0aNHceXKFSxZsgQ2NjZISkrC8OHDcfToUZXjREZGwsvLCy9fvoSuri5+/PFH+Pv7w9/fHz/++CN0dXXx4sULeHl5ISoqSuU4R44cwYgRI5CcnIySJUtiyZIluHLlCo4ePYpu3boBAAIDA9G1a1ekpWk+ZTooKAiLFi2CoaEhzMzMND6OiIiIqKC42FnAd3QTTPJ0gr6O4m3s1itP4fm3P86EvtJCdERfIIEoByIjIwUAAgAhMjJS2+HkCW9vb/Ga1q9fr7bv69evhWXLlilt8/HxEccJDw9X2sfX11fsY21tLaSmpuYy+v8JDw8Xx/bx8VHaJzk5WWjUqJHYb968eSrHc3BwEAAIDg4OCm1paWnCwIEDxXFq1qwpREdH59GVKPfmzRvht99+E06cOCG8fftWEARB8PPzE2OYOXNmnp1r7dq14rjfffedQntYWJhgbm4uABAqVqwopKSkKB2nX79+4ji7du1SaN+5c6dG75mjo6MAQDA3NxcePnyo0Oe7777T+DMsk5qaKtStW1cAIPzyyy9q328iIiKiwiDsVZzQdXmA4DD5kNLH2O03hJiPSdoOkyjP5Mf9N2cGESFjBsjhw4cBAPXq1cOAAQPU9re2tsaoUaNyfD4vLy+4ubkBAKKjo3Hjxo0cj5UTenp6+O2338SvT506le0x0tLS4OPjg/Xr1wMA6tatCz8/v2wvT8qu4sWLY9q0aWjdujWKFSuWr+eaP38+AMDKygrz5s1TaK9YsSKmTJkCAHj48CH279+v0Ofly5fYunUrAKBNmzbo3r27Qp8ePXqgTZs2AIDNmzcrXUq2f/9+PH78GAAwZcoUVKhQQaHPvHnzxNdEWbzKLF68GNevX0flypUxefJkjY4hIiIi0qaKNmbYPaIxZnpVg5GejkL7gZvP4bHwHHxvPZebVU5E/8NkEBEyEjKJiYkAMm7wC0Lm2jBPnjwRnz9+/BgLFiyAl5cXypUrByMjIxgZGcHBwQE9e/bEsWPH8uT8Li4u4vPsFnpOTU1F3759sWXLFgBAo0aNcPr06XxPzhSkBw8e4N69ewAykjXGxsZK+2VOHCpLBvn6+iI9PR0AMHDgQJXnk42Tnp4OX19fhfYDBw4oPWdmxsbG6NGjBwDg7t27ePDggcrzARmfuxkzZgAAVq5cCX19fbX9iYiIiAoLHakEA5uUx4nxzdC0ouIfI9/GJ+P77UEYuuk6XsZ+0kKERIUbk0FEgNxNsCwBkN8y1x2S1XcJDw9HhQoVMGnSJBw6dAhPnjzBp0+f8OnTJzx9+hS7du1Cu3bt0K9fvxzvGiWT+Zo1rYEEACkpKejZsyd27twJAGjatClOnDgBCwsLtcdl3uWsefPmOYq5IAUEBIjP3d3dVfYrVaoUnJycAAAXLlzI8TiZ29SNU7lyZZQqVSrH42T23XffIT4+Hv369fsi3hMiIiKiz9lbGWPz4Ab4y7sGzA11FdpP3XuF1gvPYXvgU84SIsqEySAiZCwDcnBwAADcunULc+fOFWdz5Jfg4GDxuazYclpaGvT19eHl5YUlS5bg1KlTuHHjBk6dOoUVK1bA2dkZALBlyxb8+uuvuTp/5qSXptuIJycnw9vbG/v27QMAtGjRAseOHSuSRYfv3r0rPq9SpYravrL2yMhIxMfHKx3HwsJCbRKndOnSMDc3B6CYkPz48aM4e0vTWJSNk9mOHTtw5MgRFCtWTG4XMSIiIqIvjUQiQY969jg1wR1tnEsqtH9ISsWUfcHos/oKnsTEKxmB6OvDZBDR/xszZoz4/KeffkKFChUwduxY7Ny5E+Hh4Xl6rlu3bonLvYyNjVG/fn0AGQmBiIgI+Pr6YsyYMWjVqhVq166NVq1aYeTIkQgODhaXCC1YsACxsbE5juGPP/4Qn3t7e2fZPyUlBV27dhWXMLVu3RqHDx+GiYlJjmMozDLv6pXV1o2ybR4FQVDYDUz2tSbbP2beLjK3sSgbR+bdu3cYN24cAODPP/+EtbV1lrERERERFXY25oZY+W1drOhbByVMFZe/X3ocgzaL/LHm/GOkpXOWEH3dmAwi+n/jx4/HoEGDxK8jIiKwZMkS9OrVC46OjihVqhR69eqFgwcP5miKqSAIePnyJdasWQMPDw9xadj3338PQ0NDAICJiQlKly6tcgyJRIIFCxZAR0cH8fHx2S78/P79e1y6dAmdOnXCwYMHAQCurq7o2bNnlsc+f/4cR44cAZCxFMnX1xdGRkbZOv+X5MOHD+JzU1NTtX0zJ8Q+fvyodJysxsg8jqoxchuLzA8//IBXr17B1dUVQ4cOzTIuIiIioi+FRCJBe5fSODneHd3qlFFo/5SSjt8O30O3fy7i/ssPSkYg+jowGUT0/6RSKdauXYsTJ06gbdu20NWVX3P86tUr7Ny5E506dUKDBg3w6NGjLMcsX768WCdHKpWidOnSGDp0KN68eQMA6NChA3755ReVx6ekpCAqKgr37t1DSEgIQkJC8Pz5cxQvXhxAxgwjdTZu3CieXyKRoFixYmjcuDEOHjwIPT09DBgwAMeOHdOoZpBEIhGfBwcHZ1mc+HPlypWDIAgQBAFnz57N1rHa8OnT/woNZlVY2cDAQHwuK0T++TiaFGeWjaNqjNzGAgD+/v5Yt24ddHV1sXLlSrn3lYiIiKioKGaij4U9amHDwPooY6n4B8xbke/Rcel5LDr1AMmp+VsegqgwYjKI6DOtW7fG0aNHERMTgyNHjmD27Nnw8vKSK5B87do1uLm54cWLF9keX19fH02aNMHGjRvFpExmKSkpWL58ORo1agRTU1PY29ujWrVqcHFxER+vX78GADGplBOVKlXC+PHjxTo1WSlbtix++OEHAMDbt2/RunVrhIaG5vj8hZ1sthaQUStJnaSkJPH557OlZONkNUbmcVSNkdtYkpKSMGzYMAiCgLFjx6JGjRpZxkRERET0JWte2QbHxzdDf1cHhbaUNAGLToXBa2kAbkW+L/jgiLRIsdw6EQEAzM3N0a5dO7Rr1w5Axo30tm3bMHHiRLx79w4vXrzAzz//jDVr1qgc4/jx42JxaKlUClNTU5QqVUrl7I63b9/C09MT169f1yhGZTM/MuvcuTN+++03ABlblj9//hzHjh3Dv//+i7t376J58+a4dOkSKleurNH5/vrrLyQmJmLZsmV4/fo1PDw84O/vD0dHR42O/5JkLor98eNHuYTM5zIXjf58GZeZmRkSEhJULtlSNo6yMTLHoskYysaZM2cO7t+/D3t7e8yePTvLeIiIiIiKAlMDXfzSuTo61rDFT3tv4/Eb+SLS9199QNcVFzC4aXlMaF0ZRvo6WoqUqOAwGUSkIQMDAwwcOBC2trZo27YtAGDfvn1YtWoVpFLlk+ycnJw03qkLAMaOHSsmgrp06YJBgwahRo0asLGxgaGhobikp2zZsoiMjMyydpGlpSWqV68ufl2jRg20bdsWXl5eaNu2Ld69e4c+ffogMDAQOjqa/U9vyZIlSEhIwLp16/Ds2TO0atUK/v7+coWLi4LMhZqjoqJQokQJlX1lhZolEolCgWc7Ozu8evVKobC0unE+fy3LlPnfevesxslcNPrzcebOnQsA8PDwEGtGfU6WTIqPj8eOHTsAADY2NmjZsmWW8RMREREVZg3KW+HIWDcsPh2GVf7yRaTTBWD1+XCcuPsKf3arAdcKxbUYKVH+YzKIKJvatGkDe3t7REZG4t27d4iJicmT3Zji4uKwc+dOAEDfvn2xZcsWlX3fvXuXq3O1atUKY8eOxYIFC3Djxg1s2LABgwcP1uhYiUSC1atX49OnT9i2bRsiIiLEhJC6rdO/NNWqVROfh4aGolatWir7ypbL2dvbK+yuVq1aNVy/fh2xsbF4+fKlytfoxYsXiIuLAwBUrVpVrs3MzEz8zGW1NC9z++fjyJaYrV+/HuvXr1c7zps3b9C7d28AGQXDmQwiIiKiosBQTweT21ZBB5fS+HHPbdx9ESfX/iQmAb1XX0bvBmUxpX0VmBtmXVuT6EvEmkFEOSBb+gUgzwrwhoWFISUlBQDU7u4VGhqq0ZKjrEydOlWsFzR79myNatrISKVSbNy4Ed26dQOQEbuHhwdiYmJyHVdh0bRpU/H5uXPnVPZ7+fKlWEy7SZMmOR4nc5u6ce7fv4+XL1/meBwiIiIiAqqXscB/o5vghzaVoa+jeFu8PfApWi88h1N3X2khOqL8x2QQUTYlJCTg7t27ADLqCsl29sqt1NRU8Xnmui+fW7lyZZ6cz8rKCqNGjQKQsbRo48aN2TpeV1cX27dvF2sq3blzB56enoiNjc2T+LTNyclJnFmza9cuJCQkKO23YcMG8XnXrl0V2jt16iQuI1Q3G0c2jlQqRadOnRTau3TpovScmSUkJGDXrl0AMmYkOTk5ybXLdnNT93BwyCiu6ODg8EXt/kZERESUXXo6UoxqURFHxrqhrkMxhfZXcUkYsukavt8ehJiPSUpGIPpyMRlEhIyivA0bNsShQ4eQnq56a8n09HSMGTMGHz58AJBxo59XM4MqVqwojrVx40al9YAOHjyIZcuW5cn5AGD8+PEwNjYGAPz5559IS0vL1vH6+vrYt2+fuIToxo0baNu2rdKZSxEREeIW982bN8917Lk1a9YsMR5VyZVJkyYByCjs/eOPPyq0P3r0CH/88QeAjPdPWTKoVKlS6Nu3L4CMguJ79uxR6LN7924cP34cANCvXz+lS8m6du0qFur+448/8OjRI4U+P/zwg7iEULbzGxERERGpV9HGFLuHu2J2J2cYKyke7XvrOTwWnsN/N59lWbOT6EvBmkFE/y8wMBBeXl4oU6YMunTpAldXVzg4OMDMzAzv379HUFAQ1q1bh+DgYACAhYUFfv311zw7f/HixdG+fXscPnwYx44dg6enJ0aOHAkHBwe8fv0ae/fuxYYNG+Do6Ij3798jOjo61+e0trbG0KFDsXjxYjx+/Bjbtm1Dv379sjWGoaEhfH190aZNG1y4cAGXL19Gx44dcfToUYWtzfPKsWPH5JZKZa6Tc/PmTbnkjqmpKby9vXN0Hh8fH6xbtw4XLlzA8uXL8fLlSwwdOhTFihVDYGAgfv31V8TFxUEqlWLJkiXQ1VX+K3XOnDk4duwYoqOj0bt3b1y7dg0dO3YEABw6dAgLFiwAkPF+yHZ/+5yenh6WLl0KLy8vxMXFoUmTJpg+fToaNGiAd+/eYfXq1di7dy+AjCVl2X0fiYiIiL5mUqkEPo3LoWUVG0zdH4zzYW/k2t8lpGDsjpvwvfkcv3WtjtIW+fPvXKICIxDlQGRkpABAACBERkZqO5xcS0xMFEqVKiVeU1aPSpUqCdeuXVM6lo+Pj9gvPDw8W3E8ffpUKFu2rMrzli1bVrhz547g4OAgABB8fHwUxggPDxf7K2v/XGRkpKCvry8AEKpWrSqkpaXJtcvO5eDgoHac2NhYoV69euK527RpIyQlJSmNy93dXYNXQzV3d3eN3ytVcc+cOVPss379epXnio6OFurXr69yfAMDA2H16tVZxnz58mW1n7FSpUoJly9fznKcVatWie+XskeDBg2E6OjoLMdRRdP3m4iIiKioSk9PF3ZfixRqzDouOEw+pPBwnnFM2HI5QkhLS9d2qPSVyI/7by4TI0LG7JZnz57hwoULmD17Ntq1awdHR0eYmJhAR0cH5ubmqFKlCnr27Ilt27YhJCQEdevWzfM47O3tcePGDfzwww9wcnKCgYEBLCwsULNmTcycORM3b96U2+UqL9jZ2cHHxwcAcO/ePXF2SXaZm5vj+PHjqFGjBoCMJVE9e/aUq4X0JSpRogQuXryIFStWoGnTpihevDgMDQ3h6OiIoUOH4vr16xgyZEiW4zRs2BDBwcGYPn06qlevDlNTU5iamsLFxQXTp09HSEgIGjZsmOU4snMOHToUjo6OMDQ0RPHixdG0aVP8888/uHDhAkqUKJEXl05ERET0VZJIJPCua4eTE5qhXXXF5fsfk1IxbX8I+qy5jIg3qmt9EhVmEkHgokfKvqioKNjb2wPIKD5sZ2en5YiIiIiIiIjy3tHgF/j5vzt4o6SItIGuFBM9nTCoSXnoKtmVjCgv5Mf9Nz+tRERERERERCq0cymNUxOawbuu4g14Umo6fj8Sim/+uYjQl3FaiI4oZ5gMIiIiIiIiIlLD0lgf87vXxKZBDVDGUrF49K2oWHRcEoCFJx8gKTV7O/QSaQOTQUREREREREQaaOZkjRPjm2FA43KQSOTbUtMFLDkdBq+lAQh6+k47ARJpiMkgIiIiIiIiIg2ZGOhiVidn7B7uCkdrE4X2B68+ots/F/HrobtISP6yN1OhoovJICIiIiIiIqJsqlfOCke+d8OoFhWgI5WfJiQIwNqAcLRddB4XH77RUoREqjEZRERERERERJQDhno6+KFNFfiObgJnW3OF9qdvE9BnzRX8tPc2YhNTtBAhkXJMBhERERERERHlgrOtBQ6MaoIf21aGvq7ibfaOq5Hw/PscTt59pYXoiBQxGVQEbNiwARKJRO2jVatW2g6TiIiIiIioyNLTkeK75hVxdKwb6pcrptD+Ki4JQzddw+htN/DmY5IWIiT6H11tB0C5V6tWLcycOVNp2549e3Dnzh20adOmgKMiIiIiIiL6+lSwNsXOYa7YcuUJ5h4NRXyy/Fbzh26/wIWHbzDTyxmda9lC8vm2ZEQFQCIIgqDtICh/JCcnw9bWFrGxsYiKikLJkiXzbOyoqCjY29sDACIjI2FnZ5dnYxMRERERERUFUe8SMHV/CPwfRCttb1nFBr91qQ5bS6MCjoy+JPlx/81lYkXYgQMHEBMTg44dO+ZpIoioIJUrVw4SiQQDBgzIt3MMGDAAEokE5cqVy7dzEBEREdHXx66YMTYOrI8F3WvCwkhPof1M6Gt4/u2PLZefID2d8zSo4Hz1yaDXr1/j0KFDmDFjBtq1a4cSJUqIdXaye/P55MkTTJw4EVWqVIGJiQmsrKxQv359zJs3DwkJCflzAWqsWbMGADBkyJACP3dRkJqair1792LYsGFwcXGBjY0N9PT0YGFhgYoVK6Jr166YN28ewsPDtR3qV08QBPj6+qJ3796oVKkSTE1NoaurC0tLS1SvXh3du3fHvHnzcOvWrQKPrUWLFuLvFE9PT42Pa968udL6Xzo6OrCyskLdunUxduxY3LlzJ8uxZs2aJR5/9uxZtX0DAgJgbm4OiUQCXV1dbNmyReOYs+Pjx4/w9/fH/Pnz0aNHD5QvX16MMb+SchcvXsS3334LBwcHGBoaolSpUmjTpg22b9+e5bGpqakICgrCv//+iyFDhqBGjRrQ1dUVY46IiMhyjKSkJOzfvx9TpkyBh4cHnJycYGVlBT09PRQvXhyNGzfGjBkzEBUVleNr7Nmzp9znRZO4iIiIKH9JJBJ8U9cOpya4o71LKYX2j0mpmH4gBL1XX0b4m3gtREhfJeErB0Dlw8fHR+NxfH19BXNzc5VjOTk5CWFhYfl3IZ+JiIgQpFKpYGdnJ6Smpub5+JGRkeK1RUZG5vn42vbff/8JFStWVPv5yPzo0KGDEBwcrO2wiyQHBwe1P48vX74UmjZtqvF7de/ePYUxfHx8BACCg4NDnsYeEREhSCQS8dxSqVR49uyZRse6u7trdD06OjrCH3/8oXasmTNniv39/PxU9vPz8xNMTEwEAIKurq6wc+fO7FxutjRv3lzlNeX1+yAIGa+BVCpV+zOcmJio8vhZs2apfR/Cw8OzjCEsLEyj99TExETYsGFDtq/x4MGDOYqLiIiICtbR4OdCvd9OCg6TDyk8nKYdEVaefSikpKZpO0wqRPLj/vurnxmUWdmyZbP1l3uZoKAg9OzZE3FxcTA1NcWcOXNw8eJFnD59GkOHDgUAPHjwAB06dMCHDx/yOmyl1q9fj/T0dAwYMAA6OjoFcs6i4rfffkOXLl3w8OFDABkzNObPn48TJ07g+vXrOH/+PHbt2oXRo0eLMxgOHz6MZcuWaTHqr1NycjJat26NgIAAAEDt2rWxZMkS+Pv7IygoCOfOncPKlSvRp08fWFhYFHh8mzdvhiAIMDAwgK6uLtLT03M00yY4OFh8XL9+Hbt370bfvn0BAGlpaZgyZQp2796dq1hPnTqF9u3bIz4+Hnp6eti1axd69OiRqzHVETKVq7OysoKnpydMTU3z5Vz//vsvZs+ejfT0dFSoUAFr165FYGAgDhw4gBYtWgDI+BkeNGiQRvEaGhqiUaNGqFChQrZjsbGxQc+ePTF//nzs3r0bAQEBuHLlCvbt24fBgwfD0NAQ8fHxGDhwII4cOaLxuB8/fsSoUaPEcxAREVHh1bZ6aZwa747udRXrviSlpuOPo6Ho9s9F3HsRp4Xo6Gvx1e8mNmPGDNSvXx/169dHyZIlERERgfLly2drjLFjxyIxMRG6uro4ceIEXF1dxbaWLVuiUqVK+PHHH/HgwQMsWLAAs2bNUhhj4sSJSErSfHvBsWPHolKlSkrb0tPTsX79ekgkErU3N6Ro3bp1+PnnnwEAJUuWxI4dO9C8eXOlfbt3745FixZhx44dmDp1agFGSTKrV69GcHAwAGDgwIFYs2YNpFL5HHezZs0wfPhwJCUlYfv27bC0tCyw+DZv3gwA6NixIxITE3HkyBFs3rwZP/74Y7bGqV69utzXderUgbe3Nxo2bIjvv/8eADB79mx07949R3EePXoU3bp1w6dPn2BgYIA9e/agY8eOORpLU3369MHw4cNRv359VKxYEUBGfaiPHz/m6Xnevn2LyZMnA8hI+F++fBklSpQQ2zt27IiuXbvi4MGD2L59O4YNG6b0Z97V1RUrV65EgwYN4OLiAl1dXQwYMACPHj3SOBZHR0e8fPlS5Y4hXbt2xbBhw9C0aVOkpKRg+vTpaN++vUZjT58+HU+fPkWrVq1gZ2eHjRs3ahwXERERFTwLYz3M614TnWrZYsq+YES9S5Rrvx0VC6+lAfiueQWMalkRBrr8Az/lsTyZX1SEhIeHZ2uZ2JUrV8T+w4cPV9onLS1NqFq1qgBAsLS0FJKTkxX6yJZmaPpQt9Tj+PHjAgChVatWml52thXFZWJPnz4VDA0NBQCCubl5tpb1vXv3TvD19c3H6L5e6paJtW7dWlzS9O7duxyfIz+WiV26dEn8Gdm3b5+wbds28evr169neXzmZWKqpKWlCWXLlhX7vXjxQmk/dcvEfH19BX19fQGAYGRkJBw7dixb15mXZO91Xr4Pc+fOFa99+/btSvtERkYKOjo6AgChffv2Go8t+9wgj5djtWnTRhz3w4cPWfa/evWqIJVKBQMDA+H+/fv5FhcRERHlj4+fUoSZ/4UI5X5SXDbmMPmQ4LHgrHD9yVtth0laxGVihdCBAwfE5wMHDlTaRyqVon///gCA9+/fw8/PT6HPx48fIQiCxg9Vs1UAFo7OqYULF+LTp08AgDlz5oizFTRhaWkJLy8vtX1evnyJadOmoV69erCysoKBgQHs7e3Ro0cPnDp1SuVxERERYjHYDRs2AABOnjwJLy8vlCpVCgYGBihfvjxGjhypceFZPz8/+Pj4wNHREcbGxjA3N4eLiwt++OEHPH/+XOVxmQsRA0BsbCx+/fVX1K5dG5aWlnIxAkB8fDx27tyJIUOGoFatWrCwsICenh6sra3h7u6O+fPn52omyNOnTwEAJUqUyNMZP+/fv8eMGTPg7OwMExMTWFpaolmzZti6davGY2zatAkAUKxYMXTo0AFdunSBmZmZXFtuSaVSODs7i19HRkZm6/h9+/bhm2++QXJyMoyNjXHo0CG0adMmT2IrLGS/o83NzdGtWzelfezs7ODh4QEAOH36dIEt51VF9jkBkOWM0dTUVAwdOhTp6en46aef4OTklN/hERERUR4zMdDFrE7O2DPCFRWsTRTaw15/xDf/XMQvB+8iITlVCxFSUcRkUC7JapWYmJigbt26Kvu5u7uLzy9cuJBv8cTExOC///6DlZUVunbtmm/nKWoEQRCX9JiZmalM7OXU1q1bUbFiRfz++++4fv063r17h+TkZERFRWH37t1o3bo1hgwZgtTUrH+5T5kyBZ6enjh06BBevXqF5ORkREREYOXKlahTpw7u3bun8thPnz6hd+/eaNmyJTZt2oTw8HAkJibiw4cPCAkJwfz58+Hk5ISDBw9mGUdYWBhq1aqFGTNm4ObNm4iNjVXo06FDB/Tq1Qtr167FrVu3EBcXh9TUVLx58wb+/v744YcfUKNGDYSGhmZ5PmX09fUBAK9evcLbt29zNMbn7t+/j9q1a+PXX3/F3bt3kZCQgNjYWJw/fx7ffvstRo8eneUYycnJ2LlzJ4CM5YT6+vowMjISkxHbt2/X6L3WhOw1AAA9PcXtSlXZuXMnevbsiZSUFJiamuLYsWNo2bJllsdl3uWssO9UlZycjMDAQAAZy7wyv1afk/2OTkpKwrVr1wokPmWio6Nx+vRpABlJzuLFi6vtv3DhQty8eROVKlXClClTCiJEIiIiyid1Haxw+Hs3jG5REbpS+WXlggCsuxCONov8ERD2RksRUlHCZFAuyW68K1asCF1d1SWYqlSponBMfti8eTOSk5Px7bffwsDAIN/OU9SEhIQgJiYGAODm5gYTE8WMfE7t2rUL/fr1Q3x8PBwdHbFw4UIcO3YM169fx969e8WaIGvXrs2ylszq1avx559/wt3dHdu2bcO1a9dw6tQpceZZdHS0yjpRgiDA29sbO3bsAAB4eXlh8+bNuHDhAi5duoTFixejbNmyiI+Ph7e3d5Y3xN7e3nj27BnGjBmDkydP4tq1a9i+fTsqV64s9klNTYWLiwumTZuG/fv348qVK7h8+TJ27tyJXr16QSqVIjw8HF26dBFnZWVHnTp1xGsbOnRoruvNJCQkwMvLCzExMZg+fTrOnj2La9euYfXq1bCzyyjwt3z5chw/flztOIcOHRKTU99++634fdnz169f49ixY7mKVSbz7xMHBweNjtm6dSv69u2L1NRUmJub48SJE3Bzc8uTeAqTBw8eIC0tDYD872BlCup3tDJJSUkIDw/H6tWr4erqinfv3gEAxo0bp/a48PBwzJ49GwCwYsUK/s4nIiIqAgz1dDCpTWX8N7oJqpcxV2iPfJuIb9dewY97biE2MUULEVKRkSeLzYqQ7NQMSkxMlNuWOCuyukCNGjXKo2gVVa9eXQAg3L59O1fjREZGqn0EBgYWqZpBW7ZsEa9n+vTpeTZudHS0YGFhIQAQBg0aJKSkpCjtN3XqVHHr8dDQULm2zJ9JAMLQoUOF9PR0hTGGDBki9rlx44ZC+6pVqwQAgp6ennD06FGlcbx9+1ZwdnYWAAhNmjRRaM9ce0YqlQrHjx9Xe/0PHjxQ237y5Elxu+81a9Yo7aOuZtCVK1fktgu3tLQU+vXrJ6xatUq4deuWkJqaqvb8MplrrFhYWAghISEKfcLCwsSaUp06dVI7XufOnQUAQrly5eTeq7S0NMHW1lYAIHTv3l3tGJrUDNq7d6/YR12NsMzv28CBA8XXrFixYkJgYKDaONTFldf1aPK6ZtDRo0fFWOfNm6e279WrV8W+P/30k0bj56Y2j5+fn9q6cP379xeSkpLUjuHp6SkAEHr37p1ncREREVHhkZKaJqzweyhUmnZEaS2h+r+dFI6FKK8ZSUULawYVMpnrSmiyJbJstkle75YjExgYiJCQEHG3m9ywt7dX+2jQoEEeRV04vHnzv6mW1tbWKvulp6cjJCRE5SMlRT47/88//yA2NhZlypTBihUrVM4emz17NsqUKYP09HS19WRKly6NpUuXKt2NaNKkSeLz8+fPy7UJgoC5c+cCAL7//nu0bdtW6fjFihXDvHnzAGQsZwwLC1MZy4ABA+Dp6amyHYDKHe9kPDw80KlTJwDy9bc01aBBA/z777/i8qj3799j8+bNGDZsGGrWrAkLCwt4enpi9erViI+P12jMX3/9Va4Oj0zFihXRpUsXAP9bHqpMTEyMuCV4nz595N4rqVSKPn36AAAOHjyI9+/faxRTZsnJyQgNDcUff/yBfv36AQCMjY0xZ84cjY5fv3490tPTYWRkhNOnT6N+/frZjuFLkZ3f0ZlnA+bX72hNlCtXDidOnMDGjRvVLmvbsmULTpw4AQsLC/z9998FGCEREREVFF0dKUY2r4BjY93QoJyVQvvrD0kYvvk6Rm29gegPmu9MTQRwmViuZF7Wou4f7TKyKfyJiYlZ9MyZBg0aQBAEXLlyJV/GL8oy3zSqWyIWFxcHFxcXlY9nz57J9ff19QWQsX21uiUcurq6cHV1BQBcunRJZT9vb2+V41SuXFm84X38+LFc2927d8UtsL29vVWOD2RsxS6jLpa+ffuqHUeZ6OhohIWFySXQZMm3W7duZXs8IKNQenBwMAYOHChXeBfIKGB98uRJDBs2DJUqVcpyaZZEIhGTNcrI6oK9fftWZSJn+/btYlIw8xIxGdn3Pn36hN27d6uNJ3NcsoeBgQGqVq2KqVOnIiEhAXXq1MGJEyfQsGFDjccCMn4PHT58WKNjMjt79qxYyL5cuXLZPr4gZed3dOafq/z6HZ1Z/fr1ERwcjODgYFy7dg379u3DgAEDEBkZCR8fH6xdu1blsTExMZgwYQIA4Pfff0fJkiXzPV4iIiLSHkdrU+wY1gi/dnaGib7iFvOHg1+g9d/nsO9GFARB0EKE9CViMigXDA0NxefJyclZ9pftCmNkZJRvMeWVyMhItQ9ZUdaiInMSQdMZJFlJS0vDzZs3AQD//vuv3A29sseePXsAZOw6pkpWdU+KFSsGAAq7IWWu/+Pq6qo2jswzKNTFUqNGDbWxyFy4cAE9e/ZE8eLFYWNjAycnJ7kE2urVqwHIz87KrsqVK2PdunWIiYnBxYsXsXDhQvTt21es8wMAL168QMeOHdXu3JZVwV4rq//9RUbVjlMbN24EkFHPqGrVqgrtNWvWRPXq1QHkflcxfX19DB48GE2aNNH4mN9//138vP/8889FelZJdn5HZ961qyB+R5uYmKB69eqoXr066tati65du2L9+vU4fvw43r59iyFDhuCXX35ReuzEiRMRHR2NBg0aYMSIEfkeKxEREWmfVCpBP9dyODHBHe5OiisZ3iekYMKuWxi44Sqevc//P2zRl4/JoFzInEDQZFmBLMmgyZIybbOzs1P7KF26tLZDzFOZEwDR0dEq+1laWoqzImQPHx8fpX3fvn2box2jEhISVLYZGxurPVYqzfiRlhXNlXn9+nW248gqFlniSZ1Zs2ahadOm2LVrV5a7feXFbAw9PT24urpi/Pjx2LJlCyIjI3H69Glx2VdaWhq+++47lX8x0fT1lY31uXv37omJN2WzgmRky7suXLiA8PBw9RcFiDNIgoOD4e/vj2XLlqFChQpITk7GqFGjxKV9mmjUqBEOHTokXuuECROwcuVKjY//kmTnd3TmJLA2f0e3atUKY8eOBZCxfPTznfbOnDmDjRs3QkdHBytXrpT7TBIREVHRV8bSCBsG1sfCHjVhaay4k+zZ+9HwXHgOmy9FID2ds4RINdXbX1GWDA0NUbx4ccTExCAqKkpt33fv3ok3G/b29gURHmVDzZo1xedBQUF5MmbmZMGQIUPEG7ysaLLkMDexHDx4UOPlPTY2NirbdHQUp6hmdvr0aXGnI0dHR0yaNAlNmzZF2bJlYWJiItZPmjFjBn799VeN4smJli1b4uTJk6hevTrevn2LsLAw3Lx5E7Vr187zc2We6TNhwgRxKY8qgiBg06ZNmDlzptp+splEMm5ubujfvz+aNm2K27dvY+rUqWjevLnG9X+aNWuGAwcOwMvLC0lJSfjuu+9gbGws7kpXVGSeGZbV7+jIyEjxubZ/R3fu3Bl//fUX0tPTsW/fPkydOlVsk9X+qlevHu7fv4/79+8rHJ85wXjw4EFxKWavXr3yOXIiIiIqCBKJBN3q2MGtkjVmHbyDw7dfyLXHJ6fh5//u4OCtF/jzGxc4Whf+yQhU8JgMyqVq1arh/PnzePjwIVJTU1UWCM78111lS0dIu6pXry4m9s6fP4+EhIQsZ4lkJfOSIkEQFG7oC1LmmU+WlpYFEots+VexYsVw+fJllYW5s5oxlBdKly6NDh06YPPmzQCAhw8f5nkyKD09HVu3bs32cZs3b84yGaSMmZkZNm3ahDp16iA1NRUTJ06Ev7+/xse3bt0au3fvxjfffIOUlBQMGjQIhoaG6NGjR7ZjKaycnJygo6ODtLQ0hRk2nytMv6Mz/6w8efJErk22nO3KlSvo3bt3lmN9//334nMmg4iIiIoWazMDLO9TB51qvsTPB0Lw+rMi0oERb9F28XmM93DCULfy0NXhjGL6H34acqlp06YAMpYYXL9+XWW/c+fOic+zU9+DCoZEIhGX9cTFxYl1X3JDX19fXJ504cKFXI+XG5kTHwUVy507dwAALVq0ULtDW+Z6RvnJ1tZWfK5sN7bc8vPzE2eXjBkzBtu3b1f7GDduHADg0aNHOX5PatasKRa8Pn/+fJYFsj/n5eWFrVu3igmTb7/9FgcPHsxRLIWRvr6+uPPhpUuX1NYNkv2ONjAwQL169QokPlUyF6L/EpYVExERkXa1cS6FkxPc0bOe4uzm5NR0zD0Wii4rLuDu8zgtREeFFZNBuSTbahrI2LJZmczbhVtaWqJFixYFERpl04QJE8SCs1OmTNGolktWZNumh4aG4vjx47keL6fq1KkjLplZtWqV3C5L+UVWL0ldQe6goKBc7X6Xnd0SMiedHB0dc3xOVWQ/4zo6Opg+fTp69eql9jFt2jRxJmFuCklPmzZNrBvz22+/Zfv47t27Y926dZBIJEhJSUH37t1x8uTJHMdT2Mh+R8fFxWHfvn1K+0RFRYmFxVu1aqWwK11By7zLnIuLi1xb5t3cVD0y1zELDw8Xv09ERERFl4WRHuZ618DWIQ1hb6W4GUbIszh0WhaA+cfv41OKYu1L+vowGZRLDRo0gJubGwBg7dq1SrfiXrBgAe7duwcAGDt2LPT0FAt9kfaVLVsWS5YsAQDExsaiadOmCAgIUHuMIAgqtxgHMt5v2V/2Bw4cKM6WUeXw4cO4fft29gLXgFQqFeuOPH78GP3795fbPelzcXFxWLZsWa7OWalSJQBAQEAAHj58qNAeHR0tFlLOqW7dumHFihVZ7gC3YcMGnD59GkDG+5zXS8Ti4+PFRIObm5vaWksyJUqUgLu7OwBg165dat8PdapUqYJu3boByJj15efnl+0x+vfvj3/++QdAxjKkLl26qFxy1rx5c3HnuYiIiBzFnFciIiLEWJo3b660z5AhQ2BhYQEA+OmnnxATEyPXLisqLqur9cMPP+RbvNu3b0dsbKzaPrt27cK///4LALCwsBATykRERESaaFKxBI6Pa4ZBTcrj88nwqekClvk9RIcl53H9Sf6XaqDC7auvGfT5jWrm7a0fPnyIDRs2yPUfMGCAwhiLFy9GkyZNkJiYCE9PT0ydOhUtWrRAYmIiduzYgVWrVgHIqF8xceLEfLkOyhtDhw7Fs2fPMHv2bDx//hxubm5o2bIlvLy84OLiAisrK6SlpeHly5e4ceMGdu3aJSZ4dHR0FIo/lyxZEhs3boS3tzdevHiBevXqYcCAAWjXrh3s7OyQkpKCqKgoBAYGYs+ePXj8+DEOHjyo8bbt2TFixAicPHkS+/fvx+7du3Hjxg0MHz4cDRo0gIWFBeLi4hAaGoqzZ8/C19cXhoaGGD16dI7P179/fxw8eBDx8fFwd3fHTz/9hLp16wKAuP37y5cv4erqqjSJqonIyEiMGjUKkydPhpeXF5o1a4bKlSujWLFi+PTpE0JDQ7F7924cOXIEQMbysL///jvPl4nt27dP3K3qm2++0fi4b775BqdPn8b79+/h6+uL7t275+j8U6dOxZ49ewBkzA7KyezD4cOHIzExEePHj0dCQgI6duyIU6dOicus8trDhw8Vkq2y1/Djx48Kv3vbtm2LUqVKZfs8VlZWmDt3LkaMGIEnT56gYcOGmDZtGlxcXPD8+XMsWrRITKD17t1bZVLp48eP4muc+Rpk9uzZgxIlSohf16pVC7Vq1ZLr/++//2LYsGHo0qWL+Fm1sLBAfHw87t+/jz179sh9VhcvXixXe4yIiIhIE8b6upjhVQ0da5bG5D23EfZaflfVR9Hx8F55CT6u5fBDm8owMfjq0wJfJ+Er5+PjIwDQ+KGKr6+vYG5urvI4JycnISwsrACvLG9Vq1ZN7lGpUiXx2iIjI7UdXp7bt2+f4OjoqNFnQiKRCG3bthWCg4NVjufr6ytYWVllOZZUKhXOnDkjd2x4eLjYvn79erVxOzg4CAAEHx8fpe3JycnCyJEjBYlEkmUs5cuXVzh+5syZWf4sZDZw4ECV4+vo6AiLFi3Kckx119S5c2eNf3YtLCyETZs2KT2H7PeAg4OD2utZv369OF54eLj4fQ8PD/Gz8OzZM41eG0EQhJcvXwpSqVQAIHTs2FGuzd3dPVuvdfv27cX+ly5dkmvL/Br7+fmpHWfOnDli32LFiglBQUEq48r8GmRX5tdSk4eyuDP/bLi7u6s934wZM9R+7tu3by8kJiaqPD7zuTR5zJw5U2GMzK+dukexYsWErVu3ZvMV/Z/M/1/LzXtEREREX75PKanCguOhQoUphwWHyYcUHk3+PC34P3it7TApC5GRkXl+/81lYnnEy8sLt2/fxvjx4+Hk5ARjY2NYWlqiXr16mDt3LoKCglCxYkVth0ka6tq1K+7fv49du3Zh8ODBqFatGkqUKAFdXV2Ym5ujfPny6NSpE/744w88evQIR48eVbtDl5eXF8LDwzF//ny0bNkSJUuWhJ6eHoyMjFC+fHl07NgRCxcuRERERL7WlNLT08OKFStw69YtjBkzBi4uLrCwsICOjg4sLCxQq1YtDB48GHv27BGXNubGunXrsHnzZri5ucHMzAwGBgZwcHBAv379cPHiRYwdOzZX4x84cAChoaFYvHgxevToAWdnZ/F6TExMULZsWbRv3x6LFi3Cw4cPc70sTZlnz57hzJkzAABXV1e5QtVZKVmypFhQ/tixY4iOjs5xHNOmTROf//rrrzkeZ+rUqZg+fToA4N27d/D09MyTz4K2zZ49GwEBAejTpw/s7e2hr68PGxsbtG7dGtu2bcPhw4fFmmH5ZdOmTVi+fDl69+6N2rVrw9bWFnp6ejAxMYGDgwM6duyIZcuW4dGjR2JhcCIiIqLcMNDVwQTPyjg4pilcylgotEe9S0S/tYH4YfctxCakaCFC0haJILCqJGVfVFQU7O0zqtVHRkaKxYmJiIiIiIio8ElNS8fagHAsPPkASanpCu3WZgb4tXN1tK2e/aX5lL/y4/6bM4OIiIiIiIiIijhdHSmGu1fAsXHN0KC8Yl3C6A9JGLHlOr7beh2vP+T/7sOkXUwGEREREREREX0lypcwwY6hjfBbl+owVVI8+kjwS7Re6I+916PAhURFF5NBRERERERERF8RqVSCbxs54MT4ZmhR2VqhPTYxBRN334LP+quIepeghQgpvzEZRERERERERPQVsrU0wroB9bGoZy0UM9ZTaPd/EI02f/tj06UIpKdzllBRwmQQERERERER0VdKIpGgS+0yODnBHV41FXfGjU9Ow4z/7qDnqkt4FP1RCxFSfmAyiIiIiIiIiOgrV8LUAEt718bq/vVQ0txAof1qxDu0W3weK84+REqa4m5k9GVhMoiIiIiIiIiIAACtq5XEifHu6N3AXqEtOTUdfx27jy7LLyDkWawWoqO8wmQQEREREREREYksjPTwR7ca2DakIcpaGSu033keh87LL2De8VB8SknTQoSUW4r7yBEp4ezsLPd1SkqKliIhIiIiIiKigtC4YgkcG+eGhSceYN2FcGSuIZ2WLmC53yMcDXmJv76pgXrlrLQXKGUbZwYRERERERERkVLG+rqY3rEa9o5sjEo2pgrtj6Pj0f3fS5jlewfxSalaiJByQiIIAveHo2yLioqCvX3GGtLIyEjY2dlpOSIiIiIiIiLKT0mpaVjh9wjL/R4iVclW82UsjfB7Nxe4O1lrIbqiKz/uvzkziIiIiIiIiIiyZKCrg/GtnXDo+6aoYWeh0P7sfSJ81gVi4q5beJ+QrIUISVNMBhERERERERGRxqqUMse+kY0xrX1VGOgqphX23oiCx0J/HA1+oYXoSBNMBhERERERERFRtujqSDG0mSOOj2uGhuUVi0e/+ZiEkVtvYMTm63gd90kLEZI6TAYRERERERERUY6UK2GC7UMb4feuLjA1UNyw/Nidl/BYeA67r0WCJYsLDyaDiIiIiIiIiCjHpFIJ+jQsi5MTmqFlFRuF9rhPqfhhz230XxeIyLcJWoiQPsdkEBERERERERHlWmkLI6z1qYfFvWrBykRfof182Bu0WeSPDRfCka5kNzIqOEwGEREREREREVGekEgk6FyrDE6Ob4ZONW0V2hOS0zDr4F10//cSHr7+oIUICWAyiIiIiIiIiIjyWHFTAyzpXRtr+tdDKXNDhfbrT96h/eIALPd7iJS0dC1E+HVjMoiIiIiIiIiI8oVHtZI4MaEZejcoq9CWnJaOecfvo/OyCwh5FquF6L5eEoHlvEkDzs7Ocl+npKQgLCwMABAZGQk7OztthEVERERERERfiIuP3mDKvmA8iVEsIq0jlWBYM0eMbVUJhno6Woiu8IqKioK9vT2AvLv/5swgIiIiIiIiIsp3jSuUwLGxzTDUrTykEvm2tHQB/5x9hPaLz+NqxFvtBPgV4cwgypH8yEwSERERERHR1+Fm5HtM3nMb918pLyLd39UBP7atAlMD3QKOrPDhzCAiIiIiIiIi+uLVsrfEwTFNMd7DCXo6EoX2TZeeoM3f/jh7/7UWoiv6mAwiIiIiIiIiogKnryvFWI9KODTGDTXtLRXan71PxID1VzFh1028i08u+ACLMCaDiIiIiIiIiEhrKpcyw76RjTG9Q1UY6immKfbdeIbWf5/DkeAXYKWbvMFkEBERERERERFplY5UgiFujjg+rhlcHYsrtL/5mIzvtt7AiC3X8TrukxYiLFqYDCIiIiIiIiKiQsGhuAm2DW2IP7u5wExJ8ejjd17BY+E57LoWyVlCucBkEBEREREREREVGhKJBL0alMXJCe7wqGqj0B73KRU/7rmNfmsDEfk2QQsRfvmYDCIiIiIiIiKiQqeUhSFW96+Hpb1ro7iJvkJ7wMM38PzbH+svhCMtnbOEsoPJICIiIiIiIiIqlCQSCbxq2uLkBHd0qWWr0J6YkobZB++i+8qLePj6gxYi/DIxGUREREREREREhZqViT4W9aqNdQPqobSFoUL7jafv0X5xAJadCUNKWroWIvyyMBlERERERERERF+EllVK4sT4ZujbsKxCW3JaOuafeACvpQEIjorVQnRfDiaDiIiIiIiIiOiLYWaohzldXbBjWCOUK26s0B768gO6rLiAP4+G4lNKmhYiLPwU92kjUsLZ2Vnu65SUFC1FQkRERERERAQ0ciyOo2ObYdGpB1h9/jEy15BOSxew8twjHL/zEn92c0FDx+LaC7QQ4swgIiIiIiIiIvoiGenrYEr7qjgwqgmqlDJTaA9/E4+eqy5j+oFgfPjESQ0yEkEQuP8aZVtUVBTs7e0BAJGRkbCzs9NyRERERERERPQ1S05Nx8pzj7D0TBhS0hRTHbYWhpjT1QUtqthoIbqcy4/7b84MIiIiIiIiIqIvnr6uFN+3qoTD37uhlr2lQvvz2E8YuOEqxu+8ibfxyQUfYCHCZBARERERERERFRlOJc2wd2Rj/NyxGoz0dBTa9wc9Q+uF53Do9nN8rYulmAwiIiIiIiIioiJFRyrB4KblcXxcMzSpqFg8OiY+GaO3BeFM6GstRKd9TAYRERERERERUZFUtrgxtgxuiLnfuMDMUH5D9frliqFF5S+rflBeYTKIiIiIiIiIiIosiUSCnvXL4tQEd7SuVhJARn2hP7+pAalUouXotEM36y5ERERERERERF+2kuaGWNWvLg4Hv8C7+GRUsDbVdkhaw2QQEREREREREX0VJBIJOtaw1XYYWsdlYkREREREREREXxEmg4iIiIiIiIiIviJMBhERERERERERfUWYDCIiIiIiIiIi+oowGURERERERERE9BVhMoiIiIiIiIiI6CvCZBARERERERER0VdEV9sB0JfB2dlZ7uuUlBQtRUJEREREREREucGZQUREREREREREXxHODCKN3LlzR+7rqKgo2NvbaykaIiIiIiIiIsopzgwiIiL6v/buPa6qKv//+PsgNwURUVRUVCYlM6+ToKWOaGUlOWmW2ngByi6O4/jwMpqM3y6OZqbZNF3GdPKWWRmVaZZaXkrT0qNYkFqaV9QSL3lFEVm/P+jsH8g5h6MgoOf1fDx4PA6stdf+rH3WYbM/rL02AAAA4EVIBgEAAAAAAHgRkkEAAAAAAABehGQQAAAAAACAFyEZBAAAAAAA4EVIBgEAAAAAAHgRkkEAAAAAAABehGQQAAAAAACAFyEZBAAAAAAA4EVIBgEAAAAAAHgRkkEAAAAAAABehGQQAAAAAACAFyEZBAAAAAAA4EVIBgEAAAAAAHgRkkEAAAAAAABehGQQAAAAAACAFyEZBAAAAAAA4EVIBgEAAAAAAHgR37IOANeGm2++ucD32dnZ1utDhw6VdjgAAAAAAHiF/NfcOTk5JdImySBckfwDMDY2tgwjAQAAAADAO2RmZqpBgwbFbodkEDzyww8/FPh+48aNJIEAAAAAALgG2YwxpqyDwLXn3LlzSktLkySFh4fL15e8Ymno3LmzJGnlypVlHMn1yZuP77Xe9/Icf3mJrSziKI19Xq19HDp0yPqnx4YNGxQREVGi7cO7lZffC9cjbz6213rfy3v85SG+6/VcfjX3c72cz3NycpSZmSlJatasmQIDA4vdJlfwuCKBgYGKiYkp6zC8jp+fnySpbt26ZRzJ9cmbj++13vfyHH95ia0s4iiNfZbGPiIiIsr8/cP1pbz8XrgeefOxvdb7Xt7jLw/xXa/n8tLaz7V+Pi+JW8Py42liAAAAAAAAXoRkEAAAAAAAgBchGQQAAAAAAOBFWEAaAADgEhkZGYqMjJQk7d+//5peYwAAAG/F+dw1ZgYBAAAAAAB4EZJBAAAAAAAAXoRkEAAAAAAAgBdhzSAAAAAAAAAvwswgAAAAAAAAL0IyCAAAAAAAwIuQDAIAAAAAAPAiJIMAAAAAAAC8CMkgAAAAAAAAL0IyCAAAAAAAwIuQDAIAAChBBw4c0L///W916dJF9erVk7+/v2rVqqWePXvq22+/LevwAABAEc6dO6fhw4frT3/6k2rXrq3AwEDVqlVL7dq106xZs3ThwoWyDrHYbMYYU9ZBAAAAXC+efPJJTZo0STfccIPi4uIUHh6uHTt2aOHChTLGaP78+erdu3dZhwkAAFw4cuSIIiMjFRsbq+joaIWHh+v48eP67LPPtHfvXnXp0kWfffaZfHyu3fk1JIMAAABK0Icffqhq1aqpY8eOBX6+Zs0a3X777QoODtahQ4cUEBBQRhECAAB3cnNzlZOTI39//wI/z8nJ0Z133qnVq1frk08+UXx8fBlFWHzXbhoLAACgHLr//vsLJYIkqUOHDurUqZOOHz+utLS0MogMAAB4wsfHp1AiSJJ8fX3Vo0cPSdLOnTtLO6wSRTIIAABcNw4fPqxPPvlETz31lO655x5Vr15dNptNNptNiYmJl9XW3r17NWLECDVu3FhBQUEKCwtTTEyMJk+erLNnz15RfH5+fpLy/pgEAACFledzeW5urpYuXSpJatq06WVvX55wmxgAALhu2Gw2l2UJCQmaPXu2R+0sXrxY/fr108mTJ52WR0dHa8mSJWrYsKHHse3bt0/R0dEKCwvT/v37VaFCBY+3BQDAW5Snc3l2draee+45GWN09OhRrVixQtu3b1dSUpJmzpzpURzlFf+WAgAA16V69eqpcePGWr58+WVtl5qaqt69eysrK0vBwcEaM2aMOnXqpKysLL377ruaMWOGfvrpJ8XHx8tut6ty5cpFtnnhwgX1799f58+f16RJk0gEAQDggbI+l2dnZ+vZZ5+1vrfZbBo5cqQmTpxYrH6VBySDAADAdeOpp55STEyMYmJiVLNmTe3Zs0dRUVGX1cbQoUOVlZUlX19fLV++XLfeeqtV1rlzZzVq1EijRo3STz/9pBdffFHPPPOM2/Zyc3OVmJior776So8++qj69+9/JV0DAMArlKdzeXBwsIwxys3N1cGDB7V48WIlJydr/fr1+vTTTxUSElKcrpYpbhMDAADXrfx/QHoytXzDhg1q06aNJOnxxx/XtGnTCtXJzc1V06ZNtW3bNoWGhurw4cPWWkDO6j788MOaM2eO+vXrpzlz5lzTj6EFAKC0lfW5/FLvv/++evXqpVGjRmnSpEmX15lyhL9GAAAAfrdw4ULrdVJSktM6Pj4+GjBggCTpt99+06pVq5zWy83NVVJSkubMmaOHHnpIs2fPJhEEAMBVVpLncme6dOkiSVq9evUVx1ge8BcJAADA79auXStJCgoK0i233OKyXv5Hx3/99deFyh2JoLlz56p379566623WCcIAIBSUFLnclcOHjwoSR7PJCqvSAYBAAD8btu2bZKkhg0bun38e+PGjQtt4+C4NWzu3Ll68MEHNW/ePBJBAACUkpI4l2/dutXpo+fPnj2r4cOHS5K6du1aEuGWGRaQBgAAkHTu3DkdOXJEklS3bl23datWraqgoCCdOXNG+/fvL1A2btw4zZkzR8HBwYqOjtb48eMLbd+9e3e1bNmyxGIHAAAldy5fsGCBpk6dqvbt26tBgwYKCQnRgQMH9Nlnn+no0aPq0KGDhg0bdtX6URpIBgEAAEg6deqU9To4OLjI+o4/IE+fPl3g53v27JEknT59WhMmTHC6bYMGDUgGAQBQwkrqXH7vvffq4MGDWrdundavX6/Tp0+rSpUqat68ufr06aOHH37Y7ayja8G1HT0AAEAJOXfunPXa39+/yPoBAQGSpKysrAI/nz17dpFPOgEAACWvpM7lrVu3VuvWrUs2uHKGNYMAAAAkBQYGWq+zs7OLrH/+/HlJUsWKFa9aTAAAwHOcyz1HMggAAEBS5cqVrdeXThd35syZM5I8m4YOAACuPs7lniMZBAAAoLz/JlarVk2SlJGR4bbu8ePHrT8gIyMjr3psAACgaJzLPUcyCAAA4HdNmjSRJO3cuVM5OTku623fvt16fdNNN131uAAAgGc4l3uGZBAAAMDv2rdvLylv2vimTZtc1vvyyy+t1+3atbvqcQEAAM9wLvcMySAAAIDfde/e3Xo9a9Ysp3Vyc3M1d+5cSVJoaKg6depUGqEBAAAPcC73DMkgAACA38XGxqpDhw6SpDfffFPr168vVOfFF1/Utm3bJElDhw6Vn59fqcYIAABc41zuGZsxxpR1EAAAACVh7dq12rlzp/X9kSNH9I9//ENS3hTwgQMHFqifmJhYqI3U1FS1a9dOWVlZCg4OVnJysjp16qSsrCy9++67mj59uiQpOjpadru9wJNLAABA8XAuLx0kgwAAwHUjMTFRc+bM8bi+qz+DFi9erH79+unkyZNOy6Ojo7VkyRI1bNjwiuIEAADOcS4vHdwmBgAAcIlu3brp+++/17BhwxQdHa1KlSopNDRUrVu31qRJk5Samuq1fzwCAHAt4FzuHjODAAAAAAAAvAgzgwAAAAAAALwIySAAAAAAAAAvQjIIAAAAAADAi5AMAgAAAAAA8CIkgwAAAAAAALwIySAAAAAAAAAvQjIIAAAAAADAi5AMAgAAAAAA8CIkgwAAAAAAALwIySAAAAAAAAAvQjIIAAAAAADAi5AMAgAAAAAA8CIkgwAAAAAAALwIySAAAAAAAAAvQjIIAAAAAADAi5AMAgAAAAAA8CIkgwAAAAAAALwIySAAAAC4FRcXJ5vNpri4uLIOpVxbsmSJ7rrrLlWvXl0VKlSQzWZTaGhoWYcFAEAhvmUdAAAAAHCte/311zV48OCyDgMAAI8wMwgAAAAohrNnzyo5OVmS1LhxY6WkpCg1NVVpaWlav359qcSQmJgom82mBg0alMr+AADXNmYGAQAAAMVgt9t14sQJSdKUKVMUHx9fxhEBAOAeM4MAAACAYjhw4ID1Ojo6ugwjAQDAMySDAAAAgGI4f/689drPz68MIwEAwDMkgwAAKAHp6ekaP3687rrrLtWtW1cBAQEKDg5Wo0aNlJCQoG+++cbpdmfPnlXlypVls9nUt2/fIvezfv162Ww22Ww2vf76607r/PLLL/rnP/+p1q1bKywsTAEBAYqMjFSvXr30xRdfuGx7z549VtuzZ8+WJH344Yfq2rWrateuLV9f30JPk/rmm280duxYxcXFqVatWvL391dISIiaNGmiQYMGaevWrUX2SZL27dunQYMGKSoqSoGBgapdu7a6d++uVatWSZKeeeYZKzZ3Tpw4oYkTJ6pdu3YKDw+Xv7+/IiIi1K1bN6WkpMgY41E8zjg7Pp9//rm6deumWrVqKSAgQFFRURo0aJAyMjJctuPp2i6zZ8+29rdnz55C5Q0aNJDNZlNiYqIkafPmzerbt68iIyNVsWJFNWzYUMOHD9eRI0cKbLdu3To9+OCDqlevngIDA3XDDTdo9OjROnXqlMfH4scff9Rjjz1mvV8RERHq1auXy3F+qdIco57KzMzU2LFj1apVK4WGhiowMFANGjRQ//79tXbtWqfbOJ6ylpSUZP0sKirKitFms2n16tWXFce5c+f0n//8R3FxcQoPD5efn5/CwsJ044036p577tHUqVMLjAfHZ2POnDmSpL179xbYv7vPzblz5/Tqq6/q9ttvtz6/NWrU0B133KE333xTOTk5LuO8dPxt3LhRDz30kCIjIxUYGKjIyEglJSVp+/btJdpfAEAJMQAAoFhWrVplJBX59eSTTzrdvl+/fkaSCQoKMqdPn3a7r8GDBxtJxtfX12RmZhYqnzdvngkKCnIbxyOPPGIuXLhQaNvdu3dbdWbOnGn69+9faNuOHTta9WfNmlVknytUqGBee+01t31asWKFCQ4Odrq9zWYzEyZMME8//bT1M1e++OILU61aNbfxdO3a1Zw6dcptPK7kPz6zZs0yTz75pMv9hIeHm61btzptJyEhwUgy9evXd7u//Md39+7dhcrr169vJJmEhAQzd+5c4+/v7zSW6Ohoc+jQIWOMMZMnTzY2m81pvT/+8Y8uj03Hjh2t9//TTz91OcZ8fHzMSy+95LZfpTlGPbVs2TITEhLiNqbBgwebixcvOj0u7r5WrVrlcRwHDx40TZo0KbLNESNGWNvk/2y4+7rUli1brDHk6ismJsb88ssvTmPNP/7efPNN4+vr67SNgIAAs2DBghLrLwCgZJAMAgCgmD7//HMTFBRkevXqZaZNm2ZWr15tNm/ebJYuXWpefPHFAhdcM2fOLLT9Z599ZpW//fbbLvdz4cIFU6NGDSPJxMfHFyp/7733rAv9P/zhD2bq1Klm6dKlZtOmTeaDDz4wXbt2tfYzbNiwQtvnv9Bu3ry5kWQ6dOhg5s+fb+x2u/niiy/M//73P6v+jBkzTNWqVU1iYqKZOXOmWbNmjdm8ebP55JNPzLhx40z16tWthM6KFSuc9unnn3+2EgO+vr5myJAhZsWKFWbjxo1m1qxZ1oVimzZt3CaD1q5da/z8/IwkU7NmTTN+/HizePFis2nTJrN48WIr4SbJ3H///S6PsTv5j89tt91mJR7yH58BAwZYddq2beu0nZJOBrVs2dL4+/ubJk2amJkzZ5qNGzealStXFuhz3759zQcffGDF9fbbbxu73W6WLl1aYFyMHj3aaSyOpEejRo1MaGioqVKlinnuuefMunXrzLp168yECRMKJFM++ugjp+2U9hj1RGpqqpVI8/PzM8OGDTOrVq0yGzZsMG+88YaJioqy9jlq1KgC2+7atcukpaWZ8ePHW3WWLVtm0tLSrK+iErz59ezZ02qnX79+5sMPPzTffPON2bhxo1m0aJF56qmnTIsWLQokR3799VeTlpZm7rvvPiPJ1K5du8D+HV/57dixw1SpUsVIMiEhIWbMmDHmo48+Mna73SxbtswMHjzYSu60adPGZGdnF4rVMf5atGhh/Pz8TO3atc0rr7xivv32W/Pll1+a0aNHm4CAAOu4bty4sUT6CwAoGSSDAAAopszMTHP8+HGX5efPnzd33nmnlQDIyckpUF5Ukschf9Jo/vz5hWJwXNw9/PDDTmdVGGNMcnKyNYNj+/btBcryX2hLMgMGDDC5ubku48nIyDBnzpxxWf7bb79ZF+zt27d3Wqd79+5uEwhnzpwxsbGxbmc4ZGdnmwYNGhhJ5u6773YZ0/Tp0602li9f7jJuVy49Po8++qjT4zNw4ECrzubNmwuVl3QyyJGcctbvBx54wEh5M7TCwsJMz549C42/nJwc07ZtWyPJVKtWzenYyT8DpkqVKk5nPaWnp1sJoTp16hRKIJTFGPVETEyMdYyWLVtWqPzYsWNWUtLHx8ekp6cXqlPUe+WJrKwsK6FZVPLj6NGjhX7m6bgyxljJzFatWjmdYWhM3u8bHx8fI8lMnz69UHn+8Ve/fn1r9ll+K1eutJJKMTExBcqK218AQPGQDAIAoBRs2bLFunCy2+2FyocMGWL9B/3IkSNO23DM9AgODi504T9u3DjrIvzcuXMu47hw4YKpU6eOkWSSk5MLlOW/0A4NDTUnT568gp4WtHDhQqvNS/t14MABU6FCBSPJPPDAAy7byH/snCWD5s6daySZwMBAc/jwYbfxOBJLf/nLXy67L/mPT0REhMvjvH37dqveyy+/XKi8pJNBNpvN5S1pK1eutNqoVKmSy4vqmTNnWvW+++67QuX5k0FTpkxxGfOkSZOseu+//36BsvI4Rr/99lurvSeeeMJlvbVr11r1/vrXvxYqL4lk0IEDB6w2Pv7448ve3tNx9dVXX1n7+f77793W7dWrl5VsvFT+ZFBKSorLNgYNGmTVyz87qLj9BQAUDwtIAwBQws6fP699+/Zp69atSk9PV3p6eoGFi7/77rtC2zgWj75w4YIWLFhQqDwrK0sLFy6UJHXv3l2VKlUqUL5o0SJJ0r333quAgACXsfn6+urWW2+VlLcYtSvdunVT5cqVXZY7c+bMGe3Zs0c//PCD1e/8T1a6tN+rVq3SxYsXJUn9+/d32W6LFi3UokULl+WOvnfs2FHh4eFuY/zTn/4kyX3fPfHAAw+4PM433nijgoODJUm7du0q1n480bx5c910001Oy/IftzvvvFNhYWFF1nMXs81mU0JCgsvypKQka7HiSxeCLg9j9FL5Y3zkkUdc1mvXrp11jN0tcF0c1apVk7+/vyTprbfecrt4c3E43ocbb7xRzZo1c1vX8XnZuHGjy3iqVq2q++67z2UbDz/8sPU6/7Errf4CAJwjGQQAQAk4c+aMJk6cqBYtWigoKEj169fXzTffrGbNmqlZs2Zq1aqVVffSpztJUps2bXTDDTdIkt5+++1C5YsWLdLp06clqdBTxy5evKgtW7ZIkt544w2nTxLK/5WSkiIp74lOrjRv3tyjfh85ckTJycm68cYbVblyZUVFRalp06ZWv+Pj4132Oz093Xp9yy23uN1P69atXZbZ7XZJ0rJly4rs+5QpUyS577snGjdu7La8atWqknRZT+i6UtHR0S7LQkNDL7ueu5ijoqJUvXp1l+Xh4eHWU9LS0tKsn5flGHXHMQb9/f3VsmVLt3XbtGkjSdqxY4eys7OLve9LBQQEqHfv3pKklJQUNWzYUKNGjdKnn36q3377rcT24/i8/Pjjj0W+D3/7298k5SWpjx075rS9Vq1aydfX1+X+WrZsaSV98o+J0uovAMA5kkEAABTTnj171KxZMyUnJ+v777+3Zru4kpWV5fTnjiTPunXrCj1K2ZEgcjz2Ob9jx45d0X/Vz54967LMkcxwZ9OmTWrcuLEmTpyon376qcjHtl/a7+PHj1uvi5rR46788OHDRcZaVCyX69KZWZfy8cn7E6uosVAS3MXiiONy6rmLuUaNGkXGU7NmTUkqkDwoqzFaFEeMYWFhbhMaklSrVi1JkjGmwNgtSa+++qq6desmKe8R8ZMnT1Z8fLyqVaummJgYTZ48WSdOnCjWPq7k8yK5fi+KGhO+vr7WjLRLE0ql0V8AgHPuz3oAAKBI/fv31+7du2Wz2ZSUlKQ+ffropptuUnh4uPz9/WWz2ZSbm6sKFSpIksukSd++fTVu3DgZY/TOO+9ozJgxkvIuoJYtWyZJ6t27d6GL1vwX7wMHDtTQoUM9itvx33pnHLG6kp2drV69euno0aPy8/PTkCFDdN999yk6OlpVq1a1bgPatWuXNeOpqGTRlXL0/5577tELL7xwVfaBPI5bwC5XWYzRy3Gl/SppISEhWrRokTZs2KAFCxZo9erV2rJliy5evCi73S673a4pU6Zo4cKF1q10l8vxXrRo0ULz5s3zeLs6deo4/Xlxjl1p9BcA4BzJIAAAimH79u1au3atJCk5OVnjx493Ws/VLRb5RUdHq3Xr1rLb7Zo/f76VDEpJSbFuS7n0FjFJBdaBMcaoadOml92Py7Vy5UprbZnXX39dAwcOdFrPXb/zz+zIzMx0ebHpKHelWrVqOnjwoLKzs0ul78XlmIWTm5vrtt6ZM2dKI5zL8uuvv3pcJ/+4LIsx6glHXEePHlVOTo7b2UGOW9ZsNluJzEpyJzY2VrGxsZLybttbvXq1Zs+erQ8//FCHDx9Wz5499fPPP6tixYqX3Xa1atUkSadPny6R96GoMZGTk1NgBpYzV7O/AADnuE0MAIBi+OGHH6zXjvUvnHGs01EUR7InPT1d33//vaT/f4vYDTfcYK1bkp+/v79uvvlmSdLXX3/tWeDFVBL9dsQs5d1y5o67dhzrMdnt9quylktJcyx6XNS6KD/99FMpRHN5du/eraNHj7osz8zMtG5xzJ9oKIsx6glHjNnZ2daaRq5s2LBBktSoUSO3M5ZKWuXKldWtWzd98MEH+vvf/y5JOnTokJWEdvB0ho7j87Jr165ir50lSVu2bHF7C+B3331nfS49ST552l8AQPGQDAIAoBjyXwS5m8kxbdo0j9rr06ePdfvL22+/rYyMDK1Zs0aS81lBDn/+858l5c1UctxSdjV50u/c3FzNmDHDZRtxcXHWLJm33nrLZb3vvvvO6RPYHBx9P3HihGbNmuU27vIgKipKUt4MiB9//NFpnezsbH3wwQelGZZHjDGaO3euy/LZs2dbtwNeurZVaY9RT+SPcebMmS7rrV+/Xlu3bi20TWm7/fbbrdeXLsgeGBgoKe9phu443gdjjF5++eVix3Ts2DEtXrzYZXn+43q5x85dfwEAxUMyCACAYmjUqJH1evbs2U7r/Pe//9XHH3/sUXu1atVS586dJUnvvPOO5s+fb11cu0sGDR061HqceVJSUoGZO84sWbLEmnl0JTzp95gxY7R582aXbdStW9d62lhKSooWLlxYqE5WVpYee+wxt7EkJCQoMjJSkjRy5Eh99dVXbuuvXbtWX375pds6V1PHjh2t1y+++KLTOsOHD9eBAwdKK6TL8q9//ctpEmvbtm2aMGGCJCkiIqLQ48ZLe4x6IjY21npS3YwZM7RixYpCdU6cOKHHH39cUt4tfoMGDboqsezatavIcbl8+XLrtSOp6BARESEpb4Fod0+E69Kli3VL1uTJk7VgwQK3+0xLS3Ob7JHyxquz28W+/PJLTZ8+XVLeEwNjYmKssuL2FwBQPKwZBABAMbRq1UpNmzZVenq63njjDR0/flz9+/dXRESEMjIyNG/ePKWkpKhdu3Ye3x7Tt29fff7559q/f78mTpwoKe/R6u4eDV6zZk3NmTNHDzzwgA4dOqTWrVsrMTFR99xzj+rWrasLFy4oIyNDGzZsUEpKinbt2qXFixdf8eO577rrLtWoUUOHDx/W2LFjtWfPHvXo0UPVq1fXzp07rQvrovo9depUrVixQmfPntWDDz6oQYMGqUePHgoJCVF6erpeeOEFbd26VTExMdq4caPTNgICArRgwQLFxcXp9OnT6ty5s/r06aPu3bsrKipKubm5OnTokDZt2qSPPvpIaWlpeuWVVwokZUpTq1atdOutt2r9+vWaMWOGsrOzlZCQoCpVqmjHjh2aPn26Vq5cqdtuu03r1q0rkxhdadiwoTIzM9W2bVuNHj1acXFxkqTVq1fr+eeft5789MorrxS6laq0x6inZsyYoTZt2ig7O1tdu3bVkCFD1K1bNwUFBSk1NVXPP/+8tT7WyJEjr9p6R/v27VOnTp3UpEkT9ejRQ61bt7bW0dq/f7/ee+89K3HTsmXLQreM3nbbbZLyZuQ98cQTGjJkiKpXr26VN2zY0Ho9f/58xcbG6tixY+rdu7fmzZun3r17q1GjRqpQoYIOHz6s1NRULV68WN98841GjBhhPfXrUi1atNDWrVt1yy23aMyYMYqNjdX58+f16aef6qWXXrLWYnrttddKtL8AgGIyAACgWFJTU03VqlWNJKdfzZo1MwcPHrS+f/rpp922d/LkSVOxYsUCbbz00ksexbJo0SITFhbmMhbHl4+Pj1m5cmWBbXfv3m2Vz5o1q8h9LV261AQGBrrcR1xcnElPTy+yzeXLl5ugoCCX7Tz99NPm//7v/4wkExgY6DKe9evXm8jIyCL7LsnMmTPHo+N5pcenfv36RpJJSEhwWr5t2zZTo0YNl/GNHDnSzJo1y/p+9+7dl70PB0/GXVF969ixo5FkOnbsaD755BNTqVIll+NqypQpbuMpzTHqqWXLlpmQkBC38QwePNhcvHjR6fZFvVeeWLVqlUdjt3HjxmbXrl2Ftr948aJp27aty+0u9eOPP5qmTZt6tM9nn3220Pb5x9+MGTOMr6+v0239/f3NO++8U+L9BQAUD7eJAQBQTC1bttSWLVv0xBNPqH79+vLz81NYWJhiY2M1ZcoUbdiwwbqFwxOOBVQdKlSooD59+ni0bbdu3bR7925NmTJFnTt3Vs2aNeXn56eKFSsqKipK9957r6ZOnao9e/aoU6dOl93X/O666y7Z7Xb169dPtWvXlp+fn8LDw9WxY0dNnz5dK1asUFBQUJHt3HnnnUpPT9fjjz+u+vXry9/fXzVr1lR8fLyWLl2qZ555RidPnpQkValSxWU7bdu21Y4dOzRt2jTFx8erdu3a8vf3V2BgoCIjI9WlSxdNmDBB27dv14ABA4rV9+Jq3LixNm/erEGDBll9Dg8P1913360lS5Zo8uTJZRqfO/Hx8bLb7UpKSrJir1Gjhnr27Km1a9dqxIgRbrcvzTHqqS5dumjnzp1KTk5Wy5YtFRISooCAANWrV099+/bVmjVr9Oqrr1prXF0NHTp00OrVqzVmzBh16tRJDRs2VOXKleXn56eaNWuqS5cumjZtmrZs2eL0likfHx8tX75cY8eOVYsWLRQcHOx2Ueno6Ght2bJF8+fPV8+ePVWvXj1VrFhR/v7+ioiIUFxcnMaOHatNmzbpqaeechv7wIEDtWbNGvXq1cv63NWpU0cDBgxQamqq099fxe0vAKB4bMb8vhABAABAOXXHHXdoxYoVat++vbWgNoCy06BBA+3du1cJCQku1w0DAJRfzAwCAADl2sGDB61Fodu2bVvG0QAAAFz7SAYBAIAytXPnTpdlWVlZSkxM1IULFySpzG/vAgAAuB7wNDEAAFCmBg4cqDNnzqhXr1665ZZbFBYWplOnTslut+v111+3kkWPPPKImjVrVsbRAgAAXPtIBgEAgDJnt9tlt9tdlvfo0UOvvPJKKUYEAABw/SIZBAAAytTUqVP10UcfaeXKlcrIyFBmZqaMMapRo4batm2rhIQEde3atazDBAAAuG7wNDEAAAAAAAAvwgLSAAAAAAAAXoRkEAAAAAAAgBchGQQAAAAAAOBFSAYBAAAAAAB4EZJBAAAAAAAAXoRkEAAAAAAAgBchGQQAAAAAAOBFSAYBAAAAAAB4EZJBAAAAAAAAXoRkEAAAAAAAgBchGQQAAAAAAOBFSAYBAAAAAAB4EZJBAAAAAAAAXoRkEAAAAAAAgBchGQQAAAAAAOBFSAYBAAAAAAB4EZJBAAAAAAAAXoRkEAAAAAAAgBchGQQAAAAAAOBF/h+K1eIJjWVrQQAAAABJRU5ErkJggg==" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "draw_order_multiple(\n", + " [\n", + " out_SLOWRK_commutative_sde,\n", + " out_SPaRK_commutive_sde,\n", + " out_GenShARK_commutative_sde,\n", + " ],\n", + " [\"SlowRK\", \"SPaRK\", \"GeneralShARK\"],\n", + " title=\"Order of convergence on a commutative noise SDE\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "604ba9f83e626b75", + "metadata": { + "collapsed": false + }, + "source": [ + "## Additive noise SDEs" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "9b82db9458a6d31a", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-04-02T15:07:46.653744Z", + "start_time": "2024-04-02T15:07:44.060391Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "
", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAioAAAGdCAYAAAA8F1jjAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOxdZZgcZdY9Ve020+OSmcnEhbgSJQQSkiABAsF9kUU+3BZZ3H0J7osTLCEhEE8gxJ34SMa93buqvh9vacsIyUTYOs+TJ9NlXV1dXe957z33XIrjOA4qVKhQoUKFChXHIOijfQIqVKhQoUKFChXJoBIVFSpUqFChQsUxC5WoqFChQoUKFSqOWahERYUKFSpUqFBxzEIlKipUqFChQoWKYxYqUVGhQoUKFSpUHLNQiYoKFSpUqFCh4piFSlRUqFChQoUKFccstEf7BA4VLMuipqYGNpsNFEUd7dNRoUKFChUqVLQDHMfB4/EgPz8fNJ08bnLcE5WamhoUFhYe7dNQoUKFChUqVPwFVFZWoqCgIOn6456o2Gw2AOSDpqSkHOWzUaFChQoVKlS0B263G4WFheI4ngzHPVER0j0pKSkqUVGhQoUKFSqOM7Ql21DFtCpUqFChQoWKYxYqUVGhQoUKFSpUHLNQiYoKFSpUqFCh4piFSlRUqFChQoUKFccsVKKiQoUKFSpUqDhm0alE5emnn8bIkSNhs9mQnZ2Ns88+G3v37lVsEwwGcdNNNyEjIwNWqxWzZs1CfX19Z56WChUqVKhQoeI4QacSlZUrV+Kmm27C2rVrsXjxYkQiEUydOhU+n0/c5vbbb8f8+fPxzTffYOXKlaipqcG5557bmaelQoUKFSpUqDhOQHEcxx2pN2tsbER2djZWrlyJiRMnwuVyISsrC59//jnOO+88AMCePXvQr18//PHHHzjxxBPbPKbb7UZqaipcLpfqo6JChQoVKlQcJ2jv+H1ENSoulwsAkJ6eDgDYtGkTIpEITj31VHGbvn37oqioCH/88UfCY4RCIbjdbsU/FSpUqFChQsXfE0eMqLAsi9tuuw3jxo3DgAEDAAB1dXXQ6/Ww2+2KbXNyclBXV5fwOE8//TRSU1PFf2qfHxUqVKhQoeLviyNGVG666Sbs3LkTX3755SEd5/7774fL5RL/VVZWHqYzVKFChQoVKlQcazgivX5uvvlm/PTTT1i1apWiQ2Jubi7C4TCcTqciqlJfX4/c3NyExzIYDDAYDJ19yipUqFChQoWKYwCdGlHhOA4333wzvv/+eyxbtgzdunVTrB8+fDh0Oh2WLl0qLtu7dy8qKiowZsyYzjw1FSpUHAdwBB34dNen2NG442ifigoVKo4SOjWictNNN+Hzzz/Hjz/+CJvNJupOUlNTYTKZkJqaimuuuQZ33HEH0tPTkZKSgltuuQVjxoxpV8WPChUq/t64dfmt2NKwBYW2Qiw8d+HRPh0VKlQcBXQqUXnzzTcBAJMmTVIs//DDD3HllVcCAF5++WXQNI1Zs2YhFArhtNNOwxtvvNGZp6VChYrjBHta9gAAKj2V8Ef8MOvMR/mMVKhQcaTRqUSlPRYtRqMRc+bMwZw5czrzVFSoUHGcgWEZhJmw+Pqg+yD6ZfSL284T9uDiBRcj3ZiO1ya/hlRD6pE8TRUqVHQy1F4/KlSoOCbRFGgCwzHi61JXacLttjRsQbm7HJsbNuOrvV8dqdNToULFEYJKVFSoUHFMot6v7PlV461JuN0B5wHpb8eBhNuoUKHi+IVKVFSoUHFMos6nNH18bctr2NKwRXwdYSL495p/4+VNL4vLSlwlR+z8VKhQcWSgEhUVKlQck0hEOm5ddqv4948lP+K7/d8p1pe7ysGwTOxuKlSoOI6hEhUVKlQck9jv2A8AGJY9TFzmCDmwoHQBAKkiSI4wG0a1txoAsKFuA2b+MBNbG7Z2/smqUKGi06ASFRUqVByTELQnY/PHKpbft/o+AKRkWY4u1i4AgBInicRc/cvVKHWV4v+W/R++2/8dZs+fjXJXeSeftQoVKg43VKKiQoWKYxKCRmVk7siE6yvcFYrXg7IGAQB2Nu9UWCM4Qg78e82/sbtlN65YdEUnna0KFSo6CypRUaFCxTGHQDSAQDQAAOid1huPjHlEsT4YDaLGp6wCGpg5EADw2e7PkpYytwRb1FSQChXHGVSiokJFKwhXVcE5dy64aPRon8r/FJxBJwBAS2th0Vlwbq9zFevP+fEcsBwLAJjSdQpeOfkVXNDnAuhoHXwRH9bWrk167Mt+vgyOoKPTzl2FChWHFypRUaGiFRy85FLUPvgQWj7++Gifyv8UHCFCJNIMaaAoChRF4eszvhbXV3mrAAC90nrhpUkv4ZSiU6DX6FFoKwQA7HPsa/X4C8vUvkEqVBwvUImKChWtIFpPTMdc8+Yf5TP534IQUbEb7eKyfhn98NmMz3Bx34uT7peiTwEA7G3ZCwDon9Ef6cZ0cf2UrlMAAIsPLj7MZ6xChYrOgkpUVKhoB6LNzUf7FP6nIERU0g3piuWDsgbh/tH3i68LrAWK9SkGnqg4CFGZ0GUC3pnyDnS0Dmf3PBt3jrgTALCpfhMWlS3qtPNXoULF4UOnNiVUoeLvAqap6Wifwv8UnCEnAGVERY7vzvoO72x/B7cNv02xXIioRFmiKcqx5KBPeh8sn70cZp0ZOlqH3mm9sc+xD1/s+QLTuk3rrI+gQoWKwwQ1oqJCRRKw4bDiNceojqdHCoLY1W6wJ1zfK60Xnj/pedE7RYBAVARkGjMBAKmGVOhoHQDgrhF3ASAVQCpUqDj2oRIVFSqSgHE4Fa+d33xzdE7kfxBCRCXNmNah/YTUj4AMU0bcNjnmHABAc1CZzmMDAYX/igoVKo4NqERFhYokYJxOxeu6Rx4FGwgcnZP5H0NbEZVkiI2oyIW0scs8YQ8iTAQAEKmpwb6x41B21ky4f/4Z/i1b4vZToULF0YFKVFSoSALGEZ8aYNyeo3Am/3uQlyd3BEJ5soBERCXFkAINpVG8j2vePHCBAEL796P69jtw8KKL44iqChUqjg5UoqJCRRJ4li2LW8b6fcptli9H/dNPg/X54rb9uyDEhNAcOLJVT2JEJYmYNhkmdJmgiKqYdea4bWiKFiM1wvvQ5vjtwpVVHXpvFSpUdA5UoqJCRRJ4V64EAGRce624jPV6FTPt2gceRMvHn6DqttuP9OkdMdy/+n5MmTslYbfizsB+x36xIWFHIyoaWoPrB13f5nbpJhJpaQqQaq5ERNM5d26H3luFChWdA5WoqFCRBKzTBQBInXkW9N27AwCq77wL+8aOQ3D3bgAA00LSQ77Vq4/OSR4mOINOzCuZB3/EH7du8cHFiLARnD//fDz0+0PwR/xYULogring4cKjfzwq/t1RMS0AXNj3QpzZ/Uw8Nf6ppNsUpxQDAHY170K0qQmBHTvjtnF+9RX8mzZ1+P1VqFBxeKESFRUqEoDjODBeLwCAtqWAtlgAAJGKCoBl0fDCi0fz9A477lp1Fx747QE8t+E5xXLBj0TADwd+wOjPR+O+1ffhqkVXHfbzYDkW2xq3ia8TaUzagl6jx1MTnsKZPc5Muo3QkXlj9VqUX3AhvHyaL2XGdBj69hW3c379dcL9VahQceSgEhUVKhKA8/sB3jdFY7PGaRjClZVgQ6GjcWqdgnW16wAA3+7/VrG8Na+RhkBDx8t5WRZwHEy6utJTKf793+n/hV6j79jx24lh2cMAAPS6bYhUV4vLrSefjJz77hNfB7Zui9v3SMK/cSN8a9cd1XNQoeJoQyUqKlQkgBBNgUYDymQSIyoCorW1YN1uxTL5oM0xDMKVlTge8N6O95Kua/Q3AgCyTdnYdvk2PHTiQ4r1QtVMUoS8wJ4FQDQElCwHHksDXh1EliWAIG7tYu2CIdlD2v8hOohuqd2goTTot0eZ6tJmZcFy4mikz/sKwNElpGw4jIOXXoaKK69EpK7uqJyDChXHAlSiouKoYGvDVpz89cnHZHO4oC+Cbcuq0ZzWFxqbDRRFxUVUOIYB41GWKnMyj5WmOXNQMmUqXPN/OiLnfCh4dfOrSdcJXYizzFmgKRqz+8zG4+MeF9eXOEtaP/jih4AvLwZ+eQBYeLe0fO2bABvv9CtoZCw6S9y6wwm9Ro+ilCL0qSbkUpuTg/SrroJ55EgsLF2IU9dcjIjFALAswqWlnXouyRBtaBD/9q9Toyoq/nehEhUVRwwcx2HJwSWo9dbi8bWPoynQhDtW3HG0TysO25ZVYt1KJ7YNvgXR1CwAiIuogGVROuN05SK+coRjGDS98SYAoPYhZQTieECUjeLFjS9iQekCMdpSlFIkrj+759mYXDgZAPDprk9bP9jGD8j/G94FmveDCVPwVBvAlawGnikCfntZsbkvSq5hrk+PCN+5urPQz1iMAhIwQrdv5yLn3nvA0RTuXX0vQFEoSSORlHBZWaeeRzJEZZ/fv2lz3HqOZY/k6ahQcdSgEhUVRwzLK5fj9hW3Y9a8WdDSUj/MRJUmAoLR4BHvyeJtDop/R1JyASQgKgkQ+PNPlF9wIRpefElcxgWD4CKRw3+SnYjnNzyPj/78CPetvg8VHlLZc+PgGxXb3DiEvF5WuQxVnlb8Rqy5ipcVKzNQtToDjgMWIOwFljyiWO+L+KBhOFzzwk4cmHwKPEuWHPoHSoLemnzQAKI6GtpM0hOo3FUurm+2UQCAg6Vbsb1xe6edRyIEduxE9Z13ia9jzQejLS3YP/Ek1D322BE9LxUqjgZUoqLiiGFD3QYAgCfiERvEAcC7O94FQGbyq6pWYX3tery+5XUEogHcvOxmnDb3NJS5jtysNuCTiEWII4QqkSFYLOr+/QgC27ah5YMPFMuPVYfTCBvB/JL5ccs/3/O54nWuJRfFqcWKZX3S+8CmswEApn83Hetr1yd+E0uW4mWwmYhj3ZXGhJv7Ij5kOwGTLwowDDzLlsOzfDlCnRDVKNbnAQDCOkpcJq84cljJ/ws3fIpLFl6CMKNsUtmZOHjJJYjKdCmMx6tY7/zqKzBNTXB8/sUROycVKo4WVKKi4ohBTk7kA8KWBtJX5fsD3+OmpTfhml+vwdvb38abW9/Eutp1CDJB/GfLf47YeQY90oC0IfM87F1Xh1BKDtqqb4kmSVUwLtdhPLvDh3kH5uFfv/2rze1MWlPC5dO6TRP//uDPDxJug4Akto0GpMeNITUKjgPcFUaESiSdiz/iR5ZLutLeVatQ9c8bUTp9Rpvn2VEU6bLJe2oZTPt2GpZWLBVN7c7sfiYcVkJg0niOcCQje1xM5242Vg+lpn1U/A9BJSoqjggibCRupi7AGyYjwS9lvyiWL6uULOxLnUdO0BhwK6s8lny4Cz+uSUdd7olx22bffTf0PXu0erxjNaKyoX5D3LJTik4R/75v1H3INGXi9mEJXHeZKG7td4WYAvq9+nec8f0ZeGHDC8rtAtLg7m+WlRpTgLfGgOo16Sg9/QyyrHoTvFXrkC3jdUxTU8c/WDuRpyFmckEdUO2txoc7PxSrjvqm94XfbgBwdIhKLBiv2mNKxf8uVKKi4ojg57KfEWISl3l6I2QkyLPmKZYfdEt+GxRF4a+A47gOe30EfdGEy6vzJ8Qts005Fdr0jITba7KI7qEjEZV3t7/btkD1MGFvy964ZTQlPRIu6nsRls9ejpOLTlZu5KkDXh+B1NeG4QZNLvIt+QDI9/Xxro8RYfnU2b5fAJn+KNAkERU2QsFXb5COufm/qPtgCj5oXKeIqHQmuADRImWlFQAg10MgI2nGNBiyib7G7iXnI5CYIwFNmtKRl/Um7yXVYS8bFSqOM6hE5W+MD3d+iJc3vdz2hkcAicSIs3rNAgB4wm3PFgPRQJvbxIKLRlE+6zxUXtd27xcBDMMiHEocVnenFCOYUYS8J6TyXG12NugUW8LtDT16kmO2M6LSHGjGa1tew/Mbn09K6g4XIkwEpS5llOqxsY8hzyKRRTlpUaBiLeAoA9goqJ/vxpndldVP+x37UeGuwMofr1YsD7ZIqT82SoGNyI4/72Z8ZyOikGxn4rcV0h0cE1/W/Fewb+dWLO9bBCo1C2atGUEmiM0NpLom1ZAKXRGpdMp1AKYghxuW3IBlFfGNKjsTBW++ASA+9aPAcSbWVqGio1CJyt8UYSaMlza9hA92ftBpPVk6glgrdgD45+B/AiACSo7j4Aw5k+7vi7SvOzEXDsOzYgVYnw+h0lIEd+2Cb/XquJx/IkTCDILe1h/6kTteQeq55yL9iiuQddutoI1GkZDEQptNhKSMs30RlcYAqZVlObbTZ++1vlqwHAujxoi1F6/Ftsu34Zxe5+Dagdfi5MKT8cqkV5LvHJINmr5GzNBlKlZva9yG078/HTdnp+NzmxV1Gg0AgAlLjxs2QoONSFEyjgG0fGQgWUSl5cOPsLtvP+wfNx7RQ0wJsZEIVqxegoBBhwNUBAMzBwKAGA2yG+ywFXZDTTqgZYEBB8k53br81kN633afn59EonR5hDhy4TBY+T0su0RsoOMkXoWK4wmdSlRWrVqFM888E/n5+aAoCj/88INi/ZVXXgmKohT/pk2blvhgKtqNCBPBrHmzxNftiVh0NoSB964Rd2FI1hBc1v8ypBhSAAAMxyAQDcAdcifdv71EpeGll1F1wz9R+/C/QdHS7R1rzhaLxgoP3r1tFX55lzSn00Z8GBxYGb9dfQQUTSPn/vuQecMNAABjn94Jj6mx28l7dyCiIqCziUq1l9jG51vzYdFZxOiJ3WjHa5NfwyldT0m+c0h5LbuV/K54vb1ihfj305npuOGEMQBIFEUAa+2qiKgwERoe/vvKTsLrGp5/nmzrdMKz7NAiG/W/fyP+rdVoMCZ/jGK93WBHriUXW7uTcx5SeuTSK1w0Co53w9VmSVVT8qgKF5UItUpUVPzd0alExefzYfDgwZgzZ07SbaZNm4ba2lrx3xdfqOV2h4p1detQ7i4XXzcHm5NvfIQg5P7zrfn474z/4p6R98CoMUJLkfJfT9gDV4iMUF2sXXDXiLsU+0fYSJvloYEdO9Hy0UcAAPeCBWCDUvqEcRESxPp8cC1YoJydAijd2giO5VB7gJyDMeRAV2M9zKlEV1E8kOhQmio9Ck1A0BuBJ00S06ZfcQUAwDZlCjSpqfx7Jx5568vdWPDGdjjqCAlrCkhRgiNJVDoMXvyMdPK5qT0/Ye7pX2FS4SQAwLYapYtqiacCG2a9Ab/GKh2i0Qe/U0qZOSMaNGo10Ec42NvBSeUk9K/AV1cu/h2J+nFa8WmgIBGpVEMqTFqTkqgcIS2IYBwIABqbTSyNVxAVGTlh/SpRUfH3RqcSlenTp+OJJ57AOeeck3Qbg8GA3Nxc8V9ajIhMRcexumq14rV8pn60IPSESTNI3y9FUbDqyeDlCXvE1M+rJ7+KS/pdohg4gLajKtV3KF1uuYAk5GQ9hKjUPPAgau68C40vKrsfR4JK3YM+5AKdkoJZdw/H5Mv7YvzsXgAAryOE1V/vJ/uEGHzx2Dr88HE19Hc/idTHXkDWvfeg56qV6PLyS9DYyEDMepUeGALmPrsR5dubsPDNHQCUhLIl1LkVJkLzvy7WLh3fWYio9JoKaE1A0Ik+0OORMY+QY1PxGpKrNz+DSEAy0mM9HnAhiSweWJOBiEeDzORBNQWizYd2ffxNUiPCUMSHAlsBrht0nbjMprch15KL3YXkHsxyA+YQMGEHi31TpiC0n9wDZa4ysWrtcEFI+0CnA6XXg7YTwhttcci2kchJaM/uw/r+KlQcazjqGpUVK1YgOzsbffr0wT//+U80N7c+qIZCIbjdbsU/FRJKXaVxZcBHs6xSPAe+TDXdlK5YbtURotIYaBSJSpoxDVpai8XnLcYzE54Rt22LqERkTQA1drsiJM643QiVlsKzaBE5n48/UaQPvA5pEAUAfcQLjc2KsMWHmoLdsKZLFSo7lhMn1pItDfC7yWC7aIMd3y8zYd/6euiys0FptaAt5LMxviQDGT9Bd9aTgelIpn4EH5u+6X07vrMwMJvsQDa/f8Of5HujNIn34TgYW5H/pDpoTF+ogY3/yqg2DPaizYemUQk4pD46IY4D1r2NK0+4EsUpxZhcOBk0RWN8l/G4aOhVYE3ku0/xA7f8xIKprMK+e2/HPsc+nPXDWfjHr/84pHOJhRBR0fDXQN+1KwCllT8blO7X6jvuVLxWoeLvhqNKVKZNm4ZPPvkES5cuxbPPPouVK1di+vTpYFpR9T/99NNITU0V/xUWFh7BMz72UemO79h7tIlKiAnBEyGz8AyjspTXpidRh1VVq8BwDHItucgykbx8jiUHp3c/HelGQm7aIiqGPn3Ev7U5OWBlM3jG7Ub9E08otq+68SYE9xCDr5Jq5XWL6KxARjpuWHwDblt+G+bun4v+46U0STgQhac5fnAo2yoNoLSVEJXWSksFvL3tbfxW/Zv4ujOJSpWnCpvrSXXL8JzhHT9AiCcqeiuQfQL5u3oz6ANLkR5JnJ7TMQDdRuYkr4GCJUA20uXktLot03QIUcKgGwFZOi5EaYGf74HV24B5Z8/Dq5NJk0aaonHniDthyCTGcJlBqWqpobEcyyuWAwD+bP7zsH5fAlGhLISoGLqTFJvcGI8NKNtOeFcro6gqVPydcFSJyoUXXoizzjoLAwcOxNlnn42ffvoJGzZswIoVK5Luc//998Plcon/KivjB+b/ZQiEYFj2MIzLHwdAqiY5Wmjwk9mrUWNEij5FsU4gIZ/uJt4hI3NGxnmmCJ10/dHkPYEAgDLKfDk0tOJhzrrd8K35I26fSHU1ItEoAi3K6b7FV40HvJ9hr4N4jXy972ucfGlfWHjNyru3r8KBTQ1xx6svkwZA2krOOzb1w0RZ/D53v2LZ61tfV5QLdxa5bPA3YPZPs8GBQ7fUbihOKe74QYTUj8EKFJN7DBveAz47D5lMYg8aYzvc54MGIDtASpC1SYiKoS+J4ETbiLwmxf7FmHfzVGyqlDxd/JQeVf4UwF2b0K9Hm07u0dyw5NDLsgy+P/C9+Hp9XZIWAn8BYkSF7y9l6CkQlQMAgMY5c+BdslSxT2DL1sP2/ipUHGs46qkfObp3747MzEwcOHAg6TYGgwEpKSmKfyokCBU+GaYMXNDnAgBAibOktV06HfU+Yi2fbc6OGwgyTMoIS7fUbnH7C+mhtiIq8sgFFwgqBIeM2wN9j3gHWY5hsG7pPpiiNgQ1PuTNqEdRxa8oqPwVWzIl8aIgdM3plioua6mRnQ8FUBTgc4XhbgrgpznbsG0XzZ+XkqjsXVuHrUtaJ9idFVE54Dgg3iNvnfrWXzPSC8siKoMuAIx2cVkmo/SgOangJACAgeeBrE4D6KTIRIukr0VAD3T3kv01iVsBIfvOOwG0nvpx//wzqu+5B5GGeCLpWvAE9nsy45Z/dXAw3NWJnzsanqiYPRKZpThJkAwAtd7apOfTEXAcJ1aJ0SnkXhPIWXD7DkQdDjT95/W4/RiX87C8f3sQjTCY9+oWbFxYfsTeU8X/No4polJVVYXm5mbk5eW1vbGKhBCEfTa9DX3SSSqk1FWKCHP0TKGEiEqOJX6WHJsKyjTFDyJC1KXGWwOA9IOZVzIP+xz7FNvJCQEbDCoEh4zLBcZNoh1yy/uw04sdP5Hmb5sKf4Hfsx09S3+EwxpARCsN4kKEo3hQ/PkNObUQlz85FqnZJFT/89s7cHBHM7ZuCoClaLBeL/wbN6LyjjuwffdKeJxtV2kI4uPDDXnE7S9V/ACyiIoNoDVAiiTIzZKlbftn9Me9I+/FWT3OEiMqYT2NsFkiKvu7SNc4oAfyWZIK1CCelFomToA2h6RhWkv9NLzwItzz5qPiaqXhHJyVKC1V9mNKCUiVYRW79yQ8niaD3H+X/iTdX7H0riEQT4r+Ciqvux7VdxAyJpS3G084AZTBAMbhgOPzxG0o2COo1Tu4oxmVux1YN69UdcVVcUTQqUTF6/Vi69at2Lp1KwCgrKwMW7duRUVFBbxeL+6++26sXbsW5eXlWLp0KWbOnImePXvitNNO68zT+ltDmC1bdVbkWfJg09sQZaNxLqRHEvV+KaISi9iISuxrABiQOQCAJAC9d9W9eOC3B3DRTxdhztY5WFW1Cv6IX1HWyfn9YIMSIYjU1IB1k2tT8Oqr4nJnYwBsBAhrgtiRuxKeSnKdGlPjIw0RJoLeo3PQY5jyc/QelQtbuhGp2SQ10FQpDWh+cy4Yvx8HL70M3oU/Y+s9/8QGx9q4Y8eisyIqQlRK0Ab9JcgjKgBgkiq5BoSkgX9X8y4UphTiyfFP4rLuswEAbm0EDVopJVcp430hHZDBknSHpnE9ci4/FannnIPib76G/YILkP/ss9Bm8m0JnE4wbnechwjHcYhUk0hH+EBMJPHbf6A5rBTpjiqpgU1PyFVVeVXCj5uwRULM+NzoP/T0Ksey8Mm0Jpo0OwCA1uth7NcPABJGUwCp/P5IgNbIyKVHdcVV0fnoVKKyceNGDB06FEOHDgUA3HHHHRg6dCgefvhhaDQabN++HWeddRZ69+6Na665BsOHD8fq1athMBjaOLKKRJhXMg8f/vkhADIQURQlRijc4SNbHcVyUgqg1kfC4jnm+IiKJqZKJBFRGZI9BAD5fLctvw3bm4gdf5gN461tb+GmpTfhxfUvKIgKG1SmfjyLFkkmWtnZSL2ApMW8TqKpcBobAAooWkSO3ZQgo1jrq4VGQ2PadQPQfxyJ+plS9LDnkMEvJTO+y7Db1hUMI41qhY0c1lati9suFp2lURGIrKD76TA4DvDwUQkLzzJMdnH1eR4f+qaQ9N2w7GHi8hSGaEJCOsArS+us6yM9gnRRwM4S8kNrI0gPf4L8B++AaeBA5D36CLRpacSbhvdQ2T9hIkpmnK7wxGEcEsETxMwiKtfCFZbePNPjh57lMP6kwQAAjytxalGXHx95EoTBdi+H0XtYNHrr8e62d3D1giuxpnpNwuO0Bfn9CwBaPqICKI3f5Mj5F+l+3Zah4eECx3Io3Sal3YSKNRUqOhPazjz4pEmTWg0N/vLLL0nXqeg4HvjtAfFvYcZs5JP9weiRK19cWbkSd6+6G4+NfQzTuk0TozmJhJtmnXKGG5sKAoBRuaNQYC1AlbcKSyuWxq0HgI0Hf8N5snuNC4XA+BIMPBSFxgYGCxrHoqigCWYPmU27jU0oruPQnWSB0Jwg4FDlqUJRCun/MunSvhgwqQCpWSboDIRsDZvaVSxdFrCn76Uo6T4TY9Y9Ai0ThJYBdNEkAgwZ3GE3ImwEOlrX5rYdwSFHVLz1QNgDUDSQVkyWmZTeRx9PnoNPD/6MKV2niMvsLPnMIR2xpBdQkw7cf4UGT3/MwByhYGEtCALQ6PiN9iwgOhgd2Z/SaKBJTwfT1AQuFEK0thbhsnIY+/RGcO9eHLz4EvHYXDRG2Ks1wskLYsd1M8AyrxTa7GzosgsA7EQgiXGavlu8bkqIKTz3CQW7i8GXdBm0OzfimmoOd/xjI16b+QFG5Y1Kfh0TIFbLJG9MqJGRFjl0RaTqke1A48tDwb4N9dizRtLjOOv9yO9lPyLvreJ/F8eURkXFX4c8ggFIA5FBQ6JTnd3kTo6bl92MQDSAu1fdDUAS8/a0x/fEmdFthiIllIio6DV6fDjtw1bf09FSE7eMaYlPn9A2G0q3NSHManGg57nY7yIdcj2GFuQ4JaKzpzA+9VPllUgIRVHIKrRBb5S4vjXNgP4T4mffEb0NjjSiF9IygJ5JTFSG5wzHfaPuE43uBKfejsIf8ePHAz9iU/2muHWHHFFp5gWn9q6Alo98xhAVszUX1w26TiGM7qshOhafgVI4z0Z0FIJ8AU4ma0a0gqRQtCZe6zL//4Bf7lccX5uhvEcEf5HqO+9Upv+CQbGRIQAw1i5wR8g5dx11IbQsB21ODoz5pAVCIMQkdJ/VdyuOWyZobuwuQoa673TgxL0c0r3AqH0cllcuj9unLcRGReTkJLabsnhuBaTzM3OENCpl25QproC3HeVcKlQcIlSi8jdBbHdhm44nKvxgEmSOjiGUK+QSK2a627vHrTdqjVh83mKcVnwaLu13KXSaxBGEXEsu8i1KEiD3ADHzPIy2SZECT4KIHRcKKfQFAY6QBrexGSl8FLskF9jRTfppCIO6nKgkQ7IamjBfuaRlExMViqPwn8n/wSX9LhFFrruad7X5fonwxtY38ODvD+K6X6+LS/kJERWhkqrDEIhKhpx0xgzuCb5DtoroRtj8LNhj/O/uO+lR8ofHh2htHWirFZYLbpc22PiBYvvYNEi4vAyR2tp4TQqUVvN1bgosaJhMBhhChAhpMzNhyueraqIawBNfvZMo7WIJAla/9Ln1YYkQpfiRkCS2hVi/HUVEJQlREZazXu9h6yrdGmKrxKJhNsmWKlQcPqhE5ThHrbcWC0sXKkp3zVqz6DgqpH5C0c6LqAgOnYIBlhwLyxYCAPIseUln8TRF44WTXsC9o+5t9X28EWmEu2rAVXjhpBdww+Ab0DWlKwqayKChLywEZVQSgR5Llojhe+PAAQgH470++pY34dpfyEO3NFf5ML5x8I0AgGpPddx+sWCZxKlOPx81MrEa6BIQFS2jh1lL0mCC/81NS2/6S1GVFVUrABANT2xp+vZGosERWhd0GC7+GthlRouRtquYwgcPAgD6DJqEd6eRx87X48n/E3qTFFFN3ljs6ns5zJMmgx4U03ZD5s9iHqE0qfMsXoIDJ09O+L5sIAD4W4D3T0NlA/kNFPbrB4bvvqzNyoLRTiI0UU6DSOUW5QE4DhRFofD992A9+WRxsZYFnjRdJL4uapS+94Imrl2kNu5cvcqIiq5Lgfi3JkUi4IZevWTLJTHVkYiqxFazR0KdT45UqFCJynEMjuMw9dupuHf1vVh8cDEAMvtffeFq5FmJ2FNI/XRmROXFjS+izFWG/1v+f3Hrnlr3FIDE0ZSOQk5Ubht2GzJNmbhpyE0otBWiex0ZKCry9eD00ozeeMIJ0Bd0QdF77yL3sUfR5dlnEQ7EE5WZf0ghbXeMe3uBjQwYct+MZBhyahFoDYX+4/Jw4tnSZw6YCFHRhBlkeuLF4imUHRqaaF3O6HGGuPyP2niTutZQ76vHQfdB8XWpU6r22li3ESUuQlz+ckTFy4t4rLnSshNvlP7urex+HiotQ90TT8K/fgMAwFjcHcsH07jxRg3mjqeQa8kVm+7t6XMJ6nJHo6nHSUBWH6DrOOlArgrxz7RLLoGhd2+RfAb//DPp6bJ+P7DnJ6ByLRy8PiWnR29EBaKSmQm9yQyaH4CD6z8FFt4D7J5PyrD/Mxz44SZYx41D4ZtvoMtLUo+o/jXSqJ0uixIVNnLwhD0dtgSQa1QMvXsn78qdKnn5UDodaN4Yrr1dug8FFK1kKpGwSlRUdD5UonIc481tb4p/b2sgpbsWrQV6jeS6adTyEZVO1KgYZe5cyRog9kyN16d0FJf0I0LJU4tOBU1Jt26qIRU9eYnKD9rtqKOlmWnx3G8AALouXZA2ezZ0XbogFCAPV4tXMl0zyqps3Gblw1ho3NeeWXJ6vgVXvzABky7ti2FTu2LoVCK+DRns4jYZwXiSYKWkwWdo9lCcUnQKACXRaA+agkojNIGYAMCaGqkaRUv/RR29UPFjk1VwZfQAHnEBd5cCF3yq2PzgpZfC8emnYHgnWVsfvsw2lTjkvXjSi6C0WugKpOgBCnuSqfvlP4odmrHvV3G1xmZD93k/ovvCBdB1iW+qWPTRh6jsmo+l/buioawEqCVRpCBDPrPRlo5weTm5DlmZoCgKRjO5hwN7VwLr3wa+uhTY8hnQUgJs/RRgyT2TMmMGzCNHkuPtTEyQCpoBiuU67IXDeAhRoS0WFH/1pWKdnJzQsr8ByZBOXvHUWYiV8ETViIqKIwCVqByncIfdeHfHu4rXQHwVjSim7cTUj0CGAOClTS8l3KZXWq+EyzuCW4beghdPehGPj3tcsTydMaJfJXmC7uxKYU1/QjQokymh82qET/1ogjvEZTQnPXDPKZiOM7ufCQCY0nWKSFRcIRf8keTlmAzLYEvDFrC6CCiKAkVT6DWCDOghWesAmk0QUYFy8BmaTUr6O+oq7A4pw/+CKzCgJCcn5p2YcP+ow4GK66+He1GCiryy1UAJ38hRHlERYMmI06cwLRIB1OXnI6vnQPG1TWfDoKxBAID8Z6Xmk2LqTqMDhl1O/t74ftzbURQVX5Gj08Fy4onYYTchpNPi9x8+BzaQ30mAIefm/fRz+NeREnFtNp+Ss2cqtgEA1O+U/nZJJFUoVw7ukO4fOfRRIMcJOAIdKzEXUj+2U08BbVKWultPPhn2889H3pNPIOe+e6Hv2QN5T5LeVYLF/19uK9ABRGLSppEQg6Avohq/qehUqETlOMWqqlWIstJDQ0hLJCMqnZn6cYUlHcW8knkAiBX+rcNuFZdPKpx0yO9j0powtXhqnL4iv9QNLQvUpgF16RS+GU/j+8kmFH/xecIHaDBAQvIBTQmGbH0VdMujivVDz70OT45/Et+d9R2enfCs4pq21m/os92f4fKfL1eUiUeN5LqH9SngeKktS5FQfYVdEstaoTRu6WEnkYSOGvXFimfl6TJB73LtwGuRZkwszmx64034Vq5C9W23KVd4G4CPzwBYPp1hjTfvi0XstTePGgWTziRWdsnTgcahw5AITO8Z+LmmN7aWJml2WFigfJ2VpXhfpkVKGQX5687t2i29b//+AABLGt/4svh06WAVsrSbICIGYOhFooNcJD61U8HrbqduZtG8/auE55wMjFeIqMRH3CiNBnmPPwb7rFnQFxaix08/wT5rFgBZRCVBldvhRigmbVqxqwXv37kaa384eoaSKv7+UInKcYoqjzINIRIVrZKomLRkZtaZqZ/YWTxA7NPP730+RuaOxO3Db0eqITXBnocHtJsIiQU3WUZD4YvREbzj/RVzn9mIn17fptg+5CcDjNsUQLpzHyZtJ/bn2vw8dP95IYy9e4OiKPRK6wWdRgeaosXr2Jofzfs7yaxf0AsBwP+tvQksWICiETp3Jhhah6iWDJhLen2CRisZSM1Q+poU2UjKqMpT1aHZahxRCUtExRlyAkCr3wWTbFYeWw2T0rb9vjyaAkik4tMZn+KM7mfg0bESQYzKtA7ySvuqyibscuVgaXUhVnz4JoI+ZcmQvrBI8Tr/+efgl/W90fARMO663xGk7eQ8ZP2IdLkkMmSx80Sl8BSgG+lPJCcnrgObMPfJh7Bv/nuwRRL7P7lNAIaRyOEZGzjonv6iQ98dy7vL0qkd61+mSSekk3F0Xpd0juOw8PUXUbf3KyURjJBrufmXg8l2BcsyiIbDKNm0Hi01beu8/i4IR8m1iTAsWJYDw6pRp7+KTjV8U9F5EKp80gxpcIQciPAz3djKGjGi0omGb4kqUwZkDECqIRUfnPZBgj0OLygvqToJ8BmVXEsu6nx12Fa6G5qDxH4/HIyKnifhIBkUnWblNdHY7TAkMPcCiA4nEA20eh1jxZMcx2GPYzdG6tywROyozJ+NzHTSn4ihwghrAgjR5HhmKL+3PEseaIpGkAmiOdis6IG0snIlAOCkwpPizkHwScm35KPGVwNPxANXyIWNdRtFW/7WiIq8YopjWYA3TaNCMoIw/nbAliD1E4NwqXKWLZCCAlsBnp7wtGKdvHpEPhD6A1IkZdOiBQj4A5h+0x3iMoNMcNpz+TLo8vJQsXO7uCzAagGNHtUvfoigxw1QFHRR8l6ZN0oiYCGisnnRfIyYnC2WmbMc8G3FQFTsXgUAqNu5ATf3+QMpQ4bBvbVO8RlSoMVAbg0OgoRVrBUUwhUVMHTtKm7D+nyofeRR2KZOQcqUKYr9Gd60TZPaMVIvpX46j6iUbd2I3atJVZ8hdRxAtd8w8LunH8HB7VI11Z1f/XTYz+9Yw+YKBy56Zy1CURY6DYUIw6F3jhU/3zoRGjo+Ha2idagRleMUQkhfqEgREBtREfQjnZH64TgOc7bOQYWHRAXO7XWuuG5at2nJdjvsKKRJKsFvICTl+YnPA1A29vM5SUSJZVhEQ2Sm02JVXpPQvv1J30O4jrF+NXKEWWV6QrjmPj2ZKVfudGDLgJvIOq0LoICIhmxjZJVERafRIddMBvYdjZIWwhv24uZlN+PmZTcn1MsI0S3Bi8UTcuO5dU/hthW3iRVEdpmwNxaUrLNxtKkJ5Zddhr2jRsO9jO9Bkz8UOPWRpPvLEeKN2ARoc5KTG3lERU5a/DFdgXetVpbAW8aORcGbb6DnsqXQ8c1MW6olkbQnYkBLWSa2rV8Djtcr6RgWKWecgaz/u0XczsoTFW9zE/Y0SGJ0Z9iECr9dOjeWPDJzpsanvqhwGC6TEXtPNMHJS0xc376q2Kb5gw/hnj8f1bf8H+lLVF8vEjORqCRxoU0GDW+A5/ruu07RqXiam/D9M1L0i2M9mHbdgHbtu3v1cgVJAYD/3nsrqvfuTrLH3wObyh0IiREV8v3uq/eiph0NSVXEQyUqxyl8YRJRiSMqR1BMu7F+I97a9pb4+rpB12F279l4dOyjCbsgdxZ66YnYdUyvU/H9Wd+L3ZabvNIM0+ciJKKpygtwQEjjR0OK0mDLNHAgkkFM/bRC+MKMkqgIpMGvj484MRQ5TkRD9snRx3cMz7WQgV1e9i10ogaA5qA0KIWZMD7b/Rl2t5ABQCAqXn8T5vFeNgJai6jInV2b33kXwW3bwQWDcC76jSw0tH8mHS4rV7zW5STXtciNwwRSSf6O0V1wHNyN0jWgKAq2k09W9ONprpZ0KWGtFh/6emBPvuRma+reHTkP/EtxWI1OIifVDimi00QpU1wmLYmaaW1mpF16KfQpUhRNZ2HwefkQlATysbo/Sce4ly8US2U4joPja0m34vnlVxw4aRIaXyQlz0J5cUcjKtaJJwE0DdbnQ/Wtt3Vo3/agvkwp6I745sOSFm/qF5vmCvn9WPj6i3HbNZSX4MuH7z68J3mMwRVIXJoukBcVHYNKVI5TiBEVq5KoxA5CnSWmdYVcWFW1SrEs25yNh8Y8pIisHAkI/hMFub1h1VtFoqJlpYfpis/2gOM4VO4ng3u9rRwp3ZQl0/nPPZv0PdoTUeFkDq0cx4lpGEFQK4dPS3L6QkSln+2EuG1O7y4JOwXhtNCJGgCafQ3A769i347PccXPV+CZ9c+IJchCpVKQppAT0/NGuD6JwMh6xjg+lUqNoy6+5NvQPv0Ex3Hwb9yoWKbNiW9KKcArIye719SiscKDSDCIhvJ4keYPzz3W6nvLIyqxyM7MQbfvv4M2xum1aMBg8W9jhvSbaqKV5c+BqI7wDk8dch98AD1mNKJoUhOM6WHkn+gUtwvzKVhXBbB2zw/k7+++B9MolY+75hHhefN776P8kksRriTnrUm1t/r5YmHo3g0Z110LAPBv2nRYoirhYBQHNjUgHIzC09SgWMexPhxYtyRun1jzt5DfG7eNHEw0Co5lFcTz74JkRMUTVLtN/xWoROU4hVBpU2grVCyPtZnvLB+VSxZego/+/Eix7HA30GsvWMF/wkaqJQSdjoaVJFiuhgDWzivB2rkkHdFgrYDDIhELQ+/e0Bcqr6UcbTV3ZFjlQzoQDcATIYM7Z4p/OFWbvgUARDTke6Gj8XKxc3pJ7qyCKFYRUdn8EXxL/o1Zm5/Gzuadin3l90WET3lkm7Lxf0P/D11TuiIZGHdiJ1zGyQ867XS0DWzZiuCOHaBMJuS/8AKKPvxA4aIqx67fa/DTf5SC5x0rqvD14/8S0wYnZkhRksaK8rhjcByHX956FUs/eAuuBkLm0r1KUmm3peLS198DrdfH7Z+e3wX9JxDn2SAtRY2aOaI3GZdF3pMFjQirATw1oreKJTeMblObYCqWiJiGM8BtAswhCs999xAAwPH558pzDkr3UmDTJrGxoMbeceF59m23kQomjoN3xcoO7x+LtT+U4pd3d2L5p7ux7MO349aXbY03Igx6lfd5KFFTUBkaykuw+ed5ePfmq7Hik/f+ViXOTp6oPHRGf8VyTwJXbBVtQyUqxyHml8wXrdBzzDkKfwzBkVaAkLIQZvdy1Hhr8NS6p1DpST4DTYQwE1a4nwJQlCIfCTARFhyvomd4/wkN3+dH8E6RR1QAYPPP0mDn0ztxQGaGxnjjr48cwnVMFlGJvb5/1P4hpn50ungSMiGDlONGaEJUElmR62idqDlyh92IMBE8+PuD4vrmvfPQqNEkPJ8+6X1g4p/7Lfw2r5/yOq4ddG3iD8iDdSqJimCoFvUESCShnamfSDWpSjMNHozUM06HZcyYpNtuXxZ//1nsBtQd2Ce+zjO1bg/vaWrEzuWLsfWXn5ISlbxuPRL66gjI6U4ibEF/AJh4D9BtIpp85LvLHTweWv5pGWC0gLuWdJKWgbNJkwR9ENjfhbxXrxryRWhzlRqd4N69Cc+jo6kfAYLFv3fFir+0vxw7VpDvb++ateKyzK7Sd+iqr4Y9x4Ts4hRY7CRq21RFyCwTjWDtt1+iPEabEova/Xuxbx2JAG5a8AM2/vT9IZ/3sQIhopJqUj6DVKLy16ASleMQ//pNyq9b9VakGaQwdmxEpQ/ftXefY19c6erDax7GF3u+wE1Lb+rQ+8d6e9wy9BZcPeDqDh3jUOBqDOCDu1dj8YfEh0SMqMT4T2jY5BEev86DYdnDxId7+iWXtPqeAlHZVL8J7+14T6yyEiCU/gq4bflt+PUgcVPVa6UZvA5hDNz5Nsa8sxYf156OVCsZ+CPBxA6fKXyqxR12iz18BDRTLLx04p9wN0MWUlhlPrytEnEuGkWkQRmGF+ztwXJgwlS7iYqQfhAqUlpDapY5bhmtVc6ui1Ik0qEzmoB1bwOvDgF2/QgAcDfFpg84pPmV0a+RM85u9TwM/P0T9HmByQ8getG3cNSTyp7M854CSxHC91tjMcCEEH2hPyEtPMImpQZnXx75bgSiEmtxL5SCUzHmbskiT23BOnECAMSl3P7SsdIJ+WCjUlfy1LyJMNhvAUAh5Pdh5q19cN49w5GWS74/gdzs+X0Vfv/6U6z6VFnx12fMBEy89Gp0Hz4KAFB3YB8ssjRXzV7yez7Q4MFVH67HhvLOq2LqbLj8RHsWT1TU1M9fgUpUjnNYdVZQsp69sRGVPGseilOKwXIsttQrZzjraok7Z5lLWZ3RFnY3KxX75/c+X2Fp39nYtrQS4SCD/RvqEfJHwHpINENI/QDAl2d8iZnFZyc9xlkDZ+CpCU8h//nnUfjee0i//PJW31NIoX1/4Hu8uvlVzN03FwDpn/PO9ncS2qUL5neRnk3Qm7ToNTIHM8xLkNVEomGmj37E9SNIhCPW8VOATU+IgTvkxq/lvyrWNWs08KQSYtpTY0XfKBkQ+4dC0K1+ATkxhmQpmnhCIEdwz15wgQBomw2ZN/4TAJB9772g+YGTCWraTVQE8zGhIqU1JOoXE/JJFU3dh4+C1p6HK7qRjsRcyAv8fA/gKAO+uRKIBOBuaow5AgVzSPr8I0tqkNG7T6vnYeRJY9DrhauhDq9edi44loXRlgKLPQ1aPRm897gJIVlQ3Rfv7h8FR9go7idHdQYZpHrVcAhGg4jWEi8ay0kTFdvpiyQvGPPo0YrKq45Ax6cuGYcDXPTQZu4mKyHXHEPOecp1tyASNIKidDCnEvLprKsBRVMYOIloehx15DvzNDclOCJwxm33YuSZ56LvOFJW73W0IBSQvmcfX+E1Z3kJlu9txPlv/QF/+PiMQAgRFbtZjagcDqhE5TiEIJQESJVPQ0CaTSbqUCxoEpoCygeIMAgCHfNZiW2Ul8zltLNQXy5Fhqr2OKSyzhQpYnBCxgmYkDcxbl8BFw49D12sXaCxWmAdP67NwUHeJgAADjiIGdhVv1yF/2z5D77Y80XSfc2pelz13DhMubo/NBYlWdAZyCw9WRfaFN563x1xi2Xgo3PJjNRD0/AWkb9tHIe36pvxUFMLnm5sBv54HTmM8pgWrvWKg8C2rQAA06CByDT9iNyzWCz9YTe297gILKVBNEi3P6LSIkRU2r43Qn7p4c1xUYTcX2Dv75+Jy06/5S6A1sCqI7PUKKdBlOXJOccCruqEgkxTRDquJRwRm/clg0BU6kv3o2bfHnH5sGlngqIoTLj4Stl5Age8mYhwGqxp7AqAQrDPLMXxnBYdWADZLqD+4B4xWpVzr7JLuF7ms2I/77xWz7E1aFJTAT7Cdqh9f4I+MtCyDDnOgc0s/B5y/W0ZhKh5WsjzJLPAqthHm0ADNOa8i8W/TTZyTwfcLoRlRCXA/453Vkvpx101nd8RuqOYv60GX66vaHWb5KmfYzei8lPpT7h31b2dag76V6ESleMQWaYs8e9UQ6pIOIpTihNuL1jOy+3UOY5TWPDHak6SIcpG8Vv1b+JreWXKkYLPIZEqV6Nf6oSbrQy9CyWvLaYYV1UA5pT4h2lrEFI/AvQavUL8t6UheT4+25wNrU4DiqJAm5XH0RkJUQknS/0IRCXkFhs+FvEW9D5aAy9vZW+NhJAR9mO2x4vu/ACtiREnUuHWqzAi1STMb8gygKreiLWhy1Hl7Y6G1EHYOPwehEN6ID2xIV4sxIhKetsRFaGbdVaRDdlFLeCYWjhriTjYYk+D3mQGmg/AQEcBvrIqxGqBFL46x1Mbl/oZllYNmgMGVjSgb00TzO2YmZtsEgn7Y64kfD3x3AsAAD1GEFJIgVP0BNrjzkbj5esQyFCWt4/kMlDNV+k7v/wPEI2C0uuh79oVtOy99F2liIouN3llVFugNBpo+GqmaMuhpU0CnjA4jgU4QiRq9kfgdxOiYuY1NAE+kmm0kmvBRFhEwgxCMe7B3YaOwJjzLhJfm1OE/d0IB6SUns/lhCcYQWmTJMItaWz9nj3SYFgOt3yxBfd9t6NVTxQhcmIzavHAjH7icvcxHFG5f/X9WFi2ED8e+PFon0ocVKJyHEIgGHcMvwMmrQlvnPIGTi48Ge9MeSfh9lYdISpywacj5FAIQ8vc7Uv/rKxaCU/Yg1RDKlZdsApPjX+qfefc0oLqe+6B97ffE/ZIaS84jkPAI+3va3ADLAvQNLQZSj2EYO/tMCkdRAFAq0ssQk0GeYdogBAVoaoHkK7tiJwRcfvmmKXBhzIrIypaDRl424yohN1o4Ts8d9XwxFNnhIf3/7AGZCLYPFJqK6cpNzmcQKh1wXC0kaRPtFWLEGLNKAuNEtd5rQXYT50GdJvU6jEECPoLTRsRlcYKD5z1ZDCcfHlfaLXKSIDexF+vGS+AogAjTe79YK9zJdLkqUOAr1Y65Zobcd2lwzAxh9zPhQ4PujcmrmSKRVpuPvJ6kfSQo5aQtv4TTgbFRykMZr5XECg0hZTfY21ZGdwNSnHtWDoPISuvT/ltBQBAV1QISqMR05UAoOsilUPHku2OQtAEJW2F0A5Ewgwh+VwA5C6iAMoMNsoBFGBNtwMgRAMgUUFaQyJcQW8krs1BZmFXhYhZjKh43Aj7JVISCQYwb+NBhdV8SWPrlUNHGoGI9Dtt8SXuPxVhWET5z2DSaXDtxO64fiLpa3WspX4YlkFToEnpBt1K49WjBZWoHIcQCMYJGcR7Y0j2ELw2+bU4fYoAIeIi2O4DyjJXALh75d1t3qD+iB93rrgTADAufxzSjGnt1qY0PPss3PPmo/If/0DdY637YLSGkD8KVvYg8zeTh6I2IwOUVlldwwjOkJow6q3lAID1hQtw4UOj0FHEGulx4BSdiYVr2zutNx4c/aBiW8G4DZCJU3nU3XIDOcckDzAhrVbqLBUFvIVr3iDvqdHAy19/myCcTSkAhl2BIGvBOU4yOIyKUrjB6W4XUYlqDNhkvBB/+k8DByWZO6gbD2ja7rrBRaMIl5cDAHRtDLzfvbAJLOMAx3rgaa5G2eYFivUG4XqNIGJto4Zcp7lrw/Dr+WN7asVSWGPJAtg2vQoNxaHo1lPaPFc5KJrGBY88K6aAAEAv+750BiMo/no3BJXC7bDfj/JtmwEAGj6NGIjSCPIl8Lpact30xcUAgJQZMwAAqTPPAmSW6tosKVr6VyBogg7FTt/nIKF/juUJB2UWP7fJqoMlVYqIAKTKzmghnznojcRpdQwxIncTH1FhGQbemP5Em/aQlEq2jeiBSo8xoiLXzCQzbwvKyIyRnxAVpJP76FhL/czZOgcnf30y5mydIy6TP+tWVq7EXSvviivEONJQicpxCIGoxKYjkkHQrcgjKrFEBQA2N2xu9TiVnkowHPkRdrTKJ1wtNSNzfjO3Q/vKEfAoZzEH9gQRNKQlNBOL8hEVho6g10Um7Jm8ENNnjURGl/Z5gcgR67Trj/gV5msCCmwFSDUqq2taIypUC4lihJM8wHqnkV42QrrNqjUjndee7KEZVEXJoGAViIolA6wpE980v4B91c9irmEkXo/y59MWUamvw8Gi07Bfeyr+8BJxcb5uJ/r5SSjYQ7UvLeHfuAmMywWN3Q7jCfFGdnJEgh6E3R8i5P4E5Vt+j1svRlRoDTD7Exh4ouJtacK7C5qwy5UFeOrEWbyxROojYxnSD71+/w322bNR/M037Tp3jVYrRlUU7w8yIOv51F1dUKnVCQUCqNlPdC19xxJtVCDMIWIh34s2SB61gh4l58EHkPfkE8h7/HGFn0rs/dFRaA9Dg8KWGkIOOJb8T9GStodlOBitJCIS9EiDl5D+CfoiCPmV5CKjSyHARIDF/wb2LIDWoXS6BQAtX/vtd5A07tAiOwCg2Xds6SWCMgdlt+w3y3Ec1tSsgSPoQDAibWPgP1cK32fsWIqocByHd3e8CwB4e7vklSN32L552c34pfwXvLf9vSN+fnKoROU4hD9KIh/tJSo2HXmoyjUqgrBWLswVfD9YjsW8knmodCv9LRoDZFDtk9YHfdJbr6CIhcamLLnk2MSzkbYQS1QAYHefSxLORBl+ZjO7//m4bPhF+M/sF3Bh3wv/0vvGln37Ij5FRAUgPXTO6XmOmGoTICcqsaF9DSP4qLAJDa9OyCQDvVBVlK41wyqLKM0vJ/b4NmGZOQPOYDrcTC78bDoM3GiYhLLkBF2u5YjW18IdYwZn19agD0e6QQc5g+hd0xqCf/L6krFj4qJccnAsBzbKE1guBFdjvJZISLcAAPqdBWPxcOl8oyx+rukLb0MVgh6S3hGIDABAZ4I2IwN5jz0K08D29aYBgLyeMqJiVP7GhPOpoXsAAGjeoyYc8ItRnYxCcg0DrAExLZxg7Ev0Ctr0dNhnzQKl14seMx3t8ZMINB/tYFztnwGXbm3E54+sRX2ZGyzD4ue3SW8pmiLVaXKiEgkzMKVIqRsBBrMWkcAa7P1jWVzqJ7tbD2DdW8DvrwBfXgysjHeAztUTYhVtIb+pHlnkN+RIkl45WvDLBNruQASukAuLyhZhXsk8XL/4ety98m4xomLU0WLKyyYQFb4S7Yv1Ffhhy9HtJF3uLk+4XD5OCKj2Ht1zVbsnH0dwBp14ev3TomdHbDoiGRKJaYWIyol5J8IddmPxwcX4/sD36GLrggp3BR747QEAwBPjnsDU4qkwaU1o9BOikmnueB+faIw/B9Pc/JfC3H43+aFTbBQcb3Tnt+RCVxivQxAiKgZDx4SziZBnUabVvBFvXETl+kHXw6q3KojKw2MeVlRixbrfakSFPYX188uQ0y0FxSekAX9+DxRPQHFKMbS0VtQl5dFGKXoig23IJcDe34Ghl6GpzgqAEJsaTxGKeeHmjlX1cK5dgfFXnwQqpoMrGwyC9YfAUspHQpq2ClaDA+BYcBSNgDfSphA5Uks0QXLtRSL4XGGwjHRflGxcG7eNPKIBioIxPRfAPsU2FRVN8DoIYTDKiQrz18LsaXky47aYCIdwPh4nGaTzevVF9Z4/EfS4EQ2T79KeTYhpMMRCN2wGsFwqKzf26w2EvIDeAvCDmKFXL3SfPy/p74GJsIiEGDFq0RponlixwfY3v1vx2R4EPBHMfXYjzr59KACA40IIeYlfkkYviYRPnNkDJht5DsiJCsfUgQmuxbZf1iItT9l2wLb0DmD3PGnBn99hen4Wfq7pKy7KMXpR5bcj2FQDZHRHd56oNB9rREVWSu8ORHD3yrsVVZDr6tah2U+eRUaZDs5mJN+dJxhFiy+M+78jZHBPnQf3TZeuw5FEMqNPf8SPD3Z+gDqfpO3TaY6O67gANaJyHOGZDc9goazBXHsjKsLA6Q17EWWj8Ef8IunIMmeJgs21tWtx6cJLFT18Hvz9Qby7nYQHhYhKtqnjgr9IvXJQj9TGz55bQ8NBN5Z+tAu1B5wAgFSZ+NcYbIGhW3w1iiCm1eoO/TbPsSjTHr6ITyQqZ3Q/A2+c8gYu6UdM4wZnDcYtQ2/Bqye/ivN7n6/YT5euHPg0TIiU2ALYuLAcC+ZsBz6fDXx7DfDlxaApGvNmSg/50TDBkoCopHWbDNywGhhwLppbpIdKdZ0ZMNrBccCq7X2xfSOLivW74vYXKqeCJmWVTl/TMmh1Uej5HLVw/VtDpI58t9q85N2Sfa4Qvnx8HTgmXvSpMQwT/26pt6Bil7SN0RKftvt5By1GowxdRwBTHgPyhwEDz4/btj2wZUr3t8Gk/L4MMuJCazQo6EciXl6HdI4p2eReCXjciAw/ATW8nlhjZKD/70jg6S7AfKnRJEDISrKIyvzXt+H9u1bD09K2hQBtIqJvLtB+uwFaI/0+6svd4DgWYfeXZAFlgUYv9cQaPLlA7DTtqKtByO8HE43i4NZ3xW0ctWT2Pfa8i3HRGAqUnKTw6H/a+bji6tPE1+l6QqyGuncgM9SEHll8ujoYRYQ5dhr5BWRExRWIxFk1AMCOJtIOwqiVExUp9VPvlr6bt1aW4EBD6ynZzoAr5Epq9OkMOfHyppcVlgvuNqKxnQ2VqBxHiDVsazdRkUVUrv31Wpz89cnY79wPgJQ6xzqWxrquCuRoUfkiso+5Y5EQNhQCww+EgpgwUtMxovL73APYs7YO25cT98tU1wF0KyN6hIjWgmBGMaIRZeWMEFHRHAaiotfooaGkB4837BVTP6NyR2FCwQQxzEtRFK4bdB0mF02OO44mqtQOUJBHVXgc4Bu+VW9E0BtB41oWFxdfhl5pvXBeiIOZ46CPGjF5/2UodJDZmNzLxtkiPdgbaxkE+8xGkJM0FT99VC8KjQVEGxsRNNgRkrkcn3//CBhvWgL6rBcR4QXZi97ZCSbKIpwg1y6QhagQUclNTlTWzysl/ilczIyZoqDRS+WcLbW5mP+a1AdIZ1RWX8XCmJEPjLsVuG45YG7bFTcRUjKl+1tniKn2khGXcRdchvR8EjXyCE68egMsdnINg14vLOZU3HaDFt9eRKHr5GaIxS+bP2n3+VTvJdGxPX/E/2aE5oFC1RglRlTaT1TkRH7PH7VgoxXgWIF4Sb+pnG4poDU0MouKkZ5fgGgohP3r16B003pwrPJ+oLUmZGQMQL5T2bhUxKmPIvO0W3D+Q0/iiudfR8Qs3XcX1XyD4gyLqDF2+I+dqIqcqCQrNW7iu5qb9IkiKhE0eJS/99jXRwKf7/k86Tp5JEVc5o9fdiShEpXjCPJyWEDqjNwWhKqfg+6D2Fi/Ef6oX/T9yLfmxxGVPS17FK/tBjt2NO7AfgchN7F6jbYQ4bvC0lYrjP3JINTRiIqrURnKNvsbkN1IxL9+Sy6++y6Aea9sVWwj+KgcjogKAPww8wexp5E8oiLXoLQJVxUyByhnJ9qYztZRTkqtfPXUevzxXQkmNJyD7876DmmuWkRZM67e8Cx6N43A6XuIg2wG763CsRzczdLxOA5o1g6Bl1FGSip3KwlTpKICrlSiu8jICOOaFycgu2sKkNMfVLdRyJL5xLx/52p8dO/vCr3Q2h9K8PF9v8PrCCFS1zZR8fMl5pyMqGgNBpz3r8cx+YqJ0FtnQ2+7QKGPaA9oe+vppvbAIuusHAkpv5uuA0lqJKu4O0aceY6UCuLNz/QmE8ypqaA1WnAcC2OAjLYlI4bCMHoGYJKVa/vaLiGWa4LcjfHpnGWf7MEv7+7EH98TgSrNEzmOT/0Ed+1C3WOPtdpRWfBHAYi7LBveL77W6CW9jlCCTFEUCvoTzY+nqRGuxnhROZCNxV/HONRa+fuhaAzAt5UoGjAYmUXFCMdE8lJNOtjNZBuH79iplPHLJkMuf+LzagmQ35YgpAWkiEqE4eK0KY1Hgai0ZvCZiJSUucriDEOPJFSicpyA4zhFeTGAVhusydHaQJpvyUeqXklUYoVTUTaKtbVEP2DT2TC92/R2va+AcAUpOdQXFUGXT0hOpLamtV3iIDi4CjAHGsV0hIDaEqVOJcrPfmL3/asoTi3GxAJS0eGNeFHrI2RL7pPSJtzVyBrgRXofSS8U+zmCLO/0SWfD20IeYnWlLmDfr0DjbvzqvEOxvSlsQ5oxDUs+2oUP7v4NTZXk2CZeS+JpCcLHKHVFXlkagXE6UXPvffBYiX4mvwvEclMAoPQG9Cr5FhQ/a46EGERCDCp2SWRn06KD8LnC2Pxzqejhoc1LXC4PEPElAIAjD/szb78P17z6LroOGgKL3QBaVwBa2yVuv9Y67Fq1ISD10IkKTUv3S2ZRsWLd8NNn4qYPvsSlT78MmtaIRCUaIt+TwWwGTWuQyoum9z/zCQwhGi4mAFzwX+DecsDOC5Y3fgBUxOty5JBHrhKlfko2E43PzpUk0kjxqR+WT/2UnTsLjs+/QNOcOXH7CseP9fDhWPI7yuneE1rjWHF5Qi8UrxvOuvhJB6VJYPR34efAkEuB2f+NWxWMidKy0QjSePv5ZH4lRwMBWXmyJ5SYqAgaQrlGxarXitG0748BoiLvDxeLRBGVKBvFt/u+7cxTahUqUTlO4I/6wfJahmnF03D78Nvbva+gQUmEXEsuLPrEs9bpxYSQ7HXsFcvYbhp6U7tFvALCBwlR0XUtEgevaAcjKvJZHwBYfLXQRgPQtDIzEB7AWv3hISqApPdpCbbAF/FBR+tQaCtsYy8ZXGRAobMlN1J9KJaokAiYBxLBtOq9wOfnI8DaUBEertg+z9MDhqgZe9fWiTbmANClt50cpzkI74j7Fft4HNLDMbCdVHf4+EhZeo5SLEsb9DCE3ehd/oNiuYafMSrMonhfG8pgUGgu1n3/Nb5/7jEwfA8aluHARCrAsYTspGTliNqHWLGuSGogVdkAwCVPvoTrz+2KK7pvwsXFW3FR8dbDQlQA4OpX3sasBx5HdnH3uHVGi1UkM4YkYlv5NRm2zy5qwgAAOXzJ9vIngA9OAyLJ72F5e4FQ1T5g2ZPS64CsRYA5AjgrQXuIwzQbDCjOQW4PIMfu38nvkKYp2fdJzmfc7Etx+dNS+lIuwBa0QmWbN2LbYl43R1lk29qVb5TRCygYDpw9B7DGp45bDMrJVMDrRrqFj6j8hdRPIMzAHYzgy/UVaHC3Pw3WnuMK8Mf0qBqdNxoA4I44ARCzNwE0TcGqT1y7cjSISrIu8InW5ZhzMCRrSIcrPQ8nVKJynEDwQNFSWjw38bnD0q04zZAGs86M/un944zbKFB4aMxD4mvh5h2dO7rD7xOuIA9PfVFX6PiKCs/iJQju3Zdwe9+69aj51wNg3GQAj0YY0WZdgJYJggJgjqm8kT+cBaJyuCIqQHwvpV5pvTqmiHeTAUNTJPmL0JzygRfiIyqBoPRgC/K9W9zR+OhYlrcQVbuVjq7mFD3S83hBYksQHq1SbOxtCaLhIClHDVcR8uTlK5uE/QRQBpJitDmUXbNDfOhbPphGPKR0XpebK87AOY7Db19+gtJN61G+jTQWDHrDiHglPx259iSWqIQDUfF7HT5jJiz2NIw481zk9uwN6wVzkFlYjDyTBym68GEjKml5XVA8aGib28WamdF8OTYli8rkNRvhDDmlASC7v/IgnuTRRQVRCXHAqucAlmiEvvz3anGdJVwKvDIA9JrnARAxrdASAQD0CSqwyrY34bdvSJonNceMkWcUk89Ak4HdZEtBSoakg0vLkUiZYIrnrJcmHBqD9LkojZ2cBweS7rr4q6SfEQAqoSQvQa8XaXzqp6MRFac/jInPL8egR37Ffd/twAu/7u3Q/q1BnvrxhiUClGnKxClFxGCw0U8iisaYlHNKTN+frhnkeh4NoiJYXPS092xjS+DSfpfivzP+i0mFkzr5rJJDJSrHCbx8nxar3trulI8ciXQlhSkkElCUUoRfZ/2KsflSmDfNmAab3oZPpkuiv572nuiZ1vaNHYtIBdGo6IuKoC+WfDo8vyq7ATNeLxr/8zoqrrgCru++Q8PLLwMAfvvmgGK74uY14t8WvzIyIw9jC3/rjZ1HVPql90uyZRK4CFGh7ckrpwTha4CRNY3kn4luhjzQ83qkot85dgBAf3YYmquU3gfWNANSMsjgX7a1CQ18I8csHdEy7Ftfj2+e3ohdqw4icvAgohojQkYS0UgvtCuORfEkwuKrhlYvPTJ2rqoGx3HwyqIzHicZVDS5edi6pAINB92KhoGC5iLgVYbN5X4l5hQ9bBlGMZLCcdJ3abGn4fq3PsFJl8qIer5UJYSU+HRRZ0IuvAWA2v1kUJx63c3Q8Z8pxa+DIURLIfWcGKLiTh5dlEdNwoIpi/Mgqvc64HXHp8EoLW/Z73IhuHOHuDy2bUXQG8HCN7aLr9NyzRg2tSum3zAAFMUTFd4vZebtQ9F3TC5OPFuKLsndewVoDUOk86BJOjkKPXDeB0BGj6SfEQDKY1KTQY8sotJBorJ0d4Ni8P96Y1WH9m8N8oiKNyJFQheeuxDZZvKbLneQ+90Y06YjTZZOvXpcN9x+KjFz3N9wePoZrdjbgPPfWoPSdvRHElzIp3adih6prX83R7rpbCKoROU4geCBIu943BG8eeqbcct62XuJf+dYchTNDtP5QSuw2obn9B/i4j4X49Gxj3boPSP19dg3bjx8vxPHUX3XIhi6d4eugMzuYt0zG197TZFLD+3Zi6o9LfhzFRncU7NMOO++Eei+R3IZtcTkUxWzezGicvjsgrS0VlFtFVu23Cb41A9lk651rEYlxKUAFI0AK6XsQryzqafv9QAAW6YRA/oS0pjqzUFLrbL9gd6kRdcBmbCk6hH0RVC1h0RcumYqr9f6H/fBX1qGVRNeBABY6GYYY8pkab4broaN4oxreyONj7g0VXqx67ca+JzSoOB0AVGNEQ3pg/D73AP4+qn1eO+Wa8T10XAIDMOisdKpeA95RIXW0Ljo4dG47MmxoLWElMu/1ziiniULSce4Anc25KkoAMjuSgbzgn4D8H8ffyNWBWU5DRJRKRipPMhHMwBP4qqKkCyVF+LM4DgKaNgtistNNPleQzyJofneUeGyMlTfJqWH2RgTtvULlL29UjJNaCgvwbwXrkc0TIiBoEMp6JOGU86xwwBJzG+0KZ9DQ6dfBIq2QWMYDo1+oJj6CbMWILXt1OjeoHIwDHg9SOOJSktM6ofjOFS2+BPqlTiOw887lcTPqKMVbTcOBfJ0jz9KtDxphjSYtCZR90Fp/Pz7xhAVsxQpPH9EAcb0IDqenTWuQza2C0YYXPnhBmwod+CD39vu2+aLEr2jWWfGyye/jKldp+L1ya8n3FZDH76J3l9FpxKVVatW4cwzz0R+fj4oisIPP/ygWM9xHB5++GHk5eXBZDLh1FNPxf79+xMf7H8cQuon1vW0vehu747F5y1W6ClitRXdUqX0QIYpAz5XCFsWV6B0pRvnUldhUNagDr2n86uvFM3RdEVEl5F+xRVgKRpNX85FxbXXiaWUgW3bFPtzkQh2/UbC1xotjan/OAGZ6QAi0o/a7lTeL3KNRmekfgBlVCXd0IESWJYBPOQhKicqxQd/RoqrFBRHBmMH1RMcy6KlSZplBiNkNuaOkGUpGSak5ZLzCHojKN9OFPmF/dOhN2ow/vxeMFp1OOUq5ey9x5nTMNomheFzuU1oPCDpF3SaSHxZr04nmpPl5OrQe6REzlZ8thcNB2XGX5QGtbmjEcgiA7YgzBQ/h8+HnSuqwTFKYXhsGbDOoIHBpIXBREhmbOpPgeFXAl2GAyfeCPyFaOPhxIxb7lS8zutNyseznAbUePlUjL0IuOBT5Y5LH094vFBduewVjTBnwrp35+H3uSTKmKMj/4c4nqhoE3uOeH9fA8+KFeA4DvXPPY+GLUob++HTuuLLf98Lho+8aLRaMSKEuh3AK4OAN8cCHAeO4xBeskzc9/yHnsSw6ecAAHTmk6CzTBHJZJgzt5mO+7PGhd1OGud32ykuC279EWdVPAsNmLhB/LWlBzDhueX4Yn28Ydn2KheW7FaaSwYjLCodEpH3hqJ/ibiUNHrx/m8SCQhwJO1cYCOfT9QC8qmzWKIiT/2kW/TISTGie6YFHAfsqG5f48xkWL4nviVKaxAiKmadGd1Su+HFSS9iWM6whNvK3cuPFjqVqPh8PgwePBhzkijOn3vuObz22mt46623sG7dOlgsFpx22mkIdsAD4H8FAlH5qxEVgAhnF5yzAGPyxoCmaIzvMl6xfka3GeLfeZY8RUh/yYe74G5uv9slEG/yJjhv0mnpWDfyIawb+QDqNx3A3k9+AQCw5hQwtDTr4CIRVPLai7PvGIrsrilimSVtJYQthRcPChBmoBzHdYqYFlCSxXRTB4iKpw7gGIDWQlsohVsNYTdGbHkRo1M3AACqDFPhrTbC2SIRlShnRLTXTLgDZECyZRihM2hgTVeWqE+95gRc+8pJYj+jwr7pipB9Sv+hGPH48ygybgUAcBwNP2MX15sKugExmhuKosT0DxsKxfnVrJ+vnMGVdp8JuidJicUaugW9HmIaxymJiiaJ1b6eJyryiEocjCnAtcuAaU8n36YTMf6iKwAAU2/4P2QUFCnW5fcSiIoeB5yyFGa/M4GcARBpRe1WyUk3EiDpoO1fw7/sbfnhEGIt2OibLb7OFomKDRwHUJrEAzDrcqHqhn/C/dNPaPngA2h2bRDXnXbtABjMWtFZFwCMthRCNup2Am+NB9gIIdn1fyKwZSvcb0nnld2tB8ypid2KQ4YugK51v6fP11UAoECnpGGgnRB597ZF6FfzHabRG+LcaV9eQrRt//p+R+yhsLeOPCfH9shA+TOnY3hXEuX45U8SsdpS4cCQR3/Fs7/sidu3LZwzR9mLKghyrsIET/CrojSk63SsRkUjI9FCdCWLb77oDPz1EuyKZj/++ZnUp82XpBO7HIJGxayVdEfyvwEiDn5i3BMYmt22Vquz0alEZfr06XjiiSdwzjnnxK3jOA6vvPIKHnzwQcycORODBg3CJ598gpqamrjIiwop9fNXIyoCKIrC66e8jkXnLopTcedZ8/DcxOdwef/Lcf2g6xUhfQCo3KVM1bSF4M4/xb/TzzkFVPlqIOxDxJyGgDkbQVMW1o96EMs22/D2rSuxyHAR/hj9b5GsREKMGCER0g1ChEaTkY7cRx4BBQ5nTJQIlM8ZQjTMgImyoh5Cdxg1KoCydUFrZX5x4IW0sOXDNHIU0q+6SrE6I40ct6mBg8+fjoAxphHilDfg4cuVUzLJw1+IqgDAyZf1VZQVCxh2WlecdHEfnHpVf+iNWsCYgr7nnQkACEYtCpO3ky4bGLc/IKV/uFAoaYRq4M63YQg5wGgM2L6MpLiEqh4BQZ8XTdVeseFdW2hXROUoY9RZs3DNa+9hwKQpcesyi4gmy+bXiSaLAv4YfSXGdi3AfKsZqN8JPNcdqN4MfHwW8FJf4LtrURZSpok24kbF664GIk5mOS2i0IupHwEruwxWvPZvItvLJwTFAzMUdvgAYE61kz82faj8QCXLwPp8MEQZ9KlpxqRZF8NosZL7KgHC1l4JlwuIMCx+3EoiTYYug5BtJPeF0PRxAr29Q2LTlftIdVV33tn23GEkGrB4F5k03fnNNkRZDm+vLE18gCRo9obiDN4YLTmmQFSEiApFsQAVwcl9lDo0Vpaq0vMVVql8lMX9F4lKIMxgYUyqqy3xMcdxKHeVA1CSEw2tUaS1pxRNwcyeM//SeR1uHDWNSllZGerq6nDqqaeKy1JTUzF69Gj88Ue8LbGAUCgEt9ut+Pe/ADH1oz80ogIQl9U8a2KPi+ndpuPukXejwFaArUsqFOtiTcJaA+vzIcSn8XJGOJGp+wz4+Ezgp9vBmON1BFF+FhA22OHnfUk8LvJg0MtSAEL7em16BtIuvAC9N25A14tPR9cBJN+75KPd+PThtQgHpFmFTn94b3M5WTTSyUu/4yB4ZqR2AUVRyLn3Htgvkpok6l0NREDKAe7BN8HFPwAF/5LKvS4468lMSBDK9h9HRNJ6kxa9R8XrZfybN6PhhRfQf3Qm+oyWKoYMvKlZiLEgyBOVoVOLkJGf+P4SKn+4UAgDTipA9yFZGHeqHXpKeiimukqR1ah0T+YYJ9mf78sUcLvhagy0m6gIJDPW6+NYAkXTsOfkJhS5m1PJtTWGaext2QuGlT7HzX++BR9N419ZPCENuYFF9wNV6wEAUU6HhggRXOoM5B7e3TQEAKBBBFdnX44sbalIHPf2egvUGc8o3n9PmiRe5wCU1JrgsRYgyg9IJ1/WF1q9Bp6mRsV+Qqk4qBhS2rCLpDAB9Gh0oqcp/v7XUgF00ZNoR8BQFLdejoPNfnhDUZj1GmScMBm5RvKcqw3YEGFpnKdZhWD9PjFS0hpW7mvEgh1k0C7iW1X0yiaEp9ETwjcbK1HaKN13HUn/nP9WzJhERUGZyfNN6HBu0poAjnxPb1/eHxN7K4XWid5NSAe5k3RPbw0sy2HGa6vxzM8kOiR8Zmcb5dwrKleIZpWxVhPyZ5tcCnC0cdSISh3vXpmTo3y45uTkiOsS4emnn0Zqaqr4r7CwbaHW3wGHKqZtD1iGRW2JCyzLoa7MhdoDJG+akkUeaiWbG1Ff1j5iGNy1C2BZUGYaC1JfxHznI6RUcftXiBha/wwhQypqc0Zh3eC7AQC2dDIoe1asgPun+QAAbSYhJho+BSSf5fucITgbyICu0dGKXiaHA3JH4Ive3IlNB2UErmoT8MlMYAuvQfjzB2DVC8SFdDFf7i1LF2XfeRcM/UiaJNrcDDtfAlrmyQNH62AItiCzmTz0V3wmlVla08g59ByejXPvGoZz7hwKrS4+0nHw4kvQ8v4HaPngA8VyPf+AdFDFCJgy+WMmt6cXiAobDCG0YjH6r30Jhgcvgb2eVI5oI17oI17k1a1T7MexZIDRGcnv3FHnADhAb2rfLFkQQkfCxy5RaQ1mvpuxjqHh9bmwTnZ9wmyCAaVSMoAT/HQoCkiPIZDmNBNMxQNBTX0Mw04jZGRHaQHoMVKUbs3AAdiU0wcBQzpC+lS0pPXFNmokNoy4Hx5enyZMACr+3K44vs5gBDb/F1ivTD2hcQ/YgBTBFKwHAEDL/wbTNDVI1ZAoiSNjWrJLAwB44RdyT/fKtoIuGoUsow9WbQhhVovNLV2gpVgMoUqwer9EpOT+JL6QFOX4eoOkWclJIfdyhpVEjsqb/bh7rvIzlrSjOgYgEYjSpphUpbkEtNaLTGMWxuST7tcURQEseV9aFy9fKEyL958SIiouWUTFF4qivKltIu8KRFAm225cT/JMdCRxzBXw/YHvxb+N2pgWERpZpC21uM1zOFI47qp+7r//frhcLvFfZWXiDpB/NxyqmLYtcByH5Z/uwXfPb8Ku1dUo3ybZJQ88qQvS80kodcXn7cvthkpIaDWSkwU3k4vq8EDURvgBWdN6zjpksKMpUxLuGsxaRBsbUXXDP+FZTPrgaNKVzpex6QhPE3mY6g6zPgWQvAfYqBmegA6z3vwD1c4AEA2RqFHpCnAL78bqHQeAb64Alj0OPC8zDusiidY0VgtyHyIExvfHH7CZiWqh3kkeIGZ/fVxlk8mmU5CvvJ52ZBbISpn37kW4SmnwFZCl4QAoQvXNGSTdk5ab3MiPNvIRlXAI1bffAf9aMqD23j8XXapXYcTmFwAANm8VeurLkF2cgpMv7QuOJYOB0UqISu2+zYiGtsNsIw9Ta0YmLn36laTvK3yvkeDxSVR0BiO0enLtjGEN1lST0vrWHHYFCH46RosWFrtSi2RKMQFX/wyM+z8MnNQFtIZCS40PbkcUXf85HF1PbURB3xZENSb8MeZxrB31EBxpUpfesMEOgETi/G4XVn2qJLJanRaYd3P8STXuU1QQRRslAjHr7mHokXEAU+wvI0NLorHNzcmHmIpmPxbx2pEuaSYgrRgaisPwdHLv1gTIPd2NrsP8bTV4Z1UJftxaDYNM+/HTdskrptYlEaiT+GhGpiW+zYhgbX/hO2vhD7edUmzyxhNKjY48j4tTekJHS+lWjicqDYGDcfvcdHIPzBySj/evGCEuSzEKqR9yHjXOAE56fgVOfnEFfj+Q2LI+wrB4+ufdeHOlUhA9qhvfLLKN1I9Q0QkgrjRZ7kourwI92jhqRCWX7wNSHyO4rK+vF9clgsFgQEpKiuLf/wI6O6Ky5rsS7PmDPDRWfrEPmxaRH1p6vgX9x+dj/Pkk1+xpjp8pJALjJCLYsFn6fg6GyAAtGIVlZGmRm6D76N7eF4HNLRZf9xmdi3B5uWIbbYZSxBqrQ3Hy5ZuHu+IHAG4ddituG/AsAhXXQfgJnfbyKmD3fIBvc0BF/Djw1b/id7YXAWOUXUtNQ4dA36MHEI0CawkR80QJmTOGnEh1KR9IVzw9Lum5RRsbUTbzbJSceio4WZdlLqx8eAkiVTlyuycv7aX4wbbiSqWuRh/xoM/+r2AOSAPW4Jx6nH/fCPQfnw+dnr9faJkw2L8EFEjEa+z5FyOne3JvHpGoHMOpn9ZAUZQYVTGFaNT4yMBa5lYKkBPRFsFPx2AhvjJymGSmeAazDtld+RRHpQfmUWNgzozgNM1G9NaR753RmlBZGN8kk92/C3t+j28cqOOS/M4jPrDf3SK9lJnKZRbYMO3EP5GmrUaGljw/mquTRy1Km6R1s6vWoebRJ8FGKWQbyXIHbwLXlarDtioXnlq4B7d+uRVOWcTglSX78fGacjAshz18emjJHRPFPkEpJi20MkfdR87sj9cuGgqaApp9Yfy0rW2H7IPNQtQiCnPxHPQeMA8GPYkIbj8YFjUhHMeB41Ohz2x+IO44NqMOr144FKf0k7IIKfzvUEj9bChvQZM3BI4Dvtuc2E34j5JmvL2yFO+sUupsJvQi18sTiiIc03RUDsHi/4HRD7QqJfgrfl2dhaNGVLp164bc3FwsXbpUXOZ2u7Fu3TqMGTPmaJ3WUQHHcdjbshcRJnHI7sOdH2JB6QIAnRNRYRkWWxdXJFw3+fJ+0Bu1yCoiD8KQPxpX9ZEIjMMJAAjL+gj5mAyA1iLEz1BSMw3ov/dTnLT535jxz4Ho1vybuG0zQ3L7o7Vr0X98PsIVyshZWxEVB+8rYrAcPg8V8b1pDczMALAhiVB7Q1GgUpn2uEr7S/zOs96Pq4KgKEqMqmhrlQOYxUoj3bEHxhBJLw2dWiRanSdCqExWPrl1q/h3nOHXcuW5FfS1t0rqhNRPe0BbSfQtEg4hHCQPeSZiV56n3wkAsNpbr5oSImLHK1EBAHMK+Q0YwxqxRPmXMuX1d878T9x+QurHaNHFufWarErRtCCqdtT6gOFXocVItCEjxS7IpHQ8Fs1PPoKavbsAAGPH98egYh10WhojG2S+SzkDiWmbhQyEHCMNYL7Vq+GUFz+ECFnI0JHniac5GNdpm2E5/FHSTHQnHIc5699Bzof/gWvut2jcYUO6nkwynD4OUZZCMZWo6SFBrSuIf8/7E2+vKhE9ToS0D0B+W4JxHABk2gw47YRcXD6mGABQ0tR2+ueHrbybtKkKGlMlapk10BlIWtwb0OOlxSR9FYywoLXS8WK70CdCrJi2ziURxBV7GxJG3mqc8dWXpw/MQ4ZFDz0faf1qQ+LnOQA4gmQSKY+sCLi036UAgKfGP9XmuR9JdCpR8Xq92Lp1K7byD8yysjJs3boVFRUVoCgKt912G5544gnMmzcPO3bswOWXX478/HycffbZnXlaRxWN/ka8tvk11HoJk+c4Dm9sewPnzT8PT657UrEtx3Go9FTipU0vicsOh5g2FIgqZjqt6U4E0abBrBUHyNi+O4nAOJ3kvfR2cZmfTQPYKIJuPh3Al+ZpPM3I8uxDZtnq2MOAriCCtbiISmYbRKXOx593B+zt24sdc+Et2xS3OFBFcuBNg5WVGZ5eZwMXfQU83AIUjkp4SPOokaCMRuiDSiv8jBMHgQKHwdtex7Sr+2DMOa27SDIt0v7uBQul5S6lT0NonRTJGtPXhan/GNDqcYXUT3ug4Qdmb4swSGoRiSgdfd2NpCrIktYGUTkOxLRtQaigMYUIUeE4DovKFym2aUqgo5KnfmLJqcmqJC4CUWms9AJ6M5ZkXAKG08IQUopZM5p3Kl5rfS3w7yfpKHvpdzjVuAw39liFVA3/fBh8EfDP34ABswBeEM1GlTPtuof/Lb0IOMk5n/ucWLLcUqPUW/zz00246N21ePrnPbCHvOheI7XScFXYYNaEodNS4DgO7ogRxVS8ZlGvpWGX/bafW0TIgpamYI0xeMywSveuQFoKeeFpVUvrlgscx+HbTYSoXDlRigrSZjIh4BgTShrI5/OHo+A46TkUiratwxJSPxvKW7Cnzo1aGVFp9oWxen8TmBjRb71bedwMix4vzh5MBPqp5LM+9OOfWLA9cbRIiKgkcpy9ddit+HHmjzizx5ltnvuRRKcSlY0bN2Lo0KEYOpTUYd9xxx0YOnQoHn74YQDAPffcg1tuuQXXXXcdRo4cCa/Xi0WLFsFoTC7qO97x2B+P4d0d7+LaxdeC5Vh8sPMDvLXtLQDAt/ul7pR3r7wbZ3x/BpoDSh8Km+7QUz/zXtmCLx9fLxp1tdSSH1pB3/gb18jP3CiKEmd1HSEqQa1dXObjCLkIuQSiwn/PHAfXDz/A5q1C14MxUYjKA2BDIYRKlWFOTXpM6ifm4SR8pkTluoeEuh3At9fg2t1X4gqNdK5pcJPSUgCledPxcVQqVa3NnwL0mQa04vBI0TT0hQUwhJWEIrVLGmibDRZ/PfK09W2GY+WaAbmBXrRGCtGz4TD869Zi8LbXMWFoAMNuOydu4Is7P1PbjSiLPvwA5unTsdHThP/edyu2/kKigFpDGigq8W/a2hZREVM/x255cltIySKhfqtfC0fIgX2OfSh1lUJH62DntSI+rez65xDSGBSJig7dhyj1Akab8r7OKiLblm9vQmOFB2UROxoiPRFmldf3xC7fQMubffXa9wUq0/SobCIk0KyNgKIALS0bGAtlvb34iC8XQ1S4cFiK2AV4omxKQwava2uStXf4s8aFX3dJEZJCjzJawgQ4hHr/HyzppLQ3wOhg4QJIZ52K7QrSTPj0mtGYPkApE7CbdXG/kZwUiahk8qSlII1ENascxOE2kESs7QszCPAR5FpGmkiFaBLh5RgzglGy3h9mEKiQWjsEmbbT5AXp5DyCERbTXlmNj9aUK9Zf/sF6PPPzbsWyupgmi6cPyoszlwOAj/8oj1sGSBEV4d6Tw6g1ors9vhHn0UanEpVJkyaRvF3Mv48++ggAGfwee+wx1NXVIRgMYsmSJejdu3dnntJRx8b6jQCAg+6DmDJ3Cl7Z/ErcNlE2ikXli1DhqcCyimWKdYeqUQl6I2g4SMKzJVsa4az3Y9968rBIz7fEGXvKf/TCDMkRY9eeCAJR8WokLwEfS4iKv4U8uEwyFXxoDxHpZk5Wpv30ITeY5max1FmANkMZUYnr58M/a43Ww0xUGqXKm0d1H+M0egMGUqX4t+4TmKgwytgc/PeACR8w08XtKox9Ex0pDrqirjCEJKJi8xxEfncLrCedBADwLFmabFcRcqIS2iudK+NygfUR8tb44kuI1tUhw7EbPYcl7zkkR8qM6W1uYxkzBvtP6Imdv69AQ1kJNi/8keyblQtQ8REZWqMRbdqTQSIqyXPuxzrsOWQwTQ8SsiY8A7qmdEWuhazzaWREexAxdJM0KjrY0o04+VLpPjLblMSyS5805HQj13Lbskpwu4vwp38qAMBPcSjUb8VI6xfomrodYzc8jImr74DWtwV78qUogUUTBnRmYOS1/Il3BQZL5fPgUxksE0+Wa+7ju3PLiEpWGvnODmySXFMX7lDO8gu9ZJ31pJNgmTCBfG7DIJhS7ACAFU198NrecTi7dr6ieefU/rkY0CUVb146HJP7SvdwbOM/ACjOkDlJCxEV/tmzrcqF6a+uxsgnl6BeRgB8ER84jhOFqUZrDVZXr4w7NseaUO0gUZlAhAHj7wGw5D2EJpQsx+Kg+2DCNE7f3BTRlE6ONFm06N3VynRwbDfo3FRpEuCW+Q2VJOghxLAMXPxk6Fjo4dNeHHdVP8c7Mk3Sg6HBH2977Aq5FC3hXTEz7ENN/VTvl1IDzdVefPHoOtTsdwIA7NlmnH2nVJEidzQFIFYe/PG9sklgIjBOJ1iKRjMtpSpCjBleJgPOJvJDsx/8HLSZPDBCB8gxT5hUCJNstqhhwwiVlCISU92liylLT6avMJoPs0bFo3zQvq1/GfMND+JsDQmfPxW9BPN3NuAgl4P3o9PxTvR0VETb516rLyyELiI9XPJr1kBjNsE8glQJhMvKku0qQk5UYnUpEb7sv+Xjj8Vl2rz4ZpWJkDJjBnIffwyF77yN1JlnicszriWDWsoZZyDgcWPHsnhdTmp2Higq/lFjtqeBolt/BB3vYloASM0hnkV2nqhsbdgKgFRVCIZbXr0FGHg+MPRS0goAUv8eoTmjXdG9ON49uKAPGXj2rq2DLWLE3uDJAIBqbQhnpT+KUdavQVFAaoYDWiYEh0UZ5TIgAoy+ATj9BeDuEuD6lQo9FTv9ZXiqjIgGyHdiGSeJut0LSPTMX+6Gt9aAiJeF/uXbyPvvdYjf364aZZq50EPuV323btB1IfdipKZGbIhY5+MbYoa96O4vhxV+ABxmDJQiKY+eJXUiTxQZkWtUBEdYIZIBAHvqPPCGotjK959aU70GE7+ciOc2PCe64tpsSXQyrAGN3hCiDCtqZCiOfDdC6ueVza/gjO/PwBd7vkh4iFcvHIIHT++HFFklnlxwCxDDOQB4/7cyLI2xy++fJ5H9DFlktNkXjjPL84Q9YDlCIFMNR7Yv1qFAJSpHGHKikgj7HPsw/Ttp9lrpUQ7Qh1oy5nNKaZuDO5oVpkf2XDPye9px01uTcdNbkzF8WrFi30GTSE+LgCfSpq9FyOXBjgHXIQoj9EaN6IWyyTcLLid50NnLPgJlVgpLTfvexZVPj0XPEdno5SE6isprr+X7xUsQnFIFJCUqhzuiwnc/fj86HfVc/IxkFSuUVVN4PHoZnopegqZ2NhzTdy0CBQ7dsR8Zzt3IrV8H2miEJp28D+NwtLo/x7IKAW0sIrV1YEPKB5cup30RFYqikHb++bBOnAg6RXrAZd5yMwrefAO5jzyC+rISsAwDW0aWollfanbiKj6rve0Z3d8h9WPPJoOOyUsiEZvrSYow25wtTjx8UR8w6z1g5hxAo0P1OW+gxUr8UYRrIK9sM9niU3Wp2YnTc5RGObPOG+FC2unj4TMoj1H7e29w4/leRZZMwKT8fppX16Hqt3R4qshvVpuljGpyoRAO/mxA5coM+HeVwxRoAs2Qe9/P6yp215Jo7l1Te4OigP40OTddUSF0+cRFNlpTA5MtfhDtG9yP7YZr8Yz2XQzsIq0X9CYA0OSN14X0yJImdxq+AkjQhshRxUdGrl9yPcJsGJ/u/hQtPnI8gymxjo/jNOA4UsLs5UXDNMh1FVI/H+4k7r7Pb3w+4TEK0sz4x4Tu+PzaE3FK32y8euEQjOmuvLYljT6wLIfHf9oVt/8J+amArwkoWYYXzx+M3jnS593foDTKc4TIM8SmsynKqo91qETlCIPhWh/gF5QuUGxz0CXV408unHxIEZWAJ4zVX+1Luj6vlfJUAMjvbYeW9zDwu+IfCH53GPVlbmxfVoF9WaeK/hzp+VYMPoVEQGrC/RHl9KDAwKZpAG2Q/VhoDpp9X4D+/h84baoPfei9ce8BAMZBCZojJpFuHHYxrZuIQA9y2TigUUacHrI+ihDiB5D2WoDrConwsXflfAze/gY0bASU0QgN3804sHUryi+8SIyMxMK/bl2rUZdITbUYuQKArDvvAKXr+PXJvOF6GAcMQO6/Hwat18N28snQWC1oLCc6orxefWDPkZyPU2KEzwIsaYmXy6H9G0RUrOlkcqIJs9AwQEOAzIizzdmwaEnUxBfxodxVjgd+ewBLDy7FtK3P4I8oGeQFoiJPb8ZW/QCAPTuxP5GFVurcdBYGOc+8Do9Rea8yjQ6wkeT+Lq7585THSVO+X6RKKqP/fv5GUJA6g7c0B/DBb2WivuKyMcVYdOtE9Kf9aMoYgJ+35aKWIhOhSHUNNLr4SOgglICmOFyoXZFUqxVh4s9/2oBcXDK6CM+cm7g9hIDnFu1BlFGmGFv4Fh60jgzw/dL7KdZnacizaPKLK7DoTxJt1fBpTiH1I0BuT58IA7qk4v0rR2LmkC7IsyujXdVOv6Ln0asXDoHdrMOQfAvpF/TGicB/z8HQyFb8evtJuD7vAB7S/hcOj1LILAhp7UZ7q+dyrEElKkcYgnFbLB4d+ygA4NfyXxXLhYdaF2sXvDDpBcU6jmHA+tvWiwj45b0/k64bNLkgafM+juFDmhQl6lR8LmWUoLHSg4/v+x1zn92I1V8fQHWXk8R11jQDUnl325YomSWmaauhoRjQfqmMLmuAh2hk/vwO+O/ZYH3KWYw1P4ickX4UPhPvUSB3ZdXKDKFim/YdMpwkwlXDZcKrkw20N63Heo2y++iEXmSAOtjsb5fTpL4rISrh0lKA90ChTSZo06SZbWDrVlTffkfC/f1biH298YQTFMsNfUhPp9CevYjwolrT4MHI5NM2HQVrNmHH2KGoK1KmjRorygEA2cXdpV4xAFIyyd9a0yTF9raMtonK3yH1Y7BYxAiTISzdp9nmbFj0hKg8t+E53LLsFswrmYfbVtwGANCx5N4VIinyezxRRCW7OLHeJ11DooC72UI8HrkUc7q/ibqGZgT54w2obMS4fbw4tJWGsNpMZTSXZpW/z2i5LCXMV+npI+R5d8N7G/AYHw0YXGhHqkmHPrk2ROsbsH3gP+F0AXuqSGQkUlsLv8sZ9/4Uk1yndN1EMmm4/qR4IaiGpvDkOQNx4ajW7fxDURZXfqLUodz1DYl+sRpivtY/Q+pG/tzE55CTQr4/f5jBp2vJs0zLE5VgVHktjZr2F4nkpypJTY0ziAYPOV6mVY+ZQ7pg7YwGfO88D1j/LuDjU74HfwfCftzveBjXaH+GvexnxXEEIe3xpE8BVKJyRFHpqVR2UAVxOX107KMYlk0GOU8kMZGZ0W1GXKiu8vobsG/sOETbSAkAwI4VVajem3i7yZf3xbhZiU23vKtWYe/wEaJXgqBT8ccQlYZyd9LeGRa7IS4Fk68jpElrlB4+pkzZMSN+sA5l/xF9ShTpPZzQro4nKnk9UzF0ShFOvbIfRp7RDUUnZGDCBb1R0Psw/yAd5QCACi4bm1NOIcsGXQhk9UFUZrA25+JhuIL3alhf3oJJL6zA7lpyjWJFdfXuIPbUuaHLywO0WkWaizYYxIiKgMCWLfDI/IfCBw8iUlOD4HZitW89aaJi+4yriUmbZ9kyVN/yfwAAbXb7Uj6JsG/t79i/bg0WvvY8Gsqlaiyfg8zcUzKzFCLZlCyi0dEah8Fgvx39J05G8ZDhGDaj7YZnfweiQirmSLRyhI3MwDWUBgMzByp8kcrd5Yr9dAxPVPgJhMVuwKDJBRg+vWvCVKdGS2P2U0oxepqmEmcb5wIAAtpUvM/MwJI/yvD5HdcBFAUaFIqhRWqA/PZYGVFh/X60fPKJ6HKszVSmrY26KsXrwHapymx0PalU0fMTM6Ps67v3NEKcGY8HvrD0XHA6OXAg+rYufZVkGwCsuuTD1d2n9cHcG8bgzil9km4Ti9MHkqifXMOypqxcsQ2l9QJUBB6ORLaH5wwX16XqU5Fji58I6WieqDBBxW891q6+NcRGVKocATTwkdksvlrS+NNNoNgIsPAuaUONHtizQHypdSv9VMTS5I40Uz0GcPjdsFQkxdW/XB237PmJz6NnWk+wHAuz1iy2345FIuGT7zdikOb55VekXXgB6p9+BqGSEhS+9SYorfKrXRWT8knJMsHNu7dmFdmS9sOpvvsecMEgau+7HykzZsDMl/p5HUFEI4w4y2utZFln0MQRlbxMNxABtCbpCaY1KGdMnKcBwi1qzg4hox9P4spWkciGXRLUUhSFsTKyJfQ/OawIOICgEwBQyWWhNr07cNEWIJXM1OR+B6cPysOOKqUQevqrq1GUboaGprDw/ybAxA9AF76zFmVNPiz4v/EwdMlH5CD/cNFoAJ0OmtT4777m7ntQPHcutOlpKDnjTFAUBQ0foTCPPhH6BQsRbW5GzxXLwfG9WaKylNFfJSosy8DTLFl7L5rzEi5//nVyefgGoaaUVBit0gBsy5DOn6IoTL8pcUQoEYQBOXocExUAMKXa4XW04K4TbsWNhUak6lNRYCvAssplSfcRiYqMlEyY3XpVZFQDfGgLQk9TeK/bfnSpeh2pGiIEpTTkNzi+RfLQScnJRbdn/4OSaUQXJ+/j0/z+B2iaMwdNb76F3n+sEY38AMDaJYBIcAvyn38JNXffAwBoePvzuPMRUj8WTkrVdOFTRqH9BxCQae7CIRYhgx1GnxODJ58GrU6P4n594P7PZHxTMQghuU4pEgR00mCu09AYUdw+0bqAp84ZiPG9MjFjQB4u/3A9tlU6QWmVE0VK68b0YRRWeyLINmdjaPZQcV2KIQU5CYJYetoAMCSi8urmV8XllZ5KbK7fjGE5w+J3ioFBq8Hqe07GNxsr8dqyA6hy+NHI63yyBXJE68RKLBG+RsArCX+NXiVRaQkS48hEpcnHMtSIyhFCmAmjzhevLRBq1mmKRr8MKf8ZGz3JsShV4Ap7dL66o+Xjj+H77Tf41qyJex9jjFYjv6c0eCQKIwuQ54LrHnoYFjvZ9ve5B/DhPb/D1UiIVaAVosKyXLyT5uSLAI1BQU40134H3HUAGE8GsryRTlAaDrnDneg6uRnai96RDrBC2SX2iMBBZlU+XToCMJKZWHp3gC8tjc2PZyWYbVW0+FHW5MP0V1fBF4qC4zixsdg3G6ugL5IIFm00gqIoUHq9Qkti6N0brN8Px+efI7BjBxCJgAuHEa0lOXJtVhaKv52LnksWQ2O1QpOZCTqm1YQ2q3VRdyKwDIPPH7gLa775TFzWVFkBJkoGEL+HJyq2FNCycluT7a+X1MsjKu3pj3OsQoio+F1OnJBxAgpsRI/hCrmS7qPjy1xjPYJagy/EoEnDwWfWoP+5U5BqcIrr8tn4548tMwv64mLo8kkaT5768a8nXZwFETcXkn7jXU50whppAk6Z2ur5CESlp01KZeTxaQ3/+nUKogIAHiuZfFDBIIacdjrsRT1htJJ7NyjXzwSdOFSkmnW4aFQRUs063DGFEMBY4XHXnADG9SPPqEGZgxQDvEFjUJQGS8vJskA0gKUVSkuBKxZd0e7zK0w3ix2YSxq8YupHJCr2wvidfE1A1UbpM/qVPYec/HVTUz8qEmJvS7wwNNOUCVpWtjkiR2pWNS5/HCiZQrS3XTmTkmtTuEhE8RCPtrQotmWiLEIBZdVERhdpxptImCeAlg0yrh9/RH5Pu/g6HIhi62KS2/Z7yENswgW9cMk0NwZvnyOd+8icuN4yqUPHAw/Wg+p9qrhM03ssYM0C+AZ2lpww+syqRVovP5DZm/hLXPw12bhqfdJz7jTwaZ8mHXmo22PI3xmDSChZUN1ntGKiVt7sx1MLdyu6pm6pdGKxU7pOHmhx79ztuOKD9eBk90nqOecAIANI8M/4KgBNmp0QFD5lRFEU9F2VESYu3LG28kw0ipcvnon6UqWfDcexWPLeHOxY/qsioiI35NFodQl7C7UHAlHhOICJHL9eKoJmx+9WEhNfJF67tPDchbhv1H3QJoiotAWhyZ5FrwGy+wG3bASmPQuAwry0+EEyJYtE1oTqOzYgERVtlkQiau67H2yQRFtqrrwZ0AJaisXmHfH3nxwmvgdUgUb6regFh+vNm+E3Kydg3nRiZyB3UjYahVSK7PfmVwqEDxVCiW9sROXkwX64Q+S+TjOmwaKTokr51nyJNMhg5EWzwWgQXIIuTu5w+zrQA0CvHPL8rXEFcYD3RRFbBEQSuOr++R3QIH0nxb7tYH55AJyLpOmEqh81oqIiIYRmZABwUd+L8NjYx/D5DGWo9MS8E8W/003pipupKEUpBGM90g+K9fkUnhmsW/ljczcFwLEctAYN+p6YC3uOGSdM7IJJl/TB1H+ckDTtAwC0RWl9XjQgQzHoVPG6FyH1Y7LpwTgcyGjZhTPKrselt+cho4s1TqWvN2oBigLVTXK+pLT8g8gizfap0x4DLvwCuJyYhyFvMPm/+UDiH2pngicqtTR5uGZYlETk9im98dx5g/DpP8hn0mlosRwyET5bV4Ehjy0WX2+rdGKdVhocPByNrzZWYuW+RtSapRkQbSGiQzYQgPe3mLYDNA1Ngkad+iLl/WMaMjjpeSWCozZxgzQA2Ll8MX596zWwDBkkTSkpccaB3Yd0PIIDQCHwPp51Kmb+O4kViV4z4Jq4bbNMWbDpbVLqJ9bMsBX4+GtkFqIw9iLgxBuAB+qwPWMaACBMSQO+0BCSNgpEhTi1ctGoom+U64cfwAVJ6qGF1aIJJEKUvVlq7yHAn27EVVPuh2H2hbDyzz3aHcEFwwvw0mzpvguXH4TbRu5LoSGmN5UQatbtBuMlpmtGnqgxHI0Iyz+r3j2l3dekPciyGXDL5J6g9YQAsWGSRtrWtFmqlDHYQVEUVsxegSXnLYFFZ1H0FRJg5rUoQSaIlkBL3Pp9LckrL2ORatIhj4/a/HaAnFt2ioGI7QUBbVo3YKzUKBKMMrqt+eN1/PifOxGKMq3a5x/LUInKEYKQ9plePB3/Gv0vnNPrHORZ8xTbDMkeIv7tj/jxzMRnoKW1GJM3BlpaOSNlZESFcbQoQraMR8nYnQ1kQLdnm3DKlf1x8SOjodNrcMKELug1QjmjiQOjjMTo9Bqcc+dQjDmBsHuBBAU8hCiZU/RgyolQ1mDmkNpLigT1H08iEYX9pVyy/YILocnKhH32bOlNZEQFRWOAvjOAFL7CxJoDmDMAjgUa97R+7ocbPFEpYwiZyItR5ht1GsweUYhsm/TwWnTrBHz2j9G4aFQhXpo9GOv+dQom9UnuhbM5S7peWlYamDfbu4l/07ydffDPPxHYqOw5RBmNoDTxA5u+h1QNYb/wAtEFtL0IyO6pWf96DKffeg+6Dhoat53OYIROb0BBP2Up6PjZvdFrRDZOvzFBaXkroGkKWj15TB3PRMVoITPjoFeZWuiZ1hPLZy9Hhq8LxpfOQiZyYNQaYdVYoeWNw/5yREUOnVFcFqJlTfoKiwGQNCMARGtrsX/8BOwZMBCh3ZJ1uzYrS4yo/Li7GQ2cHQBwQv085AxTRokqi7Phy8hB5oSxMPvrQHEsQv4oHjq1L84dRlJeXCSCUG09vFbindJ7FHkOBXmRZ9Obb2Hf6NFoeP4F6M58DlqaRNPcET6CEQ0ALcq2GoeKO6b0xsjehIxFXERHUuYqEytlBJ1ghilDTMVnp8RHVMy8SZ4r5EpYHFHra7tjsxzds/hWBLxHTKHWCeycC7D8s/mmdcCIeP2jHN3C+1DlCIipHy2sKG1suyHjsQKVqBwhCERFsMxOBC2txUuTXkLXlK64esDVGJs/Fr/O+hWvTX4tblt5RCXa4lCo9Zv+8zqa3pb0HM56kiYSnC070r47ylvhC+DCYWQW2GB4435QLAOW4bDqy33ie1jsBgS3kzJZQ88eiv42ky7ugxn/HIjJl0lW4Nq0NPRauRJ5jz0qvYncaCovZuZPUUAmr+xvLsERBU9U9gaJaDVWmZ8IvXJsGNczE0+fOwjnDitATooRH101CpedmFjs22S2i39nBiVy8FH/6diU3RsvDrsA+1xRcADqfG6ENTSMg6XBn0tSrm4/91zpuNde2+EW7gJRye/TH8WDh6Hv2IlIyYwnXIKjaO8Tx2H6zXfiqpdJHyuDSYup/xiA4kEdj6z8HSp/jHwVVNAbP3BlGDNw/vZ7MKB+IkaXkEooE6TUbEeIio83YjTrYyY20ShSytbBGvWAlXVR1uaR+5AykcHVvfBnMM3xaRU2HBYjKiGNTmF2mN7bh7Te0qC3nu6LPjk20CYzNGwUJobcO64G6d6M1NbCa8wBR+tgtGjFiEqY95XxrlgBMAxaPvkE6H4ScnqT/ke1Nlnvof1L4s6zOdCMy3++HB//+XHcuvagwkOiSOf1mwINpUWEjWCPg0yIEqVL5JMSAakG8l3vd5A0qZbWKqqFhA7abWFj3UZsqNuALnblhGj80nOB73hrgdQiQGsgWrliafLBapXmfzowcAUiYurngbllmPziSlQ52m9vcTShEpUjhHo/UWLHimJjMaXrFPx0zk+isDbLnJWwrE0RUWluBheO8TV5+WVEqkm43sk/IOxJnCuTgeM4sE7lbCnqcAIAaI6FMUTCmjtX8eWLOhq2DCMCpYSUmUadpNiXoil0G5wFa5oxZnnMbZg7iAhqZ75BfoSxSOMHeefB+HWdhUgAqCPdkXcFSUQo1uugI0gkwkuG7lkWePVmPDj2OiwpGomDAQ41aVas75GPTd1yoe9S0OYxdLm56Pr55+jy2qvQdenS4fMN8vebsuw4/l62pJIBjKIo9J9wMtLz2z63tvC3ICp8FVQioiInjVlOcm+bODJgsxQT1zm5Nfj5yhhLDLnZ+ssCGDb8iMsrP4eOJc+KL/LPR4Ql2wkRFcEzKRas1ytGVEIavRhREUBrJS1GGZ2Pa8Z3A20ixzSFyeDoapJSteGKSrhtxQCI/4sg6A9TRnBy98ZoFKH9+5HXi0xualLGAhl8dV/M798T9mDS15OwpWELXtio9JyKxe7m3XENX5dXLocj5IBJa8LDp01GIS94LnMR8pKIqKRb9Hh85gmKFG/vdHJ+6+uIji7dmI4PT/sQN/Jd1dsTUfGEPbjql6tw9S9XIztF+f3rgrLznnC79Pfgi8Q/aZOyUrCAaoTTHxajQ94Aea5uKI9PTR2LUInKEUIFX8+eb2lfb5W2oIyotCQ0aqp77HG0zFuAqt3kZkzL7SBR8ftF7QvF9+RhHC2icNcYaFJsn24PIrL+V7AhFpSGhWHirA69nwiKAk79NzD0ksTr7bzewnEEiArHAcufBt6bAvibwZgysZPrBpNOEyem7QjkIrw+OTZ8f+NYvH/FCPxrRl/sSZPU/PNuHofnZg1StK5vjFAo403UHBYTdAUFSLuYPKT03aQUUSzMw4YiZWrrVRrJEBAreiRxtSDEVLxHO2zxOwqh6uV4Jioma/KIihwGju9tEyXXOawJIMK0X/icLKIiiKA1YGFmybMiTOvhjxBiQ/GkItoQ338MAMAwYPhJS1ijQ9dipbHaekqKkl44oSemD8wDxZMfU4g8JwQ7BACI1tfBlUruVUJUyG+JA41oTDQgUlWNrgNIZHX/lq1ghlxOVsjKcAHg233fKl4zbOL7pdRZitk/zcbMH2diTc0ahHlNh9AA9vze58OoNcbpApP1xrlsTDGuGS/97gbnkEmm4DCeYcwARVHIt5Jnf3siKutrpWIBvVF+z8SIc3vItDpW2e/RqDzXFMoPtnINvHwvMS5KiPCc5SUIR499kbpKVI4A/BE/9jvJg0LubHgoUERUWlrABuNt2r0rV2LbC1/A3RSEOUXf4bC7oLyndDro+YZhjIwUdStfqNg+27sUvldJrtScx4HK7qRO2HY+orJvEbD6RWDpY3G9gA4b9v4MrHwGqCdmalVD74AfRmTa9B1On8hhk/Ua+fbGsRhalIZT+uXglH45eH74xai0ZuHlobPRxW7CiOJ0bP/3VDx4OnkA7l67GG6zRHRCWTnIvuceZN15BwrffOMvn1NrCHjIvaCMqMQTFWtaZxCVv4FGJUFEZd0P32D9j3MV29EMIRicg0QYnKYGVHqV/b5aQ7KIiiFGFA8AYVonNvETdE/y9gw9ly5Bz5UrxNdCSojRG9C3Vy/FsbbJmo9aUm2KY5p40eemRQexbRn5LJ6qJjRmEY1TYd90aLS02HzR8o8bkX3XnTCfSIoLWK8HhQMGw2ixIuj1YOVa/nrIiMrC0oV4fevrinNqiplICdjauBUA0ZBcv/h6fLablNsLUe8+6SS1LC9uAFoXoAYj0r05OKeHQlOYbiIR2CwzSZU2BROfl4DmQLPoTgwAtKwS6V+TldFQL5uFBXO2oeaAE7DIUrGGFGD87YptB68lz2YKAFhCIg80ePHtZqVx37EIlah0MkqcJZhfMh8sxyLbnN1m6qe9YGXle4zTCdaf2KJdeBj0G5dHKm06AIbXp2jsdmh4V8poU5NIkuzuUpzev5w0ECyoxSjbl/A3kgespVcW4ko/DheyeI2Kt56QlNUvAo4EPW7cNcDGDw+tOqhyneLlD+FRAABzgl4kHYE87yyPltQs/xFDfJtw78Rr8WvXUWK3V5qmkJtqhIEJomtIGUl6ch8D2mhE5rXXQl9cfEjnlQwB/js3WqWIii09nvha7B0z3WoP/h6pH3LdAl4POI6D3+XEb198jNWff4SgT9J3cCzgaQnCUUfStS2mOjEa2x54RTGt8v70u1xx24ZpPQL8ACukfsB74tCXXAFdly7Q5eSIHc6FyUC3IiA9R9JYOTkLqjXSvZCSSkiZkPoxOqTz/+1rMmGrqAJYWoc0YwB5vKeT0Ppi0Z5irPGPQMRGBl7G44FGq8WIs4jOqryET514SfSnJdiCe1ffixCjnKwlS7EI3YMFvLntTQBSN/scvmT63F7nomsK+ZxXnnAlimzJLfiFhoQAYNTpxWMAJKIi/3+/Yz/qffV4ceOLuGnpTYiyyoKFLQ1bFK+7ZkdFMf51I2SREls+Vn51AOU7mvH9C5uVERW9GTjl38Cd+7DfRp5ZDbzI3hqlIB/6jwediupMe5jxyJpH0OBvwOunvI6mQBPO/vFscV3vtMMXYYg2yfKUHIdofeI25K4UEqItHthxEaNEVFKh451MI/X1YL0SKTJHnDjtHwOA+e8i8EsA3mrycDEVdmIL8S7Dgd7TgX2yPhZNB4igTI4PppE8tvMgcOojHXuPkBdY+SywWSnKe/k38jAzJumL1F4MLEjFo2edgEJZu3m/y4n1336BHgB8GgtWZk4ALct9Z1gM0HHxaYBF4cMbxYiEQ5j34lNIzc7FwMlT8cOzj8LrIOlDs6xzckpWNnqPmYB9f0gl0pZOiaj8DVI/fMqMiUQQDYfgrJciF94WpVZi65IKuJtI1NJpqsdBd+IUpz/iR0uwBV/s+QKX9rsUedY8+GPLk4VtY8qiWYoGS2nEiIqQ+hHwztYm9Ftdin9M6A7aYlH4NpVb3gds/xFfL2RGwydruJeaRj6rkPqx+uJL22s9JMJTlMdgT8seMByDtC5msYdYxa4WaI0noBjzwXoIkcvjm/K5m51ALsSIipyQmLVm5FvzccB5ALW+WgzBkLj3FnQaAgLRAMpd5SJRyTaTZ51FZ8GPM38ETdFtRk89Md29cy25qPaSzy0QlHSjROLP+uEs0YV8fd16jM0fC4DoAnc07VAcqznYiKfPPZ28qJJV+V27DM4XpXsjasiSBvSaLWSiaMuBy9oD8KxHOW8amR2hIU8+6TUafL6uAtMH5CItxnLhWIEaUTmMYFgG3+7/FqurV2N3y258vlvpk5JjPjzRFIBENuSI1MTPHlhKg7CePDRSMlsXfgZ2/on6558HIyMhQupHk2oXLdejDY1gZeFrVpgNOg6icpX0QzRMOq8Dn6aDoCjiDSFH8/747QSx3b5fkh+LiQCVG8QmgCJ+fQBY8xoQTOwcamql70h7ccXYYkzuK90TPqf0ADUx8bOcYV3tGJEXrzPScRE0NTZh4X9eQNWunYd8XpU7t6N86yZs+3UBPr3vVpGkAEBqjlS1RlEUzrztXpx9z0PiskRRlkOF0DU4HGNaeDxBZzSJjQkDHg+c9dLv1dWo7GlVX+YWdWU1KQcSRgYa/A046auTMP276fhk1ye4YwVxc27hB/NUk1I/FUdUNHyHXz6iIjd3AwCPzownFuyGNxSN81Jya9yATdLafcacgqBGGuAEoiK0fjAGlYJNlmHhYO0AAHtXCrN/mo3Lfr4MJYN/R1nadukcOLK/8LyJLCINWxmaQpjRkJYWYT+aNr4n7vP9zO/RPZVMWFqCiYWiwvIhWUPEZbcuv1XUb8if0xpa064Ur1DFJzQiVURUTISoyDsWy1ulOGUuuw/89gA+2PkBAEK6AIiEB75mYOkj5O+8IUBKHlhZo8bGSh+Qz9vzyyqAPHwX+4M8UekdCUADifS/vGQf/vX9Dtz+9dY2P+fRgkpUDiPknZFLnaX4au9XivWHK+0DJCAqtfECrZDBDlA0aBqiWC0Zys87Dy3vf4Cm16U8rxBRoe2p0GaTc482NICVeUEwwt8tpWBknWE1J7bfKvovoWgM0E3WfK+pFROlZM3AmAiJurx/KvElkKMkvgfLF9GTxb+NukOLqCSCPDxvYoP4h0ygB5D+H0+dGR+VM7AhLP3wHez+bQW+evS+Qz4PeaPBWNhz8+KWdR82CtNuvB0nzroQRQM7ZiTXHuh57ULIf/wSFYqixPRP0OuBs04iH866eKISjbCgU6NoslQlbL0xr2QegowkoN/ZTAhqvZssy40xIosjKlpCLISIirGfUjvn1ZNBckeVC7Ssb1NYC/iNAFLygJPuAyY/hH7DJiAsc55NEyIqNI0ur70KCkBWREr/rP+pDGGaP76eeC5F2Sg+LH8Hv/R9H98OfBEA4GF4AT8fUWErKqDlB2avjk9zrHoOjTvJc3ZiwUTkW/NF0WtsewKOIw1BBaIyvst4cV2pi9zzNp0NZl3Hig4AYGLvLCy/axLev2IkAKUNhRBJiW2LIuDHkh9x54o74Qg6ML90vrj8vN5ksje/dD4ZW366lfQ5AwBzBtzNAbibZbYUVV5ijDn5IWDG8+LyrJHno5LNQjmfru4X9eECk9J/CQBW7G2MW3asQE39HEbIrZH/vebfiMQ0jDpcEZVIXR0Cm5Q3WrQuPvUT4svpLOZ47xSOZVF1403QZKQj/8knxeVBmcmTGFGxyyIq9fUSOQFIWLZ6E+A8CFqbBzZKofCdtxX+KYcCv8sJv8uJzKJi5QqtAbhiPrDtK+D760jqh2Wk95WLa3VJoklfXw5U830x9i4kFv3iB5OlGXqdhpYuJ+OJRdKs09QZRMXtFP8+IV2Dy6f1jdsmHIzX2xjYMJrKDsQt/6tobIWoWBNoUCiKwgknHV6nUDkMfHVVbBuI4w1Gqw1+lxPuxga4ZBEVd2MjgPgy7rT+5PFc74v/bf8/e9cdHkd1fc/MbO/q1bIk9947BgM2tjG99x46BAghQAJJ+CVAQiCQEEggEDqmJvRqwOAO7r2r97a9Tfn98abuzqpZtiRb5/v8eXd2dnZ2NfPeeefee67NoD+Z1klExa0t64+GtQqdYNQqKpaR2q7DdTbyd/7+wF6cCD+kzKSadECQxpIT7wMAPCYIqM/wolVsMWa0K+dmSCfHmVD9Hr4uJsmd6z9TwhUHrC1AQiFUm4V83whrAMtYEK+rRdUdd4Ktr4fZMwgsY4LfORLp3lpg3b/RaCP3YpaZfJZURiy5sAJE7b7is8vB8iziArmO8h35uGjERVi6e6m835jM5K7NnUVJpqI8FTiUpFcp9JMKq2rIDyflw0i4ZeIt+OzgZ2gMN+Kg9yDG7/xYedGVh91r6jRFQI2VfmBeIXC80k2ZF3iMHpSJC21/RJnhGQAtKI6zOGlkEJSpCK+v7Xz+U29iQFHpQagVFYmkMCpzJSn2eaho+te/5MfGQjLAcQnGbJaxY2WXR5s5ObYfO3AAge++g/e99zUeLOHNm+Wmhr4viNTKuN0w5pOVdKyyUo4ZA8RfAfuWgecAniUDmHXixEP8hgrefOCXePmXt6K5KsUNlSl6KpSvAB4pBOrJCg1RlTuvnqJSsZaQEwlxVXl3LAj4RLmVYoAlj6NyyMUIQiE8h4WoqBQVb00lIm3J1QHxsA5R4aKIx5KrvroLb6N+iWpO6dBkz5sjALPYsiEW6lp/or4GKU/lg7/8AdV7lAXBxs+/hiDwcKSZNXb5RWPIxFsXSlZUzEyyv5AgCGgQO+yqrd05ltU0MQUASERFVFSqI0BYFb4pcxFF4I3K+7AByr1XnZEcBqEoCk5VqIkyK+cm9ZtCayMcadpzNsQDaLUkj01xQxRml9iV3ZaD4PLv4f/8cwCAVSynro6IFTRsUE4SzaTJ8T2MAxQv4M1db6Ip3IQ4F8eZ7y3Glqat2NGyUzZiy7Xn4tczf41sqzIuT8sligjPcfj82Sex+SttZWNnsbhkMc4cciZm5s3UOI6PSh+V8j3SeUmwGW3ynNESadEaYaYPQWMFmW/yh3kAAM3V2oKKOBfHuR+ei1u+uQWW7GKETWTcLo7H4eD8yE8wkqMoALWbga8eBFRmk30BA0SlB+GNJeczSBc+0HMeKuGNm+THUrM5qcOpBNuM6cAoEq+0GpNXouqBi/MpF6UQiaDimmvh//Zb2ULbkJUF89ChAE2Da25G9IDiCBuvrwd81eCi4qVkNGoaGR4KBEGQY/k7fvhWf6cMVZlkPETUHQDwq1ah+5cB2/+rfV/5CvK/6CIpuc4CUBxvrenAA02AZxBaQ1pDvUNNplUjFg7h06cfx5avP9Ns//jJPyXvq6OoWPkIuB4kKnGdz0jPL8Sim+/U2fvww3wUhH4AwGxXQiheVTKtwLdA4JpRPC4TNpdCFgryiALbHG6WvVSkCpEwm/w3ag5GERNDI2rHVDbBDBIAaJPY4VdUVD7ZWouPS0lCZ701DTzNAODBM02oSVfIiURUEtVidY6LmswyYnI17/PBZNBaCIzb9jxCnPI91A3/rBnkGL6ESpuCFjI5b9zVhE9DeViYV4R3XWS8yeIFxOvrMeXKp3DLx+R3uPGLa7G6ZhXKQ9o8n4lZEzE5m4yP90y/B26zG8PThuP00tMBAGWbN2D7d1/j638/061FgNPkxB+O+wOeP+V5TSjpuQXP6fZ3AoCVNSvlxy8vIkn8UtioNdKaQFRK4RV9aQaPI4qN2qcGALY3b8e+tn1YUb0C6a4AKDoGSqBQGGdh5bxJbRbsJgPwr+OBlU+RQoI+hAGi0oNI7IrpMXs03VFL3KnNuLoCg1gqnHbppWAyyIUcK9dWBqRdcAFi4g1ippJXomqDOLYxOTbZ9M9/yo89Z50F2mqVS1/9n30uvxavqABXXykTFYPHc0j+ImqoJ8yUTfEsCQ34wiJh8yckIL5zlfZ5zSby/1jRWr5xp9I7pFkMo2QQcgYgmagYuk9UDm5ajxVLXwEvhpd2r1mBnT98i5YarZ9B7d7kjtsxUcK3uT0onkhsuW1cCEJcOT+9iakriEWTzQPHz1+MzEH6tv+HG1ITzEg/Jyo8m3z+Riu5fgUhgjnnD5VLdAEgLysTRtoIAQIaw414c9ebmPnGTKytXavbdfnjbSRPK9dlkTsUAwCrM9HSJq2iUtkSwsujFmPNKZfigdnXAQAoAxnPalXRvu2Dyb2dmP9hHjoUub/7HQqf0fr4SAm1EATwBxQVaej+95Hm3ScnlT583MP48rwv5cpIJo0QjT3DL0LISkgQZTQiS3SqjkTi2Fk+FCduUAhSZtiP4MpVoCMxHL9dQFabgN3eA3js+/uSvv+loy8FI4aJFxYvxIqLVuC9M96T+6+pe1u9ft+dmu70qVC7bzd+/Oh9+b7Wg8fiwcUjL9Z9TSJ/Qz1DMTlnMsCxSIuRe7G5ebdmrBOcefCKPdyKRhOiEgnG5fBo5c4WHPgiAEog18EuliQcW+I2GAGYY17YEyrD1N47BzYt7/D7HkkMEJUexLt7tAmZQzxDcN04ctNfO/baHpvAJQXEPme23PVUQsZ112L4j+tgGjwYcYYQFROSByo+qAx0iSQHAKI7yKBiGTNGlm9tU8gKJF6jTdwN76sCGyEXOdOD5anqUIhPh0zJuFiJMctEJZAc10e4jfwvCIryMvpMQEpy+9skoHKdoqhIVt0AWoNastdOw+kO8f4jv8Xa/76NvWtJiK25qvOGXjEx9FMyaSrcouFaUbgSlMob4ucvfY9GfwTlWzbJxKYjlG3ZiNfuuwP1B/eD1XE5Nlm73y7gUCEpKg1lPhzYRK6Dxgo/PvnHZngb+74HhAS1XwpAyKZRzKcYMd0DQ0I40WQyynltdcE6PLz2YUS5KH676re6ROUfP5Br+qo5xZrtesSVMWsVlXpfBBzNwHTp5ah0ks+cWCqGhVThnj1i6kVjKPl+TLvoQjhPOlGzjTIYQIv9n4wRJTRuibTAMf9khOJiew+zBy6TS84vEVzKmFWfTQh5HByEhMk1r0VRjrJ++g+Eg6vk55P2E3JRxmqTYK7In4f5RfOTzl8NX5MS/myuqsD+9etQtWMb/C2pzdre+PUv8P1rL2Ld/95NuQ9AiirunX4vfjHlF7qvl/vKyRi08q/I2LsMAPDU7tdRKVUKOfPgt44Dx/KgDRTS8+1yscSBjQ2IRVh8+NQm1K6MoaSF9AGriZJk64xmUgFER9s0/k0AYFApYfUBDg3+5HGgtzBAVHoILZEWrK3VmoNNzJqIEwediC/P/RK3T769Rz4nuG4dYgfIyp9xOkEnTCDpV18NRgy9xClyExuF5AtO7YsQK0smKpJ1vjrjP+vOOwGdzryxuhawEVFRyWw/cUzCxi8+xj9vvAK7V69IuY86uTTY1k5PihGLgXn3k8epFBVA6bbcvJ/koDAmYNBM0pFZwhsXAGWiL4iqk3FNm1ZW5RJC/t1BUCz7bU1QUobNIBJ8RmGywZQU+jFZrbC5PQCA0lCZZp/SZU/ilevOw7t//A1WvfN6x+fR1or3/vgA6g/sw2v3/jxpQgUAo0Unz+cIwaxqVfDDW0Q1ePdPP6FsazNee2ANIsH+kbvCJ/TRmXr6OQCIQsAYpO+gXcxIlYKSayoARNiILlFpjTYiz23BVbOLNdv1iYrYhDBMPrdeym1xWvDOjbNw9qQCnD2N/M0rsyn87XQav72UAceQ80v0+mgPkhJrU+XaOBw08h/9k6yoSGEfqWKHH67c761pYqIvx6MWLaAS1A1a/FmzWA7cbiV8MqgpWQV5tq4Bv8ycmdSNXo2yzRuw6m3tffPBY/+Ht35/L5Y+eE+H5H/78uRmiYm4dNSluGrsVfLzOQVz5FDU2YMXAX+fDHzzB5hV3/UVWvybX/wmvI1kTHdnWkHTlNxw9ptXduHz5xSLAndEsQuIeyeC9k4EAFChliSvHUqlkrGg8VNZq9hC5GFga/vk63BjgKj0ENS18BKOKzgOFEUhz5EHmjr0nzqwYiUqrrhSLg+mXS6NUROTkQFDhkIUYuIgaOKS49lqRSW8cWPS6xLURMWQng7rBKX81FhAZFLWG0E8YBC3ddyELh6J4JsX/4lgawu+fuEZrHrnjaRBHABCqtyZYFtrckKgGlL8dv1LwNJLgRV/Td7nxYXA3q+AMrHEb9AM4uCoKuVDuBU4KMqeOePkzQebtBMD303Lfo5VJlWDmHTYXK1VVMbPXwxAPx9FUlRMFhts7o7Vq/WffNDhPm/85m7d7RZVToXJ0vuKCkC6cwMAzym//7t/+kn2k2itC+Lfv/geG744gg0rO4kTr7pe83ziwiXgOPJ9KFoiE9rrSipzrfQr1whDMwiy5Hq8e+rdsjpAGb2YWuTGZ3/9I/77p9/L4Qq90E+GmxCDz7fXoTkQVcqa3RZMK07HXy+ciBiU+2/FWBo7ixQStb4+ubw1FcxDiL2+068k5eacOg+Mwy4rKlIeh5ST0YwGXHA/ye8L2shvQPNAi5mFMaE3jSdA8noyOQ5cUCFlgxq1v+XVbT4cF44AgeTkZDW+ePZJ+XHxhMma13yNDdjxw3ftvl9det5ZnF56Ol5c+CKePPFJ3JYzR95uUymle6QQvtklN5p1i41mp5+mpBVU7lD5HoWV0BgXKkWrIN7T4RY4E9osuGKKiuSgItha7SXO3Mv/BLx3LRBtv0/V4cQAUekhqFc4Dx/3MB474TFMzZ3ao5/R+Pe/aZ4zLpfcTwNQGgdKiApkJWpkk1dfge++Ux4vTx2PZBxasycp/AMAFjtRL9gwg7iJDEbGQYPQESIhZcUe8fuw+t03sP37ZajYtgV71iorIrX3g8DzCPn0zdcAaBPNdn2sKCtjzgFyxiqvvXMVUCeuOPJE0jXyVOCBZhIGUiM3NVFh2yNN7SCsIl+MwYh4LApvgzZM5Ugjg3U8mjzBRMVWCWabTe4fI8F92nVJ+0uqS3vwNeq7GnvylOTv3lRUHGlmmO2ErJisBgS92t/F2xDGxq/IJLjuo4OIBlms/u/+pOP0NgpHjsHNL7yJweMnYcH1t0HgGLAx8r1omnynOeeR5PCJC4iaJoV+pE68AEmklcYbu9EuV4Y4Mj/C4I8ewoH163Bgw48ItBLHW0lRUZv1ZZs4DM6wIRLnsbW6FX7Ps7AWvowsp5LM61NXzomQSmjX16/vVN4GAJiHEAM2j5fkfhljfuRedSEAxfhMKrfOs5PFT12wDs4Mcs3FTS5wtBE0gPo0wJiwqMnwmeA0kNwLzqtMpkVN0NgU5Ek5Qn796x0A4tGIbHBIUTSmn5lsXFm1Q1GT4k1hcMG4ZgHSlcq4pUuW4lfTfoXFJYvB0AxOLjoZnrAyzl3oC2CSGIrdbjIizJvhD1vx/VKiLEpKSuHIdMw4szTp+CMbZyIjSOJ184ono0UqNOdiGJtB4YwJ+bhiaC5yWAq5gkJUsqg21PsiYMt2Yv+nWahakQZhZ/cqoHoCA0SlhyDdcEM9Q3H6kNOxqHhRjx4/sHw5Ipu3aLbRTpfcT6M5bRS+KbkD+9YrF1tMbHBmjGmZcHjbdvg+1VaYpAJt106GnnOVjshWG7nh4yEGMY4w948agLd/rEQoljrxMaZTYttaW4N3/u9+fPTEI/LEHU4gJmrn1iRYU6gLpScAk69QfXgA+OkF8jitWNnOGDRum8gYCoixepbjUdGilXu7G/pRJ+mxsShaa6oBQYDF7sDxl12Dc3/9fzCK+QN6uSLSb2D3pCEtT/FqqLAUYj2rlFnyYvlpyOfVDKKdBc0Y4MpSQmK9qahQFIWTryBlndFgHG31ydL7uo8Oork6gPqDyu/rb+k7MXYJVocT5/36/zD+5IWoPeAFKPJ34lhyroNGpeOavxyH2ecQ4i+REHVY2R/zyyZwdqMdefY8GOMURpU7NblKIdGyYOMXxH/DaFLKg4NtrRiUZgPA4dql78Lg2AODcydsZuX9icUBAHDZqMtgoAyoD9WjJqjfBTjOx/HspmfxY92PAADPRWJn73gQs1f/GvN/UQKD00l6HiUoKpKC9NGBj9AqNMFgIVNURFRa6tIoGBNuvgyvCXaTAzE/A+8+ZUpzhAG3dKkIAF+Rg9WNRUA4dRhZqsYy2+24880PUDhKWeScfM1NAIDq3TsAAJwvivq//IS6v/yIiMpbqiu5iGMyx+Cy0ZcRxZ2Lk07t/7tJft02YgleQh5yglmYUHkG/tPwCl75wy759bRcZXGaXaRfbXn+lnvgDmfj8bMW4qM7F8ihbpP3IP58+lhkr6vC5T4OpVCUphy0otXnR+PrHyLmM8JfZUXjC0t1j38kMEBUegjSCqc7roadgU9VaSOBtttAiTkqmyfcCpa24KsXiY9IxY5mRFki7ZkSVkahtdpcGgCgTCZkXH89su74ecJnaBUVU3ExCp58Ej+efxMiYp8alnUgXksI0huVHO55bwuue/mnlN8lrjMBc6o4utT7JNFNU1oh6iIVURl9FjDlKtKgK6H1OdISqrBcKtfVE34lP2wLx8Hy2tUjz3cv9KNWheLRqBz2SS8swrTTz0Hx+EmyesHGY0kVBFJei82ThpySITjjrvthufBefJB3Oj7brRC5VtgQpwyAIKC6ogaCIGBHjQ8vrDiIi59bA6+Ym5AqnGa0mGVlhzzvPUUFUPJUGsr9Gmm7eHwmikang+cELP2/dRpysv6zsiN9mp2GIAhY/2kZKDGPbPOXn2DF0lcBAFaH0plbytlIxO4WUhFmN9qxqGQRFq/JxZTd2nsg0NoCf0sTdq8ioU7aYEDpFNKgbsKCU+Fy+OEY9kfYihVfpiivLCKkyp6LRlwEhmKQZc3CWUPPwuhM4mKbKvyzqnoVntn8DK754hrsbd0LxmFH+tVXAwAs0TY8VEGcZyv8FRDEUFeiogIAj/30GAwucn2GLSTXok7nNi+uM2BMnRlNe5J9qq7/jEeaX4A7aEBdXT5WNQ1G3J+aqLTWEfLlyckDRVGgaBqXPfIkzvv1H+TcsUBrC3ieQ6yakBMhzGHp3SqTNY7r1uIAPzxBOrVLGHkacNHroG2ZOK5yASbXLIAAbS+eoVOU79xem5SJNSfCaTLDWB/F0pqHUB8fCjTtRWNlM6Le5xH1/gf5UeVaM1Ic7K17ENmvVFvaLtQPER8JDBCVHkKihHmo4Hw+1PzqXtl8LahHLigqqerHYGIQCcbx0d82AwDswRoYE/rVqEtZJZiGDEH2XXci4zpt+IBOCC80+CJ4yViKB+NDYLSQ48SDFNgWcvM3WcnFvmp/s6b1uRp6Ph1NKkM36SZPDPU0V7bjopg5LHnbLT8CVg9xsZ17F3D5/7SvqxUVADCpvmvJCfLDVlX/lKtmF8NqZHDziUPQHahVouWvvoCNn34IAMguVmRbNSlgE8I/QZG8OTxkxB42YzYyByUn3XIUg5BY9fXo/37Cm+sqcerffsD/fbwDqw8049OtJI7+8VN/1j9RARgxay7saenIKirWqCu9AXWeyvrPSf7JmOMLsOTm8ZiyWL9sunxbO8S2lxEJxlG73ysTFQBY+1+l5QbL8WjwR+A26RMVaYJ3m93IsWYj3Z/cTO5/f35IUy0Xj0Zx+p334aonnsWQKdMhWA6CMmjVKbU/i+QLNTlnMt5Y8gZePfVVWAwWTMkhVThqorK9aTtu++Y2/OK7X2i6/57z4TmYu3QudoWV7ubrGzbgnu/vkfsTAYBFNGVUW89/Vf4VKkHet2X8zWhOH42adApBs9aKPruNxg3/OIhQizIWmj1kDJm2V8D5K3j8ab/yPfVCyIIg4JuX/oUPH38YAOBOVxhRTulQDB4/UQm1CgKioRCEuELy3YK2x5Weatwh1vxD+9zqIf9b3EiLupJ2nzB/EEwW5b6QQmV6GBuYhdX/249PntmC5lA2Pmn9NdC0B1WyE3kM/pYsHNg+BMF6ci0t8L0r+3MNuu8yOI6bk+Lohx8DRKWHIEmYatOiQ0HDXx6H94MPUHHNteCjUbC1ZGJxnEysyrPvJSt+2mZFzKhMsGm5NgTblMktv2Yl+EhYs3KOVyd7kliGk4meMhg05IROyFG58+1NeOKrPQAEeKziiiIaAzgOPCh4Tcr+UkVBIqQkUWemkuilbqb35b/+Bp7n5AElvYDkvdQfbCfvwOICrlum3ZZIXgomA+e/TJxq8yYkExVVObIU9gGAVtERNc1mxO/OGINNv12AwRnd+zurQz8A8V4AgClLzpK3GYwm0SYSiKnUJ46NIyK+3+ZRBtI0W/IkxVEMYjTZvq2sAX/9WtsLyWygIQgC9qzRr7qKhoLIHz4SNzz7Mq547GkYjO33ijrcUFf+SHBlkoE5a3DyIA4AIW8MQjeVr8ONSIBcUxbXEE3/JKni6pHPdmH6H5fhyn/v0LxvqGeo5nm6JR3NKZyEAeDAemWBEwuHYDAakSHeTwYjuQ/5uPL7SUQlzsXl0I3b5MbojNGyLfzUHJJ7903FN2gONyMUD+GiTy7Cd5Xf4cvyL/HCthc059AWbcOf7aSSrlH8qM8OfoY9rco1KRUb5NnzNF2GKw1Ka4gdo65EINOG0oYWQBBQ0tBGzlX0CoiLfW+GLKlHxgQlUfTU+gj2Nym/ccifnBTaXFWBjZ8pfXbGBT/RulWD5JRJYdloIAA+oIxv6aZczb6xcAhrPzyAtx/+EWF/J3yNtr2X3ABVqkyyemCPO5Le4kzTEhPGkHo650IUNn6pLPTCvAd7d3FoLCuTt7GhZdjdNAwV3xLSdSazSvbHMo6d2fF3OIwYICo9BDnWqqOoCIKA8JYtmh45HSG8RclH4ZrJypAyGpH/p0dR/M7byLjqKrLNYoHPVSzvS9OU7OBJUUBh9XJwjU3YP38BDl5wIdiWFsSqtOWwAGAervT6UCfEMgmKysp95FzsiMBk5EGbFALkNdtFR0uCtpREhQwAabl5OP8BsoLhOSWnxdtQj10rv0dYVA9KRGOz2r272k/gK5wKnPM8eTz1Wnmy12DMWcD9NcD1y0leiholxwNnPA3c8L1ms9SRVmqBbj4Eszd1JZMEq9MFT0JXYmlAjKvM14JSk0jGAKtDiUenq1qzb3OSXI61adMRo8jkbqXiGJat/TtGWb79UJrqXPoC1IoKANAGSja6MpoYjT37efdOBUWR8FyoM5NEF9DZBNKOEBYnOZsrDdc+9TwsTjKDS/lZL6wgSoLAaceTCVnapo8ZlgxUlyffzxJq9ylkIHGV77CS34YLjEKe6JodZsPY3bIbk19TkuYTw08z8magwFGAtmgbPi/7HBsaNnTwbYGKbAp3XcfgnmuS7527pyohBYqi8Oz8Z+XnLTZVBY3NgdL0oRAMPizcehCDWsi9FDcw8JuNEADYJ4+B6bS74b7/PygQiw+qGDcao8r1Hwomqx3q+6zU0Yxi7CZ28gkwi+NhJOAHF1CureHuqSh1jJefR0Mh/PRpGRor/Hjxlys0JEEXn+h4qkhEyeIGxZH7/f2xT8gvMypzwBgXw9WfX93+ZyTgy63Ho/6gtq/XjuwgOPGeFziAZ2kIAAwlk7p07J7GAFHpIUjlgno5KoFvvkHZBRei4hp962Q9cCp5UuqUzGRmgnE4YB2nVKPQVhu8LiVsEIuwsjthRgYNCgJi5eWI19QgsmULWt98E5x4PMl+nzxWwge5DzxAHCVpGpaxY1HdFsZJj3+HV9coJZ8uEGLGqJLv/EYb5gzNwOAM8ht4UxAVKfRjtFiRP2IUrK5kebulugq+ZiJbD506E0azBb7GBjmRLSXGnQ/c8AOwOEVIAyCNC/UmYIoCJl+uVAOJaBNdafWUi64iUVEByKCWCKNYuqzO52mqLAMAuLKyNJUFaSqi8m3GCXhx0BWotBaCEzvkMmwMLKedYINRFi3VyROcFIfvazCYlO97wf3TcOPf5iGzUJl8HKrVZVquDVbRir6pMtBjSbUN5T689sBqfP2fHYes1MiKioOQyTRRVfE21IFVJYsmEpVsk6ISWhgLrAYranUWHhKqdipKpTrkKggCmlii5l01czTs4rgVZsN4etPTmmO4TFrFysyYMW/QPADE+E2vu7PuuWRRCFq1992svFm4coy207q6J1qbVanQsXqsKHQU4kAuBUYQ5E7KLEPjh5FFqE5zIP2WO4ET7wcGz4YhPR37sz34KaNAc/xQKKZtOgptiHVBrthz58VTgLX/0uxnFYsL/K3NCNZqif60rMWwMERpbatr07y26v12moYKAhDXCRWJCdaC2YMoT4jKuRYBm/O+BesMYehk5XdaXrUchf9dhzHbiZq1ovhd7MvomEAGE3rEAUB1ZiEEAWBj5J4rLzoFf/vjFrz3zs6kfY8Uep2o/O53vyNJS6p/I0cmd43t60jMXlej7d33AACRLVvQ8PgTaPrXcx0ej/eqqkNEYiFZ56thzMuF36koILEwJzdvk5q5qRHduw9cGyFB+U88Lk/YlvHKasA2eRKGfr8cw9etg3nIELz9YyUONAbxwP9URkIUIWZxv/IZRYEG5Lgs8IgNyrwpmsgpXiBWGIxGzL3kyqR9Nn7+IaKi14s7JxdDps4AAFRsTV7laEBRQN74ZLXkECCFfjw64YeuIrGSCQCmLDkzaZtUZSOVKLPxOP776O8BQFOJAADpagJFUQgayGBpFI9h5GPYWaslSMEoh71iKXjJpKk44+5f47JHnsSCn92KIVNn4oy7f92dr3fYQFEUTr1pHE68fCSyipygaO2EN2RyFigKmH56CUwWA+xi9+CPn96MV+5fhaYqP+LR1NbmHaGpyo/3/7IBvqYIdq+tkx1yuwuJqFgloiJWcDWWH0R5SwgQBNAU8NBpWh+Pxz9VEkHZuA1Pf7MPLSqn6FZHDB/NrsXpd94LIHWy9BdlX2BN7RoAQJ4zA1YDuVbC8TDW1Sql0PMK56HAWZD0/jSx4ekL215IcuTuCKcMPkV+LDnRqqEO/TTZq1DhIYuTWIhFrj0X78ylsWMQkqp/dg8vhmOOkkfBuN3Ynaf4SpVOIspsmDUkhVnidYS05Vj8cBhV49Zn92hKnC2ikvnhX/6Iyp+Sx6IT8kgH9qaq5ArFlOTWWymTEg3SSbJ/hEqDAKJEDbEasbr4f/hy1rPY4P0RvFjlxQs8LlnOI6dxA8Zu/CW25f2AA5NW4IL7p8ml7vJ5CAIysBKCwIHX8dgKWdPRtt+Glt2ElDWlDYcpyKO+F6voep2oAMCYMWNQW1sr/1uxIrVbaV+F7Gtg0MldUA0Wzc8/j8a//lXTsTgRgiBoDNnidWTFojZzk2DMy0PAoZisqRUVky15so7t3wdOtKY3eDwYuuxrlH70IYzZ2ox52myWPVRMOrFPF8j50XnKzbcjfTByXBa4RKJS2arv4CipBFLSaE7J0KR91DK1PS0NmUXFAIC2Ov2SyMMJqc9Pek8oKglEZcH1t2H2BZcl7WcS/XGiYs5C2WZldZRoQuW26hMoo0VsocDH4Y+QwbdA7JgaCgax9RvSHXvyotMxbNos5JQOhdXpwlm//A2GTZvV5e92uFEyIQuj5+g39pw4vwg/e+oETFtCBvfExMK3/vAjvn1Vf0XIxfUnczW2Lq/W7NckVXwIQpeccdtCMeytaMPuteSelhSVgpGkkmbNe0vx4f0349qKlzEpy4B8j3bhw8eUSTwateLxr/bA10iOtW5UCz44vhbNnhgos841oVIRX9z2ovzYbXYrRIUNy/ki75/xPv5+8t91zSrTLEqO1PZmUmk4K097zdw3/T6cPfRsXDP2Gnmb3WhHvkP5G+pVNdEUjctHX07y/Shg7ej3AZAEZNfmofDaKTx2LgNaEECpJv9oJIx9P65Rvq6qOeriW+5CeiFRkEOcEQhpK3/iBwhpN1A614LK5VpqLGllnMi1kmuN5ZWx3GMk46i/mdznJtVicd+GFLlELWKisbrBauYI4DiSbBzmyW9kpvwYZiNzQIW/Ajd8fQM+PUi8TWKccg4cTcbdxcMWIavIiezB2rLl7MIKVLeuRSzwLiAkE5XKguNQvqUILbscEEAh4CTfM6tUPxfsSKBPEBWDwYDc3Fz5X6aOctDXIVf96CgqeqsaTiehS34tQY6LHSAXsiEr+XcJ+WKIqaTZaIjF9h/IZK6XgBjdu08mSbTbA2N+PszDdCpmVEgM4eS4zHj5IlL1Eppihf3Sy7B18kn4y+SLkeM0w2Uhn/v7j3bIE6Tm+8g28OS3aq/Z3YLrbwNNM0gTzcd2rvgOr/zq9m51NO0uWhNyVA4FiaGf8Scv1E1UlXJWpEaFQVU+yfAZ2ux7mqZgEhMKc1xKroaXI9vG+7bh2oqXkRupQ66bTOCR2jLwHAdHRqbc3LC/w6jqBqtHaPb+lDxRbFtehX/d/h22fNt+v6WDm4mqWTDCAwAIiImbK9/dhxfu/gF1B9sxI1ThyhfX4f1H1qNmbxsAwOIg19Sg0Uo4l/e1wMaHMRvlyHaZEalfAoBYoAuscq/z4uOQaNjnsyl5XjFD8phz/m/+KD8ucimrbKfRKROVhlADAnFCwgqdqV2m9ZSQkelaJXx2/mw8NOchLC5ZLG/LsmYhy6ok0XssyccBgHum3YM1l6zBv0/5N94853Wk55NFU/RHB3L8xQhaSLMBY8LYqm4ZEZNUN0HA0DETYBXzgHw+C+J/PUETbmFriTeJkdZR3RqUcLOkqGRZCsHQBnhjjfhv+VOa3SlQiATJfDB8upKU/+W/t8tKmgZb3ib/pxUDN64Ezn0BuHWd3IQwJP6dbbQXRbYcGGllvPj0ACEq6iaRUSOw4bINuHHCjQCAoZOzMWXRYGQPDuPyP0yC3U2uZYGthsCLhIpiwfDk/hHA4kDJGQCIKzBvsCAGASVDPcnnfoTQJ4jK3r17kZ+fj9LSUlx66aWoqEideBSNRuHz+TT/+gIkF0enKdl0R9Cpqee8qc87nhBzlkqUjTplqH5xwDRZlD9lS43oXuo0J+0vw2gEbe9cKXVbQufgPLcV1jBZxdVYMxG74ed4a+6lqHVkIttlwZ56hYT9Z2UZar1a1i7nqIgJozTDYOFNdwAUhcmnnqkphbWKqyK1uVlj2QHUqZIEDzd6MvTTrruuCumFJJzXIvqsSA0ax5+8SNf58ss7j8fcYZl45JxxKM0kg3pBlgcAYOPDsPFhnNL4NXJd5DcX6gn5LRg+qvtfpg+jaEy67nZ1MmwswmL5m3sgCMAPb+1FOKCvckaCcYR95LWhU8i16W8l993mZZWAAKx4e2+H5+SLxFFb5oNDUJQNdehn2Ky5mv1L08zIcpoRb5mLwP5fIFJ7LmwmIzzhsxD3TkS07mxAEBBvJQSsNFcJIUYMWsPFaWeci6KxSng306oseiiKkonKfi+prEu3pMvb9KBWVCSMSB+heZ5hJav/kekj8ejcR7Fg8ALcP+N+2d0W0Cc8aszIm4E8Rx4u+PU0FI0hxxtdNwegKHy7OA9UAslvq6vF7tU/oLWuBmHRxdnE8oDfD7ON3BfeZisqv7Eg+vkz5Hrg4oi3kjHXSOsoKg2KEmcTGyxaGTIutcUaAIZCmV8Ji1sYByIBQlQsdu35SdeNjKr1wKbXyGN3IZA7FhindcP1xclv7WQaYIiHNaEx6bE32iZvEyxmGBnlcymaQsl4ARWbnsXzt1yOXSuTnciPL3LCERdVSCGOMBPHD8MLsWmQCzzXgkaGR0H64fEI6wx6najMmDEDL730Ej7//HM8++yzOHjwIObOnQt/CsXhkUcegdvtlv8N6oRl+5FAU5iwVPVqQQLXnGwwxLczYcUSiJrUhNA+JznRkRW9Smzu5Bp6syM1UWHc7k5VdAiCgC1V2nPNdVkAL7mxa4RMtARjcq+QHJcZJ49SiMYTX+3BnEe/0bxfSh41qyz/x86bj5+/8h5OvPJnKBw1Rt5udZCBIS2vAHaVAVlDQrb64URPhX4EntcoKury7ERIJdlbv/kSP7zxktyg0ebW99UozrTj1Wtn4KSROXj3ptm4b/FIzBunJbY2LowckaiwtYSo5I8Y3e3v05eR6tqOhZUJPDHJtnKnvhGYr4kQa5vLhPQ8Mtm11oaw5gOlXD4aYiHwAgRekO/JRGyt8iKf0w657iyFDEy59AbNa+ZIGzLs5B4WYlmAYMRlMwdjWvq5iNRcBIF1wcqHYeDjEACcNuEKFIsVgBFauzgyWbWTTESVE3FcwXEyKakNkDCHejLUg17IZkbeDM1zh8o2YUnpEjwx7wnMyp+F4WlKs8/O+k4xDI2JJ5N7It83DJ5QDn5aWAzGpShMJqsN8WgEHz/5J7zxm7vRWE6ucRPLIlZZJfugxBka0VYTDvzy3/C/8x+gbgtYsX+QwZ0DnK5tVaImKi6xY7lVDPEXTB4Hq8uNtU2fwB9vFb+3W25cmEhU1NYRAICt7yiPJd8UFdoaQvj2v+S4TqYBGDwH149XekZJ7sGhFiVnqjQ7OcezpTbZkkINw37AHBMXlEIcLcYy+K1m+CwsYr6X0EaHke3sPdPHXicqixcvxvnnn4/x48dj4cKF+PTTT9HW1oa3335bd//77rsPXq9X/ldZ2b5k2xMIrl2HeENqrwJAISrSKkINtjm5DLS90E9sf/IkTJlMsIxSVr8ttUG8/5f1OLhFTLQ10TBZtKV/FpdVv7oFhKh0Bu9vqMauOu255roVolIrZKDRH0WD2H0122nBrSdpc04Sc8gkrwhzgj2/wUSIgDNDWe1Jg4vBaMTVT/xT7DgLVO7QthPoScRYHpzqpCWi4jlEohINheQw4CV/fBxX/PnvKffNHqxUcq374F0EReMlq8vT4eek20244YQhcLm0v69RYOFc9i8c17wSDvHvVzDi6FRUAGDyouSQYtCrqCYhr1ZB+eqFHdixMjkHyttABnB3tlXOfQn5Ylj/mVIF11YfwjM3f4tnbv4W/7ptORrKkxXTnbU+pCcQlZiVxrZqshDwczS+y1BUFX9DLUwGWs5Bcse9mDfYprk23XHyOX7GgTGFGXCZycQdpLUTojmhD5jkl3L31LthYkyy4VpzRLQf6MAPSk8JUas0QGqyqDZ1i/Odz++RfntnLA0Xbb4f1rJcLLzp53CkZ2D8/EWYdd7F8r4Rvw9f//sZAICZ5RD66Sd5vGEZ5W/Q9MTDwPMnIS6GPYwjFhCbAjVUoR9XJiEqFoYcK2PYYHBip/kgS/6OdoNHRVS0eYKB1ihJzv34LuDLB4CK1cqL029I6nStLmt2TD4FGHM2zh9+Ps4YQkIzDSEyL7EqLx2LkJybGNGpNpRACQzo7dtgixCizkbWIp6gLA1uegX8EQy3J6LXiUoiPB4Phg8fjn379Mu5zGYzXC6X5t/hRHD1alRceSX2nTw/5T4cz6E1SiaSxJsV0HYqlt/TTugnsJIkE1vGKMoC7XZpJP+v/7MDtfu82Pw1IWqMgYbRor1A3dlW2WIf0JKTzhKVv3y5O2lbjssC+AhDrxEysGxnPWJiBn6W0wyH2YBbEpxbY6qOp1HRT8aSQFQkONIUsif5SwBksB0x8zgAwP6f1qKhrOdVlUicw7zHvsWF/1ot2+S3SYZv9kML/UhhH5PViryhI1J+fwBIT+hC3SquiFIpKnqwOJLDkHztfkzybYGZj4FnjMgaXKLzzqMD008vQd4Q7e8VUjU0DInhHCmhFQC+fXUXyrY2ad7jbRQ71WZZNQ0S28PaD8m12VITxO61dfA2hlHXFEIGr528z311Hc5+ZiXqvBG0BGPY6hqL98QQjtRK4mdzS5DDRHB5zVtY/9gvcfu8UuS4zFg8NheeOLmmItY0DEq3yk62ASEMSpUEm0pRkZQU6f8WcbLqSOnItmXjl1N/iSWlJH/mwhEXdvibSKAoClePvRrFrmIsLF7Y6fepvXIAwFmXi5KJUzH9nIcwcu4lmLLkLFz420ex4PpbASjeKGaWQ3jTJpjE8TPOKAs6iib3eFwQjc0sZsCZBw0adsnlzLKiIhIVxmUCJzY7DMbbABCiIn/2it9rDhVojQCtZaTf2Kq/AU1iCPuGH1BV1Yy/X3UB1n1AqqjaGkLYsUIhzuZBIwGaQVtdDU6OEgsFiagw1Yqiwuv4dUUSttlVYywjkN/DIOf7cBB4rbpo4CNYqcr/OdLoc0QlEAhg//79yMvL63jnI4DgSrGbbzw182+NtoIXeNAULZftSRA4DoJOb5uau+/WrfxhW1oQ2UK6c+Y+pFzkjGpSa6kJorFCq3IYTDQYg3YQzMh3gFbZsTOqJGXnySel/D5qDM1OnEwFXLbnNqCKuFbWCul4fyOZRD02IyxGctHne7Tx7aaAMkFIikqqidqkIleJk23u0OHIF/MqXv3V7Umdhw8VP5W1osYbwU/lrRj/+y/xwoqDco7OoYZ+pLCPnm9MIiiKwsSFp8nPpRYDtk68V0Lx+ElIz0+dEMlmDALNMClf7+9gGBpLbiEW+zY3+dupwz1SF+ai0ek45VplUfDJP7bA16yykpcUlSwbKIrSLftPRMX2Fnz/1h68+dBafP2fHXjtgdWw/dCEjARFJcQLiHMCdtb50CpWD4UZsQJHzL+79aRhePeSYaB4DrFwCLF9G7H2/vl49rIp+Pl0EqKZOXE4KIqSFRVfzAeTTWUpb9MqJJKiIhEUKdQjE5VO9Cy7YswVeHTuo/ji3C9wz7R7AGjVkvZw15S78NHZH6XsZaQHg0l7rQ6qG4sDGxux6r39+PDJTeBZAYWjx2LsiQtgUDVgtMRYsA0NoJvJYjKuUlQk0YflafEzLIAxIcTBhuXKHElRsRrIuMQ4TcgpJYuysFgJ6TC4wYphFEuLth9SsC0KBFTqvGhrAWcuVrz1CniOxQ9vvAQA2LlSqTZyplswajaZE1+84wZsfO4V5Dda0BhuxI7mHbBUKsfUJyra+SI9X8n5Y0CIOpOifNrnOAVxRyaGTpmh+/qRQK8TlbvvvhvLly9HWVkZVq1ahbPPPhsMw+Diiy/u+M1HAp3I42gOk5VPmjkNDK29mXgdMy8J4W3bNM+FeBx7Z88BBAHm0aNgHqoKoajcUD96elPSsRgDk2ShbHObNMlmgqrkN/2aa9AZ2E3KoHzx9CK8fEY6nDUr5W1eszIw5ahimFJsXYKUwwKoQj8OfaKizkXRq4hRt6v//rUXk14/FKirlAJRFv/38Q45dHWooR+pNNnq7JwKeMLlKoNAMQnUkd75ijij2YJLH/krRiQkaUrw6vhjHG0w24yYeeYQjJhOrpm6/Uq+lRT6sbnNGDZN28vo1V+vlsv8vY1K6AcAJpxMcn/UnWvVsIukaOu32qR4c4iHU0yk/cAWxYtO5Z4obwrKIcYZo8nxo6Gg3Pcqqrp3m6uUcLchSu6lnFxy/lJeSJANalSUVKEfiagkOt52pWdZviMfJoZ858eOfwx2ox2/mfGbTr//UPDFv7fLj8u3k3GYphlkFCq5i7YYC7a5GZTY94hjaEjagdRsWg79mFPk9dWRULPRYsGcMy+D00gWpIYcOxbfchfGnXQKJp1LwjF2oxtsnJBgC62QhnjoB+xZ9SziTWXaY1MMBGsGjGblN3/rDx9rcqYu/8OsJKI2qZzsf+HHF8JaoaiAegp+MlFRFjAWUf1hdKpTP827DG9lDEHTKbejcPTYpNePFHqdqFRVVeHiiy/GiBEjcMEFFyAjIwNr1qxBVlbqRMMji46JyrYmQjj0yvnaIyqxg2Wa59G9SuWAc96JoHVumpp9bQi0JMcKDUZaY6k84aRBxEDPohyDVyk7nbVGl0qTHz57HB45ZxxOMO3SvH7CRCXHIVtVGpueUMr7xtoKROIc8Z0Qbxq1ovLBpmp8t5usCorGjMeEBYtx4lXXQw/OdEW29Dc36e7TXTQF9Ss/HGaDrp9MVyCFfjqrihiMRuQN1VZSOLtYum+yWHHaHb/SqFQSymypy8KPNuQN8wAAavaRvwHH8nLTQoeHXLeTF2p/j6pdLRAEAW0SURETX8edUIBLfz8TF/92Bs6/b2rSZw1ROYbqwWg3YI+JRzOjrGDLW0JyGbzHo4R5pWsmFlImn2CbYiYmq5Oi8igRjwgbgVlFVJJCP5w29DMmY4ymyqe7XeAnZk/EqotX4cKRnQ8FdQVkXFOeq03U1HlBabkKUTEIJvA+H4QKhTjGxYUfGyH/x3NJib6sxJz+N6BoFjBarKSq2yq/d/w40m/NkGMDYzfClZmNU264HWnDyWfaDR7wHBmjTaIxpiBw4KI/ItR2EHt/+AK+Sgv8NeJnObKx58cGVOxQyET19ndk1fzMOyYmGRwCwOzNcdjDAu59i8NxO5TfgQsGIQgCuFgUZZdciuq7fiFfJxLUlhAu0VMmSVGhKBw02xChAYruXeW114nK0qVLUVNTg2g0iqqqKixduhRDhnSvM+1hgequkEobvyr/Crctuw07mkmS1XdV3wEAji88PuntfDA1UYnu3y//f+Ccc9Dy8ivyaxnX/0y7s3gNSTkpiWCMtKbN93EXEG8Udcgo98EHAQCZt9yS8pwSIRGVfI+olpRpzfhmD1MG5TxV5VFaQinvO+ur8MbaCsRVDRIlovL5tlr8fOkmXPWfHxGJc6BoGvOvuwWTF5+he07qJFz6EPru6KG2Tb/raY+40kqhH2fn5W6rKgfL6nLDaEqx4usAPKtUojAmM8qsRdjNp/dY75q+juwiMpF7G0LgWB5bv6tCW30IFocRQ6eSa3jGmaUakvH5v7bhmZu+lUuTJaJC0RQ8OSQMlD3YhTPvmAizzQCzzYAz7piIWecMwXm/moqbnjkRp948Holw5ySTgG93NaBGLOPPcFpk1U0K/0RUq+SQlxCVWCSM/WLTQauoTkpJsWE2DIvK8KwqrrW5T1RUjIwR5w1XymIPpQu8nkFcT2HO+UNx7RPH4/TbJiS91lZPxtpIMI6KXcp3ry1cBACI790Lo3gfRDMkokKD44Ddu0n42igt7KZcCVzzOVAgElGvMu7Gq0RFuFirjDJiDo3N4AQtkM8xUhGcPmsDBE4pqKhevhnVK9NR9X0GBB4oj+XjqxfXJpivkXFX4CPY9MVLsuGj+n6lIOAPr3CYfCDhHo7H0bJ3N/557SVY11AJ36efJhlNSgaDgFi+DYBNsD1wpKXj1vkj4bIYcOO83p2Te52o9CcIkQg+2PcB7vruLnxX9R0u/PhC8AKPMm8ZAGBSdnLjJj0ZLvtXpPNxdD9JGPZ/9TWiO3bC+8EHAADL2LGa3BLy4eRilLweFl47UiNXG4w0jjt/GPKGunHKdaokXLsSm3YtWohhP3yPzFu7TlTcViM5hwSiMq1YCdNcOE1ZxeiZo322rVZm9ozBABiM2FbtxY2vKa6r22s69sVRq0FRnd83FQRBQDDKtrtPnVffJron+vxInij2tGQPilRQ5+jwbPvn3h54VW+Tq55+CV/ln4rWUBx7GzrfKLM/w+Y2wWCkIQjEe2jXajJxzzijVLbbp2kKJ14+EkMmJau5OSUuXQNFACgcmY7rnjge1z1xPAaNTIfByCCnxAWaplAyPhPzrx6NU+6YgH0GDiwEjJqRi0fPGQeaAp64YAKsRgZlzSG8t55Mlrkui6y66SkqgVayAv74yT/JFSdSGNXCiIZ+bAR2t3KdPbb5r5pzDse1RAUAjss/Tn5sNab2UOlNSDlC9rRkwi4RlaZKPzi+BBSdAYp2oy19OgAguns3zBQxAlxXMhE/TbwNe0vPxb56RaW0exLuTZdoHOhXiF68jvwtjHna0DVtNwJinqCdMUEQWHxeU4JA0wZMWqD8nvUh5Z4u93vw7iYbot6XIQjKolYQOKTl2uDJ2oTdq77Bew8/iO9e+TdWvf2a5jMLEirqDWJu54d/egiRWBSVGaJpXKvW0j9D5cnFi+NpIlFxZWbjrgXDsenBUzAkK3Xi/5HAAFHpAvhwGL9ZqY297mzZKSegpVvS4du2FdsXLUTrm29i/+JTUXFlch8byxjCZmP7iKIiJCTqqslFIuLiqtz4/mUYkq2KVRtpONMtOOfuKRg2VSEwiccyZGV1qSOuTyQqLquRlCQHVMmr065Dut2Ef142BU9fMglTBiukxaNj6z44wy5nn5vtDlz10o847e9a4rOpsq3DcxozT6nAaqosR9Wu7e3sreCxL3ZjzG+/wJxHv8GWKv3PSWX777R0nEDZHjiWlVe/JRM67wSbXayUKecP734PLJ5TiIrH7cSc4eQa+Xqn8vd856dKLF3XQZfXfgqKouASFZHW+hBaxcmmaLTWL8RsNeCkK5NLtsfNS52U3BFGzMhF3GXEfx0xvFrAY/wJhbhoehEOPLIE50wuxJyhZKKUKudyXBa5usvfRPIqomHlugyJoZ+DG3+St1nsZPJTKypm1b1fx2lDpGGxx4u0P6BtBngoisqRgLoJpYS2hjAEXkCgLQqKYmByXQqT6yq4WPJ7RWvrEE8/EwCDWLgZrXYTKgedjGXBSwEAZrMVJRMTQnlOMR/Op1TfxGoloqIdWymKAuUm48QI13ggvh17/Fn4bJ0PoTblPmt02sHSFChrOuL0H1BkHwUgDoFXFBWGYXHJ72bC36iURq//5H9Y8/5b8vM9eRn4fngh/BZlEWUZMxoCgDaVgiJACRcazGaMmTcflCqlwREli98cn3bRJ1U40TphpyONAaLSAQRVV029fJPVNatl0500Sxpe/eNv8LnbiMr/+wMEqhiUcwiY3AmwTLkWTCYxOpIs6+PV1eBDIfAJ7dfbJSqi7GukQrC2Khnl6vyUzh4rJUItgCCA4wX4RQXCbTXKJcnwDAZu3yh3KF40NhenjddalhtUmfVZokNuIMLKUjYsdqzcl+wvs7+x4xW+zeXGXUs/QtG4iQCATV980qmvJU3K1W1hvLyqXHef8mZ9okJ3gdzpwdtQh2gwCKPZgvyRnfcumXzqmbjgwYcxafHpmHvJVd3+/Mmnkli71B15RC6Z2CT/m1CMxS/f3YJ739+KOY9+o9v6oKchCAI+31aLA40BhGIsQrHuK0adgRS6Kd/WDJ4TYLQwST2BAMBkMWDwOCUP6pTrxmis0LsD2RBRx5gx121OeG5B/giiin75r7+hta5G7vkEAEFvm0YhAxTlTSIYES6iyUsJ8xGwvPL7JoZ+ACDLpihJBrrnmnoeDpgsDGiG3JMZBXbQDAUuzuPTZ7dg748k142iDKAoBn5jJlZP/y287iGgaAdog0g6RWIQo8lEnWnWqYKTSpX9dYAggI9x4MVQoDE7mcwxYtfuUucYMJTyex9c+4X8mKcprBpZCsfCR2E1FmNWthjiVikqHBtFLBJGyNvW7u8QsJqxM9+JrydS2Hb/2bCOHYeNg3PAq8hF0GwEy5Fzufnfb2DRTXfA9+lnmLOnEiNqmlHQQnJhXJEYLnvkSfl9nlz9vlq9gQGi0gHUCah6+Sbbm8hqngIFB21HRBxAQnkjYBl/MWyzboNt5i0wDpoB67QrkHH99TCkpYFJJyu5yltuAR/WHpdWZejbxW6gaZdcAgCIs2K9PxWGefeb8n6JRGXvulV45Ve3IzydNLAzDe5k4uSeL4A/lwDfPYpmsaSYohKIiisfSC8FxAQrX1MjPn/2yZS+JtkSUYmyCPkJ02/j9UMpqUIviaAoCrPOvQgAULFts24/JUEQ8NsPtuHJr4lXQYsqUXbNgeak/IxwjEOD//CYGgVaiOrmyMgE3YXENIqiMGjMeJx01Q2H5Hky9+Ircebdv8Gim+4AANjNZCKSyIH6d69uC+PfPxzs9md1Ft/tacSNr23ASY8vx+l/X4H5jy9HJIWra09AUlS2f0+u48wCR0p1ceHPxmLmWaU4/76pGDY1p0sqZCKCURZXv0TK+aUWBmrkJDh+5rosGHfiAvn51m++lN2cAeJw7EswoLQk5KhE2Iim0ifEhnDq+6ciFA9hdc1qsDwLA23QlAe7VD3DArG+HRKkKArTTy9B8bgMLLp+nExCy7Y2o0Ks/hkxU6kODNuyUVVwgvhmsY2EQK55nm+FmbZhjPMkNL++E7y6y7ZEVOJBIOoD7ydjCGWkQVmS72PPEiWXwySownVB7bgyNEfPHkI7Hj1381Upvr0WbXYjnlvMwH7ccfBceAHqPNowjddKxl+L1SbnuMWrKuEOxzCksU1DAnJKh+Lqv/4Tcy+5CpMWnYa+ggGi0gHUJCIaTLa9lzqHus1ucKIvSq61BFljL0/al8kqRPZddwIALONIqVfsYBmEUGpFpfDvf8Pg119D2kgO+OFxxDkSUjFSEU3pm0FFVKKhID58/GE0lh3Ator9GPzmGyh+a2nnvvDHpGMnlj8K3w/PYRa9HcUZdhgZGvCJdf0uLdP++vmnsf27r/Hqr27XbP/9GWMwe0gGrj+ehC/8UVZO6qqK6E/WNSmSWfWQN2wEaIZB2OeFvyVZndla7cXLq8vx5Nd74Q3H0aTq5VLdFsY7P2nLRytaUic+H6KggoDYVNCZ3r41+eGCwWTC0Gkz5VW2VfS7CcU4BKIsTnpc2//jqWV7ccnzaw6rsrK+TImb728MosYbQWU7f4NDhTtTm3cx4eTU7TeMJgZTFhUje/ChG0pKRBlQwjtqqFUWk4FGhsMMV1Y2Ji06HQDgratFNKSV5at2aq0NpMR0dY5K4QwSxqhPIxNybbAWX5R9gV99T3LkzhxypkZR0eR+cb3nQtpZTFlUjCW3TIAnxwaOTf5dC0do803a3MTuwZlOyFmJ5SCyGjeCidZgSuYCpJncCG9tQugnVeKxyQbYxZBY7RZwIlGJCnLaoAaWQjdiPLmGjXTyb0gbyTkMsuuHcdXhus7m38UZE0qr7ciKOSHo9G9rs5NrwmZWrjNapwqQEqtM0/MLMf3M87rk2XS4MUBUUuC/G6tw/Ss/obJGyVYK+1qT9qsNksnbbXYj3ETiwCfkXgCzjkOtusQs76H/AwCwjY3gEkrH1IoKbbPBNmkSqI9vg/D1Q4gL5GIyUlGYaeVC5jhy1wg8j9ZaJZ5au3c3rBMngvF4OvfFVWZHQ398AG+a/oipGaKEKcVpE4hKY6V+GOXK2cV442czkeOygBY4WOt2o62e/F4BaOVuaeKs7aSiAgCMwSjfTIlZ7QA01v8vrSwDANhNDG4WM9g/2lKTsP/haXAZbGvFZ08/Tj5f5QjZm7CbFaKyen8yyQOAVfub8fKqssN2DpzOSN94mBQtQNtXx5FmRqlO0uzhgPo6PGdysndNjkpluWLmYDDiOCGFNr2N9fA2kMlTIprlWzfJ78kqLpVDFnKOChdGzEHjzfmV+GKGkh/x4KoHZRftO6fcmXQul426DG6zGxeMuKDL37M3MXy61mguq8iZ5I0TFxvG2j1kzDAWZmDkzhdhMbtRaFNsANiE1goYSnLiuO3L4FtFxi9/hE0y3ZQ/R1RqGC55LDNYZiA7YIEvrr3nKHEqtum0yJD6fqUCJRhx/OZM7F/6sS65Kc8k39equt04Hzl3U0kJSt5/D7Zp01D00n/a/ZzeRN8ORPYi9tYH8OWOelzkC0LilZFAW8r9eYGH/8d17R5TUMnahox0gKYBjkO8StswSu19AgAIEQLEwgSJWxqpMAyUckNxMR4f/OUP2PfjGo3VfsjbhkBLs6Z/TrvQ6Zj6WPm5QMWXgF+c2J1aomK22tCeUOwwGzDZuwnTWtdho7i4jDBaufvP543HbW9uhDccRyjGwmbq3KVpc6ch0NqCoDeZRG5VNVP8q7iqzfNYsWhsLp75bj+21/ggCIK8ktxcmUx2Mh0mNAViuHxm9z1Hti9f1u33Hi5YTUroR+0PU5ppx4EmZbD7Yns9bj1p2GE5B17HCXPV/mbsqffjkhmDD9m3JhEuFVEpGZ95SOGcrsAfIUT/z+eOx9mTkpNyJxV5MDTbgTH5LvxykTJherLJRFt/QGknMuq4E7D5q8/kDriDRo/DOfc/JL8uKSTheBiNoUZETTrdgAEUOYt0XWF/Nf1XuHvq3UnGlX0dkxYUwZFmxrCpOfC3ROBItyQZYEpwZngAAJzDgbZBecgw52uuBSGxMnDYfGDzG2haNwrxmJjczAONFX7kFCcrbjwleagkj2EUbQdv9CDKaZVDC2NDmAsk+d0ApJRYqhhsD9U7t2s8dhJRAiPi1dWo/uU9CG8glZaO44+HZfRoDH71lZTv6wsYUFRSQLKCp6MKK4762+THy85fhny7MmGfM+wc+LZuBaNzcUrgIwpRoQwGGDLI6jpWrlUkEkNB8BMWzwrK5G64fTXgUFYRsUgU+35cQ96fkK+x/LUXO+eX0XIQqN+q/9qrZwNh8SawaVUBtT03p1NC67QYMN6nPa5kEy5hUpEHDjFvoiuqilQdEfISkvHBpmqc8fQKfL+nUVcheeiMMRie4wRDU2gJxvDUMsVkb2t1GwDgxhOUOPPS62fh8zvmYsHo7idTVu5QvntGB6ujIwW7SVFUQqqBORLnsPT6mThnEln5b632Hra8EV7nmnz623343Uc78Mrqsh7/PFeGBa5MC2xuE6acWtzjx08FKXxWlKFfSeOyGPH1XSfgqYsmwazyBXJna1UCd05uUp7S1DPO0bg3y4ZvXAT1odTtJdprOtjfSAoAmKwGjJlbAJPVgIwCh9zmoHBkshWAO4eEX3ev/gE/eiywGz2a14VowvU+iFjHx2PKeB8RBN3GkwDAiwvIWdlngAKFaRmVmJJeBZe5CBTtQNyUAQOlrYq0GcQyYrFLuuZ8s1KPPUyCc/fOH76VH8/ZoyU3OaEoqu68SyYpQDeLLXoBA0QlBSxizget6hgZ30eSRQscBci2ZeO/Z/4Xry5+FRsu24Drxl2HaGsLzHQ7/gOcAEEVSzVkk9inkFD1wyVmevvJgBPnCVExUFHQmaVA8Rx5l5C3Eamwe9X3qHz+emBV6m69AIAPb0v9WjwIHPiOPLZoV2KMQSFnek6xDrMBjKAlT+GE36nAY5UN4zqbUAsANrcHAFGOGnwR/HzpJmyp8uLd9VXYU6/VeaaXpGP20ExYjAzGFZDvoM5TaW5oRGG4CqeMzsbF04uwZFwehmTZMTLXdUir75Zq8hnFE6dg8pIzu32cnoRVRVR8qjyUCMtjZmkGHr9ggnwPvLDioKZXU09BJ11DxoaK1CvD7oJmaFz04Axc+ruZsnfKkUBAJIISEe8sjBYL8lQl6UOnzYJD5cpcNHZ8UjmtOpm2Ptg9onI0YfGVo3DuuHTMHqfkhqXnk8cOQxqmZi5CjrUYAOm7BECbTAsAroKkRoURHoiG9KvUBChK9wUl98BEmTEv5yDSvYT8xyy5MNDaYoLZ2WfBwth1x091S5FEJPZwWv/J/8gphyJwh7UhLK6lFZEt2q7ztE5OS1/EAFFJAUlRMYVVMb91mwAo9tI2ow0TsyfCyBBWG/X5YGba/8OrbwJjgTZeHTUwEAAYmTbtm0RFJS4qKkZGvEEcuTBQZFJ3prc/kUR3LQO+/A3QuEf39U2VbeDLVrV7DBkJREXt8bBnzYrEveGwGGSnRgkBgx0ZKlM4iqKQKxKVriTUqonKFzuUgXlbtVc2q5MwZ4gS/nrtOrJKqm4Lo7IlBEEQcMqO/+Dsuo+AugN45Jxx+Melkw85PNBSUwVfIzmvBT+7tdvOsj0NqYdTOMbBF1YG3LhIpCmKkitUHvtiNx7+ZGePn4OeoiLBQB+eocloYmDqRFPBnkRADP10x4tnzPHErj2jsAgzz7lQ01m8cPS4pOtTSqYNxAN4bafWHEwNI33oTsv9AYEvy8FW+pFV6Uf+UDeKxmQgt5SomrOyz8AQ5wRkW8hzr5Tnl0hUKApC6SLNpqgggE3cT4QfZZrntoJFEHjAHCOqr9uahxwrCSVvatuMICfAZnCi0DYcJ1x2DRb87FbN++2eNLjdycoQAHhy9Jv3WhxOFDz5JLKKigEAaYEwuNZk8j+gqPRzWAwMIAiwqfJS6LJqQBA0mfJqxAIBDVHxG9qS9hEiyqTgPvss+bHPYsKyMcXYMiYT6ZMTJjPRZC0uiHbX0vzuzMFFGbcg0/AfbPvm1c59sX9MU6p3VDjrHyuxne9kWCKRqKgqEg5s+DFpd7OBSVJU/AYnHj5nHM6dXIjXriWkId9Nvl93FZWNqlW4Os9CwpyhyiDvMBswqYi8d93BFrQGojAK5G8Trtrf6c/vCP+580blM9tZGR1p2ERFJRhjNYrKXy+cKD9WJ7tuSmGQdyhol6gwvW8y1RPgeAHBGJnQuqqoAMD4kxfiot//GZc+/AQsdodGUZE6+aqhNnCTjCj1wKMdOesoQlQVnjntylE47dbxsEddGDNnPtLN2tCan9MqKnwoDv+KarS8vRsRj7adR4QH4jF9osIihA8rnpGfGwvngL3gI5jiROE9IU1ZoIbAw1hCwj7TF5+HKaeeifHzF+H8B/4o72P3pGHh+OkoalJy6IZOm4nr/v5vLLj+VuSUDsWZv3wAOaVKE1vPtOlwLVqIc+77PSbOW4BJ5fX6RMXWPxSVgWTaFDAbadjYCIysIp9RcRbmOKNxbRRYHoG1tbAMT0MgEoLZQ16rC5ehKVKFsWnHaY6rVlRskyfLj8szycVabXCDqf+R1L5RFAn7rHwKALA3Qo5lNIoDvCMXbVEeVY36MnnJxCk4uImYwsUFVdz5iZEAKGDOz4EFv5c3x8XLYS0/EjNobfNBDRKJiqqteLBNf3BkEgbGIGPDhEIPFo5RBgtZUekCUbGKRlctLW14v7o66fVTRudgSLYDLYEYJhVpVyXTS9KxsaINK/c1wVmhmOe5umBx3xUkmUn1ImzipNkWimOVWPVz3XElmK/Kxan3KSpdnTeiSTzuCcQSSkpH5bmws1Z0Xj5MisqRRlBlYufohqJC0bSmL4u6ZFTdB0r+DKMDDqMDgXj7PijHQo8nQRDAeZVrOFbpR6zch9Z392Jy8YmI2YJASBmPA2Loh21qQ/TAQQTXRRHaRELqIWgVqIggwBzTJ3s8b0aYC6MquAeF9uEwUCawQjrsgZqkfcdnR5Fe7EKgwg+rxSkXQqjzk2xuD9jaOgytb0WFWMEz/uRF8j6SSRsXj+HjJ/8EQGmp4EjPwEnX3YxdT/1T91z7C1E5OkaDwwCLkUF6hJRw0Q4HKBORMYbUChqiElhZA+9HB9Dwt43Y4zDBRJPJNsaFEWKTy9fUsiLtdJLKHyR0rmzYDpT9QB6/di4gmi+VRYn1eslo0dDHmobWaOqcmFNv+6X8OCa2MRcEoDbsRJSjgZVPavZ3gYRw/sqeh9Ojf8AT8fPwD+sNOj+OMlgKPK8xowq2tSXtzsa0sdI4ZYBA0bIRnASp8WGtt/OhH8mRc8uBZJUIAIbnOPGrRSPxp/PGy2WfEmaUEIVj3aq12Lr0eeV84/odlLuKvjwZ2IwKaVpfTohuToIZ2ZJxiqwcinEaH5qeQChhRepUKQ5Hi6IiVfyYDLQmUba7oGgaExacisJRYzFYLF9Wg6EZfHPBN3h07qO4btx1eHnRy1h62lJMzp6MVxcrqisvHP2KCh+MA5xyD7LNEQRWELIQK/PB6NFe70HxcuS8QRw86yyZpOghygOsnqIiCIiyZpjdP0OMJwsuIcqDbWxEWttuTPF+qt3/008gRIn6y6tCsK7MLAwaMx6DRo+D3e0B19wCk6qxqMmWHLJRV3bmlCgFAZTRCNqtjNnGIqXPT18eo9QYUFRSgBAVsrozZGWB8/vBNTXhd2/weGOWwq4j+9sAAEKcxxDnRJhpQmJYIY4Qm5wVzqtCPxRNy05iaqIS5RiY9y0DSo6Xq3BqYyPh48jEMWwBCZXA7EBLTJ+ojDpuHiwOB8accDK2L1+GOE8DEy/FD9uj+HFXPUa5GnBqwW7yeeIN4KII4fAJNuwQirGVK8UiJw+E/6U9uEm5ScIBPwTVoBePhBGLhGGykPMSeB6V27UJXBHajKtmFyf1kMjtRuhHcuQ0sfrvGZKdOgYr9SYaFNYav8XCnSdK7UHdX2f2+Zf2yDF7CjZz8qTpSsjdePC00Rid58JjX+5GjOVR0RKS2yH0BNSW+UaGgjlFG4j+DDk/pRthn1SYf93N7b5uNVixpHSJZtvLi1/WPBfQPyaoQwGX4IfCNYfBhZQwJ2XU3gMRcdKmDGYIsRj4SDNoi77vUVQATKKiEquoQMPjTyDjuutgHTUMEc4OirYiLhIVphEINLKg7dkoyEho7sdG4PvwfTA5J2nSAiiaxgUPPiyrmJzfDxrAuDGTELZZkD9sBBLhzFB8gYrGartLGzwexMTKSPuMGWirID29hFjPLj4OF46+kaGHYDHQGqIiqFSD/EZVq21Vrf7UzIUYk0Z6qZhcdoQ4hagwHjLAq1kzAECczATVnB1gTUTt+Pp35D0CjfdbHpFfl5MBzU60xAgxOumaG5GXrhCoxbf+AgCpHABERcVThB+3kHyXnT4xvs2xcr8XNwiz9wrK5G6weZJ+G7VNq6+BHM+RngGj6HwYUqkq37/xEt5/9Heat6/JnovfnTEGicjvIJlWj/3v95LBwsIrEu/1x5fCZTHgtPF5WDxWP9kMENsCAHJuioR4pPNEqT1wKmVm6hnn9MgxewomJvnWz3RoSUia3YSfHV+K8WKFVFcIZGegVlRMjFZxCB9GK/0jCak0uTthn8OJY0FRUYd9AIBtiRCVRUQsobyYFYcXymABZbWBD+gbIR4QVXFWvEZrfnUv/F98gbLzzwdiQUR5QkYEkXibvSZwXgssEy8H5ThZcyyBjSJWSapJk+YGqN2C80G7CjFn/mKcefevNV5ZEhzpGZi06HRMWXJmUp8eRuWIbRk/Tu43Z585U/c79jUMEJUUsBgZOOOEnDAej6YhYU6dMmAnmQNJ73fZNYqKMYcQCl6npC1O0wiZFJLhj1uwPzITke+fw3Lf9dgYPEuzv8kqDugmB1rE0E92mhmz7aQ+ftTYEvkCN4rKRpxnEGWUmLaZFm/YqA8N/ijMiMFMkW0+KETFalNakuvBK1a0uLJy5BbpatOh+v2KT0mdORtv5Z+LXRb9njV5HnKuvgiLYMLvGoiyOP3pFbj032sQjnG4+fX1eHV1GR79hqwMzHxU9rS+Y/4wbPndQjx9yWS5eisV7l88AukxMiDxYkfRWKRnbNzVIS+D0dTOnkceFEXhDbHySUJxpr76JFm8S431JKwvb8Wtb2zoUpWWhD31fjk3BiChEYtKUVm1rxmcjiFcf8OBRkL+9Xr89AbOHELK428cf2MHe/Z/SESFSSe/Pdsc1oSCEsGElFCP8awrQTHJ9+yGIIuymFiFKSkqB1V9saJ+RAQyZg6ft1jzXkNWsgoCNgpBnGfUPltqxGqDMAw6HfaTHgRlTT0eUxSFk66+AfOu+FlSLpmaqJgKClDy3rsYvma17OXV1zFAVFLAYmRg5sjEndgXYdRz34BtbYUgCGCb9Adpe1YGOIHFNuMaZF4zVr5Z1IweIG2ovh09WNNIalPgJHze9iu82PAytoUWY01A2zfIKMr2cZjgZ8lx02q/RrGjDdcPXYtFIxSCZBJVjjjPwMcquTWswJB5PeJFTVtYzk/hBQoBKIOqx24CbGLsc+x5wIXakkevqKi4s7LhEC/61jpVS/SI8vusTJuFBnNypYIEh9kgS+SJpm/PfX8A26p9WLmvGS+vLsOnW+vwwAfbsaeN3Nw0BJgEQgw662oLAMMOfoNBEXK+IbFiq6dCP1KuC2M0HjEX1K5g1hDtIDUoTT+xTmqapyYqoRiLc59dhY+31OI/K7vWwDDKcjjlr99rtpkMNArSlPuszhfBv3/Qb3LZn7Cxsg0AMFGsMOttPDTnIXxz/jeYXTC7t0/lsINtJPexZZgHAMAH2+/OXXLwI9lLZXPbGMQtZHJvjCvqU4AHXAxpCMnGOAiCoC3xjXgR5clzi6f9RNWsmyfANmO6QlR0FBUAciNEAIhV6e7SIQzpSoGAISsLlMnU+bYqfQADRCUFLEYaFrHih7JZYR41SvN66McfEdnenBQHleDIJpP7/ur1sAxPA20jigkf0hKVqIEBmyDDl0cnAgAE6KsB0qTXKrbntjBx2DaSrG6nMQa6dpO8r1FcpcZ4Bj5GIQmcQJNwUKQNeesfw48WEvdugAeC6rJwW43AzzcBd+0CznsBGHW6/JogCCjbTFQcd04u8oaSFUPN7h3K9xN7T+ybfhVqrESOPG9Kso24BGmyOtCorVr4brfSLfZjVY8ejjaAFd2Ah3tonDs59bH1sOGzD+XHQZGo9FToR1JUDKa+paZIoCgKaTZFyUtlWZ8rmqOpicoPexVjqq5UaQFELUmE2cDg5nlDNdse+aydyrN+gs0iUZk0yNOr5yGBpmhk2Y5Mj6MjCSHOI7ytSWOoGW8kBMBU6ATtaN83xjbBgNyG9WBaCTk2sDR4M1Ev9kZ5HIhy2BBk0coJcFLiWCQAHMtrKme4UJtMVNhdG9AezEUu0A4HECPnqc5RkRCrqkZkj9JCIbjB360EWNquLISZzE62U+lDGCAqKWAxMrBwZKKhLRYU/u0prD1O/APTBtTc8yBCm/Sb8QGA2UUu1mgoBIHnwdjIZJoY+rHccE3ym4XOdaxtbSSTRZopQQFoLQPE/BpjSJQpaTt8Ia20GOaMwOvnY2rFi/K2z7lpGJKlrBDcNhNgdgKu5FyP5spyVG7fApphMOq4eXIZZc0eZYKJiB4rPl5ROf549tiU32lCoQcA8OHmGvywtxFbqtoQjLLYXqOoRNuqtbHluOgd8e8Lx+DxC7RJZKnAsXGs/e/bmm05uWQAV6tAhwKZqPSxsI8ab90wCyNzSWVUKkjVQHUiUflyex1ueFUp5+5qI8GfypNL2E0GGm6rES9cOVXnHf0THC9gn0i4R+Udehfm/gyBF+BbVoHInp53HAYA3zcVaH5tJ1qWKmNPvJ4QAEO2DYaM1NWR+Q/OhH0GKcunvMR23sFQ8jIxwgvY529DZZwQBGNMqeZkYzwom3LscFUdQJFpte3FR8G1VbR73ozDASFOxhs+zEJQhTsje/bgwGmnofGviqO4EObaDV+lhConqT8pKRIGiEoKWAwMzCJREcxWmAYNwjunurByFAXzuItgX/gnhLfpJ1sBgNEuVr0IPGKRCGi7qKgkhH64Ij0TsPYlSgkhsRGfwyCqOo4cYvcMAPXbAQCmauI2GzdnIOTTTvAhzggElbjsWn4kNhZegfOnKsZvavfYRDSUE8k/b9gIpOcXwpNDFJPmqgq8cs9tKN+6CTGRqHhZhai0V6Y5ebAHAPDxllpc/sI6nPH0Sizf09huvkKEJiv+SLB97wg1tiz7AiuWahtxjRtK1JieCv1wcfK3ZvowURme48TndxyPm+YNSbmPVOkjEZLrVSQFAKpaupbT0yLeA+qiLym515qQU1TR3DP5Qr2BipYQYiwPi5FGYYqw2rGC0OZG+L4qR9OL2zr9HoEX4P2irFPkJrCSqKzhbc3wfn4QDf/cDN4XA2gKxmwbDGmpq9UoiwGmkhJk3XEHHCeSvC0LTcEgXp9pTZtx+g3D5f1tXBQ0yDXMxjgIMWVMr3njawAAw0ZAs1GEVvwFkU3acLl7SSmyf048tGiHQw79QAAEVYK5/+uvIUQioAzac0+y+O8EBFVpc18MQ3eEAaKSAmZV6IczkwvFF/PBZwNMJceDomgAhIyYSpK7kBpsJrlhVDQUUIV+tCSkZpfOjSt0kqj4SLmZjRGJitkF5I4jj+u3AREfTPVkUokY0xELawf9+rAiB94bvw7clZ/gyRtOw2UzB+O8KYW4ePogzBuRWiZuFImK1ChNKhWWXnv3D7+RS3QvPI5kmS8c035zv3kjknNYnv2OOMWqwxTXHack5AYpkagE9Nuu66G5MnmlI/XUSPydugtW7BPVV0M/nYXHSs5/f2MQ17yU7Dxc64t0qXFhW4hcr789fQxuOL4UAPCbJSS0GktoAPTuhm4G5fsA9taT63FIliPJw+dYQ1S0cegKwlsb4f+2slPkxqDqjO3/rgqxMrIoc55QCNpqkMffRJiHp4GiKVAUhcwbb4BtIlF7i0w0GHFCX/T0NciaNhqLx1Sg9MCHyArVyZ3r2SgHXrUAbNtOCJOBDUGIRAA2gnj1T8oHMhQcx+XDlEdUa9rhAPg4IBpiqu0r5ErThKReIYUjbntwHD+XfJ6z/eKIvooBopICZgMth344kwW8wMMX84HWWdg7jytAFNrVPGVi5IZR0WBQVlQ4lWlWyOfFto27k44nCCxGzc5D5iBtzX08vAJR3+uywZpUBmwziIze4gJyxLBK3RZg/zI4aLJvwB+SJ2BpQt7SpoRzDvB5GJLlAEVRcJgN+Mv5E/DIOePbTUyVWo9nDiJ9K8x2R8p9L5k9DP+7ZQ6eumhSyn0AEmZ45Jxxmm1bqwkhu3CaYlS0eFyeTFaK8khILhLovKJi1ulxIRGu5qoKBFpT24+nAs9zqNqxTc5xYUVFpS+HfjoDt2qQ/2aXkiv0yjXTke00QxBIr6jOok3M0/LYjLh38Uhs+d0pmD2U/A0ld2IJVa39W1EBgJIU1VTHErg2JTwodDJ0oc7/k/IyBJZH6//2wb9C60ItsPrl1s65RGGmdHyDACDzaq1NAqXj5cOIYfycKcNRXPEF4nVBGGhybvGq7eBU406MI/e6Ma5q4RFXXcO81t2ZdkjXBrkn+BCL1jffxP5Tl6Dtv/8j59QDRMV+/PEY9PzzKP344y6/ty9ggKikAEVRsInKRnmQx7OfbsJJrdOQRSX3w2FcJpTH1mnfryIqn/ztMdAucrHxwbh8U/mbGjWJUUptfBwjZubiwl9Px+m3k5wLgfeBi6yDwNWjbDNRSaSW4HYp9GNxK4rKlneAdf+Gy0gGiJC3DSGxK/OYE0gtf3PMgZXCeGznB2OzMARZjq6ZeUkutI50MskYjEbZSyURDENj4iBPh+XCAHDx9CKUPboET100UbN9UpEHT144EfcsGoHJRR7cs2gk3r95NkYNJipNVxQVtRmbhOzBpcgbPhICz+PAhnU672ofXzzzJN76/b34YSkx1zpaFBXJb0aNsQUuHD88C9NEd98fD3ae2LWFJaJiAkVRcFmU44/MdeGZSyfjfDHhOtyNQbmvoKaNENYCTzsd1Y8BCIIAtkGZrPlw53Lw1KQhesALIc6j9uG1CK6phffjA4jXK2RAXRkDALYpOci6fryspNAWZdwxDnKCththm5KTFAZJNIFTbzMNIeHReE0t7BbyeU17y8H7lXEnLjasNbBagh1e9y8AAtyLtdYMjKhC8yGiyggRFq1vvIHYgQPgmsXUgsTQT5iFb1kFoge86CwoioJj7nEw5qSuuuzLGCAq7cAl5or856c6HP+DD7+ovRInDPpN0n60w4hoTNsEjzLSiIsTVXNVBby+BoChSKa4j1zkYT+5ODPNQdx+13mYde7F5M0CC1cmGdwyC4lUx8XL5GNLBCEoEg8bI974ZhdQegJgdgNsGChfATPNwmg0yOcBAOn5hTBZbRAEAbf5r8eS2COIwpTkFNsRIkGx6sihyInqx4eKxAF+UpEHZ00qwM3zhoKiKJgMNCYXpcEqypldISp6+xotFmSLqoq/OXX+kR7YWAw7fvgWALBnzUqyTVJUTO1XHPR12E3Jg7dUyjxF7J+0pbrzg6YU+vHoECAAOHVcHqaLBCjRZr8/QWoFkefuGx4qvQW2KSyPeQDg+yJ1EYIaAqss4pqe34rG57ZoQudtHx1A9KAXAssnhdTTzx8Oc6kSkqdUzsDmUjfy7p+O9POHIxF6igoltnNgpB5gPI9BOWT8OLgzDvCKmsOK7VWMCUSFrVkPU8F2OOYWaLbTIlERxOIH3zfbEd2nbYpKMVqi0vreXpLv88oOHCsYICrtwCFW37hNbqSlKBUGgA/+/kfUerUTG8VQCLQo2xrL9oMRyzwlI6KwOFlamTiMOcPAceJnUCwcopOtzWXCBfdPw4STlFwRX1MDoqEQmirIDa+EftyANQ24VKlmoSjAJdbQS54nJpsNGYVEGUqPk0S17qz6pMlenZti0QmpjD7+pC4fG4DGVyPfbUG2U3/Al8jR+k/+p+nk3B5SERWb2E491Na16gRJ3QKAYGsL6g/sAxeTfFT6t6KSuOp0W4248QSyupSqWaRmgp2BFPpJs6X+XaSQY/9WVAhRyT/GFZVYhfZeC/5Yh/qnNiC4sSHFOwj4hHLdWKX2ONF9bWj81xY0v75T3mabmIWs68cnHUutqDB2IygdZ2YAoEypp0TaZJKJRXERIVFlDdkIWZWxOW4k458m9CPCXDxIJ5GVPBc4VvxOPChzQs6jUUtUJH8YvXJmPsIisKZGk2JwNGCAqLQDK08uBJuOdM8270eZbwXWN32Jsu0b4Ys3g+UVSZO2G2WnVACoP5hMVFqryI1qYTggvQQcS24mg1EApVI3soqcYAzKRelrrMf2775CPBKG1WpGlkWMkVrEEshBWsdRV47WTrklRsNnJPs6xcaJL141rTM/CQCgqbIc1bt3yh4pVpWKYlaRlimnnY3bX3kXi26+s9PHVkNNTEbnpy7vVJOjLV9/3qljh0W5dsntv0TukGEoHDUWRrMFDjF/J5CiC3QqhBKaMf7vzw/1eR+V7sBtNWLNfSdjgugLMiqP/O2rWsPwRTqW9CNxTrbHd6dIcAQAm6jihOKdSyzvi5D8ZY51oqLOT5EQrw2i9a3dYFtSe/AIKQzQ7NNzNc8jO8m9ahzkRPpFIzVKigRKpQqmSqwF9EM/akiqittlRb5xOwTQaHMPBSW2KpEUlcTQDwCYigcnbbOMIrYAtEMJyTDppZp9bFO147l8rjp5N63v7UXb//aj5Y3+70GkxgBRaQdmltxgnqyNSa/FgzVY27wS+/zkNU5g8UHF04hk7kLW9ePBOEwoUnU3DXnbwLjJhBXZ04roAS/WfUoSaQPCIMCRCzZOyAnNaFeR/pYmeOvr5Oct1VVoqiRqyvhZU2CiRelRYuIJrD29sFjz/LZ3d2BNJSEZRj6OK2cNxohchWwIgoC9P66Gv7kJiRAEAS/ffQuWPqh0ZlaHe4wmhf3POf9SGM2WbpfDMTSFqYPTYKAp3L1Qx35aOidV6TJj6JwrraSoWF1uXPLHJ3DBbx8heUme7ikqUhhOQqC1RXam7e/JtGpkOEywqgZ9j82ETAf5fp0pJfaK+QkMTcHVTv8b6TP6a+gnynJyOfexHvpJ7LmjRrw2tQKaqKhIkBZ8iTDlpk5aVisqtD31dZcY+rFO1FY9SkSFY03wGEhCb9TsAW23g3Y6EZdCP3EdojI4maiYhw5F0YsvALRCnpg0LVEBpU+shDifZP4W3krG7K7kr/QHDBCVdmAQc0zSaLLK/tK9Gg2ONgBAW9WKpP1ZIQb3uEEyoz/11l/Ak0sqaziWBSMy+dCGBjQ+twVWcf5uYMeC5QRwcfLnoChlcI4EA3jupquwe/UP8rbmqgqUbSEEKX3EVGDYKYDRBgzWt8VOy9e6tcYoI2I0mVxMfBxjCrQrkF0rvsOHf/kj/nPXTUnHklQUCSarDTSjDAJj5s2H1eXGyDknyA0RDwUvXj0N399zIkbmplZURh53gvxYnSS7ffkyPH/rtWgoS7Zil/KDrE4XKIqSyZRD7lfU1qXzlEI/Um8liqYRj4rdU48iojJ1cFrStgw7mTjaElyXfZE4+AT/m1YxP8Vtbb+tgKSo9NfQT704OZsNNNLb8SI6msG2ROD9qhzxutRkhPOnDlGk6n1DO41SxEQDY35qoqLOUZEqMHX3U5Fwx9wCpF+oXSAxaR4AABcGHAwJ7UfNHtAmAxinE5yB3P8GLtmLiXHpj2H22bPhOllxi6WdebJCI27RP1leUDopHuUYICrtgI5FQNmzsCh8CgCg3tiEj2euh3MuD1+sUfc9GfMXyI/tnjRMPe1sAADPsqBtWiZvFyd4irJi/WfliEco8blyg9btTS5fBkjFEACkFQwGLn0HuL8GKJ6ju296vjaBK0abEBdZupmPIqtpt1wRBAD7NxCvjLiOQ2vQq1UaEpNnR8yai5uffx1Lbv8legIui7FD6dxid2D8/EUAgJjK/v7zZ/4KX2M9vn3pOc3+B66+bQAAayFJREFUPM8hLHof2NwezWt2sXmXv7kR/pZkRSkVpNDP0GkzQVE0BJ6XVTCTtf9L/+/dNBsXTy/Cr08dnfRamjjwt4SUSae8OYipf/gat72pVSPVpcntwdbPFZUar5Kf0h8NtnoCza/vhH9ZhZyjknHFaLjmF2n2aZeoiKEfQ7bWLI9xmHRVFev41J5PiTkqKfdTKSq0LZlMG8TQMOvzw24l5x41p4EKVoF22sGJpcSZZ5wEy/jxKHjqKZT8778Y+s2ylJ8JALaJ+QitfBIAQNnSteoLT85JTwniUzTFld3qjhIMEJUUEHgedCwGY6ESH6w1NYGxm2HMsSMiVtJYnVqWrFYXAIAWQxEcGwdt1V5oNloiKjZs+bYKMWmOVVnoN1drDa9oRnsMSbFJDPeo4crSlqTFaBNiotQ4OrALG/7zJN76/X3y6wKv70kAAGGvVlJ0ZfaNviFSWbSkYqjBJ3yfiN8PQeABioLNpVWTHGkZMrF47qardI+nB18TyTdypGfALlpUS1VWZpst1dv6DaYMTsMj54zTzSuRFIOWgCLxf7ylFjGWxydbazWqikxUUlT8SLD282RaKZH2WA77xKu1vkbGHBscx2kXTYllxWpIyaK2CaoxhiLEJf3ikWA8ZqRfMhKuBYORceXodgkIlUBAOrOfrqeKm4wXvNcLh5tcoxGzB7RBgMkSBidW6DjGj0LJ22/BtfAUWEaOhDE/P+lYahiys8CHiUJDWzNgKlIIndhrFYwn+VoSVC616jAQ3cH91d8wQFRSQBBt1A2igdp+cyVWODfCbcoG7XIhLA6kY48/GXmtZMUwddHpScdhDKLRG8smSY5WSVGhnXBlWhAVBQxB5UzbVFmmeY8nR0kkM5otnSoHtji0ZIqlDIjT2nORzNuaqyqwZ01yWEtCYi7GtDPP7fDzjwSkkEs8SiZLNdlKrEQKivknVqcriVhSFCW3AgA6V6YsCAIObCTukwUjRsGRTroSN1eR31Ty0zlaIVXvtKhCP2pPnt99tF1OtJVLk9up+AEAm5jUGON4sFxq4txXIXX/znP3fzWtq2Cbw5q8MQCwTc6GIcMK2mJA2nnDwGSQSbczigrjUa4l67hMGDOtMA92Ie/e6bCNz4Lr5CJYR2WkOgwAgHaaYJ2YBdvUnKQFoxrq0A+lY9cgdUrmgkG4ssgCJGzNgmBg4HZuASeG1I0JSm1HoE0mWIaTSkzKaIGxSMlT4YLk+jcNSh7r1eExNWlJ/I4CL8D/fVVS5VR/wQBRSQFeJCqUlcTkn8p7HXGaxYysU2AeMgRBO7lILf4AJlQ04OTtZTj+0quTjiMld/Isq6OoiDcg7UTYF0NUzL/iWGVlGkxI6pQVFACOjMzUsvJ5/wEMVuDC15NDDxSFOJU8UYT9Pnzw+MP6xxMRSgj9lEyY0u7+RwpGsc2BFK5SV+1IJEaCRLbsKQaTeVdeJz+OdqJ/UFt9LQLNTTAYSQK1FE6S1BjTUaCotAepH1RrUJl0IqreIq+sLsfSdURdUszeOlJUlAmjoYtND/sCpETaHFfXTBT7O0JbGlH32E9o+3C/JvyQdrbSGds+NReeJWQibi/RVkqmVRMVy9DkHKnOgKIoZFw0EunnJXunaMBQ+o9FSESFDwbhLsyFRWgFz5jQ5iiFszAC2kGuW2M3/KQGPfs0eNFPhbKQEBNlcpImhBRgHZ1MxgRV6CfRS0aN8JZGeD89iIZ/bOpW9+XeRp8gKv/4xz9QXFwMi8WCGTNmYN26rruC9jR4MdeBEp0G/QxJCBM4O0DT8NnIzbN7mx0UAAtFg9IpQ5VCPywbRzyBVFgNpJSXou0IemMI+ci+0ZBfTgpNTF51qxQVZ3o7q4ix5wD3VwOjTtOSGfFxjE6eKGr37UZrjTbUxPNa6T2UoKgobrq9CylxNx6NIOTz4uW7b5FfSwzfSN8hMT9FwqDR42Q7/c4QFX8TyWVxZWXDaDInhQMtttStBY4GpEmhH1WOSiAhdl7dSghkq2z21r6iYjYo19W972/tkfM8kmgRSduxlkjr+5JUIwbX1MqJnrm/nJpU9mvMIeNqvD6ka38v8IKsEBjSlJCHmrQcDqjHSl1FRXKSDYZAZQ1HNks8XNrsIvESc1SMHVzfemDcbggxMt6Yh5AkXsqWLn6uCaYiFfkR7w+1QsKrFM1Em/14o5Jv2FnDvb6EXp9l3nrrLdx111347W9/iw0bNmDChAlYuHAhGhraNwM63OBDIYA2yJ0rA0wIAmuHNxyHv6kRMYEHQMPLD0bc6JSZdiLUikooodzORFtgM7AwWkR/Fc4CgAIEAUFvKwTxfzXUYQlnRibaBa1jBy22IE8M/QDAwY3rk7bFI9oVT1jV1+Kka25s//OPIEyialK+ZSPefOBuDcGr2LoZZZs3yM8lIz67J/XqzCL2LepMR+aAmHTrEP8eieG4o11RkSbjRpXyEUro8ColxXpls7f2FRX1hPH9Hv3E9b4MiZAda0QlMa+DMtFg0pNzK5h0CyirAeAEeD87mPS6EOMAceHPOE1yCxLNZH2YIX2mZpukqAQCQPZIOFgyT0VsJPeGFch3NbRjHNcezCUkN8VUOgzDVq5AwV+fAUAIGm0xwHniIFgnZsGQTuYM72dl8nt5le+MENeSP15lABdYWd3vVJVeJypPPPEEfvazn+Hqq6/G6NGj8c9//hM2mw0vvvhir56XEA7LagoPHkE6ApqzEKIiTnQU7QRFMYiaPeAc6UmlmADAMEqOSiyhps5EW+C0MXCJ8VqKogCK3AjBlhYse/GfcuXI8FlzMfX0c5BTqkiozsyu923gxAs0phP62fbdV0nbEit/JIXhhMuvxaSFp3X58w8XpNBPLBxGW12t5jU2FsV7Dz+Ip644F2G/D211pMOpOoyWCLNMVDp2upX8Zpxiz6NEReVoz1GR8jDqfYpyJSkqUohHylGRTNCynB2vjJ1iSalFJ6mxr6NZnBjSjnGiwrjMuuFpiqJkVSWwqiZp4pQnXQMNykgj964pyHtgJuh2vHd6CukXjoBjbgEsw5MXMrQY8ueDQSBnLCw2cu1H7IXANV8iLpDr2qDTdqIzoF3k+EKMgyEjA2DI8RkHuY/cC4uRcdFImEuUIgDpt1OHfhIVFbZZuTeFOA8+2Ll+S30FvToCxGIxrF+/HvPnz5e30TSN+fPnY/Xq1brviUaj8Pl8mn+HA3xjGSiTSBroMARKQGbjFHjDcSUcQBGC0eYegm+G/gIf/31T0nEYueqHRTAYx+feOFb4yQVlYqxwuh2wqZg7RZPPDLS2YPOXn8jbj7/kSpxw2TVwZ+fIqojUXLBL3wupQz9sNDlerC73BSBb1Ot1H+5NJOah6IGNRrH1my/RWkuISlpeQcp9pbYAnQr9iMRVSqJNIip97LfqaUiVLbXeiDxoBkWikusir/nEiae8mVw/xZ3oKPz+zcQXqDONLPsaJEUl41gjKoZEopL6+6edM4w8EJJdaCWiQlvJ3562GNqt6ulJ2CZlw7OkVJdgqXNUQFGwZJOQT9iQDr5gGniQczSm6NbcEWiR4EhEQxCTZRMJmuc0Jdk2vK0J/u+rtKGfBDM4tkm74ORa+1feV68SlaamJnAch5ycHM32nJwc1NXV6b7nkUcegdvtlv8NGpTczbgnwG/5EMaSEwGQsM9HlTVYFv83vOG4HA6gKMKeywcvBABU7kx2M6Xl0E8cQW8UET4Gb4TYG5toC5wZOYio2K2Ud9JYrpVDzWKegyMtHRc8+DCuevxZTQVQR9hjJ0rMevckAMDcsckuiXpIVlTIRNPX8i4kRaUj/PDGS6jauQ1A+0TF3InQz8q3X8er9/4cW5d9AaAdonKUKyo5IhmJsTxaxcEyGCMTjeSB0xiI4u2fKlEuutcWZ3T8m0iVQd5wsnFcX4YgCGgWc1Ta62d0NIJPIBztERVjtk32llI3LQSU0uQjoaB0BRqiAsDKk/GhLWbHe39WQufdDf1IVUfhrU0QeEH+PSlLAvEx0ICYQ9Py+i54Pz2IyG7t/COFf4Q4LyctG0T1nm3tnO1CX0G/01Tvu+8+eL1e+V9lZeVh+RyBN8JUOg8AkBfPQjHLggLgC8VUiooYbjApE5NQuwWoVJKBNYqKNwY29B3CQTKx0RQNV04JSkSfAJqmMOaEyQCAVe+8rjkfk01RDApHj5WbCnYWyzLn4f3cM/CThxzf4Uwd61WX7MYSiUpfVVTMXfSroKgkIzw1LCK5iAT0iYrA81jz3ptoOLgfvNhQTKoiohJKno92omIy0MgUy5GljsEBMUdFUlv2NQRwz7tb5PdkdyL04xa9IAQB8KewUz+SEAQBH2yqxrPf7UcknuzvIggCVu9vRlVrGDExQTTDcYwRlYSQAt1B1ZNEZBKJCh/WVxJ6G4yUTCuOC6awUl3YUE4SWykKYAzdJCpi6CyyuxWhDQ1y5VNixShFURoTOwCI12rHKkmVYVsjgEB6A0klzokeN30dvUpUMjMzwTAM6uvrNdvr6+uRm6uvFpjNZrhcLs2/wwGO1x9gwiGfSlFJnhw3P/EYfnzqOaCVZFYzRiVHJdASARfbAU5g5QaGzpzhmLJoMOZdOgJXPjoHM84+L8nUDQBoncTYroCljai2FoCnyHHcNiPOuucBGMxmnH7nvfJ+rqxs3PH6/5BVVAwASa6uEknra5NvWl77hkpJEASYrKmTXKXmitGQ/g0d1um+LCkpuao8oisee7rT/Yf6M6RO1/sayO8VimoVFc2+HitonYqKRJgMtOxQ2xbu/W6w26p9+PnSTfjT57sw5f++QoNfuyr9aEstLn5+Deb++VsA5Hta+2HYqrsQBAFcQhuF9hQVAEmNWiVIz6l2PE96A7KiEgpB4HkwLbVJ+xjMTLfdiNU+LpFdzSlDPwBAJWzjvNp7JLy5EQ3/2ozoXqK0GDIssIglzoFVNRB0yHZfRa8SFZPJhClTpmDZMsVemOd5LFu2DLNmzerFMwN4VrnB/puxVNneXKlUx1DJq4WV/quxLnAJ2jYR0zR11U9LTVDOa4nxZJBLzx8BxkBjzNwC2FwmMAYDJp96Ro9+lyibfEG6LAYMmTIDP3/lPQyfeZy83WQhlt9SBUtrnTbRLSIrKn0r9GOy2rD41l8kbR8zbz4W3vhz3PLiUtz15ocYOo1cV9POPK/d40lVP6lyVMK+5KZfVpE0O9IzcO1Tz+PmF96UCd/RjhklpIzyh70ksVhKptVzZn35mumdPq7kYCs1M+xN7GtUyGkwxuGllWWa119brS37XDQ295iyzxdifFLvmVQNBOXXXfpEJbSZVHpZhiR3Qu5NqKs7gytXgmtMrkjrDAlPBUoVMqKMjCpXR2fx2kEeTNtHBxA76EPbR6TXmSHTCuu4TFAWA4Q4r0mw7evo9dDPXXfdheeffx4vv/wydu7ciZtuugnBYBBXX51snnYkIREVTohjtesr8GLyaeba91C7l+SYMAINW0g/lyZeQbwfJGfaeDQGb2NYVmEkouJ0JHuhzDznIs3z2edfekjfpTmQvBp1pbBYlnJqpF49bDQqh3s4lpUTbvuiLfzouSdi1nmXyM+tLjdmn38Jxp64ABa7AxRNY/Etd2LRzXditmo/PcjlyQH9qh+9poVWlR2/JzcP1m6YPvVXzBlKiO368lbwvIBq0UK+NCuZ0Lq7YO/tFnM8WkO9T1SqWrRh0K3VWrLaFNBOtrOHtO+WerRBr5JEyolIBVqsZuFU7xU4HrFKsWnouL7RokMCZbHIZKXyZ9cjuncfHAGt9xQb676TstpvhjLScugnKUclxbb2YMggi1BDFlE5ExNs+zJ6XVe78MIL0djYiAcffBB1dXWYOHEiPv/886QE2yMNnhON2vgobBlDEQ1Uw+ovgyGmtO828ICJ8yFkSw5TxcSJTE6mFfMYJBUmzPrhMWXpOjOabTakFwySbe1nnXfxIX0XXaJi0Z8spPCF2WaH1elC2O+Dv6kRFrtDJizS630R0886H7lDhqFw1BgYLckN4UxWW6eqpdTJtKveeR3p+YUYOUfp0pzo0Aso5OZYRFE6Ia4NvghW7W+GP8LCamQwNt+FexePxKOf7ZL37QpRcYirxlCq5mtHEBL5clkM8EVY8AkltY0JRGVsQd9SAw43dIlKZvvVeFJbEfV7ubYowAMw0Ifd4K2roCgKTEaGnEwLAGN2/Adrpz8gPx83L3XuW4fHV6kxlJFWFBWd0E+q/B1TsQuxsuRqWIk0GjOtiFf6EW8Ko780eOh1RQUAbr31VpSXlyMajWLt2rWYMWNGx286zIjHRWM0IYrbChZAsLiR6JHDCAxMMf3eCdEgGbSU/AQBscAnEDgS0wxz5H2JcUUJna1i6QyaxHNR+1G4EqTEk665EY6MTJx45fXyNqfYcFDyCfE2EPXI6nIn9cjpKzAYjSidPA0mq+2QZHepPNnXWI/V776JT/72GARBQDwWxTt/+A2+ffn5pPf01d/kSEDyRQnGOFz2wloAwLhCNwwMjWuPK9Hsa+pCoqFN7KkV7APNCatEd92TR5FFVJOqT40gCJqE3wy7qVMJw0cTEvNTAKXcNhUYHaIihSQMGRZdd9heR0KTU3u8GSUTFfPNUXO6mC+ngsal10ArOSpdCP2Yh3h0txvSCS2RyGO8rmOPqL6CPkFU+iLoU08CALRaoxifNgK01YM4r1wYtHEozHx6SqISEW9aVsVD+Phu+XGRbT+A1L0uDDp2/N2FpKiML/DI2xJJ16SFp+GGZ17SVBNJnZF9jcR9sbGMlExnF5fiaIeeOhIJ+LFr5XJUbN2U1ErgWIfdxCQljt50whAAgJHp/jAjJdOGY72vqFS1EjV1QiFRStShHkltkZDl1Dc6O5ohkQ0pJGEq7rjQQU9RaftEzKnQcbTtC8j97W81z405ORrfFPMhJACriUpgeZU8P0ghMjXUybQmlQGcKV9f7TaIBnvS3yW8qRGhrf3D9XmAqKRARCAXXpSJgjI7wNg8iPJSt2MGRvvpsJeUwjRjju77o1EaYGNorArpvm6myQXCtukTldLJJOGwy2W3OmgWB1SpMgMABmd0nGMi9cIJiYmjDWVkADkWiIpe+bW3oR4Rv5aYXvnY0yidMh0Lrr/tSJ1anwRFURq32fGFbpw4UnFOdnazzLSvKCo8L6Cmjaz0JxYRx9KWUEzu7LzmQItm//5oUneo4IOETFpGpCPn7qnIunZsh++RiAoXIEQl3hgCW0/GTFNh38zxcsw9DkOXL5efMxnpmtdNtp4hKjJoCoyOOsc4lcWsulRZL9yWdsFwWb0yD1YIpP+7KghxDuHtzeCjva9apkKv56j0VUTiAkwAWCoC0AYYbGmI8uTnYowWUBQFa2Eu0godQDXJ9j/J9TesjvwM4ZgVUd4BBBvRXJ1CMaEbwAmpFZUpS86C2WbH4HETD/m7SCu/DLsJn98xF7XeCIZmdzwIJFa++JqIstKeUdrRAsZghMFs1rj1+hrrNQ0OjWYLMouKcfY9D/bGKfY5ZDnNqGghk8yYfO1q+hcLhuN3H+3A5CJPl44pKSqhXiYqjYEoYhwPhqYwKs8JhqbA8QJqvRGs2NeE+xIaJ/ZH23+ANAPsbrhFUkUYuxHGDnJTJEhKAR+II7CmFoJI/GinCc55h8fMsydgzFFIuMGTBp5TJGrDIfzt7VNyEFytLXlm3CZQOp2cGbdCVCzD0hDZScgyo+NdYx2jJHZTBhqu+UXwfV0BzhuFb1kF/N9VwTohCxkXjwQf4zoM2R1pDBCVFIjGOAA0ISqCAMqRjSgnKiqM0s9hYtEOhK2rMNy6HAWm7fA5p+Gn6lmI8nYgUIfQnnIQ4UrLlE10C6KcfgIaQHJbJixY3CPfRQr9ZDrNGJnrwsjcznnPqBNKeY6T/WOk/I2jHRa7AwEVUfE21MPf3Cw/pw1962bubQxKs2J9OUkyXjRW20fpytnFGJRuw5j8riWY2vpIMu3+RnLt57osMBsYjC1wY3NlG95cV4FnviNhXIm8AIC5H14b8bogGp7dDOfxhXCdXNSp9/ARFr6vK2CbmCXnPDBpnVeB1bb4gRXVspuqc26B7uTcl5Bx4w1oeeVVZP/qHmz/Ssn3OJSQn6nQCee8QfB/pxiZGlL8nmpCYhrsQtb148C4zUnVQMZCB2izdqq3Tc6B7+sK8IE4/N+RqqXw5kb4sqzwfVOJrBvHw1x0eDzKuoP+SfuPAGLiBMVTUSB/IuDKQ0xUVNgo+d9gomH54DKc6H4GBabtAACLmygVEcEB1O9A6OAu6P3MZppMeHwwDs4bRcs7ezQtu3sSTcHu9R2RCEnljq34+9UXoHYPqdww9zH7/MOFRA8Ub2MD/M1KTFcqPR8AwR3zh2N6STrOmVyA44dpO3tTFIWTR+UgV8dXpT3YjOReCx0hc6p9DQFNc0UJD320AwBQKIZP54rl2J9tU+wJOJXNf39UVLxflkOIcvB9Vd7xziJ831YisKIaDU9vQrSMhIjNJZ2f4CgDDadIitimMDhvFLTNAPuMzrcH6S1k33EHRm5YD3NpKTiu51o8SOXDEvS6TwNaRYUyMzCXeuQSZGUfM7Jvmpj8Xqf+2OX7ugLgBbT9d183zvzwof/dTUcI8YiYBZuWBZjsgDMPUbFkmRI7DxuZ5FWebcoSAECQywAOfIcgnwZAu1+WOQAT1UaeCEDtI+sQWl+Phmc2HY6vIueoSDbnnYUU+vE3NWpCIMeKojLu5IWa576GOjRVKoP4oNHjjvQp9WkUZ9rx9g2z8MQFE3sskVQO/RwBRaXeF8H8J5ZjxsPLNNsrW0LYVUcWET+bS/KzxokJtQeblJX0vYtHyo9Tlf/3aaiIVmI3Yz2EtzchsFzxEBEiHCgjDWNe18YH17xBmpnIWOhMUgD6OoQe7EWV2IHaMVu/ikhtppfYDNJxXAFohxGZ147VVaaoDnKoErsv9zb619VwBMFFxZCMmMwHZ56cTCt5oRi4hFr1BQ/BmZcJoBx+Lgs48B1C3ALNLhOGu3Hi3Kmg2wYDGxI+9DD1XZNCP13tO5LKfbaveqj0NIZM0ZbJV+/eiViY5GCMO+kUHHfxlb1xWscU5NDPERg4N1e2yY95XpAdRiW33WnFaZg/mpQml+p0f/7Z3FJwvIA31lbgzgXDD/v59jhUExrvj2vs76MH2uD9qhxpZw2FMccOLhhH86s7kw5hyLB2OWRDGWkYMqxgG0nlVF+t9mkPeUPdKN/W3PGOnYCadGReOxamfP1xmLYYYJuSAyHKatQVgHRXdi8p6faCgY9xxGKfopJIUG+g98+gj6I8uxHPZ7+HpgKxaseZh6DoViu5yzIRVY+iwmnApMvhFG+yIJ+OoJ+Dn88GY54q72axWsCc8jvggpePyPcgnVwPTVFJ2n6MKCo0w+Dc+36P3KFk0pFISlZRMU654XbYXMeWoVdvQFJU9jcG8PCnO+US4Z6EpB5EVBUXQVU5dHWblCCs/L2LMmxQzwEnj8wGQ1O45cShWHnvSbo9jvo6+IDKy0TlWioIAhqf24rYQR+8XxBFMVVuXUdOtKmgrlTpj0Rl4slFOO78Ybjkd4fuAaYmesYUJEVC+vnDkXHZaF1C0hFJsU3OTvka74+j+vdrUPPHtSkrU48kBohKCsyZeTJmn7MY02bOJRvsWWiMkouGYkh8mt/zDXlt8pXAdV8DtnTYXCbQDCCAwUuNLwIAnDalSR1jav8mjFX5OyW7dha+MIu4GD9N72aOSiJMlv43CHcXxROn4KLf/wnqWSktv7AXz+jYglSevL8xiOe+P4AbXl3fY8fmeAHnPrsKFz+/BpE4hxaVL4ravC0olm06VOEIs4FBgYqMOPpYl9/2EC33ofH5LYjVaPtYqSsQ1Zb2ktIBKARFckxNBJPRvbHBoHqfIbvvtefoCIyRxoSTByEt99DVZkHVL+lwdo9OO0+r+lnHafPKwPIQwiwCq2sO2zl0FgNEJQWGeIZgYfFCjMoYRTbQNBpjJFFWIipxQVQohpwov4+iKdg9WjIywf61/DguKBee88Tk8ruGpzchuretJ74CAMWV1mk2dNnbQbLTTwRFH1uXDWMwwpmu3MTOjMx29h5AT8KWUCa5vcaHt3+sTLF311Dni2B9eSvWHGjBI5/uxO/EhFkA8EWUiVpqsGhLcALNdSn3eXd9YnoDjc9uRnS/Fy2vK6EbPswS63oRgoqoqbdzohtvKkXFmNU9oqJO7rSkcFY9dqAQlcNZ+UTRFCyjiAeMfXpuUm6MBKnsuTdxbM04hwA2Hoc3Sm4mWiQqrCAOVEPna/YdNVtbmplurJYfx1XysuuUwbqfFfyRVBKQDpdh8IeQSNjd/BSA9MU5+1e/7XjHYwDubKX3lDOjbzVKO5qhl5R6z3tbeuTYAdVk/HJi5+Mnf8B/VhIn5qB4/zkSEjzVBncOc/9LnuV8im12rEpbcciriYpaaWmNQGB5XUWFthmSV+WdhH16Hiwj0pB24YiUE+axAsvwdBgLHEek8intnGFIO284PGcOSfm7s00hfSO6I4j+swzoZQSamwBQABiAItKkhfYDl7wNmLXmaZMXDsb6z8rBiX9cG90GgEx0DpW9cqoYonTBNL++E5FdhM2mXzIStvFdnyDbQmQw8ti6Z8lfOnkaJi48DfUH9sKZnoniSVO6dZz+DldWNiAuQJ2ZA4rKkcLwHCcMNAU2oarCF4kfcmWNP6KvCkj4/Uc7cNXsYllRsZu0w6W6l09/UlQkCHEewR/rYJuag1i1NgykJiJqQgMBaHl3D8KblDL9rJsmAJwAQ5a126EK2mpA5tUdO9keC6CMNHJum3REPotxmmCfSuYmddKs1AySD8YhxHnEKnwwl3qOyDnpof/dXb0EXxO5MWnaDoqi4KAbMdn+PlDy56R9GQON4y8ajm9fI74jNroVZw9qxF5fJiZf8VyHnyUNDBJJAYCWN3Z1i6iERf+JxD4sXcHJ19zY7fceLfDkKCqZa0BROWKwmhiMyHVie422wu6LbXWobA3j8pmDNcpGV+DvhFK5o9YnKyr2dhSV/kJUEstoW9/bS7r0JjQUlJrhAQDn0yZTqkmKfWaexpJ9AP0X6sTd3F9OBXig6dUdiO5pRduHB5D980m91r+qf9xdfQCS0RdNE/XkjPTfwUIHAKN+TFbdpMrGeFHqiKH05peB9NSdNY25dsTrgmCbw7qvd8faOCo6PfZHA6q+hDHz5qOpshxWpxM5pUM7fsMAegzHD89KIiq/fJeEf3bX+fCvy6fqva1DqBNmU2FPvV8ujbYn5KioVcq0biqWRxKtH+xDaH1D0vbgT/VypQ1lYiDEOI2iwramrvpQVwcNoH/DNikbnD8Gc7ELFEMDDClzrn9qA+J1QcQq/L1GSgdmr05C6nMj0KRE0UBFAVtGyv2dqjI9w5XvAKf+BSg9od3PyLyGSJ9cW1QTI5bAh7qeqyIrKn2sd0N/gyszC6ffeS/mX3fLMZdM3Nu4a8FwXDlLP5/r+z1N3T5uQOceYxL63OxrCCihnwRFpUTlpXLSyNSlnn0FwdW1ukZe8dqAPN5Ikn9oYwOEOIe2Tw8guqc16T3mUjIOpjIjG0D/A0VTcM0bBHOxUoZvzLbBNoEoyKH19aneetgxoKh0Ei3VxIGRoknnVCMVA4pSE4/cUjfmnDcU7mwbUJoJlM7r8DNopxGUkSYxQR07fT4YBzxdk7kjIlGx9MPeIwMYAAAYGRq/P3Ms7GaD3FdHQjjOYUNFKyaLHY27gkA0OUflrIkFeG+D4rb6zHf7YRDJS2Iy7ewhGXjozDGYOjg9icT0NbRnecAHWfBiqIfxmME2EN8Y37eVCIlhHspiAG03gGsm7QUyLhsFzheDIaf/lRIPoGtwzM6HaZBTJiy9gb59d/UiDm5uROWOFuQOceDAT29j10rS1psykFwFJncYcOpj7R5j4vzONfaSQFEUDBkWxOtCiJUTqVtqMsXWh5LiyJ1BRAz9mI/BtvMDOLpwz6KRmD86B+c8s0qz/ZxnVqHs0SVdPp5e6OfB00fjoumDYGRonPWPlRAEyD5EiWSEoihcMau4y5/bGxBi7VdtcG2EgNA25Tv6v1HKwDOvGYOWpbvl55TFAKOt/1U6DaDrMBU6YSp0drzjYcSAhp0CdQd82Lq8Gpu//ATbl5PeHxRFgTbkQoAA7mdfAa6elz2ZdJLzEtnXBoC0QafFAaG90E+U1bcYj7CHnkw7gAH0FeS79XPC9tZ3vaGnHlFxW42YVpyOiYM8WDRGWx7qMPXfdV0q3xMJbAPJNbGKvhqJMBU4Nb2AKLp3kioHcGxigKikgMFEfpqgV+mOanG6QVEmxAFE44enMY9luAcAECsTFRWXCYy4ykmlqLy8qgxjf/sFVu5LjteHxZj0QDLtAI4GpKrwefzLPV0+VkfJtHNUHaDdVmNSMm1/AheIdbwTAEO6FXm/TraBpxjqsPUiG8AAOsLA7JUCBjH5lI0pLd+NFhKPjVNKkuqhQkpGk1xq7VNzNV0xDdk20HZRUUmxKvrth9sR5wTc/ubGpNckpaWrrrQDGEBfRGKyq4S9DV1XVLxhMnkfN5QQkvmjcjSvD8tWyjV/f8YYGJj+O1x2pKhIoMwMGKd+BVNPtvYYwAC6gv6rZR5mtNVsRzy0CpFWxQXTaLIgFgdYCHIlwKHCvaQUtsnZcg07ZaBhn54L31fEKdOYbZPLlSX76lSIJrgHVjSH8OY6EmceCP0M4GjERdMGYemPlWjwd71xWpOoMlw2czB+vWSUpooHAKYOTsO5kwsxNNuBsyYV9Mj59hbUDQfbQ7uGbQNEZQC9hP67RDjMaKvfDy6aqFCQnytOAff/dytq2sKo80bwwaZq3PLGBoRiXScvFEPBVOjUxHzt05XYuDHHJnscsC2RpPerEUsgKrctVc5/IPQzgKMFT18yCdOL07H6vpNw/xLSi8sfYeUwZ2fREiREJdNhwqg8V5LqaGBoPH7BBNw0b0jPnHgvQs/uQA+UhfwGaecOk7dl/myceJAeP60BDKBTGFBUUsDqcidtk9YTLAVsqmjDKX/9HhYjLa/Mxua7e2RQY5wmpF80AmxLBMYCh5yxzzZ3QFQ4ZSR5d30VNle2yc8Hqn4GcLTgtPH5OG08CZkKggCLkUYkzqPBH8HgjM53r20WuyVnOLrnbNufIJUfW8dlwjYxG82v7kjaxzjIKRtK2qbmwDIyHbTDqLiRDigqA+glDCyzU8CmR1TErPc4BMQ4HoEoK5MUAKhqDekei+cF2c+k058/MRuuk4pIyXImUVS4lkinV0a//u9WzfOB0M8AjkZQFIUcsYtxV8I/kTiHoKjApNv7vqvsoULqhmzIsMA6JgOUKXnoTz9/uPyYoigwTpPGMp1JtyS9ZwADOBIYICopYHN7krZJCwo2RWVeYuhFwnWv/IRZjyyDtxs+KABAO02gHSSh1vdleQd7E+R7tGWcxn6cCDiAAbQHiajUtHXezr1ZDPsYGQquftKn51AgKSqU6AUjhXbU5INxtO+Lkn7RSJiHpyHrhvGH6SwHMAB9DMxeKWBP03G6FJlKPAVRqWpVBspAlIUgCBAEAd/sakBrKI5lu7pnQUxRFJzzSFVQYFUN4o2KchPZ24q7YEHiWmdwhtYxMthDyb8DGEBfQ7F4rR9sCnb6PXVecq+m20291mjtSEJSVGgxB8U2IRt590+H62TFlJKytk/YjJlWZF0zFuaSZLV5AAM4nBggKing0CEqgkAGtEGZNpgMNN6+YZbm9b0NAQiCgLKmIKb831f4+dJN2FLlVb2/++djGaGcT/2TGyCIoaSmF7bhHJhwJoh87RNb1yeGmkyGgT/1AI5OlGaRirn9jZ0nKst2kt5dkwZ13Xq/P4KPkvFAXdXDuMwwF4tN5gz0MUHYBtA/cfRrnt2E3eNJ2paWPwZ1ZcCUoRn48ZzpcFuNeP/m2fhiWx3+9f0BNAWiqG4L46sd9YiyPD7cXIMPN9fI728Ldy/0A0Cu/AEAcALq/7YR2bdOlDelgwwyy3c34vQJ+QiKA9PisbmwGhksGqt12RzAAI4WlIplxQcaA53aPxBl8fZPpGz/1PF5h+28+hKk3DapqkeCIcOKnDsny15NAxhAX8QAUUkBk8UIxjwRXGwHpp9xBjIGDUZTdR7qympgsRngtpIbe3JRGiYXpWHV/mZsrfZiU2VbygZlDf72q3baA8XQoF0m8D4SW2cbw2h9W3HjDIk1SevLW0WiQgamq+eUYHqJvi32AAZwNKAgjeRj1fs6d3/987v9aArEUJJpT7LJP1ohiDkqtM7YZMzpfKXUAAbQGxiIB6SAwUTDaDsJFs+tmH3hFRg990Swom2+USf5bnQekVD3NwThTaGcLF1XeUjujtkJSWzh7c3yY6eoqEgExS+3ph+o9hnA0Q2pq3Gokz4qS38kaso9C0ccMyHRVIrKAAbQH3Bs3KXdgEFVvseJPiaxMBkITTpEpUhM6KtoCaUkKt5wHD+Vt3b/nDKsyL51IhwnFIJKICASUWkOxvBjWQsaxVLNxNb0AxjA0QabSSEq/kj74dVglEWT6J+i7uVztEOq+mnXeXYAA+ijGCAqKcAYaIhzP9g4ISpxcVVisiavSorSCVGp1CEqF0wtlB9XNOt7rXQWpkInPItLYMzXyrVu8WS/2dWA8/+5Wt4+QFQGcLRDrRpO/cPXiHOpLVSrxRJmt9UIl+XYyMsQWB4QrRPoAUVlAP0QvUpUiouLQVGU5t+jjz7am6ckg6IomMQBMCr6n8QkoqIz+UtEpbwlCJ9IVH57+miUPboEfz5vAk4Tk/ZSqS1dheRWK8EF/Yz9VPkyAxjA0QKLQZl8oyyPHTW+lPtKpoyFadaU+xxtkCp+AIAaUFQG0A/R64rKQw89hNraWvnfbbfd1tunJMOZQQYzbyNZhcVE+dSoo6iUZNlBUUC9L4r9YvWBlHALAC7x8V++3I1X13TOtK3dczuBqDSxNGL/nYqomI+RGPwAjl3QCR2V1x1sSblvZQu5l48loiJ5qFAmWtNTbAAD6C/o9VnM6XQiNzdX/me3950MdHcWGcx8TSJRCYuKis6qxGUxYmQuSajdVUdazquJivQ4FOPwwP+2gW1HngaAcIzDo5/t0vTrUcM6LhM5d05G9fGkaiEHNHJA4TnYsRxO3FeYiRtPGDLgjTCAYw576v0pX5NyxEbkOI/U6RxxtL6/Fy3v7pET92VX2gE1ZQD9FL1OVB599FFkZGRg0qRJeOyxx8Cy7TuoRqNR+Hw+zb/DBZdIVLwNYXBxHoFWkoRnShHnnTLYo3nusakUlYR4eDDafoXCX7/eg38u348z/7FS93WKomDMscNvMyACAVZQ+AtsGA0GDCgsqYrh3sUj2/2MAQzgaISULJsInhewYm8jAGDu8KwjeUpHDHyYRXBdHUI/1ctWBtF9hJwN5KcMoL+iV4nK7bffjqVLl+Lbb7/FDTfcgIcffhj33HNPu+955JFH4Ha75X+DBg06bOeXlkPyTvZvbMTqD/bL2/UUFQDIdlpSPncl2FMHYqkJ2cp9TXju+wOdOscoz+OA2H+9BNqBiFM1TBzAAI4VfLu7ET+VJYd/moMxtIbioChgQqHnyJ/YEYA6H6X2kXXgvFF4PysDAFCmAaIygP6JHicq9957b1KCbOK/Xbt2AQDuuusuzJs3D+PHj8eNN96Ixx9/HH//+98RjabugnrffffB6/XK/yorK3v6K8gYOjUbZrsBwbYoNn+tfI7Npd9tNbG5WbbLrHotUVFJTVQue2Ftp88xEuexH/rqTJuKXA1gAMcSzlNVvkmQDOEyHeaj1j9FSOiuHtraJD/mQwP9vgbQP9HjQctf/OIXuOqqq9rdp7S0VHf7jBkzwLIsysrKMGLECN19zGYzzGaz7ms9DZPFgMxCB6p3t8nbLv/DrJQJaU4VGUmzGWFWVSPYElYzN762Hg+fPQ4zSzM02+99b0uXegJF4xz2pSAq4W1NEDgekV0tMBY4YPAMtGkfwLELyRk6x3Vkxo/eAJ8YUuaVwYT3DyisA+if6HGikpWVhays7sV/N23aBJqmkZ2d3cNn1X1YVD0wzHYDnBmpJ3uXKnlWaj0vIc+trTI40BjERc+tQdmjS1TbArJrZmcRZXnsR4rEXAEIrKiB97ODoJ1G5P96ZpeOPYABHC2IshzWHSS5GjnOo5ewJyoqnE8hJ0K8/QT+AQygr6LX9M/Vq1fjySefxObNm3HgwAG8/vrruPPOO3HZZZchTadzcW/B4lDCPI40S7tVNE5V6CeRqIzOd+Gfl01p14CtJai/4nniqz262wFCVPamUFQAwP9DFQCA98cPyb5/AAPoy5hc5EnaFlLlgT34v+3453ISCs12Hb1EJVFR4XxKGN154uHL5xvAAA4neo2omM1mLF26FCeccALGjBmDP/7xj7jzzjvx3HPP9dYp6cJiV4iFw9O+ZKzOQ8n3JPs0LBqbiymDtSQsyioDS0DMWynOsGHHQwvl7X9btjflZ0ZZDkEA701w4WIE8Dli4GkKxgIHAIAPKAZzXGvq3J8BDKA/44Urp+HvF0/Cp7fPlbc1qCbpt35SlMqCozgEKjUflMB5xcUPQ8E1f3AvnNEABnDo6LXC+smTJ2PNmjW99fGdhlWlqNjT2icqakWlJNOmu4+R0XLDypYwhmYTUiE1Vct2WuT+JakgCAIEAYiKcm7IY0IlePwBEQy/dAwmVIQRr9a2vY9VB8A4TaCMR2ci4QCOXaTZTTh9Qj4AoCTTjoNNQVS2hlCcaUeM5cHQFDhewKnjcnHR9KJePtvDBz4x9OMlZM08xAOKGfBUGkD/xIADUAfokqKiylEp8OgTlVpvWPO8pk0hKoF2Oh4LgiCHnQRBwAX/Wv3/7d15eFvllT/w71202vIir3FiZ99jEichIQllGTIkLYVS2rCHddItHQjQDPDrJEyhEKAtU2AyUDqFwBQKTFtoWVpIKaS0EAIxJoTsu7PYjlfZkrXde39/XN0rXe2yJF9JPp/n8fNEsmS/im3p6LznPQdun4jGMaUAABPP4ZZ/moSDp51YMr0Gbld7xNfofn43AKByVSNM40vh7xoEX2mhpnCkoMxtKMfhTif+0HISX5pchdYeFwRRgtXIYePVcwv69z1i6ycQqLA0RZ3kMQpUErCEHEWOV0gLaAcA1pZGv234UEJHyLRXVyBQsUapY3F6BfXr97h8+PiIXBio9Gcx8SxWnz9JvT0fZ62dv/wcpinl8OzrQfk3p6Bofk3cx0VIPrm0qQ6/az6OjwO9VJS/ubEVRQUdpAAhxbQsABFAoCyNeqiQfEZ7AAmMmVqOxvPGYO6yBkycG/80EscyWPWl8bhkdh2a6sui3uaGJeM0lx2DwVStM7D1UxzY9rlmYTBFPRB4ApIkCf/6m2b1+q5AU7fwmT58ZfxZJp59cqDT92ZyjeUIyRdKIbsyHFR5M1BmKfxpyYJTfqy8Xfv3zxjpqZ7kL8qoJMByLM65ckrSt//hRTPifv77/zQJ88fZ8fLHrXjj81OaacpONaMiv/v58aWz8PvmExj0CTh4egBun4DeQR/+caBLvc/pfjm1azZo3zGxNiNYmwFif/xpzdQEihQapai9x+VDj9Or1n6NhEni/g45e2Sst8HfGdxmpq0fks8ozB5mJp7DuVOq1K61yru9ltZe/O+H8lTlokBGhWEYVNrkradr/ucjfOWx99ETdoS5yxk9o8IwDGpunYuKa6ej/LLJ2XtAhOSY0KL2pvs2q28AotV+FRJJlODrkIMT41jt0EXa+iH5jAIVnSjv+j461IXfbj+OSzf+A/3qE2rwidZmCqarXV45sxKNyRD5RMQVG2GZVQm2OPg1zFNzp0cNIdkQ3gX6aKBGpdAzKv5uN+AXAZ6BYVSx5nNMgQdppLAV9l9uDlNOCDUf60XzsV7N50Lf+RWHzQ9q7dYW46r3ifOOyTylHMZxJTCOsUGMMqhQ9Apg6R0XKRDhBbOfn+gDgLjNFguBZ79cd2YcYwNr1T5W1lr49TmkcFFGRSelcQr7PCGtrsMHHbb2DIbfHEDsU0YAwPAsqr8zG2VfnQAmyuTn0KZwhBSanYFAJTzTUmjc+3sBAOZpdrBh09rDLxOSTyhQ0Ul4AMIywLlTqmA2sLhgevB00eiwDrexMiq1SbYFj1ZUJ9CwMlLA/IHBfIWeUfG3OwHIhbRs+POLtbAfOyls9NurkzHlwYZwZ02w44V/OQsMAwiiBD6ke+0ZY8oAHFUvHw9kVIw8C68/mHmxFwX7vcQTLaNy+onPYL9mOqyNlSk+CkLyR6Juz/lMEkT4e+Tp0IYqCxieBXhWrlkBwNHWD8ljlFHRyYy6Elw+fwwAYO2yqWBZBgzDaIIUAJgTNmxt0Ccftbx+kXZuR7KNrFhz9PS30rWWkEIwO9CxOVQhn/px7+8FRLlfChs4KciG9E6hjArJZxSo6Oihb5yBHf9xIeaNtce8zcSqYjx65RzMrCvRXF9mNeLer80EAEypKY5216iiZVQUkiSh/x8n4HivNeZtCMkHm25cgCevnQeeDQbwRQWaUZFECd0v7gUA8FXW4KgNMTgtPd7fPSG5jn57dcQwjGbicixfmzMabp+AO3/3uXqdzcxj5VljUW0zYdboyHePsfBxBisOvH8CfW8eBgBY51SBL+Aps6SwlRcZsXxWLerKLDjWXdjHkz0HetXW+aVfGa9ez3CM0kEfDFvYowNIYaOMSp646Iw6zeWqYhMYhsHyWaM09S6J8FXB25oml6Hqu7PVy47NwVoY6lhLCkG1LRiYV9niDxXNV742uYjWMrsK5oll6vX2K6YBHAPT5LLodyQkT1CgkieKTTweuTwYVJw3Nf7coVi4omAGhzXzMI0tQeVNswAAUsixaMlNgQrJf6GlWw325AP6fCIG5oWF16GYp5Rj1F0LUHn9TD2WRUjGFGYutEBdPLsOe9v6sWhiBSxp9IQwTSqD50AvihfLWRrTxDL5fHTInrYy3IyQfOb0COq/jXxhvC8Lb9AoBt5UROuVwtmSOw1ISC6jQCWPGDgWd39letpfp+La6RD6PDDUFAGQ97K5MhOEbrd6G9FJGRWS/+rtFuw65RjSfSVRQt/rh8CWGGE7e7R85Fdng190out/d8M8zY7yb0wGZzNCUjIqVDBLChT9Zo9ArJmPeFLj7eawQIUyKiT/rb94JnyChJuWjE984zC+kwMY+OAkAMC9uxtV3z5D16LUwV1d6PpfuY2Ae083et84hIorp0F0y1kj6j5LCpX+bxFITig6s0ZzmQIVUghGl1nw9A1n4uzJqTczDO3Y7D3qwODOzoT38bU7MfDhSc3R4EzwHHOg67ldmusGW04DCNao0BFkUqjoN5sAACyzqgDsVS8LLgpUyMgWPgOr99UDME8qizvgr2PjZ5C8AiS/BNuXRmdmHW4/Tv/3ZxHXK1PR1RqVGM0cCcl3lFEhAOQ6lVC+U04M7urK+DtDQnKVMOCFq6UDgsOjXgYA84wKAPKRfe9JZ9yvIXnlbRhXS0fG1tX/t+NRrxedPkiCpJ7Qo60fUqgoUCFR+dtd6HpuFwY/P633UggZFn2vH0L3i3vR/mgzfB0uuJrlYMNQZYF5ajkAwN8tz9oS+r3wd7shunxw7++BJGkDet+JAYgZykoKjihDQ1kAEiA6vRAHAzUqtPVDChT9ZpO43Pt7YZ09tJ4thOQT74kBAPKJt/ZHtqvXs8VGcHY5GOj9/QFYZlSg/efNmjou08RSVKycofl6zo/bYTt3TNrrighUOAZckQGCwwt/r0fN4jCUUSEFijIqJC56l0ZGAkmSomcuAHAlBs04id7XDkUUm3sO9qFz0xdAyBaq79RARtbm73BpLjMcq9an+LuCJ/WoRoUUKgpUiMrSGHkygp78yEggDfohhTSHU5in22GeaodlZoV63eBn0bdDvUccgBDcAvK1u6LeLhWixw+h16O5zthgAxOYBC30yZ9jDCwYjp7OSWGi32yiKl8xBRU3hrXbpmFmZATw93giritZPg6V188Ea+bBV1pQc8e8iNuwJbE7v/o6XDGzNEmvq0OuiWFtBlR9bzYss6tQ/s0pamdaNVChbR9SwChQISrWyMEy1a65LnT+DyGFSuhxR1zHlWqHGEacquEYVH9nNiLwDIwNNkCQMPDBibTW5Qts+xiqrTA1lKDiqmngy0xqRsX54Sl5bbRFSwoYBSokQuioeKVQj5BC5g9sr3D2YC0KF5YtCQ8GWBMH3m5GzQ/mwziuRHN90YJRAADP0aG171cogQpfrR2oyBi0W7J0NJkUMgpUSATbOWPUPXnKqJCRQMmomEICDq5Y29gtfNaP6JL7lxgqLXIGRbmdiYexvhiAfEw5nV5E/kCdi6FGG6iwJm2gwpiolowULgpUSFTKO0TKqJCRQKlRMdbbYKi1gis3ga+wxL1P6VcnqP/mQ7aJWCMHvsoKxsBC8opRt5WSFbr1E4oJm56eqZ4thOQiyheSqJTUskgZFTICKMEEV25G9b82AVJkBiVc8ZI69d+hRbWMiQPDMmCLDBB6PRCcvoRBTzSiV1DXFbH1Y9KuLfwIMyGFJGsZlfvvvx+LFy+G1WpFWVlZ1NscO3YMF110EaxWK6qrq7F27Vr4/f5sLYmkgDHIvxq09UNGAmWwH2vlwXBswiCFrzCDYYIn4jhbMFBRalvYosAsHtfQntP8pwcBCWCLeHDF2nqZ8BqV4kV1IKRQZS1Q8Xq9WLFiBb773e9G/bwgCLjooovg9XrxwQcf4Nlnn8WmTZuwfv36bC2JpIAxBgIVrwD3wV44m9t1XhEh2SO6U2tDz1dpMxxcSXDrR8mesFb5a/W+egDufT0prylWIW04+1XTULK0IeWvT0i+yFqg8qMf/Qi33XYbGhsbo37+7bffxq5du/DrX/8ac+bMwZe//GXcd9992LhxI7ze9HoPkPQp79gkn4jOX36Onpf3wdcWfyAbIflIkiRInuQmEFf+yyyYZ1Sg7OuTNNeHZlSUrrHKSRyh14POp3emvC5/4O8tvD5F/qLBAl3LGZURGRZCColuxbQffvghGhsbUVNTo163bNkyOBwOfPHFFzHv5/F44HA4NB8k89Stn5Bunf7eyKZYhOQ7ySsCgdd9JkFGxTypHJXXzdAUzwLBvxcgZOvHqj01lOrpH/fBXgCAcYwt4nOSGNySDd2CIqQQ6RaotLW1aYIUAOrltra2mPfbsGEDSktL1Y/6+vqsrnOkUjpf+jsHdV4JIdkluQM1JKw24EhV6VfGw3JGpXq0X9n6UQiO5AN9wemD77g8K8g8zR7x+ahZFkIKVEp/lXfddRcYhon7sWfPnmytFQBw9913o6+vT/1obW3N6vcbqZhoKXB//hXW+jpc6Pr1rhG5beXvGoTnUK/ey8h5YiBQYUx8WtkJ2zljUHH1dHXmjuTWHu0XupM/pqy03meLDZptJYV5RgXKLpmI6tVzhrxeQvJFSseT77jjDtxwww1xbzNhwoS4n1fU1tZi27Ztmuva29vVz8ViMplgMplifp5kRrROl2KUoW25ruu5XfB3DsJz2IG6dWfpvZxh1faTTwAA1f/aBOPoYp1Xk7uU3+tMD+C0NlVj4IOT6mXHX46hYmVxUl1klf5FsRq5MQyD4sV00oeMDCkFKlVVVaiqqsrIN160aBHuv/9+dHR0oLq6GgCwefNmlJSUYMaMGRn5HmToWIsh4jo1RZ4HBj46hd4/HlSLDkWnD4O7umCaUDri5qJ4DvZSoBKHlOKJn2QZ622o+49F8Lb2o+u5XfAc6sPAP06gZOnYxGtSgicjFckSkrUalWPHjqGlpQXHjh2DIAhoaWlBS0sLBgbkfdcLL7wQM2bMwMqVK/HZZ5/hrbfewr//+79j9erVlDHJAQwXmQLv/9txnPzxVvjac38bpfeVA5qTEYCcXel8dpdOKxpeUshjH85MmOgR4PrsNERP/gS1wa2fzAcFrJmHeXI5Si4cBwDwnUrub0f5mVFrfEKyGKisX78eTU1NuOeeezAwMICmpiY0NTXhk0/kdDTHcXj99dfBcRwWLVqEa6+9Ftdddx3uvffebC2JpEno80Ic8GHgHycT3zhHeQ/3QRgo7OPvolfA4I7T6uXhzIT1vnYQ3b/Zg57/2zds3zNd2cqohDJUy71VfEkWp6tbP5RRISR7LfQ3bdqETZs2xb3N2LFj8eabb2ZrCSRLuLLcznhJkgSwDBDjOGjHxhaMunPBMK8qu/y9HvS/14qieTXof/84Bnd0qp8T+ocvMHN9IteZDe7sGrbvmS4lo5LpGpVQfKUcqPi7BiGJEhg2ftGuEqiEDx8kZCQaWZv1ZESQ3ELMIAUAhB4PJJ+Y1lHUXCIJEtp/9gkknzwAz71X2wVVGbg3HBgLD2kwf7Z9gJCtnyxmVLhyMxijPKTQc7gP5ollMW87uLMTvX84KK+JMiqEUKBCYmPMXMQRSwCQcvyYsuiMPknWftU09PxuHySvCH+vG4YqK/xdg3A2d8C2pC6iQVcu6/vTYfT//UREHU54kAIAvtZ+uA/0wDypPOvr4ooN8OdZoCJl6dRPKIZlYJ1bA+fWU3Btb48ZqEiShK5f71YvU0aFEB0bvpHcV/29OSg+ezRqfzAfljlVMNQWAdAWauYiIUqgUrt2Pqyzq8CVm+XbBLIM7Y9/iv53jsHxzrFhXWO6+rccjwhS4un8n53ylliWKYP4gOCgv1ynzPnJZkYFAMxT5EAxXkGtc5u22SVlVAihQIXEYai2ouyrE8BXWlBx5bRgh8wcn6gsBoplDWOKUbJ8HMq+PkkdFMcHApXOp3di4KNTasbIe2IArh2n0f3iHoje3O4Xk2ordkXv7w9oghVfhwve4/2ZWlYEoS8/Ri4oxcbZzl4YauRusr7Trpg/w95XDmguMyZ6iiaEtn5I0hheLgCUhNwMVJzN7eDtZgxsPQUAMNQUoeQ87YgF3m5W/x36ouA94kD3EXlulHFsCYoX5V4zLcHhBVvEQxpioOj8uA3Oj9sw+r7FAM/i9FM7IA74YJ1bjeKzR8NYl36vFSkkyBMcXjULl8uCDd+y+3TIlZvBGFhIPhH+zsGINvjRtlQpo0IIZVRIKvhAa3B/7m39eE8OoOflfTj95A54DvQCAGznR86BMk1NXKcheXMvEBvc041TD25Dz2/3a/qiFC0IdnEuWdoQcb+ySyfBNKlMc533lBOiyw9xQN4iczV3qMWb6Qr9v1PawA835/Z2dD63K+ppp+6X96L7pb2azFKwmDa7QQHDMjA2yAMGXdvbIz4frd8N1agQQoEKSYE6wyQHi2mFnpA5KhLAlhhhCBwJDWWeWJawiVa2X7CGomvTF4AowfVpR3Crwsqj+Jwx6m2KFo6CpbFScz9jvQ3ll02GdW61ep3Q7Ubfnw5rbufvyszwSU1GpV+frZ+e3+6De1cXTv9ih+Z6wemDq7kDrk87NAXXah8VU/YTzNa58uDV/i3H4T3eD3/XIE49/DG6fr0Lnf/zecTtw4NMQkYi2vohSWMMgd4PORiohBduxmoZz/AsRt15Jk7euzXm1xrq1kq2hBfBhhZ/8hVmFJ01CqyZB2czonzFFNjOq4erpQOST4ShrggMw8B++VSAYeDa3o7uF/eqX0s52SUO+CD5RTB8eu9dxLCtH10E/rv8nYPwHu+HcYycxQj9HZE8AhD4FRmujAoQmP/z/gn42pxw7+2B77QLQrcbgyEDCxkDC2ODDeZpFeDLzHG+GiEjAwUqJGm5nFHxBOpLFJYzYs+kYq0G2K+ehu4Xok/6lnJs+KKmLwkbcpzWxIFhGJRfOin4aSMH4+jiqIEaXx7ZqI9hGUgcAwgSBIdXU8MzpLWGbP2IOgUqrM0IMbDt420NCVRCsihKsCeJkhrAJDMsMF0My8A6rxp9bxyG7+QAom2isiVGVK06I+trISRf0NYPSZrybtt7cgBtP/0Erk87dF6RzNfmVDuiAkDRwlpYZ8cfnmmNE8jk2qkfTWZCBESX/IKbagaAr7JGXCe6/OBKTYHvk95WjeQXNY329DqeHJoRC92K0gYq8tqEHrd8zJtn1P+HbDMEipa9JweiZrBoECEhWpRRIckLPKmK/T6I/T50v7QX1qbqBHfKPs+RPvXfJUsbkppOCwCVN86Ea0cnWCsPrtgA0Sui/51jOZdR8fdqAwhfhwtA6qdUzJPLIq+bUQHR5YPQ7Ya/xwPTuKGuUhsUAHoGKiHBSWiGxxValyKvTfm/NFRaE7a1zxTlmLLQ4wHqopz0oQJaQjQoUCFJU44n55yQ/Hl4MWk85ql2mKfa1cv9W47LXy7HApXwfiT9f20FkPqJENZqgHVuNVwtHSi/bDKEAR+K5teg/6+t8B5xwHvMgaI0Ak/RrX+gIomSphGeJmhxBtcjDsrX+zvkImK+OrLwOlu4YiNYKw/R5Ye3NbKPTXjAR8hIR4EKSVq6hZbZorwgWpuqYagZet8O5Z1srm39+APv+sOxNmPKX6v8m1NQftlkzc/SOK4E+OAkvGF1PqkK3zrSJVAJK4TWHJd2abd+PEcd6umn8J4m2cZXW+E94ohacCz0FfZ0b0JSlZuvPCQnKcW0CtaaG3GuWgxpS29WjxKo5FpGxdcePVCJdbIpHoZlIgJO07hS+fu0OdXajaFQMj984EVf8ghD7qKbKtfnnWh75BN4j2mDLSU7IQmSprhXcvtx+qng8WV+mAMVZftHUbw42GAw1qwqQkYqClRI0sKnDYsuP1wtw1dQK4kS/F2Dkcd1M3RqQylizLXUe8xAJXCaJV1ciRFchRmQAPe+yKGGiUiSBF+bE/7T8jaKoTb4IjxcWZXu53fD3zGInrAW9ILDi/bHmnHqga2abRZx0K/ZIhrujIoSHCqKFtbCPLMCAHKi7ouQXJIbb4lJfuAia1S6X9wL65zheWId/Ow0ul/aC77SgqpVjeopDeX4brot0JW5KuG1FnqSfIJ61LZk2Ti4mtvBGDnwFWY5uMgQ88QyOLva0Penw7DMqkypsNS9uxtdz+1SL/PlZjBGDpJXkH82RcM4lTrs6LzSpRjQ1qh4TwxobqfMghouxgnaQIW1GmBfMQWDM7tgmVExrGshJNdRoEKSFqtGpev53Sj9ynh14F+29L11BIDcyKv3jUOouHo6gMxlVLhiueZDGWqYC4T+wDYAz8J23hiURBkLkAklFzTAua0NQo8Hosun/l8kwxnWDp4rM4G18BC8AgSXDzyyGwSETstmbcakGs2F1uNU3DgzIluYbVyJ9v+XtfJgOBZFgc61hJAg2vohyZOi1xsMft6Jzk1fZKUeIXSbxzAqWCg7uKMTgHzk1HtcfnfMpBuoBDI0osufMwW1yrwazmYAw2Tv1BVXagoWE6eYUeKs2owJV2oCXykHrb6wzEWmiS4f2h7+WL2cqBkhW8Rrfo8ss6tgCTn5NVzCf5bh9V+EkCD66yBJ4yst4GusMEcZ7Odvd2W8HsFzpA+n7tsK5ydtmh4YgPyOVJIktD/+qVpTkkoWIBrGzKnTanVr/x5GVAOV9B5bMthAAzkpxYLa8KJqrtQE08QyANqtl2zof/+Epvg5USEqY+Y12Qz7FVOztjZCSGZQoEKSxnAsatbMReWNs9Tr7FdPU9PmmT4t0/Pb/RBdfvT8dj9O/vgjuHd3q58TXD6I/V65aRbk+g1D3dCPJgPyu1yuVH4RE3r1GagXTsmoDOUocqqYQI1Pqid/mLBOqlypEabxcg2G97hcwCq6fOjc9AVcO05nYKVB4ROS1UAlRvKJNXIoWTYO4BgUL6kbtiZvhJChoxoVkhIlZV25qhHe1n5YGivR+8eDkHxi1DH16dBkaMK3lfwSvCedAOSaiEzVbnBlJvhPD0Y0WdOLMKwZlUCgMpjazzF8u4UtMsAQqGcS+rwQnD443joC955uuPd0xx1fkLLwrZ7ArwlrNUTNrjAmDsa6Yoz+j0WA3tstDBB12A8hRIMyKmRIzBPLUHJePRiGCek/kuGjqFFOGYVSembwGTz9wgZOqIgufdq/h1O2oHJ56yc0UOGrLWAYRp7mHPi5DO7sjHnEOl3+nugBJRejp47SzZcxcLpnU4bjZ0pIIaBAhaRN7T+SoYyKe283/H0eQIhRGBl4gVFObvD2zJ0qYYe4/ZEtw1mjEtz6GVpGhTGyKL9ssnq9sv3Tv+V4RI3RUEmShP73j6s9Ufzdcu+W8Lop49iSqPfPpTk6FStngK+0oOK6GXovhZCcRls/JG3qaZEMBCqDu7rQ9dwuMGY+5jt741gbvIcd8ByShxHytZlr1sXmWHdaYUB+gWdLhi+jkmqQprStt53foGlkZjuvHq5P2iH0edLucaNw7+5G3xty2/u6exdDDBzfLlo4Cu69wWZ1htro9Urh9TR6MtbbUPuD+Xovg5CcRxkVkrZMvri7muWeHGqQwgC2c8eoR0pZKw/zFO1xUvPkyFNIQzXUgtJscH3WoR7v5Yqz3zRNeexD3foJ77PDKaeBBCljbeFDv073C3vk72vm1aPlilgN3FId5EgI0R9lVEjaMplRURucBRhGFaH0y+NhO3cM3Ad7YZpQBqHXA0eg+Zuhtgh8VSa3foZWp5FposeP7t/sVS+HNwjLBqVhXspBRSCjwhjCeoPEyF5IkjTknjCh9TDuPfIpML7CDMaoDZJ4uxmMhVe7FqtrokCFkLxDgQpJG2sKvBMfYqAiiRIcfzkKzmaMOG6qdLtlrQZYG+XTIlyRAfarpsK9rxclSxsy2ghNzajovPXjPdqvucwWZT9QMVTJW2ipFr7GyqjEnLbtlwDD0H5m0Xr18HazNihiAa7chKpVjXD85Ri4UiOcH56SP2WipzxC8g391ZK0qRmVIXRzlUQJru3t6P9ra+CLaT/P2aOf6LHOroZ1duZnDClbA579vZB8AhiDPu/AQ49HFy2oBZPgBFQmKMMEfaeccB/shTnQtC2RWIFKLKJXADfElvXRTmNxdrOaCQMCJ3o4Fsa6YlReNwPO7e1wQg5UKKNCSP6hGhWSNiaNGpXeVw+g53f7g1eE9ZXgK4d3WFzoC55DCZ6ySHB40fP7/fCeDLaalwQJfZuPAgCs82s0J2myiSsPZiY6n/ki6TodyS//0BIGKoFgK53p1LEyKqGZkvDfQ4YPBnnscA5IJIRkBAUqJG1KC/VUh/lJkgTntra4tzFUZ+5ETzKYkNMpAx+czPr3635xD5zb2tD17Bfqdf1bWiEq/VPCikSziWEZlH8jEBT5RXgO9iV1P8knBwaJBvupRddZCFQAwNokZ9isc7WZNr4y8DvEM7BMG/65PoSQ9NDWD0mbUkcSq/lWLEISt+eHOVAJPRUSOrwuW5Qj1kJfMMgLDd446/D+iVpnV8HV3A733h6Ig8kV1aoFrgkyKoyRA1x+SN74gwPjUQIV47iSkD468u9f+TcmwzS5DJbpFZr7GEcXo+o7Z8BQbR32KcmEkPRRoELSxgUCFaHHndL9fCfjT9blK8zghjlVz9nNYMwcJLcAZHFaMYCY06ZDswaMDsWfTIqt9NWtnwRBgLKtlE6hsnIiyXbuGAwYTwGSpP7+MTyLork1Ue8X2t+FEJJfsvb24v7778fixYthtVpRVlYW9TYMw0R8vPjii9laEskS3i5vT4guP/y97ojZL7H4AwWjlsbKiM9V39KEylWNmVtkkhiGQcU10wEAUpIZhaFSesYAAFsUODkliGqNhWVWBaxNGZyLkyT1mHKyNSq+5IpplSPEQ9366Xn1APwdLoAJZElumoWqmxt1b4VPCMmurL1d83q9WLFiBRYtWoRf/epXMW/3zDPPYPny5erlWEENyV2siQdr5SG6/Gh78GMA8gkS+5XTYnYIBQBR6bpabIB5ml3ti2GeZoexrjj7C49BeaEWsjzvx/VZcJKw6BEhSZKmfb39qunDctonXKq9ZJKuUVECoCh1Jgm/hxisZzI2lIArGb7aHUKIvrIWqPzoRz8CAGzatCnu7crKylBbW5utZZBhwhYbNEdHfW0udD63C7V3zAMTY0qtksbnio0ouaABruYOcOVmmKdkrtPsULBWZTChL63mZPFIgqjWWAAA/CIkr6g2KGNMnC5BChC69ZM4oJBESd4mA+K2yWfMXMj/6xACFa+gTtC2Xzk15fsTQvKX7pVlq1evRmVlJRYsWICnn34akhR/7rnH44HD4dB8EP0pL0KhhG43/KcHo95ecPrUxmJssQFcsRG2c8bA2lipe5tzZRsGfknd1sgkf58HJ374D0g+EYyFBwLHZ0WnT91uydRsnKEIbv0ICf8eQ7Muyv1CVVw3A1yZCZU3zAyeDhvCgEI1uOEZtXibEDIy6Bqo3HvvvXj55ZexefNmfOMb38D3vvc9PP7443Hvs2HDBpSWlqof9fX1w7RaEk9o0avtn+rV0zq+NmfEbSW/iI7HmuE96oi4by5gjJza8yNTM2pCDbx/Qv23cVQRuEDXWdHpU7MYrEW/YE3Z+nHv6sKpH2/F4M7OmLdV1ssY2Kg1KpYZFRh11wKYxpWqway/243+vx2P6EIcT/D/Jbd+Vwgh2ZdSoHLXXXdFLYAN/dizZ0/SX2/dunVYsmQJmpqacOedd+Lf/u3f8JOf/CTufe6++2709fWpH62t2W/KRRILbaTFV1jA2eTL3S/uhb9Lm1Vx7+vRHMdlbdlvD58KhmHAFctrSuXFNFma01Eco2ZwhJCMCqNnRiXke4tOP7p+vTvmbYMBROL1KhmVwc9Oo+/Nw3C8cyyp9Uh+ER2Pf5r09yGEFJaU/urvuOMO3HDDDXFvM2HChCEvZuHChbjvvvvg8XhgMkUvljOZTDE/R/QTuvVjqLGCLQ4GH87mDpT+81gA8rvprud2ae7LV+ReKp8tMULo86iN1zJF8glwHwo2UgstIhYHfGodhp4vyKl0b1UzKkmslwvbHnR92oHySyclvJ/naMj2boKtKEJI4Unp2bCqqgpVVdk7LtnS0oLy8nIKRPKQcvIDkCcaQwjWdjA8C0mU4DnQC39Yr5XKm2ep2YtcwtmM8AEQUuy2m8jg7m61YLbsskkoml8LX6s8gDB0m0nPGhVDXTEMo4rgOxW5bRduKBmVUMkUK4euQxjI7pFxQkjuydqz4bFjx9Dd3Y1jx45BEAS0tLQAACZNmoTi4mK89tpraG9vx1lnnQWz2YzNmzfjgQcewA9+8INsLYlkUWhGheFZCM5gkaXY74Vre7t2pg/kOT7myfqe8IlF2boSMpxR8XfIBcRFZ9aieMEoAMEMhuPdVnUwoJ4ZFYZlYJ1Xg77XDyW8bUqBSlhAKnkECA4v+ARjAkIbA0pDONpMCMlvWXs2XL9+PZ599ln1clNTEwDg3XffxXnnnQeDwYCNGzfitttugyRJmDRpEh555BGsWrUqW0siWVS8pA7+rkFYZ8sZN9u5Y9B1WN7i8Pd54DnUG3Ef0ZO7LzpcoG5G7I98By+JEiSPMKRgwt8rN7njyoIvzmyxHKhIbr96ZDmZrZRs4sOmVseaJJ1KoGKoiRyHILr8QIJAJTTTZBxXkvD7EEIKS9aeDTdt2hS3h8ry5cs1jd5IfmPNPOyXB/tbWKbZUfqV8eh78zD8Ha6odQ/KELlcpLz7F6Kc+un/23E4/nwElTfPSjkjJAS68YYOG4z2f6Pn1g8QOQxRcHjBV0ROslaOGkc7nh4u2qkgKYlgVWm5z1daqIcKISOQ7n1USOGyzqsBY2DhPz2obW4GwDSpDCVLx+q0ssTU7qxR2r07/nwEAND9m+RPuCmEKBkV4xhb5PfX8XgyINcZGcYEuwOHntIKFdpdOBlV3zoD5qnlanCWzNwfpaFc2dcmgi/LvcJrQkh2UaBCsoYrMqDoTG3X4erVc1B60QRU3jQLrFHfF+N4Yg3QC22AluwsnFBqRiU0UKkrxqj/t0DTgl7vjArDMahePQfG8fJWi9AffdK1UtzKJRmomCaUovLGWWqfHcmdOFBRtggZnRsBEkL0QYEKyaqiBcFAhS3iYay3wfal0Tk/SE55UQyfd6NkRAAAYuwJyNGIXgGSVz4NpRTrKrgSk/riDehfowIE+skEZurEzKg4lYxKaie3lO7DUhIZFTGJFv2EkMJFgQrJqtChhMqLdD6I9UKqtP1XpDIJWC0K5Rg1YxOKC2l8lyuNzbjSQK1OjNNPYuD4dqrdhZVAMNHWjyRJah2Lsh1HCBlZKFAhWVf6VbkJYFkSzb1yRawX0vCRAMlkBBRq9qHIELV3SGjju1zJHnC2QEbFEbn1I0mSuvWTbI2KQq0BSlBMK/lEIBDfMqbc+D8hhAwv+ssnWVe8pA6WWZXqu/N8wAZeFCWvoGlK5g8LVESPgGTf5yuD9WJlH4wNJcA/TsrfP0cyKuogwSi1JJJHAIRAJ91sZVSUzzMAY6T3VYSMRPSXT7KOYRjwZaaEHUhzCaNsM0jQTFD2dWrnFqWy9SOEZFSiMU0oDX7/HCkcVbaoomWOlGwKY+RSLoxm43zdUOrsIxOXV78/hJDMyY23bYTkGMbAAgzkQMUjAMopoLAGcPEyApIkwX96EHylBQzLBLd+orSSB+QalepbmsBwTM4UG6u1OlECMqU+JdVtHyA4dDFRRkXJQrG07UPIiEV//YREwTBywavkESC6/eBsxkBNRuDF2WaA2O+LmxFwNXeg5//2RVwfb5vEWFcc83N6ULZbogcqqR1NDqVsbflODsD1aQcss6uiBmfuXV0AAMPo3Pp/IYQMH9r6ISQGJZvQ/rPt6Hxul9zzI1CTwdvlLq3RAhVvaz9Et18zoyaUct98oPaTiXJiK9FWVjym8fI2l//0ILpf2ovBHaej3s59oBcAYJ1dmfL3IIQUBgpUCImBD5lN497VpU5+ZkxczM6qg3u70bGxBaef3KF+jgk7VmuoK0K+UAOVfi86/rsF/t7g9OtgRiX1ImneboZhVPD/wROYCxVO6VvDV0XOCSKEjAwUqBASQ+X1MzWXvUflMQCczRizz4rrk3YA8jFmpb6idPl4zW2Mo/InUGFDTtp4j/XDvadbvRxs9pZ6RgUA7FdNU//tOeLAYGCbRyH5RfV7hM8eIoSMHBSoEBIDw7OouX2eern3DwcByC/MyqkcpSW+IrSWQ50sbOVhnlkBAOCrLUkN8MsV4Y3pel89iNNP74QkSsGtnxjFwYkYqq0ou1juseNvd6HruV0Y3C0HK/5eNzo2tgQWMfTvQQjJfxSoEBKHodoaMeXZWG9Ti0EHPjgJf7cbg7u6IImSppZDGgy8kFt4lC4bh5KlDaj+zuzhW3wGROug69nXA3+HSx0vkE7PF9am3TZyftwOSZLQ8/I++E4FetZIoKPJhIxg9DaFkATCX0zNk8rA2c3of7cVAND1wm74jkcWzgpOJaNigKHaCkMOT4uOJdYxadHly8gMnvD6FsHhQd+fj8BzKFizMpRiXUJI4aBAhZAEwrcdTJPLwbAMuAozhC531CAFkAtQgdzpMptJgsOrbm2lM0CRDRvO6Ds+oP5/Fp8zGgzDqNtmhJCRqfCeQQnJsNBApezrk9QsA2vmkUxf2kKsrxAcXrVrbFoZFVv0E0N8pQWly8fnTOM7Qoh+qEaFkARYS/Bdf2jQEZ4pCW2BrzBPt+fMgMFMEvo8kAbTn2oca1SAZVYFBSmEEAAUqBCSkDY4CQlawl6gy742EZX/0qi5zr5iSnYXNwyqvnWGPFIghNDnUWcgpbO1FatItujM2iF/TUJIYSm8t3qEZFjoC3Hov5mwTAlr5mGoCfZIYUxcXh1FjsU0oRQl/zwWfW8eVq8THF7130ym5vCwQN09i2iuDyFEgzIqhCQQGmxogpaw2hMlcDHUysGK7fz6YVjd8GD4sIxKIFBhTBwYLkNbNCINHySERKJnBUIS0Gz9hP47NKPCBgf4VVw3A95WByxnVA3bGrMtIlAJNLrLZP1NIRYdE0LSR88MhCTAGjlUXDcDECXNC7NmG8jEq/UWvN0M3m4e9nVmVViNCuTZjGCL0n8KqVzViL7XD6Hs0klpfy1CSOGhQIWQJFhmRPbyMIwuVv8dXmxaaGJt72SiGZt5YhnMt85N++sQQgpTYT+7EpJFpoYSGMeXAAAMNYU93Td860dRCMXChJDcRhkVQtJQdXMjPIf71ALaQhUrY8RRe3tCSJZRoEJIGhiehXlyud7LyDqGi5VRoacQQkh20dYPISShWBkVGhhICMk2ClQIIYnFaJVCNSqEkGyjQIUQkpDSLh8AwAejFs5GgQohJLsoUCGEJGSst4GvtMhDFkMGCbIxph8TQkimUKBCCEmI4VnU3D4PFdfN0BTWchSoEEKyLGuBypEjR3DzzTdj/PjxsFgsmDhxIu655x54vV7N7Xbs2IEvfelLMJvNqK+vx8MPP5ytJRFC0sCwDBiG0Zz0YUxcnHsQQkj6sna2cM+ePRBFEb/4xS8wadIk7Ny5E6tWrYLT6cRPf/pTAIDD4cCFF16IpUuX4sknn8Tnn3+Om266CWVlZfjWt76VraURQtLAFhsBuABAHRtACCHZkrVAZfny5Vi+fLl6ecKECdi7dy+eeOIJNVB5/vnn4fV68fTTT8NoNGLmzJloaWnBI488QoEKITnKMqsCngO9BT82gBCSG4a1W1NfXx/sdrt6+cMPP8Q555wDozG4z71s2TI89NBD6OnpQXl5ZCMtj8cDj8ejXnY4HNldNCFEo2jBKAAMTIHxAYQQkk3D9pbowIEDePzxx/Htb39bva6trQ01NTWa2ymX29raon6dDRs2oLS0VP2or6/P3qIJIREYlkHxWaNgqCnssQGEkNyQcqBy1113gWGYuB979uzR3OfEiRNYvnw5VqxYgVWrVqW14Lvvvht9fX3qR2tra1pfjxBCCCG5K+WtnzvuuAM33HBD3NtMmDBB/ffJkydx/vnnY/HixXjqqac0t6utrUV7e7vmOuVybW1t1K9tMplgMplSXTYhhBBC8lDKgUpVVRWqqqqSuu2JEydw/vnnY968eXjmmWfAstoEzqJFi/DDH/4QPp8PBoPc4XLz5s2YOnVq1PoUQgghhIwsWatROXHiBM477zw0NDTgpz/9KU6fPo22tjZN7cnVV18No9GIm2++GV988QVeeuklPProo7j99tuztSxCCCGE5JGsnfrZvHkzDhw4gAMHDmDMmDGaz0mSBAAoLS3F22+/jdWrV2PevHmorKzE+vXr6WgyIYQQQgAAjKREDXnK4XCgtLQUfX19KCmh45KEEEJIPkj29Zs6NhFCCCEkZ1GgQgghhJCcRYEKIYQQQnIWBSqEEEIIyVkUqBBCCCEkZ1GgQgghhJCcRYEKIYQQQnJW1hq+DRelDYzD4dB5JYQQQghJlvK6naidW94HKv39/QCA+vp6nVdCCCGEkFT19/ejtLQ05ufzvjOtKIo4efIkbDYbGIbJ2Nd1OByor69Ha2trQXa8LfTHBxT+Yyz0xwcU/mMs9McHFP5jpMc3dJIkob+/H3V1dRFDi0PlfUaFZdmIWUKZVFJSUpC/fIpCf3xA4T/GQn98QOE/xkJ/fEDhP0Z6fEMTL5OioGJaQgghhOQsClQIIYQQkrMoUInBZDLhnnvugclk0nspWVHojw8o/MdY6I8PKPzHWOiPDyj8x0iPL/vyvpiWEEIIIYWLMiqEEEIIyVkUqBBCCCEkZ1GgQgghhJCcRYEKIYQQQnIWBSpRbNy4EePGjYPZbMbChQuxbds2vZeUMRs2bMCZZ54Jm82G6upqXHrppdi7d6/ey8qaBx98EAzDYM2aNXovJaNOnDiBa6+9FhUVFbBYLGhsbMQnn3yi97IyQhAErFu3DuPHj4fFYsHEiRNx3333JZwHksv+9re/4eKLL0ZdXR0YhsGrr76q+bwkSVi/fj1GjRoFi8WCpUuXYv/+/fosdgjiPT6fz4c777wTjY2NKCoqQl1dHa677jqcPHlSvwUPQaKfYajvfOc7YBgGP//5z4dtfelK5vHt3r0bl1xyCUpLS1FUVIQzzzwTx44dy/raKFAJ89JLL+H222/HPffcg+bmZsyePRvLli1DR0eH3kvLiC1btmD16tXYunUrNm/eDJ/PhwsvvBBOp1PvpWXcxx9/jF/84hc444wz9F5KRvX09GDJkiUwGAz405/+hF27duFnP/sZysvL9V5aRjz00EN44okn8F//9V/YvXs3HnroITz88MN4/PHH9V7akDmdTsyePRsbN26M+vmHH34Yjz32GJ588kl89NFHKCoqwrJly+B2u4d5pUMT7/G5XC40Nzdj3bp1aG5uxu9//3vs3bsXl1xyiQ4rHbpEP0PFK6+8gq1bt6Kurm6YVpYZiR7fwYMHcfbZZ2PatGl47733sGPHDqxbtw5mszn7i5OIxoIFC6TVq1erlwVBkOrq6qQNGzbouKrs6ejokABIW7Zs0XspGdXf3y9NnjxZ2rx5s3TuuedKt956q95Lypg777xTOvvss/VeRtZcdNFF0k033aS57rLLLpOuueYanVaUWQCkV155Rb0siqJUW1sr/eQnP1Gv6+3tlUwmk/Sb3/xGhxWmJ/zxRbNt2zYJgHT06NHhWVSGxXqMx48fl0aPHi3t3LlTGjt2rPSf//mfw762TIj2+K644grp2muv1WU9lFEJ4fV6sX37dixdulS9jmVZLF26FB9++KGOK8uevr4+AIDdbtd5JZm1evVqXHTRRZqfZaH44x//iPnz52PFihWorq5GU1MTfvnLX+q9rIxZvHgx3nnnHezbtw8A8Nlnn+Hvf/87vvzlL+u8suw4fPgw2traNL+rpaWlWLhwYUE/7zAMg7KyMr2XkjGiKGLlypVYu3YtZs6cqfdyMkoURbzxxhuYMmUKli1bhurqaixcuDDu9lcmUaASorOzE4IgoKamRnN9TU0N2tradFpV9oiiiDVr1mDJkiWYNWuW3svJmBdffBHNzc3YsGGD3kvJikOHDuGJJ57A5MmT8dZbb+G73/0ubrnlFjz77LN6Ly0j7rrrLlx55ZWYNm0aDAYDmpqasGbNGlxzzTV6Ly0rlOeWkfK843a7ceedd+Kqq64qqCF+Dz30EHiexy233KL3UjKuo6MDAwMDePDBB7F8+XK8/fbb+PrXv47LLrsMW7Zsyfr3z/vpyWToVq9ejZ07d+Lvf/+73kvJmNbWVtx6663YvHnz8Oyd6kAURcyfPx8PPPAAAKCpqQk7d+7Ek08+ieuvv17n1aXv5ZdfxvPPP48XXngBM2fOREtLC9asWYO6urqCeHwjmc/nw+WXXw5JkvDEE0/ovZyM2b59Ox599FE0NzeDYRi9l5NxoigCAL72ta/htttuAwDMmTMHH3zwAZ588kmce+65Wf3+lFEJUVlZCY7j0N7errm+vb0dtbW1Oq0qO77//e/j9ddfx7vvvosxY8bovZyM2b59Ozo6OjB37lzwPA+e57FlyxY89thj4HkegiDovcS0jRo1CjNmzNBcN3369GGpvh8Oa9euVbMqjY2NWLlyJW677baCzZApzy2F/ryjBClHjx7F5s2bCyqb8v7776OjowMNDQ3q887Ro0dxxx13YNy4cXovL22VlZXgeV635x0KVEIYjUbMmzcP77zzjnqdKIp45513sGjRIh1XljmSJOH73/8+XnnlFfz1r3/F+PHj9V5SRl1wwQX4/PPP0dLSon7Mnz8f11xzDVpaWsBxnN5LTNuSJUsijpTv27cPY8eO1WlFmeVyucCy2qcmjuPUd3WFZvz48aitrdU87zgcDnz00UcF87yjBCn79+/HX/7yF1RUVOi9pIxauXIlduzYoXneqaurw9q1a/HWW2/pvby0GY1GnHnmmbo979DWT5jbb78d119/PebPn48FCxbg5z//OZxOJ2688Ua9l5YRq1evxgsvvIA//OEPsNls6h54aWkpLBaLzqtLn81mi6i3KSoqQkVFRcHU4dx2221YvHgxHnjgAVx++eXYtm0bnnrqKTz11FN6Ly0jLr74Ytx///1oaGjAzJkz8emnn+KRRx7BTTfdpPfShmxgYAAHDhxQLx8+fBgtLS2w2+1oaGjAmjVr8OMf/xiTJ0/G+PHjsW7dOtTV1eHSSy/Vb9EpiPf4Ro0ahW9+85tobm7G66+/DkEQ1Ocdu90Oo9Go17JTkuhnGB58GQwG1NbWYurUqcO91CFJ9PjWrl2LK664Aueccw7OP/98/PnPf8Zrr72G9957L/uL0+WsUY57/PHHpYaGBsloNEoLFiyQtm7dqveSMgZA1I9nnnlG76VlTaEdT5YkSXrttdekWbNmSSaTSZo2bZr01FNP6b2kjHE4HNKtt94qNTQ0SGazWZowYYL0wx/+UPJ4PHovbcjefffdqH93119/vSRJ8hHldevWSTU1NZLJZJIuuOACae/evfouOgXxHt/hw4djPu+8++67ei89aYl+huHy7XhyMo/vV7/6lTRp0iTJbDZLs2fPll599dVhWRsjSXnc7pEQQgghBY1qVAghhBCSsyhQIYQQQkjOokCFEEIIITmLAhVCCCGE5CwKVAghhBCSsyhQIYQQQkjOokCFEEIIITmLAhVCCCGE5CwKVAghhBCSsyhQIYQQQkjOokCFEEIIITmLAhVCCCGE5Kz/D0n7pBLJHJcvAAAAAElFTkSuQmCC" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# A plot of the solution of the additive-noise SDE used to compare the methods\n", + "# A plot of the solution of the SDE\n", + "# We will use this to compare the methods\n", + "sol_additive = diffeqsolve(\n", + " terms_time_sde,\n", + " ShARK(),\n", + " t0,\n", + " t1,\n", + " dt0=0.02,\n", + " y0=time_sde.y0,\n", + " args=time_sde.args,\n", + " saveat=diffrax.SaveAt(steps=True),\n", + ")\n", + "plot_sol_general(sol_additive)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "a8aeb3aa7e69b296", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-04-02T15:10:45.713553Z", + "start_time": "2024-04-02T15:07:46.654180Z" + } + }, + "outputs": [], + "source": [ + "# A comparison of SRKs for additive noise SDEs\n", + "# We compute their orders and plot them on the same graph\n", + "out_SLOWRK_time_sde = constant_step_strong_order(\n", + " keys, time_sde_short, SlowRK(), levels=(4, 10)\n", + ")\n", + "out_SPaRK_time_sde = constant_step_strong_order(\n", + " keys, time_sde_short, SPaRK(), levels=(4, 10)\n", + ")\n", + "out_GenShARK_time_sde = constant_step_strong_order(\n", + " keys, time_sde_short, GeneralShARK(), levels=(4, 10)\n", + ")\n", + "out_ShARK_time_sde = constant_step_strong_order(\n", + " keys, time_sde_short, ShARK(), levels=(4, 10)\n", + ")\n", + "out_SRA1_time_sde = constant_step_strong_order(\n", + " keys, time_sde_short, SRA1(), levels=(4, 10)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "78fed5faa530d9eb", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-04-02T15:10:45.891614Z", + "start_time": "2024-04-02T15:10:45.715098Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "
", + "image/png": "iVBORw0KGgoAAAANSUhEUgAABIMAAAOOCAYAAACTMtKnAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAAB7CAAAewgFu0HU+AAEAAElEQVR4nOzdd1xT1/sH8E9C2HupLEXcG1FRVNxbQdxb0bpHraO2VeuoraNqW63WPXDjxL3BgYgIKuJegIDI3jPj/v7gy/0lZEMAlef9euVl4jn33HOTcJM895zncBiGYUAIIYQQQgghhBBCqgRuZXeAEEIIIYQQQgghhFQcCgYRQgghhBBCCCGEVCEUDCKEEEIIIYQQQgipQigYRAghhBBCCCGEEFKFUDCIEEIIIYQQQgghpAqhYBAhhBBCCCGEEEJIFULBIEIIIYQQQgghhJAqhIJBhBBCCCGEEEIIIVUIBYMIIYQQQgghhBBCqhAKBhFCCCGEEEIIIYRUIRQMIoQQQgghhBBCCKlCKBhECCGEEEIIIYQQUoVQMIgQQgghhBBCCCGkCqFgECGEEEIIIYQQQkgVQsEgQgghhBBCCCGEkCqEgkGEEEIIIYQQQgghVQgFgwghhBBCCCGEEEKqEAoGEUIIIYQQQgghhFQhFAwihBBCCCGEEEIIqUIoGEQI+aLt378fHA4HHA4H3t7eld2dr0JKSgpWrlyJtm3bwtzcHFpaWuxzuH///sruHiGEkHLg6OjInuujoqI00qa3t7fSz48VK1awdVasWKGR/QLArVu32Ha7dOmisXZJ2ZTX600IqXi8yu4AIaRiREdH4+zZs7h8+TLev3+PhIQEFBYWonr16rCzs0P37t3h4eGBNm3aVHZXSRl8+PABnTp1QlxcXGV3hRBCCCHkqxcWFobDhw8jODgY79+/R3p6OjgcDoyMjGBvb4969eqhVatW6NChA9zc3MDjyf6JvX//fkycOFFmGZfLhYmJCUxNTWFmZobGjRujVatWaNeuHdq3bw8Oh6Nyf729veHj41OqY+3cuTNu3bpVqm3J14eCQYR845KTk7FixQrs2LEDAoFAqjw6OhrR0dEICgrCqlWr0Lt3b2zYsAFNmzathN6Sspo2bRobCNLX10ePHj1gZ2cHLS0tAECjRo0qs3uEEEKqoFu3bqFr164A6Mcm+XrExsZi2rRpuHTpkszygoICpKSkIDw8HCdPngQAmJqaIiUlhf3epSqRSIT09HSkp6cjOjoa4eHhOHr0KACgbt26mDp1KmbNmgUDA4OyHRQhYigYRMg37MWLF+jTpw9iYmLY/+PxeGjXrh1q1aoFXV1dfPr0CUFBQcjMzAQAXL16Fbdu3cLhw4cxZMiQyuo6KYX4+HjcuHEDAKCrq4vw8HDUq1evkntFCCGEEPJ1iY6ORqdOnfDx40f2/8zMzODq6go7Oztoa2sjNTUVr169wsuXLyEUCgEAGRkZYBhGafvGxsYYP368xP/l5OQgPT0dsbGxCA8PB5/PBwC8e/cOixYtwp49e3Do0CG0bt1a5eNo2LAhunfvrnJ9+t5YtVAwiJBv1IsXL9CxY0ekpaUBALS1tbFw4UIsWLAAlpaWEnULCgrg6+uLhQsXIikpCQUFBRg+fDgOHDiAMWPGVEb3SSk8fvyYve/u7k4f6IQQQsrdihUryiV3TJcuXVT6UU0qVnm93l+aSZMmsYEgExMTbNq0CWPGjIG2trZU3czMTFy4cAFHjhzB5cuXVWrfwsICW7ZskVuen5+PGzdu4O+//4a/vz8A4PXr1+jYsSNu3ryJDh06qLSftm3bKtwPqdoogTQh36D8/HyMGDGCDQQZGBjg+vXrWL16tVQgCCgaRTJ+/HiJkSQikQjTpk3DmzdvKrTvpPSKX28AsLGxqcSeEEIIIYR8nR4+fMgGYDgcDs6fPw9vb2+ZgSCgKFg0evRoXLhwAW/fvlV7ipgsenp6GDBgAG7evIl9+/ZBX18fQNEF3EGDBkmMWCKktCgYRMg3aPXq1Xj27Bn7+ODBg+jcubPS7WxsbHD9+nUYGxsDKBquOmXKlHLrJ9Gs4uHEQFEiQkIIIYQQop5r166x99u2bYtOnTqpvK2Tk5NayZ5VUTIhdFJSUpUYnUXKH/1aIOQbk5ubi61bt7KPvby8MHjwYJW3r1WrFlauXMk+vnPnDkJCQmTW7dKlC7u8aHEyyPj4eKxevRqurq6oUaMGtLS0YGZmJnP7R48eYcqUKXBycoK+vj6sra3h6uqKP//8E6mpqSr3uaSHDx9i3rx5cHZ2hrW1NXR0dFCjRg107twZ69atkxhBI4+sJXrfv3+PJUuWoGXLlrC2tgaXy4Wzs3Op+1ksOzsbmzdvRu/evWFvbw89PT2Ym5ujadOmmD17Nh48eCB3W/Gld8VXqPDx8WH/v/jm7e1d5r4CwOXLlzFt2jQ0bdoUlpaW0NbWhpmZGVxcXDBt2jScO3dOZrJycQzD4MSJExg1ahTq1KkDIyMjGBkZoU6dOhg9ejROnjyp0vQAWe/B1NRUrFu3Dm3atIGVlRX09fXh5OSE7777TiJIWtJff/3FttW7d2+Vn4+goCB2OwsLCxQUFMitm5OTg23btsHDwwO1atWCgYEBjI2NUa9ePUyaNIm9EqnI/v37pV5ToVCIY8eOYeDAgezfE4fDgZ+fn9T2aWlp+P3339G6dWuYm5vDyMgIDRo0wOTJk/Hw4UO2nvh7RxUpKSnYuHEjevbsCQcHB+jp6bErosyaNQuhoaFK25C1ZLFAIMCBAwfYZOi6urqwsbGBl5cXLly4oFLfxD19+hQ///wz2rZtixo1akBHR4d9DkaMGIE9e/YgIyOjQo63tJ4/f44ff/wRLVu2hJWVFXR1dWFra4suXbpg3bp1SElJUdqGrPcRAJw5cwYeHh6oWbMmdHV1Ua1aNfTq1QuHDh3S+JSdvLw8+Pn54fvvv0fHjh1RvXp19vVwdHTEoEGDsGfPHhQWFiptS94y5P7+/hg5ciScnJygp6cHS0tLdOrUCVu2bJEIoGsCn8/H1atXsWjRInTt2hW2trbQ09ODvr4+7O3t0bdvX/zzzz/Izs5Wq92CggL8+++/cHd3h7W1NfT19dlzZUBAQKn76+fnh4EDB7J/V/b29ujZsycOHjyo9BwuTtFS48VlxcmjAeD27dtSn08cDgeOjo4S2ypaWp7P58PKyootv3//vsr97dWrF7vd+vXrFdbVxPcJVVXEe/jq1auYNGkS6tevDxMTE+jr66NWrVoYNGgQ9u/fr1J76iwt7+/vj++++w7NmjWDmZkZeDweDAwMYG9vD3d3d/zwww+4cOGCSn/jFXnOFV+RtVatWhprtyyGDRuG0aNHs48PHTqE6OjoSuwR+SYwhJBvyv79+xkA7O3OnTtqt5GVlcUYGRmxbXh7e8us17lzZ7ZOQEAA4+fnx5ibm0vsHwBjamoqte2SJUsYLS0tqbrFN3t7e+b+/fvMvn372P+bMGGCwn6npqYyQ4YMkdtm8c3MzIw5ceKEwrZq1arF1o+MjGR27NjB6OnpSbXVokULFZ9V2c6fP8/UqFFDaZ9Hjx7N5OTkSG0fEBCgdFtVnz9lnj17xrRu3VqlfY0YMUJuO2/evGFatmyptI1WrVox79+/V9inku/BwMBAxs7OTm6bWlpazM6dO2W29enTJ/Y9qaWlxcTHx6v0vMyYMYNtf+rUqXLrHT9+XKXXesCAAUx6errcdkr+TcTFxTEdO3aU2daZM2cktvX392eqV68ud99cLpdZsWIFwzCMxP8rs2XLFsbU1FThcXE4HGbSpElMQUGB3HaWL1/O1l++fDkTGxvLtG/fXmG7EydOZIRCodI+pqWlMSNGjGA4HI7S16B69eoVcrzq4vP5zJw5cxSeO4vPcfv371fYVsn3UXp6OuPp6amw3T59+jC5ubkaOZbg4GCJzxlFN0dHR+bRo0cK2xM/F3bu3JkpKChgpkyZorBdFxcXJikpSSPH8/HjR8bS0lKl47G0tGSuXbumUrsvXrxgGjRooLC96dOnM4WFhVKfW/JkZWUx/fr1U9hmx44dmfj4eGbChAns/+3bt09meyX/buWVKbvVqlVLYtuSr2lJ4ufemTNnqvR8ip/nuVwuExsbK7OeJr9PqKo838MJCQlM9+7dlR5PvXr1mIcPHypsS9HrXSw7O1vp+UT8tmvXLoX7rOhz7qxZs9h2XV1dy9xeMfHzbsn3uyrCwsIkjvmvv/6SWU/877as3/3It40SSBPyjRG/Sujg4AB3d3e12zAyMsLAgQNx+PBhAFBpCdigoCCsWLECfD6fvWplZWWFxMREicTGALB48WKsWbOGfWxgYIBu3brBxsYGnz9/hr+/P2JjY9GvXz/88MMPKvX58+fP6NatG16+fMn+X5MmTdCiRQsYGRkhMTERd+/eRUpKCtLT0zF8+HAcPHhQpQTZJ06cwKJFiwAAtra26NChA0xNTfHp06cyjWDy9fXFmDFj2BUotLS00LFjR9StWxfZ2dm4e/cuPn36BAA4cuQIIiMj4e/vDz09PbYNOzs7zJo1CwDw6tUr3Lx5E4Ds1SPatWtX6r7eunULnp6eyMrKYv+vZs2acHV1hYWFBXJycvD69Wt29Yv8/HyZ7bx8+RKdO3dGUlIS+3/NmjWDs7MzOBwOHj9+jIiICABAWFgY2rdvjzt37qB+/fpK+/js2TP88ssvyM7ORrVq1eDu7g5LS0vExcXB398feXl5EAqFmD59Opo1ayb1fNjY2KBbt264fv06O9JG2fuPz+fj+PHj7ONx48bJrPf3339jwYIF7MgKExMTuLm5wd7eHkKhEM+fP0doaCgYhsGFCxfQpUsX3Lt3T+kSsgUFBfD09ERYWBh4PB7at2+POnXqoKCgAI8ePZKoGxwcjAEDBiA3NxdA0cifNm3aoEmTJigsLERISAjevn2LFStWwMrKSuF+xf3www/YtGkT+9jKygpubm6oUaMG8vPz8fjxYzx79gwMw2Dv3r349OkTLl68qHQqY3Z2Nvr06YNnz57BwMAA7u7ucHBwQFZWFgICApCYmAgA2LdvHxo0aICffvpJblufPn1Ct27d8Pr1a/b/zMzM0KFDB9jY2IDP5+Pjx48ICwtDZmam3PdveR6vMiKRCEOGDMG5c+fY/7OwsECXLl1gYWGBmJgYBAQEoLCwEOnp6fD29kZ6ejrmzp2rtG2BQIAhQ4bg5s2b0NHRYd9H+fn5uHv3Lpub4sqVK5g/fz62bdtWpmMBikaoFY+QqVatGpo0aQJ7e3sYGhoiNzcX7969Q0hICAQCAaKiotC5c2c8evQIdevWVan9qVOnwsfHB1wuF23btkXDhg0hEokQHBzMvg8ePXqE8ePHy102Wh05OTnsiCxzc3M0adIEtWrVgpGREQoLCxEZGYng4GDk5+cjJSUF/fr1w+3bt9G+fXu5bUZHR6N79+6Ij49n/69JkyZwcXEBh8PBo0eP8OzZM2zfvl3l5ab5fD769++PO3fusP9Xo0YNdOrUCcbGxnj37h0CAwMRGBiIQYMGwcnJqZTPSBFXV1fMmjULcXFx7EhFW1tbDBo0SKqurJyGiowdO5Z9Lx4/fhybNm0Cj6f4p82xY8fYz9yuXbvCzs5Oqk55fp9Qh6bewwkJCejQoQPev3/P/l+dOnXQtm1b6Orq4sWLF+wI5Ldv36Jr1664cuWKygmKZRk7dqzEuapu3bpo2bIlLCwswOfzkZSUhIiICHbktSKVcc6tU6cOez80NBQ3b95Ua0Wu8uLi4gInJyd8+PABAHD37l3MmzevkntFvmqVGooihGhcnTp12KsBQ4cOLXU7mzdvlrj6IOvqmfioDB6Px3A4HGbVqlVMYWGhRL38/Hz2/u3btyWuzA8dOpRJTU2VqJ+ens6MHDmSAcDo6OgovbohFAqZrl27SlzFkXUVOS8vj1mxYgW7f0NDQ+bDhw8y2xS/wsrj8RgdHR1m586djEgkknts6nj37p3EVXFXV1fm7du3Use1ceNGhsvlsvXmzJkjt011RlGp4+PHj4yVlRXbdu3atZnLly/LrJuamsps376dWbhwoVRZQUEB06JFC7adatWqMdevX5eqd/XqVYn9ubi4SL2niom/B3V1dRktLS1m48aNDJ/PlzqGpk2bsnW7du0qsz0fHx+2TqtWrZQ9Ncy5c+cknpeS7w+GYZgbN26wr6GOjg6zdu1amaO8Hj9+zDRu3Jhtb8aMGTL3Kf4683g89iqyrJEAxe/PvLw8pl69ehJ9DQkJkarv6+vLGBgYMLq6uhJ///Ls2bOHrWNiYsLs2rVL5mvl7+8vMWJr3bp1MtsTv+Jc3IcJEyYwKSkpEvVycnKYUaNGsXWNjIyY7OxsmW3y+XymQ4cObF19fX1my5YtMvtZUFDAnDt3jvHy8qqQ41XHunXrJF6Tn3/+WeoKeHx8PNOrVy+J90dwcLDM9sTfR8XPdd++faXO9Xw+n1m4cKHE1XdFo05UFRwczCxevJiJiIiQWychIYEZN24cu+/u3bvLrSs+qqL4eNq0acO8fPlSop5IJGL++ecfiefy9u3bZT6eqKgoZs6cOcyDBw/kjlTLyMhgFixYwO63fv36Cke1iY/kMDU1Zc6fPy9V59KlS+yIXG1tbba+vNfot99+k3gt//jjD0YgEEjUef36NXuuFv8MLs3IoGLKRvmUdhsnJye2jqznpyQXFxeFx1Me3ydUVV7v4b59+7L1DA0NmaNHj0rVefjwocRz6eDgwKSlpclsT9nr/eTJE4lz86VLl+T27f3798zvv//OnDt3TmZ5ZZ1z3759K/Hdy8DAgPn111+lvqepq6wjgxiGkfjss7GxkVmHRgYRVVEwiJBvTPGPQwDsdI/S8Pf3l/iiERgYKFVH/Ic4AOb3339X2q74lI/u3btLfQktJhQKJX7UKPpAO3DgAFunXbt2SqcxiH+RmT59usw64sEgAMyhQ4eUHps6xo8fz7Zdt25dhdOC/vrrL7Yul8uV+4WzvIJBY8aMkfjy8vnz51K1s3fvXrYdbW1thdM+QkJCJN7LPj4+MuuVfA/u2LFDbpsRERHsF3cOh8N8+vRJqk5WVhZjYGDAtvfq1SuFxzRixAi27tKlS6XKhUKhRBDm9OnTCtuLj49np3Fpa2szMTExUnXEX2cATLNmzZS+57dt2ybxpfbdu3dy654+fVqifUD2V4XMzEzGzMyM/cEoL+hQ7MWLF+xUS0tLS5kBsZJTSkaNGiW3vby8PMbBwYGte+zYMZn1du3aJfG+K83UWYYpn+NVVUZGhkTwWFawtVh+fj7Tpk0btq68wGfJ95G7u7tUELWYSCSSaHPt2rWlPpbSEP8x++LFC5l1Sk6ZrVevHpOVlSW3zaFDhyr9HCgv06dPZ/ct74fytWvXJII2/v7+ctu7c+eO1PRHWcGg9PR0ifObou8IiYmJjI2NjUSbX2IwaNmyZWydkSNHKmzv5cuXbF19fX0mMzNTqk55fJ9QVXm8h0t+l7tw4YLc9iIjIyWmYq1cuVJmPWWv97///suWL1myRPFBK1CZ51yGYZg5c+ZIfRYCRdNWR4wYwaxfv565c+eOWhcFNREMWrFiBdsGj8eTWUc8GNSwYUNm1qxZKt/evHlTqn6RrxMFgwj5hmRkZEh8YG3atKnUbT1+/FiiLVlXbcR/iNva2sr9IVHsxYsXEm0+f/5cYf3Xr19LfMGVF+BwdnZm6zx58kTpseXl5bFfMExNTWVemRUPBmlyvjjDFOUuER95oSxAIBQKmSZNmrD1f/75Z5n1yiMYFBsbKxGUkTciSBVt27Zl2/n++++V1hfPB9GuXTuZdcTfg82aNVPapqurq8L3NMMwzOjRoxUGeIplZmYy+vr6bF1ZgSM/Pz+2XN5ok5LWrFnDbrNx40ap8pI/4hVddS0mnutp8eLFSuuLXxkHZH9VEL8y/cMPPyg/MIZhpk2bxm5z6tQpqXLxHxk6OjpK8zYtWrSIrT9//nyZdRo2bMjW+emnn1TqpyzlcbyqEg/mVa9eXekP1AcPHki8frLemyXfR2FhYQrb/O+//9i6gwcPLvWxlIavry+7782bN8usU/KHtLLn+9KlS2xdFxeX8ui2XOKvj7z37fDhw9k6w4YNU9qm+HkLkB0MEn8N7e3tleZW2blzp0SbX2Iw6M2bN2wdAwMDmQGeYkuWLGHrygsclcf3CVWVx3tY/IKFp6en0j6Ij0C0sbGROdpV2ev9xx9/sOX//POP0n3KU5nnXIYpGhUp/j1E3k1XV5cZMGCAVI4+WTQRDPr7778l9p+RkSFVRzwYpO4tICCgVP0iXydaTYyQb4h4PhcAMDQ0LHVbRkZGEo8zMzMV1h86dKjSufri+YxatWqFxo0bK6xfv359pXlu4uPj8eTJEwBA48aN0aJFC4X1AUBPTw9ubm4AgIyMDIUrTAHAyJEjlbapjqCgIHbFKSsrK3h4eCisz+VyMWnSJPZxWVaPUdeNGzfYVWXq1auHPn36lKqdrKwsiZU+xI9HnsmTJ7P3Hz58iJycHIX1hw0bprTNli1bsvfl5SoYO3Yse//IkSNy2zp9+jTy8vIAAK1bt0aDBg2k6ojncRBfBUSRbt26sfcDAwMV1jU3N0evXr0U1snKypLIHyR+fPKoUqe8j61jx46oUaOGwjrKXs/o6Gi8evWKfTx79myV+ilLeR+vIuKrzI0aNQr6+voK67u6uqJZs2bsY2XnDCcnJ7i4uCiso8rfTmnl5ubC398fmzZtwtKlSzF37lzMnj2bvR09epStW3y+V0RPT0/pebU8j4fP5yMwMBBbt27Fr7/+ih9++EHieMRX/JR3POKv2fjx45Xuc8KECUrriLc5YsQI6OjoKKw/cuRIpXUqW7169eDq6gqg6H105swZuXXFz+eyznEV8X1CVZp6D4u/5qp87k6cOJHNtRMfHy+RZ01VDg4O7P0DBw6weerUVZnnXADg8Xj477//EBQUBC8vL7l/CwUFBbhw4QIGDRqEDh06sDnWykvJ7+clv/sTog5KIE3IN8TY2FjisbIfz4qUXPrWxMREYf1WrVopbVM8kXTxlydl3NzcFC4ZK16Wl5en8o898USKMTExaN68udy6qhybOsSfB1dXV6VBNAASiRwfP34MhmFUXvK7LIKDg9n7JZf2VcfTp0/ZpJ1GRkYKn+9izs7OMDQ0RE5ODoRCIcLDwxUmWxX/8SuPeIJSeQHOnj17olq1akhMTMSHDx8QFBQkc7+HDh1i78tLHC3+/jx16hRu376ttI/iy5rHxMQorOvs7AwtLS2FdZ4+fQqRSASg6O+4YcOGSvvQtm1bpXXEj23nzp3w8fFRuk1sbCx7X9mxaeL1FH//1qtXD/b29krblKe8j1cR8XOGor8BcR06dGCTsZdMJl6Spv521JWamoply5bhwIEDKv+gSU5OVlqnQYMG0NbWVlinPI4nLy8Pq1evxvbt21XqJyD7eOLi4iSS7KuS/L9du3bgcDhsknpZ1P0MNjY2RtOmTZW+fyrb2LFjERISAqDovCwreHbv3j1ERkYCAKytrdG7d2+pOhXxfUJVmngPx8XFsYn2AdXOHdbW1qhfvz4bRH/06JFKnxni+vXrx352F2//3XffoX///mjZsqXSz6xilXnOFefm5oYzZ84gIyMDd+7cQWBgIMLCwvDo0SOkpaVJ1A0KCkK7du3w8OFDmcnJNaHkuVLZ9/MJEyZg//795dIX8vWjYBAh3xATExPweDx2JEdZVroq+QFnYWGhsL61tbXSNsW/3NasWVOlfiirV7zaFgBERkZKXHFVVcljLUmVY1OH+PNQq1YtlbZxdHRk7xcWFiIrK0vpFwBNSEhIYO+XZVUZ8WN2cHBQKZDF5XLh4ODAfilV9uPK1NRUaZviX675fL7MOjweDyNHjsTmzZsBAIcPH5b6Eh0fH8+O1iiuL4v4+9PX11dp/0rSxHtT/Lm3t7dX6blXFjTJzs6W+EK6e/dupW2WpOzYNPF6aur9WxHHq0hZzxkV9bejjujoaHTq1Entq+iqBI3UPZ7iz8yySEtLQ7du3VQauSRO1vGIv94GBgYqre5nYmICU1NTpKeny61T2s/gLz0YNHLkSMyfPx8CgQD+/v74/Pmz1KjC4tVRi+vLughTEd8nVKWJ97D4662vr6/ydxlHR0eVP3dlsbS0xO7duzF+/Hjw+XzExMRgxYoVWLFiBYyMjNC2bVt07twZHh4ecHZ2ltlGZZ9zZTE1NYWHhwc7YothGDx+/BiHDh3Cjh072BFQ8fHxmDFjhsRqapokfsFIW1tb6kIwIeqgaWKEfGPEfyiUZbhyyW3Ff1jIomzaAiA52kjVZXCVTXUT/1AsLWU/BFQ5NnWIPw+qTuUrWa+ihgWL76fk0GR1lOaYS9ZVdsyaHCklPtLn+PHjUj9+jx49yo626dWrF6pVqyaznbK+PzXx3izN352y17oi/u408Xpq6v1bEcerSFnPGRX5t6Oq0aNHs4EgY2NjzJs3D1euXMGHDx+QnZ0NoVAIpii3pcRUl+K/O0Uq43hmzZrFBoJ0dHQwefJknD17Fm/evEFWVhYEAgF7PMUjVADZx1Oav1lA+XujPD6DvwTiI32EQqHEtEKgKHh5/Phx9rG8kZyV/XcuThPv4Yr43JVn5MiRCAkJwaBBgySCVtnZ2bh58yaWLVuGli1bonXr1rh7967U9l/SayEPh8OBi4sL/vrrLzx69Ag2NjZs2fnz5yX+zjVJfOpz9erVy2UfpOqgkUGEfGM6dOjADll+8OBBqdsR39bR0VEjw13Ff4ypOodc2VQ38S8tnp6eOHv2bOk6V4HEnwdVp/KVrFdRV4LE91Ny6qA6SnPMJetW5NWv1q1bo2HDhnj16hWSk5Nx9epVDBgwgC0Xv8qsKL+OoaEh+6X20aNHEjkeKkp5/90BRaMQzc3N1e9cOdPU+7eyj9fIyIh9H5XmnPGlXTkOCgpCUFAQgKJjCw4OVphD7kvPiREXF4djx44BKBrReOXKFXTt2lVufWXHU5q/WUD5e0P8faSpc8GXYuzYsbh48SKAovPzvHnz2LIrV64gJSUFQNH0qzZt2shs42v8PqFIZX/uOjs74/Tp00hPT2enVwUGBiI0NJS9wBIWFoauXbvi6NGjEnn/Kvucq64GDRpg48aNErmN7t69i9q1a2t8X+Lfz1WZQkqIIjQyiJBvjPgX0NjYWNy5c0ftNrKzsyW+BCn6UqsO8SHKqk4NUDbnW/yqyOfPn0vXsQpWmudBPDmkjo5Ohf24E39+y3KVS/yYY2NjFea1KCYSiSRef1WmSmjSmDFj2Pvi+YFevnzJTpswNjaGl5eX3Da+hPen+PMWFxen0jbieRdkMTMzg66uLvv4S/3b09T7t7KPt6znjIr+21Hm5s2b7P0JEyYoXUwgOjq6vLtUJv7+/uw5rW/fvko/M5Udj/jrnZubywYyFMnKylI6mqI8PoO/FAMHDmQ/F8PCwiRGT4ifvxUF77+E87Umib/eeXl5Kk/50vS5w8zMDJ6envjzzz8RFBSE5ORk7Nu3j52qKBQKMXPmTHZBhuJtvobPGHElF9iIj4/X+D5CQ0MlXp9OnTppfB+kaqFgECHfmGHDhklcPfnrr7/UbmPXrl0SV9GnT5+ukb6Jj4oQT+yqiKLk0YBkotsnT558FVcxxZ+HkJAQNrGyIsVX0Yu3r6hpEOJXncqyilnz5s3ZpJFZWVlsYltFwsPD2ddTS0tLpZVdNGnMmDHs83zu3Dn2ar74qKDBgwcrnKol/v68d+9eOfVUsebNm7Orw2RkZEj8SJKnOBmrIsUr+ACVd2zKiL9/37x5ozTIpUhlHq/4OUP8XKCIeD1lK4VVNPHcLKokry7NRY2KpOnjsbOzk/ghr8rnZXBwsNIgu7qfwdnZ2RpbHau8P7P09fUxePBg9nHxeTorKwvnz59n+yAe5C/pa/w+oYidnZ3EFGZVzh3Jycl48+YN+7g8zh0mJibw9vaGv78/G/BJTk6W+r73NXzGiNPT05N4LB7M0hTx7/Q6OjoS73lCSoOCQYR8YwwNDTFz5kz28dmzZxUutVpSdHQ0li1bxj7u1KmTxAdyWYhfLQ0NDVX6o/Tdu3dKg0FOTk5o1KgRgKLEynv27Cl7R8tZ+/bt2S8JSUlJ7NB2eUQiEfbt28c+Fl86tbz17NmTTbT59u1bXL16tVTtGBsbo3Xr1uxjVVa2EH8tXV1dKzx3Re3atdnE0Xl5eTh9+jQYhlG6PLE48alle/fuRX5+fvl0VgETExOJH4HiwSx5xK+kyyN+bNu2bVNptFdFq1WrFnt+AFCqhLDFKvN4xf/mjx07pvR9FBoaiqdPn7KPNTW6U1OKg5OA8ulKnz59+uKn66hzPLm5uThw4IDSNsVfs4MHDyqtr26bvr6+ShOB+/r6oqCgQGm7qhD/oayJBOSyiJ+Pi89zp06dYkectG/fXuG0na/x+4Qy4q+5Kp+7+/fvZ/NY2draokGDBuXVNdSpUwdNmjRhH4sn/Ae+js8YceHh4RKPVU3SrqoTJ05I5MPy9vYutxXLSNVBwSBCvkGLFy+WGHY/duxYla6sfv78Gb169WJHBRkaGmLXrl0a61ejRo0kVmWaO3eu3GSgIpEI33//vUof/j/99BN7f+nSpSqNOilWGUOPzczMMGLECPbxjz/+qDCHxJYtW9hj4nK5mDp1arn3sZitra1EX6dNmyb1hU1V06ZNY+9v3bpV4sdqSWFhYdixYwf7WFOj09Qlnmj08OHDCAoKYqcb2dnZKQ3MDRkyBHXr1gVQNGR85syZKn+hzc7O1tiV6UmTJrH3//nnH4VTps6dOycxjUeeadOmwczMDEBRPqSVK1eq3J/k5GSVRsRpwvz589n7GzdulJmsVBWVebyjR49m83/Ex8cr3HdhYSHmzJnDPu7atWu5/qArDfGV3RStuCMUCjF16lQUFhZWRLdKTfx4Ll26pPC1XrBggUrn0MmTJ7P3jx8/rvAz/N69exJBanlGjx7NJo6OiYnBunXr5NZNSUmRuDBUVuLLoKs6XVVd3bp1Y38cR0ZGIigoSCL4LS9xtLiv7fuEMuKfu2fOnFF4QSc6Ohp//PGHxLalGdGl6nQ0oVAoMZWq5EIMlXnO3b9/P06ePKny57VQKMTy5cvZx9ra2ujevXup91+Sj48PJkyYwD6uUaOGRv8+SRXGEEK+SREREYypqSkDgAHAaGtrM4sXL2aSk5Ol6hYUFDA+Pj5MtWrV2PpcLpc5dOiQwn107tyZrR8QEKBSvwICAhgOh8NuN2LECCYtLU2iTkZGBjN69GgGAKOjo8PWnTBhgsw2BQIB061bN7aeiYkJs337dqagoEBm/YyMDObQoUNM586dmaFDh8qsU6tWLba9yMhIlY5NHe/evWOMjIzYfbi5uTHv37+XqCMUCpl//vmH0dLSYuvNmTNHbpv79u1T+lyVxsePHxkLCwu27dq1azNXrlyRWTctLY3ZsWMH8+OPP0qVFRQUMC1atGDbqVGjBuPv7y9V7/r164y1tTVbz8XFhSksLJS5P3Xfg8uXL2frL1++XGn91NRU9j2opaXFeHl5sdsvXLhQ6fbFxyP+Gvbt25d58eKF3PqPHz9mFi1axJiZmTERERFS5aV5nXNzc5m6deuy2zk5OTGhoaFS9U6cOMEYGhoyurq6bF1FXxXE+wKAGT9+PBMdHS2zrkgkYgIDA5kZM2Yw+vr6TFZWllQddV+fgIAAtn7nzp1l1uHz+Uz79u3ZegYGBszWrVtlvqcKCgqYc+fOMV5eXhVyvOpYt26dxL6XLl0qdY77/Pkz06dPH7YOj8djgoODlR6LKu+jyMhItn6tWrXKdCwvX76U+BxYsGABk5ubK1EnPj6eGThwIAOAMTQ0VPo6q/JeKEmV97gqUlNTGQMDA7atMWPGyPxcmzJlitTxKHouu3btytYzNzdnLl68KFXn6tWr7PlZW1tb6efWihUr2DocDodZu3YtIxAIJOq8efOGadmypdRn8L59+2S2qcrfrUAgkHiOQkJC5B53sdK8pgsXLmS38fLyYrhcLnscqampSrcvj+8Tqiqv93Dfvn3ZOkZGRszx48el6oSGhkp8Rjg4OEi9h4spe729vb0Zd3d3xsfHR24bycnJzMSJEyWe55LnAIapvHPuTz/9xABgGjZsyKxbt46JioqSW/f58+dMr169JPr5/fffy6wrfjzKzqN5eXnMhQsXmB49eki0ra+vzzx48EDhthMmTCiX74Pk20OriRHyjWratCkCAwPRt29fxMbGgs/nY/Xq1fjzzz/h5uaGWrVqQUdHB/Hx8QgKCpJIPKmrq4uDBw9KrOygKV26dMHChQuxfv16AEXD0C9cuIBu3bqhRo0aSEhIgL+/P7Kzs2Fubo65c+dixYoVCtvU0tLC8ePH0bNnTzx+/BiZmZmYPn06Fi1aBDc3N9jZ2UFLSwtpaWl4/fo1Xr58yS45OmTIEI0foyrq1KmD3bt3Y8yYMRAKhbh//z4aNGgAd3d31KlTB9nZ2bh7967EFdR27drhzz//rPC+Ojg44Pjx4/Dy8kJ2djYiIyPRp08f1KpVC66urrCwsEB2djbevHmDJ0+egM/nY+DAgVLt6Ojo4OjRo+jcuTOSkpLw+fNndOvWDS1atICzszOAojwN4kOtq1WrhqNHj0osTVuRzM3N0a9fP/j5+UEoFMLPz48tU+UqMwD06NED27Ztw4wZMyAUCnH58mVcuXIFjRs3RvPmzWFiYoLc3FzEx8cjPDwcSUlJGj8OfX197N+/Hz179kReXh4+fPiANm3aoG3btmjcuDEKCwsREhLC5orYsmULZs+eDUBxrg9vb298+PABq1atAlA0VeXw4cNwdnZGw4YNYWRkhOzsbMTGxuLJkycaWS5YXTweD76+vujWrRvevn2L3NxczJo1C0uWLEGHDh1gY2MDgUCA6OhohIWFITMzE6ampjLbqszjXbhwIQIDA9n8J7///ju2bduGrl27wtzcHDExMQgICJCY1rN+/XqJPChfioYNG2LcuHHs1KaNGzfiyJEjaNOmDapVq4aoqCjcuXMHhYWFMDY2xvr16yttdKAqzM3NsXDhQvz2228AikYRXr58GW3btoWdnR3i4+Nx69Yt5OTkgMfj4b///pO4yi/Pnj174ObmhoSEBKSlpaF///5o2rQpXFxcwOFw8PjxY3aE5fz583Hq1Cmlyal/+eUXXL9+Hffu3QPDMPj555+xadMmdO7cGUZGRnj37h3u3r0LoVCItm3bok6dOiqNOlJGS0sLXl5ebFtdunRBnz59ULNmTTafnIWFBRYvXlym/YwdOxYbNmwAAInzdf/+/VVajepr/D6hzL59+9iVZrOzszF8+HDUq1cPbdu2hY6ODl68eIEHDx6wo2AMDQ1x9OhRdlSOuhiGwd27d3H37l1oaWmhYcOGaNSoEczNzZGXl4e4uDjcu3dPYsTfhg0bZObfq+zPmFevXuGnn37CTz/9BHt7ezRv3hzW1tbQ09NDamoqnj17hpcvX0ps0759e4kRVvKkpqayn7PFcnNzkZ6ezh5PySmVTZo0waFDh9jvTKp48OCB1H6UWb9+vcJ8iOQbUsnBKEJIOUtISGCmT5/O8Hg8iSsL8m69evVinj59qlLbpRkZVOznn39mr9jJutna2jJBQUFqXb3Ozc1V61j19fWZ1atXy2yrvEcGFTt//jxTvXp1pX0dNWoUk5OTo7Ct8hoZVOzJkycSI3sU3caMGSO3ndevX7NXnRXdXFxcmHfv3insU3mPDGIYhjl58qRU35o1a6bStuL8/f2ZevXqqfT8AWCaNGnCxMXFSbVTlte55Kirkjcul8usWLGCKSwsZP/P1NRUabu+vr6Mra2tysfm6urK5OfnS7VTHiODiqWkpDCDBg1SqX92dnYVcrzq4vP5zOzZsyVGmsm6mZqayh3FUawyRwYxDMPk5ORIXU0vebO3t2cCAwNVep0rc2QQwxSNKBk/frzC4zEzM2POnDmj1nP57NkzpeeNKVOmMIWFhSp/bmVkZEiMIJN1a9++PfPp0yeJEQZlGRnEMAwTFRXF1KhRQ+4+Sz4XpXlNGYZhmjZtKtX2qVOnVN6eYTT7fUJV5fke/vz5s8SIJ3m3unXrKh21pez1nj17tsrnRmNjY2bnzp1Kj7Oiz7l+fn5M48aNVd4fUDQSc+7cuUx2drbcdkuOdFL1Vr9+fWbjxo1MXl6eSv0X/7stzU3eiC7y7aGRQYR846pVq4Zt27bhp59+gp+fH65cuYJ3794hMTERfD4f1tbWsLe3R7du3TBw4EC0adOmQvq1Zs0aDB06FP/99x/8/f0RHx8PIyMjODo6YvDgwZg6dSqsrKzw+vVrldvU19dnj/XQoUPw9/fHmzdvkJKSApFIBFNTUzg5OaFFixbo3r07+vTpAxMTk3I8SuUGDBiAd+/eYe/evbhw4QKeP3+O5ORk6Ovrw9bWFl27dsX48eO/iKv7LVq0wOPHj+Hn5wc/Pz/cv38fCQkJyMnJgYmJCZycnODq6goPDw/07t1bbjv169dHaGgoTp48iVOnTiEkJASJiYkAit6vbdu2xdChQzFkyJAKWzVNkQEDBsDMzAzp6ens/ylLHC1L165d8fLlS/j5+eHixYsIDg7G58+fkZmZCQMDA1SvXh0NGzZE+/bt0bdvX7Wu/KmqR48eePXqFbZs2QI/Pz98+PABfD4fdnZ26NSpE6ZNm4Y2bdpI5DRR5erw8OHDMXDgQBw7dgxXr17Fw4cPkZSUhOzsbBgaGsLOzg6NGjWCu7s7+vXrh/r162v82JSxsLDA6dOn8fDhQxw5cgS3bt1CbGws0tLSoK+vD3t7ezg7O6NPnz4YOnSowrYq63h5PB7+/fdfTJ8+HXv37sXNmzcRExODrKwsWFhYoH79+ujXrx+mTJkikaPlS2RgYIDLly/jyJEj8PHxYUdhWFlZwcnJCUOGDIG3tzfMzc1x69atyu6uUlpaWvDx8cGwYcOwc+dOPHjwAGlpaTA3N0fNmjUxcOBATJo0Cba2thJLQyvTpEkTPH36FDt37oSvry9evXqF3Nxc2NjYoE2bNpg8eTJ69uypVl9NTExw+fJlnD59Gvv378fDhw+RmpoKKysrNGrUCGPGjMHYsWM1PiKzVq1aCA8Px5YtW3Dt2jW8efMGWVlZ7MgaTRk3bpxE7h8zMzP0799frTa+xu8TilSvXh03b97ElStX4Ovri8DAQHz+/Bl8Ph/VqlVDy5Yt4eXlpZHX/d9//8XMmTNx48YNBAcH4/nz5/j48SOysrLA4/FgaWmJJk2aoFevXhg3bpxUriBZKvqcO3DgQAwcOBDv37/HrVu3EBQUhJcvXyIyMhLp6ekQCAQwMjJCtWrV0LRpU3Ts2BEjR46EjY1NqffJ5XJhbGwMExMTmJubo3HjxmjdujXat28PNze3Mh0PIfJwGOYLT81OCCGEkAp1/fp19OrVCwDQp08fXL58uZJ7RAghhBBCNIlWEyOEEEKIBF9fX/Z+RY0WJIQQQgghFYdGBhFCCCGE9eDBA7i7u7OJK1++fImGDRtWcq8IIYQQQogm0cggQgghpAr4+PEjhg0bhsDAQMi6DiQUCnHo0CH07t2bDQR5enpSIIgQQggh5BtEI4MIIYSQKiAqKgq1a9cGUJSou1WrVrCxsYGWlhYSEhJw//59iWXtbWxsEBYWVqaEmIQQQggh5MtEwSBCCCGkChAPBinTunVrnDx5ErVq1SrnXhFCCCGEkMpAwSBCCCGkiggJCcH58+cRHByM2NhYJCcnIz09HUZGRqhevTrc3NwwePBgeHh4VHZXCSGEEEJIOaJgECGEEEIIIYQQQkgVQgmkCSGEEEIIIYQQQqoQCgYRQgghhBBCCCGEVCEUDCKEEEIIIYQQQgipQigYRAghhBBCCCGEEFKFUDCIEEIIIYQQQgghpAqhYBAhhBBCCCGEEEJIFcKr7A6Qr09+fj4iIiIAANbW1uDx6G1ECCGEEEIIIYRomkAgQFJSEgCgWbNm0NPT00i79CueqC0iIgKurq6V3Q1CCCGEEEIIIaTKCAkJQZs2bTTSFk0TI4QQQgghhBBCCKlCaGQQUZu1tTV7PyQkBDY2NpXYG0IIIYQQQggh5NsUHx/PzswR/y1eVhQMImoTzxFkY2MDe3v7SuwNIYQQQgghhBDy7dNkvl6aJkYIIYQQQgghhBBShVAwiBBCCCGEEEIIIaQKoWAQIYQQQgghhBBCSBVCwSBCCCGEEEIIIYSQKoSCQYQQQgghhBBCCCFVCAWDCCGEEEIIIYQQQqoQCgYRQgghhBBCCCGEVCEUDCKEEEIIIYQQQgipQigYRAghhBBCCCGEEFKFUDCIEEIIIYQQQgghpArhVXYHCCGEEEIIIaSsRCIRsrOzkZmZicLCQgiFwsruEiGEQEtLCwYGBjAzM4Oenl5ld4dFwSBCCCGEEELIVy0rKwtxcXFgGKayu0IIIRIEAgEKCgqQlpYGU1NT2NjYgMPhVHa3KBhECCGEEEII+XrJCgRxOBxoaWlVYq8IIaSIQCBg72dkZEBHRwdWVlaV2KMiFAwihBBCCCGEfJVEIpFEIMjIyAgWFhYwMDD4Iq68E0KIUChEeno6EhMTAQBJSUkwMTGBjo5OpfaLEkgTQgghhBBCvkrZ2dkSgSB7e3sYGhpSIIgQ8sXQ0tKCpaUlLC0t2f/Lzs6uxB4VoWAQIYQQQggh5KuUmZnJ3rewsKAgECHki2ViYsLez8nJqcSeFKFgECGEEEIIIeSrVFhYCKAoR5CBgUEl94YQQuTT1dVlA9bF567KRMEgQgghhBBCyFepePl4LS0tGhVECPmiiSe2F4lEldwbCgYRQgghhBBCCCGEVCkUDCKEEEIIIYQQQgipQigYRAghhBBCCCGEEFKFUDCIEEIIIYQQQgghpAqhYBAhhBBCCCGEEEJIFULBIEIIIYQQQgghFaZLly7gcDjo0qVLZXeFkCqLgkGEEEIIIYQQQtSSk5OD7du3o1+/frCzs4Oenh50dXVhbW2NNm3aYNKkSdi1axdiYmIqu6sadevWLXA4HJk3AwMDODg4YMCAAdi7dy8KCgqUtle8rbLAmEAgwIgRI9j67dq1Q3p6umYOilRJvMruACGEEEIIIYSQr8f9+/cxcuRIfPz4UaosOTkZycnJCA0Nxb59+1C9enV8/vy5EnpZ8fLy8hAbG4vY2FhcvHgRf/31Fy5cuABHR8cytcvn8zFixAicOXMGANCxY0dcunQJxsbGGug1qaooGEQIIYQQQgghRCVv3rxB7969kZWVBQDw9PTE0KFDUb9+fejo6CA5ORnh4eG4fv06AgICKrm35WvGjBmYOXMm+zgxMRHPnj3D+vXrERsbi+fPn8PT0xOPHz+GlpZWqfZRUFCAoUOH4sKFCwCKpthduHABhoaGGjkGUnVRMIgQQgghhBBCiEqWLFnCBoL27dsHb29vqTo9e/bEwoULkZSUhOPHj1dwDytOtWrV0LRpU4n/69atGyZOnIjmzZsjKioKEREROHPmDIYOHap2+/n5+fDy8sLVq1cBFD2vZ8+ehb6+vkb6T6o2yhlEqgSGYbDoZDhOhsWCYZjK7g4hhBBCCCFfHaFQiIsXLwIAWrduLTMQJM7a2hqzZs2qgJ59WYyNjbF06VL28Y0bN9RuIzc3FwMGDGADQX379sW5c+coEEQ0hoJBpErwfRiD46GxWHgiHPN8nyArn1/ZXSKEEEIIIeSrkpSUhLy8PABA3bp1y31/gYGBGDduHBwdHaGnpwczMzO0bNkSS5cuRVJSksxtNmzYAA6HA21tbWRnZ0uV5+fnQ09Pj03E/OTJE5ntNGzYEBwOByNHjixV35s1a8beVzeJdnZ2Nvr164ebN28CKJqK5+fnBz09vVL1hRBZKBhEvnnvErOw4vxzcP43IMjvySf03xyIJzHpldovQgghhBBCviY6Ojrs/ZcvX5bbfkQiEWbPng13d3ccOnQI0dHRKCgoQEZGBp48eYI//vgD9erVw/Xr16W27dy5M4Ci1bcCAwOlyh88eCCxytetW7ek6iQkJOD169cAoHSVL3nEnyttbW2Vt8vMzESfPn1w+/ZtAMDQoUNx8uRJifYI0QQKBpFv3tXnCSgsFGF0ti7a5/PAZYCPqbkYui0I22+/h0hE08YIIYQQQghRxsLCArVq1QIAhIeHY926dRCJRBrfz88//4ytW7cCAGrXro3t27cjJCQEAQEBmDdvHrS1tZGRkYEBAwYgPDxcYlsXFxd2lS1ZgZ6S/6esTnFwSV3iwTJVVxPLyMhAr169cO/ePQDAqFGjcPToUbWCSYSoihJIk2/erK51Yfo+FxmhybAVclGbr4VLBoVIA4O1l1/h3rtkbBzeAtWMadglIYQQQsi3RiRikJZbWNndqFDmBjrgcjnl0vacOXOwcOFCAEVBm+3bt8PT0xPt27eHq6srateuXab2IyIisHHjRgBA06ZNcffuXZiZmbHlXbp0Qa9evdC/f38UFhZi6tSpePDgAVuupaWFjh074vLlyzIDPcUjbjw8PHD+/HncuXMHIpEIXC5Xqk716tXRqFEjtY9BKBRi/fr17GNVkkdnZGSgR48eCA0NBQCMHz8e+/btk+gXIZpEwSDyzfscmYHMRynsY1shFxOydHFLn48nOkLcfZuMvv/cxYbhLdC1QbVK7CkhhBBCCNG0tNxCtPpd/QS+X7OwpT1gaaRbLm3PmzcPL168wN69ewEAUVFR2Lx5MzZv3gygKIDSpUsXjBkzBgMGDACHo15Qatu2bexoo927d0sEgor16dMHkyZNwu7duxESEoKHDx+iTZs2bHmXLl1w+fJlhIWFITs7G0ZGRgCKlmkPDg4GAPz000+4ceMG0tLS8PTpUzg7O7PbFweROnXqpFbfk5KSEBERgWXLluHx48cAigJBHTt2VLqteO6iUaNGUSCIlDt6d5FvXvLHLKn/0wYHPfN0MDRHB4YiICWnEBP3PcTvF16gQCCshF4SQgghhBDy5eNyudizZw+uXbuGPn36gMeTHF+QkJAAX19feHp6wtXVFe/fv1er/eKVt5o0aYK2bdvKrTdlyhSpbYrJyxsUEhKCvLw8mJqaol27dmjXrh0AyWlhiYmJ7BQvZfmCVq5cySai5nA4qFatGrp374579+7BwMAA8+fPx5EjR5QfNCARNLt//z4+ffqk0naElBYFg8g3r2lnewz5sRXMqhtIldUWaGFilh7qFxb9KewOjMSQbUH4kCS98gAhhBBCCCGkSM+ePXH58mWkpKTg0qVLWLlyJTw8PGBqasrWCQ0Nhbu7O+Lj41Vqs6CgAG/fvgUAhYEgAGjZsiWbS+fZs2cSZa1atWJHA4kHeorvd+zYEVpaWmywR7xO8RQxoPT5ggDA2dkZ33//vcr5fjp27MiuXBYVFYXu3bvj8+fPpd4/IcpQMIhUCdVrm2D4kjZo1tlOqkyf4WBgri7652hDVwQ8i8vEgH8DcSosFgxDyaUJIYQQQgiRx8TEBH379sWyZctw7tw5JCQkYO/evTA3NwcAxMfH49dff1WprbS0NPZ+tWqK0zdoa2vD0tISAJCamipRxuPx0KFDBwCyAz3FQaDif4vzBonXsba2RpMmTRT2YcaMGYiIiEBERAQeP36M8+fPY8KECeByuQgKCkKXLl2QlJSksI1iXC4XBw8ehJeXFwDgzZs36NmzJ1JSUhRvSEgpUc6gb8TDhw+xfPlyBAUFgc/no1mzZpg/fz6GDx9e2V37YmjraKHTqAZwbG6FmwdeIjdDMpFgYz4P9gIuLhvw8RFCLDgRjrtvk7DKqymM9SiDPyGEEELI18jcQAdhS3tUdjcqlLlB5S1Drquri4kTJ8LW1hZ9+vQBAJw+fRo7d+5UKweOurmGSurSpQuuXr3K5g3S1dXF/fv32TKgaPSRnp6eRN6g4mCQKvmCqlWrhqZNm7KPnZ2dMWDAAHTt2hXe3t6IiorC5MmTcfbsWZX6zOPx4Ovri4EDB+LKlSt49uwZevXqBX9/f4kRV4RoAgWDvgEBAQHo3bs39PT0MHLkSBgbG+PUqVMYMWIEYmJisGDBgsru4helZhNLjPq1LW4deY33jxIlykwYLkbk6CJUV4C7enz4PfmERx/TsXlUSzg7mFVOhwkhhBBCSKlxuZxyS6ZM5OvduzccHBwQExODtLQ0pKSkwNraWuE2xaOJgKLcQ4oIBAJ21IyFhYVUecm8QcbGxsjNzYWpqSlatmwJoChw1a5dO9y6dQu3bt2Cvb09nj9/DkB5viBFJkyYgPPnz+PUqVM4d+4c/P390a1bN5W21dHRwenTp9G/f38EBATg0aNH6Nu3L65du8ZOfSNEE2ia2FdOIBBgypQp4HK5uHPnDnbu3ImNGzciPDwc9evXx+LFixEdHV3Z3fzi6Blpo/eUJugxsTF09KVjoq0LeBiXpYtqAg4+puZi6LYgbL/9HiIRTRsjhBBCCCFEFba2tux9VUb66Orqol69egAgsVy8LI8fPwafzwcAidE5xdq0aQNDQ0MARVPFikf8FOcLKiaeN+jOnTtsmoiy5AsCgNWrV7P7Wbx4sVrb6uvr49y5c3BzcwNQlFDaw8MDeXl5ZeoTIeIoGPSV8/f3x/v37zF69GiJ5RBNTU2xePFiFBYWwsfHp/I6+AXjcDho0LYGRv7qCrsG5lLlViIuxmbrom0+D0Ihg7WXX2HCvhAkZuVXQm8JIYQQQgj5euTm5uLFixcAivIKFef3UaZHj6Ipfc+fP0dISIjcert375baRhyPx0P79u0BgB35A0iP+BHPG+Tv7w8AsLS0lBlgUkf9+vXZlB0PHjzA9evX1dreyMgIly9fRqtWrQAUHcPgwYNRWFioZEtCVEPBoDJITEzEhQsXsGzZMvTt2xdWVlbssoLe3t5qtRUdHY0FCxagYcOGMDQ0hIWFBdq0aYP169cjNzdX7nbFJ7VevXpJlfXu3RuAZEZ8Is3YQg8D5zqj47B60OJJ/klogYNO+doYla0DMyEHd98mo+8/dxHwOlFOa4QQQgghhHybsrOz0bZtW1y4cIFNuCyLSCTCnDlzkJWVBQDw9PRUOQfQjBkz2NxCU6dORWZmplSda9euYc+ePQAAV1dXtGnTRmZbxYGesLAw3Lt3T+L/irVt2xa6urpIS0vDoUOHABTlCyprziKgaERQcTu///672tubmpri6tWraNasGQDgypUrGDFiBAQCQZn7RgjlDCqD6tWra6Sd8+fPY+zYsRInutzcXISGhiI0NBS7d+/GxYsXUbduXalti5deLB5OKa5GjRowMjJi61R1/MRE8KyswJGRuI7D5aBFdwc4NLLA9X3PkRwjubS8nVALE7K4CNDn4ylTiIn7HmJyx9r4sU8D6PK0pNojhBBCCCHkWxQSEgIPDw/Y2dnBy8sLbm5uqFWrFoyNjZGeno7Hjx9j7969iIiIAFAU0Fi1apXK7Tdr1gwLFizA+vXrER4eDhcXF/z0009o2bIlcnJycP78eWzevBlCoRA6OjrYsWOH3LbE8wYJBAKJfEHF9PT00K5dO9y+fRsZGRkAypYvSFzTpk3h6emJs2fP4s6dOwgMDETHjh3VasPS0hLXr19H586d8fr1a/j5+WH8+PE4dOiQWgm5CSmJ3j0aUrNmTZmjc5R5/PgxRowYgczMTBgZGeGPP/5AUFAQbt68iSlTpgAoWlawf//+bGRdXPEJS152eRMTE7ZOVSbMzkH0uHGImT4dghJLT4qzsDXE0J9ao1XfWih5MUAHHPTO08GgHB0YioDdgZEYsi0IH5KyZTdGCCGEEELIN4TH46FGjRoAgLi4OGzduhVjx46Fu7s7nJ2d0aVLF8ybN48NBNWrVw83b96Eo6OjWvtZu3YtZs6cCQB4//49pk6dijZt2qBLly7YuHEj+Hw+TE1Ncf78eYlUGSW5urrCwMCAfVwyX1CxksGfsuYLErdkyRL2vjpBMXHVq1fHzZs3Ubt2bQDA0aNHMWXKFDa/ESGlQcGgMli2bBnOnz+Pz58/Izo6WmFUWp65c+ciLy8PPB4P165dw+LFi+Hm5oZu3bph586d+PPPPwEUBYQ2btyo6UOoEhiGweeVK8GP/oicO3cROdALOcHyE9Jp8bhoN7AOBi1sBRNrfanyugIteGfpoV4hF8/iMjHg30CcCoulkzEhhBBCCPmm6enpIS4uDvfu3cPKlSvRt29fODk5wdDQEFpaWjAxMUHDhg0xYsQIHDlyBM+ePWNz3qiDy+Vi69atuHPnDsaMGYOaNWtCV1cXJiYmcHZ2xuLFi/H27VulF+O1tbXZJMyA/BE/4v9vYWGB5s2bq91nedq0aYOePXsCKJre9vDhw1K1Y2dnB39/fzg4OAAA9u7dizlz5misn6Tq4TD0C1ZjoqKi2GjthAkTsH//foX1Q0JC0LZtWwDAtGnTsH37dqk6IpEITZs2xcuXL2FmZobExERoa2uz5cOGDcPJkycRGhoq80RrbGwMc3NzfPz4sQxHJik2NpY9CcXExMDe3l5jbZeH9NNnEF8ygz+HA6sZ02E1cyY4PPmzJQvzBbh36h1e3P0kszxCRwB/fT4KOYCXsy1WeTWFsZ62zLqEEEIIIUSz3r59C4FAAB6PJzNtAiGEfElKc84qr9/fNDKoEvn5+bH3J06cKLMOl8vF+PHjAQDp6ekICAiQKC9+A8nKC/T582dkZ2dX+Q9GbVsbaFlbSf4nwyD5v22I9vYG//Nnudvq6PHQdUxD9J/ZHPomOlLlzQp58M7Uhb2AC78nn9B/cyCexKRr+AgIIYQQQgghhBDNoWBQJQoMDAQAGBoaKhw+KT5ntTgLfsmya9euSW139epVqe2rIsN27eDk5wdDGcna8kLDEDnQC1n+ATK2/H+Oza0w6ldXODlbS5WZMlyMzNZB5zwe4lJyMXRbELbffg+RiAbdEUIIIYQQQgj58lAwqBK9fPkSAFC3bl3wFExVatiwodQ2xbp37w4nJyccOXIET548Yf8/IyMDq1evho6ODjuyqCrjWVrCYecOVFu4ACjxXAszMhA7cyYS1qyBqLBQbhv6xjroM60puk9oBG09ycRzHHDgWqCNcVm6MOcDay+/wvi9IUjMzC+X4yGEEEIIIYQQQkqLlpavJPn5+UhOTgYApXP+zM3NYWhoiJycHMTExEiU8Xg87N69G71790anTp0wcuRIGBsb49SpU4iOjsaGDRvUzt4fGxursDw+Pl6t9r4UHC4XlpMnw6B1a8QtWAh+XJxEearPAeSGhsHur43QqVVLdhscDhq62cC2nhlu+rzEp7fpEuXWIi7GZukiUE+Ae2+T0XfTXWwY3gJdG1Qrr8MihBBCCCGEEELUQiODKon4MvFGRkZK6xsaGgIAsrOllzHv2rUrAgMD0aFDB/j6+mLbtm2oXr06jh07hgULFqjdNwcHB4U3V1dXtdv8kug7O6P2mdMwlrH6QP7z54gcPAQZFy4qbMPESh8D57VE+8F1weVJrkHPAwdd8rUxIlsH/Ew+Ju57iFUXXqBAINTocRBCCCGEEEIIIaVBI4MqSX7+/08f0tGRTkxckq6uLgAgLy9PZrmrqysuX76smc5VAVomJrDb9A/SfX2RsHoNGLHpYaKcHHxauBA594NQY8kScA0MZLbB5XLQsldNODS2wI19L5ASJxmocxBqwTuLi5v6fOy5G4kHkSnYPLIlnKyVB/8IIYQQQgghhJDyQiODKomenh57v1BBnppiBQUFAAB9ff1y61OxmJgYhbeQkJBy70NF4HA4MB85Eo4njkPHyUmqPOPUaUQOG478128UtmNlb4RhP7dGy141AclBQtAFB/3ydOCVq4P3MZkY8G8gTobFgmEouTQhhBBCCCGEkMpBwaBKYmxszN6XNfWrpJycHACqTSkrK3t7e4U3Gxubcu9DRdJr0AC1T56A6aBBUmWF798javhwpB3zVRjA0dLmov3guhg0vyWMLfWkyuvxtTAxSw81chgsPBGOH3yfICufr9HjIIQQQgghhBBCVEHBoEqip6cHS0tLAMoTNqelpbHBIAcHh3LvW1XENTCA7ZrVsP1zndS0MKagAJ9XrEDcvPkQZmYqbMe2njlGLnVFw/bSATNDhoMhObrolauNS48/of/mQDyJSdfkYRBCCCGEEEIIIUpRMKgSNW7cGADw7t07CAQCufVevXrF3m/UqFG596sqM/X0RO3Tp6DbWPp5zrpyBZGDBiPv6VOFbejo89B9fCP0nd4MekbaUuUtCnnwztKFIDEfQ7cFYfvt9xCJaNoYIYQQQgghhJCKQcGgStSxY0cARVPAwsLC5Na7ffs2e79Dhw7l3q+qTsfREY7HjsF83DipMn5cHKJGj0HKnr1gRCKF7Tg5W2PUsrZwbGYpVWYm4mJUtg7a5Wjhz0uvMH5vCBIz82W0QgghhBBCCCGEaBYFgyqRl5cXe3/fvn0y64hEIhw4cAAAYGZmhq5du1ZE16o8ro4OaixZDPv/toJraipZKBAgcf16xEyfDkFqqsJ2DEx00G9mc3Qd2xA8XS3JfYADtwJtjM3WxcvXKei76S4CXidq+lAIIYQQQgghhBAJFAyqRK6urnB3dwcA7NmzB/fv35eqs3HjRrx8+RIAMHfuXGhrS087IuXHuFs3OPmdgb6Li1RZzp27iBzohZzgBwrb4HA4aNzRFiOXtkENJ1Op8upCLsZn6cIxRYiJex9i1YUXKBAINXYMhBBCCCGEEEKIOF5ld+BrFhgYiHfv3rGPk5OT2fvv3r3D/v37Jep7e3tLtbFp0yZ06NABeXl56NWrFxYvXoyuXbsiLy8Px44dw86dOwEA9evXx4IFC8rlOIhi2jY2qHXAB0lbtyJl+w5AbFUxQVISPk6cCKsZM2A1cwY4PPl/UqbWBhi00AWPr0Uj5FykRJ4gHjjolq+DOgIhjt+JwoPIFGwe2RJO1uW/ehwhhBBCCCGEkKqFwyhaL5so5O3tDR8fH5Xry3uqz58/j7FjxyJTzkpV9evXx8WLF1G3bt1S9VPTYmNj2VXNYmJiYG9vX8k9qjg59+8jbtEiCJOSpcr0W7eC3YYN0K5RQ2k7SR+zcH3fC6TF50iV5YPBTQM+ogyB37yaYoiLHTgcjkb6TwghhBDyLXn79i0EAgF4PB7q1atX2d0hhBCFSnPOKq/f3zRN7Avg4eGBp0+fYt68eahfvz4MDAxgZmaG1q1bY926dXj8+PEXEwiq6gzd3OB05gwM/5f8W1xeaBgivQYhKyBAaTvWNY0xfHFrtOjuIFWmBw765+qgezoXS33D8YPvE2Tl8zXSf0IIIYQQQgghhEYGEbVV5ZFBxRiRCKl79yLxn02AQCBVbjFhPKwXLABXR0dpW7GvUnHT5yWy0wqkyrI5DC4bFEJUXQ+bR7WEs4OZJrpPCCGEEPJNoJFBhJCvCY0MIl+VJk2aSNy6detW2V2qdBwuF5aTJ8Px0EFo29pKlaf6HED0qNEojI5W2pZ9QwuM/NUV9dtWlyozYjgYlqOL+rF8jPwvCNtvv5fINUQIIYQQQgghhKiLgkGkysgqzNJ4m/rOzqjtdwbGvXpJleU/f47IwUOQceGi0nZ0DbTRc2IT9J7SFLqG0kmoWxbyMCZDBz7nX2P83hAkZuZrpP+EEEIIIYQQQqoeCgYRpZ4/fy5x8/f3r+wuqS2zMBNefl74Pfh3jQeFtExMYLfpH9RYvgycEtPCRDk5+LRwIT4tWQJRbq7Stuq2qoZRv7ZFzcYWUmUWIi5GZ+tCFJGOfv/cRcDrRI0dAyGEEEIIIYSQqoOCQaRK+DvsbyTmJcL3tS+8znrB/6NmA1ocDgfmo0bB8cRx6Dg5SZVnnDqNyGHDkf/mjdK2DM10MWBOC3QeVR88bck/US446FCgjT4JHMzfHYpVF16gQCDU2HEQQgghhBCiipycHGzfvh39+vWDnZ0d9PT0oKurC2tra7Rp0waTJk3Crl27EBMTI7Wtt7c3OByO1I3L5cLMzAwtWrTArFmz8OTJk3Lpe1RUlMz9czgc6OnpwdbWFr169cKmTZvkrvgsztHRERwOB46Ojkrrzp8/n91XvXr1ZD4/hFQESiBN1Pa1JZB++PkhJl2dJPX/PWv1xC+uv8DawFqj+xPl5uLzqt+RceaMVBlHVxfVFy+G2fBhKi0Xn56Qi+v7XiAxSvpDiA8Gt/X54Nc2wOZRLnCyNtJI/wkhhBBCvhaUQLpy3L9/HyNHjsTHjx+V1q1evTo+f/4s8X/e3t7w8fFRui2Xy8XPP/+MP/74o9R9lSUqKgq1a9dWqa6DgwP8/Pzg4uIit46joyOio6NRq1YtREVFyazDMAy+//57bNmyBQDQsGFD3Lx5E7Yy8o+SbxclkCakAn3O+Qx9nr7U/1+Pvo6BfgNx6s0paDImyjUwgO2a1bD9cx24BgYSZUxBAT4vX464efMhzFI+Xc2sugGG/OgCV4/a4JT4a9UGBz3ydNDwTQFG/BOIk2GxGj0OQgghhBBCSnrz5g169+7NBoI8PT1x4MABBAcH49GjR7h27RrWr1+PXr16QVtbW2l7V69eRUREBCIiIhAeHo5r165h7ty54PF4EIlEWL16Nf77779yO56BAwey+4+IiMCdO3ewc+dONGrUCEDRj+/+/furNEJIHoZhMH36dDYQ1KRJE9y6dYsCQaRS0cggoravbWQQAMRmxWJV8CoEfQqSWd6mRhssd1uOWia1NLrfgshIxC1YgIIXL6XKtO3tYffXRug3b65SWwlRmbix7wXSE6RzD+VxGFzX56Oha3X87tUUxnrKP3gJIYQQQr52NDKo4g0bNgwnT54EAOzbtw/e3t5y6yYlJeH48eOYNWuWxP+LjwyKjIyUOb3q/Pnz8PT0BABYW1sjPj4eWlpaGjkG8ZFBEyZMwP79+6Xq8Pl8dOrUCcHBwQCA9evXY+HChTLbUzQySCQSYfLkydi3bx8AoEWLFrhx4wasrKw0cizk60IjgwipYPbG9tjeYztWd1wNM10zqfKHnx9i8NnB2B2xG3wRX2P71a1dG47HjsF83DipMn5sLKJGj0HKnr1gRCKlbVV3NMHwJW3QrIv0H78+w4Fnrg4E95Lg9fddPIlJ10T3CSGEEEIIYQmFQly8WLRSbuvWrRUGgoCiIE7JQJCqPDw84O7uDqAoqPTo0aNStVNa2tra+P3339nHN27cULsNoVCICRMmsIGgVq1aISAggAJB5ItAwSBSZXA4HHjU8cBZr7MY4DRAqrxQVIhNjzZh5IWReJb8TGP75erooMaSxbDfugVcU1PJQoEAievXI2b6dAhSU5W2pa2jhU4j68Pj+xYwMNWRKm/M56HnRwbzNwdj++33EIlo4B8hhBBCCNGMpKQk5OXlAQDq1q1b7vtzdXVl70dHR7P3P3z4gI0bN8LDwwOOjo7Q19eHvr4+atWqhREjRuDKlSsa2X+zZs3Y++omehYIBBgzZgwOHToEAGjXrh1u3rwJc3NzjfSNkLKiYBCpciz0LLDGfQ229dgGW0Ppebpv0t5gzKUxWP9wPXL5ypeDV5Vx9+5w8jsDfRnJ53Lu3EXkQC/kBD9Qqa2ajS0xallb1G1dTXo/DAdDsnTw4NR7eO9+gMTM/DL3nRBCCCGEEB2d/78Y+fKldBoETRPPOSQUFq2gGxkZiTp16mDhwoW4cOECoqOjkZ+fj/z8fHz8+BHHjx9H3759MW7cOAgEgjLtX/x4Vcl/VIzP52PEiBHw9fUFAHTs2BHXrl2DackLw4RUIgoGkSqro11HnBl4BmMbjQW3RHZmESPCgRcHMPjcYNyLu6exfWrb2KDWAR9YTp8GlFhNTJCUhI8TJyJp879gVPjg0jPURu/JTdHzu8bQ1pOeP926kIc6T7Ixdv1dBLxO1NgxEEIIIYSQqsnCwgK1ahXl2AwPD8e6desgUiHdQWlFRESw94uTLQuFQujo6MDDwwObN2/GjRs38OjRI9y4cQP//fcfmjRpAgA4dOgQVq1aVab9iwe8VFk2HgAKCwsxdOhQnD59GgDQtWtXXLlyBcbGxmXqCyGaRgmkidq+xgTSykQkRWD5/eV4m/ZWZrmHkwd+bPMjzPU0N6wz5/59xC1aBGFSslSZQevWsN2wHto1aqjUVnZaPm7sf4G41+lSZUIwCNIToEk3eyzq1xC6PM0k3iOEEEIIqWwqJWMViYA85dPxvyn6FgC3fK77b9y4USKRsqOjIzw9PdG+fXu4urqqtGS7Kgmkw8PD0apVKwiFQhgYGCAlJQV6enrIyclBZmYmbGxsZLbNMAwmTZqE/fv3w9DQEHFxcVIjclRJIA0UrZR2/vx5AMDBgwcxduxYmfWKE0jb2trC2dkZly5dAgD07NkTZ8+ehb6+9MrGpGr6khJIUzCIqO1bDAYBAF/Ex/5n+7E9fDsKRYVS5ea65ljkugj9a/cHp8SontISJCfj008/I+ee9OgjLTMz2KxZDeOuXVVqixExeHorFvdOvwMjkP6z/qQlwtvaOlg7wQVO1kZl7jshhBBCSGVT6YdVTjKwvk7Fdqyy/fgeMCyfJMUikQhTpkzB3r17ZZZXr14dXbp0wZgxYzBgwACZ35vlBYMYhkFCQgIuXLiAX375BcnJRRdNf/75Z6xZs0blPqampqJatWoQCoU4efIkhgwZIlGuKBiUnp6Oly9fYs2aNWwgyM3NDbdv35Y7Vaw4GCSuc+fOuHLlCvT09FTuN/n2fUnBIJomRsj/aHO1MaX5FJzyPIXW1VtLlacVpOGXu79gxs0ZiMuO08g+eVZWcNi1E9UWLgB4PIkyYXo6YmfMRMKaNRAVSgenSuJwOWjRzQEjl7jC3N5QqtxWyIXbOz5++jMIJ0JjQHFgQgghhBCiLi6Xiz179uDatWvo06cPeCW+wyYkJMDX1xeenp5wdXXF+/fvFbZXu3ZtcDgccDgccLlc2NjYYMqUKWwgqH///vjtt9/kbs/n8xEbG4uXL1/i2bNnePbsGT59+gRLS0sARSOMFPHx8WH3z+FwYG5ujvbt2+P8+fPQ1taGt7c3rly5olLOIPHAV0REBN68eaN0G0IqCwWDCCnB0dQRe3rvwXK35TDWlp7bey/uHgadHYSDLw5CKBKWeX8cLheWkyfD8dBBaNtKJ7RO9TmA6FGjUVjiaoM8FjaGGPFzG7TqWwsocSFGBxx0yeLh/v5XWHDgEbLy+WXuPyGEEEIIqXp69uyJy5cvIyUlBZcuXcLKlSvh4eEhMSUrNDQU7u7uiI+PV6ttHR0ddOjQAT4+PmxQRhyfz8fWrVvRrl07GBkZwcHBAY0bN0azZs3YW2JiUc7M4qBSadSrVw/z5s2DiYmJSvVr1qyJH3/8EUDR6KSePXvi1atXpd4/IeWJgkFEqSZNmkjcunXrVtldKndcDhdD6w/FWa+z6Fmrp1R5niAPfz78E2MvjcXr1Nca2ae+szNqnzkN457S+8t//hyRg4cg48JFldrS4nHRbmAdDPmxFfQtdKXK6wi0YBOcjqlr7uJJTHpZu04IIYQQQqooExMT9O3bF8uWLcO5c+eQkJCAvXv3skuox8fH49dff5W7/dWrVxEREYGIiAg8f/4c0dHRyMrKQmBgIMaPHy81zSw1NRVubm6YPXs2Hjx4gEIlI+jz8vIUlg8cOJDdf3h4OC5fvoy5c+dCT08PL168QJcuXfD6terf9//880/Mnj0bAJCYmIgePXrgw4cPKm9PSEXhKa9CSNVlbWCNv7r8hZvRN/HHgz+QlJckUf4s5RlGXhiJiU0nYlqLadDVkg68qEPL1BR2mzch/dgxJKxZC0bsw02Uk4NPCxciJ/g+aixeDK6BgdL2ajiZYuyytrhz4i1e35O8ImPAcNAxAdi5/iGaeThiWve64HI1kwuJEEIIIeSLoW9RlEOnKtG3qLRd6+rqYuLEibC1tUWfPn0AAKdPn8bOnTvBlZHUun79+iqv1AUAc+fORVhYGADAy8sLkyZNQvPmzVGtWjXo6emxwaOaNWsiJkZ5agQzMzM0bdqUfdy8eXP06dMHHh4e6NOnD9LS0jB69GiEhIRAS0u1hVg2b96M3Nxc7N27F3FxcejevTvu3LnD5n0h5EtAwSCi1PPnzyUeiyewqiq61+oOVxtX/BP2D46/OS5RJmAE2BWxC9ejr2OZ2zK0qdGmTPvicDgwHzUK+i1bIm7efBRGRkqUZ5w8hbzHT2D391/Qq19faXs6ejz0GNcIdZ2tcWXvcwjzJKe2NSnUQtrpaMx+loQV37mgmgkluSOEEELIN4TLLbdkykS+3r17w8HBATExMUhLS0NKSgqsra3L1GZmZiZ8fX0BAGPGjMGhQ4fk1k1LSyvTvrp37465c+di48aNePToEfbv34/vvvtOpW05HA527dqF/Px8HDlyBFFRUWxAqIaKqwUTUt5omhghKjLWMcavbr9if5/9cDRxlCqPyozCpKuTsCJoBTILM8u8P72GDVH71EmYDhokVVb4/j2ihg1Hmu9xlRNBOzazwoRVbrBpIn2lyJThovHrPCz7LRA3n38uc98JIYQQQgixFcuHqYnVeN++fQs+vyjn5YgRI+TWe/XqFbKzs8u8v8WLF7P5glauXKl0Spo4LpcLHx8fDB48GEBR33v06IGUlJQy94sQTaBgECFqalW9FU56nsTU5lPB40gPrjv19hS8/LxwI/pGmffFNTCA7ZrVsP1zHTglpoUxBQX4vHw54ubNhzArS6X29I10MGh2C3Qd3xAMT/IDmQMOmmVzcXfrM6w+Eo4CQdmTYxNCCCGEkKopNzcXL168AFCUV6h4da+yEAgE7P2cnBy59bZv317mfQGAhYUFZs2aBaBoSW8fHx+1tufxeDh69Cj69u0LoGjGRa9evZCRkaGR/hFSFhQMIqQUdLV0MaflHPh6+KK5VXOp8qS8JMy7NQ9z/eciISehzPsz9fRE7VMnodu4kVRZ1pUriBw0GHlPn6rUFofDQeP2thi/sh2MHKSXoLcWcWF0Jxk//nYX7xNUCzIRQgghhJBvX3Z2Ntq2bYsLFy5AJBLJrScSiTBnzhxk/e+Cpaenp0ZGBtWtW5dtx8fHR+YI+fPnz2PLli1l3lexefPmweB/F2XXrl0LoVC9C6Y6Ojo4ffo0uwjPo0eP0KdPH42MXCKkLCgYREgZ1DevjwN9D+Bn15+hz9OXKveP8YfXWS8cf30cIkb+B6YqdGvXhuOxYzAfN06qjB8bi6jRY5Cydx8YBR/M4kws9TH+F1e09HCEqMRnsxY4aJAowt7fH8A3IFLlqWiEEEIIIeTbFhISAg8PD9SsWROzZ8/G4cOHERgYiPDwcNy+fRv//PMPnJ2dsXfvXgCAqakpVq1apZF9W1paol+/fgCAK1euoFevXjh9+jTCwsJw+fJlTJ48GYMGDYKTk1OZ8xMVs7a2xpQpUwAAHz58wJEjR9RuQ09PD+fOnUOHDh0AAMHBwRgwYIDSlc4IKU8UDCKkjLS4WhjTaAz8BvrB3c5dqjybn41Vwasw8cpEfMgo27KSXB0d1FiyGPZbt4BraipZKBAg8c8/ETN9OgSpqSq1x+Fy0L6/E0YtaQOuuY5UuS2fi0++H7DsnwfIzFN9jjQhhBBCCPn28Hg8NgFyXFwctm7dirFjx8Ld3R3Ozs7o0qUL5s2bh4iICABAvXr1cPPmTbVWC1Nm27ZtqFmzJgDgxo0bGDJkCFq3bo1+/fphz549sLOzg5+fHzuaRxMWLlwIHZ2i78pr1qxROCpKHkNDQ1y6dAmtW7cGANy+fRuDBg1SKw8RIZpEwSBCNMTWyBZbu2/FOvd1sNCTTtL8KPERhp4bih3hO8AX8su0L+Pu3eF05jT0XVykynLu3EXkQC/kBD9QuT0re2NM/a09anaogZJjgHTAQY3XuVizJBAhr5LK1G9CCCGEEPL10tPTQ1xcHO7du4eVK1eib9++cHJygqGhIbS0tGBiYoKGDRtixIgROHLkCJ49e4ZWrVpptA8ODg549OgRfvzxR9SvXx+6urowNTVFixYtsHz5cjx58gSNGzfW6D7t7e0xYcIEAMDLly9x6tSpUrVjYmKCq1evonnzojQTV69exYgRIyRyIRFSUTgMzf8gahJfWj4mJgb29vaV3KMvT3p+OtaHrse59+dkltc1q4uV7VeiubV0viF1MAIBkrZsQcqOnUDJP2UOB1YzZsBq5gxweNKJruWJfJmCczsiwMuXvuKRw2Fg3LEapo1qCi637PO+CSGEEELK4u3btxAIBODxeKhXr15ld4cQQhQqzTmrvH5/08ggQsqBmZ4Z/uj4B3b03AE7Izup8nfp7zD20lisDVmLHL78lRCU4fB4qPbDD6i5Zze0rKwkCxkGyf/9h4/eE8H/rPpy8bUbWWLG2o4wamQqVWbIcCC6m4QVy+4iLrH0/SaEEEIIIYQQUnkoGERIOWpv2x6nPU/Du4k3uBzJPzcGDA6/PIxBZwfhTuydMu3HsH17OPmdgeH/ktKJyw0NRaTXIGQFBKjcno4eDxPmtoLruPoolDGoqHqyAAdXBuOif2RZuk0IIYQQQgghpBJQMIhUHfc2A8nvKny3BtoGWNB6AY70P4KGFg2lyuNz4jHr5iwsurMIKXkppd4Pz8oKDrt2wnrBfEBLS6JMmJ6O2BkzkbBmDRg1ktS16WCPyb+3h9BGT6rMWMjBh+MfsGH9A+Tmly0HEiGEEEIIIYSQikPBIFI1vL0BXP8V2N4BuLcJEFZ8krYmlk1wpP8R/ODyA3S1dKXKL0dexsCzA3H23dlSL+XO4XJhNWUKah06CG1bW6nyVJ8DiBo1GoXR0Sq3aWymhznL3FCztz34HMl+ccGB/vscbPr5LsKfJZaqz4QQQgghhBBCKhYFg8i3Lz8TOD+36L4gH7i+DNjTA0h4XuFd0eZq47tm3+G052m41nCVKs8oyMDSe0sx9fpUxGTFlHo/Bi1bovaZ0zDu2VOqLP/5c0QOHoKMCxdVbo/D4cBjUH0M+aUNso21pMpN8oFbWyLgsy8CIqH6S20SQgghhBBCCKk4FAwi3757m4DMWMn/+/QY2NEZCFgDCFSfNqUpNU1qYnev3fit/W8w1jGWKg+OD8bgs4Ph89wHAlHpRjFpmZrCbvMm1Fi+DBwdHYkyUU4OPi1ciE9Ll0KUl6dymw41TbBwrTt0XCwgLLEIPQ8cZD9IwsYlgYiPzypVnwkhhBBCCCGElD8KBpFvX8d5QNvpAEoshS7iA7fXAju7AHFhFd4tDoeDQfUG4ZzXOfR27C1Vni/Mx4bQDRhzaQxeprws9T7MR42C43Ff6NSuLVWecfIUIocNQ/6bNyq3qaXFxZSpzmg3rQkydaTLDdIFOPbbQ1y79L7U090IIYQQQgghhJQfDkO/1ogSTZo0kXjM5/Px9u1bAEBMTAzs7e0ro1vq+xgMnJ0NpLyVLuNwAbfZQNfFgLZ+xfcNwK2YW/g9+Hck5CZIlWlxtDChyQTMaDEDejzpZM6qEOXm4vOq35Fx5oxUGUdXF9UXL4bZ8GHgcDgytpYtK4ePrZtDYRSdC07JYBsAOBhg4hwXGJjIiBoRQgghhJTR27dvIRAIwOPxUK9evcruDiGEKFSac1ZsbCwcHBwAaPb3N40MIlVHzXbA9ECg43yAUyLvDSMCgjYD2zoA0UGV0r0uDl3gN9APIxuMlAqsCBkh9j7bi8HnBuNB/INStc81MIDtmtWw/XMdOAYGEmVMQQE+L1+OuPnzIcxSfYqXsaE2fv7FDTYDayGLKyOuHJOLHUvuITwkvlR9JoQQQgghhBCieTQyiKitvCKTFerTk6JRQgkRssvbTAF6LAd0pfP5VIQniU+wPGg5PmR8kFk+qO4gLGi9AKa6pqVqvyAyEnHzF6DgpfT0M217e9j9tRH6zZur1WZkfCb2bH4EmzTZCaTNm1lg2ORm0NaVTkBNCCGEEFIaNDKIEPI1oZFBhFQ2W2dgagDQdSmgJWMK08NdwH9uwLubFd41AHCu5owTHicws8VM8Lg8qfIz787A088TV6KulCovj27t2nD0PQbzsWOlyvixsYgaPQYpe/eBEam+MlhtGxOs+L0T+G0tkMeR7lNaRCq2LQ5EzLs0tftLCCGEEEIIIURzKBhEqi4tbaDzj8C0u4Bda+nyjBjg0GDAbxaQV/EBDB0tHcxwnoGTHifRwrqFVHlqfip+vP0jvvf/Hp9zPqvdPldHBzWWLoH91i3gmpYYYSQQIPHPPxEzfToEqakqt8nT4uKHic5wnd4EcXrSASGtHCH8NjzGlWOvIKQl6AkhhBBCCCGkUlAwiJBqDYHvrgG9/gB4MpJHPzkEbG0LvLxQ8X0DUMesDg70PYAlbZfAUNtQqvxW7C14nfXC0VdHIWLUD7AYd+8OpzOnoe/iIlWWc+cuIr0GIedBiFptdmpRAwt/74gYJz0UlliCngvg/a1P2L38PlLjs9XuLyGEEEIIIYSQsqFgECEAwNUC2s8GZtwDHN2ly7MTAN8xwAlvIDup4rvH4WJkw5HwG+iHLvZdpMpz+DlY/WA1JlyegPfp79VuX9vWFrUO+MBy2jSgxGpigsREfPT2RtLmf8EIBCq3aWGkizU/usF8UE3E86SDVILkAhxeFYKQa1FgRJS6jBBCCCGEEEIqCgWDCBFnWQcYfw4Y8DegIyN59PMzwFZX4OkJoBJyr9cwrIHN3TZjQ+cNsNCzkCp/kvQEQ88PxbYn21AoLFSrbQ6Ph2rzfkDNPbuhZWUlWcgwSP7vP3z0ngj+Z9WnpHE4HHj3rodRv7TBC0sOhCVHCYmAh6c/4PD6UGSnFajVX0IIIYQQQgghpUPBIEJK4nKB1pOAWcFAvV7S5XmpwOnJwNGRQOanCu8eh8NBb8feOOd1DoPqDpIqF4gE+C/8Pww7PwxPEp+o3b5h+/Zw8jsDww4dpMpyQ0MR6TUIWQEBarXZxM4UG5Z3QnJbc6RwpUcJZURmwWfZfbwOUT/3ESGEEEIIIYQQ9VAwiBB5TO2B0ceBQTsAfXPp8jdXinIJhe2vlFFCprqm+K3Db9jdazccjB2kyj9kfMD4y+PxR/AfyC5ULzcPz8oKDrt2wnrBfEBLcil4YXo6YmfMRMKaNWAKVR99pK+jhd8musBlSiNEGAilK/BFuLH3Bc5te4r8HL5a/SWEEEIIIYQQojoKBhGiCIcDtBgJzAoBGg+ULi/IBM7PBQ54AqmRFd8/AG1t2uK052lMajoJWhzJwA0DBsdeH8PAswNxK+aWWu1yuFxYTZmCWocOQtvWVqo81ecAokaNRuHHj2q1O6ClHZb+2gFPnLSRJWMJ+pjwZPgsv4+YF6qvYkYIIYQQQgghRHUUDCJEFUbVgOEHgOEHAcNq0uWRd4Bt7YHgbYBIxqiXcqbH08O8VvNwtP9RNLZsLFWemJuIOf5zsPD2QiTnJavVtkHLlqh95jSMe/aUKst//hyRgwYj4+JFtdq0NzfAtgUdoNPfFi+1pZNSC7IFOLf5CQKOvga/sOKfT0IIIYQQQr4Wjo6ORbk6vb3LbR/e3t7gcDhwdHQst32QikXBIELU0dgTmPUAaDFKuoyfC1z5GdjbB0h6XfF9A9DIshEO9zuMha0XQk9LT6r8atRVePp54szbM2DUmNqmZWoKu82bUH3Zr+Do6EiUiXJy8GnBQnxauhSivDyV2+RpcTF/QCOM+aEV7loyyJcxSujF7Tgc/u0BEqMzVW6XEEIIIYRUHIFAgFOnTmHq1Klo1qwZqlWrBm1tbZiamqJu3boYNGgQ1q9fj8jIyhlFT4owDINz585h1KhRqFevHoyMjMDj8WBmZoamTZti2LBhWL9+PcLDwyu8b127dgWHwwGHw0GvXjJytsrRpUsXdjvxm5aWFiwsLNCqVSvMnTsXz58/V9rWihUr2O1v3bqlsG5gYCBMTEzA4XDA4/Fw6NAhlfv8JaFgECHqMrAABm0HxpwETOyly2NDgO0dgTsbAGHF577hcXmY0GQCTg88DTcbN6nyrMIsLAtahinXpuBjpupTvDgcDixGj4bjcV/o1K4tVZ5x8hQihw1D/ps3avXXrY4lti12x3sXY0TxpEcB5STn48TaUDy8GAmRUDr5NCGEEEIIqRznzp1Do0aNMHToUOzatQvPnj1DUlISBAIBMjMz8f79e/j5+WHRokVwcnLCgAED8OzZs8rudpWTkJCATp06YeDAgTh27BjevXuHnJwcCIVCZGRk4Pnz5zh58iQWLVoEZ2dnvHr1qsL6Fh0djdu3b7OPb968iU+fyrZIj0gkQlpaGh49eoTNmzejRYsWWLt2bVm7CgC4desW+vTpg6ysLPB4PBw5cgRjx47VSNsVjYJBhJRWvZ7AzPtA6++ky4SFgP8qYFc3IP5pxfcNgIOxA3b03IE/Ov4BU11TqfIHnx9g8LnB2BOxB3yR6kErvYYNUfvkCZgOkl7JrPDde0QNG4403+NqjTwyN9TB1slt0HR0Xdwy5INfYgl6MEDI+Uic+DMM6Qm5KrdLCCGEEELKx++//w4vLy+8e/cOQNEojQ0bNuDatWsICwvD3bt3cfz4ccyePZudWnTx4kVs2bKlEntd9RQWFqJnz54IDAwEALRs2RKbN2/GnTt38PjxY9y+fRvbt2/H6NGjYWoq/ZuhvB08eBAMw0BXVxc8Hg8ikahUI20iIiLYW1hYGE6cOIExY8YAAIRCIX755RecOHGiTH29ceMG+vXrh5ycHGhra+P48eMYPnx4mdqsTBQMIqQs9EyAAX8B3hcBCyfp8s9PgV1dgZurAEFBhXePw+HAs44nzg48i361+0mVFwgL8M+jfzD64mg8T1E+fLIY19AQtmtWw/bPdeAYGEiUMQUF+Lx8OeLmz4cwK0utvk5oXxu/LXTDLUctfNaSHgWUHJ2Fo7+H4NmdOLWCTYQQQgghRHP27t2LX3/9FQzDoHr16ggICEBAQAAWLFiAnj17wsXFBR07dsSwYcPw77//4t27dzh06BBq1qxZ2V2vcnbt2oWIiAgAwMSJExEaGoo5c+bA3d0dzs7O6NSpE6ZNm4bDhw8jISEB+/btg5mZWYX17+DBgwCAAQMGsFPEiv9PHU2bNmVvLi4uGDp0KA4dOoTNmzezdVauXFnqfl6+fBkeHh7Iy8uDrq4uTp8+jUEyLo5/TSgYRIgmOHYEpt8D3GYDnBJ/ViIBcHcDsN0diHlYKd2z1LfEuk7rsLX7VtQwrCFV/ir1FUZfHI2NoRuRJ1A974+ppydqnzoJ3UaNpMqyLl9B5KDByHuq3sioRjYmODrfHYKu1RCky4eoxCghEV+E20de48KWcORkVHyAjRBCCCGkKouJicGsWbMAACYmJggMDESXLl0UbqOlpYUxY8YgPDwc/fv3r4BekmJnz54FAPB4PPz111/gcuWHAHR1deHt7Y0aNaR/L5SH4OBgvPlfiokxY8aw062ePXuGR48eaWQfs2bNYoOQz58/x+fPn9Vu4/z58/Dy8kJ+fj709fVx9uxZDBgwQCP9q0wUDCJEU3QMgN5/AN/dAKylgyNIfg3s6Qlc+QUozKn4/gHoZN8JfgP9MKbRGHDAkSgTMSLsf74fg84OQtCnIJXb1K1dG47HjsJcxlxZfmwsokaPQcrefWBEquf70dfRwpqhzTFmcnOctRAijSu97cfnqTiy8gHeP0pUuV1CCCGEEFI2f/31F/Lz8wEAf/zxB+rWravytmZmZvDw8JBb/vnzZyxZsgStW7eGhYUFdHV14eDggOHDh+PGjRtyt4uKimKT/+7fvx8AcP36dXh4eKBGjRrQ1dVF7dq1MWPGDMTGxqrU14CAAEyYMAFOTk4wMDCAiYkJmjVrhh9//FFhThvxRMQAkJGRgVWrVqFly5YwMzOT6CMA5OTkwNfXF5MnT4azszNMTU2hra0Na2trdO7cGRs2bEB2drZKfZbl48eiHKFWVlYaHfGTnp6OZcuWoUmTJjA0NISZmRk6deqEw4cPq9zGgQMHAADm5ubo378/vLy8YGxsLFFWVlwuF02aNGEfx8TEqLX96dOnMWTIEBQWFsLAwAAXLlxA7969NdK3ykbBIEI0zb4VMO020PkngMsrUcgAwf8VLUMfeadSumeobYifXX/GoX6HUNdM+sM7LjsO065Pw5LAJUjPT1epTa6uLmosXQL7Lf+CW3KusUCAxD//RMyMGRCkpqrV137NbLD/x4542lgfj3Wkl6AvzBXgys5nuLHvBQrypMsJIYQQQojmMAzDTuExNjbGxIkTNdb24cOHUbduXaxevRphYWFIS0tDYWEhYmNjceLECfTs2ROTJ0+GQKD8O98vv/yCXr164cKFC0hISEBhYSGioqKwfft2uLi44OXLl3K3zc/Px6hRo9CtWzccOHAAkZGRyMvLQ1ZWFp49e4YNGzagfv36OH/+vNJ+vH37Fs7Ozli2bBmePHmCjIwMqTr9+/fHyJEjsWfPHoSHhyMzMxMCgQDJycm4c+cOfvzxRzRv3rzUSZ11/rcScEJCAlLV/C4uz+vXr9GyZUusWrUKL168QG5uLjIyMnD37l2MHTsWs2fPVtpGYWEhfH19AQDDhg2Djo4O9PX1MXjwYADA0aNHVXqtVaEjthqytra2ytv5+vpixIgR4PP5MDIywpUrV9CtWzeN9OlLQMEgolSTJk0kbt/SH0C54ekCXRcDU28BNs7S5WlRgI8HcH4ukC/9oVARmls3x/EBxzHbeTa0udInxXPvz2Hg2YG49OGSyvl5jHv0gNOZ09B3cZEqy7l9B5Feg5DzIEStftqbG+DIDDc0HlALp4wKkC1jCfrXDz7j2G8PEPc6Ta22CSGEEEKI6p49e4aUlBQAgLu7OwwNDTXS7vHjxzFu3Djk5OTAyckJf/31F65cuYKwsDCcOnUK/foV5b7cs2cPFi1apLCtXbt2Ye3atejcuTOOHDmC0NBQ3LhxA+PHjwcAJCUlYdKkSTK3ZRgGQ4cOxbFjxwAAHh4eOHjwIO7du4f79+9j06ZNqFmzJnJycjB06FCEhoYq7MvQoUMRFxeHOXPm4Pr16wgNDcXRo0fRoEEDto5AIECzZs2wZMkSnDlzBg8ePEBwcDB8fX0xcuRIcLlcREZGstOU1OXyv+/lDMNgypQpZRplBAC5ubnw8PBASkoKli5dilu3biE0NBS7du2CvX3RSstbt27F1atXFbZz4cIFNjglvhpX8f3ExERcuXKlTH0tJh78q1WrlkrbHD58GGPGjIFAIICJiQmuXbsGd3d3jfTnS8FhKAsrUUJ8WB0A8Pl8vH37FkDRMLviP3oih1AA3P8XCFgDCGXkuDG2BTz+AepX3nDDDxkfsDJoJR4lyp6b627njl/b/QobIxuV2mMEAiT9uwUpO3cCJU8xHA6sZsyA1ayZ4GhpqdXP++9T8PORJ2iRKEIDvoxtOUCL7g5oN9AJPG312iaEEELI1+ft27cQCATg8XioV6+ezDoiRoT0gvSK7VglM9M1A7dkHksNOHz4MPtjfenSpVi1alWZ20xOTkbdunWRkZGBSZMmYceOHeDxSo6uB5YsWYLVq1eDy+XixYsXEgGVqKgo1K5dm308ZcoU7Nixg52qJf7/u3fvBgA8evQILVu2lCjftWsXpk6dCm1tbZw7dw59+vSR6kdaWhrc3d3x/PlzdOjQgV2lq9iKFSvYRMVcLheXL19mEyPL8vbtW7nvXaBoBavevXtDJBJh9+7d+O476ZWMHR0dER0djQkTJkhMQQOAkJAQuLm5QfS/lA3FU/Xc3d3Rtm1bNGnSBFoqfCf39vaGj48PAMDU1BT37t2T+p347t07NGvWDPn5+fD09GTzFcni5eWFs2fPwtHRER8+fGBfK5FIBAcHB3z69AnDhg3D8ePH5bbRpUsXdll6eWGN4mleANC9e3e50w3FX7eJEyfCx8cHIpEI5ubmuHr1Ktq0aSO3H+pQ5ZxVUmxsLBwcHABo9ve39F8ZISU8fy65ypT4m5GoQIsHdJwHNBwAnJ0NxARLlmd9Ao4MB5qPAPqsBQwsKryLTqZO2NdnH06+OYm/w/5GNl/yisHduLsYeHYg5rrMxcgGI6HFVfyBweHxUG3eDzBs64q4RT9BmJz8/4UMg+T//kNuSAhsN6yHthoJ6tzqWMJvvjsWnQzHxfAU9MjVhq547iMGCL8Rg5jnqegxqTGsHYxVbpsQQggh36b0gnR09u1c2d2oULdH3IaFnua/UyaLfaeztraWW08kEuHFixdyyxs0aMBO19m2bRsyMjJgZ2eH//77T2YgCChaCcrHxwdxcXE4cOAA/vjjD5n1bGxs8O+//0oFggBg4cKFbDDo7t27EsEghmGwbt06AMD3338vMxAEFOW3Wb9+Pfr164d79+4pDOZ4e3srDAQBUBoQ6NGjBzw9PeHn5wc/Pz+ZwSBFXF1dsWPHDsycORN8Ph/p6ek4ePAgO93P0NAQ7du3x7BhwzB69GiVRnutWrVKKhAEAHXr1oWXlxeOHTsmFSQTl5KSgkuXLgEARo8eLfFacblcjB49Ghs2bMD58+eRnp6udq6jwsJCfPjwAWfOnMHvv/8OADAwMJD7nilp3759AAB9fX3cvHlTKmj4raBpYoRUFKt6wMTLQN8/AW0ZJ9mnvsBWV+C5X4V3DQC4HC6GNxgOv4F+6OYgPRUwT5CHtSFrMf7yeLxNe6tSm4bt28PJ7wwMO3SQKssNDUWk1yBkBQSo1U9zQx3sHN8aI4c1xGHzQnzkCaXqpMbn4OTaUIRdiYJIRIMfCSGEEEI0ISsri72vKGiQmZmJZs2ayb3FxcWxdc+dOwegaGlxXV1duW3yeDy4ubkBAO7fvy+33tChQ+W206BBAxgZGQEAPnz4IFH24sULvH//nm1DkU6dOrH3FfVlzJgxCtuRJSkpCW/fvsWzZ8/YW3HgLTw8XO32AGDy5MmIiIjAxIkT2QTNxXJycnD9+nVMnToV9erVUzo1i8PhYPTo0XLLW7VqBQBITU1Fenq6zDpHjx4Fn88HIDlFrFjx/+Xn5+PEiRMK+yPer+Kbrq4uGjVqhMWLFyM3NxcuLi64du0a2rZtq3JbAJCXl4eLFy+qtM3XiIJBhFQkLhdoOw2YeR9w6iJdnpMEnJgA+I4FshIqvHsAUN2wOjZ124S/u/wNK30rqfKnyU8x/Pxw/Pv4XxTImvZWAs/KCg67dsJ6wXygxBBUYXo6YmfMRMKatWAKC1XuI4fDwXg3Rxz6vgPCamvDX68QgpJL0AsZBPt9gN/GR8hIylO5bUIIIYQQIpt4ICEnp+yr4wqFQjx58gQA2Gldim4nT54EAIXLgzds2FDhPs3NzQFIBrYASOT/cXNzU9iP4oCSsr40b95cYV+K3bt3DyNGjIClpSWqVauG+vXrSwTPdu3aBUByZJa6GjRogL179yIlJQVBQUH466+/MGbMGIkpR/Hx8RgwYIDCldusrKxgaWkpt9zC4v9HpJV8josVTzdzcXFBo0bSqzC3aNECTZs2BVD2VcV0dHTw3XffoYOMi9PyrF69mn2v//rrr/j777/L1IcvFQWDCKkM5rWAcX6A57+Arql0+cvzRaOEnhyRzrlTQXrU6oGzXmcxtL70lREBI8DOpzsx9NxQhCWEKW2Lw+XCasoU1Dp0ENq2tlLlqT4+iBo1GoX/W/pSVY1sTHB+jjvqd7LDQeMCJMpYgj7+fQZ8fw/Bi3ufVE6ETQghhBBCpIkHAZKSkuTWMzMzA8MwErcJEyZI1UtNTS3VilG5ublyywwMDBRuy+UW/QQWCiVHlycmJqrdD2V9KQ48KbJixQp07NgRx48fV7raV15e2S9wamtrw83NDfPmzcOhQ4cQExODmzdvstO+hEIhZs6cKfd7s6rPb3FbJb18+ZINvMkaFVRs3LhxAIoCZZGRkYoPCkBERAR7u3PnDrZs2YI6deqgsLAQs2bNwvr165W2Uaxdu3a4cOECe6zz58/H9u3bVd7+a0E5gwipLBwO4DIeqNsTuDgfeH1Jsjw/HfCbATw7BQz4BzCr+DxNJjomWO62HP1q98PK+ysRnRktUR6VGQXvK94YVn8Y5rWaB2MdxTl6DFq2RO0zpxG/9FdkXb8uUZb//DkiBw1Gjd9WwrR/f5X7qK+jhTWDm8G9nhUWn3yKZulCuBbwwBXLJcQvECLg4CtEhiej69iGMDDRUdAiIYQQQr4lZrpmuD3idmV3o0KZ6ZqVS7stWrRg7z9+/LjM7YkHCyZPnoy5c+eqtJ34UuGaIt6X8+fPw9HRUaXtqlWrJrdMWWLmmzdvskmLnZycsHDhQnTs2BE1a9aEoaEhmz9p2bJlGknWLU+3bt1w/fp1NG3aFKmpqXj79i2ePHlSLrlyxEf6zJ8/H/Pnz1dYn2EYHDhwAMuXL1dYr3gkUTF3d3eMHz8eHTt2xNOnT7F48WJ06dJF5UTQnTp1gp+fHzw8PFBQUICZM2fCwMCAXZXuW0DBIEIqm4kNMPJIUdDn8iIgN0Wy/N0N4L92QM+VQKtJRVPNKlibGm1wyvMUdoTvwL5n+yBgJK/gnHhzArdjbmNx28XoXqu7wra0TE1ht3kT0o4eReLadRLTw0Q5Ofi0YCFyg4NRffFicPX1Ve5jv2Y2aG5vih+OPcGx9xnol6sNM5HkcxX1NBnHVj1AlzEN4eQsP+khIYQQQr4dXA63XJIpV0VNmzaFpaUlUlJScPfuXeTm5iodKaKI+JQihmGkftBXJPFRT2ZmZhXSl+LpX+bm5ggODpablFvZiCFNsLGxQf/+/dnE0u/evdN4MEgkEuHw4cNqb3fw4EGlwSBZjI2NceDAAbi4uEAgEGDBggW4c+eOytv37NkTJ06cwJAhQ8Dn8zFp0iTo6elh+PDhavflS0TTxAj5EnA4QLOhwKwQoKmMhHWF2cDFBYDPACDlfcX3D4Culi6+d/kexwYcQ1NL6Q/HxLxE/HDrB8wLmIfEXMXDbDkcDixGj4aj7zHoiC0DWiz9xElEDhuG/Ddv1OqjvbkBjk1thyG9nOBjUoBwHelhx3lZfFzeHgH/Ay9RmK/+sGRCCCGEkKqKw+GwU3syMzPZ3C+lpaOjw05PunfvXpn7VxbigY+K6kvxqs1du3ZVuDqbeD6j8mQrls5B1mpsZRUQEICYmBgAwJw5c3D06FGFtx9++AEA8P79+1K/Ji1atGATXt+9e1dpguySPDw8cPjwYWhpaUEoFGLs2LE4f/58qfrypaFgECFfEkMrYOgeYORRwEjGkuvR94Bt7YF7mwGR9BzcitDAogEO9TuERW0WQZ8nPXLnxscb8PLzwsk3JyFipHP4iNNr1Ai1T56AqZeXVFnhu/eIGjYcacePq5Xrh6fFxfxeDeAzpR2e1tDCacMC5HCkt38ZFA/f30Pw6V26ym0TQgghhFR18+fPh56eHgDgl19+USmfiyKenp4AgFevXuHq1atl7l9pubi4sMmUd+7cifz8/HLfZ3G+JEXJuB8/fowHDx6Ueh/qfI8WDzo5OTmVep/yFE8R09LSwtKlSzFy5EiFtyVLlrBT5cqSSHrJkiVsLqPipebVMWzYMOzduxccDgd8Ph/Dhg3D9RIpL75GFAwi5EvUsB8w6wHQcpx0mSAfuP4rsLsHkPCi4vsGQIurhXGNx+HMwDPoYCedmT+Ln4WV91di0tVJiMxQ/AWBa2gI27VrYLtuLTglhhkzBQX4vGw54ubPh1DOagTyuNWxxOW57nBqYY19xvl4qy0dPMtMzseZjY9w/8w7CPmKA1eEEEIIIQSoWbMmNm/eDADIyMhAx44dERgYqHAbhmHkLjM+d+5cdnWuiRMnsqNl5Ll48SKePn2qfseV4HK5WLx4MYCiZefHjx+PggL5K+dmZmZiy5YtZdpnvXr1AACBgYF49+6dVHlSUhKbSLm0Bg8ejP/++0/p6m/79+/HzZs3ARS9xpqeIpaTk4PTp08DKMrnoyjXUjErKyt07twZAHD8+HGFr4ciDRs2xODBgwEUjfoKCAhQu43x48dj27ZtAICCggJ4eXmpNeXsS0TBIEK+VPpmwMAtRauOmdWULv/0CNjRCbi1DhCoviy7JtkZ2WFb921Y674W5rrSqyWEJYRh6Lmh2PV0F/givsK2TAcORO1TJ6ErY3nJrMtXEDloMPLU/OA3N9TBznGt8MugJrhkIsBl/UIUlFiCHgzw6OpHnFgXipS4bLXaJ4QQQgipiqZMmcLmcPn06RPc3d3RvXt3/PPPP7h58yYeP36M0NBQXLhwAb/99huaNWuGs2fPAigaFSKeALp69erw8fEBh8NBfHw8WrdujRkzZuDcuXN49OgRHjx4gFOnTuGnn35CnTp1MGDAAHxUcwVaVU2fPh2DBg0CAJw4cQJNmjTB+vXrcfv2bTx58gR37tzBzp07MXr0aNja2mLFihVl2l9xMuKcnBx07twZ//77L4KCghAUFIQNGzagRYsWePHiBdzc3Eq9j5iYGMyaNQs1atTA6NGjsX37dgQEBODJkycIDg7G/v370b9/f0ycOBFA0fSwv//+W+PTxE6fPo3s7KLv2kOGDFF5u+K66enpOHfuXKn3XxzoA0o3OggApk2bxi4zn5ubiwEDBiAkJKTUfapslECakC9dna7AjPvAzd+AkJ2AeDBDxAdurQZenC0KHNm5VHj3OBwO+jv1R3vb9lj/cD3Of5CcQ1soKsTmx5txOeoyVrqtRDPrZnLb0q1dG47HjiLxz/VIK5Fcjh8bi6jRY1Bt/nxYeE8AR8VE2hwOB+PdHNHG0QJzjj6GT3wO+uZqw0EoubpDSmw2jq95iHYD68C5uwM4XM3PkyaEEEII+VasWLECLVq0wMKFC/Hhwwf4+/vD399fbn0Oh4PevXtj/fr1ErlpgKLRK2fPnoW3tzdSU1Oxfft2uUt5c7lcGBoaavRYxPvo6+uLuXPnYvv27Xj//j0WLVokt74qo1sUGTp0KCZOnIh9+/bh06dP+P777yXKtbS08PfffyMtLQ33798v1T7s7e0RFhaG7OxsNhePPKampvj333/ZUTSaVDzNi8PhqNX+4MGDMXv2bIhEIhw4cADDhg0r1f5btmyJfv364dKlS/D390dwcDDatWundjs//PADcnNzsWTJEmRlZaFPnz7w9/eHs7NzqfpVmWhkECFfA10joN+fwKQrgGU96fLE58Du7sD1ZQA/r+L7B8Bczxyr3Vdje4/tsDOykyp/m/YWYy+PxbqQdcjl58pth6urixq/LoX9ln/BNTWVLBQIkPjnn4iZMQMCNVdVaGRjgvOzO6KfmwN8jQpxW48PYYlRQiIBg6BT7+D392NkplTO80gIIYQQ8rUYNGgQXr9+jePHj+O7775D48aNYWVlBR6PBxMTE9SuXRuenp5Ys2YN3r9/j8uXL8tdpcvDwwORkZHYsGEDunXrhurVq0NbWxv6+vqoXbs2BgwYgL/++j/27jM6qqqLw/gzJb0RIITeexMEpAuK9KqIqKiACojYRawgWLAAVrADSkd6LxaK9N5bgAAhtBBaepmZ90Mkr+EOkMYkgf9vrawFs8+9Zye6mJs9Z5/zBceOHeO+++67Zd+Tm5sb3333HTt37uTFF1+kRo0aBAQEYLFYCAgIoFatWjzzzDPMnDmT/fv3Z3m+cePGMXHiRJo2bYqfnx8eHh6UKlWKJ598knXr1vHyyy9n6f5z587lwIEDfP311zzyyCNUq1Yt9fvx8fGhZMmStGvXjq+++orDhw9nuS3NmfDw8NRCYcOGDQ3FwBsJDg6mceOUbSmWLl1KREREpvN49913U//84YcfZvo+77zzDu+99x4AFy9epFWrVtny/4KrmRwZ2VFKBDh58iQlSpQAUpYdXt1oTVwkKR5WfZqyibTDySbSBcpDp9FQKvPLSbMqNimWMTvGMGn/JKebSBf1KcrghoNpUqzJDe+TdOoU4a8PJG77dkPMWqgQRUeMwKf+PRnOb/Hu07w1axceMTbax7gTZDfWxd09LTR9tCKV6he+JacpiIiISNaFhISQnJyM1WpN3X9FRCS3ysy/Wbfq92+tDBLJa9w84YGh0OcvCHbyyUrkYRjfFha/AQk5sweOt5s3b9R7g8ntJlMxsKIhfirmFP3/7M9b/7zFhfjrr/BxK1qUUhMnUKBfP7imIJN87hwnevUi4ptvcdgydrJauxpFWPxyU0qVzcdEvwQ2eSThuGaVUGK8jb9+3c/Sn/YQF50zezKJiIiIiIjcCioGieRVRWtDnxVw37tgdrsm6EjZX+i7hnDk+r3bt1r1gtWZ1mEaL9/9Mu5md0N80dFFdJ7bmQVHFlz32EuT1UqhV1+h5NhfsBQsmDbocHD+u+840bMXSWfOZCi34oHeTOvbgAEtyrPaO5lpvolcNhlXMR3dHsHUDzZxbPf5DN1fREREREQkt1IxSCQvs7pDs0Hw3D9QrI4xfvkETHwQ5g6AuIuuzw9wM7vxbI1nmd15NvUK1zPELyVc4p017/Dcn89xMurkde/j06gRZefOwadRI0MsdssWQrs8SFQGj4m0Wsy81qoSU55tQHJ+d371T2C3e7JhXNyVRBaN2cXKyQdIjDfGRURERERE8hIVg0RuB4WqwDN/QKuPwepljO+YBGMawIFFrs/tX6X8SzG21ViGNRqGn5ufIb7u1Doemv8QE/ZOwGZ33vZlLViQEr/8TNBrr4El7WlgtkuXONn/ec5+8imOxIy1dTUsV4AlLzelWbVglnonMdc7gViTcaXS3n9O8fvHmzlz9HKG7i8iIiIiIpKbqBgkcrswW6DRC9B/LZRysjFz9BmY9jjM6A0xOdPyZDKZeKjCQ8zrMo+WpVoa4nHJcYzYMoIei3tw8MJB5/cwmynYtw+lJk3EWrSIIX7ht9849ngPEk+cyFBugT7u/PRkHT7oXI3j3vCrXzxHrMai1OWIOGaP2MrG+Uex2YxtZSIiIiIiIrmdikEit5sC5aDnAmj/BbgbV+CwdzaMrge7ZkAOHSYY5B3EF82/4Ov7vqaQVyFDfG/kXrov7M7X274mPjne6T28a9em7Jw5+LV8wBCL37OH0Acf4vKijK2EMplMPNWwNPMGNKZoYV9m+ySyzCuRxGs2l3Y4YMviY8z6bCsXTsdkaA4REREREZGcpmKQyO3IbIZ6z8CADVDeuAKHuAsw+1mY+hhcOeX6/P51f8n7mdtlLt0rdTfEbA4bv+z+hYcXPMzmM5udXm8JCKDYN98QPGQwJve0G1TbY2I49fpATg8ejD0uLkN5VSniz/wXmvB4g5Ls8rDxm18C4RbjKqGIE1H8PnwzO/8Ow2HPmcKaiIiIiIhIRpkc1zvCR+Rf1apVS/P3pKQkQkJCAAgLC6N48eI5kZakl8MBu6bDkjch/pIx7hEArT6Eu58yHN/uStvObmPo+qGEXg51Gu9aoSuv1nmVAI8Ap/H4/fsJf/U1Eo8dM8Tcy5ej2Bdf4FnReMz9zSzefZq3Zu0iKi6ZexKsNI63YsH4cypWMR9NHqlIweK+GZ5DREREMickJITk5GSsVisVKlTI6XRERG4oM/9mnTx5khIlSgDZ+/u3VgaJ3O5MJrjrURiwCap0MsYTLsOCl2BiF7h4zNXZpbo7+G5mdpzJc3c9h9VsNcRnhcyi89zOLD+23Okx9J5VqlBm1kwCunQxxBIPH+FYt0e4+Pvv1z3C/nra1SjC4pebUqd0IBs9k5nkm8B5s3GvoPBDl/j9402smLifmMsJGZpDRERERETElbQySDLsVlUmxUX2zoXFAyEmwhhz84YW78M9fVNazXJIyMUQhq4fyq6IXU7j95W4j3frv0uwT7DT+OV58zg97AMcsbGGmH+7thQeNgyLn5P9lG4g2Wbnm78PM/rvEMx2aBrvRt0EY9EKwM3Dwt2tS1HrgRJY3S1Ox4iIiEjWaWWQiOQlWhkkIjmnWpeUVUI1HzXGkmJh6Zswvg1EHHJ5aldVCKzAhDYTePuet/G2ehviK8JW0GVeF6YfmI7dYVylE9C5M2VmzcSjShVD7MriJYQ+1JW43bszlJPVYua1lhWZ0qcBBQM8WeGVxHSfBK6YjPMnJdjYOP8ok9/fwKFNZ7SfkIiIiIiI5CoqBoncibzzw0M/wuMzwL+YMR62EX5oAv98AbZk1+cHWMwWHq/yOPO6zOPe4vca4tFJ0Xy08SN6Le3F0UtHDXGPMmUoPW0qgT16GGJJYWEce+xxIseNx2HP2PHwDcoWYMnLTWlZNZgTbnbG+iewxjPJcOIYQPTFBP4Yt4+Zn2/l9OFLGZpHRERERETkVlExSOROVrEVPL8B6j5tjNkS4K9h8Mv9cCZjq2iyU2Gfwoy+fzQj7h1Bfs/8hvj2c9t5eMHD/LDzB5JsSWliZg8PCg9+j2LffoPZ3z/thcnJnPv8c8L69yf5woUM5RTo485PT9bhw87VMLuZWe+ZzC/+8exyT8bhpCh07tgVZo/cxtKf9nA5ImMnm4mIiIiIiGQ3FYNE7nSe/tDhS+i5EALLGOOnd8JPzeHvjyA5ZzZGNplMtCnThnmd59G5XGdDPMmexJgdY3hk4SPsOLfDEPdv2ZKyc2bjVbu2IRazajWhXR4kZuOmDOf0ZMPSLHvlXu6tGESMGZZ5JzHBN4HjVuMx9ABHtp1jyrANrJt1mIS4nFlxJSIiIiIiomKQiKQo0xT6r4OGL4Dpmn8a7MmwegT8eC+c3JIz+QH5PPPxUZOP+KnlTxT3NW6cdvjSYZ5a8hSfbPyEmKSYNDG3YsUoNeE3CvTrl3LC2n8knzvHid69ifh2NA6b80LO9ZQp6MNvvevxwxN1KJbPi3NWB7/7JDLbJ4ELTk4dsyc72P7HCSYNXs+eVSex2zLWpiYiIiIiIpJVKgaJyP+5e0Prj+GZPyCosjEecQDGtoRl70Ki8aQuV2lYtCGzO8+md7XemK8pXDlwMOXAFLrM68Lqk6vTxExubhR69RVK/PIzloIF097Ubuf8mDGc6NmLpLNnM5SPyWSiTfXC/PlaM164rzzuVjNH3OyM90vgL69E4kzG1rH46CRWTT3EtA83cXxPZIaPvBcREREREcksFYNExKh4Xei3Gu59A8zXHJ/usMP60fB9Iwj9J2fyA7ysXrxW9zWmtp9KlfzGU8POxJxhwF8DGLRqEOfjzqeJ+TZuTNk5s/Fp1MhwXeyWLYR27kLUypUZz8ndwsDWlVj2akrrmN0E2zxs/OIXzxaPZGxO9hO6eCaWhaN3suDbnUSGR2d4ThERERERkYxSMUhEnLN6wP3vQd+VUOQuY/xiKPzWARa8AvFXXJ1dqqoFqjKl/RReq/MaHhYPQ3zJsSV0ntuZuYfnpll9Yw0KosQvPxP02mtgsaS5xnbpEief68/ZTz7FkZiY4ZyubR2LN8MKryTG+SVwyM15G1rYvgtM/2gTKycfIPZKxucUERERERFJLxWDROTGCteAZ/+GFu+Dk2ILW8fDdw3g0HLX5/Yvq9lK7+q9mdNpDvWL1DfEryReYfDawfT5ow9hV8JSXzeZzRTs24dSEydiLVrEcN2F337j2OM9SDxxIsM5GVrHLGYuWRzM80lkmk8CZy3GvYIcDtj7zykmDVnP1qXHSE7K2P5FIiIiIiIi6aFikIjcnMUKTV+D59ZACWOxhSvhMKUbzO4HsRk7pj07lfAvwc8tf+bDxh/i7+5viG88vZGH5j/E+D3jSbb//zQv77trU3bOHPxaPmC4Jn7PHkIffIgrixdnKqdrW8cAwtzsTPBNYLF3IlFO9hNKirexYe5Rpry/kZDNZ7WfkIiIiNxWSpcujclkolevXjmdisgdS8UgEUm/oIrQewm0+QzcvI3xXdNgzD2wd67LU7vKZDLRpXwX5nWZR9vSbQ3xeFs8X2z9gscXPc6+yH2pr1sCAij2zTcED34Pk5tbmmvsMTGEv/Y6pwcPxh4Xl6m8rm0dwwR73W2M9Y9nrWcSiU72E4q6EM/ysXuZ9flWzhy9nKl5RURERG6FmJgYfvjhB9q1a0exYsXw9PTEw8ODoKAg6tWrx9NPP83PP/9MWFjYzW+WjTn5+flhMpkwmUwMHz483ddevebaL3d3d4KDg2nWrBkff/wx586du+m9mjdvnnr9zXz99depYwsVKsSuXbvSnbNIVpgc+shZMujkyZOUKFECgLCwMIoXNx7xLXeAi8dg/ksQusp5vEpHaDcK/IJdmta1VoWt4sMNH3I21nhCmMVk4amqT9G/Vn+8rF6pr8fv30/4q6+ReOyY4Rr38uUo9sUXeFasmOmc4hJtjFlxmJ9WHyXx36Plfe3QJN6N6okWTDh/cChftxANu5TDv6CX07iIiMidJiQkhOTkZKxWKxUqVMjpdO4Y69ev59FHH+VEOlrpg4ODOXPmTJrXSpcuzfHjx+nZsye//vprtuU1YcIEevbsmfr3ypUrs3///nRdm57CDUD+/PmZOnUqrVq1uu6Y5s2bs2pVyjPyjX7dHjFiBIMGDQKgcOHC/PXXX1StWjVdeUjelJl/s27V799aGSQimRNYGp6aB52+BY8AY3z/gpRVQjumpmyGk0OalWjGvC7zeKzyY4Yii81hY/ze8Tw07yE2nN6Q+rpnlSqUmTWTgM6dDfdLPHyEY90e4eLvv2e6fctZ61i0GZZ6JzHBN4ETFud7BR3eco4pQzeyfs4REuOSnY4RERERuZUOHTpE69atUwtBnTp1YsKECWzYsIFt27axfPlyRowYQatWrXC7ZrX1rTZhwgQAfH19AThw4ACbNm3K0D3q1q3L7t27U7/WrVvHhAkTaNCgAQAXLlzgoYceIjQ0NEu5fvzxx6mFoGLFirFq1SoVgsSlVAwSkcwzmeDup2DABqhobMki/hLMfQ4mPwyXXLdE+Fo+bj68U/8dJrSdQLmAcob4yeiT9Fneh/fWvMflhJR2LLOPD0U/+5Qin36CyTttS5wjIYEzQ97n1OuvY4uKynRehtYx4JzVwXTfROZ4J3DRbNxk2pZsZ9uy40wasp49q8Ox24xjRERERG6Vd999l6h/n3/Gjx/PvHnzePLJJ6lfvz61a9emZcuWDBw4kGXLlhEeHs7gwYNdktfJkydZsWIFAEOHDiUwMBD4f4EovXx8fKhevXrqV8OGDXnyySdZt24dDz/8MJDSjjZq1KhM5zp06FDee+89AEqWLMmqVauomIVV5yKZoWKQiGSdf1F4bCp0HQveBYzxw3+mnDi2eSzYc654UatQLWZ0nMHztZ7HzWz8pGrekXl0mtuJpaFLU1f95OvShTIzZ+JRubJh/JXFSwh9qCtxu3dnOidnp45hgsPudsb5JfC3ZyLxTjaZjotKYtWUg0z/eDMn9kZmen4RERGR9LLZbCxatAhIWUFzsw2gg4KCGDBggAsyg0mTJmG327FarTz11FN069YNgGnTppGUlJTl+5tMJj799NPUv//555+Zus8777zDsGHDAChTpgyrV6+mXDnjh5Uit5qKQSKSPUwmqPEwDNgE1bsa44nRsOg1+K0jRB5xfX7/crO40f+u/szsOJPahWob4hfiL/DG6jd44e8XOB19GgCPsmUoPX0agT16GMYnhYVx7LHHiRw7Foct80fBO2sds5tgq6eNn/3i2eKejM3JJtMXTsWw4NudLPh2BxdOxWR6fhEREZGbiYiIIO7fwzTKly+fbfc9ePAgffr0oXTp0nh4eBAcHMyDDz7Ihg0bbn7xvyZOnAhAq1atCAoK4sknnwQgMjIytYCVVWXLlsXHxwcgUxtjDxw4kE8++QSAChUqsHr1akqVKpUtuYlklIpBIpK9fArCw+Pg0angW9gYP74Gvm8M674Fe+aLJ1lVNl9Zfm3zK+/Vfw8fNx9DfPXJ1XSZ14Up+6dgs9swe3hQePB7FPv2G8z+1xxbn5zMuREjOdb9UeL37TPcKyOctY7Fm2GFdxLj/RIIsTr/mZ3Ye4FpH21i1ZSDxF5JzFIOIiIiIs64u7un/jm9GzPfzJw5c7j77rv55ZdfOH78OImJiZw7d465c+fSpEkTpk+fftN7bNmyhX3/PoM98cQTADRu3JgyZcoAGW8Vux6TyYTVagXI8H5IL7/8cmprWeXKlVm1apUO4pEcpWKQiNwaldvBgI1Q+wljLDkOlr8HY1vBuex5kMgMs8lM98rdmdt5Ls1LNDfEY5Nj+WTTJ/Rc2pPDFw8D4N+yJWXnzMartnFVUfyePYR2e4Szn32OPTY203k5bR0DLloczPVNZLpPAmctxnY7h93BntXhTB6ynm3LjpOclHPFNhEREbn95M+fP3Uly86dO/nss8+wZ2ELgN27d/P4448THBzM6NGj2bBhA+vXr2fo0KF4enpis9no27cvERERN7zP1WKPn58fnf89AMRkMvH4448DsGjRIi5cuJDpPK86ffo0ly+n7C9ZunTpdF3jcDh4/vnn+eabbwCoXr06K1eupEiRIlnORyQrdLS8ZJiOlpcMO/I3zH8ZLjs5ftTsBs0GQZNXweLaEyf+y+Fw8MfxP/hk0yecjztviFvNVp6t8Sx9avTB3eKOIymJiG9HE/nzz05PS3MrWpTC7w/Bt1mzLOcWej6G9+fvZfWh/z8ImRxQLdFC03g3fB3Oj0L1K+BJwwfLUb5OoXQflyoiIpKXpOeYZofdju3SJdcmlsMs+fJhMt+az/1HjRrFwIEDU/9eunRpOnXqRKNGjbjnnntSV+PcyNWj5QHq1KnD33//jf81K68nT56cusrniy++4NVXX3V6r6SkJIoVK0ZERARPPfUUv/32W2rswIEDVKlSBYAxY8bw/PPPXzenq89KzZo1Y+XKlU7HvPTSS3z77bcAfPjhh6mbQF/rv0fLP/vss/zyyy8A1KpViz/++IOCBQteNw+5veWmo+VVDJIMUzFIMiUhGv76ADb96DweXB06j4aixhU3rnQ54TJfbv2SWSGznMbLBJRhWKNhqfsNxW7Zwukh75N49KjT8X5t21D4nXewBgVlKS+Hw8GyvWf5cOE+wi/Fpb7u5oB74q3US7DihvOCT+GyATTuVp7CZQKylIOIiEhuk55frJIvXCCkUWMXZ5azKqxbizV//ltyb7vdTp8+fRg3bpzTeHBwMM2bN6dHjx506NDB6QdS/y0G7dy5k5o1axrGOBwOihcvzqlTp3jwwQeZPXu20/nmz5+fuhpo+fLltGzZMk28Xr16bNmyhfr1699wD6LrFYOioqIICQnh22+/5bfffsPhcFChQgU2btyYemLZtf5bDLqqSpUqrF279rrXyJ0hNxWD1CYmIq7h4QvtPofeS6GAkw0Hz+6Bn1vAn0MhKd7l6V0V4BHA0EZDGdtqLCX9ShrioZdDeWrJU3y04SOiE6PxrluXMnPnUPDFFzA56R2PWrKUI+3ac3H67ziysIz6eq1jSSZY65XMWP8E9rolO732zNHLzPpsK8vH7iXqQs79bEVERCTvM5vNjB07luXLl9OmTZvUPXSuOnv2LNOnT6dTp07cc889HDly/YNDatSo4bQQBCnPPrX/bcs/ep0P3eD/LWJFihShRYsWhvjV1UUbN24kJCTkxt8csGrVKkwmU+qXv78/derU4ddffwWgc+fOrFy5Mt1FnatFptDQULZt25aua0RcQcUgEXGtUg3huTXQ+BUwXfNPkMMGa76EH5rAifSfHnEr3FPkHmZ1msWzNZ7FYrIY4tMPTqfzvM78feJvzO7uBA0YQJl58/CuV88w1h4VxZn33+f4E0+ScPhwlvJyduoYQJTZwWKfJCb4xhNmcb5XUMjms0x+fwMb5h4hMd554UhEREQkPVq2bMmSJUuIjIxk8eLFDBs2jI4dOxIQ8P+VyFu2bKFp06acPn3a6T0qV658wzny/7u6KSoqymn84sWLLFiwAIDHHnsMs5PWuMceewyLJeVZLqsbSRctWpRXXnmFokWLpvuaq21l8fHxdOrUiTVr1mQpB5HsomKQiLiemxe0HAbP/gWFqhnjkSEwrg0sHpTSXpZDPK2evHz3y0zvMJ1qBYx5nos9x8srXua1la8RERuBR9kylJzwG0U+/ghzgLElK27bNo4++BDnvv4ae0JClnJzduoYwFmrg2m+icz1TuCS2bgSyZZkZ+vS40wavJ69/4Rjt6tTWERERDLP39+ftm3bMmTIEObPn8/Zs2cZN25c6sqZ06dPM3jwYKfXent73/DeV4s7NpvzD7qmTZtGYmLKKapXVwBdq1ChQrRq1QqASZMmcbNdUurWrcvu3bvZvXs3u3btYvny5QwePJiAgADCw8Np06YN//zzzw3v8V8DBgxgxIgRAMTGxtK+fXs2b96c7utFbhXtGSQZpj2DJFslJ6asBlo9AuxJxni+ktDxGyh3n+tz+49kezKT909mzI4xxCXHGeK+br4MqDWA7pW742Z2IzkykrOffsaVfz+tupZ7qVIUHjYUnwYNspxbXKKNMSsO89PqoyTa/l8AsjigdoKVhglWPK+zyXSBYj407lqBElVvzb4CIiIit5I2kHbuVm4gnV7Lli2jTZs2AAQGBnL+/PnU4s7VPYN69uyZ2n7lTK9evfjtt98oVaoUx44dM8QbNmx4w32AnFm5ciXNnBzwcbMNpPfs2UOjRo2IioqiRIkS7Nmzx7Dp9VX/3TPo6q/bH3zwAe+//z6Q8vNYuXLldVvk5PalPYNERK6yukPzN6Hfaih6tzF+6QRM7ALzXoC4S67OLpXVbKVntZ7M7jSbRkUbGeLRSdF8tvkzHlnwCJtOb8JaoADFRnxOiV9+we3ff7z/K/H4cU706s2pt94m+eLFLOV2vdYxmwm2eCbzi188W92TsWOs/UeGxzD/mx0sHL2TC6djspSHiIhIbmQym7Hmz39HfeV0IQigdevWqb/AXrx4kcjIyGy9f0hISIYLQZD5VrHq1aszfPhwIOUX8qurfdJryJAhvPnmm0DKz6Nly5YcOHAgU7mIZIec/1dCcr1q1aql+br//vtzOiW5HQVXhWf+gJYfgtXTGN8+Eb5rAAcWuz63/yjuV5wfHviB4U2Gk88jnyF++NJhnln+DANXDeRMzBl8mzSm7Px5FOjTB67ZYBHg8ty5HG3Xnsvz5t102fLNXK91LM4Mf3snMd4vgcNW58usj++JZNqHm1g99SBx0YlZykNEREQESLO3jrNTxbLiv0Wd77//nqlTp97w62qr2MyZM4mLM67yTo9+/fpRpkwZAL788kvOnz+foes//fRTXnrpJQDOnTtHixYtbrg5tsitpGKQiOQeFis0fgn6r4NSTo6AjToN0x6DmU9DTMbefLOTyWSiY7mOzOsyj45lOzods+zYMjrN7cTPu34m2d1Coddfo8ysmXjeZVwObLt4kVNvvsWJp58m8d9jVrOSm7NTxwAuWBzM8U3kd58EIizG/YQcdge7V4UzafAGti8/gS0p86efiYiIyJ0tNjaWffv2ASn7ChUoUCDb7u1wOJg0aRKQsmLnueee49FHH73hV//+/QG4cuUKc+fOzdS8bm5uvPXWWwDExMTw5ZdfZvgeX331FX369AHg1KlT3H///YSFhWUqH5GsUDFIbmrv3r1pvv7++++cTkludwXKQc+F0H4UuPsa43tmwZh7YPdMyMFtz/J75md40+FMaDuByvmNp2HEJcfxzfZveHDeg6w+uRrPSpUoPWUKwYPfw+zjYxgfu34DRzt15vwPP+JIzNrqnOu1jgEcd7Pzm28CS70SiTMbf36Jccmsm32YKcM2cHjruSyvWBIREZHbQ3R0NPXr12fhwoXY7df/0Mhut/Piiy+mngLWqVOnbF0ZtHr16tQ9hB5++OF0XdOmTRt8fVOeK7NyqlivXr0oVqwYAGPGjOHy5csZut5kMvHDDz/Qo0cPAI4fP06LFi04c+ZMpnMSyQwVg0QkdzKbod6z8PwGKP+AMR4bCbOegWmPwxXnx5W6Su1CtZnWfhrv1X8Pf3fjRoInok4w4K8BvPDXC5yMOUX+Hj0ou3gRfi1bGsY6EhKI+OorQrs+TOy27VnO7XqtYw4T7Paw8aNfPOs8krA5eT67cj6eZT/vYc6obZw9diXLuYiIiEjet2nTJjp27EjJkiV54YUXmDx5MmvWrGHnzp2sWrWKr776ilq1ajFu3DgAAgIC+PDDD7M1h/8Wc7p27Zquazw9PWnXrh0Af/zxR6aLL+7u7gwcOBCAy5cv880332T4Hmazmd9++y0195CQEB544IEMt52JZIWKQSKSu+UrAT1mQpfvwTOfMX5wMYypD9sm5OgqIYvZQvfK3Vn44EK6VeyGCWN1ZdXJVXSe15lvtn1DUn4/in/7DcW/G4O1cGHD2ISQEI736MHpoUOxXclaIeZGrWNJJljrlczPfvHsd3e+n9Dpw5eZ+ekW/hi/l6gL8VnKRURERPIuq9VK4X+fW8LDwxkzZgxPPPEETZs2pVatWjRv3pxXX32V3bt3A1ChQgX++usvSpcunW05xMXFMXPmTAAqVapE9erV033t1VVENpuNyZMnZzqHvn37UrBgQQC+/vproqOjM3wPi8XC1KlTad++PZDSjdGqVSsu3WEn30nOUTFIRHI/kwlqPQ4DNkEVJ3v0JFyG+S+mnDp28Zirs0sj0DOQIQ2HMK3DNO4KussQT7In8fPun+k8rzPLji3D9777KLtwIYFPPZnyff6Xw8GladM50r49V5YuzXK71o1ax6LMDhZ6JzLRN55zHs7nObTxLJPf38DG+UdJjE/OUi4iIiKS93h6ehIeHs7atWsZNmwYbdu2pWzZsvj4+GCxWPD396dy5cp0796dKVOmsGfPHurUqZOtOcydO5cr/35Qlt5VQVe1a9cOL6+UldJZaRXz9vbm1VdfBSAyMpLvv/8+U/dxc3Nj5syZPPBAyir47du307Zt20wVl0QyyuTQZhCSQSdPnkw9JjIsLIzixYvncEZyx9k7FxYPhJgIY8zNBx54H+r1SWk1y0F2h52FRxfyxZYviIx3fpxq/cL1eeuetygfWJ643bs5PeR9EvbvdzrWt3lzCg8ZjNt/TubILIfDwbK9Z/lw4T7CL11zooYDKiaZaZ3siWei87cIb3936ncuS+WGRTCbs/d0EBERkfQKCQkhOTkZq9VKhQoVcjodEZEbysy/Wbfq928VgyTDVAySXCH2Aix9C3ZNdx4v0QA6j4aCOf9gGJUYxQ87f2Dy/snYHMZWLIvJwuNVHqf/Xf3xNXtx4bcJRIwejcPJsacmb2+CXnqR/E88gcnJUfUZFZdoY8yKw/y0+iiJtrQbQVocUDfRStMkd0zJzt8qChT3pfHD5SlROX+WcxEREckoFYNEJC9RMUjyNBWDJFc5tBwWvgJXwo0xiwfc9zY0fDHl2PocduTSET7Z9AkbT290Gi/gWYBX67xKx3IdSQ4/zZkPhhGz+h+nYz2rVqXwhx/gVa1atuQWej6G9+fvZfUh42orLzs8YPekcowJrvOOUbpmQRo9VI7AwsZT0kRERG4VFYNEJC9RMUjyNBWDJNeJvwJ/DIGt453Hi9SCzmOgcPo3GLxVHA4Hfxz/gxFbRnAmxvkpFjWDavJO/Xeomr8qUUuWcGb4J9icnS5hNpP/yScJeulFp0fVZya367aOAQVsJrqYvMl/xflRsmaziWrNinFP+zJ4+rplOR8REZGbUTFIRPKS3FQM0gbSIpL3efpDx6+g5wIILG2Mn94BPzWDvz+G5AQXJ5eWyWSiVelWzOs8j741++JudjeM2RWxi8cWPsYHGz7Adn9Dyi1aSL5HHjHezG7nwm+/caRjR6JWrsyW3K536hhApMXBWHMMs/wSsfkZV1rZ7Q52rzjJpCHr2fHnCWzJzotGIiIiIiKSs7QySDJMK4MkV0uMSSn6bPgOpz1NQVVSVgkVz96TLTIr7EoYn2/+nJUnVzqN+7v782LtF+lWsRsJ23dwesj7JB454nSsX5s2BL/zNm6FCmVLbjdqHTM5oLHFg8ZxbtjjnB9J7x/kRaOHylG2VhCma09KExERyQZaGSQieYlWBomI3CruPtBmODzzBxSsZIxH7IexD8CydyEx1vX5XaOEfwm+bfEtY1qMoaRfSUP8SuIVPt74MY8uepQDxU2UmTObgi+9iMnN2IYVtXQpR9t34OK0aTjsWV+VU6agD7/1rscPT9ShWD6vNDGHCdbYE/jSPZqwYCtmq7HYcyUijqU/7mHuF9s5d/xKlvMREREREZHsoZVBkmFaGSR5RnICrB4Ba74Ee7Ixnr8sdPoWSjdxfW5OJNoSmbBvAj/t+om4ZOOePQAdynbgtTqv4X82mjPvDyV20yan47xq16bIB8PwyKZPSW906hhAoMPEU775cA+Pv+49KjUoTIPOZfEN9MyWnERERLQySETyEq0MEhFxBasH3P8e9FkBhWsa4xeOwq/tYeFrKZtQ5zB3izvP1niW+V3m07Z0W6djFh5dSIc5HZgas4oi436myPDhWAICDOPitm/n6ENdOffVV9jjr1+gSS8vdwsDW1di2av3cm/FIEP8osnB1zEXWRxsxz3Yy8kd4OCGM0wesoGNC46SGO+kOCciIiIiIi6hYpCI3P6K1IQ+f0OLIWAxbtjMlrHwXUMI+dP1uTlR2Kcwnzf7nHGtx1E+X3lDPDY5llFbR9F1QVf21Q+m7JLFBHTuZLxRUhKRP/zI0c6diVm/Pltyu1HrGMDehAQ+jr/AgTLueOYz/qyTk+xsWXSMye9vYP+6UzjsWpwqIiIiIuJqahOTDFObmORpEQdh3gtw0nl7FXc9Bq2Hg3d+1+Z1Hcn2ZKYfnM6Y7WOISopyOqZFyRa8Ue8N8u06zumhw0g6ccLpuIDOnSn01ptYAwOzJbebtY55mEz0KRyE/9FYkhKcbzJdsIQvjR+uQPFK2ZOTiIjcWdQmJiJ5idrERERySlAleHoptPkU3LyN8Z1TYUx92DfP9bk5YTVb6VGlBwseXMCD5R90OuavE3/ReW5nJnjvpOjs6RTo2xesxqPfL8+bx9G27bg0Zy7Z8TnA1daxpa80ddo6luBwMPr0OSYXTMKnaj6cHSh2PiyaeV9uZ/H3u7h0Nuc39BYRERERuRNoZZBkmFYGyW3jQigseAlCVzuPV+0M7UaCb/Yc1Z4ddkXsYvjG4eyN3Os0Xsy3GIPqDaJhXDHOvj+UuB07nI7zbtCAIkPfx7106WzJy+FwsGzvWT5cuI/wS843v25RJJCWSR5EHr7sNG42m6jevBj12pfB08d4WpqIiMi1tDJIRPKS3LQySMUgyTAVg+S24nDAtgmw/D1IcLKJtFdgyiqimt1xurQlB9gdduaEzOHrbV9zMeGi0zGNizZmUN03CFy6iXOjvsAeHW0YY3J3p2D/5yjwzDOY3J3spZQJN2sds5pN9K1YhKLHErh8xvlKIA9vK/Xal6F6s2JYrFrAKiIi16dikIjkJbmpGKSnbBG5s5lMUKcnPL8BKrYxxuMuwpx+MLkbXD7p+vycMJvMdK3YlQUPLuCxyo9hNhn/KV97ai1dFz7M+IqnKTxvBn6tWxvGOBITifj6G44+9BCx27ZlS243ax1Ltjv47sApRrtFE3hvMF6+xhVACbHJrJkRwtQPNnJ0R0S2tLSJiIiIiMj/aWWQZJhWBslty+GA3TNhySCIu2CMu/tBqw/g7l5gzj219IMXDjJ843C2nXNe0CnkVYjX6r5G02NenP3wI5JPn3Y6Ll/37hR6/TUs/v7Zkld6WsealMrPY34BnNhwFnuy87ejYhXz0fjhCgSV9MuWvERE5PahlUEikpfkppVBKgZJhqkYJLe96IiUgtDe2c7jpZtCp28gf1nX5nUDDoeDJaFLGLVlFOfizjkdc3ehu3m7xqsETlzChYmTwG5s47IEFaTwO+/g16YNpmxqi0tP69gztUtQ8wIc2x7h/CYmqNygMA06l8Mnn0e25CUiInmfikEikpfkpmJQ7vloW0Qkt/ANgm7joftk8A02xo/9A981gnWjwe78yHRXM5lMtCvbjvkPzufp6k9jNRtPE9t2bhvd/+7J2PsdFJw8Do+qVQxjbBHnCX/1NcKee46k8PBsyS09rWM/bj3BB5FnKfRgSYJLO1mZ5IAD688wach6Ni0Mve5R9SIiIiIicnNaGSQZppVBckeJuwjL3oMdk5zHi9WFzqOhkLGwkpNCL4fy2ebPWBu+1mk80COQl+96gWYbojn/zWgcccY2LpOXF0Evvkj+p57E5OSo+sxIT+tYwzL56V+uCKF/hxN9IcHpGJ98HjToUpZK9xTGZM4dG3uLiIjraWWQiOQluWllkIpBkmEqBskd6fBfsOAVuHzCGLO4Q9OB0OQVsOaeFiaHw8HKsJV8tvkzwqOdr/KpVqAa75R8lnxjZhCzarXTMR5Vq1Dkgw/xql4t23JLT+vY0w1Lc5/Jk91/hF13JVBQST8aP1yeYhUDsy03ERHJO1QMEpG8RMUgydNUDJI7VkIU/DkMNv/sPJ6/HLQbAeVbuDavm4hPjmf83vGM3T2WBJvzlTZdynXmuciaxIz8FlvEeeMAs5n8Tz5B0EsvYfbxybbcjkZEM3TBPlYfcr5XULC/B2/dV5F8R2LZv/YU13vHKlsriIYPlSNfIe9sy01ERHI/FYNEJC/JTcUg7RkkIpJeHn7QfiT0XpJS+LnWhSMw6SGY/mSuOYYewNPqSf+7+jOvyzweKPmA0zFzj8yjW8xXbP2iJ/6PdDMOsNu58NsEjnToSNTfK7Itt7JBvvzWux4/PFGHYvm8DPGzVxJ4dd5ufoq9RIP+1ShRxfkKoKM7Ipg6bCNrZoQQH5OUbfmJiIiIczExMfzwww+0a9eOYsWK4enpiYeHB0FBQdSrV4+nn36an3/+mbCwMMO1vXr1wmQyGb7MZjP58uWjRo0a9OnThw0bNmQ4r2HDhqXeLyAggPj4+HRfe/jwYaZOncqrr75K48aN8fb2Tr3Xr7/+muFcRHIzrQySDNPKIBEgKQ5WfgLrvgWHsc0JN29oNggaDACru+vzu4F1p9bx6aZPCb0c6jRePl953vN+mHxfTyXx8BGnY/xatSL43XdxCy6UbXmlq3WscWkeKlKQbfOPcvFMrNP7ePhYqde+DNWbFcNi0WceIiK3M60Myhnr16/n0Ucf5cQJJ+3z1wgODubMmTNpXuvVqxe//fZbuuZ64YUX+Oabb9J9ymm5cuU4evRo6t+nTp3Ko48+etPrVq1aRfPmza8bHz9+PL169UpXDiLXo5VBIiJ5nZsXtPwA+q5M2UT6Wkmx8OdQ+KEJHF3l6uxuqFHRRszqOIuBdQfibTW2VR2+dJhepz7lp1cr4fn8M5jcjcWsqOXLOdq+PRenTsXh5Ij6zEjPqWM//RNKz6W78OlUgnsfrYinr5thXEJMMmt+D2HaB5sI3XUefeYhIiKSfQ4dOkTr1q1TC0GdOnViwoQJbNiwgW3btrF8+XJGjBhBq1atcHMzvk9fa9myZezevZvdu3ezfft25syZw4ABA7D+e3jF6NGjGTlyZLpyW7NmTWohyNfXF4AJEyak69r/Pi+YzWaqVavGPffck65rRfIirQySDNPKIJFr2O0pp4398T7EXXA+pnpXaPUx+BdxbW43EREbwZdbv2TB0QVO415WL14K6kaTyXuI27jJ+ZhatSj8wTA8K1bMtrzSdepY2QIMblWJS1sj2bkiDHuy87ezYpUCafxweYJK+GVbfiIikjtoZZDrdevWjZkzZwI3Xy0TERHB77//zoABA9K8/t+VQaGhoZQuXdpw7YIFC+jUqRMA+fLl49y5czctLvXt25eff/6Z4OBgXnnlFd5++20sFgvh4eEEBwff8NqQkBDmzZtHvXr1qFOnDr6+vvz666/07t07Xd+rSHpoZZCIyO3EbIa7n4IXt0KdXoCTZcx7ZsHourBuNNhyz542Qd5BDG86nAltJ1A5f2VDPC45js9OT2BApwiuDOqFJV8+45gdOwh9qCvnvvgSewb68m/EZDLRpnph/njtXgbcVw53J+1e649G0umn9az0SqLL23Upd7dxNRFA+MGL/D58M39P2E/MZecbaIuIiMjN2Ww2Fi1aBEDdunVvWhwJCgoyFILSq2PHjjRp0gSAS5cusXXr1huOj4+PZ8aMGQB0796dJ598ErPZjM1mY/LkyTedr0KFCgwcOJBmzZqlrioSuZ2pGCQikl2880PHr+HZv6BILWM8MRqWvws/3gvH1ro8vRupXag209pP47367+Hv7m+In4gO41nLJH5+uybWdk42oU5OJvKnnzjaqTMx69ZlW17e7lbeaF35xq1jq4/S5beNJNYvwIOv16ZQKScrgBywf91pJg3ZwJbFoSQlOj+qXkRERK4vIiKCuLiUFbvly5e/5fPVqFEj9c/ONqL+r/nz53Pp0iUAnnjiCYoVK8Z9990HpL9VTOROomKQiEh2K14H+vwN7b8Az3zG+Ll98Gs7mN0Pos+5PL3rsZgtdK/cnYUPLqRbxW6YnKxwWnppHY/WXsfWdzthLVnCEE86cYITTz9D+KBBJF+4TstcJqTn1LGXpm7n9RUHuatXJR7oXRXfQA/DuOQEGxvnhzLl/Q0c3HgGh12d0iIiIunl/p99BPfv3+/S+W7WIna14FOpUiXq1asHpBSFAHbu3Mnu3btvUZYieZOKQSIit4LZAvWeSWkdq/2E8zG7psG3dWHjT2BLdm1+NxDoGciQhkOY2mEqdwXdZYgn2hP5zL6Yl3vbufxoK/h3g8f/ujJ/AUfbtuPSrNnZtoFzelvH2n27hlmRF+n8Tl3qdyqL1cNiGBd9MYE/x+9j5mdbOBVyKVvyExERud3lz5+fUqVKASkFls8++wx7Nh0k4cx/C07O9hW66ty5cyxbtgyAHj16pL7etWtXvLxSPkRK7+llIncKFYNERG4ln4LQeQw8vRyCaxjjCZdhyRvwc3MIc75Bc06pVqAaE9pO4OMmH1PAs4AhHpZ4lj5l/ubXgTUw1ahiiNsuX+b0u+9yomcvEkKdH2OfGeltHWv97T+cLubOE8PqU7VxEadbOZ07HsWcUdtY+uNuLkc4P6peRETyNofdQVxU4h31dStXvr744oupf37rrbcoV64cL7/8MtOnTyc0G9/vt2zZwp9//gmk7OdTs2bN646dMmUKyckpH6xdXQ0E4Ofnl7oJ9ZQpU7DZ1CYucpVOE5ObqlatWpq/JyUlERISAug0MZEMsSXDlrHw90eQcMX5mNpPwAPDUopIuUhUYhTf7/yeKfunYHMYH6TcsPDm6bu5a9ZuHFHRhrjJ3Z0Cz/Wj4LPPOj2qPrPSc+pYo3IFGNapGoFJsHbmYU4euOh0nNlqouZ9JajbthQe3jc/CldERHJeek7miYtKZNwba1ycWc56ekQTvPyy7/32v+x2O3369GHcuHFO48HBwTRv3pwePXrQoUMHTCbjpzHXO03MZrNx7Ngxli9fzuDBg4mMjMRisTBr1iw6d+583Zzuvvtutm/fTqNGjVi7Nu2+jIsWLaJDhw4ALFmyhDZt2qT7e9VpYpLddJqYiMidyGKF+v3ghS1Q81HnY7ZPgm/rwJZxYM89n175ufsxqN4gZnWaRf0i9Q3xJGx8VGQzg/p5EtXEuALKkZjI+W++5eiDDxF7k9NAMiI9rWPrjkTS9ut/+GnXSe5/rjrtB9QkX7C3YZw92cGOP04wafAGdq88ic1265a9i4iI5FVms5mxY8eyfPly2rRpg/WadvGzZ88yffp0OnXqxD333MORI0dueL8yZcpgMpkwmUxYrVbKly/P888/T2RkJBUqVGD27Nk3LATt2bOH7du3A2lXBV3VunVrgoJSVhJPnDgxo9+uyG1LxSC5qb1796b5+vvvv3M6JZG8zS8YHvoRei2GIGN7FfGXYOGr8EsLCM++wkl2KJevHD+3/JlRzUZR2KewIR7qdolnmu5n+rPlcQQbVzclHjnC8R5PcHrwEGyXL2dbXultHXvgi1XstifSfXA9mnaviIePcb+j+JgkVk87xPQPN3Fs9/ls2/NIRETkdtKyZUuWLFlCZGQkixcvZtiwYXTs2JGAgIDUMVu2bKFp06acPn06w/c3mUx07949dVXP9VxdYeTm5sYjjzxiiFutVrp37w7A3LlziYqKynAuIrcjFYNERHJK6cbw3D/Q6mNw9zXGT22Hn1vAglcgNvtO5soqk8lEq9KtmNd5Hn1r9sXdbFyGPivoGD2fuMTB1pXAbHyruTRjBkfad+DyokXZWmxJz6ljL07dzlPjN+NVNYAnPmhIrQdKYLYYl7BfPBPLojG7mP/1Ds6fNLa+iYiICPj7+9O2bVuGDBnC/PnzOXv2LOPGjSMwMBCA06dPM3jw4Otev2zZMnbv3s3u3btZv34948aNo1atWjgcDj766KM0exRdy2azMWXKFADatm1LgQLGPQ4BnnzySQBiY2OZOXNmZr9VkduK9gySDLtVPYsid7Qrp2H5u7BnlvO4V35oOQxqPeG0uJKTwq6E8fnmz1l5cqXTeI3z3rz+lw/eR51/Kuhzb1MKD3kf9+LFsjWv2MRkxqw4zM+rQ0l00vJlNZt4ukkZXmpRAduVRNbPPsKR7RFO72UyQZVGRbinU1l8AoxH1ouISM5Iz/4bDruD+JgkF2eWszx93DCZnZyc4ELLli1L3Z8nMDCQ8+fPY/73GeZ6ewZdlZSURIcOHVi+fDkAs2fP5sEHH7zhHOnVvHlzVqxYka6x2jNIsltu2jNIxSDJMBWDRG6hoyth8Rtw/pDzePF7oP1IKGI88j2nrT65ms82fcaJqBOGmNnuoOe+INr8eRFTXIIhbvLyIuiFF8jf8ylMTo6qz4qjEdEMXbCP1YecF3qC/T14r31VOtQswunDl1gz4zARJ5wvIXfzsHB3m1LUalECq7vxyHoREXGtzPxiJa5TsmRJwsLCgJTj36/u3XOzYhCkrCiqVKkSUVFRlC1blgMHDuDmlvaAhx49eqSuDEovk8nEsWPHKFmy5E3Hqhgk2S03FYNy18fLIiJ3urLN4bm18MBQcDNucszJTfBT85SCUdwl1+Z2E/cWv5c5nefw8t0v42VN26JlN5sYX/08z/dOJqxmsOFaR1wc50aMILTbI8Tt3p2teaW3dazHLxuJ8bfS7a26PNCrCj75jCuAkhJsbJx3lMnvb+DQpjO39OheERGRvK5o0aKpf3Z2qtiNFClShJdffhmAo0ePMnbs2DTxqKgo5s6dC0CLFi2YOnXqDb9+/PFHIOUkUm0kLaKVQZIJWhkk4iKXT8LSt2H/fOdxnyBo+SHc9WhKH1MucibmDKO2jGLpsaXGoMPBvSHu9P3bgvtFJ3vxmM0E9uhB0MsvY/H1yda8MtI65mEyseOPE2xbdpzkROcnixUq7U+Th8tTpHy+bM1TRETSRyuDcq/Y2FgKFy5MVFQU/v7+XLp0KbUglJ6VQQAXLlygVKlSREdHU6pUKQ4fPpx6etn48eN5+umnAZg5cyZdu3a9aU61atVi586dVKpUiQMHDtx0vFYGSXbTyiAREbm5gOLQfSI8MQvylzPGYyJg7nMwvi2c3ev6/G6gsE9hRjQbwbjW4yifr3zaoMnE6opJ9OkVx/oGATiuLWTZ7VycOJGjHToQlc2nF/731LGmFYynnV09dazFqJUs3X+Wuu1K88QHDancqAg4qbedO3aF2SO3seznPVw5H5etuYqIiOQ20dHR1K9fn4ULF2K3O/+gBMBut/Piiy+mntzVqVOnDK8MAsifPz/PPfccAMePH0+zomfChAkAeHt707Zt23Td7+GHHwbg4MGDbNy4McP5iNxOtDJIMkwrg0RyQHICrP0G/hkJyfHGuMkC9Z+D5m+Bp7/r87uBZHsy0w9OZ8z2MUQlGffiqXDSwet/eZH/lPMTu/xatiT4vXdxCza2l2WFw+Fg2d4zfLhwP+GXnBdyGpUrwLBO1agQ7EfEiSjWzgoh/OAlp2PNVhN33V+COm1L4+GVvfseiYiIc1oZ5FrR0dH4+fkBUKxYMbp06ULDhg0pVaoUfn5+XLp0ie3btzNu3Dh2/9v2HRAQwI4dO9Ks/knvyiCAM2fOUKZMGeLj46lYsSL79+/n5MmTlC5dGofDQdeuXdN9Qtj+/fupWrUqAM8//zxjxoxJE585cybR0f9/HlmzZk1qe9ozzzxDkyZNUmOFCxfO8ObVIrlpZZCKQZJhKgaJ5KCLx2HpW3BwsfO4b2Fo/TFU75rrWsci4yL5etvXzDk8xxCz2Bw8uNnMw2vsmJNshrjZx4eg114l8NFHMVmyd+Pm9LSOPdOkDC+2qICPu4Vju86zdtZhLp9zXkDy9HWjfscyVG1SFLNFC3BFRG4lFYNcKz4+njJlynDmzJl0ja9QoQJTp06lTp06aV7PSDEIYMCAAXz33XcATJkyhaNHj/Lee+8BMHXqVB599NF0fw/VqlVj3759FChQgFOnTuHu7p4aK126NMePH0/XfZo1a8bKlSvTPa8I5K5ikJ5SRUTyksBS8NhUeGw65CtljEefgVnPwG8d4dzNe+FdqYBXAT5o/AGT202mWoFqaWI2i4mZDRy8/DQcKudpuNYeE8PZDz/i2OOPE3/wYLbmlZ7WsR//bR1buOs0pWsW5LEh9WnySAU8vI0rgOKjk1g19RDTPtzE8T2R2ZqriIhITvL09CQ8PJy1a9cybNgw2rZtS9myZfHx8cFiseDv70/lypXp3r07U6ZMYc+ePYZCUGYMGjQo9SSx4cOHp7aIeXh40L59+wzd6+reQpGRkSxatCjLuYnkVVoZJBmmlUEiuURSHKz5EtZ8BTbjce2YrdBwANw7CDx8XZ7ejdgdduaEzOHrbV9zMeFi2qDDQdM9Dp5ZacE7Osl4sdVKgd69KPj885i9jKeDZUVGW8fiY5LYsvgYu1ecxH6dk8VKVs1Po67lKVAsd/03EBG5HWhlkIjkJblpZZCKQZJhKgaJ5DKRR2DJIDj8p/O4fzFo8wlU6ZTrWscuJ1xmzI4xTD84HbsjbYuWX6yDp1ZAs13GtjEAtxIlKPz++/g2aZzteWWkdczXw8qls7Gsn3OEozsinN7PZIKqTYpyT8eyePu7Ox0jIiIZp2KQiOQlKgZJnqZikEgu5HDAgYUpR9FfDnM+ptz90G4kFHByMlkOO3jhIMM3DmfbuW2GWLVjdvovN1Eo0nlRyL9DB4LffgtrgQLZntfRiGjen7+Xf0LOO40H+3vwXvuqdKhZBJPJRPjBi6yZGcL5MOebYbt5WqjTphR3tSiB1S179z4SEbkTqRgkInmJikGSp6kYJJKLJcbA6pGw7luwO2mxsrhDo5eg6evg7u36/G7A4XCwOHQxX2z5gnNx59LE3JIdPLTWTpeNKZtNX8scEEDwoDcIeOihTB1de7O8MtI65rA7OLjxDBvmHiHmcqLT8X75PWn4YDnK1y2U7fmKiNxJVAwSkbxExSDJ01QMEskDzofA4oFwdKXzeEBJaPspVGqX61rHYpJi+GnXT0zYN4Fke3KaWLHzDvotsVP5pPO3Lu969Sg8bBgeZctke14ZbR1LSrCx/Y8TbF9+nORE43iA4DL+NOlWgcJlA7I9XxGRO4GKQSKSl6gYJHmaikEieYTDAXvnwLJ3IOq08zEVWkPbzyB/9hdPsir0ciifbfqMtafWpnnd5HBw/w4HT65w4J1gfAszublRoF8/CvTtg9k9+/fnyWjrWPTFBDbOO8KBDdc/hrd83UI07FIO/4LZuyG2iMjtTsUgEclLVAySPE3FIJE8JiEKVn0GG76Ha1baAGDxgKavQeNXwM14rHtOcjgcrAhbweebPyc8OjxNLCDaQa8/7TTe7/xtzL1sWYp8MAzvunVvSV7paR37oHM1yhfyA+Dc8SusnXmYUyGXnI63WM3c1aIEddqUwt3LeGS9iIgYqRgkInmJikGSp6kYJJJHndsPiwbC8TXO44Gloe0IqNjKpWmlR3xyPOP3jmfs7rEk2BLSxGoftvPMcjuFLju/Nl+3hyk0cCCWgOxvxbraOvbT6qMkOdnL6NrWMYfDQejO86ybdZjLEc6LSF5+btzTsSxVGxfBbDFne84iIrcTFYNEJC9RMUjyNBWDRPIwhwN2z4Tl70L0WedjKndIOYo+X0nX5pYO4dHhjNw8kj9P/JnmdY9EB4/8Y6fdZgcWJ+9qlgIFCH77bfzbt7slGzZntHXMlmxnz6pwNi8KJSHWyWotIH9RHxo/XJ6SVbP/lDQRkduFikEikpeoGCR5mopBIreB+Muw8lPY+CM4nBzZbvWCewdCoxfB6uH6/G5i3al1fLrpU0Ivh6Z5vfQZB/2W2Ch3ne15fJo0ofDQ93G/Bf9uZaZ1LD46ic2LQtmzKhy73fnbcY3mxWncrTwWrRISETFQMUhE8hIVgyRPUzFI5DZyZndK61jYBufxAuWh3Qgod79r80qHJFsSk/dP5vud3xObHJv6usnuoM1WB4+tsuOZZLzO5OlJ0IsvkP+ppzC5uWV7XhltHQO4eCaGdbOPcGyX85VFxSrmo3Xf6nj5Zv+G2CIieZmKQSKSl6gYJHmaikEitxm7HXZNg+WDIdZ5MYKqXaD1cAgo5tLU0iMiNoIvtn7BwqML07xe4IqDp5fbqRfi/G3Oo3JlinwwDK+aNW9JXhltHQM4eeACa2YeJvJktGG8XwFP2vWvQcHifrckXxGRvEjFIBHJS1QMkjxNxSCR21TcRfj7Y9gyFhx2Y9zNB5q/CfX7gzX3rVDZdnYbn2z6hAMXDvz/RYeDew6lFIXyG+srYDIR2KMHQa+8jMXXN9tzykzrmN3uYN+aU6z5PQRbctr/DlZ3My16VqV8nULZnquISF509OhREhISMJlMVKpU6ZbsCycikh0cDgcHDx7E4XDg7u5OuXLl0nWdikGSa6gYJHKbO7UDFr0O4VucxwtWgvYjocy9Lk0rPWx2GzMPzeSb7d9wJfFK6ute8Q4eX2Wn5TYHznbesQYHU3jwe/g98MAtySszrWNnQ6+w5IddxFxONIyv264093Qog8msX3pE5M528uRJoqKiAChZsiQ+Pj45nJGIiHPx8fGEhqbsd+nr65v6O/XNqBgkuYaKQSJ3ALsdtk+EP4dC3AXnY6o/DK0/Br/CLk0tPS7GX+Tb7d8y89BMHPz/ba7CSQf9ltooGeH8Ot8HWlD4vfdwK3xrvqeMto7FXE5gyQ+7ORt6xTC2dM2CtOxdFXcv6y3JVUQkL7hy5Qrh4eFAyi9XxYsX1+ogEcmVzp07R2RkJADBwcHkz58/XdepGCS5hopBIneQ2Avw1zDY+hvg5O3C3Q/uewfu6QuW3FeU2Bu5l+Ebh7MrYlfqaxabgw6bHHRbY8fdyanuZh8fgl55hcDHH8NksWR7ThltHbMl2Vk59SAH1p02jAss4kO7/jXIV8g72/MUEckL7HY7hw4d4uqvNL6+vuTPnx9vb28VhUQkV7DZbFy6dIlz586lvlauXDnc3dO37YKKQZJrqBgkcgc6uRUWvQandziPF6oG7UdBqYYuTSs97A47C44s4MutXxIZH5n6evBFB88utXPXMedvg541a1Lkg2F4Vq58S/JKb+vYqy0r4mE1s3vlSdbMOIzjmiPoPbyttHqmGiWrFbgleYqI5HZRUVGEh4fz319rTCYTlltQ0BcRyQiHw4HNZkvzWlBQEAULFkz3PVQMklxDxSCRO5TdBlvHw18fQPxl52PuehxaDgPf3LfBcVRiFN/v/J4p+6dgc/z7puxw0GSvg15/2vF3tkjHYqFA714UHDAAs5fXLcnrZq1jVYv488MTdShZwJuTBy6w7Oe9xMckpRljMkHDh8pT64ES+iRcRO5IzgpCIiK5TUBAAEWKFMnQ85qKQZJrqBgkcoeLOQ9/vA87JjmPewRAi8FQ92kw575PZQ9fPMynmz5l45mNqa/5xjp4coWd+3Y5f0t0K16cwu+/j2/TJrckp5u1jvl5WvnykVo8UDWYK+fjWPz9LiLDYwzjKtYP5r4elbG6576fu4jIrWa324mOjubKlSskJiYaPo0XEckJFosFb29v8uXLh6enZ4avVzFIcg0Vg0QEgBMbUk4dO7vHebxwTWj/BZSo59q80sHhcPDH8T8YsWUEZ2LOpL5e9bidvkvtFL3Ontn+7dsT/PZbWDOwtDcjbtY69nzzcrzWsiL2JDt//7afI9uNO2EXKuVH2+dq4BuY8YcNEREREcldVAySXEPFIBFJZUuGzb/Aio8hwXjiFQB3PwUthoJP7tvTJjYplrF7xjJ+z3iS7CmtV27JDrqss/PgegdWu/Eac0AAhQa+Tr6uXTGZnR1Un3UhZ6PoP3kbh89FG2KNyhXgm8dqU8DbnS1LjrFpQahhjJe/O2371aBIuYBbkp+IiIiIuIaKQZJrqBgkIgZRZ+CPIbBruvO4VyC0eB/u7gm3qICSFWFXwvh88+esPLky9bVi5x30WWqjapjza7zr1qXwB8PwKFv2luQUk5DMm7N2sXCX8RSxYH8Pxjx+N3VL5+fojgj+HL+PpIS07RBmi4lmj1eiauOityQ/EREREbn1VAySXEPFIBG5rmNrYNFAiNjvPF707pRTx4rd7dq80mn1ydV8tukzTkSdAMDkcHD/Tgc9VtjxjTeON7m5UaBvXwr064s5nceDZoTD4eC3dcf4aNF+kq85RcxqNvFOuyr0blyaC6djWPz9bq5EGPcbqtG8OI27lcdiyX1FOBERERG5MRWDJNdQMUhEbsiWBBt/gJWfQqKxzQlMULc33D8YvPO7PL2bSbQlMmHfBH7a9RNxySnFlYDolBPHGu93/pbpXqYMhYcNxeeee25JTluPX2TA5G2cuWKsSLWvWYTPutbEmuxg+S97CNt/0TCmWMV8tO5THS+/7C9YiYiIiMito2KQ5BoqBolIulw5Bcvehb2znce9C0DLD1KOo8+FrWNnYs4wassolh5bmvraXUfs9Flmp9Bl59cEdH2I4DfewJIvX7bncz46gZenbWft4UhDrGyQDz8+UYdyBX1YP+cIO/409rb55fek3fM1KFjcL9tzExEREZFbQ8UgyTVUDBKRDDmyAha/AZEhzuMl6kO7kVCkpmvzSqdNpzfxyaZPOHzpMAAeiQ4eXmOnwyYHFifvoJb8+Ql++238O7THZDJlay42u4Mv/zjE6BWHDTFvdwufPFSDzrWKcXDDaVZMOogtOe0O2FZ3My16VqV8nULZmpeIiIiI3BoqBkmuoWKQiGRYciKsHw2rR0BSrDFuMkO9PnD/u+CZ+07ASrYnM/3gdEZvH010UkrrW6mzDvotsVHeuL8zAD6NG1N46Pu4//vvZXb6a/9ZXp2+gyvxyYZYz4aleLd9VS6ejGbJ97uIuZxoGFO3XWnu6VAGkzl7i1UiIiIikr1UDJJcQ8UgEcm0S2Gw7G3Yv8B53KcQtPoQanaHbF5Vkx3Ox53n621fM/fwXABMdgettzl4bJUdL2PNBZOnJwUHPE+BXr0wubllay5hF2J5btJW9p66YojVLpmPMY/fTYDJzNIfd3PmqHFM6ZoFadm7Ku5e1mzNS0RERESyj4pBkmuoGCQiWRbyJyweCBdDncdLNU5pHQuu6tq80mlXxC6GbxzO3si9ABS44uDp5XbqhTh/S/WoVIkiHwzD6667sjWP+CQbQ+fvZdpm4x5B+X3c+ebR2jQsnZ9V0w6yf61xCVNgYW/a9a9JvmDvbM1LRERERLKHikGSa6gYJCLZIike1n0D/4yCZGfntlugQX9o/hZ45L5Nj+0OO3NC5vD1tq+5mJBygle9g3aeWW4nv9ND1EwEPvYYQa+9isXXN1tz+X1LGIPn7iHhmj2CTCZ47YGKPN+8HHtXn2LNjBAc1xxR7+FtpdUz1ShZrUC25iQiIiIiWadikOQaKgaJSLa6eAyWvAWHljiP+xWBVh9B9a65snXscsJlxuwYw/SD07E77HglOHh0lZ3WWx04OyPNWqgQwYPfw79ly2zNY++py/SftI0TF4x7Mt1XKYgvu9ciOiyGZT/tIT4mKU3cZIKGD5anVssS2b7ptYiIiIhknopBkmuoGCQit8TBJbBkEFw64Txe5t6U1rGgSq7NK50OXjjI8I3D2XZuGwDlwx30XWqj9Dnn431btKDwe+/iVqRItuVwOS6J13/fyZ/7zxpixQO9+L5HHUp5urP4+91EhhuXL1W8J5j7nqiM1d2SbTmJiIiISOapGCS5hopBInLLJMXBP1/A2q/A5mRHZrMbNBwAzQaBu4/L07sZh8PB4tDFfLHlC87FncNic9B+s4Nu/9jxMB78hdnbm6BXXiGwx+OYLNlTgLHbHfy4+igjlh3gmo4w3C1mhnWuRtcaRfh7wgGObI8wXB9U0o92/WvgG+iZLfmIiIiISOapGCS5hopBInLLRR5JWSV0+E/ncf/i0OYTqNIxV7aOxSTF8OOuH5m4byLJ9mQKXXTw7DI7tUKdv+V61qhBkQ+G4VmlSrblsO7IeV6aup3z0caiWte7i/Nh52rs+yuMjfONm3h7+bvTtm91ipTPl235iIiIiEjGqRgkuYaKQSLiEg4HHFiYsp/QlZPOx5RrAe1GQIFyrs0tnUIvh/LZps9Ye2otOBw03ueg1592Aozb+oDFQv6ePQl6YQBm7+w53evM5XhemLKNLccvGmKVC/vxwxN1cITH8sf4fSTF29LEzRYT9z5akWpNi2VLLiIiIiKScSoGSa6hYpCIuFRiDKweCeu+BXuSMW5xh8YvQ5PXwD33HZHucDhYEbaCzzd/Tnh0OD5xDp5YYafFTudvv27FilH4/SH43ntvtsyfZLPz6ZIDjF1jXAHk52Fl5CN3US/Qj8Xf7+JyRJxhTI1mxWj8SAUsFmfbYYuIiIjIraRikOQaKgaJSI6IOASLB0LoKufxfCWh7edQqa1r80qn+OR4xu8dz9jdY0mwJVDlhIO+S2wUu+B8vH+7dgS/8zbWggWzZf5Fu04zaOZOYhJthli/ZmV5qUlZ/hq/n7B9xoSKVshHm77V8fJzz5ZcRERERCR9VAySXEPFIBHJMQ4H7J0Dy96BqNPOx1RsA20/g8DSLk0tvcKjwxm5eSR/nvgTa7KDB9fb6bLegZuxRoPZ358iH3yAf5vW2TL3kYhonpu4lZBzxpPEGpTNz1eP1OLIX+Hs+MN4optffk/a9q9BUAm/bMlFRERERG5OxSDJNVQMEpEclxAFqz6DDd+D3ckxXVbPlLaxxi+DW+48FWvdqXV8uulTQi+HUjQyZZVQ1TDnYwMff5xCbw7C7OGR5XljEpJ5Z85u5u04ZYgV8vNgTI+78T+byIqJB7Al29PEre5mWvSsSvk6hbKch4iIiIjcnIpBkmOqVauW5u9JSUmEhIQAKgaJSA47tx8WDYTja5zHA8ukbDBdoaVr80qnJFsSk/dP5vud3xOXFEPzXQ6e/NuOb7xxrEfVKhT/8kvcS5XK8rwOh4NJG47zwcJ9JNnSPgZYzCbebluZDsULsPTHPcRcSjBcX6dtKep3LIvJnPtOchMRERG5ndyqYpB2gxQRkbyrUBXotRAe+hl8nKxWuRgKkx+GaT3g0nWW3eQgN4sbvar3YsGDC2hfriMr7jLzal8La6sYiywJ+/YT+lBXrixZkuV5TSYTTzYsze/9GlI0IO3KKZvdwUeL9jNs7WHavVaLwmUDDNdvXXKcxT/sJjHOyaosEREREcn1tDJIMkxtYiKSK8VfhhXDYdNP4LAb41YvaPYGNHwRrLlzI+RtZ7cxfONwDl44wH27HDyz3I67k3pLvsceJfitt7KlbexCTCIvT9vOPyHnDbGyBX0Y82htTq84xf61xj2aAgt7065/TfIF575T3ERERERuB2oTk1xDxSARydXO7IZFr0PYRufxAhVSWsfK3efavNLJZrcx7eA0vtz6JYVOx/PqXBvFI43jPKpUofiXX+BeunQ2zOng679C+OavEEPMy83C8AerU/4K/PN7CA572scGdy8rrZ6tRqlqBbKch4iIiIikpTYxERGR9ChcA3ovhc7fgbeTY9kjQ2BiF5jRCy6Huzq7m7KYLfSo0oMp7adgrVCWt3tZWFXdSdvY/v2Edn2YK4sXZ8OcJl5rWZHxveoR4OWWJhaXZOPV33fy+6VLtH2hJp4+aeOJccksGr2TbcuPo8+XRERERPIGFYNEROT2YzZD7R7w4hao+wzgZKPjvXNgdD1Y+w3Yklye4s1UDKzItPbTaFOlC2M6mPmuvZkEa9ox9pgYwl97ndNDh2JPMG70nFH3VS7EwhebUKOYcZ+giRuO88qKAzQdUJ0CxXzTxBwOWD/7CH+M20dyoi3LeYiIiIjIraU2MckwtYmJSJ5zantK61j4VufxoMrQbiSUaeravNJpwZEFfLjhQwqcjuW1OddpG6tcmWJffoFHmTJZni8+ycYHC/cxZeMJQyzQ240vut5FwtoIjmw7Z4gHlfSj7XM18MvvaYiJiIiISMaoTUxERCSzitaGZ/6Ejl+DV6AxHnEAfusAs/pA1BnX53cTHct1ZHqH6XhXrMTbvSysrOGkbezAAY51fZjLCxdleT5PNwvDH6zByG534WFN+6hwMTaJpydt4WAZd+7pVMaw6CriRBQzPtnM6cOXspyHiIiIiNwaKgaJiMidwWyGOr3gha1wd0/nY3b/ntI6tuF7sOWuY9PLBJRhSvspdKnxKN91sDCmvZn4tNv3YI+N5dTAgZwe8j72+Pgsz/lwneLMHdCY0gXSnhbmcMCXf4Uw+kwE9/augpunJU08LiqJuV9uZ+8/uW9PJhERERFRm5hkgtrEROS2cHILLHoNTu90Hg+uDu1HQckGrs0rHZYfW877694n4HQUr82xUcJ4KjwelSpR7Msv8Sib9baxK/FJDPx9J8v3nTXEiuXzYmTrKhyZfYzLEXGGePVmxWjySAUsFn3+JCIiIpJRahMTERHJTsXrQp8VKXsFeRo3TObsHhjXGuY+D9ERrs/vBlqVbsXvHX8nsFIN3ulpYUVNJ21jBw8S+vDDXF6wIMvz+Xu68eOTdXi7bWUs5rRzhV+Ko+fMHThaBlOian7DtXtWhTP/qx3ERSVmOQ8RERERyR4qBomIyJ3LbIF7+qS0jtXq4XzMjskwug5s+hnsueekrBJ+JZjQdgKP1urF9+0tjO5gbBtzxMZy6o1BnB48OMttYyaTiX7NyjH52foU9PVIE0u02Rm8eD8rCzmo0cL4adWpkEv8/slmIsKispSDiIiIiGQPFYNERER8g6DLd9B7aUp72LXiL8PigfDzfSntZbmEm8WNgfUGMvr+0eysG8jbvSycKGgcd2nGTI490p2Eo0ezPGeDsgVY/FIT7iltXAU0e8cphoef4a6Hy2FxS/uIEX0hgdmfbyVki7HVTERERERcS8UgERGRq0o1hL6roM2n4O5njJ/eCb88APNfgtgLrs/vOpqVaMbMjjMJrlqHd3pZ+OsuJ21jhw4R+nA3Ls+fn+X5Cvl7MrlPffreW9YQO3Amiv7/HKBwl5L4BqZdQZScZGf5L3vZMPcIDru2LBQRERHJKSoGiYiI/JfFCg36w4tboMYjTgY4YNtv8O3dsPVXsNtdnaFThX0KM7b1WHrV6cdP7ax82/E6bWOD3uTUe+9hjzNu9pwRbhYz77Srwvc97sbXw5omFpWQzCt/7OdUvQCCy/obrt269DiLv99FQlzuOrFNRERE5E6h08Qkw3SamIjcUUL/SWkRizjgPF6sTsqpY0VruzavG1h/aj1v//M2HuHneXWOjVJO9r/2qFCBYl99iUe5clme72hENP0nbePgWeOeQPVLBdLTw5/QTecMsXzB3rR/vib5gr0NMRERERHRaWIiIiI5o0xTeG4NtPwQ3HyM8fCt8NN9sOh1iLvo+vycaFi0ITM7zaRk9Ya809PCn7WctI2FhBD6cDcuzZ2b5fnKBvkyZ0AjHqxdzBDbePwiQ8LPULJlMczXnER26WwsMz7dwvE9kVnOQURERETST8UgERGRm7G4QeOX4IXNUO1BJwMcsPkX+LYubJ+cK1rHCnoV5MeWP9L/npcZ286drzuZiXNPO8YRF8fpt97m1DvvZrltzNvdyheP3MVHXarjbkn7eHE+JoFXth7BdF8hPH3T9q4lxiWzcMxOti07jhYri4iIiLiG2sQkw9QmJiJ3vCMrYPEbEBniPF6iAbQfCYVruDav69h2dhuDVg/CHHaGV+faKG3s2MK9fDmKf/UVHuXLZ3m+nWGXeH7yNsIvGQtMHcoF0ei0g4unYgyxCvWCuf/JyljdLVnOQUREROR2oDYxERGR3KLcfdB/LbQYAlYvYzxsA/x4Lyx5K+VY+hx2d/DdzOw4k4p3Nefdpyz8UdvYNpZ4+Aih3R7h0uw5WZ7vrhL5WPhiE+6tGGSILTwSwViPGApVCzTEQjafZfbIbURdiM9yDiIiIiJyfSoGiYiIZIbVA5q+ntI6VrmDMe6ww8bvYXQ92PU75PBC3Hye+fj2/m95tdGbjG/ncf22sXfe4dRbb2OPjc3SfIE+7ozvVY9XHqiA6Zra05GLcQw+ewa/egXhmljEiShmfLKZU4cvZWl+EREREbk+FYNERESyIl8JeHQy9JgJgWWM8eizMLsP/NoBzu13fX7/YTKZeLLqk0xqO4nj9UvyVi8LxwoZx12eO5fQbo+QEHKdNrh0sphNvPJARX7tfQ/5vNPuFRSfbGdISBiRtfxx80zbFhYXlcS8L7ez95/wLM0vIiIiIs6pGCQiIpIdKrSE5zdA83fA6mmMH18DPzSBZe9CgvEIdleqVrAav3f8nZp12vBuTwvLnbWNHfm3bWzW7Cxv7NysYhALX2zCXcUDDLFxoWdZUdyMT4G0PzO7zcHKyQdZNeUgtuSc35BbRERE5HaiDaQlw7SBtIjITVwIhSVvQsgy53G/ItD6Y6j2EIYeKhdyOBzMODSDzzZ9Rp098fRbYsc70TguoHMnCg8ZgtnHJ0vzJSTb+HDhPiZtOGGIFfJ043mPfMQcjzbEilbIR+s+1fH2dzfERERERG5n2kBaREQkr8hfBnr8Do9OhYCSxnjUaZj5NEzsAhGHXJ7eVSaTiUcqPcKU9lM407Acb/a2cDTYOO7yvPmEdnuE+ENZy9XDauGjLjX4svtdeLqlfQQ5F5/EB5cjSK7oa7juVMglZny6mYgTObuiSkREROR2oWKQiIjIrVK5HQzYCPe+ARYnq1qOroTvG8GfQyHReNS6q1TKX4lp7adRv15nBj9lYdndTtrGjh7l2CPduTRzZpbbxh6sXZy5AxpTpmDalUZ24MtzEYSU9cBiTfuIEn0hgdkjthKy5WyW5hYRERERtYlJJqhNTEQkEyKPwOI34MhfzuMBJaDNJyknk+Vg69j8I/P5aMNH3LU7hucWO28b8+/UkSLvv5/ltrGo+CTemLGLpXvPGGLVPT3pHONOYlSSIXZ3m1LU71QWsznnfk4iIiIirqA2MRERkbysQDl4YhY8MgH8ixnjl8Ng+hMwuVtK4SiHdCrXiWkdpnGhUeWUtrHCxjFX5i8g9OFuxB88mKW5/Dzd+P6Ju3mvfRUs1xR29sTHM9oahaWQh+G6bUuPs/i7XSTEJWdpfhEREZE7lYpBIiIirmIyQdXO8MJmaPwKmK3GMYf/gO8aworhkBTn8hQBygaUZXK7yTRr0J33nrSwpI6TtrHQUI490p2Lv/+epbYxk8nEs03LMrVPAwr5pS38XHY4+DzhEpeLGFvsju+JZOanW7h4Jufa60RERETyKrWJSYapTUxEJJtEHITFAyF0tfN4vlLw0E9QsoFr8/qPZceWMXTdUKrtvkL/xXa8E4xj/Dt0oPDQoVh8s9Y2di4qnhenbGdj6IW0AQe08vCm1nlw2NM+trh7WWn1TDVKVS+QpblFREREciO1iYmIiNxugirBU/Ph4XEpx81f69JxGN8O1nwFdrvL0wNoXbo1v3f8nejGNXizt4UjztrGFi7k2MMPE3/gQJbmKuTnyeRn69OvWdm0ARMsT4xljn8iZk9LmlBiXDILx+xk27LjWd7YWkREROROoWKQiIhITjKZoHrXlNaxhi+AKW2xA4cN/nwfpj4KsRec3+MWK+FXgoltJ9K6cU8GX69t7NixlLax6VlrG7NazLzdtgo/PlkHP4+0bXQhJPO9ewzJ/te01zlg/Zwj/DFuH0mJtkzPLSIiInKnUDFIREQkN/Dwg9Yfw3NroGRDYzxkGfx4L4Rtdn1ugJvFjTfqvcGXrUYzu2N+Rj5kJuaavZ0diYmcef99Tr0+EFt0dJbma12tMPNfbELlwn5pXr9idjDaFMXZAOMjTMjms8wZuY2oC/FZmltERETkdqdikIiISG4SXBV6LYJmbwHXrMC5HAbj28D6MZBDLVHNSzRnZseZJDety6CnLRx20t12ZfFijnV9mPj9+7M0V5mCPsx5vjEP3Z329LUkE0wghi3+xta5iBNRzPhkM6dCLmVpbhEREZHbmYpBIiIiuY3ZAve9DU/OBu+CaWP2ZFj2Tsox9HGXciS9wj6FGdt6LF2a9mXIk1YW1XPSNnb8OMe6P8rFadOy1Dbm5W5hVLe7GP5gDdwt/3lsMcEKcwJzfBLAmnb+uKgk5n25nT2rwzM9r4iIiMjtTMUgERGR3Krc/SltY6UaG2MHFqa0jYVvc31egNVs5aW7X+K7Nj+xsEMhRnQ1E+2ZdowjMZEzQ4dx6vXXs9Q2ZjKZeLx+SWb2b0ixfF5pYofd7Iz1iiPeM+0jjd3uYNWUg6ycchBbcs5svi0iIiKSW6kYJCIikpv5F0k5cazp68bYpeMwrjVs/CnH2sYaFm3IzE4zMTdryJu9LYQ4bRtbQmjXrsTv25eluWoWz8eil5rQvFJQmtcvWBz85B7DaS/jNXtXhzPvq+3EXknM0twiIiIitxMVg0RERHI7ixVaDIEes8Arf9qYLRGWvAEzekH85RxJr6BXQX584Eceve8lhj7lxkInbWNJx09wrPujXJgyJUttY/m83RnXsx6vtayI6T/TJJhhsnscW7ySDdecPnyZGZ9sJuJEVKbnFREREbmdqBgkIiKSV1R4AJ77B0rUN8b2zYWfmsPpXa7OCgCL2ULfmn35ud14lnYqwufO2saSkjj7wYeEv/IqtqjMF2bMZhMvtajAhKfvIdDb7f/3N8EKjyQWeCfiuOYJJ/piArNHbCVk89lMzysiIiJyu1AxSEREJC8JKJ5y2lijl4yxC0fhlwdgy7gcaxurE1yHmR1n4nP/fQx62sKhosYxUcuWEfpQV+L27M3SXE0rBLHwpabUKpEvzesH3G1M9I4nzpp2fHKSneVj97J+zhHs9pz5+YiIiIjkBioGiYiI5DUWN2j1ITw2DTzzpY3ZEmDhqzC7DyRkftPmrAj0DOTb+7+ld4tBfPiUBwvucdI2FhbGscce48KkyVlqGyuWz4vf+zWkZ8NSaV4/a3UwzjuO027Ge29bdpzF3+0iIc7YUiYiIiJyJ1AxSEREJK+q1DalbaxYHWNs94yUtrGzWVt9k1kmk4mnqj3F+PYT+btzST572Ng2RlISZz/6iPCXX8lS25i71cywztX5+tFaeLlZUl+PNcMU73h2uhuLPsf3RDLz0y1cPBOT6XlFRERE8ioVg0RERPKyfCWh91Jo8LwxFhkCP7eAbRNzrG2sRlANZnScQf4HWvPG9drGli9PaRvbvSdLc3WuVYx5LzSmbJBP6mt2Eyz3TmK5VyL2axYoXToby8xPt3Bs9/kszSsiIiKS16gYJCIiktdZ3aHNJ9B9EngEpI0lx8H8F2Buf0jMmVUwfu5+jGo2iudbD+Hjp7yYV/8GbWMTJ2WpbaxisB/zX2hC+xppz7jf6WFjuk8C8ea0906Mt7Hou11sW3Y8S/OKiIiI5CUqBomIiNwuqnSEfqugSC1jbOdU+Pl+OHfA5WlBStvYI5UeYWKnqaztUo5PupmJurZtLDmZsx9/TPhLL2O7ciXTc/l6WBn9eG0Gd6iK1fz/wtNJq53ffBM4Z7GnvcAB6+cc4Y+xe0lKtGV6XhEREZG8QsUgERGR20n+MvDMcqjXxxiLOAA/3wc7p7k+r39Vyl+J6R2mU6J1FwY9Y+FAMeOYqD/++LdtbHem5zGZTDzTpAzT+jYg2N8j9fUrZgeTfRPY72bcRyhkyzlmj9hK1IX4TM8rIiIikheoGCQiInK7sXpA+5Hw8Hhw90sbS4qFOf1g3guQFJcj6Xm7efNxk495te3HfN7Tl3kNnLSNnTzJscce58KECVlq36pbOj8LX2xKw7IFUl9LNsFC7yRWeSZx7Z3Ph0Uz45PNnAq5mOk5RURERHI7FYNERERuV9UfSmkbK1zDGNs+EX55AM4fdn1e/+pcvjOTO09n00OV+aSbmSte1wxITubs8E84+eKL2C5fzvQ8QX4eTHzmHvo3L/f/F02wyTOZ2T4JJJjSloTiopKY9+UO9qwOz/ScIiIiIrmZikEiIiK3swLl4Jk/oU5vY+zsHvipGeye6fq8/lU2oCxT2k2hQrtHGfS0hQPFjWOi//yLow8+RNyuXZmex2ox82abyvz8VF38PK2prx91szPJN4EL5rT7CNntDlZNOcjKyQewJduvvZ2IiIhInuaSYlDZsmUpW7Yso0ePdsV0IiIi8l9untDxK3joF3DzSRtLjIZZz8DCVyEpZ/bK8bR6MrjhYN7uOJKRvQKY09DYNpZ86hTHHn+cyF9/zVLbWMuqwSx8sQlVivinvnbB4mCSbwJHrMbNo/f+c4p5X20n9kpipucUERERyW1cUgw6efIkx48fp1atWq6YTkRERJyp2Q36roRCVY2xLeNgbEuIPOLytK5qU7oN0zrPYFfXmnz8iLO2MRvnPv2MkwNewHbpUqbnKVXAhznPN6Jbnf8vQ0owwxyfRDZ4JBnGnz58mRmfbCbiRFSm5xQRERHJTVxSDCpcuDAAXl7XPtWJiIiISwVVhGf/glpPGGNndsFPzWHfPJendVUJvxJMbDuRmh178sYzFvaVMI6J/vvvlLaxHTsyPY+nm4UR3e7is641cLemPA45TPCPVzILvBO5dmvp6IsJzB6xlUObz2R6ThEREZHcwiXFoPr16wOwd+9eV0wnIiIiN+LuDV3GQJfvwXrNBzUJV+D3p2DxIEhOyJH03CxuDKo3iA86j+brXvmZ3chJ29jp0xx74gkix2etbax7vZLM7t+IEvn//3M44G5jim8CV0xp9wpKTrLzx9h9rJ9zGLs983OKiIiI5DSXFIP69++Pw+Hgyy+/JCnJuPxaREREckCtx6HvCihYyRjb9COMawMXj7k8raual2jO711mcbBbXT7ufp22sc8+4+TzA7LUNla9WAALX2hKi8qFUl87Z3Uw0S+BMItxH6Fty06waMwuEmL1TCMiIiJ5k0uKQffffz9vv/02O3fupEOHDoSFhbliWhEREbmZQlWgz99Q81Fj7NQ2+PFeOLDI9Xn9q7BPYca1Hsc9nfsx6Bmr87axFSs4+uBDxG7fnul5Arzd+PmpurzRuhLmfxcixZrhd99EdrgnG8af2BvJzM+2cvFMTKbnFBEREckpJkdW1lan0wcffADArFmz2L17NxaLhcaNG1OzZk0CAwOxWCw3vH7IkCG3OkXJgJMnT1KiRMrTeFhYGMWLOzkHWERE8haHA7ZNgCWDINnJqWINX4AHhoLFzeWpXbXu1DreXfkWLf48z4PrHMZPtKwWCr36Kvl798ZkzvznXWsPn+elqduJjPn/CWJ3JVhoEeeGhbQta+6eFlo+U43SNQpmej4RERGR67lVv3+7pBhkNpsxmf7/8ORwONL8/WZsNuMSbck5KgaJiNzGzuxJ2TPogpNTxYrXg4fHQz4ny3Nc5Hzced765y1i163npfl2AmKNY3ybNaPIp59gDQzM9DynL8cxYPI2tp24lPpasWQznWPc8XFc8wxjggady3J361IZer4RERERuZlb9fu3S9rEIKUAdPXr2r/f7EtERERcpHB16LcKqnc1xk5uhh+bwqFlrs/rXwW9CvLjAz/S/KGXeesZN/aUNBZfoletIvTBB4ndlvm2sSIBXkzr25BejUqnvhZutTPRL4GzlrQbS+OADXOPsnzsXpIS9QGWiIiI5H4uKQbZ7fYsfYmIiIgLefhB17HQfhRY3NPG4i7ClEfgj/fBZtxLxxUsZgt9a/ZlZNdx/PhMYWY0NnHt00LymbMcf/IJIn/5BUcmnyXcrWaGdqrGt4/Vxts9paU9yuxgim8C+92M3/vhLeeYPWIrVyLjMjWfiIiIiKu4bGWQiIiI5CEmE9R7Fp75AwJLG+Nrv4LfOsCVU67OLFXdwnWZ0XkWZx+/j48fNXPJ+5oBNjvnRo4i7LnnSL54MdPzdLyrKPNfaEz5Qr4AJJtgoXcSqzyTcJB2BfP5sGhmfLKFUyGZn09ERETkVlMxSERERK6vaC3otxqqdDLGTqyHH5rA4T9dntZVgZ6BjL5/NK27DeKdZz3YU8rYNhaz+h+OdulC7NatmZ6nfCE/5g1oTIeaRVJeMMEmz2Rm+SQSf01BKD46iXlf7mDPqpOZnk9ERETkVlIxSERERG7MMwAemQBtPwfzNaeJxUbCpIfh74/AnjP75ZhMJnpW68k33SYy7tkSzGhibBuznT3H8aee4vxPP2e6bczHw8q3j9VmaMeqWP89fz7Uzc4kvwQizWnvabc7WDX1ECsnH8CWrJZ3ERERyV1cXgy6cOECo0aNom3btpQoUQIfHx98fHwoUaIEbdu2ZdSoUVy4cMHVaYmIiMiNmExQvx88vQwCSl4TdMDqETChM0SdyZH0AGoE1WB65xlceqINHz1m5pLPNQNsdiK++IKwfs+RnMlnDZPJRK/GZZjeryGF/T0BuGhxMMkvgSNWYzFs7z+nmPfVdmKvJBpiIiIiIjnFJUfLX/Xjjz8ycOBAYmNTzoG9duqrx7F6e3szatQo+vbt66rUJAN0tLyIyB0u7iLMfR4OLjbGfApB11+gbDPX5/Uvh8PB7wd/58cVn9FvXjw1jxkfdSzBhSg+ahTedetmep7z0Qm8NHU7645EAmByQJN4Kw0S3AxjfQM9aNe/JkEl/TI9n4iIiNx5btXv3y4rBn366ae8++67qQWggIAAateuTeHChQE4c+YM27dv5/LlyymJmUx88sknDBo0yBXpSQaoGCQiIjgcsH40/DkU7NeerGWC5m/DvQPBbMmJ7AA4eOEgb6x4nbuXhtJtjR3ztU88FjNBL75Egb59MJkzt1jaZnfw5R+HGL3icOprlRMttIl1w420+xdZ3Mzc/1RlKtYrnKm5RERE5M6Tp4tBe/bsoXbt2thsNooUKcKIESPo1q0bbm5pPzlLTk5mxowZvPHGG5w6dQqr1cr27dupVq3arU5RMkDFIBERSRW2CWb0givhxljZ5vDQL+Ab5OqsUsUmxfLRho84umIeL82zExhjHOPTuDFFP/8Ma4ECmZ7nr/1neXX6Dq7EpxTGCiWb6BLjToDDWGSq3aokDbqUw2w2bnYtIiIi8l+36vdvl+wZNHr0aGw2G0FBQaxfv57HH3/cUAgCsFqtPPbYY6xfv55ChQphs9kYPXq0K1IUERGRzChxDzy3Biq0MsaOrkw5bezYGpendZW3mzfDmw7nsceHM6SvD7tKOzltbO1ajnbpQsymTZmep0WVYBa+2JRqRf0BOGdN2UcozGLcR2j78hMsGrOThNikTM8nIiIikhUuKQb9/fffmEwm3n77bUqWvHbTSaMSJUrw5ptv4nA4+Ouvv1yQoYiIiGSad354bDo8MBRM17SFRZ+B3zrC6pGQyVO8skPn8p35qfvvTO9XkWn3mrFfUxOyRZznRK/enP/hh0yfNlaygDez+jfi0Xopn97FmuF330S2u1/bRgcn9l5gxqdbuHjGyVIlERERkVvMJcWg8PCUpeONGjVK9zWNGzcG4NSpU7ckJxEREclGZjM0eRV6LQS/ImljDjv8/SFM6QYxkTmTH1A2X1mmdJyGpXd3hj1u4YLvNQPsdiK++poTz/YhOTJzeXq6Wfi0a00+f7gmHtaUotOf3kks90rERtrO/Mvn4pj56RaO7Tqfye9IREREJHNcUgyyWFI+JUxONn4ydj02W8qyanMmN3QUERGRHFCqEfT7B8rdb4wd/hN+bAonNrg+r395Wj0Z0nAITz8xkqH9/NjppG0sdt06jnTpQszGzLeNPVK3BLOfb0TJ/N4A7PSwMd03kRhT2oJQYryNRd/vYuvSY4ZTVkVERERuFZdUWq62hmWk5evq2PS0lYmIiEgu4hsEPWbBfe+B6ZpHjSvhML4drP0m5USyHNKmTBvGdZ/F3OerM7WZsW3MHnGe4717EfHddzhsxn1/0qNa0QAWvNiEB6oEAxButTPRL4Ezlmva0BywYe5Rlo/dS1JC5uYSERERyQiXFINatmyJw+Fg5MiR7N69+6bj9+zZw4gRIzCZTLRq5WRDShEREcndzGZo9gY8NQ98CqWNOWzwx2CY+hjEXsiZ/IAS/iWY0H4SPs885bRtzGR3cP6bb1Paxs5nrpUrwMuNn56sw5ttKmM2QZTZwVTfBPa5GVdLH95yjtkjt3IlMi5Tc4mIiIikl0uKQa+88goeHh5ER0fTpEkTRo4cSaSTXvzIyEhGjhxJ06ZNiYqKwsPDg1deecUVKYqIiMitUObelNPGSjc1xg4tgR+bwcktrs/rX+4Wd96850369/yWD57Lx44yTtrG1q9PaRvbsDFTc5jNJvo3L8ekZ+tT0NedZBMs8k5ipWcSjmv2ETofFs2MT7YQfuhipuYSERERSQ+Tw0UN6hMmTKB3797/n9hkokyZMhQqVAiTycTZs2cJDQ3F4XDgcDgwmUz8+uuvPPnkk65IL0+bNGkS//zzD1u3bmX37t0kJiYyfvx4evXqdUvmO3nyJCVKpJyUEhYWRvHixW/JPCIichux22DVZ7Dqc7imAILZDVp+AA36g8lYjHGV09GnGbTyDUot3M6jq+xYrknTYTYRNGAABZ97DpPF4vwmN3HmcjwDpmxj6/GUYk+ZJDMdYtzxJO33bTabaNq9AtXuLYYpB38mIiIikrNu1e/fLisGASxatIh+/fqlOSHs6gPOf9MoWrQoP/30E+3atXNVanla6dKlOX78OAULFsTHx4fjx4+rGCQiIrnTkb9hVh+IddJ2VbkDdB4DXvlcntZVSfYkvtvxHWsW/8zL82wUiDKO8W5Qn2IjRmANCsrcHDY7ny45wNg1oQAE2kw8GONOAbtxwXbVpkW5t3tFLFYdqCEiInInulW/f7v0yaJ9+/YcO3aM6dOn06dPHxo2bEjFihWpWLEiDRs2pE+fPkyfPp3Q0FAVgjLgl19+4dixY0RERPDcc8/ldDoiIiLXV+7+lLaxko2MsQML4cd74dR21+f1LzezGy/f/TKv9v6J4f0Lsq2sk7axDRs53KULMevXZ24Oi5nBHaoy5vG78XG3cNHiYJJfAkesxs2j9/1zinlfbif2SmKm5hIRERFxxiUrg06cOAGAr68v+fPnv9XT3dE+/fRT3n77ba0MEhGR3M2WDCs+hjVfGGMWd2g9HOo9m6NtYxGxEby96i0Kz9vgvG3MZCKof38KDng+021jmk+pKQABAABJREFUh89F03/SVkLORWNyQJN4Kw0S3AzjfAM9aPtcDQqV8s/UPCIiIpI35emVQaVLl6ZMmTJMmzbNFdOl27lz51i4cCFDhgyhbdu2FCxYEJPJhMlkynAh5fjx47z++utUrlwZHx8f8ufPT7169RgxYgSxsbG35hsQERHJqyxWeOB9eHwGeAWmjdkSYfFAmNkb4q/kTH5AkHcQP7b+iRL9X2LYE1bO+6WNmxwOzn/3Hceffpqkc+cyNUf5Qr7MHdCYTncVxWGCf7ySme+dSNI1+ypFX0xg9shtHNx4JrPfjoiIiEgqqysm8fLyIj4+nnr16rliunQLDg7OlvssWLCAJ554gitX/v/AGhsby5YtW9iyZQu//PILixYtonz58tkyn4iIyG2jYquUtrEZveHkprSxvXPg9C545DcoXCNH0rOYLfS7qx91guvwQdGBdJ95jruPpC3UxG3cxJEuXSgxciQ+jZy0v92Ej4eVrx+tRd3SgXy4cB8H3W1cNNvpEuNOgOP/n9vZkuz8OX4fkSejafBgOcxmbSwtIiIimeOSlUHFihUDwGYz9sLnFiVLlqRVq1YZvm779u10796dK1eu4Ovry8cff8y6dev466+/6NOnDwCHDh2iffv2REU52YVSRETkThdQHHovhkYvGmMXjsDPLWDLeHDdmRcGdQvX5bfH5rD25WZMvM+M7Zo6jOPCRY4/8ywR33yDIxPPOyaTiacalub3fg0pEuDJOauDiX4JhFmM99r+xwkWjdlJfExSZr8dERERucO5ZGVQq1at+P7771mzZg0NGjRwxZTpMmTIEOrVq0e9evUIDg7m2LFjlClTJkP3ePnll4mLi8NqtbJ8+XIaNmyYGrv//vupUKECgwYN4tChQ4waNYqhQ4ca7vH666+TkJCQoTkrVKiQoTxFRERyNYsbtPoISjWGOc9B/KX/x2wJsPAVOL4WOnwFHr45kmKgZyCjW45hYrGJfFDiC16cm0jB/3SxpbSNfU/05s0UHzkKt+BCGZ6jdslAFr7YhFem7+CfkPP87pvI/XFu1E5M+8h2Yu8FZn62hXb9a5K/iE9WvzURERG5w7hkA+mQkBBq166Nr68vW7duTV0plNv8txjUs2dPfv311xuO37RpE/Xr1wegX79+/PDDD4Yxdrud6tWrs3//fvLly8e5c+dwc0u7MaSvry8xMTHpznPFihU0b97caUwbSIuISJ538XjKfkHhW42xghWh228QXNX1ef3HrohdDF36Og9ND6fOYeOjlCkwgOIjRuHbpHGm7m+zO/j6rxC++SsEgJoJFh6Ic8NC2iVJbp4WWj1djdI1C2ZqHhEREcnd8vQG0hUqVGDKlCnExsbSoEEDpkyZQmJi3j8ide7cual/7t27t9MxZrOZp556CoBLly6xYsUKw5jo6GgcDke6v65XCBIREbktBJaC3kuhfn9j7Pwh+Pl+2D7J9Xn9R82gmvzafRbbXm/NhPvNJF/zROW4eJkTffpw7quvcCQnZ/j+FrOJ11pWZHyvegR4ubHLw8Z030RiTGkLT0nxNhZ9v4stS47hgs/3RERE5Dbhkjax+++/H4CgoCBCQ0N58skneeaZZ6hQoQKBgYFYbnAcq8lk4q+//nJFmhm2Zs0aAHx8fKhTp851xzVr1iz1z2vXrs3U3kQiIiJ3FKs7tP0USjWEeS9Awn/6sZLjYN4AOL4O2o0Ed+8cSdHf3Z9Rzb9gepHpfFjyU16YnUDQNW1jkT/8SMyWLRQfNQq3TBxccV/lQix8sQnPT97G7vDLTPRLoEuMO4Vt/6k+OWDjvKOcD4umRc8quHlk7ph7ERERuXO4pBi0cuVKTKb/L2t2OBwkJCSwZ8+e615jMplwOBxprstt9u/fD0D58uWxWq//o6xcubLhGhEREUmHqp1TThKb0QtO70wb2zEZwrelnDYWVClH0jOZTDxa+VFq9a3FkBKv0XHqMepe0zYWv2Urhzt3osSIUfg2bZLhOUrk92bGcw35YOE+pmw8wVTfBFrHulE1Ke2zx5Ft57h0LpZ2z9XAv6BXlr4vERERub25pBh077335uqiTmbEx8dz/vx5gJv27AUGBuLj40NMTAxhYWHZnssvv/ySukpp9+7dqa+tXLkSgCZNmvDss8+m+34nT568Yfz06dOZS1RERCQz8peFp5fD8ndh8y9pYxH74af7oMOXcFf3nMkPqJy/MuMfmcFHJT5k7/QF9Fhhx2r/z4BLVwjr04cCffsS9NKLmG7wIZIznm4Whj9Yg7tLBvLunN0s8k7iXIKDZvFWTP/ZRyjyZDQzPt1Cm77VKVYxMJu+OxEREbnduGxl0O3mv8fE+/re/FSTq8Wg6OjobM9lzZo1/Pbbb2leW7t2LWvXrk39e0aKQVc3pxIREck13Dyh/Sgo1QjmvwyJ/38fJikG5vSF42ug7efgljOrYnzcfBje9BPmFW3Ax6U+pP+sWApdTjsm8qefiN6yiRJffIlb4cIZnuPhOsWpVtSf/pO2sjkylgiLnY4x7nj+pyAUH53E/K920OSRClRvVuy2+0BOREREss4lG0ifOHGCEydOcOHCBVdM5xLx8fGpf3Z3d7/peA8PDwDi4uKyPZdff/31hhtO3+xUNBERkTyjelfouxKCqxtj2ybALw/A+cMuT+sqk8lEl/Jd+Ljv7/zwSgU2VzAWYhK27eBw505Er16dqTmqFPFn/otNaFU1mGNudib5JRBptqcZY7c7WD3tECsnH8SWbL/OnURERORO5ZJiUOnSpSlTpgzTpk1zxXQu4enpmfrn9JyMlpCQAICXV+7v4Q8LC7vh16ZNm3I6RRERuZMVLA/P/gl39zTGzu6Bn5rBnlmuz+s/yuUrx/huvxP69iOMf8B42hiXowjr24+zo0Zl6rQxf083fnyyDm+3rcwVN5jkl8ARq80wbt+aU8z7cjuxV/L+Ka4iIiKSfVxSDLpaAKlXr54rpnMJPz+/1D+np/UrJiYGSF9LWU4rXrz4Db+KFCmS0ymKiMidzs0LOn0DD/4EbtecJpYYDTOfhoWvQVK88+tdwNPqyfuNhnLf6yMZ3tuHcwHGMRd+/oWjTz5BUib24zOZTPRrVo7Jz9bH38+DOT6JrPdIMow7feQyMz7ZzLnjV5zcRURERO5ELikGFStWDACbzfiJVV7l6elJgQIFgJtvuHzx4sXUYpD24xEREclGd3VPaRsLqmKMbRkL41rBhaMuT+u/2pZpy2f9ZjNuYDU2VjS2jSVu30lIl85Er1qVqfs3KFuAxS81oV6Z/KzxSma+dyKJpD3RLPpiArNHbuPgxjOZmkNERERuLy4pBrVq1Qog9cSr20XVqlUBOHz4MMk3WOJ94MCB1D9XqeLkYVVEREQyL6gS9Pkbaj1hjJ3eCT82g33zXJ/Xf5T0L8nPD03h7LtPOW0bM12OIqzfc5wdMRJHknF1z80U8vdkcp/69L23LAfdbUz1TeCyKe1eQbYkO3+O38faWYex2x3XuZOIiIjcCVxSDHr55Zfx8vJi5MiRhIeHu2JKl2jSpAmQ0gK2devW645b9Z9P+ho3bnzL8xIREbnjuHtDlzHQ+TuwXrM/X8IV+P0pWPImJOfc3jnuFnferP8WbQZ9y6dP+3M2n3HMhbFjOfJED5JOncrw/d0sZt5pV4Xve9xNrI+FiX4JnLAYV2Xv+OMEC7/dQVy09hESERG5U7mkGFShQgWmTJlCbGwsDRo0YMqUKenadDm369KlS+qfx48f73SM3W5nwoQJAOTLl4/77rvPFamJiIjcmWr3SFklVLCiMbbxBxjXGi4ed31e/3F/yfsZ+dxcJr1Ri42VjG1jSTt3E9KlM1ErVmTq/m1rFGH+C40pWcSPGb6JbHM3rl4O23+R34drHyEREZE7lcnhcNzydcL3338/AMePHyc0NBSTyYS7uzsVKlQgMDAQi8Vy/QRNJv7666//sXffYVFdWwOHfzMMQy+CCoq9K/besPeuscZurFGTa+I1RZOYZmKKabYUO/bee8MWxY4dK9gRFZQ6MOf7g8t8wAxIGQbQ9T4PT/Dsfc5ZByEya/ZeK7tDBOD27duULFkSgMGDB6erJXuTJk04dOgQGo0GPz8/GjRokGz8xx9/ZNKkSQB88cUXTJ061dxhW9zdu3cNtY+Cg4MpUqRIDkckhBBCpBDzErZMgIBVxmO2LtBtLlToYPm4ktDpdcw6PZO7i/5h0D491iZKK+YbNhSPCRNQWVtn+PqRsXFMXn+B9WfuUTXGilZR1liRPPmk1qho2rc8lRoXzuxjCCGEECIbZdfrb4skg9RqNSpVwi8f6b2dSqVCURRUKlW2FZ4+fPgw169fN/z5yZMn/Pe//wUStnMNHz482fwhQ4YYXePMmTM0atSIqKgoHB0d+fTTT2nevDlRUVGsWLGCv/76C4By5cpx8uTJZF3I8ipJBgkhhMgTFAVOL4JtkyA+xni8wThoNRWsMp5oMacj944wd9V/GbbyGZ7Pjcc1Vbwp8etvWP+vIUdGKIqC7/Egvt58iQIx0CVCi6NivBqpYqNCNOlbDo116m/QCSGEEMLy8nQyqFmzZoZkUGbsz+Qy6VcZMmQIixYtSvf81L5UmzdvZsCAAYSHm15qXa5cObZu3UqZMmUyFWduI8kgIYQQecrDAFg1GJ7eMB4rUhd6LQCXnP23LCQyhM93TaTOQn8aXDH+fUPv5ECx6T/i1CJz283PBj9n7NLTPH8aRecILUXjjZM+BYo50W5kZZzz25m4ghBCCCFyQp5OBuVW5koGQcIWuN9++42tW7dy9+5dtFotZcqUoVevXowbNw57e3tzhJwrSDJICCFEnhMdDpvfh4vrjMfs3KD7n1CujeXjSiJeH89f5//kxoLZDNoTb3LbmOvgQXh++CEqrTbD138WEcv7K89y+GoITaOtqR2jMZpj46ChzTBvinm7Z+YRhBBCCGFmkgwSuYYkg4QQQuRJigL+/8DOTyHeRCOLxhOg+RSwMk6SWJL/Q39mrviAISuemNw2ZlW5IiV/+yNT28bi9Qq/7bnG7/uuUz7WinaR1mhT1BFCBXU7laR2+xKo1Jlf2S2EEEKIrJNkkMgx3t7eyf6s0+kIDAwEJBkkhBAiD7p/BlYPgWe3jceKNYSe88A5ZwsqP41+ype7P6Lq/CM0vGxi25ijPcWm/4BTy5aZuv6+K4/4z4qzWEfE0y1Ci5veuMFs8SrutBpSCVuHnK2pJIQQQrzJsisZZJHW8qbcvXuXkydP4ufnR1RUVE6FIYQQQog3TeEaMPIgVOhkPBZ0FOb6wHXLdDJNjZutG790+hOrr/7LvHYaYlOU+FG/jOTu2HHcnzYNJdbEKqdXaFHBgy3jffAo4sQSpxiumdiTdicglNXf+fPk7ovMPoYQQgghcimLrgx68eIFP/zwAwsXLuT+/fuG4wEBAVSqVMnw5xUrVrBu3TpcXFz4+++/LRWeSCfZJiaEEOK1oChwfC7s+gz0uhSDKmjyX2j2MahztsPW+ZDz/Lr8fQYsf0ihZ8bjau8KlPxtJtoiGd82Fq2LZ8qGC6w5eZc6MRqaRGtQp9g2ZmWtpnn/8pSvXyizjyCEEEKITMrz28QCAwPp0KEDN2/eTFaIWaVSGSWDbt++TZkyZVAUhYMHD9K4cWNLhCjSSZJBQgghXit3TyVsGwsLMh4r4QNvzQMnD4uHlVRYTBjf7p1Cub/30sjktjE7in3/A06tWmX42oqisPxEMFM3XcQzGjpFanEw0X6+clMvGvcqi5UmxxaWCyGEEG+cPL1NLDo6mo4dO3Ljxg3s7e2ZNGkSW7ZsSXV+iRIlaN48oXXqpk2bLBGiEEIIId5URWrBqINQrr3x2O1DMLcx3Dxo+biScLFxYXr733Gc9hnz2lub2DYWxd1x47n3zdcZ3jamUql4u14xVo9uQHwBGxY7RXPfSm8078LBe6z/+TQvn0Vn5VGEEEIIkQtYJBk0Z84crl+/joODA4cOHeL777+nQ4cOaZ7Tvn17FEXh2LFjlghRCCGEEG8yezfotxzafAOqFJmWiMewpBsc/AH0Jvq9W4hKpaJvxX68M2UlM8cW4b6b8Zxw32Vc7dOT2ODgDF+/WlFXNo9vTI3y+VnhGMMZbZzRnEe3wlk1zZ+7V03sVxNCCCFEnmGRZNC6detQqVS8//77VK9ePV3nVKtWDcDQtUoIIYQQIlupVNBwPAzdDs4p6u8oetj/Lfi+BS9Dcia+/6noXpFfR2xg/5cdOFzJeDuXcjmQa926EL5rV4av7eagZeHQurzbsgx77HVss49FR/JtaVEvdGz69Qynd95BmtIKIYQQeZNFkkGXL18GoE2bNuk+x93dHYDnz59nR0hCCCGEEKYVqwejDkGZ1sZjN/fDnz5w+4jl40rCwdqBr1v9RP7vv2F+BxtiNcnHrSKiuffe+9z9air6DG4bs1Kr+LBNeeYNrk2ws4pljjE8VyffNqYocGz9DXb8dYHYKOMVREIIIYTI3SySDHr58iUAjo6O6T4nJiYGAGtr62yJSQghhBAiVQ7u8PYqaPk5qFL8uvTiASzqDIdmgN64to6lqFQqupfrwejP1zJ3XAmT28ZeLFvJ1d49MrVtrGXFhPbz7kWdWOwYww2N8Ra5m2dCWP39SZ7ej8jMIwghhBAih1gkGZS4yuf27dvpPufixYsAeHp6ZkdIQgghhBBpU6vB50MYvAUcU/w+osTD3i9heR+IfJoz8f1PadfSzBi+jmNf9+CQt/G2Ma7c4FrXzoTt2JHhaxdzt2f9uw3pVLsI6xxiOWSrQ0mxbez5o0hWTz9J4MlHmX0EIYQQQliYRZJBNWvWBMDPzy/d5yxevBiVSkWDBg2yKywhhBBCiFcr0QhGH4ZSzY3HAncldBsLOm75uJKw09gxpeW3FPnhR+Z3sjXeNhYZw/3/TCD4yy8yvG3M1tqKn3pV5dselTntoGeNQyxRquQJobiYeHb9c5HDqwOJj8+51VJCCCGESB+LJIN69uyJoij89ddfBAUFvXL+r7/+akgc9evXL7vDE0IIIYRIm2MBGLAWmk8GUqy+Cb8HCzvA0T8SiunkoA6lOzL+8438/V5Z7robj79cvoorPbsTdeFihq6rUqnoX684q0c3IDa/lsWOMTw00X7+3N5gNv5yhoiwmMw+ghBCCCEswCLJoIEDB1K1alWio6Np1qwZ27dvT9Z9QqVSoSgK/v7+9O/fnw8//BCVSoWPjw/t27e3RIhCCCGEEGlTW0HTSTBoIzgUTD6mj4NdU2DF2xCVs23XizkXY8Y7azj73dscrGy8bUx17Sa3e/bk2tt9ebFvP0oG6h5VK+rKlvd8qFbBnWWOMZw30X7+wfUwVk3z5/7151l5DCGEEEJkI5VioZ6gQUFBNG7cmLt376JSqbC3tycyMhKA/Pnz8+LFC0PRaEVRKF26NEeOHKFgwYJpXVbkgLt371K0aFEAgoODKVKkSA5HJIQQQljYi0ew9h24fch4zKUY9FoIRWpZPKyU9t7Zw45ZH9F/WyQ2qTT9UooWotA7I3Hp2hW1nV26rhuvV/h1zzX+2HedKjFWtIqyRpNixZRaraJhzzJUbV4ElcpELSMhhBBCvFJ2vf62yMoggGLFinH27Fn69euHWq0mIiICRVFQFIWQkBCio6MNq4V69+7NiRMnJBEkhBBCiNzJySNhhVCTSRhtGwsLgvlt4d+5Ob5trGXxVkz4fDMLP6hkctsYgCr4AQ+nfsmlpj48+u034p48eeV1k7afv+OiYrljDGGq5CuM9HqFw6sC2T3/EroY405kQgghhMg5FlsZlNSdO3fYunUrJ0+e5PHjx8THx+Pu7k6NGjXo3Lkz5cqVs3RIIg3e3t7J/qzT6QgMDARkZZAQQgjB9b2wbgREhhqPVewCXWeCrYvl40pCp9cx599febJ4EW1PxuP2MvW5emsrnDp1wmPYO9iULfvKaweFRjLK9xS374XTKVJLiTgrozluhR1oP6oKrh72WXkMIYQQ4o2TXSuDciQZJPIWSQYJIYQQrxB+H9a8A0FHjcfylYBei6BwdUtHZSQ4PJhlFxbzcMMa2h6LpnhI2vM1DetSePgo7Bs0SHOrV7QunsnrL7Du1F0aRWtoEGNtNEdra0XLIZUoVb1AVh9DCCGEeGNIMkjkGlIzSAghhDAhPg72fwOHfzEes9JC22lQZzjkgvo54bHhrL26hpNb59PI7yk1bqb966BSpjiFh4/GpUMHVFqt6TmKwrITQXy56RJFo6FjhBablFvogJpti1OvS0nUVharViCEEELkWZIMErmGJIOEEEKINFzbBetHmu4q5t0DOv8Gts6Wj8sEnV7H7tu72bbvTyruCsTnooJ1GuV94t2c8Rg0FLe+fbFydTU551zwc8b4niLyaQzdIrQU0BsnfYpUyEebd7yxczKdWBJCCCFEAkkGiVxDkkFCCCHEKzwPhjXD4O4J4zG30tB7EXhWsXxcqVAUhdOPT7P66N84bDlEm9N6nKNSnx9vY41Lj+54DH0HbbFiRuNPI2J5f8UZ/r32hDaR1lTSaYzmOOazod3IKniUzB2JMSGEECI3kmSQyDUkGSSEEEKkQ7wO9kyFYzONxzS20H461BycK7aNJRUUHsSKs4sI3bCWNv/GUPhp6nMVFWib+1B4+GjsatRIVlfI0H5+73VqxFrRPMoaq5Tt5zUqfHqXw9unsLSfF0IIIUyQZJDINSQZJIQQQmTAla2wYQxEhxmPVe0DHWeAjaPl43qFsJgw1l5dw/kN8/E59JRKwWnPVyqVpcjId3Fq1QqV5v9XAu259IgJq87i9FJPlwgtTopx0qdCA0+a9iuPRmvciUwIIYR4k0kySOQakgwSQgghMujZHVg9BO6fNh7LXw56L4aCFS0eVnro9Dp23d7F7h1z8N5zkwaXFazS+O0xzsOdQsOGk++tXlg5OgBwJzSC0b6nuXMvnM4RWorFGyd98hd1pN3IKrgUsMuuRxFCCCHyHEkGiVxDkkFCCCFEJsTFwu7P4Phc4zGNHXT8GWr0t3xc6aQoCqcenWLd4T9x3XSUluf02MekPj/O3ga3Pn0oOHgo1p6eRMXGM2VDQvv5JtEa6ppoP29jr6HV0EqUqJI/G59ECCGEyDskGSRyDUkGCSGEEFlwaSNsHAcx4cZj1ftDh59Aa2/5uDLgTvgdVpyaz8t1G2l9PIYCJh4lkd5KhW2blniNGINNxYqG9vMloqB9pBZtyvbzKqjToQR1OpZEpZY6QkIIId5skgwSuYYkg4QQQogsenoTVg2Gh+eNxwpWgl6LoEA5y8eVQWExYay+vJKraxfS5PAzyjxIe75Sw5uiI8dyo1Q13l12huj/tZ93N9F+vpi3O62HVcLWwXgFkRBCCPGmkGSQyDUkGSSEEEKYgS4adn4KJ+cZj1k7QOdfoWpvi4eVGbp4HTtubefAtrlU3XOb2tcUjNM7/y+uqAeuA0fwWUQxjtwOp32klvI64zpCzvltaTeyCgWKOWVf8EIIIUQu9long86dO8f169dRqVSUKlWK6tWr53RIIg2SDBJCCCHMKGANbH4fYl8aj9UaAu2+B+u8UVRZURROPjrJxv1zyb/5GM3OK9jEpT5f52zPjfrtmaquTFny0TRagzrFtjErazVN+5WnYsNC2Ry9EEIIkfvkqWTQtWvXAHB1daVgwYKpztu3bx/vvvsugYGByY4XL16cX375ha5du5o7NGEGkgwSQgghzOxJYEK3sUcXjMc8qkDvReBe2uJhZcWtsFus9p9P9JpNtPKPJV9E6nPjNGoOFq/O/uIdqad44GCi/by3T2F8epfDyjqtNUdCCCHE6yXPJIPOnz9P9erVUalULFiwgEGDBpmct3PnTjp37kx8fDymQlCr1SxevJi3337bnOEJM5BkkBBCCJENdFGwfRKcXmw8pnWCLr9D5R6WjyuLnkc/Z83F5dxYs5Bmh8MpFpL2/ONFq3Cv5EBc1Q5GYwWLO9FuVBWc3GyzKVohhBAid8kzyaCffvqJSZMm4erqyqNHj7C2Ni76FxkZSdmyZXnwIKHKoJubGx07dqRw4cKcPHmSvXv3Agkri65fv46bm5s5QxRZJMkgIYQQIhudWwlb/gO6SOOxOsOh7TTQ2Fg8rKyKjY9lx63tHNkwh+p7g6h+K/VfQfUqK05V6seLAg2MxmwdrWnzjjdFK8rvh0IIIV5/2fX6W2OWqyRx4sQJVCoVHTt2NJkIAli2bBkPHjxApVLh7e3Nrl278PT0NIwvXLiQYcOGERYWxtKlSxk/fry5wxQZ4O3tnezPOp0uhyIRQggh3gDV+kCharB6MIRcST7m/w/c9U/oNuZWMmfiyyStlZYuZbrS+cMunOh/gmV7ZlNosz+NLypYxyefq1biqXPRl4cFr3K5/NsoVlrDWPRLHZt/P0u9rqWo2bY4KpW0nxdCCCEyyuybri9fvgxAkyZNUp2zevVqw+e///57skQQwJAhQ2jfvj2KorBr1y5zhyiEEEIIkbsVrAAj9kE1E9vlH5yDP5vC5c2Wj8sMVCoV9QrV4+uBi2j39xa2/9yNjY2seWFi55fnY3/qnP4Ru8jHyY4rCvy74Sbb5wYQE5VGhWohhBBCmGT2bWL58uUjPDycgwcP0rhxY6NxvV6Pi4sLkZGRFClShDt37pi8zqJFixg6dCjFixfn1q1b5gxRZJFsExNCCCEs6IwvbP0Q4qKNx+qNgdZfgUZrPJaHPIt+xtrzywhetZimR8Ip9Cz5uE5jx+UKg3iSv6rRuS4F7Wg/qgruXo4WilYIIYSwnOx6/W32lUEvXya0RXV2djY5fvHiRSIiEtpJNG3aNNXrVKhQAYDQ0FAzRyiEEEIIkYfUGJCwSsi9rPHY8TmwoB08D7J8XGaUzzYfw+uOZfL0Q7xYOA3fwUW5VPT/x63joqhy4S9K3dwEij7ZuWGPo1gz/STX/B9aOGohhBAi7zJ7MsjWNmGN74sXL0yOHz9+3PB5rVq1Xnmd6GgT74IJIYQQQrxJPLxh5AGo0st47N4pmOsDV7dbPCxz01pp6Vq+O998vBOvxQtY+d9aHKmoIl4FKhRKBO2k2vlZWOteJjsvLlbP7nmX8Ft6kfg4fSpXF0IIIUQisyeDChUqBMDZs2dNjh86dMjwef369VO9zrNnCeuDHR1lya8QQgghBDaO0ONv6PQrWKXoJhb9HJb3hV1TID7vN3pQqVTUL1Sfqe/40mLBFrb/0Jlt9TREasH92RVqn5yOU7hxqYGAQ49Y+eEmngfezYGohRBCiLzD7Mmg2rVroygKCxYsMBqLiIhg8+aEYodOTk7Url071etcvXoVQOrRCCGEEEIkUqmg9lAYvgfcShmPH/0DFnaEsNcnGVLKtRT/7fwDQ/88yOk/R7OmjSMRNk+peXYGhe8fNpr/LMaZVd+d5Oj4KUT/r7GJEEIIIZIzezKoX79+AJw5c4YRI0YQHh4OwPPnzxkyZAjPnz9HpVLRs2dPrKysUr2On58fYNzWXAghhBDijVeoKow8CJW6GY8FH0/YNha42+JhZSc3WzdGNHifSb8c5smir1nY3Q1N+HIqXPFFrU++GkqndeZsbFMOjv+Ds/26EX7gAIpeto8JIYQQiczeTQzAx8eHI0eOoFKp0Gg05M+fn0ePHqEoCoqioNVqCQgIoGxZE4UQgcjISDw9PYmIiOC3335j3Lhx5g5RZIF0ExNCCCFyCUUB/39g56cQH2s83vgDaD4ZrDSWjy2bKYqC79ldbF33G20vxhPvPJwYW3ejeQUfn6LC1aXEFXbG653R5O/+FmobGxNXFEIIIXKfPNNNDGDt2rVUrlwZRVHQ6XQ8ePAAvV6Poiio1Wpmz56daiIIEtrKJ3Yla9u2bXaEKIQQQgiR96lUUHcEDNsJrsWNxw/PgMVdIPyB5WPLZiqVioE12jLro/Vs7TCK1cW2otcZbwt7XLAWJ2v+l/hQNaFTvybApwF3fv2RuKdPcyBqIYQQInfIlpVBADqdjr/++otNmzYRFBSEVqulZs2avPvuu9SpUyfNc3v06GHIeK1fvz47whNZICuDhBBCiFwo6jlsHAtXthiP2eeHt/6B0s0tHpYlxOsVZuy+ypyD52ihC6HGi0pGc6zioql4ZQkFn5wFIM5ajXWH1pQc9R42pUzUXxJCCCFygex6/Z1tySDx+pJkkBBCCJFLKQr8Owd2fwb6uBSDKmg6CZp+BOrU6zbmZbsvPeKDVWcp8EJHx0hrbBVroznFgnZT6tYm1Mr/1xCKrVeF0mM+wKFePVQqlSVDFkIIIdKUp7aJCSGEEEKIHKBSQYN3E7aNuRRNMajAwemwpBu8eJQT0WW71pU82DyuMdqizixxjOex2rhodFCx1pytOo5Ya0fDMe3xAIKHDOVMx5Y82bgORaczOk8IIYR4nUgySAghhBDidVOkNozyg3LtjMdu+cGchnB1h+XjsoAS+R1Y/24jWtQuzFKnGC5ap1whBc/zlce/9seEOZdIdtzu5gNCPprM2aYNCJrzG/H/64orhBBCvG4kGSSEEEII8Tqyd4O+y6H1V6BKsS0s8gks7wNbPoDYyJyJLxvZaa34uXc1vuhemd1Ocey2iyWe5JURYmzycbr6BO4W9iFlzQTbpxFE/DaXiz6NuPr5R8TevWu54IUQQggLMHvNoK+++sqclwPg888/N/s1ReZJzSAhhBAijwn6F1YPhRf3jcfyl4Mef0Ph6hYPyxLOBD3j3aWnITSGrhE2OCnGNYFcnh6n+oXlWOlNbw/TqyDOpxZlx0zEvkb1bI5YCCGE+H95poC0Wq02e+G9+Ph4s15PZI0kg4QQQog8KCIUNo2Hq1uNx9TW0GIyNHzvtSwuHfoyhvdWnOHMtVA6RWopHmf8jPHaJ3he+4Mqt5+kea3IisUoOep98rVui8rq9ftaCSGEyF3yXDLIXJdVqVSSDMplJBkkhBBC5FGKAqcXwY5PQGdie1jxxtDjT3B5/f5tj9cr/LzrKnP238AnWkO9GONOYxpbFdElD+N8YCV1r8SjTuPX2SgPFwoMGYZX7/6oHRyyMXIhhBBvsjyXDLK1taVr164MGjSIihUrZumaxYsXN1N0whwkGSSEEELkcU+uw7rhcP+M8ZitC3T6BSq/Zfm4LGDXxYd8uOocni/0tI/UYoPxivZKbQpy3WET0ctX0+h0NHaxqV8vxkGL3VtdKPXOeKw9CmZj5EIIId5EeSYZ1Lp1a/bv349erzdsF6tVqxYDBw6kb9++FChQwJy3EzlAkkFCCCHEayBeBwe+g0MzwKiEMlC1L3T4EWydLR5adrv9JILRvqd4dO8l3SK05Ncb91QpVskNn0Gl2XtzPbd8/6beoSfkf5H6NeOtVMS1bED5MROxy+IboUIIIUSiPJMMArh//z5Lly7F19eXgICAhBupVGg0Gtq2bcuAAQPo2rUrNjY25r61yAbe3t7J/qzT6QgMDAQkGSSEEELkebePwPpREBZsPOZaHHr8BcXqWz6ubBYVG8+n6wPYcvoebSOtqajTGM1xcrOl3ajK5C/myOE7Bzmx7Fcq7rpGqYdpXzuiWmlKj/kA16bNzV5LUwghxJslTyWDkjp//jyLFi1i+fLlPHyY8C+nSqXC2dmZXr16MWDAAJo0aZKdIYgskmSQEEII8ZqLeg7bJkLAauMxlRp8JkLTSWBlXGcnL1MUBd9/7/DV5ktUjVTTLNoadYptY1YaNU36laNSo8IAXAm9wq71M3DbeIRagfo0rx9RxI1C74ykUI++qOVNUCGEEJmQZ5NBifR6PXv27GHx4sVs2LCByMhIwzslxYoVY+DAgQwYMIBy5cpZIhyRBbJNTAghhHhNnV8FWz+EmHDjMa/aCauE3EtbPq5sdjroGWOXnkYdGkuXCC2OJtrPV2pUCJ++5dBYJ3QQC4kMYeO+2cQtX0+DczFo41K/frSzDQ59elJq2Fg0+fJl12MIIYR4DeX5ZFBSERERrFu3jkWLFnHgwIFk9YUaNmzIoUOHLB2SyABJBgkhhBCvsWd3EraNBR0zHrN2gPbTocYAeM22P4W+jGH88jOcCwylS4SWIvHGbeMLFHOi3ajKOLvbGY5FxUWx7fQK7i7+h7pHn+JqoklbojhrNfHtmlBhzH+xLVUqOx5DCCHEa+a1SgYldf/+febPn8+0adOIjo7G1taWyMg0/hUVOU6SQUIIIcRrTh8Ph39JKDCtN7HkpWJn6Pw72LtZPrZsFBev5+fd1/hz/w2aRllTO9a4jpCNg4Y273hTrJJ7suN6Rc+hm/s4veRXvPfcoOiTtO/1sm5Fyo35EOf6DaWukBBCiFS9lsmgY8eOsWTJElatWsWzZ89QFEWSQXmAJIOEEEKIN8S9U7B2BDy9YTzmVAi6zYHSzS0fVzZLbD/v9UKhbaQ12pTt51VQr3MparUrjkptnMi5HHKRvat/xmPTcSrfTruu0IuSBSk64l08O/dAZf161WQSQgiRda9NMujGjRv4+vri6+vLzZs3AQxJoC5dujBo0CA6dOhgyZBEBkkySAghhHiDxEbAjk/g9CLT4w3GQcvPQfN6FUi+9SSCMb6neHLvJV0jtLiZaD9fomp+Wg2piI296STO48jHbNkxE2XlRuoGxKJJIy8Umc8e5/59KTV4NFZOTuZ6DCGEEHlcnk4GPXv2jBUrVrBkyRKOHz8OJCSAVCoVPj4+DBw4kF69euHs7JzdoQgzkGSQEEII8Qa6vAU2jYeop8ZjHpXhrX+gYEXLx5WNImPjmLz+AltP36NDpJayOuM6Qs4F7Gg/qgr5izimfh1dJNtPLOXh4vnUPf4cx+jU7xlrYwWdW1Jh9H+xkd+xhBDijZfnkkE6nY7NmzezZMkStm/fjk6nI/FW5cqVY+DAgQwcOJBixYplx+1FNpJkkBBCCPGGevEQNoyBG/uMxzS20PorqDvytSourSgKS/69w9ebL1Ej0gqfaI1R+3mNtZpmAypQvp5nmtfSK3oOBe7i/KLfqLz3Np7P05irgojGVan47kc41ahphicRQgiRF+WZZNDhw4fx9fVl9erVPH/+3JAAcnd3p2/fvgwaNIg6deqY85bCwiQZJIQQQrzB9Ho48Sfs/gLiY4zHy7SCrrPBycPysWWj00HPeNf3NNrQWDpHarE30X6+SlMvGvUqi5XGeEtZSpceB3Bw+U8U2uxP+btp/zoeVsGLEiPH4dG2Myor49VJQgghXl95JhmkVqtRqVQoioKNjQ1dunRh4MCBtGvXDo3GuCODyHskGSSEEEIIHl1MKC79+KLxmL07dJkJFV6vOpBPXsbw3vIzBAQ+pUuElsLxxkkfz1LOtB1RBcd86auh9CjiEds2/4pm1VZqXtahTuM385ceTuQbOJBSb7+D2t4+s48hhBAiD8lzySBbW1vatm2Lq6trlq6nUqmYN2+eeYITZiHJICGEEEIAoIuGvV/Cv7NNj9caCm2/Ba2DZePKRont5//af4MWUdZUN9F+3s7JmrbDK+NVPl+6rxupi2T70YU8WbSIuv7h2OpSnxttr0HdvT2VRk7E2qNgZh5DCCFEHpHnkkHmFB8fb9briayRZJAQQgghkrm+Fza8Cy8fGo+5l0koLl24huXjykaJ7eeLhSu0jrLGOkUdIZUaGnQrQ/XWRTP0u3G8Ph6/K9u5vOB3qh4Ixv1F6nPjrFRENquJ99hPcKzkndlHEUIIkYvlqWSQuen1afThFBYnySAhhBBCGIkIhc3vwZUtxmNqDTSfDI3eB/XrU/Mmsf3807sv6RqpxdVE+/nSNQrQYnBFtLYZL5dw8cFZDvv+SNEtZyj5KO1f2Z9VLU7p0RMo2LyN2d+YFUIIkXPyTDJIvP4kGSSEEEIIkxQFTi+GHR+DLtJ4vHgj6P4nuBa1fGzZJLH9/LZT9+gYqaV0nHGyK5+nPe1GVcGtUOa2yz14+YCd637GdvUuqgWmsX8MCPNyocDQYZTsPQS1Vpup+wkhhMg9JBkkcg1JBgkhhBAiTaE3YO1wuH/aeMzGBTrNgCo9LR9XNknafr5OpBUNozWoUmwbs7axosWgipSplfkaP5G6SLbv/5uwxb7UOvMSbRqVFCKdtWh6dsZ7+Ido3NJfu0gIIUTuIskgkWtIMkgIIYQQrxSvgwPfw+EZoJjY8l+1D3T4EWxdLB9bNklsP28XGkvHSC12JtrPV29VlAbdS6O2ynxphXh9PH4BmwmcP5Oqh+7hYmIRViKdtYqo1vWpPO4THEqVzfQ9hRBC5Izsev1t/gI/Znbq1KmcDkEIIYQQQmSUlTW0/AyGbAOXYsbj51fCnMZw55jlY8smNYvlY8t7jSlUIR9LHGN4ZGWcBDu7J5iNv54lIiwm0/exUlvRvFo3Rv62B8fNvhwZWJV77qbrBFnrFJy3HeN2xy4cGdCZx0cPIO8FCyGEyLXJoKNHj9K+fXvq1auX06EIIYQQQojMKt4AxhxOWAmUUlgQLOwA+75JWEn0GsjvaMPiYXV5u2UpljnGEKCNM5pzP/A5q6b58+D68yzfr7JXLYZPXkml7bs5ObEtl0uYLlStVsDt5HVCh43hWAcfbq5ejKJ7Pb7mQgghMi7XbRPbu3cv33zzDX5+foZj0lo+d5FtYkIIIYTIlIA1sOUDiAkzHvOqBT3+BvfSlo8rm+y8+JCJK89R8oVCyyhrNCnqCKnVKhr1KkOVZkXM1gEsQhfBzl1ziViynOrnI9Ck0ZT3hZsttn17UGnI+2icnc1yfyGEEOaV52oGKYrC+vXr2bNnD8HBwVhbW1OiRAl69uxJw4YNjeYfOHCATz/9lOPHjxvOB2jTpg07duzIjhBFJkkySAghhBCZ9jwI1o+GO0eMx6wdoP33UGMgvCbt0W89iWD0klOE3XtJ1wgtzorxwvyydTxoPqAC1jbGncgyK14fz8HT67g9fzZVjzzEIY1daTE2aqI7+lD13U+xL2JiS58QQogck6eSQXfu3KFr164EBASYHO/VqxdLly7FysqK0NBQhg8fzqZNm4CEJJBKpaJLly5MnjyZ2rVrmzs8kUWSDBJCCCFElujj4chvsP9b0Btvo6JCJ+jyB9i7WT62bBAZG8en6wLYefo+nSO1FDfRft6tsAPtR1XB1cPe7Pc/f+cEp+b/QMmdF/F4nvq8eDU8r1+BCuM+In/N+maPQwghRMblmWRQbGwstWrV4uLFi6nfVKXiww8/ZPz48TRt2pQ7d+6gKApWVlb07t2bTz/9FG9vb3OGJbIg5d+FTqcjMDAQkGSQEEIIIbLg3mlYNwJCrxuPOXpC9zlQuoXl48oGie3nv9l8iXoRVtSPsTaao7W1ouWQSpSqXiBbYrgXFsx+3+9xWXeQMvfSLsMQWt4DrxFjKNG+Jyor861YEkIIkTF5Jhm0YMEC3nnnHVQqFcWLF2fKlClUqVIFrVbL5cuX+fHHHzlz5gwODg5Ur16dI0cSlgi/9dZbTJs2jbJlpeVlbiPJICGEEEJkm9gI2DkZTi0wPV7/XWj5BVjbWjaubHLqzjPGLj2NY2gsHSK02GC8Ha5Wu+LU7VIKtTp7tsq9jH3Jrq1/EOu7hiqXIlGn8WogrIA99gP64D1wHFb25l+1JIQQIm15JhnUuXNntm7dStGiRbl48SKOjo7JxvV6PU2aNOHo0aMAWFlZMW/ePAYNGmTOMEQ2km1iQgghhDC7K1th03iIDDUeK+gNb/0DHpUsH1c2ePIyhvHLznA58CndIrQU0BvXESpSIR9thntj56jNtjji9HEcPL6Ke/PmUuV4CLZpNBeLtLdC16UF1cd8iq2HZ7bFJIQQIrnsev1t9tby586dQ6VS8d///tcoEQSgVqv56quvgITtYgMHDpREkBBCCCHEm65CRxhzDEq3NB57fBH+agb/zgF9Gu2x8oj8jjYseacufVuUYqlTDJesjesm3b3yjFXf+vPoVni2xaFRa2jZ4G0G/eOHasPfnOhenqfGv74DYB8Zj8uK3QS2aI7fu715evFstsUlhBAi+5k9GRQamvBuTuXKlVOdU7VqVcPnPXv2NHcIQgghhBAiL3LygP5roN10sLJJPhYfAzs+hqU94cXDnInPjDRWaj5uX4GZg2rhl09hj10s8SRfsP/yWQzrfj7FxUP3yKYGwAbVSzdm8HcbKL57O2dG+RDkYfplgiYeCuwL4NFb/TjYsyW3d2/I9tiEEEKYn9mTQVFRUQAULFgw1Tn58+c3fC5bjIQQQgghhIFaDfVHw8gD4GHizcUbe2F2g4RtZa+Btt6ebBzfiMhi9qx0jOWlKnliRR+ncGDpVfYvuUJcbNpFn82hSL4SvD3hLxru+pcrX/TjYrnUazUVvHCfqPGfcKRVPc4v+g19bGy2xyeEEMI8zJ4MyiiNRpPTIQghhBBCiNzGoxKM2AcNxhmPRT2FFW/DpvcSClDncaUKOLJ+bENq1/ZksVM0wVbGSZ/LRx+w9sdThD+JskhMTjZOdO/3Od02+HP/z485U8+d2FSairnfe4H1d3M52q4hYcE3LRKfEEKIrMnxZJAQQgghhBAmaWyg7bcwcAM4FTIeP70I/myS0KI+j7PXavi1T3UmdfNmnbMOfxvjas5Pgl+yapo/dy6YKLKdTTRqDS2bDubtRYdR1s7hVMcyhNuZnut+P4KLfbrz9NYVi8UnhBAic8zeTUytVqNSqRgzZkyaW8WmTp2arnkAn3/+uTlDFFkk3cSEEEIIYXGRT2Hze3B5s/GYWgPNPoHGE0CdyvKVPOTUnWe8u/QUrk/iaBdpjTZl+3kV1O1UktrtS6DKpvbzaQkOucGR+d9ScNNxCoUaF/R+7mpNOd/luJfxtnhsQgjxuskzreUTk0HmFB+f/fujRfpJMkgIIYQQOUJR4IwvbP8IdCa2hxVrCD3+BNdilo/NzEJexDB++WkCA5/RNUKLu4n288Uru9NqaCVsHaxzIEIIjw5j/8qfcP5rPZ6hyX9fD3fWUHLxYgpWqJEjsQkhxOsiz7SWB1AUxWwfQgghhBBCAKBSQc2BMPoQeNUyHg86CnMaw/nVlo/NzAo42eD7Tj16tiiJr1MM16yN3xy9cyGU1d/5ExL8IgciBGdbF7oO/pqiSxZxzyN5HVDn8DhuDxjIw4ATORKbEEKItJl9ZdDBgwfNeTkAmjZtavZrisyTlUFCCCGEyHHxOjj4Axz6CRTjrUpU6QUdfwZbF8vHZmY7Ljxk4qpzVAyHJtEa1Cm2jVlZq2nWvzwV6puoq2Qht4MCuDa4P0UfJK919NJBTeF//sKrRqMcikwIIfK2PLNNTLz+JBkkhBBCiFwj6F9YNwKeBxmPuRRL2DZWvKHl4zKzmyEvGeN7msi7EXSO1OKgGJdlqNzEi8a9ymJlnTM9YoLvXebi4L4Uv5u8xXyknZoCf86kWN3mORKXEELkZXlqm5gQQgghhBAWUaw+jD4MVfsaj4UFwcKOsPerhJVEeVhi+/madRLaz9+3Ml4NdcHvHut+Ps2Lp9E5ECEU9apIVd+13Cpuk+y4fZSeJyPGcvvwzhyJSwghhDFJBgkhhBBCiLzN1iVhBVDP+WCTYluYoodDP8O8NvDkes7EZyaJ7ef/282b1c6xnNbGGc15fDucVdP8Cb7yNAcihMKeZai1dAM3SiXvP28Xo/Ds3QncOLApR+ISQgiRnCSDhBBCCCHE66HyWzDmCBRvbDx2/zT86QOnFiZ0JcujVCoVgxuWYPnoBlzwVLPVPhYdyZ8n+qWOzb+d5fTOOznSkMUjfwnq+W4isKx9suO2sQovx3/EtV1rLR6TEEKI5CQZJIQQQgghXh+uRWHwJmg1FdQpWq7rImHz+7CiP0SE5kh45lKreD62jPfBuYILS51ieK5Ovm1MUeDY+hvs+PMCsVHGK4iyWwG3IjT23crVik7Jjmt1ED1hCle2LrN4TEIIIf6fJIOEEEIIIcTrRW0FjSfA8D3gXtZ4/OpWmNMAru+xfGxmlNh+vkeLkix2jOG6xrj9/M2zIaz+/iSh919aPD43F0+aLdnG5crJt+5Zx0PspK+5uGGhxWMSQgiRQJJBQgghhBDi9VS4Oozyg9rDjMdePgLft2D7x6DLmYLL5qCxUvNJ+4r8NqgWu930HLLVoaTYNvb8USRrvj9JoP8ji8fn6pifVou2c6l6vmTHreNB/+l0zq/60+IxCSGEkGSQEEIIIYR4nWntodMv0G8F2Oc3Hj8+B/5uDo8uWj42M2pX2ZNN4xvxtLgtaxxiiVIlTwjFxerZNe8ih1cFEh9v3IksOzk75KPtgu1cqJ3866/Rg/qLXznj+7tF4xFCCCHJICGEEEII8SYo3x7GHIUyrY3HHl+Cv5rBsdmgt2yixJxKFXBkw9hGVK3tyWLHGB6aaD9/bl8wG385Q0RYjEVjc7RzoeO87QTU90h23EoB7TdzOLXgJ4vGI4QQbzpJBgkhhBBCiDeDkwf0Xw3tfwSNbfKx+FjY+Qn49oDwBzkTnxnYazX81rc6H3SryCrnWM6baD//4HoYq771537gc8vGZuNIl7+3c97HK9lxNWA/fR7+f06zaDxCCPEmk2SQEEIIIYR4c6hUUG8kjDwAHlWMx2/uhzkN4fJmi4dmLiqViiGNSrJsdH3OeqrZYRdLXIo6QpHhsWz45Qzn9gZbtP28rbUd3eZs5VyL4kZjjr8s4fjvn1ssFiGEeJNJMkgIIYQQQrx5ClaEEXuh4XjjsainsHIAbBoPMZbvwmUutYq7sWW8Dw4VXVjmGEOYKkX7eb3C4dWB7Jp3kdhoy7Wft9HY0OOPTZxrV9pozHn2ao799LHFYhFCiDeVJIOEEEIIIcSbSWMDbb6BQRvBqbDx+OnF8KcP3Dtl+djMJLH9fLcWJVniFMMtE+3nr598zJrpp3j2MMJicWmttLz183rOdq5gNOb6z0aOfDvBYrEIIcSbSJJBQgghhBDizVaqGYw5ApW6Go89vQnz2oDfj6A3TqTkBRorNZ90qMgvg2qy003PURud0ZxnDyJY/f1Jbp4JsVhc1lbW9J6+hnM9KhuNuS3ZwaEv3rXoFjYhhHiTSDJICCGEEEIIezfotQi6zgatY/IxfRzs+wYWdoRnd3ImPjNoV7kQG8Y3IqSELescYohOUUdIFx3P9j8DOLb+OnoLtZ+3UlvR+9uVnO9b02gs/8r9HJo8QhJCQgiRDSQZJIQQQgghBCQUl67RH0YfgiJ1jMeDjsHcxnB+leVjM5PSBRxZ/24jvGt7ssQphhC1cdLn9M4gNv1+jsjwWIvEpFap6f2FL+cH1TMaK7DuCH7/HSIJISGEMDOVIv9nFa/g7e2d7M86nY7AwEAAgoODKVKkSE6EJYQQQgiRfeLjEraG+f0AiolVMpV7Qsefwc7V4qGZg6IoLDx6m+lbLtPipQZvncZojkM+G9qNrIxnSReLxbTmp1FUmnfI6B3rR22q0/TXpajU8l62EOLNcvfuXYoWLQqY9/W3/N9UCCGEEEKIlKw00PwTGLoDXI3boHNhTcIqodtHLB+bGahUKoY2KonvqPqc8lSzxy6W+BTbxiKexbD+p9Nc8LtnkZU5KpWKnhP/5Mq7rdCrko957DrLgXd7ocTnzbpNQgiR28jKIJFh2ZWZFEIIIYTIlaLDYftHcG6ZiUEVNP4PNPsUNFpLR2YWIS9iGLfsNMGBz+kSocVJURnNKV/fk6Zvl8daa2WRmNbP+ZCyv2/DKsUrlQeNytJs7lrU1tYWiUMIIXKarAwSQgghhBAiJ9g6Q/c50HMB2KbcMqXA4V9gXmt4Epgj4WVVAScblg6vR+cWJVjsFE2QifbzV/99yNofThEWEmmRmLqP+ZmbE7sTl+LVSqEjgRx4pyv6WMvUMxJCiNeVJIOEEEIIIYRIj8o9YMxRKOFjPPbgLPzZBE7Ohzy48D6x/fzPA2uyzU3PCRPt50PvvmTVtJPcDnhikZi6vDONoE/6okuxGKnQiVscGNKJ+JgYi8QhhBCvI0kGCSGEEEIIkV4uRWDQJmj9FahTbFXSRcKWCbDibYiwTMLE3NpXSWg/f7+4LRvtY4hNUUcoNiqOrbPOc3zzTfT67E96dRz4Bfc+G0RsivrWhU4Hc3BAe+KiLLNSSQghXjeSDBJCCCGEECIj1Gpo9D6M2Av5yxmPX90GcxpC4B7Lx2YGpQs4smFsI8rXSWg//8RE+/mTW2+zddY5oiOMVxCZW/u+n/D4qxFEp8i9FQp4gN/b7dFFvMz2GIQQ4nUjySAhhBBCCCEyo1A1GHkQ6gw3Hnv5CJa+lVB4Whdl+diyyMFGw+99qzO+W0VWuMRyxTrOaE7QxaesmuZPSNCLbI+ndY8PeDZtHJEpanQXuvyYw33bEfsiLNtjEEKI14kkg4QQQgghhMgsrT10/BneXgUOBYzHj8+Fv5rDwwuWjy2LEtvPLxlVnxOeavbb6tCn2Db2IjSatT+c5PLR+9keT4vOY3n544dE2CQ/7hkYytE+7Yh+/jTbYxBCiNeFJIOEEEIIIYTIqnJtE4pLl21jPBZyGf5uDsdmgd54y1VuV7uEG1ve88GqojMrHWOJUCVPCMXHKexbfIV9Sy6jizHuRGZOTdsOJ2bGJ7ywS37c4+ZzjvduT1RoSLbeXwghXheSDBJCCCGEEMIcHAsmrBDq8BNobJOPxcfCzk/BtzuEP8iZ+LIgsf18x/+1n79nZZz0uXzkASu/OcHDm9m7ZatRy0Eov00l3D758YJB4Zzo04GIkLz39RVCCEuTZJAQQgghhBDmolJB3REwyg88qxiP3zwAcxrApU0WDy2rEtvP/ziwJlvc9ZzSGtcRCguJYt2Ppzi+6Sbx8dm3Cqpekz5YzZrGc0dVsuMF777kdK9OvHgQnG33FkKI14Ekg4QQQgghhDC3AuVh+F5o+B6QPGFB1DNYNRA2joWYvNcJq32VQqwf34jgEjZssY81aj+vKHBy223WTj/Fs4cR2RZH7QbdsZ37I0+dkn998z+M5FyfLoTdvZVt9xZCiLxOkkFCCCGEEEJkB40NtPkaBm8Cp8LG42d8YW5juHvS8rFlUWL7+TJ1PFjkFMN9K+NVQCFBL1j5rT/n999F0SsmrpJ1NWp3xOmf33jikvxljfvjaC706c7TO9ey5b5CCJHXSTJICCGEEEKI7FSyCYw5ApW6GY89uwXz2sDBHyDeeNtVbpbYfn5Sz8qszxfHYVsd8SlWCcXr9BxaeY3Nf5zl5bOYbImjarXWuM+fxeN8yV/auIXGcKVvT0JvXM6W+wohRF4mySAhhBBCCCGym70b9FoI3eaA1jH5mBIP+7+FhR3h2e2ciC7TVCoVb9crxrb/+BBdzpGljjGEqo1XCQVffsaKr48TePJRtsRRybsZHgv/5mF+q2TH8z3TEfh2bx5fPZct9xVCiLxKkkFCCCGEEEJYgkoF1d+G0YegSF3j8eB/YU5jOLciofBOHlIivwOrRzVgYIeyLHOJNVlcOiYyjl3/XGTXvItER+jMHkOF8g0ptmghDwpokh13CYvj1oD+PLzgb/Z7CiFEXiXJICGEEEIIISzJrRQM3Q7NPgFV8pUsxL6A9aNgzbCEQtN5iMZKzbgWZVk9thFBJWxY5RDDC5VxUivQ/xErvzlB8JWnZo+hdOnalPL15a6ndbLjzi/iCR48lHtnj5r9nkIIkRdJMkgIIYQQQghLs9JAs49h2A7IV8J4/OK6hFVCtw5ZPLSsqlLEhS3jG9OiWTEWOkVz2dp4ldDLZzFs+vUsh1ZdIy423qz3L1G8GhV8VxDkpU123DEingdDRxB88oBZ7yeEEHmRJIOEEEIIIYTIKUXrwujDUH2A8Vj4XVjUGXZ/AXGxlo8tC2ytrfiiszf/jKjHqcJWbLaPJdrEKqHz++6yapo/IUEvzHr/okUqUWXpGm4Xs0l23CFKz+Ph73L76C6z3k8IIfIaSQYJIYQQQgiRk2ycoNss6LUIbF1TDCpw5FeY1wpC8l6b9MZl87Pj/SaUr+vBQqcY7miMVwE9exjJ6u9PcnL7bfTxxsWnM6uwZ1lqLt3AzZJ2yY7bRys8HfMfbhzcYrZ7CSFEXiPJICGEEEIIIXID724w5mhCK/qUHpyDP5uA/7w8V1zaxd6a3/rWYFr/6uxw17PXLhZdihb0il7h+MabrP/5NGEhkWa7t0eBEtRdupHrZeyTHbeLUQgfP4lre9aZ7V5CCJGXSDJICCGEEEKI3MLFCwZuhDbfgDp5EWTiomDrB7C8H7wMyZn4sqBztcLs+qApDpXzscQphkdWxquAHt4MZ8XXJ7h46B6KmZJeBdyK0sh3K4HlHZMdt41ViPrPZK5sX2GW+wghRF4iySAhhBBCCCFyE7UaGo6HEfsgf3nj8WvbYU5DCNxt+diyyNPFlkVD6/Je90qsdtVxzEaHPsUqobhYPQeWXmXb7PNEhpunVpKbqydNfLdx1ds52XFtHMRO/JKLGxeZ5T5CCJFXSDJICCGEEEKI3KhQVRh1EOqONB6LeAxLe8K2/4IuyvKxZYFarWJwwxJsft+H8LIOLHeM5bnaeJXQ7YBQln91nJtnzbMKytWpAC0Wb+dytXzJjlvHg/6T7zm/5m+z3EcIIfICSQYJIYQQQgiRW1nbQYcf4e3V4FDAePzEX/BXM3hw3uKhZVWZgo6sHdOQnm1Ls8QllnNa4xb00S91bJ8bwL7Fl4mNMh7PKGcHN9ou3MGlWvmTHdfoQf3ZDM4sm5nlewghRF4gySAhhBBCCCFyu3JtYMwxKNfOeCzkCvzTEo7+AXrzdeOyBGsrNR+0LsfyMQ0ILKZlrUMMESZa0F8++oAV35zgfuDzLN/Twc6Z9vO3c7GeR7LjVgpov57FqUUzsnwPIYTI7SQZJEQeNGTIEFQqFSVKlMjpUPKU0NBQJk6cSMWKFbGzs0OlUqFSqfj1119zOjQhhBDi1RwLQL8V0PFn0CRvl058LOyaAku6Qfj9HAkvK2oUy8fW9xrT0KcIC5yiCbQ2bkH/IjSa9TNOc2z9deJ1WUt62ds40umf7QQ0KpzsuFoB2+/+5sTf32Xp+kIIkdtJMkiIVISFhTFr1iw6dOhAiRIlsLe3x8XFhXLlytG/f39WrlxJfLzxLyoidwoLC6NBgwb8/PPPXLlyhejo6JwOSQghhMg4lQrqDIdRfuBZ1Xj81kGY3QAubrB4aFllr9XwTbcqzH6nDkc8VGy3iyU2RXFpFDi9M4jV008Seu9llu5na21Htz+3E9CsWLLjasDp58X8O3Nqlq4vhBC5mSSDhDDh77//pnTp0owbN47t27dz584doqKiCA8PJzAwkGXLltG3b1+qVq3K4cOHczpckQ6zZs0iMDAQgEmTJnHo0CECAgIICAhg4MCBORydEEIIkUEFysHwvdDoP4Aq+Vj0c1g9GDaMhZgXORBc1jQvX5BdE5pQtFYBFjrFcNfK+M230LsvWfWdP2f3BKHoM9+CXqvR0m3WZs63KWU05jJzJUdnfJLpawshRG4mySAhUpg4cSIjR44kNDQUjUbDgAEDWLVqFcePH+fQoUP8888/tGjRAoBLly7RqlUr1qxZk8NRi1fZs2cPALVr12b69Ok0btyYypUrU7lyZdzd3XM4OiGEECITNFpo/SUM3gzOXsbjZ31hrg8E+1s+tizK56Bldv+afNGvKlvd9Ry01RGfYpWQPk7hyJrrbPz1DC+eZn7Fr9ZKS49fNnC+Y3njOP7awOHvP8z0tYUQIreSZJAQScyePZuff/4ZgCJFiuDv78+SJUvo1asXdevWpXHjxrzzzjvs3buXpUuXotVqiYmJYcCAAZw9ezZngxdpunfvHgDlypXL4UiEEEIIMyvpA2OOgHcP47Fnt2B+WzgwHeKz3o3LklQqFT1qFmH7B01QV3LG1zGGEBMt6O9de86Kr45z9d8HKErmVglZW1nT88e1nO9e2WjMfeE2/L4al+lrCyFEbiTJICH+586dO3z4YcI7Pw4ODuzdu5fq1aunOv/tt99m/vz5AMTExDBw4ED5JSEXi4mJAcDa2jqHIxFCCCGygV0+6Dkfuv8JWqfkY0o8HJgGCzvA01s5E18WeLnasWx4fUZ2rcDKfDr8bXQoKVYJxUbHs2fhZXb+fYHol7pM3cdKbUWvaSsJ6F3DaKzAsr34fTZKftcTQrw2JBkkxP/8+uuvhqLCn3/+ebpWkPTv35927RJavF64cIEtW7YYzWnWrBkqlYpmzZoBEBgYyLhx4yhbtiz29vaoVCpu376d7JzLly8zZMgQihYtiq2tLUWLFuXtt9/G3z9jy7wfPnzI5MmTqV27Nm5ubtjY2FC0aFF69+5t2DZlyu3btw2dthYuXAjAunXr6NChA4ULF0aj0RieJ9G1a9cYP348lStXxsnJCa1WS+HChalevTrDhg1j5cqVhoRMZm3evJmePXtSpEgRbGxscHd3p0GDBnz//fe8fGlcRPLAgQOG57hz5w4AixYtMhxL+veSUXq9nuXLl/PWW29RrFgx7OzssLOzMxQYX7NmDTqd6V9GY2NjmT17Ns2bN6dAgQJotVo8PT3p0KEDvr6+6NNoC5yyk9zz58/5/PPP8fb2xsHBAVdXV5o0acLSpUtNnv/VV18Znj2xhlJa2rZti0qlolChQqkWTN+wYQO9evWiWLFi2Nra4urqSu3atfnyyy959uxZup/lwYMHfPTRR3h7e+Pk5IRKpeLAgQPJzgkKCmLMmDGULFkSW1tbChcuTLdu3di/fz8AU6dONTxfWsLCwvjuu+9o1KiR4e+gUKFCdO7cmTVr1qT5y37i9adOnQqAv78//fr1M3xfenl5MXDgQC5fvpxmDIkuXLjA+PHjqVKlCvny5cPa2hpPT09atWrFDz/8wIMHD1I9N7M/40KI15RKBdX6wpjDULSe8Xjw8YRtY2eXQx5LaqjVKob7lGLD+MY8KmXPSodYwlTG/17eOB3C8q+Pc+diaObuo1LT68ulBAyoazRWcM0hDn40VBJCQojXgyJEBgUHByuAAijBwcE5HY5Z6PV6xc3NTQEUOzs75fnz5+k+d8eOHYavR/fu3Y3GmzZtqgBK06ZNlQ0bNigODg6G+Ykft27dMsxfuXKlYmNjYzQHUDQajfLPP/8ogwcPVgClePHiqcbl6+tr8l5JP9555x1Fp9MZnXvr1i3DnPnz5ysDBw40Ordp06aG+atWrVK0Wm2a9wKUgICAdH9dk4qKilK6d++e5rULFy6snDlzJtl5+/fvf2VMSZ8jvW7duqVUr179ldfev3+/yXMrVKiQ5nmNGzdWQkNDTd476d/9lStXlBIlSqR6nbFjxxqdHxgYaBifOnVqms/58OFDxcrKSgGU//znP0bjT58+VVq0aJHmsxQsWFA5duzYK5/l2LFjSv78+dP8Gu7du1dxdHQ0eR+VSqV8++23yhdffGE4lpo9e/Yo7u7uacbdoUMH5cWLFybPT5zzxRdfKLNmzVI0Go3Ja9jb2ysHDx5MNY64uDhlwoQJikqlSjOWwYMHmzw/Kz/jiqIoCxYsSPYsQojXTJxOUQ5MV5Sp+RTlC2fjj1WDFSXyaU5HmSkxunhl+vbLSvmPtiijx+1UZo7aa/LjwLIrSmx0XKbuodfrldXTRyiXylcw+tj3fj9FHx9v5qcSQgjTsuv1twYhBBcvXuTp06cA+Pj44OLiku5zW7VqhZ2dHVFRUWl2FgsKCmLAgAHY29vz2Wef4ePjg5WVFf7+/jg6OgIJKwz69+9PXFwcNjY2TJgwgQ4dOmBjY8Px48eZNm0aY8aMoVKlSmnGtGrVKsO2tVKlSjFu3DgqVapEgQIFuH37NvPmzWPbtm3MmzcPZ2dnZsyYkeq1fv31V86fP4+Pjw9jxoyhXLlyPH/+3LCa6dGjRwwdOpTY2FgKFizIuHHjqF+/Pvnz5ycqKorr169z8OBBNmzYkO6vaUqDBw9m/fr1AFSrVo0PP/yQihUr8vTpU1asWMHChQu5f/8+LVu25Pz583h5JRTRrFOnDgEBAUDCCpf79+/TtWtXvvnmG8O1HRwcMhTLo0ePaNSoEffv3wegRYsWDB48mAoVKqBSqbh16xb79u1j9erVRue+fPmSli1bcvPmTQC6devGsGHDKFy4MLdu3WLmzJkcPHiQw4cP07lzZ/z8/LCysjIZR2RkJJ07dyY0NJQpU6bQqlUrHB0dOXPmDF9++SV3795l1qxZdO7cmbZt2xrOK1OmDPXq1eP48eMsW7aML774ItVnXblypWE1UP/+/ZONxcTE0KpVK06fPo2VlRVvv/02HTp0oGTJkuh0Ovz8/JgxYwaPHz+mQ4cOnDlzhuLFi5u8z8uXL3nrrbeIjo5m8uTJtG7dGnt7ewICAihUqBAAN2/epEuXLkRERKDRaBgzZgzdunXD2dmZCxcu8OOPPzJ58mTq1TPxTngSR44coX379uh0Ojw8PBg/fjzVqlWjcOHC3L9/n5UrV+Lr68u2bdsYPHgwa9euTfVaO3fu5MSJE1SpUoX333+fKlWqEBUVxfr16/ntt9+IjIxk4MCBBAYGotVqjc4fOXKkYatpoUKFGDduHA0bNsTFxYWQkBBOnDiRanF6c/6MCyFeU1YaaDoJSreAtcMTagcldXE9BJ+A7nOhZJOciTGTtBo1k9pVoEWFgnyw6hw3HkbTJkqLnZJ8VeiFg/cIvvyUVkMr4Vky/b/bQcIq0Lf++yfrteMpP3dvsu0UnjvOsD+2N81nrkKllo0WQog8ymxpJfHaqlSpUrKPsmXLZktmMif5+voanunjjz/O8Pn169c3nH/v3r1kY4krg/jf6pU7d+6kep3atWsrgGJtbW1yRcHdu3eVIkWKGK5namVQSEiI4uLiogDKsGHDUl0V8OmnnyqAolarlStXriQbS7oyCFAGDRqk6PV6k9eZN29eulb+REZGKpGRkamOp2bLli2G67ds2VKJiYkxmvPXX38Z5vTu3dvkdYoXL57mKov0SrpCafr06anOe/HihfL0afJ3XCdOnGg4d8qUKUbn6PV6pX///oY5s2fPNpqTuJoGUFxcXJQLFy4YzQkMDFRsbW0VQOnSpYvR+O+//264hr+/f6rPUK9ePQVQypUrZzSW+P3j6uqqnDx50uT5t2/fVgoVKqQAyttvv53mszg6Oipnz55NNZZu3boZ5q5fv95oPCIiQqlbt26y79uUYmNjDSup2rVrp0RERJi8V9Lvp127dhmNJ71Hhw4dTH5PfvPNN4Y569atMxrfuHGjYbxBgwbKs2fPUn32oKCgZH82x8+4osjKICHeKNHhirLhXdMrhL5wUZRdnymKzvj/ZXnBi2id8tGac0ql/25RPhi7y+QKoVlj9irHN91Q4uIyt5pnw6wPlAsVjFcI7RneRYlP5f/BQghhLtm1MkhS2UIAT548MXzu6emZ4fM9PDwMn4eGpr5H/fvvv6dYsWImx/z9/Tl58iQAo0aNokkT43fpvLy8DN3OUjNnzhzCwsLw8vJi9uzZaDSmFwB++eWXeHl5odfrWbx4carXc3V1ZebMmanWYHn48CEA+fLlo3Jl4w4ciRJr6mTUrFmzgITCzwsWLDC5wmLEiBG0atUKSKhtlFaNlay4evWqYYVTt27dmDRpUqpzHR0dyZcvn+HPMTEx/PPPPwB4e3sb6s0kpVKpmD17tqHV/cyZM9OM5+uvv8bb29voeJkyZejWrRuAydVqffr0Maw4Sq220I0bNzh+/DhgvCro5cuXhr+Xr7/+mlq1apm8RvHixfnss88AWL16NREREak+y6RJk6hWrZrJsfv377N582YAevbsaXi2pOzt7fnrr79SvT7AihUruH37Nra2tixevBh7e3uT80aMGEHdugm1IhJrZplia2ub6vfke++9Zzh+6NAho/Hvv//eEPeaNWtwdXVN9T5FixZN9mdz/4wLId4ANk7QdRb0Xgy2rikGFTjyG/zTEkKu5kR0WeJoo+H7t6ry65DaHCgIu+1i0aEkm6PowX/rbdb9cIpnD1P/tyg1Xd/9mZsfdCMuxSunwoeusX94V/SxsVl5BCGEyBGSDBKvdPHixWQf+/bty+mQzO7FixeGzxO3bGVE0nPCw8NNztFqtfTq1SvVayQt9jp06NBU53Xv3j3NF46bNm0CoFOnTtjY2KQ6T6PR0KBBAwCOHTuW6rzOnTvj5OSU6njiNp5nz56xcePGVOdlRlxcHAcPHgSgTZs2Ri+KkxoxYoThnJRFh81l69athqKREyZMyNC5p06d4vnz50BC4eTUtn85OzvTu3dvAC5dupRqYkulUvH222+ner/EBM3Tp08N901UsGBBWrduDSRsBTNVsHrZsmWGz1Pe5+DBg4SFhQEJyZm0JCY1dTodp06dSnVeyoRTUvv37zdsVxs4cGCq86pVq5ZqQgn+/2ejadOmFChQIF1xp/Wz0bp1awoWLGhyzMnJibJlywIYtgUmCg0N5d9//wUSEnOFCxdOM5aUzPUzPmTIEBRFQVEUk8lJIcRrqFJXePcYlGxqPPbwPPzZFE78TV4rLg3QupIHOz9oQv7q+VnkFMN9K+N/2x7fecHKb/0JOHA3w0WgO4/4jqCP+hgnhP69yf6hndFnsUmGEEJYmiSDhIBkyQ5TXaleJek5zs7OJueULVsWW1vbVK+RWNtGq9Wm+YLW2tqaGjWMW54CxMfHc/bsWQD+/PPPZF2zTH0k1iNJXN1jStWqVVMdA+jSpYshOdW9e3datGjBL7/8wqlTp1LtPpVeN2/eJDIyEuCVtWCSjl+4cCFL903NmTNngIS/g/r162fo3KQxmeNZ8ufPb1hBZIqbm5vh86TJzkSJyZcHDx6YTPAmJoPq1atHmTJlko0lrmCDhGRgWt9jSVeLpfZ95ujoSKlSpVJ9lqRfg9RWISWqXbt2qmOJce/cufOVPxs//fRTmjEDVKhQIc1YEv8OUn79z549a3gR4uPjk+Y1UsqOn3EhxBvGuTAM3ABtvgWrFCsb46Jg20RY1gdehuRIeFmR39GGvwfV4pPeVdjoFsdhWx36FKuE4nV6/FZcY8sf54h4nrEETsfBU7n/2SBiU7yfU/hUEPsHdSAuOiqrjyCEEBYjySAhSHhhnSgzL5oePXpk+Dy1F+hJtwyZkljA2s3NLdVVI4mSbktLeY24uLg0zzUlMeFiyqvidnd3Z9OmTXh5eaEoCvv37+eDDz4wtLru0aMHW7ZsyXBM8P9fEyDVFRiJkm7vS3qeOSVuJ3RzczO5NSgt5n6W1LY4JVInKWhpKinXrVs3wzVSbhU7ffo0V65cAUyv2Hn8+HGa905Nat9naa10A5K1p3/Vip60xjMTd1RU6r/Yp/fvIOXXP+m21MSVdemVHT/jQog3kFoNDcfBiH1QwERiO3AnzGkA13ZZPrYsUqlU9K5TlG3/aUJcBSeWOsbwVG28Sijo0lOWf32c66cy9m9D236f8PjL4cSk2KFb+Nx9DvZvhy4i428qCiFETpBuYkKQfPVL4uqP9IqPj+f8+fNAwgvR1LZ8vCrBkyi12jzpjSXR8OHDef/999N1XlqJjfTE7ePjw/Xr11m7di3btm3Dz8+Pu3fvEh4ezvr161m/fj1t27Zl3bp1r3wBnZqsfF1ym9zwLI6OjnTt2pXly5ezbt065syZY1i5lrgqyMrKij59+hidm/T77PTp01hbW6frnkWKFDF5PL0/G1mVGHf79u354YcfLHJPc8uOn3EhxBvMswqMPAC7v4ATfyYfiwiBZb2gznBo/TVoM/fvd04p5m7PipEN+MvvJr/vukrDlxpqxiZ/6RMTEcfOvy9w67wHTfqUw8Y+ff+ete75IfutbXCdMgtb3f8fL3zxMYfebk/jpVvROppeKS6EELmFJIOEACpXroybmxtPnz7Fz8+PsLCwdLeX37Nnj+Fd94xu+UgqcQVOaGgo8fHxab5ATroSKamkW4MURUmzoLO52dra0r9/f8NKklu3brF161b++OMPrl27xs6dO5k8eTK//PJLuq+Z9HlSe+ZESVd0JT3PnBJXkD19+pTY2NgMvcBO+SzlypVLda4lngUSVv0sX76c8PBwtmzZQs+ePdHr9axYsQJIvSZO0tVvBQoUSDXJYy5JV6eFhITg5eWV6tyQkNS3Nbi7u3P//n1iY2Mt+rORUtKViBktdp6TP+NCiNeUtR10+AHKtoENYyAixUoZ/3/g1iF46x8olPbW8dzGSq1iTLPSNCmXnw9WnuPG3QjaR2pxTNGC/trxR9y/9pyWgytSpEL6/t1t3nUcftY26D+ZgX2S3WaFrj7hSN92NFy6FRuXtFdXCyFETpJtYkKQsFJj0KBBQMK2kL///jvd5/7xxx+Gz4cMGZLpGKpUqQJAbGws586dS3VeXFycoWZISlqt1tBd6siRI5mOxRxKlizJuHHj8Pf3NyQLVq1alaFrlCpVyrCSKLGzVWpOnDhh+Dy7XiDXrFkTSCiGnFZhYVOSxpQbngWgbdu2hsRE4mqggwcPcu/ePSD1os5Ja1ZZ4vssace0tIpQQ/J6Riklxn3y5Elic7DzS40aNQyrw/z8/DJ0bm76GRdCvGbKtkooLl2+g/HYk6vwd4uErmMmmg7kdt6FXdg4rhGtWxRnoXM0V6yNt9u+fBbDxl/Pcnh1IHG69NU8bNJhBNE/f0xEipKQntefcbRPO6KePTF9ohBC5AKSDBLif95//31DZ54vv/yS69evv/KcFStWsHXrViDhRXunTp0yff/E1ugAixYtSnXe+vXrk9VQSalLly4AXLlyhZ07d2Y6HnNxdnamTp06QPJaKemh0Who2jSh48nu3bu5e/duqnMT27ZrNBqaNWuWuWBfoWPHjoYX8b/++muGzq1Vq5ahNs6iRYtMdvCChGLDiUmzSpUqZbimTEZoNBpD57Jt27bx/PlzQ1LI3t7eZAt3SPheTUzS/f777xnuyJJRzZo1M9TfWbJkSarzzp07l2YiNfFnIywsjAULFpg3yAxwc3OjYcOGQEKC9P79+xk6P7f9jAshXiMO+aHvMuj0C2jsko/pdbD7c1jSFcLu5Ux8WWBrbcXkjpWYP7I+pwtr2GIfSzTG/36d2xvM6u9OEhJk3HzBlEatBhP/2xe8SPHl8rwdzvHe7Yl4IgX8hRC5kySDhPifEiVK8OOPPwIJ3cFatmyZ5gvLVatWMXjwYCDh3folS5ZkqRZM3bp1DStP5syZw+HDh43mPHjwgIkTJ6Z5nffff9/Q6n7o0KFcvHgxzflbt2411DzKjJ07d6a51SUsLMyw0qVkyZIZvv7YsWOBhBVT77zzDjqdzmjO/Pnz2bUrochljx49si2BUq5cObp37w7Ahg0bDN8vpkRERCRL2tnY2DB8+HAgoTvW119/bXSOoiiMGzfOkDQbN26cOcM3KXH1T0xMDMuWLWPt2rUAdO3a1fB9lJKrq6shtqNHjzJhwoRUk1uQsC0uMVmXGUWKFKFjx44ArFmzhg0bNhjNiYqKYuTIkWleZ/DgwRQtWhSAiRMnvnJVzuHDhzl48GDmgn6Fjz76CEgo7NyrVy/CwsJSnZsyCWqun/GFCxcauo5Ja3khhIFKBbWHwehDUKi68fgtP5jTEC6ut3ho5tCgtDvbJ/hQsb4nC51juKMxXgX09H4Ea6af5NSO2+j1r37Do17TvqhnfUuYQ/LfAz2CX3KqV0dePEz9zSwhhMgpkgwSIonx48cbCrIGBQVRu3ZtBg0axJo1a/D39+fo0aPMnz+fVq1a0adPH2JjY7GxsWHp0qVUr149y/efPXs2Go0GnU5H69at+fTTTzl8+DD+/v7MnDmTWrVq8eDBgzRbz3t4eLBo0SJUKhUPHjygdu3ajBkzhk2bNnH69GmOHz/O2rVr+eijjyhdujSdOnUiKCgo0zEvX76c4sWL07FjR3777Tf27t3LmTNn8PPzY/bs2TRo0MCw7Wj06NEZvn7Hjh3p1asXALt27aJ+/fosXbqUU6dOsWfPHoYPH25Isri5uTFjxoxMP0t6zJ4921AkfNKkSbRs2ZIlS5bg7+/PyZMnWbNmDWPHjqVYsWJGycTPP//c0EJ96tSp9OzZk61bt3L69GnWrl1LixYtWLx4MQANGjR4ZXLDHBo2bGhI0k2ePNmQwEpti1iir776inr16gHw22+/UbNmTWbNmsWRI0c4e/Ys+/fvZ+bMmXTr1o1ixYoxd+7cLMU5Y8YMw2qkXr168d5777F//35OnTrFokWLqF27NidOnDCsQjPFxsaGVatWYWNjw8uXL2nRogUDBgxgzZo1nDp1Cn9/fzZt2sQXX3xB1apV8fHxISAgIEtxp6Zz58688847QEJCrVKlSnz33Xf4+flx9uxZ9uzZw/fff0+NGjWYMmVKsnMt/TMuhHhD5S8L7+yGxh8AKd7sin4Oq4fA+jEQk74VNLmJs601M3pXZ/rAGuwqoLDPNpa4FKuE9PEK/264yYafTxMW8uqW8bUb9sBmzg88c0z+tSrwIJKzfToTfv+OWZ9BCCGyTBEig4KDgxVAAZTg4OCcDidbzJkzR3FzczM8Z2ofFStWVPz8/NK8VtOmTRVAadq0abruvWzZMkWr1Zq8n0ajUf766y9l8ODBCqAUL1481ets2rQpXc+gVquVffv2JTv31q1bhvEFCxakGW9iLK/6GD16tBIfH5+ur0FKUVFRSvfu3dO8fuHChZUzZ86keo3ixYsrgDJ48OBMxZDUjRs3lMqVK7/ymffv32907q1bt5QKFSqkeV6jRo2U0NBQk/dOz9+9oijKggULDNe7detWmnMnT56c7P758+dXdDrdK78O4eHhSo8ePdL199+8efNMP0uiXbt2KQ4ODqne44svvlA+++wzBVBsbW1Tvc6xY8eUokWLpivuRYsWGZ2f9H5pedXPflxcnDJu3DhFpVKlGUNq37NZ+RlXlOTfI696FiHEG+7WIUX5uZKifOFs/PFrVUUJOp7TEWbao7AoZfD840rNiVuVz97drcwctdfo48/3DigXD99T9Hr9K693/vQu5VCdisql8hWSfRxqXE15eueaBZ5ICPG6ya7X37IySAgTRo8ezY0bN/jjjz9o164dRYsWxdbWFkdHR0qXLk3fvn1Zvnw5AQEBWeogZkq/fv04c+YMAwcOpHDhwmi1Wry8vOjduzeHDx9mxIgR6bpO586duXXrFj/99BMtWrTAw8MDa2tr7OzsKFmyJJ06dWLGjBncvn2b5s2bZzreX375BV9fX4YNG0bt2rXx8vJCq9ViZ2dHuXLlGDx4MIcOHWLOnDmGui8ZZWtry7p169i0aRM9evQwfF3y5ctHvXr1+O6777h69apZVmelR6lSpTh79iwLFy6kY8eOFCpUyPC1LVeuHIMGDWLjxo0mvzdKlCjBuXPnmDlzJk2bNsXd3R1ra2s8PDxo164dS5Yswc/PL1u7iKWUchVQ79690Whe3WzSycmJtWvXcujQIYYPH0758uVxcnJCo9Hg5uZGnTp1GDt2LNu2bWP37t1ZjrN169ZcuHCBUaNGUbx4cbRaLR4eHnTs2JEdO3YwdepUwsPDAdLsBli/fn0CAwOZO3cuHTt2NHw/2draUrRoUdq0acO3337LlStXDIXls4OVlRV//PEHJ0+eZOTIkZQrVw4HBwesra3x9PSkTZs2zJgxg59++snk+Zb6GRdCCEo0hjFHoPJbxmPPbsP8drD/O4g3Lsyc2xV0tmXBkDpMeMubtfni+NdGhz7FKiFdTDz7l1xh25wAIsPTbkBQpUZr3ObPJsQ1+e887iExXO7bk9Cbl83+DEIIkRkqRcnmyp/itXP37l1D3Y3g4OBsbysthBDp1apVK/bu3Uvjxo05dOhQTocjhBCvF0WB86tg20SICTceL1Qd2v8AxepZPDRzuPUkggkrzxJyK5wOkda46o3fxLJ1tKbFwAqUrFYgzWtdvXKEB++MwiM0eU2i567WlF7sS8FyVc0auxDi9ZVdr79lZZAQQojXwv379w1FoevXr5/D0QghxGtIpYJqfWD0YSjWwHj8wVmY3wbWvAPPgy0eXlaVzO/AmtEN6NO+DL4usZzTGq90in6pY9ucAPYvuUxsdOorocpXaETRRQt5UCD5SlvX5zpuDujPg4snzR6/EEJkhCSDhBBC5AnXr19PdSwqKoohQ4YYus1l5/YuIYR44+UrDkO2QospoLIyHr+wBmbWgf3TIDbC8vFlgcZKzXsty7JybEOuF9eyziGGCJXxRopLRx6w8psTPLj+PNVrlS5Tm1JLfLnnkTwh5BIeR/DgIdw7d8zc4QshRLrJNjGRYbJNTAiRE5o1a0ZERAS9e/emVq1auLm58eLFC06ePMns2bMNyaJ33nknS63shRBCZMDdU7BhNDy5ZnrcqTC0mgpVekEmawfmlKjYeKbvuMKqw7dpE6mlbJxx4kulghpti1O3U0msNKafLzj4EpcH96Po/eT1hl7aq/H8ew5FazXJlviFEK+H7Hr9LckgkWGSDBJC5IRmzZpx8ODBNOd0796dpUuXYmdnZ6GohBBCEK8D/3/gwHcQHWZ6jldtaPc9FK1j2djM4FBgCBNXnaPgkzhaRFmjRWU0J39RR1oNrYR7YUeT17j3IJALA3tR7G5MsuMRdmryz/mdEvVbZkvsQoi8T5JBIteQZJAQIiecPn2a9evXs2/fPu7evUtISAiKolCwYEHq16/P4MGD6dChQ06HKYQQb66IUDgwDU7OB0Vvek6V3gkrhVy8LBpaVj2PjOWzjRfxO/2ADpHWFIk3XiVkpVHRoHsZqjYvgkptnDB6+PgWpwf2oOSd6GTHo2xUuMz6idKN5d8wIYQxSQaJXEOSQUIIIYQQIlWPLsHOT+HmftPjGjto/B9o+B5o7S0aWlZtPHuPz9dfoHwYNI7WYGVilZBX+Xy0HFwRJzdbo7GQ0CCOD+pO6RuRyY5Ha1U4/DqNci26ZVfoQog8SrqJCSGEEEIIIXI/j0owcD30WwFupY3H46IStpTNrA3nVye0rM8julb3YscHTdBWdsXXMYYnauMVUPeuPmPF18e5evwhKd93L+BejIa+W7heLvl2MttYhaj3PuHKzlXZGr8QQiSSZJAQQgghhBDCvFQqKN8e3v0X2nwLNi7Gc8LvwbrhMK8N3Dtl+RgzqZCLHYuH1eXd7hVZmU/HSRvjFvOxUfHsWXCJXf9cJDpCl2zMLV8hfHy3cq2ic7Lj2jiI/eALLm5ekq3xCyEESDJICCGEEEIIkV00Wmg4Dt47DbWGgsrEy4+7J+DvFrB+NITft3yMmaBWqxjSqCSb3m/MkzL2rHSIIVxlvEro+qnHrPjqOEEXQ5Mdd3UuSAvfHVyt4prsuHU8xH88jfNrpSumECJ7STJICDNr1qwZKpWKZs2a5XQoQgghhBC5g0N+6PwrjDoEJVNppX5uOfxRCw7+CLooi4aXWWUKOrHu3YZ0a1uKxc4xXLQ2XiUUERbL5j/O4bfiGrrYeMNxJ4d8tF60g8s13JPNt44H1Wc/c2bFrGyPXwjx5pJkkBApREREMHfuXDp06ICXlxe2trbY2NhQoEAB6tSpw7Bhw/j7778JDg7O6VDN7sCBA6hUKpMf9vb2FC1alE6dOjF//nxiYmJeeb3Ec1+VGIuLi6NPnz6G+fXr1+f58+fmeagUrl+/zvLly5kwYQKNGjXC3t7ecN+FCxea7T63b9/mjz/+4K233qJs2bLY29tja2tLkSJF6NatGytWrCAuzvgXxvRK+vVSqVTcvn07XecpisLatWvp1asXJUuWxM7ODjc3NypWrMiAAQNYsGAB8fHxaV7jyZMnfP7551StWhVnZ2ecnZ2pWrUqn3/+OaGhoWmeK4QQ4g3nWRkGbYK+yyBfSeNxXSTs/wZm1oELa/NEPSFrKzUftCmP77sNuVTMmk32sUSpjOMOOHCXVd/68+h2uOGYg70L7Rbu4FKdgsnmavSg/XImJ5f8ku3xCyHeTNJNTGTY69xN7NixY/Tt25egoKBXzvXw8ODhw4dGx5s1a8bBgwdp2rQpBw4cyIYos8+BAwdo3rx5uuZ6e3uzZcsWSpQokeoclSqhw0ZaXwudTkefPn1Yv349AI0bN2bbtm04OTllKPb0OHjwYJqJqQULFjBkyJAs3+ezzz7j22+/NSoamVKdOnVYs2YNxYoVy9D1t2zZQufOnZMdu3XrVpp/FwBBQUH079+fw4cPpznv2bNnuLq6mhw7fvw43bp1M/m9D1CoUCE2bNhA3bp107yHEEIIQVwMHJ+bsBIo9oXpOcUaQLvvoHANy8aWSRExcXy77TIbjwXRPlJLyTjjFvQqNdTpWJJa7Yqjtkp4bz46NpItozvhffRBsrl6IGLiYOoO/9gS4QshciHpJiZENrt27Rpt27Y1JIK6dOnC4sWL+ffffzl9+jS7du3ixx9/pE2bNlhbW+dwtNlvzJgxBAQEGD727t3Lb7/9Zvifz8WLF+nSpcsrV5GkJSYmhh49ehgSQc2aNWPHjh3ZkggCkiVn1Go13t7e2ZK0ePDgAYqi4ODgYFhtc/jwYU6ePMmSJUuoU6cOAP7+/rRq1YqXL1+m+9ovX75k7NixABQsWPAVs/9fcHAwzZo14/Dhw1hZWTF48GDWrFmDv78/x48fZ8WKFQwfPhx3d/c0r9G5c2cePnyIRqNh0qRJ+Pn54efnx6RJk9BoNDx48IDOnTtz9+7ddMcmhBDiDaWxgUbvJ9QTqjkYTLRpJ+gY/NUcNrwLL0y/EZGbONhomNa9Cn8Mq80BDxV77GLRkfzNIUUPJzbfYt1Pp3n+KKHFvK3Wni5/budik6LJ5qoBp58W8e/sLy31CEKIN4QmpwMQIreYPHkyL14kvCuV2gqR1q1bM3HiREJCQli16vVu/VmwYEEqV66c7FiLFi0YOnQoVatW5fbt2wQEBLB+/Xp69uyZ4etHR0fTrVs3du7cCSR8bTdu3IidnZ1Z4jfFy8uLH3/8kTp16lCrVi0cHR1ZuHAhJ06cMOt93N3dmT59OmPGjDFKbNWqVYt+/frx9ttvs2rVKgIDA5kxYwaff/55uq49ZcoUgoKCaNmyJUWKFGHRokWvPEdRFAYMGMCtW7fIly8f27Zto379+snm1K1blz59+jBnzhysrIzfxYSEn5GQkBAAli1bRq9evQxjPj4+1KpViz59+vD48WOmTJli1m13QgghXmOOBaHL71BnOOz8FG4fSjFBgbNL4eIG8PkAGowDa9uciDTdWlTwYNcH+fh0XQCLzj+iY6SWQvHJ34d/dCucld+eoNFbZfBu4oXW2oYuc7aw8b1uVN57K9lcl99XcCQ2hkb/mWbJxxBCvMZkZZAQQHx8PFu3bgWgdu3ar9wqVKBAAcPqjDeNk5MTU6ZMMfx5z549Gb5GZGQknTp1MiSC2rdvz6ZNm7I1EQRQtmxZJk6cSNOmTXF0dMy2+0yfPp1JkyalusLJysqK2bNno9VqAVizZk26rnvy5En++OMPbGxsmD17drrjWbp0KX5+fgD89ddfRomgpDQajWF7X1IPHz5k6dKlALRt2zZZIihR7969adu2LQBLlixJdSuZEEIIYVKhqjB4M/ReAq7Fjcd1EbDva5hVJyExlMurXbg5aJkzoCZT+lZlk3s8R2x06FOsEoqL1XNw+TW2zDxPRFgMWist3X7fSED7csbXm7uewz9MtFT4QojXnCSDhABCQkKIikroWlGmTBmL3PPw4cMMHDiQEiVKYGtri6urKzVq1GDKlCmG1Rcp/fTTT6hUKqytrU1uLYqOjsbW1tZQVPjs2bMmr1OhQgVUKhV9+/bNVOxVqlQxfJ7RQtovX76kQ4cO7N27F0jYjrdhwwZsbXP3O3zm5u7uTtWqVQG4cePGK+fHxcUxYsQI9Ho9H3/8MeXKGf+SmJqZM2cCUL58+Uyt4gLYtGkTen1Cy9yhQ4emOi8xkarX69m0aVOm7iWEEOINplJBpS4w9gS0mgpaE2/ePA+C1YNhYUd4cM7iIWaESqXirVpF2DbBh/hKzixzjOGp2rgFfdDFUFZ8dYIbpx9jbWVNj5/WEtDV22ie+/yt+H0z3hKhCyFec5IMEgIMKzQALl++nK330uv1jBs3Dh8fH3x9fblz5w4xMTGEhYVx9uxZvv32W8qWLcvu3buNzm3atCmQkBgwVQT4+PHjybp8mSra/OjRI65evQrwyi5fqUn69cpI/aTw8HDatWvHwYMHAejZsydr1qxJdj1TknY5M0eB59wi8e8qtW1ZSc2YMYOzZ89StmxZPvnkk3TfIygoiOPHjwMkKzqt0+m4ffs2wcHB6HS6V14n6fdb4vehKUnHjhw5ku44hRBCiGSsbaHxBBh/GmoMxGQ9oTtH4M+msHEsvHhk8RAzokg+e5aPqM+wLhVY7qrjjNa4o2h0hI4df11gz8JLxMdCz+9XcaFnNaN5BXz3cPDzUa9sVCGEEGmRZJAQgJubG8WLJyxHPnfuHNOnTzesgjC3jz/+mFmzZgFQsmRJ5s6dy4kTJ9i/fz8TJkzA2tqasLAwOnXqxLlzyd/tqlmzpmHrkalET8pjr5qT1ov6tCRNmL2qg1WisLAw2rRpY0gQ9OvXj+XLl78RxbhNefz4seHrWLFixTTn3rp1iy+/TCgcOXv2bGxsbNJ9n8REECSs6Hr48CFDhw7F1dWVkiVLUqxYMVxdXenevXuqK8kALl26BICLiwuenp6pzitUqBDOzs5A9idWhRBCvAGcPKDrTBh5AIo1NDFBgTO+8EctOPwL6KItHWG6qdUqRjQpxbr3GnG3lC1rHGJ4aaIF/dV/H7Li6+M8uBZGz6+Xc/Ft42YXBVf5cfCTdyQhJITINEkGCfE/48f//5Lbjz/+mNKlS/P++++zcuVKbt26lcaZ6RcQEMDPP/8MQOXKlTl9+jSjRo2iTp06NGvWjBkzZrBp0ybUajWxsbGMHDky2flWVlY0btwYMJ3oSVxxk7gCxM/PzyiplTjHw8PjlUkIU+Lj4/nxxx8Nf07PtqOwsDBatWplSEwMGjQIX19fNJo3t4b9jz/+SFxcwruCvXv3TnPu6NGjiYyMpF+/frRq1SpD90lM4gA8ffqUqlWrsnDhQiIjIw3HIyMjDe3gfX19TV4nsTtYelpZJm19KYQQQphF4eowdBv0WgSuxYzHY1/Anqkwqy5c2pSr6wlV8HRm47hGtG1VkkXO0Vy1Nu7M+vJpDBt+PcPRtdfp+vE8Lg9pbDTHY8Mx9n84UBJCQohMkWSQEP8zYcIEhg0bZvjz7du3+f333+nbty+lSpXC09OTvn37snnz5kz/oztnzhxDcuaff/7B1dXVaE67du0McZw4cQJ/f/9k44lbu06dOpWsblBMTAz//vsvAB999BF2dnY8e/aM8+fPJzs/MYnUpEmTDMUeEhLCvn37aNq0KWfOnAESEkGJyam0nD17lpMnTwIJK4IWLFiAWv3m/u/n+PHj/Prrr0BCcmXMmDGpzvX19WXXrl24uLjwyy+/ZPheT58+NXz+ySefEBISwoABAwgICCAmJoa7d+/y3XffodVq0el0DBs2jFOnThldJ7HTXnoKbzs4OACYrGslhBBCZJpKBd7dYKw/tPwcrB2M5zy/A6sGwqLO8OC88XguYaOx4uP2FVg4ugGnvazYah9LTIri0ihwdk8wa74/hU//GVwd2YKU69YLbTvFvvG9UbJpRbsQ4vX15r4aEyIFtVrNvHnz2LVrF+3atTNatfLo0SNWrlxJly5dqFu3brqK/qaU2HnL29ubevXqpTpvxIgRRuckSq1u0IkTJ4iKisLFxYX69esbOkYlXUGUdGvSq+oFffnll4Y6PSqVioIFC9KyZUuOHDmCvb09H3zwAcuWLXv1Q0Oy7lTHjh3j/v376TovUbNmzVAUBUVR8ny78kePHtGzZ0/i4uJQqVQsWrQIe3t7k3NDQ0P54IMPAJg2bRoeHh4Zvl9ERITh8+joaIYNG8aSJUuoXLkyWq0WLy8vPv74Y8PXVafTJesWl/Rc4JX1nQDDNrbEouxCCCGEWVnbgs+H8N5pqD4Ak/WEbh+CP5vApvfg5WOLh5hedUu6sf0/PlRuWIiFzjEEaYxXCT29H8Ga709SzHsSgWPboU/xuIX3XGDvmLfQxxnXIRJCiNRIMkiIFFq3bs327dsJDQ1l27ZtfPnll3Tu3BkXFxfDnJMnT+Lj48ODBw/Sfd2YmBgCAwMB0kwEAdSoUcNQS+fChQvJxmrVqmVYnZE00ZP4eePGjbGysjIke5LOSdwiBpmvFwRQvXp13nvvvXTX+2ncuLGhc9nt27dp2bLlG9l2/MWLF3Ts2NGw5er777+nRYsWqc7/8MMPCQkJoW7duowePTpT90zapU2j0TBt2jST8/r160ft2rUB2LVrF8+fPzd5ndjY2FfeM7Ewtp2dXWZCFkIIIdLHyRO6zYKR+6FofRMTFDi9CH6vCUd+g7gYE3NynpOtNT/0rMZPg2uypyDst9URl2KVkD5e4dj6G+hjBnDt3b7Ep0gIeR28wr6R3dCnoymEEEKAJIOESJWzszPt27fn888/Z9OmTTx69Ij58+eTL18+AB48eMBnn32W7us9e/bM8HnBggXTnGttbY27uzuQfJsPJLygb9SoEWA60ZOYBEr8b9K6QYlzChQogLe3cbvSpMaMGUNAQAABAQGcOXOGzZs3M3jwYNRqNUePHqVZs2aEhISkeY1EarWaJUuW0K1bNwCuXbtG69atCQ0NTdf5r4Po6Gi6du1q2II1ceJEJk2alOr8ffv2sWjRIqysrJg7d26mt9UlFhyHhCReWquL2rZtCyR0vEu5VSzxOunZ+pW4Gik9W8qEEEKILCtcA4btgJ4LwKWo8XjsC9j9OcyqB5e35Np6Qm28PdkxoQmuNdxY4hTDYxMt6B9cDyPkejMuDR2PLsWvBl5Hb7BvWGf06XjjRgghJBkkRDrZ2NgwdOhQli9fbji2bt26THUdS7ptKjNS1g3S6XQcO3Ys2Vi9evWwtbVNVjcoMRmUnnpBBQsWpHLlylSuXJnq1avTqVMnFi5cyPz584GEFT7Dhw9Pd8wajYaVK1fSrl07IGHFU5s2bQgLC0v3NfKquLg4evfuzf79+wEYPnx4siLcpkyfPh2A2rVrc/XqVVasWGH0kbSw+ebNmw3Hk0os5pzyc1OSjqdM9CUWjk5c1ZSWxMLRr7qfEEIIYTYqFVTuAeP8ofkUsDaxBfvZLVjZHxZ3gYcXjMdzgQJONvw9qDYTe1VmrXscx210KClWCeli4gm5WYHT3b8mwjb5Gy9e/nfYN6gDcTG5t6uaECJ3kGSQEBnUtm1bw4vcZ8+epXt1S+KKIkioG5OWuLg4w3Xd3NyMxlPWDTpx4gSRkZG4uLhQo0YNICF5lbRu0JMnT7h48SLw6npBaRk8eDBvvfUWAJs2bWLfvn3pPler1bJu3TqaN28OwOnTp2nfvv1rXWhYr9czcOBANm/eDECfPn34888/X3le4lar48eP069fP5Mffn5+hvnvvfee4XhSSVeAxccb1yFIKul4yppZlSpVAhI6w6W1xe/BgweEh4cDZKpbnRBCCJEl1nbQ9L8w/jRU62d6zi0/+NMHNv8HIp5YNLz0UKlU9K1bjK3/8SGighPLHWN5bmKVUESoGyd9vuFhgcrJjnudvceB/u3QRUYYnSOEEIkkGSREJhQuXNjweXpX+djY2FC2bFkAQ4v11Jw5cwbd//Z8V65c2Wi8Tp06ho5NBw4cMKz4SawXlChp3SA/Pz9DF7Ss1AuChGLGiff59NNPM3SunZ0dmzZtokGDBkBCQenOnTu/tsWGR40aZVit07lzZ3x9fS3aSa127dqG2j03b95Mc27SouheXl7JxpJ2jUtaeyqlpGOJ2xmFEEIIi3MuBN3nwvB9UKSu8biih1ML4PcacPQPiMt9W6uKuzuwalQD3u5YlqUusZzXGheIjtdZc8l7DBcr9CPOysZw3OvCIw72b0fsy3BLhiyEyEMkGSREBkVGRnLp0iUgoa5QYm2f9GjVqhUAFy9e5MSJE6nO++eff4zOSUqj0dCwYUMgIdGTWDso5YqfpHWDElfwuLu7m0wwZUS5cuXo3bs3kJDY2r17d4bOd3R0ZPv27dSqVQtIeIYePXqkqzhxXvLBBx8Y/i5btmzJ6tWrjVbcpObAgQOGDmqpfQwePNgw/9atW4bjSTk4OBi25l28eNFQxDwlvV7Pxo0bAbC3t6dmzZrJxrt06WJIYi1YsCDVuBO7kqnVarp06ZKuZxVCCCGyTZFa8M4ueGseOHsZj8eEw64pMLs+XN2e6+oJWalVjG1ehlXjGnGzhA3r7WOIVBnH+MizMcdrf0KYc0nDMa/LTzjcrx0x4c8tGLEQIq+QZJAQJBTFrVevHlu2bEmzBpBer2f8+PG8ePECSHiBnJH6P2PGjDG8oB45cqRhO01Su3btYt68eQDUrVuXOnXqmLxW0rpBR44cSXYsUb169bCxseHZs2f4+voCCfWCslqzCBJWBCVe55tvvsnw+S4uLuzcuZMqVaoAsGPHDvr06UOcibaoBw4cMLS4HzJkSJbiNochQ4YY4klaxDupqVOn8ssvvwDQsGFDNm7caGi5bmkff/wxAIqiMHbsWMOqs6SmTZtmWBk0dOhQo1g9PT3p378/ADt37mTNmjVG11i9ejU7d+4EYODAgXh6epr1OYQQQohMUamgSk8YdxKafQoaE90un96A5X1hSXd4dMnyMb5CZS8XNo9vTNMWxVngFM11Ey3oY+wKcKrGB9wo2Qm9KmEFd6HAZxzt046o529O0w4hRPqk7y1qId4AJ06coHPnznh5edGtWzcaNGhA8eLFcXJy4vnz55w5c4b58+cTEBAAJCQzvv766wzdo0qVKnz44Yf8+OOPnDt3jpo1a/LRRx9Ro0YNIiIi2Lx5M7///jvx8fFotdo0a8skrRsUFxeXrF5QIltbW+rXr8/BgwcNhZqzUi8oqcqVK9OlSxc2btyIn58fhw8fTraVKD3c3d3ZvXs3TZs25erVq2zYsIFBgwZl61aqNWvWJKtRdPjwYZOfQ0ICJHFVTUb88ccffPnll0DCdqsffvghWbFnU8qXL4+1tXWG75UedevW5d1332X27Nns3r2bxo0bM2HCBMqVK0dISAi+vr6GZGHRokWZOnWqyet8++237Nixg5CQEPr168fJkyfp1KkTAFu2bOHnn38GErrVZSZBKIQQQmQrrT00+whqDIA9UyFglfGcm/thbiOoPSwhceSQ/hXg2c3W2orPOlWiZYWCTFx1jushsbSIskZLkjf5VGruFG/PUzdvKl1eiEPkIzxvhXG8d3vqLN+Mg3vqXUWFEG8WSQYJQcK2K09PTx4+fMi9e/eYNWsWs2bNSnV+2bJlWb58OSVKlMjwvb7//nsiIiKYPXs2N27cYOTIkUZzXFxcWLVqFdWrV0/1OnXr1sXe3p7IyEjAuF5QombNmiWr45LVekFJTZ482bC16OuvvzasCskIDw8P9u7di4+PD7du3WL58uXY2dnxzz//mGUFU0oTJ07kzp07JsfmzZtnWJUFCV+rzCSD1q5da/j83r176UqS3bp1K1PfT+n1+++/8/LlSxYvXsyJEyeMCk0DlClThi1btpA/f36T1yhatCibN2+mW7duPHz4kOnTpxu6niXy9PRkw4YNhu5jQgghRK7j4gVv/Q11R8KOj+HeyeTjih78/4GA1dD0Y6g7Aqyy5w2bzGhYJj/bJzRh6qaLLDp5nw6R1njFJ/8d8IVTMfxrfUzpmxspcu8gHkEvONm7IzWXbcTJw8R2OSHEG0e2iQlBwgqae/fuceTIEb788kvat29PqVKlcHBwwMrKCmdnZypUqECfPn1YtmwZFy5cMNS7ySi1Ws2sWbPw8/Ojf//+FCtWDBsbG5ydnalevTqffvopgYGBtGnTJs3rWFtbG4owQ+orfpIed3Nzo2rVqpmK25Q6derQunVrIGF7m7+/f6au4+Xlxb59+wxd2ubPn8/48ePNFqcAKysrFi1axPbt23nrrbfw8vJCq9Xi5uaGj48Pv/76KwEBAZQvXz7N69SrV4+AgACmTJlC5cqVcXR0xNHRkSpVqjBlyhQuXLhAvXr1LPRUQgghRBYUrQPv7IYef4NTYePx6DDY+QnMbgDXduaqekIudtb80qc63w6swfaCCn62OuJTtKDXW2kJLNuLs1XHEm3jSsF7EZzp05nwB6bfFBNCvFlUSspqo0K8wt27dw0v2oODg2UFgBBCCCGEyNtiI+DIbwkfcdGm55RuCW2nQcEKlo3tFR6FR/PfNee5cukJHSO15Ncbv9+v0UVSLnAFno9P8aSgDZWXriFf0TI5EK0QIqOy6/W3rAwSQgghhBBCvNm0DtD804Qi05V7mp5zYy/MaQjb/guRTy0bXxo8nG1ZNLQO49+qxCo3HSdNtKCPs7bnUqVhXKg4FJenai72fYvQW1dyIFohRG4hySAhhBBCCCGEAHAtCj3nwbBdULim8bgSDyf+gt9rwL9zId64Q2dOUKlUDGxQgk3v+/C0nD2rHGIIVxl3yH3sUZsTdSaj0pfi6tu9eRx4PgeiFULkBpIMEkIIIYQQQoikitWD4Xuh21xw9DQej34OOz6COY0gcI/Fw0tN6QKOrBnTkO7tS7PEJZZL1sarhGJs8nG22nhC3LsSOHAIDy+fzoFIhRA5TZJBQgghhBBCCJGSWg3V+8H4U9Dkv6CxNZ7z5CosfQt8e0LINcvHaIK1lZr/tCrH8rENuVxcy2b7WKJVxmVi7xZpzrWyE7k0ejL3Av7NgUiFEDlJkkFCCCGEEEIIkRobR2gxBcb5g3cP03Ou74Y5DWD7R7mmnlC1oq5sfc+Hek2LsMApmtuaeKM5kQ6eXCn/Af6frOL2Kb8ciFIIkVOkm5h4JW9v72R/1ul0BAYGAtJNTAghhBBCvGHuHIMdH8ODs6bH7fJB88lQayhYaSwaWmoOXgvhv6vO4RUaR7MoDRoTawIcX9yi9tCSeLdqkQMRCiFSI93EhBBCCCGEECKnFW8AI/ZD19ng6GE8HvUMtk2EuY3g+l7Lx2dC03IF2PVBEwrXKcgip1geW8UYzXnpVBK/lTHsX7ATWS8gxOtPVgaJDMuuzKQQQgghhBB5SswLOPwLHJ0J8cYJFgDKtYM230D+spaNzQRFUdh07j6fr79AjWdR1IlxBJXx+oD8heLo9J+mOLjY5ECUQoikZGWQEEIIIYQQQuQmNk7Q8nMYdwIqdTM959oOmF0fdnyasGooB6lUKrpW92L7hCbEVvFklWMo6tgQo3lPHmjw/ewQN848zoEohRCWIMkgIYQQQgghhMiKfCWg9yIYsg08qxqP6+Pg31nwe03w/wfijVu+W1JhVzt836nH4O51+dNDQRt+1GhOXKyaHX9eYO+iS8RG5Wy8Qgjzk2SQEClEREQwd+5cOnTogJeXF7a2ttjY2FCgQAHq1KnDsGHD+PvvvwkODjZ5/pAhQ1CpVEYfarUaV1dXqlWrxtixYzl79my2xH/79m2T91epVNja2lK4cGHatGnDb7/9Rnh4+CuvV6JECVQqFSVKlHjl3A8++MBwr7Jly6b6Ncqq4OBg1q5dy8cff0yLFi1wcXEx3Hfq1Klmu09MTAz//vsvf/zxBwMHDqR8+fKo1WrDvdKjWbNmqf59pPZx4MABo+ssXLgw3ecvXLjQZCxTp07NcCzm/HoKIYQQr70SjWDkAegyExwKGo9HPYWtH8KfPnBjv8XDS0qtVjGscUlW/6cjG+vUxObxPLSxxr8bXjn2kBVfn+B+YM6uahJCmFfuKG8vRC5x7Ngx+vbtS1BQkNHYkydPePLkCSdPnmTBggV4eHjw8OHDdF9bURTCwsI4f/4858+fZ+7cuXz88cd8++235nyENMXExPDgwQMePHjA7t27+fnnn9mwYQM1a9bM0nUVReG9995j5syZAFSoUIG9e/dSuHBhc4SdzJ07d9KVmDKH0aNHp5pYyS5qtZqyZXO+pkCi8uXL53QIQgghRN6itoKaA6FSVzj0M/w7G+Jjk895fAmWdIPyHRLqCbmXzpFQAcp6OLFufDtmlC1I/r8+pIB1V0IKVE8258XTaNbPOEONVsWo16UUVtaypkCIvE6SQUL8z7Vr12jbti0vXrwAoEuXLvTs2ZNy5cqh1Wp58uQJ586dY/fu3ezfn753cnbu3GlIiOj1eh49esTWrVuZNWsWcXFxTJs2DS8vL959991seaauXbvyzTffGP787Nkzrly5wi+//MLly5cJDg6mY8eOXL16FWdn50zdQ1EURo8ezV9//QWAt7c3e/fuxcPDRHcNM0ha816lUlG6dGkKFy6Mn59ftt7LycmJmjVrcvXq1QwlARcsWEBERESacy5dukSfPn0AaNmyJV5eXmnOT/p9ZUpqReXeffddevbsmea14+PjadKkCeHh4Tg7O9OtW7c05wshhBAiFbbO0PpLqDUYdn8Olzcbz7m6DQJ3Q71R0HQS2LpYPk5Aq1HzcafaHC3vy8X/9qLi5fNcK9uLeI3d/09S4MzuIIIuPaXV0ErkL+KYI7EKIcxDkkFC/M/kyZMNiaAFCxYwZMgQozmtW7dm4sSJhISEsGrVqldes1y5ckarWFq3bk3Lli3p0qULkLB1Z9SoUVhZWWX5GVJydXWlcuXKyY75+PgwZMgQmjRpwr///svDhw/566+/mDhxYoavr9frGT58OAsWLACgWrVq7Nmzh/z585slflOcnJz45ptvqFu3LrVr1yZfvnwcOHCA5s2bm/1e7du3p1mzZtSpU4eKFSuiVqtp1qxZhpJBJUuWfOWcJUuWGD4fNGjQK+eb+r5Kj4IFC1KwoIkl60ls377dsH2wV69e2NnZpTlfCCGEEK/gVgr6+MItv4Qi0o8Cko/rdXBsJpxbAS0mQ83BCauLckDDskUpv2QXm97pSN2T07hcYRDPXZOvWA6995LV3/lTv2tpqrUq+n/s3Xd0VNXXxvHvpJOEEHrvvfcmhN4h9CogRQEVERHkRUQEERURVEQUkN470qv0KhAh9BoIJXQIpJA27x8x80uYSTJptDyftWatSc695+6bBCazc87e2NhYt3VeRF4tWt8nQuRqiPXr1wNQqVIli4mg6DJnzkz//v0TfT1PT088PDwAuHv3LseOHUv0XIlhb28fY8XQtm3bEjxHeHg4PXr0MCWCKlasyI4dO1I0EQSQMWNGvvjiCxo2bEj69OlT9FqdOnWiZ8+elCxZEhublPnvMiIiggULFgDg6upK27ZtU+Q61po7d67puTWJKREREbFS/lrQbxd4/gLOFn5fCrwH6wbB1FqRiaOXJGNaN7ou2Mbpso6U//cXCl1aiSEiNMYxEeFG9q+8yF8/eeF/P+glRSoiSaFkkAiRCZmgoMgXskKFCr2Qa1apUsX0/OrVq6bnly9fZsKECXh6epIvXz7SpElDmjRpyJs3L506dWLTpk3Jcv3SpUubnie00HNYWBhdu3Zl/vz5AFSrVo3t27eneHLmTbR9+3Zu3LgBQPv27XF2dn5psfj7+/PXX38BkSuaohKWIiIikkxsbKFiT/j4GLz1MdjYmx9z+yTM8YTFXeHB5RceIoCDvSMdZ27Cu04e8vhup/LRH3B5esPsuJsXHrF4zGHOHrwVY3u9iLz6lAwSARwcHEzPz5w580KuaW//vxf/8PBwAK5cuULBggUZMmQI69at4+rVqwQHBxMcHMy1a9dYunQpTZs2pXv37oSFJa3FZ/R7jh5LfEJDQ+nUqRNLliwBoGbNmmzZsoV06eLe4x69y1mdOnUSFfOb6FVaibNs2TJTUrR79+5Wd0wTERGRBHJKB43GQP9DULS55WPOroPfqkbWGwqOvwNscnOwdaDdb+s42bgQrgE3qXz0B/Jc2wrGiBjHhQaHs332GTZNO0nQ05BYZhORV42SQSJAhgwZyJs3LwDHjx9n3LhxRERExHNW0nh7/2+/eFQx4PDwcBwcHPD09GTSpEls27aNY8eOsW3bNqZMmULJkiUBmD9/PmPGjEnS9aMnvaytPxMSEkL79u1ZuXIlAHXr1mXTpk2kTZs2SbGkVk+fPmXVqlUA5M2b1+okWa9evciRIwcODg5kypSJatWqMWLECNMKo8R6lRJTIiIiqULGgtBlIbzzF2QpaT4eHgL7foFfK8DRORAR/kLDs7O1o83EVZzyLIGNMYxCl1dT/t9fcAq+b3bsZa+7LPr6MD7e915ojCKSOEoGifxnwIABpufDhg2jYMGCDBw4kCVLlnDlypVkvdbx48dN272cnZ2pXLkyANmzZ8fHx4c1a9YwYMAA6tevT/ny5alfvz4ffPAB3t7epnpGEyZM4PHjx4mO4bvvvjM9j6/DFESuCGrTpg1r1qwBIgthr1+/HhcXl0THkNqtWLHC1GmsW7duVq/E2blzJ7du3SI0NJT79+9z6NAhxo4dS6FChZg6dWqiYvHx8WHPnj0A1KhRg4IFX16LWxERkVSnQB3otxta/ATOGc3HA+7C2o9hWm3w2ftCQ7OztaPtD8s41bYsAOkfX6TKP9+S7dYBs2OD/ENY/9sJdi48R+izF5u4EpGEUTJI5D+DBg2id+/epo99fHyYNGkSnTt3pkCBAmTLlo3OnTuzdu3aRO2JNhqN+Pn58eeff9KgQQPT1rCPP/4YJycnAFxcXMiePXuscxgMBiZMmICtrS0BAQEJLvz86NEjDhw4QMuWLVm7NrK9afXq1U1tzeNy8+ZNNmzYAEDt2rVZs2aNOk0lUUJX4hQoUIAhQ4awYsUKDh8+zOHDh1m8eDEdOnTAYDAQHBzM+++/z7Rp0xIcy7x580w/1z169Ejw+SIiIpJEtnZQqTcMOAbVPwIbC42f/bxhdnNY0h0e+ryw0GwMNrQbu4jTnSP/gGkXHkyJc/MpfXIa9iFPzI4/tfsGS745jN+VxP/hUkRSlsGoSl+SQNevXyd37txAZOHhXLlyveSIktfWrVuZOHEi27Zti7UuT6VKlVi8eLHF1RM9e/Zkzpw5Vl2refPmrFq1KtaaPaGhody+fZsnT56YkkcA9evX586dO3z55Zd8/fXXMc7x8fGxqp25vb09Xbt25ZdffsHNzS3W4/Lly8fVq1cxGAymZEGGDBnYsWMHZcqUseY2U1z01vJfffUVo0aNSrFr1alTh127dgEkqVDi9evXyZs3LxEREVSrVo0DB8z/uhbd48ePcXNzi3X10Lp162jbti2hoaE4Oztz6dIlsmXLZnU8RYsW5fz58zg5OeHn5xdvDSgRERFJYfcuwpYRcH6j5XFbB6jeHzwGg+OL2bJvNBpZ/V0fis3dZ/rcMwc3zhZ9m/sZS5sdb7CBik3yUal5PmxttQ5BJDFS6v23/kWKPKdhw4Zs3LiR+/fvs2HDBkaPHo2np2eMN8dHjhzBw8ODW7duJXh+BwcHatSowZw5c1i7dq1ZIig0NJTffvuNatWq4erqSu7cuSlRogSlS5c2Pe7cuQPAvXuJ35NduHBhBg0aFGciKLo8efLw2WefAfDgwQMaNmzI2bNnE3391G7+/PmmulTWrMRJly5dnNvIWrRowciRIwEIDAxkxowZVsdy8OBBzp8/D0CrVq2UCBIREXkVZCoEby+G7qsgc3Hz8fAQ2PsTTKoAx+ZBCte7hMhV6q0/n8659+qaPucY4k8Z7z8oem4hRmPMAtLGCDiywYcV447y0C8gxeMTEespGSQSCzc3N5o2bcrIkSNZs2YNt2/fZubMmab26bdu3eLLL7+Mc47Nmzfj7e2Nt7c3p06d4urVqzx58oS9e/fyzjvvmL25f/DgAdWrV+ejjz7i0KFDhITE3ZEhqvNTbFq1amW6/vHjx9m4cSMDBw7EycmJ06dPU6dOHc6dO2fFVyPSDz/8wEcffQTAnTt3aNCgAZcvv5yWp6+7efPmAeDo6GjVNj1r9O3b1/QzFbV6yRoqHC0iIvIKK1gP3t8LzX6ENBnMxwPuwJqPYHoduLo/xcMxGAy0HjKFCx82Jir9ZABy3tpH9cPfYgy5bnbO3WtPWDL2H07suI4xQhtTRF4FSgaJWMnR0ZFevXqxaNEi0+dWrlwZZ9exIkWKUKpUKUqVKkWJEiXIkydPjJbuzxs4cCBHjx4FoHXr1qxZswYfHx8CAwOJiIjAaDRiNBpNywTj26bk7u5uun6ZMmVo0qQJP//8M+vWrcPOzo6HDx/y9ttvx9iCFp9JkyaZaivduHGD+vXr4+vra/X5Ermy7PTp00Dkip6oBGNSZcmShYwZI4tOWttZLCQkhCVLlgCQNWtWGjdunCyxiIiISDKytYMqfeDjY1DtQ8v1hG4dh1lNYWkPeHg1xUNq+fHPXPmkJeHR/rbpHHSXugfGgf8Owon5e2p4aAR7lpxn7eTjPH34LMXjE5G4KRkkkkCNGzc2JWMePnzI/fvmrTUTw9/f3/SmvGvXrqxatQpPT0/y5s1LmjRpYqwievjwYZKuVb9+fQYOHAjAsWPHmD17ttXnGgwGpk+fzttvvw1E1iiqX78+fn5+SYopNYm+Eie5izVb25Esyrp163jw4AEQ+XNna2ubrPGIiIhIMkqTHpp8Bx8cgMKx/AHn9GqYXBm2fw3PnqZoOC3eH4fvZx0Ii/au0sYYQb1jy3H0m8p9G/M/OPqefsDiMYe4cOR2isYmInFTMkgkEXLkyGF6ntA337G5cOECoaGhAHFuGzp79ixPnyb9hX348OGmekGjR4+Od0tadDY2NsyZM4e2bdsCkbE3aNAg2RJjb7LQ0FAWL14MQObMmWnatGmyzX337l1THanoP6Nx0RYxERGR11DmItB1KXRbAZmKmo+HP4M9E+DXivDvwhStJ9S099fc+rxrjIQQQK2z3uS++g1eDuargJ4FhrHlz1NsmXGK4IDQFItNRGKnZJBIAgUGBpq2+Li5uZm25SRV9M5lAQGxF9j7448/kuV6GTJkoH///kBkVXprO6BFsbOzY9GiRaZkxqlTp2jUqBGPH6uFaFw2btzI3bt3AXj77bexs7OwzDuRpk2bZto6WLt27XiPjyqSDlC2bFnKli2bbLGIiIjIC1CoAXywD5qOj1w19LynfrD6A/izHlw7mGJhNOo+gjtf9SbkuQXG1S7docHFMaxM85QnBvPyBhf+uc3iMYfxPfsgxWITEcuUDBIBnj59StWqVVm3bl2cNYAiIiIYMGAAT548AaBly5bJtjKoUKFCprnmzJljsR7Q2rVrmTx5crJcD2DQoEE4OzsD8P333yeodhBEdkZbuXIl9erVAyK3nDVp0sTiyiUfHx8MBgMGg4E6deokOfakGjVqlCmehGyTS6rErMTx8fHBy8srzmPWrVvH119/DUCaNGno1atXvPMuWrTItBpNq4JEREReU7b2ULUvDDgGVd8Hg4Ut3ze9YGZjWN4bHqVMrcf6nT7j4Tcf8Oy5v3OVvfyQjy98z1LXx5yxDzM7L+DRM9b8/C97lp4nLCRhv4uKSOIl35+kRV5zhw8fxtPTk5w5c9K6dWuqV69O3rx5SZs2LY8ePcLLy4uZM2fi7e0NRLb6HjNmTLJdP2PGjDRr1oz169ezadMmGjVqxAcffEDevHm5c+cOK1asYPbs2RQoUIBHjx6ZVpckRebMmenTpw+//PILly9fZuHChXTv3j1Bczg5ObFmzRoaN27Mvn37OHjwIC1atGDjxo2kSZMmyTFasmnTphg1iqK3uP/3339jJHdcXV1p3759oq7j5+fHpk2bzD4X5fkkUs2aNSlUqFCs8z18+JB169YBUKpUKSpUqGBVHD4+PtStW5fq1avj6elJ2bJlyZIlCwCXL19m+fLlLF++3JRA/PHHH8mZM2e880Ylpuzs7OjatatVsYiIiMgryjkDNB0HlXrD5uFwcZv5MSdXwNn18NbHUPMTcHBJ1hDqtPmYPfaORAz/mTTRKhAUufiQ7/iRnz1GcPF2BA2D7HEyxvyD6om/r+N7+gENe5ckc560yRqXiJhTMkiEyDfD2bJlw8/Pjxs3bvDbb7/x22+/xXp84cKFWbRoEfny5UvWOH7//Xdq1qzJtWvX2LZtG9u2xXwRz5MnD6tXr6ZZs2bJds0hQ4bw+++/ExISwnfffUfXrl2xsUnYokEXFxc2bNhA/fr1OXLkCLt27aJNmzasWbMmzu5pifX999/H2jr9r7/+4q+//jJ9nDdv3kQng86ePRvnCpvnx2bNmhVnMmjJkiU8exa5bz4xK3EOHDjAgQMHYh13dnbmp59+om/fvvHOdfbsWf755x8AGjVqRNasWRMcj4iIiLyCMheNrCV0YSts+hzuX4g5HhYMu38Ar3nQYBSU7ggJ/N0vLh4t+rHfwQHjZz/gHK1cUJ6LDxkS8TX7e/7ErCMPaBboQN6wmKuYHvoFsuz7I1TxzE+FRnmwsdVGFpGUon9dIkSubrlx4wb79u1j9OjRNG3alAIFCuDi4oKtrS1ubm4UK1aMTp06sXDhQk6ePEnFihWTPY7cuXNz7NgxPvvsM4oUKYKjoyPp0qWjbNmyfPXVV/z777+UKFEiWa+ZK1cuU0erM2fOsGLFikTN4+bmxubNmylTpgwAmzdvplOnTjFqIaV28+bNA8DW1jZBK3EqVqzI/Pnz6d+/P1WrViVPnjw4Ozvj4OBA1qxZqVevHmPHjuXKlStWJYKixwLaIiYiIvJGKtwQPjwATb4Hp3Tm409uwap+MKMB+B5O1ku/1agXYT9/yVOnmJ/PcfkxNWcMZHKnguzJYcP2NCGEPteC3hhh5NBfl1k1wYvHdwOTNS4R+R+D0VJhEpE4XL9+3dRa3dfXl1y5cr3kiEREREREJFYB92Hnd3BkJhhjqctTukPkSqF0yfe7/dE9ywkbOBK3wJhvOW/ncqHIzFWM23+PvUdv0SzQgWzh5usU7BxsqNmhMCVq5ki2Op0ir5uUev+tlUEiIiIiIiJvMpeM0PzHyM5jBepaPsZ7GfxaCXZ8ByHJsyKnokd7HKd8z2OXmImcrNcDuNCjFV/XzsiobuVYlzmCA46hRDy3SigsJIKdC86x4XdvAv1DEJHko2SQiIiIiIhIapClOHRfBV2WQIaC5uNhQbDre5hcCU4shWTYRFKuWktcpk7kYdqYCaEst4I43rkVNTMFs3FQLQxl3FnkGsJDG/POvj4n7rF4zCF8TtxLcjwiEknJIBERERERkdTCYICiTeDDg9D4W3C0UE/I/was7AMzGsL1o0m+ZOlKTUj/52QeuMV8+5n5djAnO7XB/tE15vauwvvtirMkQyjHHcxrTgY9CWX9lBPsWHCW0GdqQS+SVEoGiYiIiIiIpDZ2DlC9P3x8DCq9CwYLbw2v/wN/1oOV/cD/ZpIuV7xsPTLN/oO76WNeJ9O9EM52ac/dy6d4p3o+Vg/04HYRZ1a4PCPAYL4y6fSemyz55jB+Vx4nKR6R1E7JIBERERERkdTKJRO0mAjv74X8tS0fc2Ix/FoRdv0AoUGJvlTREh7knDOL2xljtpTP8CCUS293we+sFwUzu7L8g7do0aQgc9M946Kd+Sqgx3eDWDn+KIfXXSEi3HxbmYjET8kgERERERGR1C5rSXjnL+i8CDIUMB8PDYQdYyOLTHsvT3Q9oYJFqpBv3jz8MtvF+Lz74zB83unOjZOHsbe14dOGRZj3YXW88tiyOU0IIWYt6OGfdVdY+eMxHt1WC3qRhFIySERERERERCLrCRVrFllPqNE34Ohmfoz/dVjxLsxsDDcSV08oX4HyFFqwiJtZ7WN8Pp1/ODd69OLav3sBKJ8nPRs+qUXpWjmZk/YZN23NVwHdvuLPkrGHObXnBsZkKHgtklooGSQiIiIiIiL/Y+cIbw2AAcegYk/L9YR8D8H0erDqA/C/leBL5M5TiuILl3E9h0OMz6cNiOB2735c+edvAJwd7BjbpjQT363Ilmyw10kt6EWSg5JBIiIiIiIiYs41M3j+Av12Qz4Py8ccXxhZT2j3+ATXE8qRsyhlFq3iWm6nmJcNjOBBn4+4eGCT6XP1imVl0yAP0lbIyELXZzyIpQX9oq8PcUUt6EXipWSQiIiIiIiIxC5baeixFjrNh/T5zMdDA+Dvb2ByFTi5MkH1hLJmLUCFhX/hky9NjM87Bxvxf/9Tzu9aY/pcRldHpnavyKDOpVmeMYx/LbSgD34ayob/WtCHBJuPi0gkJYNE5JWWL18+DAYDPXv2TLFr9OzZE4PBQL58+VLsGiIiIiKvNYMBintC/8PQYDQ4pDU/5vE1WN4LZjWDm/9aPXXmzHmounAdlwu6xPh8mmdGAj76P/6Z/aOpHpDBYKBjpdysG1SLRyVc42xBv3TsP2pBLxILJYNEYhEWFsaKFSvo27cvpUuXJkuWLNjb25MuXToKFSpEmzZtGD9+PFeuXHnZoaZ6RqORNWvW0KVLFwoXLoyrqyt2dna4u7tTqlQpOnTowPjx4zl+/PgLj61u3boYDAYMBgONGjWy+rw6deqYzov+sLW1JUOGDFSsWJGBAwdy6tSpeOcaNWqU6fydO3fGeezevXtxc3PDYDBgZ2fH/PnzrY45IZ4+fcru3bv58ccf6dixI/nz5zfFmNxJuYsXL7Jo0SIGDRpEjRo1cHZ2Nl1r9uzZ8Z4/e/Zsi9+LuB6Wkpf+/v4sXryYwYMHU7t2bQoVKkS6dOlwcHAgS5Ys1KlThx9++IH79+/HGU9oaCibNm1i0KBBvPXWW2TKlAl7e3vc3d2pUKECn332GZcvX07kV0tERCQedo5Q8xP4+BhUeAcwmB9zbT9MqwOr+8OT21ZNmyFDDmos3MClIjGTTE6h4Pr9DLa93RD/W9dMn8+dwZnFfavT3rMw892fcSGuFvRrL6sFvchz7OI/RCT1WbNmDYMHD+bixYtmY/7+/vj7+3Pp0iVWr17N0KFDad68Od9//z2lSpV6CdGmbrdv36Z9+/bs3bvXbOzx48c8fvyYU6dOsXz5coYOHcqZM2coVqzYC4nt6tWr7Nq1y/Tx9u3buXnzJjly5Ej0nBERETx8+JCHDx9y7NgxfvvtN7755huGDRuW5Hh37txJixYtCAgIwM7OjgULFtCxY8ckz2uJp6dnvImp5LBr1y7q1KmT4td5XtGiRc0+d/jwYbp06WLx+Lt377Jr1y527drF+PHjmT9/Po0bN7Z4XPHixS0mjB4/foyXlxdeXl5MmjSJH374gYEDByb9ZkRERCxxzQItf4XK78Gmz+HqvucOMMK/8+H0avD4FKr1B3snSzOZuKfLQu0FG9nRozmFT8dc0ZPL6wbnmjfFZsj7VHx7AAC2NgY+rFOIWoUzM2ixF5euB1MvyB6HaAkqYwT8s96Hq6ce0LBXCdyzOifH3Yu89pQMEnnON998w8iRI01LUevUqUOLFi0oU6YMGTNmJDAwkFu3brF7927WrVuHj48P69evJ1euXPzxxx8vOfrUJSQkhIYNG+Lt7Q1A+fLl6dWrF+XKlSNt2rT4+/tz5swZdu/ezfr163n8+MUuE543bx5GoxFHR0fCw8MJCwtj/vz5DB06NEHzRN0fRN7z5cuXWb16NQsWLCA8PJzPP/+cggUL0qFDh0THum3bNlq2bElQUBD29vYsWbKENm3aJHq++ERv/ZohQwYqVarE/v37efr0aYpdx8bGhuLFi+Pi4sLhw4etnqN169ZUqlQp3uPatm3LhQsXsLGxoXv37haPyZ07N3Xr1qVixYrkzp2b7NmzExERwfXr11m+fDkrV67k3r17tGzZksOHD1O2bNkY5z979syUCCpXrhytWrWiatWqZM2alcePH7Nx40Z+/fVXgoOD+eSTT0iTJg19+/a1+l5FREQSLHtZ6LkezqyBLSPg0bWY4yFPYfvXcHQONBoDxVtGbjmLRdq0Gak/fzObP2hN8UN+McZcAyPg6yls3biOahNmkjZLTgBK5UzH2o89GL/5HHN2+dA80IEc4TE3wdzx8Wfx2MN4dChMiZo5MMQRg0hqoGSQSDQzZ87kyy+/BCBr1qwsXrw41lUFHTp04Oeff2bx4sUMHz78BUYpUaZPn25KlPTq1Ys///wTG5uYL/y1atWiX79+PHv2jEWLFuHu7v7C4ps3bx4ALVq0ICgoiA0bNjBv3rwEJ4OeX3FWoUIF2rdvT9WqVfn4448BGD16dKKTQRs3bqRt27YEBwfj6OjI8uXLadGiRaLmstbbb79Nv379qFy5MoUKFQIi60MldzIoZ86cjB8/nsqVK1OxYkVcXV2ZPXt2gpJB7u7u8f7cnDlzhgsXLgCRWwNz5cpldkzdunW5du2a2eejdOzYkdWrV9OmTRtCQkIYPXo0K1eujHGMwWCgYcOGfP3111SrVs3iNdq1a0fdunUJCgpi6NChdOnShbRpLdR1EBERSS4GA5RoBYUbw8EpsGdCZBIoukdXYek7kLcmNPk2MokUCxfndLSZ/TebZ48m/a9LcQuMWRMo1z/XONO0EQ7DBlCuw/sAONnb8mWLEtQrloXPlh4n751Q3gq2wybaKqHw/1rQ+5y4R93uxXF2i9nWXiQ1Uc0gkf/4+vrSv39/ANzc3Ni7d2+820tsbW3p2rUrx48fp3nz5i8gSonur7/+AsDOzo6JEyeaJYKic3R0pGfPnmTLlu2FxHbw4EHOnz8PQNeuXenWrRsAJ0+e5NixY8lyjf79+5MnTx4ATp06hZ+fXzxnmFu7di2tW7cmODiYNGnS8Ndff6V4Igigb9++dOnSxZQISimFCxdmyJAh1K5dG1dX1xS7zty5c03P33nnHYvH2NraxjtP69atTVvM9uzZYzaeM2dOtmzZYjERFKVq1ap8+OGHQOTWsa1bt8Z7XRERkWRh7xS5JWzAUSjfDYv1hK7uham1Yc0AeHon1qkMBgNNeo0i++qlnC2bwWw8bUAEjl/+wtbezQm4/7+6RDUKZWLjJ7XIUi0LC2JrQe99Xy3oJdVTMkjkPxMnTiQ4OBiAsWPHJuhNqru7O56ennEe4+fnxxdffEGlSpXIkCEDjo6O5M6dm44dO7Jt27ZYz/Px8TErdrt161Y8PT3Jli0bjo6O5M+fnw8++IDr169bFe+OHTvo0aMHBQoUwNnZGTc3N0qXLs1nn33GzZs3Yz0veiFiiHyjOWbMGMqXL4+7u7tZQd6AgACWLFnCe++9R7ly5UiXLh329vZkzpyZ2rVr8+OPPyZpJUjUKotMmTIl64qfR48eMXLkSEqWLImLiwvu7u7UqlWLBQsWWD1HVHIgffr0NG/enNatW5tWZ0RPHCSFjY0NJUuWNH3s6+uboPNXrlxJu3btCAkJwdnZmXXr1lmsUyNxi4iIMP1suLq60q5duyTNF/VzEvX/UWLUrVvX9PzSpUtJikdERCTB0maDVr9B3x2Qp7qFA4xwbC5MqgB7f4awZ7FOlSdPKVou2s2VT9vwNI35eK79lznZpD4n/ppl+lw6Z3t+6Vye4d3LsSpTeNwt6OerBb2kTkoGiRBZVyRqS0/atGnp1atXss6/YMECChUqxLfffsvRo0d5+PAhISEhXL9+nWXLltGwYUPee+89wsLifyH6/PPPadSoEevWreP27duEhITg4+PDH3/8QYUKFThz5kys5wYHB9OlSxfq1avH3LlzuXLlCkFBQTx58oSTJ0/y448/UqRIEdauXRtvHBcuXKBcuXKMHDmSf//912I9nubNm9O5c2dmzJjB8ePH8ff3JywsjHv37rF7924+++wzypQpw9mzZ+O9niUODpFLe2/fvs2DBw8SNcfzzp07R/ny5RkzZgynT58mMDCQx48fs2fPHrp168ZHH30U7xwhISEsWbIEiNxO6ODgQJo0aWjbti0AixYtsup7bY2orwGAvb291ectWbKETp06ERoaiqurK5s2baJevXrxnhe9y5mPj09iQn7j7Nixw5SIa9u2LS4uLvGcEbtz587x77//AiSp0PmzZ//7pdqaFUkiIiIpIkd56LUR2s+CdHnMx0OewLav4LcqcGYtGM1bxAPY2tjSrO+3ZF6xkPMl3c3G3Z6EY/9/P7C1XyuCHv1vtU/LsjlYP7gWgWXcYm9Bv/cmS9SCXlIhJYNEiNy6E1WU1cPDI0lv5p63dOlSunfvTkBAAAUKFGDixIls2rSJo0ePsmLFCpo1awbAjBkz4q0lM336dL7//ntq167NwoULOXLkCNu2bTNtS7l79y69e/e2eK7RaKR9+/YsXrwYiOzmNG/ePPbt28eBAwf45ZdfyJMnDwEBAbRv354jR47EGUv79u25ceMGAwYMYOvWrRw5coRFixbF6KIUFhZG6dKl+eKLL1i1ahWHDh3i4MGDLFmyhM6dO2NjY8OVK1dM25QSqkKFCqZ769OnT5LrzQQGBuLp6cn9+/cZMWIEO3fu5MiRI0yfPt1UA+a3335j8+bNcc6zbt06U3IqantY9Od37txh06ZNSYo1SvTkX968ea06Z8GCBXTt2pWwsDDc3NzYsmULHh4eyRJPahR9pVePHj0SfH5gYCAXLlxg4sSJ1K5d25Qo/OSTTxIdU/QudsWLF0/0PCIiIklmMECptvDRYag3Auwt/J790AeWdIM5nuDnbT7+n3wFytN86R4uDWhBgKP5eK5d5/m3cV1Obfjfau7s6dIwr3dVurUpxoL0IRZb0PtHa0Efrhb0kloYRRLI19fXCBgBo6+v78sOJ1nMnz/fdE8jRoxItnnv3r1rTJcunREw9u7d2xgaGmrxuOHDhxsBo42NjfHs2bMxxq5cuWKKDTD26dPHGBERYTbHe++9Zzrm2LFjZuPTpk0zAkZ7e3vjxo0bLcbx4MEDY8mSJY2AsUaNGmbjX331lekaNjY2xs2bN8d5/+fPn49zfOvWrUYbGxsjYPzzzz8tHpM3b14jYOzRo4fZ2KFDh0znA0Z3d3dj9+7djdOmTTMeP37cGBYWFuf1o/To0cM0R7p06YwnT540O+bChQtGJycnI2Bs2bJlnPO1atXKCBjz5csX43sVHh5uzJEjhxEwdujQIc45ateubYopNitWrDAdU79+/ViPi/5969Wrl+lrlj59euPhw4fjjCOuuK5cuZKgc+MT9b3Omzdvss77vFmzZpnuYdasWUma6+nTp0ZXV1cjYMydO7fFf5vxxWDpMWzYMKvnet7NmzeNadOmNQLGzJkzG4OCghI1j4iISIp4fNNoXPm+0fiVm+XHKHejcc3HRuOTO3FOc/H8IeOaVpWNp4sWs/jY8mFbY/DjhzHOOe/nb2z28y5ji082Gif222ac3G+72WPpd4eND/0CUvALIJIwKfX+WyuDRIB79/63nDRz5syxHhcREcHJkydjfYSGhsY4/vfff+fx48fkzJmTKVOmYGdnuYHf6NGjyZkzJxEREXHWk8mePTu//vqrxVaYQ4YMMT1/vvCs0Whk3LhxAHz88cc0adLE4vzp06dn/PjxAOzbt8/UHcmSnj170qhRo1jHIbJ4b1waNGhAy5YtAVi9enWcx1pSpUoVpk6datoe9ejRI+bNm0ffvn0pW7Ys6dKlo1GjRkyfPp2AgACr5hwzZkyMOjxRChUqROvWrQHYu3dvrOffv3+fDRs2AJEds6J/r2xsbHj77beByMLNjx49siqm6EJCQjh79izfffedqX25s7MzY8eOter8WbNmERERQZo0adi+fTuVK1dOcAzyPytXrjStSOvevXuS29SWK1eOw4cP89133yVqLqPRSL9+/Xjy5AkAX375JU5OTkmKSUREJFm5ZYc2v0OfvyF3VfNxYwQcnQ2/VoB9kyAsxOI0BQtXocmKfVx4vxGBFpqC5dp+mqONa3Fm6zLT5wpnTcuq/jXxaJyPeW7PuGlrvgrojs8TFn9zmJO7b2CMZduayJtAySARML1xAuLcIubv70/p0qVjfdy4cSPG8WvWrAEiW4s7OlpYy/ofOzs7qlePLK534MCBWI9r3759rPMULVrU1C3p8uXLMcZOnz5tKiLbvn37WOeHyFbsUeKKpWvXrnHOY8ndu3e5cOFCjARaVPLt+PHjCZ4P4L333sPb25tevXqZtc8OCAhg69at9O3bl8KFC8e7NctgMJiSNZZUrFgRgAcPHsSayFm0aJEpKRh9i1iUqM8FBwezbNkys/HY4op6ODo6Urx4cYYPH05gYCAVKlRgy5YtVK1q4ZepWOYCCAoKYv369VadE93OnTsxGo0YjUby5cuX4PPfNFG1xiD2LmKWtG7dGm9vb7y9vTl8+DCLFi2iTZs2/Pvvv3Tp0oV169YlKp5vv/3WVPOrbt26pg6JIiIir5ycFaH3Zmg3A9xymY8/84etX8KUanDDcidWe1t7Wn7yC66Lp3OpkPnv8OkfhsKAkWz9pDPPnkbWBHKws2Fok2JM/bAae/LYsNcplAhiJn3CQyPYtfAcG6acINDfcjJK5HWnZJAIxEgiWLuCJD7h4eGmQrBTp06N8Ybe0mP58uUAcbYHj6+gbPr06YGYyS0gRv2f6tWrxxlH9PbbccVSpkyZOGOJsm/fPjp16kTGjBnJkiULRYoUiZFAmz59OhBzdVZCFS1alJkzZ3L//n3279/PxIkT6dq1q6nOD8CtW7do0aJFnJ3bMmXKRMaMGWMdz5Dhf21Nn/8aR5kzZw4QWc/IUq2WsmXLUqpUKSDpXcUcHBx49913qVGjhtXnfPvtt6af9y+//JKffvopSTGkZjdv3mT79u1AZDv36PWy4uPu7k6pUqUoVaoUlStXpnPnzqxcuZK5c+dy+fJlWrVqFaMznzUWLFjAl19+CUD+/PlZuHAhNjZ6mRcRkVeYwQCl28NH/0DdL8De2fyYB5dgZmP4589YC0wXLVGTRqv3cb53HYIt9NPItek4/zSpxfmdf5k+VzlfBjZ8UovcNbPF34L++N1E36LIq0q/JYpAjATA3bux/2fv7u5uWhUR9YitYOyDBw8S1TEqMDAw1jFnZwsvkNFEvfELD49ZGO/OnTsJjiO+WKIST3EZNWoUNWvWZOnSpfF2+woKCkpwfM+zt7enevXqDBo0iPnz5+Pr68v27dtN277Cw8P58MMPY13ya+3XN2qu5505c8aUeLO0KihK1Pauffv2ceXKlbhvCkwrSLy9vdm9ezeTJ0+mYMGChISE0L9/f9PWPmtUq1aNdevWme71008/5Y8//rD6fPmf+fPnExER+YtjQlYFxaV79+506NCBiIgIPvroI6u75K1fv55evXphNBrJli0bW7duJVu2bMkSk4iISIpzcIbaQ+GjI1Cms/l4eAisHwwr+8Azyw1DHOwcaTX0d5wWTOFKfvPf6TLeCyH0/WFs+6wbIYGRc7g62vFD+7KM7lWBtVmNeMXWgv53b7WglzeOkkEiRK7WiOLl5ZUsc0ZPFkRtZbLmsWXLlmS5fmyxrF271upYPvzww1jnjK9d9fbt2xk9ejQABQoUYMqUKZw4cYJHjx4RGhpqSqZFrWRIKfXq1WPr1q2mVT0XLlwwrdhKbtFX+nz66aexrr76v//7PyCyvos1q4OiVpCUKlUKDw8P+vfvj5eXl2l11vDhw/nnn3+sjrNWrVqsXr3atOXwww8/TPIqpdQoaouYg4MDnTtb+MU1kVq1agVErlK0puvczp07ad++PaGhoaRPn57NmzdTsGDBZItHRETkhUmXE9pOhfe2Q44K5uPey2B6PbhzNtYpipepS4M1+znbvQYhz5XrtAFyrj3KoSY1ubhvg+nzjUtmY/2ntYiokJ7l8bWgv6wW9PJmsFzNViSVKVWqFBkzZuT+/fvs2bOHwMDAeFeJxCf6liKj0WjaGvQyRF/5FLU9JaVFbf9Knz49Bw8ejLUwt7UrH5Iie/bsNG/e3PTm/eLFi5QvXz5ZrxEREcGCBQviP/A58+bN46uvvkrweWnTpmXu3LlUqFCBsLAwBg8ezO7du60+v2HDhixbtox27doRGhpK7969cXJyomPHjgmOJTU6duwYJ0+eBCJrgkX/955U0f+tXL16Nc5jDx8+jKenJ8HBwbi6urJx40art3CKiIi8snJVgne3wLZRcGByzLF752B6XfCcBGU6WDzdwd6RNl/8ycmmW7g67P/Iey04xnimO8949t5gtrdeRq2vpmDvlIbMaR2Z0aMSiw77Mn7NaWo/tqVwWMw/fvrfDWLlj0ep2DQflZrlw9ZWayvk9aWfXhEii+pGbevx9/c31X1JCgcHB9P2pH379iV5vqSInvh4UbGcOnUKiCxiG1eHtuj1jFJSjhw5TM+T2vHJkh07duDr6wvAgAEDWLRoUZyPTz75BIBLly4l+ntStmxZU8HrPXv2WLWKJDpPT08WLFiAra0t4eHhdOvWzVR8WOIWfSVVcm0RixK9EH30Gl7PO3HiBE2aNOHp06c4OTmxdu1aqwuJi4iIvPJs7aHxWOg0HxzdYo6FBsLK92DdpxD2LNYpSlVoRJ11+zjbuQohzy1qtzFCjlUH2d+0BlcOR9YANBgMvF01D6sGeeBTLA2b0oQQ8lxxaWMEHFnvw8ofjvLoduwlFURedUoGifzn008/NbVg/vzzz62q5RKfqLbpZ8+eZfPmzUmeL7EqVKhgKqY8bdo0goOD4zkj6aLqJcVVkNvLy4tDhw4l+hoJafcZPelUoECBRF8zNlHJAVtbW0aMGEHnzp3jfHzxxRfY2dnFODcxvvjiC1Mto2+++SbB53fo0IGZM2diMBgIDQ2lQ4cObN26NdHxpAZhYWEsWrQIiCw63qxZs2SdP3qXudKlS1s85vz58zRq1IiHDx9ib2/PihUrqFOnTrLGISIi8koo7gn9dkE2C6+JR2ZEFpd+GPtKWicHZ9qMmgOzxuOb07wrb5ZbQQT0+Ii/R/Yl/Fnk78j5Mrmw/IO3aNiiIPPShXDD1rxW5J2rakEvrzclg0T+kydPHiZNmgTA48ePqVmzJnv37o3zHKPRGGuLcYCBAwea/rLfq1cv02qZ2Kxfv54TJ04kLHAr2NjYMHz4cCCy7fw777zDs2ex/xXF39+fyZMnxzpujcKFCwOwd+9eLl68aDZ+9+5dUyHlxGrbti1TpkyJtwPc7NmzTV2f8uTJk+xbxAICAli5ciUAHh4eZMmSJd5zMmXKRO3atQFYunRpnN+PuBQrVoy2bdsCkau+duzYkeA53nnnHX7//XcAnj17RuvWrWPdclanTh1T7SMfH59ExZxcfHx8TLG8yETIpk2bTEXZu3Tpgr29hbYlFsyePTveROxPP/3Ehg2RNQzy58+Ph4eH2THXrl2jQYMG3L59G1tbWxYuXJjsCSkREZFXSoYC8O5WqGChcctNL5haC87FvUK6bJUW1Fy/hzPtyxP23LtgWyNkX7qHPc1qcPVY5O9AdrY2fFy/MLM+qs6h/HbsiaMF/Xq1oJfXkGoGiUTTp08fbty4wejRo7l58yYeHh7Uq1cPT09PSpcuTYYMGQgPD8fPz49jx46xdOlSU4LH1tYWBweHGPNlzZqVOXPm0L59e27dukWlSpXo2bMnTZs2JVeuXISGhnL9+nUOHz7M8uXLuXz5MmvXrk2Rmh/vv/8+W7duZdWqVSxbtoxjx47Rr18/qlSpQrp06fD39+fs2bPs3LmTNWvW4OTkxEcffZTo673zzjusXbuWgIAAateuzbBhw6hYsSKAqf27n58f1atX58CBA4m6hq+vL/379+f//u//8PT0pFatWhQtWpT06dMTHBzM2bNnWbZsmenNtcFg4Keffkr2bWIrV67k6dPIrhTt2rWz+rx27dqxfft2Hj16xJo1a+jQwfK+9/gMHz6c5cuXA5Grg+rWrZvgOfr160dQUBCDBg0iMDCQFi1asG3bNqpUqZKomOJz8eJFs2Rr1Nfw6dOnZm3VmzRpkujuWMuXLzfNDcS47vMxZMuWjSZNmsQ5X/SVXLF1E7Rk1KhRDB48mHbt2lGzZk0KFiyIq6srT548wdvbmwULFpi2DDo4ODBt2jSzQu3379+nQYMGpi2JgwcPplixYqb6RZakT5+enDlzWh2niIjIK8k+DbScBHmq/bc9LFon2uBHsKgT1BwEdUeAreW3uc5OaWn7zUKONl2J34jR5LoVM4GT9UYg/t36saNLPWoP+xkbe3vK5nZn3ce1+G7jGRbsuUbzQAcyRMTMJl31vs/Crw9Rv3sx8peNvTyCyCvFKJJAvr6+RsAIGH19fV92OCli5cqVxgIFCpjuM66HwWAwNmnSxOjt7R3rfGvWrDFmyJAh3rlsbGyMf//9d4xzr1y5YhqfNWtWnHHnzZvXCBh79OhhcTwkJMT4wQcfGA0GQ7yx5M+f3+z8r776yjRujV69esU6v62trfHnn3+Od8647qlVq1ZWfY8AY7p06Yxz5861eI0ePXoYAWPevHnjvJ9Zs2aZ5rty5Yrp8w0aNDD9LNy4ccOqr43RaDT6+fkZbWxsjICxRYsWMcZq166doK91s2bNTMcfOHAgxlj0r/GOHTvinGfs2LGmY9OnT2/08vKKNa7oX4OEiv61tOZhKe7o/zZq164d67WifoasecQ1j9FoND58+NDo5ORkBIwlS5ZM0D1bG0euXLmMW7ZssTjHjh07EvR1i+v/AxERkdeW30mjcVIFo/ErN/PHzGZGo/+teKd4GvDIuHxoB6N3sWLG00XNH383rGS87n0wxjk7zt42Vv96q/HdAZuNk/ttt/j4e+5p47Og0JS6c0mFUur9t7aJiVjQpk0bzp07x9KlS3n33XcpUaIEmTJlws7ODjc3N/Lnz0/Lli357rvvuHTpEhs3boyzQ5enpydXrlzhxx9/pF69emTNmhV7e3vSpElD/vz5adGiBRMnTsTHxydRqzqsZW9vz5QpUzh+/DgDBgygdOnSpEuXDltbW9KlS0e5cuV49913Wb58OWfOnEny9WbOnMm8efPw8PAgbdq0ODo6kjdvXrp3787+/fsZOHBgkuZfvXo1Z8+e5ZdffqFjx46ULFnSdD8uLi7kyZOHZs2a8fPPP3Px4sUkb0uz5MaNG/z9998AVK9ePUah6vhkzZqVGjVqAJFbj+7evZvoOL744gvT8zFjxiR6nuHDhzNixAgAHj58SKNGjZLlZ+FNsWzZMtNWr4T+PG3evJkJEybQtm1bypQpQ9asWbGzsyNt2rQULFiQdu3aMWvWLM6dO0fDhg1TInwREZE3Q9aS0GcHlGxjPnZ1L/zhAVf2xDmFi3M62o1bSuDvX3Eri/lKomzXnnK/U092jRtExH+1MOsUzcL6T2thXyVj7C3o991iyTeH1YJeXnkGo1HVriRhrl+/Tu7cuYHIbTpRhYlFREREREReGKMRDk+DzV9ARGjMMYMN1BsBNQaBTdxrIJ48fcCWke9RbOMZbCy8O76V343iE6eQvXjF/y5rZJXXDb5bdYoaD23MWtBHXh8qNc1HpeZqQS9Jk1Lvv/VTKSIiIiIiIq8fgwGq9oPem8DtuTfIxgjY/jUs6gyBD+KcJq1rBtpNXMmTScPwy2S+Sij7FX9ud+jG7on/R0R4OAaDgbYVcrHqUw9ulnKx2IIeIxzZoBb08upSMkhEREREREReX7kqwft7oFAD87ELm2FqbbhxLN5pqjXsQYUN2zndsBARz405hkHmaWvY0cqD2xciu//mSu/Mor7V8WxTmIXuakEvrxclg0REREREROT15pwB3l4WuTXM8Nzb3MfXYGZj+OfPyK1lcUjnloV2v67l4cRPuZPBfPtXjosPudm2E3snf4nRaMTWxkDfWgWZO7AGXoUc2OMUSnhsLeh/Uwt6eXUoGSQiIiIiIiKvPxsbqPUZdF8FLs+1eA8PgfWDYcV78OxpvFPVbNaH0uu3crpuPrMxp1DIOHk529t4cO9KZKON4tndWPVRTco0ysPCtM+4b/P82iK4evI+C0cf5PK/iW8aIpJclAwSERERERGRN0eBOtBvN+Spbj52cjlMrwd3zsY7TYb02Wn3+0bujvuIe+7mb51znr3PtdbtODB1DEajESd7W4Y3K85P71dlWy4bvBzCzM55FhDGxj+8+XveGUKCzcdFXhQlg0REREREROTN4pYDeqyFtwaYj907B9PrwomlVk1Vq1V/iq3bxBmP3GZjaZ4Zcf9pIds61OHBtQsAVC+YkXWfeuBWI0usLejP7LvFYrWgl5dIySARERERERF589jaQ6NvoNMCcEwXcyw0EFb2gXWDIDQ43qkyZ8pNm2mb8fumHw/czN9G5zp5hystW3Fo1g8YjUbcnOyZ2KkcQ3qUY2WWcC7YmxeXfnIvmBXjj3JozWXCw823lYmkJCWDRERERERE5M1VvAX02wnZypiPHZkZWVz6oU+80xgMBuq2/4TC69Zxpnp2s3HnYCNu42axtUsDHt28AkDzMtn5a3At7pdNy8Y4WtCv+OEoD/0CEnFzIomjZJCIiIiIiIi82TIUgHe3QsWe5mO3/oWpteDcRqumypIlP21mbufmVz155GowG8/9700utGjBP/N/BiCrmxNz361Kx47FWJjecgv6u1efsHjsP5zcdV0t6OWFUDJIJJnly5cPg8FAz549X3YoIiIiIiISxd4JPH+B1n+AXZqYY8GPYVFn2PoVhMdf2NlgMFC/y/+Rd81qzlbOYjbuGhiB6zdT2dK9Ef63r2MwGHinej4WDfLgZFEniy3oI0Ij2LXoPOt/O0HA42dJulWR+CgZJPKcgIAA/vjjD5o1a0bOnDlxcnLC0dGRzJkzU7lyZXr37s306dPx9fV94XGlTZsWg8GAwWDg22+/tfrcqHOefzg4OJA1a1Zq167N2LFjuXPnTrxz1alTx3R+fH755RfTsVmyZOHEiRNWx5wQd+7cYd26dYwcOZKmTZuSKVMm03WTMyn3+PFjFixYQK9evShbtizp0qXD3t6ezJkzU7duXSZMmMCjR4+snu/evXuMHDmSMmXK4ObmhpubG2XKlGHkyJHcv38/znOjko4Jefj4+KT4PYmIiIi88sp1gT5/Q8bC5mP7foa5LeGJn1VT5chRhNZzd3J9+Nv4O1tYJfSPL2ebNebY0ikAFMriyor+NajSLB+L4mhBv+jrQ2pBLynKYNQaNEmg69evkzt3ZCV9X19fcuXK9ZIjSj4HDhygc+fOXLt2Ld5js2bNip+f+YtEvnz5uHr1Kj169GD27NnJFtvcuXPp0aOH6eNixYpx5swZq861JnEDkCFDBhYtWkSjRo1iPaZOnTrs2rULIM4lrOPHj2fo0KEAZMuWje3bt1OiRAmr4kiouO4vub4PGzdupE2bNjx7FvdfabJly8bChQupW7dunMcdOnSI1q1bW/wZAsiePTurV6+mSpUqFsejfs6slS5dOvz8/HBycjJ9LrnvSUREROS18uwJrPkYTq00H3PJAu1nQP5aVk933fc0Rz7rS9F/Lf9R71r1/Lw1fgZpM0XWGzp69SGfLfaiwI0wyofYWTyneI3s1OxQGAcny+Py5kup999aGSTyn/Pnz9O4cWNTIqhly5bMnTuXgwcPcuzYMbZs2cL48eNp1KgR9vb2Lzy+uXPnAuDq6grA2bNnOXz4cILmqFSpEt7e3qbH/v37mTt3LtWqVQPgwYMHtG3blitXriQp1rFjx5oSQTlz5mTXrl0plgh6Xp48eeJMZiXW/fv3efbsGTY2NjRu3JiffvqJv//+m2PHjrFmzRo6deoEgJ+fHy1atODff/+NdS5fX188PT3x8/PDzs6OoUOHsnv3bnbv3s3QoUOxs7Pj1q1beHp6cv36dYtzbNmyJcb30tLjp59+Mh3fsWPHGImg5L4nERERkdeOY1poPxOajgeb536/D7gDc1vBngkQYV2nr1y5S9Bq0R6uDmnP0zTm43kOXOF004b8u+pPACrmTc/aT2qRuU52lrs842lsLejHqAW9pACjvNauX79u/Omnn4wNGzY05s6d22hvb2/MmjWrsW3btsaDBw+myDV9fX2NgBEw+vr6psg1Xob27dub7mvWrFlxHnvnzh3j5MmTLY7lzZvXCBh79OiRbLH5+voabWxsjIDxxx9/NKZPn94IGPv372/V+VH3Vbt2bYvjERERMe4/rnlr165tOs6Sr776yjSeJ08e48WLF62KMSlGjhxpXLt2rdHPz89oNBqNV65cMcWQXN+HxYsXG/v162e8evVqrMdMmjTJdN26devGelz37t1Nxy1dutRsfMmSJckSf8eOHU3z7Nmzx2w8Oe9JRERE5LXm+4/ROLGk0fiVm/ljfnujMeB+gqbzufyv8a921Y2nixaz+Nj8nqcx4MEd0/HbTvsZ3xq1xfjxR1uMk/ttN3+8v9148K9LxrCw8OS+c3nFpdT7b20Te80NGzaMcePGUbBgQerUqUPmzJm5cOECq1evxmg0snDhQtNf95PLm7hNLDw8nLRp0xIUFESlSpX4559/Ej1XSmwT+/777/n888+xs7Pj5s2bjBgxgmnTppExY0Zu3boV70qlqG1UtWvXZufOnRaPuXTpEoUKFQKgaNGinD171uJxcW0TGz58ON999x0A+fPnZ8eOHeTNm9fq+0wuPj4+5M+fH0i+bWLWqly5MkeOHMHGxobbt2+TKVOmGON+fn7kzJmTiIgIGjduzKZNmyzO06RJEzZv3oyNjQ03btwgW7ZsCYrj8ePHZMuWjeDgYAoUKMClS5dS7J5ERERE3giBD2BlX7i41XwsXR7oOBtyVrR6uvCIcDb/MZws09bgEmw+/iidHa4jh1K6eXcA7j19xrDlJ/A7fp/6QfY4YF4KIXOetDTsXYL02VysjkNeb9omJhZVqVKFnTt3cvHiRf7880++++47li9fzo4dO7C1teWDDz6Itx6IwN27dwkKCgIwJUSSy7lz5+jTpw/58uXD0dGRrFmz0qZNGw4ePGj1HPPmzQOgUaNGZM6cme7dI18w7t+/z/r165MlzgIFCuDiEvmikpji2EOGDDElggoXLszu3btfSiLoZatTpw4AERERFrfbrVmzhoj/lhr36tUr1nmiCl9HRESwZs2aBMexdOlSgoMjf+t45513Enx+dPHdk4iIiMgbwTkDvL0U6o0Aw3NvlR9fg5lN4PB0sHI9ha2NLc0+HEfGZfO4WMzNbNz9cRh2g79ly4dtCfZ/SCZXR6b3qMQ7b5dkacZQrltqQX9NLegleSgZ9Jpr27YttWvXNvu8h4cHdevW5eHDh3h7e7+EyF4vDg4OpufWFmW2xqpVq6hQoQJ//vknV69eJSQkhDt37rB69Wpq1qzJkiVL4p3jyJEjnD59GoBu3boBUKNGDdPKl6haQkllMBiws4ssTJfQmkgDBw5kwoQJQGRh6127dlmVsY7qcpUvX74Ex/uqip58tbW1NRvfu3ev6bmlf7uWxvbt25fgOKJ+LgwGgyl5mFjx3ZOIiIjIG8PGBmp9Bt1Xg0vmmGPhIbBhCKx4F549tXrK/IUr0XTFPi5/0IQgB/Px3H+f4Wjj2pzasgSDwUDHyrlZ8mktzpd0ZnccLejXTT6uFvSSaKk6GZSc7aivXr3K4MGDKVasGC4uLmTIkIHKlSszfvx4AgMDU+YG4hH1hj7qDb7ELkOGDKZVLMePH2fcuHGm1RuJ5e3tzdtvv03WrFmZPHkyBw8e5MCBA4waNQonJyfCw8Pp27cvd+/G3TIy6k192rRpadWqFRD5Bv/tt98GYP369Tx48CBJsQLcunWLx48jC9NZm5wxGo18+OGHTJo0CYBSpUqxc+dOsmfPnuR4XldRW+js7e0trjKLSuylS5cuzq1f2bNnx80t8i9ICU1QXrlyxZRAqlmzJgUKFEjQ+c+L755ERERE3jgFakO/PZDnLfOxkytgel24Y/3vaHa2djQf+BNuS2dyubCr2XiGh6HYfDyKLR935NmTx+TJ6MzS99+iVquCLHELsdiC/tqpByxUC3pJpFSdDMqaNSuenp6MGTOGTZs2cf++5RaA8Vm7di1lypRh4sSJnDt3jsDAQB4+fMiRI0cYOnQo5cuX5+LFi8kcfdyuXbvGtm3byJ49O6VLl36h135dDRgwwPR82LBhFCxYkIEDB7JkyZJEbY05duwYJUuW5N9//6V///5UrVqVatWq8dVXX/Hnn5EdBPz9/Zk/f36sc4SGhrJ48WIA2rRpg7Ozs2ksapVQSEiI6ZikiNriBdC+fXurzunbty+///47AOXKlWPHjh1kzZo1ybG8rtavX8+JEycAaNy4sSmZE11UdzBrVk5F3xucEHPnzjUtG07qFjFr7klERETkjeSWHXqshRoDzcfunYfp9eB4/Cv9oytUrDqNVu/n4nv1CbawGD/3Fm/+aVKLsztWYWtjoH/dQkwd+Bb7CthxzCHM7PiQgDA2/uHN33PPEBJsPi4Sm1SdDIouse2ovby86NSpE/7+/ri6ujJ27Fj279/P9u3b6dOnDxDZsrx58+Y8efIkucO2KDQ0lO7du/Ps2TPGjRunbR1WGjRoEL179zZ97OPjw6RJk+jcuTMFChQgW7ZsdO7cmbVr11q9P3fmzJkW3zy//fbb5MiRA4A9e/bEev7GjRtNK4eikj9RihUrRqVKlYDEbxV78uQJx44do1evXkyePBmIrPfTv39/q86PSmoVL16cv//+O1UXFn7w4IHp62Zra8vXX39t8bio/wdcXc3/IvS8qBpOT59avwwZ/ldjKk2aNHTs2DFB50Zn7T2JiIiIvLFs7aDh19B5ITimizkWGgir+sLaTyDUQoXoWNjb2uM5ZDLOi6ZypYB5IeiM90MI/2A4Wwa/zbPAJ5TKmY6/BnqQp2FOlsXWgn7/LRaNOcytS2pBL9ZJ1fuHRo4cSeXKlalcuTJZs2aN0YHIWgMHDiQoKAg7Ozu2bNlC9erVTWP16tWjcOHCDB06lPPnzzNhwgRGjRplNsfgwYMTVOR54MCBFC5c2OJYREQEPXv2ZPfu3fTp0yfJtUJSExsbG2bMmEHnzp2ZOHEi27ZtIyzsf9n127dvs2TJEpYsWUKlSpVYvHgxBQsWjHW+0qVLU6ZMGYtjBoOB8uXLc/PmTS5fvhzrHFFJnuzZs1O/fn2z8W7dunHkyBEOHTrEhQsXYv25iLJr1y5TZzFLMbVq1YopU6aQPn36OOeJfo7RaOTKlSscO3bMYoxxeVOK3oWHh9O1a1euXr0KwIgRIyhfvrzFY6OKOkevUxUbR0dHAFNxc2vs37/f1DmsVatWiV7Jk5B7EhEREXnjFWsO/XbB0nfA70TMsaOz4OYx6DAHMlj/frJoqVrkX7OPjeM/Ju+C3ThGW9hjA+Re78Xhwx5k/W4sRWo25yvPktQrloUvFh+n/J0IioTG/KP/0/vBrPzxKJWa5qNS83zY2mrth8QuVSeDRo8enaTzDx8+bFrV8e6778ZIBEUZPHgws2bN4syZM/zyyy988cUXZsV5p06dSkBAgNXXbd++vcU3/REREfTu3ZuFCxfSrVs3/vjjjwTekQA0bNiQhg0b4u/vz759+/jnn384cuQIu3fvNtXUOXLkCB4eHhw9ejTW+jjFihWL8zoZMmQAiHXF2MOHD1m7di0AXbp0wcbG/D/zLl26MHjwYMLDw5k7dy5jxoyx+j6flyNHDj755BPTiiVr/Prrr3z00UcEBwfTsmVLNm/eTM2aNRMdw+vqww8/NLWIb9GiBV9++WWsxzo5OREYGEhISEi880YlidOkSWN1LNFXifXo0cPq856XkHsSERERSRUy5Id3t8Km/4Ojs2OO3ToO02pD6z+gWDOrp3Swc6TV51M53WQbV/9vKHmvxfwjYKa7zwjtM4StrZZSZ9QfeBTOzJrBtRixypsNR+7S4PkW9EY4ssEHn5P3aaQW9BIHpQqTYPXq1abnsbWItrGxMdXsePToETt27DA75unTpxiNRqsfUW2eo4uIiKBXr17MmTOHLl26MHv2bIvJA7Gem5sbTZs2ZeTIkaxZs4bbt28zc+ZM06qZW7duxfkGOXp9H0uivj/h4eYtIwEWL15sShg8v0UsSpYsWUzbG+fPnx/vSptKlSrh7e2Nt7c3J06cYMuWLXz55ZekS5eOGzdu0KRJkzi3rT2vf//+jB8/HoDAwECaN2/OP//8Y/X5b4LPP/+cadOmAZFd/JYuXRrn1sy0adMC1m39ikoSW7OlDCKTR0uXLgUiV5M1bNjQqvOel9B7EhEREUk17J3A8xdoMxXsn/t9P/gxLO4CW0dCeMLq95Qo34C66/Zx/u3qhDz3a5eNEXKtPsz+Jm9x6eAW3J0dmNy1In17lGZZpjCLLejvXXvC4m/+wXunWtCLZcoWJEFUi2gXFxcqVqwY63FJbREdn6hE0Ny5c+nUqRPz5s3TG7cU4OjoSK9evVi0aJHpcytXrkxy17HYRF/hUaFCBVOnu+cfGzduBCJrHO3evTvOOV1cXChVqhSlSpWidOnSNGzYkK+//pq9e/eSNm1anj17RteuXfH397c6ziFDhphW2fn7+9O4cWNTweE33bhx4/j++++ByO/RunXr4l3FE1U4OqqQdFyiCkdHFZKOz9q1a3n48CEQWZcqMf8PJOaeRERERFKdsp2hz9+Q0UKZhn2/wNyW8MQvQVM6OqSh1ciZ2MyaiG8uR7PxLH7BBPUayLYv3yMsOJhW5XKyfEgtrpZ1tdyCPiyC3YvPs1Yt6MUCJYOSIKrdc6FCheJs3x59u1BCW0THJ2pr2Ny5c+nQoQPz589XIiiFNW7c2PTm/OHDh4nuQheXCxcucPDgwQSfl9hC0qVKleLbb78FIhMQUat9rDVy5Ej+7//+D4j8mjRs2JCzZ88mKpbXxZQpUxg2bBgQWUB78+bNVtXnKVGiBACPHz/Gzy/2XxBu3bplSsoVL17cqpiSukUssfckIiIikiplKQ59d0CpduZjV/fBHx5wJe4/1lpSukpTaq3fx7kOlQh77h27rRFyLtvH3uY18Dm6ixzuaZj/XjUativMUvdQiy3ofU89YOHoQ1z2Ugt6+Z9UXTMoKYKDg7l37x4Qf4vo9OnT4+LiQkBAQIJbRMfn66+/Zs6cObi6ulKkSBG++eYbs2Nat25NuXLlrJ4zvhULt27dSmiYb5wcOXKYvpexFWROiuhv6n///Xfc3d3jPH7WrFls2bKF5cuXM3ny5ESt5OjXrx8TJ07kypUr/PTTTwwcODBB3cG+//57goKCmDRpEnfu3KF+/frs2bOHAgUKJDiWV928efP46KOPAChQoADbtm2z+mtVs2ZNU7evXbt20alTJ4vH7dq1y/S8Ro0a8c579+5dU42fcuXKUbp0aaviiZKUexIRERFJtRzTQrsZkKc6bPocIkL/NxZwB+a2grpfQM1PIQFlPJwcXWg9Zh7/NvuL28NHkvNWzHqTWW8E8qT7+/zdpS61h/3Mex4F8CicmU8XeZHtSjAVQmK+1Q8JDGPjVG+KVc+GR6ciODgpFZDa6ScgkaIX/bW2RXRAQECCW0THx8fHB4isPzJ27FiLx+TLly9BySBrt6SkVoGBgZw+fRqIrCuUMWPGZJ3faDQyf/58IHLFzvvvvx/vOU5OTmzZsgV/f39Wr15Nly5dEnxde3t7hg0bRr9+/QgICOCnn36K9WcqNj///DNBQUFMnz6dmzdvUq9ePfbs2fNG/UytXLmSXr16YTQayZUrF9u3b09Q0e2WLVvywQcfEBERwaxZs2JNBs2ePRuIrC3VsmXLeOddtGgRoaGRv3wkdFVQUu9JREREJFUzGKBKH8hRAZb1gMfRFgAYI+DvMeB7KLLOkHOGBE1drnorAjfUZfPovhRacxy7aAt/7CIg+4Id7N5dg4ITfqFombdYOaAGP229wPLtV2gS4ICrMeYfrs8e8MP33EMa9y5J9kLuSbhped1pm1giRbWHhpRrEW2N2bNnx1twumfPnsl6zTfR06dPqVq1KuvWrYuzBlBERAQDBgwwJQNbtmyZ7CuDdu/ebUrytW/f3qpzmjRpYkpKJnarGEDPnj3JmTMnAL/99pupe5q1DAYDf/zxB127dgXg6tWr1K9fP9btUFF1j/Lly5fomJPL7NmzTfGMGjXK4jFbtmyhS5cuhIeHkyVLFrZt25bg2LNly2b6+mzevJnly5ebHbNs2TI2b94MQPfu3cmWLVu880Z93+3s7Hj77betjic57klEREREgFwVod9uKNzIfOzCFphaC64fTfC0zmncaPP9YkL++JqbWe3NxrP5PuVRl3fZ8f1A7DEyrGkxxvWvyubcBs7bmxeXDnjwjJUTjnHwr0uEh6dM/VN59WllUCI5OTmZnqdUi+iXJb6tbLdu3aJKlSovKJoX5/Dhw3h6epIzZ05at25N9erVyZs3L2nTpuXRo0d4eXkxc+ZMvL29AUiXLl2SWrnHJnoyp107C/uPLXBycqJZs2YsXbqUrVu34ufnZ1UC4XkODg4MGTKEQYMG8fjxYyZNmpTgluI2NjbMmTOH4OBgVqxYwYULF2jQoAE7d+5MsW1He/fu5eLFi6aPo7ZwAly8eNG0yiZKYhKkBw8epE2bNoSEhGBvb89PP/1EaGgoJ0+ejPWcXLlyWdziN3bsWDZt2sTdu3fp0qULR44coUWLFgCsW7eOCRMmAJA5c2aLWz+fd/r0aY4ejfzFokmTJmTJkuWF35OIiIiIELnyp8sS2DsRdoyNXBkU5bEvzGwMTb6Dyu9FrihKgIq1OhCwoSFbvnqPIutPYROtXrR9OGSbvYUdu2pQbMJkqpSozF+fejB6zSk2HPCz2IL+6Mar+Jy8T+N3S6oFfSqkZFAiRbWHhpRpEf0yxVcD6U1kZ2dHtmzZ8PPz48aNG/z222/89ttvsR5fuHBhFi1alOwrKIKCgkwrRYoWLUqpUqWsPrd9+/YsXbqU8PBwFixYwODBgxMVQ9++fRk7diz37t3jl19+YdCgQQn+ubW1tWXRokW0adOG9evXc+rUKRo1asTff/+dIomEP//8kzlz5lgc27dvn1kXv8QkgzZt2kRgYCAAoaGhptU9cZk1a5bFa+XOnZu1a9fSunVr/Pz8GDduHOPGjYtxTLZs2Vi9erVV/x6jJxDfeeedeI+Pkpz3JCIiIiL/sbGBWkMgdxVY3hsCohVujgiFDUPg2oHIFvWOaWOfxwIXF3fa/Licw00XEDDqe7LdjdnCPscVf+51fIeLvZrj8ck4fuxYjk0lbvHtUm9q3DeQKzxms6H7vk9Z/M1harYvTKnaOVOkHqq8mrRNLJGcnJxMtWLiK7j88OFDUzLoTaqd8iZxcnLixo0b7Nu3j9GjR9O0aVMKFCiAi4sLtra2uLm5UaxYMTp16sTChQs5efIkFStWTPY4Vq9ebeogZe2qoCjNmjUzrTxLylYxZ2dnBg0aBMD9+/f5/fffEzWPvb09y5cvp0GDBgB4eXnRtGnTZK+b9bqqWrUq3t7ejBgxglKlSuHq6oqrqyulS5dmxIgRnDx5kqpVq8Y7T0REBAsWLADA3d3dqvpCIiIiIvIC5K8F/fZAnrfMx06ugGl14U7iuk1Xqd+Viuv/5mzjIjy/0cshDLJMX8/frWrid/44TUplZ/lntbhV0S2WFvTGyBb0v6oFfWpiMBqNxvgPSx18fHzInz8/EFmA9fmtJc+rVasWe/bswcXFhUePHsXaXv7AgQO89VbkfwAjR45k9OjRyRr3i3b9+nVTUsvX1zdVriQSERERERGxSnhYZBHpfT+bj9k7Q4ufoGznRE9/YNMsQr6eQJYH5vWBgu3haZ821Og/BoONDQsOXWPa6rM08LclY4T52hB7ZzsadC9OgfKZEx2PJK+Uev+tlUFJULNmTSByC1hUvQ5LEtoiWkRERERERN4QtnbQcDR0XgRO6WKOhQbCqn6wdiCEBls+Px7Vm/SizPptnK1X0GzMKRQyTVnF9ra1uHv5NN2q5WX24Bp4FXfimEOY2fGh/7Wg3z7nNCHB5uPy5lAyKAlat25tej5r1iyLx0RERJi27Li7u1O3bt0XEZqIiIiIiIi8Soo1i+w2lr2s+djR2TCjITy4kqip06fPRpsp67g/fiD33M3f5uc694DrbTqw9/dR5M/kwtIP36J0y3yscA3hqcF8s9DZA34s/PoQty4+SlQ88upTMigJqlSpgoeHBwAzZszgwIEDZsdMmDCBM2ci94EOHDgQe3vzVoAiIiIiIiKSCqTPB723QKXe5mN+J2BqbTi7PtHT1/R8n5Lrt3C2Vl6zsTQhRjL+soStHWrz+PpFPmlQhIkDq/F3Phu1oE+FUnXNIEvtqD/77DMgcjvXe++9F+N4Sx10vLy8qFGjBkFBQbi6ujJ8+HDq1q1LUFAQixcvZtq0aQAUKVKEI0eOxOhC9rpSzSAREREREZEkOr4E1n0SuVXseW99DPW/itxilki7Vk7C7vupZPA3T+QEOhkIGdCdar2HERQazrfrT+O1+yb1g+xxxLyjWMbcrmpB/5Kk1PvvVJ0M6tmzZ6ztqC2J7Uu1du1aunXrZuoC9bwiRYqwfv16ChUqlKg4XzVKBomIiIiIiCSDO2dg6Ttw77z5WJ63oP1McMue6Onv3rnK3mHvUmz/DYvjvmWzUWniDDLkLMCOc3f4evEJqt0xmrWgB7CxM6gF/UugAtKvME9PT06cOMGgQYMoUqQIzs7OuLu7U6lSJcaNG4eXl9cbkwgSERERERGRZJKlOPTZAaXamY9d2w9TPeDyLvMxK2XOkpfWM7biN+pdHrmaJ3ByH/fjcosWHJo3gTpFMrNiSC0eVHWPswX9GrWgfyOk6pVBkjhaGSQiIiIiIpKMjEb450/Y9DlEhMYcM9hA3eFQczDYJH49x+1blzgw9D2K/uNncfxaxVxUnTCDdFlzs+LYDSavOEXdR7ZkiqUFff3uxShYPkui4xHraJuYvDQlS5aM8XFoaCgXLlwAlAwSERERERFJNjeOwtKe8Pia+VihhtB2GjhnSPT0RqORv+d9i8svC0gXYJ4KeOJig+Gz96nceQC+DwL5bPG/OJ99SsUQy7WLilbLRq1ORXBIk/jaRhI3bRMTEREREREReZPlrAj9dkHhxuZjF7fC1Fpw/UiipzcYDNR/5wvy/LWK8+Uzm42nDYjAddQUNvVsQrqIRyx4vzpV2hVkVVrLLejPHfRjwRi1oH8daWWQJJi2iYmIiIiIiKSgiAjY9zP8PQaMz3UDs7GHxt9ClT6QhELORqORrbNG4f7rMtIGmacF/F1tsR/+MRXa9uXUzccMXehFwauhFA01Ly6NASo2zkvlFvmxtdOak+SklUEiIiIiIiIiqYGNDXh8Cu+sAZfn6vJEhMLGz2B5b3j2JNGXMBgMNOo9mhyrl3KhtPnWM7en4aQZ/hOb32tBXsdgVgz0IEujHGxwDuHZc8WlMcLRTVdZ+v0RHtwKSHRM8uIoGSTynICAAP744w+aNWtGzpw5cXJywtHRkcyZM1O5cmV69+7N9OnT8fX1tXh+z549MRgMZg8bGxvc3d0pXbo0ffr04eDBgwmObfTo0ab50qVLR3BwsNXnXrx4kUWLFjFo0CBq1KiBs7Ozaa7Zs2cnOJakiIiI4PTp08yePZsPP/yQypUr4+joaIpn586dyXat06dPM378eFq0aEG+fPlwcnLC2dmZ/Pnz07lzZzZs2JCk+6hevXqM73NcYvvZsPTw8fGJ9/pXr15l8ODBFCtWDBcXFzJkyEDlypUZP348gYGBib4vEREREXlF5PeA9/dA3hrmY6dWwrS6cPt0ki6RK28pmi/ZzdVP2hDgZD6eZ+8lvJvV59zmeYzwLMkXH1ZiQw64bhtuduyD609ZPPYwJ3ZcR5uQXm3aJiYJ9iZvEztw4ACdO3fm2jULBduekzVrVvz8zCvx9+zZkzlz5lh1vY8++ohJkybFm0SIUrBgQS5fvmz6eNGiRXTu3Dne83bt2kWdOnViHZ81axY9e/a0KobkMGfOnDivt2PHjjjjtVaPHj2YO3duvMc1btyYxYsX4+7unqD5J0+ezIABA2J8Lq7/UhPys3HlyhXy5csX6/jatWvp1q0b/v7+FseLFCnC+vXrKVSokFXXExEREZFXWHhY5JaxfT+bj9mlAc+foWz87wvic/XSv5wY8j6Fzjy2OH6tTlE8xs0g1MGNr1Z7c/PQXWoG22GL+fuZnMXT07BHCVzcHZMcV2qmbWIiKez8+fM0btzYlAhq2bIlc+fO5eDBgxw7dowtW7Ywfvx4GjVqhL29vVVzbt68GW9vb7y9vfHy8mLVqlX0798fO7vIavuTJ0/mxx9/tGquvXv3mhJBrq6uAFYlOiBmgsLGxoaSJUtSpUoVq85NCdHjsbe3p0KFCpQuXTrZr3Pjxg0AMmTIQN++fVm4cCH79+/n8OHDTJ06laJFiwKR3ydPT08iIiLims5s7uHDh2MwGMiUKVOC4sqRI4fp5yK2R86cOWM938vLi06dOuHv74+rqytjx45l//79bN++nT59+gCRP8/NmzfnyZPELx0WERERkVeErR00HA1dFoNTuphjYUGwqh+s+RhCrd85YEneguVotmIfV/o3J9BCDifPznN4Na7DtV3L+LlLBbr3Ks2qjGHcszH/PfrGmYfMH32QS153khSTpAz1fxP5zxdffGF64xzbSpmGDRsyZMgQ7t69y9KlS+Ods0iRIjFWd5QrV47WrVvTuHFjWrZsCcC3337LJ598Em+CKSrxkzVrVj755BM+//xztmzZwu3bt8maNWuc5+bMmZPx48dTuXJlKlasiKurK7Nnz+bw4cPx3kNKKFGiBJMmTaJy5cqUK1cOJycnRo0ahbe3d7JeJ3fu3EydOpUePXrg6Bjz1axy5cp069aNxo0bs3fvXvbu3cv8+fN55513rJr7o48+4smTJ/Tu3ZtLly6xa9cuq+Oyt7enVKlSCbqX6AYOHEhQUBB2dnZs2bKF6tWrm8bq1atH4cKFGTp0KOfPn2fChAmMGjUq0dcSERERkVdI0abQbzcs7QG3/o05dmwO3PSCjnMgQ4FEX8LWxpZmA37kcqMOnPlsAAXOx/zjYoZHYfDJGDY3XEnDb/+k8tDaDF36L1e9/c1a0IcFhbNp6km1oH8FaWWQCBAeHs769esBqFSpUrxbpjJnzkz//v0TfT1PT09q1qwJwKNHjzh69GicxwcHB7Ns2TIAOnXqRPfu3bGxsSE8PJwFCxbEe73ChQszZMgQateubVpV9DJVqVKFAQMGUK1aNZycLGxMTiazZs2ib9++ZomgKM7Ozvz++++mj5cvX27VvCtXrmT16tVkypSJH374IVlitdbhw4fZs2cPAO+++26MRFCUwYMHU7x4cQB++eUXQkNDX2iMIiIiIpKC0ueD3puh0rvmY34nYGodOLMuyZcpULQqjVft51LfhgRb+Lt1nq2nONKkFg+Ormf2u1Wp1akwq91ib0E//+tD3FQL+leGkkEiwN27dwkKCgJ4YTVWom+Liq0YdZQ1a9bw6NEjALp160bOnDmpW7cuYP1WMbGsVKlSpm1ely5divd4f39/U52g8ePHkzFjxhSN73mrV682Pe/Vq5fFY2xsbEwrnB49esSOHTteRGgiIiIi8qLYO0GLidD2T7B3jjn27DEs6QpbRkB40v4oaGdrR4tPJ5F2yZ9cKehiNp7xfijGD0ewdUgXupTLxJTBNThYxJ5z9ubFpYMePmPVhGMcWHWR8DDryzNIylAySARwcHAwPT9z5swLv6a1W8SKFi1K5cqVgcikEMDx48eTfXtVbGbPnm3qdvUmbT0KCQkBwNbWNt5jhw0bxs2bN6lVq9YLLbodZe/evQC4uLhQsWLFWI+rXbu26fm+fftSPC4REREReQnKdIA+OyBTUfOx/b/CHE/wv5XkyxQqUYOGf+3nUq86PHtup5cNkGfDcQ43qUn4uR0sGVCTPM1zszGWFvTHNl9jyXf/qAX9S6ZkkAiRBYbz5s0LRCZXxo0bl6BiwokRPekUV9eoO3fusHnzZgC6du1q+ny7du1IkyYNgNUdqsScl5eXqSNX1Naq2Bw4cIA//vgDe3v7GNvLEur+/fvUrl2bjBkz4ujoSPbs2WncuDGTJ0+OtyV81M9NoUKFTIXILSlWrJjZOSIiIiLyBspSDPr8DaU7mI9dOwBTPeDyziRfxt7OgRb/9ztOC6ZwNW8as/FMd0MI7/d/7PyiBx/XycOoj6uwJbfBYgv6hzcCWPzNYU7s8FUL+pdEySCR/0RvET5s2DAKFizIwIEDWbJkCVeuXEnWax05coRt27YBkfV8ypQpE+uxCxcuJCwsDPjfaiCAtGnTmopQL1y4kPBw8/9kJX7ffvut6XnHjh1jPS40NJQ+ffpgNBoZMmQIJUqUSPQ1nz59yu7du3nw4AEhISH4+fmxZcsWBgwYQJEiRdi/f7/F84KDg7l37x5AvC0l06dPj4tL5FLe+LYhioiIiMhrztEV2k6H5hPA1iHmWMBdmNcGdo2HZPiDd7Gydam3dj8XutUg5LmF9TZGyP3XEQ40rkE6v39YPqQWxrpZ2OUUSvhzq4SM4Ub2LLnA6l/+JeDRsyTHJQmjZJDIfwYNGkTv3r1NH/v4+DBp0iQ6d+5MgQIFyJYtG507d2bt2rWJyl6Hh4dz6dIlfv/9d5o0aUJERAS2traMHz8eG5vY/ylGbRF76623yJ8/f4yx7t27A3Dr1i22bt2a4JhSuxUrVpiKRlesWJG2bdvGeuy4ceM4deoU+fPn58svv0zU9QwGA9WqVWPs2LFs3LiRY8eOsX//fqZOnUqVKlWAyJb1jRo1wsvLy+z86G3irSkEHpUMevr0aaLiFREREZHXiMEAld+LLC6dLk/MMWME7PgGFnaEwAdJvpSDgxMtR/yJ3dyf8c1l3hAmy+1gnvX6hANj+vJNi6L06VuOtZkjLLagv3n2vxb0x9SC/kVSMkjkPzY2NsyYMYMtW7bQpEkTsy04t2/fZsmSJbRs2ZIqVapYVWw4f/78pho7dnZ2FCpUiA8//JD79+9TuHBhVq5cSatWrWI9/+TJk6akQPRVQVEaN25M5syZAZg3b15CbjdRevbsidFoxGg0vvY1g86cOWMqwJwmTRrmzZuHwWCweOyFCxcYO3YsAJMnTzZtz0uon376iQMHDjB8+HCaNGlC+fLlqV69On379uXgwYMMHz4cgICAAN577z2zpGNwcLDpefSaU7GJ6qIWVRxdRERERFKBnBWg3y4o0sR87OJW+MMDrh9JlkuVrNiY2uv3caFTFcKeyy7YGiHX8gPsaf4WBQNOsXBoLXwruXHUIcxsnrCgcDZNO8nWWacICTIfl+SnZJDEq2TJkjEe9erVe9khpaiGDRuyceNG7t+/z4YNGxg9ejSenp6kS5fOdMyRI0fw8PDg1q3EFWMzGAx06tSJFi1axHlcVC0ge3t7i1uY7Ozs6NSpExDZZSr6yhGJ3c2bN2nWrBlPnjzBYDAwc+bMOOsF9evXj+DgYNq1a0ezZs0SfV13d/dYxwwGA2PHjqV+/foAplVD0Tk5/e+vLlFFr+Py7FnkctvEJq9ERERE5DXlnAE6L4L6X4Hhubf9/tdhZhM4NBWSoV6Po6MzLUfPwThzPDdyOJqNZ7sRRECP/hyf+DF/dC1Do27FWJMu1GIL+vOHbjN/9EFuXniU5LgkbkoGicTCzc2Npk2bMnLkSNasWcPt27eZOXMm6dOnByK3ZsW3XWjz5s14e3vj7e3NgQMHmDlzJuXKlcNoNPLNN9/EqFP0vPDwcBYuXAhA06ZNY21hHrVVLDAw0LTlSWL34MEDGjVqhI+PDwC//vornTt3jvX4mTNnsmPHDtKmTcsvv/yS4vH169fP9HzXrl0xxtKmTWt6bs3Wr4CAyA4N1mwpExEREZE3jI0NeHwK76wBlywxxyJCYeNQWNYTgv2T5XJlqrWgxvo9nG9T3myVkF0E5Fy0m93Na1DN/ipTP6vJ0eKOllvQPwph1YRj7FcL+hQVeysakf+cOnUqxsfXr18nd+7cLymal8fR0ZFevXqRI0cOmjSJXHK5cuVKpk2bFmvNnyJFisToFFatWjW6detGixYt2LJlC1OmTKFBgwa0adPG7Nxt27Zx8+ZNANasWRPrFqbo5s6da9r6JOaePHlCkyZNTD/TY8aMoX///nGeM27cOCCyVfuePXssHnPnzv/2Ny9evBiIrNfj6emZ4BijF6a+ceNGjDEnJycyZszI/fv3uX79epzzPHz40JQMSo3/XkVERETkP/k94P09sPxduLo35tjp1XD7JHScC1lLJvlSadKkpdV3C/FqtoI7X44mh19ojPHs1wPx79aXu10bMG/wj0zfe5VNGy5TN8AeR2K+3/HafI0r3vdp+l4pMuRwSXJsEpNWBokkUOPGjU1vrh8+fMj9+/cTdL69vT2zZ882rfIYMmQIoaGhZsdFFY5OiF27dnHt2rUEn5caBAUF4enpyT///APAZ599xogRI+I9L2qr1bp16+jSpYvFR/TW7VGfi2vVV1ziS/pFJYsuXrxo6jJnydmzZ03P49oCJyIiIiKpQNps8M5fUPNT87H7F2F6ffh3UbJdrrxHO6qt3805z9JEPPfrrX045Ji7jb2tPGiV9SFjB1Xj73w2+FpoQf/oZgCLx6oFfUpQMkgkEXLkyGF6bs2Knedlz56dgQMHAnD58mVmzJgRY/zJkyesXr0agPr167No0aI4H1OnTgXAaDS+kELSr5vQ0FDatWtn2nb1/vvv88MPP7zkqCw7ffq06Xn0n7MoNWvWBCK3gB09ejTWeaJvMatRo0YyRigiIiIiryVbO2jwFXRZAk7uMcfCgmD1+7DmYwgNtnh6Qrm4uNN6/FKCp3zF7czmm5Jy+DzhQeee3F80hiUDa+DUKHucLehX/eylFvTJSMkgkQQKDAw0vWF3c3OLtZZPfAYNGmSq5fL999/HWOWxfPlyAgMDAfjggw/o3LlznI++fftStmxZ4MV0FXudhIeH8/bbb7Nx40YgssbSlClTrD7fx8fH1EEttkft2rVNx0d9LqomUUJFJfaAGPNGad26ten5rFmzLM4RERFhWlnm7u5O3bp1ExWLiIiIiLyBijaBfrshR3nzsWNzYEYDeHA52S5XsW5nKm3cxbkmxXm+ApBDGGSbsYED7WvRp+gz+n9YgQ1ZjRZb0N8694h5ow5y8aha0CcHJYNEiCzGW7VqVdatW0dEROxFyiIiIhgwYICpa1fLli0TtTIIIEOGDLz//vsAXL16NUYSJ+qNvLOzM02bNrVqvvbt2wNw7tw5Dh06lKiY4jN79mwMBgMGg+GVaC1fp04dUzyWki9Go5E+ffqYCmu3a9eOWbNmJfp7lhQHDx6Ms/uc0WhkxIgRbNu2DYCyZctaXNFTpUoVPDw8AJgxYwYHDhwwO2bChAmmrWsDBw7E3t4+OW5BRERERN4U6fNC781Q+T3zMT9vmFoHzqxLtsu5umag9c8reTrpc+5kNF8llPPSY+526IZx4wQWDK7B3WrpLbagDw8OZ/P0k2xRC/okUwFpkf8cPnwYT09PcubMSevWralevTp58+Ylbdq0PHr0CC8vL2bOnIm3tzcA6dKlY8yYMUm65uDBg5k8eTLBwcF8//339OjRg+vXr5u2+DRt2hRnZ2er5mrXrp2pu9ncuXOpWrVqjPHly5fH6EC1d+9ei88BsmXLZiqSnVJmz54d4+N///3X9HzTpk0xkjuFChUybY9KiCFDhphWz5QqVYrhw4fHqO9jSalSpRJ8HWts2rSJ77//niZNmtCwYUNKlCiBu7s7z54948SJE8ycOdOUxHN2dmb69OmxJq1++eUXatSoQVBQEI0aNWL48OHUrVuXoKAgFi9ezLRp04DIAuaDBw9OkfsRERERkdecnSM0nwC5q8HagRAa8L+xZ49hSVeo/hE0GAW2yfPHxaqN3uFx1absGN6botsvxhhzDIWsv//F0b/38MWE3zlYugTTlpym9mNb0hpj/l584dBtfM89pOm7pchR2D1ZYkttDEZVYZIEit5NzNfXl1y5cr3kiJIuODiY/Pnz4+fnZ9XxhQsXZtGiRVSsWNFsrGfPnsyZMweAK1euxOgmZkn//v1N25YWLlzI5cuXTYWNFy1aFGfb8+eVLFmS06dPkzFjRm7evImDg4NpLF++fFy9etWqeWrXrs3OnTvNPj979mxTt7KvvvoqSauDErI6p0ePHmbJI4hcGRSVOLP0tU7IPUdJzH+J0eOI7fxRo0YxevToeOfKkycPCxcujLfOz9q1a+nWrRv+/pZbgRYpUoT169dTqFCheK8pIiIiIqnc3XOwpDvcO2c+lrsadJgFbub1LJPiwPo/CfvmZzI9tNBe3sHA0/fbk7/z//H54uOkP/OUYqGW17KUb5SHqi0LYGv3Zm58Sqn332/mV0skgZycnLhx4wb79u1j9OjRNG3alAIFCuDi4oKtrS1ubm4UK1aMTp06sXDhQk6ePGkxEZQYQ4cONW3j+fbbb01bxBwdHWnevHmC5mrXrh0A9+/fZ/369ckSnySPXr16MWXKFLp3707ZsmXJnj07Dg4OODs7kydPHlq3bs2MGTM4d+6cVQWfPT09OXHiBIMGDaJIkSI4Ozvj7u5OpUqVGDduHF5eXkoEiYiIiIh1MheFPn9D6Y7mY74H4Q8PuLwzWS9Zvfl7lNqwlXO185mNpQkxknnSMs70acKPTdwp3b4gm11DeYb5H169tlxj0bf/8OBmgNmYxE4rgyTB3sSVQSIiIiIiIqme0QhHZsKmYRAe8tygAep+AR6DwSZ515XsXTUZw/e/k+Gxef3WQEcDQQPeJmPzAQyf/y9FfULIHW5rPomtgZrtClGmTi4MNi++RmhK0cogEREREREREUk5BgNUfhfe3QLueZ4bNMKOb2BhBwi4n6yXrdnmI4qt28T5GrnNxpyfGcn44wKuDWzFlI7ZSdcsJ7vTmLegJ9zI3qUXWPWLWtBbQ8kgEREREREREfmfHOUj288XsdDZ+OI2mFoLfP9J1ktmzJybVjO2cOfrvjxMa56qyHPiNtdbt6bOvXUM+qgiW7Ib1II+CZQMEhEREREREZGY0qSHzguhwWgwPJc68L8Os5rCwT8it5Ylo9odB1Fo3TrOVcluNuYaZMT9u1k8GtmJKT3z4O+RkSNxtKDfPPMUz9SC3iIlg0RERERERETEnI0N1PwEeqwF16wxxyJCYdP/wbKeEGy5w21iZcman1ZztnPryx74u5jX/8l79CZ+bVvR0biLDn1KsyFDOE8M5kmpi4dvM2/UQW5eeJSs8b0JlAwSERERERERkdjlqwn99kA+D/Ox06thWh24fSpZL2kwGKjXdRh516zmfMUsZuNpAyJw+3oqNj/25Nf3CnC+vCtn7c1XAT17HMLKCcfYt+IC4WHm28pSK3UTk3iVLFkyxsehoaFcuHABUDcxERERERGRVCM8DHZ+C3smmI/ZpYEWE6Hc28l+WaPRyPY53+D2yyLSBpmnMB672mD7fx9xJmcTVqw6T+0ndjhivqLILbszzfuUJkMOl2SPMaWom5iIiIiIiIiIvDy2dlB/JLy9FJzcY46FBcHqD2DNAAgNStbLGgwGGvT8klx/LYLKLVIAACrRSURBVOdimYxm4+meRuD65SQyz/iIsX0LcbCIA7624WbH+d8KZNHYwxzf7osxInWvi9HKIEmwlMpMioiIiIiIyGvi4dXIekE3j5mPZSsNHedChgLJftmIiAi2/TmS9FNW4BpsPv7IzRb74Z+w29mDQxt9qBFkh62FVUJZC7vTpHdJXNM7JnuMyUkrg0RERERERETk1ZA+L/TeBJX7mI/5ecPU2nBmbbJf1sbGhkZ9vyHbqsVcKpHebNzdPxyXYRMovnIon7xXiO25DNy10IL+9oVHrJqXvHWOXidKBomIiIiIiIhIwtk5QvMfod0MsH+uDs8zf1jSDTZ/AeGhyX7p3PnL0nT5Hq5+5EmghcU9+XaeJ+Ld1owo50tIvSwccYxZXPqJwYh/0dendlByUzJIRERERERERBKvdHvouwMyFzMfOzAZZrcA/5vJfllbG1uafPQDGZfP43JRN7PxDI/CcB3yPR67RtPlnXxsyvS/FvSXCzjwfsMiyR7T60LJIBERERERERFJmsxFoc/fULqj+ZjvQfjDAy7tSJFL5ytcicYr9+HTrwnB9hbGt57GeUB7vqh+h6uV3NieNowv362IrY15LaHUQskgEREREREREUk6BxdoOw1a/AS2DjHHAu/BvDawcxxEmNfwSSo7WzuaDvoJt6Uz8Cnoajae8UEoaQd/Q8uTP/L9gDLkzuCc7DG8TpQMEhEREREREZHkYTBApd7w7hZwz/vcoBF2fgsL2kPA/RS5fMHib9Hgr31c7l3PbJWQDZBv4wn83m7MmV2rU+T6rwslg0REREREREQkeeUoD/12QdFm5mOXtsNUD/D9J0UubW/nQPOhv+G6YCrX8pmvAMp8L4SI9z9ny5gPUuT6rwMlg0REREREREQk+aVJD50XQsOvwWAbc8z/BsxqAgd/B6MxRS5fuEwt6q7dx6XutQh57vI2RnAqVDBFrvs6UDJIRERERERERFKGwQA1BkKPteCaLeZYRBhsGgbLekCwf4pc3sHeiRZfTMVx3mR8c6cxff5i1ZzU6jIkRa75OlAySERERERERERSVr4a0G835PMwHzv9F0yrA34nU+zyxSrUp866fVzsXI176W15a/yMFLvW60DJIBERERERERFJeWmzQvfV4DHYfOzBJfizPngtSLHLOzimwXPULCps20vGLM8Xt05dlAwSERERERERkRfD1g7qj4S3l4KTe8yxsGD460P46yMIDUqxEFxc3OM95k2nZJCIiIiIiIiIvFhFGsP7eyBHBfMxr3nwZ0O4f+nFx5VKKBkkIiIiIiIiIi+eex7ovQmq9DUfu+0dWUfo9JoXHlZqoGSQiIiIiIiIiLwcdo7QbDy0nwkOrjHHnvnD0u6w+QsID3058b2hlAwSERERERERkZerVDvoswMyFzcfOzAZZjeHxzdefFxvKCWDREREREREROTly1wE+myHMp3Nx3wPwVQPuPT3i4/rDWT3sgOQV1/JkiVjfBwaquV5IiIiIiIikgIcXKDNH5CnGmwcCuEh/xsLvA/z2kKdYVDrM7CxfXlxvua0MkhEREREREREXh0GA1TqBe9uBfe8zw0aYed3sKA9BNx/KeG9CQxGo9H4soOQ18v169fJnTs3AL6+vuTKleslRyQiIiIiIiJvpKCHsPpDOLfBfMwtJ3SYDbmrvPCwXpSUev+tlUEiIiIiIiIi8mpKkx46L4SGX4PhuW1h/jdgVlM4+DtonUuCKBkkIiIiIiIiIq8ugwFqDISe68A1W8yxiDDYNAyWvgPB/i8nvteQkkEiIiIiIiIi8urL+xa8vwfy1zIfO7MGptUBv5MvPKzXkZJBIiIiIiIiIvJ6cM0C3VdHdhN73oNL8Gd98Jr/wsN63SgZJCIiIiIiIiKvDxtbqDcCui6PrCkUXVgw/NU/8hEa9HLiew0oGSQiIiIiIiIir5/CDaHfHshZ0XzMaz782QDuX3rxcb0GlAwSERERERERkdeTe27otQmq9DMfu30yso7Q6TUvPKxXnZJBIiIiIiIiIvL6snOAZj9A+5ng4Bpz7Jk/LO0Om4ZDeOjLie8VpGSQiIiIiIiIiLz+SrWDvjshc3HzsYO/wezm8PjGCw/rVaRkkIiIiIiIiIi8GTIVhj7boUxn8zHfQzDVAy5uf/FxvWKUDBIRERERERGRN4eDC7T5Azx/AVvHmGOB92F+O9jxHUSEv5z4XgFKBomIiIiIiIjIm8VggIo94b2tkD7fc4NG2PU9rP/0JQT2alAySERERERERETeTNnLQt9dUKxFzM/bOkKld19OTK8AJYNERERERERE5M2Vxh06zYdG34DBNvJzzX+E7GVealgvk93LDkBEREREREREJEUZDPDWAMhZEc5thPLdX3ZEL5WSQSIiIiIiIiKSOuR9K/KRymmbmIiIiIiIiIhIKqJkkIiIiIiIiIhIKqJkkIiIiIiIiIhIKqKaQRKvkiVLxvg4NDT0JUUiIiIiIiIiIkmllUEiIiIiIiIiIqmIVgZJvE6dOhXj4+vXr5M7d+6XFI2IiIiIiIiIJIVWBomIiIiIiIiIpCJKBomIiIiIiIiIpCJKBomIiIiIiIiIpCJKBomIiIiIiIiIpCJKBomIiIiIiIiIpCJKBomIiIiIiIiIpCJKBomIiIiIiIiIpCJKBomIiIiIiIiIpCJKBomIiIiIiIiIpCJKBomIiIiIiIiIpCJKBomIiIiIiIiIpCJKBomIiIiIiIiIpCJKBomIiIiIiIiIpCJKBomIiIiIiIiIpCJKBomIiIiIiIiIpCJKBomIiIiIiIiIpCJKBomIiIiIiIiIpCJKBomIiIiIiIiIpCJKBomIiIiIiIiIpCJKBomIiIiIiIiIpCJKBomIiIiIiIiIpCJKBomIiIiIiIiIpCJ2LzsAef2EhYWZnt+6deslRiIiIiIiIiLy5or+njv6e/GkUjJIEuzu3bum51WqVHmJkYiIiIiIiIikDnfv3iVfvnzJMpe2iYmIiIiIiIiIpCIGo9FofNlByOslODgYb29vADJnzoydnRaYvSj16tUD4O+//37Jkbx5UvPX9nW/91c9/lchvpcRw4u6Zkpd59atW6bVr4cPHyZ79uzJOr+kXq/C/wlvstT69X0T7vtVvodXITa9lifcm/JaHhYWZtqdU7p0aZycnJJlXr2LlwRzcnKicuXKLzuMVMne3h6AXLlyveRI3jyp+Wv7ut/7qx7/qxDfy4jhRV3zRVwne/bsr+zPl7x+XoX/E95kqfXr+ybc96t8D69CbHotT5rX/bU8ubaGRadtYiIiIiIiIiIiqYiSQSIiIiIiIiIiqYiSQSIiIiIiIiIiqYgKSIuIiIg85/r16+TOnRsAX1/f17rOgIiISGqk1/K4aWWQiIiIiIiIiEgqomSQiIiIiIiIiEgqomSQiIiIiIiIiEgqoppBIiIiIiIiIiKpiFYGiYiIiIiIiIikIkoGiYiIiIiIiIikIkoGiYiIiIiIiIikIkoGiYiIiIiIiIikIkoGiYiIiIiIiIikIkoGiYiIiIiIiIikIkoGiYiIiCSjGzdu8PPPP9OoUSPy5MmDg4MD2bJlo127dhw6dOhlhyciIiLxCA4O5tNPP6VWrVrkyJEDJycnsmXLRo0aNZg1axahoaEvO8QkMxiNRuPLDkJERETkTTFs2DDGjRtHwYIFqVOnDpkzZ+bChQusXr0ao9HIwoUL6dSp08sOU0RERGJx7949cufOTZUqVShSpAiZM2fm4cOHbNy4katXr9KoUSM2btyIjc3ru75GySARERGRZLRy5UoyZsxI7dq1Y3x+z5491K9fH1dXV27duoWjo+NLilBERETiEhERQVhYGA4ODjE+HxYWRsOGDdm5cyfr1q2jefPmLynCpHt901giIiIir6C2bduaJYIAPDw8qFu3Lg8fPsTb2/slRCYiIiLWsLGxMUsEAdjZ2dGmTRsALl68+KLDSlZKBomIiMgb486dO6xbt46RI0fStGlTMmXKhMFgwGAw0LNnzwTNdfXqVQYPHkyxYsVwcXEhQ4YMVK5cmfHjxxMYGJio+Ozt7YHIXyZFRETE3Kv8Wh4REcGmTZsAKFWqVILPf5Vom5iIiIi8MQwGQ6xjPXr0YPbs2VbNs3btWrp164a/v7/F8SJFirB+/XoKFSpkdWzXrl2jSJEiZMiQAV9fX2xtba0+V0REJLV4lV7LQ0JC+PbbbzEajdy/f5/t27dz9uxZevXqxcyZM62K41WlP0uJiIjIGylPnjwUK1aMLVu2JOg8Ly8vOnXqRFBQEK6urnz++efUrVuXoKAgFi9ezPTp0zl//jzNmzfnyJEjpE2bNt45Q0ND6d69O8+ePWPcuHFKBImIiFjhZb+Wh4SEMHr0aNPHBoOBIUOG8N133yXpvl4FSgaJiIjIG2PkyJFUrlyZypUrkzVrVnx8fMifP3+C5hg4cCBBQUHY2dmxZcsWqlevbhqrV68ehQsXZujQoZw/f54JEyYwatSoOOeLiIigZ8+e7N69mz59+tC9e/fE3JqIiEiq8Cq9lru6umI0GomIiODmzZusXbuW4cOHc+DAATZs2ICbm1tSbvWl0jYxEREReWNF/wXSmqXlhw8fpmrVqgD069ePP/74w+yYiIgISpUqxZkzZ3B3d+fOnTumWkCWju3duzdz5syhW7duzJkz57VuQysiIvKivezX8uctW7aMjh07MnToUMaNG5ewm3mF6LcRERERkf+sXr3a9LxXr14Wj7GxseGdd94B4NGjR+zYscPicREREfTq1Ys5c+bQpUsXZs+erUSQiIhICkvO13JLGjVqBMDOnTsTHeOrQL+RiIiIiPxn7969ALi4uFCxYsVYj4veOn7fvn1m41GJoLlz59KpUyfmzZunOkEiIiIvQHK9lsfm5s2bAFavJHpVKRkkIiIi8p8zZ84AUKhQoTjbvxcrVszsnChRW8Pmzp1Lhw4dmD9/vhJBIiIiL0hyvJafPn3aYuv5wMBAPv30UwCaNWuWHOG+NCogLSIiIgIEBwdz7949AHLlyhXnsenTp8fFxYWAgAB8fX1jjH399dfMmTMHV1dXihQpwjfffGN2fuvWrSlXrlyyxS4iIiLJ91q+dOlSJk6cSM2aNcmXLx9ubm7cuHGDjRs3cv/+fTw8PBg0aFCK3ceLoGSQiIiICPDkyRPTc1dX13iPj/oF8unTpzE+7+PjA8DTp08ZO3asxXPz5cunZJCIiEgyS67X8hYtWnDz5k3279/PgQMHePr0KenSpaNMmTJ07tyZ3r17x7nq6HXwekcvIiIikkyCg4NNzx0cHOI93tHREYCgoKAYn589e3a8nU5EREQk+SXXa3mlSpWoVKlS8gb3ilHNIBERERHAycnJ9DwkJCTe4589ewZAmjRpUiwmERERsZ5ey62nZJCIiIgIkDZtWtPz55eLWxIQEABYtwxdREREUp5ey62nZJCIiIgIkX9NzJgxIwDXr1+P89iHDx+afoHMnTt3iscmIiIi8dNrufWUDBIRERH5T4kSJQC4ePEiYWFhsR539uxZ0/PixYuneFwiIiJiHb2WW0fJIBEREZH/1KxZE4hcNn706NFYj9u1a5fpeY0aNVI8LhEREbGOXsuto2SQiIiIyH9at25tej5r1iyLx0RERDB37lwA3N3dqVu37osITURERKyg13LrKBkkIiIi8p8qVarg4eEBwIwZMzhw4IDZMRMmTODMmTMADBw4EHt7+xcao4iIiMROr+XWMRiNRuPLDkJEREQkOezdu5eLFy+aPr537x6fffYZELkE/L333otxfM+ePc3m8PLyokaNGgQFBeHq6srw4cOpW7cuQUFBLF68mGnTpgFQpEgRjhw5EqNziYiIiCSNXstfDCWDRERE5I3Rs2dP5syZY/Xxsf0atHbtWrp164a/v7/F8SJFirB+/XoKFSqUqDhFRETEMr2WvxjaJiYiIiLyHE9PT06cOMGgQYMoUqQIzs7OuLu7U6lSJcaNG4eXl1eq/eVRRETkdaDX8rhpZZCIiIiIiIiISCqilUEiIiIiIiIiIqmIkkEiIiIiIiIiIqmIkkEiIiIiIiIiIqmIkkEiIiIiIiIiIqmIkkEiIiIiIiIiIqmIkkEiIiIiIiIiIqmIkkEiIiIiIiIiIqmIkkEiIiIiIiIiIqmIkkEiIiIiIiIiIqmIkkEiIiIiIiIiIqmIkkEiIiIiIiIiIqmIkkEiIiIiIiIiIqmIkkEiIiIiIiIiIqmIkkEiIiIiIiIiIqmIkkEiIiIiIiIiIqmIkkEiIiIiIiIiIqmIkkEiIiIiIvL/7d17VJTH+Qfw73JZUBQVBARFoCoSIwJVLtGkoIkaJVStBm1RgWgTqbU5iZ4YqY2p1ZoGoifVWJUG0CieWLwEogWVi5GikRVQNkSjAWJQE/ESL3gBcX5/8Nu3C7vvu8tFSMP3c86es+7MO+8z786avI/zzhARURfCZBARERERyQoLC4NKpUJYWFhnh/Kjtn//fkycOBF9+/aFpaUlVCoVevfu3dlhERERGWXV2QEQEREREf0v27hxIxYuXNjZYRAREZmNM4OIiIiIiFrp7t27iI+PBwD4+PggPT0dJSUlKCsrw7FjxzokhpiYGKhUKnh6enbI+YiI6H8fZwYREREREbWSRqPBzZs3AQCJiYkIDw/v5IiIiIhM48wgIiIiIqJWunjxovTe29u7EyMhIiIyH5NBRERERESt9ODBA+m9tbV1J0ZCRERkPiaDiIiI2kir1WLVqlWYOHEiBgwYABsbG/To0QNDhgxBdHQ0jh8/bvS4u3fvomfPnlCpVIiKijJ5nmPHjkGlUkGlUmHjxo1G63z33Xf44x//iFGjRsHBwQE2NjZwd3dHZGQkDh8+LNt2VVWV1HZqaioAYM+ePZg8eTLc3NxgZWVlsJvU8ePHsXz5coSFhaFfv35Qq9Wwt7fHsGHDEBcXh/LycpN9AoALFy4gLi4OXl5esLW1hZubG6ZOnYq8vDwAwNtvvy3FpuTmzZtYs2YNxowZAycnJ6jVari6uiIiIgLp6ekQQpgVjzHGrs+hQ4cQERGBfv36wcbGBl5eXoiLi0N1dbVsO+au7ZKamiqdr6qqyqDc09MTKpUKMTExAIDi4mJERUXB3d0d3bp1w+DBg/H666/j6tWrTY4rLCzEiy++iIEDB8LW1haDBg3C0qVLcfv2bbOvxdmzZ/Hyyy9L35erqysiIyNlx3lzHTlGzVVTU4Ply5cjICAAvXv3hq2tLTw9PTFnzhwUFBQYPUa3y1psbKz0mZeXlxSjSqVCfn5+i+K4f/8+/v73vyMsLAxOTk6wtraGg4MDhg4dikmTJmHt2rVNxoPut7F161YAwDfffNPk/Eq/m/v372PDhg149tlnpd+vs7MznnvuOXz44Yd4+PChbJzNx19RURF+/etfw93dHba2tnB3d0dsbCzOnDnTrv0lIqJ2JIiIiKjV8vLyBACTrzfffNPo8bNnzxYAhJ2dnbhz547iuRYuXCgACCsrK1FTU2NQvn37dmFnZ6cYx7x580R9fb3BsZWVlVKd5ORkMWfOHINjQ0NDpfopKSkm+2xpaSk++OADxT7l5OSIHj16GD1epVKJ1atXixUrVkifyTl8+LBwdHRUjGfy5Mni9u3bivHI0b8+KSkp4s0335Q9j5OTkygvLzfaTnR0tAAgPDw8FM+nf30rKysNyj08PAQAER0dLbZt2ybUarXRWLy9vcXly5eFEEIkJCQIlUpltN7Pf/5z2WsTGhoqff8HDhyQHWMWFhZi3bp1iv3qyDFqruzsbGFvb68Y08KFC0VDQ4PR66L0ysvLMzuOS5cuiWHDhplsc/HixdIx+r8NpVdzpaWl0hiSewUGBorvvvvOaKz64+/DDz8UVlZWRtuwsbERu3btarf+EhFR+2EyiIiIqA0OHTok7OzsRGRkpNi0aZPIz88XxcXFIisrS7z33ntNbriSk5MNjv/3v/8tle/YsUP2PPX19cLZ2VkAEOHh4QblH3/8sXSj/7Of/UysXbtWZGVliZMnT4rdu3eLyZMnS+d57bXXDI7Xv9EeMWKEACCeeeYZkZaWJjQajTh8+LD45z//KdVPSkoSffr0ETExMSI5OVkcPXpUFBcXi08//VSsXLlS9O3bV0ro5OTkGO3T119/LSUGrKysxKJFi0ROTo4oKioSKSkp0o1icHCwYjKooKBAWFtbCwDCxcVFrFq1SmRmZoqTJ0+KzMxMKeEGQPzqV7+SvcZK9K/P6NGjpcSD/vWZO3euVCckJMRoO+2dDPL39xdqtVoMGzZMJCcni6KiIpGbm9ukz1FRUWL37t1SXDt27BAajUZkZWU1GRdLly41Gosu6TFkyBDRu3dv0atXL/HXv/5VFBYWisLCQrF69eomyZS9e/cabaejx6g5SkpKpESatbW1eO2110ReXp44ceKE2Lx5s/Dy8pLO+cYbbzQ5tqKiQpSVlYlVq1ZJdbKzs0VZWZn0MpXg1Td9+nSpndmzZ4s9e/aI48ePi6KiIpGRkSHeeust4efn1yQ58v3334uysjIxZcoUAUC4ubk1Ob/upe/cuXOiV69eAoCwt7cXy5YtE3v37hUajUZkZ2eLhQsXSsmd4OBgUVdXZxCrbvz5+fkJa2tr4ebmJtavXy8+//xzceTIEbF06VJhY2MjXdeioqJ26S8REbUfJoOIiIjaoKamRty4cUO2/MGDB2L8+PFSAuDhw4dNyk0leXT0k0ZpaWkGMehu7l566SWjsyqEECI+Pl6awXHmzJkmZfo32gDE3LlzxaNHj2Tjqa6uFrW1tbLlP/zwg3TD/vTTTxutM3XqVMUEQm1trQgKClKc4VBXVyc8PT0FAPH888/LxrRlyxapjYMHD8rGLaf59fntb39r9PrMnz9fqlNcXGxQ3t7JIF1yyli/Z8yYIYDGGVoODg5i+vTpBuPv4cOHIiQkRAAQjo6ORseO/gyYXr16GZ31pNVqpYRQ//79DRIInTFGzREYGChdo+zsbIPy69evS0lJCwsLodVqDeqY+q7Mce/ePSmhaSr5ce3aNYPPzB1XQggpmRkQEGB0hqEQjX/fWFhYCABiy5YtBuX648/Dw0OafaYvNzdXSioFBgY2KWtrf4mIqO2YDCIiInrMSktLpRsnjUZjUL5o0SLpX9CvXr1qtA3dTI8ePXoY3PivXLlSugm/f/++bBz19fWif//+AoCIj49vUqZ/o927d29x69atVvS0qX379kltNu/XxYsXhaWlpQAgZsyYIduG/rUzlgzatm2bACBsbW3FlStXFOPRJZZ+85vftLgv+tfH1dVV9jqfOXNGqvf+++8blLd3MkilUsk+kpabmyu10b17d9mb6uTkZKneqVOnDMr1k0GJiYmyMf/tb3+T6v3rX/9qUvZjHKOff/651N6CBQtk6xUUFEj1fve73xmUt0cy6OLFi1Ibn3zySYuPN3dcffbZZ9J5Tp8+rVg3MjJSSjY2p58MSk9Pl20jLi5Oqqc/O6it/SUiorbjAtJERETt6MGDB7hw4QLKy8uh1Wqh1WqbLFx86tQpg2N0i0fX19dj165dBuX37t3Dvn37AABTp05F9+7dm5RnZGQAAF544QXY2NjIxmZlZYWnnnoKQONi1HIiIiLQs2dP2XJjamtrUVVVhS+++ELqt/7OSs37nZeXh4aGBgDAnDlzZNv18/ODn5+fbLmu76GhoXByclKM8Re/+AUA5b6bY8aMGbLXeejQoejRowcAoKKiok3nMceIESPwxBNPGC3Tv27jx4+Hg4ODyXpKMatUKkRHR8uWx8bGSosVN18I+scwRpvTj3HevHmy9caMGSNdY6UFrtvC0dERarUaAPDRRx8pLt7cFrrvYejQofD19VWsq/u9FBUVycbTp08fTJkyRbaNl156SXqvf+06qr9ERCSPySAiIqI2qq2txZo1a+Dn5wc7Ozt4eHjgySefhK+vL3x9fREQECDVbb67EwAEBwdj0KBBAIAdO3YYlGdkZODOnTsAYLDrWENDA0pLSwEAmzdvNrqTkP4rPT0dQOOOTnJGjBhhVr+vXr2K+Ph4DB06FD179oSXlxeGDx8u9Ts8PFy231qtVno/cuRIxfOMGjVKtkyj0QAAsrOzTfY9MTERgHLfzeHj46NY3qdPHwBo0Q5dreXt7S1b1rt37xbXU4rZy8sLffv2lS13cnKSdkkrKyuTPu/MMapENwbVajX8/f0V6wYHBwMAzp07h7q6ujafuzkbGxvMnDkTAJCeno7BgwfjjTfewIEDB/DDDz+023l0v5ezZ8+a/B5+//vfA2hMUl+/ft1oewEBAbCyspI9n7+/v5T00R8THdVfIiKSx2QQERFRG1RVVcHX1xfx8fE4ffq0NNtFzr1794x+rkvyFBYWGmylrEsQ6bZ91nf9+vVW/av63bt3Zct0yQwlJ0+ehI+PD9asWYOvvvrK5Lbtzft948YN6b2pGT1K5VeuXDEZq6lYWqr5zKzmLCwa//fK1FhoD0qx6OJoST2lmJ2dnU3G4+LiAgBNkgedNUZN0cXo4OCgmNAAgH79+gEAhBBNxm572rBhAyIiIgA0bhGfkJCA8PBwODo6IjAwEAkJCbh582abztGa3wsg/12YGhNWVlbSjLTmCaWO6C8REclT/i8fERERKZozZw4qKyuhUqkQGxuLWbNm4YknnoCTkxPUajVUKhUePXoES0tLAJBNmkRFRWHlypUQQmDnzp1YtmwZgMYbqOzsbADAzJkzDW5a9W/e58+fj1dffdWsuHX/Wm+MLlY5dXV1iIyMxLVr12BtbY1FixZhypQp8Pb2Rp8+faTHgCoqKqQZT6aSRa2l6/+kSZPw7rvvPpZzUCPdI2At1RljtCVa26/2Zm9vj4yMDJw4cQK7du1Cfn4+SktL0dDQAI1GA41Gg8TEROzbt096lK6ldN+Fn58ftm/fbvZx/fv3N/p5W65dR/SXiIjkMRlERETUSmfOnEFBQQEAID4+HqtWrTJaT+4RC33e3t4YNWoUNBoN0tLSpGRQenq69FhK80fEADRZB0YIgeHDh7e4Hy2Vm5srrS2zceNGzJ8/32g9pX7rz+yoqamRvdnUlctxdHTEpUuXUFdX1yF9byvdLJxHjx4p1qutre2IcFrk+++/N7uO/rjsjDFqDl1c165dw8OHDxVnB+keWVOpVO0yK0lJUFAQgoKCADQ+tpefn4/U1FTs2bMHV65cwfTp0/H111+jW7duLW7b0dERAHDnzp12+R5MjYmHDx82mYFlzOPsLxERyeNjYkRERK30xRdfSO91618Yo1unwxRdsker1eL06dMA/vuI2KBBg6R1S/Sp1Wo8+eSTAID//Oc/5gXeRu3Rb13MQOMjZ0qU2tGtx6TRaB7LWi7tTbfosal1Ub766qsOiKZlKisrce3aNdnympoa6RFH/URDZ4xRc+hirKurk9Y0knPixAkAwJAhQxRnLLW3nj17IiIiArt378Yf/vAHAMDly5elJLSOuTN0dL+XioqKNq+dBQClpaWKjwCeOnVK+l2ak3wyt79ERNR2TAYRERG1kv5NkNJMjk2bNpnV3qxZs6THX3bs2IHq6mocPXoUgPFZQTq//OUvATTOVNI9UvY4mdPvR48eISkpSbaNsLAwaZbMRx99JFvv1KlTRndg09H1/ebNm0hJSVGM+8fAy8sLQOMMiLNnzxqtU1dXh927d3dkWGYRQmDbtm2y5ampqdLjgM3XturoMWoO/RiTk5Nl6x07dgzl5eUGx3S0Z599VnrffEF2W1tbAI27GSrRfQ9CCLz//vttjun69evIzMyULde/ri29dkr9JSKitmMyiIiIqJWGDBkivU9NTTVa5x//+Ac++eQTs9rr168fxo0bBwDYuXMn0tLSpJtrpWTQq6++Km1nHhsb22TmjjH79++XZh61hjn9XrZsGYqLi2XbGDBggLTbWHp6Ovbt22dQ5969e3j55ZcVY4mOjoa7uzsAYMmSJfjss88U6xcUFODIkSOKdR6n0NBQ6f17771ntM7rr7+OixcvdlRILfKXv/zFaBLryy+/xOrVqwEArq6uBtuNd/QYNUdQUJC0U11SUhJycnIM6ty8eROvvPIKgMZH/OLi4h5LLBUVFSbH5cGDB6X3uqSijqurK4DGBaKVdoSbMGGC9EhWQkICdu3apXjOsrIyxWQP0DhejT0uduTIEWzZsgVA446BgYGBUllb+0tERG3HNYOIiIhaKSAgAMOHD4dWq8XmzZtx48YNzJkzB66urqiursb27duRnp6OMWPGmP14TFRUFA4dOoRvv/0Wa9asAdC4tbrS1uAuLi7YunUrZsyYgcuXL2PUqFGIiYnBpEmTMGDAANTX16O6uhonTpxAeno6KioqkJmZ2ertuSdOnAhnZ2dcuXIFy5cvR1VVFaZNm4a+ffvi/Pnz0o21qX6vXbsWOTk5uHv3Ll588UXExcVh2rRpsLe3h1arxbvvvovy8nIEBgaiqKjIaBs2NjbYtWsXwsLCcOfOHYwbNw6zZs3C1KlT4eXlhUePHuHy5cs4efIk9u7di7KyMqxfv75JUqYjBQQE4KmnnsKxY8eQlJSEuro6REdHo1evXjh37hy2bNmC3NxcjB49GoWFhZ0So5zBgwejpqYGISEhWLp0KcLCwgAA+fn5eOedd6Sdn9avX2/wKFVHj1FzJSUlITg4GHV1dZg8eTIWLVqEiIgI2NnZoaSkBO+88460PtaSJUse23pHFy5cwNixYzFs2DBMmzYNo0aNktbR+vbbb/Hxxx9LiRt/f3+DR0ZHjx4NoHFG3oIFC7Bo0SL07dtXKh88eLD0Pi0tDUFBQbh+/TpmzpyJ7du3Y+bMmRgyZAgsLS1x5coVlJSUIDMzE8ePH8fixYulXb+a8/PzQ3l5OUaOHIlly5YhKCgIDx48wIEDB7Bu3TppLaYPPvigXftLRETtQBAREVGrlZSUiD59+ggARl++vr7i0qVL0p9XrFih2N6tW7dEt27dmrSxbt06s2LJyMgQDg4OsrHoXhYWFiI3N7fJsZWVlVJ5SkqKyXNlZWUJW1tb2XOEhYUJrVZrss2DBw8KOzs72XZWrFgh/vSnPwkAwtbWVjaeY8eOCXd3d5N9ByC2bt1q1vVs7fXx8PAQAER0dLTR8i+//FI4OzvLxrdkyRKRkpIi/bmysrLF59AxZ9yZ6ltoaKgAIEJDQ8Wnn34qunfvLjuuEhMTFePpyDFqruzsbGFvb68Yz8KFC0VDQ4PR4019V+bIy8sza+z6+PiIiooKg+MbGhpESEiI7HHNnT17VgwfPtysc/75z382OF5//CUlJQkrKyujx6rVarFz58527y8REbUdHxMjIiJqA39/f5SWlmLBggXw8PCAtbU1HBwcEBQUhMTERJw4cUJ6hMMcugVUdSwtLTFr1iyzjo2IiEBlZSUSExMxbtw4uLi4wNraGt26dYOXlxdeeOEFrF27FlVVVRg7dmyL+6pv4sSJ0Gg0mD17Ntzc3GBtbQ0nJyeEhoZiy5YtyMnJgZ2dncl2xo8fD61Wi1deeQUeHh5Qq9VwcXFBeHg4srKy8Pbbb+PWrVsAgF69esm2ExISgnPnzmHTpk0IDw+Hm5sb1Go1bG1t4e7ujgkTJmD16tU4c+YM5s6d26a+t5WPjw+Ki4sRFxcn9dnJyQnPP/889u/fj4SEhE6NT0l4eDg0Gg1iY2Ol2J2dnTF9+nQUFBRg8eLFisd35Bg114QJE3D+/HnEx8fD398f9vb2sLGxwcCBAxEVFYWjR49iw4YN0hpXj8MzzzyD/Px8LFu2DGPHjsXgwYPRs2dPWFtbw8XFBRMmTMCmTZtQWlpq9JEpCwsLHDx4EMuXL4efnx969OihuKi0t7c3SktLkZaWhunTp2PgwIHo1q0b1Go1XF1dERYWhuXLl+PkyZN46623FGOfP38+jh49isjISOl3179/f8ydOxclJSVG//5qa3+JiKjtVEL8/2IERERERD9Czz33HHJycvD0009LC2oTUefx9PTEN998g+joaNl1w4iI6MeNM4OIiIjoR+vSpUvSotAhISGdHA0RERHRTwOTQURERNRpzp8/L1t27949xMTEoL6+HgA6/fEuIiIiop8K7iZGREREnWb+/Pmora1FZGQkRo4cCQcHB9y+fRsajQYbN26UkkXz5s2Dr69vJ0dLRERE9NPAZBARERF1Ko1GA41GI1s+bdo0rF+/vgMjIiIiIvppYzKIiIiIOs3atWuxd+9e5Obmorq6GjU1NRBCwNnZGSEhIYiOjsbkyZM7O0wiIiKinxTuJkZERERERERE1IVwAWkiIiIiIiIioi6EySAiIiIiIiIioi6EySAiIiIiIiIioi6EySAiIiIiIiIioi6EySAiIiIiIiIioi6EySAiIiIiIiIioi6EySAiIiIiIiIioi6EySAiIiIiIiIioi6EySAiIiIiIiIioi6EySAiIiIiIiIioi6EySAiIiIiIiIioi6EySAiIiIiIiIioi6EySAiIiIiIiIioi6EySAiIiIiIiIioi6EySAiIiIiIiIioi6EySAiIiIiIiIioi6EySAiIiIiIiIioi6EySAiIiIiIiIioi6EySAiIiIiIiIioi7k/wD8i/mVLIkWYQAAAABJRU5ErkJggg==" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "draw_order_multiple(\n", + " [\n", + " out_SLOWRK_time_sde,\n", + " out_SPaRK_time_sde,\n", + " out_GenShARK_time_sde,\n", + " out_ShARK_time_sde,\n", + " out_SRA1_time_sde,\n", + " ],\n", + " [\"SlowRK\", \"SPaRK\", \"GeneralShARK\", \"ShARK\", \"SRA1\"],\n", + " title=\"Order of convergence on an additive noise SDE\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "12b1cd557c8975e5", + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-04-02T15:10:45.893943Z", + "start_time": "2024-04-02T15:10:45.892518Z" + } + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/usage/how-to-choose-a-solver.md b/docs/usage/how-to-choose-a-solver.md index 713b4cc8..669ef2e9 100644 --- a/docs/usage/how-to-choose-a-solver.md +++ b/docs/usage/how-to-choose-a-solver.md @@ -73,9 +73,10 @@ For Itô SDEs: For Stratonovich SDEs: -- If cheap low-accuracy solves are desired then [`diffrax.EulerHeun`][] is a good choice. -- Otherwise, and if the noise is commutative, then [`diffrax.StratonovichMilstein`][] is a typical choice. -- Otherwise, and if the noise is noncommutative, then [`diffrax.Heun`][] is a typical choice. +- If cheap low-accuracy solves are desired then [`diffrax.EulerHeun`][] is a typical choice. +- Otherwise, and if the noise is commutative, then [`diffrax.SlowRK`][] has the best order of convergence, but is expensive per step. [`diffrax.StratonovichMilstein`][] is a good cheap alternative. +- If the noise is noncommutative, [`diffrax.GeneralShARK`][] is the most efficient choice, while [`diffrax.Heun`][] is a good cheap alternative. +- If the noise is noncommutative and an embedded method for adaptive step size control is desired, then [`diffrax.SPaRK`][] is the recommended choice. ### Additive noise @@ -85,10 +86,10 @@ $\mathrm{d}y(t) = μ(t, y(t))\mathrm{d}t + σ(t, y(t))\mathrm{d}w(t)$ Then the diffusion matrix $σ$ is said to be additive if $σ(t, y) = σ(t)$. That is to say if the diffusion is independent of $y$. -In this case then the Itô solution and the Stratonovich solution coincide, and mathematically speaking the choice of Itô vs Stratonovich is unimportant. +In this case the Itô solution and the Stratonovich solution coincide, and mathematically speaking the choice of Itô vs Stratonovich is unimportant. Special solvers for additive-noise SDEs tend to do particularly well as compared to the general Itô or Stratonovich solvers discussed above. -- The cheapest (but least accurate) solver is [`diffrax.Euler`][]. -- Otherwise [`diffrax.Heun`][] is a good choice. It gets first-order strong convergence and second-order weak convergence. +- The cheapest (but least accurate) solver is [`diffrax.SEA`][]. +- Otherwise [`diffrax.ShARK`][] or [`diffrax.SRA1`][] are good choices. --- diff --git a/mkdocs.yml b/mkdocs.yml index 968104d6..03b85737 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -138,3 +138,4 @@ nav: - Developer Documentation: - 'devdocs/predictor_dirk.md' - 'devdocs/adjoint_commutative_noise.md' + - 'devdocs/srk_example.ipynb' diff --git a/test/helpers.py b/test/helpers.py index 87b94b50..ed20ac46 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -1,4 +1,5 @@ -from typing import Callable +import dataclasses +from typing import Callable, Optional, Union import diffrax import equinox as eqx @@ -7,7 +8,16 @@ import jax.random as jr import jax.tree_util as jtu import optimistix as optx -from jaxtyping import Array, PRNGKeyArray, PyTree, Shaped +from diffrax import ( + AbstractBrownianPath, + AbstractTerm, + ControlTerm, + MultiTerm, + ODETerm, + VirtualBrownianTree, +) +from jax import Array +from jaxtyping import PRNGKeyArray, PyTree, Shaped all_ode_solvers = ( @@ -81,7 +91,7 @@ def tree_allclose(x, y, *, rtol=1e-5, atol=1e-8, equal_nan=False): return eqx.tree_equal(x, y, typematch=True, rtol=rtol, atol=atol) -def _path_l2_dist( +def path_l2_dist( ys1: PyTree[Shaped[Array, "repeats times ?*channels"], " T"], ys2: PyTree[Shaped[Array, "repeats times ?*channels"], " T"], ): @@ -102,32 +112,57 @@ def sum_square_diff(y1, y2): return dist +def _get_minimal_la(solver): + while isinstance(solver, diffrax.HalfSolver): + solver = solver.solver + return getattr(solver, "minimal_levy_area", diffrax.BrownianIncrement) + + +def _abstract_la_to_la(abstract_la): + if issubclass(abstract_la, diffrax.AbstractSpaceTimeTimeLevyArea): + return diffrax.SpaceTimeTimeLevyArea + elif issubclass(abstract_la, diffrax.AbstractSpaceTimeLevyArea): + return diffrax.SpaceTimeLevyArea + elif issubclass(abstract_la, diffrax.AbstractBrownianIncrement): + return diffrax.BrownianIncrement + else: + raise ValueError(f"Unknown levy area {abstract_la}") + + @eqx.filter_jit -@eqx.filter_vmap(in_axes=(0, None, None, None, None, None, None, None, None, None)) +@eqx.filter_vmap( + in_axes=(0, None, None, None, None, None, None, None, None, None, None, None, None) +) def _batch_sde_solve( key: PRNGKeyArray, get_terms: Callable[[diffrax.AbstractBrownianPath], diffrax.AbstractTerm], - levy_area: type[diffrax.AbstractBrownianReturn], - solver: diffrax.AbstractSolver, w_shape: tuple[int, ...], t0: float, t1: float, - dt0: float, y0: PyTree[Array], args: PyTree, + solver: diffrax.AbstractSolver, + levy_area: Optional[type[diffrax.AbstractBrownianIncrement]], + dt0: Optional[float], + controller: Optional[diffrax.AbstractStepSizeController], + bm_tol: float, + saveat: diffrax.SaveAt, ): - # TODO: add a check whether the solver needs levy area + abstract_levy_area = _get_minimal_la(solver) if levy_area is None else levy_area + concrete_la = _abstract_la_to_la(abstract_levy_area) dtype = jnp.result_type(*jtu.tree_leaves(y0)) struct = jax.ShapeDtypeStruct(w_shape, dtype) bm = diffrax.VirtualBrownianTree( t0=t0, t1=t1, shape=struct, - tol=2**-14, + tol=bm_tol, key=key, - levy_area=levy_area, # pyright: ignore + levy_area=concrete_la, # pyright: ignore ) terms = get_terms(bm) + if controller is None: + controller = diffrax.ConstantStepSize() sol = diffrax.diffeqsolve( terms, solver, @@ -136,61 +171,293 @@ def _batch_sde_solve( dt0=dt0, y0=y0, args=args, - max_steps=None, + max_steps=2**19, + stepsize_controller=controller, + saveat=saveat, ) - return sol.ys + return sol.ys, sol.stats["num_accepted_steps"] + + +def _resulting_levy_area( + levy_area1: type[diffrax.AbstractBrownianIncrement], + levy_area2: type[diffrax.AbstractBrownianIncrement], +) -> type[diffrax.AbstractBrownianIncrement]: + """A helper that returns the stricter Levy area. + **Arguments:** + - `levy_area1`: The first Levy area type. + - `levy_area2`: The second Levy area type. + + **Returns:** + + `BrownianIncrement`, `SpaceTimeLevyArea`, or `SpaceTimeTimeLevyArea`. + """ + if issubclass(levy_area1, diffrax.AbstractSpaceTimeTimeLevyArea) or issubclass( + levy_area2, diffrax.AbstractSpaceTimeTimeLevyArea + ): + return diffrax.SpaceTimeTimeLevyArea + elif issubclass(levy_area1, diffrax.AbstractSpaceTimeLevyArea) or issubclass( + levy_area2, diffrax.AbstractSpaceTimeLevyArea + ): + return diffrax.SpaceTimeLevyArea + elif issubclass(levy_area1, diffrax.AbstractBrownianIncrement) or issubclass( + levy_area2, diffrax.AbstractBrownianIncrement + ): + return diffrax.BrownianIncrement + else: + raise ValueError("Invalid levy area types.") + + +@eqx.filter_jit def sde_solver_strong_order( + keys: PRNGKeyArray, get_terms: Callable[[diffrax.AbstractBrownianPath], diffrax.AbstractTerm], w_shape: tuple[int, ...], - solver: diffrax.AbstractSolver, - ref_solver: diffrax.AbstractSolver, t0: float, t1: float, - dt_precise: float, y0: PyTree[Array], args: PyTree, - num_samples: int, - num_levels: int, - key: PRNGKeyArray, + solver: diffrax.AbstractSolver, + ref_solver: diffrax.AbstractSolver, + levels: tuple[int, int], + ref_level: int, + get_dt_and_controller: Callable[ + [int], tuple[float, diffrax.AbstractStepSizeController] + ], + saveat: diffrax.SaveAt, + bm_tol: float, ): - dtype = jnp.result_type(*jtu.tree_leaves(y0)) - # TODO: add a check whether the solver needs levy area - levy_area = diffrax.BrownianIncrement - keys = jr.split(key, num_samples) # deliberately reused across all solves + levy_area1 = _get_minimal_la(solver) + levy_area2 = _get_minimal_la(ref_solver) + # Stricter levy_area requirements inherit from less strict ones + levy_area = _resulting_levy_area(levy_area1, levy_area2) + + level_coarse, level_fine = levels - correct_sols = _batch_sde_solve( + dt, step_controller = get_dt_and_controller(ref_level) + correct_sols, _ = _batch_sde_solve( keys, get_terms, - levy_area, - ref_solver, w_shape, t0, t1, - dt_precise, y0, args, + ref_solver, + levy_area, + dt, + step_controller, + bm_tol, + saveat, ) - dts = 2.0 ** jnp.arange(-3, -3 - num_levels, -1, dtype=dtype) - @jax.jit - @jax.vmap - def get_single_err(dt): - sols = _batch_sde_solve( + errs_list, steps_list = [], [] + for level in range(level_coarse, level_fine + 1): + dt, step_controller = get_dt_and_controller(level) + sols, steps = _batch_sde_solve( keys, get_terms, - levy_area, - solver, w_shape, t0, t1, - dt, y0, args, + solver, + levy_area, + dt, + step_controller, + bm_tol, + saveat, + ) + errs = path_l2_dist(sols, correct_sols) + errs_list.append(errs) + steps_list.append(jnp.average(steps)) + errs_arr = jnp.array(errs_list) + steps_arr = jnp.array(steps_list) + order, _ = jnp.polyfit(jnp.log(1 / steps_arr), jnp.log(errs_arr), 1) + return steps_arr, errs_arr, order + + +@dataclasses.dataclass(frozen=True) +class SDE: + get_terms: Callable[[AbstractBrownianPath], AbstractTerm] + args: PyTree + y0: PyTree[Array] + t0: float + t1: float + w_shape: tuple[int, ...] + + def get_dtype(self): + return jnp.result_type(*jtu.tree_leaves(self.y0)) + + def get_bm( + self, + bm_key: PRNGKeyArray, + levy_area: type[Union[diffrax.BrownianIncrement, diffrax.SpaceTimeLevyArea]], + tol: float, + ): + shp_dtype = jax.ShapeDtypeStruct(self.w_shape, dtype=self.get_dtype()) + return VirtualBrownianTree(self.t0, self.t1, tol, shp_dtype, bm_key, levy_area) + + +# A more concise function for use in the examples +def simple_sde_order( + keys, + sde: SDE, + solver, + ref_solver, + levels, + get_dt_and_controller, + saveat, + bm_tol, +): + _, level_fine = levels + ref_level = level_fine + 2 + return sde_solver_strong_order( + keys, + sde.get_terms, + sde.w_shape, + sde.t0, + sde.t1, + sde.y0, + sde.args, + solver, + ref_solver, + levels, + ref_level, + get_dt_and_controller, + saveat, + bm_tol, + ) + + +def simple_batch_sde_solve( + keys, sde: SDE, solver, levy_area, dt0, controller, bm_tol, saveat +): + return _batch_sde_solve( + keys, + sde.get_terms, + sde.w_shape, + sde.t0, + sde.t1, + sde.y0, + sde.args, + solver, + levy_area, + dt0, + controller, + bm_tol, + saveat, + ) + + +def _squareplus(x): + return 0.5 * (x + jnp.sqrt(x**2 + 4)) + + +def drift(t, y, args): + mlp, _, _ = args + return 0.25 * mlp(y) + + +def diffusion(t, y, args): + _, mlp, noise_dim = args + return 1.0 * mlp(y).reshape(3, noise_dim) + + +def get_mlp_sde(t0, t1, dtype, key, noise_dim): + driftkey, diffusionkey, ykey = jr.split(key, 3) + drift_mlp = eqx.nn.MLP( + in_size=3, + out_size=3, + width_size=8, + depth=2, + activation=_squareplus, + final_activation=jnp.tanh, + key=driftkey, + ) + diffusion_mlp = eqx.nn.MLP( + in_size=3, + out_size=3 * noise_dim, + width_size=8, + depth=2, + activation=_squareplus, + final_activation=jnp.tanh, + key=diffusionkey, + ) + args = (drift_mlp, diffusion_mlp, noise_dim) + y0 = jr.normal(ykey, (3,), dtype=dtype) + + def get_terms(bm): + return MultiTerm(ODETerm(drift), ControlTerm(diffusion, bm)) + + return SDE(get_terms, args, y0, t0, t1, (noise_dim,)) + + +# This is needed for time_sde (i.e. the additive noise SDE) because initializing +# the weights in the drift MLP with a Gaussian makes the SDE too linear and nice, +# so we need to use a Laplace distribution, which is heavier-tailed. +def lap_init(weight: jax.Array, key) -> jax.Array: + stddev = 1.0 + return stddev * jax.random.laplace(key, shape=weight.shape, dtype=weight.dtype) + + +def init_linear_weight(model, init_fn, key): + is_linear = lambda x: isinstance(x, eqx.nn.Linear) + + def get_weights(model): + list = [] + for x in jax.tree_util.tree_leaves(model, is_leaf=is_linear): + if is_linear(x): + list.extend([x.weight, x.bias]) + return list + + weights = get_weights(model) + new_weights = [ + init_fn(weight, subkey) + for weight, subkey in zip(weights, jax.random.split(key, len(weights))) + ] + new_model = eqx.tree_at(get_weights, model, new_weights) + return new_model + + +def get_time_sde(t0, t1, dtype, key, noise_dim): + y_dim = 7 + driftkey, diffusionkey, ykey = jr.split(key, 3) + + def ft(t): + return jnp.array( + [jnp.sin(t), jnp.cos(4 * t), 1.0, 1.0 / (t + 0.5)], dtype=dtype ) - return _path_l2_dist(sols, correct_sols) - errs = get_single_err(dts) - order, _ = jnp.polyfit(jnp.log(dts), jnp.log(errs), 1) - return dts, errs, order + drift_mlp = eqx.nn.MLP( + in_size=y_dim + 4, + out_size=y_dim, + width_size=16, + depth=5, + activation=_squareplus, + key=driftkey, + ) + + # The drift weights must be Laplace-distributed, + # otherwise the SDE is too linear and nice. + drift_mlp = init_linear_weight(drift_mlp, lap_init, driftkey) + + def _drift(t, y, _): + mlp_out = drift_mlp(jnp.concatenate([y, ft(t)])) + return (0.01 * mlp_out - 0.5 * y**3) / (jnp.sum(y**2) + 1) + + diffusion_mx = jr.normal(diffusionkey, (4, y_dim, noise_dim), dtype=dtype) + + def _diffusion(t, _, __): + # This needs a large coefficient to make the SDE not too easy. + return 1.0 * jnp.tensordot(ft(t), diffusion_mx, axes=1) + + args = (drift_mlp, None, None) + y0 = jr.normal(ykey, (y_dim,), dtype=dtype) + + def get_terms(bm): + return MultiTerm(ODETerm(_drift), ControlTerm(_diffusion, bm)) + + return SDE(get_terms, args, y0, t0, t1, (noise_dim,)) diff --git a/test/test_brownian.py b/test/test_brownian.py index 73f56240..ac227859 100644 --- a/test/test_brownian.py +++ b/test/test_brownian.py @@ -103,7 +103,7 @@ def is_tuple_of_ints(obj): with context: out = path.evaluate(t0, t1, use_levy=use_levy) if use_levy: - assert isinstance(out, diffrax.AbstractBrownianReturn) + assert isinstance(out, diffrax.AbstractBrownianIncrement) w = out.W if isinstance(out, diffrax.SpaceTimeLevyArea): h = out.H @@ -136,7 +136,7 @@ def _eval(key): values = jax.vmap(_eval)(keys) if use_levy: - assert isinstance(values, diffrax.AbstractBrownianReturn) + assert isinstance(values, diffrax.AbstractBrownianIncrement) w = values.W if isinstance(values, diffrax.SpaceTimeLevyArea): diff --git a/test/test_integrate.py b/test/test_integrate.py index acd14b72..b2611867 100644 --- a/test/test_integrate.py +++ b/test/test_integrate.py @@ -191,21 +191,34 @@ def f(t, y, args): assert -0.9 < order - solver.order(term) < 0.9 -def _squareplus(x): - return 0.5 * (x + jnp.sqrt(x**2 + 4)) +def _solvers_and_orders(): + # solver, noise, order + # noise is "any" or "com" or "add" where "com" means commutative and "add" means + # additive. + yield diffrax.Euler, "any", 0.5 + yield diffrax.EulerHeun, "any", 0.5 + yield diffrax.Heun, "any", 0.5 + yield diffrax.ItoMilstein, "any", 0.5 + yield diffrax.Midpoint, "any", 0.5 + yield diffrax.ReversibleHeun, "any", 0.5 + yield diffrax.StratonovichMilstein, "any", 0.5 + yield diffrax.SPaRK, "any", 0.5 + yield diffrax.GeneralShARK, "any", 0.5 + yield diffrax.SlowRK, "any", 0.5 + yield diffrax.ReversibleHeun, "com", 1 + yield diffrax.StratonovichMilstein, "com", 1 + yield diffrax.SPaRK, "com", 1 + yield diffrax.GeneralShARK, "com", 1 + yield diffrax.SlowRK, "com", 1.5 + yield diffrax.SPaRK, "add", 1.5 + yield diffrax.GeneralShARK, "add", 1.5 + yield diffrax.ShARK, "add", 1.5 + yield diffrax.SRA1, "add", 1.5 + yield diffrax.SEA, "add", 1.0 -def _solvers(): - # solver, commutative, order - yield diffrax.Euler, False, 0.5 - yield diffrax.EulerHeun, False, 0.5 - yield diffrax.Heun, False, 0.5 - yield diffrax.ItoMilstein, False, 0.5 - yield diffrax.Midpoint, False, 0.5 - yield diffrax.ReversibleHeun, False, 0.5 - yield diffrax.StratonovichMilstein, False, 0.5 - yield diffrax.ReversibleHeun, True, 1 - yield diffrax.StratonovichMilstein, True, 1 +def _squareplus(x): + return 0.5 * (x + jnp.sqrt(x**2 + 4)) def _drift(t, y, args): @@ -218,24 +231,27 @@ def _diffusion(t, y, args): return 0.25 * diffusion_mlp(y).reshape(3, noise_dim) -@pytest.mark.parametrize("solver_ctr,commutative,theoretical_order", _solvers()) -def test_sde_strong_order(solver_ctr, commutative, theoretical_order): +@pytest.mark.parametrize("solver_ctr,noise,theoretical_order", _solvers_and_orders()) +def test_sde_strong_order(solver_ctr, noise, theoretical_order): key = jr.PRNGKey(5678) driftkey, diffusionkey, ykey, bmkey = jr.split(key, 4) + num_samples = 20 + bmkeys = jr.split(bmkey, num_samples) - if commutative: + if noise == "com": noise_dim = 1 + elif noise == "any": + noise_dim = 7 + elif noise == "add": + return else: - noise_dim = 5 - - key = jr.PRNGKey(5678) - driftkey, diffusionkey, ykey, bmkey = jr.split(key, 4) + assert False drift_mlp = eqx.nn.MLP( in_size=3, out_size=3, width_size=8, - depth=1, + depth=2, activation=_squareplus, key=driftkey, ) @@ -244,7 +260,7 @@ def test_sde_strong_order(solver_ctr, commutative, theoretical_order): in_size=3, out_size=3 * noise_dim, width_size=8, - depth=1, + depth=2, activation=_squareplus, final_activation=jnp.tanh, key=diffusionkey, @@ -268,19 +284,36 @@ def get_terms(bm): else: assert False + if theoretical_order == 0.5: + levels = (3, 8) + ref_level = 10 + elif theoretical_order == 1.0: + levels = (1, 6) + ref_level = 12 + elif theoretical_order == 1.5: + levels = (0, 4) + ref_level = 12 + else: + assert False + + def get_dt_and_controller(level): + return 2**-level, diffrax.ConstantStepSize() + hs, errors, order = sde_solver_strong_order( + bmkeys, get_terms, (noise_dim,), - solver_ctr(), - ref_solver, t0, t1, - dt_precise=2**-12, - y0=y0, - args=args, - num_samples=20, - num_levels=7, - key=bmkey, + y0, + args, + solver_ctr(), + ref_solver, + levels, + ref_level, + get_dt_and_controller, + diffrax.SaveAt(t1=True), + bm_tol=2.0 ** -(ref_level + 2), ) assert -0.2 < order - theoretical_order < 0.2 diff --git a/test/test_sde.py b/test/test_sde.py new file mode 100644 index 00000000..5ffb3a6d --- /dev/null +++ b/test/test_sde.py @@ -0,0 +1,269 @@ +from typing import Literal + +import diffrax +import jax +import jax.numpy as jnp +import jax.random as jr +import jax.tree_util as jtu +import pytest +from diffrax import ControlTerm, MultiTerm, ODETerm, WeaklyDiagonalControlTerm + +from .helpers import ( + get_mlp_sde, + get_time_sde, + path_l2_dist, + simple_batch_sde_solve, + simple_sde_order, +) + + +def _solvers_and_orders(): + # solver, noise, order + # noise is "any" or "com" or "add" where "com" means commutative and "add" means + # additive. + yield diffrax.SPaRK, "any", 0.5 + yield diffrax.GeneralShARK, "any", 0.5 + yield diffrax.SlowRK, "any", 0.5 + yield diffrax.SPaRK, "com", 1 + yield diffrax.GeneralShARK, "com", 1 + yield diffrax.SlowRK, "com", 1.5 + yield diffrax.SPaRK, "add", 1.5 + yield diffrax.GeneralShARK, "add", 1.5 + yield diffrax.ShARK, "add", 1.5 + yield diffrax.SRA1, "add", 1.5 + yield diffrax.SEA, "add", 1.0 + + +# For solvers of high order, comparing to Euler or Heun is not sufficient, +# because they are substantially worse than e.g. ShARK. ShARK is more precise +# at dt=2**-4 than Euler is at dt=2**-14 (and it takes forever to run at such +# a small dt). Hence , the order of convergence of ShARK seems to plateau at +# discretisations finer than 2**-4. +# Therefore, we use two separate tests. First we determine how fast the solver +# converges to its own limit (i.e. using itself as reference), and then in a +# different test check whether that limit is the same as the Euler/Heun limit. +@pytest.mark.parametrize("solver_ctr,noise,theoretical_order", _solvers_and_orders()) +def test_sde_strong_order_new( + solver_ctr, noise: Literal["any", "com", "add"], theoretical_order +): + bmkey = jr.PRNGKey(5678) + sde_key = jr.PRNGKey(11) + num_samples = 100 + bmkeys = jr.split(bmkey, num=num_samples) + t0 = 0.3 + t1 = 5.3 + + if noise == "add": + sde = get_time_sde(t0, t1, jnp.float64, sde_key, noise_dim=7) + else: + if noise == "com": + noise_dim = 1 + elif noise == "any": + noise_dim = 5 + else: + assert False + sde = get_mlp_sde(t0, t1, jnp.float64, sde_key, noise_dim=noise_dim) + + ref_solver = solver_ctr() + level_coarse, level_fine = 1, 7 + + # We specify the times to which we step in way that each level contains all the + # steps of the previous level. This is so that we can compare the solutions at + # all the times in saveat, and not just at the end time. + def get_dt_and_controller(level): + step_ts = jnp.linspace(t0, t1, 2**level + 1, endpoint=True) + return None, diffrax.StepTo(ts=step_ts) + + saveat = diffrax.SaveAt(ts=jnp.linspace(t0, t1, 2**level_coarse + 1, endpoint=True)) + + hs, errors, order = simple_sde_order( + bmkeys, + sde, + solver_ctr(), + ref_solver, + (level_coarse, level_fine), + get_dt_and_controller, + saveat, + bm_tol=2**-14, + ) + # The upper bound needs to be 0.25, otherwise we fail. + # This still preserves a 0.05 buffer between the intervals + # corresponding to the different orders. + assert -0.2 < order - theoretical_order < 0.25 + + +# Make variables to store the correct solutions in. +# This is to avoid recomputing the correct solutions for every solver. +solutions = { + "Ito": { + "any": None, + "com": None, + "add": None, + }, + "Stratonovich": { + "any": None, + "com": None, + "add": None, + }, +} + + +# Now compare the limit of Euler/Heun to the limit of the other solvers, +# using a single reference solution. We use Euler if the solver is Ito +# and Heun if the solver is Stratonovich. +@pytest.mark.parametrize("solver_ctr,noise,theoretical_order", _solvers_and_orders()) +def test_sde_strong_limit( + solver_ctr, noise: Literal["any", "com", "add"], theoretical_order +): + bmkey = jr.PRNGKey(5678) + sde_key = jr.PRNGKey(11) + num_samples = 100 + bmkeys = jr.split(bmkey, num=num_samples) + t0 = 0.3 + t1 = 5.3 + + if noise == "add": + sde = get_time_sde(t0, t1, jnp.float64, sde_key, noise_dim=3) + level_fine = 12 + if theoretical_order <= 1.0: + level_coarse = 11 + else: + level_coarse = 8 + else: + level_coarse, level_fine = 7, 11 + if noise == "com": + noise_dim = 1 + elif noise == "any": + noise_dim = 5 + else: + assert False + sde = get_mlp_sde(t0, t1, jnp.float64, sde_key, noise_dim=noise_dim) + + # Reference solver is always an ODE-viable solver, so its implementation has been + # verified by the ODE tests like test_ode_order. + if issubclass(solver_ctr, diffrax.AbstractItoSolver): + sol_type = "Ito" + ref_solver = diffrax.Euler() + elif issubclass(solver_ctr, diffrax.AbstractStratonovichSolver): + sol_type = "Stratonovich" + ref_solver = diffrax.Heun() + else: + assert False + + ts_fine = jnp.linspace(t0, t1, 2**level_fine + 1, endpoint=True) + ts_coarse = jnp.linspace(t0, t1, 2**level_coarse + 1, endpoint=True) + contr_fine = diffrax.StepTo(ts=ts_fine) + contr_coarse = diffrax.StepTo(ts=ts_coarse) + save_ts = jnp.linspace(t0, t1, 2**5 + 1, endpoint=True) + assert len(jnp.intersect1d(ts_fine, save_ts)) == len(save_ts) + assert len(jnp.intersect1d(ts_coarse, save_ts)) == len(save_ts) + saveat = diffrax.SaveAt(ts=save_ts) + levy_area = diffrax.SpaceTimeLevyArea # must be common for all solvers + + if solutions[sol_type][noise] is None: + correct_sol, _ = simple_batch_sde_solve( + bmkeys, sde, ref_solver, levy_area, None, contr_fine, 2**-10, saveat + ) + solutions[sol_type][noise] = correct_sol + else: + correct_sol = solutions[sol_type][noise] + + sol, _ = simple_batch_sde_solve( + bmkeys, sde, solver_ctr(), levy_area, None, contr_coarse, 2**-10, saveat + ) + error = path_l2_dist(correct_sol, sol) + assert error < 0.05 + + +def _solvers(): + yield diffrax.SPaRK + yield diffrax.GeneralShARK + yield diffrax.SlowRK + yield diffrax.ShARK + yield diffrax.SRA1 + yield diffrax.SEA + + +# Define the SDE +def dict_drift(t, y, args): + pytree, _ = args + return jtu.tree_map(lambda _, x: -0.5 * x, pytree, y) + + +def dict_diffusion(t, y, args): + pytree, additive = args + + def get_matrix(y_leaf): + if additive: + return 2.0 * jnp.ones(y_leaf.shape + (3,), dtype=jnp.float64) + else: + return 2.0 * jnp.broadcast_to( + jnp.expand_dims(y_leaf, axis=y_leaf.ndim), y_leaf.shape + (3,) + ) + + return jtu.tree_map(get_matrix, y) + + +@pytest.mark.parametrize("shape", [(), (5, 2)]) +@pytest.mark.parametrize("solver_ctr", _solvers()) +def test_sde_solver_shape(shape, solver_ctr): + pytree = ({"a": 0, "b": [0, 0]}, 0, 0) + dtype = jnp.float64 + key = jr.PRNGKey(0) + y0 = jtu.tree_map(lambda _: jr.normal(key, shape, dtype=dtype), pytree) + t0, t1, dt0 = 0.0, 1.0, 0.3 + + # Some solvers only work with additive noise + additive = solver_ctr in [diffrax.ShARK, diffrax.SRA1, diffrax.SEA] + args = (pytree, additive) + solver = solver_ctr() + bmkey = jr.PRNGKey(1) + struct = jax.ShapeDtypeStruct((3,), dtype) + bm_shape = jtu.tree_map(lambda _: struct, pytree) + bm = diffrax.VirtualBrownianTree( + t0, t1, 0.1, bm_shape, bmkey, diffrax.SpaceTimeLevyArea + ) + terms = MultiTerm(ODETerm(dict_drift), ControlTerm(dict_diffusion, bm)) + solution = diffrax.diffeqsolve( + terms, solver, t0, t1, dt0, y0, args, saveat=diffrax.SaveAt(t1=True) + ) + assert jtu.tree_structure(solution.ys) == jtu.tree_structure(y0) + for leaf in jtu.tree_leaves(solution.ys): + assert leaf[0].shape == shape + + +def _weakly_diagonal_noise_helper(solver): + dtype = jnp.float64 + w_shape = (3,) + args = (0.5, 1.2) + + def _diffusion(t, y, args): + a, b = args + return jnp.array([b, t, 1 / (t + 1.0)], dtype=dtype) + + def _drift(t, y, args): + a, b = args + return -a * y + + y0 = jnp.ones(w_shape, dtype) + + bm = diffrax.VirtualBrownianTree( + 0.0, 1.0, 0.05, w_shape, jr.PRNGKey(0), diffrax.SpaceTimeLevyArea + ) + + terms = MultiTerm(ODETerm(_drift), WeaklyDiagonalControlTerm(_diffusion, bm)) + saveat = diffrax.SaveAt(t1=True) + solution = diffrax.diffeqsolve( + terms, solver, 0.0, 1.0, 0.1, y0, args, saveat=saveat + ) + assert solution.ys is not None + assert solution.ys.shape == (1, 3) + + +@pytest.mark.parametrize("solver_ctr", _solvers()) +def test_weakly_diagonal_noise(solver_ctr): + _weakly_diagonal_noise_helper(solver_ctr()) + + +def test_halfsolver_term_compatible(): + _weakly_diagonal_noise_helper(diffrax.HalfSolver(diffrax.SPaRK())) diff --git a/test/test_solver.py b/test/test_solver.py index 56656ac6..f10b8b39 100644 --- a/test/test_solver.py +++ b/test/test_solver.py @@ -194,7 +194,7 @@ class Term(diffrax.AbstractTerm): def vf(self, t, y, args): return {"f": -self.coeff * y["y"]} - def contr(self, t0, t1): + def contr(self, t0, t1, **kwargs): return {"t": t1 - t0} def prod(self, vf, control): diff --git a/test/test_term.py b/test/test_term.py index 0f73b5e7..420dc122 100644 --- a/test/test_term.py +++ b/test/test_term.py @@ -53,7 +53,7 @@ def derivative(self, t, left=True): # `# type: ignore` is used for contrapositive static type checking as per: # https://github.com/microsoft/pyright/discussions/2411#discussioncomment-2028001 _: diffrax.ControlTerm[PyTree[Array], Array] = term - __: diffrax.ControlTerm[PyTree[Array], diffrax.AbstractBrownianReturn] = term # type: ignore + __: diffrax.ControlTerm[PyTree[Array], diffrax.BrownianIncrement] = term # type: ignore term = term.to_ode() dt = term.contr(0, 1)