Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Type] Check instantiation for system-collection and timestepper #405

Draft
wants to merge 12 commits into
base: update-0.3.3
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ flake8:
.PHONY: autoflake-check
autoflake-check:
poetry run autoflake --version
poetry run autoflake $(AUTOFLAKE_ARGS) elastica tests examples
poetry run autoflake --check $(AUTOFLAKE_ARGS) elastica tests examples

.PHONY: autoflake-format
Expand Down
5 changes: 1 addition & 4 deletions elastica/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,9 @@
from elastica.utils import isqrt
from elastica.timestepper import (
integrate,
PositionVerlet,
PEFRL,
RungeKutta4,
EulerForward,
extend_stepper_interface,
)
from elastica.timestepper.symplectic_steppers import PositionVerlet, PEFRL
from elastica.memory_block.memory_block_rigid_body import MemoryBlockRigidBody
from elastica.memory_block.memory_block_rod import MemoryBlockCosseratRod
from elastica.restart import save_state, load_state
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@
from elastica.typing import (
SystemType,
SystemCollectionType,
OperatorType,
StepType,
SteppersOperatorsType,
StateType,
)
from elastica.systems.protocol import ExplicitSystemProtocol
from .protocol import ExplicitStepperProtocol, MemoryProtocol
from elastica.experimental.timestepper.protocol import (
ExplicitSystemProtocol,
ExplicitStepperProtocol,
MemoryProtocol,
)


"""
Expand Down Expand Up @@ -166,10 +169,10 @@ class EulerForward(ExplicitStepperMixin):
Classical Euler Forward stepper. Stateless, coordinates operations only.
"""

def get_stages(self) -> list[OperatorType]:
def get_stages(self) -> list[StepType]:
return [self._first_stage]

def get_updates(self) -> list[OperatorType]:
def get_updates(self) -> list[StepType]:
return [self._first_update]

