From 152427eca13da070cc03f3f245a43bff312e43d1 Mon Sep 17 00:00:00 2001 From: Isaac Francisco <78627776+isahers1@users.noreply.github.com> Date: Fri, 26 Jul 2024 17:39:57 -0700 Subject: [PATCH] make image inputs compatible with langchain_ollama (#24619) --- .../ollama/langchain_ollama/chat_models.py | 17 ++++++++------ .../integration_tests/test_chat_models.py | 22 +++++++++++++++++++ 2 files changed, 32 insertions(+), 7 deletions(-) diff --git a/libs/partners/ollama/langchain_ollama/chat_models.py b/libs/partners/ollama/langchain_ollama/chat_models.py index 339c2b2711ab1..775d8a5132419 100644 --- a/libs/partners/ollama/langchain_ollama/chat_models.py +++ b/libs/partners/ollama/langchain_ollama/chat_models.py @@ -346,7 +346,7 @@ def _convert_messages_to_ollama_messages( ) -> Sequence[Message]: ollama_messages: List = [] for message in messages: - role = "" + role: Literal["user", "assistant", "system", "tool"] tool_call_id: Optional[str] = None tool_calls: Optional[List[Dict[str, Any]]] = None if isinstance(message, HumanMessage): @@ -383,11 +383,13 @@ def _convert_messages_to_ollama_messages( image_url = None temp_image_url = content_part.get("image_url") if isinstance(temp_image_url, str): - image_url = content_part["image_url"] + image_url = temp_image_url elif ( - isinstance(temp_image_url, dict) and "url" in temp_image_url + isinstance(temp_image_url, dict) + and "url" in temp_image_url + and isinstance(temp_image_url["url"], str) ): - image_url = temp_image_url + image_url = temp_image_url["url"] else: raise ValueError( "Only string image_url or dict with string 'url' " @@ -408,15 +410,16 @@ def _convert_messages_to_ollama_messages( "Must either have type 'text' or type 'image_url' " "with a string 'image_url' field." ) - msg = { + # Should convert to ollama.Message once role includes tool, and tool_call_id is in Message # noqa: E501 + msg: dict = { "role": role, "content": content, "images": images, } + if tool_calls: + msg["tool_calls"] = tool_calls # type: ignore if tool_call_id: msg["tool_call_id"] = tool_call_id - if tool_calls: - msg["tool_calls"] = tool_calls ollama_messages.append(msg) return ollama_messages diff --git a/libs/partners/ollama/tests/integration_tests/test_chat_models.py b/libs/partners/ollama/tests/integration_tests/test_chat_models.py index 0d54b2c738bba..10ffcb39d2777 100644 --- a/libs/partners/ollama/tests/integration_tests/test_chat_models.py +++ b/libs/partners/ollama/tests/integration_tests/test_chat_models.py @@ -2,6 +2,8 @@ from typing import Type +import pytest +from langchain_core.language_models import BaseChatModel from langchain_standard_tests.integration_tests import ChatModelIntegrationTests from langchain_ollama.chat_models import ChatOllama @@ -15,3 +17,23 @@ def chat_model_class(self) -> Type[ChatOllama]: @property def chat_model_params(self) -> dict: return {"model": "llama3-groq-tool-use"} + + @property + def supports_image_inputs(self) -> bool: + return True + + @pytest.mark.xfail( + reason=( + "Fails with 'AssertionError'. Ollama does not support 'tool_choice' yet." + ) + ) + def test_structured_output(self, model: BaseChatModel) -> None: + super().test_structured_output(model) + + @pytest.mark.xfail( + reason=( + "Fails with 'AssertionError'. Ollama does not support 'tool_choice' yet." + ) + ) + def test_structured_output_pydantic_2_v1(self, model: BaseChatModel) -> None: + super().test_structured_output_pydantic_2_v1(model)