Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -842,7 +842,7 @@ jobs:
python-version: '3.12'
- name: Install dependencies for Python CodeGen check
run: |
python3.12 -m pip install 'black==26.3.1' 'protobuf==6.33.5' 'mypy==1.8.0' 'mypy-protobuf==3.3.0'
python3.12 -m pip install 'ruff==0.14.8' 'protobuf==6.33.5' 'mypy==1.8.0' 'mypy-protobuf==3.3.0'
python3.12 -m pip list
- name: Python CodeGen check for branch-3.5
if: inputs.branch == 'branch-3.5'
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pages.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ jobs:
run: |
pip install 'sphinx==4.5.0' mkdocs 'pydata_sphinx_theme>=0.13' sphinx-copybutton nbsphinx numpydoc jinja2 markupsafe 'pyzmq<24.0.0' \
ipython ipython_genutils sphinx_plotly_directive 'numpy>=1.22' pyarrow 'pandas==2.3.3' 'plotly>=4.8' 'docutils<0.18.0' \
'flake8==3.9.0' 'mypy==1.8.0' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' 'black==26.3.1' \
'flake8==3.9.0' 'mypy==1.8.0' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' 'ruff==0.14.8' \
'pandas-stubs==1.2.0.53' 'grpcio==1.76.0' 'grpcio-status==1.76.0' 'protobuf==6.33.5' 'grpc-stubs==1.24.11' 'googleapis-common-protos-stubs==2.2.0' \
'sphinxcontrib-applehelp==1.0.4' 'sphinxcontrib-devhelp==1.0.2' 'sphinxcontrib-htmlhelp==2.0.1' 'sphinxcontrib-qthelp==1.0.3' 'sphinxcontrib-serializinghtml==1.1.5'
- name: Install Ruby for documentation generation
Expand Down
2 changes: 1 addition & 1 deletion dev/create-release/spark-rm/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ RUN python3.10 -m pip install 'sphinx==4.5.0' mkdocs 'pydata_sphinx_theme>=0.13'
sphinx-copybutton nbsphinx numpydoc jinja2 markupsafe 'pyzmq<24.0.0' \
ipython ipython_genutils sphinx_plotly_directive 'numpy>=1.22' pyarrow pandas \
'plotly>=4.8' 'docutils<0.18.0' 'flake8==3.9.0' 'mypy==1.19.1' 'pytest==7.1.3' \
'pytest-mypy-plugins==1.9.3' 'black==26.3.1' 'pandas-stubs==1.2.0.53' \
'pytest-mypy-plugins==1.9.3' 'ruff==0.14.8' 'pandas-stubs==1.2.0.53' \
'grpcio==1.76.0' 'grpc-stubs==1.24.11' 'googleapis-common-protos-stubs==2.2.0' \
'sphinxcontrib-applehelp==1.0.4' 'sphinxcontrib-devhelp==1.0.2' \
'sphinxcontrib-htmlhelp==2.0.1' 'sphinxcontrib-qthelp==1.0.3' \
Expand Down
2 changes: 1 addition & 1 deletion dev/gen-protos.sh
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ for f in `find gen/proto/python -name "*.py*"`; do
rm $f.bak
done

black --config $SPARK_HOME/pyproject.toml gen/proto/python
ruff format --config $SPARK_HOME/pyproject.toml gen/proto/python

# Last step copy the result files to the destination module.
for f in `find gen/proto/python -name "*.py*"`; do
Expand Down
2 changes: 1 addition & 1 deletion dev/is-changed.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def parse_opts():
"--modules",
type=str,
default=default_value,
help="A comma-separated list of modules to test " "(default: %s)" % default_value,
help="A comma-separated list of modules to test (default: %s)" % default_value,
)

args, unknown = parser.parse_known_args()
Expand Down
15 changes: 14 additions & 1 deletion dev/lint-python
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ while (( "$#" )); do
shift
done

