@@ -30,7 +30,7 @@ use arrow::datatypes::Schema;
3030use datafusion_common:: { tree_node:: Transformed , Result , ScalarValue } ;
3131use datafusion_expr:: Operator ;
3232
33- use crate :: expressions:: { lit, BinaryExpr , Literal , NotExpr } ;
33+ use crate :: expressions:: { in_list , lit, BinaryExpr , InListExpr , Literal , NotExpr } ;
3434use crate :: PhysicalExpr ;
3535
3636/// Attempts to simplify NOT expressions
@@ -64,6 +64,19 @@ pub(crate) fn simplify_not_expr(
6464 }
6565 }
6666
67+ // Handle NOT(IN list) -> NOT IN list
68+ if let Some ( in_list_expr) = inner_expr. as_any ( ) . downcast_ref :: < InListExpr > ( ) {
69+ // Create a new InList expression with negated flag flipped
70+ let negated = !in_list_expr. negated ( ) ;
71+ let new_in_list = in_list (
72+ Arc :: clone ( in_list_expr. expr ( ) ) ,
73+ in_list_expr. list ( ) . to_vec ( ) ,
74+ & negated,
75+ schema,
76+ ) ?;
77+ return Ok ( Transformed :: yes ( new_in_list) ) ;
78+ }
79+
6780 // Handle NOT(binary_expr) where we can flip the operator
6881 if let Some ( binary_expr) = inner_expr. as_any ( ) . downcast_ref :: < BinaryExpr > ( ) {
6982 if let Some ( negated_op) = negate_operator ( binary_expr. op ( ) ) {
@@ -186,7 +199,7 @@ fn negate_operator(op: &Operator) -> Option<Operator> {
186199#[ cfg( test) ]
187200mod tests {
188201 use super :: * ;
189- use crate :: expressions:: { col, lit, BinaryExpr , NotExpr } ;
202+ use crate :: expressions:: { col, in_list , lit, BinaryExpr , NotExpr } ;
190203 use arrow:: datatypes:: { DataType , Field , Schema } ;
191204 use datafusion_common:: ScalarValue ;
192205 use datafusion_expr:: Operator ;
@@ -396,26 +409,117 @@ mod tests {
396409 assert_eq ! ( or_binary. op( ) , & Operator :: Or , "Top level should be OR" ) ;
397410
398411 // Verify left side is just 'a'
412+ assert ! ( or_binary. left( ) . as_any( ) . downcast_ref:: <NotExpr >( ) . is_none( ) ,
413+ "Left should not be a NOT expression, it should be simplified to just 'a'" ) ;
414+
415+ // Verify right side is just 'b'
416+ assert ! ( or_binary. right( ) . as_any( ) . downcast_ref:: <NotExpr >( ) . is_none( ) ,
417+ "Right should not be a NOT expression, it should be simplified to just 'b'" ) ;
418+ } else {
419+ panic ! ( "Expected binary OR expression result" ) ;
420+ }
421+
422+ Ok ( ( ) )
423+ }
424+
425+ #[ test]
426+ fn test_not_in_list ( ) -> Result < ( ) > {
427+ let schema = test_schema ( ) ;
428+
429+ // NOT(b IN (1, 2, 3)) -> b NOT IN (1, 2, 3)
430+ let list = vec ! [
431+ lit( ScalarValue :: Int32 ( Some ( 1 ) ) ) ,
432+ lit( ScalarValue :: Int32 ( Some ( 2 ) ) ) ,
433+ lit( ScalarValue :: Int32 ( Some ( 3 ) ) ) ,
434+ ] ;
435+ let in_list_expr = in_list ( col ( "b" , & schema) ?, list, & false , & schema) ?;
436+ let not_in: Arc < dyn PhysicalExpr > = Arc :: new ( NotExpr :: new ( in_list_expr) ) ;
437+
438+ let result = simplify_not_expr_recursive ( & not_in, & schema) ?;
439+ assert ! ( result. transformed, "Expression should be transformed" ) ;
440+
441+ // Verify the result is an InList expression with negated=true
442+ if let Some ( in_list_result) = result. data . as_any ( ) . downcast_ref :: < InListExpr > ( ) {
399443 assert ! (
400- or_binary
401- . left( )
402- . as_any( )
403- . downcast_ref:: <NotExpr >( )
404- . is_none( ) ,
405- "Left should be simplified to just 'a'"
444+ in_list_result. negated( ) ,
445+ "InList should be negated (NOT IN)"
446+ ) ;
447+ assert_eq ! (
448+ in_list_result. list( ) . len( ) ,
449+ 3 ,
450+ "Should have 3 items in list"
406451 ) ;
452+ } else {
453+ panic ! ( "Expected InListExpr result" ) ;
454+ }
407455
408- // Verify right side is just 'b'
456+ Ok ( ( ) )
457+ }
458+
459+ #[ test]
460+ fn test_not_not_in_list ( ) -> Result < ( ) > {
461+ let schema = test_schema ( ) ;
462+
463+ // NOT(b NOT IN (1, 2, 3)) -> b IN (1, 2, 3)
464+ let list = vec ! [
465+ lit( ScalarValue :: Int32 ( Some ( 1 ) ) ) ,
466+ lit( ScalarValue :: Int32 ( Some ( 2 ) ) ) ,
467+ lit( ScalarValue :: Int32 ( Some ( 3 ) ) ) ,
468+ ] ;
469+ let not_in_list_expr = in_list ( col ( "b" , & schema) ?, list, & true , & schema) ?;
470+ let not_not_in: Arc < dyn PhysicalExpr > = Arc :: new ( NotExpr :: new ( not_in_list_expr) ) ;
471+
472+ let result = simplify_not_expr_recursive ( & not_not_in, & schema) ?;
473+ assert ! ( result. transformed, "Expression should be transformed" ) ;
474+
475+ // Verify the result is an InList expression with negated=false
476+ if let Some ( in_list_result) = result. data . as_any ( ) . downcast_ref :: < InListExpr > ( ) {
409477 assert ! (
410- or_binary
411- . right( )
412- . as_any( )
413- . downcast_ref:: <NotExpr >( )
414- . is_none( ) ,
415- "Right should be simplified to just 'b'"
478+ !in_list_result. negated( ) ,
479+ "InList should not be negated (IN)"
480+ ) ;
481+ assert_eq ! (
482+ in_list_result. list( ) . len( ) ,
483+ 3 ,
484+ "Should have 3 items in list"
416485 ) ;
417486 } else {
418- panic ! ( "Expected binary OR expression result" ) ;
487+ panic ! ( "Expected InListExpr result" ) ;
488+ }
489+
490+ Ok ( ( ) )
491+ }
492+
493+ #[ test]
494+ fn test_double_not_in_list ( ) -> Result < ( ) > {
495+ let schema = test_schema ( ) ;
496+
497+ // NOT(NOT(b IN (1, 2, 3))) -> b IN (1, 2, 3)
498+ let list = vec ! [
499+ lit( ScalarValue :: Int32 ( Some ( 1 ) ) ) ,
500+ lit( ScalarValue :: Int32 ( Some ( 2 ) ) ) ,
501+ lit( ScalarValue :: Int32 ( Some ( 3 ) ) ) ,
502+ ] ;
503+ let in_list_expr = in_list ( col ( "b" , & schema) ?, list, & false , & schema) ?;
504+ let not_in = Arc :: new ( NotExpr :: new ( in_list_expr) ) ;
505+ let double_not: Arc < dyn PhysicalExpr > = Arc :: new ( NotExpr :: new ( not_in) ) ;
506+
507+ let result = simplify_not_expr_recursive ( & double_not, & schema) ?;
508+ assert ! ( result. transformed, "Expression should be transformed" ) ;
509+
510+ // After double negation elimination, we should get back the original IN expression
511+ if let Some ( in_list_result) = result. data . as_any ( ) . downcast_ref :: < InListExpr > ( ) {
512+ assert ! (
513+ !in_list_result. negated( ) ,
514+ "InList should not be negated (IN)"
515+ ) ;
516+ assert_eq ! (
517+ in_list_result. list( ) . len( ) ,
518+ 3 ,
519+ "Should have 3 items in list"
520+ ) ;
521+ } else {
522+ panic ! ( "Expected InListExpr result" ) ;
419523 }
420524
421525 Ok ( ( ) )
0 commit comments