Skip to content

chaudatascience/diverse_channel_vit

Repository files navigation

A Pytorch implementation for Diverse Channel ViT (DiChaViT) in our paper. This code was tested using Pytorch 2.4.1+cu121 and Python 3.10.

If you find our work useful, please consider citing:

@InProceedings{phamDiChaVit2024,
author = {Chau Pham and Bryan A. Plummer},
title = {Enhancing Feature Diversity Boosts Channel-Adaptive Vision Transformers},
booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},
year = {2024}}

alt text

Setup

Install required packages:

conda create -n dichavit python=3.10 -y
conda activate dichavit
pip install -r requirements.txt

Dataset

After downloading the following datasets, you need to update the paths in the config files configs/dataset/.

1. CHAMMI

1.1. Download

The dataset can be downloaded from https://doi.org/10.5281/zenodo.7988357

1.2. Install evaluation package

To run evaluation, we need to install the evaluation package: https://github.com/broadinstitute/MorphEm

More detail about the dataset can be found here.

2. JUMP-CP

You can refer to insitro's dataset repo for further details. Here's a quick overview to help you get started.

The processed data is stored in an S3 bucket as follows:

s3://insitro-research-2023-context-vit
└── jumpcp/
    ├──  platemap_and_metadata/
    ├──  BR00116991/
    │    ├── BR00116991_A01_1_12.npy
    │    ├── BR00116991_A01_1_13.npy
    │    └── ...
    ├──  BR00116993/
    ├──  BR00117000/
    ├──  BR00116991.pq
    ├──  BR00116993.pq
    └──  BR00117000.pq

We conduct experiments on the BR00116991 dataset, which requires downloading platemap_and_metadata/, BR00116991/ folders, and BR00116991.pq. First, you need to install AWS CLI, then run these commands in the Terminal:

aws s3 cp s3://insitro-research-2023-context-vit/jumpcp/platemap_and_metadata jumpcp/platemap_and_metadata --recursive --no-sign-request
aws s3 cp s3://insitro-research-2023-context-vit/jumpcp/BR00116991 jumpcp/BR00116991 --recursive --no-sign-request
aws s3 cp s3://insitro-research-2023-context-vit/jumpcp/BR00116991.pq jumpcp/BR00116991.pq --no-sign-request

3. So2Sat

We use the city split (version 1) of the So2Sat dataset. The dataset can be downloaded by running

wget --no-check-certificate https://dataserv.ub.tum.de/s/m1454690/download?path=%2F&files=validation.h5&downloadStartSecret=p5bjok57fil

For more detail, you can refer to So2Sat-LCZ42 repo.

Training

In this project, we use Hydra to manage configurations. To submit a job using Hydra, you need to specify the config file. Here are some key parameters:

-m: multi-run mode (submit multiple runs with 1 job)

-cp: config folder, all config files are in `configs/`

-cn: config file name (without .yaml extension)

Parameters in the command lines will override the ones in the config file. For example, to train a DiChaViT on CHAMMI dataset:

python main.py -m -cn chammi_cfg model=dichavit ++model.enable_sample=True ++model.pretrained_model_name=small tag=test_demo dataset=morphem70k_v2_12channels ++optimizer.params.lr=0.00004 ++model.temperature=0.07 ++train.num_epochs=10 ++train.batch_size=64 ++model.new_channel_inits=[zero] ++logging.wandb.use_wandb=False ++eval.skip_eval_first_epoch=True

To reproduce the results, please refer to train_scripts.sh.

Add Wandb key: If you would like to use Wandb to keep track of experiments, add your Wandb key to .env file:

echo WANDB_API_KEY=your_wandb_key >> .env

and, change use_wandb to True in configs/logging/wandb.yaml, or set ++logging.wandb.use_wandb=True in the command line.

Checkpoints

The model checkpoints for DiChaViT can be found here.

Acknowledgements

  • ChannelViT model, and dataloaders for So2Sat and JUMP-CP are adapted from ChannelViT
  • CHAMMI's baseline models, dataloader, and evaluation benchmark are from CHAMMI, MorphEm

About

Enhancing Feature Diversity Boosts Channel-Adaptive Vision Transformers

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published