Skip to content

Commit

Permalink
Merge pull request #41 from m3dev/local_cache
Browse files Browse the repository at this point in the history
[draft] add local cache
  • Loading branch information
vaaaaanquish authored Jan 28, 2021
2 parents ea13421 + eb452ee commit 4faba6f
Show file tree
Hide file tree
Showing 12 changed files with 187 additions and 54 deletions.
2 changes: 1 addition & 1 deletion test/test_gcs_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
class TestGCSClient(unittest.TestCase):
def setUp(self):
self.base_path = 'gs://bucket/prefix/'
self.client = GCSClient(self.base_path, None, None)
self.client = GCSClient(self.base_path, None, None, use_cache=False)

def test_to_absolute_path(self):
source = 'gs://bucket/prefix/hoge/piyo'
Expand Down
28 changes: 28 additions & 0 deletions test/test_local_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import unittest
import os
from pathlib import Path

from thunderbolt.client.local_cache import LocalCache


class TestLocalCache(unittest.TestCase):
def setUp(self):
self.base_path = './resources'
self.local_cache = LocalCache(self.base_path, True)

def test_init(self):
self.assertTrue(os.path.exists('./thunderbolt'))

def test_dump_and_get(self):
target = {'foo': 'bar'}
self.local_cache.dump('test.pkl', target)
output = self.local_cache.get('test.pkl')
self.assertDictEqual(target, output)

def test_convert_file_path(self):
output = self.local_cache._convert_file_path('test.pkl')
target = Path(os.path.join(os.getcwd(), '.thunderbolt', self.base_path.split('/')[-1], 'test.pkl'))
self.assertEqual(target, output)

def tearDown(self):
self.local_cache.clear()
31 changes: 31 additions & 0 deletions test/test_local_cache_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import unittest
from os import path
import pandas as pd
import pickle

from thunderbolt import Thunderbolt
from thunderbolt.client.local_cache import LocalCache
"""
requires:
python sample.py test.TestCaseTask --param=sample --number=1 --workspace-directory=./test_case --local-scheduler
"""


class LocalCacheTest(unittest.TestCase):
def test_running(self):
target = Thunderbolt(self._get_test_case_path(), use_cache=False)
_ = Thunderbolt(self._get_test_case_path())
output = Thunderbolt(self._get_test_case_path())

for k, v in target.tasks.items():
if k == 'last_modified': # cache file
continue
self.assertEqual(v, output.tasks[k])

output.client.local_cache.clear()

def _get_test_case_path(self, file_name: str = ''):
p = path.abspath(path.join(path.dirname(__file__), 'test_case'))
if file_name:
return path.join(p, file_name)
return p
2 changes: 1 addition & 1 deletion test/test_local_directory_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

class TestLocalDirectoryClient(unittest.TestCase):
def setUp(self):
self.client = LocalDirectoryClient('.', None, None)
self.client = LocalDirectoryClient('.', None, None, use_cache=False)

def test_to_absolute_path(self):
source = './hoge/hoge/piyo'
Expand Down
2 changes: 1 addition & 1 deletion test/test_s3_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
class TestS3Client(unittest.TestCase):
def setUp(self):
self.base_path = 's3://bucket/prefix/'
self.client = S3Client(self.base_path, None, None)
self.client = S3Client(self.base_path, None, None, use_cache=False)

def test_to_absolute_path(self):
source = 's3://bucket/prefix/hoge/piyo'
Expand Down
35 changes: 16 additions & 19 deletions test/test_thunderbolt.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def get_tasks():
with ExitStack() as stack:
for module in ['local_directory_client.LocalDirectoryClient', 'gcs_client.GCSClient', 's3_client.S3Client']:
stack.enter_context(patch('.'.join([module_path, module, 'get_tasks']), side_effect=get_tasks))
self.tb = thunderbolt.Thunderbolt(None)
self.tb = thunderbolt.Thunderbolt(None, use_cache=False)

