Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hf demo #81

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
17 changes: 17 additions & 0 deletions SwinUNETR/BTCV/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,23 @@ Dependencies can be installed using:
pip install -r requirements.txt
```

# Huggingface inference API

To install necessary dependencies, run the below in bash.
```
git clone https://github.com/darraghdog/Project-MONAI-research-contributions pmrc
pip install -r pmrc/requirements.txt
cd pmrc/SwinUNETR/BTCV
```

To load the model from the hub.
```
from swinunetr import SwinUnetrModelForInference
model = SwinUnetrModelForInference.from_pretrained('darragh/swinunetr-btcv-tiny')
```

You can also use `predict.py` to run inference for sample dicom medical images.

# Models

Please download the self-supervised pre-trained weights for Swin UNETR backbone (CVPR paper [1]) from this <a href="https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/model_swinvit.pt"> link</a>.
Expand Down
Binary file not shown.
Binary file not shown.
98 changes: 98 additions & 0 deletions SwinUNETR/BTCV/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import os
import glob
import shutil
import torch
import argparse
import cv2
import mediapy
import numpy as np
from skimage import color, img_as_ubyte
from monai import transforms, data
from swinunetr import SwinUnetrModelForInference, SwinUnetrConfig

parser = argparse.ArgumentParser(description='Swin UNETR segmentation pipeline')
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu', help='device for model - cpu/gpu')
parser.add_argument('--a_min', default=-175.0, type=float, help='a_min in ScaleIntensityRanged')
parser.add_argument('--a_max', default=250.0, type=float, help='a_max in ScaleIntensityRanged')
parser.add_argument('--b_min', default=0.0, type=float, help='b_min in ScaleIntensityRanged')
parser.add_argument('--b_max', default=1.0, type=float, help='b_max in ScaleIntensityRanged')
parser.add_argument('--infer_overlap', default=0.5, type=float, help='sliding window inference overlap')
parser.add_argument('--space_x', default=1.5, type=float, help='spacing in x direction')
parser.add_argument('--space_y', default=1.5, type=float, help='spacing in y direction')
parser.add_argument('--space_z', default=2.0, type=float, help='spacing in z direction')
parser.add_argument('--roi_x', default=96, type=int, help='roi size in x direction')
parser.add_argument('--roi_y', default=96, type=int, help='roi size in y direction')
parser.add_argument('--roi_z', default=96, type=int, help='roi size in z direction')
parser.add_argument('--last_n_frames', default=64, type=int, help='Limit the frames inference. -1 for all frames.')
args = parser.parse_args()

ffmpeg_path = shutil.which('ffmpeg')
mediapy.set_ffmpeg(ffmpeg_path)

model = SwinUnetrModelForInference.from_pretrained('darragh/swinunetr-btcv-tiny')
model.eval()
model.to(args.device)

test_files = glob.glob('dataset/imagesSampleTs/*.nii.gz')
test_files = [{'image': f} for f in test_files]

test_transform = transforms.Compose(
[
transforms.LoadImaged(keys=["image"]),
transforms.AddChanneld(keys=["image"]),
transforms.Spacingd(keys="image",
pixdim=(args.space_x, args.space_y, args.space_z),
mode="bilinear"),
transforms.ScaleIntensityRanged(keys=["image"],
a_min=args.a_min,
a_max=args.a_max,
b_min=args.b_min,
b_max=args.b_max,
clip=True),
#transforms.Resized(keys=["image"], spatial_size = (256,256,-1)),
transforms.ToTensord(keys=["image"]),
])

test_ds = test_transform(test_files)
test_loader = data.DataLoader(test_ds,
batch_size=1,
shuffle=False)

for i, batch in enumerate(test_loader):

tst_inputs = batch["image"]
if args.last_n_frames>0:
tst_inputs = tst_inputs[:,:,:,:,-args.last_n_frames:]

with torch.no_grad():
outputs = model(tst_inputs,
(args.roi_x,
args.roi_y,
args.roi_z),
8,
overlap=args.infer_overlap,
mode="gaussian")

tst_outputs = torch.softmax(outputs.logits, 1)
tst_outputs = torch.argmax(tst_outputs, axis=1)

fnames = batch['image_meta_dict']['filename_or_obj']

# Write frames to video

for fname, inp, outp in zip(fnames, tst_inputs, tst_outputs):

dicom_name = fname.split('/')[-1]
video_name = f'videos/{dicom_name}.mp4'
frames = []
for idx in range(inp.shape[-1]):
# Segmentation
seg = outp[:,:,idx].numpy().astype(np.uint8)
# Input dicom frame
img = (inp[0,:,:,idx]*255).numpy().astype(np.uint8)
img = cv2.cvtColor(img,cv2.COLOR_GRAY2RGB)
frame = color.label2rgb(seg,img, bg_label = 0)
frame = img_as_ubyte(frame)
frame = np.concatenate((img, frame), 1)
frames.append(frame)
mediapy.write_video(video_name, frames, fps=4)
8 changes: 7 additions & 1 deletion SwinUNETR/BTCV/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
transformers==4.20.1
torch==1.10.0

git+https://github.com/Project-MONAI/MONAI#[email protected]+271.g07de215c
nibabel==3.1.1
tqdm==4.59.0
einops==0.4.1
tensorboardX==2.1
scipy==1.2.1
scipy==1.5.0
mediapy==1.0.3
scikit-image==0.17.2
opencv-python==4.6.0.66
53 changes: 53 additions & 0 deletions SwinUNETR/BTCV/swinunetr/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.

# Copyright 2021 NAVER CLOVA Team. All rights reserved.
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING

from transformers.file_utils import (
_LazyModule,
is_torch_available,
)

_import_structure = {
"configuration_swinunetr": ["SWINUNETR_PRETRAINED_CONFIG_ARCHIVE_MAP",
"SwinUnetrConfig"],
}


if is_torch_available():
_import_structure["modeling_swinunetr"] = [
"SWINUNETR_PRETRAINED_MODEL_ARCHIVE_LIST",
"SwinUnetrModelForInference",
]

if TYPE_CHECKING:
from .configuration_swinunetr import SWINUNETR_PRETRAINED_CONFIG_ARCHIVE_MAP, SwinUnetrConfig

if is_torch_available():
from .modeling_bros import (
SWINUNETR_PRETRAINED_MODEL_ARCHIVE_LIST,
SwinUnetrModelForInference,
)

else:
import sys

sys.modules[__name__] = _LazyModule(
__name__, globals()["__file__"], _import_structure
)
94 changes: 94 additions & 0 deletions SwinUNETR/BTCV/swinunetr/configuration_swinunetr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# coding=utf-8
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Swin Unnetr configuration """

from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging

logger = logging.get_logger(__name__)

SWINUNETR_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"swinunetr-btcv-tiny": "https://huggingface.co/darragh/swinunetr-btcv-tiny/raw/main/config.json",
"swinunetr-btcv-small": "https://huggingface.co/darragh/swinunetr-btcv-small/raw/main/config.json",
}


class SwinUnetrConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a :class:`~transformers.BertModel` or a
:class:`~transformers.TFBertModel`. It is used to instantiate a model according to the specified arguments,
defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration
to that of the BERT `bert-base-uncased <https://huggingface.co/bert-base-uncased>`__ architecture.

Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.


