Skip to content

Commit

Permalink
More optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
fourndo committed Nov 29, 2023
1 parent dded48e commit 2061306
Showing 1 changed file with 86 additions and 71 deletions.
157 changes: 86 additions & 71 deletions SimPEG/dask/electromagnetics/time_domain/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,31 +155,27 @@ def evaluate_receiver(source, receiver, mesh, time_mesh, fields):
Sim.field_derivs = None


def block_deriv(time_index, field_type, source_list, mesh, time_mesh, fields, Jmatrix):
@delayed
def block_deriv(
time_index, field_type, source, rx_count, mesh, time_mesh, fields, Jmatrix
):
"""Compute derivatives for sources and receivers in a block"""
field_len = len(fields[source_list[0], field_type, 0])
df_duT = {src: {} for src in source_list}

rx_count = 0
for src in source_list:
df_duT[src] = {rx: {} for rx in src.receiver_list}

for rx in src.receiver_list:
PTv = rx.getP(mesh, time_mesh, fields).tocsr()
derivative_fun = getattr(fields, "_{}Deriv".format(rx.projField), None)
rx_ind = np.arange(rx_count, rx_count + rx.nD).astype(int)

cur = derivative_fun(
time_index,
src,
None,
PTv[:, (time_index * field_len) : ((time_index + 1) * field_len)].T,
adjoint=True,
)
df_duT[src][rx] = cur[0]
Jmatrix[rx_ind, :] += cur[1].T
field_len = len(fields[source, field_type, 0])
df_duT = []

rx_count += rx.nD
for rx in source.receiver_list:
PTv = rx.getP(mesh, time_mesh, fields).tocsr()
derivative_fun = getattr(fields, "_{}Deriv".format(rx.projField), None)
rx_ind = np.arange(rx_count, rx_count + rx.nD).astype(int)
cur = derivative_fun(
time_index,
source,
None,
PTv[:, (time_index * field_len) : ((time_index + 1) * field_len)].T,
adjoint=True,
)
df_duT.append(cur[0])
Jmatrix[rx_ind, :] += cur[1].T

return df_duT

Expand All @@ -192,17 +188,24 @@ def compute_field_derivs(simulation, Jmatrix, fields):
df_duT = []

