Skip to content

Commit 0a96357

Browse files
committed
Consolidate panic propagation into RecordBatchReceiverStream
1 parent 742597a commit 0a96357

2 files changed

Lines changed: 124 additions & 79 deletions

File tree

datafusion/core/src/physical_plan/coalesce_partitions.rs

Lines changed: 10 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,14 @@
1919
//! into a single partition
2020
2121
use std::any::Any;
22-
use std::panic;
2322
use std::sync::Arc;
24-
use std::task::Poll;
25-
26-
use futures::{Future, Stream};
27-
use tokio::sync::mpsc;
2823

2924
use arrow::datatypes::SchemaRef;
30-
use arrow::record_batch::RecordBatch;
31-
use tokio::task::JoinSet;
3225

3326
use super::expressions::PhysicalSortExpr;
3427
use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
35-
use super::{RecordBatchStream, Statistics};
28+
use super::stream::RecordBatchReceiverStreamBuilder;
29+
use super::Statistics;
3630
use crate::error::{DataFusionError, Result};
3731
use crate::physical_plan::{
3832
DisplayFormatType, EquivalenceProperties, ExecutionPlan, Partitioning,
@@ -138,28 +132,25 @@ impl ExecutionPlan for CoalescePartitionsExec {
138132
// use a stream that allows each sender to put in at
139133
// least one result in an attempt to maximize
140134
// parallelism.
141-
let (sender, receiver) =
142-
mpsc::channel::<Result<RecordBatch>>(input_partitions);
135+
let mut builder = RecordBatchReceiverStreamBuilder::new(
136+
self.schema(),
137+
input_partitions,
138+
);
143139

144140
// spawn independent tasks whose resulting streams (of batches)
145141
// are sent to the channel for consumption.
146-
let mut tasks = JoinSet::new();
147142
for part_i in 0..input_partitions {
143+
let sender = builder.tx();
148144
spawn_execution(
149-
&mut tasks,
145+
builder.join_set_mut(),
150146
self.input.clone(),
151-
sender.clone(),
147+
sender,
152148
part_i,
153149
context.clone(),
154150
);
155151
}
156152

157-
Ok(Box::pin(MergeStream {
158-
input: receiver,
159-
schema: self.schema(),
160-
baseline_metrics,
161-
tasks,
162-
}))
153+
Ok(builder.build())
163154
}
164155
}
165156
}
@@ -185,53 +176,6 @@ impl ExecutionPlan for CoalescePartitionsExec {
185176
}
186177
}
187178

