Skip to content

Commit

Permalink
Merge pull request #6 from m3dev/use_abs
Browse files Browse the repository at this point in the history
Using abspath, and add docstring, add tqdm flag
  • Loading branch information
vaaaaanquish authored Oct 25, 2019
2 parents 8723c18 + b552f6b commit 4d5b25f
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 22 deletions.
14 changes: 10 additions & 4 deletions test/test_thunderbolt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@

class SimpleLocalTest(unittest.TestCase):
def setUp(self):
self.here = path.abspath(path.dirname(__file__))
self.tb = thunderbolt.Thunderbolt(path.join(self.here, 'test_case'))
self.tb = thunderbolt.Thunderbolt(self.get_test_case_path())

def test_init(self):
self.assertEqual(self.tb.file_path, path.join(self.here, 'test_case'))
self.assertEqual(self.tb.workspace_directory, self.get_test_case_path())
self.assertEqual(self.tb.task_filters, [''])
self.assertEqual(self.tb.bucket_name, None)
self.assertEqual(self.tb.prefix, None)
Expand All @@ -30,6 +29,13 @@ def test_init(self):
self.assertListEqual(task['task_log']['file_path'], ['./test_case/sample/test_case_c5b4a28a606228ac23477557c774a3a0.pkl'])
self.assertDictEqual(task['task_params'], {'param': 'sample', 'number': '1'})

def get_test_case_path(self, file_name: str = ''):
p = path.abspath(path.join(path.dirname(__file__), 'test_case'))
if file_name:
return path.join(p, file_name)
print(p)
return p

def test_get_task_df(self):
df = self.tb.get_task_df(all_data=True)
df = df.drop('last_modified', axis=1)
Expand All @@ -49,6 +55,6 @@ def test_get_task_df(self):

def test_load(self):
x = self.tb.load(0)
with open(path.join(self.here, 'test_case/sample/test_case_c5b4a28a606228ac23477557c774a3a0.pkl'), 'rb') as f:
with open(self.get_test_case_path('sample/test_case_c5b4a28a606228ac23477557c774a3a0.pkl'), 'rb') as f:
target = pickle.load(f)
self.assertListEqual(x, [target])
80 changes: 62 additions & 18 deletions thunderbolt/thunderbolt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
from pathlib import Path
import pickle
from typing import Union, List

import boto3
from boto3 import Session
Expand All @@ -11,21 +12,35 @@


class Thunderbolt():
def __init__(self, file_path: str, task_filters=''):
def __init__(self, workspace_directory: str = '', task_filters: Union[str, List[str]] = '', use_tqdm=False):
"""Thunderbolt init.
Set the path to the directory or S3.
Args:
workspace_directory: Gokart's TASK_WORKSPACE_DIRECTORY. If None, use $TASK_WORKSPACE_DIRECTORY in os.env.
task_filters: Filter for task name.
Load only tasks that contain the specified string here. We can also specify the number of copies.
use_tqdm: Flag of using tdqm. If False, tqdm not be displayed (default=False).
"""
self.tqdm_disable = not use_tqdm
self.s3client = None
self.file_path = file_path
if not workspace_directory:
env = os.getenv('TASK_WORKSPACE_DIRECTORY')
workspace_directory = env if env else ''
self.workspace_directory = workspace_directory if workspace_directory.startswith('s3://') else os.path.abspath(workspace_directory)
self.task_filters = [task_filters] if type(task_filters) == str else task_filters
self.bucket_name = file_path.replace('s3://', '').split('/')[0] if file_path.startswith('s3://') else None
self.prefix = '/'.join(file_path.replace('s3://', '').split('/')[1:]) if file_path.startswith('s3://') else None
self.resource = boto3.resource('s3') if file_path.startswith('s3://') else None
self.s3client = Session().client('s3') if file_path.startswith('s3://') else None
self.tasks = self._get_tasks_from_s3() if file_path.startswith('s3://') else self._get_tasks()
self.bucket_name = workspace_directory.replace('s3://', '').split('/')[0] if workspace_directory.startswith('s3://') else None
self.prefix = '/'.join(workspace_directory.replace('s3://', '').split('/')[1:]) if workspace_directory.startswith('s3://') else None
self.resource = boto3.resource('s3') if workspace_directory.startswith('s3://') else None
self.s3client = Session().client('s3') if workspace_directory.startswith('s3://') else None
self.tasks = self._get_tasks_from_s3() if workspace_directory.startswith('s3://') else self._get_tasks()

