diff --git a/muse_maskgit_pytorch/muse_maskgit_pytorch.py b/muse_maskgit_pytorch/muse_maskgit_pytorch.py index 6563868..3e01412 100644 --- a/muse_maskgit_pytorch/muse_maskgit_pytorch.py +++ b/muse_maskgit_pytorch/muse_maskgit_pytorch.py @@ -218,7 +218,8 @@ def forward( if exists(texts): text_embeds = self.encode_text(texts) - context = self.text_embed_proj(text_embeds) + + context = self.text_embed_proj(text_embeds) context_mask = (text_embeds != 0).any(dim = -1) @@ -365,6 +366,9 @@ def generate( starting_temperature = temperature cond_ids = None + + text_embeds = self.transformer.encode_text(texts) + if self.resize_image_for_cond_image: assert exists(cond_images), 'conditioning image must be passed in to generate for super res maskgit' with torch.no_grad(): @@ -381,7 +385,7 @@ def generate( logits = self.transformer.forward_with_cond_scale( ids, - texts = texts, + text_embeds = text_embeds, conditioning_token_ids = cond_ids, cond_scale = cond_scale ) diff --git a/setup.py b/setup.py index dfa6d3d..f34404b 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'muse-maskgit-pytorch', packages = find_packages(exclude=[]), - version = '0.0.12', + version = '0.0.14', license='MIT', description = 'MUSE - Text-to-Image Generation via Masked Generative Transformers, in Pytorch', author = 'Phil Wang',