Skip to content

Commit

Permalink
Update generator.py
Browse files Browse the repository at this point in the history
  • Loading branch information
agent87 authored Nov 28, 2024
1 parent 890cf02 commit 2f6ef72
Showing 1 changed file with 48 additions and 51 deletions.
99 changes: 48 additions & 51 deletions tts/generator.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,64 @@
from TTS.utils.synthesizer import Synthesizer
import io
import os
from pydantic import BaseModel
from fastapi import FastAPI, HTTPException, Form
from fastapi.responses import StreamingResponse
from TTS.utils.synthesizer import Synthesizer
from typing import Optional
import numpy as np
import soundfile as sf

class TTSResponse(BaseModel):
status_code: int = 0
error: Optional[str] = None

class tts_response(BaseModel):
status_code: int = 10
error: str


class TTS_MODEL(BaseModel):
MAX_TXT_LEN: int = os.getenv('TTS_MAX_TXT_LEN', 1000)
class TTSModel(BaseModel):
MAX_TXT_LEN: int = int(os.getenv('TTS_MAX_TXT_LEN', 1000))
SOUNDS_DIR: str = "sounds"
MODEL_PATH: str = "./model.pth"
CONFIG_PATH: str = "config.json"
SPEAKERS_PATH: str = "speakers.pth"
ENCODER_CHECKPOINT_PATH: str = "SE_checkpoint.pth.tar"
ENCODER_CONFIG: str = "config_se.json"
SPEAKER_WAV = "conditioning_audio.wav"
MODEL_PATH: str = r"./model_files/model.pth"
CONFIG_PATH: str = r"./model_files/config.json"
SPEAKERS_PATH: str = r"./model_files/speakers.pth"
ENCODER_CHECKPOINT_PATH: str = r"./model_files/SE_checkpoint.pth.tar"
ENCODER_CONFIG: str = r"./model_files/config_se.json"
SPEAKER_WAV: str = r"./model_files/conditioning_audio.wav"



#Initiate the model
engine_specs = TTS_MODEL()
# Initiate the model
engine_specs = TTSModel()

engine = Synthesizer(
engine_specs.MODEL_PATH,
engine_specs.CONFIG_PATH,
tts_speakers_file=engine_specs.SPEAKERS_PATH,
encoder_checkpoint=engine_specs.ENCODER_CHECKPOINT_PATH,
encoder_config=engine_specs.ENCODER_CONFIG,
)

engine_specs.MODEL_PATH,
engine_specs.CONFIG_PATH,
tts_speakers_file=engine_specs.SPEAKERS_PATH,
encoder_checkpoint=engine_specs.ENCODER_CHECKPOINT_PATH,
encoder_config=engine_specs.ENCODER_CONFIG,
)

class Generator:
MAX_TXT_LEN: int = 1000 # os.getenv('TTS_MAX_TXT_LEN')
SOUNDS_DIR: str = "sounds"
MODEL_PATH: str = "./model.pth"
CONFIG_PATH: str = "config.json"
SPEAKERS_PATH: str = "speakers.pth"
ENCODER_CHECKPOINT_PATH: str = "SE_checkpoint.pth.tar"
ENCODER_CONFIG: str = "config_se.json"
SPEAKER_WAV = "conditioning_audio.wav"
response = tts_response()
def __init__(self, text: str) -> None:
self.MAX_TXT_LEN = 1000
self.SPEAKER_WAV = engine_specs.SPEAKER_WAV
self.response = TTSResponse()
self.audio_bytes = None
self.audio_buffer = io.BytesIO()

def __init__(self, text) -> None:
# Initiate the tts response
if len(text) > self.MAX_TXT_LEN:
text = text[: self.MAX_TXT_LEN] # cut off text to the limit
self.response.status_code = 10
self.response.error = f"Input text was cutoff since it went over the {self.MAX_TXT_LEN} character limit."

self.audio_bytes: bytes = engine.tts(text, speaker_wav=self.SPEAKER_WAV)

# save the audio
self.save_audio()



def save_audio(self) -> str:
file_id = len(os.listdir(self.SOUNDS_DIR)) + 1
file_path : str = f"{self.SOUNDS_DIR}/sound-{file_id}.wav"

with open(file_path, "wb+") as audio_file:
engine.save_wav(self.audio_bytes, audio_file)

self.file_path = file_path
else:
try:
self.audio_bytes = engine.tts(text, speaker_wav=self.SPEAKER_WAV)
self.save_audio()
except Exception as e:
self.response.status_code = 500
self.response.error = str(e)

def save_audio(self) -> None:
# Ensure that all elements are converted to the correct type
if isinstance(self.audio_bytes, list):
self.audio_bytes = np.array(self.audio_bytes, dtype=np.float32)

# Write the audio data to the BytesIO buffer as a WAV file
sf.write(self.audio_buffer, self.audio_bytes, samplerate=22050, format='WAV')
self.audio_buffer.seek(0)

0 comments on commit 2f6ef72

Please sign in to comment.