From 5e8515fdf6decfeda4a979c2b166854658888792 Mon Sep 17 00:00:00 2001 From: William Jamieson Date: Fri, 8 Dec 2023 13:38:25 -0500 Subject: [PATCH] Fix for model type --- CHANGES.rst | 2 ++ .../datamodels/_datamodels.py | 27 ++++++++++++++----- tests/test_maker_utils.py | 4 ++- tests/test_models.py | 8 ++++++ 4 files changed, 33 insertions(+), 8 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 3d2b36c2..7b80c10c 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -6,6 +6,8 @@ - Bugfix for ``model.meta.filename`` not matching the filename of the file on disk. [#295] +- Bugfix for ``meta.model_type`` not being set to match the model writing the file. [#296] + 0.18.0 (2023-11-06) =================== diff --git a/src/roman_datamodels/datamodels/_datamodels.py b/src/roman_datamodels/datamodels/_datamodels.py index 6337799c..dfcd35f8 100644 --- a/src/roman_datamodels/datamodels/_datamodels.py +++ b/src/roman_datamodels/datamodels/_datamodels.py @@ -25,29 +25,42 @@ class _DataModel(DataModel): def __init_subclass__(cls, **kwargs): """Register each subclass in the __all__ for this module""" super().__init_subclass__(**kwargs) + + # Don't register private classes + if cls.__name__.startswith("_"): + return + if cls.__name__ in __all__: raise ValueError(f"Duplicate model type {cls.__name__}") __all__.append(cls.__name__) -class MosaicModel(_DataModel): +class _RomanDataModel(_DataModel): + def __init__(self, init=None, **kwargs): + super().__init__(init, **kwargs) + + if init is not None: + self.meta.model_type = self.__class__.__name__ + + +class MosaicModel(_RomanDataModel): _node_type = stnode.WfiMosaic -class ImageModel(_DataModel): +class ImageModel(_RomanDataModel): _node_type = stnode.WfiImage -class ScienceRawModel(_DataModel): +class ScienceRawModel(_RomanDataModel): _node_type = stnode.WfiScienceRaw -class MsosStackModel(_DataModel): +class MsosStackModel(_RomanDataModel): _node_type = stnode.MsosStack -class RampModel(_DataModel): +class RampModel(_RomanDataModel): _node_type = stnode.Ramp @classmethod @@ -86,7 +99,7 @@ def from_science_raw(cls, model): raise ValueError("Input model must be a ScienceRawModel or RampModel") -class RampFitOutputModel(_DataModel): +class RampFitOutputModel(_RomanDataModel): _node_type = stnode.RampFitOutput @@ -107,7 +120,7 @@ def is_association(cls, asn_data): return isinstance(asn_data, dict) and "asn_id" in asn_data and "asn_pool" in asn_data -class GuidewindowModel(_DataModel): +class GuidewindowModel(_RomanDataModel): _node_type = stnode.Guidewindow diff --git a/tests/test_maker_utils.py b/tests/test_maker_utils.py index 2ac5b5ae..c3d75c45 100644 --- a/tests/test_maker_utils.py +++ b/tests/test_maker_utils.py @@ -7,6 +7,7 @@ from astropy.time import Time from roman_datamodels import datamodels, maker_utils, stnode +from roman_datamodels.datamodels._datamodels import _RomanDataModel from roman_datamodels.maker_utils import _ref_files as ref_files from roman_datamodels.testing import assert_node_equal @@ -109,7 +110,8 @@ def test_datamodel_maker(model_class): assert isinstance(model, model_class) model.validate() - assert model.meta.model_type == model_class.__name__ + if issubclass(model_class, _RomanDataModel): + assert model.meta.model_type == model_class.__name__ @pytest.mark.parametrize("node_class", [node for node in datamodels.MODEL_REGISTRY]) diff --git a/tests/test_models.py b/tests/test_models.py index b27661d7..3946e7dc 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -772,6 +772,14 @@ def test_ramp_from_science_raw(): if isinstance(ramp_value, np.ndarray): assert_array_equal(ramp_value, raw_value.astype(ramp_value.dtype)) + elif key == "meta": + for meta_key in ramp_value: + if meta_key == "model_type": + ramp_value[meta_key] = ramp.__class__.__name__ + raw_value[meta_key] = raw.__class__.__name__ + continue + assert_node_equal(ramp_value[meta_key], raw_value[meta_key]) + elif isinstance(ramp_value, stnode.DNode): assert_node_equal(ramp_value, raw_value)