Skip to content

Commit 5e1e49d

Browse files
committed
fix: tag reverted shuffle so AQE stage re-planning does not reintroduce Comet columnar shuffle
Without the tag, AQE re-plans each stage in isolation, and the isolated subplan (which no longer shows the parent aggregate) converts the reverted ShuffleExchangeExec back into a CometShuffleExchangeExec. Subsequent plan canonicalization then fails because a ColumnarToRowExec ends up with a non-columnar child. Persist the revert decision via a TreeNodeTag on the ShuffleExchangeExec. Both applyCometShuffle and the main transform now short-circuit when the tag is set, so the decision survives re-entrancy.
1 parent 804b120 commit 5e1e49d

1 file changed

Lines changed: 42 additions & 20 deletions

File tree

spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala

Lines changed: 42 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,13 @@ object CometExecRule {
8989

9090
val allExecs: Map[Class[_ <: SparkPlan], CometOperatorSerde[_]] = nativeExecs ++ sinks
9191

92+
/**
93+
* Tag set on a `ShuffleExchangeExec` that should be left as a plain Spark shuffle rather than
94+
* wrapped in `CometShuffleExchangeExec`. See `tagRedundantColumnarShuffle`.
95+
*/
96+
val SKIP_COMET_SHUFFLE_TAG: org.apache.spark.sql.catalyst.trees.TreeNodeTag[Unit] =
97+
org.apache.spark.sql.catalyst.trees.TreeNodeTag[Unit]("comet.skipCometShuffle")
98+
9299
}
93100

94101
/**
@@ -108,9 +115,13 @@ case class CometExecRule(session: SparkSession)
108115
* row->arrow->shuffle->arrow->row conversion overhead with no Comet consumer on either side.
109116
* See https://github.com/apache/datafusion-comet/issues/4004.
110117
*
111-
* The match is intentionally narrow (both sides must be row-based aggregates) so this does not
112-
* interfere with non-relational plan shapes such as object-mode Dataset plans where the shuffle
113-
* sits between encoder/serializer nodes.
118+
* The match is intentionally narrow (both sides must be row-based aggregates that remained JVM
119+
* after the main transform pass). Running the revert post-transform means we only fire when the
120+
* main conversion already decided to keep both aggregates JVM - we never create the dangerous
121+
* mixed mode where a Comet partial feeds a JVM final (see issue #1389).
122+
*
123+
* Also tag the reverted shuffle so AQE stage-isolated re-planning does not convert it back to a
124+
* Comet shuffle when the outer aggregate context is no longer visible.
114125
*/
115126
private def revertRedundantColumnarShuffle(plan: SparkPlan): SparkPlan = {
116127
def isAggregate(p: SparkPlan): Boolean =
@@ -127,26 +138,35 @@ case class CometExecRule(session: SparkSession)
127138
val newChildren = op.children.map {
128139
case s: CometShuffleExchangeExec
129140
if s.shuffleType == CometColumnarShuffle && isAggregate(s.child) =>
130-
s.originalPlan.withNewChildren(Seq(s.child)).asInstanceOf[SparkPlan]
141+
val reverted =
142+
s.originalPlan.withNewChildren(Seq(s.child)).asInstanceOf[ShuffleExchangeExec]
143+
reverted.setTagValue(CometExecRule.SKIP_COMET_SHUFFLE_TAG, ())
144+
reverted
131145
case other => other
132146
}
133147
op.withNewChildren(newChildren)
134148
}
135149
}
136150

151+
private def shouldSkipCometShuffle(s: ShuffleExchangeExec): Boolean =
152+
s.getTagValue(CometExecRule.SKIP_COMET_SHUFFLE_TAG).isDefined
153+
137154
private def applyCometShuffle(plan: SparkPlan): SparkPlan = {
138-
plan.transformUp { case s: ShuffleExchangeExec =>
139-
CometShuffleExchangeExec.shuffleSupported(s) match {
140-
case Some(CometNativeShuffle) =>
141-
// Switch to use Decimal128 regardless of precision, since Arrow native execution
142-
// doesn't support Decimal32 and Decimal64 yet.
143-
conf.setConfString(CometConf.COMET_USE_DECIMAL_128.key, "true")
144-
CometShuffleExchangeExec(s, shuffleType = CometNativeShuffle)
145-
case Some(CometColumnarShuffle) =>
146-
CometShuffleExchangeExec(s, shuffleType = CometColumnarShuffle)
147-
case None =>
148-
s
149-
}
155+
plan.transformUp {
156+
case s: ShuffleExchangeExec if shouldSkipCometShuffle(s) =>
157+
s
158+
case s: ShuffleExchangeExec =>
159+
CometShuffleExchangeExec.shuffleSupported(s) match {
160+
case Some(CometNativeShuffle) =>
161+
// Switch to use Decimal128 regardless of precision, since Arrow native execution
162+
// doesn't support Decimal32 and Decimal64 yet.
163+
conf.setConfString(CometConf.COMET_USE_DECIMAL_128.key, "true")
164+
CometShuffleExchangeExec(s, shuffleType = CometNativeShuffle)
165+
case Some(CometColumnarShuffle) =>
166+
CometShuffleExchangeExec(s, shuffleType = CometColumnarShuffle)
167+
case None =>
168+
s
169+
}
150170
}
151171
}
152172

@@ -295,6 +315,9 @@ case class CometExecRule(session: SparkSession)
295315
case s @ ShuffleQueryStageExec(_, ReusedExchangeExec(_, _: CometShuffleExchangeExec), _) =>
296316
convertToComet(s, CometExchangeSink).getOrElse(s)
297317

318+
case s: ShuffleExchangeExec if shouldSkipCometShuffle(s) =>
319+
s
320+
298321
case s: ShuffleExchangeExec =>
299322
convertToComet(s, CometShuffleExchangeExec).getOrElse(s)
300323

@@ -498,10 +521,9 @@ case class CometExecRule(session: SparkSession)
498521
case CometScanWrapper(_, s) => s
499522
}
500523

501-
// Revert CometColumnarShuffle to Spark's ShuffleExchangeExec when both the parent and
502-
// the child are non-Comet (JVM) operators. In that case the Comet shuffle only adds
503-
// row->arrow->arrow->row conversion overhead with no Comet operator on either side to
504-
// benefit from columnar output. See https://github.com/apache/datafusion-comet/issues/4004.
524+
// Revert CometColumnarShuffle to Spark's ShuffleExchangeExec when sandwiched between two
525+
// non-Comet HashAggregate/ObjectHashAggregate operators that remained JVM after the main
526+
// transform pass. See https://github.com/apache/datafusion-comet/issues/4004.
505527
newPlan = revertRedundantColumnarShuffle(newPlan)
506528

507529
// Set up logical links

0 commit comments

Comments
 (0)