diff --git a/diffrax/__init__.py b/diffrax/__init__.py
index 9b13acdf..3ce7f85d 100644
--- a/diffrax/__init__.py
+++ b/diffrax/__init__.py
@@ -7,6 +7,7 @@
 )
 from .autocitation import citation, citation_rules
 from .brownian import AbstractBrownianPath, UnsafeBrownianPath, VirtualBrownianTree
+from .delays import Delays
 from .event import (
     AbstractDiscreteTerminatingEvent,
     DiscreteTerminatingEvent,
diff --git a/diffrax/adjoint.py b/diffrax/adjoint.py
index b13c5903..20e2c7de 100644
--- a/diffrax/adjoint.py
+++ b/diffrax/adjoint.py
@@ -13,9 +13,8 @@
 
 from .ad import implicit_jvp
 from .heuristics import is_sde, is_unsafe_sde
-from .saveat import SaveAt, SubSaveAt, save_y
-from .solver import (AbstractItoSolver, AbstractRungeKutta,
-                     AbstractStratonovichSolver)
+from .saveat import save_y, SaveAt, SubSaveAt
+from .solver import AbstractItoSolver, AbstractRungeKutta, AbstractStratonovichSolver
 from .term import AbstractTerm, AdjointTerm
 
 
@@ -532,6 +531,7 @@ def _loop_backsolve_bwd(
     max_steps,
     throw,
     init_state,
+    y0_history,
 ):
 
     #
diff --git a/diffrax/delays.py b/diffrax/delays.py
new file mode 100644
index 00000000..036dab3a
--- /dev/null
+++ b/diffrax/delays.py
@@ -0,0 +1,383 @@
+from typing import Callable, Optional, Sequence, Type, Union
+
+import equinox as eqx
+import jax
+import jax.lax as lax
+import jax.numpy as jnp
+import jax.tree_util as jtu
+from equinox.internal import unvmap_any
+from optimistix import fixed_point, FixedPointIteration
+
+from .custom_types import Array, Bool, Int, PyTree, Scalar
+from .global_interpolation import DenseInterpolation
+from .local_interpolation import AbstractLocalInterpolation
+from .nonlinear_solver import NewtonNonlinearSolver
+from .term import VectorFieldWrapper
+
+
+class Delays(eqx.Module):
+    """Module that incorportes all the information needed for integrating DDEs"""
+
+    delays: PyTree[Callable]
+    initial_discontinuities: Union[None, Array, Sequence[Scalar]] = jnp.array([0.0])
+    max_discontinuities: Int = 100
+    recurrent_checking: Bool = False
+    sub_intervals: int = 10
+    max_steps: int = 20
+    rtol: float = 1e-3
+    atol: float = 1e-6
+
+
+class HistoryVectorField(eqx.Module):
+    """VectorField equivalent for a DDE solver that incorporates former
+    estimated values of y(t).
+
+    **Arguments:**
+        - `vector_field`: vector field of the delayed differential equation.
+        - `t0`: global integration start time
+        - `tprev`: start time of current integration step
+        - `tnext`: end time of current integration step
+        - `dense_info` : dense_info from current integration step
+        - `y0_history` : DDE's history function
+        - `delays` : DDE's different deviated arguments
+    """
+
+    vector_field: Callable
+    t0: float
+    tprev: float
+    tnext: float
+    dense_info: PyTree[Array]
+    dense_interp: Optional[DenseInterpolation]
+    interpolation_cls: Type[AbstractLocalInterpolation]
+    y0_history: Callable
+    delays: PyTree[Callable]
+
+    def __call__(self, t, y, args):
+        history_vals = []
+        delays, treedef = jtu.tree_flatten(self.delays)
+        if self.dense_interp is None:
+            assert self.dense_info is None
+            for delay in self.delays:
+                delay_val = delay(t, y, args)
+                alpha_val = t - delay_val
+                y0_val = self.y0_history(alpha_val)
+                history_vals.append(y0_val)
+        else:
+            assert self.dense_info is not None
+            for delay in delays:
+                delay_val = delay(t, y, args)
+                alpha_val = t - delay_val
+
+                is_before_t0 = alpha_val < self.t0
+                is_before_tprev = alpha_val < self.tprev
+                at_most_t0 = jnp.where(alpha_val < self.t0, alpha_val, self.t0)
+                t0_to_tprev = jnp.clip(alpha_val, self.t0, self.tprev)
+                at_least_tprev = jnp.maximum(self.tprev, alpha_val)
+                step_interpolation = self.interpolation_cls(
+                    t0=self.tprev, t1=self.tnext, **self.dense_info
+                )
+                switch = jnp.where(is_before_t0, 0, jnp.where(is_before_tprev, 1, 2))
+                history_val = lax.switch(
+                    switch,
+                    [
+                        lambda: self.y0_history(at_most_t0),
+                        lambda: self.dense_interp.evaluate(t0_to_tprev),
+                        lambda: step_interpolation.evaluate(at_least_tprev),
+                    ],
+                )
+                history_vals.append(history_val)
+
+        history_vals = jtu.tree_unflatten(treedef, history_vals)
+        history_vals = tuple(history_vals)
+        return self.vector_field(t, y, args, history=history_vals)
+
+
+def bind_history(
+    terms,
+    delays,
+    dense_info,
+    dense_interp,
+    solver,
+    direction,
+    t0,
+    tprev,
+    tnext,
+    y0_history,
+):
+    delays_fn = jtu.tree_map(
+        lambda x: (lambda t, y, args: x(t, y, args) * direction), delays.delays
+    )
+
+    is_vf_wrapper = lambda x: isinstance(x, VectorFieldWrapper)
+
+    def _apply_history(
+        x,
+    ):
+        if is_vf_wrapper(x):
+            vector_field = HistoryVectorField(
+                x.vector_field,
+                t0,
+                tprev,
+                tnext,
+                dense_info,
+                dense_interp,
+                solver.interpolation_cls,
+                y0_history,
+                delays_fn,
+            )
+            return VectorFieldWrapper(vector_field)
+        else:
+            return x
+
+    terms_ = jtu.tree_map(_apply_history, terms, is_leaf=is_vf_wrapper)
+    return terms_
+
+
+def history_extrapolation_implicit(
+    implicit_step,
+    terms,
+    dense_interp,
+    solver,
+    delays,
+    t0,
+    y0_history,
+    state,
+    args,
+):
+    def fn(dense_info, args):
+        (
+            terms,
+            _,
+            dense_interp,
+            solver,
+            delays,
+            t0,
+            y0_history,
+            state,
+            vf_args,
+        ) = args
+        terms_ = bind_history(
+            terms,
+            delays,
+            dense_info,
+            dense_interp,
+            solver,
+            1,
+            t0,
+            state.tprev,
+            state.tnext,
+            y0_history,
+        )
+        (y, y_error, new_dense_info, solver_state, solver_result) = solver.step(
+            terms_,
+            state.tprev,
+            state.tnext,
+            state.y,
+            vf_args,
+            state.solver_state,
+            state.made_jump,
+        )
+
+        return new_dense_info, (y, y_error, solver_state, solver_result)
+
+    # unwrapped_buffer = jtu.tree_leaves(
+    #     eqx.filter(state.dense_infos, eqx.is_inexact_array),
+    #     is_leaf=eqx.is_inexact_array,
+    # )
+    # aux_dense_infos = dict(zip(state.dense_infos.keys(), unwrapped_buffer))
+
+    # def get_dense_info(dense_infos, idx):
+    #     return jtu.tree_map(lambda x: x[idx], dense_infos)
+
+    # # dense_info_aux = dict(zip(["k","y0", "y1"],
+    # [jnp.zeros((4,)), jnp.zeros((1,)), jnp.zeros((1,))]))
+    # jax.debug.breakpoint()
+    # struct_dense_info = eqx.filter_eval_shape(get_dense_info,  state.dense_infos, 0)
+    # infos = jtu.tree_map(lambda _, x: x[...], struct_dense_info, state.dense_infos)
+    # jax.debug.print("infos integrate {} ", infos)
+    # print("infos,", infos)
+    # print("struct_dense_info", struct_dense_info)
+
+    jax.debug.print("state.dense_infos {}", state.dense_infos)
+    init_guess = jtu.tree_map(
+        lambda x: x[state.dense_save_index - 1], state.dense_infos
+    )
+    alg = FixedPointIteration(rtol=delays.rtol, atol=delays.atol)
+    nonlinear_args = (
+        terms,
+        implicit_step,
+        dense_interp,
+        solver,
+        delays,
+        t0,
+        y0_history,
+        state,
+        args,
+    )
+    sol = fixed_point(fn, alg, init_guess, nonlinear_args, has_aux=True)
+    dense_info, (y, y_error, solver_state, solver_result) = sol.value, sol.aux
+    return y, y_error, dense_info, solver_state, solver_result
+
+
+def maybe_find_discontinuity(
+    tprev,
+    tnext,
+    dense_info,
+    state,
+    delays,
+    solver,
+    args,
+    keep_step,
+    sub_tprev,
+    sub_tnext,
+):
+    dense_discont = solver.interpolation_cls(t0=tprev, t1=tnext, **dense_info)
+    flat_delays = jtu.tree_leaves(delays.delays)
+    _gs = []
+
+    def make_g(delay):
+        # Creating the artifical event functions g that is used to
+        # detect future breaking points.
+        # http://www.cs.toronto.edu/pub/reports/na/hzpEnrightNA09Preprint.pdf
+        # page 7
+        def g(t):
+            return (
+                t
+                - delay(t, dense_discont.evaluate(t), args)
+                - state.discontinuities[...]
+            )
+
+        return g
+
+    for delay in flat_delays:
+        _gs.append(make_g(delay))
+
+    def _find_discontinuity():
+        # Start by doing a cheap bisection search to reduce
+        # over the stored-discontinuity dimension.
+
+        def _cond_fun(_val):
+            _, _, _pred, _ = _val
+            return _pred
+
+        def _body_fun(_val):
+            _ta, _tb, _, _step = _val
+            _step = _step + 1
+            _tmid = _ta + 0.5 * (_tb - _ta)
+            _gas = jnp.stack([jnp.sign(g(_ta)) for g in _gs])
+            _gmids = jnp.stack([jnp.sign(g(_tmid)) for g in _gs])
+            _gbs = jnp.stack([jnp.sign(g(_tb)) for g in _gs])
+            _any_left = jnp.any(_gas != _gmids)
+            _next_ta = jnp.where(_any_left, _ta, _tmid)
+            _next_tb = jnp.where(_any_left, _tmid, _tb)
+            _pred = (
+                jnp.any(jnp.sum(_gas != _gbs, axis=1) > 1) | _step > delays.max_steps
+            )
+            return _next_ta, _next_tb, _pred, _step
+
+        _init_val = (sub_tprev, sub_tnext, True, 0)
+        _final_val = lax.while_loop(_cond_fun, _body_fun, _init_val)
+        _ta, _tb, _, _ = _final_val
+
+        # Then do a more expensive Newton search
+        # to find the first discontinuity.
+        _discont_solver = NewtonNonlinearSolver(rtol=delays.rtol, atol=delays.atol)
+        _disconts = []
+        for g, delay in zip(_gs, flat_delays):
+            changed_sign = jnp.sign(g(_ta)) != jnp.sign(g(_tb))
+            _i = jnp.argmax(changed_sign)
+            _d = state.discontinuities[_i]
+            _h = (
+                lambda t, args, delay=delay, _d=_d: t
+                - delay(t, dense_discont.evaluate(t), args)
+                - _d
+            )
+            _discont = _discont_solver(_h, _tb, args).root
+            _disconts.append(_discont)
+        _disconts = jnp.stack(_disconts)
+
+        best_candidate = jnp.where(
+            (_disconts > sub_tprev) & (_disconts < sub_tnext),
+            _disconts,
+            jnp.inf,
+        )
+        best_candidate = jnp.min(best_candidate)
+        discont_update = jnp.where(
+            jnp.isinf(best_candidate),
+            False,
+            True,
+        )
+        return best_candidate, discont_update
+
+    def _find_discontinuity_wrapper():
+        return lax.cond(
+            jnp.any(init_discont & jnp.invert(keep_step)),
+            _find_discontinuity,
+            lambda: (sub_tnext, False),
+        )
+
+    init_discont = jnp.stack(
+        [jnp.sign(g(sub_tprev)) != jnp.sign(g(sub_tnext)) for g in _gs]
+    )
+    # We might have rejected the step for normal reasons;
+    # skip looking for a discontinuity if so.
+    return lax.cond(
+        unvmap_any((init_discont & jnp.invert(keep_step))),
+        _find_discontinuity_wrapper,
+        lambda: (sub_tnext, False),
+    )
+
+
+Delays.__init__.__doc__ = """
+**Arguments:**
+
+- `delays`: A `PyTree` where the leaves are the DDE's different scalar 
+  deviated arguments.
+- `initial_discontinuities`: Discontinuities given by the initial point time
+and history function.
+- `max_discontinuities`: Array length that tracks the discontinuity jumps 
+during integration (only relevant when `recurrent_checking` is True). If 
+`recurrent checking` is set to `True`, the computation quits unconditionally
+when the total number of discontinuities detected is larger 
+than `max_discontinuities`.
+- `recurrent_checking` : If `True`, there will be a systematic check at 
+integration step for potential discontinuities (this involves nonlinear solves 
+hence expensive). If `False`, discontinuities will only be checked when a step 
+is rejected. This allows to integrate faster but can also impact 
+the accuracy of the DDE solution.
+- `sub_intervals` : Number of subintervals of the integration step where 
+discontinuity tracking is done.
+- `rtol` : Relative  tolerance for the nonlinear solver for the DDE's 
+implicit stepping and dichotomy for detecting breaking points.
+- `atol` : Absolute tolerance for the nonlinear solver for the DDE's
+implicit stepping and dichotomy for detecting breaking points.
+- `max_steps` : Max iteration of the dichotomy algorithm to 
+find a discontinuity.
+
+!!! example
+    To integrate `y'(t) = - y(t-1)`, we need to define the vector 
+    field and the `Delays` object. 
+    ```py
+    y0 = lambda t: 1.2
+    def vector_field(t, y, args, history):
+        return - history[0]
+
+    delays = Delays(
+        delays=[lambda t, y, args: 1.0],
+        initial_discontinuities=jnp.array([0.0])     
+    )
+    t0, t1 = 0.0, 50.0
+    ts = jnp.linspace(t0, t1, 500)
+    sol = diffrax.diffeqsolve(
+        diffrax.ODETerm(vector_field),
+        diffrax.Tsit5(),
+        t0=ts[0],
+        t1=ts[-1],
+        dt0=ts[1] - ts[0],
+        y0 = y0_history,
+        stepsize_controller=diffrax.PIDController(rtol=1e-6, atol=1e-9),
+        saveat=diffrax.SaveAt(ts=ts, dense=True),
+        delays=delays
+        )
+    ```
+"""
diff --git a/diffrax/integrate.py b/diffrax/integrate.py
index 7c40688d..073d71c3 100644
--- a/diffrax/integrate.py
+++ b/diffrax/integrate.py
@@ -1,8 +1,7 @@
 import functools as ft
 import typing
 import warnings
-from typing import (Any, Callable, Optional, Sequence, Tuple, get_args,
-                    get_origin)
+from typing import Any, Callable, get_args, get_origin, Optional, Tuple, Union
 
 import equinox as eqx
 import equinox.internal as eqxi
@@ -13,20 +12,29 @@
 
 from .adjoint import AbstractAdjoint, DirectAdjoint, RecursiveCheckpointAdjoint
 from .custom_types import Array, Bool, Int, PyTree, Scalar
+from .delays import bind_history, Delays, history_extrapolation_implicit
 from .event import AbstractDiscreteTerminatingEvent
 from .global_interpolation import DenseInterpolation
 from .heuristics import is_sde, is_unsafe_sde
 from .misc import static_select
 from .saveat import SaveAt, SubSaveAt
-from .solution import RESULTS, Solution, is_okay, is_successful
-from .solver import (AbstractItoSolver, AbstractSolver,
-                     AbstractStratonovichSolver, Euler, EulerHeun, ItoMilstein,
-                     StratonovichMilstein)
-from .step_size_controller import (AbstractAdaptiveStepSizeController,
-                                   AbstractStepSizeController,
-                                   ConstantStepSize, StepTo)
-from .term import (AbstractTerm, MultiTerm, ODETerm, VectorFieldWrapper,
-                   WrapTerm)
+from .solution import is_okay, is_successful, RESULTS, Solution
+from .solver import (
+    AbstractItoSolver,
+    AbstractSolver,
+    AbstractStratonovichSolver,
+    Euler,
+    EulerHeun,
+    ItoMilstein,
+    StratonovichMilstein,
+)
+from .step_size_controller import (
+    AbstractAdaptiveStepSizeController,
+    AbstractStepSizeController,
+    ConstantStepSize,
+    StepTo,
+)
+from .term import AbstractTerm, MultiTerm, ODETerm, WrapTerm
 
 
 class SaveState(eqx.Module):
@@ -48,11 +56,15 @@ class State(eqx.Module):
     num_steps: Int
     num_accepted_steps: Int
     num_rejected_steps: Int
+    num_dde_implicit_step: Int
+    num_dde_explicit_step: Int
     # Output that is .at[].set() updated during the solve (and their indices)
     save_state: PyTree[SaveState]
     dense_ts: Optional[Array["times + 1"]]  # noqa: F821
     dense_infos: Optional[PyTree[Array["times", ...]]]  # noqa: F821
     dense_save_index: Int
+    discontinuities: Optional[Array["times"]]  # noqa: F821
+    discontinuities_save_index: Optional[Int]
 
 
 def _is_none(x):
@@ -103,6 +115,7 @@ def _outer_buffers(state):
         [s.ts for s in save_states]
         + [s.ys for s in save_states]
         + [state.dense_ts, state.dense_infos]
+        + [state.discontinuities]
     )
 
 
@@ -164,6 +177,7 @@ def loop(
     init_state,
     inner_while_loop,
     outer_while_loop,
+    y0_history,
 ):
 
     if saveat.dense:
@@ -209,7 +223,6 @@ def body_fun(state):
         # Actually do some differential equation solving! Make numerical steps, adapt
         # step sizes, all that jazz.
         #
-
         if delays is None:
             (y, y_error, dense_info, solver_state, solver_result) = solver.step(
                 terms,
@@ -220,42 +233,48 @@ def body_fun(state):
                 state.solver_state,
                 state.made_jump,
             )
+            num_dde_explicit_step = num_dde_implicit_step = 0
         else:
-            # TODO: double-check that these are the correct `ts_size` and
-            # `direction`.
-            history = DenseInterpolation(
-                ts=state.dense_ts,
+            min_delay = []
+            flat_delays = jtu.tree_leaves(delays.delays)
+            for delay in flat_delays:
+                min_delay.append(delay(state.tprev, state.y, args))
+            min_delay = jnp.stack(min_delay).min()
+            implicit_step = min_delay < (state.tnext - state.tprev)
+
+            dense_interp = DenseInterpolation(
+                ts=state.dense_ts[...],
                 ts_size=state.dense_save_index + 1,
                 interpolation_cls=solver.interpolation_cls,
                 infos=state.dense_infos,
                 direction=1,
+                y0_if_trivial=y0_history(t0),
+                t0_if_trivial=t0,
+            )
+
+            (
+                y,
+                y_error,
+                dense_info,
+                solver_state,
+                solver_result,
+            ) = history_extrapolation_implicit(
+                implicit_step,
+                terms,
+                dense_interp,
+                solver,
+                delays,
+                t0,
+                y0_history,
+                state,
+                args,
             )
-            history_vals = []
-            for delay in delays:
-                delay_val = delay(state.tprev, state.y, args)
-                history_val = history.evaluate(delay_val)
-                history_val.append(history_val)
-            history_vals = tuple(history_vals)
-
-            is_vf_wrapper = lambda x: isinstance(x, VectorFieldWrapper)
-
-            def _apply_history(x):
-                if is_vf_wrapper(x):
-                    vector_field = jtu.Partial(x.vector_field, history=history_vals)
-                    return VectorFieldWrapper(vector_field)
-                else:
-                    return x
-
-            terms_ = jtu.tree_map(_apply_history, terms, is_leaf=is_vf_wrapper)
-            # TODO: write down implicit problem wrt dense_info, using `terms_`
-            (y, y_error, dense_info, solver_state, solver_result) = terms_  # ...
 
         # e.g. if someone has a sqrt(y) in the vector field, and dt0 is so large that
         # we get a negative value for y, and then get a NaN vector field. (And then
         # everything breaks.) See #143.
         y_error = jtu.tree_map(lambda x: jnp.where(jnp.isnan(x), jnp.inf, x), y_error)
 
-        # TODO: handle discontinuity detection for delays
         error_order = solver.error_order(terms)
         (
             keep_step,
@@ -274,12 +293,66 @@ def _apply_history(x):
             error_order,
             state.controller_state,
         )
-        assert jnp.result_type(keep_step) is jnp.dtype(bool)
+
+        # Finding all of the potential discontinuity roots
+        # if delays is not None:
+        #     _part_maybe_find_discontinuity = ft.partial(
+        #         maybe_find_discontinuity,
+        #         tprev,
+        #         tnext,
+        #         dense_info,
+        #         state,
+        #         delays,
+        #         solver,
+        #         args,
+        #     )
+
+        #     tsearch = jnp.linspace(tprev, tnext, delays.sub_intervals)
+        #     batch_tprev, batch_tnext = tsearch[:-1], tsearch[1:]
+        #     vmap_maybe_find_discontinuity_wrapper = jax.vmap(
+        #         _part_maybe_find_discontinuity, (None, 0, 0)
+        #     )
+        #     if delays.recurrent_checking:
+        #         (
+        #             tnext_candidate,
+        #             batch_discont_update,
+        #         ) = vmap_maybe_find_discontinuity_wrapper(
+        #             False, batch_tprev, batch_tnext
+        #         )
+        #     else:
+        #         (
+        #             tnext_candidate,
+        #             batch_discont_update,
+        #         ) = vmap_maybe_find_discontinuity_wrapper(
+        #             keep_step, batch_tprev, batch_tnext
+        #         )
+
+        #     proxy_tnext = jnp.where(batch_discont_update, tnext_candidate, jnp.inf)
+        #     proxy_tnext = jnp.min(proxy_tnext)
+
+        #     tnext, discont_update = jax.lax.cond(
+        #         jnp.isinf(proxy_tnext),
+        #         lambda: (tnext, False),
+        #         lambda: (proxy_tnext, True),
+        #     )
+
+        #     # Count the number of steps in DDEs, just for statistical purposes
+        #     num_dde_implicit_step = state.num_dde_implicit_step + (
+        #         keep_step & implicit_step
+        #     )
+        #     num_dde_explicit_step = state.num_dde_explicit_step + (
+        #         keep_step & jnp.invert(implicit_step)
+        #     )
+
+        #     assert jnp.result_type(discont_update) is jnp.dtype(bool)
+
+        # assert jnp.result_type(keep_step) is jnp.dtype(bool)
 
         #
         # Do some book-keeping.
         #
-
+        # discont_update = False
+        num_dde_explicit_step = num_dde_implicit_step = 0
         tprev = jnp.minimum(tprev, t1)
         tnext = _clip_to_end(tprev, tnext, t1, keep_step)
 
@@ -320,6 +393,8 @@ def _apply_history(x):
         dense_ts = state.dense_ts
         dense_infos = state.dense_infos
         dense_save_index = state.dense_save_index
+        discontinuities = state.discontinuities
+        discontinuities_save_index = state.discontinuities_save_index
 
         def save_ts(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState:
             if subsaveat.ts is not None:
@@ -374,6 +449,17 @@ def maybe_inplace(i, u, x):
             else:
                 return x.at[i].set(u, pred=keep_step)
 
+        # def maybe_inplace_delay(i, u, x):
+        #     # Annoying hack. We normally call this with `x` wrapped into a buffer
+        #     # (from Equinox's while loops). However we do also first trace through to
+        #     # see if we can resolve some values statically, in which case normal JAX
+        #     # arrays don't support the extra `pred` argument. We don't then use the
+        #     # result of this so we just skip it.
+        #     if _filtering:
+        #         return x
+        #     else:
+        #         return x.at[i].set(u, pred=discont_update)
+
         def save_steps(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState:
             if subsaveat.steps:
                 ts = maybe_inplace(save_state.save_index, tprev, save_state.ts)
@@ -403,6 +489,21 @@ def save_steps(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState:
             )
             dense_save_index = dense_save_index + keep_step
 
+        # Updating discontinuity
+        # if delays is not None:
+        #     if delays.recurrent_checking:
+        #         eqxi.error_if(
+        #             discontinuities_save_index,
+        #             discontinuities_save_index >= delays.max_discontinuities,
+        #             "the number of discontinuities detected reached the number of"
+        #             " `max_discontinuities`, please raise its value.",
+        #         )
+
+        #     discontinuities = maybe_inplace_delay(
+        #         discontinuities_save_index + 1, tnext, discontinuities
+        #     )
+        #     discontinuities_save_index = discontinuities_save_index + discont_update
+
         new_state = State(
             y=y,
             tprev=tprev,
@@ -415,9 +516,13 @@ def save_steps(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState:
             num_accepted_steps=num_accepted_steps,
             num_rejected_steps=num_rejected_steps,
             save_state=save_state,
+            num_dde_explicit_step=num_dde_explicit_step,
+            num_dde_implicit_step=num_dde_implicit_step,
             dense_ts=dense_ts,
             dense_infos=dense_infos,
             dense_save_index=dense_save_index,
+            discontinuities=discontinuities,
+            discontinuities_save_index=discontinuities_save_index,
         )
 
         if discrete_terminating_event is not None:
@@ -496,7 +601,7 @@ def diffeqsolve(
     t0: Scalar,
     t1: Scalar,
     dt0: Optional[Scalar],
-    y0: PyTree,
+    y0: Union[PyTree, Callable],
     args: Optional[PyTree] = None,
     *,
     saveat: SaveAt = SaveAt(t1=True),
@@ -504,8 +609,8 @@ def diffeqsolve(
     adjoint: AbstractAdjoint = RecursiveCheckpointAdjoint(),
     discrete_terminating_event: Optional[AbstractDiscreteTerminatingEvent] = None,
     max_steps: Optional[int] = 4096,
-    delays: Optional[Sequence[Callable]] = None,
     throw: bool = True,
+    delays: Optional[Delays] = None,
     solver_state: Optional[PyTree] = None,
     controller_state: Optional[PyTree] = None,
     made_jump: Optional[Bool] = None,
@@ -631,7 +736,7 @@ def diffeqsolve(
         )
         with jax.ensure_compile_time_eval():
             pred = (t1 - t0) * dt0 < 0
-        dt0 = eqxi.error_if(dt0, pred, msg)
+        dt0 = eqxi.error_if(jnp.array(dt0), pred, msg)
 
     # Backward compatibility
     if isinstance(
@@ -712,6 +817,18 @@ def _promote(yi):
         _dtype = jnp.result_type(yi, *timelikes)  # noqa: F821
         return jnp.asarray(yi, dtype=_dtype)
 
+    if delays is None:
+        if callable(y0):
+            raise ValueError("`y0` is passed as a callable and should be an array.")
+        new_discontinuities = None
+        discontinuities_save_index = None
+        y0_history = None
+    else:
+        if not callable(y0):
+            raise ValueError("`y0` should be a callable.")
+        y0_history = y0
+        y0 = y0_history(t0)
+
     y0 = jtu.tree_map(_promote, y0)
     del timelikes, dtype
 
@@ -736,9 +853,6 @@ def _wrap(term):
         terms,
         is_leaf=lambda x: isinstance(x, AbstractTerm) and not isinstance(x, MultiTerm),
     )
-    if delays is not None:
-        delays = [lambda t, y, args, fn=fn: fn(t, y, args) * direction for fn in delays]
-
     # Stepsize controller gets an opportunity to modify the solver.
     # Note that at this point the solver could be anything so we must check any
     # abstract base classes of the solver before this.
@@ -777,9 +891,27 @@ def _check_subsaveat_ts(ts):
         else:
             tnext = t0 + dt0
     tnext = jnp.minimum(tnext, t1)
+
+    if delays is not None:
+        terms_ = bind_history(
+            terms,
+            delays,
+            None,
+            None,
+            solver,
+            direction,
+            t0,
+            tprev,
+            tnext,
+            y0_history,
+        )
+    else:
+        terms_ = terms
+
+    # Got to init the solver
     if solver_state is None:
         passed_solver_state = False
-        solver_state = solver.init(terms, t0, tnext, y0, args)
+        solver_state = solver.init(terms_, t0, tnext, y0, args)
     else:
         passed_solver_state = True
 
@@ -814,6 +946,8 @@ def _allocate_output(subsaveat: SubSaveAt) -> SaveState:
     num_steps = 0
     num_accepted_steps = 0
     num_rejected_steps = 0
+    num_dde_explicit_step = 0
+    num_dde_implicit_step = 0
     made_jump = False if made_jump is None else made_jump
     result = RESULTS.successful
     if saveat.dense:
@@ -822,8 +956,22 @@ def _allocate_output(subsaveat: SubSaveAt) -> SaveState:
                 "`max_steps=None` is incompatible with `saveat.dense=True`"
             )
         (_, _, dense_info, _, _,) = eqx.filter_eval_shape(
-            solver.step, terms, tprev, tnext, y0, args, solver_state, made_jump
+            solver.step, terms_, tprev, tnext, y0, args, solver_state, made_jump
         )
+        if delays is not None:
+            if delays.initial_discontinuities is not None:
+                buffer = jnp.full(
+                    delays.max_discontinuities - len(delays.initial_discontinuities),
+                    jnp.inf,
+                )
+                new_discontinuities = jnp.concatenate(
+                    [delays.initial_discontinuities, buffer]
+                )
+                discontinuities_save_index = len(delays.initial_discontinuities) - 1
+            else:
+                new_discontinuities = jnp.full(delays.max_discontinuities, jnp.inf)
+                discontinuities_save_index = 0
+
         dense_ts = jnp.full(max_steps + 1, jnp.inf)
         _make_full = lambda x: jnp.full((max_steps,) + jnp.shape(x), jnp.inf)
         dense_infos = jtu.tree_map(_make_full, dense_info)
@@ -849,6 +997,11 @@ def _allocate_output(subsaveat: SubSaveAt) -> SaveState:
         dense_ts=dense_ts,
         dense_infos=dense_infos,
         dense_save_index=dense_save_index,
+        # dde specific State arguments
+        num_dde_explicit_step=num_dde_explicit_step,
+        num_dde_implicit_step=num_dde_implicit_step,
+        discontinuities=new_discontinuities,
+        discontinuities_save_index=discontinuities_save_index,
     )
 
     #
@@ -871,6 +1024,7 @@ def _allocate_output(subsaveat: SubSaveAt) -> SaveState:
         throw=throw,
         passed_solver_state=passed_solver_state,
         passed_controller_state=passed_controller_state,
+        y0_history=y0_history,
     )
 
     #
@@ -920,6 +1074,8 @@ def _allocate_output(subsaveat: SubSaveAt) -> SaveState:
         "num_accepted_steps": final_state.num_accepted_steps,
         "num_rejected_steps": final_state.num_rejected_steps,
         "max_steps": max_steps,
+        "num_dde_implicit_step": final_state.num_dde_implicit_step,
+        "num_dde_explicit_step": final_state.num_dde_explicit_step,
     }
     result = final_state.result
     sol = Solution(
diff --git a/docs/api/delays.md b/docs/api/delays.md
new file mode 100644
index 00000000..694d2afa
--- /dev/null
+++ b/docs/api/delays.md
@@ -0,0 +1,42 @@
+# Delays
+
+Delays allow to model a broader class of differential equations, Delay Differential Equations (DDEs). 
+Compared to ODEs, DDEs vector fields need a new argument `history` that integrate delayed states.
+
+At the moment only DDEs with known delays is supported.
+
+!!! example
+    The equation's RHS $y'(t) = y(t) + y(t-1) - y(t-2)$ is modelled by
+    ```
+    def vf(t,y,args,history):
+        return y + history[0] - history[1]
+    ```
+
+The first element of a `PyTree` of delayed states in a vector field's definition would be `history[0]`. If the state is multidimensional, then `history[0][i]` refers to the $i$th dimension of the first delayed state.
+    
+
+
+::: diffrax.Delays
+    selection:
+        members:
+            - __init__
+
+!!! example
+    Pytree of three `Delays.delays` : 
+    ``` 
+    delays=[lambda t, y, args: 1.0, (lambda t, y ,args: min(t,2), lambda t, y, args : 1/2 * jnp.cos(y))]
+    ``` 
+ 
+!!! info
+    If `recurrent_checking=True`, then at integration step, a so-called artificial event function $g$ checks for any new discontinuity jumps unconditionally (which are $g_i$'s new odd multiplicity roots). Let $\lambda_i$ be the combined intial discontinuity jumps and the ones found up to the current integration step.
+
+    $\begin{align}
+    g_i(t,y(t)) = t - \tau(t,y(t)) - \lambda_i, \quad i = \dots, -2,-1,-,1,2,\dots \\
+    \end{align}$
+
+    This can be prohibitely expensive but in some cases can speed up the integration and its accuracy. 
+    We refer to this [paper](http://www.cs.toronto.edu/pub/reports/na/hzpEnrightNA09Preprint.pdf) for more information on how to detect new discontinuity jumps in a DDE.
+
+!!! note
+    If the history function is continuous and the initial time point induces a discontinuity `max_discontinuities = jnp.array([t0])`.
+
diff --git a/examples/dde.ipynb b/examples/dde.ipynb
new file mode 100644
index 00000000..5546d4d6
--- /dev/null
+++ b/examples/dde.ipynb
@@ -0,0 +1,261 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Delay Differential Equations : DDEs"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "This example demonstrates the use of DDE solvers to handle delayed system of equations. In this case, we model the dimensionless [delay-logistic equation](https://www.math.fsu.edu/~bertram/lectures/delay.pdf). \n",
+    "\n",
+    "$$ \\frac{dy}{dt} = \\alpha y \\left(1 - y(t-\\tau) \\right) $$\n",
+    "\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "This example is available as a Jupyter notebook [here](https://github.com/patrick-kidger/diffrax/blob/main/examples/dde.ipynb)."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 64,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import time\n",
+    "\n",
+    "import diffrax\n",
+    "import jax\n",
+    "import jax.numpy as jnp\n",
+    "import matplotlib.pyplot as plt"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Using 64-bit precision is important when solving problems with tolerances of 1e-8 (or smaller)."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 65,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "jax.config.update(\"jax_enable_x64\", True)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "In order to model a delayed system we need to instantiate its vector field and its history function $\\phi(t) = y(t<0)$ (i.e. initial condition)."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 66,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def delay_logistic_vf(t, y, args, *, history):\n",
+    "    alpha = args\n",
+    "    return alpha * y * (1 - history[0])\n",
+    "\n",
+    "\n",
+    "def history_function(t):\n",
+    "    return 2.0"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "A DDE's vector field holds an additional argument `history` compared to an ODE's. The `history` variable refers to the delayed terms of your equation (here $y(t-\\tau)$). "
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "We need to instantiate the `Delays` object that encompasses all of the information needed to integrate properly our DDE. The arguments `initial_discontinuities` also need to be given. It accounts for the discontinuities found in the history function.   \n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "In our case we only have 0 since $y\\prime(t=0^{-}) \\neq y\\prime(t=0^{+})$ because $y\\prime(t=0^{+}) = - 2 \\alpha$ and   $y\\prime(t=0^{-}) = 0$.  \n",
+    "We choose $\\tau=1$."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 67,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "delays = diffrax.Delays(\n",
+    "    delays=[lambda t, y, args: 1], initial_discontinuities=jnp.array([0.0])\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Since DDEs require dense solutions we need to set `diffrax.SaveAt`'s argument `dense=True` in `diffrax.diffeqsolve`. "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 68,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "@jax.jit\n",
+    "def main(alpha):\n",
+    "    terms = diffrax.ODETerm(delay_logistic_vf)\n",
+    "    t0 = 0.0\n",
+    "    t1 = 20.0\n",
+    "    ts = jnp.linspace(0, 20, 200)\n",
+    "    solver = diffrax.Bosh3()\n",
+    "    stepsize_controller = diffrax.PIDController(rtol=1e-3, atol=1e-6)\n",
+    "    sol = diffrax.diffeqsolve(\n",
+    "        terms,\n",
+    "        solver,\n",
+    "        t0,\n",
+    "        t1,\n",
+    "        ts[1] - ts[0],\n",
+    "        y0=history_function,\n",
+    "        saveat=diffrax.SaveAt(ts=ts, dense=True),\n",
+    "        stepsize_controller=stepsize_controller,\n",
+    "        delays=delays,\n",
+    "        args=alpha,\n",
+    "    )\n",
+    "    return sol"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Do one iteration to JIT compile everything. Then time the second iteration."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 69,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Integration took in 0.0004687309265136719 seconds.\n"
+     ]
+    },
+    {
+     "data": {
+      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiMAAAGeCAYAAABGlgGHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAB2aUlEQVR4nO29eZhcZZn3/z219lZdvaW3pLNANkL2CCFBDWIkZFCJCyI6BhxgXIID4owOvirz6vubOMMgzigCLhAVWZXFAQRDMARIWLJBErJv3Unva1VXd9f6/P449Zyq7nRV16k6+7k/19XXBZ3q7tP91DnP97nv733fAmOMgSAIgiAIQiccel8AQRAEQRD2hsQIQRAEQRC6QmKEIAiCIAhdITFCEARBEISukBghCIIgCEJXSIwQBEEQBKErJEYIgiAIgtAVEiMEQRAEQegKiRGCIAiCIHTFpfcF5EIikUBrayt8Ph8EQdD7cgiCIAiCyAHGGILBIBobG+FwZIl/MBn84he/YAsWLGA+n4/5fD52ySWXsBdeeCHr1zzxxBNszpw5zOv1svnz57Pnn39ezo9kjDHW0tLCANAHfdAHfdAHfdCHCT9aWlqy7vOyIiNTpkzBj3/8Y8yaNQuMMfz2t7/F1VdfjT179uDCCy885/Xbt2/Hddddh40bN+LjH/84HnnkEaxbtw67d+/G/Pnzc/65Pp8PANDS0oLy8nI5l0wQBEEQhE4EAgE0NTVJ+3gmBMYKG5RXVVWFu+66CzfeeOM5/3bttdciFArhueeekz53ySWXYPHixbj//vtz/hmBQAB+vx8DAwMkRgiCIAjCJOS6f+dtYI3H43jssccQCoWwYsWKcV+zY8cOrF69etTn1qxZgx07dmT93uFwGIFAYNQHQRAEQRDWRLYY2bdvH8rKyuD1evHVr34VTz/9NObNmzfua9vb21FXVzfqc3V1dWhvb8/6MzZu3Ai/3y99NDU1yb1MgiAIgiBMgmwxMmfOHOzduxdvvfUWvva1r+H666/H+++/r+hF3XHHHRgYGJA+WlpaFP3+BEEQBEEYB9mlvR6PBzNnzgQALFu2DO+88w7++7//Gw888MA5r62vr0dHR8eoz3V0dKC+vj7rz/B6vfB6vXIvjSAIgiAIE1Jw07NEIoFwODzuv61YsQJbtmwZ9bnNmzdn9JgQBEEQBGE/ZEVG7rjjDqxduxZTp05FMBjEI488gq1bt+Kll14CAKxfvx6TJ0/Gxo0bAQC33norVq1ahbvvvhtXXXUVHnvsMezcuRO//OUvlf9NCIIgCIIwJbLESGdnJ9avX4+2tjb4/X4sXLgQL730Ej72sY8BAJqbm0d1WFu5ciUeeeQRfO9738N3v/tdzJo1C88884ysHiMEQRAEQVibgvuMaAH1GSEIgiAI86F6nxGCIAiCIAglIDFCEARBEISukBghCIIgCEJXSIwQBEEQBKErJEYIgiAU5s/vtmLLwY6JX0ioTiyewNbDnegLRfS+FCILJEYIgiAUZMvBDvzTo3vwld/vwsBQVO/LsTX9QxF8edM7uOGhd/C9Z/frfTlEFkiMEARBKERgJIr/87S46cUSDNuPd+t8RfblVHcIn/z5G3jtqLgGO473wASdLGwLiRGCsAD7zw7gl9uOIxJL6H0ptmbjCwfRHhiR/n/bURIjenHPy0fQ3DuEpqpieFwO9IYiONUzpPdlERkgMUIQJocxhn96dA/+/YVD2LT9pN6XY1veO9OPR98WJ4x/ZdV5AIDXj3XpeUm2Zt+ZAQDA/7duAeY3is229jT36XlJRBZIjBCEydl/NoAT3SEAwK9eO4mRaFznK7Inb57oAQCsvqAO/3T5LLidAlp6h3G6J6TzldmPUDiGk8m/+7zGciydWgkA2E1ixLCQGCEIk/Ps3rPSf3cFw3hiZ4uOV2NfjnQMAgDmTy5HqdclbYCUqtGeQ+1BMAbU+ryoKfNi6bSkGDndr++FERkhMUIQJiaeYPjf91oBAKtmTwIAPPDqCfKO6MDRjiAAYHadDwDwoVk1AIDXj1KqRmvebwsAEKMiACRheKg9gFA4ptt1EZkhMUIQJubtk73oCIRRXuTCz76wBDVlXpztHx4VLSHUJ5FgONopRkZm15UBAD44SxSH24/3IBYncagl77cmxUiDKEbq/UVo8BchwYD3kl4SwliQGCEIE/Pnd8WoyNr5DSgvcuMLy6cCEDdAQjvO9g9jKBKH2ylgWnUpAGDBZD/8xW4ER2LYd5Y2QC0ZGxkBQL4Rg0NihCBMSjzB8Jf9bQCATy5uBJA6lTf3UgmjlhztFFM059WUwe0UH6tOhyCdzE9TSalmxOIJHGobHRkBgCVTKwBQRY1RITFCECaltX8Y/UNReFwOLJ9RBQCYViWeymnz0xZuXp2VFIOchooiAEDbwMg5X0Oow6meEMKxBEo8TilKBQBLkpGRPc39Ol0ZkQ0SIwRhUng577SqEriSp/Gp1SUAgO7BMBn1NORI+2jzKqfBL4qR9oFhza/JrrzfJq7F3HofnA5B+vycenFtekIRujcMCIkRgjApp5JiZHpN6vTnL3bDX+wGALT0UXREK450ji9G6v3FACgyoiWSeTXNLwIAZV4XSj1OAEBnMKz5dRHZITFCECblZFKMzEgTIwAwLRkdoVSNNiQSDMfGVNJwGsopTaM13Lx6QUP5Of9Wm1yPjgCth9EgMUIQJuVUssPk9OrRYqSpShQjLWRi1YSWviGMRBPwuByjPAqAWFIKkBjRkkPZxIjPC4AiI0aExAhBmJRUmqZk1OenVVFkREu4efX8SWWjPApAyjPSPRimRnQaEIklJKHB74N06pKRkU6KjBgOEiMEYUKi8QRa+kRT5Ng0zdTkQ5jKe7XhiNR5teycf6sq9cCTNBdTakB9+N/Y43KgqtRzzr/XlXtHvY4wDiRGCMKEnOkbRjzBUOR2oM5XNOrfeEUNiRFtONElRqhmTjpXjAiCIKVq2mkDVJ3WflGgN/iLIAjCOf9eJ3lGKE1jNEiMEIQJkVI01aVwjEkN8MjImb4hxBNM82uzG+0BcQNsrCge99/JN6Id/G/M02NjIQOrcSExQhAmJFMlDQA0+IvhdgqIxhnaqL+F6rRPsAE2Uq8RzWhN/o0b/eMLwzoysBoWEiMEYUKkSppxxIjTIaCpMpmqIROr6nAxUp9BjFCvEe1o608Kw4rskREysBoPEiMEYUKkyEj1uWIESJX3km9EXYIjUYQicQCZxUiqCyttgGrDI4ENGSIjvLQ3FIljkLqwGgoSIwRhQrJFRoC0xmckRlSFC4zyIhdKPK5xX0OeEe3gf+PGDJGRUq8LPq+4TuQbMRYkRgjCZERiCZxNlvWO7THCkcp7KU2jKm0TpGgAioxoScrAOn5kBABqqbzXkJAYIQiT0dw7hAQDSj1OTCrzjvsa6jWiDbxctz7L5seFSmdwBLE4NT5Ti5FoHL2hCIDMBlYgvfEZmViNBIkRgjAZp5MpmmnVpeP2UgBSJ8POIJ3+1ESqpCnPHBmpKfXC5RCQYFTFoSY8KlLicaK8ePyUGZAmRujeMBQkRoiC2XW6F++29IMx6mmhBfw0nqmvBQBUloqTe/tCUVoXFeFrUZclTeNwCNIGSL4R9WiboOEZh5tYqfGZscgsHwkiBw62BfDZ+3eAMWBuvQ83f+g8fGbZFL0vy9LwhyhvbT0e1aXiv0XiCQyGY/AVuTW5NrsxUY8RToO/CGf7h8k3oiKtOfhFAGp8ZlQoMkIUxCuHOsEP3ofag/jWk+9K3UEJdeA9EuqypAaKPU4UucXbuy8U1eS67EguBlYAaKjgvUao8ZlapEdGssFFPHlGjAWJEaIgth3pAgD8y5o50sjuvS39Ol6R9emQxEjmyAiQio70hOihqxZ8LeqzCEPx36mCQ23aArzhWfbIiDSfhjwjhoLECJE3g+EYdjf3AQA+vrABy2dUAQDeOzOg52VZHp6mqZ1gA5R8I0MR1a/JjqRXb0x0Gq8oESfI9g9RlEoteGSkcaLIiC9VTUN+KuNAYoTIm7dO9CAaZ5hWXYJp1aVYOMUPANh3tl/fC7M4vCJj7LTesVTxyMggiRE14GF+r8sBf3F2T04lFyPDJEbUQuoxMkFkhPcZGY7GEaQurIaBxAiRNzxF86FZNQAgiZH9ZwM0LVYlovGElHaZKE1TVUKRETVJtR7PXr0BABXJtRigyIhqtOYYGSlyO1FeJNZu0Iwa40BihMib1452AwA+NGsSAGBGTRlKPU4MR+M43jWo56VZlu7BMBgD3E5BOm1nQoqMhEiMqEGq4Vn2zQ8AKopJGKpJKBxDYESMckwUGQGASTS913CQGCHyoqV3CCe6Q3A6BKw4vxqAOC32wkYxOkK+EXWQ/CK+Ijgc2U/jVVKvEdoA1UCa1juBdwcA/MnICKVp1IGnaHxeF8q8E3es4EI+QOthGEiMEHnxxjExKrJ0agXK03pYLJBSNSRG1IBXY9ROkKIBUpGRXhIjqpAq6534JM4NrAND1IRODbqSEY5JOdwXACSPDxmKjQOJESIvDrUHAQBLplaO+jz3jbx3pl/rS7IFPMfNu0hmo6pU3ABJjKhDrg3PgFSaJhJPYDgaV/W67Ej3oChGajLMahoLFyMDFBkxDCRGiLzgI+xnjBlhv2CyKEYOtAZoKJgKpLqvTrwBkhhRl/Ycms9xSjxOeJzi45ZO48rDxUimwZFjobSZ8SAxQuTFyWSX1enVo8XI9OpS+LwuhGMJHO0kE6vSdMjYAEmMqEuuzecAQBAEaQMkE6vypCIj2U3dHIqMGA8SI4RsovEEzvSJZXRjIyMOh4ALJ4udWMk3ojwdQW5gzT1NExiJIUpRKkVhjEn9WyblsBZAKlVD5b3K0x0U14LSNOaFxAghm5beIcQTDMVu57inQi5QzvbTHA6lyWUuDcdf7AYvuKHTuLIEwzFEkgIv1w2wglIDqsF771TLXAsShsaBxAghG+4XmVZdMm6zJ75R0oRS5ZGTpnE6BKmKg1I1ytKdjFCVeV0ocjtz+hp/MbWEV4uuQR4ZoTSNWSExQsjmZPcQgHNTNBzed6GduhsqSjgWR19yI8vFpwCQb0QteCO56hw3PyA9MkJroTRcHNbkmDIjMWI8ZImRjRs34qKLLoLP50NtbS3WrVuHw4cPZ/2aTZs2QRCEUR9FRROf6gjjcoqbVzOIkTo/RUbUgM9C8eQwC4VTRZERVZA2vxzTAkDKM0KREWVhjMmvppGiVHRfGAVZYuTVV1/Fhg0b8Oabb2Lz5s2IRqO44oorEAqFsn5deXk52trapI/Tp08XdNGEvmQq6+Xwvgs0Ll1ZOoOp6o2JZqFweGSEurAqSzePjJTmHhmpLKUNUA0GwzGEY/L8O1zMB8MxmqNlECbum5vGiy++OOr/N23ahNraWuzatQsf/vCHM36dIAior6/P7woJw8HLeidK0/QNRTESjeecUyeyI/UYmWBabzp8A6T5NMrSMyjPMAlQ10+16E76RUo9ThR7cvXviGvBGBAciUreKkI/CvKMDAyIpZtVVVVZXzc4OIhp06ahqakJV199NQ4cOJD19eFwGIFAYNQHYQzCsbg0HXNsjxGOv9gNr0t8a/HUAlE4csyrnGqKjKhCKi2Qj2eExIiSdOchDD0uB0qSwoV8I8YgbzGSSCRw22234dJLL8X8+fMzvm7OnDl48MEH8eyzz+Lhhx9GIpHAypUrcebMmYxfs3HjRvj9fumjqakp38skFKaldwgJJlYRZHKuC4IgTTIlE6tySEPycjSvAhQZUQveY0TOBlhRnJpPQyhHj8yGZxwysRqLvMXIhg0bsH//fjz22GNZX7dixQqsX78eixcvxqpVq/DUU09h0qRJeOCBBzJ+zR133IGBgQHpo6WlJd/LJBSGV9JMrxm/rJdTRxU1iiN3/gaQFhkhn4KipMQIVdPoTaqsN/f7AqC0mdGQ5Rnh3HLLLXjuueewbds2TJkyRdbXut1uLFmyBMeOHcv4Gq/XC69X3huL0IaT3WKL90wpGg73jXRQRY1i9MisGADSIiODtAEqST7CkG9+fcnJvbmakInsyC3r5VBkxFjIiowwxnDLLbfg6aefxiuvvIIZM2bI/oHxeBz79u1DQ0OD7K8l9GeiHiMcnqZpIzGiGPn0tqDIiDrInYUCpIRhJJbASJTa8ytFPsIQIDFiNGRFRjZs2IBHHnkEzz77LHw+H9rb2wEAfr8fxcXFAID169dj8uTJ2LhxIwDghz/8IS655BLMnDkT/f39uOuuu3D69GncdNNNCv8qhBY09/Luq9nFCE/TUHmvcuTjU6hMa3pGp3FlCMfiCIzEAMjbAEs9TrgcAmIJhv7hCIo9xWpdoq3IRxgCaS3hSYwYAlli5L777gMAXHbZZaM+/9BDD+GGG24AADQ3N8PhSAVc+vr6cPPNN6O9vR2VlZVYtmwZtm/fjnnz5hV25YQu8EjH5IrsD1Lqwqos6Y2d5PS24E3PonGGwXAMvqLcmqURmeEN5FwOAeUy/p6CIKCixI3uwQj6h6Jo8JMYUYLuAj0jJEaMgSwxwtjEzWG2bt066v/vuece3HPPPbIuijAmjDGpqypvbJaJer/4YKAurMoQisSlxk5y0jTFHieK3A6MRBPoH4qSGFEAHqGqKvXA4ZAXafIXp8QIoQyFpmmoCZ0xoNk0RM4ERmIYisQBpDwhmeBpms7gCBLU4bBguHm12O1EiUee75yf3gMjtAEqQVeemx8AqbkWbYDK0SNzSB7Hn1wLiowYAxIjRM7wKEdFiXvCrqq1yS6h0ThDLz14CyYf8yrHVySKl8BwTNFrsiv5lPVypPk0tAEqwkg0jsFw0r9D1TSmhsQIkTPc/1GfQwdQj8shnVQoVVM4+ZhXOeV8DgdFRhQh37QAkB4ZobVQgq5ganikzysvYkh9RowFiREiZ9oHxDbwE6VoOPU0ME8xpC6TMsyrHJ+UpqHIiBLk2/EToMZnStOddl/IrRTjUaoARUYMAYkRImfacjSvcqiiRjkKSdOUJ9M0FBlRhkKiVHwDpJbwyiBV0shM0QCUpjEaJEaInOmQ0jS5lSRKLeEpTVMw+QwD40iREfKMKEJhBlbehZUiI0pQSMqMi5FQJI5onJrQ6Q2JESJneGSEl+1ORD2JEcWQTuN5pGnKiykyoiSFGFj95BlRFN7zpSqv+yJV5k7REf0hMULkTLskRnKLjPDpsvz0QuRPTyj/EyCV9ipLyqdAqQG94WIkH5HudAhSpRmJQ/0hMULkjFzPCK8c6KMbvWDSG23JxSd5RihNUyiJBJM2wBofrYXe9BUQGQFIHBoJEiNETgxH4tINm2s1TWUJDWlTisIMrBQZUYrASBSxZBO/vFIDSTHCe2MQhcHvi8o8xQj38FBFjf6QGCFyglfElHicOdfzV3KzXojESCGMOo3nZWCl07hS8OoNX5ELXlf2xn/jUeYV74nBcCyn8RpEdvhBh89gkovUa4RKrXWHxAiRE21pPUZyrefnp5XASAwxcqvnzcBwFPHkabwyj4duOfVTUAy++eXjUQBSwjCeYBiOxhW7LrvSW2BkxE+l1oaBxAiRE7kOyEungtzqisDNq/5iNzwu+bcsRUaUo9DNr8TjBNfyg7QeBVO4ZyRZ3UTPJ90hMULkBE/T1OXQCp7jcjqkjZBMrPnTXUApKZDyjARHKDVQKFIpaZ5pAUEQUJZMc1JH3MIYicYRSg7uzFeMSB4eWgvdITFC5EQ+kREglVagKaX5I00lzaOUFEhFRiLxBMIxSpcVQqGRESAlDsnEWhg8ZeZ0CJKokAsXhrQW+kNihMiJNpk9RjjcxNpLJta84WmafCMjpR4XHMnUAFXUFEahaQEgtQFSE7rCkIRhify5NJwynsIkMaI7JEaInOiQMbE3HX6CpKZC+VNomsbhSEsNUEv4gugdKlyM+Cg1oAh9IfGZUlXqnuCVmZEiI7QWukNihMgJuQ3PONRrpHB6eWQkzzQNkKqoodN4YRTqGQHoNK4UXBjmU2HG8VHfF8NAYoSYkGg8IbXAlmNgBdIHg9EmmC+FzELhSMPy6ARYEH0KeEZSaRpai0LoK6ARIEfq+0JroTskRogJ6Q1FwJhoFJMbniYDa+GkhuQVEBkpIp+CEqTSNPmnBrgwpA2wMHpChUdGyigyYhhIjBAT0hXkaQIPnA55RrFKGpleMN3JNE1hPgXe+IweuoXAfQpKpAZIGBaGEmZiWgvjQGKEmBAuRib55J/MeTibP8QJ+SgRji4vpoduoYRjcekEXUiUykflpIqgiGckbS2oB4++kBghJqQgMUIG1oKIJ5jUHbKQEyANyyscLqjTR8/nAxlYlUGRMuvkWiQYqD2/zpAYISakK2lerc1DjJCBtTD6h0S/DjC6vb5cyqklfMGk+lq44ZCZrkzHl9YRl8ifXgXESLHbKfXgIQ+PvpAYISZEiciIuKlSGFQu/IFbUeKGy5n/7UobYOH0KZAWANJ7W5BALwQlxAi15zcOJEaICekMij1GJuUxvp4/uGMJRjnyPFCirwWQMurR5N78UWLzA2hwoRIwxlLisOD1oPb8RoDECDEhqciIvB4jAFDscaLILb7NqAurfJR64KaantEDN1+UFiO0+eXPYDiGaFyMtBYq1KkLqzEgMUJMSCFpGiAVHaH5NPLpUXgDJANr/igxJA+gzU8JuJm4yO1AscdZ0PdK9Rqhe0NPSIwQE1KoGKmgipq86VMoTVNOnpGC4e/fwlNmybRAJIZEgnxU+cDLegspseZQR1xjQGKEyEooHEMoIpa85R8ZER++lKaRT49Cp3HyjBSOUpERvhaMAaEIbYD5wOc1VRbQCZdDXViNAYkRIit8Jk2x24nSPMOh1Gskf6SGZwp5Rug0nj99Q8qshdflgCtZT0qn8fzoVaATLsdHaTNDQGKEyEp6ikYQ8uutwE8v1GtEPr3Jv5mSp3FqtpUffEZQoWshCAKZWAtEiYZnnDLqiGsISIwQWSnULwLQsLxC4OHowk/jTnhd4u1OLeHzQynPCJDWhZVO43mhRCt4DnXENQYkRois8O6r+fQY4VRQNU3eSIPZFDgB0rC8/GGMpa1F4T4Fn5cbikkY5oNS6UuAqpuMAokRIivKREbIwJovPXxirwInQBqWlz+hSByReAKAQhUclKYpCKWM3QD1fTEKJEaIrCiZpiEDqzyGI3GMRMUNsKqAib0cKTJCJ0DZ8JO4En0tgJRpktI0+aGsZyRp7qa10BUSI0RWFBEjpdwzQidyOfC8uMfpyLuSKZ3UsDxaB7n0KNTvhSOdxmkDzAslPSM+8owYAhIjRFaU8IxUSpN7KTIih16pesOddyVTOhSOzp8+BdMCAJkmC0XRyAh1YDUEJEaIrCgRGeHdP4ciccSSeXdiYvjpr0oBjwIAlHpIjOSLUnNpOKkpyrQByiWeYOhPNu9TYj2oz4gxIDFCZCSRYFLTs0LECD+RA5QjlwMv661SoHoDSDsB0hrIRirrVSoyQhtg3vQPRcCSffsqSpTtwMoYNQTUCxIjREYGhqPSZMzqAgyUrjTPAw1qyx3eZVKpyIiPmjvljVS9obBnhMS5fLgwLC9ywe0sfAvjwjAaZwjHKHKrFyRGiIxwv0hFiRteV2EGSt6OnHpc5E5qSJ4ykZFSEiN5o6RHASD/TiGkRLoya8HTlwCJQz0hMUJkhPtFagowr3LKpbJSiozkipK9FABK0xSCUkPyOLyclAys8lF6LRwOgVrCGwASI0RGuhWopOHwhls0NTZ3lOwyCaTC0TQpVj5KtoIH0tM0dD/IRamBhemQh0d/SIwQGZEiIwWYVzk+iozIRvnTOD1w80Xpahpai/zpVdi/A6SXWtPzSS9IjBAZ6VHwZM4bbpFnJHd6FT6N8w2QUgPyUb60l9IC+aL0WgAkDo0AiREiI90K9BjhSAZWiozkjPTQVaAVPJAysIZoA5RFel8LJYbkAalIIfXekY/SDegAEodGgMQIkRHuGalRYDMsl5o80c2eC/EEQ79KPgU6/cljYDgq9bVQKjXAT+IAbYByUTpiCIAMrAaAxAiRkVSahgysWhMYjiIhNXZS2sAaRyJBzZ1yhUeolOprAQAelwOe5PeiDVAeSpdZA2kpTBLquiHrztq4cSMuuugi+Hw+1NbWYt26dTh8+PCEX/fkk09i7ty5KCoqwoIFC/DCCy/kfcGEdnQraGCl0l55cCHoK3LB41JmAyxNO41TRU3uKN19lVPqFXv3hMJxRb+v1VG65B0Y3YWV0AdZT7lXX30VGzZswJtvvonNmzcjGo3iiiuuQCgUyvg127dvx3XXXYcbb7wRe/bswbp167Bu3Trs37+/4Isn1IMxhu7koDZF0jTU9EwWvQqX9QKA1+WA2ykO3KOHbu70DCq/+QG0AeaLGpERmk+jP66JX5LixRdfHPX/mzZtQm1tLXbt2oUPf/jD437Nf//3f+PKK6/Ev/zLvwAAfvSjH2Hz5s34+c9/jvvvvz/Py1aG91sD6AmFMa+hHNUK9NKwEsFwDJGksY6anmmP0mW9ACAIAkq9LvQPRcnEKgOle4xweOdPWovcGYnGEYqIkSRFPSMkDHWnoPjvwMAAAKCqqirja3bs2IHVq1eP+tyaNWuwY8eOjF8TDocRCARGfajBvz71Hr70m7ext6Vfle9vZniKpszrQpG7sFbwAHlG5KJGZASg3Hg+qCEMgTQPD22AOdM/JD4/nA5h1ADOQvGRwV538hYjiUQCt912Gy699FLMnz8/4+va29tRV1c36nN1dXVob2/P+DUbN26E3++XPpqamvK9zKzQSPXMKJmiAdIjI/S3zgU+sVfJxk4AVQ3kg9KdcDk0K0g+6Q3PHA5Bse+bEul0WNKLvMXIhg0bsH//fjz22GNKXg8A4I477sDAwID00dLSovjPANL7LpCBbCypsl5l0lfcMzIYjlFfhRyQhoEpJAY5dBqXD0VGjEOq4Zky/V44NCpBf/KKc91yyy147rnnsG3bNkyZMiXra+vr69HR0THqcx0dHaivr8/4NV6vF16v+h6OVKMbUsNj6UmKkWqFNsP0kOpgOKZYuapV4ZERxdM0NLpeNmr0tQDSqmkidBjKFb4WSkcM6WCqP7IiI4wx3HLLLXj66afxyiuvYMaMGRN+zYoVK7Bly5ZRn9u8eTNWrFgh70pVgD8MBukNeA5dUppGGVHodjpQnPSeUEXNxPQmc+PqPXRpDXJFjY6fQGotSBjmjhqVNED6XkBroReyIiMbNmzAI488gmeffRY+n0/yffj9fhQXFwMA1q9fj8mTJ2Pjxo0AgFtvvRWrVq3C3XffjauuugqPPfYYdu7ciV/+8pcK/yryKaVyrowonaYBRBPrcDROFTU5IEVGFE7T+MinIJtelfqMUJpGPmrMpQFoNo0RkBUZue+++zAwMIDLLrsMDQ0N0sfjjz8uvaa5uRltbW3S/69cuRKPPPIIfvnLX2LRokX44x//iGeeeSar6VUrfPQwyEiPgq3gOZKJlSpqJqR3kD90lU1X0rA8+fRx/45KkRF6/uSOeg3oxLUYjsYRp+7EuiArMsLYxIu0devWcz53zTXX4JprrpHzozRBioyQaekcuhVO0wDpw/Lo7z0R6vkUaAOUQzgWl6JIaq0FRalypyekjmekbEx3Yn5wIrTD1rNpKDSXGSlNo0AreE550jxJaZrsDEViGImKFUdKV9PQsDx58KiI0n0tgLTILB2GckYtz4jX5YAzWSpMQl0fSIyA3nzjwVtgK1nNkWoJT2IkG/xv73E5UOopvOFcOqnTOJm2cyHV18KtaF8LgNYiH9QqsxYEgfYDnbG1GKEw6fiMRFOhaWUjI5SmyYX09uOCoOwGmGp6RoIwF9TyKADpg/LofsgVvh5Kl7wD6fcGiUM9sLUYoXkE49OVbAXvcTmkULISUEv43OhRKRQN0HteLmp5FACKzMqFMSalzZSOjAAkDvXG3mKEHgbjwh/ANaXKnsxpWF5uqJUXB9Lf83T6ywU114Iis/IYTBveqbSZGKD10BsSI6A331j4kDwlUzRAumeE/t7ZUKuXAkCD8uSilkcBGH0YyqVS0e7wqEiR24Fihb1UAB1O9cbWYoQr4WicIRyjkyJHjYZnAEVGckXVNA09cGXRp1KJNZB6/iQYpOopIjNqlbtz+OBUujf0wd5iJE1dU9g6hZSmUbislDwjuaFFmmY4GqeBhTmgZpSqxJ16/gTJUDwh0n2h8HOJU0oNAXXF1mLElTYvhfoupOAG1mqFIyO+ZGSEUgTZUTMyUpre3IkE+ISoKUYcDkE6ENFaTEyvimZiACgjA6uu2FqMAGRaGo8u1dI0FBnJBf7QVaN80eNywOMSb3vqPDwxanpGAOqIKwc1hSFAk3v1xvZixEeljufADayTVDKwDkZiSND8h4yoNSWW46POwzmjpmcEIBO9HLhnRLXICO0FumJ7MUK15efSrcKQPCAl/BijvGw2elSMjAAUDcyV9L4WavkU+AZIz5+JUdNLBZC5W29IjHjowTwWPiRvksJpGq/LiSK3+JajVM34ROMJDAyr19gJoNN4roQicVX7WgD0/JGD6mkaWgtdsb0YoTTNaMKxuLQZKu0ZAVLlvQMkRsalf0j8uwiCBuFoStNkpTcpytXqawGQT0EOarbmB8i/oze2FyP0BhwNH9LmcgjwFys/RpuLP6qoGR9++qsodktTRJWGwtG5oXZfC4AqOOSgZmt+gLoT6w2JEepIOYr0hmdKTykF0st7KTIyHmpXbwBpXVhpA8yK2kZigPw7clDbM8L9g7QW+mB7MeKjU+IoJDHiU6mSgyIjWVGzrJdD0cDcUNujAFCUKlfiCYZ+yUulfMQWIC+V3thejEgPZuq5AADoDvLuq8r7RYCUZ4QiI+PTGxLFoJobIPmkckMLMULPn9wYGI6Cj+9RK01TSrOCdMX2YoQGh41GrYZnHIqMZKeXl5JqkaahNciK2n0tAEoT5woXhuVFLrid6mxbfC1iCYZwjEYlaA2JEQqTjkKtIXkcSYzQ33tctIiMUJomN9T2KABkYM0VLVNmAK2HHthejJCBbDR8Lo3SDc84ZGDNjtoVA0BaB1Z6z2dF0zQNVXBkRQtjt9MhSLPKaD20x/ZiJNUCmN58QCoyonQreA6PjAQoLD0uvJdCtUpiEEgT4LQGWVG7rwVAh6FcUbstP4fWQz9IjFCYdBRqdV/l0OTe7PA+L1Wl6vz9AZrBkStaRqnIwJodLaJUQNp+QOuhOSRGvMnhbfRgBpBe2qu2gZXSNOOhxQmQShhzQwvPCPl3ckOLtQAoMqInthcj1OgmRSSWkNqRUzWN9jDGUidAFdM0ZNqeGC36WgAkDHNFC88IQOJQT2wvRvjDIBJLIGLzcq6eZCWH0yGgQoVW8AD1GclGMBxDNC72N1A1MkIVTROiRV8LILX5jUQTiMXt/fzJhhat+YE0cUiHJc2xvRgppXIuiVTDM48qreABioxkg4eii91O1QazASTAc0GLvhZAKjILiFOCifHRojU/QGkaPbG9GHE7HfC6xD+D3d+AavcYAVKRkaFInE6CY+jRKi+eJnTsLsAzoZVh0utywu0UhT+tRWZ6NahsAtILGkgYao3txQhA7bE5andfBVIpAoD+3mPpHdTmgetyOqR+CrQG46OVRwEgn0Iu9GnQmRgASj1U3aQXJEZADwOOFpERd9pGGBi29997LFqd/gAKR0+EVn0tgNQGSB6e8QnH4tL7VHXPCB1MdYPECOhhwJE8IypN7OWkGp+RiTUdLSb2cigamB2t0jQAVTdNBI+KOB2C9L5VC1oL/SAxgpQatvsbkKdp1Gp4xiET6/homRqgqoHsaClGSqnxYlak+6LErZqxnkNRcv0gMQJSw5zuoLqt4Dk0n2Z89NgAKTIyPlpVbwDpKTMyTY5HrwadcDmUvtQPEiOgMd4cLTwjAEVGMqFlmoY6D2dHq74WQOp+sPthKBO8/5Ga85o4VE2jHyRGkB4ZsfcbUCsxQo3PxkfbNA2lBrKhVftxIOVZI2E4PnxeU7XKzyUgrZqG1kJzSIyAhiMBQDSeQJ/UCl4bAytFRkajaWSE1iArPDJCpb36wyMjNRquhd2LGfSAxAhSHobAsH1P6vz04XQIqudmfdSOfFy0jYxQmiYbWvV8AcizNhEpL5X6kRFaC/0gMQKx5TNg71MiT9FUl6rXCp5DBtZzSe+loI1nhNI0mRiOxKXW7Fr4FMjAmp1uKU2j3VoMReJIJJjqP49IQWIEQHlyKNyAjSMjWnRf5aT6jNBGyEnvpcA9NWpSRuHojPC0gMfpgM+rbl8LgIThROjRfwewd9peD0iMAPAnxYidm3Dxst4alct6gfTICN3sHL4BatFLASCfQjakza/MA0HQbi0oZTY+PTxqq8FByetywOngs4IoUqUlJEaQiozY2TPCQ6Fqm1eBdAOrff/eY9Fq9gZH6sBKgvAcejRMCwAkRiaiJ6TdegiCIA2SpPXQFhIjSEVGbJ2mCWrTfRWgaprx4JERrcQIGVgzw/1TWhgmATJNZiMci0vPCS3SNACth16QGEGq70VgJAbG7Gla0qrHCEB9RsZDy+6rAHVgzQY/iWtRSgpQyiwb/L5waeSlAmg99ILECIDyYvHNF08wyUVvN7gYUbsVPECRkfHQsskWQIPyspHyKGh7Eqe1OBeeMqvUoMqPQ2kzfSAxAqDY7YTbKb7R7eob0TIywg2sQ5E4YvGE6j/PDPRo2EsBGH36s2s0MBNadvwE0tICkTitxRh6NKyk4ZA41AcSIxBNS6lUjV3FSDI07dO2fI5ueBEpTVOiTSiaP3CjcYZwjARhOlpvgDxlFk/QWoylN6TdIYlDU5T1gcRIEqnXyJD9xIjYCp5X06h/07udDhS5xbcepWpEpMiUBmkyIDWDAyBBOJYejTdAWovM9GjYCZdDTej0gcRIEqm814abY28oAsYAh6DNmG4grQW/TSNRY0mVVmuzATocqRJGOgGORuvSXodDQAmtxbhoWdbLoWoafZAtRrZt24ZPfOITaGxshCAIeOaZZ7K+fuvWrRAE4ZyP9vb2fK9ZFXhLeDt6RnhZb3WZV2r4ozZkYh2Nlp4dDg3LOxfGmM6ncVqLdHrSxlRoBXlG9EG2GAmFQli0aBHuvfdeWV93+PBhtLW1SR+1tbVyf7Sq2LklvB4bIXVhTTESTfVS0KLpHIdKGM8lGI4hkjRVV2tkJgbST+OUGkhHazMxQPeFXsgevLB27VqsXbtW9g+qra1FRUWF7K/TCju3hNey+yqnnLqwSvBQtNspSO9DLfDRCfAc+OZX6nGiOJk60YJU3xe6H9LRs5qGZtNoi2aekcWLF6OhoQEf+9jH8MYbb2R9bTgcRiAQGPWhNlI1zbD93oBSjxENTx/lFBmRSIWivZrMQuGUUa+Rc+DVG1qexIGUiZVMk6PpCWnb8wUgA6teqC5GGhoacP/99+NPf/oT/vSnP6GpqQmXXXYZdu/enfFrNm7cCL/fL300NTWpfZm2bgnfpeGQPA7Np0nRrXGTLU5qAyQxwtFyXH06ZJocn16+HpqmzMhMrAeqz8eeM2cO5syZI/3/ypUrcfz4cdxzzz34/e9/P+7X3HHHHbj99tul/w8EAqoLEt6F1Z5pGu4Z0e4BzMWIHauXxqJ1JQ2njIblnUOPDpsfQD6F8RiOxKWO2FV6REbovtAU1cXIeFx88cV4/fXXM/671+uF16vtwyCVprGvGNGiFTzHR/NpJPQwEANUNTAePToIc4BSZuPBUzQep0PyN2kBVTbpgy59Rvbu3YuGhgY9fnRG7Jym6Q5qfzKnyEiK1N9fn9QAPXRT9Gg8I4hDaZpz6U3rMaKpl4oMrLogW24ODg7i2LFj0v+fPHkSe/fuRVVVFaZOnYo77rgDZ8+exe9+9zsAwE9/+lPMmDEDF154IUZGRvDrX/8ar7zyCv76178q91soAC/ttaOhkkp79UXrjp8cStOcS8q/QwZWvdGj3wtAKTO9kC1Gdu7ciY985CPS/3Nvx/XXX49Nmzahra0Nzc3N0r9HIhF861vfwtmzZ1FSUoKFCxfi5ZdfHvU9jIBdm57F4gn0atgKnkMG1hSpVvA6ncbpBCjBT+NaR6loHsq5pLqv6pO+FOc2xeF1aVfibWdki5HLLrss62TJTZs2jfr/b3/72/j2t78t+8K0hqdpguEY4gmmWSdSvUlvBa/lCYQ6sKbgaRqtTZP8oUtrkEIvAyulac5Fj+6rAKQxCYDYhI7EiDbQbJokPG0A2Ou03hFImVe1FGDlZGCV0CtNQ0a9c9GjrwVAazEeejQ8AwBX2iBPEofaQWIkicflQLFbVMB2anzWERgBANT6ijT9uRQZEYknWCo1oHGaxken8VGkr4VufUYoZSbB+x9pWeXHIXO39pAYScOOFTWdyRu+rlzbG55HooYiccSSs0DsSN9QBAkGCAJQpdHEZA4ZWEfTn1wLQPu1oN4W56JHywEOmVi1h8RIGnZsfMYjI5N0iowA9j598AduZYkHLqe2tyOlBkbD0wIVJW4d1oLPpqFqGo7UGVrj9CVA3Yn1gMRIGtKwPIqMqI47LS9r51SNXj1GgNGD8rKZ0u1Cd1AfwyRABtbxoDSNvSAxkgY3VdoqTaOTZwRIpWrsFIkai2SY1Lh6A0ilaRIMGI7SibwrGaXS417gUarhaBzxBAlDvVoOcKjUWntIjKTBG5/ZaXPUKzICkIkV0GdIIafY7QQvoKIToDFO4gCZWAH9Wg5waHKv9pAYSSOVprHPw0CvahqAurAC6UPytH/gCoJAxsk0unQ0THpdDriSypBO46m1qCrVtuUAhx+UaC20g8RIGrwLq13SNPEEkwyUekRGyqkLa9pgNu3//kC6V4FOgHpGRtKFIW2A+q4FkDKw0lpoB4mRNOyWpukZDCORDIVq3XIZoDQNkD4XSPvICJDWhTVsj/d8NqQNUGdhSKkBfSOGAFWa6QGJkTTKbdZnhHdfrSnTKRTqpS6sqYeuThsg9RqR0NO/A5BpMh29IyNU3aQ9JEbS4NU0dint7QyKfpG6cu39IgBFRoC0+Rt6p2nINJlqsqXTWtBpPIVx1oKiVFpBYiQN3vTMbpGRWp1OH6nSXns+fBljupomgbTUgE3XgBOLJ6SmZ3qvBZ3G9Y+MUJRKe0iMpMFLyPqH7CFGeGSkVvfIiD3+3mPpG4oiGhd7StAJUF/0LiUFqOtnOt0GMXbTWmgHiZE0+DyKvqEIEjZoPKR/ZMTeaRouBitL3PC49LkVUw9dewpCDu+3U62TfwqgNE06+kdGKEqlNSRG0qhIipEEs0dFDe++qp9nxN4G1s6Afh0/OT4ysAJI6zGi00kcAMooNSBBkRH7QWIkDY/LIc3r4KPErQw/DeoVGSm3fWQk+ffXoccLh9I0It06n8SBVGWT3Xu+ROMJ9CVT5RQZsQ8kRsZQWZpK1VidDp0jI7yU2r5ihE9M1vM0TmkaQN/uqxy+Adr1fuD0JMvdnQ4BFclnhNakqszitkjZGwESI2PgYqQ3ZO2Hc3r3Vb1O5jxFYIeU2Hh0BfVP01AHVhE9x9VzeNrS9sIwmGoE6NDJv5M+K2iIhkhqAomRMVSViA+EPounaUZ1X9WpeoA/fIciccTiCV2uQU/0TpMB6R1Y7X0a19swCUBKEdvdp6C3XwQAitwOaYgkpWq0gcTIGKTIiMXTNJ1pJ0GXU5+3AY+MAPZ8AHcF9PeMpDqw0mkc0FmM2NxDxTHCWowaImnDZ5MekBgZg1Tea/HIiDStV8eN0O10oMgtvgXt+ACW+rxQmkZ3jFFNQ2IESK2FnpERgJrQaQ2JkTGkPCPWFiPt3Lyq40YIpHdhtd/J3EhpGruf/oxwGk+VutNaAPquBZBWaWbz9dAKEiNjqLJJNU37gChGGir0FiP2PA0OhmMYiojRCCOU9oYiMdtWDYxE49L7zxhpGvsJ83SMFhmxu1DXChIjY6gssUdkpLU/KUb8xbpeh11Pg7zhXJnXhRKPa4JXqwffABmzb9UAP4l7XA6p940e8LUIxxKIxOxn6OZ0p1XT6IldD0p6QWJkDKnIiLVPJ+2BYQBAg1/fyEi5TU+DnQYJRXtdDriSZQN2DUen+0UEQZ9SUmB0OamdT+P83tCr/xGHIlXaQmJkDFWl4knd6pGRtmSapl5nMWLX04dRxAhVDaSdxHVeC5fTgRKP2BLerhsgY0z3Zowcn5f3fbHnfaE1JEbGwNM0A8NRy/a+YIylPCN6p2m89pxPw9M0eppXOXbPjRuhkoZj94qadC9VnY5eKsC+ByW9IDEyBn+xGzxS2z9szQ0yMJK64ev1Pn3Y9IY3QvdVjt1LGPn0ar2jVIB97wcOj4r4ivT1UonXwCv97LkWWkNiZAwupwP+Ymt3YeVRkYoSN4qTYWG9sOsNb4QheZwym2+APEqltzAHgLIie6cGuDDUO0UDkGdEa0iMjEOVxStq2gZE86oRHr52veFTDc8MIEZsnqbhPXfq/fqvhV0N3ZyUX0T/tbB7lEprSIyMg9Un96b8IkYSI/a64SlNYxz4/WCs07g918JYkRF7+tn0gsTIOHATq1XLe1ulhmf6mlcBoLzYnh1YDZWmsXlkxCjVGwCthZHWotzmwlBrSIyMg9XLe9uTaZoGA9zw3J8zYFGz8HiEY3H0J4WuEdI0pTau4AjH4tKhwxhpS3uKc44kRgxwX9jdS6U1JEbGQUrTWFSMGKXHCJASIwEbiZH0jp/899cT/tC1Y5qmM5Bai4oSA6yFzeehGCkyQmkabSExMg6SgdXynhH90zTpkRHG7DEbpSOtx4ieHT85PhunBtrTKmkMsRY2P41LnhEDHJT4WoQiccRtOrdJS0iMjIPVIyPtBoyMROMMwzaZjcIjU40GEIMAbN2BVboXDHASB4ByG5/GEwkmVZkZIzKS1p7fpuJQS0iMjEMqMmK9B0JwJIpgctMxghgp8Til2Sj9Fvx7j4eRxCCQStPY8YErRakMYCQG0tbChsKwbyiCaFyMQBihG67X5YTHJW6RdvXwaAmJkXGwcmQkvcNh+mAuvRAEwXYm1jYDlVYDQJlXbHxnxw2ww0ANzwB7p2l4iqa61COJAL2hihrtMMaKG4wqC4uR1n5jpQgA+1XUGC4ykpwPZEcDa3tyAzTKWqRMk/ZbCyOZVzk+m3fE1RISI+PA0zTBcAyRmLWG5RltIwRSvUbsIkZ4B1zjREaSpz8bPnA7DNTwDEgflGePeyEdI3Vf5di1Q7QekBgZB1+RC86kj8FqXViNliIA7BwZMUZ0ys4dWNsNdhovT/OM2KW6jGOk7qscO6fNtIbEyDg4HAKqk6ka3hPCKvBTuZFueDv1GoknGDqS7ymjCEJumhyyWQkjY8xwnhG+FgkGabK2Xejg85oMshYA4PPat7pJa0iMZICPE7eaGDnbL4qRKZXGOJUD9oqMdA+GEU8wOB0CagxQMQAApd7U5GY75cYHhqMIJ9OwRqmmKXY7pais3U7jHQYrswZSkRG7TRXXAxIjGeBtunndu1U40yeKkckkRnSBp8nqfF5p09Ebr8sJj1N8FNgpVcNTNJUlbhS5nRO8WhsEQZA2wMGw9e+HdDqCRvSM2NdQrDUkRjJgxchIIsGkyEhTZYnOV5OCt+G2gxjhc4GMZCAG7DmHw0jTetPhHh67ncaN7Rmx/rNJb0iMZICPdu+0kBjpHgwjEkvAIRhrM7RTNU2bgVrxp1Nuw4dupwE3P8Cep/FoPIHuQeNMsuaQgVU7SIxkwIqRkTPJqEh9eRHcTuMsvZ3SNEYsrQbsJQg57QYzr3J8NhyW1xEYAWOA2ymgptQ4YsTO7fm1xjg7ksGotaIY6ePmVeOkaAB7iREjllYD9loDjlTWa7C1sGNqID1i6DCIlwqgyIiWyBYj27Ztwyc+8Qk0NjZCEAQ888wzE37N1q1bsXTpUni9XsycORObNm3K41K1ZZJkYLWOGDlrQPMqYK/SXoqMGAcjVm8ASDOw2mcDbE1GbRsrjLUWdvRS6YVsMRIKhbBo0SLce++9Ob3+5MmTuOqqq/CRj3wEe/fuxW233YabbroJL730kuyL1RLuGekKhi3TfOhM3xAAY5X1AqNP5Vb5W2eiLWCs7qscHo4ODNvnodtq0ChVmQ3LSbmx3khjKoB0/459RLpeyJ6UtnbtWqxduzbn199///2YMWMG7r77bgDABRdcgNdffx333HMP1qxZI/fHa0aNT2x6NhyNYzAck96UZobf8JMrjHXDczESjTMMR+Mo8eg/wE8NEgmGjgE+C8WYa2CnyMjZpDg3WqRQmodiIzGSiowYbS3sOypBa1T3jOzYsQOrV68e9bk1a9Zgx44dGb8mHA4jEAiM+tCaEk9qqq1VfCNG9YyUeJxwJfPEVt4Me4ciiMQTEISUJ8ko2E2MBEeiUuTBaBugHefTSAM8DbYW6SmzhI26E+uB6mKkvb0ddXV1oz5XV1eHQCCA4eHhcb9m48aN8Pv90kdTU5PalzkutRbyjTDGDOsZEQTBFpsh94tMKvMaqpoJAMqLeWrAun//dPjm5y92S5u/UbDj2HqjekZ4+pIxIBSxz3rogbGeiEnuuOMODAwMSB8tLS26XEeNhSpqekMRDEfFWRdGu+GBtJP5kHU3Q6NW0gD2i4y0GjRlCdhzbL1RU8helwNupz3b82uN6keC+vp6dHR0jPpcR0cHysvLUVw8/hvP6/XC69U/jG2lyAi/2Wt9Xnhdxmh9nY4dqjmMOKSQY6eKJiDVc8doaQHAfmma4EhU2ugbDLYeYnt+N3pDERIjKqN6ZGTFihXYsmXLqM9t3rwZK1asUPtHF4yVGp+l/CLGutk5djiZGzVNBtjj759OqwEHRnLs1tuCRwyNmDID7Nn3RQ9ki5HBwUHs3bsXe/fuBSCW7u7duxfNzc0AxBTL+vXrpdd/9atfxYkTJ/Dtb38bhw4dwi9+8Qs88cQT+OY3v6nMb6Ai6eW9Zie1ERrLvMqxw2ZoVAMxkF7aa92/fzr8fjBiypJHCe3i3zlr4CgVYD9xqBeyxcjOnTuxZMkSLFmyBABw++23Y8mSJfjBD34AAGhra5OECQDMmDEDzz//PDZv3oxFixbh7rvvxq9//WtDl/VyJllocq9Re4xw7JAm4GvQZMA14H//UCSOaDyh89WoT8ozYjxhmD440up9d4D0tTCeMAQAn9de4lAvZMfELrvssqw3yHjdVS+77DLs2bNH7o/SHSu1hDeqQYxDkRF94ac/QBSE1WX6e7bU5KxBqzeA0X13hiJxlBowdaEkRu0xwqHIiDYYsprGKJBnRDv4A7jfomJkKBJDTygCwJieEZfTYZvR9dF4Ah3JuTRGXItitxOeZOm3Ve+HdIzaY4RjxynKekBiJAtcjPQORUwdumaM4XRPMkVQZbxTOWD9yAj3KJQXuaTf1WhYfQ047QMjSDDA43QYakIsRxAE+EusX+rOMY9nxPproSckRrJQVeKB0yGAMaBnMKL35eRNVzCM4WgcDgFoMmCKALB+aa+RUzQcq68BJz1FY6QJsemkIoXmfe7kCi95bzRg/x0g1YSOPCPqQmIkCw6HgJoycUaNmVM1p3pSMzg8LmMuudVP5S0GNxADaQ9di64Bx+geBQCosEETQACIJ5jUmdio6+EvEfeAARsNkdQDY+5MBkIq7x00b0XNqZ4QAGB6danOV5IZq1fTmCEyYnVByJHK3A26+QGjK2qsTPdgGNE4g9MhGG5eE6cyuRb9Q9aPUukJiZEJ4DdI+4B5IyOnk2JkWrWBN0KLlzMavbQasI8YaR0wbvM5TrnFDd0cnjKrLy+Cy2DzmjgVkhix9lrojTFX30A0JEv/eF7TjPA0jZEjI/z0EY0zhCJxna9GeYxezQRYPzrFOdNnhjSNmBqw+gZ4xsDN5zh+vhY28O/oCYmRCeAPLK7gzUgqMmJcMVLsdsKb9LP0hax305shTWOXzp9SK3gDixG7RKlaesWD0tQq4z6bKDKiDSRGJoDnlVtNKkYYYzjdzSMjxt0IBUFAdal4AumxmBgJhWPoTf5OU6poA9QTxpjh+1oA6Z4Ra90LY2nu4WLEuM+myqSBNTgSQ8zELR6MDomRCWiUxIg5Day9oQiC4RgEwbg9RjhVycolq0VGeFTNX+yWZsAYETuIka7BVJl7g4FTA3Y5jZ/uNb6frTytO7GV7w29ITEyAVyMtA0MI5Ewn7GS+0UayotQ5HbqfDXZ4ScQq0VGzGBeBYDyYl7aa90SRn4Sb/AXw+sy7v1gl54vLb2iUDfyQcnldEiNz6xuKNYTEiMTUOfzwiGIxsruQfNV1JzqNr5fhMPTNL0h8/2ds2EG8ypgj8gI70Rs5JM4kOozYuXISDgWlyqbDL8eNolU6QmJkQlwOR2oLxfDuWY0sXLz6vQaY9/sAFApiRFr3fBmMK8CNhEjvSYRI1KjLeuuxdm+YTAGlHic0kHEqFRK62GtqK2RIDGSA6lUjfl8I6ekkyBFRvSCVwwYuckWAMnPEhiJmjIlmQtmqCwDUsJwMBwz9VysbDT3psyrgmDMtvwcvh59FjsoGQkSIznQYOKKmtMm6L7KqUoOLbNaZETq82Lw6BT3KTAGDEas6RuR0jQG9igAo02TVu37ki5GjA6PVJFnRD1IjOQAb8hjxjSNWTZCAKgqFTdDK0VGxInJ5hCERWm9Xqw6E0XaAA2eprGDadIMZb2c1KwgStOoBYmRHDBrr5H+oYiUczbDDZ+KjFjnhu8MhjEUicPpEAxdMcCxchVHcCQqvbeMnqYBrO/hMYt/B0ibT2PRtTACJEZyoNFvzl4jx7sGAYhzH0o8rglerT+pyIh1xMjJZDXTlMpiuA06eyMdv4W7sPIUTXWpB2Ve498PUuMzi0apuJfKDCKdT+7ts+haGAHjPx0NQKNJIyPHOkUxMquuTOcryQ0eGQmMWMe0x0urjZ6i4Vh5Pk2ziU7iQNp8GgtWcDDGzOUZkUqtrbcWRoHESA7wNE1PKIKRqHmGuB3tEMXIzFpziBF/sRuOpKm+zyI3/cmkX2RGjbnEiBVTA6dMUknDkdbCgqfx7sEIhiJxCILxS96B9Pb81lsLADjYFsDZfn0be5IYyYHyYhdKPWK3RjNFR451mUuMOB2C5Fq3SqrmZBePjBj/gQukNkArNncyk2ESAPwW9inwqEijvxgel/G3IamaxoL3BQB88/G9uPTHr+DVI126XYPx3wUGQBAEU86o4ZGRWbU+na8kd6pKrSVGTklN58xxGrfa3z8ds3Rf5VhZGLaYKEUDpCIjVonYphNPMMnbdt4k/Z5TJEZyxGy+kVA4JpUimyUyAgBVFoqMJBJM2gDNkqapsujkZCDdM2KOtaiwsH/ntMmiVHwtrDi5t7V/GOFYAh6nQ9eUGYmRHDFbr5ETyfRAdalH2mDMgJVO5m2BEYRjCbgcguG7r3KqLfT3T8dMc1A4FRZO0/CIodH7vXB4lAoQDfZWglddTq8pgdOhXydcEiM5wst7zSJGjnYGAQDnmygqAgBVZdbZDHklzdSqErhMUNYLANVlYkVTjwmHQmbjTHIOSqkJ5qBw/Bau4DhuMj/bqCZ0FlsPfnA9r0bftTDHE9IAcAXPTXBGRyrrNcnNzrFSmobnYc2SogGsm6aRhGF1qeHnoHD8xdYclscYw/Hk8+n8SeZ5PqV8I9ZajxPd4lro6RcBSIzkDO8TwcOLRudop7lOHhwrpWmkHiMmEiNWTdMcM+H9YNVy0vbACEKROFwOwTQpMyDV98Vqk3uPd3LzKkVGTAEXI53BMEJh4+cMj3ear5IGsJYYOWlCMcLTZEOROIYj5umpMxGSODfRSTy9moYx60xR5pvf1OoSU3Ql5kgeHotGRs6nyIg58Je4pfkEpw2eqgnH4lIExyzdVzmWEiO84ZlJqjcAwOd1wZPcIHosNLDQzJGRWIJhyELCUPKLmEgYAtbsNTIYjqEjIN7nFBkxEfyEa/RUzanuISSYuLHU+rx6X44srCJGovGE1EvBDBOTOYIgWGYNOOkeBTOJkWK3UxKGVupvwcWI2cz1VmwJz5sy1pR5RlUM6QGJERmYxTfCK2lm1pWZxqzH4Rth31DE1KHpU90hROMMpR6nVIllFqxmYu0MhhEMx+AQzCcMKy04PPKYCc2rgDVLrSXzqs6VNACJEVlIYqTb2GLkYFsAADDbZH4RILURRuMMQRN4czJxhM8FqvPBoWPtfj5UJ30jPYPW2AD55je9uhRel1Pnq5FHTbLUuttCpdZSZERnj4JcrJimkaqaavVfCxIjMuCnqlMG94y83yqKkQsnl+t8JfIpcjtRkpwD1GvizfBIhxidmmMyzw6QXlFjjQ1QOombLC0ApImRoHnvhXSCI1HDeBTkIqVpLBQZOd5tjB4jAIkRWZglMvJ+MjJyYaP5xAhgjTQBFyOz68wYnUo2PjPx3z8dM5pXOVyMdFkkMsIbbE3yeXX3KMglVU1jjfsCSGt4ZoAoFYkRGaSX9w5FjJlC6B4MoyMQhiAAc+vNKUaqLRCa5mJklgnFCE/TmDkylc4xE5b1cmp84lqY+V5Ix6wpGgCoLLVW+jKRYDgpNTzT/94gMSKD9PLeU93GTNXwFM2M6lKUel06X01+1JeLYqQzYJ4JyemIpdXi+2OOGcWIBSJT6Zi1ASAATJKEuTXWIiVGzLwWYVOb6zmtA8MYiSbgdgpoqtTfZE9iRCa8vPe0QStqDiTFyAUmTdEAQF25OJSw3aRi5ERXCPEEg6/Ihbpyc5VWA9ZIk3EGhqJSVMGMnpFJPu4ZsUZkxMwpM74W4VjCEsPyDrcn55dNKjPE7Cz9r8Bk8FTNSYOKEbP7RYCUGOFGN7ORMq/6TFdaDaSlaSxgYD3WJa5Fg78IZSaMFFqtmoaLESOkBeRS5HZKw/K6LCAODyXFyAUNxtgrSIzIxOgm1gOtAwCAeQZ5g+VDSoyYMzJiZr8IkGZgtUBqwMwnccBaYmQkGpdGJFxQb857gzeRtIIY4QfXuQZZCxIjMjFyee9QJCbd7Bc2+nW+mvypN70YETfA2SYs6wVSkZGhSBwjUXO3IT/aYV6PAiB2xgTESbHReELnqymMw+1BJJjoSZpkss7QHH7dVqhuOsTFiEEOriRGZCKlaQwYGTnYFgRj4g1j1psdgOSzaB8wqxhJpWnMiM/rgtspppfM7hvhHqp5Jk1bVpZ44Ew2zTN7F1bejPGChnJTpi8BYJJPPCiZPTJixCgViRGZcBNcVzCMPoM9HKzgFwGAOr94wwdGYqabHDsciaM5OZPGrGmaUfNpTJyqYYxJ94RZ05YOR2otzL4BpsSIOe8LIFVRY/a1ONoxiAQTzepGObiSGJFJmdeFpiqxDIobgIzC+0m/iNnFiM/rkrqwmi1Vc6xzEIwBlSVuKcRuRqqTvpFuE5tYWwdGMDAchcshmG56dTpW8Y0cbDOWYTIf+MbdGTTXc2ksB9tTfhGjRKlIjOTBnDrxZjqcXFCjsKe5HwCwYLJ5/SKAeDI3a3nvfm4gbjRvKBqwRuOzA2fFtZhV5zPdTJp0uKg1c68RxtioNI1ZmWQRA+uhpDA0UmNMEiN5wN3HhzuMExkZDMckr8LSqZU6X03hcN+I2SIj+5Ib4HyTC0IpTWOwVKQczJ6i4VghNXCmbxjBcAxup2BaMzFgITHCIyMGSpmRGMmD2VyMGChN825LPxIMmFxRjNpkVMHMmLW8d39SjJg9OmWFxmfvm9y8yqnxmT9Nw4XhzFofPC7zbju1FliLUVEqioyYGx4ZOdIxaJi2wLtO9wEAlk0zf1QESC/vNc9NH4klpPDnwskV+l5MgXCfQo+JH7q8ksbsHqpUmsa8a3HQKlEqX2qIZMykpdZdwTD6hqJwCDCUl4rESB7MqCmF2ylgMBzDmb5hvS8HALC7WRQjS6dW6HshCmFGz8iRjiAi8QT8xW7J5GxWzB4ZGRiK4my/eG+a2aMAWMPAaoVKGiBVas2YeVOYB5MR/Rk1pShyG8dLlZcYuffeezF9+nQUFRVh+fLlePvttzO+dtOmTRAEYdRHUZG50whup0PKexohVZNIMMm8utQikREpTWOiXiMpv4i5zatAKjLVZqK/fzoH2sS1mFJZbLpR9WNJzacx5+YHpCppzB4ZcToEaZBkp0l9IzyVbJRmZxzZYuTxxx/H7bffjjvvvBO7d+/GokWLsGbNGnR2dmb8mvLycrS1tUkfp0+fLuiijcAcA5lYT3SHMDAcRZHbYfpTIKfenzSwmqiEzirmVQBoqEhGpgaMEfmTi+QXscD9YPbISHAkKvXescLzyewm1r0t/QCAxVMqdL2OscgWIz/5yU9w880348tf/jLmzZuH+++/HyUlJXjwwQczfo0gCKivr5c+6urqCrpoI8DFiBF6jfAUzcLJFXAbYPqiEtT6Up4Ro/hyJsIq5lUAaPCLaaa+oajpGs8B6Q0Azb8WXIz0DpnTp7DvjHhfTK4oRmWpeXvvcMwsRhhjKTFisJS+rJ0rEolg165dWL16deobOBxYvXo1duzYkfHrBgcHMW3aNDQ1NeHqq6/GgQMHsv6ccDiMQCAw6sNoSCZWI4iRpHl1ybQKfS9EQWqTpb2RWAL9Q1Gdr2ZirGReBYDyolTjuTYTRke4MDR7JQ0g+nccAkSfwpD5UjV7kpvfEoNtfvkilVqbMFLVNjCCrmAYToeA+QYT6rLESHd3N+Lx+DmRjbq6OrS3t4/7NXPmzMGDDz6IZ599Fg8//DASiQRWrlyJM2fOZPw5GzduhN/vlz6amprkXKYmzEmWRB3vGkQkpu9phUdGllmgvwjH63JKJkozmFitZF4FxGhmg9+cvpGB4SiOJqf1Lm6q0PdiFMCZ1hLejL6RPcnn0xKLPJ/MHBnhUZE5dT4Ue4xjXgU0qKZZsWIF1q9fj8WLF2PVqlV46qmnMGnSJDzwwAMZv+aOO+7AwMCA9NHS0qL2Zcqm0V8En9eFWIJJY8r1oHswLE2JtUpZL8dMvUasZF7lNFaIoqq131yRkb0t/WAMmFpVYpi5G4ViVt8IYylzvVUiI7UmFiPvGjRFA8gUIzU1NXA6nejo6Bj1+Y6ODtTX1+f0PdxuN5YsWYJjx45lfI3X60V5efmoD6MhCAIWTBHDXO+e6dftOrYf7wEgGsOqy6zx4OXUm6gLKz/9LTSYKawQeGTEbNOTd1us5w5gXjHS0juMnlAEHqfD9P1eOGae3LvHoOZVQKYY8Xg8WLZsGbZs2SJ9LpFIYMuWLVixYkVO3yMej2Pfvn1oaGiQd6UGZFEyBLw3qfz14I2j3QCAD86s1u0a1ELqNTJg/Jv+nVPiBnjRdOtsgNzE2mo2MWKxnjtA6l4wW8psT4u4Fhc0lpt6PlA6Zh2WF4snJDOxESMjLrlfcPvtt+P666/HBz7wAVx88cX46U9/ilAohC9/+csAgPXr12Py5MnYuHEjAOCHP/whLrnkEsycORP9/f246667cPr0adx0003K/iY6wPPRekVGGGN4/ZgoRi6dWaPLNajJ5GSa4EzfkM5Xkp3O4AhOdocgCMCyaVV6X45iNFbwDdA8aZpEgkmHA6v03AHEfikADNNkMVekFI0FvDscs3pGjnYOYjgaR6nHacj5QLLFyLXXXouuri784Ac/QHt7OxYvXowXX3xRMrU2NzfD4UgFXPr6+nDzzTejvb0dlZWVWLZsGbZv34558+Yp91voBL/BjnQEEQrHUOqV/ecsiNM9QzjbPwy3U8DFM6yzCXKm1ZQCEH9PI7MzGRWZU+czfYOtdOqTkZG2fvOcAI92DiIYjqHE48ScOnN3+0wnJUaMfS+MxWqVNEBKjIQicV2e+/nC/SILp1TA6TCery2vv+Itt9yCW265Zdx/27p166j/v+eee3DPPffk82MMT215ERr9RWgdGMF7Zwaw4nxtUyU8KrJ0aiVKPOa4IeQwvboEAHCqJ6TzlWTn7ZO9AGA5QdiY9Iy0migywmc0LW6qgMsiPXcAYEqleC+cNVFkZCQax/utYlrACpPEOaUeJ0o9ToQicbQNjGBmrfGiDOPBK2kWGTRKZZ27VSck30hyobXkjWPcL2K9FA0ATKsSIyOdwTCGIjGdryYzO0+LYuSi6dYSIw3JNFlwJIbBsHH//umk/CLW2fyAtMhI/zASCXM0ATzQGkA0zlBT5pGu3woIgoCmKlEctpgoUvVW8tBkVGM3iZECkXwjGouReIJJlTSXzrKmGPGXuFFZIqY9jJqqCY5EpdbjVhMjZV4XfEVixM0sbeF5Jc1SCzUABIB6fxEcgthczywVNTxiuLip0jLl7pypXIz0GvO5NJb2AdHX5hCMG8ElMVIgi3WKjOxt6cfAcBQ+rwsLLdB+PBPTqrlvxJipmt3N/UgwoKmqGPV+cw+AHI9GXlFjAt9I92AYJ7rF98mSJmOe/vLF7XRI1U0tJknVbD8uRm5Xapy+1gIuRpoNekgay44T4lpc2Og3rK+NxEiBLJjih0MQu4Rq2Y/hrwfEjreXza21VG58LNw3crLbmDf9OyetmaLhNJiooia9544VZqCMZbKJTKzhWBzvnBLvDStW+k1NPpeaTRIZ2ZG8N7T2NcrBuruYRpR4XJiddO3vTdbUqw1jDH/ZL4qRtfNzazZnVqbXGDsy8tZJ8Sa/2KpihJtYTRAZef1oFwBr9twBzFXeu7e5HyPRBGrKPJhdZw6DpxxSnhHjrwUA7DiRFCPnGffeIDGiALyfwdsntREjB9uCaO4dgtflwKrZkzT5mXoxPZmmMWJFzcBwFLuTfRSsePoDUo3PjN6FlTGG13kDwFnWvCd4RY0ZxMgb0km8xnJ+EQBoqkx5Row+VfxM3xBaeofhdAi4yKB+EYDEiCLwnCjPkarNi/vbAACrZk8yTY17vkxLhkONaGDdfqwb8QTD+ZNKpZOS1WgwSXnvqZ4htA6MwON0WKoLbjpm6jWynTdjNHBaoBD4WgyGY+gz+FRxnqJZMNmPMgPvFyRGFICHvg61BzVxur+Y9ItcafEUDZCKjLQNjGAkGtf5akaz9bCYFlg1u1bnK1EPPizP6G3IeYpm6bQKS/bcAVIboNF7jYTCMcnQv/J8a0YMi9xO1Cdb9BvdNyKlaAwuDEmMKEB1mRcXNIhDoLiJTi2Odw3iSMcgXA4BH72gTtWfZQQqStwoT5aXGummZ4zh1SPiBnjZHGumBYB0z8iwocPRvAHghyyaogFSqQGj9xp5+1QvYgmGKZXFktHTikgVNQZ6Lo2FMYY3jxvfLwKQGFEMHo7k4Um1eO5dMUWzcmaNYUu0lEQQBMnEeqrbOL6RQ+1BtAdGUOx2GrZuXwmmVJbA6RAwFImjI2DM/haxeCLVc8ei3h3APL1GUika664FkGZiNbAYOdQeROvACIrcDsNX/JEYUQj+EHxDRd9IPMHwxM4WAMCnl0xW7ecYDSOaWHmKZsX51ShyW2Ma6Xh4XA6pvPpIR1Dnqxmf984OIDgSQ3mRCwss3HPHDL1GGGN4+WAnAOBDs60tRszQa+Tl9zsAAB+cOQnFHmM/p0iMKMTFM6rgcgho6R1W7c257WgXzvYPw1/stoVfhJOaUWOcm/7VI+ID1+rVTACk0vWjnYM6X8n4SA/cWTWGHACmJEbvNXK8axAnu0PwOK1f6Te1mgtDY64FALx8ULw3Vl9gfF8biRGFKPW6pG6sakVHHn2rGQDwmaVTLH0aHwvvwmqUNM3AUFSa1GtlvwhnVnIQ2LFO40VGGGN4MdlzZ82F1hfoUyqM3Wvkr0lhuHJmNXxF1k4jcw+PUT0jHYERvHtGHFR4OYkRe7Eymarh/Q6UpCMwgi2HxNP4dRc3Kf79jQyfinmoPWgIE+VLB9oRSzDMrfdJQsnKzExGRo50GC8ycqxzECeSJ/HL5xr/gVsoRi/v/esBUYxcMc/6wpCnaVr7hxGNJ3S+mnPZkkyXLW6qQK3P+KMqSIwoCH8Ybj3cqXgZ6hPvtCCeYLhoeiVmJTcHuzC3wQe3U0BvKGKIE+H/vtcKAPjEokadr0QbeGTkaIcxxGA6PCrywVk1lj+JA6ko4fEuY0QJ0+kIjGBvSz8EAVg9z/rCcJLPC6/LgQQTBYnR4Cmaj80zR9UliREFWTTFj0Z/EUKROLYlyz6VYDgSx293nAIAXHfxVMW+r1nwupyYWy+WTu87O6DrtXQPhqXKjY8vbND1WrRiRk0pHAIQGImhK2isKg6p544NUjQApBYCB1sDhhOGm5MpmiUmOYkXiiAIhi3vHYrE8Eayqmm1SVpAkBhREEEQcOV8cYPiJzYlePjN0+gejKCpqtg2p/GxLJgiVkm8e6Zf1+v4y/52xBMMC6f4bZGiAcQGT7yiyUgm1pbeIRxoDcAhAKtNcvorlJm1ZfA4HQiGY4aIEqbD/SJX2EQYAqlI1TED3RcA8PLBToRjCTRVFZtmNhCJEYX5uwXijbj5YAfCscJTNUORGB7YdhwA8I2PzILbwhN6s7EoKUbea9E3MvLcu8kUzUJ7iULu2zFSeS8X/MtnVKPKglN6x8PjcmBWcnM50BrQ+WpSdAXD0kn8CpsIQwC4sFGMVO0/a5y1AICndp8BAKxbPNk0s4HsubOpyNKplaj1eREciWH7scK7sfKoyNSqEnxqqX16i4xl4ZQKAMD+swO6dZ/sCIzg7eRY9KtskqLh8A3QKJERxhj+lHzg8gOAXZiXTNW832acDfDZvWcRTzAsmVqB8yaZ4ySuBLyvzb6z/fpeSBqdwRG8liyi+JSJ+lGRGFEYh0PA2mQPkBf2tRX0vfpCEdz/6gkAwC2Xz7RtVAQQTZRFbjE8fUKnEt8/7joDxoAPTKuUZrbYhVm1omn6mEEqavadHcCh9iA8Lgc+ucg8D1wlmJc8jb9vkMgIYwxP7hSF4WeXTdH5arSFp4+PdQ5iKBLT+WpE/ry31ZTC0L67m4qsXSCeml860I7BcP5v0B89/z56QxHMqi0zlcJVA5fTgQsb9TuFxBMMf3jzNADgi5fYz0QspWk6jVFR8/g7YifiKy+sh7/E+lU06fDIyEGDREYOtAZwuEMUhh+3WfqyrrwItT4vEsw46/HU7rMAgE8vNZcwJDGiAhdNr8KMmlIERmJ4OLmByWXr4U48tfssBAH4j88utHVUhLOQm1h18I28cqgTrQMjqCxxY+18e6VoAFGMCALQPxRF92BE12sZjsTx572id+fai+zVcwcALkhGRs72D6N/SN+1AMSIISA2nbPDvKyx8FTNe2f09bMBwKH2AN5vC8DtFPDxBeZ6TtEOpwJOh4ANH5kJAPjVthMYjsgzsgZGovg/T+8HANywcjqWTq1U/BrNyKKkb+Q9HSpqfp8UlZ+7qMlW3W85RW4npiXLGPX2KvxlfxuC4RiaqooNP4lUDcqL3GiqEtOEeqdqwrE4ntkrnsTtlqLhzJd8I/qLEd6l+yNzalFpMlM3iRGVuHpxI5qqitETiuAPb+UeHYnGE/j6w7txtn8YUyqL8c9XzFHxKs0Fz88eaA1o2vHwVHcI2450QRCAL148TbOfazSWTROnfr55onBjdiE8lkzRXLOsCQ6Lz6LJhFFMrH/e24r+oSjqy4vwQQtPTM4Gj4zs11mMDAxH8WQySrV+xXRdryUfSIyohNvpwIbLxOjIA9tOIJSDd4Qxhjue2ofXj3WjxOPE/X+/DKVel9qXahpmVJeiutSDcCyBt0/2avZzN20/BQC4bPYkTE0O7bMjl5ynvxjZ09yHt0/2wuUQcM0H7HkSByD5p/SMjDDG8KvXRIP9DZdOt/yQwkwYxcT62NvNGIrEMbfeh0tnmi9iSGJERT69dAomVxSjKxjG1/6wG5FY5tN8OBbHd5/ehz/uOgOnQ8C9X1gqhf8IEYdDwEeTA594t0e1aR8YwSNvi6HPmz50niY/06hckkyJvHdmoCBjdiHc/6rYc+fqxZPR4LdXRVM6PDKiZ6+RrUe6cKRjEGVeF76w3H6mbk5deREmJU2seonDaDwhHZr+4dIZpuktkg6JERXxuBz4n+uWoNjtxLYjXbj9ib2Ij9Mjo6V3CJ974E08+nYLBAH4f+vm4yM2GPqVDx9LDuDa/H6HJlUdv9h6DJFYAhdPr8LK88132lCSpqoSTKksRjzBsPOUdpEpzrHOIF5KDmL76ip7C0N+Gj/SGURvSB8T66+2iVGRz1/UhHIbzAXKxkKdfSN/2d+OtoER1JR58MnF5qxoIjGiMsumVeL+Ly2D2ynguffacNl//Q2/3HYcWw524Nm9Z/G1h3dh1V1/w7st/fAXu/HQDRfZcv5MrnxwZg2K3A6c7R9WPV/e2j+Mx94W/Qm3fWyWKU8bSsMNozt0SNU8kOy5c8W8OtsNixxLXXkR5tb7wBgUnYOVK/vODGD78R64HAL+4YMzNP/5RoNHsXc392v+sxMJhvu2ihHDLy6fZlqDPYkRDVg1exJ+dt0S+IvdaOkdxr+/cAg3/nYnbn1sL/6yvx0JJm6yz33jg7hsDkVEslHsceLDsyYBSI0rV4t7/3YMkXgCl5xXhZXn29OcNxaeqnnzhLaRkeaeIalq46uXna/pzzYqPHr6t8Odmv/s/3zpEABxcrXdGgCOx4dmic+HrYc7NTXXA8Cf323FwbYAfF4Xblg5XdOfrSTkjtSIK+c3YNXsWjy79yye2nMW4WgcxR4nZtaW4UuXTMecenuf9OTwsXl1+Ov7Hdj8fge++bHZqvyM/WcH8GjSK/LN1er8DDOyIpmq2n92AMGRKHwahec3/uUgonGGD82qoVL3JB+ZU4v7th7Hq0e6EE8wzQykrx7pwmtHu+F2CnRvJFkytRI1ZR50D0bw1olefHCWNoeXcCyO//rrYQCiSDdbOW86JEY0pNjjxOcvnorPUxqmID56QR0cgljWeKZvCFMqla1wiScYvvv0PiSYePJbbsNeFplorCjGtOoSnO4Zws5TfZp4m9480YO/7G+HQwC+d9U81X+eWVg6tQLlRS70D0Wxt6VPKr1Wk3iCYeMLBwGI5aN2ri5Lx+kQ8NG5dXh8Zwv++n67ZmLkD28240zfMGp9Xnz50uma/Ey1oDQNYTqqSj24aLr44OWeDiX5/Y5TeO/MAHxFLnz/4xco/v3NDveNbD6ofkVTPMHwo+feBwB8YflUiiCm4XI68OHZYsrylUPapGr+uKsFh9qDKC9y4RuXz9TkZ5qFKy4UpxVrZa7vGQzjZ68cBQDctno2Sjzmji2QGCFMCT8F/HbHKQRGoop935PdIdz1khj2/M6Vc1HrK1Lse1sF7tb/897WnPrnFMLDb57GgdYAfEUuSgmMw+XcN3JIfRNrZ2AE//6C6BX5xuWzUFFi3pSAGlw6swYlHifaBkaw/6z6Jb53/vkA+oaimFvvw+cs0HOHxAhhSq6YV4+ZtWUIjsTw+x35zf8Zy1Akhq/+fhdCkTgunl6FL1A6bVxWnFeN6dUlGAzH8Nx7rar9nGOdQfx7MiXwL2vmoLrMq9rPMiurZk+CkExZnu0fVu3n8IaMA8NRzJ9cjhtMnhJQgyJ3yly/+f12VX/Wi/vb8dx7bXA6BNz12UVwWWB2mfl/A8KWOBwCvp6sqvjN6ycL7nzIGMN3n9qHwx1B1JR58fMvLLFtq/GJEARB8j09qkKaDAAisQRue3wvwrEEPjx7Er50iX3b8GejusyL5TPElOVvk02v1OCPu85gy6FOeJwO3H3NYhrcmQGeqnl+XxsS4/SUUoLeUATff1acXfaPHz5P6jljdugdRZiWTy4S5//0hiJS98F8YIzhJ5uP4Jm9rcnut0tQW07pmWx8dtkUuJ0C9rb0q9J18sd/OYT9ZwOoLHHjrs8upB4vWbg52Rn4kbeaMTCsXMqSc7AtgH/78wEAwDc/Npt8O1n46AV1KPO6cLwrhL+q0CU6Gk/g63/Yha5gGOdPKsWtH52l+M/QCxIjhGlxOR34xuXizfiTvx7BrtPye18wxnDP5iP42SvHAAB3fmIeVc/kQE2ZF1cku+HKGQSZC7/bcQoPvnESALDx0wtRR8IwKx+ZU4vZdWUYDMfw8JvKrkVXMIybfrsToUgcK8+vxj9+2N6dbyfCX+yW/Gw/ffmI4tGR//fc+3jzRC9KPU7c9/fLTNvgbDxIjBCm5pplU3DVwgbEEgxf/8NudAXDOX9tOBbHnX8+gP9JCpHvXXWBKadd6sUXLxFTNY+906LYxNLN73dIp/B/WTMHV86vV+T7WhmHQ8BXPiymLB964xRGonFFvu9QJIav/H4nzvYPY0ZNKX7xxaW2HYYnhxs/OANlXhcOtQfxVwW9Iw++fhK/Tfrj7rl2MWZbrAsxiRHC1AiCgP/4zELMrC1DRyCMGx56G809QxN+3fGuQXz6F9vxu+TN/b2rLrD9IDy5rDy/BlctaEA8wfCdP72HWIGdJ59/rw0b/rAbCSbOO/k6dVrNmU8ubkSjvwjdg2H85vWTBX+/wEgU63/zNnY3i2MqfnP9B6h6JkcqSjxp0ZGj484jk8uvXzuBHyZL3L/1sdm44kLriXQSI4TpKfO6cP/fL4O/2I0DrQH83f+8hsffacZwZPQJkTGGox1BfOuJd3HFPdtwoDWAqlIPHvryRSRE8uTfPnmh9Hf/1Wv5bYKMMfx2+ync8uhuROIJ/N2Cevxo3XzyicjA7XTgtmTp8082H8FbBcwO6gqG8YVfvYmdp/tQXuTCQ1++COdNKlPqUm3BTR88D75kdIS3zs+HREJMI/+/58Wqsm9cPhO3WLS/i8C06M5SIIFAAH6/HwMDAygvL9f7cgiDcrZ/GLc9tgfvnOoDABS5HVhxXjV8RW7EEgnsae5H28CI9PrL59bi3z+1APV+8iQUwh93ncE/P/kuXA4B//nZhfj00tx7HvSFIvg/z+zDC/vEcPaXLpmGf/vkhZQOyAPGGL75+F48s7cVk3xePP+ND8o2Yr92tAvffPxddA+GUV3qwe9vXI55jfTMzYdn957FrY/tBQD81zWL8Nll8nqB9IUi+OYTe7H1sNhD5rbVsyTBaSZy3b9JjBCWIhZP4FevncQf3jqNM33n9l1wOwWsml2Lb1w+E4uaKrS/QAvCGMO3nnwXT+0WB9ndsXYubv7QeVlLo8OxOP646wx++vJRdAXDcDkEfOuKOfjqqvMoIlIAQ5EY1t37Bo50DOK8SaW49wtLcUHDxM/MzuAI/mfLUTz8pjiPaU6dD/f9/VKKiBTI3X89jJ+9cgwepwM//swCfGrJ5Anf3/EEw592n8F/vXQYncEwvC4HfnT1fHzuoiaNrlpZSIwQtoYxhgOtAexu7kM0Lr7F59b7sHRqJYo91nGgG4VEguH/e+Gg5FeYXVeGr3z4fHxwVg1qfV4IgoBILIH3zvTjb4c78cddZ9AREM3GM2vLcM/nFlumX4LenOgaxBd+9RbaAyPwuBz45urZuOYDU1AzpmkcYwz7zwbw7N6z+MNbzRhOGl///pKp+N5V8yxVqaEXiQTDLY/uliJ/H5tXh+9cOQcza881n/YPRfC/77XhD2+exqH2IADgvJpS/PwLS00dnSIxQhCE5jz0xknc/dcjGExrE+/zijMzgmNax9eXF+Erq87DdRdPpY1PYXpDEfzzk+9KM2ucDgHLplai3l+EUq8TZ/tHcKwjiNa0tOXipgp8+8o5WHm+NkPe7EIsnsAD207gpy8fkQ5G59WUYsEUP0o8TgxH4jjSMYijnUHp331FLvzT5bOwfuU0eF3mvjdIjBAEoQsDw1H84a3T+OOuMzjVHUJ6MYG/2I1VsyfhoxfU4sr59aZ/0BoZxhie2NmCR95uwbst/eO+psjtwEfn1uHTSyfj8rm1lCJTkYNtAfzni4fwxrEeRDJUnl3QUI7PLJ2MzyydgspSa1QvkRghCEJ3wrE4TvcMwekQUFXigb/YTW32deBkdwh7mvvQG4pgMBxDo78YU6tLsGCyH6Vec097NRvBkSheO9qN1v5hDEXicDoEzKotwwUN5WiqKtH78hSHxAhBEARBELqS6/5NfUYIgiAIgtAVEiMEQRAEQehKXmLk3nvvxfTp01FUVITly5fj7bffzvr6J598EnPnzkVRUREWLFiAF154Ia+LJQiCIAjCesgWI48//jhuv/123Hnnndi9ezcWLVqENWvWoLOzc9zXb9++Hddddx1uvPFG7NmzB+vWrcO6deuwf//+gi+eIAiCIAjzI9vAunz5clx00UX4+c9/DgBIJBJoamrCN77xDfzrv/7rOa+/9tprEQqF8Nxzz0mfu+SSS7B48WLcf//9Of1MMrASBEEQhPlQxcAaiUSwa9curF69OvUNHA6sXr0aO3bsGPdrduzYMer1ALBmzZqMrweAcDiMQCAw6oMgCIIgCGsiS4x0d3cjHo+jrq5u1Ofr6urQ3t4+7te0t7fLej0AbNy4EX6/X/poajJnT36CIAiCICbGkNU0d9xxBwYGBqSPlpYWvS+JIAiCIAiVkNV6r6amBk6nEx0dHaM+39HRgfr6+nG/pr6+XtbrAcDr9cLr9Wb8d4IgCIIgrIOsyIjH48GyZcuwZcsW6XOJRAJbtmzBihUrxv2aFStWjHo9AGzevDnj6wmCIAiCsBeyhxLcfvvtuP766/GBD3wAF198MX76058iFArhy1/+MgBg/fr1mDx5MjZu3AgAuPXWW7Fq1SrcfffduOqqq/DYY49h586d+OUvf6nsb0IQBEEQhCmRLUauvfZadHV14Qc/+AHa29uxePFivPjii5JJtbm5GQ5HKuCycuVKPPLII/je976H7373u5g1axaeeeYZzJ8/X7nfgiAIgiAI00KD8giCIAiCUIVc929TzI7meon6jRAEQRCEeeD79kRxD1OIkWAwCADUb4QgCIIgTEgwGITf78/476ZI0yQSCbS2tsLn80EQBMW+byAQQFNTE1paWiyb/rH670i/n/mx+u9Iv5/5sfrvqObvxxhDMBhEY2PjKD/pWEwRGXE4HJgyZYpq37+8vNySb7B0rP470u9nfqz+O9LvZ36s/juq9ftli4hwDNmBlSAIgiAI+0BihCAIgiAIXbG1GPF6vbjzzjst3Xre6r8j/X7mx+q/I/1+5sfqv6MRfj9TGFgJgiAIgrAuto6MEARBEAShPyRGCIIgCILQFRIjBEEQBEHoCokRgiAIgiB0hcQIQRAEQRC6Ynkxcu+992L69OkoKirC8uXL8fbbb2d9/ZNPPom5c+eiqKgICxYswAsvvKDRlcpn48aNuOiii+Dz+VBbW4t169bh8OHDWb9m06ZNEARh1EdRUZFGVyyPf/u3fzvnWufOnZv1a8y0fgAwffr0c35HQRCwYcOGcV9v9PXbtm0bPvGJT6CxsRGCIOCZZ54Z9e+MMfzgBz9AQ0MDiouLsXr1ahw9enTC7yv3PlaLbL9fNBrFd77zHSxYsAClpaVobGzE+vXr0dramvV75vM+V5OJ1vCGG24453qvvPLKCb+vGdYQwLj3oyAIuOuuuzJ+TyOtYS77wsjICDZs2IDq6mqUlZXhM5/5DDo6OrJ+33zv3VyxtBh5/PHHcfvtt+POO+/E7t27sWjRIqxZswadnZ3jvn779u247rrrcOONN2LPnj1Yt24d1q1bh/3792t85bnx6quvYsOGDXjzzTexefNmRKNRXHHFFQiFQlm/rry8HG1tbdLH6dOnNbpi+Vx44YWjrvX111/P+FqzrR8AvPPOO6N+v82bNwMArrnmmoxfY+T1C4VCWLRoEe69995x//0///M/8T//8z+4//778dZbb6G0tBRr1qzByMhIxu8p9z5Wk2y/39DQEHbv3o3vf//72L17N5566ikcPnwYn/zkJyf8vnLe52oz0RoCwJVXXjnqeh999NGs39Msawhg1O/V1taGBx98EIIg4DOf+UzW72uUNcxlX/jmN7+J//3f/8WTTz6JV199Fa2trfj0pz+d9fvmc+/KglmYiy++mG3YsEH6/3g8zhobG9nGjRvHff3nPvc5dtVVV4363PLly9lXvvIVVa9TKTo7OxkA9uqrr2Z8zUMPPcT8fr92F1UAd955J1u0aFHOrzf7+jHG2K233srOP/98lkgkxv13M60fAPb0009L/59IJFh9fT276667pM/19/czr9fLHn300YzfR+59rBVjf7/xePvttxkAdvr06Yyvkfs+15Lxfsfrr7+eXX311bK+j5nX8Oqrr2aXX3551tcYeQ3H7gv9/f3M7XazJ598UnrNwYMHGQC2Y8eOcb9HvveuHCwbGYlEIti1axdWr14tfc7hcGD16tXYsWPHuF+zY8eOUa8HgDVr1mR8vdEYGBgAAFRVVWV93eDgIKZNm4ampiZcffXVOHDggBaXlxdHjx5FY2MjzjvvPHzxi19Ec3Nzxteaff0ikQgefvhh/MM//EPW6dRmWr90Tp48ifb29lFr5Pf7sXz58oxrlM99bCQGBgYgCAIqKiqyvk7O+9wIbN26FbW1tZgzZw6+9rWvoaenJ+NrzbyGHR0deP7553HjjTdO+FqjruHYfWHXrl2IRqOj1mPu3LmYOnVqxvXI596Vi2XFSHd3N+LxOOrq6kZ9vq6uDu3t7eN+TXt7u6zXG4lEIoHbbrsNl156KebPn5/xdXPmzMGDDz6IZ599Fg8//DASiQRWrlyJM2fOaHi1ubF8+XJs2rQJL774Iu677z6cPHkSH/rQhxAMBsd9vZnXDwCeeeYZ9Pf344Ybbsj4GjOt31j4OshZo3zuY6MwMjKC73znO7juuuuyTkKV+z7XmyuvvBK/+93vsGXLFvzHf/wHXn31VaxduxbxeHzc15t5DX/729/C5/NNmMIw6hqOty+0t7fD4/GcI5An2hv5a3L9Grm4FPkuhO5s2LAB+/fvnzBPuWLFCqxYsUL6/5UrV+KCCy7AAw88gB/96EdqX6Ys1q5dK/33woULsXz5ckybNg1PPPFETicVs/Gb3/wGa9euRWNjY8bXmGn97Ew0GsXnPvc5MMZw3333ZX2t2d7nn//856X/XrBgARYuXIjzzz8fW7duxUc/+lEdr0x5HnzwQXzxi1+c0CRu1DXMdV8wApaNjNTU1MDpdJ7jEO7o6EB9ff24X1NfXy/r9UbhlltuwXPPPYe//e1vmDJliqyvdbvdWLJkCY4dO6bS1SlHRUUFZs+enfFazbp+AHD69Gm8/PLLuOmmm2R9nZnWj6+DnDXK5z7WGy5ETp8+jc2bN2eNiozHRO9zo3HeeeehpqYm4/WacQ0B4LXXXsPhw4dl35OAMdYw075QX1+PSCSC/v7+Ua+faG/kr8n1a+RiWTHi8XiwbNkybNmyRfpcIpHAli1bRp0s01mxYsWo1wPA5s2bM75ebxhjuOWWW/D000/jlVdewYwZM2R/j3g8jn379qGhoUGFK1SWwcFBHD9+POO1mm390nnooYdQW1uLq666StbXmWn9ZsyYgfr6+lFrFAgE8NZbb2Vco3zuYz3hQuTo0aN4+eWXUV1dLft7TPQ+NxpnzpxBT09Pxus12xpyfvOb32DZsmVYtGiR7K/Vcw0n2heWLVsGt9s9aj0OHz6M5ubmjOuRz72bz4Vblscee4x5vV62adMm9v7777N//Md/ZBUVFay9vZ0xxtiXvvQl9q//+q/S69944w3mcrnYf/3Xf7GDBw+yO++8k7ndbrZv3z69foWsfO1rX2N+v59t3bqVtbW1SR9DQ0PSa8b+jv/3//5f9tJLL7Hjx4+zXbt2sc9//vOsqKiIHThwQI9fISvf+ta32NatW9nJkyfZG2+8wVavXs1qampYZ2cnY8z868eJx+Ns6tSp7Dvf+c45/2a29QsGg2zPnj1sz549DAD7yU9+wvbs2SNVk/z4xz9mFRUV7Nlnn2Xvvfceu/rqq9mMGTPY8PCw9D0uv/xy9rOf/Uz6/4nuY6P8fpFIhH3yk59kU6ZMYXv37h11T4bD4Yy/30Tvc63J9jsGg0H2z//8z2zHjh3s5MmT7OWXX2ZLly5ls2bNYiMjI9L3MOsacgYGBlhJSQm77777xv0eRl7DXPaFr371q2zq1KnslVdeYTt37mQrVqxgK1asGPV95syZw5566inp/3O5dwvB0mKEMcZ+9rOfsalTpzKPx8Muvvhi9uabb0r/tmrVKnb99dePev0TTzzBZs+ezTweD7vwwgvZ888/r/EV5w6AcT8eeugh6TVjf8fbbrtN+nvU1dWxv/u7v2O7d+/W/uJz4Nprr2UNDQ3M4/GwyZMns2uvvZYdO3ZM+nezrx/npZdeYgDY4cOHz/k3s63f3/72t3Hfk/x3SCQS7Pvf/z6rq6tjXq+XffSjHz3n9542bRq78847R30u232sJdl+v5MnT2a8J//2t79J32Ps7zfR+1xrsv2OQ0ND7IorrmCTJk1ibrebTZs2jd18883niAqzriHngQceYMXFxay/v3/c72HkNcxlXxgeHmZf//rXWWVlJSspKWGf+tSnWFtb2znfJ/1rcrl3C0FI/lCCIAiCIAhdsKxnhCAIgiAIc0BihCAIgiAIXSExQhAEQRCErpAYIQiCIAhCV0iMEARBEAShKyRGCIIgCILQFRIjBEEQBEHoCokRgiAIgiB0hcQIQRAEQRC6QmKEIAiCIAhdITFCEARBEISu/P+oGOKQF6FigQAAAABJRU5ErkJggg==",
+      "text/plain": [
+       "<Figure size 640x480 with 1 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "alpha = 2.0\n",
+    "main(alpha)\n",
+    "\n",
+    "start = time.time()\n",
+    "sol = main(alpha)\n",
+    "end = time.time()\n",
+    "print(f\"Integration took in {end - start} seconds.\")\n",
+    "\n",
+    "plt.plot(sol.ts, sol.ys)\n",
+    "plt.show()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "We can also change the `alpha` term and see its influence on the dynamics."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 70,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Integration took in 0.000225067138671875 seconds.\n"
+     ]
+    },
+    {
+     "data": {
+      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiMAAAGdCAYAAADAAnMpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABHoklEQVR4nO3deXxU1d0G8OfOnnVC9hUS9j0gSgy4ACIQaJTaV3FpQVHrEi1K66tUhdq+NW2t1toiVCuiVRS1gAtUi6yyyxJlh5BAAtkTMpNMkpnJzH3/mIVEsk2YO3cmeb6fz3zCTO6de643cZ6c8zvnCqIoiiAiIiKSiULuBhAREVHvxjBCREREsmIYISIiIlkxjBAREZGsGEaIiIhIVgwjREREJCuGESIiIpIVwwgRERHJSiV3A7rCbrejpKQEYWFhEARB7uYQERFRF4iiiLq6OiQmJkKhaL//IyDCSElJCVJSUuRuBhEREXVDcXExkpOT2/1+QISRsLAwAI6TCQ8Pl7k1RERE1BVGoxEpKSnuz/H2BEQYcQ3NhIeHM4wQEREFmM5KLFjASkRERLJiGCEiIiJZMYwQERGRrBhGiIiISFYMI0RERCQrhhEiIiKSFcMIERERyYphhIiIiGTFMEJERESy8iiM5Obm4pprrkFYWBhiY2Mxe/ZsnDx5stP9Pv74YwwdOhQ6nQ6jRo3Chg0but1gIiIi6lk8CiPbtm1DTk4O9uzZg40bN8JqtWLatGkwmUzt7rNr1y7cdddduP/++3Ho0CHMnj0bs2fPxpEjR6648URERBT4BFEUxe7uXFlZidjYWGzbtg033HBDm9vMmTMHJpMJX3zxhfu1a6+9FmPGjMHy5cu7dByj0Qi9Xg+DwcB70xAREQWIrn5+X1HNiMFgAABERka2u83u3bsxderUVq9Nnz4du3fvbncfs9kMo9HY6iGFNQfPY8mnR7C3oFqS9yciIqLOdTuM2O12PPHEE5g4cSJGjhzZ7nZlZWWIi4tr9VpcXBzKysra3Sc3Nxd6vd79SElJ6W4zO7TlZCXe2X0OR0qkCTtERETUuW6HkZycHBw5cgQffvihN9sDAFi0aBEMBoP7UVxc7PVjAECoVgkAMJmbJXl/IiIi6pyqOzs99thj+OKLL7B9+3YkJyd3uG18fDzKy8tbvVZeXo74+Ph299FqtdBqtd1pmkdCtY7TZxghIiKSj0c9I6Io4rHHHsPatWuxefNmpKWldbpPZmYmNm3a1Oq1jRs3IjMz07OWSiDEGUbqGUaIiIhk41HPSE5ODlatWoVPP/0UYWFh7roPvV6PoKAgAMDcuXORlJSE3NxcAMCCBQtw44034uWXX8asWbPw4YcfYv/+/XjjjTe8fCqeY88IERGR/DzqGVm2bBkMBgMmTZqEhIQE92P16tXubYqKilBaWup+PmHCBKxatQpvvPEG0tPT8cknn2DdunUdFr36CntGiIiI5OdRz0hXliTZunXrZa/dfvvtuP322z05lE8wjBAREcmvV9+b5tJsGpvMLSEiIuq9enkYUQNgzQgREZGcenUYCXH2jHCYhoiISD69OoxwNg0REZH8enUYcRWwmiw22O3dvl8gERERXYFeHUZcPSMAYLKwd4SIiEgOvTqMaFUKKBUCAM6oISIikkuvDiOCILh7R1jESkREJI9eHUYAFrESERHJrdeHEU7vJSIikhfDCIdpiIiIZNXrwwiHaYiIiOTV68NIiIZhhIiISE4MI+5hGk7tJSIikkOvDyNhOvaMEBERyanXhxHOpiEiIpIXwwhn0xAREcmq14cRzqYhIiKSV68PI67ZNOwZISIikgfDCHtGiIiIZNXrw4hrNg17RoiIiOTR68PIpZ4RrjNCREQkh14fRkI5tZeIiEhWvT6MtKwZEUVR5tYQERH1PgwjzjDSbBdhbrbL3BoiIqLeh2HEObUX4IwaIiIiOfT6MKJUCAjWsG6EiIhILr0+jABcEp6IiEhODCNouSQ8p/cSERH5GsMILt25lzUjREREvscwAt6fhoiISE4MI+Cde4mIiOTEMAIWsBIREcmJYQRAKG+WR0REJBuGEXCYhoiISE4eh5Ht27cjOzsbiYmJEAQB69at63Sf999/H+np6QgODkZCQgLmz5+P6urq7rRXEpcKWDm1l4iIyNc8DiMmkwnp6elYunRpl7bfuXMn5s6di/vvvx9Hjx7Fxx9/jH379uHBBx/0uLFS4dReIiIi+ag636S1rKwsZGVldXn73bt3IzU1Fb/4xS8AAGlpaXjooYfwxz/+0dNDSyaMNSNERESykbxmJDMzE8XFxdiwYQNEUUR5eTk++eQTzJw5U+pDd1moVg0AqG9iGCEiIvI1ycPIxIkT8f7772POnDnQaDSIj4+HXq/vcJjHbDbDaDS2ekjJNZvG2GSV9DhERER0OcnDyLFjx7BgwQIsXrwYBw4cwJdffomzZ8/i4Ycfbnef3Nxc6PV69yMlJUXSNnKYhoiISD6Sh5Hc3FxMnDgRTz31FEaPHo3p06fj9ddfx4oVK1BaWtrmPosWLYLBYHA/iouLJW1jmHNqbx2HaYiIiHzO4wJWTzU0NEClan0YpdIxe0UUxTb30Wq10Gq1UjfNLUznrBkxN0MURQiC4LNjExER9XYe94zU19cjLy8PeXl5AIDCwkLk5eWhqKgIgKNXY+7cue7ts7OzsWbNGixbtgwFBQXYuXMnfvGLX2D8+PFITEz0zllcIdcwjc0uotHKtUaIiIh8yeOekf3792Py5Mnu5wsXLgQAzJs3DytXrkRpaak7mADAvffei7q6Ovz973/HL3/5S0RERGDKlCl+NbU3WKOEQgDsomNGTbBG8g4jIiIichLE9sZK/IjRaIRer4fBYEB4eLgkxxj9m69gbGrG1wtvxMDYUEmOQURE1Jt09fOb96Zxalk3QkRERL7DMOLkqhup41ojREREPsUw4hTK6b1ERESyYBhxci98xjBCRETkUwwjTqHOmhEuCU9ERORbDCNOXBKeiIhIHgwjTlwSnoiISB4MI06sGSEiIpIHw4iTa52ROjNrRoiIiHyJYcSJU3uJiIjkwTDidGnRM4YRIiIiX2IYcQrlbBoiIiJZMIw4hbtqRrjOCBERkU8xjDixZoSIiEgeDCNOrpqRBosNNrsoc2uIiIh6D4YRJ1fNCMC1RoiIiHyJYcRJq1JCo3L85+BaI0RERL7DMNICl4QnIiLyPYaRFnizPCIiIt9jGGkhjNN7iYiIfI5hpAVO7yUiIvI9hpEWuCQ8ERGR7zGMtBDKMEJERORzDCMtuJaEr+fUXiIiIp9hGGmBNSNERES+xzDSgntqL8MIERGRzzCMtOCqGTEyjBAREfkMw0gLYawZISIi8jmGkRa4HDwREZHvMYy0wOXgiYiIfI9hpIVLy8EzjBAREfkKw0gLlxY9s0IURZlbQ0RE1DswjLQQ7gwjVpuIJqtd5tYQERH1DgwjLYRqVVAIjn8beedeIiIin2AYaUEQBIQHOepGjI0MI0RERL7gcRjZvn07srOzkZiYCEEQsG7duk73MZvNePbZZ9GvXz9otVqkpqZixYoV3Wmv5Fz3p2HPCBERkW+oPN3BZDIhPT0d8+fPx2233dalfe644w6Ul5fjrbfewsCBA1FaWgq73T9rMsKDnKuwNnJGDRERkS94HEaysrKQlZXV5e2//PJLbNu2DQUFBYiMjAQApKamenpYn2HPCBERkW9JXjPy2Wef4eqrr8af/vQnJCUlYfDgwfjVr36FxsZGqQ/dLe4wwpoRIiIin/C4Z8RTBQUF2LFjB3Q6HdauXYuqqio8+uijqK6uxttvv93mPmazGWaz2f3caDRK3Uw39zANFz4jIiLyCcl7Rux2OwRBwPvvv4/x48dj5syZeOWVV/DOO++02zuSm5sLvV7vfqSkpEjdTDf2jBAREfmW5GEkISEBSUlJ0Ov17teGDRsGURRx/vz5NvdZtGgRDAaD+1FcXCx1M93cU3tZM0JEROQTkoeRiRMnoqSkBPX19e7XTp06BYVCgeTk5Db30Wq1CA8Pb/XwFdcqrJxNQ0RE5Bseh5H6+nrk5eUhLy8PAFBYWIi8vDwUFRUBcPRqzJ0717393XffjaioKNx33304duwYtm/fjqeeegrz589HUFCQd87Ci9gzQkRE5Fseh5H9+/dj7NixGDt2LABg4cKFGDt2LBYvXgwAKC0tdQcTAAgNDcXGjRtRW1uLq6++Gvfccw+ys7Px2muveekUvEvPFViJiIh8yuPZNJMmTerwjrYrV6687LWhQ4di48aNnh5KFpd6RjhMQ0RE5Au8N80PcDYNERGRbzGM/IBrnRFDo7XDHiAiIiLyDoaRH3D1jDTbRTRabTK3hoiIqOdjGPmBYI0SSoUAgNN7iYiIfIFh5AcEQbi01gin9xIREUmOYaQN4ZzeS0RE5DMMI21wz6hhzwgREZHkGEba4L5zL2tGiIiIJMcw0gb2jBAREfkOw0gbuPAZERGR7zCMtEEfzCXhiYiIfIVhpA3uqb3sGSEiIpIcw0gbLt0sj2GEiIhIagwjbbhUM8JhGiIiIqkxjLTBPbWXPSNERESSYxhpA2fTEBER+Q7DSBtcNSMGhhEiIiLJMYy04dKiZ80QRVHm1hAREfVsDCNtcNWM2OwiGiw2mVtDRETUszGMtCFIrYRKIQBgESsREZHUGEbaIAjCpbVGOL2XiIhIUgwj7dCziJWIiMgnGEbawRk1REREvsEw0o4IZxipbbDI3BIiIqKejWGkHRHB7BkhIiLyBYaRdrBmhIiIyDcYRtpxaZiGYYSIiEhKDCPt0AdrAAC17BkhIiKSFMNIO1jASkRE5BsMI+1wFbDyzr1ERETSYhhph6uAlcM0RERE0mIYaYerZ4QFrERERNJiGGmHPshRwGpsssJuF2VuDRERUc/FMNIO1zCNKAJ1TbxZHhERkVQYRtqhUSkQrFECAGobOaOGiIhIKh6Hke3btyM7OxuJiYkQBAHr1q3r8r47d+6ESqXCmDFjPD2sLLjwGRERkfQ8DiMmkwnp6elYunSpR/vV1tZi7ty5uOmmmzw9pGy48BkREZH0VJ7ukJWVhaysLI8P9PDDD+Puu++GUqn0qDdFTlz4jIiISHo+qRl5++23UVBQgCVLlvjicF7Dhc+IiIik53HPiKdOnz6NZ555Bt988w1Uqq4dzmw2w2w2u58bjUapmtchPWtGiIiIJCdpz4jNZsPdd9+NF154AYMHD+7yfrm5udDr9e5HSkqKhK1snz6Yq7ASERFJTdIwUldXh/379+Oxxx6DSqWCSqXCb3/7W3z33XdQqVTYvHlzm/stWrQIBoPB/SguLpayme2KcC58xp4RIiIi6Ug6TBMeHo7Dhw+3eu3111/H5s2b8cknnyAtLa3N/bRaLbRarZRN6xJXzYiBPSNERESS8TiM1NfXIz8/3/28sLAQeXl5iIyMRN++fbFo0SJcuHAB7777LhQKBUaOHNlq/9jYWOh0uste90eumhEDFz0jIiKSjMdhZP/+/Zg8ebL7+cKFCwEA8+bNw8qVK1FaWoqioiLvtVBGXPSMiIhIeoIoin5/Fzij0Qi9Xg+DwYDw8HCfHfdoiQGzXtuBmDAtvn12qs+OS0RE1BN09fOb96bpQIRzBVZDgxUBkNmIiIgCEsNIB1zDNBabHU1Wu8ytISIi6pkYRjoQrFFCpRAA8M69REREUmEY6YAgCO7pvSxiJSIikgbDSCe4JDwREZG0GEY64S5i5TANERGRJBhGOnFp4TP2jBAREUmBYaQTXPiMiIhIWgwjnXDdubemgcM0REREUmAY6USks2ak1sSeESIiIikwjHSiT4gjjFxkzwgREZEkGEY6EckwQkREJCmGkU64Fj2rMTGMEBERSYFhpBOunhHOpiEiIpIGw0gnXAWsFxsssNt5514iIiJvYxjphGsFVrsIGJvYO0JERORtDCOd0KgUCNWqAAAXOVRDRETkdQwjXdAnhEWsREREUmEY6YI+rroRhhEiIiKvYxjpgj7BXGuEiIhIKgwjXcCFz4iIiKTDMNIFlxY+YwErERGRtzGMdEEka0aIiIgkwzDSBbxZHhERkXQYRrqABaxERETSYRjpAq4zQkREJB2GkS7gzfKIiIikwzDSBX14szwiIiLJMIx0gWtqL2+WR0RE5H0MI12gVSl5szwiIiKJMIx00aWFz1jESkRE5E0MI13kXhKeYYSIiMirGEa6iGuNEBERSYNhpIt4szwiIiJpMIx0EW+WR0REJA2Pw8j27duRnZ2NxMRECIKAdevWdbj9mjVrcPPNNyMmJgbh4eHIzMzEV1991d32yoY3yyMiIpKGx2HEZDIhPT0dS5cu7dL227dvx80334wNGzbgwIEDmDx5MrKzs3Ho0CGPGysn3iyPiIhIGipPd8jKykJWVlaXt3/11VdbPX/xxRfx6aef4vPPP8fYsWM9PbxsXAWsnNpLRETkXT6vGbHb7airq0NkZKSvD31FokIZRoiIiKTgcc/Ilfrzn/+M+vp63HHHHe1uYzabYTab3c+NRqMvmtahKOcwTTXDiFcdKzHizW8KUFhlgrHRivAgNWaNSkB2eiLi9Tq5m0dERD7g0zCyatUqvPDCC/j0008RGxvb7na5ubl44YUXfNiyzkWFagEAhkYrLM12aFSciHQlKuqa8MLnx7D++9LLvpdXXIs/fHkCj08ZiMenDIJSIcjQQiIi8hWfhZEPP/wQDzzwAD7++GNMnTq1w20XLVqEhQsXup8bjUakpKRI3cQORQSpoRAcN8u72GBBXDj/au+uktpG3PPPvSisMgEAstMTMWtUAiKC1civqMe6Qxew/9xFvPr1aewpqMZrd45FLP97ExH1WD4JIx988AHmz5+PDz/8ELNmzep0e61WC61W64OWdZ1CISAyRIuqejOq6s0MI91UXNOAu97cg/MXG5EUEYQ35o7DiES9+/vX9o/CT6/thzUHz+O5dUewp6AGd725Bx8/PMG98BwREfUsHo811NfXIy8vD3l5eQCAwsJC5OXloaioCICjV2Pu3Lnu7VetWoW5c+fi5ZdfRkZGBsrKylBWVgaDweCdM/AhV90Ii1i7p8lqw/yV3+L8xUakRgXjo4czWwWRlm67KhmfP34dEvQ6nKk04b6V38JkbvZxi4mIyBc8DiP79+/H2LFj3dNyFy5ciLFjx2Lx4sUAgNLSUncwAYA33ngDzc3NyMnJQUJCgvuxYMECL52C77hm1FTXM4x0x+/XH8fpinrEhGmx+qFMJEUEdbj9gJhQ/Ov+8egTrMZ3xbV49P2DsNtFH7WWiIh8xeNhmkmTJkEU2/9AWLlyZavnW7du9fQQfstVxFpVb+5kS/qhr4+V4197zgEAXr49vcvDXANjw/D2feNx1xt7sO1UJd7aUYgHb+gvZVOJiMjHOCXEAxym6R5jkxVP//t7AMAD16XhhsExHu0/JiUCi7OHAwD+9NUJHC0JvCE+IiJqH8OIB9xrjXCYxiPLtp5BtcmC/jEheGrGkG69x53XpGDa8DhYbSJ+8cEhNFltXm4lERHJhWHEA65hmmoTh2m6qqS2ESt2FAIAFmUNg1al7Nb7CIKAP/xkNGLDtDhTacLybWe82UwiIpIRw4gHXFNLq9gz0mUv//cUzM12jE+LxNRh7S901xWRIRosyR4BAHh96xkUVTd4o4lERCQzhhEPRPP+NB45WVaHNYfOAwB+PXMYBOHKV1KdOSoeEwdGwdJsxwufH73i9yMiIvkxjHjAPUzD2TRd8sb2AogiMGNEPMakRHjlPQVBwAu3jIRaKWDTiQpsPlHulfclIiL5MIx4wLXOiMliYwFlJ8qNTfjsuwsAgIdu9O5U3IGxoZg/MQ0A8If/nICNa48QEQU0hhEPhGlVUCsdQw28e2/H3tl1FlabiKv79cHYvn28/v6PTh4IfZAap8rrsfbQBa+/PxER+Q7DiAcEQUBUCIdqOtNgacb7ex2r8Eq1QJk+SI1HJw0AAPxl4yn2VBERBTCGEQ9xSfjO/fvAeRgarUiNCsbUYXGSHWfehFQk6HW4UNuI95yruxIRUeBhGPHQpem97BlpiyiKWLWvGABw74RUKBVXPoOmPTq1Ek9MHQQAWL7tDHtHiIgCFMOIh6KdM2o4vbdtRy4YcbzUCI1KgR+PTZb8eLddlYzkPkGoqrfgw31Fne9ARER+h2HEQ+4l4RlG2rR6vyMQzBgRD32wWvLjqZUKPOKsHfnH9gKYm9k7QkQUaBhGPMQ797avyWrDp3klAIA7rk7x2XH/Z1wy4sK1KDU0Yc1BzqwhIgo0DCMe4s3y2vflkTLUNTUjKSIIEwZE+ey4WpUSD93g6B15fWs+rDa7z45NRERXjmHEQ1FcEr5dH+13FK7efnUyFBIWrrblrvF9ERWiQXFNIz5z9s4QEVFgYBjxEJeEb1uFsQm7C6oBAD+5SvrC1R8K0ijxwPWONU2Wbs3nqqxERAGEYcRDrmGaKpMFosgPPJf1h0shisDYvhFIiQyWpQ0/vbYv9EFqFFSa8J8jpbK0gYiIPMcw4iHXMI2l2Y46c7PMrfEfX3zv+PDPHp0oWxvCdGrcNzEVAPD3zfmws3eEiCggMIx4KFijQqhWBQCoquNQDQCcv9iAA+cuQhCAWaMTZG3LvRNSEapV4URZHTadqJC1LURE1DUMI90QE+aoG6lkGAEArHf2imSkRSIuXCdrWyKCNfjptf0AOGbWcCiNiMj/MYx0Q4yziLWSRawAgM+/d8xe+ZGMQzQtzb8uFRqVAoeKarGnoEbu5hARUScYRrqBPSOXFFU34MgFI5QKAVkj4+VuDgAgNkyHO652zOh5fWu+zK0hIqLOMIx0A8PIJf89VgbAMUTjmvbsDx66YQCUCgHfnK7C4fMGuZtDREQdYBjpBoaRS/57tBwAMG14nMwtaS0lMhjZzmLaZdvYO0JE5M8YRrqBNSMOVfVm7D/nqMm4eYR/DNG09MikgQCA/xwpw5nKeplbQ0RE7WEY6Qb2jDhsOl4OuwiMStIjKSJI7uZcZkh8GKYOi4UoAsu3npG7OURE1A6GkW5gGHHw1yGally9I2sPXUBJbaPMrSEiorao5G5AIHKFkWqTBTa7CKWPbwrnD0zmZnyTXwUAmOaHQzQu4/r1QUZaJPYW1uDNbwqwJHuE3E0KGKIoosZkwdlqEyqMZtSZm2G22qBRKaBTKxEXrkNSRBAS9DqolPy7hoi6j2GkGyJDNBAEwGYXcbHBgmg/mkXiK9+croSl2Y5+UcEYHBcqd3M69OjkgdhbuA8f7ivG41MGIdJ5fyG6XH5FPTYdL8e3Z2tw4NxFXGywdrqPTq3A0PhwpCfrkTkgGtf2j0REMP8bE1HXMYx0g1qpQGSwBtUmCyrrzL0yjGw5UQkAmDI0FoLg3z1DNwyKxojEcBwtMWLlzkIsnDZE7ib5lYsmC1bvL8aag+dxqvzyQt+kiCDE63UI06mgUylhsdlhMjejos6MC7WNaLLakVdci7ziWryz+xwUAjBhQDR+NDoBM0bGM5gQUacYRropJkzrDiPD5L0di8+JoohtpxxhZNKQWJlb0zlBEPDopIHIWXUQK3edxc9vHOC+v1BvVlzTgKVb8rH20AWYm+0AALVSwIQB0bhuYDTGpfbB8IRw6NTKdt/DbhdxttqEIyVG7D9bg535VThTacKO/CrsyK/Cc+uO4LpB0fjx2CTMGBkPrar99yKi3ov/R+6mmDAtTpTV9coi1hNldSgzNkGnViAjLVLu5nTJjJHx6B8dgoIqE1btPYef3zBA7ibJ5qLJgr98fQof7CuC1ea4d8+IxHDMzeyHGSMSoA9Wd/m9FAoB/WNC0T8mFLekO24HcK7ahC++L8UX35fieKkRW09WYuvJSkSFaHDHNSm4e3xfpEQGS3JuRBSYGEa6qTevNbL1pKNXJLN/VId/NfsTpULAQzf2x9P/Pox/flOIuZmpAdN2b7HbRXy0vxh//PKEuxZk4sAoPDF1MK7u18drw239okKQM3kgciYPxJnKenyWV4KP9hej1NCEZVvPYPm2M5gyJBb3TkzFdQOj/X6Yj4ikxzDSTb15eu+2UxUAAmOIpqUfj03GX78+jRJDEz7YV4T7JqbJ3SSfKa5pwFOffOe+ceDQ+DAszh6OCQOiJT3ugJhQPHnzYDw+ZSA2najAe3vO4ZvTVdh0ogKbTlRgcFwo5k9Mw+yxSb0uHBLRJZyP1029NYzUNVmx/+xFAMCkITEyt8YzGpUCOVMc6468vvUMmqw2mVvkG2sOnkfWX7/BnoIaBKmVeG7WMHzx+HWSB5GWVEoFpo+Ix7/uz8CWX03CvRNSEaJR4lR5PZ5ZcxgT/rAZL//3JCqMTT5rExH5D497RrZv346XXnoJBw4cQGlpKdauXYvZs2d3uM/WrVuxcOFCHD16FCkpKXjuuedw7733drPJ/qG3hpGd+dVototIiw5Bv6gQuZvjsdvHpeD1LWdwobYR7+05hweu7y93kyRjbrbht58fw/t7iwA41lx5+fZ0pEbLe93SokPwm1tGYOG0wfjo22K8vfMsLtQ24m+b87F82xlkpydi/sQ0jEzSy9rOnqrB0gxDoxV1Tc2oa7LC2NgMk6UZlma742FzfAUcw5tKhQBBEKBRCgjWqBCqVSFYo0SIVuV8KBGuU7Nni66Ix2HEZDIhPT0d8+fPx2233dbp9oWFhZg1axYefvhhvP/++9i0aRMeeOABJCQkYPr06d1qtD/orTUjriGaGwcHVq+Ii0alwONTBuKZNYexfFsB7snohyBNz/ufaKmhEY+8dxB5xbUQBGDBTYPw+JRBfrVAX7hOjQeu7497J6Tiv8fKsWJHIfafu4g1By9gzcELyEiLxP3XpeGmYXF+1W5/12BpRkGlCWcq65FfUY/zFxtRbmxyPsyoNzdLctwgtRJ9gtWICNagT4jja2SwptVrfYI17kdEiBphWhVrhghAN8JIVlYWsrKyurz98uXLkZaWhpdffhkAMGzYMOzYsQN/+ctfAjuM9MKeEVEU3cWrgTZE09JPxiVj6dZ8FNc04u1dhXjUuWR8T7ErvwqPf3AI1SYL9EFqvHrnGEz24/oelVKBmaMSMHNUAvKKa7FiRyE2HC7F3sIa7C2sQb+oYPw0ox9+fFVSr1zTpz3mZhvyK+pxqrwOJ8rqcKqsDqfK63GhC7c9UCoEhOtUCNOpER6kQrBGBa1KAY1SAY1KAbVzRV2bKEIURdjsIizNdjRYbDBZmtFgdnw1Ob+KItBotaHRYEOJoetDbSqF4AgqwY6gEhGsRmSIpsPXIoI1DKceEEURJosNhkYrahssMDRaYWiwOp43Or82WGFstOLBG/pjTEqELO2UvIB19+7dmDp1aqvXpk+fjieeeKLdfcxmM8zmSx/yRqNRquZ1myuMGBqtMDfbesX6CafK61FqaIJWpcC1/aPkbk63qZUKLLx5MJ5c/R2WbT2Du67piz49ZFXW9/eew/PrjsAuAsMTwrH8p+PQNypwptGOSYnAa3eNxaKZQ/Hu7nNYtbcI56ob8PsNx/HHL0/gpmGxuOPqFNw4OKbXLEHfZLWhqKYBBZUmnCqvw8myOpwsr0NhlQk2u9jmPpEhGgyMCcWAWMdwany4DrHhWsSF6xAbpkWoF3skRFFEnbkZtSYrLjZYLj1Mjg+/iw1W1DRYHP92vlbTYEGT1Y5mu4iqejOqPOxhDtepWoeWEA2iQjSIDNE6v2rcr4XqHENLWpUioHthLM12R5BotLgDRMuvBnewsLiDhtH5/eZ2fk5+aNqIuJ4bRsrKyhAX1/pGanFxcTAajWhsbERQ0OV3e83NzcULL7wgddOuiD5IDY1SAYvNjup6CxL98K613rb1pGOI5toAmtLbnlvTk/Dm9kIcKzXi71vy8fyPhsvdpCsiiiJe/fo0/rrpNADgtrFJePG2UQF7nRL0QXh6xlA8PmUg1h0qwer9xfiuuBZfHS3HV0fLEROmxbThcbh5eBwyB0QF/B8D5mYbimsacLaqAWerTSisMuFstQlnqxpQYmiE2M5nSbhOhaHx4RgcH4oh8eEYEheGgbGhPr3lgSAICNepEa5TexR8m6y2VqGlxhlcak2Or5eCjfP7JgvqmhxDTMamZhibmoHqhi4fT6kQEKJRIkynRohWiWCNChqVAlpnT5CrV8j9ULb+qlIKUAqOGhqVQoBSqXB8dT1v9e/W3xMEAVabHWZ3XY4NZqujPqfJakN9UzPqzM2ob2pGvbnZUc9jbkZ9k9X9vMFyZQX3GqUC+mA19EFqRAQ5vl56roE+SCVrnZZfTu1dtGgRFi5c6H5uNBqRkpIiY4suJwgCYsK0uFDbiMo6c68II5dWXQ3cIRoXhULAM1lDMXfFPry7+yzunZAasAtx2ewinlt3BB/scxSq/mLKQDx58+CA/ivQJVijwt0ZfXF3Rl+cLKvDR/uLsfbQBVTWmfH+3iK8v7cIoVoVbhwSgxsHxeCqfhHoHx0KhR9149udf/2XGJpQWtuIEkMTSmobUWpoREltE0oNjaioM7cbOAAgTKtCanQIBsWFYmh8GAbHhWFofDjiwrUBe511aiUS9EFI0Hf9/53NNjtqG13hxBlYTI4gU1Pv/GpyPKrrHWHG9SFus4uXQkyAEgRHrZU+SI0IZ5BwPSJaBovgy1/Tqf27Z0jyMBIfH4/y8vJWr5WXlyM8PLzNXhEA0Gq10Gr9f2w42hlGynvBdMR6czO+PetYoyLQ1hdpzw2DY3DdwGjsyK/CH/5zAkvvuUruJnmsyWrDLz44hP8eK4cgAL+9dSR+dm0/uZsliSHxYXj+R8Px9Iyh2HWmChuPlWPjsXJU1Jmx/vtSrP++FICj13Js3wiMTemD/jEh6BcVjH5RIdAHdX1l2fZYmh335ak3O2ag1DU1o7re9QFoRnWLD8Jq52s1Jot7pduOhGiUSI0OcTyigpEaFYI05/OoEI1ff5D4ikqpQHSo1qPaIZtddNa3NDuvnQ0ms6OnwdVL4fgqXppR5Hzd6nzN3GyHze4YVrK1eDS3+mq//HWbc1tRhFrp6IVx9cZc+rcSYc6hpFCdCmHur2r3EFOYToWIIA3CdCq/CtreJHkYyczMxIYNG1q9tnHjRmRmZkp9aMnFOetGyntBEeuu/CpYbSL6RQUjTeapod7065nD8KO/fYP1h0txT34VJgz03dobV8rQYMUD736Lb89ehEalwGt3jsGMkT3/RkkalQKThsRi0pBY/O7Wkfj+ggFfHyvHvrM1+P58LQyNVvcS9C1FBKvd9RKhOjVCtUoEqVWwiyIsNjuabXY020RY7SKanV3q7uBhdhRrWmz2brVZIQBx4Tok6HVIiAhCol6HBH0QEiNcX4MQHcrAIQVHsa5jGIn8l8dhpL6+Hvn5+e7nhYWFyMvLQ2RkJPr27YtFixbhwoULePfddwEADz/8MP7+97/jf//3fzF//nxs3rwZH330EdavX++9s5BJXLgOAHrFQk1bXUM0ATqltz3DE8Px02v74d3d57Dks6PYsOB690wCf1ZqaMS8FftwqrweYToV3px7dUAXFXeXQiFgTEqEu+jOarPjeKkRB89dxOELRhTVmHC2ugGVdWbUNjiK+bxBq1K4/5KNdBdPXiqgjArVOF/XIipUg9gwba8puCXqDo/DyP79+zF58mT3c1dtx7x587By5UqUlpaiqKjI/f20tDSsX78eTz75JP76178iOTkZ//znPwN6Wq9LvN4RRso8mMoWiERRxLaTgXOXXk/98uYh+OL7UpyuqMc7u876/UJop8vrMG/FPpQYmhAbpsW794/H0PhwuZvlF9RKBUYnR2B0ckSr103mZhTVNOCiydKqULDBYoNKIUClFKBSKqBRClApHMWKWpXCvbBXqOurRoVgrTIgAitRIPE4jEyaNAliB5VWK1eubHOfQ4cOeXoov+fqGSnr4T0j56obcKG2EWqlgIz+gXGXXk/og9V4esYQPP3vw3hl4ylMHxHvt8WsB87VYP7K/TA0WtE/JgTvzh+P5D7+2VZ/EqJVYVgCAxuRv2K8vwJx4Y6akQpjz64Z2ZFfBQC4qm8fBGv8cgLWFbt9XArGp0WiwWLDM2u+7zBwy2XjsXLc/eZeGBqtGJMSgU8ensAgQkQ9AsPIFYjvJT0jO51h5PpBgVPc6SmFQsCffjIaOrUCO/Or8cG+Yrmb1MqH+4rw0L/2w9xsx+QhMVj1YIZP15IgIpISw8gViHWGEUOjtcfeAdZmF7HrTDUAYGIAzTTpjtToEDw1fSgA4Pfrj6GwyiRzixz1On/9+jSeWXMYdhH4n3HJeGPu1T22h4qIeieGkSsQrlMhyLnCZU9da+RoiQGGRivCdCqM6gV3Ub13QirGp0XCZLEh5/2DsoZMS7Mdv/z4O/zl61MAgEcmDcBL/zOaxZNE1OPw/2pXQBAEd91IT51R46oXyewf1SumJioVAl67cywiQzQ4VmrEb784Jks7ahssmLtiL9YcvAClQsD/zR6Jp2cM5ToURNQj9fxPF4m5ZtT01IXPXPUi1/XgepEfitfr8OqcMRAEYNXeIvcy675yrtqE25btwp6CGoRqVXhr3tX4aQ9dVZWICGAYuWKutUbKe2DPSJPVhm/PXgTQ8+tFfuiGwTFYcNMgAMCzaw9j47HyTvbwjp35Vfjx67tQUGlCol6HTx7J7JFruxARtcQwcoV68loj356tgaXZjgS9Dv170BLwXbXgpkG4fVwy7CLw2KqD7nvzSMFuF7F0Sz5+9tZe1JgsGJkUjrU5E7mYGRH1CgwjV8g9TNMDw4irXmTiwOheWasgCAJevG0UJg+JgbnZjp+9tRebT3i/h6TU0Ii5K/bhpa9Owi4Ct49LxicPT3D/bBER9XQMI1fIVcDaE8OIu16klw3RtKRWKrD0nqswaUgMmqx2PPjuAXy4r8gri6KJooh/HziP6X/Zjh35VdCpFfjDbaPw0u3p0DlnaRER9QYMI1co3t0z0rMKWGtMFhwtMQIAJgzsfTdgaylY47gR3Y/HJsFmF/HMmsN4/INDMFzBTdeOXDDg9uW78cuPv4OxqRnpyXqs/8X1uHN8Xy+2nIgoMHDlpCvUsmZEFMUeM5yx+0w1RBEYEheG2DAOF6iVCrx8ezoGxITg1a9P44vvS7GvsAaPTRmIOdekQKvqvCdDFEV8d96ApVvy3QWxQWolHr9pIB68vj/XDyGiXoth5ArFOodpLM121DZY0aeHLNHdsl6EHBQKAY9NGYTrB8XgydV5KKgyYfGnR/H6ljO4ZUwipg2Pw8gkfashliarDUdLjNhTUI1P8y7gVHk9AEAQgOzRiXgmaygSI4LkOiUiIr/AMHKFtColIkM0qDFZUF7X1GPCyKX1RXr3EE1b0lMi8J8nrsdH3xZj6ZYzKDM24Y3tBXhjewEARx1RkFqJBosNNSYLmu2X6ks0SgV+lJ6ARycNxMDYULlOgYjIrzCMeEFsmBY1JgvKDE09YipmUXUDimoaoFIIyEhjGGmLVqXEzzJTcfvVKdh8ogL/PVqGLScrYWi0XlY/FB2qQXpyBKaPiMf0kfHQB6llajURkX9iGPGCeL0OJ8rqesyMmp1nHL0iY/tGIETLH5GO6NRKzByVgJmjEiCKIi42WFFU04Bmmx1BGkevWXy4rsfUEhERSYGfNF7Q02bUsF6kewRBQGSIBpE9ZKiOiMhXWL7vBa4ZNaU9YEl4u13ELq4vQkREPsQw4gWJEa4w0ihzS67csVIjLjZYEaJRIj0lQu7mEBFRL8Aw4gWuqZkltYEfRnY560Wu7R/FdS+IiMgn+GnjBQl6RxgprQ38YZod+dUAgAkcoiEiIh9hGPEC1zBNnbkZxqbuLxEuN3OzDfsKHWGE9SJEROQrDCNeEKxRISLYsXZEIPeOHDxXiyarHdGhWgyO44JcRETkGwwjXuIaqgnkuhFXvch1A6O4LgYREfkMw4iXJDmHakoCeEaNa30R1osQEZEvMYx4SaD3jBibrPiuuBYAFzsjIiLfYhjxEtf03kCtGdlzphp2EegfHYIk3kWWiIh8iGHES1wzai4EaM/ITi4BT0REMmEY8RJ3z0iALgm/84xjSu/EgbxLLxER+RbDiJck6C8tCW+3izK3xjNlhibkV9RDEIDM/uwZISIi32IY8ZK4cB0UAmC1iagyBdbde11DNKOT9NA710shIiLyFYYRL1ErFYgNc07vDbAi1p2c0ktERDJiGPEi9917A6iIVRRF7HQvdsYwQkREvscw4kUJziLWQJpRc6ayHuVGM7QqBcb16yN3c4iIqBdiGPGipACcUbPjtKNX5JrUSOjUSplbQ0REvVG3wsjSpUuRmpoKnU6HjIwM7Nu3r8PtX331VQwZMgRBQUFISUnBk08+iaamwPnA7irXjJpAWoV1R75jSu8ETuklIiKZeBxGVq9ejYULF2LJkiU4ePAg0tPTMX36dFRUVLS5/apVq/DMM89gyZIlOH78ON566y2sXr0av/71r6+48f4m0JaEb7bZsbfAEUZYL0JERHLxOIy88sorePDBB3Hfffdh+PDhWL58OYKDg7FixYo2t9+1axcmTpyIu+++G6mpqZg2bRruuuuuTntTAlFyn8CqGfn+ggF15mbog9QYkaiXuzlERNRLeRRGLBYLDhw4gKlTp156A4UCU6dOxe7du9vcZ8KECThw4IA7fBQUFGDDhg2YOXNmu8cxm80wGo2tHoEgpU8wAKCq3oIGS7PMrencTme9yIQBUVAqBJlbQ0REvZVHYaSqqgo2mw1xcXGtXo+Li0NZWVmb+9x999347W9/i+uuuw5qtRoDBgzApEmTOhymyc3NhV6vdz9SUlI8aaZs9MFqhOlUAIALF/2/d2QH1xchIiI/IPlsmq1bt+LFF1/E66+/joMHD2LNmjVYv349fve737W7z6JFi2AwGNyP4uJiqZvpNcnO3pHiiw0yt6RjDZZmHCy6CID1IkREJC+VJxtHR0dDqVSivLy81evl5eWIj49vc5/nn38eP/vZz/DAAw8AAEaNGgWTyYSf//znePbZZ6FQXJ6HtFottFqtJ03zGyl9gnC81Ijzft4zsregBlabiKSIIKRGBcvdHCIi6sU86hnRaDQYN24cNm3a5H7Nbrdj06ZNyMzMbHOfhoaGywKHUulYz0IUA+uGcl2REunsGanx756RbacqAQA3DI6BILBehIiI5ONRzwgALFy4EPPmzcPVV1+N8ePH49VXX4XJZMJ9990HAJg7dy6SkpKQm5sLAMjOzsYrr7yCsWPHIiMjA/n5+Xj++eeRnZ3tDiU9iWtGTXGNf/eMbD3pmIp94+AYmVtCRES9ncdhZM6cOaisrMTixYtRVlaGMWPG4Msvv3QXtRYVFbXqCXnuuecgCAKee+45XLhwATExMcjOzsbvf/97752FH3HNqDlf6789I2erTDhb3QCVQsBELnZGREQyE8QAGCsxGo3Q6/UwGAwIDw+XuzkdOlFmxIxXv4E+SI3vlkyTuzltemfXWSz57Cgy0iKx+qG2h9eIiIiuVFc/v3lvGi9z9YwYGq0wNlllbk3bXPUik4bEytwSIiIihhGvC9GqEBmiAQCc98O6kSarDbvOONYXmTSE9SJERCQ/hhEJuItY/XCtkX2FNWiy2hEXrsXQ+DC5m0NERMQwIgV3EasfrjXiGqK5kVN6iYjITzCMSCA50jW91/96Ri5N6WW9CBER+QeGEQkk+2nPSHFNA85UmqBUCLhuEJeAJyIi/8AwIoEUZ83IeT+rGXEN0YxNiYA+SC1za4iIiBwYRiTgvlleTYNfLXm/9aRrSi9n0RARkf9gGJGAazaNyWJDjckic2scLM32FlN6WS9CRET+g2FEAjq1EkkRjkBSWGWSuTUO+8/WoMFiQ3SoBsMT/HsVWyIi6l0YRiTSPyYEAHCmsl7mlji0vEuvQsEpvURE5D8YRiQyICYUAHCm0j96Rjaf4F16iYjIPzGMSGRArDOMVMjfM3K2yoTTFfVQKQTWixARkd9hGJHIAD8apvn6eDkAIKN/JKf0EhGR32EYkchA5zBNUU0DzM02Wduy8ZgjjEwdFidrO4iIiNrCMCKRmDAtwrQq2EXgbJV8i59dNFnw7dkaAAwjRETknxhGJCIIwqW6ERmHaracrIBdBIbGhyElMli2dhAREbWHYURC7hk1MhaxuoZopg1nrwgREfknhhEJDYiVt4i1yWrDduf6IlMZRoiIyE8xjEhI7rVGvjldBZPFhgS9DqOS9LK0gYiIqDMMIxK6FEbqZblh3n8OlwIAskYmQBC46ioREfknhhEJ9YsKhkohoMFiQ5mxyafHNjfb3PUiM0fF+/TYREREnmAYkZBaqUDfKMcMlnwfF7HuzK9CnbkZceFaXNW3j0+PTURE5AmGEYkNjg0DABwvNfr0uBsOlwFwDNHwxnhEROTPGEYklp4SAQDIK6712TEtzXb896grjHCIhoiI/BvDiMTSUxyzWPKKan12zJ1nqmBsakZ0qBZXp0b67LhERETdwTAisdHJERAEoMTQhAofFbF+eugCAOBHoxOg5BANERH5OYYRiYVqVe66EV8M1ZjMzfjqqGMWzeyxSZIfj4iI6EoxjPiAa6jmu/O1kh/rv8fK0Gi1IS06BOnJXOiMiIj8H8OID4xJcUyt9UXPyNpDJQCAW8ckcqEzIiIKCAwjPjDGOaPm+2ID7HbpVmKtrDNjx2nHvWhmj+EQDRERBQaGER8YHBeKILUSdeZmFFRJt/jZZ9+VwC4CY/tGIDU6RLLjEBEReRPDiA+olAr3jeoOSTTFVxRFrP62CABw21XJkhyDiIhICgwjPjKmbwQA4KBEYeTAuYs4VV6PILUSt45JlOQYREREUmAY8ZFr+zsWH9t2skKSO/iu2ufoFclOT0C4Tu319yciIpJKt8LI0qVLkZqaCp1Oh4yMDOzbt6/D7Wtra5GTk4OEhARotVoMHjwYGzZs6FaDA9WEAdHQqRUoMTThZHmdV9/b0GDF+u9LAQB3je/r1fcmIiKSmsdhZPXq1Vi4cCGWLFmCgwcPIj09HdOnT0dFRUWb21ssFtx88804e/YsPvnkE5w8eRJvvvkmkpJ612wPnVqJCQOiAQCbjrf936q71hw6D3OzHUPjw9wzd4iIiAKFx2HklVdewYMPPoj77rsPw4cPx/LlyxEcHIwVK1a0uf2KFStQU1ODdevWYeLEiUhNTcWNN96I9PT0K258oJkyNBYAsOWE98KI3S7ivT3nAAD3ZPTl2iJERBRwPAojFosFBw4cwNSpUy+9gUKBqVOnYvfu3W3u89lnnyEzMxM5OTmIi4vDyJEj8eKLL8Jms7V7HLPZDKPR2OrRE0x2hpGDRRdx0WTxyntuOlGBM5UmhOlUXP6diIgCkkdhpKqqCjabDXFxca1ej4uLQ1lZWZv7FBQU4JNPPoHNZsOGDRvw/PPP4+WXX8b//d//tXuc3Nxc6PV69yMlJcWTZvqtpIggDI0Pg10Etp2q9Mp7/mPbGQDAPRn9EMbCVSIiCkCSz6ax2+2IjY3FG2+8gXHjxmHOnDl49tlnsXz58nb3WbRoEQwGg/tRXFwsdTN95qZhjt6RTV4Yqjlwrgb7z12ERqnAfRNTr/j9iIiI5OBRGImOjoZSqUR5eXmr18vLyxEfH9/mPgkJCRg8eDCUSqX7tWHDhqGsrAwWS9tDFVqtFuHh4a0ePcWUoY5epc3Hy1HXZL2i9/rHtgIAwOyxiYgL111x24iIiOTgURjRaDQYN24cNm3a5H7Nbrdj06ZNyMzMbHOfiRMnIj8/H3a73f3aqVOnkJCQAI1G081mB66r+kZgQEwITBYb1hy80O33OXLBgI3HHaHw5zf091bziIiIfM7jYZqFCxfizTffxDvvvIPjx4/jkUcegclkwn333QcAmDt3LhYtWuTe/pFHHkFNTQ0WLFiAU6dOYf369XjxxReRk5PjvbMIIIIgYN6EVADAO7vPduvGeaIo4vfrj0MUgVvSEzEwNszLrSQiIvIdlac7zJkzB5WVlVi8eDHKysowZswYfPnll+6i1qKiIigUlzJOSkoKvvrqKzz55JMYPXo0kpKSsGDBAjz99NPeO4sAc9tVyfjTlydRUGnCjvwq3DA4xqP9Nx2vwO6CamhUCjw1fYhErSQiIvINQZRibXIvMxqN0Ov1MBgMPaZ+5DefHcXKXWdx09BYvHXvNV3ez2qzY/qr21FQacJDN/bHoqxhEraSiIio+7r6+c1708hkbmY/AMDmkxU4csHQ5f1e33IGBZUmRIZokDN5oFTNIyIi8hmGEZn0jwlFdnoiRBF46pPvYWm2d7rPnoJq/HXTKQDA4h8N5w3xiIioR2AYkdGS7OGIDNHgeKkRy7ae6XDbGpMFT3yYB7sI/M+4ZK62SkREPQbDiIyiQ7X4zS0jAAB/33Iaewqq29yu1NCIe/65F2XGJgyICcFvbx3hy2YSERFJimFEZtmjEzB9RBysNhE//ede/Gv3WbhqikVRxLdna/DjpbtwvNSI6FAtlv10HII1Hk+CIiIi8lucTeMHGi02/O+/v8fn35UAAPpGBmNEYjhOldfhTKUJADAoNhRv33cNkvsEy9lUIiKiLuvq5zf/xPYDQRolXrtzDEYlheNPX55EUU0DimoaAABalQKzRidgSfYI6INYsEpERD0Pw4ifEAQBP79hAO64OgVHLhhxvNSIqFANbh4ex7vxEhFRj8Yw4mcigjW4blA0rhsULXdTiIiIfIIFrERERCQrhhEiIiKSFcMIERERyYphhIiIiGTFMEJERESyYhghIiIiWTGMEBERkawYRoiIiEhWDCNEREQkK4YRIiIikhXDCBEREcmKYYSIiIhkxTBCREREsgqIu/aKoggAMBqNMreEiIiIusr1ue36HG9PQISRuro6AEBKSorMLSEiIiJP1dXVQa/Xt/t9QewsrvgBu92OkpIShIWFQRAEr72v0WhESkoKiouLER4e7rX39Sc9/Rx5foGvp59jTz8/oOefI8+v+0RRRF1dHRITE6FQtF8ZEhA9IwqFAsnJyZK9f3h4eI/8AWupp58jzy/w9fRz7OnnB/T8c+T5dU9HPSIuLGAlIiIiWTGMEBERkax6dRjRarVYsmQJtFqt3E2RTE8/R55f4Ovp59jTzw/o+efI85NeQBSwEhERUc/Vq3tGiIiISH4MI0RERCQrhhEiIiKSFcMIERERyarHh5GlS5ciNTUVOp0OGRkZ2LdvX4fbf/zxxxg6dCh0Oh1GjRqFDRs2+KilnsvNzcU111yDsLAwxMbGYvbs2Th58mSH+6xcuRKCILR66HQ6H7XYM7/5zW8ua+vQoUM73CeQrh8ApKamXnaOgiAgJyenze39/fpt374d2dnZSExMhCAIWLduXavvi6KIxYsXIyEhAUFBQZg6dSpOnz7d6ft6+nsslY7Oz2q14umnn8aoUaMQEhKCxMREzJ07FyUlJR2+Z3d+zqXU2TW89957L2vvjBkzOn3fQLiGANr8fRQEAS+99FK77+lP17ArnwtNTU3IyclBVFQUQkND8ZOf/ATl5eUdvm93f3e7qkeHkdWrV2PhwoVYsmQJDh48iPT0dEyfPh0VFRVtbr9r1y7cdddduP/++3Ho0CHMnj0bs2fPxpEjR3zc8q7Ztm0bcnJysGfPHmzcuBFWqxXTpk2DyWTqcL/w8HCUlpa6H+fOnfNRiz03YsSIVm3dsWNHu9sG2vUDgG+//bbV+W3cuBEAcPvtt7e7jz9fP5PJhPT0dCxdurTN7//pT3/Ca6+9huXLl2Pv3r0ICQnB9OnT0dTU1O57evp7LKWOzq+hoQEHDx7E888/j4MHD2LNmjU4efIkbrnllk7f15Ofc6l1dg0BYMaMGa3a+8EHH3T4noFyDQG0Oq/S0lKsWLECgiDgJz/5SYfv6y/XsCufC08++SQ+//xzfPzxx9i2bRtKSkpw2223dfi+3fnd9YjYg40fP17MyclxP7fZbGJiYqKYm5vb5vZ33HGHOGvWrFavZWRkiA899JCk7fSWiooKEYC4bdu2drd5++23Rb1e77tGXYElS5aI6enpXd4+0K+fKIriggULxAEDBoh2u73N7wfS9QMgrl271v3cbreL8fHx4ksvveR+rba2VtRqteIHH3zQ7vt4+nvsKz88v7bs27dPBCCeO3eu3W08/Tn3pbbOcd68eeKtt97q0fsE8jW89dZbxSlTpnS4jT9fwx9+LtTW1opqtVr8+OOP3dscP35cBCDu3r27zffo7u+uJ3psz4jFYsGBAwcwdepU92sKhQJTp07F7t2729xn9+7drbYHgOnTp7e7vb8xGAwAgMjIyA63q6+vR79+/ZCSkoJbb70VR48e9UXzuuX06dNITExE//79cc8996CoqKjdbQP9+lksFrz33nuYP39+hzeEDKTr11JhYSHKyspaXSO9Xo+MjIx2r1F3fo/9icFggCAIiIiI6HA7T37O/cHWrVsRGxuLIUOG4JFHHkF1dXW72wbyNSwvL8f69etx//33d7qtv17DH34uHDhwAFartdX1GDp0KPr27dvu9ejO766nemwYqaqqgs1mQ1xcXKvX4+LiUFZW1uY+ZWVlHm3vT+x2O5544glMnDgRI0eObHe7IUOGYMWKFfj000/x3nvvwW63Y8KECTh//rwPW9s1GRkZWLlyJb788kssW7YMhYWFuP7661FXV9fm9oF8/QBg3bp1qK2txb333tvuNoF0/X7IdR08uUbd+T32F01NTXj66adx1113dXjzMU9/zuU2Y8YMvPvuu9i0aRP++Mc/Ytu2bcjKyoLNZmtz+0C+hu+88w7CwsI6HcLw12vY1udCWVkZNBrNZQG5s89G1zZd3cdTAXHXXupcTk4Ojhw50uk4ZWZmJjIzM93PJ0yYgGHDhuEf//gHfve730ndTI9kZWW5/z169GhkZGSgX79++Oijj7r0l0qgeeutt5CVlYXExMR2twmk69ebWa1W3HHHHRBFEcuWLetw20D7Ob/zzjvd/x41ahRGjx6NAQMGYOvWrbjppptkbJn3rVixAvfcc0+nReL+eg27+rngD3psz0h0dDSUSuVlFcLl5eWIj49vc5/4+HiPtvcXjz32GL744gts2bIFycnJHu2rVqsxduxY5OfnS9Q674mIiMDgwYPbbWugXj8AOHfuHL7++ms88MADHu0XSNfPdR08uUbd+T2WmyuInDt3Dhs3bvT4luyd/Zz7m/79+yM6Orrd9gbiNQSAb775BidPnvT4dxLwj2vY3udCfHw8LBYLamtrW23f2Weja5uu7uOpHhtGNBoNxo0bh02bNrlfs9vt2LRpU6u/LFvKzMxstT0AbNy4sd3t5SaKIh577DGsXbsWmzdvRlpamsfvYbPZcPjwYSQkJEjQQu+qr6/HmTNn2m1roF2/lt5++23ExsZi1qxZHu0XSNcvLS0N8fHxra6R0WjE3r17271G3fk9lpMriJw+fRpff/01oqKiPH6Pzn7O/c358+dRXV3dbnsD7Rq6vPXWWxg3bhzS09M93lfOa9jZ58K4ceOgVqtbXY+TJ0+iqKio3evRnd/d7jS8x/rwww9FrVYrrly5Ujx27Jj485//XIyIiBDLyspEURTFn/3sZ+Izzzzj3n7nzp2iSqUS//znP4vHjx8XlyxZIqrVavHw4cNynUKHHnnkEVGv14tbt24VS0tL3Y+Ghgb3Nj88xxdeeEH86quvxDNnzogHDhwQ77zzTlGn04lHjx6V4xQ69Mtf/lLcunWrWFhYKO7cuVOcOnWqGB0dLVZUVIiiGPjXz8Vms4l9+/YVn3766cu+F2jXr66uTjx06JB46NAhEYD4yiuviIcOHXLPJvnDH/4gRkREiJ9++qn4/fffi7feequYlpYmNjY2ut9jypQp4t/+9jf3885+j/3l/CwWi3jLLbeIycnJYl5eXqvfSbPZ3O75dfZz7msdnWNdXZ34q1/9Sty9e7dYWFgofv311+JVV10lDho0SGxqanK/R6BeQxeDwSAGBweLy5Yta/M9/PkaduVz4eGHHxb79u0rbt68Wdy/f7+YmZkpZmZmtnqfIUOGiGvWrHE/78rv7pXo0WFEFEXxb3/7m9i3b19Ro9GI48ePF/fs2eP+3o033ijOmzev1fYfffSROHjwYFGj0YgjRowQ169f7+MWdx2ANh9vv/22e5sfnuMTTzzh/u8RFxcnzpw5Uzx48KDvG98Fc+bMERMSEkSNRiMmJSWJc+bMEfPz893fD/Tr5/LVV1+JAMSTJ09e9r1Au35btmxp82fSdQ52u118/vnnxbi4OFGr1Yo33XTTZefdr18/ccmSJa1e6+j32Jc6Or/CwsJ2fye3bNnifo8fnl9nP+e+1tE5NjQ0iNOmTRNjYmJEtVot9uvXT3zwwQcvCxWBeg1d/vGPf4hBQUFibW1tm+/hz9ewK58LjY2N4qOPPir26dNHDA4OFn/84x+LpaWll71Py3268rt7JQTnQYmIiIhk0WNrRoiIiCgwMIwQERGRrBhGiIiISFYMI0RERCQrhhEiIiKSFcMIERERyYphhIiIiGTFMEJERESyYhghIiIiWTGMEBERkawYRoiIiEhWDCNEREQkq/8HpMEAHqUKbNQAAAAASUVORK5CYII=",
+      "text/plain": [
+       "<Figure size 640x480 with 1 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "main(1)\n",
+    "\n",
+    "start = time.time()\n",
+    "sol = main(1)\n",
+    "end = time.time()\n",
+    "print(f\"Integration took in {end - start} seconds.\")\n",
+    "\n",
+    "plt.plot(sol.ts, sol.ys)\n",
+    "plt.show()"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3.9.16 ('documentation_diffrax')",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.9.16"
+  },
+  "orig_nbformat": 4,
+  "vscode": {
+   "interpreter": {
+    "hash": "d644d6de157dae45f4c41ea729963c4364743d9466df2c5e80d490cf71a3f866"
+   }
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/examples/neural_dde.ipynb b/examples/neural_dde.ipynb
new file mode 100644
index 00000000..990c5e1b
--- /dev/null
+++ b/examples/neural_dde.ipynb
@@ -0,0 +1,319 @@
+{
+ "cells": [
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Neural DDE"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "This example demonstrates how to use Diffrax in order to solve a Delay Differential Equation (DDE) with known delays.  \n",
+    "Unlike ODEs that are identified by their vector field $f(t, y(t))$ and initial condition $y(0)=y_0$, DDEs are specified by their vector field $f$, deviated arguments $y(t-\\tau)$ and history function $\\phi(t)=y(t<0)$.\n",
+    "\n",
+    "We will model the [Lotka Volterra](https://en.wikipedia.org/wiki/Lotka%E2%80%93Volterra_equations) (LK) equations with one constant time delay defined as \n",
+    "\n",
+    "$$\n",
+    "\\begin{align}\n",
+    "& y_1'(t) = \\frac{1}{2} y_1(t) ( 1  - y_2(t-0.2)) \\\\\n",
+    "& y_2'(t) = -\\frac{1}{2} y_2(t)( 1  - y_1(t-0.2)) \\\\\n",
+    "& \\phi(t) = y(t<0) = (y_1, y_2) \n",
+    "\\end{align}\n",
+    "$$\n",
+    "\n",
+    "where $x_0, y_0$ are uniformly sampled in $[0.1,2]$."
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "This example is available as a Jupyter notebook [here](url)."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import time\n",
+    "\n",
+    "import diffrax\n",
+    "import equinox as eqx  # https://github.com/patrick-kidger/equinox\n",
+    "import jax\n",
+    "import jax.nn as jnn\n",
+    "import jax.numpy as jnp\n",
+    "import jax.random as jrandom\n",
+    "import matplotlib.pyplot as plt\n",
+    "import optax  # https://github.com/deepmind/optax"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "In order to model our problem as a DDE $y'(t) = f_{\\theta}(t, y(t), y(t-\\tau_1), \\dots, y(t-\\tau_d))$, we first need to define a `Delays` object that incorporates deviated arguments in our vector field $f$.  \n",
+    "\n",
+    "LK's initial time point $t=0$ has a derivative jump because $\\phi^{\\prime}(t=0^{-}) \\neq  y^{\\prime}(t=0^{+})$ and the history function $\\phi(t)$ has `None`.  \n",
+    "The DDE model only has one time delay so $d=1$ and our vector field will be $y'(t) = f_{\\theta}(t, y(t), y(t-\\tau))$"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "delays = diffrax.Delays(\n",
+    "    delays=[lambda t, y, args: 0.2], initial_discontinuities=jnp.array([0.0])\n",
+    ")"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Below is defined the vector field $f_{\\theta}$. "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class Func(eqx.Module):\n",
+    "    mlp: eqx.nn.MLP\n",
+    "\n",
+    "    def __init__(self, data_size, width_size, depth, *, key, **kwargs):\n",
+    "        super().__init__(**kwargs)\n",
+    "        self.mlp = eqx.nn.MLP(\n",
+    "            in_size=data_size,\n",
+    "            out_size=data_size,\n",
+    "            width_size=width_size,\n",
+    "            depth=depth,\n",
+    "            activation=jnn.relu,\n",
+    "            key=key,\n",
+    "        )\n",
+    "\n",
+    "    def __call__(self, t, y, args, history):\n",
+    "        return self.mlp(y, history[0])"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "The `history` variable inside the network's `__call__`  is a tuple of deviated arguments. For example, if we possess a `Delays` object with 2 delays then the first element of tuple would be the first deviated argument $y(t-\\tau_1)$ and the second one $y(t-\\tau_1)$.  \n",
+    "In our case, `history[0]` corresponds to $y(t-0.2)$ and by extension `history[0][0]` is $y_1(t-0.2)$."
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Here we wrap up the entire DDE solve into a model."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class NeuralDDE(eqx.Module):\n",
+    "    func: Func\n",
+    "\n",
+    "    def __init__(self, data_size, width_size, depth, *, key, **kwargs):\n",
+    "        super().__init__(**kwargs)\n",
+    "        self.func = Func(data_size, width_size, depth, key=key)\n",
+    "\n",
+    "    def __call__(self, ts, y0):\n",
+    "        solution = diffrax.diffeqsolve(\n",
+    "            diffrax.ODETerm(self.func),\n",
+    "            diffrax.Tsit5(),\n",
+    "            t0=ts[0],\n",
+    "            t1=ts[-1],\n",
+    "            dt0=ts[1] - ts[0],\n",
+    "            y0=y0,\n",
+    "            delays=delays,\n",
+    "            stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),\n",
+    "            saveat=diffrax.SaveAt(ts=ts, dense=True),\n",
+    "        )\n",
+    "        return solution.ys"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "We generate the LK dataset."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def _get_data(ts, *, key):\n",
+    "    y0 = jrandom.uniform(key, (2,), minval=0.1, maxval=2.0)\n",
+    "\n",
+    "    def vector_field(t, y, args, history):\n",
+    "        return jnp.array(\n",
+    "            [\n",
+    "                1 / 2 * y[0] * (1 - history[0][1]),\n",
+    "                -1 / 2 * y[1] * (1 - history[0][0]),\n",
+    "            ]\n",
+    "        )\n",
+    "\n",
+    "    sol = diffrax.diffeqsolve(\n",
+    "        diffrax.ODETerm(vector_field),\n",
+    "        diffrax.Dopri5(),\n",
+    "        t0=ts[0],\n",
+    "        t1=ts[-1],\n",
+    "        dt0=ts[1] - ts[0],\n",
+    "        y0=lambda t: y0,\n",
+    "        adjoint=diffrax.NoAdjoint(),\n",
+    "        stepsize_controller=diffrax.PIDController(rtol=1e-6, atol=1e-9),\n",
+    "        saveat=diffrax.SaveAt(ts=ts, dense=True),\n",
+    "        delays=delays,\n",
+    "    )\n",
+    "\n",
+    "    return sol.ys\n",
+    "\n",
+    "\n",
+    "def get_data(dataset_size, *, key):\n",
+    "    ts = jnp.linspace(0, 15, 200)\n",
+    "    key = jrandom.split(key, dataset_size)\n",
+    "    ys = jax.vmap(lambda key: _get_data(ts, key=key))(key)\n",
+    "    return ts, ys\n",
+    "\n",
+    "\n",
+    "def dataloader(arrays, batch_size, *, key):\n",
+    "    dataset_size = arrays[0].shape[0]\n",
+    "    assert all(array.shape[0] == dataset_size for array in arrays)\n",
+    "    indices = jnp.arange(dataset_size)\n",
+    "    while True:\n",
+    "        perm = jrandom.permutation(key, indices)\n",
+    "        (key,) = jrandom.split(key, 1)\n",
+    "        start = 0\n",
+    "        end = batch_size\n",
+    "        while end < dataset_size:\n",
+    "            batch_perm = perm[start:end]\n",
+    "            yield tuple(array[batch_perm] for array in arrays)\n",
+    "            start = end\n",
+    "            end = start + batch_size"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Main entry point. Try runnning `main()`."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def main(\n",
+    "    dataset_size=256,\n",
+    "    batch_size=32,\n",
+    "    width_size=32,\n",
+    "    depth=2,\n",
+    "    tot_steps=500,\n",
+    "    lr=10e-3,\n",
+    "    seed=5678,\n",
+    "    plot=True,\n",
+    "    print_every=100,\n",
+    "):\n",
+    "    key = jrandom.PRNGKey(seed)\n",
+    "    data_key, model_key, loader_key = jrandom.split(key, 3)\n",
+    "\n",
+    "    ts, ys = get_data(dataset_size, key=data_key)\n",
+    "    _, _, data_size = ys.shape\n",
+    "\n",
+    "    model = NeuralDDE(data_size, width_size, depth, key=model_key)\n",
+    "\n",
+    "    @eqx.filter_value_and_grad\n",
+    "    def grad_loss(model, ti, yi):\n",
+    "        y_pred = jax.vmap(model, in_axes=(None, 0))(ti, yi[:, 0])\n",
+    "        return jnp.mean((yi - y_pred) ** 2)\n",
+    "\n",
+    "    @eqx.filter_jit\n",
+    "    def make_step(ti, yi, model, opt_state):\n",
+    "        loss, grads = grad_loss(model, ti, yi)\n",
+    "        updates, opt_state = optim.update(grads, opt_state)\n",
+    "        model = eqx.apply_updates(model, updates)\n",
+    "        return loss, model, opt_state\n",
+    "\n",
+    "    optim = optax.adabelief(lr)\n",
+    "    opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))\n",
+    "    for step, (yi,) in zip(\n",
+    "        range(tot_steps), dataloader((ys,), batch_size, key=loader_key)\n",
+    "    ):\n",
+    "        start = time.time()\n",
+    "        loss, model, opt_state = make_step(ts, yi, model, opt_state)\n",
+    "        end = time.time()\n",
+    "        if (step % print_every) == 0 or step == tot_steps - 1:\n",
+    "            print(f\"Step: {step}, Loss: {loss}, Computation time: {end - start}\")\n",
+    "\n",
+    "    if plot:\n",
+    "        plt.plot(ts, ys[0, :, 0], c=\"dodgerblue\", label=\"Real\")\n",
+    "        plt.plot(ts, ys[0, :, 1], c=\"dodgerblue\")\n",
+    "        model_y = model(ts, ys[0, 0])\n",
+    "        plt.plot(ts, model_y[:, 0], c=\"crimson\", label=\"Model\")\n",
+    "        plt.plot(ts, model_y[:, 1], c=\"crimson\")\n",
+    "        plt.legend()\n",
+    "        plt.tight_layout()\n",
+    "        plt.savefig(\"neural_ode.png\")\n",
+    "        plt.show()\n",
+    "\n",
+    "    return ts, ys, model"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3.9.13 ('dde')",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.9.13"
+  },
+  "orig_nbformat": 4,
+  "vscode": {
+   "interpreter": {
+    "hash": "85a82c4ca03817695851f28a9a8c882825d82c078a665316024297a0baa050af"
+   }
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/mkdocs.yml b/mkdocs.yml
index 4d8d3fa5..50a7e4cc 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -98,6 +98,7 @@ nav:
         - 'usage/extending.md'
     - Examples:
         - Basic ODE/SDE/CDE examples: 'other_examples/basic-examples.md'
+        # - DDE : 'examples/dde.ipynb' to add in Basic examples
         - Coupled ODEs: 'examples/coupled_odes.ipynb'
         - Stiff ODE: 'examples/stiff_ode.ipynb'
         - Forcing terms: 'examples/forcing.ipynb'
@@ -106,6 +107,7 @@ nav:
             - Neural CDE: 'examples/neural_cde.ipynb'
             - Neural SDE: 'examples/neural_sde.ipynb'
             - Latent ODE: 'examples/latent_ode.ipynb'
+            # - Neural DDE: 'examples/neural_dde.ipynb'
             - Continuous normalising flow: 'examples/continuous_normalising_flow.ipynb'
         - Symbolic regression: 'examples/symbolic_regression.ipynb'
         - Steady state: 'examples/steady_state.ipynb'
@@ -131,6 +133,7 @@ nav:
         - 'api/interpolation.md'
         - 'api/brownian.md'
         - 'api/nonlinear_solver.md'
+        # - 'api/delays.md'
     - Further details:
         - 'further_details/faq.md'
         - 'further_details/acknowledgements.md'
diff --git a/test/julia_dde/dde.jl b/test/julia_dde/dde.jl
new file mode 100644
index 00000000..acf20d7c
--- /dev/null
+++ b/test/julia_dde/dde.jl
@@ -0,0 +1,315 @@
+# File that generates DDE dynamics of several systems
+
+using DifferentialEquations
+using DelimitedFiles
+using PyPlot
+
+# Basic check 1 =========================================
+function basic_check_1(du, u, h, p, t)
+    tau = p
+    hist1 = h(p, t-tau)[1]
+    du[1] =  u[1] * (1 - hist1)
+end
+
+h(p,t) = 1.2 * ones(1)   
+u0 = [1.2]
+tau = 1.0
+lags = [tau]
+p = (tau)
+tspan = (0.0, 50.0)
+
+
+prob = DDEProblem(basic_check_1, u0, h, tspan, p; constant_lags=lags)
+alg = MethodOfSteps(Tsit5()) # doesn't work with DP5 DP8 but works with Tsit5 and Bosh3
+sol = solve(prob,alg, saveat=0.05)
+usol = transpose(hcat(sol.u...))
+time = sol.t
+
+plot(sol.t, usol)
+plt.xlabel("Time")
+plt.show()
+
+writedlm("test_basic_check_1.txt", [sol.t usol])
+
+# Basic check 2 =========================================
+function basic_check_2(du, u, h, p, t)
+    tau = p
+    hist1 = h(p, t-tau)[1]
+    du[1] =  u[1] * (1 - hist1)
+end
+
+h(p,t) = 1.2 * ones(1)   
+u0 = [1.2]
+tau = 2.0
+lags = [tau]
+p = (tau)
+tspan = (0.0, 50.0)
+
+
+prob = DDEProblem(basic_check_2, u0, h, tspan, p; constant_lags=lags)
+alg = MethodOfSteps(BS3()) # doesn't work with DP5 DP8 but works with Tsit5 and Bosh3
+sol = solve(prob,alg, saveat=0.05)
+usol = transpose(hcat(sol.u...))
+time = sol.t
+
+plot(sol.t, usol)
+plt.xlabel("Time")
+plt.show()
+
+writedlm("test_basic_check_2.txt", [sol.t usol])
+
+
+# Basic check 3 & 4 =========================================
+function basic_check_3(du, u, h, p, t)
+    tau = p
+    hist1 = h(p, t-tau)[1]
+    du[1] =  u[1] * (1 - hist1)
+end
+
+h(p,t) = 1.2 * ones(1)   
+u0 = [1.2]
+tau = 3.0
+lags = [tau]
+p = (tau)
+tspan = (0.0, 50.0)
+
+
+prob = DDEProblem(basic_check_3, u0, h, tspan, p; constant_lags=lags)
+alg = MethodOfSteps(BS3()) # doesn't work with DP5 DP8 but works with Tsit5 and Bosh3
+sol = solve(prob,alg, saveat=0.05)
+usol = transpose(hcat(sol.u...))
+
+alg = MethodOfSteps(Tsit5()) # doesn't work with DP5 DP8 but works with Tsit5 and Bosh3
+sol2 = solve(prob,alg, saveat=0.05)
+usol2 = transpose(hcat(sol2.u...))
+time = sol.t
+
+plot(sol.t, usol2)
+plot(sol.t, usol)
+plt.xlabel("Time")
+plt.show()
+
+writedlm("test_basic_check_3.txt", [sol.t usol])
+writedlm("test_basic_check_4.txt", [sol.t usol2])
+
+# Basic check 5 & 6 =========================================
+function basic_check_5(du, u, h, p, t)
+    tau = p
+    hist1 = h(p, t-tau)[1]
+    du[1] =  u[1] * (1 - hist1)
+end
+
+h(p,t) = 1.2 * ones(1)   
+u0 = [1.2]
+tau = 4.0
+lags = [tau]
+p = (tau)
+tspan = (0.0, 50.0)
+
+
+prob = DDEProblem(basic_check_5, u0, h, tspan, p; constant_lags=lags)
+alg = MethodOfSteps(BS3()) # doesn't work with DP5 DP8 but works with Tsit5 and Bosh3
+sol = solve(prob,alg, saveat=0.05)
+usol = transpose(hcat(sol.u...))
+
+alg = MethodOfSteps(Tsit5()) # doesn't work with DP5 DP8 but works with Tsit5 and Bosh3
+sol2 = solve(prob,alg, saveat=0.05)
+usol2 = transpose(hcat(sol2.u...))
+time = sol.t
+
+plot(sol.t, usol2)
+plot(sol.t, usol)
+plt.xlabel("Time")
+plt.show()
+
+writedlm("test_basic_check_5.txt", [sol.t usol])
+writedlm("test_basic_check_6.txt", [sol.t usol2])
+
+# Basic check 7 =========================================
+function basic_check_7(du, u, h, p, t)
+    tau = p
+    hist1 = h(p, t-tau)[1]
+    du[1] =  u[1] * (1 - hist1)
+end
+
+h(p,t) = 1.2 * ones(1)   
+u0 = [1.2]
+tau = 4.0
+lags = [tau]
+p = (tau)
+tspan = (0.0, 50.0)
+
+
+prob = DDEProblem(basic_check_7, u0, h, tspan, p; constant_lags=lags)
+alg = MethodOfSteps(Kvaerno5()) # doesn't work with DP5 DP8 but works with Tsit5 and Bosh3
+sol = solve(prob,alg, saveat=0.05)
+usol = transpose(hcat(sol.u...))
+time = sol.t
+
+plot(sol.t, usol)
+plt.xlabel("Time")
+plt.show()
+
+writedlm("test_basic_check_7.txt", [sol.t usol])
+
+# Basic check 8 =========================================
+function basic_check_8(du, u, h, p, t)
+    tau1, tau2 = p
+    hist1 = h(p, t-tau1)[1]
+    hist2 = h(p, t-tau2)[1]
+    du[1] = - hist1 - hist2   
+end
+
+h(p,t) = 1.2 * ones(1)     
+lags = [1.0/3  1.0/5]
+p    = (1.0/3, 1.0/5)
+tspan = (0.0, 10.0)
+u0 = [1.2]
+
+prob = DDEProblem(basic_check_8, u0, h, tspan, p; constant_lags=lags)
+alg = MethodOfSteps(Tsit5())
+sol = solve(prob,alg, saveat=0.1)
+usol = transpose(hcat(sol.u...))
+time = sol.t
+
+plot(sol.t, usol)
+plt.xlabel("Time")
+plt.show()
+
+writedlm("test_basic_check_8.txt", [sol.t usol])
+
+# Basic check 9 =========================================
+function basic_check_9(du, u, h, p, t)
+    tau = p
+    hist1 = h(p, t-tau)[1]
+    du[1] =  0.2 * hist1 / (1+ hist1^10) - 0.1 * u[1] 
+end
+
+
+h(p,t) = 1.2 * ones(1)   
+u0 = [1.2]
+tau = 6.0
+lags = [tau]
+p = (tau)
+tspan = (0.0, 50.0)
+
+prob = DDEProblem(basic_check_9, u0, h, tspan, p; constant_lags=lags)
+alg = MethodOfSteps(Tsit5())
+sol = solve(prob,alg, saveat=0.1)
+usol = transpose(hcat(sol.u...))
+time = sol.t
+
+plot(sol.t, usol)
+plt.xlabel("Time")
+plt.show()
+
+writedlm("test_basic_check_9.txt", [sol.t usol])
+
+# Basic check 10 =========================================
+function basic_check_10(du, u, h, p, t)
+    hist1 = h(p, t- 2 -sin(t))[1]
+    du[1] =  u[1] * (1 - hist1)
+end
+
+function h_basic_check_10(p, t)
+    1.2 * ones(1)  
+end
+
+prob = DDEProblem(basic_check_10,  h_basic_check_10,  (0.0, 40.0) ; dependent_lags = ((u, p, t) -> 2 + sin(t),))
+alg = MethodOfSteps(BS3())
+sol = solve(prob,alg, saveat=0.1)
+usol = transpose(hcat(sol.u...))
+time = sol.t
+
+plot(sol.t, usol)
+plt.xlabel("Time")
+plt.show()
+
+writedlm("test_basic_check_10.txt", [sol.t usol])
+
+Basic check 11 =========================================
+function basic_check_11(du, u, h, p, t)
+    hist1 = h(p, t- 1/2*(exp(-u[1]^2) + 1))[1]
+    du[1] =  -10*  hist1
+end
+
+function h_basic_check_11(p, t)
+    1.0 * ones(1)  
+end
+
+prob = DDEProblem(basic_check_11,  h_basic_check_11,  (0.0, 5.0) ; dependent_lags = ((u, p, t) -> 1/2*(exp(-u[1]^2) + 1),))
+alg = MethodOfSteps(Kvaerno5())
+sol = solve(prob,alg, saveat=0.01)
+usol = transpose(hcat(sol.u...))
+
+alg = MethodOfSteps(Kvaerno4())
+sol = solve(prob,alg, saveat=0.01)
+usol2 = transpose(hcat(sol.u...))
+
+time = sol.t
+
+plot(sol.t, usol2, label="Kv4")
+plot(sol.t, usol, label="Kv5")
+plt.xlabel("Time")
+plt.legend()
+plt.show()
+
+writedlm("test_basic_check_11.txt", [sol.t usol])
+
+
+# Numerical check 1 =========================================
+function numerical_check_1(du, u, h, p, t)
+    hist1 = h(p, u[1])[1]
+    du[1] =   hist1
+end
+
+function h_numerical_check_1(p, t)
+    if t < 2.0
+        1/2 * ones(1)  
+    else 
+        1 * ones(1)
+    end
+end
+
+prob = DDEProblem(numerical_check_1,  h_numerical_check_1, u0=[1.0 * ones(1)], (2.0, 5.5) ; dependent_lags = ((u, p, t) -> u[1],))
+alg = MethodOfSteps(BS3())
+sol = solve(prob,alg, saveat=0.01)
+usol = transpose(hcat(sol.u...))
+
+alg = MethodOfSteps(Tsit5())
+sol = solve(prob,alg, saveat=0.01)
+usol2 = transpose(hcat(sol.u...))
+
+time = sol.t
+
+plot(sol.t, usol2, label="Kv4")
+plot(sol.t, usol, label="Kv5")
+plt.xlabel("Time")
+plt.legend()
+plt.show()
+
+writedlm("test_basic_numerical_check_1.txt", [sol.t usol])
+
+# Numerical check 2 =========================================
+function numerical_check_2(du, u, h, p, t)
+    hist1 = h(p, log(u[1]))[1]
+    du[1] =   hist1 * u[1] / t 
+end
+
+function h_numerical_check_2(p, t)
+    1 * ones(1)
+end
+
+prob = DDEProblem(numerical_check_2,  h_numerical_check_2, (1.0, 10.0) ; dependent_lags = ((u, p, t) -> log(u[1]),))
+alg = MethodOfSteps(BS3())
+sol = solve(prob,alg, saveat=0.01)
+usol = transpose(hcat(sol.u...))
+time = sol.t
+
+plot(sol.t, usol2, label="Kv4")
+plot(sol.t, usol, label="Kv5")
+plt.xlabel("Time")
+plt.legend()
+plt.show()
+
+writedlm("test_basic_numerical_check_2.txt", [sol.t usol])
diff --git a/test/julia_dde/test_basic_check_1.txt b/test/julia_dde/test_basic_check_1.txt
new file mode 100644
index 00000000..ea119434
--- /dev/null
+++ b/test/julia_dde/test_basic_check_1.txt
@@ -0,0 +1,1001 @@
+0.0	1.2
+0.05	1.1880598004963274
+0.1	1.1762384079684047
+0.15	1.1645346400763932
+0.2	1.1529473230412208
+0.25	1.141475299112477
+0.3	1.1301174233925737
+0.35	1.1188725616239938
+0.4	1.1077395901892904
+0.45	1.0967173961110868
+0.5	1.0858048770520767
+0.55	1.075000941315024
+0.6	1.064304507842764
+0.65	1.0537145062182005
+0.7	1.043229876664309
+0.75	1.0328495700441358
+0.8	1.0225725478607945
+0.85	1.012397782257473
+0.9	1.0023242560174284
+0.95	0.9923509625639848
+1.0	0.9824769054640597
+1.05	0.9729919100045454
+1.1	0.9641710833105949
+1.15	0.9559924042823509
+1.2	0.9484352472028569
+1.25	0.9414803817380578
+1.3	0.9351099729367993
+1.35	0.9293075812308284
+1.4	0.9240579895194752
+1.45	0.9193469718969539
+1.5	0.9151617604103639
+1.55	0.9114905049979656
+1.6	0.908322242295648
+1.65	0.9056468956369282
+1.7	0.9034552750529523
+1.75	0.9017390772724944
+1.8	0.9004908857219571
+1.85	0.8997041705253715
+1.9	0.8993732885043974
+1.95	0.8994933756505719
+2.0	0.9000599782649152
+2.05	0.9010645757958058
+2.1	0.9024828681252601
+2.15	0.9042879518774135
+2.2	0.9064536858175863
+2.25	0.9089546908522841
+2.3	0.9117663500291971
+2.35	0.9148648085372008
+2.4	0.9182269737063555
+2.45	0.9218305150079066
+2.5	0.9256538640542846
+2.55	0.9296759042659053
+2.6	0.933874940282238
+2.65	0.9382300491809107
+2.7	0.9427203181020574
+2.75	0.947324743221118
+2.8	0.9520222297488384
+2.85	0.9567915919312706
+2.9	0.9616115530497718
+2.95	0.966460745421006
+3.0	0.9713177103969424
+3.05	0.9761609834122871
+3.1	0.9809690973066327
+3.15	0.9857212903649177
+3.2	0.9903978262709598
+3.25	0.994979994107456
+3.3	0.9994501083559825
+3.35	1.0037915088969953
+3.4	1.0079885610098294
+3.45	1.012026655372699
+3.5	1.0158922080626982
+3.55	1.0195726605557998
+3.6	1.0230564797268564
+3.65	1.0263331578495996
+3.7	1.0293932125966407
+3.75	1.0322281870394703
+3.8	1.0348306496484578
+3.85	1.0371941942928529
+3.9	1.0393134402407835
+3.95	1.0411842638198525
+4.0	1.0428051134855039
+4.05	1.0441745684936956
+4.1	1.0452929807393159
+4.15	1.0461626207646408
+4.2	1.0467866844066491
+4.25	1.0471692927970233
+4.3	1.0473154923621484
+4.35	1.0472312548231126
+4.4	1.046923477195708
+4.45	1.046399981790429
+4.5	1.045669516212473
+4.55	1.0447417533617414
+4.6	1.0436272914328375
+4.65	1.042337653915069
+4.7	1.0408852368203725
+4.75	1.0392819068558548
+4.8	1.0375400168437656
+4.85	1.035672502303835
+4.9	1.0336924804820984
+4.95	1.0316132503508986
+5.0	1.0294482926088842
+5.05	1.0272106941313845
+5.1	1.0249138335503054
+5.15	1.022571403740183
+5.2	1.0201966152372772
+5.25	1.0178021962395725
+5.3	1.0154003926067763
+5.35	1.0130029678603216
+5.4	1.0106212031833641
+5.45	1.0082658974207848
+5.5	1.0059473670791879
+5.55	1.0036754463269015
+5.6	1.001459486993979
+5.65	0.9993083585721964
+5.7	0.9972304482150548
+5.75	0.9952332208235575
+5.8	0.9933231149285477
+5.85	0.9915061922710027
+5.9	0.9897878493607076
+5.95	0.9881728170111881
+6.0	0.98666516033971
+6.05	0.9852690478823175
+6.1	0.9839874271383723
+6.15	0.982821719353966
+6.2	0.9817729283832497
+6.25	0.9808416406884336
+6.3	0.9800280253397873
+6.35	0.9793318340156394
+6.4	0.9787524010023783
+6.45	0.9782886431944512
+6.5	0.9779390600943649
+6.55	0.9777017338126854
+6.6	0.977574329068038
+6.65	0.9775540931871072
+6.7	0.9776378561046368
+6.75	0.9778220303634302
+6.8	0.9781026111143497
+6.85	0.9784751761163168
+6.9	0.9789348857363132
+6.95	0.9794764829493784
+7.0	0.9800942933386124
+7.05	0.9807830491124823
+7.1	0.9815387397262052
+7.15	0.9823554024900034
+7.2	0.9832270447806774
+7.25	0.9841477660896762
+7.3	0.9851117580230986
+7.35	0.9861133043016919
+7.4	0.9871467807608523
+7.45	0.9882066553506255
+7.5	0.9892874881357057
+7.55	0.9903839312954362
+7.6	0.9914907291238099
+7.65	0.9926027180294678
+7.7	0.9937148265357003
+7.75	0.9948220752804471
+7.8	0.9959195770162965
+7.85	0.9970025366104858
+7.9	0.9980662510449018
+7.95	0.9991061094160795
+8.0	1.0001175929352037
+8.05	1.0010964714906228
+8.1	1.002039554761376
+8.15	1.0029438160382957
+8.2	1.0038064492120649
+8.25	1.0046248990358608
+8.3	1.0053968611253554
+8.35	1.0061202819587138
+8.4	1.0067933588765958
+8.45	1.007414540082155
+8.5	1.007982524641039
+8.55	1.0084962624813893
+8.6	1.008954954393842
+8.65	1.0093580520315268
+8.7	1.0097052579100678
+8.75	1.0099965254075827
+8.8	1.0102320587646836
+8.85	1.0104123130844762
+8.9	1.0105379943325612
+8.95	1.0106100593370324
+9.0	1.010629715788478
+9.05	1.0105984222399802
+9.1	1.0105178881071157
+9.15	1.0103900736679543
+9.2	1.010217190063061
+9.25	1.0100016992954937
+9.3	1.0097463142308054
+9.35	1.0094539985970423
+9.4	1.0091273664549263
+9.45	1.0087662220550175
+9.5	1.008373565013879
+9.55	1.0079529315784397
+9.6	1.007507748750831
+9.65	1.0070413342883853
+9.7	1.0065568967036356
+9.75	1.0060575352643175
+9.8	1.005546239993367
+9.85	1.005025891668922
+9.9	1.0044992618243216
+9.95	1.003969012748106
+10.0	1.0034376974840173
+10.05	1.0029077598309983
+10.1	1.0023815343431939
+10.15	1.0018612463299494
+10.2	1.0013490118558124
+10.25	1.0008468377405315
+10.3	1.0003566215590562
+10.35	0.9998801516415383
+10.4	0.9994191070733301
+10.45	0.9989750576949857
+10.5	0.9985494641022605
+10.55	0.9981436776461111
+10.6	0.9977589404326955
+10.65	0.9973963853233733
+10.7	0.9970570359347051
+10.75	0.9967418066384531
+10.8	0.9964515025615808
+10.85	0.996186819586253
+10.9	0.9959483443498359
+10.95	0.9957365542448972
+11.0	0.9955518174192055
+11.05	0.995396301931505
+11.1	0.9952711486235987
+11.15	0.9951750094554872
+11.2	0.9951065518609477
+11.25	0.9950644588892086
+11.3	0.9950474292049508
+11.35	0.9950541770883066
+11.4	0.9950834324348602
+11.45	0.9951339407556477
+11.5	0.9952044631771567
+11.55	0.9952937764413271
+11.6	0.9954006729055503
+11.65	0.9955239605426696
+11.7	0.9956624629409799
+11.75	0.9958150193042283
+11.8	0.9959804844516134
+11.85	0.9961577288177856
+11.9	0.9963456384528475
+11.95	0.9965431150223529
+12.0	0.9967490758073081
+12.05	0.9969624537041705
+12.1	0.9971821972248498
+12.15	0.9974072704967074
+12.2	0.9976366532625566
+12.25	0.997869340880662
+12.3	0.9981043443247409
+12.35	0.9983406901839615
+12.4	0.9985774206629445
+12.45	0.998813593581762
+12.5	0.999048282375938
+12.55	0.9992805760964485
+12.6	0.9995095794097211
+12.65	0.9997344125976353
+12.7	0.9999542115575222
+12.75	1.000168127802165
+12.8	1.000375328459799
+12.85	1.0005749962741102
+12.9	1.0007663296042377
+12.95	1.0009485424247715
+13.0	1.001120864325754
+13.05	1.001282540512679
+13.1	1.0014328318064922
+13.15	1.0015710146435914
+13.2	1.001696381075826
+13.25	1.0018080953978006
+13.3	1.0019058056111938
+13.35	1.0019899586093672
+13.4	1.0020610167310116
+13.45	1.0021194420712354
+13.5	1.0021656964815653
+13.55	1.002200241569946
+13.6	1.0022235387007397
+13.65	1.0022360489947266
+13.7	1.002238233329105
+13.75	1.0022305523374908
+13.8	1.0022134664099176
+13.85	1.0021874356928375
+13.9	1.00215292008912
+13.95	1.0021103792580526
+14.0	1.0020602726153405
+14.05	1.002003059333107
+14.1	1.0019391983398935
+14.15	1.001869148320659
+14.2	1.0017933677167798
+14.25	1.0017123147260512
+14.3	1.0016264473026857
+14.35	1.0015362231573137
+14.4	1.0014420997569835
+14.45	1.0013445343251617
+14.5	1.0012439838417322
+14.55	1.0011409050429971
+14.6	1.0010357544216761
+14.65	1.0009289882269072
+14.7	1.000821062464246
+14.75	1.0007124328956658
+14.8	1.0006035550395582
+14.85	1.0004948841707324
+14.9	1.0003868753204155
+14.95	1.0002799832762528
+15.0	1.0001746625823067
+15.05	1.0000713675390585
+15.1	0.9999705522034064
+15.15	0.9998726703886672
+15.2	0.9997781756645753
+15.25	0.9996875213572829
+15.3	0.9996011605493601
+15.35	0.9995195460797951
+15.4	0.9994431305439937
+15.45	0.9993723662937797
+15.5	0.9993078801202896
+15.55	0.9992505744009643
+15.6	0.9992002709754936
+15.65	0.9991566855771745
+15.7	0.9991195378792506
+15.75	0.9990885514949103
+15.8	0.9990634539772882
+15.85	0.999043976819465
+15.9	0.9990298554544664
+15.95	0.9990208292552645
+16.0	0.9990166415347769
+16.05	0.9990170395458667
+16.1	0.9990217744813433
+16.15	0.9990306014739615
+16.2	0.9990432795964218
+16.25	0.9990595718613707
+16.3	0.9990792452214002
+16.35	0.9991020705690484
+16.4	0.9991278227367987
+16.45	0.9991562804970805
+16.5	0.9991872265622691
+16.55	0.9992204475846853
+16.6	0.9992557341565957
+16.65	0.9992928808102128
+16.7	0.9993316860176947
+16.75	0.9993719521911454
+16.8	0.9994134856826143
+16.85	0.9994560967840972
+16.9	0.999499599727535
+16.95	0.9995438126848146
+17.0	0.9995885577677688
+17.05	0.999633661028176
+17.1	0.9996789524577604
+17.15	0.9997242659881919
+17.2	0.9997694394910861
+17.25	0.9998143147780046
+17.3	0.9998587376004546
+17.35	0.999902557649889
+17.4	0.9999456285577065
+17.45	0.9999878078952515
+17.5	1.0000289571738143
+17.55	1.000068941844631
+17.6	1.0001076312988828
+17.65	1.0001448988676978
+17.7	1.000180621822149
+17.75	1.000214681373255
+17.8	1.0002469626719812
+17.85	1.0002773548092376
+17.9	1.0003057508158808
+17.95	1.0003320476627124
+18.0	1.0003561462604804
+18.05	1.0003779514598783
+18.1	1.000397372051545
+18.15	1.0004143207660663
+18.2	1.0004287142739723
+18.25	1.0004404731857395
+18.3	1.0004495220517904
+18.35	1.0004556403120812
+18.4	1.0004586682601238
+18.45	1.0004587976048693
+18.5	1.0004562199748748
+18.55	1.0004511224400474
+18.6	1.0004436875116447
+18.65	1.0004340931422746
+18.7	1.0004225127258954
+18.75	1.000409115097816
+18.8	1.0003940645346958
+18.85	1.0003775207545444
+18.9	1.0003596389167222
+18.95	1.0003405696219398
+19.0	1.0003204589122578
+19.05	1.0002994482710883
+19.1	1.000277674623193
+19.15	1.000255270334684
+19.2	1.0002323632130246
+19.25	1.0002090765070275
+19.3	1.0001855289068569
+19.35	1.0001618345440266
+19.4	1.0001381029914014
+19.45	1.0001144392631962
+19.5	1.0000909438149763
+19.55	1.0000677125436577
+19.6	1.0000448367875068
+19.65	1.0000224033261402
+19.7	1.0000004943805254
+19.75	0.9999791876129798
+19.8	0.9999585561271713
+19.85	0.9999386684681189
+19.9	0.9999195886221911
+19.95	0.9999013760171075
+20.0	0.999884085521938
+20.05	0.9998677674471027
+20.1	0.9998524675443724
+20.15	0.9998382270068683
+20.2	0.999825082469062
+20.25	0.9998130660067754
+20.3	0.9998022051371811
+20.35	0.9997925228188019
+20.4	0.9997840374515113
+20.45	0.9997767628765329
+20.5	0.9997707083764411
+20.55	0.9997658786751604
+20.6	0.999762273937966
+20.65	0.9997598897714833
+20.7	0.9997587172236886
+20.75	0.999758742783908
+20.8	0.9997599483828186
+20.85	0.9997623113924474
+20.9	0.9997658046261724
+20.95	0.9997703963387217
+21.0	0.9997760502261739
+21.05	0.999782725425958
+21.1	0.9997903765168537
+21.15	0.9997989535189906
+21.2	0.9998084018938495
+21.25	0.9998186625442608
+21.3	0.9998296718144059
+21.35	0.9998413614898165
+21.4	0.9998536587973748
+21.45	0.9998664864053133
+21.5	0.9998797624232149
+21.55	0.9998934004020132
+21.6	0.9999073093339921
+21.65	0.9999213936527858
+21.7	0.999935553233379
+21.75	0.9999496833921071
+21.8	0.9999636748866557
+21.85	0.9999774139160608
+21.9	0.9999907821207089
+21.95	1.0000036565823371
+22.0	1.0000159098240327
+22.05	1.0000273831324133
+22.1	1.0000378192552613
+22.15	1.0000472193179122
+22.2	1.000055628495489
+22.25	1.000063091144495
+22.3	1.0000696508028153
+22.35	1.0000753501897153
+22.4	1.0000802312058417
+22.45	1.0000843349332225
+22.5	1.0000877016352665
+22.55	1.0000903707567637
+22.6	1.0000923809238846
+22.65	1.0000937699441819
+22.7	1.0000945748065881
+22.75	1.0000948316814176
+22.8	1.0000945759203652
+22.85	1.0000938420565073
+22.9	1.0000926638043013
+22.95	1.000091074059585
+23.0	1.000089104899578
+23.05	1.0000867875828805
+23.1	1.000084152549474
+23.15	1.0000812294207209
+23.2	1.0000780469993646
+23.25	1.0000746332695298
+23.3	1.000071015396722
+23.35	1.000067219727828
+23.4	1.000063271791115
+23.45	1.0000591962962322
+23.5	1.000055017134209
+23.55	1.0000507573774566
+23.6	1.0000464392797663
+23.65	1.0000420842763114
+23.7	1.0000377129836457
+23.75	1.000033345199704
+23.8	1.0000289999038028
+23.85	1.0000246952566387
+23.9	1.0000204486002902
+23.95	1.0000162764582161
+24.0	1.0000121945352567
+24.05	1.0000082177176333
+24.1	1.0000043600729482
+24.15	1.0000006348501846
+24.2	0.9999970544797072
+24.25	0.9999936305732612
+24.3	0.999990373923973
+24.35	0.9999872945063504
+24.4	0.9999844014762815
+24.45	0.9999817031710364
+24.5	0.9999792071092655
+24.55	0.9999769199910005
+24.6	0.9999748476976542
+24.65	0.9999729952920203
+24.7	0.9999713670182737
+24.75	0.9999699663019703
+24.8	0.999968795750047
+24.85	0.9999678571508217
+24.9	0.9999671514739935
+24.95	0.9999666788706424
+25.0	0.9999664386732294
+25.05	0.9999664293955969
+25.1	0.9999666487329679
+25.15	0.9999670935619466
+25.2	0.9999677599405185
+25.25	0.9999686431080497
+25.3	0.9999697374852876
+25.35	0.9999710366743607
+25.4	0.9999725334587785
+25.45	0.9999742198034313
+25.5	0.9999760868545908
+25.55	0.9999781249399096
+25.6	0.9999803235684213
+25.65	0.9999826714305405
+25.7	0.9999851563980632
+25.75	0.9999877655241658
+25.8	0.9999904850434064
+25.85	0.9999933003717237
+25.9	0.9999961961064378
+25.95	0.9999991560262493
+26.0	1.0000021630912406
+26.05	1.0000051994428745
+26.1	1.0000082464039952
+26.15	1.0000112844788278
+26.2	1.0000142814570883
+26.25	1.000017051088491
+26.3	1.0000195436035837
+26.35	1.0000217692482254
+26.4	1.000023738117155
+26.45	1.0000254601539926
+26.5	1.0000269451512378
+26.55	1.0000282027502714
+26.6	1.0000292424413542
+26.65	1.0000300735636278
+26.7	1.0000307053051143
+26.75	1.0000311467027156
+26.8	1.0000314066422147
+26.85	1.0000314938582748
+26.9	1.0000314169344398
+26.95	1.0000311843031333
+27.0	1.0000308042456605
+27.05	1.000030284892206
+27.1	1.0000296342218356
+27.15	1.0000288600624951
+27.2	1.0000279700910109
+27.25	1.0000269718330896
+27.3	1.0000258726633189
+27.35	1.0000246798051662
+27.4	1.0000234003309798
+27.45	1.0000220411619882
+27.5	1.0000206090683008
+27.55	1.0000191106689067
+27.6	1.0000175524316763
+27.65	1.0000159406733597
+27.7	1.0000142815595883
+27.75	1.0000125811048726
+27.8	1.0000108451726049
+27.85	1.0000090794750573
+27.9	1.0000072895733827
+27.95	1.000005480877614
+28.0	1.0000036586466647
+28.05	1.000001827988329
+28.1	0.9999999938592812
+28.15	0.9999981610650764
+28.2	0.9999963342601499
+28.25	0.9999945179478175
+28.3	0.9999927164802754
+28.35	0.9999909340586004
+28.4	0.9999891747327497
+28.45	0.9999874424015609
+28.5	0.999985740812752
+28.55	0.9999840735629216
+28.6	0.9999824440975487
+28.65	0.9999808557109925
+28.7	0.9999793115464931
+28.75	0.9999778145961706
+28.8	0.9999763677010259
+28.85	0.9999749735509403
+28.9	0.9999736346846753
+28.95	0.999972353489873
+29.0	0.9999711322030561
+29.05	0.9999699729096274
+29.1	0.9999688775438706
+29.15	0.9999678478889495
+29.2	0.9999668855769084
+29.25	0.9999659920886722
+29.3	0.9999651687540461
+29.35	0.9999644167517158
+29.4	0.9999637371092475
+29.45	0.9999631307030877
+29.5	0.9999625982585636
+29.55	0.9999621403498825
+29.6	0.9999617574001326
+29.65	0.9999614496812822
+29.7	0.9999612173141801
+29.75	0.9999610602685557
+29.8	0.9999609783630186
+29.85	0.9999609712650591
+29.9	0.9999610384910479
+29.95	0.999961179406236
+30.0	0.9999613932247549
+30.05	0.9999616790096166
+30.1	0.9999620356727138
+30.15	0.9999624619748192
+30.2	0.999962956525586
+30.25	0.9999635177835483
+30.3	0.9999641440561201
+30.35	0.9999648334995962
+30.4	0.9999655841191517
+30.45	0.9999663937688423
+30.5	0.9999672601516039
+30.55	0.9999681808192531
+30.6	0.9999691531724867
+30.65	0.9999701744608822
+30.7	0.9999712417828974
+30.75	0.9999723520858707
+30.8	0.9999735021660205
+30.85	0.9999746886684464
+30.9	0.9999759080871279
+30.95	0.9999771567649249
+31.0	0.9999784308935782
+31.05	0.9999797265137086
+31.1	0.9999810395148175
+31.15	0.999982365635287
+31.2	0.9999837004623792
+31.25	0.999985039432237
+31.3	0.9999863778298836
+31.35	0.9999877107892228
+31.4	0.9999890332930386
+31.45	0.9999903401729955
+31.5	0.9999916261096387
+31.55	0.9999928856323935
+31.6	0.999994113119566
+31.65	0.9999953027983425
+31.7	0.9999964487447898
+31.75	0.9999975448838552
+31.8	0.9999985849893664
+31.85	0.9999995626840317
+31.9	1.0000004714394395
+31.95	1.0000013045760592
+32.0	1.0000020552632398
+32.05	1.0000027165192118
+32.1	1.0000032812110853
+32.15	1.0000037420548513
+32.2	1.0000040916153812
+32.25	1.0000043223064266
+32.3	1.0000044263906196
+32.35	1.000004399006739
+32.4	1.0000043128626332
+32.45	1.000004203336026
+32.5	1.0000040723902253
+32.55	1.0000039219328472
+32.6	1.0000037538158149
+32.65	1.0000035698353584
+32.7	1.0000033717320158
+32.75	1.000003161190632
+32.8	1.000002939840359
+32.85	1.0000027092546566
+32.9	1.0000024709512918
+32.95	1.0000022263923385
+33.0	1.0000019769841784
+33.05	1.0000017240775
+33.1	1.0000014689672998
+33.15	1.000001212892881
+33.2	1.0000009570378545
+33.25	1.0000007025301378
+33.3	1.0000004504419566
+33.35	1.0000002017898435
+33.4	0.9999999575346381
+33.45	0.9999997185814878
+33.5	0.9999994857798472
+33.55	0.9999992599234779
+33.6	0.9999990417504491
+33.65	0.9999988319431372
+33.7	0.999998631128226
+33.75	0.9999984398767062
+33.8	0.9999982587038765
+33.85	0.9999980880693422
+33.9	0.9999979283770163
+33.95	0.9999977799751192
+34.0	0.9999976431561781
+34.05	0.9999975181570281
+34.1	0.9999974051588111
+34.15	0.9999973042869765
+34.2	0.9999972156112813
+34.25	0.9999971391457892
+34.3	0.9999970748488717
+34.35	0.9999970226232074
+34.4	0.9999969823157822
+34.45	0.9999969537178893
+34.5	0.9999969365651292
+34.55	0.9999969305374098
+34.6	0.9999969352589462
+34.65	0.9999969502982609
+34.7	0.9999969751681834
+34.75	0.9999970093258509
+34.8	0.9999970521727077
+34.85	0.9999971030545055
+34.9	0.9999971612613032
+34.95	0.999997226027467
+35.0	0.9999972965316705
+35.05	0.9999973718968944
+35.1	0.9999974511904269
+35.15	0.9999975334238635
+35.2	0.9999976175531069
+35.25	0.9999977024783672
+35.3	0.9999977870441616
+35.35	0.999997870039315
+35.4	0.999997950196959
+35.45	0.999998026194533
+35.5	0.9999980966537837
+35.55	0.9999981601407648
+35.6	0.9999982151658375
+35.65	0.9999982601836702
+35.7	0.9999982935932387
+35.75	0.9999983137378259
+35.8	0.9999983189050224
+35.85	0.9999983073267258
+35.9	0.9999982771791409
+35.95	0.99999822658278
+36.0	0.9999981536024628
+36.05	0.9999980562473161
+36.1	0.9999979324707741
+36.15	0.9999977993487448
+36.2	0.9999976838090294
+36.25	0.9999975853411772
+36.3	0.9999975031965993
+36.35	0.9999974366391983
+36.4	0.9999973849453675
+36.45	0.9999973474039922
+36.5	0.999997323316448
+36.55	0.9999973119966025
+36.6	0.999997312770814
+36.65	0.9999973249779324
+36.7	0.9999973479692984
+36.75	0.9999973811087441
+36.8	0.9999974237725932
+36.85	0.9999974753496597
+36.9	0.9999975352412496
+36.95	0.9999976028611599
+37.0	0.9999976776356788
+37.05	0.9999977590035856
+37.1	0.9999978464161507
+37.15	0.9999979393371363
+37.2	0.9999980372427951
+37.25	0.9999981396218716
+37.3	0.9999982459756009
+37.35	0.99999835581771
+37.4	0.9999984686744164
+37.45	0.9999985840844295
+37.5	0.9999987015989494
+37.55	0.9999988207816676
+37.6	0.9999989412087669
+37.65	0.999999062468921
+37.7	0.9999991841632953
+37.75	0.9999993059055459
+37.8	0.9999994273218205
+37.85	0.9999995480507576
+37.9	0.9999996677434876
+37.95	0.9999997860636314
+38.0	0.9999999026873014
+38.05	1.0000000173031014
+38.1	1.000000129612126
+38.15	1.000000239327961
+38.2	1.0000003461766842
+38.25	1.0000004498968635
+38.3	1.000000550239559
+38.35	1.0000006469683214
+38.4	1.0000007398591928
+38.45	1.0000008287007063
+38.5	1.0000009132938865
+38.55	1.0000009934522494
+38.6	1.0000010690018015
+38.65	1.000001139781041
+38.7	1.0000012056409577
+38.75	1.0000012664450315
+38.8	1.0000013220692348
+38.85	1.00000137240203
+38.9	1.0000014173443716
+38.95	1.0000014568097049
+39.0	1.0000014907239667
+39.05	1.0000015190255847
+39.1	1.0000015416654777
+39.15	1.0000015586070563
+39.2	1.0000015698262217
+39.25	1.0000015753113667
+39.3	1.0000015750633753
+39.35	1.0000015690956223
+39.4	1.0000015574339742
+39.45	1.0000015401167883
+39.5	1.0000015171949137
+39.55	1.00000148873169
+39.6	1.0000014548029486
+39.65	1.0000014154970114
+39.7	1.0000013709146924
+39.75	1.0000013211692962
+39.8	1.000001266386619
+39.85	1.0000012067049475
+39.9	1.0000011422750608
+39.95	1.0000010732602278
+40.0	1.0000009998362098
+40.05	1.000000922191259
+40.1	1.000000840526118
+40.15	1.0000007550540218
+40.2	1.000000666000696
+40.25	1.0000005736043571
+40.3	1.000000478115714
+40.35	1.000000379797965
+40.4	1.0000002789268014
+40.45	1.0000001757904047
+40.5	1.0000000706894476
+40.55	0.9999999639370944
+40.6	0.9999998558590008
+40.65	0.9999997467933128
+40.7	0.9999996370906686
+40.75	0.9999995271141972
+40.8	0.9999994172395185
+40.85	0.9999993078547441
+40.9	0.9999991993604767
+40.95	0.99999909216981
+41.0	0.9999989867083292
+41.05	0.9999988834141104
+41.1	0.9999987827377211
+41.15	0.9999986851422201
+41.2	0.9999985911031573
+41.25	0.9999985011085736
+41.3	0.9999984156590015
+41.35	0.9999983352674645
+41.4	0.9999982604594773
+41.45	0.9999981917730458
+41.5	0.9999981297586673
+41.55	0.9999980749793301
+41.6	0.9999980280105139
+41.65	0.9999979894401892
+41.7	0.9999979598688182
+41.75	0.9999979399093543
+41.8	0.9999979301872416
+41.85	0.9999979313404158
+41.9	0.9999979440193039
+41.95	0.9999979688868238
+42.0	0.999998006618385
+42.05	0.9999980574639027
+42.1	0.9999981166370969
+42.15	0.9999981821903582
+42.2	0.9999982535581772
+42.25	0.999998330189797
+42.3	0.9999984115492135
+42.35	0.9999984971151756
+42.4	0.9999985863811847
+42.45	0.9999986788554951
+42.5	0.9999987740611138
+42.55	0.9999988715358007
+42.6	0.9999989708320682
+42.65	0.9999990715171818
+42.7	0.9999991731731596
+42.75	0.9999992753967724
+42.8	0.9999993777995437
+42.85	0.9999994800077502
+42.9	0.9999995816624208
+42.95	0.9999996824193376
+43.0	0.9999997819490352
+43.05	0.999999879936801
+43.1	0.9999999760826754
+43.15	1.0000000701014513
+43.2	1.0000001617226746
+43.25	1.0000002506906436
+43.3	1.0000003367644097
+43.35	1.000000419717777
+43.4	1.0000004993393021
+43.45	1.000000575432295
+43.5	1.0000006478148178
+43.55	1.0000007163196856
+43.6	1.0000007807944662
+43.65	1.0000008411014807
+43.7	1.000000897117802
+43.75	1.0000009487352566
+43.8	1.0000009958604232
+43.85	1.000001038414634
+43.9	1.000001076333973
+43.95	1.0000011095692776
+44.0	1.0000011380861378
+44.05	1.0000011618648965
+44.1	1.0000011809006493
+44.15	1.0000011952032442
+44.2	1.0000012047972826
+44.25	1.000001209722118
+44.3	1.0000012100318574
+44.35	1.00000120579536
+44.4	1.0000011970962381
+44.45	1.0000011840328562
+44.5	1.0000011667183324
+44.55	1.000001145280537
+44.6	1.0000011198620933
+44.65	1.0000010906203771
+44.7	1.0000010577275171
+44.75	1.0000010213703951
+44.8	1.0000009817506452
+44.85	1.0000009390846543
+44.9	1.0000008936035625
+44.95	1.0000008455532623
+45.0	1.0000007951943988
+45.05	1.0000007428023703
+45.1	1.0000006886673278
+45.15	1.0000006330941746
+45.2	1.0000005764025675
+45.25	1.0000005189269152
+45.3	1.00000046101638
+45.35	1.0000004030348766
+45.4	1.0000003453610722
+45.45	1.0000002883883872
+45.5	1.0000002325249948
+45.55	1.0000001781938204
+45.6	1.0000001258325426
+45.65	1.000000075893593
+45.7	1.0000000288441553
+45.75	0.9999999851661665
+45.8	0.9999999453563162
+45.85	0.9999999099260467
+45.9	0.9999998794015532
+45.95	0.9999998543237836
+46.0	0.9999998352484385
+46.05	0.9999998216916034
+46.1	0.9999998105971786
+46.15	0.9999998016152832
+46.2	0.9999997946264213
+46.25	0.9999997895139021
+46.3	0.9999997861638408
+46.35	0.9999997844651585
+46.4	0.9999997843095817
+46.45	0.9999997855916432
+46.5	0.999999788208681
+46.55	0.9999997920608393
+46.6	0.999999797051068
+46.65	0.9999998030851226
+46.7	0.9999998100715645
+46.75	0.9999998179217608
+46.8	0.9999998265498845
+46.85	0.9999998358729142
+46.9	0.9999998458106344
+46.95	0.9999998562856354
+47.0	0.999999867223313
+47.05	0.9999998785518691
+47.1	0.9999998902023113
+47.15	0.9999999021084529
+47.2	0.9999999142069128
+47.25	0.9999999264371161
+47.3	0.9999999387412932
+47.35	0.9999999510644807
+47.4	0.9999999633545207
+47.45	0.9999999755620611
+47.5	0.9999999876405556
+47.55	0.9999999995462638
+47.6	1.000000011238251
+47.65	1.0000000226783878
+47.7	1.0000000338313517
+47.75	1.0000000446646247
+47.8	1.0000000551484953
+47.85	1.0000000652560577
+47.9	1.0000000749632119
+47.95	1.0000000842486632
+48.0	1.0000000930939232
+48.05	1.0000001014833093
+48.1	1.0000001094039441
+48.15	1.0000001168457566
+48.2	1.0000001238014815
+48.25	1.0000001302666586
+48.3	1.0000001362396342
+48.35	1.0000001417215603
+48.4	1.0000001467163944
+48.45	1.0000001512308996
+48.5	1.0000001552746454
+48.55	1.0000001588600065
+48.6	1.0000001620021635
+48.65	1.0000001647191032
+48.7	1.0000001670316176
+48.75	1.000000168963305
+48.8	1.0000001705405686
+48.85	1.0000001717926184
+48.9	1.0000001727514696
+48.95	1.0000001734519433
+49.0	1.0000001739316662
+49.05	1.0000001742310711
+49.1	1.0000001743933966
+49.15	1.0000001744646865
+49.2	1.0000001744937907
+49.25	1.0000001745323654
+49.3	1.0000001746348717
+49.35	1.000000174858577
+49.4	1.0000001752635541
+49.45	1.0000001759126822
+49.5	1.0000001768716456
+49.55	1.0000001782089347
+49.6	1.0000001799958456
+49.65	1.0000001823064804
+49.7	1.0000001852177465
+49.75	1.0000001888093573
+49.8	1.000000193163832
+49.85	1.0000001983664961
+49.9	1.0000002045054797
+49.95	1.0000002116717195
+50.0	1.000000219958958
diff --git a/test/julia_dde/test_basic_check_10.txt b/test/julia_dde/test_basic_check_10.txt
new file mode 100644
index 00000000..0e15e27c
--- /dev/null
+++ b/test/julia_dde/test_basic_check_10.txt
@@ -0,0 +1,401 @@
+0.0	1.2
+0.1	1.1762383903532672
+0.2	1.1529472850205345
+0.3	1.1301172850650132
+0.4	1.1077393734477934
+0.5	1.0858044886142186
+0.6	1.0643037416988483
+0.7	1.0432288615243335
+0.8	1.022571356528724
+0.9	1.0023223752836392
+1.0	0.9824742404177691
+1.1	0.963019381199433
+1.2	0.9439501263044995
+1.3	0.9252577408761488
+1.4	0.9069349374873376
+1.5	0.8889749354496113
+1.6	0.8713709540745148
+1.7	0.8541162126735939
+1.8	0.8372028671950351
+1.9	0.8206232664837655
+2.0	0.8043714637653621
+2.1	0.7884415190806816
+2.2	0.7728274924705816
+2.3	0.757523443975919
+2.4	0.7425233077800755
+2.5	0.7278203057112125
+2.6	0.7137323858717671
+2.7	0.7025150957051455
+2.8	0.694484707545044
+2.9	0.6894972844620731
+3.0	0.6874088895268431
+3.1	0.6880755858099645
+3.2	0.6913925755975374
+3.3	0.6973222618621986
+3.4	0.7058067257179844
+3.5	0.7168115471678291
+3.6	0.7303370069135824
+3.7	0.7463846145141237
+3.8	0.7649558795283328
+3.9	0.7860623562038731
+4.0	0.8095410256788728
+4.1	0.8348171340036061
+4.2	0.8612982116589256
+4.3	0.8884619443461924
+4.4	0.9158450419219439
+4.5	0.9430276803629207
+4.6	0.9696472893588215
+4.7	0.9954055016902975
+4.8	1.02006858215819
+4.9	1.0434549257284909
+5.0	1.0654985042705631
+5.1	1.0861544105253624
+5.2	1.1054892732450592
+5.3	1.1236480400185191
+5.4	1.1407769852205385
+5.5	1.1570306028256245
+5.6	1.1725947516443396
+5.7	1.1876404586150953
+5.8	1.2023387498561318
+5.9	1.2168526222926397
+6.0	1.2312873402797389
+6.1	1.2457399798526765
+6.2	1.260310280205678
+6.3	1.2750676414654825
+6.4	1.2899834074452599
+6.5	1.3050178098521605
+6.6	1.3201310803933353
+6.7	1.3352945550759119
+6.8	1.350350773348611
+6.9	1.3650579417149595
+7.0	1.3791742173697994
+7.1	1.392457757507972
+7.2	1.4046657660937962
+7.3	1.4154832030220348
+7.4	1.4245446545864353
+7.5	1.4314992491448382
+7.6	1.4359911635746783
+7.7	1.4376417983266423
+7.8	1.43625424520483
+7.9	1.4316782724704902
+8.0	1.4237636483848721
+8.1	1.4124176896707952
+8.2	1.3977578549156462
+8.3	1.3799340478096036
+8.4	1.3590957461794575
+8.5	1.3353924278519964
+8.6	1.3089802658668372
+8.7	1.2799746596859716
+8.8	1.2484500523482862
+8.9	1.2144805927509215
+9.0	1.1781404107123883
+9.1	1.1394589183846158
+9.2	1.0988328136288612
+9.3	1.0568777052684357
+9.4	1.014209202126653
+9.5	0.9715695193343649
+9.6	0.930032296455194
+9.7	0.8906351599134852
+9.8	0.854409165123903
+9.9	0.8220512297459162
+10.0	0.794002306444414
+10.1	0.770649736832742
+10.2	0.7521142528390301
+10.3	0.738432692185377
+10.4	0.7296563251431871
+10.5	0.7257749973000727
+10.6	0.7264364129282308
+10.7	0.731245671383987
+10.8	0.7398036635388009
+10.9	0.7516141391602215
+11.0	0.7661804613272529
+11.1	0.7830273114885461
+11.2	0.8017662055697486
+11.3	0.8221064904542812
+11.4	0.8437618854277524
+11.5	0.8665189094969064
+11.6	0.890256414798055
+11.7	0.9148586032887032
+11.8	0.940235001645523
+11.9	0.9663699334619459
+12.0	0.993260294587582
+12.1	1.0209029808720407
+12.2	1.0493020447199954
+12.3	1.0784799284506588
+12.4	1.1084608906984939
+12.5	1.1392691895837797
+12.6	1.170929083226794
+12.7	1.2034648297478152
+12.8	1.2369006872671222
+12.9	1.2712609139049917
+13.0	1.306569767781703
+13.1	1.3428558472689696
+13.2	1.3801674068749505
+13.3	1.418434768431265
+13.4	1.457559911388622
+13.5	1.4974448151977326
+13.6	1.5379914593093074
+13.7	1.5791068886029291
+13.8	1.620441928950408
+13.9	1.661439393998374
+14.0	1.701539941267787
+14.1	1.7401842282796052
+14.2	1.7767153880380508
+14.3	1.8102339971667682
+14.4	1.8398266864159183
+14.5	1.8645794114950862
+14.6	1.8835528394067071
+14.7	1.89585305747116
+14.8	1.9006113130381221
+14.9	1.8969562831461415
+15.0	1.8841144991168566
+15.1	1.8615707087722606
+15.2	1.8288415838383896
+15.3	1.7854422595678603
+15.4	1.7313654916031211
+15.5	1.6671129005732073
+15.6	1.5932016378115792
+15.7	1.5102209083685956
+15.8	1.4197593505303363
+15.9	1.3241170196184957
+16.0	1.2256975472498846
+16.1	1.12751003048568
+16.2	1.0326262210874273
+16.3	0.9439184319210752
+16.4	0.8634363752312201
+16.5	0.792881976065694
+16.6	0.7329295576461717
+16.7	0.6837732515298223
+16.8	0.6452103507276176
+16.9	0.616424326376609
+17.0	0.5965433416856748
+17.1	0.5845934321642667
+17.2	0.5794777718697854
+17.3	0.5801388198005561
+17.4	0.5856713201165599
+17.5	0.5952318062141689
+17.6	0.6081861042382376
+17.7	0.6239883435268924
+17.8	0.6422245660405075
+17.9	0.6625875134112515
+18.0	0.6847998907124471
+18.1	0.7086945321200175
+18.2	0.7341302256701652
+18.3	0.7609816885603836
+18.4	0.7891865076296669
+18.5	0.8186944899168314
+18.6	0.84944806970619
+18.7	0.881418302111781
+18.8	0.9146354592666212
+18.9	0.9491319288165144
+19.0	0.9849388298458869
+19.1	1.0220756401808062
+19.2	1.060697143076142
+19.3	1.101004243000086
+19.4	1.1431978444208233
+19.5	1.1874788518065464
+19.6	1.2340892473253846
+19.7	1.283306461498362
+19.8	1.3353667307271526
+19.9	1.390506963677426
+20.0	1.4490221304592847
+20.1	1.5108272412683106
+20.2	1.5756498889935524
+20.3	1.6432176665240672
+20.4	1.7132416554555883
+20.5	1.7850171819600342
+20.6	1.8574028668512483
+20.7	1.9292366271068422
+20.8	1.9993208745275475
+20.9	2.0660940346468717
+21.0	2.1277942682708333
+21.1	2.18265314305678
+21.2	2.2288070409036846
+21.3	2.264357181243495
+21.4	2.2874131605429233
+21.5	2.296180291918321
+21.6	2.289083976539992
+21.7	2.264579750279618
+21.8	2.221242053015983
+21.9	2.1581384971851687
+22.0	2.07444797128032
+22.1	1.9696986400326784
+22.2	1.8459166858706375
+22.3	1.7060336169944468
+22.4	1.5545508647268924
+22.5	1.3981481091710037
+22.6	1.243669657943292
+22.7	1.0974201713894012
+22.8	0.964549182681927
+22.9	0.8480951425988938
+23.0	0.7495259504671893
+23.1	0.6687072156759692
+23.2	0.6047275358053016
+23.3	0.5557981694927575
+23.4	0.520032785124636
+23.5	0.4953215039502459
+23.6	0.4796601630143706
+23.7	0.47130329126658715
+23.8	0.4687770419922361
+23.9	0.47091064222701956
+24.0	0.4768103187855602
+24.1	0.48575861705853646
+24.2	0.4972403844167946
+24.3	0.5108326949708443
+24.4	0.5262293106651805
+24.5	0.543175416289768
+24.6	0.5614679838095132
+24.7	0.5809580899799074
+24.8	0.6015015508328674
+24.9	0.623007587183179
+25.0	0.6454124949179368
+25.1	0.6686432982995325
+25.2	0.6927052230934189
+25.3	0.7176774326690653
+25.4	0.7436396766404005
+25.5	0.7706835033848395
+25.6	0.7990581815613631
+25.7	0.8290855970006543
+25.8	0.861087863650821
+25.9	0.8954400914658145
+26.0	0.9326041690150388
+26.1	0.9730441645817506
+26.2	1.0172399395188414
+26.3	1.065569839590821
+26.4	1.1183161295913722
+26.5	1.1757632224307901
+26.6	1.2379911151210596
+26.7	1.3047453154097126
+26.8	1.3757479168945408
+26.9	1.4506832758146886
+27.0	1.5288462263691007
+27.1	1.6093018750635322
+27.2	1.691112364236993
+27.3	1.773266150378219
+27.4	1.85440674901966
+27.5	1.9330975695504788
+27.6	2.007902903975058
+27.7	2.0772538147725146
+27.8	2.1394285584834822
+27.9	2.1927291079959237
+28.0	2.2352747709368015
+28.1	2.2643775373143513
+28.2	2.2772046606659635
+28.3	2.270371666387715
+28.4	2.240382229032466
+28.5	2.183817013271277
+28.6	2.0996186194131172
+28.7	1.9889581713834459
+28.8	1.8548677487477967
+28.9	1.703957917620767
+29.0	1.5435292751040297
+29.1	1.3815972010655604
+29.2	1.2254114531926712
+29.3	1.0805567268892087
+29.4	0.9511333493643609
+29.5	0.8394367089280085
+29.6	0.7461691629316033
+29.7	0.6706201995489124
+29.8	0.6110545713997341
+29.9	0.5651135619273114
+30.0	0.5304043138475901
+30.1	0.5046901164257204
+30.2	0.4860516734132419
+30.3	0.4729451664825794
+30.4	0.4641226497085326
+30.5	0.45860609382817763
+30.6	0.45562999174204727
+30.7	0.4545785378505115
+30.8	0.4549700104034184
+30.9	0.4564136674254572
+31.0	0.45859190947887885
+31.1	0.4612647227777334
+31.2	0.4642195168290886
+31.3	0.46733578319192937
+31.4	0.47051085348953503
+31.5	0.47369653104529064
+31.6	0.47694780188259134
+31.7	0.48032737920880336
+31.8	0.48393406020116203
+31.9	0.4879958270293249
+32.0	0.4927648310304586
+32.1	0.49850016915329765
+32.2	0.5055299771348553
+32.3	0.514212751113701
+32.4	0.5249162352984859
+32.5	0.5380007457734904
+32.6	0.5537808708532589
+32.7	0.5725699252989864
+32.8	0.5945802471045728
+32.9	0.6199204504554963
+33.0	0.6486843005529993
+33.1	0.6808362961459824
+33.2	0.7162623863958516
+33.3	0.754847425572777
+33.4	0.7964517347646926
+33.5	0.8408642326447557
+33.6	0.8879545098743513
+33.7	0.9376010447031385
+33.8	0.9896758291024246
+33.9	1.0441143467687217
+34.0	1.100934501377058
+34.1	1.1601563039320102
+34.2	1.2218364273501483
+34.3	1.2860192060615765
+34.4	1.3518544837209678
+34.5	1.4183399818551776
+34.6	1.484367994942448
+34.7	1.5477486535655423
+34.8	1.6057626950609103
+34.9	1.6555401367937403
+35.0	1.6943490739387572
+35.1	1.7196126570359898
+35.2	1.7296043493154853
+35.3	1.7230350377951533
+35.4	1.6994149752243763
+35.5	1.6592004165101453
+35.6	1.6031180156382323
+35.7	1.5336010639481854
+35.8	1.453880224829799
+35.9	1.3677145954776975
+36.0	1.2791702291355362
+36.1	1.1918999909843286
+36.2	1.1084209775689737
+36.3	1.0306049336493708
+36.4	0.9591511775580596
+36.5	0.894362286039932
+36.6	0.8358741307917384
+36.7	0.7831805232808203
+36.8	0.7356600549113779
+36.9	0.692610171844157
+37.0	0.6533777163294723
+37.1	0.6173497423372787
+37.2	0.5840517122093792
+37.3	0.5530311659487448
+37.4	0.5239646180249706
+37.5	0.4966071041378533
+37.6	0.47073320198676105
+37.7	0.446230433501325
+37.8	0.4230170679039287
+37.9	0.40101471390406984
+38.0	0.38021124302352716
+38.1	0.36062021933329635
+38.2	0.3422537314080519
+38.3	0.3251123738494721
+38.4	0.3092450084104864
+38.5	0.2947145342214175
+38.6	0.2815838504125872
+38.7	0.26991585611431695
+38.8	0.2597754033409548
+38.9	0.2512283964460365
+39.0	0.24433732917493914
+39.1	0.23916679234203433
+39.2	0.23578638032699598
+39.3	0.23420246767962388
+39.4	0.2344077213818149
+39.5	0.23639464632285906
+39.6	0.24013070927175628
+39.7	0.24561478245105658
+39.8	0.25287467987631057
+39.9	0.2619382155630708
+40.0	0.272833203526888
diff --git a/test/julia_dde/test_basic_check_11.txt b/test/julia_dde/test_basic_check_11.txt
new file mode 100644
index 00000000..c1948729
--- /dev/null
+++ b/test/julia_dde/test_basic_check_11.txt
@@ -0,0 +1,351 @@
+2.0	0.5
+2.01	0.5049999999999999
+2.02	0.51
+2.03	0.5149999999999999
+2.04	0.52
+2.05	0.5249999999999999
+2.06	0.53
+2.07	0.5349999999999999
+2.08	0.54
+2.09	0.5449999999999999
+2.1	0.55
+2.11	0.5549999999999999
+2.12	0.56
+2.13	0.565
+2.14	0.5700000000000001
+2.15	0.575
+2.16	0.5800000000000001
+2.17	0.585
+2.18	0.5900000000000001
+2.19	0.595
+2.2	0.6000000000000001
+2.21	0.605
+2.22	0.6100000000000001
+2.23	0.615
+2.24	0.6200000000000001
+2.25	0.625
+2.26	0.6299999999999999
+2.27	0.635
+2.28	0.6399999999999999
+2.29	0.645
+2.3	0.6499999999999999
+2.31	0.655
+2.32	0.6599999999999999
+2.33	0.665
+2.34	0.6699999999999999
+2.35	0.675
+2.36	0.6799999999999999
+2.37	0.685
+2.38	0.69
+2.39	0.6950000000000001
+2.4	0.7
+2.41	0.7050000000000001
+2.42	0.71
+2.43	0.7150000000000001
+2.44	0.72
+2.45	0.7250000000000001
+2.46	0.73
+2.47	0.7350000000000001
+2.48	0.74
+2.49	0.7450000000000001
+2.5	0.75
+2.51	0.7549999999999999
+2.52	0.76
+2.53	0.7649999999999999
+2.54	0.77
+2.55	0.7749999999999999
+2.56	0.78
+2.57	0.7849999999999999
+2.58	0.79
+2.59	0.7949999999999999
+2.6	0.8
+2.61	0.8049999999999999
+2.62	0.81
+2.63	0.815
+2.64	0.8200000000000001
+2.65	0.825
+2.66	0.8300000000000001
+2.67	0.835
+2.68	0.8400000000000001
+2.69	0.845
+2.7	0.8500000000000001
+2.71	0.855
+2.72	0.8600000000000001
+2.73	0.865
+2.74	0.8700000000000001
+2.75	0.875
+2.76	0.8799999999999999
+2.77	0.885
+2.78	0.8899999999999999
+2.79	0.895
+2.8	0.8999999999999999
+2.81	0.905
+2.82	0.9099999999999999
+2.83	0.915
+2.84	0.9199999999999999
+2.85	0.925
+2.86	0.9299999999999999
+2.87	0.935
+2.88	0.94
+2.89	0.9450000000000001
+2.9	0.95
+2.91	0.9550000000000001
+2.92	0.96
+2.93	0.9650000000000001
+2.94	0.97
+2.95	0.9750000000000001
+2.96	0.98
+2.97	0.9850000000000001
+2.98	0.99
+2.99	0.9950000000000001
+3.0	1.0
+3.01	1.005
+3.02	1.01
+3.03	1.015
+3.04	1.02
+3.05	1.025
+3.06	1.03
+3.07	1.035
+3.08	1.04
+3.09	1.045
+3.1	1.05
+3.11	1.055
+3.12	1.06
+3.13	1.065
+3.14	1.07
+3.15	1.075
+3.16	1.08
+3.17	1.085
+3.18	1.09
+3.19	1.095
+3.2	1.1
+3.21	1.105
+3.22	1.11
+3.23	1.115
+3.24	1.12
+3.25	1.125
+3.26	1.13
+3.27	1.135
+3.28	1.14
+3.29	1.145
+3.3	1.15
+3.31	1.155
+3.32	1.16
+3.33	1.165
+3.34	1.17
+3.35	1.175
+3.36	1.18
+3.37	1.185
+3.38	1.19
+3.39	1.195
+3.4	1.2
+3.41	1.205
+3.42	1.21
+3.43	1.215
+3.44	1.22
+3.45	1.225
+3.46	1.23
+3.47	1.235
+3.48	1.24
+3.49	1.245
+3.5	1.25
+3.51	1.255
+3.52	1.26
+3.53	1.265
+3.54	1.27
+3.55	1.275
+3.56	1.28
+3.57	1.285
+3.58	1.29
+3.59	1.2950000000000002
+3.6	1.3000000000000003
+3.61	1.3050000000000002
+3.62	1.3100000000000003
+3.63	1.3150000000000002
+3.64	1.3200000000000003
+3.65	1.3250000000000002
+3.66	1.3300000000000003
+3.67	1.3350000000000002
+3.68	1.3400000000000003
+3.69	1.3450000000000002
+3.7	1.3500000000000003
+3.71	1.3550000000000002
+3.72	1.3600000000000003
+3.73	1.3650000000000002
+3.74	1.3700000000000003
+3.75	1.3750000000000002
+3.76	1.3800000000000001
+3.77	1.3850000000000002
+3.78	1.3900000000000001
+3.79	1.3950000000000002
+3.8	1.4000000000000001
+3.81	1.4050000000000002
+3.82	1.4100000000000001
+3.83	1.4150000000000003
+3.84	1.4200000000000002
+3.85	1.4250000000000003
+3.86	1.4300000000000002
+3.87	1.4350000000000003
+3.88	1.4400000000000002
+3.89	1.4450000000000003
+3.9	1.4500000000000002
+3.91	1.4550000000000003
+3.92	1.4600000000000002
+3.93	1.4650000000000003
+3.94	1.4700000000000002
+3.95	1.4750000000000003
+3.96	1.4800000000000002
+3.97	1.4850000000000003
+3.98	1.4900000000000002
+3.99	1.4950000000000003
+4.0	1.5000000000000002
+4.01	1.5050000000000001
+4.02	1.51
+4.03	1.5150000000000003
+4.04	1.5199982174807718
+4.05	1.5249893583619993
+4.06	1.5299732881826311
+4.07	1.5349502477752208
+4.08	1.539920477972321
+4.09	1.5448842196064865
+4.1	1.5498417135102704
+4.11	1.5547932005162266
+4.12	1.5597389214569082
+4.13	1.5646791171648688
+4.14	1.5696140284726625
+4.15	1.574543896212843
+4.16	1.5794689612179629
+4.17	1.5843894643205765
+4.18	1.5893056463532376
+4.19	1.5942177481484996
+4.2	1.5991260105389156
+4.21	1.6040306743570396
+4.22	1.6089319804354254
+4.23	1.6138301696066268
+4.24	1.6187254827031965
+4.25	1.6236181605576885
+4.26	1.6285084440026567
+4.27	1.633396573870654
+4.28	1.6382827909942352
+4.29	1.643167336205953
+4.3	1.6480504503383608
+4.31	1.6529323742240127
+4.32	1.6578133486954625
+4.33	1.662693614585263
+4.34	1.6675734127259685
+4.35	1.6724529839501319
+4.36	1.6773325690903078
+4.37	1.682212408979049
+4.38	1.6870927444489094
+4.39	1.6919738163324425
+4.4	1.696855865462202
+4.41	1.7017391326707412
+4.42	1.706623858790614
+4.43	1.7115102846543737
+4.44	1.7163986510945748
+4.45	1.7212891989437695
+4.46	1.726182169034512
+4.47	1.7310778021993563
+4.48	1.735976339270856
+4.49	1.7408780210815642
+4.5	1.7457830884640342
+4.51	1.7506917822508201
+4.52	1.7556043432744757
+4.53	1.760521012367555
+4.54	1.7654420303626102
+4.55	1.7703676380921956
+4.56	1.775298076388865
+4.57	1.7802335860851721
+4.58	1.7851744080136702
+4.59	1.7901207830069126
+4.6	1.795072951897453
+4.61	1.8000311555178463
+4.62	1.8049956347006442
+4.63	1.8099666302784012
+4.64	1.8149443830836707
+4.65	1.8199291339490067
+4.66	1.8249211237069622
+4.67	1.8299205931900915
+4.68	1.8349277832309474
+4.69	1.8399429346620844
+4.7	1.8449662883160551
+4.71	1.8499980850254136
+4.72	1.8550385656227135
+4.73	1.8600879709405087
+4.74	1.8651465418113524
+4.75	1.870214519067798
+4.76	1.8752921435423993
+4.77	1.88037965606771
+4.78	1.8854772974762841
+4.79	1.8905853086006743
+4.8	1.8957039302734346
+4.81	1.9008334033271188
+4.82	1.9059739685942807
+4.83	1.911125866907473
+4.84	1.91628933909925
+4.85	1.9214646260021648
+4.86	1.926651968448772
+4.87	1.9318516072716239
+4.88	1.9370637833032749
+4.89	1.9422887373762783
+4.9	1.9475267103231884
+4.91	1.9527779429765575
+4.92	1.95804267616894
+4.93	1.9633211507328892
+4.94	1.9686136075009595
+4.95	1.9739202873057033
+4.96	1.9792414309796749
+4.97	1.9845772793554277
+4.98	1.989928073265516
+4.99	1.9952940535424921
+5.0	2.000675461018911
+5.01	2.0060725365273244
+5.02	2.0114855209002878
+5.03	2.016914654970354
+5.04	2.0223601795700765
+5.05	2.027822335532009
+5.06	2.0333013636887047
+5.07	2.0387975048727185
+5.08	2.0443109999166023
+5.09	2.049842089652911
+5.1	2.055391014914197
+5.11	2.0609580165330157
+5.12	2.0665433353419194
+5.13	2.0721472121734617
+5.14	2.0777698878601956
+5.15	2.083411603234677
+5.16	2.089072599129457
+5.17	2.09475311637709
+5.18	2.1004533958101304
+5.19	2.106173678261131
+5.2	2.111914204562645
+5.21	2.117675215547227
+5.22	2.1234569520474302
+5.23	2.1292596548958085
+5.24	2.1350835649249142
+5.25	2.1409289229673023
+5.26	2.146795969855526
+5.27	2.1526849464221387
+5.28	2.1585960934996953
+5.29	2.164529651920746
+5.3	2.1704858625178485
+5.31	2.176464966123554
+5.32	2.1824672035704165
+5.33	2.1884928156909895
+5.34	2.194542043317827
+5.35	2.200615127283482
+5.36	2.206712308420509
+5.37	2.212833827561461
+5.38	2.2189799255388913
+5.39	2.225150843185354
+5.4	2.2313468213334033
+5.41	2.2375681008155914
+5.42	2.2438149224644723
+5.43	2.2500875271126004
+5.44	2.256386155592529
+5.45	2.2627110487368114
+5.46	2.2690624473780003
+5.47	2.275440592348651
+5.48	2.2818457244813164
+5.49	2.28827808460855
+5.5	2.294737913562905
diff --git a/test/julia_dde/test_basic_check_2.txt b/test/julia_dde/test_basic_check_2.txt
new file mode 100644
index 00000000..66a0924a
--- /dev/null
+++ b/test/julia_dde/test_basic_check_2.txt
@@ -0,0 +1,1001 @@
+0.0	1.2
+0.05	1.1880597992376376
+0.1	1.1762383903532672
+0.15	1.1645346054176793
+0.2	1.1529472850205345
+0.25	1.1414752222367555
+0.3	1.1301172850650132
+0.35	1.1188723699778464
+0.4	1.1077393734477934
+0.45	1.0967171444550854
+0.5	1.0858044886142186
+0.55	1.0750003653428541
+0.6	1.0643037416988483
+0.65	1.0537135847400563
+0.7	1.0432288615243335
+0.75	1.0328485303971229
+0.8	1.022571356528724
+0.85	1.0123962869380496
+0.9	1.0023223752836392
+0.95	0.9923486752240321
+1.0	0.9824742404177691
+1.05	0.9726981245233894
+1.1	0.963019381199433
+1.15	0.9534370641044397
+1.2	0.9439501263044995
+1.25	0.9345573120423534
+1.3	0.9252577408761488
+1.35	0.9160505652198292
+1.4	0.9069349374873376
+1.45	0.8979100100926173
+1.5	0.8889749354496113
+1.55	0.8801288659722627
+1.6	0.8713709540745148
+1.65	0.8627003521703108
+1.7	0.8541162126735939
+1.75	0.8456175814341691
+1.8	0.8372034571683109
+1.85	0.8288730339155551
+1.9	0.8206255087352973
+1.95	0.8124600786869333
+2.0	0.8043759408298582
+2.05	0.7966049878719209
+2.1	0.7893708372331322
+2.15	0.7826602730411311
+2.2	0.7764600794235568
+2.25	0.7707570405080486
+2.3	0.7655379404222455
+2.35	0.7607895632937867
+2.4	0.7564986932503112
+2.45	0.7526521144194582
+2.5	0.7492366109288668
+2.55	0.7462389669061763
+2.6	0.7436465940570072
+2.65	0.741455112858613
+2.7	0.7396597834399526
+2.75	0.7382545843876566
+2.8	0.7372334942883548
+2.85	0.7365904917286781
+2.9	0.7363191175378985
+2.95	0.7364135450318054
+3.0	0.7368722358337684
+3.05	0.7376939906458355
+3.1	0.7388776101700548
+3.15	0.7404218951084749
+3.2	0.742325646163144
+3.25	0.7445876640361099
+3.3	0.7472067494294214
+3.35	0.7501817030451264
+3.4	0.7535113255852732
+3.45	0.7571944509563947
+3.5	0.7612334422829883
+3.55	0.7656327089678326
+3.6	0.7703966636321324
+3.65	0.7755297188970925
+3.7	0.7810362873839177
+3.75	0.7869207817138124
+3.8	0.7931876145079815
+3.85	0.7998411983876295
+3.9	0.8068859459739612
+3.95	0.8143262698881814
+4.0	0.8221665827514945
+4.05	0.8304103241453805
+4.1	0.8390503787717005
+4.15	0.8480753269322309
+4.2	0.8574737489287485
+4.25	0.8672342250630293
+4.3	0.8773453356368505
+4.35	0.8877956609519881
+4.4	0.8985737813102191
+4.45	0.9096682770133198
+4.5	0.9210677283630667
+4.55	0.9327613644666946
+4.6	0.9447382041491337
+4.65	0.9569838960216988
+4.7	0.9694839130856466
+4.75	0.9822237283422343
+4.8	0.9951888147927191
+4.85	1.008364645438358
+4.9	1.0217366932804082
+4.95	1.0352904313201265
+5.0	1.04901133255877
+5.05	1.0628853355824783
+5.1	1.0768949374628292
+5.15	1.0910168480680784
+5.2	1.1052275566412175
+5.25	1.1195035524252361
+5.3	1.1338213246631248
+5.35	1.1481573625978743
+5.4	1.1624881554724753
+5.45	1.1767901925299176
+5.5	1.191039963013192
+5.55	1.2052140962437965
+5.6	1.2192851717972115
+5.65	1.2332183443340552
+5.7	1.246978228490729
+5.75	1.2605294389036346
+5.8	1.2738365902091744
+5.85	1.286864297043751
+5.9	1.299577174043766
+5.95	1.3119398358456216
+6.0	1.3239168663975618
+6.05	1.3354681693230959
+6.1	1.3465516997729066
+6.15	1.3571276627017965
+6.2	1.3671568094140312
+6.25	1.376600463027355
+6.3	1.3854207406401668
+6.35	1.3935812044918732
+6.4	1.4010455467051504
+6.45	1.4077795499344365
+6.5	1.4137513166849074
+6.55	1.41892900858906
+6.6	1.4232812704960902
+6.65	1.4267827653222673
+6.7	1.429409968893713
+6.75	1.431139357036549
+6.8	1.4319472554482031
+6.85	1.4318171740583696
+6.9	1.4307390762174144
+6.95	1.428703024470033
+7.0	1.42569908136092
+7.05	1.4217165437466892
+7.1	1.4167471473666722
+7.15	1.410800513227501
+7.2	1.4038886531213686
+7.25	1.3960235788404673
+7.3	1.3872173021769896
+7.35	1.3774818349231273
+7.4	1.3668292087293314
+7.45	1.3552778263113545
+7.5	1.3428615353140843
+7.55	1.3296164771187597
+7.6	1.3155787931066185
+7.65	1.3007846246588997
+7.7	1.2852701131568416
+7.75	1.269071399981683
+7.8	1.252228812462711
+7.85	1.234795626158659
+7.9	1.216822773260856
+7.95	1.1983609437575724
+8.0	1.1794608276370777
+8.05	1.1601753466433684
+8.1	1.1405600662416302
+8.15	1.1206696420019118
+8.2	1.1005587294942656
+8.25	1.0802819842887406
+8.3	1.0598940619553883
+8.35	1.0394506878014353
+8.4	1.0190050205908718
+8.45	0.9986058410438103
+8.5	0.9783018719261123
+8.55	0.9581418360036411
+8.6	0.9381744916273773
+8.65	0.918447522241859
+8.7	0.8989988519527116
+8.75	0.8798636601114229
+8.8	0.8610771260694831
+8.85	0.8426744291783823
+8.9	0.8246907487896087
+8.95	0.8071595445403982
+9.0	0.790103625750701
+9.05	0.7735417306334335
+9.1	0.7574925892616298
+9.15	0.741974931708322
+9.2	0.7270074880465447
+9.25	0.71260890886523
+9.3	0.6987925775459839
+9.35	0.6855636953491089
+9.4	0.6729267876146086
+9.45	0.6608863796824879
+9.5	0.6494469968927505
+9.55	0.6386131645854012
+9.6	0.6283894081004446
+9.65	0.6187799376548058
+9.7	0.6097843760508866
+9.75	0.6013989615583346
+9.8	0.5936198578099892
+9.85	0.5864432284386905
+9.9	0.5798652370772771
+9.95	0.573882047358589
+10.0	0.5684898229154657
+10.05	0.5636847273807462
+10.1	0.5594629243872706
+10.15	0.5558207587612087
+10.2	0.5527556940533047
+10.25	0.5502643788856739
+10.3	0.5483432549029339
+10.35	0.5469887637497023
+10.4	0.5461973470705961
+10.45	0.5459654465102328
+10.5	0.5462895037132297
+10.55	0.5471659603242042
+10.6	0.5485912579877733
+10.65	0.550562243050093
+10.7	0.553079622010286
+10.75	0.5561444560988579
+10.8	0.559757507016326
+10.85	0.5639195364632071
+10.9	0.5686313061400197
+10.95	0.57389357774728
+11.0	0.5797071129855064
+11.05	0.5860726735552156
+11.1	0.5929910211569253
+11.15	0.6004629174911527
+11.2	0.6084891922610574
+11.25	0.6170718662445521
+11.3	0.6262137967331077
+11.35	0.6359178351663541
+11.4	0.6461868329839221
+11.45	0.6570236416254412
+11.5	0.6684311125305419
+11.55	0.6804120971388541
+11.6	0.6929694468900075
+11.65	0.7061060132236332
+11.7	0.7198246475793599
+11.75	0.7341282013968191
+11.8	0.7490195261156405
+11.85	0.7645014731754531
+11.9	0.7805768940158886
+11.95	0.7972486400765756
+12.0	0.8145195627971458
+12.05	0.8323908015688225
+12.1	0.8508545340275291
+12.15	0.8699000842401812
+12.2	0.8895167776464826
+12.25	0.9096939396861401
+12.3	0.9304208957988584
+12.35	0.9516869714243422
+12.4	0.9734814920022978
+12.45	0.9957937829724295
+12.5	1.0186131697744438
+12.55	1.0419289778480454
+12.6	1.0657305326329392
+12.65	1.0900066173006855
+12.7	1.1147321345325476
+12.75	1.1398673063212037
+12.8	1.1653716523320719
+12.85	1.1912046922305697
+12.9	1.2173259456821182
+12.95	1.2436949323521345
+13.0	1.2702711719060389
+13.05	1.2970141840092497
+13.1	1.3238834883271844
+13.15	1.3508386045252643
+13.2	1.3778325503119393
+13.25	1.4047954094509787
+13.3	1.4316525203717476
+13.35	1.4583292228383717
+13.4	1.4847508566149794
+13.45	1.5108427614656945
+13.5	1.5365302771546454
+13.55	1.561738743445957
+13.6	1.5863935001037555
+13.65	1.610414202912512
+13.7	1.6337068375202188
+13.75	1.6561767438550645
+13.8	1.6777292647126993
+13.85	1.6982697428887752
+13.9	1.7177035211789442
+13.95	1.7359359423788572
+14.0	1.7528723492841671
+14.05	1.7684149463881755
+14.1	1.7824618980487044
+14.15	1.7949338999216151
+14.2	1.8057539943026195
+14.25	1.814845223487432
+14.3	1.8221306297717643
+14.35	1.82753325545133
+14.4	1.8309761428218418
+14.45	1.8323794069570287
+14.5	1.8316821288780631
+14.55	1.8288675527016813
+14.6	1.823921596703502
+14.65	1.8168301791591426
+14.7	1.8075792183442216
+14.75	1.7961546325343571
+14.8	1.782542340005167
+14.85	1.766728773141058
+14.9	1.7487367633658564
+14.95	1.7286453731881093
+15.0	1.7065382792707242
+15.05	1.6824991582766124
+15.1	1.6566116868686838
+15.15	1.6289595417098466
+15.2	1.5996263994630127
+15.25	1.568695936791089
+15.3	1.536253349782111
+15.35	1.5024473904652802
+15.4	1.4674397105285182
+15.45	1.4313735875857403
+15.5	1.3943922992508566
+15.55	1.3566391231377808
+15.6	1.3182573368604271
+15.65	1.279400575486758
+15.7	1.2402295487275168
+15.75	1.200883163043972
+15.8	1.1614997797914177
+15.85	1.1222177603251473
+15.9	1.0831754660004496
+15.95	1.0445113833795212
+16.0	1.0063533457305782
+16.05	0.9688002204454318
+16.1	0.9319467464640154
+16.15	0.8958876627262662
+16.2	0.8607177081721128
+16.25	0.8265304313107716
+16.3	0.7933861546212205
+16.35	0.7613223598864833
+16.4	0.7303785371044332
+16.45	0.7005941762729371
+16.5	0.6720087673898681
+16.55	0.6446617413965466
+16.6	0.6185608877455635
+16.65	0.5936873817798612
+16.7	0.5700343087160399
+16.75	0.5475947537707044
+16.8	0.5263618021604588
+16.85	0.506328539101906
+16.9	0.4874878506593186
+16.95	0.4698031885394685
+17.0	0.4532304550269939
+17.05	0.4377398368134014
+17.1	0.42330152059019815
+17.15	0.4098856930488919
+17.2	0.3974625408809876
+17.25	0.38600225077799316
+17.3	0.3754750094314155
+17.35	0.3658472166938709
+17.4	0.3570792892826172
+17.45	0.3491421915672284
+17.5	0.34200751387452194
+17.55	0.3356468465313144
+17.6	0.3300317798644222
+17.65	0.32513517181013774
+17.7	0.320936382123747
+17.75	0.3174167617663639
+17.8	0.31455766170401606
+17.85	0.3123404329027309
+17.9	0.31074642632853605
+17.95	0.30975744733860155
+18.0	0.3093624455811279
+18.05	0.30955427514298167
+18.1	0.31032556501786496
+18.15	0.3116689441994796
+18.2	0.31357704168152767
+18.25	0.3160424864577112
+18.3	0.31905885102136716
+18.35	0.32262653815988585
+18.4	0.3267487371484904
+18.45	0.331428641219179
+18.5	0.33666944360394907
+18.55	0.3424743375347985
+18.6	0.34884651624372476
+18.65	0.35578917296272516
+18.7	0.3633037997046326
+18.75	0.37139594954858146
+18.8	0.38007859693560425
+18.85	0.38936483532409294
+18.9	0.3992677581724386
+18.95	0.40980045893903455
+19.0	0.4209760310822719
+19.05	0.43280756806054255
+19.1	0.4453081633322383
+19.15	0.4584905563115862
+19.2	0.4723626480734233
+19.25	0.48693998242042613
+19.3	0.5022411124923893
+19.35	0.5182845914291078
+19.4	0.5350889723703751
+19.45	0.5526728084559884
+19.5	0.5710546528257416
+19.55	0.5902460724801303
+19.6	0.6102565122900444
+19.65	0.6311048001124743
+19.7	0.6528098103792354
+19.75	0.6753904175221382
+19.8	0.6988654959729957
+19.85	0.7232539201636203
+19.9	0.7485745645258226
+19.95	0.7748416206442341
+20.0	0.8020539900156893
+20.05	0.8302186203948585
+20.1	0.8593433564222365
+20.15	0.8894360427383163
+20.2	0.9205045239835973
+20.25	0.952556644798572
+20.3	0.9856002498237355
+20.35	1.019643419454444
+20.4	1.054684371182126
+20.45	1.0906953470085856
+20.5	1.127645333405545
+20.55	1.1655033168447293
+20.6	1.204238283797864
+20.65	1.2438192207366707
+20.7	1.2842151141328804
+20.75	1.325394950458215
+20.8	1.3673277161843997
+20.85	1.40998239778316
+20.9	1.45332607021001
+20.95	1.4972802123053142
+21.0	1.5417363102303465
+21.05	1.5865881045653027
+21.1	1.6317293358903808
+21.15	1.6770537447857734
+21.2	1.7224550718316847
+21.25	1.7678270576083077
+21.3	1.813063391469394
+21.35	1.858038694147211
+21.4	1.9025805571632746
+21.45	1.946509478758231
+21.5	1.9896459571727148
+21.55	2.0318104906473637
+21.6	2.0728235774228154
+21.65	2.1125057157397067
+21.7	2.1506735060182325
+21.75	2.1871176192097326
+21.8	2.221622329637741
+21.85	2.253972346363065
+21.9	2.2839523784465077
+21.95	2.3113471349488814
+22.0	2.3359413249309893
+22.05	2.357514255440453
+22.1	2.3758638619860677
+22.15	2.390821392124329
+22.2	2.402218885405468
+22.25	2.4098883813797203
+22.3	2.4136619195973172
+22.35	2.413371104664108
+22.4	2.408848887946326
+22.45	2.400042852614458
+22.5	2.386940003070496
+22.55	2.369527343716431
+22.6	2.3477918789542556
+22.65	2.321720613185963
+22.7	2.291301001022439
+22.75	2.256588251369826
+22.8	2.2177738244483276
+22.85	2.1750653122035586
+22.9	2.1286703065811365
+22.95	2.078796399526671
+23.0	2.0256511829857793
+23.05	1.9694422489040764
+23.1	1.910392381274867
+23.15	1.8489014542521118
+23.2	1.7853456340537126
+23.25	1.720072739617695
+23.3	1.6534305898820802
+23.35	1.5857821505131908
+23.4	1.517512562204179
+23.45	1.4489708606272123
+23.5	1.3805043089562519
+23.55	1.3124601703652552
+23.6	1.2451837286755356
+23.65	1.178976848347684
+23.7	1.1140994754669653
+23.75	1.0508098710579707
+23.8	0.9893662961452859
+23.85	0.9299774904521412
+23.9	0.8727644229741919
+23.95	0.8178614884362746
+24.0	0.765403137830142
+24.05	0.7155214609248933
+24.1	0.6682234985113544
+24.15	0.6234852944314109
+24.2	0.5813223508384042
+24.25	0.5417501698856858
+24.3	0.5047809410656791
+24.35	0.4702490710053619
+24.4	0.43801699461677357
+24.45	0.4080314436584836
+24.5	0.3802391498890684
+24.55	0.3545868450671021
+24.6	0.33102126095115864
+24.65	0.3094891292998137
+24.7	0.2898915704868967
+24.75	0.27197102298581677
+24.8	0.25561759602685336
+24.85	0.2407409190854474
+24.9	0.22725062163704088
+24.95	0.21503913050897844
+25.0	0.20399995045085811
+25.05	0.1940584465447139
+25.1	0.18514026104802944
+25.15	0.17717065444092198
+25.2	0.1700719915589609
+25.25	0.16378471035329328
+25.3	0.1582546649027792
+25.35	0.15342770928627864
+25.4	0.1492516107167886
+25.45	0.14568659056752503
+25.5	0.14269758042836034
+25.55	0.14024951590303492
+25.6	0.1383082040963155
+25.65	0.13684863446175974
+25.7	0.1358509334570259
+25.75	0.13529525099259698
+25.8	0.13516176358430795
+25.85	0.13543452018721888
+25.9	0.13610459877564784
+25.95	0.13716371337688696
+26.0	0.13860357801822823
+26.05	0.1404159274138563
+26.1	0.14259542888441518
+26.15	0.1451422764500677
+26.2	0.1480572549370119
+26.25	0.15134114917144528
+26.3	0.15499474397956553
+26.35	0.15901776789984617
+26.4	0.16341001635660593
+26.45	0.16818063165382552
+26.5	0.1733396236123495
+26.55	0.17889700205302284
+26.6	0.18486277679669036
+26.65	0.19124695766419658
+26.7	0.1980595544763873
+26.75	0.20531057705410694
+26.8	0.2130091142070617
+26.85	0.22116534951706623
+26.9	0.229796060099537
+26.95	0.23891841604515068
+27.0	0.24854958744458225
+27.05	0.25870514396269567
+27.1	0.2693950295254785
+27.15	0.2806409279239909
+27.2	0.2924663956486647
+27.25	0.30489498918992936
+27.3	0.31795026503821516
+27.35	0.3316557796839525
+27.4	0.34602783572910606
+27.45	0.3610861632218346
+27.5	0.37685988559580935
+27.55	0.39337812689988166
+27.6	0.4106700111829034
+27.65	0.42876428338912337
+27.7	0.4476772357046438
+27.75	0.46743472263991304
+27.8	0.48806998680966823
+27.85	0.5096162708286465
+27.9	0.532106817311583
+27.95	0.5555748688732185
+28.0	0.580042180769571
+28.05	0.6055277327197993
+28.1	0.6320645363227207
+28.15	0.6596856403147902
+28.2	0.6884240934324688
+28.25	0.718312944412212
+28.3	0.7493783137861545
+28.35	0.7816277086807297
+28.4	0.8150838833899493
+28.45	0.84977094159701
+28.5	0.8857129869851015
+28.55	0.9229341232374154
+28.6	0.9614584540371438
+28.65	1.0012960084357674
+28.7	1.042434191698033
+28.75	1.084867942427294
+28.8	1.1285922642482498
+28.85	1.1736021607856004
+28.9	1.219892635664043
+28.95	1.2674586925082838
+29.0	1.3162955954614826
+29.05	1.366384218775142
+29.1	1.4176537098098794
+29.15	1.470024025055333
+29.2	1.5234151210011524
+29.25	1.5777469541369766
+29.3	1.6329394809524487
+29.35	1.6889126579372107
+29.4	1.745586441580901
+29.45	1.8028807883731701
+29.5	1.8606869422412853
+29.55	1.9188124651114538
+29.6	1.9770712864256343
+29.65	2.0352780490422004
+29.7	2.093247395819539
+29.75	2.150793986091502
+29.8	2.2077273965704394
+29.85	2.26379597002396
+29.9	2.318724822178128
+29.95	2.372239068759022
+30.0	2.424063825492708
+30.05	2.4739242081052555
+30.1	2.521541879078581
+30.15	2.5666087871829926
+30.2	2.608804939012352
+30.25	2.6478106447471013
+30.3	2.6833062145676867
+30.35	2.7149719586545533
+30.4	2.7424831828312306
+30.45	2.7655497943065854
+30.5	2.7839072180315863
+30.55	2.7972908995247177
+30.6	2.8054362843044625
+30.65	2.8080780063772517
+30.7	2.8049721156381238
+30.75	2.7960145621013934
+30.8	2.7811342928097287
+30.85	2.7602602548057984
+30.9	2.733321395132273
+30.95	2.7002425895698785
+31.0	2.661010897286886
+31.05	2.6158558522281976
+31.1	2.565042645827734
+31.15	2.508836469519418
+31.2	2.44750251473716
+31.25	2.3813059729148827
+31.3	2.3105120354865054
+31.35	2.2354381815525945
+31.4	2.156702857013973
+31.45	2.074847163363706
+31.5	1.9903873003802015
+31.55	1.9038394678418598
+31.6	1.8157266294671082
+31.65	1.726665652702312
+31.7	1.6372154898323137
+31.75	1.5479042121397886
+31.8	1.4592598909074068
+31.85	1.3718102277748077
+31.9	1.2860343923060953
+31.95	1.2023234419946591
+32.0	1.1210600894128138
+32.05	1.042627047132873
+32.1	0.9673711168885645
+32.15	0.8954221158018899
+32.2	0.8269441495689339
+32.25	0.7621168184912944
+32.3	0.7011197228705403
+32.35	0.6439504928080839
+32.4	0.5904545775039313
+32.45	0.5406363914651543
+32.5	0.4945008513328609
+32.55	0.45205138591710897
+32.6	0.41297136079719193
+32.65	0.3769634573831701
+32.7	0.3439481455494416
+32.75	0.3138458951704241
+32.8	0.28657717612052125
+32.85	0.2620624582741381
+32.9	0.2401151572629131
+32.95	0.22032450109746207
+33.0	0.20253057814315262
+33.05	0.18658620632366715
+33.1	0.17229779547947946
+33.15	0.15952363705820832
+33.2	0.14814102239923868
+33.25	0.13801915689759164
+33.3	0.12902471282767
+33.35	0.12106032996087016
+33.4	0.11402943747572661
+33.45	0.1078358333358692
+33.5	0.10240566310739144
+33.55	0.09766839349651717
+33.6	0.093556088974998
+33.65	0.09001594248840473
+33.7	0.08699908835325497
+33.75	0.08445851749320848
+33.8	0.08235822535567427
+33.85	0.08066622669548484
+33.9	0.07935103919208544
+33.95	0.07838826188643352
+34.0	0.07775859907300949
+34.05	0.07744285332490262
+34.1	0.07742388269245039
+34.15	0.07769081097631007
+34.2	0.0782339208570094
+34.25	0.0790434930171005
+34.3	0.08011146527505596
+34.35	0.0814344072701136
+34.4	0.08300965401547097
+34.45	0.08483454052432648
+34.5	0.08690572741142413
+34.55	0.0892224396700846
+34.6	0.09178819383204259
+34.65	0.09460660300308572
+34.7	0.09768128028900307
+34.75	0.10101583879558208
+34.8	0.1046126355346496
+34.85	0.10847468024545727
+34.9	0.11260988048505169
+34.95	0.11702630302400858
+35.0	0.12173201463290118
+35.05	0.1267350455054865
+35.1	0.13204083666707644
+35.15	0.13765937152111515
+35.2	0.14360371884981668
+35.25	0.14988694743539177
+35.3	0.15652212606005367
+35.35	0.16352111830607446
+35.4	0.17089283892258667
+35.45	0.1786542730445768
+35.5	0.18682343112728708
+35.55	0.19541832362596295
+35.6	0.20445695660360355
+35.65	0.21395201349236334
+35.7	0.22392028747010215
+35.75	0.2343850069281529
+35.8	0.24536940025785275
+35.85	0.2568966958505391
+35.9	0.26898848831843564
+35.95	0.28165915465934627
+36.0	0.2949350215362678
+36.05	0.3088446286795436
+36.1	0.3234165158195173
+36.15	0.33867922268652634
+36.2	0.35465440510621266
+36.25	0.37136276064704243
+36.3	0.38883754006782434
+36.35	0.40711212934018437
+36.4	0.4262199144357409
+36.45	0.4461938607580755
+36.5	0.4670529562513633
+36.55	0.4888258095760073
+36.6	0.5115489518842431
+36.65	0.5352589143282963
+36.7	0.5599922280604058
+36.75	0.5857837136601636
+36.8	0.6126493857030743
+36.85	0.6406189576299421
+36.9	0.6697279112125369
+36.95	0.7000117282226439
+37.0	0.7315058904320323
+37.05	0.7642448000840963
+37.1	0.7982397277487387
+37.15	0.8335082418617364
+37.2	0.8700752909894879
+37.25	0.9079658236983714
+37.3	0.9472047885547811
+37.35	0.9878171341251111
+37.4	1.0298190058868348
+37.45	1.0731970913876419
+37.5	1.1179436784343202
+37.55	1.164051763131001
+37.6	1.2115143415818161
+37.65	1.2603244098908775
+37.7	1.3104749641623246
+37.75	1.3619591457597686
+37.8	1.414754965797372
+37.85	1.4687851470571525
+37.9	1.5239616681850294
+37.95	1.580196507826953
+38.0	1.6374016446288429
+38.05	1.6954890572366415
+38.1	1.7543707242962916
+38.15	1.8139586244537123
+38.2	1.874164404596641
+38.25	1.9348440814597052
+38.3	1.9957894689132378
+38.35	2.0567989937263
+38.4	2.117671082667927
+38.45	2.178204162507187
+38.5	2.238197183286882
+38.55	2.2974316404495956
+38.6	2.3556212624114763
+38.65	2.4124679411940515
+38.7	2.4676735688188804
+38.75	2.520940037307493
+38.8	2.5719692386814406
+38.85	2.6204526770931507
+38.9	2.6660503670780007
+38.95	2.708418215847076
+39.0	2.7472121441010793
+39.05	2.7820880725407315
+39.1	2.812700506066235
+39.15	2.8387062536976684
+39.2	2.859815440291729
+39.25	2.8757485559026637
+39.3	2.8862260905847275
+39.35	2.8909685343921754
+39.4	2.8896933381033625
+39.45	2.88220596576251
+39.5	2.8684351383696116
+39.55	2.84831516462723
+39.6	2.8217803532379255
+39.65	2.7887650129042707
+39.7	2.749215555301837
+39.75	2.7032821300300465
+39.8	2.6512432260603562
+39.85	2.5933784331749163
+39.9	2.5299673411559005
+39.95	2.461289539785448
+40.0	2.387624618845738
+40.05	2.309382639251262
+40.1	2.2271973082832233
+40.15	2.1416137243817324
+40.2	2.0531755284644544
+40.25	1.9624263614491042
+40.3	1.8699734826454582
+40.35	1.776467194124794
+40.4	1.6824960181922362
+40.45	1.5886482887552287
+40.5	1.4955123562642696
+40.55	1.4036516983023986
+40.6	1.313545514019804
+40.65	1.2256556255962074
+40.7	1.1404438552112806
+40.75	1.0583176995673573
+40.8	0.9794957543860676
+40.85	0.9042278716201069
+40.9	0.8327684501847347
+40.95	0.7653463683969849
+41.0	0.7018998122276514
+41.05	0.6424328873878251
+41.1	0.5870019836690359
+41.15	0.5356634908628372
+41.2	0.48825021680397124
+41.25	0.4444176188995021
+41.3	0.4041012732938408
+41.35	0.36723954572406914
+41.4	0.33377080192728387
+41.45	0.30362243113711357
+41.5	0.2763600899345818
+41.55	0.2516837927967325
+41.6	0.22945591738713703
+41.65	0.209538841369376
+41.7	0.1917732786524751
+41.75	0.1758746023709538
+41.8	0.161679472765694
+41.85	0.14905252460195836
+41.9	0.13784632627830154
+41.95	0.1278934252693722
+42.0	0.11908368836717981
+42.05	0.1113120483830949
+42.1	0.10446591631136029
+42.15	0.09845121853294315
+42.2	0.09319158283219851
+42.25	0.08861028241281234
+42.3	0.08463916201098227
+42.35	0.08122521407234237
+42.4	0.07831643580787666
+42.45	0.0758654739235994
+42.5	0.0738369170869481
+42.55	0.07219721838365807
+42.6	0.07091436466961498
+42.65	0.06996518321232907
+42.7	0.06932951585690098
+42.75	0.06898728566784557
+42.8	0.0689224950991947
+42.85	0.06912436978630924
+42.9	0.06958240841323958
+42.95	0.07028628355396403
+43.0	0.07122916682911698
+43.05	0.07240743899420646
+43.1	0.07381759669134402
+43.15	0.07545611389150828
+43.2	0.07731999614336757
+43.25	0.07941037597110506
+43.3	0.08172942465086322
+43.35	0.08427931345878466
+43.4	0.08706221367101083
+43.45	0.09007942741927127
+43.5	0.0933320376795295
+43.55	0.0968267173438848
+43.6	0.10057056512561802
+43.65	0.10457067973800833
+43.7	0.1088341598943372
+43.75	0.11336779680339376
+43.8	0.11817639945891993
+43.85	0.12327021349330179
+43.9	0.12866087131634213
+43.95	0.13436000533784684
+44.0	0.1403792406614977
+44.05	0.14672649199817114
+44.1	0.15341347456049814
+44.15	0.16045650394823952
+44.2	0.16787189576115993
+44.25	0.1756759655990199
+44.3	0.18388423564914563
+44.35	0.19250714637235597
+44.4	0.20156381087418795
+44.45	0.21107548925462793
+44.5	0.22106344161365696
+44.55	0.23154892805126012
+44.6	0.2425478510679057
+44.65	0.25407777368999906
+44.7	0.26616515101800964
+44.75	0.2788364706814432
+44.8	0.2921182203098104
+44.85	0.30603590204296366
+44.9	0.32060522227450095
+44.95	0.335853692676233
+45.0	0.3518128834096593
+45.05	0.368514364636286
+45.1	0.38598970651761955
+45.15	0.40426508617347806
+45.2	0.42335962560974005
+45.25	0.4433081691013613
+45.3	0.4641464116639815
+45.35	0.48591004831324086
+45.4	0.5086347740647702
+45.45	0.5323444819374843
+45.5	0.5570619791610821
+45.55	0.5828243056176808
+45.6	0.6096685214828106
+45.65	0.6376316869319897
+45.7	0.6667508583257383
+45.75	0.6970469198564483
+45.8	0.7285382233719169
+45.85	0.7612566207404455
+45.9	0.7952339638303215
+45.95	0.8305021045098512
+46.0	0.8670928946473209
+46.05	0.9050288958409014
+46.1	0.9443084734854684
+46.15	0.9849427581970677
+46.2	1.0269438048874155
+46.25	1.0703236684682036
+46.3	1.1150944038511432
+46.35	1.1612680659479446
+46.4	1.2088513185420762
+46.45	1.2578138237029648
+46.5	1.3081097097408725
+46.55	1.3596930515954415
+46.6	1.412517924206315
+46.65	1.4665384025131136
+46.7	1.5217085614554886
+46.75	1.57798247597306
+46.8	1.635314221005471
+46.85	1.693648128720972
+46.9	1.7528516409084787
+46.95	1.8127693370016478
+47.0	1.8732473851598435
+47.05	1.9341319535424553
+47.1	1.9952692103088734
+47.15	2.056505323618462
+47.2	2.1176842349547416
+47.25	2.1785955657369658
+47.3	2.238974989903395
+47.35	2.2985559866122367
+47.4	2.357072035021676
+47.45	2.4142566142899287
+47.5	2.46984320357518
+47.55	2.523551093513213
+47.6	2.5750499388741495
+47.65	2.6239989599470066
+47.7	2.670057377545862
+47.75	2.7128844124847675
+47.8	2.7521392855777935
+47.85	2.7874730009690016
+47.9	2.818551126305275
+47.95	2.84505942377113
+48.0	2.8666837044996707
+48.05	2.883109779624009
+48.1	2.8940232674323916
+48.15	2.89911804042681
+48.2	2.8982129832124115
+48.25	2.8911703578744428
+48.3	2.8778524264981473
+48.35	2.858121451168768
+48.4	2.8318383387096255
+48.45	2.798910443156489
+48.5	2.7594878519329833
+48.55	2.7137733542889904
+48.6	2.6619697394743893
+48.65	2.604279796739081
+48.7	2.5409063153329368
+48.75	2.472062012396174
+48.8	2.3981945608017954
+48.85	2.3198469923931833
+48.9	2.237533411313962
+48.95	2.1517679217077106
+49.0	2.0630646277180564
+49.05	1.9719517906749484
+49.1	1.8791201824433603
+49.15	1.7851986868000143
+49.2	1.6907801653481012
+49.25	1.596457479690865
+49.3	1.502824893692576
+49.35	1.4104503139302969
+49.4	1.3198167142749073
+49.45	1.2313959565175732
+49.5	1.145659902449512
+49.55	1.0630293467942362
+49.6	0.9837145222214397
+49.65	0.9079623667441789
+49.7	0.8360276106751893
+49.75	0.768158735848344
+49.8	0.7043253633633086
+49.85	0.6444797188885385
+49.9	0.5886725346776338
+49.95	0.536954542984162
+50.0	0.4893316606737419
diff --git a/test/julia_dde/test_basic_check_3.txt b/test/julia_dde/test_basic_check_3.txt
new file mode 100644
index 00000000..20ab7e04
--- /dev/null
+++ b/test/julia_dde/test_basic_check_3.txt
@@ -0,0 +1,1001 @@
+0.0	1.2
+0.05	1.1880597992376376
+0.1	1.1762383903532672
+0.15	1.1645346054176793
+0.2	1.1529472850205345
+0.25	1.1414752222367555
+0.3	1.1301172850650132
+0.35	1.1188723699778464
+0.4	1.1077393734477934
+0.45	1.0967171444550854
+0.5	1.0858044886142186
+0.55	1.0750003653428541
+0.6	1.0643037416988483
+0.65	1.0537135847400563
+0.7	1.0432288615243335
+0.75	1.0328485303971229
+0.8	1.022571356528724
+0.85	1.0123962869380496
+0.9	1.0023223752836392
+0.95	0.9923486752240321
+1.0	0.9824742404177691
+1.05	0.9726981245233894
+1.1	0.963019381199433
+1.15	0.9534370641044397
+1.2	0.9439501263044995
+1.25	0.9345573120423534
+1.3	0.9252577408761488
+1.35	0.9160505652198292
+1.4	0.9069349374873376
+1.45	0.8979100100926173
+1.5	0.8889749354496113
+1.55	0.8801288659722627
+1.6	0.8713709540745148
+1.65	0.8627003521703108
+1.7	0.8541162126735939
+1.75	0.8456174480352269
+1.8	0.8372028671950351
+1.85	0.8288717208428457
+1.9	0.8206232664837655
+1.95	0.8124567616229019
+2.0	0.8043714637653621
+2.05	0.7963666304162529
+2.1	0.7884415190806816
+2.15	0.7805953872637555
+2.2	0.7728274924705816
+2.25	0.7651370922062669
+2.3	0.757523443975919
+2.35	0.7499858052846445
+2.4	0.7425231592892392
+2.45	0.7351344369656738
+2.5	0.7278189831388171
+2.55	0.7205761452365398
+2.6	0.7134052706867127
+2.65	0.7063057069172068
+2.7	0.6992768013558928
+2.75	0.6923179014306418
+2.8	0.6854283545693244
+2.85	0.6786075081998116
+2.9	0.6718547097499745
+2.95	0.6651693066476835
+3.0	0.6585506463208098
+3.05	0.6521868090091297
+3.1	0.6462596541180435
+3.15	0.6407591978238596
+3.2	0.6356754563028846
+3.25	0.6309984457314268
+3.3	0.6267181822857937
+3.35	0.6228246821422926
+3.4	0.6193079614772313
+3.45	0.6161580364669172
+3.5	0.6133649232876582
+3.55	0.6109186381157617
+3.6	0.6088091971275353
+3.65	0.6070266164992867
+3.7	0.6055610062310477
+3.75	0.6044091849610194
+3.8	0.6035702667113718
+3.85	0.6030411400229324
+3.9	0.6028186934365286
+3.95	0.6028998154929878
+4.0	0.6032813947331372
+4.05	0.6039603196978046
+4.1	0.6049334789278171
+4.15	0.6061969408560554
+4.2	0.607747767703372
+4.25	0.6095875189135027
+4.3	0.61171794179823
+4.35	0.6141407836693366
+4.4	0.6168577918386053
+4.45	0.6198707136178186
+4.5	0.6231812963187596
+4.55	0.6267912872532108
+4.6	0.6307024337329545
+4.65	0.6349164830697739
+4.7	0.6394351825754515
+4.75	0.6442602795617699
+4.8	0.649393521340512
+4.85	0.6548363986236847
+4.9	0.6605923625204214
+4.95	0.6666687077262323
+5.0	0.6730729245090687
+5.05	0.6798125031368819
+5.1	0.6868949338776228
+5.15	0.694327706999243
+5.2	0.702118312769693
+5.25	0.7102742414569245
+5.3	0.7188029833288886
+5.35	0.7277120286535362
+5.4	0.7370088676988189
+5.45	0.7467009907326875
+5.5	0.7567958783893151
+5.55	0.7673009954492572
+5.6	0.7782282717605841
+5.65	0.7895913565176023
+5.7	0.8014038989146166
+5.75	0.813679548145934
+5.8	0.8264319534058602
+5.85	0.8396747638887014
+5.9	0.8534216287887638
+5.95	0.867686197300353
+6.0	0.8824821186177751
+6.05	0.8978210479139723
+6.1	0.9137051417826871
+6.15	0.930133801549427
+6.2	0.9471064285396992
+6.25	0.9646224240790109
+6.3	0.9826811894928694
+6.35	1.0012821261067824
+6.4	1.0204246352462572
+6.45	1.0401081182368006
+6.5	1.06033197640392
+6.55	1.0810959571500407
+6.6	1.1024051083860738
+6.65	1.1242536823420526
+6.7	1.146631632226644
+6.75	1.1695289112485163
+6.8	1.192935472616337
+6.85	1.2168412695387738
+6.9	1.2412362552244953
+6.95	1.2661103828821678
+7.0	1.2914536057204598
+7.05	1.3172558769480385
+7.1	1.3435071497735722
+7.15	1.3701973774057286
+7.2	1.3973165130531748
+7.25	1.4248545099245788
+7.3	1.4528013212286086
+7.35	1.4811469001739317
+7.4	1.509881199118115
+7.45	1.5389782800159062
+7.5	1.5683961525798626
+7.55	1.598100100657414
+7.6	1.628055408095991
+7.65	1.658227358743024
+7.7	1.688581236445943
+7.75	1.719082325052178
+7.8	1.7496959084091595
+7.85	1.780387270364318
+7.9	1.811122925033841
+7.95	1.8418669475319058
+8.0	1.8725597086609542
+8.05	1.90313851346047
+8.1	1.933540666969933
+8.15	1.9637034742288284
+8.2	1.9935642402766356
+8.25	2.0230602701528393
+8.3	2.052128868896921
+8.35	2.0807073415483615
+8.4	2.108732993146645
+8.45	2.1361393502296684
+8.5	2.16284287320635
+8.55	2.1887552138951114
+8.6	2.2137880239094327
+8.65	2.2378529548627957
+8.7	2.260861658368677
+8.75	2.2827257860405585
+8.8	2.3033569894919195
+8.85	2.3226668900090086
+8.9	2.3405616664134223
+8.95	2.3569434633821165
+9.0	2.371716076157124
+9.05	2.3847799002672945
+9.1	2.396042281634715
+9.15	2.4054174410912736
+9.2	2.412819599468856
+9.25	2.4181629775993505
+9.3	2.4213617963146428
+9.35	2.422333801532281
+9.4	2.421017966043976
+9.45	2.4173610505395455
+9.5	2.4113098248196683
+9.55	2.402811058685023
+9.6	2.3918073853741615
+9.65	2.3782692601247457
+9.7	2.3622002696856885
+9.75	2.3436045090838595
+9.8	2.3224860733461314
+9.85	2.2988490574993756
+9.9	2.2726975565704626
+9.95	2.244039548654731
+10.0	2.212937352669323
+10.05	2.179483711456287
+10.1	2.1437715190596918
+10.15	2.105893669523604
+10.2	2.065943056892092
+10.25	2.0240125752092206
+10.3	1.980195118519059
+10.35	1.9345835808656762
+10.4	1.8872723502754671
+10.45	1.8384294978875548
+10.5	1.7882366416651543
+10.55	1.7368513259590423
+10.6	1.6844310951199983
+10.65	1.6311335343071072
+10.7	1.5771337374081937
+10.75	1.5226097217672243
+10.8	1.4677278553411066
+10.85	1.41265450608675
+10.9	1.3575560419610566
+10.95	1.3025988309209373
+11.0	1.2479491282494664
+11.05	1.1937607528931462
+11.1	1.1401771120697988
+11.15	1.087341445996801
+11.2	1.0353969948915365
+11.25	0.984484065667813
+11.3	0.9347008324906576
+11.35	0.8861377047519156
+11.4	0.8388897689338305
+11.45	0.7930521115186531
+11.5	0.748719818988627
+11.55	0.705959243095179
+11.6	0.6647650478857721
+11.65	0.6251755908862296
+11.7	0.5872322722497952
+11.75	0.5509764921297068
+11.8	0.5164492863725721
+11.85	0.48361792647027246
+11.9	0.4524348830442019
+11.95	0.42289362720225393
+12.0	0.39498763005231813
+12.05	0.3686315661870969
+12.1	0.34373279718633004
+12.15	0.3202738246384638
+12.2	0.2982371501319477
+12.25	0.2776052752552281
+12.3	0.2583607015967536
+12.35	0.24048593074497288
+12.4	0.223950831846511
+12.45	0.20858331100426564
+12.5	0.19429923256729392
+12.55	0.18105445919832033
+12.6	0.16880485356006925
+12.65	0.15750114166684823
+12.7	0.1470540258920014
+12.75	0.13740996022300644
+12.8	0.12852624259140769
+12.85	0.12036017092874914
+12.9	0.1128609745101247
+12.95	0.10596741395436253
+13.0	0.09964203786227739
+13.05	0.09384868822581734
+13.1	0.08855032604103662
+13.15	0.08370226381849066
+13.2	0.07927243004574394
+13.25	0.07523257594725748
+13.3	0.07155441352980291
+13.35	0.06820745069454974
+13.4	0.0651675270046265
+13.45	0.06241399580003661
+13.5	0.05992621042078313
+13.55	0.05768323302303655
+13.6	0.055668129355264065
+13.65	0.053866461849464135
+13.7	0.05226379382060764
+13.75	0.05084595459195789
+13.8	0.04960182363429139
+13.85	0.04852201561788314
+13.9	0.047597159022623084
+13.95	0.046818043268843194
+14.0	0.046177674337955096
+14.05	0.0456706587399853
+14.1	0.04529163747434971
+14.15	0.045035254325706556
+14.2	0.04489705421988383
+14.25	0.04487463111043636
+14.3	0.044965848174891764
+14.35	0.045168568590777634
+14.4	0.04548065560180838
+14.45	0.04590066908055203
+14.5	0.046428999081977486
+14.55	0.047066322023146555
+14.6	0.047813314321121014
+14.65	0.048670652392962706
+14.7	0.04963895651796998
+14.75	0.050718620023989656
+14.8	0.05191253849246556
+14.85	0.05322426299196401
+14.9	0.054657344591051496
+14.95	0.05621533435829428
+15.0	0.057901783362258846
+15.05	0.059720242671511534
+15.1	0.06167423564461651
+15.15	0.06376812031817305
+15.2	0.06600782089469139
+15.25	0.06839938058178958
+15.3	0.07094884258708545
+15.35	0.07366224585439077
+15.4	0.07654376689032617
+15.45	0.07959974929474373
+15.5	0.08283893126579933
+15.55	0.0862700510016486
+15.6	0.08990184670044707
+15.65	0.09374305656035072
+15.7	0.09780070119392023
+15.75	0.10208288948700774
+15.8	0.10660157739539043
+15.85	0.1113687519719273
+15.9	0.11639640026947794
+15.95	0.12169626466104969
+16.0	0.12727628804408414
+16.05	0.13315014663393196
+16.1	0.139333946326771
+16.15	0.1458437930187787
+16.2	0.1526957926061338
+16.25	0.15990482372249143
+16.3	0.16748183360191107
+16.35	0.17544672115696183
+16.4	0.18382073593523884
+16.45	0.192625127484339
+16.5	0.2018811453518574
+16.55	0.21160507934505363
+16.6	0.22181536124463952
+16.65	0.23253963155802002
+16.7	0.24380557729343716
+16.75	0.2556408854591307
+16.8	0.26807233665258995
+16.85	0.2811177402975892
+16.9	0.2948082752268757
+16.95	0.3091795563883086
+17.0	0.32426719872974363
+17.05	0.3401068171990377
+17.1	0.35672771274342574
+17.15	0.37415694197152904
+17.2	0.39243917447317206
+17.25	0.41161962131703544
+17.3	0.43174349357180103
+17.35	0.45285515799569886
+17.4	0.4749833351492887
+17.45	0.49817534283248655
+17.5	0.5224876697127654
+17.55	0.5479768044575993
+17.6	0.574699235734463
+17.65	0.6027031591307277
+17.7	0.6320269190642993
+17.75	0.6627379423398659
+17.8	0.6949054091335899
+17.85	0.7285984996216338
+17.9	0.7638860935554969
+17.95	0.8008124933089992
+18.0	0.8394402408300372
+18.05	0.8798515650441616
+18.1	0.922128694876923
+18.15	0.9663538592538687
+18.2	1.0126032288425226
+18.25	1.0609217181003978
+18.3	1.1113944695913622
+18.35	1.1641149553854777
+18.4	1.219176647552802
+18.45	1.2766730181634047
+18.5	1.336680549796806
+18.55	1.3992478308498963
+18.6	1.464470198118454
+18.65	1.5324458195708326
+18.7	1.6032728631754014
+18.75	1.6770494969005143
+18.8	1.7538494409075045
+18.85	1.8337108248219487
+18.9	1.9167185876634334
+18.95	2.002959529098503
+19.0	2.092520448793682
+19.05	2.1854881464155036
+19.1	2.28193664578704
+19.15	2.3818692470976592
+19.2	2.485317656588622
+19.25	2.5923208770612622
+19.3	2.7029179113169235
+19.35	2.8171477621569476
+19.4	2.935049432382669
+19.45	3.0566494211038315
+19.5	3.1818830679881756
+19.55	3.3106478005311515
+19.6	3.4428410937337786
+19.65	3.578360422597069
+19.7	3.717103262122061
+19.75	3.8589670873097677
+19.8	4.003849373161208
+19.85	4.151647594677404
+19.9	4.302193278789322
+19.95	4.455127233993672
+20.0	4.610063803797267
+20.05	4.766617343784246
+20.1	4.924402209538742
+20.15	5.083032756644884
+20.2	5.242123340686828
+20.25	5.401272677380016
+20.3	5.559889523919631
+20.35	5.7172568196444855
+20.4	5.8726552958575144
+20.45	6.025365683861684
+20.5	6.174668714959923
+20.55	6.319844708945797
+20.6	6.460096425070274
+20.65	6.594460469514813
+20.7	6.721952145798012
+20.75	6.841586757438446
+20.8	6.952379607954696
+20.85	7.0533440911423195
+20.9	7.143446676186029
+20.95	7.2216660505591665
+21.0	7.2869970757373
+21.05	7.3384346131959965
+21.1	7.374973524410826
+21.15	7.395597049856846
+21.2	7.399417900209098
+21.25	7.385817836321792
+21.3	7.354196382550761
+21.35	7.303953063251845
+21.4	7.234487402780882
+21.45	7.145263304461381
+21.5	7.0363647566654155
+21.55	6.908128584835803
+21.6	6.760891630301893
+21.65	6.594990734393041
+21.7	6.410762738438569
+21.75	6.2088450726321724
+21.8	5.990732640705275
+21.85	5.75791155432729
+21.9	5.511864196359442
+21.95	5.2540729496629055
+22.0	4.986082433415764
+22.05	4.7101639066643335
+22.1	4.42856676538218
+22.15	4.143442902024981
+22.2	3.8569464178882313
+22.25	3.5713552672555364
+22.3	3.28889884770391
+22.35	3.0117352798858366
+22.4	2.7420203611280627
+22.45	2.481463293669825
+22.5	2.231572779218588
+22.55	1.9940537506509244
+22.6	1.7704738014123098
+22.65	1.5610608918512747
+22.7	1.3666090012348524
+22.75	1.188181307974745
+22.8	1.0257875495979925
+22.85	0.8783253325850336
+22.9	0.7462653264465491
+22.95	0.6298607329726912
+23.0	0.526689471408246
+23.05	0.4362126970096439
+23.1	0.3586789044286951
+23.15	0.2930769392946884
+23.2	0.23644639121785105
+23.25	0.1887843992470036
+23.3	0.15020916432014636
+23.35	0.11841625816064075
+23.4	0.09193646297837396
+23.45	0.07084691797318869
+23.5	0.0546440589583355
+23.55	0.04135889403855052
+23.6	0.030871343784575907
+23.65	0.023187432031184143
+23.7	0.01726905105966836
+23.75	0.01266949856717107
+23.8	0.009344910618525257
+23.85	0.006786482524138971
+23.9	0.004816778288678049
+23.95	0.0034500849725069157
+24.0	0.0025160101833220463
+24.05	0.0018312388039325286
+24.1	0.0013290771755335164
+24.15	0.0009639976527466686
+24.2	0.0006989804814600982
+24.25	0.0005067026621806486
+24.3	0.0003677309832842101
+24.35	0.00026755767682035294
+24.4	0.0001951844558401153
+24.45	0.00014284822334131387
+24.5	0.00010515507647664639
+24.55	7.778055265189298e-5
+24.6	5.793493283808354e-5
+24.65	4.3496595960544536e-5
+24.7	3.290011776746757e-5
+24.75	2.5142190984993185e-5
+24.8	1.940443720424485e-5
+24.85	1.5121385224876476e-5
+24.9	1.1965876035510612e-5
+24.95	9.54987103802627e-6
+25.0	7.73200127909234e-6
+25.05	6.365973014560343e-6
+25.1	5.2963544793087956e-6
+25.15	4.470689445671548e-6
+25.2	3.841277717496045e-6
+25.25	3.346995535139201e-6
+25.3	2.954501620571546e-6
+25.35	2.6471951701062873e-6
+25.4	2.4084753800566474e-6
+25.45	2.2215394495255226e-6
+25.5	2.073562751279236e-6
+25.55	1.9594262288370855e-6
+25.6	1.8745654488899708e-6
+25.65	1.8144159781287961e-6
+25.7	1.7744133832444533e-6
+25.75	1.7502211729269242e-6
+25.8	1.7400672049990554e-6
+25.85	1.743237785350034e-6
+25.9	1.758945971985891e-6
+25.95	1.786404822912663e-6
+26.0	1.8248273961363818e-6
+26.05	1.8734267496630814e-6
+26.1	1.9314159414987944e-6
+26.15	1.9980080296495497e-6
+26.2	2.0722150882898906e-6
+26.25	2.1536714723203395e-6
+26.3	2.2425495205690797e-6
+26.35	2.3390216235075385e-6
+26.4	2.4432601716071337e-6
+26.45	2.555437555339307e-6
+26.5	2.6757261651754785e-6
+26.55	2.804298391587074e-6
+26.6	2.941326625045521e-6
+26.65	3.086983256022234e-6
+26.7	3.2414406749886615e-6
+26.75	3.40487127241622e-6
+26.8	3.577447438776335e-6
+26.85	3.7587795135315416e-6
+26.9	3.947161901513964e-6
+26.95	4.143253394892425e-6
+27.0	4.347951835206379e-6
+27.05	4.56215506399529e-6
+27.1	4.786760922798623e-6
+27.15	5.0226672531558285e-6
+27.2	5.270771896606404e-6
+27.25	5.531972694689799e-6
+27.3	5.8071674889454775e-6
+27.35	6.097254120912909e-6
+27.4	6.403130432131533e-6
+27.45	6.72569426414086e-6
+27.5	7.065843458480336e-6
+27.55	7.424475856689423e-6
+27.6	7.802489300307588e-6
+27.65	8.20078163087427e-6
+27.7	8.620250689928989e-6
+27.75	9.061794319011182e-6
+27.8	9.526310359660316e-6
+27.85	1.0013348352028656e-5
+27.9	1.0521789101765655e-5
+27.95	1.1053647587290105e-5
+28.0	1.1611018389299699e-5
+28.05	1.2195996088492152e-5
+28.1	1.2810675265565194e-5
+28.15	1.3457150501216494e-5
+28.2	1.413751637614387e-5
+28.25	1.4853867471045e-5
+28.3	1.5608298366617607e-5
+28.35	1.640290364355941e-5
+28.4	1.7239777882568074e-5
+28.45	1.812101566434144e-5
+28.5	1.9048711569577178e-5
+28.55	2.0024960178973003e-5
+28.6	2.104681917801384e-5
+28.65	2.2112082089476636e-5
+28.7	2.3225295086541718e-5
+28.75	2.4391013570882673e-5
+28.8	2.5613792944173156e-5
+28.85	2.6898188608086833e-5
+28.9	2.8248755964297266e-5
+28.95	2.9670050414478296e-5
+29.0	3.116662736030351e-5
+29.05	3.274304220344654e-5
+29.1	3.4403850345581075e-5
+29.15	3.615360718838063e-5
+29.2	3.79968681335191e-5
+29.25	3.993818858267005e-5
+29.3	4.1982123937507134e-5
+29.35	4.4133229599704e-5
+29.4	4.639235532224559e-5
+29.45	4.875279781024011e-5
+29.5	5.122277540031674e-5
+29.55	5.3811850799978914e-5
+29.6	5.6529586716730105e-5
+29.65	5.938554585807354e-5
+29.7	6.238929093151308e-5
+29.75	6.555038464455198e-5
+29.8	6.887838970469367e-5
+29.85	7.238286881944164e-5
+29.9	7.607338469629903e-5
+29.95	7.995950004276982e-5
+30.0	8.40507775663572e-5
+30.05	8.835677997456466e-5
+30.1	9.288123963748541e-5
+30.15	9.761299340089545e-5
+30.2	0.0001025676957395089
+30.25	0.00010776399463869746
+30.3	0.00011322053808383316
+30.35	0.00011895597406028808
+30.4	0.0001249889505534338
+30.45	0.00013133811554864319
+30.5	0.00013802211703128791
+30.55	0.00014505960298673997
+30.6	0.00015246922140037142
+30.65	0.00016026962025755373
+30.7	0.00016847944754366006
+30.75	0.0001771109229141103
+30.8	0.00018615038032333557
+30.85	0.00019562494484115931
+30.9	0.0002055686493533325
+30.95	0.0002160155267456082
+31.0	0.00022699960990373743
+31.05	0.00023855493171347184
+31.1	0.0002507155250605632
+31.15	0.0002635154228307622
+31.2	0.0002769886579098222
+31.25	0.0002911692631834942
+31.3	0.0003060912715375298
+31.35	0.0003217548445953388
+31.4	0.00033817214436745693
+31.45	0.00035540139228868584
+31.5	0.00037350083780491173
+31.55	0.00039252873036202174
+31.6	0.0004125433194059033
+31.65	0.000433602854382442
+31.7	0.0004557655847375281
+31.75	0.0004790897599170474
+31.8	0.0005036336293668872
+31.85	0.0005294390884839935
+31.9	0.0005565016569481556
+31.95	0.0005849046396030645
+32.0	0.0006147424821662765
+32.05	0.0006461096303553476
+32.1	0.0006791005298878432
+32.15	0.0007138096264813124
+32.2	0.0007503313658533232
+32.25	0.0007887601937214239
+32.3	0.0008291876039328545
+32.35	0.0008716237833675735
+32.4	0.0009161628363820181
+32.45	0.0009629511909588468
+32.5	0.0010121352750806925
+32.55	0.001063861516730208
+32.6	0.0011182763438900472
+32.65	0.0011755261845428407
+32.7	0.0012357574666712514
+32.75	0.001299077958835768
+32.8	0.0013655343314847815
+32.85	0.0014353329672678646
+32.9	0.001508692841793293
+32.95	0.0015858329306693837
+33.0	0.0016669722095044113
+33.05	0.0017523296539066827
+33.1	0.001842124239484507
+33.15	0.0019365050661177198
+33.2	0.0020355859489668334
+33.25	0.0021396782998149874
+33.3	0.0022491003006402093
+33.35	0.0023641701334205286
+33.4	0.0024852059801339266
+33.45	0.002612526022758449
+33.5	0.002746419722344061
+33.55	0.0028870158376160682
+33.6	0.0030347071119279914
+33.65	0.0031899453963282207
+33.7	0.003353182541865236
+33.75	0.0035248703995874224
+33.8	0.003705460820543237
+33.85	0.0038953545607776047
+33.9	0.0040947707431942424
+33.95	0.004304283228705645
+34.0	0.004524522077522654
+34.05	0.004756117349856203
+34.1	0.004999699105917233
+34.15	0.0052558974059165806
+34.2	0.005525138328425961
+34.25	0.005807923134528589
+34.3	0.006105115939931086
+34.35	0.006417582537711775
+34.4	0.006746188720948853
+34.45	0.007091800282720693
+34.5	0.00745516258459202
+34.55	0.007836800333211071
+34.6	0.008237832852801386
+34.65	0.008659435913645637
+34.7	0.009102785286026734
+34.75	0.009569056740227337
+34.8	0.010059313773287167
+34.85	0.010574260646268591
+34.9	0.011115366044177084
+34.95	0.011684211579467325
+35.0	0.012282378864593673
+35.05	0.012911449512010727
+35.1	0.013572815103629234
+35.15	0.014267500540935528
+35.2	0.014997517875816616
+35.25	0.01576497806639549
+35.3	0.016571992070795443
+35.35	0.017420670847139805
+35.4	0.018312685173292334
+35.45	0.01924968415309058
+35.5	0.020234442905654234
+35.55	0.02126976089459592
+35.6	0.02235843758352828
+35.65	0.023503249379662153
+35.7	0.024706012728242802
+35.75	0.025969650272638308
+35.8	0.027297854679857802
+35.85	0.028694318616910446
+35.9	0.030162734750804814
+35.95	0.031706415455700944
+36.0	0.03332787329061813
+36.05	0.03503172850735315
+36.1	0.03682282936766161
+36.15	0.038706024133298336
+36.2	0.04068615069732654
+36.25	0.042766552443595805
+36.3	0.04495204172481472
+36.35	0.04724895832464577
+36.4	0.04966364202675038
+36.45	0.05220243261479136
+36.5	0.05487092642302335
+36.55	0.05767364436346181
+36.6	0.060618563381571004
+36.65	0.06371394333229173
+36.7	0.06696804407056656
+36.75	0.07038898289926133
+36.8	0.07398221637130845
+36.85	0.07775676609611071
+36.9	0.08172335523389757
+36.95	0.08589270694490071
+37.0	0.09027554438934955
+37.05	0.09488037833913766
+37.1	0.09971621831715513
+37.15	0.10479687889109346
+37.2	0.11013623039567051
+37.25	0.11574814316560113
+37.3	0.12164536652631897
+37.35	0.1278373938059498
+37.4	0.13434107976601245
+37.45	0.1411742954228076
+37.5	0.14835491179263216
+37.55	0.1559005710011262
+37.6	0.16382324101807627
+37.65	0.172141933997992
+37.7	0.18087957335219007
+37.75	0.19005908249198256
+37.8	0.1997033848286852
+37.85	0.20983053544200236
+37.9	0.2204598323038906
+37.95	0.2316204338146686
+38.0	0.24334158791252208
+38.05	0.25565254253564157
+38.1	0.26857984083615244
+38.15	0.2821435148894166
+38.2	0.2963788455731146
+38.25	0.31132290012334685
+38.3	0.3270127457762196
+38.35	0.34348480535160253
+38.4	0.3607633025216382
+38.45	0.37888741969303963
+38.5	0.39790372047370615
+38.55	0.417858768471545
+38.6	0.43879912729446385
+38.65	0.4607610734299169
+38.7	0.48378225644585704
+38.75	0.5079202865847827
+38.8	0.5332329579841185
+38.85	0.5597780647812899
+38.9	0.5876089915982601
+38.95	0.6167621673008777
+39.0	0.6473032833149648
+39.05	0.6793028656945395
+39.1	0.712831440493621
+39.15	0.7479594396605547
+39.2	0.7847342129135348
+39.25	0.823217622113749
+39.3	0.8634935941531566
+39.35	0.9056460559237182
+39.4	0.949758934317376
+39.45	0.9959079700090568
+39.5	1.0441410750223796
+39.55	1.09454823488251
+39.6	1.1472254344279125
+39.65	1.2022686584970301
+39.7	1.2597738919283366
+39.75	1.3198093040896974
+39.8	1.3824366531315018
+39.85	1.4477581076715695
+39.9	1.5158760735523693
+39.95	1.5868929566164067
+40.0	1.660910668944245
+40.05	1.7379829126208655
+40.1	1.8181742280585573
+40.15	1.9015790392263767
+40.2	1.9882917700934268
+40.25	2.078406844628762
+40.3	2.172018656805829
+40.35	2.2691717752327363
+40.4	2.369888098993709
+40.45	2.474221547636366
+40.5	2.582226040708271
+40.55	2.69395549775703
+40.6	2.809463838330252
+40.65	2.928799560415366
+40.7	3.0519314275712772
+40.75	3.1787907877067165
+40.8	3.309312404645847
+40.85	3.4434310422128362
+40.9	3.581081464231792
+40.95	3.722198434526901
+41.0	3.866716716922271
+41.05	4.014568827675874
+41.1	4.165580609355552
+41.15	4.3194294195811285
+41.2	4.475781793533374
+41.25	4.63430426639297
+41.3	4.794663373340665
+41.35	4.956525649557209
+41.4	5.119557630223281
+41.45	5.283420651179304
+41.5	5.447611150542289
+41.55	5.61144856088327
+41.6	5.774245796587521
+41.65	5.9353157720402505
+41.7	6.093971401626758
+41.75	6.249525678979698
+41.8	6.401248640407215
+41.85	6.5482069219410635
+41.9	6.689413242551435
+41.95	6.823880321208602
+42.0	6.950620876882759
+42.05	7.06864762854416
+42.1	7.176938262153528
+42.15	7.27440600752035
+42.2	7.359976359975245
+42.25	7.432574929134884
+42.3	7.491127324615977
+42.35	7.534553691114843
+42.4	7.5617939728915635
+42.45	7.572045793659248
+42.5	7.56455785483789
+42.55	7.538578857847491
+42.6	7.49335750410804
+42.65	7.428151725357762
+42.7	7.3426726058789225
+42.75	7.237025708571567
+42.8	7.111324833182236
+42.85	6.965683779457459
+42.9	6.800216347143829
+42.95	6.615284588450397
+43.0	6.412099000434837
+43.05	6.191997586548051
+43.1	5.956317372524124
+43.15	5.706395384097242
+43.2	5.443568647001452
+43.25	5.16954597937931
+43.3	4.88686121253996
+43.35	4.597719201100378
+43.4	4.304310549145162
+43.45	4.008958598265056
+43.5	3.7140880725611183
+43.55	3.4220094283648006
+43.6	3.1350328703342174
+43.65	2.8552853236801443
+43.7	2.584580943289718
+43.75	2.3248627259415255
+43.8	2.0780007719371825
+43.85	1.8447916139291527
+43.9	1.6263053444703741
+43.95	1.4238591429225103
+44.0	1.2378495508062841
+44.05	1.0676117660836748
+44.1	0.9137937748400329
+44.15	0.7766775562581366
+44.2	0.6541446509533947
+44.25	0.5460258275341494
+44.3	0.4526329589292429
+44.35	0.37208805481763735
+44.4	0.30225547898343
+44.45	0.24331791807279535
+44.5	0.19511542400617302
+44.55	0.15449031008368086
+44.6	0.12071355788444883
+44.65	0.0938826957051435
+44.7	0.07263525682076719
+44.75	0.05511310640146764
+44.8	0.0413911144043884
+44.85	0.031291435522553986
+44.9	0.023320331978155463
+44.95	0.017200308819630094
+45.0	0.012731689746122843
+45.05	0.009224318145506102
+45.1	0.006608549070533346
+45.15	0.004813506971795548
+45.2	0.0034975470646108546
+45.25	0.002538104683370468
+45.3	0.0018352837245791953
+45.35	0.0013238772501159243
+45.4	0.0009533027362415938
+45.45	0.0006857490475925572
+45.5	0.0004931924134454181
+45.55	0.0003549056604967315
+45.6	0.0002558520517525481
+45.65	0.00018496607420581378
+45.7	0.00013420141629981423
+45.75	9.773882246850954e-5
+45.8	7.162571556538256e-5
+45.85	5.283877574358323e-5
+45.9	3.921777643831624e-5
+45.95	2.9411155708976258e-5
+46.0	2.2204639441713795e-5
+46.05	1.6983161478800986e-5
+46.1	1.3102655516730324e-5
+46.15	1.0227383487653142e-5
+46.2	8.113981211825818e-6
+46.25	6.493126781687679e-6
+46.3	5.27542763847095e-6
+46.35	4.366588897499584e-6
+46.4	3.654188041375601e-6
+46.45	3.1018881618384487e-6
+46.5	2.679925024487994e-6
+46.55	2.3546173667701558e-6
+46.6	2.095266698315522e-6
+46.65	1.8914372455343665e-6
+46.7	1.7336011148483766e-6
+46.75	1.6122304126793446e-6
+46.8	1.5179976240508768e-6
+46.85	1.4453581337628131e-6
+46.9	1.3920033773926119e-6
+46.95	1.3557218835400264e-6
+47.0	1.3343021808048348e-6
+47.05	1.3255327977867997e-6
+47.1	1.327202263085691e-6
+47.15	1.3374259128759728e-6
+47.2	1.3558202866841001e-6
+47.25	1.3821522496182739e-6
+47.3	1.4161683210163368e-6
+47.35	1.4576150202161354e-6
+47.4	1.5062388665554982e-6
+47.45	1.5617863793722795e-6
+47.5	1.624004078004304e-6
+47.55	1.6926384817894189e-6
+47.6	1.767436110065476e-6
+47.65	1.8478936884213871e-6
+47.7	1.9329092003405298e-6
+47.75	2.0226895515575455e-6
+47.8	2.117589842249513e-6
+47.85	2.217965172593514e-6
+47.9	2.3241706427665866e-6
+47.95	2.436561352945827e-6
+48.0	2.5554924033082717e-6
+48.05	2.6813188940310033e-6
+48.1	2.814395925291107e-6
+48.15	2.955078597265612e-6
+48.2	3.1037220101316256e-6
+48.25	3.260681264066172e-6
+48.3	3.42631145924634e-6
+48.35	3.6009676958492197e-6
+48.4	3.78500507405183e-6
+48.45	3.978778694031291e-6
+48.5	4.182643654185989e-6
+48.55	4.3958396573664185e-6
+48.6	4.6180155849332155e-6
+48.65	4.850131498355723e-6
+48.7	5.093147459103417e-6
+48.75	5.348023528645634e-6
+48.8	5.615719768451817e-6
+48.85	5.897196239991412e-6
+48.9	6.193413004733749e-6
+48.95	6.5053301241483195e-6
+49.0	6.833907659704442e-6
+49.05	7.180105672871568e-6
+49.1	7.544884225119155e-6
+49.15	7.929203377916511e-6
+49.2	8.334023192733152e-6
+49.25	8.76030373103837e-6
+49.3	9.209005054301629e-6
+49.35	9.68098136958124e-6
+49.4	1.0174990535730868e-5
+49.45	1.0691885478597057e-5
+49.5	1.1233653340982156e-5
+49.55	1.1802281265688716e-5
+49.6	1.2399756395519322e-5
+49.65	1.3028065873276285e-5
+49.7	1.3689196841762275e-5
+49.75	1.4385136443779596e-5
+49.8	1.5117871822130826e-5
+49.85	1.588939011961856e-5
+49.9	1.6701678479045073e-5
+49.95	1.7556724043213083e-5
+50.0	1.845651395492484e-5
diff --git a/test/julia_dde/test_basic_check_4.txt b/test/julia_dde/test_basic_check_4.txt
new file mode 100644
index 00000000..e61f2232
--- /dev/null
+++ b/test/julia_dde/test_basic_check_4.txt
@@ -0,0 +1,1001 @@
+0.0	1.2
+0.05	1.1880598004963274
+0.1	1.1762384079684047
+0.15	1.1645346400763932
+0.2	1.1529473230412208
+0.25	1.141475299112477
+0.3	1.1301174233925737
+0.35	1.1188725616239938
+0.4	1.1077395901892904
+0.45	1.0967173961110868
+0.5	1.0858048770520767
+0.55	1.075000941315024
+0.6	1.064304507842764
+0.65	1.0537145062182005
+0.7	1.043229876664309
+0.75	1.0328495700441358
+0.8	1.0225725478607945
+0.85	1.012397782257473
+0.9	1.0023242560174284
+0.95	0.9923509625639848
+1.0	0.9824768980307879
+1.05	0.9727010597135578
+1.1	0.9630224771641795
+1.15	0.9534401895136975
+1.2	0.9439532442848286
+1.25	0.9345606973919633
+1.3	0.9252616131411647
+1.35	0.916055064230169
+1.4	0.9069401317483856
+1.45	0.8979159051768966
+1.5	0.888981482388457
+1.55	0.8801359696474954
+1.6	0.8713784816101126
+1.65	0.8627081413240832
+1.7	0.8541240802288541
+1.75	0.8456254381555455
+1.8	0.837211363326951
+1.85	0.8288810123575361
+1.9	0.8206335502534408
+1.95	0.8124681504124766
+2.0	0.80438399462413
+2.05	0.7963802730695569
+2.1	0.7884561843215901
+2.15	0.7806109353447339
+2.2	0.7728437414951651
+2.25	0.765153826520734
+2.3	0.7575404225609632
+2.35	0.7500027701470497
+2.4	0.7425401175213142
+2.45	0.735151718621093
+2.5	0.7278368354138617
+2.55	0.7205947369425942
+2.6	0.7134246992198385
+2.65	0.7063260052277169
+2.7	0.6992979449179254
+2.75	0.6923398152117342
+2.8	0.6854509199999872
+2.85	0.6786305701431024
+2.9	0.6718780834710726
+2.95	0.6651927847834634
+3.0	0.6585740058494151
+3.05	0.6522129149573156
+3.1	0.6462924386898269
+3.15	0.6407999646879741
+3.2	0.6357234315305058
+3.25	0.6310513287338946
+3.3	0.626772696752336
+3.35	0.6228771269777498
+3.4	0.619354761739779
+3.45	0.6161962943057904
+3.5	0.6133929688808745
+3.55	0.6109365806078452
+3.6	0.60881947556724
+3.65	0.6070345507773204
+3.7	0.605575254194071
+3.75	0.6044355847112005
+3.8	0.6036100921601407
+3.85	0.6030938773100476
+3.9	0.6028825918678004
+3.95	0.6029724384780017
+4.0	0.6033601707229788
+4.05	0.604043093122781
+4.1	0.6050190611351829
+4.15	0.6062864811556813
+4.2	0.6078443105174975
+4.25	0.609692057491576
+4.3	0.6118297812865848
+4.35	0.6142580920489168
+4.4	0.6169781508626868
+4.45	0.619991669749734
+4.5	0.6233009116696199
+4.55	0.6269085710336688
+4.6	0.630815543628207
+4.65	0.6350235971587848
+4.7	0.6395355437952476
+4.75	0.6443546093224879
+4.8	0.6494844331404451
+4.85	0.6549290682641051
+4.9	0.6606929813235011
+4.95	0.666781052563712
+5.0	0.6731985758448654
+5.05	0.6799512586421336
+5.1	0.6870452220457377
+5.15	0.694487000760944
+5.2	0.7022835431080668
+5.25	0.7104422110224664
+5.3	0.7189707800545505
+5.35	0.7278774393697731
+5.4	0.7371707917486356
+5.45	0.746859958736378
+5.5	0.7569547677070346
+5.55	0.7674652488443671
+5.6	0.7784019827980168
+5.65	0.7897761339397672
+5.7	0.8015994503635439
+5.75	0.813884263885415
+5.8	0.8266434900435903
+5.85	0.8398906280984233
+5.9	0.8536397610324082
+5.95	0.8679055555501825
+6.0	0.8827032620785261
+6.05	0.8980442119576897
+6.1	0.9139302097477254
+6.15	0.9303622713080193
+6.2	0.9473409029130228
+6.25	0.9648661012522547
+6.3	0.9829373534303
+6.35	1.0015536369668097
+6.4	1.0207134197965013
+6.45	1.0404146602691589
+6.5	1.060654807149632
+6.55	1.0814307996178367
+6.6	1.1027390672687578
+6.65	1.1245755301124434
+6.7	1.1469355985740082
+6.75	1.1698141734936365
+6.8	1.1932056461265734
+6.85	1.2171038981431364
+6.9	1.2415023016287001
+6.95	1.266393719083723
+7.0	1.2917705034237088
+7.05	1.317622676238042
+7.1	1.3439344875291743
+7.15	1.3706943818076935
+7.2	1.397889317915369
+7.25	1.4255041476036272
+7.3	1.4535216155335473
+7.35	1.4819223592758635
+7.4	1.5106849093109642
+7.45	1.5397856890288912
+7.5	1.5691990147293418
+7.55	1.598897095621667
+7.6	1.6288500338248735
+7.65	1.65902582436762
+7.7	1.6893903551882197
+7.75	1.7199074071346396
+7.8	1.750538653964509
+7.85	1.7812436623450993
+7.9	1.811979891853345
+7.95	1.8427026949758254
+8.0	1.8733655204108828
+8.05	1.9039160239594701
+8.1	1.9342967182776467
+8.15	1.9644477008356354
+8.2	1.9943066682588895
+8.25	2.023808916328101
+8.3	2.052887339979193
+8.35	2.0814724333033228
+8.4	2.1094922895468846
+8.45	2.136872601111504
+8.5	2.163536659554045
+8.55	2.1894053555866013
+8.6	2.2143971790765016
+8.65	2.238428219046314
+8.7	2.261412163673834
+8.75	2.2832603002920964
+8.8	2.3038815153893686
+8.85	2.323182294609149
+8.9	2.341066722750178
+8.95	2.3574364837664246
+9.0	2.3721908607670934
+9.05	2.385243561217768
+9.1	2.3965045430152734
+9.15	2.4058713525386093
+9.2	2.4132491205087794
+9.25	2.418550561988795
+9.3	2.4216959763836727
+9.35	2.4226132474404354
+9.4	2.4212378432481154
+9.45	2.4175128162377466
+9.5	2.4113888031823736
+9.55	2.402824025197046
+9.6	2.391784287738818
+9.65	2.3782429806067538
+9.7	2.3621810779419197
+9.75	2.3435871382273943
+9.8	2.3224573042882537
+9.85	2.2987953032915898
+9.9	2.2726124467464954
+9.95	2.243936662874981
+10.0	2.212829407232541
+10.05	2.1793603051700225
+10.1	2.1436071549890467
+10.15	2.1056560286629353
+10.2	2.0656012718367216
+10.25	2.023545503827138
+10.3	1.979599617622623
+10.35	1.9338827798833258
+10.4	1.8865224309410902
+10.45	1.837654284799473
+10.5	1.787422329133732
+10.55	1.7359788252908328
+10.6	1.6834843082894428
+10.65	1.6301075868199335
+10.7	1.5760257432443914
+10.75	1.5214241335965866
+10.8	1.4664962694610577
+10.85	1.4114290106972407
+10.9	1.356387562738108
+10.95	1.3015295749133389
+11.0	1.2470056475865348
+11.05	1.192959332155231
+11.1	1.1395271310508905
+11.15	1.0868384977388998
+11.2	1.0350158367185822
+11.25	0.9841745035231799
+11.3	0.9344228047198724
+11.35	0.8858619979097635
+11.4	0.8385862917278817
+11.45	0.7926828458431956
+11.5	0.7482317709585881
+11.55	0.7053061288108814
+11.6	0.6639719321708164
+11.65	0.6242881448430789
+11.7	0.5863066816662622
+11.75	0.5500724085129048
+11.8	0.5156056263225116
+11.85	0.48282765766606284
+11.9	0.45171302639042277
+11.95	0.4222417601688053
+12.0	0.39438628899652595
+12.05	0.3681135967598563
+12.1	0.343382325689573
+12.15	0.3201451306927898
+12.2	0.2983525808812395
+12.25	0.2779531595712671
+12.3	0.25889326428383524
+12.35	0.24111720674452214
+12.4	0.22456721288351963
+12.45	0.2091834228356374
+12.5	0.19490389094029853
+12.55	0.18166458574154323
+12.6	0.16940455502062668
+12.65	0.15806942563057208
+12.7	0.14760164949795584
+12.75	0.1379459766189818
+12.8	0.12904954090635606
+12.85	0.1208618601892862
+12.9	0.11333483621347977
+12.95	0.1064227546411474
+13.0	0.10008352079000768
+13.05	0.09427729770589059
+13.1	0.0889638838519627
+13.15	0.08410564922735048
+13.2	0.07966758244263973
+13.25	0.07561729071987416
+13.3	0.07192499989255652
+13.35	0.0685635544056481
+13.4	0.06550841731556832
+13.45	0.06273767029019571
+13.5	0.060232013608867055
+13.55	0.057974766162377775
+13.6	0.05595154540390375
+13.65	0.054143728150752184
+13.7	0.05253542096143592
+13.75	0.05111376431213964
+13.8	0.049867039237643705
+13.85	0.04878466733132397
+13.9	0.04785721074515173
+13.95	0.04707637218969389
+14.0	0.0464349949341128
+14.05	0.04592701769836477
+14.1	0.045546624739730736
+14.15	0.04528931834969655
+14.2	0.0451514955794
+14.25	0.04513008817380427
+14.3	0.045222562571697905
+14.35	0.04542691990569489
+14.4	0.045741696002234565
+14.45	0.04616596138158169
+14.5	0.04669932125782645
+14.55	0.047341915538884324
+14.6	0.04809441882649631
+14.65	0.04895804028474974
+14.7	0.049934201531537686
+14.75	0.051024408893837934
+14.8	0.05223068084161456
+14.85	0.053555451081960144
+14.9	0.05500156855909605
+14.95	0.05657229745437189
+15.0	0.05827131718626629
+15.05	0.060102892037134606
+15.1	0.062071401418132756
+15.15	0.06418132740669906
+15.2	0.06643757679748863
+15.25	0.06884548110237429
+15.3	0.07141079655044563
+15.35	0.07413970408800942
+15.4	0.07703880937859015
+15.45	0.08011514280292893
+15.5	0.08337615945898437
+15.55	0.08682973916193217
+15.6	0.09048418644416534
+15.65	0.0943482305552937
+15.7	0.09843102546214476
+15.75	0.10274214984876331
+15.8	0.10729171804825564
+15.85	0.11209077099320668
+15.9	0.11715043463737092
+15.95	0.12248247023798853
+16.0	0.12809937089528062
+16.05	0.13401436155244756
+16.1	0.1402413989956704
+16.15	0.1467951718541095
+16.2	0.15369110059990698
+16.25	0.1609453375481831
+16.3	0.1685747668570395
+16.35	0.17659700452755775
+16.4	0.1850304633897114
+16.45	0.19389507162903813
+16.5	0.203211036660323
+16.55	0.2129994958726723
+16.6	0.22328282470582111
+16.65	0.23408463665013263
+16.7	0.24542978324660109
+16.75	0.25734435408684836
+16.8	0.26985567681312306
+16.85	0.2829923171183062
+16.9	0.296784078745904
+16.95	0.31126200349005567
+17.0	0.32645837119552795
+17.05	0.34240669975771165
+17.1	0.35914174512263297
+17.15	0.3766995493478325
+17.2	0.39511866943589197
+17.25	0.41443937956042687
+17.3	0.43470351596029744
+17.35	0.4559547854382354
+17.4	0.47823876536084137
+17.45	0.5016029036585927
+17.5	0.526096518825831
+17.55	0.5517707999207738
+17.6	0.5786788065655069
+17.65	0.6068754689459871
+17.7	0.636417587812045
+17.75	0.6673638344773806
+17.8	0.6997747508195656
+17.85	0.73371274928004
+17.9	0.769242112864117
+17.95	0.8064298972342026
+18.0	0.8453463489441032
+18.05	0.8860624040359305
+18.1	0.9286512860147209
+18.15	0.9731882620893852
+18.2	1.019750141598271
+18.25	1.0684152760091428
+18.3	1.1192635589191955
+18.35	1.1723764260550513
+18.4	1.2278368552727539
+18.45	1.2857284043169876
+18.5	1.3461360686558452
+18.55	1.4091461867043567
+18.6	1.4748444613953504
+18.65	1.5433159537011103
+18.7	1.6146450826333911
+18.75	1.6889156252433966
+18.8	1.7662107166217869
+18.85	1.84661082758134
+18.9	1.9301906122177452
+18.95	2.017031141511684
+19.0	2.107207032862616
+19.05	2.2007850152801156
+19.1	2.2978239293838714
+19.15	2.398374727403668
+19.2	2.5024804731794337
+19.25	2.6101763421611723
+19.3	2.721489621409019
+19.35	2.836439709593211
+19.4	2.955029423231612
+19.45	3.0772339327692704
+19.5	3.203040909069934
+19.55	3.3324137121670563
+19.6	3.465288240102433
+19.65	3.6015729289261924
+19.7	3.741148752696839
+19.75	3.8838692234811942
+19.8	4.029560391354422
+19.85	4.178020844400036
+19.9	4.32902170870988
+19.95	4.482306648384192
+20.0	4.63759182333648
+20.05	4.794549975980078
+20.1	4.952790701757586
+20.15	5.1118775686148
+20.2	5.271328344787154
+20.25	5.430614998799649
+20.3	5.589163699466898
+20.35	5.746354815893115
+20.4	5.901522917472108
+20.45	6.053956773887321
+20.5	6.202899355111759
+20.55	6.347547831408041
+20.6	6.4870535733283985
+20.65	6.620522151714647
+20.7	6.74701333769822
+20.75	6.86554110270015
+20.8	6.975097585988036
+20.85	7.074756987241683
+20.9	7.163345419608717
+20.95	7.23970926129966
+21.0	7.302777059566318
+21.05	7.351559530701824
+21.1	7.385149560040613
+21.15	7.402722201958439
+21.2	7.403534679872361
+21.25	7.38692638624075
+21.3	7.352318882563292
+21.35	7.299215899380968
+21.4	7.227203336276111
+21.45	7.135949261872304
+21.5	7.0252039138344955
+21.55	6.8947996988688915
+21.6	6.74465119272308
+21.65	6.575224674265068
+21.7	6.387733512213771
+21.75	6.183130324475875
+21.8	5.962536763414633
+21.85	5.727249092536215
+21.9	5.478738186489729
+21.95	5.218649531067155
+22.0	4.9488032232034325
+22.05	4.67119397097642
+22.1	4.387991093606868
+22.15	4.10153852145851
+22.2	3.814354796037904
+22.25	3.5290313268426377
+22.3	3.247680664717435
+22.35	2.972145181905108
+22.4	2.7041601490817597
+22.45	2.445353592778882
+22.5	2.197246295383466
+22.55	1.9612517951379447
+22.6	1.7386763861401908
+22.65	1.5307191183435263
+22.7	1.3384717975567046
+22.75	1.1626849204819312
+22.8	1.002545500134845
+22.85	0.8580447324446863
+22.9	0.7290660272529799
+22.95	0.6151088661398314
+23.0	0.5152888024239812
+23.05	0.42833746116276894
+23.1	0.3528700782886587
+23.15	0.2882686813177955
+23.2	0.23370795504952072
+23.25	0.18814604167223206
+23.3	0.1503427800968641
+23.35	0.11902811119375684
+23.4	0.0934657882479476
+23.45	0.07288677871920558
+23.5	0.056463495964596405
+23.55	0.04336624786792155
+23.6	0.033048393231938075
+23.65	0.025044693612199312
+23.7	0.018879588046535827
+23.75	0.014117205126969798
+23.8	0.010494168668146536
+23.85	0.007771952430576503
+23.9	0.00572659237794399
+23.95	0.004196649465176822
+24.0	0.0030682684605941022
+24.05	0.0022382433124819743
+24.1	0.0016268421273769097
+24.15	0.0011818193528510545
+24.2	0.0008587195343808228
+24.25	0.0006233612455862654
+24.3	0.0004533600206809531
+24.35	0.0003306483306907093
+24.4	0.00024162783734463168
+24.45	0.00017738064466219696
+24.5	0.0001309525689239069
+24.55	9.716850425920782e-5
+24.6	7.26240966192432e-5
+24.65	5.47478818928715e-5
+24.7	4.1621424946029056e-5
+24.75	3.1947149008192264e-5
+24.8	2.47950842119474e-5
+24.85	1.9461844891531977e-5
+24.9	1.545870390331348e-5
+24.95	1.2439828811011545e-5
+25.0	1.014770216875457e-5
+25.05	8.387099067448246e-6
+25.1	7.023043078151283e-6
+25.15	5.9697343991657785e-6
+25.2	5.149996286641158e-6
+25.25	4.511792414769304e-6
+25.3	4.012664558191549e-6
+25.35	3.6184372626452725e-6
+25.4	3.303217844963916e-6
+25.45	3.049396393076911e-6
+25.5	2.847645766009793e-6
+25.55	2.69587554230337e-6
+25.6	2.5835641587810274e-6
+25.65	2.5028908244649546e-6
+25.7	2.4494096105647085e-6
+25.75	2.419391019410189e-6
+25.8	2.4098219844516108e-6
+25.85	2.418405870259513e-6
+25.9	2.443562472524755e-6
+25.95	2.4839527078495604e-6
+26.0	2.5384142814260535e-6
+26.05	2.606273821426257e-6
+26.1	2.6869379843119175e-6
+26.15	2.7798933189022552e-6
+26.2	2.8847062663740076e-6
+26.25	3.001023160261364e-6
+26.3	3.128570226456017e-6
+26.35	3.2671535832071314e-6
+26.4	3.4166592411213753e-6
+26.45	3.57705310316289e-6
+26.5	3.7483809646533127e-6
+26.55	3.930768513271749e-6
+26.6	4.124421329054799e-6
+26.65	4.32960786038784e-6
+26.7	4.546517541656452e-6
+26.75	4.775505565572347e-6
+26.8	5.017017532787742e-6
+26.85	5.271549794558421e-6
+26.9	5.539649452743732e-6
+26.95	5.821914359806641e-6
+27.0	6.118993118813638e-6
+27.05	6.431585083434791e-6
+27.1	6.760440357943753e-6
+27.15	7.106359797217719e-6
+27.2	7.470195006737535e-6
+27.25	7.852848342587543e-6
+27.3	8.255272911455645e-6
+27.35	8.678472570633381e-6
+27.4	9.123501928015798e-6
+27.45	9.591466342101632e-6
+27.5	1.008352192199302e-5
+27.55	1.060087552739576e-5
+27.6	1.1144784768619385e-5
+27.65	1.1716558006576588e-5
+27.7	1.2317554352783923e-5
+27.75	1.2949184065999924e-5
+27.8	1.3613223077415422e-5
+27.85	1.4311399433303857e-5
+27.9	1.504526308691353e-5
+27.95	1.581650858713442e-5
+28.0	1.6626975078497833e-5
+28.05	1.7478646301176728e-5
+28.1	1.837365059098551e-5
+28.15	1.9314260879380068e-5
+28.2	2.0302894693457978e-5
+28.25	2.1342114155958202e-5
+28.3	2.243462598526123e-5
+28.35	2.358328149538912e-5
+28.4	2.4791076596005323e-5
+28.45	2.6061151792415053e-5
+28.5	2.739679218556477e-5
+28.55	2.880142747204276e-5
+28.6	3.0278631944078403e-5
+28.65	3.18321244895429e-5
+28.7	3.346576859194911e-5
+28.75	3.518357233045125e-5
+28.8	3.6989688379844895e-5
+28.85	3.8888414010567146e-5
+28.9	4.0884191088696255e-5
+28.95	4.298160607595395e-5
+29.0	4.518565941340856e-5
+29.05	4.750323389336468e-5
+29.1	4.993968133981481e-5
+29.15	5.2500387459309735e-5
+29.2	5.519125036977669e-5
+29.25	5.801868060051815e-5
+29.3	6.098960109221252e-5
+29.35	6.411144719691421e-5
+29.4	6.739216667805312e-5
+29.45	7.084021971043556e-5
+29.5	7.446457888024339e-5
+29.55	7.827472918503394e-5
+29.6	8.228066803374081e-5
+29.65	8.64929052466731e-5
+29.7	9.092246305551653e-5
+29.75	9.558087610333142e-5
+29.8	0.00010048019144455492
+29.85	0.00010563296854499985
+29.9	0.00011105227928185373
+29.95	0.00011675170794368223
+30.0	0.0001227453512304237
+30.05	0.00012904781825339625
+30.1	0.00013567423053529063
+30.15	0.0001426402220101734
+30.2	0.00014996193902348996
+30.25	0.00015765604033205773
+30.3	0.00016574058516240938
+30.35	0.00017424111370541984
+30.4	0.0001831776763267837
+30.45	0.00019257013950133904
+30.5	0.00020244023055680148
+30.55	0.00021281153767376562
+30.6	0.00022370950988570654
+30.65	0.0002351614570789764
+30.7	0.0002471965499928098
+30.75	0.00025984582021931773
+30.8	0.0002731421602034928
+30.85	0.000287120323243203
+30.9	0.00030181692348919944
+30.95	0.00031727043594511185
+31.0	0.00033352119646744673
+31.05	0.00035061140176559293
+31.1	0.00036858510940181437
+31.15	0.0003874882377912571
+31.2	0.0004073685662019494
+31.25	0.0004282757347547921
+31.3	0.0004502612444235679
+31.35	0.0004733784570349454
+31.4	0.0004976825952684588
+31.45	0.0005232307426565268
+31.5	0.0005500818435844629
+31.55	0.0005782967032904355
+31.6	0.0006079503852070661
+31.65	0.0006391300928337179
+31.7	0.0006719048829063695
+31.75	0.0007063497942577817
+31.8	0.0007425463982876225
+31.85	0.0007805827989624643
+31.9	0.0008205536328157828
+31.95	0.0008625600689479657
+32.0	0.0009067098090262956
+32.05	0.0009531170872849661
+32.1	0.0010019026705250823
+32.15	0.0010531938581146301
+32.2	0.0011071244819885395
+32.25	0.0011638349066485994
+32.3	0.001223472029163538
+32.35	0.0012861892791689834
+32.4	0.0013521466188674495
+32.45	0.001421510543028392
+32.5	0.0014944540789881082
+32.55	0.0015711567866498796
+32.6	0.0016518047584838263
+32.65	0.0017365906195270112
+32.7	0.0018257135273834148
+32.75	0.0019193791722238486
+32.8	0.002017800891978172
+32.85	0.0021212672343752857
+32.9	0.002230045548369633
+32.95	0.0023443806222218783
+33.0	0.0024645392476519467
+33.05	0.002590810219839141
+33.1	0.002723504337422094
+33.15	0.002862954402498706
+33.2	0.003009515220626284
+33.25	0.003163563600821381
+33.3	0.0033254983555599292
+33.35	0.003495740300777223
+33.4	0.003674732255867753
+33.45	0.0038629390436855106
+33.5	0.004060847490543645
+33.55	0.004268966426214731
+33.6	0.004487826683930703
+33.65	0.004717981100382675
+33.7	0.004960004515721289
+33.75	0.005214493773556361
+33.8	0.005482067720957016
+33.85	0.005763367208451878
+33.9	0.006059055090028736
+33.95	0.0063698162231347515
+34.0	0.006696362113585118
+34.05	0.00703964166167945
+34.1	0.007400529648686494
+34.15	0.007779842808177503
+34.2	0.008178469684576748
+34.25	0.008597370633161057
+34.3	0.009037577820060243
+34.35	0.009500195222256945
+34.4	0.009986398627586372
+34.45	0.010497435634736808
+34.5	0.011034625653249098
+34.55	0.01159935990351698
+34.6	0.012193101416787191
+34.65	0.012817385035158838
+34.7	0.013473817411584306
+34.75	0.01416407700986826
+34.8	0.014889914104668726
+34.85	0.015653150781496157
+34.9	0.016455680936713825
+34.95	0.017299470277538028
+35.0	0.01818655632203741
+35.05	0.019119048399133965
+35.1	0.020099127648602316
+35.15	0.021129047021069634
+35.2	0.02221127816980577
+35.25	0.02334899952533738
+35.3	0.0245448536781086
+35.35	0.025801588371398444
+35.4	0.027122178838323182
+35.45	0.028509827801837626
+35.5	0.029967965474733606
+35.55	0.0315002495596414
+35.6	0.03311056524902905
+35.65	0.03480302522520191
+35.7	0.03658196966030365
+35.75	0.038451966216315185
+35.8	0.040417810045055846
+35.85	0.042484523788182765
+35.9	0.04465735757719016
+35.95	0.04694178903341079
+36.0	0.049343523268014736
+36.05	0.05186849288201033
+36.1	0.05452285796624317
+36.15	0.057313006101397376
+36.2	0.06024555235799432
+36.25	0.06332733929639267
+36.3	0.06656543696679054
+36.35	0.06996714922746461
+36.4	0.07354161492253536
+36.45	0.07729797038358993
+36.5	0.0812447315842601
+36.55	0.08539109106659043
+36.6	0.08974691794103712
+36.65	0.09432275788646544
+36.7	0.09912983315015493
+36.75	0.1041800425477933
+36.8	0.10948596146348222
+36.85	0.11506084184973359
+36.9	0.12091861222746997
+36.95	0.12707387768602704
+37.0	0.13354191988314948
+37.05	0.14033869704499438
+37.1	0.14748084396613156
+37.15	0.1549856720095387
+37.2	0.1628711691066082
+37.25	0.1711559997571415
+37.3	0.179859505029352
+37.35	0.18900170255986756
+37.4	0.19860328655372045
+37.45	0.20868562778436117
+37.5	0.21927077359364575
+37.55	0.2303825623341242
+37.6	0.2420484518884242
+37.65	0.2542949883150577
+37.7	0.2671501517028943
+37.75	0.28064354236051037
+37.8	0.29480638081620114
+37.85	0.30967150781797603
+37.9	0.3252733843335506
+37.95	0.34164809155036563
+38.0	0.35883333087556424
+38.05	0.3768684239360106
+38.1	0.39579431257828335
+38.15	0.4156535588686653
+38.2	0.4364903450931678
+38.25	0.45835047375749827
+38.3	0.4812813675870958
+38.35	0.5053320695271012
+38.4	0.5305532427423685
+38.45	0.5569971706174758
+38.5	0.5847177567567053
+38.55	0.6137705249840516
+38.6	0.6442126193432424
+38.65	0.6761028040976818
+38.7	0.709499798733684
+38.75	0.7444528628440026
+38.8	0.7810352776827908
+38.85	0.8193266781646816
+38.9	0.859407153064521
+38.95	0.9013572450174041
+39.0	0.9452579505186354
+39.05	0.9911907199237645
+39.1	1.0392374574485708
+39.15	1.0894805211690428
+39.2	1.1420027230214302
+39.25	1.1968873288021742
+39.3	1.2542180581679718
+39.35	1.3140790846357546
+39.4	1.3765550355826464
+39.45	1.4417309922460484
+39.5	1.5096924897235524
+39.55	1.5805255169729964
+39.6	1.6543165168124576
+39.65	1.7311523859202103
+39.7	1.8111204748347953
+39.75	1.8943085879549524
+39.8	1.9808049835396564
+39.85	2.070698373708153
+39.9	2.163983109791232
+39.95	2.2606853338691906
+40.0	2.360938720294395
+40.05	2.4648534690798374
+40.1	2.5725162507935737
+40.15	2.683990206558674
+40.2	2.7993149480533326
+40.25	2.918506557510744
+40.3	3.0415575877192116
+40.35	3.1684370620220994
+40.4	3.299090474317784
+40.45	3.4334397890597734
+40.5	3.5713834412565633
+40.55	3.712796336471769
+40.6	3.8575298508240716
+40.65	4.0054118309871445
+40.7	4.156246594189814
+40.75	4.309814928215874
+40.8	4.465874091404255
+40.85	4.6241578126489316
+40.9	4.784376291398901
+40.95	4.946216197658313
+41.0	5.109340671986252
+41.05	5.273404182823946
+41.1	5.43813381246399
+41.15	5.602840033479278
+41.2	5.766713870720248
+41.25	5.928933535352453
+41.3	6.088664424856716
+41.35	6.245059123029065
+41.4	6.397257399980665
+41.45	6.544386212137977
+41.5	6.685559702242575
+41.55	6.8198791993513055
+41.6	6.946433218836213
+41.65	7.064297462384493
+41.7	7.172534817998619
+41.75	7.2701953599961975
+41.8	7.356316349010094
+41.85	7.429922231988369
+41.9	7.490024642194255
+41.95	7.5356223992062175
+42.0	7.565701508917923
+42.05	7.579235163538244
+42.1	7.575183741591243
+42.15	7.552494807916206
+42.2	7.5101031136676
+42.25	7.446960211990248
+42.3	7.362653511188749
+42.35	7.257378130597485
+42.4	7.131478200218654
+42.45	6.985446256627586
+42.5	6.819923242972875
+42.55	6.635698508976241
+42.6	6.433709810932609
+42.65	6.215043311710154
+42.7	5.98093358075015
+42.75	5.732763594067173
+42.8	5.472064734248904
+42.85	5.200516790456206
+42.9	4.919947958423249
+42.95	4.632334840457234
+43.0	4.339802445438737
+43.05	4.044624188821392
+43.1	3.749221892632005
+43.15	3.4561794196288083
+43.2	3.1684159295351932
+43.25	2.888051061812437
+43.3	2.616777437601601
+43.35	2.356145147941265
+43.4	2.107561753767621
+43.45	1.872292285914226
+43.5	1.6514592451122792
+43.55	1.446042601990393
+43.6	1.25687979707469
+43.65	1.084665740788878
+43.7	0.929626080669022
+43.75	0.7904892915534079
+43.8	0.6669630322147689
+43.85	0.5585713669155822
+43.9	0.46444522877890587
+43.95	0.3833224197882812
+44.0	0.31357782352246616
+44.05	0.2541601641675317
+44.1	0.20432561528432963
+44.15	0.1630559586008962
+44.2	0.12916402730886625
+44.25	0.10137107882488659
+44.3	0.07886507863046813
+44.35	0.06091603556284321
+44.4	0.04673787065327124
+44.45	0.03554801391736649
+44.5	0.02682478669927672
+44.55	0.0201299600180282
+44.6	0.015021743551476268
+44.65	0.011118602685664251
+44.7	0.008185699330205028
+44.75	0.006004595922205878
+44.8	0.004378886320825617
+44.85	0.003179006285911467
+44.9	0.0023036198045184044
+44.95	0.001663712986568561
+45.0	0.0011983802002857918
+45.05	0.0008633333324131829
+45.1	0.0006214365576144882
+45.15	0.0004473583556963225
+45.2	0.000322925952415157
+45.25	0.00023356670130283947
+45.3	0.00016948971525873436
+45.35	0.0001236970332192665
+45.4	9.075149291610927e-5
+45.45	6.700338161970473e-5
+45.5	4.989738095078384e-5
+45.55	3.748216280361419e-5
+45.6	2.8409702490601546e-5
+45.65	2.177101260041849e-5
+45.7	1.687692531558942e-5
+45.75	1.3236778688421806e-5
+45.8	1.0515831989377071e-5
+45.85	8.470422021611415e-6
+45.9	6.9156793082844705e-6
+45.95	5.720431452806873e-6
+46.0	4.799861531694991e-6
+46.05	4.071769606983338e-6
+46.1	3.499905323891162e-6
+46.15	3.0582739139824986e-6
+46.2	2.723401134855897e-6
+46.25	2.474333270144931e-6
+46.3	2.292637129517839e-6
+46.35	2.16240004867769e-6
+46.4	2.0702298893624183e-6
+46.45	2.0052550393446896e-6
+46.5	1.9591244124320165e-6
+46.55	1.9260074484667103e-6
+46.6	1.9025941133258457e-6
+46.65	1.8883429362876368e-6
+46.7	1.8877296320074688e-6
+46.75	1.9010805379502938e-6
+46.8	1.9271814081456527e-6
+46.85	1.965004977494437e-6
+46.9	2.0137109617688604e-6
+46.95	2.072646057612502e-6
+47.0	2.1413439425402585e-6
+47.05	2.2194306520794387e-6
+47.1	2.306425913502996e-6
+47.15	2.40217065928247e-6
+47.2	2.506573645865822e-6
+47.25	2.619586373274399e-6
+47.3	2.7412030851030265e-6
+47.35	2.8714607685199803e-6
+47.4	3.010439154266917e-6
+47.45	3.158260716659016e-6
+47.5	3.315090673584805e-6
+47.55	3.481136986506317e-6
+47.6	3.6566503604590372e-6
+47.65	3.841924244051815e-6
+47.7	4.037294829467039e-6
+47.75	4.243141052460408e-6
+47.8	4.459884592361183e-6
+47.85	4.687989872072064e-6
+47.9	4.927964058069078e-6
+47.95	5.180357060401818e-6
+48.0	5.4457615326931705e-6
+48.05	5.724812872139598e-6
+48.1	6.0181892195110635e-6
+48.15	6.326611459150686e-6
+48.2	6.650900668532434e-6
+48.25	6.991902988490036e-6
+48.3	7.350371648743385e-6
+48.35	7.727128171109574e-6
+48.4	8.12306528188901e-6
+48.45	8.539146911865828e-6
+48.5	8.976408196307391e-6
+48.55	9.43595547496474e-6
+48.6	9.918966292072427e-6
+48.65	1.0426689396348247e-5
+48.7	1.0960444740993775e-5
+48.75	1.1521623483693771e-5
+48.8	1.211168798661677e-5
+48.85	1.2732171816414726e-5
+48.9	1.3384679744222858e-5
+48.95	1.4070887745660146e-5
+49.0	1.4792543000828803e-5
+49.05	1.5551463894314893e-5
+49.1	1.6349540015187564e-5
+49.15	1.7188732156999628e-5
+49.2	1.8071072317787616e-5
+49.25	1.8998663700070797e-5
+49.3	1.9973680710852777e-5
+49.35	2.0998368961620643e-5
+49.4	2.2075105909974988e-5
+49.45	2.3206955462057086e-5
+49.5	2.4396816795292184e-5
+49.55	2.5647631879502024e-5
+49.6	2.6962524475982816e-5
+49.65	2.8344800137504537e-5
+49.7	2.9797946208312406e-5
+49.75	3.13256318241251e-5
+49.8	3.293170791213642e-5
+49.85	3.4620207191014525e-5
+49.9	3.639534417090125e-5
+49.95	3.826151515341381e-5
+50.0	4.022329823164275e-5
diff --git a/test/julia_dde/test_basic_check_5.txt b/test/julia_dde/test_basic_check_5.txt
new file mode 100644
index 00000000..db42a42d
--- /dev/null
+++ b/test/julia_dde/test_basic_check_5.txt
@@ -0,0 +1,1001 @@
+0.0	1.2
+0.05	1.1880597992376376
+0.1	1.1762383903532672
+0.15	1.1645346054176793
+0.2	1.1529472850205345
+0.25	1.1414752222367555
+0.3	1.1301172850650132
+0.35	1.1188723699778464
+0.4	1.1077393734477934
+0.45	1.0967171444550854
+0.5	1.0858044886142186
+0.55	1.0750003653428541
+0.6	1.0643037416988483
+0.65	1.0537135847400563
+0.7	1.0432288615243335
+0.75	1.0328485303971229
+0.8	1.022571356528724
+0.85	1.0123962869380496
+0.9	1.0023223752836392
+0.95	0.9923486752240321
+1.0	0.9824742404177691
+1.05	0.9726981245233894
+1.1	0.963019381199433
+1.15	0.9534370641044397
+1.2	0.9439501263044995
+1.25	0.9345573120423534
+1.3	0.9252577408761488
+1.35	0.9160505652198292
+1.4	0.9069349374873376
+1.45	0.8979100100926173
+1.5	0.8889749354496113
+1.55	0.8801288659722627
+1.6	0.8713709540745148
+1.65	0.8627003521703108
+1.7	0.8541162126735939
+1.75	0.8456174480352269
+1.8	0.8372028671950351
+1.85	0.8288717208428457
+1.9	0.8206232664837655
+1.95	0.8124567616229019
+2.0	0.8043714637653621
+2.05	0.7963666304162529
+2.1	0.7884415190806816
+2.15	0.7805953872637555
+2.2	0.7728274924705816
+2.25	0.7651370922062669
+2.3	0.757523443975919
+2.35	0.7499858052846445
+2.4	0.7425231000475476
+2.45	0.7351341865132892
+2.5	0.7278184247305727
+2.55	0.7205751779086018
+2.6	0.7134038092565801
+2.65	0.7063036819837109
+2.7	0.6992741592991979
+2.75	0.6923146044122449
+2.8	0.6854243805320552
+2.85	0.6786028508678325
+2.9	0.6718493786287804
+2.95	0.6651633270241023
+3.0	0.658544059263002
+3.05	0.6519909385546828
+3.1	0.6455033281083487
+3.15	0.6390802274436268
+3.2	0.6327205188637759
+3.25	0.6264236614309651
+3.3	0.620189119515113
+3.35	0.6140163574861381
+3.4	0.6079048397139594
+3.45	0.6018540305684947
+3.5	0.595863394419663
+3.55	0.589932395637383
+3.6	0.5840604985915729
+3.65	0.5782471676521516
+3.7	0.5724918671890373
+3.75	0.566794061572149
+3.8	0.5611532151714047
+3.85	0.5555687923567235
+3.9	0.5500402574980238
+3.95	0.5445670749652239
+4.0	0.5391485415659062
+4.05	0.533941825286989
+4.1	0.5290979030935984
+4.15	0.5246066587037986
+4.2	0.5204579758356546
+4.25	0.5166417382072302
+4.3	0.5131478295365905
+4.35	0.5099661335417995
+4.4	0.5070865425485469
+4.45	0.5045013413138684
+4.5	0.50220474475892
+4.55	0.5001901862647773
+4.6	0.4984507696525758
+4.65	0.4969812968856559
+4.7	0.4957776883424407
+4.75	0.49483586440135285
+4.8	0.4941517454408154
+4.85	0.49372125183925114
+4.9	0.4935403082579171
+4.95	0.4936057326752097
+5.0	0.4939159902049112
+5.05	0.49446967873339276
+5.1	0.495265396147025
+5.15	0.49630174033217894
+5.2	0.49757730917522547
+5.25	0.4990903623716535
+5.3	0.500840165913952
+5.35	0.5028281061034408
+5.4	0.5050556308997887
+5.45	0.5075241882626647
+5.5	0.5102352261517374
+5.55	0.513190192526676
+5.6	0.516390535347149
+5.65	0.5198377025728254
+5.7	0.5235331421633739
+5.75	0.5274782732641654
+5.8	0.5316753705349432
+5.85	0.5361289906293275
+5.9	0.540843962783196
+5.95	0.545825116232426
+6.0	0.5510772802128953
+6.05	0.5566052839604811
+6.1	0.5624139567110612
+6.15	0.5685081277005132
+6.2	0.5748926261647143
+6.25	0.5815722278711621
+6.3	0.5885516376851783
+6.35	0.5958386412284099
+6.4	0.6034418101727644
+6.45	0.6113697161901499
+6.5	0.6196309309524738
+6.55	0.6282340261316438
+6.6	0.6371875733995677
+6.65	0.6465001444281533
+6.7	0.656180310889308
+6.75	0.6662355248448237
+6.8	0.6766749685128675
+6.85	0.6875119127809535
+6.9	0.6987596953139646
+6.95	0.7104316537767829
+7.0	0.7225411258342908
+7.05	0.7351014491513711
+7.1	0.7481259613929059
+7.15	0.761628000223778
+7.2	0.7756206149741648
+7.25	0.7901141516640606
+7.3	0.8051262630768021
+7.35	0.820676871069783
+7.4	0.8367858975003976
+7.45	0.8534732642260386
+7.5	0.8707588931041004
+7.55	0.8886627059919765
+7.6	0.9072046247470607
+7.65	0.9264045642825595
+7.7	0.9462803312576148
+7.75	0.9668554155299864
+7.8	0.9881574048590489
+7.85	1.010213887004177
+7.9	1.0330524497247455
+7.95	1.0567006807801287
+8.0	1.0811861679297008
+8.05	1.1065310334843907
+8.1	1.1327491923150275
+8.15	1.1598559210208417
+8.2	1.1878664962010594
+8.25	1.2167961944549117
+8.3	1.2466602923816255
+8.35	1.2774740665804287
+8.4	1.309252793650552
+8.45	1.3420070425387494
+8.5	1.3757421114772663
+8.55	1.4104712624753173
+8.6	1.4462079661311253
+8.65	1.482965693042917
+8.7	1.520757913808914
+8.75	1.5595980990273437
+8.8	1.5994997192964286
+8.85	1.640476245214392
+8.9	1.6825399326817119
+8.95	1.7256890050713114
+9.0	1.769923239018476
+9.05	1.8152445536479256
+9.1	1.8616548680843785
+9.15	1.9091561014525582
+9.2	1.957750172877182
+9.25	2.0074390014829735
+9.3	2.0582245063946507
+9.35	2.110108606736934
+9.4	2.1630934961721984
+9.45	2.2171858103843394
+9.5	2.2723618450389838
+9.55	2.3285863608705677
+9.6	2.3858241186135283
+9.65	2.44403987900231
+9.7	2.5031984027713463
+9.75	2.563264450655082
+9.8	2.624202783387952
+9.85	2.685978161704395
+9.9	2.7485553463388537
+9.95	2.8118990980257634
+10.0	2.8759741774995664
+10.05	2.9407453454947006
+10.1	3.0061773627456025
+10.15	3.0722169760776383
+10.2	3.138751761019326
+10.25	3.205682404718919
+10.3	3.272911124072332
+10.35	3.340340135975479
+10.4	3.4078716573242818
+10.45	3.475407905014652
+10.5	3.54285109594251
+10.55	3.61010529461556
+10.6	3.677064458301878
+10.65	3.7435742366658236
+10.7	3.8094743727175246
+10.75	3.8746046094671227
+10.8	3.9388046899247495
+10.85	4.001914357100539
+10.9	4.0637733540046295
+10.95	4.124221423647152
+11.0	4.183098242224349
+11.05	4.240228435912087
+11.1	4.295404399767815
+11.15	4.348414550926566
+11.2	4.399047306523364
+11.25	4.447091083693243
+11.3	4.492334299571228
+11.35	4.534565371292346
+11.4	4.573572522527261
+11.45	4.609133847302006
+11.5	4.641028606338098
+11.55	4.669040645161524
+11.6	4.692953809298272
+11.65	4.7125519442743355
+11.7	4.727618895615703
+11.75	4.737938508848365
+11.8	4.743294935484707
+11.85	4.743497907544847
+11.9	4.738383718400927
+11.95	4.727786559379314
+12.0	4.711540621806375
+12.05	4.689464016482542
+12.1	4.661453493741227
+12.15	4.6274612042082595
+12.2	4.587439298509476
+12.25	4.541339927270708
+12.3	4.489115241117789
+12.35	4.4307172613748325
+12.4	4.366152947696486
+12.45	4.295579883345401
+12.5	4.219179589544553
+12.55	4.137133587516926
+12.6	4.049623398485506
+12.65	3.956830543673269
+12.7	3.8589365443032033
+12.75	3.756145319766686
+12.8	3.648869729310099
+12.85	3.5375251983771396
+12.9	3.4225079622443286
+12.95	3.3042142561882004
+13.0	3.1830403154852744
+13.05	3.0593960993312876
+13.1	2.933816013450895
+13.15	2.8068086805267716
+13.2	2.678865716643506
+13.25	2.55047873788567
+13.3	2.422139360337847
+13.35	2.294343383418942
+13.4	2.167580602916567
+13.45	2.042318269765592
+13.5	1.919022785288025
+13.55	1.7981605508058898
+13.6	1.6801719372487585
+13.65	1.565335718273593
+13.7	1.4539913178733794
+13.75	1.3464966609101616
+13.8	1.2432096722459953
+13.85	1.1444242451454565
+13.9	1.0500847213972058
+13.95	0.9603510880513052
+14.0	0.8754326000622558
+14.05	0.7955385123845679
+14.1	0.720696469775138
+14.15	0.6504975497817775
+14.2	0.5849952624191386
+14.25	0.5242811611033897
+14.3	0.4684467992507062
+14.35	0.41726535797670977
+14.4	0.3700374540598733
+14.45	0.3267654592649285
+14.5	0.2874842428742905
+14.55	0.252228674170379
+14.6	0.22084391633542563
+14.65	0.19231099117816666
+14.7	0.16650333730367056
+14.75	0.14345684638823014
+14.8	0.12320741010814093
+14.85	0.10579092013969796
+14.9	0.09078324196296068
+14.95	0.07743617399287633
+15.0	0.06572024904646447
+15.05	0.05561660389438912
+15.1	0.047103029194414746
+15.15	0.039766204265192555
+15.2	0.033325287757735945
+15.25	0.027773907099514918
+15.3	0.023105689718000144
+15.35	0.019305981367019328
+15.4	0.016073857939862792
+15.45	0.01326114603654852
+15.5	0.010868315626638229
+15.55	0.008895836679693946
+15.6	0.007343500058411132
+15.65	0.006096238394876711
+15.7	0.0050542978188134816
+15.75	0.0041923015533235535
+15.8	0.003476097092756846
+15.85	0.0028813782437393854
+15.9	0.0023888728728829003
+15.95	0.0019812266106934537
+16.0	0.0016446691307760468
+16.05	0.0013667252933375887
+16.1	0.001137223143648161
+16.15	0.0009476667665580653
+16.2	0.0007911660128428781
+16.25	0.0006619157697303199
+16.3	0.0005551482755670173
+16.35	0.00046686305999317653
+16.4	0.0003938118906203424
+16.45	0.00033333752281953003
+16.5	0.00028314314129438897
+16.55	0.0002414949526376132
+16.6	0.00020681097760988972
+16.65	0.00017793754480651896
+16.7	0.00015379407553580107
+16.75	0.00013362671622470384
+16.8	0.00011668323770843165
+16.85	0.00010247121104970752
+16.9	9.048881425541913e-5
+16.95	8.037068610024978e-5
+17.0	7.183231507164557e-5
+17.05	6.457846620963938e-5
+17.1	5.8423691774176164e-5
+17.15	5.320333154527555e-5
+17.2	4.8749047466969784e-5
+17.25	4.495521420840009e-5
+17.3	4.173221077606829e-5
+17.35	3.8989124277101314e-5
+17.4	3.665353887474203e-5
+17.45	3.467921548110942e-5
+17.5	3.302089298979895e-5
+17.55	3.1633462324064184e-5
+17.6	3.0481783106832075e-5
+17.65	2.9544896465280214e-5
+17.7	2.8802842987084214e-5
+17.75	2.823566325991989e-5
+17.8	2.782362227232167e-5
+17.85	2.7553331975005054e-5
+17.9	2.7418045539681683e-5
+17.95	2.741127837896782e-5
+18.0	2.7526545905479753e-5
+18.05	2.775736353183376e-5
+18.1	2.8097246619573764e-5
+18.15	2.8540884873025393e-5
+18.2	2.9087698313496896e-5
+18.25	2.9738289467161208e-5
+18.3	3.0493260860191267e-5
+18.35	3.135321501876001e-5
+18.4	3.2318754469040314e-5
+18.45	3.3390481737205245e-5
+18.5	3.45689993494277e-5
+18.55	3.585296333408846e-5
+18.6	3.723333571393522e-5
+18.65	3.871384458144151e-5
+18.7	4.030036358771365e-5
+18.75	4.1998766383857675e-5
+18.8	4.3814926620979705e-5
+18.85	4.575471795018588e-5
+18.9	4.782401402258214e-5
+18.95	5.0028688489274946e-5
+19.0	5.2374615001370255e-5
+19.05	5.4867667209974195e-5
+19.1	5.7513718766192896e-5
+19.15	6.031864332113228e-5
+19.2	6.328831452589888e-5
+19.25	6.64286060315986e-5
+19.3	6.974508877438047e-5
+19.35	7.323652707221297e-5
+19.4	7.69101582178799e-5
+19.45	8.077772518257386e-5
+19.5	8.485097093748668e-5
+19.55	8.914163845381041e-5
+19.6	9.366147070273715e-5
+19.65	9.842221065545859e-5
+19.7	0.00010343560128316748
+19.75	0.00010871122321593115
+19.8	0.00011424243465606984
+19.85	0.00012004376195416888
+19.9	0.00012613565222541643
+19.95	0.00013253855258500192
+20.0	0.00013927291014811352
+20.05	0.0001463591720299398
+20.1	0.0001538177853456694
+20.15	0.00016166919721049034
+20.2	0.00016993385473959227
+20.25	0.00017863220504816328
+20.3	0.00018778113899536105
+20.35	0.00019737873541757228
+20.4	0.0002074521753784754
+20.45	0.00021803485355739299
+20.5	0.0002291601646336456
+20.55	0.00024086150328655437
+20.6	0.00025317226419544056
+20.65	0.0002661258420396244
+20.7	0.00027975563149842904
+20.75	0.0002940922218916396
+20.8	0.00030913370525589325
+20.85	0.00032491782563353656
+20.9	0.0003414973154406396
+20.95	0.0003589249070932758
+21.0	0.00037725333300751525
+21.05	0.0003965353255994291
+21.1	0.00041682361728508847
+21.15	0.00043817094048056304
+21.2	0.0004606289870918508
+21.25	0.00048420720969698033
+21.3	0.0005089558956894124
+21.35	0.0005349558824716729
+21.4	0.0005622880074462859
+21.45	0.0005910331080157811
+21.5	0.0006212720215826827
+21.55	0.0006530855855495162
+21.6	0.000686554637318808
+21.65	0.0007217244440419502
+21.7	0.0007586369140928553
+21.75	0.0007974114970944676
+21.8	0.0008381688900945253
+21.85	0.0008810297901407669
+21.9	0.0009261148942809271
+21.95	0.0009735448995627507
+22.0	0.0010234405030339728
+22.05	0.0010758674103834008
+22.1	0.0011309082234625218
+22.15	0.001188738483430932
+22.2	0.0012495341070863388
+22.25	0.0013134710112264387
+22.3	0.0013807251126489322
+22.35	0.0014514723281515186
+22.4	0.0015258645492803558
+22.45	0.0016039735836754021
+22.5	0.0016860273377255132
+22.55	0.0017722770527852946
+22.6	0.0018629739702093525
+22.65	0.001958369331352286
+22.7	0.0020587143775687147
+22.75	0.0021642341928720787
+22.8	0.002275043297578281
+22.85	0.00239145778948531
+22.9	0.0025138302786009356
+22.95	0.0026425133749329545
+23.0	0.0027778596884891356
+23.05	0.0029202218292772577
+23.1	0.0030698672976881033
+23.15	0.0032270262551232975
+23.2	0.003392177892495248
+23.25	0.0035658100112012153
+23.3	0.0037484104126384687
+23.35	0.003940466898204278
+23.4	0.004142445044319549
+23.45	0.004354601692589874
+23.5	0.004577503910252191
+23.55	0.004811818459224796
+23.6	0.005058212101425987
+23.65	0.005317351598774041
+23.7	0.005589899287648707
+23.75	0.005876252279009922
+23.8	0.00617708605838838
+23.85	0.0064933030482079695
+23.9	0.0068258056708925575
+23.95	0.007175496348866081
+24.0	0.007543273442447884
+24.05	0.007929700826037223
+24.1	0.008335682764424885
+24.15	0.008762430679453235
+24.2	0.009211155992964739
+24.25	0.009683070126801762
+24.3	0.010179359311351058
+24.35	0.010700759853161369
+24.4	0.011248606748867872
+24.45	0.01182451534257886
+24.5	0.012430100978402499
+24.55	0.013066979000447005
+24.6	0.013736632347962078
+24.65	0.014440093184907035
+24.7	0.015179353711805838
+24.75	0.015956555916169984
+24.8	0.01677384178551103
+24.85	0.017633353307340515
+24.9	0.018536759268335617
+24.95	0.019485827543527734
+25.0	0.02048337239871802
+25.05	0.02153222107898219
+25.1	0.022635200829395945
+25.15	0.023795049919238722
+25.2	0.025013578259642445
+25.25	0.026294018030455907
+25.3	0.02764008669994533
+25.35	0.029055501736376957
+25.4	0.030543980608016903
+25.45	0.032108513400120824
+25.5	0.03375212946457666
+25.55	0.03547966677107697
+25.6	0.037295997416007365
+25.65	0.03920599349575325
+25.7	0.04121428515888784
+25.75	0.04332408455708571
+25.8	0.045541163334195785
+25.85	0.04787188927593277
+25.9	0.050322630168011216
+25.95	0.05289975223037868
+26.0	0.05560786074318585
+26.05	0.058453069323036504
+26.1	0.06144368075112385
+26.15	0.06458799780864083
+26.2	0.06789432327678108
+26.25	0.07136988227873912
+26.3	0.07502086916643964
+26.35	0.078857816259233
+26.4	0.08289152572268692
+26.45	0.08713279972237008
+26.5	0.09159200978152159
+26.55	0.09627637399515697
+26.6	0.10119846524538026
+26.65	0.1063723090652302
+26.7	0.11181193098774658
+26.75	0.11753132961245914
+26.8	0.12354047406374256
+26.85	0.12985336249326254
+26.9	0.13648817059730076
+26.95	0.14346307407214032
+27.0	0.15079624861406296
+27.05	0.1585027374101037
+27.1	0.16659727962320772
+27.15	0.17510316634668957
+27.2	0.18404390407275775
+27.25	0.19344299929361908
+27.3	0.20332202693815007
+27.35	0.21369744874628444
+27.4	0.22459800654555656
+27.45	0.2360540335811343
+27.5	0.2480958630981832
+27.55	0.2607530894413613
+27.6	0.2740460699238659
+27.65	0.288008693254334
+27.7	0.3026800019456106
+27.75	0.3180990385105376
+27.8	0.3343048274751558
+27.85	0.3513257561305068
+27.9	0.36919936291811095
+27.95	0.38797574684328856
+28.0	0.40770500691135625
+28.05	0.4284372421276319
+28.1	0.45021428222314047
+28.15	0.47307577260762895
+28.2	0.49708503969237244
+28.25	0.522306126774
+28.3	0.548803077149142
+28.35	0.5766349527857709
+28.4	0.6058456423288524
+28.45	0.6365119398467898
+28.5	0.6687153176492591
+28.55	0.702537248045939
+28.6	0.738057654914667
+28.65	0.7753305110693685
+28.7	0.8144434231982368
+28.75	0.8554993742873835
+28.8	0.8986013473229225
+28.85	0.9438523252909677
+28.9	0.9913297868875838
+28.95	1.041123710097386
+29.0	1.0933630948430608
+29.05	1.14817699015449
+29.1	1.2056944450615554
+29.15	1.2660294097558396
+29.2	1.3292704696127888
+29.25	1.3955718731975337
+29.3	1.4650933956963939
+29.35	1.5379948122956881
+29.4	1.6144322187413243
+29.45	1.6945059343288849
+29.5	1.7783813246656435
+29.55	1.8662528139459742
+29.6	1.9583148263642518
+29.65	2.0547617861148435
+29.7	2.1557460904638592
+29.75	2.261409543140288
+29.8	2.3719812068007204
+29.85	2.4876916053935085
+29.9	2.608771262866994
+29.95	2.7354418607787947
+30.0	2.867833268001007
+30.05	3.006172393970795
+30.1	3.150721427559324
+30.15	3.301742557637748
+30.2	3.4594979730772537
+30.25	3.6242036779718427
+30.3	3.796000312947196
+30.35	3.9751569991656543
+30.4	4.161950665497836
+30.45	4.3566582408144034
+30.5	4.559556653985974
+30.55	4.770833894645508
+30.6	4.9906199335286585
+30.65	5.219162942569686
+30.7	5.456712185271295
+30.75	5.703516925136135
+30.8	5.959826425666878
+30.85	6.225812985552521
+30.9	6.501494551293352
+30.95	6.786991830496306
+31.0	7.0824301646721395
+31.05	7.387934895331626
+31.1	7.703631363985543
+31.15	8.02964429581182
+31.2	8.365969601129745
+31.25	8.712378897019077
+31.3	9.068630145944798
+31.35	9.43448131037188
+31.4	9.809690352765271
+31.45	10.194015235590006
+31.5	10.58721392131103
+31.55	10.989044372393318
+31.6	11.399185932112827
+31.65	11.816814655568894
+31.7	12.240913599537175
+31.75	12.670465711290694
+31.8	13.104453938102514
+31.85	13.54186122724569
+31.9	13.981670525993247
+31.95	14.422860229961994
+32.0	14.864066241211134
+32.05	15.303350875437815
+32.1	15.738721891484401
+32.15	16.168187048193033
+32.2	16.589754104406108
+32.25	17.001430818965762
+32.3	17.40112875741172
+32.35	17.786206289383145
+32.4	18.153837422401434
+32.45	18.501196153278787
+32.5	18.825456478827164
+32.55	19.12379239379405
+32.6	19.39328018340708
+32.65	19.630858578672516
+32.7	19.83349310673894
+32.75	19.9981492947548
+32.8	20.12179266986864
+32.85	20.201367882770075
+32.9	20.234028960220357
+32.95	20.217639157306632
+33.0	20.150142310639445
+33.05	20.029482256829326
+33.1	19.85360283248679
+33.15	19.620653866292756
+33.2	19.330483574902434
+33.25	18.983603862670492
+33.3	18.580526687952972
+33.35	18.121764009105878
+33.4	17.607850902026595
+33.45	17.041004012197757
+33.5	16.425347738748968
+33.55	15.764958338679975
+33.6	15.06391206899048
+33.65	14.3262851866805
+33.7	13.556558450649579
+33.75	12.76200526834196
+33.8	11.949478426283411
+33.85	11.125590413384847
+33.9	10.297086836606853
+33.95	9.471288292197361
+34.0	8.655321266099904
+34.05	7.85627615438624
+34.1	7.0807582268509135
+34.15	6.33403337646641
+34.2	5.622072991058431
+34.25	4.950691141209821
+34.3	4.321840083274608
+34.35	3.7386615478986283
+34.4	3.2054279738738485
+34.45	2.7218318310108915
+34.5	2.2861233536241135
+34.55	1.9007836831588563
+34.6	1.5638096599568296
+34.65	1.2691004030964574
+34.7	1.0179631521778372
+34.75	0.8080319622887391
+34.8	0.6302047852419099
+34.85	0.4848581465843025
+34.9	0.3694811097467595
+34.95	0.2752330911034964
+35.0	0.2021828861211352
+35.05	0.14699573418067538
+35.1	0.10393108273236383
+35.15	0.07300848717904075
+35.2	0.05020999035277301
+35.25	0.0339934650647469
+35.3	0.02279095129967895
+35.35	0.015050341092619643
+35.4	0.009760658950261831
+35.45	0.006210064349417115
+35.5	0.003874772476991086
+35.55	0.002369912511200858
+35.6	0.0014203851182303866
+35.65	0.0008338273726116825
+35.7	0.00047894603427704654
+35.75	0.0002689999388510345
+35.8	0.00014760542925836503
+35.85	7.900586327484483e-5
+35.9	4.1131483924148396e-5
+35.95	2.0729000341789616e-5
+36.0	9.830979593492913e-6
+36.05	3.5098735344186295e-6
+36.1	9.189269223395096e-8
+36.15	-1.887279534205756e-8
+36.2	7.252943304034224e-9
+36.25	1.309220307737597e-7
+36.3	3.2908909610467193e-7
+36.35	5.558436852339584e-7
+36.4	7.652753440987047e-7
+36.45	9.11473618636111e-7
+36.5	9.485280547832864e-7
+36.55	8.305281984774123e-7
+36.6	5.115635956555821e-7
+36.65	-2.684841319306035e-8
+36.7	-4.2542990904938504e-7
+36.75	-1.8039786009714827e-7
+36.8	5.050825789153971e-7
+36.85	9.619825794053776e-7
+36.9	5.522756124634739e-7
+36.95	-5.213684458838995e-7
+37.0	-1.1544929381068978e-6
+37.05	-4.2904562784102766e-7
+37.1	7.240489451915882e-7
+37.15	1.0156385295113247e-6
+37.2	5.6764655700949686e-8
+37.25	-7.626496617237029e-7
+37.3	-5.457360385386776e-7
+37.35	1.9635653595418105e-7
+37.4	6.015566327139623e-7
+37.45	2.6082534887574583e-7
+37.5	-2.0843505408453615e-7
+37.55	-4.178403152306908e-7
+37.6	-1.8978099597290364e-7
+37.65	9.232247537872038e-8
+37.7	2.6057409141174634e-7
+37.75	1.9902803094234795e-7
+37.8	5.326044251702247e-8
+37.85	-6.912452446119937e-8
+37.9	-1.3466479341773114e-7
+37.95	-1.1304973175336538e-7
+38.0	-6.388287646301767e-8
+38.05	-2.3025689624820396e-8
+38.1	7.693319930149252e-9
+38.15	2.6445643370801556e-8
+38.2	3.140277186606055e-8
+38.25	2.5649430393586305e-8
+38.3	2.0702626884890685e-8
+38.35	1.699555299463395e-8
+38.4	1.436124797887245e-8
+38.45	1.2632751093660736e-8
+38.5	1.1643101595054601e-8
+38.55	1.1225338739109073e-8
+38.6	1.121250178187951e-8
+38.65	1.1437629979421218e-8
+38.7	1.1733762587789548e-8
+38.75	1.1934312844372769e-8
+38.8	1.2080665187396904e-8
+38.85	1.2294693397805885e-8
+38.9	1.2575534983202852e-8
+38.95	1.2922327451191093e-8
+39.0	1.3334208309373706e-8
+39.05	1.3810315065353941e-8
+39.1	1.4349785226735088e-8
+39.15	1.495175630112019e-8
+39.2	1.5615365796112624e-8
+39.25	1.63397512193154e-8
+39.3	1.7124050078331818e-8
+39.35	1.7967399880765203e-8
+39.4	1.8868938134218517e-8
+39.45	1.982780234629522e-8
+39.5	2.084313002459824e-8
+39.55	2.191405867673092e-8
+39.6	2.303972581029662e-8
+39.65	2.4188896662382705e-8
+39.7	2.5213032377264123e-8
+39.75	2.6103107130769886e-8
+39.8	2.6871752587985e-8
+39.85	2.7531600413994427e-8
+39.9	2.8095282273882808e-8
+39.95	2.8575429832735164e-8
+40.0	2.898467475563623e-8
+40.05	2.9335648707670923e-8
+40.1	2.964098335392413e-8
+40.15	2.991331035948064e-8
+40.2	3.0165261389425356e-8
+40.25	3.040946810884307e-8
+40.3	3.0658562182818704e-8
+40.35	3.092517527643709e-8
+40.4	3.1221939054783084e-8
+40.45	3.156148518294154e-8
+40.5	3.195644532599731e-8
+40.55	3.241945114903522e-8
+40.6	3.296313431714023e-8
+40.65	3.360012649539703e-8
+40.7	3.434305934889068e-8
+40.75	3.52045645427058e-8
+40.8	3.619727374192734e-8
+40.85	3.7333818611640384e-8
+40.9	3.862683081692937e-8
+40.95	4.008894202287959e-8
+41.0	4.173278389457546e-8
+41.05	4.357098809710202e-8
+41.1	4.561618629554443e-8
+41.15	4.788101015498703e-8
+41.2	5.037809134051519e-8
+41.25	5.312006151721311e-8
+41.3	5.6119552350166066e-8
+41.35	5.938919550445932e-8
+41.4	6.294162264517676e-8
+41.45	6.678946543740433e-8
+41.5	7.094535554622576e-8
+41.55	7.542192463672646e-8
+41.6	8.023180437399186e-8
+41.65	8.538762642310565e-8
+41.7	9.090202244915397e-8
+41.75	9.678762411722021e-8
+41.8	1.0305706309238992e-7
+41.85	1.0972297103974902e-7
+41.9	1.1679797962438039e-7
+41.95	1.2429472051137097e-7
+42.0	1.322258253658033e-7
+42.05	1.406039258527635e-7
+42.1	1.4944165363733767e-7
+42.15	1.5875164038460794e-7
+42.2	1.6854651775966214e-7
+42.25	1.7883891742758207e-7
+42.3	1.8964147105345398e-7
+42.35	2.0096681030236458e-7
+42.4	2.1282756683939527e-7
+42.45	2.252363723296344e-7
+42.5	2.3820585843816323e-7
+42.55	2.517486568300685e-7
+42.6	2.658773991704367e-7
+42.65	2.806047171243493e-7
+42.7	2.959432423568947e-7
+42.75	3.119056065331539e-7
+42.8	3.285044413182136e-7
+42.85	3.457523783771612e-7
+42.9	3.6366204937507665e-7
+42.95	3.821067535802948e-7
+43.0	4.008052150485976e-7
+43.05	4.198588644770537e-7
+43.1	4.393868359168193e-7
+43.15	4.5950826341904257e-7
+43.2	4.803422810348826e-7
+43.25	5.020080228154869e-7
+43.3	5.24624622812012e-7
+43.35	5.483112150756151e-7
+43.4	5.731869336574427e-7
+43.45	5.993709126086556e-7
+43.5	6.269822859803998e-7
+43.55	6.561401878238322e-7
+43.6	6.869637521901112e-7
+43.65	7.195721131303816e-7
+43.7	7.540844046958056e-7
+43.75	7.906197609375275e-7
+43.8	8.292973159067054e-7
+43.85	8.70236203654499e-7
+43.9	9.135555582320497e-7
+43.95	9.59374513690524e-7
+44.0	1.0078122040810616e-6
+44.05	1.0589877634548231e-6
+44.1	1.1130203258629698e-6
+44.15	1.1700290253566394e-6
+44.2	1.230132995987002e-6
+44.25	1.2934513718051938e-6
+44.3	1.360103286862377e-6
+44.35	1.430207875209715e-6
+44.4	1.5038842708983416e-6
+44.45	1.5812516079794313e-6
+44.5	1.662429020504115e-6
+44.55	1.7454278033768966e-6
+44.6	1.8285439608609973e-6
+44.65	1.9123650296215106e-6
+44.7	1.9974785680304374e-6
+44.75	2.0844721344597293e-6
+44.8	2.1739332872813744e-6
+44.85	2.266449584867362e-6
+44.9	2.3626085855896422e-6
+44.95	2.462997847820218e-6
+45.0	2.5682049299310377e-6
+45.05	2.678817390294091e-6
+45.1	2.7954227872813705e-6
+45.15	2.9186086792648208e-6
+45.2	3.048962624616451e-6
+45.25	3.187072181708202e-6
+45.3	3.333524908912067e-6
+45.35	3.488908364600045e-6
+45.4	3.6538101071440676e-6
+45.45	3.828817694916159e-6
+45.5	4.014518686288244e-6
+45.55	4.211500639632323e-6
+45.6	4.4203511133204055e-6
+45.65	4.641657665724405e-6
+45.7	4.876007855216361e-6
+45.75	5.123989240168185e-6
+45.8	5.386189378951883e-6
+45.85	5.663195829939474e-6
+45.9	5.955596151502852e-6
+45.95	6.26397790201408e-6
+46.0	6.588928639845044e-6
+46.05	6.931035923367765e-6
+46.1	7.290887310954268e-6
+46.15	7.669070360976432e-6
+46.2	8.066172631806335e-6
+46.25	8.482781681815845e-6
+46.3	8.91948506937699e-6
+46.35	9.375968886600003e-6
+46.4	9.847023565583588e-6
+46.45	1.0333786668904147e-5
+46.5	1.0838770265523423e-5
+46.55	1.1364486424403357e-5
+46.6	1.1913447214505912e-5
+46.65	1.2488164704792812e-5
+46.7	1.3091150964226096e-5
+46.75	1.3724918061767474e-5
+46.8	1.439197806637891e-5
+46.85	1.509484304702238e-5
+46.9	1.5836025072659568e-5
+46.95	1.661803621225256e-5
+47.0	1.744338853476302e-5
+47.05	1.8314594109152927e-5
+47.1	1.9234165004384297e-5
+47.15	2.020461328941875e-5
+47.2	2.122845103321844e-5
+47.25	2.230819030474495e-5
+47.3	2.3446343172960313e-5
+47.35	2.464542170682657e-5
+47.4	2.5907937975305263e-5
+47.45	2.7236404047358632e-5
+47.5	2.8633331991948176e-5
+47.55	3.009806972638282e-5
+47.6	3.162218202776161e-5
+47.65	3.321109283517312e-5
+47.7	3.487170254963638e-5
+47.75	3.6610911572169504e-5
+47.8	3.84356203037913e-5
+47.85	4.0352729145520636e-5
+47.9	4.236913849837555e-5
+47.95	4.4491748763375194e-5
+48.0	4.672746034153756e-5
+48.05	4.9083173633881515e-5
+48.1	5.1565789041426e-5
+48.15	5.418220696518889e-5
+48.2	5.6939327806189526e-5
+48.25	5.98440519654457e-5
+48.3	6.290327984397639e-5
+48.35	6.612391184280061e-5
+48.4	6.951284836293609e-5
+48.45	7.306821154518534e-5
+48.5	7.678319205978394e-5
+48.55	8.067208960716033e-5
+48.6	8.474979700493381e-5
+48.65	8.903120707072185e-5
+48.7	9.353121262214436e-5
+48.75	9.826470647681877e-5
+48.8	0.00010324658145236433
+48.85	0.00010849173036640049
+48.9	0.00011401504603654443
+48.95	0.00011983142128041642
+49.0	0.00012595574891563344
+49.05	0.00013240292175981498
+49.1	0.00013918783263058076
+49.15	0.00014631064161143507
+49.2	0.00015376893550346217
+49.25	0.00016158994876125146
+49.3	0.00016980134443617313
+49.35	0.00017843078557959753
+49.4	0.0001875059352428914
+49.45	0.00019705445647742652
+49.5	0.00020710401233456927
+49.55	0.0002176822658656901
+49.6	0.00022881688012215982
+49.65	0.00024053551815534442
+49.7	0.00025286475519168463
+49.75	0.00026581658796611703
+49.8	0.00027942372663386186
+49.85	0.0002937271246721932
+49.9	0.0003087677355583789
+49.95	0.00032458651276969517
+50.0	0.00034122440978340957
diff --git a/test/julia_dde/test_basic_check_6.txt b/test/julia_dde/test_basic_check_6.txt
new file mode 100644
index 00000000..eb94fce5
--- /dev/null
+++ b/test/julia_dde/test_basic_check_6.txt
@@ -0,0 +1,1001 @@
+0.0	1.2
+0.05	1.1880598004963274
+0.1	1.1762384079684047
+0.15	1.1645346400763932
+0.2	1.1529473230412208
+0.25	1.141475299112477
+0.3	1.1301174233925737
+0.35	1.1188725616239938
+0.4	1.1077395901892904
+0.45	1.0967173961110868
+0.5	1.0858048770520767
+0.55	1.075000941315024
+0.6	1.064304507842764
+0.65	1.0537145062182005
+0.7	1.043229876664309
+0.75	1.0328495700441358
+0.8	1.0225725478607945
+0.85	1.012397782257473
+0.9	1.0023242560174284
+0.95	0.9923509625639848
+1.0	0.9824768980307879
+1.05	0.9727010597135578
+1.1	0.9630224771641795
+1.15	0.9534401895136975
+1.2	0.9439532442848286
+1.25	0.9345606973919633
+1.3	0.9252616131411647
+1.35	0.916055064230169
+1.4	0.9069401317483856
+1.45	0.8979159051768966
+1.5	0.888981482388457
+1.55	0.8801359696474954
+1.6	0.8713784816101126
+1.65	0.8627081413240832
+1.7	0.8541240802288541
+1.75	0.8456254381555455
+1.8	0.837211363326951
+1.85	0.8288810123575361
+1.9	0.8206335502534408
+1.95	0.8124681504124766
+2.0	0.80438399462413
+2.05	0.7963802730695569
+2.1	0.7884561843215901
+2.15	0.7806109353447339
+2.2	0.7728437414951651
+2.25	0.765153826520734
+2.3	0.7575404225609632
+2.35	0.7500027701470497
+2.4	0.7425401095139714
+2.45	0.7351516745473683
+2.5	0.7278367335692332
+2.55	0.7205945629397542
+2.6	0.7134244451629572
+2.65	0.7063256688867061
+2.7	0.6992975289027018
+2.75	0.6923393261464837
+2.8	0.6854503676974283
+2.85	0.6786299667787499
+2.9	0.6718774427575008
+2.95	0.6651921211445705
+3.0	0.6585733335946865
+3.05	0.6520204179064143
+3.1	0.6455327180221563
+3.15	0.639109584028153
+3.2	0.6327503721544829
+3.25	0.626454444775062
+3.3	0.6202211704076434
+3.35	0.614049923713819
+3.4	0.6079400854990171
+3.45	0.6018910427125049
+3.5	0.5959021884473867
+3.55	0.589972921940604
+3.6	0.5841026485729373
+3.65	0.5782907798690035
+3.7	0.5725367334972585
+3.75	0.5668399332699937
+3.8	0.5611998091433409
+3.85	0.5556157972172671
+3.9	0.5500873397355796
+3.95	0.5446138850859217
+4.0	0.5391948877998171
+4.05	0.5339894476498435
+4.1	0.5291484887570944
+4.15	0.5246599102269127
+4.2	0.5205124110692027
+4.25	0.5166954624161663
+4.3	0.5131993075223042
+4.35	0.5100148786651322
+4.4	0.507133789063348
+4.45	0.5045484127257454
+4.5	0.5022516488470562
+4.55	0.500236916528291
+4.6	0.4984981547767386
+4.65	0.49702982250596606
+4.7	0.49582689853581874
+4.75	0.49488488159242033
+4.8	0.494199790308173
+4.85	0.4937680930953992
+4.9	0.4935865081959021
+4.95	0.49365234000660846
+5.0	0.4939632697256549
+5.05	0.4945173295813525
+5.1	0.49531290283218715
+5.15	0.49634872376681904
+5.2	0.49762387770408306
+5.25	0.4991378009929886
+5.3	0.5008902810127198
+5.35	0.5028814561726352
+5.4	0.5051118159122677
+5.45	0.5075822007013254
+5.5	0.5102938125293998
+5.55	0.5132481420160805
+5.6	0.5164468664365349
+5.65	0.5198919874314458
+5.7	0.5235858405469669
+5.75	0.5275310952347235
+5.8	0.531730754851812
+5.85	0.5361881566607996
+5.9	0.5409069718297252
+5.95	0.5458912054320991
+6.0	0.5511451964469019
+6.05	0.5566736177585864
+6.1	0.5624814761570759
+6.15	0.5685741123377654
+6.2	0.5749572009015202
+6.25	0.5816367503546787
+6.3	0.5886191031090475
+6.35	0.5959109354819085
+6.4	0.603519763311911
+6.45	0.6114556167254628
+6.5	0.6197256590045511
+6.55	0.6283370718597721
+6.6	0.637297663258921
+6.65	0.6466158674269936
+6.7	0.6563007448461833
+6.75	0.6663619822558845
+6.8	0.6768098926526902
+6.85	0.6876554152903934
+6.9	0.6989101156799858
+6.95	0.7105861855896597
+7.0	0.7226964430448051
+7.05	0.7352543323280116
+7.1	0.7482739239790708
+7.15	0.7617699147949715
+7.2	0.7757576278299017
+7.25	0.7902530123952489
+7.3	0.8052726440596026
+7.35	0.8208337246487463
+7.4	0.8369540822456707
+7.45	0.8536521711905581
+7.5	0.8709470720807941
+7.55	0.888858491770967
+7.6	0.9074067633728512
+7.65	0.926612846255443
+7.7	0.946498326044915
+7.75	0.9670854146246521
+7.8	0.9883986283036008
+7.85	1.0104650299017568
+7.9	1.0333110834586596
+7.95	1.056964355445862
+8.0	1.081453529577144
+8.05	1.1068040247077164
+8.1	1.133031675252411
+8.15	1.160151390386478
+8.2	1.1881777648723224
+8.25	1.217125079059513
+8.3	1.247007298884773
+8.35	1.2778380758719852
+8.4	1.3096307471321937
+8.45	1.3423983353635964
+8.5	1.376153548851558
+8.55	1.4109087814685923
+8.6	1.4466761126743772
+8.65	1.483467307515752
+8.7	1.5212938166267091
+8.75	1.5601667762283993
+8.8	1.6000970081291368
+8.85	1.6410950197243965
+8.9	1.6831710039968013
+8.95	1.72633148245242
+9.0	1.7705794905938943
+9.05	1.8159224788729322
+9.1	1.862365635256938
+9.15	1.9099117265643932
+9.2	1.9585610984648445
+9.25	2.008311675478917
+9.3	2.059158960978304
+9.35	2.111096037185769
+9.4	2.1641135651751533
+9.45	2.2181997848713633
+9.5	2.2733405150503807
+9.55	2.329519153339259
+9.6	2.386716676216126
+9.65	2.4449116390101793
+9.7	2.504080175901684
+9.75	2.564192442298866
+9.8	2.6252063573836484
+9.85	2.6870886538434977
+9.9	2.7498014288308483
+9.95	2.8133006339811084
+10.0	2.877536075412671
+10.05	2.942451413726905
+10.1	3.0079841640081573
+10.15	3.074065695823763
+10.2	3.1406212332240275
+10.25	3.2075698547422418
+10.3	3.2748244933946777
+10.35	3.3422919366805766
+10.4	3.409872826582172
+10.45	3.4774616595646712
+10.5	3.5449467865762605
+10.55	3.612210413048112
+10.6	3.679128598894362
+10.65	3.745571258512154
+10.7	3.811402160781577
+10.75	3.8764789290657355
+10.8	3.9406530412106795
+10.85	4.0037752560274695
+10.9	4.065710361208003
+10.95	4.126263189406429
+11.0	4.185228039219931
+11.05	4.242399038561617
+11.1	4.297570144660533
+11.15	4.350535144061664
+11.2	4.401087652625922
+11.25	4.44902111553016
+11.3	4.4941288072671615
+11.35	4.5362038316456434
+11.4	4.575039121790263
+11.45	4.6104274401416045
+11.5	4.642161378456192
+11.55	4.670033357806482
+11.6	4.6938356285808664
+11.65	4.713360270483668
+11.7	4.728399192535151
+11.75	4.738744133071507
+11.8	4.744186659744865
+11.85	4.744518169523291
+11.9	4.739529888690779
+11.95	4.729012872847263
+12.0	4.712758006908607
+12.05	4.690644772827069
+12.1	4.662607874771771
+12.15	4.628544701418221
+12.2	4.588384372705109
+12.25	4.5420877398343
+12.3	4.489647385270846
+12.35	4.431087622742975
+12.4	4.36646449724209
+12.45	4.295865785022786
+12.5	4.219410993602822
+12.55	4.137251361763148
+12.6	4.0495698595478915
+12.65	3.956581188264356
+12.7	3.8585317804830335
+12.75	3.755699800037576
+12.8	3.6483951420248397
+12.85	3.5369598085500358
+12.9	3.421872986351
+12.95	3.3036137528179275
+13.0	3.1825896906344413
+13.05	3.059219238591574
+13.1	2.9339316915877616
+13.15	2.807167200628823
+13.2	2.679376772828
+13.25	2.55102227140591
+13.3	2.4225764156905853
+13.35	2.2945227811174655
+13.4	2.1673557992293646
+13.45	2.0415807576765275
+13.5	1.9177138002165603
+13.55	1.7962819267145127
+13.6	1.6778229931428135
+13.65	1.5628857115812622
+13.7	1.451775448594975
+13.75	1.3444872104184455
+13.8	1.241323137203346
+13.85	1.142544019405883
+13.9	1.0483619993004278
+13.95	0.958940570979545
+14.0	0.8743945803539613
+14.05	0.7947902251525917
+14.1	0.720145054922527
+14.15	0.6504279710290286
+14.2	0.5855592266555394
+14.25	0.5254104268036786
+14.3	0.46980452829324654
+14.35	0.41854245182074584
+14.4	0.3715213248496608
+14.45	0.3286413774786371
+14.5	0.2897610736192788
+14.55	0.25469718925321655
+14.6	0.2232248124321029
+14.65	0.19507734327761048
+14.7	0.16994649398143608
+14.75	0.14748476071965655
+14.8	0.12747860910971498
+14.85	0.10980739774121695
+14.9	0.09430132081329244
+14.95	0.08077609315674421
+15.0	0.06903295023404325
+15.05	0.058858648139332344
+15.1	0.050025463598425
+15.15	0.0423515352891361
+15.2	0.035748634408577284
+15.25	0.030106283070165652
+15.3	0.025312577530941235
+15.35	0.02125428291269547
+15.4	0.017816833201970878
+15.45	0.014890207823781922
+15.5	0.01241435562312446
+15.55	0.01033522210714558
+15.6	0.008597913655595604
+15.65	0.007150103019257152
+15.7	0.005942029319945532
+15.75	0.004929855341324498
+15.8	0.004086937387096473
+15.85	0.003388335741636043
+15.9	0.00281091700782403
+15.95	0.0023337264921647967
+16.0	0.0019381060012871984
+16.05	0.0016105870898081322
+16.1	0.001340102298681946
+16.15	0.0011167692494021184
+16.2	0.0009323877459983964
+16.25	0.0007801253909870112
+16.3	0.0006543465182712573
+16.35	0.0005503754813700738
+16.4	0.0004643447916936142
+16.45	0.00039309405049839006
+16.5	0.00033400315527697057
+16.55	0.00028491725176067826
+16.6	0.00024408189440663954
+16.65	0.00021005101985087374
+16.7	0.00018162516020069216
+16.75	0.00015783592362906903
+16.8	0.00013788844694569702
+16.85	0.00012112660968516564
+16.9	0.00010700211647825708
+16.95	9.507347081214746e-5
+17.0	8.498512339869922e-5
+17.05	7.643876452321336e-5
+17.1	6.918767485250488e-5
+17.15	6.302051055846575e-5
+17.2	5.776129353098915e-5
+17.25	5.326941137797212e-5
+17.3	4.943961742531317e-5
+17.35	4.619114694863392e-5
+17.4	4.343365485645419e-5
+17.45	4.110144366726172e-5
+17.5	3.913880532586793e-5
+17.55	3.7498404263323886e-5
+17.6	3.614127739692056e-5
+17.65	3.503683413018882e-5
+17.7	3.4158658117854505e-5
+17.75	3.348423836993994e-5
+17.8	3.299722596485071e-5
+17.85	3.2683202960132985e-5
+17.9	3.252967918350988e-5
+17.95	3.252609223288135e-5
+18.0	3.266380747632436e-5
+18.05	3.293611805209274e-5
+18.1	3.3338244868617236e-5
+18.15	3.3867324482219054e-5
+18.2	3.452074947804436e-5
+18.25	3.529683044759008e-5
+18.3	3.6195209108088085e-5
+18.35	3.721598242829983e-5
+18.4	3.835970262851611e-5
+18.45	3.962737718055762e-5
+18.5	4.102046880777432e-5
+18.55	4.254089548504581e-5
+18.6	4.41910304387811e-5
+18.65	4.5973702146918646e-5
+18.7	4.789219433892721e-5
+18.75	4.995024599580391e-5
+18.8	5.215213905850741e-5
+18.85	5.450275727598156e-5
+18.9	5.700644253905166e-5
+18.95	5.9667952253505815e-5
+19.0	6.249261610419027e-5
+19.05	6.54863360550103e-5
+19.1	6.865558634892979e-5
+19.15	7.200741350797115e-5
+19.2	7.554943633321627e-5
+19.25	7.928984590480496e-5
+19.3	8.323740558193607e-5
+19.35	8.740145100286727e-5
+19.4	9.179189008491467e-5
+19.45	9.641920302445356e-5
+19.5	0.0001012944422969172
+19.55	0.00010642923265679917
+19.6	0.00011183577113764963
+19.65	0.00011752682705207853
+19.7	0.0001235157419917554
+19.75	0.00012981642982740708
+19.8	0.00013644337670881972
+19.85	0.00014341238042792347
+19.9	0.0001507419950778271
+19.95	0.00015845037624489302
+20.0	0.00016655656364450173
+20.05	0.0001750807873136931
+20.1	0.00018404401543587213
+20.15	0.0001934681944893066
+20.2	0.00020337666682029103
+20.25	0.0002137941706431422
+20.3	0.00022474684004020258
+20.35	0.00023626220496183865
+20.4	0.00024836919122644117
+20.45	0.0002610981205204284
+20.5	0.00027448071039823944
+20.55	0.00028855007428233895
+20.6	0.00030334074363392726
+20.65	0.00031889009257348225
+20.7	0.00033523692131091513
+20.75	0.0003524216342851605
+20.8	0.0003704871119082912
+20.85	0.0003894787105655175
+20.9	0.00040944426261518504
+20.95	0.00043043407638878114
+21.0	0.00045250093619092653
+21.05	0.00047570010229937927
+21.1	0.0005000893109650367
+21.15	0.0005257287744119279
+21.2	0.0005526824540478568
+21.25	0.0005810188336191905
+21.3	0.0006108070107982503
+21.35	0.0006421205340998608
+21.4	0.0006750375944100534
+21.45	0.0007096410249860753
+21.5	0.0007460183014563766
+21.55	0.000784261541820617
+21.6	0.0008244675064496669
+21.65	0.0008667375980856045
+21.7	0.0009111778618417234
+21.75	0.0009578989852025179
+21.8	0.0010070162980236952
+21.85	0.0010586497725321665
+21.9	0.0011129277097948057
+21.95	0.0011699900591849392
+22.0	0.0012299755817579254
+22.05	0.0012930319611769592
+22.1	0.001359316386819383
+22.15	0.001428995553776683
+22.2	0.0015022456628545052
+22.25	0.0015792524205726272
+22.3	0.0016602110391649818
+22.35	0.0017453262365796482
+22.4	0.0018348122364788412
+22.45	0.0019288927682389522
+22.5	0.0020278010669504873
+22.55	0.0021317798734181216
+22.6	0.0022410814341606625
+22.65	0.002355984639422683
+22.7	0.0024767808780344143
+22.75	0.002603761980892953
+22.8	0.0027372401948221586
+22.85	0.0028775481866526414
+22.9	0.003025039043221745
+22.95	0.0031800862713735876
+23.0	0.0033430837979590115
+23.05	0.003514445969835603
+23.1	0.003694607553867713
+23.15	0.0038840237369264087
+23.2	0.0040831701258895415
+23.25	0.004292542747641672
+23.3	0.004512658049074163
+23.35	0.004744052897085041
+23.4	0.004987284580402943
+23.45	0.005242969271549389
+23.5	0.005511761054674322
+23.55	0.00579432197338657
+23.6	0.006091355491220548
+23.65	0.0064036064916362335
+23.7	0.006731861278019287
+23.75	0.007076947573680842
+23.8	0.00743973452185771
+23.85	0.007821132685712216
+23.9	0.00822209404833235
+23.95	0.008643612012731632
+24.0	0.00908672140184931
+24.05	0.00955254692791327
+24.1	0.010042253834152701
+24.15	0.0105570296819283
+24.2	0.011098143201901847
+24.25	0.011666944294036008
+24.3	0.012264864027594422
+24.35	0.012893414641141748
+24.4	0.013554189542543494
+24.45	0.014248863308966384
+24.5	0.014979191686877847
+24.55	0.015747011592046525
+24.6	0.01655424110954179
+24.65	0.01740287949373414
+24.7	0.018295007168295047
+24.75	0.01923278572619698
+24.8	0.02021853997412642
+24.85	0.021254828633957477
+24.9	0.022344159029441403
+24.95	0.023489200541525417
+25.0	0.02469279655942265
+25.05	0.025957964480612526
+25.1	0.02728789571084064
+25.15	0.02868595566411857
+25.2	0.030155683762724525
+25.25	0.0317007934372025
+25.3	0.03332517212636277
+25.35	0.03503288127728171
+25.4	0.03682815634530202
+25.45	0.03871540679403269
+25.5	0.04069921609534854
+25.55	0.042784373528663246
+25.6	0.0449763279076481
+25.65	0.04728044488017886
+25.7	0.04970227955968355
+25.75	0.052247770966094
+25.8	0.05492324202584642
+25.85	0.05773539957188114
+25.9	0.06069133434364258
+25.95	0.06379852098707998
+26.0	0.06706481805464633
+26.05	0.07049846800529865
+26.1	0.0741080972044986
+26.15	0.07790271592421177
+26.2	0.08189171834290815
+26.25	0.08608488254556189
+26.3	0.09049237052365182
+26.35	0.09512472817515968
+26.4	0.09999314447706532
+26.45	0.10511041700287879
+26.5	0.11048880924430914
+26.55	0.11614123078552946
+26.6	0.12208146938430069
+26.65	0.1283241909719708
+26.7	0.1348849396534774
+26.75	0.14178013770734382
+26.8	0.14902708558568242
+26.85	0.15664396191419241
+26.9	0.16464982349216084
+26.95	0.1730646052924643
+27.0	0.18190912046156457
+27.05	0.1912050603195135
+27.1	0.20097499435994792
+27.15	0.21124237025009449
+27.2	0.22203151383076938
+27.25	0.23336763052011422
+27.3	0.24527919741825396
+27.35	0.25779578606893544
+27.4	0.2709467295786414
+27.45	0.2847633563531029
+27.5	0.2992789900972936
+27.55	0.31452894981543345
+27.6	0.3305505498109882
+27.65	0.34738309968666686
+27.7	0.36506790434442815
+27.75	0.3836482639854726
+27.8	0.40316947411024917
+27.85	0.42367882551844804
+27.9	0.44522560430900615
+27.95	0.46786109188011127
+28.0	0.491638564929192
+28.05	0.5166132954529155
+28.1	0.5428425507472129
+28.15	0.5703855934072375
+28.2	0.5993056758137147
+28.25	0.6296730287713358
+28.3	0.661557134535638
+28.35	0.6950312476576611
+28.4	0.7301727245459128
+28.45	0.767063023466374
+28.5	0.8057877045424929
+28.55	0.8464364297551855
+28.6	0.889102962942839
+28.65	0.9338851698013031
+28.7	0.9808850178839114
+28.75	1.0302085766014537
+28.8	1.0819660172221945
+28.85	1.1362716128718702
+28.9	1.1932437385336734
+28.95	1.2530048710482837
+29.0	1.3156815891138394
+29.05	1.3814045732859581
+29.1	1.4503086059776964
+29.15	1.5225304355713316
+29.2	1.5982148520056785
+29.25	1.6775248621592675
+29.3	1.7606287388881923
+29.35	1.8476994258653576
+29.4	1.9389145375804722
+29.45	2.034456359340072
+29.5	2.1345118472674884
+29.55	2.2392726283028628
+29.6	2.348935000203153
+29.65	2.4636999315421133
+29.7	2.5837730617103425
+29.75	2.709364700915209
+29.8	2.8406898301809043
+29.85	2.9779681013484613
+29.9	3.1214238370756484
+29.95	3.271286030837136
+30.0	3.4277883469243466
+30.05	3.5911691204455147
+30.1	3.7616712492812883
+30.15	3.9394477800127183
+30.2	4.124721521955342
+30.25	4.31780877756426
+30.3	4.519005519333036
+30.35	4.728587389793698
+30.4	4.946809701516717
+30.45	5.173907437111068
+30.5	5.410095249224141
+30.55	5.655567460541805
+30.6	5.91049806378836
+30.65	6.175040721726584
+30.7	6.449328767157758
+30.75	6.733475202921545
+30.8	7.027572701896115
+30.85	7.331693606998084
+30.9	7.645889931182482
+30.95	7.970193357442919
+31.0	8.304615238811323
+31.05	8.649146598358202
+31.1	9.003758129192397
+31.15	9.368301795032494
+31.2	9.742340807509574
+31.25	10.125707606463338
+31.3	10.518149593152962
+31.35	10.919290853552916
+31.4	11.328632158352924
+31.45	11.745550962958097
+31.5	12.16930140748874
+31.55	12.599014316780478
+31.6	13.033697200384234
+31.65	13.472234252566171
+31.7	13.91338635230787
+31.75	14.355791063306063
+31.8	14.797962633972846
+31.85	15.238291997435569
+31.9	15.675046771536875
+31.95	16.106371258834777
+32.0	16.530286446602425
+32.05	16.944690006828345
+32.1	17.347356296216514
+32.15	17.73593635618577
+32.2	18.1079723996273
+32.25	18.460967843671416
+32.3	18.791663532227535
+32.35	19.09670346609908
+32.4	19.372860196652287
+32.45	19.61703482581648
+32.5	19.826257006083754
+32.55	19.997684940509277
+32.6	20.12860538271109
+32.65	20.216433636870143
+32.7	20.258713557730346
+32.75	20.253117550598528
+32.8	20.197446571344457
+32.85	20.089630126400827
+32.9	19.92772627276328
+32.95	19.709921617990318
+33.0	19.434531320203547
+33.05	19.100017086239927
+33.1	18.70717440425866
+33.15	18.257944167333243
+33.2	17.753943711112875
+33.25	17.197386025691845
+33.3	16.59107975560906
+33.35	15.938429199848189
+33.4	15.243434311838062
+33.45	14.510690699451878
+33.5	13.745389625008139
+33.55	12.953318005269765
+33.6	12.140858411444603
+33.65	11.314989069185586
+33.7	10.483283858589981
+33.75	9.653912314200417
+33.8	8.834944747288745
+33.85	8.032595116068496
+33.9	7.252574989480735
+33.95	6.500244894133464
+34.0	5.780611806633342
+34.05	5.098329153584972
+34.1	4.457696811591198
+34.15	3.8626611072533747
+34.2	3.31681481717077
+34.25	2.8227900565193895
+34.3	2.3782600656369945
+34.35	1.9832044136900482
+34.4	1.6370099334569301
+34.45	1.3374100032609235
+34.5	1.0804845469706035
+34.55	0.8613445322670071
+34.6	0.6776876583085046
+34.65	0.5267122520657747
+34.7	0.40441627905775024
+34.75	0.3058944183001442
+34.8	0.22781720254468069
+34.85	0.167362080767307
+34.9	0.12112943508036349
+34.95	0.08606884444489188
+35.0	0.06025894955071106
+35.05	0.041575547523421456
+35.1	0.02807503179683125
+35.15	0.018661435450158345
+35.2	0.012187514371274674
+35.25	0.00778543573291238
+35.3	0.004900744392348273
+35.35	0.0030109136488610075
+35.4	0.001818415564176398
+35.45	0.0010735325005964665
+35.5	0.0006215527791510491
+35.55	0.00035179031575440834
+35.6	0.00019521699086201097
+35.65	0.00010567915314407212
+35.7	5.624620480371376e-5
+35.75	2.9060952855810056e-5
+35.8	1.4872013925480651e-5
+35.85	7.309603619828211e-6
+35.9	3.6172062933089507e-6
+35.95	1.6900745319711885e-6
+36.0	7.99061722219662e-7
+36.05	3.804043291337521e-7
+36.1	1.4637335962133365e-7
+36.15	9.604956964715575e-8
+36.2	3.1204373674468655e-8
+36.25	2.1585977569879193e-9
+36.3	2.1295697894631186e-8
+36.35	3.2435323272394713e-9
+36.4	-2.1120310656368706e-8
+36.45	-3.4820157243245986e-9
+36.5	2.5034223156922792e-8
+36.55	-2.4483048590997665e-8
+36.6	-7.752088542315618e-8
+36.65	-3.2288506757808416e-8
+36.7	6.575294392447037e-8
+36.75	2.0398273024218855e-8
+36.8	-1.885092372732254e-7
+36.85	-1.7935095917065193e-7
+36.9	8.463580903736217e-8
+36.95	2.103667032664853e-7
+37.0	-1.0902159040207543e-7
+37.05	-2.3792501267866805e-7
+37.1	2.949933213772468e-8
+37.15	2.605462472391962e-7
+37.2	3.579643931507328e-8
+37.25	-7.719883832340324e-8
+37.3	1.4474087423205393e-8
+37.35	1.1692222224812147e-7
+37.4	4.542260321956686e-8
+37.45	4.663159569340446e-9
+37.5	3.98437435540024e-9
+37.55	2.1542306063229875e-8
+37.6	1.7081577073831208e-8
+37.65	9.916058553643511e-9
+37.7	6.177405840400442e-9
+37.75	4.036295013754806e-9
+37.8	2.4981203133085326e-9
+37.85	1.4029941386133586e-9
+37.9	9.775671336281986e-10
+37.95	7.13443018982795e-10
+38.0	5.474637691993512e-10
+38.05	4.519928975853745e-10
+38.1	4.0371527249345375e-10
+38.15	3.8363711732123077e-10
+38.2	3.770860105113808e-10
+38.25	3.7371088555161584e-10
+38.3	3.67482030974705e-10
+38.35	3.5669109035838606e-10
+38.4	3.4404046252714365e-10
+38.45	3.351363584149498e-10
+38.5	3.306599912751868e-10
+38.55	3.301494468402913e-10
+38.6	3.331801929177388e-10
+38.65	3.3936507939003944e-10
+38.7	3.4835433821474265e-10
+38.75	3.5983558342442964e-10
+38.8	3.735338111267219e-10
+38.85	3.892113995042786e-10
+38.9	4.0666810881478823e-10
+38.95	4.257410813909854e-10
+39.0	4.4630484164063106e-10
+39.05	4.682712960465296e-10
+39.1	4.91589733166524e-10
+39.15	5.162468236334813e-10
+39.2	5.422666201553219e-10
+39.25	5.697105575149844e-10
+39.3	5.986774525704591e-10
+39.35	6.311202142544551e-10
+39.4	6.698607232832231e-10
+39.45	7.143352357870469e-10
+39.5	7.639584348814852e-10
+39.55	8.181804114395342e-10
+39.6	8.764866640916286e-10
+39.65	9.38398099225592e-10
+39.7	1.003471030986719e-9
+39.75	1.0712971812776804e-9
+39.8	1.1415036797586178e-9
+39.85	1.2137530638470843e-9
+39.9	1.2877432787180312e-9
+39.95	1.3632076773038862e-9
+40.0	1.4399150202944342e-9
+40.05	1.517669476136966e-9
+40.1	1.5963106210361297e-9
+40.15	1.6757134389540376e-9
+40.2	1.7557883216102534e-9
+40.25	1.836481068481641e-9
+40.3	1.917772886802624e-9
+40.35	1.9996803915650825e-9
+40.4	2.0822556055181207e-9
+40.45	2.1655859591684613e-9
+40.5	2.2497942907801726e-9
+40.55	2.335038846374777e-9
+40.6	2.4215132797311464e-9
+40.65	2.509446652385729e-9
+40.7	2.599103433632205e-9
+40.75	2.690783500521884e-9
+40.8	2.7848221378632998e-9
+40.85	2.881590038222633e-9
+40.9	2.981493301923223e-9
+40.95	3.08497343704608e-9
+41.0	3.1925073594295738e-9
+41.05	3.304607392669293e-9
+41.1	3.4218212681186887e-9
+41.15	3.544732124888063e-9
+41.2	3.673958509845681e-9
+41.25	3.8101543776169834e-9
+41.3	3.95400909058484e-9
+41.35	4.106247418889438e-9
+41.4	4.267629540428591e-9
+41.45	4.438951040857695e-9
+41.5	4.6210429135889836e-9
+41.55	4.814771559792631e-9
+41.6	5.0210387883960985e-9
+41.65	5.240781816084172e-9
+41.7	5.47497326729939e-9
+41.75	5.724621174241144e-9
+41.8	5.990768976866703e-9
+41.85	6.274495522890672e-9
+41.9	6.5769150677850095e-9
+41.95	6.899177274779381e-9
+42.0	7.242467214860197e-9
+42.05	7.608005366772091e-9
+42.1	7.997047617016514e-9
+42.15	8.410885259852642e-9
+42.2	8.850844997297112e-9
+42.25	9.318288939123777e-9
+42.3	9.814614602864361e-9
+42.35	1.0341254913807295e-8
+42.4	1.0899678204998906e-8
+42.45	1.1491388217243294e-8
+42.5	1.2117924099100853e-8
+42.55	1.2780860406890458e-8
+42.6	1.3481807104688487e-8
+42.65	1.4222409564327528e-8
+42.7	1.5004348565398855e-8
+42.75	1.5829340295250504e-8
+42.8	1.6699136348988185e-8
+42.85	1.7615523729475743e-8
+42.9	1.85803248473321e-8
+42.95	1.9595397520936415e-8
+43.0	2.066263497642389e-8
+43.05	2.1783965847686735e-8
+43.1	2.296135417637582e-8
+43.15	2.4196799411898173e-8
+43.2	2.5492336411419803e-8
+43.25	2.6850035439861983e-8
+43.3	2.827200216990519e-8
+43.35	2.9760377681986907e-8
+43.4	3.131733846430158e-8
+43.45	3.294509641280208e-8
+43.5	3.464589883119767e-8
+43.55	3.6809923607612946e-8
+43.6	4.154499690698694e-8
+43.65	4.898439529619013e-8
+43.7	5.891486049350068e-8
+43.75	7.11289627430776e-8
+43.8	8.542510081497004e-8
+43.85	1.0160750200511346e-7
+43.9	1.1948622213532442e-7
+43.95	1.3887714555331176e-7
+44.0	1.5960198513266106e-7
+44.05	1.8148828227285304e-7
+44.1	2.0436940689925326e-7
+44.15	2.2808455746310648e-7
+44.2	2.5247876094155054e-7
+44.25	2.774028728376016e-7
+44.3	3.027135771801683e-7
+44.35	3.282733865240484e-7
+44.4	3.539506419499113e-7
+44.45	3.796195130643286e-7
+44.5	4.0515999799974224e-7
+44.55	4.304579234144961e-7
+44.6	4.5540494449280557e-7
+44.65	4.798985449447842e-7
+44.7	5.038420370064247e-7
+44.75	5.27144561439597e-7
+44.8	5.497210875320751e-7
+44.85	5.714924130975111e-7
+44.9	5.923851644754372e-7
+44.95	6.123317965312756e-7
+45.0	6.31270592656341e-7
+45.05	6.491456647678196e-7
+45.1	6.659069533088062e-7
+45.15	6.815102272482466e-7
+45.2	6.959170840810043e-7
+45.25	7.090949498278235e-7
+45.3	7.210170790353243e-7
+45.35	7.31662554776009e-7
+45.4	7.410162886482758e-7
+45.45	7.490690207764042e-7
+45.5	7.558173198105756e-7
+45.55	7.612635829268149e-7
+45.6	7.654160358271001e-7
+45.65	7.682887327392191e-7
+45.7	7.699015564169263e-7
+45.75	7.702802181397586e-7
+45.8	7.694562577132557e-7
+45.85	7.674670434687431e-7
+45.9	7.643557722635116e-7
+45.95	7.601714694806681e-7
+46.0	7.549689890292627e-7
+46.05	7.488090133441886e-7
+46.1	7.417580533862704e-7
+46.15	7.338884486421561e-7
+46.2	7.252783671244582e-7
+46.25	7.160118053715835e-7
+46.3	7.061785884479207e-7
+46.35	6.95874369943652e-7
+46.4	6.852006319749456e-7
+46.45	6.742646851837244e-7
+46.5	6.63179668737944e-7
+46.55	6.520645503313282e-7
+46.6	6.410441261835761e-7
+46.65	6.302490210402114e-7
+46.7	6.198156881726391e-7
+46.75	6.09886409378209e-7
+46.8	6.006092949800714e-7
+46.85	5.921382838273863e-7
+46.9	5.84633143295059e-7
+46.95	5.782594692839692e-7
+47.0	5.73188686220875e-7
+47.05	5.695980470584003e-7
+47.1	5.676706332750411e-7
+47.15	5.675953548752083e-7
+47.2	5.695669503891852e-7
+47.25	5.737859868731545e-7
+47.3	5.804588599091835e-7
+47.35	5.897977936052373e-7
+47.4	6.020208405950911e-7
+47.45	6.173518820384442e-7
+47.5	6.360206276209592e-7
+47.55	6.582626155541462e-7
+47.6	6.843192125752839e-7
+47.65	7.14437613947704e-7
+47.7	7.488708434605792e-7
+47.75	7.878777534288256e-7
+47.8	8.317230246934425e-7
+47.85	8.806771666212422e-7
+47.9	9.350165171049236e-7
+47.95	9.95023242562971e-7
+48.0	1.060985337939922e-6
+48.05	1.1331966267061466e-6
+48.1	1.2119567608577905e-6
+48.15	1.2975712209170121e-6
+48.2	1.3903513159318535e-6
+48.25	1.4906141834762016e-6
+48.3	1.5986827896497617e-6
+48.35	1.714885929078196e-6
+48.4	1.8395582249131437e-6
+48.45	1.973040128832015e-6
+48.5	2.11567792103805e-6
+48.55	2.2678237102605558e-6
+48.6	2.429835433754578e-6
+48.65	2.6020768573010654e-6
+48.7	2.784917575206933e-6
+48.75	2.978733010304936e-6
+48.8	3.18390441395349e-6
+48.85	3.4008188660375673e-6
+48.9	3.6298692749671633e-6
+48.95	3.871454377678801e-6
+49.0	4.1259787396345065e-6
+49.05	4.393852754822395e-6
+49.1	4.67549264575661e-6
+49.15	4.971320463476581e-6
+49.2	5.281764087548395e-6
+49.25	5.607257226063441e-6
+49.3	5.948239415639302e-6
+49.35	6.305156021419535e-6
+49.4	6.678458237073032e-6
+49.45	7.068603084795207e-6
+49.5	7.476053415307059e-6
+49.55	7.901277907855351e-6
+49.6	8.344751070213068e-6
+49.65	8.806953238678842e-6
+49.7	9.288370578077374e-6
+49.75	9.789495081758916e-6
+49.8	1.0310824571599971e-5
+49.85	1.0852862698002855e-5
+49.9	1.1416118939895385e-5
+49.95	1.2002207279899062e-5
+50.0	1.2617573124471475e-5
diff --git a/test/julia_dde/test_basic_check_7.txt b/test/julia_dde/test_basic_check_7.txt
new file mode 100644
index 00000000..cb4d1299
--- /dev/null
+++ b/test/julia_dde/test_basic_check_7.txt
@@ -0,0 +1,1001 @@
+0.0	1.2
+0.05	1.1880597989704629
+0.1	1.1762384068305827
+0.15	1.1645346214892631
+0.2	1.1529468631577875
+0.25	1.1414739115822856
+0.3	1.1301147478486349
+0.35	1.118868353042714
+0.4	1.1077337082504002
+0.45	1.0967097945575714
+0.5	1.0857955930501062
+0.55	1.0749900848138823
+0.6	1.0642922509347774
+0.65	1.0537010724986695
+0.7	1.0432155305914366
+0.75	1.032834606298957
+0.8	1.0225572807071084
+0.85	1.0123825349017685
+0.9	1.0023093499688158
+0.95	0.9923367069941277
+1.0	0.9824635870635824
+1.05	0.9726889712630578
+1.1	0.9630118406784319
+1.15	0.9534311763955828
+1.2	0.9439459595003884
+1.25	0.9345551710787264
+1.3	0.9252577922164751
+1.35	0.916052803999512
+1.4	0.9069391875137156
+1.45	0.8979159238449635
+1.5	0.8889819940791337
+1.55	0.8801351528810064
+1.6	0.8713702819563883
+1.65	0.8626865462276534
+1.7	0.8540835081720922
+1.75	0.8455607302669946
+1.8	0.837117774989651
+1.85	0.8287542048173518
+1.9	0.8204695822273872
+1.95	0.8122634696970475
+2.0	0.804135429703623
+2.05	0.7960850247244037
+2.1	0.7881118172366801
+2.15	0.7802153697177425
+2.2	0.7723952446448812
+2.25	0.7646510044953863
+2.3	0.7569822117465481
+2.35	0.7493884288756572
+2.4	0.7418692183600033
+2.45	0.7344241426768768
+2.5	0.7270527643035685
+2.55	0.7197546457173682
+2.6	0.7125293493955662
+2.65	0.705376437815453
+2.7	0.6982954734543185
+2.75	0.6912860187894534
+2.8	0.6843476362981477
+2.85	0.6774798884576916
+2.9	0.6706823377453756
+2.95	0.6639545466384897
+3.0	0.6572960776143245
+3.05	0.65070649315017
+3.1	0.6441853557233166
+3.15	0.6377322278110547
+3.2	0.6313466718906741
+3.25	0.6250282504394656
+3.3	0.6187765259347192
+3.35	0.6125910608537252
+3.4	0.6064714176737739
+3.45	0.6004171588721555
+3.5	0.5944278469261605
+3.55	0.5885030443130788
+3.6	0.582642313510201
+3.65	0.5768408697967465
+3.7	0.5710794766847663
+3.75	0.5653610009353052
+3.8	0.5596913989056138
+3.85	0.5540766269529438
+3.9	0.5485226414345459
+3.95	0.5430353987076714
+4.0	0.5376208551295714
+4.05	0.5324279779879286
+4.1	0.5275956230773843
+4.15	0.5231145795156263
+4.2	0.5189756364203432
+4.25	0.5151695829092232
+4.3	0.5116872080999545
+4.35	0.5085193011102256
+4.4	0.5056566510577244
+4.45	0.5030900470601397
+4.5	0.5008102782351593
+4.55	0.49880813370047183
+4.6	0.4970748553643428
+4.65	0.49560655605664106
+4.7	0.49440098579159875
+4.75	0.4934557002893139
+4.8	0.4927682552698849
+4.85	0.49233620645340975
+4.9	0.49215710955998665
+4.95	0.49222852030971387
+5.0	0.4925479944226895
+5.05	0.4931130876190118
+5.1	0.49392135561877887
+5.15	0.49497035414208906
+5.2	0.4962576389090403
+5.25	0.49778076563973095
+5.3	0.49953729005425906
+5.35	0.501524941318734
+5.4	0.5037458970491127
+5.45	0.506203842831906
+5.5	0.5089017677147981
+5.55	0.5118426607454732
+5.6	0.515029510971616
+5.65	0.5184653074409105
+5.7	0.5221530392010407
+5.75	0.5260956952996914
+5.8	0.5302962647845464
+5.85	0.5347577367032903
+5.9	0.5394831001036072
+5.95	0.5444753440331814
+6.0	0.5497374575396972
+6.05	0.5552724296708389
+6.1	0.5610832494742907
+6.15	0.5671730058296118
+6.2	0.5735480139444015
+6.25	0.5802171634943021
+6.3	0.5871891995100587
+6.35	0.5944728670224162
+6.4	0.6020769110621197
+6.45	0.6100100766599139
+6.5	0.618281108846544
+6.55	0.6268987526527547
+6.6	0.6358717531092911
+6.65	0.6452088552468984
+6.7	0.6549188040963212
+6.75	0.6650103446883043
+6.8	0.675492222053593
+6.85	0.6863731812229321
+6.9	0.6976619672270669
+6.95	0.7093672886207111
+7.0	0.721498859885574
+7.05	0.7340731174836238
+7.1	0.7471081152386012
+7.15	0.7606219069742478
+7.2	0.7746325465143046
+7.25	0.7891580876825128
+7.3	0.804216584302614
+7.35	0.8198260901983494
+7.4	0.8360046591934602
+7.45	0.8527703451116875
+7.5	0.8701412017767727
+7.55	0.888135283012457
+7.6	0.9067706426424818
+7.65	0.9260653344905886
+7.7	0.9460374123805181
+7.75	0.9667049301360117
+7.8	0.9880875594749563
+7.85	1.0102237295962704
+7.9	1.0331448908278946
+7.95	1.0568783428825315
+8.0	1.0814513854728836
+8.05	1.1068902728897445
+8.1	1.133211944911305
+8.15	1.1604300012191169
+8.2	1.18855804149473
+8.25	1.2176096654196973
+8.3	1.247598472675569
+8.35	1.2785380629438947
+8.4	1.3104420359062283
+8.45	1.3433239912441184
+8.5	1.3771979415201407
+8.55	1.412080842329271
+8.6	1.4479844428163409
+8.65	1.4849191054493085
+8.7	1.5228951926961265
+8.75	1.5619230670247535
+8.8	1.6020130909031434
+8.85	1.6431756267992503
+8.9	1.6854210371810325
+8.95	1.7287596845164428
+9.0	1.77320758665951
+9.05	1.8187801150494953
+9.1	1.8654706014211884
+9.15	1.9132715169522916
+9.2	1.9621753328205
+9.25	2.012174520203516
+9.3	2.063261550279037
+9.35	2.1154288942247583
+9.4	2.1686690232183836
+9.45	2.2229744084376075
+9.5	2.278337521060133
+9.55	2.3347508322636554
+9.6	2.3922068132258727
+9.65	2.450698013919172
+9.7	2.5102301062290664
+9.75	2.570768322579658
+9.8	2.6322517060627697
+9.85	2.694619299770224
+9.9	2.757810146793851
+9.95	2.8217632902254697
+10.0	2.8864177731569103
+10.05	2.9517126386799952
+10.1	3.017586929886546
+10.15	3.083979689868393
+10.2	3.150829961717356
+10.25	3.218076788525264
+10.3	3.2856592133839397
+10.35	3.353516279385206
+10.4	3.421593224048038
+10.45	3.48979601343226
+10.5	3.557969743792511
+10.55	3.6259577332031125
+10.6	3.6936032997383843
+10.65	3.760749761472654
+10.7	3.827240436480242
+10.75	3.892918642835474
+10.8	3.957627698612673
+10.85	4.0212109218861585
+10.9	4.083511630730259
+10.95	4.144373143219292
+11.0	4.203638777427587
+11.05	4.261151851429465
+11.1	4.316755482109749
+11.15	4.370260250551552
+11.2	4.421437319670552
+11.25	4.470060803896637
+11.3	4.51590481765969
+11.35	4.5587434753895915
+11.4	4.59835089151623
+11.45	4.634501180469488
+11.5	4.666968456679249
+11.55	4.695526834575399
+11.6	4.71995042858782
+11.65	4.740013353146399
+11.7	4.755489722681018
+11.75	4.766153651621562
+11.8	4.771779254397915
+11.85	4.77214064543996
+11.9	4.767019317339158
+11.95	4.756319793371234
+12.0	4.739918594440316
+12.05	4.717649596734286
+12.1	4.689408949409466
+12.15	4.655135230673881
+12.2	4.614767018735563
+12.25	4.5682428918025355
+12.3	4.515501428082827
+12.35	4.456503109477955
+12.4	4.391345694980259
+12.45	4.320170073351286
+12.5	4.24311636922747
+12.55	4.16032470724525
+12.6	4.071935212041071
+12.65	3.9780967020752507
+12.7	3.879096453017905
+12.75	3.775302980147889
+12.8	3.66708157895344
+12.85	3.554797544922798
+12.9	3.4388161735441893
+12.95	3.3195027603058573
+13.0	3.1972275910455767
+13.05	3.0725085685956763
+13.1	2.9458780493346794
+13.15	2.817829840620449
+13.2	2.6888577498108677
+13.25	2.559455584263797
+13.3	2.4301171513371154
+13.35	2.3013362583886994
+13.4	2.1736458275079675
+13.45	2.047628348742059
+13.5	1.9236743230534912
+13.55	1.802158771791302
+13.6	1.6834567163045295
+13.65	1.5679431779421988
+13.7	1.4559931780533526
+13.75	1.347981737987016
+13.8	1.2442830365268889
+13.85	1.1451348486975603
+13.9	1.0506150157980325
+13.95	0.96079463250776
+14.0	0.8757447935061845
+14.05	0.7955365934727581
+14.1	0.7202411270869309
+14.15	0.6499294890281462
+14.2	0.5846413868944909
+14.25	0.5242148233560262
+14.3	0.4684608017345736
+14.35	0.41719644937801403
+14.4	0.3702388936342235
+14.45	0.3274052618510851
+14.5	0.28851268137647507
+14.55	0.25337827955827474
+14.6	0.22180905141412283
+14.65	0.19350867526965712
+14.7	0.16823527867063232
+14.75	0.14577140563248472
+14.8	0.12589960017065335
+14.85	0.10840240630057683
+14.9	0.09306236803769201
+14.95	0.07966202939743805
+15.0	0.06798393439525205
+15.05	0.05780832744586894
+15.1	0.04897753895854812
+15.15	0.041368396188088095
+15.2	0.034857728459313314
+15.25	0.029322365097047245
+15.3	0.024639135426114043
+15.35	0.020684868771337753
+15.4	0.017336394457542002
+15.45	0.014474135340200821
+15.5	0.012032821889186085
+15.55	0.009972844893206104
+15.6	0.008252148671182253
+15.65	0.006828677542035726
+15.7	0.005660375824687943
+15.75	0.00470518783806014
+15.8	0.003921057901073675
+15.85	0.0032659303326498866
+15.9	0.002706408787715835
+15.95	0.0022421878225317676
+16.0	0.0018592343012498686
+16.05	0.0015411039192114238
+16.1	0.0012772980130345102
+16.15	0.001060686702426913
+16.2	0.000884140107096374
+16.25	0.0007405283467506819
+16.3	0.0006227215410976081
+16.35	0.0005237670300543738
+16.4	0.00044054154764645257
+16.45	0.00037151598796316283
+16.5	0.0003147536398515534
+16.55	0.0002683177921586665
+16.6	0.00023027173373154434
+16.65	0.00019867875341723128
+16.7	0.0001717211482542478
+16.75	0.00014882785931608067
+16.8	0.0001296151169741785
+16.85	0.00011360376098475126
+16.9	0.00010031463110401002
+16.95	8.926856708816294e-5
+17.0	7.998640869342122e-5
+17.05	7.199146609673012e-5
+17.1	6.505037401007346e-5
+17.15	5.911176459089205e-5
+17.2	5.4070918416871555e-5
+17.25	4.9823116065698875e-5
+17.3	4.626363811506052e-5
+17.35	4.328776514264296e-5
+17.4	4.079077772613283e-5
+17.45	3.866795644321627e-5
+17.5	3.6822689845990906e-5
+17.55	3.524866954756011e-5
+17.6	3.3938708195953574e-5
+17.65	3.2874606120415507e-5
+17.7	3.203816365018986e-5
+17.75	3.1411181114520857e-5
+17.8	3.097545884265261e-5
+17.85	3.0712797163829244e-5
+17.9	3.060499640729489e-5
+17.95	3.0633856902293646e-5
+18.0	3.0781178978069665e-5
+18.05	3.103509304805598e-5
+18.1	3.1406304151324396e-5
+18.15	3.189659161775605e-5
+18.2	3.250657615472168e-5
+18.25	3.3236878469591956e-5
+18.3	3.408811926973756e-5
+18.35	3.506091926252916e-5
+18.4	3.615589915533737e-5
+18.45	3.737367965553303e-5
+18.5	3.871488147048675e-5
+18.55	4.0180125307569196e-5
+18.6	4.1770031874151066e-5
+18.65	4.348522187760291e-5
+18.7	4.532631602529567e-5
+18.75	4.7293935024599884e-5
+18.8	4.938869958288625e-5
+18.85	5.1615427351473687e-5
+18.9	5.3983610978448554e-5
+18.95	5.6500343553329005e-5
+19.0	5.917271560407602e-5
+19.05	6.200781765865075e-5
+19.1	6.501274024501433e-5
+19.15	6.819457389112769e-5
+19.2	7.156040912495243e-5
+19.25	7.511733647444949e-5
+19.3	7.887244646757999e-5
+19.35	8.283282963230508e-5
+19.4	8.700557649658562e-5
+19.45	9.139777758838334e-5
+19.5	9.601761403886084e-5
+19.55	0.00010087832563687495
+19.6	0.00010599456638455815
+19.65	0.00011138098882895671
+19.7	0.0001170522455171181
+19.75	0.00012302298899608856
+19.8	0.00012930787181291483
+19.85	0.00013592154651464355
+19.9	0.0001428786656483209
+19.95	0.00015019388176099456
+20.0	0.00015788184739971067
+20.05	0.00016594355676053039
+20.1	0.00017439137835178976
+20.15	0.00018325759864471072
+20.2	0.00019257450411051715
+20.25	0.00020237438122043113
+20.3	0.00021268951644567522
+20.35	0.00022355219625747212
+20.4	0.0002349947071270435
+20.45	0.00024704933552561375
+20.5	0.0002597483679244046
+20.55	0.0002731240907946388
+20.6	0.0002872087906075389
+20.65	0.00030203475383432647
+20.7	0.0003176342669462262
+20.75	0.0003340396164144598
+20.8	0.00035128308871024984
+20.85	0.00036939697030481883
+20.9	0.0003884135476693882
+20.95	0.00040836510727518315
+21.0	0.0004292711037861321
+21.05	0.00045116093863702047
+21.1	0.0004741187449434656
+21.15	0.0004982304888727387
+21.2	0.0005235821365921165
+21.25	0.0005502596542688701
+21.3	0.0005783490080702727
+21.35	0.0006079361641635971
+21.4	0.0006391070887161142
+21.45	0.0006719477478951015
+21.5	0.0007065441078678298
+21.55	0.0007429821348015722
+21.6	0.0007813477948636015
+21.65	0.0008217270542211883
+21.7	0.0008642058790416108
+21.75	0.0009088702354921396
+21.8	0.0009558060897400474
+21.85	0.0010050994079526076
+21.9	0.0010568361562970892
+21.95	0.0011111023009407728
+22.0	0.0011679092830706055
+22.05	0.0012273558045181427
+22.1	0.0012896840216456182
+22.15	0.0013551362682000785
+22.2	0.0014239548779285845
+22.25	0.0014963821845781823
+22.3	0.0015726605218959232
+22.35	0.001653032223628858
+22.4	0.0017377396235240326
+22.45	0.0018270250553285097
+22.5	0.001921130852789334
+22.55	0.002020299349653557
+22.6	0.0021247728796682307
+22.65	0.0022347937765803973
+22.7	0.0023506043741371233
+22.75	0.0024724470060854533
+22.8	0.002600564006172437
+22.85	0.002735197708145127
+22.9	0.0028765904457505637
+22.95	0.003024984552735818
+23.0	0.003180622362847931
+23.05	0.003343726234633874
+23.1	0.0035144820812906107
+23.15	0.003693488052770074
+23.2	0.003881418770958153
+23.25	0.004078948857740697
+23.3	0.004286752935003569
+23.35	0.00450550562463263
+23.4	0.004735881548513729
+23.45	0.004978555328532758
+23.5	0.005234201586575566
+23.55	0.005503494944528014
+23.6	0.005787110024275965
+23.65	0.00608572144770526
+23.7	0.006400003836701804
+23.75	0.0067306318131514385
+23.8	0.007078279998940026
+23.85	0.0074436230159534295
+23.9	0.007827335486077482
+23.95	0.008230092031198104
+24.0	0.008652567273201128
+24.05	0.009094862549674452
+24.1	0.009557864678319622
+24.15	0.010043428286458735
+24.2	0.010553408001413979
+24.25	0.011089658450507441
+24.3	0.011654034261061243
+24.35	0.012248390060397515
+24.4	0.012874580475838329
+24.45	0.013534460134705896
+24.5	0.014229883664322299
+24.55	0.014962705692009656
+24.6	0.015734780845090095
+24.65	0.016547963750885675
+24.7	0.017404109036718636
+24.75	0.01830507132991105
+24.8	0.019252705257785028
+24.85	0.020248865447662703
+24.9	0.021295406526866123
+24.95	0.022394183122717553
+25.0	0.02354704986253904
+25.05	0.024755861373652717
+25.1	0.026020606014804316
+25.15	0.027343296023248972
+25.2	0.028729465615056828
+25.25	0.030184653662472128
+25.3	0.03171439903773921
+25.35	0.033324240613102435
+25.4	0.035019717260806014
+25.45	0.03680636785309454
+25.5	0.03868973126221223
+25.55	0.04067534636040344
+25.6	0.04276875201991251
+25.65	0.04497548711298361
+25.7	0.047301090511861425
+25.75	0.04975110108879013
+25.8	0.05233105771601409
+25.85	0.05504649926577762
+25.9	0.05790296461032489
+25.95	0.060905992621900606
+26.0	0.06406112217274897
+26.05	0.06737389213511429
+26.1	0.07084984138124091
+26.15	0.07449450878337292
+26.2	0.07831321218376372
+26.25	0.0823016199870987
+26.3	0.0864691673727658
+26.35	0.0908341991822645
+26.4	0.09541506025709397
+26.45	0.10023009543875436
+26.5	0.10529764956874482
+26.55	0.11063606748856489
+26.6	0.116263694039714
+26.65	0.12219887406369125
+26.7	0.128459952401997
+26.75	0.1350652738961303
+26.8	0.1420331833875907
+26.85	0.14938202571787756
+26.9	0.15713014572849
+26.95	0.16529588826092845
+27.0	0.17389759815669195
+27.05	0.18295362025728
+27.1	0.19248229940419204
+27.15	0.2025019804389269
+27.2	0.21303100820298543
+27.25	0.22408772753786652
+27.3	0.23569048328506956
+27.35	0.2478576202860941
+27.4	0.2606074833824388
+27.45	0.2739584174156048
+27.5	0.28791877866299176
+27.55	0.30247293391773145
+27.6	0.3176785915165655
+27.65	0.3336030577799969
+27.7	0.35031363902853185
+27.75	0.3678776415826731
+27.8	0.38636237176292476
+27.85	0.40583513588979075
+27.9	0.42636324028377365
+27.95	0.4480139912653802
+28.0	0.47085469515511325
+28.05	0.4949526582734765
+28.1	0.5203751869409741
+28.15	0.547189587478108
+28.2	0.575463166205386
+28.25	0.6052632294433102
+28.3	0.6366570835123847
+28.35	0.6697120347331132
+28.4	0.7044953894259977
+28.45	0.7410744539115467
+28.5	0.779516534510262
+28.55	0.8198889375426474
+28.6	0.8622589693292068
+28.65	0.9066939361904413
+28.7	0.9532611444468608
+28.75	1.0020279004189667
+28.8	1.0530615104272625
+28.85	1.1064292807922524
+28.9	1.1621985178344363
+28.95	1.220442820838801
+29.0	1.2813105969028458
+29.05	1.344994275368992
+29.1	1.4116863321217337
+29.15	1.4815792430455592
+29.2	1.554865484024973
+29.25	1.6317375309444637
+29.3	1.7123878596885258
+29.35	1.7970089461416523
+29.4	1.8857932661883314
+29.45	1.9789332957130696
+29.5	2.0766215106003543
+29.55	2.17905038673468
+29.6	2.2864124000005392
+29.65	2.3989000262824196
+29.7	2.516705741464829
+29.75	2.6400220214322556
+29.8	2.769041342069192
+29.85	2.903956179260132
+29.9	3.0449590088895606
+29.95	3.1922423068419903
+30.0	3.3459985490019064
+30.05	3.5064202112538014
+30.1	3.67369976948217
+30.15	3.8481258257842943
+30.2	4.030315851668986
+30.25	4.220491831202417
+30.3	4.418821900949248
+30.35	4.62547419747414
+30.4	4.840616857341738
+30.45	5.064418017116733
+30.5	5.29704581336377
+30.55	5.538668382647509
+30.6	5.789453861532612
+30.65	6.049570386583722
+30.7	6.319186094365533
+30.75	6.59846912144269
+30.8	6.887587604379853
+30.85	7.186709679741682
+30.9	7.496003484092816
+30.95	7.815637153997958
+31.0	8.145778826021749
+31.05	8.48659663672885
+31.1	8.838381447172468
+31.15	9.202540719144691
+31.2	9.578657865631007
+31.25	9.965746982683546
+31.3	10.362822166354473
+31.35	10.768897512695936
+31.4	11.182987117760067
+31.45	11.604105077599081
+31.5	12.031265488265106
+31.55	12.4634824458103
+31.6	12.899770046286816
+31.65	13.339142385746783
+31.7	13.78061356024242
+31.75	14.22319766582585
+31.8	14.665908798549234
+31.85	15.107761054464726
+31.9	15.547768529624452
+31.95	15.984945320080634
+32.0	16.418305521885394
+32.05	16.846853130042344
+32.1	17.26790395904011
+32.15	17.67779114457618
+32.2	18.073459942546645
+32.25	18.451855608847392
+32.3	18.80992339937447
+32.35	19.144608570023912
+32.4	19.452856376691617
+32.45	19.731612075273663
+32.5	19.977820921665952
+32.55	20.188428171764528
+32.6	20.3603790814654
+32.65	20.49061890666451
+32.7	20.57609290325788
+32.75	20.613746327141484
+32.8	20.60052443421132
+32.85	20.533372480363372
+32.9	20.40923572149365
+32.95	20.225059413498105
+33.0	19.977788812272784
+33.05	19.66676008778088
+33.1	19.296097281984412
+33.15	18.868237530680005
+33.2	18.38557022157783
+33.25	17.850484742388332
+33.3	17.265370480821748
+33.35	16.63261682458828
+33.4	15.954613161398411
+33.45	15.233748878962253
+33.5	14.472631126227892
+33.55	13.679991845426963
+33.6	12.864348805807072
+33.65	12.032403438973647
+33.7	11.190857176531647
+33.75	10.346411450086508
+33.8	9.505767691243308
+33.85	8.675627331607124
+33.9	7.862691802783389
+33.95	7.073662536377071
+34.0	6.315240963993592
+34.05	5.595309345647246
+34.1	4.921859167529581
+34.15	4.295483750671969
+34.2	3.7162475510341326
+34.25	3.1842150245761363
+34.3	2.699450627257795
+34.35	2.2620188150389513
+34.4	1.8719840438796285
+34.45	1.529410769739616
+34.5	1.2343634485789126
+34.55	0.9869006397989931
+34.6	0.7815809844341698
+34.65	0.6109891743701285
+34.7	0.4712762378333173
+34.75	0.35859320305027276
+34.8	0.26909109824747063
+34.85	0.19873408214310903
+34.9	0.14397332902647025
+34.95	0.10246996892547482
+35.0	0.07192728091301866
+35.05	0.05004854406197936
+35.1	0.03453703744523951
+35.15	0.023182054003042168
+35.2	0.014938404781953317
+35.25	0.009290470880887124
+35.3	0.005678771789230219
+35.35	0.003543826996370389
+35.4	0.0023261559916961877
+35.45	0.0014718069186605
+35.5	0.0008193068414955804
+35.55	0.0004119109807310751
+35.6	0.00019087475332052196
+35.65	9.745357621753578e-5
+35.7	7.290286637563057e-5
+35.75	5.847804074836544e-5
+35.8	2.841247648856679e-5
+35.85	1.0361935180105783e-5
+35.9	1.8160283331824833e-6
+35.95	-3.549595404496574e-7
+36.0	7.192560709664645e-7
+36.05	1.9089596791869836e-6
+36.1	9.923000893459772e-7
+36.15	1.6212626992685651e-7
+36.2	-2.766217959471011e-7
+36.25	-4.181047554768823e-7
+36.3	-3.5648325586367654e-7
+36.35	-1.8591794430855596e-7
+36.4	-5.694680126816988e-10
+36.45	1.0540152582286967e-7
+36.5	5.877327065986533e-8
+36.55	-6.239019919981164e-9
+36.6	-4.9123791863240605e-8
+36.65	-7.319477942358836e-8
+36.7	-8.176571685471941e-8
+36.75	-7.81503384103188e-8
+36.8	-6.566237834407583e-8
+36.85	-4.761557090967482e-8
+36.9	-2.7323650360808442e-8
+36.95	-8.100350951159016e-9
+37.0	6.740593065580493e-9
+37.05	1.3885447435725838e-8
+37.1	1.0067323603178002e-8
+37.15	1.697301468326906e-9
+37.2	-5.00787655272815e-9
+37.25	-1.0185988185352414e-8
+37.3	-1.397481115491411e-8
+37.35	-1.651212318678064e-8
+37.4	-1.7935702006318462e-8
+37.45	-1.8383325338895163e-8
+37.5	-1.7992770909877763e-8
+37.55	-1.6901816444633553e-8
+37.6	-1.52482396685294e-8
+37.65	-1.3169818306932935e-8
+37.7	-1.0804330085210747e-8
+37.75	-8.289552728730676e-9
+37.8	-5.763263962859547e-9
+37.85	-3.363241512964171e-9
+37.9	-1.2272631044124011e-9
+37.95	5.068935374291914e-10
+38.0	1.7014506871929536e-9
+38.05	2.218630619512017e-9
+38.1	1.9716664405816655e-9
+38.15	1.5657788077311757e-9
+38.2	1.2119602340273852e-9
+38.25	9.068606515725605e-10
+38.3	6.471299924688145e-10
+38.35	4.294181888182857e-10
+38.4	2.503751727232019e-10
+38.45	1.0665087628567785e-10
+38.5	-5.1047683920800536e-12
+38.55	-8.824182920792646e-11
+38.6	-1.4611037405970277e-10
+38.65	-1.8206047084522614e-10
+38.7	-1.9944218746234075e-10
+38.75	-2.0160559180887615e-10
+38.8	-1.9190075178266892e-10
+38.85	-1.7367773528154975e-10
+38.9	-1.5028661020335788e-10
+38.95	-1.2507744444592053e-10
+39.0	-1.0140030590707907e-10
+39.05	-8.26052624846633e-11
+39.1	-7.204238207650585e-11
+39.15	-7.30617325804437e-11
+39.2	-8.901338189431143e-11
+39.25	-1.1843887220022445e-10
+39.3	-1.4832320983357539e-10
+39.35	-1.777387661420899e-10
+39.4	-2.0674635271727452e-10
+39.45	-2.354067811506525e-10
+39.5	-2.637808630337305e-10
+39.55	-2.919294099580277e-10
+39.6	-3.1991323351506286e-10
+39.65	-3.4779314529634303e-10
+39.7	-3.756299568933911e-10
+39.75	-4.0348447989771404e-10
+39.8	-4.314175259008308e-10
+39.85	-4.594899064942602e-10
+39.9	-4.877624332695093e-10
+39.95	-5.162959178181008e-10
+40.0	-5.451511717315421e-10
+40.05	-5.743890066013516e-10
+40.1	-6.040702340190488e-10
+40.15	-6.3425566557614e-10
+40.2	-6.650061128641489e-10
+40.25	-6.963823874745815e-10
+40.3	-7.284453009989573e-10
+40.35	-7.612556650287956e-10
+40.4	-7.948742911556024e-10
+40.45	-8.293619909709019e-10
+40.5	-8.647795760661994e-10
+40.55	-9.011878580330148e-10
+40.6	-9.38647648462868e-10
+40.65	-9.772197589472638e-10
+40.7	-1.0169650010777275e-9
+40.75	-1.0579441864457631e-9
+40.8	-1.1002181266428913e-9
+40.85	-1.1438476332606328e-9
+40.9	-1.1888935178904905e-9
+40.95	-1.2354165921239922e-9
+41.0	-1.2834776675526395e-9
+41.05	-1.3331375557679534e-9
+41.1	-1.3854527349412998e-9
+41.15	-1.4424430509508524e-9
+41.2	-1.5045369479863534e-9
+41.25	-1.5721299108955784e-9
+41.3	-1.6456174245263288e-9
+41.35	-1.7253949737264096e-9
+41.4	-1.8118580433435922e-9
+41.45	-1.905402118225693e-9
+41.5	-2.0064226832204798e-9
+41.55	-2.115315223175759e-9
+41.6	-2.23247522293934e-9
+41.65	-2.3582981673589822e-9
+41.7	-2.4931795412825165e-9
+41.75	-2.637514829557696e-9
+41.8	-2.7916995170323346e-9
+41.85	-2.956129088554248e-9
+41.9	-3.131199028971183e-9
+41.95	-3.317304823130983e-9
+42.0	-3.5148419558813874e-9
+42.05	-3.724205912070216e-9
+42.1	-3.945792176545295e-9
+42.15	-4.179996234154352e-9
+42.2	-4.42721356974525e-9
+42.25	-4.6878396681657115e-9
+42.3	-4.962270014263561e-9
+42.35	-5.250900092886636e-9
+42.4	-5.55412538888265e-9
+42.45	-5.872341387099479e-9
+42.5	-6.205943572384828e-9
+42.55	-6.555327429586531e-9
+42.6	-6.920888443552437e-9
+42.65	-7.303022099130237e-9
+42.7	-7.702123881167829e-9
+42.75	-8.118589274512894e-9
+42.8	-8.552813764013282e-9
+42.85	-9.005192834516851e-9
+42.9	-9.476121970871267e-9
+42.95	-9.965996657924456e-9
+43.0	-1.0518853981869387e-8
+43.05	-1.1234800074879036e-8
+43.1	-1.2131045103151822e-8
+43.15	-1.3221917892308879e-8
+43.2	-1.4521747267971889e-8
+43.25	-1.6044862055761884e-8
+43.3	-1.7805591081300405e-8
+43.35	-1.9818263170209163e-8
+43.4	-2.2097207148108987e-8
+43.45	-2.4656751840621858e-8
+43.5	-2.7511226073368482e-8
+43.55	-3.067495867197057e-8
+43.6	-3.416227846205004e-8
+43.65	-3.798751426922728e-8
+43.7	-4.216499491912474e-8
+43.75	-4.67090492373627e-8
+43.8	-5.163400604956304e-8
+43.85	-5.695419418134794e-8
+43.9	-6.268394245833733e-8
+43.95	-6.88375797061541e-8
+44.0	-7.542943475041806e-8
+44.05	-8.247383641675137e-8
+44.1	-8.998511353077642e-8
+44.15	-9.797759491811257e-8
+44.2	-1.0646560940438346e-7
+44.25	-1.1546348581520803e-7
+44.3	-1.24985552976209e-7
+44.35	-1.3504613971300902e-7
+44.4	-1.4565957485122662e-7
+44.45	-1.5684018721648633e-7
+44.5	-1.6860230563440633e-7
+44.55	-1.8096025893060949e-7
+44.6	-1.939283759307191e-7
+44.65	-2.0752098546035293e-7
+44.7	-2.2175241634513624e-7
+44.75	-2.3663699741068627e-7
+44.8	-2.521890574826264e-7
+44.85	-2.684229253865804e-7
+44.9	-2.8535292994816517e-7
+44.95	-3.0299339999300667e-7
+45.0	-3.213586643467214e-7
+45.05	-3.4046305183493326e-7
+45.1	-3.603208912832664e-7
+45.15	-3.8094651151733657e-7
+45.2	-4.023542413627711e-7
+45.25	-4.245584096451854e-7
+45.3	-4.475733451902035e-7
+45.35	-4.7141337682345047e-7
+45.4	-4.960928333705409e-7
+45.45	-5.216315967722299e-7
+45.5	-5.48180386179763e-7
+45.55	-5.759023510773636e-7
+45.6	-6.049319141910751e-7
+45.65	-6.354034982469291e-7
+45.7	-6.674515259709735e-7
+45.75	-7.012104200892394e-7
+45.8	-7.368146033277705e-7
+45.85	-7.743984984126117e-7
+45.9	-8.140965280697923e-7
+45.95	-8.560431150253629e-7
+46.0	-9.003726820053514e-7
+46.05	-9.472196517358032e-7
+46.1	-9.96718446942765e-7
+46.15	-1.0490034903522621e-6
+46.2	-1.1042092046903495e-6
+46.25	-1.1624700126830507e-6
+46.3	-1.2239203370564133e-6
+46.35	-1.288694600536486e-6
+46.4	-1.3569272258492901e-6
+46.45	-1.4287526357208848e-6
+46.5	-1.504305252877289e-6
+46.55	-1.5837195000445524e-6
+46.6	-1.6671297999487263e-6
+46.65	-1.7546705753158275e-6
+46.7	-1.8464762488719199e-6
+46.75	-1.942681243343017e-6
+46.8	-2.0434199814551715e-6
+46.85	-2.1488268859344384e-6
+46.9	-2.25903459843318e-6
+46.95	-2.374282022967995e-6
+47.0	-2.495024163628297e-6
+47.05	-2.62173662456376e-6
+47.1	-2.7548950099240605e-6
+47.15	-2.894974923858819e-6
+47.2	-3.0424519705177325e-6
+47.25	-3.1978017540504166e-6
+47.3	-3.3614998786065494e-6
+47.35	-3.5340219483358142e-6
+47.4	-3.7158435673878194e-6
+47.45	-3.907440339912274e-6
+47.5	-4.10928787005878e-6
+47.55	-4.321861761977024e-6
+47.6	-4.545637619816695e-6
+47.65	-4.781091047727388e-6
+47.7	-5.028697649858826e-6
+47.75	-5.288933030360596e-6
+47.8	-5.562272793382392e-6
+47.85	-5.849192543073912e-6
+47.9	-6.150167883584733e-6
+47.95	-6.465685738400373e-6
+48.0	-6.7968501097875715e-6
+48.05	-7.145681758174973e-6
+48.1	-7.5142753168289336e-6
+48.15	-7.904725419015655e-6
+48.2	-8.319126698001555e-6
+48.25	-8.75957378705282e-6
+48.3	-9.22816131943581e-6
+48.35	-9.726983928416907e-6
+48.4	-1.0258136247262273e-5
+48.45	-1.0823712909238363e-5
+48.5	-1.1425808547611325e-5
+48.55	-1.2066517795647538e-5
+48.6	-1.2747935286613406e-5
+48.65	-1.3472155653775048e-5
+48.7	-1.4241273530398967e-5
+48.75	-1.505738354975125e-5
+48.8	-1.5922580345098313e-5
+48.85	-1.6838958549706593e-5
+48.9	-1.7808612796842136e-5
+48.95	-1.883363771977152e-5
+49.0	-1.991612795176076e-5
+49.05	-2.10581781260763e-5
+49.1	-2.226188287598463e-5
+49.15	-2.35293368347517e-5
+49.2	-2.486263463564419e-5
+49.25	-2.626387091192802e-5
+49.3	-2.773514029686967e-5
+49.35	-2.9278537423735685e-5
+49.4	-3.0896156925791924e-5
+49.45	-3.259009343630514e-5
+49.5	-3.4362441588541186e-5
+49.55	-3.621529601576658e-5
+49.6	-3.815075135124791e-5
+49.65	-4.017090222825095e-5
+49.7	-4.2277843280042545e-5
+49.75	-4.447366913988843e-5
+49.8	-4.676100787079892e-5
+49.85	-4.9158332010406336e-5
+49.9	-5.1678658852762816e-5
+49.95	-5.4328420253736615e-5
+50.0	-5.711404806919452e-5
diff --git a/test/julia_dde/test_basic_check_8.txt b/test/julia_dde/test_basic_check_8.txt
new file mode 100644
index 00000000..03be058d
--- /dev/null
+++ b/test/julia_dde/test_basic_check_8.txt
@@ -0,0 +1,101 @@
+0.0	1.2
+0.1	0.9599999999995766
+0.2	0.7199999999999984
+0.3	0.4920000000275244
+0.4	0.29333333335795136
+0.5	0.1409333335308199
+0.6	0.033896296463601816
+0.7	-0.03317518502318162
+0.8	-0.06821925912167397
+0.9	-0.08014353322730915
+1.0	-0.07701983861422906
+1.1	-0.0654722229623731
+1.2	-0.05045085349638549
+1.3	-0.035289552691150715
+1.4	-0.02195877308646844
+1.5	-0.011381793510271429
+1.6	-0.0037517722095308094
+1.7	0.0011898937000099639
+1.8	0.003933781500679573
+1.9	0.005044443226068331
+2.0	0.005058526960046457
+2.1	0.004428643610077466
+2.2	0.003501092643886467
+2.3	0.0025162283890649896
+2.4	0.0016220398912455607
+2.5	0.0008937555703458587
+2.6	0.0003545026061102214
+2.7	-6.021088477003101e-6
+2.8	-0.00021639564671320946
+2.9	-0.00031229906801823225
+3.0	-0.00032906674579850654
+3.1	-0.000297375207694315
+3.2	-0.00024130136430728218
+3.3	-0.00017802281541228886
+3.4	-0.00011850790244016358
+3.5	-6.870395777800505e-5
+3.6	-3.08729627261438e-5
+3.7	-4.8219479059382635e-6
+3.8	1.1040713112133792e-5
+3.9	1.893868175148974e-5
+4.0	2.1183148188046327e-5
+4.1	1.98275447496031e-5
+4.2	1.6521478954025237e-5
+4.3	1.2505551567813808e-5
+4.4	8.575476638211887e-6
+4.5	5.191726552570755e-6
+4.6	2.5552410383856707e-6
+4.7	6.890011193102985e-7
+4.8	-4.90632770177942e-7
+4.9	-1.1199076718777521e-6
+5.0	-1.3467796787535478e-6
+5.1	-1.3094949262271898e-6
+5.2	-1.1233718764641545e-6
+5.3	-8.74180306246676e-7
+5.4	-6.181169983854771e-7
+5.5	-3.883781367637394e-7
+5.6	-2.0528204214353039e-7
+5.7	-7.218293218941051e-8
+5.8	1.495077570387544e-8
+5.9	6.409244617426073e-8
+6.0	8.481977044724017e-8
+6.1	8.636673420998494e-8
+6.2	7.646609915620887e-8
+6.3	6.098239819884763e-8
+6.4	4.425022722636792e-8
+6.5	2.883985554104184e-8
+6.6	1.613411924581046e-8
+6.7	6.637122673450674e-9
+6.8	2.1373987855654678e-10
+6.9	-3.628443633800087e-9
+7.0	-5.521024228007085e-9
+7.1	-6.08050850644319e-9
+7.2	-5.81787672661891e-9
+7.3	-5.111981266902379e-9
+7.4	-4.24678014085044e-9
+7.5	-3.5124055601286853e-9
+7.6	-3.306632059341181e-9
+7.7	-2.9351111500057813e-9
+7.8	-2.2808466389605966e-9
+7.9	-1.5908742427106886e-9
+8.0	-9.981002021452451e-10
+8.1	-5.552644429934129e-10
+8.2	-2.6376382441024117e-10
+8.3	-9.733547567272958e-11
+8.4	-2.0600221028350734e-11
+8.5	-2.4660926630226497e-12
+8.6	-2.4391931740811084e-11
+8.7	-8.351107766773117e-11
+8.8	-1.9061514543636015e-10
+8.9	-3.6299789097840117e-10
+9.0	-6.121591648798676e-10
+9.1	-9.259727156607687e-10
+9.2	-1.0058033091378092e-9
+9.3	-8.732061539788326e-10
+9.4	-7.068814738245494e-10
+9.5	-5.39648760551056e-10
+9.6	-3.608859726598022e-10
+9.7	-1.7608730711911611e-10
+9.8	-2.3539544665121147e-11
+9.9	5.18830314483592e-11
+10.0	6.780514325394838e-11
diff --git a/test/julia_dde/test_basic_check_9.txt b/test/julia_dde/test_basic_check_9.txt
new file mode 100644
index 00000000..36c43baf
--- /dev/null
+++ b/test/julia_dde/test_basic_check_9.txt
@@ -0,0 +1,501 @@
+0.0	1.2
+0.1	1.1913803336218909
+0.2	1.1828464342012461
+0.3	1.1743974465820717
+0.4	1.1660325266930673
+0.5	1.1577508392776654
+0.6	1.1495515569159094
+0.7	1.1414338600244531
+0.8	1.1333969368565608
+0.9	1.125439983502108
+1.0	1.11756220388758
+1.1	1.1097628097760737
+1.2	1.1020410207672957
+1.3	1.0943960642975639
+1.4	1.0868271756398067
+1.5	1.079333597903563
+1.6	1.0719145820349825
+1.7	1.0645693851178095
+1.8	1.0572972581510276
+1.9	1.0500974756894166
+2.0	1.042969323512299
+2.1	1.0359120935999364
+2.2	1.0289250841335298
+2.3	1.0220075994952196
+2.4	1.0151589502680851
+2.5	1.0083784532361455
+2.6	1.0016654313843587
+2.7	0.9950192138986226
+2.8	0.9884391361657737
+2.9	0.9819245397735885
+3.0	0.9754747725107822
+3.1	0.9690891883670097
+3.2	0.9627671475328651
+3.3	0.9565080163998818
+3.4	0.9503111675605325
+3.5	0.9441759798082294
+3.6	0.9381018381373233
+3.7	0.9320881337431056
+3.8	0.9261342640218061
+3.9	0.9202396325705939
+4.0	0.9144036491875777
+4.1	0.9086257298718055
+4.2	0.9029052968232649
+4.3	0.8972417784428806
+4.4	0.8916346093325216
+4.5	0.8860832302949908
+4.6	0.8805870872809749
+4.7	0.8751456306653099
+4.8	0.8697583172256367
+4.9	0.8644246088405793
+5.0	0.8591439724658375
+5.1	0.8539158801341864
+5.2	0.8487398089554763
+5.3	0.8436152411166331
+5.4	0.838541663881658
+5.5	0.8335185695916273
+5.6	0.8285454556646932
+5.7	0.8236218245960826
+5.8	0.8187471839580989
+5.9	0.8139210464001188
+6.0	0.8091429296485967
+6.1	0.8045000663328911
+6.2	0.8000856864444652
+6.3	0.7959078461489202
+6.4	0.7919739863961017
+6.5	0.7882909329201001
+6.6	0.7848648962392509
+6.7	0.7817014716561339
+6.8	0.7788056392575742
+6.9	0.7761817639146412
+7.0	0.7738335952826495
+7.1	0.771764267801158
+7.2	0.7699763006939712
+7.3	0.7684715979691374
+7.4	0.7672514484189504
+7.5	0.7663165256199485
+7.6	0.7656668879329147
+7.7	0.765301978502877
+7.8	0.7652206252591082
+7.9	0.7654210409151257
+8.0	0.7659008229686917
+8.1	0.7666569537018133
+8.2	0.7676858001807424
+8.3	0.7689831142559754
+8.4	0.7705440325622539
+8.5	0.7723630765185638
+8.6	0.7744341523281365
+8.7	0.7767505509784474
+8.8	0.7793049482412178
+8.9	0.7820894046724118
+9.0	0.7850953656122401
+9.1	0.7883136611851577
+9.2	0.7917345062998644
+9.3	0.7953475006493039
+9.4	0.7991422217672364
+9.5	0.8031121244753257
+9.6	0.8072463572221119
+9.7	0.8115331047858
+9.8	0.8159608297236955
+9.9	0.8205182723722055
+10.0	0.8251944508468378
+10.1	0.8299786610422017
+10.2	0.8348604766320074
+10.3	0.8398297490690659
+10.4	0.8448766075852897
+10.5	0.8499914591916924
+10.6	0.8551649886783884
+10.7	0.8603881586145933
+10.8	0.8656522093486239
+10.9	0.870948659007898
+11.0	0.8762693034989344
+11.1	0.8816062165073529
+11.2	0.8869517494978751
+11.3	0.8922985317143232
+11.4	0.8976394701796201
+11.5	0.9029677496957899
+11.6	0.9082768328439581
+11.7	0.9135604599843521
+11.8	0.9188126492562988
+11.9	0.9240284001518926
+12.0	0.9292034205027914
+12.1	0.9343321860472654
+12.2	0.9394104742561977
+12.3	0.9444359312455402
+12.4	0.9494065831882625
+12.5	0.9543208363143533
+12.6	0.9591774769108184
+12.7	0.9639756713216824
+12.8	0.9687149659479878
+12.9	0.9733952872477949
+13.0	0.9780169417361826
+13.1	0.9825806159852475
+13.2	0.9870873766241047
+13.3	0.9915386703388873
+13.4	0.995936323872746
+13.5	1.0002825440258503
+13.6	1.0045799176553876
+13.7	1.0088314116755634
+13.8	1.0130402258394566
+13.9	1.0172070501687922
+14.0	1.0213339854530747
+14.1	1.0254238598495375
+14.2	1.0294791499638911
+14.3	1.0335019808503252
+14.4	1.0374941260115065
+14.5	1.0414570073985805
+14.6	1.0453916954111704
+14.7	1.049298908897378
+14.8	1.0531790151537825
+14.9	1.0570320299254414
+15.0	1.0608576174058906
+15.1	1.0646550902371434
+15.2	1.0684234095096918
+15.3	1.072161184762506
+15.4	1.0758666739830334
+15.5	1.0795377836072
+15.6	1.0831720685194102
+15.7	1.086766754596584
+15.8	1.0903207124475685
+15.9	1.0938297506834724
+16.0	1.0972874877497316
+16.1	1.1006875993433998
+16.2	1.1040238184131483
+16.3	1.1072899351592664
+16.4	1.1104797970336613
+16.5	1.1135873087398578
+16.6	1.1166064322329987
+16.7	1.119531186719844
+16.8	1.1223556486587727
+16.9	1.1250739517597803
+17.0	1.1276802869844813
+17.1	1.1301689025461068
+17.2	1.1325341039095067
+17.3	1.1347702537911484
+17.4	1.1368717721591168
+17.5	1.1388331362331152
+17.6	1.1406488804844641
+17.7	1.1423135966361022
+17.8	1.1438219336625859
+17.9	1.1451685977900892
+18.0	1.1463483524964047
+18.1	1.1473572134649492
+18.2	1.1481942735667123
+18.3	1.1488591062777442
+18.4	1.1493514867783459
+18.5	1.1496713919530686
+18.6	1.1498190003907138
+18.7	1.1497946923843334
+18.8	1.1495990499312294
+18.9	1.149232856732955
+19.0	1.1486970981953126
+19.1	1.1479929614283562
+19.2	1.1471218352463888
+19.3	1.1460853101679647
+19.4	1.1448851784158884
+19.5	1.1435234339172147
+19.6	1.1420022723032486
+19.7	1.1403240909095458
+19.8	1.1384914887759119
+19.9	1.1365072666464031
+20.0	1.1343744269693263
+20.1	1.1320961738972384
+20.2	1.129675913286946
+20.3	1.1271172526995077
+20.4	1.1244240014002305
+20.5	1.1216001502040933
+20.6	1.1186480466392204
+20.7	1.1155699713855358
+20.8	1.1123691395004491
+20.9	1.1090488335390565
+21.0	1.10561240355414
+21.1	1.1020632670961685
+21.2	1.0984049092132975
+21.3	1.0946408824513678
+21.4	1.090774806853908
+21.5	1.0868103699621317
+21.6	1.0827513268149398
+21.7	1.0786014999489193
+21.8	1.0743647793983433
+21.9	1.0700451226951717
+22.0	1.0656465548690504
+22.1	1.0611731684473116
+22.2	1.0566291234549747
+22.3	1.0520186474147442
+22.4	1.047346035347012
+22.5	1.0426156497698558
+22.6	1.0378319206990394
+22.7	1.0329993456480147
+22.8	1.028122489627917
+22.9	1.023205985147571
+23.0	1.0182545322134855
+23.1	1.0132728983298565
+23.2	1.0082659184985674
+23.3	1.0032384952191855
+23.4	0.9981955984889676
+23.5	0.9931422658028544
+23.6	0.9880836021534728
+23.7	0.9830247800311392
+23.8	0.9779710394238517
+23.9	0.9729276878172999
+24.0	0.9679000962768328
+24.1	0.9628924611125935
+24.2	0.957908768050749
+24.3	0.9529536831617005
+24.4	0.9480317902552851
+24.5	0.9431475954471874
+24.6	0.9383055271589413
+24.7	0.9335099371954515
+24.8	0.9287651807614069
+24.9	0.9240754814595606
+25.0	0.9194449749786636
+25.1	0.9148777842689583
+25.2	0.910378019542178
+25.3	0.9059497782715468
+25.4	0.90159714519178
+25.5	0.8973241922990837
+25.6	0.8931349788511549
+25.7	0.8890335513671822
+25.8	0.8850239436278442
+25.9	0.8811101766753112
+26.0	0.8772962588132442
+26.1	0.8735861856067951
+26.2	0.8699839398826074
+26.3	0.8664934917288143
+26.4	0.8631187984950413
+26.5	0.8598638047924039
+26.6	0.8567324424935098
+26.7	0.8537283642735556
+26.8	0.8508523315430913
+26.9	0.8481088039119128
+27.0	0.8455033320072018
+27.1	0.8430411823122855
+27.2	0.8407273371666362
+27.3	0.8385664947658713
+27.4	0.8365630691617543
+27.5	0.8347211902621933
+27.6	0.8330447038312422
+27.7	0.8315371714891004
+27.8	0.8302018707121122
+27.9	0.8290417948327675
+28.0	0.8280596530397015
+28.1	0.8272578703776948
+28.2	0.8266385877476735
+28.3	0.8262036619067089
+28.4	0.8259546654680174
+28.5	0.8258928869009613
+28.6	0.8260193305310479
+28.7	0.8263347165399301
+28.8	0.8268394809654055
+28.9	0.8275337757014182
+29.0	0.8284174684980568
+29.1	0.8294901429615552
+29.2	0.8307510985542932
+29.3	0.8321993505947959
+29.4	0.833833630257733
+29.5	0.8356523845739209
+29.6	0.8376537764303199
+29.7	0.8398356844692543
+29.8	0.8421930058904772
+29.9	0.8447191335522195
+30.0	0.8474090406532736
+30.1	0.8502579280959757
+30.2	0.8532600214445994
+30.3	0.8564084516120837
+30.4	0.8596963022925561
+30.5	0.8631166099613334
+30.6	0.8666623638749217
+30.7	0.8703265060710157
+30.8	0.8741019313684997
+30.9	0.8779814873674467
+31.0	0.8819579744491193
+31.1	0.8860241457759687
+31.2	0.8901727072916352
+31.3	0.8943963177209486
+31.4	0.8986875885699275
+31.5	0.9030390841257798
+31.6	0.9074433214569023
+31.7	0.9118929132684104
+31.8	0.9163818251206785
+31.9	0.9209038559739097
+32.0	0.9254528809327589
+32.1	0.9300230686386423
+32.2	0.9346088812697381
+32.3	0.939205074540986
+32.4	0.9438066977040886
+32.5	0.9484090935475091
+32.6	0.9530078983964734
+32.7	0.9575990421129682
+32.8	0.9621787480957427
+32.9	0.9667435332803087
+33.0	0.9712902081389382
+33.1	0.9758158766806666
+33.2	0.9803179364512893
+33.3	0.984794078533365
+33.4	0.9892422875462145
+33.5	0.9936604951529576
+33.6	0.9980450902090258
+33.7	1.0023952171143335
+33.8	1.0067105329541524
+33.9	1.0109906369754293
+34.0	1.0152350705867872
+34.1	1.0194433173585231
+34.2	1.0236148030226107
+34.3	1.0277488954726983
+34.4	1.0318449047641105
+34.5	1.0359020831138466
+34.6	1.0399196249005815
+34.7	1.0438966666646656
+34.8	1.0478322871081245
+34.9	1.05172550709466
+35.0	1.0555752896496482
+35.1	1.0593805399601417
+35.2	1.063140105374868
+35.3	1.0668527754042294
+35.4	1.0705172817203046
+35.5	1.0741322981568484
+35.6	1.0776964407092886
+35.7	1.0812082675347303
+35.8	1.084666278951954
+35.9	1.088068917441416
+36.0	1.0914145676452458
+36.1	1.094701875725707
+36.2	1.097927448184176
+36.3	1.101086590800616
+36.4	1.1041746413655324
+36.5	1.1071869696799714
+36.6	1.1101189775555211
+36.7	1.1129660988143113
+36.8	1.115723799289013
+36.9	1.1183875768228393
+37.0	1.1209529612695446
+37.1	1.1234155144934246
+37.2	1.1257708303693166
+37.3	1.1280145347826
+37.4	1.1301422856291954
+37.5	1.1321497728155647
+37.6	1.134032718258712
+37.7	1.1357868758861822
+37.8	1.1374080316360624
+37.9	1.1388920034569812
+38.0	1.1402346413081081
+38.1	1.141431827159155
+38.2	1.142479474990375
+38.3	1.1433735307925628
+38.4	1.1441099725670543
+38.5	1.1446848103257277
+38.6	1.1450945626546407
+38.7	1.1453381815859058
+38.8	1.1454154230864997
+38.9	1.1453262358394716
+39.0	1.145070766406141
+39.1	1.1446493592260973
+39.2	1.1440625566172
+39.3	1.143311098775579
+39.4	1.1423959237756336
+39.5	1.1413181675700332
+39.6	1.140079163989718
+39.7	1.1386804447438974
+39.8	1.1371237394200515
+39.9	1.1354109754839297
+40.0	1.1335442782795522
+40.1	1.1315259710292085
+40.2	1.1293585748334591
+40.3	1.1270448086711335
+40.4	1.124587589399332
+40.5	1.1219900317534244
+40.6	1.1192554483470514
+40.7	1.1163873496721226
+40.8	1.1133894440988188
+40.9	1.1102656378755893
+41.0	1.1070200351291546
+41.1	1.1036569378645058
+41.2	1.1001808459649032
+41.3	1.0965964571918767
+41.4	1.092908667185227
+41.5	1.0891221654555732
+41.6	1.085238871914666
+41.7	1.081262592540593
+41.8	1.077197701502359
+41.9	1.0730485720762972
+42.0	1.0688195766460709
+42.1	1.064515086702672
+42.2	1.0601394728444222
+42.3	1.0556971047769725
+42.4	1.0511923513133024
+42.5	1.0466295803737211
+42.6	1.042013158985867
+42.7	1.0373474532847085
+42.8	1.032636828512542
+42.9	1.0278856490189936
+43.0	1.0230982782610187
+43.1	1.0182790788029026
+43.2	1.0134324123162584
+43.3	1.0085626395800302
+43.4	1.0036741204804898
+43.5	0.9987712140112388
+43.6	0.9938582782732083
+43.7	0.9889396704746587
+43.8	0.9840197469311793
+43.9	0.9791028630656887
+44.0	0.9741933734084348
+44.1	0.9692956315969946
+44.2	0.9644139903762744
+44.3	0.9595528015985111
+44.4	0.9547164162232679
+44.5	0.94990918431744
+44.6	0.9451354550552505
+44.7	0.9403995767182515
+44.8	0.9357058966953263
+44.9	0.9310587614826841
+45.0	0.9264625166838675
+45.1	0.9219215070097437
+45.2	0.9174400762785138
+45.3	0.9130225674157044
+45.4	0.9086733224541736
+45.5	0.9043966825341077
+45.6	0.9001940922795764
+45.7	0.8960654669696551
+45.8	0.8920165942169074
+45.9	0.8880532253925764
+46.0	0.8841809479279281
+46.1	0.8804051853142504
+46.2	0.876731197102853
+46.3	0.8731640789050679
+46.4	0.8697087623922485
+46.5	0.866370015295771
+46.6	0.8631524414070331
+46.7	0.8600604805774548
+46.8	0.8570984087184782
+46.9	0.8542703378015669
+47.0	0.8515802158582071
+47.1	0.8490318269799068
+47.2	0.8466287913181959
+47.3	0.8443745650846267
+47.4	0.842272440550773
+47.5	0.8403255460482308
+47.6	0.8385368459686185
+47.7	0.8369091407635763
+47.8	0.8354450669447662
+47.9	0.8341470970838722
+48.0	0.833017539812601
+48.1	0.8320585398226801
+48.2	0.8312720778658602
+48.3	0.830659970753914
+48.4	0.8302238713586347
+48.5	0.8299652686118398
+48.6	0.829885487505367
+48.7	0.8299856890910766
+48.8	0.8302668704808515
+48.9	0.8307298648465955
+49.0	0.8313753414202355
+49.1	0.83220380549372
+49.2	0.8332155984190188
+49.3	0.8344108976081249
+49.4	0.8357897165330536
+49.5	0.8373519047258399
+49.6	0.8390971477785436
+49.7	0.8410249673432448
+49.8	0.8431346986042163
+49.9	0.8454183595200698
+50.0	0.84786698489988
diff --git a/test/julia_dde/test_basic_numerical_check_1.txt b/test/julia_dde/test_basic_numerical_check_1.txt
new file mode 100644
index 00000000..a0f079db
--- /dev/null
+++ b/test/julia_dde/test_basic_numerical_check_1.txt
@@ -0,0 +1,351 @@
+2.0	1.0
+2.01	1.005
+2.02	1.01
+2.03	1.015
+2.04	1.02
+2.05	1.025
+2.06	1.03
+2.07	1.035
+2.08	1.04
+2.09	1.045
+2.1	1.05
+2.11	1.055
+2.12	1.06
+2.13	1.065
+2.14	1.07
+2.15	1.075
+2.16	1.08
+2.17	1.085
+2.18	1.09
+2.19	1.095
+2.2	1.1
+2.21	1.105
+2.22	1.11
+2.23	1.115
+2.24	1.12
+2.25	1.125
+2.26	1.13
+2.27	1.135
+2.28	1.14
+2.29	1.145
+2.3	1.15
+2.31	1.1550000000000002
+2.32	1.16
+2.33	1.165
+2.34	1.17
+2.35	1.175
+2.36	1.18
+2.37	1.185
+2.38	1.19
+2.39	1.195
+2.4	1.2
+2.41	1.205
+2.42	1.21
+2.43	1.215
+2.44	1.22
+2.45	1.225
+2.46	1.23
+2.47	1.235
+2.48	1.24
+2.49	1.245
+2.5	1.25
+2.51	1.255
+2.52	1.26
+2.53	1.265
+2.54	1.27
+2.55	1.275
+2.56	1.28
+2.57	1.285
+2.58	1.29
+2.59	1.295
+2.6	1.3
+2.61	1.305
+2.62	1.3100000000000003
+2.63	1.3150000000000002
+2.64	1.3200000000000003
+2.65	1.325
+2.66	1.33
+2.67	1.335
+2.68	1.34
+2.69	1.345
+2.7	1.35
+2.71	1.355
+2.72	1.3599999999999999
+2.73	1.3650000000000002
+2.74	1.37
+2.75	1.375
+2.76	1.38
+2.77	1.3850000000000002
+2.78	1.39
+2.79	1.395
+2.8	1.4
+2.81	1.405
+2.82	1.41
+2.83	1.415
+2.84	1.4199999999999997
+2.85	1.425
+2.86	1.43
+2.87	1.435
+2.88	1.44
+2.89	1.445
+2.9	1.45
+2.91	1.455
+2.92	1.46
+2.93	1.4649999999999999
+2.94	1.47
+2.95	1.475
+2.96	1.48
+2.97	1.485
+2.98	1.49
+2.99	1.495
+3.0	1.5
+3.01	1.505
+3.02	1.51
+3.03	1.515
+3.04	1.52
+3.05	1.525
+3.06	1.53
+3.07	1.535
+3.08	1.54
+3.09	1.545
+3.1	1.55
+3.11	1.555
+3.12	1.56
+3.13	1.565
+3.14	1.57
+3.15	1.575
+3.16	1.58
+3.17	1.585
+3.18	1.59
+3.19	1.5949999999999998
+3.2	1.6000000000000003
+3.21	1.605
+3.22	1.61
+3.23	1.6150000000000002
+3.24	1.62
+3.25	1.625
+3.26	1.63
+3.27	1.635
+3.28	1.64
+3.29	1.645
+3.3	1.65
+3.31	1.6550000000000002
+3.32	1.66
+3.33	1.665
+3.34	1.67
+3.35	1.675
+3.36	1.68
+3.37	1.685
+3.38	1.69
+3.39	1.695
+3.4	1.7
+3.41	1.705
+3.42	1.71
+3.43	1.715
+3.44	1.72
+3.45	1.725
+3.46	1.73
+3.47	1.735
+3.48	1.74
+3.49	1.745
+3.5	1.75
+3.51	1.755
+3.52	1.76
+3.53	1.765
+3.54	1.77
+3.55	1.7749999999999997
+3.56	1.78
+3.57	1.785
+3.58	1.7900000000000003
+3.59	1.795
+3.6	1.8
+3.61	1.8049999999999997
+3.62	1.81
+3.63	1.815
+3.64	1.82
+3.65	1.825
+3.66	1.83
+3.67	1.835
+3.68	1.84
+3.69	1.845
+3.7	1.85
+3.71	1.855
+3.72	1.86
+3.73	1.865
+3.74	1.87
+3.75	1.875
+3.76	1.88
+3.77	1.8850000000000002
+3.78	1.89
+3.79	1.895
+3.8	1.9000000000000001
+3.81	1.9050000000000002
+3.82	1.91
+3.83	1.915
+3.84	1.92
+3.85	1.925
+3.86	1.93
+3.87	1.935
+3.88	1.94
+3.89	1.945
+3.9	1.9499999999999997
+3.91	1.955
+3.92	1.9599999999999997
+3.93	1.965
+3.94	1.97
+3.95	1.975
+3.96	1.98
+3.97	1.985
+3.98	1.99
+3.99	1.995
+4.0	2.0
+4.01	2.0100250414697625
+4.02	2.02010033362734
+4.03	2.0302261206108767
+4.04	2.04040265467344
+4.05	2.050630196098499
+4.06	2.060909005169523
+4.07	2.07123934216998
+4.08	2.081621467383338
+4.09	2.092055626503915
+4.1	2.102542060618594
+4.11	2.1130810415242953
+4.12	2.1236728415794506
+4.13	2.134317733142494
+4.14	2.145015988571857
+4.15	2.1557678802259743
+4.16	2.1665736804632765
+4.17	2.1774336616421968
+4.18	2.1883480757652336
+4.19	2.199317136842254
+4.2	2.2103411299215328
+4.21	2.2214203456065267
+4.22	2.2325550745006892
+4.23	2.2437456072074764
+4.24	2.2549922343303406
+4.25	2.266295246472737
+4.26	2.2776549342381207
+4.27	2.2890715882299455
+4.28	2.300545499051667
+4.29	2.3120769573067386
+4.3	2.323666253598615
+4.31	2.335313657991987
+4.32	2.347019339067631
+4.33	2.3587835885299824
+4.34	2.3706067225386542
+4.35	2.382489057253257
+4.36	2.3944309088334017
+4.37	2.4064325934386983
+4.38	2.4184944272287576
+4.39	2.4306167263631906
+4.4	2.44279980700161
+4.41	2.4550439853036226
+4.42	2.467349577428842
+4.43	2.4797168995368777
+4.44	2.492146267787343
+4.45	2.5046379983398452
+4.46	2.517192407353997
+4.47	2.5298098109894083
+4.48	2.5424905070170105
+4.49	2.5552346012827103
+4.5	2.568042377913275
+4.51	2.5809141881657314
+4.52	2.5938503832971085
+4.53	2.6068513145644343
+4.54	2.619917333224734
+4.55	2.633048790535037
+4.56	2.646246037752371
+4.57	2.6595094261337646
+4.58	2.6728393069362437
+4.59	2.6862360314168363
+4.6	2.6996999508325707
+4.61	2.713231416440476
+4.62	2.7268307794975772
+4.63	2.740498391260903
+4.64	2.7542346029874816
+4.65	2.7680397659343416
+4.66	2.7819142313585083
+4.67	2.795858350517011
+4.68	2.809872474666877
+4.69	2.8239569550651353
+4.7	2.83811194357926
+4.71	2.85233754544585
+4.72	2.8666341570171343
+4.73	2.881002176398268
+4.74	2.8954420016944025
+4.75	2.909954031010692
+4.76	2.9245386624522913
+4.77	2.939196294124353
+4.78	2.9539273241320316
+4.79	2.968732150580479
+4.8	2.9836111715748492
+4.81	2.9985647852202963
+4.82	3.013593389621975
+4.83	3.0286973828850363
+4.84	3.043877163114635
+4.85	3.059133128415924
+4.86	3.074465676894059
+4.87	3.0898752066541912
+4.88	3.1053621158014746
+4.89	3.1209268024410632
+4.9	3.1365696646781114
+4.91	3.1522911006177714
+4.92	3.168091508365196
+4.93	3.1839712860255402
+4.94	3.1999308317039588
+4.95	3.2159704686922903
+4.96	3.232090069241931
+4.97	3.248289980212273
+4.98	3.2645706615158128
+4.99	3.280932573065044
+5.0	3.2973761747724644
+5.01	3.31390192655057
+5.02	3.3305102883118582
+5.03	3.3472017199688255
+5.04	3.363976681433966
+5.05	3.3808356326197773
+5.06	3.397779033438756
+5.07	3.4148073438033997
+5.08	3.431921023626201
+5.09	3.4491205328196584
+5.1	3.4664063312962683
+5.11	3.4837788789685287
+5.12	3.5012386357489316
+5.13	3.5187860615499758
+5.14	3.536421616284158
+5.15	3.5541457598639754
+5.16	3.5719589522019217
+5.17	3.5898616532104937
+5.18	3.607854322802189
+5.19	3.6259374208895045
+5.2	3.644111407384934
+5.21	3.6623767422009754
+5.22	3.680733885250124
+5.23	3.699183296444879
+5.24	3.717725069576682
+5.25	3.7363305415419448
+5.26	3.754987432320445
+5.27	3.7737041768170903
+5.28	3.7924892099367904
+5.29	3.811350966584451
+5.3	3.830297881664982
+5.31	3.8493383900832914
+5.32	3.8684809267442897
+5.33	3.887733926552881
+5.34	3.9071058244139762
+5.35	3.9266050552324834
+5.36	3.946240053913313
+5.37	3.9660192553613696
+5.38	3.9859510944815626
+5.39	4.006044006178801
+5.4	4.026306425357996
+5.41	4.04674678692405
+5.42	4.067373525781875
+5.43	4.088195076836379
+5.44	4.109219874992471
+5.45	4.130456355155058
+5.46	4.151912952229049
+5.47	4.173598101119351
+5.48	4.195520236730876
+5.49	4.217687793968527
+5.5	4.240109207737216
diff --git a/test/julia_dde/test_basic_numerical_check_2.txt b/test/julia_dde/test_basic_numerical_check_2.txt
new file mode 100644
index 00000000..3cc1ca7b
--- /dev/null
+++ b/test/julia_dde/test_basic_numerical_check_2.txt
@@ -0,0 +1,901 @@
+1.0	1.0
+1.01	1.01
+1.02	1.02
+1.03	1.03
+1.04	1.04
+1.05	1.05
+1.06	1.06
+1.07	1.07
+1.08	1.08
+1.09	1.09
+1.1	1.1
+1.11	1.11
+1.12	1.12
+1.13	1.13
+1.14	1.14
+1.15	1.15
+1.16	1.16
+1.17	1.17
+1.18	1.18
+1.19	1.19
+1.2	1.2
+1.21	1.21
+1.22	1.22
+1.23	1.23
+1.24	1.24
+1.25	1.25
+1.26	1.26
+1.27	1.27
+1.28	1.28
+1.29	1.29
+1.3	1.3
+1.31	1.31
+1.32	1.32
+1.33	1.33
+1.34	1.34
+1.35	1.35
+1.36	1.36
+1.37	1.37
+1.38	1.38
+1.39	1.39
+1.4	1.4
+1.41	1.41
+1.42	1.42
+1.43	1.43
+1.44	1.44
+1.45	1.45
+1.46	1.46
+1.47	1.47
+1.48	1.48
+1.49	1.49
+1.5	1.5
+1.51	1.51
+1.52	1.52
+1.53	1.53
+1.54	1.54
+1.55	1.55
+1.56	1.56
+1.57	1.57
+1.58	1.58
+1.59	1.59
+1.6	1.6
+1.61	1.61
+1.62	1.62
+1.63	1.63
+1.64	1.64
+1.65	1.65
+1.66	1.66
+1.67	1.67
+1.68	1.68
+1.69	1.69
+1.7	1.7
+1.71	1.71
+1.72	1.72
+1.73	1.73
+1.74	1.74
+1.75	1.75
+1.76	1.76
+1.77	1.77
+1.78	1.78
+1.79	1.79
+1.8	1.8
+1.81	1.81
+1.82	1.82
+1.83	1.83
+1.84	1.84
+1.85	1.85
+1.86	1.86
+1.87	1.87
+1.88	1.88
+1.89	1.89
+1.9	1.9
+1.91	1.91
+1.92	1.92
+1.93	1.93
+1.94	1.94
+1.95	1.95
+1.96	1.96
+1.97	1.97
+1.98	1.98
+1.99	1.99
+2.0	2.0
+2.01	2.01
+2.02	2.02
+2.03	2.03
+2.04	2.04
+2.05	2.05
+2.06	2.06
+2.07	2.07
+2.08	2.08
+2.09	2.09
+2.1	2.0999998807417106
+2.11	2.1099990125109187
+2.12	2.119997338303108
+2.13	2.129994899434663
+2.14	2.13999173722197
+2.15	2.1499878929814136
+2.16	2.15998340802938
+2.17	2.1699783236822543
+2.18	2.179972681256422
+2.19	2.1899665220682683
+2.2	2.1999598874341793
+2.21	2.2099528186705397
+2.22	2.2199453570937355
+2.23	2.2299375440201517
+2.24	2.2399294207661744
+2.25	2.2499210286481883
+2.26	2.259912408982579
+2.27	2.2699036030857327
+2.28	2.279894652274034
+2.29	2.2898855978638686
+2.3	2.299876481171622
+2.31	2.30986734351368
+2.32	2.3198582262064273
+2.33	2.32984917056625
+2.34	2.3398402179095323
+2.35	2.3498314095526625
+2.36	2.3598227868120225
+2.37	2.369814391004
+2.38	2.37980626344498
+2.39	2.3897984454513477
+2.4	2.399790978339489
+2.41	2.409783903425789
+2.42	2.4197772620266322
+2.43	2.429771095458406
+2.44	2.439765445037495
+2.45	2.449760352080284
+2.46	2.459755857903159
+2.47	2.469752003822506
+2.48	2.4797488311547093
+2.49	2.4897463812161553
+2.5	2.499744695323229
+2.51	2.5097438147923152
+2.52	2.5197437809398013
+2.53	2.5297446350820705
+2.54	2.53974641853551
+2.55	2.549749172616504
+2.56	2.5597529386414393
+2.57	2.5697577579266997
+2.58	2.579763671788672
+2.59	2.5897707215437404
+2.6	2.599778948508292
+2.61	2.6097883939987105
+2.62	2.6197990993313827
+2.63	2.6298111058226934
+2.64	2.639824454789028
+2.65	2.6498391875467724
+2.66	2.659855345412312
+2.67	2.6698729697020314
+2.68	2.679892101732317
+2.69	2.689912782819554
+2.7	2.6999350542801275
+2.71	2.7099589574304237
+2.72	2.7199845335868273
+2.73	2.7300149666457774
+2.74	2.740076185977242
+2.75	2.7501740836611828
+2.76	2.7603088296354907
+2.77	2.770480593838056
+2.78	2.7806895462067707
+2.79	2.790935856679524
+2.8	2.8012196951942063
+2.81	2.81154123168871
+2.82	2.8219006361009233
+2.83	2.832298078368739
+2.84	2.842733728430047
+2.85	2.853207756222737
+2.86	2.8637203316847
+2.87	2.8742716247538276
+2.88	2.8848618053680095
+2.89	2.8954910434651366
+2.9	2.9061595089830994
+2.91	2.916867371859789
+2.92	2.927614802033095
+2.93	2.9384019694409096
+2.94	2.9492290440211217
+2.95	2.960096195711623
+2.96	2.971003594450303
+2.97	2.981951410175054
+2.98	2.9929398128237654
+2.99	3.0039689723343277
+3.0	3.0150390586446325
+3.01	3.0261502416925694
+3.02	3.0373026914160297
+3.03	3.0484964862986823
+3.04	3.05973087738779
+3.05	3.0710058212414117
+3.06	3.0823215184164074
+3.07	3.0936781694696354
+3.08	3.1050759749579546
+3.09	3.116515135438225
+3.1	3.1279958514673045
+3.11	3.1395183236020525
+3.12	3.1510827523993283
+3.13	3.1626893384159906
+3.14	3.1743382822088986
+3.15	3.1860297843349117
+3.16	3.197764045350888
+3.17	3.209541265813687
+3.18	3.221361646280168
+3.19	3.2332253873071886
+3.2	3.2451326894516104
+3.21	3.2570837532702903
+3.22	3.2690787793200884
+3.23	3.2811179681578624
+3.24	3.2932015203404728
+3.25	3.3053296364247777
+3.26	3.3175025169676364
+3.27	3.329720362525908
+3.28	3.3419833736564515
+3.29	3.354291750916126
+3.3	3.36664569486179
+3.31	3.3790454060503032
+3.32	3.391491085038524
+3.33	3.4039829323833124
+3.34	3.4165211486415257
+3.35	3.4291059343700248
+3.36	3.4417374901256674
+3.37	3.4544160164653133
+3.38	3.4671417139458205
+3.39	3.4799147831240496
+3.4	3.492735424556858
+3.41	3.5056038388011057
+3.42	3.5185202264136515
+3.43	3.5314847879513547
+3.44	3.544497723971073
+3.45	3.557559235029667
+3.46	3.570669497262101
+3.47	3.583827683490499
+3.48	3.597033472871828
+3.49	3.6102870954560555
+3.5	3.623588781293147
+3.51	3.6369387604330705
+3.52	3.6503372629257926
+3.53	3.6637845188212803
+3.54	3.6772807581695015
+3.55	3.6908262110204215
+3.56	3.704421107424008
+3.57	3.7180656774302268
+3.58	3.7317601510890475
+3.59	3.745504758450434
+3.6	3.7592997295643547
+3.61	3.773145294480776
+3.62	3.787041683249666
+3.63	3.80098912592099
+3.64	3.8149878525447165
+3.65	3.8290380931708103
+3.66	3.8431400778492404
+3.67	3.857294036629972
+3.68	3.8715001995629743
+3.69	3.8857587966982114
+3.7	3.900070058085652
+3.71	3.9144342137752624
+3.72	3.9288514938170107
+3.73	3.9433221282608617
+3.74	3.9578463471567837
+3.75	3.972424380554743
+3.76	3.987056458504707
+3.77	4.001742811056642
+3.78	4.016483668260516
+3.79	4.031279260166296
+3.8	4.046129816823948
+3.81	4.0610355682834385
+3.82	4.075996744594735
+3.83	4.091013575807805
+3.84	4.1060862919726135
+3.85	4.12121512313913
+3.86	4.1364002993573195
+3.87	4.151642050677149
+3.88	4.166940607148587
+3.89	4.182295790533781
+3.9	4.197706378888936
+3.91	4.213172511990118
+3.92	4.228694463093983
+3.93	4.24427250545719
+3.94	4.259906912336394
+3.95	4.275597956988256
+3.96	4.29134591266943
+3.97	4.307151052636578
+3.98	4.323013650146354
+3.99	4.338933978455416
+4.0	4.354912310820422
+4.01	4.37094892049803
+4.02	4.387044080744897
+4.03	4.403198064817683
+4.04	4.419411145973042
+4.05	4.435683597467634
+4.06	4.452015692558114
+4.07	4.468407704501144
+4.08	4.4848599065533765
+4.09	4.501372571971471
+4.1	4.517945974012085
+4.11	4.53458038593188
+4.12	4.551276080987509
+4.13	4.56803333243563
+4.14	4.5848524135329
+4.15	4.601733597535981
+4.16	4.618677157701525
+4.17	4.635683367286193
+4.18	4.652752499546642
+4.19	4.669884827739529
+4.2	4.687080625121512
+4.21	4.704340164949247
+4.22	4.721663720479394
+4.23	4.7390515649686105
+4.24	4.756503971673553
+4.25	4.774021213850878
+4.26	4.791603564757244
+4.27	4.809251297649309
+4.28	4.826964685783732
+4.29	4.844744002417169
+4.3	4.862589520806277
+4.31	4.8805015142077135
+4.32	4.898480255878139
+4.33	4.916526019074206
+4.34	4.9346390770525765
+4.35	4.952819703069906
+4.36	4.971068170382854
+4.37	4.9893847522480765
+4.38	5.007769510122513
+4.39	5.026220866411237
+4.4	5.044738705633241
+4.41	5.0633233533753454
+4.42	5.08197513522438
+4.43	5.100694376767169
+4.44	5.119481403590542
+4.45	5.138336541281321
+4.46	5.157260115426332
+4.47	5.1762524516124016
+4.48	5.19531387542636
+4.49	5.214444712455029
+4.5	5.233645288285233
+4.51	5.252915928503802
+4.52	5.272256958697559
+4.53	5.2916687044533335
+4.54	5.311151491357949
+4.55	5.330705644998231
+4.56	5.350331490961006
+4.57	5.370029354833103
+4.58	5.389799562201344
+4.59	5.409642438652556
+4.6	5.429558309773565
+4.61	5.4495475011512005
+4.62	5.4696103383722825
+4.63	5.489747147023642
+4.64	5.509958252692101
+4.65	5.530243980964491
+4.66	5.550604657427632
+4.67	5.571040607668354
+4.68	5.5915521572734805
+4.69	5.612139631829841
+4.7	5.632803356924257
+4.71	5.6535436581435565
+4.72	5.674360861074567
+4.73	5.695255291304114
+4.74	5.716227274419021
+4.75	5.737277136006116
+4.76	5.758405201652225
+4.77	5.779611796944175
+4.78	5.80089724746879
+4.79	5.822261878812898
+4.8	5.843706016563322
+4.81	5.865229986306891
+4.82	5.886834113630431
+4.83	5.908518724120764
+4.84	5.930284143364721
+4.85	5.952130696949125
+4.86	5.974058710460806
+4.87	5.996068509486584
+4.88	6.018160419613288
+4.89	6.040334600260398
+4.9	6.0625891346687535
+4.91	6.084923695666421
+4.92	6.107338677599926
+4.93	6.129834474815789
+4.94	6.152411481660532
+4.95	6.175070092480674
+4.96	6.197810701622736
+4.97	6.220633703433242
+4.98	6.243539492258711
+4.99	6.266528462445665
+5.0	6.289601008340625
+5.01	6.312757524290112
+5.02	6.3359984046406455
+5.03	6.359324043738752
+5.04	6.382734835930947
+5.05	6.406231175563754
+5.06	6.429813456983694
+5.07	6.453482074537289
+5.08	6.477237422571058
+5.09	6.501079895431524
+5.1	6.525009887465207
+5.11	6.549027793018631
+5.12	6.573134006438314
+5.13	6.597328922070776
+5.14	6.621612934262543
+5.15	6.6459864373601345
+5.16	6.670449825710068
+5.17	6.695003493658868
+5.18	6.719647835553055
+5.19	6.744383245739153
+5.2	6.769210118563676
+5.21	6.794128848373153
+5.22	6.8191398295141
+5.23	6.844243456333042
+5.24	6.869440123176496
+5.25	6.894730224390985
+5.26	6.920114154323032
+5.27	6.945592307319155
+5.28	6.971165077725879
+5.29	6.996832859889722
+5.3	7.022596048157205
+5.31	7.048455036874851
+5.32	7.074410220389184
+5.33	7.100461993046717
+5.34	7.126610749193977
+5.35	7.152856883177484
+5.36	7.179200789343762
+5.37	7.205642862039326
+5.38	7.232183495610702
+5.39	7.258823084404409
+5.4	7.28556202276697
+5.41	7.312400705044905
+5.42	7.339339525584734
+5.43	7.36637887873298
+5.44	7.393518738450745
+5.45	7.4207566120939985
+5.46	7.44809241728296
+5.47	7.475526634517007
+5.48	7.503059744295525
+5.49	7.530692227117892
+5.5	7.558424563483486
+5.51	7.5862572338916925
+5.52	7.6141907188418925
+5.53	7.642225498833466
+5.54	7.670362054365791
+5.55	7.698600865938251
+5.56	7.726942414050225
+5.57	7.755387179201102
+5.58	7.783935641890251
+5.59	7.81258828261706
+5.6	7.841345581880909
+5.61	7.870208020181179
+5.62	7.89917607801725
+5.63	7.928250235888502
+5.64	7.957430974294317
+5.65	7.986718773734078
+5.66	8.016114114707165
+5.67	8.045617477712955
+5.68	8.075229343250834
+5.69	8.104950191820182
+5.7	8.134780503920377
+5.71	8.1647207600508
+5.72	8.194771440710834
+5.73	8.224933026399864
+5.74	8.255205997617265
+5.75	8.285590834862417
+5.76	8.316088018634703
+5.77	8.346698029433508
+5.78	8.37742134775821
+5.79	8.408258454108186
+5.8	8.43920982898282
+5.81	8.470275952881495
+5.82	8.501457306303594
+5.83	8.532754369748488
+5.84	8.564167623715566
+5.85	8.595697548704207
+5.86	8.627344625213793
+5.87	8.659109333743702
+5.88	8.690992154793316
+5.89	8.72299356886202
+5.9	8.755114056449191
+5.91	8.787354098054207
+5.92	8.819714174176454
+5.93	8.852194765315314
+5.94	8.884796351970165
+5.95	8.917519414640386
+5.96	8.95036443382536
+5.97	8.98333189002447
+5.98	9.016422263737097
+5.99	9.049636035462617
+6.0	9.082973685700413
+6.01	9.116434969439744
+6.02	9.1500167512682
+6.03	9.183719112838643
+6.04	9.21754264736258
+6.05	9.251487948051537
+6.06	9.285555608117033
+6.07	9.319746220770588
+6.08	9.35406037922372
+6.09	9.388498676687947
+6.1	9.423061706374792
+6.11	9.457750061495775
+6.12	9.492564335262406
+6.13	9.527505120886216
+6.14	9.562573011578714
+6.15	9.597768600551431
+6.16	9.633092481015876
+6.17	9.668545246183573
+6.18	9.70412748926604
+6.19	9.7398398034748
+6.2	9.775682782021368
+6.21	9.811657018117263
+6.22	9.847763104974005
+6.23	9.88400163580312
+6.24	9.920373203816117
+6.25	9.956878402224518
+6.26	9.993517824239847
+6.27	10.03029206307362
+6.28	10.067201711937361
+6.29	10.104247364042582
+6.3	10.141429612600804
+6.31	10.178749050823551
+6.32	10.21620627192234
+6.33	10.253801869108687
+6.34	10.291536435594114
+6.35	10.32941056459014
+6.36	10.367424849308287
+6.37	10.405579882960073
+6.38	10.443876258757014
+6.39	10.48231456991063
+6.4	10.520895409632447
+6.41	10.559619371133975
+6.42	10.598487047626739
+6.43	10.637499032322257
+6.44	10.67665591843205
+6.45	10.715958299167633
+6.46	10.75540676774053
+6.47	10.795001917362256
+6.48	10.834744341244338
+6.49	10.874634632598285
+6.5	10.914673384635622
+6.51	10.954861190567868
+6.52	10.99519864360654
+6.53	11.035686336963167
+6.54	11.076324863849255
+6.55	11.11711481747633
+6.56	11.158056791055907
+6.57	11.199151377799515
+6.58	11.240399170918664
+6.59	11.281800763624878
+6.6	11.323356566310819
+6.61	11.365063021057098
+6.62	11.406919107387543
+6.63	11.448925564848842
+6.64	11.491083132987699
+6.65	11.533392551350804
+6.66	11.575854559484844
+6.67	11.618469896936515
+6.68	11.661239303252511
+6.69	11.704163517979527
+6.7	11.747243280664248
+6.71	11.790479330853373
+6.72	11.833872408093592
+6.73	11.877423251931605
+6.74	11.921132601914092
+6.75	11.965001197587757
+6.76	12.009029778499285
+6.77	12.053219084195373
+6.78	12.097569854222716
+6.79	12.142082828128004
+6.8	12.186758745457928
+6.81	12.231598345759183
+6.82	12.276602368578462
+6.83	12.321771553462456
+6.84	12.367106639957859
+6.85	12.412608367611364
+6.86	12.45827747596967
+6.87	12.504114704579454
+6.88	12.550120792987423
+6.89	12.596296480740264
+6.9	12.642642507384673
+6.91	12.68915961246734
+6.92	12.735848535534956
+6.93	12.782710016134217
+6.94	12.82974479381182
+6.95	12.876953608114448
+6.96	12.924337198588798
+6.97	12.971896304781565
+6.98	13.019631666239444
+6.99	13.067544022509122
+7.0	13.115634113137293
+7.01	13.16390267767065
+7.02	13.212350455655887
+7.03	13.2609781866397
+7.04	13.309786610168775
+7.05	13.358776465789807
+7.06	13.40794849304949
+7.07	13.457303431494523
+7.08	13.506842020671588
+7.09	13.55656500012738
+7.1	13.606473109408597
+7.11	13.65656708806193
+7.12	13.706847675634068
+7.13	13.757315611671709
+7.14	13.80797163572154
+7.15	13.858816487330262
+7.16	13.90985090604456
+7.17	13.961075631411129
+7.18	14.012491402976663
+7.19	14.06409896028786
+7.2	14.115899042891401
+7.21	14.167891946795114
+7.22	14.220071999102633
+7.23	14.272438056698665
+7.24	14.324991146682857
+7.25	14.377732296154873
+7.26	14.430662532214372
+7.27	14.483782881961005
+7.28	14.537094372494439
+7.29	14.590598030914316
+7.3	14.644294884320306
+7.31	14.698185959812058
+7.32	14.752272284489239
+7.33	14.806554885451492
+7.34	14.861034789798483
+7.35	14.915713024629863
+7.36	14.9705906170453
+7.37	15.025668594144438
+7.38	15.080947983026942
+7.39	15.136429810792466
+7.4	15.19211510454067
+7.41	15.248004891371208
+7.42	15.304100198383734
+7.43	15.360402052677909
+7.44	15.416911481353397
+7.45	15.473629511509843
+7.46	15.530557170246906
+7.47	15.587695484664247
+7.48	15.645045481861525
+7.49	15.70260818893839
+7.5	15.7603846329945
+7.51	15.818375841129518
+7.52	15.876582840443094
+7.53	15.935006658034894
+7.54	15.993648321004562
+7.55	16.052508856451766
+7.56	16.111589291476157
+7.57	16.1708906531774
+7.58	16.23041396865514
+7.59	16.29016026500904
+7.6	16.350130569338756
+7.61	16.410325908743953
+7.62	16.470747310324278
+7.63	16.531395801179386
+7.64	16.59227240840894
+7.65	16.653378159112602
+7.66	16.714714080390017
+7.67	16.776281199340847
+7.68	16.838080543064752
+7.69	16.900113138661393
+7.7	16.962380013230412
+7.71	17.024882193871473
+7.72	17.087620707684238
+7.73	17.150596581768365
+7.74	17.2138108432235
+7.75	17.277264519149305
+7.76	17.340958636645443
+7.77	17.404894222811564
+7.78	17.469072304747332
+7.79	17.53349390955239
+7.8	17.59816006432641
+7.81	17.663071796169046
+7.82	17.72823013217995
+7.83	17.793636099458777
+7.84	17.85928695670301
+7.85	17.925178028875727
+7.86	17.99131065735145
+7.87	18.05768625685671
+7.88	18.12430624211806
+7.89	18.191172027862045
+7.9	18.258285028815216
+7.91	18.325646659704105
+7.92	18.393258335255265
+7.93	18.461121470195238
+7.94	18.529237479250586
+7.95	18.59760777714783
+7.96	18.66623377861352
+7.97	18.735116898374212
+7.98	18.804258551156448
+7.99	18.873660151686764
+8.0	18.943323114691715
+8.01	19.01324885489784
+8.02	19.083438787031685
+8.03	19.1538943258198
+8.04	19.224616885988723
+8.05	19.29560788226502
+8.06	19.366868729375206
+8.07	19.43840084204584
+8.08	19.510205635003466
+8.09	19.58228452297463
+8.1	19.654638920685876
+8.11	19.72727024286375
+8.12	19.800179904234795
+8.13	19.873369319525576
+8.14	19.946839903462607
+8.15	20.020593070772442
+8.16	20.09463023618164
+8.17	20.168952814416727
+8.18	20.24356222020426
+8.19	20.318459868270782
+8.2	20.393647173342842
+8.21	20.469125550146995
+8.22	20.544896413409756
+8.23	20.620961177857687
+8.24	20.697321258217332
+8.25	20.773978069215243
+8.26	20.850933025577955
+8.27	20.928187542032017
+8.28	21.005743033303975
+8.29	21.083600914120375
+8.3	21.16176259920777
+8.31	21.24022950329269
+8.32	21.319003041101677
+8.33	21.39808462736129
+8.34	21.47747567679807
+8.35	21.55717760413856
+8.36	21.637191824109305
+8.37	21.717519751436853
+8.38	21.798162800847763
+8.39	21.87912238706855
+8.4	21.960399924825776
+8.41	22.04199682884598
+8.42	22.123914513855713
+8.43	22.20615439458152
+8.44	22.288717885749946
+8.45	22.371606402087533
+8.46	22.45481729706235
+8.47	22.53834326234479
+8.48	22.622185915562632
+8.49	22.70634720663077
+8.5	22.790829085464104
+8.51	22.87563350197752
+8.52	22.960762406085927
+8.53	23.046217747704212
+8.54	23.13200147674727
+8.55	23.21811554313002
+8.56	23.304561896767314
+8.57	23.391342487574068
+8.58	23.478459265465183
+8.59	23.56591418035555
+8.6	23.65370918216006
+8.61	23.74184622079361
+8.62	23.830327246171112
+8.63	23.91915420820745
+8.64	24.008329056817512
+8.65	24.097853741916186
+8.66	24.187730213418394
+8.67	24.27796042123901
+8.68	24.368546315292935
+8.69	24.45948984549507
+8.7	24.550792961760305
+8.71	24.64245761400355
+8.72	24.734485752139676
+8.73	24.826879326083592
+8.74	24.919640285750187
+8.75	25.01277058105436
+8.76	25.10627216191101
+8.77	25.200146978235033
+8.78	25.294396979941318
+8.79	25.389024116944764
+8.8	25.48403033916028
+8.81	25.579417596502733
+8.82	25.675187838887034
+8.83	25.771343016228077
+8.84	25.867885078440757
+8.85	25.96481597543997
+8.86	26.06213765714061
+8.87	26.159852073457575
+8.88	26.25796117430578
+8.89	26.356466909600076
+8.9	26.455371229255388
+8.91	26.554676083186607
+8.92	26.654383421308623
+8.93	26.754495193536336
+8.94	26.855013349784645
+8.95	26.955939839968437
+8.96	27.05727661400263
+8.97	27.159025621802087
+8.98	27.261188813281716
+8.99	27.36376813835641
+9.0	27.466765546941076
+9.01	27.5701829889506
+9.02	27.674022414299873
+9.03	27.778285772903804
+9.04	27.88297501467728
+9.05	27.988092089535215
+9.06	28.09363894739247
+9.07	28.199617538163956
+9.08	28.306025762051267
+9.09	28.41285297194165
+9.1	28.52010091033562
+9.11	28.627772306134688
+9.12	28.73586988824035
+9.13	28.84439638555414
+9.14	28.953354526977527
+9.15	29.062747041412024
+9.16	29.17257665775915
+9.17	29.28284610492041
+9.18	29.3935581117973
+9.19	29.504715407291336
+9.2	29.616320720304017
+9.21	29.728376779736873
+9.22	29.840886314491367
+9.23	29.95385205346903
+9.24	30.067276725571368
+9.25	30.18116305969988
+9.26	30.295513784756075
+9.27	30.41033162964146
+9.28	30.52561932325754
+9.29	30.641379594505825
+9.3	30.757615172287835
+9.31	30.874328785505043
+9.32	30.991523163058968
+9.33	31.109201033851114
+9.34	31.227365126782992
+9.35	31.346018170756114
+9.36	31.465162894671977
+9.37	31.58480202743209
+9.38	31.704938297937975
+9.39	31.825574435091102
+9.4	31.946713167793003
+9.41	32.06835722494517
+9.42	32.190509335449114
+9.43	32.31317222820635
+9.44	32.436348632118374
+9.45	32.56004127608669
+9.46	32.68425288901284
+9.47	32.80898619979828
+9.48	32.93424393734452
+9.49	33.06002883055309
+9.5	33.18634360832549
+9.51	33.31319099956321
+9.52	33.44057373316778
+9.53	33.56849453804069
+9.54	33.696956143083455
+9.55	33.8259612771976
+9.56	33.95551266928458
+9.57	34.08561304824593
+9.58	34.216265142983154
+9.59	34.34747168239776
+9.6	34.47923539539126
+9.61	34.611559010865136
+9.62	34.744445257720926
+9.63	34.87789686486013
+9.64	35.011916561184236
+9.65	35.14650707559475
+9.66	35.28167113699318
+9.67	35.41741147428105
+9.68	35.55373081635985
+9.69	35.69063163117055
+9.7	35.82810999919812
+9.71	35.96616622778142
+9.72	36.10480369622355
+9.73	36.24402578382771
+9.74	36.38383586989705
+9.75	36.524237333734746
+9.76	36.66523355464397
+9.77	36.80682791192787
+9.78	36.94902378488961
+9.79	37.09182455283237
+9.8	37.235233595059334
+9.81	37.37925429087361
+9.82	37.523890019578396
+9.83	37.66914416047685
+9.84	37.81502009287214
+9.85	37.96152119606743
+9.86	38.10865084936588
+9.87	38.25641243207066
+9.88	38.40480932348496
+9.89	38.55384490291189
+9.9	38.70352254965464
+9.91	38.85384564301639
+9.92	39.00481756230029
+9.93	39.156441686809494
+9.94	39.308721395847186
+9.95	39.461660068716526
+9.96	39.6152610847207
+9.97	39.769527823162825
+9.98	39.92446366334608
+9.99	40.08007198457365
+10.0	40.23635616614868
diff --git a/test/test_delays.py b/test/test_delays.py
new file mode 100644
index 00000000..da85900d
--- /dev/null
+++ b/test/test_delays.py
@@ -0,0 +1,1109 @@
+import diffrax
+import jax.lax as lax
+import jax.numpy as jnp
+import jax.random as jrandom
+import pytest
+from diffrax.delays import Delays
+
+
+def open_process_file(path):
+    ts, ys = [], []
+    with open(path, "r", encoding="utf-8") as infile:
+        for line in infile:
+            data = line.split()
+            ts.append(float(data[0])), ys.append(float(data[1]))
+    return jnp.array(ts), jnp.array(ys)
+
+
+def test_dde_solver_with_ode():
+    # testing that dde solver solves ode
+    # when history not specific
+    key = jrandom.PRNGKey(5678)
+    akey, ykey = jrandom.split(key, 2)
+
+    A = jrandom.normal(akey, (10, 10), dtype=jnp.float64) * 0.5
+
+    def dde_f(t, y, args, history):
+        return A @ y
+
+    def ode_f(t, y, args):
+        return A @ y
+
+    dde_term = diffrax.ODETerm(dde_f)
+    ode_term = diffrax.ODETerm(ode_f)
+    t0 = 0
+    t1 = 4
+    ts = jnp.linspace(t0, t1, int(10 * (t1 - t0)))
+    y0 = jrandom.normal(ykey, (10,), dtype=jnp.float64)
+    delays = diffrax.Delays(
+        delays=[lambda t, y, args: 0.2],
+        initial_discontinuities=jnp.array([0.0]),
+        max_discontinuities=100,
+        recurrent_checking=True,
+        rtol=10e-3,
+        atol=10e-6,
+    )
+
+    dt0 = 0.1
+    dde_sol = diffrax.diffeqsolve(
+        dde_term,
+        diffrax.Dopri5(),
+        t0,
+        t1,
+        dt0,
+        y0=lambda t: y0,
+        delays=delays,
+        saveat=diffrax.SaveAt(ts=ts, dense=True),
+    )
+    ode_sol = diffrax.diffeqsolve(
+        ode_term,
+        diffrax.Dopri5(),
+        t0,
+        t1,
+        dt0,
+        y0,
+        saveat=diffrax.SaveAt(ts=ts, dense=True),
+    )
+
+    error = jnp.mean(jnp.abs(dde_sol.ys - ode_sol.ys))
+    assert error < 10**-5
+
+
+def test_jump_ts_dde_solver():
+    # test jump ts with dde solver
+    # when t=2.0 the vf changes
+
+    key = jrandom.PRNGKey(5678)
+
+    def vf(t, y, args, history):
+        sign = jnp.where(t < 2, 1, -1)
+        return sign * history[0]
+
+    def first_part_vf(t, y, args, history):
+        return history[0]
+
+    def second_part_vf(t, y, args, history):
+        return -history[0]
+
+    t0, t1 = 0.0, 4.0
+    dt0 = 0.1
+    ts_first = jnp.linspace(0, 2.0, 20)
+    ts_second = jnp.linspace(2.0, 4.0, 20)
+    ts = jnp.concatenate([ts_first, ts_second[1:]])
+    y0 = jrandom.normal(key, (1,), dtype=jnp.float64)
+    delays = diffrax.Delays(
+        delays=[lambda t, y, args: 1.0],
+        initial_discontinuities=jnp.array([0.0]),
+        max_discontinuities=10,
+        recurrent_checking=False,
+        rtol=10e-3,
+        atol=10e-6,
+    )
+
+    delays2 = diffrax.Delays(
+        delays=[lambda t, y, args: 1.0],
+        initial_discontinuities=jnp.array([2.0]),
+        max_discontinuities=10,
+        recurrent_checking=False,
+        rtol=10e-3,
+        atol=10e-6,
+    )
+
+    first_part_dde = diffrax.diffeqsolve(
+        diffrax.ODETerm(first_part_vf),
+        diffrax.Dopri5(),
+        t0,
+        ts_first[-1],
+        dt0,
+        y0=lambda t: y0,
+        delays=delays,
+        saveat=diffrax.SaveAt(ts=ts_first, dense=True),
+    )
+    second_part_dde = diffrax.diffeqsolve(
+        diffrax.ODETerm(second_part_vf),
+        diffrax.Dopri5(),
+        ts_first[-1],
+        t1,
+        dt0,
+        y0=lambda t: first_part_dde.interpolation.evaluate(t),
+        delays=delays2,
+        stepsize_controller=diffrax.PIDController(rtol=1e-6, atol=1e-3),
+        saveat=diffrax.SaveAt(ts=ts_second, dense=True),
+    )
+    complete_dde = diffrax.diffeqsolve(
+        diffrax.ODETerm(vf),
+        diffrax.Dopri5(),
+        t0,
+        t1,
+        dt0,
+        y0=lambda t: y0,
+        delays=delays,
+        stepsize_controller=diffrax.PIDController(
+            rtol=1e-9, atol=1e-6, jump_ts=jnp.array([2.0])
+        ),
+        saveat=diffrax.SaveAt(ts=ts, dense=True),
+    )
+
+    error = jnp.mean(
+        jnp.abs(
+            complete_dde.ys
+            - jnp.concatenate([first_part_dde.ys, second_part_dde.ys[1:]])
+        )
+    )
+    assert error < 1**-5
+
+
+def test_smooth_dde():
+    # testing a smooth dde with no initial discontinuities
+    # y' = y(t-1), phi = t + 1
+    # for t in [0,1], y(t) = t**2/2 + 1
+    # for t in [1,2], y(t) = (t-1)**3/(2*3) + t + 1/2
+    # we compare the values at t=1,2 with their analytical
+    # solution
+    def dde_f(t, y, args, history):
+        return history[0]
+
+    dde_term = diffrax.ODETerm(dde_f)
+    t0, t1 = 0.0, 2.0
+    ts = jnp.linspace(t0, t1, 100)
+    delays = diffrax.Delays(
+        delays=[lambda t, y, args: 1.0],
+        initial_discontinuities=None,
+        max_discontinuities=10,
+        recurrent_checking=True,
+        rtol=10e-3,
+        atol=10e-6,
+    )
+
+    dt0 = 0.1
+    dde_sol = diffrax.diffeqsolve(
+        dde_term,
+        diffrax.Dopri5(),
+        t0,
+        t1,
+        dt0,
+        y0=lambda t: t + 1,
+        delays=delays,
+        saveat=diffrax.SaveAt(ts=ts, dense=True),
+    )
+
+    error1 = jnp.mean(jnp.abs(dde_sol.ys[50] - 3 / 2))
+    error2 = jnp.mean(jnp.abs(dde_sol.ys[100] - 8 / 3))
+
+    assert error1 < 10**-1
+    assert error2 < 10**-3
+
+
+def _test_exceed_max_discontinuities():
+    # we recurrent_checking to True and
+    # integrate a DDE with a delay equal
+    # to 1 and hence from t > 10.0 we
+    # should have RunTimeError picked
+    # up
+    def dde_f(t, y, args, history):
+        return -history[0]
+
+    t0, t1 = 0.0, 12.0
+    ts = jnp.linspace(t0, t1, 120)
+    delays = diffrax.Delays(
+        delays=[lambda t, y, args: 1.0],
+        initial_discontinuities=jnp.array([0.0]),
+        max_discontinuities=10,
+        recurrent_checking=True,
+        rtol=10e-3,
+        atol=10e-6,
+    )
+
+    dt0 = 0.1
+
+    return diffrax.diffeqsolve(
+        diffrax.ODETerm(dde_f),
+        diffrax.Dopri5(),
+        t0,
+        t1,
+        dt0,
+        y0=lambda t: 1.0,
+        delays=delays,
+        saveat=diffrax.SaveAt(ts=ts, dense=True),
+    )
+
+
+def test_exceed_max_discontinuities():
+    with pytest.raises(RuntimeError):
+        _test_exceed_max_discontinuities()
+
+
+def test_only_explicit_stepping():
+    # we check that we only do explicit
+    # stepping here by putting
+    # dt = 0.9 < delays
+    def dde_f(t, y, args, history):
+        return -history[0]
+
+    t0, t1 = 0.0, 12.0
+    ts = jnp.linspace(t0, t1, 120)
+    delays = diffrax.Delays(
+        delays=[lambda t, y, args: 1.0],
+        initial_discontinuities=jnp.array([0.0]),
+        max_discontinuities=10,
+        recurrent_checking=False,
+        rtol=10e-3,
+        atol=10e-6,
+    )
+
+    dt0 = 0.1
+    sol = diffrax.diffeqsolve(
+        diffrax.ODETerm(dde_f),
+        diffrax.Bosh3(),
+        t0=ts[0],
+        t1=ts[-1],
+        dt0=dt0,
+        y0=lambda t: 1.0,
+        stepsize_controller=diffrax.PIDController(rtol=1e-6, atol=1e-9, dtmax=0.9),
+        saveat=diffrax.SaveAt(ts=ts, dense=True),
+        delays=delays,
+    )
+
+    assert sol.stats["num_dde_implicit_step"] == 0
+    assert sol.stats["num_dde_explicit_step"] > 0
+
+
+def test_hit_explicit_and_implicit_stepping():
+    # we check that we only do implicit
+    # stepping here by putting
+    # dt=1.1 > delays
+    def dde_f(t, y, args, history):
+        return -history[0]
+
+    t0, t1 = 0.0, 25.0
+    ts = jnp.linspace(t0, t1, 120)
+    delays = diffrax.Delays(
+        delays=[lambda t, y, args: 1.0],
+        initial_discontinuities=jnp.array([0.0]),
+        max_discontinuities=10,
+        recurrent_checking=False,
+        rtol=10e-3,
+        atol=10e-6,
+    )
+
+    sol = diffrax.diffeqsolve(
+        diffrax.ODETerm(dde_f),
+        diffrax.Dopri5(),
+        t0=ts[0],
+        t1=ts[-1],
+        dt0=0.1,
+        y0=lambda t: 1.0,
+        stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
+        saveat=diffrax.SaveAt(ts=ts, dense=True),
+        delays=delays,
+    )
+    assert sol.stats["num_dde_implicit_step"] > 0
+    assert sol.stats["num_dde_explicit_step"] > 0
+
+
+def test_basic_check_1():
+    def vector_field(t, y, args, *, history):
+        return y * (1 - history[0])
+
+    y0_history = lambda t: 1.2
+
+    delays = Delays(
+        delays=[lambda t, y, args: 1.0],
+        initial_discontinuities=jnp.array([0.0]),
+        max_discontinuities=100,
+        recurrent_checking=False,
+        rtol=10e-3,
+        atol=10e-6,
+    )
+    made_jump = delays.initial_discontinuities is None
+    t0, t1 = 0.0, 50.0
+    ts = jnp.linspace(t0, t1, 1001)
+    sol = diffrax.diffeqsolve(
+        diffrax.ODETerm(vector_field),
+        diffrax.Tsit5(),
+        t0=ts[0],
+        t1=ts[-1],
+        dt0=ts[1] - ts[0],
+        y0=y0_history,
+        stepsize_controller=diffrax.PIDController(rtol=1e-6, atol=1e-9),
+        saveat=diffrax.SaveAt(ts=ts, dense=True),
+        delays=delays,
+        made_jump=made_jump,
+    )
+
+    delays = Delays(
+        delays=[lambda t, y, args: 1.0],
+        initial_discontinuities=jnp.array([0.0]),
+        max_discontinuities=100,
+        recurrent_checking=True,
+        rtol=10e-3,
+        atol=10e-6,
+    )
+    sol2 = diffrax.diffeqsolve(
+        diffrax.ODETerm(vector_field),
+        diffrax.Tsit5(),
+        t0=ts[0],
+        t1=ts[-1],
+        dt0=ts[1] - ts[0],
+        y0=y0_history,
+        stepsize_controller=diffrax.PIDController(rtol=1e-6, atol=1e-9),
+        saveat=diffrax.SaveAt(ts=ts, dense=True),
+        delays=delays,
+        made_jump=made_jump,
+    )
+
+    _, juliays = open_process_file("test/julia_dde/test_basic_check_1.txt")
+    error1, error2, error3 = (
+        jnp.mean(jnp.abs(juliays - sol.ys)),
+        jnp.mean(jnp.abs(juliays - sol2.ys)),
+        jnp.mean(jnp.abs(sol.ys - sol2.ys)),
+    )
+
+    assert error1 < 1e-4
+    assert error2 < 1e-4
+    assert error3 < 1e-5
+
+
+def test_basic_check_2():
+    def vector_field(t, y, args, *, history):
+        return y * (1 - history[0])
+
+    y0_history = lambda t: 1.2
+
+    delays = Delays(
+        delays=[lambda t, y, args: 2.0],
+        initial_discontinuities=jnp.array([0.0]),
+        max_discontinuities=100,
+        recurrent_checking=False,
+        rtol=10e-3,
+        atol=10e-6,
+    )
+    made_jump = delays.initial_discontinuities is None
+    t0, t1 = 0.0, 50.0
+    ts = jnp.linspace(t0, t1, 1001)
+    sol = diffrax.diffeqsolve(
+        diffrax.ODETerm(vector_field),
+        diffrax.Bosh3(),
+        t0=ts[0],
+        t1=ts[-1],
+        dt0=ts[1] - ts[0],
+        y0=y0_history,
+        stepsize_controller=diffrax.PIDController(rtol=1e-6, atol=1e-9),
+        saveat=diffrax.SaveAt(ts=ts, dense=True),
+        delays=delays,
+        made_jump=made_jump,
+    )
+
+    delays = Delays(
+        delays=[lambda t, y, args: 2.0],
+        initial_discontinuities=jnp.array([0.0]),
+        max_discontinuities=100,
+        recurrent_checking=True,
+        rtol=10e-3,
+        atol=10e-6,
+    )
+    sol2 = diffrax.diffeqsolve(
+        diffrax.ODETerm(vector_field),
+        diffrax.Bosh3(),
+        t0=ts[0],
+        t1=ts[-1],
+        dt0=ts[1] - ts[0],
+        y0=y0_history,
+        stepsize_controller=diffrax.PIDController(rtol=1e-6, atol=1e-9),
+        saveat=diffrax.SaveAt(ts=ts, dense=True),
+        delays=delays,
+        made_jump=made_jump,
+    )
+
+    _, juliays = open_process_file("test/julia_dde/test_basic_check_2.txt")
+    error1, error2, error3 = (
+        jnp.mean(jnp.abs(juliays - sol.ys)),
+        jnp.mean(jnp.abs(juliays - sol2.ys)),
+        jnp.mean(jnp.abs(sol.ys - sol2.ys)),
+    )
+
+    assert error1 < 1e-2
+    assert error2 < 1e-2
+    assert error3 < 1e-5
+
+
+def test_basic_check_3():
+    # same test as test_basic_check_2 but we
+    # have a larger delay =3
+    def vector_field(t, y, args, *, history):
+        return y * (1 - history[0])
+
+    y0_history = lambda t: 1.2
+
+    delays = Delays(
+        delays=[lambda t, y, args: 3.0],
+        initial_discontinuities=jnp.array([0.0]),
+        max_discontinuities=100,
+        recurrent_checking=False,
+        rtol=10e-3,
+        atol=10e-6,
+    )
+    made_jump = delays.initial_discontinuities is None
+    t0, t1 = 0.0, 50.0
+    ts = jnp.linspace(t0, t1, 1001)
+    sol = diffrax.diffeqsolve(
+        diffrax.ODETerm(vector_field),
+        diffrax.Bosh3(),
+        t0=ts[0],
+        t1=ts[-1],
+        dt0=ts[1] - ts[0],
+        y0=y0_history,
+        stepsize_controller=diffrax.PIDController(rtol=1e-6, atol=1e-9),
+        saveat=diffrax.SaveAt(ts=ts, dense=True),
+        delays=delays,
+        made_jump=made_jump,
+    )
+
+    delays = Delays(
+        delays=[lambda t, y, args: 3.0],
+        initial_discontinuities=jnp.array([0.0]),
+        max_discontinuities=100,
+        recurrent_checking=True,
+        rtol=10e-3,
+        atol=10e-6,
+    )
+    sol2 = diffrax.diffeqsolve(
+        diffrax.ODETerm(vector_field),
+        diffrax.Bosh3(),
+        t0=ts[0],
+        t1=ts[-1],
+        dt0=ts[1] - ts[0],
+        y0=y0_history,
+        stepsize_controller=diffrax.PIDController(rtol=1e-6, atol=1e-9),
+        saveat=diffrax.SaveAt(ts=ts, dense=True),
+        delays=delays,
+        made_jump=made_jump,
+    )
+
+    _, juliays = open_process_file("test/julia_dde/test_basic_check_3.txt")
+    error1, error2, error3 = (
+        jnp.mean(jnp.abs(juliays - sol.ys)),
+        jnp.mean(jnp.abs(juliays - sol2.ys)),
+        jnp.mean(jnp.abs(sol.ys - sol2.ys)),
+    )
+
+    assert error1 < 0.15
+    assert error2 < 0.15
+    assert error3 < 1e-5
+
+
+def test_basic_check_4():
+    # same experiment as test_basic_check_3
+    # but solver is Tsit5
+    def vector_field(t, y, args, *, history):
+        return y * (1 - history[0])
+
+    y0_history = lambda t: 1.2
+
+    delays = Delays(
+        delays=[lambda t, y, args: 3.0],
+        initial_discontinuities=jnp.array([0.0]),
+        max_discontinuities=100,
+        recurrent_checking=False,
+        rtol=10e-3,
+        atol=10e-6,
+    )
+    made_jump = delays.initial_discontinuities is None
+    t0, t1 = 0.0, 50.0
+    ts = jnp.linspace(t0, t1, 1001)
+    sol = diffrax.diffeqsolve(
+        diffrax.ODETerm(vector_field),
+        diffrax.Tsit5(),
+        t0=ts[0],
+        t1=ts[-1],
+        dt0=ts[1] - ts[0],
+        y0=y0_history,
+        stepsize_controller=diffrax.PIDController(rtol=1e-6, atol=1e-9),
+        saveat=diffrax.SaveAt(ts=ts, dense=True),
+        delays=delays,
+        made_jump=made_jump,
+    )
+
+    delays = Delays(
+        delays=[lambda t, y, args: 3.0],
+        initial_discontinuities=jnp.array([0.0]),
+        max_discontinuities=100,
+        recurrent_checking=True,
+        rtol=10e-3,
+        atol=10e-6,
+    )
+    sol2 = diffrax.diffeqsolve(
+        diffrax.ODETerm(vector_field),
+        diffrax.Tsit5(),
+        t0=ts[0],
+        t1=ts[-1],
+        dt0=ts[1] - ts[0],
+        y0=y0_history,
+        stepsize_controller=diffrax.PIDController(rtol=1e-6, atol=1e-9),
+        saveat=diffrax.SaveAt(ts=ts, dense=True),
+        delays=delays,
+        made_jump=made_jump,
+    )
+
+    _, juliays = open_process_file("test/julia_dde/test_basic_check_4.txt")
+    error1, error2, error3 = (
+        jnp.mean(jnp.abs(juliays - sol.ys)),
+        jnp.mean(jnp.abs(juliays - sol2.ys)),
+        jnp.mean(jnp.abs(sol.ys - sol2.ys)),
+    )
+
+    assert error1 < 1e-2
+    assert error2 < 1e-2
+    assert error3 < 1e-5
+
+
+def test_basic_check_5():
+    # same test as test_basic_check_3 but we
+    # have a larger delay =4
+    def vector_field(t, y, args, *, history):
+        return y * (1 - history[0])
+
+    y0_history = lambda t: 1.2
+
+    delays = Delays(
+        delays=[lambda t, y, args: 4.0],
+        initial_discontinuities=jnp.array([0.0]),
+        max_discontinuities=10,
+        recurrent_checking=False,
+        rtol=10e-3,
+        atol=10e-6,
+    )
+    made_jump = delays.initial_discontinuities is None
+    t0, t1 = 0.0, 50.0
+    ts = jnp.linspace(t0, t1, 1001)
+    sol = diffrax.diffeqsolve(
+        diffrax.ODETerm(vector_field),
+        diffrax.Bosh3(),
+        t0=ts[0],
+        t1=ts[-1],
+        dt0=ts[1] - ts[0],
+        y0=y0_history,
+        stepsize_controller=diffrax.PIDController(rtol=1e-6, atol=1e-9),
+        saveat=diffrax.SaveAt(ts=ts, dense=True),
+        delays=delays,
+        made_jump=made_jump,
+    )
+
+    delays = Delays(
+        delays=[lambda t, y, args: 4.0],
+        initial_discontinuities=jnp.array([0.0]),
+        max_discontinuities=100,
+        recurrent_checking=True,
+        rtol=10e-3,
+        atol=10e-6,
+    )
+    sol2 = diffrax.diffeqsolve(
+        diffrax.ODETerm(vector_field),
+        diffrax.Bosh3(),
+        t0=ts[0],
+        t1=ts[-1],
+        dt0=ts[1] - ts[0],
+        y0=y0_history,
+        stepsize_controller=diffrax.PIDController(rtol=1e-6, atol=1e-9),
+        saveat=diffrax.SaveAt(ts=ts, dense=True),
+        delays=delays,
+        made_jump=made_jump,
+    )
+
+    _, juliays = open_process_file("test/julia_dde/test_basic_check_5.txt")
+    error1, error2, error3 = (
+        jnp.mean(jnp.abs(juliays - sol.ys)),
+        jnp.mean(jnp.abs(juliays - sol2.ys)),
+        jnp.mean(jnp.abs(sol.ys - sol2.ys)),
+    )
+
+    assert error1 < 0.2
+    assert error2 < 0.2
+    assert error3 < 1e-5
+
+
+def test_basic_check_6():
+    # same experiment as test_basic_check_5
+    # but solver is Tsit5
+    def vector_field(t, y, args, *, history):
+        return y * (1 - history[0])
+
+    y0_history = lambda t: 1.2
+
+    delays = Delays(
+        delays=[lambda t, y, args: 4.0],
+        initial_discontinuities=jnp.array([0.0]),
+        max_discontinuities=100,
+        recurrent_checking=False,
+        rtol=10e-3,
+        atol=10e-6,
+    )
+    made_jump = delays.initial_discontinuities is None
+    t0, t1 = 0.0, 50.0
+    ts = jnp.linspace(t0, t1, 1001)
+    sol = diffrax.diffeqsolve(
+        diffrax.ODETerm(vector_field),
+        diffrax.Tsit5(),
+        t0=ts[0],
+        t1=ts[-1],
+        dt0=ts[1] - ts[0],
+        y0=y0_history,
+        stepsize_controller=diffrax.PIDController(rtol=1e-6, atol=1e-9),
+        saveat=diffrax.SaveAt(ts=ts, dense=True),
+        delays=delays,
+        made_jump=made_jump,
+    )
+
+    delays = Delays(
+        delays=[lambda t, y, args: 4.0],
+        initial_discontinuities=jnp.array([0.0]),
+        max_discontinuities=100,
+        recurrent_checking=True,
+        rtol=10e-3,
+        atol=10e-6,
+    )
+    sol2 = diffrax.diffeqsolve(
+        diffrax.ODETerm(vector_field),
+        diffrax.Tsit5(),
+        t0=ts[0],
+        t1=ts[-1],
+        dt0=ts[1] - ts[0],
+        y0=y0_history,
+        stepsize_controller=diffrax.PIDController(rtol=1e-6, atol=1e-9),
+        saveat=diffrax.SaveAt(ts=ts, dense=True),
+        delays=delays,
+        made_jump=made_jump,
+    )
+
+    _, juliays = open_process_file("test/julia_dde/test_basic_check_6.txt")
+    error1, error2, error3 = (
+        jnp.mean(jnp.abs(juliays - sol.ys)),
+        jnp.mean(jnp.abs(juliays - sol2.ys)),
+        jnp.mean(jnp.abs(sol.ys - sol2.ys)),
+    )
+
+    assert error1 < 1e-2
+    assert error2 < 1e-2
+    assert error3 < 1e-5
+
+
+def test_basic_check_7():
+    # same experiment as test_basic_check_7
+    # but solver is implicit
+    def vector_field(t, y, args, *, history):
+        return y * (1 - history[0])
+
+    y0_history = lambda t: 1.2
+
+    delays = Delays(
+        delays=[lambda t, y, args: 4.0],
+        initial_discontinuities=jnp.array([0.0]),
+        max_discontinuities=100,
+        recurrent_checking=False,
+        rtol=10e-3,
+        atol=10e-6,
+    )
+    made_jump = delays.initial_discontinuities is None
+    t0, t1 = 0.0, 50.0
+    ts = jnp.linspace(t0, t1, 1001)
+    sol = diffrax.diffeqsolve(
+        diffrax.ODETerm(vector_field),
+        diffrax.Kvaerno5(),
+        t0=ts[0],
+        t1=ts[-1],
+        dt0=ts[1] - ts[0],
+        y0=y0_history,
+        stepsize_controller=diffrax.PIDController(rtol=1e-6, atol=1e-9),
+        saveat=diffrax.SaveAt(ts=ts, dense=True),
+        delays=delays,
+        made_jump=made_jump,
+    )
+
+    delays = Delays(
+        delays=[lambda t, y, args: 4.0],
+        initial_discontinuities=jnp.array([0.0]),
+        max_discontinuities=100,
+        recurrent_checking=True,
+        rtol=10e-3,
+        atol=10e-6,
+    )
+    sol2 = diffrax.diffeqsolve(
+        diffrax.ODETerm(vector_field),
+        diffrax.Kvaerno5(),
+        t0=ts[0],
+        t1=ts[-1],
+        dt0=ts[1] - ts[0],
+        y0=y0_history,
+        stepsize_controller=diffrax.PIDController(rtol=1e-6, atol=1e-9),
+        saveat=diffrax.SaveAt(ts=ts, dense=True),
+        delays=delays,
+        made_jump=made_jump,
+    )
+
+    _, juliays = open_process_file("test/julia_dde/test_basic_check_7.txt")
+    error1, error2, error3 = (
+        jnp.mean(jnp.abs(juliays - sol.ys)),
+        jnp.mean(jnp.abs(juliays - sol2.ys)),
+        jnp.mean(jnp.abs(sol.ys - sol2.ys)),
+    )
+
+    assert error1 < 1e-1
+    assert error2 < 1e-1
+    assert error3 < 1e-2
+
+
+def test_basic_check_8():
+    # new system with 2 delays
+
+    def vector_field(t, y, args, *, history):
+        return -history[0] - history[1]
+
+    y0_history = lambda t: 1.2
+
+    delays = Delays(
+        delays=[lambda t, y, args: 1 / 5, lambda t, y, args: 1 / 3],
+        initial_discontinuities=jnp.array([0.0]),
+        max_discontinuities=100,
+        recurrent_checking=False,
+        rtol=10e-3,
+        atol=10e-6,
+    )
+    made_jump = delays.initial_discontinuities is None
+    t0, t1 = 0.0, 10.0
+    ts = jnp.linspace(t0, t1, 101)
+    sol = diffrax.diffeqsolve(
+        diffrax.ODETerm(vector_field),
+        diffrax.Tsit5(),
+        t0=ts[0],
+        t1=ts[-1],
+        dt0=ts[1] - ts[0],
+        y0=y0_history,
+        stepsize_controller=diffrax.PIDController(rtol=1e-6, atol=1e-9),
+        saveat=diffrax.SaveAt(ts=ts, dense=True),
+        delays=delays,
+        made_jump=made_jump,
+    )
+
+    delays = Delays(
+        delays=[lambda t, y, args: 1 / 5, lambda t, y, args: 1 / 3],
+        initial_discontinuities=jnp.array([0.0]),
+        max_discontinuities=1000,
+        recurrent_checking=True,
+        rtol=10e-3,
+        atol=10e-6,
+    )
+    sol2 = diffrax.diffeqsolve(
+        diffrax.ODETerm(vector_field),
+        diffrax.Tsit5(),
+        t0=ts[0],
+        t1=ts[-1],
+        dt0=ts[1] - ts[0],
+        y0=y0_history,
+        stepsize_controller=diffrax.PIDController(rtol=1e-6, atol=1e-9),
+        saveat=diffrax.SaveAt(ts=ts, dense=True),
+        delays=delays,
+        made_jump=made_jump,
+    )
+
+    _, juliays = open_process_file("test/julia_dde/test_basic_check_8.txt")
+    error1, error2, error3 = (
+        jnp.mean(jnp.abs(juliays - sol.ys)),
+        jnp.mean(jnp.abs(juliays - sol2.ys)),
+        jnp.mean(jnp.abs(sol.ys - sol2.ys)),
+    )
+
+    assert error1 < 1e-6
+    assert error2 < 1e-6
+    assert error3 < 1e-6
+
+
+def test_basic_check_9():
+    # new system ie Mackey Glass
+
+    def vector_field(t, y, args, *, history):
+        return 0.2 * (history[0]) / (1 + history[0] ** 10) - 0.1 * y
+
+    y0_history = lambda t: 1.2
+
+    delays = Delays(
+        delays=[lambda t, y, args: 6.0],
+        initial_discontinuities=jnp.array([0.0]),
+        max_discontinuities=100,
+        recurrent_checking=False,
+        rtol=10e-3,
+        atol=10e-6,
+    )
+
+    t0, t1, nb_points = 0.0, 50.0, 501
+    ts = jnp.linspace(t0, t1, nb_points)
+    sol = diffrax.diffeqsolve(
+        diffrax.ODETerm(vector_field),
+        diffrax.Tsit5(),
+        t0=ts[0],
+        t1=ts[-1],
+        dt0=ts[1] - ts[0],
+        y0=y0_history,
+        stepsize_controller=diffrax.PIDController(rtol=1e-6, atol=1e-9),
+        saveat=diffrax.SaveAt(ts=ts, dense=True),
+        delays=delays,
+    )
+
+    delays = Delays(
+        delays=[lambda t, y, args: 6.0],
+        initial_discontinuities=jnp.array([0.0]),
+        max_discontinuities=100,
+        recurrent_checking=True,
+        rtol=10e-3,
+        atol=10e-6,
+    )
+
+    sol2 = diffrax.diffeqsolve(
+        diffrax.ODETerm(vector_field),
+        diffrax.Tsit5(),
+        t0=ts[0],
+        t1=ts[-1],
+        dt0=ts[1] - ts[0],
+        y0=y0_history,
+        stepsize_controller=diffrax.PIDController(rtol=1e-6, atol=1e-9),
+        saveat=diffrax.SaveAt(ts=ts, dense=True),
+        delays=delays,
+    )
+
+    _, juliays = open_process_file("test/julia_dde/test_basic_check_9.txt")
+    error1, error2, error3 = (
+        jnp.mean(jnp.abs(juliays - sol.ys)),
+        jnp.mean(jnp.abs(juliays - sol2.ys)),
+        jnp.mean(jnp.abs(sol.ys - sol2.ys)),
+    )
+
+    assert error1 < 1e-4
+    assert error2 < 1e-4
+    assert error3 < 1e-4
+
+
+def test_basic_check_10():
+    # testing a time dependent equation
+
+    def vector_field(t, y, args, *, history):
+        return y * (1 - history[0])
+
+    y0_history = lambda t: 1.2
+
+    delays = Delays(
+        delays=[lambda t, y, args: 2 + jnp.sin(t)],
+        initial_discontinuities=jnp.array([0.0]),
+        max_discontinuities=100,
+        recurrent_checking=False,
+        rtol=10e-3,
+        atol=10e-6,
+    )
+
+    t0, t1, nb_points = 0.0, 40.0, 401
+    ts = jnp.linspace(t0, t1, nb_points)
+    sol = diffrax.diffeqsolve(
+        diffrax.ODETerm(vector_field),
+        diffrax.Bosh3(),
+        t0=ts[0],
+        t1=ts[-1],
+        dt0=ts[1] - ts[0],
+        y0=y0_history,
+        stepsize_controller=diffrax.PIDController(rtol=1e-6, atol=1e-9),
+        saveat=diffrax.SaveAt(ts=ts, dense=True),
+        delays=delays,
+    )
+
+    delays = Delays(
+        delays=[lambda t, y, args: 2 + jnp.sin(t)],
+        initial_discontinuities=jnp.array([0.0]),
+        max_discontinuities=100,
+        recurrent_checking=True,
+        rtol=10e-3,
+        atol=10e-6,
+    )
+
+    sol2 = diffrax.diffeqsolve(
+        diffrax.ODETerm(vector_field),
+        diffrax.Bosh3(),
+        t0=ts[0],
+        t1=ts[-1],
+        dt0=ts[1] - ts[0],
+        y0=y0_history,
+        stepsize_controller=diffrax.PIDController(rtol=1e-6, atol=1e-9),
+        saveat=diffrax.SaveAt(ts=ts, dense=True),
+        delays=delays,
+    )
+
+    _, juliays = open_process_file("test/julia_dde/test_basic_check_10.txt")
+    error1, error2, error3 = (
+        jnp.mean(jnp.abs(juliays - sol.ys)),
+        jnp.mean(jnp.abs(juliays - sol2.ys)),
+        jnp.mean(jnp.abs(sol.ys - sol2.ys)),
+    )
+
+    assert error1 < 1e-3
+    assert error2 < 1e-3
+    assert error3 < 1e-5
+
+
+def test_basic_numerical_check_1():
+    # testing a dde where we know its analytical value
+    # http://www.cs.toronto.edu/pub/reports/na/hzpEnrightNA09Preprint.pdf
+    # test problem 1
+
+    def vector_field(t, y, args, history):
+        return history[0]
+
+    y0_history = lambda t: lax.cond(t < 2.0, lambda: 0.5, lambda: 1.0)
+
+    delays = Delays(
+        delays=[lambda t, y, args: t - y],
+        initial_discontinuities=jnp.array([2.0]),
+        max_discontinuities=100,
+        recurrent_checking=False,
+        rtol=1e-3,
+        atol=1e-6,
+    )
+
+    t0, t1, nb_points = 2.0, 5.5, 350
+    ts = jnp.linspace(t0, t1, nb_points)
+    sol = diffrax.diffeqsolve(
+        diffrax.ODETerm(vector_field),
+        diffrax.Bosh3(),
+        t0=ts[0],
+        t1=ts[-1],
+        dt0=ts[1] - ts[0],
+        y0=y0_history,
+        stepsize_controller=diffrax.PIDController(rtol=1e-6, atol=1e-9),
+        saveat=diffrax.SaveAt(ts=ts, dense=True),
+        delays=delays,
+    )
+
+    delays = Delays(
+        delays=[lambda t, y, args: t - y],
+        initial_discontinuities=jnp.array([2.0]),
+        max_discontinuities=100,
+        recurrent_checking=True,
+        rtol=1e-3,
+        atol=1e-6,
+    )
+
+    sol2 = diffrax.diffeqsolve(
+        diffrax.ODETerm(vector_field),
+        diffrax.Bosh3(),
+        t0=ts[0],
+        t1=ts[-1],
+        dt0=ts[1] - ts[0],
+        y0=y0_history,
+        stepsize_controller=diffrax.PIDController(rtol=1e-6, atol=1e-9),
+        saveat=diffrax.SaveAt(ts=ts, dense=True),
+        delays=delays,
+    )
+
+    def f1(t):
+        return t / 2
+
+    def f2(t):
+        return 2 * jnp.exp(t / 2 - 2)
+
+    # y(t) = t/2 for 2 <= t <= 4
+    # y(t) = 2 exp(t/2-2) for 4 <= t <= 5.386
+    # y(t) = 4 - 2 ln(1+ 5.386 -t )] for 5.386 <= t <= 5.5
+    error1 = jnp.sum(jnp.abs(sol.ys[:200] - f1(sol.ts[:200])))
+    error2 = jnp.sum(jnp.abs(sol.ys[202:300] - f2(sol.ts[202:300])))
+
+    error11 = jnp.sum(jnp.abs(sol2.ys[:200] - f1(sol2.ts[:200])))
+    error21 = jnp.sum(jnp.abs(sol2.ys[202:300] - f2(sol2.ts[202:300])))
+    juliats, juliays = open_process_file(
+        "test/julia_dde/test_basic_numerical_check_1.txt"
+    )
+
+    error3 = jnp.sum(jnp.abs(juliays[:200] - f1(juliats[:200])))
+    error4 = jnp.sum(jnp.abs(juliays[200:300] - f2(juliats[200:300])))
+    assert error2 < error4
+    assert error1 < 1e-5
+    assert error2 < 1e-2
+    assert error11 < 1e-5
+    assert error21 < 1e-2
+    assert error3 < 1e-5
+
+
+def test_basic_numerical_check_2():
+    # testing a dde where we know its analytical value
+    # http://www.cs.toronto.edu/pub/reports/na/hzpEnrightNA09Preprint.pdf
+    # test problem 3
+
+    def vector_field(t, y, args, history):
+        return y * history[0] / t
+
+    y0_history = lambda t: 1.0
+
+    delays = Delays(
+        delays=[lambda t, y, args: t - jnp.log(y)],
+        initial_discontinuities=jnp.array([1.0]),
+        max_discontinuities=100,
+        recurrent_checking=False,
+        rtol=1e-3,
+        atol=1e-6,
+    )
+
+    t0, t1, nb_points = 1.0, 10.0, 901
+    ts = jnp.linspace(t0, t1, nb_points)
+    sol = diffrax.diffeqsolve(
+        diffrax.ODETerm(vector_field),
+        diffrax.Bosh3(),
+        t0=ts[0],
+        t1=ts[-1],
+        dt0=ts[1] - ts[0],
+        y0=y0_history,
+        stepsize_controller=diffrax.PIDController(rtol=1e-6, atol=1e-9),
+        saveat=diffrax.SaveAt(ts=ts, dense=True),
+        delays=delays,
+    )
+
+    delays = Delays(
+        delays=[lambda t, y, args: t - jnp.log(y)],
+        initial_discontinuities=jnp.array([1.0]),
+        max_discontinuities=100,
+        recurrent_checking=True,
+        rtol=1e-3,
+        atol=1e-6,
+    )
+
+    sol2 = diffrax.diffeqsolve(
+        diffrax.ODETerm(vector_field),
+        diffrax.Bosh3(),
+        t0=ts[0],
+        t1=ts[-1],
+        dt0=ts[1] - ts[0],
+        y0=y0_history,
+        stepsize_controller=diffrax.PIDController(rtol=1e-6, atol=1e-9),
+        saveat=diffrax.SaveAt(ts=ts, dense=True),
+        delays=delays,
+    )
+
+    def f1(t):
+        return t
+
+    def f2(t):
+        return jnp.exp(t / jnp.exp(1))
+
+    # y(t) = t for 1 <= t <= e
+    # y(t) = exp(t/e) for e <= t <= e^2
+    error1 = jnp.sum(jnp.abs(sol.ys[:100] - f1(sol.ts[:100])))
+    error2 = jnp.sum(jnp.abs(sol.ys[300:400] - f2(sol.ts[300:400])))
+
+    error11 = jnp.sum(jnp.abs(sol2.ys[:100] - f1(sol2.ts[:100])))
+    error21 = jnp.sum(jnp.abs(sol2.ys[300:400] - f2(sol2.ts[300:400])))
+    juliats, juliays = open_process_file(
+        "test/julia_dde/test_basic_numerical_check_2.txt"
+    )
+
+    error3 = jnp.sum(jnp.abs(juliays[:100] - f1(juliats[:100])))
+    error4 = jnp.sum(jnp.abs(juliays[300:400] - f2(juliats[300:400])))
+
+    assert error1 < 1e-6
+    assert error11 < 1e-6
+    assert error3 < 1e-6
+    assert error2 < 1e-2
+    assert error21 < 1e-2
+    assert error2 < error4