Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/release/0.21.2.1' into release…
Browse files Browse the repository at this point in the history
…/0.21.2.1
  • Loading branch information
sebhmg committed Nov 8, 2024
2 parents 1c9a653 + c0c8e7a commit 24b6b39
Showing 1 changed file with 24 additions and 12 deletions.
36 changes: 24 additions & 12 deletions simpeg/directives/directives.py
Original file line number Diff line number Diff line change
Expand Up @@ -3575,6 +3575,7 @@ class ScaleMisfitMultipliers(InversionDirective):

def __init__(self, path: Path | None, **kwargs):
self.last_beta = None
self.chi_factors = None

if path is None:
path = Path()
Expand All @@ -3599,41 +3600,52 @@ def initialize(self):

def endIter(self):
ratio = self.invProb.beta / self.last_beta

if ratio > 1:
return

chi_factors = []
for objfct, pred in zip(self.invProb.dmisfit.objfcts, self.invProb.dpred):
residual = objfct.W * (objfct.data.dobs - pred)
phi_d = np.vdot(residual, residual)
chi_factors.append(phi_d / objfct.nD)

chi_factors = np.asarray(chi_factors)
self.chi_factors = np.asarray(chi_factors)

if np.all(self.chi_factors < 1) or ratio >= 1:
self.last_beta = self.invProb.beta
self.write_log()
return

# Normalize scaling between [ratio, 1]
scalings = (
1 - (1 - ratio) * (chi_factors.max() - chi_factors) / chi_factors.max()
1
- (1 - ratio)
* (self.chi_factors.max() - self.chi_factors)
/ self.chi_factors.max()
)

# Force the ones that overshot target
scalings[chi_factors < 1] = ratio * chi_factors[chi_factors < 1]
scalings[self.chi_factors < 1] = (
ratio # * self.chi_factors[self.chi_factors < 1]
)

# Update the scaling
self.scalings *= scalings

# Normalize total phi_d with scalings
multipliers = self.multipliers * self.scalings
self.invProb.dmisfit.multipliers = self.multipliers * self.scalings
self.last_beta = self.invProb.beta
self.write_log()

def write_log(self):
"""
Write the scaling factors to the log file.
"""
with open(self.filepath, "a", encoding="utf-8") as f:
f.write(
f"{self.opt.iter}\t"
+ "\t".join(
f"{multi:.2e} * {chi:.2e}"
for multi, chi in zip(multipliers, chi_factors)
for multi, chi in zip(
self.invProb.dmisfit.multipliers, self.chi_factors
)
)
+ "\n"
)

self.invProb.dmisfit.multipliers = multipliers.tolist()
self.last_beta = self.invProb.beta

0 comments on commit 24b6b39

Please sign in to comment.