diff --git a/.idea/brain_pipe.iml b/.idea/brain_pipe.iml index 84173d2..a3959c8 100644 --- a/.idea/brain_pipe.iml +++ b/.idea/brain_pipe.iml @@ -6,11 +6,11 @@ - + - \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml index 52028f0..a4652f3 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -1,4 +1,4 @@ - + \ No newline at end of file diff --git a/brain_pipe/save/default.py b/brain_pipe/save/default.py index 0580e2c..859ecd0 100644 --- a/brain_pipe/save/default.py +++ b/brain_pipe/save/default.py @@ -1,10 +1,11 @@ """Default save class.""" import abc +import copy import gc import json import logging import os -from typing import Any, Dict +from typing import Any, Dict, Optional, Union, Callable, Mapping import numpy as np @@ -13,8 +14,19 @@ pickle_load_wrapper, ) from brain_pipe.save.base import Save +from brain_pipe.utils.list import wrap_in_list from brain_pipe.utils.multiprocess import MultiprocessingSingleton +# Shorthand interfaces. +CheckInterface = Callable[[Dict[str, Any], str, Dict[str, Any]], Union[str, bool]] +FilenameFnInterface = Callable[[Dict[str, Any], Optional[str], Optional[str]], str] +SaveFnInterface = Union[ + Callable[[Any, str], None], Mapping[str, Callable[[Any, str], None]], None +] +ReloadFnInterface = Union[ + Callable[[str], Any], Mapping[str, Callable[[str], Any]], None +] + def default_metadata_key_fn(data_dict: Dict[str, Any]) -> str: """Generate a key for the metadata. @@ -41,54 +53,125 @@ def default_metadata_key_fn(data_dict: Dict[str, Any]) -> str: raise ValueError("No data_path or stimulus_path in data_dict.") -def default_filename_fn(data_dict, feature_name, set_name=None, separator="_-_"): - """Generate a filename for the data_dict. +class DefaultFilenameFn(FilenameFnInterface): + """Default filename function to create paths to save data.""" - Parameters - ---------- - data_dict: Dict[str, Any] - The data dict containing the data to save. - feature_name: Optional[str] - The name of the feature. - set_name: Optional[str] - The name of the set. If no set name is given, the set name is not - included in the filename. - separator: str - The separator to use between the different parts of the filename. + SPLIT_CHAR = "/" - Returns - ------- - str - The filename. - """ - parts = [] - if "data_path" in data_dict: - parts += [os.path.basename(data_dict["data_path"]).split(".")[0]] + def __init__( + self, + path_keys=("data_path", "stimulus_path"), + other_keys=("event_info/snr",), + separator="_-_", + data_dict_extension=".data_dict", + feature_extension=".npy", + ): + """Create a new DefaultFilenameFn instance. - if "stimulus_path" in data_dict: - parts += [os.path.basename(data_dict["stimulus_path"]).split(".")[0]] + Parameters + ---------- + path_keys: Sequence[str] + The keys of the paths to include in the filename. + other_keys: Sequence[str] + The keys of other data to include in the filename. + separator: str + The separator to use between parts of the filename. + data_dict_extension: str + The extension to use when saving the entire data_dict. + feature_extension: str + The extension to use when saving a single feature. + """ + self.path_keys = path_keys + self.other_keys = other_keys + self.separator = separator + self.data_dict_extension = data_dict_extension + self.feature_extension = feature_extension - if feature_name is None and set_name is None: - return separator.join(parts) + ".data_dict" + def __call__(self, data_dict, feature_name=None, set_name=None): + """Generate a filename for the data_dict. - if "event_info" in data_dict and "snr" in data_dict["event_info"]: - parts += [str(data_dict["event_info"]["snr"])] + Parameters + ---------- + data_dict: Dict[str, Any] + The data dict containing the data to save. + feature_name: Optional[str] + The name of the feature. + set_name: Optional[str] + The name of the set. If no set name is given, the set name is not + included in the filename. - keys = parts + [feature_name] - if set_name is not None: - keys = [set_name] + keys - return separator.join(keys) + ".npy" + Returns + ------- + str + The filename. + """ + parts = [] + + # Add indicated path keys to filename + for path_key in self.path_keys: + if path_key in data_dict: + parts.append( + "".join(os.path.basename(data_dict[path_key]).rsplit(".")[0]) + ) + # If no feature name or set name is given, return the data_dict filename. + if feature_name is None and set_name is None: + return self.separator.join(parts) + self.data_dict_extension + + # Add indicated other keys to data + for full_key in self.other_keys: + item = data_dict + append = True + for key in full_key.split(self.SPLIT_CHAR): + if key not in item: + append = False + break + item = item[key] + if append: + # Has to be added as a string to be able to be used in filename + parts.append(str(item)) + + # Add the feature name as the last part of the filename, if given + if feature_name is not None: + parts.append(feature_name) + # Add the set name as the first part of the filename, if given + if set_name is not None: + parts.insert(0, set_name) + return self.separator.join(parts) + self.feature_extension + + +class AttachSave(abc.ABC): + """Mixin class to attach a Save object.""" + + def __init__(self, saver=None): + """Initialize the AttachSaver. + Parameters + ---------- + saver: Optional[Save] + The saver to use. Can be attached later with :meth:`attach_saver`, but + that will create a new object. -class CheckFunctor(abc.ABC): - """Functor to use with DefaultSave to check something about of a metadata item.""" + """ + self.saver = saver - def __init__(self): - """Create a new CheckFunctor.""" - self.saver = None + def attach_saver(self, saver): + """Initialize a new object with the saver. + + Parameters + ---------- + saver: Save + The saver. + """ + new_ = copy.deepcopy(self) + new_.saver = saver + return new_ + + +class CheckFunctor(CheckInterface, AttachSave, abc.ABC): + """Functor to use with DefaultSave to check something about of a metadata item.""" @abc.abstractmethod - def __call__(self, metadata_item, feature_name): + def __call__(self, metadata_item, feature_name, data_dict): """Check something about the metadata item. Parameters @@ -99,26 +182,15 @@ def __call__(self, metadata_item, feature_name): optionally :attr:`DefaultSave.OLD_FORMAT_STR` feature_name: Optional[str] The name of the feature. + data_dict: Dict[str, Any] + The data dict containing the data to save. Returns ------- - Optional[str] - This value will normally be passed to - :meth:`DefaultSave._iterate_over_metadata_item`. If this value is - :attr:`DefaultSave.STOP_ITERATING`, the iteration will stop immediately. - """ - pass - - @abc.abstractmethod - def clear(self, *args, **kwargs): - """Clear the state of the CheckFunctor. - - Parameters - ---------- - args: List[Any] - Additional positional arguments. - kwargs: Dict[str, Any] - Additional keyword arguments. + Union[Any, bool, None] + False if the data does not pass the check, otherwise a useful representation + of the data (such as the path). If None is returned, the check was not + applicable. """ pass @@ -126,12 +198,7 @@ def clear(self, *args, **kwargs): class IsDoneCheck(CheckFunctor): """Check if data has already been saved.""" - def __init__(self): - """Create a new IsDoneCheck.""" - super().__init__() - self.is_done = [] - - def __call__(self, metadata_item, feature_name): + def __call__(self, metadata_item, feature_name, data_dict): """Check if data has already been saved. Parameters @@ -145,44 +212,28 @@ def __call__(self, metadata_item, feature_name): Returns ------- - Optional[str] - This value will normally be passed to - :meth:`DefaultSave._iterate_over_metadata_item`. If this value is - :attr:`DefaultSave.STOP_ITERATING`, the iteration will stop immediately. + Union[str, bool, None] + False if the data has not been saved. + None if the check was not applicable for the given feature name. + The path to the data otherwise """ - if metadata_item[self.saver.FEATURE_NAME_STR] is feature_name: + if metadata_item[self.saver.metadata.FEATURE_NAME_STR] == feature_name: path = os.path.join( - self.saver.root_dir, metadata_item[self.saver.FILENAME_STR] + self.saver.root_dir, metadata_item[self.saver.metadata.FILENAME_STR] ) - path_is_done = os.path.exists(path) - self.is_done.append(path_is_done) - if not path_is_done: - return self.saver.STOP_ITERATION - - def clear(self, *args, **kwargs): - """Clear the state of the CheckFunctor. - - Parameters - ---------- - args: List[Any] - Additional positional arguments. - kwargs: Dict[str, Any] - Additional keyword arguments. - """ - self.is_done.clear() + return path if os.path.exists(path) else False class IsReloadableCheck(CheckFunctor): """Check if data has already been saved and is reloadable.""" - def __init__(self): - """Create a new ReloadableCheck.""" - super().__init__() - self.reloadable = None - self.expected_filename = None - - def __call__(self, metadata_item, feature_name): + def __call__( + self, + metadata_item: Dict[str, Any], + feature_name: Optional[str], + data_dict: Dict[str, Any], + ): """Check if data has already been saved and is reloadable. Parameters @@ -193,44 +244,366 @@ def __call__(self, metadata_item, feature_name): optionally :attr:`DefaultSave.OLD_FORMAT_STR`. feature_name: Optional[str] The name of the feature. + data_dict: Dict[str, Any] + The data dict containing the unprocessed data Returns ------- - Optional[str] - This value will normally be passed to - :meth:`DefaultSave._iterate_over_metadata_item`. If this value is - :attr:`DefaultSave.STOP_ITERATING`, the iteration will stop immediately. + Union[str, bool, None] + False if the data is not reloadable. + None if the check was not applicable for the given feature name. + The path to the data otherwise. """ - if metadata_item[self.saver.FEATURE_NAME_STR] is None: - is_old_format = metadata_item.get(self.saver.OLD_FORMAT_STR, False) - filename = metadata_item[self.saver.FILENAME_STR] + if metadata_item[self.saver.metadata.FEATURE_NAME_STR] is None: + is_old_format = self.saver.metadata.is_old_format(metadata_item) + filename = metadata_item[self.saver.metadata.FILENAME_STR] + expected_filename = os.path.relpath( + self.saver.filename_fn(data_dict, None, None), self.saver.root_dir + ) # If old format and not the expected filename to be reloadable # then skip - if is_old_format and filename != self.expected_filename: - return + if is_old_format and filename != expected_filename: + return False path = os.path.join( - self.saver.root_dir, metadata_item[self.saver.FILENAME_STR] + self.saver.root_dir, metadata_item[self.saver.metadata.FILENAME_STR] + ) + return path if os.path.exists(path) else False + + +class SaveMetadata(AttachSave, abc.ABC): + """Abstract class for metadata to use when saving/reloading.""" + + FEATURE_NAME_STR = "feature_name" + FILENAME_STR = "filename" + SET_NAME_STR = "set_name" + OLD_FORMAT_STR = "old_format" + + def __init__( + self, key_fn: Callable[[Dict[str, Any]], str] = default_metadata_key_fn + ): + """Create a new SaveMetadata. + + Parameters + ---------- + key_fn: Callable[[Dict[str, Any]], str] + The function to use to get the key from the data dict. + """ + super().__init__() + self.key_fn = key_fn + + @abc.abstractmethod + def clear(self): + """Clear the metadata.""" + + @abc.abstractmethod + def add( + self, + data_dict: Dict[str, Any], + filepath: str, + feature_name: Optional[str], + set_name: Optional[str], + ): + """Add a metadata entry. + + Parameters + ---------- + data_dict: Dict[str, Any] + The data dict containing the data to save. + filepath: str + The path to the data. + feature_name: Optional[str] + The name of the feature. + set_name: Optional[str] + The name of the set. + """ + + @abc.abstractmethod + def __contains__(self, item): + """Check if the metadata contains a certain item. + + Parameters + ---------- + item: Any + The item to check. + + Returns + ------- + bool + Whether the item is contained. + """ + + @abc.abstractmethod + def __getitem__(self, key: Any): + """Retrieve a metadata item. + + Parameters + ---------- + key: Any + The key to retrieve. + + Returns + ------- + Dict[str, Any] + The metadata item. + """ + + @classmethod + def is_old_format(cls, metadata_item: Union[str, Dict[str, Any]]): + """Check if the metadata item is in the old format. + + Parameters + ---------- + metadata_item: Union[str, Dict[str, Any]] + The metadata item to check. + + Returns + ------- + bool + Whether the metadata item is in the old format. + """ + return isinstance(metadata_item, str) or cls.OLD_FORMAT_STR in metadata_item + + +class OldMetadataCompliant(abc.ABC): + """Mixin class for metadata that is compliant with the old format.""" + + @abc.abstractmethod + def convert_old_format(self, metadata_item: str, data_dict: Dict[str, Any]): + """Convert the metadata item from the old format. + + Parameters + ---------- + metadata_item: str + The metadata item to convert. + + data_dict: Dict[str, Any] + The data dict containing the data to save. + + Returns + ------- + Dict[str, Any] + The converted metadata item. + """ + + +class DefaultSaveMetadata(OldMetadataCompliant, SaveMetadata): + """Implementation of SaveMetadata to work with DefaultSave.""" + + def __init__( + self, + key_fn: Callable[[Dict[str, Any]], str] = default_metadata_key_fn, + filename: str = ".save_metadata.json", + ): + """Create a new DefaultSaveMetadata. + + Parameters + ---------- + key_fn: Callable[[Dict[str, Any]], str] + The function to use to get the key from the data dict. + filename: str + The filename to use for the metadata. + """ + super().__init__(key_fn=key_fn) + self.filename = filename + self.saver = None + + def get_path(self): + """Get the path to the metadata file. + + Returns + ------- + str + The path to the metadata file. + """ + if self.saver is not None and isinstance(self.saver, DefaultSave): + path = os.path.join(self.saver.root_dir, self.filename) + return path + return self.filename + + def get_relpath(self, path: str): + """Construct a relative path with regard to save folder. + + Parameters + ---------- + path: str + The path to make relative. If no saver is attached, this is returned as is. + + Returns + ------- + str + The relative path. + """ + if ( + self.saver is not None + and isinstance(self.saver, DefaultSave) + and os.path.isabs(path) + ): + return os.path.relpath(path, self.saver.root_dir) + return path + + @property + def lock(self): + """Retrieve the lock to use for the metadata file. + + Returns + ------- + multiprocessing.Lock + The lock to use for the metadata file. + """ + return MultiprocessingSingleton.get_lock(self.get_path()) + + def clear(self): + """Clear the metadata.""" + self.lock.acquire() + metadata_path = self.get_path() + if os.path.exists(metadata_path): + os.remove(metadata_path) + self.lock.release() + + def get_metadata_for_savepath( + self, + path: str, + feature_name: Optional[str], + set_name: Optional[str], + from_old_format=False, + ): + """Get the metadata associated for path where data is saved. + + Parameters + ---------- + path: str + The path to the data. + feature_name: Optional[str] + The name of the feature. + set_name: Optional[str] + The name of the set. + from_old_format: bool + Whether the metadata is in the old format. + + Returns + ------- + Dict[str, Any] + The metadata associated with the path. + """ + metadata = { + self.FILENAME_STR: self.get_relpath(path), + self.FEATURE_NAME_STR: feature_name, + self.SET_NAME_STR: set_name, + } + if from_old_format: + metadata[self.OLD_FORMAT_STR] = True + return metadata + + def convert_old_format(self, metadata_item: str, data_dict: Dict[str, Any]): + """Convert the metadata item from the old format. + + Parameters + ---------- + metadata_item: str + The metadata item to convert. + data_dict: Dict[str, Any] + The data dict containing the data to save. + + Returns + ------- + Dict[str, Any] + The converted metadata item. + """ + return self.get_metadata_for_savepath(metadata_item, None, None, True) + + def get(self): + """Load the metadata. + + Returns + ------- + Dict[str, Any] + The metadata. + """ + metadata_path = self.get_path() + if not os.path.exists(metadata_path): + return {} + self.lock.acquire() + with open(metadata_path) as fp: + metadata = json.load(fp) + self.lock.release() + return metadata + + def add( + self, + data_dict: Dict[str, Any], + filepath: str, + feature_name: Optional[str], + set_name: Optional[str], + ): + """Add metadata for a file. + + Parameters + ---------- + data_dict: Dict[str, Any] + The data dictionary. + filepath: str + The path to the file. + feature_name: Optional[str] + The name of the feature. + set_name: Optional[str] + The name of the set. + """ + metadata = self.get() + key = self.key_fn(data_dict) + if key not in metadata: + metadata[key] = [] + all_filepaths = wrap_in_list(filepath) + for path in all_filepaths: + metadata_for_savepath = self.get_metadata_for_savepath( + path, feature_name, set_name ) - if os.path.exists(path): - self.reloadable = path - return self.saver.STOP_ITERATION + if metadata_for_savepath not in metadata[key]: + metadata[key] += [metadata_for_savepath] + self.write(metadata) + + def write(self, metadata_dict: Dict[str, Any]): + """Write the metadata to disk. + + Parameters + ---------- + metadata_dict: Dict[str, Any] + A dictionary containing the metadata. + """ + self.lock.acquire() + with open(self.get_path(), "w") as fp: + json.dump(metadata_dict, fp) + self.lock.release() - def clear(self, expected_filename, *args, **kwargs): - """Clear the state of the CheckFunctor. + def __contains__(self, item: Any): + """Check if the metadata contains a certain item. Parameters ---------- - expected_filename: str - The expected filename of the data. - args: List[Any] - Additional positional arguments. - kwargs: Dict[str, Any] - Additional keyword arguments. + item: Any + + Returns + ------- + bool + Whether the item is contained. """ - self.reloadable = None - self.expected_filename = expected_filename + return item in self.get() + + def __getitem__(self, key: Any): + """Retrieve a metadata item. + + Parameters + ---------- + key: Any + The key to retrieve. + + Returns + ------- + Any + The metadata item. + """ + return self.get()[key] class DefaultSave(Save): @@ -241,14 +614,11 @@ class DefaultSave(Save): between an unprocessed input filename and multiple possible output filenames. """ - lock = MultiprocessingSingleton.manager.Lock() - DEFAULT_SAVE_FUNCTIONS = { "npy": np.save, "pickle": pickle_dump_wrapper, "data_dict": pickle_dump_wrapper, } - DEFAULT_RELOAD_FUNCTIONS = { "npy": np.load, "npz": np.load, @@ -256,28 +626,20 @@ class DefaultSave(Save): "data_dict": pickle_load_wrapper, } - FEATURE_NAME_STR = "feature_name" - FILENAME_STR = "filename" - SET_NAME_STR = "set_name" - OLD_FORMAT_STR = "old_format" - - STOP_ITERATION = "stop_iteration" - _metadata_deprecation_warning_logged = False def __init__( self, - root_dir, - to_save=None, - overwrite=False, - clear_output=False, - filename_fn=default_filename_fn, - save_fn=None, - reload_fn=None, - metadata_filename=".save_metadata.json", - metadata_key_fn=default_metadata_key_fn, - check_done: IsDoneCheck = IsDoneCheck(), - check_reloadable: IsReloadableCheck = IsReloadableCheck(), + root_dir: str, + to_save: Optional[Mapping[str, Any]] = None, + overwrite: bool = False, + clear_output: bool = False, + filename_fn: FilenameFnInterface = DefaultFilenameFn(), + metadata: SaveMetadata = DefaultSaveMetadata(), + save_fn: SaveFnInterface = None, + reload_fn: ReloadFnInterface = None, + check_done: Optional[CheckInterface] = IsDoneCheck(), + check_reloadable: Optional[CheckInterface] = IsReloadableCheck(), ): """Create a Save step. @@ -294,32 +656,28 @@ def __init__( clear_output: bool Whether to clear the output data_dict after saving. This can save space when save is the last step in a pipeline. - filename_fn: Callable[[Dict[str, Any], Optional[str], Optional[str], str], str] + filename_fn: FilenameFnInterface A function to generate a filename for the data. The function should take the data_dict, the feature name, the set name and a separator as input and return a filename. - save_fn: Union[Callable[[Any, str], None], Mapping[str, Callable[[Any, str], None]], None] # noqa: E501 + save_fn: SaveFnInterface A function to save the data. The function should take the data and the filepath as inputs and save the data. If a mapping between file extensions and functions is given, the function corresponding to the file extension is used to save the data. If None, the default save functions (defined in self.DEFAULT_SAVE_FUNCTIONS) are used. - reload_fn: Union[Callable[[str], Any], Mapping[str, Callable[[str], Any]], None] + reload_fn: ReloadFnInterface A function to reload the data. The function should take the filepath as input and return the data. If a mapping between file extensions and functions is given, the function corresponding to the file extension is used to reload the data. If None, the default reload functions (defined in self.DEFAULT_RELOAD_FUNCTIONS) are used. - metadata_filename: str - The filename of the metadata file. - metadata_key_fn: Callable[[Dict[str, Any]], str] - A function to generate a key for the metadata. The function should take - the data_dict as input and return a key. This key will be used to check - whether the data has already been saved. - check_done: IsDoneCheck - A functor to check whether the data has already been saved. - check_reloadable: IsReloadableCheck - A functor to check whether the data can be reloaded. + check_done: Optional[CheckInterface] + A functor to check whether the data has already been saved. If None, no + checking is done + check_reloadable: Optional[CheckInterface] + A functor to check whether the data can be reloaded. if None, no checking is + done. """ super().__init__(clear_output=clear_output, overwrite=overwrite) self.root_dir = root_dir @@ -331,12 +689,29 @@ def __init__( self.reload_fn = reload_fn if self.reload_fn is None: self.reload_fn = self.DEFAULT_RELOAD_FUNCTIONS - self.metadata_filename = metadata_filename - self.metadata_key_fn = metadata_key_fn - self.check_done = check_done - self.check_reload = check_reloadable - self.check_done.saver = self - self.check_reload.saver = self + self.check_done = self._attach_saver(check_done) + self.check_reloadable = self._attach_saver(check_reloadable) + self.metadata = self._attach_saver(metadata) + + def _attach_saver(self, check: Optional[AttachSave]): + """Attach this :class:`Save` object to a :class:`AttachSave` object. + + Parameters + ---------- + check: Optional[AttachSave] + An optional object to attach this :class:`Save` object to. + Note that the object should implement the :meth:`attach_saver` method, which + will return a copy of the :class:`AttachSave` object with the :class:`Save` + object attached. + + Returns + ------- + Optional[CheckInterface] + The prepared check. + """ + if hasattr(check, "attach_saver"): + return check.attach_saver(self) + return check @property def overwrite(self): @@ -359,13 +734,9 @@ def overwrite(self, value): Whether to overwrite existing files. """ self._overwrite = value - if self._overwrite: - self._clear_metadata() - - def _single_obj_to_list(self, obj): - if not isinstance(obj, (list, tuple)): - obj = [obj] - return obj + # Clear the metadata if overwrite is True and it has been initialized + if self._overwrite and hasattr(self, "metadata"): + self.metadata.clear() def is_already_done(self, data_dict): """Check whether the data_dict has already been saved. @@ -381,89 +752,55 @@ def is_already_done(self, data_dict): Whether the data_dict has already been saved. This will be checked in the stored metadata. """ + # If overwrite is True, the data is never done if self.overwrite: return False - metadata = self._get_metadata() - key = self.metadata_key_fn(data_dict) - if key not in metadata: + # If the key is not in the metadata, the data is not done + key = self.metadata.key_fn(data_dict) + if key not in self.metadata: return False - self.check_done.clear() - self._iterate_over_metadata_item(metadata[key], self.check_done) - - return all(self.check_done.is_done) + # If the key is in the metadata, check whether the data is done + results = self._iterate_over_metadata_item( + self.metadata[key], self.check_done, data_dict + ) + # All items should be done for it to be considered done + return len(results) and all(results) - def _iterate_over_metadata_item(self, metadata_item, callback): - file_infos = self._single_obj_to_list(metadata_item) + def _iterate_over_metadata_item(self, metadata_item, callback, data_dict): + # Make a list of metadata items + file_infos = wrap_in_list(metadata_item) + # Select the feature names for this Save step feature_names = [None] if self.to_save is None else self.to_save.keys() + # Iterate over all file_infos to check them + results = [] for info in file_infos: - if isinstance(info, str): + # Info should be a dict, but it can also be a string if the data was saved + # with the old metadata format (version <= 0.0.2). + if self.metadata.is_old_format(info): # Only log this if it hasn't been logged before. if not self._metadata_deprecation_warning_logged: logging.warning( "Found previously saved data with the old metadata format " "(version <= 0.0.2). DefaultSave will attempt to reload the " "data, but it is recommended to delete the old data and " - f"{os.path.join(self.root_dir, self.metadata_filename)} if " - f"possible." + "metadata file if possible." ) self._metadata_deprecation_warning_logged = True - info = { - self.FILENAME_STR: info, - self.FEATURE_NAME_STR: None, - self.SET_NAME_STR: None, - self.OLD_FORMAT_STR: True, - } + if isinstance(self.metadata, OldMetadataCompliant): + info = self.metadata.convert_old_format(info, data_dict) + else: + raise ValueError( + "The metadata class used is not compatible with the old " + "format. Please delete the old data and metadata file." + ) for feature_name in feature_names: - status = callback(info, feature_name) - # If the callback returns STOP_ITERATION, we stop iterating over the - # metadata immediately. - if status == self.STOP_ITERATION: - return - - def _clear_metadata(self): - self.lock.acquire() - if not hasattr(self, "root_dir"): - # If the instance is not initialized - self.lock.release() - return - metadata_path = os.path.join(self.root_dir, self.metadata_filename) - if os.path.exists(metadata_path): - os.remove(metadata_path) - self.lock.release() - - def _get_metadata(self): - metadata_path = os.path.join(self.root_dir, self.metadata_filename) - if not os.path.exists(metadata_path): - return {} - self.lock.acquire() - with open(metadata_path) as fp: - metadata = json.load(fp) - self.lock.release() - return metadata - - def _add_metadata(self, data_dict, filepath, feature_name, set_name): - metadata = self._get_metadata() - key = self.metadata_key_fn(data_dict) - if key not in metadata: - metadata[key] = [] - all_filepaths = self._single_obj_to_list(filepath) - for path in all_filepaths: - metadata_for_path = { - self.FILENAME_STR: os.path.relpath(path, self.root_dir), - self.FEATURE_NAME_STR: feature_name, - self.SET_NAME_STR: set_name, - } - if metadata_for_path not in metadata[key]: - metadata[key] += [metadata_for_path] - self._write_metadata(metadata) - - def _write_metadata(self, metadata): - metadata_path = os.path.join(self.root_dir, self.metadata_filename) - self.lock.acquire() - with open(metadata_path, "w") as fp: - json.dump(metadata, fp) - self.lock.release() + item = callback(info, feature_name, data_dict) + # If the callback returns None, the check is skipped because it is + # assumed that the check is not applicable. + if item is not None: + results.append(item) + return results def _serialization_wrapper(self, fn, filepath, *args, action="save", **kwargs): if not isinstance(fn, dict): @@ -482,7 +819,7 @@ def _apply_to_data(self, data_dict, fn): if self.to_save is None: path = os.path.join(self.root_dir, self.filename_fn(data_dict, None, None)) self._serialization_wrapper(fn, path, data_dict, action="save") - self._add_metadata(data_dict, [path], None, None) + self.metadata.add(data_dict, [path], None, None) return # Save singular features @@ -496,15 +833,15 @@ def _apply_to_data(self, data_dict, fn): path = os.path.join(self.root_dir, filename) self._serialization_wrapper(fn, path, set_data, action="save") paths += [path] - self._add_metadata(data_dict, paths, feature_name, set_name) + self.metadata.add(data_dict, paths, feature_name, set_name) # In full else: - filename = self.filename_fn(data_dict, feature_name) + filename = self.filename_fn(data_dict, feature_name, None) path = os.path.join(self.root_dir, filename) self._serialization_wrapper(fn, path, data, action="save") paths += [path] - self._add_metadata(data_dict, paths, feature_name, None) + self.metadata.add(data_dict, paths, feature_name, None) def is_reloadable(self, data_dict: Dict[str, Any]) -> bool: """Check whether an already processed data_dict can be reloaded. @@ -520,9 +857,8 @@ def is_reloadable(self, data_dict: Dict[str, Any]) -> bool: Whether an already processed data_dict can be reloaded to continue processing. """ - metadata = self._get_metadata() - key = self.metadata_key_fn(data_dict) - if key not in metadata: + key = self.metadata.key_fn(data_dict) + if key not in self.metadata: return False # No support to reload singular features @@ -533,13 +869,10 @@ def is_reloadable(self, data_dict: Dict[str, Any]) -> bool: if self.overwrite: return False - expected_filename = os.path.relpath( - self.filename_fn(data_dict, None, None), self.root_dir + is_reloadable = self._iterate_over_metadata_item( + self.metadata[key], self.check_reloadable, data_dict ) - self.check_reload.clear(expected_filename) - self._iterate_over_metadata_item(metadata[key], self.check_reload) - - return self.check_reload.reloadable is not None + return len(is_reloadable) and any(is_reloadable) def reload(self, data_dict: Dict[str, Any]) -> Dict[str, Any]: """Reload the data_dict from the saved file. @@ -554,19 +887,17 @@ def reload(self, data_dict: Dict[str, Any]) -> Dict[str, Any]: Dict[str, Any] The reloaded data_dict. """ - metadata = self._get_metadata() - key = self.metadata_key_fn(data_dict) - expected_filename = os.path.relpath( - self.filename_fn(data_dict, None, None), self.root_dir + key = self.metadata.key_fn(data_dict) + paths = self._iterate_over_metadata_item( + self.metadata[key], self.check_reloadable, data_dict ) - self.check_reload.clear(expected_filename) - self._iterate_over_metadata_item(metadata[key], self.check_reload) - if self.check_reload.reloadable is None: + if not (len(paths) and any(paths)): raise ValueError("Didn't find any file that can be reloaded.") + selected_path = [p for p in paths if p][0] return self._serialization_wrapper( self.reload_fn, - self.check_reload.reloadable, + selected_path, action="reload", ) diff --git a/brain_pipe/utils/list.py b/brain_pipe/utils/list.py index 7197b70..10c6370 100644 --- a/brain_pipe/utils/list.py +++ b/brain_pipe/utils/list.py @@ -1,17 +1,18 @@ """List utilities.""" +from typing import Union, Any, List, Tuple -def flatten(lst): - """Flatten a list. +def flatten(lst: Union[List[Any], Tuple[Any]]) -> List[Any]: + """Flatten a list or tuple (recursively). Parameters ---------- - lst: Union[list, tuple] - A list to be flattened. + lst: Union[List[Any], Tuple[Any]] + A list to be flattened. Sublists and tuples are also flattened. Returns ------- - list + List[Any] A flattened list. """ result = [] @@ -21,3 +22,22 @@ def flatten(lst): else: result.append(lst) return result + + +def wrap_in_list(obj: Any) -> List[Any]: + """Wrap an object in a list if it is not already a list. + + Parameters + ---------- + obj: Any + An object to be wrapped in a list. If it is already a list, it is returned + as is. + + Returns + ------- + List[Any] + A list containing the object. + """ + if not isinstance(obj, (list, tuple)): + obj = [obj] + return obj diff --git a/brain_pipe/utils/multiprocess.py b/brain_pipe/utils/multiprocess.py index 5a2690c..750a134 100644 --- a/brain_pipe/utils/multiprocess.py +++ b/brain_pipe/utils/multiprocess.py @@ -46,6 +46,7 @@ class MultiprocessingSingleton: """Singleton class for multiprocessing.""" manager = multiprocessing.Manager() + locks = {} to_clean = [] @@ -115,3 +116,21 @@ def clean(cls): pool.close() pool.join() cls.to_clean = [] + + @classmethod + def get_lock(cls, id_str): + """Create or get a lock for multiprocessing. + + Parameters + ---------- + id_str: str + Identifier for the lock. If the lock does not already exist in self.locks, + it will be created and added to self.locks. + + Returns + ------- + multiprocessing.Lock + """ + if id_str not in cls.locks: + cls.locks[id_str] = cls.manager.Lock() + return cls.locks[id_str] diff --git a/setup.cfg b/setup.cfg index 8ade7df..3a2361a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [flake8] # black default max-line-length = 88 -ignore = E203 +ignore = E203,W503 exclude = .git,__pycache__,docs/source/conf.py,old,build,dist,bids_preprocessing/utils,bids_preprocessing/tests,bids_preprocessing/preprocessing/cache \ No newline at end of file diff --git a/tests/test_data/truncated_sparrKULee/derivatives/stimuli/audiobook_7_5.data_dict.gz b/tests/test_data/truncated_sparrKULee/derivatives/stimuli/audiobook_7_5.data_dict.gz index 9789ff4..2abfdc9 100644 Binary files a/tests/test_data/truncated_sparrKULee/derivatives/stimuli/audiobook_7_5.data_dict.gz and b/tests/test_data/truncated_sparrKULee/derivatives/stimuli/audiobook_7_5.data_dict.gz differ diff --git a/tests/test_save/test_default.py b/tests/test_save/test_default.py index 46ed76f..6773e4f 100644 --- a/tests/test_save/test_default.py +++ b/tests/test_save/test_default.py @@ -9,8 +9,9 @@ from brain_pipe.save.default import ( default_metadata_key_fn, - default_filename_fn, + DefaultFilenameFn, DefaultSave, + DefaultSaveMetadata, ) from brain_pipe.utils.serialization import pickle_load_wrapper @@ -30,16 +31,17 @@ def test_defaultMetadataKeyFn(self): class DefaultFilenameFnTest(unittest.TestCase): def test_defaultFilenameFn(self): self.assertEqual( - default_filename_fn({"data_path": "a"}, "b", "c", "_-_"), "c_-_a_-_b.npy" + DefaultFilenameFn(path_keys=["a", "b", "c"])( + {"a": "c", "b": "a"}, "b", "d" + ), + "d_-_c_-_a_-_b.npy", ) self.assertEqual( - default_filename_fn( - {"data_path": "a", "stimulus_path": "d"}, None, None, "_" - ), + DefaultFilenameFn(separator="_")({"data_path": "a", "stimulus_path": "d"}), "a_d.data_dict", ) self.assertEqual( - default_filename_fn( + DefaultFilenameFn(separator="|")( { "data_path": "a", "stimulus_path": "d", @@ -47,7 +49,6 @@ def test_defaultFilenameFn(self): }, "b", "c", - "|", ), "c|a|d|123.123|b.npy", ) @@ -75,10 +76,11 @@ def setUp(self) -> None: self.tmp_dir = tempfile.TemporaryDirectory() def test_is_already_done(self): + metadata_path = os.path.join(self.tmp_dir.name, ".save_metadata.json") saver = DefaultSave( self.tmp_dir.name, filename_fn=self.MockupFilenameFn(), - metadata_key_fn=self.MockupMetadataKeyFn(), + metadata=DefaultSaveMetadata(key_fn=self.MockupMetadataKeyFn()), ) path = os.path.join(self.tmp_dir.name, "b") temp_dict = {"output_filename": path, "metadata_key": "d"} @@ -86,7 +88,7 @@ def test_is_already_done(self): self.assertFalse(saver.is_already_done(temp_dict)) # Make manually sure the file is already done/there. - with open(os.path.join(self.tmp_dir.name, ".save_metadata.json"), "w") as f: + with open(metadata_path, "w") as f: json.dump({"d": os.path.basename(path)}, f) with open(path, "w") as f: @@ -115,7 +117,7 @@ def test_is_reloadable(self): saver = DefaultSave( self.tmp_dir.name, filename_fn=self.MockupFilenameFn(), - metadata_key_fn=self.MockupMetadataKeyFn(), + metadata=DefaultSaveMetadata(key_fn=self.MockupMetadataKeyFn()), ) path = os.path.join(self.tmp_dir.name, "b") temp_dict = {"output_filename": path, "metadata_key": "d"} @@ -136,7 +138,7 @@ def test_is_reloadable(self): saver = DefaultSave( self.tmp_dir.name, filename_fn=self.MockupFilenameFn(), - metadata_key_fn=self.MockupMetadataKeyFn(), + metadata=DefaultSaveMetadata(key_fn=self.MockupMetadataKeyFn()), overwrite=True, ) # When overwrite is True, it should always return False @@ -154,7 +156,7 @@ def test_reload(self): saver = DefaultSave( self.tmp_dir.name, filename_fn=self.MockupFilenameFn(), - metadata_key_fn=self.MockupMetadataKeyFn(), + metadata=DefaultSaveMetadata(key_fn=self.MockupMetadataKeyFn()), ) path = os.path.join(self.tmp_dir.name, "b") temp_dict = {"output_filename": path, "metadata_key": "d"} @@ -200,7 +202,7 @@ def test_call(self): saver = DefaultSave( self.tmp_dir.name, filename_fn=self.MockupFilenameFn(), - metadata_key_fn=self.MockupMetadataKeyFn(), + metadata=DefaultSaveMetadata(key_fn=self.MockupMetadataKeyFn()), ) data_dict = {"output_filename": "a.pickle", "metadata_key": "b", "data": 123} output = saver(data_dict) @@ -221,7 +223,7 @@ def test_to_save(self): saver = DefaultSave( self.tmp_dir.name, filename_fn=self.MockupFilenameFn(), - metadata_key_fn=self.MockupMetadataKeyFn(), + metadata=DefaultSaveMetadata(key_fn=self.MockupMetadataKeyFn()), to_save={"envelope": "data"}, ) data_dict = {"output_filename": "a.npy", "metadata_key": "b", "data": 123} @@ -302,7 +304,7 @@ def test_clear_output(self): saver = DefaultSave( self.tmp_dir.name, filename_fn=self.MockupFilenameFn(), - metadata_key_fn=self.MockupMetadataKeyFn(), + metadata=DefaultSaveMetadata(key_fn=self.MockupMetadataKeyFn()), clear_output=True, ) self.assertEqual(saver(data_dict), {}) @@ -311,15 +313,23 @@ def test_multiple_savers(self): saver1 = DefaultSave( self.tmp_dir.name, filename_fn=self.MockupFilenameFn(), - metadata_key_fn=self.MockupMetadataKeyFn(), + metadata=DefaultSaveMetadata(key_fn=self.MockupMetadataKeyFn()), ) saver2 = DefaultSave( self.tmp_dir.name, to_save={"envelope": "data"}, filename_fn=self.MockupFilenameFn(output_override="a.npy"), - metadata_key_fn=self.MockupMetadataKeyFn(), + metadata=DefaultSaveMetadata(key_fn=self.MockupMetadataKeyFn()), ) + # Do check_done, check_reloadable and metadata have the correct savers attached? + self.assertEqual(saver1.check_done.saver, saver1) + self.assertEqual(saver2.check_done.saver, saver2) + self.assertEqual(saver1.check_reloadable.saver, saver1) + self.assertEqual(saver2.check_reloadable.saver, saver2) + self.assertEqual(saver1.metadata.saver, saver1) + self.assertEqual(saver2.metadata.saver, saver2) + data_dict = { "output_filename": "a.data_dict", "metadata_key": "b", diff --git a/tests/test_utils/test_list.py b/tests/test_utils/test_list.py new file mode 100644 index 0000000..9912b19 --- /dev/null +++ b/tests/test_utils/test_list.py @@ -0,0 +1,16 @@ +import unittest + +from brain_pipe.utils.list import flatten, wrap_in_list + + +class ListUtilsTest(unittest.TestCase): + def test_flatten(self): + flat_list = flatten([1, [2, [3, 4, [5], 6]]]) + self.assertEqual(flat_list, [1, 2, 3, 4, 5, 6]) + + def test_single_obj_to_list(self): + self.assertEqual(wrap_in_list(1), [1]) + self.assertEqual(wrap_in_list([1]), [1]) + self.assertEqual(wrap_in_list((1, 2)), (1, 2)) + self.assertEqual(wrap_in_list([]), []) + self.assertEqual(wrap_in_list(None), [None])