Skip to content

Commit 8fa18fe

Browse files
Add the SPEX approximator to shapiq (#379)
1 parent e6e7a2e commit 8fa18fe

File tree

15 files changed

+971
-319
lines changed

15 files changed

+971
-319
lines changed

docs/source/notebooks/language_notebooks/language_model_game.ipynb

Lines changed: 110 additions & 249 deletions
Large diffs are not rendered by default.

docs/source/notebooks/vision_notebooks/vision_transformer.ipynb

Lines changed: 95 additions & 33 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@ dependencies = [
1515
"scikit-learn",
1616
"tqdm",
1717
"requests",
18+
"sparse-transform",
19+
"galois",
1820
# plotting
1921
"matplotlib",
2022
"networkx",
21-
"colour"
23+
"colour",
2224
]
2325
authors = [
2426
{name = "Maximilian Muschalik", email = "Maximilian.Muschalik@lmu.de"},

shapiq/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# approximator classes
88
from .approximator import (
99
SHAPIQ,
10+
SPEX,
1011
SVARM,
1112
SVARMIQ,
1213
InconsistentKernelSHAPIQ,
@@ -95,6 +96,7 @@
9596
"SVARMIQ",
9697
"kADDSHAP",
9798
"UnbiasedKernelSHAP",
99+
"SPEX",
98100
# explainers
99101
"Explainer",
100102
"TabularExplainer",

shapiq/approximator/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
RegressionFSII,
1515
kADDSHAP,
1616
)
17+
from .sparse import SPEX
1718

1819
# contains all SV approximators
1920
SV_APPROXIMATORS: list[Approximator.__class__] = [
@@ -24,6 +25,7 @@
2425
PermutationSamplingSV,
2526
KernelSHAP,
2627
kADDSHAP,
28+
SPEX,
2729
]
2830

2931
# contains all SI approximators
@@ -44,6 +46,7 @@
4446
InconsistentKernelSHAPIQ,
4547
SVARMIQ,
4648
SHAPIQ,
49+
SPEX,
4750
]
4851

4952
# contains all approximators that can be used for STII
@@ -53,6 +56,7 @@
5356
InconsistentKernelSHAPIQ,
5457
SVARMIQ,
5558
SHAPIQ,
59+
SPEX,
5660
]
5761

5862
# contains all approximators that can be used for FSII
@@ -62,11 +66,13 @@
6266
InconsistentKernelSHAPIQ,
6367
SVARMIQ,
6468
SHAPIQ,
69+
SPEX,
6570
]
6671

6772
# contains all approximators that can be used for FBII
6873
FBII_APPROXIMATORS: list[Approximator.__class__] = [
6974
RegressionFBII,
75+
SPEX,
7076
]
7177

