Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
phizaz committed Mar 17, 2022
1 parent 829e04a commit a9c4dde
Show file tree
Hide file tree
Showing 15 changed files with 697 additions and 1,328 deletions.
101 changes: 6 additions & 95 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ class TrainConfig(BaseConfig):
T_sampler: str = 'uniform'
T: int = 1_000
total_samples: int = 10_000_000
warmup: int = 5000
warmup: int = 0
pretrain: PretrainConfig = None
continue_from: PretrainConfig = None
eval_programs: Tuple[str] = None
Expand All @@ -194,107 +194,18 @@ class TrainConfig(BaseConfig):
use_cache_dataset: bool = False
data_cache_dir: str = os.path.expanduser('~/cache')
work_cache_dir: str = os.path.expanduser('~/mycache')

# data_cache_dir: str = os.path.expanduser('/scratch/konpat')
# work_cache_dir: str = os.path.expanduser('/scratch/konpat')
name: str = ''

def __post_init__(self):
self.batch_size_eval = self.batch_size_eval or self.batch_size
self.data_val_name = self.data_val_name or self.data_name

@property
def name(self):
self.make_model_conf()
names = []
tmp = f'{self.data_name}{self.img_size}-bs{self.batch_size}'
if self.accum_batches > 1:
tmp += f'accum{self.accum_batches}'
if self.optimizer != OptimizerType.adam:
tmp += f'-{self.optimizer.value}lr{self.lr}'
else:
tmp += f'-lr{self.lr}'
if self.weight_decay > 0:
tmp += f'wd{self.weight_decay}'
if self.grad_clip != 1:
if self.grad_clip < 0:
tmp += '-noclip'
else:
tmp += f'-clip{self.grad_clip}'
if self.warmup != 5000:
tmp += f'-warmup{self.warmup}'

if self.train_mode.is_manipulate():
tmp += f'_mani{self.manipulate_mode.value}'
if self.manipulate_mode.is_single_class():
tmp += f'-{self.manipulate_cls}'
if self.manipulate_mode.is_fewshot():
tmp += f'-{self.manipulate_shots}shots'
if self.manipulate_znormalize:
tmp += '-znorm'
if self.manipulate_mode.is_fewshot():
tmp += f'-seed{self.manipulate_seed}'

if self.train_mode.is_diffusion():
tmp += f'_ddpm-T{self.T}-Tgen{self.T_eval}'
if self.diffusion_type == 'default':
tmp += '-default'
elif self.diffusion_type == 'beatgans':
tmp += f'-beatgans-gen{self.beatgans_gen_type.value}'
if self.beta_scheduler != 'linear':
tmp += f'-beta{self.beta_scheduler}'
if self.beatgans_model_mean_type != ModelMeanType.eps:
tmp += f'-pred{self.beatgans_model_mean_type.value}'
if self.beatgans_loss_type != LossType.mse:
tmp += f'-loss{self.beatgans_loss_type.value}'
if self.beatgans_loss_type == LossType.mse_var_weighted:
tmp += f'{self.beatgans_model_mse_weight_type.value}'
else:
if self.beatgans_model_mean_type == ModelMeanType.start_x:
tmp += f'-weight{self.beatgans_xstart_weight_type.value}'
if self.beatgans_model_var_type != ModelVarType.fixed_large:
tmp += f'-var{self.beatgans_model_var_type.value}'

if self.train_mode.use_latent_net():
# latent diffusion configs
tmp += f'_latentddpm-Tgen{self.latent_T_eval}'
if self.latent_beta_scheduler != 'linear':
tmp += f'-beta{self.latent_beta_scheduler}'
tmp += f'-gen{self.latent_gen_type.value}'
if self.latent_model_mean_type != ModelMeanType.eps:
tmp += f'-pred{self.latent_model_mean_type.value}'
if self.latent_loss_type != LossType.mse:
tmp += f'-loss{self.latent_loss_type.value}'
if self.latent_loss_type == LossType.mse_var_weighted:
tmp += f'{self.latent_model_mse_weight_type.value}'
else:
if self.latent_model_mean_type == ModelMeanType.start_x:
tmp += f'-weight{self.latent_xstart_weight_type.value}'
if self.latent_model_var_type != ModelVarType.fixed_large:
tmp += f'-var{self.latent_model_var_type.value}'

