Skip to content

Commit

Permalink
feat(backend): add SQLiteBackend
Browse files Browse the repository at this point in the history
  • Loading branch information
Rizhiy committed May 27, 2024
1 parent 7b6aa1d commit f5640ce
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 58 deletions.
32 changes: 21 additions & 11 deletions benchmark/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
import random
import string
from time import monotonic
from typing import cast

import numpy as np
from replete.logging import setup_logging

from class_cache import Cache
from class_cache.backends import PickleBackend
from class_cache.backends import BaseBackend, PickleBackend, SQLiteBackend

LOGGER = logging.getLogger("class_cache.benchmark.main")
SIZE = 512
Expand All @@ -36,16 +35,16 @@ def convert_size(size_bytes):
return f"{s} {size_name[i]}"


def main():
setup_logging(print_level=logging.INFO)
def evaluate(backend_type: type[BaseBackend] = PickleBackend):
LOGGER.info(f"Evaluating {backend_type.__name__} backend")
arrays = {f"arr_{idx}": NP_RNG.standard_normal((SIZE, 32)) for idx in range(SIZE)}
strings = {f"str_{idx}": get_random_string(SIZE) for idx in range(SIZE)}
objects = {f"obj_{idx}": MyObj(str(idx) * SIZE, idx) for idx in range(SIZE)}

data = arrays | strings | objects
LOGGER.info(f"Got {len(data)} elements")

cache = Cache()
cache = Cache(backend_type=backend_type)
cache.clear()
start_write = monotonic()
cache.update(data)
Expand All @@ -54,19 +53,30 @@ def main():
LOGGER.info(f"Write took {end_write - start_write:3f} seconds")

del cache
read_cache = Cache()
read_cache = Cache(backend_type=backend_type)
start_read = monotonic()
for key in read_cache:
read_cache[key]
end_read = monotonic()
LOGGER.info(f"Read took {end_read - start_read:3f} seconds")

total_size = 0
backend = cast(PickleBackend, read_cache.backend)
for block_id in backend.get_all_block_ids():
total_size += backend.get_path_for_block_id(block_id).stat().st_size
backend = read_cache.backend
match backend:
case PickleBackend():
total_size = 0
for block_id in backend.get_all_block_ids():
total_size += backend.get_path_for_block_id(block_id).stat().st_size
LOGGER.info(f"{len(list(backend.get_all_block_ids()))} total blocks")
case SQLiteBackend():
total_size = backend.db_path.stat().st_size

LOGGER.info(f"Size on disk: {convert_size(total_size)}")
LOGGER.info(f"{len(list(backend.get_all_block_ids()))} total blocks")


def main():
setup_logging(print_level=logging.INFO)
for backend_type in {PickleBackend, SQLiteBackend}:
evaluate(backend_type)


if __name__ == "__main__":
Expand Down
87 changes: 82 additions & 5 deletions class_cache/backends.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# ruff: noqa: S608
import codecs
import json
import pickle # noqa: S403
import sqlite3
from abc import ABC
from collections import defaultdict
from pathlib import Path
from typing import Any, Iterable, Iterator, MutableMapping
from typing import Any, Iterable, Iterator, MutableMapping, TypeAlias, cast

from fasteners import InterProcessReaderWriterLock
from marisa_trie import Trie
Expand All @@ -12,13 +15,15 @@
from class_cache.types import KeyType, ValueType
from class_cache.utils import get_class_cache_dir

ID_TYPE: TypeAlias = str | int | None


class BaseBackend(ABC, MutableMapping[KeyType, ValueType]):
def __init__(self, id_: str | int | None = None) -> None:
def __init__(self, id_: ID_TYPE = None) -> None:
self._id = id_

@property
def id(self) -> str | int | None:
def id(self) -> ID_TYPE:
return self._id

# Override these methods to allow getting results in a more optimal fashion
Expand Down Expand Up @@ -47,9 +52,9 @@ class PickleBackend(BaseBackend[KeyType, ValueType]):
BLOCK_SUFFIX = ".block.pkl"
META_TYPE = dict[str, Any]

def __init__(self, id_: str | int, max_block_size=1024 * 1024) -> None:
def __init__(self, id_: ID_TYPE = None, max_block_size=1024 * 1024) -> None:
super().__init__(id_)
self._dir = self.ROOT_DIR / str(self._id)
self._dir = self.ROOT_DIR / str(self.id)
self._dir.mkdir(exist_ok=True, parents=True)
self._max_block_size = max_block_size
self._meta_path = self._dir / "meta.json"
Expand Down Expand Up @@ -174,3 +179,75 @@ def clear(self) -> None:
self.get_path_for_block_id(block_id).unlink()
self._meta_path.unlink()
self._write_clean_meta()


class SQLiteBackend(BaseBackend[KeyType, ValueType]):
ROOT_DIR = get_class_cache_dir() / "SQLiteBackend"
ROOT_DIR.mkdir(parents=True, exist_ok=True)
DATA_TABLE_NAME = "data"

def __init__(self, id_: ID_TYPE = None) -> None:
super().__init__(id_)
self._db_path = self.ROOT_DIR / str(self.id)
self._con = sqlite3.connect(self._db_path)
self._cursor = self._con.cursor()
self._check_table()

@property
def db_path(self) -> Path:
return self._db_path

def _check_table(self):
tables = self._cursor.execute("SELECT name FROM sqlite_master LIMIT 1").fetchone()
if tables is None:
self._cursor.execute(f"CREATE TABLE {self.DATA_TABLE_NAME}(key TEXT, value TEXT)")
self._cursor.execute(f"CREATE UNIQUE INDEX key_index ON {self.DATA_TABLE_NAME}(key)")

# TODO: Can probably cache these
def _encode(self, obj: KeyType | ValueType) -> str:
return codecs.encode(pickle.dumps(obj), "base64").decode()

def _decode(self, stored: str) -> KeyType | ValueType:
return pickle.loads(codecs.decode(stored.encode(), "base64")) # noqa: S301

def __contains__(self, key: KeyType) -> bool:
key_str = self._encode(key)
sql = f"SELECT EXISTS(SELECT 1 FROM {self.DATA_TABLE_NAME} WHERE key=? LIMIT 1)"
value = self._cursor.execute(sql, (key_str,)).fetchone()[0]
return value != 0

def __len__(self) -> int:
return self._cursor.execute(f"SELECT COUNT(key) FROM {self.DATA_TABLE_NAME}").fetchone()[0]

def __iter__(self) -> Iterator[KeyType]:
for key_str in self._cursor.execute(f"SELECT key FROM {self.DATA_TABLE_NAME}").fetchall():
yield cast(KeyType, self._decode(key_str[0]))

def __getitem__(self, key: KeyType) -> ValueType:
key_str = self._encode(key)
sql = f"SELECT value FROM {self.DATA_TABLE_NAME} WHERE key=? LIMIT 1"
res = self._cursor.execute(sql, (key_str,)).fetchone()
if res is None:
raise KeyError(key)
return cast(ValueType, self._decode(res[0]))

def __setitem__(self, key: KeyType, value: ValueType) -> None:
key_str = self._encode(key)
value_str = self._encode(value)
self._cursor.execute(f"INSERT INTO {self.DATA_TABLE_NAME} VALUES (?, ?)", (key_str, value_str))
self._con.commit()

def __delitem__(self, key: KeyType) -> None:
key_str = self._encode(key)
self._cursor.execute(f"DELETE FROM {self.DATA_TABLE_NAME} WHERE key=?", (key_str,))
self._con.commit()

def clear(self) -> None:
self._cursor.execute(f"DROP TABLE IF EXISTS {self.DATA_TABLE_NAME}")
self._check_table()

def __del__(self):
self._con.commit()
self._con.close()

# TODO: implement *_many methods
4 changes: 2 additions & 2 deletions class_cache/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@


class Cache(ABC, MutableMapping[KeyType, ValueType]):
def __init__(self, id_: str | int | None = None, backend: type[BaseBackend] = DEFAULT_BACKEND_TYPE) -> None:
self._backend = backend(id_)
def __init__(self, id_: str | int | None = None, backend_type: type[BaseBackend] = DEFAULT_BACKEND_TYPE) -> None:
self._backend = backend_type(id_)
self._data: dict[KeyType, ValueType] = {}
self._to_write = set()
self._to_delete = set()
Expand Down
43 changes: 43 additions & 0 deletions tests/test_backends.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import random

import pytest

from class_cache.backends import BaseBackend, PickleBackend, SQLiteBackend

TEST_ID = "class_cache.tests.backends.id"
TEST_KEY = "class_cache.tests.backends.key"
TEST_VALUE = "class_cache.tests.backends.value"


@pytest.mark.parametrize(("backend_type"), [PickleBackend, SQLiteBackend])
class TestCore:
def test_basic(self, backend_type: type[BaseBackend]):
backend = backend_type(TEST_ID)
backend.clear()
assert TEST_KEY not in backend
backend[TEST_KEY] = TEST_VALUE
assert TEST_KEY in backend
assert backend[TEST_KEY] == TEST_VALUE
assert len(backend) == 1
assert list(backend) == [TEST_KEY]
del backend[TEST_KEY]
assert TEST_KEY not in backend

def test_write_read(self, backend_type: type[BaseBackend]):
write_backend = backend_type(TEST_ID)
write_backend.clear()
assert TEST_KEY not in write_backend
write_backend[TEST_KEY] = TEST_VALUE

read_backend = backend_type(TEST_ID)
assert TEST_KEY in read_backend
assert read_backend[TEST_KEY] == TEST_VALUE


def test_max_block_size():
size = 256
backend = PickleBackend(TEST_ID, 1024)
backend.clear()
for i in range(size):
backend[i] = random.sample(list(range(size)), size)
assert len(list(backend.get_all_block_ids())) > 100
Empty file removed tests/test_backends/__init__.py
Empty file.
40 changes: 0 additions & 40 deletions tests/test_backends/test_pickle.py

This file was deleted.

0 comments on commit f5640ce

Please sign in to comment.