Skip to content

Commit 7fcb36f

Browse files
authored
Merge pull request #431 from mmschlk/product-kernel-explainer
Product kernel explainer
2 parents 7d31eb3 + d522cae commit 7fcb36f

File tree

24 files changed

+998
-16
lines changed

24 files changed

+998
-16
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22

33
## Development
44

5+
### Introducing ProductKernelExplainer
6+
7+
The ProductKernelExplainer is a new model-specific explanation method for Product Kernel based machine learning model, such as Gaussian Processes or Support Vector Machines.
8+
9+
For further details refer to: https://arxiv.org/abs/2505.16516
10+
511
### Shapiq Statically Typechecked [#430](https://github.com/mmschlk/shapiq/pull/430)
612
We have introduced static type checking to `shapiq` using [Pyright](https://github.com/microsoft/pyright), and integrated it into our `pre-commit` hooks.
713
This ensures that type inconsistencies are caught early during development, improving code quality and maintainability.

src/shapiq/approximator/regression/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,10 @@ def _init_kernel_weights(self, interaction_size: int) -> FloatVector:
111111
else:
112112
weight_vector[coalition_size] = 1 / (
113113
(self.n - 2 * interaction_size + 1)
114-
* binom(self.n - 2 * interaction_size, coalition_size - interaction_size)
114+
* binom(
115+
self.n - 2 * interaction_size,
116+
coalition_size - interaction_size,
117+
)
115118
)
116119
return weight_vector
117120
msg = f"Index {self.index} not available for Regression Approximator."

src/shapiq/explainer/custom_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@
55
from typing import Literal
66

77
ExplainerIndices = Literal["SV", "SII", "k-SII", "STII", "FSII", "BV", "BII", "FBII"]
8+
ValidProductKernelExplainerIndices = Literal["SV"]
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
"""Implementation of the ProductKernelComputer and the ProductKernelExplainer."""
2+
3+
from .base import ProductKernelModel
4+
from .explainer import ProductKernelExplainer
5+
from .product_kernel import ProductKernelComputer
6+
7+
__all__ = ["ProductKernelModel", "ProductKernelExplainer", "ProductKernelComputer"]
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
"""The base class for product kernel model conversion."""
2+
3+
from __future__ import annotations
4+
5+
from dataclasses import dataclass
6+
from typing import TYPE_CHECKING
7+
8+
if TYPE_CHECKING:
9+
import numpy as np
10+
11+
12+
@dataclass
13+
class ProductKernelModel:
14+
"""A dataclass for storing the information of a product kernel model.
15+
16+
Attributes:
17+
alpha: The alpha parameter of the product kernel model.
18+
X_train: The training data used to fit the product kernel model.
19+
n: The number of samples in the training data.
20+
d: The number of features in the training data.
21+
gamma: The gamma parameter of the product kernel model.
22+
intercept: The intercept term of the product kernel model. For Gaussian Processes this should be zero, but support vectors have often non-zero intercepts.
23+
"""
24+
25+
X_train: np.ndarray
26+
alpha: np.ndarray
27+
n: int
28+
d: int
29+
gamma: float | None = None
30+
kernel_type: str = "rbf"
31+
intercept: float = 0.0
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
"""Functions for converting scikit-learn models to a format used by shapiq."""
2+
3+
from __future__ import annotations
4+
5+
from typing import TYPE_CHECKING
6+
7+
import numpy as np
8+
9+
from shapiq.explainer.product_kernel.base import ProductKernelModel
10+
11+
if TYPE_CHECKING:
12+
from sklearn.gaussian_process import GaussianProcessRegressor
13+
from sklearn.svm import SVC, SVR
14+
15+
16+
def convert_svm(model: SVC | SVR) -> ProductKernelModel:
17+
"""Converts a scikit-learn SVM model to the product kernel format used by shapiq.
18+
19+
Args:
20+
model: The scikit-learn SVM model to convert. Can be either a binary support vector classifier (SVC) or a support vector regressor (SVR).
21+
22+
Returns:
23+
ProductKernelModel: The converted model in the product kernel format.
24+
25+
"""
26+
X_train = model.support_vectors_
27+
n, d = X_train.shape
28+
29+
if hasattr(model, "kernel"):
30+
kernel_type = model.kernel # pyright: ignore[reportAttributeAccessIssue]
31+
if kernel_type != "rbf":
32+
msg = "Currently only RBF kernel is supported for SVM models."
33+
raise ValueError(msg)
34+
else:
35+
msg = "Kernel type not found in the model. Ensure the model is a valid SVC or SVR."
36+
raise ValueError(msg)
37+
38+
return ProductKernelModel(
39+
alpha=model.dual_coef_.flatten(), # pyright: ignore[reportAttributeAccessIssue]
40+
X_train=X_train,
41+
n=n,
42+
d=d,
43+
gamma=model._gamma, # pyright: ignore[reportArgumentType, reportAttributeAccessIssue] # noqa: SLF001
44+
kernel_type=kernel_type,
45+
intercept=model.intercept_[0],
46+
)
47+
48+
49+
def convert_gp_reg(model: GaussianProcessRegressor) -> ProductKernelModel:
50+
"""Converts a scikit-learn Gaussian Process Regression model to the product kernel format used by shapiq.
51+
52+
Args:
53+
model: The scikit-learn Gaussian Process Regression model to convert.
54+
55+
Returns:
56+
ProductKernelModel: The converted model in the product kernel format.
57+
58+
"""
59+
X_train = np.array(model.X_train_)
60+
n, d = X_train.shape
61+
62+
if hasattr(model, "kernel"):
63+
kernel_type = model.kernel_.__class__.__name__.lower() # Get the kernel type as a string
64+
if kernel_type != "rbf":
65+
msg = "Currently only RBF kernel is supported for Gaussian Process Regression models."
66+
raise ValueError(msg)
67+
else:
68+
msg = "Kernel type not found in the model. Ensure the model is a valid Gaussian Process Regressor."
69+
raise ValueError(msg)
70+
71+
alphas = np.array(model.alpha_).flatten()
72+
parameters = (
73+
model.kernel_.get_params() # pyright: ignore[reportAttributeAccessIssue]
74+
)
75+
if "length_scale" in parameters:
76+
length_scale = parameters["length_scale"]
77+
else:
78+
msg = "Length scale parameter not found in the kernel."
79+
raise ValueError(msg)
80+
81+
return ProductKernelModel(
82+
alpha=alphas,
83+
X_train=X_train,
84+
n=n,
85+
d=d,
86+
gamma=(2 * (length_scale**2)) ** -1,
87+
kernel_type=kernel_type,
88+
)
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
"""Implementation of the ProductKernelExplainer class."""
2+
3+
from __future__ import annotations
4+
5+
from typing import TYPE_CHECKING, Any
6+
7+
from shapiq import InteractionValues
8+
from shapiq.explainer.base import Explainer
9+
from shapiq.game_theory import get_computation_index
10+
11+
from .product_kernel import ProductKernelComputer, ProductKernelSHAPIQIndices
12+
from .validation import validate_pk_model
13+
14+
if TYPE_CHECKING:
15+
import numpy as np
16+
from sklearn.gaussian_process import GaussianProcessRegressor
17+
from sklearn.svm import SVC, SVR
18+
19+
from shapiq.typing import Model
20+
21+
from .base import ProductKernelModel
22+
23+
24+
class ProductKernelExplainer(Explainer):
25+
"""The ProductKernelExplainer class for product kernel-based models.
26+
27+
The ProductKernelExplainer can be used with a variety of product kernel-based models. The explainer can handle both regression and
28+
classification models. See [pkex-shapley]_ for details.
29+
30+
31+
References:
32+
.. [pkex-shapley] Majid Mohammadi and Siu Lun Chau, Krikamol Muandet. (2025). Computing Exact Shapley Values in Polynomial Time for Product-Kernel Methods. https://arxiv.org/abs/2505.16516
33+
34+
Attributes:
35+
model: The product kernel model to explain. Can be a dictionary, a ProductKernelModel, or a list of ProductKernelModels.
36+
Note that the model will be converted to a ProductKernelModel if it is not already in that format.
37+
Supported models include scikit-learn's SVR, SVC (binary classification only), and GaussianProcessRegressor.
38+
Beware that for classification models, the class to explain is set to the predicted class of the model.
39+
For further details, see the `validate_pk_model` function in `shapiq.explainer.product_kernel.validation`.
40+
max_order: The maximum interaction order to be computed. Defaults to ``1``.
41+
min_order: The minimum interaction order to be computed. Defaults to ``0``.
42+
index: The type of interaction to be computed. Currently, only ``"SV"`` is supported.
43+
"""
44+
45+
def __init__(
46+
self,
47+
model: (
48+
ProductKernelModel | Model | SVR | SVC | GaussianProcessRegressor # pyright: ignore[reportInvalidTypeVarUse]
49+
),
50+
*,
51+
min_order: int = 0,
52+
max_order: int = 1,
53+
index: ProductKernelSHAPIQIndices = "SV",
54+
**kwargs: Any, # noqa: ARG002
55+
) -> None:
56+
"""Initializes the ProductKernelExplainer.
57+
58+
Args:
59+
model: A product kernel-based model to explain.
60+
61+
min_order: The minimum interaction order to be computed. Defaults to ``0``.
62+
63+
max_order: The maximum interaction order to be computed. An interaction order of ``1``
64+
corresponds to the Shapley value. Defaults to ``1``.
65+
66+
index: The type of interaction to be computed. Currently, only ``"SV"`` is supported.
67+
68+
class_index: The class index of the model to explain. Defaults to ``None``, which will
69+
set the class index to ``1`` per default for classification models and is ignored
70+
for regression models.
71+
72+
**kwargs: Additional keyword arguments are ignored.
73+
74+
"""
75+
if max_order > 1:
76+
msg = "ProductKernelExplainer currently only supports max_order=1."
77+
raise ValueError(msg)
78+
79+
super().__init__(model, index=index, max_order=max_order)
80+
81+
self._min_order = min_order
82+
self._max_order = max_order
83+
84+
self._index = index
85+
self._base_index: str = get_computation_index(self._index)
86+
87+
# validate model
88+
self.converted_model = validate_pk_model(model)
89+
90+
self.explainer = ProductKernelComputer(
91+
model=self.converted_model,
92+
max_order=max_order,
93+
index=index,
94+
)
95+
96+
self.empty_prediction = self._compute_baseline_value()
97+
98+
def explain_function(
99+
self,
100+
x: np.ndarray,
101+
**kwargs: Any, # noqa: ARG002
102+
) -> InteractionValues:
103+
"""Compute Shapley values for all features of an instance.
104+
105+
Args:
106+
x: The instance (1D array) for which to compute Shapley values.
107+
**kwargs: Additional keyword arguments are ignored.
108+
109+
Returns:
110+
The interaction values for the instance.
111+
"""
112+
n_players = self.converted_model.d
113+
114+
# compute the kernel vectors for the instance x
115+
kernel_vectors = self.explainer.compute_kernel_vectors(self.converted_model.X_train, x)
116+
117+
shapley_values = {}
118+
for j in range(self.converted_model.d):
119+
shapley_values.update({(j,): self.explainer.compute_shapley_value(kernel_vectors, j)})
120+
121+
return InteractionValues(
122+
values=shapley_values,
123+
index=self._base_index,
124+
min_order=self._min_order,
125+
max_order=self.max_order,
126+
n_players=n_players,
127+
estimated=False,
128+
baseline_value=self.empty_prediction,
129+
target_index=self._index,
130+
)
131+
132+
def _compute_baseline_value(self) -> float:
133+
"""Computes the baseline value for the explainer.
134+
135+
Returns:
136+
The baseline value for the explainer.
137+
138+
"""
139+
return self.converted_model.alpha.sum() + self.converted_model.intercept

0 commit comments

Comments
 (0)