diff --git a/docs/task_on_kart.rst b/docs/task_on_kart.rst index ce52c6d5..09e7a59f 100644 --- a/docs/task_on_kart.rst +++ b/docs/task_on_kart.rst @@ -286,3 +286,24 @@ If you want to dump csv file with other encodings, you can use `encoding` parame def output(self): return self.make_target('file_name.csv', processor=CsvFileProcessor(encoding='cp932')) # This will dump csv as 'cp932' which is used in Windows. + +Cache output in memory instead of dumping to files +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +You can use :class:`~InMemoryTarget` to cache output in memory instead of dumping to files by calling :func:`~gokart.target.make_inmemory_target`. + +.. code:: python + + from gokart.in_memory.target import make_inmemory_target + + def output(self): + unique_id = self.make_unique_id() if use_unique_id else None + # TaskLock is not supported in InMemoryTarget, so it's dummy + task_lock_params = make_task_lock_params( + file_path='dummy_path', + unique_id=unique_id, + redis_host=None, + redis_port=None, + redis_timeout=self.redis_timeout, + raise_task_lock_exception_on_collision=False, + ) + return make_inmemory_target('dummy_path', task_lock_params, unique_id) \ No newline at end of file diff --git a/gokart/in_memory/__init__.py b/gokart/in_memory/__init__.py new file mode 100644 index 00000000..a935fc20 --- /dev/null +++ b/gokart/in_memory/__init__.py @@ -0,0 +1 @@ +from .repository import InMemoryCacheRepository # noqa:F401 diff --git a/gokart/in_memory/cacheable.py b/gokart/in_memory/cacheable.py new file mode 100644 index 00000000..05664e49 --- /dev/null +++ b/gokart/in_memory/cacheable.py @@ -0,0 +1,94 @@ +from typing import Any + +from gokart.conflict_prevention_lock.task_lock_wrappers import wrap_dump_with_lock, wrap_load_with_lock, wrap_remove_with_lock +from gokart.in_memory.repository import InMemoryCacheRepository + + +class CacheNotFoundError(OSError): ... + + +class CacheableMixin: + def exists(self) -> bool: + return self._cache_exists() or super().exists() + + def load(self) -> bool: + def _load(): + cached = self._cache_exists() + if cached: + return self._cache_load() + else: + try: + loaded = super(CacheableMixin, self)._load() + except FileNotFoundError as e: + raise CacheNotFoundError from e + self._cache_dump(loaded) + return loaded + + return wrap_load_with_lock(func=_load, task_lock_params=super()._get_task_lock_params())() + + def dump(self, obj: Any, lock_at_dump: bool = True, also_dump_to_file: bool = False): + # TODO: how to sync cache and file + def _dump(obj: Any): + self._cache_dump(obj) + if also_dump_to_file: + super(CacheableMixin, self)._dump(obj) + + if lock_at_dump: + wrap_dump_with_lock(func=_dump, task_lock_params=super()._get_task_lock_params(), exist_check=self.exists)(obj) + else: + _dump(obj) + + def remove(self, also_remove_file: bool = False): + def _remove(): + if self._cache_exists(): + self._cache_remove() + if super(CacheableMixin, self).exists() and also_remove_file: + super(CacheableMixin, self)._remove() + + wrap_remove_with_lock(func=_remove, task_lock_params=super()._get_task_lock_params())() + + def last_modification_time(self): + if self._cache_exists(): + return self._cache_last_modification_time() + try: + return super()._last_modification_time() + except FileNotFoundError as e: + raise CacheNotFoundError from e + + @property + def data_key(self): + return super().path() + + def _cache_exists(self): + raise NotImplementedError + + def _cache_load(self): + raise NotImplementedError + + def _cache_dump(self, obj: Any): + raise NotImplementedError + + def _cache_remove(self): + raise NotImplementedError + + def _cache_last_modification_time(self): + raise NotImplementedError + + +class InMemoryCacheableMixin(CacheableMixin): + _repository = InMemoryCacheRepository() + + def _cache_exists(self): + return self._repository.has(self.data_key) + + def _cache_load(self): + return self._repository.get_value(self.data_key) + + def _cache_dump(self, obj): + return self._repository.set_value(self.data_key, obj) + + def _cache_remove(self): + self._repository.remove(self.data_key) + + def _cache_last_modification_time(self): + return self._repository.get_last_modification_time(self.data_key) diff --git a/gokart/in_memory/data.py b/gokart/in_memory/data.py new file mode 100644 index 00000000..a01c3ad2 --- /dev/null +++ b/gokart/in_memory/data.py @@ -0,0 +1,16 @@ +from dataclasses import dataclass +from datetime import datetime +from typing import Any, Protocol + + +class BaseData(Protocol): ... + + +@dataclass +class InMemoryData(BaseData): + value: Any + last_modification_time: datetime + + @classmethod + def create_data(self, value: Any) -> 'InMemoryData': + return InMemoryData(value=value, last_modification_time=datetime.now()) diff --git a/gokart/in_memory/repository.py b/gokart/in_memory/repository.py new file mode 100644 index 00000000..a0ef0e76 --- /dev/null +++ b/gokart/in_memory/repository.py @@ -0,0 +1,103 @@ +from abc import ABC, abstractmethod +from typing import Any, Iterator + +from .data import InMemoryData + + +class CacheScheduler(ABC): + def __new__(cls): + if not hasattr(cls, '__instance'): + setattr(cls, '__instance', super().__new__(cls)) + return getattr(cls, '__instance') + + @abstractmethod + def get_hook(self, repository: 'InMemoryCacheRepository', key: str, data: InMemoryData): ... + + @abstractmethod + def set_hook(self, repository: 'InMemoryCacheRepository', key: str, data: InMemoryData): ... + + @abstractmethod + def clear(self): ... + + +class DoNothingScheduler(CacheScheduler): + def get_hook(self, repository: 'InMemoryCacheRepository', key: str, data: InMemoryData): + pass + + def set_hook(self, repository: 'InMemoryCacheRepository', key: str, data: InMemoryData): + pass + + def clear(self): + pass + + +# TODO: ambiguous class name +class InstantScheduler(CacheScheduler): + def get_hook(self, repository: 'InMemoryCacheRepository', key: str, data: InMemoryData): + repository.remove(key) + + def set_hook(self, repository: 'InMemoryCacheRepository', key: str, data: InMemoryData): + pass + + def clear(self): + pass + + +class InMemoryCacheRepository: + _cache: dict[str, InMemoryData] = {} + _scheduler: CacheScheduler = DoNothingScheduler() + + def __new__(cls): + if not hasattr(cls, '__instance'): + cls.__instance = super().__new__(cls) + return cls.__instance + + @classmethod + def set_scheduler(cls, cache_scheduler: CacheScheduler): + cls._scheduler = cache_scheduler + + def get_value(self, key: str) -> Any: + data = self._get_data(key) + self._scheduler.get_hook(self, key, data) + return data.value + + def get_last_modification_time(self, key: str): + return self._get_data(key).last_modification_time + + def _get_data(self, key: str) -> InMemoryData: + return self._cache[key] + + def set_value(self, key: str, obj: Any) -> None: + data = InMemoryData.create_data(obj) + self._scheduler.set_hook(self, key, data) + self._set_data(key, data) + + def _set_data(self, key: str, data: InMemoryData): + self._cache[key] = data + + def has(self, key: str) -> bool: + return key in self._cache + + def remove(self, key: str) -> None: + assert self.has(key), f'{key} does not exist.' + del self._cache[key] + + def empty(self) -> bool: + return not self._cache + + @classmethod + def clear(cls) -> None: + cls._cache.clear() + cls._scheduler.clear() + + def get_gen(self) -> Iterator[tuple[str, Any]]: + for key, data in self._cache.items(): + yield key, data.value + + @property + def size(self) -> int: + return len(self._cache) + + @property + def scheduler(self) -> CacheScheduler: + return self._scheduler diff --git a/gokart/target.py b/gokart/target.py index 88b3c942..0fdaa958 100644 --- a/gokart/target.py +++ b/gokart/target.py @@ -14,6 +14,8 @@ from gokart.conflict_prevention_lock.task_lock import TaskLockParams, make_task_lock_params from gokart.conflict_prevention_lock.task_lock_wrappers import wrap_dump_with_lock, wrap_load_with_lock, wrap_remove_with_lock from gokart.file_processor import FileProcessor, make_file_processor +from gokart.in_memory.cacheable import InMemoryCacheableMixin +from gokart.in_memory.repository import InMemoryCacheRepository from gokart.object_storage import ObjectStorage from gokart.zip_client_util import make_zip_client @@ -164,6 +166,12 @@ def _make_temporary_directory(self): os.makedirs(self._temporary_directory, exist_ok=True) +class CacheableSingleFileTarget(InMemoryCacheableMixin, SingleFileTarget): ... + + +class CacheableModelTarget(InMemoryCacheableMixin, ModelTarget): ... + + class LargeDataFrameProcessor(object): def __init__(self, max_byte: int): self.max_byte = int(max_byte) @@ -216,12 +224,14 @@ def make_target( processor: Optional[FileProcessor] = None, task_lock_params: Optional[TaskLockParams] = None, store_index_in_feather: bool = True, + cacheable: bool = False, ) -> TargetOnKart: _task_lock_params = task_lock_params if task_lock_params is not None else make_task_lock_params(file_path=file_path, unique_id=unique_id) file_path = _make_file_path(file_path, unique_id) processor = processor or make_file_processor(file_path, store_index_in_feather=store_index_in_feather) file_system_target = _make_file_system_target(file_path, processor=processor, store_index_in_feather=store_index_in_feather) - return SingleFileTarget(target=file_system_target, processor=processor, task_lock_params=_task_lock_params) + cls = CacheableSingleFileTarget if cacheable else SingleFileTarget + return cls(target=file_system_target, processor=processor, task_lock_params=_task_lock_params) def make_model_target( @@ -231,14 +241,62 @@ def make_model_target( load_function, unique_id: Optional[str] = None, task_lock_params: Optional[TaskLockParams] = None, + cacheable: bool = False, ) -> TargetOnKart: _task_lock_params = task_lock_params if task_lock_params is not None else make_task_lock_params(file_path=file_path, unique_id=unique_id) file_path = _make_file_path(file_path, unique_id) temporary_directory = os.path.join(temporary_directory, hashlib.md5(file_path.encode()).hexdigest()) - return ModelTarget( + cls = CacheableModelTarget if cacheable else ModelTarget + return cls( file_path=file_path, temporary_directory=temporary_directory, save_function=save_function, load_function=load_function, task_lock_params=_task_lock_params, ) + + +class InMemoryTarget(TargetOnKart): + def __init__(self, data_key: str, task_lock_param: TaskLockParams): + if task_lock_param.should_task_lock: + logger.warning(f'Redis in {self.__class__.__name__} is not supported now.') + + self._data_key = data_key + self._task_lock_params = task_lock_param + self._repository = InMemoryCacheRepository() + + def _exists(self) -> bool: + return self._repository.has(self._data_key) + + def _get_task_lock_params(self) -> TaskLockParams: + return self._task_lock_params + + def _load(self) -> Any: + return self._repository.get_value(self._data_key) + + def _dump(self, obj: Any) -> None: + return self._repository.set_value(self._data_key, obj) + + def _remove(self) -> None: + self._repository.remove(self._data_key) + + def _last_modification_time(self) -> datetime: + if not self._repository.has(self._data_key): + raise ValueError(f'No object(s) which id is {self._data_key} are stored before.') + time = self._repository.get_last_modification_time(self._data_key) + return time + + def _path(self) -> str: + # TODO: this module name `_path` migit not be appropriate + return self._data_key + + +def _make_data_key(data_key: str, unique_id: Optional[str] = None): + if not unique_id: + return data_key + return data_key + '_' + unique_id + + +def make_inmemory_target(data_key: str, task_lock_params: TaskLockParams, unique_id: Optional[str] = None): + _data_key = _make_data_key(data_key, unique_id) + return InMemoryTarget(_data_key, task_lock_params) diff --git a/gokart/task.py b/gokart/task.py index f577f64b..66f8c65f 100644 --- a/gokart/task.py +++ b/gokart/task.py @@ -20,7 +20,7 @@ import gokart import gokart.target -from gokart.conflict_prevention_lock.task_lock import make_task_lock_params, make_task_lock_params_for_run +from gokart.conflict_prevention_lock.task_lock import TaskLockParams, make_task_lock_params, make_task_lock_params_for_run from gokart.conflict_prevention_lock.task_lock_wrappers import wrap_run_with_lock from gokart.file_processor import FileProcessor from gokart.pandas_type_config import PandasTypeConfigMap @@ -105,6 +105,9 @@ class TaskOnKart(luigi.Task, Generic[T]): default=True, description='Check if output file exists at run. If exists, run() will be skipped.', significant=False ) should_lock_run: bool = ExplicitBoolParameter(default=False, significant=False, description='Whether to use redis lock or not at task run.') + cache_in_memory_by_default: bool = ExplicitBoolParameter( + default=False, significant=False, description='If `True`, output is stored on a memory instead of files unless specified.' + ) @property def priority(self): @@ -134,11 +137,13 @@ def __init__(self, *args, **kwargs): task_lock_params = make_task_lock_params_for_run(task_self=self) self.run = wrap_run_with_lock(run_func=self.run, task_lock_params=task_lock_params) # type: ignore + self.make_default_target = self.make_target if not self.cache_in_memory_by_default else self.make_cache_target + def input(self) -> FlattenableItems[TargetOnKart]: return super().input() def output(self) -> FlattenableItems[TargetOnKart]: - return self.make_target() + return self.make_default_target() def requires(self) -> FlattenableItems['TaskOnKart']: tasks = self.make_task_instance_dictionary() @@ -209,7 +214,9 @@ def clone(self, cls=None, **kwargs): return cls(**new_k) - def make_target(self, relative_file_path: Optional[str] = None, use_unique_id: bool = True, processor: Optional[FileProcessor] = None) -> TargetOnKart: + def make_target( + self, relative_file_path: Optional[str] = None, use_unique_id: bool = True, processor: Optional[FileProcessor] = None, cacheable: bool = False + ) -> TargetOnKart: formatted_relative_file_path = ( relative_file_path if relative_file_path is not None else os.path.join(self.__module__.replace('.', '/'), f'{type(self).__name__}.pkl') ) @@ -226,8 +233,28 @@ def make_target(self, relative_file_path: Optional[str] = None, use_unique_id: b ) return gokart.target.make_target( - file_path=file_path, unique_id=unique_id, processor=processor, task_lock_params=task_lock_params, store_index_in_feather=self.store_index_in_feather + file_path=file_path, + unique_id=unique_id, + processor=processor, + task_lock_params=task_lock_params, + store_index_in_feather=self.store_index_in_feather, + cacheable=cacheable, + ) + + def make_cache_target(self, data_key: Optional[str] = None, use_unique_id: bool = True): + _data_key = data_key if data_key else os.path.join(self.__module__.replace('.', '/'), type(self).__name__) + unique_id = self.make_unique_id() if use_unique_id else None + # TODO: combine with redis + task_lock_params = TaskLockParams( + redis_host=None, + redis_port=None, + redis_timeout=None, + redis_key='redis_key', + should_task_lock=False, + raise_task_lock_exception_on_collision=False, + lock_extend_seconds=-1, ) + return gokart.target.make_inmemory_target(_data_key, task_lock_params, unique_id) def make_large_data_frame_target(self, relative_file_path: Optional[str] = None, use_unique_id: bool = True, max_byte=int(2**26)) -> TargetOnKart: formatted_relative_file_path = ( @@ -254,7 +281,12 @@ def make_large_data_frame_target(self, relative_file_path: Optional[str] = None, ) def make_model_target( - self, relative_file_path: str, save_function: Callable[[Any, str], None], load_function: Callable[[str], Any], use_unique_id: bool = True + self, + relative_file_path: str, + save_function: Callable[[Any, str], None], + load_function: Callable[[str], Any], + use_unique_id: bool = True, + cacheable: bool = False, ): """ Make target for models which generate multiple files in saving, e.g. gensim.Word2Vec, Tensorflow, and so on. @@ -283,6 +315,7 @@ def make_model_target( save_function=save_function, load_function=load_function, task_lock_params=task_lock_params, + cacheable=cacheable, ) @overload diff --git a/test/in_memory/test_cacheable_target.py b/test/in_memory/test_cacheable_target.py new file mode 100644 index 00000000..7dd55830 --- /dev/null +++ b/test/in_memory/test_cacheable_target.py @@ -0,0 +1,307 @@ +import pickle +from time import sleep + +import luigi +import pytest + +from gokart.in_memory import InMemoryCacheRepository +from gokart.in_memory.cacheable import CacheNotFoundError +from gokart.target import CacheableModelTarget, CacheableSingleFileTarget +from gokart.task import TaskOnKart + + +class DummyTask(TaskOnKart): + namespace = __name__ + param = luigi.IntParameter() + + def run(self): + self.dump(self.param) + + +class TestCacheableSingleFileTarget: + @pytest.fixture + def task(self, tmpdir): + task = DummyTask(param=100, workspace_directory=tmpdir) + return task + + @pytest.fixture(autouse=True) + def clear_repository(self): + InMemoryCacheRepository.clear() + + def test_exists_when_cache_exists(self, task: TaskOnKart): + cacheable_target = task.make_target('sample.pkl', cacheable=True) + cache_target = task.make_cache_target(cacheable_target.path(), use_unique_id=False) + assert not cacheable_target.exists() + cache_target.dump('data') + assert cacheable_target.exists() + + def test_exists_when_file_exists(self, task: TaskOnKart): + cacheable_target = task.make_target('sample.pkl', cacheable=True) + target = task.make_target('sample.pkl') + assert not cacheable_target.exists() + target.dump('data') + assert cacheable_target.exists() + + def test_load_without_cache_or_file(self, task: TaskOnKart): + target = task.make_target('sample.pkl') + with pytest.raises(FileNotFoundError): + target.load() + cacheable_target = task.make_target('sample.pkl', cacheable=True) + with pytest.raises(CacheNotFoundError): + cacheable_target.load() + + def test_load_with_cache(self, task: TaskOnKart): + target = task.make_target('sample.pkl') + cacheable_target: CacheableSingleFileTarget = task.make_target('sample.pkl', cacheable=True) + cache_target = task.make_cache_target(target.path(), use_unique_id=False) + with pytest.raises(CacheNotFoundError): + cacheable_target.load() + cache_target.dump('data') + with pytest.raises(FileNotFoundError): + target.load() + assert cacheable_target.load() == 'data' + assert cache_target.load() == 'data' + + def test_load_with_file(self, task: TaskOnKart): + target = task.make_target('sample.pkl') + cacheable_target: CacheableSingleFileTarget = task.make_target('sample.pkl', cacheable=True) + cache_target = task.make_cache_target(target.path(), use_unique_id=False) + with pytest.raises(CacheNotFoundError): + cacheable_target.load() + target.dump('data') + assert target.load() == 'data' + with pytest.raises(KeyError): + cache_target.load() + assert cacheable_target.load() == 'data' + assert cache_target.load() == 'data' + + def test_load_with_cache_and_file(self, task: TaskOnKart): + target = task.make_target('sample.pkl') + cacheable_target: CacheableSingleFileTarget = task.make_target('sample.pkl', cacheable=True) + cache_target = task.make_cache_target(target.path(), use_unique_id=False) + with pytest.raises(CacheNotFoundError): + cacheable_target.load() + target.dump('data_in_file') + cache_target.dump('data_in_memory') + assert cacheable_target.load() == 'data_in_memory' + + def test_dump(self, task: TaskOnKart): + target = task.make_target('sample.pkl') + cacheable_target: CacheableSingleFileTarget = task.make_target('sample.pkl', cacheable=True) + cache_target = task.make_cache_target(target.path(), use_unique_id=False) + cacheable_target.dump('data') + assert not target.exists() + assert cache_target.exists() + assert cacheable_target.exists() + + def test_dump_with_dump_to_file_flag(self, task: TaskOnKart): + target = task.make_target('sample.pkl') + cacheable_target: CacheableSingleFileTarget = task.make_target('sample.pkl', cacheable=True) + cache_target = task.make_cache_target(target.path(), use_unique_id=False) + cacheable_target.dump('data', also_dump_to_file=True) + assert target.exists() + assert cache_target.exists() + assert cacheable_target.exists() + + def test_remove_without_cache_or_file(self, task: TaskOnKart): + cacheable_target: CacheableSingleFileTarget = task.make_target('sample.pkl', cacheable=True) + cacheable_target.remove() + cacheable_target.remove(also_remove_file=True) + assert True + + def test_remove_with_cache(self, task: TaskOnKart): + target = task.make_target('sample.pkl') + cacheable_target: CacheableSingleFileTarget = task.make_target('sample.pkl', cacheable=True) + cache_target = task.make_cache_target(target.path(), use_unique_id=False) + cache_target.dump('data') + assert cache_target.exists() + cacheable_target.remove() + assert not cache_target.exists() + + def test_remove_with_file(self, task: TaskOnKart): + target = task.make_target('sample.pkl') + cacheable_target: CacheableSingleFileTarget = task.make_target('sample.pkl', cacheable=True) + target.dump('data') + assert target.exists() + cacheable_target.remove() + assert target.exists() + cacheable_target.remove(also_remove_file=True) + assert not target.exists() + + def test_remove_with_cache_and_file(self, task: TaskOnKart): + target = task.make_target('sample.pkl') + cacheable_target: CacheableSingleFileTarget = task.make_target('sample.pkl', cacheable=True) + cache_target = task.make_cache_target(target.path(), use_unique_id=False) + target.dump('file_data') + cache_target.dump('inmemory_data') + cacheable_target.remove() + assert target.exists() + assert not cache_target.exists() + + target.dump('file_data') + cache_target.dump('inmemory_data') + cacheable_target.remove(also_remove_file=True) + assert not target.exists() + assert not cache_target.exists() + + def test_last_modification_time_without_cache_and_file(self, task: TaskOnKart): + target = task.make_target('sample.pkl') + cacheable_target: CacheableSingleFileTarget = task.make_target('sample.pkl', cacheable=True) + cache_target = task.make_cache_target(target.path(), use_unique_id=False) + with pytest.raises(FileNotFoundError): + target.last_modification_time() + with pytest.raises(ValueError): + cache_target.last_modification_time() + with pytest.raises(CacheNotFoundError): + cacheable_target.last_modification_time() + + def test_last_modification_time_with_cache(self, task: TaskOnKart): + target = task.make_target('sample.pkl') + cacheable_target: CacheableSingleFileTarget = task.make_target('sample.pkl', cacheable=True) + cache_target = task.make_cache_target(target.path(), use_unique_id=False) + cache_target.dump('data') + assert cacheable_target.last_modification_time() == cache_target.last_modification_time() + + def test_last_modification_time_with_file(self, task: TaskOnKart): + target = task.make_target('sample.pkl') + cacheable_target: CacheableSingleFileTarget = task.make_target('sample.pkl', cacheable=True) + target.dump('data') + assert cacheable_target.last_modification_time() == target.last_modification_time() + + def test_last_modification_time_with_cache_and_file(self, task: TaskOnKart): + target = task.make_target('sample.pkl') + cacheable_target: CacheableSingleFileTarget = task.make_target('sample.pkl', cacheable=True) + cache_target = task.make_cache_target(target.path(), use_unique_id=False) + target.dump('file_data') + sleep(0.1) + cache_target.dump('inmemory_data') + assert cacheable_target.last_modification_time() == cache_target.last_modification_time() + + +class DummyModule: + def func(self): + return 'hello world' + + +class DummyModuleA: + def func_a(self): + return 'hello world' + + +class DummyModuleB: + def func_b(self): + return 'hello world' + + +def _save_func(obj, path): + with open(path, 'wb') as f: + pickle.dump(obj, f) + + +def _load_func(path): + with open(path, 'rb') as f: + return pickle.load(f) + + +class TestCacheableModelTarget: + @pytest.fixture + def task(self, tmpdir): + task = DummyTask(param=100, workspace_directory=tmpdir) + return task + + @pytest.fixture(autouse=True) + def clear_repository(self): + InMemoryCacheRepository.clear() + + def test_exists_without_cache_or_file(self, task: TaskOnKart): + target = task.make_model_target(relative_file_path='model.zip', save_function=_save_func, load_function=_load_func) + cacheable_target = task.make_model_target(relative_file_path='model.zip', save_function=_save_func, load_function=_load_func, cacheable=True) + cache_target = task.make_cache_target(target.path(), use_unique_id=False) + assert not target.exists() + assert not cache_target.exists() + assert not cacheable_target.exists() + + def test_exists_with_file(self, task: TaskOnKart): + target = task.make_model_target(relative_file_path='model.zip', save_function=_save_func, load_function=_load_func) + cacheable_target = task.make_model_target(relative_file_path='model.zip', save_function=_save_func, load_function=_load_func, cacheable=True) + assert not cacheable_target.exists() + module = DummyModule() + target.dump(module) + assert cacheable_target.exists() + + def test_exists_with_cache(self, task: TaskOnKart): + target = task.make_model_target(relative_file_path='model.zip', save_function=_save_func, load_function=_load_func) + cacheable_target = task.make_model_target(relative_file_path='model.zip', save_function=_save_func, load_function=_load_func, cacheable=True) + cache_target = task.make_cache_target(target.path(), use_unique_id=False) + assert not cacheable_target.exists() + module = DummyModule() + cache_target.dump(module) + assert not target.exists() + assert cacheable_target.exists() + + def test_load_without_cache_or_file(self, task: TaskOnKart): + target = task.make_model_target(relative_file_path='model.zip', save_function=_save_func, load_function=_load_func) + cacheable_target = task.make_model_target(relative_file_path='model.zip', save_function=_save_func, load_function=_load_func, cacheable=True) + cache_target = task.make_cache_target(target.path(), use_unique_id=False) + with pytest.raises(FileNotFoundError): + target.load() + with pytest.raises(KeyError): + cache_target.load() + with pytest.raises(CacheNotFoundError): + cacheable_target.load() + + def test_load_with_cache(self, task: TaskOnKart): + target = task.make_model_target(relative_file_path='model.zip', save_function=_save_func, load_function=_load_func) + cacheable_target = task.make_model_target(relative_file_path='model.zip', save_function=_save_func, load_function=_load_func, cacheable=True) + cache_target = task.make_cache_target(target.path(), use_unique_id=False) + module = DummyModule() + cache_target.dump(module) + with pytest.raises(FileNotFoundError): + target.load() + assert isinstance(cache_target.load(), DummyModule) + assert isinstance(cacheable_target.load(), DummyModule) + + def test_load_with_file(self, task: TaskOnKart): + target = task.make_model_target(relative_file_path='model.zip', save_function=_save_func, load_function=_load_func) + cacheable_target = task.make_model_target(relative_file_path='model.zip', save_function=_save_func, load_function=_load_func, cacheable=True) + cache_target = task.make_cache_target(target.path(), use_unique_id=False) + module = DummyModule() + target.dump(module) + assert isinstance(target.load(), DummyModule) + assert not cache_target.exists() + assert isinstance(cacheable_target.load(), DummyModule) + assert cache_target.exists() + + def test_load_with_cache_and_file(self, task: TaskOnKart): + target = task.make_model_target(relative_file_path='model.zip', save_function=_save_func, load_function=_load_func) + cacheable_target = task.make_model_target(relative_file_path='model.zip', save_function=_save_func, load_function=_load_func, cacheable=True) + cache_target = task.make_cache_target(target.path(), use_unique_id=False) + inmemory_module_cls, file_module_cls = DummyModule, DummyModuleA + inmemory_module, file_module = inmemory_module_cls(), file_module_cls() + target.dump(file_module) + cache_target.dump(inmemory_module) + assert isinstance(target.load(), file_module_cls) + assert isinstance(cache_target.load(), inmemory_module_cls) + assert isinstance(cacheable_target.load(), inmemory_module_cls) + + def test_dump(self, task: TaskOnKart): + target = task.make_model_target(relative_file_path='model.zip', save_function=_save_func, load_function=_load_func) + cacheable_target: CacheableModelTarget = task.make_model_target( + relative_file_path='model.zip', save_function=_save_func, load_function=_load_func, cacheable=True + ) + module = DummyModule() + cacheable_target.dump(module) + assert not target.exists() + assert cacheable_target.exists() + + def test_dump_with_dump_to_file_flag(self, task: TaskOnKart): + target = task.make_model_target(relative_file_path='model.zip', save_function=_save_func, load_function=_load_func) + cacheable_target: CacheableModelTarget = task.make_model_target( + relative_file_path='model.zip', save_function=_save_func, load_function=_load_func, cacheable=True + ) + cache_target = task.make_cache_target(target.path(), use_unique_id=False) + module = DummyModule() + cacheable_target.dump(module, also_dump_to_file=True) + assert target.exists() + assert cache_target.exists() + assert cacheable_target.exists() diff --git a/test/in_memory/test_in_memory_target.py b/test/in_memory/test_in_memory_target.py new file mode 100644 index 00000000..0716c514 --- /dev/null +++ b/test/in_memory/test_in_memory_target.py @@ -0,0 +1,57 @@ +from datetime import datetime +from time import sleep + +import pytest + +from gokart.conflict_prevention_lock.task_lock import TaskLockParams +from gokart.in_memory import InMemoryCacheRepository +from gokart.target import InMemoryTarget, make_inmemory_target + + +class TestInMemoryTarget: + @pytest.fixture + def task_lock_params(self): + return TaskLockParams( + redis_host=None, + redis_port=None, + redis_timeout=None, + redis_key='dummy', + should_task_lock=False, + raise_task_lock_exception_on_collision=False, + lock_extend_seconds=0, + ) + + @pytest.fixture + def target(self, task_lock_params: TaskLockParams): + return make_inmemory_target(data_key='dummy_key', task_lock_params=task_lock_params) + + @pytest.fixture(autouse=True) + def clear_repo(self): + InMemoryCacheRepository().clear() + + def test_dump_and_load_data(self, target: InMemoryTarget): + dumped = 'dummy_data' + target.dump(dumped) + loaded = target.load() + assert loaded == dumped + + def test_exist(self, target: InMemoryTarget): + assert not target.exists() + target.dump('dummy_data') + assert target.exists() + + def test_last_modified_time(self, target: InMemoryTarget): + input = 'dummy_data' + target.dump(input) + time = target.last_modification_time() + assert isinstance(time, datetime) + + sleep(0.1) + another_input = 'another_data' + target.dump(another_input) + another_time = target.last_modification_time() + assert time < another_time + + target.remove() + with pytest.raises(ValueError): + assert target.last_modification_time() diff --git a/test/in_memory/test_repository.py b/test/in_memory/test_repository.py new file mode 100644 index 00000000..14449937 --- /dev/null +++ b/test/in_memory/test_repository.py @@ -0,0 +1,95 @@ +import time + +import pytest + +from gokart.in_memory import InMemoryCacheRepository +from gokart.in_memory.repository import InstantScheduler + +dummy_num = 100 + + +class TestInMemoryCacheRepository: + @pytest.fixture + def repo(self): + repo = InMemoryCacheRepository() + repo.clear() + return repo + + def test_set(self, repo: InMemoryCacheRepository): + repo.set_value('dummy_key', dummy_num) + assert repo.size == 1 + for key, value in repo.get_gen(): + assert (key, value) == ('dummy_key', dummy_num) + + repo.set_value('another_key', 'another_value') + assert repo.size == 2 + + def test_get(self, repo: InMemoryCacheRepository): + repo.set_value('dummy_key', dummy_num) + repo.set_value('another_key', 'another_value') + + """Raise Error when key doesn't exist.""" + with pytest.raises(KeyError): + repo.get_value('not_exist_key') + + assert repo.get_value('dummy_key') == dummy_num + assert repo.get_value('another_key') == 'another_value' + + def test_empty(self, repo: InMemoryCacheRepository): + assert repo.empty() + repo.set_value('dummmy_key', dummy_num) + assert not repo.empty() + + def test_has(self, repo: InMemoryCacheRepository): + assert not repo.has('dummy_key') + repo.set_value('dummy_key', dummy_num) + assert repo.has('dummy_key') + assert not repo.has('not_exist_key') + + def test_remove(self, repo: InMemoryCacheRepository): + repo.set_value('dummy_key', dummy_num) + + with pytest.raises(AssertionError): + repo.remove('not_exist_key') + + repo.remove('dummy_key') + assert not repo.has('dummy_key') + + def test_last_modification_time(self, repo: InMemoryCacheRepository): + repo.set_value('dummy_key', dummy_num) + date1 = repo.get_last_modification_time('dummy_key') + time.sleep(0.1) + repo.set_value('dummy_key', dummy_num) + date2 = repo.get_last_modification_time('dummy_key') + assert date1 < date2 + + +class TestInstantScheduler: + @pytest.fixture(autouse=True) + def set_scheduler(self): + scheduler = InstantScheduler() + InMemoryCacheRepository.set_scheduler(scheduler) + + @pytest.fixture(autouse=True) + def clear_cache(self): + InMemoryCacheRepository.clear() + + @pytest.fixture + def repo(self): + repo = InMemoryCacheRepository() + return repo + + def test_identity(self): + scheduler1 = InstantScheduler() + scheduler2 = InstantScheduler() + assert id(scheduler1) == id(scheduler2) + + def test_scheduler_type(self, repo: InMemoryCacheRepository): + assert isinstance(repo.scheduler, InstantScheduler) + + def test_volatility(self, repo: InMemoryCacheRepository): + assert repo.empty() + repo.set_value('dummy_key', 100) + assert repo.has('dummy_key') + repo.get_value('dummy_key') + assert not repo.has('dummy_key') diff --git a/test/in_memory/test_task_cached_in_memory.py b/test/in_memory/test_task_cached_in_memory.py new file mode 100644 index 00000000..e57d5444 --- /dev/null +++ b/test/in_memory/test_task_cached_in_memory.py @@ -0,0 +1,118 @@ +from typing import Optional, Type, Union + +import luigi +import pytest + +import gokart +from gokart.in_memory import InMemoryCacheRepository +from gokart.target import InMemoryTarget, SingleFileTarget + + +class DummyTask(gokart.TaskOnKart): + task_namespace = __name__ + param: str = luigi.Parameter() + + def run(self): + self.dump(self.param) + + +class DummyTaskWithDependencies(gokart.TaskOnKart): + task_namespace = __name__ + task: list[gokart.TaskOnKart[str]] = gokart.ListTaskInstanceParameter() + + def run(self): + result = ','.join(self.load()) + self.dump(result) + + +class DumpIntTask(gokart.TaskOnKart[int]): + task_namespace = __name__ + value: int = luigi.IntParameter() + + def run(self): + self.dump(self.value) + + +class AddTask(gokart.TaskOnKart[Union[int, float]]): + a: gokart.TaskOnKart[int] = gokart.TaskInstanceParameter() + b: gokart.TaskOnKart[int] = gokart.TaskInstanceParameter() + + def requires(self): + return dict(a=self.a, b=self.b) + + def run(self): + a = self.load(self.a) + b = self.load(self.b) + self.dump(a + b) + + +class TestTaskOnKartWithCache: + @pytest.fixture(autouse=True) + def clear_repository(slef): + InMemoryCacheRepository().clear() + + @pytest.mark.parametrize('data_key', ['sample_key', None]) + @pytest.mark.parametrize('use_unique_id', [True, False]) + def test_key_identity(self, data_key: Optional[str], use_unique_id: bool): + task = DummyTask(param='param') + ext = '.pkl' + relative_file_path = data_key + ext if data_key else None + target = task.make_target(relative_file_path=relative_file_path, use_unique_id=use_unique_id) + cached_target = task.make_cache_target(data_key=data_key, use_unique_id=use_unique_id) + + target_path = target.path().removeprefix(task.workspace_directory).removesuffix(ext).strip('/') + assert cached_target.path() == target_path + + def test_make_cached_target(self): + task = DummyTask(param='param') + target = task.make_cache_target() + assert isinstance(target, InMemoryTarget) + + @pytest.mark.parametrize(['cache_in_memory_by_default', 'target_type'], [[True, InMemoryTarget], [False, SingleFileTarget]]) + def test_make_default_target(self, cache_in_memory_by_default: bool, target_type: Type[gokart.TaskOnKart]): + task = DummyTask(param='param', cache_in_memory_by_default=cache_in_memory_by_default) + target = task.output() + assert isinstance(target, target_type) + + def test_complete_with_cache_in_memory_flag(self, tmpdir): + task = DummyTask(param='param', cache_in_memory_by_default=True, workspace_directory=tmpdir) + assert not task.complete() + file_target = task.make_target() + file_target.dump('data') + assert not task.complete() + cache_target = task.make_cache_target() + cache_target.dump('data') + assert task.complete() + + def test_complete_without_cache_in_memory_flag(self, tmpdir): + task = DummyTask(param='param', workspace_directory=tmpdir) + assert not task.complete() + cache_target = task.make_cache_target() + cache_target.dump('data') + assert not task.complete() + file_target = task.make_target() + file_target.dump('data') + assert task.complete() + + def test_dump_with_cache_in_memory_flag(self, tmpdir): + task = DummyTask(param='param', cache_in_memory_by_default=True, workspace_directory=tmpdir) + file_target = task.make_target() + cache_target = task.make_cache_target() + task.dump('data') + assert not file_target.exists() + assert cache_target.exists() + + def test_dump_without_cache_in_memory_flag(self, tmpdir): + task = DummyTask(param='param', workspace_directory=tmpdir) + file_target = task.make_target() + cache_target = task.make_cache_target() + task.dump('data') + assert file_target.exists() + assert not cache_target.exists() + + def test_gokart_build(self): + task = AddTask( + a=DumpIntTask(value=2, cache_in_memory_by_default=True), b=DumpIntTask(value=3, cache_in_memory_by_default=True), cache_in_memory_by_default=True + ) + output = gokart.build(task, reset_register=False) + assert output == 5