Skip to content

Commit

Permalink
[ci/cd] try to pass ut
Browse files Browse the repository at this point in the history
  • Loading branch information
xingchensong committed Dec 12, 2023
1 parent 469c90b commit a197ed4
Showing 1 changed file with 21 additions and 12 deletions.
33 changes: 21 additions & 12 deletions test/wenet/cli/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,34 @@
# -*- coding: utf-8 -*-
# Copyright [2023-12-12] <[email protected], Xingchen Song>

import os
import pytest

from wenet.cli.hub import download
from wenet.cli.model import Model


@pytest.mark.parametrize("model_link", [
"https://wenet.org.cn/downloads?models=wenet&version=aishell_u2pp_conformer_libtorch.tar.gz",
"https://wenet.org.cn/downloads?models=wenet&version=aishell2_u2pp_conformer_libtorch.tar.gz",
"https://wenet.org.cn/downloads?models=wenet&version=gigaspeech_u2pp_conformer_libtorch.tar.gz",
"https://wenet.org.cn/downloads?models=wenet&version=librispeech_u2pp_conformer_libtorch.tar.gz",
"https://wenet.org.cn/downloads?models=wenet&version=multi_cn_unified_conformer_libtorch.tar.gz",
"https://wenet.org.cn/downloads?models=wenet&version=wenetspeech_u2pp_conformer_libtorch.tar.gz"
@pytest.mark.parametrize("model", [
"aishell_u2pp_conformer_libtorch.tar.gz",
"aishell2_u2pp_conformer_libtorch.tar.gz",
"gigaspeech_u2pp_conformer_libtorch.tar.gz",
"librispeech_u2pp_conformer_libtorch.tar.gz",
"multi_cn_unified_conformer_libtorch.tar.gz",
"wenetspeech_u2pp_conformer_libtorch.tar.gz"
])
def test_model(model_link):
dest = model_link.split('=')[-1].split('.')[0] # aishell_u2pp_conformer_libtorch
dataset = model_link.split('_')[-4].split('=')[-1] # aishell
download(model_link, dest=dest)
model = Model(model_link,
def test_model(model):
dest = model.split('.')[0] # aishell_u2pp_conformer_libtorch
dataset = model.split('_')[0] # aishell
if not os.path.exists(dest):
os.makedirs(dest)
response = requests.get(
"https://modelscope.cn/api/v1/datasets/wenet/wenet_pretrained_models/oss/tree" # noqa
)
model_info = next(data for data in response.json()["Data"]
if data["Key"] == model)
model_url = model_info['Url']
download(model_url, dest=dest, only_child=True)
model = Model(dest,
gpu=-1,
beam=5,
resample_rate=16000)
Expand Down

0 comments on commit a197ed4

Please sign in to comment.