In this repo, we provide the offical implementation of this paper:
[NeurIPS2024] "Wasserstein Distance Rivals Kullback-Leibler Divergence for Knowledge Distillation" [Project] [Paper].
In this paper, We propose a novel methodology of Wasserstein distance based knowledge distillation (WKD), extending beyond the classical Kullback-Leibler divergece based one pioneered by Hinton et al. Specifically,
-
We present a discrete WD based logit distillation method (
WKD-L
). It can leverage rich interrelations among classes via cross-category comparisons between predicted probabilities of the teacher and student, overcoming the downside of category-to-category KL divergence. -
We introduce continuous WD into intermediate layers for feature distillation (
WKD-F
). It can effectively leverage geometric structure of the Riemannian space of Gaussians, better than geometryunaware KL-divergence.
We hope our work can shed light on the promise of WD and inspire further interest in this metric in knowledge distillation.
If this repo is helpful for your research, please consider citing the paper:
@inproceedings{WKD_NeurIPS2024,
title={Wasserstein Distance Rivals Kullback-Leibler Divergence for Knowledge Distillation},
author={Jiaming Lv and Haoyuan Yang and Peihua Li},
booktitle={Advances in Neural Information Processing Systems},
year={2024}
}
We evaluate WKD for image classification on ImageNet and CIFAR-100, following the settings of CRD.
Method | ResNet34 -> ResNet18 | ResNet50 -> MobileNetV1 | ||||
---|---|---|---|---|---|---|
Top-1 | Model | Log | Top-1 | Model | Log | |
WKD-L | 72.49 | Model | Log | 73.17 | Model | Log |
WKD-F | 72.50 | Model | Log | 73.12 | Model | Log |
WKD-L+WKD-F | 72.76 | Model | Log | 73.69 | Model | Log |
Method | RN32x4 -> RN8x4 | WRN40-2 -> SNV1 | ||||
---|---|---|---|---|---|---|
Top-1 | Model | Log | Top-1 | Model | Log | |
WKD-L | 76.53 | Model | Log | 76.72 | Model | Log |
WKD-F | 76.77 | Model | Log | 77.36 | Model | Log |
WKD-L+WKD-F | 77.28 | Model | Log | 77.50 | Model | Log |
Note that the test accuracy may slightly vary with different Pytorch/CUDA versions, GPUs, etc. All our experiments are conducted on a PC with an Intel Core i9-13900K CPU and GeForce RTX 4090 GPUs.
We recommend using the Pytorch NGC Containers environment provided by NVIDIA to reproduce our method. This environment contains
- CUDA 12.1
- Python 3.8
- PyTorch 2.1.0
- torchvision 0.16.0
First, please make sure that Docker Engine is installed on your device by referring to this installation guide. Then, use the following command to pull the docker image:
docker pull nvcr.io/nvidia/pytorch:23.04-py3
Run the docker image and enter the container with the following command:
docker run --name wkd -dit --gpus=all nvcr.io/nvidia/pytorch:23.04-py3 /bin/bash
Note: If you are unable to use Docker, you can try manually installing the PyTorch environment. It is recommended to keep the environment consistent with the one in the NGC Container.
Clone this repo in NGC Container:
git clone https://github.com/JiamingLv/WKD.git
cd WKD
And then install the extra package:
pip install -r requirements.txt
python setup.py develop
- Generate cost matrix offline for WKD-L
- We have provided the cost matrix in
wkd_cost_matrix
. If you want to regenerate the cost matrix, rungenerate_cost_matrix.sh
sh generate_cost_matrix.sh
- Training on ImageNet
-
Put the ImageNet dataset to default path
data/imagenet
. And you can change this path by modifing the variabledata_folder
inmdistiller/dataset/imagenet.py
.# for instance, our WKD-L method. python3 tools/train.py --cfg configs/imagenet/r34_r18/wkd_l.yaml
- Training on CIFAR-100
-
Download the
cifar_teachers.tar
provided by DKD and untar it to./download_ckpts
viatar xvf cifar_teachers.tar
.# for instance, our WKD-L method. python3 tools/train.py --cfg configs/cifar100/wkd_l/res32x4_res8x4.yaml
-
Thanks for DKD. We build this repo based on the mdistiller provided by DKD.
-
We thank OFA for the codebase, NKD for the codebase, CRD for the codebase, and ReviewKD for the codebase.
-
Thanks also go to authors of other papers who make their code publicly available.
If you have any questions or suggestions, please contact us:
-
Jiaming Lv ([email protected])
-
Haoyuan Yang ([email protected])