Skip to content

Commit 81e8678

Browse files
ueshinHyukjinKwon
authored andcommitted
[SPARK-56187][PS] Fix Series.argsort null ordering for pandas 3
### What changes were proposed in this pull request? This PR updates pandas-on-Spark `Series.argsort()` to follow the pandas 3 behavior for null values. Before this change, `Series.argsort()` always dropped nulls from the ordering step and appended `-1` for null positions. That matches pandas 2 behavior, but it no longer matches pandas 3, where nulls are ordered last and receive real positional indices. This patch keeps the existing pandas `< 3.0.0` behavior, including the deprecation warning, and switches pandas `>= 3.0.0` to sort the full Series with nulls ordered last so the returned positions match upstream pandas. ### Why are the changes needed? `pyspark.pandas.tests.series.test_arg_ops SeriesArgOpsTests.test_argsort` fails in the pandas 3 environment because pandas-on-Spark still implements the deprecated pandas 2 null-handling semantics. For example, with null values present: - pandas 3 returns positional indices for all rows, with nulls ordered last - pandas-on-Spark returned `-1` for null rows This makes `Series.argsort()` inconsistent with pandas 3 and causes the existing compatibility test to fail. ### Does this PR introduce _any_ user-facing change? Yes, it will behave more like pandas 3. ### How was this patch tested? The existing tests should pass. ### Was this patch authored or co-authored using generative AI tooling? Generated-by: Codex (GPT-5) Closes #54989 from ueshin/issues/SPARK-56187/argsort. Authored-by: Takuya Ueshin <[email protected]> Signed-off-by: Hyukjin Kwon <[email protected]>
1 parent ab2a069 commit 81e8678

File tree

1 file changed

+25
-16
lines changed

1 file changed

+25
-16
lines changed

python/pyspark/pandas/series.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6422,14 +6422,17 @@ def argsort(self) -> "Series":
64226422
10 10
64236423
dtype: int64
64246424
"""
6425-
warnings.warn(
6426-
"The behavior of Series.argsort in the presence of NA values is deprecated. "
6427-
"In a future version, NA values will be ordered last instead of set to -1.",
6428-
FutureWarning,
6429-
)
6430-
notnull = self.loc[self.notnull()]
6425+
if LooseVersion(pd.__version__) < "3.0.0":
6426+
warnings.warn(
6427+
"The behavior of Series.argsort in the presence of NA values is deprecated. "
6428+
"In a future version, NA values will be ordered last instead of set to -1.",
6429+
FutureWarning,
6430+
)
6431+
source = self.loc[self.notnull()]
6432+
else:
6433+
source = self
64316434

6432-
sdf_for_index = notnull._internal.spark_frame.select(notnull._internal.index_spark_columns)
6435+
sdf_for_index = source._internal.spark_frame.select(source._internal.index_spark_columns)
64336436

64346437
tmp_join_key = verify_temp_column_name(sdf_for_index, "__tmp_join_key__")
64356438
sdf_for_index = InternalFrame.attach_distributed_sequence_column(
@@ -6446,8 +6449,8 @@ def argsort(self) -> "Series":
64466449
# | 4| 4|
64476450
# +----------------+-----------------+
64486451

6449-
sdf_for_data = notnull._internal.spark_frame.select(
6450-
notnull.spark.column.alias("values"), NATURAL_ORDER_COLUMN_NAME
6452+
sdf_for_data = source._internal.spark_frame.select(
6453+
source.spark.column.alias("values"), NATURAL_ORDER_COLUMN_NAME
64516454
)
64526455
sdf_for_data = InternalFrame.attach_distributed_sequence_column(
64536456
sdf_for_data, SPARK_DEFAULT_SERIES_NAME
@@ -6463,9 +6466,12 @@ def argsort(self) -> "Series":
64636466
# | 4| 2| 128849018880|
64646467
# +---+------+-----------------+
64656468

6466-
sdf_for_data = sdf_for_data.sort(
6467-
scol_for(sdf_for_data, "values"), NATURAL_ORDER_COLUMN_NAME
6468-
).drop("values", NATURAL_ORDER_COLUMN_NAME)
6469+
value_scol = scol_for(sdf_for_data, "values")
6470+
if LooseVersion(pd.__version__) < "3.0.0":
6471+
sdf_for_data = sdf_for_data.sort(value_scol, NATURAL_ORDER_COLUMN_NAME)
6472+
else:
6473+
sdf_for_data = sdf_for_data.sort(value_scol.asc_nulls_last(), NATURAL_ORDER_COLUMN_NAME)
6474+
sdf_for_data = sdf_for_data.drop("values", NATURAL_ORDER_COLUMN_NAME)
64696475

64706476
tmp_join_key = verify_temp_column_name(sdf_for_data, "__tmp_join_key__")
64716477
sdf_for_data = InternalFrame.attach_distributed_sequence_column(sdf_for_data, tmp_join_key)
@@ -6492,10 +6498,13 @@ def argsort(self) -> "Series":
64926498
)
64936499
psser = first_series(DataFrame(internal))
64946500

6495-
return cast(
6496-
Series,
6497-
ps.concat([psser, self.loc[self.isnull()].spark.transform(lambda _: F.lit(-1))]),
6498-
)
6501+
if LooseVersion(pd.__version__) < "3.0.0":
6502+
return cast(
6503+
Series,
6504+
ps.concat([psser, self.loc[self.isnull()].spark.transform(lambda _: F.lit(-1))]),
6505+
)
6506+
else:
6507+
return psser
64996508

65006509
def argmax(self, axis: Axis = None, skipna: bool = True) -> int:
65016510
"""

0 commit comments

Comments
 (0)