Skip to content

Commit

Permalink
Outfit frequency simulations with client option
Browse files Browse the repository at this point in the history
  • Loading branch information
domfournier committed Jan 9, 2025
1 parent 3cf2f4d commit e647079
Showing 1 changed file with 113 additions and 66 deletions.
179 changes: 113 additions & 66 deletions simpeg/dask/electromagnetics/frequency_domain/simulation.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,15 @@
from ....electromagnetics.frequency_domain.simulation import BaseFDEMSimulation as Sim

from ....utils import Zero
from ...simulation import getJtJdiag, Jvec, Jtvec, Jmatrix
import numpy as np
import scipy.sparse as sp
from multiprocessing import cpu_count
from dask import array, compute, delayed

from simpeg.dask.utils import get_parallel_blocks

from simpeg.electromagnetics.natural_source.sources import PlanewaveXYPrimary

import zarr
from tqdm import tqdm


@delayed
def evaluate_receivers(block, mesh, fields):
data = []
for source, _, receiver in block:
Expand All @@ -23,7 +18,6 @@ def evaluate_receivers(block, mesh, fields):
return np.hstack(data)


@delayed
def source_evaluation(simulation, sources):
s_m, s_e = [], []
for source in sources:
Expand All @@ -34,7 +28,6 @@ def source_evaluation(simulation, sources):
return s_m, s_e


@delayed
def receiver_derivs(survey, mesh, fields, blocks):
field_derivatives = []
for address in blocks:
Expand All @@ -55,7 +48,6 @@ def receiver_derivs(survey, mesh, fields, blocks):
return field_derivatives