def test_get_client(self):
source_workspace_directory = ['s3://', 'gs://', 'gcs://', './local', 'hoge']
Expand All @@ -27,27 +27,24 @@ def test_get_client(self):
target = [S3Client, GCSClient, GCSClient, LocalDirectoryClient, LocalDirectoryClient]

for s, t in zip(source_workspace_directory, target):
output = self.tb._get_client(s, source_filters, source_tqdm_disable)
output = self.tb._get_client(s, source_filters, source_tqdm_disable, False)
self.assertEqual(type(output), t)

def test_get_tasks_dic(self):
tasks_list = [
{
'task_name': 'task',
'last_modified': 'last_modified_2',
'task_params': 'task_params_1',
'task_hash': 'task_hash_1',
'task_log': 'task_log_1'
},
{
'task_name': 'task',
'last_modified': 'last_modified_1',
'task_params': 'task_params_1',
'task_hash': 'task_hash_1',
'task_log': 'task_log_1'
}
]

tasks_list = [{
'task_name': 'task',
'last_modified': 'last_modified_2',
'task_params': 'task_params_1',
'task_hash': 'task_hash_1',
'task_log': 'task_log_1'
}, {
'task_name': 'task',
'last_modified': 'last_modified_1',
'task_params': 'task_params_1',
'task_hash': 'task_hash_1',
'task_log': 'task_log_1'
}]

