Skip to content

Commit

Permalink
Lazy load the multiprocessing code for Windows (test also on Windows)
Browse files Browse the repository at this point in the history
  • Loading branch information
berndie committed Aug 29, 2024
1 parent 187aac0 commit 360f0eb
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 12 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@ jobs:
run: black --check .

test:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python: ["3.8", "3.9", "3.10", "3.11"]
os: [ubuntu-latest, windows-latest, macos-latest]

runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v3
- name: Setup Python
Expand Down
33 changes: 24 additions & 9 deletions brain_pipe/preprocessing/brain/link.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ def default_multiprocessing_key_fn(data_dict):
class LinkStimulusToBrainResponse(PipelineStep):
"""Link stimulus to Brain data."""

multiprocessing_dict = MultiprocessingSingleton.manager.dict()
multiprocessing_condition = MultiprocessingSingleton.manager.Condition()
_multiprocessing_dict = None
_multiprocessing_condition = None

def __init__(
self,
Expand Down Expand Up @@ -233,24 +233,39 @@ def __call__(self, data_dict: Dict[str, Any]) -> Dict[str, Any]:
for stim_info in stimulus_info_from_brain:
prototype_stim_dict = self.grouper(stim_info)
key = self.key_fn_for_multiprocessing(prototype_stim_dict)
with self.multiprocessing_condition:
with self.get_multiprocessing_condition():
# Check if no other processes are already running this
while key in self.multiprocessing_dict:
while key in self.get_multiprocessing_dict():
# Wait for the process to finish
self.multiprocessing_condition.wait()
self.multiprocessing_dict[key] = True
self.get_multiprocessing_condition().wait()
self.get_multiprocessing_dict()[key] = True
try:
stimulus_dicts = self.stimulus_data(prototype_stim_dict)
finally:
# Notify all waiting processes of that this is done
with self.multiprocessing_condition:
with self.get_multiprocessing_condition():
# Remove the key from the multiprocessing dict to signal that
# this specific stimulus is processed
del self.multiprocessing_dict[key]
self.multiprocessing_condition.notify_all()
del self.get_multiprocessing_dict()[key]
self.get_multiprocessing_condition().notify_all()
if isinstance(stimulus_dicts, dict):
stimulus_dicts = [stimulus_dicts]
all_stimuli += stimulus_dicts

data_dict[self.stimuli_key] = all_stimuli
return data_dict

@classmethod
def get_multiprocessing_dict(cls):
"""Get the multiprocessing dict."""
if cls._multiprocessing_dict is None:
cls._multiprocessing_dict = MultiprocessingSingleton.get_manager().dict()
return cls._multiprocessing_dict

@classmethod
def get_multiprocessing_condition(cls):
"""Get the multiprocessing condition."""
if cls._multiprocessing_condition is None:
cls._multiprocessing_condition = MultiprocessingSingleton.get_manager().Condition()
return cls._multiprocessing_condition

14 changes: 12 additions & 2 deletions brain_pipe/utils/parallellization.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __call__(self, result):
class MultiprocessingSingleton:
"""Singleton class for multiprocessing."""

manager = multiprocess.Manager()
_manager = None
locks = {}

to_clean = []
Expand Down Expand Up @@ -132,5 +132,15 @@ def get_lock(cls, id_str):
multiprocessing.Lock
"""
if id_str not in cls.locks:
cls.locks[id_str] = cls.manager.Lock()
cls.locks[id_str] = cls.get_manager().Lock()
return cls.locks[id_str]

@classmethod
def get_manager(cls):
"""Get the multiprocessing manager."""
if cls._manager is None:
cls._manager = multiprocess.Manager()
return cls._manager



0 comments on commit 360f0eb

Please sign in to comment.