Skip to content

Commit

Permalink
fixed some type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
amirreza8002 committed Dec 28, 2024
1 parent 8336b54 commit c7017d3
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 64 deletions.
12 changes: 6 additions & 6 deletions django_valkey/async_cache/client/default.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import builtins
from collections.abc import Iterable, AsyncIterable
from collections.abc import Iterable, AsyncIterator
import contextlib
from contextlib import suppress
from typing import Any, cast, TYPE_CHECKING
Expand Down Expand Up @@ -454,7 +454,7 @@ async def amget(
map_keys = {await self.make_key(k, version=version): k for k in keys}

try:
results = await client.mget(*map_keys)
results: list[bytes] = await client.mget(*map_keys)
except _main_exceptions as e:
raise ConnectionInterrupted(connection=client) from e

Expand Down Expand Up @@ -698,7 +698,7 @@ async def aiter_keys(
itersize: int | None = None,
client: AValkey | Any | None = None,
version: int | None = None,
) -> AsyncIterable[KeyT]:
) -> AsyncIterator[KeyT]:
"""
Same as keys, but uses cursors
for make memory efficient keys iteration.
Expand Down Expand Up @@ -974,7 +974,7 @@ async def asrandmember(
version: int | None = None,
client: AValkey | Any | None = None,
return_set: bool = True,
) -> list | Any:
) -> builtins.set | list | Any:
client = await self._get_client(write=False, client=client)
key = await self.make_key(key, version=version)
result = await client.srandmember(key, count)
Expand Down Expand Up @@ -1034,7 +1034,7 @@ async def asscan_iter(
count: int = 10,
version: int | None = None,
client: AValkey | Any | None = None,
) -> AsyncIterable[Any]:
) -> AsyncIterator[Any]:
if self._has_compression_enabled() and match:
error_message = "Using match with compression is not supported."
raise ValueError(error_message)
Expand Down Expand Up @@ -1251,7 +1251,7 @@ async def ahmget(
keys: list,
version: int | None = None,
client: AValkey | Any | None = None,
) -> list:
) -> list[Any]:
client = await self._get_client(write=False, client=client)
nkeys = [await self.make_key(key, version=version) for key in keys]
try:
Expand Down
77 changes: 43 additions & 34 deletions django_valkey/base.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
import builtins
from collections.abc import Iterator, AsyncIterator, Callable
import contextlib
import functools
import logging
from asyncio import iscoroutinefunction
from typing import Any, TypeVar, Generic, Iterator, AsyncGenerator, Set, Callable
from typing import Any, TypeVar, Generic, TYPE_CHECKING

from django.conf import settings
from django.core.cache.backends.base import BaseCache
from django.utils.module_loading import import_string

from django_valkey.exceptions import ConnectionInterrupted
from django_valkey.typing import KeyT

if TYPE_CHECKING:
from valkey.lock import Lock
from valkey.asyncio.lock import Lock as ALock

Client = TypeVar("Client")
Backend = TypeVar("Backend")
Expand Down Expand Up @@ -91,19 +98,19 @@ def client(self) -> Client:
return self._client

@omit_exception
def mset(self, *args, **kwargs):
def mset(self, *args, **kwargs) -> bool:
return self.client.mset(*args, **kwargs)

@omit_exception
async def amset(self, *args, **kwargs):
async def amset(self, *args, **kwargs) -> bool:
return await self.client.amset(*args, **kwargs)

@omit_exception
def mget(self, *args, **kwargs):
def mget(self, *args, **kwargs) -> dict | list[Any]:
return self.client.mget(*args, **kwargs)

@omit_exception
async def amget(self, *args, **kwargs):
async def amget(self, *args, **kwargs) -> dict | list[Any]:
return await self.client.amget(*args, **kwargs)

@omit_exception
Expand Down Expand Up @@ -147,13 +154,13 @@ async def apexpire_at(self, *args, **kwargs) -> bool:
return await self.client.apexpire_at(*args, **kwargs)

@omit_exception
def get_lock(self, *args, **kwargs):
def get_lock(self, *args, **kwargs) -> "Lock":
return self.client.get_lock(*args, **kwargs)

lock = get_lock

@omit_exception
async def aget_lock(self, *args, **kwargs):
async def aget_lock(self, *args, **kwargs) -> "ALock":
return await self.client.aget_lock(*args, **kwargs)

alock = aget_lock
Expand All @@ -169,7 +176,7 @@ async def adelete_pattern(self, *args, **kwargs) -> int:
return await self.client.adelete_pattern(*args, **kwargs)

@omit_exception
def ttl(self, *args, **kwargs) -> int | None:
def ttl(self, *args, **kwargs) -> int:
return self.client.ttl(*args, **kwargs)

@omit_exception
Expand All @@ -185,29 +192,29 @@ async def apttl(self, *args, **kwargs) -> int:
return await self.client.apttl(*args, **kwargs)

@omit_exception
def iter_keys(self, *args, **kwargs) -> Iterator:
def iter_keys(self, *args, **kwargs) -> Iterator[KeyT]:
return self.client.iter_keys(*args, **kwargs)

@omit_exception
async def aiter_keys(self, *args, **kwargs) -> AsyncGenerator:
async def aiter_keys(self, *args, **kwargs) -> AsyncIterator[KeyT]:
async with contextlib.aclosing(self.client.aiter_keys(*args, **kwargs)) as it:
async for key in it:
yield key

@omit_exception
def keys(self, *args, **kwargs) -> list[Any]:
def keys(self, *args, **kwargs) -> list:
return self.client.keys(*args, **kwargs)

@omit_exception
async def akeys(self, *args, **kwargs) -> list[Any]:
async def akeys(self, *args, **kwargs) -> list:
return await self.client.akeys(*args, **kwargs)

@omit_exception
def scan(self, *args, **kwargs) -> tuple[int, list[str]]:
def scan(self, *args, **kwargs) -> tuple[int, list[Any]]:
return self.client.scan(*args, **kwargs)

@omit_exception
async def ascan(self, *args, **kwargs) -> tuple[int, list[str]]:
async def ascan(self, *args, **kwargs) -> tuple[int, list[Any]]:
return await self.client.ascan(*args, **kwargs)

@omit_exception
Expand All @@ -227,11 +234,11 @@ async def ascard(self, *args, **kwargs) -> int:
return await self.client.ascard(*args, **kwargs)

@omit_exception
def sdiff(self, *args, **kwargs) -> Set[Any]:
def sdiff(self, *args, **kwargs) -> builtins.set[Any] | list[Any]:
return self.client.sdiff(*args, **kwargs)

@omit_exception
async def asdiff(self, *args, **kwargs) -> Set[Any]:
async def asdiff(self, *args, **kwargs) -> builtins.set[Any] | list[Any]:
return await self.client.asdiff(*args, **kwargs)

@omit_exception
Expand All @@ -243,11 +250,11 @@ async def asdiffstore(self, *args, **kwargs) -> int:
return await self.client.asdiffstore(*args, **kwargs)

@omit_exception
def sinter(self, *args, **kwargs) -> Set[Any]:
def sinter(self, *args, **kwargs) -> builtins.set[Any] | list[Any]:
return self.client.sinter(*args, **kwargs)

@omit_exception
async def asinter(self, *args, **kwargs) -> Set[Any]:
async def asinter(self, *args, **kwargs) -> builtins.set[Any] | list[Any]:
return await self.client.asinter(*args, **kwargs)

@omit_exception
Expand Down Expand Up @@ -283,11 +290,11 @@ async def asismember(self, *args, **kwargs) -> bool:
return await self.client.asismember(*args, **kwargs)

@omit_exception
def smembers(self, *args, **kwargs) -> Set[Any]:
def smembers(self, *args, **kwargs) -> builtins.set[Any] | list[Any]:
return self.client.smembers(*args, **kwargs)

@omit_exception
async def asmembers(self, *args, **kwargs) -> Set[Any]:
async def asmembers(self, *args, **kwargs) -> builtins.set[Any] | list[Any]:
return await self.client.asmembers(*args, **kwargs)

@omit_exception
Expand All @@ -299,19 +306,19 @@ async def asmove(self, *args, **kwargs) -> bool:
return await self.client.asmove(*args, **kwargs)

@omit_exception
def spop(self, *args, **kwargs) -> Set | Any:
def spop(self, *args, **kwargs) -> builtins.set | list | Any:
return self.client.spop(*args, **kwargs)

@omit_exception
async def aspop(self, *args, **kwargs) -> Set | Any:
async def aspop(self, *args, **kwargs) -> builtins.set | list | Any:
return await self.client.aspop(*args, **kwargs)

@omit_exception
def srandmember(self, *args, **kwargs) -> list | Any:
def srandmember(self, *args, **kwargs) -> builtins.set | list | Any:
return self.client.srandmember(*args, **kwargs)

@omit_exception
async def asrandmember(self, *args, **kwargs) -> list | Any:
async def asrandmember(self, *args, **kwargs) -> builtins.set | list | Any:
return await self.client.asrandmember(*args, **kwargs)

@omit_exception
Expand All @@ -323,29 +330,31 @@ async def asrem(self, *args, **kwargs) -> int:
return await self.client.asrem(*args, **kwargs)

@omit_exception
def sscan(self, *args, **kwargs) -> Set[Any]:
def sscan(self, *args, **kwargs) -> tuple[int, builtins.set[Any] | list[Any]]:
return self.client.sscan(*args, **kwargs)

@omit_exception
async def asscan(self, *args, **kwargs) -> Set[Any]:
async def asscan(
self, *args, **kwargs
) -> tuple[int, builtins.set[Any] | list[Any]]:
return await self.client.asscan(*args, **kwargs)

@omit_exception
def sscan_iter(self, *args, **kwargs) -> Iterator:
def sscan_iter(self, *args, **kwargs) -> Iterator[Any]:
return self.client.sscan_iter(*args, **kwargs)

@omit_exception
async def asscan_iter(self, *args, **kwargs) -> AsyncGenerator:
async def asscan_iter(self, *args, **kwargs) -> AsyncIterator[Any]:
async with contextlib.aclosing(self.client.asscan_iter(*args, **kwargs)) as it:
async for key in it:
yield key

@omit_exception
def sunion(self, *args, **kwargs) -> Set[Any]:
def sunion(self, *args, **kwargs) -> builtins.set[Any] | list[Any]:
return self.client.sunion(*args, **kwargs)

@omit_exception
async def asunion(self, *args, **kwargs) -> Set[Any]:
async def asunion(self, *args, **kwargs) -> builtins.set[Any] | list[Any]:
return await self.client.asunion(*args, **kwargs)

@omit_exception
Expand Down Expand Up @@ -389,19 +398,19 @@ async def ahdel_many(self, *args, **kwargs) -> int:
return await self.client.ahdel_many(*args, **kwargs)

@omit_exception
def hget(self, *args, **kwargs) -> bytes | None:
def hget(self, *args, **kwargs) -> Any | None:
return self.client.hget(*args, **kwargs)

@omit_exception
async def ahget(self, *args, **kwargs) -> str | None:
async def ahget(self, *args, **kwargs) -> Any | None:
return await self.client.ahget(*args, **kwargs)

@omit_exception
def hgetall(self, *args, **kwargs) -> dict:
def hgetall(self, *args, **kwargs) -> dict[str, Any]:
return self.client.hgetall(*args, **kwargs)

@omit_exception
async def ahgetall(self, *args, **kwargs) -> dict:
async def ahgetall(self, *args, **kwargs) -> dict[str, Any]:
return await self.client.ahgetall(*args, **kwargs)

@omit_exception
Expand Down
12 changes: 6 additions & 6 deletions django_valkey/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,7 @@ def mget(
map_keys = {self.make_key(k, version=version): k for k in keys}

try:
results = client.mget(map_keys)
results: list[bytes] = client.mget(map_keys)
except _main_exceptions as e:
raise ConnectionInterrupted(connection=client) from e

Expand Down Expand Up @@ -1073,7 +1073,7 @@ def srandmember(
version: int | None = None,
client: Backend | Any | None = None,
return_set: bool = True,
) -> list | Any:
) -> builtins.set | list | Any:
client = self._get_client(write=False, client=client)

key = self.make_key(key, version=version)
Expand Down Expand Up @@ -1102,7 +1102,7 @@ def sscan(
version: int | None = None,
client: Backend | Any | None = None,
return_set: bool = True,
) -> tuple[int, builtins.set[Any]] | tuple[int, list[Any]]:
) -> tuple[int, builtins.set[Any] | list[Any]]:
if self._has_compression_enabled() and match:
err_msg = "Using match with compression is not supported."
raise ValueError(err_msg)
Expand Down Expand Up @@ -1275,7 +1275,7 @@ def hdel_many(
client: Backend | Any | None = None,
) -> int:
client = self._get_client(write=True, client=client)
nkeys = [self.make_key(key) for key in keys]
nkeys = [self.make_key(key, version=version) for key in keys]
return client.hdel(name, *nkeys)

def hget(
Expand Down Expand Up @@ -1315,7 +1315,7 @@ def hmget(
keys: list,
version: int | None = None,
client: Backend | Any | None = None,
) -> list:
) -> list[Any]:
client = self._get_client(write=False, client=client)
nkeys = [self.make_key(key, version=version) for key in keys]
try:
Expand Down Expand Up @@ -1373,7 +1373,7 @@ def hkeys(
self,
name: str,
client: Backend | Any | None = None,
) -> list[Any]:
) -> list[str]:
"""
Return a list of keys in hash name.
"""
Expand Down
Loading

0 comments on commit c7017d3

Please sign in to comment.