diff --git a/CHANGELOG.md b/CHANGELOG.md index 5e2a80d..60e9c8b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,16 @@ ### Changed +## [0.5.0] – 2020-01-22 + +### Added + +- `AioBaseKiwiCache` corresponding with `BaseKiwiCache` + +### Changed + +- rework `AioKiwiCache` to correspond with `KiwiCache` + ## [0.4.5] – 2019-10-08 ### Added diff --git a/kw/cache/aio.py b/kw/cache/aio.py index 939da49..2840fdc 100644 --- a/kw/cache/aio.py +++ b/kw/cache/aio.py @@ -1,166 +1,197 @@ import asyncio -from datetime import datetime, timedelta -import logging +from datetime import datetime +from typing import Any, Dict, ItemsView, KeysView, Optional, ValuesView import aioredis +import attr -from . import json +from . import utils +from .base import BaseKiwiCache, CACHE_RECORD_ATTRIBUTES, CacheRecord, KiwiCache from .helpers import CallAttempt, CallAttemptException -class AioKiwiCache: # pylint: disable=too-many-instance-attributes - """Caches data from expensive sources to Redis and to memory.""" +@attr.s +class AioBaseKiwiCache(BaseKiwiCache): + """Helper class for load data from cache using asyncio and aioredis.""" - instances = [] # type: List[AioKiwiCache] - reload_ttl = timedelta(minutes=1) - cache_ttl = reload_ttl * 10 - refill_lock_ttl = timedelta(seconds=5) - resources_redis = None + resources_redis = attr.ib(None, type=aioredis.Redis, validator=attr.validators.instance_of(aioredis.Redis)) - def __init__(self, resources_redis=None, logger=None, statsd=None): - # type: (redis.Connection, logging.Logger, datadog.DogStatsd) -> None + async def load_from_cache(self) -> Optional[CacheRecord]: + try: + value = await self.resources_redis.get(self._cache_key) + except aioredis.RedisError: + self._process_cache_error("kiwicache.load_failed") + return None - self.instances.append(self) + if value is None: + return None - if resources_redis is not None: - self.resources_redis = resources_redis + cache_data = self.json.loads(value) + if set(cache_data.keys()) != CACHE_RECORD_ATTRIBUTES: + self._log_warning("kiwicache.malformed_cache_data") + return None + return CacheRecord(**cache_data) - self.check_initialization() + async def save_to_cache(self, data: dict) -> None: + cache_record = CacheRecord(data=data) + try: + await self.resources_redis.set( + self._cache_key, self.json.dumps(attr.asdict(cache_record)), expire=int(self._cache_ttl.total_seconds()) + ) + except aioredis.RedisError: + self._process_cache_error("kiwicache.save_failed") + else: + self._increment_metric("success") - self.name = self.__class__.__name__ - self.expires_at = datetime.utcnow() - self._data = {} # type: dict - self.logger = logger if logger else logging.getLogger(__name__) - self.statsd = statsd - self.call_attempt = CallAttempt("{}.load_from_source".format(self.name.lower())) - self.initialized = False + async def _get_refill_lock(self) -> Optional[bool]: + try: + return bool( + await self.resources_redis.set( + self._refill_lock_key, + "locked", + expire=int(self.refill_ttl.total_seconds()), + exist=self.resources_redis.SET_IF_NOT_EXIST, + ) + ) + except aioredis.RedisError: + self._process_cache_error("kiwicache.refill_lock_failed") + return None + + async def _wait_for_refill_lock(self) -> Optional[bool]: + start_timestamp = utils.get_current_timestamp() + lock_check_period = 0.5 + while True: + has_lock = await self._get_refill_lock() + if has_lock is None or has_lock is True: + return has_lock + + self._log_warning("kiwicache.refill_locked") + # let the lock owner finish + lock_check_period = min(lock_check_period * 2, self.refill_ttl.total_seconds()) + await asyncio.sleep(lock_check_period) + + if await self._is_refilled(start_timestamp): + return False + + async def _is_refilled(self, timestamp: float) -> bool: + cache_record = await self.load_from_cache() + return cache_record and cache_record.timestamp > timestamp + + async def _release_refill_lock(self) -> Optional[bool]: + try: + return bool(await self.resources_redis.delete(self._refill_lock_key)) + except aioredis.RedisError: + self._process_cache_error("kiwicache.release_lock_failed") + return None + + async def _prolong_cache_expiration(self) -> None: + try: + await self.resources_redis.expire(self._cache_key, timeout=int(self._cache_ttl.total_seconds())) + except aioredis.RedisError: + self._process_cache_error("kiwicache.prolong_expiration_failed") - def check_initialization(self): - if self.resources_redis is None: - raise RuntimeError("You must set a redis.Connection object") - if self.cache_ttl < self.reload_ttl: - raise RuntimeError("The cache_ttl has to be greater then reload_ttl.") +@attr.s +class AioKiwiCache(AioBaseKiwiCache, KiwiCache): + """Caches data from expensive sources to Redis and to memory using asyncio.""" - async def acheck_initialization(self): - if await self.resources_redis.ttl(self.redis_key) > int(self.reload_ttl.total_seconds()): - await self.resources_redis.expire(self.redis_key, int(self.reload_ttl.total_seconds())) + instances: Dict[str, "AioKiwiCache"] = {} - @property - def redis_key(self): - return "resource:" + self.name + def __attrs_post_init__(self) -> None: + super().__attrs_post_init__() + self._add_instance() + self._call_attempt = CallAttempt("{}.load_from_source".format(self.name.lower()), self.max_attempts) - async def getitem(self, key): + async def getitem(self, key: Any) -> Any: data = await self.get_data() if key not in data: return self.__missing__(key) return data[key] - def __missing__(self, key): + def __missing__(self, key: Any) -> None: raise KeyError - async def get(self, key, default=None): + async def get(self, key: Any, default: Any = None) -> None: return (await self.get_data()).get(key, default) - async def contains(self, key): + async def contains(self, key: Any) -> bool: return key in await self.get_data() - async def keys(self): + async def keys(self) -> KeysView: return (await self.get_data()).keys() - async def values(self): + async def values(self) -> ValuesView: return (await self.get_data()).values() - async def items(self): + async def items(self) -> ItemsView: return (await self.get_data()).items() - async def get_data(self): + async def get_data(self) -> dict: await self.maybe_reload() return self._data - async def load_from_source(self): # type: () -> dict - """Get the full data bundle from our expensive source.""" - raise NotImplementedError() - - async def load_from_cache(self): # type: () -> str - """Get the full data bundle from cache.""" - return await self.resources_redis.get(self.redis_key) - - async def save_to_cache(self, data): # type: (dict) -> None - """Save the provided full data bundle to cache.""" - try: - await self.resources_redis.set( - self.redis_key, json.dumps(data), expire=int(self.cache_ttl.total_seconds()) if self.cache_ttl else 0 - ) - except aioredis.RedisError: - self.statsd and self.statsd.increment("kiwicache", tags=["name:" + self.name, "status:redis_error"]) - self.logger.exception("kiwicache.redis_exception") - - async def reload(self): - """Load the full data bundle, from cache, or if unavailable, from source.""" - try: - cache_data = await self.load_from_cache() - except aioredis.RedisError: - self.logger.exception("kiwicache.redis_exception") - self.statsd and self.statsd.increment("kiwicache", tags=["name:" + self.name, "status:redis_error"]) - return - - if cache_data: - self._data = json.loads(cache_data) - self.expires_at = datetime.utcnow() + self.reload_ttl - self.statsd and self.statsd.increment("kiwicache", tags=["name:" + self.name, "status:success"]) - else: - await self.refill_cache() - await self.reload() - - async def maybe_reload(self): # type: () -> None - """Load the full data bundle if it's too old.""" - if not self.initialized: - await self.acheck_initialization() - self.initialized = True - - if not self._data or self.expires_at < datetime.utcnow(): + async def reload(self) -> None: + successful_reload = await self.reload_from_cache() + while not successful_reload: try: - await self.reload() + await self.refill_cache() except CallAttemptException: + self._prolong_data_expiration() raise - except Exception: - self.logger.exception("kiwicache.reload_exception") - async def get_refill_lock(self): # type: () -> bool - """Lock loading from the expensive source. + successful_reload = await self.reload_from_cache() + if self.max_attempts < 0 and not successful_reload: + self._prolong_data_expiration() + self._log_error("kiwicache.reload_failed") + break - This lets us avoid all workers hitting database at the same time. + async def reload_from_cache(self) -> bool: + cache_data = await self.load_from_cache() - :return: Whether we got the lock or not - """ - try: - return bool( - await self.resources_redis.set( - self.redis_key + ":lock", - "locked", - expire=int(self.refill_lock_ttl.total_seconds()), - exist=self.resources_redis.SET_IF_NOT_EXIST, - ) - ) - except aioredis.RedisError: - pass + if not cache_data: + return False + + self._data = cache_data.data + self._prolong_data_expiration() + return True + + async def maybe_reload(self) -> None: + if self.expires_at <= datetime.utcnow() or (not self._data and not self.allow_empty_data): + await self.reload() - async def refill_cache(self): - """Cache the full data bundle in Redis.""" - if not await self.get_refill_lock(): - await asyncio.sleep(self.refill_lock_ttl.total_seconds()) # let the lock owner finish + async def _prolong_cache_expiration(self) -> None: + await super()._prolong_cache_expiration() + successful_reload = await self.reload_from_cache() + if not successful_reload and self._data: + await self.save_to_cache(self._data) + + async def _process_refill_error(self, msg: str, exception: Exception = None) -> None: + await self._prolong_cache_expiration() + self._increment_metric("load_error") + self._log_exception(msg) + self._call_attempt.countdown() + + async def refill_cache(self) -> None: + has_lock = await self._wait_for_refill_lock() + if not has_lock: + if has_lock is None: + # redis error + self._call_attempt.countdown() return try: - source_data = await self.load_from_source() - if not source_data: - raise RuntimeError("load_from_source returned empty response!") - - self.call_attempt.reset() - await self.save_to_cache(source_data) - self.statsd and self.statsd.increment("kiwicache", tags=["name:" + self.name, "status:success"]) - except Exception: - self.logger.exception("kiwicache.source_exception") - self.call_attempt.countdown() - self.statsd and self.statsd.increment("kiwicache", tags=["name:" + self.name, "status:load_error"]) + try: + source_data = await self.load_from_source() + except Exception as e: + await self._process_refill_error("kiwicache.source_exception", e) + return + + if source_data or self.allow_empty_data: + await self.save_to_cache(source_data) + else: + await self._process_refill_error("load_from_source returned empty response!") + finally: + await self._release_refill_lock() + + async def load_from_source(self) -> dict: + raise NotImplementedError() diff --git a/kw/cache/base.py b/kw/cache/base.py index 82eb887..c967278 100644 --- a/kw/cache/base.py +++ b/kw/cache/base.py @@ -250,7 +250,7 @@ def _increment_metric(self, status): :param status: metric status """ if self.statsd: - self.statsd.increment(self.metric, tags=["name:{}".format(self.name), "status:{}".format(status)]) + self.statsd.increment(self.metric, tags=["cache_name:{}".format(self.name), "status:{}".format(status)]) @attr.s diff --git a/kw/cache/helpers.py b/kw/cache/helpers.py index f37bbbb..354689f 100644 --- a/kw/cache/helpers.py +++ b/kw/cache/helpers.py @@ -20,7 +20,7 @@ class CallAttempt(object): name = attr.ib(None, type=str) max_attempts = attr.ib(3, type=int) - counter = attr.ib(None, type=int) + counter = attr.ib(None, init=False, type=int) def __attrs_post_init__(self): self.reset() diff --git a/setup.py b/setup.py index 2acaf72..86f4462 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ setup( name="kiwi-cache", - version="0.4.5", + version="0.5.0", url="https://github.com/kiwicom/kiwi-cache", author="Stanislav Komanec", author_email="platform@kiwi.com", diff --git a/test-requirements.txt b/test-requirements.txt index 115ff7f..38b90a7 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -6,7 +6,7 @@ # atomicwrites==1.3.0 # via pytest attrs==19.3.0 -coverage==4.5.4 +coverage==5.0.3 freezegun==0.3.12 future==0.18.2 importlib-metadata==0.23 # via pluggy, pytest diff --git a/test/conftest.py b/test/conftest.py new file mode 100644 index 0000000..986a2d6 --- /dev/null +++ b/test/conftest.py @@ -0,0 +1,8 @@ +from freezegun import freeze_time +import pytest + + +@pytest.fixture +def frozen_time(): + with freeze_time("2000-01-01 00:00:00", ignore=["_pytest.runner"]) as ft: + yield ft diff --git a/test/integration/conftest.py b/test/integration/conftest.py index 6f5341a..ab2e3d4 100644 --- a/test/integration/conftest.py +++ b/test/integration/conftest.py @@ -1,6 +1,5 @@ import os -from freezegun import freeze_time import pytest import redis as redislib import testing.redis @@ -20,10 +19,3 @@ def redis(redis_url): # pylint: disable=redefined-outer-name client = redislib.StrictRedis.from_url(redis_url) yield client client.flushall() - - -@pytest.fixture -def frozen_time(): - ft = freeze_time("2000-01-01 00:00:00") - yield ft.start() - ft.stop() diff --git a/test/integration/py3/conftest.py b/test/integration/py3/conftest.py index af51d2d..40e6130 100644 --- a/test/integration/py3/conftest.py +++ b/test/integration/py3/conftest.py @@ -24,8 +24,8 @@ async def get_refill_lock(self): @pytest.fixture def get_cache(get_aioredis, mocker): # pylint: disable=redefined-outer-name - async def coroutine(): - cache_instance = ArrayCache(await get_aioredis()) + async def coroutine(**params): + cache_instance = ArrayCache(resources_redis=await get_aioredis(), **params) mocker.spy(cache_instance, "load_from_cache") mocker.spy(cache_instance, "load_from_source") return cache_instance diff --git a/test/integration/py3/test_aiokiwicache.py b/test/integration/py3/test_aiokiwicache.py index 2891ed0..c898c67 100644 --- a/test/integration/py3/test_aiokiwicache.py +++ b/test/integration/py3/test_aiokiwicache.py @@ -1,9 +1,11 @@ -from datetime import timedelta +from datetime import datetime, timedelta import sys +import aioredis import pytest from kw.cache.aio import AioKiwiCache as uut +from kw.cache.helpers import CallAttemptException pytestmark = pytest.mark.skipif(sys.version_info < (3, 5), reason="requires Python 3.5+") @@ -11,19 +13,20 @@ @pytest.fixture(autouse=True) def clean_instances_list(): yield None - uut.instances = [] + uut.instances = {} @pytest.mark.asyncio -async def test_instances_list(get_aioredis): +async def test_instances_list(get_aioredis, get_cache): redis_client = await get_aioredis() - instance_one = uut(redis_client) - instance_two = uut(redis_client) - assert uut.instances == [instance_one, instance_two] + instance_one = uut(resources_redis=redis_client) + await get_cache() + instance_three = await get_cache() + assert uut.instances == {"resource:AioKiwiCache": instance_one, "resource:ArrayCache": instance_three} @pytest.mark.asyncio -async def test_init(get_cache, mocker): +async def test_init(get_cache): cache = await get_cache() assert await cache.get("a") == 101 @@ -40,53 +43,80 @@ async def test_init(get_cache, mocker): "2rd call: Cache is filled, so the call succeeds" ) - # Check RuntimeError when initial values are wrong - # The reload_ttl has to be greater then cache_ttl - cache.cache_ttl = timedelta(seconds=5) - cache.reload_ttl = timedelta(seconds=10) - with pytest.raises(RuntimeError): - cache.check_initialization() +@pytest.mark.parametrize( + ("invalid_params", "error"), + [ + ({"resources_redis": None}, TypeError), + ({"resources_redis": "redis"}, TypeError), + ({"cache_ttl": 5}, TypeError), + ({"refill_ttl": None}, TypeError), + ({"refill_ttl": 10}, TypeError), + ({"metric": None}, TypeError), + ({"metric": 1}, TypeError), + ({"reload_ttl": None}, TypeError), + ({"reload_ttl": 5}, TypeError), + ({"expires_at": None}, TypeError), + ({"expires_at": timedelta(seconds=5)}, TypeError), + ({"max_attempts": None}, TypeError), + ({"max_attempts": "3"}, TypeError), + ({"cache_ttl": timedelta(seconds=5), "reload_ttl": timedelta(seconds=10)}, AttributeError), + ], +) +@pytest.mark.asyncio +async def test_validators(get_aioredis, invalid_params, error): + params = {"resources_redis": await get_aioredis()} + params.update(invalid_params) + with pytest.raises(error): + uut(**params) @pytest.mark.asyncio -async def test_error(get_cache, mocker): - async def noop(): - return +async def test_allow_empty(get_cache, mocker): + cache = await get_cache(allow_empty_data=True) - mocker.patch("asyncio.sleep", noop) + async def empty_load(): + return {} - cache = await get_cache() + mocker.patch.object(cache, "load_from_source", empty_load) + mocker.spy(cache, "load_from_source") + + assert await cache.get("a") is None + with pytest.raises(KeyError): + assert not await cache.getitem("b") + assert cache.load_from_source.call_count == 1 + + +@pytest.mark.asyncio +async def test_error(get_cache, mocker): + cache = await get_cache(max_attempts=2) cache.load_from_source = mocker.Mock(side_effect=[Exception("Mock error"), cache.load_from_source()]) assert await cache.get("a") == 101 assert cache.load_from_source.call_count == 2, "Load should be called a second time after first call fails" - assert cache.load_from_cache.call_count == 3, ( + assert cache.load_from_cache.call_count == 4, ( "1st call: Cache is empty, so try loading from source, which fails, " - "2nd call: Cache is empty, so try loading from source, which succeeds and fills cache, " - "3rd call: Cache is filled, so the call succeeds" + "2nd call: Prolong cache expiration after loading fail, " + "3rd call: Cache is empty, so try loading from source, which succeeds and fills cache, " + "4th call: Cache is filled, so the call succeeds" ) @pytest.mark.asyncio -async def test_ttl(get_cache, mocker): - cache = await get_cache() - cache.load_from_source = mocker.Mock(side_effect=[cache.load_from_source()]) - - cache.cache_ttl = timedelta(hours=1) +async def test_ttl(get_cache): + cache = await get_cache(cache_ttl=timedelta(hours=1)) await cache.reload() - ttl = await cache.resources_redis.ttl(cache.redis_key) + ttl = await cache.resources_redis.ttl(cache._cache_key) assert ttl == timedelta(hours=1).total_seconds() cache.cache_ttl = timedelta(minutes=1) - await cache.reload() - await cache.acheck_initialization() - ttl = await cache.resources_redis.ttl(cache.redis_key) + await cache.refill_cache() + ttl = await cache.resources_redis.ttl(cache._cache_key) assert ttl == timedelta(minutes=1).total_seconds() @pytest.mark.asyncio -async def test_maybe_reload(get_cache, mocker, frozen_time): +async def test_maybe_reload(get_cache, frozen_time): cache = await get_cache() await cache.maybe_reload() @@ -110,10 +140,31 @@ async def test_maybe_reload(get_cache, mocker, frozen_time): @pytest.mark.asyncio async def test_missing(get_cache, mocker): cache = await get_cache() - cache.load_from_source = mocker.Mock(side_effect=[cache.load_from_source()]) with pytest.raises(KeyError): - await cache.getitem("misisng-key") + await cache.getitem("missing-key") missing = mocker.patch.object(cache, "__missing__") - await cache.getitem("misisng-key") + await cache.getitem("missing-key") assert missing.call_count == 1 + + +@pytest.mark.parametrize("max_attempts", [-1, 3]) +@pytest.mark.usefixtures("frozen_time") +@pytest.mark.asyncio +async def test_redis_error(get_cache, mocker, max_attempts): + cache = await get_cache(max_attempts=max_attempts) + cache._data = {"a": 213} + cache.expires_at = datetime.utcnow() + + mocker.patch.object(cache.resources_redis, "set", side_effect=aioredis.RedisError) + mocker.patch.object(cache.resources_redis, "get", side_effect=aioredis.RedisError) + + if max_attempts < 0: + assert await cache.get("a") == 213 + else: + with pytest.raises(CallAttemptException): + await cache.get("a") + + assert cache.load_from_source.call_count == 0 + assert cache.load_from_cache.call_count == (2 if max_attempts < 3 else max_attempts) + assert cache.expires_at == datetime.utcnow() + cache.reload_ttl diff --git a/test/unit/conftest.py b/test/unit/conftest.py index f6907d9..b66bb40 100644 --- a/test/unit/conftest.py +++ b/test/unit/conftest.py @@ -1,5 +1,4 @@ import attr -from freezegun import freeze_time import pytest from kw.cache import KiwiCache @@ -38,10 +37,3 @@ def load_from_source(self): @pytest.fixture def cache(redis): # pylint: disable=redefined-outer-name return UUTResource(resources_redis=redis) - - -@pytest.fixture -def frozen_time(): - ft = freeze_time("2000-01-01 00:00:00") - yield ft.start() - ft.stop()