diff --git a/src/fairseq2/models/wav2vec2/factory.py b/src/fairseq2/models/wav2vec2/factory.py index 7244380f3..36247df17 100644 --- a/src/fairseq2/models/wav2vec2/factory.py +++ b/src/fairseq2/models/wav2vec2/factory.py @@ -72,6 +72,11 @@ class Wav2Vec2Config: final_proj_bias: bool = True """If ``True``, the final projection learns an additive bias.""" + quantizer_encoder_grad: bool = True + """If ``True``, gradients are propagated from the quantizer through the convolutional + encoder. Otherwise, they are detached and the encoder is only trained with gradients + from the transformer. """ + # Mask temporal_mask_span_len: int = 10 """The length of each temporal mask span that is applied over time steps.""" @@ -284,6 +289,7 @@ def build_model(self) -> Wav2Vec2Model: final_proj_bias=self._config.final_proj_bias, num_distractors=self._config.num_distractors, logit_temp=self._config.logit_temp, + quantizer_encoder_grad=self._config.quantizer_encoder_grad, device=self._device, dtype=self._dtype, ) diff --git a/src/fairseq2/models/wav2vec2/model.py b/src/fairseq2/models/wav2vec2/model.py index 749f931ce..1f6dbc68d 100644 --- a/src/fairseq2/models/wav2vec2/model.py +++ b/src/fairseq2/models/wav2vec2/model.py @@ -54,6 +54,7 @@ def __init__( final_proj_bias: bool = True, num_distractors: int = 100, logit_temp: float = 0.1, + quantizer_encoder_grad: bool = True, device: Device | None = None, dtype: DataType | None = None, ) -> None: @@ -151,7 +152,10 @@ def run_frontend( # We use the extracted features as context network targets after masking # and quantization. - targets = seqs.clone() + if self.quantizer_encoder_grad: + targets = seqs.clone() + else: + targets = seqs.detach().clone() if frontend.first_pass_dropout is not None: targets = frontend.first_pass_dropout(targets)