Skip to content

Commit

Permalink
Added preliminary support for llama.cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
rmusser01 committed May 5, 2024
1 parent dc8c111 commit c38a329
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 57 deletions.
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,18 @@ By default videos, transcriptions and summaries are stored in a folder with the
3. Run `python diarize.py <video_url>` or `python diarize.py <List_of_videos.txt>`
4. If you want summarization, add your API keys (if needed[is needed for now]) to the `config.txt` file, and then re-run the script, passing in the name of the API [or URL endpoint - to be added] to the script.
* `python diarize.py https://www.youtube.com/watch?v=4nd1CDZP21s --api_name anthropic` - This will attempt to download the video, then upload the resulting json file to the anthropic API endpoint, referring to values set in the config file (API key and model) to request summarization.
- OpenAI:
- Anthropic:
* Opus: `claude-3-opus-20240229`
* Sonnet: `claude-3-sonnet-20240229`
* Haiku: `claude-3-haiku-20240307`
- Cohere:
* `command-r`
* `command-r-plus`
- OpenAI:
* `gpt-4-turbo`
* `gpt-4-turbo-preview`
* `gpt-4`


### What's in the repo?
- `diarize.py` - download, transcribe and diarize audio
Expand Down
2 changes: 2 additions & 0 deletions config.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ openai_api_key = <openai_api_key>
openai_model = gpt-4-turbo
cohere_api_key = <your_cohere_api_key>
cohere_model = command-r-plus
llama_api_key = <llama.cpp api key>
llama_api_IP = <IP:port of llama.cpp server>

[Paths]
output_path = Results
Expand Down
103 changes: 47 additions & 56 deletions diarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,15 @@
cohere_api_key = config.get('API', 'cohere_api_key', fallback=None)
anthropic_api_key = config.get('API', 'anthropic_api_key', fallback=None)
openai_api_key = config.get('API', 'openai_api_key', fallback=None)
llama_api_key = config.get('API', 'llama_api_key', fallback=None)

# Models
anthropic_model = config.get('API', 'anthropic_model', fallback='claude-v1')
cohere_model = config.get('API', 'cohere_model', fallback='base_model')
openai_model = config.get('API', 'openai_model', fallback='gpt-3.5-turbo')
anthropic_model = config.get('API', 'anthropic_model', fallback='claude-3-sonnet-20240229')
cohere_model = config.get('API', 'cohere_model', fallback='command-r-plus')
openai_model = config.get('API', 'openai_model', fallback='gpt-4-turbo')

# Local-Models
llama_ip = config.get('API', 'llama_api_IP', fallback='127.0.0.1:8080/v1/chat/completions')

# Retrieve output paths from the configuration file
output_path = config.get('Paths', 'output_path', fallback='Results')
Expand Down Expand Up @@ -651,60 +655,7 @@ def summarize_with_claude(api_key, file_path, model):
except Exception as e:
print("Error occurred while processing summary with Claude:", str(e))
return None
"""
def summarize_with_claude(api_key, file_path, model):
try:
# Load your JSON data
with open(file_path, 'r') as file:
segments = json.load(file)
# Extract text from the segments
text = extract_text_from_segments(segments)

headers = {
'x-api-key': api_key,
'anthropic-version': '2023-06-01',
'Content-Type': 'application/json'
}
# Prepare the data for the Claude API
user_message = {
"role": "user",
"content": f"{text} \n\n\n\nPlease provide a detailed, bulleted list of the points made throughout the transcribed video and any supporting arguments made for said points"
}
data = {
"model": model,
"messages": [user_message],
"max_tokens": 4096, # max _possible_ tokens to return
"stop_sequences": ["\n\nHuman:"],
"temperature": 0.7,
"top_k": 0,
"top_p": 1.0,
"metadata": {
"user_id": "example_user_id",
},
"stream": False,
"system": "You are a professional summarizer."
}
response = requests.post('https://api.anthropic.com/v1/messages', headers=headers, json=data)
if response.status_code == 200:
if 'completion' in response.json():
summary = response.json()['completion'].strip()
print("Summary processed successfully.")
return summary
else:
print("Unexpected response format from Claude API:", response.text)
return None
else:
print("Failed to process summary:", response.text)
return None
except Exception as e:
print("Error occurred while processing summary with Claude:", str(e))
return None
"""


# Summarize with Cohere
Expand Down Expand Up @@ -746,6 +697,44 @@ def summarize_with_cohere(api_key, file_path, model):



def summarize_with_llama(api_url, file_path):
try:
# Load your JSON data
with open(file_path, 'r') as file:
segments = json.load(file)

# Extract text from the segments
text = extract_text_from_segments(segments)

# Prepare the data for the llama.cpp API
data = {
"prompt": f"{text} \n\n\n\nPlease provide a detailed, bulleted list of the points made throughout the transcribed video and any supporting arguments made for said points",
"max_tokens": 4096,
"stop": ["\n\nHuman:"],
"temperature": 0.7,
"top_k": 0,
"top_p": 1.0,
"repeat_penalty": 1.0,
"repeat_last_n": 64,
"seed": -1,
"threads": 4,
"n_predict": 4096
}

response = requests.post(api_url, json=data)

if response.status_code == 200:
summary = response.json()['content'].strip()
print("Summary processed successfully.")
return summary
else:
print("Failed to process summary:", response.text)
return None
except Exception as e:
print("Error occurred while processing summary with llama.cpp:", str(e))
return None


def save_summary_to_file(summary, file_path):
summary_file_path = file_path.replace('.segments.json', '_summary.txt')
with open(summary_file_path, 'w') as file:
Expand Down Expand Up @@ -809,6 +798,8 @@ def main(input_path, api_name=None, api_key=None, num_speakers=2, whisper_model=
elif api_name.lower() == 'cohere':
api_key = cohere_api_key
summary = summarize_with_cohere(api_key, json_file_path, cohere_model)
elif api_name.lower() == 'llama':
summary = summarize_with_llmaa(llama_ip, json_file_path)
else:
logging.warning(f"Unsupported API: {api_name}")
summary = None
Expand Down

0 comments on commit c38a329

Please sign in to comment.