Skip to content

Commit

Permalink
LangChain: Support for LCEL Runnables (#1586)
Browse files Browse the repository at this point in the history
  • Loading branch information
joshuasundance-swca authored Oct 29, 2023
1 parent 62e97dd commit b57a8db
Showing 1 changed file with 87 additions and 16 deletions.
103 changes: 87 additions & 16 deletions bertopic/representation/_langchain.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,26 @@
import pandas as pd
from tqdm import tqdm
from scipy.sparse import csr_matrix
from typing import Mapping, List, Tuple, Union, Callable
from langchain.docstore.document import Document
from scipy.sparse import csr_matrix
from typing import Callable, Dict, Mapping, List, Tuple, Union

from bertopic.representation._base import BaseRepresentation
from bertopic.representation._utils import truncate_document


DEFAULT_PROMPT = "What are these documents about? Please give a single label."


class LangChain(BaseRepresentation):
""" Using chains in langchain to generate topic labels.
Currently, only chains from question answering is implemented. See:
https://langchain.readthedocs.io/en/latest/modules/chains/combine_docs_examples/question_answering.html
The classic example uses `langchain.chains.question_answering.load_qa_chain`.
This returns a chain that takes a list of documents and a question as input.
You can also use Runnables such as those composed using the LangChain Expression Language.
Arguments:
chain: A langchain chain that has two input parameters, `input_documents` and `query`.
chain: The langchain chain or Runnable with a `batch` method.
Input keys must be `input_documents` and `question`.
Output key must be `output_text`.
prompt: The prompt to be used in the model. If no prompt is given,
`self.default_prompt_` is used instead.
nr_docs: The number of documents to pass to LangChain if a prompt
Expand All @@ -42,6 +45,8 @@ class LangChain(BaseRepresentation):
* If tokenizer is a callable, then that callable is used to tokenize
the document. These tokens are counted and truncated depending
on `doc_length`
chain_config: The configuration for the langchain chain. Can be used to set options
like max_concurrency to avoid rate limiting errors.
Usage:
To use this, you will need to install the langchain package first.
Expand Down Expand Up @@ -77,18 +82,60 @@ class LangChain(BaseRepresentation):
prompt = "What are these documents about? Please give a single label."
representation_model = LangChain(chain, prompt=prompt)
```
You can also use a Runnable instead of a chain.
The example below uses the LangChain Expression Language:
```python
from bertopic.representation import LangChain
from langchain.chains.question_answering import load_qa_chain
from langchain.chat_models import ChatAnthropic
from langchain.schema.document import Document
from langchain.schema.runnable import RunnablePassthrough
from langchain_experimental.data_anonymizer.presidio import PresidioReversibleAnonymizer
prompt = ...
llm = ...
# We will construct a special privacy-preserving chain using Microsoft Presidio
pii_handler = PresidioReversibleAnonymizer(analyzed_fields=["PERSON"])
chain = (
{
"input_documents": (
lambda inp: [
Document(
page_content=pii_handler.anonymize(
d.page_content,
language="en",
),
)
for d in inp["input_documents"]
]
),
"question": RunnablePassthrough(),
}
| load_qa_chain(representation_llm, chain_type="stuff")
| (lambda output: {"output_text": pii_handler.deanonymize(output["output_text"])})
)
representation_model = LangChain(chain, prompt=representation_prompt)
```
"""
def __init__(self,
chain,
prompt: str = None,
nr_docs: int = 4,
diversity: float = None,
doc_length: int = None,
tokenizer: Union[str, Callable] = None
tokenizer: Union[str, Callable] = None,
chain_config = None,
):
self.chain = chain
self.prompt = prompt if prompt is not None else DEFAULT_PROMPT
self.default_prompt_ = DEFAULT_PROMPT
self.chain_config = chain_config
self.nr_docs = nr_docs
self.diversity = diversity
self.doc_length = doc_length
Expand All @@ -99,7 +146,7 @@ def extract_topics(self,
documents: pd.DataFrame,
c_tf_idf: csr_matrix,
topics: Mapping[str, List[Tuple[str, float]]]
) -> Mapping[str, List[Tuple[str, float]]]:
) -> Mapping[str, List[Tuple[str, int]]]:
""" Extract topics
Arguments:
Expand All @@ -121,12 +168,36 @@ def extract_topics(self,
diversity=self.diversity
)

# Generate label using langchain
updated_topics = {}
for topic, docs in tqdm(repr_docs_mappings.items(), disable=not topic_model.verbose):
truncated_docs = [truncate_document(topic_model, self.doc_length, self.tokenizer, doc) for doc in docs]
chain_docs = [Document(page_content=doc) for doc in truncated_docs]
label = self.chain.run(input_documents=chain_docs, question=self.prompt).strip()
updated_topics[topic] = [(label, 1)] + [("", 0) for _ in range(9)]
# Generate label using langchain's batch functionality
chain_docs: List[List[Document]] = [
[
Document(
page_content=truncate_document(
topic_model,
self.doc_length,
self.tokenizer,
doc
)
)
for doc in docs
]
for docs in repr_docs_mappings.values()
]

# `self.chain` must take `input_documents` and `question` as input keys
inputs = [
{"input_documents": docs, "question": self.prompt}
for docs in chain_docs
]

# `self.chain` must return a dict with an `output_text` key
# same output key as the `StuffDocumentsChain` returned by `load_qa_chain`
outputs = self.chain.batch(inputs=inputs, config=self.chain_config)
labels = [output["output_text"].strip() for output in outputs]

updated_topics = {
topic: [(label, 1)] + [("", 0) for _ in range(9)]
for topic, label in zip(repr_docs_mappings.keys(), labels)
}

return updated_topics

0 comments on commit b57a8db

Please sign in to comment.