Skip to content
/ WKD Public

The offical implementation of [NeurIPS2024] Wasserstein Distance Rivals Kullback-Leibler Divergence for Knowledge Distillation https://arxiv.org/abs/2412.08139

Notifications You must be signed in to change notification settings

JiamingLv/WKD

Repository files navigation

Wasserstein Distance Rivals Kullback-Leibler Divergence for Knowledge Distillation

In this repo, we provide the offical implementation of this paper:

[NeurIPS2024] "Wasserstein Distance Rivals Kullback-Leibler Divergence for Knowledge Distillation" [Project] [Paper].

Introduction

In this paper, We propose a novel methodology of Wasserstein distance based knowledge distillation (WKD), extending beyond the classical Kullback-Leibler divergece based one pioneered by Hinton et al. Specifically,

  • We present a discrete WD based logit distillation method (WKD-L). It can leverage rich interrelations among classes via cross-category comparisons between predicted probabilities of the teacher and student, overcoming the downside of category-to-category KL divergence.

  • We introduce continuous WD into intermediate layers for feature distillation (WKD-F). It can effectively leverage geometric structure of the Riemannian space of Gaussians, better than geometryunaware KL-divergence.

We hope our work can shed light on the promise of WD and inspire further interest in this metric in knowledge distillation.

Citation

If this repo is helpful for your research, please consider citing the paper:

@inproceedings{WKD_NeurIPS2024,
  title={Wasserstein Distance Rivals Kullback-Leibler Divergence for Knowledge Distillation},
  author={Jiaming Lv and Haoyuan Yang and Peihua Li},
  booktitle={Advances in Neural Information Processing Systems},
  year={2024}
}

Experiments

We evaluate WKD for image classification on ImageNet and CIFAR-100, following the settings of CRD.

Image Classification On ImageNet

Method ResNet34 -> ResNet18 ResNet50 -> MobileNetV1
Top-1 Model Log Top-1 Model Log
WKD-L 72.49 Model Log 73.17 Model Log
WKD-F 72.50 Model Log 73.12 Model Log
WKD-L+WKD-F 72.76 Model Log 73.69 Model Log

Image Classification On CIFAR-100

Method RN32x4 -> RN8x4 WRN40-2 -> SNV1
Top-1 Model Log Top-1 Model Log
WKD-L 76.53 Model Log 76.72 Model Log
WKD-F 76.77 Model Log 77.36 Model Log
WKD-L+WKD-F 77.28 Model Log 77.50 Model Log

Installation

Note that the test accuracy may slightly vary with different Pytorch/CUDA versions, GPUs, etc. All our experiments are conducted on a PC with an Intel Core i9-13900K CPU and GeForce RTX 4090 GPUs.

Environments:

We recommend using the Pytorch NGC Containers environment provided by NVIDIA to reproduce our method. This environment contains

  • CUDA 12.1
  • Python 3.8
  • PyTorch 2.1.0
  • torchvision 0.16.0

First, please make sure that Docker Engine is installed on your device by referring to this installation guide. Then, use the following command to pull the docker image:

docker pull nvcr.io/nvidia/pytorch:23.04-py3

Run the docker image and enter the container with the following command:

docker run --name wkd -dit --gpus=all nvcr.io/nvidia/pytorch:23.04-py3 /bin/bash

Note: If you are unable to use Docker, you can try manually installing the PyTorch environment. It is recommended to keep the environment consistent with the one in the NGC Container.

Clone this repo:

Clone this repo in NGC Container:

git clone https://github.com/JiamingLv/WKD.git
cd WKD

And then install the extra package:

pip install -r requirements.txt
python setup.py develop

Getting started

  1. Generate cost matrix offline for WKD-L
  • We have provided the cost matrix in wkd_cost_matrix. If you want to regenerate the cost matrix, run generate_cost_matrix.sh
    sh generate_cost_matrix.sh
  1. Training on ImageNet
  • Put the ImageNet dataset to default path data/imagenet. And you can change this path by modifing the variable data_folder in mdistiller/dataset/imagenet.py.

    # for instance, our WKD-L method.
    python3 tools/train.py --cfg configs/imagenet/r34_r18/wkd_l.yaml
  1. Training on CIFAR-100
  • Download the cifar_teachers.tar provided by DKD and untar it to ./download_ckpts via tar xvf cifar_teachers.tar.

    # for instance, our WKD-L method.
    python3 tools/train.py --cfg configs/cifar100/wkd_l/res32x4_res8x4.yaml

Acknowledgement

Contact

If you have any questions or suggestions, please contact us: