Skip to content

Commit

Permalink
Untyped register that fits us all
Browse files Browse the repository at this point in the history
  • Loading branch information
Tomperez98 committed Jan 27, 2025
1 parent e10fcfe commit a329bcb
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 27 deletions.
57 changes: 38 additions & 19 deletions src/resonate/resonate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
if TYPE_CHECKING:
from collections.abc import Coroutine, Generator

from resonate import retry_policy
from resonate.context import Context
from resonate.handle import Handle
from resonate.retry_policy import RetryPolicy
from resonate.scheduler.traits import IScheduler
from resonate.stores.local import LocalStore
from resonate.stores.traits import IPromiseStore
from resonate.task_sources.traits import ITaskSource
from resonate.typing import DurableCoro, DurableFn, Yieldable
from resonate.typing import Yieldable

P = ParamSpec("P")
T = TypeVar("T")
Expand Down Expand Up @@ -113,47 +113,66 @@ def register(
Concatenate[Context, P],
Generator[Yieldable, Any, T],
],
/,
*,
name: str | None = None,
version: int = 1,
retry_policy: retry_policy.RetryPolicy | None = None,
retry_policy: RetryPolicy | None = None,
) -> None: ...
@overload
def register(
self,
func: Callable[Concatenate[Context, P], Coroutine[Any, Any, T]],
/,
*,
name: str | None = None,
version: int = 1,
retry_policy: retry_policy.RetryPolicy | None = None,
retry_policy: RetryPolicy | None = None,
) -> None: ...
@overload
def register(
self,
func: Callable[Concatenate[Context, P], T],
/,
*,
name: str | None = None,
version: int = 1,
retry_policy: retry_policy.RetryPolicy | None = None,
retry_policy: RetryPolicy | None = None,
) -> None: ...
@overload
def register(
self,
func: DurableCoro[P, Any] | DurableFn[P, Any],
*,
name: str | None = None,
version: int = 1,
retry_policy: retry_policy.RetryPolicy | None = None,
) -> None:
return self._registry.add(
name or func.__name__,
(func, Options(version=version, durable=True, retry_policy=retry_policy)),
)

def __call__(
self,
name: str | None = None,
version: int = 1,
retry_policy: retry_policy.RetryPolicy | None = None,
retry_policy: RetryPolicy | None = None,
) -> Callable[
[Callable[Concatenate[Context, P], Any]],
RegisteredFn[P, Any],
]:
]: ...
def register(
self, *args: Any, **kwargs: Any
) -> (
None
| Callable[
[Callable[Concatenate[Context, P], Any]],
RegisteredFn[P, Any],
]
):
name: str | None = kwargs.get("name")
version: int = kwargs.get("version", 1)
retry_policy: RetryPolicy | None = kwargs.get("retry_policy")
if args and callable(args[0]):
func: Callable = args[0] # type: ignore[type-arg]
self._registry.add(
name or func.__name__,
(
func,
Options(version=version, durable=True, retry_policy=retry_policy),
),
)
return None

def wrapper(
func: Callable[Concatenate[Context, P], Any],
) -> RegisteredFn[P, Any]:
Expand Down
16 changes: 8 additions & 8 deletions tests/test_functionality.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ def test_golden_device_rfi_and_lfc_with_decorator() -> None:
task_source=Poller("http://localhost:8002", group=group),
)

@resonate()
@resonate.register()
def foo(ctx: Context, n: str) -> Generator[Yieldable, Any, str]:
v: str = yield ctx.lfc(bar, n).options(
id="bar",
Expand All @@ -472,7 +472,7 @@ def bar(ctx: Context, n: str) -> Generator[Yieldable, Any, str]:
v: str = yield p
return v

@resonate()
@resonate.register()
def baz(ctx: Context, n: str) -> str: # noqa: ARG001
return n

Expand Down Expand Up @@ -611,7 +611,7 @@ def test_golden_device_rfc_and_lfc_with_decorator() -> None:
task_source=Poller("http://localhost:8002", group=group),
)

@resonate()
@resonate.register()
def foo(ctx: Context, n: str) -> Generator[Yieldable, Any, str]:
v: str = yield ctx.lfc(bar, n).options(
id="bar",
Expand All @@ -623,7 +623,7 @@ def bar(ctx: Context, n: str) -> Generator[Yieldable, Any, str]:
v: str = yield ctx.rfc(baz, n).options(id="baz", send_to=poll(group))
return v

@resonate()
@resonate.register()
def baz(ctx: Context, n: str) -> str: # noqa: ARG001
return n

Expand Down Expand Up @@ -734,7 +734,7 @@ def test_sleep() -> None:
store=store, task_source=Poller("http://localhost:8002", group=group)
)

@resonate()
@resonate.register()
def foo_sleep(ctx: Context, n: int) -> Generator[Yieldable, Any, int]:
yield ctx.sleep(n)
return n
Expand Down Expand Up @@ -794,7 +794,7 @@ def test_golden_device_detached_with_registered() -> None:
task_source=Poller("http://localhost:8002", group=group),
)

@resonate()
@resonate.register()
def foo_golden_device_detached_with_registered(
ctx: Context, n: str
) -> Generator[Yieldable, Any, str]:
Expand All @@ -807,11 +807,11 @@ def foo_golden_device_detached_with_registered(
v: str = yield p
return v

@resonate()
@resonate.register()
def bar_golden_device_detached_with_registered(ctx: Context, n: str) -> str: # noqa: ARG001
return n

@resonate()
@resonate.register()
def baz_golden_device_detached_with_registered(
ctx: Context, # noqa: ARG001
promise_id: str,
Expand Down

0 comments on commit a329bcb

Please sign in to comment.