From d072bcc13f982e0b53d00d8d63e3b7ce5c417fe3 Mon Sep 17 00:00:00 2001 From: Dima Tisnek Date: Tue, 26 Nov 2024 17:20:09 +0900 Subject: [PATCH] chore: pull facade type improvements from #1104 --- juju/client/facade.py | 68 ++++++++++++++++++++++++++----------------- 1 file changed, 41 insertions(+), 27 deletions(-) diff --git a/juju/client/facade.py b/juju/client/facade.py index f0ee75130..7ab4fddd0 100644 --- a/juju/client/facade.py +++ b/juju/client/facade.py @@ -14,7 +14,7 @@ from collections import defaultdict from glob import glob from pathlib import Path -from typing import Any, Mapping, Sequence +from typing import Any, Mapping, Sequence, TypeVar, overload import packaging.version import typing_inspect @@ -183,7 +183,7 @@ def ref_type(self, obj): return self.get_ref_type(obj["$ref"]) -CLASSES = {} +CLASSES: dict[str, type[Type]] = {} factories = codegen.Capture() @@ -479,37 +479,48 @@ def ReturnMapping(cls): # noqa: N802 def decorator(f): @functools.wraps(f) async def wrapper(*args, **kwargs): - nonlocal cls reply = await f(*args, **kwargs) - if cls is None: - return reply - if "error" in reply: - cls = CLASSES["Error"] - if typing_inspect.is_generic_type(cls) and issubclass( - typing_inspect.get_origin(cls), Sequence - ): - parameters = typing_inspect.get_parameters(cls) - result = [] - item_cls = parameters[0] - for item in reply: - result.append(item_cls.from_json(item)) - """ - if 'error' in item: - cls = CLASSES['Error'] - else: - cls = item_cls - result.append(cls.from_json(item)) - """ - else: - result = cls.from_json(reply["response"]) - - return result + return _convert_response(reply, cls=cls) return wrapper return decorator +@overload +def _convert_response(response: dict[str, Any], *, cls: type[SomeType]) -> SomeType: ... + + +@overload +def _convert_response(response: dict[str, Any], *, cls: None) -> dict[str, Any]: ... + + +def _convert_response(response: dict[str, Any], *, cls: type[Type] | None) -> Any: + if cls is None: + return response + if "error" in response: + cls = CLASSES["Error"] + if typing_inspect.is_generic_type(cls) and issubclass( + typing_inspect.get_origin(cls), Sequence + ): + parameters = typing_inspect.get_parameters(cls) + result = [] + item_cls = parameters[0] + for item in response: + result.append(item_cls.from_json(item)) + """ + if 'error' in item: + cls = CLASSES['Error'] + else: + cls = item_cls + result.append(cls.from_json(item)) + """ + else: + result = cls.from_json(response["response"]) + + return result + + def make_func(cls, name, description, params, result, _async=True): indent = " " args = Args(cls.schema, params) @@ -663,7 +674,7 @@ async def rpc(self, msg: dict[str, _RichJson]) -> _Json: return result @classmethod - def from_json(cls, data): + def from_json(cls, data: Type | str | dict[str, Any] | list[Any]) -> Type | None: def _parse_nested_list_entry(expr, result_dict): if isinstance(expr, str): if ">" in expr or ">=" in expr: @@ -742,6 +753,9 @@ def get(self, key, default=None): return getattr(self, attr, default) +SomeType = TypeVar("SomeType", bound=Type) + + class Schema(dict): def __init__(self, schema): self.name = schema["Name"]