From c554e1099998aa77f75e0d34362145894dd19238 Mon Sep 17 00:00:00 2001 From: amirreza Date: Sat, 28 Dec 2024 17:54:05 +0330 Subject: [PATCH] fixed some type hints --- django_valkey/async_cache/client/default.py | 12 ++-- django_valkey/base.py | 77 ++++++++++++--------- django_valkey/base_client.py | 12 ++-- django_valkey/base_pool.py | 24 +++---- django_valkey/pool.py | 13 ++-- django_valkey/typing.py | 7 ++ 6 files changed, 81 insertions(+), 64 deletions(-) diff --git a/django_valkey/async_cache/client/default.py b/django_valkey/async_cache/client/default.py index 1b11d06..262bb86 100644 --- a/django_valkey/async_cache/client/default.py +++ b/django_valkey/async_cache/client/default.py @@ -1,5 +1,5 @@ import builtins -from collections.abc import Iterable, AsyncIterable +from collections.abc import Iterable, AsyncIterator import contextlib from contextlib import suppress from typing import Any, cast, TYPE_CHECKING @@ -454,7 +454,7 @@ async def amget( map_keys = {await self.make_key(k, version=version): k for k in keys} try: - results = await client.mget(*map_keys) + results: list[bytes] = await client.mget(*map_keys) except _main_exceptions as e: raise ConnectionInterrupted(connection=client) from e @@ -698,7 +698,7 @@ async def aiter_keys( itersize: int | None = None, client: AValkey | Any | None = None, version: int | None = None, - ) -> AsyncIterable[KeyT]: + ) -> AsyncIterator[KeyT]: """ Same as keys, but uses cursors for make memory efficient keys iteration. @@ -974,7 +974,7 @@ async def asrandmember( version: int | None = None, client: AValkey | Any | None = None, return_set: bool = True, - ) -> list | Any: + ) -> builtins.set | list | Any: client = await self._get_client(write=False, client=client) key = await self.make_key(key, version=version) result = await client.srandmember(key, count) @@ -1034,7 +1034,7 @@ async def asscan_iter( count: int = 10, version: int | None = None, client: AValkey | Any | None = None, - ) -> AsyncIterable[Any]: + ) -> AsyncIterator[Any]: if self._has_compression_enabled() and match: error_message = "Using match with compression is not supported." raise ValueError(error_message) @@ -1251,7 +1251,7 @@ async def ahmget( keys: list, version: int | None = None, client: AValkey | Any | None = None, - ) -> list: + ) -> list[Any]: client = await self._get_client(write=False, client=client) nkeys = [await self.make_key(key, version=version) for key in keys] try: diff --git a/django_valkey/base.py b/django_valkey/base.py index 6a29e14..032f19f 100644 --- a/django_valkey/base.py +++ b/django_valkey/base.py @@ -1,14 +1,21 @@ +import builtins +from collections.abc import Iterator, AsyncIterator, Callable import contextlib import functools import logging from asyncio import iscoroutinefunction -from typing import Any, TypeVar, Generic, Iterator, AsyncGenerator, Set, Callable +from typing import Any, TypeVar, Generic, TYPE_CHECKING from django.conf import settings from django.core.cache.backends.base import BaseCache from django.utils.module_loading import import_string from django_valkey.exceptions import ConnectionInterrupted +from django_valkey.typing import KeyT + +if TYPE_CHECKING: + from valkey.lock import Lock + from valkey.asyncio.lock import Lock as ALock Client = TypeVar("Client") Backend = TypeVar("Backend") @@ -91,19 +98,19 @@ def client(self) -> Client: return self._client @omit_exception - def mset(self, *args, **kwargs): + def mset(self, *args, **kwargs) -> bool: return self.client.mset(*args, **kwargs) @omit_exception - async def amset(self, *args, **kwargs): + async def amset(self, *args, **kwargs) -> bool: return await self.client.amset(*args, **kwargs) @omit_exception - def mget(self, *args, **kwargs): + def mget(self, *args, **kwargs) -> dict | list[Any]: return self.client.mget(*args, **kwargs) @omit_exception - async def amget(self, *args, **kwargs): + async def amget(self, *args, **kwargs) -> dict | list[Any]: return await self.client.amget(*args, **kwargs) @omit_exception @@ -147,13 +154,13 @@ async def apexpire_at(self, *args, **kwargs) -> bool: return await self.client.apexpire_at(*args, **kwargs) @omit_exception - def get_lock(self, *args, **kwargs): + def get_lock(self, *args, **kwargs) -> "Lock": return self.client.get_lock(*args, **kwargs) lock = get_lock @omit_exception - async def aget_lock(self, *args, **kwargs): + async def aget_lock(self, *args, **kwargs) -> "ALock": return await self.client.aget_lock(*args, **kwargs) alock = aget_lock @@ -169,7 +176,7 @@ async def adelete_pattern(self, *args, **kwargs) -> int: return await self.client.adelete_pattern(*args, **kwargs) @omit_exception - def ttl(self, *args, **kwargs) -> int | None: + def ttl(self, *args, **kwargs) -> int: return self.client.ttl(*args, **kwargs) @omit_exception @@ -185,29 +192,29 @@ async def apttl(self, *args, **kwargs) -> int: return await self.client.apttl(*args, **kwargs) @omit_exception - def iter_keys(self, *args, **kwargs) -> Iterator: + def iter_keys(self, *args, **kwargs) -> Iterator[KeyT]: return self.client.iter_keys(*args, **kwargs) @omit_exception - async def aiter_keys(self, *args, **kwargs) -> AsyncGenerator: + async def aiter_keys(self, *args, **kwargs) -> AsyncIterator[KeyT]: async with contextlib.aclosing(self.client.aiter_keys(*args, **kwargs)) as it: async for key in it: yield key @omit_exception - def keys(self, *args, **kwargs) -> list[Any]: + def keys(self, *args, **kwargs) -> list: return self.client.keys(*args, **kwargs) @omit_exception - async def akeys(self, *args, **kwargs) -> list[Any]: + async def akeys(self, *args, **kwargs) -> list: return await self.client.akeys(*args, **kwargs) @omit_exception - def scan(self, *args, **kwargs) -> tuple[int, list[str]]: + def scan(self, *args, **kwargs) -> tuple[int, list[Any]]: return self.client.scan(*args, **kwargs) @omit_exception - async def ascan(self, *args, **kwargs) -> tuple[int, list[str]]: + async def ascan(self, *args, **kwargs) -> tuple[int, list[Any]]: return await self.client.ascan(*args, **kwargs) @omit_exception @@ -227,11 +234,11 @@ async def ascard(self, *args, **kwargs) -> int: return await self.client.ascard(*args, **kwargs) @omit_exception - def sdiff(self, *args, **kwargs) -> Set[Any]: + def sdiff(self, *args, **kwargs) -> builtins.set[Any] | list[Any]: return self.client.sdiff(*args, **kwargs) @omit_exception - async def asdiff(self, *args, **kwargs) -> Set[Any]: + async def asdiff(self, *args, **kwargs) -> builtins.set[Any] | list[Any]: return await self.client.asdiff(*args, **kwargs) @omit_exception @@ -243,11 +250,11 @@ async def asdiffstore(self, *args, **kwargs) -> int: return await self.client.asdiffstore(*args, **kwargs) @omit_exception - def sinter(self, *args, **kwargs) -> Set[Any]: + def sinter(self, *args, **kwargs) -> builtins.set[Any] | list[Any]: return self.client.sinter(*args, **kwargs) @omit_exception - async def asinter(self, *args, **kwargs) -> Set[Any]: + async def asinter(self, *args, **kwargs) -> builtins.set[Any] | list[Any]: return await self.client.asinter(*args, **kwargs) @omit_exception @@ -283,11 +290,11 @@ async def asismember(self, *args, **kwargs) -> bool: return await self.client.asismember(*args, **kwargs) @omit_exception - def smembers(self, *args, **kwargs) -> Set[Any]: + def smembers(self, *args, **kwargs) -> builtins.set[Any] | list[Any]: return self.client.smembers(*args, **kwargs) @omit_exception - async def asmembers(self, *args, **kwargs) -> Set[Any]: + async def asmembers(self, *args, **kwargs) -> builtins.set[Any] | list[Any]: return await self.client.asmembers(*args, **kwargs) @omit_exception @@ -299,19 +306,19 @@ async def asmove(self, *args, **kwargs) -> bool: return await self.client.asmove(*args, **kwargs) @omit_exception - def spop(self, *args, **kwargs) -> Set | Any: + def spop(self, *args, **kwargs) -> builtins.set | list | Any: return self.client.spop(*args, **kwargs) @omit_exception - async def aspop(self, *args, **kwargs) -> Set | Any: + async def aspop(self, *args, **kwargs) -> builtins.set | list | Any: return await self.client.aspop(*args, **kwargs) @omit_exception - def srandmember(self, *args, **kwargs) -> list | Any: + def srandmember(self, *args, **kwargs) -> builtins.set | list | Any: return self.client.srandmember(*args, **kwargs) @omit_exception - async def asrandmember(self, *args, **kwargs) -> list | Any: + async def asrandmember(self, *args, **kwargs) -> builtins.set | list | Any: return await self.client.asrandmember(*args, **kwargs) @omit_exception @@ -323,29 +330,31 @@ async def asrem(self, *args, **kwargs) -> int: return await self.client.asrem(*args, **kwargs) @omit_exception - def sscan(self, *args, **kwargs) -> Set[Any]: + def sscan(self, *args, **kwargs) -> tuple[int, builtins.set[Any] | list[Any]]: return self.client.sscan(*args, **kwargs) @omit_exception - async def asscan(self, *args, **kwargs) -> Set[Any]: + async def asscan( + self, *args, **kwargs + ) -> tuple[int, builtins.set[Any] | list[Any]]: return await self.client.asscan(*args, **kwargs) @omit_exception - def sscan_iter(self, *args, **kwargs) -> Iterator: + def sscan_iter(self, *args, **kwargs) -> Iterator[Any]: return self.client.sscan_iter(*args, **kwargs) @omit_exception - async def asscan_iter(self, *args, **kwargs) -> AsyncGenerator: + async def asscan_iter(self, *args, **kwargs) -> AsyncIterator[Any]: async with contextlib.aclosing(self.client.asscan_iter(*args, **kwargs)) as it: async for key in it: yield key @omit_exception - def sunion(self, *args, **kwargs) -> Set[Any]: + def sunion(self, *args, **kwargs) -> builtins.set[Any] | list[Any]: return self.client.sunion(*args, **kwargs) @omit_exception - async def asunion(self, *args, **kwargs) -> Set[Any]: + async def asunion(self, *args, **kwargs) -> builtins.set[Any] | list[Any]: return await self.client.asunion(*args, **kwargs) @omit_exception @@ -389,19 +398,19 @@ async def ahdel_many(self, *args, **kwargs) -> int: return await self.client.ahdel_many(*args, **kwargs) @omit_exception - def hget(self, *args, **kwargs) -> bytes | None: + def hget(self, *args, **kwargs) -> Any | None: return self.client.hget(*args, **kwargs) @omit_exception - async def ahget(self, *args, **kwargs) -> str | None: + async def ahget(self, *args, **kwargs) -> Any | None: return await self.client.ahget(*args, **kwargs) @omit_exception - def hgetall(self, *args, **kwargs) -> dict: + def hgetall(self, *args, **kwargs) -> dict[str, Any]: return self.client.hgetall(*args, **kwargs) @omit_exception - async def ahgetall(self, *args, **kwargs) -> dict: + async def ahgetall(self, *args, **kwargs) -> dict[str, Any]: return await self.client.ahgetall(*args, **kwargs) @omit_exception diff --git a/django_valkey/base_client.py b/django_valkey/base_client.py index cbc9d5c..5fc7def 100644 --- a/django_valkey/base_client.py +++ b/django_valkey/base_client.py @@ -584,7 +584,7 @@ def mget( map_keys = {self.make_key(k, version=version): k for k in keys} try: - results = client.mget(map_keys) + results: list[bytes] = client.mget(map_keys) except _main_exceptions as e: raise ConnectionInterrupted(connection=client) from e @@ -1073,7 +1073,7 @@ def srandmember( version: int | None = None, client: Backend | Any | None = None, return_set: bool = True, - ) -> list | Any: + ) -> builtins.set | list | Any: client = self._get_client(write=False, client=client) key = self.make_key(key, version=version) @@ -1102,7 +1102,7 @@ def sscan( version: int | None = None, client: Backend | Any | None = None, return_set: bool = True, - ) -> tuple[int, builtins.set[Any]] | tuple[int, list[Any]]: + ) -> tuple[int, builtins.set[Any] | list[Any]]: if self._has_compression_enabled() and match: err_msg = "Using match with compression is not supported." raise ValueError(err_msg) @@ -1275,7 +1275,7 @@ def hdel_many( client: Backend | Any | None = None, ) -> int: client = self._get_client(write=True, client=client) - nkeys = [self.make_key(key) for key in keys] + nkeys = [self.make_key(key, version=version) for key in keys] return client.hdel(name, *nkeys) def hget( @@ -1315,7 +1315,7 @@ def hmget( keys: list, version: int | None = None, client: Backend | Any | None = None, - ) -> list: + ) -> list[Any]: client = self._get_client(write=False, client=client) nkeys = [self.make_key(key, version=version) for key in keys] try: @@ -1373,7 +1373,7 @@ def hkeys( self, name: str, client: Backend | Any | None = None, - ) -> list[Any]: + ) -> list[str]: """ Return a list of keys in hash name. """ diff --git a/django_valkey/base_pool.py b/django_valkey/base_pool.py index d753312..e7dd5c4 100644 --- a/django_valkey/base_pool.py +++ b/django_valkey/base_pool.py @@ -16,13 +16,13 @@ class BaseConnectionFactory(Generic[Base, Pool]): _pools: dict[str, Pool | Any] = {} - def __init__(self, options: dict): - pool_cls_path = options.get("CONNECTION_POOL_CLASS", self.path_pool_cls) - self.pool_cls: type[Pool] | type = import_string(pool_cls_path) + def __init__(self, options: dict) -> None: + pool_cls_path: str = options.get("CONNECTION_POOL_CLASS", self.path_pool_cls) + self.pool_cls: type[Pool] = import_string(pool_cls_path) self.pool_cls_kwargs = options.get("CONNECTION_POOL_KWARGS", {}) - base_client_cls_path = options.get("BASE_CLIENT_CLASS", self.path_base_cls) - self.base_client_cls: type[Base] | type = import_string(base_client_cls_path) + base_client_cls_path: str = options.get("BASE_CLIENT_CLASS", self.path_base_cls) + self.base_client_cls: type[Base] = import_string(base_client_cls_path) self.base_client_cls_kwargs = options.get("BASE_CLIENT_KWARGS", {}) self.options = options @@ -38,7 +38,7 @@ def make_connection_params(self, url: str | None) -> dict: "parser_class": self.get_parser_cls(), } - socket_timeout = self.options.get("SOCKET_TIMEOUT", None) + socket_timeout: int | float | None = self.options.get("SOCKET_TIMEOUT", None) # TODO: do we need to check for existence? if socket_timeout: if not isinstance(socket_timeout, (int, float)): @@ -46,20 +46,20 @@ def make_connection_params(self, url: str | None) -> dict: raise ImproperlyConfigured(error_message) kwargs["socket_timeout"] = socket_timeout - socket_connect_timeout = self.options.get("SOCKET_CONNECT_TIMEOUT", None) + socket_connect_timeout: int | float | None = self.options.get("SOCKET_CONNECT_TIMEOUT", None) if socket_connect_timeout: if not isinstance(socket_connect_timeout, (int, float)): error_message = "Socket connect timeout should be float or integer" raise ImproperlyConfigured(error_message) kwargs["socket_connect_timeout"] = socket_connect_timeout - password = self.options.get("PASSWORD", None) + password: str | None = self.options.get("PASSWORD", None) if password: kwargs["password"] = password return kwargs - def get_connection_pool(self, params: dict) -> Pool | Any: + def get_connection_pool(self, params: dict) -> Pool: """ Given a connection parameters, return a new connection pool for them. @@ -77,7 +77,7 @@ def get_connection_pool(self, params: dict) -> Pool | Any: return pool - def get_or_create_connection_pool(self, params: dict) -> Pool | Any: + def get_or_create_connection_pool(self, params: dict) -> Pool: """ Given a connection parameters and return a new or cached connection pool for them. @@ -90,7 +90,7 @@ def get_or_create_connection_pool(self, params: dict) -> Pool | Any: self._pools[key] = self.get_connection_pool(params) return self._pools[key] - def get_connection(self, params: dict) -> Base | Any: + def get_connection(self, params: dict) -> Base: """ Given a now preformatted params, return a new connection. @@ -100,7 +100,7 @@ def get_connection(self, params: dict) -> Base | Any: """ raise NotImplementedError - def connect(self, url: str) -> Base | Any: + def connect(self, url: str) -> Base: """ Given a basic connection parameters, return a new connection. diff --git a/django_valkey/pool.py b/django_valkey/pool.py index d69b60e..7c242a4 100644 --- a/django_valkey/pool.py +++ b/django_valkey/pool.py @@ -9,14 +9,15 @@ from valkey.sentinel import Sentinel from valkey._parsers.url_parser import to_bool -from django_valkey.base_pool import BaseConnectionFactory, Base +from django_valkey.base_pool import BaseConnectionFactory +from django_valkey.typing import DefaultParserT class ConnectionFactory(BaseConnectionFactory[Valkey, ConnectionPool]): path_pool_cls = "valkey.connection.ConnectionPool" path_base_cls = "valkey.client.Valkey" - def disconnect(self, connection: type[Valkey] | type) -> None: + def disconnect(self, connection: type[Valkey]) -> None: """ Given a not null client connection it disconnects from the Valkey server. @@ -24,17 +25,17 @@ def disconnect(self, connection: type[Valkey] | type) -> None: """ connection.connection_pool.disconnect() - def get_parser_cls(self) -> type[DefaultParser] | type: + def get_parser_cls(self) -> DefaultParserT: cls = self.options.get("PARSER_CLASS", None) if cls is None: return DefaultParser return import_string(cls) - def connect(self, url: str) -> Valkey | Any: + def connect(self, url: str) -> Valkey: params = self.make_connection_params(url) return self.get_connection(params) - def get_connection(self, params: dict) -> Base | Any: + def get_connection(self, params: dict) -> Valkey: pool = self.get_or_create_connection_pool(params) return self.base_client_cls(connection_pool=pool, **self.base_client_cls_kwargs) @@ -78,7 +79,7 @@ def get_connection_pool(self, params: dict) -> ConnectionPool | Any: # convert "is_master" to a boolean if set on the URL, otherwise if not # provided it defaults to True. - is_master: list[str] = parse_qs(url.query).get("is_master") + is_master: list[str] | None = parse_qs(url.query).get("is_master", None) if is_master: pool.is_master = to_bool(is_master[0]) diff --git a/django_valkey/typing.py b/django_valkey/typing.py index 451f4df..c4850fa 100644 --- a/django_valkey/typing.py +++ b/django_valkey/typing.py @@ -1 +1,8 @@ +from valkey._parsers import _RESP2Parser, _RESP3Parser, _LibvalkeyParser +from valkey.utils import LIBVALKEY_AVAILABLE + KeyT = int | float | str | bytes | memoryview + + +DefaultParserT = type[_RESP2Parser | _RESP3Parser | _LibvalkeyParser] +