Skip to content

Commit

Permalink
feat(backend): add block splitting to PickleBackend
Browse files Browse the repository at this point in the history
  • Loading branch information
Rizhiy committed May 27, 2024
1 parent f95ca70 commit a8fcb94
Show file tree
Hide file tree
Showing 8 changed files with 203 additions and 41 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ pip install class-cache
NON_HASH_ATTRIBUTES = frozenset({*CacheWithDefault.NON_HASH_ATTRIBUTES, "_misc"})
def __init__(self, name: str):
# Attributes which affect default value generation should come before super().__init__()
# They will be used to generate a unique id
self._name = name
super().__init__()
# Other attributes should not affect how default value is generated, add them to NON_HASH_ATTRIBUTES
Expand Down
Empty file added benchmark/__init__.py
Empty file.
73 changes: 73 additions & 0 deletions benchmark/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import logging
import math
import random
import string
from time import monotonic
from typing import cast

import numpy as np
from replete.logging import setup_logging

from class_cache import Cache
from class_cache.backends import PickleBackend

LOGGER = logging.getLogger("class_cache.benchmark.main")
SIZE = 512
NP_RNG = np.random.default_rng()


class MyObj:
def __init__(self, name: str, number: int):
self._name = name
self._number = number


def get_random_string(size: int) -> str:
return "".join(random.choice(string.ascii_letters + string.digits) for _ in range(size)) # noqa: S311