target = {
0: {
'task_name': 'task',
Expand Down
2 changes: 1 addition & 1 deletion test/test_thunderbolt_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

class SimpleLocalTest(unittest.TestCase):
def setUp(self):
self.tb = thunderbolt.Thunderbolt(self.get_test_case_path())
self.tb = thunderbolt.Thunderbolt(self.get_test_case_path(), use_cache=False)

def test_init(self):
self.assertEqual(self.tb.client.workspace_directory, self.get_test_case_path())
Expand Down
19 changes: 16 additions & 3 deletions thunderbolt/client/gcs_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,20 @@
from datetime import datetime
from typing import List, Dict, Any

from thunderbolt.client.local_cache import LocalCache

from gokart.gcs_config import GCSConfig


class GCSClient:
def __init__(self, workspace_directory: str = '', task_filters: List[str] = [], tqdm_disable: bool = False):
def __init__(self, workspace_directory: str = '', task_filters: List[str] = [], tqdm_disable: bool = False, use_cache: bool = True):
"""must set $GCS_CREDENTIAL"""
self.workspace_directory = workspace_directory
self.task_filters = task_filters
self.tqdm_disable = tqdm_disable
self.gcs_client = GCSConfig().get_gcs_client()
self.local_cache = LocalCache(workspace_directory, use_cache)
self.use_cache = use_cache

def get_tasks(self) -> List[Dict[str, Any]]:
"""Load all task_log from GCS"""
Expand All @@ -26,15 +30,24 @@ def get_tasks(self) -> List[Dict[str, Any]]:
continue
n = n.split('_')

if self.use_cache:
cache = self.local_cache.get(x)
if cache:
tasks_list.append(cache)
continue

try:
meta = self._get_gcs_object_info(x)
tasks_list.append({
params = {
'task_name': '_'.join(n[:-1]),
'task_params': pickle.load(self.gcs_client.download(x.replace('task_log', 'task_params'))),
'task_log': pickle.load(self.gcs_client.download(x)),
'last_modified': datetime.strptime(meta['updated'].split('.')[0], '%Y-%m-%dT%H:%M:%S'),
'task_hash': n[-1].split('.')[0]
})
}
tasks_list.append(params)
if self.use_cache:
self.local_cache.dump(x, params)
except Exception:
continue

Expand Down
37 changes: 37 additions & 0 deletions thunderbolt/client/local_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import os
import pickle
import shutil
from pathlib import Path
from typing import Optional


class LocalCache:
def __init__(self, workspace_directory: str, use_cache: bool):
"""Log file cache.
dump file: ./.thunderbolt/resources/{task_hash}.pkl
"""
self.cache_dir = Path(os.path.join(os.getcwd(), '.thunderbolt', workspace_directory.split('/')[-1]))
if use_cache:
self.cache_dir.mkdir(parents=True, exist_ok=True)

def get(self, file_name: str) -> Optional[dict]:
cache_file_path = self._convert_file_path(file_name)
if cache_file_path.exists():
with cache_file_path.open(mode='rb') as f:
params = pickle.load(f)
return params
return None

def dump(self, file_name: str, params: dict):
cache_file_path = self._convert_file_path(file_name)
with cache_file_path.open(mode='wb') as f:
pickle.dump(params, f)

def clear(self):
shutil.rmtree(os.path.join(os.getcwd(), '.thunderbolt'))

def _convert_file_path(self, file_name: str) -> Path:
file_name = file_name.split('/')[-1]
cache_file_path = self.cache_dir.joinpath(file_name)
return cache_file_path.with_suffix('.pkl')
31 changes: 22 additions & 9 deletions thunderbolt/client/local_directory_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,18 @@
import pickle
from typing import List, Dict, Any

from thunderbolt.client.local_cache import LocalCache

from tqdm import tqdm


class LocalDirectoryClient:
def __init__(self, workspace_directory: str = '', task_filters: List[str] = [], tqdm_disable: bool = False):
def __init__(self, workspace_directory: str = '', task_filters: List[str] = [], tqdm_disable: bool = False, use_cache: bool = True):
self.workspace_directory = os.path.abspath(workspace_directory)
self.task_filters = task_filters
self.tqdm_disable = tqdm_disable
self.local_cache = LocalCache(workspace_directory, use_cache)
self.use_cache = use_cache

def get_tasks(self) -> List[Dict[str, Any]]:
"""Load all task_log from workspace_directory."""
Expand All @@ -24,23 +28,32 @@ def get_tasks(self) -> List[Dict[str, Any]]:
continue
n = n.split('_')

if self.use_cache:
cache = self.local_cache.get(x)
if cache:
tasks_list.append(cache)
continue

try:
modified = datetime.fromtimestamp(os.stat(x).st_mtime)
with open(x, 'rb') as f:
task_log = pickle.load(f)
with open(x.replace('task_log', 'task_params'), 'rb') as f:
task_params = pickle.load(f)

params = {
'task_name': '_'.join(n[:-1]),
'task_params': task_params,
'task_log': task_log,
'last_modified': modified,
'task_hash': n[-1].split('.')[0],
}
tasks_list.append(params)
if self.use_cache:
self.local_cache.dump(x, params)
except Exception:
continue

tasks_list.append({
'task_name': '_'.join(n[:-1]),
'task_params': task_params,
'task_log': task_log,
'last_modified': modified,
'task_hash': n[-1].split('.')[0],
})

if len(tasks_list) != len(files):
warnings.warn(f'[NOT FOUND LOGS] target file: {len(files)}, found log file: {len(tasks_list)}')

Expand Down
34 changes: 21 additions & 13 deletions thunderbolt/client/s3_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,24 @@
import warnings
from typing import List, Dict, Any

from thunderbolt.client.local_cache import LocalCache

import boto3
from boto3 import Session
from tqdm import tqdm


class S3Client:
def __init__(self, workspace_directory: str = '', task_filters: List[str] = [], tqdm_disable: bool = False):
def __init__(self, workspace_directory: str = '', task_filters: List[str] = [], tqdm_disable: bool = False, use_cache: bool = True):
self.workspace_directory = workspace_directory
self.task_filters = task_filters
self.tqdm_disable = tqdm_disable
self.bucket_name = workspace_directory.replace('s3://', '').split('/')[0]
self.prefix = '/'.join(workspace_directory.replace('s3://', '').split('/')[1:])
self.resource = boto3.resource('s3')
self.s3client = Session().client('s3')
self.local_cache = LocalCache(workspace_directory, use_cache)
self.use_cache = use_cache

def get_tasks(self) -> List[Dict[str, Any]]:
"""Load all task_log from S3"""
Expand All @@ -28,19 +32,23 @@ def get_tasks(self) -> List[Dict[str, Any]]:
continue
n = n.split('_')

if self.use_cache:
cache = self.local_cache.get(x)
if cache:
tasks_list.append(cache)
continue

try:
tasks_list.append({
'task_name':
'_'.join(n[:-1]),
'task_params':
pickle.loads(self.resource.Object(self.bucket_name, x['Key'].replace('task_log', 'task_params')).get()['Body'].read()),
'task_log':
pickle.loads(self.resource.Object(self.bucket_name, x['Key']).get()['Body'].read()),
'last_modified':
x['LastModified'],
'task_hash':
n[-1].split('.')[0]
})
params = {
'task_name': '_'.join(n[:-1]),
'task_params': pickle.loads(self.resource.Object(self.bucket_name, x['Key'].replace('task_log', 'task_params')).get()['Body'].read()),
'task_log': pickle.loads(self.resource.Object(self.bucket_name, x['Key']).get()['Body'].read()),
'last_modified': x['LastModified'],
'task_hash': n[-1].split('.')[0]
}
tasks_list.append(params)
if self.use_cache:
self.local_cache.dump(x, params)
except Exception:
continue

Expand Down
18 changes: 12 additions & 6 deletions thunderbolt/thunderbolt.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@


class Thunderbolt:
def __init__(self, workspace_directory: str = '', task_filters: Union[str, List[str]] = '', use_tqdm: bool = False, tmp_path: str = './tmp'):
def __init__(self,
workspace_directory: str = '',
task_filters: Union[str, List[str]] = '',
use_tqdm: bool = False,
tmp_path: str = './tmp',
use_cache: bool = True):
"""Thunderbolt init.
Set the path to the directory or S3.
Expand All @@ -21,20 +26,21 @@ def __init__(self, workspace_directory: str = '', task_filters: Union[str, List[
Load only tasks that contain the specified string here. We can also specify the number of copies.
use_tqdm: Flag of using tdqm. If False, tqdm not be displayed (default=False).
tmp_path: Temporary directory when use external load function.
use_cache: Flag of using Log Cache.
"""
self.tmp_path = tmp_path
if not workspace_directory:
env = os.getenv('TASK_WORKSPACE_DIRECTORY')
workspace_directory = env if env else ''
self.client = self._get_client(workspace_directory, [task_filters] if type(task_filters) == str else task_filters, not use_tqdm)
self.client = self._get_client(workspace_directory, [task_filters] if type(task_filters) == str else task_filters, not use_tqdm, use_cache)
self.tasks = self._get_tasks_dic(tasks_list=self.client.get_tasks())

def _get_client(self, workspace_directory, filters, tqdm_disable):
def _get_client(self, workspace_directory, filters, tqdm_disable, use_cache):
if workspace_directory.startswith('s3://'):
return S3Client(workspace_directory, filters, tqdm_disable)
return S3Client(workspace_directory, filters, tqdm_disable, use_cache)
elif workspace_directory.startswith('gs://') or workspace_directory.startswith('gcs://'):
return GCSClient(workspace_directory, filters, tqdm_disable)
return LocalDirectoryClient(workspace_directory, filters, tqdm_disable)
return GCSClient(workspace_directory, filters, tqdm_disable, use_cache)
return LocalDirectoryClient(workspace_directory, filters, tqdm_disable, use_cache)

def _get_tasks_dic(self, tasks_list: List[Dict]) -> Dict[int, Dict]:
return {i: task for i, task in enumerate(sorted(tasks_list, key=lambda x: x['last_modified']))}
Expand Down

0 comments on commit 4faba6f

Please sign in to comment.