1919
2020use crate :: error:: Result ;
2121use arrow:: { datatypes:: SchemaRef , record_batch:: RecordBatch } ;
22- use futures:: { Stream , StreamExt } ;
22+ use datafusion_common:: DataFusionError ;
23+ use futures:: stream:: BoxStream ;
24+ use futures:: { Future , Stream , StreamExt } ;
2325use pin_project_lite:: pin_project;
24- use tokio:: task:: JoinHandle ;
26+ use tokio:: task:: { JoinHandle , JoinSet } ;
2527use tokio_stream:: wrappers:: ReceiverStream ;
2628
2729use super :: common:: AbortOnDropSingle ;
2830use super :: { RecordBatchStream , SendableRecordBatchStream } ;
2931
32+ /// Builder for [`RecordBatchReceiverStream`]
33+ pub struct RecordBatchReceiverStreamBuilder {
34+ tx : tokio:: sync:: mpsc:: Sender < Result < RecordBatch > > ,
35+ rx : tokio:: sync:: mpsc:: Receiver < Result < RecordBatch > > ,
36+ schema : SchemaRef ,
37+ join_set : JoinSet < ( ) > ,
38+ }
39+
40+ impl RecordBatchReceiverStreamBuilder {
41+ /// create new channels with the specified buffer size
42+ pub fn new ( schema : SchemaRef , capacity : usize ) -> Self {
43+ let ( tx, rx) = tokio:: sync:: mpsc:: channel ( capacity) ;
44+
45+ Self {
46+ tx,
47+ rx,
48+ schema,
49+ join_set : JoinSet :: new ( ) ,
50+ }
51+ }
52+
53+ /// Get a handle for sending [`RecordBatch`]es to the output
54+ pub fn tx ( & self ) -> tokio:: sync:: mpsc:: Sender < Result < RecordBatch > > {
55+ self . tx . clone ( )
56+ }
57+
58+ /// Get a handle to the `JoinSet` on which tasks are launched
59+ pub fn join_set_mut ( & mut self ) -> & mut JoinSet < ( ) > {
60+ & mut self . join_set
61+ }
62+
63+ /// Spawn task that will be aborted if this builder (or the stream
64+ /// built from it) are dropped
65+ ///
66+ /// this is often used to spawn tasks that write to the sender
67+ /// retrieved from `Self::tx`
68+ pub fn spawn < F > ( & mut self , task : F )
69+ where
70+ F : Future < Output = ( ) > ,
71+ F : Send + ' static ,
72+ {
73+ self . join_set . spawn ( task) ;
74+ }
75+
76+ /// Create a stream of all `RecordBatch`es written to `tx`
77+ pub fn build ( self ) -> SendableRecordBatchStream {
78+ let Self {
79+ tx,
80+ rx,
81+ schema,
82+ mut join_set,
83+ } = self ;
84+
85+ // don't need tx
86+ drop ( tx) ;
87+
88+ let schema = schema. clone ( ) ;
89+
90+ // future that checks the result of the join set
91+ let check = async move {
92+ while let Some ( result) = join_set. join_next ( ) . await {
93+ match result {
94+ Ok ( ( ) ) => continue , // nothing to report
95+ // This means a tokio task error, likely a panic
96+ Err ( e) => {
97+ if e. is_panic ( ) {
98+ // resume on the main thread
99+ std:: panic:: resume_unwind ( e. into_panic ( ) ) ;
100+ } else {
101+ return Some ( Err ( DataFusionError :: Execution ( format ! (
102+ "Task error: {e}"
103+ ) ) ) ) ;
104+ }
105+ }
106+ }
107+ }
108+ None
109+ } ;
110+
111+ let check_stream = futures:: stream:: once ( check)
112+ // unwrap Option / only return the error
113+ . filter_map ( |item| async move { item } ) ;
114+
115+ let inner = ReceiverStream :: new ( rx) . chain ( check_stream) . boxed ( ) ;
116+
117+ Box :: pin ( RecordBatchReceiverStream { schema, inner } )
118+ }
119+ }
120+
30121/// Adapter for a tokio [`ReceiverStream`] that implements the
31122/// [`SendableRecordBatchStream`]
32123/// interface
33124pub struct RecordBatchReceiverStream {
34125 schema : SchemaRef ,
35-
36- inner : ReceiverStream < Result < RecordBatch > > ,
37-
38- #[ allow( dead_code) ]
39- drop_helper : AbortOnDropSingle < ( ) > ,
126+ inner : BoxStream < ' static , Result < RecordBatch > > ,
40127}
41128
42129impl RecordBatchReceiverStream {
@@ -48,12 +135,26 @@ impl RecordBatchReceiverStream {
48135 join_handle : JoinHandle < ( ) > ,
49136 ) -> SendableRecordBatchStream {
50137 let schema = schema. clone ( ) ;
51- let inner = ReceiverStream :: new ( rx) ;
52- Box :: pin ( Self {
53- schema,
54- inner,
55- drop_helper : AbortOnDropSingle :: new ( join_handle) ,
56- } )
138+ // wrap with AbortOnDropSingle to cancel task if this stream gets dropped prematurely
139+ let drop_helper = AbortOnDropSingle :: new ( join_handle) ;
140+
141+ // future that checks the result of the join handle
142+ let check = async move {
143+ match drop_helper. await {
144+ Ok ( ( ) ) => None ,
145+ // This means a tokio task error
146+ Err ( e) => {
147+ Some ( Err ( DataFusionError :: Execution ( format ! ( "Task error: {e}" ) ) ) )
148+ }
149+ }
150+ } ;
151+ let check_stream = futures:: stream:: once ( check)
152+ // unwrap Option / only return the error
153+ . filter_map ( |item| async move { item } ) ;
154+
155+ let inner = ReceiverStream :: new ( rx) . chain ( check_stream) . boxed ( ) ;
156+
157+ Box :: pin ( Self { schema, inner } )
57158 }
58159}
59160
0 commit comments