From 4c238ca61e9fb86d025f5c9ebb6e398efd0b7b20 Mon Sep 17 00:00:00 2001 From: Christoph Blessing <33834216+cblessing24@users.noreply.github.com> Date: Wed, 18 Oct 2023 18:04:27 +0200 Subject: [PATCH 01/13] Return a new entity from entity operations --- link/domain/link.py | 27 ++++++++-- link/domain/state.py | 70 ++++++++++--------------- tests/unit/entities/test_state.py | 85 +++++++++---------------------- 3 files changed, 74 insertions(+), 108 deletions(-) diff --git a/link/domain/link.py b/link/domain/link.py index 1339d852..282a866d 100644 --- a/link/domain/link.py +++ b/link/domain/link.py @@ -7,6 +7,7 @@ from .custom_types import Identifier from .state import ( STATE_MAP, + TRANSITION_MAP, Components, Entity, EntityOperationResult, @@ -14,6 +15,7 @@ Operations, PersistentState, Processes, + Transition, Update, ) @@ -128,7 +130,9 @@ def create_link_operation_result(results: Iterable[EntityOperationResult]) -> Li def process(link: Link, *, requested: Iterable[Identifier]) -> LinkOperationResult: """Process all entities in the link producing appropriate updates.""" _validate_requested(link, requested) - return create_link_operation_result(entity.process() for entity in link if entity.identifier in requested) + return create_link_operation_result( + _create_update(entity, Operations.PROCESS) for entity in link if entity.identifier in requested + ) def _validate_requested(link: Link, requested: Iterable[Identifier]) -> None: @@ -136,13 +140,30 @@ def _validate_requested(link: Link, requested: Iterable[Identifier]) -> None: assert set(requested) <= link.identifiers, "Requested identifiers not present in link." +def _create_update(current: Entity, operation: Operations) -> EntityOperationResult: + operations_map = { + Operations.START_PULL: "start_pull", + Operations.START_DELETE: "start_delete", + Operations.PROCESS: "process", + } + new = getattr(current, operations_map[operation])() + if current.state is new.state: + return InvalidOperation(operation, current.identifier, current.state) + transition = Transition(current.state, new.state) + return Update(operation, current.identifier, transition, TRANSITION_MAP[transition]) + + def start_pull(link: Link, *, requested: Iterable[Identifier]) -> LinkOperationResult: """Start the pull process on the requested entities.""" _validate_requested(link, requested) - return create_link_operation_result(entity.start_pull() for entity in link if entity.identifier in requested) + return create_link_operation_result( + _create_update(entity, Operations.START_PULL) for entity in link if entity.identifier in requested + ) def start_delete(link: Link, *, requested: Iterable[Identifier]) -> LinkOperationResult: """Start the delete process on the requested entities.""" _validate_requested(link, requested) - return create_link_operation_result(entity.start_delete() for entity in link if entity.identifier in requested) + return create_link_operation_result( + _create_update(entity, Operations.START_DELETE) for entity in link if entity.identifier in requested + ) diff --git a/link/domain/state.py b/link/domain/state.py index 313a8c00..178072fa 100644 --- a/link/domain/state.py +++ b/link/domain/state.py @@ -1,7 +1,7 @@ """Contains everything state related.""" from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, replace from enum import Enum, auto from typing import Optional, Union @@ -12,35 +12,19 @@ class State: """An entity's state.""" @classmethod - def start_pull(cls, entity: Entity) -> EntityOperationResult: + def start_pull(cls, entity: Entity) -> Entity: """Return the command needed to start the pull process for the entity.""" - return cls._create_invalid_operation_result(Operations.START_PULL, entity.identifier) + return entity @classmethod - def start_delete(cls, entity: Entity) -> EntityOperationResult: + def start_delete(cls, entity: Entity) -> Entity: """Return the commands needed to start the delete process for the entity.""" - return cls._create_invalid_operation_result(Operations.START_DELETE, entity.identifier) + return entity @classmethod - def process(cls, entity: Entity) -> EntityOperationResult: + def process(cls, entity: Entity) -> Entity: """Return the commands needed to process the entity.""" - return cls._create_invalid_operation_result(Operations.PROCESS, entity.identifier) - - @classmethod - def _create_invalid_operation_result(cls, operation: Operations, identifier: Identifier) -> EntityOperationResult: - return InvalidOperation(operation, identifier, cls) - - @classmethod - def _create_valid_operation_result( - cls, operation: Operations, identifier: Identifier, new_state: type[State] - ) -> EntityOperationResult: - transition = Transition(cls, new_state) - return Update( - operation, - identifier, - transition, - command=TRANSITION_MAP[transition], - ) + return entity class States: @@ -66,9 +50,9 @@ class Idle(State): """The default state of an entity.""" @classmethod - def start_pull(cls, entity: Entity) -> EntityOperationResult: + def start_pull(cls, entity: Entity) -> Entity: """Return the command needed to start the pull process for an entity.""" - return cls._create_valid_operation_result(Operations.START_PULL, entity.identifier, Activated) + return replace(entity, state=Activated, current_process=Processes.PULL) states.register(Idle) @@ -78,16 +62,15 @@ class Activated(State): """The state of an activated entity.""" @classmethod - def process(cls, entity: Entity) -> EntityOperationResult: + def process(cls, entity: Entity) -> Entity: """Return the commands needed to process an activated entity.""" - new_state: type[State] if entity.is_tainted: - new_state = Deprecated + return replace(entity, state=Deprecated, current_process=None) elif entity.current_process is Processes.PULL: - new_state = Received + return replace(entity, state=Received) elif entity.current_process is Processes.DELETE: - new_state = Idle - return cls._create_valid_operation_result(Operations.PROCESS, entity.identifier, new_state) + return replace(entity, state=Idle, current_process=None) + raise RuntimeError states.register(Activated) @@ -97,17 +80,16 @@ class Received(State): """The state of an received entity.""" @classmethod - def process(cls, entity: Entity) -> EntityOperationResult: + def process(cls, entity: Entity) -> Entity: """Return the commands needed to process a received entity.""" - new_state: type[State] if entity.current_process is Processes.PULL: if entity.is_tainted: - new_state = Tainted + return replace(entity, state=Tainted, current_process=None) else: - new_state = Pulled + return replace(entity, state=Pulled, current_process=None) elif entity.current_process is Processes.DELETE: - new_state = Activated - return cls._create_valid_operation_result(Operations.PROCESS, entity.identifier, new_state) + return replace(entity, state=Activated) + raise RuntimeError states.register(Received) @@ -117,9 +99,9 @@ class Pulled(State): """The state of an entity that has been copied to the local side.""" @classmethod - def start_delete(cls, entity: Entity) -> EntityOperationResult: + def start_delete(cls, entity: Entity) -> Entity: """Return the commands needed to start the delete process for the entity.""" - return cls._create_valid_operation_result(Operations.START_DELETE, entity.identifier, Received) + return replace(entity, state=Received, current_process=Processes.DELETE) states.register(Pulled) @@ -129,9 +111,9 @@ class Tainted(State): """The state of an entity that has been flagged as faulty by the source side.""" @classmethod - def start_delete(cls, entity: Entity) -> EntityOperationResult: + def start_delete(cls, entity: Entity) -> Entity: """Return the commands needed to start the delete process for the entity.""" - return cls._create_valid_operation_result(Operations.START_DELETE, entity.identifier, Received) + return replace(entity, state=Received, current_process=Processes.DELETE) states.register(Tainted) @@ -288,14 +270,14 @@ class Entity: current_process: Optional[Processes] is_tainted: bool - def start_pull(self) -> EntityOperationResult: + def start_pull(self) -> Entity: """Start the pull process for the entity.""" return self.state.start_pull(self) - def start_delete(self) -> EntityOperationResult: + def start_delete(self) -> Entity: """Start the delete process for the entity.""" return self.state.start_delete(self) - def process(self) -> EntityOperationResult: + def process(self) -> Entity: """Process the entity.""" return self.state.process(self) diff --git a/tests/unit/entities/test_state.py b/tests/unit/entities/test_state.py index dd59505b..1f5a7743 100644 --- a/tests/unit/entities/test_state.py +++ b/tests/unit/entities/test_state.py @@ -1,5 +1,6 @@ from __future__ import annotations +from dataclasses import replace from typing import Iterable import pytest @@ -7,14 +8,9 @@ from link.domain.custom_types import Identifier from link.domain.link import create_link from link.domain.state import ( - Commands, Components, - InvalidOperation, - Operations, Processes, State, - Transition, - Update, states, ) from tests.assignments import create_assignments, create_identifier, create_identifiers @@ -31,7 +27,7 @@ (create_identifier("6"), states.Deprecated, ["start_pull", "start_delete", "process"]), ], ) -def test_invalid_transitions_produce_no_updates(identifier: Identifier, state: type[State], methods: str) -> None: +def test_invalid_transitions_returns_unchanged_entity(identifier: Identifier, state: type[State], methods: str) -> None: link = create_link( create_assignments( { @@ -44,42 +40,29 @@ def test_invalid_transitions_produce_no_updates(identifier: Identifier, state: t processes={Processes.PULL: create_identifiers("2", "3")}, ) entity = next(entity for entity in link if entity.identifier == identifier) - method_operation_map = { - "start_pull": Operations.START_PULL, - "start_delete": Operations.START_DELETE, - "process": Operations.PROCESS, - } - assert all( - getattr(entity, method)() == InvalidOperation(method_operation_map[method], entity.identifier, state) - for method in methods - ) + assert all(getattr(entity, method)() == entity for method in methods) -def test_start_pulling_idle_entity_returns_correct_commands() -> None: +def test_start_pulling_idle_entity_returns_correct_entity() -> None: link = create_link(create_assignments({Components.SOURCE: {"1"}})) entity = next(iter(link)) - assert entity.start_pull() == Update( - Operations.START_PULL, - create_identifier("1"), - Transition(states.Idle, states.Activated), - command=Commands.START_PULL_PROCESS, - ) + assert entity.start_pull() == replace(entity, state=states.Activated, current_process=Processes.PULL) @pytest.mark.parametrize( - ("process", "tainted_identifiers", "new_state", "command"), + ("process", "tainted_identifiers", "new_state", "new_process"), [ - (Processes.PULL, set(), states.Received, Commands.ADD_TO_LOCAL), - (Processes.PULL, create_identifiers("1"), states.Deprecated, Commands.DEPRECATE), - (Processes.DELETE, set(), states.Idle, Commands.FINISH_DELETE_PROCESS), - (Processes.DELETE, create_identifiers("1"), states.Deprecated, Commands.DEPRECATE), + (Processes.PULL, set(), states.Received, Processes.PULL), + (Processes.PULL, create_identifiers("1"), states.Deprecated, None), + (Processes.DELETE, set(), states.Idle, None), + (Processes.DELETE, create_identifiers("1"), states.Deprecated, None), ], ) -def test_processing_activated_entity_returns_correct_commands( +def test_processing_activated_entity_returns_correct_entity( process: Processes, tainted_identifiers: Iterable[Identifier], new_state: type[State], - command: Commands, + new_process: Processes | None, ) -> None: link = create_link( create_assignments({Components.SOURCE: {"1"}, Components.OUTBOUND: {"1"}}), @@ -87,25 +70,20 @@ def test_processing_activated_entity_returns_correct_commands( tainted_identifiers=tainted_identifiers, ) entity = next(iter(link)) - assert entity.process() == Update( - Operations.PROCESS, - create_identifier("1"), - Transition(states.Activated, new_state), - command=command, - ) + assert entity.process() == replace(entity, state=new_state, current_process=new_process) @pytest.mark.parametrize( - ("process", "tainted_identifiers", "new_state", "command"), + ("process", "tainted_identifiers", "new_state", "new_process"), [ - (Processes.PULL, set(), states.Pulled, Commands.FINISH_PULL_PROCESS), - (Processes.PULL, create_identifiers("1"), states.Tainted, Commands.FINISH_PULL_PROCESS), - (Processes.DELETE, set(), states.Activated, Commands.REMOVE_FROM_LOCAL), - (Processes.DELETE, create_identifiers("1"), states.Activated, Commands.REMOVE_FROM_LOCAL), + (Processes.PULL, set(), states.Pulled, None), + (Processes.PULL, create_identifiers("1"), states.Tainted, None), + (Processes.DELETE, set(), states.Activated, Processes.DELETE), + (Processes.DELETE, create_identifiers("1"), states.Activated, Processes.DELETE), ], ) -def test_processing_received_entity_returns_correct_commands( - process: Processes, tainted_identifiers: Iterable[Identifier], new_state: type[State], command: Commands +def test_processing_received_entity_returns_correct_entity( + process: Processes, tainted_identifiers: Iterable[Identifier], new_state: type[State], new_process: Processes | None ) -> None: link = create_link( create_assignments({Components.SOURCE: {"1"}, Components.OUTBOUND: {"1"}, Components.LOCAL: {"1"}}), @@ -113,25 +91,15 @@ def test_processing_received_entity_returns_correct_commands( tainted_identifiers=tainted_identifiers, ) entity = next(iter(link)) - assert entity.process() == Update( - Operations.PROCESS, - create_identifier("1"), - Transition(states.Received, new_state), - command=command, - ) + assert entity.process() == replace(entity, state=new_state, current_process=new_process) -def test_starting_delete_on_pulled_entity_returns_correct_commands() -> None: +def test_starting_delete_on_pulled_entity_returns_correct_entity() -> None: link = create_link( create_assignments({Components.SOURCE: {"1"}, Components.OUTBOUND: {"1"}, Components.LOCAL: {"1"}}) ) entity = next(iter(link)) - assert entity.start_delete() == Update( - Operations.START_DELETE, - create_identifier("1"), - Transition(states.Pulled, states.Received), - command=Commands.START_DELETE_PROCESS, - ) + assert entity.start_delete() == replace(entity, state=states.Received, current_process=Processes.DELETE) def test_starting_delete_on_tainted_entity_returns_correct_commands() -> None: @@ -140,9 +108,4 @@ def test_starting_delete_on_tainted_entity_returns_correct_commands() -> None: tainted_identifiers={create_identifier("1")}, ) entity = next(iter(link)) - assert entity.start_delete() == Update( - Operations.START_DELETE, - create_identifier("1"), - Transition(states.Tainted, states.Received), - command=Commands.START_DELETE_PROCESS, - ) + assert entity.start_delete() == replace(entity, state=states.Received, current_process=Processes.DELETE) From 7ebdd6bbf262267effefc9816586aa4016755630 Mon Sep 17 00:00:00 2001 From: Christoph Blessing <33834216+cblessing24@users.noreply.github.com> Date: Thu, 19 Oct 2023 11:16:51 +0200 Subject: [PATCH 02/13] Add apply method to entity --- link/domain/link.py | 7 +---- link/domain/state.py | 15 ++++++++-- tests/unit/entities/test_state.py | 47 ++++++++++++++++++------------- 3 files changed, 40 insertions(+), 29 deletions(-) diff --git a/link/domain/link.py b/link/domain/link.py index 282a866d..902d8094 100644 --- a/link/domain/link.py +++ b/link/domain/link.py @@ -141,12 +141,7 @@ def _validate_requested(link: Link, requested: Iterable[Identifier]) -> None: def _create_update(current: Entity, operation: Operations) -> EntityOperationResult: - operations_map = { - Operations.START_PULL: "start_pull", - Operations.START_DELETE: "start_delete", - Operations.PROCESS: "process", - } - new = getattr(current, operations_map[operation])() + new = current.apply(operation) if current.state is new.state: return InvalidOperation(operation, current.identifier, current.state) transition = Transition(current.state, new.state) diff --git a/link/domain/state.py b/link/domain/state.py index 178072fa..967e1328 100644 --- a/link/domain/state.py +++ b/link/domain/state.py @@ -270,14 +270,23 @@ class Entity: current_process: Optional[Processes] is_tainted: bool - def start_pull(self) -> Entity: + def apply(self, operation: Operations) -> Entity: + """Apply an operation to the entity.""" + if operation is Operations.START_PULL: + return self._start_pull() + if operation is Operations.START_DELETE: + return self._start_delete() + if operation is Operations.PROCESS: + return self._process() + + def _start_pull(self) -> Entity: """Start the pull process for the entity.""" return self.state.start_pull(self) - def start_delete(self) -> Entity: + def _start_delete(self) -> Entity: """Start the delete process for the entity.""" return self.state.start_delete(self) - def process(self) -> Entity: + def _process(self) -> Entity: """Process the entity.""" return self.state.process(self) diff --git a/tests/unit/entities/test_state.py b/tests/unit/entities/test_state.py index 1f5a7743..2259f59f 100644 --- a/tests/unit/entities/test_state.py +++ b/tests/unit/entities/test_state.py @@ -7,27 +7,28 @@ from link.domain.custom_types import Identifier from link.domain.link import create_link -from link.domain.state import ( - Components, - Processes, - State, - states, -) +from link.domain.state import Components, Operations, Processes, State, states from tests.assignments import create_assignments, create_identifier, create_identifiers @pytest.mark.parametrize( - ("identifier", "state", "methods"), + ("identifier", "state", "operations"), [ - (create_identifier("1"), states.Idle, ["start_delete", "process"]), - (create_identifier("2"), states.Activated, ["start_pull", "start_delete"]), - (create_identifier("3"), states.Received, ["start_pull", "start_delete"]), - (create_identifier("4"), states.Pulled, ["start_pull", "process"]), - (create_identifier("5"), states.Tainted, ["start_pull", "process"]), - (create_identifier("6"), states.Deprecated, ["start_pull", "start_delete", "process"]), + (create_identifier("1"), states.Idle, [Operations.START_DELETE, Operations.PROCESS]), + (create_identifier("2"), states.Activated, [Operations.START_PULL, Operations.START_DELETE]), + (create_identifier("3"), states.Received, [Operations.START_PULL, Operations.START_DELETE]), + (create_identifier("4"), states.Pulled, [Operations.START_PULL, Operations.PROCESS]), + (create_identifier("5"), states.Tainted, [Operations.START_PULL, Operations.PROCESS]), + ( + create_identifier("6"), + states.Deprecated, + [Operations.START_PULL, Operations.START_DELETE, Operations.PROCESS], + ), ], ) -def test_invalid_transitions_returns_unchanged_entity(identifier: Identifier, state: type[State], methods: str) -> None: +def test_invalid_transitions_returns_unchanged_entity( + identifier: Identifier, state: type[State], operations: list[Operations] +) -> None: link = create_link( create_assignments( { @@ -40,13 +41,15 @@ def test_invalid_transitions_returns_unchanged_entity(identifier: Identifier, st processes={Processes.PULL: create_identifiers("2", "3")}, ) entity = next(entity for entity in link if entity.identifier == identifier) - assert all(getattr(entity, method)() == entity for method in methods) + assert all(entity.apply(operation) == entity for operation in operations) def test_start_pulling_idle_entity_returns_correct_entity() -> None: link = create_link(create_assignments({Components.SOURCE: {"1"}})) entity = next(iter(link)) - assert entity.start_pull() == replace(entity, state=states.Activated, current_process=Processes.PULL) + assert entity.apply(Operations.START_PULL) == replace( + entity, state=states.Activated, current_process=Processes.PULL + ) @pytest.mark.parametrize( @@ -70,7 +73,7 @@ def test_processing_activated_entity_returns_correct_entity( tainted_identifiers=tainted_identifiers, ) entity = next(iter(link)) - assert entity.process() == replace(entity, state=new_state, current_process=new_process) + assert entity.apply(Operations.PROCESS) == replace(entity, state=new_state, current_process=new_process) @pytest.mark.parametrize( @@ -91,7 +94,7 @@ def test_processing_received_entity_returns_correct_entity( tainted_identifiers=tainted_identifiers, ) entity = next(iter(link)) - assert entity.process() == replace(entity, state=new_state, current_process=new_process) + assert entity.apply(Operations.PROCESS) == replace(entity, state=new_state, current_process=new_process) def test_starting_delete_on_pulled_entity_returns_correct_entity() -> None: @@ -99,7 +102,9 @@ def test_starting_delete_on_pulled_entity_returns_correct_entity() -> None: create_assignments({Components.SOURCE: {"1"}, Components.OUTBOUND: {"1"}, Components.LOCAL: {"1"}}) ) entity = next(iter(link)) - assert entity.start_delete() == replace(entity, state=states.Received, current_process=Processes.DELETE) + assert entity.apply(Operations.START_DELETE) == replace( + entity, state=states.Received, current_process=Processes.DELETE + ) def test_starting_delete_on_tainted_entity_returns_correct_commands() -> None: @@ -108,4 +113,6 @@ def test_starting_delete_on_tainted_entity_returns_correct_commands() -> None: tainted_identifiers={create_identifier("1")}, ) entity = next(iter(link)) - assert entity.start_delete() == replace(entity, state=states.Received, current_process=Processes.DELETE) + assert entity.apply(Operations.START_DELETE) == replace( + entity, state=states.Received, current_process=Processes.DELETE + ) From a5fb4481674df9f1256f30c6c8e59407bc780bd3 Mon Sep 17 00:00:00 2001 From: Christoph Blessing <33834216+cblessing24@users.noreply.github.com> Date: Thu, 19 Oct 2023 13:02:52 +0200 Subject: [PATCH 03/13] Store entity operation results on entities --- link/domain/link.py | 3 +- link/domain/state.py | 54 ++++++++++++++------ tests/unit/entities/test_link.py | 2 +- tests/unit/entities/test_state.py | 84 ++++++++++++++++++++++++------- 4 files changed, 106 insertions(+), 37 deletions(-) diff --git a/link/domain/link.py b/link/domain/link.py index 902d8094..833af8d5 100644 --- a/link/domain/link.py +++ b/link/domain/link.py @@ -70,8 +70,9 @@ def create_entity(identifier: Identifier) -> Entity: return Entity( identifier, state=state, - current_process=processes_map.get(identifier), + current_process=processes_map.get(identifier, Processes.NONE), is_tainted=is_tainted(identifier), + operation_results=tuple(), ) return {create_entity(identifier) for identifier in assignments[Components.SOURCE]} diff --git a/link/domain/state.py b/link/domain/state.py index 967e1328..7b453232 100644 --- a/link/domain/state.py +++ b/link/domain/state.py @@ -3,7 +3,8 @@ from dataclasses import dataclass, replace from enum import Enum, auto -from typing import Optional, Union +from functools import partial +from typing import Union from .custom_types import Identifier @@ -14,17 +15,34 @@ class State: @classmethod def start_pull(cls, entity: Entity) -> Entity: """Return the command needed to start the pull process for the entity.""" - return entity + return cls._create_invalid_operation(entity, Operations.START_PULL) @classmethod def start_delete(cls, entity: Entity) -> Entity: """Return the commands needed to start the delete process for the entity.""" - return entity + return cls._create_invalid_operation(entity, Operations.START_DELETE) @classmethod def process(cls, entity: Entity) -> Entity: """Return the commands needed to process the entity.""" - return entity + return cls._create_invalid_operation(entity, Operations.PROCESS) + + @staticmethod + def _create_invalid_operation(entity: Entity, operation: Operations) -> Entity: + updated = entity.operation_results + (InvalidOperation(operation, entity.identifier, entity.state),) + return replace(entity, operation_results=updated) + + @classmethod + def _transition_entity( + cls, entity: Entity, operation: Operations, new_state: type[State], *, new_process: Processes | None = None + ) -> Entity: + if new_process is None: + new_process = entity.current_process + transition = Transition(cls, new_state) + updated_results = entity.operation_results + ( + Update(operation, entity.identifier, transition, TRANSITION_MAP[transition]), + ) + return replace(entity, state=transition.new, current_process=new_process, operation_results=updated_results) class States: @@ -52,7 +70,7 @@ class Idle(State): @classmethod def start_pull(cls, entity: Entity) -> Entity: """Return the command needed to start the pull process for an entity.""" - return replace(entity, state=Activated, current_process=Processes.PULL) + return cls._transition_entity(entity, Operations.START_PULL, Activated, new_process=Processes.PULL) states.register(Idle) @@ -64,12 +82,13 @@ class Activated(State): @classmethod def process(cls, entity: Entity) -> Entity: """Return the commands needed to process an activated entity.""" + transition_entity = partial(cls._transition_entity, entity, Operations.PROCESS) if entity.is_tainted: - return replace(entity, state=Deprecated, current_process=None) + return transition_entity(Deprecated, new_process=Processes.NONE) elif entity.current_process is Processes.PULL: - return replace(entity, state=Received) + return transition_entity(Received) elif entity.current_process is Processes.DELETE: - return replace(entity, state=Idle, current_process=None) + return transition_entity(Idle, new_process=Processes.NONE) raise RuntimeError @@ -82,13 +101,14 @@ class Received(State): @classmethod def process(cls, entity: Entity) -> Entity: """Return the commands needed to process a received entity.""" + transition_entity = partial(cls._transition_entity, entity, Operations.PROCESS) if entity.current_process is Processes.PULL: if entity.is_tainted: - return replace(entity, state=Tainted, current_process=None) + return transition_entity(Tainted, new_process=Processes.NONE) else: - return replace(entity, state=Pulled, current_process=None) + return transition_entity(Pulled, new_process=Processes.NONE) elif entity.current_process is Processes.DELETE: - return replace(entity, state=Activated) + return transition_entity(Activated) raise RuntimeError @@ -101,7 +121,7 @@ class Pulled(State): @classmethod def start_delete(cls, entity: Entity) -> Entity: """Return the commands needed to start the delete process for the entity.""" - return replace(entity, state=Received, current_process=Processes.DELETE) + return cls._transition_entity(entity, Operations.START_DELETE, Received, new_process=Processes.DELETE) states.register(Pulled) @@ -113,7 +133,7 @@ class Tainted(State): @classmethod def start_delete(cls, entity: Entity) -> Entity: """Return the commands needed to start the delete process for the entity.""" - return replace(entity, state=Received, current_process=Processes.DELETE) + return cls._transition_entity(entity, Operations.START_DELETE, Received, new_process=Processes.DELETE) states.register(Tainted) @@ -196,8 +216,9 @@ class InvalidOperation: class Processes(Enum): """Names for processes that pull/delete entities into/from the local side.""" - PULL = 1 - DELETE = 2 + NONE = auto() + PULL = auto() + DELETE = auto() class Components(Enum): @@ -267,8 +288,9 @@ class Entity: identifier: Identifier state: type[State] - current_process: Optional[Processes] + current_process: Processes is_tainted: bool + operation_results: tuple[EntityOperationResult, ...] def apply(self, operation: Operations) -> Entity: """Apply an operation to the entity.""" diff --git a/tests/unit/entities/test_link.py b/tests/unit/entities/test_link.py index 086d185a..3f78e484 100644 --- a/tests/unit/entities/test_link.py +++ b/tests/unit/entities/test_link.py @@ -79,7 +79,7 @@ def test_entities_get_correct_process_assigned() -> None: (create_identifier("2"), Processes.DELETE), (create_identifier("3"), Processes.PULL), (create_identifier("4"), Processes.DELETE), - (create_identifier("5"), None), + (create_identifier("5"), Processes.NONE), } assert {(entity.identifier, entity.current_process) for entity in link} == set(expected) diff --git a/tests/unit/entities/test_state.py b/tests/unit/entities/test_state.py index 2259f59f..c74e7c79 100644 --- a/tests/unit/entities/test_state.py +++ b/tests/unit/entities/test_state.py @@ -7,7 +7,17 @@ from link.domain.custom_types import Identifier from link.domain.link import create_link -from link.domain.state import Components, Operations, Processes, State, states +from link.domain.state import ( + Commands, + Components, + InvalidOperation, + Operations, + Processes, + State, + Transition, + Update, + states, +) from tests.assignments import create_assignments, create_identifier, create_identifiers @@ -41,31 +51,44 @@ def test_invalid_transitions_returns_unchanged_entity( processes={Processes.PULL: create_identifiers("2", "3")}, ) entity = next(entity for entity in link if entity.identifier == identifier) - assert all(entity.apply(operation) == entity for operation in operations) + for operation in operations: + result = InvalidOperation(operation, identifier, state) + assert entity.apply(operation) == replace(entity, operation_results=(result,)) def test_start_pulling_idle_entity_returns_correct_entity() -> None: link = create_link(create_assignments({Components.SOURCE: {"1"}})) entity = next(iter(link)) assert entity.apply(Operations.START_PULL) == replace( - entity, state=states.Activated, current_process=Processes.PULL + entity, + state=states.Activated, + current_process=Processes.PULL, + operation_results=( + Update( + Operations.START_PULL, + entity.identifier, + Transition(states.Idle, states.Activated), + Commands.START_PULL_PROCESS, + ), + ), ) @pytest.mark.parametrize( - ("process", "tainted_identifiers", "new_state", "new_process"), + ("process", "tainted_identifiers", "new_state", "new_process", "command"), [ - (Processes.PULL, set(), states.Received, Processes.PULL), - (Processes.PULL, create_identifiers("1"), states.Deprecated, None), - (Processes.DELETE, set(), states.Idle, None), - (Processes.DELETE, create_identifiers("1"), states.Deprecated, None), + (Processes.PULL, set(), states.Received, Processes.PULL, Commands.ADD_TO_LOCAL), + (Processes.PULL, create_identifiers("1"), states.Deprecated, Processes.NONE, Commands.DEPRECATE), + (Processes.DELETE, set(), states.Idle, Processes.NONE, Commands.FINISH_DELETE_PROCESS), + (Processes.DELETE, create_identifiers("1"), states.Deprecated, Processes.NONE, Commands.DEPRECATE), ], ) def test_processing_activated_entity_returns_correct_entity( process: Processes, tainted_identifiers: Iterable[Identifier], new_state: type[State], - new_process: Processes | None, + new_process: Processes, + command: Commands, ) -> None: link = create_link( create_assignments({Components.SOURCE: {"1"}, Components.OUTBOUND: {"1"}}), @@ -73,20 +96,29 @@ def test_processing_activated_entity_returns_correct_entity( tainted_identifiers=tainted_identifiers, ) entity = next(iter(link)) - assert entity.apply(Operations.PROCESS) == replace(entity, state=new_state, current_process=new_process) + updated_results = entity.operation_results + ( + Update(Operations.PROCESS, entity.identifier, Transition(entity.state, new_state), command), + ) + assert entity.apply(Operations.PROCESS) == replace( + entity, state=new_state, current_process=new_process, operation_results=updated_results + ) @pytest.mark.parametrize( - ("process", "tainted_identifiers", "new_state", "new_process"), + ("process", "tainted_identifiers", "new_state", "new_process", "command"), [ - (Processes.PULL, set(), states.Pulled, None), - (Processes.PULL, create_identifiers("1"), states.Tainted, None), - (Processes.DELETE, set(), states.Activated, Processes.DELETE), - (Processes.DELETE, create_identifiers("1"), states.Activated, Processes.DELETE), + (Processes.PULL, set(), states.Pulled, Processes.NONE, Commands.FINISH_PULL_PROCESS), + (Processes.PULL, create_identifiers("1"), states.Tainted, Processes.NONE, Commands.FINISH_PULL_PROCESS), + (Processes.DELETE, set(), states.Activated, Processes.DELETE, Commands.REMOVE_FROM_LOCAL), + (Processes.DELETE, create_identifiers("1"), states.Activated, Processes.DELETE, Commands.REMOVE_FROM_LOCAL), ], ) def test_processing_received_entity_returns_correct_entity( - process: Processes, tainted_identifiers: Iterable[Identifier], new_state: type[State], new_process: Processes | None + process: Processes, + tainted_identifiers: Iterable[Identifier], + new_state: type[State], + new_process: Processes, + command: Commands, ) -> None: link = create_link( create_assignments({Components.SOURCE: {"1"}, Components.OUTBOUND: {"1"}, Components.LOCAL: {"1"}}), @@ -94,7 +126,10 @@ def test_processing_received_entity_returns_correct_entity( tainted_identifiers=tainted_identifiers, ) entity = next(iter(link)) - assert entity.apply(Operations.PROCESS) == replace(entity, state=new_state, current_process=new_process) + operation_results = (Update(Operations.PROCESS, entity.identifier, Transition(entity.state, new_state), command),) + assert entity.apply(Operations.PROCESS) == replace( + entity, state=new_state, current_process=new_process, operation_results=operation_results + ) def test_starting_delete_on_pulled_entity_returns_correct_entity() -> None: @@ -102,8 +137,17 @@ def test_starting_delete_on_pulled_entity_returns_correct_entity() -> None: create_assignments({Components.SOURCE: {"1"}, Components.OUTBOUND: {"1"}, Components.LOCAL: {"1"}}) ) entity = next(iter(link)) + transition = Transition(states.Pulled, states.Received) + operation_results = ( + Update( + Operations.START_DELETE, + entity.identifier, + transition, + Commands.START_DELETE_PROCESS, + ), + ) assert entity.apply(Operations.START_DELETE) == replace( - entity, state=states.Received, current_process=Processes.DELETE + entity, state=transition.new, current_process=Processes.DELETE, operation_results=operation_results ) @@ -113,6 +157,8 @@ def test_starting_delete_on_tainted_entity_returns_correct_commands() -> None: tainted_identifiers={create_identifier("1")}, ) entity = next(iter(link)) + transition = Transition(states.Tainted, states.Received) + operation_results = (Update(Operations.START_DELETE, entity.identifier, transition, Commands.START_DELETE_PROCESS),) assert entity.apply(Operations.START_DELETE) == replace( - entity, state=states.Received, current_process=Processes.DELETE + entity, state=transition.new, current_process=Processes.DELETE, operation_results=operation_results ) From cd12f75ebb1bd1ea26b161a81ba91720a99b83d0 Mon Sep 17 00:00:00 2001 From: Christoph Blessing <33834216+cblessing24@users.noreply.github.com> Date: Thu, 19 Oct 2023 15:58:44 +0200 Subject: [PATCH 04/13] Remove old update creating function --- link/domain/link.py | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/link/domain/link.py b/link/domain/link.py index 833af8d5..82055a01 100644 --- a/link/domain/link.py +++ b/link/domain/link.py @@ -7,7 +7,6 @@ from .custom_types import Identifier from .state import ( STATE_MAP, - TRANSITION_MAP, Components, Entity, EntityOperationResult, @@ -15,7 +14,6 @@ Operations, PersistentState, Processes, - Transition, Update, ) @@ -132,7 +130,7 @@ def process(link: Link, *, requested: Iterable[Identifier]) -> LinkOperationResu """Process all entities in the link producing appropriate updates.""" _validate_requested(link, requested) return create_link_operation_result( - _create_update(entity, Operations.PROCESS) for entity in link if entity.identifier in requested + entity.apply(Operations.PROCESS).operation_results[0] for entity in link if entity.identifier in requested ) @@ -141,19 +139,11 @@ def _validate_requested(link: Link, requested: Iterable[Identifier]) -> None: assert set(requested) <= link.identifiers, "Requested identifiers not present in link." -def _create_update(current: Entity, operation: Operations) -> EntityOperationResult: - new = current.apply(operation) - if current.state is new.state: - return InvalidOperation(operation, current.identifier, current.state) - transition = Transition(current.state, new.state) - return Update(operation, current.identifier, transition, TRANSITION_MAP[transition]) - - def start_pull(link: Link, *, requested: Iterable[Identifier]) -> LinkOperationResult: """Start the pull process on the requested entities.""" _validate_requested(link, requested) return create_link_operation_result( - _create_update(entity, Operations.START_PULL) for entity in link if entity.identifier in requested + entity.apply(Operations.START_PULL).operation_results[0] for entity in link if entity.identifier in requested ) @@ -161,5 +151,5 @@ def start_delete(link: Link, *, requested: Iterable[Identifier]) -> LinkOperatio """Start the delete process on the requested entities.""" _validate_requested(link, requested) return create_link_operation_result( - _create_update(entity, Operations.START_DELETE) for entity in link if entity.identifier in requested + entity.apply(Operations.START_DELETE).operation_results[0] for entity in link if entity.identifier in requested ) From c6c1a8d658d84bba055e843e9d03603c9b145406 Mon Sep 17 00:00:00 2001 From: Christoph Blessing <33834216+cblessing24@users.noreply.github.com> Date: Thu, 19 Oct 2023 16:52:59 +0200 Subject: [PATCH 05/13] Start operations from link object --- link/domain/link.py | 67 +++++++++++-------- link/service/services.py | 8 +-- .../integration/test_datajoint_persistence.py | 66 +++++++++++++----- tests/unit/entities/test_link.py | 18 ++--- 4 files changed, 101 insertions(+), 58 deletions(-) diff --git a/link/domain/link.py b/link/domain/link.py index 82055a01..29e16609 100644 --- a/link/domain/link.py +++ b/link/domain/link.py @@ -99,6 +99,44 @@ def identifiers(self) -> frozenset[Identifier]: """Return the identifiers of all entities in the link.""" return frozenset(entity.identifier for entity in self) + def apply(self, operation: Operations, *, requested: Iterable[Identifier]) -> LinkOperationResult: + """Apply an operation to the requested entities.""" + if operation is Operations.START_PULL: + return self._start_pull(requested) + if operation is Operations.START_DELETE: + return self._start_delete(requested) + if operation is Operations.PROCESS: + return self._process(requested) + + def _process(self, requested: Iterable[Identifier]) -> LinkOperationResult: + """Process all entities in the link producing appropriate updates.""" + self._validate_requested(requested) + return create_link_operation_result( + entity.apply(Operations.PROCESS).operation_results[0] for entity in self if entity.identifier in requested + ) + + def _start_delete(self, requested: Iterable[Identifier]) -> LinkOperationResult: + """Start the delete process on the requested entities.""" + self._validate_requested(requested) + return create_link_operation_result( + entity.apply(Operations.START_DELETE).operation_results[0] + for entity in self + if entity.identifier in requested + ) + + def _start_pull(self, requested: Iterable[Identifier]) -> LinkOperationResult: + """Start the pull process on the requested entities.""" + self._validate_requested(requested) + return create_link_operation_result( + entity.apply(Operations.START_PULL).operation_results[0] + for entity in self + if entity.identifier in requested + ) + + def _validate_requested(self, requested: Iterable[Identifier]) -> None: + assert requested, "No identifiers requested." + assert set(requested) <= self.identifiers, "Requested identifiers not present in link." + @dataclass(frozen=True) class LinkOperationResult: @@ -124,32 +162,3 @@ def create_link_operation_result(results: Iterable[EntityOperationResult]) -> Li updates=frozenset(result for result in results if isinstance(result, Update)), errors=frozenset(result for result in results if isinstance(result, InvalidOperation)), ) - - -def process(link: Link, *, requested: Iterable[Identifier]) -> LinkOperationResult: - """Process all entities in the link producing appropriate updates.""" - _validate_requested(link, requested) - return create_link_operation_result( - entity.apply(Operations.PROCESS).operation_results[0] for entity in link if entity.identifier in requested - ) - - -def _validate_requested(link: Link, requested: Iterable[Identifier]) -> None: - assert requested, "No identifiers requested." - assert set(requested) <= link.identifiers, "Requested identifiers not present in link." - - -def start_pull(link: Link, *, requested: Iterable[Identifier]) -> LinkOperationResult: - """Start the pull process on the requested entities.""" - _validate_requested(link, requested) - return create_link_operation_result( - entity.apply(Operations.START_PULL).operation_results[0] for entity in link if entity.identifier in requested - ) - - -def start_delete(link: Link, *, requested: Iterable[Identifier]) -> LinkOperationResult: - """Start the delete process on the requested entities.""" - _validate_requested(link, requested) - return create_link_operation_result( - entity.apply(Operations.START_DELETE).operation_results[0] for entity in link if entity.identifier in requested - ) diff --git a/link/service/services.py b/link/service/services.py index 4dddd1e2..7fe3f6a6 100644 --- a/link/service/services.py +++ b/link/service/services.py @@ -6,8 +6,6 @@ from enum import Enum, auto from link.domain.custom_types import Identifier -from link.domain.link import process as process_domain_service -from link.domain.link import start_delete, start_pull from link.domain.state import InvalidOperation, Operations, Update, states from .gateway import LinkGateway @@ -129,7 +127,7 @@ def start_pull_process( output_port: Callable[[OperationResponse], None], ) -> None: """Start the pull process for the requested entities.""" - result = start_pull(link_gateway.create_link(), requested=request.requested) + result = link_gateway.create_link().apply(Operations.START_PULL, requested=request.requested) link_gateway.apply(result.updates) output_port(OperationResponse(result.operation, request.requested, result.updates, result.errors)) @@ -148,7 +146,7 @@ def start_delete_process( output_port: Callable[[OperationResponse], None], ) -> None: """Start the delete process for the requested entities.""" - result = start_delete(link_gateway.create_link(), requested=request.requested) + result = link_gateway.create_link().apply(Operations.START_DELETE, requested=request.requested) link_gateway.apply(result.updates) output_port(OperationResponse(result.operation, request.requested, result.updates, result.errors)) @@ -164,7 +162,7 @@ def process( request: ProcessRequest, *, link_gateway: LinkGateway, output_port: Callable[[OperationResponse], None] ) -> None: """Process entities.""" - result = process_domain_service(link_gateway.create_link(), requested=request.requested) + result = link_gateway.create_link().apply(Operations.PROCESS, requested=request.requested) link_gateway.apply(result.updates) output_port(OperationResponse(result.operation, request.requested, result.updates, result.errors)) diff --git a/tests/integration/test_datajoint_persistence.py b/tests/integration/test_datajoint_persistence.py index 34725e48..c9e91319 100644 --- a/tests/integration/test_datajoint_persistence.py +++ b/tests/integration/test_datajoint_persistence.py @@ -17,8 +17,8 @@ from link.adapters import PrimaryKey from link.adapters.gateway import DJLinkGateway from link.adapters.identification import IdentificationTranslator -from link.domain.link import create_link, process, start_delete, start_pull -from link.domain.state import Components, Processes +from link.domain.link import create_link +from link.domain.state import Components, Operations, Processes from link.infrastructure.facade import DJLinkFacade, Table @@ -370,7 +370,9 @@ def test_add_to_local_command() -> None: ), ) - gateway.apply(process(gateway.create_link(), requested={gateway.translator.to_identifier({"a": 0})}).updates) + gateway.apply( + gateway.create_link().apply(Operations.PROCESS, requested={gateway.translator.to_identifier({"a": 0})}).updates + ) assert has_state( tables, @@ -408,7 +410,11 @@ def test_add_to_local_command_with_error() -> None: tables["local"].children(as_objects=True)[0].error_on_insert = RuntimeError try: - gateway.apply(process(gateway.create_link(), requested={gateway.translator.to_identifier({"a": 0})}).updates) + gateway.apply( + gateway.create_link() + .apply(Operations.PROCESS, requested={gateway.translator.to_identifier({"a": 0})}) + .updates + ) except RuntimeError: pass @@ -425,7 +431,9 @@ def test_add_to_local_command_with_external_file(tmpdir: Path) -> None: tables["source"].insert([{"a": 0, "external": insert_filepath}]) os.remove(insert_filepath) tables["outbound"].insert([{"a": 0, "process": "PULL", "is_flagged": "FALSE", "is_deprecated": "FALSE"}]) - gateway.apply(process(gateway.create_link(), requested={gateway.translator.to_identifier({"a": 0})}).updates) + gateway.apply( + gateway.create_link().apply(Operations.PROCESS, requested={gateway.translator.to_identifier({"a": 0})}).updates + ) fetch_filepath = Path(tables["local"].fetch(as_dict=True, download_path=str(tmpdir))[0]["external"]) with fetch_filepath.open(mode="rb") as file: assert file.read() == data @@ -444,7 +452,11 @@ def test_remove_from_local_command() -> None: ) with as_stdin(StringIO("y")): - gateway.apply(process(gateway.create_link(), requested={gateway.translator.to_identifier({"a": 0})}).updates) + gateway.apply( + gateway.create_link() + .apply(Operations.PROCESS, requested={gateway.translator.to_identifier({"a": 0})}) + .updates + ) assert has_state( tables, @@ -460,7 +472,11 @@ def test_start_pull_process() -> None: "link", primary={"a"}, non_primary={"b"}, initial=State(source=TableState([{"a": 0, "b": 1}])) ) - gateway.apply(start_pull(gateway.create_link(), requested={gateway.translator.to_identifier({"a": 0})}).updates) + gateway.apply( + gateway.create_link() + .apply(Operations.START_PULL, requested={gateway.translator.to_identifier({"a": 0})}) + .updates + ) assert has_state( tables, @@ -485,7 +501,11 @@ def initial_state() -> State: def test_state_after_command(initial_state: State) -> None: tables, gateway = initialize("link", primary={"a"}, non_primary={"b"}, initial=initial_state) - gateway.apply(process(gateway.create_link(), requested={gateway.translator.to_identifier({"a": 0})}).updates) + gateway.apply( + gateway.create_link() + .apply(Operations.PROCESS, requested={gateway.translator.to_identifier({"a": 0})}) + .updates + ) assert has_state( tables, @@ -503,7 +523,9 @@ def test_rollback_on_error(initial_state: State) -> None: tables["outbound"].error_on_insert = RuntimeError try: gateway.apply( - process(gateway.create_link(), requested={gateway.translator.to_identifier({"a": 0})}).updates + gateway.create_link() + .apply(Operations.PROCESS, requested={gateway.translator.to_identifier({"a": 0})}) + .updates ) except RuntimeError: pass @@ -526,7 +548,9 @@ def test_state_after_command(initial_state: State) -> None: tables, gateway = initialize("link", primary={"a"}, non_primary={"b"}, initial=initial_state) gateway.apply( - start_delete(gateway.create_link(), requested={gateway.translator.to_identifier({"a": 0})}).updates + gateway.create_link() + .apply(Operations.START_DELETE, requested={gateway.translator.to_identifier({"a": 0})}) + .updates ) assert has_state( @@ -545,7 +569,9 @@ def test_rollback_on_error(initial_state: State) -> None: tables["outbound"].error_on_insert = RuntimeError try: gateway.apply( - start_delete(gateway.create_link(), requested={gateway.translator.to_identifier({"a": 0})}).updates + gateway.create_link() + .apply(Operations.START_DELETE, requested={gateway.translator.to_identifier({"a": 0})}) + .updates ) except RuntimeError: pass @@ -564,7 +590,9 @@ def test_finish_delete_process_command() -> None: ), ) - gateway.apply(process(gateway.create_link(), requested={gateway.translator.to_identifier({"a": 0})}).updates) + gateway.apply( + gateway.create_link().apply(Operations.PROCESS, requested={gateway.translator.to_identifier({"a": 0})}).updates + ) assert has_state(tables, State(source=TableState([{"a": 0, "b": 1}]))) @@ -582,7 +610,11 @@ def initial_state() -> State: def test_state_after_command(initial_state: State) -> None: tables, gateway = initialize("link", primary={"a"}, non_primary={"b"}, initial=initial_state) - gateway.apply(process(gateway.create_link(), requested={gateway.translator.to_identifier({"a": 0})}).updates) + gateway.apply( + gateway.create_link() + .apply(Operations.PROCESS, requested={gateway.translator.to_identifier({"a": 0})}) + .updates + ) assert has_state( tables, @@ -599,7 +631,9 @@ def test_rollback_on_error(initial_state: State) -> None: tables["outbound"].error_on_insert = RuntimeError try: gateway.apply( - process(gateway.create_link(), requested={gateway.translator.to_identifier({"a": 0})}).updates + gateway.create_link() + .apply(Operations.PROCESS, requested={gateway.translator.to_identifier({"a": 0})}) + .updates ) except RuntimeError: pass @@ -626,7 +660,9 @@ def test_applying_multiple_commands() -> None: with as_stdin(StringIO("y")): gateway.apply( - process(gateway.create_link(), requested=gateway.translator.to_identifiers([{"a": 0}, {"a": 1}])).updates + gateway.create_link() + .apply(Operations.PROCESS, requested=gateway.translator.to_identifiers([{"a": 0}, {"a": 1}])) + .updates ) assert has_state( diff --git a/tests/unit/entities/test_link.py b/tests/unit/entities/test_link.py index 3f78e484..176cbd8a 100644 --- a/tests/unit/entities/test_link.py +++ b/tests/unit/entities/test_link.py @@ -6,8 +6,8 @@ import pytest from link.domain.custom_types import Identifier -from link.domain.link import Link, create_link, process, start_delete, start_pull -from link.domain.state import Components, Processes, State, states +from link.domain.link import Link, create_link +from link.domain.state import Components, Operations, Processes, State, states from tests.assignments import create_assignments, create_identifier, create_identifiers @@ -183,7 +183,7 @@ def test_process_produces_correct_updates() -> None: ) actual = { (update.identifier, update.transition.new) - for update in process(link, requested=create_identifiers("1", "2", "3", "4")).updates + for update in link.apply(Operations.PROCESS, requested=create_identifiers("1", "2", "3", "4")).updates } expected = { (create_identifier("1"), states.Received), @@ -202,7 +202,7 @@ def link() -> Link: @staticmethod def test_idle_entity_becomes_activated(link: Link) -> None: - result = start_pull(link, requested=create_identifiers("1")) + result = link.apply(Operations.START_PULL, requested=create_identifiers("1")) update = next(iter(result.updates)) assert update.identifier == create_identifier("1") assert update.transition.new is states.Activated @@ -210,12 +210,12 @@ def test_idle_entity_becomes_activated(link: Link) -> None: @staticmethod def test_not_specifying_requested_identifiers_raises_error(link: Link) -> None: with pytest.raises(AssertionError, match="No identifiers requested."): - start_pull(link, requested={}) + link.apply(Operations.START_PULL, requested={}) @staticmethod def test_specifying_identifiers_not_present_in_link_raises_error(link: Link) -> None: with pytest.raises(AssertionError, match="Requested identifiers not present in link."): - start_pull(link, requested=create_identifiers("2")) + link.apply(Operations.START_PULL, requested=create_identifiers("2")) @pytest.fixture() @@ -228,7 +228,7 @@ def link() -> Link: class TestStartDelete: @staticmethod def test_pulled_entity_becomes_received(link: Link) -> None: - result = start_delete(link, requested=create_identifiers("1")) + result = link.apply(Operations.START_DELETE, requested=create_identifiers("1")) update = next(iter(result.updates)) assert {update.identifier} == create_identifiers("1") assert update.transition.new is states.Received @@ -236,9 +236,9 @@ def test_pulled_entity_becomes_received(link: Link) -> None: @staticmethod def test_not_specifying_requested_identifiers_raises_error(link: Link) -> None: with pytest.raises(AssertionError, match="No identifiers requested."): - start_delete(link, requested={}) + link.apply(Operations.START_DELETE, requested={}) @staticmethod def test_specifying_identifiers_not_present_in_link_raises_error(link: Link) -> None: with pytest.raises(AssertionError, match="Requested identifiers not present in link."): - start_delete(link, requested=create_identifiers("2")) + link.apply(Operations.START_DELETE, requested=create_identifiers("2")) From 234b8fa8db9312f04b1aa2d9c5f1029bbf67c1b5 Mon Sep 17 00:00:00 2001 From: Christoph Blessing <33834216+cblessing24@users.noreply.github.com> Date: Thu, 19 Oct 2023 17:32:52 +0200 Subject: [PATCH 06/13] Convert link from frozenset to set --- link/domain/link.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/link/domain/link.py b/link/domain/link.py index 29e16609..19431a0c 100644 --- a/link/domain/link.py +++ b/link/domain/link.py @@ -2,7 +2,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, FrozenSet, Iterable, Mapping, Optional, TypeVar +from typing import Any, Iterable, Iterator, Mapping, Optional, Set, TypeVar from .custom_types import Identifier from .state import ( @@ -91,9 +91,13 @@ def assign_to_component(component: Components) -> set[Entity]: return Link(entity_assignments[Components.SOURCE]) -class Link(FrozenSet[Entity]): +class Link(Set[Entity]): """The state of a link between two databases.""" + def __init__(self, entities: Iterable[Entity]) -> None: + """Initialize the link.""" + self._entities = set(entities) + @property def identifiers(self) -> frozenset[Identifier]: """Return the identifiers of all entities in the link.""" @@ -137,6 +141,18 @@ def _validate_requested(self, requested: Iterable[Identifier]) -> None: assert requested, "No identifiers requested." assert set(requested) <= self.identifiers, "Requested identifiers not present in link." + def __contains__(self, entity: object) -> bool: + """Check if the link contains the given entity.""" + return entity in self._entities + + def __iter__(self) -> Iterator[Entity]: + """Iterate over all entities in the link.""" + return iter(self._entities) + + def __len__(self) -> int: + """Return the number of entities in the link.""" + return len(self._entities) + @dataclass(frozen=True) class LinkOperationResult: From 0ca302f8ad39ec9a71b65c20779f11c4874e6d2a Mon Sep 17 00:00:00 2001 From: Christoph Blessing <33834216+cblessing24@users.noreply.github.com> Date: Thu, 19 Oct 2023 17:49:22 +0200 Subject: [PATCH 07/13] Add operation results attribute to link --- link/domain/link.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/link/domain/link.py b/link/domain/link.py index 19431a0c..24ba5918 100644 --- a/link/domain/link.py +++ b/link/domain/link.py @@ -2,7 +2,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Iterable, Iterator, Mapping, Optional, Set, TypeVar +from typing import Any, Iterable, Iterator, Mapping, Optional, Set, Tuple, TypeVar from .custom_types import Identifier from .state import ( @@ -94,15 +94,23 @@ def assign_to_component(component: Components) -> set[Entity]: class Link(Set[Entity]): """The state of a link between two databases.""" - def __init__(self, entities: Iterable[Entity]) -> None: + def __init__( + self, entities: Iterable[Entity], operation_results: Tuple[LinkOperationResult, ...] = tuple() + ) -> None: """Initialize the link.""" self._entities = set(entities) + self._operation_results = operation_results @property def identifiers(self) -> frozenset[Identifier]: """Return the identifiers of all entities in the link.""" return frozenset(entity.identifier for entity in self) + @property + def operation_results(self) -> Tuple[LinkOperationResult, ...]: + """Return the results of operations performed on this link.""" + return self._operation_results + def apply(self, operation: Operations, *, requested: Iterable[Identifier]) -> LinkOperationResult: """Apply an operation to the requested entities.""" if operation is Operations.START_PULL: From 6f0f18ecb441590d65977441017d358069753b63 Mon Sep 17 00:00:00 2001 From: Christoph Blessing <33834216+cblessing24@users.noreply.github.com> Date: Thu, 19 Oct 2023 17:52:24 +0200 Subject: [PATCH 08/13] Simplify operation application logic in link --- link/domain/link.py | 54 ++++++++++----------------------------------- 1 file changed, 12 insertions(+), 42 deletions(-) diff --git a/link/domain/link.py b/link/domain/link.py index 24ba5918..f29b648e 100644 --- a/link/domain/link.py +++ b/link/domain/link.py @@ -113,41 +113,22 @@ def operation_results(self) -> Tuple[LinkOperationResult, ...]: def apply(self, operation: Operations, *, requested: Iterable[Identifier]) -> LinkOperationResult: """Apply an operation to the requested entities.""" - if operation is Operations.START_PULL: - return self._start_pull(requested) - if operation is Operations.START_DELETE: - return self._start_delete(requested) - if operation is Operations.PROCESS: - return self._process(requested) - - def _process(self, requested: Iterable[Identifier]) -> LinkOperationResult: - """Process all entities in the link producing appropriate updates.""" - self._validate_requested(requested) - return create_link_operation_result( - entity.apply(Operations.PROCESS).operation_results[0] for entity in self if entity.identifier in requested - ) - - def _start_delete(self, requested: Iterable[Identifier]) -> LinkOperationResult: - """Start the delete process on the requested entities.""" - self._validate_requested(requested) - return create_link_operation_result( - entity.apply(Operations.START_DELETE).operation_results[0] - for entity in self - if entity.identifier in requested - ) - def _start_pull(self, requested: Iterable[Identifier]) -> LinkOperationResult: - """Start the pull process on the requested entities.""" - self._validate_requested(requested) - return create_link_operation_result( - entity.apply(Operations.START_PULL).operation_results[0] - for entity in self - if entity.identifier in requested - ) + def create_operation_result(results: Iterable[EntityOperationResult]) -> LinkOperationResult: + """Create the result of an operation on a link from results of individual entities.""" + results = set(results) + operation = next(iter(results)).operation + return LinkOperationResult( + operation, + updates=frozenset(result for result in results if isinstance(result, Update)), + errors=frozenset(result for result in results if isinstance(result, InvalidOperation)), + ) - def _validate_requested(self, requested: Iterable[Identifier]) -> None: assert requested, "No identifiers requested." assert set(requested) <= self.identifiers, "Requested identifiers not present in link." + return create_operation_result( + entity.apply(operation).operation_results[0] for entity in self if entity.identifier in requested + ) def __contains__(self, entity: object) -> bool: """Check if the link contains the given entity.""" @@ -175,14 +156,3 @@ def __post_init__(self) -> None: assert all( result.operation is self.operation for result in (self.updates | self.errors) ), "Not all results have same operation." - - -def create_link_operation_result(results: Iterable[EntityOperationResult]) -> LinkOperationResult: - """Create the result of an operation on a link from results of individual entities.""" - results = set(results) - operation = next(iter(results)).operation - return LinkOperationResult( - operation, - updates=frozenset(result for result in results if isinstance(result, Update)), - errors=frozenset(result for result in results if isinstance(result, InvalidOperation)), - ) From fed91c3c79d9b59e45b446a5ac5932d66a5a0121 Mon Sep 17 00:00:00 2001 From: Christoph Blessing <33834216+cblessing24@users.noreply.github.com> Date: Fri, 20 Oct 2023 13:53:29 +0200 Subject: [PATCH 09/13] Return a updated link when applying an operation --- link/domain/link.py | 9 ++++--- link/service/services.py | 6 ++--- .../integration/test_datajoint_persistence.py | 25 ++++++++++++++++--- tests/unit/entities/test_link.py | 23 +++++++++-------- 4 files changed, 43 insertions(+), 20 deletions(-) diff --git a/link/domain/link.py b/link/domain/link.py index f29b648e..d6205f24 100644 --- a/link/domain/link.py +++ b/link/domain/link.py @@ -111,7 +111,7 @@ def operation_results(self) -> Tuple[LinkOperationResult, ...]: """Return the results of operations performed on this link.""" return self._operation_results - def apply(self, operation: Operations, *, requested: Iterable[Identifier]) -> LinkOperationResult: + def apply(self, operation: Operations, *, requested: Iterable[Identifier]) -> Link: """Apply an operation to the requested entities.""" def create_operation_result(results: Iterable[EntityOperationResult]) -> LinkOperationResult: @@ -126,9 +126,12 @@ def create_operation_result(results: Iterable[EntityOperationResult]) -> LinkOpe assert requested, "No identifiers requested." assert set(requested) <= self.identifiers, "Requested identifiers not present in link." - return create_operation_result( - entity.apply(operation).operation_results[0] for entity in self if entity.identifier in requested + changed = {entity.apply(operation) for entity in self if entity.identifier in requested} + unchanged = {entity for entity in self if entity.identifier not in requested} + operation_results = self.operation_results + ( + create_operation_result(entity.operation_results[0] for entity in changed), ) + return Link(changed | unchanged, operation_results) def __contains__(self, entity: object) -> bool: """Check if the link contains the given entity.""" diff --git a/link/service/services.py b/link/service/services.py index 7fe3f6a6..1defeda3 100644 --- a/link/service/services.py +++ b/link/service/services.py @@ -127,7 +127,7 @@ def start_pull_process( output_port: Callable[[OperationResponse], None], ) -> None: """Start the pull process for the requested entities.""" - result = link_gateway.create_link().apply(Operations.START_PULL, requested=request.requested) + result = link_gateway.create_link().apply(Operations.START_PULL, requested=request.requested).operation_results[0] link_gateway.apply(result.updates) output_port(OperationResponse(result.operation, request.requested, result.updates, result.errors)) @@ -146,7 +146,7 @@ def start_delete_process( output_port: Callable[[OperationResponse], None], ) -> None: """Start the delete process for the requested entities.""" - result = link_gateway.create_link().apply(Operations.START_DELETE, requested=request.requested) + result = link_gateway.create_link().apply(Operations.START_DELETE, requested=request.requested).operation_results[0] link_gateway.apply(result.updates) output_port(OperationResponse(result.operation, request.requested, result.updates, result.errors)) @@ -162,7 +162,7 @@ def process( request: ProcessRequest, *, link_gateway: LinkGateway, output_port: Callable[[OperationResponse], None] ) -> None: """Process entities.""" - result = link_gateway.create_link().apply(Operations.PROCESS, requested=request.requested) + result = link_gateway.create_link().apply(Operations.PROCESS, requested=request.requested).operation_results[0] link_gateway.apply(result.updates) output_port(OperationResponse(result.operation, request.requested, result.updates, result.errors)) diff --git a/tests/integration/test_datajoint_persistence.py b/tests/integration/test_datajoint_persistence.py index c9e91319..0de89752 100644 --- a/tests/integration/test_datajoint_persistence.py +++ b/tests/integration/test_datajoint_persistence.py @@ -371,7 +371,10 @@ def test_add_to_local_command() -> None: ) gateway.apply( - gateway.create_link().apply(Operations.PROCESS, requested={gateway.translator.to_identifier({"a": 0})}).updates + gateway.create_link() + .apply(Operations.PROCESS, requested={gateway.translator.to_identifier({"a": 0})}) + .operation_results[0] + .updates ) assert has_state( @@ -413,6 +416,7 @@ def test_add_to_local_command_with_error() -> None: gateway.apply( gateway.create_link() .apply(Operations.PROCESS, requested={gateway.translator.to_identifier({"a": 0})}) + .operation_results[0] .updates ) except RuntimeError: @@ -432,7 +436,10 @@ def test_add_to_local_command_with_external_file(tmpdir: Path) -> None: os.remove(insert_filepath) tables["outbound"].insert([{"a": 0, "process": "PULL", "is_flagged": "FALSE", "is_deprecated": "FALSE"}]) gateway.apply( - gateway.create_link().apply(Operations.PROCESS, requested={gateway.translator.to_identifier({"a": 0})}).updates + gateway.create_link() + .apply(Operations.PROCESS, requested={gateway.translator.to_identifier({"a": 0})}) + .operation_results[0] + .updates ) fetch_filepath = Path(tables["local"].fetch(as_dict=True, download_path=str(tmpdir))[0]["external"]) with fetch_filepath.open(mode="rb") as file: @@ -455,6 +462,7 @@ def test_remove_from_local_command() -> None: gateway.apply( gateway.create_link() .apply(Operations.PROCESS, requested={gateway.translator.to_identifier({"a": 0})}) + .operation_results[0] .updates ) @@ -475,6 +483,7 @@ def test_start_pull_process() -> None: gateway.apply( gateway.create_link() .apply(Operations.START_PULL, requested={gateway.translator.to_identifier({"a": 0})}) + .operation_results[0] .updates ) @@ -504,6 +513,7 @@ def test_state_after_command(initial_state: State) -> None: gateway.apply( gateway.create_link() .apply(Operations.PROCESS, requested={gateway.translator.to_identifier({"a": 0})}) + .operation_results[0] .updates ) @@ -525,6 +535,7 @@ def test_rollback_on_error(initial_state: State) -> None: gateway.apply( gateway.create_link() .apply(Operations.PROCESS, requested={gateway.translator.to_identifier({"a": 0})}) + .operation_results[0] .updates ) except RuntimeError: @@ -550,6 +561,7 @@ def test_state_after_command(initial_state: State) -> None: gateway.apply( gateway.create_link() .apply(Operations.START_DELETE, requested={gateway.translator.to_identifier({"a": 0})}) + .operation_results[0] .updates ) @@ -571,6 +583,7 @@ def test_rollback_on_error(initial_state: State) -> None: gateway.apply( gateway.create_link() .apply(Operations.START_DELETE, requested={gateway.translator.to_identifier({"a": 0})}) + .operation_results[0] .updates ) except RuntimeError: @@ -591,7 +604,10 @@ def test_finish_delete_process_command() -> None: ) gateway.apply( - gateway.create_link().apply(Operations.PROCESS, requested={gateway.translator.to_identifier({"a": 0})}).updates + gateway.create_link() + .apply(Operations.PROCESS, requested={gateway.translator.to_identifier({"a": 0})}) + .operation_results[0] + .updates ) assert has_state(tables, State(source=TableState([{"a": 0, "b": 1}]))) @@ -613,6 +629,7 @@ def test_state_after_command(initial_state: State) -> None: gateway.apply( gateway.create_link() .apply(Operations.PROCESS, requested={gateway.translator.to_identifier({"a": 0})}) + .operation_results[0] .updates ) @@ -633,6 +650,7 @@ def test_rollback_on_error(initial_state: State) -> None: gateway.apply( gateway.create_link() .apply(Operations.PROCESS, requested={gateway.translator.to_identifier({"a": 0})}) + .operation_results[0] .updates ) except RuntimeError: @@ -662,6 +680,7 @@ def test_applying_multiple_commands() -> None: gateway.apply( gateway.create_link() .apply(Operations.PROCESS, requested=gateway.translator.to_identifiers([{"a": 0}, {"a": 1}])) + .operation_results[0] .updates ) diff --git a/tests/unit/entities/test_link.py b/tests/unit/entities/test_link.py index 176cbd8a..9df845a3 100644 --- a/tests/unit/entities/test_link.py +++ b/tests/unit/entities/test_link.py @@ -167,7 +167,7 @@ def test_can_get_identifiers_of_entities_in_component( assert set(link.identifiers) == create_identifiers("1", "2") -def test_process_produces_correct_updates() -> None: +def test_link_is_processed_correctly() -> None: link = create_link( create_assignments( { @@ -182,14 +182,15 @@ def test_process_produces_correct_updates() -> None: }, ) actual = { - (update.identifier, update.transition.new) - for update in link.apply(Operations.PROCESS, requested=create_identifiers("1", "2", "3", "4")).updates + (entity.identifier, entity.state) + for entity in link.apply(Operations.PROCESS, requested=create_identifiers("1", "2", "3", "4")) } expected = { (create_identifier("1"), states.Received), (create_identifier("2"), states.Pulled), (create_identifier("3"), states.Idle), (create_identifier("4"), states.Activated), + (create_identifier("5"), states.Received), } assert actual == expected @@ -202,10 +203,10 @@ def link() -> Link: @staticmethod def test_idle_entity_becomes_activated(link: Link) -> None: - result = link.apply(Operations.START_PULL, requested=create_identifiers("1")) - update = next(iter(result.updates)) - assert update.identifier == create_identifier("1") - assert update.transition.new is states.Activated + link = link.apply(Operations.START_PULL, requested=create_identifiers("1")) + entity = next(iter(link)) + assert entity.identifier == create_identifier("1") + assert entity.state is states.Activated @staticmethod def test_not_specifying_requested_identifiers_raises_error(link: Link) -> None: @@ -228,10 +229,10 @@ def link() -> Link: class TestStartDelete: @staticmethod def test_pulled_entity_becomes_received(link: Link) -> None: - result = link.apply(Operations.START_DELETE, requested=create_identifiers("1")) - update = next(iter(result.updates)) - assert {update.identifier} == create_identifiers("1") - assert update.transition.new is states.Received + link = link.apply(Operations.START_DELETE, requested=create_identifiers("1")) + entity = next(iter(link)) + assert entity.identifier == create_identifier("1") + assert entity.state is states.Received @staticmethod def test_not_specifying_requested_identifiers_raises_error(link: Link) -> None: From e0f3a4936db000993033d43b3bf22d88199a441f Mon Sep 17 00:00:00 2001 From: Christoph Blessing <33834216+cblessing24@users.noreply.github.com> Date: Mon, 23 Oct 2023 15:50:24 +0200 Subject: [PATCH 10/13] Add unit of work --- link/infrastructure/link.py | 16 +++--- link/service/services.py | 34 +++++------ link/service/uow.py | 86 ++++++++++++++++++++++++++++ tests/integration/test_services.py | 92 ++++++++++++++++++------------ 4 files changed, 164 insertions(+), 64 deletions(-) create mode 100644 link/service/uow.py diff --git a/link/infrastructure/link.py b/link/infrastructure/link.py index 9c509bfd..53b727a1 100644 --- a/link/infrastructure/link.py +++ b/link/infrastructure/link.py @@ -24,6 +24,7 @@ start_delete_process, start_pull_process, ) +from link.service.uow import UnitOfWork from . import DJConfiguration, create_tables from .facade import DJLinkFacade @@ -54,17 +55,16 @@ def inner(obj: type) -> Any: ) facade = DJLinkFacade(tables.source, tables.outbound, tables.local) gateway = DJLinkGateway(facade, translator) + uow = UnitOfWork(gateway) source_restriction: IterationCallbackList[PrimaryKey] = IterationCallbackList() idle_entities_updater = create_idle_entities_updater(translator, create_content_replacer(source_restriction)) operation_presenter = create_operation_response_presenter(translator, create_operation_logger()) - process_service = partial( - make_responsive(partial(process, link_gateway=gateway)), output_port=operation_presenter - ) + process_service = partial(make_responsive(partial(process, uow=uow)), output_port=operation_presenter) start_pull_process_service = partial( - make_responsive(partial(start_pull_process, link_gateway=gateway)), output_port=operation_presenter + make_responsive(partial(start_pull_process, uow=uow)), output_port=operation_presenter ) start_delete_process_service = partial( - make_responsive(partial(start_delete_process, link_gateway=gateway)), output_port=operation_presenter + make_responsive(partial(start_delete_process, uow=uow)), output_port=operation_presenter ) process_to_completion_service = partial( make_responsive(partial(process_to_completion, process_service=process_service)), output_port=lambda x: None @@ -82,10 +82,8 @@ def inner(obj: type) -> Any: start_delete_process_service=start_delete_process_service, output_port=lambda x: None, ), - Services.PROCESS: partial(process, link_gateway=gateway, output_port=operation_presenter), - Services.LIST_IDLE_ENTITIES: partial( - list_idle_entities, link_gateway=gateway, output_port=idle_entities_updater - ), + Services.PROCESS: partial(process, uow=uow, output_port=operation_presenter), + Services.LIST_IDLE_ENTITIES: partial(list_idle_entities, uow=uow, output_port=idle_entities_updater), } controller = DJController(handlers, translator) source_restriction.callback = controller.list_idle_entities diff --git a/link/service/services.py b/link/service/services.py index 1defeda3..a3237f6c 100644 --- a/link/service/services.py +++ b/link/service/services.py @@ -8,7 +8,7 @@ from link.domain.custom_types import Identifier from link.domain.state import InvalidOperation, Operations, Update, states -from .gateway import LinkGateway +from .uow import UnitOfWork class Request: @@ -123,12 +123,13 @@ class StartPullProcessRequest(Request): def start_pull_process( request: StartPullProcessRequest, *, - link_gateway: LinkGateway, + uow: UnitOfWork, output_port: Callable[[OperationResponse], None], ) -> None: """Start the pull process for the requested entities.""" - result = link_gateway.create_link().apply(Operations.START_PULL, requested=request.requested).operation_results[0] - link_gateway.apply(result.updates) + with uow: + result = uow.link.apply(Operations.START_PULL, requested=request.requested).operation_results[0] + uow.commit() output_port(OperationResponse(result.operation, request.requested, result.updates, result.errors)) @@ -142,12 +143,13 @@ class StartDeleteProcessRequest(Request): def start_delete_process( request: StartDeleteProcessRequest, *, - link_gateway: LinkGateway, + uow: UnitOfWork, output_port: Callable[[OperationResponse], None], ) -> None: """Start the delete process for the requested entities.""" - result = link_gateway.create_link().apply(Operations.START_DELETE, requested=request.requested).operation_results[0] - link_gateway.apply(result.updates) + with uow: + result = uow.link.apply(Operations.START_DELETE, requested=request.requested).operation_results[0] + uow.commit() output_port(OperationResponse(result.operation, request.requested, result.updates, result.errors)) @@ -158,12 +160,11 @@ class ProcessRequest(Request): requested: frozenset[Identifier] -def process( - request: ProcessRequest, *, link_gateway: LinkGateway, output_port: Callable[[OperationResponse], None] -) -> None: +def process(request: ProcessRequest, *, uow: UnitOfWork, output_port: Callable[[OperationResponse], None]) -> None: """Process entities.""" - result = link_gateway.create_link().apply(Operations.PROCESS, requested=request.requested).operation_results[0] - link_gateway.apply(result.updates) + with uow: + result = uow.link.apply(Operations.PROCESS, requested=request.requested).operation_results[0] + uow.commit() output_port(OperationResponse(result.operation, request.requested, result.updates, result.errors)) @@ -182,15 +183,14 @@ class ListIdleEntitiesResponse(Response): def list_idle_entities( request: ListIdleEntitiesRequest, *, - link_gateway: LinkGateway, + uow: UnitOfWork, output_port: Callable[[ListIdleEntitiesResponse], None], ) -> None: """List all idle entities.""" - output_port( - ListIdleEntitiesResponse( - frozenset(entity.identifier for entity in link_gateway.create_link() if entity.state is states.Idle) + with uow: + output_port( + ListIdleEntitiesResponse(frozenset(entity.identifier for entity in uow.link if entity.state is states.Idle)) ) - ) class Services(Enum): diff --git a/link/service/uow.py b/link/service/uow.py new file mode 100644 index 00000000..3ede3823 --- /dev/null +++ b/link/service/uow.py @@ -0,0 +1,86 @@ +"""Contains the unit of work for links.""" +from __future__ import annotations + +from abc import ABC +from collections import defaultdict, deque +from types import TracebackType +from typing import Callable + +from link.domain.custom_types import Identifier +from link.domain.link import Link +from link.domain.state import TRANSITION_MAP, Entity, Operations, Transition, Update + +from .gateway import LinkGateway + + +class UnitOfWork(ABC): + """Controls if and when updates to entities of a link are persisted.""" + + def __init__(self, gateway: LinkGateway) -> None: + """Initialize the unit of work.""" + self._gateway = gateway + self._link: Link | None = None + self._updates: dict[Identifier, deque[Update]] = defaultdict(deque) + self._entities: dict[Identifier, Entity] = {} + + def __enter__(self) -> UnitOfWork: + """Enter the context in which updates to entities can be made.""" + + def track_entity(entity: Entity) -> None: + apply = getattr(entity, "apply") + augmented = augment_apply(entity, apply) + object.__setattr__(entity, "apply", augmented) + self._entities[entity.identifier] = entity + + def augment_apply(current: Entity, apply: Callable[[Operations], Entity]) -> Callable[[Operations], Entity]: + def track_and_apply(operation: Operations) -> Entity: + if self._link is None: + raise RuntimeError + new = apply(operation) + store_update(operation, current, new) + track_entity(new) + return new + + return track_and_apply + + def store_update(operation: Operations, current: Entity, new: Entity) -> None: + assert current.identifier == new.identifier + if current.state is new.state: + return + transition = Transition(current.state, new.state) + self._updates[current.identifier].append( + Update(operation, current.identifier, transition, TRANSITION_MAP[transition]) + ) + + link = self._gateway.create_link() + for entity in link: + track_entity(entity) + self._link = link + return self + + def __exit__( + self, exc_type: type[BaseException] | None, exc: BaseException | None, traceback: TracebackType | None + ) -> None: + """Exit the context rolling back any not yet persisted updates.""" + self.rollback() + + @property + def link(self) -> Link: + """Return the link object that is governed by this unit of work.""" + if self._link is None: + raise RuntimeError + return self._link + + def commit(self) -> None: + """Persist updates made to the link.""" + while self._updates: + identifier, updates = self._updates.popitem() + while updates: + self._gateway.apply([updates.popleft()]) + self._link = None + + def rollback(self) -> None: + """Throw away any not yet persisted updates.""" + self._link = None + self._entities = {} + self._updates = defaultdict(deque) diff --git a/tests/integration/test_services.py b/tests/integration/test_services.py index 12766b6d..1582d58a 100644 --- a/tests/integration/test_services.py +++ b/tests/integration/test_services.py @@ -30,6 +30,7 @@ start_delete_process, start_pull_process, ) +from link.service.uow import UnitOfWork from tests.assignments import create_assignments, create_identifier, create_identifiers @@ -92,7 +93,7 @@ def __call__(self, response: T) -> None: self._response = response -def create_gateway(state: type[State], process: Processes | None = None, is_tainted: bool = False) -> FakeLinkGateway: +def create_uow(state: type[State], process: Processes | None = None, is_tainted: bool = False) -> UnitOfWork: if state in (states.Activated, states.Received): assert process is not None else: @@ -112,22 +113,26 @@ def create_gateway(state: type[State], process: Processes | None = None, is_tain processes = {} assignments = {Components.SOURCE: {"1"}} if state is states.Idle: - return FakeLinkGateway( - create_assignments(assignments), tainted_identifiers=tainted_identifiers, processes=processes + return UnitOfWork( + FakeLinkGateway( + create_assignments(assignments), tainted_identifiers=tainted_identifiers, processes=processes + ) ) assignments[Components.OUTBOUND] = {"1"} if state in (states.Deprecated, states.Activated): - return FakeLinkGateway( - create_assignments(assignments), tainted_identifiers=tainted_identifiers, processes=processes + return UnitOfWork( + FakeLinkGateway( + create_assignments(assignments), tainted_identifiers=tainted_identifiers, processes=processes + ) ) assignments[Components.LOCAL] = {"1"} - return FakeLinkGateway( - create_assignments(assignments), tainted_identifiers=tainted_identifiers, processes=processes + return UnitOfWork( + FakeLinkGateway(create_assignments(assignments), tainted_identifiers=tainted_identifiers, processes=processes) ) -def create_process_to_completion_service(gateway: FakeLinkGateway) -> Callable[[ProcessToCompletionRequest], None]: - process_service = partial(make_responsive(partial(process, link_gateway=gateway)), output_port=lambda x: None) +def create_process_to_completion_service(uow: UnitOfWork) -> Callable[[ProcessToCompletionRequest], None]: + process_service = partial(make_responsive(partial(process, uow=uow)), output_port=lambda x: None) return partial( make_responsive( partial( @@ -139,10 +144,10 @@ def create_process_to_completion_service(gateway: FakeLinkGateway) -> Callable[[ ) -def create_pull_service(gateway: FakeLinkGateway) -> Service[PullRequest, PullResponse]: - process_to_completion_service = create_process_to_completion_service(gateway) +def create_pull_service(uow: UnitOfWork) -> Service[PullRequest, PullResponse]: + process_to_completion_service = create_process_to_completion_service(uow) start_pull_process_service = partial( - make_responsive(partial(start_pull_process, link_gateway=gateway)), output_port=lambda x: None + make_responsive(partial(start_pull_process, uow=uow)), output_port=lambda x: None ) return partial( pull, @@ -151,10 +156,10 @@ def create_pull_service(gateway: FakeLinkGateway) -> Service[PullRequest, PullRe ) -def create_delete_service(gateway: FakeLinkGateway) -> Service[DeleteRequest, DeleteResponse]: - process_to_completion_service = create_process_to_completion_service(gateway) +def create_delete_service(uow: UnitOfWork) -> Service[DeleteRequest, DeleteResponse]: + process_to_completion_service = create_process_to_completion_service(uow) start_delete_process_service = partial( - make_responsive(partial(start_delete_process, link_gateway=gateway)), output_port=lambda x: None + make_responsive(partial(start_delete_process, uow=uow)), output_port=lambda x: None ) return partial( delete, @@ -203,18 +208,21 @@ class EntityConfig(TypedDict): ], ) def test_deleted_entity_ends_in_correct_state(state: EntityConfig, expected: type[State]) -> None: - gateway = create_gateway(**state) - delete_service = create_delete_service(gateway) + uow = create_uow(**state) + delete_service = create_delete_service(uow) delete_service(DeleteRequest(frozenset(create_identifiers("1"))), output_port=lambda x: None) - assert next(iter(gateway.create_link())).state is expected + with uow: + assert next(iter(uow.link)).state is expected def test_correct_response_model_gets_passed_to_delete_output_port() -> None: - gateway = FakeLinkGateway( - create_assignments({Components.SOURCE: {"1"}, Components.OUTBOUND: {"1"}, Components.LOCAL: {"1"}}) + uow = UnitOfWork( + FakeLinkGateway( + create_assignments({Components.SOURCE: {"1"}, Components.OUTBOUND: {"1"}, Components.LOCAL: {"1"}}) + ) ) output_port = FakeOutputPort[DeleteResponse]() - delete_service = create_delete_service(gateway) + delete_service = create_delete_service(uow) delete_service(DeleteRequest(frozenset(create_identifiers("1"))), output_port=output_port) assert output_port.response.requested == create_identifiers("1") @@ -237,13 +245,14 @@ def test_correct_response_model_gets_passed_to_delete_output_port() -> None: ], ) def test_pulled_entity_ends_in_correct_state(state: EntityConfig, expected: type[State]) -> None: - gateway = create_gateway(**state) - pull_service = create_pull_service(gateway) + uow = create_uow(**state) + pull_service = create_pull_service(uow) pull_service( PullRequest(frozenset(create_identifiers("1"))), output_port=lambda x: None, ) - assert next(iter(gateway.create_link())).state is expected + with uow: + assert next(iter(uow.link)).state is expected @pytest.mark.parametrize( @@ -272,7 +281,7 @@ def test_correct_response_model_gets_passed_to_pull_output_port(state: EntityCon } else: errors = set() - gateway = create_gateway(**state) + gateway = create_uow(**state) output_port = FakeOutputPort[PullResponse]() pull_service = create_pull_service(gateway) pull_service( @@ -283,28 +292,33 @@ def test_correct_response_model_gets_passed_to_pull_output_port(state: EntityCon def test_entity_undergoing_process_gets_processed() -> None: - gateway = FakeLinkGateway( - create_assignments({Components.SOURCE: {"1"}, Components.OUTBOUND: {"1"}}), - processes={Processes.PULL: create_identifiers("1")}, + uow = UnitOfWork( + FakeLinkGateway( + create_assignments({Components.SOURCE: {"1"}, Components.OUTBOUND: {"1"}}), + processes={Processes.PULL: create_identifiers("1")}, + ) ) process( ProcessRequest(frozenset(create_identifiers("1"))), - link_gateway=gateway, + uow=uow, output_port=FakeOutputPort[OperationResponse](), ) - entity = next(entity for entity in gateway.create_link() if entity.identifier == create_identifier("1")) - assert entity.state is states.Received + with uow: + entity = next(entity for entity in uow.link if entity.identifier == create_identifier("1")) + assert entity.state is states.Received def test_correct_response_model_gets_passed_to_process_output_port() -> None: - gateway = FakeLinkGateway( - create_assignments({Components.SOURCE: {"1"}, Components.OUTBOUND: {"1"}}), - processes={Processes.PULL: create_identifiers("1")}, + uow = UnitOfWork( + FakeLinkGateway( + create_assignments({Components.SOURCE: {"1"}, Components.OUTBOUND: {"1"}}), + processes={Processes.PULL: create_identifiers("1")}, + ) ) output_port = FakeOutputPort[OperationResponse]() process( ProcessRequest(frozenset(create_identifiers("1"))), - link_gateway=gateway, + uow=uow, output_port=output_port, ) assert output_port.response.requested == create_identifiers("1") @@ -312,9 +326,11 @@ def test_correct_response_model_gets_passed_to_process_output_port() -> None: def test_correct_response_model_gets_passed_to_list_idle_entities_output_port() -> None: - link_gateway = FakeLinkGateway( - create_assignments({Components.SOURCE: {"1", "2"}, Components.OUTBOUND: {"2"}, Components.LOCAL: {"2"}}) + uow = UnitOfWork( + FakeLinkGateway( + create_assignments({Components.SOURCE: {"1", "2"}, Components.OUTBOUND: {"2"}, Components.LOCAL: {"2"}}) + ) ) output_port = FakeOutputPort[ListIdleEntitiesResponse]() - list_idle_entities(ListIdleEntitiesRequest(), link_gateway=link_gateway, output_port=output_port) + list_idle_entities(ListIdleEntitiesRequest(), uow=uow, output_port=output_port) assert set(output_port.response.identifiers) == create_identifiers("1") From 2d4a498dd620478e037d5d48b39633216a843f2f Mon Sep 17 00:00:00 2001 From: Christoph Blessing <33834216+cblessing24@users.noreply.github.com> Date: Mon, 23 Oct 2023 16:21:26 +0200 Subject: [PATCH 11/13] Move fake link gateway to new module --- tests/integration/gateway.py | 52 ++++++++++++++++++++++++++++++ tests/integration/test_services.py | 51 ++--------------------------- 2 files changed, 55 insertions(+), 48 deletions(-) create mode 100644 tests/integration/gateway.py diff --git a/tests/integration/gateway.py b/tests/integration/gateway.py new file mode 100644 index 00000000..e8c1eec9 --- /dev/null +++ b/tests/integration/gateway.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Iterable + +from link.domain.custom_types import Identifier +from link.domain.link import Link, create_link +from link.domain.state import Commands, Components, Processes, Update +from link.service.gateway import LinkGateway + + +class FakeLinkGateway(LinkGateway): + def __init__( + self, + assignments: Mapping[Components, Iterable[Identifier]], + *, + tainted_identifiers: Iterable[Identifier] | None = None, + processes: Mapping[Processes, Iterable[Identifier]] | None = None, + ) -> None: + self.assignments = {component: set(identifiers) for component, identifiers in assignments.items()} + self.tainted_identifiers = set(tainted_identifiers) if tainted_identifiers is not None else set() + self.processes: dict[Processes, set[Identifier]] = {process: set() for process in Processes} + if processes is not None: + for entity_process, identifiers in processes.items(): + self.processes[entity_process].update(identifiers) + + def create_link(self) -> Link: + return create_link(self.assignments, tainted_identifiers=self.tainted_identifiers, processes=self.processes) + + def apply(self, updates: Iterable[Update]) -> None: + for update in updates: + if update.command is Commands.START_PULL_PROCESS: + self.processes[Processes.PULL].add(update.identifier) + self.assignments[Components.OUTBOUND].add(update.identifier) + elif update.command is Commands.ADD_TO_LOCAL: + self.assignments[Components.LOCAL].add(update.identifier) + elif update.command is Commands.FINISH_PULL_PROCESS: + self.processes[Processes.PULL].remove(update.identifier) + elif update.command is Commands.START_DELETE_PROCESS: + self.processes[Processes.DELETE].add(update.identifier) + elif update.command is Commands.REMOVE_FROM_LOCAL: + self.assignments[Components.LOCAL].remove(update.identifier) + elif update.command is Commands.FINISH_DELETE_PROCESS: + self.processes[Processes.DELETE].remove(update.identifier) + self.assignments[Components.OUTBOUND].remove(update.identifier) + elif update.command is Commands.DEPRECATE: + try: + self.processes[Processes.DELETE].remove(update.identifier) + except KeyError: + self.processes[Processes.PULL].remove(update.identifier) + else: + raise ValueError("Unsupported command encountered") diff --git a/tests/integration/test_services.py b/tests/integration/test_services.py index 1582d58a..444eab47 100644 --- a/tests/integration/test_services.py +++ b/tests/integration/test_services.py @@ -1,15 +1,12 @@ from __future__ import annotations -from collections.abc import Callable, Iterable, Mapping +from collections.abc import Callable from functools import partial from typing import Generic, TypedDict, TypeVar import pytest -from link.domain.custom_types import Identifier -from link.domain.link import Link, create_link -from link.domain.state import Commands, Components, InvalidOperation, Operations, Processes, State, Update, states -from link.service.gateway import LinkGateway +from link.domain.state import Components, InvalidOperation, Operations, Processes, State, states from link.service.io import Service, make_responsive from link.service.services import ( DeleteRequest, @@ -33,49 +30,7 @@ from link.service.uow import UnitOfWork from tests.assignments import create_assignments, create_identifier, create_identifiers - -class FakeLinkGateway(LinkGateway): - def __init__( - self, - assignments: Mapping[Components, Iterable[Identifier]], - *, - tainted_identifiers: Iterable[Identifier] | None = None, - processes: Mapping[Processes, Iterable[Identifier]] | None = None, - ) -> None: - self.assignments = {component: set(identifiers) for component, identifiers in assignments.items()} - self.tainted_identifiers = set(tainted_identifiers) if tainted_identifiers is not None else set() - self.processes: dict[Processes, set[Identifier]] = {process: set() for process in Processes} - if processes is not None: - for entity_process, identifiers in processes.items(): - self.processes[entity_process].update(identifiers) - - def create_link(self) -> Link: - return create_link(self.assignments, tainted_identifiers=self.tainted_identifiers, processes=self.processes) - - def apply(self, updates: Iterable[Update]) -> None: - for update in updates: - if update.command is Commands.START_PULL_PROCESS: - self.processes[Processes.PULL].add(update.identifier) - self.assignments[Components.OUTBOUND].add(update.identifier) - elif update.command is Commands.ADD_TO_LOCAL: - self.assignments[Components.LOCAL].add(update.identifier) - elif update.command is Commands.FINISH_PULL_PROCESS: - self.processes[Processes.PULL].remove(update.identifier) - elif update.command is Commands.START_DELETE_PROCESS: - self.processes[Processes.DELETE].add(update.identifier) - elif update.command is Commands.REMOVE_FROM_LOCAL: - self.assignments[Components.LOCAL].remove(update.identifier) - elif update.command is Commands.FINISH_DELETE_PROCESS: - self.processes[Processes.DELETE].remove(update.identifier) - self.assignments[Components.OUTBOUND].remove(update.identifier) - elif update.command is Commands.DEPRECATE: - try: - self.processes[Processes.DELETE].remove(update.identifier) - except KeyError: - self.processes[Processes.PULL].remove(update.identifier) - else: - raise ValueError("Unsupported command encountered") - +from .gateway import FakeLinkGateway T = TypeVar("T", bound=Response) From ce06b1c420c37c9f5a3bba4e043af925a0285053 Mon Sep 17 00:00:00 2001 From: Christoph Blessing <33834216+cblessing24@users.noreply.github.com> Date: Mon, 23 Oct 2023 18:14:24 +0200 Subject: [PATCH 12/13] Add tests for unit of work --- link/domain/link.py | 2 +- link/service/uow.py | 12 +++-- tests/integration/test_uow.py | 86 +++++++++++++++++++++++++++++++++++ 3 files changed, 95 insertions(+), 5 deletions(-) create mode 100644 tests/integration/test_uow.py diff --git a/link/domain/link.py b/link/domain/link.py index d6205f24..6485d21d 100644 --- a/link/domain/link.py +++ b/link/domain/link.py @@ -129,7 +129,7 @@ def create_operation_result(results: Iterable[EntityOperationResult]) -> LinkOpe changed = {entity.apply(operation) for entity in self if entity.identifier in requested} unchanged = {entity for entity in self if entity.identifier not in requested} operation_results = self.operation_results + ( - create_operation_result(entity.operation_results[0] for entity in changed), + create_operation_result(entity.operation_results[-1] for entity in changed), ) return Link(changed | unchanged, operation_results) diff --git a/link/service/uow.py b/link/service/uow.py index 3ede3823..34ba9e8d 100644 --- a/link/service/uow.py +++ b/link/service/uow.py @@ -30,12 +30,14 @@ def track_entity(entity: Entity) -> None: apply = getattr(entity, "apply") augmented = augment_apply(entity, apply) object.__setattr__(entity, "apply", augmented) + object.__setattr__(entity, "_is_expired", False) self._entities[entity.identifier] = entity def augment_apply(current: Entity, apply: Callable[[Operations], Entity]) -> Callable[[Operations], Entity]: def track_and_apply(operation: Operations) -> Entity: - if self._link is None: - raise RuntimeError + assert hasattr(current, "_is_expired") + if current._is_expired is True: + raise RuntimeError("Can not apply operation to expired entity") new = apply(operation) store_update(operation, current, new) track_entity(new) @@ -68,7 +70,7 @@ def __exit__( def link(self) -> Link: """Return the link object that is governed by this unit of work.""" if self._link is None: - raise RuntimeError + raise RuntimeError("Not available outside of context") return self._link def commit(self) -> None: @@ -77,10 +79,12 @@ def commit(self) -> None: identifier, updates = self._updates.popitem() while updates: self._gateway.apply([updates.popleft()]) - self._link = None + self.rollback() def rollback(self) -> None: """Throw away any not yet persisted updates.""" self._link = None + for entity in self._entities.values(): + object.__setattr__(entity, "_is_expired", True) self._entities = {} self._updates = defaultdict(deque) diff --git a/tests/integration/test_uow.py b/tests/integration/test_uow.py new file mode 100644 index 00000000..4e3fab70 --- /dev/null +++ b/tests/integration/test_uow.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +from typing import Iterable, Mapping + +import pytest + +from link.domain.state import Components, Operations, states +from link.service.uow import UnitOfWork +from tests.assignments import create_assignments, create_identifier, create_identifiers + +from .gateway import FakeLinkGateway + + +def initialize(assignments: Mapping[Components, Iterable[str]]) -> tuple[FakeLinkGateway, UnitOfWork]: + gateway = FakeLinkGateway(create_assignments(assignments)) + return gateway, UnitOfWork(gateway) + + +def test_updates_are_applied_to_gateway_on_commit() -> None: + gateway, uow = initialize({Components.SOURCE: {"1", "2"}, Components.OUTBOUND: {"2"}, Components.LOCAL: {"2"}}) + with uow: + link = uow.link.apply(Operations.START_PULL, requested=create_identifiers("1")) + link = link.apply(Operations.START_DELETE, requested=create_identifiers("2")) + link = link.apply(Operations.PROCESS, requested=create_identifiers("1", "2")) + link.apply(Operations.PROCESS, requested=create_identifiers("1", "2")) + uow.commit() + actual = {(entity.identifier, entity.state) for entity in gateway.create_link()} + expected = {(create_identifier("1"), states.Pulled), (create_identifier("2"), states.Idle)} + assert actual == expected + + +def test_updates_are_discarded_on_context_exit() -> None: + gateway, uow = initialize({Components.SOURCE: {"1", "2"}, Components.OUTBOUND: {"2"}, Components.LOCAL: {"2"}}) + with uow: + link = uow.link.apply(Operations.START_PULL, requested=create_identifiers("1")) + link = link.apply(Operations.START_DELETE, requested=create_identifiers("2")) + link = link.apply(Operations.PROCESS, requested=create_identifiers("1", "2")) + link.apply(Operations.PROCESS, requested=create_identifiers("1", "2")) + actual = {(entity.identifier, entity.state) for entity in gateway.create_link()} + expected = {(create_identifier("1"), states.Idle), (create_identifier("2"), states.Pulled)} + assert actual == expected + + +def test_updates_are_discarded_on_rollback() -> None: + gateway, uow = initialize({Components.SOURCE: {"1", "2"}, Components.OUTBOUND: {"2"}, Components.LOCAL: {"2"}}) + with uow: + link = uow.link.apply(Operations.START_PULL, requested=create_identifiers("1")) + link = link.apply(Operations.START_DELETE, requested=create_identifiers("2")) + link = link.apply(Operations.PROCESS, requested=create_identifiers("1", "2")) + link.apply(Operations.PROCESS, requested=create_identifiers("1", "2")) + uow.rollback() + actual = {(entity.identifier, entity.state) for entity in gateway.create_link()} + expected = {(create_identifier("1"), states.Idle), (create_identifier("2"), states.Pulled)} + assert actual == expected + + +def test_link_can_not_be_accessed_outside_of_context() -> None: + _, uow = initialize({Components.SOURCE: {"1"}}) + with pytest.raises(RuntimeError, match="outside"): + uow.link.apply(Operations.START_PULL, requested=create_identifiers("1")) + + +def test_no_more_operations_can_be_applied_after_commit() -> None: + _, uow = initialize({Components.SOURCE: {"1"}}) + with uow: + link = uow.link.apply(Operations.START_PULL, requested=create_identifiers("1")) + uow.commit() + with pytest.raises(RuntimeError, match="expired"): + link.apply(Operations.PROCESS, requested=create_identifiers("1")) + + +def test_no_more_operations_can_be_applied_after_rollback() -> None: + _, uow = initialize({Components.SOURCE: {"1"}}) + with uow: + link = uow.link.apply(Operations.START_PULL, requested=create_identifiers("1")) + uow.rollback() + with pytest.raises(RuntimeError, match="expired"): + link.apply(Operations.START_PULL, requested=create_identifiers("1")) + + +def test_no_more_operations_can_be_applied_after_exiting_context() -> None: + _, uow = initialize({Components.SOURCE: {"1"}}) + with uow: + link = uow.link.apply(Operations.START_PULL, requested=create_identifiers("1")) + with pytest.raises(RuntimeError, match="expired"): + link.apply(Operations.START_PULL, requested=create_identifiers("1")) From 508f348a76ad28150872e311b58748192e94a073 Mon Sep 17 00:00:00 2001 From: Christoph Blessing <33834216+cblessing24@users.noreply.github.com> Date: Tue, 24 Oct 2023 16:23:30 +0200 Subject: [PATCH 13/13] Fix some uow bugs --- link/service/uow.py | 68 +++++++++++++++------- tests/integration/test_uow.py | 104 ++++++++++++++++++++++++++-------- 2 files changed, 130 insertions(+), 42 deletions(-) diff --git a/link/service/uow.py b/link/service/uow.py index 34ba9e8d..34bbf0ef 100644 --- a/link/service/uow.py +++ b/link/service/uow.py @@ -4,7 +4,7 @@ from abc import ABC from collections import defaultdict, deque from types import TracebackType -from typing import Callable +from typing import Callable, Iterable, Protocol from link.domain.custom_types import Identifier from link.domain.link import Link @@ -13,6 +13,13 @@ from .gateway import LinkGateway +class SupportsLinkApply(Protocol): + """Protocol for an object that supports applying operations to links.""" + + def __call__(self, operation: Operations, *, requested: Iterable[Identifier]) -> Link: + """Apply the operation to the link.""" + + class UnitOfWork(ABC): """Controls if and when updates to entities of a link are persisted.""" @@ -21,29 +28,48 @@ def __init__(self, gateway: LinkGateway) -> None: self._gateway = gateway self._link: Link | None = None self._updates: dict[Identifier, deque[Update]] = defaultdict(deque) - self._entities: dict[Identifier, Entity] = {} def __enter__(self) -> UnitOfWork: """Enter the context in which updates to entities can be made.""" - def track_entity(entity: Entity) -> None: - apply = getattr(entity, "apply") - augmented = augment_apply(entity, apply) + def augment_link(link: Link) -> None: + original = getattr(link, "apply") + augmented = augment_link_apply(link, original) + object.__setattr__(link, "apply", augmented) + object.__setattr__(link, "_is_expired", False) + + def augment_link_apply(current: Link, original: SupportsLinkApply) -> SupportsLinkApply: + def augmented(operation: Operations, *, requested: Iterable[Identifier]) -> Link: + assert hasattr(current, "_is_expired") + if current._is_expired: + raise RuntimeError("Can not apply operation to expired link") + self._link = original(operation, requested=requested) + augment_link(self._link) + object.__setattr__(current, "_is_expired", True) + return self._link + + return augmented + + def augment_entity(entity: Entity) -> None: + original = getattr(entity, "apply") + augmented = augment_entity_apply(entity, original) object.__setattr__(entity, "apply", augmented) object.__setattr__(entity, "_is_expired", False) - self._entities[entity.identifier] = entity - def augment_apply(current: Entity, apply: Callable[[Operations], Entity]) -> Callable[[Operations], Entity]: - def track_and_apply(operation: Operations) -> Entity: + def augment_entity_apply( + current: Entity, original: Callable[[Operations], Entity] + ) -> Callable[[Operations], Entity]: + def augmented(operation: Operations) -> Entity: assert hasattr(current, "_is_expired") if current._is_expired is True: raise RuntimeError("Can not apply operation to expired entity") - new = apply(operation) + new = original(operation) store_update(operation, current, new) - track_entity(new) + augment_entity(new) + object.__setattr__(current, "_is_expired", True) return new - return track_and_apply + return augmented def store_update(operation: Operations, current: Entity, new: Entity) -> None: assert current.identifier == new.identifier @@ -54,10 +80,10 @@ def store_update(operation: Operations, current: Entity, new: Entity) -> None: Update(operation, current.identifier, transition, TRANSITION_MAP[transition]) ) - link = self._gateway.create_link() - for entity in link: - track_entity(entity) - self._link = link + self._link = self._gateway.create_link() + augment_link(self._link) + for entity in self._link: + augment_entity(entity) return self def __exit__( @@ -65,6 +91,7 @@ def __exit__( ) -> None: """Exit the context rolling back any not yet persisted updates.""" self.rollback() + self._link = None @property def link(self) -> Link: @@ -75,6 +102,8 @@ def link(self) -> Link: def commit(self) -> None: """Persist updates made to the link.""" + if self._link is None: + raise RuntimeError("Not available outside of context") while self._updates: identifier, updates = self._updates.popitem() while updates: @@ -83,8 +112,9 @@ def commit(self) -> None: def rollback(self) -> None: """Throw away any not yet persisted updates.""" - self._link = None - for entity in self._entities.values(): + if self._link is None: + raise RuntimeError("Not available outside of context") + object.__setattr__(self._link, "_is_expired", True) + for entity in self._link: object.__setattr__(entity, "_is_expired", True) - self._entities = {} - self._updates = defaultdict(deque) + self._updates.clear() diff --git a/tests/integration/test_uow.py b/tests/integration/test_uow.py index 4e3fab70..651e8e44 100644 --- a/tests/integration/test_uow.py +++ b/tests/integration/test_uow.py @@ -19,10 +19,10 @@ def initialize(assignments: Mapping[Components, Iterable[str]]) -> tuple[FakeLin def test_updates_are_applied_to_gateway_on_commit() -> None: gateway, uow = initialize({Components.SOURCE: {"1", "2"}, Components.OUTBOUND: {"2"}, Components.LOCAL: {"2"}}) with uow: - link = uow.link.apply(Operations.START_PULL, requested=create_identifiers("1")) - link = link.apply(Operations.START_DELETE, requested=create_identifiers("2")) - link = link.apply(Operations.PROCESS, requested=create_identifiers("1", "2")) - link.apply(Operations.PROCESS, requested=create_identifiers("1", "2")) + uow.link.apply(Operations.START_PULL, requested=create_identifiers("1")) + uow.link.apply(Operations.START_DELETE, requested=create_identifiers("2")) + uow.link.apply(Operations.PROCESS, requested=create_identifiers("1", "2")) + uow.link.apply(Operations.PROCESS, requested=create_identifiers("1", "2")) uow.commit() actual = {(entity.identifier, entity.state) for entity in gateway.create_link()} expected = {(create_identifier("1"), states.Pulled), (create_identifier("2"), states.Idle)} @@ -32,10 +32,10 @@ def test_updates_are_applied_to_gateway_on_commit() -> None: def test_updates_are_discarded_on_context_exit() -> None: gateway, uow = initialize({Components.SOURCE: {"1", "2"}, Components.OUTBOUND: {"2"}, Components.LOCAL: {"2"}}) with uow: - link = uow.link.apply(Operations.START_PULL, requested=create_identifiers("1")) - link = link.apply(Operations.START_DELETE, requested=create_identifiers("2")) - link = link.apply(Operations.PROCESS, requested=create_identifiers("1", "2")) - link.apply(Operations.PROCESS, requested=create_identifiers("1", "2")) + uow.link.apply(Operations.START_PULL, requested=create_identifiers("1")) + uow.link.apply(Operations.START_DELETE, requested=create_identifiers("2")) + uow.link.apply(Operations.PROCESS, requested=create_identifiers("1", "2")) + uow.link.apply(Operations.PROCESS, requested=create_identifiers("1", "2")) actual = {(entity.identifier, entity.state) for entity in gateway.create_link()} expected = {(create_identifier("1"), states.Idle), (create_identifier("2"), states.Pulled)} assert actual == expected @@ -44,10 +44,10 @@ def test_updates_are_discarded_on_context_exit() -> None: def test_updates_are_discarded_on_rollback() -> None: gateway, uow = initialize({Components.SOURCE: {"1", "2"}, Components.OUTBOUND: {"2"}, Components.LOCAL: {"2"}}) with uow: - link = uow.link.apply(Operations.START_PULL, requested=create_identifiers("1")) - link = link.apply(Operations.START_DELETE, requested=create_identifiers("2")) - link = link.apply(Operations.PROCESS, requested=create_identifiers("1", "2")) - link.apply(Operations.PROCESS, requested=create_identifiers("1", "2")) + uow.link.apply(Operations.START_PULL, requested=create_identifiers("1")) + uow.link.apply(Operations.START_DELETE, requested=create_identifiers("2")) + uow.link.apply(Operations.PROCESS, requested=create_identifiers("1", "2")) + uow.link.apply(Operations.PROCESS, requested=create_identifiers("1", "2")) uow.rollback() actual = {(entity.identifier, entity.state) for entity in gateway.create_link()} expected = {(create_identifier("1"), states.Idle), (create_identifier("2"), states.Pulled)} @@ -56,31 +56,89 @@ def test_updates_are_discarded_on_rollback() -> None: def test_link_can_not_be_accessed_outside_of_context() -> None: _, uow = initialize({Components.SOURCE: {"1"}}) + with uow: + pass with pytest.raises(RuntimeError, match="outside"): - uow.link.apply(Operations.START_PULL, requested=create_identifiers("1")) + uow.link -def test_no_more_operations_can_be_applied_after_commit() -> None: +def test_unable_to_commit_outside_of_context() -> None: + _, uow = initialize({Components.SOURCE: {"1"}}) + with pytest.raises(RuntimeError, match="outside"): + uow.commit() + + +def test_unable_to_rollback_outside_of_context() -> None: + _, uow = initialize({Components.SOURCE: {"1"}}) + with pytest.raises(RuntimeError, match="outside"): + uow.rollback() + + +def test_entity_expires_when_committing() -> None: _, uow = initialize({Components.SOURCE: {"1"}}) with uow: - link = uow.link.apply(Operations.START_PULL, requested=create_identifiers("1")) + entity = next(entity for entity in uow.link if entity.identifier == create_identifier("1")) uow.commit() - with pytest.raises(RuntimeError, match="expired"): - link.apply(Operations.PROCESS, requested=create_identifiers("1")) + with pytest.raises(RuntimeError, match="expired entity"): + entity.apply(Operations.START_PULL) -def test_no_more_operations_can_be_applied_after_rollback() -> None: +def test_entity_expires_when_rolling_back() -> None: _, uow = initialize({Components.SOURCE: {"1"}}) with uow: - link = uow.link.apply(Operations.START_PULL, requested=create_identifiers("1")) + entity = next(entity for entity in uow.link if entity.identifier == create_identifier("1")) uow.rollback() - with pytest.raises(RuntimeError, match="expired"): + with pytest.raises(RuntimeError, match="expired entity"): + entity.apply(Operations.START_PULL) + + +def test_entity_expires_when_leaving_context() -> None: + _, uow = initialize({Components.SOURCE: {"1"}}) + with uow: + entity = next(entity for entity in uow.link if entity.identifier == create_identifier("1")) + with pytest.raises(RuntimeError, match="expired entity"): + entity.apply(Operations.START_PULL) + + +def test_entity_expires_when_applying_operation() -> None: + _, uow = initialize({Components.SOURCE: {"1"}}) + with uow: + entity = next(entity for entity in uow.link if entity.identifier == create_identifier("1")) + entity.apply(Operations.START_PULL) + with pytest.raises(RuntimeError, match="expired entity"): + entity.apply(Operations.PROCESS) + + +def test_link_expires_when_committing() -> None: + _, uow = initialize({Components.SOURCE: {"1"}}) + with uow: + link = uow.link + uow.commit() + with pytest.raises(RuntimeError, match="expired link"): + link.apply(Operations.START_PULL, requested=create_identifiers("1")) + + +def test_link_expires_when_rolling_back() -> None: + _, uow = initialize({Components.SOURCE: {"1"}}) + with uow: + link = uow.link + uow.rollback() + with pytest.raises(RuntimeError, match="expired link"): link.apply(Operations.START_PULL, requested=create_identifiers("1")) -def test_no_more_operations_can_be_applied_after_exiting_context() -> None: +def test_link_expires_when_exiting_context() -> None: + _, uow = initialize({Components.SOURCE: {"1"}}) + with uow: + link = uow.link + with pytest.raises(RuntimeError, match="expired link"): + link.apply(Operations.START_PULL, requested=create_identifiers("1")) + + +def test_link_expires_when_applying_operation() -> None: _, uow = initialize({Components.SOURCE: {"1"}}) with uow: - link = uow.link.apply(Operations.START_PULL, requested=create_identifiers("1")) - with pytest.raises(RuntimeError, match="expired"): + link = uow.link link.apply(Operations.START_PULL, requested=create_identifiers("1")) + with pytest.raises(RuntimeError, match="expired link"): + link.apply(Operations.PROCESS, requested=create_identifiers("1"))