Args:
img_size: dimension of input image.
in_channels: dimension of input channels.
out_channels: dimension of output channels.
feature_size: dimension of network feature size.
depths: number of layers in each stage.
num_heads: number of attention heads.
norm_name: feature normalization type and arguments.
drop_rate: dropout rate.
attn_drop_rate: attention dropout rate.
dropout_path_rate: drop path rate.
normalize: normalize output intermediate features in each stage.
use_checkpoint: use gradient checkpointing for reduced memory usage.
spatial_dims: number of spatial dims.

Examples::

>>> TBD
"""
model_type = "swinunetr"

def __init__(
self,
architecture= "SwinUNETR",
img_size= 96,
in_channels= 1,
out_channels= 14,
depths= (2, 2, 2, 2),
num_heads= (3, 6, 12, 24),
feature_size= 12,
norm_name= "instance",
drop_rate= 0.0,
attn_drop_rate= 0.0,
dropout_path_rate= 0.0,
normalize= True,
use_checkpoint= False,
spatial_dims= 3,
**kwargs
):
super().__init__(

architecture= architecture,
img_size= img_size,
in_channels= in_channels,
out_channels= out_channels,
depths= depths,
num_heads= num_heads,
feature_size= feature_size,
norm_name= norm_name,
drop_rate= drop_rate,
attn_drop_rate= attn_drop_rate,
dropout_path_rate= dropout_path_rate,
normalize= normalize,
use_checkpoint= use_checkpoint,
spatial_dims= spatial_dims,
**kwargs,
)
Loading