@@ -89,6 +89,13 @@ object CometExecRule {
8989
9090 val allExecs : Map [Class [_ <: SparkPlan ], CometOperatorSerde [_]] = nativeExecs ++ sinks
9191
92+ /**
93+ * Tag set on a `ShuffleExchangeExec` that should be left as a plain Spark shuffle rather than
94+ * wrapped in `CometShuffleExchangeExec`. See `tagRedundantColumnarShuffle`.
95+ */
96+ val SKIP_COMET_SHUFFLE_TAG : org.apache.spark.sql.catalyst.trees.TreeNodeTag [Unit ] =
97+ org.apache.spark.sql.catalyst.trees.TreeNodeTag [Unit ](" comet.skipCometShuffle" )
98+
9299}
93100
94101/**
@@ -108,9 +115,13 @@ case class CometExecRule(session: SparkSession)
108115 * row->arrow->shuffle->arrow->row conversion overhead with no Comet consumer on either side.
109116 * See https://github.com/apache/datafusion-comet/issues/4004.
110117 *
111- * The match is intentionally narrow (both sides must be row-based aggregates) so this does not
112- * interfere with non-relational plan shapes such as object-mode Dataset plans where the shuffle
113- * sits between encoder/serializer nodes.
118+ * The match is intentionally narrow (both sides must be row-based aggregates that remained JVM
119+ * after the main transform pass). Running the revert post-transform means we only fire when the
120+ * main conversion already decided to keep both aggregates JVM - we never create the dangerous
121+ * mixed mode where a Comet partial feeds a JVM final (see issue #1389).
122+ *
123+ * Also tag the reverted shuffle so AQE stage-isolated re-planning does not convert it back to a
124+ * Comet shuffle when the outer aggregate context is no longer visible.
114125 */
115126 private def revertRedundantColumnarShuffle (plan : SparkPlan ): SparkPlan = {
116127 def isAggregate (p : SparkPlan ): Boolean =
@@ -127,26 +138,35 @@ case class CometExecRule(session: SparkSession)
127138 val newChildren = op.children.map {
128139 case s : CometShuffleExchangeExec
129140 if s.shuffleType == CometColumnarShuffle && isAggregate(s.child) =>
130- s.originalPlan.withNewChildren(Seq (s.child)).asInstanceOf [SparkPlan ]
141+ val reverted =
142+ s.originalPlan.withNewChildren(Seq (s.child)).asInstanceOf [ShuffleExchangeExec ]
143+ reverted.setTagValue(CometExecRule .SKIP_COMET_SHUFFLE_TAG , ())
144+ reverted
131145 case other => other
132146 }
133147 op.withNewChildren(newChildren)
134148 }
135149 }
136150
151+ private def shouldSkipCometShuffle (s : ShuffleExchangeExec ): Boolean =
152+ s.getTagValue(CometExecRule .SKIP_COMET_SHUFFLE_TAG ).isDefined
153+
137154 private def applyCometShuffle (plan : SparkPlan ): SparkPlan = {
138- plan.transformUp { case s : ShuffleExchangeExec =>
139- CometShuffleExchangeExec .shuffleSupported(s) match {
140- case Some (CometNativeShuffle ) =>
141- // Switch to use Decimal128 regardless of precision, since Arrow native execution
142- // doesn't support Decimal32 and Decimal64 yet.
143- conf.setConfString(CometConf .COMET_USE_DECIMAL_128 .key, " true" )
144- CometShuffleExchangeExec (s, shuffleType = CometNativeShuffle )
145- case Some (CometColumnarShuffle ) =>
146- CometShuffleExchangeExec (s, shuffleType = CometColumnarShuffle )
147- case None =>
148- s
149- }
155+ plan.transformUp {
156+ case s : ShuffleExchangeExec if shouldSkipCometShuffle(s) =>
157+ s
158+ case s : ShuffleExchangeExec =>
159+ CometShuffleExchangeExec .shuffleSupported(s) match {
160+ case Some (CometNativeShuffle ) =>
161+ // Switch to use Decimal128 regardless of precision, since Arrow native execution
162+ // doesn't support Decimal32 and Decimal64 yet.
163+ conf.setConfString(CometConf .COMET_USE_DECIMAL_128 .key, " true" )
164+ CometShuffleExchangeExec (s, shuffleType = CometNativeShuffle )
165+ case Some (CometColumnarShuffle ) =>
166+ CometShuffleExchangeExec (s, shuffleType = CometColumnarShuffle )
167+ case None =>
168+ s
169+ }
150170 }
151171 }
152172
@@ -295,6 +315,9 @@ case class CometExecRule(session: SparkSession)
295315 case s @ ShuffleQueryStageExec (_, ReusedExchangeExec (_, _ : CometShuffleExchangeExec ), _) =>
296316 convertToComet(s, CometExchangeSink ).getOrElse(s)
297317
318+ case s : ShuffleExchangeExec if shouldSkipCometShuffle(s) =>
319+ s
320+
298321 case s : ShuffleExchangeExec =>
299322 convertToComet(s, CometShuffleExchangeExec ).getOrElse(s)
300323
@@ -498,10 +521,9 @@ case class CometExecRule(session: SparkSession)
498521 case CometScanWrapper (_, s) => s
499522 }
500523
501- // Revert CometColumnarShuffle to Spark's ShuffleExchangeExec when both the parent and
502- // the child are non-Comet (JVM) operators. In that case the Comet shuffle only adds
503- // row->arrow->arrow->row conversion overhead with no Comet operator on either side to
504- // benefit from columnar output. See https://github.com/apache/datafusion-comet/issues/4004.
524+ // Revert CometColumnarShuffle to Spark's ShuffleExchangeExec when sandwiched between two
525+ // non-Comet HashAggregate/ObjectHashAggregate operators that remained JVM after the main
526+ // transform pass. See https://github.com/apache/datafusion-comet/issues/4004.
505527 newPlan = revertRedundantColumnarShuffle(newPlan)
506528
507529 // Set up logical links
0 commit comments