From 4483bb66e000305122a6e8e741afd6f7c9a74ebc Mon Sep 17 00:00:00 2001 From: Matthias Cremon Date: Fri, 2 Aug 2024 09:22:03 -0700 Subject: [PATCH] Add Wav2Vec2 base model (#4513) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/4513 As titled. Reviewed By: zonglinpengmeta Differential Revision: D60619295 fbshipit-source-id: 00fd48029bc2413cf2a4a1453c80bbf65d29c57f --- examples/cadence/models/wav2vec2.py | 65 +++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 examples/cadence/models/wav2vec2.py diff --git a/examples/cadence/models/wav2vec2.py b/examples/cadence/models/wav2vec2.py new file mode 100644 index 0000000000..5db9ea2a6d --- /dev/null +++ b/examples/cadence/models/wav2vec2.py @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Example script for exporting simple models to flatbuffer + +import logging + +from executorch.backends.cadence.aot.ops_registrations import * # noqa + +import torch + +from executorch.backends.cadence.aot.export_example import export_model +from torchaudio.models.wav2vec2.model import wav2vec2_model, Wav2Vec2Model + +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(level=logging.INFO, format=FORMAT) + + +def main() -> None: + # The wrapper is needed to avoid issues with the optional second arguments + # of Wav2Vec2Models. + class Wav2Vec2ModelWrapper(torch.nn.Module): + def __init__(self, model: Wav2Vec2Model): + super().__init__() + self.model = model + + def forward(self, x): + out, _ = self.model(x) + return out + + _model = wav2vec2_model( + extractor_mode="layer_norm", + extractor_conv_layer_config=None, + extractor_conv_bias=False, + encoder_embed_dim=768, + encoder_projection_dropout=0.1, + encoder_pos_conv_kernel=128, + encoder_pos_conv_groups=16, + encoder_num_layers=12, + encoder_num_heads=12, + encoder_attention_dropout=0.1, + encoder_ff_interm_features=3072, + encoder_ff_interm_dropout=0.0, + encoder_dropout=0.1, + encoder_layer_norm_first=False, + encoder_layer_drop=0.1, + aux_num_out=None, + ) + _model.eval() + + model = Wav2Vec2ModelWrapper(_model) + model.eval() + + # test input + audio_len = 1680 + example_inputs = (torch.rand(1, audio_len),) + + export_model(model, example_inputs) + + +if __name__ == "__main__": + main()