diff --git a/data_analysis/cache_manager.py b/data_analysis/cache_manager.py new file mode 100644 index 0000000..89e9c2f --- /dev/null +++ b/data_analysis/cache_manager.py @@ -0,0 +1,41 @@ +from pathlib import Path + +import logging + +import requests +from io import BytesIO + + +class CacheManager: + DATA_DIR = Path(__file__).parent.parent / "data_output" / "scratch" + + def __init__(self, verbose=False): + self.verbose = verbose + + def log(self, *args): + if self.verbose: + logging.info(*args) + + def retrieve(self, subdir: str, filename: str, url: str) -> BytesIO: + """Retrieve data from the local filesystem cache or a remote URL. + + Args: + subdir (str): subdirectory under DATA_DIR. + filename (str): filename in subdir. + url (str): fetch data from this URL if the file does not exist locally. + + Returns: + BytesIO: buffer containing payload data. + """ + cache_dir = self.DATA_DIR / subdir + if not cache_dir.exists(): + cache_dir.mkdir() + filepath = cache_dir / filename + if filepath.exists(): + self.log(f'Retrieved cached {url} from {subdir}/{filename}') + return BytesIO(filepath.open('rb').read()) + bytes_io = BytesIO(requests.get(url).content) + with filepath.open('wb') as ofh: + ofh.write(bytes_io.getvalue()) + self.log(f'Stored cached {url} in {subdir}/{filename}') + return bytes_io diff --git a/data_analysis/static_gtfs_analysis.py b/data_analysis/static_gtfs_analysis.py index 65bba88..759966f 100644 --- a/data_analysis/static_gtfs_analysis.py +++ b/data_analysis/static_gtfs_analysis.py @@ -27,6 +27,7 @@ from tqdm import tqdm from scrape_data.scrape_schedule_versions import create_schedule_list +from data_analysis.cache_manager import CacheManager VERSION_ID = "20220718" BUCKET = os.getenv('BUCKET_PUBLIC', 'chn-ghost-buses-public') @@ -360,11 +361,10 @@ def download_zip(version_id: str) -> zipfile.ZipFile: """ logger.info('Downloading CTA data') CTA_GTFS = zipfile.ZipFile( - BytesIO( - requests.get( - f"https://transitfeeds.com/p/chicago-transit-authority" - f"/165/{version_id}/download" - ).content + CacheManager(verbose=True).retrieve( + "transitfeeds_schedules", + f"{version_id}.zip", + f"https://transitfeeds.com/p/chicago-transit-authority/165/{version_id}/download" ) ) logging.info('Download complete') diff --git a/utils/s3_csv_reader.py b/utils/s3_csv_reader.py index ae1d63c..704feef 100644 --- a/utils/s3_csv_reader.py +++ b/utils/s3_csv_reader.py @@ -1,6 +1,9 @@ import pandas as pd from pathlib import Path import data_analysis.compare_scheduled_and_rt as csrt +from data_analysis.cache_manager import CacheManager + +CACHE_MANAGER = CacheManager(verbose=False) def read_csv(filename: str | Path) -> pd.DataFrame: """Read pandas csv from S3 @@ -14,9 +17,14 @@ def read_csv(filename: str | Path) -> pd.DataFrame: if isinstance(filename, str): filename = Path(filename) s3_filename = '/'.join(filename.parts[-2:]) + cache_filename = f'{filename.stem}.csv' df = pd.read_csv( - f'https://{csrt.BUCKET_PUBLIC}.s3.us-east-2.amazonaws.com/{s3_filename}', - low_memory=False + CACHE_MANAGER.retrieve( + 's3csv', + cache_filename, + f'https://{csrt.BUCKET_PUBLIC}.s3.us-east-2.amazonaws.com/{s3_filename}', + ), + low_memory=False ) return df \ No newline at end of file