Skip to content

Commit

Permalink
Adding save/load methods (#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
DiogenesAnalytics committed Jan 8, 2024
1 parent 0be245c commit 4816a82
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 20 deletions.
3 changes: 3 additions & 0 deletions notebooks/demo/mnist_dataset.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,9 @@
"# get instance\n",
"mnist_recon_error = ReconstructionError(autoencoder, test_ds, axis=(1, 2))\n",
"\n",
"# calculate recon error\n",
"mnist_recon_error.calculate_error()\n",
"\n",
"# view distribution\n",
"mnist_recon_error.histogram(\"MNIST Autoencoder: Reconstruction Error Distribution\")"
]
Expand Down
3 changes: 3 additions & 0 deletions notebooks/demo/tf_flowers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,9 @@
"# get instance\n",
"tfflower_recon_error = ReconstructionError(autoencoder, x_val)\n",
"\n",
"# calculate recon error\n",
"tfflower_recon_error.calculate_error()\n",
"\n",
"# view distribution\n",
"tfflower_recon_error.histogram(\"tf_flowers Autoencoder: Reconstruction Error Distribution\", bins=[100])"
]
Expand Down
87 changes: 67 additions & 20 deletions src/autoencoder/data/anomaly.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Tools for evaluating an autoencoder's perfomance on a dataset."""
from dataclasses import InitVar
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from typing import Generator
from typing import List
Expand All @@ -24,26 +25,12 @@ class ReconstructionError:
ae: Union[tf.keras.Model, BaseAutoencoder]
dataset: tf.data.Dataset
axis: Tuple[int, ...] = (1, 2, 3)
file_paths: InitVar[Optional[List[str]]] = None

def __post_init__(self, file_paths: Optional[List[str]]) -> None:
"""Calculate and store errors, and threshold."""
# check file paths
if file_paths is None:
# get file paths from dataset
file_paths = self.get_file_paths(self.dataset)

# get the reconstrution error
self.errors = pd.DataFrame(
data=self.gen_reconstruction_error(),
columns=["reconstruction_error"],
index=file_paths,
)

# store 95th threshold
self.threshold = self.calc_95th_threshold(
self.errors["reconstruction_error"].values.tolist()
)
def _check_data_attrs_set(self) -> None:
"""Make sure errors and threshold attributes have been set."""
assert all(
[hasattr(self, "errors"), hasattr(self, "threshold")]
), "errors/threshold attributes should be set before running this method."

@staticmethod
def get_file_paths(dataset: tf.data.Dataset) -> Optional[List[str]]:
Expand Down Expand Up @@ -77,6 +64,25 @@ def calc_max_threshold(
"""Calculate threshold for anomalous data using maximum value."""
return np.max(errors)

def calculate_error(self, file_paths: Optional[List[str]] = None) -> None:
"""Calculate and store errors, and threshold."""
# check file paths
if file_paths is None:
# get file paths from dataset
file_paths = self.get_file_paths(self.dataset)

# get the reconstrution error
self.errors = pd.DataFrame(
data=self.gen_reconstruction_error(),
columns=["reconstruction_error"],
index=file_paths,
)

# store 95th threshold
self.threshold = self.calc_95th_threshold(
self.errors["reconstruction_error"].values.tolist()
)

def gen_batch_predictions(
self, axis: Optional[Tuple[int, ...]] = None
) -> Generator[Tuple[tf.Tensor, tf.Tensor, tf.Tensor], None, None]:
Expand Down Expand Up @@ -112,6 +118,10 @@ def _plot_error_distribution(
alphas: Optional[List[float]],
bins: Optional[List[int]],
) -> None:
"""Handles all histogram plots."""
# make sure attrs set
self._check_data_attrs_set()

# setup list of data and labels
error_data = [self.errors["reconstruction_error"].values.tolist()]

Expand Down Expand Up @@ -194,3 +204,40 @@ def probability_distribution(
density=True,
alphas=alphas,
)

def save(self, output_path: Union[str, Path]) -> None:
"""Save object instance data to path."""
# creat path obj
output_path = Path(output_path)

# make sure path doesn't exist
assert (
not output_path.exists()
), "Save method expects a new directory not an existing one."

# now create it
output_path.mkdir()

# save errors dataframe
self.errors.to_csv(output_path / "errors.csv", columns=["reconstruction_error"])

# open new JSON file
with open(output_path / "threshold.json", "w") as outfile:
# save threshold with pretty print
json.dump({"threshold": self.threshold}, outfile, indent=4)

def load(self, input_path: Union[str, Path]) -> None:
"""Load previously saved object instance data."""
# create path obj
input_path = Path(input_path)

# make sure path doesn't exist
assert input_path.exists(), "Load method cannot find directory path."

# open threshold file
with open(input_path / "threshold.json") as infile:
# update threshold
self.threshold = json.load(infile)["threshold"]

# now get errors attribute
self.errors = pd.read_csv(input_path / "errors.csv")

0 comments on commit 4816a82

Please sign in to comment.