Skip to content

Commit

Permalink
feat: draft summarization nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
chloedia committed Jan 2, 2025
1 parent e0ccd3d commit 85bc1e5
Show file tree
Hide file tree
Showing 6 changed files with 249 additions and 16 deletions.
2 changes: 2 additions & 0 deletions core/quivr_core/brain/brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,8 @@ async def ask_streaming(
if rag_pipeline is None:
rag_pipeline = QuivrQARAGLangGraph

logger.info(f"LLLL Using vector db : {self.vector_db}")

rag_instance = rag_pipeline(
retrieval_config=retrieval_config, llm=llm, vector_store=self.vector_db
)
Expand Down
15 changes: 9 additions & 6 deletions core/quivr_core/llm_tools/llm_tools.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,26 @@
from typing import Dict, Any, Type, Union
from typing import Any, Dict, Type, Union

from quivr_core.llm_tools.entity import ToolWrapper

from quivr_core.llm_tools.web_search_tools import (
WebSearchTools,
)

from quivr_core.llm_tools.other_tools import (
OtherTools,
)
from quivr_core.llm_tools.summarization_tool import (
SummarizationTools,
)
from quivr_core.llm_tools.web_search_tools import (
WebSearchTools,
)

TOOLS_CATEGORIES = {
WebSearchTools.name: WebSearchTools,
SummarizationTools.name: SummarizationTools,
OtherTools.name: OtherTools,
}

# Register all ToolsList enums
TOOLS_LISTS = {
**{tool.value: tool for tool in WebSearchTools.tools},
**{tool.value: tool for tool in SummarizationTools.tools},
**{tool.value: tool for tool in OtherTools.tools},
}

Expand Down
186 changes: 186 additions & 0 deletions core/quivr_core/llm_tools/summarization_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
import operator
from enum import Enum
from typing import Annotated, Any, Dict, List, Literal, TypedDict

from langchain.chains.combine_documents.reduce import (
acollapse_docs,
split_list_of_docs,
)
from langchain_core.documents import Document
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.prompts import ChatPromptTemplate
from langgraph.constants import Send
from langgraph.graph import END, START, StateGraph

from quivr_core.llm_tools.entity import ToolRegistry, ToolsCategory, ToolWrapper

# This code is Widely inspired by https://python.langchain.com/docs/tutorials/summarization/

default_map_prompt = "Write a concise summary of the following:\\n\\n{context}"

default_reduce_prompt = """
The following is a set of summaries:
{docs}
Take these and distill it into a final, consolidated summary
of the main themes.
"""


class SummaryToolsList(str, Enum):
MAPREDUCE = "map_reduce"


# This will be the overall state of the main graph.
# It will contain the input document contents, corresponding
# summaries, and a final summary.
class OverallState(TypedDict):
# Notice here we use the operator.add
# This is because we want combine all the summaries we generate
# from individual nodes back into one list - this is essentially
# the "reduce" part
contents: List[str]
summaries: Annotated[list, operator.add]
collapsed_summaries: List[Document]
final_summary: str


# This will be the state of the node that we will "map" all
# documents to in order to generate summaries
class SummaryState(TypedDict):
content: str


class MPSummarizeTool:
def __init__(
self,
llm: BaseChatModel,
token_max: int = 8000,
map_prompt: str = default_map_prompt,
reduce_prompt: str = default_reduce_prompt,
):
self.map_prompt = ChatPromptTemplate.from_messages([("system", map_prompt)])
self.reduce_prompt = ChatPromptTemplate([("human", reduce_prompt)])
self.llm = llm
self.token_max = token_max

def length_function(self, documents: List[Document]) -> int:
"""Get number of tokens for input contents."""
return sum(self.llm.get_num_tokens(doc.page_content) for doc in documents)

# Here we generate a summary, given a document
async def generate_summary(self, state: SummaryState):
prompt = self.map_prompt.invoke(state["content"])
response = await self.llm.ainvoke(prompt)
return {"summaries": [response.content]}

# Here we define the logic to map out over the documents
# We will use this an edge in the graph
def map_summaries(self, state: OverallState):
# We will return a list of `Send` objects
# Each `Send` object consists of the name of a node in the graph
# as well as the state to send to that node
return [
Send("generate_summary", {"content": content})
for content in state["contents"]
]

def collect_summaries(self, state: OverallState):
return {
"collapsed_summaries": [Document(summary) for summary in state["summaries"]]
}

async def _reduce(self, input: list) -> str:
prompt = self.reduce_prompt.invoke(input)
response = await self.llm.ainvoke(prompt)
return response.content

# Add node to collapse summaries
async def collapse_summaries(self, state: OverallState):
doc_lists = split_list_of_docs(
state["collapsed_summaries"], self.length_function, self.token_max
)
results = []
for doc_list in doc_lists:
results.append(await acollapse_docs(doc_list, self._reduce))

return {"collapsed_summaries": results}

# This represents a conditional edge in the graph that determines
# if we should collapse the summaries or not
def should_collapse(
self,
state: OverallState,
) -> Literal["collapse_summaries", "generate_final_summary"]:
num_tokens = self.length_function(state["collapsed_summaries"])
if num_tokens > self.token_max:
return "collapse_summaries"
else:
return "generate_final_summary"

# Here we will generate the final summary
async def generate_final_summary(self, state: OverallState):
response = await self._reduce(state["collapsed_summaries"])
return {"final_summary": response}

def build(self):
summary_graph = StateGraph(OverallState)

summary_graph.add_node(
"generate_summary", self.generate_summary
) # same as before
summary_graph.add_node("collect_summaries", self.collect_summaries)
summary_graph.add_node("collapse_summaries", self.collapse_summaries)
summary_graph.add_node("generate_final_summary", self.generate_final_summary)

# Edges:
summary_graph.add_conditional_edges(
START, self.map_summaries, ["generate_summary"]
)
summary_graph.add_edge("generate_summary", "collect_summaries")
summary_graph.add_conditional_edges("collect_summaries", self.should_collapse)
summary_graph.add_conditional_edges("collapse_summaries", self.should_collapse)
summary_graph.add_edge("generate_final_summary", END)

return summary_graph.compile()


def create_summary_tool(config: Dict[str, Any]):
summary_tool = MPSummarizeTool(
config.get("map_prompt", None),
config.get("reduce_prompt", None),
config.get("llm", None),
config.get("token_max", None),
)

def format_input(task: str) -> Dict[str, Any]:
return {"query": task}

def format_output(response: Any) -> List[Document]:
return [
Document(
page_content=d["content"],
)
for d in response
]

return ToolWrapper(summary_tool, format_input, format_output)


# Initialize the registry and register tools
summarization_tool_registry = ToolRegistry()
summarization_tool_registry.register_tool(
SummaryToolsList.MAPREDUCE, create_summary_tool
)


def create_summarization_tool(tool_name: str, config: Dict[str, Any]) -> ToolWrapper:
return summarization_tool_registry.create_tool(tool_name, config)


SummarizationTools = ToolsCategory(
name="Summarization",
description="Tools for summarizing documents",
tools=[SummaryToolsList.MAPREDUCE],
default_tool=SummaryToolsList.MAPREDUCE,
create_tool=create_summarization_tool,
)
20 changes: 11 additions & 9 deletions core/quivr_core/rag/quivr_rag_langgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,12 @@
from langchain_core.messages.ai import AIMessageChunk
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.vectorstores import VectorStore
from langfuse.callback import CallbackHandler
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from langgraph.types import Send
from pydantic import BaseModel, Field

from langfuse.callback import CallbackHandler

from quivr_core.llm import LLMEndpoint
from quivr_core.llm_tools.llm_tools import LLMToolFactory
from quivr_core.rag.entities.chat import ChatHistory
Expand Down Expand Up @@ -502,7 +501,7 @@ async def rewrite(self, state: AgentState) -> AgentState:
task_ids = [jobs[1] for jobs in async_jobs] if async_jobs else []

# Replace each question with its condensed version
for response, task_id in zip(responses, task_ids):
for response, task_id in zip(responses, task_ids, strict=False):
tasks.set_definition(task_id, response.content)

return {**state, "tasks": tasks}
Expand Down Expand Up @@ -558,7 +557,7 @@ async def tool_routing(self, state: AgentState):
)
task_ids = [jobs[1] for jobs in async_jobs] if async_jobs else []

