Skip to content

Commit 47574ba

Browse files
zhengruifengdongjoon-hyun
authored andcommitted
[SPARK-53453][PYTHON][ML] Unblock 'torch<2.6.0'
### What changes were proposed in this pull request? Unblock 'torch<2.6.0' ### Why are the changes needed? to test with latest torch ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? CI ### Was this patch authored or co-authored using generative AI tooling? No Closes #52197 from zhengruifeng/torch_280. Authored-by: Ruifeng Zheng <ruifengz@apache.org> Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
1 parent ce763c1 commit 47574ba

File tree

7 files changed

+8
-7
lines changed

7 files changed

+8
-7
lines changed

dev/spark-test-image/python-310/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,6 @@ RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.10
7272
RUN python3.10 -m pip install --ignore-installed blinker>=1.6.2 # mlflow needs this
7373
RUN python3.10 -m pip install --ignore-installed 'six==1.16.0' # Avoid `python3-six` installation
7474
RUN python3.10 -m pip install $BASIC_PIP_PKGS unittest-xml-reporting $CONNECT_PIP_PKGS && \
75-
python3.10 -m pip install 'torch<2.6.0' torchvision --index-url https://download.pytorch.org/whl/cpu && \
75+
python3.10 -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu && \
7676
python3.10 -m pip install deepspeed torcheval && \
7777
python3.10 -m pip cache purge

dev/spark-test-image/python-311-classic-only/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,6 @@ ARG TEST_PIP_PKGS="coverage unittest-xml-reporting"
7474
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.11
7575
RUN python3.11 -m pip install --ignore-installed blinker>=1.6.2 # mlflow needs this
7676
RUN python3.11 -m pip install $BASIC_PIP_PKGS $TEST_PIP_PKGS && \
77-
python3.11 -m pip install 'torch<2.6.0' torchvision --index-url https://download.pytorch.org/whl/cpu && \
77+
python3.11 -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu && \
7878
python3.11 -m pip install deepspeed torcheval && \
7979
python3.11 -m pip cache purge

dev/spark-test-image/python-311/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,6 @@ ARG CONNECT_PIP_PKGS="grpcio==1.67.0 grpcio-status==1.67.0 protobuf==5.29.1 goog
7575
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.11
7676
RUN python3.11 -m pip install --ignore-installed blinker>=1.6.2 # mlflow needs this
7777
RUN python3.11 -m pip install $BASIC_PIP_PKGS unittest-xml-reporting $CONNECT_PIP_PKGS && \
78-
python3.11 -m pip install 'torch<2.6.0' torchvision --index-url https://download.pytorch.org/whl/cpu && \
78+
python3.11 -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu && \
7979
python3.11 -m pip install deepspeed torcheval && \
8080
python3.11 -m pip cache purge

dev/spark-test-image/python-312/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,6 @@ ARG CONNECT_PIP_PKGS="grpcio==1.67.0 grpcio-status==1.67.0 protobuf==5.29.1 goog
7575
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.12
7676
RUN python3.12 -m pip install --ignore-installed blinker>=1.6.2 # mlflow needs this
7777
RUN python3.12 -m pip install $BASIC_PIP_PKGS unittest-xml-reporting $CONNECT_PIP_PKGS lxml && \
78-
python3.12 -m pip install 'torch<2.6.0' torchvision --index-url https://download.pytorch.org/whl/cpu && \
78+
python3.12 -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu && \
7979
python3.12 -m pip install torcheval && \
8080
python3.12 -m pip cache purge

dev/spark-test-image/python-313/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,6 @@ ARG CONNECT_PIP_PKGS="grpcio==1.67.0 grpcio-status==1.67.0 protobuf==5.29.1 goog
7575
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.13
7676
RUN python3.13 -m pip install --ignore-installed blinker>=1.6.2 # mlflow needs this
7777
RUN python3.13 -m pip install $BASIC_PIP_PKGS unittest-xml-reporting $CONNECT_PIP_PKGS lxml && \
78-
python3.13 -m pip install 'torch<2.6.0' torchvision --index-url https://download.pytorch.org/whl/cpu && \
78+
python3.13 -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu && \
7979
python3.13 -m pip install torcheval && \
8080
python3.13 -m pip cache purge

python/pyspark/ml/connect/classification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ def _save_core_model(self, path: str) -> None:
381381
def _load_core_model(self, path: str) -> None:
382382
import torch
383383

384-
lor_torch_model = torch.load(path)
384+
lor_torch_model = torch.load(path, weights_only=False)
385385
self.torch_model = lor_torch_model[0]
386386

387387
def _get_extra_metadata(self) -> Dict[str, Any]:

python/pyspark/ml/tests/connect/test_legacy_mode_classification.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,8 @@ def test_save_load(self):
183183

184184
# test saved torch model can be loaded by pytorch solely
185185
lor_torch_model = torch.load(
186-
os.path.join(local_model_path, "LogisticRegressionModel.torch")
186+
os.path.join(local_model_path, "LogisticRegressionModel.torch"),
187+
weights_only=False,
187188
)
188189

189190
with torch.inference_mode():

0 commit comments

Comments
 (0)