Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
762805d
Initial commit to add Sparse approximator files
justinkang221 Mar 20, 2025
a2b353b
Added list of interactions supported by sparse approximator.
justinkang221 Mar 20, 2025
02bc40e
template for SMT and SPEX
justinkang221 Mar 20, 2025
2a75524
Updated the list of approximators for the indicies SPEX can compute
justinkang221 Mar 21, 2025
21c4ba6
added base class functionality to Sparse
justinkang221 Mar 21, 2025
a4005e8
Added conversion from Fourier -> Moebius, added interface with Intera…
landonbutler Mar 21, 2025
53bb23c
added b computing functionality
justinkang221 Mar 21, 2025
6d5831d
Simplified logic.
justinkang221 Mar 21, 2025
0c875b2
Added Mobius logic for computing b, and comments for paths to new fun…
justinkang221 Mar 22, 2025
b066c13
Added test files.
justinkang221 Mar 22, 2025
3da4b9c
bug fixes in sparse/_base.py
justinkang221 Mar 22, 2025
abc8759
bug fixes for _base.approximate
justinkang221 Mar 23, 2025
2952190
modify logic in sparse/_base.py
justinkang221 Mar 23, 2025
a6de258
change the way baseline values are processed
justinkang221 Mar 24, 2025
72b8fde
remove some TODOs
justinkang221 Mar 25, 2025
0a207b2
Removed Mobius transform
justinkang221 Mar 25, 2025
f5f0cd4
rename mobius -> moebius
justinkang221 Mar 25, 2025
8bc717f
added order filtering to Sparse base class
justinkang221 Mar 27, 2025
59a586c
added docstrings for base sparse approximator
justinkang221 Mar 27, 2025
b3e8107
Update SPEX class to match Sparse
justinkang221 Mar 27, 2025
6b0c91a
fix filtering logic
justinkang221 Mar 27, 2025
e7cbfb6
Added SPEX tests, remove weights from SPEX
justinkang221 Mar 27, 2025
f5f5311
Modify Sparse to work with sparse-transform commit 5feb6bc, will be t…
justinkang221 Mar 27, 2025
e2de3a6
black-ing touched files
justinkang221 Mar 27, 2025
a69ae2e
fix weird formatting due to comment placment
justinkang221 Mar 27, 2025
af6fe9e
remove unused import
justinkang221 Mar 27, 2025
a1c41e8
Added logic for small budget setting
landonbutler Apr 2, 2025
adca81c
Added functionality for undersampling alongside tests
landonbutler Apr 3, 2025
4f0baa0
Added SPEX to init of ShapIQ
landonbutler Apr 3, 2025
5a4600e
sentiment notebook commit
justinkang221 Apr 9, 2025
1ee41d1
Merge branch 'main' of https://github.com/justinkang221/shapiq
justinkang221 Apr 9, 2025
5c841f0
sentiment notebook initial draft
justinkang221 Apr 9, 2025
8822083
added larger image patches example.
justinkang221 Apr 10, 2025
e7db63a
Added print out of top interactions
landonbutler Apr 15, 2025
5612970
minor bug fixes.
justinkang221 Apr 15, 2025
1cc05dd
Merge branch 'main' of https://github.com/justinkang221/shapiq
justinkang221 Apr 15, 2025
546363c
Merge remote-tracking branch 'upstream/main'
justinkang221 May 7, 2025
522b0ab
fix integrations, run ruff linter
justinkang221 May 7, 2025
531f90b
added sparse transform to dependancies
justinkang221 May 7, 2025
a50f43a
re-run ruff
justinkang221 May 7, 2025
5f26daf
bump number of samples for tests
justinkang221 May 7, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
359 changes: 110 additions & 249 deletions docs/source/notebooks/language_notebooks/language_model_game.ipynb

Large diffs are not rendered by default.

