Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/in memory/cacheable mixin #440

Draft
wants to merge 16 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions docs/task_on_kart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -286,3 +286,24 @@ If you want to dump csv file with other encodings, you can use `encoding` parame
def output(self):
return self.make_target('file_name.csv', processor=CsvFileProcessor(encoding='cp932'))
# This will dump csv as 'cp932' which is used in Windows.

Cache output in memory instead of dumping to files
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
You can use :class:`~InMemoryTarget` to cache output in memory instead of dumping to files by calling :func:`~gokart.target.make_inmemory_target`.

.. code:: python

from gokart.in_memory.target import make_inmemory_target

def output(self):
unique_id = self.make_unique_id() if use_unique_id else None
# TaskLock is not supported in InMemoryTarget, so it's dummy
task_lock_params = make_task_lock_params(
file_path='dummy_path',
unique_id=unique_id,
redis_host=None,
redis_port=None,
redis_timeout=self.redis_timeout,
raise_task_lock_exception_on_collision=False,
)
return make_inmemory_target('dummy_path', task_lock_params, unique_id)
1 change: 1 addition & 0 deletions gokart/in_memory/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .repository import InMemoryCacheRepository # noqa:F401
94 changes: 94 additions & 0 deletions gokart/in_memory/cacheable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from typing import Any

from gokart.conflict_prevention_lock.task_lock_wrappers import wrap_dump_with_lock, wrap_load_with_lock, wrap_remove_with_lock
from gokart.in_memory.repository import InMemoryCacheRepository


class CacheNotFoundError(OSError): ...


class CacheableMixin:
def exists(self) -> bool:
return self._cache_exists() or super().exists()

def load(self) -> bool:
def _load():
cached = self._cache_exists()
if cached:
return self._cache_load()
else:
try:
loaded = super(CacheableMixin, self)._load()
except FileNotFoundError as e:
raise CacheNotFoundError from e
self._cache_dump(loaded)
return loaded

return wrap_load_with_lock(func=_load, task_lock_params=super()._get_task_lock_params())()

def dump(self, obj: Any, lock_at_dump: bool = True, also_dump_to_file: bool = False):
# TODO: how to sync cache and file
def _dump(obj: Any):
self._cache_dump(obj)
if also_dump_to_file:
super(CacheableMixin, self)._dump(obj)

if lock_at_dump:
wrap_dump_with_lock(func=_dump, task_lock_params=super()._get_task_lock_params(), exist_check=self.exists)(obj)
else:
_dump(obj)

def remove(self, also_remove_file: bool = False):
def _remove():
if self._cache_exists():
self._cache_remove()
if super(CacheableMixin, self).exists() and also_remove_file:
super(CacheableMixin, self)._remove()

wrap_remove_with_lock(func=_remove, task_lock_params=super()._get_task_lock_params())()

def last_modification_time(self):
if self._cache_exists():
return self._cache_last_modification_time()
try:
return super()._last_modification_time()
except FileNotFoundError as e:
raise CacheNotFoundError from e

@property
def data_key(self):
return super().path()

def _cache_exists(self):
raise NotImplementedError

def _cache_load(self):
raise NotImplementedError

def _cache_dump(self, obj: Any):
raise NotImplementedError

def _cache_remove(self):
raise NotImplementedError

def _cache_last_modification_time(self):
raise NotImplementedError


class InMemoryCacheableMixin(CacheableMixin):
_repository = InMemoryCacheRepository()

def _cache_exists(self):
return self._repository.has(self.data_key)

def _cache_load(self):
return self._repository.get_value(self.data_key)

def _cache_dump(self, obj):
return self._repository.set_value(self.data_key, obj)

def _cache_remove(self):
self._repository.remove(self.data_key)

def _cache_last_modification_time(self):
return self._repository.get_last_modification_time(self.data_key)
16 changes: 16 additions & 0 deletions gokart/in_memory/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Protocol


class BaseData(Protocol): ...


@dataclass
class InMemoryData(BaseData):
value: Any
last_modification_time: datetime

@classmethod
def create_data(self, value: Any) -> 'InMemoryData':
return InMemoryData(value=value, last_modification_time=datetime.now())
103 changes: 103 additions & 0 deletions gokart/in_memory/repository.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from abc import ABC, abstractmethod
from typing import Any, Iterator

from .data import InMemoryData


class CacheScheduler(ABC):
def __new__(cls):
if not hasattr(cls, '__instance'):
setattr(cls, '__instance', super().__new__(cls))
return getattr(cls, '__instance')

@abstractmethod
def get_hook(self, repository: 'InMemoryCacheRepository', key: str, data: InMemoryData): ...

@abstractmethod
def set_hook(self, repository: 'InMemoryCacheRepository', key: str, data: InMemoryData): ...

@abstractmethod
def clear(self): ...


class DoNothingScheduler(CacheScheduler):
def get_hook(self, repository: 'InMemoryCacheRepository', key: str, data: InMemoryData):
pass

def set_hook(self, repository: 'InMemoryCacheRepository', key: str, data: InMemoryData):
pass

def clear(self):
pass


# TODO: ambiguous class name
class InstantScheduler(CacheScheduler):
def get_hook(self, repository: 'InMemoryCacheRepository', key: str, data: InMemoryData):
repository.remove(key)

def set_hook(self, repository: 'InMemoryCacheRepository', key: str, data: InMemoryData):
pass

def clear(self):
pass


class InMemoryCacheRepository:
_cache: dict[str, InMemoryData] = {}
_scheduler: CacheScheduler = DoNothingScheduler()

