Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add faq consistency check #209

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions app/domain/rewriting_pipeline_execution_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@

class RewritingPipelineExecutionDTO(BaseModel):
execution: PipelineExecutionDTO
course_id: int = Field(alias="courseId")
to_be_rewritten: str = Field(alias="toBeRewritten")
5 changes: 5 additions & 0 deletions app/domain/status/rewriting_status_update_dto.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from typing import List

from app.domain.status.status_update_dto import StatusUpdateDTO


class RewritingStatusUpdateDTO(StatusUpdateDTO):
result: str = ""
suggestions: List[str] = []
inconsistencies: List[str] = []
improvement: str = ""
46 changes: 46 additions & 0 deletions app/pipeline/prompts/faq_consistency_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
faq_consistency_prompt = """
You are an AI assistant responsible for verifying the consistency of information.
### Task:
You have been provided with a list of FAQs and a final result. Your task is to determine whether the
final result is consistent with the given FAQs. Please compare each FAQ with the final result separately.



### Given FAQs:
{faqs}

### Final Result:
{final_result}

### Output:
Generate the following response dictionary:
"type": "consistent" or "inconsistent"

The following four entries to the dictionary are optional and can only be set if inconsistencies are detected:
"faqs": This entry should be a list of Strings, each string represents an FAQ.
-Make sure each faq is separated by comma.
-Also end each faq with a newline character.
-The fields are exactly named faq_id, faq_question_title and faq_question_answer
and reside within properties dict of each list entry.
-Make sure to only include inconsistent faqs
-Do not include any additional FAQs that are consistent with the final_result.

"message": "The provided text was rephrased, however it contains inconsistent information with existing FAQs."
-Localize the message to the language of the ###Final Result.
-Make sure to always insert two new lines after the last character of this sentences.
The affected FAQs can only contain the faq_id, faq_question_title, and faq_question_answer of inconsistent FAQs.
Make sure to not include any additional FAQs, that are consistent with the final_result.
Insert the faq_id, faq_question_title, and faq_question_answer of the inconsistent FAQ in the placeholder.

-"suggestion": This entry is a list of strings, each string represents a suggestion to improve the final result.\n
- Each suggestion should focus on a different inconsistency.
- Each suggestions highlights what is the inconsistency and how it can be improved.
- Do not mention the term final result, call it provided text
- Please ensure that at no time, you have a different amount of suggestions than inconsistencies.\n
Both should have the same amount of entries.

-"improved version": This entry should be a string that represents the improved version of the final result.


Do NOT provide any explanations or additional text.
"""
32 changes: 0 additions & 32 deletions app/pipeline/prompts/faq_rewriting.py

This file was deleted.

91 changes: 85 additions & 6 deletions app/pipeline/rewriting_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import logging
from typing import Literal, Optional
from typing import Literal, Optional, List, Dict

from langchain.output_parsers import PydanticOutputParser
from langchain_core.prompts import (
Expand All @@ -12,10 +13,13 @@
from app.domain.rewriting_pipeline_execution_dto import RewritingPipelineExecutionDTO
from app.llm import CapabilityRequestHandler, RequirementList, CompletionArguments
from app.pipeline import Pipeline
from app.pipeline.prompts.faq_consistency_prompt import faq_consistency_prompt
from app.pipeline.prompts.rewriting_prompts import (
system_prompt_faq,
system_prompt_problem_statement,
)
from app.retrieval.faq_retrieval import FaqRetrieval
from app.vector_database.database import VectorDatabase
from app.web.status.status_update import RewritingCallback

logger = logging.getLogger(__name__)
Expand All @@ -38,8 +42,11 @@ def __init__(
context_length=16385,
)
)

self.db = VectorDatabase()
self.tokens = []
self.variant = variant
self.faq_retriever = FaqRetrieval(self.db.client)

