Skip to content
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dev/spark-test-image/python-311/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,6 @@ ARG CONNECT_PIP_PKGS="grpcio==1.67.0 grpcio-status==1.67.0 protobuf==5.29.1 goog
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.11
RUN python3.11 -m pip install --ignore-installed blinker>=1.6.2 # mlflow needs this
RUN python3.11 -m pip install $BASIC_PIP_PKGS unittest-xml-reporting $CONNECT_PIP_PKGS && \
python3.11 -m pip install 'torch<2.6.0' torchvision --index-url https://download.pytorch.org/whl/cpu && \
python3.11 -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu && \
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do other Dockerfiles need to be handled in separate pr?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch

python3.11 -m pip install deepspeed torcheval && \
python3.11 -m pip cache purge
2 changes: 1 addition & 1 deletion python/pyspark/ml/connect/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def _save_core_model(self, path: str) -> None:
def _load_core_model(self, path: str) -> None:
import torch

lor_torch_model = torch.load(path)
lor_torch_model = torch.load(path, weights_only=False)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to have the weights_only=False . From there page "Loading un-trusted checkpoint with weights_only=False MUST never be done." https://github.com/pytorch/pytorch/security

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default value of weights_only is False in <2.6.0. And the default value was changed to True since 2.6.0.
This PR keeps the behavior.

Ideally, we should use weights_only=True, it will needs some investigation. Thanks for pointing it out.

self.torch_model = lor_torch_model[0]

def _get_extra_metadata(self) -> Dict[str, Any]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ def test_save_load(self):

# test saved torch model can be loaded by pytorch solely
lor_torch_model = torch.load(
os.path.join(local_model_path, "LogisticRegressionModel.torch")
os.path.join(local_model_path, "LogisticRegressionModel.torch"),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dito but this is for a tests so I think its ok.

weights_only=False,
)

with torch.inference_mode():
Expand Down