Skip to content

Commit 522b0ab

Browse files
committed
fix integrations, run ruff linter
1 parent 546363c commit 522b0ab

10 files changed

Lines changed: 196 additions & 186 deletions

File tree

docs/source/notebooks/language_notebooks/language_model_game.ipynb

Lines changed: 61 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -16,33 +16,34 @@
1616
},
1717
{
1818
"cell_type": "code",
19+
"execution_count": null,
1920
"id": "96756a5298128aed",
2021
"metadata": {
2122
"collapsed": false
2223
},
24+
"outputs": [],
2325
"source": [
2426
"# Install the required packages\n",
2527
"!pip install transformers torch"
26-
],
27-
"outputs": [],
28-
"execution_count": null
28+
]
2929
},
3030
{
3131
"cell_type": "code",
32+
"execution_count": null,
3233
"id": "233a68eadd33ade3",
3334
"metadata": {
3435
"collapsed": false
3536
},
37+
"outputs": [],
3638
"source": [
3739
"# Import the required libraries\n",
38-
"from transformers import pipeline\n",
3940
"import numpy as np\n",
41+
"from transformers import pipeline\n",
42+
"\n",
4043
"import shapiq\n",
4144
"\n",
4245
"shapiq.__version__"
43-
],
44-
"outputs": [],
45-
"execution_count": null
46+
]
4647
},
4748
{
4849
"cell_type": "markdown",
@@ -59,10 +60,12 @@
5960
},
6061
{
6162
"cell_type": "code",
63+
"execution_count": null,
6264
"id": "50f59cc77301eef0",
6365
"metadata": {
6466
"collapsed": false
6567
},
68+
"outputs": [],
6669
"source": [
6770
"# Load the model and tokenizer\n",
6871
"classifier = pipeline(task=\"sentiment-analysis\", model=\"lvwerra/distilbert-imdb\")\n",
@@ -79,9 +82,7 @@
7982
"\n",
8083
"mask_toke_id = tokenizer.mask_token_id\n",
8184
"print(f\"Mask token id: {mask_toke_id}\")"
82-
],
83-
"outputs": [],
84-
"execution_count": null
85+
]
8586
},
8687
{
8788
"cell_type": "markdown",
@@ -95,10 +96,12 @@
9596
},
9697
{
9798
"cell_type": "code",
99+
"execution_count": null,
98100
"id": "c3b3b6f4193e7d73",
99101
"metadata": {
100102
"collapsed": false
101103
},
104+
"outputs": [],
102105
"source": [
103106
"# Test the tokenizer\n",
104107
"decoded_sentence = tokenizer.decode(tokenized_sentence[\"input_ids\"])\n",
@@ -110,9 +113,7 @@
110113
"print(\n",
111114
" f\"Decoded sentence: {decoded_sentence} - Tokenized input: {tokenized_input} - {len(tokenized_input)} tokens.\"\n",
112115
")"
113-
],
114-
"outputs": [],
115-
"execution_count": null
116+
]
116117
},
117118
{
118119
"cell_type": "markdown",
@@ -143,10 +144,12 @@
143144
},
144145
{
145146
"cell_type": "code",
147+
"execution_count": null,
146148
"id": "bce879ce457e9a98",
147149
"metadata": {
148150
"collapsed": false
149151
},
152+
"outputs": [],
150153
"source": [
151154
"# Define the model call function\n",
152155
"def model_call(input_texts: list[str]) -> np.ndarray[float]:\n",
@@ -170,9 +173,7 @@
170173
"\n",
171174
"# Test the model call function\n",
172175
"print(f\"Model call: {model_call(['I love this movie!', 'I hate this movie!'])}\")"
173-
],
174-
"outputs": [],
175-
"execution_count": null
176+
]
176177
},
177178
{
178179
"cell_type": "markdown",
@@ -186,10 +187,12 @@
186187
},
187188
{
188189
"cell_type": "code",
190+
"execution_count": null,
189191
"id": "d176905292347ec1",
190192
"metadata": {
191193
"collapsed": false
192194
},
195+
"outputs": [],
193196
"source": [
194197
"# Show coalitions\n",
195198
"n_players = len(tokenized_sentence[\"input_ids\"]) - 2 # remove [CLS] and [SEP]\n",
@@ -199,9 +202,7 @@
199202
"\n",
200203
"print(f\"Empty coalition: {empty_coalition}\")\n",
201204
"print(f\"Full coalition: {full_coalition}\")"
202-
],
203-
"outputs": [],
204-
"execution_count": null
205+
]
205206
},
206207
{
207208
"cell_type": "markdown",
@@ -219,10 +220,12 @@
219220
},
220221
{
221222
"cell_type": "code",
223+
"execution_count": null,
222224
"id": "79a5c423622a0904",
223225
"metadata": {
224226
"collapsed": false
225227
},
228+
"outputs": [],
226229
"source": [
227230
"# Define the value function\n",
228231
"def value_function(\n",
@@ -253,9 +256,7 @@
253256
" normalized_sentiments = sentiments - normalization_value\n",
254257
"\n",
255258
" return normalized_sentiments"
256-
],
257-
"outputs": [],
258-
"execution_count": null
259+
]
259260
},
260261
{
261262
"cell_type": "markdown",
@@ -269,10 +270,12 @@
269270
},
270271
{
271272
"cell_type": "code",
273+
"execution_count": null,
272274
"id": "22b2201ca139c0d0",
273275
"metadata": {
274276
"collapsed": false
275277
},
278+
"outputs": [],
276279
"source": [
277280
"# Test the value function without normalization\n",
278281
"print(f\"Output of the classifier: {classifier(test_sentence)}\")\n",
@@ -283,9 +286,7 @@
283286
"print(\n",
284287
" f\"Value function for the empty coalition: {value_function(empty_coalition, tokenized_input=tokenized_input)[0]}\"\n",
285288
")"
286-
],
287-
"outputs": [],
288-
"execution_count": null
289+
]
289290
},
290291
{
291292
"cell_type": "markdown",
@@ -299,10 +300,12 @@
299300
},
300301
{
301302
"cell_type": "code",
303+
"execution_count": null,
302304
"id": "338e1ae439120652",
303305
"metadata": {
304306
"collapsed": false
305307
},
308+
"outputs": [],
306309
"source": [
307310
"# Test the value function with normalization\n",
308311
"normalization_value = float(value_function(empty_coalition, tokenized_input=tokenized_input)[0])\n",
@@ -312,9 +315,7 @@
312315
"print(\n",
313316
" f\"Value function for the empty coalition: {value_function(empty_coalition, tokenized_input=tokenized_input, normalization_value=normalization_value)[0]}\"\n",
314317
")"
315-
],
316-
"outputs": [],
317-
"execution_count": null
318+
]
318319
},
319320
{
320321
"cell_type": "markdown",
@@ -328,10 +329,12 @@
328329
},
329330
{
330331
"cell_type": "code",
332+
"execution_count": null,
331333
"id": "91e8b195226e1ecb",
332334
"metadata": {
333335
"collapsed": false
334336
},
337+
"outputs": [],
335338
"source": [
336339
"# Define the game function\n",
337340
"def game_fun(coalitions: np.ndarray[bool]) -> np.ndarray[float]:\n",
@@ -351,9 +354,7 @@
351354
"# Test the game function\n",
352355
"print(f\"Game for the full coalition: {game_fun(full_coalition)[0]}\")\n",
353356
"print(f\"Game for the empty coalition: {game_fun(empty_coalition)[0]}\")"
354-
],
355-
"outputs": [],
356-
"execution_count": null
357+
]
357358
},
358359
{
359360
"cell_type": "markdown",
@@ -367,10 +368,12 @@
367368
},
368369
{
369370
"cell_type": "code",
371+
"execution_count": null,
370372
"id": "ea94eb7697abad0d",
371373
"metadata": {
372374
"collapsed": false
373375
},
376+
"outputs": [],
374377
"source": [
375378
"class SentimentClassificationGame(shapiq.Game):\n",
376379
" \"\"\"The sentiment analysis classifier modeled as a cooperative game.\n",
@@ -438,9 +441,7 @@
438441
"game_class = SentimentClassificationGame(classifier, tokenizer, test_sentence)\n",
439442
"print(f\"Game for the full coalition: {game_class(full_coalition)[0]}\")\n",
440443
"print(f\"Game for the empty coalition: {game_class(empty_coalition)[0]}\")"
441-
],
442-
"outputs": [],
443-
"execution_count": null
444+
]
444445
},
445446
{
446447
"cell_type": "markdown",
@@ -455,48 +456,51 @@
455456
},
456457
{
457458
"cell_type": "code",
459+
"execution_count": null,
458460
"id": "f62adc49538c8a79",
459461
"metadata": {
460462
"collapsed": false
461463
},
464+
"outputs": [],
462465
"source": [
463466
"# Compute Shapley interactions with the ShapIQ approximator for the game function\n",
464467
"approximator = shapiq.KernelSHAPIQ(n=n_players, max_order=2, index=\"k-SII\")\n",
465468
"sii_values = approximator.approximate(budget=2**n_players, game=game_fun)\n",
466469
"sii_values.dict_values"
467-
],
468-
"outputs": [],
469-
"execution_count": null
470+
]
470471
},
471472
{
472473
"cell_type": "code",
474+
"execution_count": null,
473475
"id": "7641d33a850cdd16",
474476
"metadata": {
475477
"collapsed": false
476478
},
479+
"outputs": [],
477480
"source": [
478481
"# Compute Shapley interactions with the ShapIQ approximator for the game object\n",
479482
"approximator = shapiq.KernelSHAPIQ(n=game_class.n_players, max_order=2, index=\"k-SII\")\n",
480483
"sii_values = approximator.approximate(budget=2**game_class.n_players, game=game_class)\n",
481484
"sii_values.dict_values"
482-
],
483-
"outputs": [],
484-
"execution_count": null
485+
]
485486
},
486487
{
487-
"metadata": {},
488488
"cell_type": "markdown",
489+
"id": "ef3641f671c8616b",
490+
"metadata": {},
489491
"source": [
490492
"Now let's say we want to do this for a much larger inputs. We can use the `shapiq.SPEX` approximator which is a sparse\n",
491493
"transform approximator. This approximator is much faster than the KernelSHAPIQ approximator when the number of\n",
492494
"players is large and can be used for larger inputs. Instead of computing all interactions it computes only the\n",
493495
"most important ones."
494-
],
495-
"id": "ef3641f671c8616b"
496+
]
496497
},
497498
{
498-
"metadata": {},
499499
"cell_type": "code",
500+
"execution_count": null,
501+
"id": "ce6bc4d47530fa2b",
502+
"metadata": {},
503+
"outputs": [],
500504
"source": [
501505
"text = \\\n",
502506
"\"\"\"\n",
@@ -505,7 +509,7 @@
505509
"\"\"\"\n",
506510
"big_game = SentimentClassificationGame(classifier=classifier,\n",
507511
" tokenizer=tokenizer,\n",
508-
" test_sentence=text,)\n",
512+
" test_sentence=text)\n",
509513
"print(f\"There are a total of {big_game.n_players} players.\")\n",
510514
"# To speed up inference, run pipeline with gpu support. Takes ~10 minutes on Mac M1 with MPS.\n",
511515
"scalable_approximator = shapiq.SPEX(n=big_game.n_players, index=\"SII\")\n",
@@ -514,31 +518,24 @@
514518
"print(f\"Game for the empty coalition: {game_class(empty_coalition)[0]}\")\n",
515519
"interactions = (list(large_sii.dict_values.items()))\n",
516520
"interactions.sort(key= lambda x : abs(x[1]), reverse=True)"
517-
],
518-
"id": "ce6bc4d47530fa2b",
519-
"outputs": [],
520-
"execution_count": null
521+
]
521522
},
522523
{
523-
"metadata": {},
524524
"cell_type": "markdown",
525-
"source": "`shapiq.SPEX` identifies interactions between the most sentiment-rich tokens in the paragraph (i.e. *powerful*, *valuable*, *weakness*)",
526-
"id": "e81434a652831817"
525+
"id": "e81434a652831817",
526+
"metadata": {},
527+
"source": "`shapiq.SPEX` identifies interactions between the most sentiment-rich tokens in the paragraph (i.e. *powerful*, *valuable*, *weakness*)"
527528
},
528529
{
530+
"cell_type": "code",
531+
"execution_count": 28,
532+
"id": "a09121f781d9be77",
529533
"metadata": {
530534
"ExecuteTime": {
531535
"end_time": "2025-04-15T18:41:34.838256Z",
532536
"start_time": "2025-04-15T18:41:34.829935Z"
533537
}
534538
},
535-
"cell_type": "code",
536-
"source": [
537-
"for inter, value in interactions[:10]:\n",
538-
" tokens = [big_game.tokenizer.decode(big_game.tokenized_input[idx]) for idx in inter]\n",
539-
" print(f'Tokens: {tokens}, Value: {value:.3f}')"
540-
],
541-
"id": "a09121f781d9be77",
542539
"outputs": [
543540
{
544541
"name": "stdout",
@@ -557,7 +554,11 @@
557554
]
558555
}
559556
],
560-
"execution_count": 28
557+
"source": [
558+
"for inter, value in interactions[:10]:\n",
559+
" tokens = [big_game.tokenizer.decode(big_game.tokenized_input[idx]) for idx in inter]\n",
560+
" print(f'Tokens: {tokens}, Value: {value:.3f}')"
561+
]
561562
}
562563
],
563564
"metadata": {

0 commit comments

Comments
 (0)