Skip to content

Commit

Permalink
Remove 2 multiplier
Browse files Browse the repository at this point in the history
  • Loading branch information
domfournier committed Jan 7, 2025
1 parent b2ad007 commit 704289b
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 13 deletions.
10 changes: 5 additions & 5 deletions simpeg/dask/objective_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,16 @@ def _calc_residual(objfct, model, field):

def _deriv(objfct, multiplier, model, fields):
if fields is not None and objfct.has_fields:
return 2 * multiplier * objfct.deriv(objfct.simulation.model, f=fields)
return multiplier * objfct.deriv(objfct.simulation.model, f=fields)
else:
return 2 * multiplier * objfct.deriv(objfct.simulation.model)
return multiplier * objfct.deriv(objfct.simulation.model)


def _deriv2(objfct, multiplier, model, v, fields):
if fields is not None and objfct.has_fields:
return 2 * multiplier * objfct.deriv2(objfct.simulation.model, v, f=fields)
return multiplier * objfct.deriv2(objfct.simulation.model, v, f=fields)
else:
return 2 * multiplier * objfct.deriv2(objfct.simulation.model, v)
return multiplier * objfct.deriv2(objfct.simulation.model, v)


def _store_model(objfct, model):
Expand Down Expand Up @@ -229,7 +229,7 @@ def fields(self, m):
workers=worker,
)
)
self._stashed_fields = f
self._stashed_fields = client.compute(f)
return f

@property
Expand Down
14 changes: 6 additions & 8 deletions simpeg/dask/potential_fields/magnetics/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,20 @@ def getJtJdiag(self, m, W=None, f=None):
W = W.diagonal()
if getattr(self, "_gtg_diagonal", None) is None:
if not self.is_amplitude_data:
diag = array.einsum(
"i,ij,ij->j", W**2, self.Jmatrix, self.Jmatrix
).compute()
diag = array.einsum("i,ij,ij->j", W**2, self.Jmatrix, self.Jmatrix)
else:
ampDeriv = self.ampDeriv
J = (
ampDeriv[0, :, None] * self.Jmatrix[::3]
+ ampDeriv[1, :, None] * self.Jmatrix[1::3]
+ ampDeriv[2, :, None] * self.Jmatrix[2::3]
)
diag = array.einsum("i,ij,ij->j", W**2, J, J).compute()
self._gtg_diagonal = diag
else:
diag = self._gtg_diagonal
diag = array.einsum("i,ij,ij->j", W**2, J, J)
self._gtg_diagonal = np.asarray(diag)

return mkvc((sdiag(np.sqrt(diag)) @ self.chiDeriv).power(2).sum(axis=0))
return mkvc(
(sdiag(np.sqrt(self._gtg_diagonal)) @ self.chiDeriv).power(2).sum(axis=0)
)


Sim.clean_on_model_update = []
Expand Down

0 comments on commit 704289b

Please sign in to comment.