diff --git a/django_valkey/async_cache/client/default.py b/django_valkey/async_cache/client/default.py index e5882e7..915daec 100644 --- a/django_valkey/async_cache/client/default.py +++ b/django_valkey/async_cache/client/default.py @@ -806,11 +806,17 @@ async def scard( ascard = scard async def sdiff( - self, *keys, version: int | None = None, client: AValkey | Any | None = None - ) -> Set[Any]: + self, + *keys, + version: int | None = None, + client: AValkey | Any | None = None, + return_set: bool = True, + ) -> Set[Any] | list[Any]: client = await self._get_client(write=False, client=client) nkeys = [await self.make_key(key, version=version) for key in keys] - return {await self.decode(value) for value in await client.sdiff(*nkeys)} + return await self._decode_iterable_result( + await client.sdiff(*nkeys), convert_to_set=return_set + ) asdiff = sdiff @@ -830,11 +836,17 @@ async def sdiffstore( asdiffstore = sdiffstore async def sinter( - self, *keys, version: int | None = None, client: AValkey | Any | None = None + self, + *keys, + version: int | None = None, + client: AValkey | Any | None = None, + return_set: bool = True, ) -> Set[Any]: client = await self._get_client(write=False, client=client) nkeys = [await self.make_key(key, version=version) for key in keys] - return {await self.decode(value) for value in await client.sinter(*nkeys)} + return await self._decode_iterable_result( + await client.sinter(*nkeys), convert_to_set=return_set + ) asinter = sinter @@ -899,12 +911,18 @@ async def sismember( asismember = sismember async def smembers( - self, key, version: int | None = None, client: AValkey | Any | None = None - ) -> Set[Any]: + self, + key, + version: int | None = None, + client: AValkey | Any | None = None, + return_set: bool = True, + ) -> Set[Any] | list[Any]: client = await self._get_client(write=False, client=client) key = await self.make_key(key, version=version) - return {await self.decode(value) for value in await client.smembers(key)} + return await self._decode_iterable_result( + await client.smembers(key), convert_to_set=return_set + ) asmembers = smembers @@ -930,11 +948,12 @@ async def spop( count: int | None = None, version: int | None = None, client: AValkey | Any | None = None, - ) -> Set | Any: + return_set: bool = True, + ) -> Set | list | Any: client = await self._get_client(write=True, client=client) nkey = await self.make_key(key, version=version) result = await client.spop(nkey, count) - return await self._decode_iterable_result(result) + return await self._decode_iterable_result(result, convert_to_set=return_set) aspop = spop @@ -944,11 +963,12 @@ async def srandmember( count: int | None = None, version: int | None = None, client: AValkey | Any | None = None, + return_set: bool = True, ) -> list | Any: client = await self._get_client(write=False, client=client) key = await self.make_key(key, version=version) result = await client.srandmember(key, count) - return await self._decode_iterable_result(result, convert_to_set=False) + return await self._decode_iterable_result(result, convert_to_set=return_set) asrandmember = srandmember @@ -1030,11 +1050,14 @@ async def sunion( *keys, version: int | None = None, client: AValkey | Any | None = None, - ) -> Set[Any]: + retrun_set: bool = True, + ) -> Set[Any] | list[Any]: client = await self._get_client(write=False, client=client) nkeys = [await self.make_key(key, version=version) for key in keys] - return {await self.decode(value) for value in await client.sunion(*nkeys)} + return await self._decode_iterable_result( + await client.sunion(*nkeys), convert_to_set=retrun_set + ) asunion = sunion diff --git a/django_valkey/base_client.py b/django_valkey/base_client.py index 146dec2..858d1c4 100644 --- a/django_valkey/base_client.py +++ b/django_valkey/base_client.py @@ -553,7 +553,7 @@ def encode(self, value: EncodableT) -> bytes | int | float: def _decode_iterable_result( self, result: Any, convert_to_set: bool = True - ) -> List[Any] | Any | None: + ) -> list[Any] | Set[Any] | Any | None: if result is None: return None if isinstance(result, list): @@ -931,11 +931,14 @@ def sdiff( *keys: KeyT, version: int | None = None, client: Backend | Any | None = None, - ) -> Set[Any]: + return_set: bool = True, + ) -> Set[Any] | list[Any]: client = self._get_client(write=False, client=client) nkeys = [self.make_key(key, version=version) for key in keys] - return {self.decode(value) for value in client.sdiff(*nkeys)} + return self._decode_iterable_result( + client.sdiff(*nkeys), convert_to_set=return_set + ) def sdiffstore( self, @@ -956,11 +959,14 @@ def sinter( *keys: KeyT, version: int | None = None, client: Backend | Any | None = None, - ) -> Set[Any]: + return_set: bool = True, + ) -> Set[Any] | list[Any]: client = self._get_client(write=False, client=client) nkeys = [self.make_key(key, version=version) for key in keys] - return {self.decode(value) for value in client.sinter(*nkeys)} + return self._decode_iterable_result( + client.sinter(*nkeys), convert_to_set=return_set + ) def sintercard( self, @@ -993,12 +999,11 @@ def smismember( *members: Any, version: int | None = None, client: Backend | Any | None = None, - ) -> List[bool]: + ) -> list[bool]: client = self._get_client(write=False, client=client) key = self.make_key(key, version=version) encoded_members = [self.encode(member) for member in members] - return [bool(value) for value in client.smismember(key, *encoded_members)] def sismember( @@ -1019,11 +1024,14 @@ def smembers( key: KeyT, version: int | None = None, client: Backend | Any | None = None, - ) -> Set[Any]: + return_set: bool = True, + ) -> Set[Any] | list[Any]: client = self._get_client(write=False, client=client) key = self.make_key(key, version=version) - return {self.decode(value) for value in client.smembers(key)} + return self._decode_iterable_result( + client.smembers(key), convert_to_set=return_set + ) def smove( self, @@ -1046,12 +1054,13 @@ def spop( count: int | None = None, version: int | None = None, client: Backend | Any | None = None, - ) -> Set | Any: + return_set: bool = True, + ) -> Set | list | Any: client = self._get_client(write=True, client=client) nkey = self.make_key(key, version=version) result = client.spop(nkey, count) - return self._decode_iterable_result(result) + return self._decode_iterable_result(result, convert_to_set=return_set) def srandmember( self, @@ -1059,12 +1068,13 @@ def srandmember( count: int | None = None, version: int | None = None, client: Backend | Any | None = None, - ) -> List | Any: + return_set: bool = True, + ) -> list | Any: client = self._get_client(write=False, client=client) key = self.make_key(key, version=version) result = client.srandmember(key, count) - return self._decode_iterable_result(result, convert_to_set=False) + return self._decode_iterable_result(result, convert_to_set=return_set) def srem( self, @@ -1132,11 +1142,14 @@ def sunion( *keys: KeyT, version: int | None = None, client: Backend | Any | None = None, - ) -> Set[Any]: + return_set: bool = True, + ) -> Set[Any] | list[Any]: client = self._get_client(write=False, client=client) nkeys = [self.make_key(key, version=version) for key in keys] - return {self.decode(value) for value in client.sunion(*nkeys)} + return self._decode_iterable_result( + client.sunion(*nkeys), convert_to_set=return_set + ) def sunionstore( self, diff --git a/django_valkey/client/sharded.py b/django_valkey/client/sharded.py index 80795ef..7df0131 100644 --- a/django_valkey/client/sharded.py +++ b/django_valkey/client/sharded.py @@ -468,11 +468,14 @@ def smembers( key: KeyT, version: int | None = None, client: Valkey | Any | None = None, - ) -> Set[Any]: + return_set: bool = True, + ) -> Set[Any] | list[Any]: if client is None: key = self.make_key(key, version=version) client = self.get_server(key) - return super().smembers(key=key, version=version, client=client) + return super().smembers( + key=key, version=version, client=client, return_set=return_set + ) def smove( self, @@ -551,11 +554,14 @@ def srandmember( count: int | None = None, version: int | None = None, client: Valkey | Any | None = None, - ) -> Set | Any: + return_set: bool = True, + ) -> Set | list | Any: if client is None: key = self.make_key(key, version=version) client = self.get_server(key) - return super().srandmember(key=key, count=count, version=version, client=client) + return super().srandmember( + key=key, count=count, version=version, client=client, return_set=return_set + ) def sismember( self, @@ -575,11 +581,14 @@ def spop( count: int | None = None, version: int | None = None, client: Valkey | Any | None = None, - ) -> Set | Any: + return_set: bool = True, + ) -> Set | list | Any: if client is None: key = self.make_key(key, version=version) client = self.get_server(key) - return super().spop(key=key, count=count, version=version, client=client) + return super().spop( + key=key, count=count, version=version, client=client, return_set=return_set + ) def smismember( self, @@ -587,7 +596,7 @@ def smismember( *members, version: int | None = None, client: Valkey | Any | None = None, - ) -> List[bool]: + ) -> list[bool]: if client is None: key = self.make_key(key, version=version) client = self.get_server(key) diff --git a/tests/test_backend.py b/tests/test_backend.py index 13b9c5e..bb8890e 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -1238,7 +1238,8 @@ def test_srandmember_default_count(self, cache: ValkeyCache): def test_srandmember(self, cache: ValkeyCache): cache.sadd("foo", "bar1", "bar2") - assert cache.srandmember("foo", 1) in [["bar1"], ["bar2"]] + assert cache.srandmember("foo", 1) in [{"bar1"}, {"bar2"}] + assert cache.srandmember("foo", 1, return_set=False) in [["bar1"], ["bar2"]] def test_srem(self, cache: ValkeyCache): cache.sadd("foo", "bar1", "bar2") diff --git a/tests/tests_async/test_backend.py b/tests/tests_async/test_backend.py index d2f1897..9ab7bed 100644 --- a/tests/tests_async/test_backend.py +++ b/tests/tests_async/test_backend.py @@ -1141,7 +1141,11 @@ async def test_srandmember_default_count(self, cache: AsyncValkeyCache): async def test_srandmember(self, cache: AsyncValkeyCache): await cache.asadd("foo", "bar1", "bar2") - assert await cache.asrandmember("foo", 1) in [["bar1"], ["bar2"]] + assert await cache.asrandmember("foo", 1, return_set=False) in [ + ["bar1"], + ["bar2"], + ] + assert await cache.asrandmember("foo", 1) in [{"bar1"}, {"bar2"}] async def test_srem(self, cache: AsyncValkeyCache): await cache.asadd("foo", "bar1", "bar2")