Skip to content

Commit

Permalink
Decorator to register options independent of register function (#133)
Browse files Browse the repository at this point in the history
* Add syntax sugar to create a promise without any invocation

* Fighting with the type system

* property promises to resonate

* Need to get type hint from registered fn

* Test passing

* Expose all low lvl api methods

* Untyped register that fits us all

* return registered fn in all overloads

* fix lint
  • Loading branch information
Tomperez98 authored Jan 28, 2025
1 parent 2bda1c3 commit 32cd7ef
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 70 deletions.
8 changes: 7 additions & 1 deletion src/resonate/context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import sys
from typing import TYPE_CHECKING, Any, Callable, TypeVar, final, overload

from typing_extensions import Concatenate, ParamSpec
Expand All @@ -18,7 +19,7 @@
from collections.abc import Coroutine, Generator

from resonate.dependencies import Dependencies
from resonate.typing import DurableCoro, DurableFn, Yieldable
from resonate.typing import Data, DurableCoro, DurableFn, Yieldable

P = ParamSpec("P")
T = TypeVar("T")
Expand All @@ -42,6 +43,11 @@ def sleep(self, secs: int) -> RFC:
)
)

def promise(
self, id: str | None = None, data: Data = None, timeout: int = sys.maxsize
) -> RFI:
return self.rfi(DurablePromise(id=id, data=data, timeout=timeout))

@overload
def rfc(self, cmd: DurablePromise, /) -> RFC: ...
@overload
Expand Down
112 changes: 64 additions & 48 deletions src/resonate/resonate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, TypedDict, TypeVar, overload
from typing import TYPE_CHECKING, Any, Callable, TypeVar, overload
from uuid import uuid4

from typing_extensions import Concatenate, ParamSpec
Expand All @@ -18,22 +18,19 @@
if TYPE_CHECKING:
from collections.abc import Coroutine, Generator

from resonate import retry_policy
from resonate.context import Context
from resonate.handle import Handle
from resonate.retry_policy import RetryPolicy
from resonate.scheduler.traits import IScheduler
from resonate.stores.local import LocalStore
from resonate.stores.traits import IPromiseStore
from resonate.task_sources.traits import ITaskSource
from resonate.typing import DurableCoro, DurableFn, Yieldable
from resonate.typing import Yieldable

P = ParamSpec("P")
T = TypeVar("T")


class _RunOptions(TypedDict):
version: int


