Skip to content

Commit

Permalink
set methods can now return a set or a lis, depending on user's choice
Browse files Browse the repository at this point in the history
the `return_set` parameters has been added to the methods that return a value, so it can either return a set (default) or return a list (valkey's default behaviour)
  • Loading branch information
amirreza8002 committed Dec 23, 2024
1 parent 6fcbecf commit adddf8e
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 37 deletions.
49 changes: 36 additions & 13 deletions django_valkey/async_cache/client/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,11 +806,17 @@ async def scard(
ascard = scard

async def sdiff(
self, *keys, version: int | None = None, client: AValkey | Any | None = None
) -> Set[Any]:
self,
*keys,
version: int | None = None,
client: AValkey | Any | None = None,
return_set: bool = True,
) -> Set[Any] | list[Any]:
client = await self._get_client(write=False, client=client)
nkeys = [await self.make_key(key, version=version) for key in keys]
return {await self.decode(value) for value in await client.sdiff(*nkeys)}
return await self._decode_iterable_result(
await client.sdiff(*nkeys), convert_to_set=return_set
)

asdiff = sdiff

Expand All @@ -830,11 +836,17 @@ async def sdiffstore(
asdiffstore = sdiffstore

async def sinter(
self, *keys, version: int | None = None, client: AValkey | Any | None = None
self,
*keys,
version: int | None = None,
client: AValkey | Any | None = None,
return_set: bool = True,
) -> Set[Any]:
client = await self._get_client(write=False, client=client)
nkeys = [await self.make_key(key, version=version) for key in keys]
return {await self.decode(value) for value in await client.sinter(*nkeys)}
return await self._decode_iterable_result(
await client.sinter(*nkeys), convert_to_set=return_set
)

asinter = sinter

Expand Down Expand Up @@ -899,12 +911,18 @@ async def sismember(
asismember = sismember

async def smembers(
self, key, version: int | None = None, client: AValkey | Any | None = None
) -> Set[Any]:
self,
key,
version: int | None = None,
client: AValkey | Any | None = None,
return_set: bool = True,
) -> Set[Any] | list[Any]:
client = await self._get_client(write=False, client=client)

key = await self.make_key(key, version=version)
return {await self.decode(value) for value in await client.smembers(key)}
return await self._decode_iterable_result(
await client.smembers(key), convert_to_set=return_set
)

asmembers = smembers

Expand All @@ -930,11 +948,12 @@ async def spop(
count: int | None = None,
version: int | None = None,
client: AValkey | Any | None = None,
) -> Set | Any:
return_set: bool = True,
) -> Set | list | Any:
client = await self._get_client(write=True, client=client)
nkey = await self.make_key(key, version=version)
result = await client.spop(nkey, count)
return await self._decode_iterable_result(result)
return await self._decode_iterable_result(result, convert_to_set=return_set)

aspop = spop

Expand All @@ -944,11 +963,12 @@ async def srandmember(
count: int | None = None,
version: int | None = None,
client: AValkey | Any | None = None,
return_set: bool = True,
) -> list | Any:
client = await self._get_client(write=False, client=client)
key = await self.make_key(key, version=version)
result = await client.srandmember(key, count)
return await self._decode_iterable_result(result, convert_to_set=False)
return await self._decode_iterable_result(result, convert_to_set=return_set)

asrandmember = srandmember

Expand Down Expand Up @@ -1030,11 +1050,14 @@ async def sunion(
*keys,
version: int | None = None,
client: AValkey | Any | None = None,
) -> Set[Any]:
retrun_set: bool = True,
) -> Set[Any] | list[Any]:
client = await self._get_client(write=False, client=client)

nkeys = [await self.make_key(key, version=version) for key in keys]
return {await self.decode(value) for value in await client.sunion(*nkeys)}
return await self._decode_iterable_result(
await client.sunion(*nkeys), convert_to_set=retrun_set
)

asunion = sunion

Expand Down
43 changes: 28 additions & 15 deletions django_valkey/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ def encode(self, value: EncodableT) -> bytes | int | float:

def _decode_iterable_result(
self, result: Any, convert_to_set: bool = True
) -> List[Any] | Any | None:
) -> list[Any] | Set[Any] | Any | None:
if result is None:
return None
if isinstance(result, list):
Expand Down Expand Up @@ -931,11 +931,14 @@ def sdiff(
*keys: KeyT,
version: int | None = None,
client: Backend | Any | None = None,
) -> Set[Any]:
return_set: bool = True,
) -> Set[Any] | list[Any]:
client = self._get_client(write=False, client=client)

nkeys = [self.make_key(key, version=version) for key in keys]
return {self.decode(value) for value in client.sdiff(*nkeys)}
return self._decode_iterable_result(
client.sdiff(*nkeys), convert_to_set=return_set
)

def sdiffstore(
self,
Expand All @@ -956,11 +959,14 @@ def sinter(
*keys: KeyT,
version: int | None = None,
client: Backend | Any | None = None,
) -> Set[Any]:
return_set: bool = True,
) -> Set[Any] | list[Any]:
client = self._get_client(write=False, client=client)

nkeys = [self.make_key(key, version=version) for key in keys]
return {self.decode(value) for value in client.sinter(*nkeys)}
return self._decode_iterable_result(
client.sinter(*nkeys), convert_to_set=return_set
)

