Skip to content

Commit 4d5a0bd

Browse files
committed
Support aggregation in rolling window queries
The idea is to aggregate inside each matching partition and dimension. `ROLLING_WINDOW` clause now has an optional `GROUP BY DIMENSION <expr>` argument. Corresponding expression is used both as a grouping key for non-rolling aggregates and a "join" key to match to the rolling window output dimension.
1 parent 605c921 commit 4d5a0bd

4 files changed

Lines changed: 246 additions & 73 deletions

File tree

datafusion/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ ahash = "0.7"
4848
hashbrown = "0.11"
4949
arrow = { git = "https://github.com/cube-js/arrow-rs.git", branch = "cube", features = ["prettyprint"] }
5050
parquet = { git = "https://github.com/cube-js/arrow-rs.git", branch = "cube", features = ["arrow"] }
51-
sqlparser = { git = "https://github.com/cube-js/sqlparser-rs.git", rev = "6008dfab082a3455c54b023be878d92ec9acef43" }
51+
sqlparser = { git = "https://github.com/cube-js/sqlparser-rs.git", rev = "2fcd06f7354e8c85f170b49a08fc018749289a40" }
5252
paste = "^1.0"
5353
num_cpus = "1.13.0"
5454
chrono = "0.4"

datafusion/src/cube_ext/rolling.rs

Lines changed: 141 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,23 +26,28 @@ use crate::logical_plan::{
2626
};
2727
use crate::physical_plan::coalesce_batches::concat_batches;
2828
use crate::physical_plan::expressions::PhysicalSortExpr;
29-
use crate::physical_plan::hash_aggregate::{append_value, create_builder};
29+
use crate::physical_plan::group_scalar::GroupByScalar;
30+
use crate::physical_plan::hash_aggregate::{
31+
append_value, create_accumulators, create_builder, create_group_by_value,
32+
};
3033
use crate::physical_plan::planner::ExtensionPlanner;
3134
use crate::physical_plan::sort::SortExec;
3235
use crate::physical_plan::{
3336
collect, AggregateExpr, ColumnarValue, Distribution, ExecutionPlan, Partitioning,
34-
PhysicalPlanner, SendableRecordBatchStream,
37+
PhysicalExpr, PhysicalPlanner, SendableRecordBatchStream,
3538
};
3639
use crate::scalar::ScalarValue;
37-
use arrow::array::{make_array, BooleanBuilder, MutableArrayData};
40+
use arrow::array::{make_array, ArrayRef, BooleanBuilder, MutableArrayData, UInt64Array};
3841
use arrow::compute::filter;
3942
use arrow::datatypes::{DataType, Schema, SchemaRef};
4043
use arrow::record_batch::RecordBatch;
4144
use async_trait::async_trait;
4245
use chrono::{TimeZone, Utc};
46+
use hashbrown::HashMap;
4347
use itertools::Itertools;
4448
use std::any::Any;
4549
use std::cmp::{max, Ordering};
50+
use std::convert::TryFrom;
4651
use std::sync::Arc;
4752

4853
#[derive(Debug)]
@@ -55,6 +60,8 @@ pub struct RollingWindowAggregate {
5560
pub every: Expr,
5661
pub partition_by: Vec<Column>,
5762
pub rolling_aggs: Vec<Expr>,
63+
pub group_by_dimension: Option<Expr>,
64+
pub aggs: Vec<Expr>,
5865
}
5966

6067
impl UserDefinedLogicalNode for RollingWindowAggregate {
@@ -79,6 +86,10 @@ impl UserDefinedLogicalNode for RollingWindowAggregate {
7986
];
8087
e.extend(self.partition_by.iter().map(|c| Expr::Column(c.clone())));
8188
e.extend_from_slice(self.rolling_aggs.as_slice());
89+
e.extend_from_slice(self.aggs.as_slice());
90+
if let Some(d) = &self.group_by_dimension {
91+
e.push(d.clone());
92+
}
8293
e
8394
}
8495

