-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit b3417e5
Showing
109 changed files
with
416,647 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
## RaDialog: A Large Vision-Language Model for Radiology Report Generation and Conversational Assistance | ||
**Authors**: [Chantal Pellegrini*][cp], [Ege Özsoy*][eo], [Benjamin Busam][bb], [Nassir Navab][nn], [Matthias Keicher][mk] | ||
|
||
[cp]:https://www.cs.cit.tum.de/camp/members/chantal-pellegrini/ | ||
[eo]:https://www.cs.cit.tum.de/camp/members/ege-oezsoy/ | ||
[mk]:https://www.cs.cit.tum.de/camp/members/matthias-keicher/ | ||
[nn]:https://www.cs.cit.tum.de/camp/members/cv-nassir-navab/nassir-navab/ | ||
[bb]:https://www.cs.cit.tum.de/camp/members/benjamin-busam-1/ | ||
|
||
## [Paper](https://arxiv.org/abs/2106.02009) | [Demo](https://www.youtube.com/watch?v=8Z3QX6Q4Zq4) | Dataset - Coming Soon | ||
|
||
<img align="right" src="figs/example.png" alt="teaser" width="50%" style="margin-left: 20px"> | ||
|
||
Conversational AI tools that can generate and discuss clinically correct radiology reports for a given medical image have the potential to transform radiology. Such a human-in-the-loop radiology assistant could facilitate a collaborative diagnostic process, thus saving time and improving the quality of reports. Towards this goal, we introduce RaDialog, the first thoroughly evaluated and publicly available large vision-language model for radiology report generation and interactive dialog. RaDialog effectively integrates visual image features and structured pathology findings with a large language model (LLM) while simultaneously adapting it to a specialized domain using parameter-efficient fine-tuning. To keep the conversational abilities of the underlying LLM, we propose a comprehensive, semi-automatically labeled, image-grounded instruct dataset for chest X-ray radiology tasks. By training with this dataset, our method achieves state-of-the-art clinical correctness in report generation and shows impressive abilities in interactive tasks such as correcting reports and answering questions, serving as a foundational step toward clinical dialog systems. | ||
|
||
## Installation | ||
|
||
### Environment Setup: | ||
#### 1) RaDialog Environment | ||
- Install the RaDialog environment with `conda create --name radialog python=3.7` | ||
- Activate the environment with `conda activate radialog` | ||
- Install the requirements with `pip install -r requirements.txt` | ||
- Reinstall correct versions of torch and transformers with `pip install torch==1.13.1 transformers==4.28.1` | ||
|
||
#### 2) CheXbert Environment | ||
- Install the CheXbert environment with `conda create --name chexbert python=3.7` | ||
- Activate the environment with `conda activate chexbert` | ||
- Move to the chexbert directory with `cd chexbert` | ||
- Install the requirements with `pip install -r requirements.txt` | ||
- Set the absolute path to the chexbert env and folder in `local_config.py` | ||
|
||
### Prepare the Data and Models: | ||
|
||
#### 1) Download MIMIC-CXR | ||
- Download the MIMIC-CXR dataset from [here](https://physionet.org/content/mimic-cxr/2.0.0/) | ||
- in local_config.py set the path to the MIMIC-CXR dataset | ||
- in model/lavis/defaults_report.yaml set the path to the MIMIC-CXR dataset | ||
|
||
#### 2) Create sectioned report data | ||
- go to the mimic-cxr folder with `cd mimic-cxr` | ||
- run `python create_section_files.py` to prepare the report data | ||
|
||
#### 3) Prepare the instruct dataset | ||
|
||
- As MIMIC-CXR needs a certified PhysioNet account to be accessed, we can not publish our instruct dataset directly. | ||
- We are working on publishing the instruct dataset on PhysioNet. In the meantime, you can create an instruct dataset yourself by following the steps below. | ||
|
||
- The MIMIC-NLE data has to be generated first, as it also contains protected data. Follow the instructions [here](https://github.com/maximek3/MIMIC-NLE/tree/main) to generate the MIMIC-NLE data and set the path to the MIMIC-NLE data in `local_config.py`. | ||
- For the correction task, you can write us, then we can share the used incorrect predictions with you. | ||
- To generate data without Correction or Reasoning (MIMIC-NLE), please comment our line 335 or 336 in "create_data.py" accordingly. | ||
|
||
Data for RaDialog-RG: | ||
- run `python create_data.py --mode "RG"` to generate the report generation dataset in the required format (no instruct data) | ||
|
||
Data for RaDialog-INS: | ||
- run `python create_data.py --mode "INS"` to generate the instruct dataset | ||
|
||
4) Download pretrained models | ||
- Download the pretrained models from [here](TODO) and place them in the checkpoints folder | ||
|
||
### Run Demo: | ||
- run `python demo.py --cfg-path configs/blip2_pretrain_stage1_emb.yaml` to start the demo | ||
- connect to the demo with a browser at `http://127.0.0.1:7860` (check terminal for address) and start chatting with RaDialog | ||
|
||
### Evaluate RaDialog on MIMIC-CXR test set: | ||
- RaDialog-RG: run `python test.py --prompt img_matching_examples_ig2_noexamples_IMG_findings --use_embs --num_workers 0 --lora_model checkpoints/vicuna-7b-img-report/checkpoint-11200` | ||
- RaDialog-INS: run `python test.py --prompt img_matching_examples_ig2_noexamples_IMG_findings --use_embs --num_workers 0 --lora_model checkpoints/vicuna-7b-img-instruct/checkpoint-4800` | ||
|
||
### Train RaDialog: | ||
#### 1) CheXbert classifier Training | ||
- run `python -m findings_classifier.train --train --run_name "train_chexbert" ` | ||
- then run `python -m findings_classifier.train --run_name "save_preds" ` to save the predictions of the trained model | ||
|
||
#### 2) Image Encoder Pretraining | ||
- run `python -m pretraining.train` | ||
|
||
#### 3) LLM Training | ||
Train RaDialog-RG: | ||
- run `python finetune.py --use_embs True --base_model 'vicuna_v7' --output_dir './lora-cxr-vicuna-specific-7b-noexamples-imgemb-findings-rightpadding-stratified_32imgtokens_600tokens' --wandb_run_name lora-cxr-vicuna-specific-7b-noexamples-imgemb-findings-rightpadding-stratified_32imgtokens_600tokens --prompt_template_name vicuna_v11 --data_path "data/data_files/mimic_cxr_reports_stratified.json" --cutoff_len 600` | ||
|
||
Train RaDialog-INS: | ||
- run `python finetune.py --use_embs True --base_model 'vicuna_v7' --output_dir './lora-cxr-vicuna-specific-7b-noexamples-imgemb-findings-rightpadding-stratified_32imgtokens_600tokens_reversed2' --wandb_run_name lora-cxr-vicuna-specific-7b-noexamples-imgemb-findings-rightpadding-stratified_32imgtokens_600tokens_reversed2 --prompt_template_name vicuna_v11 --data_path "data/data_files/instruct_data_stratified.json" --cutoff_len 600` | ||
|
||
# TODO fix all epochs etc etc |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
# ------------------------------------------------------------------------------------------- | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. | ||
# ------------------------------------------------------------------------------------------- | ||
|
||
from __future__ import annotations | ||
|
||
from contextlib import contextmanager | ||
from typing import Any, Generator, Optional, Sequence, Tuple, Union | ||
|
||
import torch | ||
import torch.nn as nn | ||
from health_multimodal.common.device import get_module_device | ||
from timm.models.layers import trunc_normal_ | ||
|
||
from .resnet import resnet18, resnet50 | ||
from .transformer import VisionTransformerPooler | ||
from .types import ImageEncoderType | ||
|
||
DEFAULT_DILATION_VALUES_FOR_RESNET = (False, False, True) | ||
ImageEncoderOutputType = Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] | ||
|
||
|
||
class ImageEncoder(nn.Module): | ||
"""Image encoder trunk module for the ``ImageModel`` class. | ||
:param img_encoder_type : Type of image encoder model to use, either ``"resnet18_multi_image"`` or | ||
``"resnet50_multi_image"``. | ||
""" | ||
|
||
def __init__(self, img_encoder_type: str): | ||
super().__init__() | ||
self.img_encoder_type = img_encoder_type | ||
self.encoder = self._create_encoder() | ||
|
||
def _create_encoder(self, **kwargs: Any) -> nn.Module: | ||
if self.img_encoder_type in [ImageEncoderType.RESNET18, ImageEncoderType.RESNET18_MULTI_IMAGE]: | ||
encoder_class = resnet18 | ||
elif self.img_encoder_type in [ImageEncoderType.RESNET50, ImageEncoderType.RESNET50_MULTI_IMAGE]: | ||
encoder_class = resnet50 | ||
else: | ||
supported = ImageEncoderType.get_members(multi_image_encoders_only=False) | ||
raise NotImplementedError(f"Image encoder type \"{self.img_encoder_type}\" must be in {supported}") | ||
|
||
encoder = encoder_class(pretrained=True, **kwargs) | ||
|
||
return encoder | ||
|
||
def forward(self, | ||
current_image: torch.Tensor, | ||
return_patch_embeddings: bool = False) -> ImageEncoderOutputType: | ||
"""Get image global and patch embeddings""" | ||
|
||
patch_emb = self.encoder(current_image) | ||
avg_pooled_emb = torch.flatten(torch.nn.functional.adaptive_avg_pool2d(patch_emb, (1, 1)), 1) | ||
if return_patch_embeddings: | ||
return patch_emb, avg_pooled_emb | ||
|
||
return avg_pooled_emb | ||
|
||
def reload_encoder_with_dilation(self, replace_stride_with_dilation: Optional[Sequence[bool]] = None) -> None: | ||
"""Workaround for enabling dilated convolutions after model initialization. | ||
:param replace_stride_with_dilation: Replace the 2x2 standard convolution stride with a dilated convolution | ||
in each layer in the last three blocks of ResNet architecture. | ||
""" | ||
if self.img_encoder_type == ImageEncoderType.RESNET18: | ||
# resnet18 uses BasicBlock implementation, which does not support dilated convolutions. | ||
raise NotImplementedError("resnet18 does not support dilated convolutions") | ||
|
||
if replace_stride_with_dilation is None: | ||
replace_stride_with_dilation = DEFAULT_DILATION_VALUES_FOR_RESNET | ||
|
||
device = next(self.encoder.parameters()).device | ||
new_encoder = self._create_encoder(replace_stride_with_dilation=replace_stride_with_dilation).to(device) | ||
|
||
if self.encoder.training: | ||
new_encoder.train() | ||
else: | ||
new_encoder.eval() | ||
|
||
new_encoder.load_state_dict(self.encoder.state_dict()) | ||
self.encoder = new_encoder | ||
|
||
|
||
class MultiImageEncoder(ImageEncoder): | ||
"""Multi-image encoder trunk module for the ``ImageModel`` class. | ||
It can be used to encode multiple images into combined latent representation. | ||
Currently it only supports two input images but can be extended to support more in future. | ||
:param img_encoder_type: Type of image encoder model to use: either ``"resnet18"`` or ``"resnet50"``. | ||
""" | ||
|
||
def __init__(self, img_encoder_type: str): | ||
super().__init__(img_encoder_type) | ||
|
||
output_dim = 256 # The aggregate feature dim of the encoder is `2 * output_dim` i.e. [f_static, f_diff] | ||
grid_shape = (14, 14) # Spatial dimensions of patch grid. | ||
|
||
backbone_output_feature_dim = get_encoder_output_dim(self.encoder, device=get_module_device(self)) | ||
|
||
self.backbone_to_vit = nn.Conv2d(in_channels=backbone_output_feature_dim, out_channels=output_dim, | ||
kernel_size=1, stride=1, padding=0, bias=False) | ||
self.vit_pooler = VisionTransformerPooler(input_dim=output_dim, grid_shape=grid_shape) | ||
|
||
# Missing image embedding | ||
self.missing_previous_emb = nn.Parameter(torch.zeros(1, output_dim, 1, 1)) | ||
trunc_normal_(self.missing_previous_emb, std=.02) | ||
|
||
def forward(self, # type: ignore[override] | ||
current_image: torch.Tensor, | ||
previous_image: Optional[torch.Tensor] = None, | ||
return_patch_embeddings: bool = False) -> ImageEncoderOutputType: | ||
|
||
batch_size = current_image.shape[0] | ||
|
||
if previous_image is not None: | ||
assert current_image.shape == previous_image.shape | ||
x = torch.cat([current_image, previous_image], dim=0) | ||
x = super().forward(x, return_patch_embeddings=True)[0] | ||
x = self.backbone_to_vit(x) | ||
patch_x, patch_x_previous = x[:batch_size], x[batch_size:] | ||
diff_x = self.vit_pooler(current_image=patch_x, previous_image=patch_x_previous) | ||
else: | ||
x = super().forward(current_image, return_patch_embeddings=True)[0] | ||
patch_x = self.backbone_to_vit(x) | ||
B, _, W, H = patch_x.shape | ||
diff_x = self.missing_previous_emb.repeat(B, 1, W, H) | ||
|
||
patch_fused = torch.cat([patch_x, diff_x], dim=1) | ||
avg_pooled_emb = torch.flatten(torch.nn.functional.adaptive_avg_pool2d(patch_fused, (1, 1)), 1) | ||
|
||
if return_patch_embeddings: | ||
return patch_fused, avg_pooled_emb | ||
|
||
return avg_pooled_emb | ||
|
||
def reload_encoder_with_dilation(self, replace_stride_with_dilation: Optional[Sequence[bool]] = None) -> None: | ||
raise NotImplementedError | ||
|
||
|
||
@torch.no_grad() | ||
def get_encoder_output_dim(module: torch.nn.Module, device: torch.device) -> int: | ||
"""Calculate the output dimension of an encoder by making a single forward pass. | ||
:param module: Encoder module. | ||
:param device: Compute device to use. | ||
""" | ||
# Target device | ||
assert isinstance(device, torch.device) | ||
|
||
x = torch.rand((1, 3, 448, 448)).to(device) | ||
|
||
# Extract the number of output feature dimensions | ||
with restore_training_mode(module): | ||
module.eval() | ||
representations = module(x) | ||
return representations.shape[1] | ||
|
||
|
||
@contextmanager | ||
def restore_training_mode(module: nn.Module) -> Generator[None, None, None]: | ||
"""Restore the training mode of a module after some operation. | ||
:param module: PyTorch module. | ||
""" | ||
training_mode = module.training | ||
yield | ||
module.train(mode=training_mode) | ||
|
||
|
||
def get_encoder_from_type(img_encoder_type: str) -> ImageEncoder: | ||
"""Returns the encoder class for the given encoder type. | ||
:param img_encoder_type: Encoder type. {RESNET18, RESNET50, RESNET18_MULTI_IMAGE, RESNET50_MULTI_IMAGE} | ||
""" | ||
if img_encoder_type in ImageEncoderType.get_members(multi_image_encoders_only=True): | ||
return MultiImageEncoder(img_encoder_type=img_encoder_type) | ||
else: | ||
return ImageEncoder(img_encoder_type=img_encoder_type) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
# ------------------------------------------------------------------------------------------- | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. | ||
# ------------------------------------------------------------------------------------------- | ||
|
||
from __future__ import annotations | ||
|
||
from abc import ABC, abstractmethod | ||
from pathlib import Path | ||
from typing import Any, Optional, Union | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from health_multimodal.common.device import get_module_device | ||
|
||
from .encoder import get_encoder_from_type, get_encoder_output_dim, MultiImageEncoder | ||
from .modules import MLP, MultiTaskModel | ||
from .types import ImageModelOutput | ||
|
||
|
||
class BaseImageModel(nn.Module, ABC): | ||
"""Abstract class for image models.""" | ||
@abstractmethod | ||
def forward(self, *args: Any, **kwargs: Any) -> ImageModelOutput: | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def get_patchwise_projected_embeddings(self, input_img: torch.Tensor, normalize: bool) -> torch.Tensor: | ||
raise NotImplementedError | ||
|
||
|
||
class ImageModel(BaseImageModel): | ||
"""Image encoder module""" | ||
|
||
def __init__(self, | ||
img_encoder_type: str, | ||
joint_feature_size: int, | ||
freeze_encoder: bool = False, | ||
pretrained_model_path: Optional[Union[str, Path]] = None, | ||
**downstream_classifier_kwargs: Any): | ||
super().__init__() | ||
|
||
# Initiate encoder, projector, and classifier | ||
self.encoder = get_encoder_from_type(img_encoder_type) | ||
self.feature_size = get_encoder_output_dim(self.encoder, device=get_module_device(self.encoder)) | ||
self.projector = MLP(input_dim=self.feature_size, output_dim=joint_feature_size, | ||
hidden_dim=joint_feature_size, use_1x1_convs=True) | ||
self.downstream_classifier_kwargs = downstream_classifier_kwargs | ||
self.classifier = self.create_downstream_classifier() if downstream_classifier_kwargs else None | ||
|
||
# Initialise the mode of modules | ||
self.freeze_encoder = freeze_encoder | ||
self.train() | ||
|
||
if pretrained_model_path is not None: | ||
if not isinstance(pretrained_model_path, (str, Path)): | ||
raise TypeError(f"Expected a string or Path, got {type(pretrained_model_path)}") | ||
state_dict = torch.load(pretrained_model_path, map_location="cpu") | ||
# drop projector | ||
for k in list(state_dict.keys()): | ||
if k.startswith("projector"): | ||
state_dict.pop(k) | ||
|
||
self.load_state_dict(state_dict, strict=False) | ||
|
||
|
||
def train(self, mode: bool = True) -> Any: | ||
"""Switch the model between training and evaluation modes.""" | ||
super().train(mode=mode) | ||
if self.freeze_encoder: | ||
self.encoder.train(mode=False) | ||
self.projector.train(mode=False) | ||
return self | ||
|
||
def forward(self, x: torch.Tensor) -> ImageModelOutput: # type: ignore[override] | ||
with torch.set_grad_enabled(not self.freeze_encoder): | ||
patch_x, pooled_x = self.encoder(x, return_patch_embeddings=True) | ||
return self.forward_post_encoder(patch_x, pooled_x) | ||
|
||
def forward_post_encoder(self, patch_x: torch.Tensor, pooled_x: torch.Tensor) -> ImageModelOutput: | ||
with torch.set_grad_enabled(not self.freeze_encoder): | ||
projected_patch_embeddings = self.projector(patch_x) | ||
projected_global_embedding = torch.mean(projected_patch_embeddings, dim=(2, 3)) | ||
|
||
logits = self.classifier(pooled_x) if self.classifier else None | ||
return ImageModelOutput(img_embedding=pooled_x, | ||
patch_embeddings=patch_x, | ||
class_logits=logits, | ||
projected_patch_embeddings=projected_patch_embeddings, | ||
projected_global_embedding=projected_global_embedding) | ||
|
||
def create_downstream_classifier(self, **kwargs: Any) -> MultiTaskModel: | ||
"""Create the classification module for the downstream task.""" | ||
downstream_classifier_kwargs = kwargs if kwargs else self.downstream_classifier_kwargs | ||
return MultiTaskModel(self.feature_size, **downstream_classifier_kwargs) | ||
|
||
@torch.no_grad() | ||
def get_patchwise_projected_embeddings(self, input_img: torch.Tensor, normalize: bool) -> torch.Tensor: | ||
"""Get patch-wise projected embeddings from the CNN model. | ||
:param input_img: input tensor image [B, C, H, W]. | ||
:param normalize: If ``True``, the embeddings are L2-normalized. | ||
:returns projected_embeddings: tensor of embeddings in shape [batch, n_patches_h, n_patches_w, feature_size]. | ||
""" | ||
assert not self.training, "This function is only implemented for evaluation mode" | ||
outputs = self.forward(input_img) | ||
projected_embeddings = outputs.projected_patch_embeddings.detach() # type: ignore | ||
if normalize: | ||
projected_embeddings = F.normalize(projected_embeddings, dim=1) | ||
projected_embeddings = projected_embeddings.permute([0, 2, 3, 1]) # B D H W -> B H W D (D: Features) | ||
return projected_embeddings | ||
|
||
|
||
class MultiImageModel(ImageModel): | ||
def __init__(self, **kwargs: Any) -> None: | ||
super().__init__(**kwargs) | ||
assert isinstance(self.encoder, MultiImageEncoder), "MultiImageModel only supports MultiImageEncoder" | ||
|
||
def forward(self, # type: ignore[override] | ||
current_image: torch.Tensor, | ||
previous_image: Optional[torch.Tensor] = None) -> ImageModelOutput: | ||
|
||
with torch.set_grad_enabled(not self.freeze_encoder): | ||
patch_x, pooled_x = self.encoder(current_image=current_image, | ||
previous_image=previous_image, | ||
return_patch_embeddings=True) | ||
return self.forward_post_encoder(patch_x, pooled_x) |
Oops, something went wrong.