Skip to content

Commit

Permalink
add google
Browse files Browse the repository at this point in the history
  • Loading branch information
dartpain committed Jan 13, 2025
1 parent 8935dc4 commit 51225b1
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 16 deletions.
39 changes: 24 additions & 15 deletions application/llm/google_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,29 +72,38 @@ def _raw_gen(
messages,
stream=False,
tools=None,
formatting="openai",
**kwargs
):
import google.generativeai as genai
genai.configure(api_key=self.api_key)
from google import genai
from google.genai import types
client = genai.Client(api_key=self.api_key)


config = {
}
model = 'gemini-2.0-flash-exp'
if formatting=="raw":
response = client.models.generate_content(
model=model,
contents=messages
)

model = genai.GenerativeModel(
model_name=model,
generation_config=config,
system_instruction=messages[0]["content"],
tools=self._clean_tools_format(tools)
else:
model = genai.GenerativeModel(
model_name=model,
generation_config=config,
system_instruction=messages[0]["content"],
tools=self._clean_tools_format(tools)
)
chat_session = model.start_chat(
history=self._clean_messages_google(messages)[:-1]
)
chat_session = model.start_chat(
history=self._clean_messages_google(messages)[:-1]
)
response = chat_session.send_message(
self._clean_messages_google(messages)[-1]
)
logging.info(response)
return response.text
response = chat_session.send_message(
self._clean_messages_google(messages)[-1]
)
logging.info(response)
return response.text

def _raw_gen_stream(
self,
Expand Down
32 changes: 31 additions & 1 deletion application/tools/agent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import logging

from application.core.mongo_db import MongoDB
from application.llm.llm_creator import LLMCreator
Expand Down Expand Up @@ -79,6 +80,25 @@ def _execute_tool_action(self, tools_dict, call):
print(f"Executing tool: {action_name} with args: {call_args}")
return tool.execute_action(action_name, **call_args), call_id

def _execute_tool_action_google(self, tools_dict, call):
call_args = json.loads(call.args)
tool_id = call.name.split("_")[-1]
action_name = call.name.rsplit("_", 1)[0]

tool_data = tools_dict[tool_id]
action_data = next(
action for action in tool_data["actions"] if action["name"] == action_name
)

for param, details in action_data["parameters"]["properties"].items():
if param not in call_args and "value" in details:
call_args[param] = details["value"]

tm = ToolManager(config={})
tool = tm.load_tool(tool_data["name"], tool_config=tool_data["config"])
print(f"Executing tool: {action_name} with args: {call_args}")
return tool.execute_action(action_name, **call_args)

def _simple_tool_agent(self, messages):
tools_dict = self._get_user_tools()
self._prepare_tools(tools_dict)
Expand All @@ -91,8 +111,18 @@ def _simple_tool_agent(self, messages):
if resp.message.content:
yield resp.message.content
return
# check if self.llm class is GoogleLLM
while self.llm.__class__.__name__ == "GoogleLLM" and resp.content.parts[0].function_call:
from google.genai import types

function_call_part = resp.candidates[0].content.parts[0]
tool_response = self._execute_tool_action_google(tools_dict, function_call_part.function_call)
function_response_part = types.Part.from_function_response(
name=function_call_part.function_call.name,
response=tool_response
)

while resp.finish_reason == "tool_calls":
while self.llm.__class__.__name__ == "OpenAILLM" and resp.finish_reason == "tool_calls":
message = json.loads(resp.model_dump_json())["message"]
keys_to_remove = {"audio", "function_call", "refusal"}
filtered_data = {
Expand Down

0 comments on commit 51225b1

Please sign in to comment.