Skip to content

Commit

Permalink
#111 #132 #133 #134 Refactor TTS module; no longer generate TSS marku…
Browse files Browse the repository at this point in the history
…p in transcript
  • Loading branch information
souzatharsis committed Oct 31, 2024
1 parent fa3799c commit cfa5155
Show file tree
Hide file tree
Showing 17 changed files with 1,027 additions and 707 deletions.
27 changes: 20 additions & 7 deletions podcastfy.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion podcastfy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# This file can be left empty for now
__version__ = "0.2.15" # or whatever version you're on
__version__ = "0.2.16" # or whatever version you're on
43 changes: 20 additions & 23 deletions podcastfy/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,26 +47,25 @@ def process_content(
try:
if config is None:
config = load_config()

# Load default conversation config
conv_config = load_conversation_config()

# Update with provided config if any
if conversation_config:
conv_config.configure(conversation_config)

# Get output directories from conversation config
tts_config = conv_config.get('text_to_speech', {})
output_directories = tts_config.get('output_directories', {})
tts_config = conv_config.get("text_to_speech", {})
output_directories = tts_config.get("output_directories", {})

if transcript_file:
logger.info(f"Using transcript file: {transcript_file}")
with open(transcript_file, "r") as file:
qa_content = file.read()
else:
content_generator = ContentGenerator(
api_key=config.GEMINI_API_KEY,
conversation_config=conv_config.to_dict()
api_key=config.GEMINI_API_KEY, conversation_config=conv_config.to_dict()
)

combined_content = ""
Expand All @@ -83,8 +82,8 @@ def process_content(
# Generate Q&A content using output directory from conversation config
random_filename = f"transcript_{uuid.uuid4().hex}.txt"
transcript_filepath = os.path.join(
output_directories.get("transcripts", "data/transcripts"),
random_filename
output_directories.get("transcripts", "data/transcripts"),
random_filename,
)
qa_content = content_generator.generate_qa_content(
combined_content,
Expand All @@ -99,15 +98,14 @@ def process_content(
api_key = getattr(config, f"{tts_model.upper()}_API_KEY")

text_to_speech = TextToSpeech(
model=tts_model,
api_key=api_key,
conversation_config=conv_config.to_dict()
model=tts_model,
api_key=api_key,
conversation_config=conv_config.to_dict(),
)

random_filename = f"podcast_{uuid.uuid4().hex}.mp3"
audio_file = os.path.join(
output_directories.get("audio", "data/audio"),
random_filename
output_directories.get("audio", "data/audio"), random_filename
)
text_to_speech.convert_to_speech(qa_content, audio_file)
logger.info(f"Podcast generated successfully using {tts_model} TTS model")
Expand All @@ -120,6 +118,7 @@ def process_content(
logger.error(f"An error occurred in the process_content function: {str(e)}")
raise


@app.command()
def main(
urls: list[str] = typer.Option(None, "--url", "-u", help="URLs to process"),
Expand All @@ -130,7 +129,7 @@ def main(
None, "--transcript", "-t", help="Path to a transcript file"
),
tts_model: str = typer.Option(
"openai",
None,
"--tts-model",
"-tts",
help="TTS model to use (openai, elevenlabs or edge)",
Expand Down Expand Up @@ -169,14 +168,12 @@ def main(
if conversation_config_path:
with open(conversation_config_path, "r") as f:
conversation_config: Dict[str, Any] | None = yaml.safe_load(f)




# Use default TTS model from conversation config if not specified
if tts_model is None:
tts_config = load_conversation_config().get('text_to_speech', {})
tts_model = tts_config.get('default_tts_model', 'openai')
tts_config = load_conversation_config().get("text_to_speech", {})
tts_model = tts_config.get("default_tts_model", "openai")

if transcript:
if image_paths:
logger.warning("Image paths are ignored when using a transcript file.")
Expand Down Expand Up @@ -230,7 +227,7 @@ def generate_podcast(
urls: Optional[List[str]] = None,
url_file: Optional[str] = None,
transcript_file: Optional[str] = None,
tts_model: Optional[str] = 'openai',
tts_model: Optional[str] = None,
transcript_only: bool = False,
config: Optional[Dict[str, Any]] = None,
conversation_config: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -315,7 +312,7 @@ def generate_podcast(
conversation_config=conversation_config,
image_paths=image_paths,
is_local=is_local,
text=text
text=text,
)

except Exception as e:
Expand Down
4 changes: 2 additions & 2 deletions podcastfy/config.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
content_generator:
gemini_model: "gemini-1.5-pro-latest"
max_output_tokens: 8192
prompt_template: "souzatharsis/podcastfy_multimodal"
prompt_commit: "c67bea9c"
prompt_template: "podcastfy_multimodal_cleanmarkup"
prompt_commit: "3d5b42fc"
content_extractor:
youtube_url_patterns:
- "youtube.com"
Expand Down
106 changes: 80 additions & 26 deletions podcastfy/content_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@

import os
from typing import Optional, Dict, Any, List
import re

from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_community.llms.llamafile import Llamafile
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain import hub
from podcastfy.utils.config_conversation import load_conversation_config
from podcastfy.utils.config import load_config
import logging
from langchain.prompts import HumanMessagePromptTemplate, SystemMessagePromptTemplate
from langchain.schema import SystemMessage
from langchain.prompts import HumanMessagePromptTemplate

logger = logging.getLogger(__name__)

Expand All @@ -29,7 +29,7 @@ def __init__(
is_local: bool,
temperature: float,
max_output_tokens: int,
model_name: str
model_name: str,
):
"""
Initialize the LLMBackend.
Expand Down Expand Up @@ -72,14 +72,14 @@ def __init__(
self.content_generator_config = self.config.get("content_generator", {})

self.config_conversation = load_conversation_config(conversation_config)
self.tts_config = self.config_conversation.get('text_to_speech', {})
self.tts_config = self.config_conversation.get("text_to_speech", {})

# Get output directories from conversation config
self.output_directories = self.tts_config.get('output_directories', {})
self.output_directories = self.tts_config.get("output_directories", {})

# Create output directories if they don't exist
transcripts_dir = self.output_directories.get('transcripts')
transcripts_dir = self.output_directories.get("transcripts")

if transcripts_dir and not os.path.exists(transcripts_dir):
os.makedirs(transcripts_dir)

Expand All @@ -90,18 +90,21 @@ def __compose_prompt(self, num_images: int):
prompt_template = hub.pull(
self.config.get("content_generator", {}).get(
"prompt_template", "souzatharsis/podcastfy_multimodal"
) + ":" + self.config.get("content_generator", {}).get(
"prompt_commit", "c67bea9c"
)
+ ":"
+ self.config.get("content_generator", {}).get("prompt_commit", "3d5b42fc")
)

image_path_keys = []
messages = []

# Only add text content if input_text is not empty
text_content = {"type": "text", "text": "Please analyze this input and generate a conversation. {input_text}"}
text_content = {
"type": "text",
"text": "Please analyze this input and generate a conversation. {input_text}",
}
messages.append(text_content)

for i in range(num_images):
key = f"image_path_{i}"
image_content = {
Expand All @@ -115,19 +118,28 @@ def __compose_prompt(self, num_images: int):
messages=[HumanMessagePromptTemplate.from_template(messages)]
)
user_instructions = self.config_conversation.get("user_instructions", "")

user_instructions = "[[MAKE SURE TO FOLLOW THESE INSTRUCTIONS OVERRIDING THE PROMPT TEMPLATE IN CASE OF CONFLICT: " + user_instructions + "]]"

new_system_message = prompt_template.messages[0].prompt.template + "\n" + user_instructions

user_instructions = (
"[[MAKE SURE TO FOLLOW THESE INSTRUCTIONS OVERRIDING THE PROMPT TEMPLATE IN CASE OF CONFLICT: "
+ user_instructions
+ "]]"
)

new_system_message = (
prompt_template.messages[0].prompt.template + "\n" + user_instructions
)

# Create new prompt with updated system message
#prompt_template = ChatPromptTemplate.from_messages([
# prompt_template = ChatPromptTemplate.from_messages([
# SystemMessagePromptTemplate.from_template(new_system_message),
# HumanMessagePromptTemplate.from_template(messages)
#])
# ])

# Compose messages from podcastfy_prompt_template and user_prompt_template
combined_messages = ChatPromptTemplate.from_messages([new_system_message]).messages + user_prompt_template.messages
combined_messages = (
ChatPromptTemplate.from_messages([new_system_message]).messages
+ user_prompt_template.messages
)

# Create a new ChatPromptTemplate object with the combined messages
composed_prompt_template = ChatPromptTemplate.from_messages(combined_messages)
Expand Down Expand Up @@ -162,6 +174,43 @@ def __compose_prompt_params(

return prompt_params

def __clean_scratchpad(self, text: str) -> str:
"""
Remove scratchpad blocks from the text.
Args:
text (str): Input text that may contain scratchpad blocks
Returns:
str: Text with scratchpad blocks removed
Example:
Input: '<Person1> (scratchpad)\n```\nSome notes\n```\nActual content</Person1>'
Output: '<Person1>Actual content</Person1>'
"""
try:
# Pattern to match scratchpad blocks:
# 1. Optional whitespace
# 2. (scratchpad) marker
# 3. Optional whitespace
# 4. Code block with any content
# 5. Optional whitespace before next content
pattern = r"\s*\(scratchpad\)\s*```.*?```\s*"

# Remove scratchpad blocks using regex
cleaned_text = re.sub(pattern, "", text, flags=re.DOTALL)

# Remove any resulting double newlines
cleaned_text = re.sub(r"\n\s*\n", "\n", cleaned_text)

# Remove any standalone (scratchpad) markers that might remain
cleaned_text = cleaned_text.replace("(scratchpad)", "")

return cleaned_text.strip()
except Exception as e:
logger.error(f"Error cleaning scratchpad content: {str(e)}")
return text # Return original text if cleaning fails

def generate_qa_content(
self,
input_texts: str = "",
Expand All @@ -175,11 +224,11 @@ def generate_qa_content(
Args:
input_texts (str): Input texts to generate content from.
image_file_paths (List[str]): List of image file paths.
output_filepath (Optional[str]): Filepath to save the response content. Defaults to None.
is_local (bool): Whether to use a local LLM or not. Defaults to False.
output_filepath (Optional[str]): Filepath to save the response content.
is_local (bool): Whether to use a local LLM or not.
Returns:
str: Formatted Q&A content.
str: Formatted Q&A content with scratchpad blocks removed.
Raises:
Exception: If there's an error in generating content.
Expand All @@ -197,7 +246,7 @@ def generate_qa_content(
)
if not is_local
else "User provided model"
)
),
)

num_images = 0 if is_local else len(image_file_paths)
Expand All @@ -209,7 +258,12 @@ def generate_qa_content(
image_file_paths, image_path_keys, input_texts
)

self.response = self.chain.invoke(prompt_params)
response_raw = self.chain.invoke(
prompt_params
) # in the future, make sure we have structured output

# Clean up scratchpad blocks from response
self.response = self.__clean_scratchpad(response_raw)

logger.info(f"Content generated successfully")

Expand Down
Loading

0 comments on commit cfa5155

Please sign in to comment.