This is the official Python
implementation of the ICLR 2021 paper Continuous Wasserstein-2 Barycenter Estimation without Minimax Optimization (paper on openreview) by Alexander Korotin, Lingxiao Li, Justin Solomon and Evgeny Burnaev
The repository contains the fully-reproducible PyTorch
source code for computing Wasserstein-2 barycenters in high dimensions via the non-minimax method (proposed in the paper) by using input convex neural networks. Examples are provided for various toy examples and the example of averaging image color palettes.
@inproceedings{
korotin2021continuous,
title={Continuous Wasserstein-2 Barycenter Estimation without Minimax Optimization},
author={Alexander Korotin and Lingxiao Li and Justin Solomon and Evgeny Burnaev},
booktitle={International Conference on Learning Representations},
year={2021},
url={https://openreview.net/forum?id=3tFAs5E-Pe}
}
The implementation is GPU-based. Single GPU (~GTX 1080 ti) is enough to run each particular experiment. Tested with
torch==1.3.0
The code might not run as intended in newer torch
versions.
- Repository for Kantorovich Strikes Back! Wasserstein GANs are not Optimal Transport? paper.
- Repository for Wasserstein-2 Generative Networks paper.
- Repository for Continuous Regularized Wasserstein Barycenters paper.
- Repository for Do Neural Optimal Transport Solvers Work? A Continuous Wasserstein-2 Benchmark paper.
- Repository for Large-Scale Wasserstein Gradient Flows paper.
The code for running the experiments are located in self-contained jupyter notebooks (notebooks/
). For convenience, the majority of the evaluation output is preserved. Other auxilary source code is moved to .py
modules (src/
).
notebooks/CW2B_toy_experiments.ipynb.ipynb
-- toy experiments (in dimensions up to 256) and subset posterior aggregation.notebooks/CW2B_averaging_color_palettes.ipynb
-- averaging color palettes of images.
src/icnn.py
-- modules for Input Convex Neural Network architectures (DenseICNN);
poster/CW2B_poster.png
-- poster (landscape format)poster/CW2B_poster.svg
-- source file for the poster
The provided code is capable of generating the following visual results that are included in the paper.
Example below contains 4 initial distributions (on the left), the ground truth barycenter (in the middle) and the barycenter computed by each of 4 potentials recovered by our algithm (on the right).
Example below demonstrates barycenters of RGB (3D) color palettes of three images.
Original images and color palettes
"Averaged" images and color palettes (estimated by each of three potentials computed by our algorithm)