Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use OperatorGroup for constrain and callback features #404

Open
wants to merge 10 commits into
base: update-0.3.3
Choose a base branch
from
Open
1 change: 0 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ flake8:
.PHONY: autoflake-check
autoflake-check:
poetry run autoflake --version
poetry run autoflake $(AUTOFLAKE_ARGS) elastica tests examples
poetry run autoflake --check $(AUTOFLAKE_ARGS) elastica tests examples

.PHONY: autoflake-format
Expand Down
5 changes: 1 addition & 4 deletions elastica/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,9 @@
from elastica.utils import isqrt
from elastica.timestepper import (
integrate,
PositionVerlet,
PEFRL,
RungeKutta4,
EulerForward,
extend_stepper_interface,
)
from elastica.timestepper.symplectic_steppers import PositionVerlet, PEFRL
from elastica.memory_block.memory_block_rigid_body import MemoryBlockRigidBody
from elastica.memory_block.memory_block_rod import MemoryBlockCosseratRod
from elastica.restart import save_state, load_state
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@
from elastica.typing import (
SystemType,
SystemCollectionType,
OperatorType,
StepType,
SteppersOperatorsType,
StateType,
)
from elastica.systems.protocol import ExplicitSystemProtocol
from .protocol import ExplicitStepperProtocol, MemoryProtocol
from elastica.experimental.timestepper.protocol import (
ExplicitSystemProtocol,
ExplicitStepperProtocol,
MemoryProtocol,
)


"""
Expand Down Expand Up @@ -166,10 +169,10 @@ class EulerForward(ExplicitStepperMixin):
Classical Euler Forward stepper. Stateless, coordinates operations only.
"""

def get_stages(self) -> list[OperatorType]:
def get_stages(self) -> list[StepType]:
return [self._first_stage]

def get_updates(self) -> list[OperatorType]:
def get_updates(self) -> list[StepType]:
return [self._first_update]

