Skip to content

Commit 0a354bc

Browse files
gaogaotiantianzhengruifeng
authored andcommitted
[SPARK-55076][PYTHON] Fix the type hint issue in ml/mllib and add scipy requirement
### What changes were proposed in this pull request? * Pin scipy version to >=1.8.0, which is the first minor version to support 3.10 * Install scipy on lint image so we can find scipy related lint failures * Add `sparray` as that's the preferred type for scipy now * Expand `VectorLike` to include other vector like types to simplify our code * Replace some `type(x)` check with `isinstance()` because that's the recommended way and mypy understands it * Fix a few `numpy` 1 vs 2 related type hints so they can pass with both versions * Add a few assertions to make mypy happy about attributes ### Why are the changes needed? Currently, local `mypy` check will fail with a lot of failures due to scipy/numpy because our lint image does not include those stubs. This is bad because it's really hard for people to do mypy check locally - they'll think that their environment setup has issues so `mypy` result is not to be trusted. We want to make `mypy` result consistent between CI and local and make it clean. ### Does this PR introduce _any_ user-facing change? It should not. Almost all changes are type annotation related. ### How was this patch tested? CI should pass. ### Was this patch authored or co-authored using generative AI tooling? No Closes #53841 from gaogaotiantian/fix-ml-typehint. Authored-by: Tian Gao <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent 69f1d5c commit 0a354bc

File tree

10 files changed

+74
-60
lines changed

10 files changed

+74
-60
lines changed

dev/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ numpy>=1.22
66
pyarrow>=18.0.0
77
six==1.16.0
88
pandas>=2.2.0
9-
scipy
9+
scipy>=1.8.0
1010
plotly<6.0.0
1111
mlflow>=2.3.1
1212
scikit-learn

dev/spark-test-image/lint/Dockerfile

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ RUN python3.11 -m pip install \
9999
'pyarrow>=22.0.0' \
100100
'pytest-mypy-plugins==1.9.3' \
101101
'pytest==7.1.3' \
102+
'scipy>=1.8.0' \
103+
'scipy-stubs' \
102104
&& python3.11 -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu \
103105
&& python3.11 -m pip install torcheval \
104106
&& python3.11 -m pip cache purge

python/pyspark/ml/_typing.pyi

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,20 @@
1616
# specific language governing permissions and limitations
1717
# under the License.
1818

19-
from typing import Any, Dict, List, TypeVar, Tuple, Union
19+
from typing import Any, Dict, List, TYPE_CHECKING, TypeVar, Tuple, Union
2020
from typing_extensions import Literal
2121

2222
from numpy import ndarray
2323
from py4j.java_gateway import JavaObject
2424

2525
import pyspark.ml.base
2626
import pyspark.ml.param
27-
import pyspark.ml.util
2827
from pyspark.ml.linalg import Vector
2928
import pyspark.ml.wrapper
3029

30+
if TYPE_CHECKING:
31+
from scipy.sparse import spmatrix, sparray
32+
3133
ParamMap = Dict[pyspark.ml.param.Param, Any]
3234
PipelineStage = Union[pyspark.ml.base.Estimator, pyspark.ml.base.Transformer]
3335

@@ -81,4 +83,4 @@ RankingEvaluatorMetricType = Union[
8183
Literal["recallAtK"],
8284
]
8385

84-
VectorLike = Union[ndarray, Vector, List[float], Tuple[float, ...]]
86+
VectorLike = Union[ndarray, Vector, List[float], Tuple[float, ...], "spmatrix", "sparray", range]

python/pyspark/ml/linalg/__init__.py

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@
7070
if TYPE_CHECKING:
7171
from pyspark.mllib._typing import NormType
7272
from pyspark.ml._typing import VectorLike
73-
from scipy.sparse import spmatrix
7473

7574

7675
# Check whether we have SciPy. MLlib works without it too, but if we have it, some methods,
@@ -85,23 +84,25 @@
8584
_have_scipy = False
8685

8786

