diff --git a/.riot/requirements/f12fa99.txt b/.riot/requirements/102c18b.txt similarity index 90% rename from .riot/requirements/f12fa99.txt rename to .riot/requirements/102c18b.txt index 6ba3725dc6a..60b40418ee3 100644 --- a/.riot/requirements/f12fa99.txt +++ b/.riot/requirements/102c18b.txt @@ -2,15 +2,16 @@ # This file is autogenerated by pip-compile with Python 3.11 # by the following command: # -# pip-compile --allow-unsafe --no-annotate .riot/requirements/f12fa99.in +# pip-compile --allow-unsafe --no-annotate .riot/requirements/102c18b.in # annotated-types==0.7.0 attrs==24.2.0 cachetools==5.5.0 certifi==2024.8.30 charset-normalizer==3.4.0 -coverage[toml]==7.6.7 +coverage[toml]==7.6.8 docstring-parser==0.16 +google-ai-generativelanguage==0.6.13 google-api-core[grpc]==2.23.0 google-auth==2.36.0 google-cloud-aiplatform[all]==1.71.1 @@ -36,8 +37,8 @@ proto-plus==1.25.0 protobuf==5.28.3 pyasn1==0.6.1 pyasn1-modules==0.4.1 -pydantic==2.9.2 -pydantic-core==2.23.4 +pydantic==2.10.2 +pydantic-core==2.27.1 pytest==8.3.3 pytest-asyncio==0.24.0 pytest-cov==6.0.0 diff --git a/.riot/requirements/23eab38.txt b/.riot/requirements/107d415.txt similarity index 90% rename from .riot/requirements/23eab38.txt rename to .riot/requirements/107d415.txt index feed5849537..4a242f86894 100644 --- a/.riot/requirements/23eab38.txt +++ b/.riot/requirements/107d415.txt @@ -2,15 +2,16 @@ # This file is autogenerated by pip-compile with Python 3.12 # by the following command: # -# pip-compile --allow-unsafe --no-annotate .riot/requirements/23eab38.in +# pip-compile --allow-unsafe --no-annotate .riot/requirements/107d415.in # annotated-types==0.7.0 attrs==24.2.0 cachetools==5.5.0 certifi==2024.8.30 charset-normalizer==3.4.0 -coverage[toml]==7.6.7 +coverage[toml]==7.6.8 docstring-parser==0.16 +google-ai-generativelanguage==0.6.13 google-api-core[grpc]==2.23.0 google-auth==2.36.0 google-cloud-aiplatform[all]==1.71.1 @@ -36,8 +37,8 @@ proto-plus==1.25.0 protobuf==5.28.3 pyasn1==0.6.1 pyasn1-modules==0.4.1 -pydantic==2.9.2 -pydantic-core==2.23.4 +pydantic==2.10.2 +pydantic-core==2.27.1 pytest==8.3.3 pytest-asyncio==0.24.0 pytest-cov==6.0.0 diff --git a/.riot/requirements/692fe7a.txt b/.riot/requirements/1b0d9c1.txt similarity index 90% rename from .riot/requirements/692fe7a.txt rename to .riot/requirements/1b0d9c1.txt index 2e2a00ab4e3..842077d8b3d 100644 --- a/.riot/requirements/692fe7a.txt +++ b/.riot/requirements/1b0d9c1.txt @@ -2,16 +2,17 @@ # This file is autogenerated by pip-compile with Python 3.9 # by the following command: # -# pip-compile --allow-unsafe --no-annotate .riot/requirements/692fe7a.in +# pip-compile --allow-unsafe --no-annotate .riot/requirements/1b0d9c1.in # annotated-types==0.7.0 attrs==24.2.0 cachetools==5.5.0 certifi==2024.8.30 charset-normalizer==3.4.0 -coverage[toml]==7.6.7 +coverage[toml]==7.6.8 docstring-parser==0.16 exceptiongroup==1.2.2 +google-ai-generativelanguage==0.6.13 google-api-core[grpc]==2.23.0 google-auth==2.36.0 google-cloud-aiplatform[all]==1.71.1 @@ -37,8 +38,8 @@ proto-plus==1.25.0 protobuf==5.28.3 pyasn1==0.6.1 pyasn1-modules==0.4.1 -pydantic==2.9.2 -pydantic-core==2.23.4 +pydantic==2.10.2 +pydantic-core==2.27.1 pytest==8.3.3 pytest-asyncio==0.24.0 pytest-cov==6.0.0 diff --git a/.riot/requirements/1bee666.txt b/.riot/requirements/1bee666.txt new file mode 100644 index 00000000000..70c923d2825 --- /dev/null +++ b/.riot/requirements/1bee666.txt @@ -0,0 +1,64 @@ +# +# This file is autogenerated by pip-compile with Python 3.9 +# by the following command: +# +# pip-compile --allow-unsafe --no-annotate .riot/requirements/1bee666.in +# +annotated-types==0.7.0 +attrs==24.2.0 +cachetools==5.5.0 +certifi==2024.8.30 +charset-normalizer==3.4.0 +coverage[toml]==7.6.8 +docstring-parser==0.16 +exceptiongroup==1.2.2 +google-ai-generativelanguage==0.6.10 +google-api-core[grpc]==2.23.0 +google-api-python-client==2.154.0 +google-auth==2.36.0 +google-auth-httplib2==0.2.0 +google-cloud-aiplatform[all]==1.71.1 +google-cloud-bigquery==3.27.0 +google-cloud-core==2.4.1 +google-cloud-resource-manager==1.13.1 +google-cloud-storage==2.18.2 +google-crc32c==1.6.0 +google-generativeai==0.8.3 +google-resumable-media==2.7.2 +googleapis-common-protos[grpc]==1.66.0 +grpc-google-iam-v1==0.13.1 +grpcio==1.68.0 +grpcio-status==1.68.0 +httplib2==0.22.0 +hypothesis==6.45.0 +idna==3.10 +iniconfig==2.0.0 +mock==5.1.0 +numpy==2.0.2 +opentracing==2.4.0 +packaging==24.2 +pillow==11.0.0 +pluggy==1.5.0 +proto-plus==1.25.0 +protobuf==5.28.3 +pyasn1==0.6.1 +pyasn1-modules==0.4.1 +pydantic==2.10.2 +pydantic-core==2.27.1 +pyparsing==3.2.0 +pytest==8.3.3 +pytest-asyncio==0.24.0 +pytest-cov==6.0.0 +pytest-mock==3.14.0 +python-dateutil==2.9.0.post0 +requests==2.32.3 +rsa==4.9 +shapely==2.0.6 +six==1.16.0 +sortedcontainers==2.4.0 +tomli==2.1.0 +tqdm==4.67.1 +typing-extensions==4.12.2 +uritemplate==4.1.1 +urllib3==2.2.3 +vertexai==1.71.1 diff --git a/.riot/requirements/1e15a25.txt b/.riot/requirements/1e15a25.txt deleted file mode 100644 index 36405478a02..00000000000 --- a/.riot/requirements/1e15a25.txt +++ /dev/null @@ -1,48 +0,0 @@ -# -# This file is autogenerated by pip-compile with Python 3.11 -# by the following command: -# -# pip-compile --allow-unsafe --no-annotate .riot/requirements/1e15a25.in -# -annotated-types==0.7.0 -attrs==24.2.0 -cachetools==5.5.0 -certifi==2024.8.30 -charset-normalizer==3.3.2 -coverage[toml]==7.6.1 -google-ai-generativelanguage==0.6.9 -google-api-core[grpc]==2.19.2 -google-api-python-client==2.145.0 -google-auth==2.34.0 -google-auth-httplib2==0.2.0 -google-generativeai==0.8.0 -googleapis-common-protos==1.65.0 -grpcio==1.66.1 -grpcio-status==1.66.1 -httplib2==0.22.0 -hypothesis==6.45.0 -idna==3.8 -iniconfig==2.0.0 -mock==5.1.0 -opentracing==2.4.0 -packaging==24.1 -pillow==10.4.0 -pluggy==1.5.0 -proto-plus==1.24.0 -protobuf==5.28.0 -pyasn1==0.6.0 -pyasn1-modules==0.4.0 -pydantic==2.9.1 -pydantic-core==2.23.3 -pyparsing==3.1.4 -pytest==8.3.3 -pytest-asyncio==0.24.0 -pytest-cov==5.0.0 -pytest-mock==3.14.0 -requests==2.32.3 -rsa==4.9 -sortedcontainers==2.4.0 -tqdm==4.66.5 -typing-extensions==4.12.2 -uritemplate==4.1.1 -urllib3==2.2.2 diff --git a/.riot/requirements/1f54e6b.txt b/.riot/requirements/1f54e6b.txt deleted file mode 100644 index 8bcc57eabff..00000000000 --- a/.riot/requirements/1f54e6b.txt +++ /dev/null @@ -1,50 +0,0 @@ -# -# This file is autogenerated by pip-compile with Python 3.10 -# by the following command: -# -# pip-compile --allow-unsafe --no-annotate .riot/requirements/1f54e6b.in -# -annotated-types==0.7.0 -attrs==24.2.0 -cachetools==5.5.0 -certifi==2024.8.30 -charset-normalizer==3.3.2 -coverage[toml]==7.6.1 -exceptiongroup==1.2.2 -google-ai-generativelanguage==0.6.9 -google-api-core[grpc]==2.19.2 -google-api-python-client==2.145.0 -google-auth==2.34.0 -google-auth-httplib2==0.2.0 -google-generativeai==0.8.0 -googleapis-common-protos==1.65.0 -grpcio==1.66.1 -grpcio-status==1.66.1 -httplib2==0.22.0 -hypothesis==6.45.0 -idna==3.8 -iniconfig==2.0.0 -mock==5.1.0 -opentracing==2.4.0 -packaging==24.1 -pillow==10.4.0 -pluggy==1.5.0 -proto-plus==1.24.0 -protobuf==5.28.0 -pyasn1==0.6.0 -pyasn1-modules==0.4.0 -pydantic==2.9.1 -pydantic-core==2.23.3 -pyparsing==3.1.4 -pytest==8.3.3 -pytest-asyncio==0.24.0 -pytest-cov==5.0.0 -pytest-mock==3.14.0 -requests==2.32.3 -rsa==4.9 -sortedcontainers==2.4.0 -tomli==2.0.1 -tqdm==4.66.5 -typing-extensions==4.12.2 -uritemplate==4.1.1 -urllib3==2.2.2 diff --git a/.riot/requirements/55b8536.txt b/.riot/requirements/55b8536.txt new file mode 100644 index 00000000000..ed6036adcd1 --- /dev/null +++ b/.riot/requirements/55b8536.txt @@ -0,0 +1,62 @@ +# +# This file is autogenerated by pip-compile with Python 3.12 +# by the following command: +# +# pip-compile --allow-unsafe --no-annotate .riot/requirements/55b8536.in +# +annotated-types==0.7.0 +attrs==24.2.0 +cachetools==5.5.0 +certifi==2024.8.30 +charset-normalizer==3.4.0 +coverage[toml]==7.6.8 +docstring-parser==0.16 +google-ai-generativelanguage==0.6.10 +google-api-core[grpc]==2.23.0 +google-api-python-client==2.154.0 +google-auth==2.36.0 +google-auth-httplib2==0.2.0 +google-cloud-aiplatform[all]==1.71.1 +google-cloud-bigquery==3.27.0 +google-cloud-core==2.4.1 +google-cloud-resource-manager==1.13.1 +google-cloud-storage==2.18.2 +google-crc32c==1.6.0 +google-generativeai==0.8.3 +google-resumable-media==2.7.2 +googleapis-common-protos[grpc]==1.66.0 +grpc-google-iam-v1==0.13.1 +grpcio==1.68.0 +grpcio-status==1.68.0 +httplib2==0.22.0 +hypothesis==6.45.0 +idna==3.10 +iniconfig==2.0.0 +mock==5.1.0 +numpy==2.1.3 +opentracing==2.4.0 +packaging==24.2 +pillow==11.0.0 +pluggy==1.5.0 +proto-plus==1.25.0 +protobuf==5.28.3 +pyasn1==0.6.1 +pyasn1-modules==0.4.1 +pydantic==2.10.2 +pydantic-core==2.27.1 +pyparsing==3.2.0 +pytest==8.3.3 +pytest-asyncio==0.24.0 +pytest-cov==6.0.0 +pytest-mock==3.14.0 +python-dateutil==2.9.0.post0 +requests==2.32.3 +rsa==4.9 +shapely==2.0.6 +six==1.16.0 +sortedcontainers==2.4.0 +tqdm==4.67.1 +typing-extensions==4.12.2 +uritemplate==4.1.1 +urllib3==2.2.3 +vertexai==1.71.1 diff --git a/.riot/requirements/6820ef2.txt b/.riot/requirements/6820ef2.txt new file mode 100644 index 00000000000..2db99b509e5 --- /dev/null +++ b/.riot/requirements/6820ef2.txt @@ -0,0 +1,62 @@ +# +# This file is autogenerated by pip-compile with Python 3.11 +# by the following command: +# +# pip-compile --allow-unsafe --no-annotate .riot/requirements/6820ef2.in +# +annotated-types==0.7.0 +attrs==24.2.0 +cachetools==5.5.0 +certifi==2024.8.30 +charset-normalizer==3.4.0 +coverage[toml]==7.6.8 +docstring-parser==0.16 +google-ai-generativelanguage==0.6.10 +google-api-core[grpc]==2.23.0 +google-api-python-client==2.154.0 +google-auth==2.36.0 +google-auth-httplib2==0.2.0 +google-cloud-aiplatform[all]==1.71.1 +google-cloud-bigquery==3.27.0 +google-cloud-core==2.4.1 +google-cloud-resource-manager==1.13.1 +google-cloud-storage==2.18.2 +google-crc32c==1.6.0 +google-generativeai==0.8.3 +google-resumable-media==2.7.2 +googleapis-common-protos[grpc]==1.66.0 +grpc-google-iam-v1==0.13.1 +grpcio==1.68.0 +grpcio-status==1.68.0 +httplib2==0.22.0 +hypothesis==6.45.0 +idna==3.10 +iniconfig==2.0.0 +mock==5.1.0 +numpy==2.1.3 +opentracing==2.4.0 +packaging==24.2 +pillow==11.0.0 +pluggy==1.5.0 +proto-plus==1.25.0 +protobuf==5.28.3 +pyasn1==0.6.1 +pyasn1-modules==0.4.1 +pydantic==2.10.2 +pydantic-core==2.27.1 +pyparsing==3.2.0 +pytest==8.3.3 +pytest-asyncio==0.24.0 +pytest-cov==6.0.0 +pytest-mock==3.14.0 +python-dateutil==2.9.0.post0 +requests==2.32.3 +rsa==4.9 +shapely==2.0.6 +six==1.16.0 +sortedcontainers==2.4.0 +tqdm==4.67.1 +typing-extensions==4.12.2 +uritemplate==4.1.1 +urllib3==2.2.3 +vertexai==1.71.1 diff --git a/.riot/requirements/ab2f587.txt b/.riot/requirements/ab2f587.txt new file mode 100644 index 00000000000..29fd2375edd --- /dev/null +++ b/.riot/requirements/ab2f587.txt @@ -0,0 +1,64 @@ +# +# This file is autogenerated by pip-compile with Python 3.10 +# by the following command: +# +# pip-compile --allow-unsafe --no-annotate .riot/requirements/ab2f587.in +# +annotated-types==0.7.0 +attrs==24.2.0 +cachetools==5.5.0 +certifi==2024.8.30 +charset-normalizer==3.4.0 +coverage[toml]==7.6.8 +docstring-parser==0.16 +exceptiongroup==1.2.2 +google-ai-generativelanguage==0.6.10 +google-api-core[grpc]==2.23.0 +google-api-python-client==2.154.0 +google-auth==2.36.0 +google-auth-httplib2==0.2.0 +google-cloud-aiplatform[all]==1.71.1 +google-cloud-bigquery==3.27.0 +google-cloud-core==2.4.1 +google-cloud-resource-manager==1.13.1 +google-cloud-storage==2.18.2 +google-crc32c==1.6.0 +google-generativeai==0.8.3 +google-resumable-media==2.7.2 +googleapis-common-protos[grpc]==1.66.0 +grpc-google-iam-v1==0.13.1 +grpcio==1.68.0 +grpcio-status==1.68.0 +httplib2==0.22.0 +hypothesis==6.45.0 +idna==3.10 +iniconfig==2.0.0 +mock==5.1.0 +numpy==2.1.3 +opentracing==2.4.0 +packaging==24.2 +pillow==11.0.0 +pluggy==1.5.0 +proto-plus==1.25.0 +protobuf==5.28.3 +pyasn1==0.6.1 +pyasn1-modules==0.4.1 +pydantic==2.10.2 +pydantic-core==2.27.1 +pyparsing==3.2.0 +pytest==8.3.3 +pytest-asyncio==0.24.0 +pytest-cov==6.0.0 +pytest-mock==3.14.0 +python-dateutil==2.9.0.post0 +requests==2.32.3 +rsa==4.9 +shapely==2.0.6 +six==1.16.0 +sortedcontainers==2.4.0 +tomli==2.1.0 +tqdm==4.67.1 +typing-extensions==4.12.2 +uritemplate==4.1.1 +urllib3==2.2.3 +vertexai==1.71.1 diff --git a/.riot/requirements/59e23ef.txt b/.riot/requirements/bf4cae6.txt similarity index 90% rename from .riot/requirements/59e23ef.txt rename to .riot/requirements/bf4cae6.txt index a0284fd7200..17acda0e74d 100644 --- a/.riot/requirements/59e23ef.txt +++ b/.riot/requirements/bf4cae6.txt @@ -2,16 +2,17 @@ # This file is autogenerated by pip-compile with Python 3.10 # by the following command: # -# pip-compile --allow-unsafe --no-annotate .riot/requirements/59e23ef.in +# pip-compile --allow-unsafe --no-annotate .riot/requirements/bf4cae6.in # annotated-types==0.7.0 attrs==24.2.0 cachetools==5.5.0 certifi==2024.8.30 charset-normalizer==3.4.0 -coverage[toml]==7.6.7 +coverage[toml]==7.6.8 docstring-parser==0.16 exceptiongroup==1.2.2 +google-ai-generativelanguage==0.6.13 google-api-core[grpc]==2.23.0 google-auth==2.36.0 google-cloud-aiplatform[all]==1.71.1 @@ -37,8 +38,8 @@ proto-plus==1.25.0 protobuf==5.28.3 pyasn1==0.6.1 pyasn1-modules==0.4.1 -pydantic==2.9.2 -pydantic-core==2.23.4 +pydantic==2.10.2 +pydantic-core==2.27.1 pytest==8.3.3 pytest-asyncio==0.24.0 pytest-cov==6.0.0 diff --git a/.riot/requirements/e8247d6.txt b/.riot/requirements/e8247d6.txt deleted file mode 100644 index 2aad3bb1a89..00000000000 --- a/.riot/requirements/e8247d6.txt +++ /dev/null @@ -1,48 +0,0 @@ -# -# This file is autogenerated by pip-compile with Python 3.12 -# by the following command: -# -# pip-compile --allow-unsafe --no-annotate .riot/requirements/e8247d6.in -# -annotated-types==0.7.0 -attrs==24.2.0 -cachetools==5.5.0 -certifi==2024.8.30 -charset-normalizer==3.3.2 -coverage[toml]==7.6.1 -google-ai-generativelanguage==0.6.9 -google-api-core[grpc]==2.19.2 -google-api-python-client==2.145.0 -google-auth==2.34.0 -google-auth-httplib2==0.2.0 -google-generativeai==0.8.0 -googleapis-common-protos==1.65.0 -grpcio==1.66.1 -grpcio-status==1.66.1 -httplib2==0.22.0 -hypothesis==6.45.0 -idna==3.8 -iniconfig==2.0.0 -mock==5.1.0 -opentracing==2.4.0 -packaging==24.1 -pillow==10.4.0 -pluggy==1.5.0 -proto-plus==1.24.0 -protobuf==5.28.0 -pyasn1==0.6.0 -pyasn1-modules==0.4.0 -pydantic==2.9.1 -pydantic-core==2.23.3 -pyparsing==3.1.4 -pytest==8.3.3 -pytest-asyncio==0.24.0 -pytest-cov==5.0.0 -pytest-mock==3.14.0 -requests==2.32.3 -rsa==4.9 -sortedcontainers==2.4.0 -tqdm==4.66.5 -typing-extensions==4.12.2 -uritemplate==4.1.1 -urllib3==2.2.2 diff --git a/.riot/requirements/ebe4ea5.txt b/.riot/requirements/ebe4ea5.txt deleted file mode 100644 index 264c2960158..00000000000 --- a/.riot/requirements/ebe4ea5.txt +++ /dev/null @@ -1,50 +0,0 @@ -# -# This file is autogenerated by pip-compile with Python 3.9 -# by the following command: -# -# pip-compile --allow-unsafe --no-annotate .riot/requirements/ebe4ea5.in -# -annotated-types==0.7.0 -attrs==24.2.0 -cachetools==5.5.0 -certifi==2024.8.30 -charset-normalizer==3.3.2 -coverage[toml]==7.6.1 -exceptiongroup==1.2.2 -google-ai-generativelanguage==0.6.9 -google-api-core[grpc]==2.19.2 -google-api-python-client==2.145.0 -google-auth==2.34.0 -google-auth-httplib2==0.2.0 -google-generativeai==0.8.0 -googleapis-common-protos==1.65.0 -grpcio==1.66.1 -grpcio-status==1.66.1 -httplib2==0.22.0 -hypothesis==6.45.0 -idna==3.8 -iniconfig==2.0.0 -mock==5.1.0 -opentracing==2.4.0 -packaging==24.1 -pillow==10.4.0 -pluggy==1.5.0 -proto-plus==1.24.0 -protobuf==5.28.0 -pyasn1==0.6.0 -pyasn1-modules==0.4.0 -pydantic==2.9.1 -pydantic-core==2.23.3 -pyparsing==3.1.4 -pytest==8.3.3 -pytest-asyncio==0.24.0 -pytest-cov==5.0.0 -pytest-mock==3.14.0 -requests==2.32.3 -rsa==4.9 -sortedcontainers==2.4.0 -tomli==2.0.1 -tqdm==4.66.5 -typing-extensions==4.12.2 -uritemplate==4.1.1 -urllib3==2.2.2 diff --git a/ddtrace/contrib/internal/google_generativeai/_utils.py b/ddtrace/contrib/internal/google_generativeai/_utils.py index 20a923e07cb..ad281c4e847 100644 --- a/ddtrace/contrib/internal/google_generativeai/_utils.py +++ b/ddtrace/contrib/internal/google_generativeai/_utils.py @@ -5,6 +5,7 @@ from ddtrace.internal.utils import get_argument_value from ddtrace.llmobs._integrations.utils import get_generation_config_google +from ddtrace.llmobs._integrations.utils import get_system_instructions_from_google_model from ddtrace.llmobs._integrations.utils import tag_request_content_part_google from ddtrace.llmobs._integrations.utils import tag_response_part_google @@ -109,7 +110,7 @@ def tag_request(span, integration, instance, args, kwargs): """ contents = get_argument_value(args, kwargs, 0, "contents") generation_config = get_generation_config_google(instance, kwargs) - system_instruction = getattr(instance, "_system_instruction", None) + system_instruction = get_system_instructions_from_google_model(instance) stream = kwargs.get("stream", None) try: @@ -127,10 +128,8 @@ def tag_request(span, integration, instance, args, kwargs): return if system_instruction: - for idx, part in enumerate(system_instruction.parts): - span.set_tag_str( - "google_generativeai.request.system_instruction.%d.text" % idx, integration.trunc(str(part.text)) - ) + for idx, text in enumerate(system_instruction): + span.set_tag_str("google_generativeai.request.system_instruction.%d.text" % idx, integration.trunc(text)) if isinstance(contents, str): span.set_tag_str("google_generativeai.request.contents.0.text", integration.trunc(contents)) diff --git a/ddtrace/contrib/internal/vertexai/_utils.py b/ddtrace/contrib/internal/vertexai/_utils.py index 07fd0cb69e2..129b97fd920 100644 --- a/ddtrace/contrib/internal/vertexai/_utils.py +++ b/ddtrace/contrib/internal/vertexai/_utils.py @@ -5,18 +5,23 @@ from ddtrace.internal.utils import get_argument_value from ddtrace.llmobs._integrations.utils import get_generation_config_google +from ddtrace.llmobs._integrations.utils import get_system_instructions_from_google_model from ddtrace.llmobs._integrations.utils import tag_request_content_part_google from ddtrace.llmobs._integrations.utils import tag_response_part_google from ddtrace.llmobs._utils import _get_attr class BaseTracedVertexAIStreamResponse: - def __init__(self, generator, integration, span, is_chat): + def __init__(self, generator, model_instance, integration, span, args, kwargs, is_chat, history): self._generator = generator + self._model_instance = model_instance self._dd_integration = integration self._dd_span = span - self._chunks = [] + self._args = args + self._kwargs = kwargs self.is_chat = is_chat + self._chunks = [] + self._history = history class TracedVertexAIStreamResponse(BaseTracedVertexAIStreamResponse): @@ -41,6 +46,12 @@ def __iter__(self): else: tag_stream_response(self._dd_span, self._chunks, self._dd_integration) finally: + if self._dd_integration.is_pc_sampled_llmobs(self._dd_span): + self._kwargs["instance"] = self._model_instance + self._kwargs["history"] = self._history + self._dd_integration.llmobs_set_tags( + self._dd_span, args=self._args, kwargs=self._kwargs, response=self._chunks + ) self._dd_span.finish() @@ -66,30 +77,15 @@ async def __aiter__(self): else: tag_stream_response(self._dd_span, self._chunks, self._dd_integration) finally: + if self._dd_integration.is_pc_sampled_llmobs(self._dd_span): + self._kwargs["instance"] = self._model_instance + self._kwargs["history"] = self._history + self._dd_integration.llmobs_set_tags( + self._dd_span, args=self._args, kwargs=self._kwargs, response=self._chunks + ) self._dd_span.finish() -def get_system_instruction_texts_from_model(instance): - """ - Extract system instructions from model and convert to []str for tagging. - """ - raw_system_instructions = _get_attr(instance, "_system_instruction", []) - if isinstance(raw_system_instructions, str): - return [raw_system_instructions] - elif isinstance(raw_system_instructions, Part): - return [_get_attr(raw_system_instructions, "text", "")] - elif not isinstance(raw_system_instructions, list): - return [] - - system_instructions = [] - for elem in raw_system_instructions: - if isinstance(elem, str): - system_instructions.append(elem) - elif isinstance(elem, Part): - system_instructions.append(_get_attr(elem, "text", "")) - return system_instructions - - def extract_info_from_parts(parts): """Return concatenated text from parts and function calls.""" concatenated_text = "" @@ -200,7 +196,7 @@ def tag_request(span, integration, instance, args, kwargs): generation_config_dict = ( generation_config if isinstance(generation_config, dict) else generation_config.to_dict() ) - system_instructions = get_system_instruction_texts_from_model(model_instance) + system_instructions = get_system_instructions_from_google_model(model_instance) stream = kwargs.get("stream", None) if generation_config_dict is not None: @@ -219,7 +215,7 @@ def tag_request(span, integration, instance, args, kwargs): integration.trunc(str(text)), ) - if isinstance(contents, str) or isinstance(contents, dict): + if isinstance(contents, str): span.set_tag_str("vertexai.request.contents.0.text", integration.trunc(str(contents))) return elif isinstance(contents, Part): diff --git a/ddtrace/contrib/internal/vertexai/patch.py b/ddtrace/contrib/internal/vertexai/patch.py index 1fdfcb7dd16..2dbce060234 100644 --- a/ddtrace/contrib/internal/vertexai/patch.py +++ b/ddtrace/contrib/internal/vertexai/patch.py @@ -59,13 +59,17 @@ def _traced_generate(vertexai, pin, func, instance, args, kwargs, model_instance "%s.%s" % (instance.__class__.__name__, func.__name__), provider="google", model=extract_model_name_google(model_instance, "_model_name"), - submit_to_llmobs=False, + submit_to_llmobs=True, ) + # history must be copied since it is modified during the LLM interaction + history = getattr(instance, "history", [])[:] try: tag_request(span, integration, instance, args, kwargs) generations = func(*args, **kwargs) if stream: - return TracedVertexAIStreamResponse(generations, integration, span, is_chat) + return TracedVertexAIStreamResponse( + generations, model_instance, integration, span, args, kwargs, is_chat, history + ) tag_response(span, generations, integration) except Exception: span.set_exc_info(*sys.exc_info()) @@ -73,6 +77,10 @@ def _traced_generate(vertexai, pin, func, instance, args, kwargs, model_instance finally: # streamed spans will be finished separately once the stream generator is exhausted if span.error or not stream: + if integration.is_pc_sampled_llmobs(span): + kwargs["instance"] = model_instance + kwargs["history"] = history + integration.llmobs_set_tags(span, args=args, kwargs=kwargs, response=generations) span.finish() return generations @@ -86,13 +94,17 @@ async def _traced_agenerate(vertexai, pin, func, instance, args, kwargs, model_i "%s.%s" % (instance.__class__.__name__, func.__name__), provider="google", model=extract_model_name_google(model_instance, "_model_name"), - submit_to_llmobs=False, + submit_to_llmobs=True, ) + # history must be copied since it is modified during the LLM interaction + history = getattr(instance, "history", [])[:] try: tag_request(span, integration, instance, args, kwargs) generations = await func(*args, **kwargs) if stream: - return TracedAsyncVertexAIStreamResponse(generations, integration, span, is_chat) + return TracedAsyncVertexAIStreamResponse( + generations, model_instance, integration, span, args, kwargs, is_chat, history + ) tag_response(span, generations, integration) except Exception: span.set_exc_info(*sys.exc_info()) @@ -100,6 +112,10 @@ async def _traced_agenerate(vertexai, pin, func, instance, args, kwargs, model_i finally: # streamed spans will be finished separately once the stream generator is exhausted if span.error or not stream: + if integration.is_pc_sampled_llmobs(span): + kwargs["instance"] = model_instance + kwargs["history"] = history + integration.llmobs_set_tags(span, args=args, kwargs=kwargs, response=generations) span.finish() return generations diff --git a/ddtrace/llmobs/_constants.py b/ddtrace/llmobs/_constants.py index 7c295835e54..27000b36aac 100644 --- a/ddtrace/llmobs/_constants.py +++ b/ddtrace/llmobs/_constants.py @@ -27,6 +27,7 @@ GEMINI_APM_SPAN_NAME = "gemini.request" LANGCHAIN_APM_SPAN_NAME = "langchain.request" OPENAI_APM_SPAN_NAME = "openai.request" +VERTEXAI_APM_SPAN_NAME = "vertexai.request" INPUT_TOKENS_METRIC_KEY = "input_tokens" OUTPUT_TOKENS_METRIC_KEY = "output_tokens" diff --git a/ddtrace/llmobs/_integrations/bedrock.py b/ddtrace/llmobs/_integrations/bedrock.py index 0aaa545b47e..78798ae4f98 100644 --- a/ddtrace/llmobs/_integrations/bedrock.py +++ b/ddtrace/llmobs/_integrations/bedrock.py @@ -29,7 +29,12 @@ class BedrockIntegration(BaseLLMIntegration): _integration_name = "bedrock" def _llmobs_set_tags( - self, span: Span, args: List[Any], kwargs: Dict[str, Any], response: Optional[Any] = None, operation: str = "" + self, + span: Span, + args: List[Any], + kwargs: Dict[str, Any], + response: Optional[Any] = None, + operation: str = "", ) -> None: """Extract prompt/response tags from a completion and set them as temporary "_ml_obs.*" tags.""" if span.get_tag(PROPAGATED_PARENT_ID_KEY) is None: diff --git a/ddtrace/llmobs/_integrations/gemini.py b/ddtrace/llmobs/_integrations/gemini.py index 21e74b036f0..f1a4730812f 100644 --- a/ddtrace/llmobs/_integrations/gemini.py +++ b/ddtrace/llmobs/_integrations/gemini.py @@ -7,16 +7,17 @@ from ddtrace import Span from ddtrace.internal.utils import get_argument_value from ddtrace.llmobs._constants import INPUT_MESSAGES -from ddtrace.llmobs._constants import INPUT_TOKENS_METRIC_KEY from ddtrace.llmobs._constants import METADATA from ddtrace.llmobs._constants import METRICS from ddtrace.llmobs._constants import MODEL_NAME from ddtrace.llmobs._constants import MODEL_PROVIDER from ddtrace.llmobs._constants import OUTPUT_MESSAGES -from ddtrace.llmobs._constants import OUTPUT_TOKENS_METRIC_KEY from ddtrace.llmobs._constants import SPAN_KIND -from ddtrace.llmobs._constants import TOTAL_TOKENS_METRIC_KEY from ddtrace.llmobs._integrations.base import BaseLLMIntegration +from ddtrace.llmobs._integrations.utils import extract_message_from_part_google +from ddtrace.llmobs._integrations.utils import get_llmobs_metrics_tags_google +from ddtrace.llmobs._integrations.utils import get_system_instructions_from_google_model +from ddtrace.llmobs._integrations.utils import llmobs_get_metadata_google from ddtrace.llmobs._utils import _get_attr from ddtrace.llmobs._utils import safe_json @@ -45,10 +46,10 @@ def _llmobs_set_tags( span.set_tag_str(MODEL_PROVIDER, span.get_tag("google_generativeai.request.provider") or "") instance = kwargs.get("instance", None) - metadata = self._llmobs_set_metadata(kwargs, instance) + metadata = llmobs_get_metadata_google(kwargs, instance) span.set_tag_str(METADATA, safe_json(metadata)) - system_instruction = _get_attr(instance, "_system_instruction", None) + system_instruction = get_system_instructions_from_google_model(instance) input_contents = get_argument_value(args, kwargs, 0, "contents") input_messages = self._extract_input_message(input_contents, system_instruction) span.set_tag_str(INPUT_MESSAGES, safe_json(input_messages)) @@ -59,50 +60,15 @@ def _llmobs_set_tags( output_messages = self._extract_output_message(response) span.set_tag_str(OUTPUT_MESSAGES, safe_json(output_messages)) - usage = self._get_llmobs_metrics_tags(span) + usage = get_llmobs_metrics_tags_google("google_generativeai", span) if usage: span.set_tag_str(METRICS, safe_json(usage)) - @staticmethod - def _llmobs_set_metadata(kwargs, instance): - metadata = {} - model_config = _get_attr(instance, "_generation_config", {}) - request_config = kwargs.get("generation_config", {}) - parameters = ("temperature", "max_output_tokens", "candidate_count", "top_p", "top_k") - for param in parameters: - model_config_value = _get_attr(model_config, param, None) - request_config_value = _get_attr(request_config, param, None) - if model_config_value or request_config_value: - metadata[param] = request_config_value or model_config_value - return metadata - - @staticmethod - def _extract_message_from_part(part, role): - text = _get_attr(part, "text", "") - function_call = _get_attr(part, "function_call", None) - function_response = _get_attr(part, "function_response", None) - message = {"content": text} - if role: - message["role"] = role - if function_call: - function_call_dict = function_call - if not isinstance(function_call, dict): - function_call_dict = type(function_call).to_dict(function_call) - message["tool_calls"] = [ - {"name": function_call_dict.get("name", ""), "arguments": function_call_dict.get("args", {})} - ] - if function_response: - function_response_dict = function_response - if not isinstance(function_response, dict): - function_response_dict = type(function_response).to_dict(function_response) - message["content"] = "[tool result: {}]".format(function_response_dict.get("response", "")) - return message - def _extract_input_message(self, contents, system_instruction=None): messages = [] if system_instruction: - for part in system_instruction.parts: - messages.append({"content": part.text or "", "role": "system"}) + for instruction in system_instruction: + messages.append({"content": instruction or "", "role": "system"}) if isinstance(contents, str): messages.append({"content": contents}) return messages @@ -128,7 +94,7 @@ def _extract_input_message(self, contents, system_instruction=None): messages.append(message) continue for part in parts: - message = self._extract_message_from_part(part, role) + message = extract_message_from_part_google(part, role) messages.append(message) return messages @@ -140,21 +106,6 @@ def _extract_output_message(self, generations): role = content.get("role", "model") parts = content.get("parts", []) for part in parts: - message = self._extract_message_from_part(part, role) + message = extract_message_from_part_google(part, role) output_messages.append(message) return output_messages - - @staticmethod - def _get_llmobs_metrics_tags(span): - usage = {} - input_tokens = span.get_metric("google_generativeai.response.usage.prompt_tokens") - output_tokens = span.get_metric("google_generativeai.response.usage.completion_tokens") - total_tokens = span.get_metric("google_generativeai.response.usage.total_tokens") - - if input_tokens is not None: - usage[INPUT_TOKENS_METRIC_KEY] = input_tokens - if output_tokens is not None: - usage[OUTPUT_TOKENS_METRIC_KEY] = output_tokens - if total_tokens is not None: - usage[TOTAL_TOKENS_METRIC_KEY] = total_tokens - return usage diff --git a/ddtrace/llmobs/_integrations/utils.py b/ddtrace/llmobs/_integrations/utils.py index 695dedb19c8..2676dce9637 100644 --- a/ddtrace/llmobs/_integrations/utils.py +++ b/ddtrace/llmobs/_integrations/utils.py @@ -1,3 +1,6 @@ +from ddtrace.llmobs._constants import INPUT_TOKENS_METRIC_KEY +from ddtrace.llmobs._constants import OUTPUT_TOKENS_METRIC_KEY +from ddtrace.llmobs._constants import TOTAL_TOKENS_METRIC_KEY from ddtrace.llmobs._utils import _get_attr @@ -72,3 +75,91 @@ def tag_response_part_google(tag_prefix, span, integration, part, part_idx, cand "%s.response.candidates.%d.content.parts.%d.function_call.args" % (tag_prefix, candidate_idx, part_idx), integration.trunc(str(_get_attr(function_call, "args", {}))), ) + + +def llmobs_get_metadata_google(kwargs, instance): + metadata = {} + model_config = getattr(instance, "_generation_config", {}) or {} + model_config = model_config.to_dict() if hasattr(model_config, "to_dict") else model_config + request_config = kwargs.get("generation_config", {}) or {} + request_config = request_config.to_dict() if hasattr(request_config, "to_dict") else request_config + + parameters = ("temperature", "max_output_tokens", "candidate_count", "top_p", "top_k") + for param in parameters: + model_config_value = _get_attr(model_config, param, None) + request_config_value = _get_attr(request_config, param, None) + if model_config_value or request_config_value: + metadata[param] = request_config_value or model_config_value + return metadata + + +def extract_message_from_part_google(part, role=None): + text = _get_attr(part, "text", "") + function_call = _get_attr(part, "function_call", None) + function_response = _get_attr(part, "function_response", None) + message = {"content": text} + if role: + message["role"] = role + if function_call: + function_call_dict = function_call + if not isinstance(function_call, dict): + function_call_dict = type(function_call).to_dict(function_call) + message["tool_calls"] = [ + {"name": function_call_dict.get("name", ""), "arguments": function_call_dict.get("args", {})} + ] + if function_response: + function_response_dict = function_response + if not isinstance(function_response, dict): + function_response_dict = type(function_response).to_dict(function_response) + message["content"] = "[tool result: {}]".format(function_response_dict.get("response", "")) + return message + + +def get_llmobs_metrics_tags_google(integration_name, span): + usage = {} + input_tokens = span.get_metric("%s.response.usage.prompt_tokens" % integration_name) + output_tokens = span.get_metric("%s.response.usage.completion_tokens" % integration_name) + total_tokens = span.get_metric("%s.response.usage.total_tokens" % integration_name) + + if input_tokens is not None: + usage[INPUT_TOKENS_METRIC_KEY] = input_tokens + if output_tokens is not None: + usage[OUTPUT_TOKENS_METRIC_KEY] = output_tokens + if total_tokens is not None: + usage[TOTAL_TOKENS_METRIC_KEY] = total_tokens + return usage + + +def get_system_instructions_from_google_model(model_instance): + """ + Extract system instructions from model and convert to []str for tagging. + """ + try: + from google.ai.generativelanguage_v1beta.types.content import Content + except ImportError: + Content = None + try: + from vertexai.generative_models._generative_models import Part + except ImportError: + Part = None + + raw_system_instructions = getattr(model_instance, "_system_instruction", []) + if Content is not None and isinstance(raw_system_instructions, Content): + system_instructions = [] + for part in raw_system_instructions.parts: + system_instructions.append(_get_attr(part, "text", "")) + return system_instructions + elif isinstance(raw_system_instructions, str): + return [raw_system_instructions] + elif Part is not None and isinstance(raw_system_instructions, Part): + return [_get_attr(raw_system_instructions, "text", "")] + elif not isinstance(raw_system_instructions, list): + return [] + + system_instructions = [] + for elem in raw_system_instructions: + if isinstance(elem, str): + system_instructions.append(elem) + elif Part is not None and isinstance(elem, Part): + system_instructions.append(_get_attr(elem, "text", "")) + return system_instructions diff --git a/ddtrace/llmobs/_integrations/vertexai.py b/ddtrace/llmobs/_integrations/vertexai.py index 1ad64b61d40..69fdc7eb665 100644 --- a/ddtrace/llmobs/_integrations/vertexai.py +++ b/ddtrace/llmobs/_integrations/vertexai.py @@ -1,9 +1,25 @@ from typing import Any from typing import Dict +from typing import Iterable +from typing import List from typing import Optional from ddtrace import Span +from ddtrace.internal.utils import get_argument_value +from ddtrace.llmobs._constants import INPUT_MESSAGES +from ddtrace.llmobs._constants import METADATA +from ddtrace.llmobs._constants import METRICS +from ddtrace.llmobs._constants import MODEL_NAME +from ddtrace.llmobs._constants import MODEL_PROVIDER +from ddtrace.llmobs._constants import OUTPUT_MESSAGES +from ddtrace.llmobs._constants import SPAN_KIND from ddtrace.llmobs._integrations.base import BaseLLMIntegration +from ddtrace.llmobs._integrations.utils import extract_message_from_part_google +from ddtrace.llmobs._integrations.utils import get_llmobs_metrics_tags_google +from ddtrace.llmobs._integrations.utils import get_system_instructions_from_google_model +from ddtrace.llmobs._integrations.utils import llmobs_get_metadata_google +from ddtrace.llmobs._utils import _get_attr +from ddtrace.llmobs._utils import safe_json class VertexAIIntegration(BaseLLMIntegration): @@ -16,3 +32,106 @@ def _set_base_span_tags( span.set_tag_str("vertexai.request.provider", provider) if model is not None: span.set_tag_str("vertexai.request.model", model) + + def _llmobs_set_tags( + self, + span: Span, + args: List[Any], + kwargs: Dict[str, Any], + response: Optional[Any] = None, + operation: str = "", + ) -> None: + span.set_tag_str(SPAN_KIND, "llm") + span.set_tag_str(MODEL_NAME, span.get_tag("vertexai.request.model") or "") + span.set_tag_str(MODEL_PROVIDER, span.get_tag("vertexai.request.provider") or "") + + instance = kwargs.get("instance", None) + history = kwargs.get("history", []) + metadata = llmobs_get_metadata_google(kwargs, instance) + span.set_tag_str(METADATA, safe_json(metadata)) + + system_instruction = get_system_instructions_from_google_model(instance) + input_contents = get_argument_value(args, kwargs, 0, "contents") + input_messages = self._extract_input_message(input_contents, history, system_instruction) + span.set_tag_str(INPUT_MESSAGES, safe_json(input_messages)) + + if span.error or response is None: + span.set_tag_str(OUTPUT_MESSAGES, safe_json([{"content": ""}])) + return + + output_messages = self._extract_output_message(response) + span.set_tag_str(OUTPUT_MESSAGES, safe_json(output_messages)) + + usage = get_llmobs_metrics_tags_google("vertexai", span) + if usage: + span.set_tag_str(METRICS, safe_json(usage)) + + def _extract_input_message(self, contents, history, system_instruction=None): + from vertexai.generative_models._generative_models import Part + + messages = [] + if system_instruction: + for instruction in system_instruction: + messages.append({"content": instruction or "", "role": "system"}) + for content in history: + messages.extend(self._extract_messages_from_content(content)) + if isinstance(contents, str): + messages.append({"content": contents}) + return messages + if isinstance(contents, Part): + message = extract_message_from_part_google(contents) + messages.append(message) + return messages + if not isinstance(contents, list): + messages.append({"content": "[Non-text content object: {}]".format(repr(contents))}) + return messages + for content in contents: + if isinstance(content, str): + messages.append({"content": content}) + continue + if isinstance(content, Part): + message = extract_message_from_part_google(content) + messages.append(message) + continue + messages.extend(self._extract_messages_from_content(content)) + return messages + + def _extract_output_message(self, generations): + output_messages = [] + # streamed responses will be a list of chunks + if isinstance(generations, list): + message_content = "" + tool_calls = [] + role = "model" + for chunk in generations: + for candidate in _get_attr(chunk, "candidates", []): + content = _get_attr(candidate, "content", {}) + messages = self._extract_messages_from_content(content) + for message in messages: + message_content += message.get("content", "") + tool_calls.extend(message.get("tool_calls", [])) + message = {"content": message_content, "role": role} + if tool_calls: + message["tool_calls"] = tool_calls + return [message] + generations_dict = generations.to_dict() + for candidate in generations_dict.get("candidates", []): + content = candidate.get("content", {}) + output_messages.extend(self._extract_messages_from_content(content)) + return output_messages + + @staticmethod + def _extract_messages_from_content(content): + messages = [] + role = _get_attr(content, "role", "") + parts = _get_attr(content, "parts", []) + if not parts or not isinstance(parts, Iterable): + message = {"content": "[Non-text content object: {}]".format(repr(content))} + if role: + message["role"] = role + messages.append(message) + return messages + for part in parts: + message = extract_message_from_part_google(part, role) + messages.append(message) + return messages diff --git a/ddtrace/llmobs/_llmobs.py b/ddtrace/llmobs/_llmobs.py index 07f4d6f93c2..a3ac9501319 100644 --- a/ddtrace/llmobs/_llmobs.py +++ b/ddtrace/llmobs/_llmobs.py @@ -69,6 +69,7 @@ "openai": "openai", "langchain": "langchain", "google_generativeai": "google_generativeai", + "vertexai": "vertexai", } diff --git a/ddtrace/llmobs/_utils.py b/ddtrace/llmobs/_utils.py index f3d6434d297..8813788f0a3 100644 --- a/ddtrace/llmobs/_utils.py +++ b/ddtrace/llmobs/_utils.py @@ -18,6 +18,7 @@ from ddtrace.llmobs._constants import PARENT_ID_KEY from ddtrace.llmobs._constants import PROPAGATED_PARENT_ID_KEY from ddtrace.llmobs._constants import SESSION_ID +from ddtrace.llmobs._constants import VERTEXAI_APM_SPAN_NAME log = get_logger(__name__) @@ -118,7 +119,7 @@ def _get_llmobs_parent_id(span: Span) -> Optional[str]: def _get_span_name(span: Span) -> str: - if span.name in (LANGCHAIN_APM_SPAN_NAME, GEMINI_APM_SPAN_NAME) and span.resource != "": + if span.name in (LANGCHAIN_APM_SPAN_NAME, GEMINI_APM_SPAN_NAME, VERTEXAI_APM_SPAN_NAME) and span.resource != "": return span.resource elif span.name == OPENAI_APM_SPAN_NAME and span.resource != "": client_name = span.get_tag("openai.request.client") or "OpenAI" diff --git a/releasenotes/notes/feat-llmobs-vertexai-f58488859472c7b5.yaml b/releasenotes/notes/feat-llmobs-vertexai-f58488859472c7b5.yaml new file mode 100644 index 00000000000..5709289091e --- /dev/null +++ b/releasenotes/notes/feat-llmobs-vertexai-f58488859472c7b5.yaml @@ -0,0 +1,5 @@ +--- +features: + - | + LLM Observability: Adds support to automatically submit Vertex AI Python calls to LLM Observability. + diff --git a/riotfile.py b/riotfile.py index 2133641a05c..f25e09e2d25 100644 --- a/riotfile.py +++ b/riotfile.py @@ -2720,6 +2720,8 @@ def select_pys(min_version: str = MIN_PYTHON_VERSION, max_version: str = MAX_PYT "pytest-asyncio": latest, "google-generativeai": [latest], "pillow": latest, + "google-ai-generativelanguage": [latest], + "vertexai": [latest], }, ), Venv( @@ -2729,6 +2731,7 @@ def select_pys(min_version: str = MIN_PYTHON_VERSION, max_version: str = MAX_PYT pkgs={ "pytest-asyncio": latest, "vertexai": [latest], + "google-ai-generativelanguage": [latest], }, ), Venv( diff --git a/tests/contrib/vertexai/conftest.py b/tests/contrib/vertexai/conftest.py index 9f58381ca41..74ba41d4dee 100644 --- a/tests/contrib/vertexai/conftest.py +++ b/tests/contrib/vertexai/conftest.py @@ -1,9 +1,10 @@ +import mock from mock import PropertyMock -from mock import patch as mock_patch import pytest from ddtrace.contrib.vertexai import patch from ddtrace.contrib.vertexai import unpatch +from ddtrace.llmobs import LLMObs from ddtrace.pin import Pin from tests.contrib.vertexai.utils import MockAsyncPredictionServiceClient from tests.contrib.vertexai.utils import MockPredictionServiceClient @@ -13,6 +14,10 @@ from tests.utils import override_global_config +def default_global_config(): + return {} + + @pytest.fixture def ddtrace_global_config(): return {} @@ -34,29 +39,46 @@ def mock_async_client(): @pytest.fixture -def mock_tracer(vertexai): +def mock_tracer(ddtrace_global_config, vertexai): try: pin = Pin.get_from(vertexai) mock_tracer = DummyTracer(writer=DummyWriter(trace_flush_enabled=False)) pin.override(vertexai, tracer=mock_tracer) pin.tracer.configure() + if ddtrace_global_config.get("_llmobs_enabled", False): + # Have to disable and re-enable LLMObs to use the mock tracer. + LLMObs.disable() + LLMObs.enable(_tracer=mock_tracer, integrations_enabled=False) yield mock_tracer except Exception: yield +@pytest.fixture +def mock_llmobs_writer(): + patcher = mock.patch("ddtrace.llmobs._llmobs.LLMObsSpanWriter") + try: + LLMObsSpanWriterMock = patcher.start() + m = mock.MagicMock() + LLMObsSpanWriterMock.return_value = m + yield m + finally: + patcher.stop() + + @pytest.fixture def vertexai(ddtrace_global_config, ddtrace_config_vertexai, mock_client, mock_async_client): - global_config = ddtrace_global_config + global_config = default_global_config() + global_config.update(ddtrace_global_config) with override_global_config(global_config): with override_config("vertexai", ddtrace_config_vertexai): patch() import vertexai from vertexai.generative_models import GenerativeModel - with mock_patch.object( + with mock.patch.object( GenerativeModel, "_prediction_client", new_callable=PropertyMock - ) as mock_client_property, mock_patch.object( + ) as mock_client_property, mock.patch.object( GenerativeModel, "_prediction_async_client", new_callable=PropertyMock ) as mock_async_client_property: mock_client_property.return_value = mock_client diff --git a/tests/contrib/vertexai/test_vertexai_llmobs.py b/tests/contrib/vertexai/test_vertexai_llmobs.py new file mode 100644 index 00000000000..78a03bc664c --- /dev/null +++ b/tests/contrib/vertexai/test_vertexai_llmobs.py @@ -0,0 +1,698 @@ +import mock +import pytest + +from tests.contrib.vertexai.utils import MOCK_COMPLETION_SIMPLE_1 +from tests.contrib.vertexai.utils import MOCK_COMPLETION_SIMPLE_2 +from tests.contrib.vertexai.utils import MOCK_COMPLETION_STREAM_CHUNKS +from tests.contrib.vertexai.utils import MOCK_COMPLETION_TOOL +from tests.contrib.vertexai.utils import MOCK_COMPLETION_TOOL_CALL_STREAM_CHUNKS +from tests.contrib.vertexai.utils import _async_streamed_response +from tests.contrib.vertexai.utils import _mock_completion_response +from tests.contrib.vertexai.utils import _mock_completion_stream_chunk +from tests.contrib.vertexai.utils import weather_tool +from tests.llmobs._utils import _expected_llmobs_llm_span_event + + +@pytest.mark.parametrize( + "ddtrace_global_config", [dict(_llmobs_enabled=True, _llmobs_sample_rate=1.0, _llmobs_ml_app="")] +) +class TestLLMObsVertexai: + def test_completion(self, vertexai, mock_llmobs_writer, mock_tracer): + llm = vertexai.generative_models.GenerativeModel("gemini-1.5-flash") + llm._prediction_client.responses["generate_content"].append(_mock_completion_response(MOCK_COMPLETION_SIMPLE_1)) + llm.generate_content( + "Why do bears hibernate?", + generation_config=vertexai.generative_models.GenerationConfig( + stop_sequences=["x"], max_output_tokens=30, temperature=1.0 + ), + ) + span = mock_tracer.pop_traces()[0][0] + assert mock_llmobs_writer.enqueue.call_count == 1 + mock_llmobs_writer.enqueue.assert_called_with(expected_llmobs_span_event(span)) + + def test_completion_error(self, vertexai, mock_llmobs_writer, mock_tracer): + llm = vertexai.generative_models.GenerativeModel("gemini-1.5-flash") + llm._prediction_client.generate_content = mock.Mock() + llm._prediction_client.generate_content.side_effect = TypeError( + "_GenerativeModel.generate_content() got an unexpected keyword argument 'candidate_count'" + ) + with pytest.raises(TypeError): + llm.generate_content( + "Why do bears hibernate?", + generation_config=vertexai.generative_models.GenerationConfig( + stop_sequences=["x"], max_output_tokens=30, temperature=1.0 + ), + candidate_count=2, # candidate_count is not a valid keyword argument + ) + span = mock_tracer.pop_traces()[0][0] + assert mock_llmobs_writer.enqueue.call_count == 1 + mock_llmobs_writer.enqueue.assert_called_with(expected_llmobs_error_span_event(span)) + + def test_completion_tool(self, vertexai, mock_llmobs_writer, mock_tracer): + llm = vertexai.generative_models.GenerativeModel("gemini-1.5-flash") + llm._prediction_client.responses["generate_content"].append(_mock_completion_response(MOCK_COMPLETION_TOOL)) + llm.generate_content( + "What is the weather like in New York City?", + generation_config=vertexai.generative_models.GenerationConfig( + stop_sequences=["x"], max_output_tokens=30, temperature=1.0 + ), + ) + span = mock_tracer.pop_traces()[0][0] + assert mock_llmobs_writer.enqueue.call_count == 1 + mock_llmobs_writer.enqueue.assert_called_with(expected_llmobs_tool_span_event(span)) + + def test_completion_multiple_messages(self, vertexai, mock_llmobs_writer, mock_tracer): + llm = vertexai.generative_models.GenerativeModel("gemini-1.5-flash") + llm._prediction_client.responses["generate_content"].append(_mock_completion_response(MOCK_COMPLETION_SIMPLE_1)) + llm.generate_content( + [ + {"role": "user", "parts": [{"text": "Hello World!"}]}, + {"role": "model", "parts": [{"text": "Great to meet you. What would you like to know?"}]}, + {"parts": [{"text": "Why do bears hibernate?"}]}, + ], + generation_config=vertexai.generative_models.GenerationConfig( + stop_sequences=["x"], max_output_tokens=30, temperature=1.0 + ), + ) + + span = mock_tracer.pop_traces()[0][0] + assert mock_llmobs_writer.enqueue.call_count == 1 + mock_llmobs_writer.enqueue.assert_called_with(expected_llmobs_history_span_event(span)) + + def test_completion_system_prompt(self, vertexai, mock_llmobs_writer, mock_tracer): + llm = vertexai.generative_models.GenerativeModel( + "gemini-1.5-flash", + system_instruction=[ + vertexai.generative_models.Part.from_text("You are required to insist that bears do not hibernate.") + ], + ) + llm._prediction_client.responses["generate_content"].append(_mock_completion_response(MOCK_COMPLETION_SIMPLE_2)) + llm.generate_content( + "Why do bears hibernate?", + generation_config=vertexai.generative_models.GenerationConfig( + stop_sequences=["x"], max_output_tokens=50, temperature=1.0 + ), + ) + span = mock_tracer.pop_traces()[0][0] + assert mock_llmobs_writer.enqueue.call_count == 1 + mock_llmobs_writer.enqueue.assert_called_with(expected_llmobs_system_prompt_span_event(span)) + + def test_completion_model_generation_config(self, vertexai, mock_llmobs_writer, mock_tracer): + llm = vertexai.generative_models.GenerativeModel("gemini-1.5-flash") + llm._prediction_client.responses["generate_content"].append(_mock_completion_response(MOCK_COMPLETION_SIMPLE_1)) + llm.generate_content( + "Why do bears hibernate?", + generation_config=vertexai.generative_models.GenerationConfig( + stop_sequences=["x"], max_output_tokens=30, temperature=1.0 + ), + ) + span = mock_tracer.pop_traces()[0][0] + assert mock_llmobs_writer.enqueue.call_count == 1 + mock_llmobs_writer.enqueue.assert_called_with(expected_llmobs_span_event(span)) + + def test_completion_no_generation_config(self, vertexai, mock_llmobs_writer, mock_tracer): + llm = vertexai.generative_models.GenerativeModel("gemini-1.5-flash") + llm._prediction_client.responses["generate_content"].append(_mock_completion_response(MOCK_COMPLETION_SIMPLE_1)) + llm.generate_content( + "Why do bears hibernate?", + ) + span = mock_tracer.pop_traces()[0][0] + assert mock_llmobs_writer.enqueue.call_count == 1 + mock_llmobs_writer.enqueue.assert_called_with(expected_llmobs_no_generation_config_span_event(span)) + + def test_completion_stream(self, vertexai, mock_llmobs_writer, mock_tracer): + llm = vertexai.generative_models.GenerativeModel("gemini-1.5-flash") + llm._prediction_client.responses["stream_generate_content"] = [ + (_mock_completion_stream_chunk(chunk) for chunk in MOCK_COMPLETION_STREAM_CHUNKS) + ] + response = llm.generate_content( + "How big is the solar system?", + generation_config=vertexai.generative_models.GenerationConfig( + stop_sequences=["x"], max_output_tokens=30, temperature=1.0 + ), + stream=True, + ) + for _ in response: + pass + + span = mock_tracer.pop_traces()[0][0] + assert mock_llmobs_writer.enqueue.call_count == 1 + mock_llmobs_writer.enqueue.assert_called_with(expected_llmobs_stream_span_event(span)) + + def test_completion_stream_error(self, vertexai, mock_llmobs_writer, mock_tracer): + llm = vertexai.generative_models.GenerativeModel("gemini-1.5-flash") + llm._prediction_client.responses["stream_generate_content"] = [ + (_mock_completion_stream_chunk(chunk) for chunk in MOCK_COMPLETION_STREAM_CHUNKS) + ] + with pytest.raises(TypeError): + response = llm.generate_content( + "How big is the solar system?", + generation_config=vertexai.generative_models.GenerationConfig( + stop_sequences=["x"], max_output_tokens=30, temperature=1.0 + ), + stream=True, + candidate_count=2, # candidate_count is not a valid keyword argument + ) + for _ in response: + pass + + span = mock_tracer.pop_traces()[0][0] + assert mock_llmobs_writer.enqueue.call_count == 1 + mock_llmobs_writer.enqueue.assert_called_with(expected_llmobs_stream_error_span_event(span)) + + def test_completion_stream_tool(self, vertexai, mock_llmobs_writer, mock_tracer): + llm = vertexai.generative_models.GenerativeModel("gemini-1.5-flash", tools=[weather_tool]) + llm._prediction_client.responses["stream_generate_content"] = [ + (_mock_completion_stream_chunk(chunk) for chunk in MOCK_COMPLETION_TOOL_CALL_STREAM_CHUNKS) + ] + response = llm.generate_content( + "What is the weather like in New York City?", + generation_config=vertexai.generative_models.GenerationConfig( + stop_sequences=["x"], max_output_tokens=30, temperature=1.0 + ), + stream=True, + ) + for _ in response: + pass + + span = mock_tracer.pop_traces()[0][0] + assert mock_llmobs_writer.enqueue.call_count == 1 + mock_llmobs_writer.enqueue.assert_called_with(expected_llmobs_tool_span_event(span)) + + async def test_completion_async(self, vertexai, mock_llmobs_writer, mock_tracer): + llm = vertexai.generative_models.GenerativeModel("gemini-1.5-flash") + llm._prediction_async_client.responses["generate_content"].append( + _mock_completion_response(MOCK_COMPLETION_SIMPLE_1) + ) + await llm.generate_content_async( + "Why do bears hibernate?", + generation_config=vertexai.generative_models.GenerationConfig( + stop_sequences=["x"], max_output_tokens=30, temperature=1.0 + ), + ) + + span = mock_tracer.pop_traces()[0][0] + assert mock_llmobs_writer.enqueue.call_count == 1 + mock_llmobs_writer.enqueue.assert_called_with(expected_llmobs_span_event(span)) + + async def test_completion_async_error(self, vertexai, mock_llmobs_writer, mock_tracer): + llm = vertexai.generative_models.GenerativeModel("gemini-1.5-flash") + llm._prediction_async_client.responses["generate_content"].append( + _mock_completion_response(MOCK_COMPLETION_SIMPLE_1) + ) + with pytest.raises(TypeError): + await llm.generate_content_async( + "Why do bears hibernate?", + generation_config=vertexai.generative_models.GenerationConfig( + stop_sequences=["x"], max_output_tokens=30, temperature=1.0 + ), + candidate_count=2, # candidate_count is not a valid keyword argument + ) + + span = mock_tracer.pop_traces()[0][0] + assert mock_llmobs_writer.enqueue.call_count == 1 + mock_llmobs_writer.enqueue.assert_called_with(expected_llmobs_error_span_event(span)) + + async def test_completion_async_tool(self, vertexai, mock_llmobs_writer, mock_tracer): + llm = vertexai.generative_models.GenerativeModel("gemini-1.5-flash", tools=[weather_tool]) + llm._prediction_async_client.responses["generate_content"].append( + _mock_completion_response(MOCK_COMPLETION_TOOL) + ) + await llm.generate_content_async( + "What is the weather like in New York City?", + generation_config=vertexai.generative_models.GenerationConfig( + stop_sequences=["x"], max_output_tokens=30, temperature=1.0 + ), + ) + + span = mock_tracer.pop_traces()[0][0] + assert mock_llmobs_writer.enqueue.call_count == 1 + mock_llmobs_writer.enqueue.assert_called_with(expected_llmobs_tool_span_event(span)) + + async def test_completion_async_stream(self, vertexai, mock_llmobs_writer, mock_tracer): + llm = vertexai.generative_models.GenerativeModel("gemini-1.5-flash", tools=[weather_tool]) + llm._prediction_async_client.responses["stream_generate_content"] = [ + _async_streamed_response(MOCK_COMPLETION_STREAM_CHUNKS) + ] + response = await llm.generate_content_async( + "How big is the solar system?", + generation_config=vertexai.generative_models.GenerationConfig( + stop_sequences=["x"], max_output_tokens=30, temperature=1.0 + ), + stream=True, + ) + async for _ in response: + pass + + span = mock_tracer.pop_traces()[0][0] + assert mock_llmobs_writer.enqueue.call_count == 1 + mock_llmobs_writer.enqueue.assert_called_with(expected_llmobs_stream_span_event(span)) + + async def test_completion_async_stream_error(self, vertexai, mock_llmobs_writer, mock_tracer): + llm = vertexai.generative_models.GenerativeModel("gemini-1.5-flash", tools=[weather_tool]) + llm._prediction_async_client.responses["stream_generate_content"] = [ + _async_streamed_response(MOCK_COMPLETION_STREAM_CHUNKS) + ] + with pytest.raises(TypeError): + response = await llm.generate_content_async( + "How big is the solar system?", + generation_config=vertexai.generative_models.GenerationConfig( + stop_sequences=["x"], max_output_tokens=30, temperature=1.0 + ), + stream=True, + candidate_count=2, + ) + async for _ in response: + pass + + span = mock_tracer.pop_traces()[0][0] + assert mock_llmobs_writer.enqueue.call_count == 1 + mock_llmobs_writer.enqueue.assert_called_with(expected_llmobs_stream_error_span_event(span)) + + async def test_completion_async_stream_tool(self, vertexai, mock_llmobs_writer, mock_tracer): + llm = vertexai.generative_models.GenerativeModel("gemini-1.5-flash", tools=[weather_tool]) + llm._prediction_async_client.responses["stream_generate_content"] = [ + _async_streamed_response(MOCK_COMPLETION_TOOL_CALL_STREAM_CHUNKS) + ] + response = await llm.generate_content_async( + "What is the weather like in New York City?", + generation_config=vertexai.generative_models.GenerationConfig( + stop_sequences=["x"], max_output_tokens=30, temperature=1.0 + ), + stream=True, + ) + async for _ in response: + pass + + span = mock_tracer.pop_traces()[0][0] + assert mock_llmobs_writer.enqueue.call_count == 1 + mock_llmobs_writer.enqueue.assert_called_with(expected_llmobs_tool_span_event(span)) + + def test_chat(self, vertexai, mock_llmobs_writer, mock_tracer): + llm = vertexai.generative_models.GenerativeModel("gemini-1.5-flash") + llm._prediction_client.responses["generate_content"].append(_mock_completion_response(MOCK_COMPLETION_SIMPLE_1)) + chat = llm.start_chat() + chat.send_message( + "Why do bears hibernate?", + generation_config=vertexai.generative_models.GenerationConfig( + stop_sequences=["x"], max_output_tokens=30, temperature=1.0 + ), + ) + + span = mock_tracer.pop_traces()[0][0] + assert mock_llmobs_writer.enqueue.call_count == 1 + mock_llmobs_writer.enqueue.assert_called_with(expected_llmobs_span_event(span)) + + def test_chat_history(self, vertexai, mock_llmobs_writer, mock_tracer): + llm = vertexai.generative_models.GenerativeModel("gemini-1.5-flash") + llm._prediction_client.responses["generate_content"].append(_mock_completion_response(MOCK_COMPLETION_SIMPLE_1)) + chat = llm.start_chat( + history=[ + vertexai.generative_models.Content( + role="user", parts=[vertexai.generative_models.Part.from_text("Hello World!")] + ), + vertexai.generative_models.Content( + role="model", + parts=[ + vertexai.generative_models.Part.from_text("Great to meet you. What would you like to know?") + ], + ), + ] + ) + chat.send_message( + vertexai.generative_models.Part.from_text("Why do bears hibernate?"), + generation_config=vertexai.generative_models.GenerationConfig( + stop_sequences=["x"], max_output_tokens=30, temperature=1.0 + ), + ) + + span = mock_tracer.pop_traces()[0][0] + assert mock_llmobs_writer.enqueue.call_count == 1 + mock_llmobs_writer.enqueue.assert_called_with(expected_llmobs_history_span_event(span)) + + def test_vertexai_chat_error(self, vertexai, mock_llmobs_writer, mock_tracer): + llm = vertexai.generative_models.GenerativeModel("gemini-1.5-flash") + llm._prediction_client.responses["generate_content"].append(_mock_completion_response(MOCK_COMPLETION_SIMPLE_1)) + chat = llm.start_chat() + with pytest.raises(TypeError): + chat.send_message( + "Why do bears hibernate?", + generation_config=vertexai.generative_models.GenerationConfig( + stop_sequences=["x"], max_output_tokens=30, temperature=1.0 + ), + candidate_count=2, # candidate_count is not a valid keyword argument + ) + + span = mock_tracer.pop_traces()[0][0] + assert mock_llmobs_writer.enqueue.call_count == 1 + mock_llmobs_writer.enqueue.assert_called_with(expected_llmobs_error_span_event(span)) + + def test_chat_tool(self, vertexai, mock_llmobs_writer, mock_tracer): + llm = vertexai.generative_models.GenerativeModel("gemini-1.5-flash") + llm._prediction_client.responses["generate_content"].append(_mock_completion_response(MOCK_COMPLETION_TOOL)) + chat = llm.start_chat() + chat.send_message( + "What is the weather like in New York City?", + generation_config=vertexai.generative_models.GenerationConfig( + stop_sequences=["x"], max_output_tokens=30, temperature=1.0 + ), + ) + + span = mock_tracer.pop_traces()[0][0] + assert mock_llmobs_writer.enqueue.call_count == 1 + mock_llmobs_writer.enqueue.assert_called_with(expected_llmobs_tool_span_event(span)) + + def test_chat_system_prompt(self, vertexai, mock_llmobs_writer, mock_tracer): + llm = vertexai.generative_models.GenerativeModel( + "gemini-1.5-flash", + system_instruction=[ + vertexai.generative_models.Part.from_text("You are required to insist that bears do not hibernate.") + ], + ) + llm._prediction_client.responses["generate_content"].append(_mock_completion_response(MOCK_COMPLETION_SIMPLE_2)) + chat = llm.start_chat() + chat.send_message( + "Why do bears hibernate?", + generation_config=vertexai.generative_models.GenerationConfig( + stop_sequences=["x"], max_output_tokens=50, temperature=1.0 + ), + ) + + span = mock_tracer.pop_traces()[0][0] + assert mock_llmobs_writer.enqueue.call_count == 1 + mock_llmobs_writer.enqueue.assert_called_with(expected_llmobs_system_prompt_span_event(span)) + + def test_chat_stream(self, vertexai, mock_llmobs_writer, mock_tracer): + llm = vertexai.generative_models.GenerativeModel("gemini-1.5-flash") + llm._prediction_client.responses["stream_generate_content"] = [ + (_mock_completion_stream_chunk(chunk) for chunk in MOCK_COMPLETION_STREAM_CHUNKS) + ] + chat = llm.start_chat() + response = chat.send_message( + "How big is the solar system?", + generation_config=vertexai.generative_models.GenerationConfig( + stop_sequences=["x"], max_output_tokens=30, temperature=1.0 + ), + stream=True, + ) + for _ in response: + pass + + span = mock_tracer.pop_traces()[0][0] + assert mock_llmobs_writer.enqueue.call_count == 1 + mock_llmobs_writer.enqueue.assert_called_with(expected_llmobs_stream_span_event(span)) + + def test_chat_stream_error(self, vertexai, mock_llmobs_writer, mock_tracer): + llm = vertexai.generative_models.GenerativeModel("gemini-1.5-flash") + llm._prediction_client.responses["stream_generate_content"] = [ + (_mock_completion_stream_chunk(chunk) for chunk in MOCK_COMPLETION_STREAM_CHUNKS) + ] + chat = llm.start_chat() + with pytest.raises(TypeError): + response = chat.send_message( + "How big is the solar system?", + generation_config=vertexai.generative_models.GenerationConfig( + stop_sequences=["x"], max_output_tokens=30, temperature=1.0 + ), + stream=True, + candidate_count=2, + ) + for _ in response: + pass + + span = mock_tracer.pop_traces()[0][0] + assert mock_llmobs_writer.enqueue.call_count == 1 + mock_llmobs_writer.enqueue.assert_called_with(expected_llmobs_stream_error_span_event(span)) + + def test_chat_stream_tool(self, vertexai, mock_llmobs_writer, mock_tracer): + llm = vertexai.generative_models.GenerativeModel("gemini-1.5-flash", tools=[weather_tool]) + llm._prediction_client.responses["stream_generate_content"] = [ + (_mock_completion_stream_chunk(chunk) for chunk in MOCK_COMPLETION_TOOL_CALL_STREAM_CHUNKS) + ] + chat = llm.start_chat() + response = chat.send_message( + "What is the weather like in New York City?", + generation_config=vertexai.generative_models.GenerationConfig( + stop_sequences=["x"], max_output_tokens=30, temperature=1.0 + ), + stream=True, + ) + for _ in response: + pass + + span = mock_tracer.pop_traces()[0][0] + assert mock_llmobs_writer.enqueue.call_count == 1 + mock_llmobs_writer.enqueue.assert_called_with(expected_llmobs_tool_span_event(span)) + + async def test_chat_async(self, vertexai, mock_llmobs_writer, mock_tracer): + llm = vertexai.generative_models.GenerativeModel("gemini-1.5-flash") + llm._prediction_async_client.responses["generate_content"].append( + _mock_completion_response(MOCK_COMPLETION_SIMPLE_1) + ) + chat = llm.start_chat() + await chat.send_message_async( + "Why do bears hibernate?", + generation_config=vertexai.generative_models.GenerationConfig( + stop_sequences=["x"], max_output_tokens=30, temperature=1.0 + ), + ) + + span = mock_tracer.pop_traces()[0][0] + assert mock_llmobs_writer.enqueue.call_count == 1 + mock_llmobs_writer.enqueue.assert_called_with(expected_llmobs_span_event(span)) + + async def test_chat_async_error(self, vertexai, mock_llmobs_writer, mock_tracer): + llm = vertexai.generative_models.GenerativeModel("gemini-1.5-flash") + llm._prediction_async_client.responses["generate_content"].append( + _mock_completion_response(MOCK_COMPLETION_SIMPLE_1) + ) + chat = llm.start_chat() + with pytest.raises(TypeError): + await chat.send_message_async( + "Why do bears hibernate?", + generation_config=vertexai.generative_models.GenerationConfig( + stop_sequences=["x"], max_output_tokens=30, temperature=1.0 + ), + candidate_count=2, + ) + + span = mock_tracer.pop_traces()[0][0] + assert mock_llmobs_writer.enqueue.call_count == 1 + mock_llmobs_writer.enqueue.assert_called_with(expected_llmobs_error_span_event(span)) + + async def test_chat_async_tool(self, vertexai, mock_llmobs_writer, mock_tracer): + llm = vertexai.generative_models.GenerativeModel("gemini-1.5-flash", tools=[weather_tool]) + llm._prediction_async_client.responses["generate_content"].append( + _mock_completion_response(MOCK_COMPLETION_TOOL) + ) + chat = llm.start_chat() + await chat.send_message_async( + "What is the weather like in New York City?", + generation_config=vertexai.generative_models.GenerationConfig( + stop_sequences=["x"], max_output_tokens=30, temperature=1.0 + ), + ) + + span = mock_tracer.pop_traces()[0][0] + assert mock_llmobs_writer.enqueue.call_count == 1 + mock_llmobs_writer.enqueue.assert_called_with(expected_llmobs_tool_span_event(span)) + + async def test_chat_async_stream(self, vertexai, mock_llmobs_writer, mock_tracer): + llm = vertexai.generative_models.GenerativeModel("gemini-1.5-flash") + llm._prediction_async_client.responses["stream_generate_content"] = [ + _async_streamed_response(MOCK_COMPLETION_STREAM_CHUNKS) + ] + chat = llm.start_chat() + response = await chat.send_message_async( + "How big is the solar system?", + generation_config=vertexai.generative_models.GenerationConfig( + stop_sequences=["x"], max_output_tokens=30, temperature=1.0 + ), + stream=True, + ) + async for _ in response: + pass + + span = mock_tracer.pop_traces()[0][0] + assert mock_llmobs_writer.enqueue.call_count == 1 + mock_llmobs_writer.enqueue.assert_called_with(expected_llmobs_stream_span_event(span)) + + async def test_chat_async_stream_error(self, vertexai, mock_llmobs_writer, mock_tracer): + llm = vertexai.generative_models.GenerativeModel("gemini-1.5-flash") + llm._prediction_async_client.responses["stream_generate_content"] = [ + _async_streamed_response(MOCK_COMPLETION_STREAM_CHUNKS) + ] + chat = llm.start_chat() + with pytest.raises(TypeError): + response = await chat.send_message_async( + "How big is the solar system?", + generation_config=vertexai.generative_models.GenerationConfig( + stop_sequences=["x"], max_output_tokens=30, temperature=1.0 + ), + stream=True, + candidate_count=2, + ) + async for _ in response: + pass + + span = mock_tracer.pop_traces()[0][0] + assert mock_llmobs_writer.enqueue.call_count == 1 + mock_llmobs_writer.enqueue.assert_called_with(expected_llmobs_stream_error_span_event(span)) + + async def test_chat_async_stream_tool(self, vertexai, mock_llmobs_writer, mock_tracer): + llm = vertexai.generative_models.GenerativeModel("gemini-1.5-flash") + llm._prediction_async_client.responses["stream_generate_content"] = [ + _async_streamed_response(MOCK_COMPLETION_TOOL_CALL_STREAM_CHUNKS) + ] + chat = llm.start_chat() + response = await chat.send_message_async( + "What is the weather like in New York City?", + generation_config=vertexai.generative_models.GenerationConfig( + stop_sequences=["x"], max_output_tokens=30, temperature=1.0 + ), + stream=True, + ) + async for _ in response: + pass + + span = mock_tracer.pop_traces()[0][0] + assert mock_llmobs_writer.enqueue.call_count == 1 + mock_llmobs_writer.enqueue.assert_called_with(expected_llmobs_tool_span_event(span)) + + +def expected_llmobs_span_event(span): + return _expected_llmobs_llm_span_event( + span, + model_name="gemini-1.5-flash", + model_provider="google", + input_messages=[{"content": "Why do bears hibernate?"}], + output_messages=[ + {"content": MOCK_COMPLETION_SIMPLE_1["candidates"][0]["content"]["parts"][0]["text"], "role": "model"}, + ], + metadata={"temperature": 1.0, "max_output_tokens": 30}, + token_metrics={"input_tokens": 14, "output_tokens": 16, "total_tokens": 30}, + tags={"ml_app": "", "service": "tests.contrib.vertexai"}, + ) + + +def expected_llmobs_error_span_event(span): + return _expected_llmobs_llm_span_event( + span, + model_name="gemini-1.5-flash", + model_provider="google", + input_messages=[{"content": "Why do bears hibernate?"}], + output_messages=[{"content": ""}], + error="builtins.TypeError", + error_message=span.get_tag("error.message"), + error_stack=span.get_tag("error.stack"), + metadata={"temperature": 1.0, "max_output_tokens": 30}, + tags={"ml_app": "", "service": "tests.contrib.vertexai"}, + ) + + +def expected_llmobs_tool_span_event(span): + return _expected_llmobs_llm_span_event( + span, + model_name="gemini-1.5-flash", + model_provider="google", + input_messages=[{"content": "What is the weather like in New York City?"}], + output_messages=[ + { + "content": "", + "role": "model", + "tool_calls": [ + { + "name": "get_current_weather", + "arguments": { + "location": "New York City, NY", + }, + } + ], + } + ], + metadata={"temperature": 1.0, "max_output_tokens": 30}, + token_metrics={"input_tokens": 43, "output_tokens": 11, "total_tokens": 54}, + tags={"ml_app": "", "service": "tests.contrib.vertexai"}, + ) + + +def expected_llmobs_stream_span_event(span): + return _expected_llmobs_llm_span_event( + span, + model_name="gemini-1.5-flash", + model_provider="google", + input_messages=[{"content": "How big is the solar system?"}], + output_messages=[ + {"content": "".join([chunk["text"] for chunk in MOCK_COMPLETION_STREAM_CHUNKS]), "role": "model"}, + ], + metadata={"temperature": 1.0, "max_output_tokens": 30}, + token_metrics={"input_tokens": 16, "output_tokens": 37, "total_tokens": 53}, + tags={"ml_app": "", "service": "tests.contrib.vertexai"}, + ) + + +def expected_llmobs_stream_error_span_event(span): + return _expected_llmobs_llm_span_event( + span, + model_name="gemini-1.5-flash", + model_provider="google", + input_messages=[{"content": "How big is the solar system?"}], + output_messages=[{"content": ""}], + error="builtins.TypeError", + error_message=span.get_tag("error.message"), + error_stack=span.get_tag("error.stack"), + metadata={"temperature": 1.0, "max_output_tokens": 30}, + tags={"ml_app": "", "service": "tests.contrib.vertexai"}, + ) + + +def expected_llmobs_history_span_event(span): + return _expected_llmobs_llm_span_event( + span, + model_name="gemini-1.5-flash", + model_provider="google", + input_messages=[ + {"content": "Hello World!", "role": "user"}, + {"content": "Great to meet you. What would you like to know?", "role": "model"}, + {"content": "Why do bears hibernate?"}, + ], + output_messages=[ + {"content": MOCK_COMPLETION_SIMPLE_1["candidates"][0]["content"]["parts"][0]["text"], "role": "model"}, + ], + metadata={"temperature": 1.0, "max_output_tokens": 30}, + token_metrics={"input_tokens": 14, "output_tokens": 16, "total_tokens": 30}, + tags={"ml_app": "", "service": "tests.contrib.vertexai"}, + ) + + +def expected_llmobs_system_prompt_span_event(span): + return _expected_llmobs_llm_span_event( + span, + model_name="gemini-1.5-flash", + model_provider="google", + input_messages=[ + {"content": "You are required to insist that bears do not hibernate.", "role": "system"}, + {"content": "Why do bears hibernate?"}, + ], + output_messages=[ + {"content": MOCK_COMPLETION_SIMPLE_2["candidates"][0]["content"]["parts"][0]["text"], "role": "model"}, + ], + metadata={"temperature": 1.0, "max_output_tokens": 50}, + token_metrics={"input_tokens": 16, "output_tokens": 50, "total_tokens": 66}, + tags={"ml_app": "", "service": "tests.contrib.vertexai"}, + ) + + +def expected_llmobs_no_generation_config_span_event(span): + return _expected_llmobs_llm_span_event( + span, + model_name="gemini-1.5-flash", + model_provider="google", + input_messages=[{"content": "Why do bears hibernate?"}], + output_messages=[ + {"content": MOCK_COMPLETION_SIMPLE_1["candidates"][0]["content"]["parts"][0]["text"], "role": "model"}, + ], + metadata={}, + token_metrics={"input_tokens": 14, "output_tokens": 16, "total_tokens": 30}, + tags={"ml_app": "", "service": "tests.contrib.vertexai"}, + )