Skip to content

Commit

Permalink
Require explicit azure auth settings when using AOI. (#1665)
Browse files Browse the repository at this point in the history
* Require explicit azure auth settings when using AOI.

- Must set LanguageModel.azure_auth_type to either
"api_key" or "managed_identity" when using AOI.

* Fix smoke tests

* Use general auth_type property instead of azure_auth_type

* Remove unused error type

* Update validation

* Update validation comment
  • Loading branch information
dworthen authored Jan 29, 2025
1 parent d31750f commit 94bd2bb
Show file tree
Hide file tree
Showing 10 changed files with 68 additions and 24 deletions.
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20250128230417263466.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Require explicit azure auth settings when using AOI."
}
2 changes: 2 additions & 0 deletions graphrag/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from graphrag.config.enums import (
AsyncType,
AuthType,
CacheType,
ChunkStrategyType,
InputFileType,
Expand All @@ -24,6 +25,7 @@
ASYNC_MODE = AsyncType.Threaded
ENCODING_MODEL = "cl100k_base"
AZURE_AUDIENCE = "https://cognitiveservices.azure.com/.default"
AUTH_TYPE = AuthType.APIKey
#
# LLM Parameters
#
Expand Down
6 changes: 3 additions & 3 deletions graphrag/config/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,11 @@ def __repr__(self):
return f'"{self.value}"'


class AzureAuthType(str, Enum):
"""AzureAuthType enum class definition."""
class AuthType(str, Enum):
"""AuthType enum class definition."""

APIKey = "api_key"
ManagedIdentity = "managed_identity"
AzureManagedIdentity = "azure_managed_identity"


class AsyncType(str, Enum):
Expand Down
2 changes: 2 additions & 0 deletions graphrag/config/init_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
{defs.DEFAULT_CHAT_MODEL_ID}:
api_key: ${{GRAPHRAG_API_KEY}} # set this in the generated .env file
type: {defs.LLM_TYPE.value} # or azure_openai_chat
auth_type: {defs.AUTH_TYPE.value} # or azure_managed_identity
model: {defs.LLM_MODEL}
model_supports_json: true # recommended if this is available for your model.
parallelization_num_threads: {defs.PARALLELIZATION_NUM_THREADS}
Expand All @@ -29,6 +30,7 @@
{defs.DEFAULT_EMBEDDING_MODEL_ID}:
api_key: ${{GRAPHRAG_API_KEY}}
type: {defs.EMBEDDING_TYPE.value} # or azure_openai_embedding
auth_type: {defs.AUTH_TYPE.value} # or azure_managed_identity
model: {defs.EMBEDDING_MODEL}
parallelization_num_threads: {defs.PARALLELIZATION_NUM_THREADS}
parallelization_stagger: {defs.PARALLELIZATION_STAGGER}
Expand Down
38 changes: 27 additions & 11 deletions graphrag/config/models/language_model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pydantic import BaseModel, Field, model_validator

import graphrag.config.defaults as defs
from graphrag.config.enums import AsyncType, AzureAuthType, LLMType
from graphrag.config.enums import AsyncType, AuthType, LLMType
from graphrag.config.errors import (
ApiKeyMissingError,
AzureApiBaseMissingError,
Expand Down Expand Up @@ -40,27 +40,42 @@ def _validate_api_key(self) -> None:
ApiKeyMissingError
If the API key is missing and is required.
"""
if (
self.type == LLMType.OpenAIEmbedding
or self.type == LLMType.OpenAIChat
or self.azure_auth_type == AzureAuthType.APIKey
) and (self.api_key is None or self.api_key.strip() == ""):
if self.auth_type == AuthType.APIKey and (
self.api_key is None or self.api_key.strip() == ""
):
raise ApiKeyMissingError(
self.type.value,
self.azure_auth_type.value if self.azure_auth_type else None,
self.auth_type.value,
)

if (self.azure_auth_type == AzureAuthType.ManagedIdentity) and (
if (self.auth_type == AuthType.AzureManagedIdentity) and (
self.api_key is not None and self.api_key.strip() != ""
):
msg = "API Key should not be provided when using Azure Managed Identity. Please rerun `graphrag init` and remove the api_key when using Azure Managed Identity."
raise ConflictingSettingsError(msg)

azure_auth_type: AzureAuthType | None = Field(
description="The Azure authentication type to use when using AOI.",
default=None,
auth_type: AuthType = Field(
description="The authentication type.",
default=defs.AUTH_TYPE,
)

def _validate_auth_type(self) -> None:
"""Validate the authentication type.
auth_type must be api_key when using OpenAI and
can be either api_key or azure_managed_identity when using AOI.
Raises
------
ConflictingSettingsError
If the Azure authentication type conflicts with the model being used.
"""
if self.auth_type == AuthType.AzureManagedIdentity and (
self.type == LLMType.OpenAIChat or self.type == LLMType.OpenAIEmbedding
):
msg = f"auth_type of azure_managed_identity is not supported for model type {self.type.value}. Please rerun `graphrag init` and set the auth_type to api_key."
raise ConflictingSettingsError(msg)

type: LLMType = Field(description="The type of LLM model to use.")
model: str = Field(description="The LLM model to use.")
encoding_model: str = Field(description="The encoding model to use", default="")
Expand Down Expand Up @@ -233,6 +248,7 @@ def _validate_azure_settings(self) -> None:

@model_validator(mode="after")
def _validate_model(self):
self._validate_auth_type()
self._validate_api_key()
self._validate_azure_settings()
self._validate_encoding_model()
Expand Down
8 changes: 5 additions & 3 deletions graphrag/query/llm/get_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from azure.identity import DefaultAzureCredential, get_bearer_token_provider

from graphrag.config.enums import LLMType
from graphrag.config.enums import AuthType, LLMType
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.query.llm.oai.chat_openai import ChatOpenAI
from graphrag.query.llm.oai.embedding import OpenAIEmbedding
Expand All @@ -31,7 +31,8 @@ def get_llm(config: GraphRagConfig) -> ChatOpenAI:
api_key=default_llm_settings.api_key,
azure_ad_token_provider=(
get_bearer_token_provider(DefaultAzureCredential(), audience)
if is_azure_client and not default_llm_settings.api_key
if is_azure_client
and default_llm_settings.auth_type == AuthType.AzureManagedIdentity
else None
),
api_base=default_llm_settings.api_base,
Expand Down Expand Up @@ -65,7 +66,8 @@ def get_text_embedder(config: GraphRagConfig) -> OpenAIEmbedding:
api_key=embeddings_llm_settings.api_key,
azure_ad_token_provider=(
get_bearer_token_provider(DefaultAzureCredential(), audience)
if is_azure_client and not embeddings_llm_settings.api_key
if is_azure_client
and embeddings_llm_settings.auth_type == AuthType.AzureManagedIdentity
else None
),
api_base=embeddings_llm_settings.api_base,
Expand Down
2 changes: 2 additions & 0 deletions tests/fixtures/min-csv/settings.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
models:
default_chat_model:
azure_auth_type: api_key
type: ${GRAPHRAG_LLM_TYPE}
api_key: ${GRAPHRAG_API_KEY}
api_base: ${GRAPHRAG_API_BASE}
Expand All @@ -13,6 +14,7 @@ models:
parallelization_stagger: 0.3
async_mode: threaded
default_embedding_model:
azure_auth_type: api_key
type: ${GRAPHRAG_EMBEDDING_TYPE}
api_key: ${GRAPHRAG_API_KEY}
api_base: ${GRAPHRAG_API_BASE}
Expand Down
2 changes: 2 additions & 0 deletions tests/fixtures/text/settings.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
models:
default_chat_model:
azure_auth_type: api_key
type: ${GRAPHRAG_LLM_TYPE}
api_key: ${GRAPHRAG_API_KEY}
api_base: ${GRAPHRAG_API_BASE}
Expand All @@ -13,6 +14,7 @@ models:
parallelization_stagger: 0.3
async_mode: threaded
default_embedding_model:
azure_auth_type: api_key
type: ${GRAPHRAG_EMBEDDING_TYPE}
api_key: ${GRAPHRAG_API_KEY}
api_base: ${GRAPHRAG_API_BASE}
Expand Down
26 changes: 20 additions & 6 deletions tests/unit/config/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import graphrag.config.defaults as defs
from graphrag.config.create_graphrag_config import create_graphrag_config
from graphrag.config.enums import AzureAuthType, LLMType
from graphrag.config.enums import AuthType, LLMType
from graphrag.config.load_config import load_config
from tests.unit.config.utils import (
DEFAULT_EMBEDDING_MODEL_CONFIG,
Expand Down Expand Up @@ -46,7 +46,7 @@ def test_missing_azure_api_key() -> None:
model_config_missing_api_key = {
defs.DEFAULT_CHAT_MODEL_ID: {
"type": LLMType.AzureOpenAIChat,
"azure_auth_type": AzureAuthType.APIKey,
"auth_type": AuthType.APIKey,
"model": defs.LLM_MODEL,
"api_base": "some_api_base",
"api_version": "some_api_version",
Expand All @@ -59,17 +59,31 @@ def test_missing_azure_api_key() -> None:
create_graphrag_config({"models": model_config_missing_api_key})

# API Key not required for managed identity
model_config_missing_api_key[defs.DEFAULT_CHAT_MODEL_ID]["azure_auth_type"] = (
AzureAuthType.ManagedIdentity
model_config_missing_api_key[defs.DEFAULT_CHAT_MODEL_ID]["auth_type"] = (
AuthType.AzureManagedIdentity
)
create_graphrag_config({"models": model_config_missing_api_key})


def test_conflicting_auth_type() -> None:
model_config_invalid_auth_type = {
defs.DEFAULT_CHAT_MODEL_ID: {
"auth_type": AuthType.AzureManagedIdentity,
"type": LLMType.OpenAIChat,
"model": defs.LLM_MODEL,
},
defs.DEFAULT_EMBEDDING_MODEL_ID: DEFAULT_EMBEDDING_MODEL_CONFIG,
}

with pytest.raises(ValidationError):
create_graphrag_config({"models": model_config_invalid_auth_type})


def test_conflicting_azure_api_key() -> None:
model_config_conflicting_api_key = {
defs.DEFAULT_CHAT_MODEL_ID: {
"type": LLMType.AzureOpenAIChat,
"azure_auth_type": AzureAuthType.ManagedIdentity,
"auth_type": AuthType.AzureManagedIdentity,
"model": defs.LLM_MODEL,
"api_base": "some_api_base",
"api_version": "some_api_version",
Expand All @@ -85,7 +99,7 @@ def test_conflicting_azure_api_key() -> None:

base_azure_model_config = {
"type": LLMType.AzureOpenAIChat,
"azure_auth_type": AzureAuthType.ManagedIdentity,
"auth_type": AuthType.AzureManagedIdentity,
"model": defs.LLM_MODEL,
"api_base": "some_api_base",
"api_version": "some_api_version",
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/config/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def assert_language_model_configs(
actual: LanguageModelConfig, expected: LanguageModelConfig
) -> None:
assert actual.api_key == expected.api_key
assert actual.azure_auth_type == expected.azure_auth_type
assert actual.auth_type == expected.auth_type
assert actual.type == expected.type
assert actual.model == expected.model
assert actual.encoding_model == expected.encoding_model
Expand Down

0 comments on commit 94bd2bb

Please sign in to comment.