diff --git a/App_Function_Libraries/Gradio_UI/RAG_QA_Chat_tab.py b/App_Function_Libraries/Gradio_UI/RAG_QA_Chat_tab.py index 8da14328..6fe74be1 100644 --- a/App_Function_Libraries/Gradio_UI/RAG_QA_Chat_tab.py +++ b/App_Function_Libraries/Gradio_UI/RAG_QA_Chat_tab.py @@ -73,6 +73,7 @@ def update_conversation_list(): top_k_input = gr.Number(value=10, label="Maximum amount of results to use (Default: 10)", minimum=1, maximum=50, step=1, precision=0, interactive=True) keywords_input = gr.Textbox(label="Keywords (comma-separated) to filter results by)", visible=True) use_query_rewriting = gr.Checkbox(label="Use Query Rewriting", value=True) + use_re_ranking = gr.Checkbox(label="Use Re-ranking", value=True) # with gr.Row(): # page_number = gr.Number(value=1, label="Page", precision=0) # page_size = gr.Number(value=20, label="Items per page", precision=0) @@ -385,7 +386,7 @@ def rephrase_question(history, latest_question, api_choice): def rag_qa_chat_wrapper(message, history, context_source, existing_file, search_results, file_upload, convert_to_text, keywords, api_choice, use_query_rewriting, state_value, - keywords_input, top_k_input): + keywords_input, top_k_input, use_re_ranking): try: logging.info(f"Starting rag_qa_chat_wrapper with message: {message}") logging.info(f"Context source: {context_source}") @@ -421,8 +422,8 @@ def rag_qa_chat_wrapper(message, history, context_source, existing_file, search_ if context_source == "All Files in the Database": # Use the enhanced_rag_pipeline to search the entire database - context = enhanced_rag_pipeline(rephrased_question, api_choice, keywords=keywords_input, - top_k=int(top_k_input)) + context = enhanced_rag_pipeline(rephrased_question, api_choice, keywords_input, top_k_input, + use_re_ranking) logging.info(f"Using enhanced_rag_pipeline for database search") elif context_source == "Search Database": context = f"media_id:{search_results.split('(ID: ')[1][:-1]}" diff --git a/App_Function_Libraries/RAG/RAG_Library_2.py b/App_Function_Libraries/RAG/RAG_Library_2.py index 956523ea..9e3fad3b 100644 --- a/App_Function_Libraries/RAG/RAG_Library_2.py +++ b/App_Function_Libraries/RAG/RAG_Library_2.py @@ -336,14 +336,15 @@ def generate_answer(api_choice: str, context: str, query: str) -> str: logging.error(f"Error in generate_answer: {str(e)}") return "An error occurred while generating the answer." -def perform_vector_search(query: str, relevant_media_ids: List[str] = None) -> List[Dict[str, Any]]: +def perform_vector_search(query: str, relevant_media_ids: List[str] = None, top_k=10) -> List[Dict[str, Any]]: log_counter("perform_vector_search_attempt") start_time = time.time() all_collections = chroma_client.list_collections() vector_results = [] try: for collection in all_collections: - collection_results = vector_search(collection.name, query, k=5) + k = top_k + collection_results = vector_search(collection.name, query, k) filtered_results = [ result for result in collection_results if relevant_media_ids is None or result['metadata'].get('media_id') in relevant_media_ids @@ -358,11 +359,11 @@ def perform_vector_search(query: str, relevant_media_ids: List[str] = None) -> L logging.error(f"Error in perform_vector_search: {str(e)}") raise -def perform_full_text_search(query: str, relevant_media_ids: List[str] = None) -> List[Dict[str, Any]]: +def perform_full_text_search(query: str, relevant_media_ids: List[str] = None, fts_top_k=None) -> List[Dict[str, Any]]: log_counter("perform_full_text_search_attempt") start_time = time.time() try: - fts_results = search_db(query, ["content"], "", page=1, results_per_page=5) + fts_results = search_db(query, ["content"], "", page=1, results_per_page=fts_top_k or 10) filtered_fts_results = [ { "content": result['content'], @@ -381,7 +382,7 @@ def perform_full_text_search(query: str, relevant_media_ids: List[str] = None) - raise -def fetch_relevant_media_ids(keywords: List[str]) -> List[int]: +def fetch_relevant_media_ids(keywords: List[str], top_k=10) -> List[int]: log_counter("fetch_relevant_media_ids_attempt", labels={"keyword_count": len(keywords)}) start_time = time.time() relevant_ids = set()