Skip to content

Commit

Permalink
Update ip
Browse files Browse the repository at this point in the history
  • Loading branch information
domfournier committed Jan 12, 2025
1 parent f953297 commit 96c3a5f
Showing 1 changed file with 2 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,13 @@ def getJtJdiag(self, m, W=None):
"""
self.model = m
if getattr(self, "_jtjdiag", None) is None:
if isinstance(self.Jmatrix, Future):
self.Jmatrix # Wait to finish

if W is None:
W = self._scale * np.ones(self.nD)
else:
W = (self._scale * W.diagonal()) ** 2.0

diag = da.einsum("i,ij,ij->j", W, self.Jmatrix, self.Jmatrix)

if isinstance(diag, da.Array):
diag = np.asarray(diag.compute())
diag = np.einsum("i,ij,ij->j", W, self.Jmatrix, self.Jmatrix)

self._jtjdiag = diag

Expand Down Expand Up @@ -121,7 +116,7 @@ def Jtvec(self, m, v, f=None):
if isinstance(self.Jmatrix, Future):
self.Jmatrix # Wait to finish

return da.dot(v * self._scale, self.Jmatrix).astype(np.float32)
return da.dot((v * self._scale).astype(np.float32), self.Jmatrix).astype(np.float32)


Sim.compute_J = compute_J
Expand Down

0 comments on commit 96c3a5f

Please sign in to comment.