Skip to content

Commit

Permalink
Implement the _get_default_model method
Browse files Browse the repository at this point in the history
  • Loading branch information
Flagro committed Oct 15, 2024
1 parent 19f4075 commit 2b8353a
Showing 1 changed file with 26 additions and 1 deletion.
27 changes: 26 additions & 1 deletion bot/rp_bot/ai_agent/ai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import tiktoken
import io
from typing import AsyncIterator, Literal
from typing import AsyncIterator, Literal, Optional
from openai import OpenAI
from langchain_openai import OpenAI
from langchain_core.messages import HumanMessage, SystemMessage
Expand Down Expand Up @@ -30,6 +30,31 @@ def __init__(
model=self._get_default_image_generation_model_name(),
)

def _get_default_model(
self, model_type: Literal["text", "vision", "image_generation"]
) -> Optional[Model]:
params_dict = {
"text": {
"models_dict": self.ai_config.TextGeneration.Models,
"default_attr": "text_default",
},
"vision": {
"models_dict": self.ai_config.TextGeneration.Models,
"default_attr": "vision_default",
},
"image_generation": {
"models_dict": self.ai_config.ImageGeneration.Models,
"default_attr": "image_generation_default",
},
}
first_model = None
for model in params_dict[model_type]["models_dict"].values():
if getattr(model, params_dict[model_type]["default_attr"]):
return model
if first_model is None:
first_model = model
return first_model

def _get_default_model_name(
self, model_type: Literal["text", "vision", "image_generation"]
) -> str:
Expand Down

0 comments on commit 2b8353a

Please sign in to comment.