diff --git a/simpeg/directives/directives.py b/simpeg/directives/directives.py index f88f5517f2..5cdaba6152 100644 --- a/simpeg/directives/directives.py +++ b/simpeg/directives/directives.py @@ -2987,23 +2987,29 @@ class SaveIterationsGeoH5(InversionDirective): Saves inversion results to a geoh5 file """ - def __init__(self, h5_object, **kwargs): + def __init__( + self, h5_object, dmisfit=None, attribute_type: str = "model", **kwargs + ): self.data_type = {} self._association = None - self.attribute_type = "model" + self.attribute_type = attribute_type self._label = None self.channels = [""] self.components = [""] - self._h5_object = None - self._workspace = None self._transforms: list = [] self.save_objective_function = False self.sorting = None self._reshape = None self.h5_object = h5_object self._joint_index = None + + if attribute_type == "sensitivities" and dmisfit is None: + raise ValueError( + "To save sensitivities, the data misfit object must be provided." + ) + super().__init__( - inversion=None, dmisfit=None, reg=None, verbose=False, **kwargs + inversion=None, dmisfit=dmisfit, reg=None, verbose=False, **kwargs ) def initialize(self): @@ -3085,9 +3091,10 @@ def get_values(self, values: list[np.ndarray] | None): prop = self.stack_channels(dpred) elif self.attribute_type == "sensitivities": - for directive in self.inversion.directiveList.dList: - if isinstance(directive, directives.UpdateSensitivityWeights): - prop = self.reshape(np.sum(directive.JtJdiag, axis=0) ** 0.5) + + prop = np.zeros_like(self.invProb.model) + for fun in self.dmisfit.objfcts: + prop += fun.getJtJdiag(self.invProb.model) return prop