Skip to content

Commit 0383194

Browse files
committed
fix(plot): validate input shape in beeswarm_plot and sentence_plot
Both functions previously accepted malformed input silently: - beeswarm_plot with data.shape[1] != n_players would plot a subset or scramble columns without warning. - sentence_plot with len(words) != n_players would index past the InteractionValues or drop entries silently. Each gets a ValueError guard with a clear message. Re-enables the two dropped edge-case tests in TestPlotEdgeCases. https://claude.ai/code/session_01DHsGf4an1Dnnw4qTnmdB22
1 parent 8b32ad8 commit 0383194

3 files changed

Lines changed: 27 additions & 0 deletions

File tree

src/shapiq/plot/beeswarm.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,12 @@ def beeswarm_plot(
202202
n_samples = len(data)
203203
n_players = interaction_values_list[0].n_players
204204

205+
if data.shape[1] != n_players:
206+
error_message = (
207+
f"data must have {n_players} columns to match n_players, but got {data.shape[1]}."
208+
)
209+
raise ValueError(error_message)
210+
205211
if feature_names is not None:
206212
if abbreviate:
207213
feature_names = abbreviate_feature_names(feature_names)

src/shapiq/plot/sentence.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,13 @@ def sentence_plot(
9191
:align: center
9292
9393
"""
94+
if len(words) != interaction_values.n_players:
95+
error_message = (
96+
f"Number of words ({len(words)}) must match "
97+
f"interaction_values.n_players ({interaction_values.n_players})."
98+
)
99+
raise ValueError(error_message)
100+
94101
# set all the size parameters
95102
fontsize = 20
96103
word_spacing = 15

tests/shapiq/test_plots.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,3 +421,17 @@ def test_max_display_below_n_features(self, sample_iv):
421421
max_display=2,
422422
)
423423
assert ax is not None
424+
425+
def test_sentence_plot_word_count_mismatch_raises(self):
426+
"""``sentence_plot`` must raise when len(words) != n_players."""
427+
iv = _build_iv(n=3).get_n_order(order=1)
428+
with pytest.raises(ValueError, match="must match"):
429+
sentence_plot(iv, words=["only-one"])
430+
with pytest.raises(ValueError, match="must match"):
431+
sentence_plot(iv, words=["a", "b", "c", "d", "e"])
432+
433+
def test_beeswarm_plot_data_column_mismatch_raises(self, sample_iv_list):
434+
"""``beeswarm_plot`` must raise when data has wrong number of columns."""
435+
bad_data = np.random.default_rng(0).normal(size=(len(sample_iv_list), 99))
436+
with pytest.raises(ValueError, match="columns"):
437+
beeswarm_plot(sample_iv_list, data=bad_data, show=False)

0 commit comments

Comments
 (0)