Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tool choice fix #67

Merged
merged 11 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions client/astra_assistants/astra_assistants_event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def run_tool(self, tool_call):
model: BaseModel = tool.get_model()
if issubclass(model, BaseModel):
self.arguments = model(**arguments)
else:
self.arguments = arguments
results = tool.call(self.arguments)
return results
except Exception as e:
Expand Down
44 changes: 33 additions & 11 deletions client/astra_assistants/tools/structured_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,22 @@ class IndentLeftEdit(BaseModel):
end_line_number: Optional[int] = Field(None, description="Line number where the indent left edit ends (line numbers are inclusive, i.e. start_line_number 1 end_line_number 1 will indent 1 line, start_line_number 1 end_line_number 2 will indent two lines)")
class Config:
schema_extra = {
"example": {
"thoughts": "let's move lines 55 through 57 to the left by one indentation unit",
"start_line_number": 55,
"end_line_number": 57,
}
"examples": [
{
"thoughts": "let's move lines 55 through 57 to the left by one indentation unit",
"start_line_number": 55,
"end_line_number": 57,
},
{
"thoughts": "let's move line 12 to the left by one indentation unit",
"start_line_number": 12,
},
{
"thoughts": "let's move lines 100 through 101 to the left by one indentation unit",
"start_line_number": 100,
"end_line_number": 101,
},
]
}


Expand All @@ -49,11 +60,22 @@ class IndentRightEdit(BaseModel):
end_line_number: Optional[int] = Field(None, description="Line number where the indent right edit ends (line numbers are inclusive, i.e. start_line_number 1 end_line_number 1 will indent 1 line, start_line_number 1 end_line_number 2 will indent two lines)")
class Config:
schema_extra = {
"example": {
"thoughts": "let's move lines 55 through 57 to the right by one indentation unit",
"start_line_number": 55,
"end_line_number": 57,
}
"examples": [
{
"thoughts": "let's move lines 55 through 57 to the right by one indentation unit",
"start_line_number": 55,
"end_line_number": 57,
},
{
"thoughts": "let's move line 12 to the right by one indentation unit",
"start_line_number": 12,
},
{
"thoughts": "let's move lines 100 through 101 to the right by one indentation unit",
"start_line_number": 100,
"end_line_number": 101,
},
]
}


