Skip to content

Commit

Permalink
have transformer automatically handle binary cross entropy loss for t…
Browse files Browse the repository at this point in the history
…oken critic
  • Loading branch information
lucidrains committed Jan 20, 2023
1 parent 3954fa9 commit 862b8e9
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion muse_maskgit_pytorch/muse_maskgit_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,8 @@ def __init__(
self.transformer_blocks = TransformerBlocks(dim = dim, **kwargs)
self.norm = LayerNorm(dim)

self.to_logits = nn.Linear(dim, default(dim_out, num_tokens), bias = False)
self.dim_out = default(dim_out, num_tokens)
self.to_logits = nn.Linear(dim, self.dim_out, bias = False)

# text conditioning

Expand Down Expand Up @@ -318,6 +319,9 @@ def forward(
if not exists(labels):
return logits

if self.dim_out == 1:
return F.binary_cross_entropy_with_logits(rearrange(logits, '... 1 -> ...'), labels)

logits = rearrange(logits, 'b n c -> b c n')
return F.cross_entropy(logits, labels, ignore_index = ignore_index)

Expand Down

0 comments on commit 862b8e9

Please sign in to comment.