Skip to content

Commit b5f4eea

Browse files
committed
add end-to-end integration tests for public user flows
These tests exercise the full seam — sklearn → imputer → approximator → InteractionValues → plots/serialisation — mirroring the canonical flows from README.md and docs/source/introduction/start.rst. They catch cross-module regressions that pass every per-module unit test. Coverage (8 test invocations, <2s total): - test_tabular_explainer_readme_flow (parametrised SV / k-SII / FSII / STII), asserts the efficiency axiom holds end-to-end - test_tree_explainer_efficiency (parametrised SV / k-SII), asserts pointwise efficiency for TreeExplainer - test_agnostic_explainer_on_soum, verifies the Game-based researcher path against ExactComputer ground truth - test_interaction_values_roundtrip_and_plots, covers JSON save/load and all five top-level plot functions https://claude.ai/code/session_01DHsGf4an1Dnnw4qTnmdB22
1 parent d3b422b commit b5f4eea

1 file changed

Lines changed: 125 additions & 0 deletions

File tree

tests/shapiq/test_integration.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
"""Integration tests for the public shapiq user flows.
2+
3+
These tests mirror the canonical flows documented in ``README.md`` and
4+
``docs/source/introduction/start.rst`` — load data, fit a real sklearn
5+
model, build a shapiq explainer, call ``.explain()``, then consume the
6+
``InteractionValues`` downstream (serialisation, plots). Their role is to
7+
catch regressions in the seams between sklearn, the imputer, the
8+
approximator, ``InteractionValues`` and the plot module — a class of bug
9+
that the per-module unit tests can miss.
10+
"""
11+
12+
from __future__ import annotations
13+
14+
import matplotlib as mpl
15+
16+
mpl.use("Agg")
17+
18+
import matplotlib.pyplot as plt
19+
import numpy as np
20+
import pytest
21+
from sklearn.ensemble import RandomForestRegressor
22+
23+
import shapiq
24+
25+
from .conftest import assert_iv_close
26+
27+
28+
@pytest.fixture(scope="session")
29+
def california_rf():
30+
"""Real sklearn RF on subsampled California housing — matches README."""
31+
x_data, y_data = shapiq.load_california_housing(to_numpy=True)
32+
rng = np.random.default_rng(42)
33+
idx = rng.choice(len(x_data), size=150, replace=False)
34+
x_data, y_data = x_data[idx], y_data[idx]
35+
model = RandomForestRegressor(n_estimators=10, max_depth=4, random_state=42, n_jobs=1)
36+
model.fit(x_data, y_data)
37+
return model, x_data, y_data
38+
39+
40+
@pytest.mark.parametrize(
41+
("index", "max_order"),
42+
[
43+
("SV", 1),
44+
("k-SII", 2),
45+
("FSII", 2),
46+
("STII", 3),
47+
],
48+
)
49+
def test_tabular_explainer_readme_flow(california_rf, index, max_order):
50+
"""README quickstart through TabularExplainer, parametrised across indices.
51+
52+
Asserts the **efficiency axiom** that all four indices satisfy by design:
53+
summing the interaction values across all subsets up to ``max_order`` --
54+
``InteractionValues.values`` includes the empty-coalition entry storing
55+
the baseline -- recovers the model prediction on ``x``.
56+
"""
57+
model, x_data, _ = california_rf
58+
explainer = shapiq.TabularExplainer(
59+
model=model,
60+
data=x_data,
61+
index=index,
62+
max_order=max_order,
63+
random_state=42,
64+
)
65+
iv = explainer.explain(x_data[0], budget=256)
66+
67+
assert isinstance(iv, shapiq.InteractionValues)
68+
assert iv.index == index
69+
assert iv.max_order == max_order
70+
assert iv.n_players == x_data.shape[1]
71+
assert np.all(np.isfinite(iv.values))
72+
73+
pred = float(model.predict(x_data[:1])[0])
74+
assert iv.values.sum() == pytest.approx(pred, abs=1e-4)
75+
76+
77+
@pytest.mark.parametrize(
78+
("index", "max_order"),
79+
[
80+
("SV", 1),
81+
("k-SII", 2),
82+
],
83+
)
84+
def test_tree_explainer_efficiency(california_rf, index, max_order):
85+
"""TreeExplainer pointwise efficiency — holds exactly for tree models."""
86+
model, x_data, _ = california_rf
87+
x = x_data[0]
88+
iv = shapiq.TreeExplainer(model=model, index=index, max_order=max_order).explain(x)
89+
90+
pred = float(model.predict(x.reshape(1, -1))[0])
91+
assert iv.values.sum() == pytest.approx(pred, abs=1e-4)
92+
93+
94+
def test_agnostic_explainer_on_soum(soum_7, exact_soum_7):
95+
"""AgnosticExplainer on a Game — researcher flow against exact ground truth."""
96+
iv = shapiq.AgnosticExplainer(game=soum_7, index="k-SII", max_order=2, random_state=42).explain(
97+
budget=2**7
98+
)
99+
100+
ground_truth = exact_soum_7("k-SII", order=2)
101+
assert_iv_close(iv, ground_truth, atol=1e-6)
102+
103+
104+
def test_interaction_values_roundtrip_and_plots(california_rf, tmp_path):
105+
"""End-to-end consumption: serialise, reload, plot across all plot surfaces."""
106+
model, x_data, _ = california_rf
107+
108+
iv = shapiq.TabularExplainer(
109+
model=model, data=x_data, index="k-SII", max_order=2, random_state=42
110+
).explain(x_data[0], budget=256)
111+
112+
path = tmp_path / "iv.json"
113+
iv.save(path)
114+
iv_loaded = shapiq.InteractionValues.load(path)
115+
assert np.allclose(iv.values, iv_loaded.values)
116+
assert iv.index == iv_loaded.index
117+
assert iv.max_order == iv_loaded.max_order
118+
119+
feature_names = [f"f{i}" for i in range(x_data.shape[1])]
120+
assert shapiq.network_plot(iv, feature_names=feature_names) is not None
121+
assert shapiq.stacked_bar_plot(iv, feature_names=feature_names) is not None
122+
assert shapiq.bar_plot([iv], feature_names=feature_names) is not None
123+
assert shapiq.force_plot(iv.get_n_order(order=1), feature_names=feature_names) is not None
124+
assert shapiq.upset_plot(iv, feature_names=feature_names) is not None
125+
plt.close("all")

0 commit comments

Comments
 (0)