diff --git a/client/astra_assistants/astra_assistants_event_handler.py b/client/astra_assistants/astra_assistants_event_handler.py index b45191d..79c642b 100644 --- a/client/astra_assistants/astra_assistants_event_handler.py +++ b/client/astra_assistants/astra_assistants_event_handler.py @@ -63,6 +63,8 @@ def run_tool(self, tool_call): model: BaseModel = tool.get_model() if issubclass(model, BaseModel): self.arguments = model(**arguments) + else: + self.arguments = arguments results = tool.call(self.arguments) return results except Exception as e: diff --git a/client/astra_assistants/tools/structured_code.py b/client/astra_assistants/tools/structured_code.py index fa2366a..b724e6c 100644 --- a/client/astra_assistants/tools/structured_code.py +++ b/client/astra_assistants/tools/structured_code.py @@ -34,11 +34,22 @@ class IndentLeftEdit(BaseModel): end_line_number: Optional[int] = Field(None, description="Line number where the indent left edit ends (line numbers are inclusive, i.e. start_line_number 1 end_line_number 1 will indent 1 line, start_line_number 1 end_line_number 2 will indent two lines)") class Config: schema_extra = { - "example": { - "thoughts": "let's move lines 55 through 57 to the left by one indentation unit", - "start_line_number": 55, - "end_line_number": 57, - } + "examples": [ + { + "thoughts": "let's move lines 55 through 57 to the left by one indentation unit", + "start_line_number": 55, + "end_line_number": 57, + }, + { + "thoughts": "let's move line 12 to the left by one indentation unit", + "start_line_number": 12, + }, + { + "thoughts": "let's move lines 100 through 101 to the left by one indentation unit", + "start_line_number": 100, + "end_line_number": 101, + }, + ] } @@ -49,11 +60,22 @@ class IndentRightEdit(BaseModel): end_line_number: Optional[int] = Field(None, description="Line number where the indent right edit ends (line numbers are inclusive, i.e. start_line_number 1 end_line_number 1 will indent 1 line, start_line_number 1 end_line_number 2 will indent two lines)") class Config: schema_extra = { - "example": { - "thoughts": "let's move lines 55 through 57 to the right by one indentation unit", - "start_line_number": 55, - "end_line_number": 57, - } + "examples": [ + { + "thoughts": "let's move lines 55 through 57 to the right by one indentation unit", + "start_line_number": 55, + "end_line_number": 57, + }, + { + "thoughts": "let's move line 12 to the right by one indentation unit", + "start_line_number": 12, + }, + { + "thoughts": "let's move lines 100 through 101 to the right by one indentation unit", + "start_line_number": 100, + "end_line_number": 101, + }, + ] } @@ -303,7 +325,7 @@ def call(self, edit: IndentLeftEdit): indentation_unit = get_indentation_unit(program.to_string(with_line_numbers=False), edit.start_line_number-1) i = edit.start_line_number-1 if edit.end_line_number is not None: - while i < edit.end_line_number: + while i < edit.end_line_number and i < len(program.lines): program.lines[i] = program.lines[i].replace(indentation_unit, "", 1) i += 1 else: diff --git a/client/astra_assistants/tools/tool_interface.py b/client/astra_assistants/tools/tool_interface.py index 69c71ff..05305f5 100644 --- a/client/astra_assistants/tools/tool_interface.py +++ b/client/astra_assistants/tools/tool_interface.py @@ -34,6 +34,8 @@ def to_function(self): param_type = param.annotation if issubclass(param_type, BaseModel): parameters = param_type.schema() + if hasattr(param_type, "Config") and hasattr(param_type.Config, "schema_extra"): + parameters.update(param_type.Config.schema_extra) else: parameters = { "type": "object", diff --git a/client/tests/tools/test_structured_code.py b/client/tests/tools/test_structured_code.py index 92e96bf..e929e7a 100644 --- a/client/tests/tools/test_structured_code.py +++ b/client/tests/tools/test_structured_code.py @@ -58,6 +58,7 @@ def test_structured_code_raw(patched_openai_client): program_id = programs[0]['program_id'] program = programs[0]['output'] patched_openai_client.beta.threads.messages.create(thread.id, content=f"nice, now add trigonometric functions to program_id {program_id}: \n{program.to_string()}" , role="user") + code_editor.set_program_id(program_id) with patched_openai_client.beta.threads.runs.create_and_stream( thread_id=thread.id, assistant_id=assistant.id, @@ -90,6 +91,7 @@ async def test_structured_code_with_manager(patched_openai_client): tool=code_generator ) content = f"nice, now add trigonometric functions to program_id {result['program_id']}: \n{result['output'].to_string()}" + code_editor.set_program_id(result['program_id']) result = await assistant_manager.run_thread( content=content, tool=code_editor @@ -131,7 +133,8 @@ def factorial(n): chunks: ToolOutput = assistant_manager.stream_thread( content="Rewrite to use memoization.", - tool_choice=code_rewriter + #tool_choice=code_rewriter + tool_choice="auto" ) text = "" @@ -157,6 +160,8 @@ def factorial(n): program_id = add_program_to_cache(program, programs) print(program_id) + + def test_structured_rewrite_and_edit_with_manager(patched_openai_client): programs: List[Dict[str, StructuredProgram]] = [] program_content = """ @@ -186,7 +191,7 @@ def factorial(n): assistant_manager = AssistantManager( instructions="use the structured code tool to generate code to help the user.", tools=tools, - model="gpt-4o", + model="openai/gpt-4o-2024-08-06", ) #code_indent_left.set_program_id(program_id) @@ -217,6 +222,6 @@ def factorial(n): tool_choice=code_rewriter ) - program_id = add_chunks_to_cache(chunks, programs) + program_id = add_chunks_to_cache(chunks, programs)['program_id'] assert len(programs) == 3 print(program_id) diff --git a/impl/main.py b/impl/main.py index b2dbf8c..fadbaa4 100644 --- a/impl/main.py +++ b/impl/main.py @@ -8,6 +8,7 @@ import httpx import openai from fastapi import FastAPI, Request, HTTPException +from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse from prometheus_client import Counter, Summary, Histogram from prometheus_fastapi_instrumentator import Instrumentator @@ -24,6 +25,7 @@ from loguru import logger + cass_logger = logging.getLogger('cassandra') cass_logger.setLevel(logging.WARN) @@ -112,7 +114,7 @@ async def dispatch(self, request: Request, call_next): response = await call_next(request) return response except Exception as e: - logger.error(f"Error: {e} {request.json()}") + logger.error(f"Error: {e}") print(e) raise e @@ -231,7 +233,7 @@ async def shutdown_event(): @app.exception_handler(Exception) async def generic_exception_handler(request: Request, exc: Exception): # Log the error - logger.error(f"Unexpected error: {exc} for request url {request.url} request method {request.method} request path params {request.path_params} request query params {request.query_params} request body {await request.body()} base_url {request.base_url}") + logger.error(f"Unexpected error: {exc} for request url {request.url} request method {request.method} request path params {request.path_params} request query params {request.query_params} base_url {request.base_url}") if isinstance(exc, HTTPException): raise exec @@ -241,6 +243,17 @@ async def generic_exception_handler(request: Request, exc: Exception): ) +@app.exception_handler(RequestValidationError) +async def validation_exception_handler(request: Request, exc: RequestValidationError): + logging.error(f"Validation error for request: {request.url}") + logging.error(f"Body: {exc.body}") + logger.error(f"Validation error: {exc} for request url {request.url} request method {request.method} request path params {request.path_params} request query params {request.query_params} base_url {request.base_url}") + logging.error(f"Errors: {exc.errors()}") + return JSONResponse( + status_code=422, + content={"detail": exc.errors()}, + ) + @app.api_route( "/{full_path:path}", methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "PATCH", "TRACE"], diff --git a/impl/model_v2/create_run_request.py b/impl/model_v2/create_run_request.py index b8183e5..d52ed94 100644 --- a/impl/model_v2/create_run_request.py +++ b/impl/model_v2/create_run_request.py @@ -1,8 +1,7 @@ -from typing import Optional +from typing import Optional, Any -from impl.model_v2.assistants_api_tool_choice_option import AssistantsApiToolChoiceOption from openapi_server_v2.models.create_run_request import CreateRunRequest as GeneratedCreateRunRequest class CreateRunRequest(GeneratedCreateRunRequest): - tool_choice: Optional[AssistantsApiToolChoiceOption] = None + tool_choice: Optional[Any] = None diff --git a/impl/model_v2/create_thread_and_run_request.py b/impl/model_v2/create_thread_and_run_request.py new file mode 100644 index 0000000..ee2863e --- /dev/null +++ b/impl/model_v2/create_thread_and_run_request.py @@ -0,0 +1,6 @@ +from typing import Optional, Any + +from openapi_server_v2.models.create_thread_and_run_request import CreateThreadAndRunRequest as GeneratedCreateThreadAndRunRequest + +class CreateThreadAndRunRequest(GeneratedCreateThreadAndRunRequest): + tool_choice: Optional[Any] = None \ No newline at end of file diff --git a/impl/routes_v2/threads_v2.py b/impl/routes_v2/threads_v2.py index 79fcebb..fcded13 100644 --- a/impl/routes_v2/threads_v2.py +++ b/impl/routes_v2/threads_v2.py @@ -5,7 +5,7 @@ import logging import re import time -from typing import Dict, Any, Union, get_origin, Type, List, Optional +from typing import Dict, Any, Union, Type, List, Optional from fastapi import APIRouter, Body, Depends, Path, HTTPException, Query @@ -24,6 +24,7 @@ from impl.routes_v2.vector_stores import read_vsf from impl.services.inference_utils import get_chat_completion, get_async_chat_completion_response from impl.utils import map_model, store_object, read_object, read_objects, generate_id +from impl.model_v2.create_thread_and_run_request import CreateThreadAndRunRequest from openapi_server_v2.models.assistants_api_response_format_option import AssistantsApiResponseFormatOption from openapi_server_v2.models.assistants_api_tool_choice_option import AssistantsApiToolChoiceOption from openapi_server_v2.models.message_delta_object_delta_content_inner import MessageDeltaObjectDeltaContentInner @@ -37,7 +38,6 @@ from openapi_server_v2.models.truncation_object import TruncationObject from openapi_server_v2.models.assistant_stream_event import AssistantStreamEvent from openapi_server_v2.models.create_message_request import CreateMessageRequest -from openapi_server_v2.models.create_thread_and_run_request import CreateThreadAndRunRequest from openapi_server_v2.models.create_thread_request import CreateThreadRequest from openapi_server_v2.models.delete_message_response import DeleteMessageResponse from openapi_server_v2.models.delete_thread_response import DeleteThreadResponse @@ -88,6 +88,7 @@ tags=["Assistants"], summary="Create a thread.", response_model_by_alias=True, + response_model=None ) async def create_thread( create_thread_request: CreateThreadRequest = Body(None, description=""), @@ -136,6 +137,7 @@ async def get_thread( tags=["Assistants"], summary="Modifies a thread.", response_model_by_alias=True, + response_model=None ) async def modify_thread( thread_id: str = Path(..., @@ -160,6 +162,7 @@ async def modify_thread( tags=["Assistants"], summary="Delete a thread.", response_model_by_alias=True, + response_model=None ) async def delete_thread( thread_id: str = Path(..., description="The ID of the thread to delete."), @@ -265,6 +268,7 @@ def messages_json_to_objects(raw_messages): tags=["Assistants"], summary="Modifies a message.", response_model_by_alias=True, + response_model=None ) async def modify_message( thread_id: str = Path(..., description="The ID of the thread to which this message belongs."), @@ -302,6 +306,7 @@ async def modify_message( tags=["Assistants"], summary="Delete a message.", response_model_by_alias=True, + response_model=None ) async def delete_message( thread_id: str = Path(..., description="The ID of the thread to delete."), @@ -757,6 +762,24 @@ async def create_run( if create_run_request.additional_instructions is not None: instructions = instructions + "\n Additional Instructions:\n" + create_run_request.additional_instructions + required_action = None + + # TODO consider initializing the run here otherwise we need a retry elsewhere + #run = await store_run( + # id=run_id, + # created_at=created_at, + # thread_id=thread_id, + # assistant_id=create_run_request.assistant_id, + # status=status, + # required_action=None, + # model=model, + # tools=tools, + # instructions=instructions, + # create_run_request=create_run_request, + # astradb=astradb, + #) + #logger.info(f"initial create run {run.id} for thread {run.thread_id} will upsert later") + toolsJson = [] if len(tools) == 0: @@ -850,12 +873,12 @@ async def create_run( if tool.type == "function": toolsJson.append(tool.dict()) - required_action = None - if len(toolsJson) > 0: litellm_kwargs[0]["tools"] = toolsJson - if create_run_request.tool_choice is not None and hasattr(create_run_request.tool_choice, "to_dict"): - litellm_kwargs[0]["tool_choice"] = create_run_request.tool_choice.to_dict() + if create_run_request.tool_choice is not None and isinstance(create_run_request.tool_choice, dict): + litellm_kwargs[0]["tool_choice"] = create_run_request.tool_choice + elif create_run_request.tool_choice is not None and isinstance(create_run_request.tool_choice, str): + litellm_kwargs[0]["tool_choice"] = create_run_request.tool_choice else: litellm_kwargs[0]["tool_choice"] = "auto" message_content = summarize_message_content(instructions, messages.data, False) @@ -865,6 +888,7 @@ async def create_run( logger.error(f"error: {e}, tenant {astradb.dbid}, model {model}, messages.data {messages.data}, create_run_request {create_run_request}") raise HTTPException(status_code=500, detail=f"Error processing message, {e}") + logger.info(f"tool_call message: {message}") tool_call_object_id = generate_id("call") run_tool_calls = [] # TODO: fix this, we can't hang off message.content because it turns out you can have both a message and a tool call. @@ -1014,6 +1038,8 @@ async def store_run(id, created_at, thread_id, assistant_id, status, required_ac tool_choice = create_run_request.tool_choice if tool_choice is None: tool_choice = AssistantsApiToolChoiceOption(actual_instance="auto") + else: + tool_choice = AssistantsApiToolChoiceOption(actual_instance=tool_choice) response_format = create_run_request.response_format if response_format is None: @@ -1038,7 +1064,6 @@ async def store_run(id, created_at, thread_id, assistant_id, status, required_ac "truncation_strategy": truncation_strategy, "tool_choice": tool_choice, "response_format": response_format, - } run = await store_object(astradb=astradb, obj=create_run_request, target_class=RunObject, table_name="runs_v2", extra_fields=extra_fields) return run @@ -1806,10 +1831,9 @@ async def make_text_delta_obj_from_chunk(chunk, i, run, message_id): tags=["Assistants"], summary="Create a thread and run it in one request.", response_model_by_alias=True, + response_model=None ) async def create_thread_and_run( - # TODO - make copy of CreateThreadAndRunRequest to handle LiteralGenericAlias issue with Tools - # also do it for create run create_thread_and_run_request: CreateThreadAndRunRequest = Body(None, description=""), astradb: CassandraClient = Depends(verify_db_client), embedding_model: str = Depends(infer_embedding_model), diff --git a/impl/utils.py b/impl/utils.py index 9f980f4..c29b305 100644 --- a/impl/utils.py +++ b/impl/utils.py @@ -4,6 +4,7 @@ import json import logging import secrets +import traceback from typing import Type, Dict, Any, List, get_origin, Annotated, get_args, Union from fastapi import HTTPException @@ -88,6 +89,7 @@ def read_object(astradb: CassandraClient, target_class: Type[BaseModel], table_n objs = read_objects(astradb, target_class, table_name, partition_keys, args) except Exception as e: logger.error(f"read_object failed {e} for table {table_name}") + logger.error(f"trace: {traceback.format_exc()}") raise HTTPException(status_code=404, detail=f"{target_class.__name__} not found.") if len(objs) == 0: # Maybe pass down name