diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index fe9b7cca08..2039fce1f3 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -156,26 +156,30 @@ def evaluate_receiver(source, receiver, mesh, time_mesh, fields): @delayed -def block_deriv( - time_index, field_type, source, rx_count, mesh, time_mesh, fields, Jmatrix -): +def block_deriv(time_index, field_type, source_list, mesh, time_mesh, fields, Jmatrix): """Compute derivatives for sources and receivers in a block""" - field_len = len(fields[source, field_type, 0]) + field_len = len(fields[source_list[0], field_type, 0]) df_duT = [] + rx_count = 0 + for source in source_list: + sources_block = [] - for rx in source.receiver_list: - PTv = rx.getP(mesh, time_mesh, fields).tocsr() - derivative_fun = getattr(fields, "_{}Deriv".format(rx.projField), None) - rx_ind = np.arange(rx_count, rx_count + rx.nD).astype(int) - cur = derivative_fun( - time_index, - source, - None, - PTv[:, (time_index * field_len) : ((time_index + 1) * field_len)].T, - adjoint=True, - ) - df_duT.append(cur[0]) - Jmatrix[rx_ind, :] += cur[1].T + for rx in source.receiver_list: + PTv = rx.getP(mesh, time_mesh, fields).tocsr() + derivative_fun = getattr(fields, "_{}Deriv".format(rx.projField), None) + rx_ind = np.arange(rx_count, rx_count + rx.nD).astype(int) + cur = derivative_fun( + time_index, + source, + None, + PTv[:, (time_index * field_len) : ((time_index + 1) * field_len)].T, + adjoint=True, + ) + sources_block.append(cur[0]) + Jmatrix[rx_ind, :] += cur[1].T + rx_count += rx.nD + + df_duT.append(sources_block) return df_duT @@ -184,28 +188,20 @@ def compute_field_derivs(simulation, Jmatrix, fields): """ Compute the derivative of the fields """ - df_duT = [] for time_index in range(simulation.nT + 1): - rx_count = 0 - sources_block = [] - for source in simulation.survey.source_list: - sources_block.append( - block_deriv( - time_index, - simulation._fieldType + "Solution", - source, - rx_count, - simulation.mesh, - simulation.time_mesh, - fields, - Jmatrix, - ) + df_duT.append( + block_deriv( + time_index, + simulation._fieldType + "Solution", + simulation.survey.source_list, + simulation.mesh, + simulation.time_mesh, + fields, + Jmatrix, ) - rx_count += source.nD - - df_duT.append(sources_block) + ) df_duT = dask.compute(df_duT)[0] @@ -247,6 +243,30 @@ def get_parallel_blocks(source_list: list, m_size: int, max_chunk_size: int): return blocks +@delayed +def deriv_block( + s_id, r_id, b_id, ATinv_df_duT_v, Asubdiag, local_ind, sub_ind, simulation, tInd +): + if (s_id, r_id, b_id) not in ATinv_df_duT_v: + # last timestep (first to be solved) + stacked_block = simulation.field_derivs[tInd + 1][s_id][r_id].toarray()[ + :, sub_ind + ] + + else: + stacked_block = np.asarray( + simulation.field_derivs[tInd + 1][s_id][r_id][:, sub_ind] + - Asubdiag.T * ATinv_df_duT_v[(s_id, r_id, b_id)][:, local_ind] + ) + + return stacked_block + + +def update_deriv_blocks(address, indices, derivatives, solve): + columns, local_ind = indices[address] + derivatives[:, local_ind] = solve[:, columns] + + def get_field_deriv_block( simulation, block: dict, tInd: int, AdiagTinv, ATinv_df_duT_v: dict, time_mask ): @@ -256,6 +276,11 @@ def get_field_deriv_block( stacked_blocks = [] indices = {} count = 0 + + Asubdiag = None + if tInd < simulation.nT - 1: + Asubdiag = simulation.getAsubdiag(tInd + 1) + for (s_id, r_id, b_id), (rx_ind, j_ind) in block.items(): # Cut out early data rx = simulation.survey.source_list[s_id].receiver_list[r_id] @@ -274,38 +299,44 @@ def get_field_deriv_block( ) count += len(sub_ind) - if (s_id, r_id, b_id) not in ATinv_df_duT_v: - # last timestep (first to be solved) - stacked_blocks.append( - simulation.field_derivs[tInd + 1][s_id][r_id].toarray()[:, sub_ind] - ) - - else: - Asubdiag = simulation.getAsubdiag(tInd + 1) - stacked_blocks.append( - np.asarray( - simulation.field_derivs[tInd + 1][s_id][r_id][:, sub_ind] - - Asubdiag.T * ATinv_df_duT_v[(s_id, r_id, b_id)][:, local_ind] - ) + stacked_blocks.append( + deriv_block( + s_id, + r_id, + b_id, + ATinv_df_duT_v, + Asubdiag, + local_ind, + sub_ind, + simulation, + tInd, ) + ) if len(stacked_blocks) > 0: - solve = AdiagTinv * np.hstack(stacked_blocks) + solve = AdiagTinv * np.hstack(dask.compute(stacked_blocks)[0]) + update_list = [] for s_id, r_id, b_id in block: - ATinv_df_duT_v[(s_id, r_id, b_id)] = np.zeros( - ( - simulation.field_derivs[tInd][s_id][r_id].shape[0], - len(block[(s_id, r_id, b_id)][0]), + if (s_id, r_id, b_id) not in ATinv_df_duT_v: + ATinv_df_duT_v[(s_id, r_id, b_id)] = np.zeros( + ( + simulation.field_derivs[tInd][s_id][r_id].shape[0], + len(block[(s_id, r_id, b_id)][0]), + ) ) - ) if (s_id, r_id, b_id) in indices: - try: - columns, local_ind = indices[(s_id, r_id, b_id)] - ATinv_df_duT_v[(s_id, r_id, b_id)][:, local_ind] = solve[:, columns] - except: - print("ouch") + update_list.append( + update_deriv_blocks( + (s_id, r_id, b_id), + indices, + ATinv_df_duT_v[(s_id, r_id, b_id)], + solve, + ) + ) + + dask.compute(update_list) return ATinv_df_duT_v @@ -370,9 +401,7 @@ def compute_J(self, f=None, Ainv=None): blocks = get_parallel_blocks( self.survey.source_list, self.model.shape[0], self.max_chunk_size ) - self.field_derivs = compute_field_derivs(self, Jmatrix, f) - ATinv_df_duT_v = {} for tInd, dt in tqdm(zip(reversed(range(self.nT)), reversed(self.time_steps))): AdiagTinv = Ainv[dt]