Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion bertopic/_bertopic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5006,12 +5006,18 @@ def add_mappings(self, mappings: Mapping[int, int], topic_model: BERTopic):
def add_new_topics(self, mappings: Mapping[int, int]):
"""Add new row(s) of topic mappings.

New topics did not exist at earlier states, so the intermediate
history columns are backfilled with the topic's own ``key`` to
keep ``mappings_`` a homogeneous integer matrix. Without this,
``None`` placeholders break ``model.save(serialization="safetensors")``
which casts the matrix to ``np.array(..., dtype=int)``.

Arguments:
mappings: The mappings to add
"""
length = len(self.mappings_[0])
for key, value in mappings.items():
to_append = [key] + ([None] * (length - 2)) + [value]
to_append = [key] * (length - 1) + [value]
self.mappings_.append(to_append)


Expand Down
33 changes: 33 additions & 0 deletions tests/test_bertopic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import copy
import pytest
import numpy as np
from bertopic import BERTopic
from bertopic._bertopic import TopicMapper
import importlib.util


Expand Down Expand Up @@ -153,3 +155,34 @@ def test_full_model(model, documents, request):
merged_model = BERTopic.merge_models([topic_model, topic_model1])

assert len(merged_model.get_topic_info()) > len(topic_model.get_topic_info())


def test_topic_mapper_add_new_topics_keeps_integer_matrix():
"""Regression test for #2432: ``TopicMapper.add_new_topics`` must keep
``mappings_`` as a homogeneous integer matrix.

Previously, ``add_new_topics`` inserted ``None`` placeholders for the
intermediate history columns, which broke ``model.save()`` because
``_save_utils.save_topics`` casts the matrix to ``np.array(..., dtype=int)``.
"""
mapper = TopicMapper(topics=[-1, 0, 1, 2])
# Simulate two prior reduce_topics calls so the matrix has more than 2
# columns (the buggy ``length - 2`` path is hidden when ``__init__``'s
# default 2-column shape is used).
for row in mapper.mappings_:
row.append(row[-1])
row.append(row[-1])
pre_existing = [list(row) for row in mapper.mappings_]

# New clusters discovered during partial_fit
mapper.add_new_topics({3: 2, 4: 3})

# The matrix must round-trip through ``np.array(..., dtype=int)``
# (mirrors what ``_save_utils.save_topics`` does).
matrix = np.array(mapper.mappings_, dtype=int)
# Pre-existing rows must be untouched.
assert mapper.mappings_[: len(pre_existing)] == pre_existing
# Original and current state of new rows must be preserved, and the
# intermediate history columns must be backfilled with the topic's own key.
assert (matrix[-2, :-1] == 3).all() and matrix[-2, -1] == 2
assert (matrix[-1, :-1] == 4).all() and matrix[-1, -1] == 3
Loading