Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
ChantalMP committed Nov 21, 2023
0 parents commit b3417e5
Show file tree
Hide file tree
Showing 109 changed files with 416,647 additions and 0 deletions.
84 changes: 84 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
## RaDialog: A Large Vision-Language Model for Radiology Report Generation and Conversational Assistance
**Authors**: [Chantal Pellegrini*][cp], [Ege Özsoy*][eo], [Benjamin Busam][bb], [Nassir Navab][nn], [Matthias Keicher][mk]

[cp]:https://www.cs.cit.tum.de/camp/members/chantal-pellegrini/
[eo]:https://www.cs.cit.tum.de/camp/members/ege-oezsoy/
[mk]:https://www.cs.cit.tum.de/camp/members/matthias-keicher/
[nn]:https://www.cs.cit.tum.de/camp/members/cv-nassir-navab/nassir-navab/
[bb]:https://www.cs.cit.tum.de/camp/members/benjamin-busam-1/

## [Paper](https://arxiv.org/abs/2106.02009) | [Demo](https://www.youtube.com/watch?v=8Z3QX6Q4Zq4) | Dataset - Coming Soon

<img align="right" src="figs/example.png" alt="teaser" width="50%" style="margin-left: 20px">

Conversational AI tools that can generate and discuss clinically correct radiology reports for a given medical image have the potential to transform radiology. Such a human-in-the-loop radiology assistant could facilitate a collaborative diagnostic process, thus saving time and improving the quality of reports. Towards this goal, we introduce RaDialog, the first thoroughly evaluated and publicly available large vision-language model for radiology report generation and interactive dialog. RaDialog effectively integrates visual image features and structured pathology findings with a large language model (LLM) while simultaneously adapting it to a specialized domain using parameter-efficient fine-tuning. To keep the conversational abilities of the underlying LLM, we propose a comprehensive, semi-automatically labeled, image-grounded instruct dataset for chest X-ray radiology tasks. By training with this dataset, our method achieves state-of-the-art clinical correctness in report generation and shows impressive abilities in interactive tasks such as correcting reports and answering questions, serving as a foundational step toward clinical dialog systems.

## Installation

### Environment Setup:
#### 1) RaDialog Environment
- Install the RaDialog environment with `conda create --name radialog python=3.7`
- Activate the environment with `conda activate radialog`
- Install the requirements with `pip install -r requirements.txt`
- Reinstall correct versions of torch and transformers with `pip install torch==1.13.1 transformers==4.28.1`

#### 2) CheXbert Environment
- Install the CheXbert environment with `conda create --name chexbert python=3.7`
- Activate the environment with `conda activate chexbert`
- Move to the chexbert directory with `cd chexbert`
- Install the requirements with `pip install -r requirements.txt`
- Set the absolute path to the chexbert env and folder in `local_config.py`

### Prepare the Data and Models:

