Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LLM: Add image recognition and generation support #89

Merged
merged 18 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion app/common/message_converters.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from datetime import datetime
from typing import Literal

from langchain_core.messages import BaseMessage

Expand Down Expand Up @@ -47,7 +48,7 @@ def convert_langchain_message_to_iris_message(
)


def map_role_to_str(role: IrisMessageRole) -> str:
def map_role_to_str(role: IrisMessageRole) -> Literal["user", "assistant", "system"]:
match role:
case IrisMessageRole.USER:
return "user"
Expand Down
1 change: 1 addition & 0 deletions app/domain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
TutorChatPipelineExecutionDTO,
)
from .pyris_message import PyrisMessage, IrisMessageRole
from app.domain.data import image_message_content_dto
6 changes: 3 additions & 3 deletions app/domain/data/image_message_content_dto.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pydantic import BaseModel
from typing import Optional

from pydantic import BaseModel, Field


class ImageMessageContentDTO(BaseModel):
image_data: Optional[str] = Field(alias="imageData", default=None)
base64: str
prompt: Optional[str]
8 changes: 5 additions & 3 deletions app/domain/data/json_message_content_dto.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from pydantic import BaseModel, Field, Json
from typing import Any, Optional
from pydantic import BaseModel, ConfigDict, Field, Json
from typing import Any


class JsonMessageContentDTO(BaseModel):
json_content: Optional[Json[Any]] = Field(alias="jsonContent", default=None)
model_config = ConfigDict(populate_by_name=True)

json_content: Json[Any] = Field(alias="jsonContent")
8 changes: 4 additions & 4 deletions app/domain/data/text_message_content_dto.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional

from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field


class TextMessageContentDTO(BaseModel):
text_content: Optional[str] = Field(alias="textContent", default=None)
model_config = ConfigDict(populate_by_name=True)

text_content: str = Field(alias="textContent")
4 changes: 3 additions & 1 deletion app/domain/pyris_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from enum import Enum
from typing import List

from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field

from app.domain.data.message_content_dto import MessageContentDTO

Expand All @@ -14,6 +14,8 @@ class IrisMessageRole(str, Enum):


class PyrisMessage(BaseModel):
model_config = ConfigDict(populate_by_name=True)

sent_at: datetime | None = Field(alias="sentAt", default=None)
sender: IrisMessageRole
contents: List[MessageContentDTO] = []
Expand Down
24 changes: 24 additions & 0 deletions app/llm/external/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,27 @@ def embed(self, text: str) -> list[float]:
raise NotImplementedError(
f"The LLM {self.__str__()} does not support embeddings"
)


class ImageGenerationModel(LanguageModel, metaclass=ABCMeta):
"""Abstract class for the llm image generation wrappers"""

@classmethod
def __subclasshook__(cls, subclass):
return hasattr(subclass, "generate_images") and callable(
subclass.generate_images
)

@abstractmethod
def generate_images(
self,
prompt: str,
n: int = 1,
size: str = "256x256",
quality: str = "standard",
**kwargs,
) -> list:
"""Create an image from the prompt"""
raise NotImplementedError(
f"The LLM {self.__str__()} does not support image generation"
)
64 changes: 55 additions & 9 deletions app/llm/external/ollama.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,65 @@
import base64
from datetime import datetime
from typing import Literal, Any
from typing import Literal, Any, Optional

from ollama import Client, Message

from ...common.message_converters import map_role_to_str, map_str_to_role
from ...domain.data.json_message_content_dto import JsonMessageContentDTO
from ...domain.data.text_message_content_dto import TextMessageContentDTO
from ...domain.data.image_message_content_dto import ImageMessageContentDTO
from ...domain import PyrisMessage
from ...llm import CompletionArguments
from ...llm.external.model import ChatModel, CompletionModel, EmbeddingModel


def convert_to_ollama_images(base64_images: list[str]) -> list[bytes] | None:
"""
Convert a list of base64 images to a list of bytes
"""
if not base64_images:
return None
return [base64.b64decode(base64_image) for base64_image in base64_images]


