Skip to content

Commit

Permalink
RCAL-930 Roundtrip L3 wcsinfo especially when skycell specifications …
Browse files Browse the repository at this point in the history
…are used (#1585)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
stscieisenhamer and pre-commit-ci[bot] authored Jan 28, 2025
1 parent cb24749 commit 8422362
Show file tree
Hide file tree
Showing 8 changed files with 285 additions and 30 deletions.
1 change: 1 addition & 0 deletions changes/1585.mosaic_pipeline.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Roundtrip L3 wcsinfo especially when skycell specifications are used
103 changes: 80 additions & 23 deletions romancal/pipeline/mosaic_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from astropy import units as u
from astropy.modeling import models
from gwcs import WCS, coordinate_frames
from stcal.alignment import util as wcs_util

import romancal.datamodels.filetype as filetype
from romancal.datamodels import ModelLibrary
Expand Down Expand Up @@ -114,20 +115,21 @@ def process(self, input):

# check to see if there exists a skycell on disk if not create it
if not isfile(skycell_file_name):
# extract the wcs info from the record for generate_tan_wcs
# extract the wcs info from the record for skycell_to_wcs
log.info(
"Creating skycell image at ra: %f dec %f",
float(skycell_record["ra_center"]),
float(skycell_record["dec_center"]),
)
skycell_wcs = generate_tan_wcs(skycell_record)
skycell_wcs = skycell_to_wcs(skycell_record)
# skycell_wcs.bounding_box = bounding_box

# For resample to use an external grid we need to pass it the skycell gwcs object
# Currently we cannot do that directly so create an asdf file to read the skycell gwcs object
wcs_tree = {"wcs": skycell_wcs}
wcs_file = asdf.AsdfFile(wcs_tree)
wcs_file.write_to("skycell_wcs.asdf")

self.resample.output_wcs = "skycell_wcs.asdf"
self.resample.output_shape = (
int(skycell_record["nx"]),
Expand Down Expand Up @@ -162,41 +164,96 @@ def process(self, input):
return result


def generate_tan_wcs(skycell_record):
"""extract the wcs info from the record for generate_tan_wcs
we need the scale, ra, dec, bounding_box"""
def skycell_to_wcs(skycell_record):
"""From a skycell record, generate a GWCS
Parameters
----------
skycell_record : dict
A skycell record, or row, from the skycell patches table.
Returns
-------
wcsobj : wcs.GWCS
The GWCS object from the skycell record.
"""
wcsinfo = dict()

# The scale is given in arcseconds per pixel. Convert to degrees.
wcsinfo["pixel_scale"] = float(skycell_record["pixel_scale"]) / 3600.0

# Remaining components of the wcsinfo block
wcsinfo["ra_ref"] = float(skycell_record["ra_projection_center"])
wcsinfo["dec_ref"] = float(skycell_record["dec_projection_center"])
wcsinfo["x_ref"] = float(skycell_record["x0_projection"])
wcsinfo["y_ref"] = float(skycell_record["y0_projection"])
wcsinfo["orientat"] = float(skycell_record["orientat_projection_center"])
wcsinfo["rotation_matrix"] = None

scale = float(skycell_record["pixel_scale"])
ra_center = float(skycell_record["ra_projection_center"])
dec_center = float(skycell_record["dec_projection_center"])
shiftx = float(skycell_record["x0_projection"])
shifty = float(skycell_record["y0_projection"])
# Bounding box of the skycell. Note that the center of the pixels are at (0.5, 0.5)
bounding_box = (
(-0.5, -0.5 + skycell_record["nx"]),
(-0.5, -0.5 + skycell_record["ny"]),
)

# components of the model
# shift = models.Shift(shiftx) & models.Shift(shifty)
wcsobj = wcsinfo_to_wcs(wcsinfo, bounding_box=bounding_box)
return wcsobj


def wcsinfo_to_wcs(wcsinfo, bounding_box=None, name="wcsinfo"):
"""Create a GWCS from the L3 wcsinfo meta
Parameters
----------
wcsinfo : dict or MosaicModel.meta.wcsinfo
The L3 wcsinfo to create a GWCS from.
# select a scale for the skycell image, this will come from INS and may
# be optimized for the different survey programs
scale_x = scale
scale_y = scale
# set the pixelsscale to 0.1 arcsec/pixel
pixelscale = models.Scale(scale_x / 3600.0) & models.Scale(scale_y / 3600.0)
bounding_box : None or 4-tuple
The bounding box in detector/pixel space. Form of input is:
(x_left, x_right, y_bottom, y_top)
pixelshift = models.Shift(-1.0 * shiftx) & models.Shift(-1.0 * shifty)
name : str
Value of the `name` attribute of the GWCS object.
Returns
-------
wcs : wcs.GWCS
The GWCS object created.
"""
pixelshift = models.Shift(-wcsinfo["x_ref"], name="crpix1") & models.Shift(
-wcsinfo["y_ref"], name="crpix2"
)
pixelscale = models.Scale(wcsinfo["pixel_scale"], name="cdelt1") & models.Scale(
wcsinfo["pixel_scale"], name="cdelt2"
)
tangent_projection = models.Pix2Sky_TAN()
celestial_rotation = models.RotateNative2Celestial(ra_center, dec_center, 180.0)
det2sky = pixelshift | pixelscale | tangent_projection | celestial_rotation
celestial_rotation = models.RotateNative2Celestial(
wcsinfo["ra_ref"], wcsinfo["dec_ref"], 180.0
)

matrix = wcsinfo.get("rotation_matrix", None)
if matrix:
matrix = np.array(matrix)
else:
orientat = wcsinfo.get("orientat", 0.0)
matrix = wcs_util.calc_rotation_matrix(
np.deg2rad(orientat), v3i_yangle=0.0, vparity=1
)
matrix = np.reshape(matrix, (2, 2))
rotation = models.AffineTransformation2D(matrix, name="pc_rotation_matrix")
det2sky = (
pixelshift | rotation | pixelscale | tangent_projection | celestial_rotation
)

detector_frame = coordinate_frames.Frame2D(
name="detector", axes_names=("x", "y"), unit=(u.pix, u.pix)
)
sky_frame = coordinate_frames.CelestialFrame(
reference_frame=coordinates.ICRS(), name="icrs", unit=(u.deg, u.deg)
)
wcsobj = WCS([(detector_frame, det2sky), (sky_frame, None)])
wcsobj.bounding_box = bounding_box
wcsobj = WCS([(detector_frame, det2sky), (sky_frame, None)], name=name)

if bounding_box:
wcsobj.bounding_box = bounding_box

return wcsobj
133 changes: 133 additions & 0 deletions romancal/pipeline/tests/test_mosaic_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
"""Unit tests for the mosaic pipeline"""

import numpy as np

import romancal.pipeline.mosaic_pipeline as mp


def test_skycell_to_wcs():
"""Test integrity of skycell_to_wcs"""

skycell = np.void(
(
"r274dp63x31y81",
269.7783307416819,
66.04965143695566,
1781.5,
1781.5,
355.9788,
3564,
3564,
67715.5,
-110484.5,
269.6657957545588,
65.9968687812357,
269.6483032937494,
66.09523979539262,
269.89132874168854,
66.10234971630734,
269.9079118635897,
66.00394719483091,
0.1,
274.2857142857143,
63.0,
0.0,
463181,
),
dtype=[
("name", "<U20"),
("ra_center", "<f8"),
("dec_center", "<f8"),
("x_center", "<f4"),
("y_center", "<f4"),
("orientat", "<f4"),
("nx", "<i4"),
("ny", "<i4"),
("x0_projection", "<f4"),
("y0_projection", "<f4"),
("ra_corn1", "<f8"),
("dec_corn1", "<f8"),
("ra_corn2", "<f8"),
("dec_corn2", "<f8"),
("ra_corn3", "<f8"),
("dec_corn3", "<f8"),
("ra_corn4", "<f8"),
("dec_corn4", "<f8"),
("pixel_scale", "<f4"),
("ra_projection_center", "<f8"),
("dec_projection_center", "<f8"),
("orientat_projection_center", "<f4"),
("index", "<i8"),
],
)

wcs = mp.skycell_to_wcs(skycell)

assert np.allclose(
wcs(
skycell["x0_projection"], skycell["y0_projection"], with_bounding_box=False
),
(skycell["ra_projection_center"], skycell["dec_projection_center"]),
)
assert np.allclose(
wcs(skycell["x_center"], skycell["y_center"]),
(skycell["ra_center"], skycell["dec_center"]),
)
assert np.allclose(wcs(0.0, 0.0), (skycell["ra_corn1"], skycell["dec_corn1"]))
assert np.allclose(
wcs(0.0, skycell["ny"] - 1), (skycell["ra_corn2"], skycell["dec_corn2"])
)
assert np.allclose(
wcs(skycell["nx"] - 1, skycell["ny"] - 1),
(skycell["ra_corn3"], skycell["dec_corn3"]),
)
assert np.allclose(
wcs(skycell["nx"] - 1, 0.0), (skycell["ra_corn4"], skycell["dec_corn4"])
)


def test_wcsinfo_to_wcs():
"""Test integrity of wcsinfo_to_wcs"""
wcsinfo = {
"ra_ref": 269.83219987378925,
"dec_ref": 66.04081466149024,
"x_ref": 2069.0914958388985,
"y_ref": 2194.658767532754,
"rotation_matrix": [
[-0.9999964196507396, -0.00267594575838714],
[-0.00267594575838714, 0.9999964196507396],
],
"pixel_scale": 3.036307317109957e-05,
"pixel_shape": [4389, 4138],
"ra_center": 269.82284964811464,
"dec_center": 66.0369888162117,
"ra_corn1": 269.98694025887136,
"dec_corn1": 65.97426875366378,
"ra_corn2": 269.98687579251805,
"dec_corn2": 66.09988065827382,
"ra_corn3": 269.6579498847431,
"dec_corn3": 66.099533603104,
"ra_corn4": 269.6596332616879,
"dec_corn4": 65.97389321243348,
"orientat": 359.8466793994546,
}

wcs = mp.wcsinfo_to_wcs(wcsinfo)

assert np.allclose(
wcs(wcsinfo["x_ref"], wcsinfo["y_ref"]), (wcsinfo["ra_ref"], wcsinfo["dec_ref"])
)
assert np.allclose(
wcs(4389 / 2.0, 4138 / 2.0), (wcsinfo["ra_center"], wcsinfo["dec_center"])
)
assert np.allclose(wcs(0.0, 0.0), (wcsinfo["ra_corn1"], wcsinfo["dec_corn1"]))
assert np.allclose(
wcs(0.0, wcsinfo["pixel_shape"][1]), (wcsinfo["ra_corn2"], wcsinfo["dec_corn2"])
)
assert np.allclose(
wcs(wcsinfo["pixel_shape"][0], wcsinfo["pixel_shape"][1]),
(wcsinfo["ra_corn3"], wcsinfo["dec_corn3"]),
)
assert np.allclose(
wcs(wcsinfo["pixel_shape"][0], 0.0), (wcsinfo["ra_corn4"], wcsinfo["dec_corn4"])
)
10 changes: 10 additions & 0 deletions romancal/regtest/test_mos_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
import pytest
import roman_datamodels as rdm

from romancal.pipeline import mosaic_pipeline
from romancal.pipeline.mosaic_pipeline import MosaicPipeline

from . import util
from .regtestdata import compare_asdf

# mark all tests in this module
Expand Down Expand Up @@ -116,3 +118,11 @@ def test_added_background(output_model):
def test_added_background_level(output_model):
# DMS400
assert any(output_model.meta.individual_image_meta.background["level"] != 0)


def test_wcsinfo_wcs_roundtrip(output_model):
"""Test that the contents of wcsinfo reproduces the wcs"""
wcs_from_wcsinfo = mosaic_pipeline.wcsinfo_to_wcs(output_model.meta.wcsinfo)

ra_mad, dec_mad = util.comp_wcs_grids_arcs(output_model.meta.wcs, wcs_from_wcsinfo)
assert (ra_mad + dec_mad) / 2.0 < 1.0e-5
10 changes: 10 additions & 0 deletions romancal/regtest/test_mos_skycell_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import pytest
import roman_datamodels as rdm

from romancal.pipeline import mosaic_pipeline
from romancal.pipeline.mosaic_pipeline import MosaicPipeline

from . import util
from .regtestdata import compare_asdf

# mark all tests in this module
Expand Down Expand Up @@ -56,3 +58,11 @@ def test_resample_ran(output_model):
def test_location_name(output_model):
# test that the location_name matches the skycell selected
assert output_model.meta.basic.location_name == "r274dp63x31y81"


def test_wcsinfo_wcs_roundtrip(output_model):
"""Test that the contents of wcsinfo reproduces the wcs"""
wcs_from_wcsinfo = mosaic_pipeline.wcsinfo_to_wcs(output_model.meta.wcsinfo)

ra_mad, dec_mad = util.comp_wcs_grids_arcs(output_model.meta.wcs, wcs_from_wcsinfo)
assert (ra_mad + dec_mad) / 2.0 < 1.0e-5
33 changes: 33 additions & 0 deletions romancal/regtest/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""Test utilities"""

import numpy as np
from astropy.stats import mad_std


def comp_wcs_grids_arcs(wcs_a, wcs_b, npix=4088, interval=10):
"""Compare world grids produced by the two wcs
Parameters
----------
wcs_a, wcs_b : gwcs.WCS
The wcs object to compare.
npix : int
The size of the grid to produce.
interval : int
The interval to check over.
Returns
-------
mad_std : float
The numpy MAD_STD in arcseconds
"""
xx, yy = np.meshgrid(np.linspace(0, npix, interval), np.linspace(0, npix, interval))
ra_a, dec_a = wcs_a(xx, yy, with_bounding_box=False)
ra_b, dec_b = wcs_b(xx, yy, with_bounding_box=False)

ra_mad = mad_std(ra_a - ra_b, ignore_nan=True) * 60.0 * 60.0 * 1000.0
dec_mad = mad_std(dec_a - dec_b, ignore_nan=True) * 60.0 * 60.0 * 1000.0

return ra_mad, dec_mad
Loading

0 comments on commit 8422362

Please sign in to comment.