Skip to content

Commit 132b043

Browse files
alambakurmustafa
andauthored
Fix incorrect SortExec removal before AggregateExec (option 2) (#20247)
## Which issue does this PR close? - Fixes #20244 This is an alternatative to - #20245 ## Rationale for this change Wrong answers bug was exposed by #19287 in 52. See #20244 and backstory here - #19287 (comment) ## What changes are included in this PR? Fix the bug by properly implemnting ## Are these changes tested? Yes, a new test is added ## Are there any user-facing changes? A bug is fixed --------- Co-authored-by: Mustafa Akur <akurmustafa@gmail.com>
1 parent b0566c5 commit 132b043

2 files changed

Lines changed: 284 additions & 0 deletions

File tree

datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ use datafusion_physical_expr_common::sort_expr::{
3535
LexOrdering, LexRequirement, OrderingRequirements, PhysicalSortExpr,
3636
PhysicalSortRequirement,
3737
};
38+
use datafusion_physical_plan::aggregates::AggregateExec;
3839
use datafusion_physical_plan::execution_plan::CardinalityEffect;
3940
use datafusion_physical_plan::filter::FilterExec;
4041
use datafusion_physical_plan::joins::utils::{
@@ -353,6 +354,8 @@ fn pushdown_requirement_to_children(
353354
Ok(None)
354355
}
355356
}
357+
} else if let Some(aggregate_exec) = plan.as_any().downcast_ref::<AggregateExec>() {
358+
handle_aggregate_pushdown(aggregate_exec, parent_required)
356359
} else if maintains_input_order.is_empty()
357360
|| !maintains_input_order.iter().any(|o| *o)
358361
|| plan.as_any().is::<RepartitionExec>()
@@ -388,6 +391,77 @@ fn pushdown_requirement_to_children(
388391
// TODO: Add support for Projection push down
389392
}
390393

