Skip to content

Commit

Permalink
Fixed issue #2144
Browse files Browse the repository at this point in the history
  • Loading branch information
PipaFlores committed Oct 15, 2024
1 parent 9518035 commit b006e46
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 18 deletions.
56 changes: 38 additions & 18 deletions bertopic/_bertopic.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,21 +478,20 @@ def fit_transform(
if documents.Document.values[0] is None:
custom_documents = self._images_to_text(documents, embeddings)

# Extract topics by calculating c-TF-IDF
self._extract_topics(custom_documents, embeddings=embeddings)
self._create_topic_vectors(documents=documents, embeddings=embeddings)

# Reduce topics
# Extract topics by calculating c-TF-IDF, reduce topics if needed, and get representations.
self._extract_topics(custom_documents, embeddings=embeddings, calculate_representation=not self.nr_topics)
if self.nr_topics:
custom_documents = self._reduce_topics(custom_documents)
self._create_topic_vectors(documents=documents, embeddings=embeddings)

# Save the top 3 most representative documents per topic
self._save_representative_docs(custom_documents)
else:
# Extract topics by calculating c-TF-IDF
self._extract_topics(documents, embeddings=embeddings, verbose=self.verbose)

# Reduce topics
else:
# Extract topics by calculating c-TF-IDF, reduce topics if needed, and get representations.
self._extract_topics(
documents, embeddings=embeddings, verbose=self.verbose, calculate_representation=not self.nr_topics
)
if self.nr_topics:
documents = self._reduce_topics(documents)

Expand Down Expand Up @@ -3972,6 +3971,7 @@ def _extract_topics(
embeddings: np.ndarray = None,
mappings=None,
verbose: bool = False,
calculate_representation: bool = True,
):
"""Extract topics from the clusters using a class-based TF-IDF.
Expand All @@ -3980,18 +3980,29 @@ def _extract_topics(
embeddings: The document embeddings
mappings: The mappings from topic to word
verbose: Whether to log the process of extracting topics
calculate_representation: Whether to extract the topic representations
Returns:
c_tf_idf: The resulting matrix giving a value (importance score) for each word per topic
"""
if verbose:
logger.info("Representation - Extracting topics from clusters using representation models.")
action = "Representation" if calculate_representation else "Topics"
logger.info(
f"{action} - Extracting topics from clusters{' using representation models' if calculate_representation else ''}."
)

documents_per_topic = documents.groupby(["Topic"], as_index=False).agg({"Document": " ".join})
self.c_tf_idf_, words = self._c_tf_idf(documents_per_topic)
self.topic_representations_ = self._extract_words_per_topic(words, documents)
self.topic_representations_ = self._extract_words_per_topic(
words,
documents,
calculate_representation=calculate_representation,
calculate_aspects=calculate_representation,
)
self._create_topic_vectors(documents=documents, embeddings=embeddings, mappings=mappings)

if verbose:
logger.info("Representation - Completed \u2713")
logger.info(f"{action} - Completed \u2713")

def _save_representative_docs(self, documents: pd.DataFrame):
"""Save the 3 most representative docs per topic.
Expand Down Expand Up @@ -4245,6 +4256,7 @@ def _extract_words_per_topic(
words: List[str],
documents: pd.DataFrame,
c_tf_idf: csr_matrix = None,
calculate_representation: bool = True,
calculate_aspects: bool = True,
) -> Mapping[str, List[Tuple[str, float]]]:
"""Based on tf_idf scores per topic, extract the top n words per topic.
Expand All @@ -4258,6 +4270,7 @@ def _extract_words_per_topic(
words: List of all words (sorted according to tf_idf matrix position)
documents: DataFrame with documents and their topic IDs
c_tf_idf: A c-TF-IDF matrix from which to calculate the top words
calculate_representation: Whether to calculate the topic representations
calculate_aspects: Whether to calculate additional topic aspects
Returns:
Expand Down Expand Up @@ -4288,15 +4301,15 @@ def _extract_words_per_topic(

# Fine-tune the topic representations
topics = base_topics.copy()
if not self.representation_model:
if not self.representation_model or not calculate_representation:
# Default representation: c_tf_idf + top_n_words
topics = {label: values[: self.top_n_words] for label, values in topics.items()}
elif isinstance(self.representation_model, list):
elif calculate_representation and isinstance(self.representation_model, list):
for tuner in self.representation_model:
topics = tuner.extract_topics(self, documents, c_tf_idf, topics)
elif isinstance(self.representation_model, BaseRepresentation):
elif calculate_representation and isinstance(self.representation_model, BaseRepresentation):
topics = self.representation_model.extract_topics(self, documents, c_tf_idf, topics)
elif isinstance(self.representation_model, dict):
elif calculate_representation and isinstance(self.representation_model, dict):
if self.representation_model.get("Main"):
main_model = self.representation_model["Main"]
if isinstance(main_model, BaseRepresentation):
Expand Down Expand Up @@ -4350,6 +4363,13 @@ def _reduce_topics(self, documents: pd.DataFrame, use_ctfidf: bool = False) -> p
if isinstance(self.nr_topics, int):
if self.nr_topics < initial_nr_topics:
documents = self._reduce_to_n_topics(documents, use_ctfidf)
else:
logger.info(
f"Topic reduction - Number of topics ({self.nr_topics}) is equal or higher than the clustered topics({len(self.get_topics())})."
)
documents = self._sort_mappings_by_frequency(documents)
self._extract_topics(documents, verbose=self.verbose)
return documents
elif isinstance(self.nr_topics, str):
documents = self._auto_reduce_topics(documents, use_ctfidf)
else:
Expand Down Expand Up @@ -4412,7 +4432,7 @@ def _reduce_to_n_topics(self, documents: pd.DataFrame, use_ctfidf: bool = False)

# Update representations
documents = self._sort_mappings_by_frequency(documents)
self._extract_topics(documents, mappings=mappings)
self._extract_topics(documents, mappings=mappings, verbose=self.verbose)

self._update_topic_size(documents)
return documents
Expand Down Expand Up @@ -4468,7 +4488,7 @@ def _auto_reduce_topics(self, documents: pd.DataFrame, use_ctfidf: bool = False)
# Update documents and topics
self.topic_mapper_.add_mappings(mapped_topics, topic_model=self)
documents = self._sort_mappings_by_frequency(documents)
self._extract_topics(documents, mappings=mappings)
self._extract_topics(documents, mappings=mappings, verbose=self.verbose)
self._update_topic_size(documents)
return documents

Expand Down
1 change: 1 addition & 0 deletions tests/test_representation/test_representations.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def test_topic_reduction_edge_cases(model, documents, request):
topics = np.random.randint(-1, nr_topics - 1, len(documents))
old_documents = pd.DataFrame({"Document": documents, "ID": range(len(documents)), "Topic": topics})
topic_model._update_topic_size(old_documents)
old_documents = topic_model._sort_mappings_by_frequency(old_documents)
topic_model._extract_topics(old_documents)
old_freq = topic_model.get_topic_freq()

Expand Down

0 comments on commit b006e46

Please sign in to comment.