Skip to content

Commit

Permalink
Fix for mag. Best so far
Browse files Browse the repository at this point in the history
  • Loading branch information
domfournier committed Jan 8, 2025
1 parent 704289b commit 3cf2f4d
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 38 deletions.
11 changes: 7 additions & 4 deletions simpeg/dask/objective_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ def deriv(self, m, f=None):
_deriv, objfct, multiplier, m_future, field, workers=worker
)
)

return _reduce(client, add, derivs)
derivs = _reduce(client, add, derivs)
return derivs

def deriv2(self, m, v=None, f=None):
"""
Expand Down Expand Up @@ -166,7 +166,9 @@ def deriv2(self, m, v=None, f=None):
)
)

return _reduce(client, add, derivs)
derivs = _reduce(client, add, derivs)

return derivs

def get_dpred(self, m, f=None):
self.model = m
Expand Down Expand Up @@ -229,7 +231,7 @@ def fields(self, m):
workers=worker,
)
)
self._stashed_fields = client.compute(f)
self._stashed_fields = f
return f

@property
Expand Down Expand Up @@ -259,6 +261,7 @@ def model(self, value):
)
)
self.client.gather(futures) # blocking call to ensure all models were stored
self._model = value

@property
def objfcts(self):
Expand Down
33 changes: 1 addition & 32 deletions simpeg/dask/potential_fields/magnetics/simulation.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,6 @@
import numpy as np
from dask import array
from ....potential_fields.magnetics import Simulation3DIntegral as Sim
from ..base import G
from ....utils import sdiag, mkvc


def getJtJdiag(self, m, W=None, f=None):
"""
Return the diagonal of JtJ
"""

self.model = m

if W is None:
W = np.ones(self.nD)
else:
W = W.diagonal()
if getattr(self, "_gtg_diagonal", None) is None:
if not self.is_amplitude_data:
diag = array.einsum("i,ij,ij->j", W**2, self.Jmatrix, self.Jmatrix)
else:
ampDeriv = self.ampDeriv
J = (
ampDeriv[0, :, None] * self.Jmatrix[::3]
+ ampDeriv[1, :, None] * self.Jmatrix[1::3]
+ ampDeriv[2, :, None] * self.Jmatrix[2::3]
)
diag = array.einsum("i,ij,ij->j", W**2, J, J)
self._gtg_diagonal = np.asarray(diag)

return mkvc(
(sdiag(np.sqrt(self._gtg_diagonal)) @ self.chiDeriv).power(2).sum(axis=0)
)
from ...simulation import getJtJdiag


Sim.clean_on_model_update = []
Expand Down
3 changes: 1 addition & 2 deletions simpeg/potential_fields/magnetics/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,7 @@ def getJtJdiag(self, m, W=None, f=None):
if getattr(self, "_gtg_diagonal", None) is None:
diag = np.zeros(self.Jmatrix.shape[1])
if not self.is_amplitude_data:
for i in range(len(W)):
diag += W[i] * (self.Jmatrix[i] * self.Jmatrix[i])
diag = np.einsum("i,ij,ij->j", W, self.Jmatrix, self.Jmatrix)
else:
ampDeriv = self.ampDeriv
Gx = self.Jmatrix[::3]
Expand Down

0 comments on commit 3cf2f4d

Please sign in to comment.