diff --git a/simpeg/dask/objective_function.py b/simpeg/dask/objective_function.py index c501de42df..47ecf0326f 100644 --- a/simpeg/dask/objective_function.py +++ b/simpeg/dask/objective_function.py @@ -24,16 +24,16 @@ def _calc_residual(objfct, model, field): def _deriv(objfct, multiplier, model, fields): if fields is not None and objfct.has_fields: - return 2 * multiplier * objfct.deriv(objfct.simulation.model, f=fields) + return multiplier * objfct.deriv(objfct.simulation.model, f=fields) else: - return 2 * multiplier * objfct.deriv(objfct.simulation.model) + return multiplier * objfct.deriv(objfct.simulation.model) def _deriv2(objfct, multiplier, model, v, fields): if fields is not None and objfct.has_fields: - return 2 * multiplier * objfct.deriv2(objfct.simulation.model, v, f=fields) + return multiplier * objfct.deriv2(objfct.simulation.model, v, f=fields) else: - return 2 * multiplier * objfct.deriv2(objfct.simulation.model, v) + return multiplier * objfct.deriv2(objfct.simulation.model, v) def _store_model(objfct, model): @@ -229,7 +229,7 @@ def fields(self, m): workers=worker, ) ) - self._stashed_fields = f + self._stashed_fields = client.compute(f) return f @property diff --git a/simpeg/dask/potential_fields/magnetics/simulation.py b/simpeg/dask/potential_fields/magnetics/simulation.py index e19ca3f4d5..09c558f38d 100644 --- a/simpeg/dask/potential_fields/magnetics/simulation.py +++ b/simpeg/dask/potential_fields/magnetics/simulation.py @@ -18,9 +18,7 @@ def getJtJdiag(self, m, W=None, f=None): W = W.diagonal() if getattr(self, "_gtg_diagonal", None) is None: if not self.is_amplitude_data: - diag = array.einsum( - "i,ij,ij->j", W**2, self.Jmatrix, self.Jmatrix - ).compute() + diag = array.einsum("i,ij,ij->j", W**2, self.Jmatrix, self.Jmatrix) else: ampDeriv = self.ampDeriv J = ( @@ -28,12 +26,12 @@ def getJtJdiag(self, m, W=None, f=None): + ampDeriv[1, :, None] * self.Jmatrix[1::3] + ampDeriv[2, :, None] * self.Jmatrix[2::3] ) - diag = array.einsum("i,ij,ij->j", W**2, J, J).compute() - self._gtg_diagonal = diag - else: - diag = self._gtg_diagonal + diag = array.einsum("i,ij,ij->j", W**2, J, J) + self._gtg_diagonal = np.asarray(diag) - return mkvc((sdiag(np.sqrt(diag)) @ self.chiDeriv).power(2).sum(axis=0)) + return mkvc( + (sdiag(np.sqrt(self._gtg_diagonal)) @ self.chiDeriv).power(2).sum(axis=0) + ) Sim.clean_on_model_update = []