Skip to content

Commit

Permalink
Merge branch 'master' into lints
Browse files Browse the repository at this point in the history
  • Loading branch information
gmechali authored Jan 13, 2025
2 parents c3cb7b3 + c015ad2 commit ebe2fdc
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 2 deletions.
5 changes: 5 additions & 0 deletions nl_server/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,8 @@ def vector_search(self, queries: List[str], top_k: int) -> SearchVarsResult:

# Turn this into a map:
return {k: v for k, v in zip(queries, results)}


class NoEmbeddingsException(Exception):
"""Custom exception raised when no embeddings are found in the embeddings csv."""
pass
11 changes: 9 additions & 2 deletions nl_server/flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,15 @@ def create_app():
if not lib_utils.is_test_env():
# Below is a safe check to ensure that the model and embedding is loaded.
server_config = reg.server_config()
idx_type = server_config.default_indexes[0]
embeddings = reg.get_index(idx_type)

def _get_first_available_embeddings():
for idx in server_config.default_indexes:
embeddings = reg.get_index(idx)
if embeddings:
return (idx, embeddings)
raise ValueError('No embeddings found')

idx_type, embeddings = _get_first_available_embeddings()
query = server_config.indexes[idx_type].healthcheck_query
result = search.search_vars([embeddings], [query]).get(query)
if not result or not result.svs:
Expand Down
10 changes: 10 additions & 0 deletions nl_server/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from nl_server.config import StoreType
from nl_server.embeddings import Embeddings
from nl_server.embeddings import EmbeddingsModel
from nl_server.embeddings import NoEmbeddingsException
from nl_server.model.attribute_model import AttributeModel
from nl_server.model.create import create_embeddings_model
from nl_server.ranking import RerankingModel
Expand Down Expand Up @@ -108,6 +109,15 @@ def _set_embeddings(self, idx_name: str, idx_info: IndexConfig):
return
elif idx_info.store_type == StoreType.VERTEXAI:
store = VertexAIStore(idx_info)
except NoEmbeddingsException as e:
if not is_custom_dc():
raise e
# Some custom DCs may not have SVs or topics in which case no embeddings is a valid condition.
# We log a warning and skip it in that case.
logging.warning(
f'No embeddings found for the following index and will be skipped: {idx_info}'
)
store = None
except Exception as e:
logging.error(f'error loading index {idx_name}: {str(e)} ')
raise e
Expand Down
30 changes: 30 additions & 0 deletions nl_server/store/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""In-memory Embeddings store."""

import csv
import logging
from typing import List

Expand All @@ -25,6 +26,7 @@
from nl_server.embeddings import EmbeddingsMatch
from nl_server.embeddings import EmbeddingsResult
from nl_server.embeddings import EmbeddingsStore
from nl_server.embeddings import NoEmbeddingsException
from shared.lib.custom_dc_util import use_anonymous_gcs_client
from shared.lib.gcs import is_gcs_path
from shared.lib.gcs import maybe_download
Expand Down Expand Up @@ -53,6 +55,10 @@ def __init__(self, idx_info: MemoryIndexConfig) -> None:
f'"embeddings_path" path must start with `/` or `gs://`: {idx_info.embeddings_path}'
)

# Raise no embeddings exception if the embeddings path does not have any embeddings.
if _is_csv_empty_or_header_only(embeddings_path):
raise NoEmbeddingsException()

self.dataset_embeddings: torch.Tensor = None
self.dcids: List[str] = []
self.sentences: List[str] = []
Expand Down Expand Up @@ -99,3 +105,27 @@ def vector_search(self, query_embeddings: torch.Tensor,
results.append(matches)

return results


def _is_csv_empty_or_header_only(file_path):
"""
Checks if a CSV file is empty or only contains the header row.
Args:
file_path: The path to the CSV file.
Returns:
True if the CSV file is empty or has only the header, False otherwise.
"""
with open(file_path, 'r', newline='') as csvfile:
reader = csv.reader(csvfile)
try:
# Read the first row (header)
next(reader)
# Try reading the second row
next(reader)
# If no exception is raised, there are more rows than just the header
return False
except StopIteration:
# StopIteration is raised if there are no more rows to read
return True

0 comments on commit ebe2fdc

Please sign in to comment.