Skip to content

Commit ceb85ff

Browse files
committed
replace recursive and add test
1 parent 0d075ba commit ceb85ff

2 files changed

Lines changed: 96 additions & 53 deletions

File tree

datafusion/physical-expr/src/simplifier/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ use datafusion_common::{
2424
};
2525
use std::sync::Arc;
2626

27-
use crate::{simplifier::not::simplify_not_expr_recursive, PhysicalExpr};
27+
use crate::{simplifier::not::simplify_not_expr, PhysicalExpr};
2828

2929
pub mod not;
3030
pub mod unwrap_cast;
@@ -58,7 +58,7 @@ impl<'a> TreeNodeRewriter for PhysicalExprSimplifier<'a> {
5858

5959
fn f_up(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
6060
// Apply NOT expression simplification first
61-
let not_expr_simplified = simplify_not_expr_recursive(&node, self.schema)?;
61+
let not_expr_simplified = simplify_not_expr(&node, self.schema)?;
6262
let node = not_expr_simplified.data;
6363
let transformed = not_expr_simplified.transformed;
6464

datafusion/physical-expr/src/simplifier/not.rs

Lines changed: 94 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ use crate::expressions::{in_list, lit, BinaryExpr, InListExpr, Literal, NotExpr}
3434
use crate::PhysicalExpr;
3535

3636
/// Attempts to simplify NOT expressions
37-
pub(crate) fn simplify_not_expr(
37+
pub(crate) fn simplify_not_expr_impl(
3838
expr: Arc<dyn PhysicalExpr>,
3939
schema: &Schema,
4040
) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
@@ -48,10 +48,8 @@ pub(crate) fn simplify_not_expr(
4848

4949
// Handle NOT(NOT(expr)) -> expr (double negation elimination)
5050
if let Some(inner_not) = inner_expr.as_any().downcast_ref::<NotExpr>() {
51-
// Recursively simplify the inner expression
52-
let simplified = simplify_not_expr_recursive(inner_not.arg(), schema)?;
5351
// We eliminated double negation, so always return transformed=true
54-
return Ok(Transformed::yes(simplified.data));
52+
return Ok(Transformed::yes(Arc::clone(inner_not.arg())));
5553
}
5654

5755
// Handle NOT(literal) -> !literal
@@ -81,10 +79,8 @@ pub(crate) fn simplify_not_expr(
8179
if let Some(binary_expr) = inner_expr.as_any().downcast_ref::<BinaryExpr>() {
8280
if let Some(negated_op) = negate_operator(binary_expr.op()) {
8381
// Recursively simplify the left and right expressions first
84-
let left_simplified =
85-
simplify_not_expr_recursive(binary_expr.left(), schema)?;
86-
let right_simplified =
87-
simplify_not_expr_recursive(binary_expr.right(), schema)?;
82+
let left_simplified = simplify_not_expr(binary_expr.left(), schema)?;
83+
let right_simplified = simplify_not_expr(binary_expr.right(), schema)?;
8884

8985
let new_binary = Arc::new(BinaryExpr::new(
9086
left_simplified.data,
@@ -105,8 +101,8 @@ pub(crate) fn simplify_not_expr(
105101
Arc::new(NotExpr::new(Arc::clone(binary_expr.right())));
106102

107103
// Recursively simplify the NOT expressions
108-
let simplified_left = simplify_not_expr_recursive(&not_left, schema)?;
109-
let simplified_right = simplify_not_expr_recursive(&not_right, schema)?;
104+
let simplified_left = simplify_not_expr(&not_left, schema)?;
105+
let simplified_right = simplify_not_expr(&not_right, schema)?;
110106

111107
let new_binary = Arc::new(BinaryExpr::new(
112108
simplified_left.data,
@@ -123,8 +119,8 @@ pub(crate) fn simplify_not_expr(
123119
Arc::new(NotExpr::new(Arc::clone(binary_expr.right())));
124120

125121
// Recursively simplify the NOT expressions
126-
let simplified_left = simplify_not_expr_recursive(&not_left, schema)?;
127-
let simplified_right = simplify_not_expr_recursive(&not_right, schema)?;
122+
let simplified_left = simplify_not_expr(&not_left, schema)?;
123+
let simplified_right = simplify_not_expr(&not_right, schema)?;
128124

129125
let new_binary = Arc::new(BinaryExpr::new(
130126
simplified_left.data,
@@ -141,43 +137,43 @@ pub(crate) fn simplify_not_expr(
141137
Ok(Transformed::no(expr))
142138
}
143139

144-
/// Helper function that recursively simplifies expressions, including NOT expressions
145-
pub fn simplify_not_expr_recursive(
140+
pub fn simplify_not_expr(
146141
expr: &Arc<dyn PhysicalExpr>,
147142
schema: &Schema,
148143
) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
149-
// First, try to simplify any NOT expressions in this expression
150-
let not_simplified = simplify_not_expr(Arc::clone(expr), schema)?;
151-
152-
// If the expression was transformed, we might have created new opportunities for simplification
153-
if not_simplified.transformed {
154-
// Recursively simplify the result
155-
let further_simplified =
156-
simplify_not_expr_recursive(&not_simplified.data, schema)?;
157-
if further_simplified.transformed {
158-
return Ok(Transformed::yes(further_simplified.data));
159-
} else {
160-
return Ok(not_simplified);
144+
let mut current_expr = Arc::clone(expr);
145+
let mut overall_transformed = false;
146+
147+
loop {
148+
let not_simplified = simplify_not_expr_impl(Arc::clone(&current_expr), schema)?;
149+
if not_simplified.transformed {
150+
overall_transformed = true;
151+
current_expr = not_simplified.data;
152+
continue;
161153
}
162-
}
163154

164-
// If this expression wasn't a NOT expression, try to simplify its children
165-
// This handles cases where NOT expressions might be nested deeper in the tree
166-
if let Some(binary_expr) = expr.as_any().downcast_ref::<BinaryExpr>() {
167-
let left_simplified = simplify_not_expr_recursive(binary_expr.left(), schema)?;
168-
let right_simplified = simplify_not_expr_recursive(binary_expr.right(), schema)?;
155+
if let Some(binary_expr) = current_expr.as_any().downcast_ref::<BinaryExpr>() {
156+
let left_simplified = simplify_not_expr(binary_expr.left(), schema)?;
157+
let right_simplified = simplify_not_expr(binary_expr.right(), schema)?;
169158

170-
if left_simplified.transformed || right_simplified.transformed {
171-
let new_binary = Arc::new(BinaryExpr::new(
172-
left_simplified.data,
173-
*binary_expr.op(),
174-
right_simplified.data,
175-
));
176-
return Ok(Transformed::yes(new_binary));
159+
if left_simplified.transformed || right_simplified.transformed {
160+
let new_binary = Arc::new(BinaryExpr::new(
161+
left_simplified.data,
162+
*binary_expr.op(),
163+
right_simplified.data,
164+
));
165+
return Ok(Transformed::yes(new_binary));
166+
}
177167
}
168+
169+
break;
178170
}
179171

180-
Ok(not_simplified)
172+
if overall_transformed {
173+
Ok(Transformed::yes(current_expr))
174+
} else {
175+
Ok(Transformed::no(current_expr))
176+
}
181177
}
182178

183179
/// Returns the negated version of a comparison operator, if possible
@@ -224,7 +220,7 @@ mod tests {
224220
let inner_not = Arc::new(NotExpr::new(Arc::clone(&inner_expr)));
225221
let double_not: Arc<dyn PhysicalExpr> = Arc::new(NotExpr::new(inner_not));
226222

227-
let result = simplify_not_expr_recursive(&double_not, &schema)?;
223+
let result = simplify_not_expr(&double_not, &schema)?;
228224

229225
assert!(result.transformed);
230226
// Should be simplified back to the original b > 5
@@ -238,7 +234,7 @@ mod tests {
238234

239235
// NOT(TRUE) -> FALSE
240236
let not_true = Arc::new(NotExpr::new(lit(ScalarValue::Boolean(Some(true)))));
241-
let result = simplify_not_expr(not_true, &schema)?;
237+
let result = simplify_not_expr_impl(not_true, &schema)?;
242238
assert!(result.transformed);
243239

244240
if let Some(literal) = result.data.as_any().downcast_ref::<Literal>() {
@@ -250,7 +246,7 @@ mod tests {
250246
// NOT(FALSE) -> TRUE
251247
let not_false: Arc<dyn PhysicalExpr> =
252248
Arc::new(NotExpr::new(lit(ScalarValue::Boolean(Some(false)))));
253-
let result = simplify_not_expr_recursive(&not_false, &schema)?;
249+
let result = simplify_not_expr(&not_false, &schema)?;
254250
assert!(result.transformed);
255251

256252
if let Some(literal) = result.data.as_any().downcast_ref::<Literal>() {
@@ -274,7 +270,7 @@ mod tests {
274270
));
275271
let not_eq: Arc<dyn PhysicalExpr> = Arc::new(NotExpr::new(eq_expr));
276272

277-
let result = simplify_not_expr_recursive(&not_eq, &schema)?;
273+
let result = simplify_not_expr(&not_eq, &schema)?;
278274
assert!(result.transformed);
279275

280276
if let Some(binary) = result.data.as_any().downcast_ref::<BinaryExpr>() {
@@ -298,7 +294,7 @@ mod tests {
298294
));
299295
let not_and: Arc<dyn PhysicalExpr> = Arc::new(NotExpr::new(and_expr));
300296

301-
let result = simplify_not_expr_recursive(&not_and, &schema)?;
297+
let result = simplify_not_expr(&not_and, &schema)?;
302298
assert!(result.transformed);
303299

304300
if let Some(binary) = result.data.as_any().downcast_ref::<BinaryExpr>() {
@@ -325,7 +321,7 @@ mod tests {
325321
));
326322
let not_or: Arc<dyn PhysicalExpr> = Arc::new(NotExpr::new(or_expr));
327323

328-
let result = simplify_not_expr_recursive(&not_or, &schema)?;
324+
let result = simplify_not_expr(&not_or, &schema)?;
329325
assert!(result.transformed);
330326

331327
if let Some(binary) = result.data.as_any().downcast_ref::<BinaryExpr>() {
@@ -359,7 +355,7 @@ mod tests {
359355
let and_expr = Arc::new(BinaryExpr::new(eq1, Operator::And, eq2));
360356
let not_and: Arc<dyn PhysicalExpr> = Arc::new(NotExpr::new(and_expr));
361357

362-
let result = simplify_not_expr_recursive(&not_and, &schema)?;
358+
let result = simplify_not_expr(&not_and, &schema)?;
363359
assert!(result.transformed, "Expression should be transformed");
364360

365361
// Verify the result is an OR expression
@@ -401,7 +397,7 @@ mod tests {
401397
let and_expr = Arc::new(BinaryExpr::new(not_a, Operator::And, not_b));
402398
let not_and: Arc<dyn PhysicalExpr> = Arc::new(NotExpr::new(and_expr));
403399

404-
let result = simplify_not_expr_recursive(&not_and, &schema)?;
400+
let result = simplify_not_expr(&not_and, &schema)?;
405401
assert!(result.transformed, "Expression should be transformed");
406402

407403
// Verify the result is an OR expression
@@ -435,7 +431,7 @@ mod tests {
435431
let in_list_expr = in_list(col("b", &schema)?, list, &false, &schema)?;
436432
let not_in: Arc<dyn PhysicalExpr> = Arc::new(NotExpr::new(in_list_expr));
437433

438-
let result = simplify_not_expr_recursive(&not_in, &schema)?;
434+
let result = simplify_not_expr(&not_in, &schema)?;
439435
assert!(result.transformed, "Expression should be transformed");
440436

441437
// Verify the result is an InList expression with negated=true
@@ -469,7 +465,7 @@ mod tests {
469465
let not_in_list_expr = in_list(col("b", &schema)?, list, &true, &schema)?;
470466
let not_not_in: Arc<dyn PhysicalExpr> = Arc::new(NotExpr::new(not_in_list_expr));
471467

472-
let result = simplify_not_expr_recursive(&not_not_in, &schema)?;
468+
let result = simplify_not_expr(&not_not_in, &schema)?;
473469
assert!(result.transformed, "Expression should be transformed");
474470

475471
// Verify the result is an InList expression with negated=false
@@ -504,7 +500,7 @@ mod tests {
504500
let not_in = Arc::new(NotExpr::new(in_list_expr));
505501
let double_not: Arc<dyn PhysicalExpr> = Arc::new(NotExpr::new(not_in));
506502

507-
let result = simplify_not_expr_recursive(&double_not, &schema)?;
503+
let result = simplify_not_expr(&double_not, &schema)?;
508504
assert!(result.transformed, "Expression should be transformed");
509505

510506
// After double negation elimination, we should get back the original IN expression
@@ -524,4 +520,51 @@ mod tests {
524520

525521
Ok(())
526522
}
523+
524+
#[test]
525+
fn test_deeply_nested_not() -> Result<()> {
526+
let schema = test_schema();
527+
528+
// Create a deeply nested NOT expression: NOT(NOT(NOT(...NOT(b > 5)...)))
529+
// This tests that we don't get stack overflow with many nested NOTs
530+
let inner_expr: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
531+
col("b", &schema)?,
532+
Operator::Gt,
533+
lit(ScalarValue::Int32(Some(5))),
534+
));
535+
536+
let mut expr = Arc::clone(&inner_expr);
537+
// Create 20000 layers of NOT
538+
for _ in 0..20000 {
539+
expr = Arc::new(NotExpr::new(expr));
540+
}
541+
542+
let result = simplify_not_expr(&expr, &schema)?;
543+
544+
// With 20000 NOTs (even number), should simplify back to the original expression
545+
assert_eq!(
546+
result.data.to_string(),
547+
inner_expr.to_string(),
548+
"Should simplify back to original expression"
549+
);
550+
551+
// Manually dismantle the deep input expression to avoid Stack Overflow on Drop
552+
// If we just let `expr` go out of scope, Rust's recursive Drop will blow the stack.
553+
// We peel off layers one by one.
554+
while let Some(not_expr) = expr.as_any().downcast_ref::<NotExpr>() {
555+
// Clone the child (Arc increment).
556+
// Now child has 2 refs: one in parent, one in `child`.
557+
let child = Arc::clone(not_expr.arg());
558+
559+
// Reassign `expr` to `child`.
560+
// This drops the old `expr` (Parent).
561+
// Parent refcount -> 0, Parent is dropped.
562+
// Parent drops its reference to Child.
563+
// Child refcount decrements 2 -> 1.
564+
// Child is NOT dropped recursively because we still hold it in `expr`
565+
expr = child;
566+
}
567+
568+
Ok(())
569+
}
527570
}

0 commit comments

Comments
 (0)