Skip to content

Commit

Permalink
Paraphrase prompt fix & cross encoder (#414)
Browse files Browse the repository at this point in the history
* added cross-encoder
* Skip rerank when 1 or fewer results
* updated diagram
* add requirement
* removed use of SERVICE_IDENTITY
* added check for n_top:
  • Loading branch information
sidravi1 authored Aug 30, 2024
1 parent e2ac83b commit d58417c
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 23 deletions.
9 changes: 8 additions & 1 deletion core_backend/app/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from fastapi.middleware.cors import CORSMiddleware
from prometheus_client import CollectorRegistry, make_asgi_app, multiprocess
from redis import asyncio as aioredis
from sentence_transformers import CrossEncoder

from . import (
admin,
Expand All @@ -20,7 +21,7 @@
urgency_rules,
user_tools,
)
from .config import DOMAIN, LANGFUSE, REDIS_HOST
from .config import CROSS_ENCODER_MODEL, DOMAIN, LANGFUSE, REDIS_HOST, USE_CROSS_ENCODER
from .prometheus_middleware import PrometheusMiddleware
from .utils import setup_logger

Expand Down Expand Up @@ -92,7 +93,13 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]:

logger.info("Application started")
app.state.redis = await aioredis.from_url(REDIS_HOST)
if USE_CROSS_ENCODER == "True":
app.state.crossencoder = CrossEncoder(
CROSS_ENCODER_MODEL,
)

yield

await app.state.redis.close()
logger.info("Application finished")

Expand Down
6 changes: 6 additions & 0 deletions core_backend/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@
SERVICE_IDENTITY = os.environ.get(
"SERVICE_IDENTITY", "air pollution and air quality chatbot"
)
# Cross-encoder
USE_CROSS_ENCODER = os.environ.get("USE_CROSS_ENCODER", "True")
CROSS_ENCODER_MODEL = os.environ.get(
"CROSS_ENCODER_MODEL", "cross-encoder/ms-marco-MiniLM-L-6-v2"
)

