Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Googleai compatability tools #1594

Merged
merged 16 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions application/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def gen_cache_key(messages, model="docgpt", tools=None):
if not all(isinstance(msg, dict) for msg in messages):
raise ValueError("All messages must be dictionaries.")
messages_str = json.dumps(messages)
tools_str = json.dumps(tools) if tools else ""
tools_str = json.dumps(str(tools)) if tools else ""
combined = f"{model}_{messages_str}_{tools_str}"
cache_key = get_hash(combined)
return cache_key
Expand Down Expand Up @@ -68,8 +68,8 @@ def wrapper(self, model, messages, stream, tools=None, *args, **kwargs):


def stream_cache(func):
def wrapper(self, model, messages, stream, *args, **kwargs):
cache_key = gen_cache_key(messages)
def wrapper(self, model, messages, stream, tools=None, *args, **kwargs):
cache_key = gen_cache_key(messages, model, tools)
logger.info(f"Stream cache key: {cache_key}")

redis_client = get_redis_instance()
Expand All @@ -86,7 +86,7 @@ def wrapper(self, model, messages, stream, *args, **kwargs):
except redis.ConnectionError as e:
logger.error(f"Redis connection error: {e}")

result = func(self, model, messages, stream, *args, **kwargs)
result = func(self, model, messages, stream, tools=tools, *args, **kwargs)
stream_cache_data = []

for chunk in result:
Expand Down
35 changes: 28 additions & 7 deletions application/llm/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod

from application.cache import gen_cache, stream_cache
from application.usage import gen_token_usage, stream_token_usage
from application.cache import stream_cache, gen_cache


class BaseLLM(ABC):
Expand All @@ -18,18 +19,38 @@ def _raw_gen(self, model, messages, stream, tools, *args, **kwargs):

def gen(self, model, messages, stream=False, tools=None, *args, **kwargs):
decorators = [gen_token_usage, gen_cache]
return self._apply_decorator(self._raw_gen, decorators=decorators, model=model, messages=messages, stream=stream, tools=tools, *args, **kwargs)
return self._apply_decorator(
self._raw_gen,
decorators=decorators,
model=model,
messages=messages,
stream=stream,
tools=tools,
*args,
**kwargs
)

@abstractmethod
def _raw_gen_stream(self, model, messages, stream, *args, **kwargs):
pass

def gen_stream(self, model, messages, stream=True, *args, **kwargs):
def gen_stream(self, model, messages, stream=True, tools=None, *args, **kwargs):
decorators = [stream_cache, stream_token_usage]
return self._apply_decorator(self._raw_gen_stream, decorators=decorators, model=model, messages=messages, stream=stream, *args, **kwargs)

return self._apply_decorator(
self._raw_gen_stream,
decorators=decorators,
model=model,
messages=messages,
stream=stream,
tools=tools,
*args,
**kwargs
)

def supports_tools(self):
return hasattr(self, '_supports_tools') and callable(getattr(self, '_supports_tools'))
return hasattr(self, "_supports_tools") and callable(
getattr(self, "_supports_tools")
)

def _supports_tools(self):
raise NotImplementedError("Subclass must implement _supports_tools method")
raise NotImplementedError("Subclass must implement _supports_tools method")
151 changes: 126 additions & 25 deletions application/llm/google_ai.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,149 @@
from google import genai
from google.genai import types

from application.llm.base import BaseLLM

class GoogleLLM(BaseLLM):

class GoogleLLM(BaseLLM):
def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):

super().__init__(*args, **kwargs)
self.api_key = api_key
self.user_api_key = user_api_key

def _clean_messages_google(self, messages):
return [
{
"role": "model" if message["role"] == "system" else message["role"],
"parts": [message["content"]],
}
for message in messages[1:]
]
cleaned_messages = []
for message in messages:
role = message.get("role")
content = message.get("content")

if role == "assistant":
role = "model"

parts = []
if role and content is not None:
if isinstance(content, str):
parts = [types.Part.from_text(content)]
elif isinstance(content, list):
for item in content:
if "text" in item:
parts.append(types.Part.from_text(item["text"]))
elif "function_call" in item:
parts.append(
types.Part.from_function_call(
name=item["function_call"]["name"],
args=item["function_call"]["args"],
)
)
elif "function_response" in item:
parts.append(
types.Part.from_function_response(
name=item["function_response"]["name"],
response=item["function_response"]["response"],
)
)
else:
raise ValueError(
f"Unexpected content dictionary format:{item}"
)
else:
raise ValueError(f"Unexpected content type: {type(content)}")

cleaned_messages.append(types.Content(role=role, parts=parts))

return cleaned_messages

def _clean_tools_format(self, tools_list):
genai_tools = []
for tool_data in tools_list:
if tool_data["type"] == "function":
function = tool_data["function"]
genai_function = dict(
name=function["name"],
description=function["description"],
parameters={
"type": "OBJECT",
"properties": {
k: {
**v,
"type": v["type"].upper() if v["type"] else None,
}
for k, v in function["parameters"]["properties"].items()
},
"required": (
function["parameters"]["required"]
if "required" in function["parameters"]
else []
),
},
)
genai_tool = types.Tool(function_declarations=[genai_function])
genai_tools.append(genai_tool)

return genai_tools

