Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
phizaz committed Mar 18, 2022
1 parent 826d44f commit 2a7c745
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 66 deletions.
61 changes: 2 additions & 59 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,6 @@ def __post_init__(self):
self.batch_size_eval = self.batch_size_eval or self.batch_size
self.data_val_name = self.data_val_name or self.data_name

# @property
# def name(self):
# # self.make_model_conf()
# raise NotImplementedError()

def scale_up_gpus(self, num_gpus, num_nodes=1):
self.eval_ema_every_samples *= num_gpus * num_nodes
self.eval_every_samples *= num_gpus * num_nodes
Expand Down Expand Up @@ -253,19 +248,7 @@ def generate_dir(self):
return f'{self.work_cache_dir}/gen_images/{self.name}'

def _make_diffusion_conf(self, T=None):
if self.diffusion_type == 'default':
assert T == self.T
assert self.beta_scheduler == 'linear'
return DiffusionDefaultConfig(beta_1=self.def_beta_1,
beta_T=self.def_beta_T,
T=self.T,
img_size=self.img_size,
mean_type=self.def_mean_type,
var_type=self.def_var_type,
model_type=self.model_type,
kl_coef=self.kl_coef,
fp16=self.fp16)
elif self.diffusion_type == 'beatgans':
if self.diffusion_type == 'beatgans':
# can use T < self.T for evaluation
# follows the guided-diffusion repo conventions
# t's are evenly spaced
Expand Down Expand Up @@ -326,15 +309,7 @@ def _make_latent_diffusion_conf(self, T=None):

@property
def model_out_channels(self):
if self.diffusion_type == 'beatgans':
if self.beatgans_model_var_type in [
ModelVarType.learned, ModelVarType.learned_range
]:
return 6
else:
return 3
else:
return 3
return 3

def make_T_sampler(self):
if self.T_sampler == 'uniform':
Expand Down Expand Up @@ -373,43 +348,11 @@ def make_dataset(self, path=None, **kwargs):
return Horse_lmdb(path=path or self.data_path,
image_size=self.img_size,
**kwargs)
elif self.data_name == 'celebalmdb':
return CelebAlmdb(path=path or self.data_path,
image_size=self.img_size,
original_resolution=None,
crop_d2c=True,
**kwargs)
elif self.data_name == 'celebaalignlmdb':
return CelebAlmdb(path=path or self.data_path,
image_size=self.img_size,
**kwargs)

elif self.data_name == 'celebahq':
return CelebHQLMDB(path=path or self.data_path,
image_size=self.img_size,
**kwargs)
else:
return ImageDataset(folder=path or self.data_path,
image_size=self.img_size,
**kwargs)

def make_test_dataset(self, **kwargs):
if self.data_val_name == 'ffhqlmdbsplit256':
print('test on ffhq split')
return FFHQlmdb(path=data_paths['ffhqlmdbsplit256'][0],
original_resolution=256,
image_size=self.img_size,
split='test',
**kwargs)
elif self.data_val_name == 'celebhq':
print('test on celebhq')
return CelebHQLMDB(path=data_paths['celebahq'][0],
original_resolution=256,
image_size=self.img_size,
**kwargs)
else:
return None

def make_loader(self,
dataset,
shuffle: bool,
Expand Down
9 changes: 2 additions & 7 deletions experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,13 +164,8 @@ def setup(self, stage=None) -> None:

self.train_data = self.conf.make_dataset()
print('train data:', len(self.train_data))
self.val_data = self.conf.make_test_dataset()
if self.val_data is None:
# val data is provided, use the train data
self.val_data = self.train_data
else:
# val data is provided
print('val data:', len(self.val_data))
self.val_data = self.train_data
print('val data:', len(self.val_data))

def _train_dataloader(self, drop_last=True):
"""
Expand Down

0 comments on commit 2a7c745

Please sign in to comment.