Skip to content

Commit

Permalink
more efficient generation
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 10, 2023
1 parent d6a4fb4 commit 277239a
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
8 changes: 6 additions & 2 deletions muse_maskgit_pytorch/muse_maskgit_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,8 @@ def forward(

if exists(texts):
text_embeds = self.encode_text(texts)
context = self.text_embed_proj(text_embeds)

context = self.text_embed_proj(text_embeds)

context_mask = (text_embeds != 0).any(dim = -1)

Expand Down Expand Up @@ -365,6 +366,9 @@ def generate(
starting_temperature = temperature

cond_ids = None

text_embeds = self.transformer.encode_text(texts)

if self.resize_image_for_cond_image:
assert exists(cond_images), 'conditioning image must be passed in to generate for super res maskgit'
with torch.no_grad():
Expand All @@ -381,7 +385,7 @@ def generate(

logits = self.transformer.forward_with_cond_scale(
ids,
texts = texts,
text_embeds = text_embeds,
conditioning_token_ids = cond_ids,
cond_scale = cond_scale
)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'muse-maskgit-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.12',
version = '0.0.14',
license='MIT',
description = 'MUSE - Text-to-Image Generation via Masked Generative Transformers, in Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 277239a

Please sign in to comment.