diff --git a/.idea/brain_pipe.iml b/.idea/brain_pipe.iml
index 84173d2..a3959c8 100644
--- a/.idea/brain_pipe.iml
+++ b/.idea/brain_pipe.iml
@@ -6,11 +6,11 @@
-
+
-
-
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
index 52028f0..a4652f3 100644
--- a/.idea/misc.xml
+++ b/.idea/misc.xml
@@ -1,4 +1,4 @@
-
+
\ No newline at end of file
diff --git a/brain_pipe/save/default.py b/brain_pipe/save/default.py
index 0580e2c..859ecd0 100644
--- a/brain_pipe/save/default.py
+++ b/brain_pipe/save/default.py
@@ -1,10 +1,11 @@
"""Default save class."""
import abc
+import copy
import gc
import json
import logging
import os
-from typing import Any, Dict
+from typing import Any, Dict, Optional, Union, Callable, Mapping
import numpy as np
@@ -13,8 +14,19 @@
pickle_load_wrapper,
)
from brain_pipe.save.base import Save
+from brain_pipe.utils.list import wrap_in_list
from brain_pipe.utils.multiprocess import MultiprocessingSingleton
+# Shorthand interfaces.
+CheckInterface = Callable[[Dict[str, Any], str, Dict[str, Any]], Union[str, bool]]
+FilenameFnInterface = Callable[[Dict[str, Any], Optional[str], Optional[str]], str]
+SaveFnInterface = Union[
+ Callable[[Any, str], None], Mapping[str, Callable[[Any, str], None]], None
+]
+ReloadFnInterface = Union[
+ Callable[[str], Any], Mapping[str, Callable[[str], Any]], None
+]
+
def default_metadata_key_fn(data_dict: Dict[str, Any]) -> str:
"""Generate a key for the metadata.
@@ -41,54 +53,125 @@ def default_metadata_key_fn(data_dict: Dict[str, Any]) -> str:
raise ValueError("No data_path or stimulus_path in data_dict.")
-def default_filename_fn(data_dict, feature_name, set_name=None, separator="_-_"):
- """Generate a filename for the data_dict.
+class DefaultFilenameFn(FilenameFnInterface):
+ """Default filename function to create paths to save data."""
- Parameters
- ----------
- data_dict: Dict[str, Any]
- The data dict containing the data to save.
- feature_name: Optional[str]
- The name of the feature.
- set_name: Optional[str]
- The name of the set. If no set name is given, the set name is not
- included in the filename.
- separator: str
- The separator to use between the different parts of the filename.
+ SPLIT_CHAR = "/"
- Returns
- -------
- str
- The filename.
- """
- parts = []
- if "data_path" in data_dict:
- parts += [os.path.basename(data_dict["data_path"]).split(".")[0]]
+ def __init__(
+ self,
+ path_keys=("data_path", "stimulus_path"),
+ other_keys=("event_info/snr",),
+ separator="_-_",
+ data_dict_extension=".data_dict",
+ feature_extension=".npy",
+ ):
+ """Create a new DefaultFilenameFn instance.
- if "stimulus_path" in data_dict:
- parts += [os.path.basename(data_dict["stimulus_path"]).split(".")[0]]
+ Parameters
+ ----------
+ path_keys: Sequence[str]
+ The keys of the paths to include in the filename.
+ other_keys: Sequence[str]
+ The keys of other data to include in the filename.
+ separator: str
+ The separator to use between parts of the filename.
+ data_dict_extension: str
+ The extension to use when saving the entire data_dict.
+ feature_extension: str
+ The extension to use when saving a single feature.
+ """
+ self.path_keys = path_keys
+ self.other_keys = other_keys
+ self.separator = separator
+ self.data_dict_extension = data_dict_extension
+ self.feature_extension = feature_extension
- if feature_name is None and set_name is None:
- return separator.join(parts) + ".data_dict"
+ def __call__(self, data_dict, feature_name=None, set_name=None):
+ """Generate a filename for the data_dict.
- if "event_info" in data_dict and "snr" in data_dict["event_info"]:
- parts += [str(data_dict["event_info"]["snr"])]
+ Parameters
+ ----------
+ data_dict: Dict[str, Any]
+ The data dict containing the data to save.
+ feature_name: Optional[str]
+ The name of the feature.
+ set_name: Optional[str]
+ The name of the set. If no set name is given, the set name is not
+ included in the filename.
- keys = parts + [feature_name]
- if set_name is not None:
- keys = [set_name] + keys
- return separator.join(keys) + ".npy"
+ Returns
+ -------
+ str
+ The filename.
+ """
+ parts = []
+
+ # Add indicated path keys to filename
+ for path_key in self.path_keys:
+ if path_key in data_dict:
+ parts.append(
+ "".join(os.path.basename(data_dict[path_key]).rsplit(".")[0])
+ )
+ # If no feature name or set name is given, return the data_dict filename.
+ if feature_name is None and set_name is None:
+ return self.separator.join(parts) + self.data_dict_extension
+
+ # Add indicated other keys to data
+ for full_key in self.other_keys:
+ item = data_dict
+ append = True
+ for key in full_key.split(self.SPLIT_CHAR):
+ if key not in item:
+ append = False
+ break
+ item = item[key]
+ if append:
+ # Has to be added as a string to be able to be used in filename
+ parts.append(str(item))
+
+ # Add the feature name as the last part of the filename, if given
+ if feature_name is not None:
+ parts.append(feature_name)
+ # Add the set name as the first part of the filename, if given
+ if set_name is not None:
+ parts.insert(0, set_name)
+ return self.separator.join(parts) + self.feature_extension
+
+
+class AttachSave(abc.ABC):
+ """Mixin class to attach a Save object."""
+
+ def __init__(self, saver=None):
+ """Initialize the AttachSaver.
+ Parameters
+ ----------
+ saver: Optional[Save]
+ The saver to use. Can be attached later with :meth:`attach_saver`, but
+ that will create a new object.
-class CheckFunctor(abc.ABC):
- """Functor to use with DefaultSave to check something about of a metadata item."""
+ """
+ self.saver = saver
- def __init__(self):
- """Create a new CheckFunctor."""
- self.saver = None
+ def attach_saver(self, saver):
+ """Initialize a new object with the saver.
+
+ Parameters
+ ----------
+ saver: Save
+ The saver.
+ """
+ new_ = copy.deepcopy(self)
+ new_.saver = saver
+ return new_
+
+
+class CheckFunctor(CheckInterface, AttachSave, abc.ABC):
+ """Functor to use with DefaultSave to check something about of a metadata item."""
@abc.abstractmethod
- def __call__(self, metadata_item, feature_name):
+ def __call__(self, metadata_item, feature_name, data_dict):
"""Check something about the metadata item.
Parameters
@@ -99,26 +182,15 @@ def __call__(self, metadata_item, feature_name):
optionally :attr:`DefaultSave.OLD_FORMAT_STR`
feature_name: Optional[str]
The name of the feature.
+ data_dict: Dict[str, Any]
+ The data dict containing the data to save.
Returns
-------
- Optional[str]
- This value will normally be passed to
- :meth:`DefaultSave._iterate_over_metadata_item`. If this value is
- :attr:`DefaultSave.STOP_ITERATING`, the iteration will stop immediately.
- """
- pass
-
- @abc.abstractmethod
- def clear(self, *args, **kwargs):
- """Clear the state of the CheckFunctor.
-
- Parameters
- ----------
- args: List[Any]
- Additional positional arguments.
- kwargs: Dict[str, Any]
- Additional keyword arguments.
+ Union[Any, bool, None]
+ False if the data does not pass the check, otherwise a useful representation
+ of the data (such as the path). If None is returned, the check was not
+ applicable.
"""
pass
@@ -126,12 +198,7 @@ def clear(self, *args, **kwargs):
class IsDoneCheck(CheckFunctor):
"""Check if data has already been saved."""
- def __init__(self):
- """Create a new IsDoneCheck."""
- super().__init__()
- self.is_done = []
-
- def __call__(self, metadata_item, feature_name):
+ def __call__(self, metadata_item, feature_name, data_dict):
"""Check if data has already been saved.
Parameters
@@ -145,44 +212,28 @@ def __call__(self, metadata_item, feature_name):
Returns
-------
- Optional[str]
- This value will normally be passed to
- :meth:`DefaultSave._iterate_over_metadata_item`. If this value is
- :attr:`DefaultSave.STOP_ITERATING`, the iteration will stop immediately.
+ Union[str, bool, None]
+ False if the data has not been saved.
+ None if the check was not applicable for the given feature name.
+ The path to the data otherwise
"""
- if metadata_item[self.saver.FEATURE_NAME_STR] is feature_name:
+ if metadata_item[self.saver.metadata.FEATURE_NAME_STR] == feature_name:
path = os.path.join(
- self.saver.root_dir, metadata_item[self.saver.FILENAME_STR]
+ self.saver.root_dir, metadata_item[self.saver.metadata.FILENAME_STR]
)
- path_is_done = os.path.exists(path)
- self.is_done.append(path_is_done)
- if not path_is_done:
- return self.saver.STOP_ITERATION
-
- def clear(self, *args, **kwargs):
- """Clear the state of the CheckFunctor.
-
- Parameters
- ----------
- args: List[Any]
- Additional positional arguments.
- kwargs: Dict[str, Any]
- Additional keyword arguments.
- """
- self.is_done.clear()
+ return path if os.path.exists(path) else False
class IsReloadableCheck(CheckFunctor):
"""Check if data has already been saved and is reloadable."""
- def __init__(self):
- """Create a new ReloadableCheck."""
- super().__init__()
- self.reloadable = None
- self.expected_filename = None
-
- def __call__(self, metadata_item, feature_name):
+ def __call__(
+ self,
+ metadata_item: Dict[str, Any],
+ feature_name: Optional[str],
+ data_dict: Dict[str, Any],
+ ):
"""Check if data has already been saved and is reloadable.
Parameters
@@ -193,44 +244,366 @@ def __call__(self, metadata_item, feature_name):
optionally :attr:`DefaultSave.OLD_FORMAT_STR`.
feature_name: Optional[str]
The name of the feature.
+ data_dict: Dict[str, Any]
+ The data dict containing the unprocessed data
Returns
-------
- Optional[str]
- This value will normally be passed to
- :meth:`DefaultSave._iterate_over_metadata_item`. If this value is
- :attr:`DefaultSave.STOP_ITERATING`, the iteration will stop immediately.
+ Union[str, bool, None]
+ False if the data is not reloadable.
+ None if the check was not applicable for the given feature name.
+ The path to the data otherwise.
"""
- if metadata_item[self.saver.FEATURE_NAME_STR] is None:
- is_old_format = metadata_item.get(self.saver.OLD_FORMAT_STR, False)
- filename = metadata_item[self.saver.FILENAME_STR]
+ if metadata_item[self.saver.metadata.FEATURE_NAME_STR] is None:
+ is_old_format = self.saver.metadata.is_old_format(metadata_item)
+ filename = metadata_item[self.saver.metadata.FILENAME_STR]
+ expected_filename = os.path.relpath(
+ self.saver.filename_fn(data_dict, None, None), self.saver.root_dir
+ )
# If old format and not the expected filename to be reloadable
# then skip
- if is_old_format and filename != self.expected_filename:
- return
+ if is_old_format and filename != expected_filename:
+ return False
path = os.path.join(
- self.saver.root_dir, metadata_item[self.saver.FILENAME_STR]
+ self.saver.root_dir, metadata_item[self.saver.metadata.FILENAME_STR]
+ )
+ return path if os.path.exists(path) else False
+
+
+class SaveMetadata(AttachSave, abc.ABC):
+ """Abstract class for metadata to use when saving/reloading."""
+
+ FEATURE_NAME_STR = "feature_name"
+ FILENAME_STR = "filename"
+ SET_NAME_STR = "set_name"
+ OLD_FORMAT_STR = "old_format"
+
+ def __init__(
+ self, key_fn: Callable[[Dict[str, Any]], str] = default_metadata_key_fn
+ ):
+ """Create a new SaveMetadata.
+
+ Parameters
+ ----------
+ key_fn: Callable[[Dict[str, Any]], str]
+ The function to use to get the key from the data dict.
+ """
+ super().__init__()
+ self.key_fn = key_fn
+
+ @abc.abstractmethod
+ def clear(self):
+ """Clear the metadata."""
+
+ @abc.abstractmethod
+ def add(
+ self,
+ data_dict: Dict[str, Any],
+ filepath: str,
+ feature_name: Optional[str],
+ set_name: Optional[str],
+ ):
+ """Add a metadata entry.
+
+ Parameters
+ ----------
+ data_dict: Dict[str, Any]
+ The data dict containing the data to save.
+ filepath: str
+ The path to the data.
+ feature_name: Optional[str]
+ The name of the feature.
+ set_name: Optional[str]
+ The name of the set.
+ """
+
+ @abc.abstractmethod
+ def __contains__(self, item):
+ """Check if the metadata contains a certain item.
+
+ Parameters
+ ----------
+ item: Any
+ The item to check.
+
+ Returns
+ -------
+ bool
+ Whether the item is contained.
+ """
+
+ @abc.abstractmethod
+ def __getitem__(self, key: Any):
+ """Retrieve a metadata item.
+
+ Parameters
+ ----------
+ key: Any
+ The key to retrieve.
+
+ Returns
+ -------
+ Dict[str, Any]
+ The metadata item.
+ """
+
+ @classmethod
+ def is_old_format(cls, metadata_item: Union[str, Dict[str, Any]]):
+ """Check if the metadata item is in the old format.
+
+ Parameters
+ ----------
+ metadata_item: Union[str, Dict[str, Any]]
+ The metadata item to check.
+
+ Returns
+ -------
+ bool
+ Whether the metadata item is in the old format.
+ """
+ return isinstance(metadata_item, str) or cls.OLD_FORMAT_STR in metadata_item
+
+
+class OldMetadataCompliant(abc.ABC):
+ """Mixin class for metadata that is compliant with the old format."""
+
+ @abc.abstractmethod
+ def convert_old_format(self, metadata_item: str, data_dict: Dict[str, Any]):
+ """Convert the metadata item from the old format.
+
+ Parameters
+ ----------
+ metadata_item: str
+ The metadata item to convert.
+
+ data_dict: Dict[str, Any]
+ The data dict containing the data to save.
+
+ Returns
+ -------
+ Dict[str, Any]
+ The converted metadata item.
+ """
+
+
+class DefaultSaveMetadata(OldMetadataCompliant, SaveMetadata):
+ """Implementation of SaveMetadata to work with DefaultSave."""
+
+ def __init__(
+ self,
+ key_fn: Callable[[Dict[str, Any]], str] = default_metadata_key_fn,
+ filename: str = ".save_metadata.json",
+ ):
+ """Create a new DefaultSaveMetadata.
+
+ Parameters
+ ----------
+ key_fn: Callable[[Dict[str, Any]], str]
+ The function to use to get the key from the data dict.
+ filename: str
+ The filename to use for the metadata.
+ """
+ super().__init__(key_fn=key_fn)
+ self.filename = filename
+ self.saver = None
+
+ def get_path(self):
+ """Get the path to the metadata file.
+
+ Returns
+ -------
+ str
+ The path to the metadata file.
+ """
+ if self.saver is not None and isinstance(self.saver, DefaultSave):
+ path = os.path.join(self.saver.root_dir, self.filename)
+ return path
+ return self.filename
+
+ def get_relpath(self, path: str):
+ """Construct a relative path with regard to save folder.
+
+ Parameters
+ ----------
+ path: str
+ The path to make relative. If no saver is attached, this is returned as is.
+
+ Returns
+ -------
+ str
+ The relative path.
+ """
+ if (
+ self.saver is not None
+ and isinstance(self.saver, DefaultSave)
+ and os.path.isabs(path)
+ ):
+ return os.path.relpath(path, self.saver.root_dir)
+ return path
+
+ @property
+ def lock(self):
+ """Retrieve the lock to use for the metadata file.
+
+ Returns
+ -------
+ multiprocessing.Lock
+ The lock to use for the metadata file.
+ """
+ return MultiprocessingSingleton.get_lock(self.get_path())
+
+ def clear(self):
+ """Clear the metadata."""
+ self.lock.acquire()
+ metadata_path = self.get_path()
+ if os.path.exists(metadata_path):
+ os.remove(metadata_path)
+ self.lock.release()
+
+ def get_metadata_for_savepath(
+ self,
+ path: str,
+ feature_name: Optional[str],
+ set_name: Optional[str],
+ from_old_format=False,
+ ):
+ """Get the metadata associated for path where data is saved.
+
+ Parameters
+ ----------
+ path: str
+ The path to the data.
+ feature_name: Optional[str]
+ The name of the feature.
+ set_name: Optional[str]
+ The name of the set.
+ from_old_format: bool
+ Whether the metadata is in the old format.
+
+ Returns
+ -------
+ Dict[str, Any]
+ The metadata associated with the path.
+ """
+ metadata = {
+ self.FILENAME_STR: self.get_relpath(path),
+ self.FEATURE_NAME_STR: feature_name,
+ self.SET_NAME_STR: set_name,
+ }
+ if from_old_format:
+ metadata[self.OLD_FORMAT_STR] = True
+ return metadata
+
+ def convert_old_format(self, metadata_item: str, data_dict: Dict[str, Any]):
+ """Convert the metadata item from the old format.
+
+ Parameters
+ ----------
+ metadata_item: str
+ The metadata item to convert.
+ data_dict: Dict[str, Any]
+ The data dict containing the data to save.
+
+ Returns
+ -------
+ Dict[str, Any]
+ The converted metadata item.
+ """
+ return self.get_metadata_for_savepath(metadata_item, None, None, True)
+
+ def get(self):
+ """Load the metadata.
+
+ Returns
+ -------
+ Dict[str, Any]
+ The metadata.
+ """
+ metadata_path = self.get_path()
+ if not os.path.exists(metadata_path):
+ return {}
+ self.lock.acquire()
+ with open(metadata_path) as fp:
+ metadata = json.load(fp)
+ self.lock.release()
+ return metadata
+
+ def add(
+ self,
+ data_dict: Dict[str, Any],
+ filepath: str,
+ feature_name: Optional[str],
+ set_name: Optional[str],
+ ):
+ """Add metadata for a file.
+
+ Parameters
+ ----------
+ data_dict: Dict[str, Any]
+ The data dictionary.
+ filepath: str
+ The path to the file.
+ feature_name: Optional[str]
+ The name of the feature.
+ set_name: Optional[str]
+ The name of the set.
+ """
+ metadata = self.get()
+ key = self.key_fn(data_dict)
+ if key not in metadata:
+ metadata[key] = []
+ all_filepaths = wrap_in_list(filepath)
+ for path in all_filepaths:
+ metadata_for_savepath = self.get_metadata_for_savepath(
+ path, feature_name, set_name
)
- if os.path.exists(path):
- self.reloadable = path
- return self.saver.STOP_ITERATION
+ if metadata_for_savepath not in metadata[key]:
+ metadata[key] += [metadata_for_savepath]
+ self.write(metadata)
+
+ def write(self, metadata_dict: Dict[str, Any]):
+ """Write the metadata to disk.
+
+ Parameters
+ ----------
+ metadata_dict: Dict[str, Any]
+ A dictionary containing the metadata.
+ """
+ self.lock.acquire()
+ with open(self.get_path(), "w") as fp:
+ json.dump(metadata_dict, fp)
+ self.lock.release()
- def clear(self, expected_filename, *args, **kwargs):
- """Clear the state of the CheckFunctor.
+ def __contains__(self, item: Any):
+ """Check if the metadata contains a certain item.
Parameters
----------
- expected_filename: str
- The expected filename of the data.
- args: List[Any]
- Additional positional arguments.
- kwargs: Dict[str, Any]
- Additional keyword arguments.
+ item: Any
+
+ Returns
+ -------
+ bool
+ Whether the item is contained.
"""
- self.reloadable = None
- self.expected_filename = expected_filename
+ return item in self.get()
+
+ def __getitem__(self, key: Any):
+ """Retrieve a metadata item.
+
+ Parameters
+ ----------
+ key: Any
+ The key to retrieve.
+
+ Returns
+ -------
+ Any
+ The metadata item.
+ """
+ return self.get()[key]
class DefaultSave(Save):
@@ -241,14 +614,11 @@ class DefaultSave(Save):
between an unprocessed input filename and multiple possible output filenames.
"""
- lock = MultiprocessingSingleton.manager.Lock()
-
DEFAULT_SAVE_FUNCTIONS = {
"npy": np.save,
"pickle": pickle_dump_wrapper,
"data_dict": pickle_dump_wrapper,
}
-
DEFAULT_RELOAD_FUNCTIONS = {
"npy": np.load,
"npz": np.load,
@@ -256,28 +626,20 @@ class DefaultSave(Save):
"data_dict": pickle_load_wrapper,
}
- FEATURE_NAME_STR = "feature_name"
- FILENAME_STR = "filename"
- SET_NAME_STR = "set_name"
- OLD_FORMAT_STR = "old_format"
-
- STOP_ITERATION = "stop_iteration"
-
_metadata_deprecation_warning_logged = False
def __init__(
self,
- root_dir,
- to_save=None,
- overwrite=False,
- clear_output=False,
- filename_fn=default_filename_fn,
- save_fn=None,
- reload_fn=None,
- metadata_filename=".save_metadata.json",
- metadata_key_fn=default_metadata_key_fn,
- check_done: IsDoneCheck = IsDoneCheck(),
- check_reloadable: IsReloadableCheck = IsReloadableCheck(),
+ root_dir: str,
+ to_save: Optional[Mapping[str, Any]] = None,
+ overwrite: bool = False,
+ clear_output: bool = False,
+ filename_fn: FilenameFnInterface = DefaultFilenameFn(),
+ metadata: SaveMetadata = DefaultSaveMetadata(),
+ save_fn: SaveFnInterface = None,
+ reload_fn: ReloadFnInterface = None,
+ check_done: Optional[CheckInterface] = IsDoneCheck(),
+ check_reloadable: Optional[CheckInterface] = IsReloadableCheck(),
):
"""Create a Save step.
@@ -294,32 +656,28 @@ def __init__(
clear_output: bool
Whether to clear the output data_dict after saving. This can save space
when save is the last step in a pipeline.
- filename_fn: Callable[[Dict[str, Any], Optional[str], Optional[str], str], str]
+ filename_fn: FilenameFnInterface
A function to generate a filename for the data. The function should take
the data_dict, the feature name, the set name and a separator as input
and return a filename.
- save_fn: Union[Callable[[Any, str], None], Mapping[str, Callable[[Any, str], None]], None] # noqa: E501
+ save_fn: SaveFnInterface
A function to save the data. The function should take the data and the
filepath as inputs and save the data. If a mapping between file extensions
and functions is given, the function corresponding to the file extension
is used to save the data. If None, the default save functions (defined in
self.DEFAULT_SAVE_FUNCTIONS) are used.
- reload_fn: Union[Callable[[str], Any], Mapping[str, Callable[[str], Any]], None]
+ reload_fn: ReloadFnInterface
A function to reload the data. The function should take the filepath as
input and return the data. If a mapping between file extensions and
functions is given, the function corresponding to the file extension is
used to reload the data. If None, the default reload functions (defined in
self.DEFAULT_RELOAD_FUNCTIONS) are used.
- metadata_filename: str
- The filename of the metadata file.
- metadata_key_fn: Callable[[Dict[str, Any]], str]
- A function to generate a key for the metadata. The function should take
- the data_dict as input and return a key. This key will be used to check
- whether the data has already been saved.
- check_done: IsDoneCheck
- A functor to check whether the data has already been saved.
- check_reloadable: IsReloadableCheck
- A functor to check whether the data can be reloaded.
+ check_done: Optional[CheckInterface]
+ A functor to check whether the data has already been saved. If None, no
+ checking is done
+ check_reloadable: Optional[CheckInterface]
+ A functor to check whether the data can be reloaded. if None, no checking is
+ done.
"""
super().__init__(clear_output=clear_output, overwrite=overwrite)
self.root_dir = root_dir
@@ -331,12 +689,29 @@ def __init__(
self.reload_fn = reload_fn
if self.reload_fn is None:
self.reload_fn = self.DEFAULT_RELOAD_FUNCTIONS
- self.metadata_filename = metadata_filename
- self.metadata_key_fn = metadata_key_fn
- self.check_done = check_done
- self.check_reload = check_reloadable
- self.check_done.saver = self
- self.check_reload.saver = self
+ self.check_done = self._attach_saver(check_done)
+ self.check_reloadable = self._attach_saver(check_reloadable)
+ self.metadata = self._attach_saver(metadata)
+
+ def _attach_saver(self, check: Optional[AttachSave]):
+ """Attach this :class:`Save` object to a :class:`AttachSave` object.
+
+ Parameters
+ ----------
+ check: Optional[AttachSave]
+ An optional object to attach this :class:`Save` object to.
+ Note that the object should implement the :meth:`attach_saver` method, which
+ will return a copy of the :class:`AttachSave` object with the :class:`Save`
+ object attached.
+
+ Returns
+ -------
+ Optional[CheckInterface]
+ The prepared check.
+ """
+ if hasattr(check, "attach_saver"):
+ return check.attach_saver(self)
+ return check
@property
def overwrite(self):
@@ -359,13 +734,9 @@ def overwrite(self, value):
Whether to overwrite existing files.
"""
self._overwrite = value
- if self._overwrite:
- self._clear_metadata()
-
- def _single_obj_to_list(self, obj):
- if not isinstance(obj, (list, tuple)):
- obj = [obj]
- return obj
+ # Clear the metadata if overwrite is True and it has been initialized
+ if self._overwrite and hasattr(self, "metadata"):
+ self.metadata.clear()
def is_already_done(self, data_dict):
"""Check whether the data_dict has already been saved.
@@ -381,89 +752,55 @@ def is_already_done(self, data_dict):
Whether the data_dict has already been saved. This will be checked in the
stored metadata.
"""
+ # If overwrite is True, the data is never done
if self.overwrite:
return False
- metadata = self._get_metadata()
- key = self.metadata_key_fn(data_dict)
- if key not in metadata:
+ # If the key is not in the metadata, the data is not done
+ key = self.metadata.key_fn(data_dict)
+ if key not in self.metadata:
return False
- self.check_done.clear()
- self._iterate_over_metadata_item(metadata[key], self.check_done)
-
- return all(self.check_done.is_done)
+ # If the key is in the metadata, check whether the data is done
+ results = self._iterate_over_metadata_item(
+ self.metadata[key], self.check_done, data_dict
+ )
+ # All items should be done for it to be considered done
+ return len(results) and all(results)
- def _iterate_over_metadata_item(self, metadata_item, callback):
- file_infos = self._single_obj_to_list(metadata_item)
+ def _iterate_over_metadata_item(self, metadata_item, callback, data_dict):
+ # Make a list of metadata items
+ file_infos = wrap_in_list(metadata_item)
+ # Select the feature names for this Save step
feature_names = [None] if self.to_save is None else self.to_save.keys()
+ # Iterate over all file_infos to check them
+ results = []
for info in file_infos:
- if isinstance(info, str):
+ # Info should be a dict, but it can also be a string if the data was saved
+ # with the old metadata format (version <= 0.0.2).
+ if self.metadata.is_old_format(info):
# Only log this if it hasn't been logged before.
if not self._metadata_deprecation_warning_logged:
logging.warning(
"Found previously saved data with the old metadata format "
"(version <= 0.0.2). DefaultSave will attempt to reload the "
"data, but it is recommended to delete the old data and "
- f"{os.path.join(self.root_dir, self.metadata_filename)} if "
- f"possible."
+ "metadata file if possible."
)
self._metadata_deprecation_warning_logged = True
- info = {
- self.FILENAME_STR: info,
- self.FEATURE_NAME_STR: None,
- self.SET_NAME_STR: None,
- self.OLD_FORMAT_STR: True,
- }
+ if isinstance(self.metadata, OldMetadataCompliant):
+ info = self.metadata.convert_old_format(info, data_dict)
+ else:
+ raise ValueError(
+ "The metadata class used is not compatible with the old "
+ "format. Please delete the old data and metadata file."
+ )
for feature_name in feature_names:
- status = callback(info, feature_name)
- # If the callback returns STOP_ITERATION, we stop iterating over the
- # metadata immediately.
- if status == self.STOP_ITERATION:
- return
-
- def _clear_metadata(self):
- self.lock.acquire()
- if not hasattr(self, "root_dir"):
- # If the instance is not initialized
- self.lock.release()
- return
- metadata_path = os.path.join(self.root_dir, self.metadata_filename)
- if os.path.exists(metadata_path):
- os.remove(metadata_path)
- self.lock.release()
-
- def _get_metadata(self):
- metadata_path = os.path.join(self.root_dir, self.metadata_filename)
- if not os.path.exists(metadata_path):
- return {}
- self.lock.acquire()
- with open(metadata_path) as fp:
- metadata = json.load(fp)
- self.lock.release()
- return metadata
-
- def _add_metadata(self, data_dict, filepath, feature_name, set_name):
- metadata = self._get_metadata()
- key = self.metadata_key_fn(data_dict)
- if key not in metadata:
- metadata[key] = []
- all_filepaths = self._single_obj_to_list(filepath)
- for path in all_filepaths:
- metadata_for_path = {
- self.FILENAME_STR: os.path.relpath(path, self.root_dir),
- self.FEATURE_NAME_STR: feature_name,
- self.SET_NAME_STR: set_name,
- }
- if metadata_for_path not in metadata[key]:
- metadata[key] += [metadata_for_path]
- self._write_metadata(metadata)
-
- def _write_metadata(self, metadata):
- metadata_path = os.path.join(self.root_dir, self.metadata_filename)
- self.lock.acquire()
- with open(metadata_path, "w") as fp:
- json.dump(metadata, fp)
- self.lock.release()
+ item = callback(info, feature_name, data_dict)
+ # If the callback returns None, the check is skipped because it is
+ # assumed that the check is not applicable.
+ if item is not None:
+ results.append(item)
+ return results
def _serialization_wrapper(self, fn, filepath, *args, action="save", **kwargs):
if not isinstance(fn, dict):
@@ -482,7 +819,7 @@ def _apply_to_data(self, data_dict, fn):
if self.to_save is None:
path = os.path.join(self.root_dir, self.filename_fn(data_dict, None, None))
self._serialization_wrapper(fn, path, data_dict, action="save")
- self._add_metadata(data_dict, [path], None, None)
+ self.metadata.add(data_dict, [path], None, None)
return
# Save singular features
@@ -496,15 +833,15 @@ def _apply_to_data(self, data_dict, fn):
path = os.path.join(self.root_dir, filename)
self._serialization_wrapper(fn, path, set_data, action="save")
paths += [path]
- self._add_metadata(data_dict, paths, feature_name, set_name)
+ self.metadata.add(data_dict, paths, feature_name, set_name)
# In full
else:
- filename = self.filename_fn(data_dict, feature_name)
+ filename = self.filename_fn(data_dict, feature_name, None)
path = os.path.join(self.root_dir, filename)
self._serialization_wrapper(fn, path, data, action="save")
paths += [path]
- self._add_metadata(data_dict, paths, feature_name, None)
+ self.metadata.add(data_dict, paths, feature_name, None)
def is_reloadable(self, data_dict: Dict[str, Any]) -> bool:
"""Check whether an already processed data_dict can be reloaded.
@@ -520,9 +857,8 @@ def is_reloadable(self, data_dict: Dict[str, Any]) -> bool:
Whether an already processed data_dict can be reloaded to continue
processing.
"""
- metadata = self._get_metadata()
- key = self.metadata_key_fn(data_dict)
- if key not in metadata:
+ key = self.metadata.key_fn(data_dict)
+ if key not in self.metadata:
return False
# No support to reload singular features
@@ -533,13 +869,10 @@ def is_reloadable(self, data_dict: Dict[str, Any]) -> bool:
if self.overwrite:
return False
- expected_filename = os.path.relpath(
- self.filename_fn(data_dict, None, None), self.root_dir
+ is_reloadable = self._iterate_over_metadata_item(
+ self.metadata[key], self.check_reloadable, data_dict
)
- self.check_reload.clear(expected_filename)
- self._iterate_over_metadata_item(metadata[key], self.check_reload)
-
- return self.check_reload.reloadable is not None
+ return len(is_reloadable) and any(is_reloadable)
def reload(self, data_dict: Dict[str, Any]) -> Dict[str, Any]:
"""Reload the data_dict from the saved file.
@@ -554,19 +887,17 @@ def reload(self, data_dict: Dict[str, Any]) -> Dict[str, Any]:
Dict[str, Any]
The reloaded data_dict.
"""
- metadata = self._get_metadata()
- key = self.metadata_key_fn(data_dict)
- expected_filename = os.path.relpath(
- self.filename_fn(data_dict, None, None), self.root_dir
+ key = self.metadata.key_fn(data_dict)
+ paths = self._iterate_over_metadata_item(
+ self.metadata[key], self.check_reloadable, data_dict
)
- self.check_reload.clear(expected_filename)
- self._iterate_over_metadata_item(metadata[key], self.check_reload)
- if self.check_reload.reloadable is None:
+ if not (len(paths) and any(paths)):
raise ValueError("Didn't find any file that can be reloaded.")
+ selected_path = [p for p in paths if p][0]
return self._serialization_wrapper(
self.reload_fn,
- self.check_reload.reloadable,
+ selected_path,
action="reload",
)
diff --git a/brain_pipe/utils/list.py b/brain_pipe/utils/list.py
index 7197b70..10c6370 100644
--- a/brain_pipe/utils/list.py
+++ b/brain_pipe/utils/list.py
@@ -1,17 +1,18 @@
"""List utilities."""
+from typing import Union, Any, List, Tuple
-def flatten(lst):
- """Flatten a list.
+def flatten(lst: Union[List[Any], Tuple[Any]]) -> List[Any]:
+ """Flatten a list or tuple (recursively).
Parameters
----------
- lst: Union[list, tuple]
- A list to be flattened.
+ lst: Union[List[Any], Tuple[Any]]
+ A list to be flattened. Sublists and tuples are also flattened.
Returns
-------
- list
+ List[Any]
A flattened list.
"""
result = []
@@ -21,3 +22,22 @@ def flatten(lst):
else:
result.append(lst)
return result
+
+
+def wrap_in_list(obj: Any) -> List[Any]:
+ """Wrap an object in a list if it is not already a list.
+
+ Parameters
+ ----------
+ obj: Any
+ An object to be wrapped in a list. If it is already a list, it is returned
+ as is.
+
+ Returns
+ -------
+ List[Any]
+ A list containing the object.
+ """
+ if not isinstance(obj, (list, tuple)):
+ obj = [obj]
+ return obj
diff --git a/brain_pipe/utils/multiprocess.py b/brain_pipe/utils/multiprocess.py
index 5a2690c..750a134 100644
--- a/brain_pipe/utils/multiprocess.py
+++ b/brain_pipe/utils/multiprocess.py
@@ -46,6 +46,7 @@ class MultiprocessingSingleton:
"""Singleton class for multiprocessing."""
manager = multiprocessing.Manager()
+ locks = {}
to_clean = []
@@ -115,3 +116,21 @@ def clean(cls):
pool.close()
pool.join()
cls.to_clean = []
+
+ @classmethod
+ def get_lock(cls, id_str):
+ """Create or get a lock for multiprocessing.
+
+ Parameters
+ ----------
+ id_str: str
+ Identifier for the lock. If the lock does not already exist in self.locks,
+ it will be created and added to self.locks.
+
+ Returns
+ -------
+ multiprocessing.Lock
+ """
+ if id_str not in cls.locks:
+ cls.locks[id_str] = cls.manager.Lock()
+ return cls.locks[id_str]
diff --git a/setup.cfg b/setup.cfg
index 8ade7df..3a2361a 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -1,5 +1,5 @@
[flake8]
# black default
max-line-length = 88
-ignore = E203
+ignore = E203,W503
exclude = .git,__pycache__,docs/source/conf.py,old,build,dist,bids_preprocessing/utils,bids_preprocessing/tests,bids_preprocessing/preprocessing/cache
\ No newline at end of file
diff --git a/tests/test_data/truncated_sparrKULee/derivatives/stimuli/audiobook_7_5.data_dict.gz b/tests/test_data/truncated_sparrKULee/derivatives/stimuli/audiobook_7_5.data_dict.gz
index 9789ff4..2abfdc9 100644
Binary files a/tests/test_data/truncated_sparrKULee/derivatives/stimuli/audiobook_7_5.data_dict.gz and b/tests/test_data/truncated_sparrKULee/derivatives/stimuli/audiobook_7_5.data_dict.gz differ
diff --git a/tests/test_save/test_default.py b/tests/test_save/test_default.py
index 46ed76f..6773e4f 100644
--- a/tests/test_save/test_default.py
+++ b/tests/test_save/test_default.py
@@ -9,8 +9,9 @@
from brain_pipe.save.default import (
default_metadata_key_fn,
- default_filename_fn,
+ DefaultFilenameFn,
DefaultSave,
+ DefaultSaveMetadata,
)
from brain_pipe.utils.serialization import pickle_load_wrapper
@@ -30,16 +31,17 @@ def test_defaultMetadataKeyFn(self):
class DefaultFilenameFnTest(unittest.TestCase):
def test_defaultFilenameFn(self):
self.assertEqual(
- default_filename_fn({"data_path": "a"}, "b", "c", "_-_"), "c_-_a_-_b.npy"
+ DefaultFilenameFn(path_keys=["a", "b", "c"])(
+ {"a": "c", "b": "a"}, "b", "d"
+ ),
+ "d_-_c_-_a_-_b.npy",
)
self.assertEqual(
- default_filename_fn(
- {"data_path": "a", "stimulus_path": "d"}, None, None, "_"
- ),
+ DefaultFilenameFn(separator="_")({"data_path": "a", "stimulus_path": "d"}),
"a_d.data_dict",
)
self.assertEqual(
- default_filename_fn(
+ DefaultFilenameFn(separator="|")(
{
"data_path": "a",
"stimulus_path": "d",
@@ -47,7 +49,6 @@ def test_defaultFilenameFn(self):
},
"b",
"c",
- "|",
),
"c|a|d|123.123|b.npy",
)
@@ -75,10 +76,11 @@ def setUp(self) -> None:
self.tmp_dir = tempfile.TemporaryDirectory()
def test_is_already_done(self):
+ metadata_path = os.path.join(self.tmp_dir.name, ".save_metadata.json")
saver = DefaultSave(
self.tmp_dir.name,
filename_fn=self.MockupFilenameFn(),
- metadata_key_fn=self.MockupMetadataKeyFn(),
+ metadata=DefaultSaveMetadata(key_fn=self.MockupMetadataKeyFn()),
)
path = os.path.join(self.tmp_dir.name, "b")
temp_dict = {"output_filename": path, "metadata_key": "d"}
@@ -86,7 +88,7 @@ def test_is_already_done(self):
self.assertFalse(saver.is_already_done(temp_dict))
# Make manually sure the file is already done/there.
- with open(os.path.join(self.tmp_dir.name, ".save_metadata.json"), "w") as f:
+ with open(metadata_path, "w") as f:
json.dump({"d": os.path.basename(path)}, f)
with open(path, "w") as f:
@@ -115,7 +117,7 @@ def test_is_reloadable(self):
saver = DefaultSave(
self.tmp_dir.name,
filename_fn=self.MockupFilenameFn(),
- metadata_key_fn=self.MockupMetadataKeyFn(),
+ metadata=DefaultSaveMetadata(key_fn=self.MockupMetadataKeyFn()),
)
path = os.path.join(self.tmp_dir.name, "b")
temp_dict = {"output_filename": path, "metadata_key": "d"}
@@ -136,7 +138,7 @@ def test_is_reloadable(self):
saver = DefaultSave(
self.tmp_dir.name,
filename_fn=self.MockupFilenameFn(),
- metadata_key_fn=self.MockupMetadataKeyFn(),
+ metadata=DefaultSaveMetadata(key_fn=self.MockupMetadataKeyFn()),
overwrite=True,
)
# When overwrite is True, it should always return False
@@ -154,7 +156,7 @@ def test_reload(self):
saver = DefaultSave(
self.tmp_dir.name,
filename_fn=self.MockupFilenameFn(),
- metadata_key_fn=self.MockupMetadataKeyFn(),
+ metadata=DefaultSaveMetadata(key_fn=self.MockupMetadataKeyFn()),
)
path = os.path.join(self.tmp_dir.name, "b")
temp_dict = {"output_filename": path, "metadata_key": "d"}
@@ -200,7 +202,7 @@ def test_call(self):
saver = DefaultSave(
self.tmp_dir.name,
filename_fn=self.MockupFilenameFn(),
- metadata_key_fn=self.MockupMetadataKeyFn(),
+ metadata=DefaultSaveMetadata(key_fn=self.MockupMetadataKeyFn()),
)
data_dict = {"output_filename": "a.pickle", "metadata_key": "b", "data": 123}
output = saver(data_dict)
@@ -221,7 +223,7 @@ def test_to_save(self):
saver = DefaultSave(
self.tmp_dir.name,
filename_fn=self.MockupFilenameFn(),
- metadata_key_fn=self.MockupMetadataKeyFn(),
+ metadata=DefaultSaveMetadata(key_fn=self.MockupMetadataKeyFn()),
to_save={"envelope": "data"},
)
data_dict = {"output_filename": "a.npy", "metadata_key": "b", "data": 123}
@@ -302,7 +304,7 @@ def test_clear_output(self):
saver = DefaultSave(
self.tmp_dir.name,
filename_fn=self.MockupFilenameFn(),
- metadata_key_fn=self.MockupMetadataKeyFn(),
+ metadata=DefaultSaveMetadata(key_fn=self.MockupMetadataKeyFn()),
clear_output=True,
)
self.assertEqual(saver(data_dict), {})
@@ -311,15 +313,23 @@ def test_multiple_savers(self):
saver1 = DefaultSave(
self.tmp_dir.name,
filename_fn=self.MockupFilenameFn(),
- metadata_key_fn=self.MockupMetadataKeyFn(),
+ metadata=DefaultSaveMetadata(key_fn=self.MockupMetadataKeyFn()),
)
saver2 = DefaultSave(
self.tmp_dir.name,
to_save={"envelope": "data"},
filename_fn=self.MockupFilenameFn(output_override="a.npy"),
- metadata_key_fn=self.MockupMetadataKeyFn(),
+ metadata=DefaultSaveMetadata(key_fn=self.MockupMetadataKeyFn()),
)
+ # Do check_done, check_reloadable and metadata have the correct savers attached?
+ self.assertEqual(saver1.check_done.saver, saver1)
+ self.assertEqual(saver2.check_done.saver, saver2)
+ self.assertEqual(saver1.check_reloadable.saver, saver1)
+ self.assertEqual(saver2.check_reloadable.saver, saver2)
+ self.assertEqual(saver1.metadata.saver, saver1)
+ self.assertEqual(saver2.metadata.saver, saver2)
+
data_dict = {
"output_filename": "a.data_dict",
"metadata_key": "b",
diff --git a/tests/test_utils/test_list.py b/tests/test_utils/test_list.py
new file mode 100644
index 0000000..9912b19
--- /dev/null
+++ b/tests/test_utils/test_list.py
@@ -0,0 +1,16 @@
+import unittest
+
+from brain_pipe.utils.list import flatten, wrap_in_list
+
+
+class ListUtilsTest(unittest.TestCase):
+ def test_flatten(self):
+ flat_list = flatten([1, [2, [3, 4, [5], 6]]])
+ self.assertEqual(flat_list, [1, 2, 3, 4, 5, 6])
+
+ def test_single_obj_to_list(self):
+ self.assertEqual(wrap_in_list(1), [1])
+ self.assertEqual(wrap_in_list([1]), [1])
+ self.assertEqual(wrap_in_list((1, 2)), (1, 2))
+ self.assertEqual(wrap_in_list([]), [])
+ self.assertEqual(wrap_in_list(None), [None])