Skip to content

Commit

Permalink
Merge branch 'main' into cli/tab-completion
Browse files Browse the repository at this point in the history
  • Loading branch information
Ahmedsaed authored Aug 12, 2024
2 parents ff0c17a + 0677d64 commit 9e91069
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
6 changes: 6 additions & 0 deletions src/fairseq2/models/wav2vec2/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ class Wav2Vec2Config:
final_proj_bias: bool = True
"""If ``True``, the final projection learns an additive bias."""

quantizer_encoder_grad: bool = True
"""If ``True``, gradients are propagated from the quantizer through the convolutional
encoder. Otherwise, they are detached and the encoder is only trained with gradients
from the transformer. """

# Mask
temporal_mask_span_len: int = 10
"""The length of each temporal mask span that is applied over time steps."""
Expand Down Expand Up @@ -284,6 +289,7 @@ def build_model(self) -> Wav2Vec2Model:
final_proj_bias=self._config.final_proj_bias,
num_distractors=self._config.num_distractors,
logit_temp=self._config.logit_temp,
quantizer_encoder_grad=self._config.quantizer_encoder_grad,
device=self._device,
dtype=self._dtype,
)
Expand Down
6 changes: 5 additions & 1 deletion src/fairseq2/models/wav2vec2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(
final_proj_bias: bool = True,
num_distractors: int = 100,
logit_temp: float = 0.1,
quantizer_encoder_grad: bool = True,
device: Device | None = None,
dtype: DataType | None = None,
) -> None:
Expand Down Expand Up @@ -151,7 +152,10 @@ def run_frontend(

# We use the extracted features as context network targets after masking
# and quantization.
targets = seqs.clone()
if self.quantizer_encoder_grad:
targets = seqs.clone()
else:
targets = seqs.detach().clone()

if frontend.first_pass_dropout is not None:
targets = frontend.first_pass_dropout(targets)
Expand Down

0 comments on commit 9e91069

Please sign in to comment.