Skip to content

Commit e3cf6a9

Browse files
committed
Update merge_topics and delete_topics
1 parent 0298751 commit e3cf6a9

File tree

2 files changed

+112
-170
lines changed

2 files changed

+112
-170
lines changed

bertopic/_bertopic.py

Lines changed: 42 additions & 167 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import pandas as pd
1919
import polars as pl
2020
import scipy.sparse as sp
21-
from copy import deepcopy
2221

2322

2423
from pathlib import Path
@@ -1510,14 +1509,15 @@ def generate_topic_labels(
15101509

15111510
return topic_labels
15121511

1513-
# TODO: Update
15141512
def merge_topics(
15151513
self,
1516-
docs: List[str],
1517-
topics_to_merge: List[Union[Iterable[int], int]],
1518-
images: List[str] | None = None,
1514+
docs: list[str],
1515+
topics_to_merge: list[Iterable[int] | int],
1516+
images: list[str] | None = None,
15191517
) -> None:
1520-
"""Arguments:
1518+
"""Merge multiple topics into a single topic.
1519+
1520+
Arguments:
15211521
docs: The documents you used when calling either `fit` or `fit_transform`
15221522
topics_to_merge: Either a list of topics or a list of list of topics
15231523
to merge. For example:
@@ -1546,15 +1546,8 @@ def merge_topics(
15461546
"""
15471547
check_is_fitted(self)
15481548
check_documents_type(docs)
1549-
documents = pd.DataFrame(
1550-
{
1551-
"Document": docs,
1552-
"Topic": self.topics_,
1553-
"Image": images,
1554-
"ID": range(len(docs)),
1555-
}
1556-
)
15571549

1550+
# Build mapping: all topics map to themselves, except merged topics map to target
15581551
mapping = {topic: topic for topic in set(self.topics_)}
15591552
if isinstance(topics_to_merge[0], int):
15601553
for topic in sorted(topics_to_merge):
@@ -1565,170 +1558,52 @@ def merge_topics(
15651558
mapping[topic] = topic_group[0]
15661559
else:
15671560
raise ValueError(
1568-
"Make sure that `topics_to_merge` is eithera list of topics or a list of list of topics."
1561+
"Make sure that `topics_to_merge` is either a list of topics or a list of list of topics."
15691562
)
15701563

1571-
# Track mappings and sizes of topics for merging topic embeddings
1572-
mappings = defaultdict(list)
1573-
for key, val in sorted(mapping.items()):
1574-
mappings[val].append(key)
1575-
mappings = {
1576-
topic_to: {
1577-
"topics_from": topics_from,
1578-
"topic_sizes": [self.topic_sizes_[topic] for topic in topics_from],
1579-
}
1580-
for topic_to, topics_from in mappings.items()
1581-
}
1564+
# Build Corpus BEFORE merge with current topic assignments and embeddings.
1565+
# This ensures embeddings correspond to original topics, so after merge
1566+
# the mean embedding correctly produces a weighted average.
1567+
topic_embeddings = {topic.id: topic.embedding for topic in self._topics}
1568+
doc_embeddings = np.array([topic_embeddings[topic_id] for topic_id in self.topics_])
1569+
corpus = Corpus(
1570+
documents=docs,
1571+
topics=np.array(self.topics_),
1572+
images=images,
1573+
embeddings=doc_embeddings,
1574+
)
15821575

1583-
# Update topics
1584-
documents.Topic = documents.Topic.map(mapping)
1585-
self.topic_mapper_.add_mappings(mapping, topic_model=self)
1586-
documents = self._sort_mappings_by_frequency(documents)
1587-
self._extract_representations(documents, mappings=mappings)
1588-
self._update_topic_size(documents)
1589-
self._save_representative_docs(documents)
1590-
self.probabilities_ = self._map_probabilities(self.probabilities_)
1576+
# Merge topics (updates topic IDs and cumulative mapping)
1577+
self._topics.merge(mapping)
15911578

1592-
# TODO: Update
1593-
def delete_topics(
1594-
self,
1595-
topics_to_delete: List[int],
1596-
) -> None:
1579+
# Map corpus topics to match merged state and then sort by frequency
1580+
corpus.map_topics_and_probabilities(self._topics, from_original=False)
1581+
self._topics.sort_by_frequency()
1582+
corpus.map_topics_and_probabilities(self._topics, from_original=False)
1583+
1584+
# Recalculate representations from merged documents
1585+
self._extract_representations(corpus)
1586+
self._save_representative_docs(corpus)
1587+
1588+
def delete_topics(self, topics_to_delete: list[int] | int) -> None:
15971589
"""Delete topics from the topic model.
15981590
1599-
The deleted topics will be mapped to -1 (outlier topic). Core topic attributes
1600-
like topic embeddings and c-TF-IDF will be automatically updated.
1591+
The deleted topics will be mapped to -1 (outlier topic). Document predictions
1592+
for deleted topics become -1. Remaining topics are renumbered by frequency.
16011593
16021594
Arguments:
1603-
topics_to_delete: List of topics to delete
1595+
topics_to_delete: List of topic IDs to delete or a single topic ID.
1596+
1597+
Examples:
1598+
```python
1599+
# Delete topics 3 and 5
1600+
topic_model.delete_topics([3, 5])
1601+
```
16041602
"""
16051603
check_is_fitted(self)
16061604

1607-
topics_df = pd.DataFrame({"Topic": self.topics_})
1608-
1609-
# Check if -1 exists in the current topics
1610-
had_outliers = -1 in set(self.topics_)
1611-
1612-
# If adding -1 for the first time, initialize its attributes
1613-
if not had_outliers and any(topic in topics_to_delete for topic in self.topics_):
1614-
# Initialize c-TF-IDF for -1 topic (zeros)
1615-
outlier_row = np.zeros((1, self.c_tf_idf_.shape[1]))
1616-
outlier_row = sp.csr_matrix(outlier_row)
1617-
self.c_tf_idf_ = sp.vstack([outlier_row, self.c_tf_idf_])
1618-
1619-
# Initialize topic embeddings for -1 topic (zeros)
1620-
outlier_embedding = np.zeros((1, self.topic_embeddings_.shape[1]))
1621-
self.topic_embeddings_ = np.vstack([outlier_embedding, self.topic_embeddings_])
1622-
1623-
# Initialize topic representations for -1 topic: ("", 1e-05)
1624-
self.topic_representations_[-1] = [("", 1e-05)]
1625-
1626-
# Initialize representative docs for -1 topic (empty list)
1627-
self.representative_docs_[-1] = []
1628-
1629-
# Initialize representative images for -1 topic if images are being used
1630-
if self.representative_images_ is not None:
1631-
outlier_image = np.zeros((1, self.representative_images_.shape[1]))
1632-
self.representative_images_ = np.vstack([outlier_image, self.representative_images_])
1633-
1634-
# Initialize custom labels for -1 topic if they exist
1635-
if hasattr(self, "custom_labels_") and self.custom_labels_ is not None:
1636-
self.custom_labels_[-1] = ""
1637-
1638-
# Initialize ctfidf model diagonal for -1 topic (ones) if it exists
1639-
if hasattr(self, "ctfidf_model") and self.ctfidf_model is not None:
1640-
n_features = self.ctfidf_model._idf_diag.shape[1]
1641-
outlier_diag = sp.csr_matrix(([1.0], ([0], [0])), shape=(1, n_features))
1642-
self.ctfidf_model._idf_diag = sp.vstack([outlier_diag, self.ctfidf_model._idf_diag])
1643-
1644-
# Initialize topic aspects for -1 topic (empty dict for each aspect) if they exist
1645-
if hasattr(self, "topic_aspects_") and self.topic_aspects_ is not None:
1646-
for aspect in self.topic_aspects_:
1647-
self.topic_aspects_[aspect][-1] = {}
1648-
1649-
# First map deleted topics to -1
1650-
mapping = {topic: -1 if topic in topics_to_delete else topic for topic in set(self.topics_)}
1651-
mapping[-1] = -1
1652-
1653-
# Track mappings and sizes of topics for merging topic embeddings
1654-
mappings = defaultdict(list)
1655-
for key, val in sorted(mapping.items()):
1656-
mappings[val].append(key)
1657-
mappings = {
1658-
topic_to: {
1659-
"topics_from": topics_from,
1660-
"topic_sizes": [self.topic_sizes_[topic] for topic in topics_from],
1661-
}
1662-
for topic_to, topics_from in mappings.items()
1663-
}
1664-
1665-
# remove deleted topics and update attributes
1666-
topics_df.Topic = topics_df.Topic.map(mapping)
1667-
self.topic_mapper_.add_mappings(mapping, topic_model=deepcopy(self))
1668-
topics_df = self._sort_mappings_by_frequency(topics_df)
1669-
self._update_topic_size(topics_df)
1670-
self.probabilities_ = self._map_probabilities(self.probabilities_)
1671-
1672-
final_mapping = self.topic_mapper_.get_mappings(original_topics=False)
1673-
1674-
# Update dictionary-based attributes to remove deleted topics
1675-
# Handle topic_aspects_ if it exists
1676-
if hasattr(self, "topic_aspects_") and self.topic_aspects_ is not None:
1677-
new_aspects = {
1678-
aspect: {
1679-
(final_mapping[old_topic] if old_topic != -1 else -1): content
1680-
for old_topic, content in topics.items()
1681-
if old_topic not in topics_to_delete
1682-
}
1683-
for aspect, topics in self.topic_aspects_.items()
1684-
}
1685-
self.topic_aspects_ = new_aspects
1686-
1687-
# Update custom labels if they exist
1688-
if hasattr(self, "custom_labels_") and self.custom_labels_ is not None:
1689-
new_labels = {
1690-
(final_mapping[old_topic] if old_topic != -1 else -1): label
1691-
for old_topic, label in self.custom_labels_.items()
1692-
if old_topic not in topics_to_delete
1693-
}
1694-
self.custom_labels_ = new_labels
1695-
1696-
# Update topic representations
1697-
new_representations = {
1698-
(final_mapping[old_topic] if old_topic != -1 else -1): content
1699-
for old_topic, content in self.topic_representations_.items()
1700-
if old_topic not in topics_to_delete
1701-
}
1702-
self.topic_representations_ = new_representations
1703-
1704-
# Update representative docs if they exist
1705-
new_representative_docs = {
1706-
(final_mapping[old_topic] if old_topic != -1 else -1): docs
1707-
for old_topic, docs in self.representative_docs_.items()
1708-
if old_topic not in topics_to_delete
1709-
}
1710-
self.representative_docs_ = new_representative_docs
1711-
1712-
# Update representative images if they exist
1713-
if self.representative_images_ is not None:
1714-
# Create a mask for non-deleted topics
1715-
mask = np.array(
1716-
[topic not in topics_to_delete for topic in range(len(self.representative_images_))]
1717-
)
1718-
self.representative_images_ = self.representative_images_[mask] if mask.any() else None
1719-
1720-
# Update array-based attributes using masks to remove deleted topics
1721-
for attr in ["topic_embeddings_", "c_tf_idf_"]:
1722-
matrix = getattr(self, attr)
1723-
mask = np.array([topic not in topics_to_delete for topic in range(matrix.shape[0])])
1724-
setattr(self, attr, matrix[mask])
1725-
1726-
# Update ctfidf model to remove deleted topics if it exists
1727-
if hasattr(self, "ctfidf_model") and self.ctfidf_model is not None:
1728-
mask = np.array(
1729-
[topic not in topics_to_delete for topic in range(self.ctfidf_model._idf_diag.shape[0])]
1730-
)
1731-
self.ctfidf_model._idf_diag = self.ctfidf_model._idf_diag[mask]
1605+
self._topics.delete(topics_to_delete)
1606+
self._topics.sort_by_frequency()
17321607

17331608
def reduce_topics(
17341609
self,

bertopic/_topics.py

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __str__(self) -> str:
2424
class Keywords(TopicRepresentation):
2525
"""Weighted keywords representing a topic."""
2626

27-
data: list[tuple[str, float]] = field(default_factory=list)
27+
data: list[tuple[str, float]] = field(default_factory=lambda: [("", 1e-05)])
2828

2929
def top_n(self, n: int = 10) -> list[tuple[str, float]]:
3030
"""Get the top N keywords by score."""
@@ -257,8 +257,7 @@ def to_dict(self) -> dict:
257257
info[name] = str(rep)
258258

259259
# Representative documents and images
260-
if self.representative_documents:
261-
info["Representative_Docs"] = self.representative_documents
260+
info["Representative_Docs"] = self.representative_documents if self.representative_documents else [""]
262261
if self.representative_images is not None and self.representative_images.size > 0:
263262
info["Representative_Images"] = self.representative_images
264263

@@ -605,6 +604,74 @@ def merge(self, old_to_new: dict[int, int]) -> None:
605604
self.mapping.apply(old_to_new)
606605
self.add_action(TopicAction.MERGED)
607606

607+
def delete(self, topics: list[int] | int) -> None:
608+
"""Delete topics by mapping them to the outlier topic (-1).
609+
610+
This will:
611+
* Create an outlier topic if it doesn't exist
612+
* Map deleted topic predictions to -1
613+
* Remove deleted Topic objects from the collection
614+
* Update outlier's nr_documents with the sum of deleted topic counts
615+
616+
Arguments:
617+
topics: List of topic IDs to delete or a single topic ID.
618+
"""
619+
if isinstance(topics, int):
620+
topics = {topics}
621+
else:
622+
topics = set(topics)
623+
624+
topics = {topics} if isinstance(topics, int) else set(topics)
625+
626+
# Calculate total documents being moved to outlier
627+
deleted_doc_count = sum(
628+
self.topics[topic_id].nr_documents for topic_id in topics if topic_id in self.topics
629+
)
630+
631+
# Create outlier topic if it doesn't exist
632+
if -1 not in self.topics:
633+
sample_topic = next(iter(self.topics.values()))
634+
635+
# If only one topic is deleted, we use its data for the outlier
636+
if len(topics) == 1:
637+
selected_topic = self.topics[next(iter(topics))]
638+
embedding = selected_topic.embedding
639+
c_tf_idf = selected_topic.c_tf_idf
640+
representative_documents = selected_topic.representative_documents
641+
representations = selected_topic.representations
642+
else:
643+
embedding_dim = sample_topic.embedding.shape[0]
644+
embedding = np.zeros(embedding_dim) if embedding_dim > 0 else np.array([])
645+
c_tf_idf_dim = sample_topic.c_tf_idf.shape[1]
646+
c_tf_idf = csr_matrix((1, c_tf_idf_dim)) if c_tf_idf_dim > 0 else csr_matrix([])
647+
representative_documents = [""]
648+
representations = {name: type(rep)() for name, rep in sample_topic.representations.items()}
649+
650+
# Build outlier topic
651+
self.topics[-1] = Topic(
652+
id=-1,
653+
representations=representations,
654+
representative_documents=representative_documents,
655+
embedding=embedding,
656+
c_tf_idf=c_tf_idf,
657+
topic_type=TopicType.OUTLIER,
658+
nr_documents=deleted_doc_count,
659+
)
660+
else:
661+
self.topics[-1].nr_documents += deleted_doc_count
662+
663+
# Build mapping: deleted -> -1, others -> themselves
664+
old_to_new = {topic_id: -1 if topic_id in topics else topic_id for topic_id in self.topics.keys()}
665+
old_to_new[-1] = -1
666+
667+
# Remove deleted topics
668+
for topic_id in topics:
669+
if topic_id in self.topics:
670+
del self.topics[topic_id]
671+
672+
self.mapping.apply(old_to_new)
673+
self.add_action(TopicAction.DELETED)
674+
608675
def map_predictions(self, predictions: list[int], from_original: bool) -> list[int]:
609676
"""Map a list of original predictions to current IDs.
610677

0 commit comments

Comments
 (0)