Skip to content

Commit a1c886d

Browse files
committed
fix: Unwind panic in spawned threads
Signed-off-by: Alex Qyoun-ae <4062971+MazterQyou@users.noreply.github.com>
1 parent 03e683e commit a1c886d

8 files changed

Lines changed: 827 additions & 294 deletions

File tree

datafusion-cli/Cargo.lock

Lines changed: 3 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/core/src/physical_plan/analyze.rs

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -123,16 +123,19 @@ impl ExecutionPlan for AnalyzeExec {
123123
)));
124124
}
125125

126-
let (tx, rx) = tokio::sync::mpsc::channel(input_partitions);
126+
let mut builder =
127+
RecordBatchReceiverStream::builder(self.schema(), input_partitions);
128+
let tx = builder.tx();
127129

128130
let captured_input = self.input.clone();
129131
let mut input_stream = captured_input.execute(0, context).await?;
130132
let captured_schema = self.schema.clone();
131133
let verbose = self.verbose;
132134

133-
// Task reads batches the input and when complete produce a
134-
// RecordBatch with a report that is written to `tx` when done
135-
let join_handle = tokio::task::spawn(async move {
135+
// Task reads batches from the input and when complete produces
136+
// a RecordBatch with a report that is written to `tx` when
137+
// done. Panics from this task are propagated via the builder.
138+
builder.spawn(async move {
136139
let start = Instant::now();
137140
let mut total_rows = 0;
138141

@@ -201,11 +204,7 @@ impl ExecutionPlan for AnalyzeExec {
201204
tx.send(maybe_batch).await.ok();
202205
});
203206

204-
Ok(RecordBatchReceiverStream::create(
205-
&self.schema,
206-
rx,
207-
join_handle,
208-
))
207+
Ok(builder.build())
209208
}
210209

211210
fn fmt_as(

datafusion/core/src/physical_plan/coalesce_partitions.rs

Lines changed: 26 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -20,27 +20,20 @@
2020
2121
use std::any::Any;
2222
use std::sync::Arc;
23-
use std::task::Poll;
24-
25-
use futures::channel::mpsc;
26-
use futures::Stream;
2723

2824
use async_trait::async_trait;
2925

30-
use arrow::record_batch::RecordBatch;
31-
use arrow::{datatypes::SchemaRef, error::Result as ArrowResult};
26+
use arrow::datatypes::SchemaRef;
3227

33-
use super::common::AbortOnDropMany;
3428
use super::expressions::PhysicalSortExpr;
3529
use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
36-
use super::{RecordBatchStream, Statistics};
30+
use super::stream::{ObservedStream, RecordBatchReceiverStream};
31+
use super::Statistics;
3732
use crate::error::{DataFusionError, Result};
3833
use crate::physical_plan::{DisplayFormatType, ExecutionPlan, Partitioning};
3934

4035
use super::SendableRecordBatchStream;
4136
use crate::execution::context::TaskContext;
42-
use crate::physical_plan::common::spawn_execution;
43-
use pin_project_lite::pin_project;
4437

4538
/// Merge execution plan executes partitions in parallel and combines them into a single
4639
/// partition. No guarantees are made about the order of the resulting partition.
@@ -134,27 +127,17 @@ impl ExecutionPlan for CoalescePartitionsExec {
134127
// use a stream that allows each sender to put in at
135128
// least one result in an attempt to maximize
136129
// parallelism.
137-
let (sender, receiver) =
138-
mpsc::channel::<ArrowResult<RecordBatch>>(input_partitions);
130+
let mut builder =
131+
RecordBatchReceiverStream::builder(self.schema(), input_partitions);
139132

140133
// spawn independent tasks whose resulting streams (of batches)
141134
// are sent to the channel for consumption.
142-
let mut join_handles = Vec::with_capacity(input_partitions);
143135
for part_i in 0..input_partitions {
144-
join_handles.push(spawn_execution(
145-
self.input.clone(),
146-
sender.clone(),
147-
part_i,
148-
context.clone(),
149-
));
136+
builder.run_input(self.input.clone(), part_i, context.clone());
150137
}
151138

152-
Ok(Box::pin(MergeStream {
153-
input: receiver,
154-
schema: self.schema(),
155-
baseline_metrics,
156-
drop_helper: AbortOnDropMany(join_handles),
157-
}))
139+
let stream = builder.build();
140+
Ok(Box::pin(ObservedStream::new(stream, baseline_metrics)))
158141
}
159142
}
160143
}
@@ -180,35 +163,6 @@ impl ExecutionPlan for CoalescePartitionsExec {
180163
}
181164
}
182165

