@@ -25,10 +25,10 @@ use std::vec;
2525use ahash:: RandomState ;
2626use futures:: {
2727 stream:: { Stream , StreamExt } ,
28- Future ,
28+ Future , FutureExt ,
2929} ;
3030
31- use crate :: error:: Result ;
31+ use crate :: error:: { DataFusionError , Result } ;
3232use crate :: physical_plan:: hash_utils:: create_hashes;
3333use crate :: physical_plan:: {
3434 Accumulator , AggregateExpr , DisplayFormatType , Distribution , ExecutionPlan ,
@@ -576,16 +576,32 @@ impl GroupedHashAggregateStream {
576576 let elapsed_compute = baseline_metrics. elapsed_compute ( ) . clone ( ) ;
577577
578578 let join_handle = tokio:: spawn ( async move {
579- let result = compute_grouped_hash_aggregate (
580- mode,
581- schema_clone,
582- group_expr,
583- aggr_expr,
584- input,
585- elapsed_compute,
586- )
579+ let result = std:: panic:: AssertUnwindSafe ( async move {
580+ compute_grouped_hash_aggregate (
581+ mode,
582+ schema_clone,
583+ group_expr,
584+ aggr_expr,
585+ input,
586+ elapsed_compute,
587+ )
588+ . await
589+ . record_output ( & baseline_metrics)
590+ } )
591+ . catch_unwind ( )
587592 . await
588- . record_output ( & baseline_metrics) ;
593+ . unwrap_or_else ( |panic_payload| {
594+ let msg = if let Some ( s) = panic_payload. downcast_ref :: < & str > ( ) {
595+ s
596+ } else if let Some ( s) = panic_payload. downcast_ref :: < String > ( ) {
597+ s. as_str ( )
598+ } else {
599+ "unknown panic"
600+ } ;
601+ Err ( ArrowError :: ExternalError ( Box :: new (
602+ DataFusionError :: Execution ( format ! ( "Panic: {}" , msg) ) ,
603+ ) ) )
604+ } ) ;
589605
590606 // failing here is OK, the receiver is gone and does not care about the result
591607 tx. send ( result) . ok ( ) ;
@@ -804,15 +820,31 @@ impl HashAggregateStream {
804820 let schema_clone = schema. clone ( ) ;
805821 let elapsed_compute = baseline_metrics. elapsed_compute ( ) . clone ( ) ;
806822 let join_handle = tokio:: spawn ( async move {
807- let result = compute_hash_aggregate (
808- mode,
809- schema_clone,
810- aggr_expr,
811- input,
812- elapsed_compute,
813- )
823+ let result = std:: panic:: AssertUnwindSafe ( async move {
824+ compute_hash_aggregate (
825+ mode,
826+ schema_clone,
827+ aggr_expr,
828+ input,
829+ elapsed_compute,
830+ )
831+ . await
832+ . record_output ( & baseline_metrics)
833+ } )
834+ . catch_unwind ( )
814835 . await
815- . record_output ( & baseline_metrics) ;
836+ . unwrap_or_else ( |panic_payload| {
837+ let msg = if let Some ( s) = panic_payload. downcast_ref :: < & str > ( ) {
838+ s
839+ } else if let Some ( s) = panic_payload. downcast_ref :: < String > ( ) {
840+ s. as_str ( )
841+ } else {
842+ "unknown panic"
843+ } ;
844+ Err ( ArrowError :: ExternalError ( Box :: new (
845+ DataFusionError :: Execution ( format ! ( "Panic: {}" , msg) ) ,
846+ ) ) )
847+ } ) ;
816848
817849 // failing here is OK, the receiver is gone and does not care about the result
818850 tx. send ( result) . ok ( ) ;
0 commit comments