def _first_stage(
Expand Down Expand Up @@ -198,15 +201,15 @@ class RungeKutta4(ExplicitStepperMixin):
to be externally managed and allocated.
"""

def get_stages(self) -> list[OperatorType]:
def get_stages(self) -> list[StepType]:
return [
self._first_stage,
self._second_stage,
self._third_stage,
self._fourth_stage,
]

def get_updates(self) -> list[OperatorType]:
def get_updates(self) -> list[StepType]:
return [
self._first_update,
self._second_update,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from typing import Iterator, TypeVar, Generic, Type
from elastica.timestepper.protocol import ExplicitStepperProtocol
from elastica.typing import SystemCollectionType
from elastica.experimental.timestepper.explicit_steppers import (
RungeKutta4,
EulerForward,
)
from elastica.experimental.timestepper.protocol import ExplicitStepperProtocol

from copy import copy

Expand All @@ -12,11 +16,6 @@ def make_memory_for_explicit_stepper(
) -> "MemoryCollection":
# TODO Automated logic (class creation, memory management logic) agnostic of stepper details (RK, AB etc.)

from elastica.timestepper.explicit_steppers import (
RungeKutta4,
EulerForward,
)

# is_this_system_a_collection = is_system_a_collection(system)

memory_cls: Type
Expand Down
86 changes: 86 additions & 0 deletions elastica/experimental/timestepper/protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from typing import Protocol

from elastica.typing import StepType, StateType
from elastica.systems.protocol import SystemProtocol, SlenderBodyGeometryProtocol
from elastica.timestepper.protocol import StepperProtocol

import numpy as np


class ExplicitSystemProtocol(SystemProtocol, SlenderBodyGeometryProtocol, Protocol):
# TODO: Temporarily made to handle explicit stepper.
# Need to be refactored as the explicit stepper is further developed.
def __call__(self, time: np.float64, dt: np.float64) -> np.float64: ...
@property
def state(self) -> StateType: ...
@state.setter
def state(self, state: StateType) -> None: ...
@property
def n_elems(self) -> int: ...


class MemoryProtocol(Protocol):
@property
def initial_state(self) -> bool: ...


class ExplicitStepperProtocol(StepperProtocol, Protocol):
"""symplectic stepper protocol."""

def get_stages(self) -> list[StepType]: ...

def get_updates(self) -> list[StepType]: ...


# class _LinearExponentialIntegratorMixin:
# """
# Linear Exponential integrator mixin wrapper.
# """
#
# def __init__(self):
# pass
#
# def _do_stage(self, System, Memory, time, dt):
# # TODO : Make more general, system should not be calculating what the state
# # transition matrix directly is, but rather it should just give
# Memory.linear_operator = System.get_linear_state_transition_operator(time, dt)
#
# def _do_update(self, System, Memory, time, dt):
# # FIXME What's the right formula when doing update?
# # System.linearly_evolving_state = _batch_matmul(
# # System.linearly_evolving_state,
# # Memory.linear_operator
# # )
# System.linearly_evolving_state = np.einsum(
# "ijk,ljk->ilk", System.linearly_evolving_state, Memory.linear_operator
# )
# return time + dt
#
# def _first_prefactor(self, dt):
# """Prefactor call to satisfy interface of SymplecticStepper. Should never
# be used in actual code.
#
# Parameters
# ----------
# dt : the time step of simulation
#
# Raises
# ------
# RuntimeError
# """
# raise RuntimeError(
# "Symplectic prefactor of LinearExponentialIntegrator should not be called!"
# )
#
# # Code repeat!
# # Easy to avoid, but keep for performance.
# def _do_one_step(self, System, time, prefac):
# System.linearly_evolving_state = np.einsum(
# "ijk,ljk->ilk",
# System.linearly_evolving_state,
# System.get_linear_state_transition_operator(time, prefac),
# )
# return (
# time # TODO fix hack that treats time separately here. Shuold be time + dt
# )
# # return time + dt
19 changes: 14 additions & 5 deletions elastica/modules/base_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Basic coordinating for multiple, smaller systems that have an independently integrable
interface (i.e. works with symplectic or explicit routines `timestepper.py`.)
"""
from typing import Type, Generator, Iterable, Any, overload
from typing import Type, Generator, Any, overload
from typing import final
from elastica.typing import (
SystemType,
Expand All @@ -27,6 +27,7 @@

from .memory_block import construct_memory_block_structures
from .operator_group import OperatorGroupFIFO
from .protocol import ModuleProtocol


class BaseSystemCollection(MutableSequence):
Expand Down Expand Up @@ -55,10 +56,18 @@ def __init__(self) -> None:
# Collection of functions. Each group is executed as a collection at the different steps.
# Each component (Forcing, Connection, etc.) registers the executable (callable) function
# in the group that that needs to be executed. These should be initialized before mixin.
self._feature_group_synchronize: Iterable[OperatorType] = OperatorGroupFIFO()
self._feature_group_constrain_values: list[OperatorType] = []
self._feature_group_constrain_rates: list[OperatorType] = []
self._feature_group_callback: list[OperatorCallbackType] = []
self._feature_group_synchronize: OperatorGroupFIFO[
OperatorType, ModuleProtocol
] = OperatorGroupFIFO()
self._feature_group_constrain_values: OperatorGroupFIFO[
OperatorType, ModuleProtocol
] = OperatorGroupFIFO()
self._feature_group_constrain_rates: OperatorGroupFIFO[
OperatorType, ModuleProtocol
] = OperatorGroupFIFO()
self._feature_group_callback: OperatorGroupFIFO[
OperatorCallbackType, ModuleProtocol
] = OperatorGroupFIFO()
self._feature_group_finalize: list[OperatorFinalizeType] = []
# We need to initialize our mixin classes
super().__init__()
Expand Down
34 changes: 16 additions & 18 deletions elastica/modules/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from elastica.typing import SystemType, SystemIdxType, OperatorFinalizeType
from .protocol import ModuleProtocol

import functools

import numpy as np

from elastica.callback_functions import CallBackBaseClass
Expand All @@ -29,9 +31,7 @@ class CallBacks:

def __init__(self: SystemCollectionProtocol) -> None:
self._callback_list: list[ModuleProtocol] = []
self._callback_operators: list[tuple[int, CallBackBaseClass]] = []
super(CallBacks, self).__init__()
self._feature_group_callback.append(self._callback_execution)
self._feature_group_finalize.append(self._finalize_callback)

def collect_diagnostics(
Expand All @@ -54,30 +54,28 @@ def collect_diagnostics(
sys_idx: SystemIdxType = self.get_system_index(system)

# Create _Constraint object, cache it and return to user
_callbacks: ModuleProtocol = _CallBack(sys_idx)
self._callback_list.append(_callbacks)
_callback: ModuleProtocol = _CallBack(sys_idx)
self._callback_list.append(_callback)
self._feature_group_callback.append_id(_callback)

return _callbacks
return _callback

def _finalize_callback(self: SystemCollectionProtocol) -> None:
# dev : the first index stores the rod index to collect data.
self._callback_operators = [
(callback.id(), callback.instantiate()) for callback in self._callback_list
]
for callback in self._callback_list:
sys_id = callback.id()
callback_instance = callback.instantiate()

callback_operator = functools.partial(
callback_instance.make_callback, system=self[sys_id]
)
self._feature_group_callback.add_operators(callback, [callback_operator])

self._callback_list.clear()
del self._callback_list

# First callback execution
time = np.float64(0.0)
self._callback_execution(time=time, current_step=0)

def _callback_execution(
self: SystemCollectionProtocol,
time: np.float64,
current_step: int,
) -> None:
for sys_id, callback in self._callback_operators:
callback.make_callback(self[sys_id], time, current_step)
self.apply_callbacks(time=np.float64(0.0), current_step=0)


class _CallBack:
Expand Down
48 changes: 30 additions & 18 deletions elastica/modules/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from typing import Any, Type, cast
from typing_extensions import Self

import functools

import numpy as np

from elastica.boundary_conditions import ConstraintBase
Expand Down Expand Up @@ -36,8 +38,6 @@ class Constraints:
def __init__(self: SystemCollectionProtocol) -> None:
self._constraints_list: list[ModuleProtocol] = []
super(Constraints, self).__init__()
self._feature_group_constrain_values.append(self._constrain_values)
self._feature_group_constrain_rates.append(self._constrain_rates)
self._feature_group_finalize.append(self._finalize_constraints)

def constrain(
Expand All @@ -62,6 +62,8 @@ def constrain(
# Create _Constraint object, cache it and return to user
_constraint: ModuleProtocol = _Constraint(sys_idx)
self._constraints_list.append(_constraint)
self._feature_group_constrain_values.append_id(_constraint)
self._feature_group_constrain_rates.append_id(_constraint)

return _constraint

Expand All @@ -71,11 +73,14 @@ def _finalize_constraints(self: SystemCollectionProtocol) -> None:
periodic boundaries, a new constrain for memory block rod added called as _ConstrainPeriodicBoundaries. This
constrain will synchronize the only periodic boundaries of position, director, velocity and omega variables.
"""
from elastica._synchronize_periodic_boundary import _ConstrainPeriodicBoundaries

for block in self.block_systems():
# append the memory block to the simulation as a system. Memory block is the final system in the simulation.
if hasattr(block, "ring_rod_flag"):
from elastica._synchronize_periodic_boundary import (
_ConstrainPeriodicBoundaries,
)

# Apply the constrain to synchronize the periodic boundaries of the memory rod. Find the memory block
# sys idx among other systems added and then apply boundary conditions.
memory_block_idx = self.get_system_index(block)
Expand All @@ -89,31 +94,38 @@ def _finalize_constraints(self: SystemCollectionProtocol) -> None:

# dev : the first index stores the rod index to apply the boundary condition
# to.
self._constraints_operators = [
(constraint.id(), constraint.instantiate(self[constraint.id()]))
for constraint in self._constraints_list
]

# Sort from lowest id to highest id for potentially better memory access
# _constraints contains list of tuples. First element of tuple is rod number and
# following elements are the type of boundary condition such as
# [(0, ConstraintBase, OneEndFixedBC), (1, HelicalBucklingBC), ... ]
# Thus using lambda we iterate over the list of tuples and use rod number (x[0])
# to sort constraints.
self._constraints_operators.sort(key=lambda x: x[0])
self._constraints_list.sort(key=lambda x: x.id())
for constraint in self._constraints_list:
sys_id = constraint.id()
constraint_instance = constraint.instantiate(self[sys_id])

constrain_values = functools.partial(
constraint_instance.constrain_values, system=self[sys_id]
)
constrain_rates = functools.partial(
constraint_instance.constrain_rates, system=self[sys_id]
)

self._feature_group_constrain_values.add_operators(
constraint, [constrain_values]
)
self._feature_group_constrain_rates.add_operators(
constraint, [constrain_rates]
)

# At t=0.0, constrain all the boundary conditions (for compatability with
# initial conditions)
self._constrain_values(time=np.float64(0.0))
self._constrain_rates(time=np.float64(0.0))

def _constrain_values(self: SystemCollectionProtocol, time: np.float64) -> None:
for sys_id, constraint in self._constraints_operators:
constraint.constrain_values(self[sys_id], time)
self.constrain_values(time=np.float64(0.0))
self.constrain_rates(time=np.float64(0.0))

def _constrain_rates(self: SystemCollectionProtocol, time: np.float64) -> None:
for sys_id, constraint in self._constraints_operators:
constraint.constrain_rates(self[sys_id], time)
self._constraints_list = []
del self._constraints_list


class _Constraint:
Expand Down
Loading
Loading