Skip to content

Commit

Permalink
typing(): more typing
Browse files Browse the repository at this point in the history
  • Loading branch information
PaulRenvoise committed Jun 3, 2024
1 parent dcd967b commit ef73f0d
Show file tree
Hide file tree
Showing 10 changed files with 113 additions and 75 deletions.
32 changes: 19 additions & 13 deletions flashback/caching/adapters/disk_adapter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from __future__ import annotations

from collections.abc import Sequence, Generator
from contextlib import contextmanager
from datetime import datetime, timedelta
from fcntl import flock, LOCK_SH, LOCK_EX, LOCK_UN
from typing import Any
import shelve
from shelve import Shelf
import tempfile
import uuid

Expand All @@ -15,12 +20,12 @@ class DiskAdapter(BaseAdapter):
See: https://docs.python.org/3/library/shelve.html.
"""

def __init__(self, **_kwargs):
def __init__(self, **_kwargs) -> None:
super().__init__()

self._store_path = f"{tempfile.gettempdir()}/{uuid.uuid4()}"

def set(self, key, value, ttl):
def set(self, key: str, value: Any, ttl: int) -> bool:
if ttl == -1:
expiry = None
else:
Expand All @@ -31,8 +36,9 @@ def set(self, key, value, ttl):

return True

def batch_set(self, keys, values, ttls):
def batch_set(self, keys: Sequence[str], values: Sequence[Any], ttls: Sequence[int]) -> bool:
now = datetime.now()
# TODO: use relativedelta
expiries = [None if ttl == -1 else datetime.timestamp(now + timedelta(seconds=ttl)) for ttl in ttls]

values = zip(values, expiries)
Expand All @@ -42,53 +48,53 @@ def batch_set(self, keys, values, ttls):

return True

def get(self, key):
def get(self, key: str) -> Any | None:
self._evict()

with self._open_locked_store(LOCK_SH) as store:
return store.get(key, (None,))[0]

def batch_get(self, keys):
def batch_get(self, keys: Sequence[str]) -> Sequence[Any | None]:
self._evict()

with self._open_locked_store(LOCK_SH) as store:
return [store.get(key, (None,))[0] for key in keys]

def delete(self, key):
def delete(self, key: str) -> bool:
self._evict()

with self._open_locked_store(LOCK_EX) as store:
return bool(store.pop(key, False))

def batch_delete(self, keys):
def batch_delete(self, keys: Sequence[str]) -> bool:
self._evict()

with self._open_locked_store(LOCK_EX) as store:
res = [bool(store.pop(key, False)) for key in keys]

return False not in res

def exists(self, key):
def exists(self, key: str) -> bool:
self._evict()

with self._open_locked_store(LOCK_SH) as store:
return key in store

def flush(self):
def flush(self) -> bool:
with self._open_locked_store(LOCK_EX) as store:
store.clear()

return True

def ping(self):
def ping(self) -> bool:
return True

@property
def connection_exceptions(self):
def connection_exceptions(self) -> tuple[Exception, ...]:
return ()

@contextmanager
def _open_locked_store(self, mode):
def _open_locked_store(self, mode: int) -> Generator[Shelf[Any], None, None]:
with open(f"{self._store_path}.lock", "w", encoding="utf-8") as lock:
flock(lock.fileno(), mode) # blocking until lock is acquired

Expand All @@ -98,7 +104,7 @@ def _open_locked_store(self, mode):
finally:
flock(lock.fileno(), LOCK_UN)

def _evict(self):
def _evict(self) -> None:
now = datetime.timestamp(datetime.now())

expired_keys = set()
Expand Down
27 changes: 16 additions & 11 deletions flashback/caching/adapters/memcached_adapter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from __future__ import annotations

from collections.abc import Sequence
from typing import Any

from pymemcache.client.base import Client
from pymemcache.exceptions import * # noqa: F403

Expand All @@ -11,18 +16,18 @@ class MemcachedAdapter(BaseAdapter):
Exposes `pymemcache`'s exceptions.
"""

def __init__(self, host="localhost", port=11211, **kwargs):
def __init__(self, host: str = "localhost", port: int = 11211, **kwargs) -> None:
super().__init__()

self.store = Client((host, port), **kwargs)

def set(self, key, value, ttl):
def set(self, key: str, value: Any, ttl: int) -> bool:
if ttl == -1:
ttl = 0

return self.store.set(key, value, expire=ttl)

def batch_set(self, keys, values, ttls):
def batch_set(self, keys: Sequence[str], values: Sequence[Any], ttls: Sequence[int]) -> bool:
# There's two reasons to recode pymemcache.set_multi():
# - It returns a list of keys that failed to be inserted, and the base expects a boolean
# - It only allows a unique ttl for all keys
Expand All @@ -45,17 +50,17 @@ def batch_set(self, keys, values, ttls):

return all(line != b"NOT_STORED" for line in results)

def get(self, key):
def get(self, key: str) -> Any | None:
return self.store.get(key)

def batch_get(self, keys):
def batch_get(self, keys: Sequence[str]) -> Sequence[Any | None]:
key_to_value = self.store.get_multi(keys)
return [key_to_value.get(key, None) for key in keys]

def delete(self, key):
def delete(self, key: str) -> bool:
return self.store.delete(key, noreply=False)

def batch_delete(self, keys):
def batch_delete(self, keys: Sequence[str]) -> bool:
# Here as well, pymemcache.delete_multi() always returns True
commands = []

Expand All @@ -69,16 +74,16 @@ def batch_delete(self, keys):