#### 1) Download MIMIC-CXR
- Download the MIMIC-CXR dataset from [here](https://physionet.org/content/mimic-cxr/2.0.0/)
- in local_config.py set the path to the MIMIC-CXR dataset
- in model/lavis/defaults_report.yaml set the path to the MIMIC-CXR dataset

#### 2) Create sectioned report data
- go to the mimic-cxr folder with `cd mimic-cxr`
- run `python create_section_files.py` to prepare the report data

#### 3) Prepare the instruct dataset

- As MIMIC-CXR needs a certified PhysioNet account to be accessed, we can not publish our instruct dataset directly.
- We are working on publishing the instruct dataset on PhysioNet. In the meantime, you can create an instruct dataset yourself by following the steps below.

- The MIMIC-NLE data has to be generated first, as it also contains protected data. Follow the instructions [here](https://github.com/maximek3/MIMIC-NLE/tree/main) to generate the MIMIC-NLE data and set the path to the MIMIC-NLE data in `local_config.py`.
- For the correction task, you can write us, then we can share the used incorrect predictions with you.
- To generate data without Correction or Reasoning (MIMIC-NLE), please comment our line 335 or 336 in "create_data.py" accordingly.

Data for RaDialog-RG:
- run `python create_data.py --mode "RG"` to generate the report generation dataset in the required format (no instruct data)

Data for RaDialog-INS:
- run `python create_data.py --mode "INS"` to generate the instruct dataset

4) Download pretrained models
- Download the pretrained models from [here](TODO) and place them in the checkpoints folder

### Run Demo:
- run `python demo.py --cfg-path configs/blip2_pretrain_stage1_emb.yaml` to start the demo
- connect to the demo with a browser at `http://127.0.0.1:7860` (check terminal for address) and start chatting with RaDialog

### Evaluate RaDialog on MIMIC-CXR test set:
- RaDialog-RG: run `python test.py --prompt img_matching_examples_ig2_noexamples_IMG_findings --use_embs --num_workers 0 --lora_model checkpoints/vicuna-7b-img-report/checkpoint-11200`
- RaDialog-INS: run `python test.py --prompt img_matching_examples_ig2_noexamples_IMG_findings --use_embs --num_workers 0 --lora_model checkpoints/vicuna-7b-img-instruct/checkpoint-4800`

### Train RaDialog:
#### 1) CheXbert classifier Training
- run `python -m findings_classifier.train --train --run_name "train_chexbert" `
- then run `python -m findings_classifier.train --run_name "save_preds" ` to save the predictions of the trained model

#### 2) Image Encoder Pretraining
- run `python -m pretraining.train`

#### 3) LLM Training
Train RaDialog-RG:
- run `python finetune.py --use_embs True --base_model 'vicuna_v7' --output_dir './lora-cxr-vicuna-specific-7b-noexamples-imgemb-findings-rightpadding-stratified_32imgtokens_600tokens' --wandb_run_name lora-cxr-vicuna-specific-7b-noexamples-imgemb-findings-rightpadding-stratified_32imgtokens_600tokens --prompt_template_name vicuna_v11 --data_path "data/data_files/mimic_cxr_reports_stratified.json" --cutoff_len 600`

Train RaDialog-INS:
- run `python finetune.py --use_embs True --base_model 'vicuna_v7' --output_dir './lora-cxr-vicuna-specific-7b-noexamples-imgemb-findings-rightpadding-stratified_32imgtokens_600tokens_reversed2' --wandb_run_name lora-cxr-vicuna-specific-7b-noexamples-imgemb-findings-rightpadding-stratified_32imgtokens_600tokens_reversed2 --prompt_template_name vicuna_v11 --data_path "data/data_files/instruct_data_stratified.json" --cutoff_len 600`

# TODO fix all epochs etc etc
Empty file added biovil_t/__init__.py
Empty file.
180 changes: 180 additions & 0 deletions biovil_t/encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# -------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# -------------------------------------------------------------------------------------------

from __future__ import annotations

from contextlib import contextmanager
from typing import Any, Generator, Optional, Sequence, Tuple, Union

import torch
import torch.nn as nn
from health_multimodal.common.device import get_module_device
from timm.models.layers import trunc_normal_

from .resnet import resnet18, resnet50
from .transformer import VisionTransformerPooler
from .types import ImageEncoderType

DEFAULT_DILATION_VALUES_FOR_RESNET = (False, False, True)
ImageEncoderOutputType = Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]


class ImageEncoder(nn.Module):
"""Image encoder trunk module for the ``ImageModel`` class.
:param img_encoder_type : Type of image encoder model to use, either ``"resnet18_multi_image"`` or
``"resnet50_multi_image"``.
"""

def __init__(self, img_encoder_type: str):
super().__init__()
self.img_encoder_type = img_encoder_type
self.encoder = self._create_encoder()

def _create_encoder(self, **kwargs: Any) -> nn.Module:
if self.img_encoder_type in [ImageEncoderType.RESNET18, ImageEncoderType.RESNET18_MULTI_IMAGE]:
encoder_class = resnet18
elif self.img_encoder_type in [ImageEncoderType.RESNET50, ImageEncoderType.RESNET50_MULTI_IMAGE]:
encoder_class = resnet50
else:
supported = ImageEncoderType.get_members(multi_image_encoders_only=False)
raise NotImplementedError(f"Image encoder type \"{self.img_encoder_type}\" must be in {supported}")

encoder = encoder_class(pretrained=True, **kwargs)

return encoder

def forward(self,
current_image: torch.Tensor,
return_patch_embeddings: bool = False) -> ImageEncoderOutputType:
"""Get image global and patch embeddings"""

patch_emb = self.encoder(current_image)
avg_pooled_emb = torch.flatten(torch.nn.functional.adaptive_avg_pool2d(patch_emb, (1, 1)), 1)
if return_patch_embeddings:
return patch_emb, avg_pooled_emb

return avg_pooled_emb

def reload_encoder_with_dilation(self, replace_stride_with_dilation: Optional[Sequence[bool]] = None) -> None:
"""Workaround for enabling dilated convolutions after model initialization.
:param replace_stride_with_dilation: Replace the 2x2 standard convolution stride with a dilated convolution
in each layer in the last three blocks of ResNet architecture.
"""
if self.img_encoder_type == ImageEncoderType.RESNET18:
# resnet18 uses BasicBlock implementation, which does not support dilated convolutions.
raise NotImplementedError("resnet18 does not support dilated convolutions")

