Skip to content

Commit

Permalink
Merge pull request #743 from rootart/issue/597
Browse files Browse the repository at this point in the history
Add support for the set functions from issue #597
  • Loading branch information
WisdomPill authored Jun 16, 2024
2 parents ce47b30 + f34935c commit e11150a
Show file tree
Hide file tree
Showing 5 changed files with 657 additions and 3 deletions.
1 change: 1 addition & 0 deletions changelog.d/730.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Support for sets and support basic operations, sadd, scard, sdiff, sdiffstore, sinter, sinterstore, smismember, sismember, smembers, smove, spop, srandmember, srem, sscan, sscan_iter, sunion, sunionstore
68 changes: 68 additions & 0 deletions django_redis/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,74 @@ def close(self, **kwargs):
def touch(self, *args, **kwargs):
return self.client.touch(*args, **kwargs)

@omit_exception
def sadd(self, *args, **kwargs):
return self.client.sadd(*args, **kwargs)

@omit_exception
def scard(self, *args, **kwargs):
return self.client.scard(*args, **kwargs)

@omit_exception
def sdiff(self, *args, **kwargs):
return self.client.sdiff(*args, **kwargs)

@omit_exception
def sdiffstore(self, *args, **kwargs):
return self.client.sdiffstore(*args, **kwargs)

@omit_exception
def sinter(self, *args, **kwargs):
return self.client.sinter(*args, **kwargs)

@omit_exception
def sinterstore(self, *args, **kwargs):
return self.client.sinterstore(*args, **kwargs)

@omit_exception
def sismember(self, *args, **kwargs):
return self.client.sismember(*args, **kwargs)

@omit_exception
def smembers(self, *args, **kwargs):
return self.client.smembers(*args, **kwargs)

@omit_exception
def smove(self, *args, **kwargs):
return self.client.smove(*args, **kwargs)

@omit_exception
def spop(self, *args, **kwargs):
return self.client.spop(*args, **kwargs)

@omit_exception
def srandmember(self, *args, **kwargs):
return self.client.srandmember(*args, **kwargs)

@omit_exception
def srem(self, *args, **kwargs):
return self.client.srem(*args, **kwargs)

@omit_exception
def sscan(self, *args, **kwargs):
return self.client.sscan(*args, **kwargs)

@omit_exception
def sscan_iter(self, *args, **kwargs):
return self.client.sscan_iter(*args, **kwargs)

@omit_exception
def smismember(self, *args, **kwargs):
return self.client.smismember(*args, **kwargs)

@omit_exception
def sunion(self, *args, **kwargs):
return self.client.sunion(*args, **kwargs)

@omit_exception
def sunionstore(self, *args, **kwargs):
return self.client.sunionstore(*args, **kwargs)

@omit_exception
def hset(self, *args, **kwargs):
return self.client.hset(*args, **kwargs)
Expand Down
285 changes: 283 additions & 2 deletions django_redis/client/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,26 @@
import socket
from collections import OrderedDict
from contextlib import suppress
from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union
from typing import (
Any,
Dict,
Iterable,
Iterator,
List,
Optional,
Set,
Tuple,
Union,
cast,
)

from django.conf import settings
from django.core.cache.backends.base import DEFAULT_TIMEOUT, BaseCache, get_key_func
from django.core.exceptions import ImproperlyConfigured
from django.utils.module_loading import import_string
from redis import Redis
from redis.exceptions import ConnectionError, ResponseError, TimeoutError
from redis.typing import AbsExpiryT, EncodableT, ExpiryT, KeyT
from redis.typing import AbsExpiryT, EncodableT, ExpiryT, KeyT, PatternT

from django_redis import pool
from django_redis.exceptions import CompressorError, ConnectionInterrupted
Expand Down Expand Up @@ -66,6 +77,14 @@ def __init__(self, server, params: Dict[str, Any], backend: BaseCache) -> None:
def __contains__(self, key: KeyT) -> bool:
return self.has_key(key)

def _has_compression_enabled(self) -> bool:
return (
self._options.get(
"COMPRESSOR", "django_redis.compressors.identity.IdentityCompressor"
)
!= "django_redis.compressors.identity.IdentityCompressor"
)

def get_next_client_index(
self, write: bool = True, tried: Optional[List[int]] = None
) -> int:
Expand Down Expand Up @@ -498,6 +517,17 @@ def encode(self, value: EncodableT) -> Union[bytes, int]:

return value

def _decode_iterable_result(
self, result: Any, covert_to_set: bool = True
) -> Union[List[Any], None, Any]:
if result is None:
return None
if isinstance(result, list):
if covert_to_set:
return {self.decode(value) for value in result}
return [self.decode(value) for value in result]
return self.decode(result)

