diff --git a/App_Function_Libraries/RAG/RAG_Library_2.py b/App_Function_Libraries/RAG/RAG_Library_2.py index 9e3fad3b..10c8c5df 100644 --- a/App_Function_Libraries/RAG/RAG_Library_2.py +++ b/App_Function_Libraries/RAG/RAG_Library_2.py @@ -343,8 +343,7 @@ def perform_vector_search(query: str, relevant_media_ids: List[str] = None, top_ vector_results = [] try: for collection in all_collections: - k = top_k - collection_results = vector_search(collection.name, query, k) + collection_results = vector_search(collection.name, query, k=top_k) filtered_results = [ result for result in collection_results if relevant_media_ids is None or result['metadata'].get('media_id') in relevant_media_ids diff --git a/Tests/RAG/test_RAG_Library_2.py b/Tests/RAG/test_RAG_Library_2.py index d5c26491..d0dc7d05 100644 --- a/Tests/RAG/test_RAG_Library_2.py +++ b/Tests/RAG/test_RAG_Library_2.py @@ -115,7 +115,7 @@ def test_perform_vector_search_with_relevant_media_ids(self, mock_chroma_client, mock_chroma_client.list_collections.assert_called_once() # Assert vector_search was called with correct arguments - mock_vector_search.assert_called_once_with('collection1', query, k=5) + mock_vector_search.assert_called_once_with('collection1', query, k=10) @patch('App_Function_Libraries.RAG.RAG_Library_2.vector_search') @patch('App_Function_Libraries.RAG.RAG_Library_2.chroma_client') @@ -152,7 +152,7 @@ def test_perform_vector_search_without_relevant_media_ids(self, mock_chroma_clie mock_chroma_client.list_collections.assert_called_once() # Assert vector_search was called with correct arguments - mock_vector_search.assert_called_once_with('collection1', query, k=5) + mock_vector_search.assert_called_once_with('collection1', query, k=10) @patch('App_Function_Libraries.RAG.RAG_Library_2.search_db') def test_perform_full_text_search_with_relevant_media_ids(self, mock_search_db): @@ -171,7 +171,7 @@ def test_perform_full_text_search_with_relevant_media_ids(self, mock_search_db): relevant_media_ids = [1, 3] # Call the function - result = perform_full_text_search(query, relevant_media_ids) + result = perform_full_text_search(query, relevant_media_ids, fts_top_k=10) # Expected to filter out id 2 expected = [ @@ -182,7 +182,7 @@ def test_perform_full_text_search_with_relevant_media_ids(self, mock_search_db): # Assert search_db was called with correct arguments mock_search_db.assert_called_once_with( - query, ['content'], '', page=1, results_per_page=5) + query, ['content'], '', page=1, results_per_page=10) @patch('App_Function_Libraries.RAG.RAG_Library_2.search_db') def test_perform_full_text_search_without_relevant_media_ids(self, mock_search_db): @@ -211,7 +211,7 @@ def test_perform_full_text_search_without_relevant_media_ids(self, mock_search_d # Assert search_db was called with correct arguments mock_search_db.assert_called_once_with( - query, ['content'], '', page=1, results_per_page=5) + query, ['content'], '', page=1, results_per_page=10) @patch('App_Function_Libraries.RAG.RAG_Library_2.search_db') def test_perform_full_text_search_empty_results(self, mock_search_db): @@ -234,7 +234,7 @@ def test_perform_full_text_search_empty_results(self, mock_search_db): # Assert search_db was called with correct arguments mock_search_db.assert_called_once_with( - query, ['content'], '', page=1, results_per_page=5) + query, ['content'], '', page=1, results_per_page=10) @patch('App_Function_Libraries.RAG.RAG_Library_2.fetch_keywords_for_media') @patch('App_Function_Libraries.RAG.RAG_Library_2.logging') @@ -344,7 +344,7 @@ def test_perform_full_text_search_case_insensitive_filtering(self, mock_search_d # Assert search_db was called with correct arguments mock_search_db.assert_called_once_with( - query, ['content'], '', page=1, results_per_page=5) + query, ['content'], '', page=1, results_per_page=10) @patch('App_Function_Libraries.RAG.RAG_Library_2.search_db') def test_perform_full_text_search_multiple_pages(self, mock_search_db): @@ -380,7 +380,7 @@ def test_perform_full_text_search_multiple_pages(self, mock_search_db): # Assert search_db was called with correct arguments mock_search_db.assert_called_once_with( - query, ['content'], '', page=1, results_per_page=5) + query, ['content'], '', page=1, results_per_page=10) @patch('App_Function_Libraries.RAG.RAG_Library_2.chroma_client') @patch('App_Function_Libraries.RAG.RAG_Library_2.vector_search') @@ -429,8 +429,8 @@ def vector_search_side_effect(collection_name, query, k): mock_chroma_client.list_collections.assert_called_once() # Assert vector_search was called twice with correct arguments - mock_vector_search.assert_any_call('collection1', query, k=5) - mock_vector_search.assert_any_call('collection2', query, k=5) + mock_vector_search.assert_any_call('collection1', query, k=10) + mock_vector_search.assert_any_call('collection2', query, k=10) self.assertEqual(mock_vector_search.call_count, 2) @patch('App_Function_Libraries.RAG.RAG_Library_2.search_db') @@ -462,7 +462,7 @@ def test_perform_full_text_search_partial_matches(self, mock_search_db): # Assert search_db was called with correct arguments mock_search_db.assert_called_once_with( - query, ['content'], '', page=1, results_per_page=5) + query, ['content'], '', page=1, results_per_page=10) if __name__ == '__main__':