Skip to content

Commit

Permalink
SwinUNETR refactored img_size parameter and removed checkpointing dep… (
Browse files Browse the repository at this point in the history
#7093)

### Description
Make two changes for the SwinUNETR network:
- The image_size parameter does not seem to have an effect apart from
checking shape compatibility in the beginning. This is now expressed in
the docstring and the parameter will be deprecated in the future.
Instead the shape compatibility is checked during the forward pass on
the actual shape
- newer pytorch versions accept a parameter
[use_reentrant](https://pytorch.org/docs/2.1/checkpoint.html). The old
default of True is deprecated in favor of True. This PR sets the
parameter to true and therefore adopts the recommended value and removes
the warning.


### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: John Zielke <[email protected]>
  • Loading branch information
john-zielke-snkeos authored Oct 6, 2023
1 parent 100db27 commit f239825
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 10 deletions.
42 changes: 33 additions & 9 deletions monai/networks/nets/swin_unetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from torch.nn import LayerNorm
from typing_extensions import Final

from monai.networks.blocks import MLPBlock as Mlp
from monai.networks.blocks import PatchEmbed, UnetOutBlock, UnetrBasicBlock, UnetrUpBlock
from monai.networks.layers import DropPath, trunc_normal_
from monai.utils import ensure_tuple_rep, look_up_option, optional_import
from monai.utils.deprecate_utils import deprecated_arg

rearrange, _ = optional_import("einops", name="rearrange")

Expand All @@ -49,6 +51,15 @@ class SwinUNETR(nn.Module):
<https://arxiv.org/abs/2201.01266>"
"""

patch_size: Final[int] = 2

@deprecated_arg(
name="img_size",
since="1.3",
removed="1.5",
msg_suffix="The img_size argument is not required anymore and "
"checks on the input size are run during forward().",
)
def __init__(
self,
img_size: Sequence[int] | int,
Expand All @@ -69,7 +80,10 @@ def __init__(
) -> None:
"""
Args:
img_size: dimension of input image.
img_size: spatial dimension of input image.
This argument is only used for checking that the input image size is divisible by the patch size.
The tensor passed to forward() can have a dynamic shape as long as its spatial dimensions are divisible by 2**5.
It will be removed in an upcoming version.
in_channels: dimension of input channels.
out_channels: dimension of output channels.
feature_size: dimension of network feature size.
Expand Down Expand Up @@ -103,16 +117,13 @@ def __init__(
super().__init__()

img_size = ensure_tuple_rep(img_size, spatial_dims)
patch_size = ensure_tuple_rep(2, spatial_dims)
patch_sizes = ensure_tuple_rep(self.patch_size, spatial_dims)
window_size = ensure_tuple_rep(7, spatial_dims)

if spatial_dims not in (2, 3):
raise ValueError("spatial dimension should be 2 or 3.")

for m, p in zip(img_size, patch_size):
for i in range(5):
if m % np.power(p, i + 1) != 0:
raise ValueError("input image size (img_size) should be divisible by stage-wise image resolution.")
self._check_input_size(img_size)

if not (0 <= drop_rate <= 1):
raise ValueError("dropout rate should be between 0 and 1.")
Expand All @@ -132,7 +143,7 @@ def __init__(
in_chans=in_channels,
embed_dim=feature_size,
window_size=window_size,
patch_size=patch_size,
patch_size=patch_sizes,
depths=depths,
num_heads=num_heads,
mlp_ratio=4.0,
Expand Down Expand Up @@ -297,7 +308,20 @@ def load_from(self, weights):
weights["state_dict"]["module.layers4.0.downsample.norm.bias"]
)

@torch.jit.unused
def _check_input_size(self, spatial_shape):
img_size = np.array(spatial_shape)
remainder = (img_size % np.power(self.patch_size, 5)) > 0
if remainder.any():
wrong_dims = (np.where(remainder)[0] + 2).tolist()
raise ValueError(
f"spatial dimensions {wrong_dims} of input image (spatial shape: {spatial_shape})"
f" must be divisible by {self.patch_size}**5."
)

def forward(self, x_in):
if not torch.jit.is_scripting():
self._check_input_size(x_in.shape[2:])
hidden_states_out = self.swinViT(x_in, self.normalize)
enc0 = self.encoder1(x_in)
enc1 = self.encoder2(hidden_states_out[0])
Expand Down Expand Up @@ -669,12 +693,12 @@ def load_from(self, weights, n_block, layer):
def forward(self, x, mask_matrix):
shortcut = x
if self.use_checkpoint:
x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix)
x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix, use_reentrant=False)
else:
x = self.forward_part1(x, mask_matrix)
x = shortcut + self.drop_path(x)
if self.use_checkpoint:
x = x + checkpoint.checkpoint(self.forward_part2, x)
x = x + checkpoint.checkpoint(self.forward_part2, x, use_reentrant=False)
else:
x = x + self.forward_part2(x)
return x
Expand Down
2 changes: 1 addition & 1 deletion runtests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ do
doBlackFormat=true
doIsortFormat=true
doFlake8Format=true
doPylintFormat=true
# doPylintFormat=true # https://github.com/Project-MONAI/MONAI/issues/7094
doRuffFormat=true
doCopyRight=true
;;
Expand Down

0 comments on commit f239825

Please sign in to comment.