Skip to content

Commit

Permalink
add local cache regression test
Browse files Browse the repository at this point in the history
  • Loading branch information
vaaaaanquish committed Jan 26, 2021
1 parent 7391c66 commit eb452ee
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 8 deletions.
2 changes: 1 addition & 1 deletion test/test_local_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_dump_and_get(self):

def test_convert_file_path(self):
output = self.local_cache._convert_file_path('test.pkl')
target = Path(os.path.join(os.getcwd(), '.thunderbolt', self.base_path.split('/')[-1], 'test.json'))
target = Path(os.path.join(os.getcwd(), '.thunderbolt', self.base_path.split('/')[-1], 'test.pkl'))
self.assertEqual(target, output)

def tearDown(self):
Expand Down
31 changes: 31 additions & 0 deletions test/test_local_cache_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import unittest
from os import path
import pandas as pd
import pickle

from thunderbolt import Thunderbolt
from thunderbolt.client.local_cache import LocalCache
"""
requires:
python sample.py test.TestCaseTask --param=sample --number=1 --workspace-directory=./test_case --local-scheduler
"""


class LocalCacheTest(unittest.TestCase):
def test_running(self):
target = Thunderbolt(self._get_test_case_path(), use_cache=False)
_ = Thunderbolt(self._get_test_case_path())
output = Thunderbolt(self._get_test_case_path())

for k, v in target.tasks.items():
if k == 'last_modified': # cache file
continue
self.assertEqual(v, output.tasks[k])

output.client.local_cache.clear()

def _get_test_case_path(self, file_name: str = ''):
p = path.abspath(path.join(path.dirname(__file__), 'test_case'))
if file_name:
return path.join(p, file_name)
return p
15 changes: 8 additions & 7 deletions thunderbolt/client/local_cache.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
import json
import pickle
import shutil
from pathlib import Path
from typing import Optional
Expand All @@ -9,7 +9,7 @@ class LocalCache:
def __init__(self, workspace_directory: str, use_cache: bool):
"""Log file cache.
dump file: ./.thunderbolt/resources/{task_hash}.json
dump file: ./.thunderbolt/resources/{task_hash}.pkl
"""
self.cache_dir = Path(os.path.join(os.getcwd(), '.thunderbolt', workspace_directory.split('/')[-1]))
if use_cache:
Expand All @@ -18,19 +18,20 @@ def __init__(self, workspace_directory: str, use_cache: bool):
def get(self, file_name: str) -> Optional[dict]:
cache_file_path = self._convert_file_path(file_name)
if cache_file_path.exists():
with cache_file_path.open(mode='r') as f:
params = json.load(f)
with cache_file_path.open(mode='rb') as f:
params = pickle.load(f)
return params
return None

def dump(self, file_name: str, params: dict):
cache_file_path = self._convert_file_path(file_name)
with cache_file_path.open(mode='w') as f:
json.dump(params, f)
with cache_file_path.open(mode='wb') as f:
pickle.dump(params, f)

def clear(self):
shutil.rmtree(os.path.join(os.getcwd(), '.thunderbolt'))

def _convert_file_path(self, file_name: str) -> Path:
file_name = file_name.split('/')[-1]
cache_file_path = self.cache_dir.joinpath(file_name)
return cache_file_path.with_suffix('.json')
return cache_file_path.with_suffix('.pkl')

0 comments on commit eb452ee

Please sign in to comment.