Skip to content

Commit

Permalink
fix: Google embeddings endpoint bug (#2416) (#2417)
Browse files Browse the repository at this point in the history
  • Loading branch information
sarahwooders authored Feb 11, 2025
2 parents c2e523d + 4da0e49 commit 1e51cba
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 3 deletions.
30 changes: 30 additions & 0 deletions letta/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,27 @@ def get_text_embedding(self, text: str):
return response_json["embedding"]


class GoogleEmbeddings:
def __init__(self, api_key: str, model: str, base_url: str):
self.api_key = api_key
self.model = model
self.base_url = base_url # Expected to be "https://generativelanguage.googleapis.com"

def get_text_embedding(self, text: str):
import httpx

headers = {"Content-Type": "application/json"}
# Build the URL based on the provided base_url, model, and API key.
url = f"{self.base_url}/v1beta/models/{self.model}:embedContent?key={self.api_key}"
payload = {"model": self.model, "content": {"parts": [{"text": text}]}}
with httpx.Client() as client:
response = client.post(url, headers=headers, json=payload)
# Raise an error for non-success HTTP status codes.
response.raise_for_status()
response_json = response.json()
return response_json["embedding"]["values"]


def query_embedding(embedding_model, query_text: str):
"""Generate padded embedding for querying database"""
query_vec = embedding_model.get_text_embedding(query_text)
Expand Down Expand Up @@ -237,5 +258,14 @@ def embedding_model(config: EmbeddingConfig, user_id: Optional[uuid.UUID] = None
)
return model

elif endpoint_type == "google_ai":
assert all([model_settings.gemini_api_key is not None, model_settings.gemini_base_url is not None])
model = GoogleEmbeddings(
model=config.embedding_model,
api_key=model_settings.gemini_api_key,
base_url=model_settings.gemini_base_url,
)
return model

else:
raise ValueError(f"Unknown endpoint type {endpoint_type}")
2 changes: 1 addition & 1 deletion letta/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class ModelSettings(BaseSettings):

# google ai
gemini_api_key: Optional[str] = None

gemini_base_url: str = "https://generativelanguage.googleapis.com/"
# together
together_api_key: Optional[str] = None

Expand Down
4 changes: 3 additions & 1 deletion tests/test_base_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,9 @@ def test_send_message_to_agent(client, agent_obj, other_agent_obj):
found = True
break

print(f"In context messages of the sender agent (without system):\n\n{"\n".join([m.text for m in in_context_messages[1:]])}")
# Compute the joined string first
joined_messages = "\n".join([m.text for m in in_context_messages[1:]])
print(f"In context messages of the sender agent (without system):\n\n{joined_messages}")
if not found:
raise Exception(f"Was not able to find an instance of the target snippet: {target_snippet}")

Expand Down
157 changes: 157 additions & 0 deletions tests/test_google_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import pytest
import httpx

from letta.embeddings import GoogleEmbeddings # Adjust the import based on your module structure
from dotenv import load_dotenv

load_dotenv()
import os

import pytest

import time
import uuid
import pytest
from letta_client import CreateBlock
from letta_client import Letta as LettaSDKClient
from letta_client import MessageCreate
import threading

SERVER_PORT = 8283


def run_server():
load_dotenv()

from letta.server.rest_api.app import start_server

print("Starting server...")
start_server(debug=True)


@pytest.fixture(scope="module")
def client() -> LettaSDKClient:
# Get URL from environment or start server
server_url = os.getenv("LETTA_SERVER_URL", f"http://localhost:{SERVER_PORT}")
if not os.getenv("LETTA_SERVER_URL"):
print("Starting server thread")
thread = threading.Thread(target=run_server, daemon=True)
thread.start()
time.sleep(5)
print("Running client tests with server:", server_url)
client = LettaSDKClient(base_url=server_url, token=None)
yield client


def test_google_embeddings_response():
api_key = os.environ.get("GEMINI_API_KEY")
model = "text-embedding-004"
base_url = "https://generativelanguage.googleapis.com"
text = "Hello, world!"

embedding_model = GoogleEmbeddings(api_key, model, base_url)
response = None

try:
response = embedding_model.get_text_embedding(text)
except httpx.HTTPStatusError as e:
pytest.fail(f"Request failed with status code {e.response.status_code}")

assert response is not None, "No response received from API"
assert isinstance(response, list), "Response is not a list of embeddings"


def test_archival_insert_text_embedding_004(client: LettaSDKClient):
"""
Test that an agent with model 'gemini-2.0-flash-exp' and embedding 'text_embedding_004'
correctly inserts a message into its archival memory.
The test works by:
1. Creating an agent with the desired model and embedding.
2. Sending a message prefixed with 'archive :' to instruct the agent to store the message in archival.
3. Retrieving the archival memory via the agent messaging API.
4. Verifying that the archival message is stored.
"""
# Create an agent with the specified model and embedding.
agent = client.agents.create(
name=f"archival_insert_text_embedding_004",
memory_blocks=[
CreateBlock(label="human", value="name: archival_test"),
CreateBlock(label="persona", value="You are a helpful assistant that loves helping out the user"),
],
model="google_ai/gemini-2.0-flash-exp",
embedding="google_ai/text-embedding-004",
)

# Define the archival message.
archival_message = "Archival insertion test message"

# Send a message instructing the agent to archive it.
res = client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content=f"Store this in your archive memory: {archival_message}")],
)
print(res.messages)