7278
__all__ = [
@@ -84,6 +90,7 @@
8490
"SVARM",
8591
"SVARMIQ",
8692
"kADDSHAP",
93+
"SPEX",
8794
"UnbiasedKernelSHAP",
8895
"SV_APPROXIMATORS",
8996
"SI_APPROXIMATORS",

shapiq/approximator/_base.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def __init__(
6969
pairing_trick: bool = False,
7070
sampling_weights: np.ndarray[float] | None = None,
7171
random_state: int | None = None,
72+
initialize_dict: bool = True,
7273
) -> None:
7374
# check if index can be approximated
7475
self.index: str = index
@@ -89,11 +90,16 @@ def __init__(
8990
self._grand_coalition_tuple = tuple(range(self.n))
9091
self._grand_coalition_array: np.ndarray = np.arange(self.n + 1, dtype=int)
9192
self.iteration_cost: int = 1 # default value, can be overwritten by subclasses
92-
self._interaction_lookup = generate_interaction_lookup(
93-
self.n,
94-
self.min_order,
95-
self.max_order,
96-
)
93+
94+
# The interaction_lookup is not initialized is some cases due to performance reasons
95+
if initialize_dict:
96+
self._interaction_lookup = generate_interaction_lookup(
97+
self.n,
98+
self.min_order,
99+
self.max_order,
100+
)
101+
else:
102+
self._interaction_lookup = {}
97103

98104
# set up random state and random number generators
99105
self._random_state: int | None = random_state
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from ._base import Sparse
2+
from .spex import SPEX
3+
4+
__all__ = [
5+
"SPEX",
6+
"Sparse",
7+
]
Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
from __future__ import annotations
2+
3+
import copy
4+
from collections.abc import Callable
5+
6+
import numpy as np
7+
from sparse_transform.qsft.qsft import transform as sparse_fourier_transform
8+
from sparse_transform.qsft.signals.input_signal_subsampled import (
9+
SubsampledSignal as SubsampledSignalFourier,
10+
)
11+
from sparse_transform.qsft.utils.general import fourier_to_mobius as fourier_to_moebius
12+
from sparse_transform.qsft.utils.query import get_bch_decoder
13+
14+
from ...game_theory.indices import is_index_aggregated
15+
from ...game_theory.moebius_converter import MoebiusConverter
16+
from ...interaction_values import InteractionValues
17+
from .._base import Approximator
18+
19+
20+
class Sparse(Approximator):
21+
"""Approximator for interaction values using sparse transformation techniques.
22+
23+
This class implements a sparse approximation method for computing various interaction indices
24+
using sparse Fourier transforms. It efficiently estimates interaction values with a limited
25+
sample budget by leveraging sparsity in the Fourier domain.
26+
27+
Attributes:
28+
transform_type (str): Type of transform used (currently only "fourier" is supported).
29+
t (int): Error parameter for the sparse Fourier transform (currently fixed to 5).
30+
query_args (dict): Parameters for querying the signal.
31+
decoder_args (dict): Parameters for decoding the transform.
32+
33+
Args:
34+
n (int): Number of players/features.
35+
index (str): Type of interaction index to compute (e.g., "STII", "FBII", "FSII").
36+
max_order (int, optional): Maximum interaction order to compute. It is not suggested to use this parameter
37+
since sparse approximation dynamically and implicitly adjusts the order based on the budget and function.
38+
top_order (bool, optional): If True, only compute interactions of exactly max_order.
39+
If False, compute interactions up to max_order. Defaults to False.
40+
random_state (int, optional): Random seed for reproducibility. Defaults to None.
41+
transform_type (str, optional): Type of transform to use. Currently only "fourier"
42+
is supported. Defaults to "fourier".
43+
decoder_type (str, optional): Type of decoder to use, either "soft" or "hard".
44+
Defaults to "soft"
45+
46+
Raises:
47+
ValueError: If transform_type is not "fourier" or if decoder_type is not "soft" or "hard".
48+
"""
49+
50+
def __init__(
51+
self,
52+
n: int,
53+
index: str,
54+
max_order: int | None = None,
55+
top_order: bool = False,
56+
random_state: int | None = None,
57+
transform_type: str = "fourier",
58+
decoder_type: str = "soft",
59+
) -> None:
60+
if transform_type.lower() not in ["fourier"]:
61+
msg = "transform_type must be 'fourier'"
62+
raise ValueError(msg)
63+
self.transform_type = transform_type.lower()
64+
self.t = 5 # 5 could be a parameter
65+
self.decoder_type = "hard" if decoder_type is None else decoder_type.lower()
66+
if self.decoder_type not in ["soft", "hard"]:
67+
msg = "decoder_type must be 'soft' or 'hard'"
68+
raise ValueError(msg)
69+
# The sampling parameters for the Fourier transform
70+
self.query_args = {
71+
"query_method": "complex",
72+
"num_subsample": 3,
73+
"delays_method_source": "joint-coded",
74+
"subsampling_method": "qsft",
75+
"delays_method_channel": "identity-siso",
76+
"num_repeat": 1,
77+
"t": self.t,
78+
}
79+
self.decoder_args = {
80+
"num_subsample": 3,
81+
"num_repeat": 1,
82+
"reconstruct_method_source": "coded",
83+
"peeling_method": "multi-detect",
84+
"reconstruct_method_channel": "identity-siso"
85+
if self.decoder_type == "soft"
86+
else "identity",
87+
"regress": "lasso",
88+
"res_energy_cutoff": 0.9,
89+
"source_decoder": get_bch_decoder(n, self.t, self.decoder_type),
90+
}
91+
super().__init__(
92+
n=n,
93+
max_order=n if max_order is None else max_order,
94+
index=index,
95+
top_order=top_order,
96+
random_state=random_state,
97+
initialize_dict=False, # Important for performance
98+
)
99+
100+
def approximate(
101+
self,
102+
budget: int,
103+
game: Callable[[np.ndarray], np.ndarray],
104+
) -> InteractionValues:
105+
"""Approximates the interaction values using a sparse transform approach.
106+
107+
Args:
108+
budget: The budget for the approximation.
109+
game: The game function that returns the values for the coalitions.
110+
111+
Returns:
112+
The approximated Shapley interaction values.
113+
"""
114+
# Find the maximum value of b that fits within the given sample budget and get the used budget
115+
used_budget = self._set_transform_budget(budget)
116+
signal = SubsampledSignalFourier(
117+
func=lambda inputs: game(inputs.astype(bool)),
118+
n=self.n,
119+
q=2,
120+
query_args=self.query_args,
121+
)
122+
# Extract the coefficients of the original transform
123+
initial_transform = {
124+
tuple(np.nonzero(key)[0]): np.real(value)
125+
for key, value in sparse_fourier_transform(signal, **self.decoder_args).items()
126+
}
127+
# If we are using the fourier transform, we need to convert it to a Moebius transform
128+
moebius_transform = fourier_to_moebius(initial_transform)
129+
# Convert the Moebius transform to the desired index
130+
result = self._process_moebius(moebius_transform=moebius_transform)
131+
# Filter the output as needed
132+
if self.top_order:
133+
result = self._filter_order(result)
134+
output = InteractionValues(
135+
values=result,
136+
index=self.approximation_index,
137+
min_order=self.min_order,
138+
max_order=self.max_order,
139+
n_players=self.n,
140+
interaction_lookup=copy.deepcopy(self.interaction_lookup),
141+
estimated=True,
142+
estimation_budget=used_budget,
143+
baseline_value=self.interaction_lookup.get((), 0.0),
144+
)
145+
# Update the interaction lookup to reflect the filtered results
146+
if is_index_aggregated(self.index):
147+
output = self.aggregate_interaction_values(output)
148+
return output
149+
150+
def _filter_order(self, result: np.ndarray) -> np.ndarray:
151+
"""Filters the interactions to keep only those of the maximum order.
152+
153+
This method is used when top_order=True to filter out all interactions that are not
154+
of exactly the maximum order (self.max_order).
155+
156+
Args:
157+
result: Array of interaction values.
158+
159+
Returns:
160+
Filtered array containing only interaction values of the maximum order.
161+
The method also updates the internal _interaction_lookup dictionary.
162+
"""
163+
filtered_interactions = {}
164+
filtered_results = []
165+
i = 0
166+
for j, key in enumerate(self.interaction_lookup):
167+
if len(key) == self.max_order:
168+
filtered_interactions[key] = i
169+
filtered_results.append(result[j])
170+
i += 1
171+
self._interaction_lookup = filtered_interactions
172+
return np.array(filtered_results)
173+
174+
def _process_moebius(self, moebius_transform: dict[tuple, float]) -> np.ndarray:
175+
"""Processes the Moebius transform to extract the desired index.
176+
177+
Args:
178+
moebius_transform: The Moebius transform to process (dict mapping tuples to float values).
179+
180+
Returns:
181+
np.ndarray: The converted interaction values based on the specified index.
182+
The function also updates the internal _interaction_lookup dictionary.
183+
"""
184+
moebius_interactions = InteractionValues(
185+
values=np.array([moebius_transform[key] for key in moebius_transform.keys()]),
186+
index="Moebius",
187+
min_order=self.min_order,
188+
max_order=self.max_order,
189+
n_players=self.n,
190+
interaction_lookup={key: i for i, key in enumerate(moebius_transform.keys())},
191+
estimated=True,
192+
baseline_value=moebius_transform.get((), 0.0),
193+
)
194+
autoconverter = MoebiusConverter(moebius_coefficients=moebius_interactions)
195+
converted_interaction_values = autoconverter(index=self.index, order=self.max_order)
196+
self._interaction_lookup = converted_interaction_values.interaction_lookup
197+
return converted_interaction_values.values
198+
199+
def _set_transform_budget(self, budget: int) -> int:
200+
"""Sets the appropriate transform budget parameters based on the given sample budget.
201+
202+
This method calculates the maximum possible 'b' parameter (number of bits to subsample)
203+
that fits within the provided budget, then configures the query and decoder arguments
204+
accordingly. The actual number of samples that will be used is returned.
205+
206+
Args:
207+
budget: The maximum number of samples allowed for the approximation.
208+
209+
Returns:
210+
int: The actual number of samples that will be used, which is less than or equal to the budget.
211+
212+
Raises:
213+
ValueError: If the budget is too low to compute the transform with acceptable parameters.
214+
"""
215+
b = SubsampledSignalFourier.get_b_for_sample_budget(
216+
budget, self.n, self.t, 2, self.query_args
217+
)
218+
used_budget = SubsampledSignalFourier.get_number_of_samples(
219+
self.n, b, self.t, 2, self.query_args
220+
)
221+
222+
if b <= 2:
223+
while self.t > 2:
224+
self.t -= 1
225+
self.query_args["t"] = self.t
226+
227+
# Recalculate 'b' with the updated 't'
228+
b = SubsampledSignalFourier.get_b_for_sample_budget(
229+
budget, self.n, self.t, 2, self.query_args
230+
)
231+
232+
# Compute the used budget
233+
used_budget = SubsampledSignalFourier.get_number_of_samples(
234+
self.n, b, self.t, 2, self.query_args
235+
)
236+
237+
# Break if 'b' is now sufficient
238+
if b > 2:
239+
self.decoder_args["source_decoder"] = get_bch_decoder(
240+
self.n, self.t, self.decoder_type
241+
)
242+
break
243+
244+
# If 'b' is still too low, raise an error
245+
if b <= 2:
246+
msg = "Insufficient budget to compute the transform. Increase the budget or use a different approximator."
247+
raise ValueError(msg)
248+
# Store the final 'b' value
249+
self.query_args["b"] = b
250+
self.decoder_args["b"] = b
251+
return used_budget

0 commit comments

Comments
 (0)