if replace_stride_with_dilation is None:
replace_stride_with_dilation = DEFAULT_DILATION_VALUES_FOR_RESNET

device = next(self.encoder.parameters()).device
new_encoder = self._create_encoder(replace_stride_with_dilation=replace_stride_with_dilation).to(device)

if self.encoder.training:
new_encoder.train()
else:
new_encoder.eval()

new_encoder.load_state_dict(self.encoder.state_dict())
self.encoder = new_encoder


class MultiImageEncoder(ImageEncoder):
"""Multi-image encoder trunk module for the ``ImageModel`` class.
It can be used to encode multiple images into combined latent representation.
Currently it only supports two input images but can be extended to support more in future.
:param img_encoder_type: Type of image encoder model to use: either ``"resnet18"`` or ``"resnet50"``.
"""

def __init__(self, img_encoder_type: str):
super().__init__(img_encoder_type)

output_dim = 256 # The aggregate feature dim of the encoder is `2 * output_dim` i.e. [f_static, f_diff]
grid_shape = (14, 14) # Spatial dimensions of patch grid.

backbone_output_feature_dim = get_encoder_output_dim(self.encoder, device=get_module_device(self))

self.backbone_to_vit = nn.Conv2d(in_channels=backbone_output_feature_dim, out_channels=output_dim,
kernel_size=1, stride=1, padding=0, bias=False)
self.vit_pooler = VisionTransformerPooler(input_dim=output_dim, grid_shape=grid_shape)

# Missing image embedding
self.missing_previous_emb = nn.Parameter(torch.zeros(1, output_dim, 1, 1))
trunc_normal_(self.missing_previous_emb, std=.02)

def forward(self, # type: ignore[override]
current_image: torch.Tensor,
previous_image: Optional[torch.Tensor] = None,
return_patch_embeddings: bool = False) -> ImageEncoderOutputType:

batch_size = current_image.shape[0]

if previous_image is not None:
assert current_image.shape == previous_image.shape
x = torch.cat([current_image, previous_image], dim=0)
x = super().forward(x, return_patch_embeddings=True)[0]
x = self.backbone_to_vit(x)
patch_x, patch_x_previous = x[:batch_size], x[batch_size:]
diff_x = self.vit_pooler(current_image=patch_x, previous_image=patch_x_previous)
else:
x = super().forward(current_image, return_patch_embeddings=True)[0]
patch_x = self.backbone_to_vit(x)
B, _, W, H = patch_x.shape
diff_x = self.missing_previous_emb.repeat(B, 1, W, H)

patch_fused = torch.cat([patch_x, diff_x], dim=1)
avg_pooled_emb = torch.flatten(torch.nn.functional.adaptive_avg_pool2d(patch_fused, (1, 1)), 1)

if return_patch_embeddings:
return patch_fused, avg_pooled_emb

return avg_pooled_emb

def reload_encoder_with_dilation(self, replace_stride_with_dilation: Optional[Sequence[bool]] = None) -> None:
raise NotImplementedError


@torch.no_grad()
def get_encoder_output_dim(module: torch.nn.Module, device: torch.device) -> int:
"""Calculate the output dimension of an encoder by making a single forward pass.
:param module: Encoder module.
:param device: Compute device to use.
"""
# Target device
assert isinstance(device, torch.device)

x = torch.rand((1, 3, 448, 448)).to(device)

# Extract the number of output feature dimensions
with restore_training_mode(module):
module.eval()
representations = module(x)
return representations.shape[1]


@contextmanager
def restore_training_mode(module: nn.Module) -> Generator[None, None, None]:
"""Restore the training mode of a module after some operation.
:param module: PyTorch module.
"""
training_mode = module.training
yield
module.train(mode=training_mode)


def get_encoder_from_type(img_encoder_type: str) -> ImageEncoder:
"""Returns the encoder class for the given encoder type.
:param img_encoder_type: Encoder type. {RESNET18, RESNET50, RESNET18_MULTI_IMAGE, RESNET50_MULTI_IMAGE}
"""
if img_encoder_type in ImageEncoderType.get_members(multi_image_encoders_only=True):
return MultiImageEncoder(img_encoder_type=img_encoder_type)
else:
return ImageEncoder(img_encoder_type=img_encoder_type)
128 changes: 128 additions & 0 deletions biovil_t/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# -------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# -------------------------------------------------------------------------------------------

from __future__ import annotations

from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Optional, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from health_multimodal.common.device import get_module_device

from .encoder import get_encoder_from_type, get_encoder_output_dim, MultiImageEncoder
from .modules import MLP, MultiTaskModel
from .types import ImageModelOutput