88-
def _convert_to_vector(d: Union["VectorLike", "spmatrix", range]) -> "Vector":
87+
def _convert_to_vector(d: "VectorLike") -> "Vector":
8988
if isinstance(d, Vector):
9089
return d
91-
elif type(d) in (array.array, np.array, np.ndarray, list, tuple, range):
90+
elif isinstance(d, (array.array, np.ndarray, list, tuple, range)):
9291
return DenseVector(d)
9392
elif _have_scipy and scipy.sparse.issparse(d):
94-
assert cast("spmatrix", d).shape[1] == 1, "Expected column vector"
93+
assert hasattr(d, "shape")
94+
assert d.shape[1] == 1, "Expected column vector"
9595
# Make sure the converted csc_matrix has sorted indices.
96-
csc = cast("spmatrix", d).tocsc()
96+
assert hasattr(d, "tocsc")
97+
csc = d.tocsc()
9798
if not csc.has_sorted_indices:
9899
csc.sort_indices()
99-
return SparseVector(cast("spmatrix", d).shape[0], csc.indices, csc.data)
100+
return SparseVector(d.shape[0], csc.indices, csc.data)
100101
else:
101102
raise TypeError("Cannot convert type %s into Vector" % type(d))
102103

103104

104-
def _vector_size(v: Union["VectorLike", "spmatrix", range]) -> int:
105+
def _vector_size(v: "VectorLike") -> int:
105106
"""
106107
Returns the size of the vector.
107108
@@ -124,16 +125,17 @@ def _vector_size(v: Union["VectorLike", "spmatrix", range]) -> int:
124125
"""
125126
if isinstance(v, Vector):
126127
return len(v)
127-
elif type(v) in (array.array, list, tuple, range):
128+
elif isinstance(v, (array.array, list, tuple, range)):
128129
return len(v)
129-
elif type(v) == np.ndarray:
130+
elif isinstance(v, np.ndarray):
130131
if v.ndim == 1 or (v.ndim == 2 and v.shape[1] == 1):
131132
return len(v)
132133
else:
133134
raise ValueError("Cannot treat an ndarray of shape %s as a vector" % str(v.shape))
134135
elif _have_scipy and scipy.sparse.issparse(v):
135-
assert cast("spmatrix", v).shape[1] == 1, "Expected column vector"
136-
return cast("spmatrix", v).shape[0]
136+
assert hasattr(v, "shape")
137+
assert v.shape[1] == 1, "Expected column vector"
138+
return v.shape[0]
137139
else:
138140
raise TypeError("Cannot treat type %s as a vector" % type(v))
139141

@@ -337,13 +339,13 @@ def __init__(self, ar: Union[bytes, np.ndarray, Iterable[float]]):
337339
def __reduce__(self) -> Tuple[Type["DenseVector"], Tuple[bytes]]:
338340
return DenseVector, (self.array.tobytes(),)
339341

340-
def numNonzeros(self) -> int:
342+
def numNonzeros(self) -> Union[int, np.intp]:
341343
"""
342344
Number of nonzero elements. This scans all active values and count non zeros
343345
"""
344346
return np.count_nonzero(self.array)
345347

