Skip to content

Commit

Permalink
Revert "Skip dpred for residuals"
Browse files Browse the repository at this point in the history
This reverts commit a2727e1.
  • Loading branch information
domfournier committed Jan 17, 2025
1 parent 580bda9 commit 74777e5
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 8 deletions.
2 changes: 1 addition & 1 deletion simpeg/dask/inverse_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def dask_evalFunction(self, m, return_g=True, return_H=True):
residuals = []
print("Computing residuals")
if isinstance(self.dmisfit, DaskComboMisfits):
residuals = self.dmisfit.residuals(m, self.dpred)
residuals = self.dmisfit.residuals(m)
else:
for (_, objfct), pred in zip(self.dmisfit, self.dpred):
residuals.append(objfct.W * (objfct.data.dobs - pred))
Expand Down
15 changes: 9 additions & 6 deletions simpeg/dask/objective_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ def _calc_dpred(objfct, model):
return objfct.simulation.dpred(m=objfct.simulation.model)


def _calc_residual(objfct, dpred):
return objfct.W * (objfct.data.dobs - dpred)
def _calc_residual(objfct, model):
return objfct.W * (
objfct.data.dobs - objfct.simulation.dpred(m=objfct.simulation.model)
)


def _deriv(objfct, multiplier, model):
Expand Down Expand Up @@ -278,7 +280,6 @@ def get_dpred(self, m, f=None):
client = self.client
m_future = self._m_as_future
dpred = []
print("in dpred")
for futures in self._futures:
for objfct, worker in zip(futures, self._workers):
dpred.append(
Expand Down Expand Up @@ -396,20 +397,22 @@ def objfcts(self, objfcts):
self._futures = futures
self._workers = workers

def residuals(self, m, dpreds, f=None):
def residuals(self, m, f=None):
"""
Compute the residual for the data misfit.
"""
self.model = m

client = self.client
m_future = self._m_as_future
residuals = []
for futures in self._futures:
for objfct, worker, dpred in zip(futures, self._workers, dpreds):
for objfct, worker in zip(futures, self._workers):
residuals.append(
client.submit(
_calc_residual,
objfct,
dpred,
m_future,
workers=worker,
)
)
Expand Down
2 changes: 1 addition & 1 deletion simpeg/directives/directives.py
Original file line number Diff line number Diff line change
Expand Up @@ -3271,7 +3271,7 @@ def get_values(self, values: list[np.ndarray] | None):
print("Computing dpred")
dpred = self.invProb.get_dpred(self.invProb.model)
self.invProb.dpred = dpred
print("Done")

if self.joint_index is not None:
dpred = [dpred[ind] for ind in self.joint_index]

Expand Down

0 comments on commit 74777e5

Please sign in to comment.