Skip to content

Commit ae4b975

Browse files
committed
feat: support sklearn ensure_all_finite; deprecate force_all_finite with warning
- Add public ensure_all_finite param in utils/pairwise.py; accept force_all_finite as deprecated alias and map to ensure_all_finite with DeprecationWarning - Dynamically use ensure_all_finite/force_all_finite based on sklearn function signatures; fall back to unspecified when absent - Mirror handling in _core/_core.py for check_array - Remove unused overload causing lint error
1 parent 72aacea commit ae4b975

2 files changed

Lines changed: 75 additions & 23 deletions

File tree

src/kennard_stone/_core/_core.py

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from typing import Callable, Generator
1717

1818

19+
from inspect import signature
20+
1921
import numpy as np
2022
from numpy.typing import ArrayLike
2123
from sklearn.feature_selection import VarianceThreshold
@@ -195,19 +197,6 @@ def _iter_indices(
195197
yield ind_train.tolist(), ind_test.tolist()
196198

197199

198-
@overload
199-
def train_test_split(
200-
*arrays: T,
201-
test_size: Optional[Union[float, int]] = None,
202-
train_size: Optional[Union[float, int]] = None,
203-
metric: Union[
204-
Metrics, Callable[[ArrayLike, ArrayLike], np.ndarray]
205-
] = "euclidean",
206-
n_jobs: Optional[int] = None,
207-
device: Device = "cpu",
208-
) -> list[T]: ...
209-
210-
211200
def train_test_split(
212201
*arrays: T,
213202
test_size: Optional[Union[float, int]] = None,
@@ -426,14 +415,42 @@ def get_indexes(self, X: ArrayLike) -> list[array[int]]:
426415
The sorted indexes.
427416
"""
428417
# check input array
429-
X_checked: np.ndarray = check_array(
430-
X,
418+
# scikit-learn 1.6+ deprecates 'force_all_finite' and 1.8 renames to
419+
# 'ensure_all_finite'. Check the signature dynamically.
420+
check_array_sig = signature(check_array)
421+
supports_ensure_all_finite = (
422+
"ensure_all_finite" in check_array_sig.parameters
423+
)
424+
supports_force_all_finite = (
425+
"force_all_finite" in check_array_sig.parameters
426+
)
427+
428+
check_kwargs: dict[str, Any] = dict(
431429
ensure_2d=True,
432430
dtype="numeric",
433-
force_all_finite="allow-nan"
434-
if self.metric == "nan_euclidean"
435-
else True,
436431
)
432+
if supports_ensure_all_finite:
433+
check_kwargs["ensure_all_finite"] = (
434+
"allow-nan" if self.metric == "nan_euclidean" else True
435+
)
436+
elif supports_force_all_finite:
437+
check_kwargs["force_all_finite"] = (
438+
"allow-nan" if self.metric == "nan_euclidean" else True
439+
)
440+
441+
try:
442+
X_checked: np.ndarray = check_array(
443+
X,
444+
**check_kwargs,
445+
)
446+
except TypeError:
447+
# Fallback when the argument is not accepted at runtime
448+
check_kwargs.pop("ensure_all_finite", None)
449+
check_kwargs.pop("force_all_finite", None)
450+
X_checked = check_array(
451+
X,
452+
**check_kwargs,
453+
)
437454
n_samples = X_checked.shape[0]
438455

439456
# drop no variance

src/kennard_stone/utils/_pairwise.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
else:
88
from typing import Callable
99

10+
from inspect import signature
11+
1012
import numpy as np
1113
import sklearn.metrics.pairwise
1214
from numpy.typing import ArrayLike
@@ -25,7 +27,8 @@ def pairwise_distances(
2527
Metrics, Callable[[ArrayLike, ArrayLike], np.ndarray]
2628
] = "euclidean",
2729
n_jobs: Optional[int] = None,
28-
force_all_finite=True,
30+
ensure_all_finite: bool = True,
31+
force_all_finite: Optional[bool] = None,
2932
device: Device = "cpu",
3033
verbose: int = 1,
3134
**kwargs,
@@ -60,9 +63,13 @@ def pairwise_distances(
6063
down the pairwise matrix into n_jobs even slices and computing them in
6164
parallel. (Note: 'n_jobs' is not supported by PyTorch.)
6265
63-
force_all_finite : bool, default=True
66+
ensure_all_finite : bool, default=True
6467
Whether to raise an error on np.inf and np.nan in X.
6568
69+
force_all_finite : Optional[bool], default=None
70+
Deprecated alias of 'ensure_all_finite'. If provided, a warning is
71+
emitted and its value overrides 'ensure_all_finite'.
72+
6673
device : Literal['cpu', 'cuda', 'mps'] or torch.device or str
6774
, default="cpu"
6875
Device to use for calculating pairwise distances.
@@ -80,6 +87,14 @@ def pairwise_distances(
8087
else:
8188
available_torch = False
8289

90+
# Handle deprecated alias
91+
if force_all_finite is not None:
92+
warnings.warn(
93+
"'force_all_finite' is deprecated. Use 'ensure_all_finite' instead.",
94+
DeprecationWarning,
95+
)
96+
ensure_all_finite = force_all_finite
97+
8398
if available_torch:
8499
# Convert NumPy array to PyTorch tensor and move it to GPU
85100
X_tensor = torch.tensor(X, dtype=torch.float32, device=device)
@@ -109,14 +124,34 @@ def pairwise_distances(
109124
_logger.info(
110125
"Calculating pairwise distances using scikit-learn.\n"
111126
)
112-
return sklearn.metrics.pairwise.pairwise_distances(
113-
X,
127+
# scikit-learn 1.6+ deprecates 'force_all_finite' and 1.8 renames to
128+
# 'ensure_all_finite'. Dynamically use whichever is available.
129+
pd_sig = signature(sklearn.metrics.pairwise.pairwise_distances)
130+
supports_ensure_all_finite = "ensure_all_finite" in pd_sig.parameters
131+
supports_force_all_finite = "force_all_finite" in pd_sig.parameters
132+
133+
call_kwargs = dict(
114134
Y=Y,
115135
metric=metric,
116136
n_jobs=n_jobs,
117-
force_all_finite=force_all_finite,
118137
**kwargs,
119138
)
139+
if supports_ensure_all_finite:
140+
call_kwargs["ensure_all_finite"] = ensure_all_finite
141+
elif supports_force_all_finite:
142+
call_kwargs["force_all_finite"] = ensure_all_finite
143+
144+
try:
145+
return sklearn.metrics.pairwise.pairwise_distances(
146+
X, **call_kwargs
147+
)
148+
except TypeError:
149+
# Fallback for environments where the arg is rejected at runtime
150+
call_kwargs.pop("ensure_all_finite", None)
151+
call_kwargs.pop("force_all_finite", None)
152+
return sklearn.metrics.pairwise.pairwise_distances(
153+
X, **call_kwargs
154+
)
120155
else:
121156
if verbose > 0:
122157
_logger.info(

0 commit comments

Comments
 (0)