Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Geopy 1408 #43

Merged
merged 3 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 28 additions & 13 deletions SimPEG/dask/electromagnetics/time_domain/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from SimPEG.dask.simulation import dask_Jvec, dask_Jtvec, dask_getJtJdiag
from SimPEG.dask.utils import get_parallel_blocks
from SimPEG.utils import mkvc
import zarr
from time import time
from tqdm import tqdm
Expand Down Expand Up @@ -215,6 +216,10 @@ def dask_dpred(self, m=None, f=None, compute_J=False):
rows = []
receiver_projection = self.survey.source_list[0].receiver_list[0].projField
fields_array = f[:, receiver_projection, :]

if len(self.survey.source_list) == 1:
fields_array = fields_array[:, np.newaxis, :]

all_receivers = []

for ind, src in enumerate(self.survey.source_list):
Expand Down Expand Up @@ -275,15 +280,16 @@ def delayed_block_deriv(
j_update += sp.csr_matrix((arrays[0].shape[0], shape), dtype=np.float32)
continue

projection = sp.kron(timeP[:, time_index], spatialP)
projection = sp.kron(timeP[:, time_index], spatialP, format="csr")
cur = derivative_fun(
time_index,
source,
None,
projection.T,
adjoint=True,
)
time_derivs.append(cur[0])

time_derivs.append(cur[0][:, arrays[0]])

if not isinstance(cur[1], Zero):
j_update += cur[1].T
Expand Down Expand Up @@ -344,15 +350,15 @@ def compute_field_derivs(simulation, fields, blocks, Jmatrix, fields_shape):

@delayed
def deriv_block(
s_id, r_id, b_id, ATinv_df_duT_v, Asubdiag, local_ind, sub_ind, field_derivs, tInd
s_id, r_id, b_id, ATinv_df_duT_v, Asubdiag, local_ind, field_derivs, tInd
):
if (s_id, r_id, b_id) not in ATinv_df_duT_v:
# last timestep (first to be solved)
stacked_block = field_derivs.toarray()[:, sub_ind]
stacked_block = field_derivs.toarray()[:, local_ind]

else:
stacked_block = np.asarray(
field_derivs[:, sub_ind]
field_derivs[:, local_ind]
- Asubdiag.T * ATinv_df_duT_v[(s_id, r_id, b_id)][:, local_ind]
)

Expand All @@ -367,7 +373,8 @@ def update_deriv_blocks(address, indices, derivatives, solve, shape):

if address in indices:
columns, local_ind = indices[address]
deriv_array[:, local_ind] = solve[:, columns]
if solve is not None:
deriv_array[:, local_ind] = solve[:, columns]

derivatives[address] = deriv_array

Expand Down Expand Up @@ -397,29 +404,26 @@ def get_field_deriv_block(
):
# Cut out early data
time_check = np.kron(time_mask, np.ones(shape, dtype=bool))[rx_ind]
sub_ind = rx_ind[time_check]
local_ind = np.arange(rx_ind.shape[0])[time_check]

if len(sub_ind) < 1:
if len(local_ind) < 1:
continue

indices[(s_id, r_id, b_id)] = (
np.arange(count, count + len(sub_ind)),
np.arange(count, count + len(local_ind)),
local_ind,
)
count += len(sub_ind)
count += len(local_ind)
deriv_comp = deriv_block(
s_id,
r_id,
b_id,
ATinv_df_duT_v,
Asubdiag,
local_ind,
sub_ind,
field_deriv,
tInd,
)

stacked_blocks.append(
array.from_delayed(
deriv_comp,
Expand Down Expand Up @@ -469,7 +473,11 @@ def compute_rows(
local_ind = np.arange(len(ind_array[0]))[time_check]

if len(local_ind) < 1:
return
row_block = np.zeros(
(len(ind_array[1]), simulation.model.size), dtype=np.float32
)
rows.append(row_block)
continue

field_derivs = ATinv_df_duT_v[address]
dAsubdiagT_dm_v = simulation.getAsubdiagDeriv(
Expand Down Expand Up @@ -530,6 +538,10 @@ def compute_J(self, f=None, Ainv=None):
self.survey.source_list, self.model.shape[0], self.max_chunk_size
)
fields_array = f[:, ftype, :]

if len(self.survey.source_list) == 1:
fields_array = fields_array[:, np.newaxis, :]

times_field_derivs, Jmatrix = compute_field_derivs(
self, f, blocks, Jmatrix, fields_array.shape
)
Expand All @@ -540,6 +552,9 @@ def compute_J(self, f=None, Ainv=None):
j_row_updates = []
time_mask = data_times > simulation_times[tInd]

if not np.any(time_mask):
continue

tc = time()
for block, field_deriv in zip(blocks, times_field_derivs[tInd + 1]):
ATinv_df_duT_v = get_field_deriv_block(
Expand Down
2 changes: 1 addition & 1 deletion SimPEG/dask/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def get_parallel_blocks(source_list: list, m_size: int, max_chunk_size: int) ->
if (row_count + chunk_size) > (data_block_size * cpu_count()):
row_count = 0
block_count += 1
blocks.append = []
blocks.append([])

blocks[block_count].append(
(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

[tool.poetry]
name = "Mira-SimPEG"
version = "0.19.0.dev7"
version = "0.19.0.dev8"
license = "MIT"
description = "Mira Geoscience fork of SimPEG: Simulation and Parameter Estimation in Geophysics"

Expand Down
Loading