Skip to content

Commit

Permalink
Feat/local llm bug fix (#1758)
Browse files Browse the repository at this point in the history
# Description

Please include a summary of the changes and the related issue. Please
also include relevant motivation and context.

## Checklist before requesting a review

Please delete options that are not relevant.

- [ ] My code follows the style guidelines of this project
- [ ] I have performed a self-review of my code
- [ ] I have commented hard-to-understand areas
- [ ] I have ideally added tests that prove my fix is effective or that
my feature works
- [ ] New and existing unit tests pass locally with my changes
- [ ] Any dependent changes have been merged

## Screenshots (if appropriate):
  • Loading branch information
StanGirard authored Nov 29, 2023
1 parent c6d4566 commit e1cde0f
Show file tree
Hide file tree
Showing 10 changed files with 144 additions and 230 deletions.
3 changes: 3 additions & 0 deletions .backend_env.example
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ GOOGLE_CLOUD_PROJECT=<change-me>
CELERY_BROKER_URL=redis://redis:6379/0
CELEBRY_BROKER_QUEUE_NAME=quivr-preview.fifo

#LOCAL
#OLLAMA_API_BASE_URL=http://host.docker.internal:11434 # local all in one remove comment to use local llm with Ollama



#RESEND
Expand Down
12 changes: 7 additions & 5 deletions backend/llm/api_brain_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from uuid import UUID

from fastapi import HTTPException
from logger import get_logger
from litellm import completion
from models.chats import ChatQuestion
from models.databases.supabase.chats import CreateChatHistory
Expand All @@ -17,6 +18,7 @@
get_api_brain_definition_as_json_schema,
)

logger = get_logger(__name__)