class Resonate:
"""
The Resonate class serves as the main API interface for Resonate Application Nodes.
Expand Down Expand Up @@ -80,11 +77,13 @@ def __init__(
self._deps = Dependencies()
self._registry = FunctionRegistry()

self._store = store or RemoteStore()

self._scheduler: IScheduler = Scheduler(
deps=self._deps,
pid=pid or uuid4().hex,
registry=self._registry,
store=store or RemoteStore(),
store=self._store,
task_source=task_source or Poller(),
)

Expand Down Expand Up @@ -114,40 +113,79 @@ def register(
Concatenate[Context, P],
Generator[Yieldable, Any, T],
],
/,
*,
name: str | None = None,
version: int = 1,
retry_policy: retry_policy.RetryPolicy | None = None,
retry_policy: RetryPolicy | None = None,
) -> RegisteredFn[P, T]: ...
@overload
def register(
self,
func: Callable[Concatenate[Context, P], Coroutine[Any, Any, T]],
/,
*,
name: str | None = None,
version: int = 1,
retry_policy: retry_policy.RetryPolicy | None = None,
retry_policy: RetryPolicy | None = None,
) -> RegisteredFn[P, T]: ...
@overload
def register(
self,
func: Callable[Concatenate[Context, P], T],
/,
*,
name: str | None = None,
version: int = 1,
retry_policy: retry_policy.RetryPolicy | None = None,
retry_policy: RetryPolicy | None = None,
) -> RegisteredFn[P, T]: ...
@overload
def register(
self,
func: DurableCoro[P, T] | DurableFn[P, T],
*,
name: str | None = None,
version: int = 1,
retry_policy: retry_policy.RetryPolicy | None = None,
) -> RegisteredFn[P, T]:
if name is None:
name = func.__name__
self._registry.add(
name,
(func, Options(version=version, durable=True, retry_policy=retry_policy)),
)
return RegisteredFn[P, T](self._scheduler, func)
retry_policy: RetryPolicy | None = None,
) -> Callable[
[Callable[Concatenate[Context, P], Any]],
RegisteredFn[P, Any],
]: ...
def register(
self, *args: Any, **kwargs: Any
) -> (
RegisteredFn[P, T]
| Callable[
[Callable[Concatenate[Context, P], Any]],
RegisteredFn[P, T],
]
):
name: str | None = kwargs.get("name")
version: int = kwargs.get("version", 1)
retry_policy: RetryPolicy | None = kwargs.get("retry_policy")
if args and callable(args[0]):
func = args[0]
self._registry.add(
name or func.__name__,
(
func,
Options(version=version, durable=True, retry_policy=retry_policy),
),
)
return RegisteredFn[P, T](self._scheduler, func) # type: ignore[arg-type, unused-ignore]

def wrapper(
func: Callable[Concatenate[Context, P], Any],
) -> RegisteredFn[P, Any]:
self._registry.add(
name or func.__name__,
(
func,
Options(version=version, durable=True, retry_policy=retry_policy),
),
)
return RegisteredFn(self._scheduler, func)

return wrapper

@overload
def run(
Expand All @@ -172,23 +210,6 @@ def run(
*args: P.args,
**kwargs: P.kwargs,
) -> Handle[T]: ...
@overload
def run(
self,
id: str,
func: tuple[
Callable[
Concatenate[Context, P],
Generator[Yieldable, Any, T],
]
| Callable[Concatenate[Context, P], T]
| Callable[Concatenate[Context, P], Coroutine[Any, Any, T]],
_RunOptions,
],
/,
*args: P.args,
**kwargs: P.kwargs,
) -> Handle[T]: ...
def run(
self,
id: str,
Expand All @@ -198,16 +219,7 @@ def run(
Generator[Yieldable, Any, T],
]
| Callable[Concatenate[Context, P], T]
| Callable[Concatenate[Context, P], Coroutine[Any, Any, T]]
| tuple[
Callable[
Concatenate[Context, P],
Generator[Yieldable, Any, T],
]
| Callable[Concatenate[Context, P], T]
| Callable[Concatenate[Context, P], Coroutine[Any, Any, T]],
_RunOptions,
],
| Callable[Concatenate[Context, P], Coroutine[Any, Any, T]],
/,
*args: P.args,
**kwargs: P.kwargs,
Expand All @@ -217,3 +229,7 @@ def run(
if isinstance(func, RegisteredFn):
func = func.fn
return self._scheduler.run(id, func, *args, **kwargs)

@property
def promises(self) -> IPromiseStore:
return self._store.promises
44 changes: 23 additions & 21 deletions tests/test_functionality.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import pytest

from resonate import DurablePromise, Handle, Resonate
from resonate import Handle, Resonate
from resonate.promise import Promise
from resonate.retry_policy import constant, exponential, linear, never
from resonate.stores import LocalStore, RemoteStore
Expand Down Expand Up @@ -459,7 +459,7 @@ def test_golden_device_rfi_and_lfc_with_decorator() -> None:
task_source=Poller("http://localhost:8002", group=group),
)

@resonate.register
@resonate.register()
def foo(ctx: Context, n: str) -> Generator[Yieldable, Any, str]:
v: str = yield ctx.lfc(bar, n).options(
id="bar",
Expand All @@ -472,7 +472,7 @@ def bar(ctx: Context, n: str) -> Generator[Yieldable, Any, str]:
v: str = yield p
return v

@resonate.register
@resonate.register()
def baz(ctx: Context, n: str) -> str: # noqa: ARG001
return n

Expand Down Expand Up @@ -501,9 +501,9 @@ def bar_golden_device_rfc(ctx: Context, n: str) -> str: # noqa: ARG001
store=RemoteStore(url=os.environ["RESONATE_STORE_URL"]),
task_source=Poller("http://localhost:8002", group=group),
)
resonate.register(foo_golden_device_rfc)
rf = resonate.register(foo_golden_device_rfc)
resonate.register(bar_golden_device_rfc)
p: Handle[str] = resonate.run(f"{group}-foo", foo_golden_device_rfc, "hi")
p: Handle[str] = rf.run(f"{group}-foo", "hi")
assert isinstance(p, Handle)
assert p.result() == "hi"
resonate.stop()
Expand Down Expand Up @@ -611,7 +611,7 @@ def test_golden_device_rfc_and_lfc_with_decorator() -> None:
task_source=Poller("http://localhost:8002", group=group),
)

@resonate.register
@resonate.register()
def foo(ctx: Context, n: str) -> Generator[Yieldable, Any, str]:
v: str = yield ctx.lfc(bar, n).options(
id="bar",
Expand All @@ -623,7 +623,7 @@ def bar(ctx: Context, n: str) -> Generator[Yieldable, Any, str]:
v: str = yield ctx.rfc(baz, n).options(id="baz", send_to=poll(group))
return v

@resonate.register
@resonate.register()
def baz(ctx: Context, n: str) -> str: # noqa: ARG001
return n

Expand Down Expand Up @@ -686,16 +686,16 @@ def foo_retry_policy(ctx: Context, n: str) -> str: # noqa: ARG001
def test_human_in_the_loop() -> None:
group = "test-human-in-the-loop"

def _user_manual_completion(id: str) -> DurablePromise:
return DurablePromise(id=id)

def human_in_the_loop(ctx: Context) -> Generator[Yieldable, Any, str]:
name: str = yield ctx.rfc(
_user_manual_completion("test-human-in-loop-question-to-answer-1")
p_name: Promise[str] = yield ctx.promise(
"test-human-in-loop-question-to-answer-1"
)
age: int = yield ctx.rfc(
_user_manual_completion(id="test-human-in-loop-question-to-answer-2")
name: str = yield p_name

p_age: Promise[int] = yield ctx.promise(
id="test-human-in-loop-question-to-answer-2"
)
age: int = yield p_age
return f"Hi {name} with age {age}"

store = RemoteStore(url=os.environ["RESONATE_STORE_URL"])
Expand Down Expand Up @@ -730,17 +730,19 @@ def test_sleep() -> None:
group = "test-sleep"

store = RemoteStore(url="http://localhost:8001")
s = Resonate(store=store, task_source=Poller("http://localhost:8002", group=group))
resonate = Resonate(
store=store, task_source=Poller("http://localhost:8002", group=group)
)

@s.register
@resonate.register()
def foo_sleep(ctx: Context, n: int) -> Generator[Yieldable, Any, int]:
yield ctx.sleep(n)
return n

n = 1
p = foo_sleep.run(f"{group}-{n}", n)
p: Handle[int] = foo_sleep.run(f"{group}-{n}", n)
assert p.result() == n
s.stop()
resonate.stop()


@pytest.mark.skipif(
Expand Down Expand Up @@ -792,7 +794,7 @@ def test_golden_device_detached_with_registered() -> None:
task_source=Poller("http://localhost:8002", group=group),
)

@resonate.register
@resonate.register()
def foo_golden_device_detached_with_registered(
ctx: Context, n: str
) -> Generator[Yieldable, Any, str]:
Expand All @@ -805,11 +807,11 @@ def foo_golden_device_detached_with_registered(
v: str = yield p
return v

@resonate.register
@resonate.register()
def bar_golden_device_detached_with_registered(ctx: Context, n: str) -> str: # noqa: ARG001
return n

@resonate.register
@resonate.register()
def baz_golden_device_detached_with_registered(
ctx: Context, # noqa: ARG001
promise_id: str,
Expand Down

0 comments on commit 32cd7ef

Please sign in to comment.