Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip empty embeddings in custom DC #4836

Merged
merged 1 commit into from
Jan 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions nl_server/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,8 @@ def vector_search(self, queries: List[str], top_k: int) -> SearchVarsResult:

# Turn this into a map:
return {k: v for k, v in zip(queries, results)}


class NoEmbeddingsException(Exception):
"""Custom exception raised when no embeddings are found in the embeddings csv."""
pass
11 changes: 9 additions & 2 deletions nl_server/flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,15 @@ def create_app():
if not lib_utils.is_test_env():
# Below is a safe check to ensure that the model and embedding is loaded.
server_config = reg.server_config()
idx_type = server_config.default_indexes[0]
embeddings = reg.get_index(idx_type)

def _get_first_available_embeddings():
for idx in server_config.default_indexes:
embeddings = reg.get_index(idx)
if embeddings:
return (idx, embeddings)
raise ValueError('No embeddings found')

idx_type, embeddings = _get_first_available_embeddings()
query = server_config.indexes[idx_type].healthcheck_query
result = search.search_vars([embeddings], [query]).get(query)
if not result or not result.svs:
Expand Down
10 changes: 10 additions & 0 deletions nl_server/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from nl_server.config import StoreType
from nl_server.embeddings import Embeddings
from nl_server.embeddings import EmbeddingsModel
from nl_server.embeddings import NoEmbeddingsException
from nl_server.model.attribute_model import AttributeModel
from nl_server.model.create import create_embeddings_model
from nl_server.ranking import RerankingModel
Expand Down Expand Up @@ -108,6 +109,15 @@ def _set_embeddings(self, idx_name: str, idx_info: IndexConfig):
return
elif idx_info.store_type == StoreType.VERTEXAI:
store = VertexAIStore(idx_info)
except NoEmbeddingsException as e:
if not is_custom_dc():
raise e
# Some custom DCs may not have SVs or topics in which case no embeddings is a valid condition.
# We log a warning and skip it in that case.
logging.warning(
f'No embeddings found for the following index and will be skipped: {idx_info}'
)
store = None
except Exception as e:
logging.error(f'error loading index {idx_name}: {str(e)} ')
raise e
Expand Down
30 changes: 30 additions & 0 deletions nl_server/store/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""In-memory Embeddings store."""

import csv
import logging
from typing import List

Expand All @@ -25,6 +26,7 @@
from nl_server.embeddings import EmbeddingsMatch
from nl_server.embeddings import EmbeddingsResult
from nl_server.embeddings import EmbeddingsStore
from nl_server.embeddings import NoEmbeddingsException
from shared.lib.custom_dc_util import use_anonymous_gcs_client
from shared.lib.gcs import is_gcs_path
from shared.lib.gcs import maybe_download
Expand Down Expand Up @@ -53,6 +55,10 @@ def __init__(self, idx_info: MemoryIndexConfig) -> None:
f'"embeddings_path" path must start with `/` or `gs://`: {idx_info.embeddings_path}'
)

# Raise no embeddings exception if the embeddings path does not have any embeddings.
if _is_csv_empty_or_header_only(embeddings_path):
raise NoEmbeddingsException()

self.dataset_embeddings: torch.Tensor = None
self.dcids: List[str] = []
self.sentences: List[str] = []
Expand Down Expand Up @@ -99,3 +105,27 @@ def vector_search(self, query_embeddings: torch.Tensor,
results.append(matches)

return results


def _is_csv_empty_or_header_only(file_path):
"""
Checks if a CSV file is empty or only contains the header row.

Args:
file_path: The path to the CSV file.

Returns:
True if the CSV file is empty or has only the header, False otherwise.
"""
with open(file_path, 'r', newline='') as csvfile:
reader = csv.reader(csvfile)
try:
# Read the first row (header)
next(reader)
# Try reading the second row
next(reader)
# If no exception is raised, there are more rows than just the header
return False
except StopIteration:
# StopIteration is raised if there are no more rows to read
return True
Loading