Skip to content

Commit

Permalink
feat: Make ChatMessage use ContentBlocks for data storage (#17039)
Browse files Browse the repository at this point in the history
  • Loading branch information
masci authored Nov 26, 2024
1 parent 8247ad8 commit 55d4891
Show file tree
Hide file tree
Showing 8 changed files with 182 additions and 55 deletions.
4 changes: 1 addition & 3 deletions .github/workflows/unit_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,7 @@ jobs:
with:
# v0 makes it easy to bust the cache if needed
# just increase the integer to start with a fresh cache
gha-cache-key: v1-py${{ matrix.python_version }}
named-caches-hash: v1-py${{ matrix.python_version }}
pants-python-version: ${{ matrix.python-version }}
named-caches-hash: v3-py${{ matrix.python_version }}-${{ hashFiles('./**/pyproject.toml') }}
pants-ci-config: pants.toml
- name: Check BUILD files
run: |
Expand Down
92 changes: 61 additions & 31 deletions llama-index-core/llama_index/core/base/llms/types.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,31 @@
from __future__ import annotations

import base64
import requests
from enum import Enum
from io import BytesIO
from typing import (
Annotated,
Any,
AsyncGenerator,
Dict,
Generator,
List,
Literal,
Optional,
Union,
List,
Any,
)

import requests
from typing_extensions import Self

from llama_index.core.bridge.pydantic import (
BaseModel,
Field,
ConfigDict,
Field,
field_serializer,
)
from llama_index.core.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS
from llama_index.core.schema import ImageType

try:
from pydantic import BaseModel as V2BaseModel
from pydantic.v1 import BaseModel as V1BaseModel
except ImportError:
from pydantic import BaseModel as V2BaseModel

V1BaseModel = V2BaseModel # type: ignore


class MessageRole(str, Enum):
"""Message role."""
Expand All @@ -44,21 +39,13 @@ class MessageRole(str, Enum):
MODEL = "model"


# ===== Generic Model Input - Chat =====
class ContentBlockTypes(str, Enum):
TEXT = "text"
IMAGE = "image"


class TextBlock(BaseModel):
type: Literal[ContentBlockTypes.TEXT] = ContentBlockTypes.TEXT

block_type: Literal["text"] = "text"
text: str


class ImageBlock(BaseModel):
type: Literal[ContentBlockTypes.IMAGE] = ContentBlockTypes.IMAGE

block_type: Literal["image"] = "image"
image: Optional[str] = None
image_path: Optional[str] = None
image_url: Optional[str] = None
Expand All @@ -78,12 +65,58 @@ def resolve_image(self) -> ImageType:
raise ValueError("No image found in the chat message!")


ContentBlock = Annotated[
Union[TextBlock, ImageBlock], Field(discriminator="block_type")
]


class ChatMessage(BaseModel):
"""Chat message."""

role: MessageRole = MessageRole.USER
content: Optional[Any] = ""
additional_kwargs: dict = Field(default_factory=dict)
additional_kwargs: dict[str, Any] = Field(default_factory=dict)
blocks: list[ContentBlock] = Field(default_factory=list)

def __init__(self, /, content: Any | None = None, **data: Any) -> None:
"""Keeps backward compatibility with the old `content` field.
If content was passed and contained text, store a single TextBlock.
If content was passed and it was a list, assume it's a list of content blocks and store it.
"""
if content is not None:
if isinstance(content, str):
data["blocks"] = [TextBlock(text=content)]
elif isinstance(content, list):
data["blocks"] = content

super().__init__(**data)

@property
def content(self) -> str | None:
"""Keeps backward compatibility with the old `content` field.
Returns:
The block content if there's a single TextBlock, an empty string otherwise.
"""
if len(self.blocks) == 1 and isinstance(self.blocks[0], TextBlock):
return self.blocks[0].text
return None

@content.setter
def content(self, content: str) -> None:
"""Keeps backward compatibility with the old `content` field.
Raises:
ValueError: if blocks contains more than a block, or a block that's not TextBlock.
"""
if not self.blocks:
self.blocks = [TextBlock(text=content)]
elif len(self.blocks) == 1 and isinstance(self.blocks[0], TextBlock):
self.blocks = [TextBlock(text=content)]
else:
raise ValueError(
"ChatMessage contains multiple blocks, use 'ChatMessage.blocks' instead."
)

def __str__(self) -> str:
return f"{self.role.value}: {self.content}"
Expand All @@ -94,13 +127,13 @@ def from_str(
content: str,
role: Union[MessageRole, str] = MessageRole.USER,
**kwargs: Any,
) -> "ChatMessage":
) -> Self:
if isinstance(role, str):
role = MessageRole(role)
return cls(role=role, content=content, **kwargs)
return cls(role=role, blocks=[TextBlock(text=content)], **kwargs)