return all(line != b"NOT_FOUND" for line in results)

def exists(self, key):
def exists(self, key: str) -> bool:
# Can't just cast to bool since we can store falsey values
return self.store.get(key) is not None

def flush(self):
def flush(self) -> bool:
return self.store.flush_all(noreply=False)

def ping(self):
def ping(self) -> bool:
return bool(self.store.stats())

@property
def connection_exceptions(self):
def connection_exceptions(self) -> tuple[Exception, ...]:
return (MemcacheUnexpectedCloseError, MemcacheServerError, MemcacheUnknownError) # noqa: F405
29 changes: 17 additions & 12 deletions flashback/caching/adapters/memory_adapter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from __future__ import annotations

from collections.abc import Sequence
from datetime import datetime, timedelta
from threading import RLock
from typing import Any

from .base import BaseAdapter

Expand All @@ -9,24 +13,25 @@ class MemoryAdapter(BaseAdapter):
Exposes a cache store using a in-memory dict.
"""

def __init__(self, **_kwargs):
def __init__(self, **_kwargs) -> None:
super().__init__()

self._lock = RLock()
self.store = {}

def set(self, key, value, ttl):
def set(self, key: str, value: Any, ttl: int) -> bool:
if ttl == -1:
expiry = None
else:
# TODO: use relativedelta
expiry = datetime.timestamp(datetime.now() + timedelta(seconds=ttl))

with self._lock:
self.store[key] = (value, expiry)

return True

def batch_set(self, keys, values, ttls):
def batch_set(self, keys: Sequence[str], values: Sequence[Any], ttls: Sequence[int]) -> bool:
now = datetime.now()
expiries = [None if ttl == -1 else datetime.timestamp(now + timedelta(seconds=ttl)) for ttl in ttls]

Expand All @@ -37,50 +42,50 @@ def batch_set(self, keys, values, ttls):

return True

def get(self, key):
def get(self, key: str) -> Any | None:
self._evict()

return self.store.get(key, (None,))[0]

def batch_get(self, keys):
def batch_get(self, keys: Sequence[str]) -> Sequence[Any | None]:
self._evict()

return [self.store.get(key, (None,))[0] for key in keys]

def delete(self, key):
def delete(self, key: str) -> bool:
self._evict()

with self._lock:
value = self.store.pop(key, False)

return bool(value)

def batch_delete(self, keys):
def batch_delete(self, keys: Sequence[str]) -> bool:
self._evict()

with self._lock:
res = [bool(self.store.pop(key, False)) for key in keys]

return False not in res

def exists(self, key):
def exists(self, key: str) -> bool:
self._evict()

return key in self.store

def flush(self):
def flush(self) -> bool:
self.store.clear()

return True

def ping(self):
def ping(self) -> bool:
return True

@property
def connection_exceptions(self):
def connection_exceptions(self) -> tuple[Exception, ...]:
return ()

def _evict(self):
def _evict(self) -> None:
now = datetime.timestamp(datetime.now())

expired_keys = set()
Expand Down
35 changes: 21 additions & 14 deletions flashback/caching/adapters/redis_adapter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from __future__ import annotations

from collections.abc import Sequence
from typing import Any

from redis import Redis
from redis.exceptions import * # noqa: F403
from redis.exceptions import ConnectionError as RedisConnectionError
Expand All @@ -23,51 +28,53 @@ def __init__(self, host="localhost", port=6379, db="0", encoding="utf-8", **kwar
self._encoding = encoding
self.store = Redis(host=host, port=port, db=db, encoding=encoding, **kwargs)

def set(self, key, value, ttl):
def set(self, key: str, value: Any, ttl: int) -> bool:
if ttl == -1:
ttl = None
converted_ttl = None
else:
converted_ttl = ttl

return self.store.set(key, value, ex=ttl)
return self.store.set(key, value, ex=converted_ttl)

def batch_set(self, keys, values, ttls):
ttls = [None if ttl == -1 else ttl for ttl in ttls]
def batch_set(self, keys: Sequence[str], values: Sequence[Any], ttls: Sequence[int]) -> bool:
converted_ttls = [None if ttl == -1 else ttl for ttl in ttls]

pipe = self.store.pipeline()

pipe.mset(dict(zip(keys, values)))
for key, ttl in zip(keys, ttls):
for key, ttl in zip(keys, converted_ttls):
if ttl is not None:
pipe.expire(key, ttl)

return pipe.execute()

def get(self, key):
def get(self, key: str) -> Any | None:
value = self.store.get(key)

return value.decode(self._encoding) if value is not None else None

def batch_get(self, keys):
def batch_get(self, keys: Sequence[str]) -> Sequence[Any | None]:
values = self.store.mget(keys)

return [value.decode(self._encoding) if value is not None else None for value in values]

def delete(self, key):
def delete(self, key: str) -> bool:
return bool(self.store.delete(key))

def batch_delete(self, keys):
def batch_delete(self, keys: Sequence[str]) -> bool:
res = self.store.delete(*keys)

return res == len(keys)

def exists(self, key):
def exists(self, key: str) -> bool:
return self.store.exists(key)

def flush(self):
def flush(self) -> bool:
return self.store.flushdb()

def ping(self):
def ping(self) -> bool:
return self.store.ping()

@property
def connection_exceptions(self):
def connection_exceptions(self) -> tuple[Exception, ...]:
return (RedisConnectionError, RedisTimeoutError, ResponseError) # noqa: F405
Loading

0 comments on commit ef73f0d

Please sign in to comment.