diff --git a/app/common/message_converters.py b/app/common/message_converters.py index 3059a57b..4ca1dd80 100644 --- a/app/common/message_converters.py +++ b/app/common/message_converters.py @@ -1,4 +1,5 @@ from datetime import datetime +from typing import Literal from langchain_core.messages import BaseMessage @@ -47,7 +48,7 @@ def convert_langchain_message_to_iris_message( ) -def map_role_to_str(role: IrisMessageRole) -> str: +def map_role_to_str(role: IrisMessageRole) -> Literal["user", "assistant", "system"]: match role: case IrisMessageRole.USER: return "user" diff --git a/app/domain/__init__.py b/app/domain/__init__.py index 149df609..c2f4199e 100644 --- a/app/domain/__init__.py +++ b/app/domain/__init__.py @@ -5,3 +5,4 @@ TutorChatPipelineExecutionDTO, ) from .pyris_message import PyrisMessage, IrisMessageRole +from app.domain.data import image_message_content_dto diff --git a/app/domain/data/image_message_content_dto.py b/app/domain/data/image_message_content_dto.py index d48fd717..a73e2654 100644 --- a/app/domain/data/image_message_content_dto.py +++ b/app/domain/data/image_message_content_dto.py @@ -1,7 +1,7 @@ +from pydantic import BaseModel from typing import Optional -from pydantic import BaseModel, Field - class ImageMessageContentDTO(BaseModel): - image_data: Optional[str] = Field(alias="imageData", default=None) + base64: str + prompt: Optional[str] diff --git a/app/domain/data/json_message_content_dto.py b/app/domain/data/json_message_content_dto.py index 73a0d7cb..cd4ccfcb 100644 --- a/app/domain/data/json_message_content_dto.py +++ b/app/domain/data/json_message_content_dto.py @@ -1,6 +1,8 @@ -from pydantic import BaseModel, Field, Json -from typing import Any, Optional +from pydantic import BaseModel, ConfigDict, Field, Json +from typing import Any class JsonMessageContentDTO(BaseModel): - json_content: Optional[Json[Any]] = Field(alias="jsonContent", default=None) + model_config = ConfigDict(populate_by_name=True) + + json_content: Json[Any] = Field(alias="jsonContent") diff --git a/app/domain/data/text_message_content_dto.py b/app/domain/data/text_message_content_dto.py index b7ece8f9..9442dbd3 100644 --- a/app/domain/data/text_message_content_dto.py +++ b/app/domain/data/text_message_content_dto.py @@ -1,7 +1,7 @@ -from typing import Optional - -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field class TextMessageContentDTO(BaseModel): - text_content: Optional[str] = Field(alias="textContent", default=None) + model_config = ConfigDict(populate_by_name=True) + + text_content: str = Field(alias="textContent") diff --git a/app/domain/pyris_message.py b/app/domain/pyris_message.py index 5f44cd9d..056f77ef 100644 --- a/app/domain/pyris_message.py +++ b/app/domain/pyris_message.py @@ -2,7 +2,7 @@ from enum import Enum from typing import List -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field from app.domain.data.message_content_dto import MessageContentDTO @@ -14,6 +14,8 @@ class IrisMessageRole(str, Enum): class PyrisMessage(BaseModel): + model_config = ConfigDict(populate_by_name=True) + sent_at: datetime | None = Field(alias="sentAt", default=None) sender: IrisMessageRole contents: List[MessageContentDTO] = [] diff --git a/app/llm/external/model.py b/app/llm/external/model.py index 4d42745b..47b90962 100644 --- a/app/llm/external/model.py +++ b/app/llm/external/model.py @@ -60,3 +60,27 @@ def embed(self, text: str) -> list[float]: raise NotImplementedError( f"The LLM {self.__str__()} does not support embeddings" ) + + +class ImageGenerationModel(LanguageModel, metaclass=ABCMeta): + """Abstract class for the llm image generation wrappers""" + + @classmethod + def __subclasshook__(cls, subclass): + return hasattr(subclass, "generate_images") and callable( + subclass.generate_images + ) + + @abstractmethod + def generate_images( + self, + prompt: str, + n: int = 1, + size: str = "256x256", + quality: str = "standard", + **kwargs, + ) -> list: + """Create an image from the prompt""" + raise NotImplementedError( + f"The LLM {self.__str__()} does not support image generation" + ) diff --git a/app/llm/external/ollama.py b/app/llm/external/ollama.py index 72dbb04e..f2363b23 100644 --- a/app/llm/external/ollama.py +++ b/app/llm/external/ollama.py @@ -1,26 +1,65 @@ +import base64 from datetime import datetime -from typing import Literal, Any +from typing import Literal, Any, Optional from ollama import Client, Message from ...common.message_converters import map_role_to_str, map_str_to_role +from ...domain.data.json_message_content_dto import JsonMessageContentDTO from ...domain.data.text_message_content_dto import TextMessageContentDTO +from ...domain.data.image_message_content_dto import ImageMessageContentDTO from ...domain import PyrisMessage from ...llm import CompletionArguments from ...llm.external.model import ChatModel, CompletionModel, EmbeddingModel +def convert_to_ollama_images(base64_images: list[str]) -> list[bytes] | None: + """ + Convert a list of base64 images to a list of bytes + """ + if not base64_images: + return None + return [base64.b64decode(base64_image) for base64_image in base64_images] + + def convert_to_ollama_messages(messages: list[PyrisMessage]) -> list[Message]: - return [ - Message( - role=map_role_to_str(message.sender), - content=message.contents[0].text_content, + """ + Convert a list of PyrisMessages to a list of Ollama Messages + """ + messages_to_return = [] + for message in messages: + if len(message.contents) == 0: + continue + text_content = "" + images = [] + for content in message.contents: + match content: + case ImageMessageContentDTO(): + images.append(content.base64) + case TextMessageContentDTO(): + if len(text_content) > 0: + text_content += "\n" + text_content += content.text_content + case JsonMessageContentDTO(): + if len(text_content) > 0: + text_content += "\n" + text_content += content.json_content + case _: + continue + messages_to_return.append( + Message( + role=map_role_to_str(message.sender), + content=text_content, + images=convert_to_ollama_images(images), + ) ) - for message in messages - ] + return messages_to_return def convert_to_iris_message(message: Message) -> PyrisMessage: + """ + Convert a Message to a PyrisMessage + """ contents = [TextMessageContentDTO(text_content=message["content"])] return PyrisMessage( sender=map_str_to_role(message["role"]), @@ -42,8 +81,15 @@ class OllamaModel( def model_post_init(self, __context: Any) -> None: self._client = Client(host=self.host) # TODO: Add authentication (httpx auth?) - def complete(self, prompt: str, arguments: CompletionArguments) -> str: - response = self._client.generate(model=self.model, prompt=prompt) + def complete( + self, + prompt: str, + arguments: CompletionArguments, + image: Optional[ImageMessageContentDTO] = None, + ) -> str: + response = self._client.generate( + model=self.model, prompt=prompt, images=[image.base64] if image else None + ) return response["response"] def chat( diff --git a/app/llm/external/openai_chat.py b/app/llm/external/openai_chat.py index 450efdd7..894b3b18 100644 --- a/app/llm/external/openai_chat.py +++ b/app/llm/external/openai_chat.py @@ -3,11 +3,13 @@ from openai import OpenAI from openai.lib.azure import AzureOpenAI -from openai.types.chat import ChatCompletionMessageParam, ChatCompletionMessage +from openai.types.chat import ChatCompletionMessage, ChatCompletionMessageParam -from ...common.message_converters import map_role_to_str, map_str_to_role +from ...common.message_converters import map_str_to_role, map_role_to_str from app.domain.data.text_message_content_dto import TextMessageContentDTO from ...domain import PyrisMessage +from ...domain.data.image_message_content_dto import ImageMessageContentDTO +from ...domain.data.json_message_content_dto import JsonMessageContentDTO from ...llm import CompletionArguments from ...llm.external.model import ChatModel @@ -15,16 +17,50 @@ def convert_to_open_ai_messages( messages: list[PyrisMessage], ) -> list[ChatCompletionMessageParam]: - return [ - { + """ + Convert a list of PyrisMessage to a list of ChatCompletionMessageParam + """ + openai_messages = [] + for message in messages: + openai_content = [] + for content in message.contents: + match content: + case ImageMessageContentDTO(): + openai_content.append( + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{content.base64}", + "detail": "high", + }, + } + ) + case TextMessageContentDTO(): + openai_content.append( + {"type": "text", "text": content.text_content} + ) + case JsonMessageContentDTO(): + openai_content.append( + { + "type": "json_object", + "json_object": content.json_content, + } + ) + case _: + pass + + openai_message = { "role": map_role_to_str(message.sender), - "content": message.contents[0].text_content, + "content": openai_content, } - for message in messages - ] + openai_messages.append(openai_message) + return openai_messages def convert_to_iris_message(message: ChatCompletionMessage) -> PyrisMessage: + """ + Convert a ChatCompletionMessage to a PyrisMessage + """ return PyrisMessage( sender=map_str_to_role(message.role), contents=[TextMessageContentDTO(textContent=message.content)], @@ -45,7 +81,6 @@ def chat( messages=convert_to_open_ai_messages(messages), temperature=arguments.temperature, max_tokens=arguments.max_tokens, - stop=arguments.stop, ) return convert_to_iris_message(response.choices[0].message) diff --git a/app/llm/external/openai_dalle.py b/app/llm/external/openai_dalle.py new file mode 100644 index 00000000..e8f9817c --- /dev/null +++ b/app/llm/external/openai_dalle.py @@ -0,0 +1,59 @@ +import base64 +from typing import List, Literal + +import requests + +from app.domain.data.image_message_content_dto import ImageMessageContentDTO + + +def generate_images( + self, + prompt: str, + n: int = 1, + size: Literal[ + "256x256", "512x512", "1024x1024", "1792x1024", "1024x1792" + ] = "256x256", + quality: Literal["standard", "hd"] = "standard", + **kwargs, +) -> List[ImageMessageContentDTO]: + """ + Generate images from the prompt. + """ + try: + response = self._client.images.generate( + model=self.model, + prompt=prompt, + size=size, + quality=quality, + n=n, + response_format="url", + **kwargs, + ) + except Exception as e: + print(f"Failed to generate images: {e}") + return [] + + images = response.data + iris_images = [] + for image in images: + revised_prompt = ( + prompt if image.revised_prompt is None else image.revised_prompt + ) + base64_data = image.b64_json + if base64_data is None: + try: + image_response = requests.get(image.url) + image_response.raise_for_status() + base64_data = base64.b64encode(image_response.content).decode("utf-8") + except requests.RequestException as e: + print(f"Failed to download or encode image: {e}") + continue + + iris_images.append( + ImageMessageContentDTO( + prompt=revised_prompt, + base64=base64_data, + ) + ) + + return iris_images diff --git a/app/llm/request_handler/basic_request_handler.py b/app/llm/request_handler/basic_request_handler.py index dc07d545..5756346f 100644 --- a/app/llm/request_handler/basic_request_handler.py +++ b/app/llm/request_handler/basic_request_handler.py @@ -1,4 +1,7 @@ +from typing import Optional + from app.domain import PyrisMessage +from app.domain.data.image_message_content_dto import ImageMessageContentDTO from app.llm.request_handler import RequestHandler from app.llm.completion_arguments import CompletionArguments from app.llm.llm_manager import LlmManager @@ -12,9 +15,14 @@ def __init__(self, model_id: str): self.model_id = model_id self.llm_manager = LlmManager() - def complete(self, prompt: str, arguments: CompletionArguments) -> str: + def complete( + self, + prompt: str, + arguments: CompletionArguments, + image: Optional[ImageMessageContentDTO] = None, + ) -> str: llm = self.llm_manager.get_llm_by_id(self.model_id) - return llm.complete(prompt, arguments) + return llm.complete(prompt, arguments, image) def chat( self, messages: list[PyrisMessage], arguments: CompletionArguments diff --git a/app/llm/request_handler/request_handler_interface.py b/app/llm/request_handler/request_handler_interface.py index 4acdbe6d..390a4cbc 100644 --- a/app/llm/request_handler/request_handler_interface.py +++ b/app/llm/request_handler/request_handler_interface.py @@ -1,6 +1,8 @@ from abc import ABCMeta, abstractmethod +from typing import Optional from ...domain import PyrisMessage +from ...domain.data.image_message_content_dto import ImageMessageContentDTO from ...llm import CompletionArguments @@ -19,7 +21,12 @@ def __subclasshook__(cls, subclass) -> bool: ) @abstractmethod - def complete(self, prompt: str, arguments: CompletionArguments) -> str: + def complete( + self, + prompt: str, + arguments: CompletionArguments, + image: Optional[ImageMessageContentDTO] = None, + ) -> str: """Create a completion from the prompt""" raise NotImplementedError diff --git a/app/pipeline/chat/tutor_chat_pipeline.py b/app/pipeline/chat/tutor_chat_pipeline.py index ed3e9347..5f36b1b8 100644 --- a/app/pipeline/chat/tutor_chat_pipeline.py +++ b/app/pipeline/chat/tutor_chat_pipeline.py @@ -74,6 +74,7 @@ def __call__(self, dto: TutorChatPipelineExecutionDTO, **kwargs): :param dto: The pipeline execution data transfer object :param kwargs: The keyword arguments """ + # Set up the initial prompt self.prompt = ChatPromptTemplate.from_messages( [