Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial draft implementation of the resize function (still needs work) #7685

Draft
wants to merge 1 commit into
base: geometric
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 71 additions & 2 deletions monai/transforms/spatial/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,11 @@
from monai.transforms.croppad.array import ResizeWithPadOrCrop
from monai.transforms.intensity.array import GaussianSmooth
from monai.transforms.inverse import TraceableTransform
from monai.transforms.lazy.utils import apply_to_geometry
from monai.transforms.utils import create_rotate, create_translate, resolves_modes, scale_affine
from monai.transforms.utils_pytorch_numpy_unification import allclose
from monai.utils import (
KindKeys,
LazyAttr,
TraceKeys,
convert_to_dst_type,
Expand Down Expand Up @@ -265,7 +267,7 @@ def flip(img, sp_axes, lazy, transform_info):
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out


def resize(
def resize_impl(
img, out_size, mode, align_corners, dtype, input_ndim, anti_aliasing, anti_aliasing_sigma, lazy, transform_info
):
"""
Expand Down Expand Up @@ -300,23 +302,36 @@ def resize(
"dtype": str(dtype)[6:], # dtype as string; remove "torch": torch.float32 -> float32
"new_dim": len(orig_size) - input_ndim,
}
affine = scale_affine(orig_size, out_size)
meta_info = TraceableTransform.track_transform_meta(
img,
sp_size=out_size,
affine=scale_affine(orig_size, out_size),
affine=affine,
extra_info=extra_info,
orig_size=orig_size,
transform_info=transform_info,
lazy=lazy,
)

return affine, orig_size, meta_info


def resize_raster(
img, out_size, mode, align_corners, dtype, input_ndim, anti_aliasing, anti_aliasing_sigma, lazy, transform_info
):
_, orig_size, meta_info = resize_impl(
img, out_size, mode, align_corners, dtype, input_ndim, anti_aliasing, anti_aliasing_sigma, lazy, transform_info
)
if lazy:
if anti_aliasing and lazy:
warnings.warn("anti-aliasing is not compatible with lazy evaluation.")
out = _maybe_new_metatensor(img)
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info

if tuple(convert_to_numpy(orig_size)) == out_size:
out = _maybe_new_metatensor(img, dtype=torch.float32)
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out

out = _maybe_new_metatensor(img)
img_ = convert_to_tensor(out, dtype=dtype, track_meta=False) # convert to a regular tensor
if anti_aliasing and any(x < y for x, y in zip(out_size, img_.shape[1:])):
Expand All @@ -339,6 +354,60 @@ def resize(
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out


def resize_geom(
img, out_size, mode, align_corners, dtype, input_ndim, anti_aliasing, anti_aliasing_sigma, lazy, transform_info
):
_, _, meta_info = resize_impl(
img, out_size, mode, align_corners, dtype, input_ndim, anti_aliasing, anti_aliasing_sigma, lazy, transform_info
)
out = _maybe_new_metatensor(img)
if lazy:
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info

out = apply_to_geometry(out, meta_info)
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out


def resize(
img, out_size, mode, align_corners, dtype, input_ndim, anti_aliasing, anti_aliasing_sigma, lazy, transform_info
):
"""
Functional implementation of resize.
This function operates eagerly or lazily according to
``lazy`` (default ``False``).

Args:
img: data to be changed, assuming `img` is channel-first.
out_size: expected shape of spatial dimensions after resize operation.
mode: {``"nearest"``, ``"nearest-exact"``, ``"linear"``,
``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``}
The interpolation mode.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html
align_corners: This only has an effect when mode is
'linear', 'bilinear', 'bicubic' or 'trilinear'.
dtype: data type for resampling computation. If None, use the data type of input data.
input_ndim: number of spatial dimensions.
anti_aliasing: whether to apply a Gaussian filter to smooth the image prior
to downsampling. It is crucial to filter when downsampling
the image to avoid aliasing artifacts. See also ``skimage.transform.resize``
anti_aliasing_sigma: {float, tuple of floats}, optional
Standard deviation for Gaussian filtering used when anti-aliasing.
lazy: a flag that indicates whether the operation should be performed lazily or not
transform_info: a dictionary with the relevant information pertaining to an applied transform.
"""
if isinstance(img, MetaTensor):
if img.kind == KindKeys.PIXEL:
return resize_raster(
img, out_size, mode, align_corners, dtype, input_ndim, anti_aliasing, anti_aliasing_sigma, lazy, transform_info
)
elif img.kind == KindKeys.GEOMETRY:
return resize_geom(
img, out_size, mode, align_corners, dtype, input_ndim, anti_aliasing, anti_aliasing_sigma, lazy, transform_info
)
else:
raise ValueError(f"Unsupported value for 'kind': {img.kind}")


def rotate(img, angle, output_shape, mode, padding_mode, align_corners, dtype, lazy, transform_info):
"""
Functional implementation of rotate.
Expand Down
Loading