@@ -26,7 +26,7 @@ use std::sync::Arc;
2626use crate :: PhysicalOptimizerRule ;
2727
2828use datafusion_common:: config:: ConfigOptions ;
29- use datafusion_common:: tree_node:: { Transformed , TreeNode , TreeNodeRecursion } ;
29+ use datafusion_common:: tree_node:: { Transformed , TreeNode } ;
3030use datafusion_common:: Result ;
3131use datafusion_physical_plan:: coop:: CooperativeExec ;
3232use datafusion_physical_plan:: execution_plan:: { EvaluationType , SchedulingType } ;
@@ -67,23 +67,57 @@ impl PhysicalOptimizerRule for EnsureCooperative {
6767 plan : Arc < dyn ExecutionPlan > ,
6868 _config : & ConfigOptions ,
6969 ) -> Result < Arc < dyn ExecutionPlan > > {
70- plan. transform_up ( |plan| {
71- let is_leaf = plan. children ( ) . is_empty ( ) ;
72- let is_exchange = plan. properties ( ) . evaluation_type == EvaluationType :: Eager ;
73- if ( is_leaf || is_exchange)
74- && plan. properties ( ) . scheduling_type != SchedulingType :: Cooperative
75- {
76- // Wrap non-cooperative leaves or eager evaluation roots in a cooperative exec to
77- // ensure the plans they participate in are properly cooperative.
78- Ok ( Transformed :: new (
79- Arc :: new ( CooperativeExec :: new ( Arc :: clone ( & plan) ) ) ,
80- true ,
81- TreeNodeRecursion :: Continue ,
82- ) )
83- } else {
70+ use std:: cell:: RefCell ;
71+
72+ let ancestry_stack = RefCell :: new ( Vec :: < ( SchedulingType , EvaluationType ) > :: new ( ) ) ;
73+
74+ plan. transform_down_up (
75+ // Down phase: Push parent properties <SchedulingType, EvaluationType> into the stack
76+ |plan| {
77+ let props = plan. properties ( ) ;
78+ ancestry_stack
79+ . borrow_mut ( )
80+ . push ( ( props. scheduling_type , props. evaluation_type ) ) ;
8481 Ok ( Transformed :: no ( plan) )
85- }
86- } )
82+ } ,
83+ // Up phase: Wrap nodes with CooperativeExec if needed
84+ |plan| {
85+ ancestry_stack. borrow_mut ( ) . pop ( ) ;
86+
87+ let props = plan. properties ( ) ;
88+ let is_cooperative = props. scheduling_type == SchedulingType :: Cooperative ;
89+ let is_leaf = plan. children ( ) . is_empty ( ) ;
90+ let is_exchange = props. evaluation_type == EvaluationType :: Eager ;
91+
92+ let mut is_under_cooperative_context = false ;
93+ for ( scheduling_type, evaluation_type) in
94+ ancestry_stack. borrow ( ) . iter ( ) . rev ( )
95+ {
96+ // If nearest ancestor is cooperative, we are under a cooperative context
97+ if * scheduling_type == SchedulingType :: Cooperative {
98+ is_under_cooperative_context = true ;
99+ break ;
100+ // If nearest ancestor is eager, the cooperative context will be reset
101+ } else if * evaluation_type == EvaluationType :: Eager {
102+ is_under_cooperative_context = false ;
103+ break ;
104+ }
105+ }
106+
107+ // Wrap if:
108+ // 1. Node is a leaf or exchange point
109+ // 2. Node is not already cooperative
110+ // 3. Not under any Cooperative context
111+ if ( is_leaf || is_exchange)
112+ && !is_cooperative
113+ && !is_under_cooperative_context
114+ {
115+ return Ok ( Transformed :: yes ( Arc :: new ( CooperativeExec :: new ( plan) ) ) ) ;
116+ }
117+
118+ Ok ( Transformed :: no ( plan) )
119+ } ,
120+ )
87121 . map ( |t| t. data )
88122 }
89123
@@ -110,9 +144,269 @@ mod tests {
110144
111145 let display = displayable ( optimized. as_ref ( ) ) . indent ( true ) . to_string ( ) ;
112146 // Use insta snapshot to ensure full plan structure
113- assert_snapshot ! ( display, @r###"
114- CooperativeExec
115- DataSourceExec: partitions=1, partition_sizes=[1]
116- "### ) ;
147+ assert_snapshot ! ( display, @r"
148+ CooperativeExec
149+ DataSourceExec: partitions=1, partition_sizes=[1]
150+ " ) ;
151+ }
152+
153+ #[ tokio:: test]
154+ async fn test_optimizer_is_idempotent ( ) {
155+ // Comprehensive idempotency test: verify f(f(...f(x))) = f(x)
156+ // This test covers:
157+ // 1. Multiple runs on unwrapped plan
158+ // 2. Multiple runs on already-wrapped plan
159+ // 3. No accumulation of CooperativeExec nodes
160+
161+ let config = ConfigOptions :: new ( ) ;
162+ let rule = EnsureCooperative :: new ( ) ;
163+
164+ // Test 1: Start with unwrapped plan, run multiple times
165+ let unwrapped_plan = scan_partitioned ( 1 ) ;
166+ let mut current = unwrapped_plan;
167+ let mut stable_result = String :: new ( ) ;
168+
169+ for run in 1 ..=5 {
170+ current = rule. optimize ( current, & config) . unwrap ( ) ;
171+ let display = displayable ( current. as_ref ( ) ) . indent ( true ) . to_string ( ) ;
172+
173+ if run == 1 {
174+ stable_result = display. clone ( ) ;
175+ assert_eq ! ( display. matches( "CooperativeExec" ) . count( ) , 1 ) ;
176+ } else {
177+ assert_eq ! (
178+ display, stable_result,
179+ "Run {run} should match run 1 (idempotent)"
180+ ) ;
181+ assert_eq ! (
182+ display. matches( "CooperativeExec" ) . count( ) ,
183+ 1 ,
184+ "Should always have exactly 1 CooperativeExec, not accumulate"
185+ ) ;
186+ }
187+ }
188+
189+ // Test 2: Start with already-wrapped plan, verify no double wrapping
190+ let pre_wrapped = Arc :: new ( CooperativeExec :: new ( scan_partitioned ( 1 ) ) ) ;
191+ let result = rule. optimize ( pre_wrapped, & config) . unwrap ( ) ;
192+ let display = displayable ( result. as_ref ( ) ) . indent ( true ) . to_string ( ) ;
193+
194+ assert_eq ! (
195+ display. matches( "CooperativeExec" ) . count( ) ,
196+ 1 ,
197+ "Should not double-wrap already cooperative plans"
198+ ) ;
199+ assert_eq ! (
200+ display, stable_result,
201+ "Pre-wrapped plan should produce same result as unwrapped after optimization"
202+ ) ;
203+ }
204+
205+ #[ tokio:: test]
206+ async fn test_selective_wrapping ( ) {
207+ // Test that wrapping is selective: only leaf/eager nodes, not intermediate nodes
208+ // Also verify depth tracking prevents double wrapping in subtrees
209+ use datafusion_physical_expr:: expressions:: lit;
210+ use datafusion_physical_plan:: filter:: FilterExec ;
211+
212+ let config = ConfigOptions :: new ( ) ;
213+ let rule = EnsureCooperative :: new ( ) ;
214+
215+ // Case 1: Filter -> Scan (middle node should not be wrapped)
216+ let scan = scan_partitioned ( 1 ) ;
217+ let filter = Arc :: new ( FilterExec :: try_new ( lit ( true ) , scan) . unwrap ( ) ) ;
218+ let optimized = rule. optimize ( filter, & config) . unwrap ( ) ;
219+ let display = displayable ( optimized. as_ref ( ) ) . indent ( true ) . to_string ( ) ;
220+
221+ assert_eq ! ( display. matches( "CooperativeExec" ) . count( ) , 1 ) ;
222+ assert ! ( display. contains( "FilterExec" ) ) ;
223+
224+ // Case 2: Filter -> CoopExec -> Scan (depth tracking prevents double wrap)
225+ let scan2 = scan_partitioned ( 1 ) ;
226+ let wrapped_scan = Arc :: new ( CooperativeExec :: new ( scan2) ) ;
227+ let filter2 = Arc :: new ( FilterExec :: try_new ( lit ( true ) , wrapped_scan) . unwrap ( ) ) ;
228+ let optimized2 = rule. optimize ( filter2, & config) . unwrap ( ) ;
229+ let display2 = displayable ( optimized2. as_ref ( ) ) . indent ( true ) . to_string ( ) ;
230+
231+ assert_eq ! ( display2. matches( "CooperativeExec" ) . count( ) , 1 ) ;
232+ }
233+
234+ #[ tokio:: test]
235+ async fn test_multiple_leaf_nodes ( ) {
236+ // When there are multiple leaf nodes, each should be wrapped separately
237+ use datafusion_physical_plan:: union:: UnionExec ;
238+
239+ let scan1 = scan_partitioned ( 1 ) ;
240+ let scan2 = scan_partitioned ( 1 ) ;
241+ let union = UnionExec :: try_new ( vec ! [ scan1, scan2] ) . unwrap ( ) ;
242+
243+ let config = ConfigOptions :: new ( ) ;
244+ let optimized = EnsureCooperative :: new ( )
245+ . optimize ( union as Arc < dyn ExecutionPlan > , & config)
246+ . unwrap ( ) ;
247+
248+ let display = displayable ( optimized. as_ref ( ) ) . indent ( true ) . to_string ( ) ;
249+
250+ // Each leaf should have its own CooperativeExec
251+ assert_eq ! (
252+ display. matches( "CooperativeExec" ) . count( ) ,
253+ 2 ,
254+ "Each leaf node should be wrapped separately"
255+ ) ;
256+ assert_eq ! (
257+ display. matches( "DataSourceExec" ) . count( ) ,
258+ 2 ,
259+ "Both data sources should be present"
260+ ) ;
261+ }
262+
263+ #[ tokio:: test]
264+ async fn test_eager_evaluation_resets_cooperative_context ( ) {
265+ // Test that cooperative context is reset when encountering an eager evaluation boundary.
266+ use arrow:: datatypes:: Schema ;
267+ use datafusion_common:: { internal_err, Result } ;
268+ use datafusion_execution:: TaskContext ;
269+ use datafusion_physical_expr:: EquivalenceProperties ;
270+ use datafusion_physical_plan:: {
271+ execution_plan:: { Boundedness , EmissionType } ,
272+ DisplayAs , DisplayFormatType , Partitioning , PlanProperties ,
273+ SendableRecordBatchStream ,
274+ } ;
275+ use std:: any:: Any ;
276+ use std:: fmt:: Formatter ;
277+
278+ #[ derive( Debug ) ]
279+ struct DummyExec {
280+ name : String ,
281+ input : Arc < dyn ExecutionPlan > ,
282+ scheduling_type : SchedulingType ,
283+ evaluation_type : EvaluationType ,
284+ properties : PlanProperties ,
285+ }
286+
287+ impl DummyExec {
288+ fn new (
289+ name : & str ,
290+ input : Arc < dyn ExecutionPlan > ,
291+ scheduling_type : SchedulingType ,
292+ evaluation_type : EvaluationType ,
293+ ) -> Self {
294+ let properties = PlanProperties :: new (
295+ EquivalenceProperties :: new ( Arc :: new ( Schema :: empty ( ) ) ) ,
296+ Partitioning :: UnknownPartitioning ( 1 ) ,
297+ EmissionType :: Incremental ,
298+ Boundedness :: Bounded ,
299+ )
300+ . with_scheduling_type ( scheduling_type)
301+ . with_evaluation_type ( evaluation_type) ;
302+
303+ Self {
304+ name : name. to_string ( ) ,
305+ input,
306+ scheduling_type,
307+ evaluation_type,
308+ properties,
309+ }
310+ }
311+ }
312+
313+ impl DisplayAs for DummyExec {
314+ fn fmt_as (
315+ & self ,
316+ _: DisplayFormatType ,
317+ f : & mut Formatter ,
318+ ) -> std:: fmt:: Result {
319+ write ! ( f, "{}" , self . name)
320+ }
321+ }
322+
323+ impl ExecutionPlan for DummyExec {
324+ fn name ( & self ) -> & str {
325+ & self . name
326+ }
327+ fn as_any ( & self ) -> & dyn Any {
328+ self
329+ }
330+ fn properties ( & self ) -> & PlanProperties {
331+ & self . properties
332+ }
333+ fn children ( & self ) -> Vec < & Arc < dyn ExecutionPlan > > {
334+ vec ! [ & self . input]
335+ }
336+ fn with_new_children (
337+ self : Arc < Self > ,
338+ children : Vec < Arc < dyn ExecutionPlan > > ,
339+ ) -> Result < Arc < dyn ExecutionPlan > > {
340+ Ok ( Arc :: new ( DummyExec :: new (
341+ & self . name ,
342+ Arc :: clone ( & children[ 0 ] ) ,
343+ self . scheduling_type ,
344+ self . evaluation_type ,
345+ ) ) )
346+ }
347+ fn execute (
348+ & self ,
349+ _: usize ,
350+ _: Arc < TaskContext > ,
351+ ) -> Result < SendableRecordBatchStream > {
352+ internal_err ! ( "DummyExec does not support execution" )
353+ }
354+ }
355+
356+ // Build a plan similar to the original test:
357+ // scan -> exch1(NonCoop,Eager) -> CoopExec -> filter -> exch2(Coop,Eager) -> filter
358+ let scan = scan_partitioned ( 1 ) ;
359+ let exch1 = Arc :: new ( DummyExec :: new (
360+ "exch1" ,
361+ scan,
362+ SchedulingType :: NonCooperative ,
363+ EvaluationType :: Eager ,
364+ ) ) ;
365+ let coop = Arc :: new ( CooperativeExec :: new ( exch1) ) ;
366+ let filter1 = Arc :: new ( DummyExec :: new (
367+ "filter1" ,
368+ coop,
369+ SchedulingType :: NonCooperative ,
370+ EvaluationType :: Lazy ,
371+ ) ) ;
372+ let exch2 = Arc :: new ( DummyExec :: new (
373+ "exch2" ,
374+ filter1,
375+ SchedulingType :: Cooperative ,
376+ EvaluationType :: Eager ,
377+ ) ) ;
378+ let filter2 = Arc :: new ( DummyExec :: new (
379+ "filter2" ,
380+ exch2,
381+ SchedulingType :: NonCooperative ,
382+ EvaluationType :: Lazy ,
383+ ) ) ;
384+
385+ let config = ConfigOptions :: new ( ) ;
386+ let optimized = EnsureCooperative :: new ( ) . optimize ( filter2, & config) . unwrap ( ) ;
387+
388+ let display = displayable ( optimized. as_ref ( ) ) . indent ( true ) . to_string ( ) ;
389+
390+ // Expected wrapping:
391+ // - Scan (leaf) gets wrapped
392+ // - exch1 (eager+noncoop) keeps its manual CooperativeExec wrapper
393+ // - filter1 is protected by exch2's cooperative context, no extra wrap
394+ // - exch2 (already Cooperative) does NOT get wrapped
395+ // - filter2 (not leaf or eager) does NOT get wrapped
396+ assert_eq ! (
397+ display. matches( "CooperativeExec" ) . count( ) ,
398+ 2 ,
399+ "Should have 2 CooperativeExec: one wrapping scan, one wrapping exch1"
400+ ) ;
401+
402+ assert_snapshot ! ( display, @r"
403+ filter2
404+ exch2
405+ filter1
406+ CooperativeExec
407+ exch1
408+ CooperativeExec
409+ DataSourceExec: partitions=1, partition_sizes=[1]
410+ " ) ;
117411 }
118412}
0 commit comments