diff --git a/nl_server/embeddings.py b/nl_server/embeddings.py index 2d58ce7f8f..0a7f395d39 100644 --- a/nl_server/embeddings.py +++ b/nl_server/embeddings.py @@ -95,3 +95,8 @@ def vector_search(self, queries: List[str], top_k: int) -> SearchVarsResult: # Turn this into a map: return {k: v for k, v in zip(queries, results)} + + +class NoEmbeddingsException(Exception): + """Custom exception raised when no embeddings are found in the embeddings csv.""" + pass diff --git a/nl_server/flask.py b/nl_server/flask.py index 00e13a519b..d7b00bb145 100644 --- a/nl_server/flask.py +++ b/nl_server/flask.py @@ -58,8 +58,15 @@ def create_app(): if not lib_utils.is_test_env(): # Below is a safe check to ensure that the model and embedding is loaded. server_config = reg.server_config() - idx_type = server_config.default_indexes[0] - embeddings = reg.get_index(idx_type) + + def _get_first_available_embeddings(): + for idx in server_config.default_indexes: + embeddings = reg.get_index(idx) + if embeddings: + return (idx, embeddings) + raise ValueError('No embeddings found') + + idx_type, embeddings = _get_first_available_embeddings() query = server_config.indexes[idx_type].healthcheck_query result = search.search_vars([embeddings], [query]).get(query) if not result or not result.svs: diff --git a/nl_server/registry.py b/nl_server/registry.py index 6a522ea973..0577f02f51 100644 --- a/nl_server/registry.py +++ b/nl_server/registry.py @@ -23,6 +23,7 @@ from nl_server.config import StoreType from nl_server.embeddings import Embeddings from nl_server.embeddings import EmbeddingsModel +from nl_server.embeddings import NoEmbeddingsException from nl_server.model.attribute_model import AttributeModel from nl_server.model.create import create_embeddings_model from nl_server.ranking import RerankingModel @@ -108,6 +109,15 @@ def _set_embeddings(self, idx_name: str, idx_info: IndexConfig): return elif idx_info.store_type == StoreType.VERTEXAI: store = VertexAIStore(idx_info) + except NoEmbeddingsException as e: + if not is_custom_dc(): + raise e + # Some custom DCs may not have SVs or topics in which case no embeddings is a valid condition. + # We log a warning and skip it in that case. + logging.warning( + f'No embeddings found for the following index and will be skipped: {idx_info}' + ) + store = None except Exception as e: logging.error(f'error loading index {idx_name}: {str(e)} ') raise e diff --git a/nl_server/store/memory.py b/nl_server/store/memory.py index cb21c58100..d2c8761aa2 100644 --- a/nl_server/store/memory.py +++ b/nl_server/store/memory.py @@ -13,6 +13,7 @@ # limitations under the License. """In-memory Embeddings store.""" +import csv import logging from typing import List @@ -25,6 +26,7 @@ from nl_server.embeddings import EmbeddingsMatch from nl_server.embeddings import EmbeddingsResult from nl_server.embeddings import EmbeddingsStore +from nl_server.embeddings import NoEmbeddingsException from shared.lib.custom_dc_util import use_anonymous_gcs_client from shared.lib.gcs import is_gcs_path from shared.lib.gcs import maybe_download @@ -53,6 +55,10 @@ def __init__(self, idx_info: MemoryIndexConfig) -> None: f'"embeddings_path" path must start with `/` or `gs://`: {idx_info.embeddings_path}' ) + # Raise no embeddings exception if the embeddings path does not have any embeddings. + if _is_csv_empty_or_header_only(embeddings_path): + raise NoEmbeddingsException() + self.dataset_embeddings: torch.Tensor = None self.dcids: List[str] = [] self.sentences: List[str] = [] @@ -99,3 +105,27 @@ def vector_search(self, query_embeddings: torch.Tensor, results.append(matches) return results + + +def _is_csv_empty_or_header_only(file_path): + """ + Checks if a CSV file is empty or only contains the header row. + + Args: + file_path: The path to the CSV file. + + Returns: + True if the CSV file is empty or has only the header, False otherwise. + """ + with open(file_path, 'r', newline='') as csvfile: + reader = csv.reader(csvfile) + try: + # Read the first row (header) + next(reader) + # Try reading the second row + next(reader) + # If no exception is raised, there are more rows than just the header + return False + except StopIteration: + # StopIteration is raised if there are no more rows to read + return True