This is the official Python
implementation of the paper Energy-Guided Continuous Entropic Barycenter Estimation for General Costs (paper on Arxiv) by Alexander Kolesov, Petr Mokrov, Igor Udovichenko, Milena Gazdieva,Anastasis Kratsios, Gudmund Pammer, Evgeny Burnaev and Alexander Korotin.
The implementation is GPU-based. Single GPU GTX 1080 ti is enough to run each particular experiment. We tested the code with torch==2.1.1+cu121
. The code might not run as intended in older/newer torch
versions. Versions of other libraries are specified in requirements.txt
. Pre-trained models for maps and potentials are located here.
All the experiments are issued in the form of pretty self-explanatory jupyter notebooks ( stylegan2/notebooks/
).
src/
- auxiliary source code for the experiments: training, plotting, logging, etc.stylegan2/
- folder with auxiliary code for using StyleGAN2.stylegan2/notebooks
- jupyter notebooks with evaluation of barycenters on 2D and Image datasets.data/
- folder with datasets.SG2_ckpt/
- folder with checkpoints for trained StyleGAN2 models.
stylegan2/notebooks/twister2D.ipynb
-- toy experiments on 2D Twister dataset.stylegan2/notebooks/Gauss2D.ipynb
-- evaluating metrics of our method in Gaussian case.
notebooks/MNIST_01_barycenter_in_data_space.ipynb
-- estimating barycenters for 0,1 digits of MNIST dataset in Image space ;notebooks/MNIST_01_barycenter_in_latent_space.ipynb
-- estimating barycenters for 0,1 digits of MNIST dataset in latent space ;notebooks/Ave_celeba_in_data_space.ipynb
-- estimating barycenters of Ave, Celeba! dataset in Image space ;notebooks/Ave_celeba_in_latent_space.ipynb
-- estimating barycenters of Ave, Celeba! dataset in latent space ;
- Download the repository.
git clone https://github.com/justkolesov/EnergyGuidedBarycenters.git
- Create virtual environment
pip install -r requirements.txt
-
Download either MNIST or Ave, Celeba! 64x64 dataset.
-
Set downloaded dataset in appropriate subfolder in
data/
. -
If you run experiment in Image space, download appropriate StyleGan2 model from here (folder
StyleGan2/
). -
Set StyleGan2 model in appropriate subfolder in
SG2_ckpt/
. -
Run notebook for training or take appropriate checkpoint from here and upload them.
- Ave,Celeba with faces dataset;
- MNIST with images dataset.
- UNet architecture for maps in Image spaces;
- ResNet architectures for maps in latent spaces;