77else :
88 from typing import Callable
99
10+ from inspect import signature
11+
1012import numpy as np
1113import sklearn .metrics .pairwise
1214from numpy .typing import ArrayLike
@@ -25,7 +27,8 @@ def pairwise_distances(
2527 Metrics , Callable [[ArrayLike , ArrayLike ], np .ndarray ]
2628 ] = "euclidean" ,
2729 n_jobs : Optional [int ] = None ,
28- force_all_finite = True ,
30+ ensure_all_finite : bool = True ,
31+ force_all_finite : Optional [bool ] = None ,
2932 device : Device = "cpu" ,
3033 verbose : int = 1 ,
3134 ** kwargs ,
@@ -60,9 +63,13 @@ def pairwise_distances(
6063 down the pairwise matrix into n_jobs even slices and computing them in
6164 parallel. (Note: 'n_jobs' is not supported by PyTorch.)
6265
63- force_all_finite : bool, default=True
66+ ensure_all_finite : bool, default=True
6467 Whether to raise an error on np.inf and np.nan in X.
6568
69+ force_all_finite : Optional[bool], default=None
70+ Deprecated alias of 'ensure_all_finite'. If provided, a warning is
71+ emitted and its value overrides 'ensure_all_finite'.
72+
6673 device : Literal['cpu', 'cuda', 'mps'] or torch.device or str
6774 , default="cpu"
6875 Device to use for calculating pairwise distances.
@@ -80,6 +87,14 @@ def pairwise_distances(
8087 else :
8188 available_torch = False
8289
90+ # Handle deprecated alias
91+ if force_all_finite is not None :
92+ warnings .warn (
93+ "'force_all_finite' is deprecated. Use 'ensure_all_finite' instead." ,
94+ DeprecationWarning ,
95+ )
96+ ensure_all_finite = force_all_finite
97+
8398 if available_torch :
8499 # Convert NumPy array to PyTorch tensor and move it to GPU
85100 X_tensor = torch .tensor (X , dtype = torch .float32 , device = device )
@@ -109,14 +124,34 @@ def pairwise_distances(
109124 _logger .info (
110125 "Calculating pairwise distances using scikit-learn.\n "
111126 )
112- return sklearn .metrics .pairwise .pairwise_distances (
113- X ,
127+ # scikit-learn 1.6+ deprecates 'force_all_finite' and 1.8 renames to
128+ # 'ensure_all_finite'. Dynamically use whichever is available.
129+ pd_sig = signature (sklearn .metrics .pairwise .pairwise_distances )
130+ supports_ensure_all_finite = "ensure_all_finite" in pd_sig .parameters
131+ supports_force_all_finite = "force_all_finite" in pd_sig .parameters
132+
133+ call_kwargs = dict (
114134 Y = Y ,
115135 metric = metric ,
116136 n_jobs = n_jobs ,
117- force_all_finite = force_all_finite ,
118137 ** kwargs ,
119138 )
139+ if supports_ensure_all_finite :
140+ call_kwargs ["ensure_all_finite" ] = ensure_all_finite
141+ elif supports_force_all_finite :
142+ call_kwargs ["force_all_finite" ] = ensure_all_finite
143+
144+ try :
145+ return sklearn .metrics .pairwise .pairwise_distances (
146+ X , ** call_kwargs
147+ )
148+ except TypeError :
149+ # Fallback for environments where the arg is rejected at runtime
150+ call_kwargs .pop ("ensure_all_finite" , None )
151+ call_kwargs .pop ("force_all_finite" , None )
152+ return sklearn .metrics .pairwise .pairwise_distances (
153+ X , ** call_kwargs
154+ )
120155 else :
121156 if verbose > 0 :
122157 _logger .info (
0 commit comments