# Rate limit variables
CHECK_CONTENT_LIMIT = os.environ.get("CHECK_CONTENT_LIMIT", True)
DEFAULT_CONTENT_QUOTA = int(os.environ.get("DEFAULT_CONTENT_QUOTA", 50))
Expand Down
6 changes: 3 additions & 3 deletions core_backend/app/llm_call/llm_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,11 @@ def get_prompt(cls) -> str:
},
]
PARAPHRASE_PROMPT = f"""You are a high-performing paraphrasing bot. \
The user has sent a message.
The user has sent a message for a question-answering service.
If the message is a question, do not answer it, \
just paraphrase it to remove unecessary information and focus on the question. \
Remove any irrelevant or offensive words.
just paraphrase it to focus on the question and include any relevant information.\
Remove any irrelevant or offensive words
If the input message is not a question, respond with the same message but \
remove any irrelevant or offensive words.
Expand Down
1 change: 1 addition & 0 deletions core_backend/app/question_answer/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os

# Functionality variables
N_TOP_CONTENT_TO_CROSSENCODER = os.environ.get("N_TOP_CONTENT_TO_CROSSENCODER", "10")
N_TOP_CONTENT = os.environ.get("N_TOP_CONTENT", "4")
89 changes: 72 additions & 17 deletions core_backend/app/question_answer/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
from typing import Tuple

from fastapi import APIRouter, Depends, status
from fastapi.requests import Request
from fastapi.responses import JSONResponse
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession

from ..auth.dependencies import authenticate_key, rate_limiter
from ..config import CUSTOM_SPEECH_ENDPOINT, GCS_SPEECH_BUCKET
from ..config import CUSTOM_SPEECH_ENDPOINT, GCS_SPEECH_BUCKET, USE_CROSS_ENCODER
from ..contents.models import (
get_similar_content_async,
increment_query_count,
Expand All @@ -30,6 +31,7 @@
generate_llm_query_response,
generate_tts__after,
)
from ..schemas import QuerySearchResult
from ..users.models import UserDB
from ..utils import (
create_langfuse_metadata,
Expand All @@ -39,7 +41,7 @@
setup_logger,
upload_file_to_gcs,
)
from .config import N_TOP_CONTENT
from .config import N_TOP_CONTENT, N_TOP_CONTENT_TO_CROSSENCODER
from .models import (
QueryDB,
check_secret_key_match,
Expand Down Expand Up @@ -88,6 +90,7 @@
)
async def search(
user_query: QueryBase,
request: Request,
asession: AsyncSession = Depends(get_async_session),
user_db: UserDB = Depends(authenticate_key),
) -> QueryResponse | JSONResponse:
Expand All @@ -114,8 +117,10 @@ async def search(
response=response_template,
user_id=user_db.user_id,
n_similar=int(N_TOP_CONTENT),
n_to_crossencoder=int(N_TOP_CONTENT_TO_CROSSENCODER),
asession=asession,
exclude_archived=True,
request=request,
)

if user_query.generate_llm_response:
Expand All @@ -138,17 +143,18 @@ async def search(
asession=asession,
)

if type(response) is QueryResponse:
if isinstance(response, QueryResponse):
return response
elif type(response) is QueryResponseError:

if isinstance(response, QueryResponseError):
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST, content=response.model_dump()
)
else:
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={"message": "Internal server error"},
)

return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={"message": "Internal server error"},
)


@router.post(
Expand All @@ -167,6 +173,7 @@ async def search(
)
async def voice_search(
file_url: str,
request: Request,
asession: AsyncSession = Depends(get_async_session),
user_db: UserDB = Depends(authenticate_key),
) -> QueryAudioResponse | JSONResponse:
Expand Down Expand Up @@ -222,8 +229,10 @@ async def voice_search(
response=response_template,
user_id=user_db.user_id,
n_similar=int(N_TOP_CONTENT),
n_to_crossencoder=int(N_TOP_CONTENT_TO_CROSSENCODER),
asession=asession,
exclude_archived=True,
request=request,
)

if user_query.generate_llm_response:
Expand All @@ -250,17 +259,18 @@ async def voice_search(
os.remove(file_path)
file_stream.close()

if type(response) is QueryAudioResponse:
if isinstance(response, QueryAudioResponse):
return response
elif type(response) is QueryResponseError:

if isinstance(response, QueryResponseError):
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST, content=response.model_dump()
)
else:
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={"error": "Internal server error"},
)

return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={"error": "Internal server error"},
)

except ValueError as ve:
logger.error(f"ValueError: {str(ve)}")
Expand Down Expand Up @@ -328,7 +338,9 @@ async def get_search_response(
response: QueryResponse,
user_id: int,
n_similar: int,
n_to_crossencoder: int,
asession: AsyncSession,
request: Request,
exclude_archived: bool = True,
) -> QueryResponse | QueryResponseError:
"""Get similar content and construct the LLM answer for the user query.
Expand All @@ -347,8 +359,12 @@ async def get_search_response(
The ID of the user making the query.
n_similar
The number of similar contents to retrieve.
n_to_crossencoder
The number of similar contents to send to the cross-encoder.
asession
`AsyncSession` object for database transactions.
request
The FastAPI request object.
exclude_archived
Specifies whether to exclude archived content.
Expand All @@ -362,19 +378,56 @@ async def get_search_response(
# always do the embeddings search even if some guardrails have failed
metadata = create_langfuse_metadata(query_id=response.query_id, user_id=user_id)

if USE_CROSS_ENCODER == "True" and (n_to_crossencoder < n_similar):
raise ValueError(
"`n_to_crossencoder` must be less than or equal to `n_similar`."
)

search_results = await get_similar_content_async(
user_id=user_id,
question=query_refined.query_text, # use latest transformed version of the text
n_similar=n_similar,
n_similar=n_to_crossencoder if USE_CROSS_ENCODER == "True" else n_similar,
asession=asession,
metadata=metadata,
exclude_archived=exclude_archived,
)

if USE_CROSS_ENCODER and (len(search_results) > 1):
search_results = rerank_search_results(
n_similar=n_similar,
search_results=search_results,
query_text=query_refined.query_text,
request=request,
)

response.search_results = search_results

return response


def rerank_search_results(
search_results: dict[int, QuerySearchResult],
n_similar: int,
query_text: str,
request: Request,
) -> dict[int, QuerySearchResult]:
"""
Rerank search results based on the similarity of the content to the query text
"""
encoder = request.app.state.crossencoder
contents = search_results.values()
scores = encoder.predict(
[(query_text, content.title + "\n" + content.text) for content in contents]
)

sorted_by_score = [v for _, v in sorted(zip(scores, contents), reverse=True)][
:n_similar
]
reranked_search_results = dict(enumerate(sorted_by_score))

return reranked_search_results


@generate_tts__after
@check_align_score__after
async def get_generation_response(
Expand Down Expand Up @@ -418,6 +471,8 @@ async def get_user_query_and_response(
The user query database object.
asession
`AsyncSession` object for database transactions.
generate_tts
Specifies whether to generate a TTS audio response
Returns
-------
Expand Down
1 change: 1 addition & 0 deletions core_backend/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ google-cloud-storage==2.18.2
google-cloud-texttospeech==2.16.5
google-cloud-speech==2.27.0
pydub==0.25.1
sentence-transformers==3.0.1
8 changes: 6 additions & 2 deletions docs/components/qa-service/search.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ sequenceDiagram
LLM->>AAQ: <Translated text>
AAQ->>LLM: Paraphrase question
LLM->>AAQ: <Paraphrased question>
AAQ->>Vector DB: Request N most similar contents in DB
Vector DB->>AAQ: <N contents with similarity score>
AAQ->>Vector DB: Request M most similar contents in DB
Vector DB->>AAQ: <M contents with similarity score>
AAQ->>Cross-encoder: Re-rank to get top N contents
Cross-encoder->>AAQ: <N contents with similarity score>
AAQ->>User: Return JSON of N contents
```
Expand All @@ -37,6 +39,8 @@ sequenceDiagram
LLM->>AAQ: <Safety Classification>
AAQ->>Vector DB: Request N most similar contents in DB
Vector DB->>AAQ: <N contents with similarity score>
AAQ->>Cross-encoder: Re-rank to get top N contents
Cross-encoder->>AAQ: <N contents with similarity score>
AAQ->>LLM: Given contents, construct response in user's language to question
LLM->>AAQ: <LLM response>
AAQ->>LLM: Check if LLM response is consistent with contents
Expand Down

0 comments on commit d58417c

Please sign in to comment.