for response, task_id in zip(responses, task_ids):
for response, task_id in zip(responses, task_ids, strict=False):
tasks.set_completion(task_id, response.is_task_completable)
if not response.is_task_completable and response.tool:
tasks.set_tool(task_id, response.tool)
Expand Down Expand Up @@ -599,7 +598,7 @@ async def run_tool(self, state: AgentState) -> AgentState:
)
task_ids = [jobs[1] for jobs in async_jobs] if async_jobs else []

for response, task_id in zip(responses, task_ids):
for response, task_id in zip(responses, task_ids, strict=False):
_docs = tool_wrapper.format_output(response)
_docs = self.filter_chunks_by_relevance(_docs)
tasks.set_docs(task_id, _docs)
Expand Down Expand Up @@ -634,7 +633,6 @@ async def retrieve(self, state: AgentState) -> AgentState:
compression_retriever = ContextualCompressionRetriever(
base_compressor=reranker, base_retriever=base_retriever
)

# Prepare the async tasks for all questions
async_jobs = []
for task_id in tasks.ids:
Expand All @@ -652,7 +650,7 @@ async def retrieve(self, state: AgentState) -> AgentState:
task_ids = [task[1] for task in async_jobs] if async_jobs else []

# Process responses and associate docs with tasks
for response, task_id in zip(responses, task_ids):
for response, task_id in zip(responses, task_ids, strict=False):
_docs = self.filter_chunks_by_relevance(response)
tasks.set_docs(task_id, _docs) # Associate docs with the specific task

