From 4816a825a5660452c6e4e88081ea5bc8f5411366 Mon Sep 17 00:00:00 2001 From: Diogenes Analytics Date: Mon, 8 Jan 2024 17:09:57 -0500 Subject: [PATCH] Adding save/load methods (#25) --- notebooks/demo/mnist_dataset.ipynb | 3 ++ notebooks/demo/tf_flowers.ipynb | 3 ++ src/autoencoder/data/anomaly.py | 87 +++++++++++++++++++++++------- 3 files changed, 73 insertions(+), 20 deletions(-) diff --git a/notebooks/demo/mnist_dataset.ipynb b/notebooks/demo/mnist_dataset.ipynb index ad3bd92..e267a32 100644 --- a/notebooks/demo/mnist_dataset.ipynb +++ b/notebooks/demo/mnist_dataset.ipynb @@ -195,6 +195,9 @@ "# get instance\n", "mnist_recon_error = ReconstructionError(autoencoder, test_ds, axis=(1, 2))\n", "\n", + "# calculate recon error\n", + "mnist_recon_error.calculate_error()\n", + "\n", "# view distribution\n", "mnist_recon_error.histogram(\"MNIST Autoencoder: Reconstruction Error Distribution\")" ] diff --git a/notebooks/demo/tf_flowers.ipynb b/notebooks/demo/tf_flowers.ipynb index 33d0db7..d18360c 100644 --- a/notebooks/demo/tf_flowers.ipynb +++ b/notebooks/demo/tf_flowers.ipynb @@ -205,6 +205,9 @@ "# get instance\n", "tfflower_recon_error = ReconstructionError(autoencoder, x_val)\n", "\n", + "# calculate recon error\n", + "tfflower_recon_error.calculate_error()\n", + "\n", "# view distribution\n", "tfflower_recon_error.histogram(\"tf_flowers Autoencoder: Reconstruction Error Distribution\", bins=[100])" ] diff --git a/src/autoencoder/data/anomaly.py b/src/autoencoder/data/anomaly.py index 2a137f1..501cafd 100644 --- a/src/autoencoder/data/anomaly.py +++ b/src/autoencoder/data/anomaly.py @@ -1,6 +1,7 @@ """Tools for evaluating an autoencoder's perfomance on a dataset.""" -from dataclasses import InitVar +import json from dataclasses import dataclass +from pathlib import Path from typing import Any from typing import Generator from typing import List @@ -24,26 +25,12 @@ class ReconstructionError: ae: Union[tf.keras.Model, BaseAutoencoder] dataset: tf.data.Dataset axis: Tuple[int, ...] = (1, 2, 3) - file_paths: InitVar[Optional[List[str]]] = None - def __post_init__(self, file_paths: Optional[List[str]]) -> None: - """Calculate and store errors, and threshold.""" - # check file paths - if file_paths is None: - # get file paths from dataset - file_paths = self.get_file_paths(self.dataset) - - # get the reconstrution error - self.errors = pd.DataFrame( - data=self.gen_reconstruction_error(), - columns=["reconstruction_error"], - index=file_paths, - ) - - # store 95th threshold - self.threshold = self.calc_95th_threshold( - self.errors["reconstruction_error"].values.tolist() - ) + def _check_data_attrs_set(self) -> None: + """Make sure errors and threshold attributes have been set.""" + assert all( + [hasattr(self, "errors"), hasattr(self, "threshold")] + ), "errors/threshold attributes should be set before running this method." @staticmethod def get_file_paths(dataset: tf.data.Dataset) -> Optional[List[str]]: @@ -77,6 +64,25 @@ def calc_max_threshold( """Calculate threshold for anomalous data using maximum value.""" return np.max(errors) + def calculate_error(self, file_paths: Optional[List[str]] = None) -> None: + """Calculate and store errors, and threshold.""" + # check file paths + if file_paths is None: + # get file paths from dataset + file_paths = self.get_file_paths(self.dataset) + + # get the reconstrution error + self.errors = pd.DataFrame( + data=self.gen_reconstruction_error(), + columns=["reconstruction_error"], + index=file_paths, + ) + + # store 95th threshold + self.threshold = self.calc_95th_threshold( + self.errors["reconstruction_error"].values.tolist() + ) + def gen_batch_predictions( self, axis: Optional[Tuple[int, ...]] = None ) -> Generator[Tuple[tf.Tensor, tf.Tensor, tf.Tensor], None, None]: @@ -112,6 +118,10 @@ def _plot_error_distribution( alphas: Optional[List[float]], bins: Optional[List[int]], ) -> None: + """Handles all histogram plots.""" + # make sure attrs set + self._check_data_attrs_set() + # setup list of data and labels error_data = [self.errors["reconstruction_error"].values.tolist()] @@ -194,3 +204,40 @@ def probability_distribution( density=True, alphas=alphas, ) + + def save(self, output_path: Union[str, Path]) -> None: + """Save object instance data to path.""" + # creat path obj + output_path = Path(output_path) + + # make sure path doesn't exist + assert ( + not output_path.exists() + ), "Save method expects a new directory not an existing one." + + # now create it + output_path.mkdir() + + # save errors dataframe + self.errors.to_csv(output_path / "errors.csv", columns=["reconstruction_error"]) + + # open new JSON file + with open(output_path / "threshold.json", "w") as outfile: + # save threshold with pretty print + json.dump({"threshold": self.threshold}, outfile, indent=4) + + def load(self, input_path: Union[str, Path]) -> None: + """Load previously saved object instance data.""" + # create path obj + input_path = Path(input_path) + + # make sure path doesn't exist + assert input_path.exists(), "Load method cannot find directory path." + + # open threshold file + with open(input_path / "threshold.json") as infile: + # update threshold + self.threshold = json.load(infile)["threshold"] + + # now get errors attribute + self.errors = pd.read_csv(input_path / "errors.csv")