From 4fa859a89fc09d3a662628d0f5f709aed588d343 Mon Sep 17 00:00:00 2001 From: Robert Date: Sun, 20 Oct 2024 15:33:27 -0700 Subject: [PATCH 1/2] Keyword filtering + top-k results returned added --- .../Gradio_UI/RAG_QA_Chat_tab.py | 21 ++++++++++++------- App_Function_Libraries/RAG/RAG_Library_2.py | 8 +++---- Docs/Issues/ISSUES.md | 7 +++++++ 3 files changed, 25 insertions(+), 11 deletions(-) diff --git a/App_Function_Libraries/Gradio_UI/RAG_QA_Chat_tab.py b/App_Function_Libraries/Gradio_UI/RAG_QA_Chat_tab.py index 68c8a092..8da14328 100644 --- a/App_Function_Libraries/Gradio_UI/RAG_QA_Chat_tab.py +++ b/App_Function_Libraries/Gradio_UI/RAG_QA_Chat_tab.py @@ -66,14 +66,18 @@ def update_conversation_list(): ) existing_file = gr.Dropdown(label="Select Existing File", choices=[], interactive=True) file_page = gr.State(value=1) - with gr.Row(): - page_number = gr.Number(value=1, label="Page", precision=0) - page_size = gr.Number(value=20, label="Items per page", precision=0) - total_pages = gr.Number(label="Total Pages", interactive=False) with gr.Row(): prev_page_btn = gr.Button("Previous Page") next_page_btn = gr.Button("Next Page") page_info = gr.HTML("Page 1") + top_k_input = gr.Number(value=10, label="Maximum amount of results to use (Default: 10)", minimum=1, maximum=50, step=1, precision=0, interactive=True) + keywords_input = gr.Textbox(label="Keywords (comma-separated) to filter results by)", visible=True) + use_query_rewriting = gr.Checkbox(label="Use Query Rewriting", value=True) + # with gr.Row(): + # page_number = gr.Number(value=1, label="Page", precision=0) + # page_size = gr.Number(value=20, label="Items per page", precision=0) + # total_pages = gr.Number(label="Total Pages", interactive=False) + search_query = gr.Textbox(label="Search Query", visible=False) search_button = gr.Button("Search", visible=False) @@ -119,7 +123,6 @@ def update_conversation_list(): label="Select API for RAG", value="OpenAI", ) - use_query_rewriting = gr.Checkbox(label="Use Query Rewriting", value=True) with gr.Row(): with gr.Column(scale=2): @@ -381,7 +384,8 @@ def rephrase_question(history, latest_question, api_choice): return rephrased_question.strip() def rag_qa_chat_wrapper(message, history, context_source, existing_file, search_results, file_upload, - convert_to_text, keywords, api_choice, use_query_rewriting, state_value): + convert_to_text, keywords, api_choice, use_query_rewriting, state_value, + keywords_input, top_k_input): try: logging.info(f"Starting rag_qa_chat_wrapper with message: {message}") logging.info(f"Context source: {context_source}") @@ -417,7 +421,8 @@ def rag_qa_chat_wrapper(message, history, context_source, existing_file, search_ if context_source == "All Files in the Database": # Use the enhanced_rag_pipeline to search the entire database - context = enhanced_rag_pipeline(rephrased_question, api_choice) + context = enhanced_rag_pipeline(rephrased_question, api_choice, keywords=keywords_input, + top_k=int(top_k_input)) logging.info(f"Using enhanced_rag_pipeline for database search") elif context_source == "Search Database": context = f"media_id:{search_results.split('(ID: ')[1][:-1]}" @@ -525,6 +530,8 @@ def clear_chat_history(): api_choice, use_query_rewriting, state, + keywords_input, + top_k_input ], outputs=[chatbot, msg, loading_indicator, state], ) diff --git a/App_Function_Libraries/RAG/RAG_Library_2.py b/App_Function_Libraries/RAG/RAG_Library_2.py index 131ed7f8..956523ea 100644 --- a/App_Function_Libraries/RAG/RAG_Library_2.py +++ b/App_Function_Libraries/RAG/RAG_Library_2.py @@ -115,9 +115,9 @@ # return {"error": "An unexpected error occurred", "details": str(e)} - # RAG Search with keyword filtering -def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, apply_re_ranking=True) -> Dict[str, Any]: +# FIXME - Update each called function to support modifiable top-k results +def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, top_k=10, apply_re_ranking=True) -> Dict[str, Any]: log_counter("enhanced_rag_pipeline_attempt", labels={"api_choice": api_choice}) start_time = time.time() try: @@ -175,8 +175,8 @@ def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, app # Update all_results based on reranking all_results = [all_results[result['id']] for result in reranked_results] - # Extract content from results (top 10) - context = "\n".join([result['content'] for result in all_results[:10]]) # Limit to top 10 results + # Extract content from results (top 10 by default) + context = "\n".join([result['content'] for result in all_results[:top_k]]) logging.debug(f"Context length: {len(context)}") logging.debug(f"Context: {context[:200]}") diff --git a/Docs/Issues/ISSUES.md b/Docs/Issues/ISSUES.md index c2290be7..2fee461e 100644 --- a/Docs/Issues/ISSUES.md +++ b/Docs/Issues/ISSUES.md @@ -19,3 +19,10 @@ Update model suggestions for RAG vs Chatting/General use Whisper pipeline https://huggingface.co/spaces/aadnk/faster-whisper-webui https://huggingface.co/spaces/zhang082799/openai-whisper-large-v3-turbo + + +Create Documentation for how this can help + https://stevenberlinjohnson.com/how-to-use-notebooklm-as-a-research-tool-6ad5c3a227cc?gi=9a0b63820ff0 + +Create a blog post + tldwproject.com \ No newline at end of file From 7eaccea588cccbad058ffd252eb326c2b8417eae Mon Sep 17 00:00:00 2001 From: Robert Date: Sun, 20 Oct 2024 16:17:22 -0700 Subject: [PATCH 2/2] And now we have re-ranking + top_k --- App_Function_Libraries/Gradio_UI/RAG_QA_Chat_tab.py | 7 ++++--- App_Function_Libraries/RAG/RAG_Library_2.py | 11 ++++++----- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/App_Function_Libraries/Gradio_UI/RAG_QA_Chat_tab.py b/App_Function_Libraries/Gradio_UI/RAG_QA_Chat_tab.py index 8da14328..6fe74be1 100644 --- a/App_Function_Libraries/Gradio_UI/RAG_QA_Chat_tab.py +++ b/App_Function_Libraries/Gradio_UI/RAG_QA_Chat_tab.py @@ -73,6 +73,7 @@ def update_conversation_list(): top_k_input = gr.Number(value=10, label="Maximum amount of results to use (Default: 10)", minimum=1, maximum=50, step=1, precision=0, interactive=True) keywords_input = gr.Textbox(label="Keywords (comma-separated) to filter results by)", visible=True) use_query_rewriting = gr.Checkbox(label="Use Query Rewriting", value=True) + use_re_ranking = gr.Checkbox(label="Use Re-ranking", value=True) # with gr.Row(): # page_number = gr.Number(value=1, label="Page", precision=0) # page_size = gr.Number(value=20, label="Items per page", precision=0) @@ -385,7 +386,7 @@ def rephrase_question(history, latest_question, api_choice): def rag_qa_chat_wrapper(message, history, context_source, existing_file, search_results, file_upload, convert_to_text, keywords, api_choice, use_query_rewriting, state_value, - keywords_input, top_k_input): + keywords_input, top_k_input, use_re_ranking): try: logging.info(f"Starting rag_qa_chat_wrapper with message: {message}") logging.info(f"Context source: {context_source}") @@ -421,8 +422,8 @@ def rag_qa_chat_wrapper(message, history, context_source, existing_file, search_ if context_source == "All Files in the Database": # Use the enhanced_rag_pipeline to search the entire database - context = enhanced_rag_pipeline(rephrased_question, api_choice, keywords=keywords_input, - top_k=int(top_k_input)) + context = enhanced_rag_pipeline(rephrased_question, api_choice, keywords_input, top_k_input, + use_re_ranking) logging.info(f"Using enhanced_rag_pipeline for database search") elif context_source == "Search Database": context = f"media_id:{search_results.split('(ID: ')[1][:-1]}" diff --git a/App_Function_Libraries/RAG/RAG_Library_2.py b/App_Function_Libraries/RAG/RAG_Library_2.py index 956523ea..9e3fad3b 100644 --- a/App_Function_Libraries/RAG/RAG_Library_2.py +++ b/App_Function_Libraries/RAG/RAG_Library_2.py @@ -336,14 +336,15 @@ def generate_answer(api_choice: str, context: str, query: str) -> str: logging.error(f"Error in generate_answer: {str(e)}") return "An error occurred while generating the answer." -def perform_vector_search(query: str, relevant_media_ids: List[str] = None) -> List[Dict[str, Any]]: +def perform_vector_search(query: str, relevant_media_ids: List[str] = None, top_k=10) -> List[Dict[str, Any]]: log_counter("perform_vector_search_attempt") start_time = time.time() all_collections = chroma_client.list_collections() vector_results = [] try: for collection in all_collections: - collection_results = vector_search(collection.name, query, k=5) + k = top_k + collection_results = vector_search(collection.name, query, k) filtered_results = [ result for result in collection_results if relevant_media_ids is None or result['metadata'].get('media_id') in relevant_media_ids @@ -358,11 +359,11 @@ def perform_vector_search(query: str, relevant_media_ids: List[str] = None) -> L logging.error(f"Error in perform_vector_search: {str(e)}") raise -def perform_full_text_search(query: str, relevant_media_ids: List[str] = None) -> List[Dict[str, Any]]: +def perform_full_text_search(query: str, relevant_media_ids: List[str] = None, fts_top_k=None) -> List[Dict[str, Any]]: log_counter("perform_full_text_search_attempt") start_time = time.time() try: - fts_results = search_db(query, ["content"], "", page=1, results_per_page=5) + fts_results = search_db(query, ["content"], "", page=1, results_per_page=fts_top_k or 10) filtered_fts_results = [ { "content": result['content'], @@ -381,7 +382,7 @@ def perform_full_text_search(query: str, relevant_media_ids: List[str] = None) - raise -def fetch_relevant_media_ids(keywords: List[str]) -> List[int]: +def fetch_relevant_media_ids(keywords: List[str], top_k=10) -> List[int]: log_counter("fetch_relevant_media_ids_attempt", labels={"keyword_count": len(keywords)}) start_time = time.time() relevant_ids = set()