Skip to content

Commit

Permalink
Merge pull request #1594 from arc53/googleai-compatability-tools
Browse files Browse the repository at this point in the history
Googleai compatability tools
  • Loading branch information
dartpain authored Jan 21, 2025
2 parents c0a2daa + b965ce7 commit 2606e6b
Show file tree
Hide file tree
Showing 11 changed files with 312 additions and 108 deletions.
8 changes: 4 additions & 4 deletions application/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def gen_cache_key(messages, model="docgpt", tools=None):
if not all(isinstance(msg, dict) for msg in messages):
raise ValueError("All messages must be dictionaries.")
messages_str = json.dumps(messages)
tools_str = json.dumps(tools) if tools else ""
tools_str = json.dumps(str(tools)) if tools else ""
combined = f"{model}_{messages_str}_{tools_str}"
cache_key = get_hash(combined)
return cache_key
Expand Down Expand Up @@ -68,8 +68,8 @@ def wrapper(self, model, messages, stream, tools=None, *args, **kwargs):


def stream_cache(func):
def wrapper(self, model, messages, stream, *args, **kwargs):
cache_key = gen_cache_key(messages)
def wrapper(self, model, messages, stream, tools=None, *args, **kwargs):
cache_key = gen_cache_key(messages, model, tools)
logger.info(f"Stream cache key: {cache_key}")

redis_client = get_redis_instance()
Expand All @@ -86,7 +86,7 @@ def wrapper(self, model, messages, stream, *args, **kwargs):
except redis.ConnectionError as e:
logger.error(f"Redis connection error: {e}")

result = func(self, model, messages, stream, *args, **kwargs)
result = func(self, model, messages, stream, tools=tools, *args, **kwargs)
stream_cache_data = []

for chunk in result:
Expand Down
35 changes: 28 additions & 7 deletions application/llm/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod

from application.cache import gen_cache, stream_cache
from application.usage import gen_token_usage, stream_token_usage
from application.cache import stream_cache, gen_cache


class BaseLLM(ABC):
Expand All @@ -18,18 +19,38 @@ def _raw_gen(self, model, messages, stream, tools, *args, **kwargs):

def gen(self, model, messages, stream=False, tools=None, *args, **kwargs):
decorators = [gen_token_usage, gen_cache]
return self._apply_decorator(self._raw_gen, decorators=decorators, model=model, messages=messages, stream=stream, tools=tools, *args, **kwargs)
return self._apply_decorator(
self._raw_gen,
decorators=decorators,
model=model,
messages=messages,
stream=stream,
tools=tools,
*args,
**kwargs
)

@abstractmethod
def _raw_gen_stream(self, model, messages, stream, *args, **kwargs):
pass

def gen_stream(self, model, messages, stream=True, *args, **kwargs):
def gen_stream(self, model, messages, stream=True, tools=None, *args, **kwargs):
decorators = [stream_cache, stream_token_usage]
return self._apply_decorator(self._raw_gen_stream, decorators=decorators, model=model, messages=messages, stream=stream, *args, **kwargs)

return self._apply_decorator(
self._raw_gen_stream,
decorators=decorators,
model=model,
messages=messages,
stream=stream,
tools=tools,
*args,
**kwargs
)

def supports_tools(self):
return hasattr(self, '_supports_tools') and callable(getattr(self, '_supports_tools'))
return hasattr(self, "_supports_tools") and callable(
getattr(self, "_supports_tools")
)

def _supports_tools(self):
raise NotImplementedError("Subclass must implement _supports_tools method")
raise NotImplementedError("Subclass must implement _supports_tools method")
151 changes: 126 additions & 25 deletions application/llm/google_ai.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,149 @@
from google import genai
from google.genai import types

from application.llm.base import BaseLLM

class GoogleLLM(BaseLLM):

class GoogleLLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):

super().__init__(*args, **kwargs)
self.api_key = api_key
self.user_api_key = user_api_key

def _clean_messages_google(self, messages):
return [
{
"role": "model" if message["role"] == "system" else message["role"],
"parts": [message["content"]],
}
for message in messages[1:]
]
cleaned_messages = []
for message in messages:
role = message.get("role")
content = message.get("content")

if role == "assistant":
role = "model"

parts = []
if role and content is not None:
if isinstance(content, str):
parts = [types.Part.from_text(content)]
elif isinstance(content, list):
for item in content:
if "text" in item:
parts.append(types.Part.from_text(item["text"]))
elif "function_call" in item:
parts.append(
types.Part.from_function_call(
name=item["function_call"]["name"],
args=item["function_call"]["args"],
)
)
elif "function_response" in item:
parts.append(
types.Part.from_function_response(
name=item["function_response"]["name"],
response=item["function_response"]["response"],
)
)
else:
raise ValueError(
f"Unexpected content dictionary format:{item}"
)
else:
raise ValueError(f"Unexpected content type: {type(content)}")

cleaned_messages.append(types.Content(role=role, parts=parts))

return cleaned_messages

def _clean_tools_format(self, tools_list):
genai_tools = []
for tool_data in tools_list:
if tool_data["type"] == "function":
function = tool_data["function"]
genai_function = dict(
name=function["name"],
description=function["description"],
parameters={
"type": "OBJECT",
"properties": {
k: {
**v,
"type": v["type"].upper() if v["type"] else None,
}
for k, v in function["parameters"]["properties"].items()
},
"required": (
function["parameters"]["required"]
if "required" in function["parameters"]
else []
),
},
)
genai_tool = types.Tool(function_declarations=[genai_function])
genai_tools.append(genai_tool)

return genai_tools

