Skip to content

Commit

Permalink
feat(Message/add_message): add role_type field and update method (#126)
Browse files Browse the repository at this point in the history
* Working sdk

* tests

* feat: Add ai and human prefixes to zep chat message history

* updating role types

---------

Co-authored-by: paulpaliychuk <[email protected]>
  • Loading branch information
petergarbers and paul-paliychuk authored Mar 8, 2024
1 parent 6cc145d commit b7ed932
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 3 deletions.
17 changes: 17 additions & 0 deletions examples/chat_history/chat_history_shoe_purchase.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
history = [
{
"role": "user",
"role_type": "user",
"content": (
"Hi there, I'm looking for a new pair of shoes. I'm not sure where"
" to start."
),
},
{
"role": "assistant",
"role_type": "assistant",
"content": (
"Hello! I'd be happy to help you find the perfect pair of shoes. To"
" get started, could you tell me what brand you're interested in,"
Expand All @@ -16,10 +18,12 @@
},
{
"role": "user",
"role_type": "user",
"content": "I like Nike and Adidas, but I'm open to other brands too.",
},
{
"role": "assistant",
"role_type": "assistant",
"content": (
"That's great! It's important to find the right shoe for your"
" needs. Could you tell me what you'll be using these shoes for"
Expand All @@ -29,10 +33,12 @@
},
{
"role": "user",
"role_type": "user",
"content": "I'll be using them for running.",
},
{
"role": "assistant",
"role_type": "assistant",
"content": (
"Fantastic! Now, could you please provide me with your shoe size"
" and foot width? This information will help me narrow down the"
Expand All @@ -41,24 +47,28 @@
},
{
"role": "user",
"role_type": "user",
"content": ("I wear a men's size 10, and I'm not sure about the foot width."),
},
{
"role": "assistant",
"role_type": "assistant",
"content": (
"Thank you. Based on your size, I'll make sure to suggest shoes"
" with a wide range of sizes and widths available."
),
},
{
"role": "user",
"role_type": "user",
"content": (
"I have a bit of a pronation issue. Do I need to consider that"
" while choosing shoes?"
),
},
{
"role": "assistant",
"role_type": "assistant",
"content": (
"Yes, pronation is essential when choosing running shoes. It refers"
" to the way your foot rolls inward or outward when you walk or"
Expand All @@ -69,10 +79,12 @@
},
{
"role": "user",
"role_type": "user",
"content": "Great, I have a budget of around $120.",
},
{
"role": "assistant",
"role_type": "assistant",
"content": (
"Perfect! I'll make sure to find the best running shoes that fit"
" your preferences, taking into account your brand preferences,"
Expand All @@ -81,6 +93,7 @@
},
{
"role": "assistant",
"role_type": "assistant",
"content": (
"I've found a few great options for you. The first one is the"
" Brooks Ghost 14, which is known for its cushioning and support,"
Expand All @@ -91,13 +104,15 @@
},
{
"role": "user",
"role_type": "user",
"content": (
"I've heard good things about Brooks, but I was kind of hoping to"
" get a Nike or an Adidas shoe."
),
},
{
"role": "assistant",
"role_type": "assistant",
"content": (
"I understand. In that case, I'd recommend the Adidas Ultraboost"
" 21. It's a popular choice among runners for its comfort, energy"
Expand All @@ -107,10 +122,12 @@
},
{
"role": "user",
"role_type": "user",
"content": "They're nice looking. Yes, I think I'll go with those.",
},
{
"role": "assistant",
"role_type": "assistant",
"content": (
"Excellent choice! To complete the purchase, I'll need you to"
" provide your credit card information in our secure payment form."
Expand Down
14 changes: 13 additions & 1 deletion tests/memory_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,14 @@
{
"uuid": "msg-uuid",
"role": "user",
"role_type": "user",
"content": "Test message",
"metadata": {"key": "value"},
},
{
"uuid": "msg-uuid2",
"role": "ai",
"role_type": "assistant",
"content": "Test message2",
"metadata": {"key2": "value2"},
},
Expand Down Expand Up @@ -101,10 +103,11 @@ async def test_aget_memory_missing_values(httpx_mock: HTTPXMock):

mock_response = {
"messages": [
{"role": "user", "content": "Test message"},
{"role": "user", "role_type": "user", "content": "Test message"},
{
"uuid": "msg-uuid2",
"role": "ai",
"role_type": "assistant",
},
],
}
Expand Down Expand Up @@ -214,6 +217,7 @@ async def test_asearch_memory(httpx_mock: HTTPXMock):
"message": {
"uuid": "msg-uuid",
"role": "user",
"role_type": "user",
"content": "Test message",
},
"score": 0.9,
Expand Down Expand Up @@ -247,6 +251,7 @@ async def test_asearch_memory_mmr(httpx_mock: HTTPXMock):
"message": {
"uuid": "msg-uuid",
"role": "user",
"role_type": "user",
"content": "Test message",
},
"score": 0.9,
Expand Down Expand Up @@ -421,6 +426,7 @@ async def test_aget_session_message(httpx_mock: HTTPXMock):
mock_message = {
"uuid": message_id,
"role": "user",
"role_type": "user",
"content": "Test message",
"metadata": {"key": "value"},
}
Expand Down Expand Up @@ -472,6 +478,7 @@ def test_get_session_message(httpx_mock: HTTPXMock):
mock_message = {
"uuid": message_id,
"role": "user",
"role_type": "user",
"content": "Test message",
"metadata": {"key": "value"},
}
Expand All @@ -481,6 +488,7 @@ def test_get_session_message(httpx_mock: HTTPXMock):
session_message = client.message.get_session_message(session_id, message_id)
assert session_message.uuid == mock_message["uuid"]
assert session_message.role == mock_message["role"]
assert session_message.role_type == mock_message["role_type"]
assert session_message.content == mock_message["content"]
assert session_message.metadata == mock_message["metadata"]
assert filter_unset_fields(session_message.model_dump()) == mock_message
Expand Down Expand Up @@ -520,6 +528,7 @@ async def test_aupdate_message_metadata(httpx_mock: HTTPXMock):
mock_message = {
"uuid": message_id,
"role": "user",
"role_type": "user",
"content": "Test message",
"metadata": {"metadata": {"foo": "bar"}},
}
Expand All @@ -531,6 +540,7 @@ async def test_aupdate_message_metadata(httpx_mock: HTTPXMock):
)
assert updated_message.uuid == mock_message["uuid"]
assert updated_message.role == mock_message["role"]
assert updated_message.role_type == mock_message["role_type"]
assert updated_message.content == mock_message["content"]
assert updated_message.metadata == mock_message["metadata"]
assert filter_unset_fields(updated_message.model_dump()) == mock_message
Expand Down Expand Up @@ -577,6 +587,7 @@ def test_update_message_metadata(httpx_mock: HTTPXMock):
mock_message = {
"uuid": message_id,
"role": "user",
"role_type": "user",
"content": "Test message",
"metadata": {"metadata": {"foo": "bar"}},
}
Expand All @@ -588,6 +599,7 @@ def test_update_message_metadata(httpx_mock: HTTPXMock):
)
assert updated_message.uuid == mock_message["uuid"]
assert updated_message.role == mock_message["role"]
assert updated_message.role_type == mock_message["role_type"]
assert updated_message.content == mock_message["content"]
assert updated_message.metadata == mock_message["metadata"]
assert filter_unset_fields(updated_message.model_dump()) == mock_message
Expand Down
17 changes: 16 additions & 1 deletion zep_python/langchain/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from zep_python import API_URL, NotFoundError, ZepClient
from zep_python.memory.models import Memory, Message
from zep_python.message.models import get_zep_message_role_type

try:
from langchain_core.chat_history import BaseChatMessageHistory
Expand Down Expand Up @@ -50,6 +51,8 @@ def __init__(
api_url: Optional[str] = API_URL,
api_key: Optional[str] = None,
memory_type: Optional[str] = None,
ai_prefix: Optional[str] = None,
human_prefix: Optional[str] = None,
) -> None:
if zep_client is None:
self._client = ZepClient(api_url=api_url, api_key=api_key)
Expand All @@ -59,6 +62,9 @@ def __init__(
self.session_id = session_id
self.memory_type = memory_type or "perpetual"

self.ai_prefix = ai_prefix or "ai"
self.human_prefix = human_prefix or "human"

@property
def messages(self) -> List[BaseMessage]: # type: ignore
"""Retrieve messages from Zep memory"""
Expand Down Expand Up @@ -159,8 +165,17 @@ def add_message(
if isinstance(message.content, list):
raise ValueError("Message content cannot be a list")

if message.type == "ai":
message.name = self.ai_prefix
elif message.type == "human":
message.name = self.human_prefix

zep_message = Message(
content=message.content, role=message.type, metadata=metadata
content=message.content,
# If name is not set, use type as role
role=message.name or message.type,
role_type=get_zep_message_role_type(message.type),
metadata=metadata,
)
zep_memory = Memory(messages=[zep_message])

Expand Down
37 changes: 36 additions & 1 deletion zep_python/message/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,31 @@
from enum import Enum
from typing import Any, Dict, Optional
from uuid import UUID

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, validator


class RoleType(str, Enum):
USER_ROLE = "user"
ASSISTANT_ROLE = "assistant"
SYSTEM_ROLE = "system"
FUNCTION_ROLE = "function"
TOOL_ROLE = "tool"


def get_zep_message_role_type(role):
if role == "human":
return RoleType.USER_ROLE
elif role == "ai":
return RoleType.ASSISTANT_ROLE
elif role == "system":
return RoleType.SYSTEM_ROLE
elif role == "function":
return RoleType.FUNCTION_ROLE
elif role == "tool":
return RoleType.TOOL_ROLE
else:
return RoleType.SYSTEM_ROLE


class Message(BaseModel):
Expand All @@ -16,6 +40,8 @@ class Message(BaseModel):
The timestamp of when the message was created.
role : str
The role of the sender of the message (e.g., "user", "assistant").
role_type: RoleType
The type of the role (e.g., "user", "system").
content : str
The content of the message.
token_count : int, optional
Expand All @@ -28,12 +54,21 @@ class Message(BaseModel):
"""

role: str = Field("A role is required")
role_type: RoleType = Field("A role type is required")
content: str = Field("Content is required")
uuid: Optional[str] = Field(default=None)
created_at: Optional[str] = Field(default=None)
token_count: Optional[int] = Field(default=None)
metadata: Optional[Dict[str, Any]] = Field(default=None)

@validator("role_type")
def validate_role_type(cls, v):
if isinstance(v, RoleType):
return v
if v not in RoleType._value2member_map_:
raise ValueError(f"Invalid role type: {v}")
return v

def to_dict(self) -> Dict[str, Any]:
"""
Returns a dictionary representation of the message.
Expand Down

0 comments on commit b7ed932

Please sign in to comment.