|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import copy |
| 4 | +from collections.abc import Callable |
| 5 | + |
| 6 | +import numpy as np |
| 7 | +from sparse_transform.qsft.qsft import transform as sparse_fourier_transform |
| 8 | +from sparse_transform.qsft.signals.input_signal_subsampled import ( |
| 9 | + SubsampledSignal as SubsampledSignalFourier, |
| 10 | +) |
| 11 | +from sparse_transform.qsft.utils.general import fourier_to_mobius as fourier_to_moebius |
| 12 | +from sparse_transform.qsft.utils.query import get_bch_decoder |
| 13 | + |
| 14 | +from ...game_theory.indices import is_index_aggregated |
| 15 | +from ...game_theory.moebius_converter import MoebiusConverter |
| 16 | +from ...interaction_values import InteractionValues |
| 17 | +from .._base import Approximator |
| 18 | + |
| 19 | + |
| 20 | +class Sparse(Approximator): |
| 21 | + """Approximator for interaction values using sparse transformation techniques. |
| 22 | +
|
| 23 | + This class implements a sparse approximation method for computing various interaction indices |
| 24 | + using sparse Fourier transforms. It efficiently estimates interaction values with a limited |
| 25 | + sample budget by leveraging sparsity in the Fourier domain. |
| 26 | +
|
| 27 | + Attributes: |
| 28 | + transform_type (str): Type of transform used (currently only "fourier" is supported). |
| 29 | + t (int): Error parameter for the sparse Fourier transform (currently fixed to 5). |
| 30 | + query_args (dict): Parameters for querying the signal. |
| 31 | + decoder_args (dict): Parameters for decoding the transform. |
| 32 | +
|
| 33 | + Args: |
| 34 | + n (int): Number of players/features. |
| 35 | + index (str): Type of interaction index to compute (e.g., "STII", "FBII", "FSII"). |
| 36 | + max_order (int, optional): Maximum interaction order to compute. It is not suggested to use this parameter |
| 37 | + since sparse approximation dynamically and implicitly adjusts the order based on the budget and function. |
| 38 | + top_order (bool, optional): If True, only compute interactions of exactly max_order. |
| 39 | + If False, compute interactions up to max_order. Defaults to False. |
| 40 | + random_state (int, optional): Random seed for reproducibility. Defaults to None. |
| 41 | + transform_type (str, optional): Type of transform to use. Currently only "fourier" |
| 42 | + is supported. Defaults to "fourier". |
| 43 | + decoder_type (str, optional): Type of decoder to use, either "soft" or "hard". |
| 44 | + Defaults to "soft" |
| 45 | +
|
| 46 | + Raises: |
| 47 | + ValueError: If transform_type is not "fourier" or if decoder_type is not "soft" or "hard". |
| 48 | + """ |
| 49 | + |
| 50 | + def __init__( |
| 51 | + self, |
| 52 | + n: int, |
| 53 | + index: str, |
| 54 | + max_order: int | None = None, |
| 55 | + top_order: bool = False, |
| 56 | + random_state: int | None = None, |
| 57 | + transform_type: str = "fourier", |
| 58 | + decoder_type: str = "soft", |
| 59 | + ) -> None: |
| 60 | + if transform_type.lower() not in ["fourier"]: |
| 61 | + msg = "transform_type must be 'fourier'" |
| 62 | + raise ValueError(msg) |
| 63 | + self.transform_type = transform_type.lower() |
| 64 | + self.t = 5 # 5 could be a parameter |
| 65 | + self.decoder_type = "hard" if decoder_type is None else decoder_type.lower() |
| 66 | + if self.decoder_type not in ["soft", "hard"]: |
| 67 | + msg = "decoder_type must be 'soft' or 'hard'" |
| 68 | + raise ValueError(msg) |
| 69 | + # The sampling parameters for the Fourier transform |
| 70 | + self.query_args = { |
| 71 | + "query_method": "complex", |
| 72 | + "num_subsample": 3, |
| 73 | + "delays_method_source": "joint-coded", |
| 74 | + "subsampling_method": "qsft", |
| 75 | + "delays_method_channel": "identity-siso", |
| 76 | + "num_repeat": 1, |
| 77 | + "t": self.t, |
| 78 | + } |
| 79 | + self.decoder_args = { |
| 80 | + "num_subsample": 3, |
| 81 | + "num_repeat": 1, |
| 82 | + "reconstruct_method_source": "coded", |
| 83 | + "peeling_method": "multi-detect", |
| 84 | + "reconstruct_method_channel": "identity-siso" |
| 85 | + if self.decoder_type == "soft" |
| 86 | + else "identity", |
| 87 | + "regress": "lasso", |
| 88 | + "res_energy_cutoff": 0.9, |
| 89 | + "source_decoder": get_bch_decoder(n, self.t, self.decoder_type), |
| 90 | + } |
| 91 | + super().__init__( |
| 92 | + n=n, |
| 93 | + max_order=n if max_order is None else max_order, |
| 94 | + index=index, |
| 95 | + top_order=top_order, |
| 96 | + random_state=random_state, |
| 97 | + initialize_dict=False, # Important for performance |
| 98 | + ) |
| 99 | + |
| 100 | + def approximate( |
| 101 | + self, |
| 102 | + budget: int, |
| 103 | + game: Callable[[np.ndarray], np.ndarray], |
| 104 | + ) -> InteractionValues: |
| 105 | + """Approximates the interaction values using a sparse transform approach. |
| 106 | +
|
| 107 | + Args: |
| 108 | + budget: The budget for the approximation. |
| 109 | + game: The game function that returns the values for the coalitions. |
| 110 | +
|
| 111 | + Returns: |
| 112 | + The approximated Shapley interaction values. |
| 113 | + """ |
| 114 | + # Find the maximum value of b that fits within the given sample budget and get the used budget |
| 115 | + used_budget = self._set_transform_budget(budget) |
| 116 | + signal = SubsampledSignalFourier( |
| 117 | + func=lambda inputs: game(inputs.astype(bool)), |
| 118 | + n=self.n, |
| 119 | + q=2, |
| 120 | + query_args=self.query_args, |
| 121 | + ) |
| 122 | + # Extract the coefficients of the original transform |
| 123 | + initial_transform = { |
| 124 | + tuple(np.nonzero(key)[0]): np.real(value) |
| 125 | + for key, value in sparse_fourier_transform(signal, **self.decoder_args).items() |
| 126 | + } |
| 127 | + # If we are using the fourier transform, we need to convert it to a Moebius transform |
| 128 | + moebius_transform = fourier_to_moebius(initial_transform) |
| 129 | + # Convert the Moebius transform to the desired index |
| 130 | + result = self._process_moebius(moebius_transform=moebius_transform) |
| 131 | + # Filter the output as needed |
| 132 | + if self.top_order: |
| 133 | + result = self._filter_order(result) |
| 134 | + output = InteractionValues( |
| 135 | + values=result, |
| 136 | + index=self.approximation_index, |
| 137 | + min_order=self.min_order, |
| 138 | + max_order=self.max_order, |
| 139 | + n_players=self.n, |
| 140 | + interaction_lookup=copy.deepcopy(self.interaction_lookup), |
| 141 | + estimated=True, |
| 142 | + estimation_budget=used_budget, |
| 143 | + baseline_value=self.interaction_lookup.get((), 0.0), |
| 144 | + ) |
| 145 | + # Update the interaction lookup to reflect the filtered results |
| 146 | + if is_index_aggregated(self.index): |
| 147 | + output = self.aggregate_interaction_values(output) |
| 148 | + return output |
| 149 | + |
| 150 | + def _filter_order(self, result: np.ndarray) -> np.ndarray: |
| 151 | + """Filters the interactions to keep only those of the maximum order. |
| 152 | +
|
| 153 | + This method is used when top_order=True to filter out all interactions that are not |
| 154 | + of exactly the maximum order (self.max_order). |
| 155 | +
|
| 156 | + Args: |
| 157 | + result: Array of interaction values. |
| 158 | +
|
| 159 | + Returns: |
| 160 | + Filtered array containing only interaction values of the maximum order. |
| 161 | + The method also updates the internal _interaction_lookup dictionary. |
| 162 | + """ |
| 163 | + filtered_interactions = {} |
| 164 | + filtered_results = [] |
| 165 | + i = 0 |
| 166 | + for j, key in enumerate(self.interaction_lookup): |
| 167 | + if len(key) == self.max_order: |
| 168 | + filtered_interactions[key] = i |
| 169 | + filtered_results.append(result[j]) |
| 170 | + i += 1 |
| 171 | + self._interaction_lookup = filtered_interactions |
| 172 | + return np.array(filtered_results) |
| 173 | + |
| 174 | + def _process_moebius(self, moebius_transform: dict[tuple, float]) -> np.ndarray: |
| 175 | + """Processes the Moebius transform to extract the desired index. |
| 176 | +
|
| 177 | + Args: |
| 178 | + moebius_transform: The Moebius transform to process (dict mapping tuples to float values). |
| 179 | +
|
| 180 | + Returns: |
| 181 | + np.ndarray: The converted interaction values based on the specified index. |
| 182 | + The function also updates the internal _interaction_lookup dictionary. |
| 183 | + """ |
| 184 | + moebius_interactions = InteractionValues( |
| 185 | + values=np.array([moebius_transform[key] for key in moebius_transform.keys()]), |
| 186 | + index="Moebius", |
| 187 | + min_order=self.min_order, |
| 188 | + max_order=self.max_order, |
| 189 | + n_players=self.n, |
| 190 | + interaction_lookup={key: i for i, key in enumerate(moebius_transform.keys())}, |
| 191 | + estimated=True, |
| 192 | + baseline_value=moebius_transform.get((), 0.0), |
| 193 | + ) |
| 194 | + autoconverter = MoebiusConverter(moebius_coefficients=moebius_interactions) |
| 195 | + converted_interaction_values = autoconverter(index=self.index, order=self.max_order) |
| 196 | + self._interaction_lookup = converted_interaction_values.interaction_lookup |
| 197 | + return converted_interaction_values.values |
| 198 | + |
| 199 | + def _set_transform_budget(self, budget: int) -> int: |
| 200 | + """Sets the appropriate transform budget parameters based on the given sample budget. |
| 201 | +
|
| 202 | + This method calculates the maximum possible 'b' parameter (number of bits to subsample) |
| 203 | + that fits within the provided budget, then configures the query and decoder arguments |
| 204 | + accordingly. The actual number of samples that will be used is returned. |
| 205 | +
|
| 206 | + Args: |
| 207 | + budget: The maximum number of samples allowed for the approximation. |
| 208 | +
|
| 209 | + Returns: |
| 210 | + int: The actual number of samples that will be used, which is less than or equal to the budget. |
| 211 | +
|
| 212 | + Raises: |
| 213 | + ValueError: If the budget is too low to compute the transform with acceptable parameters. |
| 214 | + """ |
| 215 | + b = SubsampledSignalFourier.get_b_for_sample_budget( |
| 216 | + budget, self.n, self.t, 2, self.query_args |
| 217 | + ) |
| 218 | + used_budget = SubsampledSignalFourier.get_number_of_samples( |
| 219 | + self.n, b, self.t, 2, self.query_args |
| 220 | + ) |
| 221 | + |
| 222 | + if b <= 2: |
| 223 | + while self.t > 2: |
| 224 | + self.t -= 1 |
| 225 | + self.query_args["t"] = self.t |
| 226 | + |
| 227 | + # Recalculate 'b' with the updated 't' |
| 228 | + b = SubsampledSignalFourier.get_b_for_sample_budget( |
| 229 | + budget, self.n, self.t, 2, self.query_args |
| 230 | + ) |
| 231 | + |
| 232 | + # Compute the used budget |
| 233 | + used_budget = SubsampledSignalFourier.get_number_of_samples( |
| 234 | + self.n, b, self.t, 2, self.query_args |
| 235 | + ) |
| 236 | + |
| 237 | + # Break if 'b' is now sufficient |
| 238 | + if b > 2: |
| 239 | + self.decoder_args["source_decoder"] = get_bch_decoder( |
| 240 | + self.n, self.t, self.decoder_type |
| 241 | + ) |
| 242 | + break |
| 243 | + |
| 244 | + # If 'b' is still too low, raise an error |
| 245 | + if b <= 2: |
| 246 | + msg = "Insufficient budget to compute the transform. Increase the budget or use a different approximator." |
| 247 | + raise ValueError(msg) |
| 248 | + # Store the final 'b' value |
| 249 | + self.query_args["b"] = b |
| 250 | + self.decoder_args["b"] = b |
| 251 | + return used_budget |
0 commit comments