diff --git a/src/fairseq2/models/wav2vec2/model.py b/src/fairseq2/models/wav2vec2/model.py index 1f6dbc68d..d2f811fc5 100644 --- a/src/fairseq2/models/wav2vec2/model.py +++ b/src/fairseq2/models/wav2vec2/model.py @@ -109,6 +109,7 @@ def __init__( self.num_distractors = num_distractors self.logit_temp = logit_temp + self.quantizer_encoder_grad = quantizer_encoder_grad def forward(self, batch: SequenceBatch) -> Wav2Vec2Output: """