Skip to content

Commit a50f43a

Browse files
committed
re-run ruff
1 parent 531f90b commit a50f43a

9 files changed

Lines changed: 49 additions & 38 deletions

File tree

docs/source/notebooks/language_notebooks/language_model_game.ipynb

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -502,22 +502,21 @@
502502
"metadata": {},
503503
"outputs": [],
504504
"source": [
505-
"text = \\\n",
506-
"\"\"\"\n",
505+
"text = \"\"\"\n",
507506
"shapiq is a valuable Python library designed for Explainable AI (XAI), focusing specifically on approaches like\n",
508507
"Shapley values and their extensions. Its core strength lies in providing a unified framework to compute not only individual feature attributions but also sophisticated interaction indices (e.g., Shapley Interaction Index, Banzhaf Index). This allows users to gain deeper insights into how features collaborate or conflict within complex machine learning models, going beyond simple importance scores. A notable weakness stems from the inherent computational complexity of these game-theoretic measures. Calculating exact values, especially for higher-order interactions, is often infeasible, and even approximations can be computationally intensive and time-consuming, particularly for models with many features or large datasets. Despite this, shapiq remains a powerful tool for detailed model inspection.\n",
509508
"\"\"\"\n",
510-
"big_game = SentimentClassificationGame(classifier=classifier,\n",
511-
" tokenizer=tokenizer,\n",
512-
" test_sentence=text)\n",
509+
"big_game = SentimentClassificationGame(\n",
510+
" classifier=classifier, tokenizer=tokenizer, test_sentence=text\n",
511+
")\n",
513512
"print(f\"There are a total of {big_game.n_players} players.\")\n",
514513
"# To speed up inference, run pipeline with gpu support. Takes ~10 minutes on Mac M1 with MPS.\n",
515514
"scalable_approximator = shapiq.SPEX(n=big_game.n_players, index=\"SII\")\n",
516515
"large_sii = scalable_approximator.approximate(budget=32000, game=big_game)\n",
517516
"print(f\"Game for the full coalition: {game_class(full_coalition)[0]}\")\n",
518517
"print(f\"Game for the empty coalition: {game_class(empty_coalition)[0]}\")\n",
519-
"interactions = (list(large_sii.dict_values.items()))\n",
520-
"interactions.sort(key= lambda x : abs(x[1]), reverse=True)"
518+
"interactions = list(large_sii.dict_values.items())\n",
519+
"interactions.sort(key=lambda x: abs(x[1]), reverse=True)"
521520
]
522521
},
523522
{
@@ -557,7 +556,7 @@
557556
"source": [
558557
"for inter, value in interactions[:10]:\n",
559558
" tokens = [big_game.tokenizer.decode(big_game.tokenized_input[idx]) for idx in inter]\n",
560-
" print(f'Tokens: {tokens}, Value: {value:.3f}')"
559+
" print(f\"Tokens: {tokens}, Value: {value:.3f}\")"
561560
]
562561
}
563562
],