def convert_to_ollama_messages(messages: list[PyrisMessage]) -> list[Message]:
return [
Message(
role=map_role_to_str(message.sender),
content=message.contents[0].text_content,
"""
Convert a list of PyrisMessages to a list of Ollama Messages
"""
messages_to_return = []
for message in messages:
if len(message.contents) == 0:
continue
text_content = ""
images = []
for content in message.contents:
match content:
case ImageMessageContentDTO():
images.append(content.base64)
case TextMessageContentDTO():
if len(text_content) > 0:
text_content += "\n"
text_content += content.text_content
case JsonMessageContentDTO():
if len(text_content) > 0:
text_content += "\n"
text_content += content.json_content
case _:
continue
messages_to_return.append(
Message(
role=map_role_to_str(message.sender),
content=text_content,
images=convert_to_ollama_images(images),
)
)
for message in messages
]
return messages_to_return


def convert_to_iris_message(message: Message) -> PyrisMessage:
"""
Convert a Message to a PyrisMessage
"""
contents = [TextMessageContentDTO(text_content=message["content"])]
return PyrisMessage(
sender=map_str_to_role(message["role"]),
Expand All @@ -42,8 +81,15 @@ class OllamaModel(
def model_post_init(self, __context: Any) -> None:
self._client = Client(host=self.host) # TODO: Add authentication (httpx auth?)

def complete(self, prompt: str, arguments: CompletionArguments) -> str:
response = self._client.generate(model=self.model, prompt=prompt)
def complete(
self,
prompt: str,
arguments: CompletionArguments,
image: Optional[ImageMessageContentDTO] = None,
) -> str:
response = self._client.generate(
model=self.model, prompt=prompt, images=[image.base64] if image else None
)
return response["response"]

def chat(
Expand Down
51 changes: 43 additions & 8 deletions app/llm/external/openai_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,64 @@

from openai import OpenAI
from openai.lib.azure import AzureOpenAI
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionMessage
from openai.types.chat import ChatCompletionMessage, ChatCompletionMessageParam

from ...common.message_converters import map_role_to_str, map_str_to_role
from ...common.message_converters import map_str_to_role, map_role_to_str
from app.domain.data.text_message_content_dto import TextMessageContentDTO
from ...domain import PyrisMessage
from ...domain.data.image_message_content_dto import ImageMessageContentDTO
from ...domain.data.json_message_content_dto import JsonMessageContentDTO
from ...llm import CompletionArguments
from ...llm.external.model import ChatModel


def convert_to_open_ai_messages(
messages: list[PyrisMessage],
) -> list[ChatCompletionMessageParam]:
return [
{
"""
Convert a list of PyrisMessage to a list of ChatCompletionMessageParam
"""
openai_messages = []
for message in messages:
openai_content = []
for content in message.contents:
match content:
case ImageMessageContentDTO():
openai_content.append(
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{content.base64}",
"detail": "high",
},
}
)
case TextMessageContentDTO():
openai_content.append(
{"type": "text", "text": content.text_content}
)
case JsonMessageContentDTO():
openai_content.append(
{
"type": "json_object",
"json_object": content.json_content,
}
)
case _:
pass

openai_message = {
"role": map_role_to_str(message.sender),
"content": message.contents[0].text_content,
"content": openai_content,
}
for message in messages
]
openai_messages.append(openai_message)
return openai_messages


def convert_to_iris_message(message: ChatCompletionMessage) -> PyrisMessage:
"""
Convert a ChatCompletionMessage to a PyrisMessage
"""
return PyrisMessage(
sender=map_str_to_role(message.role),
contents=[TextMessageContentDTO(textContent=message.content)],
Expand All @@ -45,7 +81,6 @@ def chat(
messages=convert_to_open_ai_messages(messages),
temperature=arguments.temperature,
max_tokens=arguments.max_tokens,
stop=arguments.stop,
)
return convert_to_iris_message(response.choices[0].message)

Expand Down
59 changes: 59 additions & 0 deletions app/llm/external/openai_dalle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import base64
from typing import List, Literal

import requests

from app.domain.data.image_message_content_dto import ImageMessageContentDTO


def generate_images(
self,
prompt: str,
n: int = 1,
size: Literal[
"256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"
] = "256x256",
quality: Literal["standard", "hd"] = "standard",
**kwargs,
) -> List[ImageMessageContentDTO]:
"""
Generate images from the prompt.
"""
try:
response = self._client.images.generate(
model=self.model,
prompt=prompt,
size=size,
quality=quality,
n=n,
response_format="url",
**kwargs,
)
except Exception as e:
print(f"Failed to generate images: {e}")
return []

images = response.data
iris_images = []
for image in images:
revised_prompt = (
prompt if image.revised_prompt is None else image.revised_prompt
)
base64_data = image.b64_json
if base64_data is None:
try:
image_response = requests.get(image.url)
image_response.raise_for_status()
base64_data = base64.b64encode(image_response.content).decode("utf-8")
except requests.RequestException as e:
print(f"Failed to download or encode image: {e}")
continue

iris_images.append(
ImageMessageContentDTO(
prompt=revised_prompt,
base64=base64_data,
)
)
Comment on lines +52 to +57
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not work, because ImageMessageContentDTO.base64 is an array.


return iris_images
12 changes: 10 additions & 2 deletions app/llm/request_handler/basic_request_handler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import Optional

from app.domain import PyrisMessage
from app.domain.data.image_message_content_dto import ImageMessageContentDTO
from app.llm.request_handler import RequestHandler
from app.llm.completion_arguments import CompletionArguments
from app.llm.llm_manager import LlmManager
Expand All @@ -12,9 +15,14 @@ def __init__(self, model_id: str):
self.model_id = model_id
self.llm_manager = LlmManager()

def complete(self, prompt: str, arguments: CompletionArguments) -> str:
def complete(
self,
prompt: str,
arguments: CompletionArguments,
image: Optional[ImageMessageContentDTO] = None,
) -> str:
llm = self.llm_manager.get_llm_by_id(self.model_id)
return llm.complete(prompt, arguments)
return llm.complete(prompt, arguments, image)

def chat(
self, messages: list[PyrisMessage], arguments: CompletionArguments
Expand Down
9 changes: 8 additions & 1 deletion app/llm/request_handler/request_handler_interface.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from abc import ABCMeta, abstractmethod
from typing import Optional

from ...domain import PyrisMessage
from ...domain.data.image_message_content_dto import ImageMessageContentDTO
from ...llm import CompletionArguments


Expand All @@ -19,7 +21,12 @@ def __subclasshook__(cls, subclass) -> bool:
)

@abstractmethod
def complete(self, prompt: str, arguments: CompletionArguments) -> str:
def complete(
self,
prompt: str,
arguments: CompletionArguments,
image: Optional[ImageMessageContentDTO] = None,
) -> str:
"""Create a completion from the prompt"""
raise NotImplementedError

Expand Down
1 change: 1 addition & 0 deletions app/pipeline/chat/tutor_chat_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __call__(self, dto: TutorChatPipelineExecutionDTO, **kwargs):
:param dto: The pipeline execution data transfer object
:param kwargs: The keyword arguments
"""

# Set up the initial prompt
self.prompt = ChatPromptTemplate.from_messages(
[
Expand Down
Loading