From 14e7d2b22818241ab6a6840123061738678ef706 Mon Sep 17 00:00:00 2001 From: domfournier Date: Fri, 27 Dec 2024 21:37:11 -0800 Subject: [PATCH] Comment out Explicit version. Clean ups --- simpeg/dask/simulation.py | 1 - simpeg/data_misfit.py | 12 +-- simpeg/meta/dask_sim.py | 176 ++++++++++++++++++++------------------ 3 files changed, 92 insertions(+), 97 deletions(-) diff --git a/simpeg/dask/simulation.py b/simpeg/dask/simulation.py index 535bb8d38c..ca915ed62e 100644 --- a/simpeg/dask/simulation.py +++ b/simpeg/dask/simulation.py @@ -115,7 +115,6 @@ def Jmatrix(self): """ if getattr(self, "_Jmatrix", None) is None: self._Jmatrix = self.compute_J(self.model) - self._stashed_fields = None return self._Jmatrix diff --git a/simpeg/data_misfit.py b/simpeg/data_misfit.py index 05dabcbb41..ef8273b36f 100644 --- a/simpeg/data_misfit.py +++ b/simpeg/data_misfit.py @@ -1,5 +1,5 @@ import numpy as np -from .utils import Counter, mkvc, sdiag, timeIt, Identity, validate_type +from .utils import Counter, sdiag, timeIt, Identity, validate_type from .data import Data from .simulation import BaseSimulation from .objective_function import L2ObjectiveFunction @@ -359,16 +359,6 @@ def getJtJdiag(self, m): + "Cannot form the sensitivity explicitly" ) - # mapping_deriv = self.model_map.deriv(m) - # - # if self.model_map is not None: - # m = mapping_deriv @ m - jtjdiag = self.simulation.getJtJdiag(m, W=self.W) - # if self.model_map is not None: - # jtjdiag = mkvc( - # (sdiag(np.sqrt(jtjdiag)) @ mapping_deriv).power(2).sum(axis=0) - # ) - return jtjdiag diff --git a/simpeg/meta/dask_sim.py b/simpeg/meta/dask_sim.py index cbb1582d5e..a30bdc4d71 100644 --- a/simpeg/meta/dask_sim.py +++ b/simpeg/meta/dask_sim.py @@ -158,6 +158,8 @@ class DaskMetaSimulation(MetaSimulation): The dask client to use for communication. """ + clean_on_model_update = ["_jtjdiag", "_stashed_fields"] + def __init__(self, simulations, mappings, client): self._client = validate_type("client", client, Client, cast=False) self._concrete_simulations = None @@ -324,6 +326,8 @@ def fields(self, m): self.model = m client = self.client m_future = self._m_as_future + if getattr(self, "_stashed_fields", None) is not None: + return self._stashed_fields # The above should pass the model to all the internal simulations. f = [] for mapping, sim, worker in zip(self.mappings, self.simulations, self._workers): @@ -337,6 +341,7 @@ def fields(self, m): workers=worker, ) ) + self._stashed_fields = f return f def dpred(self, m=None, f=None): @@ -447,91 +452,92 @@ def getJtJdiag(self, m, W=None, f=None): return self._jtjdiag -def _compute_j(sim, model): - sim.model = model - jmatrix = getattr(sim, "_Jmatrix", None) - - if jmatrix is None: - jmatrix = sim.compute_J(model) - - return jmatrix - - -def set_jmatrix(sim, jmatrix): - sim._Jmatrix = jmatrix - return sim - - -class DaskMetaSimulationExplicit(DaskMetaSimulation): - clean_on_model_update = ["_Jmatrix", "_stashed_fields"] - - def fields(self, m): - self.model = m - - if getattr(self, "_stashed_fields", None) is not None: - return self._stashed_fields - - client = self.client - m_future = self._m_as_future - # The above should pass the model to all the internal simulations. - f = [] - simulations = [] - for mapping, sim, worker in zip(self.mappings, self.simulations, self._workers): - # jmatrix = client.submit( - # _compute_j, - # sim, - # m_future, - # workers=worker, - # ) - # sim = client.submit(set_jmatrix, sim, jmatrix, workers=worker) - f.append( - client.submit( - _calc_fields, - mapping, - sim, - m_future, - self._repeat_sim, - workers=worker, - ) - ) - simulations.append(sim) - - self._stashed_fields = f - self.simulations = simulations - return f - - def getJtJdiag(self, m, W=None, f=None): - self.model = m - m_future = self._m_as_future - if getattr(self, "_jtjdiag", None) is None: - if W is None: - W = np.ones(self.survey.nD) - else: - W = W.diagonal() - jtj_diag = [] - client = self.client - if f is None: - f = self.fields(m) - for i, (mapping, sim, worker, field) in enumerate( - zip(self.mappings, self.simulations, self._workers, f) - ): - sim_w = W[self._data_offsets[i] : self._data_offsets[i + 1]] - - jtj_diag.append( - client.submit( - _get_jtj_diag, - mapping, - sim, - m_future, - field, - sim_w, - self._repeat_sim, - workers=worker, - ) - ) - self._jtjdiag = _reduce(client, add, jtj_diag) - - return self._jtjdiag +# +# def _compute_j(sim, model): +# sim.model = model +# jmatrix = getattr(sim, "_Jmatrix", None) +# +# if jmatrix is None: +# jmatrix = sim.compute_J(model) +# +# return jmatrix +# +# +# def set_jmatrix(sim, jmatrix): +# sim._Jmatrix = jmatrix +# return sim + + +# class DaskMetaSimulationExplicit(DaskMetaSimulation): +# clean_on_model_update = ["_Jmatrix", "_stashed_fields"] +# +# def fields(self, m): +# self.model = m +# +# if getattr(self, "_stashed_fields", None) is not None: +# return self._stashed_fields +# +# client = self.client +# m_future = self._m_as_future +# # The above should pass the model to all the internal simulations. +# f = [] +# simulations = [] +# for mapping, sim, worker in zip(self.mappings, self.simulations, self._workers): +# # jmatrix = client.submit( +# # _compute_j, +# # sim, +# # m_future, +# # workers=worker, +# # ) +# # sim = client.submit(set_jmatrix, sim, jmatrix, workers=worker) +# f.append( +# client.submit( +# _calc_fields, +# mapping, +# sim, +# m_future, +# self._repeat_sim, +# workers=worker, +# ) +# ) +# simulations.append(sim) +# +# self._stashed_fields = f +# # self.simulations = simulations +# return f +# +# def getJtJdiag(self, m, W=None, f=None): +# self.model = m +# m_future = self._m_as_future +# if getattr(self, "_jtjdiag", None) is None: +# if W is None: +# W = np.ones(self.survey.nD) +# else: +# W = W.diagonal() +# jtj_diag = [] +# client = self.client +# if f is None: +# f = self.fields(m) +# for i, (mapping, sim, worker, field) in enumerate( +# zip(self.mappings, self.simulations, self._workers, f) +# ): +# sim_w = W[self._data_offsets[i] : self._data_offsets[i + 1]] +# +# jtj_diag.append( +# client.submit( +# _get_jtj_diag, +# mapping, +# sim, +# m_future, +# field, +# sim_w, +# self._repeat_sim, +# workers=worker, +# ) +# ) +# self._jtjdiag = _reduce(client, add, jtj_diag) +# +# return self._jtjdiag class DaskSumMetaSimulation(DaskMetaSimulation, SumMetaSimulation):