Skip to content

Commit

Permalink
Simplify imports
Browse files Browse the repository at this point in the history
  • Loading branch information
Tomperez98 committed Jan 8, 2025
1 parent 92d8c01 commit 7954af2
Show file tree
Hide file tree
Showing 12 changed files with 89 additions and 47 deletions.
12 changes: 10 additions & 2 deletions src/resonate/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
from __future__ import annotations

from . import random
from resonate.dataclasses import DurablePromise
from resonate.handle import Handle
from resonate.promise import Promise
from resonate.resonate import Resonate

__all__ = ["random"]
__all__ = [
"Resonate",
"Handle",
"DurablePromise",
"Promise",
]
8 changes: 5 additions & 3 deletions src/resonate/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@

from typing_extensions import ParamSpec, assert_never

from resonate.promise import Promise
from resonate.result import Err, Ok, Result

if TYPE_CHECKING:
from collections.abc import Generator

from resonate.record import Handle, Record
from resonate.handle import Handle
from resonate.record import Record
from resonate.scheduler.traits import IScheduler
from resonate.typing import Data, DurableCoro, DurableFn, Headers, Tags, Yieldable

Expand Down Expand Up @@ -93,7 +95,7 @@ def _send(self, v: T) -> Yieldable:
self._next_child_to_yield += 1
if child.done():
continue
return child.promise
return Promise[Any](child.id)

assert all(child.done() for child in self._record.children)
assert not self._coro_active
Expand All @@ -117,7 +119,7 @@ def _throw(self, error: Exception) -> Yieldable:
self._next_child_to_yield += 1
if child.done():
continue
return child.promise
return Promise[Any](child.id)
assert all(
child.done() for child in self._record.children
), "All children promise must have been resolved."
Expand Down
20 changes: 20 additions & 0 deletions src/resonate/handle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Generic, TypeVar, final

if TYPE_CHECKING:
from concurrent.futures import Future


T = TypeVar("T")


@final
@dataclass(frozen=True)
class Handle(Generic[T]):
id: str
f: Future[T] = field(repr=False)

def result(self, timeout: float | None = None) -> T:
return self.f.result(timeout=timeout)
12 changes: 12 additions & 0 deletions src/resonate/promise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Generic, TypeVar, final

T = TypeVar("T")


@final
@dataclass(frozen=True)
class Promise(Generic[T]):
id: str
29 changes: 5 additions & 24 deletions src/resonate/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from asyncio import iscoroutinefunction
from concurrent.futures import Future
from dataclasses import dataclass, field
from inspect import isfunction, isgeneratorfunction
from typing import TYPE_CHECKING, Any, Generic, TypeVar, final

Expand All @@ -25,22 +24,6 @@
T = TypeVar("T")


@final
@dataclass(frozen=True)
class Promise(Generic[T]):
id: str


@final
@dataclass(frozen=True)
class Handle(Generic[T]):
id: str
_f: Future[T] = field(repr=False)

def result(self, timeout: float | None = None) -> T:
return self._f.result(timeout=timeout)


