@@ -23,13 +23,11 @@ use std::task::{Context, Poll};
2323use std:: vec;
2424
2525use ahash:: RandomState ;
26- use futures:: {
27- stream:: { Stream , StreamExt } ,
28- Future ,
29- } ;
26+ use futures:: stream:: { Stream , StreamExt } ;
3027
3128use crate :: error:: Result ;
3229use crate :: physical_plan:: hash_utils:: create_hashes;
30+ use crate :: physical_plan:: stream:: RecordBatchReceiverStream ;
3331use crate :: physical_plan:: {
3432 Accumulator , AggregateExpr , DisplayFormatType , Distribution , ExecutionPlan ,
3533 Partitioning , PhysicalExpr ,
@@ -39,19 +37,17 @@ use crate::scalar::ScalarValue;
3937use arrow:: { array:: ArrayRef , compute, compute:: cast} ;
4038use arrow:: {
4139 array:: { Array , UInt32Builder } ,
42- error:: { ArrowError , Result as ArrowResult } ,
40+ error:: Result as ArrowResult ,
4341} ;
4442use arrow:: {
4543 datatypes:: { Field , Schema , SchemaRef } ,
4644 record_batch:: RecordBatch ,
4745} ;
4846use hashbrown:: raw:: RawTable ;
49- use pin_project_lite:: pin_project;
5047
5148use crate :: execution:: context:: TaskContext ;
5249use async_trait:: async_trait;
5350
54- use super :: common:: AbortOnDropSingle ;
5551use super :: expressions:: PhysicalSortExpr ;
5652use super :: metrics:: {
5753 self , BaselineMetrics , ExecutionPlanMetricsSet , MetricsSet , RecordOutput ,
@@ -356,14 +352,9 @@ Example: average
356352* Once all N record batches arrive, `merge` is performed, which builds a RecordBatch with N rows and 2 columns.
357353* Finally, `get_value` returns an array with one entry computed from the state
358354*/
359- pin_project ! {
360- struct GroupedHashAggregateStream {
361- schema: SchemaRef ,
362- #[ pin]
363- output: futures:: channel:: oneshot:: Receiver <ArrowResult <RecordBatch >>,
364- finished: bool ,
365- drop_helper: AbortOnDropSingle <( ) >,
366- }
355+ struct GroupedHashAggregateStream {
356+ schema : SchemaRef ,
357+ stream : SendableRecordBatchStream ,
367358}
368359
369360fn group_aggregate_batch (
@@ -570,12 +561,16 @@ impl GroupedHashAggregateStream {
570561 input : SendableRecordBatchStream ,
571562 baseline_metrics : BaselineMetrics ,
572563 ) -> Self {
573- let ( tx, rx) = futures:: channel:: oneshot:: channel ( ) ;
564+ // Use the panic-propagating builder so that panics in the
565+ // compute task are re-raised on the consumer side instead of
566+ // being reported as a closed channel.
567+ let mut builder = RecordBatchReceiverStream :: builder ( schema. clone ( ) , 1 ) ;
568+ let tx = builder. tx ( ) ;
574569
575570 let schema_clone = schema. clone ( ) ;
576571 let elapsed_compute = baseline_metrics. elapsed_compute ( ) . clone ( ) ;
577572
578- let join_handle = tokio :: spawn ( async move {
573+ builder . spawn ( async move {
579574 let result = compute_grouped_hash_aggregate (
580575 mode,
581576 schema_clone,
@@ -588,14 +583,12 @@ impl GroupedHashAggregateStream {
588583 . record_output ( & baseline_metrics) ;
589584
590585 // failing here is OK, the receiver is gone and does not care about the result
591- tx. send ( result) . ok ( ) ;
586+ tx. send ( result) . await . ok ( ) ;
592587 } ) ;
593588
594589 Self {
595590 schema,
596- output : rx,
597- finished : false ,
598- drop_helper : AbortOnDropSingle :: new ( join_handle) ,
591+ stream : builder. build ( ) ,
599592 }
600593 }
601594}
@@ -647,31 +640,10 @@ impl Stream for GroupedHashAggregateStream {
647640 type Item = ArrowResult < RecordBatch > ;
648641
649642 fn poll_next (
650- self : std:: pin:: Pin < & mut Self > ,
643+ mut self : std:: pin:: Pin < & mut Self > ,
651644 cx : & mut Context < ' _ > ,
652645 ) -> Poll < Option < Self :: Item > > {
653- if self . finished {
654- return Poll :: Ready ( None ) ;
655- }
656-
657- // is the output ready?
658- let this = self . project ( ) ;
659- let output_poll = this. output . poll ( cx) ;
660-
661- match output_poll {
662- Poll :: Ready ( result) => {
663- * this. finished = true ;
664-
665- // check for error in receiving channel and unwrap actual result
666- let result = match result {
667- Err ( e) => Err ( ArrowError :: ExternalError ( Box :: new ( e) ) ) , // error receiving
668- Ok ( result) => result,
669- } ;
670-
671- Poll :: Ready ( Some ( result) )
672- }
673- Poll :: Pending => Poll :: Pending ,
674- }
646+ self . stream . poll_next_unpin ( cx)
675647 }
676648}
677649
@@ -748,15 +720,10 @@ fn aggregate_expressions(
748720 }
749721}
750722
751- pin_project ! {
752- /// stream struct for hash aggregation
753- pub struct HashAggregateStream {
754- schema: SchemaRef ,
755- #[ pin]
756- output: futures:: channel:: oneshot:: Receiver <ArrowResult <RecordBatch >>,
757- finished: bool ,
758- drop_helper: AbortOnDropSingle <( ) >,
759- }
723+ /// stream struct for hash aggregation
724+ pub struct HashAggregateStream {
725+ schema : SchemaRef ,
726+ stream : SendableRecordBatchStream ,
760727}
761728
762729/// Special case aggregate with no groups
@@ -799,11 +766,15 @@ impl HashAggregateStream {
799766 input : SendableRecordBatchStream ,
800767 baseline_metrics : BaselineMetrics ,
801768 ) -> Self {
802- let ( tx, rx) = futures:: channel:: oneshot:: channel ( ) ;
769+ // Use the panic-propagating builder so that panics in the
770+ // compute task are re-raised on the consumer side instead of
771+ // being reported as a closed channel.
772+ let mut builder = RecordBatchReceiverStream :: builder ( schema. clone ( ) , 1 ) ;
773+ let tx = builder. tx ( ) ;
803774
804775 let schema_clone = schema. clone ( ) ;
805776 let elapsed_compute = baseline_metrics. elapsed_compute ( ) . clone ( ) ;
806- let join_handle = tokio :: spawn ( async move {
777+ builder . spawn ( async move {
807778 let result = compute_hash_aggregate (
808779 mode,
809780 schema_clone,
@@ -815,14 +786,12 @@ impl HashAggregateStream {
815786 . record_output ( & baseline_metrics) ;
816787
817788 // failing here is OK, the receiver is gone and does not care about the result
818- tx. send ( result) . ok ( ) ;
789+ tx. send ( result) . await . ok ( ) ;
819790 } ) ;
820791
821792 Self {
822793 schema,
823- output : rx,
824- finished : false ,
825- drop_helper : AbortOnDropSingle :: new ( join_handle) ,
794+ stream : builder. build ( ) ,
826795 }
827796 }
828797}
@@ -863,31 +832,10 @@ impl Stream for HashAggregateStream {
863832 type Item = ArrowResult < RecordBatch > ;
864833
865834 fn poll_next (
866- self : std:: pin:: Pin < & mut Self > ,
835+ mut self : std:: pin:: Pin < & mut Self > ,
867836 cx : & mut Context < ' _ > ,
868837 ) -> Poll < Option < Self :: Item > > {
869- if self . finished {
870- return Poll :: Ready ( None ) ;
871- }
872-
873- // is the output ready?
874- let this = self . project ( ) ;
875- let output_poll = this. output . poll ( cx) ;
876-
877- match output_poll {
878- Poll :: Ready ( result) => {
879- * this. finished = true ;
880-
881- // check for error in receiving channel and unwrap actual result
882- let result = match result {
883- Err ( e) => Err ( ArrowError :: ExternalError ( Box :: new ( e) ) ) , // error receiving
884- Ok ( result) => result,
885- } ;
886-
887- Poll :: Ready ( Some ( result) )
888- }
889- Poll :: Pending => Poll :: Pending ,
890- }
838+ self . stream . poll_next_unpin ( cx)
891839 }
892840}
893841
0 commit comments