Skip to content

Commit 4d79e26

Browse files
committed
feat: Allow struct field access projections to be pushed down into scans
1 parent f8a22a5 commit 4d79e26

34 files changed

Lines changed: 1508 additions & 188 deletions

File tree

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,13 @@ use datafusion_physical_plan::{
6464
sorts::sort::SortExec,
6565
};
6666

67+
use super::pushdown_utils::{
68+
OptimizationTest, TestNode, TestScanBuilder, TestSource, format_plan_for_test,
69+
};
6770
use datafusion_physical_plan::union::UnionExec;
6871
use futures::StreamExt;
6972
use object_store::{ObjectStore, memory::InMemory};
7073
use regex::Regex;
71-
use util::{OptimizationTest, TestNode, TestScanBuilder, format_plan_for_test};
72-
73-
use crate::physical_optimizer::filter_pushdown::util::TestSource;
74-
75-
mod util;
7674

7775
#[test]
7876
fn test_pushdown_into_scan() {

datafusion/core/tests/physical_optimizer/mod.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ mod combine_partial_final_agg;
2424
mod enforce_distribution;
2525
mod enforce_sorting;
2626
mod enforce_sorting_monotonicity;
27-
#[expect(clippy::needless_pass_by_value)]
2827
mod filter_pushdown;
2928
mod join_selection;
3029
#[expect(clippy::needless_pass_by_value)]
@@ -38,3 +37,5 @@ mod sanity_checker;
3837
#[expect(clippy::needless_pass_by_value)]
3938
mod test_utils;
4039
mod window_optimize;
40+
41+
mod pushdown_utils;

datafusion/core/tests/physical_optimizer/projection_pushdown.rs

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,25 +18,29 @@
1818
use std::any::Any;
1919
use std::sync::Arc;
2020

21+
use arrow::array::{Int32Array, RecordBatch, StructArray};
2122
use arrow::compute::SortOptions;
2223
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
24+
use arrow_schema::Fields;
2325
use datafusion::datasource::listing::PartitionedFile;
2426
use datafusion::datasource::memory::MemorySourceConfig;
2527
use datafusion::datasource::physical_plan::CsvSource;
2628
use datafusion::datasource::source::DataSourceExec;
29+
use datafusion::prelude::get_field;
2730
use datafusion_common::config::{ConfigOptions, CsvOptions};
2831
use datafusion_common::{JoinSide, JoinType, NullEquality, Result, ScalarValue};
2932
use datafusion_datasource::TableSchema;
3033
use datafusion_datasource::file_scan_config::FileScanConfigBuilder;
3134
use datafusion_execution::object_store::ObjectStoreUrl;
3235
use datafusion_execution::{SendableRecordBatchStream, TaskContext};
3336
use datafusion_expr::{
34-
Operator, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
37+
Operator, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility, lit,
3538
};
3639
use datafusion_expr_common::columnar_value::ColumnarValue;
3740
use datafusion_physical_expr::expressions::{
3841
BinaryExpr, CaseExpr, CastExpr, Column, Literal, NegativeExpr, binary, cast, col,
3942
};
43+
use datafusion_physical_expr::planner::logical2physical;
4044
use datafusion_physical_expr::{Distribution, Partitioning, ScalarFunctionExpr};
4145
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
4246
use datafusion_physical_expr_common::sort_expr::{
@@ -64,6 +68,8 @@ use datafusion_physical_plan::{ExecutionPlan, displayable};
6468
use insta::assert_snapshot;
6569
use itertools::Itertools;
6670

71+
use crate::physical_optimizer::pushdown_utils::TestScanBuilder;
72+
6773
/// Mocked UDF
6874
#[derive(Debug, PartialEq, Eq, Hash)]
6975
struct DummyUDF {
@@ -1723,3 +1729,87 @@ fn test_cooperative_exec_after_projection() -> Result<()> {
17231729

17241730
Ok(())
17251731
}
1732+
1733+
#[test]
1734+
fn test_pushdown_projection_through_repartition_filter() {
1735+
let struct_fields = Fields::from(vec![Field::new("a", DataType::Int32, false)]);
1736+
let array = StructArray::new(
1737+
struct_fields.clone(),
1738+
vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]))],
1739+
None,
1740+
);
1741+
let batches = vec![
1742+
RecordBatch::try_new(
1743+
Arc::new(Schema::new(vec![Field::new(
1744+
"struct",
1745+
DataType::Struct(struct_fields.clone()),
1746+
true,
1747+
)])),
1748+
vec![Arc::new(array)],
1749+
)
1750+
.unwrap(),
1751+
];
1752+
let build_side_schema = Arc::new(Schema::new(vec![Field::new(
1753+
"struct",
1754+
DataType::Struct(struct_fields),
1755+
true,
1756+
)]));
1757+
1758+
let scan = TestScanBuilder::new(Arc::clone(&build_side_schema))
1759+
.with_support(true)
1760+
.with_batches(batches)
1761+
.build();
1762+
let scan_schema = scan.schema();
1763+
let struct_access = get_field(datafusion_expr::col("struct"), "a");
1764+
let filter = struct_access.clone().gt(lit(2));
1765+
let repartition =
1766+
RepartitionExec::try_new(scan, Partitioning::RoundRobinBatch(32)).unwrap();
1767+
let filter_exec = FilterExec::try_new(
1768+
logical2physical(&filter, &scan_schema),
1769+
Arc::new(repartition),
1770+
)
1771+
.unwrap();
1772+
let projection: Arc<dyn ExecutionPlan> = Arc::new(
1773+
ProjectionExec::try_new(
1774+
vec![ProjectionExpr::new(
1775+
logical2physical(&struct_access, &scan_schema),
1776+
"a",
1777+
)],
1778+
Arc::new(filter_exec),
1779+
)
1780+
.unwrap(),
1781+
) as _;
1782+
1783+
let initial = displayable(projection.as_ref()).indent(true).to_string();
1784+
let actual = initial.trim();
1785+
1786+
assert_snapshot!(
1787+
actual,
1788+
@r"
1789+
ProjectionExec: expr=[get_field(struct@0, a) as a]
1790+
FilterExec: get_field(struct@0, a) > 2
1791+
RepartitionExec: partitioning=RoundRobinBatch(32), input_partitions=1
1792+
DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[struct], file_type=test, pushdown_supported=true
1793+
"
1794+
);
1795+
1796+
let after_optimize = ProjectionPushdown::new()
1797+
.optimize(projection, &ConfigOptions::new())
1798+
.unwrap();
1799+
1800+
let after_optimize_string = displayable(after_optimize.as_ref())
1801+
.indent(true)
1802+
.to_string();
1803+
let actual = after_optimize_string.trim();
1804+
1805+
// Projection should be pushed all the way down to the DataSource, and
1806+
// filter predicate should be rewritten to reference projection's output column
1807+
assert_snapshot!(
1808+
actual,
1809+
@r"
1810+
FilterExec: a@0 > 2
1811+
RepartitionExec: partitioning=RoundRobinBatch(32), input_partitions=1
1812+
DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[get_field(struct@0, a) as a], file_type=test, pushdown_supported=true
1813+
"
1814+
);
1815+
}

datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs renamed to datafusion/core/tests/physical_optimizer/pushdown_utils.rs

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ use datafusion_datasource::{
2424
file_scan_config::FileScanConfigBuilder, file_stream::FileOpenFuture,
2525
file_stream::FileOpener, source::DataSourceExec,
2626
};
27+
use datafusion_physical_expr::projection::ProjectionExprs;
2728
use datafusion_physical_expr_common::physical_expr::fmt_sql;
2829
use datafusion_physical_optimizer::PhysicalOptimizerRule;
2930
use datafusion_physical_plan::filter::batch_filter;
@@ -50,7 +51,7 @@ use std::{
5051
pub struct TestOpener {
5152
batches: Vec<RecordBatch>,
5253
batch_size: Option<usize>,
53-
projection: Option<Vec<usize>>,
54+
projection: Option<ProjectionExprs>,
5455
predicate: Option<Arc<dyn PhysicalExpr>>,
5556
}
5657

@@ -60,6 +61,7 @@ impl FileOpener for TestOpener {
6061
if self.batches.is_empty() {
6162
return Ok((async { Ok(TestStream::new(vec![]).boxed()) }).boxed());
6263
}
64+
let schema = self.batches[0].schema();
6365
if let Some(batch_size) = self.batch_size {
6466
let batch = concat_batches(&batches[0].schema(), &batches)?;
6567
let mut new_batches = Vec::new();
@@ -83,9 +85,10 @@ impl FileOpener for TestOpener {
8385
batches = new_batches;
8486

8587
if let Some(projection) = &self.projection {
88+
let projector = projection.make_projector(&schema)?;
8689
batches = batches
8790
.into_iter()
88-
.map(|batch| batch.project(projection).unwrap())
91+
.map(|batch| projector.project_batch(&batch).unwrap())
8992
.collect();
9093
}
9194

@@ -103,14 +106,13 @@ pub struct TestSource {
103106
batch_size: Option<usize>,
104107
batches: Vec<RecordBatch>,
105108
metrics: ExecutionPlanMetricsSet,
106-
projection: Option<Vec<usize>>,
109+
projection: Option<ProjectionExprs>,
107110
table_schema: datafusion_datasource::TableSchema,
108111
}
109112

110113
impl TestSource {
111114
pub fn new(schema: SchemaRef, support: bool, batches: Vec<RecordBatch>) -> Self {
112-
let table_schema =
113-
datafusion_datasource::TableSchema::new(Arc::clone(&schema), vec![]);
115+
let table_schema = datafusion_datasource::TableSchema::new(schema, vec![]);
114116
Self {
115117
support,
116118
metrics: ExecutionPlanMetricsSet::new(),
@@ -210,6 +212,30 @@ impl FileSource for TestSource {
210212
}
211213
}
212214

215+
fn try_pushdown_projection(
216+
&self,
217+
projection: &ProjectionExprs,
218+
) -> Result<Option<Arc<dyn FileSource>>> {
219+
if let Some(existing_projection) = &self.projection {
220+
// Combine existing projection with new projection
221+
let combined_projection = existing_projection.try_merge(projection)?;
222+
Ok(Some(Arc::new(TestSource {
223+
projection: Some(combined_projection),
224+
table_schema: self.table_schema.clone(),
225+
..self.clone()
226+
})))
227+
} else {
228+
Ok(Some(Arc::new(TestSource {
229+
projection: Some(projection.clone()),
230+
..self.clone()
231+
})))
232+
}
233+
}
234+
235+
fn projection(&self) -> Option<&ProjectionExprs> {
236+
self.projection.as_ref()
237+
}
238+
213239
fn table_schema(&self) -> &datafusion_datasource::TableSchema {
214240
&self.table_schema
215241
}
@@ -332,6 +358,7 @@ pub struct OptimizationTest {
332358
}
333359

334360
impl OptimizationTest {
361+
#[expect(clippy::needless_pass_by_value)]
335362
pub fn new<O>(
336363
input_plan: Arc<dyn ExecutionPlan>,
337364
opt: O,

datafusion/expr-common/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,7 @@ pub mod operator;
4444
pub mod signature;
4545
pub mod sort_properties;
4646
pub mod statistics;
47+
pub mod triviality;
4748
pub mod type_coercion;
49+
50+
pub use triviality::ArgTriviality;
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! Triviality classification for expressions and function arguments.
19+
20+
/// Classification of argument triviality for scalar functions.
21+
///
22+
/// This enum is used by [`ScalarUDFImpl::triviality`] to allow
23+
/// functions to make context-dependent decisions about whether they are
24+
/// trivial based on the nature of their arguments.
25+
///
26+
/// For example, `get_field(struct_col, 'field_name')` is trivial (static field
27+
/// lookup), but `get_field(struct_col, key_col)` is not (dynamic per-row lookup).
28+
///
29+
/// [`ScalarUDFImpl::triviality`]: https://docs.rs/datafusion-expr/latest/datafusion_expr/trait.ScalarUDFImpl.html#tymethod.triviality
30+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31+
pub enum ArgTriviality {
32+
/// Argument is a literal constant value or an expression that can be
33+
/// evaluated to a constant at planning time.
34+
Literal,
35+
/// Argument is a simple column reference.
36+
Column,
37+
/// Argument is a complex expression that declares itself trivial.
38+
/// For example, if `get_field(struct_col, 'field_name')` is implemented as a
39+
/// trivial expression, then it would return this variant.
40+
/// Then `other_trivial_function(get_field(...), 42)` could also be classified as
41+
/// a trivial expression using the knowledge that `get_field(...)` is trivial.
42+
TrivialExpr,
43+
/// Argument is a complex expression that declares itself non-trivial.
44+
/// For example, `min(col1 + col2)` is non-trivial because it requires per-row computation.
45+
NonTrivial,
46+
}

datafusion/expr/src/expr.rs

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ use std::sync::Arc;
2727
use crate::expr_fn::binary_expr;
2828
use crate::function::WindowFunctionSimplification;
2929
use crate::logical_plan::Subquery;
30-
use crate::{AggregateUDF, Volatility};
30+
use crate::{AggregateUDF, ArgTriviality, Volatility};
3131
use crate::{ExprSchemable, Operator, Signature, WindowFrame, WindowUDF};
3232

3333
use arrow::datatypes::{DataType, Field, FieldRef};
@@ -1933,6 +1933,32 @@ impl Expr {
19331933
}
19341934
}
19351935

1936+
/// Returns the triviality classification of this expression.
1937+
///
1938+
/// Trivial expressions include column references, literals, and nested
1939+
/// field access via `get_field`.
1940+
///
1941+
/// # Example
1942+
/// ```
1943+
/// # use datafusion_expr::{col, ArgTriviality};
1944+
/// let expr = col("foo");
1945+
/// assert_eq!(expr.triviality(), ArgTriviality::Column);
1946+
/// ```
1947+
pub fn triviality(&self) -> ArgTriviality {
1948+
match self {
1949+
Expr::Column(_) => ArgTriviality::Column,
1950+
Expr::Literal(_, _) => ArgTriviality::Literal,
1951+
Expr::ScalarFunction(func) => {
1952+
// Classify each argument's triviality for context-aware decision making
1953+
let arg_trivialities: Vec<ArgTriviality> =
1954+
func.args.iter().map(|arg| arg.triviality()).collect();
1955+
1956+
func.func.triviality_with_args(&arg_trivialities)
1957+
}
1958+
_ => ArgTriviality::NonTrivial,
1959+
}
1960+
}
1961+
19361962
/// Return all references to columns in this expression.
19371963
///
19381964
/// # Example

datafusion/expr/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ pub use datafusion_doc::{
9292
DocSection, Documentation, DocumentationBuilder, aggregate_doc_sections,
9393
scalar_doc_sections, window_doc_sections,
9494
};
95+
pub use datafusion_expr_common::ArgTriviality;
9596
pub use datafusion_expr_common::accumulator::Accumulator;
9697
pub use datafusion_expr_common::columnar_value::ColumnarValue;
9798
pub use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator};

0 commit comments

Comments
 (0)