Skip to content

Commit

Permalink
make it right
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 10, 2023
1 parent 5bdcefa commit 0707d6e
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 14 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,10 @@ images # List[PIL.Image.Image]
## Todo

- [x] test end-to-end
- [x] separate cond_images_or_ids, it is not done right
- [x] add training code for vae

- [ ] separate cond_images_or_ids, it is not done right
- [ ] hook up accelerate code
- [ ] hook up accelerate training code for maskgit
- [ ] combine with token critic paper, already implemented at <a href="https://github.com/lucidrains/phenaki-pytorch">Phenaki</a>

## Citations
Expand Down
20 changes: 9 additions & 11 deletions muse_maskgit_pytorch/muse_maskgit_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,8 @@ def forward(
self,
images_or_ids: torch.Tensor,
ignore_index = -1,
cond_images_or_ids: Optional[torch.Tensor] = None,
cond_images: Optional[torch.Tensor] = None,
cond_token_ids: Optional[torch.Tensor] = None,
texts: Optional[List[str]] = None,
text_embeds: Optional[torch.Tensor] = None,
cond_drop_prob = None
Expand Down Expand Up @@ -482,17 +483,14 @@ def forward(

# tokenize conditional images if needed

cond_ids = None
assert not (exists(cond_images) and exists(cond_token_ids)), 'if conditioning on low resolution, cannot pass in both images and token ids'

if exists(cond_images_or_ids):
if cond_images_or_ids.dtype == torch.float:
assert exists(self.cond_vae), 'cond vqgan vae must be passed in'
assert all([height_or_width == self.cond_image_size for height_or_width in cond_images_or_ids.shape[-2:]])
if exists(cond_images):
assert exists(self.cond_vae), 'cond vqgan vae must be passed in'
assert all([height_or_width == self.cond_image_size for height_or_width in cond_images.shape[-2:]])

with torch.no_grad():
_, cond_ids, _ = self.cond_vae.encode(cond_images_or_ids)
else:
cond_ids = cond_image_or_ids
with torch.no_grad():
_, cond_token_ids, _ = self.cond_vae.encode(cond_images)

# prepare mask

Expand All @@ -515,7 +513,7 @@ def forward(
ids,
texts = texts,
text_embeds = text_embeds,
conditioning_token_ids = cond_ids,
conditioning_token_ids = cond_token_ids,
labels = labels,
cond_drop_prob = cond_drop_prob,
ignore_index = ignore_index
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'muse-maskgit-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.15',
version = '0.0.16',
license='MIT',
description = 'MUSE - Text-to-Image Generation via Masked Generative Transformers, in Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 0707d6e

Please sign in to comment.