The Spatial-VAE architecture, as defined in [1] is shown below.
Make sure to check our interactive demo of the model here! 😇
The experiments were run in Python 3.7
. Start by installing PyTorch as per the docs. Then run the commands below (they are for Ubuntu and might be slightly different for other OS):
# Clone the repository.
git clone https://github.com/COMP6248-Reproducability-Challenge/SVAE.git
cd SVAE
# (Optional) Create a new Python environment and activate it.
python3 -m venv .env
source .env/bin/activate
# Install the dependencies.
pip install -r requirements.txt
To train a model with the default parameters run:
python -m src.train
To see the available options, run:
python -m src.train --help
After the model is trained, a state dict (a .pt
file) and the loss log (a .csv
file) will be stored in model_logs/
. The name of the files is <dataset>_<svae><has_rotation><has_translation>_<n_unconstrained>
.
- data - the three MNIST datasets are here.
- doc - the report is here.
- model_logs - after training, models and logs will be stored here.
- src - the source code.
- models - the PyTorch models.
svae.py
- implementation of the SpatialVAE architecture.mnist_model.py
- adds training methods to SpatialVAE fromsvae.py
.
- notebooks - an notebook example of the model.
train.py
- the training script.
- models - the PyTorch models.
- report - contains the LaTeX source, and scripts for plotting the results; all logs are stored here.
- gh-pages - contains scripts for converting the PyTorch model to ONNX and the html files for the interactive demo.
- vanilla_vae - an implementation of a standard (vanilla) VAE, plus training scripts.
[1] Explicitly disentangling image content from translation and rotation with spatial-VAE (online)
Tristan Bepler, Ellen D. Zhong, Kotaro Kelley, Edward Brignole, Bonnie Berger.