diff --git a/backend/rag_1/chain.py b/backend/rag_1/chain.py index a0c5103..09381ef 100644 --- a/backend/rag_1/chain.py +++ b/backend/rag_1/chain.py @@ -1,9 +1,13 @@ """RAG chain for Option 1.""" from langchain_core.messages import BaseMessage, HumanMessage -from langchain_core.output_parsers.string import StrOutputParser -from langchain_core.runnables import RunnableLambda, RunnablePassthrough -from langchain_core.runnables.base import RunnableSequence, RunnableSerializable +from langchain_core.output_parsers import StrOutputParser +from langchain_core.runnables import ( + RunnableLambda, + RunnablePassthrough, + RunnableSequence, + RunnableSerializable, +) from omegaconf.dictconfig import DictConfig from pydantic import BaseModel @@ -13,8 +17,8 @@ from backend.rag_components.chain_links.retrieve_and_format_multimodal_docs import ( fetch_docs_chain, ) -from backend.utils.llm import get_vision_llm -from backend.utils.retriever import get_retriever +from backend.rag_components.llm import get_vision_llm +from backend.rag_components.retriever import get_retriever from . import prompts diff --git a/backend/rag_1/ingest.py b/backend/rag_1/ingest.py index ac33adf..a09d5f1 100644 --- a/backend/rag_1/ingest.py +++ b/backend/rag_1/ingest.py @@ -10,13 +10,13 @@ from unstructured.partition.pdf import partition_pdf from backend.rag_1.config import validate_config -from backend.utils.unstructured import ( +from backend.rag_components.unstructured import ( load_chunking_func, select_images, select_tables, select_texts, ) -from backend.utils.vectorstore import get_vectorstore +from backend.rag_components.vectorstore import get_vectorstore logger = logging.getLogger(__name__) diff --git a/backend/rag_1/notebook.ipynb b/backend/rag_1/notebook.ipynb index d440e6c..abced02 100644 --- a/backend/rag_1/notebook.ipynb +++ b/backend/rag_1/notebook.ipynb @@ -73,16 +73,16 @@ "\n", "from backend.rag_1.chain import get_chain\n", "from backend.rag_1.config import validate_config\n", - "from backend.utils.elements import convert_documents_to_elements\n", - "from backend.utils.retriever import get_retriever\n", - "from backend.utils.unstructured import (\n", + "from backend.rag_components.elements import convert_documents_to_elements\n", + "from backend.rag_components.retriever import get_retriever\n", + "from backend.rag_components.unstructured import (\n", " load_chunking_func,\n", " select_images,\n", " select_tables,\n", " select_texts,\n", ")\n", + "from backend.rag_components.vectorstore import get_vectorstore\n", "from backend.utils.utils import format_time_delta\n", - "from backend.utils.vectorstore import get_vectorstore\n", "\n", "logging.basicConfig(format=\"[%(asctime)s] - %(name)s - %(levelname)s - %(message)s\")\n", "logging.getLogger(\"backend\").setLevel(logging.INFO)\n", diff --git a/backend/rag_2/chain.py b/backend/rag_2/chain.py index 3eb7b6e..6cab9c7 100644 --- a/backend/rag_2/chain.py +++ b/backend/rag_2/chain.py @@ -1,9 +1,9 @@ """RAG chain for Option 2.""" -from langchain_core.output_parsers.string import StrOutputParser +from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate -from langchain_core.runnables import RunnablePassthrough -from langchain_core.runnables.base import ( +from langchain_core.runnables import ( + RunnablePassthrough, RunnableSequence, RunnableSerializable, ) @@ -16,8 +16,8 @@ from backend.rag_components.chain_links.retrieve_and_format_text_docs import ( fetch_docs_chain, ) -from backend.utils.llm import get_text_llm -from backend.utils.retriever import get_retriever +from backend.rag_components.llm import get_text_llm +from backend.rag_components.retriever import get_retriever from . import prompts diff --git a/backend/rag_2/config.yaml b/backend/rag_2/config.yaml index 7e23156..0482eb5 100644 --- a/backend/rag_2/config.yaml +++ b/backend/rag_2/config.yaml @@ -41,7 +41,7 @@ store: root_path: "${..path.database}/multi_vector_retriever_metadata/" retriever: - _target_: backend.utils.multi_vector.ThresholdedMultiVectorRetriever + _target_: backend.rag_components.multi_vector.ThresholdedMultiVectorRetriever vectorstore: ${..vectorstore} byte_store: ${..store} id_key: "doc_id" diff --git a/backend/rag_2/ingest.py b/backend/rag_2/ingest.py index 564802c..90dca1b 100644 --- a/backend/rag_2/ingest.py +++ b/backend/rag_2/ingest.py @@ -12,15 +12,14 @@ from backend.rag_2 import prompts from backend.rag_2.config import validate_config -from backend.utils.elements import Image, Table, Text -from backend.utils.ingest import add_elements_to_multivector_retriever -from backend.utils.llm import get_text_llm, get_vision_llm -from backend.utils.retriever import get_retriever -from backend.utils.summarization import ( - generate_image_summaries, - generate_text_summaries, +from backend.rag_components.ingest import ( + add_elements_to_multivector_retriever, + apply_summarize_image, + apply_summarize_table, + apply_summarize_text, ) -from backend.utils.unstructured import ( +from backend.rag_components.retriever import get_retriever +from backend.rag_components.unstructured import ( load_chunking_func, select_images, select_tables, @@ -30,108 +29,6 @@ logger = logging.getLogger(__name__) -async def apply_summarize_text(text_list: list[Text], config: DictConfig) -> None: - """Apply text summarization to a list of Text elements. - - The function directly modifies the Text elements inplace. - - Args: - text_list (list[Text]): List of Text elements. - config (DictConfig): Configuration object. - """ - if config.ingest.summarize_text: - str_list = [text.text for text in text_list] - - model = get_text_llm(config) - - text_summaries = await generate_text_summaries( - str_list, prompt_template=prompts.TEXT_SUMMARIZATION_PROMPT, model=model - ) - - for text in text_list: - text.set_summary(text_summaries.pop(0)) - - else: - logger.info("Skipping text summarization") - - return - - -async def apply_summarize_table(table_list: list[Table], config: DictConfig) -> None: - """Apply table summarization to a list of Table elements. - - The function directly modifies the Table elements inplace. - - Args: - table_list (list[Table]): List of Table elements. - config (DictConfig): Configuration object. - - Raises: - ValueError: If the table format is "image" and summarize_table is False. - ValueError: If the table format is invalid. - """ - if config.ingest.summarize_table: - table_format = config.ingest.table_format - if table_format in ["text", "html"]: - str_list = [table.text for table in table_list] - - model = get_text_llm(config) - - table_summaries = await generate_text_summaries( - str_list, - prompt_template=prompts.TABLE_SUMMARIZATION_PROMPT, - model=model, - ) - elif config.ingest.table_format == "image": - img_base64_list = [table.base64 for table in table_list] - img_mime_type_list = [table.mime_type for table in table_list] - model = get_vision_llm(config) - - table_summaries = await generate_image_summaries( - img_base64_list, - img_mime_type_list, - prompt=prompts.TABLE_SUMMARIZATION_PROMPT, - model=model, - ) - else: - raise ValueError(f"Invalid table format: {table_format}") - - for table in table_list: - table.set_summary(table_summaries.pop(0)) - - else: - logger.info("Skipping table summarization") - - return - - -async def apply_summarize_image(image_list: list[Image], config: DictConfig) -> None: - """Apply image summarization to a list of Image elements. - - The function directly modifies the Image elements inplace. - - Args: - image_list (list[Image]): List of Image elements. - config (DictConfig): Configuration object. - """ - img_base64_list = [image.base64 for image in image_list] - img_mime_type_list = [image.mime_type for image in image_list] - - model = get_vision_llm(config) - - image_summaries = await generate_image_summaries( - img_base64_list, - img_mime_type_list, - prompt=prompts.IMAGE_SUMMARIZATION_PROMPT, - model=model, - ) - - for image in image_list: - image.set_summary(image_summaries.pop(0)) - - return - - async def ingest_pdf(file_path: str | Path, config: DictConfig) -> None: """Ingest a PDF file. @@ -173,13 +70,25 @@ async def ingest_pdf(file_path: str | Path, config: DictConfig) -> None: ) # Summarize text - await apply_summarize_text(texts, config) + await apply_summarize_text( + text_list=texts, + config=config, + prompt_template=prompts.TEXT_SUMMARIZATION_PROMPT, + ) # Summarize tables - await apply_summarize_table(tables, config) + await apply_summarize_table( + table_list=tables, + config=config, + prompt_template=prompts.TABLE_SUMMARIZATION_PROMPT, + ) # Summarize images - await apply_summarize_image(images, config) + await apply_summarize_image( + image_list=images, + config=config, + prompt_template=prompts.IMAGE_SUMMARIZATION_PROMPT, + ) retriever = get_retriever(config) diff --git a/backend/rag_2/notebook.ipynb b/backend/rag_2/notebook.ipynb index 688ef5b..d227b42 100644 --- a/backend/rag_2/notebook.ipynb +++ b/backend/rag_2/notebook.ipynb @@ -72,17 +72,18 @@ "from hydra import compose, initialize\n", "from unstructured.partition.pdf import partition_pdf\n", "\n", + "from backend.rag_2 import prompts\n", "from backend.rag_2.chain import get_chain\n", "from backend.rag_2.config import validate_config\n", - "from backend.rag_2.ingest import (\n", + "from backend.rag_components.elements import convert_documents_to_elements\n", + "from backend.rag_components.ingest import (\n", + " add_elements_to_multivector_retriever,\n", " apply_summarize_image,\n", " apply_summarize_table,\n", " apply_summarize_text,\n", ")\n", - "from backend.utils.elements import convert_documents_to_elements\n", - "from backend.utils.ingest import add_elements_to_multivector_retriever\n", - "from backend.utils.retriever import get_retriever\n", - "from backend.utils.unstructured import (\n", + "from backend.rag_components.retriever import get_retriever\n", + "from backend.rag_components.unstructured import (\n", " load_chunking_func,\n", " select_images,\n", " select_tables,\n", @@ -295,7 +296,11 @@ "outputs": [], "source": [ "# Summarize text\n", - "await apply_summarize_text(texts, config)\n", + "await apply_summarize_text(\n", + " text_list=texts,\n", + " config=config,\n", + " prompt_template=prompts.TEXT_SUMMARIZATION_PROMPT,\n", + ")\n", "for text in texts[:N_DISPLAY]:\n", " display(text)" ] @@ -307,7 +312,11 @@ "outputs": [], "source": [ "# Summarize tables\n", - "await apply_summarize_table(tables, config)\n", + "await apply_summarize_table(\n", + " table_list=tables,\n", + " config=config,\n", + " prompt_template=prompts.TABLE_SUMMARIZATION_PROMPT,\n", + ")\n", "for table in tables[:N_DISPLAY]:\n", " display(table)" ] @@ -319,7 +328,11 @@ "outputs": [], "source": [ "# Summarize images\n", - "await apply_summarize_image(images, config)\n", + "await apply_summarize_image(\n", + " image_list=images,\n", + " config=config,\n", + " prompt_template=prompts.IMAGE_SUMMARIZATION_PROMPT,\n", + ")\n", "for image in images[:N_DISPLAY]:\n", " display(image)" ] diff --git a/backend/rag_3/chain.py b/backend/rag_3/chain.py index 58ed3ad..133f161 100644 --- a/backend/rag_3/chain.py +++ b/backend/rag_3/chain.py @@ -1,9 +1,10 @@ """RAG chain for Option 3.""" from langchain_core.messages import BaseMessage, HumanMessage -from langchain_core.output_parsers.string import StrOutputParser -from langchain_core.runnables import RunnableLambda, RunnablePassthrough -from langchain_core.runnables.base import ( +from langchain_core.output_parsers import StrOutputParser +from langchain_core.runnables import ( + RunnableLambda, + RunnablePassthrough, RunnableSequence, RunnableSerializable, ) @@ -16,8 +17,8 @@ from backend.rag_components.chain_links.retrieve_and_format_multimodal_docs import ( fetch_docs_chain, ) -from backend.utils.llm import get_vision_llm -from backend.utils.retriever import get_retriever +from backend.rag_components.llm import get_vision_llm +from backend.rag_components.retriever import get_retriever from . import prompts diff --git a/backend/rag_3/config.yaml b/backend/rag_3/config.yaml index b55e472..aebc473 100644 --- a/backend/rag_3/config.yaml +++ b/backend/rag_3/config.yaml @@ -41,7 +41,7 @@ store: root_path: "${..path.database}/multi_vector_retriever_metadata/" retriever: - _target_: backend.utils.multi_vector.ThresholdedMultiVectorRetriever + _target_: backend.rag_components.multi_vector.ThresholdedMultiVectorRetriever vectorstore: ${..vectorstore} byte_store: ${..store} id_key: "doc_id" diff --git a/backend/rag_3/ingest.py b/backend/rag_3/ingest.py index b654a09..0fad870 100644 --- a/backend/rag_3/ingest.py +++ b/backend/rag_3/ingest.py @@ -12,15 +12,14 @@ from backend.rag_3 import prompts from backend.rag_3.config import validate_config -from backend.utils.elements import Image, Table, Text -from backend.utils.ingest import add_elements_to_multivector_retriever -from backend.utils.llm import get_text_llm, get_vision_llm -from backend.utils.retriever import get_retriever -from backend.utils.summarization import ( - generate_image_summaries, - generate_text_summaries, +from backend.rag_components.ingest import ( + add_elements_to_multivector_retriever, + apply_summarize_image, + apply_summarize_table, + apply_summarize_text, ) -from backend.utils.unstructured import ( +from backend.rag_components.retriever import get_retriever +from backend.rag_components.unstructured import ( load_chunking_func, select_images, select_tables, @@ -30,108 +29,6 @@ logger = logging.getLogger(__name__) -async def apply_summarize_text(text_list: list[Text], config: DictConfig) -> None: - """Apply text summarization to a list of Text elements. - - The function directly modifies the Text elements inplace. - - Args: - text_list (list[Text]): List of Text elements. - config (DictConfig): Configuration object. - """ - if config.ingest.summarize_text: - str_list = [text.text for text in text_list] - - model = get_text_llm(config) - - text_summaries = await generate_text_summaries( - str_list, prompt_template=prompts.TEXT_SUMMARIZATION_PROMPT, model=model - ) - - for text in text_list: - text.set_summary(text_summaries.pop(0)) - - else: - logger.info("Skipping text summarization") - - return - - -async def apply_summarize_table(table_list: list[Table], config: DictConfig) -> None: - """Apply table summarization to a list of Table elements. - - The function directly modifies the Table elements inplace. - - Args: - table_list (list[Table]): List of Table elements. - config (DictConfig): Configuration object. - - Raises: - ValueError: If the table format is "image" and summarize_table is False. - ValueError: If the table format is invalid. - """ - if config.ingest.summarize_table: - table_format = config.ingest.table_format - if table_format in ["text", "html"]: - str_list = [table.text for table in table_list] - - model = get_text_llm(config) - - table_summaries = await generate_text_summaries( - str_list, - prompt_template=prompts.TABLE_SUMMARIZATION_PROMPT, - model=model, - ) - elif config.ingest.table_format == "image": - img_base64_list = [table.base64 for table in table_list] - img_mime_type_list = [table.mime_type for table in table_list] - model = get_vision_llm(config) - - table_summaries = await generate_image_summaries( - img_base64_list, - img_mime_type_list, - prompt=prompts.TABLE_SUMMARIZATION_PROMPT, - model=model, - ) - else: - raise ValueError(f"Invalid table format: {table_format}") - - for table in table_list: - table.set_summary(table_summaries.pop(0)) - - else: - logger.info("Skipping table summarization") - - return - - -async def apply_summarize_image(image_list: list[Image], config: DictConfig) -> None: - """Apply image summarization to a list of Image elements. - - The function directly modifies the Image elements inplace. - - Args: - image_list (list[Image]): List of Image elements. - config (DictConfig): Configuration object. - """ - img_base64_list = [image.base64 for image in image_list] - img_mime_type_list = [image.mime_type for image in image_list] - - model = get_vision_llm(config) - - image_summaries = await generate_image_summaries( - img_base64_list, - img_mime_type_list, - prompt=prompts.IMAGE_SUMMARIZATION_PROMPT, - model=model, - ) - - for image in image_list: - image.set_summary(image_summaries.pop(0)) - - return - - async def ingest_pdf(file_path: str | Path, config: DictConfig) -> None: """Ingest a PDF file. @@ -173,13 +70,25 @@ async def ingest_pdf(file_path: str | Path, config: DictConfig) -> None: ) # Summarize text - await apply_summarize_text(texts, config) + await apply_summarize_text( + text_list=texts, + config=config, + prompt_template=prompts.TEXT_SUMMARIZATION_PROMPT, + ) # Summarize tables - await apply_summarize_table(tables, config) + await apply_summarize_table( + table_list=tables, + config=config, + prompt_template=prompts.TABLE_SUMMARIZATION_PROMPT, + ) # Summarize images - await apply_summarize_image(images, config) + await apply_summarize_image( + image_list=images, + config=config, + prompt_template=prompts.IMAGE_SUMMARIZATION_PROMPT, + ) retriever = get_retriever(config) diff --git a/backend/rag_3/notebook.ipynb b/backend/rag_3/notebook.ipynb index 51b8ccb..64ff469 100644 --- a/backend/rag_3/notebook.ipynb +++ b/backend/rag_3/notebook.ipynb @@ -72,17 +72,18 @@ "from hydra import compose, initialize\n", "from unstructured.partition.pdf import partition_pdf\n", "\n", + "from backend.rag_3 import prompts\n", "from backend.rag_3.chain import get_chain\n", "from backend.rag_3.config import validate_config\n", - "from backend.rag_3.ingest import (\n", + "from backend.rag_components.elements import convert_documents_to_elements\n", + "from backend.rag_components.ingest import (\n", + " add_elements_to_multivector_retriever,\n", " apply_summarize_image,\n", " apply_summarize_table,\n", " apply_summarize_text,\n", ")\n", - "from backend.utils.elements import convert_documents_to_elements\n", - "from backend.utils.ingest import add_elements_to_multivector_retriever\n", - "from backend.utils.retriever import get_retriever\n", - "from backend.utils.unstructured import (\n", + "from backend.rag_components.retriever import get_retriever\n", + "from backend.rag_components.unstructured import (\n", " load_chunking_func,\n", " select_images,\n", " select_tables,\n", @@ -298,7 +299,11 @@ "outputs": [], "source": [ "# Summarize text\n", - "await apply_summarize_text(texts, config)\n", + "await apply_summarize_text(\n", + " text_list=texts,\n", + " config=config,\n", + " prompt_template=prompts.TEXT_SUMMARIZATION_PROMPT,\n", + ")\n", "for text in texts[:N_DISPLAY]:\n", " display(text)" ] @@ -310,7 +315,11 @@ "outputs": [], "source": [ "# Summarize tables\n", - "await apply_summarize_table(tables, config)\n", + "await apply_summarize_table(\n", + " table_list=tables,\n", + " config=config,\n", + " prompt_template=prompts.TABLE_SUMMARIZATION_PROMPT,\n", + ")\n", "for table in tables[:N_DISPLAY]:\n", " display(table)" ] @@ -322,7 +331,11 @@ "outputs": [], "source": [ "# Summarize images\n", - "await apply_summarize_image(images, config)\n", + "await apply_summarize_image(\n", + " image_list=images,\n", + " config=config,\n", + " prompt_template=prompts.IMAGE_SUMMARIZATION_PROMPT,\n", + ")\n", "for image in images[:N_DISPLAY]:\n", " display(image)" ] diff --git a/backend/rag_components/chain_links/rag_with_history.py b/backend/rag_components/chain_links/rag_with_history.py index 9688386..1e743d7 100644 --- a/backend/rag_components/chain_links/rag_with_history.py +++ b/backend/rag_components/chain_links/rag_with_history.py @@ -1,13 +1,14 @@ """RAG pipeline with memory.""" -from langchain_core.runnables.base import RunnableSequence +from langchain_core.runnables import RunnableSequence from langchain_core.runnables.history import RunnableWithMessageHistory from omegaconf import DictConfig from pydantic import BaseModel -from backend.rag_components.chain_links.condense_question import condense_question from backend.rag_components.chat_message_history import get_chat_message_history -from backend.utils.llm import get_text_llm +from backend.rag_components.llm import get_text_llm + +from .condense_question import condense_question class QuestionWithHistory(BaseModel): diff --git a/backend/rag_components/chain_links/retrieve_and_format_text_docs.py b/backend/rag_components/chain_links/retrieve_and_format_text_docs.py index 4dcd3af..d3d0126 100644 --- a/backend/rag_components/chain_links/retrieve_and_format_text_docs.py +++ b/backend/rag_components/chain_links/retrieve_and_format_text_docs.py @@ -4,7 +4,7 @@ from langchain_core.documents import Document from langchain_core.prompts import PromptTemplate from langchain_core.retrievers import BaseRetriever -from langchain_core.runnables.base import RunnableLambda, RunnableSequence +from langchain_core.runnables import RunnableLambda, RunnableSequence from pydantic import BaseModel DOCUMENT_TEMPLATE = """\ diff --git a/backend/utils/elements.py b/backend/rag_components/elements.py similarity index 99% rename from backend/utils/elements.py rename to backend/rag_components/elements.py index 930414b..ce314e3 100644 --- a/backend/utils/elements.py +++ b/backend/rag_components/elements.py @@ -10,7 +10,7 @@ from langchain_core.documents import Document from pydantic import BaseModel, PrivateAttr, validator -from .image import local_image_to_base64 +from backend.utils.image import local_image_to_base64 class Element(BaseModel): diff --git a/backend/rag_components/ingest.py b/backend/rag_components/ingest.py new file mode 100644 index 0000000..ad0e37a --- /dev/null +++ b/backend/rag_components/ingest.py @@ -0,0 +1,183 @@ +"""Ingest utility functions.""" + +import logging +from collections.abc import Sequence + +from langchain.retrievers.multi_vector import MultiVectorRetriever +from omegaconf.dictconfig import DictConfig + +from .elements import Element, Image, Table, Text +from .llm import get_text_llm, get_vision_llm +from .retriever import add_documents_multivector +from .summarization import ( + generate_image_summaries, + generate_text_summaries, +) + +logger = logging.getLogger(__name__) + + +def get_attr_from_elements(elements: Sequence[Element], attr: str) -> list: + """Get a specific attribute from a list of elements. + + Args: + elements (list[Element]): List of elements. + attr (str): Attribute to get from the elements. + + Raises: + ValueError: If the attribute is not supported. + + Returns: + list: List of the specified attribute from the elements. + """ + match attr: + case "content": + return [element.get_content() for element in elements] + case "summary": + return [element.get_summary() for element in elements] + case "metadata": + return [element.get_metadata() for element in elements] + case other: + raise ValueError(f"Unsupported attribute: {other}") + + +def add_elements_to_multivector_retriever( + elements: Sequence[Element], + retriever: MultiVectorRetriever, + vectorstore_source: str, + docstore_source: str, +) -> None: + """Add a list of elements to the multi-vector retriever. + + Args: + elements (Sequence[Element]): List of elements to add. + retriever (MultiVectorRetriever): Multi-vector retriever. + vectorstore_source (str): Attribute of the elements to add to the vectorstore. + docstore_source (str): Attribute of the elements to add to the docstore. + """ + vectorstore_content = get_attr_from_elements(elements, vectorstore_source) + docstore_content = get_attr_from_elements(elements, docstore_source) + metadata_list = get_attr_from_elements(elements, "metadata") + + logging.info(f"Adding {vectorstore_source} to vectorstore.") + logging.info(f"Adding {docstore_source} to docstore.") + + add_documents_multivector( + retriever=retriever, + vectorstore_content=vectorstore_content, + docstore_content=docstore_content, + metadata_list=metadata_list, + vectorstore_source=vectorstore_source, + docstore_source=docstore_source, + ) + + +async def apply_summarize_text( + text_list: list[Text], config: DictConfig, prompt_template: str +) -> None: + """Apply text summarization to a list of Text elements. + + The function directly modifies the Text elements inplace. + + Args: + text_list (list[Text]): List of Text elements. + config (DictConfig): Configuration object. + prompt_template (str): Prompt template for the summarization. + """ + if config.ingest.summarize_text: + str_list = [text.text for text in text_list] + + model = get_text_llm(config) + + text_summaries = await generate_text_summaries( + str_list, prompt_template=prompt_template, model=model + ) + + for text in text_list: + text.set_summary(text_summaries.pop(0)) + + else: + logger.info("Skipping text summarization") + + return + + +async def apply_summarize_table( + table_list: list[Table], config: DictConfig, prompt_template: str +) -> None: + """Apply table summarization to a list of Table elements. + + The function directly modifies the Table elements inplace. + + Args: + table_list (list[Table]): List of Table elements. + config (DictConfig): Configuration object. + prompt_template (str): Prompt template for the summarization. + + Raises: + ValueError: If the table format is "image" and summarize_table is False. + ValueError: If the table format is invalid. + """ + if config.ingest.summarize_table: + table_format = config.ingest.table_format + if table_format in ["text", "html"]: + str_list = [table.text for table in table_list] + + model = get_text_llm(config) + + table_summaries = await generate_text_summaries( + str_list, + prompt_template=prompt_template, + model=model, + ) + elif config.ingest.table_format == "image": + img_base64_list = [table.base64 for table in table_list] + img_mime_type_list = [table.mime_type for table in table_list] + model = get_vision_llm(config) + + table_summaries = await generate_image_summaries( + img_base64_list, + img_mime_type_list, + prompt=prompt_template, + model=model, + ) + else: + raise ValueError(f"Invalid table format: {table_format}") + + for table in table_list: + table.set_summary(table_summaries.pop(0)) + + else: + logger.info("Skipping table summarization") + + return + + +async def apply_summarize_image( + image_list: list[Image], config: DictConfig, prompt_template: str +) -> None: + """Apply image summarization to a list of Image elements. + + The function directly modifies the Image elements inplace. + + Args: + image_list (list[Image]): List of Image elements. + config (DictConfig): Configuration object. + prompt_template (str): Prompt template for the summarization. + """ + img_base64_list = [image.base64 for image in image_list] + img_mime_type_list = [image.mime_type for image in image_list] + + model = get_vision_llm(config) + + image_summaries = await generate_image_summaries( + img_base64_list, + img_mime_type_list, + prompt=prompt_template, + model=model, + ) + + for image in image_list: + image.set_summary(image_summaries.pop(0)) + + return diff --git a/backend/utils/llm.py b/backend/rag_components/llm.py similarity index 94% rename from backend/utils/llm.py rename to backend/rag_components/llm.py index e94a2d2..6da2b66 100644 --- a/backend/utils/llm.py +++ b/backend/rag_components/llm.py @@ -1,7 +1,7 @@ """Utility functions for instantiating language models.""" from hydra.utils import instantiate -from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.language_models import BaseChatModel from omegaconf.dictconfig import DictConfig diff --git a/backend/utils/multi_vector.py b/backend/rag_components/multi_vector.py similarity index 100% rename from backend/utils/multi_vector.py rename to backend/rag_components/multi_vector.py diff --git a/backend/utils/retriever.py b/backend/rag_components/retriever.py similarity index 100% rename from backend/utils/retriever.py rename to backend/rag_components/retriever.py diff --git a/backend/utils/summarization.py b/backend/rag_components/summarization.py similarity index 97% rename from backend/utils/summarization.py rename to backend/rag_components/summarization.py index 667f518..eb53338 100644 --- a/backend/utils/summarization.py +++ b/backend/rag_components/summarization.py @@ -4,9 +4,9 @@ from collections.abc import Sequence import openai -from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.language_models import BaseChatModel from langchain_core.messages import BaseMessage, HumanMessage -from langchain_core.output_parsers.string import StrOutputParser +from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate from langchain_core.runnables import RunnableLambda from tenacity import ( diff --git a/backend/utils/unstructured.py b/backend/rag_components/unstructured.py similarity index 98% rename from backend/utils/unstructured.py rename to backend/rag_components/unstructured.py index 2562a42..3a61802 100644 --- a/backend/utils/unstructured.py +++ b/backend/rag_components/unstructured.py @@ -8,7 +8,7 @@ from omegaconf.dictconfig import DictConfig from unstructured.documents.coordinates import RelativeCoordinateSystem -from backend.utils.elements import Image, Table, TableImage, TableText, Text +from .elements import Image, Table, TableImage, TableText, Text def get_element_size(element: unstructured_elements.Element) -> tuple[float, float]: diff --git a/backend/utils/vectorstore.py b/backend/rag_components/vectorstore.py similarity index 100% rename from backend/utils/vectorstore.py rename to backend/rag_components/vectorstore.py diff --git a/backend/utils/ingest.py b/backend/utils/ingest.py deleted file mode 100644 index b19045b..0000000 --- a/backend/utils/ingest.py +++ /dev/null @@ -1,66 +0,0 @@ -"""Ingest utility functions.""" - -import logging -from collections.abc import Sequence - -from langchain.retrievers.multi_vector import MultiVectorRetriever - -from .elements import Element -from .retriever import add_documents_multivector - -logger = logging.getLogger(__name__) - - -def get_attr_from_elements(elements: Sequence[Element], attr: str) -> list: - """Get a specific attribute from a list of elements. - - Args: - elements (list[Element]): List of elements. - attr (str): Attribute to get from the elements. - - Raises: - ValueError: If the attribute is not supported. - - Returns: - list: List of the specified attribute from the elements. - """ - match attr: - case "content": - return [element.get_content() for element in elements] - case "summary": - return [element.get_summary() for element in elements] - case "metadata": - return [element.get_metadata() for element in elements] - case other: - raise ValueError(f"Unsupported attribute: {other}") - - -def add_elements_to_multivector_retriever( - elements: Sequence[Element], - retriever: MultiVectorRetriever, - vectorstore_source: str, - docstore_source: str, -) -> None: - """Add a list of elements to the multi-vector retriever. - - Args: - elements (Sequence[Element]): List of elements to add. - retriever (MultiVectorRetriever): Multi-vector retriever. - vectorstore_source (str): Attribute of the elements to add to the vectorstore. - docstore_source (str): Attribute of the elements to add to the docstore. - """ - vectorstore_content = get_attr_from_elements(elements, vectorstore_source) - docstore_content = get_attr_from_elements(elements, docstore_source) - metadata_list = get_attr_from_elements(elements, "metadata") - - logging.info(f"Adding {vectorstore_source} to vectorstore.") - logging.info(f"Adding {docstore_source} to docstore.") - - add_documents_multivector( - retriever=retriever, - vectorstore_content=vectorstore_content, - docstore_content=docstore_content, - metadata_list=metadata_list, - vectorstore_source=vectorstore_source, - docstore_source=docstore_source, - ) diff --git a/tests/backend/utils/test_elements.py b/tests/backend/rag_components/test_elements.py similarity index 98% rename from tests/backend/utils/test_elements.py rename to tests/backend/rag_components/test_elements.py index 83a2b20..b1ba5f2 100644 --- a/tests/backend/utils/test_elements.py +++ b/tests/backend/rag_components/test_elements.py @@ -4,7 +4,14 @@ from pytest import FixtureRequest from pytest_lazy_fixtures import lf -from backend.utils.elements import Element, Image, Table, TableImage, TableText, Text +from backend.rag_components.elements import ( + Element, + Image, + Table, + TableImage, + TableText, + Text, +) # ----------------------------------- Text ----------------------------------- #