183-
pin_project! {
184-
struct MergeStream {
185-
schema: SchemaRef,
186-
#[pin]
187-
input: mpsc::Receiver<ArrowResult<RecordBatch>>,
188-
baseline_metrics: BaselineMetrics,
189-
drop_helper: AbortOnDropMany<()>,
190-
}
191-
}
192-
193-
impl Stream for MergeStream {
194-
type Item = ArrowResult<RecordBatch>;
195-
196-
fn poll_next(
197-
self: std::pin::Pin<&mut Self>,
198-
cx: &mut std::task::Context<'_>,
199-
) -> Poll<Option<Self::Item>> {
200-
let this = self.project();
201-
let poll = this.input.poll_next(cx);
202-
this.baseline_metrics.record_poll(poll)
203-
}
204-
}
205-
206-
impl RecordBatchStream for MergeStream {
207-
fn schema(&self) -> SchemaRef {
208-
self.schema.clone()
209-
}
210-
}
211-
212166
#[cfg(test)]
213167
mod tests {
214168

@@ -220,7 +174,9 @@ mod tests {
220174
use crate::physical_plan::file_format::{CsvExec, FileScanConfig};
221175
use crate::physical_plan::{collect, common};
222176
use crate::prelude::SessionContext;
223-
use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec};
177+
use crate::test::exec::{
178+
assert_strong_count_converges_to_zero, BlockingExec, PanicExec,
179+
};
224180
use crate::test::{self, assert_is_pending};
225181
use crate::test_util;
226182

@@ -288,4 +244,19 @@ mod tests {
288244

289245
Ok(())
290246
}
247+
248+
#[tokio::test]
249+
#[should_panic(expected = "PanickingStream did panic")]
250+
async fn test_panic() {
251+
let session_ctx = SessionContext::new();
252+
let task_ctx = session_ctx.task_ctx();
253+
let schema =
254+
Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]));
255+
256+
let panicking_exec = Arc::new(PanicExec::new(Arc::clone(&schema), 2));
257+
let coalesce_partitions_exec =
258+
Arc::new(CoalescePartitionsExec::new(panicking_exec));
259+
260+
collect(coalesce_partitions_exec, task_ctx).await.unwrap();
261+
}
291262
}

datafusion/core/src/physical_plan/hash_aggregate.rs

Lines changed: 30 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,11 @@ use std::task::{Context, Poll};
2323
use std::vec;
2424

2525
use ahash::RandomState;
26-
use futures::{
27-
stream::{Stream, StreamExt},
28-
Future,
29-
};
26+
use futures::stream::{Stream, StreamExt};
3027