for time_index in range(simulation.nT + 1):
df_duT.append(
delayed(block_deriv, pure=True)(
time_index,
simulation._fieldType + "Solution",
simulation.survey.source_list,
simulation.mesh,
simulation.time_mesh,
fields,
Jmatrix,
rx_count = 0
sources_block = []
for source in simulation.survey.source_list:
sources_block.append(
block_deriv(
time_index,
simulation._fieldType + "Solution",
source,
rx_count,
simulation.mesh,
simulation.time_mesh,
fields,
Jmatrix,
)
)
)
rx_count += source.nD

df_duT.append(sources_block)

df_duT = dask.compute(df_duT)[0]

Expand All @@ -221,8 +224,8 @@ def get_parallel_blocks(source_list: list, m_size: int, max_chunk_size: int):
row_index = 0
block_count = 0
blocks = {0: {}}
for src in source_list:
for rx in src.receiver_list:
for s_id, src in enumerate(source_list):
for r_id, rx in enumerate(src.receiver_list):
indices = np.arange(rx.nD).astype(int)
chunks = np.split(indices, int(np.ceil(len(indices) / data_block_size)))

Expand All @@ -235,7 +238,7 @@ def get_parallel_blocks(source_list: list, m_size: int, max_chunk_size: int):
block_count += 1
blocks[block_count] = {}

blocks[block_count][(src, rx, ind)] = chunk, np.arange(
blocks[block_count][(s_id, r_id, ind)] = chunk, np.arange(
row_index, row_index + chunk_size
).astype(int)
row_index += chunk_size
Expand All @@ -253,89 +256,103 @@ def get_field_deriv_block(
stacked_blocks = []
indices = {}
count = 0
for (src, rx, ind), (rx_ind, j_ind) in block.items():
for (s_id, r_id, b_id), (rx_ind, j_ind) in block.items():
# Cut out early data
rx = simulation.survey.source_list[s_id].receiver_list[r_id]
time_check = np.kron(time_mask, np.ones(rx.locations.shape[0], dtype=bool))[
rx_ind
]
sub_ind = rx_ind[time_check]
local_ind = np.arange(rx_ind.shape[0])[time_check]

if len(sub_ind) < 1:
continue

indices[(src, rx, ind)] = (np.arange(count, count + len(sub_ind)), sub_ind)
indices[(s_id, r_id, b_id)] = (
np.arange(count, count + len(sub_ind)),
local_ind,
)
count += len(sub_ind)

if (src, rx, ind) not in ATinv_df_duT_v:
if (s_id, r_id, b_id) not in ATinv_df_duT_v:
# last timestep (first to be solved)
stacked_blocks.append(
simulation.field_derivs[tInd + 1][src][rx].toarray()[:, sub_ind]
simulation.field_derivs[tInd + 1][s_id][r_id].toarray()[:, sub_ind]
)

else:
Asubdiag = simulation.getAsubdiag(tInd + 1)
stacked_blocks.append(
np.asarray(
simulation.field_derivs[tInd + 1][src][rx][:, sub_ind]
- Asubdiag.T * ATinv_df_duT_v[(src, rx, ind)][:, sub_ind]
simulation.field_derivs[tInd + 1][s_id][r_id][:, sub_ind]
- Asubdiag.T * ATinv_df_duT_v[(s_id, r_id, b_id)][:, local_ind]
)
)

if len(stacked_blocks) > 1:
if len(stacked_blocks) > 0:
solve = AdiagTinv * np.hstack(stacked_blocks)

for src, rx, ind in block:
ATinv_df_duT_v[(src, rx, ind)] = np.zeros(
for s_id, r_id, b_id in block:
ATinv_df_duT_v[(s_id, r_id, b_id)] = np.zeros(
(
simulation.field_derivs[tInd][src][rx].shape[0],
len(block[(src, rx, ind)][0]),
simulation.field_derivs[tInd][s_id][r_id].shape[0],
len(block[(s_id, r_id, b_id)][0]),
)
)

if (src, rx, ind) in indices:
columns, sub_ind = indices[(src, rx, ind)]
ATinv_df_duT_v[(src, rx, ind)][:, sub_ind] = solve[:, columns]
if (s_id, r_id, b_id) in indices:
try:
columns, local_ind = indices[(s_id, r_id, b_id)]
ATinv_df_duT_v[(s_id, r_id, b_id)][:, local_ind] = solve[:, columns]
except:
print("ouch")

return ATinv_df_duT_v


@delayed
def compute_rows(
simulation,
tInd,
src,
rx,
rx_ind,
j_ind,
address, # (s_id, r_id, b_id)
indices, # (rx_ind, j_ind),
ATinv_df_duT_v,
f,
fields,
Jmatrix,
ftype,
time_mask,
):
"""
Compute the rows of the sensitivity matrix for a given source and receiver.
"""
time_check = np.kron(time_mask, np.ones(rx.locations.shape[0], dtype=bool))[rx_ind]
sub_ind = rx_ind[time_check]

if len(sub_ind) < 1:
src = simulation.survey.source_list[address[0]]
rx = src.receiver_list[address[1]]
time_check = np.kron(time_mask, np.ones(rx.locations.shape[0], dtype=bool))[
indices[0]
]
local_ind = np.arange(indices[0].shape[0])[time_check]

if len(local_ind) < 1:
return

dAsubdiagT_dm_v = simulation.getAsubdiagDeriv(
tInd, f[src, ftype, tInd], ATinv_df_duT_v[:, sub_ind], adjoint=True
tInd,
fields[src, ftype, tInd],
ATinv_df_duT_v[address][:, local_ind],
adjoint=True,
)

dRHST_dm_v = simulation.getRHSDeriv(
tInd + 1, src, ATinv_df_duT_v[:, sub_ind], adjoint=True
tInd + 1, src, ATinv_df_duT_v[address][:, local_ind], adjoint=True
) # on nodes of time mesh

un_src = f[src, ftype, tInd + 1]
un_src = fields[src, ftype, tInd + 1]
# cell centered on time mesh
dAT_dm_v = simulation.getAdiagDeriv(
tInd, un_src, ATinv_df_duT_v[:, sub_ind], adjoint=True
tInd, un_src, ATinv_df_duT_v[address][:, local_ind], adjoint=True
)

Jmatrix[j_ind[time_check], :] += (-dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v).T
Jmatrix[indices[1][time_check], :] += (-dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v).T


def compute_J(self, f=None, Ainv=None):
Expand Down Expand Up @@ -367,16 +384,14 @@ def compute_J(self, f=None, Ainv=None):
self, block, tInd, AdiagTinv, ATinv_df_duT_v, time_mask
)

for (src, rx, ind), (rx_ind, j_ind) in block.items():
for address, indices in block.items():
j_row_updates.append(
delayed(compute_rows, pure=True)(
compute_rows(
self,
tInd,
src,
rx,
rx_ind,
j_ind,
ATinv_df_duT_v[(src, rx, ind)],
address,
indices,
ATinv_df_duT_v,
f,
Jmatrix,
ftype,
Expand Down

0 comments on commit 2061306

Please sign in to comment.