Skip to content

Commit

Permalink
Curate sources to avoid the UI crashing (#1212)
Browse files Browse the repository at this point in the history
* Curate sources to avoid the UI crashing

* Remove sources from chat history to avoid confusing the LLM
  • Loading branch information
imartinez authored Nov 12, 2023
1 parent a579c9b commit b764754
Showing 1 changed file with 40 additions and 21 deletions.
61 changes: 40 additions & 21 deletions private_gpt/ui/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,43 @@
from fastapi import FastAPI
from gradio.themes.utils.colors import slate # type: ignore
from llama_index.llms import ChatMessage, ChatResponse, MessageRole
from pydantic import BaseModel

from private_gpt.di import root_injector
from private_gpt.server.chat.chat_service import ChatService, CompletionGen
from private_gpt.server.chunks.chunks_service import ChunksService
from private_gpt.server.chunks.chunks_service import Chunk, ChunksService
from private_gpt.server.ingest.ingest_service import IngestService
from private_gpt.settings.settings import settings
from private_gpt.ui.images import logo_svg

logger = logging.getLogger(__name__)


UI_TAB_TITLE = "My Private GPT"
SOURCES_SEPARATOR = "\n\n Sources: \n"


class Source(BaseModel):
file: str
page: str
text: str

class Config:
frozen = True

@staticmethod
def curate_sources(sources: list[Chunk]) -> set["Source"]:
curated_sources = set()

for chunk in sources:
doc_metadata = chunk.document.doc_metadata

file_name = doc_metadata.get("file_name", "-") if doc_metadata else "-"
page_label = doc_metadata.get("page_label", "-") if doc_metadata else "-"

source = Source(file=file_name, page=page_label, text=chunk.text)
curated_sources.add(source)

return curated_sources


class PrivateGptUi:
Expand All @@ -44,21 +69,11 @@ def yield_deltas(completion_gen: CompletionGen) -> Iterable[str]:
yield full_response

if completion_gen.sources:
full_response += "\n\n Sources: \n"
sources = (
{
"file": chunk.document.doc_metadata["file_name"]
if chunk.document.doc_metadata
else "",
"page": chunk.document.doc_metadata["page_label"]
if chunk.document.doc_metadata
else "",
}
for chunk in completion_gen.sources
)
full_response += SOURCES_SEPARATOR
cur_sources = Source.curate_sources(completion_gen.sources)
sources_text = "\n\n\n".join(
f"{index}. {source['file']} (page {source['page']})"
for index, source in enumerate(sources, start=1)
f"{index}. {source.file} (page {source.page})"
for index, source in enumerate(cur_sources, start=1)
)
full_response += sources_text
yield full_response
Expand All @@ -70,7 +85,9 @@ def build_history() -> list[ChatMessage]:
[
ChatMessage(content=interaction[0], role=MessageRole.USER),
ChatMessage(
content=interaction[1], role=MessageRole.ASSISTANT
# Remove from history content the Sources information
content=interaction[1].split(SOURCES_SEPARATOR)[0],
role=MessageRole.ASSISTANT,
),
]
for interaction in history
Expand Down Expand Up @@ -103,11 +120,13 @@ def build_history() -> list[ChatMessage]:
text=message, limit=4, prev_next_chunks=0
)

sources = Source.curate_sources(response)

yield "\n\n\n".join(
f"{index}. **{chunk.document.doc_metadata['file_name'] if chunk.document.doc_metadata else ''} "
f"(page {chunk.document.doc_metadata['page_label'] if chunk.document.doc_metadata else ''})**\n "
f"{chunk.text}"
for index, chunk in enumerate(response, start=1)
f"{index}. **{source.file} "
f"(page {source.page})**\n "
f"{source.text}"
for index, source in enumerate(sources, start=1)
)

def _list_ingested_files(self) -> list[list[str]]:
Expand Down

0 comments on commit b764754

Please sign in to comment.