diff --git a/SimPEG/dask/electromagnetics/time_domain/simulation.py b/SimPEG/dask/electromagnetics/time_domain/simulation.py index 359ca13167..997908bebf 100644 --- a/SimPEG/dask/electromagnetics/time_domain/simulation.py +++ b/SimPEG/dask/electromagnetics/time_domain/simulation.py @@ -165,9 +165,7 @@ def block_deriv(time_index, field_type, source_list, mesh, time_mesh, fields, Jm df_duT[src] = {rx: {} for rx in src.receiver_list} for rx in src.receiver_list: - PTv = np.asarray( - rx.getP(mesh, time_mesh, fields).todense().T - ).reshape((field_len, time_mesh.n_faces, -1), order="F") + PTv = rx.getP(mesh, time_mesh, fields).tocsr() derivative_fun = getattr(fields, "_{}Deriv".format(rx.projField), None) rx_ind = np.arange(rx_count, rx_count + rx.nD).astype(int) @@ -175,7 +173,7 @@ def block_deriv(time_index, field_type, source_list, mesh, time_mesh, fields, Jm time_index, src, None, - sp.csr_matrix(PTv[:, time_index, :]), + PTv[:, (time_index * field_len) : ((time_index + 1) * field_len)].T, adjoint=True, ) df_duT[src][rx] = cur[0] @@ -194,21 +192,23 @@ def compute_field_derivs(simulation, Jmatrix, fields): df_duT = [] for time_index in range(simulation.nT + 1): - df_duT.append(delayed(block_deriv, pure=True)( - time_index, - simulation._fieldType + "Solution", - simulation.survey.source_list, - simulation.mesh, - simulation.time_mesh, - fields, - Jmatrix - )) - + df_duT.append( + delayed(block_deriv, pure=True)( + time_index, + simulation._fieldType + "Solution", + simulation.survey.source_list, + simulation.mesh, + simulation.time_mesh, + fields, + Jmatrix, + ) + ) df_duT = dask.compute(df_duT)[0] return df_duT + def get_parallel_blocks(source_list: list, m_size: int, max_chunk_size: int): """ Get the blocks of sources and receivers to be computed in parallel. @@ -223,9 +223,8 @@ def get_parallel_blocks(source_list: list, m_size: int, max_chunk_size: int): blocks = {0: {}} for src in source_list: for rx in src.receiver_list: - indices = np.arange(rx.nD).astype(int) - chunks = np.split(indices, int(np.ceil(len(indices)/data_block_size))) + chunks = np.split(indices, int(np.ceil(len(indices) / data_block_size))) for ind, chunk in enumerate(chunks): chunk_size = len(chunk) @@ -236,68 +235,107 @@ def get_parallel_blocks(source_list: list, m_size: int, max_chunk_size: int): block_count += 1 blocks[block_count] = {} - blocks[block_count][(src, rx, ind)] = chunk, np.arange(row_index, row_index + chunk_size).astype(int) + blocks[block_count][(src, rx, ind)] = chunk, np.arange( + row_index, row_index + chunk_size + ).astype(int) row_index += chunk_size row_count += chunk_size return blocks -def get_field_deriv_block(simulation, block: dict, tInd: int, AdiagTinv, ATinv_df_duT_v: dict): +def get_field_deriv_block( + simulation, block: dict, tInd: int, AdiagTinv, ATinv_df_duT_v: dict, time_mask +): """ Stack the blocks of field derivatives for a given timestep and call the direct solver. """ stacked_blocks = [] - indices = [] + indices = {} count = 0 for (src, rx, ind), (rx_ind, j_ind) in block.items(): - indices.append( - np.arange(count, count + len(rx_ind)) - ) - count += len(rx_ind) + # Cut out early data + time_check = np.kron(time_mask, np.ones(rx.locations.shape[0], dtype=bool))[ + rx_ind + ] + sub_ind = rx_ind[time_check] + + if len(sub_ind) < 1: + continue + + indices[(src, rx, ind)] = (np.arange(count, count + len(sub_ind)), sub_ind) + count += len(sub_ind) + if (src, rx, ind) not in ATinv_df_duT_v: # last timestep (first to be solved) stacked_blocks.append( - simulation.field_derivs[tInd + 1][src][rx].toarray()[:, rx_ind] + simulation.field_derivs[tInd + 1][src][rx].toarray()[:, sub_ind] ) else: Asubdiag = simulation.getAsubdiag(tInd + 1) stacked_blocks.append( np.asarray( - simulation.field_derivs[tInd + 1][src][rx][:, rx_ind] - - Asubdiag.T * ATinv_df_duT_v[(src, rx, ind)] + simulation.field_derivs[tInd + 1][src][rx][:, sub_ind] + - Asubdiag.T * ATinv_df_duT_v[(src, rx, ind)][:, sub_ind] ) ) - solve = AdiagTinv * np.hstack(stacked_blocks) + if len(stacked_blocks) > 1: + solve = AdiagTinv * np.hstack(stacked_blocks) + + for src, rx, ind in block: + ATinv_df_duT_v[(src, rx, ind)] = np.zeros( + ( + simulation.field_derivs[tInd][src][rx].shape[0], + len(block[(src, rx, ind)][0]), + ) + ) - for (src, rx, ind), columns in zip(block, indices): - ATinv_df_duT_v[(src, rx, ind)] = solve[:, columns] + if (src, rx, ind) in indices: + columns, sub_ind = indices[(src, rx, ind)] + ATinv_df_duT_v[(src, rx, ind)][:, sub_ind] = solve[:, columns] return ATinv_df_duT_v -def compute_rows(simulation, tInd, src, rx_ind, j_ind, ATinv_df_duT_v, f, Jmatrix, ftype): +def compute_rows( + simulation, + tInd, + src, + rx, + rx_ind, + j_ind, + ATinv_df_duT_v, + f, + Jmatrix, + ftype, + time_mask, +): """ Compute the rows of the sensitivity matrix for a given source and receiver. """ + time_check = np.kron(time_mask, np.ones(rx.locations.shape[0], dtype=bool))[rx_ind] + sub_ind = rx_ind[time_check] + + if len(sub_ind) < 1: + return + dAsubdiagT_dm_v = simulation.getAsubdiagDeriv( - tInd, f[src, ftype, tInd], ATinv_df_duT_v, adjoint=True + tInd, f[src, ftype, tInd], ATinv_df_duT_v[:, sub_ind], adjoint=True ) dRHST_dm_v = simulation.getRHSDeriv( - tInd + 1, src, ATinv_df_duT_v, adjoint=True + tInd + 1, src, ATinv_df_duT_v[:, sub_ind], adjoint=True ) # on nodes of time mesh un_src = f[src, ftype, tInd + 1] # cell centered on time mesh dAT_dm_v = simulation.getAdiagDeriv( - tInd, un_src, ATinv_df_duT_v, adjoint=True + tInd, un_src, ATinv_df_duT_v[:, sub_ind], adjoint=True ) - Jmatrix[j_ind, :] += (-dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v).T - + Jmatrix[j_ind[time_check], :] += (-dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v).T def compute_J(self, f=None, Ainv=None): @@ -310,7 +348,11 @@ def compute_J(self, f=None, Ainv=None): ftype = self._fieldType + "Solution" Jmatrix = np.zeros((self.survey.nD, self.model.size), dtype=np.float32) - blocks = get_parallel_blocks(self.survey.source_list, self.model.shape[0], self.max_chunk_size) + simulation_times = np.r_[0, np.cumsum(self.time_steps)] + self.t0 + data_times = self.survey.source_list[0].receiver_list[0].times + blocks = get_parallel_blocks( + self.survey.source_list, self.model.shape[0], self.max_chunk_size + ) self.field_derivs = compute_field_derivs(self, Jmatrix, f) @@ -318,70 +360,32 @@ def compute_J(self, f=None, Ainv=None): for tInd, dt in tqdm(zip(reversed(range(self.nT)), reversed(self.time_steps))): AdiagTinv = Ainv[dt] j_row_updates = [] + time_mask = data_times > simulation_times[tInd] for block in blocks.values(): - ATinv_df_duT_v = get_field_deriv_block(self, block, tInd, AdiagTinv, ATinv_df_duT_v) + ATinv_df_duT_v = get_field_deriv_block( + self, block, tInd, AdiagTinv, ATinv_df_duT_v, time_mask + ) for (src, rx, ind), (rx_ind, j_ind) in block.items(): - j_row_updates.append(delayed(compute_rows, pure=True)( - self, - tInd, - src, - rx_ind, - j_ind, - ATinv_df_duT_v[(src, rx, ind)], - f, - Jmatrix, - ftype - )) - # for (src, rx, ind), (rx_ind, j_ind) in block.items(): - + j_row_updates.append( + delayed(compute_rows, pure=True)( + self, + tInd, + src, + rx, + rx_ind, + j_ind, + ATinv_df_duT_v[(src, rx, ind)], + f, + Jmatrix, + ftype, + time_mask, + ) + ) dask.compute(j_row_updates) - # rx_count = 0 - # for isrc, src in enumerate(self.survey.source_list): - # - # if isrc not in ATinv_df_duT_v: - # ATinv_df_duT_v[isrc] = {} - # - # for rx in src.receiver_list: - # if rx not in ATinv_df_duT_v[isrc]: - # ATinv_df_duT_v[isrc][rx] = {} - # - # rx_ind = np.arange(rx_count, rx_count + rx.nD).astype(int) - # # solve against df_duT_v - # if tInd >= self.nT - 1: - # # last timestep (first to be solved) - # ATinv_df_duT_v[isrc][rx] = ( - # AdiagTinv - # * self.field_derivs[tInd+1][src][rx].toarray() - # ) - # elif tInd > -1: - # ATinv_df_duT_v[isrc][rx] = AdiagTinv * np.asarray( - # self.field_derivs[tInd+1][src][rx] - # - Asubdiag.T * ATinv_df_duT_v[isrc][rx] - # ) - # - # dAsubdiagT_dm_v = self.getAsubdiagDeriv( - # tInd, f[src, ftype, tInd], ATinv_df_duT_v[isrc][rx], adjoint=True - # ) - # - # dRHST_dm_v = self.getRHSDeriv( - # tInd + 1, src, ATinv_df_duT_v[isrc][rx], adjoint=True - # ) # on nodes of time mesh - # - # un_src = f[src, ftype, tInd + 1] - # # cell centered on time mesh - # dAT_dm_v = self.getAdiagDeriv( - # tInd, un_src, ATinv_df_duT_v[isrc][rx], adjoint=True - # ) - # - # Jmatrix[rx_ind, :] += (-dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v).T - - # rx_count += rx.nD - - for A in Ainv.values(): A.clean() @@ -392,265 +396,4 @@ def compute_J(self, f=None, Ainv=None): return Jmatrix - # if f is None: - # f, Ainv = self.fields(self.model, return_Ainv=True) - # - # m_size = self.model.size - # row_chunks = int( - # np.ceil( - # float(self.survey.nD) - # / np.ceil(float(m_size) * self.survey.nD * 8.0 * 1e-6 / self.max_chunk_size) - # ) - # ) - # - # if self.store_sensitivities == "disk": - # self.J_initializer = zarr.open( - # self.sensitivity_path + f"J_initializer.zarr", - # mode="w", - # shape=(self.survey.nD, m_size), - # chunks=(row_chunks, m_size), - # ) - # else: - # self.J_initializer = np.zeros((self.survey.nD, m_size), dtype=np.float32) - # solution_type = self._fieldType + "Solution" # the thing we solved for - # - # if self.field_derivs is None: - # # print("Start loop for field derivs") - # block_size = len(f[self.survey.source_list[0], solution_type, 0]) - # - # field_derivs = [] - # for tInd in range(self.nT + 1): - # d_count = 0 - # df_duT_v = [] - # for i_s, src in enumerate(self.survey.source_list): - # src_field_derivs = delayed(block_deriv, pure=True)( - # self, src, tInd, f, block_size, d_count - # ) - # df_duT_v += [src_field_derivs] - # d_count += np.sum([rx.nD for rx in src.receiver_list]) - # - # field_derivs += [df_duT_v] - # # print("Dask loop field derivs") - # # tc = time() - # - # self.field_derivs = dask.compute(field_derivs)[0] - # # print(f"Done in {time() - tc} seconds") - # - # if self.store_sensitivities == "disk": - # Jmatrix = ( - # zarr.open( - # self.sensitivity_path + f"J.zarr", - # mode="w", - # shape=(self.survey.nD, m_size), - # chunks=(row_chunks, m_size), - # ) - # + self.J_initializer - # ) - # else: - # Jmatrix = dask.delayed( - # np.zeros((self.survey.nD, m_size), dtype=np.float32) + self.J_initializer - # ) - # - # f = dask.delayed(f) - # field_derivs_t = {} - # d_block_size = np.ceil(128.0 / (m_size * 8.0 * 1e-6)) - # - # # Check which time steps we need to compute - # simulation_times = np.r_[0, np.cumsum(self.time_steps)] + self.t0 - # data_times = self.survey.source_list[0].receiver_list[0].times - # n_times = len(data_times) - # - # for tInd, dt in tqdm(zip(reversed(range(self.nT)), reversed(self.time_steps))): - # AdiagTinv = Ainv[dt] - # Asubdiag = self.getAsubdiag(tInd) - # row_count = 0 - # row_blocks = [] - # field_derivs = {} - # source_blocks = [] - # d_count = 0 - # - # data_bool = data_times > simulation_times[tInd] - # - # if data_bool.sum() == 0: - # continue - # - # # tc_loop = time() - # # print(f"Loop sources for {tInd}") - # for isrc, src in enumerate(self.survey.source_list): - # - # column_inds = np.hstack([ - # np.kron(np.ones(rec.locations.shape[0], dtype=bool), data_bool - # ) for rec in src.receiver_list]) - # - # if isrc not in field_derivs_t: - # field_derivs[(isrc, src)] = self.field_derivs[tInd + 1][isrc].toarray()[ - # :, column_inds - # ] - # else: - # field_derivs[(isrc, src)] = field_derivs_t[isrc][:, column_inds] - # - # d_count += column_inds.sum() - # - # if d_count > d_block_size: - # source_blocks = block_append( - # self, - # f, - # AdiagTinv, - # field_derivs, - # m_size, - # row_count, - # tInd, - # solution_type, - # Jmatrix, - # Asubdiag, - # source_blocks, - # data_bool, - # ) - # field_derivs = {} - # row_count = d_count - # d_count = 0 - # - # if field_derivs: - # source_blocks = block_append( - # self, - # f, - # AdiagTinv, - # field_derivs, - # m_size, - # row_count, - # tInd, - # solution_type, - # Jmatrix, - # Asubdiag, - # source_blocks, - # data_bool, - # ) - # - # # print(f"Done in {time() - tc_loop} seconds") - # # tc = time() - # # print(f"Compute field derivs for {tInd}") - # del field_derivs_t - # field_derivs_t = { - # isrc: elem for isrc, elem in enumerate(dask.compute(source_blocks)[0]) - # } - # # print(f"Done in {time() - tc} seconds") - # - # for A in Ainv.values(): - # A.clean() - # - # if self.store_sensitivities == "disk": - # del Jmatrix - # return array.from_zarr(self.sensitivity_path + f"J.zarr") - # else: - # return Jmatrix.compute() - - Sim.compute_J = compute_J - - -def block_append( - simulation, - fields, - AdiagTinv, - field_derivs, - m_size, - row_count, - tInd, - solution_type, - Jmatrix, - Asubdiag, - source_blocks, - data_bool, -): - solves = AdiagTinv * np.hstack(list(field_derivs.values())) - count = 0 - - for (isrc, src), block in field_derivs.items(): - - column_inds = np.hstack([ - np.kron(np.ones(rec.locations.shape[0], dtype=bool), data_bool - ) for rec in src.receiver_list]) - - n_rows = column_inds.sum() - source_blocks.append( - dask.array.from_delayed( - delayed(parallel_block_compute, pure=True)( - simulation, - fields, - src, - solves[:, count : count + n_rows], - row_count, - tInd, - solution_type, - Jmatrix, - Asubdiag, - simulation.field_derivs[tInd][isrc], - column_inds, - ), - shape=simulation.field_derivs[tInd + 1][isrc].shape, - dtype=np.float32, - ) - ) - count += n_rows - # print(f"Appending block {isrc} in {time() - tc} seconds") - row_count += len(column_inds) - - return source_blocks - - -# def block_deriv(simulation, src, tInd, f, block_size, row_count): -# src_field_derivs = None -# for rx in src.receiver_list: -# v = sp.eye(rx.nD, dtype=float) -# PT_v = rx.evalDeriv( -# src, simulation.mesh, simulation.time_mesh, f, v, adjoint=True -# ) -# df_duTFun = getattr(f, "_{}Deriv".format(rx.projField), None) -# -# cur = df_duTFun( -# simulation.nT, -# src, -# None, -# PT_v[tInd * block_size: (tInd + 1) * block_size, :], -# adjoint=True, -# ) -# -# if not isinstance(cur[1], Zero): -# simulation.J_initializer[row_count: row_count + rx.nD, :] += cur[1].T -# -# if src_field_derivs is None: -# src_field_derivs = cur[0] -# else: -# src_field_derivs = sp.hstack([src_field_derivs, cur[0]]) -# -# row_count += rx.nD -# -# return src_field_derivs - - -def parallel_block_compute( - simulation, - f, - src, - ATinv_df_duT_v, - row_count, - tInd, - solution_type, - Jmatrix, - Asubdiag, - field_derivs, - data_bool, -): - rows = row_count + np.where(data_bool)[0] - field_derivs_t = np.asarray(field_derivs.todense()) - field_derivs_t[:, data_bool] -= Asubdiag.T * ATinv_df_duT_v - - dAsubdiagT_dm_v = simulation.getAsubdiagDeriv( - tInd, f[src, solution_type, tInd], ATinv_df_duT_v, adjoint=True - ) - dRHST_dm_v = simulation.getRHSDeriv(tInd + 1, src, ATinv_df_duT_v, adjoint=True) - un_src = f[src, solution_type, tInd + 1] - dAT_dm_v = simulation.getAdiagDeriv(tInd, un_src, ATinv_df_duT_v, adjoint=True) - Jmatrix[rows, :] += (-dAT_dm_v - dAsubdiagT_dm_v + dRHST_dm_v).T - - return field_derivs_t