diff --git a/application/cache.py b/application/cache.py index 76b594c93..80dee4f48 100644 --- a/application/cache.py +++ b/application/cache.py @@ -33,7 +33,7 @@ def gen_cache_key(messages, model="docgpt", tools=None): if not all(isinstance(msg, dict) for msg in messages): raise ValueError("All messages must be dictionaries.") messages_str = json.dumps(messages) - tools_str = json.dumps(tools) if tools else "" + tools_str = json.dumps(str(tools)) if tools else "" combined = f"{model}_{messages_str}_{tools_str}" cache_key = get_hash(combined) return cache_key @@ -68,8 +68,8 @@ def wrapper(self, model, messages, stream, tools=None, *args, **kwargs): def stream_cache(func): - def wrapper(self, model, messages, stream, *args, **kwargs): - cache_key = gen_cache_key(messages) + def wrapper(self, model, messages, stream, tools=None, *args, **kwargs): + cache_key = gen_cache_key(messages, model, tools) logger.info(f"Stream cache key: {cache_key}") redis_client = get_redis_instance() @@ -86,7 +86,7 @@ def wrapper(self, model, messages, stream, *args, **kwargs): except redis.ConnectionError as e: logger.error(f"Redis connection error: {e}") - result = func(self, model, messages, stream, *args, **kwargs) + result = func(self, model, messages, stream, tools=tools, *args, **kwargs) stream_cache_data = [] for chunk in result: diff --git a/application/llm/base.py b/application/llm/base.py index b9b0e5243..e687e567b 100644 --- a/application/llm/base.py +++ b/application/llm/base.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod + +from application.cache import gen_cache, stream_cache from application.usage import gen_token_usage, stream_token_usage -from application.cache import stream_cache, gen_cache class BaseLLM(ABC): @@ -18,18 +19,38 @@ def _raw_gen(self, model, messages, stream, tools, *args, **kwargs): def gen(self, model, messages, stream=False, tools=None, *args, **kwargs): decorators = [gen_token_usage, gen_cache] - return self._apply_decorator(self._raw_gen, decorators=decorators, model=model, messages=messages, stream=stream, tools=tools, *args, **kwargs) + return self._apply_decorator( + self._raw_gen, + decorators=decorators, + model=model, + messages=messages, + stream=stream, + tools=tools, + *args, + **kwargs + ) @abstractmethod def _raw_gen_stream(self, model, messages, stream, *args, **kwargs): pass - def gen_stream(self, model, messages, stream=True, *args, **kwargs): + def gen_stream(self, model, messages, stream=True, tools=None, *args, **kwargs): decorators = [stream_cache, stream_token_usage] - return self._apply_decorator(self._raw_gen_stream, decorators=decorators, model=model, messages=messages, stream=stream, *args, **kwargs) - + return self._apply_decorator( + self._raw_gen_stream, + decorators=decorators, + model=model, + messages=messages, + stream=stream, + tools=tools, + *args, + **kwargs + ) + def supports_tools(self): - return hasattr(self, '_supports_tools') and callable(getattr(self, '_supports_tools')) + return hasattr(self, "_supports_tools") and callable( + getattr(self, "_supports_tools") + ) def _supports_tools(self): - raise NotImplementedError("Subclass must implement _supports_tools method") \ No newline at end of file + raise NotImplementedError("Subclass must implement _supports_tools method") diff --git a/application/llm/google_ai.py b/application/llm/google_ai.py index df252abfc..ae8880421 100644 --- a/application/llm/google_ai.py +++ b/application/llm/google_ai.py @@ -1,21 +1,85 @@ +from google import genai +from google.genai import types + from application.llm.base import BaseLLM -class GoogleLLM(BaseLLM): +class GoogleLLM(BaseLLM): def __init__(self, api_key=None, user_api_key=None, *args, **kwargs): - super().__init__(*args, **kwargs) self.api_key = api_key self.user_api_key = user_api_key def _clean_messages_google(self, messages): - return [ - { - "role": "model" if message["role"] == "system" else message["role"], - "parts": [message["content"]], - } - for message in messages[1:] - ] + cleaned_messages = [] + for message in messages: + role = message.get("role") + content = message.get("content") + + if role == "assistant": + role = "model" + + parts = [] + if role and content is not None: + if isinstance(content, str): + parts = [types.Part.from_text(content)] + elif isinstance(content, list): + for item in content: + if "text" in item: + parts.append(types.Part.from_text(item["text"])) + elif "function_call" in item: + parts.append( + types.Part.from_function_call( + name=item["function_call"]["name"], + args=item["function_call"]["args"], + ) + ) + elif "function_response" in item: + parts.append( + types.Part.from_function_response( + name=item["function_response"]["name"], + response=item["function_response"]["response"], + ) + ) + else: + raise ValueError( + f"Unexpected content dictionary format:{item}" + ) + else: + raise ValueError(f"Unexpected content type: {type(content)}") + + cleaned_messages.append(types.Content(role=role, parts=parts)) + + return cleaned_messages + + def _clean_tools_format(self, tools_list): + genai_tools = [] + for tool_data in tools_list: + if tool_data["type"] == "function": + function = tool_data["function"] + genai_function = dict( + name=function["name"], + description=function["description"], + parameters={ + "type": "OBJECT", + "properties": { + k: { + **v, + "type": v["type"].upper() if v["type"] else None, + } + for k, v in function["parameters"]["properties"].items() + }, + "required": ( + function["parameters"]["required"] + if "required" in function["parameters"] + else [] + ), + }, + ) + genai_tool = types.Tool(function_declarations=[genai_function]) + genai_tools.append(genai_tool) + + return genai_tools def _raw_gen( self, @@ -23,13 +87,32 @@ def _raw_gen( model, messages, stream=False, - **kwargs - ): - import google.generativeai as genai - genai.configure(api_key=self.api_key) - model = genai.GenerativeModel(model, system_instruction=messages[0]["content"]) - response = model.generate_content(self._clean_messages_google(messages)) - return response.text + tools=None, + formatting="openai", + **kwargs, + ): + client = genai.Client(api_key=self.api_key) + if formatting == "openai": + messages = self._clean_messages_google(messages) + config = types.GenerateContentConfig() + if messages[0].role == "system": + config.system_instruction = messages[0].parts[0].text + messages = messages[1:] + + if tools: + cleaned_tools = self._clean_tools_format(tools) + config.tools = cleaned_tools + response = client.models.generate_content( + model=model, + contents=messages, + config=config, + ) + return response + else: + response = client.models.generate_content( + model=model, contents=messages, config=config + ) + return response.text def _raw_gen_stream( self, @@ -37,12 +120,30 @@ def _raw_gen_stream( model, messages, stream=True, - **kwargs - ): - import google.generativeai as genai - genai.configure(api_key=self.api_key) - model = genai.GenerativeModel(model, system_instruction=messages[0]["content"]) - response = model.generate_content(self._clean_messages_google(messages), stream=True) - for line in response: - if line.text is not None: - yield line.text \ No newline at end of file + tools=None, + formatting="openai", + **kwargs, + ): + client = genai.Client(api_key=self.api_key) + if formatting == "openai": + messages = self._clean_messages_google(messages) + config = types.GenerateContentConfig() + if messages[0].role == "system": + config.system_instruction = messages[0].parts[0].text + messages = messages[1:] + + if tools: + cleaned_tools = self._clean_tools_format(tools) + config.tools = cleaned_tools + + response = client.models.generate_content_stream( + model=model, + contents=messages, + config=config, + ) + for chunk in response: + if chunk.text is not None: + yield chunk.text + + def _supports_tools(self): + return True diff --git a/application/llm/openai.py b/application/llm/openai.py index cc2285a12..b507a1da8 100644 --- a/application/llm/openai.py +++ b/application/llm/openai.py @@ -1,6 +1,5 @@ -from application.llm.base import BaseLLM from application.core.settings import settings - +from application.llm.base import BaseLLM class OpenAILLM(BaseLLM): @@ -10,10 +9,7 @@ def __init__(self, api_key=None, user_api_key=None, *args, **kwargs): super().__init__(*args, **kwargs) if settings.OPENAI_BASE_URL: - self.client = OpenAI( - api_key=api_key, - base_url=settings.OPENAI_BASE_URL - ) + self.client = OpenAI(api_key=api_key, base_url=settings.OPENAI_BASE_URL) else: self.client = OpenAI(api_key=api_key) self.api_key = api_key @@ -27,8 +23,8 @@ def _raw_gen( stream=False, tools=None, engine=settings.AZURE_DEPLOYMENT_NAME, - **kwargs - ): + **kwargs, + ): if tools: response = self.client.chat.completions.create( model=model, messages=messages, stream=stream, tools=tools, **kwargs @@ -48,18 +44,16 @@ def _raw_gen_stream( stream=True, tools=None, engine=settings.AZURE_DEPLOYMENT_NAME, - **kwargs - ): + **kwargs, + ): response = self.client.chat.completions.create( model=model, messages=messages, stream=stream, **kwargs ) for line in response: - # import sys - # print(line.choices[0].delta.content, file=sys.stderr) if line.choices[0].delta.content is not None: yield line.choices[0].delta.content - + def _supports_tools(self): return True diff --git a/application/requirements.txt b/application/requirements.txt index 3fc6d02d6..c193f38de 100644 --- a/application/requirements.txt +++ b/application/requirements.txt @@ -14,6 +14,8 @@ esutils==1.0.1 Flask==3.1.0 faiss-cpu==1.9.0.post1 flask-restx==1.3.0 +google-genai==0.5.0 +google-generativeai==0.8.3 gTTS==2.5.4 gunicorn==23.0.0 html2text==2024.2.26 diff --git a/application/retriever/brave_search.py b/application/retriever/brave_search.py index 3d9ae89e6..efcae8ab9 100644 --- a/application/retriever/brave_search.py +++ b/application/retriever/brave_search.py @@ -74,12 +74,10 @@ def gen(self): if len(self.chat_history) > 1: for i in self.chat_history: if "prompt" in i and "response" in i: - messages_combine.append( - {"role": "user", "content": i["prompt"]} - ) - messages_combine.append( - {"role": "system", "content": i["response"]} - ) + messages_combine.append({"role": "user", "content": i["prompt"]}) + messages_combine.append( + {"role": "assistant", "content": i["response"]} + ) messages_combine.append({"role": "user", "content": self.question}) llm = LLMCreator.create_llm( diff --git a/application/retriever/classic_rag.py b/application/retriever/classic_rag.py index 2e3555137..b3735a963 100644 --- a/application/retriever/classic_rag.py +++ b/application/retriever/classic_rag.py @@ -5,7 +5,6 @@ from application.vectorstore.vector_creator import VectorCreator - class ClassicRAG(BaseRetriever): def __init__( @@ -74,13 +73,11 @@ def gen(self): if len(self.chat_history) > 1: for i in self.chat_history: - if "prompt" in i and "response" in i: - messages_combine.append( - {"role": "user", "content": i["prompt"]} - ) - messages_combine.append( - {"role": "system", "content": i["response"]} - ) + if "prompt" in i and "response" in i: + messages_combine.append({"role": "user", "content": i["prompt"]}) + messages_combine.append( + {"role": "assistant", "content": i["response"]} + ) messages_combine.append({"role": "user", "content": self.question}) # llm = LLMCreator.create_llm( # settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key diff --git a/application/retriever/duckduck_search.py b/application/retriever/duckduck_search.py index fa19ead03..321c6fd96 100644 --- a/application/retriever/duckduck_search.py +++ b/application/retriever/duckduck_search.py @@ -91,11 +91,9 @@ def gen(self): if len(self.chat_history) > 1: for i in self.chat_history: if "prompt" in i and "response" in i: + messages_combine.append({"role": "user", "content": i["prompt"]}) messages_combine.append( - {"role": "user", "content": i["prompt"]} - ) - messages_combine.append( - {"role": "system", "content": i["response"]} + {"role": "assistant", "content": i["response"]} ) messages_combine.append({"role": "user", "content": self.question}) diff --git a/application/tools/agent.py b/application/tools/agent.py index d4077e45d..209184d2c 100644 --- a/application/tools/agent.py +++ b/application/tools/agent.py @@ -1,7 +1,7 @@ -import json - from application.core.mongo_db import MongoDB from application.llm.llm_creator import LLMCreator +from application.tools.llm_handler import get_llm_handler +from application.tools.tool_action_parser import ToolActionParser from application.tools.tool_manager import ToolManager @@ -11,6 +11,7 @@ def __init__(self, llm_name, gpt_model, api_key, user_api_key=None): self.llm = LLMCreator.create_llm( llm_name, api_key=api_key, user_api_key=user_api_key ) + self.llm_handler = get_llm_handler(llm_name) self.gpt_model = gpt_model # Static tool configuration (to be replaced later) self.tools = [] @@ -60,10 +61,8 @@ def _prepare_tools(self, tools_dict): ] def _execute_tool_action(self, tools_dict, call): - call_id = call.id - call_args = json.loads(call.function.arguments) - tool_id = call.function.name.split("_")[-1] - action_name = call.function.name.rsplit("_", 1)[0] + parser = ToolActionParser(self.llm.__class__.__name__) + tool_id, action_name, call_args = parser.parse_args(call) tool_data = tools_dict[tool_id] action_data = next( @@ -77,7 +76,9 @@ def _execute_tool_action(self, tools_dict, call): tm = ToolManager(config={}) tool = tm.load_tool(tool_data["name"], tool_config=tool_data["config"]) print(f"Executing tool: {action_name} with args: {call_args}") - return tool.execute_action(action_name, **call_args), call_id + result = tool.execute_action(action_name, **call_args) + call_id = getattr(call, "id", None) + return result, call_id def _simple_tool_agent(self, messages): tools_dict = self._get_user_tools() @@ -88,45 +89,15 @@ def _simple_tool_agent(self, messages): if isinstance(resp, str): yield resp return - if resp.message.content: + if hasattr(resp, "message") and hasattr(resp.message, "content"): yield resp.message.content return - while resp.finish_reason == "tool_calls": - message = json.loads(resp.model_dump_json())["message"] - keys_to_remove = {"audio", "function_call", "refusal"} - filtered_data = { - k: v for k, v in message.items() if k not in keys_to_remove - } - messages.append(filtered_data) - tool_calls = resp.message.tool_calls - for call in tool_calls: - try: - tool_response, call_id = self._execute_tool_action(tools_dict, call) - messages.append( - { - "role": "tool", - "content": str(tool_response), - "tool_call_id": call_id, - } - ) - except Exception as e: - messages.append( - { - "role": "tool", - "content": f"Error executing tool: {str(e)}", - "tool_call_id": call.id, - } - ) - # Generate a new response from the LLM after processing tools - resp = self.llm.gen( - model=self.gpt_model, messages=messages, tools=self.tools - ) + resp = self.llm_handler.handle_response(self, resp, tools_dict, messages) - # If no tool calls are needed, generate the final response if isinstance(resp, str): yield resp - elif resp.message.content: + elif hasattr(resp, "message") and hasattr(resp.message, "content"): yield resp.message.content else: completion = self.llm.gen_stream( @@ -138,7 +109,6 @@ def _simple_tool_agent(self, messages): return def gen(self, messages): - # Generate initial response from the LLM if self.llm.supports_tools(): resp = self._simple_tool_agent(messages) for line in resp: diff --git a/application/tools/llm_handler.py b/application/tools/llm_handler.py new file mode 100644 index 000000000..cc7494c02 --- /dev/null +++ b/application/tools/llm_handler.py @@ -0,0 +1,97 @@ +import json +from abc import ABC, abstractmethod + + +class LLMHandler(ABC): + @abstractmethod + def handle_response(self, agent, resp, tools_dict, messages, **kwargs): + pass + + +class OpenAILLMHandler(LLMHandler): + def handle_response(self, agent, resp, tools_dict, messages): + while resp.finish_reason == "tool_calls": + message = json.loads(resp.model_dump_json())["message"] + keys_to_remove = {"audio", "function_call", "refusal"} + filtered_data = { + k: v for k, v in message.items() if k not in keys_to_remove + } + messages.append(filtered_data) + + tool_calls = resp.message.tool_calls + for call in tool_calls: + try: + tool_response, call_id = agent._execute_tool_action( + tools_dict, call + ) + messages.append( + { + "role": "tool", + "content": str(tool_response), + "tool_call_id": call_id, + } + ) + except Exception as e: + messages.append( + { + "role": "tool", + "content": f"Error executing tool: {str(e)}", + "tool_call_id": call_id, + } + ) + resp = agent.llm.gen( + model=agent.gpt_model, messages=messages, tools=agent.tools + ) + return resp + + +class GoogleLLMHandler(LLMHandler): + def handle_response(self, agent, resp, tools_dict, messages): + from google.genai import types + + while True: + response = agent.llm.gen( + model=agent.gpt_model, messages=messages, tools=agent.tools + ) + if response.candidates and response.candidates[0].content.parts: + tool_call_found = False + for part in response.candidates[0].content.parts: + if part.function_call: + tool_call_found = True + tool_response, call_id = agent._execute_tool_action( + tools_dict, part.function_call + ) + function_response_part = types.Part.from_function_response( + name=part.function_call.name, + response={"result": tool_response}, + ) + + messages.append( + {"role": "model", "content": [part.to_json_dict()]} + ) + messages.append( + { + "role": "tool", + "content": [function_response_part.to_json_dict()], + } + ) + + if ( + not tool_call_found + and response.candidates[0].content.parts + and response.candidates[0].content.parts[0].text + ): + return response.candidates[0].content.parts[0].text + elif not tool_call_found: + return response.candidates[0].content.parts + + else: + return response + + +def get_llm_handler(llm_type): + handlers = { + "openai": OpenAILLMHandler(), + "google": GoogleLLMHandler(), + } + return handlers.get(llm_type, OpenAILLMHandler()) diff --git a/application/tools/tool_action_parser.py b/application/tools/tool_action_parser.py new file mode 100644 index 000000000..ac0a70c16 --- /dev/null +++ b/application/tools/tool_action_parser.py @@ -0,0 +1,26 @@ +import json + + +class ToolActionParser: + def __init__(self, llm_type): + self.llm_type = llm_type + self.parsers = { + "OpenAILLM": self._parse_openai_llm, + "GoogleLLM": self._parse_google_llm, + } + + def parse_args(self, call): + parser = self.parsers.get(self.llm_type, self._parse_openai_llm) + return parser(call) + + def _parse_openai_llm(self, call): + call_args = json.loads(call.function.arguments) + tool_id = call.function.name.split("_")[-1] + action_name = call.function.name.rsplit("_", 1)[0] + return tool_id, action_name, call_args + + def _parse_google_llm(self, call): + call_args = call.args + tool_id = call.name.split("_")[-1] + action_name = call.name.rsplit("_", 1)[0] + return tool_id, action_name, call_args