This repository is the official implementation of the methods in the publication:
- Ella Tamir, Martin Trapp, and Arno Solin (2023). Transport with support: Data-conditional diffusion bridges. In Transactions on Machine Learning Research (TMLR). [arXiv preprint]
The dynamic Schrödinger bridge problem provides an appealing setting for solving constrained time-series data generation tasks posed as optimal transport problems. It consists of learning non-linear diffusion processes using efficient iterative solvers. Recent works have demonstrated state-of-the-art results (e.g., in modelling single-cell embryo RNA sequences or sampling from complex posteriors) but are limited to learning bridges with only initial and terminal constraints. Our work extends this paradigm by proposing the Iterative Smoothing Bridge (ISB). We integrate Bayesian filtering and optimal control into learning the diffusion process, enabling the generation of constrained stochastic processes governed by sparse observations at intermediate stages and terminal constraints. We assess the effectiveness of our method on synthetic and real-world data generation tasks and show that the ISB generalises well to high-dimensional data, is computationally efficient, and provides accurate estimates of the marginals at intermediate and terminal times.
In the root of this repo, run
conda env create --file=isb-env.yaml
conda activate isb-env
pip install -e ./src
The experiments scripts are in the folder tasks
. To run (non-image) experiments, run the Python file tasks/isb_tasks/run_iterative_smoother.py
.
To successfully run an experiment, follow these steps:
- If you are using a zero drift initialization, no IPFP run is necessary. If you need an unconstrained Schrödinger bridge model as a reference, run IPFP on flat data first
- Run
conda activate isb-env
python3 tasks/isb_tasks/run_iterative_smoother.py --config-name circle_isb
where the config name should match the desired config file.
- If you are using a zero drift initialization, no IPFP run is necessary. If you choose to use a NN drift initialization, set the model name of a trained model here after completing the steps in "Running IPFP on MNIST"
- Run
conda activate isb-env
python3 tasks/isb_tasks/run_iterative_smoother_img.py --config-name mnist_isb
We recommend using a GPU for this experiment.
- Run
conda activate isb-env
python3 tasks/bridge_tasks/run_iftp_flat.py --config-name circle2d
where the config name should match the desired config file.
- Run
conda activate isb-env
python3 tasks/bridge_tasks/run_iftp_image.py --config-name mnist
We recommend using a GPU for this experiment.
To train an ISB model for a data set, first check the list below for the config files and modify them as you like
- 2D sklearn circle: Bridge config in
configs/bridge/circles2d.yaml
, ISB config inconfigs/isb/circle_isb.yaml
- Single-cell: No bridge config, since pre-training an OT model is not necessary here, ISB config in
configs/isb/rna_isb.yaml
- MNIST: Bridge config in
configs/bridge/mnist.yaml
, ISB config inconfigs/isb/mnist_isb.yaml
Running the scripts creates model files and plots, below is a description of the folder contents:
bridge_models
: Unconstrained model files, to be loaded in ISB training if necessaryisb_models
: Trained ISB modelsplots
: ordered by dataset name and date, includes videos of the learned dynamics and pickled objects from the trained model. The video starting with "final_trajectory_final_particles" shows the result after learning.outputs
: hydra outputs, saves the config file used for each run
We wish to thank Adrien Corenflos for sharing an implementation of differentiable resampling in PyTorch (see transport_map.py
and sinkhorn.py
).
This software is provided under the MIT License.