Skip to content

Commit

Permalink
Allow using pydantic plugin with models defined before calling logfir…
Browse files Browse the repository at this point in the history
…e.configure
  • Loading branch information
alexmojaki committed Apr 29, 2024
1 parent 47e6bb2 commit da266ff
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 22 deletions.
20 changes: 2 additions & 18 deletions logfire/_internal/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,9 +334,7 @@ def _load_configuration(
scrubbing_callback: ScrubCallback | None = None,
) -> None:
"""Merge the given parameters with the environment variables file configurations."""
config_dir = Path(config_dir or os.getenv('LOGFIRE_CONFIG_DIR') or '.')
config_from_file = self._load_config_from_file(config_dir)
param_manager = ParamManager(config_from_file=config_from_file)
param_manager = ParamManager.create(config_dir)

self.base_url = param_manager.load_param('base_url', base_url)
self.metrics_endpoint = os.getenv(OTEL_EXPORTER_OTLP_METRICS_ENDPOINT) or urljoin(self.base_url, '/v1/metrics')
Expand Down Expand Up @@ -375,11 +373,7 @@ def _load_configuration(
if isinstance(pydantic_plugin, dict):
# This is particularly for deserializing from a dict as in executors.py
pydantic_plugin = PydanticPlugin(**pydantic_plugin) # type: ignore
self.pydantic_plugin = pydantic_plugin or PydanticPlugin(
record=param_manager.load_param('pydantic_plugin_record'),
include=param_manager.load_param('pydantic_plugin_include'),
exclude=param_manager.load_param('pydantic_plugin_exclude'),
)
self.pydantic_plugin = pydantic_plugin or param_manager.pydantic_plugin()
self.fast_shutdown = fast_shutdown

self.id_generator = id_generator or RandomIdGenerator()
Expand All @@ -396,16 +390,6 @@ def _load_configuration(
# ignore them
pass

def _load_config_from_file(self, config_dir: Path) -> dict[str, Any]:
config_file = config_dir / 'pyproject.toml'
if not config_file.exists():
return {}
try:
data = read_toml_file(config_file)
return data.get('tool', {}).get('logfire', {})
except Exception as exc:
raise LogfireConfigError(f'Invalid config file: {config_file}') from exc


class LogfireConfig(_LogfireConfigData):
def __init__(
Expand Down
26 changes: 26 additions & 0 deletions logfire/_internal/config_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@

from logfire.exceptions import LogfireConfigError

from . import config
from .constants import LOGFIRE_BASE_URL
from .exporters.console import ConsoleColorsValues
from .utils import read_toml_file

try:
import opentelemetry.instrumentation.system_metrics # noqa: F401 # type: ignore
Expand Down Expand Up @@ -117,6 +119,12 @@ class ParamManager:
config_from_file: dict[str, Any]
"""Config loaded from the config file."""

@classmethod
def create(cls, config_dir: Path | None = None) -> ParamManager:
config_dir = Path(config_dir or os.getenv('LOGFIRE_CONFIG_DIR') or '.')
config_from_file = _load_config_from_file(config_dir)
return ParamManager(config_from_file=config_from_file)

def load_param(self, name: str, runtime: Any = None) -> Any:
"""Load a parameter given its name.
Expand Down Expand Up @@ -151,6 +159,13 @@ def load_param(self, name: str, runtime: Any = None) -> Any:

return self._cast(param.default, name, param.tp)

def pydantic_plugin(self):
return config.PydanticPlugin(
record=self.load_param('pydantic_plugin_record'),
include=self.load_param('pydantic_plugin_include'),
exclude=self.load_param('pydantic_plugin_exclude'),
)

def _cast(self, value: Any, name: str, tp: type[T]) -> T | None:
if tp is str:
return value
Expand Down Expand Up @@ -191,3 +206,14 @@ def _check_bool(value: Any, name: str) -> bool | None:

def _extract_set_of_str(value: str | set[str]) -> set[str]:
return set(map(str.strip, value.split(','))) if isinstance(value, str) else value


def _load_config_from_file(config_dir: Path) -> dict[str, Any]:
config_file = config_dir / 'pyproject.toml'
if not config_file.exists():
return {}
try:
data = read_toml_file(config_file)
return data.get('tool', {}).get('logfire', {})
except Exception as exc:
raise LogfireConfigError(f'Invalid config file: {config_file}') from exc
33 changes: 29 additions & 4 deletions logfire/integrations/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
import logfire
from logfire import LogfireSpan

from .._internal.config import GLOBAL_CONFIG
from .._internal.config import GLOBAL_CONFIG, PydanticPlugin
from .._internal.config_params import ParamManager

if TYPE_CHECKING: # pragma: no cover
from pydantic import ValidationError
Expand Down Expand Up @@ -116,6 +117,14 @@ def __call__(self, validator: Any) -> Any:

@functools.wraps(validator)
def wrapped_validator(input_data: Any, *args: Any, **kwargs: Any) -> Any:
if not GLOBAL_CONFIG._initialized: # type: ignore
# These wrappers should be created when the model is defined if the plugin is activated
# by env vars even if logfire.configure() hasn't been called yet,
# but we don't want to actually record anything until logfire.configure() has been called.
# For example it would be annoying if the user didn't want to send data to logfire
# but validation ran into an error about not being authenticated.
return validator(input_data, *args, **kwargs)

# If we get a validation error, we want to let it bubble through.
# If we used `with span:` this would set the log level to 'error' and export it,
# but we want the log level to be 'warn', so we end the span manually.
Expand All @@ -140,6 +149,10 @@ def wrapped_validator(input_data: Any, *args: Any, **kwargs: Any) -> Any:

@functools.wraps(validator)
def wrapped_validator(input_data: Any, *args: Any, **kwargs: Any) -> Any:
if not GLOBAL_CONFIG._initialized: # type: ignore
# Only start recording after logfire has been configured.
return validator(input_data, *args, **kwargs)

try:
result = validator(input_data, *args, **kwargs)
except ValidationError as error:
Expand All @@ -158,6 +171,10 @@ def wrapped_validator(input_data: Any, *args: Any, **kwargs: Any) -> Any:

@functools.wraps(validator)
def wrapped_validator(input_data: Any, *args: Any, **kwargs: Any) -> Any:
if not GLOBAL_CONFIG._initialized: # type: ignore
# Only start recording after logfire has been configured.
return validator(input_data, *args, **kwargs)

try:
result = validator(input_data, *args, **kwargs)
except Exception:
Expand Down Expand Up @@ -318,7 +335,7 @@ def new_schema_validator(
if logfire_settings and 'record' in logfire_settings:
record = logfire_settings['record']
else:
record = GLOBAL_CONFIG.pydantic_plugin.record
record = _pydantic_plugin_config().record

if record == 'off':
return None, None, None
Expand All @@ -341,10 +358,18 @@ def new_schema_validator(
IGNORED_MODULE_PREFIXES: tuple[str, ...] = tuple(f'{module}.' for module in IGNORED_MODULES)


def _pydantic_plugin_config() -> PydanticPlugin:
if GLOBAL_CONFIG._initialized: # type: ignore
return GLOBAL_CONFIG.pydantic_plugin
else:
return ParamManager.create().pydantic_plugin()


def _include_model(schema_type_path: SchemaTypePath) -> bool:
"""Check whether a model should be instrumented."""
include = GLOBAL_CONFIG.pydantic_plugin.include
exclude = GLOBAL_CONFIG.pydantic_plugin.exclude
config = _pydantic_plugin_config()
include = config.include
exclude = config.exclude

# check if the model is in ignored model
module = schema_type_path.module
Expand Down
135 changes: 135 additions & 0 deletions tests/test_pydantic_plugin.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import importlib.metadata
import os
from typing import Any
from unittest.mock import patch

import pytest
from dirty_equals import IsInt
Expand Down Expand Up @@ -1081,3 +1083,136 @@ def double(v: Any) -> Any:
}
]
)


def test_record_all_env_var(exporter: TestExporter) -> None:
# Pretend that logfire.configure() hasn't been called yet.
GLOBAL_CONFIG._initialized = False # type: ignore

with patch.dict(os.environ, {'LOGFIRE_PYDANTIC_PLUGIN_RECORD': 'all'}):
# This model should be instrumented even though logfire.configure() hasn't been called
# because of the LOGFIRE_PYDANTIC_PLUGIN_RECORD env var.
class MyModel(BaseModel):
x: int

# But validations shouldn't be recorded yet.
MyModel(x=1)
assert exporter.exported_spans_as_dict() == []

# Equivalent to calling logfire.configure() with the args in the `config` test fixture.
GLOBAL_CONFIG._initialized = True # type: ignore

MyModel(x=2)
assert exporter.exported_spans_as_dict() == snapshot(
[
{
'name': 'pydantic.validate_python',
'context': {'trace_id': 1, 'span_id': 1, 'is_remote': False},
'parent': None,
'start_time': 1000000000,
'end_time': 2000000000,
'attributes': {
'code.filepath': 'pydantic.py',
'code.function': '_on_enter',
'code.lineno': 123,
'schema_name': 'MyModel',
'validation_method': 'validate_python',
'input_data': '{"x":2}',
'logfire.msg_template': 'Pydantic {schema_name} {validation_method}',
'logfire.level_num': 9,
'logfire.span_type': 'span',
'success': True,
'result': '{"x":2}',
'logfire.msg': 'Pydantic MyModel validate_python succeeded',
'logfire.json_schema': '{"type":"object","properties":{"schema_name":{},"validation_method":{},"input_data":{"type":"object"},"success":{},"result":{"type":"object","title":"MyModel","x-python-datatype":"PydanticModel"}}}',
},
}
]
)


def test_record_failure_env_var(exporter: TestExporter) -> None:
# Same as test_record_all_env_var but with LOGFIRE_PYDANTIC_PLUGIN_RECORD=failure.

GLOBAL_CONFIG._initialized = False # type: ignore

with patch.dict(os.environ, {'LOGFIRE_PYDANTIC_PLUGIN_RECORD': 'failure'}):

class MyModel(BaseModel):
x: int

with pytest.raises(ValidationError):
MyModel(x='a') # type: ignore
assert exporter.exported_spans_as_dict() == []

GLOBAL_CONFIG._initialized = True # type: ignore

with pytest.raises(ValidationError):
MyModel(x='b') # type: ignore
assert exporter.exported_spans_as_dict() == snapshot(
[
{
'name': 'Validation on {schema_name} failed',
'context': {'trace_id': 1, 'span_id': 1, 'is_remote': False},
'parent': None,
'start_time': 1000000000,
'end_time': 1000000000,
'attributes': {
'code.filepath': 'test_pydantic_plugin.py',
'code.function': 'test_record_failure_env_var',
'code.lineno': 123,
'schema_name': 'MyModel',
'logfire.msg_template': 'Validation on {schema_name} failed',
'logfire.level_num': 13,
'error_count': 1,
'errors': '[{"type":"int_parsing","loc":["x"],"msg":"Input should be a valid integer, unable to parse string as an integer","input":"b"}]',
'logfire.span_type': 'log',
'logfire.msg': 'Validation on MyModel failed',
'logfire.json_schema': '{"type":"object","properties":{"schema_name":{},"error_count":{},"errors":{"type":"array","items":{"type":"object","properties":{"loc":{"type":"array","x-python-datatype":"tuple"}}}}}}',
},
}
]
)


def test_record_metrics_env_var(metrics_reader: InMemoryMetricReader) -> None:
# Same as test_record_all_env_var but with LOGFIRE_PYDANTIC_PLUGIN_RECORD=metrics.

GLOBAL_CONFIG._initialized = False # type: ignore

with patch.dict(os.environ, {'LOGFIRE_PYDANTIC_PLUGIN_RECORD': 'metrics'}):

class MyModel(BaseModel):
x: int

MyModel(x=1)
assert metrics_reader.get_metrics_data() is None # type: ignore

GLOBAL_CONFIG._initialized = True # type: ignore

MyModel(x=2)
assert get_collected_metrics(metrics_reader) == snapshot(
[
{
'name': 'pydantic.validations',
'description': '',
'unit': '',
'data': {
'data_points': [
{
'attributes': {
'success': True,
'schema_name': 'MyModel',
'validation_method': 'validate_python',
},
'start_time_unix_nano': IsInt(gt=0),
'time_unix_nano': IsInt(gt=0),
'value': 1,
}
],
'aggregation_temporality': 1,
'is_monotonic': True,
},
}
]
)

0 comments on commit da266ff

Please sign in to comment.