From e99799ad7d1d7818d27c28c90f9f06bae18749de Mon Sep 17 00:00:00 2001 From: Pavel Kirilin Date: Sun, 2 Mar 2025 21:15:02 +0100 Subject: [PATCH 1/3] Added result parsing on return. --- taskiq/abc/broker.py | 8 ++++++++ taskiq/brokers/shared_broker.py | 1 + taskiq/decor.py | 5 +++++ taskiq/kicker.py | 5 ++++- taskiq/task.py | 16 +++++++++++++++- 5 files changed, 33 insertions(+), 2 deletions(-) diff --git a/taskiq/abc/broker.py b/taskiq/abc/broker.py index 69c9e25c..898e300e 100644 --- a/taskiq/abc/broker.py +++ b/taskiq/abc/broker.py @@ -18,10 +18,12 @@ Optional, TypeVar, Union, + get_type_hints, overload, ) from uuid import uuid4 +from pydantic import TypeAdapter from typing_extensions import ParamSpec, Self, TypeAlias from taskiq.abc.middleware import TaskiqMiddleware @@ -326,12 +328,18 @@ def inner( inner_task_name = f"{fmodule}:{fname}" wrapper = wraps(func) + sign = get_type_hints(func) + return_type = None + if "return" in sign: + return_type = TypeAdapter(sign["return"]) + decorated_task = wrapper( self.decorator_class( broker=self, original_func=func, labels=inner_labels, task_name=inner_task_name, + return_type=return_type, ), ) diff --git a/taskiq/brokers/shared_broker.py b/taskiq/brokers/shared_broker.py index d6574e60..def2c797 100644 --- a/taskiq/brokers/shared_broker.py +++ b/taskiq/brokers/shared_broker.py @@ -30,6 +30,7 @@ def kicker(self) -> AsyncKicker[_Params, _ReturnType]: task_name=self.task_name, broker=broker, labels=self.labels, + return_type=self.return_type, ) diff --git a/taskiq/decor.py b/taskiq/decor.py index dcbb2de0..268cb8eb 100644 --- a/taskiq/decor.py +++ b/taskiq/decor.py @@ -7,11 +7,13 @@ Callable, Dict, Generic, + Optional, TypeVar, Union, overload, ) +from pydantic import TypeAdapter from typing_extensions import ParamSpec from taskiq.kicker import AsyncKicker @@ -50,11 +52,13 @@ def __init__( task_name: str, original_func: Callable[_FuncParams, _ReturnType], labels: Dict[str, Any], + return_type: Optional[TypeAdapter[_ReturnType]] = None, ) -> None: self.broker = broker self.task_name = task_name self.original_func = original_func self.labels = labels + self.return_type = return_type # Docs for this method are omitted in order to help # your IDE resolve correct docs for it. @@ -172,6 +176,7 @@ def kicker(self) -> AsyncKicker[_FuncParams, _ReturnType]: task_name=self.task_name, broker=self.broker, labels=self.labels, + return_type=self.return_type, ) def __repr__(self) -> str: diff --git a/taskiq/kicker.py b/taskiq/kicker.py index d2ff8e6e..ed57c4c8 100644 --- a/taskiq/kicker.py +++ b/taskiq/kicker.py @@ -14,7 +14,7 @@ overload, ) -from pydantic import BaseModel +from pydantic import BaseModel, TypeAdapter from typing_extensions import ParamSpec from taskiq.abc.middleware import TaskiqMiddleware @@ -46,12 +46,14 @@ def __init__( task_name: str, broker: "AsyncBroker", labels: Dict[str, Any], + return_type: Optional[TypeAdapter[_ReturnType]] = None, ) -> None: self.task_name = task_name self.broker = broker self.labels = labels self.custom_task_id: Optional[str] = None self.custom_schedule_id: Optional[str] = None + self.return_type = return_type def with_labels( self, @@ -169,6 +171,7 @@ async def kiq( return AsyncTaskiqTask( task_id=message.task_id, result_backend=self.broker.result_backend, + return_type=self.return_type, # type: ignore # (pyright issue) ) async def schedule_by_cron( diff --git a/taskiq/task.py b/taskiq/task.py index c54e9d46..69d91e63 100644 --- a/taskiq/task.py +++ b/taskiq/task.py @@ -1,7 +1,9 @@ import asyncio +from logging import getLogger from time import time from typing import TYPE_CHECKING, Any, Generic, Optional +from pydantic import TypeAdapter from typing_extensions import TypeVar from taskiq.exceptions import ( @@ -15,6 +17,8 @@ from taskiq.depends.progress_tracker import TaskProgress from taskiq.result import TaskiqResult +logger = getLogger("taskiq.task") + _ReturnType = TypeVar("_ReturnType") @@ -25,9 +29,11 @@ def __init__( self, task_id: str, result_backend: "AsyncResultBackend[_ReturnType]", + return_type: Optional[TypeAdapter[_ReturnType]] = None, ) -> None: self.task_id = task_id self.result_backend = result_backend + self.return_type = return_type async def is_ready(self) -> bool: """ @@ -53,10 +59,18 @@ async def get_result(self, with_logs: bool = False) -> "TaskiqResult[_ReturnType :return: task's return value. """ try: - return await self.result_backend.get_result( + res = await self.result_backend.get_result( self.task_id, with_logs=with_logs, ) + if self.return_type is not None: + try: + res.return_value = self.return_type.validate_python( + res.return_value, + ) + except ValueError: + logger.warning("Cannot parse return type into %s", self.return_type) + return res except Exception as exc: raise ResultGetError from exc From ba55f43672e97537ee5cbd96b58c0147260a4bfe Mon Sep 17 00:00:00 2001 From: Pavel Kirilin Date: Sun, 2 Mar 2025 22:36:51 +0100 Subject: [PATCH 2/3] Fixed pydantic v1 support. --- taskiq/abc/broker.py | 5 ++--- taskiq/compat.py | 8 ++++---- taskiq/decor.py | 4 ++-- taskiq/kicker.py | 5 +++-- taskiq/task.py | 9 +++++---- 5 files changed, 16 insertions(+), 15 deletions(-) diff --git a/taskiq/abc/broker.py b/taskiq/abc/broker.py index 898e300e..9c7fbe86 100644 --- a/taskiq/abc/broker.py +++ b/taskiq/abc/broker.py @@ -23,7 +23,6 @@ ) from uuid import uuid4 -from pydantic import TypeAdapter from typing_extensions import ParamSpec, Self, TypeAlias from taskiq.abc.middleware import TaskiqMiddleware @@ -331,7 +330,7 @@ def inner( sign = get_type_hints(func) return_type = None if "return" in sign: - return_type = TypeAdapter(sign["return"]) + return_type = sign["return"] decorated_task = wrapper( self.decorator_class( @@ -339,7 +338,7 @@ def inner( original_func=func, labels=inner_labels, task_name=inner_task_name, - return_type=return_type, + return_type=return_type, # type: ignore ), ) diff --git a/taskiq/compat.py b/taskiq/compat.py index 1858d2c8..ce54bb9a 100644 --- a/taskiq/compat.py +++ b/taskiq/compat.py @@ -1,6 +1,6 @@ # flake8: noqa from functools import lru_cache -from typing import Any, Dict, Optional, Type, TypeVar, Union +from typing import Any, Dict, Hashable, Optional, Type, TypeVar, Union import pydantic from importlib_metadata import version @@ -12,13 +12,13 @@ IS_PYDANTIC2 = PYDANTIC_VER >= Version("2.0") if IS_PYDANTIC2: - T = TypeVar("T") + T = TypeVar("T", bound=Hashable) @lru_cache() - def create_type_adapter(annot: T) -> pydantic.TypeAdapter[T]: + def create_type_adapter(annot: Type[T]) -> pydantic.TypeAdapter[T]: return pydantic.TypeAdapter(annot) - def parse_obj_as(annot: T, obj: Any) -> T: + def parse_obj_as(annot: Type[T], obj: Any) -> T: return create_type_adapter(annot).validate_python(obj) def model_validate( diff --git a/taskiq/decor.py b/taskiq/decor.py index 268cb8eb..ba774506 100644 --- a/taskiq/decor.py +++ b/taskiq/decor.py @@ -8,12 +8,12 @@ Dict, Generic, Optional, + Type, TypeVar, Union, overload, ) -from pydantic import TypeAdapter from typing_extensions import ParamSpec from taskiq.kicker import AsyncKicker @@ -52,7 +52,7 @@ def __init__( task_name: str, original_func: Callable[_FuncParams, _ReturnType], labels: Dict[str, Any], - return_type: Optional[TypeAdapter[_ReturnType]] = None, + return_type: Optional[Type[_ReturnType]] = None, ) -> None: self.broker = broker self.task_name = task_name diff --git a/taskiq/kicker.py b/taskiq/kicker.py index ed57c4c8..9583df5d 100644 --- a/taskiq/kicker.py +++ b/taskiq/kicker.py @@ -9,12 +9,13 @@ Dict, Generic, Optional, + Type, TypeVar, Union, overload, ) -from pydantic import BaseModel, TypeAdapter +from pydantic import BaseModel from typing_extensions import ParamSpec from taskiq.abc.middleware import TaskiqMiddleware @@ -46,7 +47,7 @@ def __init__( task_name: str, broker: "AsyncBroker", labels: Dict[str, Any], - return_type: Optional[TypeAdapter[_ReturnType]] = None, + return_type: Optional[Type[_ReturnType]] = None, ) -> None: self.task_name = task_name self.broker = broker diff --git a/taskiq/task.py b/taskiq/task.py index 69d91e63..006dc520 100644 --- a/taskiq/task.py +++ b/taskiq/task.py @@ -1,11 +1,11 @@ import asyncio from logging import getLogger from time import time -from typing import TYPE_CHECKING, Any, Generic, Optional +from typing import TYPE_CHECKING, Any, Generic, Optional, Type -from pydantic import TypeAdapter from typing_extensions import TypeVar +from taskiq.compat import parse_obj_as from taskiq.exceptions import ( ResultGetError, ResultIsReadyError, @@ -29,7 +29,7 @@ def __init__( self, task_id: str, result_backend: "AsyncResultBackend[_ReturnType]", - return_type: Optional[TypeAdapter[_ReturnType]] = None, + return_type: Optional[Type[_ReturnType]] = None, ) -> None: self.task_id = task_id self.result_backend = result_backend @@ -65,7 +65,8 @@ async def get_result(self, with_logs: bool = False) -> "TaskiqResult[_ReturnType ) if self.return_type is not None: try: - res.return_value = self.return_type.validate_python( + res.return_value = parse_obj_as( + self.return_type, res.return_value, ) except ValueError: From 18b7ea3462a01be4b98d9e912538850da7661ef5 Mon Sep 17 00:00:00 2001 From: Pavel Kirilin Date: Tue, 4 Mar 2025 00:04:37 +0100 Subject: [PATCH 3/3] Added tests. --- tests/test_task.py | 73 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 tests/test_task.py diff --git a/tests/test_task.py b/tests/test_task.py new file mode 100644 index 00000000..a976f71c --- /dev/null +++ b/tests/test_task.py @@ -0,0 +1,73 @@ +import uuid +from dataclasses import dataclass +from typing import Dict, TypeVar + +import pytest + +from taskiq import serializers +from taskiq.abc import AsyncResultBackend +from taskiq.abc.serializer import TaskiqSerializer +from taskiq.compat import model_dump, model_validate +from taskiq.result.v1 import TaskiqResult +from taskiq.task import AsyncTaskiqTask + +_ReturnType = TypeVar("_ReturnType") + + +class SerializingBackend(AsyncResultBackend[_ReturnType]): + def __init__(self, serializer: TaskiqSerializer) -> None: + self._serializer = serializer + self._results: Dict[str, bytes] = {} + + async def set_result( + self, + task_id: str, + result: TaskiqResult[_ReturnType], # type: ignore + ) -> None: + """Set result with dumping.""" + self._results[task_id] = self._serializer.dumpb(model_dump(result)) + + async def is_result_ready(self, task_id: str) -> bool: + """Check if result is ready.""" + return task_id in self._results + + async def get_result( + self, + task_id: str, + with_logs: bool = False, + ) -> TaskiqResult[_ReturnType]: + """Get result with loading.""" + data = self._results[task_id] + return model_validate(TaskiqResult, self._serializer.loadb(data)) + + +@pytest.mark.parametrize( + "serializer", + [ + serializers.MSGPackSerializer(), + serializers.CBORSerializer(), + serializers.PickleSerializer(), + serializers.JSONSerializer(), + ], +) +@pytest.mark.anyio +async def test_res_parsing_success(serializer: TaskiqSerializer) -> None: + @dataclass + class MyResult: + name: str + age: int + + res = MyResult(name="test", age=10) + res_back: AsyncResultBackend[MyResult] = SerializingBackend(serializer) + test_id = str(uuid.uuid4()) + await res_back.set_result( + test_id, + TaskiqResult( + is_err=False, + return_value=res, + execution_time=0.0, + ), + ) + sent_task = AsyncTaskiqTask(test_id, res_back, MyResult) + parsed = await sent_task.wait_result() + assert isinstance(parsed.return_value, MyResult)