Skip to content

Commit 4889d13

Browse files
wip
1 parent c5d0e2f commit 4889d13

10 files changed

Lines changed: 746 additions & 45 deletions

File tree

datafusion/core/tests/physical_optimizer/filter_pushdown.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3986,7 +3986,7 @@ async fn test_hashjoin_dynamic_filter_pushdown_is_used() {
39863986

39873987
// Verify that a dynamic filter was created
39883988
let dynamic_filter = hash_join
3989-
.dynamic_filter_for_test()
3989+
.dynamic_filter()
39903990
.expect("Dynamic filter should be created");
39913991

39923992
// Verify that is_used() returns the expected value based on probe side support.

datafusion/physical-expr-common/src/physical_expr.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,9 @@ pub trait PhysicalExpr: Any + Send + Sync + Display + Debug + DynEq + DynHash {
7575
/// Returns the physical expression as [`Any`] so that it can be
7676
/// downcast to a specific implementation.
7777
fn as_any(&self) -> &dyn Any;
78-
/// Get the data type of this expression, given the schema of the input
78+
/// Get the data type of this expression, given the schema of the input.
79+
/// Returns an error if the data type cannot be determined, ex. if the
80+
/// schema is missing a required field.
7981
fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
8082
Ok(self.return_field(input_schema)?.data_type().to_owned())
8183
}

datafusion/physical-plan/src/aggregates/mod.rs

Lines changed: 170 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,6 +1036,34 @@ impl AggregateExec {
10361036
&self.input_order_mode
10371037
}
10381038

1039+
/// Returns the dynamic filter expression for this aggregate, if set.
1040+
pub fn dynamic_filter(&self) -> Option<&Arc<DynamicFilterPhysicalExpr>> {
1041+
self.dynamic_filter.as_ref().map(|df| &df.filter)
1042+
}
1043+
1044+
/// Replace the dynamic filter expression, recomputing any internal state
1045+
/// which may depend on the previous dynamic filter.
1046+
///
1047+
/// This is a no-op if the aggregate does not support dynamic filtering.
1048+
///
1049+
/// If dynamic filtering is supported, this method returns an error if the filter's
1050+
/// children reference invalid columns in the aggregate's input schema.
1051+
pub fn with_dynamic_filter(
1052+
mut self,
1053+
filter: Arc<DynamicFilterPhysicalExpr>,
1054+
) -> Result<Self> {
1055+
if let Some(supported_accumulators_info) = self.supported_accumulators_info() {
1056+
for child in filter.children() {
1057+
child.data_type(&self.input_schema)?;
1058+
}
1059+
self.dynamic_filter = Some(Arc::new(AggrDynFilter {
1060+
filter,
1061+
supported_accumulators_info,
1062+
}));
1063+
}
1064+
Ok(self)
1065+
}
1066+
10391067
fn statistics_inner(&self, child_statistics: &Statistics) -> Result<Statistics> {
10401068
// TODO stats: group expressions:
10411069
// - once expressions will be able to compute their own stats, use it here
@@ -1116,27 +1144,40 @@ impl AggregateExec {
11161144
/// - If yes, init one inside `AggregateExec`'s `dynamic_filter` field.
11171145
/// - If not supported, `self.dynamic_filter` should be kept `None`
11181146
fn init_dynamic_filter(&mut self) {
1119-
if (!self.group_by.is_empty()) || (!matches!(self.mode, AggregateMode::Partial)) {
1120-
debug_assert!(
1121-
self.dynamic_filter.is_none(),
1122-
"The current operator node does not support dynamic filter"
1123-
);
1124-
return;
1125-
}
1126-
11271147
// Already initialized.
11281148
if self.dynamic_filter.is_some() {
11291149
return;
11301150
}
11311151

1132-
// Collect supported accumulators
1133-
// It is assumed the order of aggregate expressions are not changed from `AggregateExec`
1134-
// to `AggregateStream`
1152+
if let Some(supported_accumulators_info) = self.supported_accumulators_info() {
1153+
// Collect column references for the dynamic filter expression.
1154+
let all_cols: Vec<Arc<dyn PhysicalExpr>> = supported_accumulators_info
1155+
.iter()
1156+
.map(|info| Arc::clone(&self.aggr_expr[info.aggr_index].expressions()[0]))
1157+
.collect();
1158+
1159+
self.dynamic_filter = Some(Arc::new(AggrDynFilter {
1160+
filter: Arc::new(DynamicFilterPhysicalExpr::new(all_cols, lit(true))),
1161+
supported_accumulators_info,
1162+
}));
1163+
}
1164+
}
1165+
1166+
/// Returns the supported accumulator info if this aggregate supports
1167+
/// dynamic filtering, or `None` otherwise.
1168+
///
1169+
/// Dynamic filtering requires:
1170+
/// - `Partial` aggregation mode with no group-by expressions
1171+
/// - All aggregate functions are `min` or `max` with a single column arg
1172+
fn supported_accumulators_info(&self) -> Option<Vec<PerAccumulatorDynFilter>> {
1173+
if !self.group_by.is_empty() || !matches!(self.mode, AggregateMode::Partial) {
1174+
return None;
1175+
}
1176+
1177+
// Collect supported accumulators.
1178+
// It is assumed the order of aggregate expressions are not changed
1179+
// from `AggregateExec` to `AggregateStream`.
11351180
let mut aggr_dyn_filters = Vec::new();
1136-
// All column references in the dynamic filter, used when initializing the dynamic
1137-
// filter, and it's used to decide if this dynamic filter is able to get push
1138-
// through certain node during optimization.
1139-
let mut all_cols: Vec<Arc<dyn PhysicalExpr>> = Vec::new();
11401181
for (i, aggr_expr) in self.aggr_expr.iter().enumerate() {
11411182
// 1. Only `min` or `max` aggregate function
11421183
let fun_name = aggr_expr.fun().name();
@@ -1147,14 +1188,13 @@ impl AggregateExec {
11471188
} else if fun_name.eq_ignore_ascii_case("max") {
11481189
DynamicFilterAggregateType::Max
11491190
} else {
1150-
return;
1191+
return None;
11511192
};
11521193

11531194
// 2. arg should be only 1 column reference
11541195
if let [arg] = aggr_expr.expressions().as_slice()
11551196
&& arg.as_any().is::<Column>()
11561197
{
1157-
all_cols.push(Arc::clone(arg));
11581198
aggr_dyn_filters.push(PerAccumulatorDynFilter {
11591199
aggr_type,
11601200
aggr_index: i,
@@ -1163,11 +1203,10 @@ impl AggregateExec {
11631203
}
11641204
}
11651205

1166-
if !aggr_dyn_filters.is_empty() {
1167-
self.dynamic_filter = Some(Arc::new(AggrDynFilter {
1168-
filter: Arc::new(DynamicFilterPhysicalExpr::new(all_cols, lit(true))),
1169-
supported_accumulators_info: aggr_dyn_filters,
1170-
}))
1206+
if aggr_dyn_filters.is_empty() {
1207+
None
1208+
} else {
1209+
Some(aggr_dyn_filters)
11711210
}
11721211
}
11731212

@@ -1964,6 +2003,7 @@ mod tests {
19642003
use crate::coalesce_partitions::CoalescePartitionsExec;
19652004
use crate::common;
19662005
use crate::common::collect;
2006+
use crate::empty::EmptyExec;
19672007
use crate::execution_plan::Boundedness;
19682008
use crate::expressions::col;
19692009
use crate::metrics::MetricValue;
@@ -1987,6 +2027,7 @@ mod tests {
19872027
use datafusion_functions_aggregate::count::count_udaf;
19882028
use datafusion_functions_aggregate::first_last::{first_value_udaf, last_value_udaf};
19892029
use datafusion_functions_aggregate::median::median_udaf;
2030+
use datafusion_functions_aggregate::min_max::min_udaf;
19902031
use datafusion_functions_aggregate::sum::sum_udaf;
19912032
use datafusion_physical_expr::Partitioning;
19922033
use datafusion_physical_expr::PhysicalSortExpr;
@@ -3459,13 +3500,10 @@ mod tests {
34593500
// Test with MIN for simple intermediate state (min) and AVG for multiple intermediate states (partial sum, partial count).
34603501
let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![
34613502
Arc::new(
3462-
AggregateExprBuilder::new(
3463-
datafusion_functions_aggregate::min_max::min_udaf(),
3464-
vec![col("b", &schema)?],
3465-
)
3466-
.schema(Arc::clone(&schema))
3467-
.alias("MIN(b)")
3468-
.build()?,
3503+
AggregateExprBuilder::new(min_udaf(), vec![col("b", &schema)?])
3504+
.schema(Arc::clone(&schema))
3505+
.alias("MIN(b)")
3506+
.build()?,
34693507
),
34703508
Arc::new(
34713509
AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?])
@@ -3604,13 +3642,10 @@ mod tests {
36043642
// Test with MIN for simple intermediate state (min) and AVG for multiple intermediate states (partial sum, partial count).
36053643
let aggregates: Vec<Arc<AggregateFunctionExpr>> = vec![
36063644
Arc::new(
3607-
AggregateExprBuilder::new(
3608-
datafusion_functions_aggregate::min_max::min_udaf(),
3609-
vec![col("b", &schema)?],
3610-
)
3611-
.schema(Arc::clone(&schema))
3612-
.alias("MIN(b)")
3613-
.build()?,
3645+
AggregateExprBuilder::new(min_udaf(), vec![col("b", &schema)?])
3646+
.schema(Arc::clone(&schema))
3647+
.alias("MIN(b)")
3648+
.build()?,
36143649
),
36153650
Arc::new(
36163651
AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?])
@@ -3950,4 +3985,103 @@ mod tests {
39503985

39513986
Ok(())
39523987
}
3988+
3989+
#[test]
3990+
fn test_with_dynamic_filter() -> Result<()> {
3991+
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)]));
3992+
let child = Arc::new(EmptyExec::new(Arc::clone(&schema)));
3993+
3994+
// Partial min triggers init_dynamic_filter.
3995+
let agg = AggregateExec::try_new(
3996+
AggregateMode::Partial,
3997+
PhysicalGroupBy::new_single(vec![]),
3998+
vec![Arc::new(
3999+
AggregateExprBuilder::new(min_udaf(), vec![col("a", &schema)?])
4000+
.schema(Arc::clone(&schema))
4001+
.alias("min_a")
4002+
.build()?,
4003+
)],
4004+
vec![None],
4005+
child,
4006+
Arc::clone(&schema),
4007+
)?;
4008+
let original_inner_id = agg
4009+
.dynamic_filter()
4010+
.expect("should have dynamic filter after init")
4011+
.inner_id();
4012+
4013+
let new_df = Arc::new(DynamicFilterPhysicalExpr::new(
4014+
vec![col("a", &schema)?],
4015+
lit(true),
4016+
));
4017+
let agg = agg.with_dynamic_filter(Arc::clone(&new_df))?;
4018+
let restored = agg
4019+
.dynamic_filter()
4020+
.expect("should still have dynamic filter");
4021+
assert_eq!(restored.inner_id(), new_df.inner_id());
4022+
assert_ne!(restored.inner_id(), original_inner_id);
4023+
Ok(())
4024+
}
4025+
4026+
#[test]
4027+
fn test_with_dynamic_filter_noop_when_unsupported() -> Result<()> {
4028+
let schema = Arc::new(Schema::new(vec![
4029+
Field::new("a", DataType::Int64, false),
4030+
Field::new("b", DataType::Int64, false),
4031+
]));
4032+
let child = Arc::new(EmptyExec::new(Arc::clone(&schema)));
4033+
4034+
// Final mode with a group-by does not support dynamic filters.
4035+
let agg = AggregateExec::try_new(
4036+
AggregateMode::Final,
4037+
PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]),
4038+
vec![Arc::new(
4039+
AggregateExprBuilder::new(sum_udaf(), vec![col("b", &schema)?])
4040+
.schema(Arc::clone(&schema))
4041+
.alias("sum_b")
4042+
.build()?,
4043+
)],
4044+
vec![None],
4045+
child,
4046+
Arc::clone(&schema),
4047+
)?;
4048+
assert!(agg.dynamic_filter().is_none());
4049+
4050+
// with_dynamic_filter should be a no-op.
4051+
let df = Arc::new(DynamicFilterPhysicalExpr::new(
4052+
vec![col("a", &schema)?],
4053+
lit(true),
4054+
));
4055+
let agg = agg.with_dynamic_filter(df)?;
4056+
assert!(agg.dynamic_filter().is_none());
4057+
Ok(())
4058+
}
4059+
4060+
#[test]
4061+
fn test_with_dynamic_filter_rejects_invalid_columns() -> Result<()> {
4062+
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)]));
4063+
let child = Arc::new(EmptyExec::new(Arc::clone(&schema)));
4064+
4065+
let agg = AggregateExec::try_new(
4066+
AggregateMode::Partial,
4067+
PhysicalGroupBy::new_single(vec![]),
4068+
vec![Arc::new(
4069+
AggregateExprBuilder::new(min_udaf(), vec![col("a", &schema)?])
4070+
.schema(Arc::clone(&schema))
4071+
.alias("min_a")
4072+
.build()?,
4073+
)],
4074+
vec![None],
4075+
child,
4076+
Arc::clone(&schema),
4077+
)?;
4078+
4079+
// Column index 99 is out of bounds for the input schema.
4080+
let df = Arc::new(DynamicFilterPhysicalExpr::new(
4081+
vec![Arc::new(Column::new("bad", 99)) as _],
4082+
lit(true),
4083+
));
4084+
assert!(agg.with_dynamic_filter(df).is_err());
4085+
Ok(())
4086+
}
39534087
}

0 commit comments

Comments
 (0)