diff --git a/bertopic/representation/_langchain.py b/bertopic/representation/_langchain.py index 2aac661e..bbf8e4aa 100644 --- a/bertopic/representation/_langchain.py +++ b/bertopic/representation/_langchain.py @@ -1,23 +1,26 @@ import pandas as pd -from tqdm import tqdm -from scipy.sparse import csr_matrix -from typing import Mapping, List, Tuple, Union, Callable from langchain.docstore.document import Document +from scipy.sparse import csr_matrix +from typing import Callable, Dict, Mapping, List, Tuple, Union + from bertopic.representation._base import BaseRepresentation from bertopic.representation._utils import truncate_document - DEFAULT_PROMPT = "What are these documents about? Please give a single label." class LangChain(BaseRepresentation): """ Using chains in langchain to generate topic labels. - Currently, only chains from question answering is implemented. See: - https://langchain.readthedocs.io/en/latest/modules/chains/combine_docs_examples/question_answering.html + The classic example uses `langchain.chains.question_answering.load_qa_chain`. + This returns a chain that takes a list of documents and a question as input. + + You can also use Runnables such as those composed using the LangChain Expression Language. Arguments: - chain: A langchain chain that has two input parameters, `input_documents` and `query`. + chain: The langchain chain or Runnable with a `batch` method. + Input keys must be `input_documents` and `question`. + Output key must be `output_text`. prompt: The prompt to be used in the model. If no prompt is given, `self.default_prompt_` is used instead. nr_docs: The number of documents to pass to LangChain if a prompt @@ -42,6 +45,8 @@ class LangChain(BaseRepresentation): * If tokenizer is a callable, then that callable is used to tokenize the document. These tokens are counted and truncated depending on `doc_length` + chain_config: The configuration for the langchain chain. Can be used to set options + like max_concurrency to avoid rate limiting errors. Usage: To use this, you will need to install the langchain package first. @@ -77,6 +82,46 @@ class LangChain(BaseRepresentation): prompt = "What are these documents about? Please give a single label." representation_model = LangChain(chain, prompt=prompt) ``` + + You can also use a Runnable instead of a chain. + The example below uses the LangChain Expression Language: + + ```python + from bertopic.representation import LangChain + from langchain.chains.question_answering import load_qa_chain + from langchain.chat_models import ChatAnthropic + from langchain.schema.document import Document + from langchain.schema.runnable import RunnablePassthrough + from langchain_experimental.data_anonymizer.presidio import PresidioReversibleAnonymizer + + prompt = ... + llm = ... + + # We will construct a special privacy-preserving chain using Microsoft Presidio + + pii_handler = PresidioReversibleAnonymizer(analyzed_fields=["PERSON"]) + + chain = ( + { + "input_documents": ( + lambda inp: [ + Document( + page_content=pii_handler.anonymize( + d.page_content, + language="en", + ), + ) + for d in inp["input_documents"] + ] + ), + "question": RunnablePassthrough(), + } + | load_qa_chain(representation_llm, chain_type="stuff") + | (lambda output: {"output_text": pii_handler.deanonymize(output["output_text"])}) + ) + + representation_model = LangChain(chain, prompt=representation_prompt) + ``` """ def __init__(self, chain, @@ -84,11 +129,13 @@ def __init__(self, nr_docs: int = 4, diversity: float = None, doc_length: int = None, - tokenizer: Union[str, Callable] = None + tokenizer: Union[str, Callable] = None, + chain_config = None, ): self.chain = chain self.prompt = prompt if prompt is not None else DEFAULT_PROMPT self.default_prompt_ = DEFAULT_PROMPT + self.chain_config = chain_config self.nr_docs = nr_docs self.diversity = diversity self.doc_length = doc_length @@ -99,7 +146,7 @@ def extract_topics(self, documents: pd.DataFrame, c_tf_idf: csr_matrix, topics: Mapping[str, List[Tuple[str, float]]] - ) -> Mapping[str, List[Tuple[str, float]]]: + ) -> Mapping[str, List[Tuple[str, int]]]: """ Extract topics Arguments: @@ -121,12 +168,36 @@ def extract_topics(self, diversity=self.diversity ) - # Generate label using langchain - updated_topics = {} - for topic, docs in tqdm(repr_docs_mappings.items(), disable=not topic_model.verbose): - truncated_docs = [truncate_document(topic_model, self.doc_length, self.tokenizer, doc) for doc in docs] - chain_docs = [Document(page_content=doc) for doc in truncated_docs] - label = self.chain.run(input_documents=chain_docs, question=self.prompt).strip() - updated_topics[topic] = [(label, 1)] + [("", 0) for _ in range(9)] + # Generate label using langchain's batch functionality + chain_docs: List[List[Document]] = [ + [ + Document( + page_content=truncate_document( + topic_model, + self.doc_length, + self.tokenizer, + doc + ) + ) + for doc in docs + ] + for docs in repr_docs_mappings.values() + ] + + # `self.chain` must take `input_documents` and `question` as input keys + inputs = [ + {"input_documents": docs, "question": self.prompt} + for docs in chain_docs + ] + + # `self.chain` must return a dict with an `output_text` key + # same output key as the `StuffDocumentsChain` returned by `load_qa_chain` + outputs = self.chain.batch(inputs=inputs, config=self.chain_config) + labels = [output["output_text"].strip() for output in outputs] + + updated_topics = { + topic: [(label, 1)] + [("", 0) for _ in range(9)] + for topic, label in zip(repr_docs_mappings.keys(), labels) + } return updated_topics