Skip to content

Commit

Permalink
Comment out Explicit version. Clean ups
Browse files Browse the repository at this point in the history
  • Loading branch information
domfournier committed Dec 28, 2024
1 parent de9ee29 commit 14e7d2b
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 97 deletions.
1 change: 0 additions & 1 deletion simpeg/dask/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ def Jmatrix(self):
"""
if getattr(self, "_Jmatrix", None) is None:
self._Jmatrix = self.compute_J(self.model)
self._stashed_fields = None

return self._Jmatrix

Expand Down
12 changes: 1 addition & 11 deletions simpeg/data_misfit.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from .utils import Counter, mkvc, sdiag, timeIt, Identity, validate_type
from .utils import Counter, sdiag, timeIt, Identity, validate_type
from .data import Data
from .simulation import BaseSimulation
from .objective_function import L2ObjectiveFunction
Expand Down Expand Up @@ -359,16 +359,6 @@ def getJtJdiag(self, m):
+ "Cannot form the sensitivity explicitly"
)

# mapping_deriv = self.model_map.deriv(m)
#
# if self.model_map is not None:
# m = mapping_deriv @ m

jtjdiag = self.simulation.getJtJdiag(m, W=self.W)

# if self.model_map is not None:
# jtjdiag = mkvc(
# (sdiag(np.sqrt(jtjdiag)) @ mapping_deriv).power(2).sum(axis=0)
# )

return jtjdiag
176 changes: 91 additions & 85 deletions simpeg/meta/dask_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ class DaskMetaSimulation(MetaSimulation):
The dask client to use for communication.
"""

clean_on_model_update = ["_jtjdiag", "_stashed_fields"]

def __init__(self, simulations, mappings, client):
self._client = validate_type("client", client, Client, cast=False)
self._concrete_simulations = None
Expand Down Expand Up @@ -324,6 +326,8 @@ def fields(self, m):
self.model = m
client = self.client
m_future = self._m_as_future
if getattr(self, "_stashed_fields", None) is not None:
return self._stashed_fields
# The above should pass the model to all the internal simulations.
f = []
for mapping, sim, worker in zip(self.mappings, self.simulations, self._workers):
Expand All @@ -337,6 +341,7 @@ def fields(self, m):
workers=worker,
)
)
self._stashed_fields = f
return f

def dpred(self, m=None, f=None):
Expand Down Expand Up @@ -447,91 +452,92 @@ def getJtJdiag(self, m, W=None, f=None):
return self._jtjdiag


def _compute_j(sim, model):
sim.model = model
jmatrix = getattr(sim, "_Jmatrix", None)

if jmatrix is None:
jmatrix = sim.compute_J(model)

return jmatrix


def set_jmatrix(sim, jmatrix):
sim._Jmatrix = jmatrix
return sim


class DaskMetaSimulationExplicit(DaskMetaSimulation):
clean_on_model_update = ["_Jmatrix", "_stashed_fields"]

def fields(self, m):
self.model = m

if getattr(self, "_stashed_fields", None) is not None:
return self._stashed_fields

client = self.client
m_future = self._m_as_future
# The above should pass the model to all the internal simulations.
f = []
simulations = []
for mapping, sim, worker in zip(self.mappings, self.simulations, self._workers):
# jmatrix = client.submit(
# _compute_j,
# sim,
# m_future,
# workers=worker,
# )
# sim = client.submit(set_jmatrix, sim, jmatrix, workers=worker)
f.append(
client.submit(
_calc_fields,
mapping,
sim,
m_future,
self._repeat_sim,
workers=worker,
)
)
simulations.append(sim)

self._stashed_fields = f
self.simulations = simulations
return f

def getJtJdiag(self, m, W=None, f=None):
self.model = m
m_future = self._m_as_future
if getattr(self, "_jtjdiag", None) is None:
if W is None:
W = np.ones(self.survey.nD)
else:
W = W.diagonal()
jtj_diag = []
client = self.client
if f is None:
f = self.fields(m)
for i, (mapping, sim, worker, field) in enumerate(
zip(self.mappings, self.simulations, self._workers, f)
):
sim_w = W[self._data_offsets[i] : self._data_offsets[i + 1]]

jtj_diag.append(
client.submit(
_get_jtj_diag,
mapping,
sim,
m_future,
field,
sim_w,
self._repeat_sim,
workers=worker,
)
)
self._jtjdiag = _reduce(client, add, jtj_diag)

return self._jtjdiag
#
# def _compute_j(sim, model):
# sim.model = model
# jmatrix = getattr(sim, "_Jmatrix", None)
#
# if jmatrix is None:
# jmatrix = sim.compute_J(model)
#
# return jmatrix
#
#
# def set_jmatrix(sim, jmatrix):
# sim._Jmatrix = jmatrix
# return sim


# class DaskMetaSimulationExplicit(DaskMetaSimulation):
# clean_on_model_update = ["_Jmatrix", "_stashed_fields"]
#
# def fields(self, m):
# self.model = m
#
# if getattr(self, "_stashed_fields", None) is not None:
# return self._stashed_fields
#
# client = self.client
# m_future = self._m_as_future
# # The above should pass the model to all the internal simulations.
# f = []
# simulations = []
# for mapping, sim, worker in zip(self.mappings, self.simulations, self._workers):
# # jmatrix = client.submit(
# # _compute_j,
# # sim,
# # m_future,
# # workers=worker,
# # )
# # sim = client.submit(set_jmatrix, sim, jmatrix, workers=worker)
# f.append(
# client.submit(
# _calc_fields,
# mapping,
# sim,
# m_future,
# self._repeat_sim,
# workers=worker,
# )
# )
# simulations.append(sim)
#
# self._stashed_fields = f
# # self.simulations = simulations
# return f
#
# def getJtJdiag(self, m, W=None, f=None):
# self.model = m
# m_future = self._m_as_future
# if getattr(self, "_jtjdiag", None) is None:
# if W is None:
# W = np.ones(self.survey.nD)
# else:
# W = W.diagonal()
# jtj_diag = []
# client = self.client
# if f is None:
# f = self.fields(m)
# for i, (mapping, sim, worker, field) in enumerate(
# zip(self.mappings, self.simulations, self._workers, f)
# ):
# sim_w = W[self._data_offsets[i] : self._data_offsets[i + 1]]
#
# jtj_diag.append(
# client.submit(
# _get_jtj_diag,
# mapping,
# sim,
# m_future,
# field,
# sim_w,
# self._repeat_sim,
# workers=worker,
# )
# )
# self._jtjdiag = _reduce(client, add, jtj_diag)
#
# return self._jtjdiag


class DaskSumMetaSimulation(DaskMetaSimulation, SumMetaSimulation):
Expand Down

0 comments on commit 14e7d2b

Please sign in to comment.