Skip to content

Commit

Permalink
More parallel blocks
Browse files Browse the repository at this point in the history
  • Loading branch information
fourndo committed Nov 29, 2023
1 parent 2061306 commit 9ce0956
Showing 1 changed file with 90 additions and 61 deletions.
151 changes: 90 additions & 61 deletions SimPEG/dask/electromagnetics/time_domain/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,26 +156,30 @@ def evaluate_receiver(source, receiver, mesh, time_mesh, fields):


@delayed
def block_deriv(
time_index, field_type, source, rx_count, mesh, time_mesh, fields, Jmatrix
):
def block_deriv(time_index, field_type, source_list, mesh, time_mesh, fields, Jmatrix):
"""Compute derivatives for sources and receivers in a block"""
field_len = len(fields[source, field_type, 0])
field_len = len(fields[source_list[0], field_type, 0])
df_duT = []
rx_count = 0
for source in source_list:
sources_block = []

for rx in source.receiver_list:
PTv = rx.getP(mesh, time_mesh, fields).tocsr()
derivative_fun = getattr(fields, "_{}Deriv".format(rx.projField), None)
rx_ind = np.arange(rx_count, rx_count + rx.nD).astype(int)
cur = derivative_fun(
time_index,
source,
None,
PTv[:, (time_index * field_len) : ((time_index + 1) * field_len)].T,
adjoint=True,
)
df_duT.append(cur[0])
Jmatrix[rx_ind, :] += cur[1].T
for rx in source.receiver_list:
PTv = rx.getP(mesh, time_mesh, fields).tocsr()
derivative_fun = getattr(fields, "_{}Deriv".format(rx.projField), None)
rx_ind = np.arange(rx_count, rx_count + rx.nD).astype(int)
cur = derivative_fun(
time_index,
source,
None,
PTv[:, (time_index * field_len) : ((time_index + 1) * field_len)].T,
adjoint=True,
)
sources_block.append(cur[0])
Jmatrix[rx_ind, :] += cur[1].T
rx_count += rx.nD

df_duT.append(sources_block)

return df_duT

Expand All @@ -184,28 +188,20 @@ def compute_field_derivs(simulation, Jmatrix, fields):
"""
Compute the derivative of the fields
"""

df_duT = []

for time_index in range(simulation.nT + 1):
rx_count = 0
sources_block = []
for source in simulation.survey.source_list:
sources_block.append(
block_deriv(
time_index,
simulation._fieldType + "Solution",
source,
rx_count,
simulation.mesh,
simulation.time_mesh,
fields,
Jmatrix,
)
df_duT.append(
block_deriv(
time_index,
simulation._fieldType + "Solution",
simulation.survey.source_list,
simulation.mesh,
simulation.time_mesh,
fields,
Jmatrix,
)
rx_count += source.nD

df_duT.append(sources_block)
)

df_duT = dask.compute(df_duT)[0]

Expand Down Expand Up @@ -247,6 +243,30 @@ def get_parallel_blocks(source_list: list, m_size: int, max_chunk_size: int):
return blocks


@delayed
def deriv_block(
s_id, r_id, b_id, ATinv_df_duT_v, Asubdiag, local_ind, sub_ind, simulation, tInd
):
if (s_id, r_id, b_id) not in ATinv_df_duT_v:
# last timestep (first to be solved)
stacked_block = simulation.field_derivs[tInd + 1][s_id][r_id].toarray()[
:, sub_ind
]

else:
stacked_block = np.asarray(
simulation.field_derivs[tInd + 1][s_id][r_id][:, sub_ind]
- Asubdiag.T * ATinv_df_duT_v[(s_id, r_id, b_id)][:, local_ind]
)

return stacked_block


def update_deriv_blocks(address, indices, derivatives, solve):
columns, local_ind = indices[address]
derivatives[:, local_ind] = solve[:, columns]


def get_field_deriv_block(
simulation, block: dict, tInd: int, AdiagTinv, ATinv_df_duT_v: dict, time_mask
):
Expand All @@ -256,6 +276,11 @@ def get_field_deriv_block(
stacked_blocks = []
indices = {}
count = 0

Asubdiag = None
if tInd < simulation.nT - 1:
Asubdiag = simulation.getAsubdiag(tInd + 1)

for (s_id, r_id, b_id), (rx_ind, j_ind) in block.items():
# Cut out early data
rx = simulation.survey.source_list[s_id].receiver_list[r_id]
Expand All @@ -274,38 +299,44 @@ def get_field_deriv_block(
)
count += len(sub_ind)

if (s_id, r_id, b_id) not in ATinv_df_duT_v:
# last timestep (first to be solved)
stacked_blocks.append(
simulation.field_derivs[tInd + 1][s_id][r_id].toarray()[:, sub_ind]
)

else:
Asubdiag = simulation.getAsubdiag(tInd + 1)
stacked_blocks.append(
np.asarray(
simulation.field_derivs[tInd + 1][s_id][r_id][:, sub_ind]
- Asubdiag.T * ATinv_df_duT_v[(s_id, r_id, b_id)][:, local_ind]
)
stacked_blocks.append(
deriv_block(
s_id,
r_id,
b_id,
ATinv_df_duT_v,
Asubdiag,
local_ind,
sub_ind,
simulation,
tInd,
)
)

if len(stacked_blocks) > 0:
solve = AdiagTinv * np.hstack(stacked_blocks)
solve = AdiagTinv * np.hstack(dask.compute(stacked_blocks)[0])

update_list = []
for s_id, r_id, b_id in block:
ATinv_df_duT_v[(s_id, r_id, b_id)] = np.zeros(
(
simulation.field_derivs[tInd][s_id][r_id].shape[0],
len(block[(s_id, r_id, b_id)][0]),
if (s_id, r_id, b_id) not in ATinv_df_duT_v:
ATinv_df_duT_v[(s_id, r_id, b_id)] = np.zeros(
(
simulation.field_derivs[tInd][s_id][r_id].shape[0],
len(block[(s_id, r_id, b_id)][0]),
)
)
)

if (s_id, r_id, b_id) in indices:
try:
columns, local_ind = indices[(s_id, r_id, b_id)]
ATinv_df_duT_v[(s_id, r_id, b_id)][:, local_ind] = solve[:, columns]
except:
print("ouch")
update_list.append(
update_deriv_blocks(
(s_id, r_id, b_id),
indices,
ATinv_df_duT_v[(s_id, r_id, b_id)],
solve,
)
)

dask.compute(update_list)

return ATinv_df_duT_v

Expand Down Expand Up @@ -370,9 +401,7 @@ def compute_J(self, f=None, Ainv=None):
blocks = get_parallel_blocks(
self.survey.source_list, self.model.shape[0], self.max_chunk_size
)

self.field_derivs = compute_field_derivs(self, Jmatrix, f)

ATinv_df_duT_v = {}
for tInd, dt in tqdm(zip(reversed(range(self.nT)), reversed(self.time_steps))):
AdiagTinv = Ainv[dt]
Expand Down

0 comments on commit 9ce0956

Please sign in to comment.