346-
def norm(self, p: "NormType") -> np.float64:
348+
def norm(self, p: "NormType") -> np.floating[Any]:
347349
"""
348350
Calculates the norm of a DenseVector.
349351
@@ -386,15 +388,17 @@ def dot(self, other: Iterable[float]) -> np.float64:
386388
...
387389
AssertionError: dimension mismatch
388390
"""
389-
if type(other) == np.ndarray:
391+
if isinstance(other, np.ndarray):
390392
if other.ndim > 1:
391393
assert len(self) == other.shape[0], "dimension mismatch"
392394
return np.dot(self.array, other)
393395
elif _have_scipy and scipy.sparse.issparse(other):
394-
assert len(self) == cast("spmatrix", other).shape[0], "dimension mismatch"
395-
return cast("spmatrix", other).transpose().dot(self.toArray())
396+
assert hasattr(other, "shape")
397+
assert len(self) == other.shape[0], "dimension mismatch"
398+
assert hasattr(other, "transpose")
399+
return other.transpose().dot(self.toArray())
396400
else:
397-
assert len(self) == _vector_size(other), "dimension mismatch"
401+
assert len(self) == _vector_size(other), "dimension mismatch" # type: ignore[arg-type]
398402
if isinstance(other, SparseVector):
399403
return other.dot(self)
400404
elif isinstance(other, Vector):
@@ -429,10 +433,11 @@ def squared_distance(self, other: Iterable[float]) -> np.float64:
429433
...
430434
AssertionError: dimension mismatch
431435
"""
432-
assert len(self) == _vector_size(other), "dimension mismatch"
436+
assert len(self) == _vector_size(other), "dimension mismatch" # type: ignore[arg-type]
433437
if isinstance(other, SparseVector):
434438
return other.squared_distance(self)
435439
elif _have_scipy and scipy.sparse.issparse(other):
440+
assert isinstance(other, scipy.sparse.spmatrix), "other must be a scipy.sparse.spmatrix"
436441
return _convert_to_vector(other).squared_distance(self) # type: ignore[attr-defined]
437442

438443
if isinstance(other, Vector):
@@ -636,13 +641,13 @@ def __init__(
636641
)
637642
assert np.min(self.indices) >= 0, "Contains negative index %d" % (np.min(self.indices))
638643

639-
def numNonzeros(self) -> int:
644+
def numNonzeros(self) -> Union[int, np.intp]:
640645
"""
641646
Number of nonzero elements. This scans all active values and count non zeros.
642647
"""
643648
return np.count_nonzero(self.values)
644649

645-
def norm(self, p: "NormType") -> np.float64:
650+
def norm(self, p: "NormType") -> np.floating[Any]:
646651
"""
647652
Calculates the norm of a SparseVector.
648653
@@ -699,7 +704,7 @@ def dot(self, other: Iterable[float]) -> np.float64:
699704
assert len(self) == other.shape[0], "dimension mismatch"
700705
return np.dot(self.values, other[self.indices])
701706

702-
assert len(self) == _vector_size(other), "dimension mismatch"
707+
assert len(self) == _vector_size(other), "dimension mismatch" # type: ignore[arg-type]
703708

704709
if isinstance(other, DenseVector):
705710
return np.dot(other.array[self.indices], self.values)
@@ -717,7 +722,7 @@ def dot(self, other: Iterable[float]) -> np.float64:
717722
else:
718723
return self.dot(_convert_to_vector(other)) # type: ignore[arg-type]
719724

720-
def squared_distance(self, other: Iterable[float]) -> np.float64:
725+
def squared_distance(self, other: "VectorLike") -> np.float64:
721726
"""
722727
Squared distance from a SparseVector or 1-dimensional NumPy array.
723728
@@ -785,7 +790,7 @@ def squared_distance(self, other: Iterable[float]) -> np.float64:
785790
j += 1
786791
return result
787792
else:
788-
return self.squared_distance(_convert_to_vector(other)) # type: ignore[arg-type]
793+
return self.squared_distance(_convert_to_vector(other))
789794

790795
def toArray(self) -> np.ndarray:
791796
"""

python/pyspark/mllib/_typing.pyi

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,22 @@
1616
# specific language governing permissions and limitations
1717
# under the License.
1818

19-
from typing import List, Tuple, TypeVar, Union
19+
from typing import List, Tuple, TYPE_CHECKING, TypeVar, Union
2020

2121
from typing_extensions import Literal
2222
from numpy import ndarray # noqa: F401
2323
from py4j.java_gateway import JavaObject
2424

2525
from pyspark.mllib.linalg import Vector
2626

27-
VectorLike = Union[ndarray, Vector, List[float], Tuple[float, ...]]
27+
if TYPE_CHECKING:
28+
from scipy.sparse import spmatrix, sparray
29+
2830
C = TypeVar("C", bound=type)
2931
JavaObjectOrPickleDump = Union[JavaObject, bytearray, bytes]
3032