Expand Down Expand Up @@ -303,7 +325,7 @@ def call(self, edit: IndentLeftEdit):
indentation_unit = get_indentation_unit(program.to_string(with_line_numbers=False), edit.start_line_number-1)
i = edit.start_line_number-1
if edit.end_line_number is not None:
while i < edit.end_line_number:
while i < edit.end_line_number and i < len(program.lines):
program.lines[i] = program.lines[i].replace(indentation_unit, "", 1)
i += 1
else:
Expand Down
2 changes: 2 additions & 0 deletions client/astra_assistants/tools/tool_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def to_function(self):
param_type = param.annotation
if issubclass(param_type, BaseModel):
parameters = param_type.schema()
if hasattr(param_type, "Config") and hasattr(param_type.Config, "schema_extra"):
parameters.update(param_type.Config.schema_extra)
else:
parameters = {
"type": "object",
Expand Down
11 changes: 8 additions & 3 deletions client/tests/tools/test_structured_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def test_structured_code_raw(patched_openai_client):
program_id = programs[0]['program_id']
program = programs[0]['output']
patched_openai_client.beta.threads.messages.create(thread.id, content=f"nice, now add trigonometric functions to program_id {program_id}: \n{program.to_string()}" , role="user")
code_editor.set_program_id(program_id)
with patched_openai_client.beta.threads.runs.create_and_stream(
thread_id=thread.id,
assistant_id=assistant.id,
Expand Down Expand Up @@ -90,6 +91,7 @@ async def test_structured_code_with_manager(patched_openai_client):
tool=code_generator
)
content = f"nice, now add trigonometric functions to program_id {result['program_id']}: \n{result['output'].to_string()}"
code_editor.set_program_id(result['program_id'])
result = await assistant_manager.run_thread(
content=content,
tool=code_editor
Expand Down Expand Up @@ -131,7 +133,8 @@ def factorial(n):

chunks: ToolOutput = assistant_manager.stream_thread(
content="Rewrite to use memoization.",
tool_choice=code_rewriter
#tool_choice=code_rewriter
tool_choice="auto"
)

text = ""
Expand All @@ -157,6 +160,8 @@ def factorial(n):
program_id = add_program_to_cache(program, programs)
print(program_id)



def test_structured_rewrite_and_edit_with_manager(patched_openai_client):
programs: List[Dict[str, StructuredProgram]] = []
program_content = """
Expand Down Expand Up @@ -186,7 +191,7 @@ def factorial(n):
assistant_manager = AssistantManager(
instructions="use the structured code tool to generate code to help the user.",
tools=tools,
model="gpt-4o",
model="openai/gpt-4o-2024-08-06",
)

#code_indent_left.set_program_id(program_id)
Expand Down Expand Up @@ -217,6 +222,6 @@ def factorial(n):
tool_choice=code_rewriter
)

program_id = add_chunks_to_cache(chunks, programs)
program_id = add_chunks_to_cache(chunks, programs)['program_id']
assert len(programs) == 3
print(program_id)
17 changes: 15 additions & 2 deletions impl/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import httpx
import openai
from fastapi import FastAPI, Request, HTTPException
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from prometheus_client import Counter, Summary, Histogram
from prometheus_fastapi_instrumentator import Instrumentator
Expand All @@ -24,6 +25,7 @@

from loguru import logger


cass_logger = logging.getLogger('cassandra')
cass_logger.setLevel(logging.WARN)

Expand Down Expand Up @@ -112,7 +114,7 @@ async def dispatch(self, request: Request, call_next):
response = await call_next(request)
return response
except Exception as e:
logger.error(f"Error: {e} {request.json()}")
logger.error(f"Error: {e}")
print(e)
raise e

Expand Down Expand Up @@ -231,7 +233,7 @@ async def shutdown_event():
@app.exception_handler(Exception)
async def generic_exception_handler(request: Request, exc: Exception):
# Log the error
logger.error(f"Unexpected error: {exc} for request url {request.url} request method {request.method} request path params {request.path_params} request query params {request.query_params} request body {await request.body()} base_url {request.base_url}")
logger.error(f"Unexpected error: {exc} for request url {request.url} request method {request.method} request path params {request.path_params} request query params {request.query_params} base_url {request.base_url}")

if isinstance(exc, HTTPException):
raise exec
Expand All @@ -241,6 +243,17 @@ async def generic_exception_handler(request: Request, exc: Exception):
)


@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
logging.error(f"Validation error for request: {request.url}")
logging.error(f"Body: {exc.body}")
logger.error(f"Validation error: {exc} for request url {request.url} request method {request.method} request path params {request.path_params} request query params {request.query_params} base_url {request.base_url}")
logging.error(f"Errors: {exc.errors()}")
return JSONResponse(
status_code=422,
content={"detail": exc.errors()},
)

@app.api_route(
"/{full_path:path}",
methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH", "TRACE"],
Expand Down
5 changes: 2 additions & 3 deletions impl/model_v2/create_run_request.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import Optional
from typing import Optional, Any

from impl.model_v2.assistants_api_tool_choice_option import AssistantsApiToolChoiceOption
from openapi_server_v2.models.create_run_request import CreateRunRequest as GeneratedCreateRunRequest


class CreateRunRequest(GeneratedCreateRunRequest):
tool_choice: Optional[AssistantsApiToolChoiceOption] = None
tool_choice: Optional[Any] = None
6 changes: 6 additions & 0 deletions impl/model_v2/create_thread_and_run_request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from typing import Optional, Any

from openapi_server_v2.models.create_thread_and_run_request import CreateThreadAndRunRequest as GeneratedCreateThreadAndRunRequest

class CreateThreadAndRunRequest(GeneratedCreateThreadAndRunRequest):
tool_choice: Optional[Any] = None
42 changes: 33 additions & 9 deletions impl/routes_v2/threads_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
import re
import time
from typing import Dict, Any, Union, get_origin, Type, List, Optional
from typing import Dict, Any, Union, Type, List, Optional


from fastapi import APIRouter, Body, Depends, Path, HTTPException, Query
Expand All @@ -24,6 +24,7 @@
from impl.routes_v2.vector_stores import read_vsf
from impl.services.inference_utils import get_chat_completion, get_async_chat_completion_response
from impl.utils import map_model, store_object, read_object, read_objects, generate_id
from impl.model_v2.create_thread_and_run_request import CreateThreadAndRunRequest
from openapi_server_v2.models.assistants_api_response_format_option import AssistantsApiResponseFormatOption
from openapi_server_v2.models.assistants_api_tool_choice_option import AssistantsApiToolChoiceOption
from openapi_server_v2.models.message_delta_object_delta_content_inner import MessageDeltaObjectDeltaContentInner
Expand All @@ -37,7 +38,6 @@
from openapi_server_v2.models.truncation_object import TruncationObject
from openapi_server_v2.models.assistant_stream_event import AssistantStreamEvent
from openapi_server_v2.models.create_message_request import CreateMessageRequest
from openapi_server_v2.models.create_thread_and_run_request import CreateThreadAndRunRequest
from openapi_server_v2.models.create_thread_request import CreateThreadRequest
from openapi_server_v2.models.delete_message_response import DeleteMessageResponse
from openapi_server_v2.models.delete_thread_response import DeleteThreadResponse
Expand Down Expand Up @@ -88,6 +88,7 @@
tags=["Assistants"],
summary="Create a thread.",
response_model_by_alias=True,
response_model=None
)
async def create_thread(
create_thread_request: CreateThreadRequest = Body(None, description=""),
Expand Down Expand Up @@ -136,6 +137,7 @@ async def get_thread(
tags=["Assistants"],
summary="Modifies a thread.",
response_model_by_alias=True,
response_model=None
)
async def modify_thread(
thread_id: str = Path(...,
Expand All @@ -160,6 +162,7 @@ async def modify_thread(
tags=["Assistants"],
summary="Delete a thread.",
response_model_by_alias=True,
response_model=None
)
async def delete_thread(
thread_id: str = Path(..., description="The ID of the thread to delete."),
Expand Down Expand Up @@ -265,6 +268,7 @@ def messages_json_to_objects(raw_messages):
tags=["Assistants"],
summary="Modifies a message.",
response_model_by_alias=True,
response_model=None
)
async def modify_message(
thread_id: str = Path(..., description="The ID of the thread to which this message belongs."),
Expand Down Expand Up @@ -302,6 +306,7 @@ async def modify_message(
tags=["Assistants"],
summary="Delete a message.",
response_model_by_alias=True,
response_model=None
)
async def delete_message(
thread_id: str = Path(..., description="The ID of the thread to delete."),
Expand Down Expand Up @@ -757,6 +762,24 @@ async def create_run(
if create_run_request.additional_instructions is not None:
instructions = instructions + "\n Additional Instructions:\n" + create_run_request.additional_instructions

required_action = None

# TODO consider initializing the run here otherwise we need a retry elsewhere
#run = await store_run(
# id=run_id,
# created_at=created_at,
# thread_id=thread_id,
# assistant_id=create_run_request.assistant_id,
# status=status,
# required_action=None,
# model=model,
# tools=tools,
# instructions=instructions,
# create_run_request=create_run_request,
# astradb=astradb,
#)
#logger.info(f"initial create run {run.id} for thread {run.thread_id} will upsert later")

toolsJson = []
if len(tools) == 0:

Expand Down Expand Up @@ -850,12 +873,12 @@ async def create_run(
if tool.type == "function":
toolsJson.append(tool.dict())

required_action = None

if len(toolsJson) > 0:
litellm_kwargs[0]["tools"] = toolsJson
if create_run_request.tool_choice is not None and hasattr(create_run_request.tool_choice, "to_dict"):
litellm_kwargs[0]["tool_choice"] = create_run_request.tool_choice.to_dict()
if create_run_request.tool_choice is not None and isinstance(create_run_request.tool_choice, dict):
litellm_kwargs[0]["tool_choice"] = create_run_request.tool_choice
elif create_run_request.tool_choice is not None and isinstance(create_run_request.tool_choice, str):
litellm_kwargs[0]["tool_choice"] = create_run_request.tool_choice
else:
litellm_kwargs[0]["tool_choice"] = "auto"
message_content = summarize_message_content(instructions, messages.data, False)
Expand All @@ -865,6 +888,7 @@ async def create_run(
logger.error(f"error: {e}, tenant {astradb.dbid}, model {model}, messages.data {messages.data}, create_run_request {create_run_request}")
raise HTTPException(status_code=500, detail=f"Error processing message, {e}")

logger.info(f"tool_call message: {message}")
tool_call_object_id = generate_id("call")
run_tool_calls = []
# TODO: fix this, we can't hang off message.content because it turns out you can have both a message and a tool call.
Expand Down Expand Up @@ -1014,6 +1038,8 @@ async def store_run(id, created_at, thread_id, assistant_id, status, required_ac
tool_choice = create_run_request.tool_choice
if tool_choice is None:
tool_choice = AssistantsApiToolChoiceOption(actual_instance="auto")
else:
tool_choice = AssistantsApiToolChoiceOption(actual_instance=tool_choice)

response_format = create_run_request.response_format
if response_format is None:
Expand All @@ -1038,7 +1064,6 @@ async def store_run(id, created_at, thread_id, assistant_id, status, required_ac
"truncation_strategy": truncation_strategy,
"tool_choice": tool_choice,
"response_format": response_format,

}
run = await store_object(astradb=astradb, obj=create_run_request, target_class=RunObject, table_name="runs_v2", extra_fields=extra_fields)
return run
Expand Down Expand Up @@ -1806,10 +1831,9 @@ async def make_text_delta_obj_from_chunk(chunk, i, run, message_id):
tags=["Assistants"],
summary="Create a thread and run it in one request.",
response_model_by_alias=True,
response_model=None
)
async def create_thread_and_run(
# TODO - make copy of CreateThreadAndRunRequest to handle LiteralGenericAlias issue with Tools
# also do it for create run
create_thread_and_run_request: CreateThreadAndRunRequest = Body(None, description=""),
astradb: CassandraClient = Depends(verify_db_client),
embedding_model: str = Depends(infer_embedding_model),
Expand Down
2 changes: 2 additions & 0 deletions impl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import logging
import secrets
import traceback
from typing import Type, Dict, Any, List, get_origin, Annotated, get_args, Union

from fastapi import HTTPException
Expand Down Expand Up @@ -88,6 +89,7 @@ def read_object(astradb: CassandraClient, target_class: Type[BaseModel], table_n
objs = read_objects(astradb, target_class, table_name, partition_keys, args)
except Exception as e:
logger.error(f"read_object failed {e} for table {table_name}")
logger.error(f"trace: {traceback.format_exc()}")
raise HTTPException(status_code=404, detail=f"{target_class.__name__} not found.")
if len(objs) == 0:
# Maybe pass down name
Expand Down
Loading