class BaseImageModel(nn.Module, ABC):
"""Abstract class for image models."""
@abstractmethod
def forward(self, *args: Any, **kwargs: Any) -> ImageModelOutput:
raise NotImplementedError

@abstractmethod
def get_patchwise_projected_embeddings(self, input_img: torch.Tensor, normalize: bool) -> torch.Tensor:
raise NotImplementedError


class ImageModel(BaseImageModel):
"""Image encoder module"""

def __init__(self,
img_encoder_type: str,
joint_feature_size: int,
freeze_encoder: bool = False,
pretrained_model_path: Optional[Union[str, Path]] = None,
**downstream_classifier_kwargs: Any):
super().__init__()

# Initiate encoder, projector, and classifier
self.encoder = get_encoder_from_type(img_encoder_type)
self.feature_size = get_encoder_output_dim(self.encoder, device=get_module_device(self.encoder))
self.projector = MLP(input_dim=self.feature_size, output_dim=joint_feature_size,
hidden_dim=joint_feature_size, use_1x1_convs=True)
self.downstream_classifier_kwargs = downstream_classifier_kwargs
self.classifier = self.create_downstream_classifier() if downstream_classifier_kwargs else None

# Initialise the mode of modules
self.freeze_encoder = freeze_encoder
self.train()

if pretrained_model_path is not None:
if not isinstance(pretrained_model_path, (str, Path)):
raise TypeError(f"Expected a string or Path, got {type(pretrained_model_path)}")
state_dict = torch.load(pretrained_model_path, map_location="cpu")
# drop projector
for k in list(state_dict.keys()):
if k.startswith("projector"):
state_dict.pop(k)

self.load_state_dict(state_dict, strict=False)


def train(self, mode: bool = True) -> Any:
"""Switch the model between training and evaluation modes."""
super().train(mode=mode)
if self.freeze_encoder:
self.encoder.train(mode=False)
self.projector.train(mode=False)
return self

def forward(self, x: torch.Tensor) -> ImageModelOutput: # type: ignore[override]
with torch.set_grad_enabled(not self.freeze_encoder):
patch_x, pooled_x = self.encoder(x, return_patch_embeddings=True)
return self.forward_post_encoder(patch_x, pooled_x)

def forward_post_encoder(self, patch_x: torch.Tensor, pooled_x: torch.Tensor) -> ImageModelOutput:
with torch.set_grad_enabled(not self.freeze_encoder):
projected_patch_embeddings = self.projector(patch_x)
projected_global_embedding = torch.mean(projected_patch_embeddings, dim=(2, 3))

logits = self.classifier(pooled_x) if self.classifier else None
return ImageModelOutput(img_embedding=pooled_x,
patch_embeddings=patch_x,
class_logits=logits,
projected_patch_embeddings=projected_patch_embeddings,
projected_global_embedding=projected_global_embedding)

def create_downstream_classifier(self, **kwargs: Any) -> MultiTaskModel:
"""Create the classification module for the downstream task."""
downstream_classifier_kwargs = kwargs if kwargs else self.downstream_classifier_kwargs
return MultiTaskModel(self.feature_size, **downstream_classifier_kwargs)

@torch.no_grad()
def get_patchwise_projected_embeddings(self, input_img: torch.Tensor, normalize: bool) -> torch.Tensor:
"""Get patch-wise projected embeddings from the CNN model.
:param input_img: input tensor image [B, C, H, W].
:param normalize: If ``True``, the embeddings are L2-normalized.
:returns projected_embeddings: tensor of embeddings in shape [batch, n_patches_h, n_patches_w, feature_size].
"""
assert not self.training, "This function is only implemented for evaluation mode"
outputs = self.forward(input_img)
projected_embeddings = outputs.projected_patch_embeddings.detach() # type: ignore
if normalize:
projected_embeddings = F.normalize(projected_embeddings, dim=1)
projected_embeddings = projected_embeddings.permute([0, 2, 3, 1]) # B D H W -> B H W D (D: Features)
return projected_embeddings


class MultiImageModel(ImageModel):
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
assert isinstance(self.encoder, MultiImageEncoder), "MultiImageModel only supports MultiImageEncoder"

def forward(self, # type: ignore[override]
current_image: torch.Tensor,
previous_image: Optional[torch.Tensor] = None) -> ImageModelOutput:

with torch.set_grad_enabled(not self.freeze_encoder):
patch_x, pooled_x = self.encoder(current_image=current_image,
previous_image=previous_image,
return_patch_embeddings=True)
return self.forward_post_encoder(patch_x, pooled_x)
Loading

0 comments on commit b3417e5

Please sign in to comment.