3133
CorrMethodType = Union[Literal["spearman"], Literal["pearson"]]
3234
KolmogorovSmirnovTestDistNameType = Literal["norm"]
3335
NormType = Union[None, float, Literal["fro"], Literal["nuc"]]
36+
37+
VectorLike = Union[ndarray, Vector, List[float], Tuple[float, ...], "spmatrix", "sparray", range]

python/pyspark/mllib/linalg/__init__.py

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@
6161

6262
if TYPE_CHECKING:
6363
from pyspark.mllib._typing import VectorLike, NormType
64-
from scipy.sparse import spmatrix
6564
from numpy.typing import ArrayLike
6665

6766

@@ -94,23 +93,25 @@
9493
_have_scipy = False
9594

9695

97-
def _convert_to_vector(d: Union["VectorLike", "spmatrix", range]) -> "Vector":
96+
def _convert_to_vector(d: "VectorLike") -> "Vector":
9897
if isinstance(d, Vector):
9998
return d
100-
elif type(d) in (array.array, np.array, np.ndarray, list, tuple, range):
99+
elif isinstance(d, (array.array, np.ndarray, list, tuple, range)):
101100
return DenseVector(d)
102101
elif _have_scipy and scipy.sparse.issparse(d):
103-
assert cast("spmatrix", d).shape[1] == 1, "Expected column vector"
102+
assert hasattr(d, "shape")
103+
assert d.shape[1] == 1, "Expected column vector"
104104
# Make sure the converted csc_matrix has sorted indices.
105-
csc = cast("spmatrix", d).tocsc()
105+
assert hasattr(d, "tocsc")
106+
csc = d.tocsc()
106107
if not csc.has_sorted_indices:
107108
csc.sort_indices()
108-
return SparseVector(cast("spmatrix", d).shape[0], csc.indices, csc.data)
109+
return SparseVector(d.shape[0], csc.indices, csc.data)
109110
else:
110111
raise TypeError("Cannot convert type %s into Vector" % type(d))
111112

112113

113-
def _vector_size(v: Union["VectorLike", "spmatrix", range]) -> int:
114+
def _vector_size(v: "VectorLike") -> int:
114115
"""
115116
Returns the size of the vector.
116117
@@ -133,16 +134,17 @@ def _vector_size(v: Union["VectorLike", "spmatrix", range]) -> int:
133134
"""
134135
if isinstance(v, Vector):
135136
return len(v)
136-
elif type(v) in (array.array, list, tuple, range):
137+
elif isinstance(v, (array.array, list, tuple, range)):
137138
return len(v)
138-
elif type(v) == np.ndarray:
139+
elif isinstance(v, np.ndarray):
139140
if v.ndim == 1 or (v.ndim == 2 and v.shape[1] == 1):
140141
return len(v)
141142
else:
142143
raise ValueError("Cannot treat an ndarray of shape %s as a vector" % str(v.shape))
143144
elif _have_scipy and scipy.sparse.issparse(v):
144-
assert cast("spmatrix", v).shape[1] == 1, "Expected column vector"
145-
return cast("spmatrix", v).shape[0]
145+
assert hasattr(v, "shape")
146+
assert v.shape[1] == 1, "Expected column vector"
147+
return v.shape[0]
146148
else:
147149
raise TypeError("Cannot treat type %s as a vector" % type(v))
148150

@@ -390,13 +392,13 @@ def parse(s: str) -> "DenseVector":
390392
def __reduce__(self) -> Tuple[Type["DenseVector"], Tuple[bytes]]:
391393
return DenseVector, (self.array.tobytes(),)
392394

393-
def numNonzeros(self) -> int:
395+
def numNonzeros(self) -> Union[int, np.intp]:
394396
"""
395397
Number of nonzero elements. This scans all active values and count non zeros
396398
"""
397399
return np.count_nonzero(self.array)
398400