128 changes: 95 additions & 33 deletions docs/source/notebooks/vision_notebooks/vision_transformer.ipynb

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@ dependencies = [
"scikit-learn",
"tqdm",
"requests",
"sparse-transform",
"galois",
# plotting
"matplotlib",
"networkx",
"colour"
"colour",
]
authors = [
{name = "Maximilian Muschalik", email = "Maximilian.Muschalik@lmu.de"},
Expand Down
2 changes: 2 additions & 0 deletions shapiq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# approximator classes
from .approximator import (
SHAPIQ,
SPEX,
SVARM,
SVARMIQ,
InconsistentKernelSHAPIQ,
Expand Down Expand Up @@ -95,6 +96,7 @@
"SVARMIQ",
"kADDSHAP",
"UnbiasedKernelSHAP",
"SPEX",
# explainers
"Explainer",
"TabularExplainer",
Expand Down
7 changes: 7 additions & 0 deletions shapiq/approximator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
RegressionFSII,
kADDSHAP,
)
from .sparse import SPEX

# contains all SV approximators
SV_APPROXIMATORS: list[Approximator.__class__] = [
Expand All @@ -24,6 +25,7 @@
PermutationSamplingSV,
KernelSHAP,
kADDSHAP,
SPEX,
]

# contains all SI approximators
Expand All @@ -44,6 +46,7 @@
InconsistentKernelSHAPIQ,
SVARMIQ,
SHAPIQ,
SPEX,
]

# contains all approximators that can be used for STII
Expand All @@ -53,6 +56,7 @@
InconsistentKernelSHAPIQ,
SVARMIQ,
SHAPIQ,
SPEX,
]

# contains all approximators that can be used for FSII
Expand All @@ -62,11 +66,13 @@
InconsistentKernelSHAPIQ,
SVARMIQ,
SHAPIQ,
SPEX,
]

# contains all approximators that can be used for FBII
FBII_APPROXIMATORS: list[Approximator.__class__] = [
RegressionFBII,
SPEX,
]

