1818import pandas as pd
1919import polars as pl
2020import scipy .sparse as sp
21- from copy import deepcopy
2221
2322
2423from pathlib import Path
@@ -1510,14 +1509,15 @@ def generate_topic_labels(
15101509
15111510 return topic_labels
15121511
1513- # TODO: Update
15141512 def merge_topics (
15151513 self ,
1516- docs : List [str ],
1517- topics_to_merge : List [ Union [ Iterable [int ], int ] ],
1518- images : List [str ] | None = None ,
1514+ docs : list [str ],
1515+ topics_to_merge : list [ Iterable [int ] | int ],
1516+ images : list [str ] | None = None ,
15191517 ) -> None :
1520- """Arguments:
1518+ """Merge multiple topics into a single topic.
1519+
1520+ Arguments:
15211521 docs: The documents you used when calling either `fit` or `fit_transform`
15221522 topics_to_merge: Either a list of topics or a list of list of topics
15231523 to merge. For example:
@@ -1546,15 +1546,8 @@ def merge_topics(
15461546 """
15471547 check_is_fitted (self )
15481548 check_documents_type (docs )
1549- documents = pd .DataFrame (
1550- {
1551- "Document" : docs ,
1552- "Topic" : self .topics_ ,
1553- "Image" : images ,
1554- "ID" : range (len (docs )),
1555- }
1556- )
15571549
1550+ # Build mapping: all topics map to themselves, except merged topics map to target
15581551 mapping = {topic : topic for topic in set (self .topics_ )}
15591552 if isinstance (topics_to_merge [0 ], int ):
15601553 for topic in sorted (topics_to_merge ):
@@ -1565,170 +1558,52 @@ def merge_topics(
15651558 mapping [topic ] = topic_group [0 ]
15661559 else :
15671560 raise ValueError (
1568- "Make sure that `topics_to_merge` is eithera list of topics or a list of list of topics."
1561+ "Make sure that `topics_to_merge` is either a list of topics or a list of list of topics."
15691562 )
15701563
1571- # Track mappings and sizes of topics for merging topic embeddings
1572- mappings = defaultdict ( list )
1573- for key , val in sorted ( mapping . items ()):
1574- mappings [ val ]. append ( key )
1575- mappings = {
1576- topic_to : {
1577- "topics_from" : topics_from ,
1578- "topic_sizes" : [ self .topic_sizes_ [ topic ] for topic in topics_from ] ,
1579- }
1580- for topic_to , topics_from in mappings . items ()
1581- }
1564+ # Build Corpus BEFORE merge with current topic assignments and embeddings.
1565+ # This ensures embeddings correspond to original topics, so after merge
1566+ # the mean embedding correctly produces a weighted average.
1567+ topic_embeddings = { topic . id : topic . embedding for topic in self . _topics }
1568+ doc_embeddings = np . array ([ topic_embeddings [ topic_id ] for topic_id in self . topics_ ])
1569+ corpus = Corpus (
1570+ documents = docs ,
1571+ topics = np . array ( self .topics_ ) ,
1572+ images = images ,
1573+ embeddings = doc_embeddings ,
1574+ )
15821575
1583- # Update topics
1584- documents .Topic = documents .Topic .map (mapping )
1585- self .topic_mapper_ .add_mappings (mapping , topic_model = self )
1586- documents = self ._sort_mappings_by_frequency (documents )
1587- self ._extract_representations (documents , mappings = mappings )
1588- self ._update_topic_size (documents )
1589- self ._save_representative_docs (documents )
1590- self .probabilities_ = self ._map_probabilities (self .probabilities_ )
1576+ # Merge topics (updates topic IDs and cumulative mapping)
1577+ self ._topics .merge (mapping )
15911578
1592- # TODO: Update
1593- def delete_topics (
1594- self ,
1595- topics_to_delete : List [int ],
1596- ) -> None :
1579+ # Map corpus topics to match merged state and then sort by frequency
1580+ corpus .map_topics_and_probabilities (self ._topics , from_original = False )
1581+ self ._topics .sort_by_frequency ()
1582+ corpus .map_topics_and_probabilities (self ._topics , from_original = False )
1583+
1584+ # Recalculate representations from merged documents
1585+ self ._extract_representations (corpus )
1586+ self ._save_representative_docs (corpus )
1587+
1588+ def delete_topics (self , topics_to_delete : list [int ] | int ) -> None :
15971589 """Delete topics from the topic model.
15981590
1599- The deleted topics will be mapped to -1 (outlier topic). Core topic attributes
1600- like topic embeddings and c-TF-IDF will be automatically updated .
1591+ The deleted topics will be mapped to -1 (outlier topic). Document predictions
1592+ for deleted topics become -1. Remaining topics are renumbered by frequency .
16011593
16021594 Arguments:
1603- topics_to_delete: List of topics to delete
1595+ topics_to_delete: List of topic IDs to delete or a single topic ID.
1596+
1597+ Examples:
1598+ ```python
1599+ # Delete topics 3 and 5
1600+ topic_model.delete_topics([3, 5])
1601+ ```
16041602 """
16051603 check_is_fitted (self )
16061604
1607- topics_df = pd .DataFrame ({"Topic" : self .topics_ })
1608-
1609- # Check if -1 exists in the current topics
1610- had_outliers = - 1 in set (self .topics_ )
1611-
1612- # If adding -1 for the first time, initialize its attributes
1613- if not had_outliers and any (topic in topics_to_delete for topic in self .topics_ ):
1614- # Initialize c-TF-IDF for -1 topic (zeros)
1615- outlier_row = np .zeros ((1 , self .c_tf_idf_ .shape [1 ]))
1616- outlier_row = sp .csr_matrix (outlier_row )
1617- self .c_tf_idf_ = sp .vstack ([outlier_row , self .c_tf_idf_ ])
1618-
1619- # Initialize topic embeddings for -1 topic (zeros)
1620- outlier_embedding = np .zeros ((1 , self .topic_embeddings_ .shape [1 ]))
1621- self .topic_embeddings_ = np .vstack ([outlier_embedding , self .topic_embeddings_ ])
1622-
1623- # Initialize topic representations for -1 topic: ("", 1e-05)
1624- self .topic_representations_ [- 1 ] = [("" , 1e-05 )]
1625-
1626- # Initialize representative docs for -1 topic (empty list)
1627- self .representative_docs_ [- 1 ] = []
1628-
1629- # Initialize representative images for -1 topic if images are being used
1630- if self .representative_images_ is not None :
1631- outlier_image = np .zeros ((1 , self .representative_images_ .shape [1 ]))
1632- self .representative_images_ = np .vstack ([outlier_image , self .representative_images_ ])
1633-
1634- # Initialize custom labels for -1 topic if they exist
1635- if hasattr (self , "custom_labels_" ) and self .custom_labels_ is not None :
1636- self .custom_labels_ [- 1 ] = ""
1637-
1638- # Initialize ctfidf model diagonal for -1 topic (ones) if it exists
1639- if hasattr (self , "ctfidf_model" ) and self .ctfidf_model is not None :
1640- n_features = self .ctfidf_model ._idf_diag .shape [1 ]
1641- outlier_diag = sp .csr_matrix (([1.0 ], ([0 ], [0 ])), shape = (1 , n_features ))
1642- self .ctfidf_model ._idf_diag = sp .vstack ([outlier_diag , self .ctfidf_model ._idf_diag ])
1643-
1644- # Initialize topic aspects for -1 topic (empty dict for each aspect) if they exist
1645- if hasattr (self , "topic_aspects_" ) and self .topic_aspects_ is not None :
1646- for aspect in self .topic_aspects_ :
1647- self .topic_aspects_ [aspect ][- 1 ] = {}
1648-
1649- # First map deleted topics to -1
1650- mapping = {topic : - 1 if topic in topics_to_delete else topic for topic in set (self .topics_ )}
1651- mapping [- 1 ] = - 1
1652-
1653- # Track mappings and sizes of topics for merging topic embeddings
1654- mappings = defaultdict (list )
1655- for key , val in sorted (mapping .items ()):
1656- mappings [val ].append (key )
1657- mappings = {
1658- topic_to : {
1659- "topics_from" : topics_from ,
1660- "topic_sizes" : [self .topic_sizes_ [topic ] for topic in topics_from ],
1661- }
1662- for topic_to , topics_from in mappings .items ()
1663- }
1664-
1665- # remove deleted topics and update attributes
1666- topics_df .Topic = topics_df .Topic .map (mapping )
1667- self .topic_mapper_ .add_mappings (mapping , topic_model = deepcopy (self ))
1668- topics_df = self ._sort_mappings_by_frequency (topics_df )
1669- self ._update_topic_size (topics_df )
1670- self .probabilities_ = self ._map_probabilities (self .probabilities_ )
1671-
1672- final_mapping = self .topic_mapper_ .get_mappings (original_topics = False )
1673-
1674- # Update dictionary-based attributes to remove deleted topics
1675- # Handle topic_aspects_ if it exists
1676- if hasattr (self , "topic_aspects_" ) and self .topic_aspects_ is not None :
1677- new_aspects = {
1678- aspect : {
1679- (final_mapping [old_topic ] if old_topic != - 1 else - 1 ): content
1680- for old_topic , content in topics .items ()
1681- if old_topic not in topics_to_delete
1682- }
1683- for aspect , topics in self .topic_aspects_ .items ()
1684- }
1685- self .topic_aspects_ = new_aspects
1686-
1687- # Update custom labels if they exist
1688- if hasattr (self , "custom_labels_" ) and self .custom_labels_ is not None :
1689- new_labels = {
1690- (final_mapping [old_topic ] if old_topic != - 1 else - 1 ): label
1691- for old_topic , label in self .custom_labels_ .items ()
1692- if old_topic not in topics_to_delete
1693- }
1694- self .custom_labels_ = new_labels
1695-
1696- # Update topic representations
1697- new_representations = {
1698- (final_mapping [old_topic ] if old_topic != - 1 else - 1 ): content
1699- for old_topic , content in self .topic_representations_ .items ()
1700- if old_topic not in topics_to_delete
1701- }
1702- self .topic_representations_ = new_representations
1703-
1704- # Update representative docs if they exist
1705- new_representative_docs = {
1706- (final_mapping [old_topic ] if old_topic != - 1 else - 1 ): docs
1707- for old_topic , docs in self .representative_docs_ .items ()
1708- if old_topic not in topics_to_delete
1709- }
1710- self .representative_docs_ = new_representative_docs
1711-
1712- # Update representative images if they exist
1713- if self .representative_images_ is not None :
1714- # Create a mask for non-deleted topics
1715- mask = np .array (
1716- [topic not in topics_to_delete for topic in range (len (self .representative_images_ ))]
1717- )
1718- self .representative_images_ = self .representative_images_ [mask ] if mask .any () else None
1719-
1720- # Update array-based attributes using masks to remove deleted topics
1721- for attr in ["topic_embeddings_" , "c_tf_idf_" ]:
1722- matrix = getattr (self , attr )
1723- mask = np .array ([topic not in topics_to_delete for topic in range (matrix .shape [0 ])])
1724- setattr (self , attr , matrix [mask ])
1725-
1726- # Update ctfidf model to remove deleted topics if it exists
1727- if hasattr (self , "ctfidf_model" ) and self .ctfidf_model is not None :
1728- mask = np .array (
1729- [topic not in topics_to_delete for topic in range (self .ctfidf_model ._idf_diag .shape [0 ])]
1730- )
1731- self .ctfidf_model ._idf_diag = self .ctfidf_model ._idf_diag [mask ]
1605+ self ._topics .delete (topics_to_delete )
1606+ self ._topics .sort_by_frequency ()
17321607
17331608 def reduce_topics (
17341609 self ,
0 commit comments