Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
phizaz committed Mar 18, 2022
1 parent 2a7c745 commit 40f8568
Show file tree
Hide file tree
Showing 15 changed files with 133 additions and 729 deletions.
17 changes: 8 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,15 @@ For manipulation: `manipulate.ipynb`

### Checkpoints

Checkpoints ought to be put into a separate directory `checkpoints`.
We provide checkpoints for the following models:

The directory tree may look like:
1. DDIM: **FFHQ128** ([72M](https://drive.google.com/drive/folders/1-J8FPNZOQxSqpfTpwRXawLi2KKGL1qlK?usp=sharing), [130M](https://drive.google.com/drive/folders/17T5YJXpYdgE6cWltN8gZFxRsJzpVxnLh?usp=sharing)), [**Bedroom128**](https://drive.google.com/drive/folders/19s-lAiK7fGD5Meo5obNV5o0L3MfqU0Sk?usp=sharing), [**Horse128**](https://drive.google.com/drive/folders/1PiC5JWLcd8mZW9cghDCR0V4Hx0QCXOor?usp=sharing)
2. DiffAE (autoencoding only): [**FFHQ256**](https://drive.google.com/drive/folders/1hTP9QbYXwv_Nl5sgcZNH0yKprJx7ivC5?usp=sharing), **FFHQ128** ([72M](https://drive.google.com/drive/folders/15QHmZP1G5jEMh80R1Nbtdb4ZKb6VvfII?usp=sharing), [130M](https://drive.google.com/drive/folders/1UlwLwgv16cEqxTn7g-V2ykIyopmY_fVz?usp=sharing)), [**Bedroom128**](https://drive.google.com/drive/folders/1okhCb1RezlWmDbdEAGWMHMkUBRRXmey0?usp=sharing), [**Horse128**](https://drive.google.com/drive/folders/1Ujmv3ajeiJLOT6lF2zrQb4FimfDkMhcP?usp=sharing)
3. DiffAE (with latent DPM, can sample): [**FFHQ256**](https://drive.google.com/drive/folders/1MonJKYwVLzvCFYuVhp-l9mChq5V2XI6w?usp=sharing), [**FFHQ128**](https://drive.google.com/drive/folders/1E3Ew1p9h42h7UA1DJNK7jnb2ERybg9ji?usp=sharing), [**Bedroom128**](https://drive.google.com/drive/folders/1okhCb1RezlWmDbdEAGWMHMkUBRRXmey0?usp=sharing), [**Horse128**](https://drive.google.com/drive/folders/1Ujmv3ajeiJLOT6lF2zrQb4FimfDkMhcP?usp=sharing)
4. DiffAE's classifiers (for manipulation): [**FFHQ256's latent on CelebAHQ**](https://drive.google.com/drive/folders/1QGkTfvNhgi_TbbV8GbX1Emrp0lStsqLj?usp=sharing), [**FFHQ128's latent on CelebAHQ**](https://drive.google.com/drive/folders/1E3Ew1p9h42h7UA1DJNK7jnb2ERybg9ji?usp=sharing)

Checkpoints ought to be put into a separate directory `checkpoints`.
Download the checkpoints and put them into `checkpoints` directory. It should look like this:

```
checkpoints/
Expand All @@ -33,13 +39,6 @@ checkpoints/
- ...
```

We provide checkpoints for the following models:

1. DDIM: FFHQ128 ([72M](https://drive.google.com/drive/folders/1-J8FPNZOQxSqpfTpwRXawLi2KKGL1qlK?usp=sharing), [130M](https://drive.google.com/drive/folders/17T5YJXpYdgE6cWltN8gZFxRsJzpVxnLh?usp=sharing)), [Bedroom128](https://drive.google.com/drive/folders/19s-lAiK7fGD5Meo5obNV5o0L3MfqU0Sk?usp=sharing), [Horse128](https://drive.google.com/drive/folders/1PiC5JWLcd8mZW9cghDCR0V4Hx0QCXOor?usp=sharing)
2. DiffAE (autoencoding only): [FFHQ256](https://drive.google.com/drive/folders/1hTP9QbYXwv_Nl5sgcZNH0yKprJx7ivC5?usp=sharing), FFHQ128 ([72M](https://drive.google.com/drive/folders/15QHmZP1G5jEMh80R1Nbtdb4ZKb6VvfII?usp=sharing), [130M](https://drive.google.com/drive/folders/1UlwLwgv16cEqxTn7g-V2ykIyopmY_fVz?usp=sharing)), [Bedroom128](https://drive.google.com/drive/folders/1okhCb1RezlWmDbdEAGWMHMkUBRRXmey0?usp=sharing), [Horse128](https://drive.google.com/drive/folders/1Ujmv3ajeiJLOT6lF2zrQb4FimfDkMhcP?usp=sharing)
3. DiffAE (with latent DPM, can sample): [FFHQ256](https://drive.google.com/drive/folders/1MonJKYwVLzvCFYuVhp-l9mChq5V2XI6w?usp=sharing), [FFHQ128](https://drive.google.com/drive/folders/1E3Ew1p9h42h7UA1DJNK7jnb2ERybg9ji?usp=sharing), [Bedroom128](https://drive.google.com/drive/folders/1okhCb1RezlWmDbdEAGWMHMkUBRRXmey0?usp=sharing), [Horse128](https://drive.google.com/drive/folders/1Ujmv3ajeiJLOT6lF2zrQb4FimfDkMhcP?usp=sharing)
4. DiffAE's classifiers (for manipulation): [FFHQ256's latent on CelebAHQ](https://drive.google.com/drive/folders/1QGkTfvNhgi_TbbV8GbX1Emrp0lStsqLj?usp=sharing), [FFHQ128's latent on CelebAHQ](https://drive.google.com/drive/folders/1E3Ew1p9h42h7UA1DJNK7jnb2ERybg9ji?usp=sharing)


### LMDB Datasets

Expand Down
61 changes: 0 additions & 61 deletions choices.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,6 @@ def can_sample(self):
return self in [ModelType.ddpm]


class ChamferType(Enum):
chamfer = 'chamfer'
stochastic = 'stochastic'


class ModelName(Enum):
"""
List of all supported model classes
Expand All @@ -116,24 +111,12 @@ class ModelName(Enum):
beatgans_autoenc = 'beatgans_autoenc'


class EncoderName(Enum):
"""
List of all encoders for ddpm models
"""

v1 = 'v1'
v2 = 'v2'


class ModelMeanType(Enum):
"""
Which type of output the model predicts.
"""

prev_x = 'x_prev' # the model predicts x_{t-1}
start_x = 'x_start' # the model predicts x_0
eps = 'eps' # the model predicts epsilon
scaled_start_x = 'scaledxstart' # the model predicts sqrt(alphacum) x_0


class ModelVarType(Enum):
Expand All @@ -144,59 +127,15 @@ class ModelVarType(Enum):
values between FIXED_SMALL and FIXED_LARGE, making its job easier.
"""

# learned directly
learned = 'learned'
# posterior beta_t
fixed_small = 'fixed_small'
# beta_t
fixed_large = 'fixed_large'
# predict values between FIXED_SMALL and FIXED_LARGE, making its job easier
learned_range = 'learned_range'


class LossType(Enum):
mse = 'mse' # use raw MSE loss (and KL when learning variances)
l1 = 'l1'
# mse weighted by the variance, somewhat like in kl
mse_var_weighted = 'mse_weighted'
mse_rescaled = 'mse_rescaled' # use raw MSE loss (with RESCALED_KL when learning variances)
kl = 'kl' # use the variational lower-bound
kl_rescaled = 'kl_rescaled' # like KL, but rescale to estimate the full VLB

def is_vb(self):
return self == LossType.kl or self == LossType.kl_rescaled


class MSEWeightType(Enum):
# use the ddpm's default variance (either analytical or learned)
var = 'var'
# optimal variance by deriving the min kl per image (based on mse of epsilon)
# = small sigma + mse
var_min_kl_img = 'varoptimg'
# optimal variance regradless of the posterior sigmas
# = mse only
var_min_kl_mse_img = 'varoptmseimg'
# same as the above but is based on mse of mu of xprev
var_min_kl_xprev_img = 'varoptxprevimg'


class XStartWeightType(Enum):
# weights for the mse of the xstart
# unweighted x start
uniform = 'uniform'
# reciprocal 1 - alpha_bar
reciprocal_alphabar = 'recipalpha'
# same as the above but not exceeding mse = 1
reciprocal_alphabar_safe = 'recipalphasafe'
# turning x0 into eps as use the mse(eps)
eps = 'eps'
# the same as above but not turning into eps
eps2 = 'eps2'
# same as the above but not exceeding mse = 1
eps2_safe = 'eps2safe'
eps_huber = 'epshuber'
unit_mse_x0 = 'unitmsex0'
unit_mse_eps = 'unitmseeps'


class GenerativeType(Enum):
Expand Down
64 changes: 15 additions & 49 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,33 +72,24 @@ class TrainConfig(BaseConfig):
autoenc_mid_attn: bool = True
batch_size: int = 16
batch_size_eval: int = None
beatgans_gen_type: GenerativeType = GenerativeType.ddpm
beatgans_gen_type: GenerativeType = GenerativeType.ddim
beatgans_loss_type: LossType = LossType.mse
beatgans_model_mean_type: ModelMeanType = ModelMeanType.eps
beatgans_model_var_type: ModelVarType = ModelVarType.fixed_large
beatgans_model_mse_weight_type: MSEWeightType = MSEWeightType.var
beatgans_xstart_weight_type: XStartWeightType = XStartWeightType.uniform
beatgans_rescale_timesteps: bool = False
latent_infer_path: str = None
latent_znormalize: bool = False
latent_gen_type: GenerativeType = GenerativeType.ddpm
latent_gen_type: GenerativeType = GenerativeType.ddim
latent_loss_type: LossType = LossType.mse
latent_model_mean_type: ModelMeanType = ModelMeanType.eps
latent_model_var_type: ModelVarType = ModelVarType.fixed_large
latent_model_mse_weight_type: MSEWeightType = MSEWeightType.var
latent_xstart_weight_type: XStartWeightType = XStartWeightType.uniform
latent_rescale_timesteps: bool = False
latent_T_eval: int = 1_000
latent_clip_sample: bool = False
latent_beta_scheduler: str = 'linear'
beta_scheduler: str = 'linear'
data_name: str = 'ffhq'
data_name: str = ''
data_val_name: str = None
def_beta_1: float = 1e-4
def_beta_T: float = 0.02
def_mean_type: str = 'epsilon'
def_var_type: str = 'fixedlarge'
device: str = 'cuda:0'
diffusion_type: str = None
dropout: float = 0.1
ema_decay: float = 0.9999
Expand All @@ -109,10 +100,7 @@ class TrainConfig(BaseConfig):
fp16: bool = False
grad_clip: float = 1
img_size: int = 64
kl_coef: float = None
chamfer_coef: float = 1
chamfer_type: ChamferType = ChamferType.chamfer
lr: float = 0.0002
lr: float = 0.0001
optimizer: OptimizerType = OptimizerType.adam
weight_decay: float = 0
model_conf: ModelConfig = None
Expand All @@ -124,49 +112,32 @@ class TrainConfig(BaseConfig):
net_beatgans_embed_channels: int = 512
net_resblock_updown: bool = True
net_enc_use_time: bool = False
net_enc_pool: str = 'depthconv'
net_enc_pool_tail_layer: int = None
net_enc_pool: str = 'adaptivenonzero'
net_beatgans_gradient_checkpoint: bool = False
net_beatgans_resnet_two_cond: bool = False
net_beatgans_resnet_use_zero_module: bool = True
net_beatgans_resnet_scale_at: ScaleAt = ScaleAt.after_norm
net_beatgans_resnet_cond_channels: int = None
mmd_alphas: Tuple[float] = (0.5, )
mmd_coef: float = 0.1
latent_detach: bool = True
latent_unit_normalize: bool = False
net_ch_mult: Tuple[int] = None
net_ch: int = 64
net_enc_attn: Tuple[int] = None
net_enc_k: int = None
net_enc_name: EncoderName = EncoderName.v1
# number of resblocks for the encoder (half-unet)
net_enc_num_res_blocks: int = 2
net_enc_tail_depth: int = 2
net_enc_channel_mult: Tuple[int] = None
net_enc_grad_checkpoint: bool = False
net_autoenc_stochastic: bool = False
net_latent_activation: Activation = Activation.silu
net_latent_attn_resolutions: Tuple[int] = tuple()
net_latent_blocks: int = None
net_latent_channel_mult: Tuple[int] = (1, 2, 4)
net_latent_cond_both: bool = True
net_latent_condition_bias: float = 0
net_latent_dropout: float = 0
net_latent_layers: int = None
net_latent_net_last_act: Activation = Activation.none
net_latent_net_type: LatentNetType = LatentNetType.none
net_latent_num_hid_channels: int = 1024
net_latent_num_res_blocks: int = 2
net_latent_num_time_layers: int = 2
net_latent_pooling: str = 'linear'
net_latent_project_size: int = 4
net_latent_residual: bool = False
net_latent_skip_layers: Tuple[int] = None
net_latent_time_emb_channels: int = 64
net_latent_time_layer_init: bool = False
net_latent_unpool: str = 'conv'
net_latent_use_mid_attn: bool = True
net_latent_use_norm: bool = False
net_latent_time_last_act: bool = False
net_num_res_blocks: int = 2
Expand All @@ -190,12 +161,11 @@ class TrainConfig(BaseConfig):
eval_programs: Tuple[str] = None
# if present load the checkpoint from this path instead
eval_path: str = None
base_dir: str = 'logs'
base_dir: str = 'checkpoints'
use_cache_dataset: bool = False
data_cache_dir: str = os.path.expanduser('~/cache')
work_cache_dir: str = os.path.expanduser('~/mycache')
# data_cache_dir: str = os.path.expanduser('/scratch/konpat')
# work_cache_dir: str = os.path.expanduser('/scratch/konpat')
# to be overridden
name: str = ''

def __post_init__(self):
Expand Down Expand Up @@ -265,15 +235,11 @@ def _make_diffusion_conf(self, T=None):
betas=get_named_beta_schedule(self.beta_scheduler, self.T),
model_mean_type=self.beatgans_model_mean_type,
model_var_type=self.beatgans_model_var_type,
model_mse_weight_type=self.beatgans_model_mse_weight_type,
xstart_weight_type=self.beatgans_xstart_weight_type,
loss_type=self.beatgans_loss_type,
rescale_timesteps=self.beatgans_rescale_timesteps,
use_timesteps=space_timesteps(num_timesteps=self.T,
section_counts=section_counts),
fp16=self.fp16,
mmd_alphas=self.mmd_alphas,
mmd_coef=self.mmd_coef,
)
else:
raise NotImplementedError()
Expand All @@ -298,8 +264,6 @@ def _make_latent_diffusion_conf(self, T=None):
betas=get_named_beta_schedule(self.latent_beta_scheduler, self.T),
model_mean_type=self.latent_model_mean_type,
model_var_type=self.latent_model_var_type,
model_mse_weight_type=self.latent_model_mse_weight_type,
xstart_weight_type=self.latent_xstart_weight_type,
loss_type=self.latent_loss_type,
rescale_timesteps=self.latent_rescale_timesteps,
use_timesteps=space_timesteps(num_timesteps=self.T,
Expand Down Expand Up @@ -348,10 +312,15 @@ def make_dataset(self, path=None, **kwargs):
return Horse_lmdb(path=path or self.data_path,
image_size=self.img_size,
**kwargs)
elif self.data_name == 'celebalmdb':
# always use d2c crop
return CelebAlmdb(path=path or self.data_path,
image_size=self.img_size,
original_resolution=None,
crop_d2c=True,
**kwargs)
else:
return ImageDataset(folder=path or self.data_path,
image_size=self.img_size,
**kwargs)
raise NotImplementedError()

def make_loader(self,
dataset,
Expand Down Expand Up @@ -431,8 +400,6 @@ def make_model_conf(self):
dropout=self.net_latent_dropout,
last_act=self.net_latent_net_last_act,
num_time_layers=self.net_latent_num_time_layers,
time_layer_init=self.net_latent_time_layer_init,
residual=self.net_latent_residual,
time_last_act=self.net_latent_time_last_act,
)
else:
Expand All @@ -447,7 +414,6 @@ def make_model_conf(self):
embed_channels=self.net_beatgans_embed_channels,
enc_out_channels=self.style_ch,
enc_pool=self.net_enc_pool,
enc_pool_tail_layer=self.net_enc_pool_tail_layer,
enc_num_res_block=self.net_enc_num_res_blocks,
enc_channel_mult=self.net_enc_channel_mult,
enc_grad_checkpoint=self.net_enc_grad_checkpoint,
Expand Down
75 changes: 65 additions & 10 deletions dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,62 @@ def d2c_crop():
return Crop(x1, x2, y1, y2)


class CelebAlmdb(Dataset):
"""
also supports for d2c crop.
"""
def __init__(self,
path,
image_size,
original_resolution=128,
split=None,
as_tensor: bool = True,
do_augment: bool = True,
do_normalize: bool = True,
crop_d2c: bool = False,
**kwargs):
self.original_resolution = original_resolution
self.data = BaseLMDB(path, original_resolution, zfill=7)
self.length = len(self.data)
self.crop_d2c = crop_d2c

if split is None:
self.offset = 0
else:
raise NotImplementedError()

if crop_d2c:
transform = [
d2c_crop(),
transforms.Resize(image_size),
]
else:
transform = [
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
]

if do_augment:
transform.append(transforms.RandomHorizontalFlip())
if as_tensor:
transform.append(transforms.ToTensor())
if do_normalize:
transform.append(
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
self.transform = transforms.Compose(transform)

def __len__(self):
return self.length

def __getitem__(self, index):
assert index < self.length
index = index + self.offset
img = self.data[index]
if self.transform is not None:
img = self.transform(img)
return {'img': img, 'index': index}


class Horse_lmdb(Dataset):
def __init__(self,
path=os.path.expanduser('datasets/horse256.lmdb'),
Expand Down Expand Up @@ -534,16 +590,15 @@ class CelebHQAttrDataset(Dataset):
]
cls_to_id = {v: k for k, v in enumerate(id_to_cls)}

def __init__(
self,
path=os.path.expanduser('datasets/celebahq256.lmdb'),
image_size=None,
attr_path=os.path.expanduser(
'datasets/celeba_anno/CelebAMask-HQ-attribute-anno.txt'),
original_resolution=256,
do_augment: bool = False,
do_transform: bool = True,
do_normalize: bool = True):
def __init__(self,
path=os.path.expanduser('datasets/celebahq256.lmdb'),
image_size=None,
attr_path=os.path.expanduser(
'datasets/celeba_anno/CelebAMask-HQ-attribute-anno.txt'),
original_resolution=256,
do_augment: bool = False,
do_transform: bool = True,
do_normalize: bool = True):
super().__init__()
self.image_size = image_size
self.data = BaseLMDB(path, original_resolution, zfill=5)
Expand Down
Loading

0 comments on commit 40f8568

Please sign in to comment.