|
| 1 | +"""Integration tests for the public shapiq user flows. |
| 2 | +
|
| 3 | +These tests mirror the canonical flows documented in ``README.md`` and |
| 4 | +``docs/source/introduction/start.rst`` — load data, fit a real sklearn |
| 5 | +model, build a shapiq explainer, call ``.explain()``, then consume the |
| 6 | +``InteractionValues`` downstream (serialisation, plots). Their role is to |
| 7 | +catch regressions in the seams between sklearn, the imputer, the |
| 8 | +approximator, ``InteractionValues`` and the plot module — a class of bug |
| 9 | +that the per-module unit tests can miss. |
| 10 | +""" |
| 11 | + |
| 12 | +from __future__ import annotations |
| 13 | + |
| 14 | +import matplotlib as mpl |
| 15 | + |
| 16 | +mpl.use("Agg") |
| 17 | + |
| 18 | +import matplotlib.pyplot as plt |
| 19 | +import numpy as np |
| 20 | +import pytest |
| 21 | +from sklearn.ensemble import RandomForestRegressor |
| 22 | + |
| 23 | +import shapiq |
| 24 | + |
| 25 | +from .conftest import assert_iv_close |
| 26 | + |
| 27 | + |
| 28 | +@pytest.fixture(scope="session") |
| 29 | +def california_rf(): |
| 30 | + """Real sklearn RF on subsampled California housing — matches README.""" |
| 31 | + x_data, y_data = shapiq.load_california_housing(to_numpy=True) |
| 32 | + rng = np.random.default_rng(42) |
| 33 | + idx = rng.choice(len(x_data), size=150, replace=False) |
| 34 | + x_data, y_data = x_data[idx], y_data[idx] |
| 35 | + model = RandomForestRegressor(n_estimators=10, max_depth=4, random_state=42, n_jobs=1) |
| 36 | + model.fit(x_data, y_data) |
| 37 | + return model, x_data, y_data |
| 38 | + |
| 39 | + |
| 40 | +@pytest.mark.parametrize( |
| 41 | + ("index", "max_order"), |
| 42 | + [ |
| 43 | + ("SV", 1), |
| 44 | + ("k-SII", 2), |
| 45 | + ("FSII", 2), |
| 46 | + ("STII", 3), |
| 47 | + ], |
| 48 | +) |
| 49 | +def test_tabular_explainer_readme_flow(california_rf, index, max_order): |
| 50 | + """README quickstart through TabularExplainer, parametrised across indices. |
| 51 | +
|
| 52 | + Asserts the **efficiency axiom** that all four indices satisfy by design: |
| 53 | + summing the interaction values across all subsets up to ``max_order`` -- |
| 54 | + ``InteractionValues.values`` includes the empty-coalition entry storing |
| 55 | + the baseline -- recovers the model prediction on ``x``. |
| 56 | + """ |
| 57 | + model, x_data, _ = california_rf |
| 58 | + explainer = shapiq.TabularExplainer( |
| 59 | + model=model, |
| 60 | + data=x_data, |
| 61 | + index=index, |
| 62 | + max_order=max_order, |
| 63 | + random_state=42, |
| 64 | + ) |
| 65 | + iv = explainer.explain(x_data[0], budget=256) |
| 66 | + |
| 67 | + assert isinstance(iv, shapiq.InteractionValues) |
| 68 | + assert iv.index == index |
| 69 | + assert iv.max_order == max_order |
| 70 | + assert iv.n_players == x_data.shape[1] |
| 71 | + assert np.all(np.isfinite(iv.values)) |
| 72 | + |
| 73 | + pred = float(model.predict(x_data[:1])[0]) |
| 74 | + assert iv.values.sum() == pytest.approx(pred, abs=1e-4) |
| 75 | + |
| 76 | + |
| 77 | +@pytest.mark.parametrize( |
| 78 | + ("index", "max_order"), |
| 79 | + [ |
| 80 | + ("SV", 1), |
| 81 | + ("k-SII", 2), |
| 82 | + ], |
| 83 | +) |
| 84 | +def test_tree_explainer_efficiency(california_rf, index, max_order): |
| 85 | + """TreeExplainer pointwise efficiency — holds exactly for tree models.""" |
| 86 | + model, x_data, _ = california_rf |
| 87 | + x = x_data[0] |
| 88 | + iv = shapiq.TreeExplainer(model=model, index=index, max_order=max_order).explain(x) |
| 89 | + |
| 90 | + pred = float(model.predict(x.reshape(1, -1))[0]) |
| 91 | + assert iv.values.sum() == pytest.approx(pred, abs=1e-4) |
| 92 | + |
| 93 | + |
| 94 | +def test_agnostic_explainer_on_soum(soum_7, exact_soum_7): |
| 95 | + """AgnosticExplainer on a Game — researcher flow against exact ground truth.""" |
| 96 | + iv = shapiq.AgnosticExplainer(game=soum_7, index="k-SII", max_order=2, random_state=42).explain( |
| 97 | + budget=2**7 |
| 98 | + ) |
| 99 | + |
| 100 | + ground_truth = exact_soum_7("k-SII", order=2) |
| 101 | + assert_iv_close(iv, ground_truth, atol=1e-6) |
| 102 | + |
| 103 | + |
| 104 | +def test_interaction_values_roundtrip_and_plots(california_rf, tmp_path): |
| 105 | + """End-to-end consumption: serialise, reload, plot across all plot surfaces.""" |
| 106 | + model, x_data, _ = california_rf |
| 107 | + |
| 108 | + iv = shapiq.TabularExplainer( |
| 109 | + model=model, data=x_data, index="k-SII", max_order=2, random_state=42 |
| 110 | + ).explain(x_data[0], budget=256) |
| 111 | + |
| 112 | + path = tmp_path / "iv.json" |
| 113 | + iv.save(path) |
| 114 | + iv_loaded = shapiq.InteractionValues.load(path) |
| 115 | + assert np.allclose(iv.values, iv_loaded.values) |
| 116 | + assert iv.index == iv_loaded.index |
| 117 | + assert iv.max_order == iv_loaded.max_order |
| 118 | + |
| 119 | + feature_names = [f"f{i}" for i in range(x_data.shape[1])] |
| 120 | + assert shapiq.network_plot(iv, feature_names=feature_names) is not None |
| 121 | + assert shapiq.stacked_bar_plot(iv, feature_names=feature_names) is not None |
| 122 | + assert shapiq.bar_plot([iv], feature_names=feature_names) is not None |
| 123 | + assert shapiq.force_plot(iv.get_n_order(order=1), feature_names=feature_names) is not None |
| 124 | + assert shapiq.upset_plot(iv, feature_names=feature_names) is not None |
| 125 | + plt.close("all") |
0 commit comments