From 862b8e9b4a0fd9504826c2df35b22f6f4c3a2e96 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 20 Jan 2023 14:15:58 -0800 Subject: [PATCH] have transformer automatically handle binary cross entropy loss for token critic --- muse_maskgit_pytorch/muse_maskgit_pytorch.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/muse_maskgit_pytorch/muse_maskgit_pytorch.py b/muse_maskgit_pytorch/muse_maskgit_pytorch.py index f474bb7..2856551 100644 --- a/muse_maskgit_pytorch/muse_maskgit_pytorch.py +++ b/muse_maskgit_pytorch/muse_maskgit_pytorch.py @@ -204,7 +204,8 @@ def __init__( self.transformer_blocks = TransformerBlocks(dim = dim, **kwargs) self.norm = LayerNorm(dim) - self.to_logits = nn.Linear(dim, default(dim_out, num_tokens), bias = False) + self.dim_out = default(dim_out, num_tokens) + self.to_logits = nn.Linear(dim, self.dim_out, bias = False) # text conditioning @@ -318,6 +319,9 @@ def forward( if not exists(labels): return logits + if self.dim_out == 1: + return F.binary_cross_entropy_with_logits(rearrange(logits, '... 1 -> ...'), labels) + logits = rearrange(logits, 'b n c -> b c n') return F.cross_entropy(logits, labels, ignore_index = ignore_index)