Skip to content

Commit

Permalink
feat(wrappers): make wrappers more general and refactor a lot
Browse files Browse the repository at this point in the history
  • Loading branch information
Rizhiy committed Jun 2, 2024
1 parent a1a7273 commit debe119
Show file tree
Hide file tree
Showing 11 changed files with 263 additions and 186 deletions.
8 changes: 4 additions & 4 deletions benchmark/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,17 @@
import math
import random
import string
from functools import partial
from time import monotonic

import numpy as np
from replete.logging import setup_logging

from class_cache import Cache
from class_cache.backends import BaseBackend, BrotliCompressWrapper, PickleBackend, SQLiteBackend
from class_cache.backends import BaseBackend, PickleBackend, SQLiteBackend
from class_cache.wrappers import BrotliCompressWrapper

LOGGER = logging.getLogger("class_cache.benchmark.main")
SIZE = 512
SIZE = 1024
NP_RNG = np.random.default_rng()


Expand Down Expand Up @@ -84,7 +84,7 @@ def main():
for name, backend_type in {
"pickle": PickleBackend,
"sqlite": SQLiteBackend,
"brotli_pickle": partial(BrotliCompressWrapper, backend_type=PickleBackend),
"brotli_pickle": lambda id_: BrotliCompressWrapper(PickleBackend(id_)),
}.items():
evaluate(name, backend_type)

Expand Down
1 change: 1 addition & 0 deletions class_cache/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@

from .backends import BaseBackend
from .core import Cache, CacheWithDefault
from .wrappers import BaseWrapper

__version__ = "0.5.0"
144 changes: 28 additions & 116 deletions class_cache/backends.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,38 @@
# ruff: noqa: S608
import codecs
import json
import logging
import pickle # noqa: S403
import sqlite3
from abc import ABC
from collections import defaultdict
from pathlib import Path
from typing import Any, Iterable, Iterator, MutableMapping, TypeAlias, cast
from typing import Any, Iterable, Iterator, cast

import brotli
from marisa_trie import Trie
from replete.consistent_hash import consistent_hash
from replete.flock import FileLock

from class_cache.types import KeyType, ValueType
from class_cache.types import CacheInterface, IdType, KeyType, ValueType
from class_cache.utils import get_class_cache_dir

ID_TYPE: TypeAlias = str | int | None
LOGGER = logging.getLogger(__name__)


class BaseBackend(ABC, MutableMapping[KeyType, ValueType]):
def __init__(self, id_: ID_TYPE = None) -> None:
self._id = id_

@property
def id(self) -> ID_TYPE:
return self._id

# Override these methods to allow getting results in a more optimal fashion
def contains_many(self, keys: Iterable[KeyType]) -> Iterator[tuple[KeyType, bool]]:
for key in keys:
yield key, key in self

def get_many(self, keys: Iterable[KeyType]) -> Iterator[tuple[KeyType, ValueType]]:
for key in keys:
yield key, self[key]

def set_many(self, items: Iterable[tuple[KeyType, ValueType]]) -> None:
for key, value in items:
self[key] = value

def del_many(self, keys: Iterable[KeyType]) -> None:
for key in keys:
del self[key]

def clear(self) -> None:
self.del_many(self)
class BaseBackend(CacheInterface[KeyType, ValueType]):
pass


class PickleBackend(BaseBackend[KeyType, ValueType]):
ROOT_DIR = get_class_cache_dir() / "PickleBackend"
BLOCK_SUFFIX = ".block.pkl"
META_TYPE = dict[str, Any]

def __init__(self, id_: ID_TYPE = None, max_block_size=1024 * 1024) -> None:
def __init__(self, id_: IdType = None, max_block_size=1024 * 1024) -> None:
super().__init__(id_)
self._dir = self.ROOT_DIR / str(self.id)
self._dir.mkdir(exist_ok=True, parents=True)
self._max_block_size = max_block_size

self._meta_path = self._dir / "meta.json"
self._lock = FileLock(self._meta_path)
self._check_meta()
Expand All @@ -84,6 +59,7 @@ def get_path_for_block_id(self, block_id: str) -> Path:
def _get_key_hash(self, key: KeyType) -> str:
return f"{consistent_hash(key):x}"

# TODO: Add caching for this
def _get_block_id_for_key(self, key: KeyType, prefix_len=1) -> str:
key_hash = self._get_key_hash(key)

Expand All @@ -93,6 +69,7 @@ def _get_block_id_for_key(self, key: KeyType, prefix_len=1) -> str:
raise ValueError("Got prefix_len that is larger than key_hash len.")
return key_hash[:prefix_len] if not prefixes else max(prefixes, key=len)

