Skip to content

Commit 62c10dd

Browse files
authored
removes for loop improving runtime of baseline imputer (#498)
1 parent af95153 commit 62c10dd

2 files changed

Lines changed: 22 additions & 4 deletions

File tree

src/shapiq/imputer/baseline_imputer.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,7 @@ def value_function(self, coalitions: np.ndarray) -> np.ndarray:
105105
``(n_subsets, n_outputs)``.
106106
107107
"""
108-
n_coalitions = coalitions.shape[0]
109-
data = np.tile(np.copy(self.x), (n_coalitions, 1))
110-
for i in range(n_coalitions):
111-
data[i, ~coalitions[i]] = self.baseline_values[0, ~coalitions[i]]
108+
data = np.where(coalitions, self.x, self.baseline_values)
112109
return self.predict(data)
113110

114111
def init_background(self, data: np.ndarray) -> BaselineImputer:

tests/shapiq/tests_unit/tests_imputer/test_baseline_imputer.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,27 @@ def model_cat(x: np.ndarray) -> np.ndarray:
9393
assert imputer.baseline_values[0, 2] == np.mean(data[:, 2])
9494

9595

96+
def test_baseline_value_function_imputes_correctly():
97+
"""Test that value_function substitutes baseline values for absent features."""
98+
captured = []
99+
100+
def model(x: np.ndarray) -> np.ndarray:
101+
captured.append(x.copy())
102+
return np.zeros(x.shape[0])
103+
104+
x = np.array([[1, 2, 3]])
105+
baseline = np.array([[10, 20, 30]])
106+
imputer = BaselineImputer(model=model, data=baseline, x=x)
107+
108+
coalitions = np.array([[True, False, True], [False, True, False]])
109+
imputer(coalitions)
110+
111+
assert len(captured) == 1
112+
result = captured[0]
113+
np.testing.assert_array_equal(result[0], [1, 20, 3])
114+
np.testing.assert_array_equal(result[1], [10, 2, 30])
115+
116+
96117
def test_baseline_imputer_init():
97118
"""Test the initialization of the marginal imputer."""
98119

0 commit comments

Comments
 (0)