def get_many(
self,
keys: Iterable[KeyT],
Expand Down Expand Up @@ -778,6 +808,257 @@ def make_pattern(

return CacheKey(self._backend.key_func(pattern, prefix, version_str))

def sadd(
self,
key: KeyT,
*values: Any,
version: Optional[int] = None,
client: Optional[Redis] = None,
) -> int:
if client is None:
client = self.get_client(write=True)

key = self.make_key(key, version=version)
encoded_values = [self.encode(value) for value in values]
return int(client.sadd(key, *encoded_values))

def scard(
self,
key: KeyT,
version: Optional[int] = None,
client: Optional[Redis] = None,
) -> int:
if client is None:
client = self.get_client(write=False)

key = self.make_key(key, version=version)
return int(client.scard(key))

def sdiff(
self,
*keys: KeyT,
version: Optional[int] = None,
client: Optional[Redis] = None,
) -> Set[Any]:
if client is None:
client = self.get_client(write=False)

nkeys = [self.make_key(key, version=version) for key in keys]
return {self.decode(value) for value in client.sdiff(*nkeys)}

def sdiffstore(
self,
dest: KeyT,
*keys: KeyT,
version_dest: Optional[int] = None,
version_keys: Optional[int] = None,
client: Optional[Redis] = None,
) -> int:
if client is None:
client = self.get_client(write=True)

dest = self.make_key(dest, version=version_dest)
nkeys = [self.make_key(key, version=version_keys) for key in keys]
return int(client.sdiffstore(dest, *nkeys))

def sinter(
self,
*keys: KeyT,
version: Optional[int] = None,
client: Optional[Redis] = None,
) -> Set[Any]:
if client is None:
client = self.get_client(write=False)

nkeys = [self.make_key(key, version=version) for key in keys]
return {self.decode(value) for value in client.sinter(*nkeys)}

def sinterstore(
self,
dest: KeyT,
*keys: KeyT,
version: Optional[int] = None,
client: Optional[Redis] = None,
) -> int:
if client is None:
client = self.get_client(write=True)

dest = self.make_key(dest, version=version)
nkeys = [self.make_key(key, version=version) for key in keys]
return int(client.sinterstore(dest, *nkeys))

def smismember(
self,
key: KeyT,
*members,
version: Optional[int] = None,
client: Optional[Redis] = None,
) -> List[bool]:
if client is None:
client = self.get_client(write=False)

key = self.make_key(key, version=version)
encoded_members = [self.encode(member) for member in members]

return [bool(value) for value in client.smismember(key, *encoded_members)]

def sismember(
self,
key: KeyT,
member: Any,
version: Optional[int] = None,
client: Optional[Redis] = None,
) -> bool:
if client is None:
client = self.get_client(write=False)

key = self.make_key(key, version=version)
member = self.encode(member)
return bool(client.sismember(key, member))

def smembers(
self,
key: KeyT,
version: Optional[int] = None,
client: Optional[Redis] = None,
) -> Set[Any]:
if client is None:
client = self.get_client(write=False)

key = self.make_key(key, version=version)
return {self.decode(value) for value in client.smembers(key)}

def smove(
self,
source: KeyT,
destination: KeyT,
member: Any,
version: Optional[int] = None,
client: Optional[Redis] = None,
) -> bool:
if client is None:
client = self.get_client(write=True)

source = self.make_key(source, version=version)
destination = self.make_key(destination)
member = self.encode(member)
return bool(client.smove(source, destination, member))

def spop(
self,
key: KeyT,
count: Optional[int] = None,
version: Optional[int] = None,
client: Optional[Redis] = None,
) -> Union[Set, Any]:
if client is None:
client = self.get_client(write=True)

nkey = self.make_key(key, version=version)
result = client.spop(nkey, count)
return self._decode_iterable_result(result)

def srandmember(
self,
key: KeyT,
count: Optional[int] = None,
version: Optional[int] = None,
client: Optional[Redis] = None,
) -> Union[List, Any]:
if client is None:
client = self.get_client(write=False)

key = self.make_key(key, version=version)
result = client.srandmember(key, count)
return self._decode_iterable_result(result, covert_to_set=False)

def srem(
self,
key: KeyT,
*members: EncodableT,
version: Optional[int] = None,
client: Optional[Redis] = None,
) -> int:
if client is None:
client = self.get_client(write=True)

key = self.make_key(key, version=version)
nmembers = [self.encode(member) for member in members]
return int(client.srem(key, *nmembers))

def sscan(
self,
key: KeyT,
match: Optional[str] = None,
count: Optional[int] = 10,
version: Optional[int] = None,
client: Optional[Redis] = None,
) -> Set[Any]:
if self._has_compression_enabled() and match:
err_msg = "Using match with compression is not supported."
raise ValueError(err_msg)

if client is None:
client = self.get_client(write=False)

key = self.make_key(key, version=version)

cursor, result = client.sscan(
key,
match=cast(PatternT, self.encode(match)) if match else None,
count=count,
)
return {self.decode(value) for value in result}

def sscan_iter(
self,
key: KeyT,
match: Optional[str] = None,
count: Optional[int] = 10,
version: Optional[int] = None,
client: Optional[Redis] = None,
) -> Iterator[Any]:
if self._has_compression_enabled() and match:
err_msg = "Using match with compression is not supported."
raise ValueError(err_msg)

if client is None:
client = self.get_client(write=False)

key = self.make_key(key, version=version)
for value in client.sscan_iter(
key,
match=cast(PatternT, self.encode(match)) if match else None,
count=count,
):
yield self.decode(value)

def sunion(
self,
*keys: KeyT,
version: Optional[int] = None,
client: Optional[Redis] = None,
) -> Set[Any]:
if client is None:
client = self.get_client(write=False)

nkeys = [self.make_key(key, version=version) for key in keys]
return {self.decode(value) for value in client.sunion(*nkeys)}

def sunionstore(
self,
destination: Any,
*keys: KeyT,
version: Optional[int] = None,
client: Optional[Redis] = None,
) -> int:
if client is None:
client = self.get_client(write=True)

destination = self.make_key(destination, version=version)
encoded_keys = [self.make_key(key, version=version) for key in keys]
return int(client.sunionstore(destination, *encoded_keys))

def close(self) -> None:
close_flag = self._options.get(
"CLOSE_CONNECTION",
Expand Down
Loading

0 comments on commit e11150a

Please sign in to comment.