diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index cfab3fd..fc132d3 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -20,12 +20,13 @@ jobs: run: black --check . test: - runs-on: ubuntu-latest strategy: fail-fast: false matrix: python: ["3.8", "3.9", "3.10", "3.11"] + os: [ubuntu-latest, windows-latest, macos-latest] + runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v3 - name: Setup Python diff --git a/brain_pipe/preprocessing/brain/link.py b/brain_pipe/preprocessing/brain/link.py index 152012d..e63a1e2 100644 --- a/brain_pipe/preprocessing/brain/link.py +++ b/brain_pipe/preprocessing/brain/link.py @@ -158,8 +158,8 @@ def default_multiprocessing_key_fn(data_dict): class LinkStimulusToBrainResponse(PipelineStep): """Link stimulus to Brain data.""" - multiprocessing_dict = MultiprocessingSingleton.manager.dict() - multiprocessing_condition = MultiprocessingSingleton.manager.Condition() + _multiprocessing_dict = None + _multiprocessing_condition = None def __init__( self, @@ -233,24 +233,39 @@ def __call__(self, data_dict: Dict[str, Any]) -> Dict[str, Any]: for stim_info in stimulus_info_from_brain: prototype_stim_dict = self.grouper(stim_info) key = self.key_fn_for_multiprocessing(prototype_stim_dict) - with self.multiprocessing_condition: + with self.get_multiprocessing_condition(): # Check if no other processes are already running this - while key in self.multiprocessing_dict: + while key in self.get_multiprocessing_dict(): # Wait for the process to finish - self.multiprocessing_condition.wait() - self.multiprocessing_dict[key] = True + self.get_multiprocessing_condition().wait() + self.get_multiprocessing_dict()[key] = True try: stimulus_dicts = self.stimulus_data(prototype_stim_dict) finally: # Notify all waiting processes of that this is done - with self.multiprocessing_condition: + with self.get_multiprocessing_condition(): # Remove the key from the multiprocessing dict to signal that # this specific stimulus is processed - del self.multiprocessing_dict[key] - self.multiprocessing_condition.notify_all() + del self.get_multiprocessing_dict()[key] + self.get_multiprocessing_condition().notify_all() if isinstance(stimulus_dicts, dict): stimulus_dicts = [stimulus_dicts] all_stimuli += stimulus_dicts data_dict[self.stimuli_key] = all_stimuli return data_dict + + @classmethod + def get_multiprocessing_dict(cls): + """Get the multiprocessing dict.""" + if cls._multiprocessing_dict is None: + cls._multiprocessing_dict = MultiprocessingSingleton.get_manager().dict() + return cls._multiprocessing_dict + + @classmethod + def get_multiprocessing_condition(cls): + """Get the multiprocessing condition.""" + if cls._multiprocessing_condition is None: + cls._multiprocessing_condition = MultiprocessingSingleton.get_manager().Condition() + return cls._multiprocessing_condition + diff --git a/brain_pipe/utils/parallellization.py b/brain_pipe/utils/parallellization.py index 803e13e..36a50d0 100644 --- a/brain_pipe/utils/parallellization.py +++ b/brain_pipe/utils/parallellization.py @@ -45,7 +45,7 @@ def __call__(self, result): class MultiprocessingSingleton: """Singleton class for multiprocessing.""" - manager = multiprocess.Manager() + _manager = None locks = {} to_clean = [] @@ -132,5 +132,15 @@ def get_lock(cls, id_str): multiprocessing.Lock """ if id_str not in cls.locks: - cls.locks[id_str] = cls.manager.Lock() + cls.locks[id_str] = cls.get_manager().Lock() return cls.locks[id_str] + + @classmethod + def get_manager(cls): + """Get the multiprocessing manager.""" + if cls._manager is None: + cls._manager = multiprocess.Manager() + return cls._manager + + +