From 2a7c74571963ce19124557273d41119853166d74 Mon Sep 17 00:00:00 2001 From: Konpat Preechakul Date: Fri, 18 Mar 2022 08:57:51 +0700 Subject: [PATCH] update --- config.py | 61 ++------------------------------------------------- experiment.py | 9 ++------ 2 files changed, 4 insertions(+), 66 deletions(-) diff --git a/config.py b/config.py index d103418..50b000a 100644 --- a/config.py +++ b/config.py @@ -202,11 +202,6 @@ def __post_init__(self): self.batch_size_eval = self.batch_size_eval or self.batch_size self.data_val_name = self.data_val_name or self.data_name - # @property - # def name(self): - # # self.make_model_conf() - # raise NotImplementedError() - def scale_up_gpus(self, num_gpus, num_nodes=1): self.eval_ema_every_samples *= num_gpus * num_nodes self.eval_every_samples *= num_gpus * num_nodes @@ -253,19 +248,7 @@ def generate_dir(self): return f'{self.work_cache_dir}/gen_images/{self.name}' def _make_diffusion_conf(self, T=None): - if self.diffusion_type == 'default': - assert T == self.T - assert self.beta_scheduler == 'linear' - return DiffusionDefaultConfig(beta_1=self.def_beta_1, - beta_T=self.def_beta_T, - T=self.T, - img_size=self.img_size, - mean_type=self.def_mean_type, - var_type=self.def_var_type, - model_type=self.model_type, - kl_coef=self.kl_coef, - fp16=self.fp16) - elif self.diffusion_type == 'beatgans': + if self.diffusion_type == 'beatgans': # can use T < self.T for evaluation # follows the guided-diffusion repo conventions # t's are evenly spaced @@ -326,15 +309,7 @@ def _make_latent_diffusion_conf(self, T=None): @property def model_out_channels(self): - if self.diffusion_type == 'beatgans': - if self.beatgans_model_var_type in [ - ModelVarType.learned, ModelVarType.learned_range - ]: - return 6 - else: - return 3 - else: - return 3 + return 3 def make_T_sampler(self): if self.T_sampler == 'uniform': @@ -373,43 +348,11 @@ def make_dataset(self, path=None, **kwargs): return Horse_lmdb(path=path or self.data_path, image_size=self.img_size, **kwargs) - elif self.data_name == 'celebalmdb': - return CelebAlmdb(path=path or self.data_path, - image_size=self.img_size, - original_resolution=None, - crop_d2c=True, - **kwargs) - elif self.data_name == 'celebaalignlmdb': - return CelebAlmdb(path=path or self.data_path, - image_size=self.img_size, - **kwargs) - - elif self.data_name == 'celebahq': - return CelebHQLMDB(path=path or self.data_path, - image_size=self.img_size, - **kwargs) else: return ImageDataset(folder=path or self.data_path, image_size=self.img_size, **kwargs) - def make_test_dataset(self, **kwargs): - if self.data_val_name == 'ffhqlmdbsplit256': - print('test on ffhq split') - return FFHQlmdb(path=data_paths['ffhqlmdbsplit256'][0], - original_resolution=256, - image_size=self.img_size, - split='test', - **kwargs) - elif self.data_val_name == 'celebhq': - print('test on celebhq') - return CelebHQLMDB(path=data_paths['celebahq'][0], - original_resolution=256, - image_size=self.img_size, - **kwargs) - else: - return None - def make_loader(self, dataset, shuffle: bool, diff --git a/experiment.py b/experiment.py index 24e6352..2488e0f 100755 --- a/experiment.py +++ b/experiment.py @@ -164,13 +164,8 @@ def setup(self, stage=None) -> None: self.train_data = self.conf.make_dataset() print('train data:', len(self.train_data)) - self.val_data = self.conf.make_test_dataset() - if self.val_data is None: - # val data is provided, use the train data - self.val_data = self.train_data - else: - # val data is provided - print('val data:', len(self.val_data)) + self.val_data = self.train_data + print('val data:', len(self.val_data)) def _train_dataloader(self, drop_last=True): """