def __call__(
self,
Expand All @@ -54,10 +61,10 @@ def __call__(
"faq": system_prompt_faq,
"problem_statement": system_prompt_problem_statement,
}
print(variant_prompts[self.variant])
prompt = variant_prompts[self.variant].format(
rewritten_text=dto.to_be_rewritten,
)

format_args = {"rewritten_text": dto.to_be_rewritten}

prompt = variant_prompts[self.variant].format(**format_args)
prompt = PyrisMessage(
sender=IrisMessageRole.SYSTEM,
contents=[TextMessageContentDTO(text_content=prompt)],
Expand All @@ -77,4 +84,76 @@ def __call__(
response = response.strip()

final_result = response
self.callback.done(final_result=final_result, tokens=self.tokens)
inconsistencies = []
improvement = ""
suggestions = []

if self.variant == "faq":
faqs = self.faq_retriever.get_faqs_from_db(
course_id=dto.course_id, search_text=response, result_limit=10
)
consistency_result = self.check_faq_consistency(faqs, final_result)

if "inconsistent" in consistency_result["type"].lower():
logging.warning("Detected inconsistencies in FAQ retrieval.")
inconsistencies = parse_inconsistencies(consistency_result["faqs"])
improvement = consistency_result["improved version"]
suggestions = consistency_result["suggestion"]

self.callback.done(
final_result=final_result,
tokens=self.tokens,
inconsistencies=inconsistencies,
improvement=improvement,
suggestions=suggestions,
)
Comment on lines +87 to +109
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Callback parameters usage is comprehensive, but consider adding error handling.

The code now properly initializes and passes inconsistency-related parameters to the callback. However, there's no error handling for potential issues in the FAQ retrieval or consistency checking process.

Add error handling around the FAQ retrieval and consistency checking to ensure that the pipeline doesn't fail if there are issues with these operations:

        final_result = response
        inconsistencies = []
        improvement = ""
        suggestions = []

        if self.variant == "faq":
-           faqs = self.faq_retriever.get_faqs_from_db(
-               course_id=dto.course_id, search_text=response, result_limit=10
-           )
-           consistency_result = self.check_faq_consistency(faqs, final_result)
-
-           if "inconsistent" in consistency_result["type"].lower():
-               logging.warning("Detected inconsistencies in FAQ retrieval.")
-               inconsistencies = parse_inconsistencies(consistency_result["faqs"])
-               improvement = consistency_result["improved version"]
-               suggestions = consistency_result["suggestion"]
+           try:
+               faqs = self.faq_retriever.get_faqs_from_db(
+                   course_id=dto.course_id, search_text=response, result_limit=10
+               )
+               if faqs:  # Only check consistency if we have FAQs to check against
+                   consistency_result = self.check_faq_consistency(faqs, final_result)
+                   
+                   if "inconsistent" in consistency_result["type"].lower():
+                       logging.warning("Detected inconsistencies in FAQ retrieval.")
+                       inconsistencies = parse_inconsistencies(consistency_result["faqs"])
+                       improvement = consistency_result["improved version"]
+                       suggestions = consistency_result["suggestion"]
+               else:
+                   logging.info("No FAQs retrieved for consistency check.")
+           except Exception as e:
+               logging.error(f"Error during FAQ consistency checking: {str(e)}")
+               # Continue with the pipeline even if FAQ checking fails
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
inconsistencies = []
improvement = ""
suggestions = []
if self.variant == "faq":
faqs = self.faq_retriever.get_faqs_from_db(
course_id=dto.course_id, search_text=response, result_limit=10
)
consistency_result = self.check_faq_consistency(faqs, final_result)
if "inconsistent" in consistency_result["type"].lower():
logging.warning("Detected inconsistencies in FAQ retrieval.")
inconsistencies = parse_inconsistencies(consistency_result["faqs"])
improvement = consistency_result["improved version"]
suggestions = consistency_result["suggestion"]
self.callback.done(
final_result=final_result,
tokens=self.tokens,
inconsistencies=inconsistencies,
improvement=improvement,
suggestions=suggestions,
)
final_result = response
inconsistencies = []
improvement = ""
suggestions = []
if self.variant == "faq":
try:
faqs = self.faq_retriever.get_faqs_from_db(
course_id=dto.course_id, search_text=response, result_limit=10
)
if faqs: # Only check consistency if we have FAQs to check against
consistency_result = self.check_faq_consistency(faqs, final_result)
if "inconsistent" in consistency_result["type"].lower():
logging.warning("Detected inconsistencies in FAQ retrieval.")
inconsistencies = parse_inconsistencies(consistency_result["faqs"])
improvement = consistency_result["improved version"]
suggestions = consistency_result["suggestion"]
else:
logging.info("No FAQs retrieved for consistency check.")
except Exception as e:
logging.error(f"Error during FAQ consistency checking: {str(e)}")
# Continue with the pipeline even if FAQ checking fails
self.callback.done(
final_result=final_result,
tokens=self.tokens,
inconsistencies=inconsistencies,
improvement=improvement,
suggestions=suggestions,
)


def check_faq_consistency(
self, faqs: List[dict], final_result: str
) -> Dict[str, str]:
"""
Checks the consistency of the given FAQs with the provided final_result.
Returns "consistent" if there are no inconsistencies, otherwise returns "inconsistent".

:param faqs: List of retrieved FAQs.
:param final_result: The result to compare the FAQs against.

"""
properties_list = [entry["properties"] for entry in faqs]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Handle potential KeyError for 'properties'.

There's no error handling if 'properties' key is missing from a FAQ entry.

Add error handling for missing 'properties' key in FAQ entries:

-        properties_list = [entry["properties"] for entry in faqs]
+        properties_list = []
+        for entry in faqs:
+            if "properties" in entry:
+                properties_list.append(entry["properties"])
+            else:
+                logging.warning(f"FAQ entry missing 'properties' key: {entry}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
properties_list = [entry["properties"] for entry in faqs]
properties_list = []
for entry in faqs:
if "properties" in entry:
properties_list.append(entry["properties"])
else:
logging.warning(f"FAQ entry missing 'properties' key: {entry}")


consistency_prompt = faq_consistency_prompt.format(
faqs=properties_list, final_result=final_result
)
Comment on lines +123 to +126
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Check for empty FAQs list before processing.

There's no check to see if the FAQs list is empty before trying to process it, which could lead to unexpected behavior in the consistency checking.

Add a check for an empty FAQs list:

        properties_list = [entry["properties"] for entry in faqs]
+        
+        # Return early if no FAQs to check against
+        if not properties_list:
+            return {
+                "type": "consistent",
+                "message": "No FAQs found to check consistency against.",
+                "faqs": [],
+                "suggestion": [],
+                "improved version": "",
+            }

        consistency_prompt = faq_consistency_prompt.format(
            faqs=properties_list, final_result=final_result
        )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
consistency_prompt = faq_consistency_prompt.format(
faqs=properties_list, final_result=final_result
)
properties_list = [entry["properties"] for entry in faqs]
# Return early if no FAQs to check against
if not properties_list:
return {
"type": "consistent",
"message": "No FAQs found to check consistency against.",
"faqs": [],
"suggestion": [],
"improved version": "",
}
consistency_prompt = faq_consistency_prompt.format(
faqs=properties_list, final_result=final_result
)


prompt = PyrisMessage(
sender=IrisMessageRole.SYSTEM,
contents=[TextMessageContentDTO(text_content=consistency_prompt)],
)

response = self.request_handler.chat(
[prompt], CompletionArguments(temperature=0.0), tools=None
)

self._append_tokens(response.token_usage, PipelineEnum.IRIS_REWRITING_PIPELINE)
result = response.contents[0].text_content
data = json.loads(result)

result_dict = {
"type": data["type"],
"message": data["message"],
"faqs": data["faqs"],
"suggestion": data["suggestion"],
"improved version": data["improved version"],
}
logging.info(f"Consistency FAQ consistency check response: {result_dict}")

return result_dict


def parse_inconsistencies(inconsistencies: List[Dict[str, str]]) -> List[str]:
logging.info("parse consistency")
parsed_inconsistencies = [
f"FAQ ID: {entry['faq_id']}, Title: {entry['faq_question_title']}, Answer: {entry['faq_question_answer']}"
for entry in inconsistencies
]
return parsed_inconsistencies
47 changes: 46 additions & 1 deletion app/retrieval/faq_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from typing import List
from langsmith import traceable
from weaviate import WeaviateClient
from weaviate.collections.classes.filters import Filter

from app.common.PipelineEnum import PipelineEnum
from .basic_retrieval import BaseRetrieval, merge_retrieved_chunks
from ..common.pyris_message import PyrisMessage
Expand Down Expand Up @@ -44,7 +46,6 @@ def __call__(
base_url: str = None,
) -> List[dict]:
course_language = self.fetch_course_language(course_id)

response, response_hyde = self.run_parallel_rewrite_tasks(
chat_history=chat_history,
student_query=student_query,
Expand All @@ -67,3 +68,47 @@ def __call__(
for obj in response_hyde.objects
]
return merge_retrieved_chunks(basic_retrieved_faqs, hyde_retrieved_faqs)

def get_faqs_from_db(
self,
course_id: int,
search_text: str = None,
result_limit: int = 10,
hybrid_factor: float = 0.75, # Gewichtung zwischen Vektor- und Textsuche
) -> List[dict]:
"""
Holt FAQs direkt aus der Datenbank, optional mit einer Ähnlichkeitssuche auf question_title und question_answer.

:param course_id: ID des Kurses, um nur FAQs eines bestimmten Kurses zu holen.
:param search_text: Optionaler Suchtext, der für eine semantische Suche verwendet wird.
:param result_limit: Anzahl der gewünschten Ergebnisse.
:param hybrid_factor: Gewichtung zwischen vektorbasierten und keywordbasierten Ergebnissen (0 = nur Vektor, 1 = nur Keywords).
:return: Liste der gefundenen FAQs.
"""
# Filter für den Kurs
filter_weaviate = Filter.by_property("course_id").equal(course_id)

if search_text:
vec = self.llm_embedding.embed(search_text)

response = self.collection.query.hybrid(
query=search_text, # Keyword-Suche
vector=vec, # Vektorbasierte Ähnlichkeitssuche
alpha=hybrid_factor, # Mischung aus Vektor- und Textsuche
return_properties=self.get_schema_properties(),
limit=result_limit,
filters=filter_weaviate,
)
else:
# Falls keine Suchanfrage, einfach nur nach Kurs filtern
response = self.collection.query.fetch_objects(
filters=filter_weaviate,
limit=result_limit,
return_properties=self.get_schema_properties(),
)

faqs = [
{"id": obj.uuid.int, "properties": obj.properties}
for obj in response.objects
]
return faqs
13 changes: 12 additions & 1 deletion app/web/status/status_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ def done(
tokens: Optional[List[TokenUsageDTO]] = None,
next_stage_message: Optional[str] = None,
start_next_stage: bool = True,
inconsistencies: Optional[List[str]] = None,
improvement: Optional[str] = None,
):
"""
Transition the current stage to DONE and update the status.
Expand All @@ -122,6 +124,11 @@ def done(
self.status.tokens = tokens or self.status.tokens
if hasattr(self.status, "suggestions"):
self.status.suggestions = suggestions

if hasattr(self.status, "inconsistencies"):
self.status.inconsistencies = inconsistencies
if hasattr(self.status, "improvement"):
self.status.improvement = improvement
next_stage = self.get_next_stage()
if next_stage is not None:
self.stage = next_stage
Expand All @@ -133,6 +140,8 @@ def done(
self.status.result = None
if hasattr(self.status, "suggestions"):
self.status.suggestions = None
if hasattr(self.status, "inconsistencies"):
self.status.inconsistencies = None

def error(
self, message: str, exception=None, tokens: Optional[List[TokenUsageDTO]] = None
Expand Down Expand Up @@ -308,7 +317,9 @@ def __init__(
base_url: str,
initial_stages: List[StageDTO],
):
url = f"{base_url}/api/public/pyris/pipelines/rewriting/runs/{run_id}/status"
url = (
f"{base_url}/api/iris/public/pyris/pipelines/rewriting/runs/{run_id}/status"
)
stages = initial_stages or []
stages.append(
StageDTO(
Expand Down
Loading