From 9068899d38d994067ffdb850bed8eb1c17cbd3fc Mon Sep 17 00:00:00 2001 From: utotsubasa Date: Sun, 9 Feb 2025 23:00:41 +0900 Subject: [PATCH 01/16] feat: add in-memory target --- gokart/in_memory/__init__.py | 2 + gokart/in_memory/repository.py | 38 +++++++++++++++++ gokart/in_memory/target.py | 50 +++++++++++++++++++++++ test/in_memory/test_in_memory_target.py | 32 +++++++++++++++ test/in_memory/test_repository.py | 54 +++++++++++++++++++++++++ 5 files changed, 176 insertions(+) create mode 100644 gokart/in_memory/__init__.py create mode 100644 gokart/in_memory/repository.py create mode 100644 gokart/in_memory/target.py create mode 100644 test/in_memory/test_in_memory_target.py create mode 100644 test/in_memory/test_repository.py diff --git a/gokart/in_memory/__init__.py b/gokart/in_memory/__init__.py new file mode 100644 index 00000000..dfab283d --- /dev/null +++ b/gokart/in_memory/__init__.py @@ -0,0 +1,2 @@ +from .repository import InMemeryCacheRepository +from .target import InMemoryTarget, make_inmemory_target \ No newline at end of file diff --git a/gokart/in_memory/repository.py b/gokart/in_memory/repository.py new file mode 100644 index 00000000..efb07929 --- /dev/null +++ b/gokart/in_memory/repository.py @@ -0,0 +1,38 @@ +import abc +from typing import Any + +class BaseRepository(abc.ABC): + ... + +class InMemeryCacheRepository(BaseRepository): + _cache: dict[str, Any] = {} + def __init__(self): + pass + + def get(self, id: str): + return self._cache[id] + + def set(self, id: str, obj: Any): + assert not self.has(id) + self._cache[id] = obj + + def has(self, id: str): + return id in self._cache + + def remove_by_id(self, id: str): + assert self.has(id) + del self._cache[id] + + def empty(self): + return not self._cache + + def clear(self): + self._cache.clear() + + def get_gen(self): + for key, value in self._cache.items(): + yield key, value + + @property + def size(self): + return len(self._cache) \ No newline at end of file diff --git a/gokart/in_memory/target.py b/gokart/in_memory/target.py new file mode 100644 index 00000000..8d9dda3b --- /dev/null +++ b/gokart/in_memory/target.py @@ -0,0 +1,50 @@ +from gokart.target import TargetOnKart, TaskLockParams +from gokart.in_memory.repository import InMemeryCacheRepository +from datetime import datetime + +_repository = InMemeryCacheRepository() + +# TODO: unnecessary params in task_lock_param expecially regarding redies +class InMemoryTarget(TargetOnKart): + def __init__( + self, + id: str, + task_lock_param: TaskLockParams + ): + self._id = id + self._task_lock_params = task_lock_param + self._last_modification_time_value: None | datetime = None + + def _exists(self): + # import pdb;pdb.set_trace() + return _repository.has(self._id) + + def _get_task_lock_params(self): + return self._task_lock_params + + def _load(self): + # import pdb + # pdb.set_trace() + return _repository.get(self._id) + + def _dump(self, obj): + return _repository.set(self._id, obj) + + def _remove(self) -> None: + _repository.remove_by_id(self._id) + + def _last_modification_time(self) -> datetime: + if self._last_modification_time_value is None: + raise ValueError(f"No object(s) which id is {self._id} are stored before.") + self._last_modification_time_value + + def _path(self): + # TODO: this module name `_path` migit not be appropriate + return self._id + + @property + def id(self): + return self._id + +def make_inmemory_target(target_key: str, task_lock_params: TaskLockParams | None = None): + return InMemoryTarget(target_key, task_lock_params) diff --git a/test/in_memory/test_in_memory_target.py b/test/in_memory/test_in_memory_target.py new file mode 100644 index 00000000..4b173610 --- /dev/null +++ b/test/in_memory/test_in_memory_target.py @@ -0,0 +1,32 @@ +from gokart.conflict_prevention_lock.task_lock import TaskLockParams +from gokart.in_memory import make_inmemory_target, InMemoryTarget, InMemeryCacheRepository +import pytest + +class TestInMemoryTarget: + @pytest.fixture + def task_lock_params(self): + return TaskLockParams( + redis_host=None, + redis_port=None, + redis_timeout=None, + redis_key='dummy', + should_task_lock=False, + raise_task_lock_exception_on_collision=False, + lock_extend_seconds=0 + ) + @pytest.fixture + def target(self, task_lock_params: TaskLockParams): + return make_inmemory_target(target_key='dummy_task_id', task_lock_params=task_lock_params) + + @pytest.fixture(autouse=True) + def clear_repo(self): + InMemeryCacheRepository().clear() + + def test_dump_and_load_data(self, target: InMemoryTarget): + dumped = 'dummy_data' + target.dump(dumped) + loaded = target.load() + assert loaded == dumped + + with pytest.raises(AssertionError): + target.dump('another_data') \ No newline at end of file diff --git a/test/in_memory/test_repository.py b/test/in_memory/test_repository.py new file mode 100644 index 00000000..2610b80a --- /dev/null +++ b/test/in_memory/test_repository.py @@ -0,0 +1,54 @@ +from gokart.in_memory import InMemeryCacheRepository as Repo +import pytest + +dummy_num = 100 + +class TestInMemoryCacheRepository: + @pytest.fixture + def repo(self): + repo = Repo() + repo.clear() + return repo + + def test_set(self, repo: Repo): + repo.set("dummy_id", dummy_num) + assert repo.size == 1 + for key, value in repo.get_gen(): + assert (key, value) == ("dummy_id", dummy_num) + + with pytest.raises(AssertionError): + repo.set('dummy_id', "dummy_value") + + repo.set('another_id', 'another_value') + assert repo.size == 2 + + def test_get(self, repo: Repo): + repo.set('dummy_id', dummy_num) + repo.set('another_id', 'another_val') + + """Raise Error when key doesn't exist.""" + with pytest.raises(KeyError): + repo.get('not_exist_id') + + assert repo.get('dummy_id') == dummy_num + assert repo.get('another_id') == 'another_val' + + def test_empty(self, repo: Repo): + assert repo.empty() + repo.set("dummmy_id", dummy_num) + assert not repo.empty() + + def test_has(self, repo: Repo): + assert not repo.has('dummy_id') + repo.set('dummy_id', dummy_num) + assert repo.has('dummy_id') + + def test_remove_by_id(self, repo: Repo): + repo.set('dummy_id', dummy_num) + + with pytest.raises(AssertionError): + repo.remove_by_id('not_exist_id') + + assert repo.has('dummy_id') + repo.remove_by_id('dummy_id') + assert not repo.has('dummy_id') From 58aa9a56e59504b709a3bb9ad3c4ec7a22d7a794 Mon Sep 17 00:00:00 2001 From: utotsubasa Date: Mon, 10 Feb 2025 10:42:15 +0900 Subject: [PATCH 02/16] feat: add a data format to be stored in `Repository` feat: last_modification_time feature in `InMemoryTarget` style: add some type hints fix: fix typo in `InMemoryCacheRepository` test: add some tests for `InMemoryTarget` and `InMemoryCacheRepository` --- gokart/in_memory/__init__.py | 2 +- gokart/in_memory/data.py | 16 ++++++++ gokart/in_memory/repository.py | 39 +++++++++++-------- gokart/in_memory/target.py | 51 ++++++++++++------------ test/in_memory/test_in_memory_target.py | 34 ++++++++++++---- test/in_memory/test_repository.py | 52 ++++++++++++++----------- 6 files changed, 120 insertions(+), 74 deletions(-) create mode 100644 gokart/in_memory/data.py diff --git a/gokart/in_memory/__init__.py b/gokart/in_memory/__init__.py index dfab283d..f3685297 100644 --- a/gokart/in_memory/__init__.py +++ b/gokart/in_memory/__init__.py @@ -1,2 +1,2 @@ -from .repository import InMemeryCacheRepository +from .repository import InMemoryCacheRepository from .target import InMemoryTarget, make_inmemory_target \ No newline at end of file diff --git a/gokart/in_memory/data.py b/gokart/in_memory/data.py new file mode 100644 index 00000000..5c26998e --- /dev/null +++ b/gokart/in_memory/data.py @@ -0,0 +1,16 @@ +from dataclasses import dataclass +from typing import Any +from abc import ABC +from datetime import datetime + +class BaseData(ABC): + ... + +@dataclass +class InMemoryData(BaseData): + value: Any + last_modified_time: datetime + + @classmethod + def create_data(self, value: Any) -> 'InMemoryData': + return InMemoryData(value=value, last_modified_time=datetime.now()) \ No newline at end of file diff --git a/gokart/in_memory/repository.py b/gokart/in_memory/repository.py index efb07929..fc871366 100644 --- a/gokart/in_memory/repository.py +++ b/gokart/in_memory/repository.py @@ -1,38 +1,45 @@ import abc -from typing import Any +from typing import Any, Iterator +from .data import InMemoryData class BaseRepository(abc.ABC): ... -class InMemeryCacheRepository(BaseRepository): - _cache: dict[str, Any] = {} +class InMemoryCacheRepository(BaseRepository): + _cache: dict[str, InMemoryData] = {} def __init__(self): pass - def get(self, id: str): - return self._cache[id] + def get_value(self, key: str) -> Any: + return self._get_data(key).value - def set(self, id: str, obj: Any): - assert not self.has(id) - self._cache[id] = obj + def get_last_modification_time(self, key: str): + return self._get_data(key).last_modified_time + + def _get_data(self, id: str) -> InMemoryData: + return self._cache[id] + + def set_value(self, id: str, obj: Any) -> None: + data = InMemoryData.create_data(obj) + self._cache[id] = data - def has(self, id: str): + def has(self, id: str) -> bool: return id in self._cache - def remove_by_id(self, id: str): + def remove(self, id: str) -> None: assert self.has(id) del self._cache[id] - def empty(self): + def empty(self) -> bool: return not self._cache - def clear(self): + def clear(self) -> None: self._cache.clear() - def get_gen(self): - for key, value in self._cache.items(): - yield key, value + def get_gen(self) -> Iterator[tuple[str, Any]]: + for key, data in self._cache.items(): + yield key, data.value @property - def size(self): + def size(self) -> int: return len(self._cache) \ No newline at end of file diff --git a/gokart/in_memory/target.py b/gokart/in_memory/target.py index 8d9dda3b..9b910c68 100644 --- a/gokart/in_memory/target.py +++ b/gokart/in_memory/target.py @@ -1,50 +1,47 @@ from gokart.target import TargetOnKart, TaskLockParams -from gokart.in_memory.repository import InMemeryCacheRepository +from gokart.in_memory.repository import InMemoryCacheRepository from datetime import datetime +from typing import Any +from logging import warning -_repository = InMemeryCacheRepository() +_repository = InMemoryCacheRepository() -# TODO: unnecessary params in task_lock_param expecially regarding redies class InMemoryTarget(TargetOnKart): def __init__( self, - id: str, + data_key: str, task_lock_param: TaskLockParams ): - self._id = id + if task_lock_param.should_task_lock: + warning(f'Redis in {self.__class__.__name__} is not supported now.') + + self._data_key = data_key self._task_lock_params = task_lock_param - self._last_modification_time_value: None | datetime = None - def _exists(self): - # import pdb;pdb.set_trace() - return _repository.has(self._id) + def _exists(self) -> bool: + return _repository.has(self._data_key) - def _get_task_lock_params(self): + def _get_task_lock_params(self) -> TaskLockParams: return self._task_lock_params - def _load(self): - # import pdb - # pdb.set_trace() - return _repository.get(self._id) + def _load(self) -> Any: + return _repository.get_value(self._data_key) - def _dump(self, obj): - return _repository.set(self._id, obj) + def _dump(self, obj: Any) -> None: + return _repository.set_value(self._data_key, obj) def _remove(self) -> None: - _repository.remove_by_id(self._id) + _repository.remove(self._data_key) def _last_modification_time(self) -> datetime: - if self._last_modification_time_value is None: - raise ValueError(f"No object(s) which id is {self._id} are stored before.") - self._last_modification_time_value - - def _path(self): - # TODO: this module name `_path` migit not be appropriate - return self._id + if not _repository.has(self._data_key): + raise ValueError(f"No object(s) which id is {self._data_key} are stored before.") + time = _repository.get_last_modification_time(self._data_key) + return time - @property - def id(self): - return self._id + def _path(self) -> str: + # TODO: this module name `_path` migit not be appropriate + return self._data_key def make_inmemory_target(target_key: str, task_lock_params: TaskLockParams | None = None): return InMemoryTarget(target_key, task_lock_params) diff --git a/test/in_memory/test_in_memory_target.py b/test/in_memory/test_in_memory_target.py index 4b173610..125b634d 100644 --- a/test/in_memory/test_in_memory_target.py +++ b/test/in_memory/test_in_memory_target.py @@ -1,7 +1,8 @@ from gokart.conflict_prevention_lock.task_lock import TaskLockParams -from gokart.in_memory import make_inmemory_target, InMemoryTarget, InMemeryCacheRepository +from gokart.in_memory import make_inmemory_target, InMemoryTarget, InMemoryCacheRepository import pytest - +from datetime import datetime +from time import sleep class TestInMemoryTarget: @pytest.fixture def task_lock_params(self): @@ -14,13 +15,14 @@ def task_lock_params(self): raise_task_lock_exception_on_collision=False, lock_extend_seconds=0 ) + @pytest.fixture def target(self, task_lock_params: TaskLockParams): - return make_inmemory_target(target_key='dummy_task_id', task_lock_params=task_lock_params) - + return make_inmemory_target(target_key='dummy_key', task_lock_params=task_lock_params) + @pytest.fixture(autouse=True) def clear_repo(self): - InMemeryCacheRepository().clear() + InMemoryCacheRepository().clear() def test_dump_and_load_data(self, target: InMemoryTarget): dumped = 'dummy_data' @@ -28,5 +30,23 @@ def test_dump_and_load_data(self, target: InMemoryTarget): loaded = target.load() assert loaded == dumped - with pytest.raises(AssertionError): - target.dump('another_data') \ No newline at end of file + def test_exist(self, target: InMemoryTarget): + assert not target.exists() + target.dump('dummy_data') + assert target.exists() + + def test_last_modified_time(self, target: InMemoryTarget): + input = 'dummy_data' + target.dump(input) + time = target.last_modification_time() + assert isinstance(time, datetime) + + sleep(0.1) + another_input = 'another_data' + target.dump(another_input) + another_time = target.last_modification_time() + assert time < another_time + + target.remove() + with pytest.raises(ValueError): + assert target.last_modification_time() diff --git a/test/in_memory/test_repository.py b/test/in_memory/test_repository.py index 2610b80a..480910d3 100644 --- a/test/in_memory/test_repository.py +++ b/test/in_memory/test_repository.py @@ -1,5 +1,6 @@ -from gokart.in_memory import InMemeryCacheRepository as Repo +from gokart.in_memory import InMemoryCacheRepository as Repo import pytest +import time dummy_num = 100 @@ -11,44 +12,49 @@ def repo(self): return repo def test_set(self, repo: Repo): - repo.set("dummy_id", dummy_num) + repo.set_value("dummy_key", dummy_num) assert repo.size == 1 for key, value in repo.get_gen(): - assert (key, value) == ("dummy_id", dummy_num) + assert (key, value) == ("dummy_key", dummy_num) - with pytest.raises(AssertionError): - repo.set('dummy_id', "dummy_value") - - repo.set('another_id', 'another_value') + repo.set_value('another_key', 'another_value') assert repo.size == 2 def test_get(self, repo: Repo): - repo.set('dummy_id', dummy_num) - repo.set('another_id', 'another_val') + repo.set_value('dummy_key', dummy_num) + repo.set_value('another_key', 'another_value') """Raise Error when key doesn't exist.""" with pytest.raises(KeyError): - repo.get('not_exist_id') + repo.get_value('not_exist_key') - assert repo.get('dummy_id') == dummy_num - assert repo.get('another_id') == 'another_val' + assert repo.get_value('dummy_key') == dummy_num + assert repo.get_value('another_key') == 'another_value' def test_empty(self, repo: Repo): assert repo.empty() - repo.set("dummmy_id", dummy_num) + repo.set_value("dummmy_key", dummy_num) assert not repo.empty() def test_has(self, repo: Repo): - assert not repo.has('dummy_id') - repo.set('dummy_id', dummy_num) - assert repo.has('dummy_id') + assert not repo.has('dummy_key') + repo.set_value('dummy_key', dummy_num) + assert repo.has('dummy_key') + assert not repo.has('not_exist_key') - def test_remove_by_id(self, repo: Repo): - repo.set('dummy_id', dummy_num) + def test_remove(self, repo: Repo): + repo.set_value('dummy_key', dummy_num) with pytest.raises(AssertionError): - repo.remove_by_id('not_exist_id') - - assert repo.has('dummy_id') - repo.remove_by_id('dummy_id') - assert not repo.has('dummy_id') + repo.remove('not_exist_key') + + repo.remove('dummy_key') + assert not repo.has('dummy_key') + + def test_last_modification_time(self, repo: Repo): + repo.set_value('dummy_key', dummy_num) + date1 = repo.get_last_modification_time('dummy_key') + time.sleep(0.1) + repo.set_value('dummy_key', dummy_num) + date2 = repo.get_last_modification_time('dummy_key') + assert date1 < date2 \ No newline at end of file From 4c311bb439b0dbd19f2f57d81632173e4fc275f7 Mon Sep 17 00:00:00 2001 From: utotsubasa Date: Mon, 10 Feb 2025 11:50:56 +0900 Subject: [PATCH 03/16] fix: fix linting errors --- gokart/file_processor.py | 10 ++++----- gokart/in_memory/__init__.py | 4 ++-- gokart/in_memory/data.py | 10 ++++----- gokart/in_memory/repository.py | 18 ++++++++------- gokart/in_memory/target.py | 29 ++++++++++++------------- test/in_memory/test_in_memory_target.py | 16 +++++++++----- test/in_memory/test_repository.py | 23 +++++++++++--------- 7 files changed, 58 insertions(+), 52 deletions(-) diff --git a/gokart/file_processor.py b/gokart/file_processor.py index 87958327..21b8b77f 100644 --- a/gokart/file_processor.py +++ b/gokart/file_processor.py @@ -166,9 +166,9 @@ def load(self, file): return pd.DataFrame() def dump(self, obj, file): - assert isinstance(obj, pd.DataFrame) or isinstance(obj, pd.Series) or isinstance(obj, dict), ( - f'requires pd.DataFrame or pd.Series or dict, but {type(obj)} is passed.' - ) + assert ( + isinstance(obj, pd.DataFrame) or isinstance(obj, pd.Series) or isinstance(obj, dict) + ), f'requires pd.DataFrame or pd.Series or dict, but {type(obj)} is passed.' if isinstance(obj, dict): obj = pd.DataFrame.from_dict(obj) obj.to_json(file) @@ -263,10 +263,8 @@ def dump(self, obj, file): if self._store_index_in_feather: index_column_name = f'{self.INDEX_COLUMN_PREFIX}{dump_obj.index.name}' - assert index_column_name not in dump_obj.columns, ( - f'column name {index_column_name} already exists in dump_obj. \ + assert index_column_name not in dump_obj.columns, f'column name {index_column_name} already exists in dump_obj. \ Consider not saving index by setting store_index_in_feather=False.' - ) assert dump_obj.index.name != 'None', 'index name is "None", which is not allowed in gokart. Consider setting another index name.' dump_obj[index_column_name] = dump_obj.index diff --git a/gokart/in_memory/__init__.py b/gokart/in_memory/__init__.py index f3685297..69e7e4c3 100644 --- a/gokart/in_memory/__init__.py +++ b/gokart/in_memory/__init__.py @@ -1,2 +1,2 @@ -from .repository import InMemoryCacheRepository -from .target import InMemoryTarget, make_inmemory_target \ No newline at end of file +from .repository import InMemoryCacheRepository # noqa:F401 +from .target import InMemoryTarget, make_inmemory_target # noqa:F401 diff --git a/gokart/in_memory/data.py b/gokart/in_memory/data.py index 5c26998e..9362cfe6 100644 --- a/gokart/in_memory/data.py +++ b/gokart/in_memory/data.py @@ -1,10 +1,10 @@ from dataclasses import dataclass -from typing import Any -from abc import ABC from datetime import datetime +from typing import Any + + +class BaseData: ... -class BaseData(ABC): - ... @dataclass class InMemoryData(BaseData): @@ -13,4 +13,4 @@ class InMemoryData(BaseData): @classmethod def create_data(self, value: Any) -> 'InMemoryData': - return InMemoryData(value=value, last_modified_time=datetime.now()) \ No newline at end of file + return InMemoryData(value=value, last_modified_time=datetime.now()) diff --git a/gokart/in_memory/repository.py b/gokart/in_memory/repository.py index fc871366..6babaed7 100644 --- a/gokart/in_memory/repository.py +++ b/gokart/in_memory/repository.py @@ -1,18 +1,20 @@ -import abc from typing import Any, Iterator + from .data import InMemoryData -class BaseRepository(abc.ABC): - ... + +class BaseRepository: ... + class InMemoryCacheRepository(BaseRepository): _cache: dict[str, InMemoryData] = {} + def __init__(self): pass def get_value(self, key: str) -> Any: return self._get_data(key).value - + def get_last_modification_time(self, key: str): return self._get_data(key).last_modified_time @@ -22,17 +24,17 @@ def _get_data(self, id: str) -> InMemoryData: def set_value(self, id: str, obj: Any) -> None: data = InMemoryData.create_data(obj) self._cache[id] = data - + def has(self, id: str) -> bool: return id in self._cache - + def remove(self, id: str) -> None: assert self.has(id) del self._cache[id] def empty(self) -> bool: return not self._cache - + def clear(self) -> None: self._cache.clear() @@ -42,4 +44,4 @@ def get_gen(self) -> Iterator[tuple[str, Any]]: @property def size(self) -> int: - return len(self._cache) \ No newline at end of file + return len(self._cache) diff --git a/gokart/in_memory/target.py b/gokart/in_memory/target.py index 9b910c68..067e5299 100644 --- a/gokart/in_memory/target.py +++ b/gokart/in_memory/target.py @@ -1,41 +1,39 @@ -from gokart.target import TargetOnKart, TaskLockParams -from gokart.in_memory.repository import InMemoryCacheRepository from datetime import datetime -from typing import Any from logging import warning +from typing import Any + +from gokart.in_memory.repository import InMemoryCacheRepository +from gokart.target import TargetOnKart, TaskLockParams _repository = InMemoryCacheRepository() + class InMemoryTarget(TargetOnKart): - def __init__( - self, - data_key: str, - task_lock_param: TaskLockParams - ): + def __init__(self, data_key: str, task_lock_param: TaskLockParams): if task_lock_param.should_task_lock: warning(f'Redis in {self.__class__.__name__} is not supported now.') self._data_key = data_key self._task_lock_params = task_lock_param - + def _exists(self) -> bool: return _repository.has(self._data_key) - + def _get_task_lock_params(self) -> TaskLockParams: return self._task_lock_params - + def _load(self) -> Any: return _repository.get_value(self._data_key) - + def _dump(self, obj: Any) -> None: return _repository.set_value(self._data_key, obj) - + def _remove(self) -> None: _repository.remove(self._data_key) - + def _last_modification_time(self) -> datetime: if not _repository.has(self._data_key): - raise ValueError(f"No object(s) which id is {self._data_key} are stored before.") + raise ValueError(f'No object(s) which id is {self._data_key} are stored before.') time = _repository.get_last_modification_time(self._data_key) return time @@ -43,5 +41,6 @@ def _path(self) -> str: # TODO: this module name `_path` migit not be appropriate return self._data_key + def make_inmemory_target(target_key: str, task_lock_params: TaskLockParams | None = None): return InMemoryTarget(target_key, task_lock_params) diff --git a/test/in_memory/test_in_memory_target.py b/test/in_memory/test_in_memory_target.py index 125b634d..c20bfa3b 100644 --- a/test/in_memory/test_in_memory_target.py +++ b/test/in_memory/test_in_memory_target.py @@ -1,8 +1,12 @@ -from gokart.conflict_prevention_lock.task_lock import TaskLockParams -from gokart.in_memory import make_inmemory_target, InMemoryTarget, InMemoryCacheRepository -import pytest from datetime import datetime from time import sleep + +import pytest + +from gokart.conflict_prevention_lock.task_lock import TaskLockParams +from gokart.in_memory import InMemoryCacheRepository, InMemoryTarget, make_inmemory_target + + class TestInMemoryTarget: @pytest.fixture def task_lock_params(self): @@ -13,7 +17,7 @@ def task_lock_params(self): redis_key='dummy', should_task_lock=False, raise_task_lock_exception_on_collision=False, - lock_extend_seconds=0 + lock_extend_seconds=0, ) @pytest.fixture @@ -34,13 +38,13 @@ def test_exist(self, target: InMemoryTarget): assert not target.exists() target.dump('dummy_data') assert target.exists() - + def test_last_modified_time(self, target: InMemoryTarget): input = 'dummy_data' target.dump(input) time = target.last_modification_time() assert isinstance(time, datetime) - + sleep(0.1) another_input = 'another_data' target.dump(another_input) diff --git a/test/in_memory/test_repository.py b/test/in_memory/test_repository.py index 480910d3..30c5b033 100644 --- a/test/in_memory/test_repository.py +++ b/test/in_memory/test_repository.py @@ -1,9 +1,12 @@ -from gokart.in_memory import InMemoryCacheRepository as Repo -import pytest import time +import pytest + +from gokart.in_memory import InMemoryCacheRepository as Repo + dummy_num = 100 + class TestInMemoryCacheRepository: @pytest.fixture def repo(self): @@ -12,11 +15,11 @@ def repo(self): return repo def test_set(self, repo: Repo): - repo.set_value("dummy_key", dummy_num) + repo.set_value('dummy_key', dummy_num) assert repo.size == 1 for key, value in repo.get_gen(): - assert (key, value) == ("dummy_key", dummy_num) - + assert (key, value) == ('dummy_key', dummy_num) + repo.set_value('another_key', 'another_value') assert repo.size == 2 @@ -27,21 +30,21 @@ def test_get(self, repo: Repo): """Raise Error when key doesn't exist.""" with pytest.raises(KeyError): repo.get_value('not_exist_key') - + assert repo.get_value('dummy_key') == dummy_num assert repo.get_value('another_key') == 'another_value' def test_empty(self, repo: Repo): assert repo.empty() - repo.set_value("dummmy_key", dummy_num) + repo.set_value('dummmy_key', dummy_num) assert not repo.empty() - + def test_has(self, repo: Repo): assert not repo.has('dummy_key') repo.set_value('dummy_key', dummy_num) assert repo.has('dummy_key') assert not repo.has('not_exist_key') - + def test_remove(self, repo: Repo): repo.set_value('dummy_key', dummy_num) @@ -57,4 +60,4 @@ def test_last_modification_time(self, repo: Repo): time.sleep(0.1) repo.set_value('dummy_key', dummy_num) date2 = repo.get_last_modification_time('dummy_key') - assert date1 < date2 \ No newline at end of file + assert date1 < date2 From da3ae69d99818520ea814da21808155271062b77 Mon Sep 17 00:00:00 2001 From: utotsubasa Date: Mon, 10 Feb 2025 12:10:10 +0900 Subject: [PATCH 04/16] fix: update type union shorthand to to make compatible with py39 --- gokart/in_memory/target.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gokart/in_memory/target.py b/gokart/in_memory/target.py index 067e5299..82979205 100644 --- a/gokart/in_memory/target.py +++ b/gokart/in_memory/target.py @@ -1,6 +1,6 @@ from datetime import datetime from logging import warning -from typing import Any +from typing import Any, Optional from gokart.in_memory.repository import InMemoryCacheRepository from gokart.target import TargetOnKart, TaskLockParams @@ -42,5 +42,5 @@ def _path(self) -> str: return self._data_key -def make_inmemory_target(target_key: str, task_lock_params: TaskLockParams | None = None): +def make_inmemory_target(target_key: str, task_lock_params: Optional[TaskLockParams] = None): return InMemoryTarget(target_key, task_lock_params) From c09040d95f34989830f833d4bb3e784794a861b7 Mon Sep 17 00:00:00 2001 From: utotsubasa Date: Mon, 10 Feb 2025 12:20:11 +0900 Subject: [PATCH 05/16] style: refactor some base classes to inherite from --- gokart/in_memory/data.py | 4 ++-- gokart/in_memory/repository.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/gokart/in_memory/data.py b/gokart/in_memory/data.py index 9362cfe6..0729ff80 100644 --- a/gokart/in_memory/data.py +++ b/gokart/in_memory/data.py @@ -1,9 +1,9 @@ from dataclasses import dataclass from datetime import datetime -from typing import Any +from typing import Any, Protocol -class BaseData: ... +class BaseData(Protocol): ... @dataclass diff --git a/gokart/in_memory/repository.py b/gokart/in_memory/repository.py index 6babaed7..c4ea8485 100644 --- a/gokart/in_memory/repository.py +++ b/gokart/in_memory/repository.py @@ -1,9 +1,9 @@ -from typing import Any, Iterator +from typing import Any, Iterator, Protocol from .data import InMemoryData -class BaseRepository: ... +class BaseRepository(Protocol): ... class InMemoryCacheRepository(BaseRepository): From 333986df6eba69342e121a233b3c3ed3f327c587 Mon Sep 17 00:00:00 2001 From: utotsubasa Date: Mon, 10 Feb 2025 13:03:24 +0900 Subject: [PATCH 06/16] fix: remove unnessesary optional type --- gokart/in_memory/target.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gokart/in_memory/target.py b/gokart/in_memory/target.py index 82979205..b47e96a7 100644 --- a/gokart/in_memory/target.py +++ b/gokart/in_memory/target.py @@ -1,6 +1,6 @@ from datetime import datetime from logging import warning -from typing import Any, Optional +from typing import Any from gokart.in_memory.repository import InMemoryCacheRepository from gokart.target import TargetOnKart, TaskLockParams @@ -42,5 +42,5 @@ def _path(self) -> str: return self._data_key -def make_inmemory_target(target_key: str, task_lock_params: Optional[TaskLockParams] = None): +def make_inmemory_target(target_key: str, task_lock_params: TaskLockParams): return InMemoryTarget(target_key, task_lock_params) From b0c6d5a2973fa0f87591d43aec2c59fe392aafdd Mon Sep 17 00:00:00 2001 From: utotsubasa Date: Mon, 10 Feb 2025 20:27:16 +0900 Subject: [PATCH 07/16] fix: fix format error --- gokart/file_processor.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/gokart/file_processor.py b/gokart/file_processor.py index 21b8b77f..87958327 100644 --- a/gokart/file_processor.py +++ b/gokart/file_processor.py @@ -166,9 +166,9 @@ def load(self, file): return pd.DataFrame() def dump(self, obj, file): - assert ( - isinstance(obj, pd.DataFrame) or isinstance(obj, pd.Series) or isinstance(obj, dict) - ), f'requires pd.DataFrame or pd.Series or dict, but {type(obj)} is passed.' + assert isinstance(obj, pd.DataFrame) or isinstance(obj, pd.Series) or isinstance(obj, dict), ( + f'requires pd.DataFrame or pd.Series or dict, but {type(obj)} is passed.' + ) if isinstance(obj, dict): obj = pd.DataFrame.from_dict(obj) obj.to_json(file) @@ -263,8 +263,10 @@ def dump(self, obj, file): if self._store_index_in_feather: index_column_name = f'{self.INDEX_COLUMN_PREFIX}{dump_obj.index.name}' - assert index_column_name not in dump_obj.columns, f'column name {index_column_name} already exists in dump_obj. \ + assert index_column_name not in dump_obj.columns, ( + f'column name {index_column_name} already exists in dump_obj. \ Consider not saving index by setting store_index_in_feather=False.' + ) assert dump_obj.index.name != 'None', 'index name is "None", which is not allowed in gokart. Consider setting another index name.' dump_obj[index_column_name] = dump_obj.index From a53378a12b90259c034344f39ddcfd547031d694 Mon Sep 17 00:00:00 2001 From: utotsubasa Date: Tue, 11 Feb 2025 00:05:47 +0900 Subject: [PATCH 08/16] chore: add an assertion error message style: update variable name from `id` to `key` --- gokart/in_memory/repository.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/gokart/in_memory/repository.py b/gokart/in_memory/repository.py index c4ea8485..727935ca 100644 --- a/gokart/in_memory/repository.py +++ b/gokart/in_memory/repository.py @@ -18,19 +18,19 @@ def get_value(self, key: str) -> Any: def get_last_modification_time(self, key: str): return self._get_data(key).last_modified_time - def _get_data(self, id: str) -> InMemoryData: - return self._cache[id] + def _get_data(self, key: str) -> InMemoryData: + return self._cache[key] - def set_value(self, id: str, obj: Any) -> None: + def set_value(self, key: str, obj: Any) -> None: data = InMemoryData.create_data(obj) - self._cache[id] = data + self._cache[key] = data - def has(self, id: str) -> bool: - return id in self._cache + def has(self, key: str) -> bool: + return key in self._cache - def remove(self, id: str) -> None: - assert self.has(id) - del self._cache[id] + def remove(self, key: str) -> None: + assert self.has(key), f'{key} does not exist.' + del self._cache[key] def empty(self) -> bool: return not self._cache From 4d9b033b9c0329faab0b26d134b42a1ee3a771f6 Mon Sep 17 00:00:00 2001 From: utotsubasa Date: Tue, 11 Feb 2025 09:09:11 +0900 Subject: [PATCH 09/16] style: update the variable name to for code consistency --- gokart/in_memory/data.py | 4 ++-- gokart/in_memory/repository.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/gokart/in_memory/data.py b/gokart/in_memory/data.py index 0729ff80..a01c3ad2 100644 --- a/gokart/in_memory/data.py +++ b/gokart/in_memory/data.py @@ -9,8 +9,8 @@ class BaseData(Protocol): ... @dataclass class InMemoryData(BaseData): value: Any - last_modified_time: datetime + last_modification_time: datetime @classmethod def create_data(self, value: Any) -> 'InMemoryData': - return InMemoryData(value=value, last_modified_time=datetime.now()) + return InMemoryData(value=value, last_modification_time=datetime.now()) diff --git a/gokart/in_memory/repository.py b/gokart/in_memory/repository.py index 727935ca..4d1920b6 100644 --- a/gokart/in_memory/repository.py +++ b/gokart/in_memory/repository.py @@ -16,7 +16,7 @@ def get_value(self, key: str) -> Any: return self._get_data(key).value def get_last_modification_time(self, key: str): - return self._get_data(key).last_modified_time + return self._get_data(key).last_modification_time def _get_data(self, key: str) -> InMemoryData: return self._cache[key] From b9dbb4d9756aa8c381fbd4faea422c0eef848378 Mon Sep 17 00:00:00 2001 From: utotsubasa Date: Wed, 12 Feb 2025 09:40:40 +0900 Subject: [PATCH 10/16] docs: add a document of how to create InMemoryTarget --- docs/task_on_kart.rst | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/docs/task_on_kart.rst b/docs/task_on_kart.rst index ce52c6d5..09e7a59f 100644 --- a/docs/task_on_kart.rst +++ b/docs/task_on_kart.rst @@ -286,3 +286,24 @@ If you want to dump csv file with other encodings, you can use `encoding` parame def output(self): return self.make_target('file_name.csv', processor=CsvFileProcessor(encoding='cp932')) # This will dump csv as 'cp932' which is used in Windows. + +Cache output in memory instead of dumping to files +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +You can use :class:`~InMemoryTarget` to cache output in memory instead of dumping to files by calling :func:`~gokart.target.make_inmemory_target`. + +.. code:: python + + from gokart.in_memory.target import make_inmemory_target + + def output(self): + unique_id = self.make_unique_id() if use_unique_id else None + # TaskLock is not supported in InMemoryTarget, so it's dummy + task_lock_params = make_task_lock_params( + file_path='dummy_path', + unique_id=unique_id, + redis_host=None, + redis_port=None, + redis_timeout=self.redis_timeout, + raise_task_lock_exception_on_collision=False, + ) + return make_inmemory_target('dummy_path', task_lock_params, unique_id) \ No newline at end of file From 177e72c041098bb657e7cb0849bb05da2f335646 Mon Sep 17 00:00:00 2001 From: utotsubasa Date: Tue, 11 Feb 2025 12:47:54 +0900 Subject: [PATCH 11/16] feat: add the new `TargetOnKart` entrypoint `make_cache_target` feat: add the new parameter `cache_in_memory_by_default` to switch default Target style: update the variable name from `target_key` to `data_key` for code consistency test: add tests for `TaskOnKart`s with the `cache_in_memory` parameter --- gokart/in_memory/target.py | 13 +- gokart/task.py | 33 +++++- test/in_memory/test_in_memory_target.py | 2 +- test/in_memory/test_task_cached_in_memory.py | 118 +++++++++++++++++++ 4 files changed, 160 insertions(+), 6 deletions(-) create mode 100644 test/in_memory/test_task_cached_in_memory.py diff --git a/gokart/in_memory/target.py b/gokart/in_memory/target.py index b47e96a7..620fd556 100644 --- a/gokart/in_memory/target.py +++ b/gokart/in_memory/target.py @@ -1,6 +1,6 @@ from datetime import datetime from logging import warning -from typing import Any +from typing import Any, Optional from gokart.in_memory.repository import InMemoryCacheRepository from gokart.target import TargetOnKart, TaskLockParams @@ -42,5 +42,12 @@ def _path(self) -> str: return self._data_key -def make_inmemory_target(target_key: str, task_lock_params: TaskLockParams): - return InMemoryTarget(target_key, task_lock_params) +def _make_data_key(data_key: str, unique_id: Optional[str] = None): + if not unique_id: + return data_key + return data_key + '_' + unique_id + + +def make_inmemory_target(data_key: str, task_lock_params: TaskLockParams, unique_id: Optional[str] = None): + _data_key = _make_data_key(data_key, unique_id) + return InMemoryTarget(_data_key, task_lock_params) diff --git a/gokart/task.py b/gokart/task.py index f577f64b..6e5eb024 100644 --- a/gokart/task.py +++ b/gokart/task.py @@ -20,9 +20,10 @@ import gokart import gokart.target -from gokart.conflict_prevention_lock.task_lock import make_task_lock_params, make_task_lock_params_for_run +from gokart.conflict_prevention_lock.task_lock import TaskLockParams, make_task_lock_params, make_task_lock_params_for_run from gokart.conflict_prevention_lock.task_lock_wrappers import wrap_run_with_lock from gokart.file_processor import FileProcessor +from gokart.in_memory.target import make_inmemory_target from gokart.pandas_type_config import PandasTypeConfigMap from gokart.parameter import ExplicitBoolParameter, ListTaskInstanceParameter, TaskInstanceParameter from gokart.target import TargetOnKart @@ -105,6 +106,9 @@ class TaskOnKart(luigi.Task, Generic[T]): default=True, description='Check if output file exists at run. If exists, run() will be skipped.', significant=False ) should_lock_run: bool = ExplicitBoolParameter(default=False, significant=False, description='Whether to use redis lock or not at task run.') + cache_in_memory_by_default: bool = ExplicitBoolParameter( + default=False, significant=False, description='If `True`, output is stored on a memory instead of files unless specified.' + ) @property def priority(self): @@ -134,11 +138,13 @@ def __init__(self, *args, **kwargs): task_lock_params = make_task_lock_params_for_run(task_self=self) self.run = wrap_run_with_lock(run_func=self.run, task_lock_params=task_lock_params) # type: ignore + self.make_default_target = self.make_target if not self.cache_in_memory_by_default else self.make_cache_target + def input(self) -> FlattenableItems[TargetOnKart]: return super().input() def output(self) -> FlattenableItems[TargetOnKart]: - return self.make_target() + return self.make_default_target() def requires(self) -> FlattenableItems['TaskOnKart']: tasks = self.make_task_instance_dictionary() @@ -210,11 +216,19 @@ def clone(self, cls=None, **kwargs): return cls(**new_k) def make_target(self, relative_file_path: Optional[str] = None, use_unique_id: bool = True, processor: Optional[FileProcessor] = None) -> TargetOnKart: + # if self.cache_in_memory and processor: + # logger.warning(f"processor {type(processor)} never used.") formatted_relative_file_path = ( relative_file_path if relative_file_path is not None else os.path.join(self.__module__.replace('.', '/'), f'{type(self).__name__}.pkl') ) file_path = os.path.join(self.workspace_directory, formatted_relative_file_path) unique_id = self.make_unique_id() if use_unique_id else None + # if self.cache_in_memory: + # from gokart.target import _make_file_path + # return make_inmemory_target( + # target_key=_make_file_path(file_path, unique_id), + # task_lock_params=TaskLockParams(None, None, None, "hoge", False, False, 100) + # ) task_lock_params = make_task_lock_params( file_path=file_path, @@ -229,6 +243,21 @@ def make_target(self, relative_file_path: Optional[str] = None, use_unique_id: b file_path=file_path, unique_id=unique_id, processor=processor, task_lock_params=task_lock_params, store_index_in_feather=self.store_index_in_feather ) + def make_cache_target(self, data_key: Optional[str] = None, use_unique_id: bool = True): + _data_key = data_key if data_key else os.path.join(self.__module__.replace('.', '/'), type(self).__name__) + unique_id = self.make_unique_id() if use_unique_id else None + # TODO: combine with redis + task_lock_params = TaskLockParams( + redis_host=None, + redis_port=None, + redis_timeout=None, + redis_key='redis_key', + should_task_lock=False, + raise_task_lock_exception_on_collision=False, + lock_extend_seconds=-1, + ) + return make_inmemory_target(_data_key, task_lock_params, unique_id) + def make_large_data_frame_target(self, relative_file_path: Optional[str] = None, use_unique_id: bool = True, max_byte=int(2**26)) -> TargetOnKart: formatted_relative_file_path = ( relative_file_path if relative_file_path is not None else os.path.join(self.__module__.replace('.', '/'), f'{type(self).__name__}.zip') diff --git a/test/in_memory/test_in_memory_target.py b/test/in_memory/test_in_memory_target.py index c20bfa3b..5c08634b 100644 --- a/test/in_memory/test_in_memory_target.py +++ b/test/in_memory/test_in_memory_target.py @@ -22,7 +22,7 @@ def task_lock_params(self): @pytest.fixture def target(self, task_lock_params: TaskLockParams): - return make_inmemory_target(target_key='dummy_key', task_lock_params=task_lock_params) + return make_inmemory_target(data_key='dummy_key', task_lock_params=task_lock_params) @pytest.fixture(autouse=True) def clear_repo(self): diff --git a/test/in_memory/test_task_cached_in_memory.py b/test/in_memory/test_task_cached_in_memory.py new file mode 100644 index 00000000..a874ee6b --- /dev/null +++ b/test/in_memory/test_task_cached_in_memory.py @@ -0,0 +1,118 @@ +from typing import Optional, Type, Union + +import luigi +import pytest + +import gokart +from gokart.in_memory import InMemoryCacheRepository, InMemoryTarget +from gokart.target import SingleFileTarget + + +class DummyTask(gokart.TaskOnKart): + task_namespace = __name__ + param: str = luigi.Parameter() + + def run(self): + self.dump(self.param) + + +class DummyTaskWithDependencies(gokart.TaskOnKart): + task_namespace = __name__ + task: list[gokart.TaskOnKart[str]] = gokart.ListTaskInstanceParameter() + + def run(self): + result = ','.join(self.load()) + self.dump(result) + + +class DumpIntTask(gokart.TaskOnKart[int]): + task_namespace = __name__ + value: int = luigi.IntParameter() + + def run(self): + self.dump(self.value) + + +class AddTask(gokart.TaskOnKart[Union[int, float]]): + a: gokart.TaskOnKart[int] = gokart.TaskInstanceParameter() + b: gokart.TaskOnKart[int] = gokart.TaskInstanceParameter() + + def requires(self): + return dict(a=self.a, b=self.b) + + def run(self): + a = self.load(self.a) + b = self.load(self.b) + self.dump(a + b) + + +class TestTaskOnKartWithCache: + @pytest.fixture(autouse=True) + def clear_repository(slef): + InMemoryCacheRepository().clear() + + @pytest.mark.parametrize('data_key', ['sample_key', None]) + @pytest.mark.parametrize('use_unique_id', [True, False]) + def test_key_identity(self, data_key: Optional[str], use_unique_id: bool): + task = DummyTask(param='param') + ext = '.pkl' + relative_file_path = data_key + ext if data_key else None + target = task.make_target(relative_file_path=relative_file_path, use_unique_id=use_unique_id) + cached_target = task.make_cache_target(data_key=data_key, use_unique_id=use_unique_id) + + target_path = target.path().removeprefix(task.workspace_directory).removesuffix(ext).strip('/') + assert cached_target.path() == target_path + + def test_make_cached_target(self): + task = DummyTask(param='param') + target = task.make_cache_target() + assert isinstance(target, InMemoryTarget) + + @pytest.mark.parametrize(['cache_in_memory_by_default', 'target_type'], [[True, InMemoryTarget], [False, SingleFileTarget]]) + def test_make_default_target(self, cache_in_memory_by_default: bool, target_type: Type[gokart.TaskOnKart]): + task = DummyTask(param='param', cache_in_memory_by_default=cache_in_memory_by_default) + target = task.output() + assert isinstance(target, target_type) + + def test_complete_with_cache_in_memory_flag(self, tmpdir): + task = DummyTask(param='param', cache_in_memory_by_default=True, workspace_directory=tmpdir) + assert not task.complete() + file_target = task.make_target() + file_target.dump('data') + assert not task.complete() + cache_target = task.make_cache_target() + cache_target.dump('data') + assert task.complete() + + def test_complete_without_cache_in_memory_flag(self, tmpdir): + task = DummyTask(param='param', workspace_directory=tmpdir) + assert not task.complete() + cache_target = task.make_cache_target() + cache_target.dump('data') + assert not task.complete() + file_target = task.make_target() + file_target.dump('data') + assert task.complete() + + def test_dump_with_cache_in_memory_flag(self, tmpdir): + task = DummyTask(param='param', cache_in_memory_by_default=True, workspace_directory=tmpdir) + file_target = task.make_target() + cache_target = task.make_cache_target() + task.dump('data') + assert not file_target.exists() + assert cache_target.exists() + + def test_dump_without_cache_in_memory_flag(self, tmpdir): + task = DummyTask(param='param', workspace_directory=tmpdir) + file_target = task.make_target() + cache_target = task.make_cache_target() + task.dump('data') + assert file_target.exists() + assert not cache_target.exists() + + def test_gokart_build(self): + task = AddTask( + a=DumpIntTask(value=2, cache_in_memory_by_default=True), b=DumpIntTask(value=3, cache_in_memory_by_default=True), cache_in_memory_by_default=True + ) + output = gokart.build(task, reset_register=False) + assert output == 5 From fa0cbd381e17ca733d7c87f120798ea13349f249 Mon Sep 17 00:00:00 2001 From: utotsubasa Date: Tue, 11 Feb 2025 12:54:22 +0900 Subject: [PATCH 12/16] chore: delete unecessary comments --- gokart/task.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/gokart/task.py b/gokart/task.py index 6e5eb024..8eac941d 100644 --- a/gokart/task.py +++ b/gokart/task.py @@ -216,19 +216,11 @@ def clone(self, cls=None, **kwargs): return cls(**new_k) def make_target(self, relative_file_path: Optional[str] = None, use_unique_id: bool = True, processor: Optional[FileProcessor] = None) -> TargetOnKart: - # if self.cache_in_memory and processor: - # logger.warning(f"processor {type(processor)} never used.") formatted_relative_file_path = ( relative_file_path if relative_file_path is not None else os.path.join(self.__module__.replace('.', '/'), f'{type(self).__name__}.pkl') ) file_path = os.path.join(self.workspace_directory, formatted_relative_file_path) unique_id = self.make_unique_id() if use_unique_id else None - # if self.cache_in_memory: - # from gokart.target import _make_file_path - # return make_inmemory_target( - # target_key=_make_file_path(file_path, unique_id), - # task_lock_params=TaskLockParams(None, None, None, "hoge", False, False, 100) - # ) task_lock_params = make_task_lock_params( file_path=file_path, From 1f37d0971def938d2d6b3e0f9dd6f921ad9b99c6 Mon Sep 17 00:00:00 2001 From: utotsubasa Date: Tue, 11 Feb 2025 22:18:41 +0900 Subject: [PATCH 13/16] feat: add cache scheduler --- gokart/in_memory/repository.py | 72 +++++++++++++++++++++++++++---- gokart/in_memory/target.py | 13 +++--- test/in_memory/test_repository.py | 48 +++++++++++++++++---- 3 files changed, 111 insertions(+), 22 deletions(-) diff --git a/gokart/in_memory/repository.py b/gokart/in_memory/repository.py index 4d1920b6..a0ef0e76 100644 --- a/gokart/in_memory/repository.py +++ b/gokart/in_memory/repository.py @@ -1,19 +1,65 @@ -from typing import Any, Iterator, Protocol +from abc import ABC, abstractmethod +from typing import Any, Iterator from .data import InMemoryData -class BaseRepository(Protocol): ... +class CacheScheduler(ABC): + def __new__(cls): + if not hasattr(cls, '__instance'): + setattr(cls, '__instance', super().__new__(cls)) + return getattr(cls, '__instance') + @abstractmethod + def get_hook(self, repository: 'InMemoryCacheRepository', key: str, data: InMemoryData): ... -class InMemoryCacheRepository(BaseRepository): - _cache: dict[str, InMemoryData] = {} + @abstractmethod + def set_hook(self, repository: 'InMemoryCacheRepository', key: str, data: InMemoryData): ... + + @abstractmethod + def clear(self): ... + + +class DoNothingScheduler(CacheScheduler): + def get_hook(self, repository: 'InMemoryCacheRepository', key: str, data: InMemoryData): + pass + + def set_hook(self, repository: 'InMemoryCacheRepository', key: str, data: InMemoryData): + pass + + def clear(self): + pass + + +# TODO: ambiguous class name +class InstantScheduler(CacheScheduler): + def get_hook(self, repository: 'InMemoryCacheRepository', key: str, data: InMemoryData): + repository.remove(key) - def __init__(self): + def set_hook(self, repository: 'InMemoryCacheRepository', key: str, data: InMemoryData): pass + def clear(self): + pass + + +class InMemoryCacheRepository: + _cache: dict[str, InMemoryData] = {} + _scheduler: CacheScheduler = DoNothingScheduler() + + def __new__(cls): + if not hasattr(cls, '__instance'): + cls.__instance = super().__new__(cls) + return cls.__instance + + @classmethod + def set_scheduler(cls, cache_scheduler: CacheScheduler): + cls._scheduler = cache_scheduler + def get_value(self, key: str) -> Any: - return self._get_data(key).value + data = self._get_data(key) + self._scheduler.get_hook(self, key, data) + return data.value def get_last_modification_time(self, key: str): return self._get_data(key).last_modification_time @@ -23,6 +69,10 @@ def _get_data(self, key: str) -> InMemoryData: def set_value(self, key: str, obj: Any) -> None: data = InMemoryData.create_data(obj) + self._scheduler.set_hook(self, key, data) + self._set_data(key, data) + + def _set_data(self, key: str, data: InMemoryData): self._cache[key] = data def has(self, key: str) -> bool: @@ -35,8 +85,10 @@ def remove(self, key: str) -> None: def empty(self) -> bool: return not self._cache - def clear(self) -> None: - self._cache.clear() + @classmethod + def clear(cls) -> None: + cls._cache.clear() + cls._scheduler.clear() def get_gen(self) -> Iterator[tuple[str, Any]]: for key, data in self._cache.items(): @@ -45,3 +97,7 @@ def get_gen(self) -> Iterator[tuple[str, Any]]: @property def size(self) -> int: return len(self._cache) + + @property + def scheduler(self) -> CacheScheduler: + return self._scheduler diff --git a/gokart/in_memory/target.py b/gokart/in_memory/target.py index 620fd556..84ce468d 100644 --- a/gokart/in_memory/target.py +++ b/gokart/in_memory/target.py @@ -15,26 +15,27 @@ def __init__(self, data_key: str, task_lock_param: TaskLockParams): self._data_key = data_key self._task_lock_params = task_lock_param + self._repository = InMemoryCacheRepository() def _exists(self) -> bool: - return _repository.has(self._data_key) + return self._repository.has(self._data_key) def _get_task_lock_params(self) -> TaskLockParams: return self._task_lock_params def _load(self) -> Any: - return _repository.get_value(self._data_key) + return self._repository.get_value(self._data_key) def _dump(self, obj: Any) -> None: - return _repository.set_value(self._data_key, obj) + return self._repository.set_value(self._data_key, obj) def _remove(self) -> None: - _repository.remove(self._data_key) + self._repository.remove(self._data_key) def _last_modification_time(self) -> datetime: - if not _repository.has(self._data_key): + if not self._repository.has(self._data_key): raise ValueError(f'No object(s) which id is {self._data_key} are stored before.') - time = _repository.get_last_modification_time(self._data_key) + time = self._repository.get_last_modification_time(self._data_key) return time def _path(self) -> str: diff --git a/test/in_memory/test_repository.py b/test/in_memory/test_repository.py index 30c5b033..14449937 100644 --- a/test/in_memory/test_repository.py +++ b/test/in_memory/test_repository.py @@ -2,7 +2,8 @@ import pytest -from gokart.in_memory import InMemoryCacheRepository as Repo +from gokart.in_memory import InMemoryCacheRepository +from gokart.in_memory.repository import InstantScheduler dummy_num = 100 @@ -10,11 +11,11 @@ class TestInMemoryCacheRepository: @pytest.fixture def repo(self): - repo = Repo() + repo = InMemoryCacheRepository() repo.clear() return repo - def test_set(self, repo: Repo): + def test_set(self, repo: InMemoryCacheRepository): repo.set_value('dummy_key', dummy_num) assert repo.size == 1 for key, value in repo.get_gen(): @@ -23,7 +24,7 @@ def test_set(self, repo: Repo): repo.set_value('another_key', 'another_value') assert repo.size == 2 - def test_get(self, repo: Repo): + def test_get(self, repo: InMemoryCacheRepository): repo.set_value('dummy_key', dummy_num) repo.set_value('another_key', 'another_value') @@ -34,18 +35,18 @@ def test_get(self, repo: Repo): assert repo.get_value('dummy_key') == dummy_num assert repo.get_value('another_key') == 'another_value' - def test_empty(self, repo: Repo): + def test_empty(self, repo: InMemoryCacheRepository): assert repo.empty() repo.set_value('dummmy_key', dummy_num) assert not repo.empty() - def test_has(self, repo: Repo): + def test_has(self, repo: InMemoryCacheRepository): assert not repo.has('dummy_key') repo.set_value('dummy_key', dummy_num) assert repo.has('dummy_key') assert not repo.has('not_exist_key') - def test_remove(self, repo: Repo): + def test_remove(self, repo: InMemoryCacheRepository): repo.set_value('dummy_key', dummy_num) with pytest.raises(AssertionError): @@ -54,10 +55,41 @@ def test_remove(self, repo: Repo): repo.remove('dummy_key') assert not repo.has('dummy_key') - def test_last_modification_time(self, repo: Repo): + def test_last_modification_time(self, repo: InMemoryCacheRepository): repo.set_value('dummy_key', dummy_num) date1 = repo.get_last_modification_time('dummy_key') time.sleep(0.1) repo.set_value('dummy_key', dummy_num) date2 = repo.get_last_modification_time('dummy_key') assert date1 < date2 + + +class TestInstantScheduler: + @pytest.fixture(autouse=True) + def set_scheduler(self): + scheduler = InstantScheduler() + InMemoryCacheRepository.set_scheduler(scheduler) + + @pytest.fixture(autouse=True) + def clear_cache(self): + InMemoryCacheRepository.clear() + + @pytest.fixture + def repo(self): + repo = InMemoryCacheRepository() + return repo + + def test_identity(self): + scheduler1 = InstantScheduler() + scheduler2 = InstantScheduler() + assert id(scheduler1) == id(scheduler2) + + def test_scheduler_type(self, repo: InMemoryCacheRepository): + assert isinstance(repo.scheduler, InstantScheduler) + + def test_volatility(self, repo: InMemoryCacheRepository): + assert repo.empty() + repo.set_value('dummy_key', 100) + assert repo.has('dummy_key') + repo.get_value('dummy_key') + assert not repo.has('dummy_key') From f255aaf5cf8fd7a2bc6d35ec0bffc8deb4672568 Mon Sep 17 00:00:00 2001 From: utotsubasa Date: Wed, 12 Feb 2025 11:26:50 +0900 Subject: [PATCH 14/16] feat: add Cacheable Mixin and inhereted Targets feat: add cacheable flag to `make_target` test: add test fo Cacheable Targets chore: move `InMemoryTarget` to `gokart.target` to avoid circular dependencies --- gokart/in_memory/__init__.py | 1 - gokart/in_memory/cacheable.py | 88 ++++++++++ gokart/in_memory/target.py | 54 ------ gokart/target.py | 55 +++++- gokart/task.py | 7 +- test/in_memory/test_cacheable_target.py | 173 +++++++++++++++++++ test/in_memory/test_in_memory_target.py | 4 +- test/in_memory/test_task_cached_in_memory.py | 4 +- 8 files changed, 322 insertions(+), 64 deletions(-) create mode 100644 gokart/in_memory/cacheable.py delete mode 100644 gokart/in_memory/target.py create mode 100644 test/in_memory/test_cacheable_target.py diff --git a/gokart/in_memory/__init__.py b/gokart/in_memory/__init__.py index 69e7e4c3..a935fc20 100644 --- a/gokart/in_memory/__init__.py +++ b/gokart/in_memory/__init__.py @@ -1,2 +1 @@ from .repository import InMemoryCacheRepository # noqa:F401 -from .target import InMemoryTarget, make_inmemory_target # noqa:F401 diff --git a/gokart/in_memory/cacheable.py b/gokart/in_memory/cacheable.py new file mode 100644 index 00000000..986ae285 --- /dev/null +++ b/gokart/in_memory/cacheable.py @@ -0,0 +1,88 @@ +from gokart.conflict_prevention_lock.task_lock_wrappers import wrap_dump_with_lock, wrap_load_with_lock, wrap_remove_with_lock +from typing import Any +from gokart.in_memory.repository import InMemoryCacheRepository + +class CacheNotFoundError(OSError): + ... + +class CacheableMixin: + def exists(self) -> bool: + return self._cache_exists() or super().exists() + + def load(self) -> bool: + def _load(): + cached = self._cache_exists() + if cached: + return self._cache_load() + else: + try: + loaded = super(CacheableMixin, self)._load() + except FileNotFoundError: + raise CacheNotFoundError + self._cache_dump(loaded) + return loaded + return wrap_load_with_lock(func=_load, task_lock_params=super()._get_task_lock_params())() + + def dump(self, obj: Any, lock_at_dump: bool = True, also_dump_to_file: bool = False): + # TODO: how to sync cache and file + def _dump(obj: Any): + self._cache_dump(obj) + if also_dump_to_file: + super(CacheableMixin, self)._dump(obj) + if lock_at_dump: + wrap_dump_with_lock(func=_dump, task_lock_params=super()._get_task_lock_params(), exist_check=self.exists)(obj) + else: + _dump(obj) + + def remove(self, also_remove_file: bool = False): + def _remove(): + if self._cache_exists(): + self._cache_remove() + if super(CacheableMixin, self).exists() and also_remove_file: + super(CacheableMixin, self)._remove() + wrap_remove_with_lock(func=_remove, task_lock_params=super()._get_task_lock_params())() + + def last_modification_time(self): + if self._cache_exists(): + return self._cache_last_modification_time() + try: + return super()._last_modification_time() + except: + raise CacheNotFoundError + + @property + def data_key(self): + return super().path() + + def _cache_exists(self): + raise NotImplementedError + + def _cache_load(self): + raise NotImplementedError + + def _cache_dump(self, obj: Any): + raise NotImplementedError + + def _cache_remove(self): + raise NotImplementedError + + def _cache_last_modification_time(self): + raise NotImplementedError + +class InMemoryCacheableMixin(CacheableMixin): + _repository = InMemoryCacheRepository() + + def _cache_exists(self): + return self._repository.has(self.data_key) + + def _cache_load(self): + return self._repository.get_value(self.data_key) + + def _cache_dump(self, obj): + return self._repository.set_value(self.data_key, obj) + + def _cache_remove(self): + self._repository.remove(self.data_key) + + def _cache_last_modification_time(self): + return self._repository.get_last_modification_time(self.data_key) diff --git a/gokart/in_memory/target.py b/gokart/in_memory/target.py deleted file mode 100644 index 84ce468d..00000000 --- a/gokart/in_memory/target.py +++ /dev/null @@ -1,54 +0,0 @@ -from datetime import datetime -from logging import warning -from typing import Any, Optional - -from gokart.in_memory.repository import InMemoryCacheRepository -from gokart.target import TargetOnKart, TaskLockParams - -_repository = InMemoryCacheRepository() - - -class InMemoryTarget(TargetOnKart): - def __init__(self, data_key: str, task_lock_param: TaskLockParams): - if task_lock_param.should_task_lock: - warning(f'Redis in {self.__class__.__name__} is not supported now.') - - self._data_key = data_key - self._task_lock_params = task_lock_param - self._repository = InMemoryCacheRepository() - - def _exists(self) -> bool: - return self._repository.has(self._data_key) - - def _get_task_lock_params(self) -> TaskLockParams: - return self._task_lock_params - - def _load(self) -> Any: - return self._repository.get_value(self._data_key) - - def _dump(self, obj: Any) -> None: - return self._repository.set_value(self._data_key, obj) - - def _remove(self) -> None: - self._repository.remove(self._data_key) - - def _last_modification_time(self) -> datetime: - if not self._repository.has(self._data_key): - raise ValueError(f'No object(s) which id is {self._data_key} are stored before.') - time = self._repository.get_last_modification_time(self._data_key) - return time - - def _path(self) -> str: - # TODO: this module name `_path` migit not be appropriate - return self._data_key - - -def _make_data_key(data_key: str, unique_id: Optional[str] = None): - if not unique_id: - return data_key - return data_key + '_' + unique_id - - -def make_inmemory_target(data_key: str, task_lock_params: TaskLockParams, unique_id: Optional[str] = None): - _data_key = _make_data_key(data_key, unique_id) - return InMemoryTarget(_data_key, task_lock_params) diff --git a/gokart/target.py b/gokart/target.py index 88b3c942..56041480 100644 --- a/gokart/target.py +++ b/gokart/target.py @@ -16,6 +16,8 @@ from gokart.file_processor import FileProcessor, make_file_processor from gokart.object_storage import ObjectStorage from gokart.zip_client_util import make_zip_client +from gokart.in_memory.cacheable import InMemoryCacheableMixin +from gokart.in_memory.repository import InMemoryCacheRepository logger = getLogger(__name__) @@ -163,7 +165,11 @@ def _remove_temporary_directory(self): def _make_temporary_directory(self): os.makedirs(self._temporary_directory, exist_ok=True) +class CacheableSingleFileTarget(InMemoryCacheableMixin, SingleFileTarget): + ... +class CacheableModelTarget(InMemoryCacheableMixin, ModelTarget): + ... class LargeDataFrameProcessor(object): def __init__(self, max_byte: int): self.max_byte = int(max_byte) @@ -216,12 +222,14 @@ def make_target( processor: Optional[FileProcessor] = None, task_lock_params: Optional[TaskLockParams] = None, store_index_in_feather: bool = True, + cacheable: bool = False ) -> TargetOnKart: _task_lock_params = task_lock_params if task_lock_params is not None else make_task_lock_params(file_path=file_path, unique_id=unique_id) file_path = _make_file_path(file_path, unique_id) processor = processor or make_file_processor(file_path, store_index_in_feather=store_index_in_feather) file_system_target = _make_file_system_target(file_path, processor=processor, store_index_in_feather=store_index_in_feather) - return SingleFileTarget(target=file_system_target, processor=processor, task_lock_params=_task_lock_params) + cls = CacheableSingleFileTarget if cacheable else SingleFileTarget + return cls(target=file_system_target, processor=processor, task_lock_params=_task_lock_params) def make_model_target( @@ -242,3 +250,48 @@ def make_model_target( load_function=load_function, task_lock_params=_task_lock_params, ) + +class InMemoryTarget(TargetOnKart): + def __init__(self, data_key: str, task_lock_param: TaskLockParams): + if task_lock_param.should_task_lock: + logger.warning(f'Redis in {self.__class__.__name__} is not supported now.') + + self._data_key = data_key + self._task_lock_params = task_lock_param + self._repository = InMemoryCacheRepository() + + def _exists(self) -> bool: + return self._repository.has(self._data_key) + + def _get_task_lock_params(self) -> TaskLockParams: + return self._task_lock_params + + def _load(self) -> Any: + return self._repository.get_value(self._data_key) + + def _dump(self, obj: Any) -> None: + return self._repository.set_value(self._data_key, obj) + + def _remove(self) -> None: + self._repository.remove(self._data_key) + + def _last_modification_time(self) -> datetime: + if not self._repository.has(self._data_key): + raise ValueError(f'No object(s) which id is {self._data_key} are stored before.') + time = self._repository.get_last_modification_time(self._data_key) + return time + + def _path(self) -> str: + # TODO: this module name `_path` migit not be appropriate + return self._data_key + + +def _make_data_key(data_key: str, unique_id: Optional[str] = None): + if not unique_id: + return data_key + return data_key + '_' + unique_id + + +def make_inmemory_target(data_key: str, task_lock_params: TaskLockParams, unique_id: Optional[str] = None): + _data_key = _make_data_key(data_key, unique_id) + return InMemoryTarget(_data_key, task_lock_params) \ No newline at end of file diff --git a/gokart/task.py b/gokart/task.py index 8eac941d..b55aeb90 100644 --- a/gokart/task.py +++ b/gokart/task.py @@ -23,7 +23,6 @@ from gokart.conflict_prevention_lock.task_lock import TaskLockParams, make_task_lock_params, make_task_lock_params_for_run from gokart.conflict_prevention_lock.task_lock_wrappers import wrap_run_with_lock from gokart.file_processor import FileProcessor -from gokart.in_memory.target import make_inmemory_target from gokart.pandas_type_config import PandasTypeConfigMap from gokart.parameter import ExplicitBoolParameter, ListTaskInstanceParameter, TaskInstanceParameter from gokart.target import TargetOnKart @@ -215,7 +214,7 @@ def clone(self, cls=None, **kwargs): return cls(**new_k) - def make_target(self, relative_file_path: Optional[str] = None, use_unique_id: bool = True, processor: Optional[FileProcessor] = None) -> TargetOnKart: + def make_target(self, relative_file_path: Optional[str] = None, use_unique_id: bool = True, processor: Optional[FileProcessor] = None, cacheable: bool = False) -> TargetOnKart: formatted_relative_file_path = ( relative_file_path if relative_file_path is not None else os.path.join(self.__module__.replace('.', '/'), f'{type(self).__name__}.pkl') ) @@ -232,7 +231,7 @@ def make_target(self, relative_file_path: Optional[str] = None, use_unique_id: b ) return gokart.target.make_target( - file_path=file_path, unique_id=unique_id, processor=processor, task_lock_params=task_lock_params, store_index_in_feather=self.store_index_in_feather + file_path=file_path, unique_id=unique_id, processor=processor, task_lock_params=task_lock_params, store_index_in_feather=self.store_index_in_feather, cacheable=cacheable ) def make_cache_target(self, data_key: Optional[str] = None, use_unique_id: bool = True): @@ -248,7 +247,7 @@ def make_cache_target(self, data_key: Optional[str] = None, use_unique_id: bool raise_task_lock_exception_on_collision=False, lock_extend_seconds=-1, ) - return make_inmemory_target(_data_key, task_lock_params, unique_id) + return gokart.target.make_inmemory_target(_data_key, task_lock_params, unique_id) def make_large_data_frame_target(self, relative_file_path: Optional[str] = None, use_unique_id: bool = True, max_byte=int(2**26)) -> TargetOnKart: formatted_relative_file_path = ( diff --git a/test/in_memory/test_cacheable_target.py b/test/in_memory/test_cacheable_target.py new file mode 100644 index 00000000..5cf1e3cf --- /dev/null +++ b/test/in_memory/test_cacheable_target.py @@ -0,0 +1,173 @@ +import pytest +from gokart.target import CacheableSingleFileTarget +from gokart.task import TaskOnKart +from gokart.in_memory import InMemoryCacheRepository +from gokart.in_memory.cacheable import CacheNotFoundError +import luigi +from time import sleep + +class DummyTask(TaskOnKart): + namespace = __name__ + param = luigi.IntParameter() + + def run(self): + self.dump(self.param) + +class TestCacheableSingleFileTarget: + @pytest.fixture + def task(self, tmpdir): + task = DummyTask(param=100, workspace_directory=tmpdir) + return task + + @pytest.fixture(autouse=True) + def clear_repository(self): + InMemoryCacheRepository.clear() + + def test_exists_when_cache_exists(self, task: TaskOnKart): + cacheable_target = task.make_target('sample.pkl', cacheable=True) + cache_target = task.make_cache_target(cacheable_target.path(), use_unique_id=False) + assert not cacheable_target.exists() + cache_target.dump('data') + assert cacheable_target.exists() + + def test_exists_when_file_exists(self, task: TaskOnKart): + cacheable_target = task.make_target('sample.pkl', cacheable=True) + target = task.make_target('sample.pkl') + assert not cacheable_target.exists() + target.dump('data') + assert cacheable_target.exists() + + def test_load_without_cache_or_file(self, task: TaskOnKart): + target = task.make_target('sample.pkl') + with pytest.raises(FileNotFoundError): + target.load() + cacheable_target = task.make_target('sample.pkl', cacheable=True) + with pytest.raises(CacheNotFoundError): + cacheable_target.load() + + def test_load_with_cache(self, task: TaskOnKart): + target = task.make_target('sample.pkl') + cacheable_target: CacheableSingleFileTarget = task.make_target('sample.pkl', cacheable=True) + cache_target = task.make_cache_target(target.path(), use_unique_id=False) + with pytest.raises(CacheNotFoundError): + cacheable_target.load() + cache_target.dump('data') + with pytest.raises(FileNotFoundError): + target.load() + assert cacheable_target.load() == 'data' + assert cache_target.load() == 'data' + + def test_load_with_file(self, task: TaskOnKart): + target = task.make_target('sample.pkl') + cacheable_target: CacheableSingleFileTarget = task.make_target('sample.pkl', cacheable=True) + cache_target = task.make_cache_target(target.path(), use_unique_id=False) + with pytest.raises(CacheNotFoundError): + cacheable_target.load() + target.dump('data') + assert target.load() == 'data' + with pytest.raises(KeyError): + cache_target.load() + assert cacheable_target.load() == 'data' + assert cache_target.load() == 'data' + + def test_load_with_cache_and_file(self, task: TaskOnKart): + target = task.make_target('sample.pkl') + cacheable_target: CacheableSingleFileTarget = task.make_target('sample.pkl', cacheable=True) + cache_target = task.make_cache_target(target.path(), use_unique_id=False) + with pytest.raises(CacheNotFoundError): + cacheable_target.load() + target.dump('data_in_file') + cache_target.dump('data_in_memory') + assert cacheable_target.load() == 'data_in_memory' + + def test_dump(self, task: TaskOnKart): + target = task.make_target('sample.pkl') + cacheable_target: CacheableSingleFileTarget = task.make_target('sample.pkl', cacheable=True) + cache_target = task.make_cache_target(target.path(), use_unique_id=False) + cacheable_target.dump('data') + assert not target.exists() + assert cache_target.exists() + assert cacheable_target.exists() + + def test_dump_with_dump_to_file_flag(self, task: TaskOnKart): + target = task.make_target('sample.pkl') + cacheable_target: CacheableSingleFileTarget = task.make_target('sample.pkl', cacheable=True) + cache_target = task.make_cache_target(target.path(), use_unique_id=False) + cacheable_target.dump('data', also_dump_to_file=True) + assert target.exists() + assert cache_target.exists() + assert cacheable_target.exists() + + def test_remove_without_cache_or_file(self, task: TaskOnKart): + cacheable_target: CacheableSingleFileTarget = task.make_target('sample.pkl', cacheable=True) + cacheable_target.remove() + cacheable_target.remove(also_remove_file=True) + assert True + + def test_remove_with_cache(self, task: TaskOnKart): + target = task.make_target('sample.pkl') + cacheable_target: CacheableSingleFileTarget = task.make_target('sample.pkl', cacheable=True) + cache_target = task.make_cache_target(target.path(), use_unique_id=False) + cache_target.dump('data') + assert cache_target.exists() + cacheable_target.remove() + assert not cache_target.exists() + + def test_remove_with_file(self, task: TaskOnKart): + target = task.make_target('sample.pkl') + cacheable_target: CacheableSingleFileTarget = task.make_target('sample.pkl', cacheable=True) + target.dump('data') + assert target.exists() + cacheable_target.remove() + assert target.exists() + cacheable_target.remove(also_remove_file=True) + assert not target.exists() + + def test_remove_with_cache_and_file(self, task: TaskOnKart): + target = task.make_target('sample.pkl') + cacheable_target: CacheableSingleFileTarget = task.make_target('sample.pkl', cacheable=True) + cache_target = task.make_cache_target(target.path(), use_unique_id=False) + target.dump('file_data') + cache_target.dump('inmemory_data') + cacheable_target.remove() + assert target.exists() + assert not cache_target.exists() + + target.dump('file_data') + cache_target.dump('inmemory_data') + cacheable_target.remove(also_remove_file=True) + assert not target.exists() + assert not cache_target.exists() + + def test_last_modification_time_without_cache_and_file(self, task: TaskOnKart): + target = task.make_target('sample.pkl') + cacheable_target: CacheableSingleFileTarget = task.make_target('sample.pkl', cacheable=True) + cache_target = task.make_cache_target(target.path(), use_unique_id=False) + with pytest.raises(FileNotFoundError): + target.last_modification_time() + with pytest.raises(ValueError): + cache_target.last_modification_time() + with pytest.raises(CacheNotFoundError): + cacheable_target.last_modification_time() + + def test_last_modification_time_with_cache(self, task: TaskOnKart): + target = task.make_target('sample.pkl') + cacheable_target: CacheableSingleFileTarget = task.make_target('sample.pkl', cacheable=True) + cache_target = task.make_cache_target(target.path(), use_unique_id=False) + cache_target.dump('data') + assert cacheable_target.last_modification_time() == cache_target.last_modification_time() + + def test_last_modification_time_with_file(self, task: TaskOnKart): + target = task.make_target('sample.pkl') + cacheable_target: CacheableSingleFileTarget = task.make_target('sample.pkl', cacheable=True) + target.dump('data') + assert cacheable_target.last_modification_time() == target.last_modification_time() + + def test_last_modification_time_with_cache_and_file(self, task: TaskOnKart): + target = task.make_target('sample.pkl') + cacheable_target: CacheableSingleFileTarget = task.make_target('sample.pkl', cacheable=True) + cache_target = task.make_cache_target(target.path(), use_unique_id=False) + target.dump('file_data') + sleep(0.1) + cache_target.dump('inmemory_data') + assert cacheable_target.last_modification_time() == cache_target.last_modification_time() diff --git a/test/in_memory/test_in_memory_target.py b/test/in_memory/test_in_memory_target.py index 5c08634b..987af21a 100644 --- a/test/in_memory/test_in_memory_target.py +++ b/test/in_memory/test_in_memory_target.py @@ -4,8 +4,8 @@ import pytest from gokart.conflict_prevention_lock.task_lock import TaskLockParams -from gokart.in_memory import InMemoryCacheRepository, InMemoryTarget, make_inmemory_target - +from gokart.in_memory import InMemoryCacheRepository +from gokart.target import InMemoryTarget, make_inmemory_target class TestInMemoryTarget: @pytest.fixture diff --git a/test/in_memory/test_task_cached_in_memory.py b/test/in_memory/test_task_cached_in_memory.py index a874ee6b..fb5ac389 100644 --- a/test/in_memory/test_task_cached_in_memory.py +++ b/test/in_memory/test_task_cached_in_memory.py @@ -4,8 +4,8 @@ import pytest import gokart -from gokart.in_memory import InMemoryCacheRepository, InMemoryTarget -from gokart.target import SingleFileTarget +from gokart.in_memory import InMemoryCacheRepository +from gokart.target import SingleFileTarget, InMemoryTarget class DummyTask(gokart.TaskOnKart): From 5a1572bce97a394c44a8b79f678779cc9e03f53d Mon Sep 17 00:00:00 2001 From: utotsubasa Date: Wed, 12 Feb 2025 12:29:29 +0900 Subject: [PATCH 15/16] feat: add cacheable flag to `make_model_target` test: add tests for `CacheableModelTarget` --- gokart/target.py | 4 +- gokart/task.py | 3 +- test/in_memory/test_cacheable_target.py | 204 +++++++++++++++++++++++- 3 files changed, 207 insertions(+), 4 deletions(-) diff --git a/gokart/target.py b/gokart/target.py index 56041480..e959b12e 100644 --- a/gokart/target.py +++ b/gokart/target.py @@ -239,11 +239,13 @@ def make_model_target( load_function, unique_id: Optional[str] = None, task_lock_params: Optional[TaskLockParams] = None, + cacheable: bool = False, ) -> TargetOnKart: _task_lock_params = task_lock_params if task_lock_params is not None else make_task_lock_params(file_path=file_path, unique_id=unique_id) file_path = _make_file_path(file_path, unique_id) temporary_directory = os.path.join(temporary_directory, hashlib.md5(file_path.encode()).hexdigest()) - return ModelTarget( + cls = CacheableModelTarget if cacheable else ModelTarget + return cls( file_path=file_path, temporary_directory=temporary_directory, save_function=save_function, diff --git a/gokart/task.py b/gokart/task.py index b55aeb90..9bb13c89 100644 --- a/gokart/task.py +++ b/gokart/task.py @@ -274,7 +274,7 @@ def make_large_data_frame_target(self, relative_file_path: Optional[str] = None, ) def make_model_target( - self, relative_file_path: str, save_function: Callable[[Any, str], None], load_function: Callable[[str], Any], use_unique_id: bool = True + self, relative_file_path: str, save_function: Callable[[Any, str], None], load_function: Callable[[str], Any], use_unique_id: bool = True, cacheable: bool = False ): """ Make target for models which generate multiple files in saving, e.g. gensim.Word2Vec, Tensorflow, and so on. @@ -303,6 +303,7 @@ def make_model_target( save_function=save_function, load_function=load_function, task_lock_params=task_lock_params, + cacheable=cacheable ) @overload diff --git a/test/in_memory/test_cacheable_target.py b/test/in_memory/test_cacheable_target.py index 5cf1e3cf..eefdd7ae 100644 --- a/test/in_memory/test_cacheable_target.py +++ b/test/in_memory/test_cacheable_target.py @@ -1,11 +1,11 @@ import pytest -from gokart.target import CacheableSingleFileTarget +from gokart.target import CacheableSingleFileTarget, CacheableModelTarget from gokart.task import TaskOnKart from gokart.in_memory import InMemoryCacheRepository from gokart.in_memory.cacheable import CacheNotFoundError import luigi from time import sleep - +import pickle class DummyTask(TaskOnKart): namespace = __name__ param = luigi.IntParameter() @@ -171,3 +171,203 @@ def test_last_modification_time_with_cache_and_file(self, task: TaskOnKart): sleep(0.1) cache_target.dump('inmemory_data') assert cacheable_target.last_modification_time() == cache_target.last_modification_time() + +class DummyModule: + def func(self): + return 'hello world' + +class DummyModuleA: + def func_a(self): + return f'hello world' + +class DummyModuleB: + def func_b(self): + return f'hello world' + +def _save_func(obj, path): + with open(path, 'wb') as f: + pickle.dump(obj, f) + +def _load_func(path): + with open(path, 'rb') as f: + return pickle.load(f) + +class TestCacheableModelTarget: + @pytest.fixture + def task(self, tmpdir): + task = DummyTask(param=100, workspace_directory=tmpdir) + return task + + @pytest.fixture(autouse=True) + def clear_repository(self): + InMemoryCacheRepository.clear() + + def test_exists_without_cache_or_file(self, task: TaskOnKart): + target = task.make_model_target( + relative_file_path='model.zip', + save_function=_save_func, + load_function=_load_func + ) + cacheable_target = task.make_model_target( + relative_file_path='model.zip', + save_function=_save_func, + load_function=_load_func, + cacheable=True + ) + cache_target = task.make_cache_target(target.path(), use_unique_id=False) + assert not target.exists() + assert not cache_target.exists() + assert not cacheable_target.exists() + + def test_exists_with_file(self, task: TaskOnKart): + target = task.make_model_target( + relative_file_path='model.zip', + save_function=_save_func, + load_function=_load_func + ) + cacheable_target = task.make_model_target( + relative_file_path='model.zip', + save_function=_save_func, + load_function=_load_func, + cacheable=True + ) + assert not cacheable_target.exists() + module = DummyModule() + target.dump(module) + assert cacheable_target.exists() + + def test_exists_with_cache(self, task: TaskOnKart): + target = task.make_model_target( + relative_file_path='model.zip', + save_function=_save_func, + load_function=_load_func + ) + cacheable_target = task.make_model_target( + relative_file_path='model.zip', + save_function=_save_func, + load_function=_load_func, + cacheable=True + ) + cache_target = task.make_cache_target(target.path(), use_unique_id=False) + assert not cacheable_target.exists() + module = DummyModule() + cache_target.dump(module) + assert not target.exists() + assert cacheable_target.exists() + + def test_load_without_cache_or_file(self, task: TaskOnKart): + target = task.make_model_target( + relative_file_path='model.zip', + save_function=_save_func, + load_function=_load_func + ) + cacheable_target = task.make_model_target( + relative_file_path='model.zip', + save_function=_save_func, + load_function=_load_func, + cacheable=True + ) + cache_target = task.make_cache_target(target.path(), use_unique_id=False) + with pytest.raises(FileNotFoundError): + target.load() + with pytest.raises(KeyError): + cache_target.load() + with pytest.raises(CacheNotFoundError): + cacheable_target.load() + + def test_load_with_cache(self, task: TaskOnKart): + target = task.make_model_target( + relative_file_path='model.zip', + save_function=_save_func, + load_function=_load_func + ) + cacheable_target = task.make_model_target( + relative_file_path='model.zip', + save_function=_save_func, + load_function=_load_func, + cacheable=True + ) + cache_target = task.make_cache_target(target.path(), use_unique_id=False) + module = DummyModule() + cache_target.dump(module) + with pytest.raises(FileNotFoundError): + target.load() + assert isinstance(cache_target.load(), DummyModule) + assert isinstance(cacheable_target.load(), DummyModule) + + def test_load_with_file(self, task: TaskOnKart): + target = task.make_model_target( + relative_file_path='model.zip', + save_function=_save_func, + load_function=_load_func + ) + cacheable_target = task.make_model_target( + relative_file_path='model.zip', + save_function=_save_func, + load_function=_load_func, + cacheable=True + ) + cache_target = task.make_cache_target(target.path(), use_unique_id=False) + module = DummyModule() + target.dump(module) + assert isinstance(target.load(), DummyModule) + assert not cache_target.exists() + assert isinstance(cacheable_target.load(), DummyModule) + assert cache_target.exists() + + def test_load_with_cache_and_file(self, task: TaskOnKart): + target = task.make_model_target( + relative_file_path='model.zip', + save_function=_save_func, + load_function=_load_func + ) + cacheable_target = task.make_model_target( + relative_file_path='model.zip', + save_function=_save_func, + load_function=_load_func, + cacheable=True + ) + cache_target = task.make_cache_target(target.path(), use_unique_id=False) + inmemory_module_cls, file_module_cls = DummyModule, DummyModuleA + inmemory_module, file_module = inmemory_module_cls(), file_module_cls() + target.dump(file_module) + cache_target.dump(inmemory_module) + assert isinstance(target.load(), file_module_cls) + assert isinstance(cache_target.load(), inmemory_module_cls) + assert isinstance(cacheable_target.load(), inmemory_module_cls) + + def test_dump(self, task: TaskOnKart): + target = task.make_model_target( + relative_file_path='model.zip', + save_function=_save_func, + load_function=_load_func + ) + cacheable_target: CacheableModelTarget = task.make_model_target( + relative_file_path='model.zip', + save_function=_save_func, + load_function=_load_func, + cacheable=True + ) + module = DummyModule() + cacheable_target.dump(module) + assert not target.exists() + assert cacheable_target.exists() + + def test_dump_with_dump_to_file_flag(self, task: TaskOnKart): + target = task.make_model_target( + relative_file_path='model.zip', + save_function=_save_func, + load_function=_load_func + ) + cacheable_target: CacheableModelTarget = task.make_model_target( + relative_file_path='model.zip', + save_function=_save_func, + load_function=_load_func, + cacheable=True + ) + cache_target = task.make_cache_target(target.path(), use_unique_id=False) + module = DummyModule() + cacheable_target.dump(module, also_dump_to_file=True) + assert target.exists() + assert cache_target.exists() + assert cacheable_target.exists() \ No newline at end of file From abb217aafb0f3a929b375e41626a0e680aef320a Mon Sep 17 00:00:00 2001 From: utotsubasa Date: Wed, 12 Feb 2025 12:38:52 +0900 Subject: [PATCH 16/16] fix: fix linting errors --- gokart/in_memory/cacheable.py | 40 +++-- gokart/target.py | 19 ++- gokart/task.py | 20 ++- test/in_memory/test_cacheable_target.py | 162 ++++++------------- test/in_memory/test_in_memory_target.py | 1 + test/in_memory/test_task_cached_in_memory.py | 2 +- 6 files changed, 100 insertions(+), 144 deletions(-) diff --git a/gokart/in_memory/cacheable.py b/gokart/in_memory/cacheable.py index 986ae285..05664e49 100644 --- a/gokart/in_memory/cacheable.py +++ b/gokart/in_memory/cacheable.py @@ -1,14 +1,16 @@ -from gokart.conflict_prevention_lock.task_lock_wrappers import wrap_dump_with_lock, wrap_load_with_lock, wrap_remove_with_lock from typing import Any + +from gokart.conflict_prevention_lock.task_lock_wrappers import wrap_dump_with_lock, wrap_load_with_lock, wrap_remove_with_lock from gokart.in_memory.repository import InMemoryCacheRepository -class CacheNotFoundError(OSError): - ... + +class CacheNotFoundError(OSError): ... + class CacheableMixin: def exists(self) -> bool: return self._cache_exists() or super().exists() - + def load(self) -> bool: def _load(): cached = self._cache_exists() @@ -17,29 +19,32 @@ def _load(): else: try: loaded = super(CacheableMixin, self)._load() - except FileNotFoundError: - raise CacheNotFoundError + except FileNotFoundError as e: + raise CacheNotFoundError from e self._cache_dump(loaded) return loaded + return wrap_load_with_lock(func=_load, task_lock_params=super()._get_task_lock_params())() - + def dump(self, obj: Any, lock_at_dump: bool = True, also_dump_to_file: bool = False): # TODO: how to sync cache and file def _dump(obj: Any): self._cache_dump(obj) if also_dump_to_file: super(CacheableMixin, self)._dump(obj) + if lock_at_dump: wrap_dump_with_lock(func=_dump, task_lock_params=super()._get_task_lock_params(), exist_check=self.exists)(obj) else: _dump(obj) - + def remove(self, also_remove_file: bool = False): def _remove(): if self._cache_exists(): self._cache_remove() if super(CacheableMixin, self).exists() and also_remove_file: super(CacheableMixin, self)._remove() + wrap_remove_with_lock(func=_remove, task_lock_params=super()._get_task_lock_params())() def last_modification_time(self): @@ -47,8 +52,8 @@ def last_modification_time(self): return self._cache_last_modification_time() try: return super()._last_modification_time() - except: - raise CacheNotFoundError + except FileNotFoundError as e: + raise CacheNotFoundError from e @property def data_key(self): @@ -56,19 +61,20 @@ def data_key(self): def _cache_exists(self): raise NotImplementedError - + def _cache_load(self): raise NotImplementedError - + def _cache_dump(self, obj: Any): raise NotImplementedError - + def _cache_remove(self): raise NotImplementedError - + def _cache_last_modification_time(self): raise NotImplementedError + class InMemoryCacheableMixin(CacheableMixin): _repository = InMemoryCacheRepository() @@ -77,12 +83,12 @@ def _cache_exists(self): def _cache_load(self): return self._repository.get_value(self.data_key) - + def _cache_dump(self, obj): return self._repository.set_value(self.data_key, obj) - + def _cache_remove(self): self._repository.remove(self.data_key) - + def _cache_last_modification_time(self): return self._repository.get_last_modification_time(self.data_key) diff --git a/gokart/target.py b/gokart/target.py index e959b12e..0fdaa958 100644 --- a/gokart/target.py +++ b/gokart/target.py @@ -14,10 +14,10 @@ from gokart.conflict_prevention_lock.task_lock import TaskLockParams, make_task_lock_params from gokart.conflict_prevention_lock.task_lock_wrappers import wrap_dump_with_lock, wrap_load_with_lock, wrap_remove_with_lock from gokart.file_processor import FileProcessor, make_file_processor -from gokart.object_storage import ObjectStorage -from gokart.zip_client_util import make_zip_client from gokart.in_memory.cacheable import InMemoryCacheableMixin from gokart.in_memory.repository import InMemoryCacheRepository +from gokart.object_storage import ObjectStorage +from gokart.zip_client_util import make_zip_client logger = getLogger(__name__) @@ -165,11 +165,13 @@ def _remove_temporary_directory(self): def _make_temporary_directory(self): os.makedirs(self._temporary_directory, exist_ok=True) -class CacheableSingleFileTarget(InMemoryCacheableMixin, SingleFileTarget): - ... -class CacheableModelTarget(InMemoryCacheableMixin, ModelTarget): - ... +class CacheableSingleFileTarget(InMemoryCacheableMixin, SingleFileTarget): ... + + +class CacheableModelTarget(InMemoryCacheableMixin, ModelTarget): ... + + class LargeDataFrameProcessor(object): def __init__(self, max_byte: int): self.max_byte = int(max_byte) @@ -222,7 +224,7 @@ def make_target( processor: Optional[FileProcessor] = None, task_lock_params: Optional[TaskLockParams] = None, store_index_in_feather: bool = True, - cacheable: bool = False + cacheable: bool = False, ) -> TargetOnKart: _task_lock_params = task_lock_params if task_lock_params is not None else make_task_lock_params(file_path=file_path, unique_id=unique_id) file_path = _make_file_path(file_path, unique_id) @@ -253,6 +255,7 @@ def make_model_target( task_lock_params=_task_lock_params, ) + class InMemoryTarget(TargetOnKart): def __init__(self, data_key: str, task_lock_param: TaskLockParams): if task_lock_param.should_task_lock: @@ -296,4 +299,4 @@ def _make_data_key(data_key: str, unique_id: Optional[str] = None): def make_inmemory_target(data_key: str, task_lock_params: TaskLockParams, unique_id: Optional[str] = None): _data_key = _make_data_key(data_key, unique_id) - return InMemoryTarget(_data_key, task_lock_params) \ No newline at end of file + return InMemoryTarget(_data_key, task_lock_params) diff --git a/gokart/task.py b/gokart/task.py index 9bb13c89..66f8c65f 100644 --- a/gokart/task.py +++ b/gokart/task.py @@ -214,7 +214,9 @@ def clone(self, cls=None, **kwargs): return cls(**new_k) - def make_target(self, relative_file_path: Optional[str] = None, use_unique_id: bool = True, processor: Optional[FileProcessor] = None, cacheable: bool = False) -> TargetOnKart: + def make_target( + self, relative_file_path: Optional[str] = None, use_unique_id: bool = True, processor: Optional[FileProcessor] = None, cacheable: bool = False + ) -> TargetOnKart: formatted_relative_file_path = ( relative_file_path if relative_file_path is not None else os.path.join(self.__module__.replace('.', '/'), f'{type(self).__name__}.pkl') ) @@ -231,7 +233,12 @@ def make_target(self, relative_file_path: Optional[str] = None, use_unique_id: b ) return gokart.target.make_target( - file_path=file_path, unique_id=unique_id, processor=processor, task_lock_params=task_lock_params, store_index_in_feather=self.store_index_in_feather, cacheable=cacheable + file_path=file_path, + unique_id=unique_id, + processor=processor, + task_lock_params=task_lock_params, + store_index_in_feather=self.store_index_in_feather, + cacheable=cacheable, ) def make_cache_target(self, data_key: Optional[str] = None, use_unique_id: bool = True): @@ -274,7 +281,12 @@ def make_large_data_frame_target(self, relative_file_path: Optional[str] = None, ) def make_model_target( - self, relative_file_path: str, save_function: Callable[[Any, str], None], load_function: Callable[[str], Any], use_unique_id: bool = True, cacheable: bool = False + self, + relative_file_path: str, + save_function: Callable[[Any, str], None], + load_function: Callable[[str], Any], + use_unique_id: bool = True, + cacheable: bool = False, ): """ Make target for models which generate multiple files in saving, e.g. gensim.Word2Vec, Tensorflow, and so on. @@ -303,7 +315,7 @@ def make_model_target( save_function=save_function, load_function=load_function, task_lock_params=task_lock_params, - cacheable=cacheable + cacheable=cacheable, ) @overload diff --git a/test/in_memory/test_cacheable_target.py b/test/in_memory/test_cacheable_target.py index eefdd7ae..7dd55830 100644 --- a/test/in_memory/test_cacheable_target.py +++ b/test/in_memory/test_cacheable_target.py @@ -1,11 +1,15 @@ +import pickle +from time import sleep + +import luigi import pytest -from gokart.target import CacheableSingleFileTarget, CacheableModelTarget -from gokart.task import TaskOnKart + from gokart.in_memory import InMemoryCacheRepository from gokart.in_memory.cacheable import CacheNotFoundError -import luigi -from time import sleep -import pickle +from gokart.target import CacheableModelTarget, CacheableSingleFileTarget +from gokart.task import TaskOnKart + + class DummyTask(TaskOnKart): namespace = __name__ param = luigi.IntParameter() @@ -13,6 +17,7 @@ class DummyTask(TaskOnKart): def run(self): self.dump(self.param) + class TestCacheableSingleFileTarget: @pytest.fixture def task(self, tmpdir): @@ -36,7 +41,7 @@ def test_exists_when_file_exists(self, task: TaskOnKart): assert not cacheable_target.exists() target.dump('data') assert cacheable_target.exists() - + def test_load_without_cache_or_file(self, task: TaskOnKart): target = task.make_target('sample.pkl') with pytest.raises(FileNotFoundError): @@ -44,7 +49,7 @@ def test_load_without_cache_or_file(self, task: TaskOnKart): cacheable_target = task.make_target('sample.pkl', cacheable=True) with pytest.raises(CacheNotFoundError): cacheable_target.load() - + def test_load_with_cache(self, task: TaskOnKart): target = task.make_target('sample.pkl') cacheable_target: CacheableSingleFileTarget = task.make_target('sample.pkl', cacheable=True) @@ -79,7 +84,7 @@ def test_load_with_cache_and_file(self, task: TaskOnKart): target.dump('data_in_file') cache_target.dump('data_in_memory') assert cacheable_target.load() == 'data_in_memory' - + def test_dump(self, task: TaskOnKart): target = task.make_target('sample.pkl') cacheable_target: CacheableSingleFileTarget = task.make_target('sample.pkl', cacheable=True) @@ -97,7 +102,7 @@ def test_dump_with_dump_to_file_flag(self, task: TaskOnKart): assert target.exists() assert cache_target.exists() assert cacheable_target.exists() - + def test_remove_without_cache_or_file(self, task: TaskOnKart): cacheable_target: CacheableSingleFileTarget = task.make_target('sample.pkl', cacheable=True) cacheable_target.remove() @@ -138,7 +143,7 @@ def test_remove_with_cache_and_file(self, task: TaskOnKart): cacheable_target.remove(also_remove_file=True) assert not target.exists() assert not cache_target.exists() - + def test_last_modification_time_without_cache_and_file(self, task: TaskOnKart): target = task.make_target('sample.pkl') cacheable_target: CacheableSingleFileTarget = task.make_target('sample.pkl', cacheable=True) @@ -162,7 +167,7 @@ def test_last_modification_time_with_file(self, task: TaskOnKart): cacheable_target: CacheableSingleFileTarget = task.make_target('sample.pkl', cacheable=True) target.dump('data') assert cacheable_target.last_modification_time() == target.last_modification_time() - + def test_last_modification_time_with_cache_and_file(self, task: TaskOnKart): target = task.make_target('sample.pkl') cacheable_target: CacheableSingleFileTarget = task.make_target('sample.pkl', cacheable=True) @@ -172,26 +177,32 @@ def test_last_modification_time_with_cache_and_file(self, task: TaskOnKart): cache_target.dump('inmemory_data') assert cacheable_target.last_modification_time() == cache_target.last_modification_time() + class DummyModule: def func(self): return 'hello world' + class DummyModuleA: def func_a(self): - return f'hello world' + return 'hello world' + class DummyModuleB: def func_b(self): - return f'hello world' - + return 'hello world' + + def _save_func(obj, path): with open(path, 'wb') as f: pickle.dump(obj, f) + def _load_func(path): with open(path, 'rb') as f: return pickle.load(f) + class TestCacheableModelTarget: @pytest.fixture def task(self, tmpdir): @@ -203,70 +214,34 @@ def clear_repository(self): InMemoryCacheRepository.clear() def test_exists_without_cache_or_file(self, task: TaskOnKart): - target = task.make_model_target( - relative_file_path='model.zip', - save_function=_save_func, - load_function=_load_func - ) - cacheable_target = task.make_model_target( - relative_file_path='model.zip', - save_function=_save_func, - load_function=_load_func, - cacheable=True - ) + target = task.make_model_target(relative_file_path='model.zip', save_function=_save_func, load_function=_load_func) + cacheable_target = task.make_model_target(relative_file_path='model.zip', save_function=_save_func, load_function=_load_func, cacheable=True) cache_target = task.make_cache_target(target.path(), use_unique_id=False) assert not target.exists() assert not cache_target.exists() assert not cacheable_target.exists() - + def test_exists_with_file(self, task: TaskOnKart): - target = task.make_model_target( - relative_file_path='model.zip', - save_function=_save_func, - load_function=_load_func - ) - cacheable_target = task.make_model_target( - relative_file_path='model.zip', - save_function=_save_func, - load_function=_load_func, - cacheable=True - ) + target = task.make_model_target(relative_file_path='model.zip', save_function=_save_func, load_function=_load_func) + cacheable_target = task.make_model_target(relative_file_path='model.zip', save_function=_save_func, load_function=_load_func, cacheable=True) assert not cacheable_target.exists() module = DummyModule() target.dump(module) assert cacheable_target.exists() def test_exists_with_cache(self, task: TaskOnKart): - target = task.make_model_target( - relative_file_path='model.zip', - save_function=_save_func, - load_function=_load_func - ) - cacheable_target = task.make_model_target( - relative_file_path='model.zip', - save_function=_save_func, - load_function=_load_func, - cacheable=True - ) + target = task.make_model_target(relative_file_path='model.zip', save_function=_save_func, load_function=_load_func) + cacheable_target = task.make_model_target(relative_file_path='model.zip', save_function=_save_func, load_function=_load_func, cacheable=True) cache_target = task.make_cache_target(target.path(), use_unique_id=False) assert not cacheable_target.exists() module = DummyModule() cache_target.dump(module) assert not target.exists() assert cacheable_target.exists() - + def test_load_without_cache_or_file(self, task: TaskOnKart): - target = task.make_model_target( - relative_file_path='model.zip', - save_function=_save_func, - load_function=_load_func - ) - cacheable_target = task.make_model_target( - relative_file_path='model.zip', - save_function=_save_func, - load_function=_load_func, - cacheable=True - ) + target = task.make_model_target(relative_file_path='model.zip', save_function=_save_func, load_function=_load_func) + cacheable_target = task.make_model_target(relative_file_path='model.zip', save_function=_save_func, load_function=_load_func, cacheable=True) cache_target = task.make_cache_target(target.path(), use_unique_id=False) with pytest.raises(FileNotFoundError): target.load() @@ -274,19 +249,10 @@ def test_load_without_cache_or_file(self, task: TaskOnKart): cache_target.load() with pytest.raises(CacheNotFoundError): cacheable_target.load() - + def test_load_with_cache(self, task: TaskOnKart): - target = task.make_model_target( - relative_file_path='model.zip', - save_function=_save_func, - load_function=_load_func - ) - cacheable_target = task.make_model_target( - relative_file_path='model.zip', - save_function=_save_func, - load_function=_load_func, - cacheable=True - ) + target = task.make_model_target(relative_file_path='model.zip', save_function=_save_func, load_function=_load_func) + cacheable_target = task.make_model_target(relative_file_path='model.zip', save_function=_save_func, load_function=_load_func, cacheable=True) cache_target = task.make_cache_target(target.path(), use_unique_id=False) module = DummyModule() cache_target.dump(module) @@ -296,17 +262,8 @@ def test_load_with_cache(self, task: TaskOnKart): assert isinstance(cacheable_target.load(), DummyModule) def test_load_with_file(self, task: TaskOnKart): - target = task.make_model_target( - relative_file_path='model.zip', - save_function=_save_func, - load_function=_load_func - ) - cacheable_target = task.make_model_target( - relative_file_path='model.zip', - save_function=_save_func, - load_function=_load_func, - cacheable=True - ) + target = task.make_model_target(relative_file_path='model.zip', save_function=_save_func, load_function=_load_func) + cacheable_target = task.make_model_target(relative_file_path='model.zip', save_function=_save_func, load_function=_load_func, cacheable=True) cache_target = task.make_cache_target(target.path(), use_unique_id=False) module = DummyModule() target.dump(module) @@ -314,19 +271,10 @@ def test_load_with_file(self, task: TaskOnKart): assert not cache_target.exists() assert isinstance(cacheable_target.load(), DummyModule) assert cache_target.exists() - + def test_load_with_cache_and_file(self, task: TaskOnKart): - target = task.make_model_target( - relative_file_path='model.zip', - save_function=_save_func, - load_function=_load_func - ) - cacheable_target = task.make_model_target( - relative_file_path='model.zip', - save_function=_save_func, - load_function=_load_func, - cacheable=True - ) + target = task.make_model_target(relative_file_path='model.zip', save_function=_save_func, load_function=_load_func) + cacheable_target = task.make_model_target(relative_file_path='model.zip', save_function=_save_func, load_function=_load_func, cacheable=True) cache_target = task.make_cache_target(target.path(), use_unique_id=False) inmemory_module_cls, file_module_cls = DummyModule, DummyModuleA inmemory_module, file_module = inmemory_module_cls(), file_module_cls() @@ -337,16 +285,9 @@ def test_load_with_cache_and_file(self, task: TaskOnKart): assert isinstance(cacheable_target.load(), inmemory_module_cls) def test_dump(self, task: TaskOnKart): - target = task.make_model_target( - relative_file_path='model.zip', - save_function=_save_func, - load_function=_load_func - ) + target = task.make_model_target(relative_file_path='model.zip', save_function=_save_func, load_function=_load_func) cacheable_target: CacheableModelTarget = task.make_model_target( - relative_file_path='model.zip', - save_function=_save_func, - load_function=_load_func, - cacheable=True + relative_file_path='model.zip', save_function=_save_func, load_function=_load_func, cacheable=True ) module = DummyModule() cacheable_target.dump(module) @@ -354,20 +295,13 @@ def test_dump(self, task: TaskOnKart): assert cacheable_target.exists() def test_dump_with_dump_to_file_flag(self, task: TaskOnKart): - target = task.make_model_target( - relative_file_path='model.zip', - save_function=_save_func, - load_function=_load_func - ) + target = task.make_model_target(relative_file_path='model.zip', save_function=_save_func, load_function=_load_func) cacheable_target: CacheableModelTarget = task.make_model_target( - relative_file_path='model.zip', - save_function=_save_func, - load_function=_load_func, - cacheable=True + relative_file_path='model.zip', save_function=_save_func, load_function=_load_func, cacheable=True ) cache_target = task.make_cache_target(target.path(), use_unique_id=False) module = DummyModule() cacheable_target.dump(module, also_dump_to_file=True) assert target.exists() assert cache_target.exists() - assert cacheable_target.exists() \ No newline at end of file + assert cacheable_target.exists() diff --git a/test/in_memory/test_in_memory_target.py b/test/in_memory/test_in_memory_target.py index 987af21a..0716c514 100644 --- a/test/in_memory/test_in_memory_target.py +++ b/test/in_memory/test_in_memory_target.py @@ -7,6 +7,7 @@ from gokart.in_memory import InMemoryCacheRepository from gokart.target import InMemoryTarget, make_inmemory_target + class TestInMemoryTarget: @pytest.fixture def task_lock_params(self): diff --git a/test/in_memory/test_task_cached_in_memory.py b/test/in_memory/test_task_cached_in_memory.py index fb5ac389..e57d5444 100644 --- a/test/in_memory/test_task_cached_in_memory.py +++ b/test/in_memory/test_task_cached_in_memory.py @@ -5,7 +5,7 @@ import gokart from gokart.in_memory import InMemoryCacheRepository -from gokart.target import SingleFileTarget, InMemoryTarget +from gokart.target import InMemoryTarget, SingleFileTarget class DummyTask(gokart.TaskOnKart):