@@ -1036,6 +1036,34 @@ impl AggregateExec {
10361036 & self . input_order_mode
10371037 }
10381038
1039+ /// Returns the dynamic filter expression for this aggregate, if set.
1040+ pub fn dynamic_filter ( & self ) -> Option < & Arc < DynamicFilterPhysicalExpr > > {
1041+ self . dynamic_filter . as_ref ( ) . map ( |df| & df. filter )
1042+ }
1043+
1044+ /// Replace the dynamic filter expression, recomputing any internal state
1045+ /// which may depend on the previous dynamic filter.
1046+ ///
1047+ /// This is a no-op if the aggregate does not support dynamic filtering.
1048+ ///
1049+ /// If dynamic filtering is supported, this method returns an error if the filter's
1050+ /// children reference invalid columns in the aggregate's input schema.
1051+ pub fn with_dynamic_filter (
1052+ mut self ,
1053+ filter : Arc < DynamicFilterPhysicalExpr > ,
1054+ ) -> Result < Self > {
1055+ if let Some ( supported_accumulators_info) = self . supported_accumulators_info ( ) {
1056+ for child in filter. children ( ) {
1057+ child. data_type ( & self . input_schema ) ?;
1058+ }
1059+ self . dynamic_filter = Some ( Arc :: new ( AggrDynFilter {
1060+ filter,
1061+ supported_accumulators_info,
1062+ } ) ) ;
1063+ }
1064+ Ok ( self )
1065+ }
1066+
10391067 fn statistics_inner ( & self , child_statistics : & Statistics ) -> Result < Statistics > {
10401068 // TODO stats: group expressions:
10411069 // - once expressions will be able to compute their own stats, use it here
@@ -1116,27 +1144,40 @@ impl AggregateExec {
11161144 /// - If yes, init one inside `AggregateExec`'s `dynamic_filter` field.
11171145 /// - If not supported, `self.dynamic_filter` should be kept `None`
11181146 fn init_dynamic_filter ( & mut self ) {
1119- if ( !self . group_by . is_empty ( ) ) || ( !matches ! ( self . mode, AggregateMode :: Partial ) ) {
1120- debug_assert ! (
1121- self . dynamic_filter. is_none( ) ,
1122- "The current operator node does not support dynamic filter"
1123- ) ;
1124- return ;
1125- }
1126-
11271147 // Already initialized.
11281148 if self . dynamic_filter . is_some ( ) {
11291149 return ;
11301150 }
11311151
1132- // Collect supported accumulators
1133- // It is assumed the order of aggregate expressions are not changed from `AggregateExec`
1134- // to `AggregateStream`
1152+ if let Some ( supported_accumulators_info) = self . supported_accumulators_info ( ) {
1153+ // Collect column references for the dynamic filter expression.
1154+ let all_cols: Vec < Arc < dyn PhysicalExpr > > = supported_accumulators_info
1155+ . iter ( )
1156+ . map ( |info| Arc :: clone ( & self . aggr_expr [ info. aggr_index ] . expressions ( ) [ 0 ] ) )
1157+ . collect ( ) ;
1158+
1159+ self . dynamic_filter = Some ( Arc :: new ( AggrDynFilter {
1160+ filter : Arc :: new ( DynamicFilterPhysicalExpr :: new ( all_cols, lit ( true ) ) ) ,
1161+ supported_accumulators_info,
1162+ } ) ) ;
1163+ }
1164+ }
1165+
1166+ /// Returns the supported accumulator info if this aggregate supports
1167+ /// dynamic filtering, or `None` otherwise.
1168+ ///
1169+ /// Dynamic filtering requires:
1170+ /// - `Partial` aggregation mode with no group-by expressions
1171+ /// - All aggregate functions are `min` or `max` with a single column arg
1172+ fn supported_accumulators_info ( & self ) -> Option < Vec < PerAccumulatorDynFilter > > {
1173+ if !self . group_by . is_empty ( ) || !matches ! ( self . mode, AggregateMode :: Partial ) {
1174+ return None ;
1175+ }
1176+
1177+ // Collect supported accumulators.
1178+ // It is assumed the order of aggregate expressions are not changed
1179+ // from `AggregateExec` to `AggregateStream`.
11351180 let mut aggr_dyn_filters = Vec :: new ( ) ;
1136- // All column references in the dynamic filter, used when initializing the dynamic
1137- // filter, and it's used to decide if this dynamic filter is able to get push
1138- // through certain node during optimization.
1139- let mut all_cols: Vec < Arc < dyn PhysicalExpr > > = Vec :: new ( ) ;
11401181 for ( i, aggr_expr) in self . aggr_expr . iter ( ) . enumerate ( ) {
11411182 // 1. Only `min` or `max` aggregate function
11421183 let fun_name = aggr_expr. fun ( ) . name ( ) ;
@@ -1147,14 +1188,13 @@ impl AggregateExec {
11471188 } else if fun_name. eq_ignore_ascii_case ( "max" ) {
11481189 DynamicFilterAggregateType :: Max
11491190 } else {
1150- return ;
1191+ return None ;
11511192 } ;
11521193
11531194 // 2. arg should be only 1 column reference
11541195 if let [ arg] = aggr_expr. expressions ( ) . as_slice ( )
11551196 && arg. as_any ( ) . is :: < Column > ( )
11561197 {
1157- all_cols. push ( Arc :: clone ( arg) ) ;
11581198 aggr_dyn_filters. push ( PerAccumulatorDynFilter {
11591199 aggr_type,
11601200 aggr_index : i,
@@ -1163,11 +1203,10 @@ impl AggregateExec {
11631203 }
11641204 }
11651205
1166- if !aggr_dyn_filters. is_empty ( ) {
1167- self . dynamic_filter = Some ( Arc :: new ( AggrDynFilter {
1168- filter : Arc :: new ( DynamicFilterPhysicalExpr :: new ( all_cols, lit ( true ) ) ) ,
1169- supported_accumulators_info : aggr_dyn_filters,
1170- } ) )
1206+ if aggr_dyn_filters. is_empty ( ) {
1207+ None
1208+ } else {
1209+ Some ( aggr_dyn_filters)
11711210 }
11721211 }
11731212
@@ -1964,6 +2003,7 @@ mod tests {
19642003 use crate :: coalesce_partitions:: CoalescePartitionsExec ;
19652004 use crate :: common;
19662005 use crate :: common:: collect;
2006+ use crate :: empty:: EmptyExec ;
19672007 use crate :: execution_plan:: Boundedness ;
19682008 use crate :: expressions:: col;
19692009 use crate :: metrics:: MetricValue ;
@@ -1987,6 +2027,7 @@ mod tests {
19872027 use datafusion_functions_aggregate:: count:: count_udaf;
19882028 use datafusion_functions_aggregate:: first_last:: { first_value_udaf, last_value_udaf} ;
19892029 use datafusion_functions_aggregate:: median:: median_udaf;
2030+ use datafusion_functions_aggregate:: min_max:: min_udaf;
19902031 use datafusion_functions_aggregate:: sum:: sum_udaf;
19912032 use datafusion_physical_expr:: Partitioning ;
19922033 use datafusion_physical_expr:: PhysicalSortExpr ;
@@ -3459,13 +3500,10 @@ mod tests {
34593500 // Test with MIN for simple intermediate state (min) and AVG for multiple intermediate states (partial sum, partial count).
34603501 let aggregates: Vec < Arc < AggregateFunctionExpr > > = vec ! [
34613502 Arc :: new(
3462- AggregateExprBuilder :: new(
3463- datafusion_functions_aggregate:: min_max:: min_udaf( ) ,
3464- vec![ col( "b" , & schema) ?] ,
3465- )
3466- . schema( Arc :: clone( & schema) )
3467- . alias( "MIN(b)" )
3468- . build( ) ?,
3503+ AggregateExprBuilder :: new( min_udaf( ) , vec![ col( "b" , & schema) ?] )
3504+ . schema( Arc :: clone( & schema) )
3505+ . alias( "MIN(b)" )
3506+ . build( ) ?,
34693507 ) ,
34703508 Arc :: new(
34713509 AggregateExprBuilder :: new( avg_udaf( ) , vec![ col( "b" , & schema) ?] )
@@ -3604,13 +3642,10 @@ mod tests {
36043642 // Test with MIN for simple intermediate state (min) and AVG for multiple intermediate states (partial sum, partial count).
36053643 let aggregates: Vec < Arc < AggregateFunctionExpr > > = vec ! [
36063644 Arc :: new(
3607- AggregateExprBuilder :: new(
3608- datafusion_functions_aggregate:: min_max:: min_udaf( ) ,
3609- vec![ col( "b" , & schema) ?] ,
3610- )
3611- . schema( Arc :: clone( & schema) )
3612- . alias( "MIN(b)" )
3613- . build( ) ?,
3645+ AggregateExprBuilder :: new( min_udaf( ) , vec![ col( "b" , & schema) ?] )
3646+ . schema( Arc :: clone( & schema) )
3647+ . alias( "MIN(b)" )
3648+ . build( ) ?,
36143649 ) ,
36153650 Arc :: new(
36163651 AggregateExprBuilder :: new( avg_udaf( ) , vec![ col( "b" , & schema) ?] )
@@ -3950,4 +3985,103 @@ mod tests {
39503985
39513986 Ok ( ( ) )
39523987 }
3988+
3989+ #[ test]
3990+ fn test_with_dynamic_filter ( ) -> Result < ( ) > {
3991+ let schema = Arc :: new ( Schema :: new ( vec ! [ Field :: new( "a" , DataType :: Int64 , false ) ] ) ) ;
3992+ let child = Arc :: new ( EmptyExec :: new ( Arc :: clone ( & schema) ) ) ;
3993+
3994+ // Partial min triggers init_dynamic_filter.
3995+ let agg = AggregateExec :: try_new (
3996+ AggregateMode :: Partial ,
3997+ PhysicalGroupBy :: new_single ( vec ! [ ] ) ,
3998+ vec ! [ Arc :: new(
3999+ AggregateExprBuilder :: new( min_udaf( ) , vec![ col( "a" , & schema) ?] )
4000+ . schema( Arc :: clone( & schema) )
4001+ . alias( "min_a" )
4002+ . build( ) ?,
4003+ ) ] ,
4004+ vec ! [ None ] ,
4005+ child,
4006+ Arc :: clone ( & schema) ,
4007+ ) ?;
4008+ let original_inner_id = agg
4009+ . dynamic_filter ( )
4010+ . expect ( "should have dynamic filter after init" )
4011+ . inner_id ( ) ;
4012+
4013+ let new_df = Arc :: new ( DynamicFilterPhysicalExpr :: new (
4014+ vec ! [ col( "a" , & schema) ?] ,
4015+ lit ( true ) ,
4016+ ) ) ;
4017+ let agg = agg. with_dynamic_filter ( Arc :: clone ( & new_df) ) ?;
4018+ let restored = agg
4019+ . dynamic_filter ( )
4020+ . expect ( "should still have dynamic filter" ) ;
4021+ assert_eq ! ( restored. inner_id( ) , new_df. inner_id( ) ) ;
4022+ assert_ne ! ( restored. inner_id( ) , original_inner_id) ;
4023+ Ok ( ( ) )
4024+ }
4025+
4026+ #[ test]
4027+ fn test_with_dynamic_filter_noop_when_unsupported ( ) -> Result < ( ) > {
4028+ let schema = Arc :: new ( Schema :: new ( vec ! [
4029+ Field :: new( "a" , DataType :: Int64 , false ) ,
4030+ Field :: new( "b" , DataType :: Int64 , false ) ,
4031+ ] ) ) ;
4032+ let child = Arc :: new ( EmptyExec :: new ( Arc :: clone ( & schema) ) ) ;
4033+
4034+ // Final mode with a group-by does not support dynamic filters.
4035+ let agg = AggregateExec :: try_new (
4036+ AggregateMode :: Final ,
4037+ PhysicalGroupBy :: new_single ( vec ! [ ( col( "a" , & schema) ?, "a" . to_string( ) ) ] ) ,
4038+ vec ! [ Arc :: new(
4039+ AggregateExprBuilder :: new( sum_udaf( ) , vec![ col( "b" , & schema) ?] )
4040+ . schema( Arc :: clone( & schema) )
4041+ . alias( "sum_b" )
4042+ . build( ) ?,
4043+ ) ] ,
4044+ vec ! [ None ] ,
4045+ child,
4046+ Arc :: clone ( & schema) ,
4047+ ) ?;
4048+ assert ! ( agg. dynamic_filter( ) . is_none( ) ) ;
4049+
4050+ // with_dynamic_filter should be a no-op.
4051+ let df = Arc :: new ( DynamicFilterPhysicalExpr :: new (
4052+ vec ! [ col( "a" , & schema) ?] ,
4053+ lit ( true ) ,
4054+ ) ) ;
4055+ let agg = agg. with_dynamic_filter ( df) ?;
4056+ assert ! ( agg. dynamic_filter( ) . is_none( ) ) ;
4057+ Ok ( ( ) )
4058+ }
4059+
4060+ #[ test]
4061+ fn test_with_dynamic_filter_rejects_invalid_columns ( ) -> Result < ( ) > {
4062+ let schema = Arc :: new ( Schema :: new ( vec ! [ Field :: new( "a" , DataType :: Int64 , false ) ] ) ) ;
4063+ let child = Arc :: new ( EmptyExec :: new ( Arc :: clone ( & schema) ) ) ;
4064+
4065+ let agg = AggregateExec :: try_new (
4066+ AggregateMode :: Partial ,
4067+ PhysicalGroupBy :: new_single ( vec ! [ ] ) ,
4068+ vec ! [ Arc :: new(
4069+ AggregateExprBuilder :: new( min_udaf( ) , vec![ col( "a" , & schema) ?] )
4070+ . schema( Arc :: clone( & schema) )
4071+ . alias( "min_a" )
4072+ . build( ) ?,
4073+ ) ] ,
4074+ vec ! [ None ] ,
4075+ child,
4076+ Arc :: clone ( & schema) ,
4077+ ) ?;
4078+
4079+ // Column index 99 is out of bounds for the input schema.
4080+ let df = Arc :: new ( DynamicFilterPhysicalExpr :: new (
4081+ vec ! [ Arc :: new( Column :: new( "bad" , 99 ) ) as _] ,
4082+ lit ( true ) ,
4083+ ) ) ;
4084+ assert ! ( agg. with_dynamic_filter( df) . is_err( ) ) ;
4085+ Ok ( ( ) )
4086+ }
39534087}
0 commit comments