diff --git a/.semversioner/next-release/patch-20250127224919088925.json b/.semversioner/next-release/patch-20250127224919088925.json new file mode 100644 index 0000000000..5e0d890434 --- /dev/null +++ b/.semversioner/next-release/patch-20250127224919088925.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Add vector store id reference to embeddings config." +} diff --git a/graphrag/config/defaults.py b/graphrag/config/defaults.py index cb961b8b51..f33985db62 100644 --- a/graphrag/config/defaults.py +++ b/graphrag/config/defaults.py @@ -106,7 +106,7 @@ VECTOR_STORE_DB_URI = str(Path(OUTPUT_BASE_DIR) / "lancedb") VECTOR_STORE_CONTAINER_NAME = "default" VECTOR_STORE_OVERWRITE = True -VECTOR_STORE_INDEX_NAME = "output" +VECTOR_STORE_DEFAULT_ID = "default_vector_store" # Local Search LOCAL_SEARCH_TEXT_UNIT_PROP = 0.5 diff --git a/graphrag/config/embeddings.py b/graphrag/config/embeddings.py index a322290125..11fa82ef08 100644 --- a/graphrag/config/embeddings.py +++ b/graphrag/config/embeddings.py @@ -57,18 +57,10 @@ def get_embedding_settings( embeddings_llm_settings = settings.get_language_model_config( settings.embeddings.model_id ) - num_entries = len(settings.vector_store) - if num_entries == 1: - store = next(iter(settings.vector_store.values())) - vector_store_settings = store.model_dump() - else: - # The vector_store dict should only have more than one entry for multi-index query - vector_store_settings = None + vector_store_settings = settings.get_vector_store_config( + settings.embeddings.vector_store_id + ).model_dump() - if vector_store_settings is None: - return { - "strategy": settings.embeddings.resolved_strategy(embeddings_llm_settings) - } # # If we get to this point, settings.vector_store is defined, and there's a specific setting for this embedding. # settings.vector_store.base contains connection information, or may be undefined diff --git a/graphrag/config/init_content.py b/graphrag/config/init_content.py index f510ef7f2b..eccd05e4eb 100644 --- a/graphrag/config/init_content.py +++ b/graphrag/config/init_content.py @@ -40,7 +40,7 @@ # deployment_name: vector_store: - {defs.VECTOR_STORE_INDEX_NAME}: + {defs.VECTOR_STORE_DEFAULT_ID}: type: {defs.VECTOR_STORE_TYPE} db_uri: {defs.VECTOR_STORE_DB_URI} container_name: {defs.VECTOR_STORE_CONTAINER_NAME} @@ -48,6 +48,7 @@ embeddings: model_id: {defs.DEFAULT_EMBEDDING_MODEL_ID} + vector_store_id: {defs.VECTOR_STORE_DEFAULT_ID} ### Input settings ### diff --git a/graphrag/config/models/graph_rag_config.py b/graphrag/config/models/graph_rag_config.py index 9c50714f7d..1e5fd84ce1 100644 --- a/graphrag/config/models/graph_rag_config.py +++ b/graphrag/config/models/graph_rag_config.py @@ -226,7 +226,7 @@ def _validate_update_index_output_base_dir(self) -> None: vector_store: dict[str, VectorStoreConfig] = Field( description="The vector store configuration.", - default={"output": VectorStoreConfig()}, + default={defs.VECTOR_STORE_DEFAULT_ID: VectorStoreConfig()}, ) """The vector store configuration.""" @@ -263,6 +263,30 @@ def get_language_model_config(self, model_id: str) -> LanguageModelConfig: return self.models[model_id] + def get_vector_store_config(self, vector_store_id: str) -> VectorStoreConfig: + """Get a vector store configuration by ID. + + Parameters + ---------- + vector_store_id : str + The ID of the vector store to get. Should match an ID in the vector_store list. + + Returns + ------- + VectorStoreConfig + The vector store configuration if found. + + Raises + ------ + ValueError + If the vector store ID is not found in the configuration. + """ + if vector_store_id not in self.vector_store: + err_msg = f"Vector Store ID {vector_store_id} not found in configuration. Please rerun `graphrag init` and set the vector store configuration." + raise ValueError(err_msg) + + return self.vector_store[vector_store_id] + @model_validator(mode="after") def _validate_model(self): """Validate the model configuration.""" diff --git a/graphrag/config/models/text_embedding_config.py b/graphrag/config/models/text_embedding_config.py index c26e13ee08..9a8763fd12 100644 --- a/graphrag/config/models/text_embedding_config.py +++ b/graphrag/config/models/text_embedding_config.py @@ -34,6 +34,10 @@ class TextEmbeddingConfig(BaseModel): description="The model ID to use for text embeddings.", default=defs.EMBEDDING_MODEL_ID, ) + vector_store_id: str = Field( + description="The vector store ID to use for text embeddings.", + default=defs.VECTOR_STORE_DEFAULT_ID, + ) def resolved_strategy(self, model_config: LanguageModelConfig) -> dict: """Get the resolved text embedding strategy.""" diff --git a/tests/fixtures/azure/settings.yml b/tests/fixtures/azure/settings.yml index 3f054b6717..6303c771c1 100644 --- a/tests/fixtures/azure/settings.yml +++ b/tests/fixtures/azure/settings.yml @@ -3,7 +3,7 @@ claim_extraction: embeddings: vector_store: - output: + default_vector_store: type: "azure_ai_search" url: ${AZURE_AI_SEARCH_URL_ENDPOINT} api_key: ${AZURE_AI_SEARCH_API_KEY} diff --git a/tests/fixtures/min-csv/settings.yml b/tests/fixtures/min-csv/settings.yml index 09642c9260..ebd9b5f31b 100644 --- a/tests/fixtures/min-csv/settings.yml +++ b/tests/fixtures/min-csv/settings.yml @@ -26,7 +26,7 @@ models: async_mode: threaded vector_store: - output: + default_vector_store: type: "lancedb" db_uri: "./tests/fixtures/min-csv/lancedb" container_name: "lancedb_ci" diff --git a/tests/fixtures/text/settings.yml b/tests/fixtures/text/settings.yml index 09b5f13d38..d05d384d97 100644 --- a/tests/fixtures/text/settings.yml +++ b/tests/fixtures/text/settings.yml @@ -26,7 +26,7 @@ models: async_mode: threaded vector_store: - output: + default_vector_store: type: "azure_ai_search" url: ${AZURE_AI_SEARCH_URL_ENDPOINT} api_key: ${AZURE_AI_SEARCH_API_KEY} diff --git a/tests/unit/config/utils.py b/tests/unit/config/utils.py index d231b5c277..6535f448e9 100644 --- a/tests/unit/config/utils.py +++ b/tests/unit/config/utils.py @@ -50,7 +50,7 @@ DEFAULT_GRAPHRAG_CONFIG_SETTINGS = { "models": DEFAULT_MODEL_CONFIG, "vector_store": { - "output": { + defs.VECTOR_STORE_DEFAULT_ID: { "type": defs.VECTOR_STORE_TYPE, "db_uri": defs.VECTOR_STORE_DB_URI, "container_name": defs.VECTOR_STORE_CONTAINER_NAME,