Skip to content
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
cfc9bab
add logic for ProductKernel explanation
IsaH57 Jul 12, 2025
ca7c365
add code to run examples, minor refactoring
IsaH57 Jul 25, 2025
2d1b8d3
Merge remote-tracking branch 'refs/remotes/origin/main' into product-…
IsaH57 Aug 6, 2025
09f49a6
merge main
IsaH57 Aug 7, 2025
43a729c
refactor code
IsaH57 Aug 7, 2025
49b9eaa
make RBF optional, add files for game, fix bugs
IsaH57 Aug 7, 2025
713744f
clean up
IsaH57 Aug 15, 2025
ab6dd5f
make explain_function return InteractionValues object
IsaH57 Aug 15, 2025
d4a2d30
add init
IsaH57 Aug 15, 2025
b482ef2
Added KernelGame
Advueu963 Aug 16, 2025
8f1598d
add first product kernel tests
IsaH57 Aug 20, 2025
9cb0f71
add product kernel tests against exact computer
IsaH57 Aug 20, 2025
c178f06
add product kernel integration test
IsaH57 Aug 21, 2025
489cc7e
working kernel
Advueu963 Aug 26, 2025
124db3d
update tests and codebase
IsaH57 Aug 28, 2025
18f48e9
Removed src imports
Advueu963 Aug 28, 2025
cbc61d4
Added Baseline Computation
Advueu963 Aug 28, 2025
43a55af
Fixing Integration test
Advueu963 Aug 28, 2025
331586d
Refactor
Advueu963 Aug 28, 2025
97519ef
Corrected Intercept Integration in Empty game prediction
Advueu963 Aug 28, 2025
f740c70
Added Changelog entry
Advueu963 Aug 28, 2025
ae2d351
Removed unnecessary code
Advueu963 Sep 1, 2025
7d2c68d
Added Invalid
Advueu963 Sep 1, 2025
0344967
Merge branch 'main' into product-kernel-explainer
mmschlk Sep 10, 2025
b4b2bdc
Introduced Pyright Typesafty and General Refactoring
Advueu963 Sep 21, 2025
606a077
Merge branch 'main' into product-kernel-explainer
Advueu963 Oct 15, 2025
89d8b60
static typing of Product Kernel
Advueu963 Oct 15, 2025
21ff245
uv lock update
Advueu963 Oct 15, 2025
f0bbd9c
updated import in test_tabular_local_xai making
Advueu963 Oct 15, 2025
435a344
Added test to check that product_kernel methods are detected properly
Advueu963 Oct 21, 2025
e4a9a83
Merge branch 'main' into product-kernel-explainer
Advueu963 Oct 23, 2025
4a7fe41
Merge branch 'main' into product-kernel-explainer
mmschlk Oct 24, 2025
d522cae
Adjustments accoding to Code Review
Advueu963 Oct 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@

## Development

### Introducing ProductKernelExplainer

The ProductKernelExplainer is a new model-specific explanation method for Product Kernel based machine learning model, such as Gaussian Processes or Support Vector Machines.

For further details refer to: https://arxiv.org/abs/2505.16516

