From a329bcb9605f79eec845e12b620ceff42a2e57b2 Mon Sep 17 00:00:00 2001 From: Tomperez98 Date: Mon, 27 Jan 2025 17:07:43 -0500 Subject: [PATCH] Untyped register that fits us all --- src/resonate/resonate.py | 57 ++++++++++++++++++++++++------------- tests/test_functionality.py | 16 +++++------ 2 files changed, 46 insertions(+), 27 deletions(-) diff --git a/src/resonate/resonate.py b/src/resonate/resonate.py index f3ccf006..d77f063e 100644 --- a/src/resonate/resonate.py +++ b/src/resonate/resonate.py @@ -18,14 +18,14 @@ if TYPE_CHECKING: from collections.abc import Coroutine, Generator - from resonate import retry_policy from resonate.context import Context from resonate.handle import Handle + from resonate.retry_policy import RetryPolicy from resonate.scheduler.traits import IScheduler from resonate.stores.local import LocalStore from resonate.stores.traits import IPromiseStore from resonate.task_sources.traits import ITaskSource - from resonate.typing import DurableCoro, DurableFn, Yieldable + from resonate.typing import Yieldable P = ParamSpec("P") T = TypeVar("T") @@ -113,47 +113,66 @@ def register( Concatenate[Context, P], Generator[Yieldable, Any, T], ], + /, + *, name: str | None = None, version: int = 1, - retry_policy: retry_policy.RetryPolicy | None = None, + retry_policy: RetryPolicy | None = None, ) -> None: ... @overload def register( self, func: Callable[Concatenate[Context, P], Coroutine[Any, Any, T]], + /, + *, name: str | None = None, version: int = 1, - retry_policy: retry_policy.RetryPolicy | None = None, + retry_policy: RetryPolicy | None = None, ) -> None: ... @overload def register( self, func: Callable[Concatenate[Context, P], T], + /, + *, name: str | None = None, version: int = 1, - retry_policy: retry_policy.RetryPolicy | None = None, + retry_policy: RetryPolicy | None = None, ) -> None: ... + @overload def register( self, - func: DurableCoro[P, Any] | DurableFn[P, Any], + *, name: str | None = None, version: int = 1, - retry_policy: retry_policy.RetryPolicy | None = None, - ) -> None: - return self._registry.add( - name or func.__name__, - (func, Options(version=version, durable=True, retry_policy=retry_policy)), - ) - - def __call__( - self, - name: str | None = None, - version: int = 1, - retry_policy: retry_policy.RetryPolicy | None = None, + retry_policy: RetryPolicy | None = None, ) -> Callable[ [Callable[Concatenate[Context, P], Any]], RegisteredFn[P, Any], - ]: + ]: ... + def register( + self, *args: Any, **kwargs: Any + ) -> ( + None + | Callable[ + [Callable[Concatenate[Context, P], Any]], + RegisteredFn[P, Any], + ] + ): + name: str | None = kwargs.get("name") + version: int = kwargs.get("version", 1) + retry_policy: RetryPolicy | None = kwargs.get("retry_policy") + if args and callable(args[0]): + func: Callable = args[0] # type: ignore[type-arg] + self._registry.add( + name or func.__name__, + ( + func, + Options(version=version, durable=True, retry_policy=retry_policy), + ), + ) + return None + def wrapper( func: Callable[Concatenate[Context, P], Any], ) -> RegisteredFn[P, Any]: diff --git a/tests/test_functionality.py b/tests/test_functionality.py index bb7ec474..0df86ad0 100644 --- a/tests/test_functionality.py +++ b/tests/test_functionality.py @@ -459,7 +459,7 @@ def test_golden_device_rfi_and_lfc_with_decorator() -> None: task_source=Poller("http://localhost:8002", group=group), ) - @resonate() + @resonate.register() def foo(ctx: Context, n: str) -> Generator[Yieldable, Any, str]: v: str = yield ctx.lfc(bar, n).options( id="bar", @@ -472,7 +472,7 @@ def bar(ctx: Context, n: str) -> Generator[Yieldable, Any, str]: v: str = yield p return v - @resonate() + @resonate.register() def baz(ctx: Context, n: str) -> str: # noqa: ARG001 return n @@ -611,7 +611,7 @@ def test_golden_device_rfc_and_lfc_with_decorator() -> None: task_source=Poller("http://localhost:8002", group=group), ) - @resonate() + @resonate.register() def foo(ctx: Context, n: str) -> Generator[Yieldable, Any, str]: v: str = yield ctx.lfc(bar, n).options( id="bar", @@ -623,7 +623,7 @@ def bar(ctx: Context, n: str) -> Generator[Yieldable, Any, str]: v: str = yield ctx.rfc(baz, n).options(id="baz", send_to=poll(group)) return v - @resonate() + @resonate.register() def baz(ctx: Context, n: str) -> str: # noqa: ARG001 return n @@ -734,7 +734,7 @@ def test_sleep() -> None: store=store, task_source=Poller("http://localhost:8002", group=group) ) - @resonate() + @resonate.register() def foo_sleep(ctx: Context, n: int) -> Generator[Yieldable, Any, int]: yield ctx.sleep(n) return n @@ -794,7 +794,7 @@ def test_golden_device_detached_with_registered() -> None: task_source=Poller("http://localhost:8002", group=group), ) - @resonate() + @resonate.register() def foo_golden_device_detached_with_registered( ctx: Context, n: str ) -> Generator[Yieldable, Any, str]: @@ -807,11 +807,11 @@ def foo_golden_device_detached_with_registered( v: str = yield p return v - @resonate() + @resonate.register() def bar_golden_device_detached_with_registered(ctx: Context, n: str) -> str: # noqa: ARG001 return n - @resonate() + @resonate.register() def baz_golden_device_detached_with_registered( ctx: Context, # noqa: ARG001 promise_id: str,