Skip to content

Commit 984c08b

Browse files
authored
Merge branch 'main' into dependabot/github_actions/actions/upload-artifact-5
2 parents 0b32039 + 68678c0 commit 984c08b

27 files changed

Lines changed: 1298 additions & 33 deletions

CHANGELOG.md

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,21 @@
22

33
## Development
44

5+
### Introducing ProxySPEX
6+
Adds the ProxySPEX approximator for efficient computation of sparse interaction values using the new ProxySPEX algorithm.
7+
For further details refer to: Butler, L., Kang, J.S., Agarwal, A., Erginbas, Y.E., Yu, Bin, Ramchandran, K. (2025). ProxySPEX: Inference-Efficient Interpretability via Sparse Feature Interactions in LLMs https://arxiv.org/pdf/2505.17495
8+
9+
10+
### Introducing ProductKernelExplainer
11+
The ProductKernelExplainer is a new model-specific explanation method for Product Kernel based machine learning model, such as Gaussian Processes or Support Vector Machines.
12+
For further details refer to: https://arxiv.org/abs/2505.16516
13+
514
### Shapiq Statically Typechecked [#430](https://github.com/mmschlk/shapiq/pull/430)
615
We have introduced static type checking to `shapiq` using [Pyright](https://github.com/microsoft/pyright), and integrated it into our `pre-commit` hooks.
716
This ensures that type inconsistencies are caught early during development, improving code quality and maintainability.
817
Developers will now benefit from immediate feedback on type errors, making the codebase more robust and reliable as it evolves.
918

1019
### Separation of `shapiq` into `shapiq`, `shapiq_games`, and `shapiq-benchmark`
11-
1220
We have begun the process of modularizing the `shapiq` package by splitting it into three distinct packages: `shapiq`, `shapiq_games`, and `shapiq-benchmark`.
1321

1422
- The `shapiq` package now serves as the core library. It contains the main functionality, including approximators, explainers, computation routines, interaction value logic, and plotting utilities.
@@ -28,8 +36,10 @@ This restructuring aims to improve maintainability and development scalability.
2836
### Bugfixes
2937
- fixes a bug where RegressionFBII approximator was throwing an error when the index was `'BV'` or `'FBII'`.[#420](https://github.com/mmschlk/shapiq/pull/420)
3038

31-
### New Features
39+
### All New Features
3240
- adds the ProxySPEX (Proxy Sparse Explanation) module in `approximator.sparse` for even more efficient computation of sparse interaction values [#442](https://github.com/mmschlk/shapiq/pull/442)
41+
- uses `predict_logits` method of sklearn-like classifiers if available in favor of `predict_proba` to support models that also offer logit outputs like TabPFNClassifier for better interpretability of the explanations [#426](https://github.com/mmschlk/shapiq/issues/426)
42+
- adds the `shapiq.explainer.ProductKernelExplainer` for model-specific explanation of Product Kernel based models like Gaussian Processes and Support Vector Machines. [#431](https://github.com/mmschlk/shapiq/pull/431)
3343

3444
### Removed Features
3545
- removes the ability to load `InteractionValues` from pickle files. This is now deprecated and will be removed in the next release. Use `InteractionValues.save(..., as_json=True)` to save interaction values as JSON files instead. [#413](https://github.com/mmschlk/shapiq/issues/413)

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,11 +221,11 @@ pythonPlatform = "Linux"
221221

222222
[dependency-groups]
223223
all_ml = [
224-
"tabpfn>=2.0.7",
224+
"tabpfn>=2.1.3",
225225
"torchvision",
226226
"torch",
227227
"xgboost",
228-
"lightgbm; platform_system != 'Darwin'", # lightgbm has install problems on macOS
228+
"lightgbm; platform_system != 'Darwin'", # lightgbm has install problems on macOS in github actions
229229
"transformers",
230230
"scikit-image",
231231
"tensorflow; python_version < '3.13' and platform_system != 'Windows'", # only up to py 3.12

src/shapiq/approximator/regression/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,10 @@ def _init_kernel_weights(self, interaction_size: int) -> FloatVector:
111111
else:
112112
weight_vector[coalition_size] = 1 / (
113113
(self.n - 2 * interaction_size + 1)
114-
* binom(self.n - 2 * interaction_size, coalition_size - interaction_size)
114+
* binom(
115+
self.n - 2 * interaction_size,
116+
coalition_size - interaction_size,
117+
)
115118
)
116119
return weight_vector
117120
msg = f"Index {self.index} not available for Regression Approximator."

src/shapiq/explainer/custom_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@
55
from typing import Literal
66

77
ExplainerIndices = Literal["SV", "SII", "k-SII", "STII", "FSII", "BV", "BII", "FBII"]
8+
ValidProductKernelExplainerIndices = Literal["SV"]
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
"""Implementation of the ProductKernelComputer and the ProductKernelExplainer."""
2+
3+
from .base import ProductKernelModel
4+
from .explainer import ProductKernelExplainer
5+
from .product_kernel import ProductKernelComputer
6+
7+
__all__ = ["ProductKernelModel", "ProductKernelExplainer", "ProductKernelComputer"]
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
"""The base class for product kernel model conversion."""
2+
3+
from __future__ import annotations
4+
5+
from dataclasses import dataclass
6+
from typing import TYPE_CHECKING
7+
8+
if TYPE_CHECKING:
9+
import numpy as np
10+
11+
12+
@dataclass
13+
class ProductKernelModel:
14+
"""A dataclass for storing the information of a product kernel model.
15+
16+
Attributes:
17+
alpha: The alpha parameter of the product kernel model.
18+
X_train: The training data used to fit the product kernel model.
19+
n: The number of samples in the training data.
20+
d: The number of features in the training data.
21+
gamma: The gamma parameter of the product kernel model.
22+
intercept: The intercept term of the product kernel model. For Gaussian Processes this should be zero, but support vectors have often non-zero intercepts.
23+
"""
24+
25+
X_train: np.ndarray
26+
alpha: np.ndarray
27+
n: int
28+
d: int
29+
gamma: float | None = None
30+
kernel_type: str = "rbf"
31+
intercept: float = 0.0
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
"""Functions for converting scikit-learn models to a format used by shapiq."""
2+
3+
from __future__ import annotations
4+
5+
from typing import TYPE_CHECKING
6+
7+
import numpy as np
8+
9+
from shapiq.explainer.product_kernel.base import ProductKernelModel
10+
11+
if TYPE_CHECKING:
12+
from sklearn.gaussian_process import GaussianProcessRegressor
13+
from sklearn.svm import SVC, SVR
14+
15+
16+
def convert_svm(model: SVC | SVR) -> ProductKernelModel:
17+
"""Converts a scikit-learn SVM model to the product kernel format used by shapiq.
18+
19+
Args:
20+
model: The scikit-learn SVM model to convert. Can be either a binary support vector classifier (SVC) or a support vector regressor (SVR).
21+
22+
Returns:
23+
ProductKernelModel: The converted model in the product kernel format.
24+
25+
"""
26+
X_train = model.support_vectors_
27+
n, d = X_train.shape
28+
29+
if hasattr(model, "kernel"):
30+
kernel_type = model.kernel # pyright: ignore[reportAttributeAccessIssue]
31+
if kernel_type != "rbf":
32+
msg = "Currently only RBF kernel is supported for SVM models."
33+
raise ValueError(msg)
34+
else:
35+
msg = "Kernel type not found in the model. Ensure the model is a valid SVC or SVR."
36+
raise ValueError(msg)
37+
38+
return ProductKernelModel(
39+
alpha=model.dual_coef_.flatten(), # pyright: ignore[reportAttributeAccessIssue]
40+
X_train=X_train,
41+
n=n,
42+
d=d,
43+
gamma=model._gamma, # pyright: ignore[reportArgumentType, reportAttributeAccessIssue] # noqa: SLF001
44+
kernel_type=kernel_type,
45+
intercept=model.intercept_[0],
46+
)
47+
48+
49+
def convert_gp_reg(model: GaussianProcessRegressor) -> ProductKernelModel:
50+
"""Converts a scikit-learn Gaussian Process Regression model to the product kernel format used by shapiq.
51+
52+
Args:
53+
model: The scikit-learn Gaussian Process Regression model to convert.
54+
55+
Returns:
56+
ProductKernelModel: The converted model in the product kernel format.
57+
58+
"""
59+
X_train = np.array(model.X_train_)
60+
n, d = X_train.shape
61+
62+
if hasattr(model, "kernel"):
63+
kernel_type = model.kernel_.__class__.__name__.lower() # Get the kernel type as a string
64+
if kernel_type != "rbf":
65+
msg = "Currently only RBF kernel is supported for Gaussian Process Regression models."
66+
raise ValueError(msg)
67+
else:
68+
msg = "Kernel type not found in the model. Ensure the model is a valid Gaussian Process Regressor."
69+
raise ValueError(msg)
70+
71+
alphas = np.array(model.alpha_).flatten()
72+
parameters = (
73+
model.kernel_.get_params() # pyright: ignore[reportAttributeAccessIssue]
74+
)
75+
if "length_scale" in parameters:
76+
length_scale = parameters["length_scale"]
77+
else:
78+
msg = "Length scale parameter not found in the kernel."
79+
raise ValueError(msg)
80+
81+
return ProductKernelModel(
82+
alpha=alphas,
83+
X_train=X_train,
84+
n=n,
85+
d=d,
86+
gamma=(2 * (length_scale**2)) ** -1,
87+
kernel_type=kernel_type,
88+
)
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
"""Implementation of the ProductKernelExplainer class."""
2+
3+
from __future__ import annotations
4+
5+
from typing import TYPE_CHECKING, Any
6+
7+
from shapiq import InteractionValues
8+
from shapiq.explainer.base import Explainer
9+
from shapiq.game_theory import get_computation_index
10+
11+
from .product_kernel import ProductKernelComputer, ProductKernelSHAPIQIndices
12+
from .validation import validate_pk_model
13+
14+
if TYPE_CHECKING:
15+
import numpy as np
16+
from sklearn.gaussian_process import GaussianProcessRegressor
17+
from sklearn.svm import SVC, SVR
18+
19+
from shapiq.typing import Model
20+
21+
from .base import ProductKernelModel
22+
23+
24+
class ProductKernelExplainer(Explainer):
25+
"""The ProductKernelExplainer class for product kernel-based models.
26+
27+
The ProductKernelExplainer can be used with a variety of product kernel-based models. The explainer can handle both regression and
28+
classification models. See [pkex-shapley]_ for details.
29+
30+
31+
References:
32+
.. [pkex-shapley] Majid Mohammadi and Siu Lun Chau, Krikamol Muandet. (2025). Computing Exact Shapley Values in Polynomial Time for Product-Kernel Methods. https://arxiv.org/abs/2505.16516
33+
34+
Attributes:
35+
model: The product kernel model to explain. Can be a dictionary, a ProductKernelModel, or a list of ProductKernelModels.
36+
Note that the model will be converted to a ProductKernelModel if it is not already in that format.
37+
Supported models include scikit-learn's SVR, SVC (binary classification only), and GaussianProcessRegressor.
38+
Beware that for classification models, the class to explain is set to the predicted class of the model.
39+
For further details, see the `validate_pk_model` function in `shapiq.explainer.product_kernel.validation`.
40+
max_order: The maximum interaction order to be computed. Defaults to ``1``.
41+
min_order: The minimum interaction order to be computed. Defaults to ``0``.
42+
index: The type of interaction to be computed. Currently, only ``"SV"`` is supported.
43+
"""
44+
45+
def __init__(
46+
self,
47+
model: (
48+
ProductKernelModel | Model | SVR | SVC | GaussianProcessRegressor # pyright: ignore[reportInvalidTypeVarUse]
49+
),
50+
*,
51+
min_order: int = 0,
52+
max_order: int = 1,
53+
index: ProductKernelSHAPIQIndices = "SV",
54+
**kwargs: Any, # noqa: ARG002
55+
) -> None:
56+
"""Initializes the ProductKernelExplainer.
57+
58+
Args:
59+
model: A product kernel-based model to explain.
60+
61+
min_order: The minimum interaction order to be computed. Defaults to ``0``.
62+
63+
max_order: The maximum interaction order to be computed. An interaction order of ``1``
64+
corresponds to the Shapley value. Defaults to ``1``.
65+
66+
index: The type of interaction to be computed. Currently, only ``"SV"`` is supported.
67+
68+
class_index: The class index of the model to explain. Defaults to ``None``, which will
69+
set the class index to ``1`` per default for classification models and is ignored
70+
for regression models.
71+
72+
**kwargs: Additional keyword arguments are ignored.
73+
74+
"""
75+
if max_order > 1:
76+
msg = "ProductKernelExplainer currently only supports max_order=1."
77+
raise ValueError(msg)
78+
79+
super().__init__(model, index=index, max_order=max_order)
80+
81+
self._min_order = min_order
82+
self._max_order = max_order
83+
84+
self._index = index
85+
self._base_index: str = get_computation_index(self._index)
86+
87+
# validate model
88+
self.converted_model = validate_pk_model(model)
89+
90+
self.explainer = ProductKernelComputer(
91+
model=self.converted_model,
92+
max_order=max_order,
93+
index=index,
94+
)
95+
96+
self.empty_prediction = self._compute_baseline_value()
97+
98+
def explain_function(
99+
self,
100+
x: np.ndarray,
101+
**kwargs: Any, # noqa: ARG002
102+
) -> InteractionValues:
103+
"""Compute Shapley values for all features of an instance.
104+
105+
Args:
106+
x: The instance (1D array) for which to compute Shapley values.
107+
**kwargs: Additional keyword arguments are ignored.
108+
109+
Returns:
110+
The interaction values for the instance.
111+
"""
112+
n_players = self.converted_model.d
113+
114+
# compute the kernel vectors for the instance x
115+
kernel_vectors = self.explainer.compute_kernel_vectors(self.converted_model.X_train, x)
116+
117+
shapley_values = {}
118+
for j in range(self.converted_model.d):
119+
shapley_values.update({(j,): self.explainer.compute_shapley_value(kernel_vectors, j)})
120+
121+
return InteractionValues(
122+
values=shapley_values,
123+
index=self._base_index,
124+
min_order=self._min_order,
125+
max_order=self.max_order,
126+
n_players=n_players,
127+
estimated=False,
128+
baseline_value=self.empty_prediction,
129+
target_index=self._index,
130+
)
131+
132+
def _compute_baseline_value(self) -> float:
133+
"""Computes the baseline value for the explainer.
134+
135+
Returns:
136+
The baseline value for the explainer.
137+
138+
"""
139+
return self.converted_model.alpha.sum() + self.converted_model.intercept

0 commit comments

Comments
 (0)