def convert_size(size_bytes):
if size_bytes == 0:
return "0B"
size_name = ("B", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB", "ZiB", "YiB")
i = int(math.floor(math.log(size_bytes, 1024)))
p = math.pow(1024, i)
s = round(size_bytes / p, 2)
return f"{s} {size_name[i]}"


def main():
setup_logging(print_level=logging.INFO)
arrays = {f"arr_{idx}": NP_RNG.standard_normal((SIZE, 32)) for idx in range(SIZE)}
strings = {f"str_{idx}": get_random_string(SIZE) for idx in range(SIZE)}
objects = {f"obj_{idx}": MyObj(str(idx) * SIZE, idx) for idx in range(SIZE)}

data = arrays | strings | objects
LOGGER.info(f"Got {len(data)} elements")

cache = Cache()
cache.clear()
start_write = monotonic()
cache.update(data)
cache.write()
end_write = monotonic()
LOGGER.info(f"Write took {end_write - start_write:3f} seconds")

del cache
read_cache = Cache()
start_read = monotonic()
for key in read_cache:
read_cache[key]
end_read = monotonic()
LOGGER.info(f"Read took {end_read - start_read:3f} seconds")

total_size = 0
backend = cast(PickleBackend, read_cache.backend)
for block_id in backend.get_all_block_ids():
total_size += backend.get_path_for_block_id(block_id).stat().st_size
LOGGER.info(f"Size on disk: {convert_size(total_size)}")
LOGGER.info(f"{len(list(backend.get_all_block_ids()))} total blocks")


if __name__ == "__main__":
main()
139 changes: 102 additions & 37 deletions class_cache/backends.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
import json
import pickle # noqa: S403
from abc import ABC
from collections import defaultdict
from pathlib import Path
from typing import Iterable, Iterator, MutableMapping
from typing import Any, Iterable, Iterator, MutableMapping

from fasteners import InterProcessReaderWriterLock
from marisa_trie import Trie
from replete.consistent_hash import consistent_hash

from class_cache.types import KeyType, ValueType
from class_cache.utils import get_user_cache_dir
from class_cache.utils import get_class_cache_dir


class BaseBackend(ABC, MutableMapping[KeyType, ValueType]):
def __init__(self, id_: str | int = None) -> None:
def __init__(self, id_: str | int | None = None) -> None:
self._id = id_

@property
Expand Down Expand Up @@ -40,72 +43,134 @@ def clear(self) -> None:


class PickleBackend(BaseBackend[KeyType, ValueType]):
ROOT_DIR = get_user_cache_dir()
ROOT_DIR = get_class_cache_dir() / "PickleBackend"
BLOCK_SUFFIX = ".block.pkl"
META_TYPE = dict[str, Any]

def __init__(self, id_: str | int, target_block_size=1024**2) -> None:
def __init__(self, id_: str | int, target_block_size=1024 * 1024) -> None:
super().__init__(id_)
self._dir = self.ROOT_DIR / str(self._id)
self._dir.mkdir(exist_ok=True, parents=True)
self._target_block_size = target_block_size
self._meta_path = self._dir / "meta.json"
self._lock = InterProcessReaderWriterLock(self._dir / "lock.file")
self._check_meta()

@property
def dir(self) -> Path:
return self._dir

# Helper methods, don't acquire locks inside them
def _read_meta(self) -> META_TYPE:
with self._meta_path.open() as f:
return json.load(f)

def _write_meta(self, meta: META_TYPE) -> None:
with self._meta_path.open("w") as f:
json.dump(meta, f)

def _write_clean_meta(self) -> None:
self._write_meta({"len": 0})

def get_path_for_block_id(self, block_id: str) -> Path:
return self._dir / f"{block_id}{self.BLOCK_SUFFIX}"

def _get_key_hash(self, key: KeyType) -> str:
return f"{consistent_hash(key):x}"

def _get_all_block_ids(self) -> Iterable[str]:
yield from (path.name.split(".")[0] for path in self._dir.glob(f"*{self.BLOCK_SUFFIX}"))
def _get_block_id_for_key(self, key: KeyType, prefix_len=1) -> str:
key_hash = self._get_key_hash(key)

def _get_path_for_block_id(self, block_id: str) -> Path:
return self._dir / f"{block_id}{self.BLOCK_SUFFIX}"
blocks_trie = Trie(self.get_all_block_ids())
prefixes = blocks_trie.prefixes(key_hash)
return key_hash[:prefix_len] if not prefixes else max(prefixes, key=len)

def _get_block(self, block_id: str) -> dict[KeyType, ValueType]:
try:
with self._get_path_for_block_id(block_id).open("rb") as f:
with self.get_path_for_block_id(block_id).open("rb") as f:
return pickle.load(f) # noqa: S301
except FileNotFoundError:
return {}

def _write_block(self, block_id: str, block: dict[KeyType, ValueType]) -> None:
with self._get_path_for_block_id(block_id).open("wb") as f:
pickle.dump(block, f)
with self.get_path_for_block_id(block_id).open("wb") as f:
pickle.dump(block, f, pickle.HIGHEST_PROTOCOL)

def _get_block_id_for_key(self, key: KeyType) -> str:
key_hash = self._get_key_hash(key)

blocks_trie = Trie(self._get_all_block_ids())
prefixes = blocks_trie.prefixes(key_hash)
return key_hash[:1] if not prefixes else max(prefixes, key=len)
def _update_length(self, change: int) -> None:
meta = self._read_meta()
meta["len"] += change
self._write_meta(meta)

def _get_block_for_key(self, key: KeyType) -> dict[KeyType, ValueType]:
return self._get_block(self._get_block_id_for_key(key))

def get_all_block_ids(self) -> Iterable[str]:
yield from (path.name.split(".")[0] for path in self._dir.glob(f"*{self.BLOCK_SUFFIX}"))

# Helper methods end

def _check_meta(self) -> None:
with self._lock.read_lock():
if self._meta_path.exists():
return
if list(self.get_all_block_ids()):
raise ValueError(f"Found existing blocks without meta file in {self._dir}")
with self._lock.write_lock():
self._write_clean_meta()

def __contains__(self, key: KeyType) -> bool:
return key in self._get_block_for_key(key)
with self._lock.read_lock():
return key in self._get_block_for_key(key)

def __len__(self) -> int:
total_items = 0
# TODO: Optimise this
for block_id in self._get_all_block_ids():
total_items += len(self._get_block(block_id))
return total_items
with self._lock.read_lock():
return self._read_meta()["len"]

def __iter__(self) -> Iterator[KeyType]:
# TODO: Optimise this
for block_id in self._get_all_block_ids():
# TODO: This should also use a read lock, but it seems to be not working, see:
# https://github.com/harlowja/fasteners/issues/115
for block_id in self.get_all_block_ids():
yield from self._get_block(block_id).keys()

def __getitem__(self, key: KeyType) -> ValueType:
return self._get_block_for_key(key)[key]

def __setitem__(self, key: KeyType, value: ValueType) -> None:
block_id = self._get_block_id_for_key(key)
block = self._get_block(block_id)
block[key] = value
# TODO: Measure block size here and split if necessary
self._write_block(block_id, block)
with self._lock.read_lock():
return self._get_block_for_key(key)[key]

def __setitem__(self, key: KeyType, value: ValueType, prefix_len=1) -> None:
with self._lock.write_lock():
block_id = self._get_block_id_for_key(key, prefix_len=prefix_len)
block = self._get_block(block_id)
change = 0 if key in block else 1
block[key] = value
self._write_block(block_id, block)
self._update_length(change)

if self.get_path_for_block_id(block_id).stat().st_size > self._target_block_size:
self._split_block(block_id)

def _split_block(self, block_id: str) -> None:
with self._lock.write_lock():
block = self._get_block(block_id)
self.get_path_for_block_id(block_id).unlink()
new_prefix_len = len(block_id) + 1
new_blocks = defaultdict(dict)
for key, value in block.items():
new_blocks[self._get_block_id_for_key(key, new_prefix_len)][key] = value
for new_block_id, new_block in new_blocks.items():
self._write_block(new_block_id, new_block)

def __delitem__(self, key: KeyType) -> None:
block_id = self._get_block_id_for_key(key)
block = self._get_block(block_id)
del block[key]
self._write_block(block_id, block)
with self._lock.write_lock():
block_id = self._get_block_id_for_key(key)
block = self._get_block(block_id)
del block[key]
self._write_block(block_id, block)
self._update_length(-1)

def clear(self) -> None:
with self._lock.write_lock():
for block_id in self.get_all_block_ids():
self.get_path_for_block_id(block_id).unlink()
self._meta_path.unlink()
self._write_clean_meta()
12 changes: 10 additions & 2 deletions class_cache/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,16 @@


class Cache(ABC, MutableMapping[KeyType, ValueType]):
def __init__(self, id_: str | int = None, backend: type[BaseBackend] = DEFAULT_BACKEND_TYPE) -> None:
def __init__(self, id_: str | int | None = None, backend: type[BaseBackend] = DEFAULT_BACKEND_TYPE) -> None:
self._backend = backend(id_)
self._data: dict[KeyType, ValueType] = {}
self._to_write = set()
self._to_delete = set()

@property
def backend(self) -> BaseBackend:
return self._backend

def __contains__(self, key: KeyType) -> bool:
if key in self._data:
return True
Expand All @@ -31,9 +35,11 @@ def __getitem__(self, key: KeyType) -> ValueType:
return self._data[key]

def __iter__(self) -> Iterable[KeyType]:
self.write()
return iter(self._backend)

def __len__(self) -> int:
self.write()
return len(self._backend)

def __delitem__(self, key: KeyType) -> None:
Expand Down Expand Up @@ -96,5 +102,7 @@ def __setattr__(self, key: str, value: Any) -> None:
and getattr(self, "_backend_set", None)
and key not in self.NON_HASH_ATTRIBUTES
):
raise TypeError(f"Trying to update hash inclusive attribute after hash has been decided: {key}")
raise TypeError(
f"Trying to update hash inclusive attribute after hash has been decided: {key}",
)
object.__setattr__(self, key, value)
4 changes: 4 additions & 0 deletions class_cache/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,7 @@ def get_user_cache_dir() -> Path:
if cache_home:
return Path(cache_home)
return Path(os.environ["HOME"]) / ".cache"


def get_class_cache_dir() -> Path:
return get_user_cache_dir() / "class_cache"
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.9"
dynamic = ["description", "version"]
dependencies = ["Pympler", "marisa-trie", "replete"]
dependencies = ["Pympler", "fasteners", "marisa-trie", "replete"]

[project.urls]
Home = "https://github.com/Rizhiy/class-cache"

[project.optional-dependencies]
test = ["pytest", "pytest-cov", "replete[testing]"]
dev = ["black", "class-cache[test]", "pre-commit", "ruff"]
dev = ["black", "class-cache[test]", "numpy", "pre-commit", "ruff"]

[tool.flit.sdist]
include = ["README.md"]
Expand Down
11 changes: 11 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,14 @@ def test_items():

read_cache = get_new_cache(clear=False)
assert dict(read_cache.items()) == TEST_DICT


def test_len():
cache = get_new_cache()
cache.update(TEST_DICT)
assert len(cache) == len(TEST_DICT)
cache.write()

del cache
cache = get_new_cache(clear=False)
assert len(cache) == len(TEST_DICT)

0 comments on commit a8fcb94

Please sign in to comment.