def _recursive_serialization(self, value: Any) -> Any:
if isinstance(value, V2BaseModel):
if isinstance(value, BaseModel):
value.model_rebuild() # ensures all fields are initialized and serializable
return value.model_dump() # type: ignore
if isinstance(value, dict):
Expand All @@ -116,9 +149,6 @@ def _recursive_serialization(self, value: Any) -> Any:
def serialize_additional_kwargs(self, value: Any, _info: Any) -> Any:
return self._recursive_serialization(value)

def dict(self, **kwargs: Any) -> Dict[str, Any]:
return self.model_dump(**kwargs)


class LogProb(BaseModel):
"""LogProb of a token."""
Expand Down
10 changes: 5 additions & 5 deletions llama-index-core/llama_index/core/memory/vector_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@

import uuid
from typing import Any, Dict, List, Optional, Union
from llama_index.core.bridge.pydantic import field_validator

from llama_index.core.schema import TextNode
from llama_index.core.vector_stores.types import BasePydanticVectorStore
from llama_index.core.base.llms.types import ChatMessage, MessageRole
from llama_index.core.bridge.pydantic import Field
from llama_index.core.memory.types import BaseMemory
from llama_index.core.bridge.pydantic import Field, field_validator
from llama_index.core.embeddings.utils import EmbedType
from llama_index.core.memory.types import BaseMemory
from llama_index.core.schema import TextNode
from llama_index.core.vector_stores.types import BasePydanticVectorStore


def _stringify_obj(d: Any) -> Union[str, list, dict]:
Expand All @@ -30,6 +29,7 @@ def _stringify_chat_message(msg: ChatMessage) -> Dict:
"""Utility function to convert chatmessage to serializable dict."""
msg_dict = msg.dict()
msg_dict["additional_kwargs"] = _stringify_obj(msg_dict["additional_kwargs"])
msg_dict["content"] = msg.content
return msg_dict


Expand Down
Empty file.
3 changes: 3 additions & 0 deletions llama-index-core/tests/base/llms/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
python_tests(
name="tests",
)
Empty file.
96 changes: 96 additions & 0 deletions llama-index-core/tests/base/llms/test_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import pytest
from llama_index.core.base.llms.types import (
ChatMessage,
ImageBlock,
MessageRole,
TextBlock,
)
from llama_index.core.bridge.pydantic import BaseModel


def test_chat_message_from_str():
m = ChatMessage.from_str(content="test content")
assert m.content == "test content"
assert len(m.blocks) == 1
assert type(m.blocks[0]) is TextBlock
assert m.blocks[0].text == "test content"


def test_chat_message_content_legacy_get():
m = ChatMessage(content="test content")
assert m.content == "test content"
assert len(m.blocks) == 1
assert type(m.blocks[0]) is TextBlock
assert m.blocks[0].text == "test content"

m = ChatMessage(role="user", content="test content")
assert m.role == "user"
assert m.content == "test content"
assert len(m.blocks) == 1
assert type(m.blocks[0]) is TextBlock
assert m.blocks[0].text == "test content"

m = ChatMessage(content=[TextBlock(text="test content")])
assert m.content == "test content"
assert len(m.blocks) == 1
assert type(m.blocks[0]) is TextBlock
assert m.blocks[0].text == "test content"


def test_chat_message_content_legacy_set():
m = ChatMessage()
m.content = "test content"
assert len(m.blocks) == 1
assert type(m.blocks[0]) is TextBlock
assert m.blocks[0].text == "test content"

m = ChatMessage(content="some original content")
m.content = "test content"
assert len(m.blocks) == 1
assert type(m.blocks[0]) is TextBlock
assert m.blocks[0].text == "test content"

m = ChatMessage(content=[TextBlock(text="test content"), ImageBlock()])
with pytest.raises(ValueError):
m.content = "test content"


def test_chat_message_content_returns_empty_string():
m = ChatMessage(content=[TextBlock(text="test content"), ImageBlock()])
assert m.content is None


def test__str__():
assert str(ChatMessage(content="test content")) == "user: test content"


def test_serializer():
class SimpleModel(BaseModel):
some_field: str = ""

