Skip to content

Commit

Permalink
Fix some uow bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
christoph-blessing committed Oct 24, 2023
1 parent ce06b1c commit 508f348
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 42 deletions.
68 changes: 49 additions & 19 deletions link/service/uow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from abc import ABC
from collections import defaultdict, deque
from types import TracebackType
from typing import Callable
from typing import Callable, Iterable, Protocol

from link.domain.custom_types import Identifier
from link.domain.link import Link
Expand All @@ -13,6 +13,13 @@
from .gateway import LinkGateway


class SupportsLinkApply(Protocol):
"""Protocol for an object that supports applying operations to links."""

def __call__(self, operation: Operations, *, requested: Iterable[Identifier]) -> Link:
"""Apply the operation to the link."""


class UnitOfWork(ABC):
"""Controls if and when updates to entities of a link are persisted."""

Expand All @@ -21,29 +28,48 @@ def __init__(self, gateway: LinkGateway) -> None:
self._gateway = gateway
self._link: Link | None = None
self._updates: dict[Identifier, deque[Update]] = defaultdict(deque)
self._entities: dict[Identifier, Entity] = {}

def __enter__(self) -> UnitOfWork:
"""Enter the context in which updates to entities can be made."""

def track_entity(entity: Entity) -> None:
apply = getattr(entity, "apply")
augmented = augment_apply(entity, apply)
def augment_link(link: Link) -> None:
original = getattr(link, "apply")
augmented = augment_link_apply(link, original)
object.__setattr__(link, "apply", augmented)
object.__setattr__(link, "_is_expired", False)

def augment_link_apply(current: Link, original: SupportsLinkApply) -> SupportsLinkApply:
def augmented(operation: Operations, *, requested: Iterable[Identifier]) -> Link:
assert hasattr(current, "_is_expired")
if current._is_expired:
raise RuntimeError("Can not apply operation to expired link")
self._link = original(operation, requested=requested)
augment_link(self._link)
object.__setattr__(current, "_is_expired", True)
return self._link

return augmented

def augment_entity(entity: Entity) -> None:
original = getattr(entity, "apply")
augmented = augment_entity_apply(entity, original)
object.__setattr__(entity, "apply", augmented)
object.__setattr__(entity, "_is_expired", False)
self._entities[entity.identifier] = entity

def augment_apply(current: Entity, apply: Callable[[Operations], Entity]) -> Callable[[Operations], Entity]:
def track_and_apply(operation: Operations) -> Entity:
def augment_entity_apply(
current: Entity, original: Callable[[Operations], Entity]
) -> Callable[[Operations], Entity]:
def augmented(operation: Operations) -> Entity:
assert hasattr(current, "_is_expired")
if current._is_expired is True:
raise RuntimeError("Can not apply operation to expired entity")
new = apply(operation)
new = original(operation)
store_update(operation, current, new)
track_entity(new)
augment_entity(new)
object.__setattr__(current, "_is_expired", True)
return new

return track_and_apply
return augmented

def store_update(operation: Operations, current: Entity, new: Entity) -> None:
assert current.identifier == new.identifier
Expand All @@ -54,17 +80,18 @@ def store_update(operation: Operations, current: Entity, new: Entity) -> None:
Update(operation, current.identifier, transition, TRANSITION_MAP[transition])
)

link = self._gateway.create_link()
for entity in link:
track_entity(entity)
self._link = link
self._link = self._gateway.create_link()
augment_link(self._link)
for entity in self._link:
augment_entity(entity)
return self

def __exit__(
self, exc_type: type[BaseException] | None, exc: BaseException | None, traceback: TracebackType | None
) -> None:
"""Exit the context rolling back any not yet persisted updates."""
self.rollback()
self._link = None

@property
def link(self) -> Link:
Expand All @@ -75,6 +102,8 @@ def link(self) -> Link:

def commit(self) -> None:
"""Persist updates made to the link."""
if self._link is None:
raise RuntimeError("Not available outside of context")
while self._updates:
identifier, updates = self._updates.popitem()
while updates:
Expand All @@ -83,8 +112,9 @@ def commit(self) -> None:

