@@ -26,23 +26,28 @@ use crate::logical_plan::{
2626} ;
2727use crate :: physical_plan:: coalesce_batches:: concat_batches;
2828use crate :: physical_plan:: expressions:: PhysicalSortExpr ;
29- use crate :: physical_plan:: hash_aggregate:: { append_value, create_builder} ;
29+ use crate :: physical_plan:: group_scalar:: GroupByScalar ;
30+ use crate :: physical_plan:: hash_aggregate:: {
31+ append_value, create_accumulators, create_builder, create_group_by_value,
32+ } ;
3033use crate :: physical_plan:: planner:: ExtensionPlanner ;
3134use crate :: physical_plan:: sort:: SortExec ;
3235use crate :: physical_plan:: {
3336 collect, AggregateExpr , ColumnarValue , Distribution , ExecutionPlan , Partitioning ,
34- PhysicalPlanner , SendableRecordBatchStream ,
37+ PhysicalExpr , PhysicalPlanner , SendableRecordBatchStream ,
3538} ;
3639use crate :: scalar:: ScalarValue ;
37- use arrow:: array:: { make_array, BooleanBuilder , MutableArrayData } ;
40+ use arrow:: array:: { make_array, ArrayRef , BooleanBuilder , MutableArrayData , UInt64Array } ;
3841use arrow:: compute:: filter;
3942use arrow:: datatypes:: { DataType , Schema , SchemaRef } ;
4043use arrow:: record_batch:: RecordBatch ;
4144use async_trait:: async_trait;
4245use chrono:: { TimeZone , Utc } ;
46+ use hashbrown:: HashMap ;
4347use itertools:: Itertools ;
4448use std:: any:: Any ;
4549use std:: cmp:: { max, Ordering } ;
50+ use std:: convert:: TryFrom ;
4651use std:: sync:: Arc ;
4752
4853#[ derive( Debug ) ]
@@ -55,6 +60,8 @@ pub struct RollingWindowAggregate {
5560 pub every : Expr ,
5661 pub partition_by : Vec < Column > ,
5762 pub rolling_aggs : Vec < Expr > ,
63+ pub group_by_dimension : Option < Expr > ,
64+ pub aggs : Vec < Expr > ,
5865}
5966
6067impl UserDefinedLogicalNode for RollingWindowAggregate {
@@ -79,6 +86,10 @@ impl UserDefinedLogicalNode for RollingWindowAggregate {
7986 ] ;
8087 e. extend ( self . partition_by . iter ( ) . map ( |c| Expr :: Column ( c. clone ( ) ) ) ) ;
8188 e. extend_from_slice ( self . rolling_aggs . as_slice ( ) ) ;
89+ e. extend_from_slice ( self . aggs . as_slice ( ) ) ;
90+ if let Some ( d) = & self . group_by_dimension {
91+ e. push ( d. clone ( ) ) ;
92+ }
8293 e
8394 }
8495
@@ -96,7 +107,13 @@ impl UserDefinedLogicalNode for RollingWindowAggregate {
96107 inputs : & [ LogicalPlan ] ,
97108 ) -> Arc < dyn UserDefinedLogicalNode + Send + Sync > {
98109 assert_eq ! ( inputs. len( ) , 1 ) ;
99- assert ! ( 4 + self . partition_by. len( ) <= exprs. len( ) ) ;
110+ assert_eq ! (
111+ exprs. len( ) ,
112+ 4 + self . partition_by. len( )
113+ + self . rolling_aggs. len( )
114+ + self . aggs. len( )
115+ + self . group_by_dimension. as_ref( ) . map( |_| 1 ) . unwrap_or( 0 )
116+ ) ;
100117 let input = inputs[ 0 ] . clone ( ) ;
101118 let dimension = match & exprs[ 0 ] {
102119 Expr :: Column ( c) => c. clone ( ) ,
@@ -105,14 +122,30 @@ impl UserDefinedLogicalNode for RollingWindowAggregate {
105122 let from = exprs[ 1 ] . clone ( ) ;
106123 let to = exprs[ 2 ] . clone ( ) ;
107124 let every = exprs[ 3 ] . clone ( ) ;
108- let partition_by = exprs[ 4 ..4 + self . partition_by . len ( ) ]
125+ let exprs = & exprs[ 4 ..] ;
126+ let partition_by = exprs[ ..self . partition_by . len ( ) ]
109127 . iter ( )
110128 . map ( |c| match c {
111129 Expr :: Column ( c) => c. clone ( ) ,
112130 o => panic ! ( "Expected column for partition_by, got {:?}" , o) ,
113131 } )
114132 . collect_vec ( ) ;
115- let rolling_aggs = exprs[ 4 + self . partition_by . len ( ) ..] . to_vec ( ) ;
133+ let exprs = & exprs[ self . partition_by . len ( ) ..] ;
134+
135+ let rolling_aggs = exprs[ ..self . rolling_aggs . len ( ) ] . to_vec ( ) ;
136+ let exprs = & exprs[ self . rolling_aggs . len ( ) ..] ;
137+
138+ let aggs = exprs[ ..self . aggs . len ( ) ] . to_vec ( ) ;
139+ let exprs = & exprs[ self . aggs . len ( ) ..] ;
140+
141+ let group_by_dimension = if self . group_by_dimension . is_some ( ) {
142+ debug_assert_eq ! ( exprs. len( ) , 1 ) ;
143+ Some ( exprs[ 0 ] . clone ( ) )
144+ } else {
145+ debug_assert_eq ! ( exprs. len( ) , 0 ) ;
146+ None
147+ } ;
148+
116149 Arc :: new ( RollingWindowAggregate {
117150 schema : self . schema . clone ( ) ,
118151 input,
@@ -122,6 +155,8 @@ impl UserDefinedLogicalNode for RollingWindowAggregate {
122155 every,
123156 partition_by,
124157 rolling_aggs,
158+ group_by_dimension,
159+ aggs,
125160 } )
126161 }
127162}
@@ -211,6 +246,21 @@ impl ExtensionPlanner for Planner {
211246 } )
212247 . collect :: < Result < Vec < _ > , _ > > ( ) ?;
213248
249+ let group_by_dimension = node
250+ . group_by_dimension
251+ . as_ref ( )
252+ . map ( |d| {
253+ planner. create_physical_expr ( d, input_dfschema, & input_schema, ctx_state)
254+ } )
255+ . transpose ( ) ?;
256+ let aggs = node
257+ . aggs
258+ . iter ( )
259+ . map ( |a| {
260+ planner. create_aggregate_expr ( a, input_dfschema, & input_schema, ctx_state)
261+ } )
262+ . collect :: < Result < _ , _ > > ( ) ?;
263+
214264 // TODO: filter inputs by date.
215265 // Do preliminary sorting.
216266 let mut sort_key = Vec :: with_capacity ( input_schema. fields ( ) . len ( ) ) ;
@@ -229,6 +279,7 @@ impl ExtensionPlanner for Planner {
229279 } ) ;
230280
231281 let sort = Arc :: new ( SortExec :: try_new ( sort_key, input. clone ( ) ) ?) ;
282+
232283 let schema = node. schema . to_schema_ref ( ) ;
233284
234285 Ok ( Some ( Arc :: new ( RollingWindowAggExec {
@@ -237,6 +288,8 @@ impl ExtensionPlanner for Planner {
237288 group_key,
238289 rolling_aggs,
239290 dimension,
291+ group_by_dimension,
292+ aggs,
240293 from,
241294 to,
242295 every,
@@ -297,6 +350,8 @@ pub struct RollingWindowAggExec {
297350 pub group_key : Vec < crate :: physical_plan:: expressions:: Column > ,
298351 pub rolling_aggs : Vec < RollingAgg > ,
299352 pub dimension : crate :: physical_plan:: expressions:: Column ,
353+ pub group_by_dimension : Option < Arc < dyn PhysicalExpr > > ,
354+ pub aggs : Vec < Arc < dyn AggregateExpr > > ,
300355 pub from : ScalarValue ,
301356 pub to : ScalarValue ,
302357 pub every : ScalarValue ,
@@ -335,6 +390,8 @@ impl ExecutionPlan for RollingWindowAggExec {
335390 group_key : self . group_key . clone ( ) ,
336391 rolling_aggs : self . rolling_aggs . clone ( ) ,
337392 dimension : self . dimension . clone ( ) ,
393+ group_by_dimension : self . group_by_dimension . clone ( ) ,
394+ aggs : self . aggs . clone ( ) ,
338395 from : self . from . clone ( ) ,
339396 to : self . to . clone ( ) ,
340397 every : self . every . clone ( ) ,
@@ -357,6 +414,7 @@ impl ExecutionPlan for RollingWindowAggExec {
357414 . iter ( )
358415 . map ( |c| input. columns ( ) [ c. index ( ) ] . clone ( ) )
359416 . collect_vec ( ) ;
417+
360418 let other_cols = input
361419 . columns ( )
362420 . iter ( )
@@ -374,15 +432,7 @@ impl ExecutionPlan for RollingWindowAggExec {
374432 let agg_inputs = self
375433 . rolling_aggs
376434 . iter ( )
377- . map ( |r| {
378- r. agg
379- . expressions ( )
380- . iter ( )
381- . map ( |e| -> Result < _ , DataFusionError > {
382- Ok ( e. evaluate ( & input) ?. into_array ( num_rows) )
383- } )
384- . collect :: < Result < Vec < _ > , _ > > ( )
385- } )
435+ . map ( |r| compute_agg_inputs ( r. agg . as_ref ( ) , & input) )
386436 . collect :: < Result < Vec < _ > , _ > > ( ) ?;
387437 let mut accumulators = self
388438 . rolling_aggs
@@ -396,6 +446,19 @@ impl ExecutionPlan for RollingWindowAggExec {
396446 dimension = arrow:: compute:: cast ( & dimension, & dim_iter_type) ?;
397447 }
398448
449+ let extra_aggs_dimension = self
450+ . group_by_dimension
451+ . as_ref ( )
452+ . map ( |d| -> Result < _ , DataFusionError > {
453+ Ok ( d. evaluate ( & input) ?. into_array ( num_rows) )
454+ } )
455+ . transpose ( ) ?;
456+ let extra_aggs_inputs = self
457+ . aggs
458+ . iter ( )
459+ . map ( |a| compute_agg_inputs ( a. as_ref ( ) , & input) )
460+ . collect :: < Result < Vec < _ > , _ > > ( ) ?;
461+
399462 let mut out_dim = create_builder ( & self . from ) ;
400463 let mut out_keys = key_cols
401464 . iter ( )
@@ -404,6 +467,12 @@ impl ExecutionPlan for RollingWindowAggExec {
404467 let mut out_aggs = Vec :: with_capacity ( self . rolling_aggs . len ( ) ) ;
405468 // This filter must be applied prior to returning the values.
406469 let mut out_aggs_keep = BooleanBuilder :: new ( 0 ) ;
470+ let extra_agg_nulls = self
471+ . aggs
472+ . iter ( )
473+ . map ( |a| ScalarValue :: try_from ( a. field ( ) ?. data_type ( ) ) )
474+ . collect :: < Result < Vec < _ > , _ > > ( ) ?;
475+ let mut out_extra_aggs = extra_agg_nulls. iter ( ) . map ( create_builder) . collect_vec ( ) ;
407476 let mut out_other = other_cols
408477 . iter ( )
409478 . map ( |c| MutableArrayData :: new ( vec ! [ c. data( ) ] , true , 0 ) )
@@ -491,6 +560,32 @@ impl ExecutionPlan for RollingWindowAggExec {
491560 }
492561 }
493562
563+ // Compute non-rolling aggregates for the group.
564+ let mut dim_to_extra_aggs = HashMap :: new ( ) ;
565+ if let Some ( key) = & extra_aggs_dimension {
566+ let mut key_to_rows = HashMap :: new ( ) ;
567+ for i in group_start..group_end {
568+ let key = create_group_by_value ( key, i) ?;
569+ key_to_rows. entry ( key) . or_insert ( Vec :: new ( ) ) . push ( i as u64 ) ;
570+ }
571+
572+ for ( k, rows) in key_to_rows {
573+ let mut accumulators = create_accumulators ( & self . aggs ) ?;
574+ let rows = UInt64Array :: from ( rows) ;
575+ let mut values = Vec :: with_capacity ( accumulators. len ( ) ) ;
576+ for i in 0 ..accumulators. len ( ) {
577+ let accum_inputs = extra_aggs_inputs[ i]
578+ . iter ( )
579+ . map ( |a| arrow:: compute:: take ( a. as_ref ( ) , & rows, None ) )
580+ . collect :: < Result < Vec < _ > , _ > > ( ) ?;
581+ accumulators[ i] . update_batch ( & accum_inputs) ?;
582+ values. push ( accumulators[ i] . evaluate ( ) ?) ;
583+ }
584+
585+ dim_to_extra_aggs. insert ( k, values) ;
586+ }
587+ }
588+
494589 // Add keys, dimension and non-aggregate columns to the output.
495590 let mut d = self . from . clone ( ) ;
496591 let mut d_iter = 0 ;
@@ -509,6 +604,19 @@ impl ExecutionPlan for RollingWindowAggExec {
509604 for i in 0 ..key_cols. len ( ) {
510605 out_keys[ i] . extend ( 0 , group_start, group_start + 1 )
511606 }
607+ // Add aggregates.
608+ match dim_to_extra_aggs. get ( & GroupByScalar :: try_from ( & d) ?) {
609+ Some ( aggs) => {
610+ for i in 0 ..out_extra_aggs. len ( ) {
611+ append_value ( out_extra_aggs[ i] . as_mut ( ) , & aggs[ i] ) ?
612+ }
613+ }
614+ None => {
615+ for i in 0 ..out_extra_aggs. len ( ) {
616+ append_value ( out_extra_aggs[ i] . as_mut ( ) , & extra_agg_nulls[ i] ) ?
617+ }
618+ }
619+ }
512620 // Find the matching row to add other columns.
513621 while matching_row_lower_bound < group_end
514622 && cmp_same_types (
@@ -590,10 +698,16 @@ impl ExecutionPlan for RollingWindowAggExec {
590698 for o in out_other {
591699 r. push ( make_array ( o. freeze ( ) ) ) ;
592700 }
701+
593702 let out_aggs_keep = out_aggs_keep. finish ( ) ;
594703 for mut a in out_aggs {
595704 r. push ( filter ( a. finish ( ) . as_ref ( ) , & out_aggs_keep) ?) ;
596705 }
706+
707+ for mut a in out_extra_aggs {
708+ r. push ( a. finish ( ) )
709+ }
710+
597711 let r = RecordBatch :: try_new ( self . schema ( ) , r) ?;
598712 Ok ( Box :: pin ( StreamWithSchema :: wrap (
599713 self . schema ( ) ,
@@ -621,6 +735,18 @@ fn add_dim(l: &ScalarValue, r: &ScalarValue) -> ScalarValue {
621735 }
622736}
623737
738+ fn compute_agg_inputs (
739+ a : & dyn AggregateExpr ,
740+ input : & RecordBatch ,
741+ ) -> Result < Vec < ArrayRef > , DataFusionError > {
742+ a. expressions ( )
743+ . iter ( )
744+ . map ( |e| -> Result < _ , DataFusionError > {
745+ Ok ( e. evaluate ( input) ?. into_array ( input. num_rows ( ) ) )
746+ } )
747+ . collect ( )
748+ }
749+
624750fn meets_lower_bound (
625751 value : & ScalarValue ,
626752 current : & ScalarValue ,
0 commit comments