-
Notifications
You must be signed in to change notification settings - Fork 52
Product kernel explainer #431
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 32 commits
Commits
Show all changes
33 commits
Select commit
Hold shift + click to select a range
cfc9bab
add logic for ProductKernel explanation
IsaH57 ca7c365
add code to run examples, minor refactoring
IsaH57 2d1b8d3
Merge remote-tracking branch 'refs/remotes/origin/main' into product-…
IsaH57 09f49a6
merge main
IsaH57 43a729c
refactor code
IsaH57 49b9eaa
make RBF optional, add files for game, fix bugs
IsaH57 713744f
clean up
IsaH57 ab6dd5f
make explain_function return InteractionValues object
IsaH57 d4a2d30
add init
IsaH57 b482ef2
Added KernelGame
Advueu963 8f1598d
add first product kernel tests
IsaH57 9cb0f71
add product kernel tests against exact computer
IsaH57 c178f06
add product kernel integration test
IsaH57 489cc7e
working kernel
Advueu963 124db3d
update tests and codebase
IsaH57 18f48e9
Removed src imports
Advueu963 cbc61d4
Added Baseline Computation
Advueu963 43a55af
Fixing Integration test
Advueu963 331586d
Refactor
Advueu963 97519ef
Corrected Intercept Integration in Empty game prediction
Advueu963 f740c70
Added Changelog entry
Advueu963 ae2d351
Removed unnecessary code
Advueu963 7d2c68d
Added Invalid
Advueu963 0344967
Merge branch 'main' into product-kernel-explainer
mmschlk b4b2bdc
Introduced Pyright Typesafty and General Refactoring
Advueu963 606a077
Merge branch 'main' into product-kernel-explainer
Advueu963 89d8b60
static typing of Product Kernel
Advueu963 21ff245
uv lock update
Advueu963 f0bbd9c
updated import in test_tabular_local_xai making
Advueu963 435a344
Added test to check that product_kernel methods are detected properly
Advueu963 e4a9a83
Merge branch 'main' into product-kernel-explainer
Advueu963 4a7fe41
Merge branch 'main' into product-kernel-explainer
mmschlk d522cae
Adjustments accoding to Code Review
Advueu963 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| """Implementation of the ProductKernelComputer and the ProductKernelExplainer.""" | ||
|
|
||
| from .base import ProductKernelModel | ||
| from .explainer import ProductKernelExplainer | ||
| from .product_kernel import ProductKernelComputer | ||
|
|
||
| __all__ = ["ProductKernelModel", "ProductKernelExplainer", "ProductKernelComputer"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,31 @@ | ||
| """The base class for product kernel model conversion.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from dataclasses import dataclass | ||
| from typing import TYPE_CHECKING | ||
|
|
||
| if TYPE_CHECKING: | ||
| import numpy as np | ||
|
|
||
|
|
||
| @dataclass | ||
| class ProductKernelModel: | ||
| """A dataclass for storing the information of a product kernel model. | ||
|
|
||
| Attributes: | ||
| alpha: The alpha parameter of the product kernel model. | ||
| X_train: The training data used to fit the product kernel model. | ||
| n: The number of samples in the training data. | ||
| d: The number of features in the training data. | ||
| gamma: The gamma parameter of the product kernel model. | ||
| intercept: The intercept term of the product kernel model. For Gaussian Processes this should be zero, but support vectors have often non-zero intercepts. | ||
| """ | ||
|
|
||
| X_train: np.ndarray | ||
| alpha: np.ndarray | ||
| n: int | ||
| d: int | ||
| gamma: float | None = None | ||
| kernel_type: str = "rbf" | ||
| intercept: float = 0.0 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,88 @@ | ||
| """Functions for converting scikit-learn models to a format used by shapiq.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from typing import TYPE_CHECKING | ||
|
|
||
| import numpy as np | ||
|
|
||
| from shapiq.explainer.product_kernel.base import ProductKernelModel | ||
|
|
||
| if TYPE_CHECKING: | ||
| from sklearn.gaussian_process import GaussianProcessRegressor | ||
| from sklearn.svm import SVC, SVR | ||
|
|
||
|
|
||
| def convert_svm(model: SVC | SVR) -> ProductKernelModel: | ||
| """Converts a scikit-learn SVM model to the product kernel format used by shapiq. | ||
|
|
||
| Args: | ||
| model: The scikit-learn SVM model to convert. Can be either a binary support vector classifier (SVC) or a support vector regressor (SVR). | ||
|
|
||
| Returns: | ||
| ProductKernelModel: The converted model in the product kernel format. | ||
|
|
||
| """ | ||
| X_train = model.support_vectors_ | ||
| n, d = X_train.shape | ||
|
|
||
| if hasattr(model, "kernel"): | ||
| kernel_type = model.kernel # pyright: ignore[reportAttributeAccessIssue] | ||
| if kernel_type != "rbf": | ||
| msg = "Currently only RBF kernel is supported for SVM models." | ||
| raise ValueError(msg) | ||
| else: | ||
| msg = "Kernel type not found in the model. Ensure the model is a valid SVC or SVR." | ||
| raise ValueError(msg) | ||
|
|
||
| return ProductKernelModel( | ||
| alpha=model.dual_coef_.flatten(), # pyright: ignore[reportAttributeAccessIssue] | ||
| X_train=X_train, | ||
| n=n, | ||
| d=d, | ||
| gamma=model._gamma, # pyright: ignore[reportArgumentType, reportAttributeAccessIssue] # noqa: SLF001 | ||
| kernel_type=kernel_type, | ||
| intercept=model.intercept_[0], | ||
| ) | ||
|
|
||
|
|
||
| def convert_gp_reg(model: GaussianProcessRegressor) -> ProductKernelModel: | ||
| """Converts a scikit-learn Gaussian Process Regression model to the product kernel format used by shapiq. | ||
|
|
||
| Args: | ||
| model: The scikit-learn Gaussian Process Regression model to convert. | ||
|
|
||
| Returns: | ||
| ProductKernelModel: The converted model in the product kernel format. | ||
|
|
||
| """ | ||
| X_train = np.array(model.X_train_) | ||
| n, d = X_train.shape | ||
|
|
||
| if hasattr(model, "kernel"): | ||
| kernel_type = model.kernel_.__class__.__name__.lower() # Get the kernel type as a string | ||
| if kernel_type != "rbf": | ||
| msg = "Currently only RBF kernel is supported for Gaussian Process Regression models." | ||
| raise ValueError(msg) | ||
| else: | ||
| msg = "Kernel type not found in the model. Ensure the model is a valid Gaussian Process Regressor." | ||
| raise ValueError(msg) | ||
|
|
||
| alphas = np.array(model.alpha_).flatten() | ||
| parameters = ( | ||
| model.kernel_.get_params() # pyright: ignore[reportAttributeAccessIssue] | ||
| ) | ||
| if "length_scale" in parameters: | ||
| length_scale = parameters["length_scale"] | ||
| else: | ||
| msg = "Length scale parameter not found in the kernel." | ||
| raise ValueError(msg) | ||
|
|
||
| return ProductKernelModel( | ||
| alpha=alphas, | ||
| X_train=X_train, | ||
| n=n, | ||
| d=d, | ||
| gamma=(2 * (length_scale**2)) ** -1, | ||
| kernel_type=kernel_type, | ||
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,137 @@ | ||
| """Implementation of the ProductKernelExplainer class.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from typing import TYPE_CHECKING, Any | ||
|
|
||
| from shapiq import InteractionValues | ||
| from shapiq.explainer.base import Explainer | ||
| from shapiq.game_theory import get_computation_index | ||
|
|
||
| from .product_kernel import ProductKernelComputer, ProductKernelSHAPIQIndices | ||
| from .validation import validate_pk_model | ||
|
|
||
| if TYPE_CHECKING: | ||
| import numpy as np | ||
|
|
||
| from shapiq.typing import Model | ||
|
|
||
| from .base import ProductKernelModel | ||
|
|
||
|
|
||
| class ProductKernelExplainer(Explainer): | ||
| """The ProductKernelExplainer class for product kernel-based models. | ||
|
|
||
| The ProductKernelExplainer can be used with a variety of product kernel-based models. The explainer can handle both regression and | ||
| classification models. See [pkex-shapley]_ for details. | ||
|
|
||
|
|
||
| References: | ||
| .. [pkex-shapley] Majid Mohammadi and Siu Lun Chau, Krikamol Muandet. (2025). Computing Exact Shapley Values in Polynomial Time for Product-Kernel Methods. https://arxiv.org/abs/2505.16516 | ||
|
|
||
| Attributes: | ||
| model: The product kernel model to explain. Can be a dictionary, a ProductKernelModel, or a list of ProductKernelModels. | ||
| max_order: The maximum interaction order to be computed. Defaults to ``1``. | ||
| min_order: The minimum interaction order to be computed. Defaults to ``0``. | ||
| index: The type of interaction to be computed. Currently, only ``"SV"`` is supported. | ||
|
|
||
| Note: | ||
| When explaining classification models, the class which is explained equals the predicted class of the model. | ||
| For further details, consult [pkex-shapley]_ . | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| model: ( | ||
|
Advueu963 marked this conversation as resolved.
|
||
| ProductKernelModel | Model # pyright: ignore[reportInvalidTypeVarUse] | ||
| ), | ||
| *, | ||
| min_order: int = 0, | ||
| max_order: int = 1, | ||
| index: ProductKernelSHAPIQIndices = "SV", | ||
| **kwargs: Any, # noqa: ARG002 | ||
| ) -> None: | ||
| """Initializes the ProductKernelExplainer. | ||
|
|
||
| Args: | ||
| model: A product kernel-based model to explain. | ||
|
|
||
| min_order: The minimum interaction order to be computed. Defaults to ``0``. | ||
|
|
||
| max_order: The maximum interaction order to be computed. An interaction order of ``1`` | ||
| corresponds to the Shapley value. Defaults to ``1``. | ||
|
|
||
| index: The type of interaction to be computed. Currently, only ``"SV"`` is supported. | ||
|
|
||
| class_index: The class index of the model to explain. Defaults to ``None``, which will | ||
| set the class index to ``1`` per default for classification models and is ignored | ||
| for regression models. | ||
|
|
||
| **kwargs: Additional keyword arguments are ignored. | ||
|
|
||
| """ | ||
| if max_order > 1: | ||
| msg = "ProductKernelExplainer currently only supports max_order=1." | ||
| raise ValueError(msg) | ||
|
|
||
| super().__init__(model, index=index, max_order=max_order) | ||
|
|
||
| self._min_order = min_order | ||
| self._max_order = max_order | ||
|
|
||
| self._index = index | ||
| self._base_index: str = get_computation_index(self._index) | ||
|
|
||
| # validate model | ||
| self.converted_model = validate_pk_model(model) | ||
|
|
||
| self.explainer = ProductKernelComputer( | ||
| model=self.converted_model, | ||
| max_order=max_order, | ||
| index=index, | ||
| ) | ||
|
|
||
| self.empty_prediction = self._compute_baseline_value() | ||
|
|
||
| def explain_function( | ||
| self, | ||
| x: np.ndarray, | ||
| **kwargs: Any, # noqa: ARG002 | ||
| ) -> InteractionValues: | ||
| """Compute Shapley values for all features of an instance. | ||
|
|
||
| Args: | ||
| x: The instance (1D array) for which to compute Shapley values. | ||
| **kwargs: Additional keyword arguments are ignored. | ||
|
|
||
| Returns: | ||
| The interaction values for the instance. | ||
| """ | ||
| n_players = self.converted_model.d | ||
|
|
||
| # compute the kernel vectors for the instance x | ||
| kernel_vectors = self.explainer.compute_kernel_vectors(self.converted_model.X_train, x) | ||
|
|
||
| shapley_values = {} | ||
| for j in range(self.converted_model.d): | ||
| shapley_values.update({(j,): self.explainer.compute_shapley_value(kernel_vectors, j)}) | ||
|
|
||
| return InteractionValues( | ||
| values=shapley_values, | ||
| index=self._base_index, | ||
| min_order=self._min_order, | ||
| max_order=self.max_order, | ||
| n_players=n_players, | ||
| estimated=False, | ||
| baseline_value=self.empty_prediction, | ||
| target_index=self._index, | ||
| ) | ||
|
|
||
| def _compute_baseline_value(self) -> float: | ||
| """Computes the baseline value for the explainer. | ||
|
|
||
| Returns: | ||
| The baseline value for the explainer. | ||
|
|
||
| """ | ||
| return self.converted_model.alpha.sum() + self.converted_model.intercept | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,102 @@ | ||
| """Product Kernel Game. | ||
|
|
||
| This module implements the product kernel game defined in https://arxiv.org/abs/2505.16516. | ||
| It is based on machine learning models using (product) kernels as decision functions. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from typing import TYPE_CHECKING | ||
|
|
||
| import numpy as np | ||
| from sklearn.metrics.pairwise import rbf_kernel | ||
|
|
||
| from shapiq.game import Game | ||
|
|
||
| if TYPE_CHECKING: | ||
| from shapiq.typing import CoalitionMatrix, GameValues | ||
|
|
||
| from .base import ProductKernelModel | ||
|
|
||
|
|
||
| class ProductKernelGame(Game): | ||
| r"""Implements the product kernel game. | ||
|
|
||
| For models using the product kernel as the decision function the game can be formulated as | ||
| ..math:: | ||
| v(S) = \alpha^T (K(X_S, x_S)) | ||
|
|
||
| where K(., .) is the product kernel function, X_S are the samples (support vectors) restricted to the features in S and x_S is the point to explain restricted to the features in S. | ||
|
|
||
| See https://arxiv.org/abs/2505.16516 for more details. | ||
|
|
||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| n_players: int, | ||
| explain_point: np.ndarray, | ||
| model: ProductKernelModel, | ||
| *, | ||
| normalize: bool = False, | ||
| ) -> None: | ||
| """Initializes the product kernel game. | ||
|
|
||
| Args: | ||
| n_players (int): The number of players in the game. | ||
| explain_point (np.ndarray): The point to explain. | ||
| model (ProductKernelModel): The product kernel model. | ||
| normalize (bool): Whether to normalize the game values. | ||
|
|
||
| """ | ||
| self.model = model | ||
| self.explain_point = explain_point | ||
| self.n, self.d = self.model.X_train.shape | ||
| self._X_train = self.model.X_train | ||
| # The empty value can generally be defined by: \sum_{i=1}^n \alpha_i K(x^i, x) - \beta, where x^i are training points / support vectors. | ||
| normalization_value: float = float(self.model.alpha.sum()) + model.intercept | ||
|
|
||
| super().__init__(n_players, normalization_value=normalization_value, normalize=normalize) | ||
|
|
||
| def value_function(self, coalitions: CoalitionMatrix) -> GameValues: | ||
| """The product kernel game value function. | ||
|
|
||
| Args: | ||
| coalitions (CoalitionMatrix): The coalitions to evaluate. | ||
|
|
||
| Raises: | ||
| NotImplementedError: If the kernel type is not supported. | ||
|
|
||
| Returns: | ||
| GameValues: The values of the game for each coalition. | ||
| """ | ||
| alpha = self.model.alpha | ||
| n_coalitions, _ = coalitions.shape | ||
| res = [] | ||
| if self.model.kernel_type == "rbf": | ||
| for coalition in range(n_coalitions): | ||
| current_coalition = coalitions[coalition, :] | ||
|
|
||
| # The baseline value | ||
| if current_coalition.sum() == 0: | ||
| res.append(float(self.model.alpha.sum()) + self.model.intercept) | ||
| continue | ||
| # Extract X_S and x_S | ||
| coalition_features = self.explain_point[current_coalition] | ||
| X_train = self.model.X_train[:, current_coalition] | ||
|
|
||
| # Reshape into twodimensional vectors | ||
| if len(coalition_features.shape) == 1: | ||
| coalition_features = coalition_features.reshape(1, -1) | ||
| if len(X_train.shape) == 1: | ||
| X_train = X_train.reshape(1, -1) | ||
|
|
||
| # Compute the RBF kernel | ||
| kernel_values = rbf_kernel(X=X_train, Y=coalition_features, gamma=self.model.gamma) | ||
|
|
||
| # Compute the decision value | ||
| res.append((alpha @ kernel_values + self.model.intercept).squeeze()) | ||
| else: | ||
| msg = f"Kernel type '{self.model.kernel_type}' not supported" | ||
| raise NotImplementedError(msg) | ||
| return np.array(res) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.