diff --git a/pyproject.toml b/pyproject.toml index 1e0c6b3d..71e0b6fe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ dependencies = [ "numpy >=1.21.2", "opencv-python-headless >=4.6.0.66", "asdf >=2.15.0", - "gwcs >= 0.18.1", + "gwcs @ git+https://github.com/nden/gwcs.git@inverse-bbox", ] dynamic = [ "version", diff --git a/src/stcal/alignment/util.py b/src/stcal/alignment/util.py index f3e7b085..aeee9d90 100644 --- a/src/stcal/alignment/util.py +++ b/src/stcal/alignment/util.py @@ -4,11 +4,11 @@ import functools import logging from typing import TYPE_CHECKING, Protocol +import warnings import gwcs import numpy as np from astropy import units as u -from astropy import wcs as fitswcs from astropy.coordinates import SkyCoord from astropy.modeling import models as astmodels from astropy.utils.misc import isiterable @@ -436,7 +436,6 @@ def compute_scale( raise ValueError(msg) crpix = np.array(wcs.invert(*fiducial)) - delta = np.zeros_like(crpix) spatial_idx = np.where(np.array(wcs.output_frame.axes_type) == "SPATIAL")[0] delta[spatial_idx[0]] = 1 @@ -449,6 +448,7 @@ def compute_scale( dec=crval_with_offsets[spatial_idx[1]], unit="deg", ) + xscale = np.abs(coords[0].separation(coords[1]).value) yscale = np.abs(coords[0].separation(coords[2]).value) @@ -681,11 +681,11 @@ def update_s_region_imaging(model, center=True): # footprint is an array of shape (2, 4) as we # are interested only in the footprint on the sky - ### TODO: we shouldn't use center=True in the call below because we want to - ### calculate the coordinates of the footprint based on the *bounding box*, - ### which means we are interested in each pixel's vertice, not its center. - ### By using center=True, a difference of 0.5 pixel should be accounted for - ### when comparing the world coordinates of the bounding box and the footprint. + # TODO: we shouldn't use center=True in the call below because we want to + # calculate the coordinates of the footprint based on the *bounding box*, + # which means we are interested in each pixel's vertice, not its center. + # By using center=True, a difference of 0.5 pixel should be accounted for + # when comparing the world coordinates of the bounding box and the footprint. footprint = model.meta.wcs.footprint(bbox, center=center, axis_type="spatial").T # take only imaging footprint footprint = footprint[:2, :] @@ -763,31 +763,11 @@ def reproject(wcs1, wcs2): positions in ``wcs1`` and returns x, y positions in ``wcs2``. """ - def _get_forward_transform_func(wcs1): - """Get the forward transform function from the input WCS. If the wcs is a - fitswcs.WCS object all_pix2world requires three inputs, the x (str, ndarrray), - y (str, ndarray), and origin (int). The origin should be between 0, and 1 - https://docs.astropy.org/en/latest/wcs/index.html#loading-wcs-information-from-a-fits-file - ). - """ - if isinstance(wcs1, fitswcs.WCS): - forward_transform = wcs1.all_pix2world - elif isinstance(wcs1, gwcs.WCS): - forward_transform = wcs1.forward_transform - else: - msg = "Expected input to be astropy.wcs.WCS or gwcs.WCS object" - raise TypeError(msg) - return forward_transform - - def _get_backward_transform_func(wcs2): - if isinstance(wcs2, fitswcs.WCS): - backward_transform = wcs2.all_world2pix - elif isinstance(wcs2, gwcs.WCS): - backward_transform = wcs2.backward_transform - else: - msg = "Expected input to be astropy.wcs.WCS or gwcs.WCS object" - raise TypeError(msg) - return backward_transform + try: + forward_transform = wcs1.pixel_to_world_values + backward_transform = wcs2.world_to_pixel_values + except AttributeError as err: + raise TypeError("Input should be a WCS") from err def _reproject(x: float | np.ndarray, y: float | np.ndarray) -> tuple: """ @@ -805,22 +785,30 @@ def _reproject(x: float | np.ndarray, y: float | np.ndarray) -> tuple: tuple Tuple of np.ndarrays including reprojected x and y coordinates. """ - # example inputs to resulting function (12, 13, 0) # third number is origin - # uses np.arrays for shape functionality - if not isinstance(x, (np.ndarray)): - x = np.array(x) - if not isinstance(y, (np.ndarray)): - y = np.array(y) - if x.shape != y.shape: - msg = "x and y must be the same length" - raise ValueError(msg) - sky = _get_forward_transform_func(wcs1)(x, y, 0) - - # rearrange into array including flattened x and y values + # sky = forward_transform(x, y) + # flat_sky = [] + # for axis in sky: + # flat_sky.append(axis.flatten()) + # # Filter out RuntimeWarnings due to computed NaNs in the WCS + # with warnings.catch_warnings(): + # warnings.simplefilter("ignore", RuntimeWarning) + # det = backward_transform(*tuple(flat_sky)) + # det_reshaped = [] + # for axis in det: + # det_reshaped.append(axis.reshape(x.shape)) + + # return tuple(det_reshaped) + shape = np.array(x).shape + sky = forward_transform(x, y) flat_sky = [axis.flatten() for axis in sky] - det = np.array(_get_backward_transform_func(wcs2)(flat_sky[0], flat_sky[1], 0)) - det_reshaped = [axis.reshape(x.shape) for axis in det] - - return tuple(det_reshaped) + + # Filter out RuntimeWarnings due to computed NaNs in the WCS + with warnings.catch_warnings(): + warnings.simplefilter("ignore", RuntimeWarning) + detector = backward_transform(*tuple(flat_sky)) + + if shape == (): + return tuple([axis.item() for axis in detector]) + return tuple([axis.reshape(shape) for axis in detector]) return _reproject diff --git a/tests/test_alignment.py b/tests/test_alignment.py index a4041817..9825a97b 100644 --- a/tests/test_alignment.py +++ b/tests/test_alignment.py @@ -253,7 +253,7 @@ def get_fake_wcs(): [ (1000, 2000, np.array(2000), np.array(4000)), # string input test ([1000], [2000], np.array(2000), np.array(4000)), # array input test - pytest.param(1, 2, 3, 4, marks=pytest.mark.xfail), # expected failure test + (1, 2, 2, 4), ], ) def test_reproject(x_inp, y_inp, x_expected, y_expected):