Skip to content

Commit

Permalink
Remove parallel compute of RHS
Browse files Browse the repository at this point in the history
  • Loading branch information
domfournier committed Jan 13, 2025
1 parent 96c3a5f commit e882cf1
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 35 deletions.
60 changes: 34 additions & 26 deletions simpeg/dask/electromagnetics/frequency_domain/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from ...simulation import getJtJdiag, Jvec, Jtvec, Jmatrix
import numpy as np
import scipy.sparse as sp
from multiprocessing import cpu_count

# from multiprocessing import cpu_count
from dask import array, compute, delayed
from simpeg.dask.utils import get_parallel_blocks
from simpeg.electromagnetics.natural_source.sources import PlanewaveXYPrimary
Expand Down Expand Up @@ -98,32 +99,39 @@ def getSourceTerm(self, freq, source=None):
of the correct size
"""
if source is None:
source_list = self.survey.get_sources_by_frequency(freq)
source_block = np.array_split(source_list, cpu_count())

block_compute = []

if self.client:
sim = self.client.scatter(self)

for block in source_block:
if len(block) == 0:
continue

if self.client:
block_compute.append(self.client.submit(source_evaluation, sim, block))
else:
block_compute.append(delayed(source_evaluation)(self, block))

if self.client:
blocks = self.client.gather(block_compute)
else:
blocks = compute(block_compute)[0]
# if self.client:
# n_splits = int(self.client.cluster.scheduler.total_nthreads / len(self.client.cluster.scheduler.workers))
# else:
# n_splits = cpu_count()
#
# source_list = self.survey.get_sources_by_frequency(freq)
# source_block = np.array_split(source_list, n_splits)
#
# block_compute = []
#
# if self.client:
# sim = self.client.scatter(self)
# source_block = self.client.scatter(source_block)
#
# for block in source_block:
# if self.client:
# block_compute.append(self.client.submit(source_evaluation, sim, block))
# else:
# block_compute.append(delayed(source_evaluation)(self, block))
#
# if self.client:
# blocks = self.client.gather(block_compute)
# else:
# blocks = compute(block_compute)[0]
s_m, s_e = [], []
for block in blocks:
if block[0]:
s_m += block[0]
s_e += block[1]
# for block in blocks:
# if block[0]:
for source in self.survey.get_sources_by_frequency(freq):
sm, se = source.eval(self)
s_m.append(sm)
s_e.append(se)
# s_m += block[0]
# s_e += block[1]

else:
sm, se = source.eval(self)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def dpred(self, m=None, f=None):
return np.asarray(data)


def getJtJdiag(self, m, W=None):
def getJtJdiag(self, m, W=None, f=None):
"""
Return the diagonal of JtJ
"""
Expand Down
30 changes: 22 additions & 8 deletions simpeg/dask/objective_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,12 @@ def __init__(
objfcts: list[BaseObjectiveFunction],
multipliers=None,
client: Client | None = None,
workers: list[str] | None = None,
**kwargs,
):
self._model: np.ndarray | None = None
self.client = client
self.workers = workers

super().__init__(objfcts=objfcts, multipliers=multipliers, **kwargs)

Expand Down Expand Up @@ -98,6 +100,20 @@ def client(self, client):

self._client = client

@property
def workers(self):
"""
List of worker addresses
"""
return self._workers

@workers.setter
def workers(self, workers):
if not isinstance(workers, list | type(None)):
raise TypeError("workers must be a list of strings")

self._workers = workers

def deriv(self, m, f=None):
"""
First derivative of the composite objective function is the sum of the
Expand Down Expand Up @@ -276,7 +292,12 @@ def objfcts(self, objfcts):
client = self.client

futures, workers = _validate_type_or_future_of_type(
"objfcts", objfcts, L2DataMisfit, client, return_workers=True
"objfcts",
objfcts,
L2DataMisfit,
client,
workers=self.workers,
return_workers=True,
)
for objfct, future in zip(objfcts, futures):
if hasattr(objfct, "name"):
Expand Down Expand Up @@ -307,10 +328,3 @@ def residuals(self, m, f=None):
)
)
return client.gather(residuals)

@property
def workers(self):
"""
Get the list of dask.distributed.workers associated with the objective functions.
"""
return self._workers

0 comments on commit e882cf1

Please sign in to comment.