Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Confused by this LogitMask #25

Open
JJASMINE22 opened this issue Jan 3, 2024 · 0 comments
Open

Confused by this LogitMask #25

JJASMINE22 opened this issue Jan 3, 2024 · 0 comments

Comments

@JJASMINE22
Copy link

rqvae/models/rqtransformer/primitives.py
class LogitMask(nn.Module):
def init(self, vocab_size: Iterable[int], value=-1e6):
super().init()

    self.vocab_size = vocab_size
    self.mask_cond = [vocab_size[0]]*len(vocab_size) != vocab_size
    self.value = value

def forward(self, logits: Tensor) -> Tensor:
    if not self.mask_cond:
        return logits
    else:
        for idx, vocab_size in enumerate(self.vocab_size):
            logits[:, idx, vocab_size:].fill_(-float('Inf'))
        return logits

The logits mentioned in LogitMask should probably be expressed as logits[:, idx, vocab_size:] = logits[:, idx, vocab_size:].fill_(-float('Inf'))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant