Skip to content

Commit

Permalink
RCAL-977 - Version datamodels (#445)
Browse files Browse the repository at this point in the history
Co-authored-by: Eddie Schlafly <[email protected]>
  • Loading branch information
braingram and schlafly authored Jan 27, 2025
1 parent ffd56b1 commit e2039db
Show file tree
Hide file tree
Showing 12 changed files with 175 additions and 181 deletions.
1 change: 1 addition & 0 deletions changes/445.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Start versioning files by allows Node instances to use multiple versions of tags.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ dependencies = [
"gwcs >=0.19.0",
"numpy >=1.24",
"astropy >=5.3.0",
"rad >=0.23.0, <0.24.0",
# "rad @ git+https://github.com/spacetelescope/rad.git",
# "rad >=0.23.0, <0.24.0",
"rad @ git+https://github.com/spacetelescope/rad.git",
"asdf-standard >=1.1.0",
]
dynamic = ["version"]
Expand Down
59 changes: 26 additions & 33 deletions src/roman_datamodels/stnode/_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,14 @@
from asdf.extension import Converter, ManifestExtension
from astropy.time import Time

from ._registry import LIST_NODE_CLASSES_BY_TAG, NODE_CONVERTERS, OBJECT_NODE_CLASSES_BY_TAG, SCALAR_NODE_CLASSES_BY_TAG
from ._registry import (
LIST_NODE_CLASSES_BY_PATTERN,
NODE_CLASSES_BY_TAG,
NODE_CONVERTERS,
OBJECT_NODE_CLASSES_BY_PATTERN,
SCALAR_NODE_CLASSES_BY_PATTERN,
)
from ._stnode import _MANIFESTS

__all__ = [
"NODE_EXTENSIONS",
Expand Down Expand Up @@ -34,6 +41,14 @@ def __init_subclass__(cls, **kwargs) -> None:

NODE_CONVERTERS[cls.__name__] = cls()

def select_tag(self, obj, tags, ctx):
return obj.tag

def from_yaml_tree(self, node, tag, ctx):
obj = NODE_CLASSES_BY_TAG[tag](node)
obj._read_tag = tag
return obj


class TaggedObjectNodeConverter(_RomanConverter):
"""
Expand All @@ -42,21 +57,15 @@ class TaggedObjectNodeConverter(_RomanConverter):

@property
def tags(self):
return list(OBJECT_NODE_CLASSES_BY_TAG.keys())
return list(OBJECT_NODE_CLASSES_BY_PATTERN.keys())

@property
def types(self):
return list(OBJECT_NODE_CLASSES_BY_TAG.values())

def select_tag(self, obj, tags, ctx):
return obj.tag
return list(OBJECT_NODE_CLASSES_BY_PATTERN.values())

def to_yaml_tree(self, obj, tag, ctx):
return dict(obj._data)

def from_yaml_tree(self, node, tag, ctx):
return OBJECT_NODE_CLASSES_BY_TAG[tag](node)


class TaggedListNodeConverter(_RomanConverter):
"""
Expand All @@ -65,21 +74,15 @@ class TaggedListNodeConverter(_RomanConverter):

@property
def tags(self):
return list(LIST_NODE_CLASSES_BY_TAG.keys())
return list(LIST_NODE_CLASSES_BY_PATTERN.keys())

@property
def types(self):
return list(LIST_NODE_CLASSES_BY_TAG.values())

def select_tag(self, obj, tags, ctx):
return obj.tag
return list(LIST_NODE_CLASSES_BY_PATTERN.values())

def to_yaml_tree(self, obj, tag, ctx):
return list(obj)

def from_yaml_tree(self, node, tag, ctx):
return LIST_NODE_CLASSES_BY_TAG[tag](node)


class TaggedScalarNodeConverter(_RomanConverter):
"""
Expand All @@ -88,37 +91,27 @@ class TaggedScalarNodeConverter(_RomanConverter):

@property
def tags(self):
return list(SCALAR_NODE_CLASSES_BY_TAG.keys())
return list(SCALAR_NODE_CLASSES_BY_PATTERN.keys())

@property
def types(self):
return list(SCALAR_NODE_CLASSES_BY_TAG.values())

def select_tag(self, obj, tags, ctx):
return obj.tag
return list(SCALAR_NODE_CLASSES_BY_PATTERN.values())

def to_yaml_tree(self, obj, tag, ctx):
from ._stnode import FileDate, FpsFileDate, TvacFileDate

node = obj.__class__.__bases__[0](obj)

if tag in (FileDate._tag, FpsFileDate._tag, TvacFileDate._tag):
if "file_date" in tag:
converter = ctx.extension_manager.get_converter_for_type(type(node))
node = converter.to_yaml_tree(node, tag, ctx)

return node

def from_yaml_tree(self, node, tag, ctx):
from ._stnode import FileDate, FpsFileDate, TvacFileDate

if tag in (FileDate._tag, FpsFileDate._tag, TvacFileDate._tag):
if "file_date" in tag:
converter = ctx.extension_manager.get_converter_for_type(Time)
node = converter.from_yaml_tree(node, tag, ctx)

return SCALAR_NODE_CLASSES_BY_TAG[tag](node)
return super().from_yaml_tree(node, tag, ctx)


# Create the ASDF extension for the STNode classes.
NODE_EXTENSIONS = [
ManifestExtension.from_uri("asdf://stsci.edu/datamodels/roman/manifests/datamodels-1.0", converters=NODE_CONVERTERS.values()),
]
NODE_EXTENSIONS = [ManifestExtension.from_uri(manifest["id"], converters=NODE_CONVERTERS.values()) for manifest in _MANIFESTS]
141 changes: 51 additions & 90 deletions src/roman_datamodels/stnode/_factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,59 +3,23 @@
These are used to dynamically create classes from the RAD manifest.
"""

import importlib.resources

import yaml
from astropy.time import Time
from rad import resources

from . import _mixins
from ._tagged import TaggedListNode, TaggedObjectNode, TaggedScalarNode, name_from_tag_uri

__all__ = ["stnode_factory"]

# Map of scalar types in the schemas to the python types
SCALAR_TYPE_MAP = {
"string": str,
"http://stsci.edu/schemas/asdf/time/time-1.1.0": Time,
# Map of scalar types by pattern (str is default)
_SCALAR_TYPE_BY_PATTERN = {
"asdf://stsci.edu/datamodels/roman/tags/file_date-*": Time,
"asdf://stsci.edu/datamodels/roman/tags/fps/file_date-*": Time,
"asdf://stsci.edu/datamodels/roman/tags/tvac/file_date-*": Time,
}
# Map of node types by pattern (TaggedObjectNode is default)
_NODE_TYPE_BY_PATTERN = {
"asdf://stsci.edu/datamodels/roman/tags/cal_logs-*": TaggedListNode,
}

BASE_SCHEMA_PATH = importlib.resources.files(resources) / "schemas"


def load_schema_from_uri(schema_uri):
"""
Load the actual schema from the rad resources directly (outside ASDF)
Outside ASDF because this has to occur before the ASDF extensions are
registered.
Parameters
----------
schema_uri : str
The schema_uri found in the RAD manifest
Returns
-------
yaml library dictionary from the schema
"""
filename = f"{schema_uri.split('/')[-1]}.yaml"

if "reference_files" in schema_uri:
schema_path = BASE_SCHEMA_PATH / "reference_files" / filename
elif "/fps/tagged_scalars" in schema_uri:
schema_path = BASE_SCHEMA_PATH / "fps/tagged_scalars" / filename
elif "/fps/" in schema_uri:
schema_path = BASE_SCHEMA_PATH / "fps" / filename
elif "/tvac/tagged_scalars" in schema_uri:
schema_path = BASE_SCHEMA_PATH / "tvac/tagged_scalars" / filename
elif "/tvac/" in schema_uri:
schema_path = BASE_SCHEMA_PATH / "tvac" / filename
elif "tagged_scalars" in schema_uri:
schema_path = BASE_SCHEMA_PATH / "tagged_scalars" / filename
else:
schema_path = BASE_SCHEMA_PATH / filename

return yaml.safe_load(schema_path.read_bytes())


def class_name_from_tag_uri(tag_uri):
Expand All @@ -79,94 +43,83 @@ def class_name_from_tag_uri(tag_uri):
return class_name


def docstring_from_tag(tag):
def docstring_from_tag(tag_def):
"""
Read the docstring (if it exists) from the RAD manifest and generate a docstring
for the dynamically generated class.
Parameters
----------
tag: dict
tag_def: dict
A tag entry from the RAD manifest
Returns
-------
A docstring for the class based on the tag
"""
docstring = f"{tag['description']}\n\n" if "description" in tag else ""
docstring = f"{tag_def['description']}\n\n" if "description" in tag_def else ""

return docstring + f"Class generated from tag '{tag['tag_uri']}'"
return docstring + f"Class generated from tag '{tag_def['tag_uri']}'"


def scalar_factory(tag):
def scalar_factory(pattern, tag_def):
"""
Factory to create a TaggedScalarNode class from a tag
Parameters
----------
tag: dict
pattern: str
A tag pattern/wildcard
tag_def: dict
A tag entry from the RAD manifest
Returns
-------
A dynamically generated TaggedScalarNode subclass
"""
class_name = class_name_from_tag_uri(tag["tag_uri"])
schema = load_schema_from_uri(tag["schema_uri"])
class_name = class_name_from_tag_uri(pattern)

# TaggedScalarNode subclasses are really subclasses of the type of the scalar,
# with the TaggedScalarNode as a mixin. This is because the TaggedScalarNode
# is supposed to be the scalar, but it needs to be serializable under a specific
# ASDF tag.
# SCALAR_TYPE_MAP will need to be updated as new wrappers of scalar types are added
# _SCALAR_TYPE_BY_PATTERN will need to be updated as new wrappers of scalar types are added
# to the RAD manifest.
if "type" in schema:
type_ = schema["type"]
elif "allOf" in schema:
type_ = schema["allOf"][0]["$ref"]
else:
raise RuntimeError(f"Unknown schema type: {schema}")
# assume everything is a string if not otherwise defined
type_ = _SCALAR_TYPE_BY_PATTERN.get(pattern, str)

return type(
class_name,
(SCALAR_TYPE_MAP[type_], TaggedScalarNode),
{"_tag": tag["tag_uri"], "__module__": "roman_datamodels.stnode", "__doc__": docstring_from_tag(tag)},
(type_, TaggedScalarNode),
{
"_pattern": pattern,
"_default_tag": tag_def["tag_uri"],
"__module__": "roman_datamodels.stnode",
"__doc__": docstring_from_tag(tag_def),
},
)


def node_factory(tag):
def node_factory(pattern, tag_def):
"""
Factory to create a TaggedObjectNode or TaggedListNode class from a tag
Parameters
----------
tag: dict
pattern: str
A tag pattern/wildcard
tag_def: dict
A tag entry from the RAD manifest
Returns
-------
A dynamically generated TaggedObjectNode or TaggedListNode subclass
"""
class_name = class_name_from_tag_uri(tag["tag_uri"])
schema = load_schema_from_uri(tag["schema_uri"])

if "type" in schema:
# Determine if the class is a TaggedObjectNode or TaggedListNode based on the
# type defined in the schema:
# - TaggedObjectNode if type is "object"
# - TaggedListNode if type is "array" (array in jsonschema represents Python list)
if schema["type"] == "object":
class_type = TaggedObjectNode
elif schema["type"] == "array":
class_type = TaggedListNode
else:
raise RuntimeError(f"Unknown schema type: {schema['type']}")
# Use of allOf in the schema indicates that the class is a TaggedObjectNode
# which is "extending" another class.
elif "allOf" in schema:
class_type = TaggedObjectNode
else:
raise RuntimeError(f"Unknown schema type for: {tag['schema_uri']}")
class_name = class_name_from_tag_uri(pattern)

class_type = _NODE_TYPE_BY_PATTERN.get(pattern, TaggedObjectNode)

# In special cases one may need to add additional features to a tagged node class.
# This is done by creating a mixin class with the name <ClassName>Mixin in _mixins.py
Expand All @@ -179,17 +132,25 @@ def node_factory(tag):
return type(
class_name,
class_type,
{"_tag": tag["tag_uri"], "__module__": "roman_datamodels.stnode", "__doc__": docstring_from_tag(tag)},
{
"_pattern": pattern,
"_default_tag": tag_def["tag_uri"],
"__module__": "roman_datamodels.stnode",
"__doc__": docstring_from_tag(tag_def),
},
)


def stnode_factory(tag):
def stnode_factory(pattern, tag_def):
"""
Construct a tagged STNode class from a tag
Parameters
----------
tag: dict
pattern: str
A tag pattern/wildcard
tag_def: dict
A tag entry from the RAD manifest
Returns
Expand All @@ -198,7 +159,7 @@ def stnode_factory(tag):
"""
# TaggedScalarNodes are a special case because they are not a subclass of a
# _node class, but rather a subclass of the type of the scalar.
if "tagged_scalar" in tag["schema_uri"]:
return scalar_factory(tag)
if "tagged_scalar" in tag_def["schema_uri"]:
return scalar_factory(pattern, tag_def)
else:
return node_factory(tag)
return node_factory(pattern, tag_def)
4 changes: 2 additions & 2 deletions src/roman_datamodels/stnode/_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class DNode(MutableMapping):
Base class describing all "object" (dict-like) data nodes for STNode classes.
"""

_tag = None
_pattern = None
_ctx = None

def __init__(self, node=None, parent=None, name=None):
Expand Down Expand Up @@ -311,7 +311,7 @@ class LNode(UserList):
Base class describing all "array" (list-like) data nodes for STNode classes.
"""

_tag = None
_pattern = None

def __init__(self, node=None):
if node is None:
Expand Down
7 changes: 4 additions & 3 deletions src/roman_datamodels/stnode/_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
whenever they generated.
"""

OBJECT_NODE_CLASSES_BY_TAG = {}
LIST_NODE_CLASSES_BY_TAG = {}
SCALAR_NODE_CLASSES_BY_TAG = {}
OBJECT_NODE_CLASSES_BY_PATTERN = {}
LIST_NODE_CLASSES_BY_PATTERN = {}
SCALAR_NODE_CLASSES_BY_PATTERN = {}
SCALAR_NODE_CLASSES_BY_KEY = {}
NODE_CONVERTERS = {}
NODE_CLASSES_BY_TAG = {}
Loading

0 comments on commit e2039db

Please sign in to comment.