Skip to content

Commit

Permalink
fix(issue-89): workaround for unmarshalling-strings issue rashadphz#89
Browse files Browse the repository at this point in the history
  • Loading branch information
jandoerntlein committed Sep 8, 2024
1 parent 883003f commit 759fd20
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 4 deletions.
9 changes: 7 additions & 2 deletions src/backend/related_queries.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
from backend.llm.base import BaseLLM
from backend.prompts import RELATED_QUESTION_PROMPT
from backend.schemas import RelatedQueries, SearchResult

import ast

async def generate_related_queries(
query: str, search_results: list[SearchResult], llm: BaseLLM
) -> list[str]:
context = "\n\n".join([f"{str(result)}" for result in search_results])
context = context[:4000]

related = llm.structured_complete(
RelatedQueries, RELATED_QUESTION_PROMPT.format(query=query, context=context)
)

return [query.lower().replace("?", "") for query in related.related_questions]
# Hotfix Part II (https://github.com/rashadphz/farfalle/issues/82)
if len(related.related_questions) == 1:
related.related_questions = ast.literal_eval(related.related_questions[0])

return [query.lower().replace("?", "").replace("}", "") for query in related.related_questions]
13 changes: 11 additions & 2 deletions src/backend/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from dotenv import load_dotenv
from logfire.integrations.pydantic import PluginSettings
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, BaseConfig, validator

from backend.constants import ChatModel
from backend.utils import strtobool
Expand Down Expand Up @@ -39,9 +39,15 @@ class ChatRequest(BaseModel, plugin_settings=record_all):
pro_search: bool = False


# Hotfix Part I (https://github.com/rashadphz/farfalle/issues/82)
class RelatedQueries(BaseModel):
related_questions: List[str] = Field(..., min_length=3, max_length=3)
related_questions: List[str]

@validator('related_questions', pre=True)
def ensure_list(cls, v):
if isinstance(v, str):
return [v]
return v

class SearchResult(BaseModel):
title: str
Expand Down Expand Up @@ -184,6 +190,9 @@ class ChatSnapshot(BaseModel):
preview: str
model_name: str

class Config(BaseConfig):
protected_namespaces = ()


class ChatHistoryResponse(BaseModel):
snapshots: List[ChatSnapshot] = Field(default_factory=list)
Expand Down

0 comments on commit 759fd20

Please sign in to comment.