if [[ -z "$COMPILE_TEST$BLACK_TEST$PYSPARK_CUSTOM_ERRORS_CHECK_TEST$FLAKE8_TEST$RUFF_TEST$MYPY_TEST$MYPY_EXAMPLES_TEST$MYPY_DATA_TEST" ]]; then
if [[ -z "$COMPILE_TEST$PYSPARK_CUSTOM_ERRORS_CHECK_TEST$FLAKE8_TEST$RUFF_TEST$MYPY_TEST$MYPY_EXAMPLES_TEST$MYPY_DATA_TEST" ]]; then
COMPILE_TEST=true
BLACK_TEST=true
PYSPARK_CUSTOM_ERRORS_CHECK_TEST=true
Expand Down Expand Up @@ -315,6 +315,19 @@ ruff checks failed."
echo
fi

RUFF_REPORT=$( ($RUFF_BUILD format --diff python/pyspark dev python/packaging python/benchmarks) 2>&1)
RUFF_STATUS=$?

if [ "$RUFF_STATUS" -ne 0 ]; then
echo "ruff format checks failed:"
echo "$RUFF_REPORT"
echo "$RUFF_STATUS"
exit "$RUFF_STATUS"
else
echo "ruff format checks passed."
echo
fi

}

function black_test {
Expand Down
10 changes: 5 additions & 5 deletions dev/reformat-python
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
FWDIR="$( cd "$DIR"/.. && pwd )"
cd "$FWDIR"

BLACK_BUILD="${PYTHON_EXECUTABLE} -m black"
BLACK_VERSION="26.3.1"
$PYTHON_EXECUTABLE -c 'import black' 2> /dev/null
RUFF_BUILD="${PYTHON_EXECUTABLE} -m ruff"
RUFF_VERSION="0.14.8"
$PYTHON_EXECUTABLE -c 'import ruff' 2> /dev/null
if [ $? -ne 0 ]; then
echo "The Python library providing the 'black' module was not found. Please install Black, for example, via 'pip install black==$BLACK_VERSION'."
echo "The Python library providing the 'ruff' module was not found. Please install Ruff, for example, via 'pip install ruff==$RUFF_VERSION'."
exit 1
fi

$BLACK_BUILD python/pyspark dev python/packaging python/benchmarks
$RUFF_BUILD format python/pyspark dev python/packaging python/benchmarks
2 changes: 1 addition & 1 deletion dev/spark-test-image/connect-gen-protos/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ ENV PATH="$VIRTUAL_ENV/bin:$PATH"
RUN python3.12 -m pip install \
'mypy==1.19.1' \
'mypy-protobuf==3.3.0' \
'black==26.3.1'
'ruff==0.14.8'

# Mount the Spark repo at /spark
WORKDIR /spark
Expand Down
2 changes: 1 addition & 1 deletion dev/spark-test-image/docs/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ ENV PATH="$VIRTUAL_ENV/bin:$PATH"
# See 'docutils<0.18.0' in SPARK-39421
RUN python3.12 -m pip install 'sphinx==4.5.0' mkdocs 'pydata_sphinx_theme>=0.13' sphinx-copybutton nbsphinx numpydoc jinja2 markupsafe \
ipython ipython_genutils sphinx_plotly_directive 'numpy>=1.22' 'pyarrow>=23.0.0' 'pandas==2.3.3' 'plotly>=4.8' 'docutils<0.18.0' \
'flake8==3.9.0' 'mypy==1.19.1' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' 'black==26.3.1' \
'flake8==3.9.0' 'mypy==1.19.1' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' 'ruff==0.14.8' \
'pandas-stubs==1.2.0.53' 'grpcio==1.76.0' 'grpcio-status==1.76.0' 'protobuf==6.33.5' 'grpc-stubs==1.24.11' 'googleapis-common-protos-stubs==2.2.0' \
'sphinxcontrib-applehelp==1.0.4' 'sphinxcontrib-devhelp==1.0.2' 'sphinxcontrib-htmlhelp==2.0.1' 'sphinxcontrib-qthelp==1.0.3' 'sphinxcontrib-serializinghtml==1.1.5' \
&& python3.12 -m pip cache purge
1 change: 0 additions & 1 deletion dev/spark-test-image/lint/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ RUN python3.12 -m venv $VIRTUAL_ENV
ENV PATH="$VIRTUAL_ENV/bin:$PATH"

RUN python3.12 -m pip install \
'black==26.3.1' \
'flake8==3.9.0' \
'ruff==0.14.8' \
'googleapis-common-protos-stubs==2.2.0' \
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ exclude = [
"*python/pyspark/sql/streaming/proto/*",
"*venv*/*",
]
line-length = 100
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was this changed from 88?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we use 100 length line "Code style guide" at https://spark.apache.org/contributing.html


[tool.ruff.lint]
extend-select = [
Expand Down
5 changes: 2 additions & 3 deletions python/conf_viztracer/daemon_viztracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@


def viztracer_wrapper(func):

def wrapper(*args, **kwargs):
result = func(*args, **kwargs)
tracer = viztracer.get_tracer()
if tracer is not None:
tracer.exit_routine()
return result

return wrapper


Expand All @@ -43,6 +43,5 @@ def wrapper(*args, **kwargs):
else:
output_dir = "./"

sys.argv[:] = ["viztracer", "-m", "pyspark.daemon", "--quiet", "-u",
"--output_dir", output_dir]
sys.argv[:] = ["viztracer", "-m", "pyspark.daemon", "--quiet", "-u", "--output_dir", output_dir]
main()
5 changes: 2 additions & 3 deletions python/conf_vscode/sitecustomize.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@
import os
import sys

if (
"DEBUGPY_ADAPTER_ENDPOINTS" in os.environ and
not any("debugpy" in arg for arg in sys.orig_argv)
if "DEBUGPY_ADAPTER_ENDPOINTS" in os.environ and not any(
"debugpy" in arg for arg in sys.orig_argv
):

def install_debugpy():
Expand Down
4 changes: 1 addition & 3 deletions python/pyspark/accumulators.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,9 +383,7 @@ def _start_update_server(
os.remove(socket_path)
server = AccumulatorUnixServer(socket_path, UpdateRequestHandler)
else:
server = AccumulatorTCPServer(
("localhost", 0), UpdateRequestHandler, auth_token
) # type: ignore[assignment]
server = AccumulatorTCPServer(("localhost", 0), UpdateRequestHandler, auth_token) # type: ignore[assignment]

thread = threading.Thread(target=server.serve_forever)
thread.daemon = True
Expand Down
6 changes: 4 additions & 2 deletions python/pyspark/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,8 @@ def signal_handler(signal: Any, frame: Any) -> NoReturn:

# see http://stackoverflow.com/questions/23206787/
if isinstance(
threading.current_thread(), threading._MainThread # type: ignore[attr-defined]
threading.current_thread(),
threading._MainThread, # type: ignore[attr-defined]
):
signal.signal(signal.SIGINT, signal_handler)

Expand Down Expand Up @@ -857,7 +858,8 @@ def f(split: int, iterator: Iterable[T]) -> Iterable:
if "__len__" not in dir(c):
c = list(c) # Make it a list so we can compute its length
batchSize = max(
1, min(len(c) // numSlices, self._batchSize or 1024) # type: ignore[arg-type]
1,
min(len(c) // numSlices, self._batchSize or 1024), # type: ignore[arg-type]
)
serializer = BatchedSerializer(self._unbatched_serializer, batchSize)

Expand Down
3 changes: 1 addition & 2 deletions python/pyspark/core/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1731,8 +1731,7 @@ def collectWithJobGroup(
:meth:`SparkContext.setJobGroup`
"""
warnings.warn(
"Deprecated in 3.1, Use pyspark.InheritableThread with "
"the pinned thread mode enabled.",
"Deprecated in 3.1, Use pyspark.InheritableThread with the pinned thread mode enabled.",
FutureWarning,
)

Expand Down
5 changes: 1 addition & 4 deletions python/pyspark/errors/error_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@
# For more information, please see: https://issues.apache.org/jira/browse/SPARK-46810
# This discrepancy will be resolved as part of: https://issues.apache.org/jira/browse/SPARK-47429
ERROR_CLASSES_JSON = (
importlib.resources
.files("pyspark.errors")
.joinpath("error-conditions.json")
.read_text()
importlib.resources.files("pyspark.errors").joinpath("error-conditions.json").read_text()
)
ERROR_CLASSES_MAP = json.loads(ERROR_CLASSES_JSON)
2 changes: 1 addition & 1 deletion python/pyspark/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def checked_versions(spark_version, hadoop_version, hive_version):
spark_version = "spark-%s" % spark_version
if not spark_version.startswith("spark-"):
raise RuntimeError(
"Spark version should start with 'spark-' prefix; however, " "got %s" % spark_version
"Spark version should start with 'spark-' prefix; however, got %s" % spark_version
)

if hadoop_version == "without":
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -3565,7 +3565,7 @@ def _fit(self, dataset: DataFrame) -> "OneVsRestModel":
weightCol = self.getWeightCol()
else:
warnings.warn(
"weightCol is ignored, " "as it is not supported by {} now.".format(classifier)
"weightCol is ignored, as it is not supported by {} now.".format(classifier)
)

if weightCol:
Expand Down
3 changes: 1 addition & 2 deletions python/pyspark/ml/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1270,8 +1270,7 @@ class _LDAParams(HasMaxIter, HasFeaturesCol, HasSeed, HasCheckpointInterval):
optimizer: Param[str] = Param(
Params._dummy(),
"optimizer",
"Optimizer or inference algorithm used to estimate the LDA model. "
"Supported: online, em",
"Optimizer or inference algorithm used to estimate the LDA model. Supported: online, em",
typeConverter=TypeConverters.toString,
)
learningOffset: Param[float] = Param(
Expand Down
4 changes: 1 addition & 3 deletions python/pyspark/ml/connect/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,7 @@ class Pipeline(Estimator["PipelineModel"], _PipelineReadWrite):
>>> loaded_pipeline_model = PipelineModel.loadFromLocal("/tmp/pipeline")
"""

stages: Param[List[Params]] = Param(
Params._dummy(), "stages", "a list of pipeline stages"
) # type: ignore[assignment]
stages: Param[List[Params]] = Param(Params._dummy(), "stages", "a list of pipeline stages") # type: ignore[assignment]

_input_kwargs: Dict[str, Any]

Expand Down
8 changes: 2 additions & 6 deletions python/pyspark/ml/connect/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,14 +203,10 @@ def _get_skip_saving_params(self) -> List[str]:

def _save_meta_algorithm(self, root_path: str, node_path: List[str]) -> Dict[str, Any]:
metadata = self._get_metadata_to_save()
metadata[
"estimator"
] = self.getEstimator()._save_to_node_path( # type: ignore[attr-defined]
metadata["estimator"] = self.getEstimator()._save_to_node_path( # type: ignore[attr-defined]
root_path, node_path + ["crossvalidator_estimator"]
)
metadata[
"evaluator"
] = self.getEvaluator()._save_to_node_path( # type: ignore[attr-defined]
metadata["evaluator"] = self.getEvaluator()._save_to_node_path( # type: ignore[attr-defined]
root_path, node_path + ["crossvalidator_evaluator"]
)
metadata["estimator_param_maps"] = [
Expand Down
5 changes: 2 additions & 3 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -1961,8 +1961,7 @@ class _ImputerParams(HasInputCol, HasInputCols, HasOutputCol, HasOutputCols, Has
missingValue: Param[float] = Param(
Params._dummy(),
"missingValue",
"The placeholder for the missing values. All occurrences of missingValue "
"will be imputed.",
"The placeholder for the missing values. All occurrences of missingValue will be imputed.",
typeConverter=TypeConverters.toFloat,
)

Expand Down Expand Up @@ -3824,7 +3823,7 @@ class _RobustScalerParams(HasInputCol, HasOutputCol, HasRelativeError):
withScaling: Param[bool] = Param(
Params._dummy(),
"withScaling",
"Whether to scale the data to " "quantile range",
"Whether to scale the data to quantile range",
typeConverter=TypeConverters.toBoolean,
)

Expand Down
4 changes: 1 addition & 3 deletions python/pyspark/ml/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,9 +826,7 @@ def predict(data: Iterator[Union[pd.Series, pd.DataFrame]]) -> Iterator[pd.DataF
raise ValueError(msg.format(num_expected_cols, num_input_cols))

# return transformed predictions to Spark
yield _validate_and_transform_prediction_result(
preds, num_input_rows, return_type
) # type: ignore
yield _validate_and_transform_prediction_result(preds, num_input_rows, return_type) # type: ignore

return pandas_udf(predict, return_type) # type: ignore[call-overload]

Expand Down
23 changes: 12 additions & 11 deletions python/pyspark/ml/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,9 @@ def serialize(
def deserialize(
self, datum: Tuple[int, Optional[int], Optional[List[int]], List[float]]
) -> "Vector":
assert (
len(datum) == 4
), "VectorUDT.deserialize given row with length %d but requires 4" % len(datum)
assert len(datum) == 4, (
"VectorUDT.deserialize given row with length %d but requires 4" % len(datum)
)
tpe = datum[0]
if tpe == 0:
return SparseVector(cast(int, datum[1]), cast(List[int], datum[2]), datum[3])
Expand Down Expand Up @@ -265,9 +265,9 @@ def deserialize(
self,
datum: Tuple[int, int, int, Optional[List[int]], Optional[List[int]], List[float], bool],
) -> "Matrix":
assert (
len(datum) == 7
), "MatrixUDT.deserialize given row with length %d but requires 7" % len(datum)
assert len(datum) == 7, (
"MatrixUDT.deserialize given row with length %d but requires 7" % len(datum)
)
tpe = datum[0]
if tpe == 0:
return SparseMatrix(*datum[1:]) # type: ignore[arg-type]
Expand Down Expand Up @@ -625,11 +625,12 @@ def __init__(
)

if self.indices.size > 0:
assert (
np.max(self.indices) < self.size
), "Index %d is out of the size of vector with size=%d" % (
np.max(self.indices),
self.size,
assert np.max(self.indices) < self.size, (
"Index %d is out of the size of vector with size=%d"
% (
np.max(self.indices),
self.size,
)
)
assert np.min(self.indices) >= 0, "Contains negative index %d" % (np.min(self.indices))

Expand Down
6 changes: 4 additions & 2 deletions python/pyspark/ml/torch/distributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,8 @@ def _check_encryption(self) -> None:
)
)
return
raise RuntimeError(textwrap.dedent(f"""
raise RuntimeError(
textwrap.dedent(f"""
This cluster has TLS encryption enabled;
however, {name} does not support
data encryption in transit. To override
Expand All @@ -285,7 +286,8 @@ def _check_encryption(self) -> None:
to 'true' in the Spark configuration. Please note this
will cause model parameters and possibly training
data to be sent between nodes unencrypted.
"""))
""")
)


class TorchDistributor(Distributor):
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/ml/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ class _RandomForestParams(_TreeEnsembleParams):
bootstrap: Param[bool] = Param(
Params._dummy(),
"bootstrap",
"Whether bootstrap samples are used " "when building trees.",
"Whether bootstrap samples are used when building trees.",
typeConverter=TypeConverters.toBoolean,
)

Expand Down
5 changes: 4 additions & 1 deletion python/pyspark/mllib/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,7 +935,10 @@ def setInitialWeights(

# LogisticRegressionWithSGD does only binary classification.
self._model = LogisticRegressionModel(
initialWeights, 0, initialWeights.size, 2 # type: ignore[attr-defined]
initialWeights,
0,
initialWeights.size, # type: ignore[attr-defined]
2,
)
return self

Expand Down
6 changes: 2 additions & 4 deletions python/pyspark/mllib/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,12 +1020,10 @@ def latestModel(self) -> Optional[StreamingKMeansModel]:
def _validate(self, dstream: Any) -> None:
if self._model is None:
raise ValueError(
"Initial centers should be set either by setInitialCenters " "or setRandomCenters."
"Initial centers should be set either by setInitialCenters or setRandomCenters."
)
if not isinstance(dstream, DStream):
raise TypeError(
"Expected dstream to be of type DStream, " "got type %s" % type(dstream)
)
raise TypeError("Expected dstream to be of type DStream, got type %s" % type(dstream))

@since("1.5.0")
def setK(self, k: int) -> "StreamingKMeans":
Expand Down
Loading