diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index da6088cd19631..8572405f9c071 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1663,10 +1663,11 @@ dependencies = [ [[package]] name = "tokio" -version = "1.17.0" +version = "1.22.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2af73ac49756f3f7c01172e34a23e5d0216f6c32333757c2c61feb2bbff5a5ee" +checksum = "d76ce4a75fb488c605c54bf610f221cea8b0dafb53333c1a67e8ee199dcd2ae3" dependencies = [ + "autocfg", "num_cpus", "parking_lot", "pin-project-lite", diff --git a/datafusion/core/src/physical_plan/analyze.rs b/datafusion/core/src/physical_plan/analyze.rs index f8050f16ce398..5bad6102b342e 100644 --- a/datafusion/core/src/physical_plan/analyze.rs +++ b/datafusion/core/src/physical_plan/analyze.rs @@ -123,16 +123,19 @@ impl ExecutionPlan for AnalyzeExec { ))); } - let (tx, rx) = tokio::sync::mpsc::channel(input_partitions); + let mut builder = + RecordBatchReceiverStream::builder(self.schema(), input_partitions); + let tx = builder.tx(); let captured_input = self.input.clone(); let mut input_stream = captured_input.execute(0, context).await?; let captured_schema = self.schema.clone(); let verbose = self.verbose; - // Task reads batches the input and when complete produce a - // RecordBatch with a report that is written to `tx` when done - let join_handle = tokio::task::spawn(async move { + // Task reads batches from the input and when complete produces + // a RecordBatch with a report that is written to `tx` when + // done. Panics from this task are propagated via the builder. + builder.spawn(async move { let start = Instant::now(); let mut total_rows = 0; @@ -201,11 +204,7 @@ impl ExecutionPlan for AnalyzeExec { tx.send(maybe_batch).await.ok(); }); - Ok(RecordBatchReceiverStream::create( - &self.schema, - rx, - join_handle, - )) + Ok(builder.build()) } fn fmt_as( diff --git a/datafusion/core/src/physical_plan/coalesce_partitions.rs b/datafusion/core/src/physical_plan/coalesce_partitions.rs index 3ecbd61f2e4ac..a86a819f2b522 100644 --- a/datafusion/core/src/physical_plan/coalesce_partitions.rs +++ b/datafusion/core/src/physical_plan/coalesce_partitions.rs @@ -20,27 +20,20 @@ use std::any::Any; use std::sync::Arc; -use std::task::Poll; - -use futures::channel::mpsc; -use futures::Stream; use async_trait::async_trait; -use arrow::record_batch::RecordBatch; -use arrow::{datatypes::SchemaRef, error::Result as ArrowResult}; +use arrow::datatypes::SchemaRef; -use super::common::AbortOnDropMany; use super::expressions::PhysicalSortExpr; use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; -use super::{RecordBatchStream, Statistics}; +use super::stream::{ObservedStream, RecordBatchReceiverStream}; +use super::Statistics; use crate::error::{DataFusionError, Result}; use crate::physical_plan::{DisplayFormatType, ExecutionPlan, Partitioning}; use super::SendableRecordBatchStream; use crate::execution::context::TaskContext; -use crate::physical_plan::common::spawn_execution; -use pin_project_lite::pin_project; /// Merge execution plan executes partitions in parallel and combines them into a single /// partition. No guarantees are made about the order of the resulting partition. @@ -134,27 +127,17 @@ impl ExecutionPlan for CoalescePartitionsExec { // use a stream that allows each sender to put in at // least one result in an attempt to maximize // parallelism. - let (sender, receiver) = - mpsc::channel::>(input_partitions); + let mut builder = + RecordBatchReceiverStream::builder(self.schema(), input_partitions); // spawn independent tasks whose resulting streams (of batches) // are sent to the channel for consumption. - let mut join_handles = Vec::with_capacity(input_partitions); for part_i in 0..input_partitions { - join_handles.push(spawn_execution( - self.input.clone(), - sender.clone(), - part_i, - context.clone(), - )); + builder.run_input(self.input.clone(), part_i, context.clone()); } - Ok(Box::pin(MergeStream { - input: receiver, - schema: self.schema(), - baseline_metrics, - drop_helper: AbortOnDropMany(join_handles), - })) + let stream = builder.build(); + Ok(Box::pin(ObservedStream::new(stream, baseline_metrics))) } } } @@ -180,35 +163,6 @@ impl ExecutionPlan for CoalescePartitionsExec { } } -pin_project! { - struct MergeStream { - schema: SchemaRef, - #[pin] - input: mpsc::Receiver>, - baseline_metrics: BaselineMetrics, - drop_helper: AbortOnDropMany<()>, - } -} - -impl Stream for MergeStream { - type Item = ArrowResult; - - fn poll_next( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - let this = self.project(); - let poll = this.input.poll_next(cx); - this.baseline_metrics.record_poll(poll) - } -} - -impl RecordBatchStream for MergeStream { - fn schema(&self) -> SchemaRef { - self.schema.clone() - } -} - #[cfg(test)] mod tests { @@ -220,7 +174,9 @@ mod tests { use crate::physical_plan::file_format::{CsvExec, FileScanConfig}; use crate::physical_plan::{collect, common}; use crate::prelude::SessionContext; - use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; + use crate::test::exec::{ + assert_strong_count_converges_to_zero, BlockingExec, PanicExec, + }; use crate::test::{self, assert_is_pending}; use crate::test_util; @@ -288,4 +244,19 @@ mod tests { Ok(()) } + + #[tokio::test] + #[should_panic(expected = "PanickingStream did panic")] + async fn test_panic() { + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); + let schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); + + let panicking_exec = Arc::new(PanicExec::new(Arc::clone(&schema), 2)); + let coalesce_partitions_exec = + Arc::new(CoalescePartitionsExec::new(panicking_exec)); + + collect(coalesce_partitions_exec, task_ctx).await.unwrap(); + } } diff --git a/datafusion/core/src/physical_plan/hash_aggregate.rs b/datafusion/core/src/physical_plan/hash_aggregate.rs index fa6b65b20b61d..6a57f520d78b8 100644 --- a/datafusion/core/src/physical_plan/hash_aggregate.rs +++ b/datafusion/core/src/physical_plan/hash_aggregate.rs @@ -23,13 +23,11 @@ use std::task::{Context, Poll}; use std::vec; use ahash::RandomState; -use futures::{ - stream::{Stream, StreamExt}, - Future, -}; +use futures::stream::{Stream, StreamExt}; use crate::error::Result; use crate::physical_plan::hash_utils::create_hashes; +use crate::physical_plan::stream::RecordBatchReceiverStream; use crate::physical_plan::{ Accumulator, AggregateExpr, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PhysicalExpr, @@ -39,19 +37,17 @@ use crate::scalar::ScalarValue; use arrow::{array::ArrayRef, compute, compute::cast}; use arrow::{ array::{Array, UInt32Builder}, - error::{ArrowError, Result as ArrowResult}, + error::Result as ArrowResult, }; use arrow::{ datatypes::{Field, Schema, SchemaRef}, record_batch::RecordBatch, }; use hashbrown::raw::RawTable; -use pin_project_lite::pin_project; use crate::execution::context::TaskContext; use async_trait::async_trait; -use super::common::AbortOnDropSingle; use super::expressions::PhysicalSortExpr; use super::metrics::{ self, BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput, @@ -356,14 +352,9 @@ Example: average * Once all N record batches arrive, `merge` is performed, which builds a RecordBatch with N rows and 2 columns. * Finally, `get_value` returns an array with one entry computed from the state */ -pin_project! { - struct GroupedHashAggregateStream { - schema: SchemaRef, - #[pin] - output: futures::channel::oneshot::Receiver>, - finished: bool, - drop_helper: AbortOnDropSingle<()>, - } +struct GroupedHashAggregateStream { + schema: SchemaRef, + stream: SendableRecordBatchStream, } fn group_aggregate_batch( @@ -570,12 +561,16 @@ impl GroupedHashAggregateStream { input: SendableRecordBatchStream, baseline_metrics: BaselineMetrics, ) -> Self { - let (tx, rx) = futures::channel::oneshot::channel(); + // Use the panic-propagating builder so that panics in the + // compute task are re-raised on the consumer side instead of + // being reported as a closed channel. + let mut builder = RecordBatchReceiverStream::builder(schema.clone(), 1); + let tx = builder.tx(); let schema_clone = schema.clone(); let elapsed_compute = baseline_metrics.elapsed_compute().clone(); - let join_handle = tokio::spawn(async move { + builder.spawn(async move { let result = compute_grouped_hash_aggregate( mode, schema_clone, @@ -588,14 +583,12 @@ impl GroupedHashAggregateStream { .record_output(&baseline_metrics); // failing here is OK, the receiver is gone and does not care about the result - tx.send(result).ok(); + tx.send(result).await.ok(); }); Self { schema, - output: rx, - finished: false, - drop_helper: AbortOnDropSingle::new(join_handle), + stream: builder.build(), } } } @@ -647,31 +640,10 @@ impl Stream for GroupedHashAggregateStream { type Item = ArrowResult; fn poll_next( - self: std::pin::Pin<&mut Self>, + mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { - if self.finished { - return Poll::Ready(None); - } - - // is the output ready? - let this = self.project(); - let output_poll = this.output.poll(cx); - - match output_poll { - Poll::Ready(result) => { - *this.finished = true; - - // check for error in receiving channel and unwrap actual result - let result = match result { - Err(e) => Err(ArrowError::ExternalError(Box::new(e))), // error receiving - Ok(result) => result, - }; - - Poll::Ready(Some(result)) - } - Poll::Pending => Poll::Pending, - } + self.stream.poll_next_unpin(cx) } } @@ -748,15 +720,10 @@ fn aggregate_expressions( } } -pin_project! { - /// stream struct for hash aggregation - pub struct HashAggregateStream { - schema: SchemaRef, - #[pin] - output: futures::channel::oneshot::Receiver>, - finished: bool, - drop_helper: AbortOnDropSingle<()>, - } +/// stream struct for hash aggregation +pub struct HashAggregateStream { + schema: SchemaRef, + stream: SendableRecordBatchStream, } /// Special case aggregate with no groups @@ -799,11 +766,15 @@ impl HashAggregateStream { input: SendableRecordBatchStream, baseline_metrics: BaselineMetrics, ) -> Self { - let (tx, rx) = futures::channel::oneshot::channel(); + // Use the panic-propagating builder so that panics in the + // compute task are re-raised on the consumer side instead of + // being reported as a closed channel. + let mut builder = RecordBatchReceiverStream::builder(schema.clone(), 1); + let tx = builder.tx(); let schema_clone = schema.clone(); let elapsed_compute = baseline_metrics.elapsed_compute().clone(); - let join_handle = tokio::spawn(async move { + builder.spawn(async move { let result = compute_hash_aggregate( mode, schema_clone, @@ -815,14 +786,12 @@ impl HashAggregateStream { .record_output(&baseline_metrics); // failing here is OK, the receiver is gone and does not care about the result - tx.send(result).ok(); + tx.send(result).await.ok(); }); Self { schema, - output: rx, - finished: false, - drop_helper: AbortOnDropSingle::new(join_handle), + stream: builder.build(), } } } @@ -863,31 +832,10 @@ impl Stream for HashAggregateStream { type Item = ArrowResult; fn poll_next( - self: std::pin::Pin<&mut Self>, + mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { - if self.finished { - return Poll::Ready(None); - } - - // is the output ready? - let this = self.project(); - let output_poll = this.output.poll(cx); - - match output_poll { - Poll::Ready(result) => { - *this.finished = true; - - // check for error in receiving channel and unwrap actual result - let result = match result { - Err(e) => Err(ArrowError::ExternalError(Box::new(e))), // error receiving - Ok(result) => result, - }; - - Poll::Ready(Some(result)) - } - Poll::Pending => Poll::Pending, - } + self.stream.poll_next_unpin(cx) } } diff --git a/datafusion/core/src/physical_plan/stream.rs b/datafusion/core/src/physical_plan/stream.rs index 67b7090406901..78c28a8a97037 100644 --- a/datafusion/core/src/physical_plan/stream.rs +++ b/datafusion/core/src/physical_plan/stream.rs @@ -17,43 +17,354 @@ //! Stream wrappers for physical operators +use std::cell::Cell; +use std::pin::Pin; +use std::sync::{Arc, Once}; +use std::task::{Context, Poll}; + +use crate::error::DataFusionError; +use crate::execution::context::TaskContext; use arrow::{ - datatypes::SchemaRef, error::Result as ArrowResult, record_batch::RecordBatch, + datatypes::SchemaRef, + error::{ArrowError, Result as ArrowResult}, + record_batch::RecordBatch, }; -use futures::{Stream, StreamExt}; -use tokio::task::JoinHandle; +use futures::stream::BoxStream; +use futures::{Future, Stream, StreamExt}; +use log::debug; +use tokio::sync::mpsc::{Receiver, Sender}; +use tokio::task::{JoinHandle, JoinSet}; use tokio_stream::wrappers::ReceiverStream; use super::common::AbortOnDropSingle; -use super::{RecordBatchStream, SendableRecordBatchStream}; +use super::displayable; +use super::metrics::BaselineMetrics; +use super::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream}; -/// Adapter for a tokio [`ReceiverStream`] that implements the -/// [`SendableRecordBatchStream`] -/// interface -pub struct RecordBatchReceiverStream { +thread_local! { + /// When set, the installed panic hook suppresses the default + /// stderr output for panics that happen on this thread. The flag + /// is scoped to the lifetime of a single poll/call of a task + /// spawned through the builder: panics in that code are already + /// captured by tokio and re-raised on the consumer thread, so + /// printing them a second time to stderr is only noise. + static SUPPRESS_PANIC_OUTPUT: Cell = const { Cell::new(false) }; +} + +static PANIC_HOOK_ONCE: Once = Once::new(); + +/// Install a process-wide panic hook that silences the default stderr +/// output only while a DataFusion-spawned task is running on the +/// current thread. Panics outside those tasks keep the previous hook. +fn install_panic_hook_once() { + PANIC_HOOK_ONCE.call_once(|| { + let prev = std::panic::take_hook(); + std::panic::set_hook(Box::new(move |info| { + if !SUPPRESS_PANIC_OUTPUT.with(|c| c.get()) { + prev(info); + } + })); + }); +} + +/// Guard that sets the thread-local flag on creation and clears it on +/// drop, so the flag is restored even when the wrapped code panics. +struct SuppressPanicGuard { + prev: bool, +} + +impl SuppressPanicGuard { + fn new() -> Self { + let prev = SUPPRESS_PANIC_OUTPUT.with(|c| c.replace(true)); + Self { prev } + } +} + +impl Drop for SuppressPanicGuard { + fn drop(&mut self) { + SUPPRESS_PANIC_OUTPUT.with(|c| c.set(self.prev)); + } +} + +pin_project_lite::pin_project! { + /// Future wrapper that activates the panic-hook suppression flag + /// around every poll of the inner future. Re-applying on every + /// poll is required because tokio may migrate the task to a + /// different worker thread across `.await` points. + struct SuppressPanicOutput { + #[pin] + inner: F, + } +} + +impl Future for SuppressPanicOutput { + type Output = F::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let _guard = SuppressPanicGuard::new(); + self.project().inner.poll(cx) + } +} + +/// Builder for [`RecordBatchReceiverStream`] that propagates errors +/// and panics from spawned tasks to the consumer. +/// +/// [`RecordBatchReceiverStream`] is used to spawn one or more tasks +/// that produce `RecordBatch`es and send them to a single +/// `Receiver` which can improve parallelism. Previously, panics in +/// those tasks were silently dropped; this builder uses a +/// [`JoinSet`] so that panics are re-raised on the consumer thread +/// and outstanding tasks are aborted when the stream is dropped. +pub struct RecordBatchReceiverStreamBuilder { + tx: Sender>, + rx: Receiver>, schema: SchemaRef, + join_set: JoinSet<()>, +} + +impl RecordBatchReceiverStreamBuilder { + /// Create a new builder with an internal buffer of `capacity` batches. + pub fn new(schema: SchemaRef, capacity: usize) -> Self { + let (tx, rx) = tokio::sync::mpsc::channel(capacity); + + Self { + tx, + rx, + schema, + join_set: JoinSet::new(), + } + } + + /// Get a handle for sending [`RecordBatch`]es to the output. + pub fn tx(&self) -> Sender> { + self.tx.clone() + } + + /// Spawn a task that will be aborted if this builder (or the + /// stream built from it) is dropped. + /// + /// Often used to spawn tasks that write to the sender returned + /// by [`Self::tx`]. Panics inside the task are captured and + /// re-raised on the consumer thread; they are also silenced + /// from the default panic hook so the panic is not printed + /// twice to stderr. + pub fn spawn(&mut self, task: F) + where + F: Future + Send + 'static, + { + install_panic_hook_once(); + self.join_set.spawn(SuppressPanicOutput { inner: task }); + } + + /// Spawn a blocking task that will be aborted if this builder + /// (or the stream built from it) is dropped. + /// + /// Often used to spawn tasks that write to the sender returned + /// by [`Self::tx`]. Panics are propagated to the consumer and + /// silenced from the default panic hook. + pub fn spawn_blocking(&mut self, f: F) + where + F: FnOnce() + Send + 'static, + { + install_panic_hook_once(); + // `JoinSet::spawn_blocking` was added in tokio 1.24; this + // crate is on 1.22, so we spawn the blocking job via the + // free function and forward its JoinError into the JoinSet + // by re-raising any panic inside an adopted async task. + let handle = tokio::task::spawn_blocking(move || { + let _guard = SuppressPanicGuard::new(); + f(); + }); + self.join_set.spawn(async move { + match handle.await { + Ok(()) => {} + Err(e) => { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } + // Cancellation of the blocking task is only + // possible via abort of this adopter task, + // which implies the stream is being dropped. + } + } + }); + } - inner: ReceiverStream>, + /// Run a partition of the given `input` [`ExecutionPlan`] on the + /// tokio threadpool and forward its output batches to this + /// builder's channel. + /// + /// If the input partition produces an error, the error is + /// forwarded and no further batches are sent from that task. + pub(crate) fn run_input( + &mut self, + input: Arc, + partition: usize, + context: Arc, + ) { + let output = self.tx(); - #[allow(dead_code)] - drop_helper: AbortOnDropSingle<()>, + self.spawn(async move { + let mut stream = match input.execute(partition, context).await { + Err(e) => { + // If send fails, the plan is being torn down and + // there is no place to report the error. + let arrow_error = ArrowError::ExternalError(Box::new(e)); + output.send(Err(arrow_error)).await.ok(); + debug!( + "Stopping execution: error executing input: {}", + displayable(input.as_ref()).indent() + ); + return; + } + Ok(stream) => stream, + }; + + while let Some(item) = stream.next().await { + let is_err = item.is_err(); + + // If send fails, the consumer is gone; no reason to + // keep producing. + if output.send(item).await.is_err() { + debug!( + "Stopping execution: output is gone, plan cancelling: {}", + displayable(input.as_ref()).indent() + ); + return; + } + + // Stop after the first error so we don't drive every + // input to completion once one has failed. + if is_err { + debug!( + "Stopping execution: plan returned error: {}", + displayable(input.as_ref()).indent() + ); + return; + } + } + }); + } + + /// Create a stream of all `RecordBatch`es written to the channel, + /// propagating any panics from spawned tasks. + pub fn build(self) -> SendableRecordBatchStream { + let Self { + tx, + rx, + schema, + mut join_set, + } = self; + + // drop our own sender so the receiver closes once all + // producer tasks have completed + drop(tx); + + // future that joins every spawned task and re-raises panics + let check = async move { + while let Some(result) = join_set.join_next().await { + match result { + Ok(()) => continue, + Err(e) => { + if e.is_panic() { + // resume on the consumer thread. Keep the + // suppression flag on for the re-raise + // too so the default panic hook stays + // silent when the unwind fires here. + install_panic_hook_once(); + let _guard = SuppressPanicGuard::new(); + std::panic::resume_unwind(e.into_panic()); + } else { + // Only reachable if the task was cancelled, + // which only happens when the JoinSet is + // dropped, i.e. when this stream has been + // dropped, so this code will not run. + return Some(Err(ArrowError::ExternalError(Box::new( + DataFusionError::Internal(format!( + "Non Panic Task error: {}", + e + )), + )))); + } + } + } + } + None + }; + + let check_stream = + futures::stream::once(check).filter_map(|item| async move { item }); + + // Interleave batches from the channel with the join-set + // check so whichever is ready first produces output. + let inner = + futures::stream::select(ReceiverStream::new(rx), check_stream).boxed(); + + Box::pin(RecordBatchReceiverStream { schema, inner }) + } +} + +/// Adapter for a tokio [`ReceiverStream`] that implements the +/// [`SendableRecordBatchStream`] interface and propagates panics and +/// errors from the tasks writing to the underlying channel. Use +/// [`Self::builder`] to construct one. +pub struct RecordBatchReceiverStream { + schema: SchemaRef, + inner: BoxStream<'static, ArrowResult>, } impl RecordBatchReceiverStream { + /// Create a builder with an internal buffer of `capacity` batches. + pub fn builder( + schema: SchemaRef, + capacity: usize, + ) -> RecordBatchReceiverStreamBuilder { + RecordBatchReceiverStreamBuilder::new(schema, capacity) + } + /// Construct a new [`RecordBatchReceiverStream`] which will send - /// batches of the specfied schema from `inner` + /// batches of the specified schema from `rx`, while monitoring + /// `join_handle` for panics. + /// + /// The task is aborted if the returned stream is dropped. pub fn create( schema: &SchemaRef, rx: tokio::sync::mpsc::Receiver>, join_handle: JoinHandle<()>, ) -> SendableRecordBatchStream { let schema = schema.clone(); - let inner = ReceiverStream::new(rx); - Box::pin(Self { - schema, - inner, - drop_helper: AbortOnDropSingle::new(join_handle), - }) + + // Hold the handle in an AbortOnDropSingle so dropping the + // resulting stream aborts the background task. + let abort_helper = AbortOnDropSingle::new(join_handle); + + let check = async move { + match abort_helper.await { + Ok(()) => None, + Err(e) => { + if e.is_panic() { + install_panic_hook_once(); + let _guard = SuppressPanicGuard::new(); + std::panic::resume_unwind(e.into_panic()); + } else { + Some(Err(ArrowError::ExternalError(Box::new( + DataFusionError::Internal(format!( + "Non Panic Task error: {}", + e + )), + )))) + } + } + } + }; + + let check_stream = + futures::stream::once(check).filter_map(|item| async move { item }); + + let inner = + futures::stream::select(ReceiverStream::new(rx), check_stream).boxed(); + + Box::pin(Self { schema, inner }) } } @@ -73,3 +384,215 @@ impl RecordBatchStream for RecordBatchReceiverStream { self.schema.clone() } } + +/// Combines a [`Stream`] with a [`SchemaRef`] implementing +/// [`SendableRecordBatchStream`] for the combination. +pub struct RecordBatchStreamAdapter { + schema: SchemaRef, + stream: S, +} + +impl RecordBatchStreamAdapter { + /// Creates a new [`RecordBatchStreamAdapter`] from the provided schema and stream + pub fn new(schema: SchemaRef, stream: S) -> Self { + Self { schema, stream } + } +} + +impl std::fmt::Debug for RecordBatchStreamAdapter { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("RecordBatchStreamAdapter") + .field("schema", &self.schema) + .finish() + } +} + +impl Stream for RecordBatchStreamAdapter +where + S: Stream> + Unpin, +{ + type Item = ArrowResult; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.stream.poll_next_unpin(cx) + } + + fn size_hint(&self) -> (usize, Option) { + self.stream.size_hint() + } +} + +impl RecordBatchStream for RecordBatchStreamAdapter +where + S: Stream> + Unpin, +{ + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +/// Stream wrapper that records [`BaselineMetrics`] for a particular +/// [`SendableRecordBatchStream`] (typically a partition). +pub(crate) struct ObservedStream { + inner: SendableRecordBatchStream, + baseline_metrics: BaselineMetrics, +} + +impl ObservedStream { + pub fn new( + inner: SendableRecordBatchStream, + baseline_metrics: BaselineMetrics, + ) -> Self { + Self { + inner, + baseline_metrics, + } + } +} + +impl RecordBatchStream for ObservedStream { + fn schema(&self) -> SchemaRef { + self.inner.schema() + } +} + +impl Stream for ObservedStream { + type Item = ArrowResult; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let poll = self.inner.poll_next_unpin(cx); + self.baseline_metrics.record_poll(poll) + } +} + +#[cfg(test)] +mod test { + use super::*; + use arrow::datatypes::{DataType, Field, Schema}; + + use crate::{ + prelude::SessionContext, + test::exec::{ + assert_strong_count_converges_to_zero, BlockingExec, MockExec, PanicExec, + }, + }; + + fn schema() -> SchemaRef { + Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])) + } + + #[tokio::test] + #[should_panic(expected = "PanickingStream did panic")] + async fn record_batch_receiver_stream_propagates_panics() { + let schema = schema(); + + let num_partitions = 10; + let input = PanicExec::new(schema.clone(), num_partitions); + consume(input, 10).await + } + + #[tokio::test] + #[should_panic(expected = "PanickingStream did panic: 1")] + async fn record_batch_receiver_stream_propagates_panics_early_shutdown() { + let schema = schema(); + + // two partitions; the second one panics before the first + let num_partitions = 2; + let input = PanicExec::new(schema.clone(), num_partitions) + .with_partition_panic(0, 10) + .with_partition_panic(1, 3); + + // The stream should stop after the first panic; since the + // two partitions interleave (0,1,0,1,0,panic) it should not + // exceed 5 batches prior to the panic. + let max_batches = 5; + consume(input, max_batches).await + } + + #[tokio::test] + async fn record_batch_receiver_stream_drop_cancel() { + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); + let schema = schema(); + + let input = BlockingExec::new(schema.clone(), 1); + let refs = input.refs(); + + let mut builder = RecordBatchReceiverStream::builder(schema, 2); + builder.run_input(Arc::new(input), 0, task_ctx.clone()); + let stream = builder.build(); + + // input should still be present + assert!(std::sync::Weak::strong_count(&refs) > 0); + + // drop the stream, ensure the refs go to zero + drop(stream); + assert_strong_count_converges_to_zero(refs).await; + } + + /// Ensure that when an error is received from one stream the + /// [`RecordBatchReceiverStream`] stops early and does not drive + /// other streams to completion. + #[tokio::test] + async fn record_batch_receiver_stream_error_does_not_drive_completion() { + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); + let schema = schema(); + + let error_stream = MockExec::new( + vec![ + Err(ArrowError::ComputeError("Test1".to_string())), + Err(ArrowError::ComputeError("Test2".to_string())), + ], + schema.clone(), + ) + .with_use_task(false); + + let mut builder = RecordBatchReceiverStream::builder(schema, 2); + builder.run_input(Arc::new(error_stream), 0, task_ctx.clone()); + let mut stream = builder.build(); + + // first result should be the first error + let first_batch = stream.next().await.unwrap(); + let first_err = first_batch.unwrap_err(); + assert_eq!(first_err.to_string(), "Compute error: Test1"); + + // no more batches should be produced (second error must not surface) + assert!(stream.next().await.is_none()); + } + + /// Collect every partition of `input` into a + /// [`RecordBatchReceiverStream`] and drive it to completion, + /// panicking if more than `max_batches` are produced. + async fn consume(input: PanicExec, max_batches: usize) { + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); + + let input = Arc::new(input); + let num_partitions = input.output_partitioning().partition_count(); + + let mut builder = + RecordBatchReceiverStream::builder(input.schema(), num_partitions); + for partition in 0..num_partitions { + builder.run_input(input.clone(), partition, task_ctx.clone()); + } + let mut stream = builder.build(); + + let mut num_batches = 0; + while let Some(next) = stream.next().await { + next.unwrap(); + num_batches += 1; + assert!( + num_batches < max_batches, + "Got the limit of {} batches before seeing panic", + num_batches + ); + } + } +} diff --git a/datafusion/core/src/physical_plan/union.rs b/datafusion/core/src/physical_plan/union.rs index 21d0d2c167122..00cf1de0b9c36 100644 --- a/datafusion/core/src/physical_plan/union.rs +++ b/datafusion/core/src/physical_plan/union.rs @@ -23,20 +23,17 @@ use std::{any::Any, sync::Arc}; -use arrow::{ - datatypes::{Field, Schema, SchemaRef}, - record_batch::RecordBatch, -}; -use futures::StreamExt; +use arrow::datatypes::{Field, Schema, SchemaRef}; use itertools::Itertools; use super::{ expressions::PhysicalSortExpr, metrics::{ExecutionPlanMetricsSet, MetricsSet}, - ColumnStatistics, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, + ColumnStatistics, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, }; use crate::execution::context::TaskContext; +use crate::physical_plan::stream::ObservedStream; use crate::{ error::Result, physical_plan::{expressions, metrics::BaselineMetrics}, @@ -191,40 +188,6 @@ impl ExecutionPlan for UnionExec { } } -/// Stream wrapper that records `BaselineMetrics` for a particular -/// partition -struct ObservedStream { - inner: SendableRecordBatchStream, - baseline_metrics: BaselineMetrics, -} - -impl ObservedStream { - fn new(inner: SendableRecordBatchStream, baseline_metrics: BaselineMetrics) -> Self { - Self { - inner, - baseline_metrics, - } - } -} - -impl RecordBatchStream for ObservedStream { - fn schema(&self) -> arrow::datatypes::SchemaRef { - self.inner.schema() - } -} - -impl futures::Stream for ObservedStream { - type Item = arrow::error::Result; - - fn poll_next( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - let poll = self.inner.poll_next_unpin(cx); - self.baseline_metrics.record_poll(poll) - } -} - fn col_stats_union( mut left: ColumnStatistics, right: ColumnStatistics, diff --git a/datafusion/core/src/physical_plan/windows/window_agg_exec.rs b/datafusion/core/src/physical_plan/windows/window_agg_exec.rs index 553b6f26b8b0d..1095ac2559078 100644 --- a/datafusion/core/src/physical_plan/windows/window_agg_exec.rs +++ b/datafusion/core/src/physical_plan/windows/window_agg_exec.rs @@ -19,11 +19,11 @@ use crate::error::Result; use crate::execution::context::TaskContext; -use crate::physical_plan::common::AbortOnDropSingle; use crate::physical_plan::expressions::PhysicalSortExpr; use crate::physical_plan::metrics::{ BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, }; +use crate::physical_plan::stream::RecordBatchReceiverStream; use crate::physical_plan::{ common, ColumnStatistics, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream, Statistics, WindowExpr, @@ -31,13 +31,11 @@ use crate::physical_plan::{ use arrow::{ array::ArrayRef, datatypes::{Schema, SchemaRef}, - error::{ArrowError, Result as ArrowResult}, + error::Result as ArrowResult, record_batch::RecordBatch, }; use async_trait::async_trait; -use futures::stream::Stream; -use futures::FutureExt; -use pin_project_lite::pin_project; +use futures::stream::{Stream, StreamExt}; use std::any::Any; use std::pin::Pin; use std::sync::Arc; @@ -232,16 +230,11 @@ fn compute_window_aggregates( .collect() } -pin_project! { - /// stream for window aggregation plan - pub struct WindowAggStream { - schema: SchemaRef, - drop_helper: AbortOnDropSingle<()>, - #[pin] - output: futures::channel::oneshot::Receiver>, - finished: bool, - baseline_metrics: BaselineMetrics, - } +/// stream for window aggregation plan +pub struct WindowAggStream { + schema: SchemaRef, + stream: SendableRecordBatchStream, + baseline_metrics: BaselineMetrics, } impl WindowAggStream { @@ -252,24 +245,30 @@ impl WindowAggStream { input: SendableRecordBatchStream, baseline_metrics: BaselineMetrics, ) -> Self { - let (tx, rx) = futures::channel::oneshot::channel(); + // Use the panic-propagating builder so that a panic in the + // compute task is re-raised on the consumer side rather than + // being reported as a closed channel. + let mut builder = RecordBatchReceiverStream::builder(schema.clone(), 1); + let tx = builder.tx(); + let schema_clone = schema.clone(); let elapsed_compute = baseline_metrics.elapsed_compute().clone(); - let join_handle = tokio::spawn(async move { - let schema = schema_clone.clone(); - let result = - WindowAggStream::process(input, window_expr, schema, elapsed_compute) - .await; + builder.spawn(async move { + let result = WindowAggStream::process( + input, + window_expr, + schema_clone, + elapsed_compute, + ) + .await; // failing here is OK, the receiver is gone and does not care about the result - tx.send(result).ok(); + tx.send(result).await.ok(); }); Self { schema, - drop_helper: AbortOnDropSingle::new(join_handle), - output: rx, - finished: false, + stream: builder.build(), baseline_metrics, } } @@ -308,39 +307,11 @@ impl Stream for WindowAggStream { mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { - let poll = self.poll_next_inner(cx); + let poll = self.stream.poll_next_unpin(cx); self.baseline_metrics.record_poll(poll) } } -impl WindowAggStream { - #[inline] - fn poll_next_inner( - self: &mut Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll>> { - if self.finished { - return Poll::Ready(None); - } - - // is the output ready? - let output_poll = self.output.poll_unpin(cx); - - match output_poll { - Poll::Ready(result) => { - self.finished = true; - // check for error in receiving channel and unwrap actual result - let result = match result { - Err(e) => Some(Err(ArrowError::ExternalError(Box::new(e)))), // error receiving - Ok(result) => Some(result), - }; - Poll::Ready(result) - } - Poll::Pending => Poll::Pending, - } - } -} - impl RecordBatchStream for WindowAggStream { /// Get the schema fn schema(&self) -> SchemaRef { diff --git a/datafusion/core/src/test/exec.rs b/datafusion/core/src/test/exec.rs index 7e0cbe35f824c..a514a4d3d8d02 100644 --- a/datafusion/core/src/test/exec.rs +++ b/datafusion/core/src/test/exec.rs @@ -41,7 +41,7 @@ use crate::physical_plan::{ }; use crate::{ error::{DataFusionError, Result}, - physical_plan::stream::RecordBatchReceiverStream, + physical_plan::stream::{RecordBatchReceiverStream, RecordBatchStreamAdapter}, }; /// Index into the data that has been returned so far @@ -119,22 +119,41 @@ impl RecordBatchStream for TestStream { } } -/// A Mock ExecutionPlan that can be used for writing tests of other ExecutionPlans -/// +/// A Mock ExecutionPlan that can be used for writing tests of other +/// ExecutionPlans #[derive(Debug)] pub struct MockExec { /// the results to send back data: Vec>, schema: SchemaRef, + /// if true (the default), sends data using a separate task to + /// ensure the batches are not available without this stream + /// yielding first + use_task: bool, } impl MockExec { - /// Create a new exec with a single partition that returns the - /// record batches in this Exec. Note the batches are not produced - /// immediately (the caller has to actually yield and another task - /// must run) to ensure any poll loops are correct. + /// Create a new `MockExec` with a single partition that returns + /// the specified `Result`s. + /// + /// By default, the batches are not produced immediately (the + /// caller has to actually yield and another task must run) to + /// ensure any poll loops are correct. This behavior can be + /// changed with `with_use_task`. pub fn new(data: Vec>, schema: SchemaRef) -> Self { - Self { data, schema } + Self { + data, + schema, + use_task: true, + } + } + + /// If `use_task` is true (the default) then the batches are sent + /// back using a separate task to ensure the underlying stream is + /// not immediately ready. + pub fn with_use_task(mut self, use_task: bool) -> Self { + self.use_task = use_task; + self } } @@ -185,26 +204,31 @@ impl ExecutionPlan for MockExec { }) .collect(); - let (tx, rx) = tokio::sync::mpsc::channel(2); - - // task simply sends data in order but in a separate - // thread (to ensure the batches are not available without the - // DelayedStream yielding). - let join_handle = tokio::task::spawn(async move { - for batch in data { - println!("Sending batch via delayed stream"); - if let Err(e) = tx.send(batch).await { - println!("ERROR batch via delayed stream: {}", e); + if self.use_task { + let mut builder = RecordBatchReceiverStream::builder(self.schema(), 2); + // task simply sends data in order but in a separate + // task (to ensure the batches are not available without + // the stream yielding). + let tx = builder.tx(); + builder.spawn(async move { + for batch in data { + println!("Sending batch via delayed stream"); + if let Err(e) = tx.send(batch).await { + println!("ERROR batch via delayed stream: {}", e); + } } - } - }); - - // returned stream simply reads off the rx stream - Ok(RecordBatchReceiverStream::create( - &self.schema, - rx, - join_handle, - )) + }); + // returned stream simply reads off the rx stream + Ok(builder.build()) + } else { + // return a synchronous stream that yields pre-built + // results without needing a task + let stream = futures::stream::iter(data); + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema(), + stream, + ))) + } } fn fmt_as( @@ -314,12 +338,13 @@ impl ExecutionPlan for BarrierExec { ) -> Result { assert!(partition < self.data.len()); - let (tx, rx) = tokio::sync::mpsc::channel(2); + let mut builder = RecordBatchReceiverStream::builder(self.schema(), 2); // task simply sends data in order after barrier is reached let data = self.data[partition].clone(); let b = self.barrier.clone(); - let join_handle = tokio::task::spawn(async move { + let tx = builder.tx(); + builder.spawn(async move { println!("Partition {} waiting on barrier", partition); b.wait().await; for batch in data { @@ -331,11 +356,7 @@ impl ExecutionPlan for BarrierExec { }); // returned stream simply reads off the rx stream - Ok(RecordBatchReceiverStream::create( - &self.schema, - rx, - join_handle, - )) + Ok(builder.build()) } fn fmt_as( @@ -655,3 +676,139 @@ pub async fn assert_strong_count_converges_to_zero(refs: Weak) { .await .unwrap(); } + +/// Execution plan that emits streams that panic. +/// +/// This is useful to test panic handling of certain execution plans. +#[derive(Debug)] +pub struct PanicExec { + /// Schema that is mocked by this plan. + schema: SchemaRef, + + /// Number of output partitions. Each generates a panicking stream. + batches_until_panics: Vec, +} + +impl PanicExec { + /// Create new [`PanicExec`] with a given schema and number of + /// partitions, which will each panic immediately. + pub fn new(schema: SchemaRef, n_partitions: usize) -> Self { + Self { + schema, + batches_until_panics: vec![0; n_partitions], + } + } + + /// Set the number of batches prior to panic for a partition. + pub fn with_partition_panic(mut self, partition: usize, count: usize) -> Self { + self.batches_until_panics[partition] = count; + self + } +} + +#[async_trait] +impl ExecutionPlan for PanicExec { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn children(&self) -> Vec> { + // this is a leaf node and has no children + vec![] + } + + fn output_partitioning(&self) -> Partitioning { + Partitioning::UnknownPartitioning(self.batches_until_panics.len()) + } + + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + None + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> Result> { + Err(DataFusionError::Internal(format!( + "Children cannot be replaced in {:?}", + self + ))) + } + + async fn execute( + &self, + partition: usize, + _context: Arc, + ) -> Result { + Ok(Box::pin(PanickingStream { + partition, + batches_until_panic: self.batches_until_panics[partition], + schema: Arc::clone(&self.schema), + ready: false, + })) + } + + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default => { + write!(f, "PanicExec") + } + } + } + + fn statistics(&self) -> Statistics { + unimplemented!() + } +} + +/// A [`RecordBatchStream`] that yields an empty batch `batches_until_panic` +/// times and then panics. +#[derive(Debug)] +struct PanickingStream { + /// Which partition was this stream created for. + partition: usize, + + /// Batches remaining before panic. + batches_until_panic: usize, + + /// Schema mocked by this stream. + schema: SchemaRef, + + /// Toggle to ensure other streams can be polled between outputs. + ready: bool, +} + +impl Stream for PanickingStream { + type Item = ArrowResult; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + if self.batches_until_panic == 0 { + panic!("PanickingStream did panic: {}", self.partition) + } + if !self.ready { + self.ready = true; + cx.waker().wake_by_ref(); + return Poll::Pending; + } + self.batches_until_panic -= 1; + self.ready = false; + Poll::Ready(Some(Ok(RecordBatch::new_empty(Arc::clone(&self.schema))))) + } +} + +impl RecordBatchStream for PanickingStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +}