Skip to content

Commit

Permalink
use SUNO
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielchua committed Sep 30, 2024
1 parent 5534c51 commit 8ddd281
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 25 deletions.
19 changes: 17 additions & 2 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,21 @@
from prompts import SYSTEM_PROMPT
from utils import generate_script, generate_audio, parse_url

LANGUAGE_MAPPING = {
"English": "en",
"Chinese": "zh",
"French": "fr",
"German": "de",
"Hindi": "hi",
"Italian": "it",
"Japanese": "ja",
"Korean": "ko",
"Polish": "pl",
"Portuguese": "pt",
"Russian": "ru",
"Spanish": "es",
"Turkish": "tr"
}

class DialogueItem(BaseModel):
"""A single dialogue item."""
Expand Down Expand Up @@ -139,7 +154,7 @@ def generate_podcast(

# Get audio file path
audio_file_path = generate_audio(
line.text, line.speaker, language_mapping[language]
line.text, line.speaker, LANGUAGE_MAPPING[language]
)
# Read the audio file into an AudioSegment
audio_segment = AudioSegment.from_file(audio_file_path)
Expand Down Expand Up @@ -206,7 +221,7 @@ def generate_podcast(
value="Medium (3-5 min)"
),
gr.Dropdown(
choices=["English", "Spanish", "French", "Chinese", "Japanese", "Korean"],
choices=list(LANGUAGE_MAPPING.keys()),
value="English",
label="6. 🌐 Choose the language"
),
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ pypdf==4.1
sentry-sdk==2.5
spaces==0.30.2

tenacity==8.3
tenacity==8.3
git+https://github.com/suno-ai/bark.git
60 changes: 38 additions & 22 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@

import os
import requests

from gradio_client import Client
from openai import OpenAI
from pydantic import ValidationError

from bark import SAMPLE_RATE, generate_audio, preload_models
from scipy.io.wavfile import write as write_wav

MODEL_ID = "accounts/fireworks/models/llama-v3p1-405b-instruct"
JINA_URL = "https://r.jina.ai/"

Expand All @@ -22,7 +24,10 @@
api_key=os.getenv("FIREWORKS_API_KEY"),
)

hf_client = Client("mrfakename/MeloTTS")
# hf_client = Client("mrfakename/MeloTTS")

# download and load all models
preload_models()


def generate_script(system_prompt: str, input_text: str, output_model):
Expand Down Expand Up @@ -73,23 +78,34 @@ def parse_url(url: str) -> str:
return response.text


def generate_audio(text: str, speaker: str, language: str) -> bytes:
"""Get the audio from the TTS model from HF Spaces and adjust pitch if necessary."""
if speaker == "Guest":
accent = "EN-US" if language == "EN" else language
speed = 0.9
else: # host
accent = "EN-Default" if language == "EN" else language
speed = 1
if language != "EN" and speaker != "Guest":
speed = 1.1

# Generate audio
result = hf_client.predict(
text=text,
language=language,
speaker=accent,
speed=speed,
api_name="/synthesize",
)
return result
def generate_audio(text: str, speaker: str, language: str) -> str:

audio_array = generate_audio(text, history_prompt=f"v2/{language}_speaker_{'1' if speaker == 'Host (Jane)' else '3'}")

file_path = f"audio_{language}_{speaker}.mp3"

# save audio to disk
write_wav(file_path, SAMPLE_RATE, audio_array)

return file_path


# """Get the audio from the TTS model from HF Spaces and adjust pitch if necessary."""
# if speaker == "Guest":
# accent = "EN-US" if language == "EN" else language
# speed = 0.9
# else: # host
# accent = "EN-Default" if language == "EN" else language
# speed = 1
# if language != "EN" and speaker != "Guest":
# speed = 1.1

# # Generate audio
# result = hf_client.predict(
# text=text,
# language=language,
# speaker=accent,
# speed=speed,
# api_name="/synthesize",
# )
# return result

0 comments on commit 8ddd281

Please sign in to comment.