@@ -89,6 +89,13 @@ object CometExecRule {
8989
9090 val allExecs : Map [Class [_ <: SparkPlan ], CometOperatorSerde [_]] = nativeExecs ++ sinks
9191
92+ /**
93+ * Tag set on a `ShuffleExchangeExec` that should be left as a plain Spark shuffle rather than
94+ * wrapped in `CometShuffleExchangeExec`. See `tagRedundantColumnarShuffle`.
95+ */
96+ val SKIP_COMET_SHUFFLE_TAG : org.apache.spark.sql.catalyst.trees.TreeNodeTag [Unit ] =
97+ org.apache.spark.sql.catalyst.trees.TreeNodeTag [Unit ](" comet.skipCometShuffle" )
98+
9299}
93100
94101/**
@@ -100,19 +107,78 @@ case class CometExecRule(session: SparkSession)
100107
101108 private lazy val showTransformations = CometConf .COMET_EXPLAIN_TRANSFORMATIONS .get()
102109
110+ /**
111+ * Revert any `CometShuffleExchangeExec` with `CometColumnarShuffle` whose parent and child are
112+ * both non-Comet `HashAggregateExec` / `ObjectHashAggregateExec` operators back to the original
113+ * Spark `ShuffleExchangeExec`. This is the partial-final-aggregate pattern where Comet couldn't
114+ * convert either aggregate; keeping a columnar shuffle between them only adds
115+ * row->arrow->shuffle->arrow->row conversion overhead with no Comet consumer on either side.
116+ * See https://github.com/apache/datafusion-comet/issues/4004.
117+ *
118+ * The match is intentionally narrow (both sides must be row-based aggregates that remained JVM
119+ * after the main transform pass). Running the revert post-transform means we only fire when the
120+ * main conversion already decided to keep both aggregates JVM - we never create the dangerous
121+ * mixed mode where a Comet partial feeds a JVM final (see issue #1389).
122+ *
123+ * Correctness depends on running as part of `preColumnarTransitions`: if the revert ran after
124+ * Spark inserted `ColumnarToRowExec` between the aggregate and the columnar shuffle, the
125+ * pattern would no longer match (the shuffle would be separated from the aggregate by the
126+ * transition) and the unnecessary conversion could not be eliminated.
127+ *
128+ * The reverted shuffle is tagged with `SKIP_COMET_SHUFFLE_TAG` so both the AQE
129+ * `QueryStagePrepRule` pass and the `ColumnarRule` `preColumnarTransitions` pass leave it alone
130+ * on re-entry - AQE in particular re-runs the rule on each stage in isolation, where the outer
131+ * aggregate context is no longer visible and the shuffle would otherwise be re-wrapped as a
132+ * Comet columnar shuffle.
133+ */
134+ private def revertRedundantColumnarShuffle (plan : SparkPlan ): SparkPlan = {
135+ def isAggregate (p : SparkPlan ): Boolean =
136+ p.isInstanceOf [HashAggregateExec ] || p.isInstanceOf [ObjectHashAggregateExec ]
137+
138+ def isRedundantShuffle (child : SparkPlan ): Boolean = child match {
139+ case s : CometShuffleExchangeExec =>
140+ s.shuffleType == CometColumnarShuffle && isAggregate(s.child)
141+ case _ => false
142+ }
143+
144+ plan.transform {
145+ case op if isAggregate(op) && op.children.exists(isRedundantShuffle) =>
146+ val newChildren = op.children.map {
147+ case s : CometShuffleExchangeExec
148+ if s.shuffleType == CometColumnarShuffle && isAggregate(s.child) =>
149+ val reverted =
150+ s.originalPlan.withNewChildren(Seq (s.child)).asInstanceOf [ShuffleExchangeExec ]
151+ reverted.setTagValue(CometExecRule .SKIP_COMET_SHUFFLE_TAG , ())
152+ logInfo(
153+ " Reverting Comet columnar shuffle to Spark shuffle between " +
154+ s " ${op.getClass.getSimpleName} and ${s.child.getClass.getSimpleName} " +
155+ " (no Comet operator on either side to consume columnar output)" )
156+ reverted
157+ case other => other
158+ }
159+ op.withNewChildren(newChildren)
160+ }
161+ }
162+
163+ private def shouldSkipCometShuffle (s : ShuffleExchangeExec ): Boolean =
164+ s.getTagValue(CometExecRule .SKIP_COMET_SHUFFLE_TAG ).isDefined
165+
103166 private def applyCometShuffle (plan : SparkPlan ): SparkPlan = {
104- plan.transformUp { case s : ShuffleExchangeExec =>
105- CometShuffleExchangeExec .shuffleSupported(s) match {
106- case Some (CometNativeShuffle ) =>
107- // Switch to use Decimal128 regardless of precision, since Arrow native execution
108- // doesn't support Decimal32 and Decimal64 yet.
109- conf.setConfString(CometConf .COMET_USE_DECIMAL_128 .key, " true" )
110- CometShuffleExchangeExec (s, shuffleType = CometNativeShuffle )
111- case Some (CometColumnarShuffle ) =>
112- CometShuffleExchangeExec (s, shuffleType = CometColumnarShuffle )
113- case None =>
114- s
115- }
167+ plan.transformUp {
168+ case s : ShuffleExchangeExec if shouldSkipCometShuffle(s) =>
169+ s
170+ case s : ShuffleExchangeExec =>
171+ CometShuffleExchangeExec .shuffleSupported(s) match {
172+ case Some (CometNativeShuffle ) =>
173+ // Switch to use Decimal128 regardless of precision, since Arrow native execution
174+ // doesn't support Decimal32 and Decimal64 yet.
175+ conf.setConfString(CometConf .COMET_USE_DECIMAL_128 .key, " true" )
176+ CometShuffleExchangeExec (s, shuffleType = CometNativeShuffle )
177+ case Some (CometColumnarShuffle ) =>
178+ CometShuffleExchangeExec (s, shuffleType = CometColumnarShuffle )
179+ case None =>
180+ s
181+ }
116182 }
117183 }
118184
@@ -261,6 +327,9 @@ case class CometExecRule(session: SparkSession)
261327 case s @ ShuffleQueryStageExec (_, ReusedExchangeExec (_, _ : CometShuffleExchangeExec ), _) =>
262328 convertToComet(s, CometExchangeSink ).getOrElse(s)
263329
330+ case s : ShuffleExchangeExec if shouldSkipCometShuffle(s) =>
331+ s
332+
264333 case s : ShuffleExchangeExec =>
265334 convertToComet(s, CometShuffleExchangeExec ).getOrElse(s)
266335
@@ -464,6 +533,13 @@ case class CometExecRule(session: SparkSession)
464533 case CometScanWrapper (_, s) => s
465534 }
466535
536+ // Revert CometColumnarShuffle to Spark's ShuffleExchangeExec when both its parent and child
537+ // are non-Comet HashAggregate/ObjectHashAggregate operators that remained JVM after the main
538+ // transform pass. See https://github.com/apache/datafusion-comet/issues/4004.
539+ if (CometConf .COMET_EXEC_SHUFFLE_REVERT_REDUNDANT_COLUMNAR_ENABLED .get()) {
540+ newPlan = revertRedundantColumnarShuffle(newPlan)
541+ }
542+
467543 // Set up logical links
468544 newPlan = newPlan.transform {
469545 case op : CometExec =>
0 commit comments