From cdf0f1aa13212d6936f4455a259660f22a7b1220 Mon Sep 17 00:00:00 2001 From: tarepan Date: Mon, 27 Nov 2023 00:39:22 +0000 Subject: [PATCH] Refactor phoneme handler by removing base class --- test/test_acoustic_feature_extractor.py | 26 +------ voicevox_engine/acoustic_feature_extractor.py | 71 +------------------ 2 files changed, 4 insertions(+), 93 deletions(-) diff --git a/test/test_acoustic_feature_extractor.py b/test/test_acoustic_feature_extractor.py index df0b7ad62..a2a520c21 100644 --- a/test/test_acoustic_feature_extractor.py +++ b/test/test_acoustic_feature_extractor.py @@ -1,31 +1,9 @@ from unittest import TestCase -from voicevox_engine.acoustic_feature_extractor import BasePhoneme, OjtPhoneme +from voicevox_engine.acoustic_feature_extractor import OjtPhoneme -class TestBasePhoneme(TestCase): - def setUp(self): - super().setUp() - self.str_hello_hiho = "sil k o N n i ch i w a pau h i h o d e s U sil" - self.base_hello_hiho = [ - BasePhoneme(s, i, i + 1) for i, s in enumerate(self.str_hello_hiho.split()) - ] - - def test_repr_(self): - self.assertEqual( - self.base_hello_hiho[1].__repr__(), "Phoneme(phoneme='k', start=1, end=2)" - ) - self.assertEqual( - self.base_hello_hiho[10].__repr__(), - "Phoneme(phoneme='pau', start=10, end=11)", - ) - - def test_convert(self): - with self.assertRaises(NotImplementedError): - BasePhoneme.convert(self.base_hello_hiho) - - -class TestOjtPhoneme(TestBasePhoneme): +class TestOjtPhoneme(TestCase): def setUp(self): super().setUp() str_hello_hiho = "sil k o N n i ch i w a pau h i h o d e s U sil" diff --git a/voicevox_engine/acoustic_feature_extractor.py b/voicevox_engine/acoustic_feature_extractor.py index 7fa0e9f25..e14e7c62e 100644 --- a/voicevox_engine/acoustic_feature_extractor.py +++ b/voicevox_engine/acoustic_feature_extractor.py @@ -1,76 +1,9 @@ -from abc import abstractmethod -from typing import List, Sequence +from typing import List import numpy -class BasePhoneme(object): - """ - 音素の応用クラス群の抽象基底クラス - - Attributes - ---------- - phoneme_list : Sequence[str] - 音素のリスト - num_phoneme : int - 音素リストの要素数 - space_phoneme : str - 読点に値する音素 - """ - - phoneme_list: Sequence[str] - num_phoneme: int - space_phoneme: str - - def __init__( - self, - phoneme: str, - start: float, - end: float, - ): - self.phoneme = phoneme - self.start = numpy.round(start, decimals=2) - self.end = numpy.round(end, decimals=2) - - def __repr__(self): - return f"Phoneme(phoneme='{self.phoneme}', start={self.start}, end={self.end})" - - def __eq__(self, o: object): - return isinstance(o, BasePhoneme) and ( - self.phoneme == o.phoneme and self.start == o.start and self.end == o.end - ) - - @property - def phoneme_id(self): - """ - phoneme_id (phoneme list内でのindex)を取得する - Returns - ------- - id : int - phoneme_idを返す - """ - return self.phoneme_list.index(self.phoneme) - - @property - def onehot(self): - """ - phoneme listの長さ分の0埋め配列のうち、phoneme id番目がTrue(1)の配列を返す - Returns - ------- - onehot : numpu.ndarray - 関数内で変更された配列を返す - """ - array = numpy.zeros(self.num_phoneme, dtype=bool) - array[self.phoneme_id] = True - return array - - @classmethod - @abstractmethod - def convert(cls, phonemes: List["BasePhoneme"]) -> List["BasePhoneme"]: - raise NotImplementedError - - -class OjtPhoneme(BasePhoneme): +class OjtPhoneme: """ OpenJTalkに含まれる音素群クラス