Skip to content

Commit

Permalink
Merge pull request Sygil-Dev#54 from ZeroCool940711/dev
Browse files Browse the repository at this point in the history
Added a couple of lines to print the total number of parameters the model has after it has been loaded for the vae and maskgit training scripts as well as the infer_vae.py
  • Loading branch information
ZeroCool940711 authored Jul 6, 2023
2 parents 9d0eb94 + eb0d1ae commit ed614c9
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 0 deletions.
6 changes: 6 additions & 0 deletions infer_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,12 @@ def main():
# move vae to device
vae = vae.to(accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}")


# Use the parameters() method to get an iterator over all the learnable parameters of the model
total_params = sum(p.numel() for p in vae.parameters())

print(f"Total number of parameters: {format(total_params, ',d')}")

# then you plug the vae and transformer into your MaskGit as so

dataset = ImageDataset(
Expand Down
6 changes: 6 additions & 0 deletions train_muse_maskgit.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,6 +790,12 @@ def main():
accelerator.print("Initialized new empty MaskGit model.")
current_step = 0


# Use the parameters() method to get an iterator over all the learnable parameters of the model
total_params = sum(p.numel() for p in maskgit.parameters())

print(f"Total number of parameters: {format(total_params, ',d')}")

# Create the dataset objects
with accelerator.main_process_first():
if args.no_cache and args.train_data_dir:
Expand Down
5 changes: 5 additions & 0 deletions train_muse_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,11 @@ def main():

current_step = 0

# Use the parameters() method to get an iterator over all the learnable parameters of the model
total_params = sum(p.numel() for p in vae.parameters())

print(f"Total number of parameters: {format(total_params, ',d')}")

dataset = ImageDataset(
dataset,
args.image_size,
Expand Down

0 comments on commit ed614c9

Please sign in to comment.