Skip to content

Commit

Permalink
Merge pull request #17 from datastax/modify-messages
Browse files Browse the repository at this point in the history
Modify messages
  • Loading branch information
phact authored Apr 1, 2024
2 parents 8923508 + c35a36b commit 0216897
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 29 deletions.
18 changes: 18 additions & 0 deletions impl/astra_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,7 @@ def get_assistant(self, id):
logger.info(f"parsed assistant from row: {assistant}")
return assistant


def delete_by_pk(self, key, value, table):
query_string = f"""
DELETE FROM {CASSANDRA_KEYSPACE}.{table} WHERE {key} = ?;
Expand All @@ -854,6 +855,23 @@ def delete_by_pk(self, key, value, table):
self.session.execute(bound)
return True


def delete_by_pks(self, keys, values, table):
query_string = f"DELETE FROM {CASSANDRA_KEYSPACE}.{table} WHERE "
i = 0
for key in keys:
query_string += f"{key} = ?"
if i < len(keys) - 1:
query_string += " AND "
i += 1

statement = self.session.prepare(query_string)
statement.consistency_level = ConsistencyLevel.QUORUM
bound = statement.bind(values)
self.session.execute(bound)
return True


def update_run_status(self, id, thread_id, status):
query_string = f"""
UPDATE {CASSANDRA_KEYSPACE}.runs SET status = ? WHERE id = ? and thread_id = ?;
Expand Down
10 changes: 10 additions & 0 deletions impl/model/modify_message_request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from typing import Optional, Annotated

from pydantic import Field, StrictStr

from openapi_server.models.create_message_request import CreateMessageRequest


class ModifyMessageRequest(CreateMessageRequest):
content: Optional[str] = Field(default=None, min_length=1, strict=True, max_length=32768, description="The content of the message.")
role: Optional[StrictStr] = Field(default=None, description="The role of the entity that is creating the message. Currently only `user` is supported.")
50 changes: 37 additions & 13 deletions impl/routes/threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from impl.model.list_messages_stream_response import ListMessagesStreamResponse
from impl.model.message_object import MessageObject
from impl.model.message_stream_response_object import MessageStreamResponseObject
from impl.model.modify_message_request import ModifyMessageRequest
from impl.model.open_ai_file import OpenAIFile
from impl.model.run_object import RunObject
from impl.model.submit_tool_outputs_run_request import SubmitToolOutputsRunRequest
Expand All @@ -43,6 +44,7 @@
from impl.services.inference_utils import get_chat_completion, get_async_chat_completion_response
from openapi_server.models.create_message_request import CreateMessageRequest
from openapi_server.models.create_thread_request import CreateThreadRequest
from openapi_server.models.delete_message_response import DeleteMessageResponse
from openapi_server.models.delete_thread_response import DeleteThreadResponse
from openapi_server.models.list_runs_response import ListRunsResponse
from openapi_server.models.message_content_delta_object import MessageContentDeltaObject
Expand All @@ -51,7 +53,6 @@
from openapi_server.models.message_content_text_object_text import (
MessageContentTextObjectText,
)
from openapi_server.models.modify_message_request import ModifyMessageRequest
from openapi_server.models.modify_thread_request import ModifyThreadRequest
from openapi_server.models.run_object_required_action import RunObjectRequiredAction
from openapi_server.models.run_object_required_action_submit_tool_outputs import RunObjectRequiredActionSubmitToolOutputs
Expand Down Expand Up @@ -184,14 +185,35 @@ async def modify_message(
object="thread.message",
created_at=None,
thread_id=thread_id,
role=None,
content=None,
role=modify_message_request.role,
content=[modify_message_request.content],
assistant_id=None,
run_id=None,
file_ids=None,
file_ids=modify_message_request.file_ids,
metadata=modify_message_request.metadata,
)

@router.delete(
"/threads/{thread_id}/messages/{message_id}",
responses={
200: {"model": DeleteMessageResponse, "description": "OK"},
},
tags=["Assistants"],
summary="Delete a message.",
response_model_by_alias=True,
)
async def delete_message(
thread_id: str = Path(..., description="The ID of the thread to delete."),
message_id: str = Path(..., description="The ID of the message to delete."),
astradb: CassandraClient = Depends(verify_db_client),
) -> DeleteMessageResponse:
astradb.delete_by_pks(table="messages", keys=["id", "thread_id"], values=[message_id, thread_id])
return DeleteMessageResponse(
id=message_id,
object="thread.message.deleted",
deleted=True
)


