Skip to content

Commit

Permalink
added AccessData epoch selection and saving functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
jacksykes17 committed May 4, 2022
1 parent 9604b84 commit ed36aab
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 3 deletions.
2 changes: 1 addition & 1 deletion coexist/__version__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@
# Date : 20.01.2021


VERSION = (0, 2, 2)
VERSION = (0, 2, 3)

__version__ = '.'.join(map(str, VERSION))
177 changes: 176 additions & 1 deletion coexist/access.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import contextlib
import subprocess
import pickle
import shutil
from datetime import datetime
from concurrent.futures import ProcessPoolExecutor

Expand Down Expand Up @@ -514,7 +515,12 @@ def update_paths(self, prefix):
for attr in ["results", "outputs", "script", "setup", "epochs",
"epochs_scaled", "history", "history_scaled"]:
prev = getattr(self, attr)
setattr(self, attr, os.path.join(prefix, os.path.split(prev)[1]))
if isinstance(prev, str):
setattr(
self,
attr,
os.path.join(prefix, os.path.split(prev)[1])
)


def save_history(self, setup, progress):
Expand Down Expand Up @@ -613,6 +619,22 @@ def load_epochs(self, access):
)


def copy(self):
'''Create a copy of an `AccessPaths` object.
'''

return AccessPaths(
directory = self.directory,
results = self.results,
outputs = self.outputs,
script = self.script,
setup = self.setup,
epochs = self.epochs,
epochs_scaled = self.epochs_scaled,
history = self.history,
history_scaled = self.history_scaled,
)



@autorepr(short = True)
Expand Down Expand Up @@ -1752,6 +1774,159 @@ def legacy(self, access_path):
self.results_scaled = results_scaled


def copy(self):
'''Return copy of `AccessData` object.
'''
data = AccessData.empty()
data.paths = self.paths.copy()
data.parameters = self.parameters.copy()
data.parameters_scaled = self.parameters_scaled.copy()
data.scaling = self.scaling.copy()
data.population = self.population
data.num_epochs = self.num_epochs
data.target = self.target
data.seed = self.seed
data.epochs = self.epochs.copy()
data.epochs_scaled = self.epochs_scaled.copy()
data.results = self.results.copy()
data.results_scaled = self.results_scaled.copy()
return data


def save(self, dirname):
'''Save `AccessData` to a new directory at `dirname`.
'''
# Copy previous folder to new location
shutil.copytree(self.paths.directory, dirname)

self.paths.update_paths(dirname)

# Save history
to_pad = len(self.results.columns) - len(self.parameters) - 1
columns = self.parameters.index.to_list() + [
f"error{i}" for i in range(to_pad)
] + ["error"]

np.savetxt(
self.paths.history,
self.results.to_numpy(),
header = " ".join(columns),
)

np.savetxt(
self.paths.history_scaled,
self.results_scaled.to_numpy(),
header = " ".join(columns),
)

# Save epochs
np.savetxt(
self.paths.epochs,
self.epochs.to_numpy(),
header = " ".join(
[f"{p}_mean" for p in self.parameters.index] +
[f"{p}_std" for p in self.parameters.index] +
["overall_std"]
),
)

np.savetxt(
self.paths.epochs_scaled,
self.epochs_scaled.to_numpy(),
header = " ".join(
[f"{p}_mean" for p in self.parameters.index] +
[f"{p}_std" for p in self.parameters.index] +
["overall_std"]
),
)

# Save setup
setup_dict = dict(
paths = self.paths.__dict__,
setup = dict(
parameters = self.parameters.to_dict(),
parameters_scaled = self.parameters_scaled.to_dict(),
scaling = self.scaling.tolist(),
population = self.population,
target = self.target,
seed = self.seed,
),
)

with open(self.paths.setup, "w") as f:
toml.dump(setup_dict, f)


def __getitem__(self, index):
# Select AccessData epochs
if isinstance(index, int):
# Allow negative indices
while index < 0:
index += self.num_epochs

if index >= self.num_epochs:
raise IndexError(textwrap.fill((
f"The index=`{index}` is out of bounds for AccessData "
f"with {self.num_epochs} epochs."
)))

data = self.copy()
data.num_epochs = 1
data.epochs = self.epochs.iloc[index:index + 1]
data.epochs_scaled = self.epochs_scaled.iloc[index:index + 1]
data.results = self.results.iloc[
index * self.population:(index + 1) * self.population
]
data.results_scaled = self.results_scaled.iloc[
index * self.population:(index + 1) * self.population
]
return data

elif isinstance(index, slice):
if index.step is not None and index.step != 1:
raise ValueError(textwrap.fill((
"Indexing with a `slice.step` is not yet available. "
"If this would be useful for you please get in touch."
)))

start = index.start if index.start is not None else 0
stop = index.stop if index.stop is not None else self.num_epochs

# Allow negative indices
while start < 0:
stop += self.num_epochs

while stop < 0:
stop += self.num_epochs

if stop > self.num_epochs or start >= self.num_epochs or \
start >= stop:
raise IndexError(textwrap.fill((
f"The slice=`{start}:{stop}` is out of bounds for "
f"AccessData with {self.num_epochs} epochs."
)))

data = self.copy()
data.num_epochs = stop - start
data.epochs = self.epochs.iloc[start:stop]
data.epochs_scaled = self.epochs_scaled.iloc[start:stop]
data.results = self.results.iloc[
start * self.population:stop * self.population
]
data.results_scaled = self.results_scaled.iloc[
start * self.population:stop * self.population
]
return data

else:
raise TypeError(textwrap.fill((
"Epoch selection via subscripting is only possible with "
"integer / slice indices (e.g. `access_data[5]` or "
"`access_data[2:5]`). Received index with type "
f"`{type(index)}`."
)))


def __repr__(self):
name = self.__class__.__name__
underline = "-" * 80
Expand Down
1 change: 1 addition & 0 deletions tests/cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def cleanup(
"__pycache__",
"access_seed123",
"access_seed124",
"access_seed123_restore",
"logs",
"vtk_export",
],
Expand Down
15 changes: 14 additions & 1 deletion tests/test_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,20 @@ def test_access_data():
print(data)

# Read data using direct class constructor
coexist.AccessData("access_data/access_seed123")
data = coexist.AccessData("access_data/access_seed123")

# Epoch selection
assert len(data[0].results) == data.population
assert len(data[0:1].results) == data.population
data[-1]
data[:-1]
data[0:]
data[:]

# Saving AccessData without last epoch
data[:-1].save("access_seed123_restore")
data2 = coexist.AccessData("access_seed123_restore")
assert data2.num_epochs == data.num_epochs - 1


def test_access():
Expand Down

0 comments on commit ed36aab

Please sign in to comment.