m = ChatMessage(
content="test content",
additional_kwargs={"some_list": ["a", "b", "c"], "some_object": SimpleModel()},
)
assert m.model_dump() == {
"role": MessageRole.USER,
"additional_kwargs": {
"some_list": ["a", "b", "c"],
"some_object": {"some_field": ""},
},
"blocks": [{"block_type": "text", "text": "test content"}],
}


def test_legacy_roundtrip():
legacy_message = {
"role": MessageRole.USER,
"content": "foo",
"additional_kwargs": {},
}
m = ChatMessage(**legacy_message)
assert m.model_dump() == {
"additional_kwargs": {},
"blocks": [{"block_type": "text", "text": "foo"}],
"role": MessageRole.USER,
}
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ async def test_async_postgres_add_message(postgres_chat_store: PostgresChatStore
message = ChatMessage(content="async_add_message_test", role="user")
await postgres_chat_store.async_add_message(key, message=message)

result = await postgres_chat_store.async_get_messages(key)
result = await postgres_chat_store.aget_messages(key)

assert result[0].content == "async_add_message_test" and result[0].role == "user"

Expand All @@ -190,23 +190,23 @@ async def test_async_set_and_retrieve_messages(postgres_chat_store: PostgresChat
ChatMessage(content="Second async message", role="user"),
]
key = "test_async_set_key"
await postgres_chat_store.async_set_messages(key, messages)
await postgres_chat_store.aset_messages(key, messages)

retrieved_messages = await postgres_chat_store.async_get_messages(key)
retrieved_messages = await postgres_chat_store.aget_messages(key)
assert len(retrieved_messages) == 2
assert retrieved_messages[0].content == "First async message"
assert retrieved_messages[1].content == "Second async message"


@pytest.mark.skipif(no_packages, reason="ayncpg, pscopg and sqlalchemy not installed")
@pytest.mark.asyncio()
async def test_async_delete_messages(postgres_chat_store: PostgresChatStore):
async def test_adelete_messages(postgres_chat_store: PostgresChatStore):
messages = [ChatMessage(content="Async message to delete", role="user")]
key = "test_async_delete_key"
await postgres_chat_store.async_set_messages(key, messages)
await postgres_chat_store.aset_messages(key, messages)

await postgres_chat_store.async_delete_messages(key)
retrieved_messages = await postgres_chat_store.async_get_messages(key)
await postgres_chat_store.adelete_messages(key)
retrieved_messages = await postgres_chat_store.aget_messages(key)
assert retrieved_messages == []


Expand All @@ -217,11 +217,11 @@ async def test_async_delete_specific_message(postgres_chat_store: PostgresChatSt
ChatMessage(content="Async keep me", role="user"),
ChatMessage(content="Async delete me", role="user"),
]
key = "test_async_delete_message_key"
await postgres_chat_store.async_set_messages(key, messages)
key = "test_adelete_message_key"
await postgres_chat_store.aset_messages(key, messages)

await postgres_chat_store.async_delete_message(key, 1)
retrieved_messages = await postgres_chat_store.async_get_messages(key)
await postgres_chat_store.adelete_message(key, 1)
retrieved_messages = await postgres_chat_store.aget_messages(key)
assert len(retrieved_messages) == 1
assert retrieved_messages[0].content == "Async keep me"

Expand All @@ -230,10 +230,10 @@ async def test_async_delete_specific_message(postgres_chat_store: PostgresChatSt
@pytest.mark.asyncio()
async def test_async_get_keys(postgres_chat_store: PostgresChatStore):
# Add some test data
await postgres_chat_store.async_set_messages(
await postgres_chat_store.aset_messages(
"async_key1", [ChatMessage(content="Test1", role="user")]
)
await postgres_chat_store.async_set_messages(
await postgres_chat_store.aset_messages(
"async_key2", [ChatMessage(content="Test2", role="user")]
)

Expand All @@ -250,13 +250,13 @@ async def test_async_delete_last_message(postgres_chat_store: PostgresChatStore)
ChatMessage(content="First async message", role="user"),
ChatMessage(content="Last async message", role="user"),
]
await postgres_chat_store.async_set_messages(key, messages)
await postgres_chat_store.aset_messages(key, messages)

deleted_message = await postgres_chat_store.async_delete_last_message(key)
deleted_message = await postgres_chat_store.adelete_last_message(key)

assert deleted_message.content == "Last async message"

remaining_messages = await postgres_chat_store.async_get_messages(key)
remaining_messages = await postgres_chat_store.aget_messages(key)

assert len(remaining_messages) == 1
assert remaining_messages[0].content == "First async message"

0 comments on commit 55d4891

Please sign in to comment.