def _first_stage(
Expand Down Expand Up @@ -198,15 +201,15 @@ class RungeKutta4(ExplicitStepperMixin):
to be externally managed and allocated.
"""

def get_stages(self) -> list[OperatorType]:
def get_stages(self) -> list[StepType]:
return [
self._first_stage,
self._second_stage,
self._third_stage,
self._fourth_stage,
]

def get_updates(self) -> list[OperatorType]:
def get_updates(self) -> list[StepType]:
return [
self._first_update,
self._second_update,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from typing import Iterator, TypeVar, Generic, Type
from elastica.timestepper.protocol import ExplicitStepperProtocol
from elastica.typing import SystemCollectionType
from elastica.experimental.timestepper.explicit_steppers import (
RungeKutta4,
EulerForward,
)
from elastica.experimental.timestepper.protocol import ExplicitStepperProtocol

from copy import copy

Expand All @@ -12,11 +16,6 @@ def make_memory_for_explicit_stepper(
) -> "MemoryCollection":
# TODO Automated logic (class creation, memory management logic) agnostic of stepper details (RK, AB etc.)

from elastica.timestepper.explicit_steppers import (
RungeKutta4,
EulerForward,
)

# is_this_system_a_collection = is_system_a_collection(system)

memory_cls: Type
Expand Down
86 changes: 86 additions & 0 deletions elastica/experimental/timestepper/protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from typing import Protocol

from elastica.typing import StepType, StateType
from elastica.systems.protocol import SystemProtocol, SlenderBodyGeometryProtocol
from elastica.timestepper.protocol import StepperProtocol

import numpy as np


class ExplicitSystemProtocol(SystemProtocol, SlenderBodyGeometryProtocol, Protocol):
# TODO: Temporarily made to handle explicit stepper.
# Need to be refactored as the explicit stepper is further developed.
def __call__(self, time: np.float64, dt: np.float64) -> np.float64: ...
@property
def state(self) -> StateType: ...
@state.setter
def state(self, state: StateType) -> None: ...
@property
def n_elems(self) -> int: ...


class MemoryProtocol(Protocol):
@property
def initial_state(self) -> bool: ...


class ExplicitStepperProtocol(StepperProtocol, Protocol):
"""symplectic stepper protocol."""

def get_stages(self) -> list[StepType]: ...

def get_updates(self) -> list[StepType]: ...


# class _LinearExponentialIntegratorMixin:
# """
# Linear Exponential integrator mixin wrapper.
# """
#
# def __init__(self):
# pass
#
# def _do_stage(self, System, Memory, time, dt):
# # TODO : Make more general, system should not be calculating what the state
# # transition matrix directly is, but rather it should just give
# Memory.linear_operator = System.get_linear_state_transition_operator(time, dt)
#
# def _do_update(self, System, Memory, time, dt):
# # FIXME What's the right formula when doing update?
# # System.linearly_evolving_state = _batch_matmul(
# # System.linearly_evolving_state,
# # Memory.linear_operator
# # )
# System.linearly_evolving_state = np.einsum(
# "ijk,ljk->ilk", System.linearly_evolving_state, Memory.linear_operator
# )
# return time + dt
#
# def _first_prefactor(self, dt):
# """Prefactor call to satisfy interface of SymplecticStepper. Should never
# be used in actual code.
#
# Parameters
# ----------
# dt : the time step of simulation
#
# Raises
# ------
# RuntimeError
# """
# raise RuntimeError(
# "Symplectic prefactor of LinearExponentialIntegrator should not be called!"
# )
#
# # Code repeat!
# # Easy to avoid, but keep for performance.
# def _do_one_step(self, System, time, prefac):
# System.linearly_evolving_state = np.einsum(
# "ijk,ljk->ilk",
# System.linearly_evolving_state,
# System.get_linear_state_transition_operator(time, prefac),
# )
# return (
# time # TODO fix hack that treats time separately here. Shuold be time + dt
# )
# # return time + dt
20 changes: 14 additions & 6 deletions elastica/memory_block/protocol.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
from typing import Protocol
from elastica.systems.protocol import SystemProtocol

import numpy as np

from elastica.rod.protocol import CosseratRodProtocol
from elastica.rigidbody.protocol import RigidBodyProtocol
from elastica.systems.protocol import SymplecticSystemProtocol
from elastica.systems.protocol import SystemProtocol


class BlockSystemProtocol(SystemProtocol, Protocol):
class BlockProtocol(Protocol):
@property
def n_systems(self) -> int:
"""Number of systems in the block."""


class BlockSystemProtocol(SystemProtocol, BlockProtocol, Protocol):
pass


class BlockRodProtocol(BlockProtocol, CosseratRodProtocol, Protocol):
pass


class BlockRigidBodyProtocol(BlockProtocol, RigidBodyProtocol, Protocol):
pass
64 changes: 55 additions & 9 deletions elastica/modules/base_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Basic coordinating for multiple, smaller systems that have an independently integrable
interface (i.e. works with symplectic or explicit routines `timestepper.py`.)
"""
from typing import Type, Generator, Iterable, Any, overload
from typing import TYPE_CHECKING, Type, Generator, Any, overload
from typing import final
from elastica.typing import (
SystemType,
Expand All @@ -27,6 +27,7 @@

from .memory_block import construct_memory_block_structures
from .operator_group import OperatorGroupFIFO
from .protocol import ModuleProtocol


class BaseSystemCollection(MutableSequence):
Expand Down Expand Up @@ -55,10 +56,18 @@ def __init__(self) -> None:
# Collection of functions. Each group is executed as a collection at the different steps.
# Each component (Forcing, Connection, etc.) registers the executable (callable) function
# in the group that that needs to be executed. These should be initialized before mixin.
self._feature_group_synchronize: Iterable[OperatorType] = OperatorGroupFIFO()
self._feature_group_constrain_values: list[OperatorType] = []
self._feature_group_constrain_rates: list[OperatorType] = []
self._feature_group_callback: list[OperatorCallbackType] = []
self._feature_group_synchronize: OperatorGroupFIFO[
OperatorType, ModuleProtocol
] = OperatorGroupFIFO()
self._feature_group_constrain_values: OperatorGroupFIFO[
OperatorType, ModuleProtocol
] = OperatorGroupFIFO()
self._feature_group_constrain_rates: OperatorGroupFIFO[
OperatorType, ModuleProtocol
] = OperatorGroupFIFO()
self._feature_group_callback: OperatorGroupFIFO[
OperatorCallbackType, ModuleProtocol
] = OperatorGroupFIFO()
self._feature_group_finalize: list[OperatorFinalizeType] = []
# We need to initialize our mixin classes
super().__init__()
Expand Down Expand Up @@ -104,11 +113,11 @@ def _check_type(self, sys_to_be_added: Any) -> bool:
def __len__(self) -> int:
return len(self.__systems)

@overload
def __getitem__(self, idx: int, /) -> SystemType: ...
@overload # type: ignore
def __getitem__(self, idx: slice, /) -> list[SystemType]: ... # type: ignore

@overload
def __getitem__(self, idx: slice, /) -> list[SystemType]: ...
@overload # type: ignore
def __getitem__(self, idx: int, /) -> SystemType: ... # type: ignore

def __getitem__(self, idx, /): # type: ignore
return self.__systems[idx]
Expand Down Expand Up @@ -266,3 +275,40 @@ def apply_callbacks(self, time: np.float64, current_step: int) -> None:
"""
for func in self._feature_group_callback:
func(time=time, current_step=current_step)


if TYPE_CHECKING:
from .protocol import SystemCollectionProtocol
from .constraints import Constraints
from .forcing import Forcing
from .connections import Connections
from .contact import Contact
from .damping import Damping
from .callbacks import CallBacks

class BaseFeature(BaseSystemCollection):
pass

class PartialFeatureA(
BaseSystemCollection, Constraints, Forcing, Damping, CallBacks
):
pass

class PartialFeatureB(BaseSystemCollection, Contact, Connections):
pass

class FullFeature(
BaseSystemCollection,
Constraints,
Contact,
Connections,
Forcing,
Damping,
CallBacks,
):
pass

_: SystemCollectionProtocol = FullFeature()
_: SystemCollectionProtocol = PartialFeatureA() # type: ignore[no-redef]
_: SystemCollectionProtocol = PartialFeatureB() # type: ignore[no-redef]
_: SystemCollectionProtocol = BaseFeature() # type: ignore[no-redef]
43 changes: 19 additions & 24 deletions elastica/modules/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
from elastica.typing import SystemType, SystemIdxType, OperatorFinalizeType
from .protocol import ModuleProtocol

import functools

import numpy as np

from elastica.callback_functions import CallBackBaseClass
from .protocol import SystemCollectionProtocol
from .protocol import SystemCollectionWithCallbackProtocol


class CallBacks:
Expand All @@ -27,15 +29,13 @@ class CallBacks:
List of call back classes defined for rod-like objects.
"""

def __init__(self: SystemCollectionProtocol) -> None:
def __init__(self: SystemCollectionWithCallbackProtocol) -> None:
self._callback_list: list[ModuleProtocol] = []
self._callback_operators: list[tuple[int, CallBackBaseClass]] = []
super(CallBacks, self).__init__()
self._feature_group_callback.append(self._callback_execution)
self._feature_group_finalize.append(self._finalize_callback)

def collect_diagnostics(
self: SystemCollectionProtocol, system: SystemType
self: SystemCollectionWithCallbackProtocol, system: SystemType
) -> ModuleProtocol:
"""
This method calls user-defined call-back classes for a
Expand All @@ -54,31 +54,26 @@ def collect_diagnostics(
sys_idx: SystemIdxType = self.get_system_index(system)

# Create _Constraint object, cache it and return to user
_callbacks: ModuleProtocol = _CallBack(sys_idx)
self._callback_list.append(_callbacks)
_callback: ModuleProtocol = _CallBack(sys_idx)
self._callback_list.append(_callback)
self._feature_group_callback.append_id(_callback)

return _callbacks
return _callback

def _finalize_callback(self: SystemCollectionProtocol) -> None:
def _finalize_callback(self: SystemCollectionWithCallbackProtocol) -> None:
# dev : the first index stores the rod index to collect data.
self._callback_operators = [
(callback.id(), callback.instantiate()) for callback in self._callback_list
]
for callback in self._callback_list:
sys_id = callback.id()
callback_instance = callback.instantiate()

callback_operator = functools.partial(
callback_instance.make_callback, system=self[sys_id]
)
self._feature_group_callback.add_operators(callback, [callback_operator])

self._callback_list.clear()
del self._callback_list

# First callback execution
time = np.float64(0.0)
self._callback_execution(time=time, current_step=0)

def _callback_execution(
self: SystemCollectionProtocol,
time: np.float64,
current_step: int,
) -> None:
for sys_id, callback in self._callback_operators:
callback.make_callback(self[sys_id], time, current_step)


class _CallBack:
"""
Expand Down
Loading
Loading