Skip to content

Commit fe8b82d

Browse files
committed
Consolidate panic propagation into RecordBatchReceiverStream
1 parent 742597a commit fe8b82d

8 files changed

Lines changed: 241 additions & 235 deletions

File tree

datafusion/core/src/physical_plan/analyze.rs

Lines changed: 9 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,9 @@ use crate::{
2929
};
3030
use arrow::{array::StringBuilder, datatypes::SchemaRef, record_batch::RecordBatch};
3131
use futures::StreamExt;
32-
use tokio::task::JoinSet;
3332

3433
use super::expressions::PhysicalSortExpr;
35-
use super::stream::RecordBatchStreamAdapter;
34+
use super::stream::{RecordBatchReceiverStream, RecordBatchStreamAdapter};
3635
use super::{Distribution, SendableRecordBatchStream};
3736
use crate::execution::context::TaskContext;
3837

@@ -121,23 +120,15 @@ impl ExecutionPlan for AnalyzeExec {
121120
// Gather futures that will run each input partition in
122121
// parallel (on a separate tokio task) using a JoinSet to
123122
// cancel outstanding futures on drop
124-
let mut set = JoinSet::new();
125123
let num_input_partitions = self.input.output_partitioning().partition_count();
124+
let mut builder =
125+
RecordBatchReceiverStream::builder(self.schema(), num_input_partitions);
126126

127127
for input_partition in 0..num_input_partitions {
128-
let input_stream = self.input.execute(input_partition, context.clone());
129-
130-
set.spawn(async move {
131-
let mut total_rows = 0;
132-
let mut input_stream = input_stream?;
133-
while let Some(batch) = input_stream.next().await {
134-
let batch = batch?;
135-
total_rows += batch.num_rows();
136-
}
137-
Ok(total_rows) as Result<usize>
138-
});
128+
builder.run_input(self.input.clone(), input_partition, context.clone());
139129
}
140130

131+
// Create future that computes thefinal output
141132
let start = Instant::now();
142133
let captured_input = self.input.clone();
143134
let captured_schema = self.schema.clone();
@@ -146,18 +137,12 @@ impl ExecutionPlan for AnalyzeExec {
146137
// future that gathers the results from all the tasks in the
147138
// JoinSet that computes the overall row count and final
148139
// record batch
140+
let mut input_stream = builder.build();
149141
let output = async move {
150142
let mut total_rows = 0;
151-
while let Some(res) = set.join_next().await {
152-
// translate join errors (aka task panic's) into ExecutionErrors
153-
match res {
154-
Ok(row_count) => total_rows += row_count?,
155-
Err(e) => {
156-
return Err(DataFusionError::Execution(format!(
157-
"Join error in AnalyzeExec: {e}"
158-
)))
159-
}
160-
}
143+
while let Some(batch) = input_stream.next().await {
144+
let batch = batch?;
145+
total_rows += batch.num_rows();
161146
}
162147

163148
let duration = Instant::now() - start;

datafusion/core/src/physical_plan/coalesce_partitions.rs

Lines changed: 7 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -19,28 +19,21 @@
1919
//! into a single partition
2020
2121
use std::any::Any;
22-
use std::panic;
2322
use std::sync::Arc;
24-
use std::task::Poll;
25-
26-
use futures::{Future, Stream};
27-
use tokio::sync::mpsc;
2823

2924
use arrow::datatypes::SchemaRef;
30-
use arrow::record_batch::RecordBatch;
31-
use tokio::task::JoinSet;
3225

3326
use super::expressions::PhysicalSortExpr;
3427
use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
35-
use super::{RecordBatchStream, Statistics};
28+
use super::stream::{ObservedStream, RecordBatchReceiverStream};
29+
use super::Statistics;
3630
use crate::error::{DataFusionError, Result};
3731
use crate::physical_plan::{
3832
DisplayFormatType, EquivalenceProperties, ExecutionPlan, Partitioning,
3933
};
4034

4135
use super::SendableRecordBatchStream;
4236
use crate::execution::context::TaskContext;
43-
use crate::physical_plan::common::spawn_execution;
4437

4538
/// Merge execution plan executes partitions in parallel and combines them into a single
4639
/// partition. No guarantees are made about the order of the resulting partition.
@@ -138,28 +131,17 @@ impl ExecutionPlan for CoalescePartitionsExec {
138131
// use a stream that allows each sender to put in at
139132
// least one result in an attempt to maximize
140133
// parallelism.
141-
let (sender, receiver) =
142-
mpsc::channel::<Result<RecordBatch>>(input_partitions);
134+
let mut builder =
135+
RecordBatchReceiverStream::builder(self.schema(), input_partitions);
143136

144137
// spawn independent tasks whose resulting streams (of batches)
145138
// are sent to the channel for consumption.
146-
let mut tasks = JoinSet::new();
147139
for part_i in 0..input_partitions {
148-
spawn_execution(
149-
&mut tasks,
150-
self.input.clone(),
151-
sender.clone(),
152-
part_i,
153-
context.clone(),
154-
);
140+
builder.run_input(self.input.clone(), part_i, context.clone());
155141
}
156142

157-
Ok(Box::pin(MergeStream {
158-
input: receiver,
159-
schema: self.schema(),
160-
baseline_metrics,
161-
tasks,
162-
}))
143+
let stream = builder.build();
144+
return Ok(Box::pin(ObservedStream::new(stream, baseline_metrics)));
163145
}
164146
}
165147
}
@@ -185,53 +167,6 @@ impl ExecutionPlan for CoalescePartitionsExec {
185167
}
186168
}
187169

188-
struct MergeStream {
189-
schema: SchemaRef,
190-
input: mpsc::Receiver<Result<RecordBatch>>,
191-
baseline_metrics: BaselineMetrics,
192-
tasks: JoinSet<()>,
193-
}
194-
195-
impl Stream for MergeStream {
196-
type Item = Result<RecordBatch>;
197-
198-
fn poll_next(
199-
mut self: std::pin::Pin<&mut Self>,
200-
cx: &mut std::task::Context<'_>,
201-
) -> Poll<Option<Self::Item>> {
202-
let poll = self.input.poll_recv(cx);
203-
204-
// If the input stream is done, wait for all tasks to finish and return
205-
// the failure if any.
206-
if let Poll::Ready(None) = poll {
207-
let fut = self.tasks.join_next();
208-
tokio::pin!(fut);
209-
210-
match fut.poll(cx) {
211-
Poll::Ready(task_poll) => {
212-
if let Some(Err(e)) = task_poll {
213-
if e.is_panic() {
214-
panic::resume_unwind(e.into_panic());
215-
}
216-
return Poll::Ready(Some(Err(DataFusionError::Execution(
217-
format!("{e:?}"),
218-
))));
219-
}
220-
}
221-
Poll::Pending => {}
222-
}
223-
}
224-
225-
self.baseline_metrics.record_poll(poll)
226-
}
227-
}
228-
229-
impl RecordBatchStream for MergeStream {
230-
fn schema(&self) -> SchemaRef {
231-
self.schema.clone()
232-
}
233-
}
234-
235170
#[cfg(test)]
236171
mod tests {
237172

datafusion/core/src/physical_plan/common.rs

Lines changed: 10 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -19,26 +19,23 @@
1919
2020
use super::SendableRecordBatchStream;
2121
use crate::error::{DataFusionError, Result};
22-
use crate::execution::context::TaskContext;
2322
use crate::execution::memory_pool::MemoryReservation;
2423
use crate::physical_plan::stream::RecordBatchReceiverStream;
25-
use crate::physical_plan::{displayable, ColumnStatistics, ExecutionPlan, Statistics};
24+
use crate::physical_plan::{ColumnStatistics, ExecutionPlan, Statistics};
2625
use arrow::datatypes::Schema;
2726
use arrow::ipc::writer::{FileWriter, IpcWriteOptions};
2827
use arrow::record_batch::RecordBatch;
2928
use datafusion_physical_expr::expressions::{BinaryExpr, Column};
3029
use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
3130
use futures::{Future, StreamExt, TryStreamExt};
32-
use log::debug;
3331
use parking_lot::Mutex;
3432
use pin_project_lite::pin_project;
3533
use std::fs;
3634
use std::fs::{metadata, File};
3735
use std::path::{Path, PathBuf};
3836
use std::sync::Arc;
3937
use std::task::{Context, Poll};
40-
use tokio::sync::mpsc;
41-
use tokio::task::{JoinHandle, JoinSet};
38+
use tokio::task::JoinHandle;
4239

4340
/// [`MemoryReservation`] used across query execution streams
4441
pub(crate) type SharedMemoryReservation = Arc<Mutex<MemoryReservation>>;
@@ -96,66 +93,30 @@ fn build_file_list_recurse(
9693
Ok(())
9794
}
9895

99-
/// Spawns a task to the tokio threadpool and writes its outputs to the provided mpsc sender
100-
pub(crate) fn spawn_execution(
101-
join_set: &mut JoinSet<()>,
102-
input: Arc<dyn ExecutionPlan>,
103-
output: mpsc::Sender<Result<RecordBatch>>,
104-
partition: usize,
105-
context: Arc<TaskContext>,
106-
) {
107-
join_set.spawn(async move {
108-
let mut stream = match input.execute(partition, context) {
109-
Err(e) => {
110-
// If send fails, plan being torn down,
111-
// there is no place to send the error.
112-
output.send(Err(e)).await.ok();
113-
debug!(
114-
"Stopping execution: error executing input: {}",
115-
displayable(input.as_ref()).one_line()
116-
);
117-
return;
118-
}
119-
Ok(stream) => stream,
120-
};
121-
122-
while let Some(item) = stream.next().await {
123-
// If send fails, plan being torn down,
124-
// there is no place to send the error.
125-
if output.send(item).await.is_err() {
126-
debug!(
127-
"Stopping execution: output is gone, plan cancelling: {}",
128-
displayable(input.as_ref()).one_line()
129-
);
130-
return;
131-
}
132-
}
133-
});
134-
}
135-
13696
/// If running in a tokio context spawns the execution of `stream` to a separate task
13797
/// allowing it to execute in parallel with an intermediate buffer of size `buffer`
13898
pub(crate) fn spawn_buffered(
13999
mut input: SendableRecordBatchStream,
140100
buffer: usize,
141101
) -> SendableRecordBatchStream {
142102
// Use tokio only if running from a tokio context (#2201)
143-
let handle = match tokio::runtime::Handle::try_current() {
144-
Ok(handle) => handle,
145-
Err(_) => return input,
103+
if let Err(_) = tokio::runtime::Handle::try_current() {
104+
return input;
146105
};
147106

148-
let schema = input.schema();
149-
let (sender, receiver) = mpsc::channel(buffer);
150-
let join = handle.spawn(async move {
107+
let mut builder = RecordBatchReceiverStream::builder(input.schema(), buffer);
108+
109+
let sender = builder.tx();
110+
111+
builder.spawn(async move {
151112
while let Some(item) = input.next().await {
152113
if sender.send(item).await.is_err() {
153114
return;
154115
}
155116
}
156117
});
157118

158-
RecordBatchReceiverStream::create(&schema, receiver, join)
119+
builder.build()
159120
}
160121

161122
/// Computes the statistics for an in-memory RecordBatch

datafusion/core/src/physical_plan/sorts/sort.rs

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ use std::io::BufReader;
5252
use std::path::{Path, PathBuf};
5353
use std::sync::Arc;
5454
use tempfile::NamedTempFile;
55-
use tokio::sync::mpsc::{Receiver, Sender};
55+
use tokio::sync::mpsc::Sender;
5656
use tokio::task;
5757

5858
struct ExternalSorterMetrics {
@@ -373,18 +373,16 @@ fn read_spill_as_stream(
373373
path: NamedTempFile,
374374
schema: SchemaRef,
375375
) -> Result<SendableRecordBatchStream> {
376-
let (sender, receiver): (Sender<Result<RecordBatch>>, Receiver<Result<RecordBatch>>) =
377-
tokio::sync::mpsc::channel(2);
378-
let join_handle = task::spawn_blocking(move || {
376+
let mut builder = RecordBatchReceiverStream::builder(schema, 2);
377+
let sender = builder.tx();
378+
379+
builder.spawn_blocking(move || {
379380
if let Err(e) = read_spill(sender, path.path()) {
380381
error!("Failure while reading spill file: {:?}. Error: {}", path, e);
381382
}
382383
});
383-
Ok(RecordBatchReceiverStream::create(
384-
&schema,
385-
receiver,
386-
join_handle,
387-
))
384+
385+
Ok(builder.build())
388386
}
389387

390388
fn write_sorted(

datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -792,21 +792,20 @@ mod tests {
792792
let mut streams = Vec::with_capacity(partition_count);
793793

794794
for partition in 0..partition_count {
795-
let (sender, receiver) = tokio::sync::mpsc::channel(1);
795+
let mut builder = RecordBatchReceiverStream::builder(schema.clone(), 1);
796+
797+
let sender = builder.tx();
798+
796799
let mut stream = batches.execute(partition, task_ctx.clone()).unwrap();
797-
let join_handle = tokio::spawn(async move {
800+
builder.spawn(async move {
798801
while let Some(batch) = stream.next().await {
799802
sender.send(batch).await.unwrap();
800803
// This causes the MergeStream to wait for more input
801804
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
802805
}
803806
});
804807

805-
streams.push(RecordBatchReceiverStream::create(
806-
&schema,
807-
receiver,
808-
join_handle,
809-
));
808+
streams.push(builder.build());
810809
}
811810

812811
let metrics = ExecutionPlanMetricsSet::new();

0 commit comments

Comments
 (0)