if self.train_mode.is_latent_diffusion():
if self.latent_znormalize:
tmp += '-znorm'
if self.latent_clip_sample:
tmp += '-clip'
if self.latent_unit_normalize:
tmp += '-unit'

if self.ema_decay != 0.9999 and not self.train_mode.is_manipulate():
tmp += f'-ema{self.ema_decay}'

if self.fp16:
tmp += '_fp16'

if self.pretrain is not None:
tmp += f'_pt{self.pretrain.name}'

if self.continue_from is not None:
tmp += f'_contd{self.continue_from.name}'

names.append(tmp)
names.append(self.model_conf.name)
return '/'.join(names) + self.postfix
# @property
# def name(self):
# # self.make_model_conf()
# raise NotImplementedError()

def scale_up_gpus(self, num_gpus, num_nodes=1):
self.eval_ema_every_samples *= num_gpus * num_nodes
Expand Down
28 changes: 1 addition & 27 deletions diffusion/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,34 +355,8 @@ def training_losses(self,
else:
raise NotImplementedError(self.loss_type)

if self.conf.model_type == ModelType.vaeddpm:
# also include the vae prior loss
model_forward: AutoencReturn
# (n, c)
loss_kl = -0.5 * (1 + model_forward.cond_logvar -
model_forward.cond_mu**2 -
model_forward.cond_logvar.exp())
# factor between latent code and pixels
# because vae losses are "sum" over dimensions
dim_factor = loss_kl[0].numel() / x_t[0].numel()
# (n, )
loss_kl = mean_flat(loss_kl) * dim_factor
terms['vae'] = loss_kl
terms['loss'] = terms['loss'] + terms['vae']
elif self.conf.model_type == ModelType.mmdddpm:
# using mmd loss on the latent
terms['mmd'] = self.mmd_loss(model, model_forward.cond)
terms['loss'] = terms['loss'] + terms['mmd']

return terms

def mmd_loss(self, model: Model, cond):
cond = cond.float()
square_mmd = SquaredMMD(alphas=self.conf.mmd_alphas)
prior = th.randn_like(cond)
prior = model.noise_to_cond(prior)
return square_mmd(cond, prior) * self.conf.mmd_coef

def sample(self,
model: Model,
shape=None,
Expand Down Expand Up @@ -935,7 +909,7 @@ def ddim_reverse_sample_loop(
T.append(t)

return {
# x0' "
# xT "
'sample': sample,
# (1, ..., T)
'sample_t': sample_t,
Expand Down
47 changes: 47 additions & 0 deletions experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,51 @@ def denormalize(self, cond):
self.device)
return cond

def sample(self, N, device):
noise = torch.randn(N,
3,
self.conf.img_size,
self.conf.img_size,
device=device)
pred_img = render_uncondition(
self.conf,
self.ema_model,
noise,
sampler=self.eval_sampler,
latent_sampler=self.eval_latent_sampler,
conds_mean=self.conds_mean,
conds_std=self.conds_std,
)
pred_img = (pred_img + 1) / 2
return pred_img

def render(self, noise, cond=None):
if cond is not None:
pred_img = render_condition(self.conf,
self.ema_model,
noise,
sampler=self.eval_sampler,
cond=cond)
else:
pred_img = render_uncondition(self.conf,
self.ema_model,
noise,
sampler=self.eval_sampler,
latent_sampler=None)
pred_img = (pred_img + 1) / 2
return pred_img

def encode(self, x):
# TODO:
assert self.conf.model_type.has_autoenc()
cond = self.ema_model.encoder.forward(x)
return cond

def encode_stochastic(self, x, cond):
out = self.eval_sampler.ddim_reverse_sample_loop(
self.ema_model, x, model_kwargs={'cond': cond})
return out['sample']

def forward(self, noise=None, x_start=None, ema_model: bool = False):
with amp.autocast(False):
if ema_model:
Expand Down Expand Up @@ -878,8 +923,10 @@ def train(conf: TrainConfig, gpus, nodes=1, mode: str = 'train'):
every_n_train_steps=conf.save_every_samples //
conf.batch_size_effective)
checkpoint_path = f'{conf.logdir}/last.ckpt'
print('ckpt path:', checkpoint_path)
if os.path.exists(checkpoint_path):
resume = checkpoint_path
print('resume!')
else:
if conf.continue_from is not None:
# continue from a checkpoint
Expand Down
2 changes: 1 addition & 1 deletion gen_autoenc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from run_templates import *
from tqdm.autonotebook import tqdm

conf = ffhq128_autoenc_200M()
conf = ffhq128_autoenc_130M()
# conf = ffhq256_autoenc()
conf.device = 'cuda:0'
print(conf.name)
Expand Down
4 changes: 2 additions & 2 deletions gen_conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def cond_sample_rejection(
return pred_img

def load(self):
state = torch.load(f'log-latent/{self.conf.name}/last.ckpt',
state = torch.load(f'checkpoints/{self.conf.name}/last.ckpt',
map_location='cpu')
print('main step:', state['global_step'])
model = LitModel(self.conf)
Expand All @@ -255,7 +255,7 @@ def load(self):
return model

def load_cls(self):
state = torch.load(f'logs-cls/{self.cls_conf.name}/last.ckpt',
state = torch.load(f'checkpoints/{self.cls_conf.name}/last.ckpt',
map_location='cpu')
print('latent step:', state['global_step'])
model = ClsModel(self.cls_conf)
Expand Down
37 changes: 33 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from run_cls_templates import cls_ffhq128_autoenc, cls_ffhq256_autoenc
from run_latent_templates import bedroom128_autoenc_latent, celeba64d2c_autoenc_latent, ffhq128_autoenc_latent, ffhq256_autoenc_latent, horse128_autoenc_latent
from run_templates import *

if __name__ == '__main__':
gpus = [0]
gpus = [3]

# conf = horse128_cosine_autoenc_thinner_deep_morech()
# conf.net_beatgans_resnet_use_inlayers_cond = True
Expand All @@ -22,9 +24,36 @@
# net_beatgans_three_cond=True,
# )

# conf = ffhq64_autoenc_48M()
# conf = ffhq128_autoenc()
# conf = ffhq128_autoenc_72M()
# conf = ffhq128_autoenc_130M()
# conf = ffhq128_ddpm_130M()
# conf = ffhq128_ddpm_72M()
# conf = ffhq256_autoenc()
# conf = bedroom128_autoenc()
# conf = bedroom128_ddpm()
# conf = horse128_autoenc_130M()
# conf = horse128_ddpm_130M()
# conf = celeba64d2c_autoenc()
# conf = celeba64d2c_ddpm()
# conf = ffhq128_autoenc_latent()
# conf = celeba64d2c_autoenc_latent()
# conf = horse128_autoenc_latent()
# conf = bedroom128_autoenc_latent()
# conf = ffhq256_autoenc_latent()
# conf = cls_ffhq128_all()
# conf = cls_ffhq256_all()

from shutil import copy

src = f'logs-cls/{conf.name}/last.ckpt'
tgt = f'checkpoints/ffhq256_autoenc_cls/last.ckpt'
if not os.path.exists(os.path.dirname(tgt)):
os.makedirs(os.path.dirname(tgt))
print('copying ..')
print(src, tgt)
copy(src, tgt)

# conf.batch_size = 8
# conf = pretrain_ffhq64_autoenc48M()
# conf = latent_diffusion_config(conf)
# # conf = latent_unet4_512(conf)
Expand Down Expand Up @@ -61,5 +90,5 @@
# conf.eval_programs = ['infer']
# conf.eval_programs = ['fid(20,20)']

train(conf, gpus=gpus)
# train(conf, gpus=gpus)
# train(conf, gpus=gpus, mode='eval')
Loading

0 comments on commit a9c4dde

Please sign in to comment.