def extractFunctionArguments(content):
pattern = r"\`\`\`.*({.*})\n\`\`\`"
Expand Down Expand Up @@ -761,9 +783,10 @@ async def process_rag(

if 'gemini' in model:
async for part in response:
text += part.choices[0].delta.content
start_time = await maybe_checkpoint(assistant_id, astradb, created_at, file_ids, frequency_in_seconds, message_id,
run_id, start_time, text, thread_id)
if part.choices[0].delta.content is not None:
text += part.choices[0].delta.content
start_time = await maybe_checkpoint(assistant_id, astradb, created_at, file_ids, frequency_in_seconds, message_id,
run_id, start_time, text, thread_id)
else:
done = False
while not done:
Expand Down Expand Up @@ -1391,12 +1414,13 @@ async def message_delta_streamer(message_id, created_at, response, run, astradb)

if 'gemini' in run.model:
async for part in response:
delta = part.choices[0].delta.content
event_json = await make_text_delta_event_from_chunk(delta, i, run, message_id, )
i += 1
yield f"data: {event_json}\n\n"
text += delta
start_time = await maybe_checkpoint(run.assistant_id, astradb, created_at, run.file_ids, frequency_in_seconds, message_id,
if part.choices[0].delta.content is not None:
delta = part.choices[0].delta.content
event_json = await make_text_delta_event_from_chunk(delta, i, run, message_id, )
i += 1
yield f"data: {event_json}\n\n"
text += delta
start_time = await maybe_checkpoint(run.assistant_id, astradb, created_at, run.file_ids, frequency_in_seconds, message_id,
run.id, start_time, text, run.thread_id)
else:
done = False
Expand Down
22 changes: 11 additions & 11 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@ docx2txt = "^0.8"
pypdf2 = "^3.0.1"
python-pptx = "^0.6.23"
gunicorn = "^21.2.0"
litellm = "1.33.4"
litellm = "1.34.18"
boto3 = "^1.29.6"
prometheus-fastapi-instrumentator = "^6.1.0"
google-cloud-aiplatform = "^1.38.0"
google-generativeai = "^0.3.1"
streaming-assistants = "^0.15.0rc3"
streaming-assistants = "^0.15.3"
annotated-types = "^0.6.0"
pydantic-core = "^2.16.3"
pydantic = "^2.6.4"
Expand Down
42 changes: 41 additions & 1 deletion tests/http/test_assistants_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from impl.model.create_assistant_request import CreateAssistantRequest
from impl.model.create_run_request import CreateRunRequest
from impl.model.message_object import MessageObject
from impl.model.modify_message_request import ModifyMessageRequest
from openapi_server.models.assistant_file_object import AssistantFileObject # noqa: F401
from openapi_server.models.assistant_object import AssistantObject # noqa: F401
from openapi_server.models.create_assistant_file_request import CreateAssistantFileRequest # noqa: F401
Expand All @@ -22,7 +23,6 @@
from openapi_server.models.list_run_steps_response import ListRunStepsResponse # noqa: F401
from openapi_server.models.list_runs_response import ListRunsResponse # noqa: F401
from openapi_server.models.message_file_object import MessageFileObject # noqa: F401
from openapi_server.models.modify_message_request import ModifyMessageRequest # noqa: F401
from openapi_server.models.modify_run_request import ModifyRunRequest # noqa: F401
from openapi_server.models.modify_thread_request import ModifyThreadRequest # noqa: F401
from openapi_server.models.run_object import RunObject # noqa: F401
Expand Down Expand Up @@ -469,6 +469,46 @@ def test_modify_message(client: TestClient):
# uncomment below to assert the status code of the HTTP response
assert response.status_code == 200

def test_modify_message_content(client: TestClient):
"""Test case for modify_message
Modifies a message.
"""

message = test_create_message(client)
modify_message_request = {"metadata":{}, "content": "puppies"}

headers = get_headers(MODEL)
response = client.request(
"POST",
"/threads/{thread_id}/messages/{message_id}".format(thread_id=message.thread_id, message_id=message.id),
headers=headers,
json=modify_message_request,
)

logger.info(response)
# uncomment below to assert the status code of the HTTP response
assert response.status_code == 200

def test_delete_message(client: TestClient):
"""Test case for delete_message
Deletes a message.
"""

message = test_create_message(client)

headers = get_headers(MODEL)
response = client.request(
"DELETE",
"/threads/{thread_id}/messages/{message_id}".format(thread_id=message.thread_id, message_id=message.id),
headers=headers,
)

logger.info(response)
# uncomment below to assert the status code of the HTTP response
assert response.status_code == 200


def test_modify_thread(client: TestClient):
"""Test case for modify_thread
Expand Down
4 changes: 2 additions & 2 deletions tests/streaming-assistants/test_streaming_run_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ def test_run_gpt3_5(patched_openai_client):
)
run_with_assistant(gpt3_assistant, patched_openai_client)

@pytest.mark.skip(reason="replace with command-r because context window")
#@pytest.mark.skip(reason="replace with command-r because context window")
def test_run_cohere(patched_openai_client):
model = "cohere/command"
model = "command-r"
name = f"{model} Math Tutor"

cohere_assistant = patched_openai_client.beta.assistants.create(
Expand Down

0 comments on commit 0216897

Please sign in to comment.