Skip to content

Commit

Permalink
- merge everything
Browse files Browse the repository at this point in the history
  • Loading branch information
ChantalMP committed Nov 24, 2023
1 parent b3417e5 commit f22b000
Show file tree
Hide file tree
Showing 24 changed files with 382 additions and 1,129 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
/data/instruct_prompts/instruct_task_correction_preds.json
/chexbert/src/checkpoint/chexbert.pth
/pretraining/outputs/stage1_pt_instruct_blip_origlr_img448/checkpoint_4-pth
/pretraining/embs/stage1_pt_instruct_blip_origlr_img448_embeddings_test.pkl
/pretraining/embs/stage1_pt_instruct_blip_origlr_img448_embeddings_val.pkl
71 changes: 46 additions & 25 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
[nn]:https://www.cs.cit.tum.de/camp/members/cv-nassir-navab/nassir-navab/
[bb]:https://www.cs.cit.tum.de/camp/members/benjamin-busam-1/

## [Paper](https://arxiv.org/abs/2106.02009) | [Demo](https://www.youtube.com/watch?v=8Z3QX6Q4Zq4) | Dataset - Coming Soon
## [Paper](https://arxiv.org/abs/2106.02009) | Dataset - Coming Soon

<img align="right" src="figs/example.png" alt="teaser" width="50%" style="margin-left: 20px">

Expand All @@ -16,69 +16,90 @@ Conversational AI tools that can generate and discuss clinically correct radiolo
## Installation

### Environment Setup:

#### 1) RaDialog Environment
- clone this repository and move to the radialog directory with `cd RaDialog`
- Install the RaDialog environment with `conda create --name radialog python=3.7`
- Activate the environment with `conda activate radialog`
- Install the requirements with `pip install -r requirements.txt`
- Reinstall correct versions of torch and transformers with `pip install torch==1.13.1 transformers==4.28.1`
- Install hl-ml-multimodal with `pip install hi-ml-multimodal==0.2.0`
- Reinstall correct versions of torch and transformers with `pip install torch==1.13.0 transformers==4.28.1`

- Install java and set JAVA_HOME and PATH in local_config.py (we used jre1.8.0)

#### 2) CheXbert Environment
- Install the CheXbert environment with `conda create --name chexbert python=3.7`
- Activate the environment with `conda activate chexbert`
- Move to the chexbert directory with `cd chexbert`
- Install the requirements with `pip install -r requirements.txt`
- Set the absolute path to the chexbert env and folder in `local_config.py`
- Set the absolute path to the chexbert env and folder in `RaDialog/local_config.py`

### Prepare the Data and Models:

