@@ -89,6 +89,7 @@ impl OptimizerRule for EliminateCrossJoin {
8989 let mut possible_join_keys = JoinKeySet :: new ( ) ;
9090 let mut all_inputs: Vec < LogicalPlan > = vec ! [ ] ;
9191 let mut all_filters: Vec < Expr > = vec ! [ ] ;
92+ let mut null_equals_null = false ;
9293
9394 let parent_predicate = if let LogicalPlan :: Filter ( filter) = plan {
9495 // if input isn't a join that can potentially be rewritten
@@ -113,6 +114,12 @@ impl OptimizerRule for EliminateCrossJoin {
113114 let Filter {
114115 input, predicate, ..
115116 } = filter;
117+
118+ // Extract null_equals_null setting from the input join
119+ if let LogicalPlan :: Join ( join) = input. as_ref ( ) {
120+ null_equals_null = join. null_equals_null ;
121+ }
122+
116123 flatten_join_inputs (
117124 Arc :: unwrap_or_clone ( input) ,
118125 & mut possible_join_keys,
@@ -122,26 +129,30 @@ impl OptimizerRule for EliminateCrossJoin {
122129
123130 extract_possible_join_keys ( & predicate, & mut possible_join_keys) ;
124131 Some ( predicate)
125- } else if matches ! (
126- plan,
127- LogicalPlan :: Join ( Join {
128- join_type: JoinType :: Inner ,
129- ..
130- } )
131- ) {
132- if !can_flatten_join_inputs ( & plan) {
133- return Ok ( Transformed :: no ( plan) ) ;
134- }
135- flatten_join_inputs (
136- plan,
137- & mut possible_join_keys,
138- & mut all_inputs,
139- & mut all_filters,
140- ) ?;
141- None
142132 } else {
143- // recursively try to rewrite children
144- return rewrite_children ( self , plan, config) ;
133+ match plan {
134+ LogicalPlan :: Join ( Join {
135+ join_type : JoinType :: Inner ,
136+ null_equals_null : original_null_equals_null,
137+ ..
138+ } ) => {
139+ if !can_flatten_join_inputs ( & plan) {
140+ return Ok ( Transformed :: no ( plan) ) ;
141+ }
142+ flatten_join_inputs (
143+ plan,
144+ & mut possible_join_keys,
145+ & mut all_inputs,
146+ & mut all_filters,
147+ ) ?;
148+ null_equals_null = original_null_equals_null;
149+ None
150+ }
151+ _ => {
152+ // recursively try to rewrite children
153+ return rewrite_children ( self , plan, config) ;
154+ }
155+ }
145156 } ;
146157
147158 // Join keys are handled locally:
@@ -153,6 +164,7 @@ impl OptimizerRule for EliminateCrossJoin {
153164 & mut all_inputs,
154165 & possible_join_keys,
155166 & mut all_join_keys,
167+ null_equals_null,
156168 ) ?;
157169 }
158170
@@ -290,6 +302,7 @@ fn find_inner_join(
290302 rights : & mut Vec < LogicalPlan > ,
291303 possible_join_keys : & JoinKeySet ,
292304 all_join_keys : & mut JoinKeySet ,
305+ null_equals_null : bool ,
293306) -> Result < LogicalPlan > {
294307 for ( i, right_input) in rights. iter ( ) . enumerate ( ) {
295308 let mut join_keys = vec ! [ ] ;
@@ -328,7 +341,7 @@ fn find_inner_join(
328341 on : join_keys,
329342 filter : None ,
330343 schema : join_schema,
331- null_equals_null : false ,
344+ null_equals_null,
332345 } ) ) ;
333346 }
334347 }
@@ -350,7 +363,7 @@ fn find_inner_join(
350363 filter : None ,
351364 join_type : JoinType :: Inner ,
352365 join_constraint : JoinConstraint :: On ,
353- null_equals_null : false ,
366+ null_equals_null,
354367 } ) )
355368}
356369
@@ -1333,4 +1346,69 @@ mod tests {
13331346 "
13341347 )
13351348 }
1349+
1350+ #[ test]
1351+ fn preserve_null_equals_null_setting ( ) -> Result < ( ) > {
1352+ let t1 = test_table_scan_with_name ( "t1" ) ?;
1353+ let t2 = test_table_scan_with_name ( "t2" ) ?;
1354+
1355+ // Create an inner join with null_equals_null: true
1356+ let join_schema = Arc :: new ( build_join_schema (
1357+ t1. schema ( ) ,
1358+ t2. schema ( ) ,
1359+ & JoinType :: Inner ,
1360+ ) ?) ;
1361+
1362+ let inner_join = LogicalPlan :: Join ( Join {
1363+ left : Arc :: new ( t1) ,
1364+ right : Arc :: new ( t2) ,
1365+ join_type : JoinType :: Inner ,
1366+ join_constraint : JoinConstraint :: On ,
1367+ on : vec ! [ ] ,
1368+ filter : None ,
1369+ schema : join_schema,
1370+ null_equals_null : true , // Set to true to test preservation
1371+ } ) ;
1372+
1373+ // Apply filter that can create join conditions
1374+ let plan = LogicalPlanBuilder :: from ( inner_join)
1375+ . filter ( binary_expr (
1376+ col ( "t1.a" ) . eq ( col ( "t2.a" ) ) ,
1377+ And ,
1378+ col ( "t2.c" ) . lt ( lit ( 20u32 ) ) ,
1379+ ) ) ?
1380+ . build ( ) ?;
1381+
1382+ let rule = EliminateCrossJoin :: new ( ) ;
1383+ let optimized_plan = rule. rewrite ( plan, & OptimizerContext :: new ( ) ) ?. data ;
1384+
1385+ // Verify that null_equals_null is preserved in the optimized plan
1386+ fn check_null_equals_null_preserved ( plan : & LogicalPlan ) -> bool {
1387+ match plan {
1388+ LogicalPlan :: Join ( join) => {
1389+ // All joins in the optimized plan should preserve null_equals_null: true
1390+ if !join. null_equals_null {
1391+ return false ;
1392+ }
1393+ // Recursively check child plans
1394+ plan. inputs ( )
1395+ . iter ( )
1396+ . all ( |input| check_null_equals_null_preserved ( input) )
1397+ }
1398+ _ => {
1399+ // Recursively check child plans for non-join nodes
1400+ plan. inputs ( )
1401+ . iter ( )
1402+ . all ( |input| check_null_equals_null_preserved ( input) )
1403+ }
1404+ }
1405+ }
1406+
1407+ assert ! (
1408+ check_null_equals_null_preserved( & optimized_plan) ,
1409+ "null_equals_null setting should be preserved after optimization"
1410+ ) ;
1411+
1412+ Ok ( ( ) )
1413+ }
13361414}
0 commit comments