Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
phizaz committed Mar 17, 2022
1 parent a9c4dde commit 826d44f
Show file tree
Hide file tree
Showing 17 changed files with 263,224 additions and 446 deletions.
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,10 @@ log-latent
__pycache__
generated
latent_infer
datasets/bedroom256.lmdb
datasets/horse256.lmdb
datasets/celebahq
datasets/celebahq256.lmdb
datasets/ffhq
datasets/ffhq256.lmdb
checkpoints
84 changes: 84 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Official implementation of Diffusion Autoencoders

A CVPR 2022 paper:

> Preechakul, Konpat, Nattanat Chatthee, Suttisak Wizadwongsa, and Supasorn Suwajanakorn. 2021. “Diffusion Autoencoders: Toward a Meaningful and Decodable Representation.” arXiv [cs.CV]. arXiv. http://arxiv.org/abs/2111.15640.
## Usage

Note: Since we expect a lot of changes on the codebase, please fork the repo before using.

### Quick start

A jupyter notebook.

For unconditional generation: `sample.ipynb`

For manipulation: `manipulate.ipynb`

### Checkpoints

Checkpoints ought to be put into a separate directory `checkpoints`.

The directory tree may look like:

```
checkpoints/
- bedroom128_autoenc
- last.ckpt # diffae checkpoint
- latent.ckpt # predicted z_sem on the dataset
- bedroom128_autoenc_latent
- last.ckpt # diffae + latent DPM checkpoint
- bedroom128_ddpm
- ...
```

We provide checkpoints for the following models:

