Skip to content

Commit f892935

Browse files
Advueu963mmschlk
andauthored
331 adding faith banzhaf approximator (#333)
* Added FBII * Introduces game normalization for FBII in exact. * Added FBII to MC_Approximator * Added FBI to Tabular Explainer * Update regression coefficient calculation for FBII & FSII. * refactoring of regression weight calculation * fix fbii exact computation missing baseline_value addition * Added FBII * Introduces game normalization for FBII in exact. * Added FBII to MC_Approximator * Added FBI to Tabular Explainer * Update regression coefficient calculation for FBII & FSII. * refactoring of regression weight calculation * fix fbii exact computation missing baseline_value addition * updated regression_coefficient calculation * improving approximation finalize result * Consistent baseline/empty player value throughout InteractionValues. * Sampling weight initialisation for FBII and corresponding approximation * renamed finalize function * finalized explanations for TreeExplainer --------- Co-authored-by: Maximilian <maximilian.muschalik@gmail.com>
1 parent af3038e commit f892935

File tree

29 files changed

+556
-169
lines changed

29 files changed

+556
-169
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
- fixes a bug with xgboost where feature names where trees that did not contain all features would lead `TreeExplainer` to fail
77
- fixes a bug with `stacked_bar_plot` where the higher order interactions were inflated by the lower order interactions, thus wrongly showing the higher order interactions as higher than they are
88
- fixes a bug where `InteractionValues.get_subset()` returns a faulty `coalition_lookup` dictionary pointing to indices outside the subset of players [#336](https://github.com/mmschlk/shapiq/issues/336)
9+
- updates default value of `TreeExplainer`'s `min_order` parameter from 1 to 0 to include the baseline value in the interaction values as per default
10+
- adds the `RegressionFBII` approximator to estimate Faithful Banzhaf interactions via least squares regression [#333](https://github.com/mmschlk/shapiq/pull/333). Additionally, FBII support was introduced in TabularExplainer and MonteCarlo-Approximator.
911

1012
### v1.2.2 (2025-03-11)
1113
- changes python support to 3.10-3.13 [#318](https://github.com/mmschlk/shapiq/pull/318)

shapiq/approximator/__init__.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,14 @@
66
from .permutation.sii import PermutationSamplingSII
77
from .permutation.stii import PermutationSamplingSTII
88
from .permutation.sv import PermutationSamplingSV
9-
from .regression import InconsistentKernelSHAPIQ, KernelSHAP, KernelSHAPIQ, RegressionFSII, kADDSHAP
9+
from .regression import (
10+
InconsistentKernelSHAPIQ,
11+
KernelSHAP,
12+
KernelSHAPIQ,
13+
RegressionFBII,
14+
RegressionFSII,
15+
kADDSHAP,
16+
)
1017

1118
# contains all SV approximators
1219
SV_APPROXIMATORS: list[Approximator.__class__] = [
@@ -57,6 +64,11 @@
5764
SHAPIQ,
5865
]
5966

67+
# contains all approximators that can be used for FBII
68+
FBII_APPROXIMATORS: list[Approximator.__class__] = [
69+
RegressionFBII,
70+
]
71+
6072
__all__ = [
6173
"PermutationSamplingSII",
6274
"PermutationSamplingSTII",
@@ -77,6 +89,7 @@
7789
"SII_APPROXIMATORS",
7890
"STII_APPROXIMATORS",
7991
"FSII_APPROXIMATORS",
92+
"FBII_APPROXIMATORS",
8093
]
8194

8295
# Path: shapiq/approximator/__init__.py

shapiq/approximator/_base.py

Lines changed: 27 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
11
"""This module contains the base approximator classes for the shapiq package."""
22

3-
import copy
3+
import warnings
44
from abc import ABC, abstractmethod
55
from collections.abc import Callable
66

77
import numpy as np
8+
from scipy.special import binom
89

910
from ..approximator.sampling import CoalitionSampler
1011
from ..game_theory.indices import (
1112
AVAILABLE_INDICES_FOR_APPROXIMATION,
1213
get_computation_index,
13-
is_empty_value_the_baseline,
14-
is_index_aggregated,
1514
)
1615
from ..interaction_values import InteractionValues
1716
from ..utils.sets import generate_interaction_lookup
@@ -143,13 +142,29 @@ def _init_sampling_weights(self) -> np.ndarray:
143142
The weights for sampling subsets of size ``s`` in shape ``(n + 1,)``.
144143
"""
145144
weight_vector = np.zeros(shape=self.n + 1)
146-
for coalition_size in range(0, self.n + 1):
147-
if (coalition_size < self.max_order) or (coalition_size > self.n - self.max_order):
148-
# prioritize these subsets
149-
weight_vector[coalition_size] = self._big_M
150-
else:
151-
# KernelSHAP sampling weights
152-
weight_vector[coalition_size] = 1 / (coalition_size * (self.n - coalition_size))
145+
if self.index in ["FBII"]:
146+
147+
try:
148+
for coalition_size in range(0, self.n + 1):
149+
weight_vector[coalition_size] = binom(self.n, coalition_size) / 2**self.n
150+
except OverflowError:
151+
for coalition_size in range(0, self.n + 1):
152+
weight_vector[coalition_size] = (
153+
1
154+
/ np.sqrt(2 * np.pi * 0.5)
155+
* np.exp(-(coalition_size - self.n / 2) * +2 / (self.n / 2))
156+
)
157+
warnings.warn(
158+
"The weights are approximated for n > 1000. While this is very close to the truth for sets for size in the region n/2, the approximation is not exact for size n or 0."
159+
)
160+
else:
161+
for coalition_size in range(0, self.n + 1):
162+
if (coalition_size < self.max_order) or (coalition_size > self.n - self.max_order):
163+
# prioritize these subsets
164+
weight_vector[coalition_size] = self._big_M
165+
else:
166+
# KernelSHAP sampling weights
167+
weight_vector[coalition_size] = 1 / (coalition_size * (self.n - coalition_size))
153168
sampling_weight = weight_vector / np.sum(weight_vector)
154169
return sampling_weight
155170

@@ -175,61 +190,6 @@ def _order_iterator(self) -> range:
175190
"""
176191
return range(self.min_order, self.max_order + 1)
177192

178-
def _finalize_result(
179-
self,
180-
result,
181-
baseline_value: float,
182-
*,
183-
estimated: bool | None = None,
184-
budget: int | None = None,
185-
) -> InteractionValues:
186-
"""Finalizes the result dictionary.
187-
188-
Args:
189-
result: Interaction values.
190-
baseline_value: Baseline value.
191-
estimated: Whether interaction values were estimated.
192-
budget: The budget for the approximation.
193-
194-
Returns:
195-
The interaction values.
196-
197-
Raises:
198-
ValueError: If the baseline value is not provided for SII and k-SII.
199-
"""
200-
201-
if budget is None: # try to get budget from sampler (exclude from coverage)
202-
budget = self._sampler.n_coalitions # pragma: no cover
203-
204-
if estimated is None:
205-
estimated = False if budget >= 2**self.n else True
206-
207-
# set empty value as baseline value if necessary
208-
if tuple() in self._interaction_lookup:
209-
idx = self._interaction_lookup[tuple()]
210-
empty_value = result[idx]
211-
# only for SII empty value is not the baseline value
212-
if empty_value != baseline_value and is_empty_value_the_baseline(self.index):
213-
result[idx] = baseline_value
214-
215-
interactions = InteractionValues(
216-
values=result,
217-
estimated=estimated,
218-
estimation_budget=budget,
219-
index=self.approximation_index, # can be different from self.index
220-
min_order=self.min_order,
221-
max_order=self.max_order,
222-
n_players=self.n,
223-
interaction_lookup=copy.deepcopy(self.interaction_lookup),
224-
baseline_value=baseline_value,
225-
)
226-
227-
# if index needs to be aggregated
228-
if is_index_aggregated(self.index):
229-
interactions = self.aggregate_interaction_values(interactions)
230-
231-
return interactions
232-
233193
@staticmethod
234194
def _calc_iteration_count(budget: int, batch_size: int, iteration_cost: int) -> tuple[int, int]:
235195
"""Computes the number of iterations and the size of the last batch given the batch size and
@@ -313,6 +273,6 @@ def aggregate_interaction_values(
313273
Returns:
314274
The aggregated interaction values.
315275
"""
316-
from shapiq.game_theory.aggregation import aggregate_interaction_values
276+
from shapiq.game_theory.aggregation import aggregate_base_interaction
317277

318-
return aggregate_interaction_values(base_interactions, order=order)
278+
return aggregate_base_interaction(base_interactions, order=order)

shapiq/approximator/marginals/owen.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import numpy as np
88

9-
from ...interaction_values import InteractionValues
9+
from ...interaction_values import InteractionValues, finalize_computed_interactions
1010
from .._base import Approximator
1111

1212

@@ -101,8 +101,21 @@ def approximate(
101101
idx = self._interaction_lookup[(player,)]
102102
result_to_finalize[idx] = result[player]
103103

104-
return self._finalize_result(
105-
result_to_finalize, baseline_value=empty_value, budget=used_budget, estimated=True
104+
interaction = InteractionValues(
105+
n_players=self.n,
106+
values=result_to_finalize,
107+
index=self.approximation_index,
108+
interaction_lookup=self._interaction_lookup,
109+
baseline_value=empty_value,
110+
min_order=self.min_order,
111+
max_order=self.max_order,
112+
estimated=True,
113+
estimation_budget=used_budget,
114+
)
115+
116+
return finalize_computed_interactions(
117+
interaction,
118+
target_index=self.index,
106119
)
107120

108121
@staticmethod

shapiq/approximator/marginals/stratified.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import numpy as np
66

7-
from ...interaction_values import InteractionValues
7+
from ...interaction_values import InteractionValues, finalize_computed_interactions
88
from .._base import Approximator
99

1010

@@ -37,13 +37,14 @@ def __init__(
3737
self.iteration_cost: int = 2
3838

3939
def approximate(
40-
self, budget: int, game: Callable[[np.ndarray], np.ndarray]
40+
self, budget: int, game: Callable[[np.ndarray], np.ndarray], *args, **kwargs
4141
) -> InteractionValues:
4242
"""Approximates the Shapley values using ApproShapley.
4343
4444
Args:
4545
budget: The number of game evaluations for approximation
4646
game: The game function as a callable that takes a set of players and returns the value.
47+
*args and **kwargs: Additional arguments not used.
4748
4849
Returns:
4950
The estimated interaction values.
@@ -114,6 +115,16 @@ def approximate(
114115
idx = self._interaction_lookup[(player,)]
115116
result_to_finalize[idx] = result[player]
116117

117-
return self._finalize_result(
118-
result_to_finalize, baseline_value=empty_value, budget=used_budget, estimated=True
118+
interactions = InteractionValues(
119+
n_players=self.n,
120+
values=result_to_finalize,
121+
index=self.approximation_index,
122+
interaction_lookup=self._interaction_lookup,
123+
baseline_value=float(empty_value),
124+
min_order=self.min_order,
125+
max_order=self.max_order,
126+
estimated=True,
127+
estimation_budget=used_budget,
119128
)
129+
130+
return finalize_computed_interactions(interactions, target_index=self.index)

shapiq/approximator/montecarlo/_base.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from scipy.special import binom, factorial
77

88
from ...game_theory.indices import AVAILABLE_INDICES_MONTE_CARLO
9-
from ...interaction_values import InteractionValues
9+
from ...interaction_values import InteractionValues, finalize_computed_interactions
1010
from ...utils.sets import powerset
1111
from .._base import Approximator
1212

@@ -50,7 +50,7 @@ def __init__(
5050
f"Index {index} not available for Regression Approximator. Choose from "
5151
f"{AVAILABLE_INDICES_MONTE_CARLO}."
5252
)
53-
if index == "FSII":
53+
if index in ["FSII", "FBII"]:
5454
top_order = True
5555
super().__init__(
5656
n,
@@ -99,10 +99,20 @@ def approximate(
9999

100100
baseline_value = float(game_values[self._sampler.empty_coalition_index])
101101

102-
return self._finalize_result(
103-
result=shapley_interactions_values, baseline_value=baseline_value, budget=budget
102+
interactions = InteractionValues(
103+
shapley_interactions_values,
104+
index=self.approximation_index,
105+
n_players=self.n,
106+
interaction_lookup=self.interaction_lookup,
107+
min_order=self.min_order,
108+
max_order=self.max_order,
109+
baseline_value=baseline_value,
110+
estimated=False if budget >= 2**self.n else True,
111+
estimation_budget=budget,
104112
)
105113

114+
return finalize_computed_interactions(interactions, target_index=self.index)
115+
106116
def monte_carlo_routine(
107117
self,
108118
game_values: np.ndarray,
@@ -386,8 +396,7 @@ def _stii_weight(self, coalition_size: int, interaction_size: int) -> float:
386396
"""
387397
if interaction_size == self.max_order:
388398
return self.max_order / (self.n * binom(self.n - 1, coalition_size))
389-
else:
390-
return 1.0 * (coalition_size == 0)
399+
return 1.0 * (coalition_size == 0)
391400

392401
def _fsii_weight(self, coalition_size: int, interaction_size: int) -> float:
393402
"""Returns the FSII discrete derivative weight given the coalition size and interaction
@@ -411,8 +420,24 @@ def _fsii_weight(self, coalition_size: int, interaction_size: int) -> float:
411420
* factorial(coalition_size + self.max_order - 1)
412421
/ factorial(self.n + self.max_order - 1)
413422
)
414-
else:
415-
raise ValueError("Lower order interactions are not supported.")
423+
raise ValueError(f"Lower order interactions are not supported for {self.index}.")
424+
425+
def _fbii_weight(self, interaction_size: int) -> float:
426+
"""Returns the FSII discrete derivative weight given the coalition size and interaction
427+
size.
428+
429+
The representation is based on the FBII representation according to Theorem 17 by
430+
`Tsai et al. (2023) <https://doi.org/10.48550/arXiv.2203.00870>`_.
431+
432+
Args:
433+
interaction_size: The size of the interaction.
434+
435+
Returns:
436+
The weight for the interaction type.
437+
"""
438+
if interaction_size == self.max_order:
439+
return 1 / 2 ** (self.n - interaction_size)
440+
raise ValueError(f"Lower order interactions are not supported for {self.index}.")
416441

417442
def _weight(self, index: str, coalition_size: int, interaction_size: int) -> float:
418443
"""Returns the weight for each interaction type given coalition and interaction size.
@@ -429,6 +454,8 @@ def _weight(self, index: str, coalition_size: int, interaction_size: int) -> flo
429454
return self._stii_weight(coalition_size, interaction_size)
430455
elif index == "FSII":
431456
return self._fsii_weight(coalition_size, interaction_size)
457+
elif index == "FBII":
458+
return self._fbii_weight(interaction_size)
432459
elif index in ["SII", "SV"]:
433460
return self._sii_weight(coalition_size, interaction_size)
434461
elif index == "BII":

shapiq/approximator/permutation/sii.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import numpy as np
66

7-
from ...interaction_values import InteractionValues
7+
from ...interaction_values import InteractionValues, finalize_computed_interactions
88
from ...utils.sets import powerset
99
from .._base import Approximator
1010

@@ -143,6 +143,16 @@ def approximate(
143143
# compute mean of interactions
144144
result = np.divide(result, counts, out=result, where=counts != 0)
145145

146-
return self._finalize_result(
147-
result, baseline_value=empty_value, budget=used_budget, estimated=True
146+
interactions = InteractionValues(
147+
n_players=self.n,
148+
values=result,
149+
index=self.approximation_index,
150+
interaction_lookup=self._interaction_lookup,
151+
baseline_value=empty_value,
152+
min_order=self.min_order,
153+
max_order=self.max_order,
154+
estimated=True,
155+
estimation_budget=used_budget,
148156
)
157+
158+
return finalize_computed_interactions(interactions, target_index=self.index)

0 commit comments

Comments
 (0)