Skip to content

Commit 002bd39

Browse files
mmschlkclaude
andauthored
Add ty linter. (#493)
* add MANIFEST.in * add MANIFEST.in * updated MANIFEST.in and python-publish.yml workflow * updated CHANGELOG.md * updated shapiq-games warning * migrate type checker from pyright to ty and fix all type errors Replace pyright with Astral's ty (Rust-based, 10-100x faster) across the entire project: update pyproject.toml dependency group, replace [tool.pyright] config with [tool.ty] sections, swap the pre-commit hook to a local ty hook, and update CLAUDE.md references. Fix all resulting type errors across 30+ source files: align override signatures (approximate, explain_function) with base class, replace pyright-specific ignore comments with ty-compatible ones, annotate ambiguous types (decoder_args, bounds_players, coalition_lookup), and add targeted type: ignore directives for unresolvable stubs (lightgbm, matplotlib internals). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 92e48ba commit 002bd39

38 files changed

Lines changed: 261 additions & 92 deletions

.pre-commit-config.yaml

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,11 @@ repos:
1717
args: [--fix, --exit-non-zero-on-fix, --no-cache]
1818
- id: ruff-format
1919

20-
- repo: https://github.com/RobertCraigie/pyright-python
21-
rev: v1.1.404
20+
- repo: local
2221
hooks:
23-
- id: pyright
24-
args: [--verbose, --venvpath=.]
22+
- id: ty
23+
name: ty
24+
entry: uv run ty check
25+
language: system
26+
pass_filenames: false
27+
types: [python]

CLAUDE.md

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# CLAUDE.md
2+
3+
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
4+
5+
## Project Overview
6+
7+
`shapiq` is a Python library for computing Shapley Interactions for Machine Learning. It approximates any-order Shapley interactions, benchmarks game-theoretical algorithms, and explains feature interactions in model predictions. The repo contains two importable packages: `shapiq` (core) and `shapiq_games` (benchmark games).
8+
9+
## Development Setup
10+
11+
This project uses `uv` for package management.
12+
13+
```sh
14+
# Install dev dependencies (test + lint + all_ml)
15+
uv sync
16+
17+
# Install only test dependencies
18+
uv sync --group test
19+
20+
# Install lint tools
21+
uv sync --group lint
22+
```
23+
24+
## Commands
25+
26+
### Testing
27+
28+
```sh
29+
# Run all shapiq unit tests
30+
uv run pytest tests/shapiq
31+
32+
# Run all shapiq_games tests (parallel)
33+
uv run pytest tests/shapiq_games -n logical
34+
35+
# Run a single test file
36+
uv run pytest tests/shapiq/tests_unit/test_interaction_values.py
37+
38+
# Run with coverage
39+
uv run pytest tests/shapiq --cov=shapiq --cov-report=xml -n logical
40+
```
41+
42+
### Linting and Code Quality
43+
44+
```sh
45+
# Run all pre-commit hooks (ruff lint + format + ty)
46+
uv run pre-commit run --all-files
47+
48+
# Run ruff linter only
49+
uv run ruff check src/ tests/ --fix
50+
51+
# Run ruff formatter only
52+
uv run ruff format src/ tests/
53+
54+
# Run type checking
55+
uv run ty check
56+
```
57+
58+
### Documentation
59+
60+
```sh
61+
uv sync --no-dev --group docs
62+
uv run sphinx-build -b html docs/source docs/build/html
63+
```
64+
65+
## Code Architecture
66+
67+
### Package Structure
68+
69+
```
70+
src/
71+
├── shapiq/ # Core package
72+
│ ├── interaction_values.py # InteractionValues data class (central output type)
73+
│ ├── game.py # Base Game class for cooperative games
74+
│ ├── approximator/ # Approximation algorithms
75+
│ │ ├── base.py # Approximator base class
76+
│ │ ├── marginals/ # Owen, Stratified sampling
77+
│ │ ├── montecarlo/ # SHAPIQ, SVARM, SVARMIQ, UnbiasedKernelSHAP
78+
│ │ ├── permutation/ # Permutation sampling for SII, STII, SV
79+
│ │ ├── regression/ # KernelSHAP, KernelSHAPIQ, RegressionFSII/FBII
80+
│ │ └── sparse/ # SPEX, ProxySPEX (for large feature spaces)
81+
│ ├── explainer/ # High-level explainer interfaces
82+
│ │ ├── tabular.py # TabularExplainer (main user-facing class)
83+
│ │ ├── tree/ # TreeExplainer with model-specific conversions
84+
│ │ └── product_kernel/ # ProductKernelExplainer
85+
│ ├── imputer/ # Imputation strategies for missing features
86+
│ │ ├── marginal_imputer.py # MarginalImputer (most common)
87+
│ │ ├── baseline_imputer.py
88+
│ │ ├── gaussian_imputer.py
89+
│ │ └── tabpfn_imputer.py
90+
│ ├── game_theory/ # Mathematical game-theory utilities
91+
│ │ ├── exact.py # ExactComputer for exact interaction values
92+
│ │ ├── indices.py # ALL_AVAILABLE_CONCEPTS index registry
93+
│ │ ├── moebius_converter.py
94+
│ │ └── aggregation.py
95+
│ ├── plot/ # Visualization functions
96+
│ └── utils/ # Shared utilities (sets, saving, typing)
97+
└── shapiq_games/ # Benchmark games package (separate from shapiq)
98+
├── benchmark/ # Pre-defined benchmark games per use-case
99+
├── synthetic/ # Synthetic game functions
100+
└── tabular/ # Tabular ML games
101+
```
102+
103+
### Core Data Flow
104+
105+
1. **Game** (`game.py`): Wraps any callable as a cooperative game. Subclasses implement `value_function(coalitions) -> np.ndarray`. Takes a boolean coalition matrix and returns scalar game values.
106+
107+
2. **Approximator** (`approximator/`): Takes a `Game` and a budget, calls `approximate(budget, game)` → returns `InteractionValues`. All approximators inherit from `Approximator` base class.
108+
109+
3. **InteractionValues** (`interaction_values.py`): Central data class storing interaction scores as a numpy array with an `interaction_lookup` dict mapping coalition tuples → array indices. Supports arithmetic operations between instances.
110+
111+
4. **Explainer** (`explainer/`): High-level interface combining an ML model + data + an `Imputer` into a `Game`, then calling an `Approximator`. `Explainer.explain(x)``InteractionValues`.
112+
113+
5. **Imputer** (`imputer/`): Converts ML model + data into a game by handling missing features. `MarginalImputer` is the default for tabular data.
114+
115+
### Interaction Indices
116+
117+
Available indices are defined in `game_theory/indices.py` (`ALL_AVAILABLE_CONCEPTS`). Key ones:
118+
- `SV` – Shapley Values (order 1 only)
119+
- `SII` – Shapley Interaction Index
120+
- `k-SII` – k-Shapley Interaction Index (most common for explanations)
121+
- `STII` – Shapley-Taylor Interaction Index
122+
- `FSII` – Faithful Shapley Interaction Index
123+
- `FBII` – Faithful Banzhaf Interaction Index
124+
- `BV` – Banzhaf Values
125+
126+
### Code Style
127+
128+
- **Formatter/Linter**: `ruff` with `black` style, line length 100, Google-style docstrings
129+
- **Type checking**: `ty` (checks `src/shapiq/`, excluded for tests)
130+
- **All files** must start with `from __future__ import annotations`
131+
- `isort` is configured with `required-imports = ["from __future__ import annotations"]`
132+
- Variable names `X` (uppercase) in functions are allowed (common in ML code)
133+
- Test files live in `tests/shapiq/` and `tests/shapiq_games/` with separate conftest files
134+
135+
### Test Organization
136+
137+
- `tests/shapiq/tests_unit/` – Unit tests per module
138+
- `tests/shapiq/tests_integration_tests/` – Integration tests
139+
- `tests/shapiq/tests_deprecation/` – Deprecation behavior tests
140+
- `tests/shapiq/fixtures/` – Shared pytest fixtures (data, games, models, interaction values)
141+
- `tests/shapiq_games/` – Tests for the `shapiq_games` package
142+
143+
### Two-Package Setup
144+
145+
The repo hosts two installable packages:
146+
- `shapiq` in `src/shapiq/` — the core library
147+
- `shapiq_games` in `src/shapiq_games/` — optional benchmark games requiring extra ML dependencies (`torch`, `transformers`, `tabpfn`)
148+
149+
`shapiq_games` requires `uv sync --group all_ml` for full functionality.

pyproject.toml

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,15 @@ exclude = [
207207
[tool.ruff.lint.pydocstyle]
208208
convention = "google"
209209

210+
[tool.ty.src]
211+
include = ["src/shapiq"]
212+
exclude = ["tests", "docs", "benchmark", "scripts", "src/shapiq_games"]
213+
214+
[tool.ty.environment]
215+
python-version = "3.10"
216+
python = "./.venv"
217+
python-platform = "linux"
218+
210219
[tool.ruff.lint.isort]
211220
known-first-party = ["shapiq"]
212221
extra-standard-library = ["typing_extensions"]
@@ -215,15 +224,6 @@ force-wrap-aliases = true
215224
no-lines-before = ["future"]
216225
required-imports = ["from __future__ import annotations"]
217226

218-
[tool.pyright]
219-
include = ["src/shapiq"]
220-
exclude = ["tests", "docs", "benchmark", "scripts", "src/shapiq_games"]
221-
venv = ".venv"
222-
pythonVersion = "3.10"
223-
defineConstant = { DEBUG = true }
224-
reportMissingImports = "error"
225-
reportMissingTypeStubs = false
226-
pythonPlatform = "Linux"
227227

228228
[dependency-groups]
229229
all_ml = [
@@ -246,7 +246,7 @@ test = [
246246
lint = [
247247
"ruff>=0.14.1", # for linting
248248
"pre-commit>=4.3.0", # for running the linters pre-commit hooks
249-
"pyright>=1.1.402", # for type checking
249+
"ty>=0.0.21", # for type checking
250250
]
251251
docs = [
252252
"sphinx>=8.0.0", # documentation generator

src/shapiq/approximator/base.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,6 @@ def approximate(
179179
self,
180180
budget: int,
181181
game: Game | Callable[[np.ndarray], np.ndarray],
182-
*args: Any,
183182
**kwargs: Any,
184183
) -> InteractionValues:
185184
"""Approximates the interaction values.

src/shapiq/approximator/montecarlo/base.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from __future__ import annotations
44

5-
from typing import TYPE_CHECKING, Literal, TypeVar, get_args
5+
from typing import TYPE_CHECKING, Any, Literal, TypeVar, get_args
66

77
import numpy as np
88
from scipy.special import binom, factorial
@@ -14,6 +14,7 @@
1414
if TYPE_CHECKING:
1515
from collections.abc import Callable
1616

17+
from shapiq.game import Game
1718
from shapiq.typing import FloatVector
1819

1920
ValidMonteCarloIndices = Literal["k-SII", "SII", "STII", "FSII", "FBII", "SV", "CHII", "BII", "BV"]
@@ -38,7 +39,7 @@ def __init__(
3839
self,
3940
n: int,
4041
max_order: int,
41-
index: TIndices = "k-SII",
42+
index: ValidMonteCarloIndices = "k-SII",
4243
*,
4344
stratify_coalition_size: bool = True,
4445
stratify_intersection: bool = True,
@@ -91,13 +92,15 @@ def __init__(
9192
def approximate(
9293
self,
9394
budget: int,
94-
game: Callable[[np.ndarray], np.ndarray],
95+
game: Game | Callable[[np.ndarray], np.ndarray],
96+
**kwargs: Any, # noqa: ARG002
9597
) -> InteractionValues:
9698
"""Approximates the Shapley interaction values using Monte Carlo sampling.
9799
98100
Args:
99101
budget: The budget for the approximation.
100102
game: The game function that returns the values for the coalitions.
103+
**kwargs: Additional keyword arguments (unused).
101104
102105
Returns:
103106
The approximated Shapley interaction values.

src/shapiq/approximator/montecarlo/shapiq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def __init__(
5252
self,
5353
n: int,
5454
max_order: int = 2,
55-
index: TIndices = "k-SII",
55+
index: ValidMonteCarloIndices = "k-SII",
5656
*,
5757
top_order: bool = False,
5858
sampling_weights: FloatVector | None = None,

src/shapiq/approximator/permutation/sii.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from __future__ import annotations
44

5-
from typing import TYPE_CHECKING, Literal, get_args
5+
from typing import TYPE_CHECKING, Any, Literal, get_args
66

77
import numpy as np
88

@@ -13,6 +13,8 @@
1313
if TYPE_CHECKING:
1414
from collections.abc import Callable
1515

16+
from shapiq.game import Game
17+
1618
ValidPermutationSIIIndices = Literal["SII", "k-SII"]
1719

1820

@@ -99,15 +101,17 @@ def _compute_order_iterator(self) -> np.ndarray:
99101
def approximate(
100102
self,
101103
budget: int,
102-
game: Callable[[np.ndarray], np.ndarray],
104+
game: Game | Callable[[np.ndarray], np.ndarray],
103105
batch_size: int | None = 5,
106+
**kwargs: Any, # noqa: ARG002
104107
) -> InteractionValues:
105108
"""Approximates the interaction values.
106109
107110
Args:
108111
budget: The budget for the approximation.
109112
game: The game function as a callable that takes a set of players and returns the value.
110113
batch_size: The size of the batch. If ``None``, the batch size is set to ``1``. Defaults to ``5``.
114+
**kwargs: Additional keyword arguments (unused).
111115
112116
Returns:
113117
InteractionValues: The estimated interaction values.

src/shapiq/approximator/permutation/stii.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
if TYPE_CHECKING:
1616
from collections.abc import Callable
1717

18+
from shapiq.game import Game
19+
1820
ValidPermutationSTIIIndices = Literal["STII"]
1921

2022

@@ -89,9 +91,8 @@ def __init__(
8991
def approximate(
9092
self,
9193
budget: int,
92-
game: Callable[[np.ndarray], np.ndarray],
94+
game: Game | Callable[[np.ndarray], np.ndarray],
9395
batch_size: int = 1,
94-
*args: Any, # noqa: ARG002
9596
**kwargs: Any, # noqa: ARG002
9697
) -> InteractionValues:
9798
"""Approximates the interaction values.

src/shapiq/approximator/permutation/sv.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ def approximate(
6767
budget: int,
6868
game: Game | Callable[[np.ndarray], np.ndarray],
6969
batch_size: int | None = 5,
70-
*args: Any, # noqa: ARG002
7170
**kwargs: Any, # noqa: ARG002
7271
) -> InteractionValues:
7372
"""Approximates the Shapley values using ApproShapley.

src/shapiq/approximator/regression/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
if TYPE_CHECKING:
1717
from collections.abc import Callable
1818

19+
from shapiq.game import Game
1920
from shapiq.typing import CoalitionMatrix, FloatVector
2021

2122

@@ -123,7 +124,7 @@ def _init_kernel_weights(self, interaction_size: int) -> FloatVector:
123124
def approximate(
124125
self,
125126
budget: int,
126-
game: Callable[[np.ndarray], np.ndarray],
127+
game: Game | Callable[[np.ndarray], np.ndarray],
127128
*args: Any | None, # noqa: ARG002
128129
**kwargs: Any, # noqa: ARG002
129130
) -> InteractionValues:

0 commit comments

Comments
 (0)