Skip to content

Commit

Permalink
Merge pull request #43 from MiraGeoscience/GEOPY-1408
Browse files Browse the repository at this point in the history
Geopy 1408
  • Loading branch information
andrewg-mira authored Apr 25, 2024
2 parents a9b8107 + ce44396 commit 6840d2f
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 15 deletions.
41 changes: 28 additions & 13 deletions SimPEG/dask/electromagnetics/time_domain/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from SimPEG.dask.simulation import dask_Jvec, dask_Jtvec, dask_getJtJdiag
from SimPEG.dask.utils import get_parallel_blocks
from SimPEG.utils import mkvc
import zarr
from time import time
from tqdm import tqdm
Expand Down Expand Up @@ -215,6 +216,10 @@ def dask_dpred(self, m=None, f=None, compute_J=False):
rows = []
receiver_projection = self.survey.source_list[0].receiver_list[0].projField
fields_array = f[:, receiver_projection, :]

if len(self.survey.source_list) == 1:
fields_array = fields_array[:, np.newaxis, :]

all_receivers = []

for ind, src in enumerate(self.survey.source_list):
Expand Down Expand Up @@ -275,15 +280,16 @@ def delayed_block_deriv(
j_update += sp.csr_matrix((arrays[0].shape[0], shape), dtype=np.float32)
continue

projection = sp.kron(timeP[:, time_index], spatialP)
projection = sp.kron(timeP[:, time_index], spatialP, format="csr")
cur = derivative_fun(
time_index,
source,
None,
projection.T,
adjoint=True,
)
time_derivs.append(cur[0])

time_derivs.append(cur[0][:, arrays[0]])

if not isinstance(cur[1], Zero):
j_update += cur[1].T
Expand Down Expand Up @@ -344,15 +350,15 @@ def compute_field_derivs(simulation, fields, blocks, Jmatrix, fields_shape):

@delayed
def deriv_block(
s_id, r_id, b_id, ATinv_df_duT_v, Asubdiag, local_ind, sub_ind, field_derivs, tInd
s_id, r_id, b_id, ATinv_df_duT_v, Asubdiag, local_ind, field_derivs, tInd
):
if (s_id, r_id, b_id) not in ATinv_df_duT_v:
# last timestep (first to be solved)
stacked_block = field_derivs.toarray()[:, sub_ind]
stacked_block = field_derivs.toarray()[:, local_ind]

else:
stacked_block = np.asarray(
field_derivs[:, sub_ind]
field_derivs[:, local_ind]
- Asubdiag.T * ATinv_df_duT_v[(s_id, r_id, b_id)][:, local_ind]
)

Expand All @@ -367,7 +373,8 @@ def update_deriv_blocks(address, indices, derivatives, solve, shape):

if address in indices:
columns, local_ind = indices[address]
deriv_array[:, local_ind] = solve[:, columns]
if solve is not None:
deriv_array[:, local_ind] = solve[:, columns]

derivatives[address] = deriv_array

Expand Down Expand Up @@ -397,29 +404,26 @@ def get_field_deriv_block(
):
# Cut out early data
time_check = np.kron(time_mask, np.ones(shape, dtype=bool))[rx_ind]
sub_ind = rx_ind[time_check]
local_ind = np.arange(rx_ind.shape[0])[time_check]

if len(sub_ind) < 1:
if len(local_ind) < 1:
continue

indices[(s_id, r_id, b_id)] = (
np.arange(count, count + len(sub_ind)),
np.arange(count, count + len(local_ind)),
local_ind,
)
count += len(sub_ind)
count += len(local_ind)
deriv_comp = deriv_block(
s_id,
r_id,
b_id,
ATinv_df_duT_v,
Asubdiag,
local_ind,
sub_ind,
field_deriv,
tInd,
)

stacked_blocks.append(
array.from_delayed(
deriv_comp,
Expand Down Expand Up @@ -469,7 +473,11 @@ def compute_rows(
local_ind = np.arange(len(ind_array[0]))[time_check]

if len(local_ind) < 1:
return
row_block = np.zeros(
(len(ind_array[1]), simulation.model.size), dtype=np.float32
)
rows.append(row_block)
continue

field_derivs = ATinv_df_duT_v[address]
dAsubdiagT_dm_v = simulation.getAsubdiagDeriv(
Expand Down Expand Up @@ -530,6 +538,10 @@ def compute_J(self, f=None, Ainv=None):
self.survey.source_list, self.model.shape[0], self.max_chunk_size
)
fields_array = f[:, ftype, :]

if len(self.survey.source_list) == 1:
fields_array = fields_array[:, np.newaxis, :]

times_field_derivs, Jmatrix = compute_field_derivs(
self, f, blocks, Jmatrix, fields_array.shape
)
Expand All @@ -540,6 +552,9 @@ def compute_J(self, f=None, Ainv=None):
j_row_updates = []
time_mask = data_times > simulation_times[tInd]

if not np.any(time_mask):
continue

tc = time()
for block, field_deriv in zip(blocks, times_field_derivs[tInd + 1]):
ATinv_df_duT_v = get_field_deriv_block(
Expand Down
2 changes: 1 addition & 1 deletion SimPEG/dask/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def get_parallel_blocks(source_list: list, m_size: int, max_chunk_size: int) ->
if (row_count + chunk_size) > (data_block_size * cpu_count()):
row_count = 0
block_count += 1
blocks.append = []
blocks.append([])

blocks[block_count].append(
(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

[tool.poetry]
name = "Mira-SimPEG"
version = "0.19.0.dev7"
version = "0.19.0.dev8"
license = "MIT"
description = "Mira Geoscience fork of SimPEG: Simulation and Parameter Estimation in Geophysics"

Expand Down

0 comments on commit 6840d2f

Please sign in to comment.