diff --git a/README.md b/README.md index 7cb2227..2b3c150 100644 --- a/README.md +++ b/README.md @@ -225,9 +225,10 @@ images # List[PIL.Image.Image] ## Todo - [x] test end-to-end +- [x] separate cond_images_or_ids, it is not done right +- [x] add training code for vae -- [ ] separate cond_images_or_ids, it is not done right -- [ ] hook up accelerate code +- [ ] hook up accelerate training code for maskgit - [ ] combine with token critic paper, already implemented at Phenaki ## Citations diff --git a/muse_maskgit_pytorch/muse_maskgit_pytorch.py b/muse_maskgit_pytorch/muse_maskgit_pytorch.py index b76b031..af28e3f 100644 --- a/muse_maskgit_pytorch/muse_maskgit_pytorch.py +++ b/muse_maskgit_pytorch/muse_maskgit_pytorch.py @@ -452,7 +452,8 @@ def forward( self, images_or_ids: torch.Tensor, ignore_index = -1, - cond_images_or_ids: Optional[torch.Tensor] = None, + cond_images: Optional[torch.Tensor] = None, + cond_token_ids: Optional[torch.Tensor] = None, texts: Optional[List[str]] = None, text_embeds: Optional[torch.Tensor] = None, cond_drop_prob = None @@ -482,17 +483,14 @@ def forward( # tokenize conditional images if needed - cond_ids = None + assert not (exists(cond_images) and exists(cond_token_ids)), 'if conditioning on low resolution, cannot pass in both images and token ids' - if exists(cond_images_or_ids): - if cond_images_or_ids.dtype == torch.float: - assert exists(self.cond_vae), 'cond vqgan vae must be passed in' - assert all([height_or_width == self.cond_image_size for height_or_width in cond_images_or_ids.shape[-2:]]) + if exists(cond_images): + assert exists(self.cond_vae), 'cond vqgan vae must be passed in' + assert all([height_or_width == self.cond_image_size for height_or_width in cond_images.shape[-2:]]) - with torch.no_grad(): - _, cond_ids, _ = self.cond_vae.encode(cond_images_or_ids) - else: - cond_ids = cond_image_or_ids + with torch.no_grad(): + _, cond_token_ids, _ = self.cond_vae.encode(cond_images) # prepare mask @@ -515,7 +513,7 @@ def forward( ids, texts = texts, text_embeds = text_embeds, - conditioning_token_ids = cond_ids, + conditioning_token_ids = cond_token_ids, labels = labels, cond_drop_prob = cond_drop_prob, ignore_index = ignore_index diff --git a/setup.py b/setup.py index 0367e8f..dc90fdb 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'muse-maskgit-pytorch', packages = find_packages(exclude=[]), - version = '0.0.15', + version = '0.0.16', license='MIT', description = 'MUSE - Text-to-Image Generation via Masked Generative Transformers, in Pytorch', author = 'Phil Wang',