Skip to content

Commit 829045e

Browse files
devin-petersohnHyukjinKwon
authored andcommitted
[SPARK-46163][PS] DataFrame.update parameters filter_func and errors
### What changes were proposed in this pull request? DataFrame.update parameters filter_func and errors ### Why are the changes needed? To add missing parameters to `update` function ### Does this PR introduce _any_ user-facing change? Yes, new parameter implementation ### How was this patch tested? CI ### Was this patch authored or co-authored using generative AI tooling? Co-authored-by: Claude Sonnet 4.5 Closes #54287 from devin-petersohn/devin/update_params. Authored-by: Devin Petersohn <devin.petersohn@gmail.com> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
1 parent d2a47b9 commit 829045e

File tree

2 files changed

+181
-7
lines changed

2 files changed

+181
-7
lines changed

python/pyspark/pandas/frame.py

Lines changed: 74 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9132,12 +9132,21 @@ def combine_first(self, other: "DataFrame") -> "DataFrame":
91329132
)
91339133
return DataFrame(internal)
91349134

9135-
# TODO(SPARK-46163): add 'filter_func' and 'errors' parameter
9136-
def update(self, other: "DataFrame", join: str = "left", overwrite: bool = True) -> None:
9135+
def update(
9136+
self,
9137+
other: "DataFrame",
9138+
join: str = "left",
9139+
overwrite: bool = True,
9140+
filter_func: Optional[Callable[[Any], bool]] = None,
9141+
errors: str = "ignore",
9142+
) -> None:
91379143
"""
91389144
Modify in place using non-NA values from another DataFrame.
91399145
Aligns on indices. There is no return value.
91409146

9147+
.. note:: When ``errors='raise'``, this method forces materialization to check
9148+
for overlapping non-NA data, which may impact performance on large datasets.
9149+
91419150
Parameters
91429151
----------
91439152
other : DataFrame, or Series
@@ -9149,10 +9158,23 @@ def update(self, other: "DataFrame", join: str = "left", overwrite: bool = True)
91499158
* True: overwrite original DataFrame's values with values from `other`.
91509159
* False: only update values that are NA in the original DataFrame.
91519160

9161+
filter_func : callable(1d-array) -> bool 1d-array, optional
9162+
Can choose to replace values other than NA. Return True for values
9163+
which should be updated. Applied to original DataFrame's values.
9164+
errors : {'ignore', 'raise'}, default 'ignore'
9165+
If 'raise', will raise a ValueError if the DataFrame and other both
9166+
contain non-NA data in the same place.
9167+
91529168
Returns
91539169
-------
91549170
None : method directly changes calling object
91559171

9172+
Raises
9173+
------
9174+
ValueError
9175+
If errors='raise' and overlapping non-NA data is detected.
9176+
If errors is not 'ignore' or 'raise'.
9177+
91569178
See Also
91579179
--------
91589180
DataFrame.merge : For column(s)-on-columns(s) operations.
@@ -9204,9 +9226,22 @@ def update(self, other: "DataFrame", join: str = "left", overwrite: bool = True)
92049226
0 1 4.0
92059227
1 2 500.0
92069228
2 3 6.0
9229+
9230+
Using filter_func to selectively update values:
9231+
9232+
>>> df = ps.DataFrame({'A': [1, 2, 3], 'B': [400, 500, 600]})
9233+
>>> new_df = ps.DataFrame({'B': [4, 5, 6]})
9234+
>>> df.update(new_df, filter_func=lambda x: x > 450)
9235+
>>> df.sort_index()
9236+
A B
9237+
0 1 400
9238+
1 2 5
9239+
2 3 6
92079240
"""
92089241
if join != "left":
92099242
raise NotImplementedError("Only left join is supported")
9243+
if errors not in ("ignore", "raise"):
9244+
raise ValueError("errors must be either 'ignore' or 'raise'")
92109245

92119246
if isinstance(other, ps.Series):
92129247
other = other.to_frame()
@@ -9218,21 +9253,53 @@ def update(self, other: "DataFrame", join: str = "left", overwrite: bool = True)
92189253
other[update_columns], rsuffix="_new"
92199254
)._internal.resolved_copy.spark_frame
92209255