# Retrieve the archival messages through the agent messaging API.
archived_messages = client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content=f"retrieve from archival memory : {archival_message}")],
)

print(archived_messages.messages)
# Assert that the archival message is present.
assert (
any(message.status == "success" for message in archived_messages.messages if message.message_type == "tool_return_message")
), f"Archival message '{archival_message}' not found. Archived messages: {archived_messages}"

# Cleanup: Delete the agent.
client.agents.delete(agent.id)


def test_archival_insert_embedding_001(client: LettaSDKClient):
"""
Test that an agent with model 'gemini-2.0-flash-exp' and embedding 'embedding_001'
correctly inserts a message into its archival memory.
The test works by:
1. Creating an agent with the desired model and embedding.
2. Sending a message prefixed with 'archive :' to instruct the agent to store the message in archival.
3. Retrieving the archival memory via the agent messaging API.
4. Verifying that the archival message is stored.
"""
# Create an agent with the specified model and embedding.
agent = client.agents.create(
name=f"archival_insert_embedding_001",
memory_blocks=[
CreateBlock(label="human", value="name: archival_test"),
CreateBlock(label="persona", value="You are a helpful assistant that loves helping out the user"),
],
model="google_ai/gemini-2.0-flash-exp",
embedding="google_ai/embedding-001",
)

# Define the archival message.
archival_message = "Archival insertion test message"

# Send a message instructing the agent to archive it.
client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content=f"archive : {archival_message}")],
)


# Retrieve the archival messages through the agent messaging API.
archived_messages = client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content=f"retrieve from archival memory : {archival_message}")],
)

# Assert that the archival message is present.
assert(
any(message.status == "success" for message in archived_messages.messages if message.message_type == "tool_return_message")
), f"Archival message '{archival_message}' not found. Archived messages: {archived_messages}"

# Cleanup: Delete the agent.
client.agents.delete(agent.id)
2 changes: 1 addition & 1 deletion tests/test_sdk_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def test_shared_blocks(client: LettaSDKClient):
)
assert (
"charles" in client.agents.core_memory.retrieve_block(agent_id=agent_state2.id, block_label="human").value.lower()
), f"Shared block update failed {client.agents.core_memory.retrieve_block(agent_id=agent_state2.id, block_label="human").value}"
), f"Shared block update failed {client.agents.core_memory.retrieve_block(agent_id=agent_state2.id, block_label='human').value}"

# cleanup
client.agents.delete(agent_state1.id)
Expand Down

0 comments on commit 1e51cba

Please sign in to comment.