Skip to content

Commit

Permalink
Add params to truncate documents to length when using LLMs (#1539)
Browse files Browse the repository at this point in the history
  • Loading branch information
MaartenGr authored Oct 11, 2023
1 parent 362ccc6 commit 62e97dd
Show file tree
Hide file tree
Showing 6 changed files with 369 additions and 39 deletions.
42 changes: 33 additions & 9 deletions bertopic/representation/_cohere.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import time
import pandas as pd
from tqdm import tqdm
from scipy.sparse import csr_matrix
from typing import Mapping, List, Tuple
from typing import Mapping, List, Tuple, Union, Callable
from bertopic.representation._base import BaseRepresentation
from bertopic.representation._utils import truncate_document


DEFAULT_PROMPT = """
Expand Down Expand Up @@ -57,6 +59,21 @@ class Cohere(BaseRepresentation):
Accepts values between 0 and 1. A higher
values results in passing more diverse documents
whereas lower values passes more similar documents.
doc_length: The maximum length of each document. If a document is longer,
it will be truncated. If None, the entire document is passed.
tokenizer: The tokenizer used to calculate to split the document into segments
used to count the length of a document.
* If tokenizer is 'char', then the document is split up
into characters which are counted to adhere to `doc_length`
* If tokenizer is 'whitespace', the the document is split up
into words separated by whitespaces. These words are counted
and truncated depending on `doc_length`
* If tokenizer is 'vectorizer', then the internal CountVectorizer
is used to tokenize the document. These tokens are counted
and trunctated depending on `doc_length`
* If tokenizer is a callable, then that callable is used to tokenize
the document. These tokens are counted and truncated depending
on `doc_length`
Usage:
Expand Down Expand Up @@ -92,7 +109,9 @@ def __init__(self,
prompt: str = None,
delay_in_seconds: float = None,
nr_docs: int = 4,
diversity: float = None
diversity: float = None,
doc_length: int = None,
tokenizer: Union[str, Callable] = None
):
self.client = client
self.model = model
Expand All @@ -101,6 +120,9 @@ def __init__(self,
self.delay_in_seconds = delay_in_seconds
self.nr_docs = nr_docs
self.diversity = diversity
self.doc_length = doc_length
self.tokenizer = tokenizer
self.prompts_ = []

def extract_topics(self,
topic_model,
Expand All @@ -124,8 +146,10 @@ def extract_topics(self,

# Generate using Cohere's Language Model
updated_topics = {}
for topic, docs in repr_docs_mappings.items():
prompt = self._create_prompt(docs, topic, topics)
for topic, docs in tqdm(repr_docs_mappings.items(), disable=not topic_model.verbose):
truncated_docs = [truncate_document(topic_model, self.doc_length, self.tokenizer, doc) for doc in docs]
prompt = self._create_prompt(truncated_docs, topic, topics)
self.prompts_.append(prompt)

# Delay
if self.delay_in_seconds:
Expand All @@ -140,21 +164,21 @@ def extract_topics(self,
updated_topics[topic] = [(label, 1)] + [("", 0) for _ in range(9)]

return updated_topics

def _create_prompt(self, docs, topic, topics):
keywords = list(zip(*topics[topic]))[0]

# Use the Default Chat Prompt
if self.prompt == self.prompt == DEFAULT_PROMPT:
prompt = self.prompt.replace("[KEYWORDS]", " ".join(keywords))
if self.prompt == DEFAULT_PROMPT:
prompt = self.prompt.replace("[KEYWORDS]", ", ".join(keywords))
prompt = self._replace_documents(prompt, docs)

# Use a custom prompt that leverages keywords, documents or both using
# custom tags, namely [KEYWORDS] and [DOCUMENTS] respectively
else:
prompt = self.prompt
if "[KEYWORDS]" in prompt:
prompt = prompt.replace("[KEYWORDS]", " ".join(keywords))
prompt = prompt.replace("[KEYWORDS]", ", ".join(keywords))
if "[DOCUMENTS]" in prompt:
prompt = self._replace_documents(prompt, docs)

Expand All @@ -164,6 +188,6 @@ def _create_prompt(self, docs, topic, topics):
def _replace_documents(prompt, docs):
to_replace = ""
for doc in docs:
to_replace += f"- {doc[:255]}\n"
to_replace += f"- {doc}\n"
prompt = prompt.replace("[DOCUMENTS]", to_replace)
return prompt
49 changes: 44 additions & 5 deletions bertopic/representation/_langchain.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import pandas as pd
from tqdm import tqdm
from scipy.sparse import csr_matrix
from typing import Mapping, List, Tuple
from typing import Mapping, List, Tuple, Union, Callable
from langchain.docstore.document import Document
from bertopic.representation._base import BaseRepresentation
from bertopic.representation._utils import truncate_document


DEFAULT_PROMPT = "What are these documents about? Please give a single label."
Expand All @@ -18,7 +20,28 @@ class LangChain(BaseRepresentation):
chain: A langchain chain that has two input parameters, `input_documents` and `query`.
prompt: The prompt to be used in the model. If no prompt is given,
`self.default_prompt_` is used instead.
nr_docs: The number of documents to pass to LangChain if a prompt
with the `["DOCUMENTS"]` tag is used.
diversity: The diversity of documents to pass to LangChain.
Accepts values between 0 and 1. A higher
values results in passing more diverse documents
whereas lower values passes more similar documents.
doc_length: The maximum length of each document. If a document is longer,
it will be truncated. If None, the entire document is passed.
tokenizer: The tokenizer used to calculate to split the document into segments
used to count the length of a document.
* If tokenizer is 'char', then the document is split up
into characters which are counted to adhere to `doc_length`
* If tokenizer is 'whitespace', the the document is split up
into words separated by whitespaces. These words are counted
and truncated depending on `doc_length`
* If tokenizer is 'vectorizer', then the internal CountVectorizer
is used to tokenize the document. These tokens are counted
and trunctated depending on `doc_length`. They are decoded with
whitespaces.
* If tokenizer is a callable, then that callable is used to tokenize
the document. These tokens are counted and truncated depending
on `doc_length`
Usage:
To use this, you will need to install the langchain package first.
Expand Down Expand Up @@ -58,10 +81,18 @@ class LangChain(BaseRepresentation):
def __init__(self,
chain,
prompt: str = None,
nr_docs: int = 4,
diversity: float = None,
doc_length: int = None,
tokenizer: Union[str, Callable] = None
):
self.chain = chain
self.prompt = prompt if prompt is not None else DEFAULT_PROMPT
self.default_prompt_ = DEFAULT_PROMPT
self.nr_docs = nr_docs
self.diversity = diversity
self.doc_length = doc_length
self.tokenizer = tokenizer

def extract_topics(self,
topic_model,
Expand All @@ -81,12 +112,20 @@ def extract_topics(self,
updated_topics: Updated topic representations
"""
# Extract the top 4 representative documents per topic
repr_docs_mappings, _, _, _ = topic_model._extract_representative_docs(c_tf_idf, documents, topics, 500, 4)
repr_docs_mappings, _, _, _ = topic_model._extract_representative_docs(
c_tf_idf=c_tf_idf,
documents=documents,
topics=topics,
nr_samples=500,
nr_repr_docs=self.nr_docs,
diversity=self.diversity
)

# Generate label using langchain
updated_topics = {}
for topic, docs in repr_docs_mappings.items():
chain_docs = [Document(page_content=doc[:1000]) for doc in docs]
for topic, docs in tqdm(repr_docs_mappings.items(), disable=not topic_model.verbose):
truncated_docs = [truncate_document(topic_model, self.doc_length, self.tokenizer, doc) for doc in docs]
chain_docs = [Document(page_content=doc) for doc in truncated_docs]
label = self.chain.run(input_documents=chain_docs, question=self.prompt).strip()
updated_topics[topic] = [(label, 1)] + [("", 0) for _ in range(9)]

Expand Down
55 changes: 39 additions & 16 deletions bertopic/representation/_openai.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import time
import openai
import pandas as pd
from tqdm import tqdm
from scipy.sparse import csr_matrix
from typing import Mapping, List, Tuple, Any
from typing import Mapping, List, Tuple, Any, Union, Callable
from bertopic.representation._base import BaseRepresentation
from bertopic.representation._utils import retry_with_exponential_backoff
from bertopic.representation._utils import retry_with_exponential_backoff, truncate_document


DEFAULT_PROMPT = """
Expand Down Expand Up @@ -47,13 +48,13 @@

class OpenAI(BaseRepresentation):
""" Using the OpenAI API to generate topic labels based
on one of their Completion of ChatCompletion models.
on one of their Completion of ChatCompletion models.
The default method is `openai.Completion` if `chat=False`.
The prompts will also need to follow a completion task. If you
The default method is `openai.Completion` if `chat=False`.
The prompts will also need to follow a completion task. If you
are looking for a more interactive chats, use `chat=True`
with `model=gpt-3.5-turbo`.
with `model=gpt-3.5-turbo`.
For an overview see:
https://platform.openai.com/docs/models
Expand Down Expand Up @@ -83,6 +84,21 @@ class OpenAI(BaseRepresentation):
Accepts values between 0 and 1. A higher
values results in passing more diverse documents
whereas lower values passes more similar documents.
doc_length: The maximum length of each document. If a document is longer,
it will be truncated. If None, the entire document is passed.
tokenizer: The tokenizer used to calculate to split the document into segments
used to count the length of a document.
* If tokenizer is 'char', then the document is split up
into characters which are counted to adhere to `doc_length`
* If tokenizer is 'whitespace', the the document is split up
into words separated by whitespaces. These words are counted
and truncated depending on `doc_length`
* If tokenizer is 'vectorizer', then the internal CountVectorizer
is used to tokenize the document. These tokens are counted
and trunctated depending on `doc_length`
* If tokenizer is a callable, then that callable is used to tokenize
the document. These tokens are counted and truncated depending
on `doc_length`
Usage:
Expand Down Expand Up @@ -112,7 +128,7 @@ class OpenAI(BaseRepresentation):
```
If you want to use OpenAI's ChatGPT model:
```python
representation_model = OpenAI(model="gpt-3.5-turbo", delay_in_seconds=10, chat=True)
```
Expand All @@ -125,10 +141,12 @@ def __init__(self,
exponential_backoff: bool = False,
chat: bool = False,
nr_docs: int = 4,
diversity: float = None
diversity: float = None,
doc_length: int = None,
tokenizer: Union[str, Callable] = None
):
self.model = model

if prompt is None:
self.prompt = DEFAULT_CHAT_PROMPT if chat else DEFAULT_PROMPT
else:
Expand All @@ -140,6 +158,9 @@ def __init__(self,
self.chat = chat
self.nr_docs = nr_docs
self.diversity = diversity
self.doc_length = doc_length
self.tokenizer = tokenizer
self.prompts_ = []

self.generator_kwargs = generator_kwargs
if self.generator_kwargs.get("model"):
Expand Down Expand Up @@ -171,15 +192,17 @@ def extract_topics(self,

# Generate using OpenAI's Language Model
updated_topics = {}
for topic, docs in repr_docs_mappings.items():
prompt = self._create_prompt(docs, topic, topics)
for topic, docs in tqdm(repr_docs_mappings.items(), disable=not topic_model.verbose):
truncated_docs = [truncate_document(topic_model, self.doc_length, self.tokenizer, doc) for doc in docs]
prompt = self._create_prompt(truncated_docs, topic, topics)
self.prompts_.append(prompt)

# Delay
if self.delay_in_seconds:
time.sleep(self.delay_in_seconds)

if self.chat:
messages=[
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}
]
Expand All @@ -205,15 +228,15 @@ def _create_prompt(self, docs, topic, topics):

# Use the Default Chat Prompt
if self.prompt == DEFAULT_CHAT_PROMPT or self.prompt == DEFAULT_PROMPT:
prompt = self.prompt.replace("[KEYWORDS]", " ".join(keywords))
prompt = self.prompt.replace("[KEYWORDS]", ", ".join(keywords))
prompt = self._replace_documents(prompt, docs)

# Use a custom prompt that leverages keywords, documents or both using
# custom tags, namely [KEYWORDS] and [DOCUMENTS] respectively
else:
prompt = self.prompt
if "[KEYWORDS]" in prompt:
prompt = prompt.replace("[KEYWORDS]", " ".join(keywords))
prompt = prompt.replace("[KEYWORDS]", ", ".join(keywords))
if "[DOCUMENTS]" in prompt:
prompt = self._replace_documents(prompt, docs)

Expand All @@ -223,7 +246,7 @@ def _create_prompt(self, docs, topic, topics):
def _replace_documents(prompt, docs):
to_replace = ""
for doc in docs:
to_replace += f"- {doc[:255]}\n"
to_replace += f"- {doc}\n"
prompt = prompt.replace("[DOCUMENTS]", to_replace)
return prompt

Expand Down
Loading

0 comments on commit 62e97dd

Please sign in to comment.