Skip to content

Commit

Permalink
And now we have re-ranking + top_k
Browse files Browse the repository at this point in the history
  • Loading branch information
rmusser01 committed Oct 20, 2024
1 parent 4fa859a commit 7eaccea
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
7 changes: 4 additions & 3 deletions App_Function_Libraries/Gradio_UI/RAG_QA_Chat_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def update_conversation_list():
top_k_input = gr.Number(value=10, label="Maximum amount of results to use (Default: 10)", minimum=1, maximum=50, step=1, precision=0, interactive=True)
keywords_input = gr.Textbox(label="Keywords (comma-separated) to filter results by)", visible=True)
use_query_rewriting = gr.Checkbox(label="Use Query Rewriting", value=True)
use_re_ranking = gr.Checkbox(label="Use Re-ranking", value=True)
# with gr.Row():
# page_number = gr.Number(value=1, label="Page", precision=0)
# page_size = gr.Number(value=20, label="Items per page", precision=0)
Expand Down Expand Up @@ -385,7 +386,7 @@ def rephrase_question(history, latest_question, api_choice):

def rag_qa_chat_wrapper(message, history, context_source, existing_file, search_results, file_upload,
convert_to_text, keywords, api_choice, use_query_rewriting, state_value,
keywords_input, top_k_input):
keywords_input, top_k_input, use_re_ranking):
try:
logging.info(f"Starting rag_qa_chat_wrapper with message: {message}")
logging.info(f"Context source: {context_source}")
Expand Down Expand Up @@ -421,8 +422,8 @@ def rag_qa_chat_wrapper(message, history, context_source, existing_file, search_

if context_source == "All Files in the Database":
# Use the enhanced_rag_pipeline to search the entire database
context = enhanced_rag_pipeline(rephrased_question, api_choice, keywords=keywords_input,
top_k=int(top_k_input))
context = enhanced_rag_pipeline(rephrased_question, api_choice, keywords_input, top_k_input,
use_re_ranking)
logging.info(f"Using enhanced_rag_pipeline for database search")
elif context_source == "Search Database":
context = f"media_id:{search_results.split('(ID: ')[1][:-1]}"
Expand Down
11 changes: 6 additions & 5 deletions App_Function_Libraries/RAG/RAG_Library_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,14 +336,15 @@ def generate_answer(api_choice: str, context: str, query: str) -> str:
logging.error(f"Error in generate_answer: {str(e)}")
return "An error occurred while generating the answer."

def perform_vector_search(query: str, relevant_media_ids: List[str] = None) -> List[Dict[str, Any]]:
def perform_vector_search(query: str, relevant_media_ids: List[str] = None, top_k=10) -> List[Dict[str, Any]]:
log_counter("perform_vector_search_attempt")
start_time = time.time()
all_collections = chroma_client.list_collections()
vector_results = []
try:
for collection in all_collections:
collection_results = vector_search(collection.name, query, k=5)
k = top_k
collection_results = vector_search(collection.name, query, k)
filtered_results = [
result for result in collection_results
if relevant_media_ids is None or result['metadata'].get('media_id') in relevant_media_ids
Expand All @@ -358,11 +359,11 @@ def perform_vector_search(query: str, relevant_media_ids: List[str] = None) -> L
logging.error(f"Error in perform_vector_search: {str(e)}")
raise

def perform_full_text_search(query: str, relevant_media_ids: List[str] = None) -> List[Dict[str, Any]]:
def perform_full_text_search(query: str, relevant_media_ids: List[str] = None, fts_top_k=None) -> List[Dict[str, Any]]:
log_counter("perform_full_text_search_attempt")
start_time = time.time()
try:
fts_results = search_db(query, ["content"], "", page=1, results_per_page=5)
fts_results = search_db(query, ["content"], "", page=1, results_per_page=fts_top_k or 10)
filtered_fts_results = [
{
"content": result['content'],
Expand All @@ -381,7 +382,7 @@ def perform_full_text_search(query: str, relevant_media_ids: List[str] = None) -
raise


def fetch_relevant_media_ids(keywords: List[str]) -> List[int]:
def fetch_relevant_media_ids(keywords: List[str], top_k=10) -> List[int]:
log_counter("fetch_relevant_media_ids_attempt", labels={"keyword_count": len(keywords)})
start_time = time.time()
relevant_ids = set()
Expand Down

0 comments on commit 7eaccea

Please sign in to comment.