From 2061306ba14debb5c82e5bc0c8ad5ef229cc7e58 Mon Sep 17 00:00:00 2001 From: fourndo Date: Wed, 29 Nov 2023 10:45:37 -0800 Subject: [PATCH] More optimization --- .../time_domain/simulation.py | 157 ++++++++++-------- 1 file changed, 86 insertions(+), 71 deletions(-) diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index 997908bebf..fe9b7cca08 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -155,31 +155,27 @@ def evaluate_receiver(source, receiver, mesh, time_mesh, fields): Sim.field_derivs = None -def block_deriv(time_index, field_type, source_list, mesh, time_mesh, fields, Jmatrix): +@delayed +def block_deriv( + time_index, field_type, source, rx_count, mesh, time_mesh, fields, Jmatrix +): """Compute derivatives for sources and receivers in a block""" - field_len = len(fields[source_list[0], field_type, 0]) - df_duT = {src: {} for src in source_list} - - rx_count = 0 - for src in source_list: - df_duT[src] = {rx: {} for rx in src.receiver_list} - - for rx in src.receiver_list: - PTv = rx.getP(mesh, time_mesh, fields).tocsr() - derivative_fun = getattr(fields, "_{}Deriv".format(rx.projField), None) - rx_ind = np.arange(rx_count, rx_count + rx.nD).astype(int) - - cur = derivative_fun( - time_index, - src, - None, - PTv[:, (time_index * field_len) : ((time_index + 1) * field_len)].T, - adjoint=True, - ) - df_duT[src][rx] = cur[0] - Jmatrix[rx_ind, :] += cur[1].T + field_len = len(fields[source, field_type, 0]) + df_duT = [] - rx_count += rx.nD + for rx in source.receiver_list: + PTv = rx.getP(mesh, time_mesh, fields).tocsr() + derivative_fun = getattr(fields, "_{}Deriv".format(rx.projField), None) + rx_ind = np.arange(rx_count, rx_count + rx.nD).astype(int) + cur = derivative_fun( + time_index, + source, + None, + PTv[:, (time_index * field_len) : ((time_index + 1) * field_len)].T, + adjoint=True, + ) + df_duT.append(cur[0]) + Jmatrix[rx_ind, :] += cur[1].T return df_duT @@ -192,17 +188,24 @@ def compute_field_derivs(simulation, Jmatrix, fields): df_duT = [] for time_index in range(simulation.nT + 1): - df_duT.append( - delayed(block_deriv, pure=True)( - time_index, - simulation._fieldType + "Solution", - simulation.survey.source_list, - simulation.mesh, - simulation.time_mesh, - fields, - Jmatrix, + rx_count = 0 + sources_block = [] + for source in simulation.survey.source_list: + sources_block.append( + block_deriv( + time_index, + simulation._fieldType + "Solution", + source, + rx_count, + simulation.mesh, + simulation.time_mesh, + fields, + Jmatrix, + ) ) - ) + rx_count += source.nD + + df_duT.append(sources_block) df_duT = dask.compute(df_duT)[0] @@ -221,8 +224,8 @@ def get_parallel_blocks(source_list: list, m_size: int, max_chunk_size: int): row_index = 0 block_count = 0 blocks = {0: {}} - for src in source_list: - for rx in src.receiver_list: + for s_id, src in enumerate(source_list): + for r_id, rx in enumerate(src.receiver_list): indices = np.arange(rx.nD).astype(int) chunks = np.split(indices, int(np.ceil(len(indices) / data_block_size))) @@ -235,7 +238,7 @@ def get_parallel_blocks(source_list: list, m_size: int, max_chunk_size: int): block_count += 1 blocks[block_count] = {} - blocks[block_count][(src, rx, ind)] = chunk, np.arange( + blocks[block_count][(s_id, r_id, ind)] = chunk, np.arange( row_index, row_index + chunk_size ).astype(int) row_index += chunk_size @@ -253,61 +256,68 @@ def get_field_deriv_block( stacked_blocks = [] indices = {} count = 0 - for (src, rx, ind), (rx_ind, j_ind) in block.items(): + for (s_id, r_id, b_id), (rx_ind, j_ind) in block.items(): # Cut out early data + rx = simulation.survey.source_list[s_id].receiver_list[r_id] time_check = np.kron(time_mask, np.ones(rx.locations.shape[0], dtype=bool))[ rx_ind ] sub_ind = rx_ind[time_check] + local_ind = np.arange(rx_ind.shape[0])[time_check] if len(sub_ind) < 1: continue - indices[(src, rx, ind)] = (np.arange(count, count + len(sub_ind)), sub_ind) + indices[(s_id, r_id, b_id)] = ( + np.arange(count, count + len(sub_ind)), + local_ind, + ) count += len(sub_ind) - if (src, rx, ind) not in ATinv_df_duT_v: + if (s_id, r_id, b_id) not in ATinv_df_duT_v: # last timestep (first to be solved) stacked_blocks.append( - simulation.field_derivs[tInd + 1][src][rx].toarray()[:, sub_ind] + simulation.field_derivs[tInd + 1][s_id][r_id].toarray()[:, sub_ind] ) else: Asubdiag = simulation.getAsubdiag(tInd + 1) stacked_blocks.append( np.asarray( - simulation.field_derivs[tInd + 1][src][rx][:, sub_ind] - - Asubdiag.T * ATinv_df_duT_v[(src, rx, ind)][:, sub_ind] + simulation.field_derivs[tInd + 1][s_id][r_id][:, sub_ind] + - Asubdiag.T * ATinv_df_duT_v[(s_id, r_id, b_id)][:, local_ind] ) ) - if len(stacked_blocks) > 1: + if len(stacked_blocks) > 0: solve = AdiagTinv * np.hstack(stacked_blocks) - for src, rx, ind in block: - ATinv_df_duT_v[(src, rx, ind)] = np.zeros( + for s_id, r_id, b_id in block: + ATinv_df_duT_v[(s_id, r_id, b_id)] = np.zeros( ( - simulation.field_derivs[tInd][src][rx].shape[0], - len(block[(src, rx, ind)][0]), + simulation.field_derivs[tInd][s_id][r_id].shape[0], + len(block[(s_id, r_id, b_id)][0]), ) ) - if (src, rx, ind) in indices: - columns, sub_ind = indices[(src, rx, ind)] - ATinv_df_duT_v[(src, rx, ind)][:, sub_ind] = solve[:, columns] + if (s_id, r_id, b_id) in indices: + try: + columns, local_ind = indices[(s_id, r_id, b_id)] + ATinv_df_duT_v[(s_id, r_id, b_id)][:, local_ind] = solve[:, columns] + except: + print("ouch") return ATinv_df_duT_v +@delayed def compute_rows( simulation, tInd, - src, - rx, - rx_ind, - j_ind, + address, # (s_id, r_id, b_id) + indices, # (rx_ind, j_ind), ATinv_df_duT_v, - f, + fields, Jmatrix, ftype, time_mask, @@ -315,27 +325,34 @@ def compute_rows( """ Compute the rows of the sensitivity matrix for a given source and receiver. """ - time_check = np.kron(time_mask, np.ones(rx.locations.shape[0], dtype=bool))[rx_ind] - sub_ind = rx_ind[time_check] - - if len(sub_ind) < 1: + src = simulation.survey.source_list[address[0]] + rx = src.receiver_list[address[1]] + time_check = np.kron(time_mask, np.ones(rx.locations.shape[0], dtype=bool))[ + indices[0] + ] + local_ind = np.arange(indices[0].shape[0])[time_check] + + if len(local_ind) < 1: return dAsubdiagT_dm_v = simulation.getAsubdiagDeriv( - tInd, f[src, ftype, tInd], ATinv_df_duT_v[:, sub_ind], adjoint=True + tInd, + fields[src, ftype, tInd], + ATinv_df_duT_v[address][:, local_ind], + adjoint=True, ) dRHST_dm_v = simulation.getRHSDeriv( - tInd + 1, src, ATinv_df_duT_v[:, sub_ind], adjoint=True + tInd + 1, src, ATinv_df_duT_v[address][:, local_ind], adjoint=True ) # on nodes of time mesh - un_src = f[src, ftype, tInd + 1] + un_src = fields[src, ftype, tInd + 1] # cell centered on time mesh dAT_dm_v = simulation.getAdiagDeriv( - tInd, un_src, ATinv_df_duT_v[:, sub_ind], adjoint=True + tInd, un_src, ATinv_df_duT_v[address][:, local_ind], adjoint=True ) - Jmatrix[j_ind[time_check], :] += (-dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v).T + Jmatrix[indices[1][time_check], :] += (-dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v).T def compute_J(self, f=None, Ainv=None): @@ -367,16 +384,14 @@ def compute_J(self, f=None, Ainv=None): self, block, tInd, AdiagTinv, ATinv_df_duT_v, time_mask ) - for (src, rx, ind), (rx_ind, j_ind) in block.items(): + for address, indices in block.items(): j_row_updates.append( - delayed(compute_rows, pure=True)( + compute_rows( self, tInd, - src, - rx, - rx_ind, - j_ind, - ATinv_df_duT_v[(src, rx, ind)], + address, + indices, + ATinv_df_duT_v, f, Jmatrix, ftype,