From 85bc1e53d6a8e7bcc0b6f90880ee1f10182454cb Mon Sep 17 00:00:00 2001 From: chloedia Date: Thu, 2 Jan 2025 12:03:47 +0100 Subject: [PATCH] feat: draft summarization nodes --- core/quivr_core/brain/brain.py | 2 + core/quivr_core/llm_tools/llm_tools.py | 15 +- .../llm_tools/summarization_tool.py | 186 ++++++++++++++++++ core/quivr_core/rag/quivr_rag_langgraph.py | 20 +- core/tests/summarization_config.yaml | 40 ++++ core/tests/test_quivr_rag.py | 2 +- 6 files changed, 249 insertions(+), 16 deletions(-) create mode 100644 core/quivr_core/llm_tools/summarization_tool.py create mode 100644 core/tests/summarization_config.yaml diff --git a/core/quivr_core/brain/brain.py b/core/quivr_core/brain/brain.py index b19c7c0a7e3f..e3514a1f17f7 100644 --- a/core/quivr_core/brain/brain.py +++ b/core/quivr_core/brain/brain.py @@ -530,6 +530,8 @@ async def ask_streaming( if rag_pipeline is None: rag_pipeline = QuivrQARAGLangGraph + logger.info(f"LLLL Using vector db : {self.vector_db}") + rag_instance = rag_pipeline( retrieval_config=retrieval_config, llm=llm, vector_store=self.vector_db ) diff --git a/core/quivr_core/llm_tools/llm_tools.py b/core/quivr_core/llm_tools/llm_tools.py index 6e35bdcdc097..2c8814f2c263 100644 --- a/core/quivr_core/llm_tools/llm_tools.py +++ b/core/quivr_core/llm_tools/llm_tools.py @@ -1,23 +1,26 @@ -from typing import Dict, Any, Type, Union +from typing import Any, Dict, Type, Union from quivr_core.llm_tools.entity import ToolWrapper - -from quivr_core.llm_tools.web_search_tools import ( - WebSearchTools, -) - from quivr_core.llm_tools.other_tools import ( OtherTools, ) +from quivr_core.llm_tools.summarization_tool import ( + SummarizationTools, +) +from quivr_core.llm_tools.web_search_tools import ( + WebSearchTools, +) TOOLS_CATEGORIES = { WebSearchTools.name: WebSearchTools, + SummarizationTools.name: SummarizationTools, OtherTools.name: OtherTools, } # Register all ToolsList enums TOOLS_LISTS = { **{tool.value: tool for tool in WebSearchTools.tools}, + **{tool.value: tool for tool in SummarizationTools.tools}, **{tool.value: tool for tool in OtherTools.tools}, } diff --git a/core/quivr_core/llm_tools/summarization_tool.py b/core/quivr_core/llm_tools/summarization_tool.py new file mode 100644 index 000000000000..8bae0f77e1cf --- /dev/null +++ b/core/quivr_core/llm_tools/summarization_tool.py @@ -0,0 +1,186 @@ +import operator +from enum import Enum +from typing import Annotated, Any, Dict, List, Literal, TypedDict + +from langchain.chains.combine_documents.reduce import ( + acollapse_docs, + split_list_of_docs, +) +from langchain_core.documents import Document +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.prompts import ChatPromptTemplate +from langgraph.constants import Send +from langgraph.graph import END, START, StateGraph + +from quivr_core.llm_tools.entity import ToolRegistry, ToolsCategory, ToolWrapper + +# This code is Widely inspired by https://python.langchain.com/docs/tutorials/summarization/ + +default_map_prompt = "Write a concise summary of the following:\\n\\n{context}" + +default_reduce_prompt = """ +The following is a set of summaries: +{docs} +Take these and distill it into a final, consolidated summary +of the main themes. +""" + + +class SummaryToolsList(str, Enum): + MAPREDUCE = "map_reduce" + + +# This will be the overall state of the main graph. +# It will contain the input document contents, corresponding +# summaries, and a final summary. +class OverallState(TypedDict): + # Notice here we use the operator.add + # This is because we want combine all the summaries we generate + # from individual nodes back into one list - this is essentially + # the "reduce" part + contents: List[str] + summaries: Annotated[list, operator.add] + collapsed_summaries: List[Document] + final_summary: str + + +# This will be the state of the node that we will "map" all +# documents to in order to generate summaries +class SummaryState(TypedDict): + content: str + + +class MPSummarizeTool: + def __init__( + self, + llm: BaseChatModel, + token_max: int = 8000, + map_prompt: str = default_map_prompt, + reduce_prompt: str = default_reduce_prompt, + ): + self.map_prompt = ChatPromptTemplate.from_messages([("system", map_prompt)]) + self.reduce_prompt = ChatPromptTemplate([("human", reduce_prompt)]) + self.llm = llm + self.token_max = token_max + + def length_function(self, documents: List[Document]) -> int: + """Get number of tokens for input contents.""" + return sum(self.llm.get_num_tokens(doc.page_content) for doc in documents) + + # Here we generate a summary, given a document + async def generate_summary(self, state: SummaryState): + prompt = self.map_prompt.invoke(state["content"]) + response = await self.llm.ainvoke(prompt) + return {"summaries": [response.content]} + + # Here we define the logic to map out over the documents + # We will use this an edge in the graph + def map_summaries(self, state: OverallState): + # We will return a list of `Send` objects + # Each `Send` object consists of the name of a node in the graph + # as well as the state to send to that node + return [ + Send("generate_summary", {"content": content}) + for content in state["contents"] + ] + + def collect_summaries(self, state: OverallState): + return { + "collapsed_summaries": [Document(summary) for summary in state["summaries"]] + } + + async def _reduce(self, input: list) -> str: + prompt = self.reduce_prompt.invoke(input) + response = await self.llm.ainvoke(prompt) + return response.content + + # Add node to collapse summaries + async def collapse_summaries(self, state: OverallState): + doc_lists = split_list_of_docs( + state["collapsed_summaries"], self.length_function, self.token_max + ) + results = [] + for doc_list in doc_lists: + results.append(await acollapse_docs(doc_list, self._reduce)) + + return {"collapsed_summaries": results} + + # This represents a conditional edge in the graph that determines + # if we should collapse the summaries or not + def should_collapse( + self, + state: OverallState, + ) -> Literal["collapse_summaries", "generate_final_summary"]: + num_tokens = self.length_function(state["collapsed_summaries"]) + if num_tokens > self.token_max: + return "collapse_summaries" + else: + return "generate_final_summary" + + # Here we will generate the final summary + async def generate_final_summary(self, state: OverallState): + response = await self._reduce(state["collapsed_summaries"]) + return {"final_summary": response} + + def build(self): + summary_graph = StateGraph(OverallState) + + summary_graph.add_node( + "generate_summary", self.generate_summary + ) # same as before + summary_graph.add_node("collect_summaries", self.collect_summaries) + summary_graph.add_node("collapse_summaries", self.collapse_summaries) + summary_graph.add_node("generate_final_summary", self.generate_final_summary) + + # Edges: + summary_graph.add_conditional_edges( + START, self.map_summaries, ["generate_summary"] + ) + summary_graph.add_edge("generate_summary", "collect_summaries") + summary_graph.add_conditional_edges("collect_summaries", self.should_collapse) + summary_graph.add_conditional_edges("collapse_summaries", self.should_collapse) + summary_graph.add_edge("generate_final_summary", END) + + return summary_graph.compile() + + +def create_summary_tool(config: Dict[str, Any]): + summary_tool = MPSummarizeTool( + config.get("map_prompt", None), + config.get("reduce_prompt", None), + config.get("llm", None), + config.get("token_max", None), + ) + + def format_input(task: str) -> Dict[str, Any]: + return {"query": task} + + def format_output(response: Any) -> List[Document]: + return [ + Document( + page_content=d["content"], + ) + for d in response + ] + + return ToolWrapper(summary_tool, format_input, format_output) + + +# Initialize the registry and register tools +summarization_tool_registry = ToolRegistry() +summarization_tool_registry.register_tool( + SummaryToolsList.MAPREDUCE, create_summary_tool +) + + +def create_summarization_tool(tool_name: str, config: Dict[str, Any]) -> ToolWrapper: + return summarization_tool_registry.create_tool(tool_name, config) + + +SummarizationTools = ToolsCategory( + name="Summarization", + description="Tools for summarizing documents", + tools=[SummaryToolsList.MAPREDUCE], + default_tool=SummaryToolsList.MAPREDUCE, + create_tool=create_summarization_tool, +) diff --git a/core/quivr_core/rag/quivr_rag_langgraph.py b/core/quivr_core/rag/quivr_rag_langgraph.py index 3fd3349bd9a9..6fd2b2803b6d 100644 --- a/core/quivr_core/rag/quivr_rag_langgraph.py +++ b/core/quivr_core/rag/quivr_rag_langgraph.py @@ -24,13 +24,12 @@ from langchain_core.messages.ai import AIMessageChunk from langchain_core.prompts.base import BasePromptTemplate from langchain_core.vectorstores import VectorStore +from langfuse.callback import CallbackHandler from langgraph.graph import END, START, StateGraph from langgraph.graph.message import add_messages from langgraph.types import Send from pydantic import BaseModel, Field -from langfuse.callback import CallbackHandler - from quivr_core.llm import LLMEndpoint from quivr_core.llm_tools.llm_tools import LLMToolFactory from quivr_core.rag.entities.chat import ChatHistory @@ -502,7 +501,7 @@ async def rewrite(self, state: AgentState) -> AgentState: task_ids = [jobs[1] for jobs in async_jobs] if async_jobs else [] # Replace each question with its condensed version - for response, task_id in zip(responses, task_ids): + for response, task_id in zip(responses, task_ids, strict=False): tasks.set_definition(task_id, response.content) return {**state, "tasks": tasks} @@ -558,7 +557,7 @@ async def tool_routing(self, state: AgentState): ) task_ids = [jobs[1] for jobs in async_jobs] if async_jobs else [] - for response, task_id in zip(responses, task_ids): + for response, task_id in zip(responses, task_ids, strict=False): tasks.set_completion(task_id, response.is_task_completable) if not response.is_task_completable and response.tool: tasks.set_tool(task_id, response.tool) @@ -599,7 +598,7 @@ async def run_tool(self, state: AgentState) -> AgentState: ) task_ids = [jobs[1] for jobs in async_jobs] if async_jobs else [] - for response, task_id in zip(responses, task_ids): + for response, task_id in zip(responses, task_ids, strict=False): _docs = tool_wrapper.format_output(response) _docs = self.filter_chunks_by_relevance(_docs) tasks.set_docs(task_id, _docs) @@ -634,7 +633,6 @@ async def retrieve(self, state: AgentState) -> AgentState: compression_retriever = ContextualCompressionRetriever( base_compressor=reranker, base_retriever=base_retriever ) - # Prepare the async tasks for all questions async_jobs = [] for task_id in tasks.ids: @@ -652,7 +650,7 @@ async def retrieve(self, state: AgentState) -> AgentState: task_ids = [task[1] for task in async_jobs] if async_jobs else [] # Process responses and associate docs with tasks - for response, task_id in zip(responses, task_ids): + for response, task_id in zip(responses, task_ids, strict=False): _docs = self.filter_chunks_by_relevance(response) tasks.set_docs(task_id, _docs) # Associate docs with the specific task @@ -715,7 +713,7 @@ async def dynamic_retrieve(self, state: AgentState) -> AgentState: task_ids = [jobs[1] for jobs in async_jobs] if async_jobs else [] _n = [] - for response, task_id in zip(responses, task_ids): + for response, task_id in zip(responses, task_ids, strict=False): _docs = self.filter_chunks_by_relevance(response) _n.append(len(_docs)) tasks.set_docs(task_id, _docs) @@ -737,6 +735,10 @@ async def dynamic_retrieve(self, state: AgentState) -> AgentState: return {**state, "tasks": tasks} + def retrieve_all_chunks_from_file(self, file_id: UUID) -> List[Document]: + retriever = self.get_retriever() + return retriever.get_by_ids(ids=[file_id]) + def _sort_docs_by_relevance(self, docs: List[Document]) -> List[Document]: return sorted( docs, @@ -1012,7 +1014,7 @@ def invoke_structured_output( ) return structured_llm.invoke(prompt) except openai.BadRequestError: - structured_llm = self.llm_endpoint._llm.with_structured_output(output_class) + structured_llm = self.llm_endpoint._llm.with_stuctured_output(output_class) return structured_llm.invoke(prompt) def _build_rag_prompt_inputs( diff --git a/core/tests/summarization_config.yaml b/core/tests/summarization_config.yaml new file mode 100644 index 000000000000..c00745132bcd --- /dev/null +++ b/core/tests/summarization_config.yaml @@ -0,0 +1,40 @@ +ingestion_config: + parser_config: + megaparse_config: + strategy: "fast" + pdf_parser: "unstructured" + splitter_config: + chunk_size: 400 + chunk_overlap: 100 + + +retrieval_config: + workflow_config: + name: "Summarizer" + available_tools: + - "Summarization" + nodes: + - name: "START" + edges: ["retrieve_all_chunks_from_file"] + + - name: "retrieve_all_chunks_from_file" + edges: ["tool"] + + - name: "run_tool" + edges: ["END"] + + llm_config: + # The LLM supplier to use + supplier: "openai" + + # The model to use for the LLM for the given supplier + model: "gpt-3.5-turbo-0125" + + max_context_tokens: 2000 + + # Maximum number of tokens to pass to the LLM + # as a context to generate the answer + max_output_tokens: 2000 + + temperature: 0.7 + streaming: true diff --git a/core/tests/test_quivr_rag.py b/core/tests/test_quivr_rag.py index f6184bf16655..84405292bc03 100644 --- a/core/tests/test_quivr_rag.py +++ b/core/tests/test_quivr_rag.py @@ -1,9 +1,9 @@ from uuid import uuid4 import pytest +from quivr_core.llm import LLMEndpoint from quivr_core.rag.entities.chat import ChatHistory from quivr_core.rag.entities.config import LLMEndpointConfig, RetrievalConfig -from quivr_core.llm import LLMEndpoint from quivr_core.rag.entities.models import ParsedRAGChunkResponse, RAGResponseMetadata from quivr_core.rag.quivr_rag_langgraph import QuivrQARAGLangGraph