Skip to content

Commit

Permalink
Merge pull request #379 from rmusser01/dev
Browse files Browse the repository at this point in the history
Keyword filter, top-k and re-ranking added to RAG chat UI
  • Loading branch information
rmusser01 authored Oct 20, 2024
2 parents 3d8b3d4 + 7eaccea commit cf62387
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 16 deletions.
22 changes: 15 additions & 7 deletions App_Function_Libraries/Gradio_UI/RAG_QA_Chat_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,19 @@ def update_conversation_list():
)
existing_file = gr.Dropdown(label="Select Existing File", choices=[], interactive=True)
file_page = gr.State(value=1)
with gr.Row():
page_number = gr.Number(value=1, label="Page", precision=0)
page_size = gr.Number(value=20, label="Items per page", precision=0)
total_pages = gr.Number(label="Total Pages", interactive=False)
with gr.Row():
prev_page_btn = gr.Button("Previous Page")
next_page_btn = gr.Button("Next Page")
page_info = gr.HTML("Page 1")
top_k_input = gr.Number(value=10, label="Maximum amount of results to use (Default: 10)", minimum=1, maximum=50, step=1, precision=0, interactive=True)
keywords_input = gr.Textbox(label="Keywords (comma-separated) to filter results by)", visible=True)
use_query_rewriting = gr.Checkbox(label="Use Query Rewriting", value=True)
use_re_ranking = gr.Checkbox(label="Use Re-ranking", value=True)
# with gr.Row():
# page_number = gr.Number(value=1, label="Page", precision=0)
# page_size = gr.Number(value=20, label="Items per page", precision=0)
# total_pages = gr.Number(label="Total Pages", interactive=False)


search_query = gr.Textbox(label="Search Query", visible=False)
search_button = gr.Button("Search", visible=False)
Expand Down Expand Up @@ -119,7 +124,6 @@ def update_conversation_list():
label="Select API for RAG",
value="OpenAI",
)
use_query_rewriting = gr.Checkbox(label="Use Query Rewriting", value=True)

with gr.Row():
with gr.Column(scale=2):
Expand Down Expand Up @@ -381,7 +385,8 @@ def rephrase_question(history, latest_question, api_choice):
return rephrased_question.strip()

def rag_qa_chat_wrapper(message, history, context_source, existing_file, search_results, file_upload,
convert_to_text, keywords, api_choice, use_query_rewriting, state_value):
convert_to_text, keywords, api_choice, use_query_rewriting, state_value,
keywords_input, top_k_input, use_re_ranking):
try:
logging.info(f"Starting rag_qa_chat_wrapper with message: {message}")
logging.info(f"Context source: {context_source}")
Expand Down Expand Up @@ -417,7 +422,8 @@ def rag_qa_chat_wrapper(message, history, context_source, existing_file, search_

if context_source == "All Files in the Database":
# Use the enhanced_rag_pipeline to search the entire database
context = enhanced_rag_pipeline(rephrased_question, api_choice)
context = enhanced_rag_pipeline(rephrased_question, api_choice, keywords_input, top_k_input,
use_re_ranking)
logging.info(f"Using enhanced_rag_pipeline for database search")
elif context_source == "Search Database":
context = f"media_id:{search_results.split('(ID: ')[1][:-1]}"
Expand Down Expand Up @@ -525,6 +531,8 @@ def clear_chat_history():
api_choice,
use_query_rewriting,
state,
keywords_input,
top_k_input
],
outputs=[chatbot, msg, loading_indicator, state],
)
Expand Down
19 changes: 10 additions & 9 deletions App_Function_Libraries/RAG/RAG_Library_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,9 @@
# return {"error": "An unexpected error occurred", "details": str(e)}



# RAG Search with keyword filtering
def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, apply_re_ranking=True) -> Dict[str, Any]:
# FIXME - Update each called function to support modifiable top-k results
def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, top_k=10, apply_re_ranking=True) -> Dict[str, Any]:
log_counter("enhanced_rag_pipeline_attempt", labels={"api_choice": api_choice})
start_time = time.time()
try:
Expand Down Expand Up @@ -175,8 +175,8 @@ def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, app
# Update all_results based on reranking
all_results = [all_results[result['id']] for result in reranked_results]

# Extract content from results (top 10)
context = "\n".join([result['content'] for result in all_results[:10]]) # Limit to top 10 results
# Extract content from results (top 10 by default)
context = "\n".join([result['content'] for result in all_results[:top_k]])
logging.debug(f"Context length: {len(context)}")
logging.debug(f"Context: {context[:200]}")

Expand Down Expand Up @@ -336,14 +336,15 @@ def generate_answer(api_choice: str, context: str, query: str) -> str:
logging.error(f"Error in generate_answer: {str(e)}")
return "An error occurred while generating the answer."

def perform_vector_search(query: str, relevant_media_ids: List[str] = None) -> List[Dict[str, Any]]:
def perform_vector_search(query: str, relevant_media_ids: List[str] = None, top_k=10) -> List[Dict[str, Any]]:
log_counter("perform_vector_search_attempt")
start_time = time.time()
all_collections = chroma_client.list_collections()
vector_results = []
try:
for collection in all_collections:
collection_results = vector_search(collection.name, query, k=5)
k = top_k
collection_results = vector_search(collection.name, query, k)
filtered_results = [
result for result in collection_results
if relevant_media_ids is None or result['metadata'].get('media_id') in relevant_media_ids
Expand All @@ -358,11 +359,11 @@ def perform_vector_search(query: str, relevant_media_ids: List[str] = None) -> L
logging.error(f"Error in perform_vector_search: {str(e)}")
raise

def perform_full_text_search(query: str, relevant_media_ids: List[str] = None) -> List[Dict[str, Any]]:
def perform_full_text_search(query: str, relevant_media_ids: List[str] = None, fts_top_k=None) -> List[Dict[str, Any]]:
log_counter("perform_full_text_search_attempt")
start_time = time.time()
try:
fts_results = search_db(query, ["content"], "", page=1, results_per_page=5)
fts_results = search_db(query, ["content"], "", page=1, results_per_page=fts_top_k or 10)
filtered_fts_results = [
{
"content": result['content'],
Expand All @@ -381,7 +382,7 @@ def perform_full_text_search(query: str, relevant_media_ids: List[str] = None) -
raise


def fetch_relevant_media_ids(keywords: List[str]) -> List[int]:
def fetch_relevant_media_ids(keywords: List[str], top_k=10) -> List[int]:
log_counter("fetch_relevant_media_ids_attempt", labels={"keyword_count": len(keywords)})
start_time = time.time()
relevant_ids = set()
Expand Down
7 changes: 7 additions & 0 deletions Docs/Issues/ISSUES.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,10 @@ Update model suggestions for RAG vs Chatting/General use
Whisper pipeline
https://huggingface.co/spaces/aadnk/faster-whisper-webui
https://huggingface.co/spaces/zhang082799/openai-whisper-large-v3-turbo


Create Documentation for how this can help
https://stevenberlinjohnson.com/how-to-use-notebooklm-as-a-research-tool-6ad5c3a227cc?gi=9a0b63820ff0

Create a blog post
tldwproject.com

0 comments on commit cf62387

Please sign in to comment.