# TODO: Add caching for this
def _get_block(self, block_id: str) -> dict[KeyType, ValueType]:
try:
with self.get_path_for_block_id(block_id).open("rb") as f:
Expand Down Expand Up @@ -125,6 +102,17 @@ def _check_meta(self) -> None:
with self._lock.write_lock():
self._write_clean_meta()

def _split_block(self, block_id: str) -> None:
with self._lock.write_lock():
block = self._get_block(block_id)
self.get_path_for_block_id(block_id).unlink()
new_prefix_len = len(block_id) + 1
new_blocks = defaultdict(dict)
for key, value in block.items():
new_blocks[self._get_block_id_for_key(key, new_prefix_len)][key] = value
for new_block_id, new_block in new_blocks.items():
self._write_block(new_block_id, new_block)

def __contains__(self, key: KeyType) -> bool:
with self._lock.read_lock():
return key in self._get_block_for_key(key)
Expand All @@ -134,7 +122,6 @@ def __len__(self) -> int:
return self._read_meta()["len"]

def __iter__(self) -> Iterator[KeyType]:
# TODO: Optimise this
with self._lock.read_lock():
for block_id in self.get_all_block_ids():
yield from self._get_block(block_id).keys()
Expand All @@ -153,19 +140,13 @@ def __setitem__(self, key: KeyType, value: ValueType, prefix_len=1) -> None:
self._update_length(change)

if self.get_path_for_block_id(block_id).stat().st_size > self._max_block_size:
if len(block) == 1:
LOGGER.warning(
"Got a block that is larger than max_block_size with single item, please increase max_block_size!",
)
return
self._split_block(block_id)

def _split_block(self, block_id: str) -> None:
with self._lock.write_lock():
block = self._get_block(block_id)
self.get_path_for_block_id(block_id).unlink()
new_prefix_len = len(block_id) + 1
new_blocks = defaultdict(dict)
for key, value in block.items():
new_blocks[self._get_block_id_for_key(key, new_prefix_len)][key] = value
for new_block_id, new_block in new_blocks.items():
self._write_block(new_block_id, new_block)

def __delitem__(self, key: KeyType) -> None:
with self._lock.write_lock():
block_id = self._get_block_id_for_key(key)
Expand All @@ -187,7 +168,7 @@ class SQLiteBackend(BaseBackend[KeyType, ValueType]):
ROOT_DIR.mkdir(parents=True, exist_ok=True)
DATA_TABLE_NAME = "data"

def __init__(self, id_: ID_TYPE = None) -> None:
def __init__(self, id_: IdType = None) -> None:
super().__init__(id_)
self._db_path = self.ROOT_DIR / f"{self.id}.db"
self._con = sqlite3.connect(self._db_path)
Expand All @@ -204,7 +185,7 @@ def _check_table(self):
self._cursor.execute(f"CREATE TABLE {self.DATA_TABLE_NAME}(key TEXT, value TEXT)")
self._cursor.execute(f"CREATE UNIQUE INDEX key_index ON {self.DATA_TABLE_NAME}(key)")

# TODO: Can probably cache these
# TODO: Add caching for keys
def _encode(self, obj: KeyType | ValueType) -> str:
return codecs.encode(pickle.dumps(obj), "base64").decode()

Expand Down Expand Up @@ -253,72 +234,3 @@ def __del__(self):
self._con.close()

# TODO: implement *_many methods


class BackendWrapper(BaseBackend[KeyType, ValueType]):
"""
:param backend: backend to be wrapped
"""

def __init__(self, *args, backend_type: type[BaseBackend], **kwargs) -> None:
super().__init__()
self._backend = backend_type(*args, **kwargs)

def __contains__(self, key: KeyType) -> bool:
return key in self._backend

def __len__(self) -> int:
return len(self._backend)

def __iter__(self) -> Iterator[KeyType]:
yield from self._backend

def __getitem__(self, key: KeyType) -> ValueType:
return self._backend[key]

def __setitem__(self, key: KeyType, value: ValueType) -> None:
self._backend[key] = value

def __delitem__(self, key: KeyType) -> None:
del self._backend[key]

def contains_many(self, keys: Iterable[KeyType]) -> Iterator[tuple[KeyType, bool]]:
yield from self._backend.contains_many(keys)

def get_many(self, keys: Iterable[KeyType]) -> Iterator[tuple[KeyType, ValueType]]:
yield from self._backend.get_many(keys)

def set_many(self, items: Iterable[tuple[KeyType, ValueType]]) -> None:
self._backend.set_many(items)

def del_many(self, keys: Iterable[KeyType]) -> None:
self._backend.del_many(keys)

