Skip to content

Commit

Permalink
Refactor phoneme handler by removing base class
Browse files Browse the repository at this point in the history
  • Loading branch information
tarepan committed Nov 27, 2023
1 parent ea65dc8 commit cdf0f1a
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 93 deletions.
26 changes: 2 additions & 24 deletions test/test_acoustic_feature_extractor.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,9 @@
from unittest import TestCase

from voicevox_engine.acoustic_feature_extractor import BasePhoneme, OjtPhoneme
from voicevox_engine.acoustic_feature_extractor import OjtPhoneme


class TestBasePhoneme(TestCase):
def setUp(self):
super().setUp()
self.str_hello_hiho = "sil k o N n i ch i w a pau h i h o d e s U sil"
self.base_hello_hiho = [
BasePhoneme(s, i, i + 1) for i, s in enumerate(self.str_hello_hiho.split())
]

def test_repr_(self):
self.assertEqual(
self.base_hello_hiho[1].__repr__(), "Phoneme(phoneme='k', start=1, end=2)"
)
self.assertEqual(
self.base_hello_hiho[10].__repr__(),
"Phoneme(phoneme='pau', start=10, end=11)",
)

def test_convert(self):
with self.assertRaises(NotImplementedError):
BasePhoneme.convert(self.base_hello_hiho)


class TestOjtPhoneme(TestBasePhoneme):
class TestOjtPhoneme(TestCase):
def setUp(self):
super().setUp()
str_hello_hiho = "sil k o N n i ch i w a pau h i h o d e s U sil"
Expand Down
71 changes: 2 additions & 69 deletions voicevox_engine/acoustic_feature_extractor.py
Original file line number Diff line number Diff line change
@@ -1,76 +1,9 @@
from abc import abstractmethod
from typing import List, Sequence
from typing import List

import numpy


class BasePhoneme(object):
"""
音素の応用クラス群の抽象基底クラス
Attributes
----------
phoneme_list : Sequence[str]
音素のリスト
num_phoneme : int
音素リストの要素数
space_phoneme : str
読点に値する音素
"""

phoneme_list: Sequence[str]
num_phoneme: int
space_phoneme: str

def __init__(
self,
phoneme: str,
start: float,
end: float,
):
self.phoneme = phoneme
self.start = numpy.round(start, decimals=2)
self.end = numpy.round(end, decimals=2)

def __repr__(self):
return f"Phoneme(phoneme='{self.phoneme}', start={self.start}, end={self.end})"

def __eq__(self, o: object):
return isinstance(o, BasePhoneme) and (
self.phoneme == o.phoneme and self.start == o.start and self.end == o.end
)

@property
def phoneme_id(self):
"""
phoneme_id (phoneme list内でのindex)を取得する
Returns
-------
id : int
phoneme_idを返す
"""
return self.phoneme_list.index(self.phoneme)

@property
def onehot(self):
"""
phoneme listの長さ分の0埋め配列のうち、phoneme id番目がTrue(1)の配列を返す
Returns
-------
onehot : numpu.ndarray
関数内で変更された配列を返す
"""
array = numpy.zeros(self.num_phoneme, dtype=bool)
array[self.phoneme_id] = True
return array

@classmethod
@abstractmethod
def convert(cls, phonemes: List["BasePhoneme"]) -> List["BasePhoneme"]:
raise NotImplementedError


class OjtPhoneme(BasePhoneme):
class OjtPhoneme:
"""
OpenJTalkに含まれる音素群クラス
Expand Down

0 comments on commit cdf0f1a

Please sign in to comment.