Skip to content

Commit

Permalink
Fix for model type
Browse files Browse the repository at this point in the history
  • Loading branch information
WilliamJamieson committed Dec 15, 2023
1 parent 05faed7 commit 5e8515f
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 8 deletions.
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

- Bugfix for ``model.meta.filename`` not matching the filename of the file on disk. [#295]

- Bugfix for ``meta.model_type`` not being set to match the model writing the file. [#296]

0.18.0 (2023-11-06)
===================

Expand Down
27 changes: 20 additions & 7 deletions src/roman_datamodels/datamodels/_datamodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,29 +25,42 @@ class _DataModel(DataModel):
def __init_subclass__(cls, **kwargs):
"""Register each subclass in the __all__ for this module"""
super().__init_subclass__(**kwargs)

# Don't register private classes
if cls.__name__.startswith("_"):
return

if cls.__name__ in __all__:
raise ValueError(f"Duplicate model type {cls.__name__}")

__all__.append(cls.__name__)


class MosaicModel(_DataModel):
class _RomanDataModel(_DataModel):
def __init__(self, init=None, **kwargs):
super().__init__(init, **kwargs)

if init is not None:
self.meta.model_type = self.__class__.__name__


class MosaicModel(_RomanDataModel):
_node_type = stnode.WfiMosaic


class ImageModel(_DataModel):
class ImageModel(_RomanDataModel):
_node_type = stnode.WfiImage


class ScienceRawModel(_DataModel):
class ScienceRawModel(_RomanDataModel):
_node_type = stnode.WfiScienceRaw


class MsosStackModel(_DataModel):
class MsosStackModel(_RomanDataModel):
_node_type = stnode.MsosStack


class RampModel(_DataModel):
class RampModel(_RomanDataModel):
_node_type = stnode.Ramp

@classmethod
Expand Down Expand Up @@ -86,7 +99,7 @@ def from_science_raw(cls, model):
raise ValueError("Input model must be a ScienceRawModel or RampModel")


class RampFitOutputModel(_DataModel):
class RampFitOutputModel(_RomanDataModel):
_node_type = stnode.RampFitOutput


Expand All @@ -107,7 +120,7 @@ def is_association(cls, asn_data):
return isinstance(asn_data, dict) and "asn_id" in asn_data and "asn_pool" in asn_data


class GuidewindowModel(_DataModel):
class GuidewindowModel(_RomanDataModel):
_node_type = stnode.Guidewindow


Expand Down
4 changes: 3 additions & 1 deletion tests/test_maker_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from astropy.time import Time

from roman_datamodels import datamodels, maker_utils, stnode
from roman_datamodels.datamodels._datamodels import _RomanDataModel
from roman_datamodels.maker_utils import _ref_files as ref_files
from roman_datamodels.testing import assert_node_equal

Expand Down Expand Up @@ -109,7 +110,8 @@ def test_datamodel_maker(model_class):
assert isinstance(model, model_class)
model.validate()

assert model.meta.model_type == model_class.__name__
if issubclass(model_class, _RomanDataModel):
assert model.meta.model_type == model_class.__name__


@pytest.mark.parametrize("node_class", [node for node in datamodels.MODEL_REGISTRY])
Expand Down
8 changes: 8 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,14 @@ def test_ramp_from_science_raw():
if isinstance(ramp_value, np.ndarray):
assert_array_equal(ramp_value, raw_value.astype(ramp_value.dtype))

elif key == "meta":
for meta_key in ramp_value:
if meta_key == "model_type":
ramp_value[meta_key] = ramp.__class__.__name__
raw_value[meta_key] = raw.__class__.__name__
continue
assert_node_equal(ramp_value[meta_key], raw_value[meta_key])

elif isinstance(ramp_value, stnode.DNode):
assert_node_equal(ramp_value, raw_value)

Expand Down

0 comments on commit 5e8515f

Please sign in to comment.