class APIBrainQA(
QABaseBrainPicking,
Expand Down Expand Up @@ -53,7 +55,6 @@ async def make_completion(
brain_id: UUID,
):
yield "🧠<Deciding what to do>🧠"

response = completion(
model=self.model,
temperature=self.temperature,
Expand All @@ -73,8 +74,7 @@ async def make_completion(

if finish_reason == "stop":
break

if "function_call" in chunk.choices[0].delta:
if "function_call" in chunk.choices[0].delta and chunk.choices[0].delta["function_call"]:
if "name" in chunk.choices[0].delta["function_call"]:
function_call["name"] = chunk.choices[0].delta["function_call"][
"name"
Expand All @@ -86,10 +86,12 @@ async def make_completion(

elif finish_reason == "function_call":
try:
logger.info(f"Function call: {function_call}")
arguments = json.loads(function_call["arguments"])

except Exception:
arguments = {}
yield f"🧠<Calling API with arguments {arguments} and brain id {brain_id}>🧠"
yield f"🧠<Calling {brain_id} with arguments {arguments}>🧠"

try:
api_call_response = call_brain_api(
Expand All @@ -106,7 +108,7 @@ async def make_completion(
messages.append(
{
"role": "function",
"name": function_call["name"],
"name": str(brain_id),
"content": api_call_response,
}
)
Expand Down
16 changes: 14 additions & 2 deletions backend/llm/qa_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from langchain.chains import ConversationalRetrievalChain, LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.chat_models import ChatLiteLLM
from langchain.embeddings.ollama import OllamaEmbeddings
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.llms.base import BaseLLM
from langchain.prompts.chat import (
Expand Down Expand Up @@ -84,8 +85,13 @@ def _determine_callback_array(
]

@property
def embeddings(self) -> OpenAIEmbeddings:
return OpenAIEmbeddings() # pyright: ignore reportPrivateUsage=none
def embeddings(self):
if self.brain_settings.ollama_api_base_url:
return OllamaEmbeddings(
base_url=self.brain_settings.ollama_api_base_url
) # pyright: ignore reportPrivateUsage=none
else:
return OpenAIEmbeddings()

supabase_client: Optional[Client] = None
vector_store: Optional[CustomSupabaseVectorStore] = None
Expand Down Expand Up @@ -143,13 +149,19 @@ def _create_llm(
:param callbacks: Callbacks to be used for streaming
:return: Language model instance
"""
api_base = None
if self.brain_settings.ollama_api_base_url and model.startswith("ollama"):
api_base = self.brain_settings.ollama_api_base_url


return ChatLiteLLM(
temperature=temperature,
max_tokens=self.max_tokens,
model=model,
streaming=streaming,
verbose=False,
callbacks=callbacks,
api_base= api_base
) # pyright: ignore reportPrivateUsage=none

def _create_prompt_template(self):
Expand Down
9 changes: 8 additions & 1 deletion backend/llm/qa_headless.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from langchain.chains import LLMChain
from langchain.chat_models import ChatLiteLLM
from langchain.chat_models.base import BaseChatModel
from models import BrainSettings # Importing settings related to the 'brain'
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
from logger import get_logger
from models.chats import ChatQuestion
Expand All @@ -30,6 +31,7 @@


class HeadlessQA(BaseModel):
brain_settings = BrainSettings()
model: str
temperature: float = 0.0
max_tokens: int = 2000
Expand Down Expand Up @@ -78,13 +80,18 @@ def _create_llm(
:param callbacks: Callbacks to be used for streaming
:return: Language model instance
"""
api_base = None
if self.brain_settings.ollama_api_base_url and model.startswith("ollama"):
api_base = self.brain_settings.ollama_api_base_url

return ChatLiteLLM(
temperature=0.1,
temperature=temperature,
model=model,
streaming=streaming,
verbose=True,
callbacks=callbacks,
max_tokens=self.max_tokens,
api_base=api_base,
)

def _create_prompt_template(self):
Expand Down
18 changes: 13 additions & 5 deletions backend/models/settings.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from langchain.embeddings.openai import OpenAIEmbeddings
from models.databases.supabase.supabase import SupabaseDB
from pydantic import BaseSettings
from supabase.client import Client, create_client
from vectorstore.supabase import SupabaseVectorStore
from langchain.embeddings.ollama import OllamaEmbeddings
from langchain.embeddings.openai import OpenAIEmbeddings

from logger import get_logger

logger = get_logger(__name__)

class BrainRateLimiting(BaseSettings):
max_brain_per_user: int = 5
Expand All @@ -15,6 +19,7 @@ class BrainSettings(BaseSettings):
supabase_service_key: str
resend_api_key: str = "null"
resend_email_address: str = "[email protected]"
ollama_api_base_url: str = None


class ContactsSettings(BaseSettings):
Expand All @@ -39,11 +44,14 @@ def get_supabase_db() -> SupabaseDB:
return SupabaseDB(supabase_client)


def get_embeddings() -> OpenAIEmbeddings:
def get_embeddings():
settings = BrainSettings() # pyright: ignore reportPrivateUsage=none
embeddings = OpenAIEmbeddings(
openai_api_key=settings.openai_api_key
) # pyright: ignore reportPrivateUsage=none
if settings.ollama_api_base_url:
embeddings = OllamaEmbeddings(
base_url=settings.ollama_api_base_url,
) # pyright: ignore reportPrivateUsage=none
else:
embeddings = OpenAIEmbeddings() # pyright: ignore reportPrivateUsage=none
return embeddings


Expand Down
7 changes: 4 additions & 3 deletions backend/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# pymupdf==1.22.3
langchain==0.0.332
litellm==0.13.2
langchain==0.0.341
litellm==1.7.7
# Markdown==3.4.4
openai==0.27.8
openai==1.1.1
GitPython==3.1.36
pdf2image==1.16.3
pypdf==3.9.0
Expand Down Expand Up @@ -36,3 +36,4 @@ python-dotenv
pytest-mock
pytest-celery
pytesseract==0.3.10
async_generator
Loading

0 comments on commit e1cde0f

Please sign in to comment.