def __new__(cls):
if not hasattr(cls, '__instance'):
cls.__instance = super().__new__(cls)
return cls.__instance

@classmethod
def set_scheduler(cls, cache_scheduler: CacheScheduler):
cls._scheduler = cache_scheduler

def get_value(self, key: str) -> Any:
data = self._get_data(key)
self._scheduler.get_hook(self, key, data)
return data.value

def get_last_modification_time(self, key: str):
return self._get_data(key).last_modification_time

def _get_data(self, key: str) -> InMemoryData:
return self._cache[key]

def set_value(self, key: str, obj: Any) -> None:
data = InMemoryData.create_data(obj)
self._scheduler.set_hook(self, key, data)
self._set_data(key, data)

def _set_data(self, key: str, data: InMemoryData):
self._cache[key] = data

def has(self, key: str) -> bool:
return key in self._cache

def remove(self, key: str) -> None:
assert self.has(key), f'{key} does not exist.'
del self._cache[key]

def empty(self) -> bool:
return not self._cache

@classmethod
def clear(cls) -> None:
cls._cache.clear()
cls._scheduler.clear()

def get_gen(self) -> Iterator[tuple[str, Any]]:
for key, data in self._cache.items():
yield key, data.value

@property
def size(self) -> int:
return len(self._cache)

@property
def scheduler(self) -> CacheScheduler:
return self._scheduler
62 changes: 60 additions & 2 deletions gokart/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from gokart.conflict_prevention_lock.task_lock import TaskLockParams, make_task_lock_params
from gokart.conflict_prevention_lock.task_lock_wrappers import wrap_dump_with_lock, wrap_load_with_lock, wrap_remove_with_lock
from gokart.file_processor import FileProcessor, make_file_processor
from gokart.in_memory.cacheable import InMemoryCacheableMixin
from gokart.in_memory.repository import InMemoryCacheRepository
from gokart.object_storage import ObjectStorage
from gokart.zip_client_util import make_zip_client

Expand Down Expand Up @@ -164,6 +166,12 @@ def _make_temporary_directory(self):
os.makedirs(self._temporary_directory, exist_ok=True)


class CacheableSingleFileTarget(InMemoryCacheableMixin, SingleFileTarget): ...


class CacheableModelTarget(InMemoryCacheableMixin, ModelTarget): ...


class LargeDataFrameProcessor(object):
def __init__(self, max_byte: int):
self.max_byte = int(max_byte)
Expand Down Expand Up @@ -216,12 +224,14 @@ def make_target(
processor: Optional[FileProcessor] = None,
task_lock_params: Optional[TaskLockParams] = None,
store_index_in_feather: bool = True,
cacheable: bool = False,
) -> TargetOnKart:
_task_lock_params = task_lock_params if task_lock_params is not None else make_task_lock_params(file_path=file_path, unique_id=unique_id)
file_path = _make_file_path(file_path, unique_id)
processor = processor or make_file_processor(file_path, store_index_in_feather=store_index_in_feather)
file_system_target = _make_file_system_target(file_path, processor=processor, store_index_in_feather=store_index_in_feather)
return SingleFileTarget(target=file_system_target, processor=processor, task_lock_params=_task_lock_params)
cls = CacheableSingleFileTarget if cacheable else SingleFileTarget
return cls(target=file_system_target, processor=processor, task_lock_params=_task_lock_params)


def make_model_target(
Expand All @@ -231,14 +241,62 @@ def make_model_target(
load_function,
unique_id: Optional[str] = None,
task_lock_params: Optional[TaskLockParams] = None,
cacheable: bool = False,
) -> TargetOnKart:
_task_lock_params = task_lock_params if task_lock_params is not None else make_task_lock_params(file_path=file_path, unique_id=unique_id)
file_path = _make_file_path(file_path, unique_id)
temporary_directory = os.path.join(temporary_directory, hashlib.md5(file_path.encode()).hexdigest())
return ModelTarget(
cls = CacheableModelTarget if cacheable else ModelTarget
return cls(
file_path=file_path,
temporary_directory=temporary_directory,
save_function=save_function,
load_function=load_function,
task_lock_params=_task_lock_params,
)


class InMemoryTarget(TargetOnKart):
def __init__(self, data_key: str, task_lock_param: TaskLockParams):
if task_lock_param.should_task_lock:
logger.warning(f'Redis in {self.__class__.__name__} is not supported now.')

self._data_key = data_key
self._task_lock_params = task_lock_param
self._repository = InMemoryCacheRepository()

def _exists(self) -> bool:
return self._repository.has(self._data_key)

def _get_task_lock_params(self) -> TaskLockParams:
return self._task_lock_params

def _load(self) -> Any:
return self._repository.get_value(self._data_key)

def _dump(self, obj: Any) -> None:
return self._repository.set_value(self._data_key, obj)

def _remove(self) -> None:
self._repository.remove(self._data_key)

def _last_modification_time(self) -> datetime:
if not self._repository.has(self._data_key):
raise ValueError(f'No object(s) which id is {self._data_key} are stored before.')
time = self._repository.get_last_modification_time(self._data_key)
return time

def _path(self) -> str:
# TODO: this module name `_path` migit not be appropriate
return self._data_key


def _make_data_key(data_key: str, unique_id: Optional[str] = None):
if not unique_id:
return data_key
return data_key + '_' + unique_id


def make_inmemory_target(data_key: str, task_lock_params: TaskLockParams, unique_id: Optional[str] = None):
_data_key = _make_data_key(data_key, unique_id)
return InMemoryTarget(_data_key, task_lock_params)
Loading
Loading