@@ -96,7 +107,13 @@ impl UserDefinedLogicalNode for RollingWindowAggregate {
96107
inputs: &[LogicalPlan],
97108
) -> Arc<dyn UserDefinedLogicalNode + Send + Sync> {
98109
assert_eq!(inputs.len(), 1);
99-
assert!(4 + self.partition_by.len() <= exprs.len());
110+
assert_eq!(
111+
exprs.len(),
112+
4 + self.partition_by.len()
113+
+ self.rolling_aggs.len()
114+
+ self.aggs.len()
115+
+ self.group_by_dimension.as_ref().map(|_| 1).unwrap_or(0)
116+
);
100117
let input = inputs[0].clone();
101118
let dimension = match &exprs[0] {
102119
Expr::Column(c) => c.clone(),
@@ -105,14 +122,30 @@ impl UserDefinedLogicalNode for RollingWindowAggregate {
105122
let from = exprs[1].clone();
106123
let to = exprs[2].clone();
107124
let every = exprs[3].clone();
108-
let partition_by = exprs[4..4 + self.partition_by.len()]
125+
let exprs = &exprs[4..];
126+
let partition_by = exprs[..self.partition_by.len()]
109127
.iter()
110128
.map(|c| match c {
111129
Expr::Column(c) => c.clone(),
112130
o => panic!("Expected column for partition_by, got {:?}", o),
113131
})
114132
.collect_vec();
115-
let rolling_aggs = exprs[4 + self.partition_by.len()..].to_vec();
133+
let exprs = &exprs[self.partition_by.len()..];
134+
135+
let rolling_aggs = exprs[..self.rolling_aggs.len()].to_vec();
136+
let exprs = &exprs[self.rolling_aggs.len()..];
137+
138+
let aggs = exprs[..self.aggs.len()].to_vec();
139+
let exprs = &exprs[self.aggs.len()..];
140+
141+
let group_by_dimension = if self.group_by_dimension.is_some() {
142+
debug_assert_eq!(exprs.len(), 1);
143+
Some(exprs[0].clone())
144+
} else {
145+
debug_assert_eq!(exprs.len(), 0);
146+
None
147+
};
148+
116149
Arc::new(RollingWindowAggregate {
117150
schema: self.schema.clone(),
118151
input,
@@ -122,6 +155,8 @@ impl UserDefinedLogicalNode for RollingWindowAggregate {
122155
every,
123156
partition_by,
124157
rolling_aggs,
158+
group_by_dimension,
159+
aggs,
125160
})
126161
}
127162
}
@@ -211,6 +246,21 @@ impl ExtensionPlanner for Planner {
211246
})
212247
.collect::<Result<Vec<_>, _>>()?;
213248

249+
let group_by_dimension = node
250+
.group_by_dimension
251+
.as_ref()
252+
.map(|d| {
253+
planner.create_physical_expr(d, input_dfschema, &input_schema, ctx_state)
254+
})
255+
.transpose()?;
256+
let aggs = node
257+
.aggs
258+
.iter()
259+
.map(|a| {
260+
planner.create_aggregate_expr(a, input_dfschema, &input_schema, ctx_state)
261+
})
262+
.collect::<Result<_, _>>()?;
263+
214264
// TODO: filter inputs by date.
215265
// Do preliminary sorting.
216266
let mut sort_key = Vec::with_capacity(input_schema.fields().len());
@@ -229,6 +279,7 @@ impl ExtensionPlanner for Planner {
229279
});
230280

231281
let sort = Arc::new(SortExec::try_new(sort_key, input.clone())?);
282+
232283
let schema = node.schema.to_schema_ref();
233284