394+
/// Try to push sorting through [`AggregateExec`]
395+
///
396+
/// `AggregateExec` only preserves the input order of its group by columns
397+
/// (not aggregates in general, which are formed from arbitrary expressions over
398+
/// input)
399+
///
400+
/// Thus function rewrites the parent required ordering in terms of the
401+
/// aggregate input if possible. This rewritten requirement represents the
402+
/// ordering of the `AggregateExec`'s **input** that would also satisfy the
403+
/// **parent** ordering.
404+
///
405+
/// If no such mapping is possible (e.g. because the sort references aggregate
406+
/// columns), returns None.
407+
fn handle_aggregate_pushdown(
408+
aggregate_exec: &AggregateExec,
409+
parent_required: OrderingRequirements,
410+
) -> Result<Option<Vec<Option<OrderingRequirements>>>> {
411+
if !aggregate_exec
412+
.maintains_input_order()
413+
.into_iter()
414+
.any(|o| o)
415+
{
416+
return Ok(None);
417+
}
418+
419+
let group_expr = aggregate_exec.group_expr();
420+
// GROUPING SETS introduce additional output columns and NULL substitutions;
421+
// skip pushdown until we can map those cases safely.
422+
if group_expr.has_grouping_set() {
423+
return Ok(None);
424+
}
425+
426+
let group_input_exprs = group_expr.input_exprs();
427+
let parent_requirement = parent_required.into_single();
428+
let mut child_requirement = Vec::with_capacity(parent_requirement.len());
429+
430+
for req in parent_requirement {
431+
// Sort above AggregateExec should reference its output columns. Map each
432+
// output group-by column to its original input expression.
433+
let Some(column) = req.expr.as_any().downcast_ref::<Column>() else {
434+
return Ok(None);
435+
};
436+
if column.index() >= group_input_exprs.len() {
437+
// AggregateExec does not produce output that is sorted on aggregate
438+
// columns so those can not be pushed through.
439+
return Ok(None);
440+
}
441+
child_requirement.push(PhysicalSortRequirement::new(
442+
Arc::clone(&group_input_exprs[column.index()]),
443+
req.options,
444+
));
445+
}
446+
447+
let Some(child_requirement) = LexRequirement::new(child_requirement) else {
448+
return Ok(None);
449+
};
450+
451+
// Keep sort above aggregate unless input ordering already satisfies the
452+
// mapped requirement.
453+
if aggregate_exec
454+
.input()
455+
.equivalence_properties()
456+
.ordering_satisfy_requirement(child_requirement.iter().cloned())?
457+
{
458+
let child_requirements = OrderingRequirements::new(child_requirement);
459+
Ok(Some(vec![Some(child_requirements)]))
460+
} else {
461+
Ok(None)
462+
}
463+
}
464+
391465
/// Return true if pushing the sort requirements through a node would violate
392466
/// the input sorting requirements for the plan
393467
fn pushdown_would_violate_requirements(

datafusion/sqllogictest/test_files/sort_pushdown.slt

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -851,6 +851,210 @@ LIMIT 3;
851851
5 4
852852
2 -3
853853

854+
# Test 3.7: Aggregate ORDER BY expression should keep SortExec
855+
# Source pattern declared on parquet scan: [x ASC, y ASC].
856+
# Requested pattern in ORDER BY: [x ASC, CAST(y AS BIGINT) % 2 ASC].
857+
# Example for x=1 input y order 1,2,3 gives bucket order 1,0,1, which does not
858+
# match requested bucket ASC order. SortExec is required above AggregateExec.
859+
statement ok
860+
SET datafusion.execution.target_partitions = 1;
861+
862+
statement ok
863+
CREATE TABLE agg_expr_data(x INT, y INT, v INT) AS VALUES
864+
(1, 1, 10),
865+
(1, 2, 20),
866+
(1, 3, 30),
867+
(2, 1, 40),
868+
(2, 2, 50),
869+
(2, 3, 60);
870+
871+
query I
872+
COPY (SELECT * FROM agg_expr_data ORDER BY x, y)
873+
TO 'test_files/scratch/sort_pushdown/agg_expr_sorted.parquet';
874+
----
875+
6
876+
877+
statement ok
878+
CREATE EXTERNAL TABLE agg_expr_parquet(x INT, y INT, v INT)
879+
STORED AS PARQUET
880+
LOCATION 'test_files/scratch/sort_pushdown/agg_expr_sorted.parquet'
881+
WITH ORDER (x ASC, y ASC);
882+
883+
query TT
884+
EXPLAIN SELECT
885+
x,
886+
CAST(y AS BIGINT) % 2,
887+
SUM(v)
888+
FROM agg_expr_parquet
889+
GROUP BY x, CAST(y AS BIGINT) % 2
890+
ORDER BY x, CAST(y AS BIGINT) % 2;
891+
----
892+
logical_plan
893+
01)Sort: agg_expr_parquet.x ASC NULLS LAST, agg_expr_parquet.y % Int64(2) ASC NULLS LAST
894+
02)--Aggregate: groupBy=[[agg_expr_parquet.x, CAST(agg_expr_parquet.y AS Int64) % Int64(2)]], aggr=[[sum(CAST(agg_expr_parquet.v AS Int64))]]
895+
03)----TableScan: agg_expr_parquet projection=[x, y, v]
896+
physical_plan
897+
01)SortExec: expr=[x@0 ASC NULLS LAST, agg_expr_parquet.y % Int64(2)@1 ASC NULLS LAST], preserve_partitioning=[false]
898+
02)--AggregateExec: mode=Single, gby=[x@0 as x, CAST(y@1 AS Int64) % 2 as agg_expr_parquet.y % Int64(2)], aggr=[sum(agg_expr_parquet.v)], ordering_mode=PartiallySorted([0])
899+
03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/agg_expr_sorted.parquet]]}, projection=[x, y, v], output_ordering=[x@0 ASC NULLS LAST, y@1 ASC NULLS LAST], file_type=parquet
900+
901+
# Expected output pattern from ORDER BY [x, bucket]:
902+
# rows grouped by x, and within each x bucket appears as 0 then 1.
903+
query III
904+
SELECT
905+
x,
906+
CAST(y AS BIGINT) % 2,
907+
SUM(v)
908+
FROM agg_expr_parquet
909+
GROUP BY x, CAST(y AS BIGINT) % 2
910+
ORDER BY x, CAST(y AS BIGINT) % 2;
911+
----
912+
1 0 20
913+
1 1 40
914+
2 0 50
915+
2 1 100
916+
917+
# Test 3.8: Aggregate ORDER BY monotonic expression can push down (no SortExec)
918+
query TT
919+
EXPLAIN SELECT
920+
x,
921+
CAST(y AS BIGINT),
922+
SUM(v)
923+
FROM agg_expr_parquet
924+
GROUP BY x, CAST(y AS BIGINT)
925+
ORDER BY x, CAST(y AS BIGINT);
926+
----
927+
logical_plan
928+
01)Sort: agg_expr_parquet.x ASC NULLS LAST, agg_expr_parquet.y ASC NULLS LAST
929+
02)--Aggregate: groupBy=[[agg_expr_parquet.x, CAST(agg_expr_parquet.y AS Int64)]], aggr=[[sum(CAST(agg_expr_parquet.v AS Int64))]]
930+
03)----TableScan: agg_expr_parquet projection=[x, y, v]
931+
physical_plan
932+
01)AggregateExec: mode=Single, gby=[x@0 as x, CAST(y@1 AS Int64) as agg_expr_parquet.y], aggr=[sum(agg_expr_parquet.v)], ordering_mode=Sorted
933+
02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/agg_expr_sorted.parquet]]}, projection=[x, y, v], output_ordering=[x@0 ASC NULLS LAST, y@1 ASC NULLS LAST], file_type=parquet
934+
935+
query III
936+
SELECT
937+
x,
938+
CAST(y AS BIGINT),
939+
SUM(v)
940+
FROM agg_expr_parquet
941+
GROUP BY x, CAST(y AS BIGINT)
942+
ORDER BY x, CAST(y AS BIGINT);
943+
----
944+
1 1 10
945+
1 2 20
946+
1 3 30
947+
2 1 40
948+
2 2 50
949+
2 3 60
950+
951+
# Test 3.9: Aggregate ORDER BY aggregate output should keep SortExec
952+
query TT
953+
EXPLAIN SELECT x, SUM(v)
954+
FROM agg_expr_parquet
955+
GROUP BY x
956+
ORDER BY SUM(v);
957+
----
958+
logical_plan
959+
01)Sort: sum(agg_expr_parquet.v) ASC NULLS LAST
960+
02)--Aggregate: groupBy=[[agg_expr_parquet.x]], aggr=[[sum(CAST(agg_expr_parquet.v AS Int64))]]
961+
03)----TableScan: agg_expr_parquet projection=[x, v]
962+
physical_plan
963+
01)SortExec: expr=[sum(agg_expr_parquet.v)@1 ASC NULLS LAST], preserve_partitioning=[false]
964+
02)--AggregateExec: mode=Single, gby=[x@0 as x], aggr=[sum(agg_expr_parquet.v)], ordering_mode=Sorted
965+
03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/agg_expr_sorted.parquet]]}, projection=[x, v], output_ordering=[x@0 ASC NULLS LAST], file_type=parquet
966+
967+
query II
968+
SELECT x, SUM(v)
969+
FROM agg_expr_parquet
970+
GROUP BY x
971+
ORDER BY SUM(v);
972+
----
973+
1 60
974+
2 150
975+
976+
# Test 3.10: Aggregate with non-preserved input order should keep SortExec
977+
# v is not part of the order by
978+
query TT
979+
EXPLAIN SELECT v, SUM(y)
980+
FROM agg_expr_parquet
981+
GROUP BY v
982+
ORDER BY v;
983+
----
984+
logical_plan
985+
01)Sort: agg_expr_parquet.v ASC NULLS LAST
986+
02)--Aggregate: groupBy=[[agg_expr_parquet.v]], aggr=[[sum(CAST(agg_expr_parquet.y AS Int64))]]
987+
03)----TableScan: agg_expr_parquet projection=[y, v]
988+
physical_plan
989+
01)SortExec: expr=[v@0 ASC NULLS LAST], preserve_partitioning=[false]
990+
02)--AggregateExec: mode=Single, gby=[v@1 as v], aggr=[sum(agg_expr_parquet.y)]
991+
03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/agg_expr_sorted.parquet]]}, projection=[y, v], file_type=parquet
992+
993+
query II
994+
SELECT v, SUM(y)
995+
FROM agg_expr_parquet
996+
GROUP BY v
997+
ORDER BY v;
998+
----
999+
10 1
1000+
20 2
1001+
30 3
1002+
40 1
1003+
50 2
1004+
60 3
1005+
1006+
# Test 3.11: Aggregate ORDER BY non-column expression (unsatisfied) keeps SortExec
1007+
# (though note in theory DataFusion could figure out that data sorted by x will also be sorted by x+1)
1008+
query TT
1009+
EXPLAIN SELECT x, SUM(v)
1010+
FROM agg_expr_parquet
1011+
GROUP BY x
1012+
ORDER BY x + 1 DESC;
1013+
----
1014+
logical_plan
1015+
01)Sort: CAST(agg_expr_parquet.x AS Int64) + Int64(1) DESC NULLS FIRST
1016+
02)--Aggregate: groupBy=[[agg_expr_parquet.x]], aggr=[[sum(CAST(agg_expr_parquet.v AS Int64))]]
1017+
03)----TableScan: agg_expr_parquet projection=[x, v]
1018+
physical_plan
1019+
01)SortExec: expr=[CAST(x@0 AS Int64) + 1 DESC], preserve_partitioning=[false]
1020+
02)--AggregateExec: mode=Single, gby=[x@0 as x], aggr=[sum(agg_expr_parquet.v)], ordering_mode=Sorted
1021+
03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/agg_expr_sorted.parquet]]}, projection=[x, v], output_ordering=[x@0 ASC NULLS LAST], file_type=parquet
1022+
1023+
query II
1024+
SELECT x, SUM(v)
1025+
FROM agg_expr_parquet
1026+
GROUP BY x
1027+
ORDER BY x + 1 DESC;
1028+
----
1029+
2 150
1030+
1 60
1031+
1032+
# Test 3.12: Aggregate ORDER BY non-column expression (unsatisfied) keeps SortExec
1033+
# (though note in theory DataFusion could figure out that data sorted by x will also be sorted by x+1)
1034+
query TT
1035+
EXPLAIN SELECT x, SUM(v)
1036+
FROM agg_expr_parquet
1037+
GROUP BY x
1038+
ORDER BY 2 * x ASC;
1039+
----
1040+
logical_plan
1041+
01)Sort: Int64(2) * CAST(agg_expr_parquet.x AS Int64) ASC NULLS LAST
1042+
02)--Aggregate: groupBy=[[agg_expr_parquet.x]], aggr=[[sum(CAST(agg_expr_parquet.v AS Int64))]]
1043+
03)----TableScan: agg_expr_parquet projection=[x, v]
1044+
physical_plan
1045+
01)SortExec: expr=[2 * CAST(x@0 AS Int64) ASC NULLS LAST], preserve_partitioning=[false]
1046+
02)--AggregateExec: mode=Single, gby=[x@0 as x], aggr=[sum(agg_expr_parquet.v)], ordering_mode=Sorted
1047+
03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/sort_pushdown/agg_expr_sorted.parquet]]}, projection=[x, v], output_ordering=[x@0 ASC NULLS LAST], file_type=parquet
1048+
1049+
query II
1050+
SELECT x, SUM(v)
1051+
FROM agg_expr_parquet
1052+
GROUP BY x
1053+
ORDER BY 2 * x ASC;
1054+
----
1055+
1 60
1056+
2 150
1057+
8541058
# Test 4: Reversed filesystem order with inferred ordering
8551059
# Create 3 parquet files with non-overlapping id ranges, named so filesystem
8561060
# order is OPPOSITE to data order. Each file is internally sorted by id ASC.
@@ -1420,5 +1624,11 @@ DROP TABLE signed_data;
14201624
statement ok
14211625
DROP TABLE signed_parquet;
14221626

1627+
statement ok
1628+
DROP TABLE agg_expr_data;
1629+
1630+
statement ok
1631+
DROP TABLE agg_expr_parquet;
1632+
14231633
statement ok
14241634
SET datafusion.optimizer.enable_sort_pushdown = true;

0 commit comments

Comments
 (0)