diff --git a/Makefile b/Makefile index 89b4b179..3bce77a7 100644 --- a/Makefile +++ b/Makefile @@ -50,7 +50,6 @@ flake8: .PHONY: autoflake-check autoflake-check: poetry run autoflake --version - poetry run autoflake $(AUTOFLAKE_ARGS) elastica tests examples poetry run autoflake --check $(AUTOFLAKE_ARGS) elastica tests examples .PHONY: autoflake-format diff --git a/elastica/__init__.py b/elastica/__init__.py index 21dfddb7..5f070306 100644 --- a/elastica/__init__.py +++ b/elastica/__init__.py @@ -75,12 +75,9 @@ from elastica.utils import isqrt from elastica.timestepper import ( integrate, - PositionVerlet, - PEFRL, - RungeKutta4, - EulerForward, extend_stepper_interface, ) +from elastica.timestepper.symplectic_steppers import PositionVerlet, PEFRL from elastica.memory_block.memory_block_rigid_body import MemoryBlockRigidBody from elastica.memory_block.memory_block_rod import MemoryBlockCosseratRod from elastica.restart import save_state, load_state diff --git a/elastica/timestepper/explicit_steppers.py b/elastica/experimental/timestepper/explicit_steppers.py similarity index 96% rename from elastica/timestepper/explicit_steppers.py rename to elastica/experimental/timestepper/explicit_steppers.py index 5fb1531f..3705ccd2 100644 --- a/elastica/timestepper/explicit_steppers.py +++ b/elastica/experimental/timestepper/explicit_steppers.py @@ -8,12 +8,15 @@ from elastica.typing import ( SystemType, SystemCollectionType, - OperatorType, + StepType, SteppersOperatorsType, StateType, ) -from elastica.systems.protocol import ExplicitSystemProtocol -from .protocol import ExplicitStepperProtocol, MemoryProtocol +from elastica.experimental.timestepper.protocol import ( + ExplicitSystemProtocol, + ExplicitStepperProtocol, + MemoryProtocol, +) """ @@ -166,10 +169,10 @@ class EulerForward(ExplicitStepperMixin): Classical Euler Forward stepper. Stateless, coordinates operations only. """ - def get_stages(self) -> list[OperatorType]: + def get_stages(self) -> list[StepType]: return [self._first_stage] - def get_updates(self) -> list[OperatorType]: + def get_updates(self) -> list[StepType]: return [self._first_update] def _first_stage( @@ -198,7 +201,7 @@ class RungeKutta4(ExplicitStepperMixin): to be externally managed and allocated. """ - def get_stages(self) -> list[OperatorType]: + def get_stages(self) -> list[StepType]: return [ self._first_stage, self._second_stage, @@ -206,7 +209,7 @@ def get_stages(self) -> list[OperatorType]: self._fourth_stage, ] - def get_updates(self) -> list[OperatorType]: + def get_updates(self) -> list[StepType]: return [ self._first_update, self._second_update, diff --git a/elastica/systems/memory.py b/elastica/experimental/timestepper/memory.py similarity index 93% rename from elastica/systems/memory.py rename to elastica/experimental/timestepper/memory.py index c669be9b..b63931aa 100644 --- a/elastica/systems/memory.py +++ b/elastica/experimental/timestepper/memory.py @@ -1,6 +1,10 @@ from typing import Iterator, TypeVar, Generic, Type -from elastica.timestepper.protocol import ExplicitStepperProtocol from elastica.typing import SystemCollectionType +from elastica.experimental.timestepper.explicit_steppers import ( + RungeKutta4, + EulerForward, +) +from elastica.experimental.timestepper.protocol import ExplicitStepperProtocol from copy import copy @@ -12,11 +16,6 @@ def make_memory_for_explicit_stepper( ) -> "MemoryCollection": # TODO Automated logic (class creation, memory management logic) agnostic of stepper details (RK, AB etc.) - from elastica.timestepper.explicit_steppers import ( - RungeKutta4, - EulerForward, - ) - # is_this_system_a_collection = is_system_a_collection(system) memory_cls: Type diff --git a/elastica/experimental/timestepper/protocol.py b/elastica/experimental/timestepper/protocol.py new file mode 100644 index 00000000..b8d489fa --- /dev/null +++ b/elastica/experimental/timestepper/protocol.py @@ -0,0 +1,86 @@ +from typing import Protocol + +from elastica.typing import StepType, StateType +from elastica.systems.protocol import SystemProtocol, SlenderBodyGeometryProtocol +from elastica.timestepper.protocol import StepperProtocol + +import numpy as np + + +class ExplicitSystemProtocol(SystemProtocol, SlenderBodyGeometryProtocol, Protocol): + # TODO: Temporarily made to handle explicit stepper. + # Need to be refactored as the explicit stepper is further developed. + def __call__(self, time: np.float64, dt: np.float64) -> np.float64: ... + @property + def state(self) -> StateType: ... + @state.setter + def state(self, state: StateType) -> None: ... + @property + def n_elems(self) -> int: ... + + +class MemoryProtocol(Protocol): + @property + def initial_state(self) -> bool: ... + + +class ExplicitStepperProtocol(StepperProtocol, Protocol): + """symplectic stepper protocol.""" + + def get_stages(self) -> list[StepType]: ... + + def get_updates(self) -> list[StepType]: ... + + +# class _LinearExponentialIntegratorMixin: +# """ +# Linear Exponential integrator mixin wrapper. +# """ +# +# def __init__(self): +# pass +# +# def _do_stage(self, System, Memory, time, dt): +# # TODO : Make more general, system should not be calculating what the state +# # transition matrix directly is, but rather it should just give +# Memory.linear_operator = System.get_linear_state_transition_operator(time, dt) +# +# def _do_update(self, System, Memory, time, dt): +# # FIXME What's the right formula when doing update? +# # System.linearly_evolving_state = _batch_matmul( +# # System.linearly_evolving_state, +# # Memory.linear_operator +# # ) +# System.linearly_evolving_state = np.einsum( +# "ijk,ljk->ilk", System.linearly_evolving_state, Memory.linear_operator +# ) +# return time + dt +# +# def _first_prefactor(self, dt): +# """Prefactor call to satisfy interface of SymplecticStepper. Should never +# be used in actual code. +# +# Parameters +# ---------- +# dt : the time step of simulation +# +# Raises +# ------ +# RuntimeError +# """ +# raise RuntimeError( +# "Symplectic prefactor of LinearExponentialIntegrator should not be called!" +# ) +# +# # Code repeat! +# # Easy to avoid, but keep for performance. +# def _do_one_step(self, System, time, prefac): +# System.linearly_evolving_state = np.einsum( +# "ijk,ljk->ilk", +# System.linearly_evolving_state, +# System.get_linear_state_transition_operator(time, prefac), +# ) +# return ( +# time # TODO fix hack that treats time separately here. Shuold be time + dt +# ) +# # return time + dt diff --git a/elastica/memory_block/protocol.py b/elastica/memory_block/protocol.py index 5da6331b..51cd57d5 100644 --- a/elastica/memory_block/protocol.py +++ b/elastica/memory_block/protocol.py @@ -1,14 +1,22 @@ from typing import Protocol -from elastica.systems.protocol import SystemProtocol - -import numpy as np - from elastica.rod.protocol import CosseratRodProtocol from elastica.rigidbody.protocol import RigidBodyProtocol -from elastica.systems.protocol import SymplecticSystemProtocol +from elastica.systems.protocol import SystemProtocol -class BlockSystemProtocol(SystemProtocol, Protocol): +class BlockProtocol(Protocol): @property def n_systems(self) -> int: """Number of systems in the block.""" + + +class BlockSystemProtocol(SystemProtocol, BlockProtocol, Protocol): + pass + + +class BlockRodProtocol(BlockProtocol, CosseratRodProtocol, Protocol): + pass + + +class BlockRigidBodyProtocol(BlockProtocol, RigidBodyProtocol, Protocol): + pass diff --git a/elastica/modules/base_system.py b/elastica/modules/base_system.py index 3ac2eaba..22ed05cb 100644 --- a/elastica/modules/base_system.py +++ b/elastica/modules/base_system.py @@ -5,7 +5,7 @@ Basic coordinating for multiple, smaller systems that have an independently integrable interface (i.e. works with symplectic or explicit routines `timestepper.py`.) """ -from typing import Type, Generator, Iterable, Any, overload +from typing import TYPE_CHECKING, Type, Generator, Any, overload from typing import final from elastica.typing import ( SystemType, @@ -27,6 +27,7 @@ from .memory_block import construct_memory_block_structures from .operator_group import OperatorGroupFIFO +from .protocol import ModuleProtocol class BaseSystemCollection(MutableSequence): @@ -55,10 +56,18 @@ def __init__(self) -> None: # Collection of functions. Each group is executed as a collection at the different steps. # Each component (Forcing, Connection, etc.) registers the executable (callable) function # in the group that that needs to be executed. These should be initialized before mixin. - self._feature_group_synchronize: Iterable[OperatorType] = OperatorGroupFIFO() - self._feature_group_constrain_values: list[OperatorType] = [] - self._feature_group_constrain_rates: list[OperatorType] = [] - self._feature_group_callback: list[OperatorCallbackType] = [] + self._feature_group_synchronize: OperatorGroupFIFO[ + OperatorType, ModuleProtocol + ] = OperatorGroupFIFO() + self._feature_group_constrain_values: OperatorGroupFIFO[ + OperatorType, ModuleProtocol + ] = OperatorGroupFIFO() + self._feature_group_constrain_rates: OperatorGroupFIFO[ + OperatorType, ModuleProtocol + ] = OperatorGroupFIFO() + self._feature_group_callback: OperatorGroupFIFO[ + OperatorCallbackType, ModuleProtocol + ] = OperatorGroupFIFO() self._feature_group_finalize: list[OperatorFinalizeType] = [] # We need to initialize our mixin classes super().__init__() @@ -104,11 +113,11 @@ def _check_type(self, sys_to_be_added: Any) -> bool: def __len__(self) -> int: return len(self.__systems) - @overload - def __getitem__(self, idx: int, /) -> SystemType: ... + @overload # type: ignore + def __getitem__(self, idx: slice, /) -> list[SystemType]: ... # type: ignore - @overload - def __getitem__(self, idx: slice, /) -> list[SystemType]: ... + @overload # type: ignore + def __getitem__(self, idx: int, /) -> SystemType: ... # type: ignore def __getitem__(self, idx, /): # type: ignore return self.__systems[idx] @@ -266,3 +275,40 @@ def apply_callbacks(self, time: np.float64, current_step: int) -> None: """ for func in self._feature_group_callback: func(time=time, current_step=current_step) + + +if TYPE_CHECKING: + from .protocol import SystemCollectionProtocol + from .constraints import Constraints + from .forcing import Forcing + from .connections import Connections + from .contact import Contact + from .damping import Damping + from .callbacks import CallBacks + + class BaseFeature(BaseSystemCollection): + pass + + class PartialFeatureA( + BaseSystemCollection, Constraints, Forcing, Damping, CallBacks + ): + pass + + class PartialFeatureB(BaseSystemCollection, Contact, Connections): + pass + + class FullFeature( + BaseSystemCollection, + Constraints, + Contact, + Connections, + Forcing, + Damping, + CallBacks, + ): + pass + + _: SystemCollectionProtocol = FullFeature() + _: SystemCollectionProtocol = PartialFeatureA() # type: ignore[no-redef] + _: SystemCollectionProtocol = PartialFeatureB() # type: ignore[no-redef] + _: SystemCollectionProtocol = BaseFeature() # type: ignore[no-redef] diff --git a/elastica/modules/callbacks.py b/elastica/modules/callbacks.py index de1e5091..f65dfd82 100644 --- a/elastica/modules/callbacks.py +++ b/elastica/modules/callbacks.py @@ -9,10 +9,12 @@ from elastica.typing import SystemType, SystemIdxType, OperatorFinalizeType from .protocol import ModuleProtocol +import functools + import numpy as np from elastica.callback_functions import CallBackBaseClass -from .protocol import SystemCollectionProtocol +from .protocol import SystemCollectionWithCallbackProtocol class CallBacks: @@ -27,15 +29,13 @@ class CallBacks: List of call back classes defined for rod-like objects. """ - def __init__(self: SystemCollectionProtocol) -> None: + def __init__(self: SystemCollectionWithCallbackProtocol) -> None: self._callback_list: list[ModuleProtocol] = [] - self._callback_operators: list[tuple[int, CallBackBaseClass]] = [] super(CallBacks, self).__init__() - self._feature_group_callback.append(self._callback_execution) self._feature_group_finalize.append(self._finalize_callback) def collect_diagnostics( - self: SystemCollectionProtocol, system: SystemType + self: SystemCollectionWithCallbackProtocol, system: SystemType ) -> ModuleProtocol: """ This method calls user-defined call-back classes for a @@ -54,31 +54,26 @@ def collect_diagnostics( sys_idx: SystemIdxType = self.get_system_index(system) # Create _Constraint object, cache it and return to user - _callbacks: ModuleProtocol = _CallBack(sys_idx) - self._callback_list.append(_callbacks) + _callback: ModuleProtocol = _CallBack(sys_idx) + self._callback_list.append(_callback) + self._feature_group_callback.append_id(_callback) - return _callbacks + return _callback - def _finalize_callback(self: SystemCollectionProtocol) -> None: + def _finalize_callback(self: SystemCollectionWithCallbackProtocol) -> None: # dev : the first index stores the rod index to collect data. - self._callback_operators = [ - (callback.id(), callback.instantiate()) for callback in self._callback_list - ] + for callback in self._callback_list: + sys_id = callback.id() + callback_instance = callback.instantiate() + + callback_operator = functools.partial( + callback_instance.make_callback, system=self[sys_id] + ) + self._feature_group_callback.add_operators(callback, [callback_operator]) + self._callback_list.clear() del self._callback_list - # First callback execution - time = np.float64(0.0) - self._callback_execution(time=time, current_step=0) - - def _callback_execution( - self: SystemCollectionProtocol, - time: np.float64, - current_step: int, - ) -> None: - for sys_id, callback in self._callback_operators: - callback.make_callback(self[sys_id], time, current_step) - class _CallBack: """ diff --git a/elastica/modules/connections.py b/elastica/modules/connections.py index 44a59073..0f59cb2b 100644 --- a/elastica/modules/connections.py +++ b/elastica/modules/connections.py @@ -18,7 +18,7 @@ import functools from elastica.joint import FreeJoint -from .protocol import SystemCollectionProtocol, ModuleProtocol +from .protocol import ConnectedSystemCollectionProtocol, ModuleProtocol class Connections: @@ -33,13 +33,13 @@ class Connections: List of joint classes defined for rod-like objects. """ - def __init__(self: SystemCollectionProtocol) -> None: + def __init__(self: ConnectedSystemCollectionProtocol) -> None: self._connections: list[ModuleProtocol] = [] super(Connections, self).__init__() self._feature_group_finalize.append(self._finalize_connections) def connect( - self: SystemCollectionProtocol, + self: ConnectedSystemCollectionProtocol, first_rod: "RodType | RigidBodyType", second_rod: "RodType | RigidBodyType", first_connect_idx: ConnectionIndex = (), @@ -81,7 +81,7 @@ def connect( return _connect - def _finalize_connections(self: SystemCollectionProtocol) -> None: + def _finalize_connections(self: ConnectedSystemCollectionProtocol) -> None: # From stored _Connect objects, instantiate the joints and store it # dev : the first indices stores the # (first rod index, second_rod_idx, connection_idx_on_first_rod, connection_idx_on_second_rod) diff --git a/elastica/modules/constraints.py b/elastica/modules/constraints.py index 029ed961..3686a135 100644 --- a/elastica/modules/constraints.py +++ b/elastica/modules/constraints.py @@ -7,6 +7,8 @@ from typing import Any, Type, cast from typing_extensions import Self +import functools + import numpy as np from elastica.boundary_conditions import ConstraintBase @@ -16,9 +18,9 @@ ConstrainingIndex, RigidBodyType, RodType, - BlockSystemType, ) -from .protocol import SystemCollectionProtocol, ModuleProtocol +from elastica.memory_block.protocol import BlockRodProtocol +from .protocol import ConstrainedSystemCollectionProtocol, ModuleProtocol class Constraints: @@ -33,15 +35,13 @@ class Constraints: List of boundary condition classes defined for rod-like objects. """ - def __init__(self: SystemCollectionProtocol) -> None: + def __init__(self: ConstrainedSystemCollectionProtocol) -> None: self._constraints_list: list[ModuleProtocol] = [] super(Constraints, self).__init__() - self._feature_group_constrain_values.append(self._constrain_values) - self._feature_group_constrain_rates.append(self._constrain_rates) self._feature_group_finalize.append(self._finalize_constraints) def constrain( - self: SystemCollectionProtocol, system: "RodType | RigidBodyType" + self: ConstrainedSystemCollectionProtocol, system: "RodType | RigidBodyType" ) -> ModuleProtocol: """ This method enforces a displacement boundary conditions to the relevant user-defined @@ -62,24 +62,29 @@ def constrain( # Create _Constraint object, cache it and return to user _constraint: ModuleProtocol = _Constraint(sys_idx) self._constraints_list.append(_constraint) + self._feature_group_constrain_values.append_id(_constraint) + self._feature_group_constrain_rates.append_id(_constraint) return _constraint - def _finalize_constraints(self: SystemCollectionProtocol) -> None: + def _finalize_constraints(self: ConstrainedSystemCollectionProtocol) -> None: """ In case memory block have ring rod, then periodic boundaries have to be synched. In order to synchronize periodic boundaries, a new constrain for memory block rod added called as _ConstrainPeriodicBoundaries. This constrain will synchronize the only periodic boundaries of position, director, velocity and omega variables. """ - from elastica._synchronize_periodic_boundary import _ConstrainPeriodicBoundaries for block in self.block_systems(): # append the memory block to the simulation as a system. Memory block is the final system in the simulation. if hasattr(block, "ring_rod_flag"): + from elastica._synchronize_periodic_boundary import ( + _ConstrainPeriodicBoundaries, + ) + # Apply the constrain to synchronize the periodic boundaries of the memory rod. Find the memory block # sys idx among other systems added and then apply boundary conditions. memory_block_idx = self.get_system_index(block) - block_system = cast(BlockSystemType, self[memory_block_idx]) + block_system = cast(BlockRodProtocol, self[memory_block_idx]) self.constrain(block_system).using( _ConstrainPeriodicBoundaries, ) @@ -89,31 +94,38 @@ def _finalize_constraints(self: SystemCollectionProtocol) -> None: # dev : the first index stores the rod index to apply the boundary condition # to. - self._constraints_operators = [ - (constraint.id(), constraint.instantiate(self[constraint.id()])) - for constraint in self._constraints_list - ] - # Sort from lowest id to highest id for potentially better memory access # _constraints contains list of tuples. First element of tuple is rod number and # following elements are the type of boundary condition such as # [(0, ConstraintBase, OneEndFixedBC), (1, HelicalBucklingBC), ... ] # Thus using lambda we iterate over the list of tuples and use rod number (x[0]) # to sort constraints. - self._constraints_operators.sort(key=lambda x: x[0]) + self._constraints_list.sort(key=lambda x: x.id()) + for constraint in self._constraints_list: + sys_id = constraint.id() + constraint_instance = constraint.instantiate(self[sys_id]) + + constrain_values = functools.partial( + constraint_instance.constrain_values, system=self[sys_id] + ) + constrain_rates = functools.partial( + constraint_instance.constrain_rates, system=self[sys_id] + ) + + self._feature_group_constrain_values.add_operators( + constraint, [constrain_values] + ) + self._feature_group_constrain_rates.add_operators( + constraint, [constrain_rates] + ) # At t=0.0, constrain all the boundary conditions (for compatability with # initial conditions) - self._constrain_values(time=np.float64(0.0)) - self._constrain_rates(time=np.float64(0.0)) - - def _constrain_values(self: SystemCollectionProtocol, time: np.float64) -> None: - for sys_id, constraint in self._constraints_operators: - constraint.constrain_values(self[sys_id], time) + self.constrain_values(time=np.float64(0.0)) + self.constrain_rates(time=np.float64(0.0)) - def _constrain_rates(self: SystemCollectionProtocol, time: np.float64) -> None: - for sys_id, constraint in self._constraints_operators: - constraint.constrain_rates(self[sys_id], time) + self._constraints_list = [] + del self._constraints_list class _Constraint: diff --git a/elastica/modules/contact.py b/elastica/modules/contact.py index 76491664..c2443adb 100644 --- a/elastica/modules/contact.py +++ b/elastica/modules/contact.py @@ -9,14 +9,13 @@ from typing_extensions import Self from elastica.typing import ( SystemIdxType, - OperatorFinalizeType, + OperatorType, StaticSystemType, SystemType, ) -from .protocol import SystemCollectionProtocol, ModuleProtocol +from .protocol import ContactedSystemCollectionProtocol, ModuleProtocol import logging -import functools import numpy as np @@ -40,13 +39,13 @@ class Contact: List of contact classes defined for rod-like objects. """ - def __init__(self: SystemCollectionProtocol) -> None: + def __init__(self: ContactedSystemCollectionProtocol) -> None: self._contacts: list[ModuleProtocol] = [] super(Contact, self).__init__() self._feature_group_finalize.append(self._finalize_contact) def detect_contact_between( - self: SystemCollectionProtocol, + self: ContactedSystemCollectionProtocol, first_system: SystemType, second_system: "SystemType | StaticSystemType", ) -> ModuleProtocol: @@ -73,23 +72,12 @@ def detect_contact_between( return _contact - def _finalize_contact(self: SystemCollectionProtocol) -> None: + def _finalize_contact(self: ContactedSystemCollectionProtocol) -> None: # dev : the first indices stores the # (first_rod_idx, second_rod_idx) # to apply the contacts to - def apply_contact( - time: np.float64, - contact_instance: NoContact, - first_sys_idx: SystemIdxType, - second_sys_idx: SystemIdxType, - ) -> None: - contact_instance.apply_contact( - system_one=self[first_sys_idx], - system_two=self[second_sys_idx], - ) - for contact in self._contacts: first_sys_idx, second_sys_idx = contact.id() contact_instance = contact.instantiate() @@ -98,12 +86,11 @@ def apply_contact( self[first_sys_idx], self[second_sys_idx], ) - func = functools.partial( - apply_contact, - contact_instance=contact_instance, - first_sys_idx=first_sys_idx, - second_sys_idx=second_sys_idx, + func: OperatorType = lambda time: contact_instance.apply_contact( + system_one=self[first_sys_idx], + system_two=self[second_sys_idx], ) + self._feature_group_synchronize.add_operators(contact, [func]) if not self._feature_group_synchronize.is_last(contact): diff --git a/elastica/modules/damping.py b/elastica/modules/damping.py index ea9ea5ea..0d2c3e64 100644 --- a/elastica/modules/damping.py +++ b/elastica/modules/damping.py @@ -12,11 +12,13 @@ from typing import Any, Type, List from typing_extensions import Self +import functools + import numpy as np from elastica.dissipation import DamperBase from elastica.typing import RodType, SystemType, SystemIdxType -from .protocol import SystemCollectionProtocol, ModuleProtocol +from .protocol import DampenedSystemCollectionProtocol, ModuleProtocol class Damping: @@ -31,13 +33,14 @@ class Damping: List of damper classes defined for rod-like objects. """ - def __init__(self: SystemCollectionProtocol) -> None: + def __init__(self: DampenedSystemCollectionProtocol) -> None: self._damping_list: List[ModuleProtocol] = [] super().__init__() - self._feature_group_constrain_rates.append(self._dampen_rates) self._feature_group_finalize.append(self._finalize_dampers) - def dampen(self: SystemCollectionProtocol, system: RodType) -> ModuleProtocol: + def dampen( + self: DampenedSystemCollectionProtocol, system: RodType + ) -> ModuleProtocol: """ This method applies damping on relevant user-defined system or rod-like object. You must input the system or rod-like @@ -57,28 +60,29 @@ def dampen(self: SystemCollectionProtocol, system: RodType) -> ModuleProtocol: # Create _Damper object, cache it and return to user _damper: ModuleProtocol = _Damper(sys_idx) self._damping_list.append(_damper) + self._feature_group_constrain_rates.append_id(_damper) return _damper - def _finalize_dampers(self: SystemCollectionProtocol) -> None: + def _finalize_dampers(self: DampenedSystemCollectionProtocol) -> None: # From stored _Damping objects, instantiate the dissipation/damping # inplace : https://stackoverflow.com/a/1208792 - self._damping_operators = [ - (damper.id(), damper.instantiate(self[damper.id()])) - for damper in self._damping_list - ] - # Sort from lowest id to highest id for potentially better memory access # _dampers contains list of tuples. First element of tuple is rod number and # following elements are the type of damping. # Thus using lambda we iterate over the list of tuples and use rod number (x[0]) # to sort dampers. - self._damping_operators.sort(key=lambda x: x[0]) + self._damping_list.sort(key=lambda x: x.id()) + for damping in self._damping_list: + sys_id = damping.id() + damping_instance = damping.instantiate(self[sys_id]) + + dampen_rate = functools.partial(damping_instance.dampen_rates, self[sys_id]) + self._feature_group_constrain_rates.add_operators(damping, [dampen_rate]) - def _dampen_rates(self: SystemCollectionProtocol, time: np.float64) -> None: - for sys_id, damper in self._damping_operators: - damper.dampen_rates(self[sys_id], time) + self._damping_list = [] + del self._damping_list class _Damper: diff --git a/elastica/modules/forcing.py b/elastica/modules/forcing.py index e36b0899..8b7f6282 100644 --- a/elastica/modules/forcing.py +++ b/elastica/modules/forcing.py @@ -14,7 +14,7 @@ from elastica.external_forces import NoForces from elastica.typing import SystemType, SystemIdxType -from .protocol import SystemCollectionProtocol, ModuleProtocol +from .protocol import ForcedSystemCollectionProtocol, ModuleProtocol logger = logging.getLogger(__name__) @@ -31,13 +31,13 @@ class Forcing: List of forcing class defined for rod-like objects. """ - def __init__(self: SystemCollectionProtocol) -> None: + def __init__(self: ForcedSystemCollectionProtocol) -> None: self._ext_forces_torques: List[ModuleProtocol] = [] super().__init__() self._feature_group_finalize.append(self._finalize_forcing) def add_forcing_to( - self: SystemCollectionProtocol, system: SystemType + self: ForcedSystemCollectionProtocol, system: SystemType ) -> ModuleProtocol: """ This method applies external forces and torques on the relevant @@ -62,7 +62,7 @@ def add_forcing_to( return _ext_force_torque - def _finalize_forcing(self: SystemCollectionProtocol) -> None: + def _finalize_forcing(self: ForcedSystemCollectionProtocol) -> None: # From stored _ExtForceTorque objects, and instantiate a Force # inplace : https://stackoverflow.com/a/1208792 diff --git a/elastica/modules/operator_group.py b/elastica/modules/operator_group.py index 3b01225d..d1905ad5 100644 --- a/elastica/modules/operator_group.py +++ b/elastica/modules/operator_group.py @@ -1,11 +1,12 @@ -from typing import TypeVar, Generic, Iterator - -from collections.abc import Iterable +from typing import TYPE_CHECKING, TypeVar, Generic, Callable, Any +from collections.abc import Iterable, Iterator import itertools -T = TypeVar("T") -F = TypeVar("F") +from .protocol import ModuleProtocol + +T = TypeVar("T", bound=Callable) +F = TypeVar("F", bound=ModuleProtocol) class OperatorGroupFIFO(Iterable, Generic[T, F]): @@ -78,3 +79,9 @@ def add_operators(self, feature: F, operators: list[T]) -> None: def is_last(self, feature: F) -> bool: """Checks if the feature is the last feature in the FIFO.""" return id(feature) == self._operator_ids[-1] + + +if TYPE_CHECKING: + from elastica.typing import OperatorType + + _: Iterable[OperatorType] = OperatorGroupFIFO[OperatorType, Any]() diff --git a/elastica/modules/protocol.py b/elastica/modules/protocol.py index eb1b5d84..fe7cfea5 100644 --- a/elastica/modules/protocol.py +++ b/elastica/modules/protocol.py @@ -1,8 +1,7 @@ from typing import Protocol, Generator, TypeVar, Any, Type, overload +from typing import TYPE_CHECKING from typing_extensions import Self # python 3.11: from typing import Self -from abc import abstractmethod - from elastica.typing import ( SystemIdxType, OperatorType, @@ -10,6 +9,8 @@ OperatorFinalizeType, StaticSystemType, SystemType, + RodType, + RigidBodyType, BlockSystemType, ConnectionIndex, ) @@ -20,10 +21,16 @@ import numpy as np -from .operator_group import OperatorGroupFIFO +if TYPE_CHECKING: + from .operator_group import OperatorGroupFIFO + + +class MixinProtocol(Protocol): + # def finalize(self) -> None: ... + ... -M = TypeVar("M", bound="ModuleProtocol") +M = TypeVar("M", bound=MixinProtocol) class ModuleProtocol(Protocol[M]): @@ -47,106 +54,89 @@ def __getitem__(self, i: slice) -> list[SystemType]: ... def __getitem__(self, i: int) -> SystemType: ... def __getitem__(self, i: slice | int) -> "list[SystemType] | SystemType": ... - @property - def _feature_group_synchronize(self) -> OperatorGroupFIFO: ... + def __delitem__(self, i: slice | int) -> None: ... + def __setitem__(self, i: slice | int, value: SystemType) -> None: ... + def insert(self, i: int, value: SystemType) -> None: ... - def synchronize(self, time: np.float64) -> None: ... + def get_system_index( + self, sys_to_be_added: "SystemType | StaticSystemType" + ) -> SystemIdxType: ... - @property - def _feature_group_constrain_values(self) -> list[OperatorType]: ... + # Operator Group + _feature_group_synchronize: "OperatorGroupFIFO[OperatorType, ModuleProtocol]" + _feature_group_constrain_values: "OperatorGroupFIFO[OperatorType, ModuleProtocol]" + _feature_group_constrain_rates: "OperatorGroupFIFO[OperatorType, ModuleProtocol]" + _feature_group_callback: "OperatorGroupFIFO[OperatorCallbackType, ModuleProtocol]" + def synchronize(self, time: np.float64) -> None: ... def constrain_values(self, time: np.float64) -> None: ... - - @property - def _feature_group_constrain_rates(self) -> list[OperatorType]: ... - def constrain_rates(self, time: np.float64) -> None: ... - - @property - def _feature_group_callback(self) -> list[OperatorCallbackType]: ... - def apply_callbacks(self, time: np.float64, current_step: int) -> None: ... - @property - def _feature_group_finalize(self) -> list[OperatorFinalizeType]: ... + # Finalize Operations + _feature_group_finalize: list[OperatorFinalizeType] - def get_system_index( - self, sys_to_be_added: "SystemType | StaticSystemType" - ) -> SystemIdxType: ... + def finalize(self) -> None: ... + +# Mixin Protocols (Used to type Self) +class ConnectedSystemCollectionProtocol(SystemCollectionProtocol, Protocol): # Connection API - _finalize_connections: OperatorFinalizeType _connections: list[ModuleProtocol] - @abstractmethod + def _finalize_connections(self) -> None: ... + def connect( self, - first_rod: SystemType, - second_rod: SystemType, + first_rod: "RodType | RigidBodyType", + second_rod: "RodType | RigidBodyType", first_connect_idx: ConnectionIndex, second_connect_idx: ConnectionIndex, - ) -> ModuleProtocol: - raise NotImplementedError + ) -> ModuleProtocol: ... - # CallBack API - _finalize_callback: OperatorFinalizeType - _callback_list: list[ModuleProtocol] - _callback_operators: list[tuple[int, CallBackBaseClass]] - @abstractmethod - def collect_diagnostics(self, system: SystemType) -> ModuleProtocol: - raise NotImplementedError +class ForcedSystemCollectionProtocol(SystemCollectionProtocol, Protocol): + # Forcing API + _ext_forces_torques: list[ModuleProtocol] + + def _finalize_forcing(self) -> None: ... + + def add_forcing_to(self, system: SystemType) -> ModuleProtocol: ... + + +class ContactedSystemCollectionProtocol(SystemCollectionProtocol, Protocol): + # Contact API + _contacts: list[ModuleProtocol] + + def _finalize_contact(self) -> None: ... - @abstractmethod - def _callback_execution( - self, time: np.float64, current_step: int, *args: Any, **kwargs: Any - ) -> None: - raise NotImplementedError + def detect_contact_between( + self, first_system: SystemType, second_system: SystemType + ) -> ModuleProtocol: ... + +class ConstrainedSystemCollectionProtocol(SystemCollectionProtocol, Protocol): # Constraints API _constraints_list: list[ModuleProtocol] - _constraints_operators: list[tuple[int, ConstraintBase]] - _finalize_constraints: OperatorFinalizeType - @abstractmethod - def constrain(self, system: SystemType) -> ModuleProtocol: - raise NotImplementedError + def _finalize_constraints(self) -> None: ... - @abstractmethod - def _constrain_values(self, time: np.float64) -> None: - raise NotImplementedError + def constrain(self, system: "RodType | RigidBodyType") -> ModuleProtocol: ... - @abstractmethod - def _constrain_rates(self, time: np.float64) -> None: - raise NotImplementedError - # Forcing API - _ext_forces_torques: list[ModuleProtocol] - _finalize_forcing: OperatorFinalizeType +class SystemCollectionWithCallbackProtocol(SystemCollectionProtocol, Protocol): + # CallBack API + _callback_list: list[ModuleProtocol] - @abstractmethod - def add_forcing_to(self, system: SystemType) -> ModuleProtocol: - raise NotImplementedError + def _finalize_callback(self) -> None: ... - # Contact API - _contacts: list[ModuleProtocol] - _finalize_contact: OperatorFinalizeType + def collect_diagnostics(self, system: SystemType) -> ModuleProtocol: ... - @abstractmethod - def detect_contact_between( - self, first_system: SystemType, second_system: SystemType - ) -> ModuleProtocol: - raise NotImplementedError +class DampenedSystemCollectionProtocol(SystemCollectionProtocol, Protocol): # Damping API _damping_list: list[ModuleProtocol] - _damping_operators: list[tuple[int, DamperBase]] - _finalize_dampers: OperatorFinalizeType - @abstractmethod - def dampen(self, system: SystemType) -> ModuleProtocol: - raise NotImplementedError + def _finalize_dampers(self) -> None: ... - @abstractmethod - def _dampen_rates(self, time: np.float64) -> None: - raise NotImplementedError + def dampen(self, system: RodType) -> ModuleProtocol: ... diff --git a/elastica/systems/protocol.py b/elastica/systems/protocol.py index 254cdcaf..89d52d92 100644 --- a/elastica/systems/protocol.py +++ b/elastica/systems/protocol.py @@ -68,15 +68,3 @@ def kinematic_rates( def dynamic_rates( self, time: np.float64, prefac: np.float64 ) -> NDArray[np.float64]: ... - - -class ExplicitSystemProtocol(SystemProtocol, SlenderBodyGeometryProtocol, Protocol): - # TODO: Temporarily made to handle explicit stepper. - # Need to be refactored as the explicit stepper is further developed. - def __call__(self, time: np.float64, dt: np.float64) -> np.float64: ... - @property - def state(self) -> StateType: ... - @state.setter - def state(self, state: StateType) -> None: ... - @property - def n_elems(self) -> int: ... diff --git a/elastica/timestepper/__init__.py b/elastica/timestepper/__init__.py index 25f88341..264c9d8b 100644 --- a/elastica/timestepper/__init__.py +++ b/elastica/timestepper/__init__.py @@ -1,23 +1,21 @@ __doc__ = """Timestepping utilities to be used with Rod and RigidBody classes""" -from typing import Tuple, List, Callable, Type, Any, overload, cast -from elastica.typing import SystemType, SystemCollectionType, SteppersOperatorsType +from typing import Callable +from elastica.typing import SystemCollectionType, SteppersOperatorsType import numpy as np from tqdm import tqdm from elastica.systems import is_system_a_collection -from .symplectic_steppers import PositionVerlet, PEFRL -from .explicit_steppers import RungeKutta4, EulerForward -from .protocol import StepperProtocol, SymplecticStepperProtocol +from .protocol import StepperProtocol # Deprecated: Remove in the future version # Many script still uses this method to control timestep. Keep it for backward compatibility def extend_stepper_interface( stepper: StepperProtocol, system_collection: SystemCollectionType -) -> Tuple[ +) -> tuple[ Callable[ [StepperProtocol, SystemCollectionType, np.float64, np.float64], np.float64 ], @@ -31,32 +29,10 @@ def extend_stepper_interface( return do_step_method, stepper_methods -@overload -def integrate( - stepper: StepperProtocol, - systems: SystemType, - final_time: float, - n_steps: int, - restart_time: float, - progress_bar: bool, -) -> float: ... - - -@overload def integrate( stepper: StepperProtocol, systems: SystemCollectionType, final_time: float, - n_steps: int, - restart_time: float, - progress_bar: bool, -) -> float: ... - - -def integrate( - stepper: StepperProtocol, - systems: "SystemType | SystemCollectionType", - final_time: float, n_steps: int = 1000, restart_time: float = 0.0, progress_bar: bool = True, @@ -67,7 +43,7 @@ def integrate( ---------- stepper : StepperProtocol Stepper algorithm to use. - systems : SystemType | SystemCollectionType + systems : SystemCollectionType The elastica-system to simulate. final_time : float Total simulation time. The timestep is determined by final_time / n_steps. @@ -85,13 +61,12 @@ def integrate( time = np.float64(restart_time) if is_system_a_collection(systems): - systems = cast(SystemCollectionType, systems) for i in tqdm(range(n_steps), disable=(not progress_bar)): time = stepper.step(systems, time, dt) else: - systems = cast(SystemType, systems) + # Typing is ignored since this part only exist for unit-testing for i in tqdm(range(n_steps), disable=(not progress_bar)): - time = stepper.step_single_instance(systems, time, dt) + time = stepper.step_single_instance(systems, time, dt) # type: ignore[arg-type] print("Final time of simulation is : ", time) return float(time) diff --git a/elastica/timestepper/protocol.py b/elastica/timestepper/protocol.py index 1a64c725..18a92fc4 100644 --- a/elastica/timestepper/protocol.py +++ b/elastica/timestepper/protocol.py @@ -3,11 +3,11 @@ from typing import Protocol from elastica.typing import ( - SystemType, SteppersOperatorsType, - OperatorType, + StepType, SystemCollectionType, ) +from elastica.systems.protocol import SymplecticSystemProtocol import numpy as np @@ -29,80 +29,13 @@ def step( ) -> np.float64: ... def step_single_instance( - self, SystemCollection: SystemType, time: np.float64, dt: np.float64 + self, System: SymplecticSystemProtocol, time: np.float64, dt: np.float64 ) -> np.float64: ... class SymplecticStepperProtocol(StepperProtocol, Protocol): """symplectic stepper protocol.""" - def get_steps(self) -> list[OperatorType]: ... + def get_steps(self) -> list[StepType]: ... - def get_prefactors(self) -> list[OperatorType]: ... - - -class MemoryProtocol(Protocol): - @property - def initial_state(self) -> bool: ... - - -class ExplicitStepperProtocol(StepperProtocol, Protocol): - """symplectic stepper protocol.""" - - def get_stages(self) -> list[OperatorType]: ... - - def get_updates(self) -> list[OperatorType]: ... - - -# class _LinearExponentialIntegratorMixin: -# """ -# Linear Exponential integrator mixin wrapper. -# """ -# -# def __init__(self): -# pass -# -# def _do_stage(self, System, Memory, time, dt): -# # TODO : Make more general, system should not be calculating what the state -# # transition matrix directly is, but rather it should just give -# Memory.linear_operator = System.get_linear_state_transition_operator(time, dt) -# -# def _do_update(self, System, Memory, time, dt): -# # FIXME What's the right formula when doing update? -# # System.linearly_evolving_state = _batch_matmul( -# # System.linearly_evolving_state, -# # Memory.linear_operator -# # ) -# System.linearly_evolving_state = np.einsum( -# "ijk,ljk->ilk", System.linearly_evolving_state, Memory.linear_operator -# ) -# return time + dt -# -# def _first_prefactor(self, dt): -# """Prefactor call to satisfy interface of SymplecticStepper. Should never -# be used in actual code. -# -# Parameters -# ---------- -# dt : the time step of simulation -# -# Raises -# ------ -# RuntimeError -# """ -# raise RuntimeError( -# "Symplectic prefactor of LinearExponentialIntegrator should not be called!" -# ) -# -# # Code repeat! -# # Easy to avoid, but keep for performance. -# def _do_one_step(self, System, time, prefac): -# System.linearly_evolving_state = np.einsum( -# "ijk,ljk->ilk", -# System.linearly_evolving_state, -# System.get_linear_state_transition_operator(time, prefac), -# ) -# return ( -# time # TODO fix hack that treats time separately here. Shuold be time + dt -# ) -# # return time + dt + def get_prefactors(self) -> list[StepType]: ... diff --git a/elastica/timestepper/symplectic_steppers.py b/elastica/timestepper/symplectic_steppers.py index 3937e337..4bd355af 100644 --- a/elastica/timestepper/symplectic_steppers.py +++ b/elastica/timestepper/symplectic_steppers.py @@ -1,15 +1,12 @@ __doc__ = """Symplectic time steppers and concepts for integrating the kinematic and dynamic equations of rod-like objects. """ -from typing import Any, Final +from typing import TYPE_CHECKING, Any from itertools import zip_longest from elastica.typing import ( - SystemType, SystemCollectionType, - # StepOperatorType, - # PrefactorOperatorType, - OperatorType, + StepType, SteppersOperatorsType, ) @@ -33,22 +30,22 @@ class SymplecticStepperMixin: def __init__(self: SymplecticStepperProtocol): - self.steps_and_prefactors: Final[SteppersOperatorsType] = self.step_methods() + self.steps_and_prefactors: SteppersOperatorsType = self.step_methods() def step_methods(self: SymplecticStepperProtocol) -> SteppersOperatorsType: # Let the total number of steps for the Symplectic method # be (2*n + 1) (for time-symmetry). - _steps: list[OperatorType] = self.get_steps() + _steps: list[StepType] = self.get_steps() # Prefac here is necessary because the linear-exponential integrator # needs only the prefactor and not the dt. - _prefactors: list[OperatorType] = self.get_prefactors() + _prefactors: list[StepType] = self.get_prefactors() assert int(np.ceil(len(_steps) / 2)) == len( _prefactors ), f"{len(_steps)=}, {len(_prefactors)=}" # Separate the kinematic and dynamic steps - _kinematic_steps: list[OperatorType] = _steps[::2] - _dynamic_steps: list[OperatorType] = _steps[1::2] + _kinematic_steps: list[StepType] = _steps[::2] + _dynamic_steps: list[StepType] = _steps[1::2] def no_operation(*args: Any) -> None: pass @@ -164,14 +161,14 @@ class PositionVerlet(SymplecticStepperMixin): includes methods for second-order position Verlet. """ - def get_steps(self) -> list[OperatorType]: + def get_steps(self) -> list[StepType]: return [ self._first_kinematic_step, self._first_dynamic_step, self._first_kinematic_step, ] - def get_prefactors(self) -> list[OperatorType]: + def get_prefactors(self) -> list[StepType]: return [ self._first_prefactor, self._first_prefactor, @@ -218,7 +215,7 @@ class PEFRL(SymplecticStepperMixin): lambda_dash_coeff: np.float64 = 0.5 * (1.0 - 2.0 * λ) xi_chi_dash_coeff: np.float64 = 1.0 - 2.0 * (ξ + χ) - def get_steps(self) -> list[OperatorType]: + def get_steps(self) -> list[StepType]: operators = [ self._first_kinematic_step, self._first_dynamic_step, @@ -228,7 +225,7 @@ def get_steps(self) -> list[OperatorType]: ] return operators + operators[-2::-1] - def get_prefactors(self) -> list[OperatorType]: + def get_prefactors(self) -> list[StepType]: return [ self._first_kinematic_prefactor, self._second_kinematic_prefactor, @@ -308,3 +305,10 @@ def _third_kinematic_step( System.omega_collection, ) # System.kinematic_states += prefac * System.kinematic_rates(time, prefac) + + +if TYPE_CHECKING: + from .protocol import StepperProtocol + + _: StepperProtocol = PositionVerlet() + _: StepperProtocol = PEFRL() # type: ignore [no-redef] diff --git a/elastica/typing.py b/elastica/typing.py index 79cba003..f77cd015 100644 --- a/elastica/typing.py +++ b/elastica/typing.py @@ -4,7 +4,7 @@ """ from typing import TYPE_CHECKING -from typing import Callable, Any, ParamSpec, TypeAlias +from typing import Callable, Any, TypeAlias, Protocol import numpy as np @@ -23,12 +23,10 @@ SystemProtocol, StaticSystemProtocol, SymplecticSystemProtocol, - ExplicitSystemProtocol, ) from .timestepper.protocol import ( StepperProtocol, SymplecticStepperProtocol, - MemoryProtocol, ) from .memory_block.protocol import BlockSystemProtocol @@ -45,8 +43,8 @@ StateType: TypeAlias = "State" # TODO: Maybe can be more specific. Up for discussion. -OperatorType: TypeAlias = Callable[..., Any] -SteppersOperatorsType: TypeAlias = tuple[tuple[OperatorType, ...], ...] +StepType: TypeAlias = Callable[..., Any] +SteppersOperatorsType: TypeAlias = tuple[tuple[StepType, ...], ...] RodType: TypeAlias = "CosseratRodProtocol" @@ -62,10 +60,16 @@ int | np.int32 | list[int] | tuple[int, ...] | np.typing.NDArray[np.int32] ) + # Operators in elastica.modules -# TODO: can be more specific. -OperatorParam = ParamSpec("OperatorParam") -OperatorCallbackType: TypeAlias = Callable[..., None] -OperatorFinalizeType: TypeAlias = Callable[..., None] +class OperatorType(Protocol): + def __call__(self, time: np.float64) -> None: ... + + +class OperatorCallbackType(Protocol): + def __call__(self, time: np.float64, current_step: int) -> None: ... + + +OperatorFinalizeType: TypeAlias = Callable[[], None] MeshType: TypeAlias = "MeshProtocol" diff --git a/pyproject.toml b/pyproject.toml index bbd4e9ce..eef6872b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -145,7 +145,6 @@ warn_unused_configs = true warn_unused_ignores = false exclude = [ - "elastica/systems/analytical.py", "elastica/experimental/*", ] @@ -183,7 +182,6 @@ branch = true omit = [ "*/.local/*", "setup.py", - "elastica/systems/analytical.py", "elastica/experimental/*", "elastica/**/protocol.py", ] diff --git a/elastica/systems/analytical.py b/tests/analytical.py similarity index 100% rename from elastica/systems/analytical.py rename to tests/analytical.py diff --git a/tests/test_math/test_timestepper.py b/tests/test_math/test_timestepper.py index 2afa6de6..62f79a45 100644 --- a/tests/test_math/test_timestepper.py +++ b/tests/test_math/test_timestepper.py @@ -3,19 +3,8 @@ import pytest from numpy.testing import assert_allclose -from elastica.systems.analytical import ( - ScalarExponentialDecaySystem, - # UndampedSimpleHarmonicOscillatorSystem, - SymplecticUndampedSimpleHarmonicOscillatorSystem, - # DampedSimpleHarmonicOscillatorSystem, - # MultipleFrameRotationSystem, - # SecondOrderHybridSystem, - SymplecticUndampedHarmonicOscillatorCollectiveSystem, - ScalarExponentialDampedHarmonicOscillatorCollectiveSystem, -) from elastica.timestepper import integrate, extend_stepper_interface - -from elastica.timestepper.explicit_steppers import ( +from elastica.experimental.timestepper.explicit_steppers import ( RungeKutta4, EulerForward, ExplicitStepperMixin, @@ -25,10 +14,20 @@ PEFRL, SymplecticStepperMixin, ) - - from elastica.utils import Tolerance +from tests.analytical import ( + ScalarExponentialDecaySystem, + # UndampedSimpleHarmonicOscillatorSystem, + SymplecticUndampedSimpleHarmonicOscillatorSystem, + # DampedSimpleHarmonicOscillatorSystem, + # MultipleFrameRotationSystem, + # SecondOrderHybridSystem, + SymplecticUndampedHarmonicOscillatorCollectiveSystem, + ScalarExponentialDampedHarmonicOscillatorCollectiveSystem, + make_simple_system_with_positions_directors, +) + class TestExtendStepperInterface: """TODO add documentation""" @@ -245,7 +244,9 @@ def test_explicit_steppers(self, explicit_stepper): # Before stepping, let's extend the interface of the stepper # while providing memory slots - from elastica.systems.memory import make_memory_for_explicit_stepper + from elastica.experimental.timestepper.memory import ( + make_memory_for_explicit_stepper, + ) memory_collection = make_memory_for_explicit_stepper(stepper, collective_system) from elastica.timestepper import extend_stepper_interface @@ -307,9 +308,6 @@ class TestSteppersAgainstRodLikeSystems: @pytest.mark.parametrize("symplectic_stepper", SymplecticSteppers) def test_symplectics_against_ellipse_motion(self, symplectic_stepper): - from elastica.systems.analytical import ( - make_simple_system_with_positions_directors, - ) random_start_position = np.random.randn(3, 1) random_end_position = np.random.randn(3, 1) diff --git a/tests/test_modules/test_base_system.py b/tests/test_modules/test_base_system.py index 694bd84f..9d4d9c3b 100644 --- a/tests/test_modules/test_base_system.py +++ b/tests/test_modules/test_base_system.py @@ -186,7 +186,16 @@ def test_constraint(self, load_collection, legal_constraint): simulator_class.finalize() # After finalize check if the created constrain object is instance of the class we have given. assert isinstance( - simulator_class._constraints_operators[-1][-1], legal_constraint + simulator_class._feature_group_constrain_values._operator_collection[-1][ + -1 + ].func.__self__, + legal_constraint, + ) + assert isinstance( + simulator_class._feature_group_constrain_rates._operator_collection[-1][ + -1 + ].func.__self__, + legal_constraint, ) # TODO: this is a dummy test for constrain values and rates find a better way to test them @@ -225,7 +234,12 @@ def test_callback(self, load_collection, legal_callback): simulator_class.collect_diagnostics(rod).using(legal_callback) simulator_class.finalize() # After finalize check if the created callback object is instance of the class we have given. - assert isinstance(simulator_class._callback_operators[-1][-1], legal_callback) + assert isinstance( + simulator_class._feature_group_callback._operator_collection[-1][ + -1 + ].func.__self__, + legal_callback, + ) # TODO: this is a dummy test for apply_callbacks find a better way to test them simulator_class.apply_callbacks(time=0, current_step=0) diff --git a/tests/test_modules/test_callbacks.py b/tests/test_modules/test_callbacks.py index a58901e2..d5830229 100644 --- a/tests/test_modules/test_callbacks.py +++ b/tests/test_modules/test_callbacks.py @@ -161,10 +161,13 @@ def mock_init(self, *args, **kwargs): def test_callback_finalize_correctness(self, load_rod_with_callbacks): scwc, callback_cls = load_rod_with_callbacks + callback_features = [d for d in scwc._callback_list] scwc._finalize_callback() - for x, y in scwc._callback_operators: + for _callback in callback_features: + x = _callback.id() + y = _callback.instantiate() assert type(x) is int assert type(y) is callback_cls diff --git a/tests/test_modules/test_constraints.py b/tests/test_modules/test_constraints.py index c6072c39..6ff952eb 100644 --- a/tests/test_modules/test_constraints.py +++ b/tests/test_modules/test_constraints.py @@ -315,24 +315,20 @@ def constrain_rates(self, *args, **kwargs) -> None: def test_constrain_finalize_correctness(self, load_rod_with_constraints): scwc, bc_cls = load_rod_with_constraints + bc_features = [bc for bc in scwc._constraints_list] scwc._finalize_constraints() + assert not hasattr(scwc, "_constraints_list") - for x, y in scwc._constraints_operators: - assert type(x) is int - assert type(y) is bc_cls + for _constraint in bc_features: + x = _constraint.id() + y = _constraint.instantiate(scwc[x]) + assert isinstance(x, int) + assert isinstance(y, bc_cls) - def test_constraint_properties(self, load_rod_with_constraints): - scwc, _ = load_rod_with_constraints - scwc._finalize_constraints() - - for i in [0, 1, -1]: - x, y = scwc._constraints_operators[i] - mock_rod = scwc[i] # Test system - assert type(x) is int - assert type(y.system) is type(mock_rod) - assert y.system is mock_rod, f"{len(scwc)}" + assert type(y.system) is type(scwc[x]) + assert y.system is scwc[x], f"{len(scwc)}" # Test node indices assert y.constrained_position_idx.size == 0 # Test element indices. TODO: maybe add more generalized test diff --git a/tests/test_modules/test_damping.py b/tests/test_modules/test_damping.py index 2634da84..d3b23dc8 100644 --- a/tests/test_modules/test_damping.py +++ b/tests/test_modules/test_damping.py @@ -180,26 +180,18 @@ def dampen_rates(self, *args, **kwargs) -> None: return scwd, MockDamper - def test_dampen_finalize_correctness(self, load_rod_with_dampers): + def test_dampen_finalize_clear_instances(self, load_rod_with_dampers): scwd, damper_cls = load_rod_with_dampers + damping_features = [d for d in scwd._damping_list] scwd._finalize_dampers() + assert not hasattr(scwd, "_damping_list") - for x, y in scwd._damping_operators: - assert type(x) is int - assert type(y) is damper_cls - - def test_damper_properties(self, load_rod_with_dampers): - scwd, _ = load_rod_with_dampers - scwd._finalize_dampers() - - for i in [0, 1, -1]: - x, y = scwd._damping_operators[i] - mock_rod = scwd[i] - # Test system - assert type(x) is int - assert type(y.system) is type(mock_rod) - assert y.system is mock_rod, f"{len(scwd)}" + for _damping in damping_features: + x = _damping.id() + y = _damping.instantiate(scwd[x]) + assert isinstance(x, int) + assert isinstance(y, damper_cls) @pytest.mark.xfail def test_dampers_finalize_sorted(self, load_rod_with_dampers): diff --git a/tests/test_modules/test_feature_grouping.py b/tests/test_modules/test_feature_grouping.py index 76a281a8..8fa3c59e 100644 --- a/tests/test_modules/test_feature_grouping.py +++ b/tests/test_modules/test_feature_grouping.py @@ -1,4 +1,5 @@ from elastica.modules.operator_group import OperatorGroupFIFO +import functools def test_add_ids(): @@ -65,3 +66,77 @@ def test_is_last(): assert group.is_last(1) == False assert group.is_last(2) == True + + +class TestOperatorGroupingWithCallableModules: + class OperatorTypeA: + def __init__(self): + self.value = 0 + + def apply(self) -> None: + self.value += 1 + + class OperatorTypeB: + def __init__(self): + self.value2 = 0 + + def apply(self) -> None: + self.value2 -= 1 + + # def test_lambda(self): + # feature_group = OperatorGroupFIFO() + + # op_a = self.OperatorTypeA() + # feature_group.append_id(op_a) + # op_b = self.OperatorTypeB() + # feature_group.append_id(op_b) + + # for op in [op_a, op_b]: + # func = functools.partial(lambda t: op.apply()) + # feature_group.add_operators(op, [func]) + + # for operator in feature_group: + # operator(t=0) + + # assert op_a.value == 1 + # assert op_b.value2 == -1 + + # def test_def(self): + # feature_group = OperatorGroupFIFO() + + # op_a = self.OperatorTypeA() + # feature_group.append_id(op_a) + # op_b = self.OperatorTypeB() + # feature_group.append_id(op_b) + + # for op in [op_a, op_b]: + # def func(t): + # op.apply() + # feature_group.add_operators(op, [func]) + + # for operator in feature_group: + # operator(t=0) + + # assert op_a.value == 1 + # assert op_b.value2 == -1 + + def test_partial(self): + feature_group = OperatorGroupFIFO() + + op_a = self.OperatorTypeA() + feature_group.append_id(op_a) + op_b = self.OperatorTypeB() + feature_group.append_id(op_b) + + def _func(t, op): + op.apply() + + for op in [op_a, op_b]: + func = functools.partial(_func, op=op) + feature_group.add_operators(op, [func]) + + for operator in feature_group: + operator(t=0) + + assert op_a.value == 1 + assert op_b.value2 == -1 diff --git a/tests/test_rigid_body/test_rigid_body_data_structures.py b/tests/test_rigid_body/test_rigid_body_data_structures.py index 0ca7715a..0ede2bd8 100644 --- a/tests/test_rigid_body/test_rigid_body_data_structures.py +++ b/tests/test_rigid_body/test_rigid_body_data_structures.py @@ -5,13 +5,11 @@ from elastica.utils import Tolerance from elastica.rigidbody.data_structures import _RigidRodSymplecticStepperMixin from elastica._rotations import _rotate -from elastica.timestepper import ( - RungeKutta4, - EulerForward, +from elastica.timestepper.symplectic_steppers import ( PEFRL, PositionVerlet, - integrate, ) +from elastica.timestepper import integrate def make_simple_system_with_positions_directors(start_position, start_director): @@ -92,7 +90,6 @@ def analytical_solution(self, type, time): return analytical_solution -ExplicitSteppers = [EulerForward, RungeKutta4] SymplecticSteppers = [PositionVerlet, PEFRL]