Skip to content

Commit

Permalink
Improve & fix image recognition
Browse files Browse the repository at this point in the history
  • Loading branch information
Hialus committed Apr 27, 2024
1 parent ec964c3 commit 001b99d
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 57 deletions.
1 change: 1 addition & 0 deletions app/domain/data/image_message_content_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ class Config:
"base64": ["base64EncodedString==", "anotherBase64EncodedString=="],
}
}
populate_by_name = True
8 changes: 5 additions & 3 deletions app/domain/data/json_message_content_dto.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from pydantic import BaseModel, Field, Json
from typing import Any, Optional
from pydantic import BaseModel, ConfigDict, Field, Json
from typing import Any


class JsonMessageContentDTO(BaseModel):
json_content: Optional[Json[Any]] = Field(alias="jsonContent", default=None)
model_config = ConfigDict(populate_by_name=True)

json_content: Json[Any] = Field(alias="jsonContent")
8 changes: 4 additions & 4 deletions app/domain/data/text_message_content_dto.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional

from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field


class TextMessageContentDTO(BaseModel):
text_content: Optional[str] = Field(alias="textContent", default=None)
model_config = ConfigDict(populate_by_name=True)

text_content: str = Field(alias="textContent")
4 changes: 3 additions & 1 deletion app/domain/pyris_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from enum import Enum
from typing import List

from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field

from app.domain.data.message_content_dto import MessageContentDTO

Expand All @@ -14,6 +14,8 @@ class IrisMessageRole(str, Enum):


class PyrisMessage(BaseModel):
model_config = ConfigDict(populate_by_name=True)

sent_at: datetime | None = Field(alias="sentAt", default=None)
sender: IrisMessageRole
contents: List[MessageContentDTO] = []
Expand Down
53 changes: 27 additions & 26 deletions app/llm/external/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,35 +24,36 @@ def convert_to_ollama_images(base64_images: list[str]) -> list[bytes] | None:

def convert_to_ollama_messages(messages: list[PyrisMessage]) -> list[Message]:
"""
Convert a list of PyrisMessage to a list of Message
Convert a list of PyrisMessages to a list of Ollama Messages
"""
messages_to_return = []
for message in messages:
match message.contents[0]:
case ImageMessageContentDTO():
messages_to_return.append(
Message(
role=map_role_to_str(message.sender),
content=message.contents[0].text_content,
images=message.contents[0].base64,
)
)
case TextMessageContentDTO():
messages_to_return.append(
Message(
role=map_role_to_str(message.sender),
content=message.contents[0].text_content,
)
)
case JsonMessageContentDTO():
messages_to_return.append(
Message(
role=map_role_to_str(message.sender),
content=message.contents[0].text_content,
)
)
case _:
continue
if len(message.contents) == 0:
continue
text_content = ""
images = []
for content in message.contents:
match content:
case ImageMessageContentDTO():
for image in content.base64:
images.append(image)
case TextMessageContentDTO():
if len(text_content) > 0:
text_content += "\n"
text_content += content.text_content
case JsonMessageContentDTO():
if len(text_content) > 0:
text_content += "\n"
text_content += content.json_content
case _:
continue
messages_to_return.append(
Message(
role=map_role_to_str(message.sender),
content=text_content,
images=convert_to_ollama_images(images),
)
)
return messages_to_return


Expand Down
52 changes: 29 additions & 23 deletions app/llm/external/openai_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,32 +22,38 @@ def convert_to_open_ai_messages(
"""
openai_messages = []
for message in messages:
match message.contents[0]:
case ImageMessageContentDTO():
content = [{"type": "text", "text": message.contents[0].prompt}]
for image_base64 in message.contents[0].base64:
content.append(
openai_content = []
for content in message.contents:
match content:
case ImageMessageContentDTO():
for image_base64 in content.base64:
openai_content.append(
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_base64}",
"detail": "high",
},
}
)
case TextMessageContentDTO():
openai_content.append(
{"type": "text", "text": content.text_content}
)
case JsonMessageContentDTO():
openai_content.append(
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_base64}",
"detail": "high",
},
"type": "json_object",
"json_object": content.json_content,
}
)
case TextMessageContentDTO():
content = [{"type": "text", "text": message.contents[0].text_content}]
case JsonMessageContentDTO():
content = [
{
"type": "json_object",
"json_object": message.contents[0].json_content,
}
]
case _:
content = [{"type": "text", "text": ""}]

openai_message = {"role": map_role_to_str(message.sender), "content": content}
case _:
pass

openai_message = {
"role": map_role_to_str(message.sender),
"content": openai_content,
}
openai_messages.append(openai_message)
return openai_messages

Expand Down

0 comments on commit 001b99d

Please sign in to comment.