Skip to content
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -842,7 +842,7 @@ jobs:
python-version: '3.12'
- name: Install dependencies for Python CodeGen check
run: |
python3.12 -m pip install 'black==26.3.1' 'protobuf==6.33.5' 'mypy==1.8.0' 'mypy-protobuf==3.3.0'
python3.12 -m pip install 'ruff==0.14.8' 'protobuf==6.33.5' 'mypy==1.8.0' 'mypy-protobuf==3.3.0'
python3.12 -m pip list
- name: Python CodeGen check for branch-3.5
if: inputs.branch == 'branch-3.5'
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pages.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ jobs:
run: |
pip install 'sphinx==4.5.0' mkdocs 'pydata_sphinx_theme>=0.13' sphinx-copybutton nbsphinx numpydoc jinja2 markupsafe 'pyzmq<24.0.0' \
ipython ipython_genutils sphinx_plotly_directive 'numpy>=1.22' pyarrow 'pandas==2.3.3' 'plotly>=4.8' 'docutils<0.18.0' \
'flake8==3.9.0' 'mypy==1.8.0' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' 'black==26.3.1' \
'flake8==3.9.0' 'mypy==1.8.0' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' 'ruff==0.14.8' \
'pandas-stubs==1.2.0.53' 'grpcio==1.76.0' 'grpcio-status==1.76.0' 'protobuf==6.33.5' 'grpc-stubs==1.24.11' 'googleapis-common-protos-stubs==2.2.0' \
'sphinxcontrib-applehelp==1.0.4' 'sphinxcontrib-devhelp==1.0.2' 'sphinxcontrib-htmlhelp==2.0.1' 'sphinxcontrib-qthelp==1.0.3' 'sphinxcontrib-serializinghtml==1.1.5'
- name: Install Ruby for documentation generation
Expand Down
2 changes: 1 addition & 1 deletion dev/create-release/spark-rm/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ RUN python3.10 -m pip install 'sphinx==4.5.0' mkdocs 'pydata_sphinx_theme>=0.13'
sphinx-copybutton nbsphinx numpydoc jinja2 markupsafe 'pyzmq<24.0.0' \
ipython ipython_genutils sphinx_plotly_directive 'numpy>=1.22' pyarrow pandas \
'plotly>=4.8' 'docutils<0.18.0' 'flake8==3.9.0' 'mypy==1.19.1' 'pytest==7.1.3' \
'pytest-mypy-plugins==1.9.3' 'black==26.3.1' 'pandas-stubs==1.2.0.53' \
'pytest-mypy-plugins==1.9.3' 'ruff==0.14.8' 'pandas-stubs==1.2.0.53' \
'grpcio==1.76.0' 'grpc-stubs==1.24.11' 'googleapis-common-protos-stubs==2.2.0' \
'sphinxcontrib-applehelp==1.0.4' 'sphinxcontrib-devhelp==1.0.2' \
'sphinxcontrib-htmlhelp==2.0.1' 'sphinxcontrib-qthelp==1.0.3' \
Expand Down
2 changes: 1 addition & 1 deletion dev/gen-protos.sh
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ for f in `find gen/proto/python -name "*.py*"`; do
rm $f.bak
done
black --config $SPARK_HOME/pyproject.toml gen/proto/python
ruff format --config $SPARK_HOME/pyproject.toml gen/proto/python
# Last step copy the result files to the destination module.
for f in `find gen/proto/python -name "*.py*"`; do
Expand Down
2 changes: 1 addition & 1 deletion dev/is-changed.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def parse_opts():
"--modules",
type=str,
default=default_value,
help="A comma-separated list of modules to test " "(default: %s)" % default_value,
help="A comma-separated list of modules to test (default: %s)" % default_value,
)

args, unknown = parser.parse_known_args()
Expand Down
15 changes: 14 additions & 1 deletion dev/lint-python
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ while (( "$#" )); do
shift
done