def rollback(self) -> None:
"""Throw away any not yet persisted updates."""
self._link = None
for entity in self._entities.values():
if self._link is None:
raise RuntimeError("Not available outside of context")
object.__setattr__(self._link, "_is_expired", True)
for entity in self._link:
object.__setattr__(entity, "_is_expired", True)
self._entities = {}
self._updates = defaultdict(deque)
self._updates.clear()
104 changes: 81 additions & 23 deletions tests/integration/test_uow.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ def initialize(assignments: Mapping[Components, Iterable[str]]) -> tuple[FakeLin
def test_updates_are_applied_to_gateway_on_commit() -> None:
gateway, uow = initialize({Components.SOURCE: {"1", "2"}, Components.OUTBOUND: {"2"}, Components.LOCAL: {"2"}})
with uow:
link = uow.link.apply(Operations.START_PULL, requested=create_identifiers("1"))
link = link.apply(Operations.START_DELETE, requested=create_identifiers("2"))
link = link.apply(Operations.PROCESS, requested=create_identifiers("1", "2"))
link.apply(Operations.PROCESS, requested=create_identifiers("1", "2"))
uow.link.apply(Operations.START_PULL, requested=create_identifiers("1"))
uow.link.apply(Operations.START_DELETE, requested=create_identifiers("2"))
uow.link.apply(Operations.PROCESS, requested=create_identifiers("1", "2"))
uow.link.apply(Operations.PROCESS, requested=create_identifiers("1", "2"))
uow.commit()
actual = {(entity.identifier, entity.state) for entity in gateway.create_link()}
expected = {(create_identifier("1"), states.Pulled), (create_identifier("2"), states.Idle)}
Expand All @@ -32,10 +32,10 @@ def test_updates_are_applied_to_gateway_on_commit() -> None:
def test_updates_are_discarded_on_context_exit() -> None:
gateway, uow = initialize({Components.SOURCE: {"1", "2"}, Components.OUTBOUND: {"2"}, Components.LOCAL: {"2"}})
with uow:
link = uow.link.apply(Operations.START_PULL, requested=create_identifiers("1"))
link = link.apply(Operations.START_DELETE, requested=create_identifiers("2"))
link = link.apply(Operations.PROCESS, requested=create_identifiers("1", "2"))
link.apply(Operations.PROCESS, requested=create_identifiers("1", "2"))
uow.link.apply(Operations.START_PULL, requested=create_identifiers("1"))
uow.link.apply(Operations.START_DELETE, requested=create_identifiers("2"))
uow.link.apply(Operations.PROCESS, requested=create_identifiers("1", "2"))
uow.link.apply(Operations.PROCESS, requested=create_identifiers("1", "2"))
actual = {(entity.identifier, entity.state) for entity in gateway.create_link()}
expected = {(create_identifier("1"), states.Idle), (create_identifier("2"), states.Pulled)}
assert actual == expected
Expand All @@ -44,10 +44,10 @@ def test_updates_are_discarded_on_context_exit() -> None:
def test_updates_are_discarded_on_rollback() -> None:
gateway, uow = initialize({Components.SOURCE: {"1", "2"}, Components.OUTBOUND: {"2"}, Components.LOCAL: {"2"}})
with uow:
link = uow.link.apply(Operations.START_PULL, requested=create_identifiers("1"))
link = link.apply(Operations.START_DELETE, requested=create_identifiers("2"))
link = link.apply(Operations.PROCESS, requested=create_identifiers("1", "2"))
link.apply(Operations.PROCESS, requested=create_identifiers("1", "2"))
uow.link.apply(Operations.START_PULL, requested=create_identifiers("1"))
uow.link.apply(Operations.START_DELETE, requested=create_identifiers("2"))
uow.link.apply(Operations.PROCESS, requested=create_identifiers("1", "2"))
uow.link.apply(Operations.PROCESS, requested=create_identifiers("1", "2"))
uow.rollback()
actual = {(entity.identifier, entity.state) for entity in gateway.create_link()}
expected = {(create_identifier("1"), states.Idle), (create_identifier("2"), states.Pulled)}
Expand All @@ -56,31 +56,89 @@ def test_updates_are_discarded_on_rollback() -> None:

def test_link_can_not_be_accessed_outside_of_context() -> None:
_, uow = initialize({Components.SOURCE: {"1"}})
with uow:
pass
with pytest.raises(RuntimeError, match="outside"):
uow.link.apply(Operations.START_PULL, requested=create_identifiers("1"))
uow.link


def test_no_more_operations_can_be_applied_after_commit() -> None:
def test_unable_to_commit_outside_of_context() -> None:
_, uow = initialize({Components.SOURCE: {"1"}})
with pytest.raises(RuntimeError, match="outside"):
uow.commit()


def test_unable_to_rollback_outside_of_context() -> None:
_, uow = initialize({Components.SOURCE: {"1"}})
with pytest.raises(RuntimeError, match="outside"):
uow.rollback()


def test_entity_expires_when_committing() -> None:
_, uow = initialize({Components.SOURCE: {"1"}})
with uow:
link = uow.link.apply(Operations.START_PULL, requested=create_identifiers("1"))
entity = next(entity for entity in uow.link if entity.identifier == create_identifier("1"))
uow.commit()
with pytest.raises(RuntimeError, match="expired"):
link.apply(Operations.PROCESS, requested=create_identifiers("1"))
with pytest.raises(RuntimeError, match="expired entity"):
entity.apply(Operations.START_PULL)


def test_no_more_operations_can_be_applied_after_rollback() -> None:
def test_entity_expires_when_rolling_back() -> None:
_, uow = initialize({Components.SOURCE: {"1"}})
with uow:
link = uow.link.apply(Operations.START_PULL, requested=create_identifiers("1"))
entity = next(entity for entity in uow.link if entity.identifier == create_identifier("1"))
uow.rollback()
with pytest.raises(RuntimeError, match="expired"):
with pytest.raises(RuntimeError, match="expired entity"):
entity.apply(Operations.START_PULL)


def test_entity_expires_when_leaving_context() -> None:
_, uow = initialize({Components.SOURCE: {"1"}})
with uow:
entity = next(entity for entity in uow.link if entity.identifier == create_identifier("1"))
with pytest.raises(RuntimeError, match="expired entity"):
entity.apply(Operations.START_PULL)


def test_entity_expires_when_applying_operation() -> None:
_, uow = initialize({Components.SOURCE: {"1"}})
with uow:
entity = next(entity for entity in uow.link if entity.identifier == create_identifier("1"))
entity.apply(Operations.START_PULL)
with pytest.raises(RuntimeError, match="expired entity"):
entity.apply(Operations.PROCESS)


def test_link_expires_when_committing() -> None:
_, uow = initialize({Components.SOURCE: {"1"}})
with uow:
link = uow.link
uow.commit()
with pytest.raises(RuntimeError, match="expired link"):
link.apply(Operations.START_PULL, requested=create_identifiers("1"))


def test_link_expires_when_rolling_back() -> None:
_, uow = initialize({Components.SOURCE: {"1"}})
with uow:
link = uow.link
uow.rollback()
with pytest.raises(RuntimeError, match="expired link"):
link.apply(Operations.START_PULL, requested=create_identifiers("1"))


def test_no_more_operations_can_be_applied_after_exiting_context() -> None:
def test_link_expires_when_exiting_context() -> None:
_, uow = initialize({Components.SOURCE: {"1"}})
with uow:
link = uow.link
with pytest.raises(RuntimeError, match="expired link"):
link.apply(Operations.START_PULL, requested=create_identifiers("1"))


def test_link_expires_when_applying_operation() -> None:
_, uow = initialize({Components.SOURCE: {"1"}})
with uow:
link = uow.link.apply(Operations.START_PULL, requested=create_identifiers("1"))
with pytest.raises(RuntimeError, match="expired"):
link = uow.link
link.apply(Operations.START_PULL, requested=create_identifiers("1"))
with pytest.raises(RuntimeError, match="expired link"):
link.apply(Operations.PROCESS, requested=create_identifiers("1"))

0 comments on commit 508f348

Please sign in to comment.