@@ -1045,7 +1045,19 @@ impl GroupedHashAggregateStream {
10451045 self . group_values . len ( )
10461046 } ;
10471047
1048- if let Some ( batch) = self . emit ( EmitTo :: First ( n) , false ) ? {
1048+ // Clamp to the sort boundary when using partial group ordering,
1049+ // otherwise remove_groups panics (#20445).
1050+ let n = match & self . group_ordering {
1051+ GroupOrdering :: None => n,
1052+ _ => match self . group_ordering . emit_to ( ) {
1053+ Some ( EmitTo :: First ( max) ) => n. min ( max) ,
1054+ _ => 0 ,
1055+ } ,
1056+ } ;
1057+
1058+ if n > 0
1059+ && let Some ( batch) = self . emit ( EmitTo :: First ( n) , false ) ?
1060+ {
10491061 Ok ( Some ( ExecutionState :: ProducingOutput ( batch) ) )
10501062 } else {
10511063 Err ( oom)
@@ -1305,6 +1317,7 @@ impl GroupedHashAggregateStream {
13051317#[ cfg( test) ]
13061318mod tests {
13071319 use super :: * ;
1320+ use crate :: InputOrderMode ;
13081321 use crate :: execution_plan:: ExecutionPlan ;
13091322 use crate :: test:: TestMemoryExec ;
13101323 use arrow:: array:: { Int32Array , Int64Array } ;
@@ -1567,4 +1580,88 @@ mod tests {
15671580
15681581 Ok ( ( ) )
15691582 }
1583+
1584+ #[ tokio:: test]
1585+ async fn test_emit_early_with_partially_sorted ( ) -> Result < ( ) > {
1586+ // Reproducer for #20445: EmitEarly with PartiallySorted panics in
1587+ // remove_groups because it emits more groups than the sort boundary.
1588+ let schema = Arc :: new ( Schema :: new ( vec ! [
1589+ Field :: new( "sort_col" , DataType :: Int32 , false ) ,
1590+ Field :: new( "group_col" , DataType :: Int32 , false ) ,
1591+ Field :: new( "value_col" , DataType :: Int64 , false ) ,
1592+ ] ) ) ;
1593+
1594+ // All rows share sort_col=1 (no sort boundary), with unique group_col
1595+ // values to create many groups and trigger memory pressure.
1596+ let n = 256 ;
1597+ let batch = RecordBatch :: try_new (
1598+ Arc :: clone ( & schema) ,
1599+ vec ! [
1600+ Arc :: new( Int32Array :: from( vec![ 1 ; n] ) ) ,
1601+ Arc :: new( Int32Array :: from( ( 0 ..n as i32 ) . collect:: <Vec <_>>( ) ) ) ,
1602+ Arc :: new( Int64Array :: from( vec![ 1 ; n] ) ) ,
1603+ ] ,
1604+ ) ?;
1605+
1606+ let runtime = RuntimeEnvBuilder :: default ( )
1607+ . with_memory_limit ( 4096 , 1.0 )
1608+ . build_arc ( ) ?;
1609+ let mut task_ctx = TaskContext :: default ( ) . with_runtime ( runtime) ;
1610+ let mut cfg = task_ctx. session_config ( ) . clone ( ) ;
1611+ cfg = cfg. set (
1612+ "datafusion.execution.batch_size" ,
1613+ & datafusion_common:: ScalarValue :: UInt64 ( Some ( 128 ) ) ,
1614+ ) ;
1615+ cfg = cfg. set (
1616+ "datafusion.execution.skip_partial_aggregation_probe_rows_threshold" ,
1617+ & datafusion_common:: ScalarValue :: UInt64 ( Some ( u64:: MAX ) ) ,
1618+ ) ;
1619+ task_ctx = task_ctx. with_session_config ( cfg) ;
1620+ let task_ctx = Arc :: new ( task_ctx) ;
1621+
1622+ let ordering = LexOrdering :: new ( vec ! [ PhysicalSortExpr :: new_default( Arc :: new(
1623+ Column :: new( "sort_col" , 0 ) ,
1624+ )
1625+ as _) ] )
1626+ . unwrap ( ) ;
1627+ let exec = TestMemoryExec :: try_new ( & [ vec ! [ batch] ] , Arc :: clone ( & schema) , None ) ?
1628+ . try_with_sort_information ( vec ! [ ordering] ) ?;
1629+ let exec = Arc :: new ( TestMemoryExec :: update_cache ( & Arc :: new ( exec) ) ) ;
1630+
1631+ // GROUP BY sort_col, group_col with input sorted on sort_col
1632+ // gives PartiallySorted([0])
1633+ let aggregate_exec = AggregateExec :: try_new (
1634+ AggregateMode :: Partial ,
1635+ PhysicalGroupBy :: new_single ( vec ! [
1636+ ( col( "sort_col" , & schema) ?, "sort_col" . to_string( ) ) ,
1637+ ( col( "group_col" , & schema) ?, "group_col" . to_string( ) ) ,
1638+ ] ) ,
1639+ vec ! [ Arc :: new(
1640+ AggregateExprBuilder :: new( count_udaf( ) , vec![ col( "value_col" , & schema) ?] )
1641+ . schema( Arc :: clone( & schema) )
1642+ . alias( "count_value" )
1643+ . build( ) ?,
1644+ ) ] ,
1645+ vec ! [ None ] ,
1646+ exec,
1647+ Arc :: clone ( & schema) ,
1648+ ) ?;
1649+ assert ! ( matches!(
1650+ aggregate_exec. input_order_mode( ) ,
1651+ InputOrderMode :: PartiallySorted ( _)
1652+ ) ) ;
1653+
1654+ // Must not panic with "assertion failed: *current_sort >= n"
1655+ let mut stream = GroupedHashAggregateStream :: new ( & aggregate_exec, & task_ctx, 0 ) ?;
1656+ while let Some ( result) = stream. next ( ) . await {
1657+ if let Err ( e) = result {
1658+ if e. to_string ( ) . contains ( "Resources exhausted" ) {
1659+ break ;
1660+ }
1661+ return Err ( e) ;
1662+ }
1663+ }
1664+
1665+ Ok ( ( ) )
1666+ }
15701667}
0 commit comments