Skip to content

Commit

Permalink
Merge pull request #13 from artefactory/refactoring/rag_components
Browse files Browse the repository at this point in the history
♻️ refactoring rag_components
  • Loading branch information
baptiste-pasquier authored Mar 25, 2024
2 parents 754e333 + b03e00a commit 7116b19
Show file tree
Hide file tree
Showing 23 changed files with 315 additions and 341 deletions.
14 changes: 9 additions & 5 deletions backend/rag_1/chain.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
"""RAG chain for Option 1."""

from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.output_parsers.string import StrOutputParser
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_core.runnables.base import RunnableSequence, RunnableSerializable
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import (
RunnableLambda,
RunnablePassthrough,
RunnableSequence,
RunnableSerializable,
)
from omegaconf.dictconfig import DictConfig
from pydantic import BaseModel

Expand All @@ -13,8 +17,8 @@
from backend.rag_components.chain_links.retrieve_and_format_multimodal_docs import (
fetch_docs_chain,
)
from backend.utils.llm import get_vision_llm
from backend.utils.retriever import get_retriever
from backend.rag_components.llm import get_vision_llm
from backend.rag_components.retriever import get_retriever

from . import prompts

Expand Down
4 changes: 2 additions & 2 deletions backend/rag_1/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
from unstructured.partition.pdf import partition_pdf

