|
18 | 18 | use std::any::Any; |
19 | 19 | use std::sync::Arc; |
20 | 20 |
|
| 21 | +use arrow::array::{Int32Array, RecordBatch, StructArray}; |
21 | 22 | use arrow::compute::SortOptions; |
22 | 23 | use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; |
| 24 | +use arrow_schema::Fields; |
23 | 25 | use datafusion::datasource::listing::PartitionedFile; |
24 | 26 | use datafusion::datasource::memory::MemorySourceConfig; |
25 | 27 | use datafusion::datasource::physical_plan::CsvSource; |
26 | 28 | use datafusion::datasource::source::DataSourceExec; |
| 29 | +use datafusion::prelude::get_field; |
27 | 30 | use datafusion_common::config::{ConfigOptions, CsvOptions}; |
28 | 31 | use datafusion_common::{JoinSide, JoinType, NullEquality, Result, ScalarValue}; |
29 | 32 | use datafusion_datasource::TableSchema; |
30 | 33 | use datafusion_datasource::file_scan_config::FileScanConfigBuilder; |
31 | 34 | use datafusion_execution::object_store::ObjectStoreUrl; |
32 | 35 | use datafusion_execution::{SendableRecordBatchStream, TaskContext}; |
33 | 36 | use datafusion_expr::{ |
34 | | - Operator, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility, |
| 37 | + Operator, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility, lit, |
35 | 38 | }; |
36 | 39 | use datafusion_expr_common::columnar_value::ColumnarValue; |
37 | 40 | use datafusion_physical_expr::expressions::{ |
38 | 41 | BinaryExpr, CaseExpr, CastExpr, Column, Literal, NegativeExpr, binary, cast, col, |
39 | 42 | }; |
| 43 | +use datafusion_physical_expr::planner::logical2physical; |
40 | 44 | use datafusion_physical_expr::{Distribution, Partitioning, ScalarFunctionExpr}; |
41 | 45 | use datafusion_physical_expr_common::physical_expr::PhysicalExpr; |
42 | 46 | use datafusion_physical_expr_common::sort_expr::{ |
@@ -64,6 +68,8 @@ use datafusion_physical_plan::{ExecutionPlan, displayable}; |
64 | 68 | use insta::assert_snapshot; |
65 | 69 | use itertools::Itertools; |
66 | 70 |
|
| 71 | +use crate::physical_optimizer::pushdown_utils::TestScanBuilder; |
| 72 | + |
67 | 73 | /// Mocked UDF |
68 | 74 | #[derive(Debug, PartialEq, Eq, Hash)] |
69 | 75 | struct DummyUDF { |
@@ -1723,3 +1729,87 @@ fn test_cooperative_exec_after_projection() -> Result<()> { |
1723 | 1729 |
|
1724 | 1730 | Ok(()) |
1725 | 1731 | } |
| 1732 | + |
| 1733 | +#[test] |
| 1734 | +fn test_pushdown_projection_through_repartition_filter() { |
| 1735 | + let struct_fields = Fields::from(vec![Field::new("a", DataType::Int32, false)]); |
| 1736 | + let array = StructArray::new( |
| 1737 | + struct_fields.clone(), |
| 1738 | + vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]))], |
| 1739 | + None, |
| 1740 | + ); |
| 1741 | + let batches = vec![ |
| 1742 | + RecordBatch::try_new( |
| 1743 | + Arc::new(Schema::new(vec![Field::new( |
| 1744 | + "struct", |
| 1745 | + DataType::Struct(struct_fields.clone()), |
| 1746 | + true, |
| 1747 | + )])), |
| 1748 | + vec![Arc::new(array)], |
| 1749 | + ) |
| 1750 | + .unwrap(), |
| 1751 | + ]; |
| 1752 | + let build_side_schema = Arc::new(Schema::new(vec![Field::new( |
| 1753 | + "struct", |
| 1754 | + DataType::Struct(struct_fields), |
| 1755 | + true, |
| 1756 | + )])); |
| 1757 | + |
| 1758 | + let scan = TestScanBuilder::new(Arc::clone(&build_side_schema)) |
| 1759 | + .with_support(true) |
| 1760 | + .with_batches(batches) |
| 1761 | + .build(); |
| 1762 | + let scan_schema = scan.schema(); |
| 1763 | + let struct_access = get_field(datafusion_expr::col("struct"), "a"); |
| 1764 | + let filter = struct_access.clone().gt(lit(2)); |
| 1765 | + let repartition = |
| 1766 | + RepartitionExec::try_new(scan, Partitioning::RoundRobinBatch(32)).unwrap(); |
| 1767 | + let filter_exec = FilterExec::try_new( |
| 1768 | + logical2physical(&filter, &scan_schema), |
| 1769 | + Arc::new(repartition), |
| 1770 | + ) |
| 1771 | + .unwrap(); |
| 1772 | + let projection: Arc<dyn ExecutionPlan> = Arc::new( |
| 1773 | + ProjectionExec::try_new( |
| 1774 | + vec![ProjectionExpr::new( |
| 1775 | + logical2physical(&struct_access, &scan_schema), |
| 1776 | + "a", |
| 1777 | + )], |
| 1778 | + Arc::new(filter_exec), |
| 1779 | + ) |
| 1780 | + .unwrap(), |
| 1781 | + ) as _; |
| 1782 | + |
| 1783 | + let initial = displayable(projection.as_ref()).indent(true).to_string(); |
| 1784 | + let actual = initial.trim(); |
| 1785 | + |
| 1786 | + assert_snapshot!( |
| 1787 | + actual, |
| 1788 | + @r" |
| 1789 | + ProjectionExec: expr=[get_field(struct@0, a) as a] |
| 1790 | + FilterExec: get_field(struct@0, a) > 2 |
| 1791 | + RepartitionExec: partitioning=RoundRobinBatch(32), input_partitions=1 |
| 1792 | + DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[struct], file_type=test, pushdown_supported=true |
| 1793 | + " |
| 1794 | + ); |
| 1795 | + |
| 1796 | + let after_optimize = ProjectionPushdown::new() |
| 1797 | + .optimize(projection, &ConfigOptions::new()) |
| 1798 | + .unwrap(); |
| 1799 | + |
| 1800 | + let after_optimize_string = displayable(after_optimize.as_ref()) |
| 1801 | + .indent(true) |
| 1802 | + .to_string(); |
| 1803 | + let actual = after_optimize_string.trim(); |
| 1804 | + |
| 1805 | + // Projection should be pushed all the way down to the DataSource, and |
| 1806 | + // filter predicate should be rewritten to reference projection's output column |
| 1807 | + assert_snapshot!( |
| 1808 | + actual, |
| 1809 | + @r" |
| 1810 | + FilterExec: a@0 > 2 |
| 1811 | + RepartitionExec: partitioning=RoundRobinBatch(32), input_partitions=1 |
| 1812 | + DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[get_field(struct@0, a) as a], file_type=test, pushdown_supported=true |
| 1813 | + " |
| 1814 | + ); |
| 1815 | +} |
0 commit comments