def sintercard(
self,
Expand Down Expand Up @@ -993,12 +999,11 @@ def smismember(
*members: Any,
version: int | None = None,
client: Backend | Any | None = None,
) -> List[bool]:
) -> list[bool]:
client = self._get_client(write=False, client=client)

key = self.make_key(key, version=version)
encoded_members = [self.encode(member) for member in members]

return [bool(value) for value in client.smismember(key, *encoded_members)]

def sismember(
Expand All @@ -1019,11 +1024,14 @@ def smembers(
key: KeyT,
version: int | None = None,
client: Backend | Any | None = None,
) -> Set[Any]:
return_set: bool = True,
) -> Set[Any] | list[Any]:
client = self._get_client(write=False, client=client)

key = self.make_key(key, version=version)
return {self.decode(value) for value in client.smembers(key)}
return self._decode_iterable_result(
client.smembers(key), convert_to_set=return_set
)

def smove(
self,
Expand All @@ -1046,25 +1054,27 @@ def spop(
count: int | None = None,
version: int | None = None,
client: Backend | Any | None = None,
) -> Set | Any:
return_set: bool = True,
) -> Set | list | Any:
client = self._get_client(write=True, client=client)

nkey = self.make_key(key, version=version)
result = client.spop(nkey, count)
return self._decode_iterable_result(result)
return self._decode_iterable_result(result, convert_to_set=return_set)

def srandmember(
self,
key: KeyT,
count: int | None = None,
version: int | None = None,
client: Backend | Any | None = None,
) -> List | Any:
return_set: bool = True,
) -> list | Any:
client = self._get_client(write=False, client=client)

key = self.make_key(key, version=version)
result = client.srandmember(key, count)
return self._decode_iterable_result(result, convert_to_set=False)
return self._decode_iterable_result(result, convert_to_set=return_set)

def srem(
self,
Expand Down Expand Up @@ -1132,11 +1142,14 @@ def sunion(
*keys: KeyT,
version: int | None = None,
client: Backend | Any | None = None,
) -> Set[Any]:
return_set: bool = True,
) -> Set[Any] | list[Any]:
client = self._get_client(write=False, client=client)

nkeys = [self.make_key(key, version=version) for key in keys]
return {self.decode(value) for value in client.sunion(*nkeys)}
return self._decode_iterable_result(
client.sunion(*nkeys), convert_to_set=return_set
)

def sunionstore(
self,
Expand Down
23 changes: 16 additions & 7 deletions django_valkey/client/sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,11 +468,14 @@ def smembers(
key: KeyT,
version: int | None = None,
client: Valkey | Any | None = None,
) -> Set[Any]:
return_set: bool = True,
) -> Set[Any] | list[Any]:
if client is None:
key = self.make_key(key, version=version)
client = self.get_server(key)
return super().smembers(key=key, version=version, client=client)
return super().smembers(
key=key, version=version, client=client, return_set=return_set
)

def smove(
self,
Expand Down Expand Up @@ -551,11 +554,14 @@ def srandmember(
count: int | None = None,
version: int | None = None,
client: Valkey | Any | None = None,
) -> Set | Any:
return_set: bool = True,
) -> Set | list | Any:
if client is None:
key = self.make_key(key, version=version)
client = self.get_server(key)
return super().srandmember(key=key, count=count, version=version, client=client)
return super().srandmember(
key=key, count=count, version=version, client=client, returnSet=return_set
)

def sismember(
self,
Expand All @@ -575,19 +581,22 @@ def spop(
count: int | None = None,
version: int | None = None,
client: Valkey | Any | None = None,
) -> Set | Any:
return_set: bool = True,
) -> Set | list | Any:
if client is None:
key = self.make_key(key, version=version)
client = self.get_server(key)
return super().spop(key=key, count=count, version=version, client=client)
return super().spop(
key=key, count=count, version=version, client=client, return_set=return_set
)

def smismember(
self,
key: KeyT,
*members,
version: int | None = None,
client: Valkey | Any | None = None,
) -> List[bool]:
) -> list[bool]:
if client is None:
key = self.make_key(key, version=version)
client = self.get_server(key)
Expand Down
3 changes: 2 additions & 1 deletion tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1238,7 +1238,8 @@ def test_srandmember_default_count(self, cache: ValkeyCache):

def test_srandmember(self, cache: ValkeyCache):
cache.sadd("foo", "bar1", "bar2")
assert cache.srandmember("foo", 1) in [["bar1"], ["bar2"]]
assert cache.srandmember("foo", 1) in [{"bar1"}, {"bar2"}]
assert cache.srandmember("foo", 1, return_set=False) in [["bar1"], ["bar2"]]

def test_srem(self, cache: ValkeyCache):
cache.sadd("foo", "bar1", "bar2")
Expand Down
6 changes: 5 additions & 1 deletion tests/tests_async/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1141,7 +1141,11 @@ async def test_srandmember_default_count(self, cache: AsyncValkeyCache):

async def test_srandmember(self, cache: AsyncValkeyCache):
await cache.asadd("foo", "bar1", "bar2")
assert await cache.asrandmember("foo", 1) in [["bar1"], ["bar2"]]
assert await cache.asrandmember("foo", 1, return_set=False) in [
["bar1"],
["bar2"],
]
assert await cache.asrandmember("foo", 1) in [{"bar1"}, {"bar2"}]

async def test_srem(self, cache: AsyncValkeyCache):
await cache.asadd("foo", "bar1", "bar2")
Expand Down

0 comments on commit adddf8e

Please sign in to comment.