55import contextlib
66import copy
77import json
8- import pickle
98from pathlib import Path
109from typing import TYPE_CHECKING
1110from warnings import warn
1918 is_index_aggregated ,
2019 is_index_valid ,
2120)
22- from .utils .errors import raise_deprecation_warning
23- from .utils .saving import safe_str_to_tuple , safe_tuple_to_str
2421from .utils .sets import generate_interaction_lookup
2522
2623if TYPE_CHECKING :
3229
3330 from shapiq .typing import InteractionScores , JSONType
3431
35- SAVE_JSON_DEPRECATION_MSG = (
36- "Saving InteractionValues not as a JSON file is deprecated. "
37- "The parameters `as_pickle` and `as_npz` will be removed in the future. "
38- )
39-
4032
4133class InteractionValues :
4234 """This class contains the interaction values as estimated by an approximator.
@@ -773,61 +765,18 @@ def get_subset(self, players: list[int]) -> InteractionValues:
773765 baseline_value = self .baseline_value ,
774766 )
775767
776- def save (self , path : Path , * , as_pickle : bool = False , as_npz : bool = False ) -> None :
777- """Save the InteractionValues object to a file.
778-
779- By default, the InteractionValues object is saved as a JSON file.
768+ def save (self , path : Path ) -> None :
769+ """Save the InteractionValues object to a JSON file.
780770
781771 Args:
782772 path: The path to save the InteractionValues object to.
783- as_pickle: Whether to save the InteractionValues object as a pickle file (``True``).
784- as_npz: Whether to save the InteractionValues object as a ``npz`` file (``True``).
785-
786- Raises:
787- DeprecationWarning: If `as_pickle` or `as_npz` is set to ``True``, a deprecation
788- warning is raised
789773 """
790774 # check if the directory exists
791775 directory = Path (path ).parent
792776 if not Path (directory ).exists ():
793777 with contextlib .suppress (FileNotFoundError ):
794778 Path (directory ).mkdir (parents = True , exist_ok = True )
795- if as_pickle :
796- raise_deprecation_warning (
797- message = SAVE_JSON_DEPRECATION_MSG ,
798- deprecated_in = "1.3.1" ,
799- removed_in = "1.4.0" ,
800- )
801- with Path (path ).open ("wb" ) as file :
802- pickle .dump (self , file )
803- elif as_npz :
804- raise_deprecation_warning (
805- message = SAVE_JSON_DEPRECATION_MSG ,
806- deprecated_in = "1.3.1" ,
807- removed_in = "1.4.0" ,
808- )
809- # save object as npz file
810- interaction_keys = np .array (
811- list (map (safe_tuple_to_str , self .interaction_lookup .keys ()))
812- )
813- interaction_indices = np .array (list (self .interaction_lookup .values ()))
814- estimation_budget = self .estimation_budget if self .estimation_budget is not None else - 1
815-
816- np .savez (
817- path ,
818- values = self .values ,
819- index = self .index ,
820- max_order = self .max_order ,
821- n_players = self .n_players ,
822- min_order = self .min_order ,
823- interaction_lookup_keys = interaction_keys ,
824- interaction_lookup_indices = interaction_indices ,
825- estimated = self .estimated ,
826- estimation_budget = estimation_budget ,
827- baseline_value = self .baseline_value ,
828- )
829- else :
830- self .to_json_file (path )
779+ self .to_json_file (path )
831780
832781 @classmethod
833782 def load (cls , path : Path | str ) -> InteractionValues :
@@ -841,46 +790,10 @@ def load(cls, path: Path | str) -> InteractionValues:
841790
842791 """
843792 path = Path (path )
844- # check if path ends with .json
845- if path .name .endswith (".json" ):
846- return cls .from_json_file (path )
847-
848- raise_deprecation_warning (
849- SAVE_JSON_DEPRECATION_MSG , deprecated_in = "1.3.1" , removed_in = "1.4.0"
850- )
851-
852- # try loading as npz file
853- if path .name .endswith (".npz" ):
854- data = np .load (path , allow_pickle = True )
855- try :
856- # try to load Pyright save format
857- interaction_lookup = {
858- safe_str_to_tuple (key ): int (value )
859- for key , value in zip (
860- data ["interaction_lookup_keys" ],
861- data ["interaction_lookup_indices" ],
862- strict = False ,
863- )
864- }
865- except KeyError :
866- # fallback to old format
867- interaction_lookup = data ["interaction_lookup" ].item ()
868- estimation_budget = data ["estimation_budget" ].item ()
869- if estimation_budget == - 1 :
870- estimation_budget = None
871- return InteractionValues (
872- values = data ["values" ],
873- index = str (data ["index" ]),
874- max_order = int (data ["max_order" ]),
875- n_players = int (data ["n_players" ]),
876- min_order = int (data ["min_order" ]),
877- interaction_lookup = interaction_lookup ,
878- estimated = bool (data ["estimated" ]),
879- estimation_budget = estimation_budget ,
880- baseline_value = float (data ["baseline_value" ]),
881- )
882- msg = f"Path { path } does not end with .json or .npz. Cannot load InteractionValues."
883- raise ValueError (msg )
793+ if not path .name .endswith (".json" ):
794+ msg = f"Path { path } does not end with .json. Cannot load InteractionValues."
795+ raise ValueError (msg )
796+ return cls .from_json_file (path )
884797
885798 @classmethod
886799 def from_dict (cls , data : dict [str , Any ]) -> InteractionValues :
0 commit comments