__all__ = [
Expand All @@ -84,6 +90,7 @@
"SVARM",
"SVARMIQ",
"kADDSHAP",
"SPEX",
"UnbiasedKernelSHAP",
"SV_APPROXIMATORS",
"SI_APPROXIMATORS",
Expand Down
16 changes: 11 additions & 5 deletions shapiq/approximator/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(
pairing_trick: bool = False,
sampling_weights: np.ndarray[float] | None = None,
random_state: int | None = None,
initialize_dict: bool = True,
) -> None:
# check if index can be approximated
self.index: str = index
Expand All @@ -89,11 +90,16 @@ def __init__(
self._grand_coalition_tuple = tuple(range(self.n))
self._grand_coalition_array: np.ndarray = np.arange(self.n + 1, dtype=int)
self.iteration_cost: int = 1 # default value, can be overwritten by subclasses
self._interaction_lookup = generate_interaction_lookup(
self.n,
self.min_order,
self.max_order,
)

# The interaction_lookup is not initialized is some cases due to performance reasons
if initialize_dict:
self._interaction_lookup = generate_interaction_lookup(
self.n,
self.min_order,
self.max_order,
)
else:
self._interaction_lookup = {}

# set up random state and random number generators
self._random_state: int | None = random_state
Expand Down
7 changes: 7 additions & 0 deletions shapiq/approximator/sparse/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from ._base import Sparse
from .spex import SPEX

__all__ = [
"SPEX",
"Sparse",
]
251 changes: 251 additions & 0 deletions shapiq/approximator/sparse/_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
from __future__ import annotations

import copy
from collections.abc import Callable

import numpy as np
from sparse_transform.qsft.qsft import transform as sparse_fourier_transform
from sparse_transform.qsft.signals.input_signal_subsampled import (
SubsampledSignal as SubsampledSignalFourier,
)
from sparse_transform.qsft.utils.general import fourier_to_mobius as fourier_to_moebius
from sparse_transform.qsft.utils.query import get_bch_decoder

from ...game_theory.indices import is_index_aggregated
from ...game_theory.moebius_converter import MoebiusConverter
from ...interaction_values import InteractionValues
from .._base import Approximator


class Sparse(Approximator):
"""Approximator for interaction values using sparse transformation techniques.

This class implements a sparse approximation method for computing various interaction indices
using sparse Fourier transforms. It efficiently estimates interaction values with a limited
sample budget by leveraging sparsity in the Fourier domain.

Attributes:
transform_type (str): Type of transform used (currently only "fourier" is supported).
t (int): Error parameter for the sparse Fourier transform (currently fixed to 5).
query_args (dict): Parameters for querying the signal.
decoder_args (dict): Parameters for decoding the transform.

Args:
n (int): Number of players/features.
index (str): Type of interaction index to compute (e.g., "STII", "FBII", "FSII").
max_order (int, optional): Maximum interaction order to compute. It is not suggested to use this parameter
since sparse approximation dynamically and implicitly adjusts the order based on the budget and function.
top_order (bool, optional): If True, only compute interactions of exactly max_order.
If False, compute interactions up to max_order. Defaults to False.
random_state (int, optional): Random seed for reproducibility. Defaults to None.
transform_type (str, optional): Type of transform to use. Currently only "fourier"
is supported. Defaults to "fourier".
decoder_type (str, optional): Type of decoder to use, either "soft" or "hard".
Defaults to "soft"

Raises:
ValueError: If transform_type is not "fourier" or if decoder_type is not "soft" or "hard".
"""

def __init__(
self,
n: int,
index: str,
max_order: int | None = None,
top_order: bool = False,
random_state: int | None = None,
transform_type: str = "fourier",
decoder_type: str = "soft",
) -> None:
if transform_type.lower() not in ["fourier"]:
msg = "transform_type must be 'fourier'"
raise ValueError(msg)
self.transform_type = transform_type.lower()
self.t = 5 # 5 could be a parameter
self.decoder_type = "hard" if decoder_type is None else decoder_type.lower()
if self.decoder_type not in ["soft", "hard"]:
msg = "decoder_type must be 'soft' or 'hard'"
raise ValueError(msg)
# The sampling parameters for the Fourier transform
self.query_args = {
"query_method": "complex",
"num_subsample": 3,
"delays_method_source": "joint-coded",
"subsampling_method": "qsft",
"delays_method_channel": "identity-siso",
"num_repeat": 1,
"t": self.t,
}
self.decoder_args = {
"num_subsample": 3,
"num_repeat": 1,
"reconstruct_method_source": "coded",
"peeling_method": "multi-detect",
"reconstruct_method_channel": "identity-siso"
if self.decoder_type == "soft"
else "identity",
"regress": "lasso",
"res_energy_cutoff": 0.9,
"source_decoder": get_bch_decoder(n, self.t, self.decoder_type),
}
super().__init__(
n=n,
max_order=n if max_order is None else max_order,
index=index,
top_order=top_order,
random_state=random_state,
initialize_dict=False, # Important for performance
)

def approximate(
self,
budget: int,
game: Callable[[np.ndarray], np.ndarray],
) -> InteractionValues:
"""Approximates the interaction values using a sparse transform approach.

Args:
budget: The budget for the approximation.
game: The game function that returns the values for the coalitions.

Returns:
The approximated Shapley interaction values.
"""
# Find the maximum value of b that fits within the given sample budget and get the used budget
used_budget = self._set_transform_budget(budget)
signal = SubsampledSignalFourier(
func=lambda inputs: game(inputs.astype(bool)),
n=self.n,
q=2,
query_args=self.query_args,
)
# Extract the coefficients of the original transform
initial_transform = {
tuple(np.nonzero(key)[0]): np.real(value)
for key, value in sparse_fourier_transform(signal, **self.decoder_args).items()
}
# If we are using the fourier transform, we need to convert it to a Moebius transform
moebius_transform = fourier_to_moebius(initial_transform)
# Convert the Moebius transform to the desired index
result = self._process_moebius(moebius_transform=moebius_transform)
# Filter the output as needed
if self.top_order:
result = self._filter_order(result)
output = InteractionValues(
values=result,
index=self.approximation_index,
min_order=self.min_order,
max_order=self.max_order,
n_players=self.n,
interaction_lookup=copy.deepcopy(self.interaction_lookup),
estimated=True,
estimation_budget=used_budget,
baseline_value=self.interaction_lookup.get((), 0.0),
)
# Update the interaction lookup to reflect the filtered results
if is_index_aggregated(self.index):
output = self.aggregate_interaction_values(output)
return output

def _filter_order(self, result: np.ndarray) -> np.ndarray:
"""Filters the interactions to keep only those of the maximum order.

This method is used when top_order=True to filter out all interactions that are not
of exactly the maximum order (self.max_order).

Args:
result: Array of interaction values.

Returns:
Filtered array containing only interaction values of the maximum order.
The method also updates the internal _interaction_lookup dictionary.
"""
filtered_interactions = {}
filtered_results = []
i = 0
for j, key in enumerate(self.interaction_lookup):
if len(key) == self.max_order:
filtered_interactions[key] = i
filtered_results.append(result[j])
i += 1
self._interaction_lookup = filtered_interactions
return np.array(filtered_results)

def _process_moebius(self, moebius_transform: dict[tuple, float]) -> np.ndarray:
"""Processes the Moebius transform to extract the desired index.

Args:
moebius_transform: The Moebius transform to process (dict mapping tuples to float values).

Returns:
np.ndarray: The converted interaction values based on the specified index.
The function also updates the internal _interaction_lookup dictionary.
"""
moebius_interactions = InteractionValues(
values=np.array([moebius_transform[key] for key in moebius_transform.keys()]),
index="Moebius",
min_order=self.min_order,
max_order=self.max_order,
n_players=self.n,
interaction_lookup={key: i for i, key in enumerate(moebius_transform.keys())},
estimated=True,
baseline_value=moebius_transform.get((), 0.0),
)
autoconverter = MoebiusConverter(moebius_coefficients=moebius_interactions)
converted_interaction_values = autoconverter(index=self.index, order=self.max_order)
self._interaction_lookup = converted_interaction_values.interaction_lookup
return converted_interaction_values.values

def _set_transform_budget(self, budget: int) -> int:
"""Sets the appropriate transform budget parameters based on the given sample budget.

This method calculates the maximum possible 'b' parameter (number of bits to subsample)
that fits within the provided budget, then configures the query and decoder arguments
accordingly. The actual number of samples that will be used is returned.

Args:
budget: The maximum number of samples allowed for the approximation.

Returns:
int: The actual number of samples that will be used, which is less than or equal to the budget.

Raises:
ValueError: If the budget is too low to compute the transform with acceptable parameters.
"""
b = SubsampledSignalFourier.get_b_for_sample_budget(
budget, self.n, self.t, 2, self.query_args
)
used_budget = SubsampledSignalFourier.get_number_of_samples(
self.n, b, self.t, 2, self.query_args
)

if b <= 2:
while self.t > 2:
self.t -= 1
self.query_args["t"] = self.t

# Recalculate 'b' with the updated 't'
b = SubsampledSignalFourier.get_b_for_sample_budget(
budget, self.n, self.t, 2, self.query_args
)

# Compute the used budget
used_budget = SubsampledSignalFourier.get_number_of_samples(
self.n, b, self.t, 2, self.query_args
)

# Break if 'b' is now sufficient
if b > 2:
self.decoder_args["source_decoder"] = get_bch_decoder(
self.n, self.t, self.decoder_type
)
break

# If 'b' is still too low, raise an error
if b <= 2:
msg = "Insufficient budget to compute the transform. Increase the budget or use a different approximator."
raise ValueError(msg)
# Store the final 'b' value
self.query_args["b"] = b
self.decoder_args["b"] = b
return used_budget
Loading