@final
class Record(Generic[T]):
def __init__(
Expand All @@ -55,7 +38,7 @@ def __init__(
self.is_root: bool = (
True if self.parent is None else isinstance(invocation, RFI)
)
self._f = Future[T]()
self.f = Future[T]()
self.children: list[Record[Any]] = []
self.invocation: LFI | RFI = invocation
self.retry_policy: retry_policy.RetryPolicy | None
Expand All @@ -77,8 +60,6 @@ def __init__(
)

self._attempt: int = 1
self.promise = Promise[T](id=id)
self.handle = Handle[T](id=self.id, _f=self._f)
self.durable_promise: DurablePromiseRecord | None = None
self._task: TaskRecord | None = None
self.ctx = ctx
Expand Down Expand Up @@ -168,21 +149,21 @@ def set_result(self, result: Result[T, Exception], *, deduping: bool) -> None:
r.done() for r in self.children
), "All children record must be completed."
if isinstance(result, Ok):
self._f.set_result(result.unwrap())
self.f.set_result(result.unwrap())
elif isinstance(result, Err):
self._f.set_exception(result.err())
self.f.set_exception(result.err())
else:
assert_never(result)

def safe_result(self) -> Result[Any, Exception]:
assert self.done()
try:
return Ok(self._f.result())
return Ok(self.f.result())
except Exception as e: # noqa: BLE001
return Err(e)

def done(self) -> bool:
return self._f.done()
return self.f.done()

def next_child_name(self) -> str:
return f"{self.id}.{self._num_children+1}"
Expand Down
2 changes: 1 addition & 1 deletion src/resonate/resonate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from resonate import retry_policy
from resonate.context import Context
from resonate.record import Handle
from resonate.handle import Handle
from resonate.scheduler.traits import IScheduler
from resonate.stores.local import LocalStore
from resonate.task_sources.traits import ITaskSource
Expand Down
25 changes: 16 additions & 9 deletions src/resonate/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@
ResonateCoro,
)
from resonate.encoders import JsonEncoder
from resonate.handle import Handle
from resonate.logging import logger
from resonate.processor.processor import Processor
from resonate.promise import Promise
from resonate.queue import DelayQueue
from resonate.record import Promise, Record
from resonate.record import Record
from resonate.result import Err, Ok, Result
from resonate.scheduler.traits import IScheduler
from resonate.stores.record import (
Expand All @@ -49,7 +51,6 @@
if TYPE_CHECKING:
from resonate.collections import FunctionRegistry
from resonate.dependencies import Dependencies
from resonate.record import Handle
from resonate.stores.local import LocalStore
from resonate.stores.record import TaskRecord
from resonate.task_sources.traits import ITaskSource
Expand Down Expand Up @@ -117,7 +118,7 @@ def run(
# If there's already a record with this ID, dedup.
record = self._records.get(id)
if record is not None:
return record.handle
return Handle[T](record.id, record.f)

# Get function name from registry
fn_name = self._registry.get_from_value(func)
Expand Down Expand Up @@ -162,7 +163,7 @@ def run(
else:
self._cmd_queue.put(Invoke(record.id))

return record.handle
return Handle[T](record.id, record.f)

def _heartbeat(self) -> None:
assert isinstance(self._store, RemoteStore)
Expand Down Expand Up @@ -612,7 +613,9 @@ def _process_rfi(self, record: Record[Any], rfi: RFI) -> list[Command]:
child_record = self._records.get(child_id)
if child_record is not None:
record.add_child(child_record)
loopbacks.extend(self._handle_continue(record.id, Ok(child_record.promise)))
loopbacks.extend(
self._handle_continue(record.id, Ok(Promise[Any](child_record.id)))
)
else:
child_record = record.create_child(id=child_id, invocation=rfi)
self._records[child_id] = child_record
Expand All @@ -636,7 +639,9 @@ def _process_rfi(self, record: Record[Any], rfi: RFI) -> list[Command]:
if durable_promise.is_completed():
value = durable_promise.get_value(self._encoder)
child_record.set_result(value, deduping=True)
loopbacks.extend(self._handle_continue(record.id, Ok(child_record.promise)))
loopbacks.extend(
self._handle_continue(record.id, Ok(Promise[Any](child_record.id)))
)

return loopbacks

Expand All @@ -646,7 +651,9 @@ def _process_lfi(self, record: Record[Any], lfi: LFI) -> list[Command]:
child_record = self._records.get(child_id)
if child_record is not None:
record.add_child(child_record)
loopbacks.extend(self._handle_continue(record.id, Ok(child_record.promise)))
loopbacks.extend(
self._handle_continue(record.id, Ok(Promise[Any](child_record.id)))
)
else:
child_record = record.create_child(id=child_id, invocation=lfi)
self._records[child_id] = child_record
Expand All @@ -672,12 +679,12 @@ def _process_lfi(self, record: Record[Any], lfi: LFI) -> list[Command]:
else:
loopbacks.append(Invoke(child_id))
loopbacks.extend(
self._handle_continue(record.id, Ok(child_record.promise))
self._handle_continue(record.id, Ok(Promise[Any](child_record.id)))
)
else:
loopbacks.append(Invoke(child_id))
loopbacks.extend(
self._handle_continue(record.id, Ok(child_record.promise))
self._handle_continue(record.id, Ok(Promise[Any](child_record.id)))
)

return loopbacks
Expand Down
2 changes: 1 addition & 1 deletion src/resonate/scheduler/traits.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing_extensions import ParamSpec

if TYPE_CHECKING:
from resonate.record import Handle
from resonate.handle import Handle
from resonate.typing import DurableCoro, DurableFn

P = ParamSpec("P")
Expand Down
9 changes: 9 additions & 0 deletions src/resonate/stores/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from __future__ import annotations

from .local import LocalStore
from .remote import RemoteStore

__all__ = [
"LocalStore",
"RemoteStore",
]
5 changes: 5 additions & 0 deletions src/resonate/task_sources/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from __future__ import annotations

from .poller import Poller

__all__ = ["Poller"]
2 changes: 1 addition & 1 deletion src/resonate/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
RFI,
)
from resonate.context import Context
from resonate.record import Promise
from resonate.promise import Promise

T = TypeVar("T")
P = ParamSpec("P")
Expand Down
10 changes: 4 additions & 6 deletions tests/test_functionality.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,12 @@

import pytest

from resonate.dataclasses import DurablePromise
from resonate.record import Handle, Promise
from resonate.resonate import Resonate
from resonate import DurablePromise, Handle, Resonate
from resonate.promise import Promise
from resonate.retry_policy import constant, exponential, linear, never
from resonate.stores.local import LocalStore
from resonate.stores.remote import RemoteStore
from resonate.stores import LocalStore, RemoteStore
from resonate.targets import poll
from resonate.task_sources.poller import Poller
from resonate.task_sources import Poller

if TYPE_CHECKING:
from collections.abc import Generator
Expand Down

0 comments on commit 7954af2

Please sign in to comment.