Skip to content

Commit 0d075ba

Browse files
committed
support not list
1 parent a0116ce commit 0d075ba

2 files changed

Lines changed: 123 additions & 19 deletions

File tree

datafusion/physical-expr/src/simplifier/mod.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,9 @@ impl<'a> TreeNodeRewriter for PhysicalExprSimplifier<'a> {
5858

5959
fn f_up(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
6060
// Apply NOT expression simplification first
61-
let not_simplified = simplify_not_expr_recursive(&node, self.schema)?;
62-
let node = not_simplified.data;
63-
let transformed = not_simplified.transformed;
61+
let not_expr_simplified = simplify_not_expr_recursive(&node, self.schema)?;
62+
let node = not_expr_simplified.data;
63+
let transformed = not_expr_simplified.transformed;
6464

6565
// Apply unwrap cast optimization
6666
#[cfg(test)]

datafusion/physical-expr/src/simplifier/not.rs

Lines changed: 120 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ use arrow::datatypes::Schema;
3030
use datafusion_common::{tree_node::Transformed, Result, ScalarValue};
3131
use datafusion_expr::Operator;
3232

33-
use crate::expressions::{lit, BinaryExpr, Literal, NotExpr};
33+
use crate::expressions::{in_list, lit, BinaryExpr, InListExpr, Literal, NotExpr};
3434
use crate::PhysicalExpr;
3535

3636
/// Attempts to simplify NOT expressions
@@ -64,6 +64,19 @@ pub(crate) fn simplify_not_expr(
6464
}
6565
}
6666

67+
// Handle NOT(IN list) -> NOT IN list
68+
if let Some(in_list_expr) = inner_expr.as_any().downcast_ref::<InListExpr>() {
69+
// Create a new InList expression with negated flag flipped
70+
let negated = !in_list_expr.negated();
71+
let new_in_list = in_list(
72+
Arc::clone(in_list_expr.expr()),
73+
in_list_expr.list().to_vec(),
74+
&negated,
75+
schema,
76+
)?;
77+
return Ok(Transformed::yes(new_in_list));
78+
}
79+
6780
// Handle NOT(binary_expr) where we can flip the operator
6881
if let Some(binary_expr) = inner_expr.as_any().downcast_ref::<BinaryExpr>() {
6982
if let Some(negated_op) = negate_operator(binary_expr.op()) {
@@ -186,7 +199,7 @@ fn negate_operator(op: &Operator) -> Option<Operator> {
186199
#[cfg(test)]
187200
mod tests {
188201
use super::*;
189-
use crate::expressions::{col, lit, BinaryExpr, NotExpr};
202+
use crate::expressions::{col, in_list, lit, BinaryExpr, NotExpr};
190203
use arrow::datatypes::{DataType, Field, Schema};
191204
use datafusion_common::ScalarValue;
192205
use datafusion_expr::Operator;
@@ -396,26 +409,117 @@ mod tests {
396409
assert_eq!(or_binary.op(), &Operator::Or, "Top level should be OR");
397410

398411
// Verify left side is just 'a'
412+
assert!(or_binary.left().as_any().downcast_ref::<NotExpr>().is_none(),
413+
"Left should not be a NOT expression, it should be simplified to just 'a'");
414+
415+
// Verify right side is just 'b'
416+
assert!(or_binary.right().as_any().downcast_ref::<NotExpr>().is_none(),
417+
"Right should not be a NOT expression, it should be simplified to just 'b'");
418+
} else {
419+
panic!("Expected binary OR expression result");
420+
}
421+
422+
Ok(())
423+
}
424+
425+
#[test]
426+
fn test_not_in_list() -> Result<()> {
427+
let schema = test_schema();
428+
429+
// NOT(b IN (1, 2, 3)) -> b NOT IN (1, 2, 3)
430+
let list = vec![
431+
lit(ScalarValue::Int32(Some(1))),
432+
lit(ScalarValue::Int32(Some(2))),
433+
lit(ScalarValue::Int32(Some(3))),
434+
];
435+
let in_list_expr = in_list(col("b", &schema)?, list, &false, &schema)?;
436+
let not_in: Arc<dyn PhysicalExpr> = Arc::new(NotExpr::new(in_list_expr));
437+
438+
let result = simplify_not_expr_recursive(&not_in, &schema)?;
439+
assert!(result.transformed, "Expression should be transformed");
440+
441+
// Verify the result is an InList expression with negated=true
442+
if let Some(in_list_result) = result.data.as_any().downcast_ref::<InListExpr>() {
399443
assert!(
400-
or_binary
401-
.left()
402-
.as_any()
403-
.downcast_ref::<NotExpr>()
404-
.is_none(),
405-
"Left should be simplified to just 'a'"
444+
in_list_result.negated(),
445+
"InList should be negated (NOT IN)"
446+
);
447+
assert_eq!(
448+
in_list_result.list().len(),
449+
3,
450+
"Should have 3 items in list"
406451
);
452+
} else {
453+
panic!("Expected InListExpr result");
454+
}
407455

408-
// Verify right side is just 'b'
456+
Ok(())
457+
}
458+
459+
#[test]
460+
fn test_not_not_in_list() -> Result<()> {
461+
let schema = test_schema();
462+
463+
// NOT(b NOT IN (1, 2, 3)) -> b IN (1, 2, 3)
464+
let list = vec![
465+
lit(ScalarValue::Int32(Some(1))),
466+
lit(ScalarValue::Int32(Some(2))),
467+
lit(ScalarValue::Int32(Some(3))),
468+
];
469+
let not_in_list_expr = in_list(col("b", &schema)?, list, &true, &schema)?;
470+
let not_not_in: Arc<dyn PhysicalExpr> = Arc::new(NotExpr::new(not_in_list_expr));
471+
472+
let result = simplify_not_expr_recursive(&not_not_in, &schema)?;
473+
assert!(result.transformed, "Expression should be transformed");
474+
475+
// Verify the result is an InList expression with negated=false
476+
if let Some(in_list_result) = result.data.as_any().downcast_ref::<InListExpr>() {
409477
assert!(
410-
or_binary
411-
.right()
412-
.as_any()
413-
.downcast_ref::<NotExpr>()
414-
.is_none(),
415-
"Right should be simplified to just 'b'"
478+
!in_list_result.negated(),
479+
"InList should not be negated (IN)"
480+
);
481+
assert_eq!(
482+
in_list_result.list().len(),
483+
3,
484+
"Should have 3 items in list"
416485
);
417486
} else {
418-
panic!("Expected binary OR expression result");
487+
panic!("Expected InListExpr result");
488+
}
489+
490+
Ok(())
491+
}
492+
493+
#[test]
494+
fn test_double_not_in_list() -> Result<()> {
495+
let schema = test_schema();
496+
497+
// NOT(NOT(b IN (1, 2, 3))) -> b IN (1, 2, 3)
498+
let list = vec![
499+
lit(ScalarValue::Int32(Some(1))),
500+
lit(ScalarValue::Int32(Some(2))),
501+
lit(ScalarValue::Int32(Some(3))),
502+
];
503+
let in_list_expr = in_list(col("b", &schema)?, list, &false, &schema)?;
504+
let not_in = Arc::new(NotExpr::new(in_list_expr));
505+
let double_not: Arc<dyn PhysicalExpr> = Arc::new(NotExpr::new(not_in));
506+
507+
let result = simplify_not_expr_recursive(&double_not, &schema)?;
508+
assert!(result.transformed, "Expression should be transformed");
509+
510+
// After double negation elimination, we should get back the original IN expression
511+
if let Some(in_list_result) = result.data.as_any().downcast_ref::<InListExpr>() {
512+
assert!(
513+
!in_list_result.negated(),
514+
"InList should not be negated (IN)"
515+
);
516+
assert_eq!(
517+
in_list_result.list().len(),
518+
3,
519+
"Should have 3 items in list"
520+
);
521+
} else {
522+
panic!("Expected InListExpr result");
419523
}
420524

421525
Ok(())

0 commit comments

Comments
 (0)