forked from phizaz/diffae
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
17 changed files
with
263,224 additions
and
446 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
# Official implementation of Diffusion Autoencoders | ||
|
||
A CVPR 2022 paper: | ||
|
||
> Preechakul, Konpat, Nattanat Chatthee, Suttisak Wizadwongsa, and Supasorn Suwajanakorn. 2021. “Diffusion Autoencoders: Toward a Meaningful and Decodable Representation.” arXiv [cs.CV]. arXiv. http://arxiv.org/abs/2111.15640. | ||
## Usage | ||
|
||
Note: Since we expect a lot of changes on the codebase, please fork the repo before using. | ||
|
||
### Quick start | ||
|
||
A jupyter notebook. | ||
|
||
For unconditional generation: `sample.ipynb` | ||
|
||
For manipulation: `manipulate.ipynb` | ||
|
||
### Checkpoints | ||
|
||
Checkpoints ought to be put into a separate directory `checkpoints`. | ||
|
||
The directory tree may look like: | ||
|
||
``` | ||
checkpoints/ | ||
- bedroom128_autoenc | ||
- last.ckpt # diffae checkpoint | ||
- latent.ckpt # predicted z_sem on the dataset | ||
- bedroom128_autoenc_latent | ||
- last.ckpt # diffae + latent DPM checkpoint | ||
- bedroom128_ddpm | ||
- ... | ||
``` | ||
|
||
We provide checkpoints for the following models: | ||
|
||
1. DDIM: FFHQ128 ([72M](https://drive.google.com/drive/folders/1-J8FPNZOQxSqpfTpwRXawLi2KKGL1qlK?usp=sharing), [130M](https://drive.google.com/drive/folders/17T5YJXpYdgE6cWltN8gZFxRsJzpVxnLh?usp=sharing)), [Bedroom128](https://drive.google.com/drive/folders/19s-lAiK7fGD5Meo5obNV5o0L3MfqU0Sk?usp=sharing), [Horse128](https://drive.google.com/drive/folders/1PiC5JWLcd8mZW9cghDCR0V4Hx0QCXOor?usp=sharing) | ||
2. DiffAE (autoencoding only): [FFHQ256](https://drive.google.com/drive/folders/1hTP9QbYXwv_Nl5sgcZNH0yKprJx7ivC5?usp=sharing), FFHQ128 ([72M](https://drive.google.com/drive/folders/15QHmZP1G5jEMh80R1Nbtdb4ZKb6VvfII?usp=sharing), [130M](https://drive.google.com/drive/folders/1UlwLwgv16cEqxTn7g-V2ykIyopmY_fVz?usp=sharing)), [Bedroom128](https://drive.google.com/drive/folders/1okhCb1RezlWmDbdEAGWMHMkUBRRXmey0?usp=sharing), [Horse128](https://drive.google.com/drive/folders/1Ujmv3ajeiJLOT6lF2zrQb4FimfDkMhcP?usp=sharing) | ||
3. DiffAE (with latent DPM, can sample): [FFHQ256](https://drive.google.com/drive/folders/1MonJKYwVLzvCFYuVhp-l9mChq5V2XI6w?usp=sharing), [FFHQ128](https://drive.google.com/drive/folders/1E3Ew1p9h42h7UA1DJNK7jnb2ERybg9ji?usp=sharing), [Bedroom128](https://drive.google.com/drive/folders/1okhCb1RezlWmDbdEAGWMHMkUBRRXmey0?usp=sharing), [Horse128](https://drive.google.com/drive/folders/1Ujmv3ajeiJLOT6lF2zrQb4FimfDkMhcP?usp=sharing) | ||
4. DiffAE's classifiers (for manipulation): [FFHQ256's latent on CelebAHQ](https://drive.google.com/drive/folders/1QGkTfvNhgi_TbbV8GbX1Emrp0lStsqLj?usp=sharing), [FFHQ128's latent on CelebAHQ](https://drive.google.com/drive/folders/1E3Ew1p9h42h7UA1DJNK7jnb2ERybg9ji?usp=sharing) | ||
|
||
|
||
### LMDB Datasets | ||
|
||
We do not own any of the following datasets. We provide the LMDB ready-to-use dataset for the sake of convenience. | ||
|
||
- [FFHQ](https://drive.google.com/drive/folders/1ww7itaSo53NDMa0q-wn-3HWZ3HHqK1IK?usp=sharing) | ||
- [CelebAHQ](https://drive.google.com/drive/folders/1SX3JuVHjYA8sA28EGxr_IoHJ63s4Btbl?usp=sharing) | ||
- [LSUN Bedroom](https://drive.google.com/drive/folders/1O_3aT3LtY1YDE2pOQCp6MFpCk7Pcpkhb?usp=sharing) | ||
- [LSUN Horse](https://drive.google.com/drive/folders/1ooHW7VivZUs4i5CarPaWxakCwfeqAK8l?usp=sharing) | ||
|
||
The directory tree should be: | ||
|
||
``` | ||
datasets/ | ||
- bedroom256.lmdb | ||
- celebahq256.lmdb | ||
- ffhq256.lmdb | ||
- horse256.lmdb | ||
``` | ||
|
||
You can also download from the original sources, and use our provided codes to package them as LMDB files. | ||
Original sources for each dataset is as follows: | ||
|
||
- FFHQ (https://github.com/NVlabs/ffhq-dataset) | ||
- CelebAHQ (https://github.com/switchablenorms/CelebAMask-HQ) | ||
- LSUN (https://github.com/fyu/lsun) | ||
|
||
The conversion codes are provided as: | ||
|
||
``` | ||
data_resize_bedroom.py | ||
data_resize_celebhq.py | ||
data_resize_ffhq.py | ||
data_resize_horse.py | ||
``` | ||
|
||
Google drive: https://drive.google.com/drive/folders/1abNP4QKGbNnymjn8607BF0cwxX2L23jh?usp=sharing | ||
|
||
|
||
## Training | ||
|
||
Soon ... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
import argparse | ||
import multiprocessing | ||
import os | ||
from os.path import join, exists | ||
from functools import partial | ||
from io import BytesIO | ||
import shutil | ||
|
||
import lmdb | ||
from PIL import Image | ||
from torchvision.datasets import LSUNClass | ||
from torchvision.transforms import functional as trans_fn | ||
from tqdm import tqdm | ||
|
||
from multiprocessing import Process, Queue | ||
|
||
|
||
def resize_and_convert(img, size, resample, quality=100): | ||
img = trans_fn.resize(img, size, resample) | ||
img = trans_fn.center_crop(img, size) | ||
buffer = BytesIO() | ||
img.save(buffer, format="webp", quality=quality) | ||
val = buffer.getvalue() | ||
|
||
return val | ||
|
||
|
||
def resize_multiple(img, | ||
sizes=(128, 256, 512, 1024), | ||
resample=Image.LANCZOS, | ||
quality=100): | ||
imgs = [] | ||
|
||
for size in sizes: | ||
imgs.append(resize_and_convert(img, size, resample, quality)) | ||
|
||
return imgs | ||
|
||
|
||
def resize_worker(idx, img, sizes, resample): | ||
img = img.convert("RGB") | ||
out = resize_multiple(img, sizes=sizes, resample=resample) | ||
return idx, out | ||
|
||
|
||
from torch.utils.data import Dataset, DataLoader | ||
|
||
|
||
class ConvertDataset(Dataset): | ||
def __init__(self, data) -> None: | ||
self.data = data | ||
|
||
def __len__(self): | ||
return len(self.data) | ||
|
||
def __getitem__(self, index): | ||
img, _ = self.data[index] | ||
bytes = resize_and_convert(img, 256, Image.LANCZOS, quality=90) | ||
return bytes | ||
|
||
|
||
if __name__ == "__main__": | ||
""" | ||
converting lsun' original lmdb to our lmdb, which is somehow more performant. | ||
""" | ||
from tqdm import tqdm | ||
|
||
# path to the original lsun's lmdb | ||
src_path = 'datasets/bedroom_train_lmdb' | ||
out_path = 'datasets/bedroom256.lmdb' | ||
|
||
dataset = LSUNClass(root=os.path.expanduser(src_path)) | ||
dataset = ConvertDataset(dataset) | ||
loader = DataLoader(dataset, | ||
batch_size=50, | ||
num_workers=12, | ||
collate_fn=lambda x: x, | ||
shuffle=False) | ||
|
||
target = os.path.expanduser(out_path) | ||
if os.path.exists(target): | ||
shutil.rmtree(target) | ||
|
||
with lmdb.open(target, map_size=1024**4, readahead=False) as env: | ||
with tqdm(total=len(dataset)) as progress: | ||
i = 0 | ||
for batch in loader: | ||
with env.begin(write=True) as txn: | ||
for img in batch: | ||
key = f"{256}-{str(i).zfill(7)}".encode("utf-8") | ||
# print(key) | ||
txn.put(key, img) | ||
i += 1 | ||
progress.update() | ||
# if i == 1000: | ||
# break | ||
# if total == len(imgset): | ||
# break | ||
|
||
with env.begin(write=True) as txn: | ||
txn.put("length".encode("utf-8"), str(i).encode("utf-8")) |
Oops, something went wrong.