def _raw_gen(
self,
baseself,
model,
messages,
stream=False,
**kwargs
):
import google.generativeai as genai
genai.configure(api_key=self.api_key)
model = genai.GenerativeModel(model, system_instruction=messages[0]["content"])
response = model.generate_content(self._clean_messages_google(messages))
return response.text
tools=None,
formatting="openai",
**kwargs,
):
client = genai.Client(api_key=self.api_key)
if formatting == "openai":
messages = self._clean_messages_google(messages)
config = types.GenerateContentConfig()
if messages[0].role == "system":
config.system_instruction = messages[0].parts[0].text
messages = messages[1:]

if tools:
cleaned_tools = self._clean_tools_format(tools)
config.tools = cleaned_tools
response = client.models.generate_content(
model=model,
contents=messages,
config=config,
)
return response
else:
response = client.models.generate_content(
model=model, contents=messages, config=config
)
return response.text

def _raw_gen_stream(
self,
baseself,
model,
messages,
stream=True,
**kwargs
):
import google.generativeai as genai
genai.configure(api_key=self.api_key)
model = genai.GenerativeModel(model, system_instruction=messages[0]["content"])
response = model.generate_content(self._clean_messages_google(messages), stream=True)
for line in response:
if line.text is not None:
yield line.text
tools=None,
formatting="openai",
**kwargs,
):
client = genai.Client(api_key=self.api_key)
if formatting == "openai":
messages = self._clean_messages_google(messages)
config = types.GenerateContentConfig()
if messages[0].role == "system":
config.system_instruction = messages[0].parts[0].text
messages = messages[1:]

if tools:
cleaned_tools = self._clean_tools_format(tools)
config.tools = cleaned_tools

response = client.models.generate_content_stream(
model=model,
contents=messages,
config=config,
)
for chunk in response:
if chunk.text is not None:
yield chunk.text

def _supports_tools(self):
return True
20 changes: 7 additions & 13 deletions application/llm/openai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from application.llm.base import BaseLLM
from application.core.settings import settings

from application.llm.base import BaseLLM


class OpenAILLM(BaseLLM):
Expand All @@ -10,10 +9,7 @@ def __init__(self, api_key=None, user_api_key=None, *args, **kwargs):

super().__init__(*args, **kwargs)
if settings.OPENAI_BASE_URL:
self.client = OpenAI(
api_key=api_key,
base_url=settings.OPENAI_BASE_URL
)
self.client = OpenAI(api_key=api_key, base_url=settings.OPENAI_BASE_URL)
else:
self.client = OpenAI(api_key=api_key)
self.api_key = api_key
Expand All @@ -27,8 +23,8 @@ def _raw_gen(
stream=False,
tools=None,
engine=settings.AZURE_DEPLOYMENT_NAME,
**kwargs
):
**kwargs,
):
if tools:
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, tools=tools, **kwargs
Expand All @@ -48,18 +44,16 @@ def _raw_gen_stream(
stream=True,
tools=None,
engine=settings.AZURE_DEPLOYMENT_NAME,
**kwargs
):
**kwargs,
):
response = self.client.chat.completions.create(
model=model, messages=messages, stream=stream, **kwargs
)

for line in response:
# import sys
# print(line.choices[0].delta.content, file=sys.stderr)
if line.choices[0].delta.content is not None:
yield line.choices[0].delta.content

def _supports_tools(self):
return True

Expand Down
2 changes: 2 additions & 0 deletions application/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ esutils==1.0.1
Flask==3.1.0
faiss-cpu==1.9.0.post1
flask-restx==1.3.0
google-genai==0.5.0
google-generativeai==0.8.3
gTTS==2.5.4
gunicorn==23.0.0
html2text==2024.2.26
Expand Down
10 changes: 4 additions & 6 deletions application/retriever/brave_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,10 @@ def gen(self):
if len(self.chat_history) > 1:
for i in self.chat_history:
if "prompt" in i and "response" in i:
messages_combine.append(
{"role": "user", "content": i["prompt"]}
)
messages_combine.append(
{"role": "system", "content": i["response"]}
)
messages_combine.append({"role": "user", "content": i["prompt"]})
messages_combine.append(
{"role": "assistant", "content": i["response"]}
)
messages_combine.append({"role": "user", "content": self.question})

llm = LLMCreator.create_llm(
Expand Down
13 changes: 5 additions & 8 deletions application/retriever/classic_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from application.vectorstore.vector_creator import VectorCreator



class ClassicRAG(BaseRetriever):

def __init__(
Expand Down Expand Up @@ -74,13 +73,11 @@ def gen(self):

if len(self.chat_history) > 1:
for i in self.chat_history:
if "prompt" in i and "response" in i:
messages_combine.append(
{"role": "user", "content": i["prompt"]}
)
messages_combine.append(
{"role": "system", "content": i["response"]}
)
if "prompt" in i and "response" in i:
messages_combine.append({"role": "user", "content": i["prompt"]})
messages_combine.append(
{"role": "assistant", "content": i["response"]}
)
messages_combine.append({"role": "user", "content": self.question})
# llm = LLMCreator.create_llm(
# settings.LLM_NAME, api_key=settings.API_KEY, user_api_key=self.user_api_key
Expand Down
6 changes: 2 additions & 4 deletions application/retriever/duckduck_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,9 @@ def gen(self):
if len(self.chat_history) > 1:
for i in self.chat_history:
if "prompt" in i and "response" in i:
messages_combine.append({"role": "user", "content": i["prompt"]})
messages_combine.append(
{"role": "user", "content": i["prompt"]}
)
messages_combine.append(
{"role": "system", "content": i["response"]}
{"role": "assistant", "content": i["response"]}
)
messages_combine.append({"role": "user", "content": self.question})

Expand Down
Loading
Loading