diff --git a/simpeg/dask/electromagnetics/frequency_domain/simulation.py b/simpeg/dask/electromagnetics/frequency_domain/simulation.py index 3e3f5c4393..8243560d25 100644 --- a/simpeg/dask/electromagnetics/frequency_domain/simulation.py +++ b/simpeg/dask/electromagnetics/frequency_domain/simulation.py @@ -5,7 +5,8 @@ from ...simulation import getJtJdiag, Jvec, Jtvec, Jmatrix import numpy as np import scipy.sparse as sp -from multiprocessing import cpu_count + +# from multiprocessing import cpu_count from dask import array, compute, delayed from simpeg.dask.utils import get_parallel_blocks from simpeg.electromagnetics.natural_source.sources import PlanewaveXYPrimary @@ -98,32 +99,39 @@ def getSourceTerm(self, freq, source=None): of the correct size """ if source is None: - source_list = self.survey.get_sources_by_frequency(freq) - source_block = np.array_split(source_list, cpu_count()) - - block_compute = [] - - if self.client: - sim = self.client.scatter(self) - - for block in source_block: - if len(block) == 0: - continue - - if self.client: - block_compute.append(self.client.submit(source_evaluation, sim, block)) - else: - block_compute.append(delayed(source_evaluation)(self, block)) - - if self.client: - blocks = self.client.gather(block_compute) - else: - blocks = compute(block_compute)[0] + # if self.client: + # n_splits = int(self.client.cluster.scheduler.total_nthreads / len(self.client.cluster.scheduler.workers)) + # else: + # n_splits = cpu_count() + # + # source_list = self.survey.get_sources_by_frequency(freq) + # source_block = np.array_split(source_list, n_splits) + # + # block_compute = [] + # + # if self.client: + # sim = self.client.scatter(self) + # source_block = self.client.scatter(source_block) + # + # for block in source_block: + # if self.client: + # block_compute.append(self.client.submit(source_evaluation, sim, block)) + # else: + # block_compute.append(delayed(source_evaluation)(self, block)) + # + # if self.client: + # blocks = self.client.gather(block_compute) + # else: + # blocks = compute(block_compute)[0] s_m, s_e = [], [] - for block in blocks: - if block[0]: - s_m += block[0] - s_e += block[1] + # for block in blocks: + # if block[0]: + for source in self.survey.get_sources_by_frequency(freq): + sm, se = source.eval(self) + s_m.append(sm) + s_e.append(se) + # s_m += block[0] + # s_e += block[1] else: sm, se = source.eval(self) diff --git a/simpeg/dask/electromagnetics/static/induced_polarization/simulation.py b/simpeg/dask/electromagnetics/static/induced_polarization/simulation.py index 4fbf0d46a5..8768265021 100644 --- a/simpeg/dask/electromagnetics/static/induced_polarization/simulation.py +++ b/simpeg/dask/electromagnetics/static/induced_polarization/simulation.py @@ -70,7 +70,7 @@ def dpred(self, m=None, f=None): return np.asarray(data) -def getJtJdiag(self, m, W=None): +def getJtJdiag(self, m, W=None, f=None): """ Return the diagonal of JtJ """ diff --git a/simpeg/dask/objective_function.py b/simpeg/dask/objective_function.py index c07c4fa046..a52608a8f3 100644 --- a/simpeg/dask/objective_function.py +++ b/simpeg/dask/objective_function.py @@ -55,10 +55,12 @@ def __init__( objfcts: list[BaseObjectiveFunction], multipliers=None, client: Client | None = None, + workers: list[str] | None = None, **kwargs, ): self._model: np.ndarray | None = None self.client = client + self.workers = workers super().__init__(objfcts=objfcts, multipliers=multipliers, **kwargs) @@ -98,6 +100,20 @@ def client(self, client): self._client = client + @property + def workers(self): + """ + List of worker addresses + """ + return self._workers + + @workers.setter + def workers(self, workers): + if not isinstance(workers, list | type(None)): + raise TypeError("workers must be a list of strings") + + self._workers = workers + def deriv(self, m, f=None): """ First derivative of the composite objective function is the sum of the @@ -276,7 +292,12 @@ def objfcts(self, objfcts): client = self.client futures, workers = _validate_type_or_future_of_type( - "objfcts", objfcts, L2DataMisfit, client, return_workers=True + "objfcts", + objfcts, + L2DataMisfit, + client, + workers=self.workers, + return_workers=True, ) for objfct, future in zip(objfcts, futures): if hasattr(objfct, "name"): @@ -307,10 +328,3 @@ def residuals(self, m, f=None): ) ) return client.gather(residuals) - - @property - def workers(self): - """ - Get the list of dask.distributed.workers associated with the objective functions. - """ - return self._workers