diff --git a/bertopic/_bertopic.py b/bertopic/_bertopic.py index 7ef1efbb..6e849976 100644 --- a/bertopic/_bertopic.py +++ b/bertopic/_bertopic.py @@ -2146,7 +2146,7 @@ def merge_topics( # Update topics documents.Topic = documents.Topic.map(mapping) - self.topic_mapper_.add_mappings(mapping) + self.topic_mapper_.add_mappings(mapping, topic_model=self) documents = self._sort_mappings_by_frequency(documents) self._extract_topics(documents, mappings=mappings) self._update_topic_size(documents) @@ -4396,50 +4396,12 @@ def _reduce_to_n_topics(self, documents: pd.DataFrame, use_ctfidf: bool = False) # Map topics documents.Topic = new_topics self._update_topic_size(documents) - self.topic_mapper_.add_mappings(mapped_topics) + self.topic_mapper_.add_mappings(mapped_topics, topic_model=self) # Update representations documents = self._sort_mappings_by_frequency(documents) self._extract_topics(documents, mappings=mappings) - # When zero-shot topic(s) are present in the topics to merge, - # determine whether to take one of the zero-shot topic labels - # or use a calculated representation. - if self._is_zeroshot(): - new_topic_id_to_zeroshot_topic_idx = {} - topics_to_map = { - topic_mapping[0]: topic_mapping[1] for topic_mapping in np.array(self.topic_mapper_.mappings_)[:, -2:] - } - - for topic_to, topics_from in basic_mappings.items(): - # When extracting topics, the reduced topics were reordered. - # Must get the updated topic_to. - topic_to = topics_to_map[topic_to] - - # which of the original topics are zero-shot - zeroshot_topic_ids = [ - topic_id for topic_id in topics_from if topic_id in self._topic_id_to_zeroshot_topic_idx - ] - if len(zeroshot_topic_ids) == 0: - continue - - # If any of the original topics are zero-shot, take the best fitting zero-shot label - # if the cosine similarity with the new topic exceeds the zero-shot threshold - zeroshot_labels = [ - self.zeroshot_topic_list[self._topic_id_to_zeroshot_topic_idx[topic_id]] - for topic_id in zeroshot_topic_ids - ] - zeroshot_embeddings = self._extract_embeddings(zeroshot_labels) - cosine_similarities = cosine_similarity( - zeroshot_embeddings, [self.topic_embeddings_[topic_to]] - ).flatten() - best_zeroshot_topic_idx = np.argmax(cosine_similarities) - best_cosine_similarity = cosine_similarities[best_zeroshot_topic_idx] - if best_cosine_similarity >= self.zeroshot_min_similarity: - new_topic_id_to_zeroshot_topic_idx[topic_to] = zeroshot_topic_ids[best_zeroshot_topic_idx] - - self._topic_id_to_zeroshot_topic_idx = new_topic_id_to_zeroshot_topic_idx - self._update_topic_size(documents) return documents @@ -4492,7 +4454,7 @@ def _auto_reduce_topics(self, documents: pd.DataFrame, use_ctfidf: bool = False) } # Update documents and topics - self.topic_mapper_.add_mappings(mapped_topics) + self.topic_mapper_.add_mappings(mapped_topics, topic_model=self) documents = self._sort_mappings_by_frequency(documents) self._extract_topics(documents, mappings=mappings) self._update_topic_size(documents) @@ -4528,7 +4490,7 @@ def _sort_mappings_by_frequency(self, documents: pd.DataFrame) -> pd.DataFrame: df = pd.DataFrame(self.topic_sizes_.items(), columns=["Old_Topic", "Size"]).sort_values("Size", ascending=False) df = df[df.Old_Topic != -1] sorted_topics = {**{-1: -1}, **dict(zip(df.Old_Topic, range(len(df))))} - self.topic_mapper_.add_mappings(sorted_topics) + self.topic_mapper_.add_mappings(sorted_topics, topic_model=self) # Map documents documents.Topic = documents.Topic.map(sorted_topics).fillna(documents.Topic).astype(int) @@ -4718,11 +4680,12 @@ def get_mappings(self, original_topics: bool = True) -> Mapping[int, int]: mappings = dict(zip(mappings[:, 0], mappings[:, 1])) return mappings - def add_mappings(self, mappings: Mapping[int, int]): + def add_mappings(self, mappings: Mapping[int, int], topic_model: BERTopic): """Add new column(s) of topic mappings. Arguments: mappings: The mappings to add + topic_model: The topic model this TopicMapper belongs to """ for topics in self.mappings_: topic = topics[-1] @@ -4731,6 +4694,50 @@ def add_mappings(self, mappings: Mapping[int, int]): else: topics.append(-1) + # When zero-shot topic(s) are present in the topics to merge, + # determine whether to take one of the zero-shot topic labels + # or use a calculated representation. + if topic_model._is_zeroshot() and len(topic_model._topic_id_to_zeroshot_topic_idx) > 0: + new_topic_id_to_zeroshot_topic_idx = {} + topics_to_map = { + topic_mapping[0]: topic_mapping[1] + for topic_mapping in np.array(topic_model.topic_mapper_.mappings_)[:, -2:] + } + + # Map topic_to to topics_from + mapping = defaultdict(list) + for key, value in topics_to_map.items(): + mapping[value].append(key) + + for topic_to, topics_from in mapping.items(): + # which of the original topics are zero-shot + zeroshot_topic_ids = [ + topic_id for topic_id in topics_from if topic_id in topic_model._topic_id_to_zeroshot_topic_idx + ] + if len(zeroshot_topic_ids) == 0: + continue + + # If any of the original topics are zero-shot, take the best fitting zero-shot label + # if the cosine similarity with the new topic exceeds the zero-shot threshold + zeroshot_labels = [ + topic_model.zeroshot_topic_list[topic_model._topic_id_to_zeroshot_topic_idx[topic_id]] + for topic_id in zeroshot_topic_ids + ] + zeroshot_embeddings = topic_model._extract_embeddings(zeroshot_labels) + cosine_similarities = cosine_similarity( + zeroshot_embeddings, [topic_model.topic_embeddings_[topic_to]] + ).flatten() + best_zeroshot_topic_idx = np.argmax(cosine_similarities) + best_cosine_similarity = cosine_similarities[best_zeroshot_topic_idx] + + if best_cosine_similarity >= topic_model.zeroshot_min_similarity: + # Using the topic ID from before mapping, get the idx into the zeroshot topic list + new_topic_id_to_zeroshot_topic_idx[topic_to] = topic_model._topic_id_to_zeroshot_topic_idx[ + zeroshot_topic_ids[best_zeroshot_topic_idx] + ] + + topic_model._topic_id_to_zeroshot_topic_idx = new_topic_id_to_zeroshot_topic_idx + def add_new_topics(self, mappings: Mapping[int, int]): """Add new row(s) of topic mappings.