1. DDIM: FFHQ128 ([72M](https://drive.google.com/drive/folders/1-J8FPNZOQxSqpfTpwRXawLi2KKGL1qlK?usp=sharing), [130M](https://drive.google.com/drive/folders/17T5YJXpYdgE6cWltN8gZFxRsJzpVxnLh?usp=sharing)), [Bedroom128](https://drive.google.com/drive/folders/19s-lAiK7fGD5Meo5obNV5o0L3MfqU0Sk?usp=sharing), [Horse128](https://drive.google.com/drive/folders/1PiC5JWLcd8mZW9cghDCR0V4Hx0QCXOor?usp=sharing)
2. DiffAE (autoencoding only): [FFHQ256](https://drive.google.com/drive/folders/1hTP9QbYXwv_Nl5sgcZNH0yKprJx7ivC5?usp=sharing), FFHQ128 ([72M](https://drive.google.com/drive/folders/15QHmZP1G5jEMh80R1Nbtdb4ZKb6VvfII?usp=sharing), [130M](https://drive.google.com/drive/folders/1UlwLwgv16cEqxTn7g-V2ykIyopmY_fVz?usp=sharing)), [Bedroom128](https://drive.google.com/drive/folders/1okhCb1RezlWmDbdEAGWMHMkUBRRXmey0?usp=sharing), [Horse128](https://drive.google.com/drive/folders/1Ujmv3ajeiJLOT6lF2zrQb4FimfDkMhcP?usp=sharing)
3. DiffAE (with latent DPM, can sample): [FFHQ256](https://drive.google.com/drive/folders/1MonJKYwVLzvCFYuVhp-l9mChq5V2XI6w?usp=sharing), [FFHQ128](https://drive.google.com/drive/folders/1E3Ew1p9h42h7UA1DJNK7jnb2ERybg9ji?usp=sharing), [Bedroom128](https://drive.google.com/drive/folders/1okhCb1RezlWmDbdEAGWMHMkUBRRXmey0?usp=sharing), [Horse128](https://drive.google.com/drive/folders/1Ujmv3ajeiJLOT6lF2zrQb4FimfDkMhcP?usp=sharing)
4. DiffAE's classifiers (for manipulation): [FFHQ256's latent on CelebAHQ](https://drive.google.com/drive/folders/1QGkTfvNhgi_TbbV8GbX1Emrp0lStsqLj?usp=sharing), [FFHQ128's latent on CelebAHQ](https://drive.google.com/drive/folders/1E3Ew1p9h42h7UA1DJNK7jnb2ERybg9ji?usp=sharing)


### LMDB Datasets

We do not own any of the following datasets. We provide the LMDB ready-to-use dataset for the sake of convenience.

- [FFHQ](https://drive.google.com/drive/folders/1ww7itaSo53NDMa0q-wn-3HWZ3HHqK1IK?usp=sharing)
- [CelebAHQ](https://drive.google.com/drive/folders/1SX3JuVHjYA8sA28EGxr_IoHJ63s4Btbl?usp=sharing)
- [LSUN Bedroom](https://drive.google.com/drive/folders/1O_3aT3LtY1YDE2pOQCp6MFpCk7Pcpkhb?usp=sharing)
- [LSUN Horse](https://drive.google.com/drive/folders/1ooHW7VivZUs4i5CarPaWxakCwfeqAK8l?usp=sharing)

The directory tree should be:

```
datasets/
- bedroom256.lmdb
- celebahq256.lmdb
- ffhq256.lmdb
- horse256.lmdb
```

You can also download from the original sources, and use our provided codes to package them as LMDB files.
Original sources for each dataset is as follows:

- FFHQ (https://github.com/NVlabs/ffhq-dataset)
- CelebAHQ (https://github.com/switchablenorms/CelebAMask-HQ)
- LSUN (https://github.com/fyu/lsun)

The conversion codes are provided as:

```
data_resize_bedroom.py
data_resize_celebhq.py
data_resize_ffhq.py
data_resize_horse.py
```

Google drive: https://drive.google.com/drive/folders/1abNP4QKGbNnymjn8607BF0cwxX2L23jh?usp=sharing


## Training

Soon ...
25 changes: 0 additions & 25 deletions choices.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
class TrainMode(Enum):
# manipulate mode = training the classifier
manipulate = 'manipulate'
# the classifier on the image domain
manipulate_img = 'manipulateimg'
# default trainin mode!
diffusion = 'diffusion'
# default latent training mode!
Expand All @@ -16,12 +14,6 @@ class TrainMode(Enum):
def is_manipulate(self):
return self in [
TrainMode.manipulate,
TrainMode.manipulate_img,
]

def is_manipluate_img(self):
return self in [
TrainMode.manipulate_img,
]

def is_diffusion(self):
Expand Down Expand Up @@ -61,49 +53,32 @@ class ManipulateMode(Enum):
how to train the classifier to manipulate
"""
# train on whole celeba attr dataset
celeba_all = 'all'
celebahq_all = 'celebahq_all'
# train on a few show subset
celeba_fewshot = 'fewshot'
celeba_fewshot_allneg = 'fewshotallneg'
# celeba with D2C's crop
d2c_fewshot = 'd2cfewshot'
d2c_fewshot_allneg = 'd2cfewshotallneg'
celebahq_fewshot = 'celebahq_fewshot'
relighting = 'light'

def is_celeba_attr(self):
return self in [
ManipulateMode.celeba_all,
ManipulateMode.celeba_fewshot,
ManipulateMode.celeba_fewshot_allneg,
ManipulateMode.d2c_fewshot,
ManipulateMode.d2c_fewshot_allneg,
ManipulateMode.celebahq_all,
ManipulateMode.celebahq_fewshot,
]

def is_single_class(self):
return self in [
ManipulateMode.celeba_fewshot,
ManipulateMode.celeba_fewshot_allneg,
ManipulateMode.d2c_fewshot,
ManipulateMode.d2c_fewshot_allneg,
ManipulateMode.celebahq_fewshot,
]

def is_fewshot(self):
return self in [
ManipulateMode.celeba_fewshot,
ManipulateMode.celeba_fewshot_allneg,
ManipulateMode.d2c_fewshot,
ManipulateMode.d2c_fewshot_allneg,
ManipulateMode.celebahq_fewshot,
]

def is_fewshot_allneg(self):
return self in [
ManipulateMode.celeba_fewshot_allneg,
ManipulateMode.d2c_fewshot_allneg,
]

Expand Down
6 changes: 1 addition & 5 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class TrainConfig(BaseConfig):
train_pred_xstart_detach: bool = True
train_interpolate_prob: float = 0
train_interpolate_img: bool = False
manipulate_mode: ManipulateMode = ManipulateMode.celeba_all
manipulate_mode: ManipulateMode = ManipulateMode.celebahq_all
manipulate_cls: str = None
manipulate_shots: int = None
manipulate_loss: ManipulateLossType = ManipulateLossType.bce
Expand Down Expand Up @@ -365,10 +365,6 @@ def make_dataset(self, path=None, **kwargs):
image_size=self.img_size,
split='train',
**kwargs)
elif self.data_name == 'horse':
return LSUNHorse(path=path or self.data_path,
image_size=self.img_size,
**kwargs)
elif self.data_name == 'horse256':
return Horse_lmdb(path=path or self.data_path,
image_size=self.img_size,
Expand Down
101 changes: 101 additions & 0 deletions data_resize_bedroom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import argparse
import multiprocessing
import os
from os.path import join, exists
from functools import partial
from io import BytesIO
import shutil

import lmdb
from PIL import Image
from torchvision.datasets import LSUNClass
from torchvision.transforms import functional as trans_fn
from tqdm import tqdm

from multiprocessing import Process, Queue


def resize_and_convert(img, size, resample, quality=100):
img = trans_fn.resize(img, size, resample)
img = trans_fn.center_crop(img, size)
buffer = BytesIO()
img.save(buffer, format="webp", quality=quality)
val = buffer.getvalue()

return val


def resize_multiple(img,
sizes=(128, 256, 512, 1024),
resample=Image.LANCZOS,
quality=100):
imgs = []

for size in sizes:
imgs.append(resize_and_convert(img, size, resample, quality))

return imgs


def resize_worker(idx, img, sizes, resample):
img = img.convert("RGB")
out = resize_multiple(img, sizes=sizes, resample=resample)
return idx, out


from torch.utils.data import Dataset, DataLoader


class ConvertDataset(Dataset):
def __init__(self, data) -> None:
self.data = data

def __len__(self):
return len(self.data)

def __getitem__(self, index):
img, _ = self.data[index]
bytes = resize_and_convert(img, 256, Image.LANCZOS, quality=90)
return bytes


if __name__ == "__main__":
"""
converting lsun' original lmdb to our lmdb, which is somehow more performant.
"""
from tqdm import tqdm

# path to the original lsun's lmdb
src_path = 'datasets/bedroom_train_lmdb'
out_path = 'datasets/bedroom256.lmdb'

dataset = LSUNClass(root=os.path.expanduser(src_path))
dataset = ConvertDataset(dataset)
loader = DataLoader(dataset,
batch_size=50,
num_workers=12,
collate_fn=lambda x: x,
shuffle=False)

target = os.path.expanduser(out_path)
if os.path.exists(target):
shutil.rmtree(target)

with lmdb.open(target, map_size=1024**4, readahead=False) as env:
with tqdm(total=len(dataset)) as progress:
i = 0
for batch in loader:
with env.begin(write=True) as txn:
for img in batch:
key = f"{256}-{str(i).zfill(7)}".encode("utf-8")
# print(key)
txn.put(key, img)
i += 1
progress.update()
# if i == 1000:
# break
# if total == len(imgset):
# break

with env.begin(write=True) as txn:
txn.put("length".encode("utf-8"), str(i).encode("utf-8"))
Loading

0 comments on commit 826d44f

Please sign in to comment.