if [[ -z "$COMPILE_TEST$BLACK_TEST$PYSPARK_CUSTOM_ERRORS_CHECK_TEST$FLAKE8_TEST$RUFF_TEST$MYPY_TEST$MYPY_EXAMPLES_TEST$MYPY_DATA_TEST" ]]; then
if [[ -z "$COMPILE_TEST$PYSPARK_CUSTOM_ERRORS_CHECK_TEST$FLAKE8_TEST$RUFF_TEST$MYPY_TEST$MYPY_EXAMPLES_TEST$MYPY_DATA_TEST" ]]; then
COMPILE_TEST=true
BLACK_TEST=true
PYSPARK_CUSTOM_ERRORS_CHECK_TEST=true
Expand Down Expand Up @@ -315,6 +315,19 @@ ruff checks failed."
echo
fi

RUFF_REPORT=$( ($RUFF_BUILD format --diff python/pyspark dev python/packaging python/benchmarks) 2>&1)
RUFF_STATUS=$?

if [ "$RUFF_STATUS" -ne 0 ]; then
echo "ruff format checks failed:"
echo "$RUFF_REPORT"
echo "$RUFF_STATUS"
exit "$RUFF_STATUS"
else
echo "ruff format checks passed."
echo
fi

}

function black_test {
Expand Down
10 changes: 5 additions & 5 deletions dev/reformat-python
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
FWDIR="$( cd "$DIR"/.. && pwd )"
cd "$FWDIR"

BLACK_BUILD="${PYTHON_EXECUTABLE} -m black"
BLACK_VERSION="26.3.1"
$PYTHON_EXECUTABLE -c 'import black' 2> /dev/null
RUFF_BUILD="${PYTHON_EXECUTABLE} -m ruff"
RUFF_VERSION="0.14.8"
$PYTHON_EXECUTABLE -c 'import ruff' 2> /dev/null
if [ $? -ne 0 ]; then
echo "The Python library providing the 'black' module was not found. Please install Black, for example, via 'pip install black==$BLACK_VERSION'."
echo "The Python library providing the 'ruff' module was not found. Please install Ruff, for example, via 'pip install ruff==$RUFF_VERSION'."
exit 1
fi

$BLACK_BUILD python/pyspark dev python/packaging python/benchmarks
$RUFF_BUILD format python/pyspark dev python/packaging python/benchmarks
2 changes: 1 addition & 1 deletion dev/spark-test-image/docs/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ ENV PATH="$VIRTUAL_ENV/bin:$PATH"
# See 'docutils<0.18.0' in SPARK-39421
RUN python3.12 -m pip install 'sphinx==4.5.0' mkdocs 'pydata_sphinx_theme>=0.13' sphinx-copybutton nbsphinx numpydoc jinja2 markupsafe \
ipython ipython_genutils sphinx_plotly_directive 'numpy>=1.22' 'pyarrow>=23.0.0' 'pandas==2.3.3' 'plotly>=4.8' 'docutils<0.18.0' \
'flake8==3.9.0' 'mypy==1.19.1' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' 'black==26.3.1' \
'flake8==3.9.0' 'mypy==1.19.1' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' 'ruff==0.14.8' \
'pandas-stubs==1.2.0.53' 'grpcio==1.76.0' 'grpcio-status==1.76.0' 'protobuf==6.33.5' 'grpc-stubs==1.24.11' 'googleapis-common-protos-stubs==2.2.0' \
'sphinxcontrib-applehelp==1.0.4' 'sphinxcontrib-devhelp==1.0.2' 'sphinxcontrib-htmlhelp==2.0.1' 'sphinxcontrib-qthelp==1.0.3' 'sphinxcontrib-serializinghtml==1.1.5' \
&& python3.12 -m pip cache purge
1 change: 0 additions & 1 deletion dev/spark-test-image/lint/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ RUN python3.12 -m venv $VIRTUAL_ENV
ENV PATH="$VIRTUAL_ENV/bin:$PATH"

RUN python3.12 -m pip install \
'black==26.3.1' \
'flake8==3.9.0' \
'ruff==0.14.8' \
'googleapis-common-protos-stubs==2.2.0' \
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ exclude = [
"*python/pyspark/sql/streaming/proto/*",
"*venv*/*",
]
line-length = 100
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was this changed from 88?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we use 100 length line "Code style guide" at https://spark.apache.org/contributing.html


[tool.ruff.lint]
extend-select = [
Expand Down
5 changes: 2 additions & 3 deletions python/conf_viztracer/daemon_viztracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@


def viztracer_wrapper(func):

def wrapper(*args, **kwargs):
result = func(*args, **kwargs)
tracer = viztracer.get_tracer()
if tracer is not None:
tracer.exit_routine()
return result

return wrapper


Expand All @@ -43,6 +43,5 @@ def wrapper(*args, **kwargs):
else:
output_dir = "./"

sys.argv[:] = ["viztracer", "-m", "pyspark.daemon", "--quiet", "-u",
"--output_dir", output_dir]
sys.argv[:] = ["viztracer", "-m", "pyspark.daemon", "--quiet", "-u", "--output_dir", output_dir]
main()
5 changes: 2 additions & 3 deletions python/conf_vscode/sitecustomize.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@
import os
import sys

if (
"DEBUGPY_ADAPTER_ENDPOINTS" in os.environ and
not any("debugpy" in arg for arg in sys.orig_argv)
if "DEBUGPY_ADAPTER_ENDPOINTS" in os.environ and not any(
"debugpy" in arg for arg in sys.orig_argv
):

def install_debugpy():
Expand Down
4 changes: 1 addition & 3 deletions python/pyspark/accumulators.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,9 +383,7 @@ def _start_update_server(
os.remove(socket_path)
server = AccumulatorUnixServer(socket_path, UpdateRequestHandler)
else:
server = AccumulatorTCPServer(
("localhost", 0), UpdateRequestHandler, auth_token
) # type: ignore[assignment]
server = AccumulatorTCPServer(("localhost", 0), UpdateRequestHandler, auth_token) # type: ignore[assignment]

thread = threading.Thread(target=server.serve_forever)
thread.daemon = True
Expand Down
6 changes: 4 additions & 2 deletions python/pyspark/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,8 @@ def signal_handler(signal: Any, frame: Any) -> NoReturn:

# see http://stackoverflow.com/questions/23206787/
if isinstance(
threading.current_thread(), threading._MainThread # type: ignore[attr-defined]
threading.current_thread(),
threading._MainThread, # type: ignore[attr-defined]
):
signal.signal(signal.SIGINT, signal_handler)

Expand Down Expand Up @@ -857,7 +858,8 @@ def f(split: int, iterator: Iterable[T]) -> Iterable:
if "__len__" not in dir(c):
c = list(c) # Make it a list so we can compute its length
batchSize = max(
1, min(len(c) // numSlices, self._batchSize or 1024) # type: ignore[arg-type]
1,
min(len(c) // numSlices, self._batchSize or 1024), # type: ignore[arg-type]
)
serializer = BatchedSerializer(self._unbatched_serializer, batchSize)

Expand Down
3 changes: 1 addition & 2 deletions python/pyspark/core/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1731,8 +1731,7 @@ def collectWithJobGroup(
:meth:`SparkContext.setJobGroup`
"""
warnings.warn(
"Deprecated in 3.1, Use pyspark.InheritableThread with "
"the pinned thread mode enabled.",
"Deprecated in 3.1, Use pyspark.InheritableThread with the pinned thread mode enabled.",
FutureWarning,
)

Expand Down
5 changes: 1 addition & 4 deletions python/pyspark/errors/error_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@
# For more information, please see: https://issues.apache.org/jira/browse/SPARK-46810
# This discrepancy will be resolved as part of: https://issues.apache.org/jira/browse/SPARK-47429
ERROR_CLASSES_JSON = (
importlib.resources
.files("pyspark.errors")
.joinpath("error-conditions.json")
.read_text()
importlib.resources.files("pyspark.errors").joinpath("error-conditions.json").read_text()
)
ERROR_CLASSES_MAP = json.loads(ERROR_CLASSES_JSON)
2 changes: 1 addition & 1 deletion python/pyspark/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def checked_versions(spark_version, hadoop_version, hive_version):
spark_version = "spark-%s" % spark_version
if not spark_version.startswith("spark-"):
raise RuntimeError(
"Spark version should start with 'spark-' prefix; however, " "got %s" % spark_version
"Spark version should start with 'spark-' prefix; however, got %s" % spark_version
)

if hadoop_version == "without":
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -3565,7 +3565,7 @@ def _fit(self, dataset: DataFrame) -> "OneVsRestModel":
weightCol = self.getWeightCol()
else:
warnings.warn(
"weightCol is ignored, " "as it is not supported by {} now.".format(classifier)
"weightCol is ignored, as it is not supported by {} now.".format(classifier)
)

if weightCol:
Expand Down
3 changes: 1 addition & 2 deletions python/pyspark/ml/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1270,8 +1270,7 @@ class _LDAParams(HasMaxIter, HasFeaturesCol, HasSeed, HasCheckpointInterval):
optimizer: Param[str] = Param(
Params._dummy(),
"optimizer",
"Optimizer or inference algorithm used to estimate the LDA model. "
"Supported: online, em",
"Optimizer or inference algorithm used to estimate the LDA model. Supported: online, em",
typeConverter=TypeConverters.toString,
)
learningOffset: Param[float] = Param(
Expand Down
4 changes: 1 addition & 3 deletions python/pyspark/ml/connect/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,7 @@ class Pipeline(Estimator["PipelineModel"], _PipelineReadWrite):
>>> loaded_pipeline_model = PipelineModel.loadFromLocal("/tmp/pipeline")
"""

stages: Param[List[Params]] = Param(
Params._dummy(), "stages", "a list of pipeline stages"
) # type: ignore[assignment]
stages: Param[List[Params]] = Param(Params._dummy(), "stages", "a list of pipeline stages") # type: ignore[assignment]

_input_kwargs: Dict[str, Any]

Expand Down
8 changes: 2 additions & 6 deletions python/pyspark/ml/connect/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,14 +203,10 @@ def _get_skip_saving_params(self) -> List[str]:

def _save_meta_algorithm(self, root_path: str, node_path: List[str]) -> Dict[str, Any]:
metadata = self._get_metadata_to_save()
metadata[
"estimator"
] = self.getEstimator()._save_to_node_path( # type: ignore[attr-defined]
metadata["estimator"] = self.getEstimator()._save_to_node_path( # type: ignore[attr-defined]
root_path, node_path + ["crossvalidator_estimator"]
)
metadata[
"evaluator"
] = self.getEvaluator()._save_to_node_path( # type: ignore[attr-defined]
metadata["evaluator"] = self.getEvaluator()._save_to_node_path( # type: ignore[attr-defined]
root_path, node_path + ["crossvalidator_evaluator"]
)
metadata["estimator_param_maps"] = [
Expand Down
5 changes: 2 additions & 3 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -1961,8 +1961,7 @@ class _ImputerParams(HasInputCol, HasInputCols, HasOutputCol, HasOutputCols, Has
missingValue: Param[float] = Param(
Params._dummy(),
"missingValue",
"The placeholder for the missing values. All occurrences of missingValue "
"will be imputed.",
"The placeholder for the missing values. All occurrences of missingValue will be imputed.",
typeConverter=TypeConverters.toFloat,
)

Expand Down Expand Up @@ -3824,7 +3823,7 @@ class _RobustScalerParams(HasInputCol, HasOutputCol, HasRelativeError):
withScaling: Param[bool] = Param(
Params._dummy(),
"withScaling",
"Whether to scale the data to " "quantile range",
"Whether to scale the data to quantile range",
typeConverter=TypeConverters.toBoolean,
)

Expand Down
4 changes: 1 addition & 3 deletions python/pyspark/ml/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,9 +826,7 @@ def predict(data: Iterator[Union[pd.Series, pd.DataFrame]]) -> Iterator[pd.DataF
raise ValueError(msg.format(num_expected_cols, num_input_cols))

# return transformed predictions to Spark
yield _validate_and_transform_prediction_result(
preds, num_input_rows, return_type
) # type: ignore
yield _validate_and_transform_prediction_result(preds, num_input_rows, return_type) # type: ignore

return pandas_udf(predict, return_type) # type: ignore[call-overload]

Expand Down
23 changes: 12 additions & 11 deletions python/pyspark/ml/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,9 @@ def serialize(
def deserialize(
self, datum: Tuple[int, Optional[int], Optional[List[int]], List[float]]
) -> "Vector":
assert (
len(datum) == 4
), "VectorUDT.deserialize given row with length %d but requires 4" % len(datum)
assert len(datum) == 4, (
"VectorUDT.deserialize given row with length %d but requires 4" % len(datum)
)
tpe = datum[0]
if tpe == 0:
return SparseVector(cast(int, datum[1]), cast(List[int], datum[2]), datum[3])
Expand Down Expand Up @@ -265,9 +265,9 @@ def deserialize(
self,
datum: Tuple[int, int, int, Optional[List[int]], Optional[List[int]], List[float], bool],
) -> "Matrix":
assert (
len(datum) == 7
), "MatrixUDT.deserialize given row with length %d but requires 7" % len(datum)
assert len(datum) == 7, (
"MatrixUDT.deserialize given row with length %d but requires 7" % len(datum)
)
tpe = datum[0]
if tpe == 0:
return SparseMatrix(*datum[1:]) # type: ignore[arg-type]
Expand Down Expand Up @@ -625,11 +625,12 @@ def __init__(
)

if self.indices.size > 0:
assert (
np.max(self.indices) < self.size
), "Index %d is out of the size of vector with size=%d" % (
np.max(self.indices),
self.size,
assert np.max(self.indices) < self.size, (
"Index %d is out of the size of vector with size=%d"
% (
np.max(self.indices),
self.size,
)
)
assert np.min(self.indices) >= 0, "Contains negative index %d" % (np.min(self.indices))

Expand Down
6 changes: 4 additions & 2 deletions python/pyspark/ml/torch/distributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,8 @@ def _check_encryption(self) -> None:
)
)
return
raise RuntimeError(textwrap.dedent(f"""
raise RuntimeError(
textwrap.dedent(f"""
This cluster has TLS encryption enabled;
however, {name} does not support
data encryption in transit. To override
Expand All @@ -285,7 +286,8 @@ def _check_encryption(self) -> None:
to 'true' in the Spark configuration. Please note this
will cause model parameters and possibly training
data to be sent between nodes unencrypted.
"""))
""")
)


class TorchDistributor(Distributor):
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/ml/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ class _RandomForestParams(_TreeEnsembleParams):
bootstrap: Param[bool] = Param(
Params._dummy(),
"bootstrap",
"Whether bootstrap samples are used " "when building trees.",
"Whether bootstrap samples are used when building trees.",
typeConverter=TypeConverters.toBoolean,
)

Expand Down
5 changes: 4 additions & 1 deletion python/pyspark/mllib/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,7 +935,10 @@ def setInitialWeights(

# LogisticRegressionWithSGD does only binary classification.
self._model = LogisticRegressionModel(
initialWeights, 0, initialWeights.size, 2 # type: ignore[attr-defined]
initialWeights,
0,
initialWeights.size, # type: ignore[attr-defined]
2,
)
return self

Expand Down
6 changes: 2 additions & 4 deletions python/pyspark/mllib/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,12 +1020,10 @@ def latestModel(self) -> Optional[StreamingKMeansModel]:
def _validate(self, dstream: Any) -> None:
if self._model is None:
raise ValueError(
"Initial centers should be set either by setInitialCenters " "or setRandomCenters."
"Initial centers should be set either by setInitialCenters or setRandomCenters."
)
if not isinstance(dstream, DStream):
raise TypeError(
"Expected dstream to be of type DStream, " "got type %s" % type(dstream)
)
raise TypeError("Expected dstream to be of type DStream, got type %s" % type(dstream))

@since("1.5.0")
def setK(self, k: int) -> "StreamingKMeans":
Expand Down
Loading