from backend.rag_1.config import validate_config
from backend.utils.unstructured import (
from backend.rag_components.unstructured import (
load_chunking_func,
select_images,
select_tables,
select_texts,
)
from backend.utils.vectorstore import get_vectorstore
from backend.rag_components.vectorstore import get_vectorstore

logger = logging.getLogger(__name__)

Expand Down
8 changes: 4 additions & 4 deletions backend/rag_1/notebook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,16 @@
"\n",
"from backend.rag_1.chain import get_chain\n",
"from backend.rag_1.config import validate_config\n",
"from backend.utils.elements import convert_documents_to_elements\n",
"from backend.utils.retriever import get_retriever\n",
"from backend.utils.unstructured import (\n",
"from backend.rag_components.elements import convert_documents_to_elements\n",
"from backend.rag_components.retriever import get_retriever\n",
"from backend.rag_components.unstructured import (\n",
" load_chunking_func,\n",
" select_images,\n",
" select_tables,\n",
" select_texts,\n",
")\n",
"from backend.rag_components.vectorstore import get_vectorstore\n",
"from backend.utils.utils import format_time_delta\n",
"from backend.utils.vectorstore import get_vectorstore\n",
"\n",
"logging.basicConfig(format=\"[%(asctime)s] - %(name)s - %(levelname)s - %(message)s\")\n",
"logging.getLogger(\"backend\").setLevel(logging.INFO)\n",
Expand Down
10 changes: 5 additions & 5 deletions backend/rag_2/chain.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""RAG chain for Option 2."""

from langchain_core.output_parsers.string import StrOutputParser
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.runnables.base import (
from langchain_core.runnables import (
RunnablePassthrough,
RunnableSequence,
RunnableSerializable,
)
Expand All @@ -16,8 +16,8 @@
from backend.rag_components.chain_links.retrieve_and_format_text_docs import (
fetch_docs_chain,
)
from backend.utils.llm import get_text_llm
from backend.utils.retriever import get_retriever
from backend.rag_components.llm import get_text_llm
from backend.rag_components.retriever import get_retriever

from . import prompts

Expand Down
2 changes: 1 addition & 1 deletion backend/rag_2/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ store:
root_path: "${..path.database}/multi_vector_retriever_metadata/"

retriever:
_target_: backend.utils.multi_vector.ThresholdedMultiVectorRetriever
_target_: backend.rag_components.multi_vector.ThresholdedMultiVectorRetriever
vectorstore: ${..vectorstore}
byte_store: ${..store}
id_key: "doc_id"
Expand Down
135 changes: 22 additions & 113 deletions backend/rag_2/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,14 @@

from backend.rag_2 import prompts
from backend.rag_2.config import validate_config
from backend.utils.elements import Image, Table, Text
from backend.utils.ingest import add_elements_to_multivector_retriever
from backend.utils.llm import get_text_llm, get_vision_llm
from backend.utils.retriever import get_retriever
from backend.utils.summarization import (
generate_image_summaries,
generate_text_summaries,
from backend.rag_components.ingest import (
add_elements_to_multivector_retriever,
apply_summarize_image,
apply_summarize_table,
apply_summarize_text,
)
from backend.utils.unstructured import (
from backend.rag_components.retriever import get_retriever
from backend.rag_components.unstructured import (
load_chunking_func,
select_images,
select_tables,
Expand All @@ -30,108 +29,6 @@
logger = logging.getLogger(__name__)


async def apply_summarize_text(text_list: list[Text], config: DictConfig) -> None:
"""Apply text summarization to a list of Text elements.
The function directly modifies the Text elements inplace.
Args:
text_list (list[Text]): List of Text elements.
config (DictConfig): Configuration object.
"""
if config.ingest.summarize_text:
str_list = [text.text for text in text_list]

model = get_text_llm(config)

text_summaries = await generate_text_summaries(
str_list, prompt_template=prompts.TEXT_SUMMARIZATION_PROMPT, model=model
)

for text in text_list:
text.set_summary(text_summaries.pop(0))

else:
logger.info("Skipping text summarization")

return


async def apply_summarize_table(table_list: list[Table], config: DictConfig) -> None:
"""Apply table summarization to a list of Table elements.
The function directly modifies the Table elements inplace.
Args:
table_list (list[Table]): List of Table elements.
config (DictConfig): Configuration object.
Raises:
ValueError: If the table format is "image" and summarize_table is False.
ValueError: If the table format is invalid.
"""
if config.ingest.summarize_table:
table_format = config.ingest.table_format
if table_format in ["text", "html"]:
str_list = [table.text for table in table_list]

model = get_text_llm(config)

table_summaries = await generate_text_summaries(
str_list,
prompt_template=prompts.TABLE_SUMMARIZATION_PROMPT,
model=model,
)
elif config.ingest.table_format == "image":
img_base64_list = [table.base64 for table in table_list]
img_mime_type_list = [table.mime_type for table in table_list]
model = get_vision_llm(config)

table_summaries = await generate_image_summaries(
img_base64_list,
img_mime_type_list,
prompt=prompts.TABLE_SUMMARIZATION_PROMPT,
model=model,
)
else:
raise ValueError(f"Invalid table format: {table_format}")

for table in table_list:
table.set_summary(table_summaries.pop(0))

else:
logger.info("Skipping table summarization")

return


async def apply_summarize_image(image_list: list[Image], config: DictConfig) -> None:
"""Apply image summarization to a list of Image elements.
The function directly modifies the Image elements inplace.
Args:
image_list (list[Image]): List of Image elements.
config (DictConfig): Configuration object.
"""
img_base64_list = [image.base64 for image in image_list]
img_mime_type_list = [image.mime_type for image in image_list]

model = get_vision_llm(config)

image_summaries = await generate_image_summaries(
img_base64_list,
img_mime_type_list,
prompt=prompts.IMAGE_SUMMARIZATION_PROMPT,
model=model,
)

for image in image_list:
image.set_summary(image_summaries.pop(0))

return


async def ingest_pdf(file_path: str | Path, config: DictConfig) -> None:
"""Ingest a PDF file.
Expand Down Expand Up @@ -173,13 +70,25 @@ async def ingest_pdf(file_path: str | Path, config: DictConfig) -> None:
)

# Summarize text
await apply_summarize_text(texts, config)
await apply_summarize_text(
text_list=texts,
config=config,
prompt_template=prompts.TEXT_SUMMARIZATION_PROMPT,
)

# Summarize tables
await apply_summarize_table(tables, config)
await apply_summarize_table(
table_list=tables,
config=config,
prompt_template=prompts.TABLE_SUMMARIZATION_PROMPT,
)

