diff --git a/logfire/_internal/config.py b/logfire/_internal/config.py index 1a8ebe59b..4f145cf27 100644 --- a/logfire/_internal/config.py +++ b/logfire/_internal/config.py @@ -334,9 +334,7 @@ def _load_configuration( scrubbing_callback: ScrubCallback | None = None, ) -> None: """Merge the given parameters with the environment variables file configurations.""" - config_dir = Path(config_dir or os.getenv('LOGFIRE_CONFIG_DIR') or '.') - config_from_file = self._load_config_from_file(config_dir) - param_manager = ParamManager(config_from_file=config_from_file) + param_manager = ParamManager.create(config_dir) self.base_url = param_manager.load_param('base_url', base_url) self.metrics_endpoint = os.getenv(OTEL_EXPORTER_OTLP_METRICS_ENDPOINT) or urljoin(self.base_url, '/v1/metrics') @@ -375,11 +373,7 @@ def _load_configuration( if isinstance(pydantic_plugin, dict): # This is particularly for deserializing from a dict as in executors.py pydantic_plugin = PydanticPlugin(**pydantic_plugin) # type: ignore - self.pydantic_plugin = pydantic_plugin or PydanticPlugin( - record=param_manager.load_param('pydantic_plugin_record'), - include=param_manager.load_param('pydantic_plugin_include'), - exclude=param_manager.load_param('pydantic_plugin_exclude'), - ) + self.pydantic_plugin = pydantic_plugin or param_manager.pydantic_plugin() self.fast_shutdown = fast_shutdown self.id_generator = id_generator or RandomIdGenerator() @@ -396,16 +390,6 @@ def _load_configuration( # ignore them pass - def _load_config_from_file(self, config_dir: Path) -> dict[str, Any]: - config_file = config_dir / 'pyproject.toml' - if not config_file.exists(): - return {} - try: - data = read_toml_file(config_file) - return data.get('tool', {}).get('logfire', {}) - except Exception as exc: - raise LogfireConfigError(f'Invalid config file: {config_file}') from exc - class LogfireConfig(_LogfireConfigData): def __init__( diff --git a/logfire/_internal/config_params.py b/logfire/_internal/config_params.py index 270b6f705..86278967c 100644 --- a/logfire/_internal/config_params.py +++ b/logfire/_internal/config_params.py @@ -11,8 +11,10 @@ from logfire.exceptions import LogfireConfigError +from . import config from .constants import LOGFIRE_BASE_URL from .exporters.console import ConsoleColorsValues +from .utils import read_toml_file try: import opentelemetry.instrumentation.system_metrics # noqa: F401 # type: ignore @@ -117,6 +119,12 @@ class ParamManager: config_from_file: dict[str, Any] """Config loaded from the config file.""" + @classmethod + def create(cls, config_dir: Path | None = None) -> ParamManager: + config_dir = Path(config_dir or os.getenv('LOGFIRE_CONFIG_DIR') or '.') + config_from_file = _load_config_from_file(config_dir) + return ParamManager(config_from_file=config_from_file) + def load_param(self, name: str, runtime: Any = None) -> Any: """Load a parameter given its name. @@ -151,6 +159,13 @@ def load_param(self, name: str, runtime: Any = None) -> Any: return self._cast(param.default, name, param.tp) + def pydantic_plugin(self): + return config.PydanticPlugin( + record=self.load_param('pydantic_plugin_record'), + include=self.load_param('pydantic_plugin_include'), + exclude=self.load_param('pydantic_plugin_exclude'), + ) + def _cast(self, value: Any, name: str, tp: type[T]) -> T | None: if tp is str: return value @@ -191,3 +206,14 @@ def _check_bool(value: Any, name: str) -> bool | None: def _extract_set_of_str(value: str | set[str]) -> set[str]: return set(map(str.strip, value.split(','))) if isinstance(value, str) else value + + +def _load_config_from_file(config_dir: Path) -> dict[str, Any]: + config_file = config_dir / 'pyproject.toml' + if not config_file.exists(): + return {} + try: + data = read_toml_file(config_file) + return data.get('tool', {}).get('logfire', {}) + except Exception as exc: + raise LogfireConfigError(f'Invalid config file: {config_file}') from exc diff --git a/logfire/integrations/pydantic.py b/logfire/integrations/pydantic.py index 5b8d1d3ea..c9eb26f7e 100644 --- a/logfire/integrations/pydantic.py +++ b/logfire/integrations/pydantic.py @@ -13,7 +13,8 @@ import logfire from logfire import LogfireSpan -from .._internal.config import GLOBAL_CONFIG +from .._internal.config import GLOBAL_CONFIG, PydanticPlugin +from .._internal.config_params import ParamManager if TYPE_CHECKING: # pragma: no cover from pydantic import ValidationError @@ -116,6 +117,14 @@ def __call__(self, validator: Any) -> Any: @functools.wraps(validator) def wrapped_validator(input_data: Any, *args: Any, **kwargs: Any) -> Any: + if not GLOBAL_CONFIG._initialized: # type: ignore + # These wrappers should be created when the model is defined if the plugin is activated + # by env vars even if logfire.configure() hasn't been called yet, + # but we don't want to actually record anything until logfire.configure() has been called. + # For example it would be annoying if the user didn't want to send data to logfire + # but validation ran into an error about not being authenticated. + return validator(input_data, *args, **kwargs) + # If we get a validation error, we want to let it bubble through. # If we used `with span:` this would set the log level to 'error' and export it, # but we want the log level to be 'warn', so we end the span manually. @@ -140,6 +149,10 @@ def wrapped_validator(input_data: Any, *args: Any, **kwargs: Any) -> Any: @functools.wraps(validator) def wrapped_validator(input_data: Any, *args: Any, **kwargs: Any) -> Any: + if not GLOBAL_CONFIG._initialized: # type: ignore + # Only start recording after logfire has been configured. + return validator(input_data, *args, **kwargs) + try: result = validator(input_data, *args, **kwargs) except ValidationError as error: @@ -158,6 +171,10 @@ def wrapped_validator(input_data: Any, *args: Any, **kwargs: Any) -> Any: @functools.wraps(validator) def wrapped_validator(input_data: Any, *args: Any, **kwargs: Any) -> Any: + if not GLOBAL_CONFIG._initialized: # type: ignore + # Only start recording after logfire has been configured. + return validator(input_data, *args, **kwargs) + try: result = validator(input_data, *args, **kwargs) except Exception: @@ -318,7 +335,7 @@ def new_schema_validator( if logfire_settings and 'record' in logfire_settings: record = logfire_settings['record'] else: - record = GLOBAL_CONFIG.pydantic_plugin.record + record = _pydantic_plugin_config().record if record == 'off': return None, None, None @@ -341,10 +358,18 @@ def new_schema_validator( IGNORED_MODULE_PREFIXES: tuple[str, ...] = tuple(f'{module}.' for module in IGNORED_MODULES) +def _pydantic_plugin_config() -> PydanticPlugin: + if GLOBAL_CONFIG._initialized: # type: ignore + return GLOBAL_CONFIG.pydantic_plugin + else: + return ParamManager.create().pydantic_plugin() + + def _include_model(schema_type_path: SchemaTypePath) -> bool: """Check whether a model should be instrumented.""" - include = GLOBAL_CONFIG.pydantic_plugin.include - exclude = GLOBAL_CONFIG.pydantic_plugin.exclude + config = _pydantic_plugin_config() + include = config.include + exclude = config.exclude # check if the model is in ignored model module = schema_type_path.module diff --git a/tests/test_pydantic_plugin.py b/tests/test_pydantic_plugin.py index 1ad2d5a43..657718347 100644 --- a/tests/test_pydantic_plugin.py +++ b/tests/test_pydantic_plugin.py @@ -1,7 +1,9 @@ from __future__ import annotations import importlib.metadata +import os from typing import Any +from unittest.mock import patch import pytest from dirty_equals import IsInt @@ -1081,3 +1083,136 @@ def double(v: Any) -> Any: } ] ) + + +def test_record_all_env_var(exporter: TestExporter) -> None: + # Pretend that logfire.configure() hasn't been called yet. + GLOBAL_CONFIG._initialized = False # type: ignore + + with patch.dict(os.environ, {'LOGFIRE_PYDANTIC_PLUGIN_RECORD': 'all'}): + # This model should be instrumented even though logfire.configure() hasn't been called + # because of the LOGFIRE_PYDANTIC_PLUGIN_RECORD env var. + class MyModel(BaseModel): + x: int + + # But validations shouldn't be recorded yet. + MyModel(x=1) + assert exporter.exported_spans_as_dict() == [] + + # Equivalent to calling logfire.configure() with the args in the `config` test fixture. + GLOBAL_CONFIG._initialized = True # type: ignore + + MyModel(x=2) + assert exporter.exported_spans_as_dict() == snapshot( + [ + { + 'name': 'pydantic.validate_python', + 'context': {'trace_id': 1, 'span_id': 1, 'is_remote': False}, + 'parent': None, + 'start_time': 1000000000, + 'end_time': 2000000000, + 'attributes': { + 'code.filepath': 'pydantic.py', + 'code.function': '_on_enter', + 'code.lineno': 123, + 'schema_name': 'MyModel', + 'validation_method': 'validate_python', + 'input_data': '{"x":2}', + 'logfire.msg_template': 'Pydantic {schema_name} {validation_method}', + 'logfire.level_num': 9, + 'logfire.span_type': 'span', + 'success': True, + 'result': '{"x":2}', + 'logfire.msg': 'Pydantic MyModel validate_python succeeded', + 'logfire.json_schema': '{"type":"object","properties":{"schema_name":{},"validation_method":{},"input_data":{"type":"object"},"success":{},"result":{"type":"object","title":"MyModel","x-python-datatype":"PydanticModel"}}}', + }, + } + ] + ) + + +def test_record_failure_env_var(exporter: TestExporter) -> None: + # Same as test_record_all_env_var but with LOGFIRE_PYDANTIC_PLUGIN_RECORD=failure. + + GLOBAL_CONFIG._initialized = False # type: ignore + + with patch.dict(os.environ, {'LOGFIRE_PYDANTIC_PLUGIN_RECORD': 'failure'}): + + class MyModel(BaseModel): + x: int + + with pytest.raises(ValidationError): + MyModel(x='a') # type: ignore + assert exporter.exported_spans_as_dict() == [] + + GLOBAL_CONFIG._initialized = True # type: ignore + + with pytest.raises(ValidationError): + MyModel(x='b') # type: ignore + assert exporter.exported_spans_as_dict() == snapshot( + [ + { + 'name': 'Validation on {schema_name} failed', + 'context': {'trace_id': 1, 'span_id': 1, 'is_remote': False}, + 'parent': None, + 'start_time': 1000000000, + 'end_time': 1000000000, + 'attributes': { + 'code.filepath': 'test_pydantic_plugin.py', + 'code.function': 'test_record_failure_env_var', + 'code.lineno': 123, + 'schema_name': 'MyModel', + 'logfire.msg_template': 'Validation on {schema_name} failed', + 'logfire.level_num': 13, + 'error_count': 1, + 'errors': '[{"type":"int_parsing","loc":["x"],"msg":"Input should be a valid integer, unable to parse string as an integer","input":"b"}]', + 'logfire.span_type': 'log', + 'logfire.msg': 'Validation on MyModel failed', + 'logfire.json_schema': '{"type":"object","properties":{"schema_name":{},"error_count":{},"errors":{"type":"array","items":{"type":"object","properties":{"loc":{"type":"array","x-python-datatype":"tuple"}}}}}}', + }, + } + ] + ) + + +def test_record_metrics_env_var(metrics_reader: InMemoryMetricReader) -> None: + # Same as test_record_all_env_var but with LOGFIRE_PYDANTIC_PLUGIN_RECORD=metrics. + + GLOBAL_CONFIG._initialized = False # type: ignore + + with patch.dict(os.environ, {'LOGFIRE_PYDANTIC_PLUGIN_RECORD': 'metrics'}): + + class MyModel(BaseModel): + x: int + + MyModel(x=1) + assert metrics_reader.get_metrics_data() is None # type: ignore + + GLOBAL_CONFIG._initialized = True # type: ignore + + MyModel(x=2) + assert get_collected_metrics(metrics_reader) == snapshot( + [ + { + 'name': 'pydantic.validations', + 'description': '', + 'unit': '', + 'data': { + 'data_points': [ + { + 'attributes': { + 'success': True, + 'schema_name': 'MyModel', + 'validation_method': 'validate_python', + }, + 'start_time_unix_nano': IsInt(gt=0), + 'time_unix_nano': IsInt(gt=0), + 'value': 1, + } + ], + 'aggregation_temporality': 1, + 'is_monotonic': True, + }, + } + ] + )