From f42fb50fcef8ef4dd00c168551a5e3b624c25878 Mon Sep 17 00:00:00 2001 From: Rachel Yang Date: Thu, 15 Aug 2024 15:40:28 -0400 Subject: [PATCH] chore(integrations): move openai,psycopg,pylibmc,pymemcache,pymongo,pymysql to internal (#10186) - Moves all integration internals in ddtrace/contrib/(integration name)/ to ddtrace/contrib/internal/(integration name)/ for openai, psycopg, pylibmc, pymemcache, pymongo, and pymysql - Ensures ddtrace/contrib/(integration name)/ and ddtrace/contrib/(integration name)/ continue to expose the same functions, classes, imports, and module level variables (via from ..internal.integration.module import * imports). - Log a deprecation warning if internal modules in ddtrace/contrib/(integration name)/ and ddtrace/contrib/(integration name)/. Only patch and unpack methods should be exposed by these packages. - https://github.com/DataDog/dd-trace-py/pull/9996 ## Checklist - [x] PR author has checked that all the criteria below are met - The PR description includes an overview of the change - The PR description articulates the motivation for the change - The change includes tests OR the PR description describes a testing strategy - The PR description notes risks associated with the change, if any - Newly-added code is easy to change - The change follows the [library release note guidelines](https://ddtrace.readthedocs.io/en/stable/releasenotes.html) - The change includes or references documentation updates if necessary - Backport labels are set (if [applicable](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting)) ## Reviewer Checklist - [x] Reviewer has checked that all the criteria below are met - Title is accurate - All changes are related to the pull request's stated goal - Avoids breaking [API](https://ddtrace.readthedocs.io/en/stable/versioning.html#interfaces) changes - Testing strategy adequately addresses listed risks - Newly-added code is easy to change - Release note makes sense to a user of the library - If necessary, author has acknowledged and discussed the performance implications of this PR as reported in the benchmarks PR comment - Backport labels are set in a manner that is consistent with the [release branch maintenance policy](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting) --------- Co-authored-by: Emmett Butler <723615+emmettbutler@users.noreply.github.com> Co-authored-by: Munir Abdinur Co-authored-by: Munir Abdinur --- ddtrace/contrib/internal/aiopg/patch.py | 6 +- ddtrace/contrib/internal/django/patch.py | 8 +- .../internal/psycopg/async_connection.py | 66 ++ .../contrib/internal/psycopg/async_cursor.py | 11 + .../contrib/internal/psycopg/connection.py | 110 +++ ddtrace/contrib/internal/psycopg/cursor.py | 28 + .../contrib/internal/psycopg/extensions.py | 180 +++++ ddtrace/contrib/internal/psycopg/patch.py | 213 +++++ ddtrace/contrib/internal/pylibmc/addrs.py | 14 + ddtrace/contrib/internal/pylibmc/client.py | 193 +++++ ddtrace/contrib/internal/pylibmc/patch.py | 26 + ddtrace/contrib/internal/pymemcache/client.py | 362 +++++++++ ddtrace/contrib/internal/pymemcache/patch.py | 49 ++ ddtrace/contrib/internal/pymongo/client.py | 372 +++++++++ ddtrace/contrib/internal/pymongo/parse.py | 204 +++++ ddtrace/contrib/internal/pymongo/patch.py | 98 +++ ddtrace/contrib/internal/pymysql/patch.py | 68 ++ ddtrace/contrib/openai/__init__.py | 10 +- ddtrace/contrib/openai/_endpoint_hooks.py | 762 +----------------- ddtrace/contrib/openai/patch.py | 358 +------- ddtrace/contrib/openai/utils.py | 360 +-------- ddtrace/contrib/psycopg/__init__.py | 18 +- ddtrace/contrib/psycopg/async_connection.py | 73 +- ddtrace/contrib/psycopg/async_cursor.py | 18 +- ddtrace/contrib/psycopg/connection.py | 117 +-- ddtrace/contrib/psycopg/cursor.py | 35 +- ddtrace/contrib/psycopg/extensions.py | 188 +---- ddtrace/contrib/psycopg/patch.py | 246 +----- ddtrace/contrib/pylibmc/__init__.py | 10 +- ddtrace/contrib/pylibmc/addrs.py | 23 +- ddtrace/contrib/pylibmc/client.py | 200 +---- ddtrace/contrib/pylibmc/patch.py | 33 +- ddtrace/contrib/pymemcache/__init__.py | 10 +- ddtrace/contrib/pymemcache/client.py | 369 +-------- ddtrace/contrib/pymemcache/patch.py | 61 +- ddtrace/contrib/pymongo/__init__.py | 8 +- ddtrace/contrib/pymongo/client.py | 379 +-------- ddtrace/contrib/pymongo/parse.py | 211 +---- ddtrace/contrib/pymongo/patch.py | 140 +--- ddtrace/contrib/pymysql/__init__.py | 8 +- ddtrace/contrib/pymysql/patch.py | 89 +- ...s-to-internal-openai-0d4ab4241552ff94.yaml | 14 + tests/.suitespec.json | 10 +- tests/contrib/openai/test_openai_llmobs.py | 12 +- tests/contrib/openai/test_openai_v0.py | 10 +- tests/contrib/openai/test_openai_v1.py | 14 +- tests/contrib/pymongo/test.py | 4 +- 47 files changed, 2233 insertions(+), 3565 deletions(-) create mode 100644 ddtrace/contrib/internal/psycopg/async_connection.py create mode 100644 ddtrace/contrib/internal/psycopg/async_cursor.py create mode 100644 ddtrace/contrib/internal/psycopg/connection.py create mode 100644 ddtrace/contrib/internal/psycopg/cursor.py create mode 100644 ddtrace/contrib/internal/psycopg/extensions.py create mode 100644 ddtrace/contrib/internal/psycopg/patch.py create mode 100644 ddtrace/contrib/internal/pylibmc/addrs.py create mode 100644 ddtrace/contrib/internal/pylibmc/client.py create mode 100644 ddtrace/contrib/internal/pylibmc/patch.py create mode 100644 ddtrace/contrib/internal/pymemcache/client.py create mode 100644 ddtrace/contrib/internal/pymemcache/patch.py create mode 100644 ddtrace/contrib/internal/pymongo/client.py create mode 100644 ddtrace/contrib/internal/pymongo/parse.py create mode 100644 ddtrace/contrib/internal/pymongo/patch.py create mode 100644 ddtrace/contrib/internal/pymysql/patch.py create mode 100644 releasenotes/notes/move-integrations-to-internal-openai-0d4ab4241552ff94.yaml diff --git a/ddtrace/contrib/internal/aiopg/patch.py b/ddtrace/contrib/internal/aiopg/patch.py index 35cf78c375c..5406de636c0 100644 --- a/ddtrace/contrib/internal/aiopg/patch.py +++ b/ddtrace/contrib/internal/aiopg/patch.py @@ -4,9 +4,9 @@ from ddtrace import config from ddtrace.contrib.aiopg.connection import AIOTracedConnection -from ddtrace.contrib.psycopg.connection import patch_conn as psycopg_patch_conn -from ddtrace.contrib.psycopg.extensions import _patch_extensions -from ddtrace.contrib.psycopg.extensions import _unpatch_extensions +from ddtrace.contrib.internal.psycopg.connection import patch_conn as psycopg_patch_conn +from ddtrace.contrib.internal.psycopg.extensions import _patch_extensions +from ddtrace.contrib.internal.psycopg.extensions import _unpatch_extensions from ddtrace.internal.schema import schematize_service_name from ddtrace.internal.utils.wrappers import unwrap as _u from ddtrace.vendor import wrapt diff --git a/ddtrace/contrib/internal/django/patch.py b/ddtrace/contrib/internal/django/patch.py index 75b0d3d70fc..a2a6f7aaf67 100644 --- a/ddtrace/contrib/internal/django/patch.py +++ b/ddtrace/contrib/internal/django/patch.py @@ -105,13 +105,13 @@ def patch_conn(django, conn): try: from psycopg.cursor import Cursor as psycopg_cursor_cls - from ddtrace.contrib.psycopg.cursor import Psycopg3TracedCursor + from ddtrace.contrib.internal.psycopg.cursor import Psycopg3TracedCursor except ImportError: Psycopg3TracedCursor = None try: from psycopg2._psycopg import cursor as psycopg_cursor_cls - from ddtrace.contrib.psycopg.cursor import Psycopg2TracedCursor + from ddtrace.contrib.internal.psycopg.cursor import Psycopg2TracedCursor except ImportError: psycopg_cursor_cls = None Psycopg2TracedCursor = None @@ -148,12 +148,12 @@ def cursor(django, pin, func, instance, args, kwargs): try: if cursor.cursor.__class__.__module__.startswith("psycopg2."): # Import lazily to avoid importing psycopg2 if not already imported. - from ddtrace.contrib.psycopg.cursor import Psycopg2TracedCursor + from ddtrace.contrib.internal.psycopg.cursor import Psycopg2TracedCursor traced_cursor_cls = Psycopg2TracedCursor elif type(cursor.cursor).__name__ == "Psycopg3TracedCursor": # Import lazily to avoid importing psycopg if not already imported. - from ddtrace.contrib.psycopg.cursor import Psycopg3TracedCursor + from ddtrace.contrib.internal.psycopg.cursor import Psycopg3TracedCursor traced_cursor_cls = Psycopg3TracedCursor except AttributeError: diff --git a/ddtrace/contrib/internal/psycopg/async_connection.py b/ddtrace/contrib/internal/psycopg/async_connection.py new file mode 100644 index 00000000000..14ec854ffd1 --- /dev/null +++ b/ddtrace/contrib/internal/psycopg/async_connection.py @@ -0,0 +1,66 @@ +from ddtrace import Pin +from ddtrace import config +from ddtrace.constants import SPAN_KIND +from ddtrace.constants import SPAN_MEASURED_KEY +from ddtrace.contrib import dbapi_async +from ddtrace.contrib.internal.psycopg.async_cursor import Psycopg3FetchTracedAsyncCursor +from ddtrace.contrib.internal.psycopg.async_cursor import Psycopg3TracedAsyncCursor +from ddtrace.contrib.internal.psycopg.connection import patch_conn +from ddtrace.contrib.trace_utils import ext_service +from ddtrace.ext import SpanKind +from ddtrace.ext import SpanTypes +from ddtrace.ext import db +from ddtrace.internal.constants import COMPONENT + + +class Psycopg3TracedAsyncConnection(dbapi_async.TracedAsyncConnection): + def __init__(self, conn, pin=None, cursor_cls=None): + if not cursor_cls: + # Do not trace `fetch*` methods by default + cursor_cls = ( + Psycopg3FetchTracedAsyncCursor if config.psycopg.trace_fetch_methods else Psycopg3TracedAsyncCursor + ) + + super(Psycopg3TracedAsyncConnection, self).__init__(conn, pin, config.psycopg, cursor_cls=cursor_cls) + + async def execute(self, *args, **kwargs): + """Execute a query and return a cursor to read its results.""" + span_name = "{}.{}".format(self._self_datadog_name, "execute") + + async def patched_execute(*args, **kwargs): + try: + cur = self.cursor() + if kwargs.get("binary", None): + cur.format = 1 # set to 1 for binary or 0 if not + return await cur.execute(*args, **kwargs) + except Exception as ex: + raise ex.with_traceback(None) + + return await self._trace_method(patched_execute, span_name, {}, *args, **kwargs) + + +def patched_connect_async_factory(psycopg_module): + async def patched_connect_async(connect_func, _, args, kwargs): + traced_conn_cls = Psycopg3TracedAsyncConnection + + pin = Pin.get_from(psycopg_module) + + if not pin or not pin.enabled() or not pin._config.trace_connect: + conn = await connect_func(*args, **kwargs) + else: + with pin.tracer.trace( + "{}.{}".format(connect_func.__module__, connect_func.__name__), + service=ext_service(pin, pin._config), + span_type=SpanTypes.SQL, + ) as span: + span.set_tag_str(SPAN_KIND, SpanKind.CLIENT) + span.set_tag_str(COMPONENT, pin._config.integration_name) + if span.get_tag(db.SYSTEM) is None: + span.set_tag_str(db.SYSTEM, pin._config.dbms_name) + + span.set_tag(SPAN_MEASURED_KEY) + conn = await connect_func(*args, **kwargs) + + return patch_conn(conn, pin=pin, traced_conn_cls=traced_conn_cls) + + return patched_connect_async diff --git a/ddtrace/contrib/internal/psycopg/async_cursor.py b/ddtrace/contrib/internal/psycopg/async_cursor.py new file mode 100644 index 00000000000..a7e1f4a710b --- /dev/null +++ b/ddtrace/contrib/internal/psycopg/async_cursor.py @@ -0,0 +1,11 @@ +from ddtrace.contrib import dbapi_async +from ddtrace.contrib.internal.psycopg.cursor import Psycopg3TracedCursor + + +class Psycopg3TracedAsyncCursor(Psycopg3TracedCursor, dbapi_async.TracedAsyncCursor): + def __init__(self, cursor, pin, cfg, *args, **kwargs): + super(Psycopg3TracedAsyncCursor, self).__init__(cursor, pin, cfg) + + +class Psycopg3FetchTracedAsyncCursor(Psycopg3TracedAsyncCursor, dbapi_async.FetchTracedAsyncCursor): + """Psycopg3FetchTracedAsyncCursor for psycopg""" diff --git a/ddtrace/contrib/internal/psycopg/connection.py b/ddtrace/contrib/internal/psycopg/connection.py new file mode 100644 index 00000000000..c823e17dc61 --- /dev/null +++ b/ddtrace/contrib/internal/psycopg/connection.py @@ -0,0 +1,110 @@ +from ddtrace import Pin +from ddtrace import config +from ddtrace.constants import SPAN_KIND +from ddtrace.constants import SPAN_MEASURED_KEY +from ddtrace.contrib import dbapi +from ddtrace.contrib.internal.psycopg.cursor import Psycopg2FetchTracedCursor +from ddtrace.contrib.internal.psycopg.cursor import Psycopg2TracedCursor +from ddtrace.contrib.internal.psycopg.cursor import Psycopg3FetchTracedCursor +from ddtrace.contrib.internal.psycopg.cursor import Psycopg3TracedCursor +from ddtrace.contrib.internal.psycopg.extensions import _patch_extensions +from ddtrace.contrib.trace_utils import ext_service +from ddtrace.ext import SpanKind +from ddtrace.ext import SpanTypes +from ddtrace.ext import db +from ddtrace.ext import net +from ddtrace.ext import sql +from ddtrace.internal.constants import COMPONENT + + +class Psycopg3TracedConnection(dbapi.TracedConnection): + def __init__(self, conn, pin=None, cursor_cls=None): + if not cursor_cls: + # Do not trace `fetch*` methods by default + cursor_cls = Psycopg3FetchTracedCursor if config.psycopg.trace_fetch_methods else Psycopg3TracedCursor + + super(Psycopg3TracedConnection, self).__init__(conn, pin, config.psycopg, cursor_cls=cursor_cls) + + def execute(self, *args, **kwargs): + """Execute a query and return a cursor to read its results.""" + + def patched_execute(*args, **kwargs): + try: + cur = self.cursor() + if kwargs.get("binary", None): + cur.format = 1 # set to 1 for binary or 0 if not + return cur.execute(*args, **kwargs) + except Exception as ex: + raise ex.with_traceback(None) + + return patched_execute(*args, **kwargs) + + +class Psycopg2TracedConnection(dbapi.TracedConnection): + """TracedConnection wraps a Connection with tracing code.""" + + def __init__(self, conn, pin=None, cursor_cls=None): + if not cursor_cls: + # Do not trace `fetch*` methods by default + cursor_cls = Psycopg2FetchTracedCursor if config.psycopg.trace_fetch_methods else Psycopg2TracedCursor + + super(Psycopg2TracedConnection, self).__init__(conn, pin, config.psycopg, cursor_cls=cursor_cls) + + +def patch_conn(conn, traced_conn_cls, pin=None): + """Wrap will patch the instance so that its queries are traced.""" + # ensure we've patched extensions (this is idempotent) in + # case we're only tracing some connections. + _config = None + if pin: + extensions_to_patch = pin._config.get("_extensions_to_patch", None) + _config = pin._config + if extensions_to_patch: + _patch_extensions(extensions_to_patch) + + c = traced_conn_cls(conn) + + # if the connection has an info attr, we are using psycopg3 + if hasattr(conn, "dsn"): + dsn = sql.parse_pg_dsn(conn.dsn) + else: + dsn = sql.parse_pg_dsn(conn.info.dsn) + + tags = { + net.TARGET_HOST: dsn.get("host"), + net.TARGET_PORT: dsn.get("port", 5432), + net.SERVER_ADDRESS: dsn.get("host"), + db.NAME: dsn.get("dbname"), + db.USER: dsn.get("user"), + "db.application": dsn.get("application_name"), + db.SYSTEM: "postgresql", + } + Pin(tags=tags, _config=_config).onto(c) + return c + + +def patched_connect_factory(psycopg_module): + def patched_connect(connect_func, _, args, kwargs): + traced_conn_cls = Psycopg3TracedConnection if psycopg_module.__name__ == "psycopg" else Psycopg2TracedConnection + + pin = Pin.get_from(psycopg_module) + + if not pin or not pin.enabled() or not pin._config.trace_connect: + conn = connect_func(*args, **kwargs) + else: + with pin.tracer.trace( + "{}.{}".format(connect_func.__module__, connect_func.__name__), + service=ext_service(pin, pin._config), + span_type=SpanTypes.SQL, + ) as span: + span.set_tag_str(SPAN_KIND, SpanKind.CLIENT) + span.set_tag_str(COMPONENT, pin._config.integration_name) + if span.get_tag(db.SYSTEM) is None: + span.set_tag_str(db.SYSTEM, pin._config.dbms_name) + + span.set_tag(SPAN_MEASURED_KEY) + conn = connect_func(*args, **kwargs) + + return patch_conn(conn, pin=pin, traced_conn_cls=traced_conn_cls) + + return patched_connect diff --git a/ddtrace/contrib/internal/psycopg/cursor.py b/ddtrace/contrib/internal/psycopg/cursor.py new file mode 100644 index 00000000000..6596b558cd3 --- /dev/null +++ b/ddtrace/contrib/internal/psycopg/cursor.py @@ -0,0 +1,28 @@ +from ddtrace.contrib import dbapi + + +class Psycopg3TracedCursor(dbapi.TracedCursor): + """TracedCursor for psycopg instances""" + + def __init__(self, cursor, pin, cfg, *args, **kwargs): + super(Psycopg3TracedCursor, self).__init__(cursor, pin, cfg) + + def _trace_method(self, method, name, resource, extra_tags, dbm_propagator, *args, **kwargs): + # treat Composable resource objects as strings + if resource.__class__.__name__ == "SQL" or resource.__class__.__name__ == "Composed": + resource = resource.as_string(self.__wrapped__) + return super(Psycopg3TracedCursor, self)._trace_method( + method, name, resource, extra_tags, dbm_propagator, *args, **kwargs + ) + + +class Psycopg3FetchTracedCursor(Psycopg3TracedCursor, dbapi.FetchTracedCursor): + """Psycopg3FetchTracedCursor for psycopg""" + + +class Psycopg2TracedCursor(Psycopg3TracedCursor): + """TracedCursor for psycopg2""" + + +class Psycopg2FetchTracedCursor(Psycopg3FetchTracedCursor): + """FetchTracedCursor for psycopg2""" diff --git a/ddtrace/contrib/internal/psycopg/extensions.py b/ddtrace/contrib/internal/psycopg/extensions.py new file mode 100644 index 00000000000..0c8d97cac38 --- /dev/null +++ b/ddtrace/contrib/internal/psycopg/extensions.py @@ -0,0 +1,180 @@ +""" +Tracing utilities for the psycopg2 potgres client library. +""" +import functools + +from ddtrace import config +from ddtrace.constants import SPAN_KIND +from ddtrace.constants import SPAN_MEASURED_KEY +from ddtrace.ext import SpanKind +from ddtrace.ext import SpanTypes +from ddtrace.ext import db +from ddtrace.ext import net +from ddtrace.internal.constants import COMPONENT +from ddtrace.internal.schema import schematize_database_operation +from ddtrace.vendor import wrapt + + +def get_psycopg2_extensions(psycopg_module): + class TracedCursor(psycopg_module.extensions.cursor): + """Wrapper around cursor creating one span per query""" + + def __init__(self, *args, **kwargs): + self._datadog_tracer = kwargs.pop("datadog_tracer", None) + self._datadog_service = kwargs.pop("datadog_service", None) + self._datadog_tags = kwargs.pop("datadog_tags", None) + super(TracedCursor, self).__init__(*args, **kwargs) + + def execute(self, query, vars=None): # noqa: A002 + """just wrap the cursor execution in a span""" + if not self._datadog_tracer: + return psycopg_module.extensions.cursor.execute(self, query, vars) + + with self._datadog_tracer.trace( + schematize_database_operation("postgres.query", database_provider="postgresql"), + service=self._datadog_service, + span_type=SpanTypes.SQL, + ) as s: + s.set_tag_str(COMPONENT, config.psycopg.integration_name) + s.set_tag_str(db.SYSTEM, config.psycopg.dbms_name) + + # set span.kind to the type of operation being performed + s.set_tag_str(SPAN_KIND, SpanKind.CLIENT) + + s.set_tag(SPAN_MEASURED_KEY) + if s.context.sampling_priority is None or s.context.sampling_priority <= 0: + return super(TracedCursor, self).execute(query, vars) + + s.resource = query + s.set_tags(self._datadog_tags) + try: + return super(TracedCursor, self).execute(query, vars) + finally: + s.set_metric(db.ROWCOUNT, self.rowcount) + + def callproc(self, procname, vars=None): # noqa: A002 + """just wrap the execution in a span""" + return psycopg_module.extensions.cursor.callproc(self, procname, vars) + + class TracedConnection(psycopg_module.extensions.connection): + """Wrapper around psycopg2 for tracing""" + + def __init__(self, *args, **kwargs): + self._datadog_tracer = kwargs.pop("datadog_tracer", None) + self._datadog_service = kwargs.pop("datadog_service", None) + + super(TracedConnection, self).__init__(*args, **kwargs) + + # add metadata (from the connection, string, etc) + dsn = psycopg_module.extensions.parse_dsn(self.dsn) + self._datadog_tags = { + net.TARGET_HOST: dsn.get("host"), + net.TARGET_PORT: dsn.get("port"), + net.SERVER_ADDRESS: dsn.get("host"), + db.NAME: dsn.get("dbname"), + db.USER: dsn.get("user"), + db.SYSTEM: config.psycopg.dbms_name, + "db.application": dsn.get("application_name"), + } + + self._datadog_cursor_class = functools.partial( + TracedCursor, + datadog_tracer=self._datadog_tracer, + datadog_service=self._datadog_service, + datadog_tags=self._datadog_tags, + ) + + def cursor(self, *args, **kwargs): + """register our custom cursor factory""" + kwargs.setdefault("cursor_factory", self._datadog_cursor_class) + return super(TracedConnection, self).cursor(*args, **kwargs) + + # extension hooks + _extensions = [ + ( + psycopg_module.extensions.register_type, + psycopg_module.extensions, + "register_type", + _extensions_register_type, + ), + (psycopg_module._psycopg.register_type, psycopg_module._psycopg, "register_type", _extensions_register_type), + (psycopg_module.extensions.adapt, psycopg_module.extensions, "adapt", _extensions_adapt), + ] + + # `_json` attribute is only available for psycopg >= 2.5 + if getattr(psycopg_module, "_json", None): + _extensions += [ + (psycopg_module._json.register_type, psycopg_module._json, "register_type", _extensions_register_type), + ] + + # `quote_ident` attribute is only available for psycopg >= 2.7 + if getattr(psycopg_module, "extensions", None) and getattr(psycopg_module.extensions, "quote_ident", None): + _extensions += [ + (psycopg_module.extensions.quote_ident, psycopg_module.extensions, "quote_ident", _extensions_quote_ident), + ] + + return _extensions + + +def _extensions_register_type(func, _, args, kwargs): + def _unroll_args(obj, scope=None): + return obj, scope + + obj, scope = _unroll_args(*args, **kwargs) + + # register_type performs a c-level check of the object + # type so we must be sure to pass in the actual db connection + if scope and isinstance(scope, wrapt.ObjectProxy): + scope = scope.__wrapped__ + + return func(obj, scope) if scope else func(obj) + + +def _extensions_quote_ident(func, _, args, kwargs): + def _unroll_args(obj, scope=None): + return obj, scope + + obj, scope = _unroll_args(*args, **kwargs) + + # register_type performs a c-level check of the object + # type so we must be sure to pass in the actual db connection + if scope and isinstance(scope, wrapt.ObjectProxy): + scope = scope.__wrapped__ + + return func(obj, scope) if scope else func(obj) + + +def _extensions_adapt(func, _, args, kwargs): + adapt = func(*args, **kwargs) + if hasattr(adapt, "prepare"): + return AdapterWrapper(adapt) + return adapt + + +class AdapterWrapper(wrapt.ObjectProxy): + def prepare(self, *args, **kwargs): + func = self.__wrapped__.prepare + if not args: + return func(*args, **kwargs) + conn = args[0] + + # prepare performs a c-level check of the object type so + # we must be sure to pass in the actual db connection + if isinstance(conn, wrapt.ObjectProxy): + conn = conn.__wrapped__ + + return func(conn, *args[1:], **kwargs) + + +def _patch_extensions(_extensions): + # we must patch extensions all the time (it's pretty harmless) so split + # from global patching of connections. must be idempotent. + for _, module, func, wrapper in _extensions: + if not hasattr(module, func) or isinstance(getattr(module, func), wrapt.ObjectProxy): + continue + wrapt.wrap_function_wrapper(module, func, wrapper) + + +def _unpatch_extensions(_extensions): + for original, module, func, _ in _extensions: + setattr(module, func, original) diff --git a/ddtrace/contrib/internal/psycopg/patch.py b/ddtrace/contrib/internal/psycopg/patch.py new file mode 100644 index 00000000000..22fd580e477 --- /dev/null +++ b/ddtrace/contrib/internal/psycopg/patch.py @@ -0,0 +1,213 @@ +from importlib import import_module +import inspect +import os +from typing import List # noqa:F401 + +from ddtrace import Pin +from ddtrace import config +from ddtrace.contrib import dbapi + + +try: + from ddtrace.contrib.internal.psycopg.async_connection import patched_connect_async_factory + from ddtrace.contrib.internal.psycopg.async_cursor import Psycopg3FetchTracedAsyncCursor + from ddtrace.contrib.internal.psycopg.async_cursor import Psycopg3TracedAsyncCursor +# catch async function syntax errors when using Python<3.7 with no async support +except SyntaxError: + pass +from ddtrace.contrib.internal.psycopg.connection import patched_connect_factory +from ddtrace.contrib.internal.psycopg.cursor import Psycopg3FetchTracedCursor +from ddtrace.contrib.internal.psycopg.cursor import Psycopg3TracedCursor +from ddtrace.contrib.internal.psycopg.extensions import _patch_extensions +from ddtrace.contrib.internal.psycopg.extensions import _unpatch_extensions +from ddtrace.contrib.internal.psycopg.extensions import get_psycopg2_extensions +from ddtrace.internal.schema import schematize_database_operation +from ddtrace.internal.schema import schematize_service_name +from ddtrace.internal.utils.formats import asbool +from ddtrace.internal.utils.wrappers import unwrap as _u +from ddtrace.propagation._database_monitoring import _DBM_Propagator +from ddtrace.propagation._database_monitoring import default_sql_injector as _default_sql_injector +from ddtrace.vendor.wrapt import wrap_function_wrapper as _w + + +try: + psycopg_import = import_module("psycopg") + + # must get the original connect class method from the class __dict__ to use later in unpatch + # Python 3.11 and wrapt result in the class method being rebinded as an instance method when + # using unwrap + _original_connect = psycopg_import.Connection.__dict__["connect"] + _original_async_connect = psycopg_import.AsyncConnection.__dict__["connect"] +# AttributeError can happen due to circular imports under certain integration methods +except (ImportError, AttributeError): + pass + + +def _psycopg_sql_injector(dbm_comment, sql_statement): + for psycopg_module in config.psycopg["_patched_modules"]: + if ( + hasattr(psycopg_module, "sql") + and hasattr(psycopg_module.sql, "Composable") + and isinstance(sql_statement, psycopg_module.sql.Composable) + ): + return psycopg_module.sql.SQL(dbm_comment) + sql_statement + return _default_sql_injector(dbm_comment, sql_statement) + + +config._add( + "psycopg", + dict( + _default_service=schematize_service_name("postgres"), + _dbapi_span_name_prefix="postgres", + _dbapi_span_operation_name=schematize_database_operation("postgres.query", database_provider="postgresql"), + _patched_modules=set(), + trace_fetch_methods=asbool( + os.getenv("DD_PSYCOPG_TRACE_FETCH_METHODS", default=False) + or os.getenv("DD_PSYCOPG2_TRACE_FETCH_METHODS", default=False) + ), + trace_connect=asbool( + os.getenv("DD_PSYCOPG_TRACE_CONNECT", default=False) + or os.getenv("DD_PSYCOPG2_TRACE_CONNECT", default=False) + ), + _dbm_propagator=_DBM_Propagator(0, "query", _psycopg_sql_injector), + dbms_name="postgresql", + ), +) + + +def get_version(): + # type: () -> str + return "" + + +PATCHED_VERSIONS = {} + + +def get_versions(): + # type: () -> List[str] + return PATCHED_VERSIONS + + +def _psycopg_modules(): + module_names = ( + "psycopg", + "psycopg2", + ) + for module_name in module_names: + try: + module = import_module(module_name) + PATCHED_VERSIONS[module_name] = getattr(module, "__version__", "") + yield module + except ImportError: + pass + + +def patch(): + for psycopg_module in _psycopg_modules(): + _patch(psycopg_module) + + +def _patch(psycopg_module): + """Patch monkey patches psycopg's connection function + so that the connection's functions are traced. + """ + if getattr(psycopg_module, "_datadog_patch", False): + return + psycopg_module._datadog_patch = True + + Pin(_config=config.psycopg).onto(psycopg_module) + + if psycopg_module.__name__ == "psycopg2": + # patch all psycopg2 extensions + _psycopg2_extensions = get_psycopg2_extensions(psycopg_module) + config.psycopg["_extensions_to_patch"] = _psycopg2_extensions + _patch_extensions(_psycopg2_extensions) + + _w(psycopg_module, "connect", patched_connect_factory(psycopg_module)) + + config.psycopg["_patched_modules"].add(psycopg_module) + else: + _w(psycopg_module, "connect", patched_connect_factory(psycopg_module)) + _w(psycopg_module, "Cursor", init_cursor_from_connection_factory(psycopg_module)) + _w(psycopg_module, "AsyncCursor", init_cursor_from_connection_factory(psycopg_module)) + + _w(psycopg_module.Connection, "connect", patched_connect_factory(psycopg_module)) + _w(psycopg_module.AsyncConnection, "connect", patched_connect_async_factory(psycopg_module)) + + config.psycopg["_patched_modules"].add(psycopg_module) + + +def unpatch(): + for psycopg_module in _psycopg_modules(): + _unpatch(psycopg_module) + + +def _unpatch(psycopg_module): + if getattr(psycopg_module, "_datadog_patch", False): + psycopg_module._datadog_patch = False + + if psycopg_module.__name__ == "psycopg2": + _u(psycopg_module, "connect") + + _psycopg2_extensions = get_psycopg2_extensions(psycopg_module) + _unpatch_extensions(_psycopg2_extensions) + else: + _u(psycopg_module, "connect") + _u(psycopg_module, "Cursor") + _u(psycopg_module, "AsyncCursor") + + # _u throws an attribute error for Python 3.11, no __get__ on the BoundFunctionWrapper + # unlike Python Class Methods which implement __get__ + psycopg_module.Connection.connect = _original_connect + psycopg_module.AsyncConnection.connect = _original_async_connect + + pin = Pin.get_from(psycopg_module) + if pin: + pin.remove_from(psycopg_module) + + +def init_cursor_from_connection_factory(psycopg_module): + def init_cursor_from_connection(wrapped_cursor_cls, _, args, kwargs): + connection = kwargs.pop("connection", None) + if not connection: + args = list(args) + index = next((i for i, x in enumerate(args) if isinstance(x, dbapi.TracedConnection)), None) + if index is not None: + connection = args.pop(index) + + # if we do not have an example of a traced connection, call the original cursor function + if not connection: + return wrapped_cursor_cls(*args, **kwargs) + + pin = Pin.get_from(connection).clone() + cfg = config.psycopg + + if cfg and cfg.trace_fetch_methods: + trace_fetch_methods = True + else: + trace_fetch_methods = False + + if issubclass(wrapped_cursor_cls, psycopg_module.AsyncCursor): + traced_cursor_cls = Psycopg3FetchTracedAsyncCursor if trace_fetch_methods else Psycopg3TracedAsyncCursor + else: + traced_cursor_cls = Psycopg3FetchTracedCursor if trace_fetch_methods else Psycopg3TracedCursor + + args_mapping = inspect.signature(wrapped_cursor_cls.__init__).parameters + # inspect.signature returns ordered dict[argument_name: str, parameter_type: type] + if "row_factory" in args_mapping and "row_factory" not in kwargs: + # check for row_factory in args by checking for functions + row_factory = None + for i in range(len(args)): + if callable(args[i]): + row_factory = args.pop(i) + break + # else just use the connection row factory + if row_factory is None: + row_factory = connection.row_factory + cursor = wrapped_cursor_cls(connection=connection, row_factory=row_factory, *args, **kwargs) # noqa: B026 + else: + cursor = wrapped_cursor_cls(connection, *args, **kwargs) + + return traced_cursor_cls(cursor=cursor, pin=pin, cfg=cfg) + + return init_cursor_from_connection diff --git a/ddtrace/contrib/internal/pylibmc/addrs.py b/ddtrace/contrib/internal/pylibmc/addrs.py new file mode 100644 index 00000000000..0f11d2ac44c --- /dev/null +++ b/ddtrace/contrib/internal/pylibmc/addrs.py @@ -0,0 +1,14 @@ +translate_server_specs = None + +try: + # NOTE: we rely on an undocumented method to parse addresses, + # so be a bit defensive and don't assume it exists. + from pylibmc.client import translate_server_specs +except ImportError: + pass + + +def parse_addresses(addrs): + if not translate_server_specs: + return [] + return translate_server_specs(addrs) diff --git a/ddtrace/contrib/internal/pylibmc/client.py b/ddtrace/contrib/internal/pylibmc/client.py new file mode 100644 index 00000000000..742747b062b --- /dev/null +++ b/ddtrace/contrib/internal/pylibmc/client.py @@ -0,0 +1,193 @@ +from contextlib import contextmanager +import random + +import pylibmc + +# project +import ddtrace +from ddtrace import config +from ddtrace.constants import ANALYTICS_SAMPLE_RATE_KEY +from ddtrace.constants import SPAN_KIND +from ddtrace.constants import SPAN_MEASURED_KEY +from ddtrace.contrib.internal.pylibmc.addrs import parse_addresses +from ddtrace.ext import SpanKind +from ddtrace.ext import SpanTypes +from ddtrace.ext import db +from ddtrace.ext import memcached +from ddtrace.ext import net +from ddtrace.internal.compat import Iterable +from ddtrace.internal.constants import COMPONENT +from ddtrace.internal.logger import get_logger +from ddtrace.internal.schema import schematize_cache_operation +from ddtrace.internal.schema import schematize_service_name +from ddtrace.vendor.wrapt import ObjectProxy + + +# Original Client class +_Client = pylibmc.Client + + +log = get_logger(__name__) + + +class TracedClient(ObjectProxy): + """TracedClient is a proxy for a pylibmc.Client that times it's network operations.""" + + def __init__(self, client=None, service=memcached.SERVICE, tracer=None, *args, **kwargs): + """Create a traced client that wraps the given memcached client.""" + # The client instance/service/tracer attributes are kept for compatibility + # with the old interface: TracedClient(client=pylibmc.Client(['localhost:11211'])) + # TODO(Benjamin): Remove these in favor of patching. + if not isinstance(client, _Client): + # We are in the patched situation, just pass down all arguments to the pylibmc.Client + # Note that, in that case, client isn't a real client (just the first argument) + client = _Client(client, *args, **kwargs) + else: + log.warning( + "TracedClient instantiation is deprecated and will be remove " + "in future versions (0.6.0). Use patching instead (see the docs)." + ) + + super(TracedClient, self).__init__(client) + + schematized_service = schematize_service_name(service) + pin = ddtrace.Pin(service=schematized_service, tracer=tracer) + pin.onto(self) + + # attempt to collect the pool of urls this client talks to + try: + self._addresses = parse_addresses(client.addresses) + except Exception: + log.debug("error setting addresses", exc_info=True) + + def clone(self, *args, **kwargs): + # rewrap new connections. + cloned = self.__wrapped__.clone(*args, **kwargs) + traced_client = TracedClient(cloned) + pin = ddtrace.Pin.get_from(self) + if pin: + pin.clone().onto(traced_client) + return traced_client + + def add(self, *args, **kwargs): + return self._trace_cmd("add", *args, **kwargs) + + def get(self, *args, **kwargs): + return self._trace_cmd("get", *args, **kwargs) + + def set(self, *args, **kwargs): + return self._trace_cmd("set", *args, **kwargs) + + def delete(self, *args, **kwargs): + return self._trace_cmd("delete", *args, **kwargs) + + def gets(self, *args, **kwargs): + return self._trace_cmd("gets", *args, **kwargs) + + def touch(self, *args, **kwargs): + return self._trace_cmd("touch", *args, **kwargs) + + def cas(self, *args, **kwargs): + return self._trace_cmd("cas", *args, **kwargs) + + def incr(self, *args, **kwargs): + return self._trace_cmd("incr", *args, **kwargs) + + def decr(self, *args, **kwargs): + return self._trace_cmd("decr", *args, **kwargs) + + def append(self, *args, **kwargs): + return self._trace_cmd("append", *args, **kwargs) + + def prepend(self, *args, **kwargs): + return self._trace_cmd("prepend", *args, **kwargs) + + def get_multi(self, *args, **kwargs): + return self._trace_multi_cmd("get_multi", *args, **kwargs) + + def set_multi(self, *args, **kwargs): + return self._trace_multi_cmd("set_multi", *args, **kwargs) + + def delete_multi(self, *args, **kwargs): + return self._trace_multi_cmd("delete_multi", *args, **kwargs) + + def _trace_cmd(self, method_name, *args, **kwargs): + """trace the execution of the method with the given name and will + patch the first arg. + """ + method = getattr(self.__wrapped__, method_name) + with self._span(method_name) as span: + result = method(*args, **kwargs) + if span is None: + return result + + if args: + span.set_tag_str(memcached.QUERY, "%s %s" % (method_name, args[0])) + if method_name == "get": + span.set_metric(db.ROWCOUNT, 1 if result else 0) + elif method_name == "gets": + # returns a tuple object that may be (None, None) + span.set_metric(db.ROWCOUNT, 1 if isinstance(result, Iterable) and len(result) > 0 and result[0] else 0) + return result + + def _trace_multi_cmd(self, method_name, *args, **kwargs): + """trace the execution of the multi command with the given name.""" + method = getattr(self.__wrapped__, method_name) + with self._span(method_name) as span: + result = method(*args, **kwargs) + if span is None: + return result + + pre = kwargs.get("key_prefix") + if pre: + span.set_tag_str(memcached.QUERY, "%s %s" % (method_name, pre)) + + if method_name == "get_multi": + # returns mapping of key -> value if key exists, but does not include a missing key. Empty result = {} + span.set_metric( + db.ROWCOUNT, sum(1 for doc in result if doc) if result and isinstance(result, Iterable) else 0 + ) + return result + + @contextmanager + def _no_span(self): + yield None + + def _span(self, cmd_name): + """Return a span timing the given command.""" + pin = ddtrace.Pin.get_from(self) + if not pin or not pin.enabled(): + return self._no_span() + + span = pin.tracer.trace( + schematize_cache_operation("memcached.cmd", cache_provider="memcached"), + service=pin.service, + resource=cmd_name, + span_type=SpanTypes.CACHE, + ) + + span.set_tag_str(COMPONENT, config.pylibmc.integration_name) + span.set_tag_str(db.SYSTEM, memcached.DBMS_NAME) + + # set span.kind to the type of operation being performed + span.set_tag_str(SPAN_KIND, SpanKind.CLIENT) + + span.set_tag(SPAN_MEASURED_KEY) + + try: + self._tag_span(span) + except Exception: + log.debug("error tagging span", exc_info=True) + return span + + def _tag_span(self, span): + # FIXME[matt] the host selection is buried in c code. we can't tell what it's actually + # using, so fallback to randomly choosing one. can we do better? + if self._addresses: + _, host, port, _ = random.choice(self._addresses) # nosec + span.set_tag_str(net.TARGET_HOST, host) + span.set_tag(net.TARGET_PORT, port) + span.set_tag_str(net.SERVER_ADDRESS, host) + + # set analytics sample rate + span.set_tag(ANALYTICS_SAMPLE_RATE_KEY, config.pylibmc.get_analytics_sample_rate()) diff --git a/ddtrace/contrib/internal/pylibmc/patch.py b/ddtrace/contrib/internal/pylibmc/patch.py new file mode 100644 index 00000000000..0d5075a01b5 --- /dev/null +++ b/ddtrace/contrib/internal/pylibmc/patch.py @@ -0,0 +1,26 @@ +import pylibmc + +from .client import TracedClient + + +# Original Client class +_Client = pylibmc.Client + + +def get_version(): + # type: () -> str + return getattr(pylibmc, "__version__", "") + + +def patch(): + if getattr(pylibmc, "_datadog_patch", False): + return + + pylibmc._datadog_patch = True + pylibmc.Client = TracedClient + + +def unpatch(): + if getattr(pylibmc, "_datadog_patch", False): + pylibmc._datadog_patch = False + pylibmc.Client = _Client diff --git a/ddtrace/contrib/internal/pymemcache/client.py b/ddtrace/contrib/internal/pymemcache/client.py new file mode 100644 index 00000000000..937f336106d --- /dev/null +++ b/ddtrace/contrib/internal/pymemcache/client.py @@ -0,0 +1,362 @@ +import os +import sys +from typing import Iterable + +import pymemcache +from pymemcache.client.base import Client +from pymemcache.client.base import PooledClient +from pymemcache.client.hash import HashClient +from pymemcache.exceptions import MemcacheClientError +from pymemcache.exceptions import MemcacheIllegalInputError +from pymemcache.exceptions import MemcacheServerError +from pymemcache.exceptions import MemcacheUnknownCommandError +from pymemcache.exceptions import MemcacheUnknownError + +# 3p +from ddtrace import config + +# project +from ddtrace.constants import ANALYTICS_SAMPLE_RATE_KEY +from ddtrace.constants import SPAN_KIND +from ddtrace.constants import SPAN_MEASURED_KEY +from ddtrace.ext import SpanKind +from ddtrace.ext import SpanTypes +from ddtrace.ext import db +from ddtrace.ext import memcached as memcachedx +from ddtrace.ext import net +from ddtrace.internal.constants import COMPONENT +from ddtrace.internal.logger import get_logger +from ddtrace.internal.schema import schematize_cache_operation +from ddtrace.internal.utils.formats import asbool +from ddtrace.pin import Pin +from ddtrace.vendor import wrapt + + +log = get_logger(__name__) + + +config._add( + "pymemcache", + { + "command_enabled": asbool(os.getenv("DD_TRACE_MEMCACHED_COMMAND_ENABLED", default=False)), + }, +) + + +# keep a reference to the original unpatched clients +_Client = Client +_HashClient = HashClient + + +class _WrapperBase(wrapt.ObjectProxy): + def __init__(self, wrapped_class, *args, **kwargs): + c = wrapped_class(*args, **kwargs) + super(_WrapperBase, self).__init__(c) + + # tags to apply to each span generated by this client + tags = _get_address_tags(*args, **kwargs) + + parent_pin = Pin.get_from(pymemcache) + + if parent_pin: + pin = parent_pin.clone(tags=tags) + else: + pin = Pin(tags=tags) + + # attach the pin onto this instance + pin.onto(self) + + def _trace_function_as_command(self, func, cmd, *args, **kwargs): + p = Pin.get_from(self) + + if not p or not p.enabled(): + return func(*args, **kwargs) + + return _trace(func, p, cmd, *args, **kwargs) + + +class WrappedClient(_WrapperBase): + """Wrapper providing patched methods of a pymemcache Client. + + Relevant connection information is obtained during initialization and + attached to each span. + + Keys are tagged in spans for methods that act upon a key. + """ + + def __init__(self, *args, **kwargs): + super(WrappedClient, self).__init__(_Client, *args, **kwargs) + + def set(self, *args, **kwargs): + return self._traced_cmd("set", *args, **kwargs) + + def set_many(self, *args, **kwargs): + return self._traced_cmd("set_many", *args, **kwargs) + + def add(self, *args, **kwargs): + return self._traced_cmd("add", *args, **kwargs) + + def replace(self, *args, **kwargs): + return self._traced_cmd("replace", *args, **kwargs) + + def append(self, *args, **kwargs): + return self._traced_cmd("append", *args, **kwargs) + + def prepend(self, *args, **kwargs): + return self._traced_cmd("prepend", *args, **kwargs) + + def cas(self, *args, **kwargs): + return self._traced_cmd("cas", *args, **kwargs) + + def get(self, *args, **kwargs): + return self._traced_cmd("get", *args, **kwargs) + + def get_many(self, *args, **kwargs): + return self._traced_cmd("get_many", *args, **kwargs) + + def gets(self, *args, **kwargs): + return self._traced_cmd("gets", *args, **kwargs) + + def gets_many(self, *args, **kwargs): + return self._traced_cmd("gets_many", *args, **kwargs) + + def delete(self, *args, **kwargs): + return self._traced_cmd("delete", *args, **kwargs) + + def delete_many(self, *args, **kwargs): + return self._traced_cmd("delete_many", *args, **kwargs) + + def incr(self, *args, **kwargs): + return self._traced_cmd("incr", *args, **kwargs) + + def decr(self, *args, **kwargs): + return self._traced_cmd("decr", *args, **kwargs) + + def touch(self, *args, **kwargs): + return self._traced_cmd("touch", *args, **kwargs) + + def stats(self, *args, **kwargs): + return self._traced_cmd("stats", *args, **kwargs) + + def version(self, *args, **kwargs): + return self._traced_cmd("version", *args, **kwargs) + + def flush_all(self, *args, **kwargs): + return self._traced_cmd("flush_all", *args, **kwargs) + + def quit(self, *args, **kwargs): + return self._traced_cmd("quit", *args, **kwargs) + + def set_multi(self, *args, **kwargs): + """set_multi is an alias for set_many""" + return self._traced_cmd("set_many", *args, **kwargs) + + def get_multi(self, *args, **kwargs): + """set_multi is an alias for set_many""" + return self._traced_cmd("get_many", *args, **kwargs) + + def _traced_cmd(self, command, *args, **kwargs): + return self._trace_function_as_command( + lambda *_args, **_kwargs: getattr(self.__wrapped__, command)(*_args, **_kwargs), command, *args, **kwargs + ) + + +class WrappedHashClient(_WrapperBase): + """Wrapper that traces HashClient commands + + This wrapper proxies its command invocations to the underlying HashClient instance. + When the use_pooling setting is in use, this wrapper starts a span before + doing the proxy call. + + This is necessary because the use_pooling setting causes Client instances to be + created and destroyed dynamically in a manner that isn't affected by the + patch() function. + """ + + def _ensure_traced(self, cmd, key, default_val, *args, **kwargs): + """ + PooledClient creates Client instances dynamically on request, which means + those Client instances aren't affected by the wrappers applied in patch(). + We handle this case here by calling trace() before running the command, + specifically when the client that will be used for the command is a + PooledClient. + + To avoid double-tracing when the key's client is not a PooledClient, we + don't create a span and instead rely on patch(). In this case the + underlying Client instance is long-lived and has been patched already. + """ + client_for_key = self._get_client(key) + if isinstance(client_for_key, PooledClient): + return self._traced_cmd(cmd, client_for_key, key, default_val, *args, **kwargs) + else: + return getattr(self.__wrapped__, cmd)(key, *args, **kwargs) + + def __init__(self, *args, **kwargs): + super(WrappedHashClient, self).__init__(_HashClient, *args, **kwargs) + + def set(self, key, *args, **kwargs): + return self._ensure_traced("set", key, False, *args, **kwargs) + + def add(self, key, *args, **kwargs): + return self._ensure_traced("add", key, False, *args, **kwargs) + + def replace(self, key, *args, **kwargs): + return self._ensure_traced("replace", key, False, *args, **kwargs) + + def append(self, key, *args, **kwargs): + return self._ensure_traced("append", key, False, *args, **kwargs) + + def prepend(self, key, *args, **kwargs): + return self._ensure_traced("prepend", key, False, *args, **kwargs) + + def cas(self, key, *args, **kwargs): + return self._ensure_traced("cas", key, False, *args, **kwargs) + + def get(self, key, *args, **kwargs): + return self._ensure_traced("get", key, None, *args, **kwargs) + + def gets(self, key, *args, **kwargs): + return self._ensure_traced("gets", key, None, *args, **kwargs) + + def delete(self, key, *args, **kwargs): + return self._ensure_traced("delete", key, False, *args, **kwargs) + + def incr(self, key, *args, **kwargs): + return self._ensure_traced("incr", key, False, *args, **kwargs) + + def decr(self, key, *args, **kwargs): + return self._ensure_traced("decr", key, False, *args, **kwargs) + + def touch(self, key, *args, **kwargs): + return self._ensure_traced("touch", key, False, *args, **kwargs) + + def _traced_cmd(self, command, client, key, default_val, *args, **kwargs): + # NB this function mimics the logic of HashClient._run_cmd, tracing the call to _safely_run_func + if client is None: + return default_val + + args = list(args) + args.insert(0, key) + + return self._trace_function_as_command( + lambda *_args, **_kwargs: self._safely_run_func( + client, getattr(client, command), default_val, *_args, **_kwargs + ), + command, + *args, + **kwargs, + ) + + +_HashClient.client_class = WrappedClient + + +def _get_address_tags(*args, **kwargs): + """Attempt to get host and port from args passed to Client initializer.""" + tags = {} + try: + if len(args): + host, port = args[0] + tags[net.TARGET_HOST] = host + tags[net.TARGET_PORT] = port + tags[net.SERVER_ADDRESS] = host + except Exception: + log.debug("Error collecting client address tags") + + return tags + + +def _get_query_string(args): + """Return the query values given the arguments to a pymemcache command. + + If there are multiple query values, they are joined together + space-separated. + """ + keys = "" + + # shortcut if no args + if not args: + return keys + + # pull out the first arg which will contain any key + arg = args[0] + + # if we get a dict, convert to list of keys + if type(arg) is dict: + arg = list(arg) + + if type(arg) is str: + keys = arg + elif type(arg) is bytes: + keys = arg.decode() + elif type(arg) is list and len(arg): + if type(arg[0]) is str: + keys = " ".join(arg) + elif type(arg[0]) is bytes: + keys = b" ".join(arg).decode() + + return keys + + +def _trace(func, p, method_name, *args, **kwargs): + """Run and trace the given command. + + Any pymemcache exception is caught and span error information is + set. The exception is then reraised for the application to handle + appropriately. + + Relevant tags are set in the span. + """ + with p.tracer.trace( + schematize_cache_operation(memcachedx.CMD, cache_provider="memcached"), + service=p.service, + resource=method_name, + span_type=SpanTypes.CACHE, + ) as span: + span.set_tag_str(COMPONENT, config.pymemcache.integration_name) + span.set_tag_str(db.SYSTEM, memcachedx.DBMS_NAME) + + # set span.kind to the type of operation being performed + span.set_tag_str(SPAN_KIND, SpanKind.CLIENT) + + span.set_tag(SPAN_MEASURED_KEY) + # set analytics sample rate + span.set_tag(ANALYTICS_SAMPLE_RATE_KEY, config.pymemcache.get_analytics_sample_rate()) + + # try to set relevant tags, catch any exceptions so we don't mess + # with the application + try: + span.set_tags(p.tags) + if config.pymemcache.command_enabled: + vals = _get_query_string(args) + query = "{}{}{}".format(method_name, " " if vals else "", vals) + span.set_tag_str(memcachedx.QUERY, query) + except Exception: + log.debug("Error setting relevant pymemcache tags") + + try: + result = func(*args, **kwargs) + + if method_name == "get_many" or method_name == "gets_many": + # gets_many returns a map of key -> (value, cas), else an empty dict if no matches + # get many returns a map with values, else an empty map if no matches + span.set_metric( + db.ROWCOUNT, sum(1 for doc in result if doc) if result and isinstance(result, Iterable) else 0 + ) + elif method_name == "get": + # get returns key or None + span.set_metric(db.ROWCOUNT, 1 if result else 0) + elif method_name == "gets": + # gets returns a tuple of (None, None) if key not found, else tuple of (key, index) + span.set_metric(db.ROWCOUNT, 1 if result[0] else 0) + return result + except ( + MemcacheClientError, + MemcacheServerError, + MemcacheUnknownCommandError, + MemcacheUnknownError, + MemcacheIllegalInputError, + ): + (typ, val, tb) = sys.exc_info() + span.set_exc_info(typ, val, tb) + raise diff --git a/ddtrace/contrib/internal/pymemcache/patch.py b/ddtrace/contrib/internal/pymemcache/patch.py new file mode 100644 index 00000000000..07402680e9e --- /dev/null +++ b/ddtrace/contrib/internal/pymemcache/patch.py @@ -0,0 +1,49 @@ +import pymemcache +import pymemcache.client.hash + +from ddtrace.ext import memcached as memcachedx +from ddtrace.internal.schema import schematize_service_name +from ddtrace.pin import _DD_PIN_NAME +from ddtrace.pin import _DD_PIN_PROXY_NAME +from ddtrace.pin import Pin + +from .client import WrappedClient +from .client import WrappedHashClient + + +_Client = pymemcache.client.base.Client +_hash_Client = pymemcache.client.hash.Client +_hash_HashClient = pymemcache.client.hash.Client + + +def get_version(): + # type: () -> str + return getattr(pymemcache, "__version__", "") + + +def patch(): + if getattr(pymemcache, "_datadog_patch", False): + return + + pymemcache._datadog_patch = True + pymemcache.client.base.Client = WrappedClient + pymemcache.client.hash.Client = WrappedClient + pymemcache.client.hash.HashClient = WrappedHashClient + + # Create a global pin with default configuration for our pymemcache clients + service = schematize_service_name(memcachedx.SERVICE) + Pin(service=service).onto(pymemcache) + + +def unpatch(): + """Remove pymemcache tracing""" + if not getattr(pymemcache, "_datadog_patch", False): + return + pymemcache._datadog_patch = False + pymemcache.client.base.Client = _Client + pymemcache.client.hash.Client = _hash_Client + pymemcache.client.hash.HashClient = _hash_HashClient + + # Remove any pins that may exist on the pymemcache reference + setattr(pymemcache, _DD_PIN_NAME, None) + setattr(pymemcache, _DD_PIN_PROXY_NAME, None) diff --git a/ddtrace/contrib/internal/pymongo/client.py b/ddtrace/contrib/internal/pymongo/client.py new file mode 100644 index 00000000000..21ee754ac73 --- /dev/null +++ b/ddtrace/contrib/internal/pymongo/client.py @@ -0,0 +1,372 @@ +# stdlib +import contextlib +import json +from typing import Iterable + +# 3p +import pymongo + +# project +import ddtrace +from ddtrace import config +from ddtrace.constants import ANALYTICS_SAMPLE_RATE_KEY +from ddtrace.constants import SPAN_KIND +from ddtrace.constants import SPAN_MEASURED_KEY +from ddtrace.ext import SpanKind +from ddtrace.ext import SpanTypes +from ddtrace.ext import db +from ddtrace.ext import mongo as mongox +from ddtrace.ext import net as netx +from ddtrace.internal.constants import COMPONENT +from ddtrace.internal.logger import get_logger +from ddtrace.internal.schema import schematize_database_operation +from ddtrace.internal.schema import schematize_service_name +from ddtrace.internal.utils import get_argument_value +from ddtrace.vendor.wrapt import ObjectProxy + +from .parse import parse_msg +from .parse import parse_query +from .parse import parse_spec + + +BATCH_PARTIAL_KEY = "Batch" + +# Original Client class +_MongoClient = pymongo.MongoClient + +VERSION = pymongo.version_tuple + +if VERSION < (3, 6, 0): + from pymongo.helpers import _unpack_response + + +log = get_logger(__name__) + +_DEFAULT_SERVICE = schematize_service_name("pymongo") + + +class TracedMongoClient(ObjectProxy): + def __init__(self, client=None, *args, **kwargs): + # To support the former trace_mongo_client interface, we have to keep this old interface + # TODO(Benjamin): drop it in a later version + if not isinstance(client, _MongoClient): + # Patched interface, instantiate the client + + # client is just the first arg which could be the host if it is + # None, then it could be that the caller: + + # if client is None then __init__ was: + # 1) invoked with host=None + # 2) not given a first argument (client defaults to None) + # we cannot tell which case it is, but it should not matter since + # the default value for host is None, in either case we can simply + # not provide it as an argument + if client is None: + client = _MongoClient(*args, **kwargs) + # else client is a value for host so just pass it along + else: + client = _MongoClient(client, *args, **kwargs) + + super(TracedMongoClient, self).__init__(client) + client._datadog_proxy = self + # NOTE[matt] the TracedMongoClient attempts to trace all of the network + # calls in the trace library. This is good because it measures the + # actual network time. It's bad because it uses a private API which + # could change. We'll see how this goes. + if not isinstance(client._topology, TracedTopology): + client._topology = TracedTopology(client._topology) + + # Default Pin + ddtrace.Pin(service=_DEFAULT_SERVICE).onto(self) + + def __setddpin__(self, pin): + pin.onto(self._topology) + + def __getddpin__(self): + return ddtrace.Pin.get_from(self._topology) + + +@contextlib.contextmanager +def wrapped_validate_session(wrapped, instance, args, kwargs): + # We do this to handle a validation `A is B` in pymongo that + # relies on IDs being equal. Since we are proxying objects, we need + # to ensure we're compare proxy with proxy or wrapped with wrapped + # or this validation will fail + client = args[0] + session = args[1] + session_client = session._client + if isinstance(session_client, TracedMongoClient): + if isinstance(client, _MongoClient): + client = getattr(client, "_datadog_proxy", client) + elif isinstance(session_client, _MongoClient): + if isinstance(client, TracedMongoClient): + client = client.__wrapped__ + + yield wrapped(client, session) + + +class TracedTopology(ObjectProxy): + def __init__(self, topology): + super(TracedTopology, self).__init__(topology) + + def select_server(self, *args, **kwargs): + s = self.__wrapped__.select_server(*args, **kwargs) + if not isinstance(s, TracedServer): + s = TracedServer(s) + # Reattach the pin every time in case it changed since the initial patching + ddtrace.Pin.get_from(self).onto(s) + return s + + +class TracedServer(ObjectProxy): + def __init__(self, server): + super(TracedServer, self).__init__(server) + + def _datadog_trace_operation(self, operation): + cmd = None + # Only try to parse something we think is a query. + if self._is_query(operation): + try: + cmd = parse_query(operation) + except Exception: + log.exception("error parsing query") + + pin = ddtrace.Pin.get_from(self) + # if we couldn't parse or shouldn't trace the message, just go. + if not cmd or not pin or not pin.enabled(): + return None + + span = pin.tracer.trace( + schematize_database_operation("pymongo.cmd", database_provider="mongodb"), + span_type=SpanTypes.MONGODB, + service=pin.service, + ) + + span.set_tag_str(COMPONENT, config.pymongo.integration_name) + + # set span.kind to the operation type being performed + span.set_tag_str(SPAN_KIND, SpanKind.CLIENT) + + span.set_tag(SPAN_MEASURED_KEY) + span.set_tag_str(mongox.DB, cmd.db) + span.set_tag_str(mongox.COLLECTION, cmd.coll) + span.set_tag_str(db.SYSTEM, mongox.SERVICE) + span.set_tags(cmd.tags) + + # set `mongodb.query` tag and resource for span + _set_query_metadata(span, cmd) + + # set analytics sample rate + sample_rate = config.pymongo.get_analytics_sample_rate() + if sample_rate is not None: + span.set_tag(ANALYTICS_SAMPLE_RATE_KEY, sample_rate) + return span + + if VERSION >= (4, 5, 0): + + @contextlib.contextmanager + def checkout(self, *args, **kwargs): + with self.__wrapped__.checkout(*args, **kwargs) as s: + if not isinstance(s, TracedSocket): + s = TracedSocket(s) + ddtrace.Pin.get_from(self).onto(s) + yield s + + else: + + @contextlib.contextmanager + def get_socket(self, *args, **kwargs): + with self.__wrapped__.get_socket(*args, **kwargs) as s: + if not isinstance(s, TracedSocket): + s = TracedSocket(s) + ddtrace.Pin.get_from(self).onto(s) + yield s + + if VERSION >= (3, 12, 0): + + def run_operation(self, sock_info, operation, *args, **kwargs): + span = self._datadog_trace_operation(operation) + if span is None: + return self.__wrapped__.run_operation(sock_info, operation, *args, **kwargs) + with span: + result = self.__wrapped__.run_operation(sock_info, operation, *args, **kwargs) + if result: + if hasattr(result, "address"): + set_address_tags(span, result.address) + if self._is_query(operation) and hasattr(result, "docs"): + set_query_rowcount(docs=result.docs, span=span) + return result + + elif (3, 9, 0) <= VERSION < (3, 12, 0): + + def run_operation_with_response(self, sock_info, operation, *args, **kwargs): + span = self._datadog_trace_operation(operation) + if span is None: + return self.__wrapped__.run_operation_with_response(sock_info, operation, *args, **kwargs) + with span: + result = self.__wrapped__.run_operation_with_response(sock_info, operation, *args, **kwargs) + if result: + if hasattr(result, "address"): + set_address_tags(span, result.address) + if self._is_query(operation) and hasattr(result, "docs"): + set_query_rowcount(docs=result.docs, span=span) + return result + + else: + + def send_message_with_response(self, operation, *args, **kwargs): + span = self._datadog_trace_operation(operation) + if span is None: + return self.__wrapped__.send_message_with_response(operation, *args, **kwargs) + with span: + result = self.__wrapped__.send_message_with_response(operation, *args, **kwargs) + if result: + if hasattr(result, "address"): + set_address_tags(span, result.address) + if self._is_query(operation): + if hasattr(result, "data"): + if VERSION >= (3, 6, 0) and hasattr(result.data, "unpack_response"): + set_query_rowcount(docs=result.data.unpack_response(), span=span) + else: + data = _unpack_response(response=result.data) + if VERSION < (3, 2, 0) and data.get("number_returned", None): + span.set_metric(db.ROWCOUNT, data.get("number_returned")) + elif (3, 2, 0) <= VERSION < (3, 6, 0): + docs = data.get("data", None) + set_query_rowcount(docs=docs, span=span) + return result + + @staticmethod + def _is_query(op): + # NOTE: _Query should always have a spec field + return hasattr(op, "spec") + + +class TracedSocket(ObjectProxy): + def __init__(self, socket): + super(TracedSocket, self).__init__(socket) + + def command(self, dbname, spec, *args, **kwargs): + cmd = None + try: + cmd = parse_spec(spec, dbname) + except Exception: + log.exception("error parsing spec. skipping trace") + + pin = ddtrace.Pin.get_from(self) + # skip tracing if we don't have a piece of data we need + if not dbname or not cmd or not pin or not pin.enabled(): + return self.__wrapped__.command(dbname, spec, *args, **kwargs) + + cmd.db = dbname + with self.__trace(cmd): + return self.__wrapped__.command(dbname, spec, *args, **kwargs) + + def write_command(self, *args, **kwargs): + msg = get_argument_value(args, kwargs, 1, "msg") + cmd = None + try: + cmd = parse_msg(msg) + except Exception: + log.exception("error parsing msg") + + pin = ddtrace.Pin.get_from(self) + # if we couldn't parse it, don't try to trace it. + if not cmd or not pin or not pin.enabled(): + return self.__wrapped__.write_command(*args, **kwargs) + + with self.__trace(cmd) as s: + result = self.__wrapped__.write_command(*args, **kwargs) + if result: + s.set_metric(db.ROWCOUNT, result.get("n", -1)) + return result + + def __trace(self, cmd): + pin = ddtrace.Pin.get_from(self) + s = pin.tracer.trace( + schematize_database_operation("pymongo.cmd", database_provider="mongodb"), + span_type=SpanTypes.MONGODB, + service=pin.service, + ) + + s.set_tag_str(COMPONENT, config.pymongo.integration_name) + s.set_tag_str(db.SYSTEM, mongox.SERVICE) + + # set span.kind to the type of operation being performed + s.set_tag_str(SPAN_KIND, SpanKind.CLIENT) + + s.set_tag(SPAN_MEASURED_KEY) + if cmd.db: + s.set_tag_str(mongox.DB, cmd.db) + if cmd: + s.set_tag(mongox.COLLECTION, cmd.coll) + s.set_tags(cmd.tags) + s.set_metrics(cmd.metrics) + + # set `mongodb.query` tag and resource for span + _set_query_metadata(s, cmd) + + # set analytics sample rate + s.set_tag(ANALYTICS_SAMPLE_RATE_KEY, config.pymongo.get_analytics_sample_rate()) + + if self.address: + set_address_tags(s, self.address) + return s + + +def normalize_filter(f=None): + if f is None: + return {} + elif isinstance(f, list): + # normalize lists of filters + # e.g. {$or: [ { age: { $lt: 30 } }, { type: 1 } ]} + return [normalize_filter(s) for s in f] + elif isinstance(f, dict): + # normalize dicts of filters + # {$or: [ { age: { $lt: 30 } }, { type: 1 } ]}) + out = {} + for k, v in f.items(): + if k == "$in" or k == "$nin": + # special case $in queries so we don't loop over lists. + out[k] = "?" + elif isinstance(v, list) or isinstance(v, dict): + # RECURSION ALERT: needs to move to the agent + out[k] = normalize_filter(v) + else: + # NOTE: this shouldn't happen, but let's have a safeguard. + out[k] = "?" + return out + else: + # FIXME[matt] unexpected type. not sure this should ever happen, but at + # least it won't crash. + return {} + + +def set_address_tags(span, address): + # the address is only set after the cursor is done. + if address: + span.set_tag_str(netx.TARGET_HOST, address[0]) + span.set_tag_str(netx.SERVER_ADDRESS, address[0]) + span.set_tag(netx.TARGET_PORT, address[1]) + + +def _set_query_metadata(span, cmd): + """Sets span `mongodb.query` tag and resource given command query""" + if cmd.query: + nq = normalize_filter(cmd.query) + span.set_tag("mongodb.query", nq) + # needed to dump json so we don't get unicode + # dict keys like {u'foo':'bar'} + q = json.dumps(nq) + span.resource = "{} {} {}".format(cmd.name, cmd.coll, q) + else: + span.resource = "{} {}".format(cmd.name, cmd.coll) + + +def set_query_rowcount(docs, span): + # results returned in batches, get len of each batch + if isinstance(docs, Iterable) and len(docs) > 0: + cursor = docs[0].get("cursor", None) + if cursor: + rowcount = sum([len(documents) for batch_key, documents in cursor.items() if BATCH_PARTIAL_KEY in batch_key]) + span.set_metric(db.ROWCOUNT, rowcount) diff --git a/ddtrace/contrib/internal/pymongo/parse.py b/ddtrace/contrib/internal/pymongo/parse.py new file mode 100644 index 00000000000..f4db22929bd --- /dev/null +++ b/ddtrace/contrib/internal/pymongo/parse.py @@ -0,0 +1,204 @@ +import ctypes +import struct + +# 3p +import bson +from bson.codec_options import CodecOptions +from bson.son import SON + +# project +from ddtrace.ext import net as netx +from ddtrace.internal.compat import to_unicode +from ddtrace.internal.logger import get_logger + + +log = get_logger(__name__) + + +# MongoDB wire protocol commands +# http://docs.mongodb.com/manual/reference/mongodb-wire-protocol +OP_CODES = { + 1: "reply", + 1000: "msg", # DEV: 1000 was deprecated at some point, use 2013 instead + 2001: "update", + 2002: "insert", + 2003: "reserved", + 2004: "query", + 2005: "get_more", + 2006: "delete", + 2007: "kill_cursors", + 2010: "command", + 2011: "command_reply", + 2013: "msg", +} + +# The maximum message length we'll try to parse +MAX_MSG_PARSE_LEN = 1024 * 1024 + +header_struct = struct.Struct("= 3.1 stores the db and coll separately + coll = getattr(query, "coll", None) + db = getattr(query, "db", None) + + # pymongo < 3.1 _Query does not have a name field, so default to 'query' + cmd = Command(getattr(query, "name", "query"), db, coll) + cmd.query = query.spec + return cmd + + +def parse_spec(spec, db=None): + """Return a Command that has parsed the relevant detail for the given + pymongo SON spec. + """ + + # the first element is the command and collection + items = list(spec.items()) + if not items: + return None + name, coll = items[0] + cmd = Command(name, db or spec.get("$db"), coll) + + if "ordered" in spec: # in insert and update + cmd.tags["mongodb.ordered"] = spec["ordered"] + + if cmd.name == "insert": + if "documents" in spec: + cmd.metrics["mongodb.documents"] = len(spec["documents"]) + + elif cmd.name == "update": + updates = spec.get("updates") + if updates: + # FIXME[matt] is there ever more than one here? + cmd.query = updates[0].get("q") + + elif cmd.name == "delete": + dels = spec.get("deletes") + if dels: + # FIXME[matt] is there ever more than one here? + cmd.query = dels[0].get("q") + + return cmd + + +def _cstring(raw): + """Return the first null terminated cstring from the buffer.""" + return ctypes.create_string_buffer(raw).value + + +def _split_namespace(ns): + """Return a tuple of (db, collection) from the 'db.coll' string.""" + if ns: + # NOTE[matt] ns is unicode or bytes depending on the client version + # so force cast to unicode + split = to_unicode(ns).split(".", 1) + if len(split) == 1: + raise Exception("namespace doesn't contain period: %s" % ns) + return split + return (None, None) diff --git a/ddtrace/contrib/internal/pymongo/patch.py b/ddtrace/contrib/internal/pymongo/patch.py new file mode 100644 index 00000000000..13718882347 --- /dev/null +++ b/ddtrace/contrib/internal/pymongo/patch.py @@ -0,0 +1,98 @@ +import contextlib + +import pymongo + +from ddtrace import Pin +from ddtrace import config +from ddtrace.constants import SPAN_KIND +from ddtrace.constants import SPAN_MEASURED_KEY +from ddtrace.contrib import trace_utils +from ddtrace.contrib.trace_utils import unwrap as _u +from ddtrace.ext import SpanKind +from ddtrace.ext import SpanTypes +from ddtrace.ext import db +from ddtrace.ext import mongo +from ddtrace.internal.constants import COMPONENT +from ddtrace.vendor.wrapt import wrap_function_wrapper as _w + +from .client import TracedMongoClient +from .client import set_address_tags +from .client import wrapped_validate_session + + +config._add( + "pymongo", + dict(_default_service="pymongo"), +) + + +def get_version(): + # type: () -> str + return getattr(pymongo, "__version__", "") + + +# Original Client class +_MongoClient = pymongo.MongoClient + +_VERSION = pymongo.version_tuple +_CHECKOUT_FN_NAME = "get_socket" if _VERSION < (4, 5) else "checkout" +_VERIFY_VERSION_CLASS = pymongo.pool.SocketInfo if _VERSION < (4, 5) else pymongo.pool.Connection + + +def patch(): + patch_pymongo_module() + # We should progressively get rid of TracedMongoClient. We now try to + # wrap methods individually. cf #1501 + pymongo.MongoClient = TracedMongoClient + + +def unpatch(): + unpatch_pymongo_module() + pymongo.MongoClient = _MongoClient + + +def patch_pymongo_module(): + if getattr(pymongo, "_datadog_patch", False): + return + pymongo._datadog_patch = True + Pin().onto(pymongo.server.Server) + + # Whenever a pymongo command is invoked, the lib either: + # - Creates a new socket & performs a TCP handshake + # - Grabs a socket already initialized before + _w("pymongo.server", "Server.%s" % _CHECKOUT_FN_NAME, traced_get_socket) + _w("pymongo.pool", f"{_VERIFY_VERSION_CLASS.__name__}.validate_session", wrapped_validate_session) + + +def unpatch_pymongo_module(): + if not getattr(pymongo, "_datadog_patch", False): + return + pymongo._datadog_patch = False + + _u(pymongo.server.Server, _CHECKOUT_FN_NAME) + _u(_VERIFY_VERSION_CLASS, "validate_session") + + +@contextlib.contextmanager +def traced_get_socket(wrapped, instance, args, kwargs): + pin = Pin._find(wrapped, instance) + if not pin or not pin.enabled(): + with wrapped(*args, **kwargs) as sock_info: + yield sock_info + return + + with pin.tracer.trace( + "pymongo.%s" % _CHECKOUT_FN_NAME, + service=trace_utils.int_service(pin, config.pymongo), + span_type=SpanTypes.MONGODB, + ) as span: + span.set_tag_str(COMPONENT, config.pymongo.integration_name) + span.set_tag_str(db.SYSTEM, mongo.SERVICE) + + # set span.kind tag equal to type of operation being performed + span.set_tag_str(SPAN_KIND, SpanKind.CLIENT) + + with wrapped(*args, **kwargs) as sock_info: + set_address_tags(span, sock_info.address) + span.set_tag(SPAN_MEASURED_KEY) + yield sock_info diff --git a/ddtrace/contrib/internal/pymysql/patch.py b/ddtrace/contrib/internal/pymysql/patch.py new file mode 100644 index 00000000000..1068368831d --- /dev/null +++ b/ddtrace/contrib/internal/pymysql/patch.py @@ -0,0 +1,68 @@ +import os + +import pymysql + +from ddtrace import Pin +from ddtrace import config +from ddtrace.contrib.dbapi import TracedConnection +from ddtrace.contrib.trace_utils import _convert_to_string +from ddtrace.ext import db +from ddtrace.ext import net +from ddtrace.internal.schema import schematize_database_operation +from ddtrace.internal.schema import schematize_service_name +from ddtrace.internal.utils.formats import asbool +from ddtrace.propagation._database_monitoring import _DBM_Propagator +from ddtrace.vendor import wrapt + + +config._add( + "pymysql", + dict( + _default_service=schematize_service_name("pymysql"), + _dbapi_span_name_prefix="pymysql", + _dbapi_span_operation_name=schematize_database_operation("pymysql.query", database_provider="mysql"), + trace_fetch_methods=asbool(os.getenv("DD_PYMYSQL_TRACE_FETCH_METHODS", default=False)), + _dbm_propagator=_DBM_Propagator(0, "query"), + ), +) + + +def get_version(): + # type: () -> str + return getattr(pymysql, "__version__", "") + + +CONN_ATTR_BY_TAG = { + net.TARGET_HOST: "host", + net.TARGET_PORT: "port", + net.SERVER_ADDRESS: "host", + db.USER: "user", + db.NAME: "db", +} + + +def patch(): + wrapt.wrap_function_wrapper("pymysql", "connect", _connect) + pymysql._datadog_patch = True + + +def unpatch(): + if isinstance(pymysql.connect, wrapt.ObjectProxy): + pymysql.connect = pymysql.connect.__wrapped__ + pymysql._datadog_patch = False + + +def _connect(func, instance, args, kwargs): + conn = func(*args, **kwargs) + return patch_conn(conn) + + +def patch_conn(conn): + tags = {t: _convert_to_string(getattr(conn, a)) for t, a in CONN_ATTR_BY_TAG.items() if getattr(conn, a, "") != ""} + tags[db.SYSTEM] = "mysql" + pin = Pin(tags=tags) + + # grab the metadata from the conn + wrapped = TracedConnection(conn, pin=pin, cfg=config.pymysql) + pin.onto(wrapped) + return wrapped diff --git a/ddtrace/contrib/openai/__init__.py b/ddtrace/contrib/openai/__init__.py index 33765465aec..c0435ec81fd 100644 --- a/ddtrace/contrib/openai/__init__.py +++ b/ddtrace/contrib/openai/__init__.py @@ -253,10 +253,12 @@ with require_modules(required_modules) as missing_modules: if not missing_modules: - from . import patch as _patch + # Required to allow users to import from `ddtrace.contrib.openai.patch` directly + from . import patch as _ # noqa: F401, I001 - patch = _patch.patch - unpatch = _patch.unpatch - get_version = _patch.get_version + # Expose public methods + from ..internal.openai.patch import patch + from ..internal.openai.patch import unpatch + from ..internal.openai.patch import get_version __all__ = ["patch", "unpatch", "get_version"] diff --git a/ddtrace/contrib/openai/_endpoint_hooks.py b/ddtrace/contrib/openai/_endpoint_hooks.py index 7e46af3c31c..9bd446b7038 100644 --- a/ddtrace/contrib/openai/_endpoint_hooks.py +++ b/ddtrace/contrib/openai/_endpoint_hooks.py @@ -1,757 +1,15 @@ -from openai.version import VERSION as OPENAI_VERSION +from ddtrace.internal.utils.deprecations import DDTraceDeprecationWarning +from ddtrace.vendor.debtcollector import deprecate -from ddtrace.contrib.openai.utils import TracedOpenAIAsyncStream -from ddtrace.contrib.openai.utils import TracedOpenAIStream -from ddtrace.contrib.openai.utils import _format_openai_api_key -from ddtrace.contrib.openai.utils import _is_async_generator -from ddtrace.contrib.openai.utils import _is_generator -from ddtrace.contrib.openai.utils import _loop_handler -from ddtrace.contrib.openai.utils import _process_finished_stream -from ddtrace.contrib.openai.utils import _tag_tool_calls -from ddtrace.internal.utils.version import parse_version -from ddtrace.llmobs._constants import SPAN_KIND +from ..internal.openai._endpoint_hooks import * # noqa: F401,F403 -API_VERSION = "v1" - - -class _EndpointHook: - """ - Base class for all OpenAI endpoint hooks. - Each new endpoint hook should declare `_request_arg_params` and `_request_kwarg_params`, - which will be tagged automatically by _EndpointHook._record_request(). - For endpoint-specific request/response parameters that requires special casing, add that logic to - the endpoint hook's `_record_request()` after a super call to the base `_EndpointHook._record_request()`. - """ - - # _request_arg_params must include the names of arg parameters in order. - # If a given arg requires special casing, replace with `None` to avoid automatic tagging. - _request_arg_params = () - # _request_kwarg_params must include the names of kwarg parameters to tag automatically. - # If a given kwarg requires special casing, remove from this tuple to avoid automatic tagging. - _request_kwarg_params = () - # _response_attrs is used to automatically tag specific response attributes. - _response_attrs = () - _base_level_tag_args = ("api_base", "api_type", "api_version") - ENDPOINT_NAME = "openai" - HTTP_METHOD_TYPE = "" - OPERATION_ID = "" # Each endpoint hook must provide an operationID as specified in the OpenAI API specs: - # https://raw.githubusercontent.com/openai/openai-openapi/master/openapi.yaml - - def _record_request(self, pin, integration, span, args, kwargs): - """ - Set base-level openai tags, as well as request params from args and kwargs. - All inherited EndpointHook classes should include a super call to this method before performing - endpoint-specific request tagging logic. - """ - endpoint = self.ENDPOINT_NAME - if endpoint is None: - endpoint = "%s" % args[0].OBJECT_NAME - span.set_tag_str("openai.request.endpoint", "/%s/%s" % (API_VERSION, endpoint)) - span.set_tag_str("openai.request.method", self.HTTP_METHOD_TYPE) - - if self._request_arg_params and len(self._request_arg_params) > 1: - for idx, arg in enumerate(self._request_arg_params, 1): - if idx >= len(args): - break - if arg is None or args[idx] is None: - continue - if arg in self._base_level_tag_args: - span.set_tag_str("openai.%s" % arg, str(args[idx])) - elif arg == "organization": - span.set_tag_str("openai.organization.id", args[idx]) - elif arg == "api_key": - span.set_tag_str("openai.user.api_key", _format_openai_api_key(args[idx])) - else: - span.set_tag_str("openai.request.%s" % arg, str(args[idx])) - for kw_attr in self._request_kwarg_params: - if kw_attr not in kwargs: - continue - if isinstance(kwargs[kw_attr], dict): - for k, v in kwargs[kw_attr].items(): - span.set_tag_str("openai.request.%s.%s" % (kw_attr, k), str(v)) - elif kw_attr == "engine": # Azure OpenAI requires using "engine" instead of "model" - span.set_tag_str("openai.request.model", str(kwargs[kw_attr])) - else: - span.set_tag_str("openai.request.%s" % kw_attr, str(kwargs[kw_attr])) - - def handle_request(self, pin, integration, span, args, kwargs): - self._record_request(pin, integration, span, args, kwargs) - resp, error = yield - if hasattr(resp, "parse"): - # Users can request the raw response, in which case we need to process on the parsed response - # and return the original raw APIResponse. - self._record_response(pin, integration, span, args, kwargs, resp.parse(), error) - return resp - return self._record_response(pin, integration, span, args, kwargs, resp, error) - - def _record_response(self, pin, integration, span, args, kwargs, resp, error): - for resp_attr in self._response_attrs: - if hasattr(resp, resp_attr): - span.set_tag_str("openai.response.%s" % resp_attr, str(getattr(resp, resp_attr, ""))) - return resp - - -class _BaseCompletionHook(_EndpointHook): - _request_arg_params = ("api_key", "api_base", "api_type", "request_id", "api_version", "organization") - - def _handle_streamed_response(self, integration, span, kwargs, resp, is_completion=False): - """Handle streamed response objects returned from completions/chat endpoint calls. - - This method returns a wrapped version of the OpenAIStream/OpenAIAsyncStream objects - to trace the response while it is read by the user. - """ - if parse_version(OPENAI_VERSION) >= (1, 6, 0): - if _is_async_generator(resp): - return TracedOpenAIAsyncStream(resp, integration, span, kwargs, is_completion) - elif _is_generator(resp): - return TracedOpenAIStream(resp, integration, span, kwargs, is_completion) - - def shared_gen(): - try: - streamed_chunks = yield - _process_finished_stream(integration, span, kwargs, streamed_chunks, is_completion=is_completion) - finally: - span.finish() - integration.metric(span, "dist", "request.duration", span.duration_ns) - - if _is_async_generator(resp): - - async def traced_streamed_response(): - g = shared_gen() - g.send(None) - n = kwargs.get("n", 1) or 1 - if is_completion: - prompts = kwargs.get("prompt", "") - if isinstance(prompts, list) and not isinstance(prompts[0], int): - n *= len(prompts) - streamed_chunks = [[] for _ in range(n)] - try: - async for chunk in resp: - _loop_handler(span, chunk, streamed_chunks) - yield chunk - finally: - try: - g.send(streamed_chunks) - except StopIteration: - pass - - return traced_streamed_response() - - elif _is_generator(resp): - - def traced_streamed_response(): - g = shared_gen() - g.send(None) - n = kwargs.get("n", 1) or 1 - if is_completion: - prompts = kwargs.get("prompt", "") - if isinstance(prompts, list) and not isinstance(prompts[0], int): - n *= len(prompts) - streamed_chunks = [[] for _ in range(n)] - try: - for chunk in resp: - _loop_handler(span, chunk, streamed_chunks) - yield chunk - finally: - try: - g.send(streamed_chunks) - except StopIteration: - pass - - return traced_streamed_response() - return resp - - -class _CompletionHook(_BaseCompletionHook): - _request_kwarg_params = ( - "model", - "engine", - "suffix", - "max_tokens", - "temperature", - "top_p", - "n", - "stream", - "logprobs", - "echo", - "stop", - "presence_penalty", - "frequency_penalty", - "best_of", - "logit_bias", - "user", - ) - _response_attrs = ("created", "id", "model") - ENDPOINT_NAME = "completions" - HTTP_METHOD_TYPE = "POST" - OPERATION_ID = "createCompletion" - - def _record_request(self, pin, integration, span, args, kwargs): - super()._record_request(pin, integration, span, args, kwargs) - if integration.is_pc_sampled_llmobs(span): - span.set_tag_str(SPAN_KIND, "llm") - if integration.is_pc_sampled_span(span): - prompt = kwargs.get("prompt", "") - if isinstance(prompt, str): - prompt = [prompt] - for idx, p in enumerate(prompt): - span.set_tag_str("openai.request.prompt.%d" % idx, integration.trunc(str(p))) - - def _record_response(self, pin, integration, span, args, kwargs, resp, error): - resp = super()._record_response(pin, integration, span, args, kwargs, resp, error) - if kwargs.get("stream") and error is None: - return self._handle_streamed_response(integration, span, kwargs, resp, is_completion=True) - if integration.is_pc_sampled_log(span): - attrs_dict = {"prompt": kwargs.get("prompt", "")} - if error is None: - log_choices = resp.choices - if hasattr(resp.choices[0], "model_dump"): - log_choices = [choice.model_dump() for choice in resp.choices] - attrs_dict.update({"choices": log_choices}) - integration.log( - span, "info" if error is None else "error", "sampled %s" % self.OPERATION_ID, attrs=attrs_dict - ) - if integration.is_pc_sampled_llmobs(span): - integration.llmobs_set_tags("completion", resp, span, kwargs, err=error) - if not resp: - return - for choice in resp.choices: - span.set_tag_str("openai.response.choices.%d.finish_reason" % choice.index, str(choice.finish_reason)) - if integration.is_pc_sampled_span(span): - span.set_tag_str("openai.response.choices.%d.text" % choice.index, integration.trunc(choice.text)) - integration.record_usage(span, resp.usage) - return resp - - -class _ChatCompletionHook(_BaseCompletionHook): - _request_arg_params = ("api_key", "api_base", "api_type", "request_id", "api_version", "organization") - _request_kwarg_params = ( - "model", - "engine", - "temperature", - "top_p", - "n", - "stream", - "stop", - "max_tokens", - "presence_penalty", - "frequency_penalty", - "logit_bias", - "user", - ) - _response_attrs = ("created", "id", "model") - ENDPOINT_NAME = "chat/completions" - HTTP_METHOD_TYPE = "POST" - OPERATION_ID = "createChatCompletion" - - def _record_request(self, pin, integration, span, args, kwargs): - super()._record_request(pin, integration, span, args, kwargs) - if integration.is_pc_sampled_llmobs(span): - span.set_tag_str(SPAN_KIND, "llm") - for idx, m in enumerate(kwargs.get("messages", [])): - role = getattr(m, "role", "") - name = getattr(m, "name", "") - content = getattr(m, "content", "") - if isinstance(m, dict): - content = m.get("content", "") - role = m.get("role", "") - name = m.get("name", "") - if integration.is_pc_sampled_span(span): - span.set_tag_str("openai.request.messages.%d.content" % idx, integration.trunc(str(content))) - span.set_tag_str("openai.request.messages.%d.role" % idx, str(role)) - span.set_tag_str("openai.request.messages.%d.name" % idx, str(name)) - - def _record_response(self, pin, integration, span, args, kwargs, resp, error): - resp = super()._record_response(pin, integration, span, args, kwargs, resp, error) - if kwargs.get("stream") and error is None: - return self._handle_streamed_response(integration, span, kwargs, resp, is_completion=False) - if integration.is_pc_sampled_log(span): - log_choices = resp.choices - if hasattr(resp.choices[0], "model_dump"): - log_choices = [choice.model_dump() for choice in resp.choices] - attrs_dict = {"messages": kwargs.get("messages", []), "completion": log_choices} - integration.log( - span, "info" if error is None else "error", "sampled %s" % self.OPERATION_ID, attrs=attrs_dict - ) - if integration.is_pc_sampled_llmobs(span): - integration.llmobs_set_tags("chat", resp, span, kwargs, err=error) - if not resp: - return - for choice in resp.choices: - idx = choice.index - finish_reason = getattr(choice, "finish_reason", None) - message = choice.message - span.set_tag_str("openai.response.choices.%d.finish_reason" % idx, str(finish_reason)) - span.set_tag_str("openai.response.choices.%d.message.role" % idx, choice.message.role) - if integration.is_pc_sampled_span(span): - span.set_tag_str( - "openai.response.choices.%d.message.content" % idx, integration.trunc(message.content or "") - ) - if getattr(message, "function_call", None): - _tag_tool_calls(integration, span, [message.function_call], idx) - if getattr(message, "tool_calls", None): - _tag_tool_calls(integration, span, message.tool_calls, idx) - integration.record_usage(span, resp.usage) - return resp - - -class _EmbeddingHook(_EndpointHook): - _request_arg_params = ("api_key", "api_base", "api_type", "request_id", "api_version", "organization") - _request_kwarg_params = ("model", "engine", "user") - _response_attrs = ("model",) - ENDPOINT_NAME = "embeddings" - HTTP_METHOD_TYPE = "POST" - OPERATION_ID = "createEmbedding" - - def _record_request(self, pin, integration, span, args, kwargs): - """ - Embedding endpoint allows multiple inputs, each of which we specify a request tag for, so have to - manually set them in _pre_response(). - """ - super()._record_request(pin, integration, span, args, kwargs) - embedding_input = kwargs.get("input", "") - if integration.is_pc_sampled_span(span): - if isinstance(embedding_input, str) or isinstance(embedding_input[0], int): - embedding_input = [embedding_input] - for idx, inp in enumerate(embedding_input): - span.set_tag_str("openai.request.input.%d" % idx, integration.trunc(str(inp))) - - def _record_response(self, pin, integration, span, args, kwargs, resp, error): - resp = super()._record_response(pin, integration, span, args, kwargs, resp, error) - if integration.is_pc_sampled_llmobs(span): - integration.llmobs_set_tags("embedding", resp, span, kwargs, err=error) - if not resp: - return - span.set_metric("openai.response.embeddings_count", len(resp.data)) - span.set_metric("openai.response.embedding-length", len(resp.data[0].embedding)) - integration.record_usage(span, resp.usage) - return resp - - -class _ListHook(_EndpointHook): - """ - Hook for openai.ListableAPIResource, which is used by Model.list, File.list, and FineTune.list. - """ - - _request_arg_params = ("api_key", "request_id", "api_version", "organization", "api_base", "api_type") - _request_kwarg_params = ("user",) - ENDPOINT_NAME = None - HTTP_METHOD_TYPE = "GET" - OPERATION_ID = "list" - - def _record_request(self, pin, integration, span, args, kwargs): - super()._record_request(pin, integration, span, args, kwargs) - endpoint = span.get_tag("openai.request.endpoint") - if endpoint.endswith("/models"): - span.resource = "listModels" - elif endpoint.endswith("/files"): - span.resource = "listFiles" - - def _record_response(self, pin, integration, span, args, kwargs, resp, error): - resp = super()._record_response(pin, integration, span, args, kwargs, resp, error) - if not resp: - return - span.set_metric("openai.response.count", len(resp.data or [])) - return resp - - -class _ModelListHook(_ListHook): - """ - Hook for openai.resources.models.Models.list (v1) - """ - - ENDPOINT_NAME = "models" - OPERATION_ID = "listModels" - - -class _FileListHook(_ListHook): - """ - Hook for openai.resources.files.Files.list (v1) - """ - - ENDPOINT_NAME = "files" - OPERATION_ID = "listFiles" - - -class _RetrieveHook(_EndpointHook): - """Hook for openai.APIResource, which is used by Model.retrieve, File.retrieve, and FineTune.retrieve.""" - - _request_arg_params = (None, "api_key", "request_id", "request_timeout") - _request_kwarg_params = ("user",) - _response_attrs = ( - "id", - "owned_by", - "model", - "parent", - "root", - "bytes", - "created", - "created_at", - "purpose", - "filename", - "fine_tuned_model", - "status", - "status_details", - "updated_at", - ) - ENDPOINT_NAME = None - HTTP_METHOD_TYPE = "GET" - OPERATION_ID = "retrieve" - - def _record_request(self, pin, integration, span, args, kwargs): - super()._record_request(pin, integration, span, args, kwargs) - endpoint = span.get_tag("openai.request.endpoint") - if endpoint.endswith("/models"): - span.resource = "retrieveModel" - span.set_tag_str("openai.request.model", args[1] if len(args) >= 2 else kwargs.get("model", "")) - elif endpoint.endswith("/files"): - span.resource = "retrieveFile" - span.set_tag_str("openai.request.file_id", args[1] if len(args) >= 2 else kwargs.get("file_id", "")) - span.set_tag_str("openai.request.endpoint", "%s/*" % endpoint) - - def _record_response(self, pin, integration, span, args, kwargs, resp, error): - resp = super()._record_response(pin, integration, span, args, kwargs, resp, error) - if not resp: - return - if hasattr(resp, "hyperparams"): - for hyperparam in ("batch_size", "learning_rate_multiplier", "n_epochs", "prompt_loss_weight"): - val = getattr(resp.hyperparams, hyperparam, "") - span.set_tag_str("openai.response.hyperparams.%s" % hyperparam, str(val)) - for resp_attr in ("result_files", "training_files", "validation_files"): - if hasattr(resp, resp_attr): - span.set_metric("openai.response.%s_count" % resp_attr, len(getattr(resp, resp_attr, []))) - if hasattr(resp, "events"): - span.set_metric("openai.response.events_count", len(resp.events)) - return resp - - -class _ModelRetrieveHook(_RetrieveHook): - """ - Hook for openai.resources.models.Models.retrieve - """ - - ENDPOINT_NAME = "models" - OPERATION_ID = "retrieveModel" - - def _record_request(self, pin, integration, span, args, kwargs): - super()._record_request(pin, integration, span, args, kwargs) - span.set_tag_str("openai.request.model", args[1] if len(args) >= 2 else kwargs.get("model", "")) - - -class _FileRetrieveHook(_RetrieveHook): - """ - Hook for openai.resources.files.Files.retrieve - """ - - ENDPOINT_NAME = "files" - OPERATION_ID = "retrieveFile" - - def _record_request(self, pin, integration, span, args, kwargs): - super()._record_request(pin, integration, span, args, kwargs) - span.set_tag_str("openai.request.file_id", args[1] if len(args) >= 2 else kwargs.get("file_id", "")) - - -class _DeleteHook(_EndpointHook): - """Hook for openai.DeletableAPIResource, which is used by File.delete, and Model.delete.""" - - _request_arg_params = (None, "api_type", "api_version") - _request_kwarg_params = ("user",) - ENDPOINT_NAME = None - HTTP_METHOD_TYPE = "DELETE" - OPERATION_ID = "delete" - - def _record_request(self, pin, integration, span, args, kwargs): - super()._record_request(pin, integration, span, args, kwargs) - endpoint = span.get_tag("openai.request.endpoint") - if endpoint.endswith("/models"): - span.resource = "deleteModel" - span.set_tag_str("openai.request.model", args[1] if len(args) >= 2 else kwargs.get("model", "")) - elif endpoint.endswith("/files"): - span.resource = "deleteFile" - span.set_tag_str("openai.request.file_id", args[1] if len(args) >= 2 else kwargs.get("file_id", "")) - span.set_tag_str("openai.request.endpoint", "%s/*" % endpoint) - - def _record_response(self, pin, integration, span, args, kwargs, resp, error): - resp = super()._record_response(pin, integration, span, args, kwargs, resp, error) - if not resp: - return - if hasattr(resp, "data"): - if resp._headers.get("openai-organization"): - span.set_tag_str("openai.organization.name", resp._headers.get("openai-organization")) - span.set_tag_str("openai.response.id", resp.data.get("id", "")) - span.set_tag_str("openai.response.deleted", str(resp.data.get("deleted", ""))) - else: - span.set_tag_str("openai.response.id", str(resp.id)) - span.set_tag_str("openai.response.deleted", str(resp.deleted)) - return resp - - -class _FileDeleteHook(_DeleteHook): - """ - Hook for openai.resources.files.Files.delete - """ - - ENDPOINT_NAME = "files" - - -class _ModelDeleteHook(_DeleteHook): - """ - Hook for openai.resources.models.Models.delete - """ - - ENDPOINT_NAME = "models" - - -class _ImageHook(_EndpointHook): - _response_attrs = ("created",) - ENDPOINT_NAME = "images" - HTTP_METHOD_TYPE = "POST" - - def _record_request(self, pin, integration, span, args, kwargs): - super()._record_request(pin, integration, span, args, kwargs) - span.set_tag_str("openai.request.model", "dall-e") - - def _record_response(self, pin, integration, span, args, kwargs, resp, error): - resp = super()._record_response(pin, integration, span, args, kwargs, resp, error) - if integration.is_pc_sampled_log(span): - attrs_dict = {} - if kwargs.get("response_format", "") == "b64_json": - attrs_dict.update({"choices": [{"b64_json": "returned"} for _ in resp.data]}) - else: - log_choices = resp.data - if hasattr(resp.data[0], "model_dump"): - log_choices = [choice.model_dump() for choice in resp.data] - attrs_dict.update({"choices": log_choices}) - if "prompt" in self._request_kwarg_params: - attrs_dict.update({"prompt": kwargs.get("prompt", "")}) - if "image" in self._request_kwarg_params: - image = args[1] if len(args) >= 2 else kwargs.get("image", "") - attrs_dict.update({"image": image.name.split("/")[-1]}) - if "mask" in self._request_kwarg_params: - mask = args[2] if len(args) >= 3 else kwargs.get("mask", "") - attrs_dict.update({"mask": mask.name.split("/")[-1]}) - integration.log( - span, "info" if error is None else "error", "sampled %s" % self.OPERATION_ID, attrs=attrs_dict - ) - if not resp: - return - choices = resp.data - span.set_metric("openai.response.images_count", len(choices)) - if integration.is_pc_sampled_span(span): - for idx, choice in enumerate(choices): - if getattr(choice, "b64_json", None) is not None: - span.set_tag_str("openai.response.images.%d.b64_json" % idx, "returned") - else: - span.set_tag_str("openai.response.images.%d.url" % idx, integration.trunc(choice.url)) - return resp - - -class _ImageCreateHook(_ImageHook): - _request_arg_params = ("api_key", "api_base", "api_type", "api_version", "organization") - _request_kwarg_params = ("prompt", "n", "size", "response_format", "user") - ENDPOINT_NAME = "images/generations" - OPERATION_ID = "createImage" - - -class _ImageEditHook(_ImageHook): - _request_arg_params = (None, None, "api_key", "api_base", "api_type", "api_version", "organization") - _request_kwarg_params = ("prompt", "n", "size", "response_format", "user", "image", "mask") - ENDPOINT_NAME = "images/edits" - OPERATION_ID = "createImageEdit" - - def _record_request(self, pin, integration, span, args, kwargs): - super()._record_request(pin, integration, span, args, kwargs) - if not integration.is_pc_sampled_span: - return - image = args[1] if len(args) >= 2 else kwargs.get("image", "") - mask = args[2] if len(args) >= 3 else kwargs.get("mask", "") - if image: - if hasattr(image, "name"): - span.set_tag_str("openai.request.image", integration.trunc(image.name.split("/")[-1])) - else: - span.set_tag_str("openai.request.image", "") - if mask: - if hasattr(mask, "name"): - span.set_tag_str("openai.request.mask", integration.trunc(mask.name.split("/")[-1])) - else: - span.set_tag_str("openai.request.mask", "") - - -class _ImageVariationHook(_ImageHook): - _request_arg_params = (None, "api_key", "api_base", "api_type", "api_version", "organization") - _request_kwarg_params = ("n", "size", "response_format", "user", "image") - ENDPOINT_NAME = "images/variations" - OPERATION_ID = "createImageVariation" - - def _record_request(self, pin, integration, span, args, kwargs): - super()._record_request(pin, integration, span, args, kwargs) - if not integration.is_pc_sampled_span: - return - image = args[1] if len(args) >= 2 else kwargs.get("image", "") - if image: - if hasattr(image, "name"): - span.set_tag_str("openai.request.image", integration.trunc(image.name.split("/")[-1])) - else: - span.set_tag_str("openai.request.image", "") - - -class _BaseAudioHook(_EndpointHook): - _request_arg_params = ("model", None, "api_key", "api_base", "api_type", "api_version", "organization") - _response_attrs = ("language", "duration") - ENDPOINT_NAME = "audio" - HTTP_METHOD_TYPE = "POST" - - def _record_request(self, pin, integration, span, args, kwargs): - super()._record_request(pin, integration, span, args, kwargs) - if not integration.is_pc_sampled_span: - return - audio_file = args[2] if len(args) >= 3 else kwargs.get("file", "") - if audio_file and hasattr(audio_file, "name"): - span.set_tag_str("openai.request.filename", integration.trunc(audio_file.name.split("/")[-1])) - else: - span.set_tag_str("openai.request.filename", "") - - def _record_response(self, pin, integration, span, args, kwargs, resp, error): - resp = super()._record_response(pin, integration, span, args, kwargs, resp, error) - text = "" - if resp: - resp_to_tag = resp.model_dump() if hasattr(resp, "model_dump") else resp - if isinstance(resp_to_tag, str): - text = resp - elif isinstance(resp_to_tag, dict): - text = resp_to_tag.get("text", "") - if "segments" in resp_to_tag: - span.set_metric("openai.response.segments_count", len(resp_to_tag.get("segments"))) - if integration.is_pc_sampled_span(span): - span.set_tag_str("openai.response.text", integration.trunc(text)) - if integration.is_pc_sampled_log(span): - file_input = args[2] if len(args) >= 3 else kwargs.get("file", "") - integration.log( - span, - "info" if error is None else "error", - "sampled %s" % self.OPERATION_ID, - attrs={ - "file": getattr(file_input, "name", "").split("/")[-1], - "prompt": kwargs.get("prompt", ""), - "language": kwargs.get("language", ""), - "text": text, - }, - ) - return resp - - -class _AudioTranscriptionHook(_BaseAudioHook): - _request_kwarg_params = ( - "prompt", - "response_format", - "temperature", - "language", - "user", - ) - ENDPOINT_NAME = "audio/transcriptions" - OPERATION_ID = "createTranscription" - - -class _AudioTranslationHook(_BaseAudioHook): - _request_kwarg_params = ( - "prompt", - "response_format", - "temperature", - "user", - ) - ENDPOINT_NAME = "audio/translations" - OPERATION_ID = "createTranslation" - - -class _ModerationHook(_EndpointHook): - _request_arg_params = ("input", "model", "api_key") - _request_kwarg_params = ("input", "model") - _response_attrs = ("id", "model") - _response_categories = ( - "hate", - "hate/threatening", - "harassment", - "harassment/threatening", - "self-harm", - "self-harm/intent", - "self-harm/instructions", - "sexual", - "sexual/minors", - "violence", - "violence/graphic", +def __getattr__(name): + deprecate( + ("%s.%s is deprecated" % (__name__, name)), + category=DDTraceDeprecationWarning, ) - ENDPOINT_NAME = "moderations" - HTTP_METHOD_TYPE = "POST" - OPERATION_ID = "createModeration" - - def _record_request(self, pin, integration, span, args, kwargs): - super()._record_request(pin, integration, span, args, kwargs) - - def _record_response(self, pin, integration, span, args, kwargs, resp, error): - resp = super()._record_response(pin, integration, span, args, kwargs, resp, error) - if not resp: - return - results = resp.results[0] - categories = results.categories - scores = results.category_scores - for category in self._response_categories: - span.set_metric("openai.response.category_scores.%s" % category, getattr(scores, category, 0)) - span.set_metric("openai.response.categories.%s" % category, int(getattr(categories, category))) - span.set_metric("openai.response.flagged", int(results.flagged)) - return resp - - -class _BaseFileHook(_EndpointHook): - ENDPOINT_NAME = "files" - - -class _FileCreateHook(_BaseFileHook): - _request_arg_params = ( - None, - "purpose", - "model", - "api_key", - "api_base", - "api_type", - "api_version", - "organization", - "user_provided_filename", - ) - _request_kwarg_params = ("purpose",) - _response_attrs = ("id", "bytes", "created_at", "filename", "purpose", "status", "status_details") - HTTP_METHOD_TYPE = "POST" - OPERATION_ID = "createFile" - - def _record_request(self, pin, integration, span, args, kwargs): - super()._record_request(pin, integration, span, args, kwargs) - fp = args[1] if len(args) >= 2 else kwargs.get("file", "") - if fp and hasattr(fp, "name"): - span.set_tag_str("openai.request.filename", fp.name.split("/")[-1]) - else: - span.set_tag_str("openai.request.filename", "") - - def _record_response(self, pin, integration, span, args, kwargs, resp, error): - resp = super()._record_response(pin, integration, span, args, kwargs, resp, error) - return resp - - -class _FileDownloadHook(_BaseFileHook): - _request_arg_params = (None, "api_key", "api_base", "api_type", "api_version", "organization") - HTTP_METHOD_TYPE = "GET" - OPERATION_ID = "downloadFile" - ENDPOINT_NAME = "files/*/content" - - def _record_request(self, pin, integration, span, args, kwargs): - super()._record_request(pin, integration, span, args, kwargs) - span.set_tag_str("openai.request.file_id", args[1] if len(args) >= 2 else kwargs.get("file_id", "")) - def _record_response(self, pin, integration, span, args, kwargs, resp, error): - resp = super()._record_response(pin, integration, span, args, kwargs, resp, error) - if not resp: - return - if isinstance(resp, bytes) or isinstance(resp, str): - span.set_metric("openai.response.total_bytes", len(resp)) - else: - span.set_metric("openai.response.total_bytes", getattr(resp, "total_bytes", 0)) - return resp + if name in globals(): + return globals()[name] + raise AttributeError("%s has no attribute %s", __name__, name) diff --git a/ddtrace/contrib/openai/patch.py b/ddtrace/contrib/openai/patch.py index 5e3bf2caead..89ea9d21adf 100644 --- a/ddtrace/contrib/openai/patch.py +++ b/ddtrace/contrib/openai/patch.py @@ -1,358 +1,4 @@ -import os -import sys +from ..internal.openai.patch import * # noqa: F401,F403 -from openai import version -from ddtrace import config -from ddtrace.internal.logger import get_logger -from ddtrace.internal.schema import schematize_service_name -from ddtrace.internal.utils.formats import asbool -from ddtrace.internal.utils.formats import deep_getattr -from ddtrace.internal.utils.version import parse_version -from ddtrace.internal.wrapping import wrap -from ddtrace.llmobs._integrations import OpenAIIntegration - -from ...pin import Pin -from . import _endpoint_hooks -from .utils import _format_openai_api_key - - -log = get_logger(__name__) - - -config._add( - "openai", - { - "logs_enabled": asbool(os.getenv("DD_OPENAI_LOGS_ENABLED", False)), - "metrics_enabled": asbool(os.getenv("DD_OPENAI_METRICS_ENABLED", True)), - "span_prompt_completion_sample_rate": float(os.getenv("DD_OPENAI_SPAN_PROMPT_COMPLETION_SAMPLE_RATE", 1.0)), - "log_prompt_completion_sample_rate": float(os.getenv("DD_OPENAI_LOG_PROMPT_COMPLETION_SAMPLE_RATE", 0.1)), - "span_char_limit": int(os.getenv("DD_OPENAI_SPAN_CHAR_LIMIT", 128)), - }, -) - - -def get_version(): - # type: () -> str - return version.VERSION - - -OPENAI_VERSION = parse_version(get_version()) - - -if OPENAI_VERSION >= (1, 0, 0): - _RESOURCES = { - "models.Models": { - "list": _endpoint_hooks._ModelListHook, - "retrieve": _endpoint_hooks._ModelRetrieveHook, - "delete": _endpoint_hooks._ModelDeleteHook, - }, - "completions.Completions": { - "create": _endpoint_hooks._CompletionHook, - }, - "chat.Completions": { - "create": _endpoint_hooks._ChatCompletionHook, - }, - "images.Images": { - "generate": _endpoint_hooks._ImageCreateHook, - "edit": _endpoint_hooks._ImageEditHook, - "create_variation": _endpoint_hooks._ImageVariationHook, - }, - "audio.Transcriptions": { - "create": _endpoint_hooks._AudioTranscriptionHook, - }, - "audio.Translations": { - "create": _endpoint_hooks._AudioTranslationHook, - }, - "embeddings.Embeddings": { - "create": _endpoint_hooks._EmbeddingHook, - }, - "moderations.Moderations": { - "create": _endpoint_hooks._ModerationHook, - }, - "files.Files": { - "create": _endpoint_hooks._FileCreateHook, - "retrieve": _endpoint_hooks._FileRetrieveHook, - "list": _endpoint_hooks._FileListHook, - "delete": _endpoint_hooks._FileDeleteHook, - "retrieve_content": _endpoint_hooks._FileDownloadHook, - }, - } -else: - _RESOURCES = { - "model.Model": { - "list": _endpoint_hooks._ListHook, - "retrieve": _endpoint_hooks._RetrieveHook, - }, - "completion.Completion": { - "create": _endpoint_hooks._CompletionHook, - }, - "chat_completion.ChatCompletion": { - "create": _endpoint_hooks._ChatCompletionHook, - }, - "image.Image": { - "create": _endpoint_hooks._ImageCreateHook, - "create_edit": _endpoint_hooks._ImageEditHook, - "create_variation": _endpoint_hooks._ImageVariationHook, - }, - "audio.Audio": { - "transcribe": _endpoint_hooks._AudioTranscriptionHook, - "translate": _endpoint_hooks._AudioTranslationHook, - }, - "embedding.Embedding": { - "create": _endpoint_hooks._EmbeddingHook, - }, - "moderation.Moderation": { - "create": _endpoint_hooks._ModerationHook, - }, - "file.File": { - # File.list() and File.retrieve() share the same underlying method as Model.list() and Model.retrieve() - # which means they are already wrapped - "create": _endpoint_hooks._FileCreateHook, - "delete": _endpoint_hooks._DeleteHook, - "download": _endpoint_hooks._FileDownloadHook, - }, - } - - -def _wrap_classmethod(obj, wrapper): - wrap(obj.__func__, wrapper) - - -def patch(): - # Avoid importing openai at the module level, eventually will be an import hook - import openai - - if getattr(openai, "__datadog_patch", False): - return - - Pin().onto(openai) - integration = OpenAIIntegration(integration_config=config.openai, openai=openai) - - if OPENAI_VERSION >= (1, 0, 0): - if OPENAI_VERSION >= (1, 8, 0): - wrap(openai._base_client.SyncAPIClient._process_response, _patched_convert(openai, integration)) - wrap(openai._base_client.AsyncAPIClient._process_response, _patched_convert(openai, integration)) - else: - wrap(openai._base_client.BaseClient._process_response, _patched_convert(openai, integration)) - wrap(openai.OpenAI.__init__, _patched_client_init(openai, integration)) - wrap(openai.AsyncOpenAI.__init__, _patched_client_init(openai, integration)) - wrap(openai.AzureOpenAI.__init__, _patched_client_init(openai, integration)) - wrap(openai.AsyncAzureOpenAI.__init__, _patched_client_init(openai, integration)) - - for resource, method_hook_dict in _RESOURCES.items(): - if deep_getattr(openai.resources, resource) is None: - continue - for method_name, endpoint_hook in method_hook_dict.items(): - sync_method = deep_getattr(openai.resources, "%s.%s" % (resource, method_name)) - async_method = deep_getattr( - openai.resources, "%s.%s" % (".Async".join(resource.split(".")), method_name) - ) - wrap(sync_method, _patched_endpoint(openai, integration, endpoint_hook)) - wrap(async_method, _patched_endpoint_async(openai, integration, endpoint_hook)) - else: - import openai.api_requestor - - wrap(openai.api_requestor._make_session, _patched_make_session) - wrap(openai.util.convert_to_openai_object, _patched_convert(openai, integration)) - - for resource, method_hook_dict in _RESOURCES.items(): - if deep_getattr(openai.api_resources, resource) is None: - continue - for method_name, endpoint_hook in method_hook_dict.items(): - sync_method = deep_getattr(openai.api_resources, "%s.%s" % (resource, method_name)) - async_method = deep_getattr(openai.api_resources, "%s.a%s" % (resource, method_name)) - _wrap_classmethod(sync_method, _patched_endpoint(openai, integration, endpoint_hook)) - _wrap_classmethod(async_method, _patched_endpoint_async(openai, integration, endpoint_hook)) - - openai.__datadog_patch = True - - -def unpatch(): - # FIXME: add unpatching. The current wrapping.unwrap method requires - # the wrapper function to be provided which we don't keep a reference to. - pass - - -def _patched_client_init(openai, integration): - """ - Patch for `openai.OpenAI/AsyncOpenAI` client init methods to add the client object to the OpenAIIntegration object. - """ - - def patched_client_init(func, args, kwargs): - func(*args, **kwargs) - client = args[0] - integration._client = client - api_key = kwargs.get("api_key") - if api_key is None: - api_key = client.api_key - if api_key is not None: - integration.user_api_key = api_key - return - - return patched_client_init - - -def _patched_make_session(func, args, kwargs): - """Patch for `openai.api_requestor._make_session` which sets the service name on the - requests session so that spans from the requests integration will use the service name openai. - This is done so that the service break down will include OpenAI time spent querying the OpenAI backend. - - This should technically be a ``peer.service`` but this concept doesn't exist yet. - """ - session = func(*args, **kwargs) - service = schematize_service_name("openai") - Pin.override(session, service=service) - return session - - -def _traced_endpoint(endpoint_hook, integration, pin, args, kwargs): - span = integration.trace(pin, endpoint_hook.OPERATION_ID) - openai_api_key = _format_openai_api_key(kwargs.get("api_key")) - err = None - if openai_api_key: - # API key can either be set on the import or per request - span.set_tag_str("openai.user.api_key", openai_api_key) - try: - # Start the hook - hook = endpoint_hook().handle_request(pin, integration, span, args, kwargs) - hook.send(None) - - resp, err = yield - - # Record any error information - if err is not None: - span.set_exc_info(*sys.exc_info()) - integration.metric(span, "incr", "request.error", 1) - - # Pass the response and the error to the hook - try: - hook.send((resp, err)) - except StopIteration as e: - if err is None: - return e.value - finally: - # Streamed responses will be finished when the generator exits, so finish non-streamed spans here. - # Streamed responses with error will need to be finished manually as well. - if not kwargs.get("stream") or err is not None: - span.finish() - integration.metric(span, "dist", "request.duration", span.duration_ns) - - -def _patched_endpoint(openai, integration, patch_hook): - def patched_endpoint(func, args, kwargs): - # FIXME: this is a temporary workaround for the fact that our bytecode wrapping seems to modify - # a function keyword argument into a cell when it shouldn't. This is only an issue on - # Python 3.11+. - if sys.version_info >= (3, 11) and kwargs.get("encoding_format", None): - kwargs["encoding_format"] = kwargs["encoding_format"].cell_contents - - pin = Pin._find(openai, args[0]) - if not pin or not pin.enabled(): - return func(*args, **kwargs) - - g = _traced_endpoint(patch_hook, integration, pin, args, kwargs) - g.send(None) - resp, err = None, None - try: - resp = func(*args, **kwargs) - return resp - except Exception as e: - err = e - raise - finally: - try: - g.send((resp, err)) - except StopIteration as e: - if err is None: - # This return takes priority over `return resp` - return e.value # noqa: B012 - - return patched_endpoint - - -def _patched_endpoint_async(openai, integration, patch_hook): - # Same as _patched_endpoint but async - async def patched_endpoint(func, args, kwargs): - # FIXME: this is a temporary workaround for the fact that our bytecode wrapping seems to modify - # a function keyword argument into a cell when it shouldn't. This is only an issue on - # Python 3.11+. - if sys.version_info >= (3, 11) and kwargs.get("encoding_format", None): - kwargs["encoding_format"] = kwargs["encoding_format"].cell_contents - - pin = Pin._find(openai, args[0]) - if not pin or not pin.enabled(): - return await func(*args, **kwargs) - g = _traced_endpoint(patch_hook, integration, pin, args, kwargs) - g.send(None) - resp, err = None, None - try: - resp = await func(*args, **kwargs) - return resp - except Exception as e: - err = e - raise - finally: - try: - g.send((resp, err)) - except StopIteration as e: - if err is None: - # This return takes priority over `return resp` - return e.value # noqa: B012 - - return patched_endpoint - - -def _patched_convert(openai, integration): - def patched_convert(func, args, kwargs): - """Patch convert captures header information in the openai response""" - pin = Pin.get_from(openai) - if not pin or not pin.enabled(): - return func(*args, **kwargs) - - span = pin.tracer.current_span() - if not span: - return func(*args, **kwargs) - - if OPENAI_VERSION < (1, 0, 0): - resp = args[0] - if not isinstance(resp, openai.openai_response.OpenAIResponse): - return func(*args, **kwargs) - headers = resp._headers - else: - resp = kwargs.get("response", {}) - headers = resp.headers - # This function is called for each chunk in the stream. - # To prevent needlessly setting the same tags for each chunk, short-circuit here. - if span.get_tag("openai.organization.name") is not None: - return func(*args, **kwargs) - if headers.get("openai-organization"): - org_name = headers.get("openai-organization") - span.set_tag_str("openai.organization.name", org_name) - - # Gauge total rate limit - if headers.get("x-ratelimit-limit-requests"): - v = headers.get("x-ratelimit-limit-requests") - if v is not None: - integration.metric(span, "gauge", "ratelimit.requests", int(v)) - span.set_metric("openai.organization.ratelimit.requests.limit", int(v)) - if headers.get("x-ratelimit-limit-tokens"): - v = headers.get("x-ratelimit-limit-tokens") - if v is not None: - integration.metric(span, "gauge", "ratelimit.tokens", int(v)) - span.set_metric("openai.organization.ratelimit.tokens.limit", int(v)) - # Gauge and set span info for remaining requests and tokens - if headers.get("x-ratelimit-remaining-requests"): - v = headers.get("x-ratelimit-remaining-requests") - if v is not None: - integration.metric(span, "gauge", "ratelimit.remaining.requests", int(v)) - span.set_metric("openai.organization.ratelimit.requests.remaining", int(v)) - if headers.get("x-ratelimit-remaining-tokens"): - v = headers.get("x-ratelimit-remaining-tokens") - if v is not None: - integration.metric(span, "gauge", "ratelimit.remaining.tokens", int(v)) - span.set_metric("openai.organization.ratelimit.tokens.remaining", int(v)) - - return func(*args, **kwargs) - - return patched_convert +# TODO: deprecate and remove this module diff --git a/ddtrace/contrib/openai/utils.py b/ddtrace/contrib/openai/utils.py index 53566a4336f..b0290c20ca8 100644 --- a/ddtrace/contrib/openai/utils.py +++ b/ddtrace/contrib/openai/utils.py @@ -1,353 +1,15 @@ -import re -import sys -from typing import Any -from typing import AsyncGenerator -from typing import Dict -from typing import Generator -from typing import List +from ddtrace.internal.utils.deprecations import DDTraceDeprecationWarning +from ddtrace.vendor.debtcollector import deprecate -from ddtrace.internal.logger import get_logger -from ddtrace.vendor import wrapt +from ..internal.openai.utils import * # noqa: F401,F403 -try: - from tiktoken import encoding_for_model +def __getattr__(name): + deprecate( + ("%s.%s is deprecated" % (__name__, name)), + category=DDTraceDeprecationWarning, + ) - tiktoken_available = True -except ModuleNotFoundError: - tiktoken_available = False - - -log = get_logger(__name__) - -_punc_regex = re.compile(r"[\w']+|[.,!?;~@#$%^&*()+/-]") - - -def _process_finished_stream(integration, span, kwargs, streamed_chunks, is_completion=False): - completions, messages = None, None - prompts = kwargs.get("prompt", None) - messages = kwargs.get("messages", None) - try: - _set_metrics_on_request(integration, span, kwargs, prompts=prompts, messages=messages) - if is_completion: - completions = [_construct_completion_from_streamed_chunks(choice) for choice in streamed_chunks] - if integration.is_pc_sampled_span(span): - _tag_streamed_completion_response(integration, span, completions) - else: - messages = [_construct_message_from_streamed_chunks(choice) for choice in streamed_chunks] - if integration.is_pc_sampled_span(span): - _tag_streamed_chat_completion_response(integration, span, messages) - _set_metrics_on_streamed_response(integration, span, completions=completions, messages=messages) - if integration.is_pc_sampled_llmobs(span): - integration.llmobs_set_tags( - "completion" if is_completion else "chat", - None, - span, - kwargs, - streamed_completions=completions if is_completion else messages, - ) - except Exception: - log.warning("Error processing streamed completion/chat response.", exc_info=True) - - -class BaseTracedOpenAIStream(wrapt.ObjectProxy): - def __init__(self, wrapped, integration, span, kwargs, is_completion=False): - super().__init__(wrapped) - n = kwargs.get("n", 1) or 1 - if is_completion: - prompts = kwargs.get("prompt", "") - if isinstance(prompts, list) and not isinstance(prompts[0], int): - n *= len(prompts) - self._dd_span = span - self._streamed_chunks = [[] for _ in range(n)] - self._dd_integration = integration - self._is_completion = is_completion - self._kwargs = kwargs - - -class TracedOpenAIStream(BaseTracedOpenAIStream): - def __enter__(self): - self.__wrapped__.__enter__() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.__wrapped__.__exit__(exc_type, exc_val, exc_tb) - - def __iter__(self): - return self - - def __next__(self): - try: - chunk = self.__wrapped__.__next__() - _loop_handler(self._dd_span, chunk, self._streamed_chunks) - return chunk - except StopIteration: - _process_finished_stream( - self._dd_integration, self._dd_span, self._kwargs, self._streamed_chunks, self._is_completion - ) - self._dd_span.finish() - self._dd_integration.metric(self._dd_span, "dist", "request.duration", self._dd_span.duration_ns) - raise - except Exception: - self._dd_span.set_exc_info(*sys.exc_info()) - self._dd_span.finish() - self._dd_integration.metric(self._dd_span, "dist", "request.duration", self._dd_span.duration_ns) - raise - - -class TracedOpenAIAsyncStream(BaseTracedOpenAIStream): - async def __aenter__(self): - await self.__wrapped__.__aenter__() - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - await self.__wrapped__.__aexit__(exc_type, exc_val, exc_tb) - - def __aiter__(self): - return self - - async def __anext__(self): - try: - chunk = await self.__wrapped__.__anext__() - _loop_handler(self._dd_span, chunk, self._streamed_chunks) - return chunk - except StopAsyncIteration: - _process_finished_stream( - self._dd_integration, self._dd_span, self._kwargs, self._streamed_chunks, self._is_completion - ) - self._dd_span.finish() - self._dd_integration.metric(self._dd_span, "dist", "request.duration", self._dd_span.duration_ns) - raise - except Exception: - self._dd_span.set_exc_info(*sys.exc_info()) - self._dd_span.finish() - self._dd_integration.metric(self._dd_span, "dist", "request.duration", self._dd_span.duration_ns) - raise - - -def _compute_token_count(content, model): - # type: (Union[str, List[int]], Optional[str]) -> Tuple[bool, int] - """ - Takes in prompt/response(s) and model pair, and returns a tuple of whether or not the number of prompt - tokens was estimated, and the estimated/calculated prompt token count. - """ - num_prompt_tokens = 0 - estimated = False - if model is not None and tiktoken_available is True: - try: - enc = encoding_for_model(model) - if isinstance(content, str): - num_prompt_tokens += len(enc.encode(content)) - elif isinstance(content, list) and isinstance(content[0], int): - num_prompt_tokens += len(content) - return estimated, num_prompt_tokens - except KeyError: - # tiktoken.encoding_for_model() will raise a KeyError if it doesn't have a tokenizer for the model - estimated = True - else: - estimated = True - - # If model is unavailable or tiktoken is not imported, then provide a very rough estimate of the number of tokens - return estimated, _est_tokens(content) - - -def _est_tokens(prompt): - # type: (Union[str, List[int]]) -> int - """ - Provide a very rough estimate of the number of tokens in a string prompt. - Note that if the prompt is passed in as a token array (list of ints), the token count - is just the length of the token array. - """ - # If model is unavailable or tiktoken is not imported, then provide a very rough estimate of the number of tokens - # Approximate using the following assumptions: - # * English text - # * 1 token ~= 4 chars - # * 1 token ~= ¾ words - est_tokens = 0 - if isinstance(prompt, str): - est1 = len(prompt) / 4 - est2 = len(_punc_regex.findall(prompt)) * 0.75 - return round((1.5 * est1 + 0.5 * est2) / 2) - elif isinstance(prompt, list) and isinstance(prompt[0], int): - return len(prompt) - return est_tokens - - -def _format_openai_api_key(openai_api_key): - # type: (Optional[str]) -> Optional[str] - """ - Returns `sk-...XXXX`, where XXXX is the last 4 characters of the provided OpenAI API key. - This mimics how OpenAI UI formats the API key. - """ - if not openai_api_key: - return None - return "sk-...%s" % openai_api_key[-4:] - - -def _is_generator(resp): - # type: (...) -> bool - import openai - - # In OpenAI v1, the response is type `openai.Stream` instead of Generator. - if isinstance(resp, Generator): - return True - if hasattr(openai, "Stream") and isinstance(resp, openai.Stream): - return True - return False - - -def _is_async_generator(resp): - # type: (...) -> bool - import openai - - # In OpenAI v1, the response is type `openai.AsyncStream` instead of AsyncGenerator. - if isinstance(resp, AsyncGenerator): - return True - if hasattr(openai, "AsyncStream") and isinstance(resp, openai.AsyncStream): - return True - return False - - -def _construct_completion_from_streamed_chunks(streamed_chunks: List[Any]) -> Dict[str, str]: - """Constructs a completion dictionary of form {"text": "...", "finish_reason": "..."} from streamed chunks.""" - completion = {"text": "".join(c.text for c in streamed_chunks if getattr(c, "text", None))} - if streamed_chunks[-1].finish_reason is not None: - completion["finish_reason"] = streamed_chunks[-1].finish_reason - return completion - - -def _construct_message_from_streamed_chunks(streamed_chunks: List[Any]) -> Dict[str, str]: - """Constructs a chat completion message dictionary from streamed chunks. - The resulting message dictionary is of form {"content": "...", "role": "...", "finish_reason": "..."} - """ - message = {} - content = "" - formatted_content = "" - idx = None - for chunk in streamed_chunks: - chunk_content = getattr(chunk.delta, "content", "") - if chunk_content: - content += chunk_content - elif getattr(chunk.delta, "function_call", None): - if idx is None: - formatted_content += "\n\n[function: {}]\n\n".format(getattr(chunk.delta.function_call, "name", "")) - idx = chunk.index - function_args = getattr(chunk.delta.function_call, "arguments", "") - content += "{}".format(function_args) - formatted_content += "{}".format(function_args) - elif getattr(chunk.delta, "tool_calls", None): - for tool_call in chunk.delta.tool_calls: - if tool_call.index != idx: - formatted_content += "\n\n[tool: {}]\n\n".format(getattr(tool_call.function, "name", "")) - idx = tool_call.index - function_args = getattr(tool_call.function, "arguments", "") - content += "{}".format(function_args) - formatted_content += "{}".format(function_args) - - message["role"] = streamed_chunks[0].delta.role or "assistant" - if streamed_chunks[-1].finish_reason is not None: - message["finish_reason"] = streamed_chunks[-1].finish_reason - message["content"] = content.strip() - if formatted_content: - message["formatted_content"] = formatted_content.strip() - return message - - -def _tag_streamed_completion_response(integration, span, completions): - """Tagging logic for streamed completions.""" - if completions is None: - return - for idx, choice in enumerate(completions): - span.set_tag_str("openai.response.choices.%d.text" % idx, integration.trunc(choice["text"])) - if choice.get("finish_reason") is not None: - span.set_tag_str("openai.response.choices.%d.finish_reason" % idx, choice["finish_reason"]) - - -def _tag_streamed_chat_completion_response(integration, span, messages): - """Tagging logic for streamed chat completions.""" - if messages is None: - return - for idx, message in enumerate(messages): - span.set_tag_str("openai.response.choices.%d.message.content" % idx, integration.trunc(message["content"])) - span.set_tag_str("openai.response.choices.%d.message.role" % idx, message["role"]) - if message.get("finish_reason") is not None: - span.set_tag_str("openai.response.choices.%d.finish_reason" % idx, message["finish_reason"]) - - -def _set_metrics_on_request(integration, span, kwargs, prompts=None, messages=None): - """Set token span metrics on streamed chat/completion requests.""" - num_prompt_tokens = 0 - estimated = False - if messages is not None: - for m in messages: - estimated, prompt_tokens = _compute_token_count(m.get("content", ""), kwargs.get("model")) - num_prompt_tokens += prompt_tokens - elif prompts is not None: - if isinstance(prompts, str) or isinstance(prompts, list) and isinstance(prompts[0], int): - prompts = [prompts] - for prompt in prompts: - estimated, prompt_tokens = _compute_token_count(prompt, kwargs.get("model")) - num_prompt_tokens += prompt_tokens - span.set_metric("openai.request.prompt_tokens_estimated", int(estimated)) - span.set_metric("openai.response.usage.prompt_tokens", num_prompt_tokens) - if not estimated: - integration.metric(span, "dist", "tokens.prompt", num_prompt_tokens) - else: - integration.metric(span, "dist", "tokens.prompt", num_prompt_tokens, tags=["openai.estimated:true"]) - - -def _set_metrics_on_streamed_response(integration, span, completions=None, messages=None): - """Set token span metrics on streamed chat/completion responses.""" - num_completion_tokens = 0 - estimated = False - if messages is not None: - for m in messages: - estimated, completion_tokens = _compute_token_count( - m.get("content", ""), span.get_tag("openai.response.model") - ) - num_completion_tokens += completion_tokens - elif completions is not None: - for c in completions: - estimated, completion_tokens = _compute_token_count( - c.get("text", ""), span.get_tag("openai.response.model") - ) - num_completion_tokens += completion_tokens - span.set_metric("openai.response.completion_tokens_estimated", int(estimated)) - span.set_metric("openai.response.usage.completion_tokens", num_completion_tokens) - num_prompt_tokens = span.get_metric("openai.response.usage.prompt_tokens") or 0 - total_tokens = num_prompt_tokens + num_completion_tokens - span.set_metric("openai.response.usage.total_tokens", total_tokens) - if not estimated: - integration.metric(span, "dist", "tokens.completion", num_completion_tokens) - integration.metric(span, "dist", "tokens.total", total_tokens) - else: - integration.metric(span, "dist", "tokens.completion", num_completion_tokens, tags=["openai.estimated:true"]) - integration.metric(span, "dist", "tokens.total", total_tokens, tags=["openai.estimated:true"]) - - -def _loop_handler(span, chunk, streamed_chunks): - """Sets the openai model tag and appends the chunk to the correct index in the streamed_chunks list. - - When handling a streamed chat/completion response, this function is called for each chunk in the streamed response. - """ - if span.get_tag("openai.response.model") is None: - span.set_tag("openai.response.model", chunk.model) - for choice in chunk.choices: - streamed_chunks[choice.index].append(choice) - - -def _tag_tool_calls(integration, span, tool_calls, choice_idx): - # type: (...) -> None - """ - Tagging logic if function_call or tool_calls are provided in the chat response. - Note: since function calls are deprecated and will be replaced with tool calls, apply the same tagging logic/schema. - """ - for idy, tool_call in enumerate(tool_calls): - if hasattr(tool_call, "function"): - # tool_call is further nested in a "function" object - tool_call = tool_call.function - span.set_tag( - "openai.response.choices.%d.message.tool_calls.%d.arguments" % (choice_idx, idy), - integration.trunc(str(tool_call.arguments)), - ) - span.set_tag("openai.response.choices.%d.message.tool_calls.%d.name" % (choice_idx, idy), str(tool_call.name)) + if name in globals(): + return globals()[name] + raise AttributeError("%s has no attribute %s", __name__, name) diff --git a/ddtrace/contrib/psycopg/__init__.py b/ddtrace/contrib/psycopg/__init__.py index 6e428177182..df1e1177bc3 100644 --- a/ddtrace/contrib/psycopg/__init__.py +++ b/ddtrace/contrib/psycopg/__init__.py @@ -60,9 +60,19 @@ cursor = db.cursor() cursor.execute("select * from users where id = 1") """ -from .patch import get_version -from .patch import get_versions -from .patch import patch +from ...internal.utils.importlib import require_modules -__all__ = ["patch", "get_version", "get_versions"] +required_modules = ["psycopg", "psycopg2"] +with require_modules(required_modules) as missing_modules: + # If psycopg and/or psycopg2 is available, patch these modules + if len(missing_modules) < len(required_modules): + # Required to allow users to import from `ddtrace.contrib.openai.patch` directly + from . import patch as _ # noqa: F401, I001 + + # Expose public methods + from ..internal.psycopg.patch import get_version + from ..internal.psycopg.patch import get_versions + from ..internal.psycopg.patch import patch + + __all__ = ["patch", "get_version", "get_versions"] diff --git a/ddtrace/contrib/psycopg/async_connection.py b/ddtrace/contrib/psycopg/async_connection.py index 8ac8989de6c..62c62d3604c 100644 --- a/ddtrace/contrib/psycopg/async_connection.py +++ b/ddtrace/contrib/psycopg/async_connection.py @@ -1,66 +1,15 @@ -from ddtrace import Pin -from ddtrace import config -from ddtrace.constants import SPAN_KIND -from ddtrace.constants import SPAN_MEASURED_KEY -from ddtrace.contrib import dbapi_async -from ddtrace.contrib.psycopg.async_cursor import Psycopg3FetchTracedAsyncCursor -from ddtrace.contrib.psycopg.async_cursor import Psycopg3TracedAsyncCursor -from ddtrace.contrib.psycopg.connection import patch_conn -from ddtrace.contrib.trace_utils import ext_service -from ddtrace.ext import SpanKind -from ddtrace.ext import SpanTypes -from ddtrace.ext import db -from ddtrace.internal.constants import COMPONENT +from ddtrace.internal.utils.deprecations import DDTraceDeprecationWarning +from ddtrace.vendor.debtcollector import deprecate +from ..internal.psycopg.async_connection import * # noqa: F401,F403 -class Psycopg3TracedAsyncConnection(dbapi_async.TracedAsyncConnection): - def __init__(self, conn, pin=None, cursor_cls=None): - if not cursor_cls: - # Do not trace `fetch*` methods by default - cursor_cls = ( - Psycopg3FetchTracedAsyncCursor if config.psycopg.trace_fetch_methods else Psycopg3TracedAsyncCursor - ) - super(Psycopg3TracedAsyncConnection, self).__init__(conn, pin, config.psycopg, cursor_cls=cursor_cls) +def __getattr__(name): + deprecate( + ("%s.%s is deprecated" % (__name__, name)), + category=DDTraceDeprecationWarning, + ) - async def execute(self, *args, **kwargs): - """Execute a query and return a cursor to read its results.""" - span_name = "{}.{}".format(self._self_datadog_name, "execute") - - async def patched_execute(*args, **kwargs): - try: - cur = self.cursor() - if kwargs.get("binary", None): - cur.format = 1 # set to 1 for binary or 0 if not - return await cur.execute(*args, **kwargs) - except Exception as ex: - raise ex.with_traceback(None) - - return await self._trace_method(patched_execute, span_name, {}, *args, **kwargs) - - -def patched_connect_async_factory(psycopg_module): - async def patched_connect_async(connect_func, _, args, kwargs): - traced_conn_cls = Psycopg3TracedAsyncConnection - - pin = Pin.get_from(psycopg_module) - - if not pin or not pin.enabled() or not pin._config.trace_connect: - conn = await connect_func(*args, **kwargs) - else: - with pin.tracer.trace( - "{}.{}".format(connect_func.__module__, connect_func.__name__), - service=ext_service(pin, pin._config), - span_type=SpanTypes.SQL, - ) as span: - span.set_tag_str(SPAN_KIND, SpanKind.CLIENT) - span.set_tag_str(COMPONENT, pin._config.integration_name) - if span.get_tag(db.SYSTEM) is None: - span.set_tag_str(db.SYSTEM, pin._config.dbms_name) - - span.set_tag(SPAN_MEASURED_KEY) - conn = await connect_func(*args, **kwargs) - - return patch_conn(conn, pin=pin, traced_conn_cls=traced_conn_cls) - - return patched_connect_async + if name in globals(): + return globals()[name] + raise AttributeError("%s has no attribute %s", __name__, name) diff --git a/ddtrace/contrib/psycopg/async_cursor.py b/ddtrace/contrib/psycopg/async_cursor.py index 0a712771fe8..38fcd54f9c5 100644 --- a/ddtrace/contrib/psycopg/async_cursor.py +++ b/ddtrace/contrib/psycopg/async_cursor.py @@ -1,11 +1,15 @@ -from ddtrace.contrib import dbapi_async -from ddtrace.contrib.psycopg.cursor import Psycopg3TracedCursor +from ddtrace.internal.utils.deprecations import DDTraceDeprecationWarning +from ddtrace.vendor.debtcollector import deprecate +from ..internal.psycopg.async_cursor import * # noqa: F401,F403 -class Psycopg3TracedAsyncCursor(Psycopg3TracedCursor, dbapi_async.TracedAsyncCursor): - def __init__(self, cursor, pin, cfg, *args, **kwargs): - super(Psycopg3TracedAsyncCursor, self).__init__(cursor, pin, cfg) +def __getattr__(name): + deprecate( + ("%s.%s is deprecated" % (__name__, name)), + category=DDTraceDeprecationWarning, + ) -class Psycopg3FetchTracedAsyncCursor(Psycopg3TracedAsyncCursor, dbapi_async.FetchTracedAsyncCursor): - """Psycopg3FetchTracedAsyncCursor for psycopg""" + if name in globals(): + return globals()[name] + raise AttributeError("%s has no attribute %s", __name__, name) diff --git a/ddtrace/contrib/psycopg/connection.py b/ddtrace/contrib/psycopg/connection.py index 62647744d12..556b39a2e1e 100644 --- a/ddtrace/contrib/psycopg/connection.py +++ b/ddtrace/contrib/psycopg/connection.py @@ -1,110 +1,15 @@ -from ddtrace import Pin -from ddtrace import config -from ddtrace.constants import SPAN_KIND -from ddtrace.constants import SPAN_MEASURED_KEY -from ddtrace.contrib import dbapi -from ddtrace.contrib.psycopg.cursor import Psycopg2FetchTracedCursor -from ddtrace.contrib.psycopg.cursor import Psycopg2TracedCursor -from ddtrace.contrib.psycopg.cursor import Psycopg3FetchTracedCursor -from ddtrace.contrib.psycopg.cursor import Psycopg3TracedCursor -from ddtrace.contrib.psycopg.extensions import _patch_extensions -from ddtrace.contrib.trace_utils import ext_service -from ddtrace.ext import SpanKind -from ddtrace.ext import SpanTypes -from ddtrace.ext import db -from ddtrace.ext import net -from ddtrace.ext import sql -from ddtrace.internal.constants import COMPONENT +from ddtrace.internal.utils.deprecations import DDTraceDeprecationWarning +from ddtrace.vendor.debtcollector import deprecate +from ..internal.psycopg.connection import * # noqa: F401,F403 -class Psycopg3TracedConnection(dbapi.TracedConnection): - def __init__(self, conn, pin=None, cursor_cls=None): - if not cursor_cls: - # Do not trace `fetch*` methods by default - cursor_cls = Psycopg3FetchTracedCursor if config.psycopg.trace_fetch_methods else Psycopg3TracedCursor - super(Psycopg3TracedConnection, self).__init__(conn, pin, config.psycopg, cursor_cls=cursor_cls) +def __getattr__(name): + deprecate( + ("%s.%s is deprecated" % (__name__, name)), + category=DDTraceDeprecationWarning, + ) - def execute(self, *args, **kwargs): - """Execute a query and return a cursor to read its results.""" - - def patched_execute(*args, **kwargs): - try: - cur = self.cursor() - if kwargs.get("binary", None): - cur.format = 1 # set to 1 for binary or 0 if not - return cur.execute(*args, **kwargs) - except Exception as ex: - raise ex.with_traceback(None) - - return patched_execute(*args, **kwargs) - - -class Psycopg2TracedConnection(dbapi.TracedConnection): - """TracedConnection wraps a Connection with tracing code.""" - - def __init__(self, conn, pin=None, cursor_cls=None): - if not cursor_cls: - # Do not trace `fetch*` methods by default - cursor_cls = Psycopg2FetchTracedCursor if config.psycopg.trace_fetch_methods else Psycopg2TracedCursor - - super(Psycopg2TracedConnection, self).__init__(conn, pin, config.psycopg, cursor_cls=cursor_cls) - - -def patch_conn(conn, traced_conn_cls, pin=None): - """Wrap will patch the instance so that its queries are traced.""" - # ensure we've patched extensions (this is idempotent) in - # case we're only tracing some connections. - _config = None - if pin: - extensions_to_patch = pin._config.get("_extensions_to_patch", None) - _config = pin._config - if extensions_to_patch: - _patch_extensions(extensions_to_patch) - - c = traced_conn_cls(conn) - - # if the connection has an info attr, we are using psycopg3 - if hasattr(conn, "dsn"): - dsn = sql.parse_pg_dsn(conn.dsn) - else: - dsn = sql.parse_pg_dsn(conn.info.dsn) - - tags = { - net.TARGET_HOST: dsn.get("host"), - net.TARGET_PORT: dsn.get("port", 5432), - net.SERVER_ADDRESS: dsn.get("host"), - db.NAME: dsn.get("dbname"), - db.USER: dsn.get("user"), - "db.application": dsn.get("application_name"), - db.SYSTEM: "postgresql", - } - Pin(tags=tags, _config=_config).onto(c) - return c - - -def patched_connect_factory(psycopg_module): - def patched_connect(connect_func, _, args, kwargs): - traced_conn_cls = Psycopg3TracedConnection if psycopg_module.__name__ == "psycopg" else Psycopg2TracedConnection - - pin = Pin.get_from(psycopg_module) - - if not pin or not pin.enabled() or not pin._config.trace_connect: - conn = connect_func(*args, **kwargs) - else: - with pin.tracer.trace( - "{}.{}".format(connect_func.__module__, connect_func.__name__), - service=ext_service(pin, pin._config), - span_type=SpanTypes.SQL, - ) as span: - span.set_tag_str(SPAN_KIND, SpanKind.CLIENT) - span.set_tag_str(COMPONENT, pin._config.integration_name) - if span.get_tag(db.SYSTEM) is None: - span.set_tag_str(db.SYSTEM, pin._config.dbms_name) - - span.set_tag(SPAN_MEASURED_KEY) - conn = connect_func(*args, **kwargs) - - return patch_conn(conn, pin=pin, traced_conn_cls=traced_conn_cls) - - return patched_connect + if name in globals(): + return globals()[name] + raise AttributeError("%s has no attribute %s", __name__, name) diff --git a/ddtrace/contrib/psycopg/cursor.py b/ddtrace/contrib/psycopg/cursor.py index 6596b558cd3..ba496677b26 100644 --- a/ddtrace/contrib/psycopg/cursor.py +++ b/ddtrace/contrib/psycopg/cursor.py @@ -1,28 +1,15 @@ -from ddtrace.contrib import dbapi +from ddtrace.internal.utils.deprecations import DDTraceDeprecationWarning +from ddtrace.vendor.debtcollector import deprecate +from ..internal.psycopg.cursor import * # noqa: F401,F403 -class Psycopg3TracedCursor(dbapi.TracedCursor): - """TracedCursor for psycopg instances""" - def __init__(self, cursor, pin, cfg, *args, **kwargs): - super(Psycopg3TracedCursor, self).__init__(cursor, pin, cfg) +def __getattr__(name): + deprecate( + ("%s.%s is deprecated" % (__name__, name)), + category=DDTraceDeprecationWarning, + ) - def _trace_method(self, method, name, resource, extra_tags, dbm_propagator, *args, **kwargs): - # treat Composable resource objects as strings - if resource.__class__.__name__ == "SQL" or resource.__class__.__name__ == "Composed": - resource = resource.as_string(self.__wrapped__) - return super(Psycopg3TracedCursor, self)._trace_method( - method, name, resource, extra_tags, dbm_propagator, *args, **kwargs - ) - - -class Psycopg3FetchTracedCursor(Psycopg3TracedCursor, dbapi.FetchTracedCursor): - """Psycopg3FetchTracedCursor for psycopg""" - - -class Psycopg2TracedCursor(Psycopg3TracedCursor): - """TracedCursor for psycopg2""" - - -class Psycopg2FetchTracedCursor(Psycopg3FetchTracedCursor): - """FetchTracedCursor for psycopg2""" + if name in globals(): + return globals()[name] + raise AttributeError("%s has no attribute %s", __name__, name) diff --git a/ddtrace/contrib/psycopg/extensions.py b/ddtrace/contrib/psycopg/extensions.py index 746f7a6b77d..62ae3d2d3bd 100644 --- a/ddtrace/contrib/psycopg/extensions.py +++ b/ddtrace/contrib/psycopg/extensions.py @@ -1,181 +1,15 @@ -""" -Tracing utilities for the psycopg2 potgres client library. -""" -import functools +from ddtrace.internal.utils.deprecations import DDTraceDeprecationWarning +from ddtrace.vendor.debtcollector import deprecate -from ddtrace import config -from ddtrace.internal.constants import COMPONENT -from ddtrace.internal.schema import schematize_database_operation -from ddtrace.vendor import wrapt +from ..internal.psycopg.extensions import * # noqa: F401,F403 -from ...constants import SPAN_KIND -from ...constants import SPAN_MEASURED_KEY -from ...ext import SpanKind -from ...ext import SpanTypes -from ...ext import db -from ...ext import net +def __getattr__(name): + deprecate( + ("%s.%s is deprecated" % (__name__, name)), + category=DDTraceDeprecationWarning, + ) -def get_psycopg2_extensions(psycopg_module): - class TracedCursor(psycopg_module.extensions.cursor): - """Wrapper around cursor creating one span per query""" - - def __init__(self, *args, **kwargs): - self._datadog_tracer = kwargs.pop("datadog_tracer", None) - self._datadog_service = kwargs.pop("datadog_service", None) - self._datadog_tags = kwargs.pop("datadog_tags", None) - super(TracedCursor, self).__init__(*args, **kwargs) - - def execute(self, query, vars=None): # noqa: A002 - """just wrap the cursor execution in a span""" - if not self._datadog_tracer: - return psycopg_module.extensions.cursor.execute(self, query, vars) - - with self._datadog_tracer.trace( - schematize_database_operation("postgres.query", database_provider="postgresql"), - service=self._datadog_service, - span_type=SpanTypes.SQL, - ) as s: - s.set_tag_str(COMPONENT, config.psycopg.integration_name) - s.set_tag_str(db.SYSTEM, config.psycopg.dbms_name) - - # set span.kind to the type of operation being performed - s.set_tag_str(SPAN_KIND, SpanKind.CLIENT) - - s.set_tag(SPAN_MEASURED_KEY) - if s.context.sampling_priority is None or s.context.sampling_priority <= 0: - return super(TracedCursor, self).execute(query, vars) - - s.resource = query - s.set_tags(self._datadog_tags) - try: - return super(TracedCursor, self).execute(query, vars) - finally: - s.set_metric(db.ROWCOUNT, self.rowcount) - - def callproc(self, procname, vars=None): # noqa: A002 - """just wrap the execution in a span""" - return psycopg_module.extensions.cursor.callproc(self, procname, vars) - - class TracedConnection(psycopg_module.extensions.connection): - """Wrapper around psycopg2 for tracing""" - - def __init__(self, *args, **kwargs): - self._datadog_tracer = kwargs.pop("datadog_tracer", None) - self._datadog_service = kwargs.pop("datadog_service", None) - - super(TracedConnection, self).__init__(*args, **kwargs) - - # add metadata (from the connection, string, etc) - dsn = psycopg_module.extensions.parse_dsn(self.dsn) - self._datadog_tags = { - net.TARGET_HOST: dsn.get("host"), - net.TARGET_PORT: dsn.get("port"), - net.SERVER_ADDRESS: dsn.get("host"), - db.NAME: dsn.get("dbname"), - db.USER: dsn.get("user"), - db.SYSTEM: config.psycopg.dbms_name, - "db.application": dsn.get("application_name"), - } - - self._datadog_cursor_class = functools.partial( - TracedCursor, - datadog_tracer=self._datadog_tracer, - datadog_service=self._datadog_service, - datadog_tags=self._datadog_tags, - ) - - def cursor(self, *args, **kwargs): - """register our custom cursor factory""" - kwargs.setdefault("cursor_factory", self._datadog_cursor_class) - return super(TracedConnection, self).cursor(*args, **kwargs) - - # extension hooks - _extensions = [ - ( - psycopg_module.extensions.register_type, - psycopg_module.extensions, - "register_type", - _extensions_register_type, - ), - (psycopg_module._psycopg.register_type, psycopg_module._psycopg, "register_type", _extensions_register_type), - (psycopg_module.extensions.adapt, psycopg_module.extensions, "adapt", _extensions_adapt), - ] - - # `_json` attribute is only available for psycopg >= 2.5 - if getattr(psycopg_module, "_json", None): - _extensions += [ - (psycopg_module._json.register_type, psycopg_module._json, "register_type", _extensions_register_type), - ] - - # `quote_ident` attribute is only available for psycopg >= 2.7 - if getattr(psycopg_module, "extensions", None) and getattr(psycopg_module.extensions, "quote_ident", None): - _extensions += [ - (psycopg_module.extensions.quote_ident, psycopg_module.extensions, "quote_ident", _extensions_quote_ident), - ] - - return _extensions - - -def _extensions_register_type(func, _, args, kwargs): - def _unroll_args(obj, scope=None): - return obj, scope - - obj, scope = _unroll_args(*args, **kwargs) - - # register_type performs a c-level check of the object - # type so we must be sure to pass in the actual db connection - if scope and isinstance(scope, wrapt.ObjectProxy): - scope = scope.__wrapped__ - - return func(obj, scope) if scope else func(obj) - - -def _extensions_quote_ident(func, _, args, kwargs): - def _unroll_args(obj, scope=None): - return obj, scope - - obj, scope = _unroll_args(*args, **kwargs) - - # register_type performs a c-level check of the object - # type so we must be sure to pass in the actual db connection - if scope and isinstance(scope, wrapt.ObjectProxy): - scope = scope.__wrapped__ - - return func(obj, scope) if scope else func(obj) - - -def _extensions_adapt(func, _, args, kwargs): - adapt = func(*args, **kwargs) - if hasattr(adapt, "prepare"): - return AdapterWrapper(adapt) - return adapt - - -class AdapterWrapper(wrapt.ObjectProxy): - def prepare(self, *args, **kwargs): - func = self.__wrapped__.prepare - if not args: - return func(*args, **kwargs) - conn = args[0] - - # prepare performs a c-level check of the object type so - # we must be sure to pass in the actual db connection - if isinstance(conn, wrapt.ObjectProxy): - conn = conn.__wrapped__ - - return func(conn, *args[1:], **kwargs) - - -def _patch_extensions(_extensions): - # we must patch extensions all the time (it's pretty harmless) so split - # from global patching of connections. must be idempotent. - for _, module, func, wrapper in _extensions: - if not hasattr(module, func) or isinstance(getattr(module, func), wrapt.ObjectProxy): - continue - wrapt.wrap_function_wrapper(module, func, wrapper) - - -def _unpatch_extensions(_extensions): - for original, module, func, _ in _extensions: - setattr(module, func, original) + if name in globals(): + return globals()[name] + raise AttributeError("%s has no attribute %s", __name__, name) diff --git a/ddtrace/contrib/psycopg/patch.py b/ddtrace/contrib/psycopg/patch.py index 38a70d7768f..91676a9aff5 100644 --- a/ddtrace/contrib/psycopg/patch.py +++ b/ddtrace/contrib/psycopg/patch.py @@ -1,246 +1,4 @@ -from importlib import import_module -import inspect -import os -from typing import List # noqa:F401 +from ..internal.psycopg.patch import * # noqa: F401,F403 -from ddtrace import Pin -from ddtrace import config -from ddtrace.contrib import dbapi - -try: - from ddtrace.contrib.psycopg.async_connection import patched_connect_async_factory - from ddtrace.contrib.psycopg.async_cursor import Psycopg3FetchTracedAsyncCursor - from ddtrace.contrib.psycopg.async_cursor import Psycopg3TracedAsyncCursor -# catch async function syntax errors when using Python<3.7 with no async support -except SyntaxError: - pass -from ddtrace.contrib.psycopg.connection import patched_connect_factory -from ddtrace.contrib.psycopg.cursor import Psycopg3FetchTracedCursor -from ddtrace.contrib.psycopg.cursor import Psycopg3TracedCursor -from ddtrace.contrib.psycopg.extensions import _patch_extensions -from ddtrace.contrib.psycopg.extensions import _unpatch_extensions -from ddtrace.contrib.psycopg.extensions import get_psycopg2_extensions -from ddtrace.internal.utils.deprecations import DDTraceDeprecationWarning -from ddtrace.propagation._database_monitoring import default_sql_injector as _default_sql_injector -from ddtrace.vendor.debtcollector import deprecate -from ddtrace.vendor.wrapt import wrap_function_wrapper as _w - -from ...internal.schema import schematize_database_operation -from ...internal.schema import schematize_service_name -from ...internal.utils.formats import asbool -from ...internal.utils.wrappers import unwrap as _u -from ...propagation._database_monitoring import _DBM_Propagator - - -try: - psycopg_import = import_module("psycopg") - - # must get the original connect class method from the class __dict__ to use later in unpatch - # Python 3.11 and wrapt result in the class method being rebinded as an instance method when - # using unwrap - _original_connect = psycopg_import.Connection.__dict__["connect"] - _original_async_connect = psycopg_import.AsyncConnection.__dict__["connect"] -# AttributeError can happen due to circular imports under certain integration methods -except (ImportError, AttributeError): - pass - - -def _psycopg_sql_injector(dbm_comment, sql_statement): - for psycopg_module in config.psycopg["_patched_modules"]: - if ( - hasattr(psycopg_module, "sql") - and hasattr(psycopg_module.sql, "Composable") - and isinstance(sql_statement, psycopg_module.sql.Composable) - ): - return psycopg_module.sql.SQL(dbm_comment) + sql_statement - return _default_sql_injector(dbm_comment, sql_statement) - - -config._add( - "psycopg", - dict( - _default_service=schematize_service_name("postgres"), - _dbapi_span_name_prefix="postgres", - _dbapi_span_operation_name=schematize_database_operation("postgres.query", database_provider="postgresql"), - _patched_modules=set(), - trace_fetch_methods=asbool( - os.getenv("DD_PSYCOPG_TRACE_FETCH_METHODS", default=False) - or os.getenv("DD_PSYCOPG2_TRACE_FETCH_METHODS", default=False) - ), - trace_connect=asbool( - os.getenv("DD_PSYCOPG_TRACE_CONNECT", default=False) - or os.getenv("DD_PSYCOPG2_TRACE_CONNECT", default=False) - ), - _dbm_propagator=_DBM_Propagator(0, "query", _psycopg_sql_injector), - dbms_name="postgresql", - ), -) - - -def _get_version(): - # type: () -> str - return "" - - -def get_version(): - deprecate( - "get_version is deprecated", - message="get_version is deprecated", - removal_version="3.0.0", - category=DDTraceDeprecationWarning, - ) - return _get_version() - - -PATCHED_VERSIONS = {} - - -def _get_versions(): - # type: () -> List[str] - return PATCHED_VERSIONS - - -def get_versions(): - deprecate( - "get_versions is deprecated", - message="get_versions is deprecated", - removal_version="3.0.0", - category=DDTraceDeprecationWarning, - ) - return _get_versions() - - -def _psycopg_modules(): - module_names = ( - "psycopg", - "psycopg2", - ) - for module_name in module_names: - try: - module = import_module(module_name) - PATCHED_VERSIONS[module_name] = getattr(module, "__version__", "") - yield module - except ImportError: - pass - - -def patch(): - for psycopg_module in _psycopg_modules(): - _patch(psycopg_module) - - -def _patch(psycopg_module): - """Patch monkey patches psycopg's connection function - so that the connection's functions are traced. - """ - if getattr(psycopg_module, "_datadog_patch", False): - return - psycopg_module._datadog_patch = True - - Pin(_config=config.psycopg).onto(psycopg_module) - - if psycopg_module.__name__ == "psycopg2": - # patch all psycopg2 extensions - _psycopg2_extensions = get_psycopg2_extensions(psycopg_module) - config.psycopg["_extensions_to_patch"] = _psycopg2_extensions - _patch_extensions(_psycopg2_extensions) - - _w(psycopg_module, "connect", patched_connect_factory(psycopg_module)) - - config.psycopg["_patched_modules"].add(psycopg_module) - else: - _w(psycopg_module, "connect", patched_connect_factory(psycopg_module)) - _w(psycopg_module, "Cursor", _init_cursor_from_connection_factory(psycopg_module)) - _w(psycopg_module, "AsyncCursor", _init_cursor_from_connection_factory(psycopg_module)) - - _w(psycopg_module.Connection, "connect", patched_connect_factory(psycopg_module)) - _w(psycopg_module.AsyncConnection, "connect", patched_connect_async_factory(psycopg_module)) - - config.psycopg["_patched_modules"].add(psycopg_module) - - -def unpatch(): - for psycopg_module in _psycopg_modules(): - _unpatch(psycopg_module) - - -def _unpatch(psycopg_module): - if getattr(psycopg_module, "_datadog_patch", False): - psycopg_module._datadog_patch = False - - if psycopg_module.__name__ == "psycopg2": - _u(psycopg_module, "connect") - - _psycopg2_extensions = get_psycopg2_extensions(psycopg_module) - _unpatch_extensions(_psycopg2_extensions) - else: - _u(psycopg_module, "connect") - _u(psycopg_module, "Cursor") - _u(psycopg_module, "AsyncCursor") - - # _u throws an attribute error for Python 3.11, no __get__ on the BoundFunctionWrapper - # unlike Python Class Methods which implement __get__ - psycopg_module.Connection.connect = _original_connect - psycopg_module.AsyncConnection.connect = _original_async_connect - - pin = Pin.get_from(psycopg_module) - if pin: - pin.remove_from(psycopg_module) - - -def _init_cursor_from_connection_factory(psycopg_module): - def init_cursor_from_connection(wrapped_cursor_cls, _, args, kwargs): - connection = kwargs.pop("connection", None) - if not connection: - args = list(args) - index = next((i for i, x in enumerate(args) if isinstance(x, dbapi.TracedConnection)), None) - if index is not None: - connection = args.pop(index) - - # if we do not have an example of a traced connection, call the original cursor function - if not connection: - return wrapped_cursor_cls(*args, **kwargs) - - pin = Pin.get_from(connection).clone() - cfg = config.psycopg - - if cfg and cfg.trace_fetch_methods: - trace_fetch_methods = True - else: - trace_fetch_methods = False - - if issubclass(wrapped_cursor_cls, psycopg_module.AsyncCursor): - traced_cursor_cls = Psycopg3FetchTracedAsyncCursor if trace_fetch_methods else Psycopg3TracedAsyncCursor - else: - traced_cursor_cls = Psycopg3FetchTracedCursor if trace_fetch_methods else Psycopg3TracedCursor - - args_mapping = inspect.signature(wrapped_cursor_cls.__init__).parameters - # inspect.signature returns ordered dict[argument_name: str, parameter_type: type] - if "row_factory" in args_mapping and "row_factory" not in kwargs: - # check for row_factory in args by checking for functions - row_factory = None - for i in range(len(args)): - if callable(args[i]): - row_factory = args.pop(i) - break - # else just use the connection row factory - if row_factory is None: - row_factory = connection.row_factory - cursor = wrapped_cursor_cls(connection=connection, row_factory=row_factory, *args, **kwargs) # noqa: B026 - else: - cursor = wrapped_cursor_cls(connection, *args, **kwargs) - - return traced_cursor_cls(cursor=cursor, pin=pin, cfg=cfg) - - return init_cursor_from_connection - - -def init_cursor_from_connection_factory(psycopg_module): - deprecate( - "init_cursor_from_connection_factory is deprecated", - message="init_cursor_from_connection_factory is deprecated", - removal_version="3.0.0", - category=DDTraceDeprecationWarning, - ) - return _init_cursor_from_connection_factory(psycopg_module) +# TODO: deprecate and remove this module diff --git a/ddtrace/contrib/pylibmc/__init__.py b/ddtrace/contrib/pylibmc/__init__.py index 2480dd1a313..c54e8c0a690 100644 --- a/ddtrace/contrib/pylibmc/__init__.py +++ b/ddtrace/contrib/pylibmc/__init__.py @@ -26,8 +26,12 @@ with require_modules(required_modules) as missing_modules: if not missing_modules: - from .client import TracedClient - from .patch import get_version - from .patch import patch + # Required to allow users to import from `ddtrace.contrib.pylibmc.patch` directly + from . import patch as _ # noqa: F401, I001 + + # Expose public methods + from ..internal.pylibmc.client import TracedClient + from ..internal.pylibmc.patch import get_version + from ..internal.pylibmc.patch import patch __all__ = ["TracedClient", "patch", "get_version"] diff --git a/ddtrace/contrib/pylibmc/addrs.py b/ddtrace/contrib/pylibmc/addrs.py index 0f11d2ac44c..40d0bfc3a51 100644 --- a/ddtrace/contrib/pylibmc/addrs.py +++ b/ddtrace/contrib/pylibmc/addrs.py @@ -1,14 +1,15 @@ -translate_server_specs = None +from ddtrace.internal.utils.deprecations import DDTraceDeprecationWarning +from ddtrace.vendor.debtcollector import deprecate -try: - # NOTE: we rely on an undocumented method to parse addresses, - # so be a bit defensive and don't assume it exists. - from pylibmc.client import translate_server_specs -except ImportError: - pass +from ..internal.pylibmc.addrs import * # noqa: F401,F403 -def parse_addresses(addrs): - if not translate_server_specs: - return [] - return translate_server_specs(addrs) +def __getattr__(name): + deprecate( + ("%s.%s is deprecated" % (__name__, name)), + category=DDTraceDeprecationWarning, + ) + + if name in globals(): + return globals()[name] + raise AttributeError("%s has no attribute %s", __name__, name) diff --git a/ddtrace/contrib/pylibmc/client.py b/ddtrace/contrib/pylibmc/client.py index fd164ffbd93..ecce78b670a 100644 --- a/ddtrace/contrib/pylibmc/client.py +++ b/ddtrace/contrib/pylibmc/client.py @@ -1,193 +1,15 @@ -from contextlib import contextmanager -import random +from ddtrace.internal.utils.deprecations import DDTraceDeprecationWarning +from ddtrace.vendor.debtcollector import deprecate -import pylibmc +from ..internal.pylibmc.client import * # noqa: F401,F403 -# project -import ddtrace -from ddtrace import config -from ddtrace.constants import ANALYTICS_SAMPLE_RATE_KEY -from ddtrace.constants import SPAN_KIND -from ddtrace.constants import SPAN_MEASURED_KEY -from ddtrace.contrib.pylibmc.addrs import parse_addresses -from ddtrace.ext import SpanKind -from ddtrace.ext import SpanTypes -from ddtrace.ext import db -from ddtrace.ext import memcached -from ddtrace.ext import net -from ddtrace.internal.compat import Iterable -from ddtrace.internal.constants import COMPONENT -from ddtrace.internal.logger import get_logger -from ddtrace.internal.schema import schematize_cache_operation -from ddtrace.internal.schema import schematize_service_name -from ddtrace.vendor.wrapt import ObjectProxy +def __getattr__(name): + deprecate( + ("%s.%s is deprecated" % (__name__, name)), + category=DDTraceDeprecationWarning, + ) -# Original Client class -_Client = pylibmc.Client - - -log = get_logger(__name__) - - -class TracedClient(ObjectProxy): - """TracedClient is a proxy for a pylibmc.Client that times it's network operations.""" - - def __init__(self, client=None, service=memcached.SERVICE, tracer=None, *args, **kwargs): - """Create a traced client that wraps the given memcached client.""" - # The client instance/service/tracer attributes are kept for compatibility - # with the old interface: TracedClient(client=pylibmc.Client(['localhost:11211'])) - # TODO(Benjamin): Remove these in favor of patching. - if not isinstance(client, _Client): - # We are in the patched situation, just pass down all arguments to the pylibmc.Client - # Note that, in that case, client isn't a real client (just the first argument) - client = _Client(client, *args, **kwargs) - else: - log.warning( - "TracedClient instantiation is deprecated and will be remove " - "in future versions (0.6.0). Use patching instead (see the docs)." - ) - - super(TracedClient, self).__init__(client) - - schematized_service = schematize_service_name(service) - pin = ddtrace.Pin(service=schematized_service, tracer=tracer) - pin.onto(self) - - # attempt to collect the pool of urls this client talks to - try: - self._addresses = parse_addresses(client.addresses) - except Exception: - log.debug("error setting addresses", exc_info=True) - - def clone(self, *args, **kwargs): - # rewrap new connections. - cloned = self.__wrapped__.clone(*args, **kwargs) - traced_client = TracedClient(cloned) - pin = ddtrace.Pin.get_from(self) - if pin: - pin.clone().onto(traced_client) - return traced_client - - def add(self, *args, **kwargs): - return self._trace_cmd("add", *args, **kwargs) - - def get(self, *args, **kwargs): - return self._trace_cmd("get", *args, **kwargs) - - def set(self, *args, **kwargs): - return self._trace_cmd("set", *args, **kwargs) - - def delete(self, *args, **kwargs): - return self._trace_cmd("delete", *args, **kwargs) - - def gets(self, *args, **kwargs): - return self._trace_cmd("gets", *args, **kwargs) - - def touch(self, *args, **kwargs): - return self._trace_cmd("touch", *args, **kwargs) - - def cas(self, *args, **kwargs): - return self._trace_cmd("cas", *args, **kwargs) - - def incr(self, *args, **kwargs): - return self._trace_cmd("incr", *args, **kwargs) - - def decr(self, *args, **kwargs): - return self._trace_cmd("decr", *args, **kwargs) - - def append(self, *args, **kwargs): - return self._trace_cmd("append", *args, **kwargs) - - def prepend(self, *args, **kwargs): - return self._trace_cmd("prepend", *args, **kwargs) - - def get_multi(self, *args, **kwargs): - return self._trace_multi_cmd("get_multi", *args, **kwargs) - - def set_multi(self, *args, **kwargs): - return self._trace_multi_cmd("set_multi", *args, **kwargs) - - def delete_multi(self, *args, **kwargs): - return self._trace_multi_cmd("delete_multi", *args, **kwargs) - - def _trace_cmd(self, method_name, *args, **kwargs): - """trace the execution of the method with the given name and will - patch the first arg. - """ - method = getattr(self.__wrapped__, method_name) - with self._span(method_name) as span: - result = method(*args, **kwargs) - if span is None: - return result - - if args: - span.set_tag_str(memcached.QUERY, "%s %s" % (method_name, args[0])) - if method_name == "get": - span.set_metric(db.ROWCOUNT, 1 if result else 0) - elif method_name == "gets": - # returns a tuple object that may be (None, None) - span.set_metric(db.ROWCOUNT, 1 if isinstance(result, Iterable) and len(result) > 0 and result[0] else 0) - return result - - def _trace_multi_cmd(self, method_name, *args, **kwargs): - """trace the execution of the multi command with the given name.""" - method = getattr(self.__wrapped__, method_name) - with self._span(method_name) as span: - result = method(*args, **kwargs) - if span is None: - return result - - pre = kwargs.get("key_prefix") - if pre: - span.set_tag_str(memcached.QUERY, "%s %s" % (method_name, pre)) - - if method_name == "get_multi": - # returns mapping of key -> value if key exists, but does not include a missing key. Empty result = {} - span.set_metric( - db.ROWCOUNT, sum(1 for doc in result if doc) if result and isinstance(result, Iterable) else 0 - ) - return result - - @contextmanager - def _no_span(self): - yield None - - def _span(self, cmd_name): - """Return a span timing the given command.""" - pin = ddtrace.Pin.get_from(self) - if not pin or not pin.enabled(): - return self._no_span() - - span = pin.tracer.trace( - schematize_cache_operation("memcached.cmd", cache_provider="memcached"), - service=pin.service, - resource=cmd_name, - span_type=SpanTypes.CACHE, - ) - - span.set_tag_str(COMPONENT, config.pylibmc.integration_name) - span.set_tag_str(db.SYSTEM, memcached.DBMS_NAME) - - # set span.kind to the type of operation being performed - span.set_tag_str(SPAN_KIND, SpanKind.CLIENT) - - span.set_tag(SPAN_MEASURED_KEY) - - try: - self._tag_span(span) - except Exception: - log.debug("error tagging span", exc_info=True) - return span - - def _tag_span(self, span): - # FIXME[matt] the host selection is buried in c code. we can't tell what it's actually - # using, so fallback to randomly choosing one. can we do better? - if self._addresses: - _, host, port, _ = random.choice(self._addresses) # nosec - span.set_tag_str(net.TARGET_HOST, host) - span.set_tag(net.TARGET_PORT, port) - span.set_tag_str(net.SERVER_ADDRESS, host) - - # set analytics sample rate - span.set_tag(ANALYTICS_SAMPLE_RATE_KEY, config.pylibmc.get_analytics_sample_rate()) + if name in globals(): + return globals()[name] + raise AttributeError("%s has no attribute %s", __name__, name) diff --git a/ddtrace/contrib/pylibmc/patch.py b/ddtrace/contrib/pylibmc/patch.py index faa77b65c95..b1934884365 100644 --- a/ddtrace/contrib/pylibmc/patch.py +++ b/ddtrace/contrib/pylibmc/patch.py @@ -1,33 +1,4 @@ -import pylibmc +from ..internal.pylibmc.patch import * # noqa: F401,F403 -from ddtrace.internal.utils.deprecations import DDTraceDeprecationWarning -from ddtrace.vendor.debtcollector import deprecate -from .client import TracedClient - - -# Original Client class -_Client = pylibmc.Client - - -def _get_version(): - # type: () -> str - return getattr(pylibmc, "__version__", "") - - -def get_version(): - deprecate( - "get_version is deprecated", - message="get_version is deprecated", - removal_version="3.0.0", - category=DDTraceDeprecationWarning, - ) - return _get_version() - - -def patch(): - pylibmc.Client = TracedClient - - -def unpatch(): - pylibmc.Client = _Client +# TODO: deprecate and remove this module diff --git a/ddtrace/contrib/pymemcache/__init__.py b/ddtrace/contrib/pymemcache/__init__.py index 25f93549652..647fb092e63 100644 --- a/ddtrace/contrib/pymemcache/__init__.py +++ b/ddtrace/contrib/pymemcache/__init__.py @@ -37,8 +37,12 @@ with require_modules(required_modules) as missing_modules: if not missing_modules: - from .patch import get_version - from .patch import patch - from .patch import unpatch + # Required to allow users to import from `ddtrace.contrib.pymemcache.patch` directly + from . import patch as _ # noqa: F401, I001 + + # Expose public methods + from ..internal.pymemcache.patch import get_version + from ..internal.pymemcache.patch import patch + from ..internal.pymemcache.patch import unpatch __all__ = ["patch", "unpatch", "get_version"] diff --git a/ddtrace/contrib/pymemcache/client.py b/ddtrace/contrib/pymemcache/client.py index d4f037e6160..84a5be9f31c 100644 --- a/ddtrace/contrib/pymemcache/client.py +++ b/ddtrace/contrib/pymemcache/client.py @@ -1,362 +1,15 @@ -import os -import sys -from typing import Iterable +from ddtrace.internal.utils.deprecations import DDTraceDeprecationWarning +from ddtrace.vendor.debtcollector import deprecate -import pymemcache -from pymemcache.client.base import Client -from pymemcache.client.base import PooledClient -from pymemcache.client.hash import HashClient -from pymemcache.exceptions import MemcacheClientError -from pymemcache.exceptions import MemcacheIllegalInputError -from pymemcache.exceptions import MemcacheServerError -from pymemcache.exceptions import MemcacheUnknownCommandError -from pymemcache.exceptions import MemcacheUnknownError +from ..internal.pymemcache.client import * # noqa: F401,F403 -# 3p -from ddtrace import config -from ddtrace.internal.constants import COMPONENT -from ddtrace.vendor import wrapt -# project -from ...constants import ANALYTICS_SAMPLE_RATE_KEY -from ...constants import SPAN_KIND -from ...constants import SPAN_MEASURED_KEY -from ...ext import SpanKind -from ...ext import SpanTypes -from ...ext import db -from ...ext import memcached as memcachedx -from ...ext import net -from ...internal.logger import get_logger -from ...internal.schema import schematize_cache_operation -from ...internal.utils.formats import asbool -from ...pin import Pin +def __getattr__(name): + deprecate( + ("%s.%s is deprecated" % (__name__, name)), + category=DDTraceDeprecationWarning, + ) - -log = get_logger(__name__) - - -config._add( - "pymemcache", - { - "command_enabled": asbool(os.getenv("DD_TRACE_MEMCACHED_COMMAND_ENABLED", default=False)), - }, -) - - -# keep a reference to the original unpatched clients -_Client = Client -_HashClient = HashClient - - -class _WrapperBase(wrapt.ObjectProxy): - def __init__(self, wrapped_class, *args, **kwargs): - c = wrapped_class(*args, **kwargs) - super(_WrapperBase, self).__init__(c) - - # tags to apply to each span generated by this client - tags = _get_address_tags(*args, **kwargs) - - parent_pin = Pin.get_from(pymemcache) - - if parent_pin: - pin = parent_pin.clone(tags=tags) - else: - pin = Pin(tags=tags) - - # attach the pin onto this instance - pin.onto(self) - - def _trace_function_as_command(self, func, cmd, *args, **kwargs): - p = Pin.get_from(self) - - if not p or not p.enabled(): - return func(*args, **kwargs) - - return _trace(func, p, cmd, *args, **kwargs) - - -class WrappedClient(_WrapperBase): - """Wrapper providing patched methods of a pymemcache Client. - - Relevant connection information is obtained during initialization and - attached to each span. - - Keys are tagged in spans for methods that act upon a key. - """ - - def __init__(self, *args, **kwargs): - super(WrappedClient, self).__init__(_Client, *args, **kwargs) - - def set(self, *args, **kwargs): - return self._traced_cmd("set", *args, **kwargs) - - def set_many(self, *args, **kwargs): - return self._traced_cmd("set_many", *args, **kwargs) - - def add(self, *args, **kwargs): - return self._traced_cmd("add", *args, **kwargs) - - def replace(self, *args, **kwargs): - return self._traced_cmd("replace", *args, **kwargs) - - def append(self, *args, **kwargs): - return self._traced_cmd("append", *args, **kwargs) - - def prepend(self, *args, **kwargs): - return self._traced_cmd("prepend", *args, **kwargs) - - def cas(self, *args, **kwargs): - return self._traced_cmd("cas", *args, **kwargs) - - def get(self, *args, **kwargs): - return self._traced_cmd("get", *args, **kwargs) - - def get_many(self, *args, **kwargs): - return self._traced_cmd("get_many", *args, **kwargs) - - def gets(self, *args, **kwargs): - return self._traced_cmd("gets", *args, **kwargs) - - def gets_many(self, *args, **kwargs): - return self._traced_cmd("gets_many", *args, **kwargs) - - def delete(self, *args, **kwargs): - return self._traced_cmd("delete", *args, **kwargs) - - def delete_many(self, *args, **kwargs): - return self._traced_cmd("delete_many", *args, **kwargs) - - def incr(self, *args, **kwargs): - return self._traced_cmd("incr", *args, **kwargs) - - def decr(self, *args, **kwargs): - return self._traced_cmd("decr", *args, **kwargs) - - def touch(self, *args, **kwargs): - return self._traced_cmd("touch", *args, **kwargs) - - def stats(self, *args, **kwargs): - return self._traced_cmd("stats", *args, **kwargs) - - def version(self, *args, **kwargs): - return self._traced_cmd("version", *args, **kwargs) - - def flush_all(self, *args, **kwargs): - return self._traced_cmd("flush_all", *args, **kwargs) - - def quit(self, *args, **kwargs): - return self._traced_cmd("quit", *args, **kwargs) - - def set_multi(self, *args, **kwargs): - """set_multi is an alias for set_many""" - return self._traced_cmd("set_many", *args, **kwargs) - - def get_multi(self, *args, **kwargs): - """set_multi is an alias for set_many""" - return self._traced_cmd("get_many", *args, **kwargs) - - def _traced_cmd(self, command, *args, **kwargs): - return self._trace_function_as_command( - lambda *_args, **_kwargs: getattr(self.__wrapped__, command)(*_args, **_kwargs), command, *args, **kwargs - ) - - -class WrappedHashClient(_WrapperBase): - """Wrapper that traces HashClient commands - - This wrapper proxies its command invocations to the underlying HashClient instance. - When the use_pooling setting is in use, this wrapper starts a span before - doing the proxy call. - - This is necessary because the use_pooling setting causes Client instances to be - created and destroyed dynamically in a manner that isn't affected by the - patch() function. - """ - - def _ensure_traced(self, cmd, key, default_val, *args, **kwargs): - """ - PooledClient creates Client instances dynamically on request, which means - those Client instances aren't affected by the wrappers applied in patch(). - We handle this case here by calling trace() before running the command, - specifically when the client that will be used for the command is a - PooledClient. - - To avoid double-tracing when the key's client is not a PooledClient, we - don't create a span and instead rely on patch(). In this case the - underlying Client instance is long-lived and has been patched already. - """ - client_for_key = self._get_client(key) - if isinstance(client_for_key, PooledClient): - return self._traced_cmd(cmd, client_for_key, key, default_val, *args, **kwargs) - else: - return getattr(self.__wrapped__, cmd)(key, *args, **kwargs) - - def __init__(self, *args, **kwargs): - super(WrappedHashClient, self).__init__(_HashClient, *args, **kwargs) - - def set(self, key, *args, **kwargs): - return self._ensure_traced("set", key, False, *args, **kwargs) - - def add(self, key, *args, **kwargs): - return self._ensure_traced("add", key, False, *args, **kwargs) - - def replace(self, key, *args, **kwargs): - return self._ensure_traced("replace", key, False, *args, **kwargs) - - def append(self, key, *args, **kwargs): - return self._ensure_traced("append", key, False, *args, **kwargs) - - def prepend(self, key, *args, **kwargs): - return self._ensure_traced("prepend", key, False, *args, **kwargs) - - def cas(self, key, *args, **kwargs): - return self._ensure_traced("cas", key, False, *args, **kwargs) - - def get(self, key, *args, **kwargs): - return self._ensure_traced("get", key, None, *args, **kwargs) - - def gets(self, key, *args, **kwargs): - return self._ensure_traced("gets", key, None, *args, **kwargs) - - def delete(self, key, *args, **kwargs): - return self._ensure_traced("delete", key, False, *args, **kwargs) - - def incr(self, key, *args, **kwargs): - return self._ensure_traced("incr", key, False, *args, **kwargs) - - def decr(self, key, *args, **kwargs): - return self._ensure_traced("decr", key, False, *args, **kwargs) - - def touch(self, key, *args, **kwargs): - return self._ensure_traced("touch", key, False, *args, **kwargs) - - def _traced_cmd(self, command, client, key, default_val, *args, **kwargs): - # NB this function mimics the logic of HashClient._run_cmd, tracing the call to _safely_run_func - if client is None: - return default_val - - args = list(args) - args.insert(0, key) - - return self._trace_function_as_command( - lambda *_args, **_kwargs: self._safely_run_func( - client, getattr(client, command), default_val, *_args, **_kwargs - ), - command, - *args, - **kwargs, - ) - - -_HashClient.client_class = WrappedClient - - -def _get_address_tags(*args, **kwargs): - """Attempt to get host and port from args passed to Client initializer.""" - tags = {} - try: - if len(args): - host, port = args[0] - tags[net.TARGET_HOST] = host - tags[net.TARGET_PORT] = port - tags[net.SERVER_ADDRESS] = host - except Exception: - log.debug("Error collecting client address tags") - - return tags - - -def _get_query_string(args): - """Return the query values given the arguments to a pymemcache command. - - If there are multiple query values, they are joined together - space-separated. - """ - keys = "" - - # shortcut if no args - if not args: - return keys - - # pull out the first arg which will contain any key - arg = args[0] - - # if we get a dict, convert to list of keys - if type(arg) is dict: - arg = list(arg) - - if type(arg) is str: - keys = arg - elif type(arg) is bytes: - keys = arg.decode() - elif type(arg) is list and len(arg): - if type(arg[0]) is str: - keys = " ".join(arg) - elif type(arg[0]) is bytes: - keys = b" ".join(arg).decode() - - return keys - - -def _trace(func, p, method_name, *args, **kwargs): - """Run and trace the given command. - - Any pymemcache exception is caught and span error information is - set. The exception is then reraised for the application to handle - appropriately. - - Relevant tags are set in the span. - """ - with p.tracer.trace( - schematize_cache_operation(memcachedx.CMD, cache_provider="memcached"), - service=p.service, - resource=method_name, - span_type=SpanTypes.CACHE, - ) as span: - span.set_tag_str(COMPONENT, config.pymemcache.integration_name) - span.set_tag_str(db.SYSTEM, memcachedx.DBMS_NAME) - - # set span.kind to the type of operation being performed - span.set_tag_str(SPAN_KIND, SpanKind.CLIENT) - - span.set_tag(SPAN_MEASURED_KEY) - # set analytics sample rate - span.set_tag(ANALYTICS_SAMPLE_RATE_KEY, config.pymemcache.get_analytics_sample_rate()) - - # try to set relevant tags, catch any exceptions so we don't mess - # with the application - try: - span.set_tags(p.tags) - if config.pymemcache.command_enabled: - vals = _get_query_string(args) - query = "{}{}{}".format(method_name, " " if vals else "", vals) - span.set_tag_str(memcachedx.QUERY, query) - except Exception: - log.debug("Error setting relevant pymemcache tags") - - try: - result = func(*args, **kwargs) - - if method_name == "get_many" or method_name == "gets_many": - # gets_many returns a map of key -> (value, cas), else an empty dict if no matches - # get many returns a map with values, else an empty map if no matches - span.set_metric( - db.ROWCOUNT, sum(1 for doc in result if doc) if result and isinstance(result, Iterable) else 0 - ) - elif method_name == "get": - # get returns key or None - span.set_metric(db.ROWCOUNT, 1 if result else 0) - elif method_name == "gets": - # gets returns a tuple of (None, None) if key not found, else tuple of (key, index) - span.set_metric(db.ROWCOUNT, 1 if result[0] else 0) - return result - except ( - MemcacheClientError, - MemcacheServerError, - MemcacheUnknownCommandError, - MemcacheUnknownError, - MemcacheIllegalInputError, - ): - (typ, val, tb) = sys.exc_info() - span.set_exc_info(typ, val, tb) - raise + if name in globals(): + return globals()[name] + raise AttributeError("%s has no attribute %s", __name__, name) diff --git a/ddtrace/contrib/pymemcache/patch.py b/ddtrace/contrib/pymemcache/patch.py index f53a4e013ae..00fcf02a3c1 100644 --- a/ddtrace/contrib/pymemcache/patch.py +++ b/ddtrace/contrib/pymemcache/patch.py @@ -1,61 +1,4 @@ -import pymemcache -import pymemcache.client.hash +from ..internal.pymemcache.patch import * # noqa: F401,F403 -from ddtrace.ext import memcached as memcachedx -from ddtrace.internal.schema import schematize_service_name -from ddtrace.internal.utils.deprecations import DDTraceDeprecationWarning -from ddtrace.pin import _DD_PIN_NAME -from ddtrace.pin import _DD_PIN_PROXY_NAME -from ddtrace.pin import Pin -from ddtrace.vendor.debtcollector import deprecate -from .client import WrappedClient -from .client import WrappedHashClient - - -_Client = pymemcache.client.base.Client -_hash_Client = pymemcache.client.hash.Client -_hash_HashClient = pymemcache.client.hash.Client - - -def _get_version(): - # type: () -> str - return getattr(pymemcache, "__version__", "") - - -def get_version(): - deprecate( - "get_version is deprecated", - message="get_version is deprecated", - removal_version="3.0.0", - category=DDTraceDeprecationWarning, - ) - return _get_version() - - -def patch(): - if getattr(pymemcache.client, "_datadog_patch", False): - return - - pymemcache.client._datadog_patch = True - pymemcache.client.base.Client = WrappedClient - pymemcache.client.hash.Client = WrappedClient - pymemcache.client.hash.HashClient = WrappedHashClient - - # Create a global pin with default configuration for our pymemcache clients - service = schematize_service_name(memcachedx.SERVICE) - Pin(service=service).onto(pymemcache) - - -def unpatch(): - """Remove pymemcache tracing""" - if not getattr(pymemcache.client, "_datadog_patch", False): - return - pymemcache.client._datadog_patch = False - pymemcache.client.base.Client = _Client - pymemcache.client.hash.Client = _hash_Client - pymemcache.client.hash.HashClient = _hash_HashClient - - # Remove any pins that may exist on the pymemcache reference - setattr(pymemcache, _DD_PIN_NAME, None) - setattr(pymemcache, _DD_PIN_PROXY_NAME, None) +# TODO: deprecate and remove this module diff --git a/ddtrace/contrib/pymongo/__init__.py b/ddtrace/contrib/pymongo/__init__.py index c653dce3335..04f206a6e1b 100644 --- a/ddtrace/contrib/pymongo/__init__.py +++ b/ddtrace/contrib/pymongo/__init__.py @@ -42,7 +42,11 @@ with require_modules(required_modules) as missing_modules: if not missing_modules: - from .patch import get_version - from .patch import patch + # Required to allow users to import from `ddtrace.contrib.pymongo.patch` directly + from . import patch as _ # noqa: F401, I001 + + # Expose public methods + from ..internal.pymongo.patch import get_version + from ..internal.pymongo.patch import patch __all__ = ["patch", "get_version"] diff --git a/ddtrace/contrib/pymongo/client.py b/ddtrace/contrib/pymongo/client.py index be2c217834b..64fa5cd0190 100644 --- a/ddtrace/contrib/pymongo/client.py +++ b/ddtrace/contrib/pymongo/client.py @@ -1,372 +1,15 @@ -# stdlib -import contextlib -import json -from typing import Iterable +from ddtrace.internal.utils.deprecations import DDTraceDeprecationWarning +from ddtrace.vendor.debtcollector import deprecate -# 3p -import pymongo +from ..internal.pymongo.client import * # noqa: F401,F403 -# project -import ddtrace -from ddtrace import config -from ddtrace.internal.constants import COMPONENT -from ddtrace.vendor.wrapt import ObjectProxy -from ...constants import ANALYTICS_SAMPLE_RATE_KEY -from ...constants import SPAN_KIND -from ...constants import SPAN_MEASURED_KEY -from ...ext import SpanKind -from ...ext import SpanTypes -from ...ext import db -from ...ext import mongo as mongox -from ...ext import net as netx -from ...internal.logger import get_logger -from ...internal.schema import schematize_database_operation -from ...internal.schema import schematize_service_name -from ...internal.utils import get_argument_value -from .parse import parse_msg -from .parse import parse_query -from .parse import parse_spec +def __getattr__(name): + deprecate( + ("%s.%s is deprecated" % (__name__, name)), + category=DDTraceDeprecationWarning, + ) - -BATCH_PARTIAL_KEY = "Batch" - -# Original Client class -_MongoClient = pymongo.MongoClient - -VERSION = pymongo.version_tuple - -if VERSION < (3, 6, 0): - from pymongo.helpers import _unpack_response - - -log = get_logger(__name__) - -_DEFAULT_SERVICE = schematize_service_name("pymongo") - - -class TracedMongoClient(ObjectProxy): - def __init__(self, client=None, *args, **kwargs): - # To support the former trace_mongo_client interface, we have to keep this old interface - # TODO(Benjamin): drop it in a later version - if not isinstance(client, _MongoClient): - # Patched interface, instantiate the client - - # client is just the first arg which could be the host if it is - # None, then it could be that the caller: - - # if client is None then __init__ was: - # 1) invoked with host=None - # 2) not given a first argument (client defaults to None) - # we cannot tell which case it is, but it should not matter since - # the default value for host is None, in either case we can simply - # not provide it as an argument - if client is None: - client = _MongoClient(*args, **kwargs) - # else client is a value for host so just pass it along - else: - client = _MongoClient(client, *args, **kwargs) - - super(TracedMongoClient, self).__init__(client) - client._datadog_proxy = self - # NOTE[matt] the TracedMongoClient attempts to trace all of the network - # calls in the trace library. This is good because it measures the - # actual network time. It's bad because it uses a private API which - # could change. We'll see how this goes. - if not isinstance(client._topology, TracedTopology): - client._topology = TracedTopology(client._topology) - - # Default Pin - ddtrace.Pin(service=_DEFAULT_SERVICE).onto(self) - - def __setddpin__(self, pin): - pin.onto(self._topology) - - def __getddpin__(self): - return ddtrace.Pin.get_from(self._topology) - - -@contextlib.contextmanager -def wrapped_validate_session(wrapped, instance, args, kwargs): - # We do this to handle a validation `A is B` in pymongo that - # relies on IDs being equal. Since we are proxying objects, we need - # to ensure we're compare proxy with proxy or wrapped with wrapped - # or this validation will fail - client = args[0] - session = args[1] - session_client = session._client - if isinstance(session_client, TracedMongoClient): - if isinstance(client, _MongoClient): - client = getattr(client, "_datadog_proxy", client) - elif isinstance(session_client, _MongoClient): - if isinstance(client, TracedMongoClient): - client = client.__wrapped__ - - yield wrapped(client, session) - - -class TracedTopology(ObjectProxy): - def __init__(self, topology): - super(TracedTopology, self).__init__(topology) - - def select_server(self, *args, **kwargs): - s = self.__wrapped__.select_server(*args, **kwargs) - if not isinstance(s, TracedServer): - s = TracedServer(s) - # Reattach the pin every time in case it changed since the initial patching - ddtrace.Pin.get_from(self).onto(s) - return s - - -class TracedServer(ObjectProxy): - def __init__(self, server): - super(TracedServer, self).__init__(server) - - def _datadog_trace_operation(self, operation): - cmd = None - # Only try to parse something we think is a query. - if self._is_query(operation): - try: - cmd = parse_query(operation) - except Exception: - log.exception("error parsing query") - - pin = ddtrace.Pin.get_from(self) - # if we couldn't parse or shouldn't trace the message, just go. - if not cmd or not pin or not pin.enabled(): - return None - - span = pin.tracer.trace( - schematize_database_operation("pymongo.cmd", database_provider="mongodb"), - span_type=SpanTypes.MONGODB, - service=pin.service, - ) - - span.set_tag_str(COMPONENT, config.pymongo.integration_name) - - # set span.kind to the operation type being performed - span.set_tag_str(SPAN_KIND, SpanKind.CLIENT) - - span.set_tag(SPAN_MEASURED_KEY) - span.set_tag_str(mongox.DB, cmd.db) - span.set_tag_str(mongox.COLLECTION, cmd.coll) - span.set_tag_str(db.SYSTEM, mongox.SERVICE) - span.set_tags(cmd.tags) - - # set `mongodb.query` tag and resource for span - _set_query_metadata(span, cmd) - - # set analytics sample rate - sample_rate = config.pymongo.get_analytics_sample_rate() - if sample_rate is not None: - span.set_tag(ANALYTICS_SAMPLE_RATE_KEY, sample_rate) - return span - - if VERSION >= (4, 5, 0): - - @contextlib.contextmanager - def checkout(self, *args, **kwargs): - with self.__wrapped__.checkout(*args, **kwargs) as s: - if not isinstance(s, TracedSocket): - s = TracedSocket(s) - ddtrace.Pin.get_from(self).onto(s) - yield s - - else: - - @contextlib.contextmanager - def get_socket(self, *args, **kwargs): - with self.__wrapped__.get_socket(*args, **kwargs) as s: - if not isinstance(s, TracedSocket): - s = TracedSocket(s) - ddtrace.Pin.get_from(self).onto(s) - yield s - - if VERSION >= (3, 12, 0): - - def run_operation(self, sock_info, operation, *args, **kwargs): - span = self._datadog_trace_operation(operation) - if span is None: - return self.__wrapped__.run_operation(sock_info, operation, *args, **kwargs) - with span: - result = self.__wrapped__.run_operation(sock_info, operation, *args, **kwargs) - if result: - if hasattr(result, "address"): - set_address_tags(span, result.address) - if self._is_query(operation) and hasattr(result, "docs"): - set_query_rowcount(docs=result.docs, span=span) - return result - - elif (3, 9, 0) <= VERSION < (3, 12, 0): - - def run_operation_with_response(self, sock_info, operation, *args, **kwargs): - span = self._datadog_trace_operation(operation) - if span is None: - return self.__wrapped__.run_operation_with_response(sock_info, operation, *args, **kwargs) - with span: - result = self.__wrapped__.run_operation_with_response(sock_info, operation, *args, **kwargs) - if result: - if hasattr(result, "address"): - set_address_tags(span, result.address) - if self._is_query(operation) and hasattr(result, "docs"): - set_query_rowcount(docs=result.docs, span=span) - return result - - else: - - def send_message_with_response(self, operation, *args, **kwargs): - span = self._datadog_trace_operation(operation) - if span is None: - return self.__wrapped__.send_message_with_response(operation, *args, **kwargs) - with span: - result = self.__wrapped__.send_message_with_response(operation, *args, **kwargs) - if result: - if hasattr(result, "address"): - set_address_tags(span, result.address) - if self._is_query(operation): - if hasattr(result, "data"): - if VERSION >= (3, 6, 0) and hasattr(result.data, "unpack_response"): - set_query_rowcount(docs=result.data.unpack_response(), span=span) - else: - data = _unpack_response(response=result.data) - if VERSION < (3, 2, 0) and data.get("number_returned", None): - span.set_metric(db.ROWCOUNT, data.get("number_returned")) - elif (3, 2, 0) <= VERSION < (3, 6, 0): - docs = data.get("data", None) - set_query_rowcount(docs=docs, span=span) - return result - - @staticmethod - def _is_query(op): - # NOTE: _Query should always have a spec field - return hasattr(op, "spec") - - -class TracedSocket(ObjectProxy): - def __init__(self, socket): - super(TracedSocket, self).__init__(socket) - - def command(self, dbname, spec, *args, **kwargs): - cmd = None - try: - cmd = parse_spec(spec, dbname) - except Exception: - log.exception("error parsing spec. skipping trace") - - pin = ddtrace.Pin.get_from(self) - # skip tracing if we don't have a piece of data we need - if not dbname or not cmd or not pin or not pin.enabled(): - return self.__wrapped__.command(dbname, spec, *args, **kwargs) - - cmd.db = dbname - with self.__trace(cmd): - return self.__wrapped__.command(dbname, spec, *args, **kwargs) - - def write_command(self, *args, **kwargs): - msg = get_argument_value(args, kwargs, 1, "msg") - cmd = None - try: - cmd = parse_msg(msg) - except Exception: - log.exception("error parsing msg") - - pin = ddtrace.Pin.get_from(self) - # if we couldn't parse it, don't try to trace it. - if not cmd or not pin or not pin.enabled(): - return self.__wrapped__.write_command(*args, **kwargs) - - with self.__trace(cmd) as s: - result = self.__wrapped__.write_command(*args, **kwargs) - if result: - s.set_metric(db.ROWCOUNT, result.get("n", -1)) - return result - - def __trace(self, cmd): - pin = ddtrace.Pin.get_from(self) - s = pin.tracer.trace( - schematize_database_operation("pymongo.cmd", database_provider="mongodb"), - span_type=SpanTypes.MONGODB, - service=pin.service, - ) - - s.set_tag_str(COMPONENT, config.pymongo.integration_name) - s.set_tag_str(db.SYSTEM, mongox.SERVICE) - - # set span.kind to the type of operation being performed - s.set_tag_str(SPAN_KIND, SpanKind.CLIENT) - - s.set_tag(SPAN_MEASURED_KEY) - if cmd.db: - s.set_tag_str(mongox.DB, cmd.db) - if cmd: - s.set_tag(mongox.COLLECTION, cmd.coll) - s.set_tags(cmd.tags) - s.set_metrics(cmd.metrics) - - # set `mongodb.query` tag and resource for span - _set_query_metadata(s, cmd) - - # set analytics sample rate - s.set_tag(ANALYTICS_SAMPLE_RATE_KEY, config.pymongo.get_analytics_sample_rate()) - - if self.address: - set_address_tags(s, self.address) - return s - - -def normalize_filter(f=None): - if f is None: - return {} - elif isinstance(f, list): - # normalize lists of filters - # e.g. {$or: [ { age: { $lt: 30 } }, { type: 1 } ]} - return [normalize_filter(s) for s in f] - elif isinstance(f, dict): - # normalize dicts of filters - # {$or: [ { age: { $lt: 30 } }, { type: 1 } ]}) - out = {} - for k, v in f.items(): - if k == "$in" or k == "$nin": - # special case $in queries so we don't loop over lists. - out[k] = "?" - elif isinstance(v, list) or isinstance(v, dict): - # RECURSION ALERT: needs to move to the agent - out[k] = normalize_filter(v) - else: - # NOTE: this shouldn't happen, but let's have a safeguard. - out[k] = "?" - return out - else: - # FIXME[matt] unexpected type. not sure this should ever happen, but at - # least it won't crash. - return {} - - -def set_address_tags(span, address): - # the address is only set after the cursor is done. - if address: - span.set_tag_str(netx.TARGET_HOST, address[0]) - span.set_tag_str(netx.SERVER_ADDRESS, address[0]) - span.set_tag(netx.TARGET_PORT, address[1]) - - -def _set_query_metadata(span, cmd): - """Sets span `mongodb.query` tag and resource given command query""" - if cmd.query: - nq = normalize_filter(cmd.query) - span.set_tag("mongodb.query", nq) - # needed to dump json so we don't get unicode - # dict keys like {u'foo':'bar'} - q = json.dumps(nq) - span.resource = "{} {} {}".format(cmd.name, cmd.coll, q) - else: - span.resource = "{} {}".format(cmd.name, cmd.coll) - - -def set_query_rowcount(docs, span): - # results returned in batches, get len of each batch - if isinstance(docs, Iterable) and len(docs) > 0: - cursor = docs[0].get("cursor", None) - if cursor: - rowcount = sum([len(documents) for batch_key, documents in cursor.items() if BATCH_PARTIAL_KEY in batch_key]) - span.set_metric(db.ROWCOUNT, rowcount) + if name in globals(): + return globals()[name] + raise AttributeError("%s has no attribute %s", __name__, name) diff --git a/ddtrace/contrib/pymongo/parse.py b/ddtrace/contrib/pymongo/parse.py index 1a4330d2e5d..24aae0e8418 100644 --- a/ddtrace/contrib/pymongo/parse.py +++ b/ddtrace/contrib/pymongo/parse.py @@ -1,204 +1,15 @@ -import ctypes -import struct +from ddtrace.internal.utils.deprecations import DDTraceDeprecationWarning +from ddtrace.vendor.debtcollector import deprecate -# 3p -import bson -from bson.codec_options import CodecOptions -from bson.son import SON +from ..internal.pymongo.parse import * # noqa: F401,F403 -# project -from ...ext import net as netx -from ...internal.compat import to_unicode -from ...internal.logger import get_logger +def __getattr__(name): + deprecate( + ("%s.%s is deprecated" % (__name__, name)), + category=DDTraceDeprecationWarning, + ) -log = get_logger(__name__) - - -# MongoDB wire protocol commands -# http://docs.mongodb.com/manual/reference/mongodb-wire-protocol -OP_CODES = { - 1: "reply", - 1000: "msg", # DEV: 1000 was deprecated at some point, use 2013 instead - 2001: "update", - 2002: "insert", - 2003: "reserved", - 2004: "query", - 2005: "get_more", - 2006: "delete", - 2007: "kill_cursors", - 2010: "command", - 2011: "command_reply", - 2013: "msg", -} - -# The maximum message length we'll try to parse -MAX_MSG_PARSE_LEN = 1024 * 1024 - -header_struct = struct.Struct("= 3.1 stores the db and coll separately - coll = getattr(query, "coll", None) - db = getattr(query, "db", None) - - # pymongo < 3.1 _Query does not have a name field, so default to 'query' - cmd = Command(getattr(query, "name", "query"), db, coll) - cmd.query = query.spec - return cmd - - -def parse_spec(spec, db=None): - """Return a Command that has parsed the relevant detail for the given - pymongo SON spec. - """ - - # the first element is the command and collection - items = list(spec.items()) - if not items: - return None - name, coll = items[0] - cmd = Command(name, db or spec.get("$db"), coll) - - if "ordered" in spec: # in insert and update - cmd.tags["mongodb.ordered"] = spec["ordered"] - - if cmd.name == "insert": - if "documents" in spec: - cmd.metrics["mongodb.documents"] = len(spec["documents"]) - - elif cmd.name == "update": - updates = spec.get("updates") - if updates: - # FIXME[matt] is there ever more than one here? - cmd.query = updates[0].get("q") - - elif cmd.name == "delete": - dels = spec.get("deletes") - if dels: - # FIXME[matt] is there ever more than one here? - cmd.query = dels[0].get("q") - - return cmd - - -def _cstring(raw): - """Return the first null terminated cstring from the buffer.""" - return ctypes.create_string_buffer(raw).value - - -def _split_namespace(ns): - """Return a tuple of (db, collection) from the 'db.coll' string.""" - if ns: - # NOTE[matt] ns is unicode or bytes depending on the client version - # so force cast to unicode - split = to_unicode(ns).split(".", 1) - if len(split) == 1: - raise Exception("namespace doesn't contain period: %s" % ns) - return split - return (None, None) + if name in globals(): + return globals()[name] + raise AttributeError("%s has no attribute %s", __name__, name) diff --git a/ddtrace/contrib/pymongo/patch.py b/ddtrace/contrib/pymongo/patch.py index 6cb891f0992..e0cf8589dd4 100644 --- a/ddtrace/contrib/pymongo/patch.py +++ b/ddtrace/contrib/pymongo/patch.py @@ -1,140 +1,4 @@ -import contextlib +from ..internal.pymongo.patch import * # noqa: F401,F403 -import pymongo -from ddtrace import Pin -from ddtrace import config -from ddtrace.contrib import trace_utils -from ddtrace.internal.constants import COMPONENT -from ddtrace.internal.utils.deprecations import DDTraceDeprecationWarning -from ddtrace.vendor.debtcollector import deprecate -from ddtrace.vendor.wrapt import wrap_function_wrapper as _w - -from ...constants import SPAN_KIND -from ...constants import SPAN_MEASURED_KEY -from ...ext import SpanKind -from ...ext import SpanTypes -from ...ext import db -from ...ext import mongo -from ..trace_utils import unwrap as _u -from .client import TracedMongoClient -from .client import set_address_tags -from .client import wrapped_validate_session - - -config._add( - "pymongo", - dict(_default_service="pymongo"), -) - - -def _get_version(): - # type: () -> str - return getattr(pymongo, "__version__", "") - - -def get_version(): - deprecate( - "get_version is deprecated", - message="get_version is deprecated", - removal_version="3.0.0", - category=DDTraceDeprecationWarning, - ) - return _get_version() - - -# Original Client class -_MongoClient = pymongo.MongoClient - -_VERSION = pymongo.version_tuple -_CHECKOUT_FN_NAME = "get_socket" if _VERSION < (4, 5) else "checkout" -_VERIFY_VERSION_CLASS = pymongo.pool.SocketInfo if _VERSION < (4, 5) else pymongo.pool.Connection - - -def patch(): - _patch_pymongo_module() - # We should progressively get rid of TracedMongoClient. We now try to - # wrap methods individually. cf #1501 - pymongo.MongoClient = TracedMongoClient - - -def unpatch(): - _unpatch_pymongo_module() - pymongo.MongoClient = _MongoClient - - -def _patch_pymongo_module(): - if getattr(pymongo, "_datadog_patch", False): - return - pymongo._datadog_patch = True - Pin().onto(pymongo.server.Server) - - # Whenever a pymongo command is invoked, the lib either: - # - Creates a new socket & performs a TCP handshake - # - Grabs a socket already initialized before - _w("pymongo.server", "Server.%s" % _CHECKOUT_FN_NAME, _traced_get_socket) - _w("pymongo.pool", f"{_VERIFY_VERSION_CLASS.__name__}.validate_session", wrapped_validate_session) - - -def patch_pymongo_module(): - deprecate( - "patch_pymongo_module is deprecated", - message="patch_pymongo_module is deprecated", - removal_version="3.0.0", - category=DDTraceDeprecationWarning, - ) - return _patch_pymongo_module() - - -def _unpatch_pymongo_module(): - if not getattr(pymongo, "_datadog_patch", False): - return - pymongo._datadog_patch = False - - _u(pymongo.server.Server, _CHECKOUT_FN_NAME) - _u(_VERIFY_VERSION_CLASS, "validate_session") - - -def unpatch_pymongo_module(): - deprecate( - "unpatch_pymongo_module is deprecated", - message="unpatch_pymongo_module is deprecated", - removal_version="3.0.0", - category=DDTraceDeprecationWarning, - ) - return _unpatch_pymongo_module() - - -@contextlib.contextmanager -def _traced_get_socket(wrapped, instance, args, kwargs): - pin = Pin._find(wrapped, instance) - if not pin or not pin.enabled(): - with wrapped(*args, **kwargs) as sock_info: - yield sock_info - return - - with pin.tracer.trace( - "pymongo.%s" % _CHECKOUT_FN_NAME, - service=trace_utils.int_service(pin, config.pymongo), - span_type=SpanTypes.MONGODB, - ) as span: - span.set_tag_str(COMPONENT, config.pymongo.integration_name) - span.set_tag_str(db.SYSTEM, mongo.SERVICE) - - # set span.kind tag equal to type of operation being performed - span.set_tag_str(SPAN_KIND, SpanKind.CLIENT) - - with wrapped(*args, **kwargs) as sock_info: - set_address_tags(span, sock_info.address) - span.set_tag(SPAN_MEASURED_KEY) - yield sock_info - - -def traced_get_socket(wrapped, instance, args, kwargs): - deprecate( - "traced_get_socket is deprecated", - message="traced_get_socket is deprecated", - removal_version="3.0.0", - category=DDTraceDeprecationWarning, - ) - return _traced_get_socket(wrapped, instance, args, kwargs) +# TODO: deprecate and remove this module diff --git a/ddtrace/contrib/pymysql/__init__.py b/ddtrace/contrib/pymysql/__init__.py index 43283e1ae22..b1524937194 100644 --- a/ddtrace/contrib/pymysql/__init__.py +++ b/ddtrace/contrib/pymysql/__init__.py @@ -62,7 +62,11 @@ with require_modules(required_modules) as missing_modules: if not missing_modules: - from .patch import get_version - from .patch import patch + # Required to allow users to import from `ddtrace.contrib.pymysql.patch` directly + from . import patch as _ # noqa: F401, I001 + + # Expose public methods + from ..internal.pymysql.patch import get_version + from ..internal.pymysql.patch import patch __all__ = ["patch", "get_version"] diff --git a/ddtrace/contrib/pymysql/patch.py b/ddtrace/contrib/pymysql/patch.py index 6bfca3b2763..415d70e010e 100644 --- a/ddtrace/contrib/pymysql/patch.py +++ b/ddtrace/contrib/pymysql/patch.py @@ -1,89 +1,4 @@ -import os +from ..internal.pymysql.patch import * # noqa: F401,F403 -import pymysql -from ddtrace import Pin -from ddtrace import config -from ddtrace.contrib.dbapi import TracedConnection -from ddtrace.internal.utils.deprecations import DDTraceDeprecationWarning -from ddtrace.vendor import wrapt -from ddtrace.vendor.debtcollector import deprecate - -from ...ext import db -from ...ext import net -from ...internal.schema import schematize_database_operation -from ...internal.schema import schematize_service_name -from ...internal.utils.formats import asbool -from ...propagation._database_monitoring import _DBM_Propagator -from ..trace_utils import _convert_to_string - - -config._add( - "pymysql", - dict( - _default_service=schematize_service_name("pymysql"), - _dbapi_span_name_prefix="pymysql", - _dbapi_span_operation_name=schematize_database_operation("pymysql.query", database_provider="mysql"), - trace_fetch_methods=asbool(os.getenv("DD_PYMYSQL_TRACE_FETCH_METHODS", default=False)), - _dbm_propagator=_DBM_Propagator(0, "query"), - ), -) - - -def _get_version(): - # type: () -> str - return getattr(pymysql, "__version__", "") - - -def get_version(): - deprecate( - "get_version is deprecated", - message="get_version is deprecated", - removal_version="3.0.0", - category=DDTraceDeprecationWarning, - ) - return _get_version() - - -CONN_ATTR_BY_TAG = { - net.TARGET_HOST: "host", - net.TARGET_PORT: "port", - net.SERVER_ADDRESS: "host", - db.USER: "user", - db.NAME: "db", -} - - -def patch(): - wrapt.wrap_function_wrapper("pymysql", "connect", _connect) - - -def unpatch(): - if isinstance(pymysql.connect, wrapt.ObjectProxy): - pymysql.connect = pymysql.connect.__wrapped__ - - -def _connect(func, instance, args, kwargs): - conn = func(*args, **kwargs) - return _patch_conn(conn) - - -def _patch_conn(conn): - tags = {t: _convert_to_string(getattr(conn, a)) for t, a in CONN_ATTR_BY_TAG.items() if getattr(conn, a, "") != ""} - tags[db.SYSTEM] = "mysql" - pin = Pin(tags=tags) - - # grab the metadata from the conn - wrapped = TracedConnection(conn, pin=pin, cfg=config.pymysql) - pin.onto(wrapped) - return wrapped - - -def patch_conn(conn): - deprecate( - "patch_conn is deprecated", - message="patch_conn is deprecated", - removal_version="3.0.0", - category=DDTraceDeprecationWarning, - ) - return _patch_conn(conn) +# TODO: deprecate and remove this module diff --git a/releasenotes/notes/move-integrations-to-internal-openai-0d4ab4241552ff94.yaml b/releasenotes/notes/move-integrations-to-internal-openai-0d4ab4241552ff94.yaml new file mode 100644 index 00000000000..50339236534 --- /dev/null +++ b/releasenotes/notes/move-integrations-to-internal-openai-0d4ab4241552ff94.yaml @@ -0,0 +1,14 @@ +--- +deprecations: + - | + openai: Deprecates all modules in the ``ddtrace.contrib.openai`` package. Use attributes exposed in ``ddtrace.contrib.openai.__all__`` instead. + - | + psycopg: Deprecates all modules in the ``ddtrace.contrib.psycopg`` package. Use attributes exposed in ``ddtrace.contrib.psycopg.__all__`` instead. + - | + pylibmc: Deprecates all modules in the ``ddtrace.contrib.pylibmc`` package. Use attributes exposed in ``ddtrace.contrib.pylibmc.__all__`` instead. + - | + pymemcache: Deprecates all modules in the ``ddtrace.contrib.pymemcache`` package. Use attributes exposed in ``ddtrace.contrib.pymemcache.__all__`` instead. + - | + pymongo: Deprecates all modules in the ``ddtrace.contrib.pymongo`` package. Use attributes exposed in ``ddtrace.contrib.pymongo.__all__`` instead. + - | + pymysql: Deprecates all modules in the ``ddtrace.contrib.pymysql`` package. Use attributes exposed in ``ddtrace.contrib.pymysql.__all__`` instead. \ No newline at end of file diff --git a/tests/.suitespec.json b/tests/.suitespec.json index 9411cbe8100..dbb7ad06443 100644 --- a/tests/.suitespec.json +++ b/tests/.suitespec.json @@ -158,6 +158,7 @@ ], "mongo": [ "ddtrace/contrib/pymongo/*", + "ddtrace/contrib/internal/pymongo/*", "ddtrace/contrib/mongoengine/*", "ddtrace/contrib/internal/mongoengine/*", "ddtrace/ext/mongo.py" @@ -166,6 +167,7 @@ "ddtrace/contrib/aiopg/*", "ddtrace/contrib/internal/aiopg/*", "ddtrace/contrib/asyncpg/*", + "ddtrace/contrib/internal/psycopg/*", "ddtrace/contrib/internal/asyncpg/*", "ddtrace/contrib/psycopg/*" ], @@ -249,6 +251,7 @@ ], "pymemcache": [ "ddtrace/contrib/pymemcache/*", + "ddtrace/contrib/internal/pymemcache/*", "ddtrace/ext/memcached.py" ], "snowflake": [ @@ -373,11 +376,13 @@ "ddtrace/contrib/mysqldb/*", "ddtrace/contrib/internal/mysqldb/*", "ddtrace/contrib/pymysql/*", + "ddtrace/contrib/internal/pymysql/*", "ddtrace/contrib/aiomysql/*", "ddtrace/contrib/internal/aiomysql/*" ], "pylibmc": [ - "ddtrace/contrib/pylibmc/*" + "ddtrace/contrib/pylibmc/*", + "ddtrace/contrib/internal/pylibmc/*" ], "logbook": [ "ddtrace/contrib/logbook/*", @@ -420,7 +425,8 @@ "ddtrace/contrib/internal/aiohttp_jinja2/*" ], "openai": [ - "ddtrace/contrib/openai/*" + "ddtrace/contrib/openai/*", + "ddtrace/contrib/internal/openai/*" ], "falcon": [ "ddtrace/contrib/falcon/*", diff --git a/tests/contrib/openai/test_openai_llmobs.py b/tests/contrib/openai/test_openai_llmobs.py index 6a79486a7d3..92f5d549a5c 100644 --- a/tests/contrib/openai/test_openai_llmobs.py +++ b/tests/contrib/openai/test_openai_llmobs.py @@ -105,7 +105,7 @@ async def test_chat_completion_stream(self, openai, ddtrace_global_config, mock_ if not hasattr(openai, "ChatCompletion"): pytest.skip("ChatCompletion not supported for this version of openai") with get_openai_vcr(subdirectory_name="v0").use_cassette("chat_completion_streamed.yaml"): - with mock.patch("ddtrace.contrib.openai.utils.encoding_for_model", create=True) as mock_encoding: + with mock.patch("ddtrace.contrib.internal.openai.utils.encoding_for_model", create=True) as mock_encoding: model = "gpt-3.5-turbo" resp_model = model input_messages = [{"role": "user", "content": "Who won the world series in 2020?"}] @@ -182,7 +182,7 @@ def test_chat_completion_function_call_stream(self, openai, ddtrace_global_confi if not hasattr(openai, "ChatCompletion"): pytest.skip("ChatCompletion not supported for this version of openai") with get_openai_vcr(subdirectory_name="v0").use_cassette("chat_completion_function_call_streamed.yaml"): - with mock.patch("ddtrace.contrib.openai.utils.encoding_for_model", create=True) as mock_encoding: + with mock.patch("ddtrace.contrib.internal.openai.utils.encoding_for_model", create=True) as mock_encoding: model = "gpt-3.5-turbo" resp_model = model mock_encoding.return_value.encode.side_effect = lambda x: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] @@ -366,8 +366,8 @@ def test_completion(self, openai, ddtrace_global_config, mock_llmobs_writer, moc def test_completion_stream(self, openai, ddtrace_global_config, mock_llmobs_writer, mock_tracer): with get_openai_vcr(subdirectory_name="v1").use_cassette("completion_streamed.yaml"): - with mock.patch("ddtrace.contrib.openai.utils.encoding_for_model", create=True) as mock_encoding: - with mock.patch("ddtrace.contrib.openai.utils._est_tokens") as mock_est: + with mock.patch("ddtrace.contrib.internal.openai.utils.encoding_for_model", create=True) as mock_encoding: + with mock.patch("ddtrace.contrib.internal.openai.utils._est_tokens") as mock_est: mock_encoding.return_value.encode.side_effect = lambda x: [1, 2] mock_est.return_value = 2 model = "ada" @@ -431,8 +431,8 @@ def test_chat_completion_stream(self, openai, ddtrace_global_config, mock_llmobs Also ensure the llmobs records have the correct tagging including trace/span ID for trace correlation. """ with get_openai_vcr(subdirectory_name="v1").use_cassette("chat_completion_streamed.yaml"): - with mock.patch("ddtrace.contrib.openai.utils.encoding_for_model", create=True) as mock_encoding: - with mock.patch("ddtrace.contrib.openai.utils._est_tokens") as mock_est: + with mock.patch("ddtrace.contrib.internal.openai.utils.encoding_for_model", create=True) as mock_encoding: + with mock.patch("ddtrace.contrib.internal.openai.utils._est_tokens") as mock_est: mock_encoding.return_value.encode.side_effect = lambda x: [1, 2, 3, 4, 5, 6, 7, 8] mock_est.return_value = 8 model = "gpt-3.5-turbo" diff --git a/tests/contrib/openai/test_openai_v0.py b/tests/contrib/openai/test_openai_v0.py index 268d7dc0e85..fa66c60857c 100644 --- a/tests/contrib/openai/test_openai_v0.py +++ b/tests/contrib/openai/test_openai_v0.py @@ -10,7 +10,7 @@ import ddtrace from ddtrace import patch -from ddtrace.contrib.openai.utils import _est_tokens +from ddtrace.contrib.internal.openai.utils import _est_tokens from ddtrace.internal.utils.version import parse_version from tests.contrib.openai.utils import chat_completion_custom_functions from tests.contrib.openai.utils import chat_completion_input_description @@ -1264,7 +1264,7 @@ def test_span_finish_on_stream_error(openai, openai_vcr, snapshot_tracer): def test_completion_stream(openai, openai_vcr, mock_metrics, mock_tracer): with openai_vcr.use_cassette("completion_streamed.yaml"): - with mock.patch("ddtrace.contrib.openai.utils.encoding_for_model", create=True) as mock_encoding: + with mock.patch("ddtrace.contrib.internal.openai.utils.encoding_for_model", create=True) as mock_encoding: mock_encoding.return_value.encode.side_effect = lambda x: [1, 2] expected_completion = '! ... A page layouts page drawer? ... Interesting. The "Tools" is' resp = openai.Completion.create(model="ada", prompt="Hello world", stream=True) @@ -1305,7 +1305,7 @@ def test_completion_stream(openai, openai_vcr, mock_metrics, mock_tracer): @pytest.mark.asyncio async def test_completion_async_stream(openai, openai_vcr, mock_metrics, mock_tracer): with openai_vcr.use_cassette("completion_async_streamed.yaml"): - with mock.patch("ddtrace.contrib.openai.utils.encoding_for_model", create=True) as mock_encoding: + with mock.patch("ddtrace.contrib.internal.openai.utils.encoding_for_model", create=True) as mock_encoding: mock_encoding.return_value.encode.side_effect = lambda x: [1, 2] expected_completion = "\" and just start creating stuff. Don't expect it to draw like this." resp = await openai.Completion.acreate(model="ada", prompt="Hello world", stream=True) @@ -1345,7 +1345,7 @@ def test_chat_completion_stream(openai, openai_vcr, mock_metrics, snapshot_trace pytest.skip("ChatCompletion not supported for this version of openai") with openai_vcr.use_cassette("chat_completion_streamed.yaml"): - with mock.patch("ddtrace.contrib.openai.utils.encoding_for_model", create=True) as mock_encoding: + with mock.patch("ddtrace.contrib.internal.openai.utils.encoding_for_model", create=True) as mock_encoding: mock_encoding.return_value.encode.side_effect = lambda x: [1, 2, 3, 4, 5, 6, 7, 8] expected_completion = "The Los Angeles Dodgers won the World Series in 2020." resp = openai.ChatCompletion.create( @@ -1395,7 +1395,7 @@ async def test_chat_completion_async_stream(openai, openai_vcr, mock_metrics, sn if not hasattr(openai, "ChatCompletion"): pytest.skip("ChatCompletion not supported for this version of openai") with openai_vcr.use_cassette("chat_completion_streamed_async.yaml"): - with mock.patch("ddtrace.contrib.openai.utils.encoding_for_model", create=True) as mock_encoding: + with mock.patch("ddtrace.contrib.internal.openai.utils.encoding_for_model", create=True) as mock_encoding: mock_encoding.return_value.encode.side_effect = lambda x: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] expected_completion = "As an AI language model, I do not have access to real-time information but as of the 2021 season, the captain of the Toronto Maple Leafs is John Tavares." # noqa: E501 resp = await openai.ChatCompletion.acreate( diff --git a/tests/contrib/openai/test_openai_v1.py b/tests/contrib/openai/test_openai_v1.py index 0a5c531215e..e72c47f12b1 100644 --- a/tests/contrib/openai/test_openai_v1.py +++ b/tests/contrib/openai/test_openai_v1.py @@ -6,7 +6,7 @@ import ddtrace from ddtrace import patch -from ddtrace.contrib.openai.utils import _est_tokens +from ddtrace.contrib.internal.openai.utils import _est_tokens from ddtrace.internal.utils.version import parse_version from tests.contrib.openai.utils import chat_completion_custom_functions from tests.contrib.openai.utils import chat_completion_input_description @@ -920,7 +920,7 @@ def test_span_finish_on_stream_error(openai, openai_vcr, snapshot_tracer): def test_completion_stream(openai, openai_vcr, mock_metrics, mock_tracer): with openai_vcr.use_cassette("completion_streamed.yaml"): - with mock.patch("ddtrace.contrib.openai.utils.encoding_for_model", create=True) as mock_encoding: + with mock.patch("ddtrace.contrib.internal.openai.utils.encoding_for_model", create=True) as mock_encoding: mock_encoding.return_value.encode.side_effect = lambda x: [1, 2] expected_completion = '! ... A page layouts page drawer? ... Interesting. The "Tools" is' client = openai.OpenAI() @@ -958,7 +958,7 @@ def test_completion_stream(openai, openai_vcr, mock_metrics, mock_tracer): async def test_completion_async_stream(openai, openai_vcr, mock_metrics, mock_tracer): with openai_vcr.use_cassette("completion_streamed.yaml"): - with mock.patch("ddtrace.contrib.openai.utils.encoding_for_model", create=True) as mock_encoding: + with mock.patch("ddtrace.contrib.internal.openai.utils.encoding_for_model", create=True) as mock_encoding: mock_encoding.return_value.encode.side_effect = lambda x: [1, 2] expected_completion = '! ... A page layouts page drawer? ... Interesting. The "Tools" is' client = openai.AsyncOpenAI() @@ -1000,7 +1000,7 @@ async def test_completion_async_stream(openai, openai_vcr, mock_metrics, mock_tr ) def test_completion_stream_context_manager(openai, openai_vcr, mock_metrics, mock_tracer): with openai_vcr.use_cassette("completion_streamed.yaml"): - with mock.patch("ddtrace.contrib.openai.utils.encoding_for_model", create=True) as mock_encoding: + with mock.patch("ddtrace.contrib.internal.openai.utils.encoding_for_model", create=True) as mock_encoding: mock_encoding.return_value.encode.side_effect = lambda x: [1, 2] expected_completion = '! ... A page layouts page drawer? ... Interesting. The "Tools" is' client = openai.OpenAI() @@ -1038,7 +1038,7 @@ def test_completion_stream_context_manager(openai, openai_vcr, mock_metrics, moc def test_chat_completion_stream(openai, openai_vcr, mock_metrics, snapshot_tracer): with openai_vcr.use_cassette("chat_completion_streamed.yaml"): - with mock.patch("ddtrace.contrib.openai.utils.encoding_for_model", create=True) as mock_encoding: + with mock.patch("ddtrace.contrib.internal.openai.utils.encoding_for_model", create=True) as mock_encoding: mock_encoding.return_value.encode.side_effect = lambda x: [1, 2, 3, 4, 5, 6, 7, 8] expected_completion = "The Los Angeles Dodgers won the World Series in 2020." client = openai.OpenAI() @@ -1087,7 +1087,7 @@ def test_chat_completion_stream(openai, openai_vcr, mock_metrics, snapshot_trace async def test_chat_completion_async_stream(openai, openai_vcr, mock_metrics, snapshot_tracer): with openai_vcr.use_cassette("chat_completion_streamed.yaml"): - with mock.patch("ddtrace.contrib.openai.utils.encoding_for_model", create=True) as mock_encoding: + with mock.patch("ddtrace.contrib.internal.openai.utils.encoding_for_model", create=True) as mock_encoding: mock_encoding.return_value.encode.side_effect = lambda x: [1, 2, 3, 4, 5, 6, 7, 8] expected_completion = "The Los Angeles Dodgers won the World Series in 2020." client = openai.AsyncOpenAI() @@ -1139,7 +1139,7 @@ async def test_chat_completion_async_stream(openai, openai_vcr, mock_metrics, sn ) async def test_chat_completion_async_stream_context_manager(openai, openai_vcr, mock_metrics, snapshot_tracer): with openai_vcr.use_cassette("chat_completion_streamed.yaml"): - with mock.patch("ddtrace.contrib.openai.utils.encoding_for_model", create=True) as mock_encoding: + with mock.patch("ddtrace.contrib.internal.openai.utils.encoding_for_model", create=True) as mock_encoding: mock_encoding.return_value.encode.side_effect = lambda x: [1, 2, 3, 4, 5, 6, 7, 8] expected_completion = "The Los Angeles Dodgers won the World Series in 2020." client = openai.AsyncOpenAI() diff --git a/tests/contrib/pymongo/test.py b/tests/contrib/pymongo/test.py index 2ef8e10af7f..6c8b010d95b 100644 --- a/tests/contrib/pymongo/test.py +++ b/tests/contrib/pymongo/test.py @@ -7,8 +7,8 @@ # project from ddtrace import Pin from ddtrace.constants import ANALYTICS_SAMPLE_RATE_KEY -from ddtrace.contrib.pymongo.client import normalize_filter -from ddtrace.contrib.pymongo.patch import _CHECKOUT_FN_NAME +from ddtrace.contrib.internal.pymongo.client import normalize_filter +from ddtrace.contrib.internal.pymongo.patch import _CHECKOUT_FN_NAME from ddtrace.contrib.pymongo.patch import patch from ddtrace.contrib.pymongo.patch import unpatch from ddtrace.ext import SpanTypes