From b8b949f3bb7efbc7aefcc8df75ae73a174560358 Mon Sep 17 00:00:00 2001 From: Josh Bradley Date: Thu, 13 Feb 2025 16:31:08 -0500 Subject: [PATCH] Cleanup query api - remove code duplication (#1690) * consolidate query api functions and remove code duplication * refactor and remove more code duplication * Add semversioner file * fix basic search * fix drift search and update base class function names * update example notebooks --- .../patch-20250210204532206223.json | 4 + docs/examples_notebooks/drift_search.ipynb | 4 +- docs/examples_notebooks/global_search.ipynb | 4 +- ...rch_with_dynamic_community_selection.ipynb | 4 +- docs/examples_notebooks/local_search.ipynb | 8 +- .../graph-visualization.ipynb | 4 +- graphrag/api/query.py | 225 +++++++----------- graphrag/query/structured_search/base.py | 19 +- .../structured_search/basic_search/search.py | 72 +----- .../structured_search/drift_search/action.py | 4 +- .../structured_search/drift_search/primer.py | 2 +- .../structured_search/drift_search/search.py | 36 +-- .../structured_search/global_search/search.py | 17 +- .../structured_search/local_search/search.py | 70 +----- 14 files changed, 130 insertions(+), 343 deletions(-) create mode 100644 .semversioner/next-release/patch-20250210204532206223.json diff --git a/.semversioner/next-release/patch-20250210204532206223.json b/.semversioner/next-release/patch-20250210204532206223.json new file mode 100644 index 0000000000..fcb6ad6eb2 --- /dev/null +++ b/.semversioner/next-release/patch-20250210204532206223.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "cleanup query code duplication." +} diff --git a/docs/examples_notebooks/drift_search.ipynb b/docs/examples_notebooks/drift_search.ipynb index 9f40eb81ba..6bd5df4684 100644 --- a/docs/examples_notebooks/drift_search.ipynb +++ b/docs/examples_notebooks/drift_search.ipynb @@ -327,7 +327,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -388,7 +388,7 @@ } ], "source": [ - "resp = await search.asearch(\"Who is agent Mercer?\")" + "resp = await search.search(\"Who is agent Mercer?\")" ] }, { diff --git a/docs/examples_notebooks/global_search.ipynb b/docs/examples_notebooks/global_search.ipynb index 1beff6b3ab..2250292444 100644 --- a/docs/examples_notebooks/global_search.ipynb +++ b/docs/examples_notebooks/global_search.ipynb @@ -392,7 +392,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -420,7 +420,7 @@ } ], "source": [ - "result = await search_engine.asearch(\n", + "result = await search_engine.search(\n", " \"What is Cosmic Vocalization and who are involved in it?\"\n", ")\n", "\n", diff --git a/docs/examples_notebooks/global_search_with_dynamic_community_selection.ipynb b/docs/examples_notebooks/global_search_with_dynamic_community_selection.ipynb index 589dcd976f..fcacc426c4 100644 --- a/docs/examples_notebooks/global_search_with_dynamic_community_selection.ipynb +++ b/docs/examples_notebooks/global_search_with_dynamic_community_selection.ipynb @@ -394,7 +394,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -420,7 +420,7 @@ } ], "source": [ - "result = await search_engine.asearch(\n", + "result = await search_engine.search(\n", " \"What is Cosmic Vocalization and who are involved in it?\"\n", ")\n", "\n", diff --git a/docs/examples_notebooks/local_search.ipynb b/docs/examples_notebooks/local_search.ipynb index 7c7600c92a..1601032027 100644 --- a/docs/examples_notebooks/local_search.ipynb +++ b/docs/examples_notebooks/local_search.ipynb @@ -963,7 +963,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -991,13 +991,13 @@ } ], "source": [ - "result = await search_engine.asearch(\"Tell me about Agent Mercer\")\n", + "result = await search_engine.search(\"Tell me about Agent Mercer\")\n", "print(result.response)" ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1026,7 +1026,7 @@ ], "source": [ "question = \"Tell me about Dr. Jordan Hayes\"\n", - "result = await search_engine.asearch(question)\n", + "result = await search_engine.search(question)\n", "print(result.response)" ] }, diff --git a/examples_notebooks/community_contrib/yfiles-jupyter-graphs/graph-visualization.ipynb b/examples_notebooks/community_contrib/yfiles-jupyter-graphs/graph-visualization.ipynb index 6e771d4a8a..bbf000f36e 100644 --- a/examples_notebooks/community_contrib/yfiles-jupyter-graphs/graph-visualization.ipynb +++ b/examples_notebooks/community_contrib/yfiles-jupyter-graphs/graph-visualization.ipynb @@ -384,7 +384,7 @@ "metadata": {}, "outputs": [], "source": [ - "result = await search_engine.asearch(\"Tell me about Agent Mercer\")\n", + "result = await search_engine.search(\"Tell me about Agent Mercer\")\n", "print(result.response)" ] }, @@ -395,7 +395,7 @@ "outputs": [], "source": [ "question = \"Tell me about Dr. Jordan Hayes\"\n", - "result = await search_engine.asearch(question)\n", + "result = await search_engine.search(question)\n", "print(result.response)" ] }, diff --git a/graphrag/api/query.py b/graphrag/api/query.py index f9c3ffcae4..8f2f539fd2 100644 --- a/graphrag/api/query.py +++ b/graphrag/api/query.py @@ -18,7 +18,7 @@ """ from collections.abc import AsyncGenerator -from typing import TYPE_CHECKING, Any +from typing import Any import pandas as pd from pydantic import validate_call @@ -53,9 +53,6 @@ ) from graphrag.utils.cli import redact -if TYPE_CHECKING: - from graphrag.query.structured_search.base import SearchResult - logger = PrintProgressLogger("") @@ -94,40 +91,27 @@ async def global_search( ------ TODO: Document any exceptions to expect. """ - communities_ = read_indexer_communities(communities, community_reports) - reports = read_indexer_reports( - community_reports, - communities, + full_response = "" + context_data = {} + get_context_data = True + # NOTE: when streaming, the first chunk of returned data is the complete context data. + # All subsequent chunks are the query response. + async for chunk in global_search_streaming( + config=config, + entities=entities, + communities=communities, + community_reports=community_reports, community_level=community_level, dynamic_community_selection=dynamic_community_selection, - ) - entities_ = read_indexer_entities( - entities, communities, community_level=community_level - ) - - map_prompt = load_search_prompt(config.root_dir, config.global_search.map_prompt) - reduce_prompt = load_search_prompt( - config.root_dir, config.global_search.reduce_prompt - ) - knowledge_prompt = load_search_prompt( - config.root_dir, config.global_search.knowledge_prompt - ) - - search_engine = get_global_search_engine( - config, - reports=reports, - entities=entities_, - communities=communities_, response_type=response_type, - dynamic_community_selection=dynamic_community_selection, - map_system_prompt=map_prompt, - reduce_system_prompt=reduce_prompt, - general_knowledge_inclusion_prompt=knowledge_prompt, - ) - result: SearchResult = await search_engine.asearch(query=query) - response = result.response - context_data = reformat_context_data(result.context_data) # type: ignore - return response, context_data + query=query, + ): + if get_context_data: + context_data = chunk + get_context_data = False + else: + full_response += chunk + return full_response, context_data @validate_call(config={"arbitrary_types_allowed": True}) @@ -193,11 +177,11 @@ async def global_search_streaming( reduce_system_prompt=reduce_prompt, general_knowledge_inclusion_prompt=knowledge_prompt, ) - search_result = search_engine.astream_search(query=query) + search_result = search_engine.stream_search(query=query) - # when streaming results, a context data object is returned as the first result + # NOTE: when streaming results, a context data object is returned as the first result # and the query response in subsequent tokens - context_data = None + context_data = {} get_context_data = True async for stream_chunk in search_result: if get_context_data: @@ -385,34 +369,29 @@ async def local_search( ------ TODO: Document any exceptions to expect. """ - vector_store_args = {} - for index, store in config.vector_store.items(): - vector_store_args[index] = store.model_dump() - logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa - - description_embedding_store = get_embedding_store( - config_args=vector_store_args, # type: ignore - embedding_name=entity_description_embedding, - ) - entities_ = read_indexer_entities(entities, communities, community_level) - covariates_ = read_indexer_covariates(covariates) if covariates is not None else [] - prompt = load_search_prompt(config.root_dir, config.local_search.prompt) - search_engine = get_local_search_engine( + full_response = "" + context_data = {} + get_context_data = True + # NOTE: when streaming, the first chunk of returned data is the complete context data. + # All subsequent chunks are the query response. + async for chunk in local_search_streaming( config=config, - reports=read_indexer_reports(community_reports, communities, community_level), - text_units=read_indexer_text_units(text_units), - entities=entities_, - relationships=read_indexer_relationships(relationships), - covariates={"claims": covariates_}, - description_embedding_store=description_embedding_store, # type: ignore + entities=entities, + communities=communities, + community_reports=community_reports, + text_units=text_units, + relationships=relationships, + covariates=covariates, + community_level=community_level, response_type=response_type, - system_prompt=prompt, - ) - - result: SearchResult = await search_engine.asearch(query=query) - response = result.response - context_data = reformat_context_data(result.context_data) # type: ignore - return response, context_data + query=query, + ): + if get_context_data: + context_data = chunk + get_context_data = False + else: + full_response += chunk + return full_response, context_data @validate_call(config={"arbitrary_types_allowed": True}) @@ -475,11 +454,11 @@ async def local_search_streaming( response_type=response_type, system_prompt=prompt, ) - search_result = search_engine.astream_search(query=query) + search_result = search_engine.stream_search(query=query) - # when streaming results, a context data object is returned as the first result + # NOTE: when streaming results, a context data object is returned as the first result # and the query response in subsequent tokens - context_data = None + context_data = {} get_context_data = True async for stream_chunk in search_result: if get_context_data: @@ -751,47 +730,28 @@ async def drift_search( ------ TODO: Document any exceptions to expect. """ - vector_store_args = {} - for index, store in config.vector_store.items(): - vector_store_args[index] = store.model_dump() - logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa - - description_embedding_store = get_embedding_store( - config_args=vector_store_args, # type: ignore - embedding_name=entity_description_embedding, - ) - - full_content_embedding_store = get_embedding_store( - config_args=vector_store_args, # type: ignore - embedding_name=community_full_content_embedding, - ) - - entities_ = read_indexer_entities(entities, communities, community_level) - reports = read_indexer_reports(community_reports, communities, community_level) - read_indexer_report_embeddings(reports, full_content_embedding_store) - prompt = load_search_prompt(config.root_dir, config.drift_search.prompt) - reduce_prompt = load_search_prompt( - config.root_dir, config.drift_search.reduce_prompt - ) - search_engine = get_drift_search_engine( + full_response = "" + context_data = {} + get_context_data = True + # NOTE: when streaming, the first chunk of returned data is the complete context data. + # All subsequent chunks are the query response. + async for chunk in drift_search_streaming( config=config, - reports=reports, - text_units=read_indexer_text_units(text_units), - entities=entities_, - relationships=read_indexer_relationships(relationships), - description_embedding_store=description_embedding_store, # type: ignore - local_system_prompt=prompt, - reduce_system_prompt=reduce_prompt, + entities=entities, + communities=communities, + community_reports=community_reports, + text_units=text_units, + relationships=relationships, + community_level=community_level, response_type=response_type, - ) - - result: SearchResult = await search_engine.asearch(query=query) - response = result.response - context_data = {} - for key in result.context_data: - context_data[key] = reformat_context_data(result.context_data[key]) # type: ignore - - return response, context_data + query=query, + ): + if get_context_data: + context_data = chunk + get_context_data = False + else: + full_response += chunk + return full_response, context_data @validate_call(config={"arbitrary_types_allowed": True}) @@ -860,12 +820,11 @@ async def drift_search_streaming( reduce_system_prompt=reduce_prompt, response_type=response_type, ) + search_result = search_engine.stream_search(query=query) - search_result = search_engine.astream_search(query=query) - - # when streaming results, a context data object is returned as the first result + # NOTE: when streaming results, a context data object is returned as the first result # and the query response in subsequent tokens - context_data = None + context_data = {} get_context_data = True async for stream_chunk in search_result: if get_context_data: @@ -1105,29 +1064,22 @@ async def basic_search( ------ TODO: Document any exceptions to expect. """ - vector_store_args = {} - for index, store in config.vector_store.items(): - vector_store_args[index] = store.model_dump() - logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa - - description_embedding_store = get_embedding_store( - config_args=vector_store_args, # type: ignore - embedding_name=text_unit_text_embedding, - ) - - prompt = load_search_prompt(config.root_dir, config.basic_search.prompt) - - search_engine = get_basic_search_engine( + full_response = "" + context_data = {} + get_context_data = True + # NOTE: when streaming, the first chunk of returned data is the complete context data. + # All subsequent chunks are the query response. + async for chunk in basic_search_streaming( config=config, - text_units=read_indexer_text_units(text_units), - text_unit_embeddings=description_embedding_store, - system_prompt=prompt, - ) - - result: SearchResult = await search_engine.asearch(query=query) - response = result.response - context_data = reformat_context_data(result.context_data) # type: ignore - return response, context_data + text_units=text_units, + query=query, + ): + if get_context_data: + context_data = chunk + get_context_data = False + else: + full_response += chunk + return full_response, context_data @validate_call(config={"arbitrary_types_allowed": True}) @@ -1155,8 +1107,6 @@ async def basic_search_streaming( vector_store_args = {} for index, store in config.vector_store.items(): vector_store_args[index] = store.model_dump() - else: - vector_store_args = None logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa description_embedding_store = get_embedding_store( @@ -1172,12 +1122,11 @@ async def basic_search_streaming( text_unit_embeddings=description_embedding_store, system_prompt=prompt, ) + search_result = search_engine.stream_search(query=query) - search_result = search_engine.astream_search(query=query) - - # when streaming results, a context data object is returned as the first result + # NOTE: when streaming results, a context data object is returned as the first result # and the query response in subsequent tokens - context_data = None + context_data = {} get_context_data = True async for stream_chunk in search_result: if get_context_data: diff --git a/graphrag/query/structured_search/base.py b/graphrag/query/structured_search/base.py index 749e89a07d..0f741dc259 100644 --- a/graphrag/query/structured_search/base.py +++ b/graphrag/query/structured_search/base.py @@ -69,27 +69,22 @@ def __init__( self.context_builder_params = context_builder_params or {} @abstractmethod - def search( - self, - query: str, - conversation_history: ConversationHistory | None = None, - **kwargs, - ) -> SearchResult: - """Search for the given query.""" - - @abstractmethod - async def asearch( + async def search( self, query: str, conversation_history: ConversationHistory | None = None, **kwargs, ) -> SearchResult: """Search for the given query asynchronously.""" + msg = "Subclasses must implement this method" + raise NotImplementedError(msg) @abstractmethod - def astream_search( + def stream_search( self, query: str, conversation_history: ConversationHistory | None = None, - ) -> AsyncGenerator[str, None] | None: + ) -> AsyncGenerator[Any, None]: """Stream search for the given query.""" + msg = "Subclasses must implement this method" + raise NotImplementedError(msg) diff --git a/graphrag/query/structured_search/basic_search/search.py b/graphrag/query/structured_search/basic_search/search.py index b97213c9f9..34372419ea 100644 --- a/graphrag/query/structured_search/basic_search/search.py +++ b/graphrag/query/structured_search/basic_search/search.py @@ -55,7 +55,7 @@ def __init__( self.callbacks = callbacks self.response_type = response_type - async def asearch( + async def search( self, query: str, conversation_history: ConversationHistory | None = None, @@ -121,77 +121,11 @@ async def asearch( output_tokens=0, ) - def search( + async def stream_search( self, query: str, conversation_history: ConversationHistory | None = None, - **kwargs, - ) -> SearchResult: - """Build basic search context that fits a single context window and generate answer for the user question.""" - start_time = time.time() - search_prompt = "" - llm_calls, prompt_tokens, output_tokens = {}, {}, {} - context_result = self.context_builder.build_context( - query=query, - conversation_history=conversation_history, - **kwargs, - **self.context_builder_params, - ) - llm_calls["build_context"] = context_result.llm_calls - prompt_tokens["build_context"] = context_result.prompt_tokens - output_tokens["build_context"] = context_result.output_tokens - - log.info("GENERATE ANSWER: %d. QUERY: %s", start_time, query) - try: - search_prompt = self.system_prompt.format( - context_data=context_result.context_chunks, - response_type=self.response_type, - ) - search_messages = [ - {"role": "system", "content": search_prompt}, - {"role": "user", "content": query}, - ] - - response = self.llm.generate( - messages=search_messages, - streaming=True, - callbacks=self.callbacks, - **self.llm_params, - ) - llm_calls["response"] = 1 - prompt_tokens["response"] = num_tokens(search_prompt, self.token_encoder) - output_tokens["response"] = num_tokens(response, self.token_encoder) - - return SearchResult( - response=response, - context_data=context_result.context_records, - context_text=context_result.context_chunks, - completion_time=time.time() - start_time, - llm_calls=sum(llm_calls.values()), - prompt_tokens=sum(prompt_tokens.values()), - output_tokens=sum(output_tokens.values()), - llm_calls_categories=llm_calls, - prompt_tokens_categories=prompt_tokens, - output_tokens_categories=output_tokens, - ) - - except Exception: - log.exception("Exception in _map_response_single_batch") - return SearchResult( - response="", - context_data=context_result.context_records, - context_text=context_result.context_chunks, - completion_time=time.time() - start_time, - llm_calls=1, - prompt_tokens=num_tokens(search_prompt, self.token_encoder), - output_tokens=0, - ) - - async def astream_search( - self, - query: str, - conversation_history: ConversationHistory | None = None, - ) -> AsyncGenerator: + ) -> AsyncGenerator[Any, None]: """Build basic search context that fits a single context window and generate answer for the user query.""" start_time = time.time() diff --git a/graphrag/query/structured_search/drift_search/action.py b/graphrag/query/structured_search/drift_search/action.py index 7f10a32109..7254eb33d0 100644 --- a/graphrag/query/structured_search/drift_search/action.py +++ b/graphrag/query/structured_search/drift_search/action.py @@ -50,7 +50,7 @@ def is_complete(self) -> bool: """Check if the action is complete (i.e., an answer is available).""" return self.answer is not None - async def asearch(self, search_engine: Any, global_query: str, scorer: Any = None): + async def search(self, search_engine: Any, global_query: str, scorer: Any = None): """ Execute an asynchronous search using the search engine, and update the action with the results. @@ -70,7 +70,7 @@ async def asearch(self, search_engine: Any, global_query: str, scorer: Any = Non log.warning("Action already complete. Skipping search.") return self - search_result = await search_engine.asearch( + search_result = await search_engine.search( drift_query=global_query, query=self.query ) diff --git a/graphrag/query/structured_search/drift_search/primer.py b/graphrag/query/structured_search/drift_search/primer.py index 8f3895a25c..53d472996b 100644 --- a/graphrag/query/structured_search/drift_search/primer.py +++ b/graphrag/query/structured_search/drift_search/primer.py @@ -154,7 +154,7 @@ async def decompose_query( return parsed_response, token_ct - async def asearch( + async def search( self, query: str, top_k_reports: pd.DataFrame, diff --git a/graphrag/query/structured_search/drift_search/search.py b/graphrag/query/structured_search/drift_search/search.py index 982e195cdb..773a7f1ea9 100644 --- a/graphrag/query/structured_search/drift_search/search.py +++ b/graphrag/query/structured_search/drift_search/search.py @@ -144,7 +144,7 @@ def _process_primer_results( error_msg = "Response must be a list of dictionaries." raise ValueError(error_msg) - async def asearch_step( + async def _search_step( self, global_query: str, search_engine: LocalSearch, actions: list[DriftAction] ) -> list[DriftAction]: """ @@ -160,12 +160,12 @@ async def asearch_step( list[DriftAction]: The results from executing the search actions asynchronously. """ tasks = [ - action.asearch(search_engine=search_engine, global_query=global_query) + action.search(search_engine=search_engine, global_query=global_query) for action in actions ] return await tqdm_asyncio.gather(*tasks, leave=False) - async def asearch( + async def search( self, query: str, conversation_history: Any = None, @@ -204,7 +204,7 @@ async def asearch( prompt_tokens["build_context"] = token_ct["prompt_tokens"] output_tokens["build_context"] = token_ct["prompt_tokens"] - primer_response = await self.primer.asearch( + primer_response = await self.primer.search( query=query, top_k_reports=primer_context ) llm_calls["primer"] = primer_response.llm_calls @@ -229,7 +229,7 @@ async def asearch( len(actions) - self.context_builder.config.drift_k_followups ) # Process actions - results = await self.asearch_step( + results = await self._search_step( global_query=query, search_engine=self.local_search, actions=actions ) @@ -278,37 +278,17 @@ async def asearch( output_tokens_categories=output_tokens, ) - def search( - self, - query: str, - conversation_history: Any = None, - **kwargs, - ) -> SearchResult: - """ - Perform a synchronous DRIFT search (Not Implemented). - - Args: - query (str): The query to search for. - conversation_history (Any, optional): The conversation history. - - Raises - ------ - NotImplementedError: Synchronous DRIFT is not implemented. - """ - error_msg = "Synchronous DRIFT is not implemented." - raise NotImplementedError(error_msg) - - async def astream_search( + async def stream_search( self, query: str, conversation_history: ConversationHistory | None = None ) -> AsyncGenerator[str, None]: """ - Perform a streaming DRIFT search (Not Implemented). + Perform a streaming DRIFT search asynchronously. Args: query (str): The query to search for. conversation_history (ConversationHistory, optional): The conversation history. """ - result = await self.asearch( + result = await self.search( query=query, conversation_history=conversation_history, reduce=False ) diff --git a/graphrag/query/structured_search/global_search/search.py b/graphrag/query/structured_search/global_search/search.py index ec0a5b71ec..2b076a77d0 100644 --- a/graphrag/query/structured_search/global_search/search.py +++ b/graphrag/query/structured_search/global_search/search.py @@ -102,7 +102,7 @@ def __init__( self.semaphore = asyncio.Semaphore(concurrent_coroutines) - async def astream_search( + async def stream_search( self, query: str, conversation_history: ConversationHistory | None = None, @@ -135,7 +135,7 @@ async def astream_search( ): yield response - async def asearch( + async def search( self, query: str, conversation_history: ConversationHistory | None = None, @@ -204,15 +204,6 @@ async def asearch( output_tokens_categories=output_tokens, ) - def search( - self, - query: str, - conversation_history: ConversationHistory | None = None, - **kwargs: Any, - ) -> GlobalSearchResult: - """Perform a global search synchronously.""" - return asyncio.run(self.asearch(query, conversation_history)) - async def _map_response_single_batch( self, context_data: str, @@ -235,7 +226,7 @@ async def _map_response_single_batch( log.info("Map response: %s", search_response) try: # parse search response json - processed_response = self.parse_search_response(search_response) + processed_response = self._parse_search_response(search_response) except ValueError: log.warning( "Warning: Error parsing search response json - skipping this batch" @@ -264,7 +255,7 @@ async def _map_response_single_batch( output_tokens=0, ) - def parse_search_response(self, search_response: str) -> list[dict[str, Any]]: + def _parse_search_response(self, search_response: str) -> list[dict[str, Any]]: """Parse the search response json and return a list of key points. Parameters diff --git a/graphrag/query/structured_search/local_search/search.py b/graphrag/query/structured_search/local_search/search.py index 1e29bde759..a26f6c691c 100644 --- a/graphrag/query/structured_search/local_search/search.py +++ b/graphrag/query/structured_search/local_search/search.py @@ -54,7 +54,7 @@ def __init__( self.callbacks = callbacks self.response_type = response_type - async def asearch( + async def search( self, query: str, conversation_history: ConversationHistory | None = None, @@ -128,7 +128,7 @@ async def asearch( output_tokens=0, ) - async def astream_search( + async def stream_search( self, query: str, conversation_history: ConversationHistory | None = None, @@ -158,69 +158,3 @@ async def astream_search( **self.llm_params, ): yield response - - def search( - self, - query: str, - conversation_history: ConversationHistory | None = None, - **kwargs, - ) -> SearchResult: - """Build local search context that fits a single context window and generate answer for the user question.""" - start_time = time.time() - search_prompt = "" - llm_calls, prompt_tokens, output_tokens = {}, {}, {} - context_result = self.context_builder.build_context( - query=query, - conversation_history=conversation_history, - **kwargs, - **self.context_builder_params, - ) - llm_calls["build_context"] = context_result.llm_calls - prompt_tokens["build_context"] = context_result.prompt_tokens - output_tokens["build_context"] = context_result.output_tokens - - log.info("GENERATE ANSWER: %d. QUERY: %s", start_time, query) - try: - search_prompt = self.system_prompt.format( - context_data=context_result.context_chunks, - response_type=self.response_type, - ) - search_messages = [ - {"role": "system", "content": search_prompt}, - {"role": "user", "content": query}, - ] - - response = self.llm.generate( - messages=search_messages, - streaming=True, - callbacks=self.callbacks, - **self.llm_params, - ) - llm_calls["response"] = 1 - prompt_tokens["response"] = num_tokens(search_prompt, self.token_encoder) - output_tokens["response"] = num_tokens(response, self.token_encoder) - - return SearchResult( - response=response, - context_data=context_result.context_records, - context_text=context_result.context_chunks, - completion_time=time.time() - start_time, - llm_calls=sum(llm_calls.values()), - prompt_tokens=sum(prompt_tokens.values()), - output_tokens=sum(output_tokens.values()), - llm_calls_categories=llm_calls, - prompt_tokens_categories=prompt_tokens, - output_tokens_categories=output_tokens, - ) - - except Exception: - log.exception("Exception in _map_response_single_batch") - return SearchResult( - response="", - context_data=context_result.context_records, - context_text=context_result.context_chunks, - completion_time=time.time() - start_time, - llm_calls=1, - prompt_tokens=num_tokens(search_prompt, self.token_encoder), - output_tokens=0, - )