def clear(self) -> None:
self._backend.clear()


class BrotliCompressWrapper(BackendWrapper[KeyType, ValueType]):
def _encode(self, obj: KeyType | ValueType) -> bytes:
return brotli.compress(pickle.dumps(obj, pickle.HIGHEST_PROTOCOL))

def _decode(self, stored: bytes) -> KeyType | ValueType:
return pickle.loads(brotli.decompress(stored)) # noqa: S301

def __contains__(self, key: KeyType) -> bool:
return super().__contains__(key)

def __iter__(self) -> Iterator[KeyType]:
yield from super().__iter__()

def __getitem__(self, key: KeyType) -> ValueType:
return self._decode(super().__getitem__(key))

def __setitem__(self, key: KeyType, value: ValueType) -> None:
super().__setitem__(key, self._encode(value))

def __delitem__(self, key: KeyType) -> None:
return super().__delitem__(key)

def set_many(self, items: Iterable[tuple[KeyType, ValueType]]) -> None:
return super().set_many((key, self._encode(value)) for key, value in items)
14 changes: 7 additions & 7 deletions class_cache/core.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
from abc import ABC, abstractmethod
from typing import Any, ClassVar, Iterable, MutableMapping
from abc import abstractmethod
from typing import Any, ClassVar, Iterable

from replete.consistent_hash import consistent_hash

from class_cache.backends import BaseBackend, PickleBackend
from class_cache.types import KeyType, ValueType
from class_cache.types import CacheInterface, KeyType, ValueType

DEFAULT_BACKEND_TYPE = PickleBackend


class Cache(ABC, MutableMapping[KeyType, ValueType]):
class Cache(CacheInterface[KeyType, ValueType]):
def __init__(self, id_: str | int | None = None, backend_type: type[BaseBackend] = DEFAULT_BACKEND_TYPE) -> None:
super().__init__(id_)
self._backend = backend_type(id_)
# TODO: Implement max_size logic
self._data: dict[KeyType, ValueType] = {}
self._to_write = set()
self._to_delete = set()
Expand Down Expand Up @@ -102,7 +104,5 @@ def __setattr__(self, key: str, value: Any) -> None:
and getattr(self, "_backend_set", None)
and key not in self.NON_HASH_ATTRIBUTES
):
raise TypeError(
f"Trying to update hash inclusive attribute after hash has been decided: {key}",
)
raise TypeError(f"Trying to update hash inclusive attribute after hash has been decided: {key}")
object.__setattr__(self, key, value)
70 changes: 69 additions & 1 deletion class_cache/types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,74 @@
from typing import TypeAlias, TypeVar
from abc import ABC, abstractmethod
from collections.abc import MutableMapping
from typing import Iterable, Iterator, TypeAlias, TypeVar

KeyType = TypeVar("KeyType")
ValueType = TypeVar("ValueType")
JsonType: TypeAlias = dict[str, "JsonType"] | list["JsonType"] | str | int | float | bool | None
MetaType = dict[str, JsonType]
IdType: TypeAlias = str | int | None


class CacheInterface(ABC, MutableMapping[KeyType, ValueType]):
def __init__(self, id_: IdType = None) -> None:
self._id = id_

@property
def id(self) -> IdType:
return self._id

def __hash__(self) -> int:
return hash(self.id)

def __eq__(self, other) -> bool:
if not isinstance(other, self.__class__):
return False
return hash(self) == hash(other)

@abstractmethod
def __len__(self) -> int:
pass

@abstractmethod
def __iter__(self) -> Iterable[KeyType]:
pass

@abstractmethod
def __setitem__(self, key: KeyType, value: ValueType) -> None:
pass

@abstractmethod
def __getitem__(self, key: KeyType) -> ValueType:
raise KeyError

@abstractmethod
def __delitem__(self, key: KeyType) -> None:
pass

def __contains__(self, key: KeyType) -> bool:
try:
self[key]
except KeyError:
return False
else:
return True

# Override these methods to allow getting results in a more optimal fashion
def contains_many(self, keys: Iterable[KeyType]) -> Iterator[tuple[KeyType, bool]]:
for key in keys:
yield key, key in self

def get_many(self, keys: Iterable[KeyType]) -> Iterator[tuple[KeyType, ValueType]]:
for key in keys:
yield key, self[key]

def set_many(self, items: Iterable[tuple[KeyType, ValueType]]) -> None:
for key, value in items:
self[key] = value

def del_many(self, keys: Iterable[KeyType]) -> None:
for key in keys:
del self[key]

def clear(self) -> None:
self.del_many(self)
Loading

0 comments on commit debe119

Please sign in to comment.