9256+
if errors == "raise" and update_columns:
9257+
from pyspark.sql.types import BooleanType
9258+
9259+
any_overlap = F.lit(False)
9260+
for column_labels in update_columns:
9261+
column_name = self._internal.spark_column_name_for(column_labels)
9262+
old_col = scol_for(update_sdf, column_name)
9263+
new_col = scol_for(
9264+
update_sdf, other._internal.spark_column_name_for(column_labels) + "_new"
9265+
)
9266+
9267+
overlap = old_col.isNotNull() & new_col.isNotNull()
9268+
if filter_func is not None:
9269+
overlap = overlap & pandas_udf( # type: ignore[call-overload]
9270+
filter_func, BooleanType()
9271+
)(old_col)
9272+
9273+
any_overlap = any_overlap | overlap
9274+
9275+
if update_sdf.select(F.max(F.when(any_overlap, 1).otherwise(0))).first()[0]:
9276+
raise ValueError("Data overlaps.")
9277+
92219278
data_fields = self._internal.data_fields.copy()
92229279
for column_labels in update_columns:
92239280
column_name = self._internal.spark_column_name_for(column_labels)
92249281
old_col = scol_for(update_sdf, column_name)
92259282
new_col = scol_for(
92269283
update_sdf, other._internal.spark_column_name_for(column_labels) + "_new"
92279284
)
9228-
if overwrite:
9229-
update_sdf = update_sdf.withColumn(
9230-
column_name, F.when(new_col.isNull(), old_col).otherwise(new_col)
9285+
9286+
if filter_func is not None:
9287+
from pyspark.sql.types import BooleanType
9288+
9289+
mask = pandas_udf(filter_func, BooleanType())(old_col) # type: ignore[call-overload]
9290+
updated_col = (
9291+
F.when(new_col.isNull() | mask.isNull() | ~mask, old_col).otherwise(new_col)
9292+
if overwrite
9293+
else F.when(old_col.isNull() & mask, new_col).otherwise(old_col)
92319294
)
92329295
else:
9233-
update_sdf = update_sdf.withColumn(
9234-
column_name, F.when(old_col.isNull(), new_col).otherwise(old_col)
9296+
updated_col = (
9297+
F.when(new_col.isNull(), old_col).otherwise(new_col)
9298+
if overwrite
9299+
else F.when(old_col.isNull(), new_col).otherwise(old_col)
92359300
)
9301+
9302+
update_sdf = update_sdf.withColumn(column_name, updated_col)
92369303
data_fields[self._internal.column_labels.index(column_labels)] = None
92379304
sdf = update_sdf.select(
92389305
*[scol_for(update_sdf, col) for col in self._internal.spark_column_names],

python/pyspark/pandas/tests/computation/test_combine.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,113 @@ def get_data(left_columns=None, right_columns=None):
658658
left_psdf.sort_values(by=[("X", "A"), ("X", "B")]),
659659
)
660660

661+
def test_update_with_filter_func(self):
662+
# Test filter_func parameter
663+
left_pdf = pd.DataFrame({"A": [1, 2, 3, 4], "B": [10, 20, 30, 40]})
664+
right_pdf = pd.DataFrame({"B": [100, 200, 300, 400]})
665+
666+
left_psdf = ps.from_pandas(left_pdf)
667+
right_psdf = ps.from_pandas(right_pdf)
668+
669+
# Only update values > 25
670+
left_pdf.update(right_pdf, filter_func=lambda x: x > 25)
671+
left_psdf.update(right_psdf, filter_func=lambda x: x > 25)
672+
673+
self.assert_eq(left_pdf.sort_index(), left_psdf.sort_index())
674+
675+
def test_update_filter_func_overwrite_false(self):
676+
# Test filter_func with overwrite=False
677+
left_pdf = pd.DataFrame({"A": [1, 2, 3], "B": [None, 20, None]})
678+
right_pdf = pd.DataFrame({"B": [100, 200, 300]})
679+
680+
left_psdf = ps.from_pandas(left_pdf)
681+
right_psdf = ps.from_pandas(right_pdf)
682+
683+
# Only update where new value > 150 (and old is null)
684+
left_pdf.update(right_pdf, overwrite=False, filter_func=lambda x: x > 150)
685+
left_psdf.update(right_psdf, overwrite=False, filter_func=lambda x: x > 150)
686+
687+
self.assert_eq(left_pdf.sort_index(), left_psdf.sort_index())
688+
689+
def test_update_errors_raise_with_overlap(self):
690+
# Test that errors='raise' raises ValueError on overlap
691+
left_psdf = ps.DataFrame({"A": [1, 2, 3], "B": [10, 20, 30]})
692+
right_psdf = ps.DataFrame({"B": [100, 200, 300]})
693+
694+
# Should raise because both have non-null values
695+
with self.assertRaisesRegex(ValueError, "Data overlaps."):
696+
left_psdf.update(right_psdf, errors="raise")
697+
698+
def test_update_errors_raise_no_overlap(self):
699+
# Test that errors='raise' works when no overlap
700+
left_pdf = pd.DataFrame({"A": [1, 2, 3], "B": [None, None, 30]})
701+
right_pdf = pd.DataFrame({"B": [100, 200, None]})
702+
703+
left_psdf = ps.from_pandas(left_pdf)
704+
right_psdf = ps.from_pandas(right_pdf)
705+
706+
left_pdf.update(right_pdf, errors="raise")
707+
left_psdf.update(right_psdf, errors="raise")
708+
709+
self.assert_eq(left_pdf.sort_index(), left_psdf.sort_index())
710+
711+
def test_update_errors_invalid_value(self):
712+
# Test that invalid errors parameter raises ValueError
713+
left_psdf = ps.DataFrame({"A": [1, 2, 3]})
714+
right_psdf = ps.DataFrame({"A": [4, 5, 6]})
715+
716+
with self.assertRaisesRegex(ValueError, "errors must be either 'ignore' or 'raise'"):
717+
left_psdf.update(right_psdf, errors="invalid")
718+
719+
def test_update_filter_func_and_errors_raise(self):
720+
# Test combination of filter_func and errors='raise'
721+
left_psdf = ps.DataFrame({"A": [1, 2, 3], "B": [10, 20, 30]})
722+
right_psdf = ps.DataFrame({"B": [100, 200, 300]})
723+
724+
# Filter only values < 25 - should find overlaps at positions 0 and 1
725+
with self.assertRaisesRegex(ValueError, "Data overlaps."):
726+
left_psdf.update(right_psdf, filter_func=lambda x: x < 25, errors="raise")
727+
728+
# Filter only values > 100 - no overlaps since no original values > 100
729+
left_psdf2 = ps.DataFrame({"A": [1, 2, 3], "B": [10, 20, 30]})
730+
right_psdf2 = ps.DataFrame({"B": [100, 200, 300]})
731+
732+
# Should not raise - no values in original DataFrame match filter
733+
left_psdf2.update(right_psdf2, filter_func=lambda x: x > 100, errors="raise")
734+
735+
def test_update_filter_func_all_false(self):
736+
# Test filter_func that returns all False
737+
left_pdf = pd.DataFrame({"A": [1, 2, 3], "B": [10, 20, 30]})
738+
right_pdf = pd.DataFrame({"B": [100, 200, 300]})
739+
740+
left_psdf = ps.from_pandas(left_pdf.copy())
741+
right_psdf = ps.from_pandas(right_pdf)
742+
743+
# Filter that matches nothing
744+
original_left_pdf = left_pdf.copy()
745+
original_left_psdf = left_psdf.copy()
746+
747+
left_pdf.update(right_pdf, filter_func=lambda x: x > 1000)
748+
left_psdf.update(right_psdf, filter_func=lambda x: x > 1000)
749+
750+
# DataFrame should be unchanged
751+
self.assert_eq(left_pdf.sort_index(), original_left_pdf.sort_index())
752+
self.assert_eq(left_psdf.sort_index(), original_left_psdf.sort_index())
753+
754+
def test_update_filter_func_with_nulls(self):
755+
# Test filter_func handling of null values
756+
left_pdf = pd.DataFrame({"A": [1, 2, 3], "B": [None, 20, None]})
757+
right_pdf = pd.DataFrame({"B": [100, 200, 300]})
758+
759+
left_psdf = ps.from_pandas(left_pdf)
760+
right_psdf = ps.from_pandas(right_pdf)
761+
762+
# Filter values > 10 (nulls will not match)
763+
left_pdf.update(right_pdf, filter_func=lambda x: x > 10)
764+
left_psdf.update(right_psdf, filter_func=lambda x: x > 10)
765+
766+
self.assert_eq(left_pdf.sort_index(), left_psdf.sort_index())
767+
661768

662769
class FrameCombineTests(
663770
FrameCombineMixin,

0 commit comments

Comments
 (0)