diff --git a/bot/rp_bot/ai_agent/ai.py b/bot/rp_bot/ai_agent/ai.py index a203582..fff4d1f 100644 --- a/bot/rp_bot/ai_agent/ai.py +++ b/bot/rp_bot/ai_agent/ai.py @@ -1,6 +1,6 @@ import tiktoken import io -from typing import AsyncIterator, Literal +from typing import AsyncIterator, Literal, Optional from openai import OpenAI from langchain_openai import OpenAI from langchain_core.messages import HumanMessage, SystemMessage @@ -30,6 +30,31 @@ def __init__( model=self._get_default_image_generation_model_name(), ) + def _get_default_model( + self, model_type: Literal["text", "vision", "image_generation"] + ) -> Optional[Model]: + params_dict = { + "text": { + "models_dict": self.ai_config.TextGeneration.Models, + "default_attr": "text_default", + }, + "vision": { + "models_dict": self.ai_config.TextGeneration.Models, + "default_attr": "vision_default", + }, + "image_generation": { + "models_dict": self.ai_config.ImageGeneration.Models, + "default_attr": "image_generation_default", + }, + } + first_model = None + for model in params_dict[model_type]["models_dict"].values(): + if getattr(model, params_dict[model_type]["default_attr"]): + return model + if first_model is None: + first_model = model + return first_model + def _get_default_model_name( self, model_type: Literal["text", "vision", "image_generation"] ) -> str: