diff --git a/simpeg/dask/electromagnetics/frequency_domain/simulation.py b/simpeg/dask/electromagnetics/frequency_domain/simulation.py index 0b105b0e0b..d9d43297e7 100644 --- a/simpeg/dask/electromagnetics/frequency_domain/simulation.py +++ b/simpeg/dask/electromagnetics/frequency_domain/simulation.py @@ -1,20 +1,15 @@ from ....electromagnetics.frequency_domain.simulation import BaseFDEMSimulation as Sim - from ....utils import Zero +from ...simulation import getJtJdiag, Jvec, Jtvec, Jmatrix import numpy as np import scipy.sparse as sp from multiprocessing import cpu_count from dask import array, compute, delayed - from simpeg.dask.utils import get_parallel_blocks - from simpeg.electromagnetics.natural_source.sources import PlanewaveXYPrimary - import zarr -from tqdm import tqdm -@delayed def evaluate_receivers(block, mesh, fields): data = [] for source, _, receiver in block: @@ -23,7 +18,6 @@ def evaluate_receivers(block, mesh, fields): return np.hstack(data) -@delayed def source_evaluation(simulation, sources): s_m, s_e = [], [] for source in sources: @@ -34,7 +28,6 @@ def source_evaluation(simulation, sources): return s_m, s_e -@delayed def receiver_derivs(survey, mesh, fields, blocks): field_derivatives = [] for address in blocks: @@ -55,7 +48,6 @@ def receiver_derivs(survey, mesh, fields, blocks): return field_derivatives -@delayed def eval_block(simulation, Ainv_deriv_u, deriv_indices, deriv_m, fields, address): """ Evaluate the sensitivities for the block or data @@ -108,15 +100,23 @@ def getSourceTerm(self, freq, source=None): source_block = np.array_split(source_list, cpu_count()) block_compute = [] + + if self.client: + sim = self.client.scatter(self) + for block in source_block: if len(block) == 0: continue - block_compute.append( - self.client.submit(source_evaluation, self, block, workers=self.worker) - ) + if self.client: + block_compute.append(self.client.submit(source_evaluation, sim, block)) + else: + block_compute.append(delayed(source_evaluation)(self, block)) - blocks = self.client.gather(block_compute) + if self.client: + blocks = self.client.gather(block_compute) + else: + blocks = compute(block_compute)[0] s_m, s_e = [], [] for block in blocks: if block[0]: @@ -174,23 +174,32 @@ def dpred(self, m=None, f=None): for rx in src.receiver_list: all_receivers.append((src, ind, rx)) + if self.client: + f = self.client.scatter(f) + mesh = self.client.scatter(self.mesh) + receiver_blocks = np.array_split(np.asarray(all_receivers), cpu_count()) rows = [] - mesh = delayed(self.mesh) for block in receiver_blocks: n_data = np.sum([rec.nD for _, _, rec in block]) if n_data == 0: continue - rows.append( - array.from_delayed( - evaluate_receivers(block, mesh, f), - dtype=np.float64, - shape=(n_data,), + if self.client: + rows.append(self.client.submit(evaluate_receivers, block, mesh, f)) + else: + rows.append( + array.from_delayed( + delayed(evaluate_receivers, block, mesh, f), + dtype=np.float64, + shape=(n_data,), + ) ) - ) - data = compute(array.hstack(rows))[0] + if self.client: + data = np.hstack(self.client.gather(rows)) + else: + data = compute(array.hstack(rows))[0] return data @@ -199,8 +208,8 @@ def fields(self, m=None): if m is not None: self.model = m - # if getattr(self, "_stashed_fields", None) is not None: - # return self._stashed_fields + if getattr(self, "_stashed_fields", None) is not None: + return self._stashed_fields f = self.fieldsPair(self) Ainv = {} @@ -213,9 +222,9 @@ def fields(self, m=None): f[sources, self._solutionType] = u Ainv[freq] = Ainv_solve - # Ainv = Ainv - # - # self._stashed_fields = f + self.Ainv = Ainv + + self._stashed_fields = f return f @@ -226,19 +235,13 @@ def compute_J(self, m, f=None): if f is None: f = self.fields(m) - Ainv = {} - for freq in self.survey.frequencies: - A = self.getA(freq) - Ainv_solve = self.solver(sp.csr_matrix(A), **self.solver_opts) - Ainv[freq] = Ainv_solve - - if len(Ainv) > 1: + if len(self.Ainv) > 1: raise NotImplementedError( "Current implementation of parallelization assumes a single frequency per simulation. " "Consider creating one misfit per frequency." ) - A_i = list(Ainv.values())[0] + A_i = list(self.Ainv.values())[0] m_size = m.size if self.store_sensitivities == "disk": @@ -255,37 +258,51 @@ def compute_J(self, m, f=None): blocks = get_parallel_blocks( self.survey.source_list, compute_row_size, optimize=False ) - fields_array = delayed(f[:, self._solutionType]) - fields = delayed(f) - survey = delayed(self.survey) - mesh = delayed(self.mesh) - blocks_receiver_derivs = [] - - for block in blocks: - blocks_receiver_derivs.append( - receiver_derivs( - survey, - mesh, - fields, - block, - ) + if self.client: + fields_array = self.client.scatter(f[:, self._solutionType]) + fields = self.client.scatter(f) + survey = self.client.scatter(self.survey) + mesh = self.client.scatter(self.mesh) + blocks_receiver_derivs = self.client.map( + receiver_derivs, + [survey] * len(blocks), + [mesh] * len(blocks), + [fields] * len(blocks), + blocks, ) + else: + fields_array = delayed(f[:, self._solutionType]) + fields = delayed(f) + survey = delayed(self.survey) + mesh = delayed(self.mesh) + blocks_receiver_derivs = [] + delayed_derivs = delayed(receiver_derivs) + for block in blocks: + blocks_receiver_derivs.append( + delayed_derivs( + survey, + mesh, + fields, + block, + ) + ) # Dask process for all derivatives - blocks_receiver_derivs = compute(blocks_receiver_derivs)[0] + if self.client: + blocks_receiver_derivs = self.client.gather(blocks_receiver_derivs) + else: + blocks_receiver_derivs = compute(blocks_receiver_derivs)[0] - for block_derivs_chunks, addresses_chunks in tqdm( - zip(blocks_receiver_derivs, blocks), - ncols=len(blocks_receiver_derivs), - desc=f"Sensitivities at {list(Ainv)[0]} Hz", - ): + for block_derivs_chunks, addresses_chunks in zip(blocks_receiver_derivs, blocks): Jmatrix = self.parallel_block_compute( m, Jmatrix, block_derivs_chunks, A_i, fields_array, addresses_chunks ) - for A in Ainv.values(): + for A in self.Ainv.values(): A.clean() + del self.Ainv + if self.store_sensitivities == "disk": del Jmatrix Jmatrix = array.from_zarr(self.sensitivity_path) @@ -298,7 +315,13 @@ def parallel_block_compute( ): m_size = m.size block_stack = sp.hstack(blocks_receiver_derivs).toarray() - ATinvdf_duT = delayed(A_i * block_stack) + + ATinvdf_duT = A_i * block_stack + if self.client: + ATinvdf_duT = self.client.scatter(ATinvdf_duT) + sim = self.client.scatter(self) + else: + ATinvdf_duT = delayed(ATinvdf_duT) count = 0 rows = [] block_delayed = [] @@ -306,39 +329,63 @@ def parallel_block_compute( for address, dfduT in zip(addresses, blocks_receiver_derivs): n_cols = dfduT.shape[1] n_rows = address[1][2] - block_delayed.append( - array.from_delayed( - eval_block( - self, + + if self.client: + block_delayed.append( + self.client.submit( + eval_block, + sim, ATinvdf_duT, np.arange(count, count + n_cols), Zero(), fields_array, address, - ), - dtype=np.float32, - shape=(n_rows, m_size), + ) + ) + else: + delayed_eval = delayed(eval_block) + block_delayed.append( + array.from_delayed( + delayed_eval( + self, + ATinvdf_duT, + np.arange(count, count + n_cols), + Zero(), + fields_array, + address, + ), + dtype=np.float32, + shape=(n_rows, m_size), + ) ) - ) count += n_cols rows += address[1][1].tolist() indices = np.hstack(rows) + if self.client: + block = np.vstack(self.client.gather(block_delayed)) + else: + block = compute(array.vstack(block_delayed))[0] + if self.store_sensitivities == "disk": Jmatrix.set_orthogonal_selection( (indices, slice(None)), - compute(array.vstack(block_delayed))[0], + block, ) else: # Dask process to compute row and store - Jmatrix[indices, :] = compute(array.vstack(block_delayed))[0] + Jmatrix[indices, :] = block return Jmatrix Sim.parallel_block_compute = parallel_block_compute Sim.compute_J = compute_J +Sim.getJtJdiag = getJtJdiag +Sim.Jvec = Jvec +Sim.Jtvec = Jtvec +Sim.Jmatrix = Jmatrix Sim.fields = fields Sim.dpred = dpred Sim.getSourceTerm = getSourceTerm