Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added result parsing on return. #415

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions taskiq/abc/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Optional,
TypeVar,
Union,
get_type_hints,
overload,
)
from uuid import uuid4
Expand Down Expand Up @@ -326,12 +327,18 @@ def inner(
inner_task_name = f"{fmodule}:{fname}"
wrapper = wraps(func)

sign = get_type_hints(func)
return_type = None
if "return" in sign:
return_type = sign["return"]

decorated_task = wrapper(
self.decorator_class(
broker=self,
original_func=func,
labels=inner_labels,
task_name=inner_task_name,
return_type=return_type, # type: ignore
),
)

Expand Down
1 change: 1 addition & 0 deletions taskiq/brokers/shared_broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def kicker(self) -> AsyncKicker[_Params, _ReturnType]:
task_name=self.task_name,
broker=broker,
labels=self.labels,
return_type=self.return_type,
)


Expand Down
8 changes: 4 additions & 4 deletions taskiq/compat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# flake8: noqa
from functools import lru_cache
from typing import Any, Dict, Optional, Type, TypeVar, Union
from typing import Any, Dict, Hashable, Optional, Type, TypeVar, Union

import pydantic
from importlib_metadata import version
Expand All @@ -12,13 +12,13 @@
IS_PYDANTIC2 = PYDANTIC_VER >= Version("2.0")

if IS_PYDANTIC2:
T = TypeVar("T")
T = TypeVar("T", bound=Hashable)

@lru_cache()
def create_type_adapter(annot: T) -> pydantic.TypeAdapter[T]:
def create_type_adapter(annot: Type[T]) -> pydantic.TypeAdapter[T]:
return pydantic.TypeAdapter(annot)

def parse_obj_as(annot: T, obj: Any) -> T:
def parse_obj_as(annot: Type[T], obj: Any) -> T:
return create_type_adapter(annot).validate_python(obj)

def model_validate(
Expand Down
5 changes: 5 additions & 0 deletions taskiq/decor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
Callable,
Dict,
Generic,
Optional,
Type,
TypeVar,
Union,
overload,
Expand Down Expand Up @@ -50,11 +52,13 @@ def __init__(
task_name: str,
original_func: Callable[_FuncParams, _ReturnType],
labels: Dict[str, Any],
return_type: Optional[Type[_ReturnType]] = None,
) -> None:
self.broker = broker
self.task_name = task_name
self.original_func = original_func
self.labels = labels
self.return_type = return_type

# Docs for this method are omitted in order to help
# your IDE resolve correct docs for it.
Expand Down Expand Up @@ -172,6 +176,7 @@ def kicker(self) -> AsyncKicker[_FuncParams, _ReturnType]:
task_name=self.task_name,
broker=self.broker,
labels=self.labels,
return_type=self.return_type,
)

def __repr__(self) -> str:
Expand Down
4 changes: 4 additions & 0 deletions taskiq/kicker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Dict,
Generic,
Optional,
Type,
TypeVar,
Union,
overload,
Expand Down Expand Up @@ -46,12 +47,14 @@ def __init__(
task_name: str,
broker: "AsyncBroker",
labels: Dict[str, Any],
return_type: Optional[Type[_ReturnType]] = None,
) -> None:
self.task_name = task_name
self.broker = broker
self.labels = labels
self.custom_task_id: Optional[str] = None
self.custom_schedule_id: Optional[str] = None
self.return_type = return_type

def with_labels(
self,
Expand Down Expand Up @@ -169,6 +172,7 @@ async def kiq(
return AsyncTaskiqTask(
task_id=message.task_id,
result_backend=self.broker.result_backend,
return_type=self.return_type, # type: ignore # (pyright issue)
)

async def schedule_by_cron(
Expand Down
19 changes: 17 additions & 2 deletions taskiq/task.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import asyncio
from logging import getLogger
from time import time
from typing import TYPE_CHECKING, Any, Generic, Optional
from typing import TYPE_CHECKING, Any, Generic, Optional, Type

from typing_extensions import TypeVar

from taskiq.compat import parse_obj_as
from taskiq.exceptions import (
ResultGetError,
ResultIsReadyError,
Expand All @@ -15,6 +17,8 @@
from taskiq.depends.progress_tracker import TaskProgress
from taskiq.result import TaskiqResult

logger = getLogger("taskiq.task")

_ReturnType = TypeVar("_ReturnType")


Expand All @@ -25,9 +29,11 @@ def __init__(
self,
task_id: str,
result_backend: "AsyncResultBackend[_ReturnType]",
return_type: Optional[Type[_ReturnType]] = None,
) -> None:
self.task_id = task_id
self.result_backend = result_backend
self.return_type = return_type

async def is_ready(self) -> bool:
"""
Expand All @@ -53,10 +59,19 @@ async def get_result(self, with_logs: bool = False) -> "TaskiqResult[_ReturnType
:return: task's return value.
"""
try:
return await self.result_backend.get_result(
res = await self.result_backend.get_result(
self.task_id,
with_logs=with_logs,
)
if self.return_type is not None:
try:
res.return_value = parse_obj_as(
self.return_type,
res.return_value,
)
except ValueError:
logger.warning("Cannot parse return type into %s", self.return_type)
return res
except Exception as exc:
raise ResultGetError from exc

Expand Down
73 changes: 73 additions & 0 deletions tests/test_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import uuid
from dataclasses import dataclass
from typing import Dict, TypeVar

import pytest

from taskiq import serializers
from taskiq.abc import AsyncResultBackend
from taskiq.abc.serializer import TaskiqSerializer
from taskiq.compat import model_dump, model_validate
from taskiq.result.v1 import TaskiqResult
from taskiq.task import AsyncTaskiqTask

_ReturnType = TypeVar("_ReturnType")


class SerializingBackend(AsyncResultBackend[_ReturnType]):
def __init__(self, serializer: TaskiqSerializer) -> None:
self._serializer = serializer
self._results: Dict[str, bytes] = {}

async def set_result(
self,
task_id: str,
result: TaskiqResult[_ReturnType], # type: ignore
) -> None:
"""Set result with dumping."""
self._results[task_id] = self._serializer.dumpb(model_dump(result))

async def is_result_ready(self, task_id: str) -> bool:
"""Check if result is ready."""
return task_id in self._results

async def get_result(
self,
task_id: str,
with_logs: bool = False,
) -> TaskiqResult[_ReturnType]:
"""Get result with loading."""
data = self._results[task_id]
return model_validate(TaskiqResult, self._serializer.loadb(data))


@pytest.mark.parametrize(
"serializer",
[
serializers.MSGPackSerializer(),
serializers.CBORSerializer(),
serializers.PickleSerializer(),
serializers.JSONSerializer(),
],
)
@pytest.mark.anyio
async def test_res_parsing_success(serializer: TaskiqSerializer) -> None:
@dataclass
class MyResult:
name: str
age: int

res = MyResult(name="test", age=10)
res_back: AsyncResultBackend[MyResult] = SerializingBackend(serializer)
test_id = str(uuid.uuid4())
await res_back.set_result(
test_id,
TaskiqResult(
is_err=False,
return_value=res,
execution_time=0.0,
),
)
sent_task = AsyncTaskiqTask(test_id, res_back, MyResult)
parsed = await sent_task.wait_result()
assert isinstance(parsed.return_value, MyResult)
Loading