From d06fe7be123aab6d350508eb5be0c593d95ad333 Mon Sep 17 00:00:00 2001 From: Piotr C Zielinski Date: Wed, 26 Jun 2024 15:45:01 +0200 Subject: [PATCH 01/14] MINOTAUR-1124 | Make it possible to run allms with python 3.8 --- allms/models/__init__.py | 4 +- allms/models/abstract.py | 4 +- poetry.lock | 102 +++++++++++++++++++++++---------------- pyproject.toml | 2 +- tests/conftest.py | 15 +++--- 5 files changed, 73 insertions(+), 54 deletions(-) diff --git a/allms/models/__init__.py b/allms/models/__init__.py index 9780f1b..3087552 100644 --- a/allms/models/__init__.py +++ b/allms/models/__init__.py @@ -1,4 +1,4 @@ -from typing import Type +from typing import Dict, Type from allms.domain.enumerables import AvailableModels from allms.models.abstract import AbstractModel @@ -20,7 +20,7 @@ ] -def get_available_models() -> dict[str, Type[AbstractModel]]: +def get_available_models() -> Dict[str, Type[AbstractModel]]: return { AvailableModels.AZURE_OPENAI_MODEL: AzureOpenAIModel, AvailableModels.AZURE_LLAMA2_MODEL: AzureLlama2Model, diff --git a/allms/models/abstract.py b/allms/models/abstract.py index 376bd6c..31f8547 100644 --- a/allms/models/abstract.py +++ b/allms/models/abstract.py @@ -167,7 +167,7 @@ async def _build_chat_prompts( self, prompt_template_args: dict, system_prompt: SystemMessagePromptTemplate - ) -> list[SystemMessagePromptTemplate | HumanMessagePromptTemplate]: + ) -> typing.List[typing.Union[SystemMessagePromptTemplate, HumanMessagePromptTemplate]]: human_message = HumanMessagePromptTemplate(prompt=PromptTemplate(**prompt_template_args)) if not system_prompt: return [human_message] @@ -330,7 +330,7 @@ def _validate_system_prompt(self, system_prompt: typing.Optional[str] = None) -> raise ValueError(input_exception_message.get_system_prompt_contains_input_variables()) @staticmethod - def _extract_input_variables_from_prompt(prompt: str) -> set[str]: + def _extract_input_variables_from_prompt(prompt: str) -> typing.Set[str]: input_variables_pattern = r'(?=4.6", markers = "python_version < \"3.10\""} packaging = ">=19.0" pyproject_hooks = "*" tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} @@ -1218,6 +1219,24 @@ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.link perf = ["ipython"] testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)", "pytest-ruff"] +[[package]] +name = "importlib-resources" +version = "6.4.0" +description = "Read resources from Python packages" +optional = false +python-versions = ">=3.8" +files = [ + {file = "importlib_resources-6.4.0-py3-none-any.whl", hash = "sha256:50d10f043df931902d4194ea07ec57960f66a80449ff867bfe782b4c486ba78c"}, + {file = "importlib_resources-6.4.0.tar.gz", hash = "sha256:cdb2b453b8046ca4e3798eb1d84f3cce1446a0e8e7b5ef4efb600f19fc398145"}, +] + +[package.dependencies] +zipp = {version = ">=3.1.0", markers = "python_version < \"3.10\""} + +[package.extras] +docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"] +testing = ["jaraco.test (>=5.4)", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-ruff (>=0.2.1)", "zipp (>=3.17)"] + [[package]] name = "iniconfig" version = "2.0.0" @@ -1331,6 +1350,7 @@ files = [ [package.dependencies] importlib-metadata = {version = ">=4.11.4", markers = "python_version < \"3.12\""} +importlib-resources = {version = "*", markers = "python_version < \"3.9\""} "jaraco.classes" = "*" jeepney = {version = ">=0.4.2", markers = "sys_platform == \"linux\""} pywin32-ctypes = {version = ">=0.2.0", markers = "sys_platform == \"win32\""} @@ -1503,6 +1523,9 @@ files = [ {file = "Markdown-3.5.2.tar.gz", hash = "sha256:e1ac7b3dc550ee80e602e71c1d168002f062e49f1b11e26a36264dafd4df2ef8"}, ] +[package.dependencies] +importlib-metadata = {version = ">=4.4", markers = "python_version < \"3.10\""} + [package.extras] docs = ["mdx-gh-links (>=0.2)", "mkdocs (>=1.5)", "mkdocs-gen-files", "mkdocs-literate-nav", "mkdocs-nature (>=0.6)", "mkdocs-section-index", "mkdocstrings[python]"] testing = ["coverage", "pyyaml"] @@ -1668,6 +1691,7 @@ files = [ click = ">=7.0" colorama = {version = ">=0.4", markers = "platform_system == \"Windows\""} ghp-import = ">=1.0" +importlib-metadata = {version = ">=4.3", markers = "python_version < \"3.10\""} jinja2 = ">=2.11.1" markdown = ">=3.2.1" markupsafe = ">=2.0.1" @@ -1815,47 +1839,39 @@ files = [ [[package]] name = "numpy" -version = "1.26.3" +version = "1.24.4" description = "Fundamental package for array computing in Python" optional = false -python-versions = ">=3.9" -files = [ - {file = "numpy-1.26.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:806dd64230dbbfaca8a27faa64e2f414bf1c6622ab78cc4264f7f5f028fee3bf"}, - {file = "numpy-1.26.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:02f98011ba4ab17f46f80f7f8f1c291ee7d855fcef0a5a98db80767a468c85cd"}, - {file = "numpy-1.26.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6d45b3ec2faed4baca41c76617fcdcfa4f684ff7a151ce6fc78ad3b6e85af0a6"}, - {file = "numpy-1.26.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bdd2b45bf079d9ad90377048e2747a0c82351989a2165821f0c96831b4a2a54b"}, - {file = "numpy-1.26.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:211ddd1e94817ed2d175b60b6374120244a4dd2287f4ece45d49228b4d529178"}, - {file = "numpy-1.26.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b1240f767f69d7c4c8a29adde2310b871153df9b26b5cb2b54a561ac85146485"}, - {file = "numpy-1.26.3-cp310-cp310-win32.whl", hash = "sha256:21a9484e75ad018974a2fdaa216524d64ed4212e418e0a551a2d83403b0531d3"}, - {file = "numpy-1.26.3-cp310-cp310-win_amd64.whl", hash = "sha256:9e1591f6ae98bcfac2a4bbf9221c0b92ab49762228f38287f6eeb5f3f55905ce"}, - {file = "numpy-1.26.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b831295e5472954104ecb46cd98c08b98b49c69fdb7040483aff799a755a7374"}, - {file = "numpy-1.26.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9e87562b91f68dd8b1c39149d0323b42e0082db7ddb8e934ab4c292094d575d6"}, - {file = "numpy-1.26.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c66d6fec467e8c0f975818c1796d25c53521124b7cfb760114be0abad53a0a2"}, - {file = "numpy-1.26.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f25e2811a9c932e43943a2615e65fc487a0b6b49218899e62e426e7f0a57eeda"}, - {file = "numpy-1.26.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:af36e0aa45e25c9f57bf684b1175e59ea05d9a7d3e8e87b7ae1a1da246f2767e"}, - {file = "numpy-1.26.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:51c7f1b344f302067b02e0f5b5d2daa9ed4a721cf49f070280ac202738ea7f00"}, - {file = "numpy-1.26.3-cp311-cp311-win32.whl", hash = "sha256:7ca4f24341df071877849eb2034948459ce3a07915c2734f1abb4018d9c49d7b"}, - {file = "numpy-1.26.3-cp311-cp311-win_amd64.whl", hash = "sha256:39763aee6dfdd4878032361b30b2b12593fb445ddb66bbac802e2113eb8a6ac4"}, - {file = "numpy-1.26.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:a7081fd19a6d573e1a05e600c82a1c421011db7935ed0d5c483e9dd96b99cf13"}, - {file = "numpy-1.26.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:12c70ac274b32bc00c7f61b515126c9205323703abb99cd41836e8125ea0043e"}, - {file = "numpy-1.26.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7f784e13e598e9594750b2ef6729bcd5a47f6cfe4a12cca13def35e06d8163e3"}, - {file = "numpy-1.26.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5f24750ef94d56ce6e33e4019a8a4d68cfdb1ef661a52cdaee628a56d2437419"}, - {file = "numpy-1.26.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:77810ef29e0fb1d289d225cabb9ee6cf4d11978a00bb99f7f8ec2132a84e0166"}, - {file = "numpy-1.26.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8ed07a90f5450d99dad60d3799f9c03c6566709bd53b497eb9ccad9a55867f36"}, - {file = "numpy-1.26.3-cp312-cp312-win32.whl", hash = "sha256:f73497e8c38295aaa4741bdfa4fda1a5aedda5473074369eca10626835445511"}, - {file = "numpy-1.26.3-cp312-cp312-win_amd64.whl", hash = "sha256:da4b0c6c699a0ad73c810736303f7fbae483bcb012e38d7eb06a5e3b432c981b"}, - {file = "numpy-1.26.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1666f634cb3c80ccbd77ec97bc17337718f56d6658acf5d3b906ca03e90ce87f"}, - {file = "numpy-1.26.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:18c3319a7d39b2c6a9e3bb75aab2304ab79a811ac0168a671a62e6346c29b03f"}, - {file = "numpy-1.26.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0b7e807d6888da0db6e7e75838444d62495e2b588b99e90dd80c3459594e857b"}, - {file = "numpy-1.26.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b4d362e17bcb0011738c2d83e0a65ea8ce627057b2fdda37678f4374a382a137"}, - {file = "numpy-1.26.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b8c275f0ae90069496068c714387b4a0eba5d531aace269559ff2b43655edd58"}, - {file = "numpy-1.26.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:cc0743f0302b94f397a4a65a660d4cd24267439eb16493fb3caad2e4389bccbb"}, - {file = "numpy-1.26.3-cp39-cp39-win32.whl", hash = "sha256:9bc6d1a7f8cedd519c4b7b1156d98e051b726bf160715b769106661d567b3f03"}, - {file = "numpy-1.26.3-cp39-cp39-win_amd64.whl", hash = "sha256:867e3644e208c8922a3be26fc6bbf112a035f50f0a86497f98f228c50c607bb2"}, - {file = "numpy-1.26.3-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:3c67423b3703f8fbd90f5adaa37f85b5794d3366948efe9a5190a5f3a83fc34e"}, - {file = "numpy-1.26.3-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:46f47ee566d98849323f01b349d58f2557f02167ee301e5e28809a8c0e27a2d0"}, - {file = "numpy-1.26.3-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:a8474703bffc65ca15853d5fd4d06b18138ae90c17c8d12169968e998e448bb5"}, - {file = "numpy-1.26.3.tar.gz", hash = "sha256:697df43e2b6310ecc9d95f05d5ef20eacc09c7c4ecc9da3f235d39e71b7da1e4"}, +python-versions = ">=3.8" +files = [ + {file = "numpy-1.24.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c0bfb52d2169d58c1cdb8cc1f16989101639b34c7d3ce60ed70b19c63eba0b64"}, + {file = "numpy-1.24.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ed094d4f0c177b1b8e7aa9cba7d6ceed51c0e569a5318ac0ca9a090680a6a1b1"}, + {file = "numpy-1.24.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:79fc682a374c4a8ed08b331bef9c5f582585d1048fa6d80bc6c35bc384eee9b4"}, + {file = "numpy-1.24.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ffe43c74893dbf38c2b0a1f5428760a1a9c98285553c89e12d70a96a7f3a4d6"}, + {file = "numpy-1.24.4-cp310-cp310-win32.whl", hash = "sha256:4c21decb6ea94057331e111a5bed9a79d335658c27ce2adb580fb4d54f2ad9bc"}, + {file = "numpy-1.24.4-cp310-cp310-win_amd64.whl", hash = "sha256:b4bea75e47d9586d31e892a7401f76e909712a0fd510f58f5337bea9572c571e"}, + {file = "numpy-1.24.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f136bab9c2cfd8da131132c2cf6cc27331dd6fae65f95f69dcd4ae3c3639c810"}, + {file = "numpy-1.24.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e2926dac25b313635e4d6cf4dc4e51c8c0ebfed60b801c799ffc4c32bf3d1254"}, + {file = "numpy-1.24.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:222e40d0e2548690405b0b3c7b21d1169117391c2e82c378467ef9ab4c8f0da7"}, + {file = "numpy-1.24.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7215847ce88a85ce39baf9e89070cb860c98fdddacbaa6c0da3ffb31b3350bd5"}, + {file = "numpy-1.24.4-cp311-cp311-win32.whl", hash = "sha256:4979217d7de511a8d57f4b4b5b2b965f707768440c17cb70fbf254c4b225238d"}, + {file = "numpy-1.24.4-cp311-cp311-win_amd64.whl", hash = "sha256:b7b1fc9864d7d39e28f41d089bfd6353cb5f27ecd9905348c24187a768c79694"}, + {file = "numpy-1.24.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1452241c290f3e2a312c137a9999cdbf63f78864d63c79039bda65ee86943f61"}, + {file = "numpy-1.24.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:04640dab83f7c6c85abf9cd729c5b65f1ebd0ccf9de90b270cd61935eef0197f"}, + {file = "numpy-1.24.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a5425b114831d1e77e4b5d812b69d11d962e104095a5b9c3b641a218abcc050e"}, + {file = "numpy-1.24.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd80e219fd4c71fc3699fc1dadac5dcf4fd882bfc6f7ec53d30fa197b8ee22dc"}, + {file = "numpy-1.24.4-cp38-cp38-win32.whl", hash = "sha256:4602244f345453db537be5314d3983dbf5834a9701b7723ec28923e2889e0bb2"}, + {file = "numpy-1.24.4-cp38-cp38-win_amd64.whl", hash = "sha256:692f2e0f55794943c5bfff12b3f56f99af76f902fc47487bdfe97856de51a706"}, + {file = "numpy-1.24.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:2541312fbf09977f3b3ad449c4e5f4bb55d0dbf79226d7724211acc905049400"}, + {file = "numpy-1.24.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9667575fb6d13c95f1b36aca12c5ee3356bf001b714fc354eb5465ce1609e62f"}, + {file = "numpy-1.24.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f3a86ed21e4f87050382c7bc96571755193c4c1392490744ac73d660e8f564a9"}, + {file = "numpy-1.24.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d11efb4dbecbdf22508d55e48d9c8384db795e1b7b51ea735289ff96613ff74d"}, + {file = "numpy-1.24.4-cp39-cp39-win32.whl", hash = "sha256:6620c0acd41dbcb368610bb2f4d83145674040025e5536954782467100aa8835"}, + {file = "numpy-1.24.4-cp39-cp39-win_amd64.whl", hash = "sha256:befe2bf740fd8373cf56149a5c23a0f601e82869598d41f8e188a0e9869926f8"}, + {file = "numpy-1.24.4-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:31f13e25b4e304632a4619d0e0777662c2ffea99fcae2029556b17d8ff958aef"}, + {file = "numpy-1.24.4-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95f7ac6540e95bc440ad77f56e520da5bf877f87dca58bd095288dce8940532a"}, + {file = "numpy-1.24.4-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:e98f220aa76ca2a977fe435f5b04d7b3470c0a2e6312907b37ba6068f26787f2"}, + {file = "numpy-1.24.4.tar.gz", hash = "sha256:80f5e3a4e498641401868df4208b74581206afbee7cf7b8329daae82676d9463"}, ] [[package]] @@ -2126,6 +2142,7 @@ mccabe = ">=0.6,<0.8" platformdirs = ">=2.2.0" tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} tomlkit = ">=0.10.1" +typing-extensions = {version = ">=3.10.0", markers = "python_version < \"3.10\""} [package.extras] spelling = ["pyenchant (>=3.2,<4.0)"] @@ -2466,6 +2483,7 @@ files = [ [package.dependencies] markdown-it-py = ">=2.2.0" pygments = ">=2.13.0,<3.0.0" +typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.9\""} [package.extras] jupyter = ["ipywidgets (>=7.5.1,<9)"] @@ -3421,5 +3439,5 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" -python-versions = "^3.10" -content-hash = "7915b11fb574bdc236e238e884d7cbc0e43849b0750577f6d1aedd69b00162f6" +python-versions = ">=3.8.1,<4.0" +content-hash = "161eceb20f8c1a1a95987444d1e10f175134ee645aa46d5c60c683b930635c7a" diff --git a/pyproject.toml b/pyproject.toml index d6d97d8..a785866 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ readme = "README.md" packages = [{include = "allms"}] [tool.poetry.dependencies] -python = "^3.10" +python = ">=3.8.1,<4.0" fsspec = "^2023.6.0" google-cloud-aiplatform = "1.38.0" pydash = "^7.0.6" diff --git a/tests/conftest.py b/tests/conftest.py index e28fea9..47e7ad2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,6 @@ import asyncio import typing +from contextlib import ExitStack from dataclasses import dataclass from unittest.mock import patch @@ -36,13 +37,13 @@ def __init__(self, *args, **kwargs): def models(): event_loop = asyncio.new_event_loop() - with ( - patch("allms.models.vertexai_palm.CustomVertexAI", ModelWithoutAsyncRequestsMock), - patch("allms.models.vertexai_gemini.CustomVertexAI", ModelWithoutAsyncRequestsMock), - patch("allms.models.vertexai_gemma.VertexAIModelGardenWrapper", ModelWithoutAsyncRequestsMock), - patch("allms.models.azure_llama2.AzureMLOnlineEndpointAsync", ModelWithoutAsyncRequestsMock), - patch("allms.models.azure_mistral.AzureMLOnlineEndpointAsync", ModelWithoutAsyncRequestsMock) - ): + with ExitStack() as stack: + stack.enter_context(patch("allms.models.vertexai_palm.CustomVertexAI", ModelWithoutAsyncRequestsMock)) + stack.enter_context(patch("allms.models.vertexai_gemini.CustomVertexAI", ModelWithoutAsyncRequestsMock)) + stack.enter_context(patch("allms.models.vertexai_gemma.VertexAIModelGardenWrapper", ModelWithoutAsyncRequestsMock)) + stack.enter_context(patch("allms.models.azure_llama2.AzureMLOnlineEndpointAsync", ModelWithoutAsyncRequestsMock)) + stack.enter_context(patch("allms.models.azure_mistral.AzureMLOnlineEndpointAsync", ModelWithoutAsyncRequestsMock)) + return { "azure_open_ai": AzureOpenAIModel( config=AzureOpenAIConfiguration( From f8bdebc59d52ca382fb59a7c33e18b621232ae0b Mon Sep 17 00:00:00 2001 From: Piotr C Zielinski Date: Wed, 26 Jun 2024 15:46:49 +0200 Subject: [PATCH 02/14] MINOTAUR-1124 | Update version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index a785866..d6d997c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "allms" -version = "1.0.4" +version = "1.0.5" description = "" authors = ["Allegro Opensource "] readme = "README.md" From 95f0801abea7ce1d67e46494b2431f58cab94e30 Mon Sep 17 00:00:00 2001 From: Piotr C Zielinski Date: Fri, 28 Jun 2024 14:34:49 +0200 Subject: [PATCH 03/14] Update langchain --- poetry.lock | 155 ++++++++++++++++++++++++++++++------------------- pyproject.toml | 4 +- 2 files changed, 98 insertions(+), 61 deletions(-) diff --git a/poetry.lock b/poetry.lock index 722ec4b..65141d7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -124,28 +124,6 @@ files = [ [package.dependencies] frozenlist = ">=1.1.0" -[[package]] -name = "anyio" -version = "4.2.0" -description = "High level compatibility layer for multiple asynchronous event loop implementations" -optional = false -python-versions = ">=3.8" -files = [ - {file = "anyio-4.2.0-py3-none-any.whl", hash = "sha256:745843b39e829e108e518c489b31dc757de7d2131d53fac32bd8df268227bfee"}, - {file = "anyio-4.2.0.tar.gz", hash = "sha256:e1875bb4b4e2de1669f4bc7869b6d3f54231cdced71605e6e64c9be77e3be50f"}, -] - -[package.dependencies] -exceptiongroup = {version = ">=1.0.2", markers = "python_version < \"3.11\""} -idna = ">=2.8" -sniffio = ">=1.1" -typing-extensions = {version = ">=4.1", markers = "python_version < \"3.11\""} - -[package.extras] -doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"] -test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"] -trio = ["trio (>=0.23)"] - [[package]] name = "astroid" version = "2.15.8" @@ -1363,13 +1341,13 @@ testing = ["pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", [[package]] name = "langchain" -version = "0.0.351" +version = "0.1.8" description = "Building applications with LLMs through composability" optional = false python-versions = ">=3.8.1,<4.0" files = [ - {file = "langchain-0.0.351-py3-none-any.whl", hash = "sha256:90cdaee27db2b2aeeb7b0709a79cbfe3e858fc9536b6bc3ea262135a6affc70f"}, - {file = "langchain-0.0.351.tar.gz", hash = "sha256:6bf2a8665a7a3ca2bbd4eea9889ecfd3d39ab23a505549a03860272474399b38"}, + {file = "langchain-0.1.8-py3-none-any.whl", hash = "sha256:19e951b0e2be099ff048ee483acecb47e1a39c33a47dadfee70fcfa20f45cc19"}, + {file = "langchain-0.1.8.tar.gz", hash = "sha256:c8b1c2954a07cd6422c9027459473bafae90c78f07015bf2fc6262fadf97ea44"}, ] [package.dependencies] @@ -1377,9 +1355,9 @@ aiohttp = ">=3.8.3,<4.0.0" async-timeout = {version = ">=4.0.0,<5.0.0", markers = "python_version < \"3.11\""} dataclasses-json = ">=0.5.7,<0.7" jsonpatch = ">=1.33,<2.0" -langchain-community = ">=0.0.2,<0.1" -langchain-core = ">=0.1,<0.2" -langsmith = ">=0.0.70,<0.1.0" +langchain-community = ">=0.0.21,<0.1" +langchain-core = ">=0.1.24,<0.2" +langsmith = ">=0.1.0,<0.2.0" numpy = ">=1,<2" pydantic = ">=1,<3" PyYAML = ">=5.3" @@ -1394,7 +1372,7 @@ cli = ["typer (>=0.9.0,<0.10.0)"] cohere = ["cohere (>=4,<5)"] docarray = ["docarray[hnswlib] (>=0.32.0,<0.33.0)"] embeddings = ["sentence-transformers (>=2,<3)"] -extended-testing = ["aiosqlite (>=0.19.0,<0.20.0)", "aleph-alpha-client (>=2.15.0,<3.0.0)", "anthropic (>=0.3.11,<0.4.0)", "arxiv (>=1.4,<2.0)", "assemblyai (>=0.17.0,<0.18.0)", "atlassian-python-api (>=3.36.0,<4.0.0)", "beautifulsoup4 (>=4,<5)", "bibtexparser (>=1.4.0,<2.0.0)", "cassio (>=0.1.0,<0.2.0)", "chardet (>=5.1.0,<6.0.0)", "cohere (>=4,<5)", "couchbase (>=4.1.9,<5.0.0)", "dashvector (>=1.0.1,<2.0.0)", "databricks-vectorsearch (>=0.21,<0.22)", "datasets (>=2.15.0,<3.0.0)", "dgml-utils (>=0.3.0,<0.4.0)", "esprima (>=4.0.1,<5.0.0)", "faiss-cpu (>=1,<2)", "feedparser (>=6.0.10,<7.0.0)", "fireworks-ai (>=0.9.0,<0.10.0)", "geopandas (>=0.13.1,<0.14.0)", "gitpython (>=3.1.32,<4.0.0)", "google-cloud-documentai (>=2.20.1,<3.0.0)", "gql (>=3.4.1,<4.0.0)", "hologres-vector (>=0.0.6,<0.0.7)", "html2text (>=2020.1.16,<2021.0.0)", "javelin-sdk (>=0.1.8,<0.2.0)", "jinja2 (>=3,<4)", "jq (>=1.4.1,<2.0.0)", "jsonschema (>1)", "lxml (>=4.9.2,<5.0.0)", "markdownify (>=0.11.6,<0.12.0)", "motor (>=3.3.1,<4.0.0)", "msal (>=1.25.0,<2.0.0)", "mwparserfromhell (>=0.6.4,<0.7.0)", "mwxml (>=0.3.3,<0.4.0)", "newspaper3k (>=0.2.8,<0.3.0)", "numexpr (>=2.8.6,<3.0.0)", "openai (<2)", "openapi-pydantic (>=0.3.2,<0.4.0)", "pandas (>=2.0.1,<3.0.0)", "pdfminer-six (>=20221105,<20221106)", "pgvector (>=0.1.6,<0.2.0)", "praw (>=7.7.1,<8.0.0)", "psychicapi (>=0.8.0,<0.9.0)", "py-trello (>=0.19.0,<0.20.0)", "pymupdf (>=1.22.3,<2.0.0)", "pypdf (>=3.4.0,<4.0.0)", "pypdfium2 (>=4.10.0,<5.0.0)", "pyspark (>=3.4.0,<4.0.0)", "rank-bm25 (>=0.2.2,<0.3.0)", "rapidfuzz (>=3.1.1,<4.0.0)", "rapidocr-onnxruntime (>=1.3.2,<2.0.0)", "requests-toolbelt (>=1.0.0,<2.0.0)", "rspace_client (>=2.5.0,<3.0.0)", "scikit-learn (>=1.2.2,<2.0.0)", "sqlite-vss (>=0.1.2,<0.2.0)", "streamlit (>=1.18.0,<2.0.0)", "sympy (>=1.12,<2.0)", "telethon (>=1.28.5,<2.0.0)", "timescale-vector (>=0.0.1,<0.0.2)", "tqdm (>=4.48.0)", "upstash-redis (>=0.15.0,<0.16.0)", "xata (>=1.0.0a7,<2.0.0)", "xmltodict (>=0.13.0,<0.14.0)"] +extended-testing = ["aiosqlite (>=0.19.0,<0.20.0)", "aleph-alpha-client (>=2.15.0,<3.0.0)", "anthropic (>=0.3.11,<0.4.0)", "arxiv (>=1.4,<2.0)", "assemblyai (>=0.17.0,<0.18.0)", "atlassian-python-api (>=3.36.0,<4.0.0)", "beautifulsoup4 (>=4,<5)", "bibtexparser (>=1.4.0,<2.0.0)", "cassio (>=0.1.0,<0.2.0)", "chardet (>=5.1.0,<6.0.0)", "cohere (>=4,<5)", "couchbase (>=4.1.9,<5.0.0)", "dashvector (>=1.0.1,<2.0.0)", "databricks-vectorsearch (>=0.21,<0.22)", "datasets (>=2.15.0,<3.0.0)", "dgml-utils (>=0.3.0,<0.4.0)", "esprima (>=4.0.1,<5.0.0)", "faiss-cpu (>=1,<2)", "feedparser (>=6.0.10,<7.0.0)", "fireworks-ai (>=0.9.0,<0.10.0)", "geopandas (>=0.13.1,<0.14.0)", "gitpython (>=3.1.32,<4.0.0)", "google-cloud-documentai (>=2.20.1,<3.0.0)", "gql (>=3.4.1,<4.0.0)", "hologres-vector (>=0.0.6,<0.0.7)", "html2text (>=2020.1.16,<2021.0.0)", "javelin-sdk (>=0.1.8,<0.2.0)", "jinja2 (>=3,<4)", "jq (>=1.4.1,<2.0.0)", "jsonschema (>1)", "langchain-openai (>=0.0.2,<0.1)", "lxml (>=4.9.2,<5.0.0)", "markdownify (>=0.11.6,<0.12.0)", "motor (>=3.3.1,<4.0.0)", "msal (>=1.25.0,<2.0.0)", "mwparserfromhell (>=0.6.4,<0.7.0)", "mwxml (>=0.3.3,<0.4.0)", "newspaper3k (>=0.2.8,<0.3.0)", "numexpr (>=2.8.6,<3.0.0)", "openai (<2)", "openapi-pydantic (>=0.3.2,<0.4.0)", "pandas (>=2.0.1,<3.0.0)", "pdfminer-six (>=20221105,<20221106)", "pgvector (>=0.1.6,<0.2.0)", "praw (>=7.7.1,<8.0.0)", "psychicapi (>=0.8.0,<0.9.0)", "py-trello (>=0.19.0,<0.20.0)", "pymupdf (>=1.22.3,<2.0.0)", "pypdf (>=3.4.0,<4.0.0)", "pypdfium2 (>=4.10.0,<5.0.0)", "pyspark (>=3.4.0,<4.0.0)", "rank-bm25 (>=0.2.2,<0.3.0)", "rapidfuzz (>=3.1.1,<4.0.0)", "rapidocr-onnxruntime (>=1.3.2,<2.0.0)", "rdflib (==7.0.0)", "requests-toolbelt (>=1.0.0,<2.0.0)", "rspace_client (>=2.5.0,<3.0.0)", "scikit-learn (>=1.2.2,<2.0.0)", "sqlite-vss (>=0.1.2,<0.2.0)", "streamlit (>=1.18.0,<2.0.0)", "sympy (>=1.12,<2.0)", "telethon (>=1.28.5,<2.0.0)", "timescale-vector (>=0.0.1,<0.0.2)", "tqdm (>=4.48.0)", "upstash-redis (>=0.15.0,<0.16.0)", "xata (>=1.0.0a7,<2.0.0)", "xmltodict (>=0.13.0,<0.14.0)"] javascript = ["esprima (>=4.0.1,<5.0.0)"] llms = ["clarifai (>=9.1.0)", "cohere (>=4,<5)", "huggingface_hub (>=0,<1)", "manifest-ml (>=0.0.1,<0.0.2)", "nlpcloud (>=1,<2)", "openai (<2)", "openlm (>=0.0.5,<0.0.6)", "torch (>=1,<3)", "transformers (>=4,<5)"] openai = ["openai (<2)", "tiktoken (>=0.3.2,<0.6.0)"] @@ -1403,20 +1381,20 @@ text-helpers = ["chardet (>=5.1.0,<6.0.0)"] [[package]] name = "langchain-community" -version = "0.0.16" +version = "0.0.38" description = "Community contributed LangChain integrations." optional = false -python-versions = ">=3.8.1,<4.0" +python-versions = "<4.0,>=3.8.1" files = [ - {file = "langchain_community-0.0.16-py3-none-any.whl", hash = "sha256:0f1dfc1a6205ce8d39931d3515974a208a9f69c16157c649f83490a7cc830b73"}, - {file = "langchain_community-0.0.16.tar.gz", hash = "sha256:c06512a93013a06fba7679cd5a1254ff8b927cddd2d1fbe0cc444bf7bbdf0b8c"}, + {file = "langchain_community-0.0.38-py3-none-any.whl", hash = "sha256:ecb48660a70a08c90229be46b0cc5f6bc9f38f2833ee44c57dfab9bf3a2c121a"}, + {file = "langchain_community-0.0.38.tar.gz", hash = "sha256:127fc4b75bc67b62fe827c66c02e715a730fef8fe69bd2023d466bab06b5810d"}, ] [package.dependencies] aiohttp = ">=3.8.3,<4.0.0" dataclasses-json = ">=0.5.7,<0.7" -langchain-core = ">=0.1.16,<0.2" -langsmith = ">=0.0.83,<0.1" +langchain-core = ">=0.1.52,<0.2.0" +langsmith = ">=0.1.0,<0.2.0" numpy = ">=1,<2" PyYAML = ">=5.3" requests = ">=2,<3" @@ -1425,27 +1403,25 @@ tenacity = ">=8.1.0,<9.0.0" [package.extras] cli = ["typer (>=0.9.0,<0.10.0)"] -extended-testing = ["aiosqlite (>=0.19.0,<0.20.0)", "aleph-alpha-client (>=2.15.0,<3.0.0)", "anthropic (>=0.3.11,<0.4.0)", "arxiv (>=1.4,<2.0)", "assemblyai (>=0.17.0,<0.18.0)", "atlassian-python-api (>=3.36.0,<4.0.0)", "azure-ai-documentintelligence (>=1.0.0b1,<2.0.0)", "beautifulsoup4 (>=4,<5)", "bibtexparser (>=1.4.0,<2.0.0)", "cassio (>=0.1.0,<0.2.0)", "chardet (>=5.1.0,<6.0.0)", "cohere (>=4,<5)", "dashvector (>=1.0.1,<2.0.0)", "databricks-vectorsearch (>=0.21,<0.22)", "datasets (>=2.15.0,<3.0.0)", "dgml-utils (>=0.3.0,<0.4.0)", "elasticsearch (>=8.12.0,<9.0.0)", "esprima (>=4.0.1,<5.0.0)", "faiss-cpu (>=1,<2)", "feedparser (>=6.0.10,<7.0.0)", "fireworks-ai (>=0.9.0,<0.10.0)", "geopandas (>=0.13.1,<0.14.0)", "gitpython (>=3.1.32,<4.0.0)", "google-cloud-documentai (>=2.20.1,<3.0.0)", "gql (>=3.4.1,<4.0.0)", "gradientai (>=1.4.0,<2.0.0)", "hdbcli (>=2.19.21,<3.0.0)", "hologres-vector (>=0.0.6,<0.0.7)", "html2text (>=2020.1.16,<2021.0.0)", "javelin-sdk (>=0.1.8,<0.2.0)", "jinja2 (>=3,<4)", "jq (>=1.4.1,<2.0.0)", "jsonschema (>1)", "lxml (>=4.9.2,<5.0.0)", "markdownify (>=0.11.6,<0.12.0)", "motor (>=3.3.1,<4.0.0)", "msal (>=1.25.0,<2.0.0)", "mwparserfromhell (>=0.6.4,<0.7.0)", "mwxml (>=0.3.3,<0.4.0)", "newspaper3k (>=0.2.8,<0.3.0)", "numexpr (>=2.8.6,<3.0.0)", "oci (>=2.119.1,<3.0.0)", "openai (<2)", "openapi-pydantic (>=0.3.2,<0.4.0)", "oracle-ads (>=2.9.1,<3.0.0)", "pandas (>=2.0.1,<3.0.0)", "pdfminer-six (>=20221105,<20221106)", "pgvector (>=0.1.6,<0.2.0)", "praw (>=7.7.1,<8.0.0)", "psychicapi (>=0.8.0,<0.9.0)", "py-trello (>=0.19.0,<0.20.0)", "pymupdf (>=1.22.3,<2.0.0)", "pypdf (>=3.4.0,<4.0.0)", "pypdfium2 (>=4.10.0,<5.0.0)", "pyspark (>=3.4.0,<4.0.0)", "rank-bm25 (>=0.2.2,<0.3.0)", "rapidfuzz (>=3.1.1,<4.0.0)", "rapidocr-onnxruntime (>=1.3.2,<2.0.0)", "requests-toolbelt (>=1.0.0,<2.0.0)", "rspace_client (>=2.5.0,<3.0.0)", "scikit-learn (>=1.2.2,<2.0.0)", "sqlite-vss (>=0.1.2,<0.2.0)", "streamlit (>=1.18.0,<2.0.0)", "sympy (>=1.12,<2.0)", "telethon (>=1.28.5,<2.0.0)", "timescale-vector (>=0.0.1,<0.0.2)", "tqdm (>=4.48.0)", "upstash-redis (>=0.15.0,<0.16.0)", "xata (>=1.0.0a7,<2.0.0)", "xmltodict (>=0.13.0,<0.14.0)", "zhipuai (>=1.0.7,<2.0.0)"] +extended-testing = ["aiosqlite (>=0.19.0,<0.20.0)", "aleph-alpha-client (>=2.15.0,<3.0.0)", "anthropic (>=0.3.11,<0.4.0)", "arxiv (>=1.4,<2.0)", "assemblyai (>=0.17.0,<0.18.0)", "atlassian-python-api (>=3.36.0,<4.0.0)", "azure-ai-documentintelligence (>=1.0.0b1,<2.0.0)", "azure-identity (>=1.15.0,<2.0.0)", "azure-search-documents (==11.4.0)", "beautifulsoup4 (>=4,<5)", "bibtexparser (>=1.4.0,<2.0.0)", "cassio (>=0.1.6,<0.2.0)", "chardet (>=5.1.0,<6.0.0)", "cloudpickle (>=2.0.0)", "cohere (>=4,<5)", "databricks-vectorsearch (>=0.21,<0.22)", "datasets (>=2.15.0,<3.0.0)", "dgml-utils (>=0.3.0,<0.4.0)", "elasticsearch (>=8.12.0,<9.0.0)", "esprima (>=4.0.1,<5.0.0)", "faiss-cpu (>=1,<2)", "feedparser (>=6.0.10,<7.0.0)", "fireworks-ai (>=0.9.0,<0.10.0)", "friendli-client (>=1.2.4,<2.0.0)", "geopandas (>=0.13.1,<0.14.0)", "gitpython (>=3.1.32,<4.0.0)", "google-cloud-documentai (>=2.20.1,<3.0.0)", "gql (>=3.4.1,<4.0.0)", "gradientai (>=1.4.0,<2.0.0)", "hdbcli (>=2.19.21,<3.0.0)", "hologres-vector (>=0.0.6,<0.0.7)", "html2text (>=2020.1.16,<2021.0.0)", "httpx (>=0.24.1,<0.25.0)", "httpx-sse (>=0.4.0,<0.5.0)", "javelin-sdk (>=0.1.8,<0.2.0)", "jinja2 (>=3,<4)", "jq (>=1.4.1,<2.0.0)", "jsonschema (>1)", "lxml (>=4.9.3,<6.0)", "markdownify (>=0.11.6,<0.12.0)", "motor (>=3.3.1,<4.0.0)", "msal (>=1.25.0,<2.0.0)", "mwparserfromhell (>=0.6.4,<0.7.0)", "mwxml (>=0.3.3,<0.4.0)", "newspaper3k (>=0.2.8,<0.3.0)", "numexpr (>=2.8.6,<3.0.0)", "nvidia-riva-client (>=2.14.0,<3.0.0)", "oci (>=2.119.1,<3.0.0)", "openai (<2)", "openapi-pydantic (>=0.3.2,<0.4.0)", "oracle-ads (>=2.9.1,<3.0.0)", "oracledb (>=2.2.0,<3.0.0)", "pandas (>=2.0.1,<3.0.0)", "pdfminer-six (>=20221105,<20221106)", "pgvector (>=0.1.6,<0.2.0)", "praw (>=7.7.1,<8.0.0)", "premai (>=0.3.25,<0.4.0)", "psychicapi (>=0.8.0,<0.9.0)", "py-trello (>=0.19.0,<0.20.0)", "pyjwt (>=2.8.0,<3.0.0)", "pymupdf (>=1.22.3,<2.0.0)", "pypdf (>=3.4.0,<4.0.0)", "pypdfium2 (>=4.10.0,<5.0.0)", "pyspark (>=3.4.0,<4.0.0)", "rank-bm25 (>=0.2.2,<0.3.0)", "rapidfuzz (>=3.1.1,<4.0.0)", "rapidocr-onnxruntime (>=1.3.2,<2.0.0)", "rdflib (==7.0.0)", "requests-toolbelt (>=1.0.0,<2.0.0)", "rspace_client (>=2.5.0,<3.0.0)", "scikit-learn (>=1.2.2,<2.0.0)", "sqlite-vss (>=0.1.2,<0.2.0)", "streamlit (>=1.18.0,<2.0.0)", "sympy (>=1.12,<2.0)", "telethon (>=1.28.5,<2.0.0)", "tidb-vector (>=0.0.3,<1.0.0)", "timescale-vector (>=0.0.1,<0.0.2)", "tqdm (>=4.48.0)", "tree-sitter (>=0.20.2,<0.21.0)", "tree-sitter-languages (>=1.8.0,<2.0.0)", "upstash-redis (>=0.15.0,<0.16.0)", "vdms (>=0.0.20,<0.0.21)", "xata (>=1.0.0a7,<2.0.0)", "xmltodict (>=0.13.0,<0.14.0)"] [[package]] name = "langchain-core" -version = "0.1.17" +version = "0.1.52" description = "Building applications with LLMs through composability" optional = false -python-versions = ">=3.8.1,<4.0" +python-versions = "<4.0,>=3.8.1" files = [ - {file = "langchain_core-0.1.17-py3-none-any.whl", hash = "sha256:026155cf97867bde410ab1834799ab4c5ba64c39380f2a4328bcf9c78623ca64"}, - {file = "langchain_core-0.1.17.tar.gz", hash = "sha256:59016e457cd6a1708d83a3a454acc97cf02c2a2c3af95626d13f83894fd4e777"}, + {file = "langchain_core-0.1.52-py3-none-any.whl", hash = "sha256:62566749c92e8a1181c255c788548dc16dbc319d896cd6b9c95dc17af9b2a6db"}, + {file = "langchain_core-0.1.52.tar.gz", hash = "sha256:084c3fc452f5a6966c28ab3ec5dbc8b8d26fc3f63378073928f4e29d90b6393f"}, ] [package.dependencies] -anyio = ">=3,<5" jsonpatch = ">=1.33,<2.0" -langsmith = ">=0.0.83,<0.1" +langsmith = ">=0.1.0,<0.2.0" packaging = ">=23.2,<24.0" pydantic = ">=1,<3" PyYAML = ">=5.3" -requests = ">=2,<3" tenacity = ">=8.1.0,<9.0.0" [package.extras] @@ -1453,19 +1429,36 @@ extended-testing = ["jinja2 (>=3,<4)"] [[package]] name = "langsmith" -version = "0.0.85" +version = "0.1.81" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." optional = false -python-versions = ">=3.8.1,<4.0" +python-versions = "<4.0,>=3.8.1" files = [ - {file = "langsmith-0.0.85-py3-none-any.whl", hash = "sha256:9d0ccbcda7b69c83828060603a51bb4319e43b8dc807fbd90b6355f8ec709500"}, - {file = "langsmith-0.0.85.tar.gz", hash = "sha256:fefc631fc30d836b54d4e3f99961c41aea497633898b8f09e305b6c7216c2c54"}, + {file = "langsmith-0.1.81-py3-none-any.whl", hash = "sha256:3251d823225eef23ee541980b9d9e506367eabbb7f985a086b5d09e8f78ba7e9"}, + {file = "langsmith-0.1.81.tar.gz", hash = "sha256:585ef3a2251380bd2843a664c9a28da4a7d28432e3ee8bcebf291ffb8e1f0af0"}, ] [package.dependencies] +orjson = ">=3.9.14,<4.0.0" pydantic = ">=1,<3" requests = ">=2,<3" +[[package]] +name = "langsmith" +version = "0.1.82" +description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." +optional = false +python-versions = "<4.0,>=3.8.1" +files = [ + {file = "langsmith-0.1.82-py3-none-any.whl", hash = "sha256:9b3653e7d316036b0c60bf0bc3e280662d660f485a4ebd8e5c9d84f9831ae79c"}, + {file = "langsmith-0.1.82.tar.gz", hash = "sha256:c02e2bbc488c10c13b52c69d271eb40bd38da078d37b6ae7ae04a18bd48140be"}, +] + +[package.dependencies] +orjson = ">=3.9.14,<4.0.0" +pydantic = {version = ">=1,<3", markers = "python_full_version < \"3.12.4\""} +requests = ">=2,<3" + [[package]] name = "lazy-object-proxy" version = "1.10.0" @@ -1896,6 +1889,61 @@ dev = ["black (>=21.6b0,<22.0)", "pytest (==6.*)", "pytest-asyncio", "pytest-moc embeddings = ["matplotlib", "numpy", "openpyxl (>=3.0.7)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)", "plotly", "scikit-learn (>=1.0.2)", "scipy", "tenacity (>=8.0.1)"] wandb = ["numpy", "openpyxl (>=3.0.7)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)", "wandb"] +[[package]] +name = "orjson" +version = "3.10.5" +description = "Fast, correct Python JSON library supporting dataclasses, datetimes, and numpy" +optional = false +python-versions = ">=3.8" +files = [ + {file = "orjson-3.10.5-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:545d493c1f560d5ccfc134803ceb8955a14c3fcb47bbb4b2fee0232646d0b932"}, + {file = "orjson-3.10.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f4324929c2dd917598212bfd554757feca3e5e0fa60da08be11b4aa8b90013c1"}, + {file = "orjson-3.10.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8c13ca5e2ddded0ce6a927ea5a9f27cae77eee4c75547b4297252cb20c4d30e6"}, + {file = "orjson-3.10.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b6c8e30adfa52c025f042a87f450a6b9ea29649d828e0fec4858ed5e6caecf63"}, + {file = "orjson-3.10.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:338fd4f071b242f26e9ca802f443edc588fa4ab60bfa81f38beaedf42eda226c"}, + {file = "orjson-3.10.5-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:6970ed7a3126cfed873c5d21ece1cd5d6f83ca6c9afb71bbae21a0b034588d96"}, + {file = "orjson-3.10.5-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:235dadefb793ad12f7fa11e98a480db1f7c6469ff9e3da5e73c7809c700d746b"}, + {file = "orjson-3.10.5-cp310-none-win32.whl", hash = "sha256:be79e2393679eda6a590638abda16d167754393f5d0850dcbca2d0c3735cebe2"}, + {file = "orjson-3.10.5-cp310-none-win_amd64.whl", hash = "sha256:c4a65310ccb5c9910c47b078ba78e2787cb3878cdded1702ac3d0da71ddc5228"}, + {file = "orjson-3.10.5-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:cdf7365063e80899ae3a697def1277c17a7df7ccfc979990a403dfe77bb54d40"}, + {file = "orjson-3.10.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b68742c469745d0e6ca5724506858f75e2f1e5b59a4315861f9e2b1df77775a"}, + {file = "orjson-3.10.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7d10cc1b594951522e35a3463da19e899abe6ca95f3c84c69e9e901e0bd93d38"}, + {file = "orjson-3.10.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dcbe82b35d1ac43b0d84072408330fd3295c2896973112d495e7234f7e3da2e1"}, + {file = "orjson-3.10.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:10c0eb7e0c75e1e486c7563fe231b40fdd658a035ae125c6ba651ca3b07936f5"}, + {file = "orjson-3.10.5-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:53ed1c879b10de56f35daf06dbc4a0d9a5db98f6ee853c2dbd3ee9d13e6f302f"}, + {file = "orjson-3.10.5-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:099e81a5975237fda3100f918839af95f42f981447ba8f47adb7b6a3cdb078fa"}, + {file = "orjson-3.10.5-cp311-none-win32.whl", hash = "sha256:1146bf85ea37ac421594107195db8bc77104f74bc83e8ee21a2e58596bfb2f04"}, + {file = "orjson-3.10.5-cp311-none-win_amd64.whl", hash = "sha256:36a10f43c5f3a55c2f680efe07aa93ef4a342d2960dd2b1b7ea2dd764fe4a37c"}, + {file = "orjson-3.10.5-cp312-cp312-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:68f85ecae7af14a585a563ac741b0547a3f291de81cd1e20903e79f25170458f"}, + {file = "orjson-3.10.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:28afa96f496474ce60d3340fe8d9a263aa93ea01201cd2bad844c45cd21f5268"}, + {file = "orjson-3.10.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9cd684927af3e11b6e754df80b9ffafd9fb6adcaa9d3e8fdd5891be5a5cad51e"}, + {file = "orjson-3.10.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3d21b9983da032505f7050795e98b5d9eee0df903258951566ecc358f6696969"}, + {file = "orjson-3.10.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ad1de7fef79736dde8c3554e75361ec351158a906d747bd901a52a5c9c8d24b"}, + {file = "orjson-3.10.5-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:2d97531cdfe9bdd76d492e69800afd97e5930cb0da6a825646667b2c6c6c0211"}, + {file = "orjson-3.10.5-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:d69858c32f09c3e1ce44b617b3ebba1aba030e777000ebdf72b0d8e365d0b2b3"}, + {file = "orjson-3.10.5-cp312-none-win32.whl", hash = "sha256:64c9cc089f127e5875901ac05e5c25aa13cfa5dbbbd9602bda51e5c611d6e3e2"}, + {file = "orjson-3.10.5-cp312-none-win_amd64.whl", hash = "sha256:b2efbd67feff8c1f7728937c0d7f6ca8c25ec81373dc8db4ef394c1d93d13dc5"}, + {file = "orjson-3.10.5-cp38-cp38-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:03b565c3b93f5d6e001db48b747d31ea3819b89abf041ee10ac6988886d18e01"}, + {file = "orjson-3.10.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:584c902ec19ab7928fd5add1783c909094cc53f31ac7acfada817b0847975f26"}, + {file = "orjson-3.10.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5a35455cc0b0b3a1eaf67224035f5388591ec72b9b6136d66b49a553ce9eb1e6"}, + {file = "orjson-3.10.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1670fe88b116c2745a3a30b0f099b699a02bb3482c2591514baf5433819e4f4d"}, + {file = "orjson-3.10.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:185c394ef45b18b9a7d8e8f333606e2e8194a50c6e3c664215aae8cf42c5385e"}, + {file = "orjson-3.10.5-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:ca0b3a94ac8d3886c9581b9f9de3ce858263865fdaa383fbc31c310b9eac07c9"}, + {file = "orjson-3.10.5-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:dfc91d4720d48e2a709e9c368d5125b4b5899dced34b5400c3837dadc7d6271b"}, + {file = "orjson-3.10.5-cp38-none-win32.whl", hash = "sha256:c05f16701ab2a4ca146d0bca950af254cb7c02f3c01fca8efbbad82d23b3d9d4"}, + {file = "orjson-3.10.5-cp38-none-win_amd64.whl", hash = "sha256:8a11d459338f96a9aa7f232ba95679fc0c7cedbd1b990d736467894210205c09"}, + {file = "orjson-3.10.5-cp39-cp39-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:85c89131d7b3218db1b24c4abecea92fd6c7f9fab87441cfc342d3acc725d807"}, + {file = "orjson-3.10.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fb66215277a230c456f9038d5e2d84778141643207f85336ef8d2a9da26bd7ca"}, + {file = "orjson-3.10.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:51bbcdea96cdefa4a9b4461e690c75ad4e33796530d182bdd5c38980202c134a"}, + {file = "orjson-3.10.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dbead71dbe65f959b7bd8cf91e0e11d5338033eba34c114f69078d59827ee139"}, + {file = "orjson-3.10.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5df58d206e78c40da118a8c14fc189207fffdcb1f21b3b4c9c0c18e839b5a214"}, + {file = "orjson-3.10.5-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:c4057c3b511bb8aef605616bd3f1f002a697c7e4da6adf095ca5b84c0fd43595"}, + {file = "orjson-3.10.5-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:b39e006b00c57125ab974362e740c14a0c6a66ff695bff44615dcf4a70ce2b86"}, + {file = "orjson-3.10.5-cp39-none-win32.whl", hash = "sha256:eded5138cc565a9d618e111c6d5c2547bbdd951114eb822f7f6309e04db0fb47"}, + {file = "orjson-3.10.5-cp39-none-win_amd64.whl", hash = "sha256:cc28e90a7cae7fcba2493953cff61da5a52950e78dc2dacfe931a317ee3d8de7"}, + {file = "orjson-3.10.5.tar.gz", hash = "sha256:7a5baef8a4284405d96c90c7c62b755e9ef1ada84c2406c24a9ebec86b89f46d"}, +] + [[package]] name = "packaging" version = "23.2" @@ -2733,17 +2781,6 @@ files = [ {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, ] -[[package]] -name = "sniffio" -version = "1.3.0" -description = "Sniff out which async library your code is running under" -optional = false -python-versions = ">=3.7" -files = [ - {file = "sniffio-1.3.0-py3-none-any.whl", hash = "sha256:eecefdce1e5bbfb7ad2eeaabf7c1eeb404d7757c379bd1f7e5cce9d8bf425384"}, - {file = "sniffio-1.3.0.tar.gz", hash = "sha256:e60305c5e5d314f5389259b7f22aaa33d8f7dee49763119234af3755c55b9101"}, -] - [[package]] name = "sqlalchemy" version = "2.0.25" @@ -3440,4 +3477,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "161eceb20f8c1a1a95987444d1e10f175134ee645aa46d5c60c683b930635c7a" +content-hash = "30cde429de81bf605d63c55e63e8a1ecb32d51c7456d60f274719cbdf51d607f" diff --git a/pyproject.toml b/pyproject.toml index d6d997c..beeb962 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "allms" -version = "1.0.5" +version = "1.0.6" description = "" authors = ["Allegro Opensource "] readme = "README.md" @@ -13,7 +13,7 @@ google-cloud-aiplatform = "1.38.0" pydash = "^7.0.6" transformers = "^4.34.1" pydantic = "1.10.13" -langchain = "^0.0.351" +langchain = "^0.1.8" aioresponses = "^0.7.6" tiktoken = "^0.6.0" openai = "^0.27.8" From ab82d0340086a9b5cd8ce452b331f04fa238a818 Mon Sep 17 00:00:00 2001 From: Piotr C Zielinski Date: Mon, 1 Jul 2024 15:26:46 +0200 Subject: [PATCH 04/14] Enable passing gemini safety settings --- allms/domain/configuration.py | 5 +- allms/models/__init__.py | 3 + allms/models/vertexai_base.py | 2 +- allms/models/vertexai_gemini.py | 3 +- allms/models/vertexai_gemma.py | 2 +- allms/models/vertexai_palm.py | 2 +- poetry.lock | 84 ++++++++++++------- pyproject.toml | 3 +- tests/test_end_to_end.py | 26 ++++-- ...model_behavior_for_different_input_data.py | 6 +- tests/test_output_parser.py | 14 ++-- tests/test_utf_characters_data.py | 2 +- 12 files changed, 96 insertions(+), 56 deletions(-) diff --git a/allms/domain/configuration.py b/allms/domain/configuration.py index 20b8037..b2ccf42 100644 --- a/allms/domain/configuration.py +++ b/allms/domain/configuration.py @@ -1,5 +1,7 @@ from dataclasses import dataclass -from typing import Optional +from typing import Dict, Optional + +from langchain_google_vertexai import HarmBlockThreshold, HarmCategory from allms.defaults.vertex_ai import GeminiModelDefaults, PalmModelDefaults @@ -26,6 +28,7 @@ class VertexAIConfiguration: cloud_location: str palm_model_name: Optional[str] = PalmModelDefaults.GCP_MODEL_NAME gemini_model_name: Optional[str] = GeminiModelDefaults.GCP_MODEL_NAME + gemini_safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None class VertexAIModelGardenConfiguration(VertexAIConfiguration): diff --git a/allms/models/__init__.py b/allms/models/__init__.py index 3087552..fccf062 100644 --- a/allms/models/__init__.py +++ b/allms/models/__init__.py @@ -1,5 +1,6 @@ from typing import Dict, Type +from allms.domain.configuration import HarmBlockThreshold, HarmCategory from allms.domain.enumerables import AvailableModels from allms.models.abstract import AbstractModel from allms.models.azure_llama2 import AzureLlama2Model @@ -16,6 +17,8 @@ "VertexAIPalmModel", "VertexAIGeminiModel", "VertexAIGemmaModel", + "HarmCategory", + "HarmBlockThreshold", "get_available_models" ] diff --git a/allms/models/vertexai_base.py b/allms/models/vertexai_base.py index 946ea83..3724212 100644 --- a/allms/models/vertexai_base.py +++ b/allms/models/vertexai_base.py @@ -1,7 +1,7 @@ from typing import List, Optional, Any, Dict from google.cloud.aiplatform.models import Prediction -from langchain_community.llms.vertexai import VertexAI, VertexAIModelGarden +from langchain_google_vertexai import VertexAI, VertexAIModelGarden from langchain_core.callbacks import AsyncCallbackManagerForLLMRun from langchain_core.outputs import LLMResult, Generation from pydash import chain diff --git a/allms/models/vertexai_gemini.py b/allms/models/vertexai_gemini.py index 98dd155..1d6c94d 100644 --- a/allms/models/vertexai_gemini.py +++ b/allms/models/vertexai_gemini.py @@ -1,5 +1,5 @@ from asyncio import AbstractEventLoop -from langchain_community.llms.vertexai import VertexAI +from langchain_google_vertexai import VertexAI from typing import Optional from allms.defaults.general_defaults import GeneralDefaults @@ -44,6 +44,7 @@ def _create_llm(self) -> VertexAI: temperature=self._temperature, top_p=self._top_p, top_k=self._top_k, + safety_settings=self._config.gemini_safety_settings, verbose=self._verbose, project=self._config.cloud_project, location=self._config.cloud_location diff --git a/allms/models/vertexai_gemma.py b/allms/models/vertexai_gemma.py index eb725c4..7834d4a 100644 --- a/allms/models/vertexai_gemma.py +++ b/allms/models/vertexai_gemma.py @@ -1,6 +1,6 @@ from asyncio import AbstractEventLoop -from langchain_community.llms.vertexai import VertexAIModelGarden +from langchain_google_vertexai import VertexAIModelGarden from typing import Optional from allms.defaults.general_defaults import GeneralDefaults diff --git a/allms/models/vertexai_palm.py b/allms/models/vertexai_palm.py index 9845068..674823a 100644 --- a/allms/models/vertexai_palm.py +++ b/allms/models/vertexai_palm.py @@ -1,5 +1,5 @@ from asyncio import AbstractEventLoop -from langchain_community.llms.vertexai import VertexAI +from langchain_google_vertexai import VertexAI from typing import Optional from allms.defaults.general_defaults import GeneralDefaults diff --git a/poetry.lock b/poetry.lock index 65141d7..5a83f92 100644 --- a/poetry.lock +++ b/poetry.lock @@ -491,6 +491,17 @@ files = [ graph = ["objgraph (>=1.7.2)"] profile = ["gprof2dot (>=2022.7.29)"] +[[package]] +name = "docstring-parser" +version = "0.16" +description = "Parse Python docstrings in reST, Google and Numpydoc format" +optional = false +python-versions = ">=3.6,<4.0" +files = [ + {file = "docstring_parser-0.16-py3-none-any.whl", hash = "sha256:bf0a1387354d3691d102edef7ec124f219ef639982d096e26e3b60aeffa90637"}, + {file = "docstring_parser-0.16.tar.gz", hash = "sha256:538beabd0af1e2db0146b6bd3caa526c35a34d61af9fd2887f3a8a27a739aa6e"}, +] + [[package]] name = "docutils" version = "0.20.1" @@ -725,42 +736,50 @@ requests = ["requests (>=2.20.0,<3.0.0.dev0)"] [[package]] name = "google-cloud-aiplatform" -version = "1.38.0" +version = "1.57.0" description = "Vertex AI API client library" optional = false python-versions = ">=3.8" files = [ - {file = "google-cloud-aiplatform-1.38.0.tar.gz", hash = "sha256:dff91f79b64e279f0e61dfd63c4e067ba5fa75ef0f4614289bbdca70d086a9e2"}, - {file = "google_cloud_aiplatform-1.38.0-py2.py3-none-any.whl", hash = "sha256:7eec50d9a36d43e163f019a1ade9284d4580602a5108738a0ebff8940ea47ce0"}, + {file = "google-cloud-aiplatform-1.57.0.tar.gz", hash = "sha256:113905f100cb0a9ad744a2445a7675f92f28600233ba499614aa704d11a809b7"}, + {file = "google_cloud_aiplatform-1.57.0-py2.py3-none-any.whl", hash = "sha256:ca5391a56e0cc8f4ed39a2beb7be02f51936ff04fd5304775a72a86c345d0e47"}, ] [package.dependencies] -google-api-core = {version = ">=1.32.0,<2.0.dev0 || >=2.8.dev0,<3.0.0dev", extras = ["grpc"]} -google-cloud-bigquery = ">=1.15.0,<4.0.0dev" +docstring-parser = "<1" +google-api-core = {version = ">=1.34.1,<2.0.dev0 || >=2.8.dev0,<3.0.0dev", extras = ["grpc"]} +google-auth = ">=2.14.1,<3.0.0dev" +google-cloud-bigquery = ">=1.15.0,<3.20.0 || >3.20.0,<4.0.0dev" google-cloud-resource-manager = ">=1.3.3,<3.0.0dev" google-cloud-storage = ">=1.32.0,<3.0.0dev" packaging = ">=14.3" proto-plus = ">=1.22.0,<2.0.0dev" protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev" -setuptools = {version = "*", markers = "python_version >= \"3.12\""} +pydantic = "<3" shapely = "<3.0.0dev" [package.extras] autologging = ["mlflow (>=1.27.0,<=2.1.1)"] cloud-profiler = ["tensorboard-plugin-profile (>=2.4.0,<3.0.0dev)", "tensorflow (>=2.4.0,<3.0.0dev)", "werkzeug (>=2.0.0,<2.1.0dev)"] -datasets = ["pyarrow (>=10.0.1)", "pyarrow (>=3.0.0,<8.0dev)"] +datasets = ["pyarrow (>=10.0.1)", "pyarrow (>=14.0.0)", "pyarrow (>=3.0.0,<8.0dev)"] endpoint = ["requests (>=2.28.1)"] -full = ["cloudpickle (<3.0)", "docker (>=5.0.3)", "explainable-ai-sdk (>=1.0.0)", "fastapi (>=0.71.0,<0.103.1)", "google-cloud-bigquery", "google-cloud-bigquery-storage", "google-cloud-logging (<4.0)", "google-vizier (==0.0.11)", "google-vizier (==0.0.4)", "google-vizier (>=0.0.14)", "google-vizier (>=0.1.6)", "httpx (>=0.23.0,<0.25.0)", "lit-nlp (==0.4.0)", "mlflow (>=1.27.0,<=2.1.1)", "numpy (>=1.15.0)", "pandas (>=1.0.0)", "pyarrow (>=10.0.1)", "pyarrow (>=3.0.0,<8.0dev)", "pyarrow (>=6.0.1)", "pydantic (<2)", "pyyaml (==5.3.1)", "ray[default] (>=2.4,<2.5)", "ray[default] (>=2.5,<2.5.1)", "requests (>=2.28.1)", "starlette (>=0.17.1)", "tensorflow (>=2.3.0,<3.0.0dev)", "urllib3 (>=1.21.1,<1.27)", "uvicorn[standard] (>=0.16.0)"] +full = ["cloudpickle (<3.0)", "docker (>=5.0.3)", "explainable-ai-sdk (>=1.0.0)", "fastapi (>=0.71.0,<=0.109.1)", "google-cloud-bigquery", "google-cloud-bigquery-storage", "google-cloud-logging (<4.0)", "google-vizier (>=0.1.6)", "httpx (>=0.23.0,<0.25.0)", "immutabledict", "lit-nlp (==0.4.0)", "mlflow (>=1.27.0,<=2.1.1)", "nest-asyncio (>=1.0.0,<1.6.0)", "numpy (>=1.15.0)", "pandas (>=1.0.0)", "pandas (>=1.0.0,<2.2.0)", "pyarrow (>=10.0.1)", "pyarrow (>=14.0.0)", "pyarrow (>=3.0.0,<8.0dev)", "pyarrow (>=6.0.1)", "pydantic (<2)", "pyyaml (>=5.3.1,<7)", "ray[default] (>=2.4,<2.5.dev0 || >2.9.0,!=2.9.1,!=2.9.2,<=2.9.3)", "ray[default] (>=2.5,<=2.9.3)", "requests (>=2.28.1)", "setuptools (<70.0.0)", "starlette (>=0.17.1)", "tensorboard-plugin-profile (>=2.4.0,<3.0.0dev)", "tensorflow (>=2.3.0,<3.0.0dev)", "tensorflow (>=2.3.0,<3.0.0dev)", "tensorflow (>=2.4.0,<3.0.0dev)", "urllib3 (>=1.21.1,<1.27)", "uvicorn[standard] (>=0.16.0)", "werkzeug (>=2.0.0,<2.1.0dev)"] +langchain = ["langchain (>=0.1.16,<0.3)", "langchain-core (<0.2)", "langchain-google-vertexai (<2)", "openinference-instrumentation-langchain (>=0.1.19,<0.2)", "tenacity (<=8.3)"] +langchain-testing = ["absl-py", "cloudpickle (>=3.0,<4.0)", "langchain (>=0.1.16,<0.3)", "langchain-core (<0.2)", "langchain-google-vertexai (<2)", "openinference-instrumentation-langchain (>=0.1.19,<0.2)", "opentelemetry-exporter-gcp-trace (<2)", "opentelemetry-sdk (<2)", "pydantic (>=2.6.3,<3)", "pytest-xdist", "tenacity (<=8.3)"] lit = ["explainable-ai-sdk (>=1.0.0)", "lit-nlp (==0.4.0)", "pandas (>=1.0.0)", "tensorflow (>=2.3.0,<3.0.0dev)"] metadata = ["numpy (>=1.15.0)", "pandas (>=1.0.0)"] -pipelines = ["pyyaml (==5.3.1)"] -prediction = ["docker (>=5.0.3)", "fastapi (>=0.71.0,<0.103.1)", "httpx (>=0.23.0,<0.25.0)", "starlette (>=0.17.1)", "uvicorn[standard] (>=0.16.0)"] +pipelines = ["pyyaml (>=5.3.1,<7)"] +prediction = ["docker (>=5.0.3)", "fastapi (>=0.71.0,<=0.109.1)", "httpx (>=0.23.0,<0.25.0)", "starlette (>=0.17.1)", "uvicorn[standard] (>=0.16.0)"] preview = ["cloudpickle (<3.0)", "google-cloud-logging (<4.0)"] private-endpoints = ["requests (>=2.28.1)", "urllib3 (>=1.21.1,<1.27)"] -ray = ["google-cloud-bigquery", "google-cloud-bigquery-storage", "pandas (>=1.0.0)", "pyarrow (>=6.0.1)", "pydantic (<2)", "ray[default] (>=2.4,<2.5)", "ray[default] (>=2.5,<2.5.1)"] -tensorboard = ["tensorflow (>=2.3.0,<3.0.0dev)"] -testing = ["bigframes", "cloudpickle (<3.0)", "docker (>=5.0.3)", "explainable-ai-sdk (>=1.0.0)", "fastapi (>=0.71.0,<0.103.1)", "google-cloud-bigquery", "google-cloud-bigquery-storage", "google-cloud-logging (<4.0)", "google-vizier (==0.0.11)", "google-vizier (==0.0.4)", "google-vizier (>=0.0.14)", "google-vizier (>=0.1.6)", "grpcio-testing", "httpx (>=0.23.0,<0.25.0)", "ipython", "kfp", "lit-nlp (==0.4.0)", "mlflow (>=1.27.0,<=2.1.1)", "numpy (>=1.15.0)", "pandas (>=1.0.0)", "pyarrow (>=10.0.1)", "pyarrow (>=3.0.0,<8.0dev)", "pyarrow (>=6.0.1)", "pydantic (<2)", "pyfakefs", "pytest-asyncio", "pytest-xdist", "pyyaml (==5.3.1)", "ray[default] (>=2.4,<2.5)", "ray[default] (>=2.5,<2.5.1)", "requests (>=2.28.1)", "requests-toolbelt (<1.0.0)", "scikit-learn", "starlette (>=0.17.1)", "tensorboard-plugin-profile (>=2.4.0,<3.0.0dev)", "tensorflow (>=2.3.0,<3.0.0dev)", "tensorflow (>=2.3.0,<=2.12.0)", "tensorflow (>=2.4.0,<3.0.0dev)", "torch (>=2.0.0,<2.1.0)", "urllib3 (>=1.21.1,<1.27)", "uvicorn[standard] (>=0.16.0)", "werkzeug (>=2.0.0,<2.1.0dev)", "xgboost", "xgboost-ray"] -vizier = ["google-vizier (==0.0.11)", "google-vizier (==0.0.4)", "google-vizier (>=0.0.14)", "google-vizier (>=0.1.6)"] +rapid-evaluation = ["nest-asyncio (>=1.0.0,<1.6.0)", "pandas (>=1.0.0,<2.2.0)"] +ray = ["google-cloud-bigquery", "google-cloud-bigquery-storage", "immutabledict", "pandas (>=1.0.0,<2.2.0)", "pyarrow (>=6.0.1)", "pydantic (<2)", "ray[default] (>=2.4,<2.5.dev0 || >2.9.0,!=2.9.1,!=2.9.2,<=2.9.3)", "ray[default] (>=2.5,<=2.9.3)", "setuptools (<70.0.0)"] +ray-testing = ["google-cloud-bigquery", "google-cloud-bigquery-storage", "immutabledict", "pandas (>=1.0.0,<2.2.0)", "pyarrow (>=6.0.1)", "pydantic (<2)", "pytest-xdist", "ray[default] (>=2.4,<2.5.dev0 || >2.9.0,!=2.9.1,!=2.9.2,<=2.9.3)", "ray[default] (>=2.5,<=2.9.3)", "ray[train] (==2.9.3)", "scikit-learn", "setuptools (<70.0.0)", "tensorflow", "torch (>=2.0.0,<2.1.0)", "xgboost", "xgboost-ray"] +reasoningengine = ["cloudpickle (>=3.0,<4.0)", "opentelemetry-exporter-gcp-trace (<2)", "opentelemetry-sdk (<2)", "pydantic (>=2.6.3,<3)"] +tensorboard = ["tensorboard-plugin-profile (>=2.4.0,<3.0.0dev)", "tensorflow (>=2.3.0,<3.0.0dev)", "tensorflow (>=2.4.0,<3.0.0dev)", "werkzeug (>=2.0.0,<2.1.0dev)"] +testing = ["bigframes", "cloudpickle (<3.0)", "docker (>=5.0.3)", "explainable-ai-sdk (>=1.0.0)", "fastapi (>=0.71.0,<=0.109.1)", "google-api-core (>=2.11,<3.0.0)", "google-cloud-bigquery", "google-cloud-bigquery-storage", "google-cloud-logging (<4.0)", "google-vizier (>=0.1.6)", "grpcio-testing", "httpx (>=0.23.0,<0.25.0)", "immutabledict", "ipython", "kfp (>=2.6.0,<3.0.0)", "lit-nlp (==0.4.0)", "mlflow (>=1.27.0,<=2.1.1)", "nest-asyncio (>=1.0.0,<1.6.0)", "nltk", "numpy (>=1.15.0)", "pandas (>=1.0.0)", "pandas (>=1.0.0,<2.2.0)", "pyarrow (>=10.0.1)", "pyarrow (>=14.0.0)", "pyarrow (>=3.0.0,<8.0dev)", "pyarrow (>=6.0.1)", "pydantic (<2)", "pyfakefs", "pytest-asyncio", "pytest-xdist", "pyyaml (>=5.3.1,<7)", "ray[default] (>=2.4,<2.5.dev0 || >2.9.0,!=2.9.1,!=2.9.2,<=2.9.3)", "ray[default] (>=2.5,<=2.9.3)", "requests (>=2.28.1)", "requests-toolbelt (<1.0.0)", "scikit-learn", "sentencepiece (>=0.2.0)", "setuptools (<70.0.0)", "starlette (>=0.17.1)", "tensorboard-plugin-profile (>=2.4.0,<3.0.0dev)", "tensorflow (==2.13.0)", "tensorflow (==2.16.1)", "tensorflow (>=2.3.0,<3.0.0dev)", "tensorflow (>=2.3.0,<3.0.0dev)", "tensorflow (>=2.4.0,<3.0.0dev)", "torch (>=2.0.0,<2.1.0)", "torch (>=2.2.0)", "urllib3 (>=1.21.1,<1.27)", "uvicorn[standard] (>=0.16.0)", "werkzeug (>=2.0.0,<2.1.0dev)", "xgboost"] +tokenization = ["sentencepiece (>=0.2.0)"] +vizier = ["google-vizier (>=0.1.6)"] xai = ["tensorflow (>=2.3.0,<3.0.0dev)"] [[package]] @@ -1427,6 +1446,25 @@ tenacity = ">=8.1.0,<9.0.0" [package.extras] extended-testing = ["jinja2 (>=3,<4)"] +[[package]] +name = "langchain-google-vertexai" +version = "1.0.4" +description = "An integration package connecting Google VertexAI and LangChain" +optional = false +python-versions = "<4.0,>=3.8.1" +files = [ + {file = "langchain_google_vertexai-1.0.4-py3-none-any.whl", hash = "sha256:f9d217df2d5cfafb2e551ddd5f1c43611222f542ee0df0cc3b5faed82e657ee3"}, + {file = "langchain_google_vertexai-1.0.4.tar.gz", hash = "sha256:bb2d2e93cc2896b9bdc96789c2df247f6392184dffc0c3dddc06889f2b530465"}, +] + +[package.dependencies] +google-cloud-aiplatform = ">=1.47.0,<2.0.0" +google-cloud-storage = ">=2.14.0,<3.0.0" +langchain-core = ">=0.1.42,<0.3" + +[package.extras] +anthropic = ["anthropic[vertexai] (>=0.23.0,<1)"] + [[package]] name = "langsmith" version = "0.1.81" @@ -2697,22 +2735,6 @@ files = [ cryptography = ">=2.0" jeepney = ">=0.6" -[[package]] -name = "setuptools" -version = "69.0.3" -description = "Easily download, build, install, upgrade, and uninstall Python packages" -optional = false -python-versions = ">=3.8" -files = [ - {file = "setuptools-69.0.3-py3-none-any.whl", hash = "sha256:385eb4edd9c9d5c17540511303e39a147ce2fc04bc55289c322b9e5904fe2c05"}, - {file = "setuptools-69.0.3.tar.gz", hash = "sha256:be1af57fc409f93647f2e8e4573a142ed38724b8cdd389706a867bb4efcf1e78"}, -] - -[package.extras] -docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"] -testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] -testing-integration = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "packaging (>=23.1)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"] - [[package]] name = "shapely" version = "2.0.2" @@ -3477,4 +3499,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "30cde429de81bf605d63c55e63e8a1ecb32d51c7456d60f274719cbdf51d607f" +content-hash = "486f861c4a1327f96eb9e32b7c895cdf603bd09c065be20e143af0808da0d717" diff --git a/pyproject.toml b/pyproject.toml index beeb962..895916b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,11 +9,12 @@ packages = [{include = "allms"}] [tool.poetry.dependencies] python = ">=3.8.1,<4.0" fsspec = "^2023.6.0" -google-cloud-aiplatform = "1.38.0" +google-cloud-aiplatform = "^1.47.0" pydash = "^7.0.6" transformers = "^4.34.1" pydantic = "1.10.13" langchain = "^0.1.8" +langchain-google-vertexai = "1.0.4" aioresponses = "^0.7.6" tiktoken = "^0.6.0" openai = "^0.27.8" diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py index 1ef9e84..1d3485e 100644 --- a/tests/test_end_to_end.py +++ b/tests/test_end_to_end.py @@ -6,7 +6,7 @@ from allms.constants.input_data import IODataConstants from allms.domain.configuration import VertexAIConfiguration from allms.domain.prompt_dto import KeywordsOutputClass -from allms.models.vertexai_gemini import VertexAIGeminiModel +from allms.models import VertexAIGeminiModel, HarmBlockThreshold, HarmCategory from allms.utils import io_utils from tests.conftest import AzureOpenAIEnv @@ -148,19 +148,29 @@ def test_prompt_is_not_modified_for_open_source_models(self, mock_aioresponse, m ) ]) - def test_gemini_version_is_passed_to_model(self): + def test_gemini_specific_args_are_passed_to_model(self): # GIVEN + gemini_model_name = "gemini-model-name" + gemini_safety_settings = { + HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE, + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH, + HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, + } model_config = VertexAIConfiguration( - cloud_project="dummy-project-id", - cloud_location="us-central1", - gemini_model_name="gemini-model-name" - ) + cloud_project="dummy-project-id", + cloud_location="us-central1", + gemini_model_name=gemini_model_name, + gemini_safety_settings=gemini_safety_settings + ) # WHEN gemini_model = VertexAIGeminiModel(config=model_config) - # WHEN - gemini_model._llm.model_name == "gemini-model-name" + # THEN + assert gemini_model._llm.model_name == gemini_model_name + assert gemini_model._llm.safety_settings == gemini_safety_settings def test_model_times_out( self, diff --git a/tests/test_model_behavior_for_different_input_data.py b/tests/test_model_behavior_for_different_input_data.py index 9e4d1cf..fa464e2 100644 --- a/tests/test_model_behavior_for_different_input_data.py +++ b/tests/test_model_behavior_for_different_input_data.py @@ -78,7 +78,7 @@ def test_exception_when_input_data_is_missing_and_prompt_contains_input_key(self model.generate(prompt, None) @patch("langchain.chains.base.Chain.arun") - @patch("langchain_community.llms.vertexai.VertexAI.get_num_tokens") + @patch("langchain_google_vertexai.llms.VertexAI.get_num_tokens") def test_exception_when_num_prompt_tokens_larger_than_model_total_max_tokens(self, tokens_mock, chain_run_mock, models): # GIVEN chain_run_mock.return_value = "{}" @@ -97,7 +97,7 @@ def test_exception_when_num_prompt_tokens_larger_than_model_total_max_tokens(sel assert "Value Error has occurred: Prompt is too long" in response.error @patch("langchain.chains.base.Chain.arun") - @patch("langchain_community.llms.vertexai.VertexAI.get_num_tokens") + @patch("langchain_google_vertexai.llms.VertexAI.get_num_tokens") def test_whether_curly_brackets_are_not_breaking_the_prompt(self, tokens_mock, chain_run_mock, models): # GIVEN chain_run_mock.return_value = "{}" @@ -115,7 +115,7 @@ def test_whether_curly_brackets_are_not_breaking_the_prompt(self, tokens_mock, c assert response.response is not None @patch("langchain.chains.base.Chain.arun") - @patch("langchain_community.llms.vertexai.VertexAI.get_num_tokens") + @patch("langchain_google_vertexai.llms.VertexAI.get_num_tokens") def test_warning_when_num_prompt_tokens_plus_max_output_tokens_larger_than_model_total_max_tokens( self, tokens_mock, diff --git a/tests/test_output_parser.py b/tests/test_output_parser.py index b32179c..976768c 100644 --- a/tests/test_output_parser.py +++ b/tests/test_output_parser.py @@ -10,7 +10,7 @@ class TestOutputModelParserForDifferentModelOutputs: @patch("langchain.chains.base.Chain.arun") - @patch("langchain_community.llms.vertexai.VertexAI.get_num_tokens") + @patch("langchain_google_vertexai.llms.VertexAI.get_num_tokens") def test_output_parser_returns_desired_format(self, tokens_mock, chain_run_mock, models): # GIVEN text_output = "This is the model output" @@ -28,7 +28,7 @@ def test_output_parser_returns_desired_format(self, tokens_mock, chain_run_mock, assert model_response[0].response.summary == text_output @patch("langchain.chains.base.Chain.arun") - @patch("langchain_community.llms.vertexai.VertexAI.get_num_tokens") + @patch("langchain_google_vertexai.llms.VertexAI.get_num_tokens") def test_output_parser_returns_error_when_model_output_returns_different_field(self, tokens_mock, chain_run_mock, models): # GIVEN text_output = "This is the model output" @@ -46,7 +46,7 @@ def test_output_parser_returns_error_when_model_output_returns_different_field(s assert model_response[0].response is None @patch("langchain.chains.base.Chain.arun") - @patch("langchain_community.llms.vertexai.VertexAI.get_num_tokens") + @patch("langchain_google_vertexai.llms.VertexAI.get_num_tokens") @pytest.mark.parametrize("json_response", [ ("{\"summary\": \"This is the model output\"}"), ("Sure! Here's the JSON you wanted: {\"summary\": \"This is the model output\"} Have a nice day!"), @@ -67,7 +67,7 @@ def test_output_parser_extracts_json_from_response(self, tokens_mock, chain_run_ assert model_response[0].response == SummaryOutputClass(summary="This is the model output") @patch("langchain.chains.base.Chain.arun") - @patch("langchain_community.llms.vertexai.VertexAI.get_num_tokens") + @patch("langchain_google_vertexai.llms.VertexAI.get_num_tokens") def test_output_parser_returns_error_when_json_is_garbled(self, tokens_mock, chain_run_mock, models): # GIVEN chain_run_mock.return_value = "Sure! Here's the JSON you wanted: {\"summary: \"text\"}" @@ -83,7 +83,7 @@ def test_output_parser_returns_error_when_json_is_garbled(self, tokens_mock, cha assert model_response[0].response is None @patch("langchain.chains.base.Chain.arun") - @patch("langchain_community.llms.vertexai.VertexAI.get_num_tokens") + @patch("langchain_google_vertexai.llms.VertexAI.get_num_tokens") def test_output_parser_returns_parsed_class_when_model_output_returns_too_many_fields(self, tokens_mock, chain_run_mock, models): # GIVEN text_output = "This is the model output" @@ -101,7 +101,7 @@ def test_output_parser_returns_parsed_class_when_model_output_returns_too_many_f assert model_response[0].response.summary == text_output @patch("langchain.chains.base.Chain.arun") - @patch("langchain_community.llms.vertexai.VertexAI.get_num_tokens") + @patch("langchain_google_vertexai.llms.VertexAI.get_num_tokens") def test_model_returns_output_as_python_list_correctly(self, tokens_mock, chain_run_mock, models): # GIVEN text_output = [1, 2, 3] @@ -119,7 +119,7 @@ def test_model_returns_output_as_python_list_correctly(self, tokens_mock, chain_ assert model_response[0].response.keywords == list(map(str, text_output)) @patch("langchain.chains.base.Chain.arun") - @patch("langchain_community.llms.vertexai.VertexAI.get_num_tokens") + @patch("langchain_google_vertexai.llms.VertexAI.get_num_tokens") def test_model_output_when_input_data_is_empty(self, tokens_mock, chain_run_mock, models): # GIVEN expected_model_response = "2+2 is 4" diff --git a/tests/test_utf_characters_data.py b/tests/test_utf_characters_data.py index f7606f7..b8981bb 100644 --- a/tests/test_utf_characters_data.py +++ b/tests/test_utf_characters_data.py @@ -6,7 +6,7 @@ class TestModelBehaviorForSpecialCharacters: @patch("langchain.chains.base.Chain.arun") - @patch("langchain_community.llms.vertexai.VertexAI.get_num_tokens") + @patch("langchain_google_vertexai.llms.VertexAI.get_num_tokens") @pytest.mark.parametrize("input_character", list(html.entities.entitydefs.values())) def test_model_is_not_broken_by_special_characters(self, tokens_mock, arun_mock, input_character, models): # GIVEN From 3d939c2c778eb80fb87344dd670c6d4d198fdd14 Mon Sep 17 00:00:00 2001 From: Piotr C Zielinski Date: Mon, 1 Jul 2024 17:51:22 +0200 Subject: [PATCH 05/14] Bump package version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 895916b..da86781 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "allms" -version = "1.0.6" +version = "1.0.7" description = "" authors = ["Allegro Opensource "] readme = "README.md" From 801da3929ea314f97cc1487fc102233ed9e9194c Mon Sep 17 00:00:00 2001 From: Piotr C Zielinski Date: Thu, 4 Jul 2024 13:28:08 +0200 Subject: [PATCH 06/14] Add possibility to pass azure_ad_token --- allms/domain/configuration.py | 3 ++- allms/models/azure_openai.py | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/allms/domain/configuration.py b/allms/domain/configuration.py index b2ccf42..55ab712 100644 --- a/allms/domain/configuration.py +++ b/allms/domain/configuration.py @@ -12,7 +12,8 @@ class AzureOpenAIConfiguration: deployment: str model_name: str api_version: str - api_key: str + api_key: Optional[str] = None + azure_ad_token: Optional[str] = None @dataclass diff --git a/allms/models/azure_openai.py b/allms/models/azure_openai.py index 2007b02..3e3a037 100644 --- a/allms/models/azure_openai.py +++ b/allms/models/azure_openai.py @@ -33,6 +33,7 @@ def __init__( event_loop=event_loop ) + from langchain_community.chat_models import AzureChatOpenAI def _create_llm(self) -> AzureChatOpenAI: return AzureChatOpenAI( deployment_name=self._config.deployment, @@ -40,6 +41,7 @@ def _create_llm(self) -> AzureChatOpenAI: model_name=self._config.model_name, base_url=self._config.base_url, api_key=self._config.api_key, + azure_ad_token=self._config.azure_ad_token, temperature=self._temperature, max_tokens=self._max_output_tokens, request_timeout=self._request_timeout_s From cd44d24741a9299eb53800915f805b42393fbafd Mon Sep 17 00:00:00 2001 From: Piotr C Zielinski Date: Thu, 4 Jul 2024 13:31:10 +0200 Subject: [PATCH 07/14] Remove unnecessary line --- allms/models/azure_openai.py | 1 - 1 file changed, 1 deletion(-) diff --git a/allms/models/azure_openai.py b/allms/models/azure_openai.py index 3e3a037..645fb41 100644 --- a/allms/models/azure_openai.py +++ b/allms/models/azure_openai.py @@ -33,7 +33,6 @@ def __init__( event_loop=event_loop ) - from langchain_community.chat_models import AzureChatOpenAI def _create_llm(self) -> AzureChatOpenAI: return AzureChatOpenAI( deployment_name=self._config.deployment, From 7f6428fac51db6b8d23d7a6188eb422801401dd4 Mon Sep 17 00:00:00 2001 From: Riccardo Belluzzo Date: Thu, 8 Aug 2024 13:20:27 +0200 Subject: [PATCH 08/14] Added gcp tokenizer for gemini models --- allms/defaults/vertex_ai.py | 2 +- allms/models/vertexai_gemini.py | 20 ++++++++- poetry.lock | 73 ++++++++++++++++++++++++++++++--- pyproject.toml | 1 + tests/test_end_to_end.py | 2 +- 5 files changed, 90 insertions(+), 8 deletions(-) diff --git a/allms/defaults/vertex_ai.py b/allms/defaults/vertex_ai.py index 12fd2c4..681ab43 100644 --- a/allms/defaults/vertex_ai.py +++ b/allms/defaults/vertex_ai.py @@ -10,7 +10,7 @@ class PalmModelDefaults: class GeminiModelDefaults: - GCP_MODEL_NAME = "gemini-pro" + GCP_MODEL_NAME = "gemini-1.0-pro-001" MODEL_TOTAL_MAX_TOKENS = 30720 MAX_OUTPUT_TOKENS = 2048 TEMPERATURE = 0.0 diff --git a/allms/models/vertexai_gemini.py b/allms/models/vertexai_gemini.py index 1d6c94d..d68ec36 100644 --- a/allms/models/vertexai_gemini.py +++ b/allms/models/vertexai_gemini.py @@ -1,10 +1,15 @@ +import typing from asyncio import AbstractEventLoop + +from langchain_core.prompts import ChatPromptTemplate from langchain_google_vertexai import VertexAI +from vertexai.preview import tokenization from typing import Optional from allms.defaults.general_defaults import GeneralDefaults from allms.defaults.vertex_ai import GeminiModelDefaults from allms.domain.configuration import VertexAIConfiguration +from allms.domain.input_data import InputData from allms.models.vertexai_base import CustomVertexAI from allms.models.abstract import AbstractModel @@ -28,6 +33,8 @@ def __init__( self._verbose = verbose self._config = config + self._gcp_tokenizer = tokenization.get_tokenizer_for_model(self._config.gemini_model_name) + super().__init__( temperature=temperature, model_total_max_tokens=model_total_max_tokens, @@ -48,4 +55,15 @@ def _create_llm(self) -> VertexAI: verbose=self._verbose, project=self._config.cloud_project, location=self._config.cloud_location - ) \ No newline at end of file + ) + + def _get_prompt_tokens_number(self, prompt: ChatPromptTemplate, input_data: InputData) -> int: + return self._gcp_tokenizer.count_tokens( + prompt.format_prompt(**input_data.input_mappings).to_string() + ).total_tokens + + def _get_model_response_tokens_number(self, model_response: typing.Optional[str]) -> int: + if model_response: + return self._gcp_tokenizer.count_tokens(model_response).total_tokens + return 0 + diff --git a/poetry.lock b/poetry.lock index 5a83f92..830e4bc 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.5.0 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "aiohttp" @@ -696,11 +696,11 @@ files = [ google-auth = ">=2.14.1,<3.0.dev0" googleapis-common-protos = ">=1.56.2,<2.0.dev0" grpcio = [ - {version = ">=1.33.2,<2.0dev", optional = true, markers = "extra == \"grpc\""}, + {version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, ] grpcio-status = [ - {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "extra == \"grpc\""}, + {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, ] protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0" @@ -2337,6 +2337,7 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -2735,6 +2736,68 @@ files = [ cryptography = ">=2.0" jeepney = ">=0.6" +[[package]] +name = "sentencepiece" +version = "0.2.0" +description = "SentencePiece python wrapper" +optional = false +python-versions = "*" +files = [ + {file = "sentencepiece-0.2.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:188779e1298a1c8b8253c7d3ad729cb0a9891e5cef5e5d07ce4592c54869e227"}, + {file = "sentencepiece-0.2.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bed9cf85b296fa2b76fc2547b9cbb691a523864cebaee86304c43a7b4cb1b452"}, + {file = "sentencepiece-0.2.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d7b67e724bead13f18db6e1d10b6bbdc454af574d70efbb36f27d90387be1ca3"}, + {file = "sentencepiece-0.2.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2fde4b08cfe237be4484c6c7c2e2c75fb862cfeab6bd5449ce4caeafd97b767a"}, + {file = "sentencepiece-0.2.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4c378492056202d1c48a4979650981635fd97875a00eabb1f00c6a236b013b5e"}, + {file = "sentencepiece-0.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1380ce6540a368de2ef6d7e6ba14ba8f3258df650d39ba7d833b79ee68a52040"}, + {file = "sentencepiece-0.2.0-cp310-cp310-win32.whl", hash = "sha256:a1151d6a6dd4b43e552394aed0edfe9292820272f0194bd56c7c1660a0c06c3d"}, + {file = "sentencepiece-0.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:d490142b0521ef22bc1085f061d922a2a6666175bb6b42e588ff95c0db6819b2"}, + {file = "sentencepiece-0.2.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:17982700c4f6dbb55fa3594f3d7e5dd1c8659a274af3738e33c987d2a27c9d5c"}, + {file = "sentencepiece-0.2.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:7c867012c0e8bcd5bdad0f791609101cb5c66acb303ab3270218d6debc68a65e"}, + {file = "sentencepiece-0.2.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7fd6071249c74f779c5b27183295b9202f8dedb68034e716784364443879eaa6"}, + {file = "sentencepiece-0.2.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:27f90c55a65013cbb8f4d7aab0599bf925cde4adc67ae43a0d323677b5a1c6cb"}, + {file = "sentencepiece-0.2.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b293734059ef656dcd65be62ff771507bea8fed0a711b6733976e1ed3add4553"}, + {file = "sentencepiece-0.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e58b47f933aca74c6a60a79dcb21d5b9e47416256c795c2d58d55cec27f9551d"}, + {file = "sentencepiece-0.2.0-cp311-cp311-win32.whl", hash = "sha256:c581258cf346b327c62c4f1cebd32691826306f6a41d8c4bec43b010dee08e75"}, + {file = "sentencepiece-0.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:0993dbc665f4113017892f1b87c3904a44d0640eda510abcacdfb07f74286d36"}, + {file = "sentencepiece-0.2.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:ea5f536e32ea8ec96086ee00d7a4a131ce583a1b18d130711707c10e69601cb2"}, + {file = "sentencepiece-0.2.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d0cb51f53b6aae3c36bafe41e86167c71af8370a039f542c43b0cce5ef24a68c"}, + {file = "sentencepiece-0.2.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3212121805afc58d8b00ab4e7dd1f8f76c203ddb9dc94aa4079618a31cf5da0f"}, + {file = "sentencepiece-0.2.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2a3149e3066c2a75e0d68a43eb632d7ae728c7925b517f4c05c40f6f7280ce08"}, + {file = "sentencepiece-0.2.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:632f3594d3e7ac8b367bca204cb3fd05a01d5b21455acd097ea4c0e30e2f63d7"}, + {file = "sentencepiece-0.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f295105c6bdbb05bd5e1b0cafbd78ff95036f5d3641e7949455a3f4e5e7c3109"}, + {file = "sentencepiece-0.2.0-cp312-cp312-win32.whl", hash = "sha256:fb89f811e5efd18bab141afc3fea3de141c3f69f3fe9e898f710ae7fe3aab251"}, + {file = "sentencepiece-0.2.0-cp312-cp312-win_amd64.whl", hash = "sha256:7a673a72aab81fef5ebe755c6e0cc60087d1f3a4700835d40537183c1703a45f"}, + {file = "sentencepiece-0.2.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:4547683f330289ec4f093027bfeb87f9ef023b2eb6f879fdc4a8187c7e0ffb90"}, + {file = "sentencepiece-0.2.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7cd6175f7eaec7142d2bf6f6597ce7db4c9ac89acf93fcdb17410c3a8b781eeb"}, + {file = "sentencepiece-0.2.0-cp36-cp36m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:859ba1acde782609a0910a26a60e16c191a82bf39b5621107552c0cd79fad00f"}, + {file = "sentencepiece-0.2.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bcbbef6cc277f8f18f36959e305f10b1c620442d75addc79c21d7073ae581b50"}, + {file = "sentencepiece-0.2.0-cp36-cp36m-win32.whl", hash = "sha256:536b934e244829e3fe6c4f198652cd82da48adb9aa145c9f00889542726dee3d"}, + {file = "sentencepiece-0.2.0-cp36-cp36m-win_amd64.whl", hash = "sha256:0a91aaa3c769b52440df56fafda683b3aa48e3f2169cf7ee5b8c8454a7f3ae9b"}, + {file = "sentencepiece-0.2.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:787e480ca4c1d08c9985a7eb1eae4345c107729c99e9b5a9a00f2575fc7d4b4b"}, + {file = "sentencepiece-0.2.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f4d158189eb2ecffea3a51edf6d25e110b3678ec47f1a40f2d541eafbd8f6250"}, + {file = "sentencepiece-0.2.0-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d1e5ca43013e8935f25457a4fca47e315780172c3e821b4b13a890668911c792"}, + {file = "sentencepiece-0.2.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7140d9e5a74a0908493bb4a13f1f16a401297bd755ada4c707e842fbf6f0f5bf"}, + {file = "sentencepiece-0.2.0-cp37-cp37m-win32.whl", hash = "sha256:6cf333625234f247ab357b0bd9836638405ea9082e1543d5b8408f014979dcbf"}, + {file = "sentencepiece-0.2.0-cp37-cp37m-win_amd64.whl", hash = "sha256:ff88712338b01031910e8e61e7239aff3ce8869ee31a47df63cb38aadd591bea"}, + {file = "sentencepiece-0.2.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:20813a68d4c221b1849c62c30e1281ea81687894d894b8d4a0f4677d9311e0f5"}, + {file = "sentencepiece-0.2.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:926ef920ae2e8182db31d3f5d081ada57804e3e1d3a8c4ef8b117f9d9fb5a945"}, + {file = "sentencepiece-0.2.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:89f65f69636b7e9c015b79dff9c9985a9bc7d19ded6f79ef9f1ec920fdd73ecf"}, + {file = "sentencepiece-0.2.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0f67eae0dbe6f2d7d6ba50a354623d787c99965f068b81e145d53240198021b0"}, + {file = "sentencepiece-0.2.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:98501e075f35dd1a1d5a20f65be26839fcb1938752ec61539af008a5aa6f510b"}, + {file = "sentencepiece-0.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e3d1d2cc4882e8d6a1adf9d5927d7716f80617fc693385661caff21888972269"}, + {file = "sentencepiece-0.2.0-cp38-cp38-win32.whl", hash = "sha256:b99a308a2e5e569031ab164b74e6fab0b6f37dfb493c32f7816225f4d411a6dd"}, + {file = "sentencepiece-0.2.0-cp38-cp38-win_amd64.whl", hash = "sha256:cdb701eec783d3ec86b7cd4c763adad8eaf6b46db37ee1c36e5e6c44b3fe1b5f"}, + {file = "sentencepiece-0.2.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:1e0f9c4d0a6b0af59b613175f019916e28ade076e21242fd5be24340d8a2f64a"}, + {file = "sentencepiece-0.2.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:298f21cc1366eb60311aedba3169d30f885c363ddbf44214b0a587d2908141ad"}, + {file = "sentencepiece-0.2.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3f1ec95aa1e5dab11f37ac7eff190493fd87770f7a8b81ebc9dd768d1a3c8704"}, + {file = "sentencepiece-0.2.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7b06b70af54daa4b4904cbb90b4eb6d35c9f3252fdc86c9c32d5afd4d30118d8"}, + {file = "sentencepiece-0.2.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:22e37bac44dd6603388cb598c64ff7a76e41ca774646f21c23aadfbf5a2228ab"}, + {file = "sentencepiece-0.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0461324897735512a32d222e3d886e24ad6a499761952b6bda2a9ee6e4313ea5"}, + {file = "sentencepiece-0.2.0-cp39-cp39-win32.whl", hash = "sha256:38aed822fb76435fa1f12185f10465a94ab9e51d5e8a9159e9a540ce926f0ffd"}, + {file = "sentencepiece-0.2.0-cp39-cp39-win_amd64.whl", hash = "sha256:d8cf876516548b5a1d6ac4745d8b554f5c07891d55da557925e5c13ff0b4e6ad"}, + {file = "sentencepiece-0.2.0.tar.gz", hash = "sha256:a52c19171daaf2e697dc6cbe67684e0fa341b1248966f6aebb541de654d15843"}, +] + [[package]] name = "shapely" version = "2.0.2" @@ -2862,7 +2925,7 @@ files = [ ] [package.dependencies] -greenlet = {version = "!=0.4.17", markers = "platform_machine == \"win32\" or platform_machine == \"WIN32\" or platform_machine == \"AMD64\" or platform_machine == \"amd64\" or platform_machine == \"x86_64\" or platform_machine == \"ppc64le\" or platform_machine == \"aarch64\""} +greenlet = {version = "!=0.4.17", markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\""} typing-extensions = ">=4.6.0" [package.extras] @@ -3499,4 +3562,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "486f861c4a1327f96eb9e32b7c895cdf603bd09c065be20e143af0808da0d717" +content-hash = "2c4179c18c5e18fc3e8d34e619eb759454d1f3e5c80d0d985085c652001eb02d" diff --git a/pyproject.toml b/pyproject.toml index da86781..772533d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ aioresponses = "^0.7.6" tiktoken = "^0.6.0" openai = "^0.27.8" pytest-mock = "^3.14.0" +sentencepiece = "^0.2.0" [tool.poetry.group.dev.dependencies] pytest = "^7.4.0" diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py index 1d3485e..6c52042 100644 --- a/tests/test_end_to_end.py +++ b/tests/test_end_to_end.py @@ -150,7 +150,7 @@ def test_prompt_is_not_modified_for_open_source_models(self, mock_aioresponse, m def test_gemini_specific_args_are_passed_to_model(self): # GIVEN - gemini_model_name = "gemini-model-name" + gemini_model_name = "gemini-1.0-pro-001" gemini_safety_settings = { HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, From a8d8b3fa2edf9ccd3f027194feffdc2202ea1862 Mon Sep 17 00:00:00 2001 From: Riccardo Belluzzo Date: Thu, 8 Aug 2024 14:39:42 +0200 Subject: [PATCH 09/14] requirement: google-cloud-aiplatform = ">=1.57.0" modified --- allms/defaults/vertex_ai.py | 2 +- poetry.lock | 22 +++++++++++----------- pyproject.toml | 2 +- tests/test_end_to_end.py | 3 +-- 4 files changed, 14 insertions(+), 15 deletions(-) diff --git a/allms/defaults/vertex_ai.py b/allms/defaults/vertex_ai.py index 681ab43..5d03a75 100644 --- a/allms/defaults/vertex_ai.py +++ b/allms/defaults/vertex_ai.py @@ -10,7 +10,7 @@ class PalmModelDefaults: class GeminiModelDefaults: - GCP_MODEL_NAME = "gemini-1.0-pro-001" + GCP_MODEL_NAME = "gemini-1.5-flash-001" MODEL_TOTAL_MAX_TOKENS = 30720 MAX_OUTPUT_TOKENS = 2048 TEMPERATURE = 0.0 diff --git a/poetry.lock b/poetry.lock index 830e4bc..8efaa7d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -736,13 +736,13 @@ requests = ["requests (>=2.20.0,<3.0.0.dev0)"] [[package]] name = "google-cloud-aiplatform" -version = "1.57.0" +version = "1.60.0" description = "Vertex AI API client library" optional = false python-versions = ">=3.8" files = [ - {file = "google-cloud-aiplatform-1.57.0.tar.gz", hash = "sha256:113905f100cb0a9ad744a2445a7675f92f28600233ba499614aa704d11a809b7"}, - {file = "google_cloud_aiplatform-1.57.0-py2.py3-none-any.whl", hash = "sha256:ca5391a56e0cc8f4ed39a2beb7be02f51936ff04fd5304775a72a86c345d0e47"}, + {file = "google-cloud-aiplatform-1.60.0.tar.gz", hash = "sha256:782c7f1ec0e77a7c7daabef3b65bfd506ed2b4b1dc2186753c43cd6faf8dd04e"}, + {file = "google_cloud_aiplatform-1.60.0-py2.py3-none-any.whl", hash = "sha256:5f14159c9575f4b46335027e3ceb8fa57bd5eaa76a07f858105b8c6c034ec0d6"}, ] [package.dependencies] @@ -753,7 +753,7 @@ google-cloud-bigquery = ">=1.15.0,<3.20.0 || >3.20.0,<4.0.0dev" google-cloud-resource-manager = ">=1.3.3,<3.0.0dev" google-cloud-storage = ">=1.32.0,<3.0.0dev" packaging = ">=14.3" -proto-plus = ">=1.22.0,<2.0.0dev" +proto-plus = ">=1.22.3,<2.0.0dev" protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev" pydantic = "<3" shapely = "<3.0.0dev" @@ -763,21 +763,21 @@ autologging = ["mlflow (>=1.27.0,<=2.1.1)"] cloud-profiler = ["tensorboard-plugin-profile (>=2.4.0,<3.0.0dev)", "tensorflow (>=2.4.0,<3.0.0dev)", "werkzeug (>=2.0.0,<2.1.0dev)"] datasets = ["pyarrow (>=10.0.1)", "pyarrow (>=14.0.0)", "pyarrow (>=3.0.0,<8.0dev)"] endpoint = ["requests (>=2.28.1)"] -full = ["cloudpickle (<3.0)", "docker (>=5.0.3)", "explainable-ai-sdk (>=1.0.0)", "fastapi (>=0.71.0,<=0.109.1)", "google-cloud-bigquery", "google-cloud-bigquery-storage", "google-cloud-logging (<4.0)", "google-vizier (>=0.1.6)", "httpx (>=0.23.0,<0.25.0)", "immutabledict", "lit-nlp (==0.4.0)", "mlflow (>=1.27.0,<=2.1.1)", "nest-asyncio (>=1.0.0,<1.6.0)", "numpy (>=1.15.0)", "pandas (>=1.0.0)", "pandas (>=1.0.0,<2.2.0)", "pyarrow (>=10.0.1)", "pyarrow (>=14.0.0)", "pyarrow (>=3.0.0,<8.0dev)", "pyarrow (>=6.0.1)", "pydantic (<2)", "pyyaml (>=5.3.1,<7)", "ray[default] (>=2.4,<2.5.dev0 || >2.9.0,!=2.9.1,!=2.9.2,<=2.9.3)", "ray[default] (>=2.5,<=2.9.3)", "requests (>=2.28.1)", "setuptools (<70.0.0)", "starlette (>=0.17.1)", "tensorboard-plugin-profile (>=2.4.0,<3.0.0dev)", "tensorflow (>=2.3.0,<3.0.0dev)", "tensorflow (>=2.3.0,<3.0.0dev)", "tensorflow (>=2.4.0,<3.0.0dev)", "urllib3 (>=1.21.1,<1.27)", "uvicorn[standard] (>=0.16.0)", "werkzeug (>=2.0.0,<2.1.0dev)"] -langchain = ["langchain (>=0.1.16,<0.3)", "langchain-core (<0.2)", "langchain-google-vertexai (<2)", "openinference-instrumentation-langchain (>=0.1.19,<0.2)", "tenacity (<=8.3)"] -langchain-testing = ["absl-py", "cloudpickle (>=3.0,<4.0)", "langchain (>=0.1.16,<0.3)", "langchain-core (<0.2)", "langchain-google-vertexai (<2)", "openinference-instrumentation-langchain (>=0.1.19,<0.2)", "opentelemetry-exporter-gcp-trace (<2)", "opentelemetry-sdk (<2)", "pydantic (>=2.6.3,<3)", "pytest-xdist", "tenacity (<=8.3)"] +full = ["cloudpickle (<3.0)", "docker (>=5.0.3)", "explainable-ai-sdk (>=1.0.0)", "fastapi (>=0.71.0,<=0.109.1)", "google-cloud-bigquery", "google-cloud-bigquery-storage", "google-cloud-logging (<4.0)", "google-vizier (>=0.1.6)", "httpx (>=0.23.0,<0.25.0)", "immutabledict", "lit-nlp (==0.4.0)", "mlflow (>=1.27.0,<=2.1.1)", "numpy (>=1.15.0)", "pandas (>=1.0.0)", "pandas (>=1.0.0,<2.2.0)", "pyarrow (>=10.0.1)", "pyarrow (>=14.0.0)", "pyarrow (>=3.0.0,<8.0dev)", "pyarrow (>=6.0.1)", "pydantic (<2)", "pyyaml (>=5.3.1,<7)", "ray[default] (>=2.4,<2.5.dev0 || >2.9.0,!=2.9.1,!=2.9.2,<=2.9.3)", "ray[default] (>=2.5,<=2.9.3)", "requests (>=2.28.1)", "setuptools (<70.0.0)", "starlette (>=0.17.1)", "tensorboard-plugin-profile (>=2.4.0,<3.0.0dev)", "tensorflow (>=2.3.0,<3.0.0dev)", "tensorflow (>=2.3.0,<3.0.0dev)", "tensorflow (>=2.4.0,<3.0.0dev)", "tqdm (>=4.23.0)", "urllib3 (>=1.21.1,<1.27)", "uvicorn[standard] (>=0.16.0)", "werkzeug (>=2.0.0,<2.1.0dev)"] +langchain = ["langchain (>=0.1.16,<0.3)", "langchain-core (<0.3)", "langchain-google-vertexai (<2)", "openinference-instrumentation-langchain (>=0.1.19,<0.2)", "tenacity (<=8.3)"] +langchain-testing = ["absl-py", "cloudpickle (>=3.0,<4.0)", "google-cloud-trace (<2)", "langchain (>=0.1.16,<0.3)", "langchain-core (<0.3)", "langchain-google-vertexai (<2)", "openinference-instrumentation-langchain (>=0.1.19,<0.2)", "opentelemetry-exporter-gcp-trace (<2)", "opentelemetry-sdk (<2)", "pydantic (>=2.6.3,<3)", "pytest-xdist", "tenacity (<=8.3)"] lit = ["explainable-ai-sdk (>=1.0.0)", "lit-nlp (==0.4.0)", "pandas (>=1.0.0)", "tensorflow (>=2.3.0,<3.0.0dev)"] metadata = ["numpy (>=1.15.0)", "pandas (>=1.0.0)"] pipelines = ["pyyaml (>=5.3.1,<7)"] prediction = ["docker (>=5.0.3)", "fastapi (>=0.71.0,<=0.109.1)", "httpx (>=0.23.0,<0.25.0)", "starlette (>=0.17.1)", "uvicorn[standard] (>=0.16.0)"] preview = ["cloudpickle (<3.0)", "google-cloud-logging (<4.0)"] private-endpoints = ["requests (>=2.28.1)", "urllib3 (>=1.21.1,<1.27)"] -rapid-evaluation = ["nest-asyncio (>=1.0.0,<1.6.0)", "pandas (>=1.0.0,<2.2.0)"] +rapid-evaluation = ["pandas (>=1.0.0,<2.2.0)", "tqdm (>=4.23.0)"] ray = ["google-cloud-bigquery", "google-cloud-bigquery-storage", "immutabledict", "pandas (>=1.0.0,<2.2.0)", "pyarrow (>=6.0.1)", "pydantic (<2)", "ray[default] (>=2.4,<2.5.dev0 || >2.9.0,!=2.9.1,!=2.9.2,<=2.9.3)", "ray[default] (>=2.5,<=2.9.3)", "setuptools (<70.0.0)"] ray-testing = ["google-cloud-bigquery", "google-cloud-bigquery-storage", "immutabledict", "pandas (>=1.0.0,<2.2.0)", "pyarrow (>=6.0.1)", "pydantic (<2)", "pytest-xdist", "ray[default] (>=2.4,<2.5.dev0 || >2.9.0,!=2.9.1,!=2.9.2,<=2.9.3)", "ray[default] (>=2.5,<=2.9.3)", "ray[train] (==2.9.3)", "scikit-learn", "setuptools (<70.0.0)", "tensorflow", "torch (>=2.0.0,<2.1.0)", "xgboost", "xgboost-ray"] -reasoningengine = ["cloudpickle (>=3.0,<4.0)", "opentelemetry-exporter-gcp-trace (<2)", "opentelemetry-sdk (<2)", "pydantic (>=2.6.3,<3)"] +reasoningengine = ["cloudpickle (>=3.0,<4.0)", "google-cloud-trace (<2)", "opentelemetry-exporter-gcp-trace (<2)", "opentelemetry-sdk (<2)", "pydantic (>=2.6.3,<3)"] tensorboard = ["tensorboard-plugin-profile (>=2.4.0,<3.0.0dev)", "tensorflow (>=2.3.0,<3.0.0dev)", "tensorflow (>=2.4.0,<3.0.0dev)", "werkzeug (>=2.0.0,<2.1.0dev)"] -testing = ["bigframes", "cloudpickle (<3.0)", "docker (>=5.0.3)", "explainable-ai-sdk (>=1.0.0)", "fastapi (>=0.71.0,<=0.109.1)", "google-api-core (>=2.11,<3.0.0)", "google-cloud-bigquery", "google-cloud-bigquery-storage", "google-cloud-logging (<4.0)", "google-vizier (>=0.1.6)", "grpcio-testing", "httpx (>=0.23.0,<0.25.0)", "immutabledict", "ipython", "kfp (>=2.6.0,<3.0.0)", "lit-nlp (==0.4.0)", "mlflow (>=1.27.0,<=2.1.1)", "nest-asyncio (>=1.0.0,<1.6.0)", "nltk", "numpy (>=1.15.0)", "pandas (>=1.0.0)", "pandas (>=1.0.0,<2.2.0)", "pyarrow (>=10.0.1)", "pyarrow (>=14.0.0)", "pyarrow (>=3.0.0,<8.0dev)", "pyarrow (>=6.0.1)", "pydantic (<2)", "pyfakefs", "pytest-asyncio", "pytest-xdist", "pyyaml (>=5.3.1,<7)", "ray[default] (>=2.4,<2.5.dev0 || >2.9.0,!=2.9.1,!=2.9.2,<=2.9.3)", "ray[default] (>=2.5,<=2.9.3)", "requests (>=2.28.1)", "requests-toolbelt (<1.0.0)", "scikit-learn", "sentencepiece (>=0.2.0)", "setuptools (<70.0.0)", "starlette (>=0.17.1)", "tensorboard-plugin-profile (>=2.4.0,<3.0.0dev)", "tensorflow (==2.13.0)", "tensorflow (==2.16.1)", "tensorflow (>=2.3.0,<3.0.0dev)", "tensorflow (>=2.3.0,<3.0.0dev)", "tensorflow (>=2.4.0,<3.0.0dev)", "torch (>=2.0.0,<2.1.0)", "torch (>=2.2.0)", "urllib3 (>=1.21.1,<1.27)", "uvicorn[standard] (>=0.16.0)", "werkzeug (>=2.0.0,<2.1.0dev)", "xgboost"] +testing = ["bigframes", "cloudpickle (<3.0)", "docker (>=5.0.3)", "explainable-ai-sdk (>=1.0.0)", "fastapi (>=0.71.0,<=0.109.1)", "google-api-core (>=2.11,<3.0.0)", "google-cloud-bigquery", "google-cloud-bigquery-storage", "google-cloud-logging (<4.0)", "google-vizier (>=0.1.6)", "grpcio-testing", "httpx (>=0.23.0,<0.25.0)", "immutabledict", "ipython", "kfp (>=2.6.0,<3.0.0)", "lit-nlp (==0.4.0)", "mlflow (>=1.27.0,<=2.1.1)", "nltk", "numpy (>=1.15.0)", "pandas (>=1.0.0)", "pandas (>=1.0.0,<2.2.0)", "pyarrow (>=10.0.1)", "pyarrow (>=14.0.0)", "pyarrow (>=3.0.0,<8.0dev)", "pyarrow (>=6.0.1)", "pydantic (<2)", "pyfakefs", "pytest-asyncio", "pytest-xdist", "pyyaml (>=5.3.1,<7)", "ray[default] (>=2.4,<2.5.dev0 || >2.9.0,!=2.9.1,!=2.9.2,<=2.9.3)", "ray[default] (>=2.5,<=2.9.3)", "requests (>=2.28.1)", "requests-toolbelt (<1.0.0)", "scikit-learn", "sentencepiece (>=0.2.0)", "setuptools (<70.0.0)", "starlette (>=0.17.1)", "tensorboard-plugin-profile (>=2.4.0,<3.0.0dev)", "tensorflow (==2.13.0)", "tensorflow (==2.16.1)", "tensorflow (>=2.3.0,<3.0.0dev)", "tensorflow (>=2.3.0,<3.0.0dev)", "tensorflow (>=2.4.0,<3.0.0dev)", "torch (>=2.0.0,<2.1.0)", "torch (>=2.2.0)", "tqdm (>=4.23.0)", "urllib3 (>=1.21.1,<1.27)", "uvicorn[standard] (>=0.16.0)", "werkzeug (>=2.0.0,<2.1.0dev)", "xgboost"] tokenization = ["sentencepiece (>=0.2.0)"] vizier = ["google-vizier (>=0.1.6)"] xai = ["tensorflow (>=2.3.0,<3.0.0dev)"] @@ -3562,4 +3562,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "2c4179c18c5e18fc3e8d34e619eb759454d1f3e5c80d0d985085c652001eb02d" +content-hash = "3307e09df9b656a276976144be13f75a1499cd9c4b5963e91c647e01fc6242ad" diff --git a/pyproject.toml b/pyproject.toml index 772533d..a996a83 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ packages = [{include = "allms"}] [tool.poetry.dependencies] python = ">=3.8.1,<4.0" fsspec = "^2023.6.0" -google-cloud-aiplatform = "^1.47.0" +google-cloud-aiplatform = ">=1.57.0" pydash = "^7.0.6" transformers = "^4.34.1" pydantic = "1.10.13" diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py index 6c52042..74414a1 100644 --- a/tests/test_end_to_end.py +++ b/tests/test_end_to_end.py @@ -72,7 +72,6 @@ def test_model_is_queried_successfully( assert list(map(lambda output: int(output[IODataConstants.GENERATED_TOKENS_NUMBER]), expected_output)) == list( map(lambda example: example.number_of_generated_tokens, parsed_responses)) - def test_prompt_is_not_modified_for_open_source_models(self, mock_aioresponse, models, mocker): # GIVEN open_source_models = ["azure_llama2", "azure_mistral", "vertex_gemma"] @@ -150,7 +149,7 @@ def test_prompt_is_not_modified_for_open_source_models(self, mock_aioresponse, m def test_gemini_specific_args_are_passed_to_model(self): # GIVEN - gemini_model_name = "gemini-1.0-pro-001" + gemini_model_name = "gemini-1.5-pro-001" gemini_safety_settings = { HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, From f09ccb9ef064bda24921986e8e038e10e46d1fa1 Mon Sep 17 00:00:00 2001 From: Riccardo Belluzzo Date: Thu, 8 Aug 2024 14:42:31 +0200 Subject: [PATCH 10/14] README.md updated and pkg version bumped up --- README.md | 14 +++++++------- pyproject.toml | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 0495b59..65e3135 100644 --- a/README.md +++ b/README.md @@ -26,13 +26,13 @@ ___ ## Supported Models -| LLM Family | Hosting | Supported LLMs | -|-------------|---------------------|-----------------------------------------| -| GPT(s) | OpenAI endpoint | `gpt-3.5-turbo`, `gpt-4`, `gpt-4-turbo` | -| Google LLMs | VertexAI deployment | `text-bison@001`, `gemini-pro` | -| Llama2 | Azure deployment | `llama2-7b`, `llama2-13b`, `llama2-70b` | -| Mistral | Azure deployment | `Mistral-7b`, `Mixtral-7bx8` | -| Gemma | GCP deployment | `gemma` | +| LLM Family | Hosting | Supported LLMs | +|-------------|---------------------|------------------------------------------------------------------| +| GPT(s) | OpenAI endpoint | `gpt-3.5-turbo`, `gpt-4`, `gpt-4-turbo`, `gpt4-o`, `gpt4-o mini` | +| Google LLMs | VertexAI deployment | `text-bison@001`, `gemini-pro`, `gemini-flash` | +| Llama2 | Azure deployment | `llama2-7b`, `llama2-13b`, `llama2-70b` | +| Mistral | Azure deployment | `Mistral-7b`, `Mixtral-7bx8` | +| Gemma | GCP deployment | `gemma` | * Do you already have a subscription to a Cloud Provider for any the models above? Configure the model using your credentials and start querying! diff --git a/pyproject.toml b/pyproject.toml index a996a83..1454196 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "allms" -version = "1.0.7" +version = "1.0.8" description = "" authors = ["Allegro Opensource "] readme = "README.md" From 2483a09b776e6b41e037e35c6b39db0845352f38 Mon Sep 17 00:00:00 2001 From: Piotr C Zielinski Date: Thu, 29 Aug 2024 17:42:14 +0200 Subject: [PATCH 11/14] Due to a bug in `langchain`, the response is empty if the request was blocked because of a prompt being unsafe. We catch it here and raise a proper exception --- allms/models/abstract.py | 3 ++- allms/models/vertexai_base.py | 7 +++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/allms/models/abstract.py b/allms/models/abstract.py index 31f8547..da4391f 100644 --- a/allms/models/abstract.py +++ b/allms/models/abstract.py @@ -32,6 +32,7 @@ from allms.domain.input_data import InputData from allms.domain.prompt_dto import SummaryOutputClass, KeywordsOutputClass from allms.domain.response import ResponseData +from allms.models.vertexai_base import GCPInvalidRequestError from allms.utils.long_text_processing_utils import get_max_allowed_number_of_tokens from allms.utils.response_parsing_utils import ResponseParser @@ -260,7 +261,7 @@ async def _predict_example( model_response = None error_message = f"{IODataConstants.ERROR_MESSAGE_STR}: {invalid_request_error}" - except (InvalidArgument, ValueError, TimeoutError, openai.error.Timeout) as other_error: + except (InvalidArgument, ValueError, TimeoutError, openai.error.Timeout, GCPInvalidRequestError) as other_error: model_response = None logger.info(f"Error for id {input_data.id} has occurred. Message: {other_error} ") error_message = f"{type(other_error).__name__}: {other_error}" diff --git a/allms/models/vertexai_base.py b/allms/models/vertexai_base.py index 3724212..4b9898e 100644 --- a/allms/models/vertexai_base.py +++ b/allms/models/vertexai_base.py @@ -9,6 +9,10 @@ from allms.constants.vertex_ai import VertexModelConstants +class GCPInvalidRequestError(Exception): + pass + + class CustomVertexAI(VertexAI): async def _agenerate( self, @@ -31,6 +35,9 @@ def was_response_blocked(generation: Generation) -> bool: **kwargs ) + if not all(result.generations): + raise GCPInvalidRequestError("The response is empty. It may have been blocked due to content filtering.") + return LLMResult( generations=( chain(result.generations) From 58f8099c31c1e8ccf8302b83afc85e424c78e542 Mon Sep 17 00:00:00 2001 From: Piotr C Zielinski Date: Thu, 29 Aug 2024 17:45:50 +0200 Subject: [PATCH 12/14] Bump up package version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 1454196..6305237 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "allms" -version = "1.0.8" +version = "1.0.9" description = "" authors = ["Allegro Opensource "] readme = "README.md" From 1ceaa17aa6e1b7e0015a4e5ed3134b0cf5121108 Mon Sep 17 00:00:00 2001 From: Piotr C Zielinski Date: Thu, 12 Sep 2024 13:33:17 +0200 Subject: [PATCH 13/14] Better gemini model names support --- allms/models/vertexai_gemini.py | 19 +++++++++++++--- tests/test_end_to_end.py | 39 +++++++++++++++++++++++++++++++-- 2 files changed, 53 insertions(+), 5 deletions(-) diff --git a/allms/models/vertexai_gemini.py b/allms/models/vertexai_gemini.py index d68ec36..cfdf88f 100644 --- a/allms/models/vertexai_gemini.py +++ b/allms/models/vertexai_gemini.py @@ -1,17 +1,20 @@ import typing from asyncio import AbstractEventLoop +from typing import Optional from langchain_core.prompts import ChatPromptTemplate from langchain_google_vertexai import VertexAI from vertexai.preview import tokenization -from typing import Optional +from vertexai.tokenization._tokenizers import Tokenizer from allms.defaults.general_defaults import GeneralDefaults from allms.defaults.vertex_ai import GeminiModelDefaults from allms.domain.configuration import VertexAIConfiguration from allms.domain.input_data import InputData -from allms.models.vertexai_base import CustomVertexAI from allms.models.abstract import AbstractModel +from allms.models.vertexai_base import CustomVertexAI + +BASE_GEMINI_MODEL_NAMES = ["gemini-1.0-pro", "gemini-1.5-pro", "gemini-1.5-flash"] class VertexAIGeminiModel(AbstractModel): @@ -33,7 +36,7 @@ def __init__( self._verbose = verbose self._config = config - self._gcp_tokenizer = tokenization.get_tokenizer_for_model(self._config.gemini_model_name) + self._gcp_tokenizer = self._get_gcp_tokenizer(self._config.gemini_model_name) super().__init__( temperature=temperature, @@ -67,3 +70,13 @@ def _get_model_response_tokens_number(self, model_response: typing.Optional[str] return self._gcp_tokenizer.count_tokens(model_response).total_tokens return 0 + @staticmethod + def _get_gcp_tokenizer(model_name) -> Tokenizer: + try: + return tokenization.get_tokenizer_for_model(model_name) + except ValueError: + for base_model_name in BASE_GEMINI_MODEL_NAMES: + if model_name.startswith(base_model_name): + return tokenization.get_tokenizer_for_model(base_model_name) + raise + diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py index 74414a1..d7ca5e4 100644 --- a/tests/test_end_to_end.py +++ b/tests/test_end_to_end.py @@ -1,7 +1,8 @@ import re -from unittest.mock import patch -from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, PromptTemplate, SystemMessagePromptTemplate +import pytest +from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, PromptTemplate, \ + SystemMessagePromptTemplate from allms.constants.input_data import IODataConstants from allms.domain.configuration import VertexAIConfiguration @@ -171,6 +172,40 @@ def test_gemini_specific_args_are_passed_to_model(self): assert gemini_model._llm.model_name == gemini_model_name assert gemini_model._llm.safety_settings == gemini_safety_settings + @pytest.mark.parametrize( + "model_name", [ + "gemini-1.0-pro", "gemini-1.5-pro", "gemini-1.5-flash","gemini-1.0-pro-001", "gemini-1.0-pro-002", + "gemini-1.5-pro-001", "gemini-1.5-flash-001", "gemini-1.5-pro-preview-0514" + ] + ) + def test_correct_gemini_model_name_work(self, model_name): + # GIVEN + model_config = VertexAIConfiguration( + cloud_project="dummy-project-id", + cloud_location="us-central1", + gemini_model_name=model_name, + ) + + # WHEN & THEN + VertexAIGeminiModel(config=model_config) + + @pytest.mark.parametrize( + "model_name", [ + "gemini-2.0-pro", "geminis-1.5-pro", "gemini-flash", "gemini-1.5-preview-pro", "gpt4" + ] + ) + def test_incorrect_gemini_model_name_fail(self, model_name): + # GIVEN + model_config = VertexAIConfiguration( + cloud_project="dummy-project-id", + cloud_location="us-central1", + gemini_model_name=model_name, + ) + + # WHEN & THEN + with pytest.raises(ValueError, match=f"Model {model_name} is not supported."): + VertexAIGeminiModel(config=model_config) + def test_model_times_out( self, mock_aioresponse, From d54a711aa76bf3324f1f6459bcd8e05508cc8966 Mon Sep 17 00:00:00 2001 From: Piotr C Zielinski Date: Thu, 12 Sep 2024 13:34:13 +0200 Subject: [PATCH 14/14] Bump up version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6305237..a85c703 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "allms" -version = "1.0.9" +version = "1.0.10" description = "" authors = ["Allegro Opensource "] readme = "README.md"