diff --git a/simpeg/dask/electromagnetics/static/induced_polarization/simulation.py b/simpeg/dask/electromagnetics/static/induced_polarization/simulation.py index beee5f7338..4fbf0d46a5 100644 --- a/simpeg/dask/electromagnetics/static/induced_polarization/simulation.py +++ b/simpeg/dask/electromagnetics/static/induced_polarization/simulation.py @@ -76,18 +76,13 @@ def getJtJdiag(self, m, W=None): """ self.model = m if getattr(self, "_jtjdiag", None) is None: - if isinstance(self.Jmatrix, Future): - self.Jmatrix # Wait to finish if W is None: W = self._scale * np.ones(self.nD) else: W = (self._scale * W.diagonal()) ** 2.0 - diag = da.einsum("i,ij,ij->j", W, self.Jmatrix, self.Jmatrix) - - if isinstance(diag, da.Array): - diag = np.asarray(diag.compute()) + diag = np.einsum("i,ij,ij->j", W, self.Jmatrix, self.Jmatrix) self._jtjdiag = diag @@ -121,7 +116,7 @@ def Jtvec(self, m, v, f=None): if isinstance(self.Jmatrix, Future): self.Jmatrix # Wait to finish - return da.dot(v * self._scale, self.Jmatrix).astype(np.float32) + return da.dot((v * self._scale).astype(np.float32), self.Jmatrix).astype(np.float32) Sim.compute_J = compute_J