From ed36aabbf505085f2ba51543ed3ba1a764ff5bfb Mon Sep 17 00:00:00 2001 From: jacksykes17 Date: Wed, 4 May 2022 17:58:05 +0100 Subject: [PATCH] added AccessData epoch selection and saving functionality --- coexist/__version__.py | 2 +- coexist/access.py | 177 ++++++++++++++++++++++++++++++++++++++++- tests/cleanup.py | 1 + tests/test_access.py | 15 +++- 4 files changed, 192 insertions(+), 3 deletions(-) diff --git a/coexist/__version__.py b/coexist/__version__.py index ad11e19..05bb6fb 100644 --- a/coexist/__version__.py +++ b/coexist/__version__.py @@ -6,6 +6,6 @@ # Date : 20.01.2021 -VERSION = (0, 2, 2) +VERSION = (0, 2, 3) __version__ = '.'.join(map(str, VERSION)) diff --git a/coexist/access.py b/coexist/access.py index caf084c..a18d87a 100644 --- a/coexist/access.py +++ b/coexist/access.py @@ -14,6 +14,7 @@ import contextlib import subprocess import pickle +import shutil from datetime import datetime from concurrent.futures import ProcessPoolExecutor @@ -514,7 +515,12 @@ def update_paths(self, prefix): for attr in ["results", "outputs", "script", "setup", "epochs", "epochs_scaled", "history", "history_scaled"]: prev = getattr(self, attr) - setattr(self, attr, os.path.join(prefix, os.path.split(prev)[1])) + if isinstance(prev, str): + setattr( + self, + attr, + os.path.join(prefix, os.path.split(prev)[1]) + ) def save_history(self, setup, progress): @@ -613,6 +619,22 @@ def load_epochs(self, access): ) + def copy(self): + '''Create a copy of an `AccessPaths` object. + ''' + + return AccessPaths( + directory = self.directory, + results = self.results, + outputs = self.outputs, + script = self.script, + setup = self.setup, + epochs = self.epochs, + epochs_scaled = self.epochs_scaled, + history = self.history, + history_scaled = self.history_scaled, + ) + @autorepr(short = True) @@ -1752,6 +1774,159 @@ def legacy(self, access_path): self.results_scaled = results_scaled + def copy(self): + '''Return copy of `AccessData` object. + ''' + data = AccessData.empty() + data.paths = self.paths.copy() + data.parameters = self.parameters.copy() + data.parameters_scaled = self.parameters_scaled.copy() + data.scaling = self.scaling.copy() + data.population = self.population + data.num_epochs = self.num_epochs + data.target = self.target + data.seed = self.seed + data.epochs = self.epochs.copy() + data.epochs_scaled = self.epochs_scaled.copy() + data.results = self.results.copy() + data.results_scaled = self.results_scaled.copy() + return data + + + def save(self, dirname): + '''Save `AccessData` to a new directory at `dirname`. + ''' + # Copy previous folder to new location + shutil.copytree(self.paths.directory, dirname) + + self.paths.update_paths(dirname) + + # Save history + to_pad = len(self.results.columns) - len(self.parameters) - 1 + columns = self.parameters.index.to_list() + [ + f"error{i}" for i in range(to_pad) + ] + ["error"] + + np.savetxt( + self.paths.history, + self.results.to_numpy(), + header = " ".join(columns), + ) + + np.savetxt( + self.paths.history_scaled, + self.results_scaled.to_numpy(), + header = " ".join(columns), + ) + + # Save epochs + np.savetxt( + self.paths.epochs, + self.epochs.to_numpy(), + header = " ".join( + [f"{p}_mean" for p in self.parameters.index] + + [f"{p}_std" for p in self.parameters.index] + + ["overall_std"] + ), + ) + + np.savetxt( + self.paths.epochs_scaled, + self.epochs_scaled.to_numpy(), + header = " ".join( + [f"{p}_mean" for p in self.parameters.index] + + [f"{p}_std" for p in self.parameters.index] + + ["overall_std"] + ), + ) + + # Save setup + setup_dict = dict( + paths = self.paths.__dict__, + setup = dict( + parameters = self.parameters.to_dict(), + parameters_scaled = self.parameters_scaled.to_dict(), + scaling = self.scaling.tolist(), + population = self.population, + target = self.target, + seed = self.seed, + ), + ) + + with open(self.paths.setup, "w") as f: + toml.dump(setup_dict, f) + + + def __getitem__(self, index): + # Select AccessData epochs + if isinstance(index, int): + # Allow negative indices + while index < 0: + index += self.num_epochs + + if index >= self.num_epochs: + raise IndexError(textwrap.fill(( + f"The index=`{index}` is out of bounds for AccessData " + f"with {self.num_epochs} epochs." + ))) + + data = self.copy() + data.num_epochs = 1 + data.epochs = self.epochs.iloc[index:index + 1] + data.epochs_scaled = self.epochs_scaled.iloc[index:index + 1] + data.results = self.results.iloc[ + index * self.population:(index + 1) * self.population + ] + data.results_scaled = self.results_scaled.iloc[ + index * self.population:(index + 1) * self.population + ] + return data + + elif isinstance(index, slice): + if index.step is not None and index.step != 1: + raise ValueError(textwrap.fill(( + "Indexing with a `slice.step` is not yet available. " + "If this would be useful for you please get in touch." + ))) + + start = index.start if index.start is not None else 0 + stop = index.stop if index.stop is not None else self.num_epochs + + # Allow negative indices + while start < 0: + stop += self.num_epochs + + while stop < 0: + stop += self.num_epochs + + if stop > self.num_epochs or start >= self.num_epochs or \ + start >= stop: + raise IndexError(textwrap.fill(( + f"The slice=`{start}:{stop}` is out of bounds for " + f"AccessData with {self.num_epochs} epochs." + ))) + + data = self.copy() + data.num_epochs = stop - start + data.epochs = self.epochs.iloc[start:stop] + data.epochs_scaled = self.epochs_scaled.iloc[start:stop] + data.results = self.results.iloc[ + start * self.population:stop * self.population + ] + data.results_scaled = self.results_scaled.iloc[ + start * self.population:stop * self.population + ] + return data + + else: + raise TypeError(textwrap.fill(( + "Epoch selection via subscripting is only possible with " + "integer / slice indices (e.g. `access_data[5]` or " + "`access_data[2:5]`). Received index with type " + f"`{type(index)}`." + ))) + + def __repr__(self): name = self.__class__.__name__ underline = "-" * 80 diff --git a/tests/cleanup.py b/tests/cleanup.py index bafe75c..08193c0 100644 --- a/tests/cleanup.py +++ b/tests/cleanup.py @@ -21,6 +21,7 @@ def cleanup( "__pycache__", "access_seed123", "access_seed124", + "access_seed123_restore", "logs", "vtk_export", ], diff --git a/tests/test_access.py b/tests/test_access.py index bb96478..a327222 100644 --- a/tests/test_access.py +++ b/tests/test_access.py @@ -23,7 +23,20 @@ def test_access_data(): print(data) # Read data using direct class constructor - coexist.AccessData("access_data/access_seed123") + data = coexist.AccessData("access_data/access_seed123") + + # Epoch selection + assert len(data[0].results) == data.population + assert len(data[0:1].results) == data.population + data[-1] + data[:-1] + data[0:] + data[:] + + # Saving AccessData without last epoch + data[:-1].save("access_seed123_restore") + data2 = coexist.AccessData("access_seed123_restore") + assert data2.num_epochs == data.num_epochs - 1 def test_access():