def _get_tasks(self):
"""Get task parameters."""
files = {str(path) for path in Path(os.path.join(self.file_path, 'log/task_log')).rglob('*')}
"""Load all task_log from workspace_directory."""
files = {str(path) for path in Path(os.path.join(self.workspace_directory, 'log/task_log')).rglob('*')}
tasks = {}
for i, x in enumerate(tqdm(files)):
for i, x in enumerate(tqdm(files, disable=self.tqdm_disable)):
n = x.split('/')[-1]
if self.task_filters and not [x for x in self.task_filters if x in n]:
continue
Expand All @@ -45,10 +60,10 @@ def _get_tasks(self):
return tasks

def _get_tasks_from_s3(self):
"""Get task parameters from S3."""
"""Load all task_log from S3"""
files = self._get_s3_keys([], '')
tasks = {}
for i, x in enumerate(tqdm(files)):
for i, x in enumerate(tqdm(files, disable=self.tqdm_disable)):
n = x['Key'].split('/')[-1]
if self.task_filters and not [x for x in self.task_filters if x in n]:
continue
Expand All @@ -62,17 +77,35 @@ def _get_tasks_from_s3(self):
}
return tasks

def _get_s3_keys(self, keys: list = [], marker: str = '') -> list:
"""Recursively get Key from S3."""
def _get_s3_keys(self, keys=[], marker=''):
"""Recursively get Key from S3.
Using s3client api by boto module.
Reference: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html
Args:
keys: The object key to get. Increases with recursion.
marker: S3 marker. The recursion ends when this is gone.
Returns:
Object keys from S3. For example: ['hoge', 'piyo', ...]
"""
response = self.s3client.list_objects(Bucket=self.bucket_name, Prefix=os.path.join(self.prefix, 'log/task_log'), Marker=marker)
if 'Contents' in response:
keys.extend([{'Key': content['Key'], 'LastModified': content['LastModified']} for content in response['Contents']])
if 'IsTruncated' in response:
if 'Contents' in response and 'IsTruncated' in response:
return self._get_s3_keys(keys=keys, marker=keys[-1]['Key'])
return keys

def get_task_df(self, all_data: bool = False) -> pd.DataFrame:
"""Get task's pandas data frame."""
"""Get task's pandas DataFrame.
Args:
all_data: If True, add `task unique hash` and `task log data` to DataFrame.
Returns:
All gokart task infomation pandas.DataFrame.
"""
df = pd.DataFrame([{
'task_id': k,
'task_name': v['task_name'],
Expand All @@ -86,5 +119,16 @@ def get_task_df(self, all_data: bool = False) -> pd.DataFrame:
return df[['task_id', 'task_name', 'last_modified', 'task_params']]

def load(self, task_id: int) -> list:
"""Load File."""
return [gokart.target.make_target(file_path=x).load() for x in self.tasks[task_id]['task_log']['file_path']]
"""Load File using gokart.load.
Args:
task_id: Specify the ID given by Thunderbolt, Read data into memory.
Please check `task_id` by using Thunderbolt.get_task_df.
Returns:
The return value is List. This is because it may be divided when dumping by gokart.
"""
return [
gokart.target.make_target(file_path=os.path.join(os.path.dirname(self.workspace_directory), x)).load()
for x in self.tasks[task_id]['task_log']['file_path']
]

0 comments on commit 4d5b25f

Please sign in to comment.