### Shapiq Statically Typechecked [#430](https://github.com/mmschlk/shapiq/pull/430)
We have introduced static type checking to `shapiq` using [Pyright](https://github.com/microsoft/pyright), and integrated it into our `pre-commit` hooks.
This ensures that type inconsistencies are caught early during development, improving code quality and maintainability.
Expand Down
5 changes: 4 additions & 1 deletion src/shapiq/approximator/regression/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,10 @@ def _init_kernel_weights(self, interaction_size: int) -> FloatVector:
else:
weight_vector[coalition_size] = 1 / (
(self.n - 2 * interaction_size + 1)
* binom(self.n - 2 * interaction_size, coalition_size - interaction_size)
* binom(
self.n - 2 * interaction_size,
coalition_size - interaction_size,
)
)
return weight_vector
msg = f"Index {self.index} not available for Regression Approximator."
Expand Down
1 change: 1 addition & 0 deletions src/shapiq/explainer/custom_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from typing import Literal

ExplainerIndices = Literal["SV", "SII", "k-SII", "STII", "FSII", "BV", "BII", "FBII"]
ValidProductKernelExplainerIndices = Literal["SV"]
7 changes: 7 additions & 0 deletions src/shapiq/explainer/product_kernel/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Implementation of the ProductKernelComputer and the ProductKernelExplainer."""

from .base import ProductKernelModel
from .explainer import ProductKernelExplainer
from .product_kernel import ProductKernelComputer

__all__ = ["ProductKernelModel", "ProductKernelExplainer", "ProductKernelComputer"]
31 changes: 31 additions & 0 deletions src/shapiq/explainer/product_kernel/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""The base class for product kernel model conversion."""

from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING

if TYPE_CHECKING:
import numpy as np


@dataclass
class ProductKernelModel:
"""A dataclass for storing the information of a product kernel model.

Attributes:
alpha: The alpha parameter of the product kernel model.
X_train: The training data used to fit the product kernel model.
n: The number of samples in the training data.
d: The number of features in the training data.
gamma: The gamma parameter of the product kernel model.
intercept: The intercept term of the product kernel model. For Gaussian Processes this should be zero, but support vectors have often non-zero intercepts.
"""

X_train: np.ndarray
alpha: np.ndarray
n: int
d: int
gamma: float | None = None
kernel_type: str = "rbf"
intercept: float = 0.0
88 changes: 88 additions & 0 deletions src/shapiq/explainer/product_kernel/conversion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""Functions for converting scikit-learn models to a format used by shapiq."""

from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np

from shapiq.explainer.product_kernel.base import ProductKernelModel

if TYPE_CHECKING:
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.svm import SVC, SVR


def convert_svm(model: SVC | SVR) -> ProductKernelModel:
"""Converts a scikit-learn SVM model to the product kernel format used by shapiq.

Args:
model: The scikit-learn SVM model to convert. Can be either a binary support vector classifier (SVC) or a support vector regressor (SVR).

Returns:
ProductKernelModel: The converted model in the product kernel format.

"""
X_train = model.support_vectors_
n, d = X_train.shape

if hasattr(model, "kernel"):
kernel_type = model.kernel # pyright: ignore[reportAttributeAccessIssue]
if kernel_type != "rbf":
msg = "Currently only RBF kernel is supported for SVM models."
raise ValueError(msg)
else:
msg = "Kernel type not found in the model. Ensure the model is a valid SVC or SVR."
raise ValueError(msg)

return ProductKernelModel(
alpha=model.dual_coef_.flatten(), # pyright: ignore[reportAttributeAccessIssue]
X_train=X_train,
n=n,
d=d,
gamma=model._gamma, # pyright: ignore[reportArgumentType, reportAttributeAccessIssue] # noqa: SLF001
kernel_type=kernel_type,
intercept=model.intercept_[0],
)


def convert_gp_reg(model: GaussianProcessRegressor) -> ProductKernelModel:
"""Converts a scikit-learn Gaussian Process Regression model to the product kernel format used by shapiq.

Args:
model: The scikit-learn Gaussian Process Regression model to convert.

Returns:
ProductKernelModel: The converted model in the product kernel format.

"""
X_train = np.array(model.X_train_)
n, d = X_train.shape

if hasattr(model, "kernel"):
kernel_type = model.kernel_.__class__.__name__.lower() # Get the kernel type as a string
if kernel_type != "rbf":
msg = "Currently only RBF kernel is supported for Gaussian Process Regression models."
raise ValueError(msg)
else:
msg = "Kernel type not found in the model. Ensure the model is a valid Gaussian Process Regressor."
raise ValueError(msg)

alphas = np.array(model.alpha_).flatten()
parameters = (
model.kernel_.get_params() # pyright: ignore[reportAttributeAccessIssue]
)
if "length_scale" in parameters:
length_scale = parameters["length_scale"]
else:
msg = "Length scale parameter not found in the kernel."
raise ValueError(msg)

return ProductKernelModel(
alpha=alphas,
X_train=X_train,
n=n,
d=d,
gamma=(2 * (length_scale**2)) ** -1,
kernel_type=kernel_type,
)
137 changes: 137 additions & 0 deletions src/shapiq/explainer/product_kernel/explainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
"""Implementation of the ProductKernelExplainer class."""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

from shapiq import InteractionValues
from shapiq.explainer.base import Explainer
from shapiq.game_theory import get_computation_index

from .product_kernel import ProductKernelComputer, ProductKernelSHAPIQIndices
from .validation import validate_pk_model

if TYPE_CHECKING:
import numpy as np

from shapiq.typing import Model

from .base import ProductKernelModel


class ProductKernelExplainer(Explainer):
"""The ProductKernelExplainer class for product kernel-based models.

The ProductKernelExplainer can be used with a variety of product kernel-based models. The explainer can handle both regression and
classification models. See [pkex-shapley]_ for details.


References:
.. [pkex-shapley] Majid Mohammadi and Siu Lun Chau, Krikamol Muandet. (2025). Computing Exact Shapley Values in Polynomial Time for Product-Kernel Methods. https://arxiv.org/abs/2505.16516

Attributes:
model: The product kernel model to explain. Can be a dictionary, a ProductKernelModel, or a list of ProductKernelModels.
Comment thread
Advueu963 marked this conversation as resolved.
max_order: The maximum interaction order to be computed. Defaults to ``1``.
min_order: The minimum interaction order to be computed. Defaults to ``0``.
index: The type of interaction to be computed. Currently, only ``"SV"`` is supported.

Note:
When explaining classification models, the class which is explained equals the predicted class of the model.
For further details, consult [pkex-shapley]_ .
"""

def __init__(
self,
model: (
Comment thread
Advueu963 marked this conversation as resolved.
ProductKernelModel | Model # pyright: ignore[reportInvalidTypeVarUse]
),
*,
min_order: int = 0,
max_order: int = 1,
index: ProductKernelSHAPIQIndices = "SV",
**kwargs: Any, # noqa: ARG002
) -> None:
"""Initializes the ProductKernelExplainer.

Args:
model: A product kernel-based model to explain.

min_order: The minimum interaction order to be computed. Defaults to ``0``.

max_order: The maximum interaction order to be computed. An interaction order of ``1``
corresponds to the Shapley value. Defaults to ``1``.

index: The type of interaction to be computed. Currently, only ``"SV"`` is supported.

class_index: The class index of the model to explain. Defaults to ``None``, which will
set the class index to ``1`` per default for classification models and is ignored
for regression models.

**kwargs: Additional keyword arguments are ignored.

"""
if max_order > 1:
msg = "ProductKernelExplainer currently only supports max_order=1."
raise ValueError(msg)

super().__init__(model, index=index, max_order=max_order)

self._min_order = min_order
self._max_order = max_order

self._index = index
self._base_index: str = get_computation_index(self._index)

# validate model
self.converted_model = validate_pk_model(model)

self.explainer = ProductKernelComputer(
model=self.converted_model,
max_order=max_order,
index=index,
)

self.empty_prediction = self._compute_baseline_value()

def explain_function(
self,
x: np.ndarray,
**kwargs: Any, # noqa: ARG002
) -> InteractionValues:
"""Compute Shapley values for all features of an instance.

Args:
x: The instance (1D array) for which to compute Shapley values.
**kwargs: Additional keyword arguments are ignored.

Returns:
The interaction values for the instance.
"""
n_players = self.converted_model.d

# compute the kernel vectors for the instance x
kernel_vectors = self.explainer.compute_kernel_vectors(self.converted_model.X_train, x)

shapley_values = {}
for j in range(self.converted_model.d):
shapley_values.update({(j,): self.explainer.compute_shapley_value(kernel_vectors, j)})

return InteractionValues(
values=shapley_values,
index=self._base_index,
min_order=self._min_order,
max_order=self.max_order,
n_players=n_players,
estimated=False,
baseline_value=self.empty_prediction,
target_index=self._index,
)

def _compute_baseline_value(self) -> float:
"""Computes the baseline value for the explainer.

Returns:
The baseline value for the explainer.

"""
return self.converted_model.alpha.sum() + self.converted_model.intercept
102 changes: 102 additions & 0 deletions src/shapiq/explainer/product_kernel/game.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""Product Kernel Game.

This module implements the product kernel game defined in https://arxiv.org/abs/2505.16516.
It is based on machine learning models using (product) kernels as decision functions.
"""

from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np
from sklearn.metrics.pairwise import rbf_kernel

from shapiq.game import Game

if TYPE_CHECKING:
from shapiq.typing import CoalitionMatrix, GameValues

from .base import ProductKernelModel


class ProductKernelGame(Game):
r"""Implements the product kernel game.

For models using the product kernel as the decision function the game can be formulated as
..math::
v(S) = \alpha^T (K(X_S, x_S))

where K(., .) is the product kernel function, X_S are the samples (support vectors) restricted to the features in S and x_S is the point to explain restricted to the features in S.

See https://arxiv.org/abs/2505.16516 for more details.

"""

def __init__(
self,
n_players: int,
explain_point: np.ndarray,
model: ProductKernelModel,
*,
normalize: bool = False,
) -> None:
"""Initializes the product kernel game.

Args:
n_players (int): The number of players in the game.
explain_point (np.ndarray): The point to explain.
model (ProductKernelModel): The product kernel model.
normalize (bool): Whether to normalize the game values.

"""
self.model = model
self.explain_point = explain_point
self.n, self.d = self.model.X_train.shape
self._X_train = self.model.X_train
# The empty value can generally be defined by: \sum_{i=1}^n \alpha_i K(x^i, x) - \beta, where x^i are training points / support vectors.
normalization_value: float = float(self.model.alpha.sum()) + model.intercept

super().__init__(n_players, normalization_value=normalization_value, normalize=normalize)

def value_function(self, coalitions: CoalitionMatrix) -> GameValues:
"""The product kernel game value function.

Args:
coalitions (CoalitionMatrix): The coalitions to evaluate.

Raises:
NotImplementedError: If the kernel type is not supported.

Returns:
GameValues: The values of the game for each coalition.
"""
alpha = self.model.alpha
n_coalitions, _ = coalitions.shape
res = []
if self.model.kernel_type == "rbf":
for coalition in range(n_coalitions):
current_coalition = coalitions[coalition, :]

# The baseline value
if current_coalition.sum() == 0:
res.append(float(self.model.alpha.sum()) + self.model.intercept)
continue
# Extract X_S and x_S
coalition_features = self.explain_point[current_coalition]
X_train = self.model.X_train[:, current_coalition]

# Reshape into twodimensional vectors
if len(coalition_features.shape) == 1:
coalition_features = coalition_features.reshape(1, -1)
if len(X_train.shape) == 1:
X_train = X_train.reshape(1, -1)

# Compute the RBF kernel
kernel_values = rbf_kernel(X=X_train, Y=coalition_features, gamma=self.model.gamma)

# Compute the decision value
res.append((alpha @ kernel_values + self.model.intercept).squeeze())
else:
msg = f"Kernel type '{self.model.kernel_type}' not supported"
raise NotImplementedError(msg)
return np.array(res)
Loading