docs/source/notebooks/vision_notebooks/vision_transformer.ipynb

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -371,9 +371,11 @@
371371
" x_explain_path=image_path,\n",
372372
")\n",
373373
"print(\"Number of players:\", big_game.n_players) # the image is transformed into 36 patches\n",
374-
"print('Takes ~6min on an M1 Macbook Pro')\n",
374+
"print(\"Takes ~6min on an M1 Macbook Pro\")\n",
375375
"# For 36 patches, consider using order 1 for simpler explanations\n",
376-
"scalable_approximator = shapiq.SPEX(n=big_game.n_players, max_order=2, index=\"SII\") # Here we use max order=2,\n",
376+
"scalable_approximator = shapiq.SPEX(\n",
377+
" n=big_game.n_players, max_order=2, index=\"SII\"\n",
378+
") # Here we use max order=2,\n",
377379
"large_sii = scalable_approximator.approximate(budget=9000, game=big_game)\n",
378380
"large_sii.plot_network(center_image=center_image, draw_legend=False)"
379381
]

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@ dependencies = [
1515
"scikit-learn",
1616
"tqdm",
1717
"requests",
18+
"sparse-transform",
19+
"galois",
1820
# plotting
1921
"matplotlib",
2022
"networkx",
2123
"colour",
22-
"sparse-transform",
23-
"galois",
2424
]
2525
authors = [
2626
{name = "Maximilian Muschalik", email = "Maximilian.Muschalik@lmu.de"},

shapiq/approximator/sparse/_base.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,9 @@ def __init__(
8181
"num_repeat": 1,
8282
"reconstruct_method_source": "coded",
8383
"peeling_method": "multi-detect",
84-
"reconstruct_method_channel": "identity-siso" if self.decoder_type == "soft" else "identity",
84+
"reconstruct_method_channel": "identity-siso"
85+
if self.decoder_type == "soft"
86+
else "identity",
8587
"regress": "lasso",
8688
"res_energy_cutoff": 0.9,
8789
"source_decoder": get_bch_decoder(n, self.t, self.decoder_type),
@@ -96,7 +98,9 @@ def __init__(
9698
)
9799

98100
def approximate(
99-
self, budget: int, game: Callable[[np.ndarray], np.ndarray],
101+
self,
102+
budget: int,
103+
game: Callable[[np.ndarray], np.ndarray],
100104
) -> InteractionValues:
101105
"""Approximates the interaction values using a sparse transform approach.
102106
@@ -109,11 +113,12 @@ def approximate(
109113
"""
110114
# Find the maximum value of b that fits within the given sample budget and get the used budget
111115
used_budget = self._set_transform_budget(budget)
112-
signal = SubsampledSignalFourier(func=lambda inputs: game(inputs.astype(bool)),
113-
n=self.n,
114-
q=2,
115-
query_args=self.query_args,
116-
)
116+
signal = SubsampledSignalFourier(
117+
func=lambda inputs: game(inputs.astype(bool)),
118+
n=self.n,
119+
q=2,
120+
query_args=self.query_args,
121+
)
117122
# Extract the coefficients of the original transform
118123
initial_transform = {
119124
tuple(np.nonzero(key)[0]): np.real(value)
@@ -142,7 +147,6 @@ def approximate(
142147
output = self.aggregate_interaction_values(output)
143148
return output
144149

145-
146150
def _filter_order(self, result: np.ndarray) -> np.ndarray:
147151
"""Filters the interactions to keep only those of the maximum order.
148152
@@ -232,7 +236,9 @@ def _set_transform_budget(self, budget: int) -> int:
232236

233237
# Break if 'b' is now sufficient
234238
if b > 2:
235-
self.decoder_args["source_decoder"] = get_bch_decoder(self.n, self.t, self.decoder_type)
239+
self.decoder_args["source_decoder"] = get_bch_decoder(
240+
self.n, self.t, self.decoder_type
241+
)
236242
break
237243

238244
# If 'b' is still too low, raise an error

shapiq/benchmark/configuration.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -745,7 +745,10 @@
745745
dict.fromkeys(SV_APPROXIMATORS, ("n", "random_state")),
746746
)
747747
APPROXIMATION_BENCHMARK_PARAMS.update(
748-
dict.fromkeys(SI_APPROXIMATORS + SII_APPROXIMATORS + STII_APPROXIMATORS + FSII_APPROXIMATORS, ("n", "random_state", "index", "max_order")),
748+
dict.fromkeys(
749+
SI_APPROXIMATORS + SII_APPROXIMATORS + STII_APPROXIMATORS + FSII_APPROXIMATORS,
750+
("n", "random_state", "index", "max_order"),
751+
),
749752
)
750753

751754

shapiq/games/benchmark/_setup/_vit_setup.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -147,30 +147,24 @@ def _transform_coalition_into_bool_mask(coalition: np.ndarray, n_patches: int) -
147147
bool_mask_2d = torch.ones((12, 12), dtype=torch.int)
148148

149149
# Calculate the size of each super-patch based on the grid size
150-
patch_size = 12 // int(n_patches ** 0.5)
150+
patch_size = 12 // int(n_patches**0.5)
151151

152152
for player, is_present in enumerate(coalition):
153153
if is_present:
154154
x, y = (
155155
MAPPING_PLAYER_MASK[n_patches][player]["x"],
156156
MAPPING_PLAYER_MASK[n_patches][player]["y"],
157157
)
158-
bool_mask_2d[y:y + patch_size, x:x + patch_size] = 0
158+
bool_mask_2d[y : y + patch_size, x : x + patch_size] = 0
159159

160160
bool_mask_1d = bool_mask_2d.flatten()
161161
return bool_mask_1d
162162

163163

164164
# constants for the boolean mask generation for the Vision Transformer model
165165
MAPPING_PLAYER_MASK = {
166-
36: {
167-
player: {"x": (player % 6) * 2, "y": (player // 6) * 2}
168-
for player in range(36)
169-
},
170-
144: {
171-
player: {"x": player % 12, "y": player // 12}
172-
for player in range(144)
173-
},
166+
36: {player: {"x": (player % 6) * 2, "y": (player // 6) * 2} for player in range(36)},
167+
144: {player: {"x": player % 12, "y": player // 12} for player in range(144)},
174168
16: {
175169
0: {"x": 0, "y": 0},
176170
1: {"x": 3, "y": 0},

shapiq/games/benchmark/local_xai/benchmark_image.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,15 @@ def __init__(
8080
raise ValueError(msg)
8181

8282
# validate inputs
83-
valid_models = ["vit_144_patches", "vit_36_patches", "vit_16_patches", "vit_9_patches", "resnet_18"]
83+
valid_models = [
84+
"vit_144_patches",
85+
"vit_36_patches",
86+
"vit_16_patches",
87+
"vit_9_patches",
88+
"resnet_18",
89+
]
8490
if model_name.lower() not in valid_models:
85-
msg = (
86-
f"Invalid model {model_name}. The model must be one of {valid_models}"
87-
)
91+
msg = f"Invalid model {model_name}. The model must be one of {valid_models}"
8892
raise ValueError(
8993
msg,
9094
)
@@ -104,7 +108,7 @@ def __init__(
104108
"vit_144_patches": 144,
105109
"vit_36_patches": 36,
106110
"vit_16_patches": 16,
107-
"vit_9_patches": 9
111+
"vit_9_patches": 9,
108112
}
109113
n_players = patch_sizes[model_name]
110114

tests/tests_approximators/test_approximator_base_sparse.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""This test module contains all tests regarding the base sparse approximator."""
2+
23
from __future__ import annotations
34

45
import pytest
@@ -188,7 +189,7 @@ def test_approximate(
188189
# generate the set of expected interactions
189190
expected_interactions = set()
190191
if estimates.min_order == 0:
191-
expected_interactions.update( {(i,) for i in range(n)} )
192+
expected_interactions.update({(i,) for i in range(n)})
192193
if estimates.max_order > 1:
193194
expected_interactions.add(interaction)
194195

tests/tests_approximators/test_approximator_spex.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def test_spex_vs_sparse():
125125
# Check that the interaction is estimated similarly
126126
assert abs(spex_estimates[interaction] - sparse_estimates[interaction]) < 0.1
127127

128+
128129
@pytest.mark.parametrize(
129130
"n, interaction, budget, correct_b, correct_t",
130131
[
@@ -148,6 +149,7 @@ def test_sparsity_parameter(n, interaction, budget, correct_b, correct_t):
148149
assert spex.query_args["b"] == correct_b
149150
assert spex.t == correct_t
150151

152+
151153
@pytest.mark.parametrize(
152154
"n, interaction, budget",
153155
[

0 commit comments

Comments
 (0)