@@ -43,6 +43,7 @@ use datafusion::common::{
4343use datafusion:: datasource:: provider_as_source;
4444use datafusion:: error:: Result as DataFusionResult ;
4545use datafusion:: execution:: session_state:: SessionStateBuilder ;
46+ use datafusion:: functions_window:: expr_fn:: row_number;
4647use datafusion:: logical_expr:: build_join_schema;
4748use datafusion:: logical_expr:: execution_props:: ExecutionProps ;
4849use datafusion:: logical_expr:: simplify:: SimplifyContext ;
@@ -72,6 +73,7 @@ use tracing::*;
7273use uuid:: Uuid ;
7374
7475use self :: barrier:: { MergeBarrier , MergeBarrierExec } ;
76+ use self :: validation:: { MergeValidation , MergeValidationExec } ;
7577use super :: { CustomExecuteHandler , Operation } ;
7678use crate :: delta_datafusion:: expr:: fmt_expr_to_sql;
7779use crate :: delta_datafusion:: logical:: MetricObserver ;
@@ -102,17 +104,22 @@ use crate::{DeltaResult, DeltaTable, DeltaTableError};
102104
103105mod barrier;
104106mod filter;
107+ mod validation;
105108
106109const SOURCE_COLUMN : & str = "__delta_rs_source" ;
107110const TARGET_COLUMN : & str = "__delta_rs_target" ;
108111
109112const OPERATION_COLUMN : & str = "__delta_rs_operation" ;
110113const DELETE_COLUMN : & str = "__delta_rs_delete" ;
114+ const TARGET_ROW_INDEX_COLUMN : & str = "__delta_rs_target_row_index" ;
111115pub ( crate ) const TARGET_INSERT_COLUMN : & str = "__delta_rs_target_insert" ;
112116pub ( crate ) const TARGET_UPDATE_COLUMN : & str = "__delta_rs_target_update" ;
113117pub ( crate ) const TARGET_DELETE_COLUMN : & str = "__delta_rs_target_delete" ;
114118pub ( crate ) const TARGET_COPY_COLUMN : & str = "__delta_rs_target_copy" ;
115119
120+ // Duplicate match validation markers
121+ const TARGET_MATCH_CARDINALITY_CLASS_COLUMN : & str = "__delta_rs_match_cardinality_class" ;
122+
116123const SOURCE_COUNT_METRIC : & str = "num_source_rows" ;
117124const TARGET_COUNT_METRIC : & str = "num_target_rows" ;
118125const TARGET_COPY_METRIC : & str = "num_copied_rows" ;
@@ -536,6 +543,27 @@ enum OperationType {
536543 Copy ,
537544}
538545
546+ /// Duplicate-match validation class encoding.
547+ #[ derive( Debug , Clone , Copy , PartialEq , Eq ) ]
548+ #[ repr( i32 ) ]
549+ enum CardinalityClass {
550+ Ignore = 0 ,
551+ MatchedUnconditionalDelete = 1 ,
552+ DuplicateInvalidating = 2 ,
553+ }
554+
555+ impl CardinalityClass {
556+ fn for_matched_operation ( op_type : OperationType , is_unconditional : bool ) -> Self {
557+ match op_type {
558+ OperationType :: Delete if is_unconditional => Self :: MatchedUnconditionalDelete ,
559+ OperationType :: Delete | OperationType :: Update | OperationType :: Copy => {
560+ Self :: DuplicateInvalidating
561+ }
562+ OperationType :: Insert | OperationType :: SourceDelete => Self :: Ignore ,
563+ }
564+ }
565+ }
566+
539567//Encapsute the User's Merge configuration for later processing
540568struct MergeOperationConfig {
541569 /// Which records to update
@@ -551,6 +579,8 @@ struct MergeOperation {
551579 /// How to update columns in a record that match the predicate
552580 operations : HashMap < Column , Expr > ,
553581 r#type : OperationType ,
582+ /// Duplicate-match validation class for this operation.
583+ cardinality_class : CardinalityClass ,
554584}
555585
556586impl MergeOperation {
@@ -610,8 +640,17 @@ impl MergeOperation {
610640 predicate : maybe_into_expr ( config. predicate , schema, state) ?,
611641 operations : ops,
612642 r#type : config. r#type ,
643+ cardinality_class : CardinalityClass :: Ignore ,
613644 } )
614645 }
646+
647+ fn into_matched ( mut self ) -> Self {
648+ let is_unconditional =
649+ matches ! ( self . r#type, OperationType :: Delete ) && self . predicate . is_none ( ) ;
650+ self . cardinality_class =
651+ CardinalityClass :: for_matched_operation ( self . r#type , is_unconditional) ;
652+ self
653+ }
615654}
616655
617656impl MergeOperationConfig {
@@ -744,6 +783,18 @@ impl ExtensionPlanner for MergeMetricExtensionPlanner {
744783 }
745784 }
746785
786+ if let Some ( validation) = node. as_any ( ) . downcast_ref :: < MergeValidation > ( ) {
787+ if physical_inputs. len ( ) != 1 {
788+ return plan_err ! ( "MergeValidationExec expects exactly one input" ) ;
789+ }
790+
791+ let schema = validation. input . schema ( ) ;
792+ return Ok ( Some ( Arc :: new ( MergeValidationExec :: new (
793+ physical_inputs. first ( ) . unwrap ( ) . clone ( ) ,
794+ planner. create_physical_expr ( & validation. expr , schema, session_state) ?,
795+ ) ) ) ) ;
796+ }
797+
747798 if let Some ( barrier) = node. as_any ( ) . downcast_ref :: < MergeBarrier > ( ) {
748799 if physical_inputs. len ( ) != 1 {
749800 return plan_err ! ( "MergeBarrierExec expects exactly one input" ) ;
@@ -946,14 +997,18 @@ async fn execute(
946997 } ) ,
947998 } ) ;
948999 let target = DataFrame :: new ( state. clone ( ) , target) ;
1000+ let target = target. with_column ( TARGET_ROW_INDEX_COLUMN , row_number ( ) ) ?;
9491001 let target = target. with_column ( TARGET_COLUMN , lit ( true ) ) ?;
9501002
9511003 let join = source. join ( target, JoinType :: Full , & [ ] , & [ ] , Some ( predicate. clone ( ) ) ) ?;
9521004 let join_schema_df = join. schema ( ) . to_owned ( ) ;
9531005
9541006 let match_operations: Vec < MergeOperation > = match_operations
9551007 . into_iter ( )
956- . map ( |op| MergeOperation :: try_from ( op, & join_schema_df, & state, & target_alias) )
1008+ . map ( |op| {
1009+ MergeOperation :: try_from ( op, & join_schema_df, & state, & target_alias)
1010+ . map ( MergeOperation :: into_matched)
1011+ } )
9571012 . collect :: < Result < Vec < MergeOperation > , DeltaTableError > > ( ) ?;
9581013
9591014 let not_match_target_operations: Vec < MergeOperation > = not_match_target_operations
@@ -1033,11 +1088,12 @@ async fn execute(
10331088
10341089 let mut when_expr = Vec :: with_capacity ( operations_size) ;
10351090 let mut then_expr = Vec :: with_capacity ( operations_size) ;
1036- let mut ops = Vec :: with_capacity ( operations_size) ;
1091+ let mut ops: Vec < ( HashMap < Column , Expr > , OperationType , CardinalityClass ) > =
1092+ Vec :: with_capacity ( operations_size) ;
10371093
10381094 fn update_case (
10391095 operations : Vec < MergeOperation > ,
1040- ops : & mut Vec < ( HashMap < Column , Expr > , OperationType ) > ,
1096+ ops : & mut Vec < ( HashMap < Column , Expr > , OperationType , CardinalityClass ) > ,
10411097 when_expr : & mut Vec < Expr > ,
10421098 then_expr : & mut Vec < Expr > ,
10431099 base_expr : & Expr ,
@@ -1053,7 +1109,7 @@ async fn execute(
10531109 when_expr. push ( predicate) ;
10541110 then_expr. push ( lit ( ops. len ( ) as i32 ) ) ;
10551111
1056- ops. push ( ( op. operations , op. r#type ) ) ;
1112+ ops. push ( ( op. operations , op. r#type , op . cardinality_class ) ) ;
10571113
10581114 let action_type = match op. r#type {
10591115 OperationType :: Update => "update" ,
@@ -1107,15 +1163,27 @@ async fn execute(
11071163
11081164 when_expr. push ( matched) ;
11091165 then_expr. push ( lit ( ops. len ( ) as i32 ) ) ;
1110- ops. push ( ( HashMap :: new ( ) , OperationType :: Copy ) ) ;
1166+ ops. push ( (
1167+ HashMap :: new ( ) ,
1168+ OperationType :: Copy ,
1169+ CardinalityClass :: DuplicateInvalidating ,
1170+ ) ) ;
11111171
11121172 when_expr. push ( not_matched_target) ;
11131173 then_expr. push ( lit ( ops. len ( ) as i32 ) ) ;
1114- ops. push ( ( HashMap :: new ( ) , OperationType :: SourceDelete ) ) ;
1174+ ops. push ( (
1175+ HashMap :: new ( ) ,
1176+ OperationType :: SourceDelete ,
1177+ CardinalityClass :: Ignore ,
1178+ ) ) ;
11151179
11161180 when_expr. push ( not_matched_source) ;
11171181 then_expr. push ( lit ( ops. len ( ) as i32 ) ) ;
1118- ops. push ( ( HashMap :: new ( ) , OperationType :: Copy ) ) ;
1182+ ops. push ( (
1183+ HashMap :: new ( ) ,
1184+ OperationType :: Copy ,
1185+ CardinalityClass :: Ignore ,
1186+ ) ) ;
11191187
11201188 let case = CaseBuilder :: new ( None , when_expr, then_expr, None ) . end ( ) ?;
11211189
@@ -1171,8 +1239,8 @@ async fn execute(
11711239 Column :: new ( source_qualifier. clone ( ) , name)
11721240 } ;
11731241
1174- for ( idx, ( operations, _) ) in ops. iter ( ) . enumerate ( ) {
1175- let op = operations
1242+ for ( idx, ( operations, _, _ ) ) in ops. iter ( ) . enumerate ( ) {
1243+ let op: Expr = operations
11761244 . get ( & column)
11771245 . map ( |expr| expr. to_owned ( ) )
11781246 . unwrap_or_else ( || col ( column. clone ( ) ) ) ;
@@ -1230,7 +1298,7 @@ async fn execute(
12301298 let mut copy_when = Vec :: with_capacity ( ops. len ( ) ) ;
12311299 let mut copy_then = Vec :: with_capacity ( ops. len ( ) ) ;
12321300
1233- for ( idx, ( _operations, r#type) ) in ops. iter ( ) . enumerate ( ) {
1301+ for ( idx, ( _operations, r#type, _ ) ) in ops. iter ( ) . enumerate ( ) {
12341302 let op = idx as i32 ;
12351303
12361304 // Used to indicate the record should be dropped prior to write
@@ -1323,6 +1391,37 @@ async fn execute(
13231391 LogicalPlanBuilder :: from ( plan) . project ( fields) ?. build ( ) ?
13241392 } ;
13251393
1394+ let new_columns = if !match_operations. is_empty ( ) {
1395+ let mut cardinality_when = Vec :: with_capacity ( ops. len ( ) ) ;
1396+ let mut cardinality_then = Vec :: with_capacity ( ops. len ( ) ) ;
1397+
1398+ for ( idx, ( _, _, cardinality_class) ) in ops. iter ( ) . enumerate ( ) {
1399+ cardinality_when. push ( lit ( idx as i32 ) ) ;
1400+ cardinality_then. push ( lit ( * cardinality_class as i32 ) ) ;
1401+ }
1402+
1403+ let cardinality_class = CaseBuilder :: new (
1404+ Some ( Box :: new ( col ( OPERATION_COLUMN ) ) ) ,
1405+ cardinality_when,
1406+ cardinality_then,
1407+ Some ( Box :: new ( lit ( 0 ) ) ) ,
1408+ )
1409+ . end ( ) ?;
1410+
1411+ let new_columns = DataFrame :: new ( state. clone ( ) , new_columns)
1412+ . with_column ( TARGET_MATCH_CARDINALITY_CLASS_COLUMN , cardinality_class) ?
1413+ . into_unoptimized_plan ( ) ;
1414+
1415+ LogicalPlan :: Extension ( Extension {
1416+ node : Arc :: new ( MergeValidation {
1417+ input : new_columns,
1418+ expr : col ( TARGET_ROW_INDEX_COLUMN ) ,
1419+ } ) ,
1420+ } )
1421+ } else {
1422+ new_columns
1423+ } ;
1424+
13261425 let distribute_expr = col ( file_column. as_str ( ) ) ;
13271426
13281427 let merge_barrier = LogicalPlan :: Extension ( Extension {
@@ -3349,6 +3448,57 @@ mod tests {
33493448 assert ! ( res. is_err( ) )
33503449 }
33513450
3451+ #[ tokio:: test]
3452+ async fn test_merge_update_multiple_source_match_error ( ) {
3453+ let schema = get_arrow_schema ( & None ) ;
3454+ let table = setup_table ( None ) . await ;
3455+ let table = write_data ( table, & schema) . await ;
3456+ let ctx = SessionContext :: new ( ) ;
3457+ let batch = RecordBatch :: try_new (
3458+ Arc :: clone ( & schema) ,
3459+ vec ! [
3460+ Arc :: new( arrow:: array:: StringArray :: from( vec![ "B" , "B" ] ) ) ,
3461+ Arc :: new( arrow:: array:: Int32Array :: from( vec![ 11 , 12 ] ) ) ,
3462+ Arc :: new( arrow:: array:: StringArray :: from( vec![
3463+ "2023-07-04" ,
3464+ "2023-07-05" ,
3465+ ] ) ) ,
3466+ ] ,
3467+ )
3468+ . unwrap ( ) ;
3469+ let source = ctx. read_batch ( batch) . unwrap ( ) ;
3470+
3471+ let expected = vec ! [
3472+ "+----+-------+------------+" ,
3473+ "| id | value | modified |" ,
3474+ "+----+-------+------------+" ,
3475+ "| A | 1 | 2021-02-01 |" ,
3476+ "| B | 10 | 2021-02-01 |" ,
3477+ "| C | 10 | 2021-02-02 |" ,
3478+ "| D | 100 | 2021-02-02 |" ,
3479+ "+----+-------+------------+" ,
3480+ ] ;
3481+
3482+ let res = table
3483+ . clone ( )
3484+ . merge ( source, "target.id = source.id" )
3485+ . with_source_alias ( "source" )
3486+ . with_target_alias ( "target" )
3487+ . when_matched_update ( |update| {
3488+ update
3489+ . update ( "value" , "source.value" )
3490+ . update ( "modified" , "source.modified" )
3491+ } )
3492+ . unwrap ( )
3493+ . await ;
3494+
3495+ assert ! ( res. is_err( ) ) ;
3496+ assert_eq ! ( table. version( ) , Some ( 1 ) ) ;
3497+
3498+ let actual = get_data ( & table) . await ;
3499+ assert_batches_sorted_eq ! ( & expected, & actual) ;
3500+ }
3501+
33523502 #[ tokio:: test]
33533503 async fn test_merge_partitions ( ) {
33543504 /* Validate the join predicate works with table partitions */
0 commit comments