Skip to content

Commit 6bcdf04

Browse files
adampolomskirtyler
authored andcommitted
fix: enforce duplicate-match validation for merge
Signed-off-by: adam.polomski <adam.polomski@relativity.com>
1 parent 0bb2d6b commit 6bcdf04

3 files changed

Lines changed: 573 additions & 10 deletions

File tree

crates/core/src/operations/merge/mod.rs

Lines changed: 160 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ use datafusion::common::{
4343
use datafusion::datasource::provider_as_source;
4444
use datafusion::error::Result as DataFusionResult;
4545
use datafusion::execution::session_state::SessionStateBuilder;
46+
use datafusion::functions_window::expr_fn::row_number;
4647
use datafusion::logical_expr::build_join_schema;
4748
use datafusion::logical_expr::execution_props::ExecutionProps;
4849
use datafusion::logical_expr::simplify::SimplifyContext;
@@ -72,6 +73,7 @@ use tracing::*;
7273
use uuid::Uuid;
7374

7475
use self::barrier::{MergeBarrier, MergeBarrierExec};
76+
use self::validation::{MergeValidation, MergeValidationExec};
7577
use super::{CustomExecuteHandler, Operation};
7678
use crate::delta_datafusion::expr::fmt_expr_to_sql;
7779
use crate::delta_datafusion::logical::MetricObserver;
@@ -102,17 +104,22 @@ use crate::{DeltaResult, DeltaTable, DeltaTableError};
102104

103105
mod barrier;
104106
mod filter;
107+
mod validation;
105108

106109
const SOURCE_COLUMN: &str = "__delta_rs_source";
107110
const TARGET_COLUMN: &str = "__delta_rs_target";
108111

109112
const OPERATION_COLUMN: &str = "__delta_rs_operation";
110113
const DELETE_COLUMN: &str = "__delta_rs_delete";
114+
const TARGET_ROW_INDEX_COLUMN: &str = "__delta_rs_target_row_index";
111115
pub(crate) const TARGET_INSERT_COLUMN: &str = "__delta_rs_target_insert";
112116
pub(crate) const TARGET_UPDATE_COLUMN: &str = "__delta_rs_target_update";
113117
pub(crate) const TARGET_DELETE_COLUMN: &str = "__delta_rs_target_delete";
114118
pub(crate) const TARGET_COPY_COLUMN: &str = "__delta_rs_target_copy";
115119

120+
// Duplicate match validation markers
121+
const TARGET_MATCH_CARDINALITY_CLASS_COLUMN: &str = "__delta_rs_match_cardinality_class";
122+
116123
const SOURCE_COUNT_METRIC: &str = "num_source_rows";
117124
const TARGET_COUNT_METRIC: &str = "num_target_rows";
118125
const TARGET_COPY_METRIC: &str = "num_copied_rows";
@@ -536,6 +543,27 @@ enum OperationType {
536543
Copy,
537544
}
538545

546+
/// Duplicate-match validation class encoding.
547+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
548+
#[repr(i32)]
549+
enum CardinalityClass {
550+
Ignore = 0,
551+
MatchedUnconditionalDelete = 1,
552+
DuplicateInvalidating = 2,
553+
}
554+
555+
impl CardinalityClass {
556+
fn for_matched_operation(op_type: OperationType, is_unconditional: bool) -> Self {
557+
match op_type {
558+
OperationType::Delete if is_unconditional => Self::MatchedUnconditionalDelete,
559+
OperationType::Delete | OperationType::Update | OperationType::Copy => {
560+
Self::DuplicateInvalidating
561+
}
562+
OperationType::Insert | OperationType::SourceDelete => Self::Ignore,
563+
}
564+
}
565+
}
566+
539567
//Encapsute the User's Merge configuration for later processing
540568
struct MergeOperationConfig {
541569
/// Which records to update
@@ -551,6 +579,8 @@ struct MergeOperation {
551579
/// How to update columns in a record that match the predicate
552580
operations: HashMap<Column, Expr>,
553581
r#type: OperationType,
582+
/// Duplicate-match validation class for this operation.
583+
cardinality_class: CardinalityClass,
554584
}
555585

556586
impl MergeOperation {
@@ -610,8 +640,17 @@ impl MergeOperation {
610640
predicate: maybe_into_expr(config.predicate, schema, state)?,
611641
operations: ops,
612642
r#type: config.r#type,
643+
cardinality_class: CardinalityClass::Ignore,
613644
})
614645
}
646+
647+
fn into_matched(mut self) -> Self {
648+
let is_unconditional =
649+
matches!(self.r#type, OperationType::Delete) && self.predicate.is_none();
650+
self.cardinality_class =
651+
CardinalityClass::for_matched_operation(self.r#type, is_unconditional);
652+
self
653+
}
615654
}
616655

617656
impl MergeOperationConfig {
@@ -744,6 +783,18 @@ impl ExtensionPlanner for MergeMetricExtensionPlanner {
744783
}
745784
}
746785

786+
if let Some(validation) = node.as_any().downcast_ref::<MergeValidation>() {
787+
if physical_inputs.len() != 1 {
788+
return plan_err!("MergeValidationExec expects exactly one input");
789+
}
790+
791+
let schema = validation.input.schema();
792+
return Ok(Some(Arc::new(MergeValidationExec::new(
793+
physical_inputs.first().unwrap().clone(),
794+
planner.create_physical_expr(&validation.expr, schema, session_state)?,
795+
))));
796+
}
797+
747798
if let Some(barrier) = node.as_any().downcast_ref::<MergeBarrier>() {
748799
if physical_inputs.len() != 1 {
749800
return plan_err!("MergeBarrierExec expects exactly one input");
@@ -946,14 +997,18 @@ async fn execute(
946997
}),
947998
});
948999
let target = DataFrame::new(state.clone(), target);
1000+
let target = target.with_column(TARGET_ROW_INDEX_COLUMN, row_number())?;
9491001
let target = target.with_column(TARGET_COLUMN, lit(true))?;
9501002

9511003
let join = source.join(target, JoinType::Full, &[], &[], Some(predicate.clone()))?;
9521004
let join_schema_df = join.schema().to_owned();
9531005

9541006
let match_operations: Vec<MergeOperation> = match_operations
9551007
.into_iter()
956-
.map(|op| MergeOperation::try_from(op, &join_schema_df, &state, &target_alias))
1008+
.map(|op| {
1009+
MergeOperation::try_from(op, &join_schema_df, &state, &target_alias)
1010+
.map(MergeOperation::into_matched)
1011+
})
9571012
.collect::<Result<Vec<MergeOperation>, DeltaTableError>>()?;
9581013

9591014
let not_match_target_operations: Vec<MergeOperation> = not_match_target_operations
@@ -1033,11 +1088,12 @@ async fn execute(
10331088

10341089
let mut when_expr = Vec::with_capacity(operations_size);
10351090
let mut then_expr = Vec::with_capacity(operations_size);
1036-
let mut ops = Vec::with_capacity(operations_size);
1091+
let mut ops: Vec<(HashMap<Column, Expr>, OperationType, CardinalityClass)> =
1092+
Vec::with_capacity(operations_size);
10371093

10381094
fn update_case(
10391095
operations: Vec<MergeOperation>,
1040-
ops: &mut Vec<(HashMap<Column, Expr>, OperationType)>,
1096+
ops: &mut Vec<(HashMap<Column, Expr>, OperationType, CardinalityClass)>,
10411097
when_expr: &mut Vec<Expr>,
10421098
then_expr: &mut Vec<Expr>,
10431099
base_expr: &Expr,
@@ -1053,7 +1109,7 @@ async fn execute(
10531109
when_expr.push(predicate);
10541110
then_expr.push(lit(ops.len() as i32));
10551111

1056-
ops.push((op.operations, op.r#type));
1112+
ops.push((op.operations, op.r#type, op.cardinality_class));
10571113

10581114
let action_type = match op.r#type {
10591115
OperationType::Update => "update",
@@ -1107,15 +1163,27 @@ async fn execute(
11071163

11081164
when_expr.push(matched);
11091165
then_expr.push(lit(ops.len() as i32));
1110-
ops.push((HashMap::new(), OperationType::Copy));
1166+
ops.push((
1167+
HashMap::new(),
1168+
OperationType::Copy,
1169+
CardinalityClass::DuplicateInvalidating,
1170+
));
11111171

11121172
when_expr.push(not_matched_target);
11131173
then_expr.push(lit(ops.len() as i32));
1114-
ops.push((HashMap::new(), OperationType::SourceDelete));
1174+
ops.push((
1175+
HashMap::new(),
1176+
OperationType::SourceDelete,
1177+
CardinalityClass::Ignore,
1178+
));
11151179

11161180
when_expr.push(not_matched_source);
11171181
then_expr.push(lit(ops.len() as i32));
1118-
ops.push((HashMap::new(), OperationType::Copy));
1182+
ops.push((
1183+
HashMap::new(),
1184+
OperationType::Copy,
1185+
CardinalityClass::Ignore,
1186+
));
11191187

11201188
let case = CaseBuilder::new(None, when_expr, then_expr, None).end()?;
11211189

@@ -1171,8 +1239,8 @@ async fn execute(
11711239
Column::new(source_qualifier.clone(), name)
11721240
};
11731241

1174-
for (idx, (operations, _)) in ops.iter().enumerate() {
1175-
let op = operations
1242+
for (idx, (operations, _, _)) in ops.iter().enumerate() {
1243+
let op: Expr = operations
11761244
.get(&column)
11771245
.map(|expr| expr.to_owned())
11781246
.unwrap_or_else(|| col(column.clone()));
@@ -1230,7 +1298,7 @@ async fn execute(
12301298
let mut copy_when = Vec::with_capacity(ops.len());
12311299
let mut copy_then = Vec::with_capacity(ops.len());
12321300

1233-
for (idx, (_operations, r#type)) in ops.iter().enumerate() {
1301+
for (idx, (_operations, r#type, _)) in ops.iter().enumerate() {
12341302
let op = idx as i32;
12351303

12361304
// Used to indicate the record should be dropped prior to write
@@ -1323,6 +1391,37 @@ async fn execute(
13231391
LogicalPlanBuilder::from(plan).project(fields)?.build()?
13241392
};
13251393

1394+
let new_columns = if !match_operations.is_empty() {
1395+
let mut cardinality_when = Vec::with_capacity(ops.len());
1396+
let mut cardinality_then = Vec::with_capacity(ops.len());
1397+
1398+
for (idx, (_, _, cardinality_class)) in ops.iter().enumerate() {
1399+
cardinality_when.push(lit(idx as i32));
1400+
cardinality_then.push(lit(*cardinality_class as i32));
1401+
}
1402+
1403+
let cardinality_class = CaseBuilder::new(
1404+
Some(Box::new(col(OPERATION_COLUMN))),
1405+
cardinality_when,
1406+
cardinality_then,
1407+
Some(Box::new(lit(0))),
1408+
)
1409+
.end()?;
1410+
1411+
let new_columns = DataFrame::new(state.clone(), new_columns)
1412+
.with_column(TARGET_MATCH_CARDINALITY_CLASS_COLUMN, cardinality_class)?
1413+
.into_unoptimized_plan();
1414+
1415+
LogicalPlan::Extension(Extension {
1416+
node: Arc::new(MergeValidation {
1417+
input: new_columns,
1418+
expr: col(TARGET_ROW_INDEX_COLUMN),
1419+
}),
1420+
})
1421+
} else {
1422+
new_columns
1423+
};
1424+
13261425
let distribute_expr = col(file_column.as_str());
13271426

13281427
let merge_barrier = LogicalPlan::Extension(Extension {
@@ -3349,6 +3448,57 @@ mod tests {
33493448
assert!(res.is_err())
33503449
}
33513450

3451+
#[tokio::test]
3452+
async fn test_merge_update_multiple_source_match_error() {
3453+
let schema = get_arrow_schema(&None);
3454+
let table = setup_table(None).await;
3455+
let table = write_data(table, &schema).await;
3456+
let ctx = SessionContext::new();
3457+
let batch = RecordBatch::try_new(
3458+
Arc::clone(&schema),
3459+
vec![
3460+
Arc::new(arrow::array::StringArray::from(vec!["B", "B"])),
3461+
Arc::new(arrow::array::Int32Array::from(vec![11, 12])),
3462+
Arc::new(arrow::array::StringArray::from(vec![
3463+
"2023-07-04",
3464+
"2023-07-05",
3465+
])),
3466+
],
3467+
)
3468+
.unwrap();
3469+
let source = ctx.read_batch(batch).unwrap();
3470+
3471+
let expected = vec![
3472+
"+----+-------+------------+",
3473+
"| id | value | modified |",
3474+
"+----+-------+------------+",
3475+
"| A | 1 | 2021-02-01 |",
3476+
"| B | 10 | 2021-02-01 |",
3477+
"| C | 10 | 2021-02-02 |",
3478+
"| D | 100 | 2021-02-02 |",
3479+
"+----+-------+------------+",
3480+
];
3481+
3482+
let res = table
3483+
.clone()
3484+
.merge(source, "target.id = source.id")
3485+
.with_source_alias("source")
3486+
.with_target_alias("target")
3487+
.when_matched_update(|update| {
3488+
update
3489+
.update("value", "source.value")
3490+
.update("modified", "source.modified")
3491+
})
3492+
.unwrap()
3493+
.await;
3494+
3495+
assert!(res.is_err());
3496+
assert_eq!(table.version(), Some(1));
3497+
3498+
let actual = get_data(&table).await;
3499+
assert_batches_sorted_eq!(&expected, &actual);
3500+
}
3501+
33523502
#[tokio::test]
33533503
async fn test_merge_partitions() {
33543504
/* Validate the join predicate works with table partitions */

0 commit comments

Comments
 (0)