Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions shapiq/games/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,30 @@ def precompute(self, coalitions: np.ndarray | None = None) -> None:
self.coalition_lookup = coalitions_dict
self.precompute_flag = True

def compute(
self, coalitions: np.ndarray | None = None, *, return_normalization: bool = False
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you actually always return the normalization_value as the third thing and not normalize the values in line 440. Then this function would be more usable because it's the "raw" computation.

) -> tuple:
"""Compute the game values for all or a given set of coalitions.

Args:
coalitions: The coalitions to evaluate.
return_normalization: Whether to return the normalization value. Defaults to ``False``.

Returns:
A tuple containing:
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I deleted the return types because sphinx will display automatically and then we do not have to maintain two parts of the code. :)

- np.ndarray: The values of the coalitions.
- dict[tuple[int, ...], int]: The lookup of the coalitions
- float: The normalization value (optional, if return_normalization is 'True')

"""
coalitions: np.ndarray = self._check_coalitions(coalitions)
values = self.value_function(coalitions)
game_values = values - self.normalization_value

if return_normalization:
return (game_values, self.coalition_lookup, self.normalization_value)
return game_values, self.coalition_lookup
Comment thread
mmschlk marked this conversation as resolved.
Outdated

def save_values(self, path: Path | str) -> None:
"""Saves the game values to the given path.

Expand Down
25 changes: 25 additions & 0 deletions tests/tests_games/test_base_game.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,3 +326,28 @@ def test_exact_computer_call():
sv = game.exact_values(index=index, order=order)
assert sv.index == index
assert sv.max_order == order


def test_compute_with_and_without_normalization():
"""Tests the compute function with and without returned normalization."""
normalization_value = 1.0 # not zero

n_players = 3
game = DummyGame(n=n_players, interaction=(0, 1))

coalitions = np.array([[1, 0, 0], [0, 1, 1]])

# Make sure normalization value is added
game.normalization_value = normalization_value
assert game.normalize

# Test without returned normalization
result = game.compute(coalitions=coalitions, return_normalization=False)
assert len(result[0]) == len(coalitions)
assert len(result) == 2 # game_values and coalition_lookup

# Test with returned normalization
result = game.compute(coalitions=coalitions, return_normalization=True)
assert len(result[0]) == len(coalitions)
assert result[2] == 1.0 # normalization_value
assert len(result) == 3 # game_values, normalization_value and coalition_lookup