Skip to content

Commit

Permalink
core[minor]: Add factory for looking up secrets from the env (langcha…
Browse files Browse the repository at this point in the history
…in-ai#25198)

Add factory method for looking secrets from the env.
  • Loading branch information
eyurtsev authored Aug 8, 2024
1 parent da9281f commit 429a0ee
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 3 deletions.
2 changes: 2 additions & 0 deletions libs/core/langchain_core/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
guard_import,
mock_now,
raise_for_status_with_text,
secret_from_env,
xor_args,
)

Expand Down Expand Up @@ -56,4 +57,5 @@
"batch_iterate",
"abatch_iterate",
"from_env",
"secret_from_env",
]
63 changes: 61 additions & 2 deletions libs/core/langchain_core/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,11 +313,11 @@ def from_env(
This will be raised as a ValueError.
"""

def get_from_env_fn() -> str: # type: ignore
def get_from_env_fn() -> Optional[str]:
"""Get a value from an environment variable."""
if key in os.environ:
return os.environ[key]
elif isinstance(default, str):
elif isinstance(default, (str, type(None))):
return default
else:
if error_message:
Expand All @@ -330,3 +330,62 @@ def get_from_env_fn() -> str: # type: ignore
)

return get_from_env_fn


@overload
def secret_from_env(key: str, /) -> Callable[[], SecretStr]: ...


@overload
def secret_from_env(key: str, /, *, default: str) -> Callable[[], SecretStr]: ...


@overload
def secret_from_env(
key: str, /, *, default: None
) -> Callable[[], Optional[SecretStr]]: ...


@overload
def secret_from_env(key: str, /, *, error_message: str) -> Callable[[], SecretStr]: ...


def secret_from_env(
key: str,
/,
*,
default: Union[str, _NoDefaultType, None] = _NoDefault,
error_message: Optional[str] = None,
) -> Union[Callable[[], Optional[SecretStr]], Callable[[], SecretStr]]:
"""Secret from env.
Args:
key: The environment variable to look up.
default: The default value to return if the environment variable is not set.
error_message: the error message which will be raised if the key is not found
and no default value is provided.
This will be raised as a ValueError.
Returns:
factory method that will look up the secret from the environment.
"""

def get_secret_from_env() -> Optional[SecretStr]:
"""Get a value from an environment variable."""
if key in os.environ:
return SecretStr(os.environ[key])
elif isinstance(default, str):
return SecretStr(default)
elif isinstance(default, type(None)):
return None
else:
if error_message:
raise ValueError(error_message)
else:
raise ValueError(
f"Did not find {key}, please add an environment variable"
f" `{key}` which contains it, or pass"
f" `{key}` as a named parameter."
)

return get_secret_from_env
1 change: 1 addition & 0 deletions libs/core/tests/unit_tests/utils/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"stringify_value",
"pre_init",
"from_env",
"secret_from_env",
]


Expand Down
108 changes: 107 additions & 1 deletion libs/core/tests/unit_tests/utils/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import os
import re
from contextlib import AbstractContextManager, nullcontext
from typing import Any, Dict, Optional, Tuple, Type, Union
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
from unittest.mock import patch

import pytest

from langchain_core import utils
from langchain_core.pydantic_v1 import SecretStr
from langchain_core.utils import (
check_package_version,
from_env,
Expand All @@ -15,6 +16,7 @@
)
from langchain_core.utils._merge import merge_dicts
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION
from langchain_core.utils.utils import secret_from_env


@pytest.mark.parametrize(
Expand Down Expand Up @@ -254,3 +256,107 @@ def test_from_env_with_default_error_message() -> None:
get_value = from_env(key)
with pytest.raises(ValueError, match=f"Did not find {key}"):
get_value()


def test_secret_from_env_with_env_variable(monkeypatch: pytest.MonkeyPatch) -> None:
# Set the environment variable
monkeypatch.setenv("TEST_KEY", "secret_value")

# Get the function
get_secret: Callable[[], Optional[SecretStr]] = secret_from_env("TEST_KEY")

# Assert that it returns the correct value
assert get_secret() == SecretStr("secret_value")


def test_secret_from_env_with_default_value(monkeypatch: pytest.MonkeyPatch) -> None:
# Unset the environment variable
monkeypatch.delenv("TEST_KEY", raising=False)

# Get the function with a default value
get_secret: Callable[[], SecretStr] = secret_from_env(
"TEST_KEY", default="default_value"
)

# Assert that it returns the default value
assert get_secret() == SecretStr("default_value")


def test_secret_from_env_with_none_default(monkeypatch: pytest.MonkeyPatch) -> None:
# Unset the environment variable
monkeypatch.delenv("TEST_KEY", raising=False)

# Get the function with a default value of None
get_secret: Callable[[], Optional[SecretStr]] = secret_from_env(
"TEST_KEY", default=None
)

# Assert that it returns None
assert get_secret() is None


def test_secret_from_env_without_default_raises_error(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Unset the environment variable
monkeypatch.delenv("TEST_KEY", raising=False)

# Get the function without a default value
get_secret: Callable[[], SecretStr] = secret_from_env("TEST_KEY")

# Assert that it raises a ValueError with the correct message
with pytest.raises(ValueError, match="Did not find TEST_KEY"):
get_secret()


def test_secret_from_env_with_custom_error_message(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Unset the environment variable
monkeypatch.delenv("TEST_KEY", raising=False)

# Get the function without a default value but with a custom error message
get_secret: Callable[[], SecretStr] = secret_from_env(
"TEST_KEY", error_message="Custom error message"
)

# Assert that it raises a ValueError with the custom message
with pytest.raises(ValueError, match="Custom error message"):
get_secret()


def test_using_secret_from_env_as_default_factory(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# Set the environment variable
monkeypatch.setenv("TEST_KEY", "secret_value")
# Get the function
from langchain_core.pydantic_v1 import BaseModel, Field

class Foo(BaseModel):
secret: SecretStr = Field(default_factory=secret_from_env("TEST_KEY"))

assert Foo().secret.get_secret_value() == "secret_value"

class Bar(BaseModel):
secret: Optional[SecretStr] = Field(
default_factory=secret_from_env("TEST_KEY_2", default=None)
)

assert Bar().secret is None

class Buzz(BaseModel):
secret: Optional[SecretStr] = Field(
default_factory=secret_from_env("TEST_KEY_2", default="hello")
)

# We know it will be SecretStr rather than Optional[SecretStr]
assert Buzz().secret.get_secret_value() == "hello" # type: ignore

class OhMy(BaseModel):
secret: Optional[SecretStr] = Field(
default_factory=secret_from_env("FOOFOOFOOBAR")
)

with pytest.raises(ValueError, match="Did not find FOOFOOFOOBAR"):
OhMy()

0 comments on commit 429a0ee

Please sign in to comment.