|
16 | 16 | }, |
17 | 17 | { |
18 | 18 | "cell_type": "code", |
| 19 | + "execution_count": null, |
19 | 20 | "id": "96756a5298128aed", |
20 | 21 | "metadata": { |
21 | 22 | "collapsed": false |
22 | 23 | }, |
| 24 | + "outputs": [], |
23 | 25 | "source": [ |
24 | 26 | "# Install the required packages\n", |
25 | 27 | "!pip install transformers torch" |
26 | | - ], |
27 | | - "outputs": [], |
28 | | - "execution_count": null |
| 28 | + ] |
29 | 29 | }, |
30 | 30 | { |
31 | 31 | "cell_type": "code", |
| 32 | + "execution_count": null, |
32 | 33 | "id": "233a68eadd33ade3", |
33 | 34 | "metadata": { |
34 | 35 | "collapsed": false |
35 | 36 | }, |
| 37 | + "outputs": [], |
36 | 38 | "source": [ |
37 | 39 | "# Import the required libraries\n", |
38 | | - "from transformers import pipeline\n", |
39 | 40 | "import numpy as np\n", |
| 41 | + "from transformers import pipeline\n", |
| 42 | + "\n", |
40 | 43 | "import shapiq\n", |
41 | 44 | "\n", |
42 | 45 | "shapiq.__version__" |
43 | | - ], |
44 | | - "outputs": [], |
45 | | - "execution_count": null |
| 46 | + ] |
46 | 47 | }, |
47 | 48 | { |
48 | 49 | "cell_type": "markdown", |
|
59 | 60 | }, |
60 | 61 | { |
61 | 62 | "cell_type": "code", |
| 63 | + "execution_count": null, |
62 | 64 | "id": "50f59cc77301eef0", |
63 | 65 | "metadata": { |
64 | 66 | "collapsed": false |
65 | 67 | }, |
| 68 | + "outputs": [], |
66 | 69 | "source": [ |
67 | 70 | "# Load the model and tokenizer\n", |
68 | 71 | "classifier = pipeline(task=\"sentiment-analysis\", model=\"lvwerra/distilbert-imdb\")\n", |
|
79 | 82 | "\n", |
80 | 83 | "mask_toke_id = tokenizer.mask_token_id\n", |
81 | 84 | "print(f\"Mask token id: {mask_toke_id}\")" |
82 | | - ], |
83 | | - "outputs": [], |
84 | | - "execution_count": null |
| 85 | + ] |
85 | 86 | }, |
86 | 87 | { |
87 | 88 | "cell_type": "markdown", |
|
95 | 96 | }, |
96 | 97 | { |
97 | 98 | "cell_type": "code", |
| 99 | + "execution_count": null, |
98 | 100 | "id": "c3b3b6f4193e7d73", |
99 | 101 | "metadata": { |
100 | 102 | "collapsed": false |
101 | 103 | }, |
| 104 | + "outputs": [], |
102 | 105 | "source": [ |
103 | 106 | "# Test the tokenizer\n", |
104 | 107 | "decoded_sentence = tokenizer.decode(tokenized_sentence[\"input_ids\"])\n", |
|
110 | 113 | "print(\n", |
111 | 114 | " f\"Decoded sentence: {decoded_sentence} - Tokenized input: {tokenized_input} - {len(tokenized_input)} tokens.\"\n", |
112 | 115 | ")" |
113 | | - ], |
114 | | - "outputs": [], |
115 | | - "execution_count": null |
| 116 | + ] |
116 | 117 | }, |
117 | 118 | { |
118 | 119 | "cell_type": "markdown", |
|
143 | 144 | }, |
144 | 145 | { |
145 | 146 | "cell_type": "code", |
| 147 | + "execution_count": null, |
146 | 148 | "id": "bce879ce457e9a98", |
147 | 149 | "metadata": { |
148 | 150 | "collapsed": false |
149 | 151 | }, |
| 152 | + "outputs": [], |
150 | 153 | "source": [ |
151 | 154 | "# Define the model call function\n", |
152 | 155 | "def model_call(input_texts: list[str]) -> np.ndarray[float]:\n", |
|
170 | 173 | "\n", |
171 | 174 | "# Test the model call function\n", |
172 | 175 | "print(f\"Model call: {model_call(['I love this movie!', 'I hate this movie!'])}\")" |
173 | | - ], |
174 | | - "outputs": [], |
175 | | - "execution_count": null |
| 176 | + ] |
176 | 177 | }, |
177 | 178 | { |
178 | 179 | "cell_type": "markdown", |
|
186 | 187 | }, |
187 | 188 | { |
188 | 189 | "cell_type": "code", |
| 190 | + "execution_count": null, |
189 | 191 | "id": "d176905292347ec1", |
190 | 192 | "metadata": { |
191 | 193 | "collapsed": false |
192 | 194 | }, |
| 195 | + "outputs": [], |
193 | 196 | "source": [ |
194 | 197 | "# Show coalitions\n", |
195 | 198 | "n_players = len(tokenized_sentence[\"input_ids\"]) - 2 # remove [CLS] and [SEP]\n", |
|
199 | 202 | "\n", |
200 | 203 | "print(f\"Empty coalition: {empty_coalition}\")\n", |
201 | 204 | "print(f\"Full coalition: {full_coalition}\")" |
202 | | - ], |
203 | | - "outputs": [], |
204 | | - "execution_count": null |
| 205 | + ] |
205 | 206 | }, |
206 | 207 | { |
207 | 208 | "cell_type": "markdown", |
|
219 | 220 | }, |
220 | 221 | { |
221 | 222 | "cell_type": "code", |
| 223 | + "execution_count": null, |
222 | 224 | "id": "79a5c423622a0904", |
223 | 225 | "metadata": { |
224 | 226 | "collapsed": false |
225 | 227 | }, |
| 228 | + "outputs": [], |
226 | 229 | "source": [ |
227 | 230 | "# Define the value function\n", |
228 | 231 | "def value_function(\n", |
|
253 | 256 | " normalized_sentiments = sentiments - normalization_value\n", |
254 | 257 | "\n", |
255 | 258 | " return normalized_sentiments" |
256 | | - ], |
257 | | - "outputs": [], |
258 | | - "execution_count": null |
| 259 | + ] |
259 | 260 | }, |
260 | 261 | { |
261 | 262 | "cell_type": "markdown", |
|
269 | 270 | }, |
270 | 271 | { |
271 | 272 | "cell_type": "code", |
| 273 | + "execution_count": null, |
272 | 274 | "id": "22b2201ca139c0d0", |
273 | 275 | "metadata": { |
274 | 276 | "collapsed": false |
275 | 277 | }, |
| 278 | + "outputs": [], |
276 | 279 | "source": [ |
277 | 280 | "# Test the value function without normalization\n", |
278 | 281 | "print(f\"Output of the classifier: {classifier(test_sentence)}\")\n", |
|
283 | 286 | "print(\n", |
284 | 287 | " f\"Value function for the empty coalition: {value_function(empty_coalition, tokenized_input=tokenized_input)[0]}\"\n", |
285 | 288 | ")" |
286 | | - ], |
287 | | - "outputs": [], |
288 | | - "execution_count": null |
| 289 | + ] |
289 | 290 | }, |
290 | 291 | { |
291 | 292 | "cell_type": "markdown", |
|
299 | 300 | }, |
300 | 301 | { |
301 | 302 | "cell_type": "code", |
| 303 | + "execution_count": null, |
302 | 304 | "id": "338e1ae439120652", |
303 | 305 | "metadata": { |
304 | 306 | "collapsed": false |
305 | 307 | }, |
| 308 | + "outputs": [], |
306 | 309 | "source": [ |
307 | 310 | "# Test the value function with normalization\n", |
308 | 311 | "normalization_value = float(value_function(empty_coalition, tokenized_input=tokenized_input)[0])\n", |
|
312 | 315 | "print(\n", |
313 | 316 | " f\"Value function for the empty coalition: {value_function(empty_coalition, tokenized_input=tokenized_input, normalization_value=normalization_value)[0]}\"\n", |
314 | 317 | ")" |
315 | | - ], |
316 | | - "outputs": [], |
317 | | - "execution_count": null |
| 318 | + ] |
318 | 319 | }, |
319 | 320 | { |
320 | 321 | "cell_type": "markdown", |
|
328 | 329 | }, |
329 | 330 | { |
330 | 331 | "cell_type": "code", |
| 332 | + "execution_count": null, |
331 | 333 | "id": "91e8b195226e1ecb", |
332 | 334 | "metadata": { |
333 | 335 | "collapsed": false |
334 | 336 | }, |
| 337 | + "outputs": [], |
335 | 338 | "source": [ |
336 | 339 | "# Define the game function\n", |
337 | 340 | "def game_fun(coalitions: np.ndarray[bool]) -> np.ndarray[float]:\n", |
|
351 | 354 | "# Test the game function\n", |
352 | 355 | "print(f\"Game for the full coalition: {game_fun(full_coalition)[0]}\")\n", |
353 | 356 | "print(f\"Game for the empty coalition: {game_fun(empty_coalition)[0]}\")" |
354 | | - ], |
355 | | - "outputs": [], |
356 | | - "execution_count": null |
| 357 | + ] |
357 | 358 | }, |
358 | 359 | { |
359 | 360 | "cell_type": "markdown", |
|
367 | 368 | }, |
368 | 369 | { |
369 | 370 | "cell_type": "code", |
| 371 | + "execution_count": null, |
370 | 372 | "id": "ea94eb7697abad0d", |
371 | 373 | "metadata": { |
372 | 374 | "collapsed": false |
373 | 375 | }, |
| 376 | + "outputs": [], |
374 | 377 | "source": [ |
375 | 378 | "class SentimentClassificationGame(shapiq.Game):\n", |
376 | 379 | " \"\"\"The sentiment analysis classifier modeled as a cooperative game.\n", |
|
438 | 441 | "game_class = SentimentClassificationGame(classifier, tokenizer, test_sentence)\n", |
439 | 442 | "print(f\"Game for the full coalition: {game_class(full_coalition)[0]}\")\n", |
440 | 443 | "print(f\"Game for the empty coalition: {game_class(empty_coalition)[0]}\")" |
441 | | - ], |
442 | | - "outputs": [], |
443 | | - "execution_count": null |
| 444 | + ] |
444 | 445 | }, |
445 | 446 | { |
446 | 447 | "cell_type": "markdown", |
|
455 | 456 | }, |
456 | 457 | { |
457 | 458 | "cell_type": "code", |
| 459 | + "execution_count": null, |
458 | 460 | "id": "f62adc49538c8a79", |
459 | 461 | "metadata": { |
460 | 462 | "collapsed": false |
461 | 463 | }, |
| 464 | + "outputs": [], |
462 | 465 | "source": [ |
463 | 466 | "# Compute Shapley interactions with the ShapIQ approximator for the game function\n", |
464 | 467 | "approximator = shapiq.KernelSHAPIQ(n=n_players, max_order=2, index=\"k-SII\")\n", |
465 | 468 | "sii_values = approximator.approximate(budget=2**n_players, game=game_fun)\n", |
466 | 469 | "sii_values.dict_values" |
467 | | - ], |
468 | | - "outputs": [], |
469 | | - "execution_count": null |
| 470 | + ] |
470 | 471 | }, |
471 | 472 | { |
472 | 473 | "cell_type": "code", |
| 474 | + "execution_count": null, |
473 | 475 | "id": "7641d33a850cdd16", |
474 | 476 | "metadata": { |
475 | 477 | "collapsed": false |
476 | 478 | }, |
| 479 | + "outputs": [], |
477 | 480 | "source": [ |
478 | 481 | "# Compute Shapley interactions with the ShapIQ approximator for the game object\n", |
479 | 482 | "approximator = shapiq.KernelSHAPIQ(n=game_class.n_players, max_order=2, index=\"k-SII\")\n", |
480 | 483 | "sii_values = approximator.approximate(budget=2**game_class.n_players, game=game_class)\n", |
481 | 484 | "sii_values.dict_values" |
482 | | - ], |
483 | | - "outputs": [], |
484 | | - "execution_count": null |
| 485 | + ] |
485 | 486 | }, |
486 | 487 | { |
487 | | - "metadata": {}, |
488 | 488 | "cell_type": "markdown", |
| 489 | + "id": "ef3641f671c8616b", |
| 490 | + "metadata": {}, |
489 | 491 | "source": [ |
490 | 492 | "Now let's say we want to do this for a much larger inputs. We can use the `shapiq.SPEX` approximator which is a sparse\n", |
491 | 493 | "transform approximator. This approximator is much faster than the KernelSHAPIQ approximator when the number of\n", |
492 | 494 | "players is large and can be used for larger inputs. Instead of computing all interactions it computes only the\n", |
493 | 495 | "most important ones." |
494 | | - ], |
495 | | - "id": "ef3641f671c8616b" |
| 496 | + ] |
496 | 497 | }, |
497 | 498 | { |
498 | | - "metadata": {}, |
499 | 499 | "cell_type": "code", |
| 500 | + "execution_count": null, |
| 501 | + "id": "ce6bc4d47530fa2b", |
| 502 | + "metadata": {}, |
| 503 | + "outputs": [], |
500 | 504 | "source": [ |
501 | 505 | "text = \\\n", |
502 | 506 | "\"\"\"\n", |
|
505 | 509 | "\"\"\"\n", |
506 | 510 | "big_game = SentimentClassificationGame(classifier=classifier,\n", |
507 | 511 | " tokenizer=tokenizer,\n", |
508 | | - " test_sentence=text,)\n", |
| 512 | + " test_sentence=text)\n", |
509 | 513 | "print(f\"There are a total of {big_game.n_players} players.\")\n", |
510 | 514 | "# To speed up inference, run pipeline with gpu support. Takes ~10 minutes on Mac M1 with MPS.\n", |
511 | 515 | "scalable_approximator = shapiq.SPEX(n=big_game.n_players, index=\"SII\")\n", |
|
514 | 518 | "print(f\"Game for the empty coalition: {game_class(empty_coalition)[0]}\")\n", |
515 | 519 | "interactions = (list(large_sii.dict_values.items()))\n", |
516 | 520 | "interactions.sort(key= lambda x : abs(x[1]), reverse=True)" |
517 | | - ], |
518 | | - "id": "ce6bc4d47530fa2b", |
519 | | - "outputs": [], |
520 | | - "execution_count": null |
| 521 | + ] |
521 | 522 | }, |
522 | 523 | { |
523 | | - "metadata": {}, |
524 | 524 | "cell_type": "markdown", |
525 | | - "source": "`shapiq.SPEX` identifies interactions between the most sentiment-rich tokens in the paragraph (i.e. *powerful*, *valuable*, *weakness*)", |
526 | | - "id": "e81434a652831817" |
| 525 | + "id": "e81434a652831817", |
| 526 | + "metadata": {}, |
| 527 | + "source": "`shapiq.SPEX` identifies interactions between the most sentiment-rich tokens in the paragraph (i.e. *powerful*, *valuable*, *weakness*)" |
527 | 528 | }, |
528 | 529 | { |
| 530 | + "cell_type": "code", |
| 531 | + "execution_count": 28, |
| 532 | + "id": "a09121f781d9be77", |
529 | 533 | "metadata": { |
530 | 534 | "ExecuteTime": { |
531 | 535 | "end_time": "2025-04-15T18:41:34.838256Z", |
532 | 536 | "start_time": "2025-04-15T18:41:34.829935Z" |
533 | 537 | } |
534 | 538 | }, |
535 | | - "cell_type": "code", |
536 | | - "source": [ |
537 | | - "for inter, value in interactions[:10]:\n", |
538 | | - " tokens = [big_game.tokenizer.decode(big_game.tokenized_input[idx]) for idx in inter]\n", |
539 | | - " print(f'Tokens: {tokens}, Value: {value:.3f}')" |
540 | | - ], |
541 | | - "id": "a09121f781d9be77", |
542 | 539 | "outputs": [ |
543 | 540 | { |
544 | 541 | "name": "stdout", |
|
557 | 554 | ] |
558 | 555 | } |
559 | 556 | ], |
560 | | - "execution_count": 28 |
| 557 | + "source": [ |
| 558 | + "for inter, value in interactions[:10]:\n", |
| 559 | + " tokens = [big_game.tokenizer.decode(big_game.tokenized_input[idx]) for idx in inter]\n", |
| 560 | + " print(f'Tokens: {tokens}, Value: {value:.3f}')" |
| 561 | + ] |
561 | 562 | } |
562 | 563 | ], |
563 | 564 | "metadata": { |
|
0 commit comments