#### 1) Download MIMIC-CXR
- Download the MIMIC-CXR dataset from [here](https://physionet.org/content/mimic-cxr/2.0.0/)
- in local_config.py set the path to the MIMIC-CXR dataset
- in model/lavis/defaults_report.yaml set the path to the MIMIC-CXR dataset
#### 1) Download pretrained models
- Download the pretrained models from [here](https://github.com/ChantalMP/RaDialog/releases/tag/models)
- place chexbert.pth in RaDialog/chexbert/src/checkpoint/
- unzip vicuna-7b-img-instruct.zip and vicuna-7b-img-report.zip and place folders into RaDialog/checkpoints/
- unzip chexpert_train and place folder into RaDialog/findings_classifier/checkpoints/
- unzip embs and place folder into RaDialog/pretraining/
- unzip checkpoint_4.pth and place it into outputs/stage1_pt_instruct_blip_origlr_img448/


#### 2) Download MIMIC-CXR
- Download the MIMIC-CXR-JPG dataset from [here](https://www.physionet.org/content/mimic-cxr-jpg/2.0.0/)
- The dataset should be saved in .../physionet.org/files/mimic-cxr-jpg
- Go to physionet.org/files/mimic-cxr-jpg/files/ and unzip mimic-cxr-2.0.0-split.csv.gz

#### 2) Create sectioned report data
- go to the mimic-cxr folder with `cd mimic-cxr`
- from [here](https://physionet.org/content/mimic-cxr/2.0.0/), dowload mimic-cxr-reports.zip
- unzip it and place the folder in the same directory as the MIMIC-CXR-JPG dataset (e.g. physionet.org/files/)

- in local_config.py set the path to the MIMIC-CXR dataset (e.g. .../physionet.org/files/)
- in model/lavis/defaults_report.yaml set the path to the MIMIC-CXR-JPG dataset (e.g. .../physionet.org/files/mimic-cxr-jpg/2.0.0 )

#### 3) Create sectioned report data
- go to the mimic-cxr folder in the code with `cd mimic-cxr`
- run `python create_section_files.py` to prepare the report data
- go back to the RaDialog directory with `cd ..`

#### 3) Prepare the instruct dataset
#### 4) Prepare the instruct dataset

- As MIMIC-CXR needs a certified PhysioNet account to be accessed, we can not publish our instruct dataset directly.
- We are working on publishing the instruct dataset on PhysioNet. In the meantime, you can create an instruct dataset yourself by following the steps below.
- We are working on publishing the instruct dataset on PhysioNet. In the meantime, you can create an instruct dataset yourself by following the steps below or just use our pre-trained model.

- The MIMIC-NLE data has to be generated first, as it also contains protected data. Follow the instructions [here](https://github.com/maximek3/MIMIC-NLE/tree/main) to generate the MIMIC-NLE data and set the path to the MIMIC-NLE data in `local_config.py`.
- For the correction task, you can write us, then we can share the used incorrect predictions with you.
- To generate data without Correction or Reasoning (MIMIC-NLE), please comment our line 335 or 336 in "create_data.py" accordingly.

Data for RaDialog-RG:
- run `python create_data.py --mode "RG"` to generate the report generation dataset in the required format (no instruct data)
- run `python -m data.create_data --mode "RG"` to generate the report generation dataset in the required format (no instruct data)

Data for RaDialog-INS:
- run `python create_data.py --mode "INS"` to generate the instruct dataset
- run `python -m data.create_data --mode "INS"` to generate the instruct dataset

4) Download pretrained models
- Download the pretrained models from [here](TODO) and place them in the checkpoints folder

### Run Demo:
- run `python demo.py --cfg-path configs/blip2_pretrain_stage1_emb.yaml` to start the demo
- connect to the demo with a browser at `http://127.0.0.1:7860` (check terminal for address) and start chatting with RaDialog
- run `python demo.py --cfg-path pretraining/configs/blip2_pretrain_stage1_emb.yaml` to start the demo
- connect to the demo with a browser at `http://127.0.0.1:7860` and start chatting with RaDialog

### Evaluate RaDialog on MIMIC-CXR test set:
- RaDialog-RG: run `python test.py --prompt img_matching_examples_ig2_noexamples_IMG_findings --use_embs --num_workers 0 --lora_model checkpoints/vicuna-7b-img-report/checkpoint-11200`
- RaDialog-INS: run `python test.py --prompt img_matching_examples_ig2_noexamples_IMG_findings --use_embs --num_workers 0 --lora_model checkpoints/vicuna-7b-img-instruct/checkpoint-4800`
- RaDialog-INS (correction): run `python test.py --prompt img_matching_examples_ig2_noexamples_IMG_findings --use_embs --num_workers 0 --lora_model checkpoints/vicuna-7b-img-instruct/checkpoint-4800 --do_corr`
- RaDialog-INS (findings QA): run `python test.py --prompt img_matching_examples_ig2_noexamples_IMG_findings --use_embs --num_workers 0 --lora_model checkpoints/vicuna-7b-img-instruct/checkpoint-4800 --do_cp_all_qa` (or ----do_cp_bin_qa)

### Train RaDialog:
#### 1) CheXbert classifier Training
- run `python -m findings_classifier.train --train --run_name "train_chexbert" `
- then run `python -m findings_classifier.train --run_name "save_preds" ` to save the predictions of the trained model
- run `python -m findings_classifier.chexpert_train --train --run_name "train_chexbert"`
- in chexpert_train.py set ckpt_path (line 152) to the path of the trained model you just trained
- then run `python -m findings_classifier.chexpert_train --run_name "save_preds"` to save the predictions of the trained model

#### 2) Image Encoder Pretraining
- run `python -m pretraining.train`
#### 2) Alignment Module Pretraining
- run `python -m pretraining.train --cfg-path pretraining/configs/blip2_pretrain_stage1.yaml`, we used the 4th epoch checkpoint
- run `python -m pretraining.train --cfg-path pretraining/configs/blip2_pretrain_stage1_emb.yaml`, to save the embeddings of the trained model

#### 3) LLM Training
Train RaDialog-RG:
- run `python finetune.py --use_embs True --base_model 'vicuna_v7' --output_dir './lora-cxr-vicuna-specific-7b-noexamples-imgemb-findings-rightpadding-stratified_32imgtokens_600tokens' --wandb_run_name lora-cxr-vicuna-specific-7b-noexamples-imgemb-findings-rightpadding-stratified_32imgtokens_600tokens --prompt_template_name vicuna_v11 --data_path "data/data_files/mimic_cxr_reports_stratified.json" --cutoff_len 600`
- run `python finetune.py --use_embs True --base_model 'vicuna_v7' --output_dir 'checkpoints/lora-vicuna-7b-report' --wandb_run_name lora-vicuna-7b-report --prompt_template_name vicuna_v11 --data_path "data/data_files/mimic_cxr_reports_stratified.json" --cutoff_len 600 --num_epochs 4`

Train RaDialog-INS:
- run `python finetune.py --use_embs True --base_model 'vicuna_v7' --output_dir './lora-cxr-vicuna-specific-7b-noexamples-imgemb-findings-rightpadding-stratified_32imgtokens_600tokens_reversed2' --wandb_run_name lora-cxr-vicuna-specific-7b-noexamples-imgemb-findings-rightpadding-stratified_32imgtokens_600tokens_reversed2 --prompt_template_name vicuna_v11 --data_path "data/data_files/instruct_data_stratified.json" --cutoff_len 600`

# TODO fix all epochs etc etc
- run `python finetune.py --use_embs True --base_model 'vicuna_v7' --output_dir 'checkpoints/lora-vicuna-7b-instruct' --wandb_run_name lora-vicuna-7b-instruct --prompt_template_name vicuna_v11 --data_path "data/data_files/mimic_cxr_instruct_stratified.json" --cutoff_len 800 --num_epochs 1`
4 changes: 2 additions & 2 deletions chexbert/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ numpy==1.18.2
pandas==0.25.3
scikit-learn==0.22.1
tokenizers==0.5.2
torch==1.4.0
torch==1.13.1
tqdm==4.38.0
transformers==2.5.1
statsmodels
statsmodels==0.13.5
6 changes: 3 additions & 3 deletions data/create_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
from transformers import AutoTokenizer
from torch.utils.data.sampler import Sampler

from instruct_tasks import create_direct_task_data, create_cp_task_data, create_correction_task_data, create_nle_task_data
from data.instruct_tasks import create_direct_task_data, create_cp_task_data, create_correction_task_data, create_nle_task_data
from local_config import VIS_ROOT, PATH_TO_MIMIC_CXR
from model.modeling_llama_imgemb import LlamaForCausalLM
from model.lavis.models.blip2_models.modeling_llama_imgemb import LlamaForCausalLM


class MyReportProcessor():
Expand Down Expand Up @@ -408,7 +408,7 @@ def fuse_instruct_dataset(prompt_type="img_matching_examples_ig2_noexamples_IMG_
random.shuffle(combined_jsons)

# save to json
with open(f"instruct_prompts_{prompt_type}_stratified.json", "w") as f:
with open(f"data/data_files/mimic_cxr_instruct_stratified.json", "w") as f:
json.dump(combined_jsons, f, indent=4)


Expand Down
3 changes: 0 additions & 3 deletions data/instruct_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,9 +317,6 @@ def create_nle_task_data():
obj = json.loads(line)
mimic_nle.append(obj)

# truncate to 10 samples
mimic_nle = mimic_nle[:10]

prompts = pd.read_csv(f"data/instruct_prompts/RE_prompts.csv")["instruction"].tolist()
report_jsons = []
reports = pd.read_csv('mimic-cxr/reports_processed/mimic_cxr_sectioned.csv')
Expand Down
101 changes: 66 additions & 35 deletions demo.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import argparse
import os
import random
import numpy as np
import torch
from torch.backends import cudnn

from local_config import PATH_TO_MIMIC_CXR
from chexpert_train import LitIGClassifier
from local_config import JAVA_HOME, JAVA_PATH

# Activate for deterministic demo, else comment
SEED = 16
random.seed(SEED)
np.random.seed(SEED)
Expand All @@ -14,8 +17,8 @@
cudnn.deterministic = True

# set java path
os.environ["JAVA_HOME"] = "/home/guests/chantal_pellegrini/java/jre1.8.0_361"
os.environ["PATH"] = "/home/guests/chantal_pellegrini/java/jre1.8.0_361/bin:" + os.environ["PATH"]
os.environ["JAVA_HOME"] = JAVA_HOME
os.environ["PATH"] = JAVA_PATH + os.environ["PATH"]
os.environ['GRADIO_TEMP_DIR'] = os.path.join(os.getcwd(), "gradio_tmp")

import dataclasses
Expand All @@ -24,20 +27,38 @@
from enum import auto, Enum
from typing import List, Any


import gradio as gr
from PIL import Image
from peft import PeftModelForCausalLM
from skimage import io
from torch import nn
from transformers import LlamaTokenizer
from torchvision.transforms import Compose, Resize, ToTensor, CenterCrop, transforms

from model.lavis import tasks
from model.lavis.common.config import Config
from model.lavis.data.ReportDataset import create_chest_xray_transform_for_inference, MIMIC_CXR_Dataset
from model.lavis.data.ReportDataset import create_chest_xray_transform_for_inference, ExpandChannels
from model.lavis.models.blip2_models.modeling_llama_imgemb import LlamaForCausalLM
from train import parse_args


def parse_args():
parser = argparse.ArgumentParser(description="Training")

parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
parser.add_argument("--local_rank", type=int, default=0, help="local rank for distributed training.")
parser.add_argument(
"--options",
nargs="+",
help="override some settings in the used config, the key-value pair "
"in xxx=yyy format will be merged into config file (deprecate), "
"change to --cfg-options instead.",
)

args = parser.parse_args()

return args

class SeparatorStyle(Enum):
"""Different separator style."""
SINGLE = auto()
Expand Down Expand Up @@ -123,15 +144,31 @@ def dict(self):
vis_transforms = create_chest_xray_transform_for_inference(512, center_crop_size=448)
use_img = False
gen_report = True
pred_chexpert_labels = json.load(open('chexbert/chexbert_data/structured_preds_chexpert_log_weighting_test_macro_dicom.json', 'r'))

pred_chexpert_labels = json.load(open('findings_classifier/predictions/structured_preds_chexpert_log_weighting_test_macro.json', 'r'))

def init_blip(cfg):
task = tasks.setup_task(cfg)
model = task.build_model(cfg)
model = model.to(torch.device('cpu'))
return model

def init_chexpert_predictor():
ckpt_path = f"findings_classifier/checkpoints/chexpert_train/ChexpertClassifier-epoch=06-val_f1=0.36.ckpt"
chexpert_cols = ["No Finding", "Enlarged Cardiomediastinum",
"Cardiomegaly", "Lung Opacity",
"Lung Lesion", "Edema",
"Consolidation", "Pneumonia",
"Atelectasis", "Pneumothorax",
"Pleural Effusion", "Pleural Other",
"Fracture", "Support Devices"]
model = LitIGClassifier.load_from_checkpoint(ckpt_path, num_classes=14, class_names=chexpert_cols, strict=False)
model.eval()
model.cuda()
model.half()
cp_transforms = Compose([Resize(512), CenterCrop(488), ToTensor(), ExpandChannels()])

return model, np.asarray(model.class_names), cp_transforms


def remap_to_uint8(array: np.ndarray, percentiles=None) -> np.ndarray:
"""Remap values in input so the output range is :math:`[0, 255]`.
Expand Down Expand Up @@ -193,31 +230,36 @@ def init_vicuna():
vicuna_tokenizer.add_special_tokens({"additional_special_tokens": ["<IMG>"]})

lang_model = PeftModelForCausalLM.from_pretrained(lang_model,
f"checkpoints/lora-cxr-vicuna-specific-7b-noexamples-imgemb-instruct-cp-rightpadding-stratified_32imgtokens_800tokens/checkpoint-4800",
f"checkpoints/vicuna-7b-img-instruct/checkpoint-4800",
torch_dtype=torch.float16, use_ram_optimized_load=False).half()
# lang_model = PeftModelForCausalLM.from_pretrained(lang_model, f"{LORA_ADAPT_PATH}/lora-cxr-vicuna-specific-7b-noexamples-imgemb-findings-rightpadding-stratified_32imgtokens_600tokens/checkpoint-11200", torch_dtype=torch.float16, use_ram_optimized_load=False).half()
# lang_model = PeftModelForCausalLM.from_pretrained(lang_model, f"checkpoints/vicuna-7b-img-report/checkpoint-11200", torch_dtype=torch.float16, use_ram_optimized_load=False).half()
return lang_model, vicuna_tokenizer

blip_model = init_blip(cfg)
lang_model, vicuna_tokenizer = init_vicuna()
blip_model.eval()
lang_model.eval()

cp_model, cp_class_names, cp_transforms = init_chexpert_predictor()

def get_response(input_text, dicom):
global use_img, blip_model, lang_model, vicuna_tokenizer

if input_text == "": # only dicom was given
findings = ', '.join(pred_chexpert_labels[dicom]).lower().strip()
if gen_report:
input_text = (
f"Image information: <IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG>. Predicted Findings: {findings}. You are to act as a radiologist and write the finding section of a chest x-ray radiology report for this X-ray image and the given predicted findings. "
"Write in the style of a radiologist, write one fluent text without enumeration, be concise and don't provide explanations or reasons.")

elif input_text[-1].endswith(".png") or input_text[-1].endswith(".jpg"):
if input_text[-1].endswith(".png") or input_text[-1].endswith(".jpg"):
image = load_image(input_text[-1])
cp_image = cp_transforms(image)
image = vis_transforms(image)
dicom = input_text[-1].split('/')[-1].split('.')[0]
findings = ', '.join(pred_chexpert_labels[dicom]).lower().strip()
if dicom in pred_chexpert_labels:
findings = ', '.join(pred_chexpert_labels[dicom]).lower().strip()
else:
logits = cp_model(cp_image[None].half().cuda())
preds_probs = torch.sigmoid(logits)
preds = preds_probs > 0.5
pred = preds[0].cpu().numpy()
findings = cp_class_names[pred].tolist()
findings = ', '.join(findings).lower().strip()

if gen_report:
input_text = (
f"Image information: <IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG><IMG>. Predicted Findings: {findings}. You are to act as a radiologist and write the finding section of a chest x-ray radiology report for this X-ray image and the given predicted findings. "
Expand All @@ -234,6 +276,7 @@ def get_response(input_text, dicom):

else: # free chat
input_text = input_text
findings = None

'''Generate prompt given input prompt'''
conv.append_message(conv.roles[0], input_text)
Expand All @@ -259,7 +302,7 @@ def get_response(input_text, dicom):
# remove last message in conv
conv.messages.pop()
conv.append_message(conv.roles[1], new_pred)
return new_pred
return new_pred, findings


'''Conversation template for prompt'''
Expand Down Expand Up @@ -303,25 +346,16 @@ def clear_history(button_name):
return [] # Return empty history to the Chatbot


def clear_qa(button_name):
global chat_history, use_img, conv
conv.messages = conv.messages[:2]
return [[conv.messages[0][1], conv.messages[1][1]]]


def bot(history):
# You can now access the global `dicom` variable here if needed
if history == []:
global dicom
response = get_response("", dicom)

else:
response = get_response(history[-1][0], None)
print(response)
response, findings = get_response(history[-1][0], None)
print(response)

# show report generation prompt if first message after image
if len(history) == 1:
input_text = f"You are to act as a radiologist and write the finding section of a chest x-ray radiology report for this X-ray image and the given predicted findings. Write in the style of a radiologist, write one fluent text without enumeration, be concise and don't provide explanations or reasons."
if findings is not None:
input_text = f"Image information: (img_tokens) Predicted Findings: {findings}. {input_text}"
history.append([input_text, None])

history[-1][1] = ""
Expand All @@ -333,9 +367,6 @@ def bot(history):


if __name__ == '__main__':
# setup_seeds(42)
mimic_dataset = MIMIC_CXR_Dataset(vis_processor=None, text_processor=None, vis_root=f"{PATH_TO_MIMIC_CXR}/mimic-cxr-jpg/2.0.0", split="test", cfg=cfg, truncate=None)
a = 1
with gr.Blocks() as demo:


Expand Down
Loading

0 comments on commit f22b000

Please sign in to comment.