3128
use crate::error::Result;
3229
use crate::physical_plan::hash_utils::create_hashes;
30+
use crate::physical_plan::stream::RecordBatchReceiverStream;
3331
use crate::physical_plan::{
3432
Accumulator, AggregateExpr, DisplayFormatType, Distribution, ExecutionPlan,
3533
Partitioning, PhysicalExpr,
@@ -39,19 +37,17 @@ use crate::scalar::ScalarValue;
3937
use arrow::{array::ArrayRef, compute, compute::cast};
4038
use arrow::{
4139
array::{Array, UInt32Builder},
42-
error::{ArrowError, Result as ArrowResult},
40+
error::Result as ArrowResult,
4341
};
4442
use arrow::{
4543
datatypes::{Field, Schema, SchemaRef},
4644
record_batch::RecordBatch,
4745
};
4846
use hashbrown::raw::RawTable;
49-
use pin_project_lite::pin_project;
5047

5148
use crate::execution::context::TaskContext;
5249
use async_trait::async_trait;
5350

54-
use super::common::AbortOnDropSingle;
5551
use super::expressions::PhysicalSortExpr;
5652
use super::metrics::{
5753
self, BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput,
@@ -356,14 +352,9 @@ Example: average
356352
* Once all N record batches arrive, `merge` is performed, which builds a RecordBatch with N rows and 2 columns.
357353
* Finally, `get_value` returns an array with one entry computed from the state
358354
*/
359-
pin_project! {
360-
struct GroupedHashAggregateStream {
361-
schema: SchemaRef,
362-
#[pin]
363-
output: futures::channel::oneshot::Receiver<ArrowResult<RecordBatch>>,
364-
finished: bool,
365-
drop_helper: AbortOnDropSingle<()>,
366-
}
355+
struct GroupedHashAggregateStream {
356+
schema: SchemaRef,
357+
stream: SendableRecordBatchStream,
367358
}
368359

369360
fn group_aggregate_batch(
@@ -570,12 +561,16 @@ impl GroupedHashAggregateStream {
570561
input: SendableRecordBatchStream,
571562
baseline_metrics: BaselineMetrics,
572563
) -> Self {
573-
let (tx, rx) = futures::channel::oneshot::channel();
564+
// Use the panic-propagating builder so that panics in the
565+
// compute task are re-raised on the consumer side instead of
566+
// being reported as a closed channel.
567+
let mut builder = RecordBatchReceiverStream::builder(schema.clone(), 1);
568+
let tx = builder.tx();
574569

575570
let schema_clone = schema.clone();
576571
let elapsed_compute = baseline_metrics.elapsed_compute().clone();
577572

578-
let join_handle = tokio::spawn(async move {
573+
builder.spawn(async move {
579574
let result = compute_grouped_hash_aggregate(
580575
mode,
581576
schema_clone,
@@ -588,14 +583,12 @@ impl GroupedHashAggregateStream {
588583
.record_output(&baseline_metrics);
589584

590585
// failing here is OK, the receiver is gone and does not care about the result
591-
tx.send(result).ok();
586+
tx.send(result).await.ok();
592587
});
593588

594589
Self {
595590
schema,
596-
output: rx,
597-
finished: false,
598-
drop_helper: AbortOnDropSingle::new(join_handle),
591+
stream: builder.build(),
599592
}
600593
}
601594
}
@@ -647,31 +640,10 @@ impl Stream for GroupedHashAggregateStream {
647640
type Item = ArrowResult<RecordBatch>;
648641

649642
fn poll_next(
650-
self: std::pin::Pin<&mut Self>,
643+
mut self: std::pin::Pin<&mut Self>,
651644
cx: &mut Context<'_>,
652645
) -> Poll<Option<Self::Item>> {
653-
if self.finished {
654-
return Poll::Ready(None);
655-
}
656-
657-
// is the output ready?
658-
let this = self.project();
659-
let output_poll = this.output.poll(cx);
660-
661-
match output_poll {
662-
Poll::Ready(result) => {
663-
*this.finished = true;
664-
665-
// check for error in receiving channel and unwrap actual result
666-
let result = match result {
667-
Err(e) => Err(ArrowError::ExternalError(Box::new(e))), // error receiving
668-
Ok(result) => result,
669-
};
670-
671-
Poll::Ready(Some(result))
672-
}
673-
Poll::Pending => Poll::Pending,
674-
}
646+
self.stream.poll_next_unpin(cx)
675647
}
676648
}
677649

@@ -748,15 +720,10 @@ fn aggregate_expressions(
748720
}
749721
}
750722

751-
pin_project! {
752-
/// stream struct for hash aggregation
753-
pub struct HashAggregateStream {
754-
schema: SchemaRef,
755-
#[pin]
756-
output: futures::channel::oneshot::Receiver<ArrowResult<RecordBatch>>,
757-
finished: bool,
758-
drop_helper: AbortOnDropSingle<()>,
759-
}
723+
/// stream struct for hash aggregation
724+
pub struct HashAggregateStream {
725+
schema: SchemaRef,
726+
stream: SendableRecordBatchStream,
760727
}
761728

762729
/// Special case aggregate with no groups
@@ -799,11 +766,15 @@ impl HashAggregateStream {
799766
input: SendableRecordBatchStream,
800767
baseline_metrics: BaselineMetrics,
801768
) -> Self {
802-
let (tx, rx) = futures::channel::oneshot::channel();
769+
// Use the panic-propagating builder so that panics in the
770+
// compute task are re-raised on the consumer side instead of
771+
// being reported as a closed channel.
772+
let mut builder = RecordBatchReceiverStream::builder(schema.clone(), 1);
773+
let tx = builder.tx();
803774

804775
let schema_clone = schema.clone();
805776
let elapsed_compute = baseline_metrics.elapsed_compute().clone();
806-
let join_handle = tokio::spawn(async move {
777+
builder.spawn(async move {
807778
let result = compute_hash_aggregate(
808779
mode,
809780
schema_clone,
@@ -815,14 +786,12 @@ impl HashAggregateStream {
815786
.record_output(&baseline_metrics);
816787

817788
// failing here is OK, the receiver is gone and does not care about the result
818-
tx.send(result).ok();
789+
tx.send(result).await.ok();
819790
});
820791

821792
Self {
822793
schema,
823-
output: rx,
824-
finished: false,
825-
drop_helper: AbortOnDropSingle::new(join_handle),
794+
stream: builder.build(),
826795
}
827796
}
828797
}
@@ -863,31 +832,10 @@ impl Stream for HashAggregateStream {
863832
type Item = ArrowResult<RecordBatch>;
864833

865834
fn poll_next(
866-
self: std::pin::Pin<&mut Self>,
835+
mut self: std::pin::Pin<&mut Self>,
867836
cx: &mut Context<'_>,
868837
) -> Poll<Option<Self::Item>> {
869-
if self.finished {
870-
return Poll::Ready(None);
871-
}
872-
873-
// is the output ready?
874-
let this = self.project();
875-
let output_poll = this.output.poll(cx);
876-
877-
match output_poll {
878-
Poll::Ready(result) => {
879-
*this.finished = true;
880-
881-
// check for error in receiving channel and unwrap actual result
882-
let result = match result {
883-
Err(e) => Err(ArrowError::ExternalError(Box::new(e))), // error receiving
884-
Ok(result) => result,
885-
};
886-
887-
Poll::Ready(Some(result))
888-
}
889-
Poll::Pending => Poll::Pending,
890-
}
838+
self.stream.poll_next_unpin(cx)
891839
}
892840
}
893841

0 commit comments

Comments
 (0)