def _raw_gen(
self,
baseself,
model,
messages,
stream=False,
**kwargs
):
import google.generativeai as genai
genai.configure(api_key=self.api_key)
model = genai.GenerativeModel(model, system_instruction=messages[0]["content"])
response = model.generate_content(self._clean_messages_google(messages))
return response.text
tools=None,
formatting="openai",
**kwargs,
):
client = genai.Client(api_key=self.api_key)
if formatting == "openai":
messages = self._clean_messages_google(messages)
config = types.GenerateContentConfig()
if messages[0].role == "system":
config.system_instruction = messages[0].parts[0].text
messages = messages[1:]

if tools:
cleaned_tools = self._clean_tools_format(tools)
config.tools = cleaned_tools
response = client.models.generate_content(
model=model,
contents=messages,
config=config,
)
return response
else:
response = client.models.generate_content(
model=model, contents=messages, config=config
)
return response.text

def _raw_gen_stream(
self,
baseself,
model,
messages,
stream=True,
**kwargs
):
import google.generativeai as genai
genai.configure(api_key=self.api_key)
model = genai.GenerativeModel(model, system_instruction=messages[0]["content"])
response = model.generate_content(self._clean_messages_google(messages), stream=True)
for line in response:
if line.text is not None:
yield line.text
tools=None,
formatting="openai",
**kwargs,
):
client = genai.Client(api_key=self.api_key)
if formatting == "openai":
messages = self._clean_messages_google(messages)
config = types.GenerateContentConfig()
if messages[0].role == "system":
config.system_instruction = messages[0].parts[0].text
messages = messages[1:]

if tools:
cleaned_tools = self._clean_tools_format(tools)
config.tools = cleaned_tools

response = client.models.generate_content_stream(
model=model,
contents=messages,
config=config,
)
for chunk in response:
if chunk.text is not None:
yield chunk.text

def _supports_tools(self):
return True
20 changes: 7 additions & 13 deletions application/llm/openai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from application.llm.base import BaseLLM
from application.core.settings import settings

from application.llm.base import BaseLLM


class OpenAILLM(BaseLLM):
Expand All @@ -10,10 +9,7 @@ def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):

super().__init__(*args, **kwargs)
if settings.OPENAI_BASE_URL:
self.client = OpenAI(
api_key=api_key,
base_url=settings.OPENAI_BASE_URL
)
self.client = OpenAI(api_key=api_key, base_url=settings.OPENAI_BASE_URL)
else:
self.client = OpenAI(api_key=api_key)
self.api_key = api_key
Expand All @@ -27,8 +23,8 @@ def _raw_gen(
stream=False,
tools=None,
engine=settings.AZURE_DEPLOYMENT_NAME,
**kwargs
):
**kwargs,
):
if tools:
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, tools=tools, **kwargs
Expand All @@ -48,18 +44,16 @@ def _raw_gen_stream(
stream=True,
tools=None,
engine=settings.AZURE_DEPLOYMENT_NAME,
**kwargs
):
**kwargs,
):
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)

for line in response:
# import sys
# print(line.choices[0].delta.content, file=sys.stderr)
if line.choices[0].delta.content is not None:
yield line.choices[0].delta.content

def _supports_tools(self):
return True

Expand Down
2 changes: 2 additions & 0 deletions application/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ esutils==1.0.1
Flask==3.1.0
faiss-cpu==1.9.0.post1
flask-restx==1.3.0
google-genai==0.5.0
google-generativeai==0.8.3
gTTS==2.5.4
gunicorn==23.0.0
html2text==2024.2.26
Expand Down
10 changes: 4 additions & 6 deletions application/retriever/brave_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,10 @@ def gen(self):
if len(self.chat_history) > 1:
for i in self.chat_history:
if "prompt" in i and "response" in i:
messages_combine.append(
{"role": "user", "content": i["prompt"]}
)
messages_combine.append(
{"role": "system", "content": i["response"]}
)
messages_combine.append({"role": "user", "content": i["prompt"]})
messages_combine.append(
{"role": "assistant", "content": i["response"]}
)
messages_combine.append({"role": "user", "content": self.question})

llm = LLMCreator.create_llm(
Expand Down
13 changes: 5 additions & 8 deletions application/retriever/classic_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from application.vectorstore.vector_creator import VectorCreator



class ClassicRAG(BaseRetriever):

def __init__(
Expand Down Expand Up @@ -74,13 +73,11 @@ def gen(self):

if len(self.chat_history) > 1:
for i in self.chat_history:
if "prompt" in i and "response" in i:
messages_combine.append(
{"role": "user", "content": i["prompt"]}
)
messages_combine.append(
{"role": "system", "content": i["response"]}
)
if "prompt" in i and "response" in i:
messages_combine.append({"role": "user", "content": i["prompt"]})
messages_combine.append(
{"role": "assistant", "content": i["response"]}
)
messages_combine.append({"role": "user", "content": self.question})
# llm = LLMCreator.create_llm(
# settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key
Expand Down
6 changes: 2 additions & 4 deletions application/retriever/duckduck_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,9 @@ def gen(self):
if len(self.chat_history) > 1:
for i in self.chat_history:
if "prompt" in i and "response" in i:
messages_combine.append({"role": "user", "content": i["prompt"]})
messages_combine.append(
{"role": "user", "content": i["prompt"]}
)
messages_combine.append(
{"role": "system", "content": i["response"]}
{"role": "assistant", "content": i["response"]}
)
messages_combine.append({"role": "user", "content": self.question})

Expand Down
Loading

0 comments on commit 2606e6b

Please sign in to comment.