188-
struct MergeStream {
189-
schema: SchemaRef,
190-
input: mpsc::Receiver<Result<RecordBatch>>,
191-
baseline_metrics: BaselineMetrics,
192-
tasks: JoinSet<()>,
193-
}
194-
195-
impl Stream for MergeStream {
196-
type Item = Result<RecordBatch>;
197-
198-
fn poll_next(
199-
mut self: std::pin::Pin<&mut Self>,
200-
cx: &mut std::task::Context<'_>,
201-
) -> Poll<Option<Self::Item>> {
202-
let poll = self.input.poll_recv(cx);
203-
204-
// If the input stream is done, wait for all tasks to finish and return
205-
// the failure if any.
206-
if let Poll::Ready(None) = poll {
207-
let fut = self.tasks.join_next();
208-
tokio::pin!(fut);
209-
210-
match fut.poll(cx) {
211-
Poll::Ready(task_poll) => {
212-
if let Some(Err(e)) = task_poll {
213-
if e.is_panic() {
214-
panic::resume_unwind(e.into_panic());
215-
}
216-
return Poll::Ready(Some(Err(DataFusionError::Execution(
217-
format!("{e:?}"),
218-
))));
219-
}
220-
}
221-
Poll::Pending => {}
222-
}
223-
}
224-
225-
self.baseline_metrics.record_poll(poll)
226-
}
227-
}
228-
229-
impl RecordBatchStream for MergeStream {
230-
fn schema(&self) -> SchemaRef {
231-
self.schema.clone()
232-
}
233-
}
234-
235179
#[cfg(test)]
236180
mod tests {
237181

datafusion/core/src/physical_plan/stream.rs

Lines changed: 114 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,24 +19,111 @@
1919
2020
use crate::error::Result;
2121
use arrow::{datatypes::SchemaRef, record_batch::RecordBatch};
22-
use futures::{Stream, StreamExt};
22+
use datafusion_common::DataFusionError;
23+
use futures::stream::BoxStream;
24+
use futures::{Future, Stream, StreamExt};
2325
use pin_project_lite::pin_project;
24-
use tokio::task::JoinHandle;
26+
use tokio::task::{JoinHandle, JoinSet};
2527
use tokio_stream::wrappers::ReceiverStream;
2628

2729
use super::common::AbortOnDropSingle;
2830
use super::{RecordBatchStream, SendableRecordBatchStream};
2931

32+
/// Builder for [`RecordBatchReceiverStream`]
33+
pub struct RecordBatchReceiverStreamBuilder {
34+
tx: tokio::sync::mpsc::Sender<Result<RecordBatch>>,
35+
rx: tokio::sync::mpsc::Receiver<Result<RecordBatch>>,
36+
schema: SchemaRef,
37+
join_set: JoinSet<()>,
38+
}
39+
40+
impl RecordBatchReceiverStreamBuilder {
41+
/// create new channels with the specified buffer size
42+
pub fn new(schema: SchemaRef, capacity: usize) -> Self {
43+
let (tx, rx) = tokio::sync::mpsc::channel(capacity);
44+
45+
Self {
46+
tx,
47+
rx,
48+
schema,
49+
join_set: JoinSet::new(),
50+
}
51+
}
52+
53+
/// Get a handle for sending [`RecordBatch`]es to the output
54+
pub fn tx(&self) -> tokio::sync::mpsc::Sender<Result<RecordBatch>> {
55+
self.tx.clone()
56+
}
57+
58+
/// Get a handle to the `JoinSet` on which tasks are launched
59+
pub fn join_set_mut(&mut self) -> &mut JoinSet<()> {
60+
&mut self.join_set
61+
}
62+
63+
/// Spawn task that will be aborted if this builder (or the stream
64+
/// built from it) are dropped
65+
///
66+
/// this is often used to spawn tasks that write to the sender
67+
/// retrieved from `Self::tx`
68+
pub fn spawn<F>(&mut self, task: F)
69+
where
70+
F: Future<Output = ()>,
71+
F: Send + 'static,
72+
{
73+
self.join_set.spawn(task);
74+
}
75+
76+
/// Create a stream of all `RecordBatch`es written to `tx`
77+
pub fn build(self) -> SendableRecordBatchStream {
78+
let Self {
79+
tx,
80+
rx,
81+
schema,
82+
mut join_set,
83+
} = self;
84+
85+
// don't need tx
86+
drop(tx);
87+
88+
let schema = schema.clone();
89+
90+
// future that checks the result of the join set
91+
let check = async move {
92+
while let Some(result) = join_set.join_next().await {
93+
match result {
94+
Ok(()) => continue, // nothing to report
95+
// This means a tokio task error, likely a panic
96+
Err(e) => {
97+
if e.is_panic() {
98+
// resume on the main thread
99+
std::panic::resume_unwind(e.into_panic());
100+
} else {
101+
return Some(Err(DataFusionError::Execution(format!(
102+
"Task error: {e}"
103+
))));
104+
}
105+
}
106+
}
107+
}
108+
None
109+
};
110+
111+
let check_stream = futures::stream::once(check)
112+
// unwrap Option / only return the error
113+
.filter_map(|item| async move { item });
114+
115+
let inner = ReceiverStream::new(rx).chain(check_stream).boxed();
116+
117+
Box::pin(RecordBatchReceiverStream { schema, inner })
118+
}
119+
}
120+
30121
/// Adapter for a tokio [`ReceiverStream`] that implements the
31122
/// [`SendableRecordBatchStream`]
32123
/// interface
33124
pub struct RecordBatchReceiverStream {
34125
schema: SchemaRef,
35-
36-
inner: ReceiverStream<Result<RecordBatch>>,
37-
38-
#[allow(dead_code)]
39-
drop_helper: AbortOnDropSingle<()>,
126+
inner: BoxStream<'static, Result<RecordBatch>>,
40127
}
41128

42129
impl RecordBatchReceiverStream {
@@ -48,12 +135,26 @@ impl RecordBatchReceiverStream {
48135
join_handle: JoinHandle<()>,
49136
) -> SendableRecordBatchStream {
50137
let schema = schema.clone();
51-
let inner = ReceiverStream::new(rx);
52-
Box::pin(Self {
53-
schema,
54-
inner,
55-
drop_helper: AbortOnDropSingle::new(join_handle),
56-
})
138+
// wrap with AbortOnDropSingle to cancel task if this stream gets dropped prematurely
139+
let drop_helper = AbortOnDropSingle::new(join_handle);
140+
141+
// future that checks the result of the join handle
142+
let check = async move {
143+
match drop_helper.await {
144+
Ok(()) => None,
145+
// This means a tokio task error
146+
Err(e) => {
147+
Some(Err(DataFusionError::Execution(format!("Task error: {e}"))))
148+
}
149+
}
150+
};
151+
let check_stream = futures::stream::once(check)
152+
// unwrap Option / only return the error
153+
.filter_map(|item| async move { item });
154+
155+
let inner = ReceiverStream::new(rx).chain(check_stream).boxed();
156+
157+
Box::pin(Self { schema, inner })
57158
}
58159
}
59160

0 commit comments

Comments
 (0)