Expand Down Expand Up @@ -715,7 +713,7 @@ async def dynamic_retrieve(self, state: AgentState) -> AgentState:
task_ids = [jobs[1] for jobs in async_jobs] if async_jobs else []

_n = []
for response, task_id in zip(responses, task_ids):
for response, task_id in zip(responses, task_ids, strict=False):
_docs = self.filter_chunks_by_relevance(response)
_n.append(len(_docs))
tasks.set_docs(task_id, _docs)
Expand All @@ -737,6 +735,10 @@ async def dynamic_retrieve(self, state: AgentState) -> AgentState:

return {**state, "tasks": tasks}

def retrieve_all_chunks_from_file(self, file_id: UUID) -> List[Document]:
retriever = self.get_retriever()
return retriever.get_by_ids(ids=[file_id])

def _sort_docs_by_relevance(self, docs: List[Document]) -> List[Document]:
return sorted(
docs,
Expand Down Expand Up @@ -1012,7 +1014,7 @@ def invoke_structured_output(
)
return structured_llm.invoke(prompt)
except openai.BadRequestError:
structured_llm = self.llm_endpoint._llm.with_structured_output(output_class)
structured_llm = self.llm_endpoint._llm.with_stuctured_output(output_class)
return structured_llm.invoke(prompt)

def _build_rag_prompt_inputs(
Expand Down
40 changes: 40 additions & 0 deletions core/tests/summarization_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
ingestion_config:
parser_config:
megaparse_config:
strategy: "fast"
pdf_parser: "unstructured"
splitter_config:
chunk_size: 400
chunk_overlap: 100


retrieval_config:
workflow_config:
name: "Summarizer"
available_tools:
- "Summarization"
nodes:
- name: "START"
edges: ["retrieve_all_chunks_from_file"]

- name: "retrieve_all_chunks_from_file"
edges: ["tool"]

- name: "run_tool"
edges: ["END"]

llm_config:
# The LLM supplier to use
supplier: "openai"

# The model to use for the LLM for the given supplier
model: "gpt-3.5-turbo-0125"

max_context_tokens: 2000

# Maximum number of tokens to pass to the LLM
# as a context to generate the answer
max_output_tokens: 2000

temperature: 0.7
streaming: true
2 changes: 1 addition & 1 deletion core/tests/test_quivr_rag.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from uuid import uuid4

import pytest
from quivr_core.llm import LLMEndpoint
from quivr_core.rag.entities.chat import ChatHistory
from quivr_core.rag.entities.config import LLMEndpointConfig, RetrievalConfig
from quivr_core.llm import LLMEndpoint
from quivr_core.rag.entities.models import ParsedRAGChunkResponse, RAGResponseMetadata
from quivr_core.rag.quivr_rag_langgraph import QuivrQARAGLangGraph

Expand Down

0 comments on commit 85bc1e5

Please sign in to comment.