Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft commit of the refactoring of the rotate function #7683

Draft
wants to merge 1 commit into
base: geometric
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 53 additions & 1 deletion monai/transforms/spatial/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,11 @@
from monai.transforms.croppad.array import ResizeWithPadOrCrop
from monai.transforms.intensity.array import GaussianSmooth
from monai.transforms.inverse import TraceableTransform
from monai.transforms.lazy.utils import apply_to_geometry
from monai.transforms.utils import create_rotate, create_translate, resolves_modes, scale_affine
from monai.transforms.utils_pytorch_numpy_unification import allclose
from monai.utils import (
KindKeys,
LazyAttr,
TraceKeys,
convert_to_dst_type,
Expand Down Expand Up @@ -339,7 +341,7 @@ def resize(
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out


def rotate(img, angle, output_shape, mode, padding_mode, align_corners, dtype, lazy, transform_info):
def rotate_impl(img, angle, output_shape, mode, padding_mode, align_corners, dtype, lazy, transform_info):
"""
Functional implementation of rotate.
This function operates eagerly or lazily according to
Expand Down Expand Up @@ -395,6 +397,14 @@ def rotate(img, angle, output_shape, mode, padding_mode, align_corners, dtype, l
transform_info=transform_info,
lazy=lazy,
)
return transform, meta_info


def rotate_raster(img, angle, output_shape, mode, padding_mode, align_corners, dtype, lazy, transform_info):
"""
Raster-specific rotation functionality
"""
transform, meta_info = rotate_impl(img, angle, output_shape, mode, padding_mode, align_corners, dtype, lazy, transform_info)
out = _maybe_new_metatensor(img)
if lazy:
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info
Expand All @@ -410,6 +420,48 @@ def rotate(img, angle, output_shape, mode, padding_mode, align_corners, dtype, l
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out


def rotate_geom(img, angle, output_shape, mode, padding_mode, align_corners, dtype, lazy, transform_info):
"""
Geometry-specific rotation functionality
"""
_, meta_info = rotate_impl(img, angle, output_shape, mode, padding_mode, align_corners, dtype, lazy, transform_info)
out = _maybe_new_metatensor(img)
if lazy:
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info
out = apply_to_geometry(out, meta_info)
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out


def rotate(img, angle, output_shape, mode, padding_mode, align_corners, dtype, lazy, transform_info):
"""
Functional implementation of rotate.
This function operates eagerly or lazily according to
``lazy`` (default ``False``).

Args:
img: data to be changed, assuming `img` is channel-first.
angle: Rotation angle(s) in radians. should a float for 2D, three floats for 3D.
output_shape: output shape of the rotated data.
mode: {``"bilinear"``, ``"nearest"``}
Interpolation mode to calculate output values.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
Padding mode for outside grid values.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
align_corners: See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
dtype: data type for resampling computation.
If None, use the data type of input data. To be compatible with other modules,
the output data type is always ``float32``.
lazy: a flag that indicates whether the operation should be performed lazily or not
transform_info: a dictionary with the relevant information pertaining to an applied transform.
"""
if isinstance(img, MetaTensor):
if img.kind == KindKeys.PIXEL:
return rotate_raster(img, angle, output_shape, mode, padding_mode, align_corners, dtype, lazy, transform_info)
elif img.kind == KindKeys.GEOMETRY:
return rotate_geom(img, angle, output_shape, mode, padding_mode, align_corners, dtype, lazy, transform_info)


def zoom(img, scale_factor, keep_size, mode, padding_mode, align_corners, dtype, lazy, transform_info):
"""
Functional implementation of zoom.
Expand Down
Loading