Skip to content

Commit

Permalink
Add vector store id reference to embeddings config. (#1662)
Browse files Browse the repository at this point in the history
  • Loading branch information
dworthen authored Jan 28, 2025
1 parent 1bbce33 commit eeee84e
Show file tree
Hide file tree
Showing 10 changed files with 43 additions and 18 deletions.
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20250127224919088925.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Add vector store id reference to embeddings config."
}
2 changes: 1 addition & 1 deletion graphrag/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@
VECTOR_STORE_DB_URI = str(Path(OUTPUT_BASE_DIR) / "lancedb")
VECTOR_STORE_CONTAINER_NAME = "default"
VECTOR_STORE_OVERWRITE = True
VECTOR_STORE_INDEX_NAME = "output"
VECTOR_STORE_DEFAULT_ID = "default_vector_store"

# Local Search
LOCAL_SEARCH_TEXT_UNIT_PROP = 0.5
Expand Down
14 changes: 3 additions & 11 deletions graphrag/config/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,10 @@ def get_embedding_settings(
embeddings_llm_settings = settings.get_language_model_config(
settings.embeddings.model_id
)
num_entries = len(settings.vector_store)
if num_entries == 1:
store = next(iter(settings.vector_store.values()))
vector_store_settings = store.model_dump()
else:
# The vector_store dict should only have more than one entry for multi-index query
vector_store_settings = None
vector_store_settings = settings.get_vector_store_config(
settings.embeddings.vector_store_id
).model_dump()

if vector_store_settings is None:
return {
"strategy": settings.embeddings.resolved_strategy(embeddings_llm_settings)
}
#
# If we get to this point, settings.vector_store is defined, and there's a specific setting for this embedding.
# settings.vector_store.base contains connection information, or may be undefined
Expand Down
3 changes: 2 additions & 1 deletion graphrag/config/init_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,15 @@
# deployment_name: <azure_model_deployment_name>
vector_store:
{defs.VECTOR_STORE_INDEX_NAME}:
{defs.VECTOR_STORE_DEFAULT_ID}:
type: {defs.VECTOR_STORE_TYPE}
db_uri: {defs.VECTOR_STORE_DB_URI}
container_name: {defs.VECTOR_STORE_CONTAINER_NAME}
overwrite: {defs.VECTOR_STORE_OVERWRITE}
embeddings:
model_id: {defs.DEFAULT_EMBEDDING_MODEL_ID}
vector_store_id: {defs.VECTOR_STORE_DEFAULT_ID}
### Input settings ###
Expand Down
26 changes: 25 additions & 1 deletion graphrag/config/models/graph_rag_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def _validate_update_index_output_base_dir(self) -> None:

vector_store: dict[str, VectorStoreConfig] = Field(
description="The vector store configuration.",
default={"output": VectorStoreConfig()},
default={defs.VECTOR_STORE_DEFAULT_ID: VectorStoreConfig()},
)
"""The vector store configuration."""

Expand Down Expand Up @@ -263,6 +263,30 @@ def get_language_model_config(self, model_id: str) -> LanguageModelConfig:

return self.models[model_id]

def get_vector_store_config(self, vector_store_id: str) -> VectorStoreConfig:
"""Get a vector store configuration by ID.
Parameters
----------
vector_store_id : str
The ID of the vector store to get. Should match an ID in the vector_store list.
Returns
-------
VectorStoreConfig
The vector store configuration if found.
Raises
------
ValueError
If the vector store ID is not found in the configuration.
"""
if vector_store_id not in self.vector_store:
err_msg = f"Vector Store ID {vector_store_id} not found in configuration. Please rerun `graphrag init` and set the vector store configuration."
raise ValueError(err_msg)

return self.vector_store[vector_store_id]

@model_validator(mode="after")
def _validate_model(self):
"""Validate the model configuration."""
Expand Down
4 changes: 4 additions & 0 deletions graphrag/config/models/text_embedding_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ class TextEmbeddingConfig(BaseModel):
description="The model ID to use for text embeddings.",
default=defs.EMBEDDING_MODEL_ID,
)
vector_store_id: str = Field(
description="The vector store ID to use for text embeddings.",
default=defs.VECTOR_STORE_DEFAULT_ID,
)

def resolved_strategy(self, model_config: LanguageModelConfig) -> dict:
"""Get the resolved text embedding strategy."""
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/azure/settings.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ claim_extraction:

embeddings:
vector_store:
output:
default_vector_store:
type: "azure_ai_search"
url: ${AZURE_AI_SEARCH_URL_ENDPOINT}
api_key: ${AZURE_AI_SEARCH_API_KEY}
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/min-csv/settings.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ models:
async_mode: threaded

vector_store:
output:
default_vector_store:
type: "lancedb"
db_uri: "./tests/fixtures/min-csv/lancedb"
container_name: "lancedb_ci"
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/text/settings.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ models:
async_mode: threaded

vector_store:
output:
default_vector_store:
type: "azure_ai_search"
url: ${AZURE_AI_SEARCH_URL_ENDPOINT}
api_key: ${AZURE_AI_SEARCH_API_KEY}
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/config/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
DEFAULT_GRAPHRAG_CONFIG_SETTINGS = {
"models": DEFAULT_MODEL_CONFIG,
"vector_store": {
"output": {
defs.VECTOR_STORE_DEFAULT_ID: {
"type": defs.VECTOR_STORE_TYPE,
"db_uri": defs.VECTOR_STORE_DB_URI,
"container_name": defs.VECTOR_STORE_CONTAINER_NAME,
Expand Down

0 comments on commit eeee84e

Please sign in to comment.