diff --git a/link/domain/link.py b/link/domain/link.py index 1339d852..6485d21d 100644 --- a/link/domain/link.py +++ b/link/domain/link.py @@ -2,7 +2,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, FrozenSet, Iterable, Mapping, Optional, TypeVar +from typing import Any, Iterable, Iterator, Mapping, Optional, Set, Tuple, TypeVar from .custom_types import Identifier from .state import ( @@ -68,8 +68,9 @@ def create_entity(identifier: Identifier) -> Entity: return Entity( identifier, state=state, - current_process=processes_map.get(identifier), + current_process=processes_map.get(identifier, Processes.NONE), is_tainted=is_tainted(identifier), + operation_results=tuple(), ) return {create_entity(identifier) for identifier in assignments[Components.SOURCE]} @@ -90,14 +91,60 @@ def assign_to_component(component: Components) -> set[Entity]: return Link(entity_assignments[Components.SOURCE]) -class Link(FrozenSet[Entity]): +class Link(Set[Entity]): """The state of a link between two databases.""" + def __init__( + self, entities: Iterable[Entity], operation_results: Tuple[LinkOperationResult, ...] = tuple() + ) -> None: + """Initialize the link.""" + self._entities = set(entities) + self._operation_results = operation_results + @property def identifiers(self) -> frozenset[Identifier]: """Return the identifiers of all entities in the link.""" return frozenset(entity.identifier for entity in self) + @property + def operation_results(self) -> Tuple[LinkOperationResult, ...]: + """Return the results of operations performed on this link.""" + return self._operation_results + + def apply(self, operation: Operations, *, requested: Iterable[Identifier]) -> Link: + """Apply an operation to the requested entities.""" + + def create_operation_result(results: Iterable[EntityOperationResult]) -> LinkOperationResult: + """Create the result of an operation on a link from results of individual entities.""" + results = set(results) + operation = next(iter(results)).operation + return LinkOperationResult( + operation, + updates=frozenset(result for result in results if isinstance(result, Update)), + errors=frozenset(result for result in results if isinstance(result, InvalidOperation)), + ) + + assert requested, "No identifiers requested." + assert set(requested) <= self.identifiers, "Requested identifiers not present in link." + changed = {entity.apply(operation) for entity in self if entity.identifier in requested} + unchanged = {entity for entity in self if entity.identifier not in requested} + operation_results = self.operation_results + ( + create_operation_result(entity.operation_results[-1] for entity in changed), + ) + return Link(changed | unchanged, operation_results) + + def __contains__(self, entity: object) -> bool: + """Check if the link contains the given entity.""" + return entity in self._entities + + def __iter__(self) -> Iterator[Entity]: + """Iterate over all entities in the link.""" + return iter(self._entities) + + def __len__(self) -> int: + """Return the number of entities in the link.""" + return len(self._entities) + @dataclass(frozen=True) class LinkOperationResult: @@ -112,37 +159,3 @@ def __post_init__(self) -> None: assert all( result.operation is self.operation for result in (self.updates | self.errors) ), "Not all results have same operation." - - -def create_link_operation_result(results: Iterable[EntityOperationResult]) -> LinkOperationResult: - """Create the result of an operation on a link from results of individual entities.""" - results = set(results) - operation = next(iter(results)).operation - return LinkOperationResult( - operation, - updates=frozenset(result for result in results if isinstance(result, Update)), - errors=frozenset(result for result in results if isinstance(result, InvalidOperation)), - ) - - -def process(link: Link, *, requested: Iterable[Identifier]) -> LinkOperationResult: - """Process all entities in the link producing appropriate updates.""" - _validate_requested(link, requested) - return create_link_operation_result(entity.process() for entity in link if entity.identifier in requested) - - -def _validate_requested(link: Link, requested: Iterable[Identifier]) -> None: - assert requested, "No identifiers requested." - assert set(requested) <= link.identifiers, "Requested identifiers not present in link." - - -def start_pull(link: Link, *, requested: Iterable[Identifier]) -> LinkOperationResult: - """Start the pull process on the requested entities.""" - _validate_requested(link, requested) - return create_link_operation_result(entity.start_pull() for entity in link if entity.identifier in requested) - - -def start_delete(link: Link, *, requested: Iterable[Identifier]) -> LinkOperationResult: - """Start the delete process on the requested entities.""" - _validate_requested(link, requested) - return create_link_operation_result(entity.start_delete() for entity in link if entity.identifier in requested) diff --git a/link/domain/state.py b/link/domain/state.py index 313a8c00..7b453232 100644 --- a/link/domain/state.py +++ b/link/domain/state.py @@ -1,9 +1,10 @@ """Contains everything state related.""" from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, replace from enum import Enum, auto -from typing import Optional, Union +from functools import partial +from typing import Union from .custom_types import Identifier @@ -12,35 +13,36 @@ class State: """An entity's state.""" @classmethod - def start_pull(cls, entity: Entity) -> EntityOperationResult: + def start_pull(cls, entity: Entity) -> Entity: """Return the command needed to start the pull process for the entity.""" - return cls._create_invalid_operation_result(Operations.START_PULL, entity.identifier) + return cls._create_invalid_operation(entity, Operations.START_PULL) @classmethod - def start_delete(cls, entity: Entity) -> EntityOperationResult: + def start_delete(cls, entity: Entity) -> Entity: """Return the commands needed to start the delete process for the entity.""" - return cls._create_invalid_operation_result(Operations.START_DELETE, entity.identifier) + return cls._create_invalid_operation(entity, Operations.START_DELETE) @classmethod - def process(cls, entity: Entity) -> EntityOperationResult: + def process(cls, entity: Entity) -> Entity: """Return the commands needed to process the entity.""" - return cls._create_invalid_operation_result(Operations.PROCESS, entity.identifier) + return cls._create_invalid_operation(entity, Operations.PROCESS) - @classmethod - def _create_invalid_operation_result(cls, operation: Operations, identifier: Identifier) -> EntityOperationResult: - return InvalidOperation(operation, identifier, cls) + @staticmethod + def _create_invalid_operation(entity: Entity, operation: Operations) -> Entity: + updated = entity.operation_results + (InvalidOperation(operation, entity.identifier, entity.state),) + return replace(entity, operation_results=updated) @classmethod - def _create_valid_operation_result( - cls, operation: Operations, identifier: Identifier, new_state: type[State] - ) -> EntityOperationResult: + def _transition_entity( + cls, entity: Entity, operation: Operations, new_state: type[State], *, new_process: Processes | None = None + ) -> Entity: + if new_process is None: + new_process = entity.current_process transition = Transition(cls, new_state) - return Update( - operation, - identifier, - transition, - command=TRANSITION_MAP[transition], + updated_results = entity.operation_results + ( + Update(operation, entity.identifier, transition, TRANSITION_MAP[transition]), ) + return replace(entity, state=transition.new, current_process=new_process, operation_results=updated_results) class States: @@ -66,9 +68,9 @@ class Idle(State): """The default state of an entity.""" @classmethod - def start_pull(cls, entity: Entity) -> EntityOperationResult: + def start_pull(cls, entity: Entity) -> Entity: """Return the command needed to start the pull process for an entity.""" - return cls._create_valid_operation_result(Operations.START_PULL, entity.identifier, Activated) + return cls._transition_entity(entity, Operations.START_PULL, Activated, new_process=Processes.PULL) states.register(Idle) @@ -78,16 +80,16 @@ class Activated(State): """The state of an activated entity.""" @classmethod - def process(cls, entity: Entity) -> EntityOperationResult: + def process(cls, entity: Entity) -> Entity: """Return the commands needed to process an activated entity.""" - new_state: type[State] + transition_entity = partial(cls._transition_entity, entity, Operations.PROCESS) if entity.is_tainted: - new_state = Deprecated + return transition_entity(Deprecated, new_process=Processes.NONE) elif entity.current_process is Processes.PULL: - new_state = Received + return transition_entity(Received) elif entity.current_process is Processes.DELETE: - new_state = Idle - return cls._create_valid_operation_result(Operations.PROCESS, entity.identifier, new_state) + return transition_entity(Idle, new_process=Processes.NONE) + raise RuntimeError states.register(Activated) @@ -97,17 +99,17 @@ class Received(State): """The state of an received entity.""" @classmethod - def process(cls, entity: Entity) -> EntityOperationResult: + def process(cls, entity: Entity) -> Entity: """Return the commands needed to process a received entity.""" - new_state: type[State] + transition_entity = partial(cls._transition_entity, entity, Operations.PROCESS) if entity.current_process is Processes.PULL: if entity.is_tainted: - new_state = Tainted + return transition_entity(Tainted, new_process=Processes.NONE) else: - new_state = Pulled + return transition_entity(Pulled, new_process=Processes.NONE) elif entity.current_process is Processes.DELETE: - new_state = Activated - return cls._create_valid_operation_result(Operations.PROCESS, entity.identifier, new_state) + return transition_entity(Activated) + raise RuntimeError states.register(Received) @@ -117,9 +119,9 @@ class Pulled(State): """The state of an entity that has been copied to the local side.""" @classmethod - def start_delete(cls, entity: Entity) -> EntityOperationResult: + def start_delete(cls, entity: Entity) -> Entity: """Return the commands needed to start the delete process for the entity.""" - return cls._create_valid_operation_result(Operations.START_DELETE, entity.identifier, Received) + return cls._transition_entity(entity, Operations.START_DELETE, Received, new_process=Processes.DELETE) states.register(Pulled) @@ -129,9 +131,9 @@ class Tainted(State): """The state of an entity that has been flagged as faulty by the source side.""" @classmethod - def start_delete(cls, entity: Entity) -> EntityOperationResult: + def start_delete(cls, entity: Entity) -> Entity: """Return the commands needed to start the delete process for the entity.""" - return cls._create_valid_operation_result(Operations.START_DELETE, entity.identifier, Received) + return cls._transition_entity(entity, Operations.START_DELETE, Received, new_process=Processes.DELETE) states.register(Tainted) @@ -214,8 +216,9 @@ class InvalidOperation: class Processes(Enum): """Names for processes that pull/delete entities into/from the local side.""" - PULL = 1 - DELETE = 2 + NONE = auto() + PULL = auto() + DELETE = auto() class Components(Enum): @@ -285,17 +288,27 @@ class Entity: identifier: Identifier state: type[State] - current_process: Optional[Processes] + current_process: Processes is_tainted: bool - - def start_pull(self) -> EntityOperationResult: + operation_results: tuple[EntityOperationResult, ...] + + def apply(self, operation: Operations) -> Entity: + """Apply an operation to the entity.""" + if operation is Operations.START_PULL: + return self._start_pull() + if operation is Operations.START_DELETE: + return self._start_delete() + if operation is Operations.PROCESS: + return self._process() + + def _start_pull(self) -> Entity: """Start the pull process for the entity.""" return self.state.start_pull(self) - def start_delete(self) -> EntityOperationResult: + def _start_delete(self) -> Entity: """Start the delete process for the entity.""" return self.state.start_delete(self) - def process(self) -> EntityOperationResult: + def _process(self) -> Entity: """Process the entity.""" return self.state.process(self) diff --git a/link/infrastructure/link.py b/link/infrastructure/link.py index 9c509bfd..53b727a1 100644 --- a/link/infrastructure/link.py +++ b/link/infrastructure/link.py @@ -24,6 +24,7 @@ start_delete_process, start_pull_process, ) +from link.service.uow import UnitOfWork from . import DJConfiguration, create_tables from .facade import DJLinkFacade @@ -54,17 +55,16 @@ def inner(obj: type) -> Any: ) facade = DJLinkFacade(tables.source, tables.outbound, tables.local) gateway = DJLinkGateway(facade, translator) + uow = UnitOfWork(gateway) source_restriction: IterationCallbackList[PrimaryKey] = IterationCallbackList() idle_entities_updater = create_idle_entities_updater(translator, create_content_replacer(source_restriction)) operation_presenter = create_operation_response_presenter(translator, create_operation_logger()) - process_service = partial( - make_responsive(partial(process, link_gateway=gateway)), output_port=operation_presenter - ) + process_service = partial(make_responsive(partial(process, uow=uow)), output_port=operation_presenter) start_pull_process_service = partial( - make_responsive(partial(start_pull_process, link_gateway=gateway)), output_port=operation_presenter + make_responsive(partial(start_pull_process, uow=uow)), output_port=operation_presenter ) start_delete_process_service = partial( - make_responsive(partial(start_delete_process, link_gateway=gateway)), output_port=operation_presenter + make_responsive(partial(start_delete_process, uow=uow)), output_port=operation_presenter ) process_to_completion_service = partial( make_responsive(partial(process_to_completion, process_service=process_service)), output_port=lambda x: None @@ -82,10 +82,8 @@ def inner(obj: type) -> Any: start_delete_process_service=start_delete_process_service, output_port=lambda x: None, ), - Services.PROCESS: partial(process, link_gateway=gateway, output_port=operation_presenter), - Services.LIST_IDLE_ENTITIES: partial( - list_idle_entities, link_gateway=gateway, output_port=idle_entities_updater - ), + Services.PROCESS: partial(process, uow=uow, output_port=operation_presenter), + Services.LIST_IDLE_ENTITIES: partial(list_idle_entities, uow=uow, output_port=idle_entities_updater), } controller = DJController(handlers, translator) source_restriction.callback = controller.list_idle_entities diff --git a/link/service/services.py b/link/service/services.py index 4dddd1e2..a3237f6c 100644 --- a/link/service/services.py +++ b/link/service/services.py @@ -6,11 +6,9 @@ from enum import Enum, auto from link.domain.custom_types import Identifier -from link.domain.link import process as process_domain_service -from link.domain.link import start_delete, start_pull from link.domain.state import InvalidOperation, Operations, Update, states -from .gateway import LinkGateway +from .uow import UnitOfWork class Request: @@ -125,12 +123,13 @@ class StartPullProcessRequest(Request): def start_pull_process( request: StartPullProcessRequest, *, - link_gateway: LinkGateway, + uow: UnitOfWork, output_port: Callable[[OperationResponse], None], ) -> None: """Start the pull process for the requested entities.""" - result = start_pull(link_gateway.create_link(), requested=request.requested) - link_gateway.apply(result.updates) + with uow: + result = uow.link.apply(Operations.START_PULL, requested=request.requested).operation_results[0] + uow.commit() output_port(OperationResponse(result.operation, request.requested, result.updates, result.errors)) @@ -144,12 +143,13 @@ class StartDeleteProcessRequest(Request): def start_delete_process( request: StartDeleteProcessRequest, *, - link_gateway: LinkGateway, + uow: UnitOfWork, output_port: Callable[[OperationResponse], None], ) -> None: """Start the delete process for the requested entities.""" - result = start_delete(link_gateway.create_link(), requested=request.requested) - link_gateway.apply(result.updates) + with uow: + result = uow.link.apply(Operations.START_DELETE, requested=request.requested).operation_results[0] + uow.commit() output_port(OperationResponse(result.operation, request.requested, result.updates, result.errors)) @@ -160,12 +160,11 @@ class ProcessRequest(Request): requested: frozenset[Identifier] -def process( - request: ProcessRequest, *, link_gateway: LinkGateway, output_port: Callable[[OperationResponse], None] -) -> None: +def process(request: ProcessRequest, *, uow: UnitOfWork, output_port: Callable[[OperationResponse], None]) -> None: """Process entities.""" - result = process_domain_service(link_gateway.create_link(), requested=request.requested) - link_gateway.apply(result.updates) + with uow: + result = uow.link.apply(Operations.PROCESS, requested=request.requested).operation_results[0] + uow.commit() output_port(OperationResponse(result.operation, request.requested, result.updates, result.errors)) @@ -184,15 +183,14 @@ class ListIdleEntitiesResponse(Response): def list_idle_entities( request: ListIdleEntitiesRequest, *, - link_gateway: LinkGateway, + uow: UnitOfWork, output_port: Callable[[ListIdleEntitiesResponse], None], ) -> None: """List all idle entities.""" - output_port( - ListIdleEntitiesResponse( - frozenset(entity.identifier for entity in link_gateway.create_link() if entity.state is states.Idle) + with uow: + output_port( + ListIdleEntitiesResponse(frozenset(entity.identifier for entity in uow.link if entity.state is states.Idle)) ) - ) class Services(Enum): diff --git a/link/service/uow.py b/link/service/uow.py new file mode 100644 index 00000000..34bbf0ef --- /dev/null +++ b/link/service/uow.py @@ -0,0 +1,120 @@ +"""Contains the unit of work for links.""" +from __future__ import annotations + +from abc import ABC +from collections import defaultdict, deque +from types import TracebackType +from typing import Callable, Iterable, Protocol + +from link.domain.custom_types import Identifier +from link.domain.link import Link +from link.domain.state import TRANSITION_MAP, Entity, Operations, Transition, Update + +from .gateway import LinkGateway + + +class SupportsLinkApply(Protocol): + """Protocol for an object that supports applying operations to links.""" + + def __call__(self, operation: Operations, *, requested: Iterable[Identifier]) -> Link: + """Apply the operation to the link.""" + + +class UnitOfWork(ABC): + """Controls if and when updates to entities of a link are persisted.""" + + def __init__(self, gateway: LinkGateway) -> None: + """Initialize the unit of work.""" + self._gateway = gateway + self._link: Link | None = None + self._updates: dict[Identifier, deque[Update]] = defaultdict(deque) + + def __enter__(self) -> UnitOfWork: + """Enter the context in which updates to entities can be made.""" + + def augment_link(link: Link) -> None: + original = getattr(link, "apply") + augmented = augment_link_apply(link, original) + object.__setattr__(link, "apply", augmented) + object.__setattr__(link, "_is_expired", False) + + def augment_link_apply(current: Link, original: SupportsLinkApply) -> SupportsLinkApply: + def augmented(operation: Operations, *, requested: Iterable[Identifier]) -> Link: + assert hasattr(current, "_is_expired") + if current._is_expired: + raise RuntimeError("Can not apply operation to expired link") + self._link = original(operation, requested=requested) + augment_link(self._link) + object.__setattr__(current, "_is_expired", True) + return self._link + + return augmented + + def augment_entity(entity: Entity) -> None: + original = getattr(entity, "apply") + augmented = augment_entity_apply(entity, original) + object.__setattr__(entity, "apply", augmented) + object.__setattr__(entity, "_is_expired", False) + + def augment_entity_apply( + current: Entity, original: Callable[[Operations], Entity] + ) -> Callable[[Operations], Entity]: + def augmented(operation: Operations) -> Entity: + assert hasattr(current, "_is_expired") + if current._is_expired is True: + raise RuntimeError("Can not apply operation to expired entity") + new = original(operation) + store_update(operation, current, new) + augment_entity(new) + object.__setattr__(current, "_is_expired", True) + return new + + return augmented + + def store_update(operation: Operations, current: Entity, new: Entity) -> None: + assert current.identifier == new.identifier + if current.state is new.state: + return + transition = Transition(current.state, new.state) + self._updates[current.identifier].append( + Update(operation, current.identifier, transition, TRANSITION_MAP[transition]) + ) + + self._link = self._gateway.create_link() + augment_link(self._link) + for entity in self._link: + augment_entity(entity) + return self + + def __exit__( + self, exc_type: type[BaseException] | None, exc: BaseException | None, traceback: TracebackType | None + ) -> None: + """Exit the context rolling back any not yet persisted updates.""" + self.rollback() + self._link = None + + @property + def link(self) -> Link: + """Return the link object that is governed by this unit of work.""" + if self._link is None: + raise RuntimeError("Not available outside of context") + return self._link + + def commit(self) -> None: + """Persist updates made to the link.""" + if self._link is None: + raise RuntimeError("Not available outside of context") + while self._updates: + identifier, updates = self._updates.popitem() + while updates: + self._gateway.apply([updates.popleft()]) + self.rollback() + + def rollback(self) -> None: + """Throw away any not yet persisted updates.""" + if self._link is None: + raise RuntimeError("Not available outside of context") + object.__setattr__(self._link, "_is_expired", True) + for entity in self._link: + object.__setattr__(entity, "_is_expired", True) + self._updates.clear() diff --git a/tests/integration/gateway.py b/tests/integration/gateway.py new file mode 100644 index 00000000..e8c1eec9 --- /dev/null +++ b/tests/integration/gateway.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Iterable + +from link.domain.custom_types import Identifier +from link.domain.link import Link, create_link +from link.domain.state import Commands, Components, Processes, Update +from link.service.gateway import LinkGateway + + +class FakeLinkGateway(LinkGateway): + def __init__( + self, + assignments: Mapping[Components, Iterable[Identifier]], + *, + tainted_identifiers: Iterable[Identifier] | None = None, + processes: Mapping[Processes, Iterable[Identifier]] | None = None, + ) -> None: + self.assignments = {component: set(identifiers) for component, identifiers in assignments.items()} + self.tainted_identifiers = set(tainted_identifiers) if tainted_identifiers is not None else set() + self.processes: dict[Processes, set[Identifier]] = {process: set() for process in Processes} + if processes is not None: + for entity_process, identifiers in processes.items(): + self.processes[entity_process].update(identifiers) + + def create_link(self) -> Link: + return create_link(self.assignments, tainted_identifiers=self.tainted_identifiers, processes=self.processes) + + def apply(self, updates: Iterable[Update]) -> None: + for update in updates: + if update.command is Commands.START_PULL_PROCESS: + self.processes[Processes.PULL].add(update.identifier) + self.assignments[Components.OUTBOUND].add(update.identifier) + elif update.command is Commands.ADD_TO_LOCAL: + self.assignments[Components.LOCAL].add(update.identifier) + elif update.command is Commands.FINISH_PULL_PROCESS: + self.processes[Processes.PULL].remove(update.identifier) + elif update.command is Commands.START_DELETE_PROCESS: + self.processes[Processes.DELETE].add(update.identifier) + elif update.command is Commands.REMOVE_FROM_LOCAL: + self.assignments[Components.LOCAL].remove(update.identifier) + elif update.command is Commands.FINISH_DELETE_PROCESS: + self.processes[Processes.DELETE].remove(update.identifier) + self.assignments[Components.OUTBOUND].remove(update.identifier) + elif update.command is Commands.DEPRECATE: + try: + self.processes[Processes.DELETE].remove(update.identifier) + except KeyError: + self.processes[Processes.PULL].remove(update.identifier) + else: + raise ValueError("Unsupported command encountered") diff --git a/tests/integration/test_datajoint_persistence.py b/tests/integration/test_datajoint_persistence.py index 34725e48..0de89752 100644 --- a/tests/integration/test_datajoint_persistence.py +++ b/tests/integration/test_datajoint_persistence.py @@ -17,8 +17,8 @@ from link.adapters import PrimaryKey from link.adapters.gateway import DJLinkGateway from link.adapters.identification import IdentificationTranslator -from link.domain.link import create_link, process, start_delete, start_pull -from link.domain.state import Components, Processes +from link.domain.link import create_link +from link.domain.state import Components, Operations, Processes from link.infrastructure.facade import DJLinkFacade, Table @@ -370,7 +370,12 @@ def test_add_to_local_command() -> None: ), ) - gateway.apply(process(gateway.create_link(), requested={gateway.translator.to_identifier({"a": 0})}).updates) + gateway.apply( + gateway.create_link() + .apply(Operations.PROCESS, requested={gateway.translator.to_identifier({"a": 0})}) + .operation_results[0] + .updates + ) assert has_state( tables, @@ -408,7 +413,12 @@ def test_add_to_local_command_with_error() -> None: tables["local"].children(as_objects=True)[0].error_on_insert = RuntimeError try: - gateway.apply(process(gateway.create_link(), requested={gateway.translator.to_identifier({"a": 0})}).updates) + gateway.apply( + gateway.create_link() + .apply(Operations.PROCESS, requested={gateway.translator.to_identifier({"a": 0})}) + .operation_results[0] + .updates + ) except RuntimeError: pass @@ -425,7 +435,12 @@ def test_add_to_local_command_with_external_file(tmpdir: Path) -> None: tables["source"].insert([{"a": 0, "external": insert_filepath}]) os.remove(insert_filepath) tables["outbound"].insert([{"a": 0, "process": "PULL", "is_flagged": "FALSE", "is_deprecated": "FALSE"}]) - gateway.apply(process(gateway.create_link(), requested={gateway.translator.to_identifier({"a": 0})}).updates) + gateway.apply( + gateway.create_link() + .apply(Operations.PROCESS, requested={gateway.translator.to_identifier({"a": 0})}) + .operation_results[0] + .updates + ) fetch_filepath = Path(tables["local"].fetch(as_dict=True, download_path=str(tmpdir))[0]["external"]) with fetch_filepath.open(mode="rb") as file: assert file.read() == data @@ -444,7 +459,12 @@ def test_remove_from_local_command() -> None: ) with as_stdin(StringIO("y")): - gateway.apply(process(gateway.create_link(), requested={gateway.translator.to_identifier({"a": 0})}).updates) + gateway.apply( + gateway.create_link() + .apply(Operations.PROCESS, requested={gateway.translator.to_identifier({"a": 0})}) + .operation_results[0] + .updates + ) assert has_state( tables, @@ -460,7 +480,12 @@ def test_start_pull_process() -> None: "link", primary={"a"}, non_primary={"b"}, initial=State(source=TableState([{"a": 0, "b": 1}])) ) - gateway.apply(start_pull(gateway.create_link(), requested={gateway.translator.to_identifier({"a": 0})}).updates) + gateway.apply( + gateway.create_link() + .apply(Operations.START_PULL, requested={gateway.translator.to_identifier({"a": 0})}) + .operation_results[0] + .updates + ) assert has_state( tables, @@ -485,7 +510,12 @@ def initial_state() -> State: def test_state_after_command(initial_state: State) -> None: tables, gateway = initialize("link", primary={"a"}, non_primary={"b"}, initial=initial_state) - gateway.apply(process(gateway.create_link(), requested={gateway.translator.to_identifier({"a": 0})}).updates) + gateway.apply( + gateway.create_link() + .apply(Operations.PROCESS, requested={gateway.translator.to_identifier({"a": 0})}) + .operation_results[0] + .updates + ) assert has_state( tables, @@ -503,7 +533,10 @@ def test_rollback_on_error(initial_state: State) -> None: tables["outbound"].error_on_insert = RuntimeError try: gateway.apply( - process(gateway.create_link(), requested={gateway.translator.to_identifier({"a": 0})}).updates + gateway.create_link() + .apply(Operations.PROCESS, requested={gateway.translator.to_identifier({"a": 0})}) + .operation_results[0] + .updates ) except RuntimeError: pass @@ -526,7 +559,10 @@ def test_state_after_command(initial_state: State) -> None: tables, gateway = initialize("link", primary={"a"}, non_primary={"b"}, initial=initial_state) gateway.apply( - start_delete(gateway.create_link(), requested={gateway.translator.to_identifier({"a": 0})}).updates + gateway.create_link() + .apply(Operations.START_DELETE, requested={gateway.translator.to_identifier({"a": 0})}) + .operation_results[0] + .updates ) assert has_state( @@ -545,7 +581,10 @@ def test_rollback_on_error(initial_state: State) -> None: tables["outbound"].error_on_insert = RuntimeError try: gateway.apply( - start_delete(gateway.create_link(), requested={gateway.translator.to_identifier({"a": 0})}).updates + gateway.create_link() + .apply(Operations.START_DELETE, requested={gateway.translator.to_identifier({"a": 0})}) + .operation_results[0] + .updates ) except RuntimeError: pass @@ -564,7 +603,12 @@ def test_finish_delete_process_command() -> None: ), ) - gateway.apply(process(gateway.create_link(), requested={gateway.translator.to_identifier({"a": 0})}).updates) + gateway.apply( + gateway.create_link() + .apply(Operations.PROCESS, requested={gateway.translator.to_identifier({"a": 0})}) + .operation_results[0] + .updates + ) assert has_state(tables, State(source=TableState([{"a": 0, "b": 1}]))) @@ -582,7 +626,12 @@ def initial_state() -> State: def test_state_after_command(initial_state: State) -> None: tables, gateway = initialize("link", primary={"a"}, non_primary={"b"}, initial=initial_state) - gateway.apply(process(gateway.create_link(), requested={gateway.translator.to_identifier({"a": 0})}).updates) + gateway.apply( + gateway.create_link() + .apply(Operations.PROCESS, requested={gateway.translator.to_identifier({"a": 0})}) + .operation_results[0] + .updates + ) assert has_state( tables, @@ -599,7 +648,10 @@ def test_rollback_on_error(initial_state: State) -> None: tables["outbound"].error_on_insert = RuntimeError try: gateway.apply( - process(gateway.create_link(), requested={gateway.translator.to_identifier({"a": 0})}).updates + gateway.create_link() + .apply(Operations.PROCESS, requested={gateway.translator.to_identifier({"a": 0})}) + .operation_results[0] + .updates ) except RuntimeError: pass @@ -626,7 +678,10 @@ def test_applying_multiple_commands() -> None: with as_stdin(StringIO("y")): gateway.apply( - process(gateway.create_link(), requested=gateway.translator.to_identifiers([{"a": 0}, {"a": 1}])).updates + gateway.create_link() + .apply(Operations.PROCESS, requested=gateway.translator.to_identifiers([{"a": 0}, {"a": 1}])) + .operation_results[0] + .updates ) assert has_state( diff --git a/tests/integration/test_services.py b/tests/integration/test_services.py index 12766b6d..444eab47 100644 --- a/tests/integration/test_services.py +++ b/tests/integration/test_services.py @@ -1,15 +1,12 @@ from __future__ import annotations -from collections.abc import Callable, Iterable, Mapping +from collections.abc import Callable from functools import partial from typing import Generic, TypedDict, TypeVar import pytest -from link.domain.custom_types import Identifier -from link.domain.link import Link, create_link -from link.domain.state import Commands, Components, InvalidOperation, Operations, Processes, State, Update, states -from link.service.gateway import LinkGateway +from link.domain.state import Components, InvalidOperation, Operations, Processes, State, states from link.service.io import Service, make_responsive from link.service.services import ( DeleteRequest, @@ -30,51 +27,10 @@ start_delete_process, start_pull_process, ) +from link.service.uow import UnitOfWork from tests.assignments import create_assignments, create_identifier, create_identifiers - -class FakeLinkGateway(LinkGateway): - def __init__( - self, - assignments: Mapping[Components, Iterable[Identifier]], - *, - tainted_identifiers: Iterable[Identifier] | None = None, - processes: Mapping[Processes, Iterable[Identifier]] | None = None, - ) -> None: - self.assignments = {component: set(identifiers) for component, identifiers in assignments.items()} - self.tainted_identifiers = set(tainted_identifiers) if tainted_identifiers is not None else set() - self.processes: dict[Processes, set[Identifier]] = {process: set() for process in Processes} - if processes is not None: - for entity_process, identifiers in processes.items(): - self.processes[entity_process].update(identifiers) - - def create_link(self) -> Link: - return create_link(self.assignments, tainted_identifiers=self.tainted_identifiers, processes=self.processes) - - def apply(self, updates: Iterable[Update]) -> None: - for update in updates: - if update.command is Commands.START_PULL_PROCESS: - self.processes[Processes.PULL].add(update.identifier) - self.assignments[Components.OUTBOUND].add(update.identifier) - elif update.command is Commands.ADD_TO_LOCAL: - self.assignments[Components.LOCAL].add(update.identifier) - elif update.command is Commands.FINISH_PULL_PROCESS: - self.processes[Processes.PULL].remove(update.identifier) - elif update.command is Commands.START_DELETE_PROCESS: - self.processes[Processes.DELETE].add(update.identifier) - elif update.command is Commands.REMOVE_FROM_LOCAL: - self.assignments[Components.LOCAL].remove(update.identifier) - elif update.command is Commands.FINISH_DELETE_PROCESS: - self.processes[Processes.DELETE].remove(update.identifier) - self.assignments[Components.OUTBOUND].remove(update.identifier) - elif update.command is Commands.DEPRECATE: - try: - self.processes[Processes.DELETE].remove(update.identifier) - except KeyError: - self.processes[Processes.PULL].remove(update.identifier) - else: - raise ValueError("Unsupported command encountered") - +from .gateway import FakeLinkGateway T = TypeVar("T", bound=Response) @@ -92,7 +48,7 @@ def __call__(self, response: T) -> None: self._response = response -def create_gateway(state: type[State], process: Processes | None = None, is_tainted: bool = False) -> FakeLinkGateway: +def create_uow(state: type[State], process: Processes | None = None, is_tainted: bool = False) -> UnitOfWork: if state in (states.Activated, states.Received): assert process is not None else: @@ -112,22 +68,26 @@ def create_gateway(state: type[State], process: Processes | None = None, is_tain processes = {} assignments = {Components.SOURCE: {"1"}} if state is states.Idle: - return FakeLinkGateway( - create_assignments(assignments), tainted_identifiers=tainted_identifiers, processes=processes + return UnitOfWork( + FakeLinkGateway( + create_assignments(assignments), tainted_identifiers=tainted_identifiers, processes=processes + ) ) assignments[Components.OUTBOUND] = {"1"} if state in (states.Deprecated, states.Activated): - return FakeLinkGateway( - create_assignments(assignments), tainted_identifiers=tainted_identifiers, processes=processes + return UnitOfWork( + FakeLinkGateway( + create_assignments(assignments), tainted_identifiers=tainted_identifiers, processes=processes + ) ) assignments[Components.LOCAL] = {"1"} - return FakeLinkGateway( - create_assignments(assignments), tainted_identifiers=tainted_identifiers, processes=processes + return UnitOfWork( + FakeLinkGateway(create_assignments(assignments), tainted_identifiers=tainted_identifiers, processes=processes) ) -def create_process_to_completion_service(gateway: FakeLinkGateway) -> Callable[[ProcessToCompletionRequest], None]: - process_service = partial(make_responsive(partial(process, link_gateway=gateway)), output_port=lambda x: None) +def create_process_to_completion_service(uow: UnitOfWork) -> Callable[[ProcessToCompletionRequest], None]: + process_service = partial(make_responsive(partial(process, uow=uow)), output_port=lambda x: None) return partial( make_responsive( partial( @@ -139,10 +99,10 @@ def create_process_to_completion_service(gateway: FakeLinkGateway) -> Callable[[ ) -def create_pull_service(gateway: FakeLinkGateway) -> Service[PullRequest, PullResponse]: - process_to_completion_service = create_process_to_completion_service(gateway) +def create_pull_service(uow: UnitOfWork) -> Service[PullRequest, PullResponse]: + process_to_completion_service = create_process_to_completion_service(uow) start_pull_process_service = partial( - make_responsive(partial(start_pull_process, link_gateway=gateway)), output_port=lambda x: None + make_responsive(partial(start_pull_process, uow=uow)), output_port=lambda x: None ) return partial( pull, @@ -151,10 +111,10 @@ def create_pull_service(gateway: FakeLinkGateway) -> Service[PullRequest, PullRe ) -def create_delete_service(gateway: FakeLinkGateway) -> Service[DeleteRequest, DeleteResponse]: - process_to_completion_service = create_process_to_completion_service(gateway) +def create_delete_service(uow: UnitOfWork) -> Service[DeleteRequest, DeleteResponse]: + process_to_completion_service = create_process_to_completion_service(uow) start_delete_process_service = partial( - make_responsive(partial(start_delete_process, link_gateway=gateway)), output_port=lambda x: None + make_responsive(partial(start_delete_process, uow=uow)), output_port=lambda x: None ) return partial( delete, @@ -203,18 +163,21 @@ class EntityConfig(TypedDict): ], ) def test_deleted_entity_ends_in_correct_state(state: EntityConfig, expected: type[State]) -> None: - gateway = create_gateway(**state) - delete_service = create_delete_service(gateway) + uow = create_uow(**state) + delete_service = create_delete_service(uow) delete_service(DeleteRequest(frozenset(create_identifiers("1"))), output_port=lambda x: None) - assert next(iter(gateway.create_link())).state is expected + with uow: + assert next(iter(uow.link)).state is expected def test_correct_response_model_gets_passed_to_delete_output_port() -> None: - gateway = FakeLinkGateway( - create_assignments({Components.SOURCE: {"1"}, Components.OUTBOUND: {"1"}, Components.LOCAL: {"1"}}) + uow = UnitOfWork( + FakeLinkGateway( + create_assignments({Components.SOURCE: {"1"}, Components.OUTBOUND: {"1"}, Components.LOCAL: {"1"}}) + ) ) output_port = FakeOutputPort[DeleteResponse]() - delete_service = create_delete_service(gateway) + delete_service = create_delete_service(uow) delete_service(DeleteRequest(frozenset(create_identifiers("1"))), output_port=output_port) assert output_port.response.requested == create_identifiers("1") @@ -237,13 +200,14 @@ def test_correct_response_model_gets_passed_to_delete_output_port() -> None: ], ) def test_pulled_entity_ends_in_correct_state(state: EntityConfig, expected: type[State]) -> None: - gateway = create_gateway(**state) - pull_service = create_pull_service(gateway) + uow = create_uow(**state) + pull_service = create_pull_service(uow) pull_service( PullRequest(frozenset(create_identifiers("1"))), output_port=lambda x: None, ) - assert next(iter(gateway.create_link())).state is expected + with uow: + assert next(iter(uow.link)).state is expected @pytest.mark.parametrize( @@ -272,7 +236,7 @@ def test_correct_response_model_gets_passed_to_pull_output_port(state: EntityCon } else: errors = set() - gateway = create_gateway(**state) + gateway = create_uow(**state) output_port = FakeOutputPort[PullResponse]() pull_service = create_pull_service(gateway) pull_service( @@ -283,28 +247,33 @@ def test_correct_response_model_gets_passed_to_pull_output_port(state: EntityCon def test_entity_undergoing_process_gets_processed() -> None: - gateway = FakeLinkGateway( - create_assignments({Components.SOURCE: {"1"}, Components.OUTBOUND: {"1"}}), - processes={Processes.PULL: create_identifiers("1")}, + uow = UnitOfWork( + FakeLinkGateway( + create_assignments({Components.SOURCE: {"1"}, Components.OUTBOUND: {"1"}}), + processes={Processes.PULL: create_identifiers("1")}, + ) ) process( ProcessRequest(frozenset(create_identifiers("1"))), - link_gateway=gateway, + uow=uow, output_port=FakeOutputPort[OperationResponse](), ) - entity = next(entity for entity in gateway.create_link() if entity.identifier == create_identifier("1")) - assert entity.state is states.Received + with uow: + entity = next(entity for entity in uow.link if entity.identifier == create_identifier("1")) + assert entity.state is states.Received def test_correct_response_model_gets_passed_to_process_output_port() -> None: - gateway = FakeLinkGateway( - create_assignments({Components.SOURCE: {"1"}, Components.OUTBOUND: {"1"}}), - processes={Processes.PULL: create_identifiers("1")}, + uow = UnitOfWork( + FakeLinkGateway( + create_assignments({Components.SOURCE: {"1"}, Components.OUTBOUND: {"1"}}), + processes={Processes.PULL: create_identifiers("1")}, + ) ) output_port = FakeOutputPort[OperationResponse]() process( ProcessRequest(frozenset(create_identifiers("1"))), - link_gateway=gateway, + uow=uow, output_port=output_port, ) assert output_port.response.requested == create_identifiers("1") @@ -312,9 +281,11 @@ def test_correct_response_model_gets_passed_to_process_output_port() -> None: def test_correct_response_model_gets_passed_to_list_idle_entities_output_port() -> None: - link_gateway = FakeLinkGateway( - create_assignments({Components.SOURCE: {"1", "2"}, Components.OUTBOUND: {"2"}, Components.LOCAL: {"2"}}) + uow = UnitOfWork( + FakeLinkGateway( + create_assignments({Components.SOURCE: {"1", "2"}, Components.OUTBOUND: {"2"}, Components.LOCAL: {"2"}}) + ) ) output_port = FakeOutputPort[ListIdleEntitiesResponse]() - list_idle_entities(ListIdleEntitiesRequest(), link_gateway=link_gateway, output_port=output_port) + list_idle_entities(ListIdleEntitiesRequest(), uow=uow, output_port=output_port) assert set(output_port.response.identifiers) == create_identifiers("1") diff --git a/tests/integration/test_uow.py b/tests/integration/test_uow.py new file mode 100644 index 00000000..651e8e44 --- /dev/null +++ b/tests/integration/test_uow.py @@ -0,0 +1,144 @@ +from __future__ import annotations + +from typing import Iterable, Mapping + +import pytest + +from link.domain.state import Components, Operations, states +from link.service.uow import UnitOfWork +from tests.assignments import create_assignments, create_identifier, create_identifiers + +from .gateway import FakeLinkGateway + + +def initialize(assignments: Mapping[Components, Iterable[str]]) -> tuple[FakeLinkGateway, UnitOfWork]: + gateway = FakeLinkGateway(create_assignments(assignments)) + return gateway, UnitOfWork(gateway) + + +def test_updates_are_applied_to_gateway_on_commit() -> None: + gateway, uow = initialize({Components.SOURCE: {"1", "2"}, Components.OUTBOUND: {"2"}, Components.LOCAL: {"2"}}) + with uow: + uow.link.apply(Operations.START_PULL, requested=create_identifiers("1")) + uow.link.apply(Operations.START_DELETE, requested=create_identifiers("2")) + uow.link.apply(Operations.PROCESS, requested=create_identifiers("1", "2")) + uow.link.apply(Operations.PROCESS, requested=create_identifiers("1", "2")) + uow.commit() + actual = {(entity.identifier, entity.state) for entity in gateway.create_link()} + expected = {(create_identifier("1"), states.Pulled), (create_identifier("2"), states.Idle)} + assert actual == expected + + +def test_updates_are_discarded_on_context_exit() -> None: + gateway, uow = initialize({Components.SOURCE: {"1", "2"}, Components.OUTBOUND: {"2"}, Components.LOCAL: {"2"}}) + with uow: + uow.link.apply(Operations.START_PULL, requested=create_identifiers("1")) + uow.link.apply(Operations.START_DELETE, requested=create_identifiers("2")) + uow.link.apply(Operations.PROCESS, requested=create_identifiers("1", "2")) + uow.link.apply(Operations.PROCESS, requested=create_identifiers("1", "2")) + actual = {(entity.identifier, entity.state) for entity in gateway.create_link()} + expected = {(create_identifier("1"), states.Idle), (create_identifier("2"), states.Pulled)} + assert actual == expected + + +def test_updates_are_discarded_on_rollback() -> None: + gateway, uow = initialize({Components.SOURCE: {"1", "2"}, Components.OUTBOUND: {"2"}, Components.LOCAL: {"2"}}) + with uow: + uow.link.apply(Operations.START_PULL, requested=create_identifiers("1")) + uow.link.apply(Operations.START_DELETE, requested=create_identifiers("2")) + uow.link.apply(Operations.PROCESS, requested=create_identifiers("1", "2")) + uow.link.apply(Operations.PROCESS, requested=create_identifiers("1", "2")) + uow.rollback() + actual = {(entity.identifier, entity.state) for entity in gateway.create_link()} + expected = {(create_identifier("1"), states.Idle), (create_identifier("2"), states.Pulled)} + assert actual == expected + + +def test_link_can_not_be_accessed_outside_of_context() -> None: + _, uow = initialize({Components.SOURCE: {"1"}}) + with uow: + pass + with pytest.raises(RuntimeError, match="outside"): + uow.link + + +def test_unable_to_commit_outside_of_context() -> None: + _, uow = initialize({Components.SOURCE: {"1"}}) + with pytest.raises(RuntimeError, match="outside"): + uow.commit() + + +def test_unable_to_rollback_outside_of_context() -> None: + _, uow = initialize({Components.SOURCE: {"1"}}) + with pytest.raises(RuntimeError, match="outside"): + uow.rollback() + + +def test_entity_expires_when_committing() -> None: + _, uow = initialize({Components.SOURCE: {"1"}}) + with uow: + entity = next(entity for entity in uow.link if entity.identifier == create_identifier("1")) + uow.commit() + with pytest.raises(RuntimeError, match="expired entity"): + entity.apply(Operations.START_PULL) + + +def test_entity_expires_when_rolling_back() -> None: + _, uow = initialize({Components.SOURCE: {"1"}}) + with uow: + entity = next(entity for entity in uow.link if entity.identifier == create_identifier("1")) + uow.rollback() + with pytest.raises(RuntimeError, match="expired entity"): + entity.apply(Operations.START_PULL) + + +def test_entity_expires_when_leaving_context() -> None: + _, uow = initialize({Components.SOURCE: {"1"}}) + with uow: + entity = next(entity for entity in uow.link if entity.identifier == create_identifier("1")) + with pytest.raises(RuntimeError, match="expired entity"): + entity.apply(Operations.START_PULL) + + +def test_entity_expires_when_applying_operation() -> None: + _, uow = initialize({Components.SOURCE: {"1"}}) + with uow: + entity = next(entity for entity in uow.link if entity.identifier == create_identifier("1")) + entity.apply(Operations.START_PULL) + with pytest.raises(RuntimeError, match="expired entity"): + entity.apply(Operations.PROCESS) + + +def test_link_expires_when_committing() -> None: + _, uow = initialize({Components.SOURCE: {"1"}}) + with uow: + link = uow.link + uow.commit() + with pytest.raises(RuntimeError, match="expired link"): + link.apply(Operations.START_PULL, requested=create_identifiers("1")) + + +def test_link_expires_when_rolling_back() -> None: + _, uow = initialize({Components.SOURCE: {"1"}}) + with uow: + link = uow.link + uow.rollback() + with pytest.raises(RuntimeError, match="expired link"): + link.apply(Operations.START_PULL, requested=create_identifiers("1")) + + +def test_link_expires_when_exiting_context() -> None: + _, uow = initialize({Components.SOURCE: {"1"}}) + with uow: + link = uow.link + with pytest.raises(RuntimeError, match="expired link"): + link.apply(Operations.START_PULL, requested=create_identifiers("1")) + + +def test_link_expires_when_applying_operation() -> None: + _, uow = initialize({Components.SOURCE: {"1"}}) + with uow: + link = uow.link + link.apply(Operations.START_PULL, requested=create_identifiers("1")) + with pytest.raises(RuntimeError, match="expired link"): + link.apply(Operations.PROCESS, requested=create_identifiers("1")) diff --git a/tests/unit/entities/test_link.py b/tests/unit/entities/test_link.py index 086d185a..9df845a3 100644 --- a/tests/unit/entities/test_link.py +++ b/tests/unit/entities/test_link.py @@ -6,8 +6,8 @@ import pytest from link.domain.custom_types import Identifier -from link.domain.link import Link, create_link, process, start_delete, start_pull -from link.domain.state import Components, Processes, State, states +from link.domain.link import Link, create_link +from link.domain.state import Components, Operations, Processes, State, states from tests.assignments import create_assignments, create_identifier, create_identifiers @@ -79,7 +79,7 @@ def test_entities_get_correct_process_assigned() -> None: (create_identifier("2"), Processes.DELETE), (create_identifier("3"), Processes.PULL), (create_identifier("4"), Processes.DELETE), - (create_identifier("5"), None), + (create_identifier("5"), Processes.NONE), } assert {(entity.identifier, entity.current_process) for entity in link} == set(expected) @@ -167,7 +167,7 @@ def test_can_get_identifiers_of_entities_in_component( assert set(link.identifiers) == create_identifiers("1", "2") -def test_process_produces_correct_updates() -> None: +def test_link_is_processed_correctly() -> None: link = create_link( create_assignments( { @@ -182,14 +182,15 @@ def test_process_produces_correct_updates() -> None: }, ) actual = { - (update.identifier, update.transition.new) - for update in process(link, requested=create_identifiers("1", "2", "3", "4")).updates + (entity.identifier, entity.state) + for entity in link.apply(Operations.PROCESS, requested=create_identifiers("1", "2", "3", "4")) } expected = { (create_identifier("1"), states.Received), (create_identifier("2"), states.Pulled), (create_identifier("3"), states.Idle), (create_identifier("4"), states.Activated), + (create_identifier("5"), states.Received), } assert actual == expected @@ -202,20 +203,20 @@ def link() -> Link: @staticmethod def test_idle_entity_becomes_activated(link: Link) -> None: - result = start_pull(link, requested=create_identifiers("1")) - update = next(iter(result.updates)) - assert update.identifier == create_identifier("1") - assert update.transition.new is states.Activated + link = link.apply(Operations.START_PULL, requested=create_identifiers("1")) + entity = next(iter(link)) + assert entity.identifier == create_identifier("1") + assert entity.state is states.Activated @staticmethod def test_not_specifying_requested_identifiers_raises_error(link: Link) -> None: with pytest.raises(AssertionError, match="No identifiers requested."): - start_pull(link, requested={}) + link.apply(Operations.START_PULL, requested={}) @staticmethod def test_specifying_identifiers_not_present_in_link_raises_error(link: Link) -> None: with pytest.raises(AssertionError, match="Requested identifiers not present in link."): - start_pull(link, requested=create_identifiers("2")) + link.apply(Operations.START_PULL, requested=create_identifiers("2")) @pytest.fixture() @@ -228,17 +229,17 @@ def link() -> Link: class TestStartDelete: @staticmethod def test_pulled_entity_becomes_received(link: Link) -> None: - result = start_delete(link, requested=create_identifiers("1")) - update = next(iter(result.updates)) - assert {update.identifier} == create_identifiers("1") - assert update.transition.new is states.Received + link = link.apply(Operations.START_DELETE, requested=create_identifiers("1")) + entity = next(iter(link)) + assert entity.identifier == create_identifier("1") + assert entity.state is states.Received @staticmethod def test_not_specifying_requested_identifiers_raises_error(link: Link) -> None: with pytest.raises(AssertionError, match="No identifiers requested."): - start_delete(link, requested={}) + link.apply(Operations.START_DELETE, requested={}) @staticmethod def test_specifying_identifiers_not_present_in_link_raises_error(link: Link) -> None: with pytest.raises(AssertionError, match="Requested identifiers not present in link."): - start_delete(link, requested=create_identifiers("2")) + link.apply(Operations.START_DELETE, requested=create_identifiers("2")) diff --git a/tests/unit/entities/test_state.py b/tests/unit/entities/test_state.py index dd59505b..c74e7c79 100644 --- a/tests/unit/entities/test_state.py +++ b/tests/unit/entities/test_state.py @@ -1,5 +1,6 @@ from __future__ import annotations +from dataclasses import replace from typing import Iterable import pytest @@ -21,17 +22,23 @@ @pytest.mark.parametrize( - ("identifier", "state", "methods"), + ("identifier", "state", "operations"), [ - (create_identifier("1"), states.Idle, ["start_delete", "process"]), - (create_identifier("2"), states.Activated, ["start_pull", "start_delete"]), - (create_identifier("3"), states.Received, ["start_pull", "start_delete"]), - (create_identifier("4"), states.Pulled, ["start_pull", "process"]), - (create_identifier("5"), states.Tainted, ["start_pull", "process"]), - (create_identifier("6"), states.Deprecated, ["start_pull", "start_delete", "process"]), + (create_identifier("1"), states.Idle, [Operations.START_DELETE, Operations.PROCESS]), + (create_identifier("2"), states.Activated, [Operations.START_PULL, Operations.START_DELETE]), + (create_identifier("3"), states.Received, [Operations.START_PULL, Operations.START_DELETE]), + (create_identifier("4"), states.Pulled, [Operations.START_PULL, Operations.PROCESS]), + (create_identifier("5"), states.Tainted, [Operations.START_PULL, Operations.PROCESS]), + ( + create_identifier("6"), + states.Deprecated, + [Operations.START_PULL, Operations.START_DELETE, Operations.PROCESS], + ), ], ) -def test_invalid_transitions_produce_no_updates(identifier: Identifier, state: type[State], methods: str) -> None: +def test_invalid_transitions_returns_unchanged_entity( + identifier: Identifier, state: type[State], operations: list[Operations] +) -> None: link = create_link( create_assignments( { @@ -44,41 +51,43 @@ def test_invalid_transitions_produce_no_updates(identifier: Identifier, state: t processes={Processes.PULL: create_identifiers("2", "3")}, ) entity = next(entity for entity in link if entity.identifier == identifier) - method_operation_map = { - "start_pull": Operations.START_PULL, - "start_delete": Operations.START_DELETE, - "process": Operations.PROCESS, - } - assert all( - getattr(entity, method)() == InvalidOperation(method_operation_map[method], entity.identifier, state) - for method in methods - ) + for operation in operations: + result = InvalidOperation(operation, identifier, state) + assert entity.apply(operation) == replace(entity, operation_results=(result,)) -def test_start_pulling_idle_entity_returns_correct_commands() -> None: +def test_start_pulling_idle_entity_returns_correct_entity() -> None: link = create_link(create_assignments({Components.SOURCE: {"1"}})) entity = next(iter(link)) - assert entity.start_pull() == Update( - Operations.START_PULL, - create_identifier("1"), - Transition(states.Idle, states.Activated), - command=Commands.START_PULL_PROCESS, + assert entity.apply(Operations.START_PULL) == replace( + entity, + state=states.Activated, + current_process=Processes.PULL, + operation_results=( + Update( + Operations.START_PULL, + entity.identifier, + Transition(states.Idle, states.Activated), + Commands.START_PULL_PROCESS, + ), + ), ) @pytest.mark.parametrize( - ("process", "tainted_identifiers", "new_state", "command"), + ("process", "tainted_identifiers", "new_state", "new_process", "command"), [ - (Processes.PULL, set(), states.Received, Commands.ADD_TO_LOCAL), - (Processes.PULL, create_identifiers("1"), states.Deprecated, Commands.DEPRECATE), - (Processes.DELETE, set(), states.Idle, Commands.FINISH_DELETE_PROCESS), - (Processes.DELETE, create_identifiers("1"), states.Deprecated, Commands.DEPRECATE), + (Processes.PULL, set(), states.Received, Processes.PULL, Commands.ADD_TO_LOCAL), + (Processes.PULL, create_identifiers("1"), states.Deprecated, Processes.NONE, Commands.DEPRECATE), + (Processes.DELETE, set(), states.Idle, Processes.NONE, Commands.FINISH_DELETE_PROCESS), + (Processes.DELETE, create_identifiers("1"), states.Deprecated, Processes.NONE, Commands.DEPRECATE), ], ) -def test_processing_activated_entity_returns_correct_commands( +def test_processing_activated_entity_returns_correct_entity( process: Processes, tainted_identifiers: Iterable[Identifier], new_state: type[State], + new_process: Processes, command: Commands, ) -> None: link = create_link( @@ -87,25 +96,29 @@ def test_processing_activated_entity_returns_correct_commands( tainted_identifiers=tainted_identifiers, ) entity = next(iter(link)) - assert entity.process() == Update( - Operations.PROCESS, - create_identifier("1"), - Transition(states.Activated, new_state), - command=command, + updated_results = entity.operation_results + ( + Update(Operations.PROCESS, entity.identifier, Transition(entity.state, new_state), command), + ) + assert entity.apply(Operations.PROCESS) == replace( + entity, state=new_state, current_process=new_process, operation_results=updated_results ) @pytest.mark.parametrize( - ("process", "tainted_identifiers", "new_state", "command"), + ("process", "tainted_identifiers", "new_state", "new_process", "command"), [ - (Processes.PULL, set(), states.Pulled, Commands.FINISH_PULL_PROCESS), - (Processes.PULL, create_identifiers("1"), states.Tainted, Commands.FINISH_PULL_PROCESS), - (Processes.DELETE, set(), states.Activated, Commands.REMOVE_FROM_LOCAL), - (Processes.DELETE, create_identifiers("1"), states.Activated, Commands.REMOVE_FROM_LOCAL), + (Processes.PULL, set(), states.Pulled, Processes.NONE, Commands.FINISH_PULL_PROCESS), + (Processes.PULL, create_identifiers("1"), states.Tainted, Processes.NONE, Commands.FINISH_PULL_PROCESS), + (Processes.DELETE, set(), states.Activated, Processes.DELETE, Commands.REMOVE_FROM_LOCAL), + (Processes.DELETE, create_identifiers("1"), states.Activated, Processes.DELETE, Commands.REMOVE_FROM_LOCAL), ], ) -def test_processing_received_entity_returns_correct_commands( - process: Processes, tainted_identifiers: Iterable[Identifier], new_state: type[State], command: Commands +def test_processing_received_entity_returns_correct_entity( + process: Processes, + tainted_identifiers: Iterable[Identifier], + new_state: type[State], + new_process: Processes, + command: Commands, ) -> None: link = create_link( create_assignments({Components.SOURCE: {"1"}, Components.OUTBOUND: {"1"}, Components.LOCAL: {"1"}}), @@ -113,24 +126,28 @@ def test_processing_received_entity_returns_correct_commands( tainted_identifiers=tainted_identifiers, ) entity = next(iter(link)) - assert entity.process() == Update( - Operations.PROCESS, - create_identifier("1"), - Transition(states.Received, new_state), - command=command, + operation_results = (Update(Operations.PROCESS, entity.identifier, Transition(entity.state, new_state), command),) + assert entity.apply(Operations.PROCESS) == replace( + entity, state=new_state, current_process=new_process, operation_results=operation_results ) -def test_starting_delete_on_pulled_entity_returns_correct_commands() -> None: +def test_starting_delete_on_pulled_entity_returns_correct_entity() -> None: link = create_link( create_assignments({Components.SOURCE: {"1"}, Components.OUTBOUND: {"1"}, Components.LOCAL: {"1"}}) ) entity = next(iter(link)) - assert entity.start_delete() == Update( - Operations.START_DELETE, - create_identifier("1"), - Transition(states.Pulled, states.Received), - command=Commands.START_DELETE_PROCESS, + transition = Transition(states.Pulled, states.Received) + operation_results = ( + Update( + Operations.START_DELETE, + entity.identifier, + transition, + Commands.START_DELETE_PROCESS, + ), + ) + assert entity.apply(Operations.START_DELETE) == replace( + entity, state=transition.new, current_process=Processes.DELETE, operation_results=operation_results ) @@ -140,9 +157,8 @@ def test_starting_delete_on_tainted_entity_returns_correct_commands() -> None: tainted_identifiers={create_identifier("1")}, ) entity = next(iter(link)) - assert entity.start_delete() == Update( - Operations.START_DELETE, - create_identifier("1"), - Transition(states.Tainted, states.Received), - command=Commands.START_DELETE_PROCESS, + transition = Transition(states.Tainted, states.Received) + operation_results = (Update(Operations.START_DELETE, entity.identifier, transition, Commands.START_DELETE_PROCESS),) + assert entity.apply(Operations.START_DELETE) == replace( + entity, state=transition.new, current_process=Processes.DELETE, operation_results=operation_results )