diff --git a/logfire/_internal/main.py b/logfire/_internal/main.py index 99a5b32c..2387831d 100644 --- a/logfire/_internal/main.py +++ b/logfire/_internal/main.py @@ -1,10 +1,11 @@ from __future__ import annotations +import atexit import sys import traceback import typing import warnings -from functools import cached_property +from functools import cached_property, partial from time import time from types import TracebackType from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Literal, Sequence, TypeVar, Union, cast @@ -1105,16 +1106,19 @@ def shutdown(self, timeout_millis: int = 30_000, flush: bool = True) -> bool: # class FastLogfireSpan: """A simple version of `LogfireSpan` optimized for auto-tracing.""" - __slots__ = ('_span', '_token') + __slots__ = ('_span', '_token', '_atexit') def __init__(self, span: trace_api.Span) -> None: self._span = span self._token = context_api.attach(trace_api.set_span_in_context(self._span)) + self._atexit = partial(self.__exit__, None, None, None) + atexit.register(self._atexit) def __enter__(self) -> FastLogfireSpan: return self def __exit__(self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: Any) -> None: + atexit.unregister(self._atexit) context_api.detach(self._token) _exit_span(self._span, exc_value) self._span.end() @@ -1139,6 +1143,7 @@ def __init__( self._token: None | object = None self._span: None | trace_api.Span = None self.end_on_exit = True + self._atexit: Callable[[], None] | None = None if not TYPE_CHECKING: # pragma: no branch @@ -1154,12 +1159,19 @@ def __enter__(self) -> LogfireSpan: ) if self._token is None: # pragma: no branch self._token = context_api.attach(trace_api.set_span_in_context(self._span)) + + self._atexit = partial(self.__exit__, None, None, None) + atexit.register(self._atexit) + return self def __exit__(self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: Any) -> None: if self._token is None: # pragma: no cover return + if self._atexit: # pragma: no branch + atexit.unregister(self._atexit) + context_api.detach(self._token) self._token = None