Skip to content

Commit

Permalink
more solvers to register
Browse files Browse the repository at this point in the history
  • Loading branch information
vdesmond committed Jan 7, 2025
1 parent d1e18b3 commit 84f52fe
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 0 deletions.
9 changes: 9 additions & 0 deletions src/attractors/solvers/euler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from attractors.solvers.registry import SolverRegistry
from attractors.type_defs import ParamVector, StateVector, SystemCallable


@SolverRegistry.register("euler")
def euler(
system_func: SystemCallable, state: StateVector, params: ParamVector, dt: float
) -> StateVector:
return state + dt * system_func(state, params)
11 changes: 11 additions & 0 deletions src/attractors/solvers/rk2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from attractors.solvers.registry import SolverRegistry
from attractors.type_defs import ParamVector, StateVector, SystemCallable


@SolverRegistry.register("rk2")
def rk2(
system_func: SystemCallable, state: StateVector, params: ParamVector, dt: float
) -> StateVector:
k1 = system_func(state, params)
k2 = system_func(state + dt * k1, params)
return state + dt * (k1 + k2) / 2
12 changes: 12 additions & 0 deletions src/attractors/solvers/rk3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from attractors.solvers.registry import SolverRegistry
from attractors.type_defs import ParamVector, StateVector, SystemCallable


@SolverRegistry.register("rk3")
def rk3(
system_func: SystemCallable, state: StateVector, params: ParamVector, dt: float
) -> StateVector:
k1 = system_func(state, params)
k2 = system_func(state + dt * k1 / 2, params)
k3 = system_func(state - dt * k1 + 2 * dt * k2, params)
return state + dt * (k1 + 4 * k2 + k3) / 6
19 changes: 19 additions & 0 deletions src/attractors/solvers/rk5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from attractors.solvers.registry import SolverRegistry
from attractors.type_defs import ParamVector, StateVector, SystemCallable


@SolverRegistry.register("rk5")
def rk5(
system_func: SystemCallable, state: StateVector, params: ParamVector, dt: float
) -> StateVector:
k1 = system_func(state, params)
k2 = system_func(state + dt * k1 / 4, params)
k3 = system_func(state + dt * (k1 + k2) / 8, params)
k4 = system_func(state + dt * (k3 - k2 / 2 + k3), params)
k5 = system_func(state + dt * (-3 * k1 / 16 + 9 * k4 / 16), params)
k6 = system_func(
state
+ dt * (-3 * k1 / 7 + 2 * k2 / 7 + 12 * k3 / 7 - 12 * k4 / 7 + 8 * k5 / 7),
params,
)
return state + dt * (7 * k1 + 32 * k3 + 12 * k4 + 32 * k5 + 7 * k6) / 90
13 changes: 13 additions & 0 deletions src/attractors/solvers/stormer_verlet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from attractors.solvers.registry import SolverRegistry
from attractors.type_defs import ParamVector, StateVector, SystemCallable


@SolverRegistry.register("stormer_verlet")
def stormer_verlet(
system_func: SystemCallable, state: StateVector, params: ParamVector, dt: float
) -> StateVector:
k1 = system_func(state, params)
half_state = state + 0.5 * dt * k1

k2 = system_func(half_state, params)
return state + dt * k2

0 comments on commit 84f52fe

Please sign in to comment.