From 1b1c43bc39dca33096d8993129668e9f40a4a335 Mon Sep 17 00:00:00 2001 From: Diane Adjavon Date: Wed, 14 Feb 2024 11:45:38 -0500 Subject: [PATCH] fix: Add context manager to file read/writes Related to #82 --- dacapo/store/file_config_store.py | 9 ++++++--- dacapo/store/file_stats_store.py | 12 ++++++++---- dacapo/store/local_weights_store.py | 7 ++----- 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/dacapo/store/file_config_store.py b/dacapo/store/file_config_store.py index 725c4a46f..20257ecb2 100644 --- a/dacapo/store/file_config_store.py +++ b/dacapo/store/file_config_store.py @@ -98,10 +98,12 @@ def __save_insert(self, collection, data, ignore=None): file_store = collection / name if not file_store.exists(): - pickle.dump(dict(data), file_store.open("wb")) + with file_store.open("wb") as fd: + pickle.dump(dict(data), fd) else: - existing = pickle.load(file_store.open("rb")) + with file_store.open("rb") as fd: + existing = pickle.load(fd) if not self.__same_doc(existing, data, ignore): raise DuplicateNameError( @@ -113,7 +115,8 @@ def __save_insert(self, collection, data, ignore=None): def __load(self, collection, name): file_store = collection / name if file_store.exists(): - return pickle.load(file_store.open("rb")) + with file_store.open("rb") as fd: + return pickle.load(fd) else: raise ValueError(f"No config with name: {name} in collection: {collection}") diff --git a/dacapo/store/file_stats_store.py b/dacapo/store/file_stats_store.py index 8a299bcf7..bb8394218 100644 --- a/dacapo/store/file_stats_store.py +++ b/dacapo/store/file_stats_store.py @@ -88,12 +88,14 @@ def __store_training_stats(self, stats, begin, end, run_name): if docs: file_store = self.training_stats / run_name - pickle.dump(docs, file_store.open("wb")) + with file_store.open("wb") as fd: + pickle.dump(docs, fd) def __read_training_stats(self, run_name): file_store = self.training_stats / run_name if file_store.exists(): - docs = pickle.load(file_store.open("rb")) + with file_store.open("rb") as fd: + docs = pickle.load(fd) else: docs = [] stats = TrainingStats(converter.structure(docs, List[TrainingIterationStats])) @@ -117,12 +119,14 @@ def __store_validation_iteration_scores( if docs: file_store = self.validation_scores / run_name - pickle.dump(docs, file_store.open("wb")) + with file_store.open("wb") as fd: + pickle.dump(docs, fd) def __read_validation_iteration_scores(self, run_name): file_store = self.validation_scores / run_name if file_store.exists(): - docs = pickle.load(file_store.open("rb")) + with file_store.open("rb") as fd: + docs = pickle.load(fd) else: docs = [] scores = converter.structure(docs, List[ValidationIterationScores]) diff --git a/dacapo/store/local_weights_store.py b/dacapo/store/local_weights_store.py index c5f0ba5ff..d7a830528 100644 --- a/dacapo/store/local_weights_store.py +++ b/dacapo/store/local_weights_store.py @@ -103,11 +103,8 @@ def store_best(self, run: str, iteration: int, dataset: str, criterion: str): def retrieve_best(self, run: str, dataset: str, criterion: str) -> int: logger.info("Retrieving weights for run %s, criterion %s", run, criterion) - weights_info = json.loads( - (self.__get_weights_dir(run) / criterion / f"{dataset}.json") - .open("r") - .read() - ) + with (self.__get_weights_dir__(run) / criterion / f"{dataset}.json").open("r") as fd: + weights_info = json.load(fd) return weights_info["iteration"]