# Summarize images
await apply_summarize_image(images, config)
await apply_summarize_image(
image_list=images,
config=config,
prompt_template=prompts.IMAGE_SUMMARIZATION_PROMPT,
)

retriever = get_retriever(config)

Expand Down
29 changes: 21 additions & 8 deletions backend/rag_2/notebook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -72,17 +72,18 @@
"from hydra import compose, initialize\n",
"from unstructured.partition.pdf import partition_pdf\n",
"\n",
"from backend.rag_2 import prompts\n",
"from backend.rag_2.chain import get_chain\n",
"from backend.rag_2.config import validate_config\n",
"from backend.rag_2.ingest import (\n",
"from backend.rag_components.elements import convert_documents_to_elements\n",
"from backend.rag_components.ingest import (\n",
" add_elements_to_multivector_retriever,\n",
" apply_summarize_image,\n",
" apply_summarize_table,\n",
" apply_summarize_text,\n",
")\n",
"from backend.utils.elements import convert_documents_to_elements\n",
"from backend.utils.ingest import add_elements_to_multivector_retriever\n",
"from backend.utils.retriever import get_retriever\n",
"from backend.utils.unstructured import (\n",
"from backend.rag_components.retriever import get_retriever\n",
"from backend.rag_components.unstructured import (\n",
" load_chunking_func,\n",
" select_images,\n",
" select_tables,\n",
Expand Down Expand Up @@ -295,7 +296,11 @@
"outputs": [],
"source": [
"# Summarize text\n",
"await apply_summarize_text(texts, config)\n",
"await apply_summarize_text(\n",
" text_list=texts,\n",
" config=config,\n",
" prompt_template=prompts.TEXT_SUMMARIZATION_PROMPT,\n",
")\n",
"for text in texts[:N_DISPLAY]:\n",
" display(text)"
]
Expand All @@ -307,7 +312,11 @@
"outputs": [],
"source": [
"# Summarize tables\n",
"await apply_summarize_table(tables, config)\n",
"await apply_summarize_table(\n",
" table_list=tables,\n",
" config=config,\n",
" prompt_template=prompts.TABLE_SUMMARIZATION_PROMPT,\n",
")\n",
"for table in tables[:N_DISPLAY]:\n",
" display(table)"
]
Expand All @@ -319,7 +328,11 @@
"outputs": [],
"source": [
"# Summarize images\n",
"await apply_summarize_image(images, config)\n",
"await apply_summarize_image(\n",
" image_list=images,\n",
" config=config,\n",
" prompt_template=prompts.IMAGE_SUMMARIZATION_PROMPT,\n",
")\n",
"for image in images[:N_DISPLAY]:\n",
" display(image)"
]
Expand Down
11 changes: 6 additions & 5 deletions backend/rag_3/chain.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""RAG chain for Option 3."""

from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.output_parsers.string import StrOutputParser
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_core.runnables.base import (
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import (
RunnableLambda,
RunnablePassthrough,
RunnableSequence,
RunnableSerializable,
)
Expand All @@ -16,8 +17,8 @@
from backend.rag_components.chain_links.retrieve_and_format_multimodal_docs import (
fetch_docs_chain,
)
from backend.utils.llm import get_vision_llm
from backend.utils.retriever import get_retriever
from backend.rag_components.llm import get_vision_llm
from backend.rag_components.retriever import get_retriever

from . import prompts

Expand Down
2 changes: 1 addition & 1 deletion backend/rag_3/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ store:
root_path: "${..path.database}/multi_vector_retriever_metadata/"

retriever:
_target_: backend.utils.multi_vector.ThresholdedMultiVectorRetriever
_target_: backend.rag_components.multi_vector.ThresholdedMultiVectorRetriever
vectorstore: ${..vectorstore}
byte_store: ${..store}
id_key: "doc_id"
Expand Down
Loading

0 comments on commit 7116b19

Please sign in to comment.