399-
def norm(self, p: "NormType") -> np.float64:
401+
def norm(self, p: "NormType") -> np.floating[Any]:
400402
"""
401403
Calculates the norm of a DenseVector.
402404
@@ -410,7 +412,7 @@ def norm(self, p: "NormType") -> np.float64:
410412
"""
411413
return np.linalg.norm(self.array, p)
412414

413-
def dot(self, other: Iterable[float]) -> np.float64:
415+
def dot(self, other: "VectorLike") -> np.float64:
414416
"""
415417
Compute the dot product of two Vectors. We support
416418
(Numpy array, list, SparseVector, or SciPy sparse)
@@ -444,8 +446,10 @@ def dot(self, other: Iterable[float]) -> np.float64:
444446
assert len(self) == other.shape[0], "dimension mismatch"
445447
return np.dot(self.array, other)
446448
elif _have_scipy and scipy.sparse.issparse(other):
447-
assert len(self) == cast("spmatrix", other).shape[0], "dimension mismatch"
448-
return cast("spmatrix", other).transpose().dot(self.toArray())
449+
assert hasattr(other, "shape")
450+
assert len(self) == other.shape[0], "dimension mismatch"
451+
assert hasattr(other, "transpose")
452+
return other.transpose().dot(self.toArray())
449453
else:
450454
assert len(self) == _vector_size(other), "dimension mismatch"
451455
if isinstance(other, SparseVector):
@@ -455,7 +459,7 @@ def dot(self, other: Iterable[float]) -> np.float64:
455459
else:
456460
return np.dot(self.toArray(), cast("ArrayLike", other))
457461

458-
def squared_distance(self, other: Iterable[float]) -> np.float64:
462+
def squared_distance(self, other: "VectorLike") -> np.float64:
459463
"""
460464
Squared distance of two Vectors.
461465
@@ -685,13 +689,13 @@ def __init__(
685689
% (self.indices[i], self.indices[i + 1])
686690
)
687691

688-
def numNonzeros(self) -> int:
692+
def numNonzeros(self) -> Union[int, np.intp]:
689693
"""
690694
Number of nonzero elements. This scans all active values and count non zeros.
691695
"""
692696
return np.count_nonzero(self.values)
693697

694-
def norm(self, p: "NormType") -> np.float64:
698+
def norm(self, p: "NormType") -> np.floating[Any]:
695699
"""
696700
Calculates the norm of a SparseVector.
697701
@@ -766,7 +770,7 @@ def parse(s: str) -> "SparseVector":
766770
raise ValueError("Unable to parse values from %s." % s)
767771
return SparseVector(cast(int, size), indices, values)
768772

769-
def dot(self, other: Iterable[float]) -> np.float64:
773+
def dot(self, other: "VectorLike") -> np.float64:
770774
"""
771775
Dot product with a SparseVector or 1- or 2-dimensional Numpy array.
772776
@@ -822,9 +826,9 @@ def dot(self, other: Iterable[float]) -> np.float64:
822826
return np.dot(self_values, other.values[other_cmind])
823827

824828
else:
825-
return self.dot(_convert_to_vector(other)) # type: ignore[arg-type]
829+
return self.dot(_convert_to_vector(other))
826830

827-
def squared_distance(self, other: Iterable[float]) -> np.float64:
831+
def squared_distance(self, other: "VectorLike") -> np.float64:
828832
"""
829833
Squared distance from a SparseVector or 1-dimensional NumPy array.
830834
@@ -892,7 +896,7 @@ def squared_distance(self, other: Iterable[float]) -> np.float64:
892896
j += 1
893897
return result
894898
else:
895-
return self.squared_distance(_convert_to_vector(other)) # type: ignore[arg-type]
899+
return self.squared_distance(_convert_to_vector(other))
896900

897901
def toArray(self) -> np.ndarray:
898902
"""

python/pyspark/mllib/linalg/distributed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
VT = TypeVar("VT", bound="Matrix")
3636

3737
if TYPE_CHECKING:
38-
from pyspark.ml._typing import VectorLike
38+
from pyspark.mllib._typing import VectorLike
3939

4040
__all__ = [
4141
"BlockMatrix",

0 commit comments

Comments
 (0)