diff --git a/app/common/__init__.py b/app/common/__init__.py index dc1e08c8..b2186c1d 100644 --- a/app/common/__init__.py +++ b/app/common/__init__.py @@ -1,5 +1,5 @@ -from ..common.singleton import Singleton -from ..common.message_converters import ( +from app.common.singleton import Singleton +from app.common.message_converters import ( convert_iris_message_to_langchain_message, convert_langchain_message_to_iris_message, ) diff --git a/app/common/message_converters.py b/app/common/message_converters.py index fbcb17c7..3059a57b 100644 --- a/app/common/message_converters.py +++ b/app/common/message_converters.py @@ -1,9 +1,15 @@ +from datetime import datetime + from langchain_core.messages import BaseMessage -from ..domain.iris_message import IrisMessage, IrisMessageRole + +from app.domain.data.text_message_content_dto import TextMessageContentDTO +from app.domain.pyris_message import PyrisMessage, IrisMessageRole -def convert_iris_message_to_langchain_message(iris_message: IrisMessage) -> BaseMessage: - match iris_message.role: +def convert_iris_message_to_langchain_message( + iris_message: PyrisMessage, +) -> BaseMessage: + match iris_message.sender: case IrisMessageRole.USER: role = "human" case IrisMessageRole.ASSISTANT: @@ -11,11 +17,19 @@ def convert_iris_message_to_langchain_message(iris_message: IrisMessage) -> Base case IrisMessageRole.SYSTEM: role = "system" case _: - raise ValueError(f"Unknown message role: {iris_message.role}") - return BaseMessage(content=iris_message.text, type=role) + raise ValueError(f"Unknown message role: {iris_message.sender}") + if len(iris_message.contents) == 0: + raise ValueError("IrisMessage contents must not be empty") + message = iris_message.contents[0] + # Check if the message is of type TextMessageContentDTO + if not isinstance(message, TextMessageContentDTO): + raise ValueError("Message must be of type TextMessageContentDTO") + return BaseMessage(content=message.text_content, type=role) -def convert_langchain_message_to_iris_message(base_message: BaseMessage) -> IrisMessage: +def convert_langchain_message_to_iris_message( + base_message: BaseMessage, +) -> PyrisMessage: match base_message.type: case "human": role = IrisMessageRole.USER @@ -25,4 +39,33 @@ def convert_langchain_message_to_iris_message(base_message: BaseMessage) -> Iris role = IrisMessageRole.SYSTEM case _: raise ValueError(f"Unknown message type: {base_message.type}") - return IrisMessage(text=base_message.content, role=role) + contents = [TextMessageContentDTO(textContent=base_message.content)] + return PyrisMessage( + contents=contents, + sender=role, + send_at=datetime.now(), + ) + + +def map_role_to_str(role: IrisMessageRole) -> str: + match role: + case IrisMessageRole.USER: + return "user" + case IrisMessageRole.ASSISTANT: + return "assistant" + case IrisMessageRole.SYSTEM: + return "system" + case _: + raise ValueError(f"Unknown message role: {role}") + + +def map_str_to_role(role: str) -> IrisMessageRole: + match role: + case "user": + return IrisMessageRole.USER + case "assistant": + return IrisMessageRole.ASSISTANT + case "system": + return IrisMessageRole.SYSTEM + case _: + raise ValueError(f"Unknown message role: {role}") diff --git a/app/domain/__init__.py b/app/domain/__init__.py index 2b67a350..149df609 100644 --- a/app/domain/__init__.py +++ b/app/domain/__init__.py @@ -1,7 +1,7 @@ from .error_response_dto import IrisErrorResponseDTO from .pipeline_execution_dto import PipelineExecutionDTO from .pipeline_execution_settings_dto import PipelineExecutionSettingsDTO -from ..domain.tutor_chat.tutor_chat_pipeline_execution_dto import ( +from app.domain.tutor_chat.tutor_chat_pipeline_execution_dto import ( TutorChatPipelineExecutionDTO, ) -from .iris_message import IrisMessage, IrisMessageRole +from .pyris_message import PyrisMessage, IrisMessageRole diff --git a/app/domain/data/feedback_dto.py b/app/domain/data/feedback_dto.py index 2615ef5e..d285a132 100644 --- a/app/domain/data/feedback_dto.py +++ b/app/domain/data/feedback_dto.py @@ -5,7 +5,7 @@ class FeedbackDTO(BaseModel): text: Optional[str] = None - test_case_name: str = Field(alias="testCaseName") + test_case_name: Optional[str] = Field(alias="testCaseName", default=None) credits: float def __str__(self): diff --git a/app/domain/data/message_dto.py b/app/domain/data/message_dto.py deleted file mode 100644 index 8ed76917..00000000 --- a/app/domain/data/message_dto.py +++ /dev/null @@ -1,51 +0,0 @@ -from datetime import datetime -from enum import Enum -from typing import List, Literal - -from langchain_core.messages import HumanMessage, AIMessage - -from .message_content_dto import MessageContentDTO -from ...domain.iris_message import IrisMessage - -from pydantic import BaseModel, Field - - -class IrisMessageSender(str, Enum): - USER = "USER" - LLM = "LLM" - - -class MessageDTO(BaseModel): - sent_at: datetime | None = Field(alias="sentAt", default=None) - sender: Literal[IrisMessageSender.USER, IrisMessageSender.LLM] - contents: List[MessageContentDTO] = [] - - def __str__(self): - match self.sender: - case IrisMessageSender.USER: - sender = "user" - case IrisMessageSender.LLM: - sender = "assistant" - case _: - raise ValueError(f"Unknown message sender: {self.sender}") - return f"{sender}: {self.contents[0].text_content}" - - def convert_to_iris_message(self): - match self.sender: - case IrisMessageSender.USER: - sender = "user" - case IrisMessageSender.LLM: - sender = "assistant" - case _: - raise ValueError(f"Unknown message sender: {self.sender}") - - return IrisMessage(text=self.contents[0].text_content, role=sender) - - def convert_to_langchain_message(self): - match self.sender: - case IrisMessageSender.USER: - return HumanMessage(content=self.contents[0].text_content) - case IrisMessageSender.LLM: - return AIMessage(content=self.contents[0].text_content) - case _: - raise ValueError(f"Unknown message sender: {self.sender}") diff --git a/app/domain/data/programming_exercise_dto.py b/app/domain/data/programming_exercise_dto.py index 3f30c8d2..75d7ffb4 100644 --- a/app/domain/data/programming_exercise_dto.py +++ b/app/domain/data/programming_exercise_dto.py @@ -21,7 +21,7 @@ class ProgrammingLanguage(str, Enum): class ProgrammingExerciseDTO(BaseModel): id: int name: str - programming_language: ProgrammingLanguage = Field(alias="programmingLanguage") + programming_language: Optional[str] = Field(alias="programmingLanguage") template_repository: Dict[str, str] = Field(alias="templateRepository") solution_repository: Dict[str, str] = Field(alias="solutionRepository") test_repository: Dict[str, str] = Field(alias="testRepository") diff --git a/app/domain/iris_message.py b/app/domain/iris_message.py deleted file mode 100644 index 94969c96..00000000 --- a/app/domain/iris_message.py +++ /dev/null @@ -1,17 +0,0 @@ -from enum import Enum - -from pydantic import BaseModel - - -class IrisMessageRole(str, Enum): - USER = "user" - ASSISTANT = "assistant" - SYSTEM = "system" - - -class IrisMessage(BaseModel): - text: str = "" - role: IrisMessageRole - - def __str__(self): - return f"{self.role.lower()}: {self.text}" diff --git a/app/domain/pipeline_execution_dto.py b/app/domain/pipeline_execution_dto.py index 3f384b05..1b8ced83 100644 --- a/app/domain/pipeline_execution_dto.py +++ b/app/domain/pipeline_execution_dto.py @@ -2,8 +2,8 @@ from pydantic import BaseModel, Field -from ..domain.pipeline_execution_settings_dto import PipelineExecutionSettingsDTO -from ..domain.status.stage_dto import StageDTO +from app.domain.pipeline_execution_settings_dto import PipelineExecutionSettingsDTO +from app.domain.status.stage_dto import StageDTO class PipelineExecutionDTO(BaseModel): diff --git a/app/domain/pyris_message.py b/app/domain/pyris_message.py new file mode 100644 index 00000000..5f44cd9d --- /dev/null +++ b/app/domain/pyris_message.py @@ -0,0 +1,22 @@ +from datetime import datetime +from enum import Enum +from typing import List + +from pydantic import BaseModel, Field + +from app.domain.data.message_content_dto import MessageContentDTO + + +class IrisMessageRole(str, Enum): + USER = "USER" + ASSISTANT = "LLM" + SYSTEM = "SYSTEM" + + +class PyrisMessage(BaseModel): + sent_at: datetime | None = Field(alias="sentAt", default=None) + sender: IrisMessageRole + contents: List[MessageContentDTO] = [] + + def __str__(self): + return f"{self.sender.lower()}: {self.contents}" diff --git a/app/domain/tutor_chat/tutor_chat_pipeline_execution_dto.py b/app/domain/tutor_chat/tutor_chat_pipeline_execution_dto.py index 8c1db7c9..4221d124 100644 --- a/app/domain/tutor_chat/tutor_chat_pipeline_execution_dto.py +++ b/app/domain/tutor_chat/tutor_chat_pipeline_execution_dto.py @@ -2,9 +2,9 @@ from pydantic import Field +from ...domain.pyris_message import PyrisMessage from ...domain import PipelineExecutionDTO from ...domain.data.course_dto import CourseDTO -from ...domain.data.message_dto import MessageDTO from ...domain.data.programming_exercise_dto import ProgrammingExerciseDTO from ...domain.data.user_dto import UserDTO from ...domain.data.submission_dto import SubmissionDTO @@ -14,5 +14,5 @@ class TutorChatPipelineExecutionDTO(PipelineExecutionDTO): submission: Optional[SubmissionDTO] = None exercise: ProgrammingExerciseDTO course: CourseDTO - chat_history: List[MessageDTO] = Field(alias="chatHistory", default=[]) + chat_history: List[PyrisMessage] = Field(alias="chatHistory", default=[]) user: Optional[UserDTO] = None diff --git a/app/llm/external/model.py b/app/llm/external/model.py index 04520e81..4d42745b 100644 --- a/app/llm/external/model.py +++ b/app/llm/external/model.py @@ -1,7 +1,7 @@ from abc import ABCMeta, abstractmethod from pydantic import BaseModel -from ...domain import IrisMessage +from ...domain import PyrisMessage from ...llm import CompletionArguments from ...llm.capability import CapabilityList @@ -39,8 +39,8 @@ def __subclasshook__(cls, subclass) -> bool: @abstractmethod def chat( - self, messages: list[IrisMessage], arguments: CompletionArguments - ) -> IrisMessage: + self, messages: list[PyrisMessage], arguments: CompletionArguments + ) -> PyrisMessage: """Create a completion from the chat messages""" raise NotImplementedError( f"The LLM {self.__str__()} does not support chat completion" diff --git a/app/llm/external/ollama.py b/app/llm/external/ollama.py index 03a832a2..72dbb04e 100644 --- a/app/llm/external/ollama.py +++ b/app/llm/external/ollama.py @@ -1,20 +1,32 @@ +from datetime import datetime from typing import Literal, Any from ollama import Client, Message -from ...domain import IrisMessage, IrisMessageRole +from ...common.message_converters import map_role_to_str, map_str_to_role +from ...domain.data.text_message_content_dto import TextMessageContentDTO +from ...domain import PyrisMessage from ...llm import CompletionArguments from ...llm.external.model import ChatModel, CompletionModel, EmbeddingModel -def convert_to_ollama_messages(messages: list[IrisMessage]) -> list[Message]: +def convert_to_ollama_messages(messages: list[PyrisMessage]) -> list[Message]: return [ - Message(role=message.role.value, content=message.text) for message in messages + Message( + role=map_role_to_str(message.sender), + content=message.contents[0].text_content, + ) + for message in messages ] -def convert_to_iris_message(message: Message) -> IrisMessage: - return IrisMessage(role=IrisMessageRole(message["role"]), text=message["content"]) +def convert_to_iris_message(message: Message) -> PyrisMessage: + contents = [TextMessageContentDTO(text_content=message["content"])] + return PyrisMessage( + sender=map_str_to_role(message["role"]), + contents=contents, + send_at=datetime.now(), + ) class OllamaModel( @@ -35,8 +47,8 @@ def complete(self, prompt: str, arguments: CompletionArguments) -> str: return response["response"] def chat( - self, messages: list[IrisMessage], arguments: CompletionArguments - ) -> IrisMessage: + self, messages: list[PyrisMessage], arguments: CompletionArguments + ) -> PyrisMessage: response = self._client.chat( model=self.model, messages=convert_to_ollama_messages(messages) ) diff --git a/app/llm/external/openai_chat.py b/app/llm/external/openai_chat.py index 9e035810..450efdd7 100644 --- a/app/llm/external/openai_chat.py +++ b/app/llm/external/openai_chat.py @@ -1,26 +1,35 @@ +from datetime import datetime from typing import Literal, Any from openai import OpenAI from openai.lib.azure import AzureOpenAI from openai.types.chat import ChatCompletionMessageParam, ChatCompletionMessage -from ...domain import IrisMessage, IrisMessageRole +from ...common.message_converters import map_role_to_str, map_str_to_role +from app.domain.data.text_message_content_dto import TextMessageContentDTO +from ...domain import PyrisMessage from ...llm import CompletionArguments from ...llm.external.model import ChatModel def convert_to_open_ai_messages( - messages: list[IrisMessage], + messages: list[PyrisMessage], ) -> list[ChatCompletionMessageParam]: return [ - {"role": message.role.value, "content": message.text} for message in messages + { + "role": map_role_to_str(message.sender), + "content": message.contents[0].text_content, + } + for message in messages ] -def convert_to_iris_message(message: ChatCompletionMessage) -> IrisMessage: - # Get IrisMessageRole from the string message.role - message_role = IrisMessageRole(message.role) - return IrisMessage(role=message_role, text=message.content) +def convert_to_iris_message(message: ChatCompletionMessage) -> PyrisMessage: + return PyrisMessage( + sender=map_str_to_role(message.role), + contents=[TextMessageContentDTO(textContent=message.content)], + send_at=datetime.now(), + ) class OpenAIChatModel(ChatModel): @@ -29,8 +38,8 @@ class OpenAIChatModel(ChatModel): _client: OpenAI def chat( - self, messages: list[IrisMessage], arguments: CompletionArguments - ) -> IrisMessage: + self, messages: list[PyrisMessage], arguments: CompletionArguments + ) -> PyrisMessage: response = self._client.chat.completions.create( model=self.model, messages=convert_to_open_ai_messages(messages), diff --git a/app/llm/request_handler/basic_request_handler.py b/app/llm/request_handler/basic_request_handler.py index de8c87ea..dc07d545 100644 --- a/app/llm/request_handler/basic_request_handler.py +++ b/app/llm/request_handler/basic_request_handler.py @@ -1,4 +1,4 @@ -from app.domain import IrisMessage +from app.domain import PyrisMessage from app.llm.request_handler import RequestHandler from app.llm.completion_arguments import CompletionArguments from app.llm.llm_manager import LlmManager @@ -17,8 +17,8 @@ def complete(self, prompt: str, arguments: CompletionArguments) -> str: return llm.complete(prompt, arguments) def chat( - self, messages: list[IrisMessage], arguments: CompletionArguments - ) -> IrisMessage: + self, messages: list[PyrisMessage], arguments: CompletionArguments + ) -> PyrisMessage: llm = self.llm_manager.get_llm_by_id(self.model_id) return llm.chat(messages, arguments) diff --git a/app/llm/request_handler/capability_request_handler.py b/app/llm/request_handler/capability_request_handler.py index dc9d1f4a..1ed05b3d 100644 --- a/app/llm/request_handler/capability_request_handler.py +++ b/app/llm/request_handler/capability_request_handler.py @@ -1,6 +1,6 @@ from enum import Enum -from app.domain import IrisMessage +from app.domain import PyrisMessage from app.llm.capability import RequirementList from app.llm.external.model import ( ChatModel, @@ -41,8 +41,8 @@ def complete(self, prompt: str, arguments: CompletionArguments) -> str: return llm.complete(prompt, arguments) def chat( - self, messages: list[IrisMessage], arguments: CompletionArguments - ) -> IrisMessage: + self, messages: list[PyrisMessage], arguments: CompletionArguments + ) -> PyrisMessage: llm = self._select_model(ChatModel) return llm.chat(messages, arguments) diff --git a/app/llm/request_handler/request_handler_interface.py b/app/llm/request_handler/request_handler_interface.py index fede2ab7..4acdbe6d 100644 --- a/app/llm/request_handler/request_handler_interface.py +++ b/app/llm/request_handler/request_handler_interface.py @@ -1,6 +1,6 @@ from abc import ABCMeta, abstractmethod -from ...domain import IrisMessage +from ...domain import PyrisMessage from ...llm import CompletionArguments @@ -24,7 +24,7 @@ def complete(self, prompt: str, arguments: CompletionArguments) -> str: raise NotImplementedError @abstractmethod - def chat(self, messages: list[any], arguments: CompletionArguments) -> IrisMessage: + def chat(self, messages: list[any], arguments: CompletionArguments) -> PyrisMessage: """Create a completion from the chat messages""" raise NotImplementedError diff --git a/app/pipeline/chat/file_selector_pipeline.py b/app/pipeline/chat/file_selector_pipeline.py index 1f63422f..129a241f 100644 --- a/app/pipeline/chat/file_selector_pipeline.py +++ b/app/pipeline/chat/file_selector_pipeline.py @@ -7,7 +7,8 @@ from langchain_core.runnables import Runnable from pydantic import BaseModel -from ...llm import BasicRequestHandler, CompletionArguments +from ...llm import CapabilityRequestHandler, RequirementList +from ...llm import CompletionArguments from ...llm.langchain import IrisLangchainChatModel from ...pipeline import Pipeline from ...pipeline.chat.output_models.output_models.selected_file_model import ( @@ -41,7 +42,12 @@ class FileSelectorPipeline(Pipeline): def __init__(self, callback: Optional[StatusCallback] = None): super().__init__(implementation_id="file_selector_pipeline_reference_impl") - request_handler = BasicRequestHandler("gpt35") + request_handler = CapabilityRequestHandler( + requirements=RequirementList( + gpt_version_equivalent=3.5, + context_length=4096, + ) + ) completion_args = CompletionArguments(temperature=0, max_tokens=500) self.llm = IrisLangchainChatModel( request_handler=request_handler, completion_args=completion_args diff --git a/app/pipeline/chat/tutor_chat_pipeline.py b/app/pipeline/chat/tutor_chat_pipeline.py index ac79268a..ed3e9347 100644 --- a/app/pipeline/chat/tutor_chat_pipeline.py +++ b/app/pipeline/chat/tutor_chat_pipeline.py @@ -10,6 +10,9 @@ ) from langchain_core.runnables import Runnable +from ...common import convert_iris_message_to_langchain_message +from ...domain import PyrisMessage +from ...llm import CapabilityRequestHandler, RequirementList from ...domain.data.build_log_entry import BuildLogEntryDTO from ...domain.data.feedback_dto import FeedbackDTO from ..prompts.iris_tutor_chat_prompts import ( @@ -20,10 +23,9 @@ ) from ...domain import TutorChatPipelineExecutionDTO from ...domain.data.submission_dto import SubmissionDTO -from ...domain.data.message_dto import MessageDTO from ...web.status.status_update import TutorChatStatusCallback from .file_selector_pipeline import FileSelectorPipeline -from ...llm import BasicRequestHandler, CompletionArguments +from ...llm import CompletionArguments from ...llm.langchain import IrisLangchainChatModel from ..pipeline import Pipeline @@ -43,7 +45,13 @@ class TutorChatPipeline(Pipeline): def __init__(self, callback: TutorChatStatusCallback): super().__init__(implementation_id="tutor_chat_pipeline") # Set the langchain chat model - request_handler = BasicRequestHandler("gpt35") + request_handler = CapabilityRequestHandler( + requirements=RequirementList( + gpt_version_equivalent=3.5, + context_length=16385, + privacy_compliance=True, + ) + ) completion_args = CompletionArguments(temperature=0.2, max_tokens=2000) self.llm = IrisLangchainChatModel( request_handler=request_handler, completion_args=completion_args @@ -74,8 +82,8 @@ def __call__(self, dto: TutorChatPipelineExecutionDTO, **kwargs): ] ) logger.info("Running tutor chat pipeline...") - history: List[MessageDTO] = dto.chat_history[:-1] - query: MessageDTO = dto.chat_history[-1] + history: List[PyrisMessage] = dto.chat_history[:-1] + query: PyrisMessage = dto.chat_history[-1] submission: SubmissionDTO = dto.submission build_logs: List[BuildLogEntryDTO] = [] @@ -88,7 +96,7 @@ def __call__(self, dto: TutorChatPipelineExecutionDTO, **kwargs): problem_statement: str = dto.exercise.problem_statement exercise_title: str = dto.exercise.name - programming_language = dto.exercise.programming_language.value.lower() + programming_language = dto.exercise.programming_language.lower() # Add the chat history and user question to the prompt self._add_conversation_to_prompt(history, query) @@ -138,12 +146,13 @@ def __call__(self, dto: TutorChatPipelineExecutionDTO, **kwargs): logger.info(f"Response from tutor chat pipeline: {response}") self.callback.done("Generated response", final_result=response) except Exception as e: + print(e) self.callback.error(f"Failed to generate response: {e}") def _add_conversation_to_prompt( self, - chat_history: List[MessageDTO], - user_question: MessageDTO, + chat_history: List[PyrisMessage], + user_question: PyrisMessage, ): """ Adds the chat history and user question to the prompt @@ -153,13 +162,14 @@ def _add_conversation_to_prompt( """ if chat_history is not None and len(chat_history) > 0: chat_history_messages = [ - message.convert_to_langchain_message() for message in chat_history + convert_iris_message_to_langchain_message(message) + for message in chat_history ] self.prompt += chat_history_messages self.prompt += SystemMessagePromptTemplate.from_template( "Now, consider the student's newest and latest input:" ) - self.prompt += user_question.convert_to_langchain_message() + self.prompt += convert_iris_message_to_langchain_message(user_question) def _add_student_repository_to_prompt( self, student_repository: Dict[str, str], selected_files: List[str] diff --git a/app/pipeline/shared/summary_pipeline.py b/app/pipeline/shared/summary_pipeline.py index 9d6572d6..382881a2 100644 --- a/app/pipeline/shared/summary_pipeline.py +++ b/app/pipeline/shared/summary_pipeline.py @@ -1,12 +1,11 @@ import logging import os -from typing import Dict from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate, SystemMessagePromptTemplate from langchain_core.runnables import Runnable -from ...llm import BasicRequestHandler +from ...llm import CapabilityRequestHandler, RequirementList from ...llm.langchain import IrisLangchainCompletionModel from ...pipeline import Pipeline @@ -16,7 +15,6 @@ class SummaryPipeline(Pipeline): """A generic summary pipeline that can be used to summarize any text""" - _cache: Dict = {} llm: IrisLangchainCompletionModel pipeline: Runnable prompt_str: str @@ -25,7 +23,12 @@ class SummaryPipeline(Pipeline): def __init__(self): super().__init__(implementation_id="summary_pipeline") # Set the langchain chat model - request_handler = BasicRequestHandler("gpt35-completion") + request_handler = CapabilityRequestHandler( + requirements=RequirementList( + gpt_version_equivalent=3.5, + context_length=4096, + ) + ) self.llm = IrisLangchainCompletionModel( request_handler=request_handler, max_tokens=1000 ) @@ -59,10 +62,6 @@ def __call__(self, query: str, **kwargs) -> str: if query is None: raise ValueError("Query must not be None") logger.info("Running summary pipeline...") - if _cache := self._cache.get(query): - logger.info(f"Returning cached summary for query: {query[:20]}...") - return _cache response: str = self.pipeline.invoke({"text": query}) logger.info(f"Response from summary pipeline: {response[:20]}...") - self._cache[query] = response return response diff --git a/app/web/status/status_update.py b/app/web/status/status_update.py index 9b23e552..2997409a 100644 --- a/app/web/status/status_update.py +++ b/app/web/status/status_update.py @@ -43,23 +43,16 @@ def __init__( self, run_id: str, base_url: str, initial_stages: List[StageDTO] = None ): url = f"{base_url}/api/public/pyris/pipelines/tutor-chat/runs/{run_id}/status" - if initial_stages is not None and len(initial_stages) > 0: - stages = initial_stages - current_stage_index = len(initial_stages) - else: - stages = [] - current_stage_index = 0 - - stages.append( - StageDTO(weight=30, state=StageStateEnum.NOT_STARTED, name="File Lookup") - ) - stages.append( + current_stage_index = len(initial_stages) if initial_stages else 0 + stages = initial_stages or [] + stages += [ + StageDTO(weight=30, state=StageStateEnum.NOT_STARTED, name="File Lookup"), StageDTO( weight=70, state=StageStateEnum.NOT_STARTED, name="Response Generation", - ) - ) + ), + ] status = TutorChatStatusUpdateDTO(stages=stages) stage = stages[current_stage_index] super().__init__(url, run_id, status, stage, current_stage_index)