234285
Ok(Some(Arc::new(RollingWindowAggExec {
@@ -237,6 +288,8 @@ impl ExtensionPlanner for Planner {
237288
group_key,
238289
rolling_aggs,
239290
dimension,
291+
group_by_dimension,
292+
aggs,
240293
from,
241294
to,
242295
every,
@@ -297,6 +350,8 @@ pub struct RollingWindowAggExec {
297350
pub group_key: Vec<crate::physical_plan::expressions::Column>,
298351
pub rolling_aggs: Vec<RollingAgg>,
299352
pub dimension: crate::physical_plan::expressions::Column,
353+
pub group_by_dimension: Option<Arc<dyn PhysicalExpr>>,
354+
pub aggs: Vec<Arc<dyn AggregateExpr>>,
300355
pub from: ScalarValue,
301356
pub to: ScalarValue,
302357
pub every: ScalarValue,
@@ -335,6 +390,8 @@ impl ExecutionPlan for RollingWindowAggExec {
335390
group_key: self.group_key.clone(),
336391
rolling_aggs: self.rolling_aggs.clone(),
337392
dimension: self.dimension.clone(),
393+
group_by_dimension: self.group_by_dimension.clone(),
394+
aggs: self.aggs.clone(),
338395
from: self.from.clone(),
339396
to: self.to.clone(),
340397
every: self.every.clone(),
@@ -357,6 +414,7 @@ impl ExecutionPlan for RollingWindowAggExec {
357414
.iter()
358415
.map(|c| input.columns()[c.index()].clone())
359416
.collect_vec();
417+
360418
let other_cols = input
361419
.columns()
362420
.iter()
@@ -374,15 +432,7 @@ impl ExecutionPlan for RollingWindowAggExec {
374432
let agg_inputs = self
375433
.rolling_aggs
376434
.iter()
377-
.map(|r| {
378-
r.agg
379-
.expressions()
380-
.iter()
381-
.map(|e| -> Result<_, DataFusionError> {
382-
Ok(e.evaluate(&input)?.into_array(num_rows))
383-
})
384-
.collect::<Result<Vec<_>, _>>()
385-
})
435+
.map(|r| compute_agg_inputs(r.agg.as_ref(), &input))
386436
.collect::<Result<Vec<_>, _>>()?;
387437
let mut accumulators = self
388438
.rolling_aggs
@@ -396,6 +446,19 @@ impl ExecutionPlan for RollingWindowAggExec {
396446
dimension = arrow::compute::cast(&dimension, &dim_iter_type)?;
397447
}
398448

449+
let extra_aggs_dimension = self
450+
.group_by_dimension
451+
.as_ref()
452+
.map(|d| -> Result<_, DataFusionError> {
453+
Ok(d.evaluate(&input)?.into_array(num_rows))
454+
})
455+
.transpose()?;
456+
let extra_aggs_inputs = self
457+
.aggs
458+
.iter()
459+
.map(|a| compute_agg_inputs(a.as_ref(), &input))
460+
.collect::<Result<Vec<_>, _>>()?;
461+
399462
let mut out_dim = create_builder(&self.from);
400463
let mut out_keys = key_cols
401464
.iter()
@@ -404,6 +467,12 @@ impl ExecutionPlan for RollingWindowAggExec {
404467
let mut out_aggs = Vec::with_capacity(self.rolling_aggs.len());
405468
// This filter must be applied prior to returning the values.
406469
let mut out_aggs_keep = BooleanBuilder::new(0);
470+
let extra_agg_nulls = self
471+
.aggs
472+
.iter()
473+
.map(|a| ScalarValue::try_from(a.field()?.data_type()))
474+
.collect::<Result<Vec<_>, _>>()?;
475+
let mut out_extra_aggs = extra_agg_nulls.iter().map(create_builder).collect_vec();
407476
let mut out_other = other_cols
408477
.iter()
409478
.map(|c| MutableArrayData::new(vec![c.data()], true, 0))
@@ -491,6 +560,32 @@ impl ExecutionPlan for RollingWindowAggExec {
491560
}
492561
}
493562

563+
// Compute non-rolling aggregates for the group.
564+
let mut dim_to_extra_aggs = HashMap::new();
565+
if let Some(key) = &extra_aggs_dimension {
566+
let mut key_to_rows = HashMap::new();
567+
for i in group_start..group_end {
568+
let key = create_group_by_value(key, i)?;
569+
key_to_rows.entry(key).or_insert(Vec::new()).push(i as u64);
570+
}
571+
572+
for (k, rows) in key_to_rows {
573+
let mut accumulators = create_accumulators(&self.aggs)?;
574+
let rows = UInt64Array::from(rows);
575+
let mut values = Vec::with_capacity(accumulators.len());
576+
for i in 0..accumulators.len() {
577+
let accum_inputs = extra_aggs_inputs[i]
578+
.iter()
579+
.map(|a| arrow::compute::take(a.as_ref(), &rows, None))
580+
.collect::<Result<Vec<_>, _>>()?;
581+
accumulators[i].update_batch(&accum_inputs)?;
582+
values.push(accumulators[i].evaluate()?);
583+
}
584+
585+
dim_to_extra_aggs.insert(k, values);
586+
}
587+
}
588+
494589
// Add keys, dimension and non-aggregate columns to the output.
495590
let mut d = self.from.clone();
496591
let mut d_iter = 0;
@@ -509,6 +604,19 @@ impl ExecutionPlan for RollingWindowAggExec {
509604
for i in 0..key_cols.len() {
510605
out_keys[i].extend(0, group_start, group_start + 1)
511606
}
607+
// Add aggregates.
608+
match dim_to_extra_aggs.get(&GroupByScalar::try_from(&d)?) {
609+
Some(aggs) => {
610+
for i in 0..out_extra_aggs.len() {
611+
append_value(out_extra_aggs[i].as_mut(), &aggs[i])?
612+
}
613+
}
614+
None => {
615+
for i in 0..out_extra_aggs.len() {
616+
append_value(out_extra_aggs[i].as_mut(), &extra_agg_nulls[i])?
617+
}
618+
}
619+
}
512620
// Find the matching row to add other columns.
513621
while matching_row_lower_bound < group_end
514622
&& cmp_same_types(
@@ -590,10 +698,16 @@ impl ExecutionPlan for RollingWindowAggExec {
590698
for o in out_other {
591699
r.push(make_array(o.freeze()));
592700
}
701+
593702
let out_aggs_keep = out_aggs_keep.finish();
594703
for mut a in out_aggs {
595704
r.push(filter(a.finish().as_ref(), &out_aggs_keep)?);
596705
}
706+
707+
for mut a in out_extra_aggs {
708+
r.push(a.finish())
709+
}
710+
597711
let r = RecordBatch::try_new(self.schema(), r)?;
598712
Ok(Box::pin(StreamWithSchema::wrap(
599713
self.schema(),
@@ -621,6 +735,18 @@ fn add_dim(l: &ScalarValue, r: &ScalarValue) -> ScalarValue {
621735
}
622736
}
623737

738+
fn compute_agg_inputs(
739+
a: &dyn AggregateExpr,
740+
input: &RecordBatch,
741+
) -> Result<Vec<ArrayRef>, DataFusionError> {
742+
a.expressions()
743+
.iter()
744+
.map(|e| -> Result<_, DataFusionError> {
745+
Ok(e.evaluate(input)?.into_array(input.num_rows()))
746+
})
747+
.collect()
748+
}
749+
624750
fn meets_lower_bound(
625751
value: &ScalarValue,
626752
current: &ScalarValue,

0 commit comments

Comments
 (0)