@delayed
def eval_block(simulation, Ainv_deriv_u, deriv_indices, deriv_m, fields, address):
"""
Evaluate the sensitivities for the block or data
Expand Down Expand Up @@ -108,15 +100,23 @@ def getSourceTerm(self, freq, source=None):
source_block = np.array_split(source_list, cpu_count())

block_compute = []

if self.client:
sim = self.client.scatter(self)

for block in source_block:
if len(block) == 0:
continue

block_compute.append(
self.client.submit(source_evaluation, self, block, workers=self.worker)
)
if self.client:
block_compute.append(self.client.submit(source_evaluation, sim, block))
else:
block_compute.append(delayed(source_evaluation)(self, block))

blocks = self.client.gather(block_compute)
if self.client:
blocks = self.client.gather(block_compute)
else:
blocks = compute(block_compute)[0]
s_m, s_e = [], []
for block in blocks:
if block[0]:
Expand Down Expand Up @@ -174,23 +174,32 @@ def dpred(self, m=None, f=None):
for rx in src.receiver_list:
all_receivers.append((src, ind, rx))

if self.client:
f = self.client.scatter(f)
mesh = self.client.scatter(self.mesh)

receiver_blocks = np.array_split(np.asarray(all_receivers), cpu_count())
rows = []
mesh = delayed(self.mesh)
for block in receiver_blocks:
n_data = np.sum([rec.nD for _, _, rec in block])
if n_data == 0:
continue

rows.append(
array.from_delayed(
evaluate_receivers(block, mesh, f),
dtype=np.float64,
shape=(n_data,),
if self.client:
rows.append(self.client.submit(evaluate_receivers, block, mesh, f))
else:
rows.append(
array.from_delayed(
delayed(evaluate_receivers, block, mesh, f),
dtype=np.float64,
shape=(n_data,),
)
)
)

data = compute(array.hstack(rows))[0]
if self.client:
data = np.hstack(self.client.gather(rows))
else:
data = compute(array.hstack(rows))[0]

return data

Expand All @@ -199,8 +208,8 @@ def fields(self, m=None):
if m is not None:
self.model = m

# if getattr(self, "_stashed_fields", None) is not None:
# return self._stashed_fields
if getattr(self, "_stashed_fields", None) is not None:
return self._stashed_fields

f = self.fieldsPair(self)
Ainv = {}
Expand All @@ -213,9 +222,9 @@ def fields(self, m=None):
f[sources, self._solutionType] = u
Ainv[freq] = Ainv_solve

# Ainv = Ainv
#
# self._stashed_fields = f
self.Ainv = Ainv

self._stashed_fields = f

return f

Expand All @@ -226,19 +235,13 @@ def compute_J(self, m, f=None):
if f is None:
f = self.fields(m)

Ainv = {}
for freq in self.survey.frequencies:
A = self.getA(freq)
Ainv_solve = self.solver(sp.csr_matrix(A), **self.solver_opts)
Ainv[freq] = Ainv_solve

if len(Ainv) > 1:
if len(self.Ainv) > 1:
raise NotImplementedError(
"Current implementation of parallelization assumes a single frequency per simulation. "
"Consider creating one misfit per frequency."
)

A_i = list(Ainv.values())[0]
A_i = list(self.Ainv.values())[0]
m_size = m.size

if self.store_sensitivities == "disk":
Expand All @@ -255,37 +258,51 @@ def compute_J(self, m, f=None):
blocks = get_parallel_blocks(
self.survey.source_list, compute_row_size, optimize=False
)
fields_array = delayed(f[:, self._solutionType])
fields = delayed(f)
survey = delayed(self.survey)
mesh = delayed(self.mesh)
blocks_receiver_derivs = []

for block in blocks:
blocks_receiver_derivs.append(
receiver_derivs(
survey,
mesh,
fields,
block,
)
if self.client:
fields_array = self.client.scatter(f[:, self._solutionType])
fields = self.client.scatter(f)
survey = self.client.scatter(self.survey)
mesh = self.client.scatter(self.mesh)
blocks_receiver_derivs = self.client.map(
receiver_derivs,
[survey] * len(blocks),
[mesh] * len(blocks),
[fields] * len(blocks),
blocks,
)
else:
fields_array = delayed(f[:, self._solutionType])
fields = delayed(f)
survey = delayed(self.survey)
mesh = delayed(self.mesh)
blocks_receiver_derivs = []
delayed_derivs = delayed(receiver_derivs)
for block in blocks:
blocks_receiver_derivs.append(
delayed_derivs(
survey,
mesh,
fields,
block,
)
)

# Dask process for all derivatives
blocks_receiver_derivs = compute(blocks_receiver_derivs)[0]
if self.client:
blocks_receiver_derivs = self.client.gather(blocks_receiver_derivs)
else:
blocks_receiver_derivs = compute(blocks_receiver_derivs)[0]

for block_derivs_chunks, addresses_chunks in tqdm(
zip(blocks_receiver_derivs, blocks),
ncols=len(blocks_receiver_derivs),
desc=f"Sensitivities at {list(Ainv)[0]} Hz",
):
for block_derivs_chunks, addresses_chunks in zip(blocks_receiver_derivs, blocks):
Jmatrix = self.parallel_block_compute(
m, Jmatrix, block_derivs_chunks, A_i, fields_array, addresses_chunks
)

for A in Ainv.values():
for A in self.Ainv.values():
A.clean()

del self.Ainv

if self.store_sensitivities == "disk":
del Jmatrix
Jmatrix = array.from_zarr(self.sensitivity_path)
Expand All @@ -298,47 +315,77 @@ def parallel_block_compute(
):
m_size = m.size
block_stack = sp.hstack(blocks_receiver_derivs).toarray()
ATinvdf_duT = delayed(A_i * block_stack)

ATinvdf_duT = A_i * block_stack
if self.client:
ATinvdf_duT = self.client.scatter(ATinvdf_duT)
sim = self.client.scatter(self)
else:
ATinvdf_duT = delayed(ATinvdf_duT)
count = 0
rows = []
block_delayed = []

for address, dfduT in zip(addresses, blocks_receiver_derivs):
n_cols = dfduT.shape[1]
n_rows = address[1][2]
block_delayed.append(
array.from_delayed(
eval_block(
self,

if self.client:
block_delayed.append(
self.client.submit(
eval_block,
sim,
ATinvdf_duT,
np.arange(count, count + n_cols),
Zero(),
fields_array,
address,
),
dtype=np.float32,
shape=(n_rows, m_size),
)
)
else:
delayed_eval = delayed(eval_block)
block_delayed.append(
array.from_delayed(
delayed_eval(
self,
ATinvdf_duT,
np.arange(count, count + n_cols),
Zero(),
fields_array,
address,
),
dtype=np.float32,
shape=(n_rows, m_size),
)
)
)
count += n_cols
rows += address[1][1].tolist()

indices = np.hstack(rows)

if self.client:
block = np.vstack(self.client.gather(block_delayed))
else:
block = compute(array.vstack(block_delayed))[0]

if self.store_sensitivities == "disk":
Jmatrix.set_orthogonal_selection(
(indices, slice(None)),
compute(array.vstack(block_delayed))[0],
block,
)
else:
# Dask process to compute row and store
Jmatrix[indices, :] = compute(array.vstack(block_delayed))[0]
Jmatrix[indices, :] = block

return Jmatrix


Sim.parallel_block_compute = parallel_block_compute
Sim.compute_J = compute_J
Sim.getJtJdiag = getJtJdiag
Sim.Jvec = Jvec
Sim.Jtvec = Jtvec
Sim.Jmatrix = Jmatrix
Sim.fields = fields
Sim.dpred = dpred
Sim.getSourceTerm = getSourceTerm

0 comments on commit e647079

Please sign in to comment.